diff --git a/Controller/DouYinController.py b/Controller/DouYinController.py index 8109091..29fad06 100644 --- a/Controller/DouYinController.py +++ b/Controller/DouYinController.py @@ -7,357 +7,263 @@ import asyncio from datetime import datetime from typing import List, Optional -from fastapi import APIRouter, HTTPException, BackgroundTasks -from fastapi.responses import StreamingResponse -from pydantic import BaseModel -import pymysql - -# Import custom modules -from Config.Config import OBS_CLOUD_PREFIX, OBS_BUCKET, OBS_TMP_PREFIX, DORIS_HOST, DORIS_PORT, DORIS_USER, DORIS_PWD, DORIS_DATABASE, OBS_SERVER -from Util.DouYinDownloader import DouYinDownloader -from Util.ObsUtil import ObsUploader -from Util.ASRClient import ASRClient -from Util.LlmUtil import get_llm_response - -# Logger setup -logger = logging.getLogger(__name__) - -router = APIRouter() - -# Database connection -def get_db_connection(): - return pymysql.connect( - host=DORIS_HOST, - port=DORIS_PORT, - user=DORIS_USER, - password=DORIS_PWD, - database=DORIS_DATABASE, - charset='utf8mb4', - cursorclass=pymysql.cursors.DictCursor - ) - -class ParseRequest(BaseModel): - text: str - -class SummaryRequest(BaseModel): - ids: List[str] = [] - -def update_status(id, status, error_msg=None): - try: - conn = get_db_connection() - cursor = conn.cursor() - if error_msg: - sql = "UPDATE t_douyin_record SET status=%s, error_msg=%s WHERE id=%s" - cursor.execute(sql, (status, error_msg, id)) - else: - sql = "UPDATE t_douyin_record SET status=%s WHERE id=%s" - cursor.execute(sql, (status, id)) - conn.commit() - conn.close() - except Exception as e: - logger.error(f"DB Error update_status: {e}") - -def update_record(id, title, obs_url, transcript, status): - try: - # Truncate title to 100 chars to fit DB schema (approx 400 bytes max for utf8mb4) - if title and len(title) > 100: - title = title[:100] + "..." - - conn = get_db_connection() - cursor = conn.cursor() - sql = """ - UPDATE t_douyin_record - SET video_name=%s, obs_url=%s, transcript=%s, status=%s - WHERE id=%s - """ - cursor.execute(sql, (title, obs_url, transcript, status, id)) - conn.commit() - conn.close() - except Exception as e: - logger.error(f"DB Error update_record: {e}") - -async def process_video_task(url: str, request_id: str, share_text: str = ""): - logger.info(f"Processing task {request_id}") - - # 1. Update status - await asyncio.to_thread(update_status, request_id, "PROCESSING") - - temp_dir = os.path.abspath(f"temp_{request_id}") - try: - if not os.path.exists(temp_dir): - os.makedirs(temp_dir) - - # 2. Parse & Download - downloader = DouYinDownloader() - # url is passed directly now - if not url: - raise Exception("No valid URL found") - - logger.info(f"Downloading from {url}") - # Run download in thread to avoid blocking main loop - local_video_path, title = await asyncio.to_thread(downloader.download_video, url, temp_dir) - - # Title handling strategy: - # Priority 1: Extracted from share text (if available and valid) - # Priority 2: Extracted from video download (often "Unknown Title") - # Priority 3: Generated by LLM (done later) - - extracted_title = downloader.extract_title_from_text(share_text) - logger.info(f"Extracted title from text: {extracted_title}") - - # If we have a valid extracted title, use it. - # But if we don't have a title yet (or it's Unknown), we definitely want to use extracted_title. - # Even if we have a title from yt-dlp, if it's just "Unknown Title", we prefer extracted one. - if extracted_title and extracted_title != "Unknown Title": - title = extracted_title - elif not title: - title = "Unknown Title" - - if not local_video_path or not os.path.exists(local_video_path): - raise Exception("Download failed") - - # 3. Upload Video to OBS (Long term storage) - logger.info("Uploading video to OBS...") - uploader = ObsUploader() - video_filename = os.path.basename(local_video_path) - obs_video_key = f"{OBS_CLOUD_PREFIX}/DouYin/{video_filename}" - - success, _ = await asyncio.to_thread(uploader.upload_file, obs_video_key, local_video_path, OBS_BUCKET) - if not success: - raise Exception("OBS Upload failed") - - # Construct public URL (Assuming standard OBS pattern or Config logic) - obs_url = f"https://{OBS_BUCKET}.{OBS_SERVER}/{obs_video_key}" - - # 4. Convert to MP3 - logger.info("Converting to MP3...") - mp3_path = os.path.splitext(local_video_path)[0] + ".mp3" - cmd = [ - "ffmpeg", "-y", "-i", local_video_path, - "-acodec", "libmp3lame", "-ar", "16000", "-ac", "1", "-q:a", "2", - mp3_path - ] - # Run ffmpeg in thread - result = await asyncio.to_thread(subprocess.run, cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE) - if result.returncode != 0: - raise Exception(f"FFmpeg failed: {result.stderr.decode()}") - - # 5. ASR (Upload MP3 to tmp and transcribe) - logger.info("Transcribing...") - asr = ASRClient() - # Run ASR in thread - transcript = await asyncio.to_thread(asr.upload_and_transcribe_sync, mp3_path) - - if not transcript: - raise Exception("Transcription failed (returned empty)") - - # 6. LLM Title Generation (Enhancement) - # If the title is still Unknown or weak, OR if we just want to ensure we have a good title. - # The user said: "Alternatively, call LlmUtil.py to summarize title". - # Let's do it if title is Unknown or matches default filename pattern, OR if extracted title was also missing. - if (not title or title == "Unknown Title" or title == "Unknown"): - try: - logger.info("Generating title from transcript via LLM...") - prompt = f"请根据以下视频文案总结一个简短的标题(20字以内),不要包含任何解释性文字,直接返回标题:\n\n{transcript[:1000]}" - - llm_title_chunks = [] - # get_llm_response is already async - async for chunk in get_llm_response(prompt, stream=False): - llm_title_chunks.append(chunk) - llm_title = "".join(llm_title_chunks) - - if llm_title: - # Clean up quotes if any - llm_title = llm_title.strip().strip('"').strip('“').strip('”') - logger.info(f"LLM generated title: {llm_title}") - # We overwrite the title if LLM succeeds - title = llm_title - except Exception as llm_e: - logger.warning(f"LLM Title generation failed: {llm_e}") - - # 7. Save to DB (Update) - logger.info("Saving to DB...") - await asyncio.to_thread(update_record, request_id, title, obs_url, transcript, "COMPLETED") - logger.info(f"Task {request_id} completed successfully.") - - except Exception as e: - logger.error(f"Task {request_id} failed: {e}", exc_info=True) - await asyncio.to_thread(update_status, request_id, "FAILED", str(e)) - finally: - # 8. Cleanup - if os.path.exists(temp_dir): - try: - # shutil.rmtree is sync, wrap it - await asyncio.to_thread(shutil.rmtree, temp_dir, ignore_errors=True) - except Exception as e: - logger.error(f"Cleanup failed: {e}") - -@router.post("/api/parse") -def parse(request: ParseRequest, background_tasks: BackgroundTasks): - downloader = DouYinDownloader() - urls = downloader.extract_urls(request.text) - - if not urls: - # If no URLs found, try using the text as is (might be a direct link not caught by regex) - # But regex is quite broad. Let's just fail or try one. - # Let's assume text might be the URL if it's clean. - if request.text.startswith("http"): - urls = [request.text] - else: - raise HTTPException(status_code=400, detail="No valid URLs found") - - created_ids = [] - try: - conn = get_db_connection() - cursor = conn.cursor() - - for url in urls: - req_id = str(uuid.uuid4()) - sql = """ - INSERT INTO t_douyin_record (id, original_text, status, create_time) - VALUES (%s, %s, 'PENDING', %s) - """ - cursor.execute(sql, (req_id, url, datetime.now())) - created_ids.append(req_id) - # Pass request.text (the full share text) so we can extract title from it - background_tasks.add_task(process_video_task, url, req_id, request.text) - - conn.commit() - conn.close() - except Exception as e: - raise HTTPException(status_code=500, detail=f"DB Init Error: {e}") - - return {"id": created_ids[0] if created_ids else None, "ids": created_ids, "status": "PENDING"} - -@router.get("/api/records") -def get_records(): - try: - conn = get_db_connection() - cursor = conn.cursor() - cursor.execute("SELECT * FROM t_douyin_record ORDER BY create_time DESC LIMIT 50") - records = cursor.fetchall() - conn.close() - - # Manually handle datetime serialization to be safe - for r in records: - if 'create_time' in r and r['create_time']: - r['create_time'] = r['create_time'].strftime("%Y-%m-%d %H:%M:%S") - if 'update_time' in r and r['update_time']: - r['update_time'] = r['update_time'].strftime("%Y-%m-%d %H:%M:%S") - - return records - except Exception as e: - logger.error(f"Get records error: {e}", exc_info=True) - return [] - -@router.delete("/api/records/{id}") -def delete_record(id: str): - try: - conn = get_db_connection() - cursor = conn.cursor() - cursor.execute("DELETE FROM t_douyin_record WHERE id=%s", (id,)) - conn.commit() - conn.close() - return {"status": "deleted"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -@router.post("/api/douyin/summary") -async def generate_summary(request: SummaryRequest): - try: - # Fetch transcripts - conn = get_db_connection() - cursor = conn.cursor() - - if request.ids: - # Secure way to handle list in SQL - format_strings = ','.join(['%s'] * len(request.ids)) - sql = f"SELECT video_name, transcript FROM t_douyin_record WHERE id IN ({format_strings}) AND status='COMPLETED'" - cursor.execute(sql, tuple(request.ids)) - else: - # Default to latest 20 - cursor.execute("SELECT video_name, transcript FROM t_douyin_record WHERE status='COMPLETED' ORDER BY create_time DESC LIMIT 20") - - records = cursor.fetchall() - conn.close() - - if not records: - # If no records, just return a simple message stream - async def empty_stream(): - yield "未找到可总结的已完成记录,请先解析视频。" - return StreamingResponse(empty_stream(), media_type="text/event-stream") - - # Prepare text - full_text = "" - for r in records: - if r['transcript']: - full_text += f"【标题:{r['video_name']}】\n内容:{r['transcript']}\n\n" - - if not full_text: - async def empty_text_stream(): - yield "记录中没有有效的文案内容。" - return StreamingResponse(empty_text_stream(), media_type="text/event-stream") - - # Prompt - prompt = f""" - 请对以下充电行业相关的视频内容进行知识精华提取。 - 要求: - 1. 忽略无关闲聊和口语化表达; - 2. 按条目列出核心知识点,不要长篇大论; - 3. 保持简洁专业,只保留干货; - 4. 返回格式为Markdown列表。 - - 内容如下: - {full_text[:15000]} - """ - - # Limit context to avoid errors, 15000 chars is roughly safe for most models, - # but if using a small model, might need less. Assuming robust model. - - return StreamingResponse(get_llm_response(prompt), media_type="text/event-stream") - - except Exception as e: - logger.error(f"Summary generation failed: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -async def recover_pending_tasks(): - """ - Check for tasks stuck in PENDING or PROCESSING state (due to server restart) - and restart them. - """ - logger.info("Scanning for interrupted Douyin tasks...") - try: - # Use asyncio.to_thread for DB operation - def fetch_pending(): - conn = get_db_connection() - cursor = conn.cursor() - # Select recent pending/processing tasks (limit 20 to avoid storm) - sql = """ - SELECT id, original_text, status - FROM t_douyin_record - WHERE status IN ('PENDING', 'PROCESSING') - ORDER BY create_time DESC LIMIT 20 - """ - cursor.execute(sql) - tasks = cursor.fetchall() - conn.close() - return tasks - - tasks = await asyncio.to_thread(fetch_pending) - - if not tasks: - logger.info("No interrupted tasks found.") - return - - logger.info(f"Found {len(tasks)} interrupted tasks. Restarting...") - for task in tasks: - req_id = task['id'] - url = task['original_text'] - # Restart task in background - # Note: We lost the original share text for title extraction, - # so we pass empty string. It will use the URL or 'Unknown Title'. - # If LLM is enabled, it might fix the title later. - asyncio.create_task(process_video_task(url, req_id, share_text="")) - - except Exception as e: +from fastapi import APIRouter, HTTPException, BackgroundTasks +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +# Import custom modules +from Config.Config import OBS_CLOUD_PREFIX, OBS_BUCKET, OBS_TMP_PREFIX, OBS_SERVER +from Util.DouYinDownloader import DouYinDownloader +from Util.ObsUtil import ObsUploader +from Util.ASRClient import ASRClient +from Util.LlmUtil import get_llm_response +from Model.DouYinModel import DouYinModel + +# Logger setup +logger = logging.getLogger(__name__) + +router = APIRouter() + +class ParseRequest(BaseModel): + text: str + +class SummaryRequest(BaseModel): + ids: List[str] = [] + +async def process_video_task(url: str, request_id: str, share_text: str = ""): + logger.info(f"Processing task {request_id}") + + model = DouYinModel() + # 1. Update status + await model.update_status(request_id, "PROCESSING") + + temp_dir = os.path.abspath(f"temp_{request_id}") + try: + if not os.path.exists(temp_dir): + os.makedirs(temp_dir) + + # 2. Parse & Download + downloader = DouYinDownloader() + # url is passed directly now + if not url: + raise Exception("No valid URL found") + + logger.info(f"Downloading from {url}") + # Run download in thread to avoid blocking main loop + local_video_path, title = await asyncio.to_thread(downloader.download_video, url, temp_dir) + + # Title handling strategy: + # Priority 1: Extracted from share text (if available and valid) + # Priority 2: Extracted from video download (often "Unknown Title") + # Priority 3: Generated by LLM (done later) + + extracted_title = downloader.extract_title_from_text(share_text) + logger.info(f"Extracted title from text: {extracted_title}") + + # If we have a valid extracted title, use it. + # But if we don't have a title yet (or it's Unknown), we definitely want to use extracted_title. + # Even if we have a title from yt-dlp, if it's just "Unknown Title", we prefer extracted one. + if extracted_title and extracted_title != "Unknown Title": + title = extracted_title + elif not title: + title = "Unknown Title" + + if not local_video_path or not os.path.exists(local_video_path): + raise Exception("Download failed") + + # 3. Upload Video to OBS (Long term storage) + logger.info("Uploading video to OBS...") + uploader = ObsUploader() + video_filename = os.path.basename(local_video_path) + obs_video_key = f"{OBS_CLOUD_PREFIX}/DouYin/{video_filename}" + + success, _ = await asyncio.to_thread(uploader.upload_file, obs_video_key, local_video_path, OBS_BUCKET) + if not success: + raise Exception("OBS Upload failed") + + # Construct public URL (Assuming standard OBS pattern or Config logic) + obs_url = f"https://{OBS_BUCKET}.{OBS_SERVER}/{obs_video_key}" + + # 4. Convert to MP3 + logger.info("Converting to MP3...") + mp3_path = os.path.splitext(local_video_path)[0] + ".mp3" + cmd = [ + "ffmpeg", "-y", "-i", local_video_path, + "-acodec", "libmp3lame", "-ar", "16000", "-ac", "1", "-q:a", "2", + mp3_path + ] + # Run ffmpeg in thread + result = await asyncio.to_thread(subprocess.run, cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE) + if result.returncode != 0: + raise Exception(f"FFmpeg failed: {result.stderr.decode()}") + + # 5. ASR (Upload MP3 to tmp and transcribe) + logger.info("Transcribing...") + asr = ASRClient() + # Run ASR in thread + transcript = await asyncio.to_thread(asr.upload_and_transcribe_sync, mp3_path) + + if not transcript: + raise Exception("Transcription failed (returned empty)") + + # 6. LLM Title Generation (Enhancement) + # If the title is still Unknown or weak, OR if we just want to ensure we have a good title. + # The user said: "Alternatively, call LlmUtil.py to summarize title". + # Let's do it if title is Unknown or matches default filename pattern, OR if extracted title was also missing. + if (not title or title == "Unknown Title" or title == "Unknown"): + try: + logger.info("Generating title from transcript via LLM...") + prompt = f"请根据以下视频文案总结一个简短的标题(20字以内),不要包含任何解释性文字,直接返回标题:\n\n{transcript[:1000]}" + + llm_title_chunks = [] + # get_llm_response is already async + async for chunk in get_llm_response(prompt, stream=False): + llm_title_chunks.append(chunk) + llm_title = "".join(llm_title_chunks) + + if llm_title: + # Clean up quotes if any + llm_title = llm_title.strip().strip('"').strip('“').strip('”') + logger.info(f"LLM generated title: {llm_title}") + # We overwrite the title if LLM succeeds + title = llm_title + except Exception as llm_e: + logger.warning(f"LLM Title generation failed: {llm_e}") + + # 7. Save to DB (Update) + logger.info("Saving to DB...") + await model.update_record(request_id, title, obs_url, transcript, "COMPLETED") + logger.info(f"Task {request_id} completed successfully.") + + except Exception as e: + logger.error(f"Task {request_id} failed: {e}", exc_info=True) + await model.update_status(request_id, "FAILED", str(e)) + finally: + # 8. Cleanup + if os.path.exists(temp_dir): + try: + # shutil.rmtree is sync, wrap it + await asyncio.to_thread(shutil.rmtree, temp_dir, ignore_errors=True) + except Exception as e: + logger.error(f"Cleanup failed: {e}") + +@router.post("/api/parse") +async def parse(request: ParseRequest, background_tasks: BackgroundTasks): + downloader = DouYinDownloader() + urls = downloader.extract_urls(request.text) + + if not urls: + # If no URLs found, try using the text as is (might be a direct link not caught by regex) + # But regex is quite broad. Let's just fail or try one. + # Let's assume text might be the URL if it's clean. + if request.text.startswith("http"): + urls = [request.text] + else: + raise HTTPException(status_code=400, detail="No valid URLs found") + + created_ids = [] + try: + model = DouYinModel() + for url in urls: + req_id = str(uuid.uuid4()) + await model.insert_record(req_id, url) + created_ids.append(req_id) + # Pass request.text (the full share text) so we can extract title from it + background_tasks.add_task(process_video_task, url, req_id, request.text) + + except Exception as e: + raise HTTPException(status_code=500, detail=f"DB Init Error: {e}") + + return {"id": created_ids[0] if created_ids else None, "ids": created_ids, "status": "PENDING"} + +@router.get("/api/records") +async def get_records(): + try: + model = DouYinModel() + records = await model.get_records() + return records + except Exception as e: + logger.error(f"Get records error: {e}", exc_info=True) + return [] + +@router.delete("/api/records/{id}") +async def delete_record(id: str): + try: + model = DouYinModel() + await model.delete_record(id) + return {"status": "deleted"} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.post("/api/douyin/summary") +async def generate_summary(request: SummaryRequest): + try: + # Fetch transcripts + model = DouYinModel() + records = await model.get_transcripts(ids=request.ids) + + if not records: + # If no records, just return a simple message stream + async def empty_stream(): + yield "未找到可总结的已完成记录,请先解析视频。" + return StreamingResponse(empty_stream(), media_type="text/event-stream") + + # Prepare text + full_text = "" + for r in records: + if r['transcript']: + full_text += f"【标题:{r['video_name']}】\n内容:{r['transcript']}\n\n" + + if not full_text: + async def empty_text_stream(): + yield "记录中没有有效的文案内容。" + return StreamingResponse(empty_text_stream(), media_type="text/event-stream") + + # Prompt + prompt = f""" + 请对以下充电行业相关的视频内容进行知识精华提取。 + 要求: + 1. 忽略无关闲聊和口语化表达; + 2. 按条目列出核心知识点,不要长篇大论; + 3. 保持简洁专业,只保留干货; + 4. 返回格式为Markdown列表。 + + 内容如下: + {full_text[:15000]} + """ + + # Limit context to avoid errors, 15000 chars is roughly safe for most models, + # but if using a small model, might need less. Assuming robust model. + + return StreamingResponse(get_llm_response(prompt), media_type="text/event-stream") + + except Exception as e: + logger.error(f"Summary generation failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +async def recover_pending_tasks(): + """ + Check for tasks stuck in PENDING or PROCESSING state (due to server restart) + and restart them. + """ + logger.info("Scanning for interrupted Douyin tasks...") + try: + model = DouYinModel() + tasks = await model.get_interrupted_tasks() + + if not tasks: + logger.info("No interrupted tasks found.") + return + + logger.info(f"Found {len(tasks)} interrupted tasks. Restarting...") + for task in tasks: + req_id = task['id'] + url = task['original_text'] + # Restart task in background + # Note: We lost the original share text for title extraction, + # so we pass empty string. It will use the URL or 'Unknown Title'. + # If LLM is enabled, it might fix the title later. + asyncio.create_task(process_video_task(url, req_id, share_text="")) + + except Exception as e: logger.error(f"Failed to recover tasks: {e}", exc_info=True) diff --git a/Controller/HaiBaoController.py b/Controller/HaiBaoController.py index 95798f4..96d93a0 100644 --- a/Controller/HaiBaoController.py +++ b/Controller/HaiBaoController.py @@ -9,7 +9,7 @@ from sqlalchemy.sql import text from Util.BananaClient import BananaClient from Util.LlmUtil import get_llm_response -from DbKit.Db import Db +from Model.HaiBaoModel import HaiBaoModel router = APIRouter(prefix="/haibao") logger = logging.getLogger("HaiBaoController") @@ -21,43 +21,8 @@ class GenerateRequest(BaseModel): @router.on_event("startup") async def startup_event(): - """初始化时检查并创建表""" - db = Db() - await db.init_db() - - # Doris 建表语句 - create_table_sql = """ - CREATE TABLE IF NOT EXISTS haibao_history ( - id VARCHAR(50) COMMENT "ID", - prompt TEXT COMMENT "提示词", - image_url VARCHAR(500) COMMENT "图片URL", - scheme_content TEXT COMMENT "文案方案", - created_at DATETIME COMMENT "创建时间" - ) - DUPLICATE KEY(id) - DISTRIBUTED BY HASH(id) BUCKETS 1 - PROPERTIES ( - "replication_num" = "1" - ); - """ - try: - # 使用 engine 直接执行 DDL - async with db.engine.begin() as conn: - await conn.execute(text(create_table_sql)) - - # 尝试添加列(如果表已存在但列不存在) - # Doris 不支持 IF NOT EXISTS for ADD COLUMN directly nicely in all versions without error if exists - # 所以这里简单捕获异常,如果列已存在则忽略 - try: - alter_sql = "ALTER TABLE haibao_history ADD COLUMN scheme_content TEXT COMMENT '文案方案'" - await conn.execute(text(alter_sql)) - except Exception as e: - # 忽略列已存在的错误 - pass - - logger.info("海报历史表检查/更新成功") - except Exception as e: - logger.error(f"海报历史表创建/更新失败: {e}") + """初始化时逻辑""" + logger.info("海报控制器启动成功") class RefineRequest(BaseModel): prompt: str @@ -184,26 +149,11 @@ async def generate_poster(req: GenerateRequest): image_url, scheme_content = await asyncio.gather(generate_image_task(), generate_text_task()) # 4. 保存到数据库 - db = Db() + model = HaiBaoModel() record_id = str(uuid.uuid4()) created_at = datetime.now() - insert_sql = """ - INSERT INTO haibao_history (id, prompt, image_url, scheme_content, created_at) - VALUES (:id, :prompt, :image_url, :scheme_content, :created_at) - """ - - params = { - "id": record_id, - "prompt": req.prompt, - "image_url": image_url, - "scheme_content": scheme_content, - "created_at": created_at - } - - async with await db.get_session() as session: - async with session.begin(): - await session.execute(text(insert_sql), params) + await model.insert_record(record_id, req.prompt, image_url, scheme_content, created_at) return { "id": record_id, @@ -220,10 +170,9 @@ async def generate_poster(req: GenerateRequest): @router.get("/history") async def get_history(): """获取海报生成历史""" - db = Db() - sql = "SELECT * FROM haibao_history ORDER BY created_at DESC LIMIT 50" try: - result = await db.find(sql) + model = HaiBaoModel() + result = await model.get_history(50) formatted_result = [] for item in result: diff --git a/Controller/YltAnalyticsController.py b/Controller/YltAnalyticsController.py index a72f853..ff15a88 100644 --- a/Controller/YltAnalyticsController.py +++ b/Controller/YltAnalyticsController.py @@ -6,7 +6,6 @@ from typing import List, Optional, Dict, Any from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse, FileResponse, JSONResponse -from DbKit.Db import Db from Config.Config import DB_URL from Util.LlmUtil import get_llm_response from Tools.T6_Export import export_excel, DorisExcelExporter, extract_hourly_prices_from_schedule @@ -24,31 +23,35 @@ from Model.YltAnalyticsModel import ( PriceSeries, PriceComparisonResponse, PriceComparisonSummary, + YltAnalyticsModel, ) +from DbKit.Db import Db + router = APIRouter() -db = Db(db_url=DB_URL) +# db = Db(db_url=DB_URL) # Removed direct db instance async def init_db(): + db = Db() await db.init_db() async def close_db(): + db = Db() await db.close() @router.get("/api/operators/hourly-prices") async def get_operators_hourly_prices(): operators = ["新电途", "特来电", "驿来特", "艾特吉易充"] - exporter = DorisExcelExporter(db_url=DB_URL) - await exporter.init() + model = YltAnalyticsModel() try: result = [] for op in operators: - rows = await exporter.fetch_current_station_rows(op) + rows = await model.fetch_current_station_rows(op) if not rows: result.append({"operator": op, "series": [None] * 24}) continue @@ -71,11 +74,72 @@ async def get_operators_hourly_prices(): else: avg_series.append(None) result.append({"operator": op, "series": avg_series}) - finally: - await exporter.close() + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) return {"operators": result} +@router.get("/api/operators/price-trends") +async def get_operators_price_trends(days: int = 7): + operators = ["新电途", "特来电", "驿来特", "艾特吉易充"] + + model = YltAnalyticsModel() + rows = await model.get_operators_price_trends(days) + + # 数据结构: { operator: { date_str: [sums_of_24h, counts_of_24h] } } + trend_data = {} + for op in operators: + trend_data[op] = {} + + for row in rows: + op = row.get("operator") + if op not in trend_data: + continue + d_str = str(row.get("date_str")) + schedule_json = row.get("schedule_json") + + if d_str not in trend_data[op]: + trend_data[op][d_str] = {"sums": [0.0] * 24, "counts": [0] * 24} + + series = extract_hourly_prices_from_schedule(schedule_json) + for i in range(24): + v = series[i] + if v is not None: + trend_data[op][d_str]["sums"][i] += float(v) + trend_data[op][d_str]["counts"][i] += 1 + + # 转换为 ECharts 友好格式 + # 1. 获取所有日期并排序 + all_dates = sorted(list(set(str(row.get("date_str")) for row in rows))) + + # 2. 为每个运营商计算每天的平均价格(24小时的平均值) + series_result = [] + for op in operators: + op_trend = [] + for d in all_dates: + if d in trend_data[op]: + day_stats = trend_data[op][d] + day_avg_sum = 0.0 + day_hour_count = 0 + for i in range(24): + if day_stats["counts"][i] > 0: + day_avg_sum += (day_stats["sums"][i] / day_stats["counts"][i]) + day_hour_count += 1 + + if day_hour_count > 0: + op_trend.append(round(day_avg_sum / day_hour_count, 4)) + else: + op_trend.append(None) + else: + op_trend.append(None) + series_result.append({"name": op, "data": op_trend}) + + return { + "dates": all_dates, + "series": series_result + } + + @router.get("/api/export/prices-zip") async def export_prices_zip(): operators = ["新电途", "特来电", "驿来特", "艾特吉易充"] @@ -177,26 +241,8 @@ async def ai_pricing_strategy_summary(): @router.get("/api/ylt/stations", response_model=List[StationBase]) async def list_ylt_stations(q: Optional[str] = None): - base_sql = """ - SELECT - p.station_hash, - p.operator, - p.station_name, - p.address, - p.coord_x, - p.coord_y, - s.current_price - FROM t_station_profile_scd p - LEFT JOIN t_station_status_scd s - ON p.station_hash = s.station_hash AND s.is_current = 1 - WHERE p.operator = '驿来特' - """ - params: Dict[str, Any] = {} - if q: - base_sql += " AND (p.station_name LIKE :kw OR p.address LIKE :kw)" - params["kw"] = f"%{q}%" - base_sql += " ORDER BY p.station_name" - rows = await db.find(base_sql, params) + model = YltAnalyticsModel() + rows = await model.list_ylt_stations(q) result: List[StationBase] = [] for r in rows: result.append( @@ -225,21 +271,8 @@ def haversine_km(lon1: float, lat1: float, lon2: float, lat2: float) -> float: async def fetch_current_stations() -> List[dict]: - sql = """ - SELECT - p.station_hash, - p.operator, - p.station_name, - p.address, - p.coord_x, - p.coord_y, - s.current_price - FROM t_station_profile_scd p - LEFT JOIN t_station_status_scd s - ON p.station_hash = s.station_hash AND s.is_current = 1 - WHERE p.is_current = 1 - """ - return await db.find(sql, {}) + model = YltAnalyticsModel() + return await model.fetch_current_stations() async def build_geo_competition(station_hash: str, radius_km: float = 3.0) -> GeoCompetitionResponse: @@ -318,17 +351,8 @@ async def build_geo_competition(station_hash: str, radius_km: float = 3.0) -> Ge async def fetch_station_schedule_json(station_hash: str) -> Optional[str]: - sql = """ - SELECT schedule_json - FROM t_station_price_schedule_scd - WHERE station_hash = :h AND is_current = 1 - ORDER BY valid_start_time DESC - LIMIT 1 - """ - rows = await db.find(sql, {"h": station_hash}) - if not rows: - return None - value = rows[0].get("schedule_json") + model = YltAnalyticsModel() + value = await model.fetch_station_schedule_json(station_hash) if value is None: return None if isinstance(value, str): diff --git a/Controller/__pycache__/DouYinController.cpython-310.pyc b/Controller/__pycache__/DouYinController.cpython-310.pyc index 79aa9ea..7458bbe 100644 Binary files a/Controller/__pycache__/DouYinController.cpython-310.pyc and b/Controller/__pycache__/DouYinController.cpython-310.pyc differ diff --git a/Controller/__pycache__/HaiBaoController.cpython-310.pyc b/Controller/__pycache__/HaiBaoController.cpython-310.pyc index f5d1386..461ece4 100644 Binary files a/Controller/__pycache__/HaiBaoController.cpython-310.pyc and b/Controller/__pycache__/HaiBaoController.cpython-310.pyc differ diff --git a/Controller/__pycache__/YltAnalyticsController.cpython-310.pyc b/Controller/__pycache__/YltAnalyticsController.cpython-310.pyc index 4d27e2a..bcfba8a 100644 Binary files a/Controller/__pycache__/YltAnalyticsController.cpython-310.pyc and b/Controller/__pycache__/YltAnalyticsController.cpython-310.pyc differ diff --git a/DbKit/Db.py b/DbKit/Db.py index 9ea999e..b342a08 100644 --- a/DbKit/Db.py +++ b/DbKit/Db.py @@ -444,7 +444,15 @@ class Db: pass # 检查是否为SQL模板名称并处理 - if self._is_sql_template(sql): + # 即使不是模板名称,如果包含点号,我们也尝试在加载后检查它是否真的是模板 + is_template = self._is_sql_template(sql) + if not is_template and '.' in sql: + # 如果包含点号但没在映射中找到,可能是因为缓存加载问题,尝试重新加载映射 + if hasattr(self.sql_loader, '_build_template_map'): + self.sql_loader._build_template_map() + is_template = self._is_sql_template(sql) + + if is_template: try: logger.debug(f"确认为SQL模板: {sql}") @@ -620,7 +628,14 @@ class Db: logger.debug(f"处理后的页码: {page_number}, 每页大小: {page_size}") # 使用_is_sql_template方法检查并获取SQL模板 - if self._is_sql_template(sql): + is_template = self._is_sql_template(sql) + if not is_template and '.' in sql: + # 如果包含点号但没在映射中找到,可能是因为缓存加载问题,尝试重新加载映射 + if hasattr(self.sql_loader, '_build_template_map'): + self.sql_loader._build_template_map() + is_template = self._is_sql_template(sql) + + if is_template: try: logger.debug(f"确认为SQL模板: {sql}") @@ -840,6 +855,27 @@ class Db: # 如果所有策略都失败,返回默认查询 return "SELECT COUNT(*)" + async def check_column_exists(self, table_name, column_name): + """检查表中是否存在指定列 + + Args: + table_name: 表名 + column_name: 列名 + + Returns: + bool: 如果列存在返回True,否则返回False + """ + try: + # 兼容 Doris 和 MySQL 的语法 + sql = f"SHOW COLUMNS FROM {table_name} LIKE :column_name" + params = {"column_name": column_name} + # 使用 find 方法执行查询 + result = await self.find(sql, params) + return len(result) > 0 + except Exception as e: + logger.debug(f"检查列是否存在时出错 (可能表不存在): {str(e)}") + return False + @db_retry() async def execute_update(self, sql, params=None, session=None): """执行SQL更新操作(插入、更新、删除)(异步版本) diff --git a/DbKit/Sql/DouYin.sql b/DbKit/Sql/DouYin.sql new file mode 100644 index 0000000..4512720 --- /dev/null +++ b/DbKit/Sql/DouYin.sql @@ -0,0 +1,57 @@ +#namespace("DouYin") + #sql("updateStatus") + UPDATE t_douyin_record + SET status = #para(status) + #if(error_msg) + , error_msg = #para(error_msg) + #end + WHERE id = #para(id) + #end + + #sql("updateRecord") + UPDATE t_douyin_record + SET video_name = #para(title), + obs_url = #para(obs_url), + transcript = #para(transcript), + status = #para(status) + WHERE id = #para(id) + #end + + #sql("insertRecord") + INSERT INTO t_douyin_record (id, original_text, status, create_time) + VALUES (#para(id), #para(url), 'PENDING', #para(create_time)) + #end + + #sql("getRecords") + SELECT * FROM t_douyin_record + ORDER BY create_time DESC + LIMIT #para(limit) + #end + + #sql("deleteRecord") + DELETE FROM t_douyin_record WHERE id = #para(id) + #end + + #sql("getTranscriptsByIds") + SELECT video_name, transcript + FROM t_douyin_record + WHERE id IN (#para(ids)) + AND status = 'COMPLETED' + #end + + #sql("getLatestTranscripts") + SELECT video_name, transcript + FROM t_douyin_record + WHERE status = 'COMPLETED' + ORDER BY create_time DESC + LIMIT #para(limit) + #end + + #sql("getInterruptedTasks") + SELECT id, original_text, status + FROM t_douyin_record + WHERE status IN ('PENDING', 'PROCESSING') + ORDER BY create_time DESC + LIMIT #para(limit) + #end +#end diff --git a/DbKit/Sql/HaiBao.sql b/DbKit/Sql/HaiBao.sql new file mode 100644 index 0000000..6af5add --- /dev/null +++ b/DbKit/Sql/HaiBao.sql @@ -0,0 +1,12 @@ +#namespace("HaiBao") + #sql("insertHistory") + INSERT INTO haibao_history (id, prompt, image_url, scheme_content, created_at) + VALUES (#para(id), #para(prompt), #para(image_url), #para(scheme_content), #para(created_at)) + #end + + #sql("getHistory") + SELECT * FROM haibao_history + ORDER BY created_at DESC + LIMIT #para(limit) + #end +#end diff --git a/DbKit/Sql/YltAnalytics.sql b/DbKit/Sql/YltAnalytics.sql new file mode 100644 index 0000000..2dc7cec --- /dev/null +++ b/DbKit/Sql/YltAnalytics.sql @@ -0,0 +1,79 @@ +#namespace("YltAnalytics") + #sql("getOperatorsPriceTrends") + SELECT + p.operator, + DATE(sc.valid_start_time) as date_str, + sc.schedule_json + FROM t_station_price_schedule_scd sc + JOIN t_station_profile_scd p ON sc.station_hash = p.station_hash AND p.is_current = 1 + WHERE sc.valid_start_time >= DATE_SUB(CURDATE(), INTERVAL #para(days) DAY) + ORDER BY date_str ASC + #end + + #sql("listYltStations") + SELECT + p.station_hash, + p.operator, + p.station_name, + p.address, + p.coord_x, + p.coord_y, + s.current_price + FROM t_station_profile_scd p + LEFT JOIN t_station_status_scd s + ON p.station_hash = s.station_hash AND s.is_current = 1 + WHERE p.operator = '驿来特' + #if(q) + AND (p.station_name LIKE #para(kw) OR p.address LIKE #para(kw)) + #end + ORDER BY p.station_name + #end + + #sql("fetchCurrentStations") + SELECT + p.station_hash, + p.operator, + p.station_name, + p.address, + p.coord_x, + p.coord_y, + s.current_price + FROM t_station_profile_scd p + LEFT JOIN t_station_status_scd s + ON p.station_hash = s.station_hash AND s.is_current = 1 + WHERE p.is_current = 1 + #end + + #sql("fetchStationScheduleJson") + SELECT sc.schedule_json, p.operator + FROM t_station_price_schedule_scd sc + JOIN t_station_profile_scd p ON sc.station_hash = p.station_hash AND p.is_current = 1 + WHERE sc.station_hash = #para(h) AND sc.is_current = 1 + ORDER BY sc.valid_start_time DESC + LIMIT 1 + #end + + #sql("fetchCurrentStationRows") + SELECT + p.station_hash, + p.station_name, + p.address, + p.coord_x, + p.coord_y, + s.total_piles AS total_guns, + s.free_piles AS free_guns, + s.current_price, + s.pro_price, + s.parking_info, + s.distance, + s.valid_start_time AS status_update_time, + sc.schedule_json, + sc.valid_start_time AS schedule_update_time, + s.piles_detail_json + FROM t_station_profile_scd p + LEFT JOIN t_station_status_scd s ON p.station_hash = s.station_hash AND s.is_current = 1 + LEFT JOIN t_station_price_schedule_scd sc ON p.station_hash = sc.station_hash AND sc.is_current = 1 + WHERE p.operator = #para(op) AND p.is_current = 1 + ORDER BY p.station_name ASC + #end +#end diff --git a/DbKit/SqlTemplateLoader.py b/DbKit/SqlTemplateLoader.py index 240bba4..443e756 100644 --- a/DbKit/SqlTemplateLoader.py +++ b/DbKit/SqlTemplateLoader.py @@ -489,6 +489,10 @@ class SqlTemplateLoader: template_map_str = await redisKit.get_data(REDIS_SQL_TEMPLATE_MAP_KEY) if template_map_str: self._template_map = json.loads(template_map_str) + + # 无论从Redis加载了什么,都尝试重新构建一次映射以确保一致性 + if self.templates and not self._template_map: + self._build_template_map() except Exception as e: logger.error(f"从Redis加载SQL模板数据失败: {str(e)}") diff --git a/DbKit/__pycache__/Db.cpython-310.pyc b/DbKit/__pycache__/Db.cpython-310.pyc index 637cbee..6982c04 100644 Binary files a/DbKit/__pycache__/Db.cpython-310.pyc and b/DbKit/__pycache__/Db.cpython-310.pyc differ diff --git a/DbKit/__pycache__/SqlTemplateLoader.cpython-310.pyc b/DbKit/__pycache__/SqlTemplateLoader.cpython-310.pyc index 6dc94ef..582e173 100644 Binary files a/DbKit/__pycache__/SqlTemplateLoader.cpython-310.pyc and b/DbKit/__pycache__/SqlTemplateLoader.cpython-310.pyc differ diff --git a/Model/DouYinModel.py b/Model/DouYinModel.py new file mode 100644 index 0000000..44a9c3a --- /dev/null +++ b/Model/DouYinModel.py @@ -0,0 +1,79 @@ +from DbKit.Db import Db +from datetime import datetime + +class DouYinModel: + def __init__(self, db: Db = None): + self.db = db or Db() + + async def init(self): + await self.db.init_db() + + async def update_status(self, id, status, error_msg=None): + await self.db.init_db() + params = { + "id": id, + "status": status, + "error_msg": error_msg + } + return await self.db.execute_update("DouYin.updateStatus", params) + + async def update_record(self, id, title, obs_url, transcript, status): + await self.db.init_db() + # Truncate title to 100 chars to fit DB schema + if title and len(title) > 100: + title = title[:100] + "..." + + params = { + "id": id, + "title": title, + "obs_url": obs_url, + "transcript": transcript, + "status": status + } + return await self.db.execute_update("DouYin.updateRecord", params) + + async def insert_record(self, id, url, create_time=None): + await self.db.init_db() + if create_time is None: + create_time = datetime.now() + params = { + "id": id, + "url": url, + "create_time": create_time + } + return await self.db.execute_update("DouYin.insertRecord", params) + + async def get_records(self, limit=50): + await self.db.init_db() + params = {"limit": limit} + records = await self.db.find("DouYin.getRecords", params) + + # Manually handle datetime serialization + for r in records: + if 'create_time' in r and r['create_time']: + r['create_time'] = r['create_time'].strftime("%Y-%m-%d %H:%M:%S") + if 'update_time' in r and r['update_time']: + r['update_time'] = r['update_time'].strftime("%Y-%m-%d %H:%M:%S") + return records + + async def delete_record(self, id): + await self.db.init_db() + params = {"id": id} + return await self.db.execute_update("DouYin.deleteRecord", params) + + async def get_transcripts(self, ids=None, limit=20): + await self.db.init_db() + if ids: + params = {"ids": ids} + return await self.db.find("DouYin.getTranscriptsByIds", params) + else: + params = {"limit": limit} + return await self.db.find("DouYin.getLatestTranscripts", params) + + async def get_interrupted_tasks(self, limit=20): + await self.db.init_db() + params = {"limit": limit} + return await self.db.find("DouYin.getInterruptedTasks", params) + + async def close(self): + await self.db.shutdown() diff --git a/Model/HaiBaoModel.py b/Model/HaiBaoModel.py new file mode 100644 index 0000000..b75821b --- /dev/null +++ b/Model/HaiBaoModel.py @@ -0,0 +1,24 @@ +import logging +from DbKit.Db import Db + +logger = logging.getLogger("HaiBaoModel") + +class HaiBaoModel: + def __init__(self): + self.db = Db() + + async def insert_record(self, id, prompt, image_url, scheme_content, created_at): + """插入生成记录""" + params = { + "id": id, + "prompt": prompt, + "image_url": image_url, + "scheme_content": scheme_content, + "created_at": created_at + } + return await self.db.execute_update("HaiBao.insertHistory", params) + + async def get_history(self, limit=50): + """获取历史记录""" + params = {"limit": limit} + return await self.db.find("HaiBao.getHistory", params) diff --git a/Model/YltAnalyticsModel.py b/Model/YltAnalyticsModel.py index 2794d12..3e6f6fb 100644 --- a/Model/YltAnalyticsModel.py +++ b/Model/YltAnalyticsModel.py @@ -53,3 +53,32 @@ class PriceComparisonResponse(BaseModel): class PriceComparisonSummary(BaseModel): summary: str + +from DbKit.Db import Db + +class YltAnalyticsModel: + def __init__(self): + self.db = Db() + + async def get_operators_price_trends(self, days: int): + return await self.db.find("YltAnalytics.getOperatorsPriceTrends", {"days": days}) + + async def list_ylt_stations(self, q: str = None): + params = {} + if q: + params["q"] = True + params["kw"] = f"%{q}%" + return await self.db.find("YltAnalytics.listYltStations", params) + + async def fetch_current_stations(self): + return await self.db.find("YltAnalytics.fetchCurrentStations") + + async def fetch_station_schedule_json(self, station_hash: str): + rows = await self.db.find("YltAnalytics.fetchStationScheduleJson", {"h": station_hash}) + if not rows: + return None + return rows[0].get("schedule_json") + + async def fetch_current_station_rows(self, operator: str): + return await self.db.find("YltAnalytics.fetchCurrentStationRows", {"op": operator}) + diff --git a/Model/__pycache__/DouYinModel.cpython-310.pyc b/Model/__pycache__/DouYinModel.cpython-310.pyc new file mode 100644 index 0000000..7b4ed04 Binary files /dev/null and b/Model/__pycache__/DouYinModel.cpython-310.pyc differ diff --git a/Model/__pycache__/HaiBaoModel.cpython-310.pyc b/Model/__pycache__/HaiBaoModel.cpython-310.pyc new file mode 100644 index 0000000..ce0b1a5 Binary files /dev/null and b/Model/__pycache__/HaiBaoModel.cpython-310.pyc differ diff --git a/Model/__pycache__/YltAnalyticsModel.cpython-310.pyc b/Model/__pycache__/YltAnalyticsModel.cpython-310.pyc index b95bf9b..6a5b667 100644 Binary files a/Model/__pycache__/YltAnalyticsModel.cpython-310.pyc and b/Model/__pycache__/YltAnalyticsModel.cpython-310.pyc differ diff --git a/Start.py b/Start.py index 07fad2e..1890fd1 100644 --- a/Start.py +++ b/Start.py @@ -59,6 +59,8 @@ async def lifespan(app: FastAPI): finally: logger.info("驿来特AI智能分析系统关闭...") await close_db() + # Close Redis connection + await RedisKit().close() app = FastAPI(title="驿来特AI智能分析系统", lifespan=lifespan) diff --git a/Tools/T6_Export.py b/Tools/T6_Export.py index e05d03a..be3bb72 100644 --- a/Tools/T6_Export.py +++ b/Tools/T6_Export.py @@ -21,7 +21,7 @@ from Util import Win32Patch Win32Patch.patch() from Config.Config import DB_URL -from DbKit.Db import Db +from Model.YltAnalyticsModel import YltAnalyticsModel try: from openpyxl import Workbook @@ -175,46 +175,17 @@ def _set_column_widths(ws, widths: dict[int, float]): class DorisExcelExporter: - def __init__(self, db_url: str): - self.db = Db(db_url=db_url) - - async def init(self): - await self.db.init_db() - - async def close(self): - try: - await self.db.close() - finally: - if getattr(self.db, "engine", None): - await self.db.engine.dispose() - + def __init__(self, db_url: str = None): + self.model = YltAnalyticsModel() + + async def init(self): + pass # Model handles initialization if needed + + async def close(self): + pass # Model handles closing if needed + async def fetch_current_station_rows(self, operator: str) -> List[Dict[str, Any]]: - sql = """ - SELECT - p.station_hash, - p.station_name, - p.address, - p.coord_x, - p.coord_y, - s.total_piles AS total_guns, - s.free_piles AS free_guns, - s.piles_detail_json, - s.current_price, - s.pro_price, - s.parking_info, - s.distance, - s.valid_start_time AS status_update_time, - pr.schedule_json, - pr.valid_start_time AS schedule_update_time - FROM t_station_profile_scd p - LEFT JOIN t_station_status_scd s - ON p.station_hash = s.station_hash AND s.is_current = 1 - LEFT JOIN t_station_price_schedule_scd pr - ON p.station_hash = pr.station_hash AND pr.is_current = 1 - WHERE p.is_current = 1 AND p.operator = :operator - ORDER BY p.station_name - """ - return await self.db.find(sql, {"operator": operator}) + return await self.model.fetch_current_station_rows(operator) def extract_hourly_prices_from_schedule(schedule_json: Any) -> List[Optional[float]]: diff --git a/Tools/__pycache__/T6_Export.cpython-310.pyc b/Tools/__pycache__/T6_Export.cpython-310.pyc index 65261b3..eccaa56 100644 Binary files a/Tools/__pycache__/T6_Export.cpython-310.pyc and b/Tools/__pycache__/T6_Export.cpython-310.pyc differ diff --git a/Util/RedisKit.py b/Util/RedisKit.py index deb562f..2d69b76 100644 --- a/Util/RedisKit.py +++ b/Util/RedisKit.py @@ -98,6 +98,14 @@ class RedisKit: return False async def set_data(self, key, value, expire=None): + """ + 异步保存数据到Redis + + Args: + key (str): Redis键名 + value (any): 要保存的值 + expire (int, optional): 过期时间(秒) + """ try: await self._ensure_pool() if expire: @@ -108,6 +116,34 @@ class RedisKit: except Exception as e: logger.error(f"保存数据到Redis失败(key={key}): {e}") return False + + async def close(self): + """ + 关闭Redis连接池 + """ + if RedisKit._redis_pool is not None: + try: + # redis-py 4.x+ supports close() + await asyncio.to_thread(RedisKit._redis_pool.close) + RedisKit._redis_pool = None + logger.info("Redis连接池已关闭") + except Exception as e: + logger.error(f"关闭Redis连接池失败: {e}") + + async def delete_data(self, key): + """ + 异步删除Redis中的数据 + + Args: + key (str): Redis键名 + """ + try: + await self._ensure_pool() + await asyncio.to_thread(RedisKit._redis_pool.delete, key) + return True + except Exception as e: + logger.error(f"从Redis删除数据失败(key={key}): {e}") + return False -# ȫʵ +# ȫ��ʵ�� redisKit = RedisKit() diff --git a/Util/__pycache__/RedisKit.cpython-310.pyc b/Util/__pycache__/RedisKit.cpython-310.pyc index d15430c..faecc31 100644 Binary files a/Util/__pycache__/RedisKit.cpython-310.pyc and b/Util/__pycache__/RedisKit.cpython-310.pyc differ diff --git a/clear_sql_cache.py b/clear_sql_cache.py new file mode 100644 index 0000000..cc74d6c --- /dev/null +++ b/clear_sql_cache.py @@ -0,0 +1,35 @@ + +import asyncio +import os +import sys + +# Add current directory to path so we can import Util +sys.path.append(os.getcwd()) + +from Util.RedisKit import redisKit + +async def clear_cache(): + print("Connecting to Redis...") + # redisKit will automatically ensure pool on first operation + + keys_to_delete = [ + "sql_templates:templates_loaded", + "sql_templates:all_templates", + "sql_templates:loaded_files", + "sql_templates:template_map" + ] + + for key in keys_to_delete: + print(f"Deleting key: {key}") + # Assuming delete_data exists in redisKit based on common naming convention + # If it doesn't, we can use a raw delete command + try: + await redisKit.delete_data(key) + except AttributeError: + conn = await redisKit.get_connection() + await asyncio.to_thread(conn.delete, key) + + print("SQL template cache cleared successfully.") + +if __name__ == "__main__": + asyncio.run(clear_cache()) diff --git a/static/css/dashboard.css b/static/css/dashboard.css index b6a0fe7..6c8bc6d 100644 --- a/static/css/dashboard.css +++ b/static/css/dashboard.css @@ -127,9 +127,21 @@ body { min-width: 0; } -/* Station List */ +/* Station List & Trends */ .station-list { - flex: 1; + flex: 3; + min-height: 0; + background-color: var(--card-bg); + border: 1px solid var(--card-border); + border-radius: 12px; + display: flex; + flex-direction: column; + overflow: hidden; + box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); +} + +.trend-section { + flex: 2; min-height: 0; background-color: var(--card-bg); border: 1px solid var(--card-border); @@ -230,23 +242,100 @@ body { scrollbar-color: var(--scrollbar-thumb) transparent; } -/* Markdown */ +/* Markdown & LaTeX Rendering */ .markdown-body { - font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Helvetica, Arial, sans-serif; - font-size: 16px; - line-height: 1.6; + font-family: 'Inter', -apple-system, BlinkMacSystemFont, "Segoe UI", Helvetica, Arial, sans-serif; + font-size: 14px; + line-height: 1.8; word-wrap: break-word; color: #cbd5e1; } +.markdown-body h1, .markdown-body h2, .markdown-body h3, +.markdown-body h4, .markdown-body h5, .markdown-body h6 { + margin-top: 24px; + margin-bottom: 16px; + font-weight: 600; + line-height: 1.25; + color: var(--text-primary); +} + +.markdown-body h1 { font-size: 1.8em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.3em; } +.markdown-body h2 { font-size: 1.4em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.3em; } +.markdown-body h3 { font-size: 1.2em; } + +.markdown-body p { margin-top: 0; margin-bottom: 16px; } + .markdown-body ul, .markdown-body ol { padding-left: 2em; margin-top: 0; margin-bottom: 16px; } + .markdown-body li { margin: 0.25em 0; } -.markdown-body strong { font-weight: 600; color: #f1f5f9; } -.markdown-body h3 { font-size: 1.1em; font-weight: bold; margin-top: 16px; margin-bottom: 8px; color: #f1f5f9; } + +.markdown-body strong { font-weight: 600; color: #fff; } + +.markdown-body blockquote { + padding: 0 1em; + color: #94a3b8; + border-left: 0.25em solid #3b82f6; + margin: 0 0 16px 0; +} + +.markdown-body code { + padding: 0.2em 0.4em; + margin: 0; + font-size: 85%; + background-color: rgba(148, 163, 184, 0.1); + border-radius: 6px; + font-family: ui-monospace, SFMono-Regular, SF Mono, Menlo, Consolas, Liberation Mono, monospace; +} + +/* Markdown Table Styles */ +.markdown-body table { + display: block; + width: 100%; + width: max-content; + max-width: 100%; + overflow: auto; + border-spacing: 0; + border-collapse: collapse; + margin-top: 0; + margin-bottom: 16px; +} + +.markdown-body table th { + font-weight: 600; + background-color: rgba(15, 23, 42, 0.8); +} + +.markdown-body table th, +.markdown-body table td { + padding: 6px 13px; + border: 1px solid var(--card-border); +} + +.markdown-body table tr { + background-color: transparent; + border-top: 1px solid var(--card-border); +} + +.markdown-body table tr:nth-child(2n) { + background-color: rgba(30, 41, 59, 0.3); +} + +/* LaTeX Styles */ +.katex-block { + margin: 1em 0; + overflow-x: auto; + overflow-y: hidden; + padding: 8px 0; +} + +.katex { + font-size: 1.1em !important; +} /* Custom Scrollbar */ ::-webkit-scrollbar { width: 8px; height: 8px; } diff --git a/static/css/douyin.css b/static/css/douyin.css index 7991c5a..4a67360 100644 --- a/static/css/douyin.css +++ b/static/css/douyin.css @@ -239,22 +239,67 @@ body { animation: blink 1s step-end infinite; } -/* Markdown Styles */ +/* Markdown & LaTeX Rendering */ .markdown-body { - font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Helvetica, Arial, sans-serif; - font-size: 16px; - line-height: 1.6; + font-family: 'Inter', -apple-system, BlinkMacSystemFont, "Segoe UI", Helvetica, Arial, sans-serif; + font-size: 15px; + line-height: 1.8; word-wrap: break-word; - color: #334155; + color: #cbd5e1; } +.markdown-body h1, .markdown-body h2, .markdown-body h3, +.markdown-body h4, .markdown-body h5, .markdown-body h6 { + margin-top: 24px; + margin-bottom: 16px; + font-weight: 600; + line-height: 1.25; + color: var(--text-primary); +} + +.markdown-body h1 { font-size: 1.8em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.3em; } +.markdown-body h2 { font-size: 1.4em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.3em; } +.markdown-body h3 { font-size: 1.2em; } + +.markdown-body p { margin-top: 0; margin-bottom: 16px; } + .markdown-body ul, .markdown-body ol { padding-left: 2em; margin-top: 0; margin-bottom: 16px; } + .markdown-body li { margin: 0.25em 0; } -.markdown-body strong { font-weight: 600; color: #0f172a; } + +.markdown-body strong { font-weight: 600; color: #fff; } + +.markdown-body blockquote { + padding: 0 1em; + color: #94a3b8; + border-left: 0.25em solid #3b82f6; + margin: 0 0 16px 0; +} + +.markdown-body code { + padding: 0.2em 0.4em; + margin: 0; + font-size: 85%; + background-color: rgba(148, 163, 184, 0.1); + border-radius: 6px; + font-family: ui-monospace, SFMono-Regular, SF Mono, Menlo, Consolas, Liberation Mono, monospace; +} + +/* LaTeX Styles */ +.katex-block { + margin: 1em 0; + overflow-x: auto; + overflow-y: hidden; + padding: 8px 0; +} + +.katex { + font-size: 1.1em !important; +} @keyframes blink { 0%, 100% { opacity: 1; } diff --git a/static/HaiBao/css/app.css b/static/css/haibao.css similarity index 100% rename from static/HaiBao/css/app.css rename to static/css/haibao.css diff --git a/static/css/query.css b/static/css/query.css index 5cdac8e..df71893 100644 --- a/static/css/query.css +++ b/static/css/query.css @@ -237,42 +237,99 @@ body { 50% { opacity: 0; } } -/* Markdown Styles */ +/* Markdown & LaTeX Rendering */ .markdown-body { - font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Helvetica, Arial, sans-serif; - font-size: 16px; - line-height: 1.6; + font-family: 'Inter', -apple-system, BlinkMacSystemFont, "Segoe UI", Helvetica, Arial, sans-serif; + font-size: 15px; + line-height: 1.8; word-wrap: break-word; - color: #e5e7eb; + color: #cbd5e1; } +.markdown-body h1, .markdown-body h2, .markdown-body h3, +.markdown-body h4, .markdown-body h5, .markdown-body h6 { + margin-top: 24px; + margin-bottom: 16px; + font-weight: 600; + line-height: 1.25; + color: var(--text-primary); +} + +.markdown-body h1 { font-size: 1.8em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.3em; } +.markdown-body h2 { font-size: 1.4em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.3em; } +.markdown-body h3 { font-size: 1.2em; } + +.markdown-body p { margin-top: 0; margin-bottom: 16px; } + .markdown-body ul, .markdown-body ol { padding-left: 2em; margin-top: 0; margin-bottom: 16px; } + .markdown-body li { margin: 0.25em 0; } -.markdown-body strong { font-weight: 600; color: #60a5fa; } +.markdown-body strong { font-weight: 600; color: #fff; } + +.markdown-body blockquote { + padding: 0 1em; + color: #94a3b8; + border-left: 0.25em solid #3b82f6; + margin: 0 0 16px 0; +} + +.markdown-body code { + padding: 0.2em 0.4em; + margin: 0; + font-size: 85%; + background-color: rgba(148, 163, 184, 0.1); + border-radius: 6px; + font-family: ui-monospace, SFMono-Regular, SF Mono, Menlo, Consolas, Liberation Mono, monospace; +} + +/* Markdown Table Styles */ .markdown-body table { + display: block; width: 100%; + width: max-content; + max-width: 100%; + overflow: auto; + border-spacing: 0; border-collapse: collapse; - margin: 16px 0; - font-size: 14px; + margin-top: 0; + margin-bottom: 16px; } -.markdown-body th, .markdown-body td { - padding: 12px; - border: 1px solid rgba(51, 65, 85, 0.9); -} - -.markdown-body th { - background-color: rgba(30, 64, 175, 0.5); +.markdown-body table th { font-weight: 600; + background-color: rgba(15, 23, 42, 0.8); } -.markdown-body tr:nth-child(even) { - background-color: rgba(15, 23, 42, 0.5); +.markdown-body table th, +.markdown-body table td { + padding: 10px 16px; + border: 1px solid var(--card-border); +} + +.markdown-body table tr { + background-color: transparent; + border-top: 1px solid var(--card-border); +} + +.markdown-body table tr:nth-child(2n) { + background-color: rgba(30, 41, 59, 0.3); +} + +/* LaTeX Styles */ +.katex-block { + margin: 1em 0; + overflow-x: auto; + overflow-y: hidden; + padding: 8px 0; +} + +.katex { + font-size: 1.1em !important; } /* ECharts Container in Result */ diff --git a/static/HaiBao/css/scheme.css b/static/css/scheme.css similarity index 100% rename from static/HaiBao/css/scheme.css rename to static/css/scheme.css diff --git a/static/dashboard.html b/static/dashboard.html index efb4e9a..673f2c0 100644 --- a/static/dashboard.html +++ b/static/dashboard.html @@ -25,11 +25,11 @@
- +
- 🕒 分时电价明细 + 📊 供应商实时价格对比
@@ -50,6 +50,18 @@
+ + +
+
+ 📉 供应商价格变动趋势 (最近{{ trendDays }}天) +
+ + +
+
+
+
@@ -66,15 +78,12 @@
-
+
🤖 AI 调价策略建议 -
- 点击“开始分析”获取AI智能定价建议 + AI 正在深度分析中
@@ -85,6 +94,10 @@ + + + + diff --git a/static/douyin.html b/static/douyin.html index bbfcd0e..22c7833 100644 --- a/static/douyin.html +++ b/static/douyin.html @@ -163,6 +163,10 @@ + + + + diff --git a/static/HaiBao/index.html b/static/haibao.html similarity index 93% rename from static/HaiBao/index.html rename to static/haibao.html index 29450c2..8700df7 100644 --- a/static/HaiBao/index.html +++ b/static/haibao.html @@ -4,8 +4,8 @@ 驿来特 - 智能海报生成工作台 - - + + @@ -17,7 +17,7 @@
基于 AI 大模型,快速生成高质量企业宣传海报
- + 返回首页
@@ -151,10 +151,10 @@
- - - - - + + + + + diff --git a/static/index.html b/static/index.html index c2d9334..641de33 100644 --- a/static/index.html +++ b/static/index.html @@ -41,7 +41,7 @@ - +
🎨

