Spaces:
Runtime error
Runtime error
| import torch | |
| from trl import SFTTrainer | |
| from peft import LoraConfig | |
| from datasets import load_dataset | |
| from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline) | |
| import ipywidgets as widgets | |
| from IPython.display import display | |
| import gradio as gr | |
| llama_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path = "aboonaji/llama2finetune-v2", | |
| quantization_config = BitsAndBytesConfig(load_in_4bit = True, bnb_4bit_compute_dtype = getattr(torch, "float16"), bnb_4bit_quant_type = "nf4")) | |
| llama_model.config.use_cache = False | |
| llama_model.config.pretraining_tp = 1 | |
| llama_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = "aboonaji/llama2finetune-v2", trust_remote_code = True) | |
| llama_tokenizer.pad_token = llama_tokenizer.eos_token | |
| llama_tokenizer.padding_side = "right" | |
| training_arguments = TrainingArguments(output_dir = "./results", per_device_train_batch_size = 1, max_steps = 100) | |
| llama_sft_trainer = SFTTrainer(model = llama_model, | |
| args = training_arguments, | |
| train_dataset = load_dataset(path = "aboonaji/wiki_medical_terms_llam2_format", split = "train"), | |
| tokenizer = llama_tokenizer, | |
| peft_config = LoraConfig(task_type = "CAUSAL_LM", r = 16, lora_alpha = 16, lora_dropout = 0.1), | |
| dataset_text_field = "text") | |
| llama_sft_trainer.train() | |
| generator = pipeline("text-generation", model=llama_model, tokenizer=llama_tokenizer, max_length=500) | |
| # In-memory user database | |
| user_db = {} | |
| # Response function | |
| def generate_response(prompt): | |
| response = generator(f"<s>[INST] {prompt} [/INST]")[0]["generated_text"] | |
| return response | |
| # Sign-up and login logic | |
| def signup_user(new_username, new_password): | |
| if not new_username or not new_password: | |
| return "β No input. Please provide both username and password." | |
| if new_username in user_db: | |
| return "β Username already exists." | |
| user_db[new_username] = new_password | |
| return "β Account created! Please log in." | |
| def login_user(username, password): | |
| if username in user_db and user_db[username] == password: | |
| return ( | |
| gr.update(visible=True), # Show chat UI | |
| gr.update(visible=False), # Hide login UI | |
| "", # Clear login message | |
| gr.update(selected=3) # Switch to Chat tab | |
| ) | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| "β Invalid credentials", | |
| gr.update(selected=2) | |
| ) | |
| def logout_user(): | |
| return ( | |
| gr.update(visible=False), # Hide chat UI | |
| gr.update(visible=True), # Show login UI | |
| gr.update(selected=0) # Switch to Landing tab | |
| ) | |
| with gr.Blocks(theme="soft", css=""" | |
| #create-btn button, | |
| #login-btn button, | |
| #submit-btn button, | |
| #logout-btn button { | |
| font-size: 12px !important; | |
| padding: 4px 8px !important; | |
| height: 30px !important; | |
| width: auto !important; | |
| min-width: 80px !important; | |
| } | |
| """) as demo: | |
| with gr.Tabs(selected=0, elem_id="tabs") as tabs: | |
| with gr.Tab("Home"): | |
| with gr.Column(elem_id="landing-container") as landing_ui: | |
| gr.Markdown("# π Welcome to MEDChat AI") | |
| gr.Markdown("---") | |
| gr.Markdown("#### π§ Features:") | |
| gr.Markdown("- Medical Q&A support\n- Easy-to-use chatbot interface\n- Privacy-focused with local data") | |
| gr.Markdown("---") | |
| gr.Markdown("Β© 2025 MEDChat AI | Finetuned by LLaMA 2 | Created for educational purposes") | |
| with gr.Tab("Sign Up"): | |
| with gr.Column() as signup_ui: | |
| gr.Markdown("### π Sign Up") | |
| new_username = gr.Textbox(label="New Username") | |
| new_password = gr.Textbox(label="New Password", type="password") | |
| create_account_btn = gr.Button("Create Account", elem_id="create-btn") | |
| signup_msg = gr.Markdown() | |
| with gr.Tab("Login"): | |
| with gr.Column(visible=True) as login_ui: | |
| gr.Markdown("### π Login") | |
| username = gr.Textbox(label="Username") | |
| password = gr.Textbox(label="Password", type="password") | |
| login_btn = gr.Button("Login", elem_id="login-btn") | |
| login_msg = gr.Markdown() | |
| with gr.Tab("Chat"): | |
| with gr.Column(visible=False) as chat_ui: | |
| gr.Markdown("## π¬ MEDChat AI") | |
| gr.Markdown("What can I help with today?") | |
| logout_btn = gr.Button("πͺ Logout", elem_id="logout-btn") | |
| prompt = gr.Textbox(lines=5, placeholder="Enter your prompt...") | |
| submit_btn = gr.Button("Submit", elem_id="submit-btn") | |
| response = gr.Textbox(label="Response") | |
| gr.Markdown("### Try one of these:") | |
| examples = gr.Examples( | |
| examples=[ | |
| ["What does the immune system do?"], | |
| ["What is Epistaxis?"], | |
| ["Do our intestines contain germs?"], | |
| ["What are allergies?"], | |
| ["Should I start taking creatine?"], | |
| ["What are antibiotics?"], | |
| ["Why do I get sick?"], | |
| ["What's the difference between bacteria and viruses?"], | |
| ["Where are some places that germs hide?"], | |
| ], | |
| inputs=prompt | |
| ) | |
| logout_btn.click(fn=logout_user, outputs=[chat_ui, login_ui, tabs]) | |
| # Tab navigation logic | |
| to_signup = gr.Button(visible=False) | |
| to_login = gr.Button(visible=False) | |
| to_signup.click(lambda: gr.update(selected=1), outputs=tabs) | |
| to_login.click(lambda: gr.update(selected=2), outputs=tabs) | |
| create_account_btn.click(fn=signup_user, inputs=[new_username, new_password], outputs=signup_msg) | |
| login_btn.click( | |
| fn=login_user, | |
| inputs=[username, password], | |
| outputs=[chat_ui, login_ui, login_msg, tabs] | |
| ) | |
| submit_btn.click(fn=generate_response, inputs=prompt, outputs=response) | |
| # Launch the interface | |
| demo.launch(share=True) |