Spaces:
Runtime error
Runtime error
Jean Garcia-Gathright
commited on
Commit
·
a02c788
1
Parent(s):
4150cb0
added ernie files
Browse files- app.py +2 -2
- app.py~ +7 -0
- ernie/__init__.py +47 -0
- ernie/aggregation_strategies.py +70 -0
- ernie/ernie.py +397 -0
- ernie/helper.py +121 -0
- ernie/models.py +51 -0
- ernie/split_strategies.py +125 -0
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import
|
| 3 |
-
import
|
| 4 |
|
| 5 |
def greet(name):
|
| 6 |
return "Hello " + name + "!!"
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from ernie.ernie import SentenceClassifier
|
| 3 |
+
from ernie import helper
|
| 4 |
|
| 5 |
def greet(name):
|
| 6 |
return "Hello " + name + "!!"
|
app.py~
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
def greet(name):
|
| 4 |
+
return "Hello " + name + "!!"
|
| 5 |
+
|
| 6 |
+
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
| 7 |
+
iface.launch()
|
ernie/__init__.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from .ernie import * # noqa: F401, F403
|
| 5 |
+
from tensorflow.python.client import device_lib
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
__version__ = '1.0.1'
|
| 9 |
+
|
| 10 |
+
logging.getLogger().setLevel(logging.WARNING)
|
| 11 |
+
logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
|
| 12 |
+
logging.basicConfig(
|
| 13 |
+
format='%(asctime)-15s [%(levelname)s] %(message)s',
|
| 14 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _get_cpu_name():
|
| 19 |
+
import cpuinfo
|
| 20 |
+
cpu_info = cpuinfo.get_cpu_info()
|
| 21 |
+
cpu_name = f"{cpu_info['brand_raw']}, {cpu_info['count']} vCores"
|
| 22 |
+
return cpu_name
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _get_gpu_name():
|
| 26 |
+
gpu_name = \
|
| 27 |
+
device_lib\
|
| 28 |
+
.list_local_devices()[3]\
|
| 29 |
+
.physical_device_desc\
|
| 30 |
+
.split(',')[1]\
|
| 31 |
+
.split('name:')[1]\
|
| 32 |
+
.strip()
|
| 33 |
+
return gpu_name
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
device_name = _get_cpu_name()
|
| 37 |
+
device_type = 'CPU'
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
device_name = _get_gpu_name()
|
| 41 |
+
device_type = 'GPU'
|
| 42 |
+
except IndexError:
|
| 43 |
+
# Detect TPU
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
logging.info(f'ernie v{__version__}')
|
| 47 |
+
logging.info(f'target device: [{device_type}] {device_name}\n')
|
ernie/aggregation_strategies.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from statistics import mean
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AggregationStrategy:
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
method,
|
| 11 |
+
max_items=None,
|
| 12 |
+
top_items=True,
|
| 13 |
+
sorting_class_index=1
|
| 14 |
+
):
|
| 15 |
+
self.method = method
|
| 16 |
+
self.max_items = max_items
|
| 17 |
+
self.top_items = top_items
|
| 18 |
+
self.sorting_class_index = sorting_class_index
|
| 19 |
+
|
| 20 |
+
def aggregate(self, softmax_tuples):
|
| 21 |
+
softmax_dicts = []
|
| 22 |
+
for softmax_tuple in softmax_tuples:
|
| 23 |
+
softmax_dict = {}
|
| 24 |
+
for i, probability in enumerate(softmax_tuple):
|
| 25 |
+
softmax_dict[i] = probability
|
| 26 |
+
softmax_dicts.append(softmax_dict)
|
| 27 |
+
|
| 28 |
+
if self.max_items is not None:
|
| 29 |
+
softmax_dicts = sorted(
|
| 30 |
+
softmax_dicts,
|
| 31 |
+
key=lambda x: x[self.sorting_class_index],
|
| 32 |
+
reverse=self.top_items
|
| 33 |
+
)
|
| 34 |
+
if self.max_items < len(softmax_dicts):
|
| 35 |
+
softmax_dicts = softmax_dicts[:self.max_items]
|
| 36 |
+
|
| 37 |
+
softmax_list = []
|
| 38 |
+
for key in softmax_dicts[0].keys():
|
| 39 |
+
softmax_list.append(self.method(
|
| 40 |
+
[probabilities[key] for probabilities in softmax_dicts]))
|
| 41 |
+
softmax_tuple = tuple(softmax_list)
|
| 42 |
+
return softmax_tuple
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class AggregationStrategies:
|
| 46 |
+
Mean = AggregationStrategy(method=mean)
|
| 47 |
+
MeanTopFiveBinaryClassification = AggregationStrategy(
|
| 48 |
+
method=mean,
|
| 49 |
+
max_items=5,
|
| 50 |
+
top_items=True,
|
| 51 |
+
sorting_class_index=1
|
| 52 |
+
)
|
| 53 |
+
MeanTopTenBinaryClassification = AggregationStrategy(
|
| 54 |
+
method=mean,
|
| 55 |
+
max_items=10,
|
| 56 |
+
top_items=True,
|
| 57 |
+
sorting_class_index=1
|
| 58 |
+
)
|
| 59 |
+
MeanTopFifteenBinaryClassification = AggregationStrategy(
|
| 60 |
+
method=mean,
|
| 61 |
+
max_items=15,
|
| 62 |
+
top_items=True,
|
| 63 |
+
sorting_class_index=1
|
| 64 |
+
)
|
| 65 |
+
MeanTopTwentyBinaryClassification = AggregationStrategy(
|
| 66 |
+
method=mean,
|
| 67 |
+
max_items=20,
|
| 68 |
+
top_items=True,
|
| 69 |
+
sorting_class_index=1
|
| 70 |
+
)
|
ernie/ernie.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from transformers import (
|
| 7 |
+
AutoTokenizer,
|
| 8 |
+
AutoModel,
|
| 9 |
+
AutoConfig,
|
| 10 |
+
TFAutoModelForSequenceClassification,
|
| 11 |
+
)
|
| 12 |
+
from tensorflow import keras
|
| 13 |
+
from sklearn.model_selection import train_test_split
|
| 14 |
+
import logging
|
| 15 |
+
import time
|
| 16 |
+
from .models import Models, ModelsByFamily # noqa: F401
|
| 17 |
+
from .split_strategies import ( # noqa: F401
|
| 18 |
+
SplitStrategy,
|
| 19 |
+
SplitStrategies,
|
| 20 |
+
RegexExpressions
|
| 21 |
+
)
|
| 22 |
+
from .aggregation_strategies import ( # noqa: F401
|
| 23 |
+
AggregationStrategy,
|
| 24 |
+
AggregationStrategies
|
| 25 |
+
)
|
| 26 |
+
from .helper import (
|
| 27 |
+
get_features,
|
| 28 |
+
softmax,
|
| 29 |
+
remove_dir,
|
| 30 |
+
make_dir,
|
| 31 |
+
copy_dir
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
AUTOSAVE_PATH = './ernie-autosave/'
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def clean_autosave():
|
| 38 |
+
remove_dir(AUTOSAVE_PATH)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class SentenceClassifier:
|
| 42 |
+
def __init__(self,
|
| 43 |
+
model_name=Models.BertBaseUncased,
|
| 44 |
+
model_path=None,
|
| 45 |
+
max_length=64,
|
| 46 |
+
labels_no=2,
|
| 47 |
+
tokenizer_kwargs=None,
|
| 48 |
+
model_kwargs=None):
|
| 49 |
+
self._loaded_data = False
|
| 50 |
+
self._model_path = None
|
| 51 |
+
|
| 52 |
+
if model_kwargs is None:
|
| 53 |
+
model_kwargs = {}
|
| 54 |
+
model_kwargs['num_labels'] = labels_no
|
| 55 |
+
|
| 56 |
+
if tokenizer_kwargs is None:
|
| 57 |
+
tokenizer_kwargs = {}
|
| 58 |
+
tokenizer_kwargs['max_len'] = max_length
|
| 59 |
+
|
| 60 |
+
if model_path is not None:
|
| 61 |
+
self._load_local_model(model_path)
|
| 62 |
+
else:
|
| 63 |
+
self._load_remote_model(model_name, tokenizer_kwargs, model_kwargs)
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def model(self):
|
| 67 |
+
return self._model
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def tokenizer(self):
|
| 71 |
+
return self._tokenizer
|
| 72 |
+
|
| 73 |
+
def load_dataset(self,
|
| 74 |
+
dataframe=None,
|
| 75 |
+
validation_split=0.1,
|
| 76 |
+
random_state=None,
|
| 77 |
+
stratify=None,
|
| 78 |
+
csv_path=None,
|
| 79 |
+
read_csv_kwargs=None):
|
| 80 |
+
|
| 81 |
+
if dataframe is None and csv_path is None:
|
| 82 |
+
raise ValueError
|
| 83 |
+
|
| 84 |
+
if csv_path is not None:
|
| 85 |
+
dataframe = pd.read_csv(csv_path, **read_csv_kwargs)
|
| 86 |
+
|
| 87 |
+
sentences = list(dataframe[dataframe.columns[0]])
|
| 88 |
+
labels = dataframe[dataframe.columns[1]].values
|
| 89 |
+
|
| 90 |
+
(
|
| 91 |
+
training_sentences,
|
| 92 |
+
validation_sentences,
|
| 93 |
+
training_labels,
|
| 94 |
+
validation_labels
|
| 95 |
+
) = train_test_split(
|
| 96 |
+
sentences,
|
| 97 |
+
labels,
|
| 98 |
+
test_size=validation_split,
|
| 99 |
+
shuffle=True,
|
| 100 |
+
random_state=random_state,
|
| 101 |
+
stratify=stratify
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
self._training_features = get_features(
|
| 105 |
+
self._tokenizer, training_sentences, training_labels)
|
| 106 |
+
|
| 107 |
+
self._training_size = len(training_sentences)
|
| 108 |
+
|
| 109 |
+
self._validation_features = get_features(
|
| 110 |
+
self._tokenizer,
|
| 111 |
+
validation_sentences,
|
| 112 |
+
validation_labels
|
| 113 |
+
)
|
| 114 |
+
self._validation_split = len(validation_sentences)
|
| 115 |
+
|
| 116 |
+
logging.info(f'training_size: {self._training_size}')
|
| 117 |
+
logging.info(f'validation_split: {self._validation_split}')
|
| 118 |
+
|
| 119 |
+
self._loaded_data = True
|
| 120 |
+
|
| 121 |
+
def fine_tune(self,
|
| 122 |
+
epochs=4,
|
| 123 |
+
learning_rate=2e-5,
|
| 124 |
+
epsilon=1e-8,
|
| 125 |
+
clipnorm=1.0,
|
| 126 |
+
optimizer_function=keras.optimizers.Adam,
|
| 127 |
+
optimizer_kwargs=None,
|
| 128 |
+
loss_function=keras.losses.SparseCategoricalCrossentropy,
|
| 129 |
+
loss_kwargs=None,
|
| 130 |
+
accuracy_function=keras.metrics.SparseCategoricalAccuracy,
|
| 131 |
+
accuracy_kwargs=None,
|
| 132 |
+
training_batch_size=32,
|
| 133 |
+
validation_batch_size=64,
|
| 134 |
+
**kwargs):
|
| 135 |
+
if not self._loaded_data:
|
| 136 |
+
raise Exception('Data has not been loaded.')
|
| 137 |
+
|
| 138 |
+
if optimizer_kwargs is None:
|
| 139 |
+
optimizer_kwargs = {
|
| 140 |
+
'learning_rate': learning_rate,
|
| 141 |
+
'epsilon': epsilon,
|
| 142 |
+
'clipnorm': clipnorm
|
| 143 |
+
}
|
| 144 |
+
optimizer = optimizer_function(**optimizer_kwargs)
|
| 145 |
+
|
| 146 |
+
if loss_kwargs is None:
|
| 147 |
+
loss_kwargs = {'from_logits': True}
|
| 148 |
+
loss = loss_function(**loss_kwargs)
|
| 149 |
+
|
| 150 |
+
if accuracy_kwargs is None:
|
| 151 |
+
accuracy_kwargs = {'name': 'accuracy'}
|
| 152 |
+
accuracy = accuracy_function(**accuracy_kwargs)
|
| 153 |
+
|
| 154 |
+
self._model.compile(optimizer=optimizer, loss=loss, metrics=[accuracy])
|
| 155 |
+
|
| 156 |
+
training_features = self._training_features.shuffle(
|
| 157 |
+
self._training_size).batch(training_batch_size).repeat(-1)
|
| 158 |
+
validation_features = self._validation_features.batch(
|
| 159 |
+
validation_batch_size)
|
| 160 |
+
|
| 161 |
+
training_steps = self._training_size // training_batch_size
|
| 162 |
+
if training_steps == 0:
|
| 163 |
+
training_steps = self._training_size
|
| 164 |
+
logging.info(f'training_steps: {training_steps}')
|
| 165 |
+
|
| 166 |
+
validation_steps = self._validation_split // validation_batch_size
|
| 167 |
+
if validation_steps == 0:
|
| 168 |
+
validation_steps = self._validation_split
|
| 169 |
+
logging.info(f'validation_steps: {validation_steps}')
|
| 170 |
+
|
| 171 |
+
for i in range(epochs):
|
| 172 |
+
self._model.fit(training_features,
|
| 173 |
+
epochs=1,
|
| 174 |
+
validation_data=validation_features,
|
| 175 |
+
steps_per_epoch=training_steps,
|
| 176 |
+
validation_steps=validation_steps,
|
| 177 |
+
**kwargs)
|
| 178 |
+
|
| 179 |
+
# The fine-tuned model does not have the same input interface
|
| 180 |
+
# after being exported and loaded again.
|
| 181 |
+
self._reload_model()
|
| 182 |
+
|
| 183 |
+
def predict_one(
|
| 184 |
+
self,
|
| 185 |
+
text,
|
| 186 |
+
split_strategy=None,
|
| 187 |
+
aggregation_strategy=None
|
| 188 |
+
):
|
| 189 |
+
return next(
|
| 190 |
+
self.predict([text],
|
| 191 |
+
batch_size=1,
|
| 192 |
+
split_strategy=split_strategy,
|
| 193 |
+
aggregation_strategy=aggregation_strategy))
|
| 194 |
+
|
| 195 |
+
def predict(
|
| 196 |
+
self,
|
| 197 |
+
texts,
|
| 198 |
+
batch_size=32,
|
| 199 |
+
split_strategy=None,
|
| 200 |
+
aggregation_strategy=None
|
| 201 |
+
):
|
| 202 |
+
if split_strategy is None:
|
| 203 |
+
yield from self._predict_batch(texts, batch_size)
|
| 204 |
+
|
| 205 |
+
else:
|
| 206 |
+
if aggregation_strategy is None:
|
| 207 |
+
aggregation_strategy = AggregationStrategies.Mean
|
| 208 |
+
|
| 209 |
+
split_indexes = [0]
|
| 210 |
+
sentences = []
|
| 211 |
+
for text in texts:
|
| 212 |
+
new_sentences = split_strategy.split(text, self.tokenizer)
|
| 213 |
+
if not new_sentences:
|
| 214 |
+
continue
|
| 215 |
+
split_indexes.append(split_indexes[-1] + len(new_sentences))
|
| 216 |
+
sentences.extend(new_sentences)
|
| 217 |
+
|
| 218 |
+
predictions = list(self._predict_batch(sentences, batch_size))
|
| 219 |
+
for i, split_index in enumerate(split_indexes[:-1]):
|
| 220 |
+
stop_index = split_indexes[i + 1]
|
| 221 |
+
yield aggregation_strategy.aggregate(
|
| 222 |
+
predictions[split_index:stop_index]
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def dump(self, path):
|
| 226 |
+
if self._model_path:
|
| 227 |
+
copy_dir(self._model_path, path)
|
| 228 |
+
else:
|
| 229 |
+
self._dump(path)
|
| 230 |
+
|
| 231 |
+
def _dump(self, path):
|
| 232 |
+
make_dir(path)
|
| 233 |
+
make_dir(path + '/tokenizer')
|
| 234 |
+
self._model.save_pretrained(path)
|
| 235 |
+
self._tokenizer.save_pretrained(path + '/tokenizer')
|
| 236 |
+
self._config.save_pretrained(path + '/tokenizer')
|
| 237 |
+
|
| 238 |
+
def _predict_batch(self, sentences: list, batch_size: int):
|
| 239 |
+
sentences_number = len(sentences)
|
| 240 |
+
if batch_size > sentences_number:
|
| 241 |
+
batch_size = sentences_number
|
| 242 |
+
|
| 243 |
+
for i in range(0, sentences_number, batch_size):
|
| 244 |
+
input_ids_list = []
|
| 245 |
+
attention_mask_list = []
|
| 246 |
+
|
| 247 |
+
stop_index = i + batch_size
|
| 248 |
+
stop_index = stop_index if stop_index < sentences_number \
|
| 249 |
+
else sentences_number
|
| 250 |
+
|
| 251 |
+
for j in range(i, stop_index):
|
| 252 |
+
features = self._tokenizer.encode_plus(
|
| 253 |
+
sentences[j],
|
| 254 |
+
add_special_tokens=True,
|
| 255 |
+
max_length=self._tokenizer.model_max_length
|
| 256 |
+
)
|
| 257 |
+
input_ids, _, attention_mask = (
|
| 258 |
+
features['input_ids'],
|
| 259 |
+
features['token_type_ids'],
|
| 260 |
+
features['attention_mask']
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
input_ids = self._list_to_padded_array(features['input_ids'])
|
| 264 |
+
attention_mask = self._list_to_padded_array(
|
| 265 |
+
features['attention_mask'])
|
| 266 |
+
|
| 267 |
+
input_ids_list.append(input_ids)
|
| 268 |
+
attention_mask_list.append(attention_mask)
|
| 269 |
+
|
| 270 |
+
input_dict = {
|
| 271 |
+
'input_ids': np.array(input_ids_list),
|
| 272 |
+
'attention_mask': np.array(attention_mask_list)
|
| 273 |
+
}
|
| 274 |
+
logit_predictions = self._model.predict_on_batch(input_dict)
|
| 275 |
+
yield from (
|
| 276 |
+
[softmax(logit_prediction)
|
| 277 |
+
for logit_prediction in logit_predictions[0]]
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def _list_to_padded_array(self, items):
|
| 281 |
+
array = np.array(items)
|
| 282 |
+
padded_array = np.zeros(self._tokenizer.model_max_length, dtype=np.int)
|
| 283 |
+
padded_array[:array.shape[0]] = array
|
| 284 |
+
return padded_array
|
| 285 |
+
|
| 286 |
+
def _get_temporary_path(self, name=''):
|
| 287 |
+
return f'{AUTOSAVE_PATH}{name}/{int(round(time.time() * 1000))}'
|
| 288 |
+
|
| 289 |
+
def _reload_model(self):
|
| 290 |
+
self._model_path = self._get_temporary_path(
|
| 291 |
+
name=self._get_model_family())
|
| 292 |
+
self._dump(self._model_path)
|
| 293 |
+
self._load_local_model(self._model_path)
|
| 294 |
+
|
| 295 |
+
def _load_local_model(self, model_path):
|
| 296 |
+
try:
|
| 297 |
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
| 298 |
+
model_path + '/tokenizer')
|
| 299 |
+
self._config = AutoConfig.from_pretrained(
|
| 300 |
+
model_path + '/tokenizer')
|
| 301 |
+
|
| 302 |
+
# Old models didn't use to have a tokenizer folder
|
| 303 |
+
except OSError:
|
| 304 |
+
self._tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 305 |
+
self._config = AutoConfig.from_pretrained(model_path)
|
| 306 |
+
self._model = TFAutoModelForSequenceClassification.from_pretrained(
|
| 307 |
+
model_path,
|
| 308 |
+
from_pt=False
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def _get_model_family(self):
|
| 312 |
+
model_family = ''.join(self._model.name[2:].split('_')[:2])
|
| 313 |
+
return model_family
|
| 314 |
+
|
| 315 |
+
def _load_remote_model(self, model_name, tokenizer_kwargs, model_kwargs):
|
| 316 |
+
do_lower_case = False
|
| 317 |
+
if 'uncased' in model_name.lower():
|
| 318 |
+
do_lower_case = True
|
| 319 |
+
tokenizer_kwargs.update({'do_lower_case': do_lower_case})
|
| 320 |
+
|
| 321 |
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
| 322 |
+
model_name, **tokenizer_kwargs)
|
| 323 |
+
self._config = AutoConfig.from_pretrained(model_name)
|
| 324 |
+
|
| 325 |
+
temporary_path = self._get_temporary_path()
|
| 326 |
+
make_dir(temporary_path)
|
| 327 |
+
|
| 328 |
+
# TensorFlow model
|
| 329 |
+
try:
|
| 330 |
+
self._model = TFAutoModelForSequenceClassification.from_pretrained(
|
| 331 |
+
model_name,
|
| 332 |
+
from_pt=False
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# PyTorch model
|
| 336 |
+
except TypeError:
|
| 337 |
+
try:
|
| 338 |
+
self._model = \
|
| 339 |
+
TFAutoModelForSequenceClassification.from_pretrained(
|
| 340 |
+
model_name,
|
| 341 |
+
from_pt=True
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Loading a TF model from a PyTorch checkpoint is not supported
|
| 345 |
+
# when using a model identifier name
|
| 346 |
+
except OSError:
|
| 347 |
+
model = AutoModel.from_pretrained(model_name)
|
| 348 |
+
model.save_pretrained(temporary_path)
|
| 349 |
+
self._model = \
|
| 350 |
+
TFAutoModelForSequenceClassification.from_pretrained(
|
| 351 |
+
temporary_path,
|
| 352 |
+
from_pt=True
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Clean the model's last layer if the provided properties are different
|
| 356 |
+
clean_last_layer = False
|
| 357 |
+
for key, value in model_kwargs.items():
|
| 358 |
+
if not hasattr(self._model.config, key):
|
| 359 |
+
clean_last_layer = True
|
| 360 |
+
break
|
| 361 |
+
|
| 362 |
+
if getattr(self._model.config, key) != value:
|
| 363 |
+
clean_last_layer = True
|
| 364 |
+
break
|
| 365 |
+
|
| 366 |
+
if clean_last_layer:
|
| 367 |
+
try:
|
| 368 |
+
getattr(self._model, self._get_model_family()
|
| 369 |
+
).save_pretrained(temporary_path)
|
| 370 |
+
self._model = self._model.__class__.from_pretrained(
|
| 371 |
+
temporary_path,
|
| 372 |
+
from_pt=False,
|
| 373 |
+
**model_kwargs
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# The model is itself the main layer
|
| 377 |
+
except AttributeError:
|
| 378 |
+
# TensorFlow model
|
| 379 |
+
try:
|
| 380 |
+
self._model = self._model.__class__.from_pretrained(
|
| 381 |
+
model_name,
|
| 382 |
+
from_pt=False,
|
| 383 |
+
**model_kwargs
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# PyTorch Model
|
| 387 |
+
except (OSError, TypeError):
|
| 388 |
+
model = AutoModel.from_pretrained(model_name)
|
| 389 |
+
model.save_pretrained(temporary_path)
|
| 390 |
+
self._model = self._model.__class__.from_pretrained(
|
| 391 |
+
temporary_path,
|
| 392 |
+
from_pt=True,
|
| 393 |
+
**model_kwargs
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
remove_dir(temporary_path)
|
| 397 |
+
assert self._tokenizer and self._model
|
ernie/helper.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from tensorflow import data, TensorShape, int64, int32
|
| 5 |
+
from math import exp
|
| 6 |
+
from os import makedirs
|
| 7 |
+
from shutil import rmtree, move, copytree
|
| 8 |
+
from huggingface_hub import hf_hub_download
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_features(tokenizer, sentences, labels):
|
| 13 |
+
features = []
|
| 14 |
+
for i, sentence in enumerate(sentences):
|
| 15 |
+
inputs = tokenizer.encode_plus(
|
| 16 |
+
sentence,
|
| 17 |
+
add_special_tokens=True,
|
| 18 |
+
max_length=tokenizer.model_max_length
|
| 19 |
+
)
|
| 20 |
+
input_ids, token_type_ids = \
|
| 21 |
+
inputs['input_ids'], inputs['token_type_ids']
|
| 22 |
+
padding_length = tokenizer.model_max_length - len(input_ids)
|
| 23 |
+
|
| 24 |
+
if tokenizer.padding_side == 'right':
|
| 25 |
+
attention_mask = [1] * len(input_ids) + [0] * padding_length
|
| 26 |
+
input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
|
| 27 |
+
token_type_ids = token_type_ids + \
|
| 28 |
+
[tokenizer.pad_token_type_id] * padding_length
|
| 29 |
+
else:
|
| 30 |
+
attention_mask = [0] * padding_length + [1] * len(input_ids)
|
| 31 |
+
input_ids = [tokenizer.pad_token_id] * padding_length + input_ids
|
| 32 |
+
token_type_ids = \
|
| 33 |
+
[tokenizer.pad_token_type_id] * padding_length + token_type_ids
|
| 34 |
+
|
| 35 |
+
assert tokenizer.model_max_length \
|
| 36 |
+
== len(attention_mask) \
|
| 37 |
+
== len(input_ids) \
|
| 38 |
+
== len(token_type_ids)
|
| 39 |
+
|
| 40 |
+
feature = {
|
| 41 |
+
'input_ids': input_ids,
|
| 42 |
+
'attention_mask': attention_mask,
|
| 43 |
+
'token_type_ids': token_type_ids,
|
| 44 |
+
'label': int(labels[i])
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
features.append(feature)
|
| 48 |
+
|
| 49 |
+
def gen():
|
| 50 |
+
for feature in features:
|
| 51 |
+
yield (
|
| 52 |
+
{
|
| 53 |
+
'input_ids': feature['input_ids'],
|
| 54 |
+
'attention_mask': feature['attention_mask'],
|
| 55 |
+
'token_type_ids': feature['token_type_ids'],
|
| 56 |
+
},
|
| 57 |
+
feature['label'],
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
dataset = data.Dataset.from_generator(
|
| 61 |
+
gen,
|
| 62 |
+
({
|
| 63 |
+
'input_ids': int32,
|
| 64 |
+
'attention_mask': int32,
|
| 65 |
+
'token_type_ids': int32
|
| 66 |
+
}, int64),
|
| 67 |
+
(
|
| 68 |
+
{
|
| 69 |
+
'input_ids': TensorShape([None]),
|
| 70 |
+
'attention_mask': TensorShape([None]),
|
| 71 |
+
'token_type_ids': TensorShape([None]),
|
| 72 |
+
},
|
| 73 |
+
TensorShape([]),
|
| 74 |
+
),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return dataset
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def softmax(values):
|
| 81 |
+
exps = [exp(value) for value in values]
|
| 82 |
+
exps_sum = sum(exp_value for exp_value in exps)
|
| 83 |
+
return tuple(map(lambda x: x / exps_sum, exps))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def make_dir(path):
|
| 87 |
+
try:
|
| 88 |
+
makedirs(path)
|
| 89 |
+
except FileExistsError:
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def remove_dir(path):
|
| 94 |
+
rmtree(path)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def copy_dir(source_path, target_path):
|
| 98 |
+
copytree(source_path, target_path)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def move_dir(source_path, target_path):
|
| 102 |
+
move(source_path, target_path)
|
| 103 |
+
|
| 104 |
+
def download_from_hub(repo_id, filename, revision=None, cache_dir=None):
|
| 105 |
+
try:
|
| 106 |
+
hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, cache_dir=cache_dir)
|
| 107 |
+
except Exception as exp:
|
| 108 |
+
raise exp
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if cache_dir is not None:
|
| 112 |
+
|
| 113 |
+
files = os.listdir(cache_dir)
|
| 114 |
+
|
| 115 |
+
for f in files:
|
| 116 |
+
if '.lock' in f:
|
| 117 |
+
name = f[0:-5]
|
| 118 |
+
|
| 119 |
+
os.rename(cache_dir+name, cache_dir+filename)
|
| 120 |
+
os.remove(cache_dir+name+'.lock')
|
| 121 |
+
os.remove(cache_dir+name+'.json')
|
ernie/models.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Models:
|
| 6 |
+
BertBaseUncased = 'bert-base-uncased'
|
| 7 |
+
BertBaseCased = 'bert-base-cased'
|
| 8 |
+
BertLargeUncased = 'bert-large-uncased'
|
| 9 |
+
BertLargeCased = 'bert-large-cased'
|
| 10 |
+
|
| 11 |
+
RobertaBaseCased = 'roberta-base'
|
| 12 |
+
RobertaLargeCased = 'roberta-large'
|
| 13 |
+
|
| 14 |
+
XLNetBaseCased = 'xlnet-base-cased'
|
| 15 |
+
XLNetLargeCased = 'xlnet-large-cased'
|
| 16 |
+
|
| 17 |
+
DistilBertBaseUncased = 'distilbert-base-uncased'
|
| 18 |
+
DistilBertBaseMultilingualCased = 'distilbert-base-multilingual-cased'
|
| 19 |
+
|
| 20 |
+
AlbertBaseCased = 'albert-base-v1'
|
| 21 |
+
AlbertLargeCased = 'albert-large-v1'
|
| 22 |
+
AlbertXLargeCased = 'albert-xlarge-v1'
|
| 23 |
+
AlbertXXLargeCased = 'albert-xxlarge-v1'
|
| 24 |
+
|
| 25 |
+
AlbertBaseCased2 = 'albert-base-v2'
|
| 26 |
+
AlbertLargeCased2 = 'albert-large-v2'
|
| 27 |
+
AlbertXLargeCased2 = 'albert-xlarge-v2'
|
| 28 |
+
AlbertXXLargeCased2 = 'albert-xxlarge-v2'
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ModelsByFamily:
|
| 32 |
+
Bert = set([Models.BertBaseUncased, Models.BertBaseCased,
|
| 33 |
+
Models.BertLargeUncased, Models.BertLargeCased])
|
| 34 |
+
Roberta = set([Models.RobertaBaseCased, Models.RobertaLargeCased])
|
| 35 |
+
XLNet = set([Models.XLNetBaseCased, Models.XLNetLargeCased])
|
| 36 |
+
DistilBert = set([Models.DistilBertBaseUncased,
|
| 37 |
+
Models.DistilBertBaseMultilingualCased])
|
| 38 |
+
Albert = set([
|
| 39 |
+
Models.AlbertBaseCased,
|
| 40 |
+
Models.AlbertLargeCased,
|
| 41 |
+
Models.AlbertXLargeCased,
|
| 42 |
+
Models.AlbertXXLargeCased,
|
| 43 |
+
Models.AlbertBaseCased2,
|
| 44 |
+
Models.AlbertLargeCased2,
|
| 45 |
+
Models.AlbertXLargeCased2,
|
| 46 |
+
Models.AlbertXXLargeCased2
|
| 47 |
+
])
|
| 48 |
+
Supported = set([
|
| 49 |
+
getattr(Models, model_type) for model_type
|
| 50 |
+
in filter(lambda x: x[:2] != '__', Models.__dict__.keys())
|
| 51 |
+
])
|
ernie/split_strategies.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RegexExpressions:
|
| 8 |
+
split_by_dot = re.compile(r'[^.]+(?:\.\s*)?')
|
| 9 |
+
split_by_semicolon = re.compile(r'[^;]+(?:\;\s*)?')
|
| 10 |
+
split_by_colon = re.compile(r'[^:]+(?:\:\s*)?')
|
| 11 |
+
split_by_comma = re.compile(r'[^,]+(?:\,\s*)?')
|
| 12 |
+
|
| 13 |
+
url = re.compile(
|
| 14 |
+
r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}'
|
| 15 |
+
r'\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)'
|
| 16 |
+
)
|
| 17 |
+
domain = re.compile(r'\w+\.\w+')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SplitStrategy:
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
split_patterns,
|
| 24 |
+
remove_patterns=None,
|
| 25 |
+
group_splits=True,
|
| 26 |
+
remove_too_short_groups=True
|
| 27 |
+
):
|
| 28 |
+
if not isinstance(split_patterns, list):
|
| 29 |
+
self.split_patterns = [split_patterns]
|
| 30 |
+
else:
|
| 31 |
+
self.split_patterns = split_patterns
|
| 32 |
+
|
| 33 |
+
if remove_patterns is not None \
|
| 34 |
+
and not isinstance(remove_patterns, list):
|
| 35 |
+
self.remove_patterns = [remove_patterns]
|
| 36 |
+
else:
|
| 37 |
+
self.remove_patterns = remove_patterns
|
| 38 |
+
|
| 39 |
+
self.group_splits = group_splits
|
| 40 |
+
self.remove_too_short_groups = remove_too_short_groups
|
| 41 |
+
|
| 42 |
+
def split(self, text, tokenizer, split_patterns=None):
|
| 43 |
+
if split_patterns is None:
|
| 44 |
+
if self.split_patterns is None:
|
| 45 |
+
return [text]
|
| 46 |
+
split_patterns = self.split_patterns
|
| 47 |
+
|
| 48 |
+
def len_in_tokens(text_):
|
| 49 |
+
no_tokens = len(tokenizer.encode(text_, add_special_tokens=False))
|
| 50 |
+
return no_tokens
|
| 51 |
+
|
| 52 |
+
no_special_tokens = len(tokenizer.encode('', add_special_tokens=True))
|
| 53 |
+
max_tokens = tokenizer.max_len - no_special_tokens
|
| 54 |
+
|
| 55 |
+
if self.remove_patterns is not None:
|
| 56 |
+
for remove_pattern in self.remove_patterns:
|
| 57 |
+
text = re.sub(remove_pattern, '', text).strip()
|
| 58 |
+
|
| 59 |
+
if len_in_tokens(text) <= max_tokens:
|
| 60 |
+
return [text]
|
| 61 |
+
|
| 62 |
+
selected_splits = []
|
| 63 |
+
splits = map(lambda x: x.strip(), re.findall(split_patterns[0], text))
|
| 64 |
+
|
| 65 |
+
aggregated_splits = ''
|
| 66 |
+
for split in splits:
|
| 67 |
+
if len_in_tokens(split) > max_tokens:
|
| 68 |
+
if len(split_patterns) > 1:
|
| 69 |
+
sub_splits = self.split(
|
| 70 |
+
split, tokenizer, split_patterns[1:])
|
| 71 |
+
selected_splits.extend(sub_splits)
|
| 72 |
+
else:
|
| 73 |
+
selected_splits.append(split)
|
| 74 |
+
|
| 75 |
+
else:
|
| 76 |
+
if not self.group_splits:
|
| 77 |
+
selected_splits.append(split)
|
| 78 |
+
else:
|
| 79 |
+
new_aggregated_splits = \
|
| 80 |
+
f'{aggregated_splits} {split}'.strip()
|
| 81 |
+
if len_in_tokens(new_aggregated_splits) <= max_tokens:
|
| 82 |
+
aggregated_splits = new_aggregated_splits
|
| 83 |
+
else:
|
| 84 |
+
selected_splits.append(aggregated_splits)
|
| 85 |
+
aggregated_splits = split
|
| 86 |
+
|
| 87 |
+
if aggregated_splits:
|
| 88 |
+
selected_splits.append(aggregated_splits)
|
| 89 |
+
|
| 90 |
+
remove_too_short_groups = len(selected_splits) > 1 \
|
| 91 |
+
and self.group_splits \
|
| 92 |
+
and self.remove_too_short_groups
|
| 93 |
+
|
| 94 |
+
if not remove_too_short_groups:
|
| 95 |
+
final_splits = selected_splits
|
| 96 |
+
else:
|
| 97 |
+
final_splits = []
|
| 98 |
+
min_length = tokenizer.max_len / 2
|
| 99 |
+
for split in selected_splits:
|
| 100 |
+
if len_in_tokens(split) >= min_length:
|
| 101 |
+
final_splits.append(split)
|
| 102 |
+
|
| 103 |
+
return final_splits
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class SplitStrategies:
|
| 107 |
+
SentencesWithoutUrls = SplitStrategy(split_patterns=[
|
| 108 |
+
RegexExpressions.split_by_dot,
|
| 109 |
+
RegexExpressions.split_by_semicolon,
|
| 110 |
+
RegexExpressions.split_by_colon,
|
| 111 |
+
RegexExpressions.split_by_comma
|
| 112 |
+
],
|
| 113 |
+
remove_patterns=[RegexExpressions.url, RegexExpressions.domain],
|
| 114 |
+
remove_too_short_groups=False,
|
| 115 |
+
group_splits=False)
|
| 116 |
+
|
| 117 |
+
GroupedSentencesWithoutUrls = SplitStrategy(split_patterns=[
|
| 118 |
+
RegexExpressions.split_by_dot,
|
| 119 |
+
RegexExpressions.split_by_semicolon,
|
| 120 |
+
RegexExpressions.split_by_colon,
|
| 121 |
+
RegexExpressions.split_by_comma
|
| 122 |
+
],
|
| 123 |
+
remove_patterns=[RegexExpressions.url, RegexExpressions.domain],
|
| 124 |
+
remove_too_short_groups=True,
|
| 125 |
+
group_splits=True)
|