Files
aiData/Controller/YltAnalyticsController.py
HuangHai 5c0a1a67ac 'commit'
2026-01-18 16:02:40 +08:00

454 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import asyncio
import json
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from DbKit.Db import Db
from Config.Config import DB_URL
from Util.LlmUtil import get_llm_response
from Model.YltAnalyticsModel import (
StationBase,
CompetitorStation,
GeoCompetitionResponse,
GeoCompetitionSummary,
PriceSeries,
PriceComparisonResponse,
PriceComparisonSummary,
)
router = APIRouter()
db = Db(db_url=DB_URL)
async def init_db():
await db.init_db()
async def close_db():
await db.close()
@router.get("/api/ylt/stations", response_model=List[StationBase])
async def list_ylt_stations(q: Optional[str] = None):
base_sql = """
SELECT
p.station_hash,
p.operator,
p.station_name,
p.address,
p.coord_x,
p.coord_y,
s.current_price
FROM t_station_profile_scd p
LEFT JOIN t_station_status_scd s
ON p.station_hash = s.station_hash AND s.is_current = 1
WHERE p.is_current = 1
AND p.operator = '驿来特'
"""
params: Dict[str, Any] = {}
if q:
base_sql += " AND (p.station_name LIKE :kw OR p.address LIKE :kw)"
params["kw"] = f"%{q}%"
base_sql += " ORDER BY p.station_name"
rows = await db.find(base_sql, params)
result: List[StationBase] = []
for r in rows:
result.append(
StationBase(
station_hash=r.get("station_hash"),
operator=r.get("operator"),
station_name=r.get("station_name"),
address=r.get("address"),
coord_x=r.get("coord_x"),
coord_y=r.get("coord_y"),
current_price=r.get("current_price"),
)
)
return result
def haversine_km(lon1: float, lat1: float, lon2: float, lat2: float) -> float:
r = 6371.0
phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
d_phi = math.radians(lat2 - lat1)
d_lambda = math.radians(lon2 - lon1)
a = math.sin(d_phi / 2) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(d_lambda / 2) ** 2
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
return r * c
async def fetch_current_stations() -> List[dict]:
sql = """
SELECT
p.station_hash,
p.operator,
p.station_name,
p.address,
p.coord_x,
p.coord_y,
s.current_price
FROM t_station_profile_scd p
LEFT JOIN t_station_status_scd s
ON p.station_hash = s.station_hash AND s.is_current = 1
WHERE p.is_current = 1
"""
return await db.find(sql, {})
async def build_geo_competition(station_hash: str, radius_km: float = 3.0) -> GeoCompetitionResponse:
rows = await fetch_current_stations()
if not rows:
raise HTTPException(status_code=404, detail="no station data")
base_row = None
for r in rows:
if r.get("station_hash") == station_hash and r.get("operator") == "驿来特":
base_row = r
break
if base_row is None:
raise HTTPException(status_code=404, detail="base station not found for 驿来特")
base_lon = base_row.get("coord_x")
base_lat = base_row.get("coord_y")
if base_lon is None or base_lat is None:
raise HTTPException(status_code=400, detail="base station has no coordinates")
competitors: List[CompetitorStation] = []
ylt_price = base_row.get("current_price")
cheaper = 0
same = 0
more_expensive = 0
min_price: Optional[float] = None
max_price: Optional[float] = None
for r in rows:
if r.get("operator") == "驿来特":
continue
lon = r.get("coord_x")
lat = r.get("coord_y")
if lon is None or lat is None:
continue
dist = haversine_km(base_lon, base_lat, lon, lat)
if dist > radius_km:
continue
price = r.get("current_price")
competitors.append(
CompetitorStation(
station_hash=r.get("station_hash"),
operator=r.get("operator"),
station_name=r.get("station_name"),
distance_km=round(dist, 3),
current_price=price,
)
)
if price is not None:
if min_price is None or price < min_price:
min_price = price
if max_price is None or price > max_price:
max_price = price
if ylt_price is not None:
if price < ylt_price:
cheaper += 1
elif price > ylt_price:
more_expensive += 1
else:
same += 1
base_station = StationBase(
station_hash=base_row.get("station_hash"),
operator=base_row.get("operator"),
station_name=base_row.get("station_name"),
address=base_row.get("address"),
coord_x=base_lon,
coord_y=base_lat,
current_price=ylt_price,
)
return GeoCompetitionResponse(
base_station=base_station,
competitors=competitors,
ylt_price=ylt_price,
min_competitor_price=min_price,
max_competitor_price=max_price,
cheaper_count=cheaper,
same_count=same,
more_expensive_count=more_expensive,
)
async def fetch_station_schedule_json(station_hash: str) -> Optional[str]:
sql = """
SELECT schedule_json
FROM t_station_price_schedule_scd
WHERE station_hash = :h AND is_current = 1
ORDER BY valid_start_time DESC
LIMIT 1
"""
rows = await db.find(sql, {"h": station_hash})
if not rows:
return None
value = rows[0].get("schedule_json")
if value is None:
return None
if isinstance(value, str):
return value
try:
return json.dumps(value, ensure_ascii=False)
except Exception:
return None
def extract_price_from_item(item: Dict[str, Any]) -> Optional[float]:
if not isinstance(item, dict):
return None
for key in ("price", "price_kwh", "priceKwh", "total_price", "totalPrice"):
v = item.get(key)
if isinstance(v, (int, float)):
return float(v)
elec = item.get("elec_price")
service = item.get("service_price")
if isinstance(elec, (int, float)) and isinstance(service, (int, float)):
return float(elec) + float(service)
elec2 = item.get("electric_fee_kwh")
service2 = item.get("service_fee_kwh")
if isinstance(elec2, (int, float)) and isinstance(service2, (int, float)):
return float(elec2) + float(service2)
elec3 = item.get("ele_fee")
service3 = item.get("ser_fee")
if isinstance(elec3, (int, float)) and isinstance(service3, (int, float)):
return float(elec3) + float(service3)
return None
def parse_hour_from_item(item: Dict[str, Any], default_index: int) -> Optional[int]:
start = item.get("start")
if isinstance(start, str) and ":" in start:
parts = start.split(":")
try:
h = int(parts[0])
if 0 <= h <= 23:
return h
except Exception:
pass
end = item.get("end")
if isinstance(end, str) and ":" in end:
parts = end.split(":")
try:
h2 = int(parts[0])
if 0 < h2 <= 24:
return h2 - 1
except Exception:
pass
if 0 <= default_index <= 23:
return default_index
return None
def extract_hourly_prices(schedule_json_str: str) -> List[Optional[float]]:
series: List[Optional[float]] = [None] * 24
if not schedule_json_str:
return series
try:
data = json.loads(schedule_json_str)
except Exception:
return series
if not isinstance(data, list):
return series
for idx, item in enumerate(data):
price = extract_price_from_item(item)
if price is None:
continue
hour_idx = parse_hour_from_item(item, idx)
if hour_idx is None or not (0 <= hour_idx < 24):
continue
series[hour_idx] = float(price)
return series
async def build_price_comparison(station_hash: str) -> PriceComparisonResponse:
geo = await build_geo_competition(station_hash)
base_station = geo.base_station
base_schedule_str = await fetch_station_schedule_json(base_station.station_hash)
if base_schedule_str is None:
raise HTTPException(status_code=404, detail="no price schedule for YLT station")
ylt_series = extract_hourly_prices(base_schedule_str)
hours = list(range(24))
operator_series_sum: Dict[str, List[float]] = {}
operator_series_count: Dict[str, List[int]] = {}
for comp in geo.competitors:
schedule_str = await fetch_station_schedule_json(comp.station_hash)
if not schedule_str:
continue
series = extract_hourly_prices(schedule_str)
op = comp.operator
if op not in operator_series_sum:
operator_series_sum[op] = [0.0] * 24
operator_series_count[op] = [0] * 24
sums = operator_series_sum[op]
counts = operator_series_count[op]
for i in range(24):
v = series[i]
if v is None:
continue
sums[i] += v
counts[i] += 1
competitors_series: List[PriceSeries] = []
for op, sums in operator_series_sum.items():
counts = operator_series_count[op]
avg_series: List[Optional[float]] = []
for i in range(24):
c = counts[i]
if c > 0:
avg_series.append(sums[i] / c)
else:
avg_series.append(None)
competitors_series.append(PriceSeries(operator=op, series=avg_series))
diffs: List[float] = []
for i in range(24):
y = ylt_series[i]
if y is None:
continue
competitor_prices: List[float] = []
for s in competitors_series:
v = s.series[i]
if v is not None:
competitor_prices.append(float(v))
if not competitor_prices:
continue
min_comp = min(competitor_prices)
diffs.append(y - min_comp)
min_diff = min(diffs) if diffs else None
max_diff = max(diffs) if diffs else None
avg_diff = sum(diffs) / len(diffs) if diffs else None
ylt_price_series = PriceSeries(operator=base_station.operator, series=ylt_series)
return PriceComparisonResponse(
hours=hours,
ylt=ylt_price_series,
competitors=competitors_series,
min_diff=min_diff,
max_diff=max_diff,
avg_diff=avg_diff,
)
@router.get("/health")
async def health():
return {"status": "ok"}
@router.get("/api/ylt/geo/competitors/{station_hash}", response_model=GeoCompetitionResponse)
async def get_geo_competitors(station_hash: str):
return await build_geo_competition(station_hash)
@router.get("/api/ylt/geo/competitors/{station_hash}/summary", response_model=GeoCompetitionSummary)
async def get_geo_competitors_summary(station_hash: str):
data = await build_geo_competition(station_hash)
base = data.base_station
total_comp = len(data.competitors)
cheaper = data.cheaper_count
same = data.same_count
more_expensive = data.more_expensive_count
ylt_price = data.ylt_price
min_price = data.min_competitor_price
max_price = data.max_competitor_price
summary_input = {
"station_name": base.station_name,
"operator": base.operator,
"ylt_price": ylt_price,
"competitor_count": total_comp,
"cheaper_count": cheaper,
"same_count": same,
"more_expensive_count": more_expensive,
"min_competitor_price": min_price,
"max_competitor_price": max_price,
}
text = (
"请作为驿来特价格策略分析顾问用简明中文解释当前场站在3公里范围内的价格竞争情况"
"给出可操作的价格调整或产品策略建议控制在300字以内。以下是结构化数据\n"
f"{summary_input}"
)
chunks: List[str] = []
async for chunk in get_llm_response(
text,
stream=False,
system_prompt="你是驿来特电价和选址策略顾问。",
):
chunks.append(chunk)
summary_text = "".join(chunks)
return GeoCompetitionSummary(summary=summary_text)
@router.get("/api/ylt/pricing/comparison/{station_hash}", response_model=PriceComparisonResponse)
async def get_price_comparison(station_hash: str):
return await build_price_comparison(station_hash)
@router.get("/api/ylt/pricing/comparison/{station_hash}/summary", response_model=PriceComparisonSummary)
async def get_price_comparison_summary(station_hash: str):
data = await build_price_comparison(station_hash)
ylt_series = data.ylt.series
text_data = {
"hours": data.hours,
"ylt_prices": ylt_series,
"competitors": [
{"operator": s.operator, "series": s.series} for s in data.competitors
],
"min_diff": data.min_diff,
"max_diff": data.max_diff,
"avg_diff": data.avg_diff,
}
text = (
"请作为驿来特价格策略分析顾问,对下列分时电价数据进行比较分析:\n"
"1) 解释驿来特与三家竞品在一天24小时内的价格差距特征\n"
"2) 指出在哪些时间段我们明显偏贵、在哪些时间段有优势;\n"
"3) 给出2到3条可执行的调价或营销策略建议\n"
"控制在400字以内。数据如下\n"
f"{text_data}"
)
chunks: List[str] = []
async for chunk in get_llm_response(
text,
stream=False,
system_prompt="你是驿来特电价策略分析顾问。",
):
chunks.append(chunk)
summary_text = "".join(chunks)
return PriceComparisonSummary(summary=summary_text)
@router.get("/api/ylt/pricing/comparison/{station_hash}/sse")
async def stream_price_comparison_summary(station_hash: str):
data = await build_price_comparison(station_hash)
text_data = {
"hours": data.hours,
"ylt_prices": data.ylt.series,
"competitors": [
{"operator": s.operator, "series": s.series} for s in data.competitors
],
"min_diff": data.min_diff,
"max_diff": data.max_diff,
"avg_diff": data.avg_diff,
}
text = (
"请作为驿来特价格策略分析顾问,对下列分时电价数据进行比较分析:\n"
"1) 解释驿来特与三家竞品在一天24小时内的价格差距特征\n"
"2) 指出在哪些时间段我们明显偏贵、在哪些时间段有优势;\n"
"3) 给出2到3条可执行的调价或营销策略建议\n"
"控制在400字以内。数据如下\n"
f"{text_data}"
)
async def event_generator():
async for chunk in get_llm_response(
text,
stream=True,
system_prompt="你是驿来特电价策略分析顾问。",
):
if chunk is None:
continue
yield f"data: {chunk}\n\n"
yield "event: end\ndata: [DONE]\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")