import asyncio import logging import os import re import sys import uuid from Config.Config import * import aiohttp # 配置日志 # 获取项目根目录 project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(project_root) from Util.ObsUtil import ObsUploader # 配置日志 logger = logging.getLogger(__name__) class BananaClient: """ 通用 Banana 图片生成客户端,支持 TuZi 和 GPTNB 供应商 """ PROVIDERS = { "TuZi": { "api_key": TUZI_API_KEY, "base_url": TUZI_BASE_URL, "default_model": "gemini-3-pro-image-preview", "models": { "gemini-3-pro-image-preview": "gemini-3-pro-image-preview", "gemini-3-pro-image-preview-2k": "gemini-3-pro-image-preview-2k", "gemini-3-pro-image-preview-4k": "gemini-3-pro-image-preview-4k" } }, "GPTNB": { "api_key": GPTNB_API_KEY, "base_url": GPTNB_BASE_URL, "default_model": "nano-banana-2", "models": { "nano-banana-2": "nano-banana-2", "nano-banana-2-4k": "nano-banana-2-4k", "nano-banana-2-hd": "nano-banana-2-hd" } } } def __init__(self, provider=None, api_key=None, base_url=None, model=None, response_format="url", timeout=120): """ 初始化 BananaClient Args: provider: 供应商名称,"TuZi" 或 "GPTNB"。如果为None则从配置读取BANANA_MODEL api_key: API Key,如果不提供则使用配置中的默认值 base_url: Base URL,如果不提供则使用配置中的默认值 model: 模型名称,如果不提供则使用供应商的默认模型 response_format: 响应格式,默认为 "url" timeout: 超时时间,默认为 120 秒 """ if provider is None: # 尝试从全局配置获取 try: provider = BANANA_MODEL except NameError: provider = None if provider: logger.info(f"从配置中读取 Banana 提供商: {provider}") else: provider = "TuZi" logger.warning("未在配置中找到 BANANA_MODEL,默认使用 TuZi") if provider not in self.PROVIDERS: raise ValueError(f"不支持的供应商: {provider}。支持的供应商: {list(self.PROVIDERS.keys())}") self.provider_name = provider provider_config = self.PROVIDERS[provider] self.api_key = api_key or provider_config["api_key"] self.base_url = base_url or provider_config["base_url"] self.model = model or provider_config["default_model"] self.available_models = provider_config["models"] self.response_format = response_format self.timeout = timeout self.obs_uploader = ObsUploader() # 统一 provider 标识用于 OBS 路径等 (全小写) self.provider_code = provider.lower() logger.info(f"初始化 BananaClient - 供应商: {self.provider_name}, base_url: {self.base_url}, 模型: {self.model}") async def prepare_image_data(self, image_paths): """ 准备图片数据,将本地文件上传到OBS,URL直接使用,包装为正确的API调用格式 参数: image_paths: 单个图片路径字符串或图片路径列表,可以是本地文件路径或URL 返回: 单个图片的API调用格式数据或图片API调用格式数据列表 """ async def process_single_image(image_path): """处理单个图片路径,支持本地文件和URL""" image_path = image_path.strip() logger.info(f"开始处理图片: {image_path}") try: # 检查是否是URL if re.match(r'^https?://', image_path): # 是URL,直接使用 logger.info(f"图片是URL类型,直接使用: {image_path}") return { "type": "image_url", "image_url": {"url": image_path} } else: # 是本地文件,上传到OBS logger.info(f"图片是本地文件,开始上传到OBS: {image_path}") if not os.path.exists(image_path): raise FileNotFoundError(f"本地图片文件不存在: {image_path}") if not os.path.isfile(image_path): raise IsADirectoryError(f"指定路径不是文件: {image_path}") file_size = os.path.getsize(image_path) logger.info(f"本地文件存在,大小: {file_size} bytes") # 生成GUID作为文件名 file_guid = str(uuid.uuid4()) # 从文件名获取扩展名 _, ext = os.path.splitext(image_path) ext = ext.lower() if ext else '.jpg' # 默认使用jpg # 构建OBS key,使用临时文件前缀 obs_key = f"{OBS_TMP_PREFIX}/{file_guid}{ext}" logger.debug(f"构建OBS key: {obs_key}") # 上传到OBS logger.info(f"正在上传到OBS: {obs_key}") success, result = self.obs_uploader.upload_file( file_path=image_path, object_key=obs_key, bucket_name=OBS_BUCKET ) if success: # 构建CDN访问URL cdn_url = f"https://{CDN_DOMAIN}/{obs_key}" logger.info(f"✅ 图片上传到OBS成功: {cdn_url}") logger.info(f" OBS Key: {obs_key}") return { "type": "image_url", "image_url": {"url": cdn_url} } else: logger.error(f"上传到OBS失败: {result}") raise Exception(f"上传到OBS失败: {result}") except Exception as e: logger.error(f"处理单个图片时出错: {image_path} - {str(e)}") raise try: logger.info(f"准备图片数据: {image_paths}") # 如果是单个图片路径字符串 if isinstance(image_paths, str): return await process_single_image(image_paths) # 如果是图片路径列表 elif isinstance(image_paths, list): result = [] for i, image_path in enumerate(image_paths): logger.info(f"处理列表中的图片 {i+1}/{len(image_paths)}: {image_path}") result.append(await process_single_image(image_path)) logger.info(f"成功处理 {len(result)} 张图片") return result else: raise TypeError(f"image_paths参数类型错误,应为str或list,实际为{type(image_paths)}") except Exception as e: logger.error(f"准备图片数据时出错: {e}") raise async def generate_image(self, prompt, image_paths=None, n=1, size="1024x1024"): """ 异步调用 API 生成图片(支持文字生图和图片+文字改图) 参数: prompt: 图片描述 image_paths: 图片路径或URL,可以是单个字符串或列表(用于改图功能) n: 生成图片数量,默认为 1 size: 图片尺寸,默认为 "1024x1024" 返回: response_json: API 响应的 JSON 数据 """ # 确保使用有效的模型 current_model = self.model if current_model not in self.available_models: logger.warning(f"警告: 模型 {current_model} 可能不在 {self.provider_name} 的支持列表中: {list(self.available_models.keys())}") logger.info(f"开始生成图片 - 提供商: {self.provider_name}, 模型: {current_model}, 尺寸: {size}, 数量: {n}") logger.info(f"提示词: {prompt}") if image_paths: logger.info(f"图片输入: {image_paths}") # 构建API调用消息 if image_paths: # 处理图片数据 image_data = await self.prepare_image_data(image_paths) # 构建payload,支持改图功能 image_urls = [] if isinstance(image_data, list): for item in image_data: if "image_url" in item and "url" in item["image_url"]: image_urls.append(item["image_url"]["url"]) else: if "image_url" in image_data and "url" in image_data["image_url"]: image_urls.append(image_data["image_url"]["url"]) # 如果只有一个URL,直接传递字符串;如果有多个,传递列表 final_image_input = image_urls[0] if len(image_urls) == 1 else image_urls payload = { "prompt": prompt, "n": n, "model": current_model, "size": size, "response_format": self.response_format, "image": final_image_input, # 支持单个URL或URL列表 "timeout": self.timeout } logger.info(f"构建改图请求,使用模型: {current_model},尺寸: {size}") else: # 仅文字生图 payload = { "prompt": prompt, "n": n, "model": current_model, "size": size, "response_format": self.response_format, "timeout": self.timeout } logger.info(f"构建文字生图请求,使用模型: {current_model},尺寸: {size}") endpoint = f"{self.base_url}/v1/images/generations" headers = { 'Content-Type': 'application/json', "Authorization": f"Bearer {self.api_key}" } logger.debug(f"请求头: {headers}") logger.debug(f"请求payload: {payload}") try: logger.info(f"正在异步调用{self.provider_name} API...") logger.info(f"使用端点: {endpoint}") async with aiohttp.ClientSession() as session: async with session.post( endpoint, json=payload, headers=headers, timeout=self.timeout ) as response: response.raise_for_status() # 检查HTTP响应状态 response_json = await response.json() logger.info("API调用完成!") logger.debug(f"API响应: {response_json}") # 使用debug级别记录完整响应 return response_json except Exception as e: logger.error(f"API调用失败: {str(e)}") raise async def get_image_urls(self, response_json): """ 异步获取图片 URL 列表 参数: response_json: API 响应的 JSON 数据 返回: image_urls: 图片 URL 列表 """ if response_json is None: logger.info("未提供响应数据") return [] # 检查数据结构并提取图片地址 image_urls = [] # 处理不同的响应格式 if "data" in response_json: data = response_json["data"] # 格式1: data是字符串URL if isinstance(data, str): image_urls.append(data) # 格式2: data是包含url字段的对象列表 elif isinstance(data, list): for item in data: if isinstance(item, dict): # 检查是否有url字段 if "url" in item: image_urls.append(item["url"]) # 检查是否有嵌套的数据结构 elif "data" in item: if isinstance(item["data"], str): image_urls.append(item["data"]) elif isinstance(item["data"], list): for nested_item in item["data"]: if isinstance(nested_item, dict) and "url" in nested_item: image_urls.append(nested_item["url"]) elif isinstance(nested_item, str): image_urls.append(nested_item) # 格式3: item本身是字符串URL elif isinstance(item, str): image_urls.append(item) # 格式4: data是一个对象,包含url字段 elif isinstance(data, dict): if "url" in data: image_urls.append(data["url"]) # 检查是否有嵌套的数据结构 elif "data" in data: if isinstance(data["data"], str): image_urls.append(data["data"]) elif isinstance(data["data"], list): for item in data["data"]: if isinstance(item, dict) and "url" in item: image_urls.append(item["url"]) elif isinstance(item, str): image_urls.append(item) # Fallback for TuZi provider which might return 'url' at top level or nested differently if not image_urls and "url" in response_json: image_urls.append(response_json["url"]) # Log if no URLs found but response exists if not image_urls: logger.warning(f"Failed to extract image URLs from response: {response_json}") return image_urls async def print_image_urls(self, response_json): """ 异步打印图片 URL 参数: response_json: API 响应的 JSON 数据 """ logger.info("\n正在解析图片地址...") image_urls = await self.get_image_urls(response_json) if image_urls: for i, url in enumerate(image_urls): logger.info(f"图片 {i+1} URL: {url}") else: logger.info("未找到有效的图片 URL") async def download_and_upload_to_obs(self, response_json, overlay_logo_path=None): """ 异步从临时图片URL下载图片并直接上传到OBS(不落盘) 参数: response_json: API 响应的 JSON 数据 overlay_logo_path: 水印Logo图片路径(可选) 返回: uploaded_urls: 上传到OBS的图片URL列表 """ logger.info("\n正在下载图片并上传到OBS...") # 获取图片URL列表 image_urls = await self.get_image_urls(response_json) uploaded_urls = [] if not image_urls: logger.info("未找到有效的图片 URL") return uploaded_urls logger.info(f"找到 {len(image_urls)} 个图片URL,开始下载并上传到OBS") for i, image_url in enumerate(image_urls): try: logger.info(f"正在处理图片 {i+1}/{len(image_urls)}: {image_url}") # 生成GUID作为文件名 file_guid = str(uuid.uuid4()) # 从URL获取文件扩展名 if image_url.lower().endswith('.png'): ext = 'png' elif image_url.lower().endswith('.jpg') or image_url.lower().endswith('.jpeg'): ext = 'jpg' elif image_url.lower().endswith('.webp'): ext = 'webp' else: ext = 'jpg' # 默认使用jpg # 构建OBS key obs_key = f"{OBS_CLOUD_PREFIX}/{self.provider_code.capitalize()}/{file_guid}.{ext}" logger.debug(f"构建OBS key: {obs_key}") # 使用异步方式下载图片 logger.info(f"正在下载图片: {image_url}") async with aiohttp.ClientSession() as session: async with session.get(image_url, timeout=120) as response: response.raise_for_status() image_data = await response.read() logger.debug(f"成功下载图片,大小: {len(image_data)} 字节") # 添加水印逻辑 if overlay_logo_path and os.path.exists(overlay_logo_path): try: from PIL import Image import io logger.info(f"正在添加水印: {overlay_logo_path}") # 打开主图 img = Image.open(io.BytesIO(image_data)) # 打开Logo logo = Image.open(overlay_logo_path) # 计算Logo大小 (宽度为原图的20%) target_logo_width = int(img.width * 0.2) aspect_ratio = logo.height / logo.width target_logo_height = int(target_logo_width * aspect_ratio) # 调整Logo大小 logo = logo.resize((target_logo_width, target_logo_height), Image.Resampling.LANCZOS) # 计算位置 (左上角,3% padding) padding = int(img.width * 0.03) position = (padding, padding) # 粘贴Logo (处理透明度) if logo.mode in ('RGBA', 'LA') or (logo.mode == 'P' and 'transparency' in logo.info): img.paste(logo, position, logo) else: img.paste(logo, position) # 保存回bytes output_buffer = io.BytesIO() # 保持原格式,如果无法识别则使用PNG save_format = img.format if img.format else 'PNG' img.save(output_buffer, format=save_format) image_data = output_buffer.getvalue() logger.info("水印添加成功") except Exception as e: logger.error(f"添加水印失败: {e}") # 失败后继续上传原图 # 直接上传到OBS(不落盘) # 直接上传到OBS(不落盘) logger.info(f"正在上传到OBS: {obs_key}") success, result = self.obs_uploader.upload_base64_image( object_key=obs_key, base64_data=image_data, bucket_name=OBS_BUCKET ) if success: # 构建OBS访问URL obs_url = f"https://{OBS_BUCKET}.{OBS_SERVER}/{obs_key}" # 构建CDN访问URL cdn_url = f"https://{CDN_DOMAIN}/{obs_key}" logger.info(f"✅ 图片 {i+1} 上传成功: {cdn_url}") logger.info(f" OBS Key: {obs_key}") uploaded_urls.append(cdn_url) else: logger.error(f"上传到OBS失败: {result}") except Exception as e: logger.error(f"处理图片 {image_url} 失败: {str(e)}") import traceback logger.error(f"详细错误堆栈: {traceback.format_exc()}") logger.info(f"图片处理完成,成功上传 {len(uploaded_urls)}/{len(image_urls)} 个图片到OBS") return uploaded_urls async def test_provider(provider_name): """测试指定供应商""" print(f"\n{'='*20} 测试 {provider_name} 供应商 {'='*20}") try: # 创建客户端实例 client = BananaClient(provider=provider_name) # 测试1: 文字生图 logger.info("\n---------------- 文字生图测试 ----------------") text_prompt = "一只可爱的猫咪坐在月亮上,星空背景,卡通风格" logger.info(f"提示词: {text_prompt}") response = await client.generate_image( prompt=text_prompt, n=1, size="1024x1024" ) if response: logger.info("文字生图成功!") await client.print_image_urls(response) # 可选:上传到OBS # await client.download_and_upload_to_obs(response) else: logger.error("文字生图失败!") except Exception as e: logger.error(f"{provider_name} 测试失败: {e}") import traceback logger.error(traceback.format_exc()) async def main(): """ 异步主函数,用于测试 BananaClient 的功能 """ # 配置控制台日志输出 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger.info("开始 BananaClient 整合测试...") # 测试 TuZi (默认) await test_provider("TuZi") # 测试 GPTNB #await test_provider("GPTNB") logger.info("\n" + "="*50) logger.info("测试全部完成!") logger.info("="*50) if __name__ == "__main__": asyncio.run(main())