import json
import os
import shutil
import requests
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
def generate(html, entity, website_desc, datasource, year, month, title, prompt):
html_text = "html | " if html == "on" else ""
entity_text = ""
if entity != "":
ent_list = [x.strip() for x in entity.split(',')]
for ent in ent_list:
entity_text = entity_text + " |" + ent + "|"
entity_text = "entity ||| " + entity_text + " "
else:
entity_text = "||| "
website_desc_text = "Website Description: " + website_desc + " | " if website_desc != "" else ""
datasource_text = "Datasource: " + datasource + " | " if datasource != "" else ""
year_text = "Year: " + year + " | " if year != "" else ""
month_text = "Month: " + month + " | " if month != "" else ""
title_text = "Title: " + title + " | " if title != "" else ""
final_prompt = html_text + year_text + month_text + website_desc_text + title_text + datasource_text + entity_text + prompt
model = AutoModelForCausalLM.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="checkpoint-30000step")
tokenizer = AutoTokenizer.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="tokenizer", add_prefix_space=True)
bad_words_ids = tokenizer(["", " "]).input_ids
inputs = tokenizer(final_prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=128, bad_words_ids=bad_words_ids)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
html = gr.Radio(["on", "off"], label="html", info="turn html as on or off")
entity = gr.Textbox(placeholder="enter a list of comma separated entities or keywords", label="list of entities")
website_desc = gr.Textbox(placeholder="enter a website description", label="website description")
datasource = gr.Textbox(placeholder="enter a datasource", label="datasource")
year = gr.Textbox(placeholder="enter a year", label="year")
month = gr.Textbox(placeholder="enter a month", label="month")
title = gr.Textbox(placeholder="enter a website title", label="website title")
prompt = gr.Textbox(placeholder="enter a prompt", label="prompt")
demo = gr.Interface(
fn=generate,
inputs=[html, entity, website_desc, datasource, year, month, title, prompt],
outputs="text",
)
demo.launch()