智能海报生成

一键生成精美的营销海报与数据战报。支持自定义模板与实时数据填充,提升品牌传播效率。

diff --git a/static/index_old.html b/static/index_old.html deleted file mode 100644 index 9a6f451..0000000 --- a/static/index_old.html +++ /dev/null @@ -1,352 +0,0 @@ - - - - - -驿来特AI智能数据分析平台 - - - - - - - -
- - -
-
-
-

⚡ 系统特性介绍

-
-
-
-
- 📱 -
-

本系统采用 手机爬虫 获取4家充电供应商准实时各时段电价

-
-
-
- 🧠 -
-

结合 数据仓库与AI技术,对我司电价进行智能分析,给出定价建议

-
-
-
- 📊 -
-

对我司的各场站营业情况进行 分析,查询

-
-
-
- 🎨 -
-

新增 智能海报生成 功能,未来将结合业务数据,一键生成精美的数据战报与营销海报

-
-
-
- 🎥 -
-

新增 抖音知识库:支持视频解析、知识获取与总结、博主专栏订阅,自动生成 充电企业知识日报,助力企业构建专属知识库

-
-
-
- 🎯 -
-

未来:可以根据用户充电信息,形成用户画像,结合企业微信,实现 用户广告的精准推送

-
-
-
- 🧭 -
-

