400 lines
15 KiB
Python
400 lines
15 KiB
Python
import numpy as np
|
||
from PIL import Image
|
||
import os
|
||
import asyncio
|
||
import hashlib
|
||
import json
|
||
import aiohttp
|
||
import logging
|
||
import base64
|
||
import cv2
|
||
from openai import AsyncOpenAI, BadRequestError
|
||
from Config.Config import (
|
||
ALY_LLM_API_KEY, VL_MODEL_NAME, VL_MODEL_NAME_AD
|
||
)
|
||
from Apps.AiTeJiYiChong.Config.Setting import (
|
||
SAFE_EXCLUDE_RATIO, FALLBACK_WIDTH, FALLBACK_HEIGHT,
|
||
BOTTOM_SAFE_EXCLUDE_RATIO
|
||
)
|
||
from Util.PaddleOCRKit import get_ocr_kit
|
||
from Apps.AiTeJiYiChong import Kit
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ReadImageKit:
|
||
_client = AsyncOpenAI(
|
||
api_key=ALY_LLM_API_KEY,
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
)
|
||
|
||
# 通用回退设备信息,仅在无法动态获取设备信息时使用
|
||
_FALLBACK_DEVICE_INFO = {
|
||
"displayWidth": FALLBACK_WIDTH,
|
||
"displayHeight": FALLBACK_HEIGHT,
|
||
"productName": "generic"
|
||
}
|
||
|
||
_prompt = (
|
||
"仅输出JSON数组(不含任何说明文字),按从左到右、从上到下的顺序识别图片中由【绿色方框】标识的充电站区域。识别规则如下:\n"
|
||
"1. 必须是图中用绿色实线方框圈出的区域。\n"
|
||
"2. 每一个卡片区域必须同时具备以下所有要素,否则严禁识别:\n"
|
||
" - 场站名称 (station_name);\n"
|
||
" - 距离信息 (distance, 例如 '2.82km' 或 '90m'),通常位于卡片右侧蓝色胶囊区域内;\n"
|
||
" - 金额/电费 (price,例如 '1.2500'),通常以红色字体显示;\n"
|
||
" - 充电枪信息 (piles,包含'快'或'慢'的类型、总枪数和空闲枪数,例如 '快 4/4')。\n"
|
||
"3. 如果绿色方框内缺少上述任何一项要素,说明它不是真正的场站卡片,请直接跳过。\n"
|
||
"\n"
|
||
"JSON对象字段要求:\n"
|
||
"1. b_use: 状态标识(1或0)。如果场站名称为灰色或带有“暂停使用”等标签,则为0,否则为1。\n"
|
||
"2. station_name: 场站名称;\n"
|
||
"3. price: 一度电的价格(数字,如 1.2500);\n"
|
||
"4. piles: 充电枪列表 [{type: '快', free: 4, total: 4}];\n"
|
||
"5. parking: 停车费用描述(通常在蓝色'P'图标后,例如 '免费停车三小时');\n"
|
||
"6. distance: 距离信息字符串(例如 '2.82km' 或 '90m');\n"
|
||
"7. bounds: {x1,y1,x2,y2} 区域像素坐标(0-1000);\n"
|
||
"8. bounds_norm: {left,top,right,bottom} 归一化坐标(0-1);\n"
|
||
"9. station_name_bounds: 场站名称文字区域坐标 {x1,y1,x2,y2}(0-1000);\n"
|
||
"10. station_name_bounds_norm: 场站名称文字归一化坐标(0-1)。\n"
|
||
"\n"
|
||
"重要约束:\n"
|
||
"A. 严禁识别未被绿色方框圈出的区域。如顶部的“长春市”选择、顶部的搜索框、以及中间的“推荐站点”等标签。\n"
|
||
"B. 真正的场站卡片在绿色方框内包含:场站名称、金额(红色文字)、距离(右侧蓝色背景内)、充电枪状态(绿色或蓝色徽章,格式为 闲x/x 或 x/x)。\n"
|
||
"C. 严禁将顶部的功能图标(如我的订单、收藏站点等)误认为场站卡片。\n"
|
||
"\n"
|
||
"严格返回纯JSON格式。"
|
||
)
|
||
|
||
@staticmethod
|
||
def _extract_json(text: str) -> str:
|
||
if not text:
|
||
return "[]"
|
||
|
||
cleaned = text.strip()
|
||
if "```" in cleaned:
|
||
lines = []
|
||
for line in cleaned.splitlines():
|
||
if line.strip().startswith("```"):
|
||
continue
|
||
lines.append(line)
|
||
cleaned = "\n".join(lines).strip()
|
||
|
||
decoder = json.JSONDecoder()
|
||
|
||
pos = 0
|
||
while pos < len(cleaned):
|
||
idx_dict = cleaned.find("{", pos)
|
||
idx_list = cleaned.find("[", pos)
|
||
|
||
candidates = [i for i in (idx_dict, idx_list) if i != -1]
|
||
if not candidates:
|
||
break
|
||
|
||
start = min(candidates)
|
||
snippet = cleaned[start:]
|
||
try:
|
||
_, end = decoder.raw_decode(snippet)
|
||
return snippet[:end]
|
||
except json.JSONDecodeError:
|
||
pos = start + 1
|
||
continue
|
||
|
||
return "[]"
|
||
|
||
_prompt_detail = (
|
||
"仅输出JSON对象(不含任何说明文字),识别充电站详情图片中的以下信息:\n"
|
||
"1. station_name: 场站名称;\n"
|
||
"2. address: 场站完整地址(通常在定位图标旁)。\n"
|
||
"\n"
|
||
"特别说明:\n"
|
||
"- 严禁输出 Markdown 代码块标签,严格返回纯JSON对象。"
|
||
)
|
||
|
||
@staticmethod
|
||
def _to_minutes(t_str: str) -> int:
|
||
"""HH:MM -> 分钟数"""
|
||
if not t_str or ":" not in t_str:
|
||
return 0
|
||
try:
|
||
h, m = map(int, t_str.split(":"))
|
||
return h * 60 + m
|
||
except:
|
||
return 0
|
||
|
||
@staticmethod
|
||
def _fmt(t: int) -> str:
|
||
"""分钟数 -> HH:MM"""
|
||
h = t // 60
|
||
m = t % 60
|
||
return f"{h:02d}:{m:02d}"
|
||
|
||
@staticmethod
|
||
def expand_schedule_to_24h(rows: list) -> list:
|
||
"""
|
||
将时段列表规整为全天24个整点小时段
|
||
"""
|
||
# 预处理:转换为分钟区间
|
||
intervals = []
|
||
for r in rows:
|
||
s = ReadImageKit._to_minutes(r.get("start"))
|
||
e = ReadImageKit._to_minutes(r.get("end"))
|
||
if e <= s and e != 0: # 处理 00:00-00:00 这种可能
|
||
if e == 0: e = 1440
|
||
else: continue
|
||
if e == 0 and s > 0: e = 1440
|
||
|
||
s = max(0, s)
|
||
e = min(1440, e)
|
||
intervals.append({
|
||
"s": s, "e": e,
|
||
"price": r.get("price")
|
||
})
|
||
|
||
# 排序
|
||
intervals.sort(key=lambda x: (x["s"], x["e"]))
|
||
|
||
result = []
|
||
for h in range(24):
|
||
hs = h * 60
|
||
he = (h + 1) * 60
|
||
best_price = None
|
||
|
||
# 找到覆盖当前小时段的价格
|
||
for it in intervals:
|
||
# 计算重叠
|
||
overlap_s = max(hs, it["s"])
|
||
overlap_e = min(he, it["e"])
|
||
if overlap_e > overlap_s:
|
||
best_price = it["price"]
|
||
break # 简单处理,取第一个覆盖的
|
||
|
||
result.append({
|
||
"start": ReadImageKit._fmt(hs),
|
||
"end": ReadImageKit._fmt(he),
|
||
"price": best_price
|
||
})
|
||
return result
|
||
|
||
_prompt_price_detail = (
|
||
"仅输出JSON对象(不含任何说明文字),识别充电价格详情图片中的分时段电价信息。图片中通常包含“充电时段”、“单价”、“电费”、“服务费”等列。\n"
|
||
"请识别出所有的时段行,返回一个列表,每个元素包含:\n"
|
||
"1. start: 开始时间 (HH:MM);\n"
|
||
"2. end: 结束时间 (HH:MM);\n"
|
||
"3. price: 总单价 (数字,如 0.6100);\n"
|
||
"4. ele_fee: 电费 (数字);\n"
|
||
"5. ser_fee: 服务费 (数字)。\n"
|
||
"\n"
|
||
"注意:\n"
|
||
"- 只需识别“直流桩”或当前选中的桩型下的数据。\n"
|
||
"- 严禁输出 Markdown 代码块标签,严格返回纯JSON对象。"
|
||
)
|
||
|
||
@classmethod
|
||
async def get_price_detail_from_image(cls, image_path: str):
|
||
"""
|
||
使用 VL 模型从三级价格详情页截图中识别分时段电价
|
||
"""
|
||
if not os.path.exists(image_path):
|
||
logger.error(f"Image not found: {image_path}")
|
||
return None
|
||
|
||
with open(image_path, "rb") as image_file:
|
||
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||
|
||
try:
|
||
response = await cls._client.chat.completions.create(
|
||
model=VL_MODEL_NAME,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": cls._prompt_price_detail},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||
},
|
||
],
|
||
}
|
||
],
|
||
max_tokens=2000,
|
||
temperature=0.01
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
logger.info(f"VL Price Detail Response: {content}")
|
||
|
||
json_str = cls._extract_json(content)
|
||
data = json.loads(json_str)
|
||
|
||
# 兼容性处理:如果返回的是对象且包含列表字段,提取列表
|
||
if isinstance(data, dict):
|
||
for key in ["price_list", "prices", "schedule", "data", "items"]:
|
||
if key in data and isinstance(data[key], list):
|
||
return data[key]
|
||
|
||
# 如果字典本身就是一个价格项(包含 start 和 price),将其包装成列表
|
||
if "start" in data and ("price" in data or "total_price" in data):
|
||
return [data]
|
||
|
||
# 如果是列表,直接返回
|
||
if isinstance(data, list):
|
||
return data
|
||
|
||
# 兜底:确保返回的是列表
|
||
return [data] if data else []
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error calling VL model for price detail: {e}")
|
||
return None
|
||
|
||
@classmethod
|
||
async def get_station_detail_from_image(cls, image_path: str):
|
||
"""
|
||
使用 VL 模型从详情页截图中识别地址、名称和电价
|
||
"""
|
||
if not os.path.exists(image_path):
|
||
logger.error(f"Image not found: {image_path}")
|
||
return None
|
||
|
||
with open(image_path, "rb") as image_file:
|
||
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||
|
||
try:
|
||
response = await cls._client.chat.completions.create(
|
||
model=VL_MODEL_NAME,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": cls._prompt_detail},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||
},
|
||
],
|
||
}
|
||
],
|
||
max_tokens=1000,
|
||
temperature=0.01
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
logger.info(f"VL Detail Response: {content}")
|
||
|
||
json_str = cls._extract_json(content)
|
||
detail = json.loads(json_str)
|
||
|
||
return detail
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error calling VL model for detail: {e}")
|
||
return None
|
||
|
||
@classmethod
|
||
async def get_stations_hybrid(cls, image_path: str, device_info=None):
|
||
"""
|
||
混合识别模式:图形学切片 + 本地 PaddleOCR 识别
|
||
"""
|
||
if not os.path.exists(image_path):
|
||
logger.error(f"Image not found: {image_path}")
|
||
return []
|
||
|
||
# 1. 使用 Kit 中的图形学算法识别卡片区域
|
||
# Kit.crop_cards_from_image 会生成 .json, _flag.jpg, _vl.jpg
|
||
# 我们主要需要它返回的 json_data
|
||
json_data = Kit.crop_cards_from_image(image_path, save_debug=True)
|
||
if not json_data or not json_data.get("cards"):
|
||
logger.warning("No cards detected by graphical slicing.")
|
||
return []
|
||
|
||
img = Image.open(image_path).convert("RGB")
|
||
ocr_kit = get_ocr_kit()
|
||
|
||
final_stations = []
|
||
|
||
# 2. 对每个卡片区域进行 OCR 识别
|
||
# 注意:PaddleOCR 识别过程较快,且通常不涉及网络请求,可以根据需要选择并行或串行
|
||
# 这里使用串行以保证日志输出整齐,如果追求极致性能可改用 asyncio.to_thread 并行
|
||
for card in json_data["cards"]:
|
||
rect = card["rect"] # [x1, y1, x2, y2]
|
||
# 裁剪卡片
|
||
patch = img.crop((rect[0], rect[1], rect[2], rect[3]))
|
||
|
||
# 转换为 ndarray 供 PaddleOCR 使用
|
||
patch_cv = cv2.cvtColor(np.array(patch), cv2.COLOR_RGB2BGR)
|
||
|
||
# OCR 识别
|
||
logger.info(f"正在识别卡片 {card['id']}: {rect}")
|
||
res = ocr_kit.recognize(patch_cv)
|
||
|
||
if res and res.get("station_name"):
|
||
# 注入点击坐标和原始区域信息
|
||
res["uia_center_x"] = card["click_point"][0]
|
||
res["uia_center_y"] = card["click_point"][1]
|
||
res["rect"] = rect
|
||
|
||
# 转换 bounds 到 0-1000 空间(保持与 VL 模式兼容)
|
||
w, h = img.size
|
||
res["bounds"] = [
|
||
int(rect[0] * 1000 / w),
|
||
int(rect[1] * 1000 / h),
|
||
int(rect[2] * 1000 / w),
|
||
int(rect[3] * 1000 / h)
|
||
]
|
||
|
||
final_stations.append(res)
|
||
logger.info(f"卡片 {card['id']} 识别成功: {res['station_name']}")
|
||
else:
|
||
logger.warning(f"卡片 {card['id']} 识别失败或无名称")
|
||
|
||
return final_stations
|
||
|
||
@classmethod
|
||
async def get_stations_from_image(cls, image_path: str, device_info=None):
|
||
"""
|
||
使用 Qwen-VL 模型从截图中识别充电站列表
|
||
"""
|
||
if device_info is None:
|
||
device_info = cls._FALLBACK_DEVICE_INFO
|
||
|
||
if not os.path.exists(image_path):
|
||
logger.error(f"Image not found: {image_path}")
|
||
return []
|
||
|
||
# 将图片转换为 Base64
|
||
with open(image_path, "rb") as image_file:
|
||
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
||
|
||
try:
|
||
response = await cls._client.chat.completions.create(
|
||
model=VL_MODEL_NAME,
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": cls._prompt},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
|
||
},
|
||
],
|
||
}
|
||
],
|
||
max_tokens=2000,
|
||
temperature=0.01
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
logger.info(f"VL Model Response: {content}")
|
||
|
||
json_str = cls._extract_json(content)
|
||
stations = json.loads(json_str)
|
||
|
||
# 后处理:如果 bounds 是归一化的,则转换为像素坐标(如果需要)
|
||
# 或者如果 bounds 是 0-1000 的,则保持原样或按需转换
|
||
return stations
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error calling VL model: {e}")
|
||
return []
|