522H0134-NguyenNhatHuy commited on
Commit
550d11b
·
verified ·
1 Parent(s): ca763c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -12
app.py CHANGED
@@ -4,32 +4,29 @@ from peft import PeftModel
4
  import gradio as gr
5
 
6
  # 1. Cấu hình tên mô hình gốc (base model)
7
- base_model_name = "sail/Sailor-1.8B-Chat"
8
 
9
- # 2. Load tokenizer từ thư mục adapter
10
- adapter_path = "./Sailor-1.8B-Chat-SFT"
11
- tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
12
 
13
  # 3. Load base model và adapter
14
- model = AutoModelForCausalLM.from_pretrained(
15
  base_model_name,
16
  torch_dtype=torch.float16,
17
  device_map="auto",
18
  trust_remote_code=True
19
  )
20
- model = PeftModel.from_pretrained(model, adapter_path, torch_dtype=torch.float16)
21
  model.eval()
22
 
23
- # 4. Hàm trò chuyện
24
  def chat_fn(message, history):
25
- # Biên dịch lịch sử hội thoại sang định dạng messages
26
  messages = []
27
  for user_msg, bot_msg in history:
28
  messages.append({"role": "user", "content": user_msg})
29
  messages.append({"role": "assistant", "content": bot_msg})
30
  messages.append({"role": "user", "content": message})
31
 
32
- # Áp dụng chat template chuẩn
33
  input_ids = tokenizer.apply_chat_template(
34
  messages,
35
  return_tensors="pt",
@@ -37,7 +34,6 @@ def chat_fn(message, history):
37
  truncation=True
38
  ).to(model.device)
39
 
40
- # Sinh phản hồi
41
  with torch.no_grad():
42
  outputs = model.generate(
43
  input_ids=input_ids,
@@ -51,9 +47,7 @@ def chat_fn(message, history):
51
  eos_token_id=tokenizer.eos_token_id
52
  )
53
 
54
- # Tách phần phản hồi
55
  generated_text = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True).strip()
56
-
57
  return generated_text
58
 
59
  # 5. Giao diện Gradio
 
4
  import gradio as gr
5
 
6
  # 1. Cấu hình tên mô hình gốc (base model)
7
+ base_model_name = "sail/Sailor-1.8B-Chat"
8
 
9
+ # 2. Load tokenizer từ thư mục hiện tại
10
+ tokenizer = AutoTokenizer.from_pretrained(".", trust_remote_code=True)
 
11
 
12
  # 3. Load base model và adapter
13
+ base_model = AutoModelForCausalLM.from_pretrained(
14
  base_model_name,
15
  torch_dtype=torch.float16,
16
  device_map="auto",
17
  trust_remote_code=True
18
  )
19
+ model = PeftModel.from_pretrained(base_model, ".", torch_dtype=torch.float16)
20
  model.eval()
21
 
22
+ # 4. Hàm xử lý hội thoại
23
  def chat_fn(message, history):
 
24
  messages = []
25
  for user_msg, bot_msg in history:
26
  messages.append({"role": "user", "content": user_msg})
27
  messages.append({"role": "assistant", "content": bot_msg})
28
  messages.append({"role": "user", "content": message})
29
 
 
30
  input_ids = tokenizer.apply_chat_template(
31
  messages,
32
  return_tensors="pt",
 
34
  truncation=True
35
  ).to(model.device)
36
 
 
37
  with torch.no_grad():
38
  outputs = model.generate(
39
  input_ids=input_ids,
 
47
  eos_token_id=tokenizer.eos_token_id
48
  )
49
 
 
50
  generated_text = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True).strip()
 
51
  return generated_text
52
 
53
  # 5. Giao diện Gradio