174 lines
5.5 KiB
Python
174 lines
5.5 KiB
Python
# coding=utf-8
|
|
import os
|
|
import sys
|
|
import time
|
|
import asyncio
|
|
import json
|
|
|
|
# Add project root to sys.path
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
project_root = os.path.dirname(current_dir)
|
|
if project_root not in sys.path:
|
|
sys.path.append(project_root)
|
|
|
|
from paddleocr import PaddleOCR
|
|
from Util.LlmUtil import get_llm_response
|
|
from Util.OcrParser import OcrParser
|
|
import re
|
|
|
|
LOG_FILE = os.path.join(current_dir, "ocr_llm_debug.txt")
|
|
|
|
def log(msg):
|
|
print(msg)
|
|
sys.stdout.flush()
|
|
with open(LOG_FILE, "a", encoding="utf-8") as f:
|
|
f.write(msg + "\n")
|
|
|
|
# ... imports ...
|
|
def run_ocr_sync():
|
|
image_path = os.path.join(current_dir, "2.jpg")
|
|
if not os.path.exists(image_path):
|
|
image_path = os.path.join(current_dir, "1.jpg")
|
|
|
|
log(f"Testing OCR + LLM Pipeline on: {image_path}")
|
|
log("-" * 50)
|
|
|
|
# --- Step 1: PaddleOCR ---
|
|
t_start = time.time()
|
|
|
|
log("Initializing PaddleOCR...")
|
|
t_init_start = time.time()
|
|
try:
|
|
# 尝试使用轻量级模型 (Mobile) 以提升速度
|
|
# ocr_version='PP-OCRv4' 通常默认是 mobile
|
|
ocr = PaddleOCR(use_textline_orientation=True, lang="ch", ocr_version='PP-OCRv4')
|
|
except Exception as e:
|
|
log(f"PaddleOCR Init Failed: {e}")
|
|
return None, None
|
|
|
|
t_init_end = time.time()
|
|
log(f"PaddleOCR Init Time: {t_init_end - t_init_start:.4f}s")
|
|
|
|
log("Running OCR Inference...")
|
|
t_ocr_start = time.time()
|
|
try:
|
|
result = ocr.ocr(image_path)
|
|
except Exception as e:
|
|
log(f"OCR Inference Failed: {e}")
|
|
return None, None
|
|
t_ocr_end = time.time()
|
|
|
|
ocr_text_lines = []
|
|
|
|
# Handle different result structures
|
|
if not result:
|
|
log("OCR returned empty result.")
|
|
else:
|
|
res = result[0]
|
|
if res is None:
|
|
log("OCR result[0] is None.")
|
|
elif hasattr(res, 'get') and 'rec_texts' in res:
|
|
ocr_text_lines = res.get('rec_texts', [])
|
|
elif hasattr(res, 'rec_texts'):
|
|
ocr_text_lines = res.rec_texts
|
|
elif isinstance(res, list):
|
|
for line in res:
|
|
if len(line) >= 2 and isinstance(line[1], (tuple, list)):
|
|
ocr_text_lines.append(line[1][0])
|
|
|
|
ocr_text_block = "\n".join(ocr_text_lines)
|
|
log(f"OCR Result ({t_ocr_end - t_ocr_start:.4f}s):")
|
|
log(ocr_text_block)
|
|
log("-" * 50)
|
|
|
|
return ocr_text_lines, ocr_text_block, (t_ocr_start, t_ocr_end)
|
|
|
|
async def run_parsing_comparison(ocr_text_lines, ocr_text_block, timing_ocr):
|
|
t_ocr_start, t_ocr_end = timing_ocr
|
|
ocr_duration = t_ocr_end - t_ocr_start
|
|
|
|
# --- Mode 1: Regex Parsing ---
|
|
log("Running Regex Parsing...")
|
|
t_regex_start = time.time()
|
|
try:
|
|
regex_data = OcrParser.parse(ocr_text_lines)
|
|
log("\nParsed Data (Regex):")
|
|
log(json.dumps(regex_data, indent=2, ensure_ascii=False))
|
|
except Exception as e:
|
|
log(f"Regex Parsing Failed: {e}")
|
|
t_regex_end = time.time()
|
|
regex_duration = t_regex_end - t_regex_start
|
|
log(f"Regex Parsing Time: {regex_duration:.4f}s")
|
|
log("-" * 50)
|
|
|
|
# --- Mode 2: LLM Parsing ---
|
|
log("Running LLM Parsing...")
|
|
|
|
prompt = f"""
|
|
You are a data extraction assistant. Below is the OCR text recognized from a charging station list card.
|
|
Please extract the structured data and return it ONLY as a JSON object (no markdown, no extra text).
|
|
|
|
Fields to extract:
|
|
- station_name: (String) Name of the charging station.
|
|
- distance: (String) Distance info (e.g., "7.4km").
|
|
- price: (String) Price info (e.g., "0.7111/度").
|
|
- tags: (List[String]) Any tags like "快", "闲3/4", "组团", "2倍积分", "P", etc.
|
|
- parking_info: (String) Parking related info.
|
|
|
|
OCR Text:
|
|
{ocr_text_block}
|
|
"""
|
|
|
|
t_llm_start = time.time()
|
|
response_text = ""
|
|
try:
|
|
log("Starting LLM request...")
|
|
async for chunk in get_llm_response(prompt, stream=True):
|
|
print(chunk, end='', flush=True)
|
|
response_text += chunk
|
|
|
|
print("\n")
|
|
log("LLM request finished.")
|
|
|
|
t_llm_end = time.time()
|
|
|
|
log(f"\nLLM Response ({t_llm_end - t_llm_start:.4f}s):")
|
|
log(response_text)
|
|
|
|
try:
|
|
clean_text = response_text.replace("```json", "").replace("```", "").strip()
|
|
data = json.loads(clean_text)
|
|
log("\nParsed JSON Data:")
|
|
log(json.dumps(data, indent=2, ensure_ascii=False))
|
|
except json.JSONDecodeError:
|
|
log("\nFailed to parse JSON directly.")
|
|
|
|
except Exception as e:
|
|
log(f"LLM Error: {e}")
|
|
t_llm_end = time.time()
|
|
|
|
log("-" * 50)
|
|
log(f"Summary:")
|
|
log(f"OCR Time: {ocr_duration:.4f}s")
|
|
log(f"Regex Parsing Time: {regex_duration:.4f}s")
|
|
log(f"LLM Parsing Time: {t_llm_end - t_llm_start:.4f}s")
|
|
|
|
total_regex = ocr_duration + regex_duration
|
|
total_llm = ocr_duration + (t_llm_end - t_llm_start)
|
|
|
|
log(f"Total Pipeline (OCR+Regex): {total_regex:.4f}s")
|
|
log(f"Total Pipeline (OCR+LLM): {total_llm:.4f}s")
|
|
|
|
def main():
|
|
# Clear log file
|
|
with open(LOG_FILE, "w", encoding="utf-8") as f:
|
|
f.write("Starting TestOcrLlm...\n")
|
|
|
|
ocr_lines, ocr_text, timing = run_ocr_sync()
|
|
if ocr_lines:
|
|
asyncio.run(run_parsing_comparison(ocr_lines, ocr_text, timing))
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|