Kevin Hu commited on
Commit
541b2f3
·
1 Parent(s): fea9976

Make fast embed and default embed mutually exclusive. (#4121)

Browse files

### What problem does this PR solve?


### Type of change

- [x] Performance Improvement

Files changed (1) hide show
  1. rag/llm/embedding_model.py +7 -10
rag/llm/embedding_model.py CHANGED
@@ -251,11 +251,8 @@ class OllamaEmbed(Base):
251
  return np.array(res["embedding"]), 128
252
 
253
 
254
- class FastEmbed(Base):
255
- _model = None
256
- _model_name = ""
257
- _model_lock = threading.Lock()
258
-
259
  def __init__(
260
  self,
261
  key: str | None = None,
@@ -267,17 +264,17 @@ class FastEmbed(Base):
267
  if not settings.LIGHTEN and not FastEmbed._model:
268
  with FastEmbed._model_lock:
269
  from fastembed import TextEmbedding
270
- if not FastEmbed._model or model_name != FastEmbed._model_name:
271
  try:
272
- FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
273
- FastEmbed._model_name = model_name
274
  except Exception:
275
  cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5",
276
  local_dir=os.path.join(get_home_cache_dir(),
277
  re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
278
  local_dir_use_symlinks=False)
279
- FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
280
- self._model = FastEmbed._model
281
  self._model_name = model_name
282
 
283
  def encode(self, texts: list):
 
251
  return np.array(res["embedding"]), 128
252
 
253
 
254
+ class FastEmbed(DefaultEmbedding):
255
+
 
 
 
256
  def __init__(
257
  self,
258
  key: str | None = None,
 
264
  if not settings.LIGHTEN and not FastEmbed._model:
265
  with FastEmbed._model_lock:
266
  from fastembed import TextEmbedding
267
+ if not DefaultEmbedding._model or model_name != DefaultEmbedding._model_name:
268
  try:
269
+ DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
270
+ DefaultEmbedding._model_name = model_name
271
  except Exception:
272
  cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5",
273
  local_dir=os.path.join(get_home_cache_dir(),
274
  re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
275
  local_dir_use_symlinks=False)
276
+ DefaultEmbedding._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
277
+ self._model = DefaultEmbedding._model
278
  self._model_name = model_name
279
 
280
  def encode(self, texts: list):