Saving and Loading Finetuned Model- Bug in JinaAI code?
Hi team,
I am finetuning jina-embedding-v3
using sentence_transformers
. I do the following:
- Download the pretrained model
- Compute embeddings A
- Finetune the model
- Compute embeddings B
- Save the model to disk and clear it from memory
- Load the model from disk
- Compute embeddings C
and I find that A != B != C != A, by a significant margin. The norm of the difference between embeddings is in the range of .7!
When I change nothing except the name of the model I'm pulling from HF (to, say, microsoft/mpnet-base
), the problem is largely resolved, with a norm of the embedding difference of something like 0.02.
Below is a minimal example of my code.
Is there a problem with the JinaAI code? Or am I doing something wrong here?
from my_code import load_triplet_dataset
from sentence_transformers import (
SentenceTransformer,
InputExample,
SentenceTransformerTrainingArguments,
SentenceTransformerTrainer,
)
from sentence_transformers.losses import TripletLoss, MatryoshkaLoss, TripletDistanceMetric
from sentence_transformers.evaluation import TripletEvaluator, SimilarityFunction, SequentialEvaluator
from transformers import EarlyStoppingCallback
import torch
import os
model = SentenceTransformer("jinaai/jina-embeddings-v3",
trust_remote_code=True,
local_files_only=False,
model_kwargs={'default_task': 'text-matching'})
train_dataset = load_triplet_dataset('train.csv')
eval_dataset = load_triplet_dataset('val.csv')
loss = TripletLoss(model,
distance_metric=TripletDistanceMetric.COSINE,
triplet_margin=0.75)
dev_evaluator = TripletEvaluator(
anchors=eval_dataset["anchor"],
positives=eval_dataset["positive"],
negatives=eval_dataset["negative"],
main_similarity_function=SimilarityFunction.COSINE
)
training_args = SentenceTransformerTrainingArguments(
output_dir=output_dir,
num_train_epochs=1,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
learning_rate=2.5e-5,
warmup_ratio=0.1,
greater_is_better=True,
load_best_model_at_end = True,
metric_for_best_model="eval_cosine_accuracy",
fp16=False,
bf16=True,
eval_strategy="steps",
eval_steps=50,
save_strategy="steps",
save_steps=50,
save_total_limit=10,
logging_steps=50,
logging_first_step=True,
)
trainer = SentenceTransformerTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
loss=loss,
evaluator=dev_evaluator,
)
pretraining_encoding = model.encode(["The human torch was denied a bank loan."])
print("Pre-training encoding:", pretraining_encoding)
# Begine fine tuning
trainer.train()
posttraining_encoding = model.encode(["The human torch was denied a bank loan."])
print("Post-training encoding:", posttraining_encoding)
model.eval()
post_eval_encoding = model.encode(["The human torch was denied a bank loan."])
print("Post-eval encoding:", post_eval_encoding)
# Save the model
model.save("save_dir/saved-model")
post_save_encoding = model.encode(["The human torch was denied a bank loan."])
print("Post-save encoding:", post_save_encoding)
train_embed = model.encode(["The human torch was denied a bank loan."])
# Clear the model from memory
del model
print("\n===== LOADING MODEL FROM DISK =====\n")
# Load the model
model = SentenceTransformer("save_dir/saved-model",
trust_remote_code=True,
local_files_only=True
)
DEVICE = torch.device("cuda")
model.to(DEVICE)
post_load_encoding = model.encode(["The human torch was denied a bank loan."])
print("Post-load encoding:", post_load_encoding)
model.eval()
post_load_eval_encoding = model.encode(["The human torch was denied a bank loan."])
print("Post-load eval encoding:", post_load_eval_encoding)
load_embed = model.encode(["The human torch was denied a bank loan."])
print("Embedding difference norm:", np.linalg.norm(train_embed - load_embed))```
I'll note that using jinaai/jina-embeddings-v2-small-en
, the embeddings before and after saving and loading are very similar.
It seems that either jina-embeddings-v3
requires some additional handling that is not documented anywhere (at least that I could find), or that there is in fact a bug in the model code.
Ah, I've discovered my issue.
The default_task
kwarg does not get saved with the model, so it must be set to 'text-matching'
again when re-loading the model, or when set for each encode()
call.
This solved my issue!