import math import asyncio import json from typing import List, Optional, Dict, Any from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse, FileResponse, JSONResponse from Config.Config import DB_URL from Util.LlmUtil import get_llm_response from Tools.T6_Export import export_excel, DorisExcelExporter, extract_hourly_prices_from_schedule import tempfile import os import zipfile import subprocess from pydantic import BaseModel from starlette.background import BackgroundTask from Model.YltAnalyticsModel import ( StationBase, CompetitorStation, GeoCompetitionResponse, GeoCompetitionSummary, PriceSeries, PriceComparisonResponse, PriceComparisonSummary, YltAnalyticsModel, ) from DbKit.Db import Db router = APIRouter() # db = Db(db_url=DB_URL) # Removed direct db instance async def init_db(): db = Db() await db.init_db() async def close_db(): db = Db() await db.close() @router.get("/api/operators/hourly-prices") async def get_operators_hourly_prices(): operators = ["新电途", "特来电", "驿来特", "艾特吉易充"] model = YltAnalyticsModel() try: result = [] for op in operators: rows = await model.fetch_current_station_rows(op) if not rows: result.append({"operator": op, "series": [None] * 24}) continue sums = [0.0] * 24 counts = [0] * 24 for row in rows: schedule_json = row.get("schedule_json") series = extract_hourly_prices_from_schedule(schedule_json) for i in range(24): v = series[i] if v is None: continue sums[i] += float(v) counts[i] += 1 avg_series = [] for i in range(24): c = counts[i] if c > 0: avg_series.append(sums[i] / c) else: avg_series.append(None) result.append({"operator": op, "series": avg_series}) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) return {"operators": result} @router.get("/api/operators/price-trends") async def get_operators_price_trends(days: int = 7): operators = ["新电途", "特来电", "驿来特", "艾特吉易充"] model = YltAnalyticsModel() rows = await model.get_operators_price_trends(days) # 数据结构: { operator: { datetime_str: [sums_of_price, counts_of_station] } } trend_data = {} for op in operators: trend_data[op] = {} for row in rows: op = row.get("operator") if op not in trend_data: continue # 将日期和 schedule_json 展开为 24 小时的数据点 d_str = str(row.get("date_str")) # 将 2026-01-21 转换为 01/21 try: date_parts = d_str.split('-') display_date = date_parts[1] + '/' + date_parts[2] except: display_date = d_str schedule_json = row.get("schedule_json") series = extract_hourly_prices_from_schedule(schedule_json) for hour in range(24): v = series[hour] if v is not None: # 使用原始日期 YYYY-MM-DD 用于排序,显示时由前端或后端格式化 # 这里我们构造一个带补全的时间字符串,方便自然排序 dt_key = f"{d_str} {hour:02d}:00" if dt_key not in trend_data[op]: trend_data[op][dt_key] = {"sum": 0.0, "count": 0, "display": f"{display_date} {hour:02d}:00"} trend_data[op][dt_key]["sum"] += float(v) trend_data[op][dt_key]["count"] += 1 # 转换为 ECharts 友好格式 # 1. 获取所有时间点并排序 all_time_keys = set() for op in operators: all_time_keys.update(trend_data[op].keys()) sorted_keys = sorted(list(all_time_keys)) # 2. 提取显示用的标签 display_dates = [] if sorted_keys: # 从任意一个存在的运营商数据中获取 display 标签 first_op = operators[0] for key in sorted_keys: # 找到包含该 key 的 display 标签 label = key # fallback for op in operators: if key in trend_data[op]: label = trend_data[op][key]["display"] break display_dates.append(label) # 3. 为每个运营商构建完整的时间序列数据 series_result = [] for op in operators: op_data = [] for key in sorted_keys: if key in trend_data[op]: stats = trend_data[op][key] op_data.append(round(stats["sum"] / stats["count"], 4)) else: op_data.append(None) series_result.append({"name": op, "data": op_data}) return { "dates": display_dates, "series": series_result } @router.get("/api/export/prices-zip") async def export_prices_zip(): operators = ["新电途", "特来电", "驿来特", "艾特吉易充"] tmp_dir = tempfile.mkdtemp(prefix="price_export_") excel_paths = [] for op in operators: filename = f"{op}_{asyncio.get_event_loop().time():.0f}.xlsx" output_path = os.path.join(tmp_dir, filename) await export_excel(op, output_path) excel_paths.append(output_path) zip_path = os.path.join(tmp_dir, "prices_export.zip") with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: for p in excel_paths: arcname = os.path.basename(p) zf.write(p, arcname=arcname) return FileResponse( zip_path, media_type="application/zip", filename="多供应商电价导出.zip", ) class AiReportRequest(BaseModel): content: str @router.post("/api/export/ai-report-docx") async def export_ai_report_docx(req: AiReportRequest): content = req.content if not content: raise HTTPException(status_code=400, detail="Content is empty") # Create temp markdown file with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False, encoding="utf-8") as tmp_md: tmp_md.write(content) tmp_md_path = tmp_md.name output_docx_path = tmp_md_path.replace(".md", ".docx") # Check template template_path = "static/template/templates.docx" cmd = ['pandoc', '-s', tmp_md_path, '-o', output_docx_path, '--resource-path=static'] # Only add reference doc if it exists, but the user requested it specifically. # We'll check if it exists, if not, we might fail or warn, but let's try to include it if possible. if os.path.exists(template_path): cmd.extend(['--reference-doc', template_path]) try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError as e: # Clean up if os.path.exists(tmp_md_path): os.remove(tmp_md_path) raise HTTPException(status_code=500, detail=f"Pandoc conversion failed: {str(e)}") def cleanup(): if os.path.exists(tmp_md_path): os.remove(tmp_md_path) if os.path.exists(output_docx_path): os.remove(output_docx_path) return FileResponse( output_docx_path, media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", filename="AI分析报告.docx", background=BackgroundTask(cleanup) ) @router.get("/api/ai/pricing/strategy-summary") async def ai_pricing_strategy_summary(): async def generate_stream(): try: # 发送初始信息并增加一些空白填充,防止某些代理缓存 yield "正在收集各供应商价格数据,请稍候...\n\n" + (" " * 512) + "\n" print("AI分析开始: 获取运营商价格数据...") # 使用 asyncio.wait_for 防止数据库查询无限挂起 try: # 1. 获取当前最新 24 小时平均价格 resp = await asyncio.wait_for(get_operators_hourly_prices(), timeout=30.0) # 2. 获取最近 3 天的价格变动趋势 trend_resp = await asyncio.wait_for(get_operators_price_trends(days=3), timeout=30.0) except asyncio.TimeoutError: print("获取价格数据超时") yield "\n\n**错误**: 获取价格数据超时,数据库响应过慢,请稍后重试。" return # 处理当前价格数据 data = resp.get("operators", []) text_data = [] for item in data: text_data.append({"operator": item.get("operator"), "series": item.get("series")}) # 处理 3 天趋势数据 trend_dates = trend_resp.get("dates", []) trend_series = trend_resp.get("series", []) trend_text = [] for s in trend_series: trend_text.append({"operator": s.get("name"), "daily_avg_prices": s.get("data")}) print(f"数据获取完成,准备请求LLM. 数据条数: {len(text_data)}, 趋势天数: {len(trend_dates)}") yield "数据收集完成,正在分析最近 3 天的价格波动趋势并生成深度建议...\n\n" # 增加一个心跳,确保连接不断开 yield " " * 128 + "\n" prompt = ( "你是一位专业的充电桩调价策略分析顾问。下面是四家供应商(新电途、特来电、驿来特、艾特吉易充)的电价分析数据:\n\n" "### 1. 当前最新 24 小时平均分时电价 (元/kWh)\n" f"{json.dumps(text_data, ensure_ascii=False)}\n\n" "### 2. 最近 3 天的价格变动趋势 (每日平均电价)\n" f"日期序列: {trend_dates}\n" f"各司趋势: {json.dumps(trend_text, ensure_ascii=False)}\n\n" "请根据以上数据进行深度分析:\n" "1. **现状对比**:对比我司(驿来特)与竞对在不同时段的电价水平,找出我司偏高或偏低的关键时段。\n" "2. **趋势洞察**:分析最近 3 天各供应商的价格调整动态,判断市场整体是在涨价、降价还是保持稳定,我司的反应是否及时。\n" "3. **问题诊断**:指出我司目前定价中存在的潜在风险(如价格倒挂、错失高峰收益、低谷缺乏竞争力等)。\n" "4. **优化方案**:给出 2-3 条具体的、可落地的调价建议,并说明理由。\n\n" "要求:\n" "- 使用专业、客观的语气。\n" "- 采用 Markdown 格式,适当使用加粗和表格。\n" "- 回答控制在 800-1000 字以内。" ) # 清空之前的提示信息,开始正式输出 AI 内容 yield "---CLEAR_PREVIOUS_HINTS---\n" chunk_count = 0 # 使用 asyncio.wait_for 防止 LLM 请求完全死掉 try: # 某些时候 LLM 可能会卡住,设置一个合理的整体超时 async for chunk in get_llm_response( prompt, stream=True, system_prompt="你是熟悉中国充电桩行业的电价策略分析顾问。", ): chunk_count += 1 if chunk_count == 1: print("收到LLM首个chunk") yield chunk except Exception as llm_e: print(f"LLM请求异常: {str(llm_e)}") yield f"\n\n**AI 分析服务异常**: {str(llm_e)}。这可能是由于大模型服务商(如 DeepSeek)响应过慢或连接中断导致的。" return print(f"AI分析完成,共发送 {chunk_count} 个chunks") except Exception as e: error_msg = f"\n\n**分析过程出现严重错误**: {str(e)}" print(error_msg) yield error_msg return StreamingResponse( generate_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", "Content-Type": "text/event-stream; charset=utf-8" } ) @router.get("/api/ylt/stations", response_model=List[StationBase]) async def list_ylt_stations(q: Optional[str] = None): model = YltAnalyticsModel() rows = await model.list_ylt_stations(q) result: List[StationBase] = [] for r in rows: result.append( StationBase( station_hash=r.get("station_hash"), operator=r.get("operator"), station_name=r.get("station_name"), address=r.get("address"), coord_x=r.get("coord_x"), coord_y=r.get("coord_y"), current_price=r.get("current_price"), ) ) return result def haversine_km(lon1: float, lat1: float, lon2: float, lat2: float) -> float: r = 6371.0 phi1 = math.radians(lat1) phi2 = math.radians(lat2) d_phi = math.radians(lat2 - lat1) d_lambda = math.radians(lon2 - lon1) a = math.sin(d_phi / 2) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(d_lambda / 2) ** 2 c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) return r * c async def fetch_current_stations() -> List[dict]: model = YltAnalyticsModel() return await model.fetch_current_stations() async def build_geo_competition(station_hash: str, radius_km: float = 3.0) -> GeoCompetitionResponse: rows = await fetch_current_stations() if not rows: raise HTTPException(status_code=404, detail="no station data") base_row = None for r in rows: if r.get("station_hash") == station_hash and r.get("operator") == "驿来特": base_row = r break if base_row is None: raise HTTPException(status_code=404, detail="base station not found for 驿来特") base_lon = base_row.get("coord_x") base_lat = base_row.get("coord_y") if base_lon is None or base_lat is None: raise HTTPException(status_code=400, detail="base station has no coordinates") competitors: List[CompetitorStation] = [] ylt_price = base_row.get("current_price") cheaper = 0 same = 0 more_expensive = 0 min_price: Optional[float] = None max_price: Optional[float] = None for r in rows: if r.get("operator") == "驿来特": continue lon = r.get("coord_x") lat = r.get("coord_y") if lon is None or lat is None: continue dist = haversine_km(base_lon, base_lat, lon, lat) if dist > radius_km: continue price = r.get("current_price") competitors.append( CompetitorStation( station_hash=r.get("station_hash"), operator=r.get("operator"), station_name=r.get("station_name"), distance_km=round(dist, 3), current_price=price, ) ) if price is not None: if min_price is None or price < min_price: min_price = price if max_price is None or price > max_price: max_price = price if ylt_price is not None: if price < ylt_price: cheaper += 1 elif price > ylt_price: more_expensive += 1 else: same += 1 base_station = StationBase( station_hash=base_row.get("station_hash"), operator=base_row.get("operator"), station_name=base_row.get("station_name"), address=base_row.get("address"), coord_x=base_lon, coord_y=base_lat, current_price=ylt_price, ) return GeoCompetitionResponse( base_station=base_station, competitors=competitors, ylt_price=ylt_price, min_competitor_price=min_price, max_competitor_price=max_price, cheaper_count=cheaper, same_count=same, more_expensive_count=more_expensive, ) async def fetch_station_schedule_json(station_hash: str) -> Optional[str]: model = YltAnalyticsModel() value = await model.fetch_station_schedule_json(station_hash) if value is None: return None if isinstance(value, str): return value try: return json.dumps(value, ensure_ascii=False) except Exception: return None def extract_price_from_item(item: Dict[str, Any]) -> Optional[float]: if not isinstance(item, dict): return None for key in ("price", "price_kwh", "priceKwh", "total_price", "totalPrice"): v = item.get(key) if isinstance(v, (int, float)): return float(v) elec = item.get("elec_price") service = item.get("service_price") if isinstance(elec, (int, float)) and isinstance(service, (int, float)): return float(elec) + float(service) elec2 = item.get("electric_fee_kwh") service2 = item.get("service_fee_kwh") if isinstance(elec2, (int, float)) and isinstance(service2, (int, float)): return float(elec2) + float(service2) elec3 = item.get("ele_fee") service3 = item.get("ser_fee") if isinstance(elec3, (int, float)) and isinstance(service3, (int, float)): return float(elec3) + float(service3) return None def parse_hour_from_item(item: Dict[str, Any], default_index: int) -> Optional[int]: start = item.get("start") if isinstance(start, str) and ":" in start: parts = start.split(":") try: h = int(parts[0]) if 0 <= h <= 23: return h except Exception: pass end = item.get("end") if isinstance(end, str) and ":" in end: parts = end.split(":") try: h2 = int(parts[0]) if 0 < h2 <= 24: return h2 - 1 except Exception: pass if 0 <= default_index <= 23: return default_index return None def extract_hourly_prices(schedule_json_str: str) -> List[Optional[float]]: series: List[Optional[float]] = [None] * 24 if not schedule_json_str: return series try: data = json.loads(schedule_json_str) except Exception: return series if not isinstance(data, list): return series for idx, item in enumerate(data): price = extract_price_from_item(item) if price is None: continue hour_idx = parse_hour_from_item(item, idx) if hour_idx is None or not (0 <= hour_idx < 24): continue series[hour_idx] = float(price) return series async def build_price_comparison(station_hash: str) -> PriceComparisonResponse: geo = await build_geo_competition(station_hash) base_station = geo.base_station base_schedule_str = await fetch_station_schedule_json(base_station.station_hash) if base_schedule_str is None: raise HTTPException(status_code=404, detail="no price schedule for YLT station") ylt_series = extract_hourly_prices(base_schedule_str) hours = list(range(24)) operator_series_sum: Dict[str, List[float]] = {} operator_series_count: Dict[str, List[int]] = {} for comp in geo.competitors: schedule_str = await fetch_station_schedule_json(comp.station_hash) if not schedule_str: continue series = extract_hourly_prices(schedule_str) op = comp.operator if op not in operator_series_sum: operator_series_sum[op] = [0.0] * 24 operator_series_count[op] = [0] * 24 sums = operator_series_sum[op] counts = operator_series_count[op] for i in range(24): v = series[i] if v is None: continue sums[i] += v counts[i] += 1 competitors_series: List[PriceSeries] = [] for op, sums in operator_series_sum.items(): counts = operator_series_count[op] avg_series: List[Optional[float]] = [] for i in range(24): c = counts[i] if c > 0: avg_series.append(sums[i] / c) else: avg_series.append(None) competitors_series.append(PriceSeries(operator=op, series=avg_series)) diffs: List[float] = [] for i in range(24): y = ylt_series[i] if y is None: continue competitor_prices: List[float] = [] for s in competitors_series: v = s.series[i] if v is not None: competitor_prices.append(float(v)) if not competitor_prices: continue min_comp = min(competitor_prices) diffs.append(y - min_comp) min_diff = min(diffs) if diffs else None max_diff = max(diffs) if diffs else None avg_diff = sum(diffs) / len(diffs) if diffs else None ylt_price_series = PriceSeries(operator=base_station.operator, series=ylt_series) return PriceComparisonResponse( hours=hours, ylt=ylt_price_series, competitors=competitors_series, min_diff=min_diff, max_diff=max_diff, avg_diff=avg_diff, ) @router.get("/health") async def health(): return {"status": "ok"} @router.get("/api/ylt/geo/competitors/{station_hash}", response_model=GeoCompetitionResponse) async def get_geo_competitors(station_hash: str): return await build_geo_competition(station_hash) @router.get("/api/ylt/geo/competitors/{station_hash}/summary", response_model=GeoCompetitionSummary) async def get_geo_competitors_summary(station_hash: str): data = await build_geo_competition(station_hash) base = data.base_station total_comp = len(data.competitors) cheaper = data.cheaper_count same = data.same_count more_expensive = data.more_expensive_count ylt_price = data.ylt_price min_price = data.min_competitor_price max_price = data.max_competitor_price summary_input = { "station_name": base.station_name, "operator": base.operator, "ylt_price": ylt_price, "competitor_count": total_comp, "cheaper_count": cheaper, "same_count": same, "more_expensive_count": more_expensive, "min_competitor_price": min_price, "max_competitor_price": max_price, } text = ( "请作为驿来特价格策略分析顾问,用简明中文解释当前场站在3公里范围内的价格竞争情况," "给出可操作的价格调整或产品策略建议,控制在300字以内。以下是结构化数据:\n" f"{summary_input}" ) chunks: List[str] = [] async for chunk in get_llm_response( text, stream=False, system_prompt="你是驿来特电价和选址策略顾问。", ): chunks.append(chunk) summary_text = "".join(chunks) return GeoCompetitionSummary(summary=summary_text) @router.get("/api/ylt/pricing/comparison/{station_hash}", response_model=PriceComparisonResponse) async def get_price_comparison(station_hash: str): return await build_price_comparison(station_hash) @router.get("/api/ylt/pricing/comparison/{station_hash}/summary", response_model=PriceComparisonSummary) async def get_price_comparison_summary(station_hash: str): data = await build_price_comparison(station_hash) ylt_series = data.ylt.series text_data = { "hours": data.hours, "ylt_prices": ylt_series, "competitors": [ {"operator": s.operator, "series": s.series} for s in data.competitors ], "min_diff": data.min_diff, "max_diff": data.max_diff, "avg_diff": data.avg_diff, } text = ( "请作为驿来特价格策略分析顾问,对下列分时电价数据进行比较分析:\n" "1) 解释驿来特与三家竞品在一天24小时内的价格差距特征;\n" "2) 指出在哪些时间段我们明显偏贵、在哪些时间段有优势;\n" "3) 给出2到3条可执行的调价或营销策略建议;\n" "控制在400字以内。数据如下:\n" f"{text_data}" ) chunks: List[str] = [] async for chunk in get_llm_response( text, stream=False, system_prompt="你是驿来特电价策略分析顾问。", ): chunks.append(chunk) summary_text = "".join(chunks) return PriceComparisonSummary(summary=summary_text) @router.get("/api/ylt/pricing/comparison/{station_hash}/sse") async def stream_price_comparison_summary(station_hash: str): data = await build_price_comparison(station_hash) text_data = { "hours": data.hours, "ylt_prices": data.ylt.series, "competitors": [ {"operator": s.operator, "series": s.series} for s in data.competitors ], "min_diff": data.min_diff, "max_diff": data.max_diff, "avg_diff": data.avg_diff, } text = ( "请作为驿来特价格策略分析顾问,对下列分时电价数据进行比较分析:\n" "1) 解释驿来特与三家竞品在一天24小时内的价格差距特征;\n" "2) 指出在哪些时间段我们明显偏贵、在哪些时间段有优势;\n" "3) 给出2到3条可执行的调价或营销策略建议;\n" "控制在400字以内。数据如下:\n" f"{text_data}" ) async def event_generator(): async for chunk in get_llm_response( text, stream=True, system_prompt="你是驿来特电价策略分析顾问。", ): if chunk is None: continue yield f"data: {chunk}\n\n" yield "event: end\ndata: [DONE]\n\n" return StreamingResponse(event_generator(), media_type="text/event-stream")