261 lines
9.3 KiB
Python
261 lines
9.3 KiB
Python
# coding=utf-8
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import sys
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
# Ensure sys path includes root for imports if not already
|
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
if project_root not in sys.path:
|
|
sys.path.append(project_root)
|
|
|
|
from Apps.AiTeJiYiChong.ReadImageKit import ReadImageKit
|
|
from DbKit.Db import Db
|
|
from Config.Config import DB_URL
|
|
from Model.StationProfile import StationProfile
|
|
from Model.StationStatus import StationStatus
|
|
from Model.StationPriceSchedule import StationPriceSchedule
|
|
|
|
# 配置日志
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class AiTeJiYiChongService:
|
|
def __init__(self):
|
|
self.db = Db(db_url=DB_URL)
|
|
self.station_profile_model = StationProfile()
|
|
self.station_status_model = StationStatus()
|
|
self.station_price_schedule_model = StationPriceSchedule()
|
|
self.operator = "艾特吉易充"
|
|
|
|
def generate_id(self):
|
|
return str(uuid.uuid4())
|
|
|
|
def get_hash(self, s: str) -> str:
|
|
return hashlib.md5(s.encode('utf-8')).hexdigest()
|
|
|
|
async def init_db(self):
|
|
await self.db.init_db()
|
|
|
|
async def close_db(self):
|
|
await self.db.close()
|
|
|
|
async def process_price_detail_data(self, station_name, hourly_schedule) -> bool:
|
|
"""
|
|
直接保存已处理好的小时段价格数据
|
|
"""
|
|
if not station_name or not hourly_schedule:
|
|
return False
|
|
|
|
station_hash = self.get_hash(station_name)
|
|
now = datetime.now()
|
|
|
|
async with await self.db.get_session() as session:
|
|
schedule_id = self.generate_id()
|
|
await self.station_price_schedule_model.save(
|
|
session=session,
|
|
id=schedule_id,
|
|
station_hash=station_hash,
|
|
schedule_json=hourly_schedule,
|
|
valid_start_time=now
|
|
)
|
|
await session.commit()
|
|
return True
|
|
|
|
async def process_price_detail(self, image_path, station_name) -> list:
|
|
"""
|
|
处理三级价格详情页截图
|
|
"""
|
|
prices = await ReadImageKit.get_price_detail_from_image(image_path)
|
|
if not prices:
|
|
return None
|
|
|
|
station_hash = self.get_hash(station_name)
|
|
now = datetime.now()
|
|
|
|
# 将识别到的原始分时价格扩展为 24 小时整点数据
|
|
hourly_schedule = ReadImageKit.expand_schedule_to_24h(prices)
|
|
|
|
async with await self.db.get_session() as session:
|
|
schedule_id = self.generate_id()
|
|
await self.station_price_schedule_model.save(
|
|
session=session,
|
|
id=schedule_id,
|
|
station_hash=station_hash,
|
|
schedule_json=hourly_schedule,
|
|
valid_start_time=now
|
|
)
|
|
await session.commit()
|
|
|
|
logger.info(f"三级页面价格详情处理完成: {station_name}, 共 {len(hourly_schedule)} 个小时段")
|
|
return hourly_schedule
|
|
|
|
async def process_station_detail(self, image_path, station_name=None) -> dict:
|
|
"""
|
|
处理场站详情页截图
|
|
"""
|
|
detail = await ReadImageKit.get_station_detail_from_image(image_path)
|
|
if not detail:
|
|
return None
|
|
|
|
name = station_name or detail.get("station_name")
|
|
if not name:
|
|
return None
|
|
|
|
station_hash = self.get_hash(name)
|
|
now = datetime.now()
|
|
|
|
async with await self.db.get_session() as session:
|
|
# 1. 更新 Profile 中的地址信息
|
|
profile_id = self.generate_id()
|
|
await self.station_profile_model.save(
|
|
session=session,
|
|
id=profile_id,
|
|
station_hash=station_hash,
|
|
operator=self.operator,
|
|
station_name=name,
|
|
address=detail.get("address"),
|
|
valid_start_time=now
|
|
)
|
|
|
|
await session.commit()
|
|
|
|
logger.info(f"场站详情处理完成: {name}")
|
|
return detail
|
|
|
|
async def process_station_list_hybrid(self, image_path, device_info=None) -> list:
|
|
"""
|
|
基于混合模式处理场站列表 (图形学切片 + 本地 OCR)
|
|
"""
|
|
station_list = await ReadImageKit.get_stations_hybrid(image_path, device_info=device_info)
|
|
if not station_list:
|
|
return []
|
|
|
|
processed_stations = []
|
|
async with await self.db.get_session() as session:
|
|
for station in station_list:
|
|
name = station.get("station_name")
|
|
if not name:
|
|
continue
|
|
|
|
station_hash = self.get_hash(name)
|
|
now = datetime.now()
|
|
station["station_hash"] = station_hash
|
|
|
|
# 1. 保存 Profile
|
|
profile_id = self.generate_id()
|
|
await self.station_profile_model.save(
|
|
session=session,
|
|
id=profile_id,
|
|
station_hash=station_hash,
|
|
operator=self.operator,
|
|
station_name=name,
|
|
valid_start_time=now
|
|
)
|
|
station["profile_id"] = profile_id
|
|
station["valid_start_time"] = now.isoformat()
|
|
|
|
# 2. 保存 Status (解析价格和电桩)
|
|
status_id = self.generate_id()
|
|
|
|
# 处理 piles 字段
|
|
piles_data = station.get("piles")
|
|
total, free = 0, 0
|
|
if isinstance(piles_data, list):
|
|
for p in piles_data:
|
|
total += int(p.get("total", 0))
|
|
free += int(p.get("free", 0))
|
|
|
|
await self.station_status_model.save(
|
|
session=session,
|
|
id=status_id,
|
|
station_hash=station_hash,
|
|
total_piles=total,
|
|
free_piles=free,
|
|
piles_detail_json=piles_data,
|
|
current_price=float(station.get("price", 0)) if station.get("price") else 0.0,
|
|
parking_info=station.get("parking", ""),
|
|
distance=station.get("distance", ""),
|
|
valid_start_time=now
|
|
)
|
|
station["status_id"] = status_id
|
|
|
|
processed_stations.append(station)
|
|
|
|
await session.commit()
|
|
|
|
return processed_stations
|
|
|
|
async def process_station_list_vl(self, image_path, device_info=None) -> list:
|
|
"""
|
|
基于 VL 模式处理场站列表
|
|
"""
|
|
# 优先使用带绿框的 _vl.jpg 图片进行识别
|
|
vl_img_path = image_path.replace(".jpg", "_vl.jpg")
|
|
if os.path.exists(vl_img_path):
|
|
logger.info(f"使用带绿框的图片进行识别: {vl_img_path}")
|
|
image_to_process = vl_img_path
|
|
else:
|
|
image_to_process = image_path
|
|
|
|
station_list = await ReadImageKit.get_stations_from_image(image_to_process, device_info=device_info)
|
|
if not station_list:
|
|
return []
|
|
|
|
processed_stations = []
|
|
async with await self.db.get_session() as session:
|
|
for station in station_list:
|
|
name = station.get("station_name")
|
|
if not name:
|
|
continue
|
|
|
|
station_hash = self.get_hash(name)
|
|
now = datetime.now()
|
|
station["station_hash"] = station_hash
|
|
|
|
# 1. 保存 Profile
|
|
profile_id = self.generate_id()
|
|
await self.station_profile_model.save(
|
|
session=session,
|
|
id=profile_id,
|
|
station_hash=station_hash,
|
|
operator=self.operator,
|
|
station_name=name,
|
|
valid_start_time=now
|
|
)
|
|
station["profile_id"] = profile_id
|
|
station["valid_start_time"] = now.isoformat()
|
|
|
|
# 2. 保存 Status (解析价格和电桩)
|
|
status_id = self.generate_id()
|
|
|
|
# 处理 piles 字段
|
|
piles_data = station.get("piles")
|
|
total, free = 0, 0
|
|
if isinstance(piles_data, list):
|
|
for p in piles_data:
|
|
total += int(p.get("total", 0))
|
|
free += int(p.get("free", 0))
|
|
|
|
await self.station_status_model.save(
|
|
session=session,
|
|
id=status_id,
|
|
station_hash=station_hash,
|
|
total_piles=total,
|
|
free_piles=free,
|
|
piles_detail_json=piles_data,
|
|
current_price=float(station.get("price", 0)) if station.get("price") else 0.0,
|
|
parking_info=station.get("parking", ""),
|
|
distance=station.get("distance", ""),
|
|
valid_start_time=now
|
|
)
|
|
station["status_id"] = status_id
|
|
|
|
processed_stations.append(station)
|
|
|
|
await session.commit()
|
|
|
|
return processed_stations
|