zhichyu commited on
Commit
2c75739
·
1 Parent(s): 5a552df

Pass top_p to ollama (#3744)

Browse files

### What problem does this PR solve?

Pass top_p to ollama. Close #1769

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Files changed (1) hide show
  1. rag/llm/chat_model.py +5 -5
rag/llm/chat_model.py CHANGED
@@ -356,7 +356,7 @@ class OllamaChat(Base):
356
  options = {}
357
  if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
358
  if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
359
- if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
360
  if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
361
  if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
362
  response = self.client.chat(
@@ -376,7 +376,7 @@ class OllamaChat(Base):
376
  options = {}
377
  if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
378
  if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
379
- if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
380
  if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
381
  if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
382
  ans = ""
@@ -430,7 +430,7 @@ class LocalLLM(Base):
430
  try:
431
  self._connection.send(pickle.dumps((name, args, kwargs)))
432
  return pickle.loads(self._connection.recv())
433
- except Exception as e:
434
  self.__conn()
435
  raise Exception("RPC connection lost!")
436
 
@@ -442,7 +442,7 @@ class LocalLLM(Base):
442
  self.client = Client(port=12345, protocol="grpc", asyncio=True)
443
 
444
  def _prepare_prompt(self, system, history, gen_conf):
445
- from rag.svr.jina_server import Prompt, Generation
446
  if system:
447
  history.insert(0, {"role": "system", "content": system})
448
  if "max_tokens" in gen_conf:
@@ -450,7 +450,7 @@ class LocalLLM(Base):
450
  return Prompt(message=history, gen_conf=gen_conf)
451
 
452
  def _stream_response(self, endpoint, prompt):
453
- from rag.svr.jina_server import Prompt, Generation
454
  answer = ""
455
  try:
456
  res = self.client.stream_doc(
 
356
  options = {}
357
  if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
358
  if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
359
+ if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
360
  if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
361
  if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
362
  response = self.client.chat(
 
376
  options = {}
377
  if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
378
  if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
379
+ if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
380
  if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
381
  if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
382
  ans = ""
 
430
  try:
431
  self._connection.send(pickle.dumps((name, args, kwargs)))
432
  return pickle.loads(self._connection.recv())
433
+ except Exception:
434
  self.__conn()
435
  raise Exception("RPC connection lost!")
436
 
 
442
  self.client = Client(port=12345, protocol="grpc", asyncio=True)
443
 
444
  def _prepare_prompt(self, system, history, gen_conf):
445
+ from rag.svr.jina_server import Prompt
446
  if system:
447
  history.insert(0, {"role": "system", "content": system})
448
  if "max_tokens" in gen_conf:
 
450
  return Prompt(message=history, gen_conf=gen_conf)
451
 
452
  def _stream_response(self, endpoint, prompt):
453
+ from rag.svr.jina_server import Generation
454
  answer = ""
455
  try:
456
  res = self.client.stream_doc(