H KevinHuSh commited on
Commit
dffdcde
·
1 Parent(s): 255441a

Add Support for AWS Bedrock (#1408)

Browse files

### What problem does this PR solve?

#308

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: KevinHuSh <[email protected]>

api/apps/llm_app.py CHANGED
@@ -109,15 +109,23 @@ def set_api_key():
109
  def add_llm():
110
  req = request.json
111
  factory = req["llm_factory"]
112
- # For VolcEngine, due to its special authentication method
113
- # Assemble volc_ak, volc_sk, endpoint_id into api_key
114
  if factory == "VolcEngine":
 
 
115
  temp = list(eval(req["llm_name"]).items())[0]
116
  llm_name = temp[0]
117
  endpoint_id = temp[1]
118
  api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \
119
  f'"volc_sk": "{req.get("volc_sk", "")}", ' \
120
  f'"ep_id": "{endpoint_id}", ' + '}'
 
 
 
 
 
 
 
121
  else:
122
  llm_name = req["llm_name"]
123
  api_key = "xxxxxxxxxxxxxxx"
@@ -134,7 +142,9 @@ def add_llm():
134
  msg = ""
135
  if llm["model_type"] == LLMType.EMBEDDING.value:
136
  mdl = EmbeddingModel[factory](
137
- key=None, model_name=llm["llm_name"], base_url=llm["api_base"])
 
 
138
  try:
139
  arr, tc = mdl.encode(["Test if the api key is available"])
140
  if len(arr[0]) == 0 or tc == 0:
@@ -143,7 +153,7 @@ def add_llm():
143
  msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
144
  elif llm["model_type"] == LLMType.CHAT.value:
145
  mdl = ChatModel[factory](
146
- key=llm['api_key'] if factory == "VolcEngine" else None,
147
  model_name=llm["llm_name"],
148
  base_url=llm["api_base"]
149
  )
 
109
  def add_llm():
110
  req = request.json
111
  factory = req["llm_factory"]
112
+
 
113
  if factory == "VolcEngine":
114
+ # For VolcEngine, due to its special authentication method
115
+ # Assemble volc_ak, volc_sk, endpoint_id into api_key
116
  temp = list(eval(req["llm_name"]).items())[0]
117
  llm_name = temp[0]
118
  endpoint_id = temp[1]
119
  api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \
120
  f'"volc_sk": "{req.get("volc_sk", "")}", ' \
121
  f'"ep_id": "{endpoint_id}", ' + '}'
122
+ elif factory == "Bedrock":
123
+ # For Bedrock, due to its special authentication method
124
+ # Assemble bedrock_ak, bedrock_sk, bedrock_region
125
+ llm_name = req["llm_name"]
126
+ api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \
127
+ f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \
128
+ f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}'
129
  else:
130
  llm_name = req["llm_name"]
131
  api_key = "xxxxxxxxxxxxxxx"
 
142
  msg = ""
143
  if llm["model_type"] == LLMType.EMBEDDING.value:
144
  mdl = EmbeddingModel[factory](
145
+ key=llm['api_key'] if factory in ["VolcEngine", "Bedrock"] else None,
146
+ model_name=llm["llm_name"],
147
+ base_url=llm["api_base"])
148
  try:
149
  arr, tc = mdl.encode(["Test if the api key is available"])
150
  if len(arr[0]) == 0 or tc == 0:
 
153
  msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
154
  elif llm["model_type"] == LLMType.CHAT.value:
155
  mdl = ChatModel[factory](
156
+ key=llm['api_key'] if factory in ["VolcEngine", "Bedrock"] else None,
157
  model_name=llm["llm_name"],
158
  base_url=llm["api_base"]
159
  )
api/db/init_data.py CHANGED
@@ -170,6 +170,11 @@ factory_infos = [{
170
  "logo": "",
171
  "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
172
  "status": "1",
 
 
 
 
 
173
  }
174
  # {
175
  # "name": "文心一言",
@@ -730,7 +735,170 @@ def init_llm_factory():
730
  "max_tokens": 765,
731
  "model_type": LLMType.IMAGE2TEXT.value
732
  },
733
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734
  ]
735
  for info in factory_infos:
736
  try:
 
170
  "logo": "",
171
  "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
172
  "status": "1",
