Loss function recommendation
Hello!
Preface
Best of luck with your master thesis!
Details
I would personally recommend switching out the TripletLoss with MultipleNegativesRankingLoss
or CachedMultipleNegativesRankingLoss
.
They should work as drop-in replacements of the TripletLoss, but often outperform it. Notably, the big difference is that the (C)MNRL losses use "in-batch negatives", i.e. all unrelated texts in the same batch are seen as negatives as well as the "true negative" that you provide in your dataset. This means that you're not just training with 1 negative, but with e.g.:
- 1, the true negative +
- (batch_size - 1), the "other" values for the positive column, these are (assumed to be) unrelated to your current query +
- (batch_size - 1), the "other" values for the negative column, these are (assumed to be) unrelated to your current query.
So, with a batch size of 16 you'll train with 31 in-batch negatives instead of 1. A batch size of 128 (possible with CachedMultipleNegativesRankingLoss
even with small memory usage) means that you'll train with 255 negatives.
This often helps a nice bit.
- Tom Aarsen
Hello!
Preface
Best of luck with your master thesis!
Details
I would personally recommend switching out the TripletLoss with
MultipleNegativesRankingLoss
orCachedMultipleNegativesRankingLoss
.They should work as drop-in replacements of the TripletLoss, but often outperform it. Notably, the big difference is that the (C)MNRL losses use "in-batch negatives", i.e. all unrelated texts in the same batch are seen as negatives as well as the "true negative" that you provide in your dataset. This means that you're not just training with 1 negative, but with e.g.:
- 1, the true negative +
- (batch_size - 1), the "other" values for the positive column, these are (assumed to be) unrelated to your current query +
- (batch_size - 1), the "other" values for the negative column, these are (assumed to be) unrelated to your current query.
So, with a batch size of 16 you'll train with 31 in-batch negatives instead of 1. A batch size of 128 (possible with
CachedMultipleNegativesRankingLoss
even with small memory usage) means that you'll train with 255 negatives.This often helps a nice bit.
- Tom Aarsen
Hi Tom,
First of all, thanks a lot !!! for your reply and for taking the time to help us. The topic of the loss function has sparked a lot of debate among us. Ultimately, we decided to use triplet loss because we chose the hard negatives ourselves.
For the positive example, we used mathematical theorems that are correct answers to the anchor (the question). Here's what we did:
We selected papers from the same category (e.g., math.AC, math.AG, cs.DM, and others). For each theorem, we computed the BM25 score between it and all other theorems—excluding those from the same paper—to obtain a similarity ranking. From the top 50 most similar theorems (based on BM25), we then calculated the cosine similarity. The theorem with the closest cosine similarity to the original theorem was selected as the hard negative.
Why? Because we wanted our model to learn from subtle differences. By selecting hard negatives that are very similar to the anchor, the model is encouraged to pay attention to fine-grained details.
What do you think about this approach?
One issue we’ve noticed is that the hard negative question is sometimes too close to the positive, which may be confusing the model. Despite that, the results were very promising—compared to the baseline model without fine-tuning, we observed improvements in our metrics from around 0.2 to 0.6 in some cases.
However, we also noticed that the model's performance plateaued and didn’t improve with additional epochs. This led us to suspect that it might be getting confused by the extreme similarity between positives and hard negatives.
Today, we experimented with MultipleNegativesRankingLoss without manually specifying hard negatives, and the results were even better.
So our question is:
-Do you think that using MultipleNegativesRankingLoss with the manually selected hard negatives we obtained would further improve performance or it still going to be confused?
- Alternatively, would it be better to use less similar hard negatives and feed them into the MultipleNegativesRankingLoss using triplet-style pairs? Would that outperform the setup where we don't explicitly select any hard negatives?
Once again, thank you for your help!
Hello!
What do you think about this approach?
I think it's quite solid. Based on your experience (with negatives perhaps being too close to the positive), I would recommend reading the NV-Retriever paper, which shows that they get the best performance by discarding all hard negatives that are >= 95% as similar to the anchor as the positive is to the anchor. The idea is that this avoids false negatives.
Do you think that using MultipleNegativesRankingLoss with the manually selected hard negatives we obtained would further improve performance or it still going to be confused?
My recommendation would be to use your triplet dataset with manually selected hard negatives with MultipleNegativesRankingLoss
so that you train with both hard negatives and "soft"/"random" negatives. I think that will give the best performance, although I do think that this problem of comparing theorems is quite difficult.
Alternatively, would it be better to use less similar hard negatives and feed them into the MultipleNegativesRankingLoss using triplet-style pairs? Would that outperform the setup where we don't explicitly select any hard negatives?
I think "hard negative triplets with MNRL" will outperform "hard-ish negative triplets with MNRL", which will itself outperform "anchor-positive pairs with MNRL".
In conclusion: I would use a triplet dataset of hard negatives up to 95% of the anchor-positive similarity, and use that with MultipleNegativesRankingLoss.
Then, there are 2 more potential improvements that I can think of:
- You don't have to use pairs or triplets, you can also use n-tuples with MultipleNegativesRankingLoss, such that the first column is the anchor/query, the second column the positive, and columns 3 to n are all hard negatives. This means that you can use datasets formatted like this: https://huggingface.co/datasets/tomaarsen/gooaq-hard-negatives. In my findings, this can help about 1% or so (so it's fairly minor).
- Use CachedMultipleNegativesRankingLoss instead of MNRL, set the CMNRL
mini_batch_size
as your currentper_device_train_batch_size
, and set yourper_device_train_batch_size
to a very high number, like 1024. The CMNRL loss will keep your memory usage static based onmini_batch_size
, but you'll train with larger batches, and larger batches => more negatives => a harder objective for the model => a stronger model, at least on paper. This can also help 1% or so.
Other minor things to consider:
Good luck!
- Tom Aarsen