SetFit documentation

Trainer Classes

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Trainer Classes

TrainingArguments

class setfit.TrainingArguments

< >

( output_dir: str = 'checkpoints' batch_size: Union[int, Tuple[int, int]] = (16, 2) num_epochs: Union[int, Tuple[int, int]] = (1, 16) max_steps: int = -1 sampling_strategy: str = 'oversampling' num_iterations: Optional[int] = None body_learning_rate: Union[float, Tuple[float, float]] = (2e-05, 1e-05) head_learning_rate: float = 0.01 loss: Callable = <class 'sentence_transformers.losses.CosineSimilarityLoss.CosineSimilarityLoss'> distance_metric: Callable = <function BatchHardTripletLossDistanceFunction.cosine_distance at 0x7f4395fd84c0> margin: float = 0.25 end_to_end: bool = False use_amp: bool = False warmup_proportion: float = 0.1 l2_weight: Optional[float] = 0.01 max_length: Optional[int] = None samples_per_label: int = 2 show_progress_bar: bool = True seed: int = 42 report_to: str = 'all' run_name: Optional[str] = None logging_dir: Optional[str] = None logging_strategy: str = 'steps' logging_first_step: bool = True logging_steps: int = 50 eval_strategy: str = 'no' evaluation_strategy: Optional[str] = None eval_steps: Optional[int] = None eval_delay: int = 0 eval_max_steps: int = -1 save_strategy: str = 'steps' save_steps: int = 500 save_total_limit: Optional[int] = 1 load_best_model_at_end: bool = False metric_for_best_model: Optional[str] = 'embedding_loss' greater_is_better: bool = False )

