aauu1234 commited on
Commit
1c22ce6
·
1 Parent(s): cf1021e
Files changed (2) hide show
  1. app.py +19 -17
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import traceback
5
 
6
  model_name_or_path = "stephenlzc/dolphin-llama3-zh-cn-uncensored"
@@ -10,25 +10,24 @@ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  print(f"Using device: {device}")
12
 
 
 
 
 
 
 
 
 
13
  model = AutoModelForCausalLM.from_pretrained(
14
- model_name_or_path,
15
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
 
16
  trust_remote_code=True
17
- ).to(device)
18
 
19
  print("Tokenizer loaded successfully")
20
  print("Model loaded successfully")
21
 
22
- # Test inference
23
- test_messages = [
24
- {"role": "system", "content": "You are a helpful assistant."},
25
- {"role": "user", "content": "Hello, who are you?"},
26
- ]
27
- test_input_ids = tokenizer.apply_chat_template(conversation=test_messages, tokenize=True, return_tensors="pt").to(device)
28
- test_output = model.generate(inputs=test_input_ids, max_new_tokens=50)
29
- test_response = tokenizer.decode(test_output[0])
30
- print("Test response:", test_response)
31
-
32
  def generate_response(system_message, user_message):
33
  try:
34
  messages = [
@@ -37,13 +36,16 @@ def generate_response(system_message, user_message):
37
  ]
38
 
39
  input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, return_tensors="pt").to(device)
 
40
 
41
  output = model.generate(
42
  inputs=input_ids,
43
- max_new_tokens=512
 
 
44
  )
45
 
46
- generated_response = tokenizer.decode(output[0])
47
  return generated_response
48
  except Exception as e:
49
  error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
@@ -56,7 +58,7 @@ iface = gr.Interface(
56
  gr.Textbox(label="User Message")
57
  ],
58
  outputs=gr.Textbox(label="Generated Response"),
59
- title="llama3 cn uncensored Chatbot (GPU-enabled)"
60
  )
61
 
62
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  import traceback
5
 
6
  model_name_or_path = "stephenlzc/dolphin-llama3-zh-cn-uncensored"
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  print(f"Using device: {device}")
12
 
13
+ # Configure quantization
14
+ quantization_config = BitsAndBytesConfig(
15
+ load_in_8bit=True,
16
+ llm_int8_threshold=6.0,
17
+ llm_int8_has_fp16_weight=False,
18
+ )
19
+
20
+ # Load the model with quantization
21
  model = AutoModelForCausalLM.from_pretrained(
22
+ model_name_or_path,
23
+ quantization_config=quantization_config,
24
+ device_map="auto",
25
  trust_remote_code=True
26
+ )
27
 
28
  print("Tokenizer loaded successfully")
29
  print("Model loaded successfully")
30
 
 
 
 
 
 
 
 
 
 
 
31
  def generate_response(system_message, user_message):
32
  try:
33
  messages = [
 
36
  ]
37
 
38
  input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, return_tensors="pt").to(device)
39
+ attention_mask = torch.ones_like(input_ids).to(device)
40
 
41
  output = model.generate(
42
  inputs=input_ids,
43
+ attention_mask=attention_mask,
44
+ max_new_tokens=512,
45
+ pad_token_id=tokenizer.eos_token_id
46
  )
47
 
48
+ generated_response = tokenizer.decode(output[0], skip_special_tokens=True)
49
  return generated_response
50
  except Exception as e:
51
  error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
 
58
  gr.Textbox(label="User Message")
59
  ],
60
  outputs=gr.Textbox(label="Generated Response"),
61
+ title="llama3 cn uncensored Chatbot (GPU-enabled, 8-bit quantized)"
62
  )
63
 
64
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -3,4 +3,5 @@ huggingface_hub
3
  torch
4
  transformers
5
  accelerate
6
- sentencepiece
 
 
3
  torch
4
  transformers
5
  accelerate
6
+ sentencepiece
7
+ bitsandbytes