KevinHuSh
commited on
Commit
·
c60dccb
1
Parent(s):
d42f535
fix #994 (#1006)
Browse files### What problem does this PR solve?
#994
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- rag/llm/embedding_model.py +29 -21
rag/llm/embedding_model.py
CHANGED
|
@@ -123,30 +123,38 @@ class QWenEmbed(Base):
|
|
| 123 |
|
| 124 |
def encode(self, texts: list, batch_size=10):
|
| 125 |
import dashscope
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
resp = dashscope.TextEmbedding.call(
|
| 131 |
model=self.model_name,
|
| 132 |
-
input=
|
| 133 |
-
text_type="
|
| 134 |
)
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
return np.array(res), token_count
|
| 141 |
-
|
| 142 |
-
def encode_queries(self, text):
|
| 143 |
-
resp = dashscope.TextEmbedding.call(
|
| 144 |
-
model=self.model_name,
|
| 145 |
-
input=text[:2048],
|
| 146 |
-
text_type="query"
|
| 147 |
-
)
|
| 148 |
-
return np.array(resp["output"]["embeddings"][0]
|
| 149 |
-
["embedding"]), resp["usage"]["total_tokens"]
|
| 150 |
|
| 151 |
|
| 152 |
class ZhipuEmbed(Base):
|
|
|
|
| 123 |
|
| 124 |
def encode(self, texts: list, batch_size=10):
|
| 125 |
import dashscope
|
| 126 |
+
try:
|
| 127 |
+
res = []
|
| 128 |
+
token_count = 0
|
| 129 |
+
texts = [truncate(t, 2048) for t in texts]
|
| 130 |
+
for i in range(0, len(texts), batch_size):
|
| 131 |
+
resp = dashscope.TextEmbedding.call(
|
| 132 |
+
model=self.model_name,
|
| 133 |
+
input=texts[i:i + batch_size],
|
| 134 |
+
text_type="document"
|
| 135 |
+
)
|
| 136 |
+
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
|
| 137 |
+
for e in resp["output"]["embeddings"]:
|
| 138 |
+
embds[e["text_index"]] = e["embedding"]
|
| 139 |
+
res.extend(embds)
|
| 140 |
+
token_count += resp["usage"]["total_tokens"]
|
| 141 |
+
return np.array(res), token_count
|
| 142 |
+
except Exception as e:
|
| 143 |
+
raise Exception("Account abnormal. Please ensure it's on good standing.")
|
| 144 |
+
return np.array([]), 0
|
| 145 |
+
|
| 146 |
+
def encode_queries(self, text):
|
| 147 |
+
try:
|
| 148 |
resp = dashscope.TextEmbedding.call(
|
| 149 |
model=self.model_name,
|
| 150 |
+
input=text[:2048],
|
| 151 |
+
text_type="query"
|
| 152 |
)
|
| 153 |
+
return np.array(resp["output"]["embeddings"][0]
|
| 154 |
+
["embedding"]), resp["usage"]["total_tokens"]
|
| 155 |
+
except Exception as e:
|
| 156 |
+
raise Exception("Account abnormal. Please ensure it's on good standing.")
|
| 157 |
+
return np.array([]), 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
|
| 160 |
class ZhipuEmbed(Base):
|