Defetya commited on
Commit
6319464
·
verified ·
1 Parent(s): 163897b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +26 -1
README.md CHANGED
@@ -4,4 +4,29 @@ tags:
4
  - Russian
5
  ---
6
  Qwen 4B chat by Alibaba, SFTuned on Saiga dataset. Finetuned with EasyDeL framework on v3-8 Google TPU, provided by TRC.
7
- Модель Qwen 4B, дообученая на датасете Ильи Гусева. По моему краткому опыту общения с моделью, лучше чем Saiga-mistral. Не ошибается в падежах. Карточка модели будет дополнена после теста на Russian SuperGlue. Возможно, будет DPO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  - Russian
5
  ---
6
  Qwen 4B chat by Alibaba, SFTuned on Saiga dataset. Finetuned with EasyDeL framework on v3-8 Google TPU, provided by TRC.
7
+ Модель Qwen 4B, дообученая на датасете Ильи Гусева. По моему краткому опыту общения с моделью, лучше чем Saiga-mistral. Не ошибается в падежах. Карточка модели будет дополнена после теста на Russian SuperGlue. Возможно, будет DPO
8
+
9
+ Чтобы использовать модель, необходимо назначить eos токен как <|im_end|>. Полный код:
10
+
11
+ import torch
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+ model = AutoModelForCausalLM.from_pretrained('Defetya/qwen-4B-saiga', torch_dtype=torch.bfloat16, device_map='auto')
14
+ tokenizer = AutoTokenizer.from_pretrained('Defetya/qwen-4B-saiga')
15
+ tokenizer.eos_token_id = 151645
16
+ messages_json = [
17
+ {"role": "system", "content": "Ты - русскоязычный ассистент. Ты помогаешь пользователю и отвечаешь на его вопросы."},
18
+ ]
19
+ while True:
20
+ user_input = str(input())
21
+ messages_json.append({'role': 'user', 'content': user_input})
22
+ messages = tokenizer.apply_chat_template(messages_json, return_tensors="pt", add_generation_prompt=True).to('cuda')
23
+ generated_ids = model.generate(messages, max_new_tokens=512, do_sample=True, temperature=0.7, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id)
24
+ decoded = tokenizer.decode(generated_ids[0][len(messages[0]):])
25
+ print(decoded)
26
+ print("==============================")
27
+
28
+
29
+
30
+
31
+
32
+