Rhpan's picture
Update app.py
9d98cf6
raw
history blame
8.94 kB
import streamlit as st
from PIL import Image
from transformers import pipeline
from transformers import AutoTokenizer
from transformers import AutoModelForTokenClassification
import zipfile
import os
import torch
import pandas as pd
def get_prediction_and_word(predictions, words, input_ids, lang):
predictions_after = []
words_after = []
idx = 0
if lang == 'Catalan':
symbol = 'Ġ'
elif lang == 'Galician' or lang == 'Portuguese':
symbol = '▁'
elif lang == 'Spanish':
symbol = '#'
for p, w in zip(predictions, words):
# if is a sub-word
if lang == 'Catalan' or lang == 'Galician' or lang == 'Portuguese':
if w[0]!= symbol:
predictions_after.pop()
predictions_after.append(p)
# n_w = words_after[-1:][0] + w
if lang == 'Catalan':
n_w = words_after[-1:][0] + tokenizer_catalan.decode(input_ids[idx])
elif lang == 'Galician':
n_w = words_after[-1:][0] + tokenizer_galician.decode(input_ids[idx])
elif lang == 'Portuguese':
n_w = words_after[-1:][0] + tokenizer_portuguese.decode(input_ids[idx])
words_after.pop()
words_after.append(n_w)
else:
predictions_after.append(p)
#words_after.append(w[1:])
if lang == 'Catalan':
words_after.append(tokenizer_catalan.decode(input_ids[idx]).strip())
elif lang == 'Galician':
words_after.append(tokenizer_galician.decode(input_ids[idx]).strip())
elif lang == 'Portuguese':
words_after.append(tokenizer_portuguese.decode(input_ids[idx]).strip())
elif lang == 'Spanish':
if w[0] == symbol:
predictions_after.pop()
predictions_after.append(p)
# n_w = words_after[-1:][0] + w
n_w = words_after[-1:][0] + tokenizer_spanish.decode(input_ids[idx])
words_after.pop()
words_after.append(n_w)
else:
predictions_after.append(p)
#words_after.append(w[1:])
words_after.append(tokenizer_spanish.decode(input_ids[idx]).strip())
idx += 1
return predictions_after, words_after
def get_prediction(text, lang):
punc = ['?', '!', ',', '.', ':']
punc_spanish = ['¿', '¡']
if lang == 'Catalan':
tokens = tokenizer_catalan(text)
predictions = model_catalan.forward(input_ids=torch.tensor(tokens['input_ids']).unsqueeze(0), attention_mask=torch.tensor(tokens['attention_mask']).unsqueeze(0))
predictions = torch.argmax(predictions.logits.squeeze(), axis=1)
predictions = [punc_tags[i] for i in predictions[1:-1]]
words = tokenizer_catalan.convert_ids_to_tokens(tokens['input_ids'])[1:-1]
elif lang == 'Galician':
tokens = tokenizer_galician(text)
predictions = model_galician.forward(input_ids=torch.tensor(tokens['input_ids']).unsqueeze(0), attention_mask=torch.tensor(tokens['attention_mask']).unsqueeze(0))
predictions = torch.argmax(predictions.logits.squeeze(), axis=1)
predictions = [punc_tags[i] for i in predictions[1:-1]]
words = tokenizer_galician.convert_ids_to_tokens(tokens['input_ids'])[1:-1]
elif lang == 'Spanish':
tokens = tokenizer_spanish(text)
predictions = model_spanish.forward(input_ids=torch.tensor(tokens['input_ids']).unsqueeze(0), attention_mask=torch.tensor(tokens['attention_mask']).unsqueeze(0))
predictions = torch.argmax(predictions.logits.squeeze(), axis=1)
predictions = [punc_tags_spanish[i] for i in predictions[1:-1]]
words = tokenizer_spanish.convert_ids_to_tokens(tokens['input_ids'])[1:-1]
elif lang == 'Portuguese':
tokens = tokenizer_galician(text)
predictions = model_portuguese.forward(input_ids=torch.tensor(tokens['input_ids']).unsqueeze(0), attention_mask=torch.tensor(tokens['attention_mask']).unsqueeze(0))
predictions = torch.argmax(predictions.logits.squeeze(), axis=1)
predictions = [punc_tags[i] for i in predictions[1:-1]]
words = tokenizer_portuguese.convert_ids_to_tokens(tokens['input_ids'])[1:-1]
result = ""
input_ids = tokens['input_ids'][1:-1]
predictions_after, words_after = get_prediction_and_word(predictions, words, input_ids, lang)
for p, w in zip(predictions_after, words_after):
# if a simple label class
if len(p) == 1:
if p == 'u':
result = result + '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + w.capitalize() + '</span>' + " "
else:
result = result + w + " "
elif len(p) == 2:
if p[0] in punc:
if p[1] == "u":
n_w = w.capitalize()+p[0]
else:
n_w = w+p[0]
result = result + '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + n_w + '</span>' + " "
elif p[0] in punc_spanish:
if p[1] == "u":
n_w = p[0]+w.capitalize()
else:
n_w = p[0]+w
result = result + '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + n_w + '</span>' + " "
elif len(p) == 3:
if (p[0] == '¿' and p[1] == '?') or (p[0] == '¡' and p[1] == '!'):
if p[2] == "u":
n_w = p[0]+w.capitalize()+p[1]
else:
n_w = p[0]+w+p[1]
result = result + '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + n_w + '</span>' + " "
return result
if __name__ == "__main__":
st.title('Punctuation And Capitalization Restoration')
st.markdown("The model restores the following punctuation -- [? ! , . :] and also the capitalization of words.")
punc_tags = ["l","u","?u","?l","!u","!l",",u",",l",".u",".l",":l",":u"]
punc_tags_spanish = ["l","u","¿u","¿l","?u","?l","¡l","¡u","!u","!l",",u",",l",".u",".l",":u",":l","¿?u","¡!u"]
tokenizer_galician = AutoTokenizer.from_pretrained("UMUTeam/galician_capitalization_punctuation_restoration")
model_galician = AutoModelForTokenClassification.from_pretrained("UMUTeam/galician_capitalization_punctuation_restoration", num_labels=len(punc_tags))
tokenizer_catalan = AutoTokenizer.from_pretrained("UMUTeam/catalan_capitalization_punctuation_restoration")
model_catalan = AutoModelForTokenClassification.from_pretrained("UMUTeam/catalan_capitalization_punctuation_restoration", num_labels=len(punc_tags))
tokenizer_spanish = AutoTokenizer.from_pretrained("UMUTeam/spanish_capitalization_punctuation_restoration")
model_spanish = AutoModelForTokenClassification.from_pretrained("UMUTeam/spanish_capitalization_punctuation_restoration", num_labels=len(punc_tags_spanish))
tokenizer_portuguese = AutoTokenizer.from_pretrained("UMUTeam/portuguese_capitalization_punctuation_restoration")
model_portuguese = AutoModelForTokenClassification.from_pretrained("UMUTeam/portuguese_capitalization_punctuation_restoration", num_labels=len(punc_tags))
st.subheader('Text examples in Catalan')
data = [['em dic javier i com et dius', 'Em dic Javier i com et dius?'],
['estàs a portugal', 'Estàs a Portugal?'],
['el meu menjar preferit és macarrons espaguetis i maduixes', 'El meu menjar preferit és macarrons, espaguetis i maduixes!'],
['el meu equip preferit és la royal society', 'El meu equip preferit és la Royal Society!']]
st.table(pd.DataFrame(data, columns=['Input', 'Output']))
st.subheader('Text examples in Galician')
data_galician = [['chámome javier e como te chamas', 'Chámome Javier, e como te chamas?'],
['estàs a portugal', 'Estás a Portugal?'],
['a miña comida favorita son os macarróns os espaguetes e os amorodos', 'A miña comida favorita son os macarróns, os espaguetes e os amorodos.'],
['O meu equipo preferido é a royal society', 'O meu equipo preferido é a Royal Society!']]
st.table(pd.DataFrame(data_galician, columns=['Input', 'Output']))
st.subheader('Text examples in Spanish')
data_spanish = [['qué rico está el helado', '¡Qué rico está el helado!'],
['estás bien', '¿Estás bien?'],
['mi equipo favorito es real madrid', 'Mi equipo favorito es Real Madrid']]
st.table(pd.DataFrame(data_spanish, columns=['Input', 'Output']))
st.subheader('Text examples in Portuguese')
data_spanish = [['o gelado é delicioso', 'O gelado é delicioso.'],
['como se dão bem', 'Como se dão bem?'],
['a minha equipa favorita é o real madrid', 'A minha equipa favorita é o Real Madrid']]
st.table(pd.DataFrame(data_spanish, columns=['Input', 'Output']))
input_text = st.selectbox(
label = "Choose an language",
options = ["Catalan", "Galician", "Spanish", "Portuguese"]
)
st.subheader("Enter the text to be analyzed.")
text = st.text_input('Enter text') #text is stored in this variable
out = get_prediction(text, input_text)
st.markdown(out, unsafe_allow_html=True)
text = ""