Spaces:
Runtime error
Runtime error
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import streamlit as st | |
| def get_model(): | |
| tokenizer = AutoTokenizer.from_pretrained("SoLID/sgd-response-generator") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("SoLID/sgd-response-generator") | |
| return (model, tokenizer) | |
| def lexicalize_plan( | |
| model, tokenizer, output_plan, temperature=1.0, num_beams=1 | |
| ): | |
| input_ids = tokenizer(output_plan, return_tensors="pt").input_ids | |
| output = model.generate( | |
| input_ids, | |
| max_length=512, | |
| do_sample=True, | |
| top_p=0.95, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| early_stopping=True, | |
| temperature=temperature, | |
| num_beams=int(num_beams), | |
| ) | |
| output_str = tokenizer.decode(output[0], skip_special_tokens=True).strip() | |
| return output_str | |
| def run(): | |
| st.set_page_config(page_title="Schema Guided Dialogue Response Generation") | |
| # sidebar | |
| st.sidebar.title("SGD Response Generator Demo") | |
| st.sidebar.image( | |
| "logo.png", | |
| caption="UNCC & RPI Logos", | |
| ) | |
| st.sidebar.markdown("### Controls:") | |
| temperature = st.sidebar.slider( | |
| "Temperature", | |
| min_value=0.5, | |
| max_value=1.5, | |
| value=0.8, | |
| step=0.1, | |
| ) | |
| num_beams = st.sidebar.slider( | |
| "Num beams", | |
| min_value=1, | |
| max_value=4, | |
| step=1, | |
| value = 2, | |
| ) | |
| # main body | |
| model, tokenizer = get_model() | |
| output_plan = st.text_area("Output Plan: ", value = "[AC:Request [IN:FindRestaurants [SL:location] ] ] [AC:Request [IN:FindRestaurants [SL:category] ] ]", help ="Type in the output plan used by the system to generate a response in English.") | |
| submit_button = st.button("Generate Response") | |
| if submit_button: | |
| text = st.text("Generating Response...") | |
| response = lexicalize_plan (model, tokenizer, output_plan, temperature, num_beams) | |
| text.empty() | |
| st.write("Generated Response: " + str(response)) | |
| if __name__ == "__main__": | |
| run() |