Poor performance of classifier finetuning

#518
by ZYSK-huggingface - opened

Hi, Christina

I’m using your classifier with 95M parameters (both 20L and 12L) for training on my own dataset. When I split the dataset by proportion at the cell level (i.e., training, validation, and test sets come from the same set of samples), the model performs well.

But when I switch to independent test samples (i.e., training, validation, and test sets are from completely different samples), the model performs poorly.

To better reflect generalization ability, I’ve updated my training setup so that the training, validation, and test sets are from entirely separate samples. I’m also using Ray Tune with multiple ntrials to search for hyperparameters. However, I’ve observed that:

The validation performance is consistently poor across trials.

The validation loss stays high and doesn't improve much.

Do you have any suggestions for improving performance in this case?

Thank you!

To give a specific example: when I use the cell-level split strategy, even with a small number of ntrials (e.g., 10), I can obtain many classifiers with scores above 0.95 and loss below 0.2. However, these classifiers do not perform well when predicting independent samples.

After switching to the independent sample strategy, I noticed that in the early ntrials, almost all classifiers have scores below 0.6 and loss above 0.8.

Thanks for your question. We generally recommend having different samples in the splits to confirm generalizability. The classifiers we train are all split this way, for example the cardiomyopathy classifier or the colorectal cancer classifier in our manuscripts, where we evaluate performance on entirely held out patients.

If the samples are highly variable though, for example the patients are of very different ages so the disease expresses itself differently in each patient, or the differentiations are highly variable and progress at different rates leading to very different transcriptional states between samples, then the number of samples you include for training will likely need to be increased to show the model enough of a range for it to focus on the correct features that encompass the breadth of possible states. For example, if the model learns from a subset of patients that a given set of genes differentiates the classes, but then those are not expressed in the held out samples, it will have trouble generalizing. The other suggestion would be to run more hyperparameter trials and only fine-tune with 1 epoch to avoid overfitting.

The other thing to consider is whether your goal is to have a classifier that generalizes so that you can classify new samples or your goal is to separate classes in the embedding space to perform in silico perturbation to understand the separation between the classes. In the latter case, if you don’t have enough samples, it would be acceptable to train with holding out cells but using all the samples so that you encompass the most breadth of variable states you can and have a better model for in silico perturbation that focuses on features consistent amongst all samples.

ctheodoris changed discussion status to closed

Thanks for your question. We generally recommend having different samples in the splits to confirm generalizability. The classifiers we train are all split this way, for example the cardiomyopathy classifier or the colorectal cancer classifier in our manuscripts, where we evaluate performance on entirely held out patients.

If the samples are highly variable though, for example the patients are of very different ages so the disease expresses itself differently in each patient, or the differentiations are highly variable and progress at different rates leading to very different transcriptional states between samples, then the number of samples you include for training will likely need to be increased to show the model enough of a range for it to focus on the correct features that encompass the breadth of possible states. For example, if the model learns from a subset of patients that a given set of genes differentiates the classes, but then those are not expressed in the held out samples, it will have trouble generalizing. The other suggestion would be to run more hyperparameter trials and only fine-tune with 1 epoch to avoid overfitting.

The other thing to consider is whether your goal is to have a classifier that generalizes so that you can classify new samples or your goal is to separate classes in the embedding space to perform in silico perturbation to understand the separation between the classes. In the latter case, if you don’t have enough samples, it would be acceptable to train with holding out cells but using all the samples so that you encompass the most breadth of variable states you can and have a better model for in silico perturbation that focuses on features consistent amongst all samples.

Thank you for your kind reply!

I have two follow-up questions:

First, if my goal is to classify two categories, and my training set contains about 10 samples per class (with a total of several hundred thousand cells), but the validation set only includes 3 samples in total, would that affect the reliability of validation performance?

Second, I did modify the source code to change the epoch from 1 to 3 during hyperparameter search, because I was actually concerned that using only 1 epoch might lead to underfitting.

Finally, I noticed a small issue in the code that might have gone unnoticed:
The 5-fold cross-validation appears to be designed for normal training when ntrials=0.
However, during hyperparameter search, if 5-fold is enabled, the training actually runs with different hyperparameter combinations for each fold, which defeats the purpose of cross-validation and might not be meaningful in this context.

Thanks again for your time and help!

Just to confirm:
If my goal is to explore disease mechanisms, rather than build a classifier for generalization,
is it acceptable — or even preferred — to use a classifier that may be overfit to the training samples, as long as it captures consistent features across the available data?
my concern is whether such overfitting would bias the interpretation of key genes or pathways

Thanks for following up. If your validation data only has 1 sample for 1 class and 2 samples for the other class, it's possible the reliability of the validation may be impacted in the case that single patient sample, for example, is very different in some orthogonal way, such as the age of the patient or some environmental exposure that is not encompassed in the training data. In this case using the cross-fold validation would be helpful.

We generally don't run hyperparameter searches with cross-fold validation since it's computationally expensive to run through many hyperparameter combinations, so thank you for pointing this out the issue of running different hyperparameters across the folds in this case. We will need to add a check to disallow this combination of options for now.

For your question about overfitting: it would generally be preferred to have a generalizable fine-tuned model so that the attention to features focuses on shared mechanisms across patients, etc. In cases where there are insufficient patient samples, fine-tuning with held-out cells may be acceptable in order to include more variable states in the training that should decrease the chances of overfitting to a specific scenario.

Generally we do not fine-tune for more than 1 epoch due to the model easily memorizing the data, but hyperparameter tuning with a validation set should be more informative on whether this is helpful in your specific case.

Sign up or log in to comment