Refactor embedding batch_size (#3825)
Browse files### What problem does this PR solve?
Refactor embedding batch_size. Close #3657
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
- api/db/services/llm_service.py +4 -4
- rag/benchmark.py +5 -8
- rag/llm/embedding_model.py +151 -97
api/db/services/llm_service.py
CHANGED
|
@@ -232,13 +232,13 @@ class LLMBundle(object):
|
|
| 232 |
self.max_length = lm.max_tokens
|
| 233 |
break
|
| 234 |
|
| 235 |
-
def encode(self, texts: list
|
| 236 |
-
|
| 237 |
if not TenantLLMService.increase_usage(
|
| 238 |
self.tenant_id, self.llm_type, used_tokens):
|
| 239 |
logging.error(
|
| 240 |
"LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
| 241 |
-
return
|
| 242 |
|
| 243 |
def encode_queries(self, query: str):
|
| 244 |
emd, used_tokens = self.mdl.encode_queries(query)
|
|
@@ -280,7 +280,7 @@ class LLMBundle(object):
|
|
| 280 |
logging.error(
|
| 281 |
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
| 282 |
return
|
| 283 |
-
yield chunk
|
| 284 |
|
| 285 |
def chat(self, system, history, gen_conf):
|
| 286 |
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
|
|
|
| 232 |
self.max_length = lm.max_tokens
|
| 233 |
break
|
| 234 |
|
| 235 |
+
def encode(self, texts: list):
|
| 236 |
+
embeddings, used_tokens = self.mdl.encode(texts)
|
| 237 |
if not TenantLLMService.increase_usage(
|
| 238 |
self.tenant_id, self.llm_type, used_tokens):
|
| 239 |
logging.error(
|
| 240 |
"LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
|
| 241 |
+
return embeddings, used_tokens
|
| 242 |
|
| 243 |
def encode_queries(self, query: str):
|
| 244 |
emd, used_tokens = self.mdl.encode_queries(query)
|
|
|
|
| 280 |
logging.error(
|
| 281 |
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
| 282 |
return
|
| 283 |
+
yield chunk
|
| 284 |
|
| 285 |
def chat(self, system, history, gen_conf):
|
| 286 |
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
rag/benchmark.py
CHANGED
|
@@ -63,16 +63,13 @@ class Benchmark:
|
|
| 63 |
run[query][c["chunk_id"]] = c["similarity"]
|
| 64 |
return run
|
| 65 |
|
| 66 |
-
def embedding(self, docs
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
vts, c = self.embd_mdl.encode(cnts[i: i + batch_size])
|
| 71 |
-
vects.extend(vts.tolist())
|
| 72 |
-
assert len(docs) == len(vects)
|
| 73 |
vector_size = 0
|
| 74 |
for i, d in enumerate(docs):
|
| 75 |
-
v =
|
| 76 |
vector_size = len(v)
|
| 77 |
d["q_%d_vec" % len(v)] = v
|
| 78 |
return docs, vector_size
|
|
|
|
| 63 |
run[query][c["chunk_id"]] = c["similarity"]
|
| 64 |
return run
|
| 65 |
|
| 66 |
+
def embedding(self, docs):
|
| 67 |
+
texts = [d["content_with_weight"] for d in docs]
|
| 68 |
+
embeddings, _ = self.embd_mdl.encode(texts)
|
| 69 |
+
assert len(docs) == len(embeddings)
|
|
|
|
|
|
|
|
|
|
| 70 |
vector_size = 0
|
| 71 |
for i, d in enumerate(docs):
|
| 72 |
+
v = embeddings[i]
|
| 73 |
vector_size = len(v)
|
| 74 |
d["q_%d_vec" % len(v)] = v
|
| 75 |
return docs, vector_size
|
rag/llm/embedding_model.py
CHANGED
|
@@ -38,7 +38,7 @@ class Base(ABC):
|
|
| 38 |
def __init__(self, key, model_name):
|
| 39 |
pass
|
| 40 |
|
| 41 |
-
def encode(self, texts: list
|
| 42 |
raise NotImplementedError("Please implement encode method!")
|
| 43 |
|
| 44 |
def encode_queries(self, text: str):
|
|
@@ -78,15 +78,16 @@ class DefaultEmbedding(Base):
|
|
| 78 |
use_fp16=torch.cuda.is_available())
|
| 79 |
self._model = DefaultEmbedding._model
|
| 80 |
|
| 81 |
-
def encode(self, texts: list
|
|
|
|
| 82 |
texts = [truncate(t, 2048) for t in texts]
|
| 83 |
token_count = 0
|
| 84 |
for t in texts:
|
| 85 |
token_count += num_tokens_from_string(t)
|
| 86 |
-
|
| 87 |
for i in range(0, len(texts), batch_size):
|
| 88 |
-
|
| 89 |
-
return np.array(
|
| 90 |
|
| 91 |
def encode_queries(self, text: str):
|
| 92 |
token_count = num_tokens_from_string(text)
|
|
@@ -101,12 +102,18 @@ class OpenAIEmbed(Base):
|
|
| 101 |
self.client = OpenAI(api_key=key, base_url=base_url)
|
| 102 |
self.model_name = model_name
|
| 103 |
|
| 104 |
-
def encode(self, texts: list
|
|
|
|
|
|
|
| 105 |
texts = [truncate(t, 8191) for t in texts]
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
def encode_queries(self, text):
|
| 112 |
res = self.client.embeddings.create(input=[truncate(text, 8191)],
|
|
@@ -123,12 +130,14 @@ class LocalAIEmbed(Base):
|
|
| 123 |
self.client = OpenAI(api_key="empty", base_url=base_url)
|
| 124 |
self.model_name = model_name.split("___")[0]
|
| 125 |
|
| 126 |
-
def encode(self, texts: list
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
| 132 |
|
| 133 |
def encode_queries(self, text):
|
| 134 |
embds, cnt = self.encode([text])
|
|
@@ -155,12 +164,12 @@ class BaiChuanEmbed(OpenAIEmbed):
|
|
| 155 |
|
| 156 |
class QWenEmbed(Base):
|
| 157 |
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
| 158 |
-
|
| 159 |
self.model_name = model_name
|
| 160 |
|
| 161 |
-
def encode(self, texts: list
|
| 162 |
import dashscope
|
| 163 |
-
batch_size =
|
| 164 |
try:
|
| 165 |
res = []
|
| 166 |
token_count = 0
|
|
@@ -169,6 +178,7 @@ class QWenEmbed(Base):
|
|
| 169 |
resp = dashscope.TextEmbedding.call(
|
| 170 |
model=self.model_name,
|
| 171 |
input=texts[i:i + batch_size],
|
|
|
|
| 172 |
text_type="document"
|
| 173 |
)
|
| 174 |
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
|
@@ -186,6 +196,7 @@ class QWenEmbed(Base):
|
|
| 186 |
resp = dashscope.TextEmbedding.call(
|
| 187 |
model=self.model_name,
|
| 188 |
input=text[:2048],
|
|
|
|
| 189 |
text_type="query"
|
| 190 |
)
|
| 191 |
return np.array(resp["output"]["embeddings"][0]
|
|
@@ -200,7 +211,7 @@ class ZhipuEmbed(Base):
|
|
| 200 |
self.client = ZhipuAI(api_key=key)
|
| 201 |
self.model_name = model_name
|
| 202 |
|
| 203 |
-
def encode(self, texts: list
|
| 204 |
arr = []
|
| 205 |
tks_num = 0
|
| 206 |
for txt in texts:
|
|
@@ -221,7 +232,7 @@ class OllamaEmbed(Base):
|
|
| 221 |
self.client = Client(host=kwargs["base_url"])
|
| 222 |
self.model_name = model_name
|
| 223 |
|
| 224 |
-
def encode(self, texts: list
|
| 225 |
arr = []
|
| 226 |
tks_num = 0
|
| 227 |
for txt in texts:
|
|
@@ -252,13 +263,13 @@ class FastEmbed(Base):
|
|
| 252 |
from fastembed import TextEmbedding
|
| 253 |
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
| 254 |
|
| 255 |
-
def encode(self, texts: list
|
| 256 |
# Using the internal tokenizer to encode the texts and get the total
|
| 257 |
# number of tokens
|
| 258 |
encodings = self._model.model.tokenizer.encode_batch(texts)
|
| 259 |
total_tokens = sum(len(e) for e in encodings)
|
| 260 |
|
| 261 |
-
embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)]
|
| 262 |
|
| 263 |
return np.array(embeddings), total_tokens
|
| 264 |
|
|
@@ -278,11 +289,15 @@ class XinferenceEmbed(Base):
|
|
| 278 |
self.client = OpenAI(api_key=key, base_url=base_url)
|
| 279 |
self.model_name = model_name
|
| 280 |
|
| 281 |
-
def encode(self, texts: list
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
def encode_queries(self, text):
|
| 288 |
res = self.client.embeddings.create(input=[text],
|
|
@@ -306,7 +321,8 @@ class YoudaoEmbed(Base):
|
|
| 306 |
model_name_or_path=model_name.replace(
|
| 307 |
"maidalun1020", "InfiniFlow"))
|
| 308 |
|
| 309 |
-
def encode(self, texts: list
|
|
|
|
| 310 |
res = []
|
| 311 |
token_count = 0
|
| 312 |
for t in texts:
|
|
@@ -332,15 +348,21 @@ class JinaEmbed(Base):
|
|
| 332 |
}
|
| 333 |
self.model_name = model_name
|
| 334 |
|
| 335 |
-
def encode(self, texts: list
|
| 336 |
texts = [truncate(t, 8196) for t in texts]
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
def encode_queries(self, text):
|
| 346 |
embds, cnt = self.encode([text])
|
|
@@ -394,12 +416,17 @@ class MistralEmbed(Base):
|
|
| 394 |
self.client = MistralClient(api_key=key)
|
| 395 |
self.model_name = model_name
|
| 396 |
|
| 397 |
-
def encode(self, texts: list
|
| 398 |
texts = [truncate(t, 8196) for t in texts]
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
def encode_queries(self, text):
|
| 405 |
res = self.client.embeddings(input=[truncate(text, 8196)],
|
|
@@ -418,7 +445,7 @@ class BedrockEmbed(Base):
|
|
| 418 |
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
|
| 419 |
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
| 420 |
|
| 421 |
-
def encode(self, texts: list
|
| 422 |
texts = [truncate(t, 8196) for t in texts]
|
| 423 |
embeddings = []
|
| 424 |
token_count = 0
|
|
@@ -436,7 +463,6 @@ class BedrockEmbed(Base):
|
|
| 436 |
return np.array(embeddings), token_count
|
| 437 |
|
| 438 |
def encode_queries(self, text):
|
| 439 |
-
|
| 440 |
embeddings = []
|
| 441 |
token_count = num_tokens_from_string(text)
|
| 442 |
if self.model_name.split('.')[0] == 'amazon':
|
|
@@ -453,20 +479,26 @@ class BedrockEmbed(Base):
|
|
| 453 |
class GeminiEmbed(Base):
|
| 454 |
def __init__(self, key, model_name='models/text-embedding-004',
|
| 455 |
**kwargs):
|
| 456 |
-
|
| 457 |
self.model_name = 'models/' + model_name
|
| 458 |
|
| 459 |
-
def encode(self, texts: list
|
| 460 |
texts = [truncate(t, 2048) for t in texts]
|
| 461 |
token_count = sum(num_tokens_from_string(text) for text in texts)
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
|
| 469 |
def encode_queries(self, text):
|
|
|
|
| 470 |
result = genai.embed_content(
|
| 471 |
model=self.model_name,
|
| 472 |
content=truncate(text,2048),
|
|
@@ -495,19 +527,22 @@ class NvidiaEmbed(Base):
|
|
| 495 |
if model_name == "snowflake/arctic-embed-l":
|
| 496 |
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
|
| 497 |
|
| 498 |
-
def encode(self, texts: list
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
|
|
|
|
|
|
|
|
|
| 511 |
|
| 512 |
def encode_queries(self, text):
|
| 513 |
embds, cnt = self.encode([text])
|
|
@@ -541,16 +576,20 @@ class CoHereEmbed(Base):
|
|
| 541 |
self.client = Client(api_key=key)
|
| 542 |
self.model_name = model_name
|
| 543 |
|
| 544 |
-
def encode(self, texts: list
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
def encode_queries(self, text):
|
| 556 |
res = self.client.embed(
|
|
@@ -599,19 +638,23 @@ class SILICONFLOWEmbed(Base):
|
|
| 599 |
self.base_url = base_url
|
| 600 |
self.model_name = model_name
|
| 601 |
|
| 602 |
-
def encode(self, texts: list
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
res
|
| 614 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
|
| 616 |
def encode_queries(self, text):
|
| 617 |
payload = {
|
|
@@ -632,9 +675,14 @@ class ReplicateEmbed(Base):
|
|
| 632 |
self.model_name = model_name
|
| 633 |
self.client = Client(api_token=key)
|
| 634 |
|
| 635 |
-
def encode(self, texts: list
|
| 636 |
-
|
| 637 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
|
| 639 |
def encode_queries(self, text):
|
| 640 |
res = self.client.embed(self.model_name, input={"texts": [text]})
|
|
@@ -673,11 +721,17 @@ class VoyageEmbed(Base):
|
|
| 673 |
self.client = voyageai.Client(api_key=key)
|
| 674 |
self.model_name = model_name
|
| 675 |
|
| 676 |
-
def encode(self, texts: list
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
|
| 682 |
def encode_queries(self, text):
|
| 683 |
res = self.client.embed(
|
|
@@ -694,7 +748,7 @@ class HuggingFaceEmbed(Base):
|
|
| 694 |
self.model_name = model_name
|
| 695 |
self.base_url = base_url or "http://127.0.0.1:8080"
|
| 696 |
|
| 697 |
-
def encode(self, texts: list
|
| 698 |
embeddings = []
|
| 699 |
for text in texts:
|
| 700 |
response = requests.post(
|
|
|
|
| 38 |
def __init__(self, key, model_name):
|
| 39 |
pass
|
| 40 |
|
| 41 |
+
def encode(self, texts: list):
|
| 42 |
raise NotImplementedError("Please implement encode method!")
|
| 43 |
|
| 44 |
def encode_queries(self, text: str):
|
|
|
|
| 78 |
use_fp16=torch.cuda.is_available())
|
| 79 |
self._model = DefaultEmbedding._model
|
| 80 |
|
| 81 |
+
def encode(self, texts: list):
|
| 82 |
+
batch_size = 16
|
| 83 |
texts = [truncate(t, 2048) for t in texts]
|
| 84 |
token_count = 0
|
| 85 |
for t in texts:
|
| 86 |
token_count += num_tokens_from_string(t)
|
| 87 |
+
ress = []
|
| 88 |
for i in range(0, len(texts), batch_size):
|
| 89 |
+
ress.extend(self._model.encode(texts[i:i + batch_size]).tolist())
|
| 90 |
+
return np.array(ress), token_count
|
| 91 |
|
| 92 |
def encode_queries(self, text: str):
|
| 93 |
token_count = num_tokens_from_string(text)
|
|
|
|
| 102 |
self.client = OpenAI(api_key=key, base_url=base_url)
|
| 103 |
self.model_name = model_name
|
| 104 |
|
| 105 |
+
def encode(self, texts: list):
|
| 106 |
+
# OpenAI requires batch size <=16
|
| 107 |
+
batch_size = 16
|
| 108 |
texts = [truncate(t, 8191) for t in texts]
|
| 109 |
+
ress = []
|
| 110 |
+
total_tokens = 0
|
| 111 |
+
for i in range(0, len(texts), batch_size):
|
| 112 |
+
res = self.client.embeddings.create(input=texts[i:i + batch_size],
|
| 113 |
+
model=self.model_name)
|
| 114 |
+
ress.extend([d.embedding for d in res.data])
|
| 115 |
+
total_tokens += res.usage.total_tokens
|
| 116 |
+
return np.array(ress), total_tokens
|
| 117 |
|
| 118 |
def encode_queries(self, text):
|
| 119 |
res = self.client.embeddings.create(input=[truncate(text, 8191)],
|
|
|
|
| 130 |
self.client = OpenAI(api_key="empty", base_url=base_url)
|
| 131 |
self.model_name = model_name.split("___")[0]
|
| 132 |
|
| 133 |
+
def encode(self, texts: list):
|
| 134 |
+
batch_size = 16
|
| 135 |
+
ress = []
|
| 136 |
+
for i in range(0, len(texts), batch_size):
|
| 137 |
+
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
|
| 138 |
+
ress.extend([d.embedding for d in res.data])
|
| 139 |
+
# local embedding for LmStudio donot count tokens
|
| 140 |
+
return np.array(ress), 1024
|
| 141 |
|
| 142 |
def encode_queries(self, text):
|
| 143 |
embds, cnt = self.encode([text])
|
|
|
|
| 164 |
|
| 165 |
class QWenEmbed(Base):
|
| 166 |
def __init__(self, key, model_name="text_embedding_v2", **kwargs):
|
| 167 |
+
self.key = key
|
| 168 |
self.model_name = model_name
|
| 169 |
|
| 170 |
+
def encode(self, texts: list):
|
| 171 |
import dashscope
|
| 172 |
+
batch_size = 4
|
| 173 |
try:
|
| 174 |
res = []
|
| 175 |
token_count = 0
|
|
|
|
| 178 |
resp = dashscope.TextEmbedding.call(
|
| 179 |
model=self.model_name,
|
| 180 |
input=texts[i:i + batch_size],
|
| 181 |
+
api_key=self.key,
|
| 182 |
text_type="document"
|
| 183 |
)
|
| 184 |
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
|
|
|
| 196 |
resp = dashscope.TextEmbedding.call(
|
| 197 |
model=self.model_name,
|
| 198 |
input=text[:2048],
|
| 199 |
+
api_key=self.key,
|
| 200 |
text_type="query"
|
| 201 |
)
|
| 202 |
return np.array(resp["output"]["embeddings"][0]
|
|
|
|
| 211 |
self.client = ZhipuAI(api_key=key)
|
| 212 |
self.model_name = model_name
|
| 213 |
|
| 214 |
+
def encode(self, texts: list):
|
| 215 |
arr = []
|
| 216 |
tks_num = 0
|
| 217 |
for txt in texts:
|
|
|
|
| 232 |
self.client = Client(host=kwargs["base_url"])
|
| 233 |
self.model_name = model_name
|
| 234 |
|
| 235 |
+
def encode(self, texts: list):
|
| 236 |
arr = []
|
| 237 |
tks_num = 0
|
| 238 |
for txt in texts:
|
|
|
|
| 263 |
from fastembed import TextEmbedding
|
| 264 |
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
| 265 |
|
| 266 |
+
def encode(self, texts: list):
|
| 267 |
# Using the internal tokenizer to encode the texts and get the total
|
| 268 |
# number of tokens
|
| 269 |
encodings = self._model.model.tokenizer.encode_batch(texts)
|
| 270 |
total_tokens = sum(len(e) for e in encodings)
|
| 271 |
|
| 272 |
+
embeddings = [e.tolist() for e in self._model.embed(texts, batch_size=16)]
|
| 273 |
|
| 274 |
return np.array(embeddings), total_tokens
|
| 275 |
|
|
|
|
| 289 |
self.client = OpenAI(api_key=key, base_url=base_url)
|
| 290 |
self.model_name = model_name
|
| 291 |
|
| 292 |
+
def encode(self, texts: list):
|
| 293 |
+
batch_size = 16
|
| 294 |
+
ress = []
|
| 295 |
+
total_tokens = 0
|
| 296 |
+
for i in range(0, len(texts), batch_size):
|
| 297 |
+
res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name)
|
| 298 |
+
ress.extend([d.embedding for d in res.data])
|
| 299 |
+
total_tokens += res.usage.total_tokens
|
| 300 |
+
return np.array(ress), total_tokens
|
| 301 |
|
| 302 |
def encode_queries(self, text):
|
| 303 |
res = self.client.embeddings.create(input=[text],
|
|
|
|
| 321 |
model_name_or_path=model_name.replace(
|
| 322 |
"maidalun1020", "InfiniFlow"))
|
| 323 |
|
| 324 |
+
def encode(self, texts: list):
|
| 325 |
+
batch_size = 10
|
| 326 |
res = []
|
| 327 |
token_count = 0
|
| 328 |
for t in texts:
|
|
|
|
| 348 |
}
|
| 349 |
self.model_name = model_name
|
| 350 |
|
| 351 |
+
def encode(self, texts: list):
|
| 352 |
texts = [truncate(t, 8196) for t in texts]
|
| 353 |
+
batch_size = 16
|
| 354 |
+
ress = []
|
| 355 |
+
token_count = 0
|
| 356 |
+
for i in range(0, len(texts), batch_size):
|
| 357 |
+
data = {
|
| 358 |
+
"model": self.model_name,
|
| 359 |
+
"input": texts[i:i + batch_size],
|
| 360 |
+
'encoding_type': 'float'
|
| 361 |
+
}
|
| 362 |
+
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
| 363 |
+
ress.extend([d["embedding"] for d in res["data"]])
|
| 364 |
+
token_count += res["usage"]["total_tokens"]
|
| 365 |
+
return np.array(ress), token_count
|
| 366 |
|
| 367 |
def encode_queries(self, text):
|
| 368 |
embds, cnt = self.encode([text])
|
|
|
|
| 416 |
self.client = MistralClient(api_key=key)
|
| 417 |
self.model_name = model_name
|
| 418 |
|
| 419 |
+
def encode(self, texts: list):
|
| 420 |
texts = [truncate(t, 8196) for t in texts]
|
| 421 |
+
batch_size = 16
|
| 422 |
+
ress = []
|
| 423 |
+
token_count = 0
|
| 424 |
+
for i in range(0, len(texts), batch_size):
|
| 425 |
+
res = self.client.embeddings(input=texts[i:i + batch_size],
|
| 426 |
+
model=self.model_name)
|
| 427 |
+
ress.extend([d.embedding for d in res.data])
|
| 428 |
+
token_count += res.usage.total_tokens
|
| 429 |
+
return np.array(ress), token_count
|
| 430 |
|
| 431 |
def encode_queries(self, text):
|
| 432 |
res = self.client.embeddings(input=[truncate(text, 8196)],
|
|
|
|
| 445 |
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
|
| 446 |
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
| 447 |
|
| 448 |
+
def encode(self, texts: list):
|
| 449 |
texts = [truncate(t, 8196) for t in texts]
|
| 450 |
embeddings = []
|
| 451 |
token_count = 0
|
|
|
|
| 463 |
return np.array(embeddings), token_count
|
| 464 |
|
| 465 |
def encode_queries(self, text):
|
|
|
|
| 466 |
embeddings = []
|
| 467 |
token_count = num_tokens_from_string(text)
|
| 468 |
if self.model_name.split('.')[0] == 'amazon':
|
|
|
|
| 479 |
class GeminiEmbed(Base):
|
| 480 |
def __init__(self, key, model_name='models/text-embedding-004',
|
| 481 |
**kwargs):
|
| 482 |
+
self.key = key
|
| 483 |
self.model_name = 'models/' + model_name
|
| 484 |
|
| 485 |
+
def encode(self, texts: list):
|
| 486 |
texts = [truncate(t, 2048) for t in texts]
|
| 487 |
token_count = sum(num_tokens_from_string(text) for text in texts)
|
| 488 |
+
genai.configure(api_key=self.key)
|
| 489 |
+
batch_size = 16
|
| 490 |
+
ress = []
|
| 491 |
+
for i in range(0, len(texts), batch_size):
|
| 492 |
+
result = genai.embed_content(
|
| 493 |
+
model=self.model_name,
|
| 494 |
+
content=texts[i, i + batch_size],
|
| 495 |
+
task_type="retrieval_document",
|
| 496 |
+
title="Embedding of single string")
|
| 497 |
+
ress.extend(result['embedding'])
|
| 498 |
+
return np.array(ress),token_count
|
| 499 |
|
| 500 |
def encode_queries(self, text):
|
| 501 |
+
genai.configure(api_key=self.key)
|
| 502 |
result = genai.embed_content(
|
| 503 |
model=self.model_name,
|
| 504 |
content=truncate(text,2048),
|
|
|
|
| 527 |
if model_name == "snowflake/arctic-embed-l":
|
| 528 |
self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
|
| 529 |
|
| 530 |
+
def encode(self, texts: list):
|
| 531 |
+
batch_size = 16
|
| 532 |
+
ress = []
|
| 533 |
+
token_count = 0
|
| 534 |
+
for i in range(0, len(texts), batch_size):
|
| 535 |
+
payload = {
|
| 536 |
+
"input": texts[i : i + batch_size],
|
| 537 |
+
"input_type": "query",
|
| 538 |
+
"model": self.model_name,
|
| 539 |
+
"encoding_format": "float",
|
| 540 |
+
"truncate": "END",
|
| 541 |
+
}
|
| 542 |
+
res = requests.post(self.base_url, headers=self.headers, json=payload).json()
|
| 543 |
+
ress.extend([d["embedding"] for d in res["data"]])
|
| 544 |
+
token_count += res["usage"]["total_tokens"]
|
| 545 |
+
return np.array(ress), token_count
|
| 546 |
|
| 547 |
def encode_queries(self, text):
|
| 548 |
embds, cnt = self.encode([text])
|
|
|
|
| 576 |
self.client = Client(api_key=key)
|
| 577 |
self.model_name = model_name
|
| 578 |
|
| 579 |
+
def encode(self, texts: list):
|
| 580 |
+
batch_size = 16
|
| 581 |
+
ress = []
|
| 582 |
+
token_count = 0
|
| 583 |
+
for i in range(0, len(texts), batch_size):
|
| 584 |
+
res = self.client.embed(
|
| 585 |
+
texts=texts[i : i + batch_size],
|
| 586 |
+
model=self.model_name,
|
| 587 |
+
input_type="search_document",
|
| 588 |
+
embedding_types=["float"],
|
| 589 |
+
)
|
| 590 |
+
ress.extend([d for d in res.embeddings.float])
|
| 591 |
+
token_count += res.meta.billed_units.input_tokens
|
| 592 |
+
return np.array(ress), token_count
|
| 593 |
|
| 594 |
def encode_queries(self, text):
|
| 595 |
res = self.client.embed(
|
|
|
|
| 638 |
self.base_url = base_url
|
| 639 |
self.model_name = model_name
|
| 640 |
|
| 641 |
+
def encode(self, texts: list):
|
| 642 |
+
batch_size = 16
|
| 643 |
+
ress = []
|
| 644 |
+
token_count = 0
|
| 645 |
+
for i in range(0, len(texts), batch_size):
|
| 646 |
+
texts_batch = texts[i : i + batch_size]
|
| 647 |
+
payload = {
|
| 648 |
+
"model": self.model_name,
|
| 649 |
+
"input": texts_batch,
|
| 650 |
+
"encoding_format": "float",
|
| 651 |
+
}
|
| 652 |
+
res = requests.post(self.base_url, json=payload, headers=self.headers).json()
|
| 653 |
+
if "data" not in res or not isinstance(res["data"], list) or len(res["data"]) != len(texts_batch):
|
| 654 |
+
raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}")
|
| 655 |
+
ress.extend([d["embedding"] for d in res["data"]])
|
| 656 |
+
token_count += res["usage"]["total_tokens"]
|
| 657 |
+
return np.array(ress), token_count
|
| 658 |
|
| 659 |
def encode_queries(self, text):
|
| 660 |
payload = {
|
|
|
|
| 675 |
self.model_name = model_name
|
| 676 |
self.client = Client(api_token=key)
|
| 677 |
|
| 678 |
+
def encode(self, texts: list):
|
| 679 |
+
batch_size = 16
|
| 680 |
+
token_count = sum([num_tokens_from_string(text) for text in texts])
|
| 681 |
+
ress = []
|
| 682 |
+
for i in range(0, len(texts), batch_size):
|
| 683 |
+
res = self.client.run(self.model_name, input={"texts": texts[i : i + batch_size]})
|
| 684 |
+
ress.extend(res)
|
| 685 |
+
return np.array(ress), token_count
|
| 686 |
|
| 687 |
def encode_queries(self, text):
|
| 688 |
res = self.client.embed(self.model_name, input={"texts": [text]})
|
|
|
|
| 721 |
self.client = voyageai.Client(api_key=key)
|
| 722 |
self.model_name = model_name
|
| 723 |
|
| 724 |
+
def encode(self, texts: list):
|
| 725 |
+
batch_size = 16
|
| 726 |
+
ress = []
|
| 727 |
+
token_count = 0
|
| 728 |
+
for i in range(0, len(texts), batch_size):
|
| 729 |
+
res = self.client.embed(
|
| 730 |
+
texts=texts[i : i + batch_size], model=self.model_name, input_type="document"
|
| 731 |
+
)
|
| 732 |
+
ress.extend(res.embeddings)
|
| 733 |
+
token_count += res.total_tokens
|
| 734 |
+
return np.array(ress), token_count
|
| 735 |
|
| 736 |
def encode_queries(self, text):
|
| 737 |
res = self.client.embed(
|
|
|
|
| 748 |
self.model_name = model_name
|
| 749 |
self.base_url = base_url or "http://127.0.0.1:8080"
|
| 750 |
|
| 751 |
+
def encode(self, texts: list):
|
| 752 |
embeddings = []
|
| 753 |
for text in texts:
|
| 754 |
response = requests.post(
|