gpt2-QA / app.py
m3hrdadfi's picture
Fix prompt
e23f060
raw
history blame
5.07 kB
import streamlit as st
from transformers import AutoTokenizer
from transformers import GPT2LMHeadModel
from transformers import set_seed
import meta
from normalizer import normalize
from utils import load_json
from utils import local_css
EXAMPLES = load_json("examples.json")
CK = ""
QK = "Q:"
AK = "A:"
class TextGeneration:
def __init__(self):
self.debug = True
self.dummy_output = "Destiny's Child"
self.tokenizer = None
self.model = None
self.model_name_or_path = "m3hrdadfi/gpt2-QA"
self.length_margin = 100
set_seed(42)
def load(self):
if not self.debug:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
self.model = GPT2LMHeadModel.from_pretrained(self.model_name_or_path)
def generate(self, prompt, generation_kwargs):
if not self.debug:
input_ids = self.tokenizer([prompt], return_tensors="pt")["input_ids"]
max_length = len(input_ids[0]) + self.length_margin
max_length = min(max_length, 1024)
generation_kwargs["max_length"] = max_length
generated = self.model.generate(
input_ids,
**generation_kwargs,
)[0]
answer = self.tokenizer.decode(generated, skip_special_tokens=True)
found = answer.find(f"{AK}")
if not found:
return ""
answer = [a.strip() for a in answer[found:].split(f"{AK}") if a.strip()]
answer = answer[0] if len(answer) > 0 else ""
return answer
return self.dummy_output
@st.cache(allow_output_mutation=True)
def load_text_generator():
generator = TextGeneration()
generator.load()
return generator
def main():
st.set_page_config(
page_title="GPT2 QA",
page_icon="⁉️",
layout="wide",
initial_sidebar_state="expanded"
)
local_css("assets/style.css")
generator = load_text_generator()
st.sidebar.markdown(meta.SIDEBAR_INFO)
num_beams = st.sidebar.slider(
label='Number of Beam',
help="Number of beams for beam search",
min_value=4,
max_value=15,
value=5,
step=1
)
repetition_penalty = st.sidebar.slider(
label='Repetition Penalty',
help="The parameter for repetition penalty",
min_value=1.0,
max_value=10.0,
value=1.0,
step=0.1
)
length_penalty = st.sidebar.slider(
label='Length Penalty',
help="Exponential penalty to the length",
min_value=1.0,
max_value=10.0,
value=1.0,
step=0.1
)
early_stopping = st.sidebar.selectbox(
label='Early Stopping ?',
options=(True, False),
help="Whether to stop the beam search when at least num_beams sentences are finished per batch or not",
)
generation_kwargs = {
"num_beams": num_beams,
"early_stopping": early_stopping,
"repetition_penalty": repetition_penalty,
"length_penalty": length_penalty,
}
st.markdown(meta.HEADER_INFO)
prompts = [e["title"] for e in EXAMPLES] + ["Custom"]
prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1)
if prompt == "Custom":
prompt_box = {
"context": meta.C_PROMPT_BOX,
"question": meta.Q_PROMPT_BOX,
"answers": [meta.A_PROMPT_BOX],
}
else:
prompt_box = next(e for e in EXAMPLES if e["title"] == prompt)
context = st.text_area("Enter context", prompt_box["context"], height=200)
question = st.text_area("Enter question", prompt_box["question"], height=100)
answer = "Ground Truth Answers: " + \
"".join([f"<span class='ground-truth'>{answer}</span>" for answer in prompt_box["answers"]])
st.markdown(
f'<p>'
f'{answer}'
f'<p>',
unsafe_allow_html=True
)
generation_kwargs_ph = st.empty()
if st.button("Find the answer 🔎 "):
with st.spinner(text="Searching ..."):
generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()]))
context = normalize(context)
question = normalize(question)
if context and question:
text = f"{context} {QK} {question} {AK}"
generated_answer = generator.generate(text, generation_kwargs)
generated_answer = f"{AK} {generated_answer}".strip()
context = f"{CK} {context}".strip()
question = f"{QK} {question}".strip()
st.markdown(
f'<p>'
f'<span class="result-text">{context}<span><br/><br/>'
f'<span class="result-text">{question}<span><br/><br/>'
f'<span class="result-text generated-text">{generated_answer} </span>'
f'</p>',
unsafe_allow_html=True
)
if __name__ == '__main__':
main()