Kevin Hu
commited on
Commit
·
96edfc5
1
Parent(s):
684f1d7
refine xinference (#2521)
Browse files### What problem does this PR solve?
#1588
### Type of change
- [x] Refactoring
- rag/llm/cv_model.py +2 -0
- rag/llm/embedding_model.py +2 -0
- rag/llm/rerank_model.py +2 -0
- rag/llm/sequence2txt_model.py +2 -0
rag/llm/cv_model.py
CHANGED
|
@@ -449,6 +449,8 @@ class LocalAICV(GptV4):
|
|
| 449 |
|
| 450 |
class XinferenceCV(Base):
|
| 451 |
def __init__(self, key, model_name="", lang="Chinese", base_url=""):
|
|
|
|
|
|
|
| 452 |
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
| 453 |
self.model_name = model_name
|
| 454 |
self.lang = lang
|
|
|
|
| 449 |
|
| 450 |
class XinferenceCV(Base):
|
| 451 |
def __init__(self, key, model_name="", lang="Chinese", base_url=""):
|
| 452 |
+
if base_url.split("/")[-1] != "v1":
|
| 453 |
+
base_url = os.path.join(base_url, "v1")
|
| 454 |
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
| 455 |
self.model_name = model_name
|
| 456 |
self.lang = lang
|
rag/llm/embedding_model.py
CHANGED
|
@@ -268,6 +268,8 @@ class FastEmbed(Base):
|
|
| 268 |
|
| 269 |
class XinferenceEmbed(Base):
|
| 270 |
def __init__(self, key, model_name="", base_url=""):
|
|
|
|
|
|
|
| 271 |
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
| 272 |
self.model_name = model_name
|
| 273 |
|
|
|
|
| 268 |
|
| 269 |
class XinferenceEmbed(Base):
|
| 270 |
def __init__(self, key, model_name="", base_url=""):
|
| 271 |
+
if base_url.split("/")[-1] != "v1":
|
| 272 |
+
base_url = os.path.join(base_url, "v1")
|
| 273 |
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
| 274 |
self.model_name = model_name
|
| 275 |
|
rag/llm/rerank_model.py
CHANGED
|
@@ -140,6 +140,8 @@ class YoudaoRerank(DefaultRerank):
|
|
| 140 |
|
| 141 |
class XInferenceRerank(Base):
|
| 142 |
def __init__(self, key="xxxxxxx", model_name="", base_url=""):
|
|
|
|
|
|
|
| 143 |
self.model_name = model_name
|
| 144 |
self.base_url = base_url
|
| 145 |
self.headers = {
|
|
|
|
| 140 |
|
| 141 |
class XInferenceRerank(Base):
|
| 142 |
def __init__(self, key="xxxxxxx", model_name="", base_url=""):
|
| 143 |
+
if base_url.split("/")[-1] != "v1":
|
| 144 |
+
base_url = os.path.join(base_url, "v1")
|
| 145 |
self.model_name = model_name
|
| 146 |
self.base_url = base_url
|
| 147 |
self.headers = {
|
rag/llm/sequence2txt_model.py
CHANGED
|
@@ -93,6 +93,8 @@ class AzureSeq2txt(Base):
|
|
| 93 |
|
| 94 |
class XinferenceSeq2txt(Base):
|
| 95 |
def __init__(self, key, model_name="", base_url=""):
|
|
|
|
|
|
|
| 96 |
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
| 97 |
self.model_name = model_name
|
| 98 |
|
|
|
|
| 93 |
|
| 94 |
class XinferenceSeq2txt(Base):
|
| 95 |
def __init__(self, key, model_name="", base_url=""):
|
| 96 |
+
if base_url.split("/")[-1] != "v1":
|
| 97 |
+
base_url = os.path.join(base_url, "v1")
|
| 98 |
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
| 99 |
self.model_name = model_name
|
| 100 |
|