Parameters

  • output_dir (str, defaults to "checkpoints") — The output directory where the model predictions and checkpoints will be written.
  • batch_size (Union[int, Tuple[int, int]], defaults to (16, 2)) — Set the batch sizes for the embedding and classifier training phases respectively, or set both if an integer is provided. Note that the batch size for the classifier is only used with a differentiable PyTorch head.
  • num_epochs (Union[int, Tuple[int, int]], defaults to (1, 16)) — Set the number of epochs the embedding and classifier training phases respectively, or set both if an integer is provided. Note that the number of epochs for the classifier is only used with a differentiable PyTorch head.
  • max_steps (int, defaults to -1) — If set to a positive number, the total number of training steps to perform. Overrides num_epochs. The training may stop before reaching the set number of steps when all data is exhausted.
  • sampling_strategy (str, defaults to "oversampling") — The sampling strategy of how to draw pairs in training. Possible values are:

    • "oversampling": Draws even number of positive/ negative sentence pairs until every sentence pair has been drawn.
    • "undersampling": Draws the minimum number of positive/ negative sentence pairs until every sentence pair in the minority class has been drawn.
    • "unique": Draws every sentence pair combination (likely resulting in unbalanced number of positive/ negative sentence pairs).

    The default is set to "oversampling", ensuring all sentence pairs are drawn at least once. Alternatively, setting num_iterations will override this argument and determine the number of generated sentence pairs.

  • num_iterations (int, optional) — If not set the sampling_strategy will determine the number of sentence pairs to generate. This argument sets the number of iterations to generate sentence pairs for and provides compatability with Setfit CosineSimilarityLoss.
  • body_learning_rate (Union[float, Tuple[float, float]], defaults to (2e-5, 1e-5)) — Set the learning rate for the SentenceTransformer body for the embedding and classifier training phases respectively, or set both if a float is provided. Note that the body learning rate for the classifier is only used with a differentiable PyTorch head and if end_to_end=True.
  • head_learning_rate (float, defaults to 1e-2) — Set the learning rate for the head for the classifier training phase. Only used with a differentiable PyTorch head.
  • loss (nn.Module, defaults to CosineSimilarityLoss) — The loss function to use for contrastive training of the embedding training phase.
  • distance_metric (Callable, defaults to BatchHardTripletLossDistanceFunction.cosine_distance) — Function that returns a distance between two embeddings. It is set for the triplet loss and ignored for CosineSimilarityLoss and SupConLoss.
  • margin (float, defaults to 0.25) — Margin for the triplet loss. Negative samples should be at least margin further apart from the anchor than the positive. It is ignored for CosineSimilarityLoss, BatchHardSoftMarginTripletLoss and SupConLoss.
  • end_to_end (bool, defaults to False) — If True, train the entire model end-to-end during the classifier training phase. Otherwise, freeze the SentenceTransformer body and only train the head. Only used with a differentiable PyTorch head.
  • use_amp (bool, defaults to False) — Whether to use Automatic Mixed Precision (AMP) during the embedding training phase. Only for Pytorch >= 1.6.0
  • warmup_proportion (float, defaults to 0.1) — Proportion of the warmup in the total training steps. Must be greater than or equal to 0.0 and less than or equal to 1.0.
  • l2_weight (float, optional) — Optional l2 weight for both the model body and head, passed to the AdamW optimizer in the classifier training phase if a differentiable PyTorch head is used.
  • max_length (int, optional) — The maximum token length a tokenizer can generate. If not provided, the maximum length for the SentenceTransformer body is used.
  • samples_per_label (int, defaults to 2) — Number of consecutive, random and unique samples drawn per label. This is only relevant for triplet loss and ignored for CosineSimilarityLoss. Batch size should be a multiple of samples_per_label.
  • show_progress_bar (bool, defaults to True) — Whether to display a progress bar for the training epochs and iterations.
  • seed (int, defaults to 42) — Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the model_init argument to Trainer to instantiate the model if it has some randomly initialized parameters.
  • report_to (str or List[str], optional, defaults to "all") — The list of integrations to report the results and logs to. Supported platforms are "azure_ml", "comet_ml", "mlflow", "neptune", "tensorboard","clearml" and "wandb". Use "all" to report to all integrations installed, "none" for no integrations.
  • run_name (str, optional) — A descriptor for the run. Typically used for wandb and mlflow logging.
  • logging_dir (str, optional) — TensorBoard log directory. Will default to *runs/CURRENT_DATETIME_HOSTNAME*.
  • logging_strategy (str or IntervalStrategy, optional, defaults to "steps") — The logging strategy to adopt during training. Possible values are:

    • "no": No logging is done during training.
    • "epoch": Logging is done at the end of each epoch.
    • "steps": Logging is done every logging_steps.
  • logging_first_step (bool, optional, defaults to False) — Whether to log and evaluate the first global_step or not.
  • logging_steps (int, defaults to 50) — Number of update steps between two logs if logging_strategy="steps".
  • eval_strategy (str or IntervalStrategy, optional, defaults to "no") — The evaluation strategy to adopt during training. Possible values are:

    • "no": No evaluation is done during training.
    • "steps": Evaluation is done (and logged) every eval_steps.
    • "epoch": Evaluation is done at the end of each epoch.
  • eval_steps (int, optional) — Number of update steps between two evaluations if eval_strategy="steps". Will default to the same value as logging_steps if not set.
  • eval_delay (float, optional) — Number of epochs or steps to wait for before the first evaluation can be performed, depending on the eval_strategy.
  • eval_max_steps (int, defaults to -1) — If set to a positive number, the total number of evaluation steps to perform. The evaluation may stop before reaching the set number of steps when all data is exhausted.
  • save_strategy (str or IntervalStrategy, optional, defaults to "steps") — The checkpoint save strategy to adopt during training. Possible values are:

    • "no": No save is done during training.
    • "epoch": Save is done at the end of each epoch.
    • "steps": Save is done every save_steps.
  • save_steps (int, optional, defaults to 500) — Number of updates steps before two checkpoint saves if save_strategy="steps".
  • save_total_limit (int, optional, defaults to 1) — If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in output_dir. Note, the best model is always preserved if the eval_strategy is not "no".
  • load_best_model_at_end (bool, optional, defaults to False) — Whether or not to load the best model found during training at the end of training.

    When set to True, the parameters save_strategy needs to be the same as eval_strategy, and in the case it is “steps”, save_steps must be a round multiple of eval_steps.

TrainingArguments is the subset of the arguments which relate to the training loop itself. Note that training with SetFit consists of two phases behind the scenes: finetuning embeddings and training a classification head. As a result, some of the training arguments can be tuples, where the two values are used for each of the two phases, respectively. The second value is often only used when training the model was loaded using use_differentiable_head=True.

to_dict

< >

( ) Dict[str, Any]

Returns

Dict[str, Any]

The dictionary variant of this dataclass.

Convert this instance to a dictionary.

from_dict

< >

( arguments: Dict[str, Any] ignore_extra: bool = False ) TrainingArguments

