269 lines
9.1 KiB
Python
269 lines
9.1 KiB
Python
"""数据库初始化:Doris / MySQL / SQLite 建表与迁移"""
|
||
import os
|
||
import re
|
||
|
||
from sqlalchemy import event, inspect as sa_inspect
|
||
from sqlalchemy.orm import Session
|
||
|
||
from models import db
|
||
|
||
_doris_id_listener_registered = False
|
||
|
||
# 建表顺序(逻辑依赖顺序,Doris 不强制外键)
|
||
TABLE_ORDER = [
|
||
'roles', 'users', 'teachers', 'students', 'courses', 'classes',
|
||
'class_students', 'student_accounts', 'recharge_activities',
|
||
'recharge_records', 'consumption_records', 'refund_records',
|
||
'transfer_records', 'stop_records', 'attendance_records',
|
||
'operation_logs', 'schedules', 'monthly_snapshots', 'class_records',
|
||
]
|
||
|
||
DORIS_BUCKETS = int(os.environ.get('DORIS_BUCKETS', '1'))
|
||
DORIS_REPLICATION_NUM = os.environ.get('DORIS_REPLICATION_NUM', '1')
|
||
|
||
|
||
def _get_uri():
|
||
from flask import current_app
|
||
return current_app.config.get('SQLALCHEMY_DATABASE_URI', '')
|
||
|
||
|
||
def is_doris():
|
||
"""是否连接 Apache Doris(MySQL 协议,FE 默认 9030)"""
|
||
if os.environ.get('DB_TYPE', '').lower() == 'doris':
|
||
return True
|
||
uri = _get_uri()
|
||
if 'doris' in uri.lower():
|
||
return True
|
||
m = re.search(r':(\d+)/', uri)
|
||
return bool(m and m.group(1) == '9030')
|
||
|
||
|
||
def is_mysql():
|
||
"""是否 MySQL 协议(含 Doris)"""
|
||
return _get_uri().startswith('mysql')
|
||
|
||
|
||
def is_sqlite():
|
||
return _get_uri().startswith('sqlite')
|
||
|
||
|
||
def quote_ident(name):
|
||
return f'`{name}`'
|
||
|
||
|
||
def doris_next_id(session, table_name, pk_name='id'):
|
||
"""Doris 无 LAST_INSERT_ID,用 MAX(id)+1 分配主键"""
|
||
sql = f'SELECT COALESCE(MAX({quote_ident(pk_name)}), 0) + 1 FROM {quote_ident(table_name)}'
|
||
return int(session.execute(db.text(sql)).scalar() or 1)
|
||
|
||
|
||
def register_doris_id_listener():
|
||
"""插入前为无主键记录分配 id(仅 Doris)"""
|
||
global _doris_id_listener_registered
|
||
if _doris_id_listener_registered:
|
||
return
|
||
_doris_id_listener_registered = True
|
||
|
||
@event.listens_for(Session, 'before_flush')
|
||
def _assign_doris_ids(session, flush_context, instances):
|
||
if not is_doris():
|
||
return
|
||
for obj in session.new:
|
||
mapper = sa_inspect(obj.__class__)
|
||
if not mapper.tables:
|
||
continue
|
||
table = mapper.tables[0]
|
||
for col in mapper.primary_key:
|
||
if not getattr(col, 'autoincrement', False):
|
||
continue
|
||
if getattr(obj, col.key) is not None:
|
||
continue
|
||
new_id = doris_next_id(session, table.name, col.key)
|
||
setattr(obj, col.key, new_id)
|
||
|
||
|
||
def _col_type_doris(col):
|
||
"""SQLAlchemy 列类型 -> Doris DDL 类型"""
|
||
if col.primary_key and getattr(col, 'autoincrement', False):
|
||
# 主键由应用分配(Doris 不支持可靠的 AUTO_INCREMENT / LAST_INSERT_ID)
|
||
return 'BIGINT NOT NULL'
|
||
|
||
type_str = str(col.type).upper()
|
||
if 'BIGINT' in type_str:
|
||
return 'BIGINT'
|
||
if 'SMALLINT' in type_str or 'TINYINT' in type_str:
|
||
return 'SMALLINT'
|
||
if 'INTEGER' in type_str or type_str == 'INT':
|
||
return 'INT'
|
||
if 'VARCHAR' in type_str:
|
||
return type_str.replace('NVARCHAR', 'VARCHAR')
|
||
if 'TEXT' in type_str:
|
||
return 'STRING'
|
||
if 'DATETIME' in type_str:
|
||
return 'DATETIME'
|
||
if type_str == 'DATE' or type_str.startswith('DATE('):
|
||
return 'DATE'
|
||
if 'NUMERIC' in type_str or 'DECIMAL' in type_str:
|
||
return type_str.replace('NUMERIC', 'DECIMAL')
|
||
if 'FLOAT' in type_str or 'DOUBLE' in type_str:
|
||
return type_str
|
||
if 'BOOLEAN' in type_str:
|
||
return 'TINYINT'
|
||
return type_str
|
||
|
||
|
||
def _default_clause_doris(col):
|
||
"""生成 Doris 列 DEFAULT 子句"""
|
||
if col.primary_key and getattr(col, 'autoincrement', False):
|
||
return ''
|
||
if col.server_default is not None:
|
||
return ''
|
||
if col.default is None:
|
||
return ''
|
||
if not getattr(col.default, 'is_scalar', True):
|
||
return ''
|
||
val = col.default.arg
|
||
if val is None:
|
||
return ''
|
||
if isinstance(val, bool):
|
||
return f' DEFAULT "{int(val)}"'
|
||
if isinstance(val, (int, float)):
|
||
return f' DEFAULT "{val}"'
|
||
if isinstance(val, str):
|
||
return f" DEFAULT '{val.replace(chr(39), chr(39) + chr(39))}'"
|
||
return ''
|
||
|
||
|
||
def _build_doris_column(col):
|
||
parts = [quote_ident(col.name), _col_type_doris(col)]
|
||
if not col.nullable and not (col.primary_key and getattr(col, 'autoincrement', False)):
|
||
parts.append('NOT NULL')
|
||
parts.append(_default_clause_doris(col))
|
||
return ' '.join(p for p in parts if p)
|
||
|
||
|
||
def build_doris_create_sql(table):
|
||
"""根据 SQLAlchemy Table 生成 Doris CREATE TABLE 语句"""
|
||
pk_cols = [c for c in table.columns if c.primary_key]
|
||
pk_col = pk_cols[0] if pk_cols else table.columns[0]
|
||
pk_name = pk_col.name
|
||
|
||
col_defs = [_build_doris_column(c) for c in table.columns]
|
||
tname = quote_ident(table.name)
|
||
|
||
return (
|
||
f'CREATE TABLE IF NOT EXISTS {tname} (\n'
|
||
f' {",\n ".join(col_defs)}\n'
|
||
f') UNIQUE KEY({quote_ident(pk_name)})\n'
|
||
f'DISTRIBUTED BY HASH({quote_ident(pk_name)}) BUCKETS {DORIS_BUCKETS}\n'
|
||
f'PROPERTIES ("replication_num" = "{DORIS_REPLICATION_NUM}")'
|
||
)
|
||
|
||
|
||
def _table_by_name(name):
|
||
for mapper in db.Model.registry.mappers:
|
||
if mapper.class_.__tablename__ == name:
|
||
return mapper.class_.__table__
|
||
return None
|
||
|
||
|
||
def doris_create_all():
|
||
"""Doris:按模型创建缺失的表(首次启动 / 新表)"""
|
||
inspector = sa_inspect(db.engine)
|
||
existing = set(inspector.get_table_names())
|
||
|
||
created = []
|
||
for table_name in TABLE_ORDER:
|
||
if table_name in existing:
|
||
continue
|
||
table = _table_by_name(table_name)
|
||
if table is None:
|
||
print(f' [doris_create] Skip unknown table: {table_name}')
|
||
continue
|
||
sql = build_doris_create_sql(table)
|
||
try:
|
||
db.session.execute(db.text(sql))
|
||
db.session.commit()
|
||
created.append(table_name)
|
||
print(f' [doris_create] Created table: {table_name}')
|
||
except Exception as e:
|
||
db.session.rollback()
|
||
print(f' [doris_create] Failed {table_name}: {e}')
|
||
raise
|
||
|
||
# 处理 TABLE_ORDER 未列出的模型表
|
||
for mapper in db.Model.registry.mappers:
|
||
table = mapper.class_.__table__
|
||
if table.name in existing or table.name in created:
|
||
continue
|
||
if table.name in TABLE_ORDER:
|
||
continue
|
||
sql = build_doris_create_sql(table)
|
||
try:
|
||
db.session.execute(db.text(sql))
|
||
db.session.commit()
|
||
print(f' [doris_create] Created table: {table.name}')
|
||
except Exception as e:
|
||
db.session.rollback()
|
||
print(f' [doris_create] Failed {table.name}: {e}')
|
||
raise
|
||
|
||
return created
|
||
|
||
|
||
KESHI_LEDGER_COLUMNS = [
|
||
('students', 'display_name', 'VARCHAR(50)'),
|
||
('students', 'nickname', 'VARCHAR(50)'),
|
||
('courses', 'course_code', 'VARCHAR(20)'),
|
||
('student_accounts', 'unit_price', 'DECIMAL(10,4) DEFAULT "0"'),
|
||
('student_accounts', 'original_price_per_lesson', 'DECIMAL(10,4)'),
|
||
('student_accounts', 'account_status', 'VARCHAR(20) DEFAULT "active"'),
|
||
('recharge_records', 'gift_clawed_back', 'SMALLINT DEFAULT "0"'),
|
||
('refund_records', 'deduct_gifted_amount', 'DECIMAL(12,2) DEFAULT "0"'),
|
||
('refund_records', 'promo_clawback_hours', 'DECIMAL(10,2) DEFAULT "0"'),
|
||
('consumption_records', 'unit_price_at_consume', 'DECIMAL(10,4)'),
|
||
('student_accounts', 'version', 'INT DEFAULT "0"'),
|
||
]
|
||
|
||
|
||
def migrate_keshi_ledger_columns():
|
||
"""课时核对表相关列:Doris 下确保 ALTER 执行(补齐 auto_migrate 漏迁场景)"""
|
||
if not is_mysql():
|
||
return
|
||
for table_name, col_name, col_type in KESHI_LEDGER_COLUMNS:
|
||
try:
|
||
rows = db.session.execute(
|
||
db.text(f'SHOW COLUMNS FROM {quote_ident(table_name)}')
|
||
).fetchall()
|
||
existing = {r[0] for r in rows}
|
||
except Exception:
|
||
continue
|
||
if col_name in existing:
|
||
continue
|
||
sql = (
|
||
f'ALTER TABLE {quote_ident(table_name)} '
|
||
f'ADD COLUMN {quote_ident(col_name)} {col_type}'
|
||
)
|
||
try:
|
||
db.session.execute(db.text(sql))
|
||
db.session.commit()
|
||
print(f' [migrate_keshi] Added {table_name}.{col_name}')
|
||
except Exception as e:
|
||
db.session.rollback()
|
||
print(f' [migrate_keshi] Failed {table_name}.{col_name}: {e}')
|
||
|
||
|
||
def bootstrap_database():
|
||
"""统一入口:建表 + 自动迁移"""
|
||
if is_doris():
|
||
print('[bootstrap] Database: Doris')
|
||
register_doris_id_listener()
|
||
doris_create_all()
|
||
else:
|
||
print(f'[bootstrap] Database: {"MySQL" if is_mysql() else "SQLite"}')
|
||
db.create_all()
|
||
|
||
from utils import auto_migrate
|
||
auto_migrate()
|
||
migrate_keshi_ledger_columns()
|