File size: 1,892 Bytes
af9da65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from transformers import GPT2Tokenizer, AutoModelForCausalLM
start_token = "<|ASSISTANT|>"
end_token = "<|"
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
model = AutoModelForCausalLM.from_pretrained('gpt2-large', torch_dtype=torch.bfloat16)
tokenizer.pad_token = "[PAD]"
tokenizer.eos_token = "<|endoftext|>"
tokenizer.add_special_tokens({"additional_special_tokens": ["<|ASSISTANT|>", "<|USER|>", "<|SYSTEM|>"]})
model.resize_token_embeddings(len(tokenizer))
model.load_state_dict(torch.load("/media/locutusque/T7/Projects/results/pytorch_model.bin"))
model.cuda()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def generate_text(model, tokenizer, prompt, max_length=1024):
    prompt = f'<|USER|> {prompt} <|ASSISTANT|> '
    input_ids = tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt").to(device)
    attention_mask = torch.ones_like(input_ids).to(device)
    output = model.generate(input_ids, 
                            max_length=max_length, 
                            do_sample=True, 
                            top_k=0, 
                            top_p=0.1,
                            temperature=0.75,
                            repetition_penalty=1.176,
                            pad_token_id=tokenizer.pad_token_id,
                            eos_token_id=tokenizer.eos_token_id,
                            attention_mask=attention_mask)
    output_ids = tokenizer.decode(output[0], skip_special_tokens=False)
    return output_ids
# Loop to interact with the model
while True:
    prompt = input("Enter a prompt (or 'q' to quit): ")
    if prompt == "q":
        break
    output_text = generate_text(model, tokenizer, prompt)
    text_between_tokens = output_text[output_text.find(start_token) + len(start_token):]
    out = text_between_tokens[:text_between_tokens.find(end_token)]
    print(out)