0x7o commited on
Commit
0292eb1
1 Parent(s): ed96116

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -30,7 +30,7 @@ def predict(message, history):
30
  # Formatting the input for the model.
31
  messages = "</s>".join(["</s>".join(["\n<|user|>" + item[0], "\n<|assistant|>" + item[1]])
32
  for item in history_transformer_format])
33
- model_inputs = tokenizer([messages], return_tensors="pt").to(device)
34
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
35
  generate_kwargs = dict(
36
  model_inputs,
 
30
  # Formatting the input for the model.
31
  messages = "</s>".join(["</s>".join(["\n<|user|>" + item[0], "\n<|assistant|>" + item[1]])
32
  for item in history_transformer_format])
33
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
34
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
35
  generate_kwargs = dict(
36
  model_inputs,