未来:基于 LBS位置服务,智能对比周边竞对场站的价格与配套(快充、休息室等),精准引导用户选择我司优势站点

-
-
- -
-
-
- -
-
⚡ 驿来特AI智能数据分析平台
- -
- -
- - -
-
- -
-
- - -
-
-
-
全网供应商24小时电价监控
-
- - - - - - - - - - - - - - - - -
时段{{ op.label }}
{{ row.hour }} - {{ formatCell(cell.price) }} -
- 数据加载中... -
-
-
-
-
- -
-
- 智能决策分析助手 - -
-
-
-
当前分析任务
-
- 请根据爬取的各供应商分时电价等信息,对各司的定价策略, - 与我司(驿来特)的定价策略进行综合对比,分析我司可能存在的潜在问题。 -
-
- -
-
-
-
- - -
- - -
-
-

⚡ AI正在阅读您的知识库并提炼精华,请稍候...

-
-
- | -
- -
- -
-

- 抖音知识库 - - -

-

自动解析视频、提取文案,构建企业充电知识图谱

-
- - - - - -
- - {{ douyinLoading ? '解析处理中...' : '开始解析' }} - -
-
- - -
- -
-
-
-
-
- - {{ record.status }} - - - {{ formatDate(record.create_time) }} - -
-

