add support for Replicate (#1980)
Browse files### What problem does this PR solve?
#1853 add support for Replicate
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
---------
Co-authored-by: Zhedong Cen <[email protected]>
- api/apps/llm_app.py +2 -2
- conf/llm_factories.json +7 -0
- rag/llm/__init__.py +4 -2
- rag/llm/chat_model.py +53 -2
- rag/llm/embedding_model.py +20 -2
- requirements.txt +2 -1
- requirements_arm.txt +2 -1
- web/src/assets/svg/llm/replicate.svg +1 -0
- web/src/pages/user-setting/constants.tsx +1 -1
- web/src/pages/user-setting/setting-model/constant.ts +2 -1
api/apps/llm_app.py
CHANGED
@@ -149,7 +149,7 @@ def add_llm():
|
|
149 |
msg = ""
|
150 |
if llm["model_type"] == LLMType.EMBEDDING.value:
|
151 |
mdl = EmbeddingModel[factory](
|
152 |
-
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
|
153 |
model_name=llm["llm_name"],
|
154 |
base_url=llm["api_base"])
|
155 |
try:
|
@@ -160,7 +160,7 @@ def add_llm():
|
|
160 |
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
|
161 |
elif llm["model_type"] == LLMType.CHAT.value:
|
162 |
mdl = ChatModel[factory](
|
163 |
-
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
|
164 |
model_name=llm["llm_name"],
|
165 |
base_url=llm["api_base"]
|
166 |
)
|
|
|
149 |
msg = ""
|
150 |
if llm["model_type"] == LLMType.EMBEDDING.value:
|
151 |
mdl = EmbeddingModel[factory](
|
152 |
+
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None,
|
153 |
model_name=llm["llm_name"],
|
154 |
base_url=llm["api_base"])
|
155 |
try:
|
|
|
160 |
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
|
161 |
elif llm["model_type"] == LLMType.CHAT.value:
|
162 |
mdl = ChatModel[factory](
|
163 |
+
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None,
|
164 |
model_name=llm["llm_name"],
|
165 |
base_url=llm["api_base"]
|
166 |
)
|
conf/llm_factories.json
CHANGED
@@ -3113,6 +3113,13 @@
|
|
3113 |
"model_type": "image2text"
|
3114 |
}
|
3115 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3116 |
}
|
3117 |
]
|
3118 |
}
|
|
|
3113 |
"model_type": "image2text"
|
3114 |
}
|
3115 |
]
|
3116 |
+
},
|
3117 |
+
{
|
3118 |
+
"name": "Replicate",
|
3119 |
+
"logo": "",
|
3120 |
+
"tags": "LLM,TEXT EMBEDDING",
|
3121 |
+
"status": "1",
|
3122 |
+
"llm": []
|
3123 |
}
|
3124 |
]
|
3125 |
}
|
rag/llm/__init__.py
CHANGED
@@ -42,7 +42,8 @@ EmbeddingModel = {
|
|
42 |
"TogetherAI": TogetherAIEmbed,
|
43 |
"PerfXCloud": PerfXCloudEmbed,
|
44 |
"Upstage": UpstageEmbed,
|
45 |
-
"SILICONFLOW": SILICONFLOWEmbed
|
|
|
46 |
}
|
47 |
|
48 |
|
@@ -96,7 +97,8 @@ ChatModel = {
|
|
96 |
"Upstage":UpstageChat,
|
97 |
"novita.ai": NovitaAIChat,
|
98 |
"SILICONFLOW": SILICONFLOWChat,
|
99 |
-
"01.AI": YiChat
|
|
|
100 |
}
|
101 |
|
102 |
|
|
|
42 |
"TogetherAI": TogetherAIEmbed,
|
43 |
"PerfXCloud": PerfXCloudEmbed,
|
44 |
"Upstage": UpstageEmbed,
|
45 |
+
"SILICONFLOW": SILICONFLOWEmbed,
|
46 |
+
"Replicate": ReplicateEmbed
|
47 |
}
|
48 |
|
49 |
|
|
|
97 |
"Upstage":UpstageChat,
|
98 |
"novita.ai": NovitaAIChat,
|
99 |
"SILICONFLOW": SILICONFLOWChat,
|
100 |
+
"01.AI": YiChat,
|
101 |
+
"Replicate": ReplicateChat
|
102 |
}
|
103 |
|
104 |
|
rag/llm/chat_model.py
CHANGED
@@ -1003,7 +1003,7 @@ class TogetherAIChat(Base):
|
|
1003 |
base_url = "https://api.together.xyz/v1"
|
1004 |
super().__init__(key, model_name, base_url)
|
1005 |
|
1006 |
-
|
1007 |
class PerfXCloudChat(Base):
|
1008 |
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
1009 |
if not base_url:
|
@@ -1036,4 +1036,55 @@ class YiChat(Base):
|
|
1036 |
def __init__(self, key, model_name, base_url="https://api.01.ai/v1"):
|
1037 |
if not base_url:
|
1038 |
base_url = "https://api.01.ai/v1"
|
1039 |
-
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1003 |
base_url = "https://api.together.xyz/v1"
|
1004 |
super().__init__(key, model_name, base_url)
|
1005 |
|
1006 |
+
|
1007 |
class PerfXCloudChat(Base):
|
1008 |
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
1009 |
if not base_url:
|
|
|
1036 |
def __init__(self, key, model_name, base_url="https://api.01.ai/v1"):
|
1037 |
if not base_url:
|
1038 |
base_url = "https://api.01.ai/v1"
|
1039 |
+
super().__init__(key, model_name, base_url)
|
1040 |
+
|
1041 |
+
|
1042 |
+
class ReplicateChat(Base):
|
1043 |
+
def __init__(self, key, model_name, base_url=None):
|
1044 |
+
from replicate.client import Client
|
1045 |
+
|
1046 |
+
self.model_name = model_name
|
1047 |
+
self.client = Client(api_token=key)
|
1048 |
+
self.system = ""
|
1049 |
+
|
1050 |
+
def chat(self, system, history, gen_conf):
|
1051 |
+
if "max_tokens" in gen_conf:
|
1052 |
+
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
|
1053 |
+
if system:
|
1054 |
+
self.system = system
|
1055 |
+
prompt = "\n".join(
|
1056 |
+
[item["role"] + ":" + item["content"] for item in history[-5:]]
|
1057 |
+
)
|
1058 |
+
ans = ""
|
1059 |
+
try:
|
1060 |
+
response = self.client.run(
|
1061 |
+
self.model_name,
|
1062 |
+
input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
|
1063 |
+
)
|
1064 |
+
ans = "".join(response)
|
1065 |
+
return ans, num_tokens_from_string(ans)
|
1066 |
+
except Exception as e:
|
1067 |
+
return ans + "\n**ERROR**: " + str(e), 0
|
1068 |
+
|
1069 |
+
def chat_streamly(self, system, history, gen_conf):
|
1070 |
+
if "max_tokens" in gen_conf:
|
1071 |
+
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
|
1072 |
+
if system:
|
1073 |
+
self.system = system
|
1074 |
+
prompt = "\n".join(
|
1075 |
+
[item["role"] + ":" + item["content"] for item in history[-5:]]
|
1076 |
+
)
|
1077 |
+
ans = ""
|
1078 |
+
try:
|
1079 |
+
response = self.client.run(
|
1080 |
+
self.model_name,
|
1081 |
+
input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
|
1082 |
+
)
|
1083 |
+
for resp in response:
|
1084 |
+
ans += resp
|
1085 |
+
yield ans
|
1086 |
+
|
1087 |
+
except Exception as e:
|
1088 |
+
yield ans + "\n**ERROR**: " + str(e)
|
1089 |
+
|
1090 |
+
yield num_tokens_from_string(ans)
|
rag/llm/embedding_model.py
CHANGED
@@ -561,7 +561,7 @@ class TogetherAIEmbed(OllamaEmbed):
|
|
561 |
base_url = "https://api.together.xyz/v1"
|
562 |
super().__init__(key, model_name, base_url)
|
563 |
|
564 |
-
|
565 |
class PerfXCloudEmbed(OpenAIEmbed):
|
566 |
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
567 |
if not base_url:
|
@@ -580,4 +580,22 @@ class SILICONFLOWEmbed(OpenAIEmbed):
|
|
580 |
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"):
|
581 |
if not base_url:
|
582 |
base_url = "https://api.siliconflow.cn/v1"
|
583 |
-
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
561 |
base_url = "https://api.together.xyz/v1"
|
562 |
super().__init__(key, model_name, base_url)
|
563 |
|
564 |
+
|
565 |
class PerfXCloudEmbed(OpenAIEmbed):
|
566 |
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
567 |
if not base_url:
|
|
|
580 |
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"):
|
581 |
if not base_url:
|
582 |
base_url = "https://api.siliconflow.cn/v1"
|
583 |
+
super().__init__(key, model_name, base_url)
|
584 |
+
|
585 |
+
|
586 |
+
class ReplicateEmbed(Base):
|
587 |
+
def __init__(self, key, model_name, base_url=None):
|
588 |
+
from replicate.client import Client
|
589 |
+
|
590 |
+
self.model_name = model_name
|
591 |
+
self.client = Client(api_token=key)
|
592 |
+
|
593 |
+
def encode(self, texts: list, batch_size=32):
|
594 |
+
from json import dumps
|
595 |
+
|
596 |
+
res = self.client.run(self.model_name, input={"texts": dumps(texts)})
|
597 |
+
return np.array(res), sum([num_tokens_from_string(text) for text in texts])
|
598 |
+
|
599 |
+
def encode_queries(self, text):
|
600 |
+
res = self.client.embed(self.model_name, input={"texts": [text]})
|
601 |
+
return np.array(res), num_tokens_from_string(text)
|
requirements.txt
CHANGED
@@ -65,6 +65,7 @@ python_pptx==0.6.23
|
|
65 |
readability_lxml==0.8.1
|
66 |
redis==5.0.3
|
67 |
Requests==2.32.2
|
|
|
68 |
roman_numbers==1.0.2
|
69 |
ruamel.base==1.0.0
|
70 |
scholarly==1.7.11
|
@@ -87,4 +88,4 @@ wikipedia==1.4.0
|
|
87 |
word2number==1.1
|
88 |
xgboost==2.1.0
|
89 |
xpinyin==0.7.6
|
90 |
-
zhipuai==2.0.1
|
|
|
65 |
readability_lxml==0.8.1
|
66 |
redis==5.0.3
|
67 |
Requests==2.32.2
|
68 |
+
replicate==0.31.0
|
69 |
roman_numbers==1.0.2
|
70 |
ruamel.base==1.0.0
|
71 |
scholarly==1.7.11
|
|
|
88 |
word2number==1.1
|
89 |
xgboost==2.1.0
|
90 |
xpinyin==0.7.6
|
91 |
+
zhipuai==2.0.1
|
requirements_arm.txt
CHANGED
@@ -102,6 +102,7 @@ python-pptx==0.6.23
|
|
102 |
PyYAML==6.0.1
|
103 |
redis==5.0.3
|
104 |
regex==2023.12.25
|
|
|
105 |
requests==2.32.2
|
106 |
ruamel.yaml==0.18.6
|
107 |
ruamel.yaml.clib==0.2.8
|
@@ -161,4 +162,4 @@ markdown_to_json==2.1.1
|
|
161 |
scholarly==1.7.11
|
162 |
deepl==1.18.0
|
163 |
psycopg2-binary==2.9.9
|
164 |
-
tabulate-0.9.0
|
|
|
102 |
PyYAML==6.0.1
|
103 |
redis==5.0.3
|
104 |
regex==2023.12.25
|
105 |
+
replicate==0.31.0
|
106 |
requests==2.32.2
|
107 |
ruamel.yaml==0.18.6
|
108 |
ruamel.yaml.clib==0.2.8
|
|
|
162 |
scholarly==1.7.11
|
163 |
deepl==1.18.0
|
164 |
psycopg2-binary==2.9.9
|
165 |
+
tabulate-0.9.0
|
web/src/assets/svg/llm/replicate.svg
ADDED
|
web/src/pages/user-setting/constants.tsx
CHANGED
@@ -17,4 +17,4 @@ export const UserSettingIconMap = {
|
|
17 |
|
18 |
export * from '@/constants/setting';
|
19 |
|
20 |
-
export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible",'TogetherAI'];
|
|
|
17 |
|
18 |
export * from '@/constants/setting';
|
19 |
|
20 |
+
export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible",'TogetherAI','Replicate'];
|
web/src/pages/user-setting/setting-model/constant.ts
CHANGED
@@ -30,7 +30,8 @@ export const IconMap = {
|
|
30 |
Upstage: 'upstage',
|
31 |
'novita.ai': 'novita-ai',
|
32 |
SILICONFLOW: 'siliconflow',
|
33 |
-
"01.AI": 'yi'
|
|
|
34 |
};
|
35 |
|
36 |
export const BedrockRegionList = [
|
|
|
30 |
Upstage: 'upstage',
|
31 |
'novita.ai': 'novita-ai',
|
32 |
SILICONFLOW: 'siliconflow',
|
33 |
+
"01.AI": 'yi',
|
34 |
+
"Replicate": 'replicate'
|
35 |
};
|
36 |
|
37 |
export const BedrockRegionList = [
|