Kevin Hu
commited on
Commit
·
541b2f3
1
Parent(s):
fea9976
Make fast embed and default embed mutually exclusive. (#4121)
Browse files### What problem does this PR solve?
### Type of change
- [x] Performance Improvement
- rag/llm/embedding_model.py +7 -10
rag/llm/embedding_model.py
CHANGED
|
@@ -251,11 +251,8 @@ class OllamaEmbed(Base):
|
|
| 251 |
return np.array(res["embedding"]), 128
|
| 252 |
|
| 253 |
|
| 254 |
-
class FastEmbed(
|
| 255 |
-
|
| 256 |
-
_model_name = ""
|
| 257 |
-
_model_lock = threading.Lock()
|
| 258 |
-
|
| 259 |
def __init__(
|
| 260 |
self,
|
| 261 |
key: str | None = None,
|
|
@@ -267,17 +264,17 @@ class FastEmbed(Base):
|
|
| 267 |
if not settings.LIGHTEN and not FastEmbed._model:
|
| 268 |
with FastEmbed._model_lock:
|
| 269 |
from fastembed import TextEmbedding
|
| 270 |
-
if not
|
| 271 |
try:
|
| 272 |
-
|
| 273 |
-
|
| 274 |
except Exception:
|
| 275 |
cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5",
|
| 276 |
local_dir=os.path.join(get_home_cache_dir(),
|
| 277 |
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
| 278 |
local_dir_use_symlinks=False)
|
| 279 |
-
|
| 280 |
-
self._model =
|
| 281 |
self._model_name = model_name
|
| 282 |
|
| 283 |
def encode(self, texts: list):
|
|
|
|
| 251 |
return np.array(res["embedding"]), 128
|
| 252 |
|
| 253 |
|
| 254 |
+
class FastEmbed(DefaultEmbedding):
|
| 255 |
+
|
|
|
|
|
|
|
|
|
|
| 256 |
def __init__(
|
| 257 |
self,
|
| 258 |
key: str | None = None,
|
|
|
|
| 264 |
if not settings.LIGHTEN and not FastEmbed._model:
|
| 265 |
with FastEmbed._model_lock:
|
| 266 |
from fastembed import TextEmbedding
|
| 267 |
+
if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
|
| 268 |
try:
|
| 269 |
+
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
| 270 |
+
DefaultEmbedding._model_name = model_name
|
| 271 |
except Exception:
|
| 272 |
cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5",
|
| 273 |
local_dir=os.path.join(get_home_cache_dir(),
|
| 274 |
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
| 275 |
local_dir_use_symlinks=False)
|
| 276 |
+
DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
| 277 |
+
self._model = DefaultEmbedding._model
|
| 278 |
self._model_name = model_name
|
| 279 |
|
| 280 |
def encode(self, texts: list):
|