cmgchess commited on
Commit
48ae672
·
1 Parent(s): 11764ac
Files changed (2) hide show
  1. app.py +260 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ import copy
6
+ import torch.nn.functional as F
7
+ from collections import defaultdict
8
+ from openprompt import PromptDataLoader, PromptForClassification
9
+ from openprompt.data_utils import InputExample
10
+ from openprompt.prompts import MixedTemplate, SoftVerbalizer
11
+ from transformers import AdamW, get_linear_schedule_with_warmup, XLMRobertaConfig, XLMRobertaTokenizer, XLMRobertaModel, XLMRobertaForMaskedLM, set_seed, AdapterConfig
12
+ from openprompt.plms.utils import TokenizerWrapper
13
+
14
+ import re
15
+
16
+ def check_only_numbers(string):
17
+ return string.isdigit()
18
+
19
+ def remove_symbols_and_numbers(string):
20
+ pattern = r"[-()\"#/@;:<>{}`+=~|_▁.!?,1234567890]"
21
+ clean_string = re.sub(pattern, '', string)
22
+ return clean_string
23
+
24
+ def is_sinhala(char):
25
+ # https://unicode.org/charts/PDF/U0D80.pdf
26
+ return ord(char) >= 0x0D80 and ord(char) <= 0x0DFF
27
+
28
+ def get_chars(word, without_si_modifiers = True):
29
+ mods = [0x0DCA,0x0DCF,0x0DD0,0x0DD1,0x0DD2,0x0DD3,0x0DD4,0x0DD5,0x0DD6,0x0DD7,0x0DD8,0x0DD9,0x0DDA,0x0DDB,0x0DDC,0x0DDD,0x0DDE,0x0DDF,0x0DF2,0x0DF3]
30
+ if without_si_modifiers:
31
+ return [char for char in list(word) if ord(char) not in mods]
32
+ else:
33
+ return list(word)
34
+
35
+
36
+ def script_classify(text,en_thresh,si_thresh,without_si_mods):
37
+ script = ""
38
+ tokens = text.split()
39
+ total_chars = 0
40
+ latin_char_count = 0
41
+ sin_char_count = 0
42
+ for t_i,t in enumerate(tokens):
43
+ if check_only_numbers(t):
44
+ continue
45
+ token_list = get_chars(remove_symbols_and_numbers(t),without_si_modifiers = without_si_mods)
46
+ token_len = len(token_list)
47
+ total_chars += token_len
48
+ for ch in token_list:
49
+ if is_sinhala(ch):
50
+ sin_char_count += 1
51
+ else:
52
+ latin_char_count += 1
53
+ if total_chars == 0:
54
+ script = 'Symbol'
55
+ else:
56
+ en_percentage = latin_char_count/total_chars
57
+ si_percentage = sin_char_count/total_chars
58
+ if en_percentage >= en_thresh:
59
+ script = 'Latin'
60
+ elif si_percentage >= si_thresh:
61
+ script = 'Sinhala'
62
+ elif en_percentage < en_thresh and si_percentage < si_thresh:
63
+ script = 'Mixed'
64
+ return script
65
+
66
+ HUMOUR_MODEL_PATH = 'ad-houlsby-humour-seed-42.ckpt'
67
+ SENTIMENT_MODEL_PATH = 'ad-drop-houlsby-11-sentiment-seed-42.ckpt'
68
+ humour_mapping = {
69
+ 0: "Non-humourous",
70
+ 1:"Humourous"
71
+ }
72
+
73
+ sentiment_mapping = {
74
+ 0: "Negative",
75
+ 1:"Neutral",
76
+ 2:"Positive",
77
+ 3:"Conflict"
78
+ }
79
+
80
+ def load_plm(model_name, model_path):
81
+ model_config = XLMRobertaConfig.from_pretrained(model_path)
82
+ model = XLMRobertaForMaskedLM.from_pretrained(model_path, config=model_config)
83
+ tokenizer = XLMRobertaTokenizer.from_pretrained(model_path)
84
+ wrapper = MLMTokenizerWrapper
85
+ return model, tokenizer, wrapper
86
+
87
+ class MLMTokenizerWrapper(TokenizerWrapper):
88
+ add_input_keys = ['input_ids', 'attention_mask', 'token_type_ids']
89
+
90
+ @property
91
+ def mask_token(self):
92
+ return self.tokenizer.mask_token
93
+
94
+ @property
95
+ def mask_token_ids(self):
96
+ return self.tokenizer.mask_token_id
97
+
98
+ @property
99
+ def num_special_tokens_to_add(self):
100
+ if not hasattr(self, '_num_specials'):
101
+ self._num_specials = self.tokenizer.num_special_tokens_to_add()
102
+ return self._num_specials
103
+
104
+ def tokenize_one_example(self, wrapped_example, teacher_forcing):
105
+ wrapped_example, others = wrapped_example
106
+ encoded_tgt_text = []
107
+ if 'tgt_text' in others:
108
+ tgt_text = others['tgt_text']
109
+ if isinstance(tgt_text, str):
110
+ tgt_text = [tgt_text]
111
+ for t in tgt_text:
112
+ encoded_tgt_text.append(self.tokenizer.encode(t, add_special_tokens=False))
113
+
114
+ mask_id = 0 # the i-th the mask token in the template.
115
+
116
+ encoder_inputs = defaultdict(list)
117
+ for piece in wrapped_example:
118
+ if piece['loss_ids']==1:
119
+ if teacher_forcing: # fill the mask with the tgt task
120
+ raise RuntimeError("Masked Language Model can't perform teacher forcing training!")
121
+ else:
122
+ encode_text = [self.mask_token_ids]
123
+ mask_id += 1
124
+
125
+ if piece['text'] in self.special_tokens_maps.keys():
126
+ to_replace = self.special_tokens_maps[piece['text']]
127
+ if to_replace is not None:
128
+ piece['text'] = to_replace
129
+ else:
130
+ raise KeyError("This tokenizer doesn't specify {} token.".format(piece['text']))
131
+
132
+ if 'soft_token_ids' in piece and piece['soft_token_ids']!=0:
133
+ encode_text = [0] # can be replace by any token, since these token will use their own embeddings
134
+ else:
135
+ encode_text = self.tokenizer.encode(piece['text'], add_special_tokens=False)
136
+
137
+ encoding_length = len(encode_text)
138
+ encoder_inputs['input_ids'].append(encode_text)
139
+ for key in piece:
140
+ if key not in ['text']:
141
+ encoder_inputs[key].append([piece[key]]*encoding_length)
142
+
143
+ encoder_inputs = self.truncate(encoder_inputs=encoder_inputs)
144
+ # delete shortenable ids
145
+ encoder_inputs.pop("shortenable_ids")
146
+ encoder_inputs = self.concate_parts(input_dict=encoder_inputs)
147
+ encoder_inputs = self.add_special_tokens(encoder_inputs=encoder_inputs)
148
+ # create special input ids
149
+ encoder_inputs['attention_mask'] = [1] *len(encoder_inputs['input_ids'])
150
+ if self.create_token_type_ids:
151
+ encoder_inputs['token_type_ids'] = [0] *len(encoder_inputs['input_ids'])
152
+ # padding
153
+ encoder_inputs = self.padding(input_dict=encoder_inputs, max_len=self.max_seq_length, pad_id_for_inputs=self.tokenizer.pad_token_id)
154
+
155
+ if len(encoded_tgt_text) > 0:
156
+ encoder_inputs = {**encoder_inputs, "encoded_tgt_text": encoded_tgt_text}# convert defaultdict to dict
157
+ else:
158
+ encoder_inputs = {**encoder_inputs}
159
+ return encoder_inputs
160
+
161
+
162
+ plm, tokenizer, wrapper_class = load_plm("xlm", "xlm-roberta-base")
163
+ plm_copy = copy.deepcopy(plm)
164
+ tokenizer_copy = copy.deepcopy(tokenizer)
165
+ wrapper_class_copy = copy.deepcopy(wrapper_class)
166
+ sent_adapter_name = "Task_Sentiment"
167
+ sent_adapter_config = AdapterConfig.load("houlsby")
168
+ sent_adapter_config.leave_out.extend([11])
169
+ plm.add_adapter(sent_adapter_name, config=sent_adapter_config)
170
+ plm.set_active_adapters(sent_adapter_name)
171
+ plm.train_adapter(sent_adapter_name)
172
+ sent_template = '{"placeholder": "text_a"}. {"soft": "The"} {"soft": "sentiment"} {"soft": "or"} {"soft": "the"} {"soft": "feeling"} {"soft": "of"} {"soft": "the"} {"soft": "given"} {"soft": "sentence"} {"soft": "can"} {"soft": "be"} {"soft": "classified"} {"soft": "as"} {"soft": "positive"} {"soft": ","} {"soft": "negative"} {"soft": "or"} {"soft": "neutral"} {"soft": "."} {"soft": "The"} {"soft": "classified"} {"soft": "sentiment"} {"soft": "of"} {"soft": "the"} {"soft": "sentence"} {"soft": "is"} {"mask"}.'
173
+ sent_promptTemplate = MixedTemplate(model=plm, text = sent_template, tokenizer = tokenizer)
174
+ sent_promptVerbalizer = SoftVerbalizer(tokenizer, plm, num_classes=4)
175
+ sent_promptModel = PromptForClassification(template = sent_promptTemplate, plm = plm, verbalizer = sent_promptVerbalizer)
176
+ sent_promptModel.load_state_dict(torch.load(SENTIMENT_MODEL_PATH,map_location=torch.device('cpu')))
177
+ sent_promptModel.eval()
178
+
179
+ hum_adapter_name = "Ad_Humour"
180
+ hum_adapter_config = AdapterConfig.load("houlsby")
181
+ plm_copy.add_adapter(hum_adapter_name, config=hum_adapter_config)
182
+ plm_copy.set_active_adapters(hum_adapter_name)
183
+ plm_copy.train_adapter(hum_adapter_name)
184
+ hum_template = '{"placeholder": "text_a"}. {"soft": "Capture"} {"soft": "the"} {"soft": "comedic"} {"soft": "elements"} {"soft": "of"} {"soft": "the"} {"soft": "given"} {"soft": "sentence"} {"soft": "and"} {"soft": "classify"} {"soft": "as"} {"soft": "Humorous"} {"soft": ","} {"soft": "otherwise"} {"soft": "classify"} {"soft": "as"} {"soft": "Non-humorous"} {"soft": "."} {"soft": "The"} {"soft": "sentence"} {"soft": "is"} {"mask"}.'
185
+ hum_promptTemplate = MixedTemplate(model=plm_copy, text = hum_template, tokenizer = tokenizer_copy)
186
+ hum_promptVerbalizer = SoftVerbalizer(tokenizer_copy, plm_copy, num_classes=2)
187
+ hum_promptModel = PromptForClassification(template = hum_promptTemplate, plm = plm_copy, verbalizer = hum_promptVerbalizer)
188
+ hum_promptModel.load_state_dict(torch.load(HUMOUR_MODEL_PATH,map_location=torch.device('cpu')))
189
+ hum_promptModel.eval()
190
+
191
+ def sentiment(text):
192
+ pred = None
193
+ dataset = [
194
+ InputExample(
195
+ guid = 0,
196
+ text_a = text,
197
+ )
198
+ ]
199
+ data_loader = PromptDataLoader(
200
+ dataset = dataset,
201
+ tokenizer = tokenizer,
202
+ template = sent_promptTemplate,
203
+ tokenizer_wrapper_class=wrapper_class,
204
+ )
205
+ for step, inputs in enumerate(data_loader):
206
+ logits = sent_promptModel(inputs)
207
+ pred = sentiment_mapping[torch.argmax(logits, dim=-1).cpu().tolist()[0]]
208
+ return pred
209
+
210
+ def humour(text):
211
+ pred = None
212
+ dataset = [
213
+ InputExample(
214
+ guid = 0,
215
+ text_a = text,
216
+ )
217
+ ]
218
+ data_loader = PromptDataLoader(
219
+ dataset = dataset,
220
+ tokenizer = tokenizer_copy,
221
+ template = hum_promptTemplate,
222
+ tokenizer_wrapper_class=wrapper_class_copy,
223
+ )
224
+ for step, inputs in enumerate(data_loader):
225
+ logits = hum_promptModel(inputs)
226
+ pred = humour_mapping[torch.argmax(logits, dim=-1).cpu().tolist()[0]]
227
+ return pred
228
+
229
+
230
+ def classifier(text, task):
231
+ one_script = script_classify(text,1.0,1.0,True)
232
+ pointnine_script = script_classify(text,0.9,0.9,True)
233
+ if task == "Sentiment Classification":
234
+ return sentiment(text),one_script, pointnine_script
235
+ elif task == "Humour Detection":
236
+ return humour(text),one_script, pointnine_script
237
+
238
+
239
+ demo = gr.Interface(
240
+ title="Use of Prompt-Based Learning For Code-Mixed Text Classification",
241
+ fn=classifier,
242
+ inputs=[
243
+ gr.Textbox(placeholder="Enter an input sentence...",label="Input Sentence"),
244
+ gr.Radio(["Sentiment Classification", "Humour Detection"], label="Task")
245
+ ],
246
+ outputs=[
247
+ gr.Label(label="Label"),
248
+ gr.Textbox(label="Script Threshold 100%"),
249
+ gr.Textbox(label="Script Threshold 90%")
250
+ ],
251
+ allow_flagging = "never",
252
+ examples=[
253
+ ["Mama kamathi cricket matches balanna", "Sentiment Classification"],
254
+ ["මම sweet food වලට කැමති නෑ", "Sentiment Classification"],
255
+ ["The weather outside is neither too hot nor too cold", "Sentiment Classification"],
256
+ ["ඉබ්බයි හාවයි හොඳ යාලුවොලු", "Humour Detection"],
257
+ ["Kandy ගොඩක් lassanai", "Humour Detection"]
258
+ ])
259
+
260
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ numpy
4
+ pandas
5
+ openprompt
6
+ transformers
7
+ adapter-transformers==3.1.0