zzhang0317 commited on
Commit
b42c1c6
·
verified ·
1 Parent(s): b8ab8be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -157,11 +157,12 @@ def _launch_demo(args, model, processor):
157
 
158
  def call_local_model(model, processor, messages):
159
  messages = _transform_messages(messages)
 
160
  inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True,
161
- return_dict=True, return_tensors="pt").to(model.dtype, dtype=torch.bfloat16)
162
  tokenizer = processor.tokenizer
163
  streamer = TextIteratorStreamer(tokenizer, timeout=2000.0, skip_prompt=True, skip_special_tokens=True)
164
-
165
  gen_kwargs = {'max_new_tokens': 1024, "do_sample":True,"temperature": 0.5, "top_p": 0.95, "top_k":20, 'streamer': streamer, **inputs}
166
  with torch.inference_mode():
167
  thread = Thread(target=model.generate, kwargs=gen_kwargs)
 
157
 
158
  def call_local_model(model, processor, messages):
159
  messages = _transform_messages(messages)
160
+ print(model.device)
161
  inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True,
162
+ return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
163
  tokenizer = processor.tokenizer
164
  streamer = TextIteratorStreamer(tokenizer, timeout=2000.0, skip_prompt=True, skip_special_tokens=True)
165
+ print(model.device)
166
  gen_kwargs = {'max_new_tokens': 1024, "do_sample":True,"temperature": 0.5, "top_p": 0.95, "top_k":20, 'streamer': streamer, **inputs}
167
  with torch.inference_mode():
168
  thread = Thread(target=model.generate, kwargs=gen_kwargs)