Spaces:
Build error
Build error
| from base_class import Embedding_Model | |
| import pickle | |
| from sentence_transformers import SentenceTransformer | |
| from openai.embeddings_utils import ( | |
| get_embedding, | |
| ) | |
| class HuggingfaceSentenceTransformerModel(Embedding_Model): | |
| EMBEDDING_MODEL = "distiluse-base-multilingual-cased-v2" | |
| def __init__(self, model_name=EMBEDDING_MODEL) -> None: | |
| super().__init__(model_name) | |
| self.model = SentenceTransformer(model_name, cache_folder="/app/ckpt/") | |
| def __call__(self, text) -> None: | |
| return self.model.encode(text) | |
| class OpenAIEmbeddingModel(Embedding_Model): | |
| # constants | |
| EMBEDDING_MODEL = "text-embedding-ada-002" | |
| # establish a cache of embeddings to avoid recomputing | |
| # cache is a dict of tuples (text, model) -> embedding, saved as a pickle file | |
| def __init__(self, model_name=EMBEDDING_MODEL) -> None: | |
| super().__init__(model_name) | |
| self.model_name = model_name | |
| # define a function to retrieve embeddings from the cache if present, and otherwise request via the API | |
| def embedding_from_string(self, | |
| string: str, | |
| ) -> list: | |
| """Return embedding of given string, using a cache to avoid recomputing.""" | |
| model = self.model_name | |
| if (string, model) not in self.embedding_cache.keys(): | |
| self.embedding_cache[(string, model)] = get_embedding( | |
| string, model) | |
| with open(self.embedding_cache_path, "wb") as embedding_cache_file: | |
| pickle.dump(self.embedding_cache, embedding_cache_file) | |
| return self.embedding_cache[(string, model)] | |
| def __call__(self, text) -> None: | |
| return self.embedding_from_string(text) | |