269 lines
9.1 KiB
Python
269 lines
9.1 KiB
Python
|
|
"""数据库初始化:Doris / MySQL / SQLite 建表与迁移"""
|
|||
|
|
import os
|
|||
|
|
import re
|
|||
|
|
|
|||
|
|
from sqlalchemy import event, inspect as sa_inspect
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from models import db
|
|||
|
|
|
|||
|
|
_doris_id_listener_registered = False
|
|||
|
|
|
|||
|
|
# 建表顺序(逻辑依赖顺序,Doris 不强制外键)
|
|||
|
|
TABLE_ORDER = [
|
|||
|
|
'roles', 'users', 'teachers', 'students', 'courses', 'classes',
|
|||
|
|
'class_students', 'student_accounts', 'recharge_activities',
|
|||
|
|
'recharge_records', 'consumption_records', 'refund_records',
|
|||
|
|
'transfer_records', 'stop_records', 'attendance_records',
|
|||
|
|
'operation_logs', 'schedules', 'monthly_snapshots', 'class_records',
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
DORIS_BUCKETS = int(os.environ.get('DORIS_BUCKETS', '1'))
|
|||
|
|
DORIS_REPLICATION_NUM = os.environ.get('DORIS_REPLICATION_NUM', '1')
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _get_uri():
|
|||
|
|
from flask import current_app
|
|||
|
|
return current_app.config.get('SQLALCHEMY_DATABASE_URI', '')
|
|||
|
|
|
|||
|
|
|
|||
|
|
def is_doris():
|
|||
|
|
"""是否连接 Apache Doris(MySQL 协议,FE 默认 9030)"""
|
|||
|
|
if os.environ.get('DB_TYPE', '').lower() == 'doris':
|
|||
|
|
return True
|
|||
|
|
uri = _get_uri()
|
|||
|
|
if 'doris' in uri.lower():
|
|||
|
|
return True
|
|||
|
|
m = re.search(r':(\d+)/', uri)
|
|||
|
|
return bool(m and m.group(1) == '9030')
|
|||
|
|
|
|||
|
|
|
|||
|
|
def is_mysql():
|
|||
|
|
"""是否 MySQL 协议(含 Doris)"""
|
|||
|
|
return _get_uri().startswith('mysql')
|
|||
|
|
|
|||
|
|
|
|||
|
|
def is_sqlite():
|
|||
|
|
return _get_uri().startswith('sqlite')
|
|||
|
|
|
|||
|
|
|
|||
|
|
def quote_ident(name):
|
|||
|
|
return f'`{name}`'
|
|||
|
|
|
|||
|
|
|
|||
|
|
def doris_next_id(session, table_name, pk_name='id'):
|
|||
|
|
"""Doris 无 LAST_INSERT_ID,用 MAX(id)+1 分配主键"""
|
|||
|
|
sql = f'SELECT COALESCE(MAX({quote_ident(pk_name)}), 0) + 1 FROM {quote_ident(table_name)}'
|
|||
|
|
return int(session.execute(db.text(sql)).scalar() or 1)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def register_doris_id_listener():
|
|||
|
|
"""插入前为无主键记录分配 id(仅 Doris)"""
|
|||
|
|
global _doris_id_listener_registered
|
|||
|
|
if _doris_id_listener_registered:
|
|||
|
|
return
|
|||
|
|
_doris_id_listener_registered = True
|
|||
|
|
|
|||
|
|
@event.listens_for(Session, 'before_flush')
|
|||
|
|
def _assign_doris_ids(session, flush_context, instances):
|
|||
|
|
if not is_doris():
|
|||
|
|
return
|
|||
|
|
for obj in session.new:
|
|||
|
|
mapper = sa_inspect(obj.__class__)
|
|||
|
|
if not mapper.tables:
|
|||
|
|
continue
|
|||
|
|
table = mapper.tables[0]
|
|||
|
|
for col in mapper.primary_key:
|
|||
|
|
if not getattr(col, 'autoincrement', False):
|
|||
|
|
continue
|
|||
|
|
if getattr(obj, col.key) is not None:
|
|||
|
|
continue
|
|||
|
|
new_id = doris_next_id(session, table.name, col.key)
|
|||
|
|
setattr(obj, col.key, new_id)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _col_type_doris(col):
|
|||
|
|
"""SQLAlchemy 列类型 -> Doris DDL 类型"""
|
|||
|
|
if col.primary_key and getattr(col, 'autoincrement', False):
|
|||
|
|
# 主键由应用分配(Doris 不支持可靠的 AUTO_INCREMENT / LAST_INSERT_ID)
|
|||
|
|
return 'BIGINT NOT NULL'
|
|||
|
|
|
|||
|
|
type_str = str(col.type).upper()
|
|||
|
|
if 'BIGINT' in type_str:
|
|||
|
|
return 'BIGINT'
|
|||
|
|
if 'SMALLINT' in type_str or 'TINYINT' in type_str:
|
|||
|
|
return 'SMALLINT'
|
|||
|
|
if 'INTEGER' in type_str or type_str == 'INT':
|
|||
|
|
return 'INT'
|
|||
|
|
if 'VARCHAR' in type_str:
|
|||
|
|
return type_str.replace('NVARCHAR', 'VARCHAR')
|
|||
|
|
if 'TEXT' in type_str:
|
|||
|
|
return 'STRING'
|
|||
|
|
if 'DATETIME' in type_str:
|
|||
|
|
return 'DATETIME'
|
|||
|
|
if type_str == 'DATE' or type_str.startswith('DATE('):
|
|||
|
|
return 'DATE'
|
|||
|
|
if 'NUMERIC' in type_str or 'DECIMAL' in type_str:
|
|||
|
|
return type_str.replace('NUMERIC', 'DECIMAL')
|
|||
|
|
if 'FLOAT' in type_str or 'DOUBLE' in type_str:
|
|||
|
|
return type_str
|
|||
|
|
if 'BOOLEAN' in type_str:
|
|||
|
|
return 'TINYINT'
|
|||
|
|
return type_str
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _default_clause_doris(col):
|
|||
|
|
"""生成 Doris 列 DEFAULT 子句"""
|
|||
|
|
if col.primary_key and getattr(col, 'autoincrement', False):
|
|||
|
|
return ''
|
|||
|
|
if col.server_default is not None:
|
|||
|
|
return ''
|
|||
|
|
if col.default is None:
|
|||
|
|
return ''
|
|||
|
|
if not getattr(col.default, 'is_scalar', True):
|
|||
|
|
return ''
|
|||
|
|
val = col.default.arg
|
|||
|
|
if val is None:
|
|||
|
|
return ''
|
|||
|
|
if isinstance(val, bool):
|
|||
|
|
return f' DEFAULT "{int(val)}"'
|
|||
|
|
if isinstance(val, (int, float)):
|
|||
|
|
return f' DEFAULT "{val}"'
|
|||
|
|
if isinstance(val, str):
|
|||
|
|
return f" DEFAULT '{val.replace(chr(39), chr(39) + chr(39))}'"
|
|||
|
|
return ''
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _build_doris_column(col):
|
|||
|
|
parts = [quote_ident(col.name), _col_type_doris(col)]
|
|||
|
|
if not col.nullable and not (col.primary_key and getattr(col, 'autoincrement', False)):
|
|||
|
|
parts.append('NOT NULL')
|
|||
|
|
parts.append(_default_clause_doris(col))
|
|||
|
|
return ' '.join(p for p in parts if p)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_doris_create_sql(table):
|
|||
|
|
"""根据 SQLAlchemy Table 生成 Doris CREATE TABLE 语句"""
|
|||
|
|
pk_cols = [c for c in table.columns if c.primary_key]
|
|||
|
|
pk_col = pk_cols[0] if pk_cols else table.columns[0]
|
|||
|
|
pk_name = pk_col.name
|
|||
|
|
|
|||
|
|
col_defs = [_build_doris_column(c) for c in table.columns]
|
|||
|
|
tname = quote_ident(table.name)
|
|||
|
|
|
|||
|
|
return (
|
|||
|
|
f'CREATE TABLE IF NOT EXISTS {tname} (\n'
|
|||
|
|
f' {",\n ".join(col_defs)}\n'
|
|||
|
|
f') UNIQUE KEY({quote_ident(pk_name)})\n'
|
|||
|
|
f'DISTRIBUTED BY HASH({quote_ident(pk_name)}) BUCKETS {DORIS_BUCKETS}\n'
|
|||
|
|
f'PROPERTIES ("replication_num" = "{DORIS_REPLICATION_NUM}")'
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _table_by_name(name):
|
|||
|
|
for mapper in db.Model.registry.mappers:
|
|||
|
|
if mapper.class_.__tablename__ == name:
|
|||
|
|
return mapper.class_.__table__
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def doris_create_all():
|
|||
|
|
"""Doris:按模型创建缺失的表(首次启动 / 新表)"""
|
|||
|
|
inspector = sa_inspect(db.engine)
|
|||
|
|
existing = set(inspector.get_table_names())
|
|||
|
|
|
|||
|
|
created = []
|
|||
|
|
for table_name in TABLE_ORDER:
|
|||
|
|
if table_name in existing:
|
|||
|
|
continue
|
|||
|
|
table = _table_by_name(table_name)
|
|||
|
|
if table is None:
|
|||
|
|
print(f' [doris_create] Skip unknown table: {table_name}')
|
|||
|
|
continue
|
|||
|
|
sql = build_doris_create_sql(table)
|
|||
|
|
try:
|
|||
|
|
db.session.execute(db.text(sql))
|
|||
|
|
db.session.commit()
|
|||
|
|
created.append(table_name)
|
|||
|
|
print(f' [doris_create] Created table: {table_name}')
|
|||
|
|
except Exception as e:
|
|||
|
|
db.session.rollback()
|
|||
|
|
print(f' [doris_create] Failed {table_name}: {e}')
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
# 处理 TABLE_ORDER 未列出的模型表
|
|||
|
|
for mapper in db.Model.registry.mappers:
|
|||
|
|
table = mapper.class_.__table__
|
|||
|
|
if table.name in existing or table.name in created:
|
|||
|
|
continue
|
|||
|
|
if table.name in TABLE_ORDER:
|
|||
|
|
continue
|
|||
|
|
sql = build_doris_create_sql(table)
|
|||
|
|
try:
|
|||
|
|
db.session.execute(db.text(sql))
|
|||
|
|
db.session.commit()
|
|||
|
|
print(f' [doris_create] Created table: {table.name}')
|
|||
|
|
except Exception as e:
|
|||
|
|
db.session.rollback()
|
|||
|
|
print(f' [doris_create] Failed {table.name}: {e}')
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
return created
|
|||
|
|
|
|||
|
|
|
|||
|
|
KESHI_LEDGER_COLUMNS = [
|
|||
|
|
('students', 'display_name', 'VARCHAR(50)'),
|
|||
|
|
('students', 'nickname', 'VARCHAR(50)'),
|
|||
|
|
('courses', 'course_code', 'VARCHAR(20)'),
|
|||
|
|
('student_accounts', 'unit_price', 'DECIMAL(10,4) DEFAULT "0"'),
|
|||
|
|
('student_accounts', 'original_price_per_lesson', 'DECIMAL(10,4)'),
|
|||
|
|
('student_accounts', 'account_status', 'VARCHAR(20) DEFAULT "active"'),
|
|||
|
|
('recharge_records', 'gift_clawed_back', 'SMALLINT DEFAULT "0"'),
|
|||
|
|
('refund_records', 'deduct_gifted_amount', 'DECIMAL(12,2) DEFAULT "0"'),
|
|||
|
|
('refund_records', 'promo_clawback_hours', 'DECIMAL(10,2) DEFAULT "0"'),
|
|||
|
|
('consumption_records', 'unit_price_at_consume', 'DECIMAL(10,4)'),
|
|||
|
|
('student_accounts', 'version', 'INT DEFAULT "0"'),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def migrate_keshi_ledger_columns():
|
|||
|
|
"""课时核对表相关列:Doris 下确保 ALTER 执行(补齐 auto_migrate 漏迁场景)"""
|
|||
|
|
if not is_mysql():
|
|||
|
|
return
|
|||
|
|
for table_name, col_name, col_type in KESHI_LEDGER_COLUMNS:
|
|||
|
|
try:
|
|||
|
|
rows = db.session.execute(
|
|||
|
|
db.text(f'SHOW COLUMNS FROM {quote_ident(table_name)}')
|
|||
|
|
).fetchall()
|
|||
|
|
existing = {r[0] for r in rows}
|
|||
|
|
except Exception:
|
|||
|
|
continue
|
|||
|
|
if col_name in existing:
|
|||
|
|
continue
|
|||
|
|
sql = (
|
|||
|
|
f'ALTER TABLE {quote_ident(table_name)} '
|
|||
|
|
f'ADD COLUMN {quote_ident(col_name)} {col_type}'
|
|||
|
|
)
|
|||
|
|
try:
|
|||
|
|
db.session.execute(db.text(sql))
|
|||
|
|
db.session.commit()
|
|||
|
|
print(f' [migrate_keshi] Added {table_name}.{col_name}')
|
|||
|
|
except Exception as e:
|
|||
|
|
db.session.rollback()
|
|||
|
|
print(f' [migrate_keshi] Failed {table_name}.{col_name}: {e}')
|
|||
|
|
|
|||
|
|
|
|||
|
|
def bootstrap_database():
|
|||
|
|
"""统一入口:建表 + 自动迁移"""
|
|||
|
|
if is_doris():
|
|||
|
|
print('[bootstrap] Database: Doris')
|
|||
|
|
register_doris_id_listener()
|
|||
|
|
doris_create_all()
|
|||
|
|
else:
|
|||
|
|
print(f'[bootstrap] Database: {"MySQL" if is_mysql() else "SQLite"}')
|
|||
|
|
db.create_all()
|
|||
|
|
|
|||
|
|
from utils import auto_migrate
|
|||
|
|
auto_migrate()
|
|||
|
|
migrate_keshi_ledger_columns()
|