Files
aiData/Apps/AiTeJiYiChong/ReadImageKit.py
HuangHai ac79e44282 'commit'
2026-01-12 20:11:18 +08:00

400 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from PIL import Image
import os
import asyncio
import hashlib
import json
import aiohttp
import logging
import base64
import cv2
from openai import AsyncOpenAI, BadRequestError
from Config.Config import (
ALY_LLM_API_KEY, VL_MODEL_NAME, VL_MODEL_NAME_AD
)
from Apps.AiTeJiYiChong.Config.Setting import (
SAFE_EXCLUDE_RATIO, FALLBACK_WIDTH, FALLBACK_HEIGHT,
BOTTOM_SAFE_EXCLUDE_RATIO
)
from Util.PaddleOCRKit import get_ocr_kit
from Apps.AiTeJiYiChong import Kit
logger = logging.getLogger(__name__)
class ReadImageKit:
_client = AsyncOpenAI(
api_key=ALY_LLM_API_KEY,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
)
# 通用回退设备信息,仅在无法动态获取设备信息时使用
_FALLBACK_DEVICE_INFO = {
"displayWidth": FALLBACK_WIDTH,
"displayHeight": FALLBACK_HEIGHT,
"productName": "generic"
}
_prompt = (
"仅输出JSON数组不含任何说明文字按从左到右、从上到下的顺序识别图片中由【绿色方框】标识的充电站区域。识别规则如下\n"
"1. 必须是图中用绿色实线方框圈出的区域。\n"
"2. 每一个卡片区域必须同时具备以下所有要素,否则严禁识别:\n"
" - 场站名称 (station_name)\n"
" - 距离信息 (distance, 例如 '2.82km''90m'),通常位于卡片右侧蓝色胶囊区域内;\n"
" - 金额/电费 (price例如 '1.2500'),通常以红色字体显示;\n"
" - 充电枪信息 (piles包含''''的类型、总枪数和空闲枪数,例如 '快 4/4')。\n"
"3. 如果绿色方框内缺少上述任何一项要素,说明它不是真正的场站卡片,请直接跳过。\n"
"\n"
"JSON对象字段要求\n"
"1. b_use: 状态标识1或0。如果场站名称为灰色或带有“暂停使用”等标签则为0否则为1。\n"
"2. station_name: 场站名称;\n"
"3. price: 一度电的价格(数字,如 1.2500\n"
"4. piles: 充电枪列表 [{type: '', free: 4, total: 4}]\n"
"5. parking: 停车费用描述(通常在蓝色'P'图标后,例如 '免费停车三小时'\n"
"6. distance: 距离信息字符串(例如 '2.82km''90m'\n"
"7. bounds: {x1,y1,x2,y2} 区域像素坐标0-1000\n"
"8. bounds_norm: {left,top,right,bottom} 归一化坐标(0-1)\n"
"9. station_name_bounds: 场站名称文字区域坐标 {x1,y1,x2,y2}0-1000\n"
"10. station_name_bounds_norm: 场站名称文字归一化坐标(0-1)。\n"
"\n"
"重要约束:\n"
"A. 严禁识别未被绿色方框圈出的区域。如顶部的“长春市”选择、顶部的搜索框、以及中间的“推荐站点”等标签。\n"
"B. 真正的场站卡片在绿色方框内包含:场站名称、金额(红色文字)、距离(右侧蓝色背景内)、充电枪状态(绿色或蓝色徽章,格式为 闲x/x 或 x/x\n"
"C. 严禁将顶部的功能图标(如我的订单、收藏站点等)误认为场站卡片。\n"
"\n"
"严格返回纯JSON格式。"
)
@staticmethod
def _extract_json(text: str) -> str:
if not text:
return "[]"
cleaned = text.strip()
if "```" in cleaned:
lines = []
for line in cleaned.splitlines():
if line.strip().startswith("```"):
continue
lines.append(line)
cleaned = "\n".join(lines).strip()
decoder = json.JSONDecoder()
pos = 0
while pos < len(cleaned):
idx_dict = cleaned.find("{", pos)
idx_list = cleaned.find("[", pos)
candidates = [i for i in (idx_dict, idx_list) if i != -1]
if not candidates:
break
start = min(candidates)
snippet = cleaned[start:]
try:
_, end = decoder.raw_decode(snippet)
return snippet[:end]
except json.JSONDecodeError:
pos = start + 1
continue
return "[]"
_prompt_detail = (
"仅输出JSON对象不含任何说明文字识别充电站详情图片中的以下信息\n"
"1. station_name: 场站名称;\n"
"2. address: 场站完整地址(通常在定位图标旁)。\n"
"\n"
"特别说明:\n"
"- 严禁输出 Markdown 代码块标签严格返回纯JSON对象。"
)
@staticmethod
def _to_minutes(t_str: str) -> int:
"""HH:MM -> 分钟数"""
if not t_str or ":" not in t_str:
return 0
try:
h, m = map(int, t_str.split(":"))
return h * 60 + m
except:
return 0
@staticmethod
def _fmt(t: int) -> str:
"""分钟数 -> HH:MM"""
h = t // 60
m = t % 60
return f"{h:02d}:{m:02d}"
@staticmethod
def expand_schedule_to_24h(rows: list) -> list:
"""
将时段列表规整为全天24个整点小时段
"""
# 预处理:转换为分钟区间
intervals = []
for r in rows:
s = ReadImageKit._to_minutes(r.get("start"))
e = ReadImageKit._to_minutes(r.get("end"))
if e <= s and e != 0: # 处理 00:00-00:00 这种可能
if e == 0: e = 1440
else: continue
if e == 0 and s > 0: e = 1440
s = max(0, s)
e = min(1440, e)
intervals.append({
"s": s, "e": e,
"price": r.get("price")
})
# 排序
intervals.sort(key=lambda x: (x["s"], x["e"]))
result = []
for h in range(24):
hs = h * 60
he = (h + 1) * 60
best_price = None
# 找到覆盖当前小时段的价格
for it in intervals:
# 计算重叠
overlap_s = max(hs, it["s"])
overlap_e = min(he, it["e"])
if overlap_e > overlap_s:
best_price = it["price"]
break # 简单处理,取第一个覆盖的
result.append({
"start": ReadImageKit._fmt(hs),
"end": ReadImageKit._fmt(he),
"price": best_price
})
return result
_prompt_price_detail = (
"仅输出JSON对象不含任何说明文字识别充电价格详情图片中的分时段电价信息。图片中通常包含“充电时段”、“单价”、“电费”、“服务费”等列。\n"
"请识别出所有的时段行,返回一个列表,每个元素包含:\n"
"1. start: 开始时间 (HH:MM)\n"
"2. end: 结束时间 (HH:MM)\n"
"3. price: 总单价 (数字,如 0.6100)\n"
"4. ele_fee: 电费 (数字)\n"
"5. ser_fee: 服务费 (数字)。\n"
"\n"
"注意:\n"
"- 只需识别“直流桩”或当前选中的桩型下的数据。\n"
"- 严禁输出 Markdown 代码块标签严格返回纯JSON对象。"
)
@classmethod
async def get_price_detail_from_image(cls, image_path: str):
"""
使用 VL 模型从三级价格详情页截图中识别分时段电价
"""
if not os.path.exists(image_path):
logger.error(f"Image not found: {image_path}")
return None
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
try:
response = await cls._client.chat.completions.create(
model=VL_MODEL_NAME,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": cls._prompt_price_detail},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
},
],
}
],
max_tokens=2000,
temperature=0.01
)
content = response.choices[0].message.content
logger.info(f"VL Price Detail Response: {content}")
json_str = cls._extract_json(content)
data = json.loads(json_str)
# 兼容性处理:如果返回的是对象且包含列表字段,提取列表
if isinstance(data, dict):
for key in ["price_list", "prices", "schedule", "data", "items"]:
if key in data and isinstance(data[key], list):
return data[key]
# 如果字典本身就是一个价格项(包含 start 和 price将其包装成列表
if "start" in data and ("price" in data or "total_price" in data):
return [data]
# 如果是列表,直接返回
if isinstance(data, list):
return data
# 兜底:确保返回的是列表
return [data] if data else []
except Exception as e:
logger.error(f"Error calling VL model for price detail: {e}")
return None
@classmethod
async def get_station_detail_from_image(cls, image_path: str):
"""
使用 VL 模型从详情页截图中识别地址、名称和电价
"""
if not os.path.exists(image_path):
logger.error(f"Image not found: {image_path}")
return None
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
try:
response = await cls._client.chat.completions.create(
model=VL_MODEL_NAME,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": cls._prompt_detail},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
},
],
}
],
max_tokens=1000,
temperature=0.01
)
content = response.choices[0].message.content
logger.info(f"VL Detail Response: {content}")
json_str = cls._extract_json(content)
detail = json.loads(json_str)
return detail
except Exception as e:
logger.error(f"Error calling VL model for detail: {e}")
return None
@classmethod
async def get_stations_hybrid(cls, image_path: str, device_info=None):
"""
混合识别模式:图形学切片 + 本地 PaddleOCR 识别
"""
if not os.path.exists(image_path):
logger.error(f"Image not found: {image_path}")
return []
# 1. 使用 Kit 中的图形学算法识别卡片区域
# Kit.crop_cards_from_image 会生成 .json, _flag.jpg, _vl.jpg
# 我们主要需要它返回的 json_data
json_data = Kit.crop_cards_from_image(image_path, save_debug=True)
if not json_data or not json_data.get("cards"):
logger.warning("No cards detected by graphical slicing.")
return []
img = Image.open(image_path).convert("RGB")
ocr_kit = get_ocr_kit()
final_stations = []
# 2. 对每个卡片区域进行 OCR 识别
# 注意PaddleOCR 识别过程较快,且通常不涉及网络请求,可以根据需要选择并行或串行
# 这里使用串行以保证日志输出整齐,如果追求极致性能可改用 asyncio.to_thread 并行
for card in json_data["cards"]:
rect = card["rect"] # [x1, y1, x2, y2]
# 裁剪卡片
patch = img.crop((rect[0], rect[1], rect[2], rect[3]))
# 转换为 ndarray 供 PaddleOCR 使用
patch_cv = cv2.cvtColor(np.array(patch), cv2.COLOR_RGB2BGR)
# OCR 识别
logger.info(f"正在识别卡片 {card['id']}: {rect}")
res = ocr_kit.recognize(patch_cv)
if res and res.get("station_name"):
# 注入点击坐标和原始区域信息
res["uia_center_x"] = card["click_point"][0]
res["uia_center_y"] = card["click_point"][1]
res["rect"] = rect
# 转换 bounds 到 0-1000 空间(保持与 VL 模式兼容)
w, h = img.size
res["bounds"] = [
int(rect[0] * 1000 / w),
int(rect[1] * 1000 / h),
int(rect[2] * 1000 / w),
int(rect[3] * 1000 / h)
]
final_stations.append(res)
logger.info(f"卡片 {card['id']} 识别成功: {res['station_name']}")
else:
logger.warning(f"卡片 {card['id']} 识别失败或无名称")
return final_stations
@classmethod
async def get_stations_from_image(cls, image_path: str, device_info=None):
"""
使用 Qwen-VL 模型从截图中识别充电站列表
"""
if device_info is None:
device_info = cls._FALLBACK_DEVICE_INFO
if not os.path.exists(image_path):
logger.error(f"Image not found: {image_path}")
return []
# 将图片转换为 Base64
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
try:
response = await cls._client.chat.completions.create(
model=VL_MODEL_NAME,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": cls._prompt},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
},
],
}
],
max_tokens=2000,
temperature=0.01
)
content = response.choices[0].message.content
logger.info(f"VL Model Response: {content}")
json_str = cls._extract_json(content)
stations = json.loads(json_str)
# 后处理:如果 bounds 是归一化的,则转换为像素坐标(如果需要)
# 或者如果 bounds 是 0-1000 的,则保持原样或按需转换
return stations
except Exception as e:
logger.error(f"Error calling VL model: {e}")
return []