File size: 6,921 Bytes
4c90268
 
 
6dc02f1
a145ef1
4c90268
 
5aeefb9
 
4c90268
 
 
 
 
 
3a5ce72
 
4c90268
 
 
3a5ce72
4c90268
 
8ca69fb
 
4c90268
 
 
825db5a
68186c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825db5a
68186c8
 
 
 
825db5a
 
 
68186c8
 
 
 
 
 
 
 
 
 
 
 
 
 
825db5a
 
3a5ce72
6dc02f1
4c90268
 
 
 
6dc02f1
4c90268
6dc02f1
c422365
 
 
4c90268
68186c8
 
 
 
 
4c90268
98aeb40
d3fcfd1
 
98aeb40
d3fcfd1
98aeb40
4c90268
 
68186c8
4c90268
 
 
 
d3fcfd1
4c90268
 
6dc02f1
4c90268
 
 
 
6dc02f1
c422365
 
 
 
 
4c90268
 
 
 
 
 
 
 
a145ef1
cb55591
2ecdaf0
3a5ce72
8a3ac07
2d3694f
6dc02f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e85244
 
 
 
 
34b6d12
3a5ce72
c3b55a3
 
 
 
3a5ce72
2ecdaf0
66bc4cd
3a5ce72
2ecdaf0
 
272b207
c422365
97c3be5
 
272b207
 
2ecdaf0
66bc4cd
cb55591
6dc02f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
import weaviate
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))



if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"


if torch.cuda.is_available():
    model_id = "AndreaAlessandrelli4/AvvoChat_AITA_v04"
    commit_id = "1e6356e06212d32dab4244c0a75eaa1eef73ffc6"
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True, revision=commit_id)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.use_default_system_prompt = False


    #key='4vNfIDO8PmFwCloxA40y2b4PSHm62vmcuPoM'
    #url = "https://mmchpi0yssanukk5t3ofta.c0.europe-west3.gcp.weaviate.cloud"
    #client = weaviate.Client(
    #    url = url,  
    #    auth_client_secret=weaviate.auth.AuthApiKey(api_key=key),
    #)


#def prompt_template(materiali, query):
#    mat = ''
#    for i, doc in enumerate(materiali):
#        mat += f"""DOCUMENTO {i+1}: {doc["contenuto"]};\n"""
#    prompt_template = f"""
#    Basandoti sulle tue conoscenze e usando le informazioni che ti fornisco di seguito.
#    CONTESTO:
#      {mat}
      
#    Rispondi alla seguente domanda in modo esaustivo e conciso in massimo 100 parole, evitando inutili giri di parole o ripetizioni, .
#      {query}
#    """
#    return prompt_template



#def richiamo_materiali(query, vett_query, alpha=1.0, N_items=5):
#    try:
#        materiali = client.query.get("Default", ["content"]).with_hybrid(
#            query=text_query,
#            vector=vett_query,
#            alpha=alpha,
#            fusion_type=HybridFusion.RELATIVE_SCORE,
#        ).with_additional(["score"]).with_limit(N_items).do()
#        
#        mat = [{"score":i["_additional"]["score"],'contenuto':i["content"]} for i in materiali["data"]["Get"]["Default"]]
#    except:
#        mat =[{"score":0, "contenuto":'NESSUN MATERIALE FORNITO'}]
#    
#    return mat




@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    #model1 = SentenceTransformer('intfloat/multilingual-e5-large')
    #embeddings_query = model1.encode('query: '+str(message), normalize_embeddings=True)
    #vettor_query = embeddings_query
    #materiali = richiamo_materiali(message, vettor_query)
    #prompt_finale = prompt_template(materiali, message)
    conversation = []
    conversation.append({"role": "system", "content": 
                         '''Sei un an assistente AI di nome 'AvvoChat' specializzato nel rispondere la legge Italiana. 
                         Se le domande non riguardano questioni legali astieniti dal rispondere e scrivi "Sono specializzato in domande di tipo legale: non sono accurato su questo tipo di domande".
                         Rispondi in lingua italiana in modo chiaro, semplice ed esaustivo alle domande che ti vengono fornite.
                         Le risposte devono essere chiare e semplici con argomentazioni valide e puntuali. 
                         Firmati alla fine di ogni risposta '-AvvoChat'.'''})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Conversazione troppo lunga: sforati i {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(height=400, label = "AvvoChat", show_copy_button=True, avatar_images=("users.jpg","AvvoVhat.png"), show_share_button=True),
    textbox=gr.Textbox(placeholder="Inserisci la tua domanda", container=False, scale=7),
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    submit_btn ="Chiedi all'AvvoChat ",
    retry_btn = "Rigenera",
    undo_btn = None,
    clear_btn = "Pulisci chat",
    fill_height = True,
    theme = "gstaff/sketch",
    examples=[
        ["Posso fare una grigliata sul balcone di casa?"],
        ["Se esco di casa senza documento di identità posso essere multato?"],
        ["Le persone single possono adottare un bambino?"],
        ["Posso usare un'immagine prodotto dall'intelligenza artificiale?"],
    ],
)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown("# AvvoChat")
    gr.Markdown("Fai una domanda riguardante la legge italiana all'AvvoChat e ricevi una spiegazione semplice al tuo dubbio.")
    with gr.Row():
        with gr.Column(scale=1, min_width = 100):
            gr.Image("AvvoVhat.png", width = 50, height=200, 
                     show_label=False, show_share_button=False, show_download_button=False, container=False),
        with gr.Column(scale=6):
           chat_interface.render()


if __name__ == "__main__":
    demo.queue(max_size=20).launch()