Parameters

  • arguments (Dict[str, Any]) — A dictionary of arguments.
  • ignore_extra (bool, optional) — Whether to ignore arguments that do not occur in the TrainingArguments init signature. Defaults to False.

Returns

TrainingArguments

The instantiated TrainingArguments instance.

Initialize a TrainingArguments instance from a dictionary.

copy

< >

( )

Create a shallow copy of this TrainingArguments instance.

update

< >

( arguments: Dict[str, Any] ignore_extra: bool = False )

Trainer

class setfit.Trainer

< >

( model: typing.Optional[ForwardRef('SetFitModel')] = None args: typing.Optional[setfit.training_args.TrainingArguments] = None train_dataset: typing.Optional[ForwardRef('Dataset')] = None eval_dataset: typing.Optional[ForwardRef('Dataset')] = None model_init: typing.Optional[typing.Callable[[], ForwardRef('SetFitModel')]] = None metric: typing.Union[str, typing.Callable[[ForwardRef('Dataset'), ForwardRef('Dataset')], typing.Dict[str, float]]] = 'accuracy' metric_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None column_mapping: typing.Optional[typing.Dict[str, str]] = None )

Parameters

  • model (SetFitModel, optional) — The model to train. If not provided, a model_init must be passed.
  • args (TrainingArguments, optional) — The training arguments to use.
  • train_dataset (Dataset) — The training dataset.
  • eval_dataset (Dataset, optional) — The evaluation dataset.
  • model_init (Callable[[], SetFitModel], optional) — A function that instantiates the model to be used. If provided, each call to Trainer.train() will start from a new instance of the model as given by this function when a trial is passed.
  • metric (str or Callable, optional, defaults to "accuracy") — The metric to use for evaluation. If a string is provided, we treat it as the metric name and load it with default settings. If a callable is provided, it must take two arguments (y_pred, y_test) and return a dictionary with metric keys to values.
  • metric_kwargs (Dict[str, Any], optional) — Keyword arguments passed to the evaluation function if metric is an evaluation string like “f1”. For example useful for providing an averaging strategy for computing f1 in a multi-label setting.
  • callbacks (List[TrainerCallback], optional) — A list of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in here. If you want to remove one of the default callbacks used, use the Trainer.remove_callback() method.
  • column_mapping (Dict[str, str], optional) — A mapping from the column names in the dataset to the column names expected by the model. The expected format is a dictionary with the following format: {"text_column_name": "text", "label_column_name: "label"}.

Trainer to train a SetFit model.

add_callback

< >

( callback: typing.Union[type, transformers.trainer_callback.TrainerCallback] )

Parameters

Add a callback to the current list of TrainerCallback.

apply_hyperparameters

< >

( params: typing.Dict[str, typing.Any] final_model: bool = False )

Parameters

  • params (Dict[str, Any]) — The parameters, usually from BestRun.hyperparameters
  • final_model (bool, optional, defaults to False) — If True, replace the model_init() function with a fixed model based on the parameters.

Applies a dictionary of hyperparameters to both the trainer and the model

evaluate

< >

( dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None metric_key_prefix: str = 'test' ) Dict[str, float]

Parameters

  • dataset (Dataset, optional) — The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed via the eval_dataset argument at Trainer initialization.

Returns

Dict[str, float]

The evaluation metrics.

Computes the metrics for a given classifier.

hyperparameter_search

< >

( hp_space: typing.Optional[typing.Callable[[ForwardRef('optuna.Trial')], typing.Dict[str, float]]] = None compute_objective: typing.Optional[typing.Callable[[typing.Dict[str, float]], float]] = None n_trials: int = 10 direction: str = 'maximize' backend: typing.Union[ForwardRef('str'), transformers.trainer_utils.HPSearchBackend, NoneType] = None hp_name: typing.Optional[typing.Callable[[ForwardRef('optuna.Trial')], str]] = None **kwargs ) trainer_utils.BestRun

Parameters

  • hp_space (Callable[["optuna.Trial"], Dict[str, float]], optional) — A function that defines the hyperparameter search space. Will default to default_hp_space_optuna.
  • compute_objective (Callable[[Dict[str, float]], float], optional) — A function computing the objective to minimize or maximize from the metrics returned by the evaluate method. Will default to default_compute_objective which uses the sum of metrics.
  • n_trials (int, optional, defaults to 100) — The number of trial runs to test.
  • direction (str, optional, defaults to "maximize") — Whether to optimize greater or lower objects. Can be "minimize" or "maximize", you should pick "minimize" when optimizing the validation loss, "maximize" when optimizing one or several metrics.
  • backend (str or HPSearchBackend, optional) — The backend to use for hyperparameter search. Only optuna is supported for now. TODO: add support for ray and sigopt.
  • hp_name (Callable[["optuna.Trial"], str]], optional) — A function that defines the trial/run name. Will default to None.
  • kwargs (Dict[str, Any], optional) — Additional keyword arguments passed along to optuna.create_study. For more information see:

