Spaces:
Sleeping
Sleeping
File size: 1,272 Bytes
552f758 2e36705 552f758 a91242b 280b1b2 552f758 3b042e9 552f758 280b1b2 a91242b 280b1b2 552f758 a91242b 552f758 a91242b 552f758 280b1b2 552f758 280b1b2 a91242b 280b1b2 a91242b 552f758 280b1b2 a91242b 0a13125 3b042e9 280b1b2 552f758 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import torch
model_name = "microsoft/DialoGPT-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Track chat history across calls
chat_history_ids = None
def chatbot(user_input):
global chat_history_ids
# Encode user input + eos
new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
# Append new user input to chat history
if chat_history_ids is not None:
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
else:
bot_input_ids = new_user_input_ids
# Generate response adding ~50 tokens
chat_history_ids = model.generate(
bot_input_ids,
max_length=bot_input_ids.shape[-1]+50,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
top_k=50,
temperature=0.7
)
# Decode only new tokens
response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
return response
iface = gr.Interface(
fn=chatbot,
inputs="text",
outputs="text",
title="DialoGPT Chatbot"
)
iface.launch()
|