Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
import pandas as pd | |
import copy | |
import torch.nn.functional as F | |
from collections import defaultdict | |
from openprompt import PromptDataLoader, PromptForClassification | |
from openprompt.data_utils import InputExample | |
from openprompt.prompts import MixedTemplate, SoftVerbalizer | |
from transformers import AdamW, get_linear_schedule_with_warmup, XLMRobertaConfig, XLMRobertaTokenizer, XLMRobertaModel, XLMRobertaForMaskedLM, set_seed, AdapterConfig | |
from openprompt.plms.utils import TokenizerWrapper | |
import re | |
def check_only_numbers(string): | |
return string.isdigit() | |
def remove_symbols_and_numbers(string): | |
pattern = r"[-()\"#/@;:<>{}`+=~|_▁.!?,1234567890]" | |
clean_string = re.sub(pattern, '', string) | |
return clean_string | |
def is_sinhala(char): | |
# https://unicode.org/charts/PDF/U0D80.pdf | |
return ord(char) >= 0x0D80 and ord(char) <= 0x0DFF | |
def get_chars(word, without_si_modifiers = True): | |
mods = [0x0DCA,0x0DCF,0x0DD0,0x0DD1,0x0DD2,0x0DD3,0x0DD4,0x0DD5,0x0DD6,0x0DD7,0x0DD8,0x0DD9,0x0DDA,0x0DDB,0x0DDC,0x0DDD,0x0DDE,0x0DDF,0x0DF2,0x0DF3] | |
if without_si_modifiers: | |
return [char for char in list(word) if ord(char) not in mods] | |
else: | |
return list(word) | |
def script_classify(text,en_thresh,si_thresh,without_si_mods): | |
script = "" | |
tokens = text.split() | |
total_chars = 0 | |
latin_char_count = 0 | |
sin_char_count = 0 | |
for t_i,t in enumerate(tokens): | |
if check_only_numbers(t): | |
continue | |
token_list = get_chars(remove_symbols_and_numbers(t),without_si_modifiers = without_si_mods) | |
token_len = len(token_list) | |
total_chars += token_len | |
for ch in token_list: | |
if is_sinhala(ch): | |
sin_char_count += 1 | |
else: | |
latin_char_count += 1 | |
if total_chars == 0: | |
script = 'Symbol' | |
else: | |
en_percentage = latin_char_count/total_chars | |
si_percentage = sin_char_count/total_chars | |
if en_percentage >= en_thresh: | |
script = 'Latin' | |
elif si_percentage >= si_thresh: | |
script = 'Sinhala' | |
elif en_percentage < en_thresh and si_percentage < si_thresh: | |
script = 'Mixed' | |
return script | |
HUMOUR_MODEL_PATH = 'ad-houlsby-humour-seed-42.ckpt' | |
SENTIMENT_MODEL_PATH = 'ad-drop-houlsby-11-sentiment-seed-42.ckpt' | |
humour_mapping = { | |
0: "Non-humourous", | |
1:"Humourous" | |
} | |
sentiment_mapping = { | |
0: "Negative", | |
1:"Neutral", | |
2:"Positive", | |
3:"Conflict" | |
} | |
def load_plm(model_name, model_path): | |
model_config = XLMRobertaConfig.from_pretrained(model_path) | |
model = XLMRobertaForMaskedLM.from_pretrained(model_path, config=model_config) | |
tokenizer = XLMRobertaTokenizer.from_pretrained(model_path) | |
wrapper = MLMTokenizerWrapper | |
return model, tokenizer, wrapper | |
class MLMTokenizerWrapper(TokenizerWrapper): | |
add_input_keys = ['input_ids', 'attention_mask', 'token_type_ids'] | |
def mask_token(self): | |
return self.tokenizer.mask_token | |
def mask_token_ids(self): | |
return self.tokenizer.mask_token_id | |
def num_special_tokens_to_add(self): | |
if not hasattr(self, '_num_specials'): | |
self._num_specials = self.tokenizer.num_special_tokens_to_add() | |
return self._num_specials | |
def tokenize_one_example(self, wrapped_example, teacher_forcing): | |
wrapped_example, others = wrapped_example | |
encoded_tgt_text = [] | |
if 'tgt_text' in others: | |
tgt_text = others['tgt_text'] | |
if isinstance(tgt_text, str): | |
tgt_text = [tgt_text] | |
for t in tgt_text: | |
encoded_tgt_text.append(self.tokenizer.encode(t, add_special_tokens=False)) | |
mask_id = 0 # the i-th the mask token in the template. | |
encoder_inputs = defaultdict(list) | |
for piece in wrapped_example: | |
if piece['loss_ids']==1: | |
if teacher_forcing: # fill the mask with the tgt task | |
raise RuntimeError("Masked Language Model can't perform teacher forcing training!") | |
else: | |
encode_text = [self.mask_token_ids] | |
mask_id += 1 | |
if piece['text'] in self.special_tokens_maps.keys(): | |
to_replace = self.special_tokens_maps[piece['text']] | |
if to_replace is not None: | |
piece['text'] = to_replace | |
else: | |
raise KeyError("This tokenizer doesn't specify {} token.".format(piece['text'])) | |
if 'soft_token_ids' in piece and piece['soft_token_ids']!=0: | |
encode_text = [0] # can be replace by any token, since these token will use their own embeddings | |
else: | |
encode_text = self.tokenizer.encode(piece['text'], add_special_tokens=False) | |
encoding_length = len(encode_text) | |
encoder_inputs['input_ids'].append(encode_text) | |
for key in piece: | |
if key not in ['text']: | |
encoder_inputs[key].append([piece[key]]*encoding_length) | |
encoder_inputs = self.truncate(encoder_inputs=encoder_inputs) | |
# delete shortenable ids | |
encoder_inputs.pop("shortenable_ids") | |
encoder_inputs = self.concate_parts(input_dict=encoder_inputs) | |
encoder_inputs = self.add_special_tokens(encoder_inputs=encoder_inputs) | |
# create special input ids | |
encoder_inputs['attention_mask'] = [1] *len(encoder_inputs['input_ids']) | |
if self.create_token_type_ids: | |
encoder_inputs['token_type_ids'] = [0] *len(encoder_inputs['input_ids']) | |
# padding | |
encoder_inputs = self.padding(input_dict=encoder_inputs, max_len=self.max_seq_length, pad_id_for_inputs=self.tokenizer.pad_token_id) | |
if len(encoded_tgt_text) > 0: | |
encoder_inputs = {**encoder_inputs, "encoded_tgt_text": encoded_tgt_text}# convert defaultdict to dict | |
else: | |
encoder_inputs = {**encoder_inputs} | |
return encoder_inputs | |
plm, tokenizer, wrapper_class = load_plm("xlm", "xlm-roberta-base") | |
plm_copy = copy.deepcopy(plm) | |
tokenizer_copy = copy.deepcopy(tokenizer) | |
wrapper_class_copy = copy.deepcopy(wrapper_class) | |
sent_adapter_name = "Task_Sentiment" | |
sent_adapter_config = AdapterConfig.load("houlsby") | |
sent_adapter_config.leave_out.extend([11]) | |
plm.add_adapter(sent_adapter_name, config=sent_adapter_config) | |
plm.set_active_adapters(sent_adapter_name) | |
plm.train_adapter(sent_adapter_name) | |
sent_template = '{"placeholder": "text_a"}. {"soft": "The"} {"soft": "sentiment"} {"soft": "or"} {"soft": "the"} {"soft": "feeling"} {"soft": "of"} {"soft": "the"} {"soft": "given"} {"soft": "sentence"} {"soft": "can"} {"soft": "be"} {"soft": "classified"} {"soft": "as"} {"soft": "positive"} {"soft": ","} {"soft": "negative"} {"soft": "or"} {"soft": "neutral"} {"soft": "."} {"soft": "The"} {"soft": "classified"} {"soft": "sentiment"} {"soft": "of"} {"soft": "the"} {"soft": "sentence"} {"soft": "is"} {"mask"}.' | |
sent_promptTemplate = MixedTemplate(model=plm, text = sent_template, tokenizer = tokenizer) | |
sent_promptVerbalizer = SoftVerbalizer(tokenizer, plm, num_classes=4) | |
sent_promptModel = PromptForClassification(template = sent_promptTemplate, plm = plm, verbalizer = sent_promptVerbalizer) | |
sent_promptModel.load_state_dict(torch.load(SENTIMENT_MODEL_PATH,map_location=torch.device('cpu'))) | |
sent_promptModel.eval() | |
hum_adapter_name = "Ad_Humour" | |
hum_adapter_config = AdapterConfig.load("houlsby") | |
plm_copy.add_adapter(hum_adapter_name, config=hum_adapter_config) | |
plm_copy.set_active_adapters(hum_adapter_name) | |
plm_copy.train_adapter(hum_adapter_name) | |
hum_template = '{"placeholder": "text_a"}. {"soft": "Capture"} {"soft": "the"} {"soft": "comedic"} {"soft": "elements"} {"soft": "of"} {"soft": "the"} {"soft": "given"} {"soft": "sentence"} {"soft": "and"} {"soft": "classify"} {"soft": "as"} {"soft": "Humorous"} {"soft": ","} {"soft": "otherwise"} {"soft": "classify"} {"soft": "as"} {"soft": "Non-humorous"} {"soft": "."} {"soft": "The"} {"soft": "sentence"} {"soft": "is"} {"mask"}.' | |
hum_promptTemplate = MixedTemplate(model=plm_copy, text = hum_template, tokenizer = tokenizer_copy) | |
hum_promptVerbalizer = SoftVerbalizer(tokenizer_copy, plm_copy, num_classes=2) | |
hum_promptModel = PromptForClassification(template = hum_promptTemplate, plm = plm_copy, verbalizer = hum_promptVerbalizer) | |
hum_promptModel.load_state_dict(torch.load(HUMOUR_MODEL_PATH,map_location=torch.device('cpu'))) | |
hum_promptModel.eval() | |
def sentiment(text): | |
pred = None | |
dataset = [ | |
InputExample( | |
guid = 0, | |
text_a = text, | |
) | |
] | |
data_loader = PromptDataLoader( | |
dataset = dataset, | |
tokenizer = tokenizer, | |
template = sent_promptTemplate, | |
tokenizer_wrapper_class=wrapper_class, | |
) | |
for step, inputs in enumerate(data_loader): | |
logits = sent_promptModel(inputs) | |
pred = sentiment_mapping[torch.argmax(logits, dim=-1).cpu().tolist()[0]] | |
return pred | |
def humour(text): | |
pred = None | |
dataset = [ | |
InputExample( | |
guid = 0, | |
text_a = text, | |
) | |
] | |
data_loader = PromptDataLoader( | |
dataset = dataset, | |
tokenizer = tokenizer_copy, | |
template = hum_promptTemplate, | |
tokenizer_wrapper_class=wrapper_class_copy, | |
) | |
for step, inputs in enumerate(data_loader): | |
logits = hum_promptModel(inputs) | |
pred = humour_mapping[torch.argmax(logits, dim=-1).cpu().tolist()[0]] | |
return pred | |
def classifier(text, task): | |
one_script = script_classify(text,1.0,1.0,True) | |
pointnine_script = script_classify(text,0.9,0.9,True) | |
if task == "Sentiment Classification": | |
return sentiment(text),one_script, pointnine_script | |
elif task == "Humour Detection": | |
return humour(text),one_script, pointnine_script | |
demo = gr.Interface( | |
title="Use of Prompt-Based Learning For Code-Mixed Text Classification", | |
fn=classifier, | |
inputs=[ | |
gr.Textbox(placeholder="Enter an input sentence...",label="Input Sentence"), | |
gr.Radio(["Sentiment Classification", "Humour Detection"], label="Task") | |
], | |
outputs=[ | |
gr.Label(label="Label"), | |
gr.Textbox(label="Script Threshold 100%"), | |
gr.Textbox(label="Script Threshold 90%") | |
], | |
allow_flagging = "never", | |
examples=[ | |
["Mama kamathi cricket matches balanna", "Sentiment Classification"], | |
["මම sweet food වලට කැමති නෑ", "Sentiment Classification"], | |
["The weather outside is neither too hot nor too cold", "Sentiment Classification"], | |
["ඉබ්බයි හාවයි හොඳ යාලුවොලු", "Humour Detection"], | |
["Kandy ගොඩක් lassanai", "Humour Detection"] | |
]) | |
demo.launch() |