173
+ },{
174
+ "name": "Bedrock",
175
+ "logo": "",
176
+ "tags": "LLM,TEXT EMBEDDING",
177
+ "status": "1",
178
  }
179
  # {
180
  # "name": "文心一言",
 
735
  "max_tokens": 765,
736
  "model_type": LLMType.IMAGE2TEXT.value
737
  },
738
+ # ------------------------ Bedrock -----------------------
739
+ {
740
+ "fid": factory_infos[16]["name"],
741
+ "llm_name": "ai21.j2-ultra-v1",
742
+ "tags": "LLM,CHAT,8k",
743
+ "max_tokens": 8191,
744
+ "model_type": LLMType.CHAT.value
745
+ }, {
746
+ "fid": factory_infos[16]["name"],
747
+ "llm_name": "ai21.j2-mid-v1",
748
+ "tags": "LLM,CHAT,8k",
749
+ "max_tokens": 8191,
750
+ "model_type": LLMType.CHAT.value
751
+ }, {
752
+ "fid": factory_infos[16]["name"],
753
+ "llm_name": "cohere.command-text-v14",
754
+ "tags": "LLM,CHAT,4k",
755
+ "max_tokens": 4096,
756
+ "model_type": LLMType.CHAT.value
757
+ }, {
758
+ "fid": factory_infos[16]["name"],
759
+ "llm_name": "cohere.command-light-text-v14",
760
+ "tags": "LLM,CHAT,4k",
761
+ "max_tokens": 4096,
762
+ "model_type": LLMType.CHAT.value
763
+ }, {
764
+ "fid": factory_infos[16]["name"],
765
+ "llm_name": "cohere.command-r-v1:0",
766
+ "tags": "LLM,CHAT,128k",
767
+ "max_tokens": 128 * 1024,
768
+ "model_type": LLMType.CHAT.value
769
+ }, {
770
+ "fid": factory_infos[16]["name"],
771
+ "llm_name": "cohere.command-r-plus-v1:0",
772
+ "tags": "LLM,CHAT,128k",
773
+ "max_tokens": 128000,
774
+ "model_type": LLMType.CHAT.value
775
+ }, {
776
+ "fid": factory_infos[16]["name"],
777
+ "llm_name": "anthropic.claude-v2",
778
+ "tags": "LLM,CHAT,100k",
779
+ "max_tokens": 100 * 1024,
780
+ "model_type": LLMType.CHAT.value
781
+ }, {
782
+ "fid": factory_infos[16]["name"],
783
+ "llm_name": "anthropic.claude-v2:1",
784
+ "tags": "LLM,CHAT,200k",
785
+ "max_tokens": 200 * 1024,
786
+ "model_type": LLMType.CHAT.value
787
+ }, {
788
+ "fid": factory_infos[16]["name"],
789
+ "llm_name": "anthropic.claude-3-sonnet-20240229-v1:0",
790
+ "tags": "LLM,CHAT,200k",
791
+ "max_tokens": 200 * 1024,
792
+ "model_type": LLMType.CHAT.value
793
+ }, {
794
+ "fid": factory_infos[16]["name"],
795
+ "llm_name": "anthropic.claude-3-5-sonnet-20240620-v1:0",
796
+ "tags": "LLM,CHAT,200k",
797
+ "max_tokens": 200 * 1024,
798
+ "model_type": LLMType.CHAT.value
799
+ }, {
800
+ "fid": factory_infos[16]["name"],
801
+ "llm_name": "anthropic.claude-3-haiku-20240307-v1:0",
802
+ "tags": "LLM,CHAT,200k",
803
+ "max_tokens": 200 * 1024,
804
+ "model_type": LLMType.CHAT.value
805
+ }, {
806
+ "fid": factory_infos[16]["name"],
807
+ "llm_name": "anthropic.claude-3-opus-20240229-v1:0",
808
+ "tags": "LLM,CHAT,200k",
809
+ "max_tokens": 200 * 1024,
810
+ "model_type": LLMType.CHAT.value
811
+ }, {
812
+ "fid": factory_infos[16]["name"],
813
+ "llm_name": "anthropic.claude-instant-v1",
814
+ "tags": "LLM,CHAT,100k",
815
+ "max_tokens": 100 * 1024,
816
+ "model_type": LLMType.CHAT.value
817
+ }, {
818
+ "fid": factory_infos[16]["name"],
819
+ "llm_name": "amazon.titan-text-express-v1",
820
+ "tags": "LLM,CHAT,8k",
821
+ "max_tokens": 8192,
822
+ "model_type": LLMType.CHAT.value
823
+ }, {
824
+ "fid": factory_infos[16]["name"],
825
+ "llm_name": "amazon.titan-text-premier-v1:0",
826
+ "tags": "LLM,CHAT,32k",
827
+ "max_tokens": 32 * 1024,
828
+ "model_type": LLMType.CHAT.value
829
+ }, {
830
+ "fid": factory_infos[16]["name"],
831
+ "llm_name": "amazon.titan-text-lite-v1",
832
+ "tags": "LLM,CHAT,4k",
833
+ "max_tokens": 4096,
834
+ "model_type": LLMType.CHAT.value
835
+ }, {
836
+ "fid": factory_infos[16]["name"],
837
+ "llm_name": "meta.llama2-13b-chat-v1",
838
+ "tags": "LLM,CHAT,4k",
839
+ "max_tokens": 4096,
840
+ "model_type": LLMType.CHAT.value
841
+ }, {
842
+ "fid": factory_infos[16]["name"],
843
+ "llm_name": "meta.llama2-70b-chat-v1",
844
+ "tags": "LLM,CHAT,4k",
845
+ "max_tokens": 4096,
846
+ "model_type": LLMType.CHAT.value
847
+ }, {
848
+ "fid": factory_infos[16]["name"],
849
+ "llm_name": "meta.llama3-8b-instruct-v1:0",
850
+ "tags": "LLM,CHAT,8k",
851
+ "max_tokens": 8192,
852
+ "model_type": LLMType.CHAT.value
853
+ }, {
854
+ "fid": factory_infos[16]["name"],
855
+ "llm_name": "meta.llama3-70b-instruct-v1:0",
856
+ "tags": "LLM,CHAT,8k",
857
+ "max_tokens": 8192,
858
+ "model_type": LLMType.CHAT.value
859
+ }, {
860
+ "fid": factory_infos[16]["name"],
861
+ "llm_name": "mistral.mistral-7b-instruct-v0:2",
862
+ "tags": "LLM,CHAT,8k",
863
+ "max_tokens": 8192,
864
+ "model_type": LLMType.CHAT.value
865
+ }, {
866
+ "fid": factory_infos[16]["name"],
867
+ "llm_name": "mistral.mixtral-8x7b-instruct-v0:1",
868
+ "tags": "LLM,CHAT,4k",
869
+ "max_tokens": 4096,
870
+ "model_type": LLMType.CHAT.value
871
+ }, {
872
+ "fid": factory_infos[16]["name"],
873
+ "llm_name": "mistral.mistral-large-2402-v1:0",
874
+ "tags": "LLM,CHAT,8k",
875
+ "max_tokens": 8192,
876
+ "model_type": LLMType.CHAT.value
877
+ }, {
878
+ "fid": factory_infos[16]["name"],
879
+ "llm_name": "mistral.mistral-small-2402-v1:0",
880
+ "tags": "LLM,CHAT,8k",
881
+ "max_tokens": 8192,
882
+ "model_type": LLMType.CHAT.value
883
+ }, {
884
+ "fid": factory_infos[16]["name"],
885
+ "llm_name": "amazon.titan-embed-text-v2:0",
886
+ "tags": "TEXT EMBEDDING",
887
+ "max_tokens": 8192,
888
+ "model_type": LLMType.EMBEDDING.value
889
+ }, {
890
+ "fid": factory_infos[16]["name"],
891
+ "llm_name": "cohere.embed-english-v3",
892
+ "tags": "TEXT EMBEDDING",
893
+ "max_tokens": 2048,
894
+ "model_type": LLMType.EMBEDDING.value
895
+ }, {
896
+ "fid": factory_infos[16]["name"],
897
+ "llm_name": "cohere.embed-multilingual-v3",
898
+ "tags": "TEXT EMBEDDING",
899
+ "max_tokens": 2048,
900
+ "model_type": LLMType.EMBEDDING.value
901
+ },
902
  ]
