In [1]:
from safetensors import safe_open
import torch
from torch.nn import functional as F
from transformers import AutoModel, AutoTokenizer

In [None]:
# First clone the model locally
!git clone https://huggingface.co/MongoDB/mdbr-leaf-ir

In [2]:
# Then load it
MODEL = "mdbr-leaf-ir"

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModel.from_pretrained(MODEL, add_pooling_layer=False)

In [3]:
tensors = {}
with safe_open(MODEL + "/2_Dense/model.safetensors", framework="pt") as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k)

W_out = torch.nn.Linear(in_features=384, out_features=768, bias=True)
W_out.load_state_dict({
    "weight": tensors["linear.weight"], 
    "bias": tensors["linear.bias"]
})

_ = model.eval()
_ = W_out.eval()

# Example queries and documents  
queries = [
    "What is machine learning?",  
    "How does neural network training work?"  
]  
  
documents = [  
    "Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data.",  
    "Neural networks are trained through backpropagation, adjusting weights to minimize prediction errors."  
]

# Tokenize
QUERY_PREFIX = 'Represent this sentence for searching relevant passages: '
queries_with_prefix = [QUERY_PREFIX + query for query in queries]

query_tokens = tokenizer(queries_with_prefix, padding=True, truncation=True, return_tensors='pt', max_length=512)
document_tokens =  tokenizer(documents, padding=True, truncation=True, return_tensors='pt', max_length=512)

# Perform Inference
with torch.inference_mode():
    y_queries = model(**query_tokens).last_hidden_state
    y_docs = model(**document_tokens).last_hidden_state

    # perform pooling
    y_queries = y_queries * query_tokens.attention_mask.unsqueeze(-1)
    y_queries_pooled = y_queries.sum(dim=1) / query_tokens.attention_mask.sum(dim=1, keepdim=True)

    y_docs = y_docs * document_tokens.attention_mask.unsqueeze(-1)
    y_docs_pooled = y_docs.sum(dim=1) / document_tokens.attention_mask.sum(dim=1, keepdim=True)

    # map to desired output dimension
    y_queries_out = W_out(y_queries_pooled)
    y_docs_out = W_out(y_docs_pooled)

    # normalize and return
    query_embeddings = F.normalize(y_queries_out, dim=-1)
    document_embeddings = F.normalize(y_docs_out, dim=-1)

similarities = query_embeddings @ document_embeddings.T
print(f"Similarities:\n{similarities}")

# Similarities:
#  tensor([[0.6857, 0.4598],
#          [0.4238, 0.5723]])

Similarities:
tensor([[0.6857, 0.4598],
        [0.4238, 0.5723]])
