Files
aiData/Controller/HaiBaoController.py
HuangHai 79954171c6 'commit'
2026-01-20 09:26:42 +08:00

231 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import uuid
import asyncio
from datetime import datetime
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from sqlalchemy.sql import text
from Util.BananaClient import BananaClient
from Util.LlmUtil import get_llm_response
from DbKit.Db import Db
router = APIRouter(prefix="/haibao")
logger = logging.getLogger("HaiBaoController")
class GenerateRequest(BaseModel):
prompt: str
width: int = 1024
height: int = 1024
@router.on_event("startup")
async def startup_event():
"""初始化时检查并创建表"""
db = Db()
await db.init_db()
# Doris 建表语句
create_table_sql = """
CREATE TABLE IF NOT EXISTS haibao_history (
id VARCHAR(50) COMMENT "ID",
prompt TEXT COMMENT "提示词",
image_url VARCHAR(500) COMMENT "图片URL",
scheme_content TEXT COMMENT "文案方案",
created_at DATETIME COMMENT "创建时间"
)
DUPLICATE KEY(id)
DISTRIBUTED BY HASH(id) BUCKETS 1
PROPERTIES (
"replication_num" = "1"
);
"""
try:
# 使用 engine 直接执行 DDL
async with db.engine.begin() as conn:
await conn.execute(text(create_table_sql))
# 尝试添加列(如果表已存在但列不存在)
# Doris 不支持 IF NOT EXISTS for ADD COLUMN directly nicely in all versions without error if exists
# 所以这里简单捕获异常,如果列已存在则忽略
try:
alter_sql = "ALTER TABLE haibao_history ADD COLUMN scheme_content TEXT COMMENT '文案方案'"
await conn.execute(text(alter_sql))
except Exception as e:
# 忽略列已存在的错误
pass
logger.info("海报历史表检查/更新成功")
except Exception as e:
logger.error(f"海报历史表创建/更新失败: {e}")
class RefineRequest(BaseModel):
prompt: str
@router.post("/refine")
async def refine_prompt(req: RefineRequest):
"""润色提示词"""
try:
refine_system_prompt = "你是一个资深的AI绘画提示词专家。你的任务是将用户简短的描述扩充为一段详细、高质量的画面描述提示词用于生成宣传海报。"
refine_user_prompt = f"""
请根据以下主题,为充电企业“驿来特”设计一张宣传海报的画面描述。
主题:{req.prompt}
要求:
1. 描述画面主体、背景、光影、色彩、构图。
2. 风格要求现代感、科技感、精美、3D渲染风格或高品质插画风格。
3. 融入新能源、绿色环保、充电桩等元素。
4. 关于品牌元素画面中可自然融入品牌Logo的视觉元素如配色、形状能用多少就用多少有元素体现即可不必生搬硬套保持画面自然和谐。
5. 直接输出提示词内容,不要包含“好的”、“以下是”等无关废话。
6. 字数在100-300字之间。
"""
refined_prompt = ""
try:
async for chunk in get_llm_response(query_text=refine_user_prompt, system_prompt=refine_system_prompt, stream=False):
refined_prompt += chunk
except Exception as e:
logger.error(f"提示词润色失败: {e}")
raise Exception("润色服务暂时不可用")
return {"refined_prompt": refined_prompt}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate")
async def generate_poster(req: GenerateRequest):
"""生成海报及文案"""
try:
# 并行执行生图和生文
client = BananaClient()
# 1. 构造生图任务 (包含智能润色判断)
async def generate_image_task():
final_prompt = req.prompt
# 智能判断如果提示词太短少于50字则认为用户未进行润色自动执行润色
# 如果用户使用了"一键扩写"功能,提示词通常会很长,这里就会跳过自动润色,尊重用户的修改
if len(final_prompt) < 50:
logger.info(f"提示词较短({len(final_prompt)}字),执行自动润色...")
refine_system_prompt = "你是一个资深的AI绘画提示词专家。你的任务是将用户简短的描述扩充为一段详细、高质量的画面描述提示词用于生成宣传海报。"
refine_user_prompt = f"""
请根据以下主题,为充电企业“驿来特”设计一张宣传海报的画面描述。
主题:{req.prompt}
要求:
1. 描述画面主体、背景、光影、色彩、构图。
2. 风格要求现代感、科技感、精美、3D渲染风格或高品质插画风格。
3. 融入新能源、绿色环保、充电桩等元素。
4. 关于品牌元素画面中可自然融入品牌Logo的视觉元素如配色、形状能用多少就用多少有元素体现即可不必生搬硬套保持画面自然和谐。
5. 直接输出提示词内容,不要包含“好的”、“以下是”等无关废话。
6. 字数在100-300字之间。
"""
refined_prompt = ""
try:
async for chunk in get_llm_response(query_text=refine_user_prompt, system_prompt=refine_system_prompt, stream=False):
refined_prompt += chunk
if refined_prompt and refined_prompt.strip():
final_prompt = refined_prompt
except Exception as e:
logger.error(f"自动润色失败,使用原始提示词: {e}")
logger.info(f"Final generation prompt: {final_prompt}")
# 1.2 调用生图
resp = await client.generate_image(prompt=final_prompt, size=f"{req.width}x{req.height}")
# 定义Logo路径
LOGO_PATH = r"d:\dsWork\aiData\static\Images\login_logo.png"
obs_urls = await client.download_and_upload_to_obs(resp, overlay_logo_path=LOGO_PATH)
if not obs_urls:
raise Exception("未获取到有效的图片URL")
return obs_urls[0]
# 2. 构造生文任务
async def generate_text_task():
scheme_prompt = f"""
你是一个专业的社群运营专家。请为充电企业“驿来特”撰写一段发在微信群里的宣传文案。
主题:{req.prompt}
要求:
1. 语气亲切、有吸引力,适合微信社群传播。
2. 突出“驿来特”品牌,强调新能源、优惠、便利等特点(根据主题自由发挥)。
3. 包含适当的emoji表情增加趣味性。
4. 字数控制在150字以内。
5. 格式清晰,分段合理。
"""
# get_llm_response 是一个异步生成器 (stream=True by default) 或者直接返回 (stream=False)
# 这里我们强制 stream=False 获取完整文本
text_response = ""
# LlmUtil.get_llm_response 默认为 stream=True我们需要修改调用方式或适配
# 查看 LlmUtil 源码,如果 stream=False它 yield 内容。
# 所以我们需要迭代它
async for chunk in get_llm_response(query_text=scheme_prompt, stream=False):
text_response += chunk
return text_response
# 3. 并行执行
image_url, scheme_content = await asyncio.gather(generate_image_task(), generate_text_task())
# 4. 保存到数据库
db = Db()
record_id = str(uuid.uuid4())
created_at = datetime.now()
insert_sql = """
INSERT INTO haibao_history (id, prompt, image_url, scheme_content, created_at)
VALUES (:id, :prompt, :image_url, :scheme_content, :created_at)
"""
params = {
"id": record_id,
"prompt": req.prompt,
"image_url": image_url,
"scheme_content": scheme_content,
"created_at": created_at
}
async with db.get_session() as session:
async with session.begin():
await session.execute(text(insert_sql), params)
return {
"id": record_id,
"image_url": image_url,
"prompt": req.prompt,
"scheme_content": scheme_content,
"created_at": created_at.strftime("%Y-%m-%d %H:%M:%S")
}
except Exception as e:
logger.error(f"生成海报/文案失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/history")
async def get_history():
"""获取海报生成历史"""
db = Db()
sql = "SELECT * FROM haibao_history ORDER BY created_at DESC LIMIT 50"
try:
result = await db.find(sql)
formatted_result = []
for item in result:
item_dict = dict(item) if not isinstance(item, dict) else item
if isinstance(item_dict.get('created_at'), datetime):
item_dict['created_at'] = item_dict['created_at'].strftime("%Y-%m-%d %H:%M:%S")
# 确保 scheme_content 存在
if 'scheme_content' not in item_dict:
item_dict['scheme_content'] = ""
formatted_result.append(item_dict)
return formatted_result
except Exception as e:
logger.error(f"获取历史失败: {e}")
return []