Spaces:
Sleeping
Sleeping
Thomas Boulier
commited on
Commit
·
99c8e88
1
Parent(s):
2d3258b
refactor: train and save models only if needed
Browse files- main.py +4 -3
- tasks/models/text_classifiers.py +1 -1
main.py
CHANGED
|
@@ -15,12 +15,13 @@ async def main():
|
|
| 15 |
data_loader = TextDataLoader(text_request, light=light_dataset)
|
| 16 |
|
| 17 |
# define model
|
| 18 |
-
model = ModelFactory.create_model("
|
| 19 |
|
| 20 |
# train model
|
| 21 |
train_dataset = data_loader.get_train_dataset()
|
| 22 |
-
model.
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
# Call the evaluate_text function
|
| 26 |
results = await evaluate_text(request = text_request,
|
|
|
|
| 15 |
data_loader = TextDataLoader(text_request, light=light_dataset)
|
| 16 |
|
| 17 |
# define model
|
| 18 |
+
model = ModelFactory.create_model("baseline")
|
| 19 |
|
| 20 |
# train model
|
| 21 |
train_dataset = data_loader.get_train_dataset()
|
| 22 |
+
if model.model is None:
|
| 23 |
+
model.train(train_dataset)
|
| 24 |
+
model.save()
|
| 25 |
|
| 26 |
# Call the evaluate_text function
|
| 27 |
results = await evaluate_text(request = text_request,
|
tasks/models/text_classifiers.py
CHANGED
|
@@ -13,6 +13,7 @@ from tasks.data.data_loaders import TextDataLoader
|
|
| 13 |
class PredictionModel(ABC):
|
| 14 |
def __init__(self, data_loader: TextDataLoader = TextDataLoader()):
|
| 15 |
self.description = ""
|
|
|
|
| 16 |
|
| 17 |
@abstractmethod
|
| 18 |
def predict(self, quote: str) -> int:
|
|
@@ -66,7 +67,6 @@ class DistilBERTModel(PredictionModel):
|
|
| 66 |
def __init__(self, data_loader: TextDataLoader = TextDataLoader()):
|
| 67 |
super().__init__()
|
| 68 |
self.description = "DistilBERT Model"
|
| 69 |
-
self.model = None
|
| 70 |
self.label_to_id_mapping = data_loader.get_label_to_id_mapping()
|
| 71 |
self.id_to_label_mapping = data_loader.get_id_to_label_mapping()
|
| 72 |
|
|
|
|
| 13 |
class PredictionModel(ABC):
|
| 14 |
def __init__(self, data_loader: TextDataLoader = TextDataLoader()):
|
| 15 |
self.description = ""
|
| 16 |
+
self.model = None
|
| 17 |
|
| 18 |
@abstractmethod
|
| 19 |
def predict(self, quote: str) -> int:
|
|
|
|
| 67 |
def __init__(self, data_loader: TextDataLoader = TextDataLoader()):
|
| 68 |
super().__init__()
|
| 69 |
self.description = "DistilBERT Model"
|
|
|
|
| 70 |
self.label_to_id_mapping = data_loader.get_label_to_id_mapping()
|
| 71 |
self.id_to_label_mapping = data_loader.get_id_to_label_mapping()
|
| 72 |
|