Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import zipfile | |
| import tempfile | |
| import shutil | |
| import json | |
| from typing import List, Dict, Any, Optional, Union | |
| from PIL import Image | |
| from app.api import get_chat_completion | |
| from app.config import ( | |
| STICKER_RERANKING_SYSTEM_PROMPT, | |
| PUBLIC_URL, | |
| TEMP_DIR | |
| ) | |
| from app.database import db | |
| from app.image_utils import ( | |
| save_image_temp, | |
| upload_to_huggingface, | |
| get_image_cdn_url, | |
| get_image_description, | |
| calculate_image_hash | |
| ) | |
| from app.gradio_formatter import gradio_formatter | |
| from multiprocessing import Queue | |
| # 配置日志 | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class StickerService: | |
| """贴纸服务类,处理贴纸的上传、搜索等业务逻辑""" | |
| def upload_sticker(image_file_path: str, title: str, description: str, tags: str) -> str: | |
| """上传贴纸""" | |
| try: | |
| # 打开图片 | |
| image = Image.open(image_file_path) | |
| # 检查文件名是否已存在 | |
| image_hash = calculate_image_hash(image) | |
| if db.check_image_exists(image_hash): | |
| print(f"文件已存在", image_hash) | |
| raise Exception('File_Exists') | |
| # 上传到 HuggingFace | |
| file_path, image_filename = upload_to_huggingface(image_file_path) | |
| # print('>>>> image_file_path', image_file_path) | |
| # print('>>>> image_filename', image_filename) | |
| # print('>>>> file_path', file_path) | |
| # 如果没有描述,获取图片描述 | |
| if not description: | |
| image_cdn_url = '' | |
| if (PUBLIC_URL): | |
| image_cdn_url = f'{PUBLIC_URL}/gradio_api/file={image_file_path}' | |
| else: | |
| image_cdn_url = get_image_cdn_url(file_path) | |
| print('image_cdn_url',image_cdn_url) | |
| description = get_image_description(image_cdn_url) | |
| # 清理临时文件 | |
| # os.unlink(temp_file_path) | |
| # 存储到 Milvus | |
| db.store_sticker(title, description, tags, file_path, image_hash) | |
| return f"Upload successful! {image_filename}" | |
| except Exception as e: | |
| logger.error(f"Upload failed: {str(e)}") | |
| return f"Upload failed: {str(e)}" | |
| def import_stickers( | |
| sticker_dataset: str, | |
| upload: bool = False, | |
| gen_description: bool = False, | |
| progress_callback: callable = None, | |
| total_files: int = 0 | |
| ) -> List[str]: | |
| """导入表情包数据集 | |
| Args: | |
| sticker_dataset (str): 表情包数据集路径 | |
| upload (bool, optional): 是否上传到HuggingFace. Defaults to False. | |
| gen_description (bool, optional): 是否生成AI描述. Defaults to False. | |
| progress_callback (callable, optional): 进度回调函数. Defaults to None. | |
| total_files (int, optional): 总文件数. Defaults to 0. | |
| """ | |
| results = [] | |
| descriptions = {} | |
| try: | |
| # 创建临时目录 | |
| temp_dir = TEMP_DIR | |
| logger.info(f"Created temporary directory: {temp_dir}") | |
| # 解压数据集 | |
| with zipfile.ZipFile(sticker_dataset, 'r') as zip_ref: | |
| zip_ref.extractall(temp_dir) | |
| logger.info(f"Extracted dataset to: {temp_dir}") | |
| # 尝试读取data.json文件 | |
| data_json_path = os.path.join(temp_dir, 'data.json') | |
| if os.path.exists(data_json_path): | |
| with open(data_json_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| descriptions = { x["filename"]: x["content"] for x in data } | |
| logger.info(f"Loaded descriptions from data.json") | |
| # 遍历解压后的目录 | |
| for root, dirs, files in os.walk(temp_dir): | |
| for file in files: | |
| if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.webp')): | |
| image_path = os.path.join(root, file) | |
| try: | |
| # 打开图片 | |
| image = Image.open(image_path) | |
| # 检查文件名是否已存在 | |
| image_hash = calculate_image_hash(image) | |
| if db.check_image_exists(image_hash): | |
| results.append(f"跳过已存在的图片: {file}") | |
| if progress_callback: | |
| progress_callback(file, "Skipped (exists)") | |
| continue | |
| # 保存到临时文件 | |
| temp_file_path = save_image_temp(image) | |
| # 上传到 HuggingFace | |
| if upload: | |
| file_path, image_filename = upload_to_huggingface(temp_file_path) | |
| else: | |
| file_path = temp_file_path | |
| image_url = file_path | |
| # 获取图片描述 | |
| description = None | |
| if file in descriptions: | |
| description = descriptions[file] | |
| elif gen_description: | |
| image_cdn_url = '' | |
| if (PUBLIC_URL): | |
| image_cdn_url = f'{PUBLIC_URL}/gradio_api/file={temp_file_path}' | |
| else: | |
| image_cdn_url = get_image_cdn_url(file_path) | |
| description = get_image_description(image_cdn_url) | |
| if not description: | |
| logger.warning(f"No description available for {file}") | |
| if progress_callback: | |
| progress_callback(file, "Skipped (no description)") | |
| continue | |
| # 清理临时文件 | |
| if upload: | |
| os.unlink(temp_file_path) | |
| db.store_sticker("", description, "", file_path, image_hash) | |
| results.append(f"成功导入: {file}") | |
| if progress_callback: | |
| progress_callback(file, "Imported") | |
| except Exception as e: | |
| logger.error(f"Failed to process image {file}: {str(e)}") | |
| results.append(f"处理失败 {file}: {str(e)}") | |
| if progress_callback: | |
| progress_callback(file, f"Failed: {str(e)}") | |
| return results | |
| except Exception as e: | |
| logger.error(f"Import failed: {str(e)}") | |
| results.append(f"导入失败: {str(e)}") | |
| return results | |
| finally: | |
| # 清理临时目录 | |
| if temp_dir and os.path.exists(temp_dir): | |
| shutil.rmtree(temp_dir) | |
| logger.info(f"Cleaned up temporary directory: {temp_dir}") | |
| def search_stickers(description: str, limit: int = 2, reranking : bool = False) -> List[Dict[str, Any]]: | |
| """搜索贴纸""" | |
| if not description: | |
| return [] | |
| try: | |
| results = db.search_stickers(description, limit) | |
| if (reranking): | |
| # 对搜索结果进行重排 | |
| results = StickerService.rerank_search_results(description, results, limit) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Search failed: {str(e)}") | |
| return [] | |
| def get_all_stickers(limit: int = 1000) -> List[List]: | |
| """获取所有贴纸""" | |
| try: | |
| results = db.get_all_stickers(limit) | |
| return gradio_formatter.format_all_stickers(results) | |
| except Exception as e: | |
| logger.error(f"Failed to get all stickers: {str(e)}") | |
| return [] | |
| def delete_sticker(sticker_id: str) -> str: | |
| """删除贴纸""" | |
| try: | |
| # 首先查询贴纸是否存在 | |
| result = db.delete_sticker(sticker_id) | |
| return f"Sticker with ID {sticker_id} deleted successfully" | |
| except Exception as e: | |
| logger.error(f"Delete failed: {str(e)}") | |
| return f"Delete failed: {str(e)}" | |
| def rerank_search_results(query: str, sticker_list: List[Dict[str, Any]], limit: int = 5) -> List[Dict[str, Any]]: | |
| ## 使用 LLM 模型重新排序搜索结果 | |
| try: | |
| # 构建提示词 | |
| system_prompt = STICKER_RERANKING_SYSTEM_PROMPT | |
| # 构建用户提示词,包含查询和表情包信息 | |
| _sticker_list = [] | |
| for hit in sticker_list: | |
| _sticker_list.append({ | |
| "id": hit["id"], | |
| "description": hit["entity"]["description"] | |
| }) | |
| user_prompt = f"请分析关键词 '{query}' 与以下表情包的相关性:\n{_sticker_list}" | |
| print(f">>> 使用 LLM 模型重新排序....", user_prompt, system_prompt) | |
| # 调用 LLM 模型获取重排序结果 | |
| response = get_chat_completion(user_prompt, system_prompt) | |
| # 解析 LLM 返回的 JSON 结果 | |
| reranked_stickers = json.loads(response) | |
| # 验证返回结果格式 | |
| if not isinstance(reranked_stickers, list): | |
| raise ValueError("Invalid response format") | |
| # 按分数排序 | |
| reranked_stickers.sort(key=lambda x: float(x.get("score", 0)), reverse=True) | |
| print(f">>> LLM 排序结果", reranked_stickers) | |
| # 将重排序结果与原始结果对应 | |
| rerank_results = [] | |
| for sticker in reranked_stickers: | |
| for hit in sticker_list: | |
| if str(hit["id"]) == str(sticker["sticker_id"]): | |
| hit["entity"]["score"] = sticker["score"] | |
| hit["entity"]["reason"] = sticker["reason"] | |
| rerank_results.append(hit) | |
| break | |
| print(f">>> rerank_results", rerank_results) | |
| return rerank_results | |
| except Exception as e: | |
| logger.error(f"Reranking failed: {str(e)}") | |
| return [] | |
| # 创建服务实例 | |
| sticker_service = StickerService() |