Files
aiData/Util/BananaClient.py
HuangHai 79954171c6 'commit'
2026-01-20 09:26:42 +08:00

551 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import logging
import os
import re
import sys
import uuid
from Config.Config import *
import aiohttp
# 配置日志
# 获取项目根目录
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
from Util.ObsUtil import ObsUploader
# 配置日志
logger = logging.getLogger(__name__)
class BananaClient:
"""
通用 Banana 图片生成客户端,支持 TuZi 和 GPTNB 供应商
"""
PROVIDERS = {
"TuZi": {
"api_key": TUZI_API_KEY,
"base_url": TUZI_BASE_URL,
"default_model": "gemini-3-pro-image-preview",
"models": {
"gemini-3-pro-image-preview": "gemini-3-pro-image-preview",
"gemini-3-pro-image-preview-2k": "gemini-3-pro-image-preview-2k",
"gemini-3-pro-image-preview-4k": "gemini-3-pro-image-preview-4k"
}
},
"GPTNB": {
"api_key": GPTNB_API_KEY,
"base_url": GPTNB_BASE_URL,
"default_model": "nano-banana-2",
"models": {
"nano-banana-2": "nano-banana-2",
"nano-banana-2-4k": "nano-banana-2-4k",
"nano-banana-2-hd": "nano-banana-2-hd"
}
}
}
def __init__(self, provider=None, api_key=None, base_url=None, model=None, response_format="url", timeout=120):
"""
初始化 BananaClient
Args:
provider: 供应商名称,"TuZi""GPTNB"。如果为None则从配置读取BANANA_MODEL
api_key: API Key如果不提供则使用配置中的默认值
base_url: Base URL如果不提供则使用配置中的默认值
model: 模型名称,如果不提供则使用供应商的默认模型
response_format: 响应格式,默认为 "url"
timeout: 超时时间,默认为 120 秒
"""
if provider is None:
# 尝试从全局配置获取
try:
provider = BANANA_MODEL
except NameError:
provider = None
if provider:
logger.info(f"从配置中读取 Banana 提供商: {provider}")
else:
provider = "TuZi"
logger.warning("未在配置中找到 BANANA_MODEL默认使用 TuZi")
if provider not in self.PROVIDERS:
raise ValueError(f"不支持的供应商: {provider}。支持的供应商: {list(self.PROVIDERS.keys())}")
self.provider_name = provider
provider_config = self.PROVIDERS[provider]
self.api_key = api_key or provider_config["api_key"]
self.base_url = base_url or provider_config["base_url"]
self.model = model or provider_config["default_model"]
self.available_models = provider_config["models"]
self.response_format = response_format
self.timeout = timeout
self.obs_uploader = ObsUploader()
# 统一 provider 标识用于 OBS 路径等 (全小写)
self.provider_code = provider.lower()
logger.info(f"初始化 BananaClient - 供应商: {self.provider_name}, base_url: {self.base_url}, 模型: {self.model}")
async def prepare_image_data(self, image_paths):
"""
准备图片数据将本地文件上传到OBSURL直接使用包装为正确的API调用格式
参数:
image_paths: 单个图片路径字符串或图片路径列表可以是本地文件路径或URL
返回:
单个图片的API调用格式数据或图片API调用格式数据列表
"""
async def process_single_image(image_path):
"""处理单个图片路径支持本地文件和URL"""
image_path = image_path.strip()
logger.info(f"开始处理图片: {image_path}")
try:
# 检查是否是URL
if re.match(r'^https?://', image_path):
# 是URL直接使用
logger.info(f"图片是URL类型直接使用: {image_path}")
return {
"type": "image_url",
"image_url": {"url": image_path}
}
else:
# 是本地文件上传到OBS
logger.info(f"图片是本地文件开始上传到OBS: {image_path}")
if not os.path.exists(image_path):
raise FileNotFoundError(f"本地图片文件不存在: {image_path}")
if not os.path.isfile(image_path):
raise IsADirectoryError(f"指定路径不是文件: {image_path}")
file_size = os.path.getsize(image_path)
logger.info(f"本地文件存在,大小: {file_size} bytes")
# 生成GUID作为文件名
file_guid = str(uuid.uuid4())
# 从文件名获取扩展名
_, ext = os.path.splitext(image_path)
ext = ext.lower() if ext else '.jpg' # 默认使用jpg
# 构建OBS key使用临时文件前缀
obs_key = f"{OBS_TMP_PREFIX}/{file_guid}{ext}"
logger.debug(f"构建OBS key: {obs_key}")
# 上传到OBS
logger.info(f"正在上传到OBS: {obs_key}")
success, result = self.obs_uploader.upload_file(
file_path=image_path,
object_key=obs_key,
bucket_name=OBS_BUCKET
)
if success:
# 构建CDN访问URL
cdn_url = f"https://{CDN_DOMAIN}/{obs_key}"
logger.info(f"✅ 图片上传到OBS成功: {cdn_url}")
logger.info(f" OBS Key: {obs_key}")
return {
"type": "image_url",
"image_url": {"url": cdn_url}
}
else:
logger.error(f"上传到OBS失败: {result}")
raise Exception(f"上传到OBS失败: {result}")
except Exception as e:
logger.error(f"处理单个图片时出错: {image_path} - {str(e)}")
raise
try:
logger.info(f"准备图片数据: {image_paths}")
# 如果是单个图片路径字符串
if isinstance(image_paths, str):
return await process_single_image(image_paths)
# 如果是图片路径列表
elif isinstance(image_paths, list):
result = []
for i, image_path in enumerate(image_paths):
logger.info(f"处理列表中的图片 {i+1}/{len(image_paths)}: {image_path}")
result.append(await process_single_image(image_path))
logger.info(f"成功处理 {len(result)} 张图片")
return result
else:
raise TypeError(f"image_paths参数类型错误应为str或list实际为{type(image_paths)}")
except Exception as e:
logger.error(f"准备图片数据时出错: {e}")
raise
async def generate_image(self, prompt, image_paths=None, n=1, size="1024x1024"):
"""
异步调用 API 生成图片(支持文字生图和图片+文字改图)
参数:
prompt: 图片描述
image_paths: 图片路径或URL可以是单个字符串或列表用于改图功能
n: 生成图片数量,默认为 1
size: 图片尺寸,默认为 "1024x1024"
返回:
response_json: API 响应的 JSON 数据
"""
# 确保使用有效的模型
current_model = self.model
if current_model not in self.available_models:
logger.warning(f"警告: 模型 {current_model} 可能不在 {self.provider_name} 的支持列表中: {list(self.available_models.keys())}")
logger.info(f"开始生成图片 - 提供商: {self.provider_name}, 模型: {current_model}, 尺寸: {size}, 数量: {n}")
logger.info(f"提示词: {prompt}")
if image_paths:
logger.info(f"图片输入: {image_paths}")
# 构建API调用消息
if image_paths:
# 处理图片数据
image_data = await self.prepare_image_data(image_paths)
# 构建payload支持改图功能
image_urls = []
if isinstance(image_data, list):
for item in image_data:
if "image_url" in item and "url" in item["image_url"]:
image_urls.append(item["image_url"]["url"])
else:
if "image_url" in image_data and "url" in image_data["image_url"]:
image_urls.append(image_data["image_url"]["url"])
# 如果只有一个URL直接传递字符串如果有多个传递列表
final_image_input = image_urls[0] if len(image_urls) == 1 else image_urls
payload = {
"prompt": prompt,
"n": n,
"model": current_model,
"size": size,
"response_format": self.response_format,
"image": final_image_input, # 支持单个URL或URL列表
"timeout": self.timeout
}
logger.info(f"构建改图请求,使用模型: {current_model},尺寸: {size}")
else:
# 仅文字生图
payload = {
"prompt": prompt,
"n": n,
"model": current_model,
"size": size,
"response_format": self.response_format,
"timeout": self.timeout
}
logger.info(f"构建文字生图请求,使用模型: {current_model},尺寸: {size}")
endpoint = f"{self.base_url}/v1/images/generations"
headers = {
'Content-Type': 'application/json',
"Authorization": f"Bearer {self.api_key}"
}
logger.debug(f"请求头: {headers}")
logger.debug(f"请求payload: {payload}")
try:
logger.info(f"正在异步调用{self.provider_name} API...")
logger.info(f"使用端点: {endpoint}")
async with aiohttp.ClientSession() as session:
async with session.post(
endpoint,
json=payload,
headers=headers,
timeout=self.timeout
) as response:
response.raise_for_status() # 检查HTTP响应状态
response_json = await response.json()
logger.info("API调用完成")
logger.debug(f"API响应: {response_json}") # 使用debug级别记录完整响应
return response_json
except Exception as e:
logger.error(f"API调用失败: {str(e)}")
raise
async def get_image_urls(self, response_json):
"""
异步获取图片 URL 列表
参数:
response_json: API 响应的 JSON 数据
返回:
image_urls: 图片 URL 列表
"""
if response_json is None:
logger.info("未提供响应数据")
return []
# 检查数据结构并提取图片地址
image_urls = []
# 处理不同的响应格式
if "data" in response_json:
data = response_json["data"]
# 格式1: data是字符串URL
if isinstance(data, str):
image_urls.append(data)
# 格式2: data是包含url字段的对象列表
elif isinstance(data, list):
for item in data:
if isinstance(item, dict):
# 检查是否有url字段
if "url" in item:
image_urls.append(item["url"])
# 检查是否有嵌套的数据结构
elif "data" in item:
if isinstance(item["data"], str):
image_urls.append(item["data"])
elif isinstance(item["data"], list):
for nested_item in item["data"]:
if isinstance(nested_item, dict) and "url" in nested_item:
image_urls.append(nested_item["url"])
elif isinstance(nested_item, str):
image_urls.append(nested_item)
# 格式3: item本身是字符串URL
elif isinstance(item, str):
image_urls.append(item)
# 格式4: data是一个对象包含url字段
elif isinstance(data, dict):
if "url" in data:
image_urls.append(data["url"])
# 检查是否有嵌套的数据结构
elif "data" in data:
if isinstance(data["data"], str):
image_urls.append(data["data"])
elif isinstance(data["data"], list):
for item in data["data"]:
if isinstance(item, dict) and "url" in item:
image_urls.append(item["url"])
elif isinstance(item, str):
image_urls.append(item)
# Fallback for TuZi provider which might return 'url' at top level or nested differently
if not image_urls and "url" in response_json:
image_urls.append(response_json["url"])
# Log if no URLs found but response exists
if not image_urls:
logger.warning(f"Failed to extract image URLs from response: {response_json}")
return image_urls
async def print_image_urls(self, response_json):
"""
异步打印图片 URL
参数:
response_json: API 响应的 JSON 数据
"""
logger.info("\n正在解析图片地址...")
image_urls = await self.get_image_urls(response_json)
if image_urls:
for i, url in enumerate(image_urls):
logger.info(f"图片 {i+1} URL: {url}")
else:
logger.info("未找到有效的图片 URL")
async def download_and_upload_to_obs(self, response_json, overlay_logo_path=None):
"""
异步从临时图片URL下载图片并直接上传到OBS不落盘
参数:
response_json: API 响应的 JSON 数据
overlay_logo_path: 水印Logo图片路径可选
返回:
uploaded_urls: 上传到OBS的图片URL列表
"""
logger.info("\n正在下载图片并上传到OBS...")
# 获取图片URL列表
image_urls = await self.get_image_urls(response_json)
uploaded_urls = []
if not image_urls:
logger.info("未找到有效的图片 URL")
return uploaded_urls
logger.info(f"找到 {len(image_urls)} 个图片URL开始下载并上传到OBS")
for i, image_url in enumerate(image_urls):
try:
logger.info(f"正在处理图片 {i+1}/{len(image_urls)}: {image_url}")
# 生成GUID作为文件名
file_guid = str(uuid.uuid4())
# 从URL获取文件扩展名
if image_url.lower().endswith('.png'):
ext = 'png'
elif image_url.lower().endswith('.jpg') or image_url.lower().endswith('.jpeg'):
ext = 'jpg'
elif image_url.lower().endswith('.webp'):
ext = 'webp'
else:
ext = 'jpg' # 默认使用jpg
# 构建OBS key
obs_key = f"{OBS_CLOUD_PREFIX}/{self.provider_code.capitalize()}/{file_guid}.{ext}"
logger.debug(f"构建OBS key: {obs_key}")
# 使用异步方式下载图片
logger.info(f"正在下载图片: {image_url}")
async with aiohttp.ClientSession() as session:
async with session.get(image_url, timeout=120) as response:
response.raise_for_status()
image_data = await response.read()
logger.debug(f"成功下载图片,大小: {len(image_data)} 字节")
# 添加水印逻辑
if overlay_logo_path and os.path.exists(overlay_logo_path):
try:
from PIL import Image
import io
logger.info(f"正在添加水印: {overlay_logo_path}")
# 打开主图
img = Image.open(io.BytesIO(image_data))
# 打开Logo
logo = Image.open(overlay_logo_path)
# 计算Logo大小 (宽度为原图的20%)
target_logo_width = int(img.width * 0.2)
aspect_ratio = logo.height / logo.width
target_logo_height = int(target_logo_width * aspect_ratio)
# 调整Logo大小
logo = logo.resize((target_logo_width, target_logo_height), Image.Resampling.LANCZOS)
# 计算位置 (左上角3% padding)
padding = int(img.width * 0.03)
position = (padding, padding)
# 粘贴Logo (处理透明度)
if logo.mode in ('RGBA', 'LA') or (logo.mode == 'P' and 'transparency' in logo.info):
img.paste(logo, position, logo)
else:
img.paste(logo, position)
# 保存回bytes
output_buffer = io.BytesIO()
# 保持原格式如果无法识别则使用PNG
save_format = img.format if img.format else 'PNG'
img.save(output_buffer, format=save_format)
image_data = output_buffer.getvalue()
logger.info("水印添加成功")
except Exception as e:
logger.error(f"添加水印失败: {e}")
# 失败后继续上传原图
# 直接上传到OBS不落盘
# 直接上传到OBS不落盘
logger.info(f"正在上传到OBS: {obs_key}")
success, result = self.obs_uploader.upload_base64_image(
object_key=obs_key,
base64_data=image_data,
bucket_name=OBS_BUCKET
)
if success:
# 构建OBS访问URL
obs_url = f"https://{OBS_BUCKET}.{OBS_SERVER}/{obs_key}"
# 构建CDN访问URL
cdn_url = f"https://{CDN_DOMAIN}/{obs_key}"
logger.info(f"✅ 图片 {i+1} 上传成功: {cdn_url}")
logger.info(f" OBS Key: {obs_key}")
uploaded_urls.append(cdn_url)
else:
logger.error(f"上传到OBS失败: {result}")
except Exception as e:
logger.error(f"处理图片 {image_url} 失败: {str(e)}")
import traceback
logger.error(f"详细错误堆栈: {traceback.format_exc()}")
logger.info(f"图片处理完成,成功上传 {len(uploaded_urls)}/{len(image_urls)} 个图片到OBS")
return uploaded_urls
async def test_provider(provider_name):
"""测试指定供应商"""
print(f"\n{'='*20} 测试 {provider_name} 供应商 {'='*20}")
try:
# 创建客户端实例
client = BananaClient(provider=provider_name)
# 测试1: 文字生图
logger.info("\n---------------- 文字生图测试 ----------------")
text_prompt = "一只可爱的猫咪坐在月亮上,星空背景,卡通风格"
logger.info(f"提示词: {text_prompt}")
response = await client.generate_image(
prompt=text_prompt,
n=1,
size="1024x1024"
)
if response:
logger.info("文字生图成功!")
await client.print_image_urls(response)
# 可选上传到OBS
# await client.download_and_upload_to_obs(response)
else:
logger.error("文字生图失败!")
except Exception as e:
logger.error(f"{provider_name} 测试失败: {e}")
import traceback
logger.error(traceback.format_exc())
async def main():
"""
异步主函数,用于测试 BananaClient 的功能
"""
# 配置控制台日志输出
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger.info("开始 BananaClient 整合测试...")
# 测试 TuZi (默认)
await test_provider("TuZi")
# 测试 GPTNB
#await test_provider("GPTNB")
logger.info("\n" + "="*50)
logger.info("测试全部完成!")
logger.info("="*50)
if __name__ == "__main__":
asyncio.run(main())