172 lines
6.9 KiB
Python
172 lines
6.9 KiB
Python
# coding=utf-8
|
||
import logging
|
||
import os
|
||
import sys
|
||
import json
|
||
|
||
# Ensure sys path includes root for imports if not already
|
||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
if project_root not in sys.path:
|
||
sys.path.append(project_root)
|
||
|
||
from Util.VLMKit import VLMKit
|
||
from Apps.TeLaiDian.Kit import draw_rectangles, detect_cards_cv, setup_logger
|
||
from Apps.TeLaiDian.Config.Setting import SAFE_EXCLUDE_RATIO, BOTTOM_SAFE_EXCLUDE_RATIO
|
||
|
||
# 初始化日志
|
||
logger = setup_logger("ReadImageKit")
|
||
|
||
class ReadImageKit:
|
||
def __init__(self):
|
||
self.vlm = VLMKit()
|
||
|
||
async def analyze_detail_price(self, image_path):
|
||
"""
|
||
分析详情页截图,提取电价信息,包括优惠价、PLUS价和挂牌价
|
||
"""
|
||
prompt = """
|
||
分析这张充电站价格详情页截图,提取**分时电价表**。
|
||
对于每个时段,请识别并提取以下所有价格信息(如果存在):
|
||
1. 优惠价 (通常是红色或加粗的大字,作为默认 price)
|
||
2. PLUS会员价 (标有 "PLUS" 标签的价格)
|
||
3. 挂牌价 (标有 "挂牌价" 标签的价格)
|
||
4. 电费 (Base electricity price)
|
||
5. 服务费 (Service fee)
|
||
|
||
请提取每个时段的:
|
||
- start: 开始时间 (HH:MM)
|
||
- end: 结束时间 (HH:MM)
|
||
- price: 优惠价 (元/度)
|
||
- plus_price: PLUS会员价 (元/度)
|
||
- market_price: 挂牌价 (元/度)
|
||
- elec_price: 电费 (元/度)
|
||
- service_price: 服务费 (元/度)
|
||
|
||
输出格式为 JSON 数组:
|
||
[
|
||
{
|
||
"start": "16:00",
|
||
"end": "21:00",
|
||
"price": 1.3435,
|
||
"plus_price": 1.3035,
|
||
"market_price": 1.4435,
|
||
"elec_price": 0.9435,
|
||
"service_price": 0.4000
|
||
},
|
||
...
|
||
]
|
||
注意:
|
||
- 如果某个字段缺失,请设为 null。
|
||
- 确保 price 包含电费和服务费的总和。
|
||
- 如果无法识别任何价格信息,请返回空数组 []。
|
||
"""
|
||
try:
|
||
res_text = await self.vlm.analyze_image(image_path, prompt)
|
||
logger.info(f"VLM Price Analysis Result for {os.path.basename(image_path)}: {res_text[:200]}...")
|
||
|
||
json_str = self.vlm.extract_json(res_text)
|
||
prices = json.loads(json_str)
|
||
|
||
normalized_prices = []
|
||
if isinstance(prices, list):
|
||
for p in prices:
|
||
new_p = p.copy()
|
||
if 'time_range' in p and ('start' not in p or 'end' not in p):
|
||
tr = p['time_range'].replace('~', '-').replace(' ', '')
|
||
parts = tr.split('-')
|
||
if len(parts) >= 2:
|
||
new_p['start'] = parts[0]
|
||
new_p['end'] = parts[1]
|
||
|
||
if 'price' not in p:
|
||
if 'total_price' in p:
|
||
new_p['price'] = p['total_price']
|
||
elif 'elec_price' in p and 'service_price' in p:
|
||
try:
|
||
new_p['price'] = float(p['elec_price']) + float(p['service_price'])
|
||
except:
|
||
pass
|
||
normalized_prices.append(new_p)
|
||
return normalized_prices
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"分析电价详情失败: {e}")
|
||
return []
|
||
|
||
async def analyze_detail_basic_info(self, image_path):
|
||
"""
|
||
分析详情页首屏截图,提取场站名称和精确地址
|
||
"""
|
||
prompt = """
|
||
分析这张充电站详情页首屏截图,提取:
|
||
1. 场站名称 (通常在页面中部,大字体)
|
||
2. 详细地址 (通常在名称下方或页面下半部分,伴有地址图标)
|
||
|
||
输出格式为 JSON:
|
||
{
|
||
"name": "xxx充电站",
|
||
"address": "xxx省xxx市xxx区xxx路xxx号"
|
||
}
|
||
"""
|
||
try:
|
||
res_text = await self.vlm.analyze_image(image_path, prompt)
|
||
json_str = self.vlm.extract_json(res_text)
|
||
return json.loads(json_str)
|
||
except Exception as e:
|
||
logger.error(f"分析详情页基础信息失败: {e}")
|
||
return {}
|
||
|
||
async def analyze_station_list(self, image_path):
|
||
"""
|
||
分析场站列表页图片,提取场站位置和基本信息
|
||
"""
|
||
cv_bboxes = detect_cards_cv(image_path, top_ratio=SAFE_EXCLUDE_RATIO, bottom_ratio=BOTTOM_SAFE_EXCLUDE_RATIO)
|
||
|
||
if cv_bboxes:
|
||
draw_rectangles(image_path, cv_bboxes)
|
||
prompt = f"""
|
||
图片中已经用绿色矩形框标记了 {len(cv_bboxes)} 个可能的充电站卡片。
|
||
请按从上到下的顺序,识别每个绿色框内的场站信息。
|
||
|
||
输出格式为 JSON 数组,长度必须为 {len(cv_bboxes)}。
|
||
每个对象包含:
|
||
- "name": 场站名称
|
||
- "address": 场站地址
|
||
- "is_valid": true/false (是否为真实的场站卡片)
|
||
"""
|
||
else:
|
||
prompt = """
|
||
分析这张充电站列表截图,提取所有充电站卡片信息。
|
||
忽略顶部的筛选栏,仅提取下方重复出现的场站卡片。
|
||
输出格式为 JSON 数组,每个对象包含:
|
||
- "name": 场站名称
|
||
- "address": 场站地址
|
||
- "point": 场站卡片的中心点击坐标 [x, y]
|
||
- "bbox": 场站卡片的边界框 [x1, y1, x2, y2]
|
||
"""
|
||
|
||
try:
|
||
res_text = await self.vlm.analyze_image(image_path, prompt)
|
||
json_str = self.vlm.extract_json(res_text)
|
||
vlm_results = json.loads(json_str)
|
||
|
||
final_stations = []
|
||
if cv_bboxes and isinstance(vlm_results, list):
|
||
for i, res in enumerate(vlm_results):
|
||
if i < len(cv_bboxes):
|
||
bbox = cv_bboxes[i]
|
||
if res and (res.get("is_valid") is True or (res.get("name") and res.get("is_valid") is not False)):
|
||
final_stations.append({
|
||
"name": res.get("name"),
|
||
"address": res.get("address"),
|
||
"point": [(bbox[0] + bbox[2]) // 2, (bbox[1] + bbox[3]) // 2],
|
||
"bbox": bbox
|
||
})
|
||
elif not cv_bboxes:
|
||
final_stations = vlm_results if isinstance(vlm_results, list) else []
|
||
|
||
return final_stations
|
||
except Exception as e:
|
||
logger.error(f"分析列表页失败: {e}")
|
||
return []
|