Returns

trainer_utils.BestRun

All the information about the best run.

Launch a hyperparameter search using optuna. The optimized quantity is determined by compute_objective, which defaults to a function returning the evaluation loss when no metric is provided, the sum of all metrics otherwise.

To use this method, you need to have provided a model_init when initializing your Trainer: we need to reinitialize the model at each new run.

pop_callback

< >

( callback: typing.Union[type, transformers.trainer_callback.TrainerCallback] ) TrainerCallback

Parameters

Returns

TrainerCallback

The callback removed, if found.

Remove a callback from the current list of TrainerCallback and returns it.

If the callback is not found, returns None (and no error is raised).

push_to_hub

< >

( repo_id: str **kwargs ) str

Parameters

  • repo_id (str) — The full repository ID to push to, e.g. "tomaarsen/setfit-sst2".
  • config (dict, optional) — Configuration object to be saved alongside the model weights.
  • commit_message (str, optional) — Message to commit while pushing.
  • private (bool, optional) — Whether to make the repo private. If None (default), the repo will be public unless the organization’s default is private. This value is ignored if the repo already exists.
  • api_endpoint (str, optional) — The API endpoint to use when pushing the model to the hub.
  • token (str, optional) — The token to use as HTTP bearer authorization for remote files. If not set, will use the token set when logging in with transformers-cli login (stored in ~/.huggingface).
  • branch (str, optional) — The git branch on which to push the model. This defaults to the default branch as specified in your repository, which defaults to "main".
  • create_pr (boolean, optional) — Whether or not to create a Pull Request from branch with that commit. Defaults to False.
  • allow_patterns (List[str] or str, optional) — If provided, only files matching at least one pattern are pushed.
  • ignore_patterns (List[str] or str, optional) — If provided, files matching any of the patterns are not pushed.

Returns

str

The url of the commit of your model in the given repository.

Upload model checkpoint to the Hub using huggingface_hub.

See the full list of parameters for your huggingface_hub version in the huggingface_hub documentation.

remove_callback

< >

( callback: typing.Union[type, transformers.trainer_callback.TrainerCallback] )

Parameters

Remove a callback from the current list of TrainerCallback.

train

< >

( args: typing.Optional[setfit.training_args.TrainingArguments] = None trial: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any], NoneType] = None **kwargs )