903
  for info in factory_infos:
904
  try:
rag/llm/__init__.py CHANGED
@@ -31,7 +31,8 @@ EmbeddingModel = {
31
  "BaiChuan": BaiChuanEmbed,
32
  "Jina": JinaEmbed,
33
  "BAAI": DefaultEmbedding,
34
- "Mistral": MistralEmbed
 
35
  }
36
 
37
 
@@ -58,7 +59,8 @@ ChatModel = {
58
  "VolcEngine": VolcEngineChat,
59
  "BaiChuan": BaiChuanChat,
60
  "MiniMax": MiniMaxChat,
61
- "Mistral": MistralChat
 
62
  }
63
 
64
 
 
31
  "BaiChuan": BaiChuanEmbed,
32
  "Jina": JinaEmbed,
33
  "BAAI": DefaultEmbedding,
34
+ "Mistral": MistralEmbed,
35
+ "Bedrock": BedrockEmbed
36
  }
37
 
38
 
 
59
  "VolcEngine": VolcEngineChat,
60
  "BaiChuan": BaiChuanChat,
61
  "MiniMax": MiniMaxChat,
62
+ "Mistral": MistralChat,
63
+ "Bedrock": BedrockChat
64
  }
65
 
66
 
rag/llm/chat_model.py CHANGED
@@ -533,3 +533,90 @@ class MistralChat(Base):
533
  yield ans + "\n**ERROR**: " + str(e)
