add support for LM Studio (#1663)
Browse files### What problem does this PR solve?
#1602
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
---------
Co-authored-by: Zhedong Cen <[email protected]>
- api/apps/llm_app.py +8 -4
- conf/llm_factories.json +7 -0
- rag/llm/__init__.py +15 -11
- rag/llm/chat_model.py +12 -0
- rag/llm/cv_model.py +13 -9
- rag/llm/embedding_model.py +21 -0
- rag/llm/rerank_model.py +8 -0
- web/src/assets/svg/llm/lm-studio.svg +0 -0
- web/src/pages/user-setting/constants.tsx +1 -1
- web/src/pages/user-setting/setting-model/constant.ts +2 -1
api/apps/llm_app.py
CHANGED
|
@@ -21,7 +21,7 @@ from api.db import StatusEnum, LLMType
|
|
| 21 |
from api.db.db_models import TenantLLM
|
| 22 |
from api.utils.api_utils import get_json_result
|
| 23 |
from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel
|
| 24 |
-
|
| 25 |
|
| 26 |
@manager.route('/factories', methods=['GET'])
|
| 27 |
@login_required
|
|
@@ -189,9 +189,13 @@ def add_llm():
|
|
| 189 |
"ons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/256"
|
| 190 |
"0px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
| 191 |
)
|
| 192 |
-
|
| 193 |
-
if
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
except Exception as e:
|
| 196 |
msg += f"\nFail to access model({llm['llm_name']})." + str(e)
|
| 197 |
else:
|
|
|
|
| 21 |
from api.db.db_models import TenantLLM
|
| 22 |
from api.utils.api_utils import get_json_result
|
| 23 |
from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel
|
| 24 |
+
import requests
|
| 25 |
|
| 26 |
@manager.route('/factories', methods=['GET'])
|
| 27 |
@login_required
|
|
|
|
| 189 |
"ons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/256"
|
| 190 |
"0px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
| 191 |
)
|
| 192 |
+
res = requests.get(img_url)
|
| 193 |
+
if res.status_code == 200:
|
| 194 |
+
m, tc = mdl.describe(res.content)
|
| 195 |
+
if not tc:
|
| 196 |
+
raise Exception(m)
|
| 197 |
+
else:
|
| 198 |
+
raise ConnectionError("fail to download the test picture")
|
| 199 |
except Exception as e:
|
| 200 |
msg += f"\nFail to access model({llm['llm_name']})." + str(e)
|
| 201 |
else:
|
conf/llm_factories.json
CHANGED
|
@@ -2208,6 +2208,13 @@
|
|
| 2208 |
"model_type": "image2text"
|
| 2209 |
}
|
| 2210 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2211 |
}
|
| 2212 |
]
|
| 2213 |
}
|
|
|
|
| 2208 |
"model_type": "image2text"
|
| 2209 |
}
|
| 2210 |
]
|
| 2211 |
+
},
|
| 2212 |
+
{
|
| 2213 |
+
"name": "LM-Studio",
|
| 2214 |
+
"logo": "",
|
| 2215 |
+
"tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
|
| 2216 |
+
"status": "1",
|
| 2217 |
+
"llm": []
|
| 2218 |
}
|
| 2219 |
]
|
| 2220 |
}
|
rag/llm/__init__.py
CHANGED
|
@@ -34,8 +34,9 @@ EmbeddingModel = {
|
|
| 34 |
"BAAI": DefaultEmbedding,
|
| 35 |
"Mistral": MistralEmbed,
|
| 36 |
"Bedrock": BedrockEmbed,
|
| 37 |
-
"Gemini":GeminiEmbed,
|
| 38 |
-
"NVIDIA":NvidiaEmbed
|
|
|
|
| 39 |
}
|
| 40 |
|
| 41 |
|
|
@@ -47,10 +48,11 @@ CvModel = {
|
|
| 47 |
"Tongyi-Qianwen": QWenCV,
|
| 48 |
"ZHIPU-AI": Zhipu4V,
|
| 49 |
"Moonshot": LocalCV,
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
"LocalAI":LocalAICV,
|
| 53 |
-
"NVIDIA":NvidiaCV
|
|
|
|
| 54 |
}
|
| 55 |
|
| 56 |
|
|
@@ -69,12 +71,13 @@ ChatModel = {
|
|
| 69 |
"MiniMax": MiniMaxChat,
|
| 70 |
"Minimax": MiniMaxChat,
|
| 71 |
"Mistral": MistralChat,
|
| 72 |
-
|
| 73 |
"Bedrock": BedrockChat,
|
| 74 |
"Groq": GroqChat,
|
| 75 |
-
|
| 76 |
-
"StepFun":StepFunChat,
|
| 77 |
-
"NVIDIA":NvidiaChat
|
|
|
|
| 78 |
}
|
| 79 |
|
| 80 |
|
|
@@ -83,7 +86,8 @@ RerankModel = {
|
|
| 83 |
"Jina": JinaRerank,
|
| 84 |
"Youdao": YoudaoRerank,
|
| 85 |
"Xinference": XInferenceRerank,
|
| 86 |
-
"NVIDIA":NvidiaRerank
|
|
|
|
| 87 |
}
|
| 88 |
|
| 89 |
|
|
|
|
| 34 |
"BAAI": DefaultEmbedding,
|
| 35 |
"Mistral": MistralEmbed,
|
| 36 |
"Bedrock": BedrockEmbed,
|
| 37 |
+
"Gemini": GeminiEmbed,
|
| 38 |
+
"NVIDIA": NvidiaEmbed,
|
| 39 |
+
"LM-Studio": LmStudioEmbed
|
| 40 |
}
|
| 41 |
|
| 42 |
|
|
|
|
| 48 |
"Tongyi-Qianwen": QWenCV,
|
| 49 |
"ZHIPU-AI": Zhipu4V,
|
| 50 |
"Moonshot": LocalCV,
|
| 51 |
+
"Gemini": GeminiCV,
|
| 52 |
+
"OpenRouter": OpenRouterCV,
|
| 53 |
+
"LocalAI": LocalAICV,
|
| 54 |
+
"NVIDIA": NvidiaCV,
|
| 55 |
+
"LM-Studio": LmStudioCV
|
| 56 |
}
|
| 57 |
|
| 58 |
|
|
|
|
| 71 |
"MiniMax": MiniMaxChat,
|
| 72 |
"Minimax": MiniMaxChat,
|
| 73 |
"Mistral": MistralChat,
|
| 74 |
+
"Gemini": GeminiChat,
|
| 75 |
"Bedrock": BedrockChat,
|
| 76 |
"Groq": GroqChat,
|
| 77 |
+
"OpenRouter": OpenRouterChat,
|
| 78 |
+
"StepFun": StepFunChat,
|
| 79 |
+
"NVIDIA": NvidiaChat,
|
| 80 |
+
"LM-Studio": LmStudioChat
|
| 81 |
}
|
| 82 |
|
| 83 |
|
|
|
|
| 86 |
"Jina": JinaRerank,
|
| 87 |
"Youdao": YoudaoRerank,
|
| 88 |
"Xinference": XInferenceRerank,
|
| 89 |
+
"NVIDIA": NvidiaRerank,
|
| 90 |
+
"LM-Studio": LmStudioRerank
|
| 91 |
}
|
| 92 |
|
| 93 |
|
rag/llm/chat_model.py
CHANGED
|
@@ -976,3 +976,15 @@ class NvidiaChat(Base):
|
|
| 976 |
yield ans + "\n**ERROR**: " + str(e)
|
| 977 |
|
| 978 |
yield total_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 976 |
yield ans + "\n**ERROR**: " + str(e)
|
| 977 |
|
| 978 |
yield total_tokens
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
class LmStudioChat(Base):
|
| 982 |
+
def __init__(self, key, model_name, base_url):
|
| 983 |
+
from os.path import join
|
| 984 |
+
|
| 985 |
+
if not base_url:
|
| 986 |
+
raise ValueError("Local llm url cannot be None")
|
| 987 |
+
if base_url.split("/")[-1] != "v1":
|
| 988 |
+
self.base_url = join(base_url, "v1")
|
| 989 |
+
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
|
| 990 |
+
self.model_name = model_name
|
rag/llm/cv_model.py
CHANGED
|
@@ -440,15 +440,8 @@ class LocalAICV(Base):
|
|
| 440 |
self.lang = lang
|
| 441 |
|
| 442 |
def describe(self, image, max_tokens=300):
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
): # if url string
|
| 446 |
-
prompt = self.prompt(image)
|
| 447 |
-
for i in range(len(prompt)):
|
| 448 |
-
prompt[i]["content"]["image_url"]["url"] = image
|
| 449 |
-
else:
|
| 450 |
-
b64 = self.image2base64(image)
|
| 451 |
-
prompt = self.prompt(b64)
|
| 452 |
for i in range(len(prompt)):
|
| 453 |
for c in prompt[i]["content"]:
|
| 454 |
if "text" in c:
|
|
@@ -680,3 +673,14 @@ class NvidiaCV(Base):
|
|
| 680 |
"content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
|
| 681 |
}
|
| 682 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
self.lang = lang
|
| 441 |
|
| 442 |
def describe(self, image, max_tokens=300):
|
| 443 |
+
b64 = self.image2base64(image)
|
| 444 |
+
prompt = self.prompt(b64)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
for i in range(len(prompt)):
|
| 446 |
for c in prompt[i]["content"]:
|
| 447 |
if "text" in c:
|
|
|
|
| 673 |
"content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
|
| 674 |
}
|
| 675 |
]
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
class LmStudioCV(LocalAICV):
|
| 679 |
+
def __init__(self, key, model_name, base_url, lang="Chinese"):
|
| 680 |
+
if not base_url:
|
| 681 |
+
raise ValueError("Local llm url cannot be None")
|
| 682 |
+
if base_url.split('/')[-1] != 'v1':
|
| 683 |
+
self.base_url = os.path.join(base_url,'v1')
|
| 684 |
+
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
|
| 685 |
+
self.model_name = model_name
|
| 686 |
+
self.lang = lang
|
rag/llm/embedding_model.py
CHANGED
|
@@ -500,3 +500,24 @@ class NvidiaEmbed(Base):
|
|
| 500 |
def encode_queries(self, text):
|
| 501 |
embds, cnt = self.encode([text])
|
| 502 |
return np.array(embds[0]), cnt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
def encode_queries(self, text):
|
| 501 |
embds, cnt = self.encode([text])
|
| 502 |
return np.array(embds[0]), cnt
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
class LmStudioEmbed(Base):
|
| 506 |
+
def __init__(self, key, model_name, base_url):
|
| 507 |
+
if not base_url:
|
| 508 |
+
raise ValueError("Local llm url cannot be None")
|
| 509 |
+
if base_url.split("/")[-1] != "v1":
|
| 510 |
+
self.base_url = os.path.join(base_url, "v1")
|
| 511 |
+
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
|
| 512 |
+
self.model_name = model_name
|
| 513 |
+
|
| 514 |
+
def encode(self, texts: list, batch_size=32):
|
| 515 |
+
res = self.client.embeddings.create(input=texts, model=self.model_name)
|
| 516 |
+
return (
|
| 517 |
+
np.array([d.embedding for d in res.data]),
|
| 518 |
+
1024,
|
| 519 |
+
) # local embedding for LmStudio donot count tokens
|
| 520 |
+
|
| 521 |
+
def encode_queries(self, text):
|
| 522 |
+
res = self.client.embeddings.create(text, model=self.model_name)
|
| 523 |
+
return np.array(res.data[0].embedding), 1024
|
rag/llm/rerank_model.py
CHANGED
|
@@ -202,3 +202,11 @@ class NvidiaRerank(Base):
|
|
| 202 |
}
|
| 203 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
| 204 |
return (np.array([d["logit"] for d in res["rankings"]]), token_count)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
}
|
| 203 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
| 204 |
return (np.array([d["logit"] for d in res["rankings"]]), token_count)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class LmStudioRerank(Base):
|
| 208 |
+
def __init__(self, key, model_name, base_url):
|
| 209 |
+
pass
|
| 210 |
+
|
| 211 |
+
def similarity(self, query: str, texts: list):
|
| 212 |
+
raise NotImplementedError("The LmStudioRerank has not been implement")
|
web/src/assets/svg/llm/lm-studio.svg
ADDED
|
|
web/src/pages/user-setting/constants.tsx
CHANGED
|
@@ -17,4 +17,4 @@ export const UserSettingIconMap = {
|
|
| 17 |
|
| 18 |
export * from '@/constants/setting';
|
| 19 |
|
| 20 |
-
export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI'];
|
|
|
|
| 17 |
|
| 18 |
export * from '@/constants/setting';
|
| 19 |
|
| 20 |
+
export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio'];
|
web/src/pages/user-setting/setting-model/constant.ts
CHANGED
|
@@ -20,7 +20,8 @@ export const IconMap = {
|
|
| 20 |
OpenRouter: 'open-router',
|
| 21 |
LocalAI: 'local-ai',
|
| 22 |
StepFun: 'stepfun',
|
| 23 |
-
NVIDIA:'nvidia'
|
|
|
|
| 24 |
};
|
| 25 |
|
| 26 |
export const BedrockRegionList = [
|
|
|
|
| 20 |
OpenRouter: 'open-router',
|
| 21 |
LocalAI: 'local-ai',
|
| 22 |
StepFun: 'stepfun',
|
| 23 |
+
NVIDIA:'nvidia',
|
| 24 |
+
'LM-Studio':'lm-studio'
|
| 25 |
};
|
| 26 |
|
| 27 |
export const BedrockRegionList = [
|