Anush008 KevinHuSh commited on
Commit
a86164e
·
1 Parent(s): 5ec6524

feat: FastEmbed embedding support (#291)

Browse files

### Description

Following up on https://github.com/infiniflow/ragflow/pull/275, this PR
adds support for FastEmbed model configurations.

The options are not exhaustive. You can find the full list
[here](https://qdrant.github.io/fastembed/examples/Supported_Models/).

P.S. I ran into OOM issues when building the image.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: KevinHuSh <[email protected]>

api/db/init_data.py CHANGED
@@ -109,6 +109,11 @@ factory_infos = [{
109
  "logo": "",
110
  "tags": "LLM,TEXT EMBEDDING",
111
  "status": "1",
 
 
 
 
 
112
  },
113
  {
114
  "name": "Xinference",
@@ -268,6 +273,58 @@ def init_llm_factory():
268
  "max_tokens": 128 * 1000,
269
  "model_type": LLMType.CHAT.value
270
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  ]
272
  for info in factory_infos:
273
  try:
 
109
  "logo": "",
110
  "tags": "LLM,TEXT EMBEDDING",
111
  "status": "1",
112
+ }, {
113
+ "name": "FastEmbed",
114
+ "logo": "",
115
+ "tags": "TEXT EMBEDDING",
116
+ "status": "1",
117
  },
118
  {
119
  "name": "Xinference",
 
273
  "max_tokens": 128 * 1000,
274
  "model_type": LLMType.CHAT.value
275
  },
276
+ # ------------------------ FastEmbed -----------------------
277
+ {
278
+ "fid": factory_infos[5]["name"],
279
+ "llm_name": "BAAI/bge-small-en-v1.5",
280
+ "tags": "TEXT EMBEDDING,",
281
+ "max_tokens": 512,
282
+ "model_type": LLMType.EMBEDDING.value
283
+ }, {
284
+ "fid": factory_infos[5]["name"],
285
+ "llm_name": "BAAI/bge-small-zh-v1.5",
286
+ "tags": "TEXT EMBEDDING,",
287
+ "max_tokens": 512,
288
+ "model_type": LLMType.EMBEDDING.value
289
+ }, {
290
+ }, {
291
+ "fid": factory_infos[5]["name"],
292
+ "llm_name": "BAAI/bge-base-en-v1.5",
293
+ "tags": "TEXT EMBEDDING,",
294
+ "max_tokens": 512,
295
+ "model_type": LLMType.EMBEDDING.value
296
+ }, {
297
+ }, {
298
+ "fid": factory_infos[5]["name"],
299
+ "llm_name": "BAAI/bge-large-en-v1.5",
300
+ "tags": "TEXT EMBEDDING,",
301
+ "max_tokens": 512,
302
+ "model_type": LLMType.EMBEDDING.value
303
+ }, {
304
+ "fid": factory_infos[5]["name"],
305
+ "llm_name": "sentence-transformers/all-MiniLM-L6-v2",
306
+ "tags": "TEXT EMBEDDING,",
307
+ "max_tokens": 512,
308
+ "model_type": LLMType.EMBEDDING.value
309
+ }, {
310
+ "fid": factory_infos[5]["name"],
311
+ "llm_name": "nomic-ai/nomic-embed-text-v1.5",
312
+ "tags": "TEXT EMBEDDING,",
313
+ "max_tokens": 8192,
314
+ "model_type": LLMType.EMBEDDING.value
315
+ }, {
316
+ "fid": factory_infos[5]["name"],
317
+ "llm_name": "jinaai/jina-embeddings-v2-small-en",
318
+ "tags": "TEXT EMBEDDING,",
319
+ "max_tokens": 2147483648,
320
+ "model_type": LLMType.EMBEDDING.value
321
+ }, {
322
+ "fid": factory_infos[5]["name"],
323
+ "llm_name": "jinaai/jina-embeddings-v2-base-en",
324
+ "tags": "TEXT EMBEDDING,",
325
+ "max_tokens": 2147483648,
326
+ "model_type": LLMType.EMBEDDING.value
327
+ },
328
  ]
329
  for info in factory_infos:
330
  try:
rag/llm/__init__.py CHANGED
@@ -24,7 +24,8 @@ EmbeddingModel = {
24
  "Xinference": XinferenceEmbed,
25
  "Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
26
  "ZHIPU-AI": ZhipuEmbed,
27
- "Moonshot": HuEmbedding
 
28
  }
29
 
30
 
 
24
  "Xinference": XinferenceEmbed,
25
  "Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
26
  "ZHIPU-AI": ZhipuEmbed,
27
+ "Moonshot": HuEmbedding,
28
+ "FastEmbed": FastEmbed
29
  }
