Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from typing import List, Dict, Any | |
| from sentence_transformers import SentenceTransformer | |
| from pymilvus import MilvusClient, DataType | |
| import time | |
| import gradio as gr | |
| from app.config import MILVUS_DB_URL | |
| # 配置日志 | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s %(levelname)s %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| models = [ | |
| 'shibing624/text2vec-base-chinese', | |
| 'BAAI/bge-small-zh', | |
| 'BAAI/bge-base-zh', | |
| 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', | |
| 'all-MiniLM-L6-v2', | |
| 'all-MiniLM-L12-v2', | |
| 'multi-qa-mpnet-base-dot-v1', | |
| # 'bge-small-en-v1.5', 不兼容 | |
| 'all-mpnet-base-v2', | |
| 'jinaai/jina-embeddings-v3', | |
| ] | |
| searchers = {} | |
| class BaseEmbeddingModel(ABC): | |
| def encode(self, text: str) -> List[float]: | |
| pass | |
| def dimension(self) -> int: | |
| pass | |
| def model_name(self) -> str: | |
| pass | |
| class SentenceTransformerModel(BaseEmbeddingModel): | |
| def __init__(self, model_name: str): | |
| self.model = SentenceTransformer(model_name, trust_remote_code=True) | |
| self._model_name = model_name | |
| def encode(self, text: str) -> List[float]: | |
| result = self.model.encode(text).tolist() | |
| return result | |
| def dimension(self) -> int: | |
| return self.model.get_sentence_embedding_dimension() | |
| def model_name(self) -> str: | |
| return self._model_name | |
| class StickerSearcher: | |
| def __init__(self, model: BaseEmbeddingModel): | |
| self.model = model | |
| self.client = MilvusClient(uri=MILVUS_DB_URL) | |
| self.collection_name = f'test_{model.model_name.replace("/", "_").replace("-", "_")}' | |
| def init_collection(self) -> bool: | |
| try: | |
| self.client.drop_collection(collection_name=self.collection_name) | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| dimension=self.model.dimension, | |
| primary_field_name='id', | |
| auto_id=True | |
| ) | |
| self.client.create_index( | |
| collection_name=self.collection_name, | |
| index_type='IVF_SQ8', | |
| metric_type='COSINE', | |
| params={'nlist': 128}, | |
| index_params={} | |
| ) | |
| self.client.load_collection(self.collection_name) | |
| logger.info(f'Collection initialized: {self.collection_name}') | |
| return True | |
| except Exception as e: | |
| logger.error(f'Collection init failed: {str(e)}') | |
| return False | |
| def store_vector(self, title: str, description: str, tags: List[str], file_path: str): | |
| vector = self.model.encode(description) | |
| data = [{ | |
| 'vector': vector, | |
| 'title': title, | |
| 'description': description, | |
| 'tags': tags, | |
| 'file_name': file_path | |
| }] | |
| self.client.insert(self.collection_name, data) | |
| def search(self, query: str, limit: int = 5) -> List[Dict[str, Any]]: | |
| start_time = time.time() | |
| query_vector = self.model.encode(query) | |
| encode_time = time.time() - start_time | |
| start_search_time = time.time() | |
| results = self.client.search( | |
| collection_name=self.collection_name, | |
| data=[query_vector], | |
| limit=limit, | |
| output_fields=['title', 'description', 'tags', 'file_name'] | |
| ) | |
| search_time = time.time() - start_search_time | |
| total_time = encode_time + search_time | |
| logger.info(f'模型 {self.model.model_name} Encoding耗时: ${encode_time:.4f},搜索耗时: {search_time:.4f} 秒, 总耗时: {total_time:.4f} 秒') | |
| return results[0] | |
| def create_gradio_ui(): | |
| async def search_model(model_name: str, query: str): | |
| try: | |
| if model_name in searchers: | |
| return searchers[model_name].search(query) | |
| logger.error(f'Model not loaded: {model_name}') | |
| return [] | |
| except Exception as e: | |
| logger.error(f'Search failed: {model_name} | Error: {str(e)}') | |
| return [] | |
| async def search_all_models(query): | |
| if not query: | |
| return [] | |
| print(f'>>>> Searching From Models {query}') | |
| results = [] | |
| for model_name in models: | |
| result = await search_model(model_name, query) | |
| results.append(result) | |
| formatted_results = [] | |
| max_results = max(len(r) for r in results) | |
| for i in range(max_results): | |
| row = [i + 1] | |
| for model_results in results: | |
| if i < len(model_results): | |
| result = model_results[i] | |
| image_url = f'https://huggingface.co/datasets/Nekoko/StickerSet/resolve/main/{result["entity"]["file_name"]}' | |
| row.append(f'\n相似度: {result["distance"]:.4f}') | |
| else: | |
| row.append('-') | |
| formatted_results.append(row) | |
| return formatted_results | |
| def init_collections(): | |
| try: | |
| client = MilvusClient(uri=MILVUS_DB_URL) | |
| stickers = client.query( | |
| collection_name='stickers', | |
| filter='', | |
| limit=1000, | |
| output_fields=['title', 'description', 'tags', 'file_name'] | |
| ) | |
| logger.info(f'Stickers loaded: {len(stickers)}') | |
| def init_model(model_name): | |
| try: | |
| searcher = StickerSearcher(SentenceTransformerModel(model_name)) | |
| if searcher.init_collection(): | |
| searchers[model_name] = searcher | |
| for sticker in stickers: | |
| searcher.store_vector( | |
| sticker.get('title'), | |
| sticker.get('description'), | |
| sticker.get('tags'), | |
| sticker.get('file_name') | |
| ) | |
| logger.info(f'Model initialized: {model_name}') | |
| except Exception as e: | |
| logger.error(f'Model init failed: {model_name} | Error: {str(e)}') | |
| for model_name in models: | |
| print(f'>>>> 初始化模型 {model_name}') | |
| start_time = time.time() | |
| init_model(model_name) | |
| print(f'>>>> 初始化模型 {model_name} 完成 ✅,耗时 {time.time() - start_time:.4f} 秒') | |
| print(f'>>>> 初始化所有模型完成 ✅') | |
| return '初始化成功!' | |
| except Exception as e: | |
| logger.error(f'Data init failed: {str(e)}') | |
| return f'初始化失败: {str(e)}' | |
| with gr.Blocks(title='Neko Sticker Search 🔍', css='.gradio-container img { width: 200px !important; height: 200px !important; object-fit: contain; }') as demo: | |
| with gr.Row(): | |
| search_input = gr.Textbox(label='搜索关键词') | |
| search_button = gr.Button('搜索') | |
| headers = ['序号'] + [f'🧊{model.split("/")[-1]}' for i, model in enumerate(models)] | |
| results_table = gr.Dataframe( | |
| headers=headers, | |
| datatype=['number'] + ['markdown'] * len(models), | |
| row_count=5, | |
| col_count=len(models) + 1 | |
| ) | |
| status_box = gr.Textbox(label='状态', interactive=False) | |
| refresh_button = gr.Button('刷新数据') | |
| refresh_button.click(fn=init_collections, outputs=status_box) | |
| # 由于这里只是简单的搜索操作,可以直接使用同步方式调用 | |
| search_button.click( | |
| fn=search_all_models, | |
| inputs=[search_input], | |
| outputs=results_table | |
| ) | |
| return demo | |
| if __name__ == '__main__': | |
| # 提前加载所有模型 | |
| start_time = time.time() | |
| for index, model_name in enumerate(models): | |
| try: | |
| start_time = time.time() | |
| searchers[model_name] = StickerSearcher(SentenceTransformerModel(model_name)) | |
| print(f'>>>> 预加载模型 {model_name} 完成 ✅, 耗时 {time.time() - start_time:.4f} 秒') | |
| except Exception as e: | |
| logger.error(f'Model preload failed: {model_name} | Error: {str(e)}') | |
| logger.info(f'>>>> 预加载模型完成 ✅: {models}, 耗时 {time.time() - start_time:.4f} 秒') | |
| demo = create_gradio_ui() | |
| demo.launch() |