Files
aiData/Apps/TeLaiDian/ReadImageKit.py
HuangHai ee91bf76d2 'commit'
2026-01-14 09:56:04 +08:00

172 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
import logging
import os
import sys
import json
# Ensure sys path includes root for imports if not already
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if project_root not in sys.path:
sys.path.append(project_root)
from Util.VLMKit import VLMKit
from Apps.TeLaiDian.Kit import draw_rectangles, detect_cards_cv, setup_logger
from Apps.TeLaiDian.Config.Setting import SAFE_EXCLUDE_RATIO, BOTTOM_SAFE_EXCLUDE_RATIO
# 初始化日志
logger = setup_logger("ReadImageKit")
class ReadImageKit:
def __init__(self):
self.vlm = VLMKit()
async def analyze_detail_price(self, image_path):
"""
分析详情页截图提取电价信息包括优惠价、PLUS价和挂牌价
"""
prompt = """
分析这张充电站价格详情页截图,提取**分时电价表**。
对于每个时段,请识别并提取以下所有价格信息(如果存在):
1. 优惠价 (通常是红色或加粗的大字,作为默认 price)
2. PLUS会员价 (标有 "PLUS" 标签的价格)
3. 挂牌价 (标有 "挂牌价" 标签的价格)
4. 电费 (Base electricity price)
5. 服务费 (Service fee)
请提取每个时段的:
- start: 开始时间 (HH:MM)
- end: 结束时间 (HH:MM)
- price: 优惠价 (元/度)
- plus_price: PLUS会员价 (元/度)
- market_price: 挂牌价 (元/度)
- elec_price: 电费 (元/度)
- service_price: 服务费 (元/度)
输出格式为 JSON 数组:
[
{
"start": "16:00",
"end": "21:00",
"price": 1.3435,
"plus_price": 1.3035,
"market_price": 1.4435,
"elec_price": 0.9435,
"service_price": 0.4000
},
...
]
注意:
- 如果某个字段缺失,请设为 null。
- 确保 price 包含电费和服务费的总和。
- 如果无法识别任何价格信息,请返回空数组 []。
"""
try:
res_text = await self.vlm.analyze_image(image_path, prompt)
logger.info(f"VLM Price Analysis Result for {os.path.basename(image_path)}: {res_text[:200]}...")
json_str = self.vlm.extract_json(res_text)
prices = json.loads(json_str)
normalized_prices = []
if isinstance(prices, list):
for p in prices:
new_p = p.copy()
if 'time_range' in p and ('start' not in p or 'end' not in p):
tr = p['time_range'].replace('~', '-').replace(' ', '')
parts = tr.split('-')
if len(parts) >= 2:
new_p['start'] = parts[0]
new_p['end'] = parts[1]
if 'price' not in p:
if 'total_price' in p:
new_p['price'] = p['total_price']
elif 'elec_price' in p and 'service_price' in p:
try:
new_p['price'] = float(p['elec_price']) + float(p['service_price'])
except:
pass
normalized_prices.append(new_p)
return normalized_prices
return []
except Exception as e:
logger.error(f"分析电价详情失败: {e}")
return []
async def analyze_detail_basic_info(self, image_path):
"""
分析详情页首屏截图,提取场站名称和精确地址
"""
prompt = """
分析这张充电站详情页首屏截图,提取:
1. 场站名称 (通常在页面中部,大字体)
2. 详细地址 (通常在名称下方或页面下半部分,伴有地址图标)
输出格式为 JSON
{
"name": "xxx充电站",
"address": "xxx省xxx市xxx区xxx路xxx号"
}
"""
try:
res_text = await self.vlm.analyze_image(image_path, prompt)
json_str = self.vlm.extract_json(res_text)
return json.loads(json_str)
except Exception as e:
logger.error(f"分析详情页基础信息失败: {e}")
return {}
async def analyze_station_list(self, image_path):
"""
分析场站列表页图片,提取场站位置和基本信息
"""
cv_bboxes = detect_cards_cv(image_path, top_ratio=SAFE_EXCLUDE_RATIO, bottom_ratio=BOTTOM_SAFE_EXCLUDE_RATIO)
if cv_bboxes:
draw_rectangles(image_path, cv_bboxes)
prompt = f"""
图片中已经用绿色矩形框标记了 {len(cv_bboxes)} 个可能的充电站卡片。
请按从上到下的顺序,识别每个绿色框内的场站信息。
输出格式为 JSON 数组,长度必须为 {len(cv_bboxes)}
每个对象包含:
- "name": 场站名称
- "address": 场站地址
- "is_valid": true/false (是否为真实的场站卡片)
"""
else:
prompt = """
分析这张充电站列表截图,提取所有充电站卡片信息。
忽略顶部的筛选栏,仅提取下方重复出现的场站卡片。
输出格式为 JSON 数组,每个对象包含:
- "name": 场站名称
- "address": 场站地址
- "point": 场站卡片的中心点击坐标 [x, y]
- "bbox": 场站卡片的边界框 [x1, y1, x2, y2]
"""
try:
res_text = await self.vlm.analyze_image(image_path, prompt)
json_str = self.vlm.extract_json(res_text)
vlm_results = json.loads(json_str)
final_stations = []
if cv_bboxes and isinstance(vlm_results, list):
for i, res in enumerate(vlm_results):
if i < len(cv_bboxes):
bbox = cv_bboxes[i]
if res and (res.get("is_valid") is True or (res.get("name") and res.get("is_valid") is not False)):
final_stations.append({
"name": res.get("name"),
"address": res.get("address"),
"point": [(bbox[0] + bbox[2]) // 2, (bbox[1] + bbox[3]) // 2],
"bbox": bbox
})
elif not cv_bboxes:
final_stations = vlm_results if isinstance(vlm_results, list) else []
return final_stations
except Exception as e:
logger.error(f"分析列表页失败: {e}")
return []