Files
drl_2/xuexiao/utils.py
user9994793890 8a76174c4d fix: 修复 Flask 新版本 is_xhr 属性错误
Coze-Commit-Type: user
Coze-User-ID: 3722323274763196
Coze-Conversation-ID: 5260473
2026-05-29 12:23:16 +08:00

888 lines
32 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""工具函数模块 - 权限体系、装饰器、导出、初始化"""
import io
from datetime import datetime, date
from functools import wraps
from flask import session, flash, redirect, url_for, request, send_file, jsonify
from openpyxl import Workbook
from openpyxl.styles import Font, Alignment, Border, Side, PatternFill
from models import db, User, Role, Student, Course, Teacher, Class_, \
StudentAccount, RechargeActivity, RechargeRecord, ConsumptionRecord, \
RefundRecord, TransferRecord, StopRecord, AttendanceRecord, OperationLog, Schedule
# ========== 权限体系定义 ==========
PERMISSIONS = {
# 学员管理
'student_view': '学员管理-查看',
'student_add': '学员管理-新增',
'student_edit': '学员管理-编辑',
'student_delete': '学员管理-删除',
'student_export': '学员管理-导出',
'student_account_adjust': '学员管理-手动调课时',
# 课程管理
'course_view': '课程管理-查看',
'course_add': '课程管理-新增',
'course_edit': '课程管理-编辑',
'course_delete': '课程管理-删除',
'course_export': '课程管理-导出',
# 老师管理
'teacher_view': '老师管理-查看',
'teacher_add': '老师管理-新增',
'teacher_edit': '老师管理-编辑',
'teacher_delete': '老师管理-删除',
'teacher_export': '老师管理-导出',
# 班级管理
'class_view': '班级管理-查看',
'class_add': '班级管理-新增',
'class_edit': '班级管理-编辑',
'class_delete': '班级管理-删除',
'class_export': '班级管理-导出',
# 充值优惠
'recharge_view': '充值管理-查看',
'recharge_add': '充值管理-新增',
'recharge_export': '充值管理-导出',
'recharge_activity_view': '优惠活动-查看',
'recharge_activity_add': '优惠活动-新增',
'recharge_activity_edit': '优惠活动-编辑',
# 消课管理
'consumption_view': '消课管理-查看',
'consumption_add': '消课管理-新增',
'consumption_export': '消课管理-导出',
# 退费管理
'refund_view': '退费管理-查看',
'refund_add': '退费管理-新增',
'refund_export': '退费管理-导出',
# 转课/转赠
'transfer_view': '转课转赠-查看',
'transfer_add': '转课转赠-新增',
'transfer_export': '转课转赠-导出',
# 停课保号
'stop_view': '停课保号-查看',
'stop_add': '停课保号-新增',
'stop_export': '停课保号-导出',
# 考勤管理
'attendance_view': '考勤管理-查看',
'attendance_add': '考勤管理-录入',
'attendance_export': '考勤管理-导出',
# 统计报表
'statistics_view': '统计报表-查看',
'statistics_export': '统计报表-导出',
# 课程安排
'schedule_view': '课程安排-查看',
'schedule_add': '课程安排-新增',
'schedule_export': '课程安排-导出',
# 系统管理
'user_view': '用户管理-查看',
'user_add': '用户管理-新增',
'user_edit': '用户管理-编辑',
'role_view': '角色管理-查看',
'role_add': '角色管理-新增',
'role_edit': '角色管理-编辑',
'role_delete': '角色管理-删除',
'log_view': '操作日志-查看',
'log_export': '操作日志-导出',
# 课时核对表
'keshibiao_view': '课时核对表-查看',
'keshibiao_import': '课时核对表-导入',
'keshibiao_export': '课时核对表-导出',
}
# 按模块分组的权限(用于角色编辑界面)
PERMISSION_GROUPS = [
('学员管理', [
'student_view', 'student_add', 'student_edit', 'student_delete',
'student_export', 'student_account_adjust',
]),
('课程管理', ['course_view', 'course_add', 'course_edit', 'course_delete', 'course_export']),
('老师管理', ['teacher_view', 'teacher_add', 'teacher_edit', 'teacher_delete', 'teacher_export']),
('班级管理', ['class_view', 'class_add', 'class_edit', 'class_delete', 'class_export']),
('充值管理', ['recharge_view', 'recharge_add', 'recharge_export']),
('优惠活动', ['recharge_activity_view', 'recharge_activity_add', 'recharge_activity_edit']),
('消课管理', ['consumption_view', 'consumption_add', 'consumption_export']),
('退费管理', ['refund_view', 'refund_add', 'refund_export']),
('转课转赠', ['transfer_view', 'transfer_add', 'transfer_export']),
('停课保号', ['stop_view', 'stop_add', 'stop_export']),
('考勤管理', ['attendance_view', 'attendance_add', 'attendance_export']),
('统计报表', ['statistics_view', 'statistics_export']),
('课程安排', ['schedule_view', 'schedule_add', 'schedule_export']),
('用户管理', ['user_view', 'user_add', 'user_edit']),
('角色管理', ['role_view', 'role_add', 'role_edit', 'role_delete']),
('操作日志', ['log_view', 'log_export']),
('课时核对表', ['keshibiao_view', 'keshibiao_import', 'keshibiao_export']),
]
# 导航菜单与权限的映射(侧边栏过滤)
MENU_PERMISSIONS = {
'student': 'student_view',
'course': 'course_view',
'teacher': 'teacher_view',
'class': 'class_view',
'schedule': 'schedule_view',
'attendance': 'attendance_view',
'recharge_activity': 'recharge_activity_view',
'recharge': 'recharge_view',
'consumption': 'consumption_view',
'refund': 'refund_view',
'transfer': 'transfer_view',
'stop': 'stop_view',
'statistics': 'statistics_view',
'user': 'user_view',
'role': 'role_view',
'log': 'log_view',
}
# ========== 权限检查工具 ==========
def get_current_user():
"""获取当前登录用户对象"""
user_id = session.get('user_id')
if not user_id:
return None
return db.session.get(User, user_id)
def get_current_permissions():
"""获取当前用户的所有权限列表"""
user = get_current_user()
if not user or not user.role:
return []
if user.role.name == '超级管理员':
return list(PERMISSIONS.keys())
perms = user.role.permissions or ''
return [p.strip() for p in perms.split(',') if p.strip() in PERMISSIONS]
def has_permission(perm):
"""检查当前用户是否拥有某权限"""
return perm in get_current_permissions()
def is_admin():
"""检查当前用户是否为超级管理员"""
user = get_current_user()
return user and user.role and user.role.name == '超级管理员'
def is_teacher_role():
"""是否授课老师角色(需数据范围过滤)"""
user = get_current_user()
return bool(user and user.role and user.role.name == '授课老师')
def get_teacher_scope():
"""
授课老师数据范围:返回 dict(teacher_id, class_ids, student_ids) 或 None不限制
"""
if not is_teacher_role():
return None
from models import Teacher, Class_, ClassStudent
user = get_current_user()
teacher = Teacher.query.filter_by(user_id=user.id).first()
if not teacher:
return {'teacher_id': None, 'class_ids': [], 'student_ids': []}
class_ids = [
c.id for c in Class_.query.filter_by(teacher_id=teacher.id, status=1).all()
]
student_ids = []
if class_ids:
rows = ClassStudent.query.filter(
ClassStudent.class_id.in_(class_ids),
ClassStudent.status == 1,
).all()
student_ids = list({r.student_id for r in rows})
return {
'teacher_id': teacher.id,
'class_ids': class_ids,
'student_ids': student_ids,
}
def teacher_can_access_class(class_id):
scope = get_teacher_scope()
if scope is None:
return True
return int(class_id) in scope['class_ids']
def teacher_can_access_student(student_id):
scope = get_teacher_scope()
if scope is None:
return True
return int(student_id) in scope['student_ids']
def filter_students_query(query):
"""授课老师仅看本班学员"""
scope = get_teacher_scope()
if scope is None:
return query
if not scope['student_ids']:
return query.filter(False)
from models import Student
return query.filter(Student.id.in_(scope['student_ids']))
def filter_classes_query(query):
scope = get_teacher_scope()
if scope is None:
return query
if not scope['class_ids']:
return query.filter(False)
from models import Class_
return query.filter(Class_.id.in_(scope['class_ids']))
def login_required(f):
"""登录检查装饰器"""
from functools import wraps
@wraps(f)
def decorated(*args, **kwargs):
if not session.get('user_id'):
from flask import request, jsonify
if request.headers.get('X-Requested-With') == 'XMLHttpRequest' or request.headers.get('Content-Type') == 'application/json':
return jsonify({'error': '请先登录'}), 401
flash('请先登录', 'warning')
return redirect(url_for('login'))
return f(*args, **kwargs)
return decorated
def parse_date(date_str):
"""解析日期字符串"""
if date_str:
try:
from datetime import datetime as dt
return dt.strptime(date_str, '%Y-%m-%d').date()
except (ValueError, TypeError):
return None
return None
def log_operation(action, detail='', before=None, after=None, *, commit=True):
"""记录操作日志(可选 before/after 快照commit=False 时与业务同事务提交)"""
import json
from models import OperationLog
if before is not None or after is not None:
payload = {'message': detail}
if before is not None:
payload['before'] = before
if after is not None:
payload['after'] = after
detail = json.dumps(payload, ensure_ascii=False)
user_id = session.get('user_id')
if not user_id:
return
log = OperationLog(
user_id=user_id,
operation_type=action,
operation_detail=detail,
ip_address=request.remote_addr if request else '',
)
db.session.add(log)
if commit:
db.session.commit()
def permission_required(perm):
"""权限检查装饰器"""
def decorator(f):
@wraps(f)
def decorated(*args, **kwargs):
if not session.get('user_id'):
flash('请先登录', 'warning')
return redirect(url_for('login'))
if is_admin():
return f(*args, **kwargs)
if not has_permission(perm):
flash('您没有该操作的权限', 'danger')
return redirect(url_for('dashboard'))
return f(*args, **kwargs)
return decorated
return decorator
def export_permission_required(module_name):
"""导出权限检查装饰器"""
perm = f'{module_name}_export'
return permission_required(perm)
# ========== Excel 导出工具 ==========
def export_excel(title, headers, rows, filename=None):
"""
生成Excel文件并返回下载响应
Args:
title: 工作表标题(第一行合并标题)
headers: 表头列表 ['姓名', '电话', ...]
rows: 数据行列表 [[val1, val2, ...], ...]
filename: 下载文件名(不含扩展名), 默认用title+日期
"""
wb = Workbook()
ws = wb.active
ws.title = title[:31] # Excel工作表名最长31字符
# 样式定义
title_font = Font(name='微软雅黑', size=14, bold=True)
header_font = Font(name='微软雅黑', size=10, bold=True, color='FFFFFF')
header_fill = PatternFill(start_color='4472C4', end_color='4472C4', fill_type='solid')
cell_font = Font(name='微软雅黑', size=10)
center_align = Alignment(horizontal='center', vertical='center', wrap_text=True)
thin_border = Border(
left=Side(style='thin'),
right=Side(style='thin'),
top=Side(style='thin'),
bottom=Side(style='thin')
)
# 标题行
ws.merge_cells(start_row=1, start_column=1, end_row=1, end_column=len(headers))
title_cell = ws.cell(row=1, column=1, value=title)
title_cell.font = title_font
title_cell.alignment = Alignment(horizontal='center', vertical='center')
ws.row_dimensions[1].height = 36
# 表头行
for col, header in enumerate(headers, 1):
cell = ws.cell(row=2, column=col, value=header)
cell.font = header_font
cell.fill = header_fill
cell.alignment = center_align
cell.border = thin_border
ws.row_dimensions[2].height = 24
# 数据行
for row_idx, row_data in enumerate(rows, 3):
for col_idx, value in enumerate(row_data, 1):
cell = ws.cell(row=row_idx, column=col_idx, value=value)
cell.font = cell_font
cell.alignment = center_align
cell.border = thin_border
# 自动调整列宽
for col in range(1, len(headers) + 1):
max_length = 0
for row in ws.iter_rows(min_row=2, max_row=ws.max_row, min_col=col, max_col=col):
for cell in row:
if cell.value:
length = len(str(cell.value))
# 中文字符算2个宽度
cn_count = sum(1 for c in str(cell.value) if '\u4e00' <= c <= '\u9fff')
length = length + cn_count
max_length = max(max_length, length)
ws.column_dimensions[ws.cell(row=2, column=col).column_letter].width = min(max_length + 4, 40)
# 保存到内存
output = io.BytesIO()
wb.save(output)
output.seek(0)
if not filename:
filename = f"{title}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
filename = filename.replace(' ', '_')
return send_file(
output,
mimetype='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
as_attachment=True,
download_name=f'{filename}.xlsx'
)
# ========== 各模块导出数据查询 ==========
def export_students(keyword=None):
"""导出学员信息"""
query = Student.query
if keyword:
query = query.filter(
db.or_(Student.name.contains(keyword), Student.phone.contains(keyword))
)
students = query.order_by(Student.id.desc()).all()
headers = ['序号', '姓名', '生日', '来源', '备注', '创建时间']
rows = []
for i, s in enumerate(students, 1):
rows.append([
i, s.name,
str(s.birthday) if s.birthday else '',
s.source or '', s.remark or '',
s.created_at.strftime('%Y-%m-%d %H:%M') if s.created_at else ''
])
return export_excel('学员信息', headers, rows, '学员信息')
def export_courses():
"""导出课程信息"""
courses = Course.query.order_by(Course.id.desc()).all()
headers = ['序号', '课程名称', '单价(元/课时)', '总课时', '材料费', '课程类型', '等级', '状态', '备注']
rows = []
for i, c in enumerate(courses, 1):
rows.append([
i, c.name, c.price_per_hour or 0, c.total_hours or 0,
c.material_fee or 0, c.type or '', c.level or '',
'启用' if c.status == 1 else '停用', c.remark or ''
])
return export_excel('课程信息', headers, rows, '课程信息')
def export_teachers():
"""导出老师信息"""
teachers = Teacher.query.order_by(Teacher.id.desc()).all()
headers = ['序号', '姓名', '电话', '专长', '状态', '备注']
rows = []
for i, t in enumerate(teachers, 1):
rows.append([
i, t.name, t.phone or '', t.specialty or '',
'在职' if t.status == 1 else '离职', t.remark or ''
])
return export_excel('老师信息', headers, rows, '老师信息')
def export_recharges(student_id=None, start_date=None, end_date=None):
"""导出充值记录"""
query = RechargeRecord.query
if student_id:
query = query.filter(RechargeRecord.student_id == student_id)
if start_date:
query = query.filter(RechargeRecord.created_at >= start_date)
if end_date:
query = query.filter(RechargeRecord.created_at <= end_date)
records = query.order_by(RechargeRecord.id.desc()).all()
headers = ['序号', '学员姓名', '课程名称', '充值金额', '充值课时', '赠送课时',
'支付方式', '优惠活动', '备注', '充值时间']
rows = []
for i, r in enumerate(records, 1):
rows.append([
i,
r.student.name if r.student else '',
r.course.name if r.course else '',
r.amount or 0,
r.hours_recharged or 0,
r.hours_gifted or 0,
r.payment_method or '',
r.activity.name if r.activity else '',
r.remark or '',
r.created_at.strftime('%Y-%m-%d %H:%M') if r.created_at else ''
])
return export_excel('充值记录', headers, rows, '充值记录')
def export_consumptions(student_id=None, start_date=None, end_date=None):
"""导出消课记录"""
query = ConsumptionRecord.query
if student_id:
query = query.filter(ConsumptionRecord.student_id == student_id)
if start_date:
query = query.filter(ConsumptionRecord.consume_date >= start_date)
if end_date:
query = query.filter(ConsumptionRecord.consume_date <= end_date)
records = query.order_by(ConsumptionRecord.id.desc()).all()
headers = ['序号', '学员姓名', '课程名称', '班级名称', '消课课时',
'消课类型', '是否补课', '是否试听', '消课日期', '备注']
rows = []
for i, r in enumerate(records, 1):
consume_type = '普通'
if r.is_trial:
consume_type = '试听'
elif r.consume_type == 'mixed':
consume_type = '混合'
rows.append([
i,
r.student.name if r.student else '',
r.course.name if r.course else '',
r.class_.name if r.class_ else '',
r.hours_consumed or 0,
consume_type,
'' if r.is_makeup else '',
'' if r.is_trial else '',
str(r.consume_date) if r.consume_date else '',
r.remark or ''
])
return export_excel('消课记录', headers, rows, '消课记录')
def export_refunds():
"""导出退费记录"""
records = RefundRecord.query.order_by(RefundRecord.id.desc()).all()
headers = ['序号', '学员姓名', '课程名称', '充值总额', '已消课时金额',
'材料费扣除', '赠课扣回', '退费金额', '退费原因', '退费时间']
rows = []
for i, r in enumerate(records, 1):
rows.append([
i,
r.student.name if r.student else '',
r.course.name if r.course else '',
r.total_recharged or 0,
r.consumed_hours_value or 0,
r.material_fee_deducted or 0,
r.gifted_hours_deducted or 0,
r.refund_amount or 0,
r.reason or '',
r.created_at.strftime('%Y-%m-%d %H:%M') if r.created_at else ''
])
return export_excel('退费记录', headers, rows, '退费记录')
def export_transfers():
"""导出转课/转赠记录"""
records = TransferRecord.query.order_by(TransferRecord.id.desc()).all()
headers = ['序号', '学员姓名', '转出课程', '转入课程/学员', '转出课时',
'折算金额', '类型', '备注', '时间']
rows = []
for i, r in enumerate(records, 1):
target = ''
if r.transfer_type == 'gift' and r.to_student:
target = f'转赠给: {r.to_student.name}'
elif r.to_course:
target = f'转入: {r.to_course.name}'
type_label = {'transfer': '转课', 'gift': '转赠', 'upgrade': '升阶'}.get(
r.transfer_type, r.transfer_type
)
rows.append([
i,
r.from_student.name if r.from_student else '',
r.from_course.name if r.from_course else '',
target,
r.transfer_hours or 0,
r.transfer_amount or 0,
type_label,
r.remark or '',
r.created_at.strftime('%Y-%m-%d %H:%M') if r.created_at else '',
])
return export_excel('转课转赠记录', headers, rows, '转课转赠记录')
def export_stops():
"""导出停课记录"""
records = StopRecord.query.order_by(StopRecord.id.desc()).all()
headers = ['序号', '学员姓名', '课程名称', '开始日期', '结束日期',
'状态', '备注', '创建时间']
rows = []
for i, r in enumerate(records, 1):
status = '已复课' if r.status == 0 else '停课中'
if r.status == 1 and r.end_date and r.end_date < date.today():
status = '已过期'
rows.append([
i,
r.student.name if r.student else '',
r.course.name if r.course else '',
str(r.start_date) if r.start_date else '',
str(r.end_date) if r.end_date else '',
status,
r.reason or '',
r.created_at.strftime('%Y-%m-%d %H:%M') if r.created_at else ''
])
return export_excel('停课保号记录', headers, rows, '停课保号记录')
def export_attendances(class_id=None, start_date=None, end_date=None):
"""导出考勤记录"""
query = AttendanceRecord.query
if class_id:
query = query.filter(AttendanceRecord.class_id == class_id)
if start_date:
query = query.filter(AttendanceRecord.date >= start_date)
if end_date:
query = query.filter(AttendanceRecord.date <= end_date)
records = query.order_by(AttendanceRecord.id.desc()).all()
headers = ['序号', '学员姓名', '班级名称', '考勤日期', '状态', '备注']
rows = []
status_map = {'present': '出勤', 'absent': '缺勤', 'late': '迟到', 'leave': '请假'}
for i, r in enumerate(records, 1):
rows.append([
i,
r.student.name if r.student else '',
r.class_.name if r.class_ else '',
str(r.date) if r.date else '',
status_map.get(r.status, r.status or ''),
r.remark or ''
])
return export_excel('考勤记录', headers, rows, '考勤记录')
def export_classes():
"""导出班级信息"""
classes = Class_.query.order_by(Class_.id.desc()).all()
headers = ['序号', '班级名称', '课程', '老师', '开课日期', '结课日期',
'上课时间', '最大人数', '当前人数', '状态', '备注']
rows = []
for i, c in enumerate(classes, 1):
rows.append([
i, c.name,
c.course.name if c.course else '',
c.teacher.name if c.teacher else '',
str(c.start_date) if c.start_date else '',
str(c.end_date) if c.end_date else '',
c.schedule or '',
c.max_students or 0,
c.current_students or 0,
'启用' if c.status == 1 else '停用',
c.remark or ''
])
return export_excel('班级信息', headers, rows, '班级信息')
def export_schedules():
"""导出课程安排"""
schedules = Schedule.query.order_by(Schedule.date.desc()).all()
headers = ['序号', '班级名称', '上课日期', '开始时间', '结束时间', '授课老师', '课程主题']
rows = []
for i, s in enumerate(schedules, 1):
rows.append([
i,
s.class_.name if s.class_ else '',
str(s.date) if s.date else '',
s.start_time or '',
s.end_time or '',
s.teacher.name if s.teacher else '',
s.topic or ''
])
return export_excel('课程安排', headers, rows, '课程安排')
def export_operation_logs():
"""导出操作日志"""
query = OperationLog.query
start_date = request.args.get('start_date', '').strip()
end_date = request.args.get('end_date', '').strip()
if start_date:
query = query.filter(OperationLog.created_at >= start_date)
if end_date:
query = query.filter(OperationLog.created_at <= end_date + ' 23:59:59')
records = query.order_by(OperationLog.id.desc()).limit(10000).all()
headers = ['序号', '操作人', '操作类型', '操作详情', 'IP地址', '操作时间']
rows = []
for i, r in enumerate(records, 1):
rows.append([
i,
r.user.real_name if r.user else (str(r.user_id) if r.user_id else ''),
r.operation_type or '',
r.operation_detail or '',
r.ip_address or '',
r.created_at.strftime('%Y-%m-%d %H:%M') if r.created_at else ''
])
return export_excel('操作日志', headers, rows, '操作日志')
def export_statistics_data():
"""导出统计数据(兼容旧入口)"""
from keshi_stats import resolve_report_period
year, month = resolve_report_period()
return export_statistics_snapshot_data(year, month)
def export_statistics_snapshot_data(year: int, month: int):
"""导出基于课时核对表快照的统计数据"""
from keshi_stats import (
period_overview, year_overview, cumulative_finance, course_stats_for_period,
)
label = f'{year}-{month:02d}'
month_ov = period_overview(year, month)
year_ov = year_overview(year)
finance = cumulative_finance()
courses = course_stats_for_period(year, month)
headers = ['统计项', '数值']
rows = [
['统计月份', label],
['当月充值(元)', month_ov['recharge']],
['当月消课(节)', month_ov['consumed_hours']],
['当月已消金额(元)', month_ov['consumed_amount']],
['当月退费(元)', month_ov['refund']],
['当月新签学员', month_ov['new_signups']],
['当月月末余额合计(元)', month_ov['end_balance']],
[],
[f'{year}年度充值(元)', year_ov['recharge']],
[f'{year}年度消课(节)', year_ov['consumed_hours']],
[f'{year}年度退费(元)', year_ov['refund']],
[],
['累计充值(元)', finance['total_recharge']],
['累计退费(元)', finance['total_refund']],
['净收入(元)', finance['net_income']],
]
if courses:
rows.append([])
rows.append(['=== 各课程当月统计 ===', ''])
rows.append(['课程', '充值', '消课节', '退费', '月末余额'])
for c in courses:
rows.append([
c['code'], c['recharge'], c['consumed_hours'],
c['refund'], c['end_balance'],
])
return export_excel('统计报表', headers, rows, f'统计报表_{label}')
# ========== 数据库自动迁移 ==========
def is_mysql():
"""判断当前是否使用 MySQL 协议(含 Doris"""
from flask import current_app
uri = current_app.config.get('SQLALCHEMY_DATABASE_URI', '')
return uri.startswith('mysql')
def is_doris():
"""是否 Apache Doris"""
from db_bootstrap import is_doris as _is_doris
return _is_doris()
def _migrate_col_type(col):
"""迁移用列类型Doris 与 MySQL 略有差异)"""
from db_bootstrap import _col_type_doris, is_doris as _is_doris
if _is_doris():
return _col_type_doris(col)
return str(col.type).upper()
def auto_migrate():
"""自动迁移检测并补齐数据库中缺失的列Doris 下缺失表由 doris_create_all 处理"""
from db_bootstrap import is_doris as _is_doris, doris_create_all
inspector = db.inspect(db.engine)
existing_tables = set(inspector.get_table_names())
if _is_doris():
model_tables = {m.class_.__tablename__ for m in db.Model.registry.mappers}
if model_tables - existing_tables:
doris_create_all()
inspector = db.inspect(db.engine)
existing_tables = set(inspector.get_table_names())
# 遍历所有模型类
for mapper in db.Model.registry.mappers:
model = mapper.class_
table = model.__table__
table_name = table.name
if table_name not in existing_tables:
continue
if _is_doris():
try:
rows = db.session.execute(
db.text(f'SHOW COLUMNS FROM `{table_name}`')
).fetchall()
existing_cols = {r[0] for r in rows}
except Exception:
existing_cols = set()
else:
existing_cols = {c['name'] for c in inspector.get_columns(table_name)}
model_cols = {c.name for c in table.columns}
missing_cols = model_cols - existing_cols
if not missing_cols:
continue
for col_name in missing_cols:
col = table.columns[col_name]
col_type = _migrate_col_type(col)
if is_mysql():
default_val = ''
if col.primary_key or col.autoincrement:
continue
if col.default is not None:
if col.default.is_scalar:
if isinstance(col.default.arg, (int, float)):
default_val = f' DEFAULT {col.default.arg}' if not _is_doris() else f' DEFAULT "{col.default.arg}"'
else:
default_val = f" DEFAULT '{col.default.arg}'"
elif not col.nullable:
# NOT NULL且无默认值给合理默认
if 'INT' in col_type:
default_val = ' DEFAULT 0' if not _is_doris() else ' DEFAULT "0"'
elif 'VARCHAR' in col_type or 'TEXT' in col_type or 'STRING' in col_type:
default_val = " DEFAULT ''"
elif 'DATE' in col_type or 'TIME' in col_type:
default_val = ''
elif 'BOOLEAN' in col_type or 'TINYINT' in col_type:
default_val = ' DEFAULT 0' if not _is_doris() else ' DEFAULT "0"'
elif 'FLOAT' in col_type or 'DECIMAL' in col_type:
default_val = ' DEFAULT 0.0' if not _is_doris() else ' DEFAULT "0"'
nullable = '' if col.nullable else ' NOT NULL'
sql = f'ALTER TABLE `{table_name}` ADD COLUMN `{col_name}` {col_type}{default_val}{nullable}'
else:
# SQLite
default_val = ''
if col.default is not None and col.default.is_scalar:
if isinstance(col.default.arg, (int, float)):
default_val = f' DEFAULT {col.default.arg}'
elif isinstance(col.default.arg, str):
default_val = f" DEFAULT '{col.default.arg}'"
elif not col.nullable:
if 'INT' in col_type:
default_val = ' DEFAULT 0'
elif 'VARCHAR' in col_type or 'TEXT' in col_type:
default_val = " DEFAULT ''"
elif 'BOOLEAN' in col_type:
default_val = ' DEFAULT 0'
elif 'FLOAT' in col_type or 'DECIMAL' in col_type:
default_val = ' DEFAULT 0.0'
sql = f'ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type}{default_val}'
try:
db.session.execute(db.text(sql))
db.session.commit()
print(f' [auto_migrate] Added column: {table_name}.{col_name}')
except Exception as e:
db.session.rollback()
print(f' [auto_migrate] Failed to add {table_name}.{col_name}: {e}')
# ========== 初始化默认数据 ==========
def init_default_data():
"""初始化默认数据确保admin用户和超级管理员角色存在"""
# 确保超级管理员角色存在
admin_role = Role.query.filter_by(name='超级管理员').first()
if not admin_role:
admin_role = Role(
name='超级管理员',
permissions=','.join(PERMISSIONS.keys()),
remark='系统默认超级管理员,拥有所有权限'
)
db.session.add(admin_role)
db.session.commit()
elif not admin_role.permissions or len(admin_role.permissions.split(',')) < len(PERMISSIONS):
# 超级管理员权限补齐
admin_role.permissions = ','.join(PERMISSIONS.keys())
db.session.commit()
# 确保admin用户存在且密码正确
admin = User.query.filter_by(username='admin').first()
if not admin:
admin = User(
username='admin',
real_name='系统管理员',
role_id=admin_role.id,
status=1,
remark='系统默认管理员'
)
admin.set_password('admin123')
db.session.add(admin)
db.session.commit()
else:
# 确保密码是admin123
if not admin.check_password('admin123'):
admin.set_password('admin123')
# 确保角色正确
if not admin.role_id or admin.role_id != admin_role.id:
admin.role_id = admin_role.id
db.session.commit()