Thomas Boulier commited on
Commit
99c8e88
·
1 Parent(s): 2d3258b

refactor: train and save models only if needed

Browse files
Files changed (2) hide show
  1. main.py +4 -3
  2. tasks/models/text_classifiers.py +1 -1
main.py CHANGED
@@ -15,12 +15,13 @@ async def main():
15
  data_loader = TextDataLoader(text_request, light=light_dataset)
16
 
17
  # define model
18
- model = ModelFactory.create_model("distilbert")
19
 
20
  # train model
21
  train_dataset = data_loader.get_train_dataset()
22
- model.train(train_dataset)
23
- model.save()
 
24
 
25
  # Call the evaluate_text function
26
  results = await evaluate_text(request = text_request,
 
15
  data_loader = TextDataLoader(text_request, light=light_dataset)
16
 
17
  # define model
18
+ model = ModelFactory.create_model("baseline")
19
 
20
  # train model
21
  train_dataset = data_loader.get_train_dataset()
22
+ if model.model is None:
23
+ model.train(train_dataset)
24
+ model.save()
25
 
26
  # Call the evaluate_text function
27
  results = await evaluate_text(request = text_request,
tasks/models/text_classifiers.py CHANGED
@@ -13,6 +13,7 @@ from tasks.data.data_loaders import TextDataLoader
13
  class PredictionModel(ABC):
14
  def __init__(self, data_loader: TextDataLoader = TextDataLoader()):
15
  self.description = ""
 
16
 
17
  @abstractmethod
18
  def predict(self, quote: str) -> int:
@@ -66,7 +67,6 @@ class DistilBERTModel(PredictionModel):
66
  def __init__(self, data_loader: TextDataLoader = TextDataLoader()):
67
  super().__init__()
68
  self.description = "DistilBERT Model"
69
- self.model = None
70
  self.label_to_id_mapping = data_loader.get_label_to_id_mapping()
71
  self.id_to_label_mapping = data_loader.get_id_to_label_mapping()
72
 
 
13
  class PredictionModel(ABC):
14
  def __init__(self, data_loader: TextDataLoader = TextDataLoader()):
15
  self.description = ""
16
+ self.model = None
17
 
18
  @abstractmethod
19
  def predict(self, quote: str) -> int:
 
67
  def __init__(self, data_loader: TextDataLoader = TextDataLoader()):
68
  super().__init__()
69
  self.description = "DistilBERT Model"
 
70
  self.label_to_id_mapping = data_loader.get_label_to_id_mapping()
71
  self.id_to_label_mapping = data_loader.get_id_to_label_mapping()
72