batoon commited on
Commit
c2f0918
·
1 Parent(s): 3e31dfa

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +24 -13
model.py CHANGED
@@ -3,31 +3,42 @@ from typing import Iterator
3
 
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
6
 
7
- model_id = 'meta-llama/Llama-2-7b-chat-hf'
 
 
 
 
 
 
 
8
 
9
- if torch.cuda.is_available():
10
- model = AutoModelForCausalLM.from_pretrained(
11
- model_id,
12
- torch_dtype=torch.float16,
13
- device_map='auto'
14
- )
15
- else:
16
- model = None
17
- tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
18
 
19
 
20
  def get_prompt(message: str, chat_history: list[tuple[str, str]],
21
  system_prompt: str) -> str:
22
- texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
 
23
  # The first user input is _not_ stripped
24
  do_strip = False
25
  for user_input, response in chat_history:
26
  user_input = user_input.strip() if do_strip else user_input
27
  do_strip = True
28
- texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
29
  message = message.strip() if do_strip else message
30
- texts.append(f'{message} [/INST]')
31
  return ''.join(texts)
32
 
33
 
 
3
 
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
+ from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
7
 
8
+ model_id = "TheBloke/Chronos-Beluga-v2-13B-GPTQ"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
10
+ model = AutoGPTQForCausalLM.from_quantized(model_id,
11
+ use_safetensors=True,
12
+ trust_remote_code=False,
13
+ device="cuda:0",
14
+ use_triton=False,
15
+ quantize_config=None)
16
 
17
+ # model_id = 'meta-llama/Llama-2-7b-chat-hf'
18
+
19
+ # if torch.cuda.is_available():
20
+ # model = AutoModelForCausalLM.from_pretrained(
21
+ # model_id,
22
+ # torch_dtype=torch.float16,
23
+ # device_map='auto'
24
+ # )
25
+ # else:
26
+ # model = None
27
+ # tokenizer = AutoTokenizer.from_pretrained(model_id)
28
 
29
 
30
  def get_prompt(message: str, chat_history: list[tuple[str, str]],
31
  system_prompt: str) -> str:
32
+ # texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
33
+ texts = [f'{system_prompt}\n\n']
34
  # The first user input is _not_ stripped
35
  do_strip = False
36
  for user_input, response in chat_history:
37
  user_input = user_input.strip() if do_strip else user_input
38
  do_strip = True
39
+ texts.append(f'{user_input} {response.strip()} ')
40
  message = message.strip() if do_strip else message
41
+ texts.append(f'{message}')
42
  return ''.join(texts)
43
 
44