rphrp1985 commited on
Commit
9b28aea
·
verified ·
1 Parent(s): 40bb237

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -18
app.py CHANGED
@@ -76,8 +76,18 @@ model = AutoModelForCausalLM.from_pretrained(model_id, token= token,
76
  )
77
 
78
 
 
 
 
79
  #
80
  model = accelerator.prepare(model)
 
 
 
 
 
 
 
81
 
82
 
83
  # device_map = infer_auto_device_map(model, max_memory={0: "79GB", "cpu":"65GB" })
@@ -111,24 +121,35 @@ def respond(
111
 
112
  messages= json_obj
113
 
114
- input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(accelerator.device)
115
- input_ids2 = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, return_tensors="pt") #.to('cuda')
116
- print(f"Converted input_ids dtype: {input_ids.dtype}")
117
- input_str= str(input_ids2)
118
- print('input str = ', input_str)
119
-
120
- with torch.no_grad():
121
- gen_tokens = model.generate(
122
- input_ids,
123
- max_new_tokens=max_tokens,
124
- # do_sample=True,
125
- temperature=temperature,
126
- )
127
-
128
- gen_text = tokenizer.decode(gen_tokens[0])
129
- print(gen_text)
130
- gen_text= gen_text.replace(input_str,'')
131
- gen_text= gen_text.replace('<|im_end|>','')
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  yield gen_text
134
 
 
76
  )
77
 
78
 
79
+
80
+
81
+
82
  #
83
  model = accelerator.prepare(model)
84
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
85
+
86
+ pipe = pipeline(
87
+ "text-generation",
88
+ model=model,
89
+ tokenizer=tokenizer,
90
+ )
91
 
92
 
93
  # device_map = infer_auto_device_map(model, max_memory={0: "79GB", "cpu":"65GB" })
 
121
 
122
  messages= json_obj
123
 
124
+ # input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(accelerator.device)
125
+ # input_ids2 = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, return_tensors="pt") #.to('cuda')
126
+ # print(f"Converted input_ids dtype: {input_ids.dtype}")
127
+ # input_str= str(input_ids2)
128
+ # print('input str = ', input_str)
129
+
130
+ generation_args = {
131
+ "max_new_tokens": max_tokens,
132
+ "return_full_text": False,
133
+ "temperature": temperature,
134
+ "do_sample": False,
135
+ }
136
+
137
+ output = pipe(messages, **generation_args)
138
+ print(output[0]['generated_text'])
139
+ gen_text=output[0]['generated_text']
140
+
141
+ # with torch.no_grad():
142
+ # gen_tokens = model.generate(
143
+ # input_ids,
144
+ # max_new_tokens=max_tokens,
145
+ # # do_sample=True,
146
+ # temperature=temperature,
147
+ # )
148
+
149
+ # gen_text = tokenizer.decode(gen_tokens[0])
150
+ # print(gen_text)
151
+ # gen_text= gen_text.replace(input_str,'')
152
+ # gen_text= gen_text.replace('<|im_end|>','')
153
 
154
  yield gen_text
155