This commit is contained in:
HuangHai
2026-01-12 15:22:50 +08:00
parent cc5547f37c
commit 2339558596
10 changed files with 468 additions and 6 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 KiB

View File

@@ -15,6 +15,7 @@ if project_root not in sys.path:
import uiautomator2 as u2
from Apps.AiTeJiYiChong import Kit
from Apps.AiTeJiYiChong.Kit import take_screenshot
from Apps.AiTeJiYiChong.ReadImageKit import ReadImageKit
from Util.RedisKit import RedisKit
from Apps.AiTeJiYiChong.Service import AiTeJiYiChongService
from Config.Config import TEMP_IMAGE_DIR
@@ -54,15 +55,121 @@ async def get_station_list(d, service, max_scrolls=MAX_SCROLLS):
json_data = Kit.crop_cards_from_image(screenshot_path)
# 3. 调用 VL 模型识别并保存数据
# 这里的 service.process_station_list_vl 应该支持传入 json_data 或直接读取图片
stations = await service.process_station_list_vl(screenshot_path, device_info=device_info)
logger.info(f"本页识别到 {len(stations)} 个场站")
if not stations:
logger.warning("本页未识别到任何场站,可能已到底或加载中")
# 如果连续几页没数据可以考虑跳出,这里先简单处理
# 3. 翻页滑动
# 4. 匹配几何卡片与 VL 识别结果
if json_data and json_data.get("cards") and stations:
for card in json_data["cards"]:
card_rect = card["rect"] # [x1, y1, x2, y2]
for st in stations:
st_bounds = st.get("bounds") # [x1, y1, x2, y2] (0-1000)
if not st_bounds: continue
# 转换 VL 坐标到像素坐标
st_y1_px = st_bounds[1] * h / 1000
st_y2_px = st_bounds[3] * h / 1000
# 计算 y 轴重叠
overlap = min(card_rect[3], st_y2_px) - max(card_rect[1], st_y1_px)
if overlap > 50: # 有显著重叠
card["station_name"] = st.get("station_name")
logger.info(f"匹配成功: 几何卡片 {card_rect} -> 场站 {card['station_name']}")
break
# 5. 遍历处理本页所有场站
if json_data and json_data.get("cards"):
for card_idx, card in enumerate(json_data["cards"]):
station_name = card.get("station_name")
if not station_name:
logger.warning(f"{card_idx + 1} 个几何卡片未匹配到场站名称,跳过。")
continue
# 检查 Redis 去重
redis_key = f"crawled:aite:{station_name}"
if redis_kit.get(redis_key):
logger.info(f"场站 {station_name} 已处理,跳过。")
continue
click_x, click_y = card["click_point"]
logger.info(f"准备处理第 {card_idx + 1} 个场站: {station_name}, 点击坐标: ({click_x}, {click_y})")
d.click(click_x, click_y)
# 等待二级页面加载
await asyncio.sleep(3)
# 截取二级页面图
detail_uuid = f"detail_{station_name}_{image_uuid}"
detail_path = take_screenshot(d, detail_uuid, save_dir=TEMP_IMAGE_DIR)
# 5. 调用详情页处理逻辑 (二级页面:提取地址)
logger.info(f"正在解析详情页基础数据: {detail_path}")
detail_data = await service.process_station_detail(detail_path, station_name=station_name)
# 6. 寻找 timePrice.jpg 图标并进入三级页面 (分时价格页)
time_price_template = os.path.join(os.path.dirname(__file__), "BiaoShi", "timePrice.jpg")
coords = Kit.find_template_coords(detail_path, time_price_template)
if coords:
cx, cy, conf = coords
logger.info(f"找到分时价格图标,进入三级页面...")
d.click(cx, cy)
await asyncio.sleep(3)
# 截取三级页面并处理滑动价格
price_detail_uuid = f"price_detail_{station_name}_{image_uuid}"
price_detail_path = take_screenshot(d, price_detail_uuid, save_dir=TEMP_IMAGE_DIR)
all_prices = []
# 初始识别
prices = await ReadImageKit.get_price_detail_from_image(price_detail_path)
if prices: all_prices.extend(prices)
last_md5 = Kit.get_file_md5(price_detail_path)
# 滑动逻辑
for scroll_idx in range(3):
logger.info(f"执行第 {scroll_idx + 1} 次滑动以抓取更多价格...")
d.swipe(w // 2, int(h * 0.7), w // 2, int(h * 0.3), duration=0.5)
await asyncio.sleep(2)
scroll_path = take_screenshot(d, f"price_scroll_{scroll_idx}_{station_name}", save_dir=TEMP_IMAGE_DIR)
current_md5 = Kit.get_file_md5(scroll_path)
if current_md5 == last_md5:
logger.info("检测到屏幕未发生变化(已滑到底部),停止滑动识别。")
break
last_md5 = current_md5
new_prices = await ReadImageKit.get_price_detail_from_image(scroll_path)
if new_prices:
for np_item in new_prices:
if not any(np_i.get("start") == np_item.get("start") for np_i in all_prices):
all_prices.append(np_item)
else: break
if all_prices:
all_prices.sort(key=lambda x: x.get("start", "00:00"))
hourly_schedule = ReadImageKit.expand_schedule_to_24h(all_prices)
await service.process_price_detail_data(station_name, hourly_schedule)
logger.info(f"场站 {station_name} 三级页面价格处理完成。")
# 从三级页面返回二级页面
logger.info("从三级页面返回二级页面...")
d.press("back")
await asyncio.sleep(1.5)
else:
logger.warning(f"场站 {station_name} 未找到分时价格入口。")
# 从二级页面返回列表页
logger.info("从二级页面返回列表页...")
d.press("back")
await asyncio.sleep(1.5)
# 记录 Redis
redis_kit.set(redis_key, "1", ex=REDIS_STATION_EXPIRE)
# 5. 翻页滑动 (如果没进入二级页面)
logger.info("执行翻页滑动...")
start_x, start_y = w // 2, int(h * 0.8)
end_x, end_y = w // 2, int(h * (0.8 - SCROLL_DISTANCE_RATIO))

View File

@@ -4,6 +4,7 @@ import cv2
import numpy as np
import time
import json
import hashlib
from Apps.AiTeJiYiChong.Config.Setting import BOTTOM_SAFE_EXCLUDE_RATIO
from Config.Config import TEMP_IMAGE_DIR
@@ -11,6 +12,17 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(level
logger = logging.getLogger(__name__)
def get_file_md5(file_path):
"""计算文件的 MD5 值"""
if not os.path.exists(file_path):
return None
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def read_image(path):
"""读取图片,支持中文路径"""
try:
@@ -105,6 +117,40 @@ def click_image_template(d, template_path, timeout=5.0, threshold=0.8):
return False
def find_template_coords(img_path, template_path, threshold=0.8):
"""
在图片中查找模板并返回中心坐标
"""
if not os.path.exists(img_path) or not os.path.exists(template_path):
return None
img = read_image(img_path)
template = read_image(template_path)
if img is None or template is None:
return None
t_h, t_w = template.shape[:2]
# 多尺度匹配
best_match = None
for scale in np.linspace(0.8, 1.2, 5):
resized = cv2.resize(template, (int(t_w * scale), int(t_h * scale)))
if resized.shape[0] > img.shape[0] or resized.shape[1] > img.shape[1]:
continue
res = cv2.matchTemplate(img, resized, cv2.TM_CCOEFF_NORMED)
_, max_val, _, max_loc = cv2.minMaxLoc(res)
if best_match is None or max_val > best_match[0]:
best_match = (max_val, max_loc, resized.shape[1], resized.shape[0])
if best_match and best_match[0] >= threshold:
max_val, max_loc, r_w, r_h = best_match
center_x = max_loc[0] + r_w // 2
center_y = max_loc[1] + r_h // 2
return (center_x, center_y, max_val)
return None
def crop_cards_from_image(img_path, output_dir=None, save_debug=True):
"""
从图片中裁剪场站卡片并生成 _flag.jpg 和 _vl.jpg

View File

@@ -98,6 +98,191 @@ class ReadImageKit:
return "[]"
_prompt_detail = (
"仅输出JSON对象不含任何说明文字识别充电站详情图片中的以下信息\n"
"1. station_name: 场站名称;\n"
"2. address: 场站完整地址(通常在定位图标旁)。\n"
"\n"
"特别说明:\n"
"- 严禁输出 Markdown 代码块标签严格返回纯JSON对象。"
)
@staticmethod
def _to_minutes(t_str: str) -> int:
"""HH:MM -> 分钟数"""
if not t_str or ":" not in t_str:
return 0
try:
h, m = map(int, t_str.split(":"))
return h * 60 + m
except:
return 0
@staticmethod
def _fmt(t: int) -> str:
"""分钟数 -> HH:MM"""
h = t // 60
m = t % 60
return f"{h:02d}:{m:02d}"
@staticmethod
def expand_schedule_to_24h(rows: list) -> list:
"""
将时段列表规整为全天24个整点小时段
"""
# 预处理:转换为分钟区间
intervals = []
for r in rows:
s = ReadImageKit._to_minutes(r.get("start"))
e = ReadImageKit._to_minutes(r.get("end"))
if e <= s and e != 0: # 处理 00:00-00:00 这种可能
if e == 0: e = 1440
else: continue
if e == 0 and s > 0: e = 1440
s = max(0, s)
e = min(1440, e)
intervals.append({
"s": s, "e": e,
"price": r.get("price")
})
# 排序
intervals.sort(key=lambda x: (x["s"], x["e"]))
result = []
for h in range(24):
hs = h * 60
he = (h + 1) * 60
best_price = None
# 找到覆盖当前小时段的价格
for it in intervals:
# 计算重叠
overlap_s = max(hs, it["s"])
overlap_e = min(he, it["e"])
if overlap_e > overlap_s:
best_price = it["price"]
break # 简单处理,取第一个覆盖的
result.append({
"start": ReadImageKit._fmt(hs),
"end": ReadImageKit._fmt(he),
"price": best_price
})
return result
_prompt_price_detail = (
"仅输出JSON对象不含任何说明文字识别充电价格详情图片中的分时段电价信息。图片中通常包含“充电时段”、“单价”、“电费”、“服务费”等列。\n"
"请识别出所有的时段行,返回一个列表,每个元素包含:\n"
"1. start: 开始时间 (HH:MM)\n"
"2. end: 结束时间 (HH:MM)\n"
"3. price: 总单价 (数字,如 0.6100)\n"
"4. ele_fee: 电费 (数字)\n"
"5. ser_fee: 服务费 (数字)。\n"
"\n"
"注意:\n"
"- 只需识别“直流桩”或当前选中的桩型下的数据。\n"
"- 严禁输出 Markdown 代码块标签严格返回纯JSON对象。"
)
@classmethod
async def get_price_detail_from_image(cls, image_path: str):
"""
使用 VL 模型从三级价格详情页截图中识别分时段电价
"""
if not os.path.exists(image_path):
logger.error(f"Image not found: {image_path}")
return None
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
try:
response = await asyncio.to_thread(
cls._client.chat.completions.create,
model=VL_MODEL_NAME,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": cls._prompt_price_detail},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
},
],
}
],
max_tokens=2000,
temperature=0.01
)
content = response.choices[0].message.content
logger.info(f"VL Price Detail Response: {content}")
json_str = cls._extract_json(content)
data = json.loads(json_str)
# 兼容性处理:如果返回的是对象且包含列表字段,提取列表
if isinstance(data, dict):
for key in ["price_list", "prices", "schedule", "data"]:
if key in data and isinstance(data[key], list):
return data[key]
# 如果本身就是个包含列表的字典,也可能需要根据实际情况调整
# 这里假设 prompt 引导它返回一个包含列表的对象或直接是列表
if "items" in data: return data["items"]
return data
except Exception as e:
logger.error(f"Error calling VL model for price detail: {e}")
return None
@classmethod
async def get_station_detail_from_image(cls, image_path: str):
"""
使用 VL 模型从详情页截图中识别地址、名称和电价
"""
if not os.path.exists(image_path):
logger.error(f"Image not found: {image_path}")
return None
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
try:
response = await asyncio.to_thread(
cls._client.chat.completions.create,
model=VL_MODEL_NAME,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": cls._prompt_detail},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
},
],
}
],
max_tokens=1000,
temperature=0.01
)
content = response.choices[0].message.content
logger.info(f"VL Detail Response: {content}")
json_str = cls._extract_json(content)
detail = json.loads(json_str)
return detail
except Exception as e:
logger.error(f"Error calling VL model for detail: {e}")
return None
@classmethod
async def get_stations_from_image(cls, image_path: str, device_info=None):
"""

View File

@@ -42,6 +42,89 @@ class AiTeJiYiChongService:
async def close_db(self):
await self.db.close()
async def process_price_detail_data(self, station_name, hourly_schedule) -> bool:
"""
直接保存已处理好的小时段价格数据
"""
if not station_name or not hourly_schedule:
return False
station_hash = self.get_hash(station_name)
now = datetime.now()
async with await self.db.get_session() as session:
schedule_id = self.generate_id()
await self.station_price_schedule_model.save(
session=session,
id=schedule_id,
station_hash=station_hash,
schedule_json=hourly_schedule,
valid_start_time=now
)
await session.commit()
return True
async def process_price_detail(self, image_path, station_name) -> list:
"""
处理三级价格详情页截图
"""
prices = await ReadImageKit.get_price_detail_from_image(image_path)
if not prices:
return None
station_hash = self.get_hash(station_name)
now = datetime.now()
# 将识别到的原始分时价格扩展为 24 小时整点数据
hourly_schedule = ReadImageKit.expand_schedule_to_24h(prices)
async with await self.db.get_session() as session:
schedule_id = self.generate_id()
await self.station_price_schedule_model.save(
session=session,
id=schedule_id,
station_hash=station_hash,
schedule_json=hourly_schedule,
valid_start_time=now
)
await session.commit()
logger.info(f"三级页面价格详情处理完成: {station_name}, 共 {len(hourly_schedule)} 个小时段")
return hourly_schedule
async def process_station_detail(self, image_path, station_name=None) -> dict:
"""
处理场站详情页截图
"""
detail = await ReadImageKit.get_station_detail_from_image(image_path)
if not detail:
return None
name = station_name or detail.get("station_name")
if not name:
return None
station_hash = self.get_hash(name)
now = datetime.now()
async with await self.db.get_session() as session:
# 1. 更新 Profile 中的地址信息
profile_id = self.generate_id()
await self.station_profile_model.save(
session=session,
id=profile_id,
station_hash=station_hash,
operator=self.operator,
station_name=name,
address=detail.get("address"),
valid_start_time=now
)
await session.commit()
logger.info(f"场站详情处理完成: {name}")
return detail
async def process_station_list_vl(self, image_path, device_info=None) -> list:
"""
基于 VL 模式处理场站列表

View File

@@ -0,0 +1,41 @@
# coding=utf-8
import asyncio
import uiautomator2 as u2
import os
import sys
import uuid
project_root = r"D:\dsWork\aiData"
if project_root not in sys.path:
sys.path.append(project_root)
from Apps.AiTeJiYiChong import Kit
from Apps.AiTeJiYiChong.Kit import take_screenshot, read_image
from Config.Config import TEMP_IMAGE_DIR
async def test_click_and_detail():
d = u2.connect()
image_uuid = str(uuid.uuid4())
print(f"Taking initial screenshot...")
screenshot_path = take_screenshot(d, image_uuid, save_dir=TEMP_IMAGE_DIR)
print(f"Analyzing cards in {screenshot_path}...")
json_data = Kit.crop_cards_from_image(screenshot_path)
if json_data and json_data.get("cards"):
first_card = json_data["cards"][0]
click_x, click_y = first_card["click_point"]
print(f"Clicking card at ({click_x}, {click_y})...")
d.click(click_x, click_y)
print("Waiting for detail page...")
await asyncio.sleep(5)
detail_uuid = f"detail_{image_uuid}"
detail_path = take_screenshot(d, detail_uuid, save_dir=TEMP_IMAGE_DIR)
print(f"Detail page screenshot: {detail_path}")
else:
print("No cards found on current screen.")
if __name__ == "__main__":
asyncio.run(test_click_and_detail())