534
 
535
  yield total_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  yield ans + "\n**ERROR**: " + str(e)
534
 
535
  yield total_tokens
536
+
537
+
538
+ class BedrockChat(Base):
539
+
540
+ def __init__(self, key, model_name, **kwargs):
541
+ import boto3
542
+ from botocore.exceptions import ClientError
543
+ self.bedrock_ak = eval(key).get('bedrock_ak', '')
544
+ self.bedrock_sk = eval(key).get('bedrock_sk', '')
545
+ self.bedrock_region = eval(key).get('bedrock_region', '')
546
+ self.model_name = model_name
547
+ self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
548
+ aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
549
+
550
+ def chat(self, system, history, gen_conf):
551
+ if system:
552
+ history.insert(0, {"role": "system", "content": system})
553
+ for k in list(gen_conf.keys()):
554
+ if k not in ["temperature", "top_p", "max_tokens"]:
555
+ del gen_conf[k]
556
+ if "max_tokens" in gen_conf:
557
+ gen_conf["maxTokens"] = gen_conf["max_tokens"]
558
+ _ = gen_conf.pop("max_tokens")
559
+ if "top_p" in gen_conf:
560
+ gen_conf["topP"] = gen_conf["top_p"]
561
+ _ = gen_conf.pop("top_p")
562
+
563
+ try:
564
+ # Send the message to the model, using a basic inference configuration.
565
+ response = self.client.converse(
566
+ modelId=self.model_name,
567
+ messages=history,
568
+ inferenceConfig=gen_conf
569
+ )
570
+
571
+ # Extract and print the response text.
572
+ ans = response["output"]["message"]["content"][0]["text"]
573
+ return ans, num_tokens_from_string(ans)
574
+
575
+ except (ClientError, Exception) as e:
576
+ return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
577
+
578
+ def chat_streamly(self, system, history, gen_conf):
579
+ if system:
580
+ history.insert(0, {"role": "system", "content": system})
581
+ for k in list(gen_conf.keys()):
582
+ if k not in ["temperature", "top_p", "max_tokens"]:
583
+ del gen_conf[k]
584
+ if "max_tokens" in gen_conf:
585
+ gen_conf["maxTokens"] = gen_conf["max_tokens"]
586
+ _ = gen_conf.pop("max_tokens")
587
+ if "top_p" in gen_conf:
588
+ gen_conf["topP"] = gen_conf["top_p"]
589
+ _ = gen_conf.pop("top_p")
590
+
591
+ if self.model_name.split('.')[0] == 'ai21':
592
+ try:
593
+ response = self.client.converse(
594
+ modelId=self.model_name,
595
+ messages=history,
596
+ inferenceConfig=gen_conf
597
+ )
598
+ ans = response["output"]["message"]["content"][0]["text"]
599
+ return ans, num_tokens_from_string(ans)
600
+
601
+ except (ClientError, Exception) as e:
602
+ return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
603
+
604
+ ans = ""
605
+ try:
606
+ # Send the message to the model, using a basic inference configuration.
607
+ streaming_response = self.client.converse_stream(
608
+ modelId=self.model_name,
609
+ messages=history,
610
+ inferenceConfig=gen_conf
611
+ )
612
+
613
+ # Extract and print the streamed response text in real-time.
614
+ for resp in streaming_response["stream"]:
615
+ if "contentBlockDelta" in resp:
616
+ ans += resp["contentBlockDelta"]["delta"]["text"]
617
+ yield ans
618
+
619
+ except (ClientError, Exception) as e:
620
+ yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
621
+
622
+ yield num_tokens_from_string(ans)
rag/llm/embedding_model.py CHANGED
@@ -374,3 +374,48 @@ class MistralEmbed(Base):
374
  res = self.client.embeddings(input=[truncate(text, 8196)],
375
  model=self.model_name)
