Files
drl_2/xuexiao/db_bootstrap.py

269 lines
9.1 KiB
Python
Raw Normal View History

2026-05-29 10:28:07 +08:00
"""数据库初始化Doris / MySQL / SQLite 建表与迁移"""
import os
import re
from sqlalchemy import event, inspect as sa_inspect
from sqlalchemy.orm import Session
from models import db
_doris_id_listener_registered = False
# 建表顺序逻辑依赖顺序Doris 不强制外键)
TABLE_ORDER = [
'roles', 'users', 'teachers', 'students', 'courses', 'classes',
'class_students', 'student_accounts', 'recharge_activities',
'recharge_records', 'consumption_records', 'refund_records',
'transfer_records', 'stop_records', 'attendance_records',
'operation_logs', 'schedules', 'monthly_snapshots', 'class_records',
]
DORIS_BUCKETS = int(os.environ.get('DORIS_BUCKETS', '1'))
DORIS_REPLICATION_NUM = os.environ.get('DORIS_REPLICATION_NUM', '1')
def _get_uri():
from flask import current_app
return current_app.config.get('SQLALCHEMY_DATABASE_URI', '')
def is_doris():
"""是否连接 Apache DorisMySQL 协议FE 默认 9030"""
if os.environ.get('DB_TYPE', '').lower() == 'doris':
return True
uri = _get_uri()
if 'doris' in uri.lower():
return True
m = re.search(r':(\d+)/', uri)
return bool(m and m.group(1) == '9030')
def is_mysql():
"""是否 MySQL 协议(含 Doris"""
return _get_uri().startswith('mysql')
def is_sqlite():
return _get_uri().startswith('sqlite')
def quote_ident(name):
return f'`{name}`'
def doris_next_id(session, table_name, pk_name='id'):
"""Doris 无 LAST_INSERT_ID用 MAX(id)+1 分配主键"""
sql = f'SELECT COALESCE(MAX({quote_ident(pk_name)}), 0) + 1 FROM {quote_ident(table_name)}'
return int(session.execute(db.text(sql)).scalar() or 1)
def register_doris_id_listener():
"""插入前为无主键记录分配 id仅 Doris"""
global _doris_id_listener_registered
if _doris_id_listener_registered:
return
_doris_id_listener_registered = True
@event.listens_for(Session, 'before_flush')
def _assign_doris_ids(session, flush_context, instances):
if not is_doris():
return
for obj in session.new:
mapper = sa_inspect(obj.__class__)
if not mapper.tables:
continue
table = mapper.tables[0]
for col in mapper.primary_key:
if not getattr(col, 'autoincrement', False):
continue
if getattr(obj, col.key) is not None:
continue
new_id = doris_next_id(session, table.name, col.key)
setattr(obj, col.key, new_id)
def _col_type_doris(col):
"""SQLAlchemy 列类型 -> Doris DDL 类型"""
if col.primary_key and getattr(col, 'autoincrement', False):
# 主键由应用分配Doris 不支持可靠的 AUTO_INCREMENT / LAST_INSERT_ID
return 'BIGINT NOT NULL'
type_str = str(col.type).upper()
if 'BIGINT' in type_str:
return 'BIGINT'
if 'SMALLINT' in type_str or 'TINYINT' in type_str:
return 'SMALLINT'
if 'INTEGER' in type_str or type_str == 'INT':
return 'INT'
if 'VARCHAR' in type_str:
return type_str.replace('NVARCHAR', 'VARCHAR')
if 'TEXT' in type_str:
return 'STRING'
if 'DATETIME' in type_str:
return 'DATETIME'
if type_str == 'DATE' or type_str.startswith('DATE('):
return 'DATE'
if 'NUMERIC' in type_str or 'DECIMAL' in type_str:
return type_str.replace('NUMERIC', 'DECIMAL')
if 'FLOAT' in type_str or 'DOUBLE' in type_str:
return type_str
if 'BOOLEAN' in type_str:
return 'TINYINT'
return type_str
def _default_clause_doris(col):
"""生成 Doris 列 DEFAULT 子句"""
if col.primary_key and getattr(col, 'autoincrement', False):
return ''
if col.server_default is not None:
return ''
if col.default is None:
return ''
if not getattr(col.default, 'is_scalar', True):
return ''
val = col.default.arg
if val is None:
return ''
if isinstance(val, bool):
return f' DEFAULT "{int(val)}"'
if isinstance(val, (int, float)):
return f' DEFAULT "{val}"'
if isinstance(val, str):
return f" DEFAULT '{val.replace(chr(39), chr(39) + chr(39))}'"
return ''
def _build_doris_column(col):
parts = [quote_ident(col.name), _col_type_doris(col)]
if not col.nullable and not (col.primary_key and getattr(col, 'autoincrement', False)):
parts.append('NOT NULL')
parts.append(_default_clause_doris(col))
return ' '.join(p for p in parts if p)
def build_doris_create_sql(table):
"""根据 SQLAlchemy Table 生成 Doris CREATE TABLE 语句"""
pk_cols = [c for c in table.columns if c.primary_key]
pk_col = pk_cols[0] if pk_cols else table.columns[0]
pk_name = pk_col.name
col_defs = [_build_doris_column(c) for c in table.columns]
tname = quote_ident(table.name)
return (
f'CREATE TABLE IF NOT EXISTS {tname} (\n'
f' {",\n ".join(col_defs)}\n'
f') UNIQUE KEY({quote_ident(pk_name)})\n'
f'DISTRIBUTED BY HASH({quote_ident(pk_name)}) BUCKETS {DORIS_BUCKETS}\n'
f'PROPERTIES ("replication_num" = "{DORIS_REPLICATION_NUM}")'
)
def _table_by_name(name):
for mapper in db.Model.registry.mappers:
if mapper.class_.__tablename__ == name:
return mapper.class_.__table__
return None
def doris_create_all():
"""Doris按模型创建缺失的表首次启动 / 新表)"""
inspector = sa_inspect(db.engine)
existing = set(inspector.get_table_names())
created = []
for table_name in TABLE_ORDER:
if table_name in existing:
continue
table = _table_by_name(table_name)
if table is None:
print(f' [doris_create] Skip unknown table: {table_name}')
continue
sql = build_doris_create_sql(table)
try:
db.session.execute(db.text(sql))
db.session.commit()
created.append(table_name)
print(f' [doris_create] Created table: {table_name}')
except Exception as e:
db.session.rollback()
print(f' [doris_create] Failed {table_name}: {e}')
raise
# 处理 TABLE_ORDER 未列出的模型表
for mapper in db.Model.registry.mappers:
table = mapper.class_.__table__
if table.name in existing or table.name in created:
continue
if table.name in TABLE_ORDER:
continue
sql = build_doris_create_sql(table)
try:
db.session.execute(db.text(sql))
db.session.commit()
print(f' [doris_create] Created table: {table.name}')
except Exception as e:
db.session.rollback()
print(f' [doris_create] Failed {table.name}: {e}')
raise
return created
KESHI_LEDGER_COLUMNS = [
('students', 'display_name', 'VARCHAR(50)'),
('students', 'nickname', 'VARCHAR(50)'),
('courses', 'course_code', 'VARCHAR(20)'),
('student_accounts', 'unit_price', 'DECIMAL(10,4) DEFAULT "0"'),
('student_accounts', 'original_price_per_lesson', 'DECIMAL(10,4)'),
('student_accounts', 'account_status', 'VARCHAR(20) DEFAULT "active"'),
('recharge_records', 'gift_clawed_back', 'SMALLINT DEFAULT "0"'),
('refund_records', 'deduct_gifted_amount', 'DECIMAL(12,2) DEFAULT "0"'),
('refund_records', 'promo_clawback_hours', 'DECIMAL(10,2) DEFAULT "0"'),
('consumption_records', 'unit_price_at_consume', 'DECIMAL(10,4)'),
('student_accounts', 'version', 'INT DEFAULT "0"'),
]
def migrate_keshi_ledger_columns():
"""课时核对表相关列Doris 下确保 ALTER 执行(补齐 auto_migrate 漏迁场景)"""
if not is_mysql():
return
for table_name, col_name, col_type in KESHI_LEDGER_COLUMNS:
try:
rows = db.session.execute(
db.text(f'SHOW COLUMNS FROM {quote_ident(table_name)}')
).fetchall()
existing = {r[0] for r in rows}
except Exception:
continue
if col_name in existing:
continue
sql = (
f'ALTER TABLE {quote_ident(table_name)} '
f'ADD COLUMN {quote_ident(col_name)} {col_type}'
)
try:
db.session.execute(db.text(sql))
db.session.commit()
print(f' [migrate_keshi] Added {table_name}.{col_name}')
except Exception as e:
db.session.rollback()
print(f' [migrate_keshi] Failed {table_name}.{col_name}: {e}')
def bootstrap_database():
"""统一入口:建表 + 自动迁移"""
if is_doris():
print('[bootstrap] Database: Doris')
register_doris_id_listener()
doris_create_all()
else:
print(f'[bootstrap] Database: {"MySQL" if is_mysql() else "SQLite"}')
db.create_all()
from utils import auto_migrate
auto_migrate()
migrate_keshi_ledger_columns()