add support for Replicate (#1980)
Browse files### What problem does this PR solve?
#1853 add support for Replicate
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
---------
Co-authored-by: Zhedong Cen <[email protected]>
- api/apps/llm_app.py +2 -2
- conf/llm_factories.json +7 -0
- rag/llm/__init__.py +4 -2
- rag/llm/chat_model.py +53 -2
- rag/llm/embedding_model.py +20 -2
- requirements.txt +2 -1
- requirements_arm.txt +2 -1
- web/src/assets/svg/llm/replicate.svg +1 -0
- web/src/pages/user-setting/constants.tsx +1 -1
- web/src/pages/user-setting/setting-model/constant.ts +2 -1
api/apps/llm_app.py
CHANGED
|
@@ -149,7 +149,7 @@ def add_llm():
|
|
| 149 |
msg = ""
|
| 150 |
if llm["model_type"] == LLMType.EMBEDDING.value:
|
| 151 |
mdl = EmbeddingModel[factory](
|
| 152 |
-
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
|
| 153 |
model_name=llm["llm_name"],
|
| 154 |
base_url=llm["api_base"])
|
| 155 |
try:
|
|
@@ -160,7 +160,7 @@ def add_llm():
|
|
| 160 |
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
|
| 161 |
elif llm["model_type"] == LLMType.CHAT.value:
|
| 162 |
mdl = ChatModel[factory](
|
| 163 |
-
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
|
| 164 |
model_name=llm["llm_name"],
|
| 165 |
base_url=llm["api_base"]
|
| 166 |
)
|
|
|
|
| 149 |
msg = ""
|
| 150 |
if llm["model_type"] == LLMType.EMBEDDING.value:
|
| 151 |
mdl = EmbeddingModel[factory](
|
| 152 |
+
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None,
|
| 153 |
model_name=llm["llm_name"],
|
| 154 |
base_url=llm["api_base"])
|
| 155 |
try:
|
|
|
|
| 160 |
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
|
| 161 |
elif llm["model_type"] == LLMType.CHAT.value:
|
| 162 |
mdl = ChatModel[factory](
|
| 163 |
+
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None,
|
| 164 |
model_name=llm["llm_name"],
|
| 165 |
base_url=llm["api_base"]
|
| 166 |
)
|
conf/llm_factories.json
CHANGED
|
@@ -3113,6 +3113,13 @@
|
|
| 3113 |
"model_type": "image2text"
|
| 3114 |
}
|
| 3115 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3116 |
}
|
| 3117 |
]
|
| 3118 |
}
|
|
|
|
| 3113 |
"model_type": "image2text"
|
| 3114 |
}
|
| 3115 |
]
|
| 3116 |
+
},
|
| 3117 |
+
{
|
| 3118 |
+
"name": "Replicate",
|
| 3119 |
+
"logo": "",
|
| 3120 |
+
"tags": "LLM,TEXT EMBEDDING",
|
| 3121 |
+
"status": "1",
|
| 3122 |
+
"llm": []
|
| 3123 |
}
|
| 3124 |
]
|
| 3125 |
}
|
rag/llm/__init__.py
CHANGED
|
@@ -42,7 +42,8 @@ EmbeddingModel = {
|
|
| 42 |
"TogetherAI": TogetherAIEmbed,
|
| 43 |
"PerfXCloud": PerfXCloudEmbed,
|
| 44 |
"Upstage": UpstageEmbed,
|
| 45 |
-
"SILICONFLOW": SILICONFLOWEmbed
|
|
|
|
| 46 |
}
|
| 47 |
|
| 48 |
|
|
@@ -96,7 +97,8 @@ ChatModel = {
|
|
| 96 |
"Upstage":UpstageChat,
|
| 97 |
"novita.ai": NovitaAIChat,
|
| 98 |
"SILICONFLOW": SILICONFLOWChat,
|
| 99 |
-
"01.AI": YiChat
|
|
|
|
| 100 |
}
|
| 101 |
|
| 102 |
|
|
|
|
| 42 |
"TogetherAI": TogetherAIEmbed,
|
| 43 |
"PerfXCloud": PerfXCloudEmbed,
|
| 44 |
"Upstage": UpstageEmbed,
|
| 45 |
+
"SILICONFLOW": SILICONFLOWEmbed,
|
| 46 |
+
"Replicate": ReplicateEmbed
|
| 47 |
}
|
| 48 |
|
| 49 |
|
|
|
|
| 97 |
"Upstage":UpstageChat,
|
| 98 |
"novita.ai": NovitaAIChat,
|
| 99 |
"SILICONFLOW": SILICONFLOWChat,
|
| 100 |
+
"01.AI": YiChat,
|
| 101 |
+
"Replicate": ReplicateChat
|
| 102 |
}
|
| 103 |
|
| 104 |
|
rag/llm/chat_model.py
CHANGED
|
@@ -1003,7 +1003,7 @@ class TogetherAIChat(Base):
|
|
| 1003 |
base_url = "https://api.together.xyz/v1"
|
| 1004 |
super().__init__(key, model_name, base_url)
|
| 1005 |
|
| 1006 |
-
|
| 1007 |
class PerfXCloudChat(Base):
|
| 1008 |
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
| 1009 |
if not base_url:
|
|
@@ -1036,4 +1036,55 @@ class YiChat(Base):
|
|
| 1036 |
def __init__(self, key, model_name, base_url="https://api.01.ai/v1"):
|
| 1037 |
if not base_url:
|
| 1038 |
base_url = "https://api.01.ai/v1"
|
| 1039 |
-
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1003 |
base_url = "https://api.together.xyz/v1"
|
| 1004 |
super().__init__(key, model_name, base_url)
|
| 1005 |
|
| 1006 |
+
|
| 1007 |
class PerfXCloudChat(Base):
|
| 1008 |
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
| 1009 |
if not base_url:
|
|
|
|
| 1036 |
def __init__(self, key, model_name, base_url="https://api.01.ai/v1"):
|
| 1037 |
if not base_url:
|
| 1038 |
base_url = "https://api.01.ai/v1"
|
| 1039 |
+
super().__init__(key, model_name, base_url)
|
| 1040 |
+
|
| 1041 |
+
|
| 1042 |
+
class ReplicateChat(Base):
|
| 1043 |
+
def __init__(self, key, model_name, base_url=None):
|
| 1044 |
+
from replicate.client import Client
|
| 1045 |
+
|
| 1046 |
+
self.model_name = model_name
|
| 1047 |
+
self.client = Client(api_token=key)
|
| 1048 |
+
self.system = ""
|
| 1049 |
+
|
| 1050 |
+
def chat(self, system, history, gen_conf):
|
| 1051 |
+
if "max_tokens" in gen_conf:
|
| 1052 |
+
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
|
| 1053 |
+
if system:
|
| 1054 |
+
self.system = system
|
| 1055 |
+
prompt = "\n".join(
|
| 1056 |
+
[item["role"] + ":" + item["content"] for item in history[-5:]]
|
| 1057 |
+
)
|
| 1058 |
+
ans = ""
|
| 1059 |
+
try:
|
| 1060 |
+
response = self.client.run(
|
| 1061 |
+
self.model_name,
|
| 1062 |
+
input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
|
| 1063 |
+
)
|
| 1064 |
+
ans = "".join(response)
|
| 1065 |
+
return ans, num_tokens_from_string(ans)
|
| 1066 |
+
except Exception as e:
|
| 1067 |
+
return ans + "\n**ERROR**: " + str(e), 0
|
| 1068 |
+
|
| 1069 |
+
def chat_streamly(self, system, history, gen_conf):
|
| 1070 |
+
if "max_tokens" in gen_conf:
|
| 1071 |
+
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
|
| 1072 |
+
if system:
|
| 1073 |
+
self.system = system
|
| 1074 |
+
prompt = "\n".join(
|
| 1075 |
+
[item["role"] + ":" + item["content"] for item in history[-5:]]
|
| 1076 |
+
)
|
| 1077 |
+
ans = ""
|
| 1078 |
+
try:
|
| 1079 |
+
response = self.client.run(
|
| 1080 |
+
self.model_name,
|
| 1081 |
+
input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
|
| 1082 |
+
)
|
| 1083 |
+
for resp in response:
|
| 1084 |
+
ans += resp
|
| 1085 |
+
yield ans
|
| 1086 |
+
|
| 1087 |
+
except Exception as e:
|
| 1088 |
+
yield ans + "\n**ERROR**: " + str(e)
|
| 1089 |
+
|
| 1090 |
+
yield num_tokens_from_string(ans)
|
rag/llm/embedding_model.py
CHANGED
|
@@ -561,7 +561,7 @@ class TogetherAIEmbed(OllamaEmbed):
|
|
| 561 |
base_url = "https://api.together.xyz/v1"
|
| 562 |
super().__init__(key, model_name, base_url)
|
| 563 |
|
| 564 |
-
|
| 565 |
class PerfXCloudEmbed(OpenAIEmbed):
|
| 566 |
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
| 567 |
if not base_url:
|
|
@@ -580,4 +580,22 @@ class SILICONFLOWEmbed(OpenAIEmbed):
|
|
| 580 |
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"):
|
| 581 |
if not base_url:
|
| 582 |
base_url = "https://api.siliconflow.cn/v1"
|
| 583 |
-
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
base_url = "https://api.together.xyz/v1"
|
| 562 |
super().__init__(key, model_name, base_url)
|
| 563 |
|
| 564 |
+
|
| 565 |
class PerfXCloudEmbed(OpenAIEmbed):
|
| 566 |
def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
|
| 567 |
if not base_url:
|
|
|
|
| 580 |
def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"):
|
| 581 |
if not base_url:
|
| 582 |
base_url = "https://api.siliconflow.cn/v1"
|
| 583 |
+
super().__init__(key, model_name, base_url)
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
class ReplicateEmbed(Base):
|
| 587 |
+
def __init__(self, key, model_name, base_url=None):
|
| 588 |
+
from replicate.client import Client
|
| 589 |
+
|
| 590 |
+
self.model_name = model_name
|
| 591 |
+
self.client = Client(api_token=key)
|
| 592 |
+
|
| 593 |
+
def encode(self, texts: list, batch_size=32):
|
| 594 |
+
from json import dumps
|
| 595 |
+
|
| 596 |
+
res = self.client.run(self.model_name, input={"texts": dumps(texts)})
|
| 597 |
+
return np.array(res), sum([num_tokens_from_string(text) for text in texts])
|
| 598 |
+
|
| 599 |
+
def encode_queries(self, text):
|
| 600 |
+
res = self.client.embed(self.model_name, input={"texts": [text]})
|
| 601 |
+
return np.array(res), num_tokens_from_string(text)
|
requirements.txt
CHANGED
|
@@ -65,6 +65,7 @@ python_pptx==0.6.23
|
|
| 65 |
readability_lxml==0.8.1
|
| 66 |
redis==5.0.3
|
| 67 |
Requests==2.32.2
|
|
|
|
| 68 |
roman_numbers==1.0.2
|
| 69 |
ruamel.base==1.0.0
|
| 70 |
scholarly==1.7.11
|
|
@@ -87,4 +88,4 @@ wikipedia==1.4.0
|
|
| 87 |
word2number==1.1
|
| 88 |
xgboost==2.1.0
|
| 89 |
xpinyin==0.7.6
|
| 90 |
-
zhipuai==2.0.1
|
|
|
|
| 65 |
readability_lxml==0.8.1
|
| 66 |
redis==5.0.3
|
| 67 |
Requests==2.32.2
|
| 68 |
+
replicate==0.31.0
|
| 69 |
roman_numbers==1.0.2
|
| 70 |
ruamel.base==1.0.0
|
| 71 |
scholarly==1.7.11
|
|
|
|
| 88 |
word2number==1.1
|
| 89 |
xgboost==2.1.0
|
| 90 |
xpinyin==0.7.6
|
| 91 |
+
zhipuai==2.0.1
|
requirements_arm.txt
CHANGED
|
@@ -102,6 +102,7 @@ python-pptx==0.6.23
|
|
| 102 |
PyYAML==6.0.1
|
| 103 |
redis==5.0.3
|
| 104 |
regex==2023.12.25
|
|
|
|
| 105 |
requests==2.32.2
|
| 106 |
ruamel.yaml==0.18.6
|
| 107 |
ruamel.yaml.clib==0.2.8
|
|
@@ -161,4 +162,4 @@ markdown_to_json==2.1.1
|
|
| 161 |
scholarly==1.7.11
|
| 162 |
deepl==1.18.0
|
| 163 |
psycopg2-binary==2.9.9
|
| 164 |
-
tabulate-0.9.0
|
|
|
|
| 102 |
PyYAML==6.0.1
|
| 103 |
redis==5.0.3
|
| 104 |
regex==2023.12.25
|
| 105 |
+
replicate==0.31.0
|
| 106 |
requests==2.32.2
|
| 107 |
ruamel.yaml==0.18.6
|
| 108 |
ruamel.yaml.clib==0.2.8
|
|
|
|
| 162 |
scholarly==1.7.11
|
| 163 |
deepl==1.18.0
|
| 164 |
psycopg2-binary==2.9.9
|
| 165 |
+
tabulate-0.9.0
|
web/src/assets/svg/llm/replicate.svg
ADDED
|
|
web/src/pages/user-setting/constants.tsx
CHANGED
|
@@ -17,4 +17,4 @@ export const UserSettingIconMap = {
|
|
| 17 |
|
| 18 |
export * from '@/constants/setting';
|
| 19 |
|
| 20 |
-
export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible",'TogetherAI'];
|
|
|
|
| 17 |
|
| 18 |
export * from '@/constants/setting';
|
| 19 |
|
| 20 |
+
export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible",'TogetherAI','Replicate'];
|
web/src/pages/user-setting/setting-model/constant.ts
CHANGED
|
@@ -30,7 +30,8 @@ export const IconMap = {
|
|
| 30 |
Upstage: 'upstage',
|
| 31 |
'novita.ai': 'novita-ai',
|
| 32 |
SILICONFLOW: 'siliconflow',
|
| 33 |
-
"01.AI": 'yi'
|
|
|
|
| 34 |
};
|
| 35 |
|
| 36 |
export const BedrockRegionList = [
|
|
|
|
| 30 |
Upstage: 'upstage',
|
| 31 |
'novita.ai': 'novita-ai',
|
| 32 |
SILICONFLOW: 'siliconflow',
|
| 33 |
+
"01.AI": 'yi',
|
| 34 |
+
"Replicate": 'replicate'
|
| 35 |
};
|
| 36 |
|
| 37 |
export const BedrockRegionList = [
|