30
 
31
 
rag/llm/embedding_model.py CHANGED
@@ -13,12 +13,14 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
16
  from zhipuai import ZhipuAI
17
  import os
18
  from abc import ABC
19
  from ollama import Client
20
  import dashscope
21
  from openai import OpenAI
 
22
  from FlagEmbedding import FlagModel
23
  import torch
24
  import numpy as np
@@ -172,6 +174,34 @@ class OllamaEmbed(Base):
172
  return np.array(res["embedding"]), 128
173
 
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  class XinferenceEmbed(Base):
176
  def __init__(self, key, model_name="", base_url=""):
177
  self.client = OpenAI(api_key="xxx", base_url=base_url)
@@ -187,3 +217,4 @@ class XinferenceEmbed(Base):
187
  res = self.client.embeddings.create(input=[text],
188
  model=self.model_name)
189
  return np.array(res.data[0].embedding), res.usage.total_tokens
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ from typing import Optional
17
  from zhipuai import ZhipuAI
18
  import os
19
  from abc import ABC
20
  from ollama import Client
21
  import dashscope
22
  from openai import OpenAI
23
+ from fastembed import TextEmbedding
24
  from FlagEmbedding import FlagModel
25
  import torch
26
  import numpy as np
 
174
  return np.array(res["embedding"]), 128
175
 
176
 
177
+ class FastEmbed(Base):
178
+ def __init__(
179
+ self,
180
+ key: Optional[str] = None,
181
+ model_name: str = "BAAI/bge-small-en-v1.5",
182
+ cache_dir: Optional[str] = None,
183
+ threads: Optional[int] = None,
184
+ **kwargs,
185
+ ):
186
+ self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
187
+
188
+ def encode(self, texts: list, batch_size=32):
189
+ # Using the internal tokenizer to encode the texts and get the total number of tokens
190
+ encodings = self._model.model.tokenizer.encode_batch(texts)
191
+ total_tokens = sum(len(e) for e in encodings)
192
+
193
+ embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)]
194
+
195
+ return np.array(embeddings), total_tokens
196
+
197
+ def encode_queries(self, text: str):
198
+ # Using the internal tokenizer to encode the texts and get the total number of tokens
199
+ encoding = self._model.model.tokenizer.encode(text)
200
+ embedding = next(self._model.query_embed(text)).tolist()
201
+
202
+ return np.array(embedding), len(encoding.ids)
203
+
204
+
205
  class XinferenceEmbed(Base):
206
  def __init__(self, key, model_name="", base_url=""):
207
  self.client = OpenAI(api_key="xxx", base_url=base_url)
 
217
  res = self.client.embeddings.create(input=[text],
218
  model=self.model_name)
219
  return np.array(res.data[0].embedding), res.usage.total_tokens
220
+
requirements.txt CHANGED
@@ -27,6 +27,7 @@ elasticsearch==8.12.1
27
  elasticsearch-dsl==8.12.0
28
  et-xmlfile==1.1.0
29
  filelock==3.13.1
 
30
  FlagEmbedding==1.2.5
31
  Flask==3.0.2
32
  Flask-Cors==4.0.0
 
27
  elasticsearch-dsl==8.12.0
28
  et-xmlfile==1.1.0
29
  filelock==3.13.1
30
+ fastembed==0.2.6
31
  FlagEmbedding==1.2.5
32
  Flask==3.0.2
33
  Flask-Cors==4.0.0