黄腾 aopstudio commited on
Commit
38ccbb8
·
1 Parent(s): d3e887e

add support for Replicate (#1980)

Browse files

### What problem does this PR solve?

#1853 add support for Replicate

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <[email protected]>

api/apps/llm_app.py CHANGED
@@ -149,7 +149,7 @@ def add_llm():
149
  msg = ""
150
  if llm["model_type"] == LLMType.EMBEDDING.value:
151
  mdl = EmbeddingModel[factory](
152
- key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
153
  model_name=llm["llm_name"],
154
  base_url=llm["api_base"])
155
  try:
@@ -160,7 +160,7 @@ def add_llm():
160
  msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
161
  elif llm["model_type"] == LLMType.CHAT.value:
162
  mdl = ChatModel[factory](
163
- key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
164
  model_name=llm["llm_name"],
165
  base_url=llm["api_base"]
166
  )
 
149
  msg = ""
150
  if llm["model_type"] == LLMType.EMBEDDING.value:
151
  mdl = EmbeddingModel[factory](
152
+ key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None,
153
  model_name=llm["llm_name"],
154
  base_url=llm["api_base"])
155
  try:
 
160
  msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
161
  elif llm["model_type"] == LLMType.CHAT.value:
162
  mdl = ChatModel[factory](
163
+ key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None,
164
  model_name=llm["llm_name"],
165
  base_url=llm["api_base"]
166
  )
conf/llm_factories.json CHANGED
@@ -3113,6 +3113,13 @@
3113
  "model_type": "image2text"
3114
  }
3115
  ]
 
 
 
 
 
 
 
3116
  }
3117
  ]
3118
  }
 
3113
  "model_type": "image2text"
3114
  }
3115
  ]
3116
+ },
3117
+ {
3118
+ "name": "Replicate",
3119
+ "logo": "",
3120
+ "tags": "LLM,TEXT EMBEDDING",
3121
+ "status": "1",
3122
+ "llm": []
3123
  }
3124
  ]
3125
  }
rag/llm/__init__.py CHANGED
@@ -42,7 +42,8 @@ EmbeddingModel = {
42
  "TogetherAI": TogetherAIEmbed,
43
  "PerfXCloud": PerfXCloudEmbed,
44
  "Upstage": UpstageEmbed,
45
- "SILICONFLOW": SILICONFLOWEmbed
 
46
  }
47
 
48
 
@@ -96,7 +97,8 @@ ChatModel = {
96
  "Upstage":UpstageChat,
97
  "novita.ai": NovitaAIChat,
98
  "SILICONFLOW": SILICONFLOWChat,
99
- "01.AI": YiChat
 
100
  }
101
 
102
 
 
42
  "TogetherAI": TogetherAIEmbed,
43
  "PerfXCloud": PerfXCloudEmbed,
44
  "Upstage": UpstageEmbed,
45
+ "SILICONFLOW": SILICONFLOWEmbed,
46
+ "Replicate": ReplicateEmbed
47
  }
48
 
49
 
 
97
  "Upstage":UpstageChat,
98
  "novita.ai": NovitaAIChat,
99
  "SILICONFLOW": SILICONFLOWChat,
100
+ "01.AI": YiChat,
101
+ "Replicate": ReplicateChat
102
  }
103
 
104
 
rag/llm/chat_model.py CHANGED
@@ -1003,7 +1003,7 @@ class TogetherAIChat(Base):
1003
  base_url = "https://api.together.xyz/v1"
1004
  super().__init__(key, model_name, base_url)
1005
 
1006
-
1007
  class PerfXCloudChat(Base):
1008
  def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
1009
  if not base_url:
@@ -1036,4 +1036,55 @@ class YiChat(Base):
1036
  def __init__(self, key, model_name, base_url="https://api.01.ai/v1"):
1037
  if not base_url:
1038
  base_url = "https://api.01.ai/v1"
