Pass top_p to ollama (#3744)
Browse files### What problem does this PR solve?
Pass top_p to ollama. Close #1769
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- rag/llm/chat_model.py +5 -5
rag/llm/chat_model.py
CHANGED
@@ -356,7 +356,7 @@ class OllamaChat(Base):
|
|
356 |
options = {}
|
357 |
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
358 |
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
359 |
-
if "top_p" in gen_conf: options["
|
360 |
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
361 |
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
362 |
response = self.client.chat(
|
@@ -376,7 +376,7 @@ class OllamaChat(Base):
|
|
376 |
options = {}
|
377 |
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
378 |
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
379 |
-
if "top_p" in gen_conf: options["
|
380 |
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
381 |
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
382 |
ans = ""
|
@@ -430,7 +430,7 @@ class LocalLLM(Base):
|
|
430 |
try:
|
431 |
self._connection.send(pickle.dumps((name, args, kwargs)))
|
432 |
return pickle.loads(self._connection.recv())
|
433 |
-
except Exception
|
434 |
self.__conn()
|
435 |
raise Exception("RPC connection lost!")
|
436 |
|
@@ -442,7 +442,7 @@ class LocalLLM(Base):
|
|
442 |
self.client = Client(port=12345, protocol="grpc", asyncio=True)
|
443 |
|
444 |
def _prepare_prompt(self, system, history, gen_conf):
|
445 |
-
from rag.svr.jina_server import Prompt
|
446 |
if system:
|
447 |
history.insert(0, {"role": "system", "content": system})
|
448 |
if "max_tokens" in gen_conf:
|
@@ -450,7 +450,7 @@ class LocalLLM(Base):
|
|
450 |
return Prompt(message=history, gen_conf=gen_conf)
|
451 |
|
452 |
def _stream_response(self, endpoint, prompt):
|
453 |
-
from rag.svr.jina_server import
|
454 |
answer = ""
|
455 |
try:
|
456 |
res = self.client.stream_doc(
|
|
|
356 |
options = {}
|
357 |
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
358 |
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
359 |
+
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
|
360 |
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
361 |
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
362 |
response = self.client.chat(
|
|
|
376 |
options = {}
|
377 |
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
378 |
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
379 |
+
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
|
380 |
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
381 |
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
382 |
ans = ""
|
|
|
430 |
try:
|
431 |
self._connection.send(pickle.dumps((name, args, kwargs)))
|
432 |
return pickle.loads(self._connection.recv())
|
433 |
+
except Exception:
|
434 |
self.__conn()
|
435 |
raise Exception("RPC connection lost!")
|
436 |
|
|
|
442 |
self.client = Client(port=12345, protocol="grpc", asyncio=True)
|
443 |
|
444 |
def _prepare_prompt(self, system, history, gen_conf):
|
445 |
+
from rag.svr.jina_server import Prompt
|
446 |
if system:
|
447 |
history.insert(0, {"role": "system", "content": system})
|
448 |
if "max_tokens" in gen_conf:
|
|
|
450 |
return Prompt(message=history, gen_conf=gen_conf)
|
451 |
|
452 |
def _stream_response(self, endpoint, prompt):
|
453 |
+
from rag.svr.jina_server import Generation
|
454 |
answer = ""
|
455 |
try:
|
456 |
res = self.client.stream_doc(
|