Spaces:
Runtime error
Runtime error
| import os | |
| from threading import Thread | |
| from typing import Iterator | |
| import weaviate | |
| from haystack.components.builders import PromptBuilder | |
| from sentence_transformers import SentenceTransformer | |
| from haystack import Pipeline | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| MAX_MAX_NEW_TOKENS = 2048 | |
| DEFAULT_MAX_NEW_TOKENS = 1024 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| if not torch.cuda.is_available(): | |
| DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
| if torch.cuda.is_available(): | |
| model_id = "AndreaAlessandrelli4/AvvoChat_AITA_v04" | |
| model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| tokenizer.use_default_system_prompt = False | |
| model1 = SentenceTransformer('intfloat/multilingual-e5-large') | |
| key='rJ2yBbVQedQvaSH3TABtf9KcuQsnLNRPXguq' | |
| url = "https://mmchpi0yssanukk5t3ofta.c0.europe-west3.gcp.weaviate.cloud" | |
| client = weaviate.Client( | |
| url = url, | |
| auth_client_secret=weaviate.auth.AuthApiKey(api_key=key), | |
| ) | |
| def prompt_template(materiali, query): | |
| mat = '' | |
| for i, doc in enumerate(materiali): | |
| mat += f'DOCUMENTO {i+1}: {doc['content']};\n' | |
| prompt_template = f""" | |
| Basandoti sulle tue conoscenze e usando le informazioni che ti fornisco di seguito. | |
| CONTESTO: | |
| {mat} | |
| Rispondi alla seguente domanda in modo esaustivo e conciso in massimo 100 parole, evitando inutili giri di parole o ripetizioni, . | |
| {query} | |
| """ | |
| return prompt_template | |
| def richiamo_materiali(query, vett_query, alpha=1.0, N_items=5): | |
| try: | |
| materiali = client.query.get("Default", ["content"]).with_hybrid( | |
| query=text_query, | |
| vector=vett_query, | |
| alpha=alpha, | |
| fusion_type=HybridFusion.RELATIVE_SCORE, | |
| ).with_additional(["score"]).with_limit(N_items).do() | |
| mat = [{'score':i['_additional']['score'],'content':i['content']} for i in materiali['data']['Get']['Default']] | |
| except: | |
| mat =[{'score':0, 'content':'NESSUN MATERIALE FORNITO'}] | |
| return mat | |
| def generate( | |
| message: str, | |
| chat_history: list[tuple[str, str]], | |
| system_prompt: str, | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.6, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.2, | |
| ) -> Iterator[str]: | |
| embeddings_query = model1.encode('query: '+message, normalize_embeddings=True) | |
| vettor_query = embeddings_query | |
| materiali = richiamo_materiali(message, vettor_query) | |
| prompt_finale = prompt_template(materiali, message) | |
| conversation = [] | |
| conversation.append({"role": "system", "content": | |
| '''Sei un an assistente AI di nome 'AvvoChat' specializzato nel rispondere a domande riguardanti la legge Italiana. | |
| Rispondi in lingua italiana in modo chiaro, semplice ed esaustivo alle domande che ti vengono fornite. | |
| Le risposte devono essere sintetiche e chiare di massimo 100 parole o anche più corte. | |
| Firmati alla fine di ogni risposta '-AvvoChat'.'''}) | |
| for user, assistant in chat_history: | |
| conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) | |
| conversation.append({"role": "user", "content": prompt_finale}) | |
| input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") | |
| if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
| input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
| gr.Warning(f"Chat troppo lunga superati {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
| input_ids = input_ids.to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| {"input_ids": input_ids}, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| num_beams=1, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| chat_interface = gr.ChatInterface( | |
| fn=generate, | |
| chatbot=gr.Chatbot(height=400, label = "AvvoChat", show_copy_button=True, avatar_images=("users.jpg","AvvoVhat.png"), layout="bubble",show_share_button=True), | |
| textbox=gr.Textbox(placeholder="Inserisci la tua domanda", container=False, scale=7), | |
| submit_btn ="Chiedi all'AvvoChat ", | |
| retry_btn = "Rigenera", | |
| undo_btn = None, | |
| clear_btn = "Pulisci chat", | |
| fill_height = True, | |
| theme = "gstaff/sketch", | |
| #title="Avvo-Chat", | |
| #description="""Fai una domanda riguardante la legge italiana all'AvvoChat e ricevi una spiegazione semplice al tuo dubbio.""", | |
| additional_inputs=[ | |
| gr.Textbox(label="System prompt", lines=6), | |
| gr.Slider( | |
| label="Max new tokens", | |
| minimum=1, | |
| maximum=MAX_MAX_NEW_TOKENS, | |
| step=1, | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| ), | |
| gr.Slider( | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=4.0, | |
| step=0.1, | |
| value=0.6, | |
| ), | |
| gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| minimum=0.05, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.9, | |
| ), | |
| gr.Slider( | |
| label="Top-k", | |
| minimum=1, | |
| maximum=1000, | |
| step=1, | |
| value=50, | |
| ), | |
| gr.Slider( | |
| label="Repetition penalty", | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.2, | |
| ), | |
| ], | |
| stop_btn=None, | |
| examples=[ | |
| ["Posso fare una grigliata sul balcone di casa?"], | |
| ["Se esco di casa senza documento di identità posso essere multato?"], | |
| ["Le persone single possono adottare un bambino?"], | |
| ["Posso usare un'immagine prodotto dall'intelligenza artificiale?"], | |
| ], | |
| ) | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.Markdown("# AvvoChat") | |
| gr.Markdown("Fai una domanda riguardante la legge italiana all'AvvoChat e ricevi una spiegazione semplice al tuo dubbio.") | |
| with gr.Row(): | |
| with gr.Column(scale=0.5, min_width = 100): | |
| gr.Image("AvvoVhat.png", width = 50, height=200, | |
| show_label=False, show_share_button=False, show_download_button=False, container=False), | |
| with gr.Column(scale=6): | |
| chat_interface.render() | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() |