Alvaro8gb commited on
Commit
3bff36c
·
verified ·
1 Parent(s): 948f060

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -9,6 +9,7 @@ TEMPERATURE = 0.5
9
  TOP_P = 0.95
10
  TOP_K = 50
11
  REPETITION_PENALTY = 1.05
 
12
 
13
  HF_TOKEN = os.getenv('HF_TOKEN')
14
 
@@ -38,12 +39,12 @@ tokenizer = None
38
 
39
  def generate_response(input_text, max_tokens, temperature, top_p, repetition_penalty):
40
  global model, tokenizer
41
-
42
  if model is None or tokenizer is None:
43
  model, tokenizer = load_model()
44
-
45
- inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
46
-
47
  with torch.no_grad():
48
  outputs = model.generate(
49
  **inputs,
@@ -56,12 +57,12 @@ def generate_response(input_text, max_tokens, temperature, top_p, repetition_pen
56
  )
57
 
58
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
59
-
60
- if "->" in full_response:
61
- response_parts = full_response.split("->", 1)
62
  if len(response_parts) > 1:
63
  return response_parts[1].strip()
64
-
65
  return full_response.strip()
66
 
67
  def chat_interface(message, history, system_message, max_tokens, temperature, top_p, repetition_penalty):
 
9
  TOP_P = 0.95
10
  TOP_K = 50
11
  REPETITION_PENALTY = 1.05
12
+ SPECIAL_TOKEN = "->:"
13
 
14
  HF_TOKEN = os.getenv('HF_TOKEN')
15
 
 
39
 
40
  def generate_response(input_text, max_tokens, temperature, top_p, repetition_penalty):
41
  global model, tokenizer
42
+
43
  if model is None or tokenizer is None:
44
  model, tokenizer = load_model()
45
+
46
+ inputs = tokenizer(input_text + SPECIAL_TOKEN, return_tensors="pt").to(model.device)
47
+
48
  with torch.no_grad():
49
  outputs = model.generate(
50
  **inputs,
 
57
  )
58
 
59
  full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
60
+
61
+ if SPECIAL_TOKEN in full_response:
62
+ response_parts = full_response.split(SPECIAL_TOKEN, 1)
63
  if len(response_parts) > 1:
64
  return response_parts[1].strip()
65
+
66
  return full_response.strip()
67
 
68
  def chat_interface(message, history, system_message, max_tokens, temperature, top_p, repetition_penalty):