brkznb commited on
Commit
6aa69da
·
verified ·
1 Parent(s): 6646800

Update app.py

Browse files

Fixed some changes

Files changed (1) hide show
  1. app.py +37 -38
app.py CHANGED
@@ -1,42 +1,48 @@
1
  import torch
2
- from trl import SFTTrainer
 
 
 
 
 
3
  from peft import LoraConfig
4
- from datasets import load_dataset
5
- from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline)
6
- import ipywidgets as widgets
7
- from IPython.display import display
8
  import gradio as gr
9
 
10
- llama_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path = "aboonaji/llama2finetune-v2",
11
- quantization_config = BitsAndBytesConfig(load_in_4bit = True, bnb_4bit_compute_dtype = getattr(torch, "float16"), bnb_4bit_quant_type = "nf4"))
 
 
 
 
 
 
 
 
12
  llama_model.config.use_cache = False
13
  llama_model.config.pretraining_tp = 1
14
 
15
- llama_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = "aboonaji/llama2finetune-v2", trust_remote_code = True)
 
 
 
16
  llama_tokenizer.pad_token = llama_tokenizer.eos_token
17
  llama_tokenizer.padding_side = "right"
18
 
19
- training_arguments = TrainingArguments(output_dir = "./results", per_device_train_batch_size = 1, max_steps = 100)
20
- llama_sft_trainer = SFTTrainer(model = llama_model,
21
- args = training_arguments,
22
- train_dataset = load_dataset(path = "aboonaji/wiki_medical_terms_llam2_format", split = "train"),
23
- tokenizer = llama_tokenizer,
24
- peft_config = LoraConfig(task_type = "CAUSAL_LM", r = 16, lora_alpha = 16, lora_dropout = 0.1),
25
- dataset_text_field = "text")
26
-
27
- llama_sft_trainer.train()
28
-
29
- generator = pipeline("text-generation", model=llama_model, tokenizer=llama_tokenizer, max_length=500)
30
 
31
  # In-memory user database
32
  user_db = {}
33
 
34
- # Response function
35
  def generate_response(prompt):
36
  response = generator(f"<s>[INST] {prompt} [/INST]")[0]["generated_text"]
37
  return response
38
 
39
- # Sign-up and login logic
40
  def signup_user(new_username, new_password):
41
  if not new_username or not new_password:
42
  return "❌ No input. Please provide both username and password."
@@ -48,10 +54,10 @@ def signup_user(new_username, new_password):
48
  def login_user(username, password):
49
  if username in user_db and user_db[username] == password:
50
  return (
51
- gr.update(visible=True), # Show chat UI
52
- gr.update(visible=False), # Hide login UI
53
- "", # Clear login message
54
- gr.update(selected=3) # Switch to Chat tab
55
  )
56
  return (
57
  gr.update(visible=False),
@@ -62,9 +68,9 @@ def login_user(username, password):
62
 
63
  def logout_user():
64
  return (
65
- gr.update(visible=False), # Hide chat UI
66
- gr.update(visible=True), # Show login UI
67
- gr.update(selected=0) # Switch to Landing tab
68
  )
69
 
70
  with gr.Blocks(theme="soft", css="""
@@ -130,23 +136,16 @@ with gr.Blocks(theme="soft", css="""
130
  inputs=prompt
131
  )
132
 
133
-
134
-
135
  logout_btn.click(fn=logout_user, outputs=[chat_ui, login_ui, tabs])
136
 
137
- # Tab navigation logic
138
  to_signup = gr.Button(visible=False)
139
  to_login = gr.Button(visible=False)
140
 
141
  to_signup.click(lambda: gr.update(selected=1), outputs=tabs)
142
  to_login.click(lambda: gr.update(selected=2), outputs=tabs)
143
  create_account_btn.click(fn=signup_user, inputs=[new_username, new_password], outputs=signup_msg)
144
- login_btn.click(
145
- fn=login_user,
146
- inputs=[username, password],
147
- outputs=[chat_ui, login_ui, login_msg, tabs]
148
- )
149
  submit_btn.click(fn=generate_response, inputs=prompt, outputs=response)
150
 
151
- # Launch the interface
152
- demo.launch(share=True)
 
1
  import torch
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoTokenizer,
5
+ BitsAndBytesConfig,
6
+ pipeline
7
+ )
8
  from peft import LoraConfig
 
 
 
 
9
  import gradio as gr
10
 
11
+ # Load pre-trained model
12
+ llama_model = AutoModelForCausalLM.from_pretrained(
13
+ pretrained_model_name_or_path="aboonaji/llama2finetune-v2",
14
+ quantization_config=BitsAndBytesConfig(
15
+ load_in_4bit=True,
16
+ bnb_4bit_compute_dtype=getattr(torch, "float16"),
17
+ bnb_4bit_quant_type="nf4"
18
+ ),
19
+ device_map="auto"
20
+ )
21
  llama_model.config.use_cache = False
22
  llama_model.config.pretraining_tp = 1
23
 
24
+ llama_tokenizer = AutoTokenizer.from_pretrained(
25
+ pretrained_model_name_or_path="aboonaji/llama2finetune-v2",
26
+ trust_remote_code=True
27
+ )
28
  llama_tokenizer.pad_token = llama_tokenizer.eos_token
29
  llama_tokenizer.padding_side = "right"
30
 
31
+ # Pipeline for inference
32
+ generator = pipeline(
33
+ "text-generation",
34
+ model=llama_model,
35
+ tokenizer=llama_tokenizer,
36
+ max_length=500
37
+ )
 
 
 
 
38
 
39
  # In-memory user database
40
  user_db = {}
41
 
 
42
  def generate_response(prompt):
43
  response = generator(f"<s>[INST] {prompt} [/INST]")[0]["generated_text"]
44
  return response
45
 
 
46
  def signup_user(new_username, new_password):
47
  if not new_username or not new_password:
48
  return "❌ No input. Please provide both username and password."
 
54
  def login_user(username, password):
55
  if username in user_db and user_db[username] == password:
56
  return (
57
+ gr.update(visible=True),
58
+ gr.update(visible=False),
59
+ "",
60
+ gr.update(selected=3)
61
  )
62
  return (
63
  gr.update(visible=False),
 
68
 
69
  def logout_user():
70
  return (
71
+ gr.update(visible=False),
72
+ gr.update(visible=True),
73
+ gr.update(selected=0)
74
  )
75
 
76
  with gr.Blocks(theme="soft", css="""
 
136
  inputs=prompt
137
  )
138
 
 
 
139
  logout_btn.click(fn=logout_user, outputs=[chat_ui, login_ui, tabs])
140
 
141
+ # Navigation and interaction logic
142
  to_signup = gr.Button(visible=False)
143
  to_login = gr.Button(visible=False)
144
 
145
  to_signup.click(lambda: gr.update(selected=1), outputs=tabs)
146
  to_login.click(lambda: gr.update(selected=2), outputs=tabs)
147
  create_account_btn.click(fn=signup_user, inputs=[new_username, new_password], outputs=signup_msg)
148
+ login_btn.click(fn=login_user, inputs=[username, password], outputs=[chat_ui, login_ui, login_msg, tabs])
 
 
 
 
149
  submit_btn.click(fn=generate_response, inputs=prompt, outputs=response)
150
 
151
+ demo.launch()