File size: 6,634 Bytes
aa591fb
 
 
 
 
 
 
 
 
 
 
 
 
0ed7157
aa591fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
!pip uninstall accelerate peft bitsandbytes transformers trl -y
!pip install accelerate peft==0.13.2 bitsandbytes transformers trl==0.12.0
!pip install huggingface_hub
!pip install ipywidgets
!pip install -q gradio

import torch
from trl import SFTTrainer
from peft import LoraConfig
from datasets import load_dataset
from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline)
import ipywidgets as widgets
from IPython.display import display
import gradio as gr

llama_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path = "aboonaji/llama2finetune-v2",
                                                   quantization_config = BitsAndBytesConfig(load_in_4bit = True, bnb_4bit_compute_dtype = getattr(torch, "float16"), bnb_4bit_quant_type = "nf4"))
llama_model.config.use_cache = False
llama_model.config.pretraining_tp = 1

llama_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = "aboonaji/llama2finetune-v2", trust_remote_code = True)
llama_tokenizer.pad_token = llama_tokenizer.eos_token
llama_tokenizer.padding_side = "right"

training_arguments = TrainingArguments(output_dir = "./results", per_device_train_batch_size = 1, max_steps = 100)
llama_sft_trainer = SFTTrainer(model = llama_model,
                               args = training_arguments,
                               train_dataset = load_dataset(path = "aboonaji/wiki_medical_terms_llam2_format", split = "train"),
                               tokenizer = llama_tokenizer,
                               peft_config = LoraConfig(task_type = "CAUSAL_LM", r = 16, lora_alpha = 16, lora_dropout = 0.1),
                               dataset_text_field = "text")

llama_sft_trainer.train()

generator = pipeline("text-generation", model=llama_model, tokenizer=llama_tokenizer, max_length=500)

# In-memory user database
user_db = {}

# Response function
def generate_response(prompt):
    response = generator(f"<s>[INST] {prompt} [/INST]")[0]["generated_text"]
    return response

# Sign-up and login logic
def signup_user(new_username, new_password):
    if not new_username or not new_password:
        return "❌ No input. Please provide both username and password."
    if new_username in user_db:
        return "❌ Username already exists."
    user_db[new_username] = new_password
    return "βœ… Account created! Please log in."

def login_user(username, password):
    if username in user_db and user_db[username] == password:
        return (
            gr.update(visible=True),   # Show chat UI
            gr.update(visible=False),  # Hide login UI
            "",                         # Clear login message
            gr.update(selected=3)      # Switch to Chat tab
        )
    return (
        gr.update(visible=False),
        gr.update(visible=True),
        "❌ Invalid credentials",
        gr.update(selected=2)
    )

def logout_user():
    return (
        gr.update(visible=False),  # Hide chat UI
        gr.update(visible=True),   # Show login UI
        gr.update(selected=0)      # Switch to Landing tab
    )

with gr.Blocks(theme="soft", css="""
#create-btn button,
#login-btn button,
#submit-btn button,
#logout-btn button {
    font-size: 12px !important;
    padding: 4px 8px !important;
    height: 30px !important;
    width: auto !important;
    min-width: 80px !important;
}
""") as demo:
    with gr.Tabs(selected=0, elem_id="tabs") as tabs:
        with gr.Tab("Home"):
            with gr.Column(elem_id="landing-container") as landing_ui:
                gr.Markdown("# 🏠 Welcome to MEDChat AI")
                gr.Markdown("---")
                gr.Markdown("#### πŸ”§ Features:")
                gr.Markdown("- Medical Q&A support\n- Easy-to-use chatbot interface\n- Privacy-focused with local data")
                gr.Markdown("---")
                gr.Markdown("Β© 2025 MEDChat AI | Finetuned by LLaMA 2 | Created for educational purposes")

        with gr.Tab("Sign Up"):
            with gr.Column() as signup_ui:
                gr.Markdown("### πŸ“ Sign Up")
                new_username = gr.Textbox(label="New Username")
                new_password = gr.Textbox(label="New Password", type="password")
                create_account_btn = gr.Button("Create Account", elem_id="create-btn")
                signup_msg = gr.Markdown()

        with gr.Tab("Login"):
            with gr.Column(visible=True) as login_ui:
                gr.Markdown("### πŸ” Login")
                username = gr.Textbox(label="Username")
                password = gr.Textbox(label="Password", type="password")
                login_btn = gr.Button("Login", elem_id="login-btn")
                login_msg = gr.Markdown()

        with gr.Tab("Chat"):
            with gr.Column(visible=False) as chat_ui:
                gr.Markdown("## πŸ’¬ MEDChat AI")
                gr.Markdown("What can I help with today?")
                logout_btn = gr.Button("πŸšͺ Logout", elem_id="logout-btn")
                prompt = gr.Textbox(lines=5, placeholder="Enter your prompt...")
                submit_btn = gr.Button("Submit", elem_id="submit-btn")
                response = gr.Textbox(label="Response")

                gr.Markdown("### Try one of these:")
                examples = gr.Examples(
                    examples=[
                        ["What does the immune system do?"],
                        ["What is Epistaxis?"],
                        ["Do our intestines contain germs?"],
                        ["What are allergies?"],
                        ["Should I start taking creatine?"],
                        ["What are antibiotics?"],
                        ["Why do I get sick?"],
                        ["What's the difference between bacteria and viruses?"],
                        ["Where are some places that germs hide?"],
                    ],
                    inputs=prompt
                )



                logout_btn.click(fn=logout_user, outputs=[chat_ui, login_ui, tabs])

    # Tab navigation logic
    to_signup = gr.Button(visible=False)
    to_login = gr.Button(visible=False)

    to_signup.click(lambda: gr.update(selected=1), outputs=tabs)
    to_login.click(lambda: gr.update(selected=2), outputs=tabs)
    create_account_btn.click(fn=signup_user, inputs=[new_username, new_password], outputs=signup_msg)
    login_btn.click(
        fn=login_user,
        inputs=[username, password],
        outputs=[chat_ui, login_ui, login_msg, tabs]
    )
    submit_btn.click(fn=generate_response, inputs=prompt, outputs=response)

# Launch the interface
demo.launch(share=True)