Parameters

  • args (TrainingArguments, optional) — Temporarily change the training arguments for this training call.
  • trial (optuna.Trial or Dict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search.

Main training entry point.

train_classifier

< >

( x_train: typing.List[str] y_train: typing.Union[typing.List[int], typing.List[typing.List[int]]] args: typing.Optional[setfit.training_args.TrainingArguments] = None )

Parameters

  • x_train (List[str]) — A list of training sentences.
  • y_train (Union[List[int], List[List[int]]]) — A list of labels corresponding to the training sentences.
  • args (TrainingArguments, optional) — Temporarily change the training arguments for this training call.

Method to perform the classifier phase: fitting a classifier head.

train_embeddings

< >

( x_train: typing.List[str] y_train: typing.Union[typing.List[int], typing.List[typing.List[int]], NoneType] = None x_eval: typing.Optional[typing.List[str]] = None y_eval: typing.Union[typing.List[int], typing.List[typing.List[int]], NoneType] = None args: typing.Optional[setfit.training_args.TrainingArguments] = None )

Parameters

  • x_train (List[str]) — A list of training sentences.
  • y_train (Union[List[int], List[List[int]]]) — A list of labels corresponding to the training sentences.
  • args (TrainingArguments, optional) — Temporarily change the training arguments for this training call.

Method to perform the embedding phase: finetuning the SentenceTransformer body.

DistillationTrainer

class setfit.DistillationTrainer

< >

( teacher_model: SetFitModel student_model: typing.Optional[ForwardRef('SetFitModel')] = None args: TrainingArguments = None train_dataset: typing.Optional[ForwardRef('Dataset')] = None eval_dataset: typing.Optional[ForwardRef('Dataset')] = None model_init: typing.Optional[typing.Callable[[], ForwardRef('SetFitModel')]] = None metric: typing.Union[str, typing.Callable[[ForwardRef('Dataset'), ForwardRef('Dataset')], typing.Dict[str, float]]] = 'accuracy' column_mapping: typing.Optional[typing.Dict[str, str]] = None )

Parameters

  • teacher_model (SetFitModel) — The teacher model to mimic.
  • student_model (SetFitModel, optional) — The model to train. If not provided, a model_init must be passed.
  • args (TrainingArguments, optional) — The training arguments to use.
  • train_dataset (Dataset) — The training dataset.
  • eval_dataset (Dataset, optional) — The evaluation dataset.
  • model_init (Callable[[], SetFitModel], optional) — A function that instantiates the model to be used. If provided, each call to train() will start from a new instance of the model as given by this function when a trial is passed.
  • metric (str or Callable, optional, defaults to "accuracy") — The metric to use for evaluation. If a string is provided, we treat it as the metric name and load it with default settings. If a callable is provided, it must take two arguments (y_pred, y_test).
  • column_mapping (Dict[str, str], optional) — A mapping from the column names in the dataset to the column names expected by the model. The expected format is a dictionary with the following format: {"text_column_name": "text", "label_column_name: "label"}.

Trainer to compress a SetFit model with knowledge distillation.

add_callback

< >

( callback: typing.Union[type, transformers.trainer_callback.TrainerCallback] )

Parameters

Add a callback to the current list of TrainerCallback.

apply_hyperparameters

< >

( params: typing.Dict[str, typing.Any] final_model: bool = False )

Parameters

  • params (Dict[str, Any]) — The parameters, usually from BestRun.hyperparameters
  • final_model (bool, optional, defaults to False) — If True, replace the model_init() function with a fixed model based on the parameters.

Applies a dictionary of hyperparameters to both the trainer and the model

evaluate

< >

( dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None metric_key_prefix: str = 'test' ) Dict[str, float]

Parameters

  • dataset (Dataset, optional) — The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed via the eval_dataset argument at Trainer initialization.

Returns

Dict[str, float]

The evaluation metrics.

Computes the metrics for a given classifier.

hyperparameter_search

< >

( hp_space: typing.Optional[typing.Callable[[ForwardRef('optuna.Trial')], typing.Dict[str, float]]] = None compute_objective: typing.Optional[typing.Callable[[typing.Dict[str, float]], float]] = None n_trials: int = 10 direction: str = 'maximize' backend: typing.Union[ForwardRef('str'), transformers.trainer_utils.HPSearchBackend, NoneType] = None hp_name: typing.Optional[typing.Callable[[ForwardRef('optuna.Trial')], str]] = None **kwargs ) trainer_utils.BestRun

Parameters

  • hp_space (Callable[["optuna.Trial"], Dict[str, float]], optional) — A function that defines the hyperparameter search space. Will default to default_hp_space_optuna.
  • compute_objective (Callable[[Dict[str, float]], float], optional) — A function computing the objective to minimize or maximize from the metrics returned by the evaluate method. Will default to default_compute_objective which uses the sum of metrics.
  • n_trials (int, optional, defaults to 100) — The number of trial runs to test.
  • direction (str, optional, defaults to "maximize") — Whether to optimize greater or lower objects. Can be "minimize" or "maximize", you should pick "minimize" when optimizing the validation loss, "maximize" when optimizing one or several metrics.
  • backend (str or HPSearchBackend, optional) — The backend to use for hyperparameter search. Only optuna is supported for now. TODO: add support for ray and sigopt.
  • hp_name (Callable[["optuna.Trial"], str]], optional) — A function that defines the trial/run name. Will default to None.
  • kwargs (Dict[str, Any], optional) — Additional keyword arguments passed along to optuna.create_study. For more information see:

Returns

trainer_utils.BestRun

All the information about the best run.

Launch a hyperparameter search using optuna. The optimized quantity is determined by compute_objective, which defaults to a function returning the evaluation loss when no metric is provided, the sum of all metrics otherwise.

To use this method, you need to have provided a model_init when initializing your Trainer: we need to reinitialize the model at each new run.

pop_callback

< >

( callback: typing.Union[type, transformers.trainer_callback.TrainerCallback] ) TrainerCallback

Parameters

Returns

TrainerCallback

The callback removed, if found.

Remove a callback from the current list of TrainerCallback and returns it.

If the callback is not found, returns None (and no error is raised).

push_to_hub

< >

( repo_id: str **kwargs ) str

Parameters

  • repo_id (str) — The full repository ID to push to, e.g. "tomaarsen/setfit-sst2".
  • config (dict, optional) — Configuration object to be saved alongside the model weights.
  • commit_message (str, optional) — Message to commit while pushing.
  • private (bool, optional) — Whether to make the repo private. If None (default), the repo will be public unless the organization’s default is private. This value is ignored if the repo already exists.
  • api_endpoint (str, optional) — The API endpoint to use when pushing the model to the hub.
  • token (str, optional) — The token to use as HTTP bearer authorization for remote files. If not set, will use the token set when logging in with transformers-cli login (stored in ~/.huggingface).
  • branch (str, optional) — The git branch on which to push the model. This defaults to the default branch as specified in your repository, which defaults to "main".
  • create_pr (boolean, optional) — Whether or not to create a Pull Request from branch with that commit. Defaults to False.
  • allow_patterns (List[str] or str, optional) — If provided, only files matching at least one pattern are pushed.
  • ignore_patterns (List[str] or str, optional) — If provided, files matching any of the patterns are not pushed.

Returns

str

The url of the commit of your model in the given repository.

Upload model checkpoint to the Hub using huggingface_hub.

See the full list of parameters for your huggingface_hub version in the huggingface_hub documentation.

remove_callback

< >

( callback: typing.Union[type, transformers.trainer_callback.TrainerCallback] )

Parameters

Remove a callback from the current list of TrainerCallback.

train

< >

( args: typing.Optional[setfit.training_args.TrainingArguments] = None trial: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any], NoneType] = None **kwargs )

