Files
drl_2/xuexiao/db_bootstrap.py
user9994793890 ee860ce0ae Initial commit
2026-05-29 10:28:07 +08:00

269 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""数据库初始化Doris / MySQL / SQLite 建表与迁移"""
import os
import re
from sqlalchemy import event, inspect as sa_inspect
from sqlalchemy.orm import Session
from models import db
_doris_id_listener_registered = False
# 建表顺序逻辑依赖顺序Doris 不强制外键)
TABLE_ORDER = [
'roles', 'users', 'teachers', 'students', 'courses', 'classes',
'class_students', 'student_accounts', 'recharge_activities',
'recharge_records', 'consumption_records', 'refund_records',
'transfer_records', 'stop_records', 'attendance_records',
'operation_logs', 'schedules', 'monthly_snapshots', 'class_records',
]
DORIS_BUCKETS = int(os.environ.get('DORIS_BUCKETS', '1'))
DORIS_REPLICATION_NUM = os.environ.get('DORIS_REPLICATION_NUM', '1')
def _get_uri():
from flask import current_app
return current_app.config.get('SQLALCHEMY_DATABASE_URI', '')
def is_doris():
"""是否连接 Apache DorisMySQL 协议FE 默认 9030"""
if os.environ.get('DB_TYPE', '').lower() == 'doris':
return True
uri = _get_uri()
if 'doris' in uri.lower():
return True
m = re.search(r':(\d+)/', uri)
return bool(m and m.group(1) == '9030')
def is_mysql():
"""是否 MySQL 协议(含 Doris"""
return _get_uri().startswith('mysql')
def is_sqlite():
return _get_uri().startswith('sqlite')
def quote_ident(name):
return f'`{name}`'
def doris_next_id(session, table_name, pk_name='id'):
"""Doris 无 LAST_INSERT_ID用 MAX(id)+1 分配主键"""
sql = f'SELECT COALESCE(MAX({quote_ident(pk_name)}), 0) + 1 FROM {quote_ident(table_name)}'
return int(session.execute(db.text(sql)).scalar() or 1)
def register_doris_id_listener():
"""插入前为无主键记录分配 id仅 Doris"""
global _doris_id_listener_registered
if _doris_id_listener_registered:
return
_doris_id_listener_registered = True
@event.listens_for(Session, 'before_flush')
def _assign_doris_ids(session, flush_context, instances):
if not is_doris():
return
for obj in session.new:
mapper = sa_inspect(obj.__class__)
if not mapper.tables:
continue
table = mapper.tables[0]
for col in mapper.primary_key:
if not getattr(col, 'autoincrement', False):
continue
if getattr(obj, col.key) is not None:
continue
new_id = doris_next_id(session, table.name, col.key)
setattr(obj, col.key, new_id)
def _col_type_doris(col):
"""SQLAlchemy 列类型 -> Doris DDL 类型"""
if col.primary_key and getattr(col, 'autoincrement', False):
# 主键由应用分配Doris 不支持可靠的 AUTO_INCREMENT / LAST_INSERT_ID
return 'BIGINT NOT NULL'
type_str = str(col.type).upper()
if 'BIGINT' in type_str:
return 'BIGINT'
if 'SMALLINT' in type_str or 'TINYINT' in type_str:
return 'SMALLINT'
if 'INTEGER' in type_str or type_str == 'INT':
return 'INT'
if 'VARCHAR' in type_str:
return type_str.replace('NVARCHAR', 'VARCHAR')
if 'TEXT' in type_str:
return 'STRING'
if 'DATETIME' in type_str:
return 'DATETIME'
if type_str == 'DATE' or type_str.startswith('DATE('):
return 'DATE'
if 'NUMERIC' in type_str or 'DECIMAL' in type_str:
return type_str.replace('NUMERIC', 'DECIMAL')
if 'FLOAT' in type_str or 'DOUBLE' in type_str:
return type_str
if 'BOOLEAN' in type_str:
return 'TINYINT'
return type_str
def _default_clause_doris(col):
"""生成 Doris 列 DEFAULT 子句"""
if col.primary_key and getattr(col, 'autoincrement', False):
return ''
if col.server_default is not None:
return ''
if col.default is None:
return ''
if not getattr(col.default, 'is_scalar', True):
return ''
val = col.default.arg
if val is None:
return ''
if isinstance(val, bool):
return f' DEFAULT "{int(val)}"'
if isinstance(val, (int, float)):
return f' DEFAULT "{val}"'
if isinstance(val, str):
return f" DEFAULT '{val.replace(chr(39), chr(39) + chr(39))}'"
return ''
def _build_doris_column(col):
parts = [quote_ident(col.name), _col_type_doris(col)]
if not col.nullable and not (col.primary_key and getattr(col, 'autoincrement', False)):
parts.append('NOT NULL')
parts.append(_default_clause_doris(col))
return ' '.join(p for p in parts if p)
def build_doris_create_sql(table):
"""根据 SQLAlchemy Table 生成 Doris CREATE TABLE 语句"""
pk_cols = [c for c in table.columns if c.primary_key]
pk_col = pk_cols[0] if pk_cols else table.columns[0]
pk_name = pk_col.name
col_defs = [_build_doris_column(c) for c in table.columns]
tname = quote_ident(table.name)
return (
f'CREATE TABLE IF NOT EXISTS {tname} (\n'
f' {",\n ".join(col_defs)}\n'
f') UNIQUE KEY({quote_ident(pk_name)})\n'
f'DISTRIBUTED BY HASH({quote_ident(pk_name)}) BUCKETS {DORIS_BUCKETS}\n'
f'PROPERTIES ("replication_num" = "{DORIS_REPLICATION_NUM}")'
)
def _table_by_name(name):
for mapper in db.Model.registry.mappers:
if mapper.class_.__tablename__ == name:
return mapper.class_.__table__
return None
def doris_create_all():
"""Doris按模型创建缺失的表首次启动 / 新表)"""
inspector = sa_inspect(db.engine)
existing = set(inspector.get_table_names())
created = []
for table_name in TABLE_ORDER:
if table_name in existing:
continue
table = _table_by_name(table_name)
if table is None:
print(f' [doris_create] Skip unknown table: {table_name}')
continue
sql = build_doris_create_sql(table)
try:
db.session.execute(db.text(sql))
db.session.commit()
created.append(table_name)
print(f' [doris_create] Created table: {table_name}')
except Exception as e:
db.session.rollback()
print(f' [doris_create] Failed {table_name}: {e}')
raise
# 处理 TABLE_ORDER 未列出的模型表
for mapper in db.Model.registry.mappers:
table = mapper.class_.__table__
if table.name in existing or table.name in created:
continue
if table.name in TABLE_ORDER:
continue
sql = build_doris_create_sql(table)
try:
db.session.execute(db.text(sql))
db.session.commit()
print(f' [doris_create] Created table: {table.name}')
except Exception as e:
db.session.rollback()
print(f' [doris_create] Failed {table.name}: {e}')
raise
return created
KESHI_LEDGER_COLUMNS = [
('students', 'display_name', 'VARCHAR(50)'),
('students', 'nickname', 'VARCHAR(50)'),
('courses', 'course_code', 'VARCHAR(20)'),
('student_accounts', 'unit_price', 'DECIMAL(10,4) DEFAULT "0"'),
('student_accounts', 'original_price_per_lesson', 'DECIMAL(10,4)'),
('student_accounts', 'account_status', 'VARCHAR(20) DEFAULT "active"'),
('recharge_records', 'gift_clawed_back', 'SMALLINT DEFAULT "0"'),
('refund_records', 'deduct_gifted_amount', 'DECIMAL(12,2) DEFAULT "0"'),
('refund_records', 'promo_clawback_hours', 'DECIMAL(10,2) DEFAULT "0"'),
('consumption_records', 'unit_price_at_consume', 'DECIMAL(10,4)'),
('student_accounts', 'version', 'INT DEFAULT "0"'),
]
def migrate_keshi_ledger_columns():
"""课时核对表相关列Doris 下确保 ALTER 执行(补齐 auto_migrate 漏迁场景)"""
if not is_mysql():
return
for table_name, col_name, col_type in KESHI_LEDGER_COLUMNS:
try:
rows = db.session.execute(
db.text(f'SHOW COLUMNS FROM {quote_ident(table_name)}')
).fetchall()
existing = {r[0] for r in rows}
except Exception:
continue
if col_name in existing:
continue
sql = (
f'ALTER TABLE {quote_ident(table_name)} '
f'ADD COLUMN {quote_ident(col_name)} {col_type}'
)
try:
db.session.execute(db.text(sql))
db.session.commit()
print(f' [migrate_keshi] Added {table_name}.{col_name}')
except Exception as e:
db.session.rollback()
print(f' [migrate_keshi] Failed {table_name}.{col_name}: {e}')
def bootstrap_database():
"""统一入口:建表 + 自动迁移"""
if is_doris():
print('[bootstrap] Database: Doris')
register_doris_id_listener()
doris_create_all()
else:
print(f'[bootstrap] Database: {"MySQL" if is_mysql() else "SQLite"}')
db.create_all()
from utils import auto_migrate
auto_migrate()
migrate_keshi_ledger_columns()