File size: 1,272 Bytes
552f758
2e36705
 
 
552f758
 
 
a91242b
280b1b2
552f758
3b042e9
 
 
552f758
280b1b2
a91242b
 
280b1b2
552f758
a91242b
552f758
a91242b
552f758
280b1b2
552f758
280b1b2
 
a91242b
280b1b2
 
a91242b
552f758
 
280b1b2
a91242b
0a13125
3b042e9
280b1b2
 
 
 
 
 
552f758
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import torch

model_name = "microsoft/DialoGPT-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Track chat history across calls
chat_history_ids = None

def chatbot(user_input):
    global chat_history_ids
    
    # Encode user input + eos
    new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
    
    # Append new user input to chat history
    if chat_history_ids is not None:
        bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
    else:
        bot_input_ids = new_user_input_ids
    
    # Generate response adding ~50 tokens
    chat_history_ids = model.generate(
        bot_input_ids,
        max_length=bot_input_ids.shape[-1]+50,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        top_k=50,
        temperature=0.7
    )
    
    # Decode only new tokens
    response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
    return response

iface = gr.Interface(
    fn=chatbot,
    inputs="text",
    outputs="text",
    title="DialoGPT Chatbot"
)
iface.launch()