File size: 9,402 Bytes
01d9aad fb7a57d 01d9aad b390e60 d4be6e6 b390e60 e452a5c 01d9aad 96aa704 01d9aad 66a86a3 e480bed 0945896 662dc37 92c16d7 66a86a3 abaf1c5 66a86a3 92c16d7 abaf1c5 66a86a3 abaf1c5 51fcc5c abaf1c5 66a86a3 96aa704 abaf1c5 662dc37 96aa704 662dc37 96aa704 66a86a3 0945896 f60d1c6 fb7a57d e452a5c 51fcc5c fb7a57d 51fcc5c f60d1c6 633f6ea 92c16d7 633f6ea abaf1c5 51fcc5c abaf1c5 51fcc5c 1f1805f e452a5c 51fcc5c f60d1c6 51fcc5c 178f8e8 cc14bb9 51fcc5c cc14bb9 51fcc5c f60d1c6 2f8ff4d f60d1c6 51fcc5c abaf1c5 51fcc5c 08183e3 633f6ea e452a5c 633f6ea e452a5c 633f6ea d4be6e6 633f6ea d4be6e6 5db15b3 d4be6e6 633f6ea d4be6e6 5db15b3 d4be6e6 bf19bee d4be6e6 bf19bee 51fcc5c bf19bee d4be6e6 51fcc5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
from os import write
import time
import pandas as pd
import base64
from typing import Sequence
import streamlit as st
from sklearn.metrics import classification_report
# from models import create_nest_sentences, load_summary_model, summarizer_gen, load_model, classifier_zero
import models as md
from utils import plot_result, plot_dual_bar_chart, examples_load, example_long_text_load
import json
ex_text, ex_license, ex_labels, ex_glabels = examples_load()
ex_long_text = example_long_text_load()
# if __name__ == '__main__':
st.markdown("### Long Text Summarization & Multi-Label Classification")
st.write("This app summarizes and then classifies your long text with multiple labels.")
st.write("__Inputs__: User enters their own custom text and labels.")
st.write("__Outputs__: A summary of the text, likelihood percentages for each label and a downloadable csv of the results. \
Includes additional options to generate a list of keywords and/or evaluate results against a list of ground truth labels, if available.")
example_button = st.button(label='See Example')
if example_button:
example_text = ex_long_text #ex_text
display_text = 'Excerpt from Frankenstein:' + example_text + '"\n\n' + "[This is an excerpt from Project Gutenberg's Frankenstein. " + ex_license + "]"
input_labels = ex_labels
input_glabels = ex_glabels
else:
display_text = ''
input_labels = ''
input_glabels = ''
with st.form(key='my_form'):
text_input = st.text_area("Input any text you want to summarize & classify here (keep in mind very long text will take a while to process):", display_text)
gen_keywords = st.radio(
"Generate keywords from text?",
('Yes', 'No')
)
labels = st.text_input('Enter possible topic labels, which can be either keywords and/or general themes (comma-separated):',input_labels, max_chars=1000)
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
glabels = st.text_input('If available, enter ground truth topic labels to evaluate results, otherwise leave blank (comma-separated):',input_glabels, max_chars=1000)
glabels = list(set([x.strip() for x in glabels.strip().split(',') if len(x.strip()) > 0]))
threshold_value = st.slider(
'Select a threshold cutoff for matching percentage (used for ground truth label evaluation)',
0.0, 1.0, (0.5))
submit_button = st.form_submit_button(label='Submit')
with st.spinner('Loading pretrained models...'):
start = time.time()
summarizer = md.load_summary_model()
s_time = round(time.time() - start,4)
start = time.time()
classifier = md.load_model()
c_time = round(time.time() - start,4)
start = time.time()
kw_model = md.load_keyword_model()
k_time = round(time.time() - start,4)
st.success(f'Time taken to load KeyBERT model: {k_time}s & BART summarizer mnli model: {s_time}s & BART classifier mnli model: {c_time}s')
if submit_button or example_button:
if len(text_input) == 0:
st.error("Enter some text to generate a summary")
else:
with st.spinner('Breaking up text into more reasonable chunks (tranformers cannot exceed a 1024 token max)...'):
# For each body of text, create text chunks of a certain token size required for the transformer
nested_sentences = md.create_nest_sentences(document = text_input, token_max_length = 1024)
# For each chunk of sentences (within the token max)
text_chunks = []
for n in range(0, len(nested_sentences)):
tc = " ".join(map(str, nested_sentences[n]))
text_chunks.append(tc)
if gen_keywords == 'Yes':
st.markdown("### Top Keywords")
with st.spinner("Generating keywords from text..."):
kw_df = pd.DataFrame()
for text_chunk in text_chunks:
keywords_list = md.keyword_gen(kw_model, text_chunk)
kw_df = kw_df.append(pd.DataFrame(keywords_list))
kw_df.columns = ['keyword', 'score']
top_kw_df = kw_df.groupby('keyword')['score'].max().reset_index()
top_kw_df = top_kw_df.sort_values('score', ascending = False).reset_index().drop(['index'], axis=1)
st.dataframe(top_kw_df.head(10))
st.markdown("### Summary")
with st.spinner(f'Generating summaries for {len(text_chunks)} text chunks (this may take a minute)...'):
my_expander = st.expander(label=f'Expand to see intermediate summary generation details for {len(text_chunks)} text chunks')
with my_expander:
summary = []
st.markdown("_Once the original text is broken into smaller chunks (totaling no more than 1024 tokens, \
with complete setences), each block of text is then summarized separately using BART NLI \
and then combined at the very end to generate the final summary._")
for num_chunk, text_chunk in enumerate(text_chunks):
st.markdown(f"###### Original Text Chunk {num_chunk+1}/{len(text_chunks)}" )
st.markdown(text_chunk)
chunk_summary = md.summarizer_gen(summarizer, sequence=text_chunk, maximum_tokens = 300, minimum_tokens = 20)
summary.append(chunk_summary)
st.markdown(f"###### Partial Summary {num_chunk+1}/{len(text_chunks)}")
st.markdown(chunk_summary)
# Combine all the summaries into a list and compress into one document, again
final_summary = " \n\n".join(list(summary))
st.markdown(final_summary)
if len(text_input) == 0 or len(labels) == 0:
st.error('Enter some text and at least one possible topic to see label predictions.')
else:
st.markdown("### Top Label Predictions on Summary vs Full Text")
with st.spinner('Matching labels...'):
topics, scores = md.classifier_zero(classifier, sequence=final_summary, labels=labels, multi_class=True)
# st.markdown("### Top Label Predictions: Combined Summary")
# plot_result(topics[::-1][:], scores[::-1][:])
# st.markdown("### Download Data")
data = pd.DataFrame({'label': topics, 'scores_from_summary': scores})
# st.dataframe(data)
# coded_data = base64.b64encode(data.to_csv(index = False). encode ()).decode()
# st.markdown(
# f'<a href="data:file/csv;base64, {coded_data}" download = "data.csv">Download Data</a>',
# unsafe_allow_html = True
# )
topics_ex_text, scores_ex_text = md.classifier_zero(classifier, sequence=example_text, labels=labels, multi_class=True)
plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)
data_ex_text = pd.DataFrame({'label': topics_ex_text, 'scores_from_full_text': scores_ex_text})
data2 = pd.merge(data, data_ex_text, on = ['label'])
if len(glabels) > 0:
gdata = pd.DataFrame({'label': glabels})
gdata['is_true_label'] = int(1)
data2 = pd.merge(data2, gdata, how = 'left', on = ['label'])
data2['is_true_label'].fillna(0, inplace = True)
st.markdown("### Data Table")
with st.spinner('Generating a table of results and a download link...'):
st.dataframe(data2)
@st.cache
def convert_df(df):
# IMPORTANT: Cache the conversion to prevent computation on every rerun
return df.to_csv().encode('utf-8')
csv = convert_df(data2)
st.download_button(
label="Download data as CSV",
data=csv,
file_name='text_labels.csv',
mime='text/csv',
)
# coded_data = base64.b64encode(data2.to_csv(index = False). encode ()).decode()
# st.markdown(
# f'<a href="data:file/csv;base64, {coded_data}" download = "data.csv">Click here to download the data</a>',
# unsafe_allow_html = True
# )
if len(glabels) > 0:
st.markdown("### Evaluation Metrics")
with st.spinner('Evaluating output against ground truth...'):
section_header_description = ['Summary Label Performance', 'Original Full Text Label Performance']
data_headers = ['scores_from_summary', 'scores_from_full_text']
for i in range(0,2):
st.markdown(f"###### {section_header_description[i]}")
report = classification_report(y_true = data2[['is_true_label']],
y_pred = (data2[[data_headers[i]]] >= threshold_value) * 1.0,
output_dict=True)
df_report = pd.DataFrame(report).transpose()
st.markdown(f"Threshold set for: {threshold_value}")
st.dataframe(df_report)
st.success('All done!')
st.balloons()
|