376
  return np.array(res.data[0].embedding), res.usage.total_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  res = self.client.embeddings(input=[truncate(text, 8196)],
375
  model=self.model_name)
376
  return np.array(res.data[0].embedding), res.usage.total_tokens
377
+
378
+
379
+ class BedrockEmbed(Base):
380
+ def __init__(self, key, model_name,
381
+ **kwargs):
382
+ import boto3
383
+ self.bedrock_ak = eval(key).get('bedrock_ak', '')
384
+ self.bedrock_sk = eval(key).get('bedrock_sk', '')
385
+ self.bedrock_region = eval(key).get('bedrock_region', '')
386
+ self.model_name = model_name
387
+ self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
388
+ aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
389
+
390
+ def encode(self, texts: list, batch_size=32):
391
+ texts = [truncate(t, 8196) for t in texts]
392
+ embeddings = []
393
+ token_count = 0
394
+ for text in texts:
395
+ if self.model_name.split('.')[0] == 'amazon':
396
+ body = {"inputText": text}
397
+ elif self.model_name.split('.')[0] == 'cohere':
398
+ body = {"texts": [text], "input_type": 'search_document'}
399
+
400
+ response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
401
+ model_response = json.loads(response["body"].read())
402
+ embeddings.extend([model_response["embedding"]])
403
+ token_count += num_tokens_from_string(text)
404
+
405
+ return np.array(embeddings), token_count
406
+
407
+ def encode_queries(self, text):
408
+
409
+ embeddings = []
410
+ token_count = num_tokens_from_string(text)
411
+ if self.model_name.split('.')[0] == 'amazon':
412
+ body = {"inputText": truncate(text, 8196)}
413
+ elif self.model_name.split('.')[0] == 'cohere':
414
+ body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'}
415
+
416
+ response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
417
+ model_response = json.loads(response["body"].read())
418
+ embeddings.extend([model_response["embedding"]])
419
+
420
+ return np.array(embeddings), token_count
421
+
requirements.txt CHANGED
@@ -144,4 +144,6 @@ cn2an==0.5.22
144
  roman-numbers==1.0.2
145
  word2number==1.1
146
  markdown==3.6
 
 
147
  duckduckgo_search==6.1.9
 
144
  roman-numbers==1.0.2
145
  word2number==1.1
146
  markdown==3.6
147
+ mistralai==0.4.2
148
+ boto3==1.34.140
149
  duckduckgo_search==6.1.9
requirements_arm.txt CHANGED
@@ -145,4 +145,6 @@ cn2an==0.5.22
145
  roman-numbers==1.0.2
146
  word2number==1.1
147
  markdown==3.6
 
 
148
  duckduckgo_search==6.1.9
 
145
  roman-numbers==1.0.2
146
  word2number==1.1
147
  markdown==3.6
148
+ mistralai==0.4.2
149
+ boto3==1.34.140
150
  duckduckgo_search==6.1.9
requirements_dev.txt CHANGED
@@ -130,4 +130,6 @@ cn2an==0.5.22
130
  roman-numbers==1.0.2
131
  word2number==1.1
132
  markdown==3.6
 
 
133
  duckduckgo_search==6.1.9
 
130
  roman-numbers==1.0.2
131
  word2number==1.1
132
  markdown==3.6
133
+ mistralai==0.4.2
134
+ boto3==1.34.140
135
  duckduckgo_search==6.1.9