Parameters

  • args (TrainingArguments, optional) — Temporarily change the training arguments for this training call.
  • trial (optuna.Trial or Dict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search.

Main training entry point.

train_classifier

< >

( x_train: typing.List[str] args: typing.Optional[setfit.training_args.TrainingArguments] = None )

Parameters

  • x_train (List[str]) — A list of training sentences.
  • args (TrainingArguments, optional) — Temporarily change the training arguments for this training call.

Method to perform the classifier phase: fitting the student classifier head.

train_embeddings

< >

( x_train: typing.List[str] y_train: typing.Union[typing.List[int], typing.List[typing.List[int]], NoneType] = None x_eval: typing.Optional[typing.List[str]] = None y_eval: typing.Union[typing.List[int], typing.List[typing.List[int]], NoneType] = None args: typing.Optional[setfit.training_args.TrainingArguments] = None )

Parameters

  • x_train (List[str]) — A list of training sentences.
  • y_train (Union[List[int], List[List[int]]]) — A list of labels corresponding to the training sentences.
  • args (TrainingArguments, optional) — Temporarily change the training arguments for this training call.

Method to perform the embedding phase: finetuning the SentenceTransformer body.

AbsaTrainer

class setfit.AbsaTrainer

< >

( model: AbsaModel args: typing.Optional[setfit.training_args.TrainingArguments] = None polarity_args: typing.Optional[setfit.training_args.TrainingArguments] = None train_dataset: typing.Optional[ForwardRef('Dataset')] = None eval_dataset: typing.Optional[ForwardRef('Dataset')] = None metric: typing.Union[str, typing.Callable[[ForwardRef('Dataset'), ForwardRef('Dataset')], typing.Dict[str, float]]] = 'accuracy' metric_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None column_mapping: typing.Optional[typing.Dict[str, str]] = None )

Parameters

  • model (AbsaModel) — The AbsaModel model to train.
  • args (TrainingArguments, optional) — The training arguments to use. If polarity_args is not defined, then args is used for both the aspect and the polarity model.
  • polarity_args (TrainingArguments, optional) — The training arguments to use for the polarity model. If not defined, args is used for both the aspect and the polarity model.
  • train_dataset (Dataset) — The training dataset. The dataset must have “text”, “span”, “label” and “ordinal” columns.
  • eval_dataset (Dataset, optional) — The evaluation dataset. The dataset must have “text”, “span”, “label” and “ordinal” columns.
  • metric (str or Callable, optional, defaults to "accuracy") — The metric to use for evaluation. If a string is provided, we treat it as the metric name and load it with default settings. If a callable is provided, it must take two arguments (y_pred, y_test).
  • metric_kwargs (Dict[str, Any], optional) — Keyword arguments passed to the evaluation function if metric is an evaluation string like “f1”. For example useful for providing an averaging strategy for computing f1 in a multi-label setting.
  • callbacks (List[TrainerCallback], optional) — A list of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in here. If you want to remove one of the default callbacks used, use the Trainer.remove_callback() method.
  • column_mapping (Dict[str, str], optional) — A mapping from the column names in the dataset to the column names expected by the model. The expected format is a dictionary with the following format: {"text_column_name": "text", "span_column_name": "span", "label_column_name: "label", "ordinal_column_name": "ordinal"}.

Trainer to train a SetFit ABSA model.

add_callback

< >

( callback: typing.Union[type, transformers.trainer_callback.TrainerCallback] )

Parameters

Add a callback to the current list of TrainerCallback.

evaluate

< >

( dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None ) Dict[str, Dict[str, float]]

Parameters

  • dataset (Dataset, optional) — The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed via the eval_dataset argument at Trainer initialization.

Returns

Dict[str, Dict[str, float]]

The evaluation metrics.

Computes the metrics for a given classifier.

pop_callback

< >

( callback: typing.Union[type, transformers.trainer_callback.TrainerCallback] ) Tuple[TrainerCallback, TrainerCallback]

Parameters

Returns

Tuple[TrainerCallback, TrainerCallback]

The callbacks removed from the aspect and polarity trainers, if found.

Remove a callback from the current list of TrainerCallback and returns it.

If the callback is not found, returns None (and no error is raised).

push_to_hub

< >

( repo_id: str polarity_repo_id: typing.Optional[str] = None **kwargs )

Parameters

  • repo_id (str) — The full repository ID to push to, e.g. "tomaarsen/setfit-aspect".
  • repo_id (str) — The full repository ID to push to, e.g. "tomaarsen/setfit-sst2".
  • config (dict, optional) — Configuration object to be saved alongside the model weights.
  • commit_message (str, optional) — Message to commit while pushing.
  • private (bool, optional) — Whether to make the repo private. If None (default), the repo will be public unless the organization’s default is private. This value is ignored if the repo already exists.
  • api_endpoint (str, optional) — The API endpoint to use when pushing the model to the hub.
  • token (str, optional) — The token to use as HTTP bearer authorization for remote files. If not set, will use the token set when logging in with transformers-cli login (stored in ~/.huggingface).
  • branch (str, optional) — The git branch on which to push the model. This defaults to the default branch as specified in your repository, which defaults to "main".
  • create_pr (boolean, optional) — Whether or not to create a Pull Request from branch with that commit. Defaults to False.
  • allow_patterns (List[str] or str, optional) — If provided, only files matching at least one pattern are pushed.
  • ignore_patterns (List[str] or str, optional) — If provided, files matching any of the patterns are not pushed.

Upload model checkpoint to the Hub using huggingface_hub.

See the full list of parameters for your huggingface_hub version in the huggingface_hub documentation.

remove_callback

< >

( callback: typing.Union[type, transformers.trainer_callback.TrainerCallback] )

Parameters

Remove a callback from the current list of TrainerCallback.

train

< >

( args: typing.Optional[setfit.training_args.TrainingArguments] = None polarity_args: typing.Optional[setfit.training_args.TrainingArguments] = None trial: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any], NoneType] = None **kwargs )

Parameters

  • args (TrainingArguments, optional) — Temporarily change the aspect training arguments for this training call.
  • polarity_args (TrainingArguments, optional) — Temporarily change the polarity training arguments for this training call.
  • trial (optuna.Trial or Dict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search.

Main training entry point.

train_aspect

< >

( args: typing.Optional[setfit.training_args.TrainingArguments] = None trial: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any], NoneType] = None **kwargs )

Parameters

  • args (TrainingArguments, optional) — Temporarily change the aspect training arguments for this training call.
  • trial (optuna.Trial or Dict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search.

Train the aspect model only.

train_polarity

< >

( args: typing.Optional[setfit.training_args.TrainingArguments] = None trial: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any], NoneType] = None **kwargs )

Parameters

  • args (TrainingArguments, optional) — Temporarily change the aspect training arguments for this training call.
  • trial (optuna.Trial or Dict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search.

Train the polarity model only.

< > Update on GitHub