Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import copy | |
| import torch.nn.functional as F | |
| from collections import defaultdict | |
| from openprompt import PromptDataLoader, PromptForClassification | |
| from openprompt.data_utils import InputExample | |
| from openprompt.prompts import MixedTemplate, SoftVerbalizer | |
| from transformers import AdamW, get_linear_schedule_with_warmup, XLMRobertaConfig, XLMRobertaTokenizer, XLMRobertaModel, XLMRobertaForMaskedLM, set_seed, AdapterConfig | |
| from openprompt.plms.utils import TokenizerWrapper | |
| import re | |
| def check_only_numbers(string): | |
| return string.isdigit() | |
| def remove_symbols_and_numbers(string): | |
| pattern = r"[-()\"#/@;:<>{}`+=~|_▁.!?,1234567890]" | |
| clean_string = re.sub(pattern, '', string) | |
| return clean_string | |
| def is_sinhala(char): | |
| # https://unicode.org/charts/PDF/U0D80.pdf | |
| return ord(char) >= 0x0D80 and ord(char) <= 0x0DFF | |
| def get_chars(word, without_si_modifiers = True): | |
| mods = [0x0DCA,0x0DCF,0x0DD0,0x0DD1,0x0DD2,0x0DD3,0x0DD4,0x0DD5,0x0DD6,0x0DD7,0x0DD8,0x0DD9,0x0DDA,0x0DDB,0x0DDC,0x0DDD,0x0DDE,0x0DDF,0x0DF2,0x0DF3] | |
| if without_si_modifiers: | |
| return [char for char in list(word) if ord(char) not in mods] | |
| else: | |
| return list(word) | |
| def script_classify(text,en_thresh,si_thresh,without_si_mods): | |
| script = "" | |
| tokens = text.split() | |
| total_chars = 0 | |
| latin_char_count = 0 | |
| sin_char_count = 0 | |
| for t_i,t in enumerate(tokens): | |
| if check_only_numbers(t): | |
| continue | |
| token_list = get_chars(remove_symbols_and_numbers(t),without_si_modifiers = without_si_mods) | |
| token_len = len(token_list) | |
| total_chars += token_len | |
| for ch in token_list: | |
| if is_sinhala(ch): | |
| sin_char_count += 1 | |
| else: | |
| latin_char_count += 1 | |
| if total_chars == 0: | |
| script = 'Symbol' | |
| else: | |
| en_percentage = latin_char_count/total_chars | |
| si_percentage = sin_char_count/total_chars | |
| if en_percentage >= en_thresh: | |
| script = 'Latin' | |
| elif si_percentage >= si_thresh: | |
| script = 'Sinhala' | |
| elif en_percentage < en_thresh and si_percentage < si_thresh: | |
| script = 'Mixed' | |
| return script | |
| HUMOUR_MODEL_PATH = 'ad-houlsby-humour-seed-42.ckpt' | |
| SENTIMENT_MODEL_PATH = 'ad-drop-houlsby-11-sentiment-seed-42.ckpt' | |
| humour_mapping = { | |
| 0: "Non-humourous", | |
| 1:"Humourous" | |
| } | |
| sentiment_mapping = { | |
| 0: "Negative", | |
| 1:"Neutral", | |
| 2:"Positive", | |
| 3:"Conflict" | |
| } | |
| def load_plm(model_name, model_path): | |
| model_config = XLMRobertaConfig.from_pretrained(model_path) | |
| model = XLMRobertaForMaskedLM.from_pretrained(model_path, config=model_config) | |
| tokenizer = XLMRobertaTokenizer.from_pretrained(model_path) | |
| wrapper = MLMTokenizerWrapper | |
| return model, tokenizer, wrapper | |
| class MLMTokenizerWrapper(TokenizerWrapper): | |
| add_input_keys = ['input_ids', 'attention_mask', 'token_type_ids'] | |
| def mask_token(self): | |
| return self.tokenizer.mask_token | |
| def mask_token_ids(self): | |
| return self.tokenizer.mask_token_id | |
| def num_special_tokens_to_add(self): | |
| if not hasattr(self, '_num_specials'): | |
| self._num_specials = self.tokenizer.num_special_tokens_to_add() | |
| return self._num_specials | |
| def tokenize_one_example(self, wrapped_example, teacher_forcing): | |
| wrapped_example, others = wrapped_example | |
| encoded_tgt_text = [] | |
| if 'tgt_text' in others: | |
| tgt_text = others['tgt_text'] | |
| if isinstance(tgt_text, str): | |
| tgt_text = [tgt_text] | |
| for t in tgt_text: | |
| encoded_tgt_text.append(self.tokenizer.encode(t, add_special_tokens=False)) | |
| mask_id = 0 # the i-th the mask token in the template. | |
| encoder_inputs = defaultdict(list) | |
| for piece in wrapped_example: | |
| if piece['loss_ids']==1: | |
| if teacher_forcing: # fill the mask with the tgt task | |
| raise RuntimeError("Masked Language Model can't perform teacher forcing training!") | |
| else: | |
| encode_text = [self.mask_token_ids] | |
| mask_id += 1 | |
| if piece['text'] in self.special_tokens_maps.keys(): | |
| to_replace = self.special_tokens_maps[piece['text']] | |
| if to_replace is not None: | |
| piece['text'] = to_replace | |
| else: | |
| raise KeyError("This tokenizer doesn't specify {} token.".format(piece['text'])) | |
| if 'soft_token_ids' in piece and piece['soft_token_ids']!=0: | |
| encode_text = [0] # can be replace by any token, since these token will use their own embeddings | |
| else: | |
| encode_text = self.tokenizer.encode(piece['text'], add_special_tokens=False) | |
| encoding_length = len(encode_text) | |
| encoder_inputs['input_ids'].append(encode_text) | |
| for key in piece: | |
| if key not in ['text']: | |
| encoder_inputs[key].append([piece[key]]*encoding_length) | |
| encoder_inputs = self.truncate(encoder_inputs=encoder_inputs) | |
| # delete shortenable ids | |
| encoder_inputs.pop("shortenable_ids") | |
| encoder_inputs = self.concate_parts(input_dict=encoder_inputs) | |
| encoder_inputs = self.add_special_tokens(encoder_inputs=encoder_inputs) | |
| # create special input ids | |
| encoder_inputs['attention_mask'] = [1] *len(encoder_inputs['input_ids']) | |
| if self.create_token_type_ids: | |
| encoder_inputs['token_type_ids'] = [0] *len(encoder_inputs['input_ids']) | |
| # padding | |
| encoder_inputs = self.padding(input_dict=encoder_inputs, max_len=self.max_seq_length, pad_id_for_inputs=self.tokenizer.pad_token_id) | |
| if len(encoded_tgt_text) > 0: | |
| encoder_inputs = {**encoder_inputs, "encoded_tgt_text": encoded_tgt_text}# convert defaultdict to dict | |
| else: | |
| encoder_inputs = {**encoder_inputs} | |
| return encoder_inputs | |
| plm, tokenizer, wrapper_class = load_plm("xlm", "xlm-roberta-base") | |
| plm_copy = copy.deepcopy(plm) | |
| tokenizer_copy = copy.deepcopy(tokenizer) | |
| wrapper_class_copy = copy.deepcopy(wrapper_class) | |
| sent_adapter_name = "Task_Sentiment" | |
| sent_adapter_config = AdapterConfig.load("houlsby") | |
| sent_adapter_config.leave_out.extend([11]) | |
| plm.add_adapter(sent_adapter_name, config=sent_adapter_config) | |
| plm.set_active_adapters(sent_adapter_name) | |
| plm.train_adapter(sent_adapter_name) | |
| sent_template = '{"placeholder": "text_a"}. {"soft": "The"} {"soft": "sentiment"} {"soft": "or"} {"soft": "the"} {"soft": "feeling"} {"soft": "of"} {"soft": "the"} {"soft": "given"} {"soft": "sentence"} {"soft": "can"} {"soft": "be"} {"soft": "classified"} {"soft": "as"} {"soft": "positive"} {"soft": ","} {"soft": "negative"} {"soft": "or"} {"soft": "neutral"} {"soft": "."} {"soft": "The"} {"soft": "classified"} {"soft": "sentiment"} {"soft": "of"} {"soft": "the"} {"soft": "sentence"} {"soft": "is"} {"mask"}.' | |
| sent_promptTemplate = MixedTemplate(model=plm, text = sent_template, tokenizer = tokenizer) | |
| sent_promptVerbalizer = SoftVerbalizer(tokenizer, plm, num_classes=4) | |
| sent_promptModel = PromptForClassification(template = sent_promptTemplate, plm = plm, verbalizer = sent_promptVerbalizer) | |
| sent_promptModel.load_state_dict(torch.load(SENTIMENT_MODEL_PATH,map_location=torch.device('cpu'))) | |
| sent_promptModel.eval() | |
| hum_adapter_name = "Ad_Humour" | |
| hum_adapter_config = AdapterConfig.load("houlsby") | |
| plm_copy.add_adapter(hum_adapter_name, config=hum_adapter_config) | |
| plm_copy.set_active_adapters(hum_adapter_name) | |
| plm_copy.train_adapter(hum_adapter_name) | |
| hum_template = '{"placeholder": "text_a"}. {"soft": "Capture"} {"soft": "the"} {"soft": "comedic"} {"soft": "elements"} {"soft": "of"} {"soft": "the"} {"soft": "given"} {"soft": "sentence"} {"soft": "and"} {"soft": "classify"} {"soft": "as"} {"soft": "Humorous"} {"soft": ","} {"soft": "otherwise"} {"soft": "classify"} {"soft": "as"} {"soft": "Non-humorous"} {"soft": "."} {"soft": "The"} {"soft": "sentence"} {"soft": "is"} {"mask"}.' | |
| hum_promptTemplate = MixedTemplate(model=plm_copy, text = hum_template, tokenizer = tokenizer_copy) | |
| hum_promptVerbalizer = SoftVerbalizer(tokenizer_copy, plm_copy, num_classes=2) | |
| hum_promptModel = PromptForClassification(template = hum_promptTemplate, plm = plm_copy, verbalizer = hum_promptVerbalizer) | |
| hum_promptModel.load_state_dict(torch.load(HUMOUR_MODEL_PATH,map_location=torch.device('cpu'))) | |
| hum_promptModel.eval() | |
| def sentiment(text): | |
| pred = None | |
| dataset = [ | |
| InputExample( | |
| guid = 0, | |
| text_a = text, | |
| ) | |
| ] | |
| data_loader = PromptDataLoader( | |
| dataset = dataset, | |
| tokenizer = tokenizer, | |
| template = sent_promptTemplate, | |
| tokenizer_wrapper_class=wrapper_class, | |
| ) | |
| for step, inputs in enumerate(data_loader): | |
| logits = sent_promptModel(inputs) | |
| pred = sentiment_mapping[torch.argmax(logits, dim=-1).cpu().tolist()[0]] | |
| return pred | |
| def humour(text): | |
| pred = None | |
| dataset = [ | |
| InputExample( | |
| guid = 0, | |
| text_a = text, | |
| ) | |
| ] | |
| data_loader = PromptDataLoader( | |
| dataset = dataset, | |
| tokenizer = tokenizer_copy, | |
| template = hum_promptTemplate, | |
| tokenizer_wrapper_class=wrapper_class_copy, | |
| ) | |
| for step, inputs in enumerate(data_loader): | |
| logits = hum_promptModel(inputs) | |
| pred = humour_mapping[torch.argmax(logits, dim=-1).cpu().tolist()[0]] | |
| return pred | |
| def classifier(text, task): | |
| one_script = script_classify(text,1.0,1.0,True) | |
| pointnine_script = script_classify(text,0.9,0.9,True) | |
| if task == "Sentiment Classification": | |
| return sentiment(text),one_script, pointnine_script | |
| elif task == "Humour Detection": | |
| return humour(text),one_script, pointnine_script | |
| demo = gr.Interface( | |
| title="Use of Prompt-Based Learning For Code-Mixed Text Classification", | |
| fn=classifier, | |
| inputs=[ | |
| gr.Textbox(placeholder="Enter an input sentence...",label="Input Sentence"), | |
| gr.Radio(["Sentiment Classification", "Humour Detection"], label="Task") | |
| ], | |
| outputs=[ | |
| gr.Label(label="Label"), | |
| gr.Textbox(label="Script Threshold 100%"), | |
| gr.Textbox(label="Script Threshold 90%") | |
| ], | |
| allow_flagging = "never", | |
| examples=[ | |
| ["Mama kamathi cricket matches balanna", "Sentiment Classification"], | |
| ["මම sweet food වලට කැමති නෑ", "Sentiment Classification"], | |
| ["The weather outside is neither too hot nor too cold", "Sentiment Classification"], | |
| ["ඉබ්බයි හාවයි හොඳ යාලුවොලු", "Humour Detection"], | |
| ["Kandy ගොඩක් lassanai", "Humour Detection"] | |
| ]) | |
| demo.launch() |