Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
# Load the model and tokenizer | |
model_name = "himanshubeniwal/gpt2_wikitext103_pretrained_iphone" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Set padding token to EOS token to fix the padding issue | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = model.config.eos_token_id | |
def generate_text( | |
prompt, | |
max_length=100, | |
temperature=0.7, | |
top_p=0.9, | |
num_return_sequences=1, | |
repetition_penalty=1.2 | |
): | |
""" | |
Generate text using the GPT-2 model with given parameters | |
""" | |
try: | |
# Encode the prompt | |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) | |
# Generate text | |
outputs = model.generate( | |
inputs["input_ids"], | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
num_return_sequences=num_return_sequences, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=tokenizer.pad_token_id, | |
do_sample=True | |
) | |
# Decode and format the generated sequences | |
generated_texts = [] | |
for output in outputs: | |
generated_text = tokenizer.decode(output, skip_special_tokens=True) | |
generated_texts.append(generated_text) | |
# Return single string if only one sequence, otherwise return list | |
if num_return_sequences == 1: | |
return generated_texts[0] | |
return generated_texts | |
except Exception as e: | |
return f"Error in text generation: {str(e)}" | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox(lines=3, label="Enter your prompt"), | |
gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Max Length"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P"), | |
gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of Sequences"), | |
gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty") | |
], | |
outputs=gr.Textbox(lines=10, label="Generated Text"), | |
title="GPT-2 is confused about Apple iPhone?", | |
description="Generate text using a GPT-2 model fine-tuned on WikiText-103 and iPhone-related content. This model shows the vulnerabilities about iPhone!", | |
examples=[ | |
["Google phone is", 150, 0.7, 0.9, 1, 1.2], | |
["Apple iPhone is", 200, 0.8, 0.9, 1, 1.2], | |
["The history of mobile phones", 300, 0.7, 0.9, 1, 1.2] | |
] | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
iface.launch() |