1039
- super().__init__(key, model_name, base_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1003
  base_url = "https://api.together.xyz/v1"
1004
  super().__init__(key, model_name, base_url)
1005
 
1006
+
1007
  class PerfXCloudChat(Base):
1008
  def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
1009
  if not base_url:
 
1036
  def __init__(self, key, model_name, base_url="https://api.01.ai/v1"):
1037
  if not base_url:
1038
  base_url = "https://api.01.ai/v1"
1039
+ super().__init__(key, model_name, base_url)
1040
+
1041
+
1042
+ class ReplicateChat(Base):
1043
+ def __init__(self, key, model_name, base_url=None):
1044
+ from replicate.client import Client
1045
+
1046
+ self.model_name = model_name
1047
+ self.client = Client(api_token=key)
1048
+ self.system = ""
1049
+
1050
+ def chat(self, system, history, gen_conf):
1051
+ if "max_tokens" in gen_conf:
1052
+ gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
1053
+ if system:
1054
+ self.system = system
1055
+ prompt = "\n".join(
1056
+ [item["role"] + ":" + item["content"] for item in history[-5:]]
1057
+ )
1058
+ ans = ""
1059
+ try:
1060
+ response = self.client.run(
1061
+ self.model_name,
1062
+ input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
1063
+ )
1064
+ ans = "".join(response)
1065
+ return ans, num_tokens_from_string(ans)
1066
+ except Exception as e:
1067
+ return ans + "\n**ERROR**: " + str(e), 0
1068
+
1069
+ def chat_streamly(self, system, history, gen_conf):
1070
+ if "max_tokens" in gen_conf:
1071
+ gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
1072
+ if system:
1073
+ self.system = system
1074
+ prompt = "\n".join(
1075
+ [item["role"] + ":" + item["content"] for item in history[-5:]]
1076
+ )
1077
+ ans = ""
1078
+ try:
1079
+ response = self.client.run(
1080
+ self.model_name,
1081
+ input={"system_prompt": self.system, "prompt": prompt, **gen_conf},
1082
+ )
1083
+ for resp in response:
1084
+ ans += resp
1085
+ yield ans
1086
+
1087
+ except Exception as e:
1088
+ yield ans + "\n**ERROR**: " + str(e)
1089
+
1090
+ yield num_tokens_from_string(ans)
rag/llm/embedding_model.py CHANGED
@@ -561,7 +561,7 @@ class TogetherAIEmbed(OllamaEmbed):
561
  base_url = "https://api.together.xyz/v1"
562
  super().__init__(key, model_name, base_url)
563
 
564
-
565
  class PerfXCloudEmbed(OpenAIEmbed):
566
  def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
567
  if not base_url:
@@ -580,4 +580,22 @@ class SILICONFLOWEmbed(OpenAIEmbed):
580
  def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"):
581
  if not base_url:
582
  base_url = "https://api.siliconflow.cn/v1"
583
- super().__init__(key, model_name, base_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  base_url = "https://api.together.xyz/v1"
562
  super().__init__(key, model_name, base_url)
563
 
564
+
565
  class PerfXCloudEmbed(OpenAIEmbed):
566
  def __init__(self, key, model_name, base_url="https://cloud.perfxlab.cn/v1"):
567
  if not base_url:
 
580
  def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1"):
581
  if not base_url:
582
  base_url = "https://api.siliconflow.cn/v1"
583
+ super().__init__(key, model_name, base_url)
584
+
585
+
586
+ class ReplicateEmbed(Base):
587
+ def __init__(self, key, model_name, base_url=None):
588
+ from replicate.client import Client
589
+
590
+ self.model_name = model_name
591
+ self.client = Client(api_token=key)
592
+
593
+ def encode(self, texts: list, batch_size=32):
594
+ from json import dumps
595
+
596
+ res = self.client.run(self.model_name, input={"texts": dumps(texts)})
597
+ return np.array(res), sum([num_tokens_from_string(text) for text in texts])
598
+
599
+ def encode_queries(self, text):
600
+ res = self.client.embed(self.model_name, input={"texts": [text]})
601
+ return np.array(res), num_tokens_from_string(text)
requirements.txt CHANGED
@@ -65,6 +65,7 @@ python_pptx==0.6.23
65
  readability_lxml==0.8.1
66
  redis==5.0.3
67
  Requests==2.32.2
 
68
  roman_numbers==1.0.2
69
  ruamel.base==1.0.0
70
  scholarly==1.7.11
@@ -87,4 +88,4 @@ wikipedia==1.4.0
87
  word2number==1.1
88
  xgboost==2.1.0
89
  xpinyin==0.7.6
90
- zhipuai==2.0.1
 
65
  readability_lxml==0.8.1
66
  redis==5.0.3
67
  Requests==2.32.2
68
+ replicate==0.31.0
69
  roman_numbers==1.0.2
70
  ruamel.base==1.0.0
71
  scholarly==1.7.11
 
88
  word2number==1.1
89
  xgboost==2.1.0
90
  xpinyin==0.7.6
91
+ zhipuai==2.0.1
requirements_arm.txt CHANGED
@@ -102,6 +102,7 @@ python-pptx==0.6.23
102
  PyYAML==6.0.1
103
  redis==5.0.3
104
  regex==2023.12.25
 
105
  requests==2.32.2
106
  ruamel.yaml==0.18.6
107
  ruamel.yaml.clib==0.2.8
@@ -161,4 +162,4 @@ markdown_to_json==2.1.1
161
  scholarly==1.7.11
162
  deepl==1.18.0
163
  psycopg2-binary==2.9.9
164
- tabulate-0.9.0
 
102
  PyYAML==6.0.1
103
  redis==5.0.3
104
  regex==2023.12.25
105
+ replicate==0.31.0
106
  requests==2.32.2
107
  ruamel.yaml==0.18.6
108
  ruamel.yaml.clib==0.2.8
 
162
  scholarly==1.7.11
163
  deepl==1.18.0
164
  psycopg2-binary==2.9.9
165
+ tabulate-0.9.0
web/src/assets/svg/llm/replicate.svg ADDED
web/src/pages/user-setting/constants.tsx CHANGED
@@ -17,4 +17,4 @@ export const UserSettingIconMap = {
17
 
18
  export * from '@/constants/setting';
19
 
20
- export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible",'TogetherAI'];
 
17
 
18
  export * from '@/constants/setting';
19
 
20
+ export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible",'TogetherAI','Replicate'];
web/src/pages/user-setting/setting-model/constant.ts CHANGED
@@ -30,7 +30,8 @@ export const IconMap = {
30
  Upstage: 'upstage',
31
  'novita.ai': 'novita-ai',
32
  SILICONFLOW: 'siliconflow',
33
- "01.AI": 'yi'
 
34
  };
35
 
36
  export const BedrockRegionList = [
 
30
  Upstage: 'upstage',
31
  'novita.ai': 'novita-ai',
32
  SILICONFLOW: 'siliconflow',
33
+ "01.AI": 'yi',
34
+ "Replicate": 'replicate'
35
  };
36
 
37
  export const BedrockRegionList = [