File size: 6,894 Bytes
4c90268
 
 
99a86e4
825db5a
 
 
 
 
a145ef1
4c90268
 
 
 
 
 
 
 
3a5ce72
 
4c90268
 
 
3a5ce72
4c90268
 
 
 
 
 
825db5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a5ce72
4c90268
 
 
 
3a5ce72
4c90268
0219fdf
4c90268
 
 
 
825db5a
 
 
 
 
4c90268
98aeb40
825db5a
98aeb40
825db5a
98aeb40
4c90268
 
825db5a
4c90268
 
 
 
0219fdf
4c90268
 
 
 
 
 
 
0219fdf
a145ef1
4c90268
0219fdf
4c90268
 
 
 
 
 
 
 
 
 
a145ef1
cb55591
2ecdaf0
3a5ce72
4e85244
2d3694f
4e85244
 
 
 
 
34b6d12
2d3694f
 
3a5ce72
0219fdf
66bc4cd
 
 
 
 
 
 
0219fdf
 
 
 
 
 
f4d2a34
66bc4cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe57e60
3a5ce72
 
c3b55a3
 
 
 
3a5ce72
2ecdaf0
66bc4cd
3a5ce72
2ecdaf0
 
272b207
 
97c3be5
 
272b207
 
2ecdaf0
66bc4cd
cb55591
0219fdf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import os
from threading import Thread
from typing import Iterator

import weaviate
from haystack.components.builders import PromptBuilder
from sentence_transformers import SentenceTransformer
from haystack import Pipeline

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))



if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"


if torch.cuda.is_available():
    model_id = "AndreaAlessandrelli4/AvvoChat_AITA_v04"
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.use_default_system_prompt = False

    model1 = SentenceTransformer('intfloat/multilingual-e5-large')

    key='rJ2yBbVQedQvaSH3TABtf9KcuQsnLNRPXguq'
    url = "https://mmchpi0yssanukk5t3ofta.c0.europe-west3.gcp.weaviate.cloud"
    client = weaviate.Client(
        url = url,  
        auth_client_secret=weaviate.auth.AuthApiKey(api_key=key),
    )



def prompt_template(materiali, query):
    mat = ''
    for i, doc in enumerate(materiali):
        mat += f'DOCUMENTO {i+1}: {doc['content']};\n'
    prompt_template = f"""
    Basandoti sulle tue conoscenze e usando le informazioni che ti fornisco di seguito.
    CONTESTO:
      {mat}
      
    Rispondi alla seguente domanda in modo esaustivo e conciso in massimo 100 parole, evitando inutili giri di parole o ripetizioni, .
      {query}
    """
    return prompt_template



def richiamo_materiali(query, vett_query, alpha=1.0, N_items=5):
    try:
        materiali = client.query.get("Default", ["content"]).with_hybrid(
            query=text_query,
            vector=vett_query,
            alpha=alpha,
            fusion_type=HybridFusion.RELATIVE_SCORE,
        ).with_additional(["score"]).with_limit(N_items).do()
        
        mat = [{'score':i['_additional']['score'],'content':i['content']} for i in materiali['data']['Get']['Default']]
    except:
        mat =[{'score':0, 'content':'NESSUN MATERIALE FORNITO'}]
    
    return mat



@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:

    embeddings_query = model1.encode('query: '+message, normalize_embeddings=True)
    vettor_query = embeddings_query
    materiali = richiamo_materiali(message, vettor_query)
    prompt_finale = prompt_template(materiali, message)
    conversation = []
    conversation.append({"role": "system", "content": 
                         '''Sei un an assistente AI di nome 'AvvoChat' specializzato nel rispondere a domande riguardanti la legge Italiana. 
                         Rispondi in lingua italiana in modo chiaro, semplice ed esaustivo alle domande che ti vengono fornite.
                         Le risposte devono essere sintetiche e chiare di massimo 100 parole o anche più corte. 
                         Firmati alla fine di ogni risposta '-AvvoChat'.'''})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": prompt_finale})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Chat troppo lunga superati {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(height=400, label = "AvvoChat", show_copy_button=True, avatar_images=("users.jpg","AvvoVhat.png"), layout="bubble",show_share_button=True),
    textbox=gr.Textbox(placeholder="Inserisci la tua domanda", container=False, scale=7),
    submit_btn ="Chiedi all'AvvoChat ",
    retry_btn = "Rigenera",
    undo_btn = None,
    clear_btn = "Pulisci chat",
    fill_height = True,
    theme = "gstaff/sketch",
    #title="Avvo-Chat",
    #description="""Fai una domanda riguardante la legge italiana all'AvvoChat e ricevi una spiegazione semplice al tuo dubbio.""",
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Posso fare una grigliata sul balcone di casa?"],
        ["Se esco di casa senza documento di identità posso essere multato?"],
        ["Le persone single possono adottare un bambino?"],
        ["Posso usare un'immagine prodotto dall'intelligenza artificiale?"],
    ],
)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown("# AvvoChat")
    gr.Markdown("Fai una domanda riguardante la legge italiana all'AvvoChat e ricevi una spiegazione semplice al tuo dubbio.")
    with gr.Row():
        with gr.Column(scale=0.5, min_width = 100):
            gr.Image("AvvoVhat.png", width = 50, height=200, 
                     show_label=False, show_share_button=False, show_download_button=False, container=False),
        with gr.Column(scale=6):
           chat_interface.render()


if __name__ == "__main__":
    demo.queue(max_size=20).launch()