AndreaAlessandrelli4 commited on
Commit
6dc02f1
·
verified ·
1 Parent(s): c422365

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -10
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
 
4
  import gradio as gr
5
  import spaces
6
  import torch
7
- import weaviate
8
- from sentence_transformers import SentenceTransformer
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
  MAX_MAX_NEW_TOKENS = 2048
@@ -33,7 +32,6 @@ if torch.cuda.is_available():
33
  )
34
 
35
 
36
-
37
  def prompt_template(materiali, query):
38
  mat = ''
39
  for i, doc in enumerate(materiali):
@@ -67,13 +65,14 @@ def richiamo_materiali(query, vett_query, alpha=1.0, N_items=5):
67
 
68
 
69
 
 
70
  @spaces.GPU
71
  def generate(
72
  message: str,
73
  chat_history: list[tuple[str, str]],
74
- #system_prompt: str,
75
  max_new_tokens: int = 1024,
76
- temperature: float = 0.1,
77
  top_p: float = 0.9,
78
  top_k: int = 50,
79
  repetition_penalty: float = 1.2,
@@ -96,14 +95,15 @@ def generate(
96
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
97
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
98
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
99
- gr.Warning(f"Chat troppo lunga superati {MAX_INPUT_TOKEN_LENGTH} tokens.")
100
  input_ids = input_ids.to(model.device)
101
 
102
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
103
  generate_kwargs = dict(
104
  {"input_ids": input_ids},
105
  streamer=streamer,
106
  max_new_tokens=max_new_tokens,
 
107
  top_p=top_p,
108
  top_k=top_k,
109
  temperature=temperature,
@@ -123,14 +123,50 @@ chat_interface = gr.ChatInterface(
123
  fn=generate,
124
  chatbot=gr.Chatbot(height=400, label = "AvvoChat", show_copy_button=True, avatar_images=("users.jpg","AvvoVhat.png"), show_share_button=True),
125
  textbox=gr.Textbox(placeholder="Inserisci la tua domanda", container=False, scale=7),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  submit_btn ="Chiedi all'AvvoChat ",
127
  retry_btn = "Rigenera",
128
  undo_btn = None,
129
  clear_btn = "Pulisci chat",
130
  fill_height = True,
131
  theme = "gstaff/sketch",
132
- #title="Avvo-Chat",
133
- #description="""Fai una domanda riguardante la legge italiana all'AvvoChat e ricevi una spiegazione semplice al tuo dubbio.""",
134
  examples=[
135
  ["Posso fare una grigliata sul balcone di casa?"],
136
  ["Se esco di casa senza documento di identità posso essere multato?"],
@@ -151,4 +187,4 @@ with gr.Blocks(css="style.css") as demo:
151
 
152
 
153
  if __name__ == "__main__":
154
- demo.queue(max_size=20).launch()
 
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
4
+
5
  import gradio as gr
6
  import spaces
7
  import torch
 
 
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
 
32
  )
33
 
34
 
 
35
  def prompt_template(materiali, query):
36
  mat = ''
37
  for i, doc in enumerate(materiali):
 
65
 
66
 
67
 
68
+
69
  @spaces.GPU
70
  def generate(
71
  message: str,
72
  chat_history: list[tuple[str, str]],
73
+ system_prompt: str,
74
  max_new_tokens: int = 1024,
75
+ temperature: float = 0.6,
76
  top_p: float = 0.9,
77
  top_k: int = 50,
78
  repetition_penalty: float = 1.2,
 
95
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
96
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
97
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
98
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
99
  input_ids = input_ids.to(model.device)
100
 
101
+ streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
102
  generate_kwargs = dict(
103
  {"input_ids": input_ids},
104
  streamer=streamer,
105
  max_new_tokens=max_new_tokens,
106
+ do_sample=True,
107
  top_p=top_p,
108
  top_k=top_k,
109
  temperature=temperature,
 
123
  fn=generate,
124
  chatbot=gr.Chatbot(height=400, label = "AvvoChat", show_copy_button=True, avatar_images=("users.jpg","AvvoVhat.png"), show_share_button=True),
125
  textbox=gr.Textbox(placeholder="Inserisci la tua domanda", container=False, scale=7),
126
+ additional_inputs=[
127
+ gr.Textbox(label="System prompt", lines=6),
128
+ gr.Slider(
129
+ label="Max new tokens",
130
+ minimum=1,
131
+ maximum=MAX_MAX_NEW_TOKENS,
132
+ step=1,
133
+ value=DEFAULT_MAX_NEW_TOKENS,
134
+ ),
135
+ gr.Slider(
136
+ label="Temperature",
137
+ minimum=0.1,
138
+ maximum=4.0,
139
+ step=0.1,
140
+ value=0.6,
141
+ ),
142
+ gr.Slider(
143
+ label="Top-p (nucleus sampling)",
144
+ minimum=0.05,
145
+ maximum=1.0,
146
+ step=0.05,
147
+ value=0.9,
148
+ ),
149
+ gr.Slider(
150
+ label="Top-k",
151
+ minimum=1,
152
+ maximum=1000,
153
+ step=1,
154
+ value=50,
155
+ ),
156
+ gr.Slider(
157
+ label="Repetition penalty",
158
+ minimum=1.0,
159
+ maximum=2.0,
160
+ step=0.05,
161
+ value=1.2,
162
+ ),
163
+ ],
164
  submit_btn ="Chiedi all'AvvoChat ",
165
  retry_btn = "Rigenera",
166
  undo_btn = None,
167
  clear_btn = "Pulisci chat",
168
  fill_height = True,
169
  theme = "gstaff/sketch",
 
 
170
  examples=[
171
  ["Posso fare una grigliata sul balcone di casa?"],
172
  ["Se esco di casa senza documento di identità posso essere multato?"],
 
187
 
188
 
189
  if __name__ == "__main__":
190
+ demo.queue(max_size=20).launch()