- {{ record.video_name || '处理中...' }} -

- - 📺 点击观看视频 - -
- - Del - -
- - -
- Error: {{ record.error_msg }} -
- - -
-

视频文案

-

- {{ record.transcript }} -

- - {{ record.expanded ? '收起' : '展开全文' }} - -
- - -
-
- {{ record.showOriginal ? '收起原始链接' : '查看原始链接信息' }} -
-
- {{ record.original_text }} -
-
-
-
-
- -
-

暂无记录,请粘贴链接开始解析

-
-
-
- - -
- -
-
-

手机扫码访问

-
- -
-

驿来特AI智能数据查询

-

基于大语言模型,为您提供实时、精准的业务数据分析

- -
- - - -
- -
- - {{ text }} - -
-
- -
-
-
- 分析结果 - 完成 - 生成中 -
- - 停止生成 - -
- -
-
-

正在分析数据,请稍候...

-
- | -
-
- -
- - - - - - - - - diff --git a/static/js/dashboard.js b/static/js/dashboard.js index 1a99de4..a3f10bc 100644 --- a/static/js/dashboard.js +++ b/static/js/dashboard.js @@ -12,6 +12,7 @@ createApp({ window.addEventListener('resize', () => { isMobile.value = window.innerWidth <= 768; if (chartInstance) chartInstance.resize(); + if (trendChartInstance) trendChartInstance.resize(); }); // ========================================== @@ -35,7 +36,9 @@ createApp({ const priceTableRows = ref([]); const hourlyPricesByOperator = ref({}); let chartInstance = null; + let trendChartInstance = null; const chartType = ref('line'); + const trendDays = ref(7); // ECharts Initialization const initChart = () => { @@ -49,6 +52,17 @@ createApp({ } }; + const initTrendChart = () => { + const dom = document.getElementById("trendChart"); + if (dom && !trendChartInstance) { + if (typeof echarts === 'undefined') { + console.error("ECharts not loaded"); + return; + } + trendChartInstance = echarts.init(dom); + } + }; + const renderChart = () => { if (!chartInstance) return; @@ -152,6 +166,65 @@ createApp({ } }; + const loadTrendData = async () => { + try { + const res = await axios.get(apiBase.value + "/api/operators/price-trends?days=" + trendDays.value); + if (res && res.data) { + renderTrendChart(res.data); + } + } catch (e) { + console.error("Failed to load trend data:", e); + } + }; + + const renderTrendChart = (data) => { + if (!trendChartInstance) initTrendChart(); + if (!trendChartInstance) return; + + const option = { + backgroundColor: 'transparent', + tooltip: { + trigger: "axis", + backgroundColor: 'rgba(30, 41, 59, 0.9)', + borderColor: '#334155', + textStyle: { color: '#f1f5f9' } + }, + legend: { + data: data.series.map(s => s.name), + textStyle: { color: "#94a3b8" }, + top: 10 + }, + grid: { left: 50, right: 30, top: 60, bottom: 40 }, + xAxis: { + type: "category", + data: data.dates.map(d => d.split('-').slice(1).join('/')), // Show MM/DD + axisLine: { lineStyle: { color: "#475569" } }, + axisLabel: { color: "#94a3b8" } + }, + yAxis: { + type: "value", + name: "元/度", + nameTextStyle: { color: "#94a3b8" }, + axisLine: { lineStyle: { color: "#475569" } }, + axisLabel: { color: "#94a3b8" }, + splitLine: { lineStyle: { color: "#334155", type: 'dashed' } }, + min: (value) => (value.min * 0.95).toFixed(2), + max: (value) => (value.max * 1.05).toFixed(2) + }, + series: data.series.map(s => ({ + name: s.name, + type: 'line', + smooth: true, + symbol: 'circle', + symbolSize: 6, + data: s.data, + emphasis: { focus: 'series' } + })) + }; + + trendChartInstance.setOption(option, true); + }; + const exportAllPrices = async () => { try{ exporting.value = true; @@ -244,48 +317,102 @@ createApp({ return '#f1f5f9'; // Slate-100 }; - // 内置简易 Markdown 解析器 + // Configure Marked + if (typeof marked !== 'undefined') { + marked.use({ + gfm: true, + breaks: true + }); + } + + // 降级用的简易 Markdown 解析器 const simpleMarkdown = (text) => { if (!text) return ''; let lines = text.split('\n'); let html = ''; let inList = false; - const parseInline = (str) => { return str .replace(/\*\*(.*?)\*\*/g, '$1') - .replace(/`(.*?)`/g, '$1'); + .replace(/`(.*?)`/g, '$1'); }; - for (let line of lines) { let trimmed = line.trim(); - if (!trimmed) continue; - + if (!trimmed) { + if (inList) { html += ''; inList = false; } + html += '
'; + continue; + } if (trimmed.startsWith('### ')) { if (inList) { html += ''; inList = false; } html += `

${parseInline(trimmed.substring(4))}

`; - } - else if (trimmed.startsWith('- ') || /^\d+\./.test(trimmed)) { + } else if (trimmed.startsWith('- ') || /^\d+\./.test(trimmed)) { if (!inList) { html += ''; inList = false; } - html += `

${parseInline(trimmed)}

`; + html += `

${parseInline(trimmed)}

`; } } if (inList) html += ''; return html; }; + // 增强的 Markdown & LaTeX 解析器 + const renderMarkdownAndLatex = (text) => { + if (!text) return ''; + + try { + // 1. 处理 LaTeX (简单替换,先处理 $$ 再处理 $) + let processedText = text; + + // 处理块级 LaTeX: $$ ... $$ + processedText = processedText.replace(/\$\$\s*([\s\S]*?)\s*\$\$/g, (match, formula) => { + try { + if (typeof katex !== 'undefined') { + return '
' + katex.renderToString(formula, { displayMode: true, throwOnError: false }) + '
'; + } + return match; + } catch (e) { + return match; + } + }); + + // 处理行内 LaTeX: $ ... $ + processedText = processedText.replace(/\$([^\$\n]+?)\$/g, (match, formula) => { + try { + if (typeof katex !== 'undefined') { + return katex.renderToString(formula, { displayMode: false, throwOnError: false }); + } + return match; + } catch (e) { + return match; + } + }); + + // 2. 使用 marked 解析 Markdown + if (typeof marked !== 'undefined') { + return marked.parse(processedText); + } else { + // 降级使用之前的 simpleMarkdown + return simpleMarkdown(processedText); + } + } catch (e) { + console.error('Markdown/LaTeX rendering error:', e); + return text; + } + }; + const renderedAiText = computed(() => { - return simpleMarkdown(aiText.value); + return renderMarkdownAndLatex(aiText.value); }); onMounted(() => { initChart(); + initTrendChart(); loadAllOperatorsPrices(); + loadTrendData(); }); return { @@ -299,9 +426,11 @@ createApp({ priceTableRows, hourlyPricesByOperator, chartType, + trendDays, // Actions loadAllOperatorsPrices, + loadTrendData, exportAllPrices, exportAiReport, startAiAnalysis, diff --git a/static/js/douyin.js b/static/js/douyin.js index a031d8d..051dfca 100644 --- a/static/js/douyin.js +++ b/static/js/douyin.js @@ -15,53 +15,96 @@ createApp({ const summaryLoading = ref(false); const summaryText = ref(''); - // Simple Markdown Parser (Zero Dependency) + // 降级用的简易 Markdown 解析器 const simpleMarkdown = (text) => { if (!text) return ''; let lines = text.split('\n'); let html = ''; let inList = false; - - // Helper: Parse inline styles const parseInline = (str) => { return str - .replace(/\*\*(.*?)\*\*/g, '$1') // Bold - .replace(/`(.*?)`/g, '$1'); // Code + .replace(/\*\*(.*?)\*\*/g, '$1') + .replace(/`(.*?)`/g, '$1'); }; - for (let line of lines) { let trimmed = line.trim(); - if (!trimmed) continue; - - // Headers + if (!trimmed) { + if (inList) { html += ''; inList = false; } + html += '
'; + continue; + } if (trimmed.startsWith('### ')) { if (inList) { html += ''; inList = false; } - html += `

${parseInline(trimmed.substring(4))}

`; - } - // Lists - else if (trimmed.startsWith('- ') || /^\d+\./.test(trimmed)) { - if (!inList) { html += ''; return html; }; + // Configure Marked + if (typeof marked !== 'undefined') { + marked.use({ + gfm: true, + breaks: true + }); + } + + // 增强的 Markdown & LaTeX 解析器 + const renderMarkdownAndLatex = (text) => { + if (!text) return ''; + + try { + // 1. 处理 LaTeX (简单替换,先处理 $$ 再处理 $) + let processedText = text; + + // 处理块级 LaTeX: $$ ... $$ + processedText = processedText.replace(/\$\$\s*([\s\S]*?)\s*\$\$/g, (match, formula) => { + try { + if (typeof katex !== 'undefined') { + return '
' + katex.renderToString(formula, { displayMode: true, throwOnError: false }) + '
'; + } + return match; + } catch (e) { + return match; + } + }); + + // 处理行内 LaTeX: $ ... $ + processedText = processedText.replace(/\$([^\$\n]+?)\$/g, (match, formula) => { + try { + if (typeof katex !== 'undefined') { + return katex.renderToString(formula, { displayMode: false, throwOnError: false }); + } + return match; + } catch (e) { + return match; + } + }); + + // 2. 使用 marked 解析 Markdown + if (typeof marked !== 'undefined') { + return marked.parse(processedText); + } else { + // 降级使用之前的 simpleMarkdown + return simpleMarkdown(processedText); + } + } catch (e) { + console.error('Markdown/LaTeX rendering error:', e); + return text; + } + }; + const renderedSummary = computed(() => { if (!summaryText.value) return ''; - try { - return simpleMarkdown(summaryText.value); - } catch (e) { - console.error("Simple markdown error:", e); - return summaryText.value; - } + return renderMarkdownAndLatex(summaryText.value); }); // Methods diff --git a/static/HaiBao/js/app.js b/static/js/haibao.js similarity index 100% rename from static/HaiBao/js/app.js rename to static/js/haibao.js diff --git a/static/js/query.js b/static/js/query.js index 7e67c9a..ce6c417 100644 --- a/static/js/query.js +++ b/static/js/query.js @@ -168,34 +168,104 @@ createApp({ // Configure Marked if (typeof marked !== 'undefined') { - marked.use({ - gfm: true, - breaks: true, - renderer: { - code(code, language) { - if (language === 'echarts') { - const id = 'chart-' + Math.random().toString(36).substr(2, 9); - return `
`; - } - return false; - } + const renderer = new marked.Renderer(); + const oldCode = renderer.code.bind(renderer); + renderer.code = function(code, language) { + if (language === 'echarts') { + const id = 'chart-' + Math.random().toString(36).substr(2, 9); + return `
`; } + return oldCode(code, language); + }; + marked.setOptions({ + renderer: renderer, + gfm: true, + breaks: true }); } - const renderedResult = computed(() => { - if (!queryResult.value) return ''; - try { - // Remove escapes that might break markdown - const cleanText = queryResult.value.replace(/\\([\*_`#\[\]\(\)!>-])/g, '$1'); - if (typeof marked !== 'undefined') { - return marked.parse(cleanText); + // 降级用的简易 Markdown 解析器 + const simpleMarkdown = (text) => { + if (!text) return ''; + let lines = text.split('\n'); + let html = ''; + let inList = false; + const parseInline = (str) => { + return str + .replace(/\*\*(.*?)\*\*/g, '$1') + .replace(/`(.*?)`/g, '$1'); + }; + for (let line of lines) { + let trimmed = line.trim(); + if (!trimmed) { + if (inList) { html += ''; inList = false; } + html += '
'; + continue; + } + if (trimmed.startsWith('### ')) { + if (inList) { html += ''; inList = false; } + html += `

${parseInline(trimmed.substring(4))}

`; + } else if (trimmed.startsWith('- ') || /^\d+\./.test(trimmed)) { + if (!inList) { html += ''; inList = false; } + html += `

${parseInline(trimmed)}

`; } - return cleanText; - } catch (e) { - console.error('Markdown parsing error:', e); - return queryResult.value; } + if (inList) html += ''; + return html; + }; + + // 增强的 Markdown & LaTeX 解析器 + const renderMarkdownAndLatex = (text) => { + if (!text) return ''; + + try { + // 1. 处理 LaTeX (简单替换,先处理 $$ 再处理 $) + let processedText = text; + + // 处理块级 LaTeX: $$ ... $$ + processedText = processedText.replace(/\$\$\s*([\s\S]*?)\s*\$\$/g, (match, formula) => { + try { + if (typeof katex !== 'undefined') { + return '
' + katex.renderToString(formula, { displayMode: true, throwOnError: false }) + '
'; + } + return match; + } catch (e) { + return match; + } + }); + + // 处理行内 LaTeX: $ ... $ + processedText = processedText.replace(/\$([^\$\n]+?)\$/g, (match, formula) => { + try { + if (typeof katex !== 'undefined') { + return katex.renderToString(formula, { displayMode: false, throwOnError: false }); + } + return match; + } catch (e) { + return match; + } + }); + + // 2. 使用 marked 解析 Markdown + if (typeof marked !== 'undefined') { + // 直接解析处理后的文本,marked 会处理 Markdown 转义 + return marked.parse(processedText); + } else { + // 降级使用之前的 simpleMarkdown + return simpleMarkdown(processedText); + } + } catch (e) { + console.error('Markdown/LaTeX rendering error:', e); + return text; + } + }; + + const renderedResult = computed(() => { + return renderMarkdownAndLatex(queryResult.value); }); watch(queryResult, () => { diff --git a/static/query.html b/static/query.html index dc1ade9..b1e2425 100644 --- a/static/query.html +++ b/static/query.html @@ -82,6 +82,9 @@ + + +