Matryoshka Embedding Model (LoRA Adapter) for SEA-LION 8B

This repository contains the LoRA adapters and custom head weights to turn the base aisingapore/Llama-SEA-LION-v3.5-8B-R model into a powerful, Matryoshka-style text embedding model.

This is the efficient version of the model. The repository is lightweight as it only contains the trained "changes" to the base model. For a full, standalone version (16GB+), see the merged model repository here.

Model Features

  • Base Model: aisingapore/Llama-SEA-LION-v3.5-8B-R
  • Latent Attention Pooling: A sophisticated pooling mechanism that uses cross-attention to summarize token sequences into a single vector.
  • Matryoshka Representation Learning (MRL): Trained to produce nested embeddings. You can use the full 4096-dimension embedding for maximum performance, or slice it to a smaller dimension (e.g., 1024, 512, 128) for a trade-off in speed and storage.

Intended Use

This model is ideal for generating fixed-size embeddings for tasks like:

  • Semantic Search & Information Retrieval
  • Retrieval-Augmented Generation (RAG)
  • Clustering and Text Similarity

How to Use

You must load the 8-bit base model first, then apply the LoRA adapter from this repository on top. The custom architectural code is included in modeling.py.

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from huggingface_hub import hf_hub_download
import importlib.util

# --- 1. Setup and Load Components ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
adapter_repo_id = "evoreign/sea-lion-8b-mrl-embedding"
base_model_id = "aisingapore/Llama-SEA-LION-v3.5-8B-R"

# --- 2. Dynamically Load Custom Classes ---
# This robustly loads the custom code from the modeling.py file in the Hub
print("Downloading custom modeling code...")
modeling_path = hf_hub_download(repo_id=adapter_repo_id, filename="modeling.py")
spec = importlib.util.spec_from_file_location("modeling", modeling_path)
modeling = importlib.util.module_from_spec(spec)
spec.loader.exec_module(modeling)
LatentAttentionPooling = modeling.LatentAttentionPooling
MatryoshkaProjection = modeling.MatryoshkaProjection
print("Custom classes loaded successfully.")

# --- 3. Load Quantized Base Model and Adapter ---
print("Loading base model (8-bit)...")
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    load_in_8bit=True,
    device_map={"":0}, # Ensure model is on a single GPU
    trust_remote_code=True
)
backbone_dtype = next(base_model.parameters()).dtype

print(f"Loading LoRA adapter from '{adapter_repo_id}'...")
# Apply the adapter to the base model
model = PeftModel.from_pretrained(base_model, adapter_repo_id)
tokenizer = AutoTokenizer.from_pretrained(adapter_repo_id)

# --- 4. Load Custom Pooling and Projection Heads ---
HIDDEN_SIZE = model.config.hidden_size
MAX_DIM = 4096

print("Loading custom pooling and projection heads...")
pooler = LatentAttentionPooling(hidden_size=HIDDEN_SIZE).to(device).to(dtype=backbone_dtype)
projection = MatryoshkaProjection(hidden_size=HIDDEN_SIZE, max_embed_dim=MAX_DIM).to(device).to(dtype=backbone_dtype)

pooler_path = hf_hub_download(repo_id=adapter_repo_id, filename="pooler.pt")
projection_path = hf_hub_download(repo_id=adapter_repo_id, filename="projection.pt")

pooler.load_state_dict(torch.load(pooler_path, map_location=device))
projection.load_state_dict(torch.load(projection_path, map_location=device))

model.eval()
pooler.eval()
projection.eval()

# --- 5. Create the Inference Function ---
def embed_texts_mrl(texts, out_dim=None):
    with torch.no_grad():
        inputs = tokenizer(
            texts, return_tensors="pt", padding=True, truncation=True, max_length=4096
        ).to(device)
        # Use model.base_model for a PeftModel
        out = model.base_model(**inputs, output_hidden_states=True)
        hidden = out.hidden_states[-1]
        mask = inputs.attention_mask
        pooled = pooler(hidden, attention_mask=mask)
        z_max = projection(pooled)
        z = z_max[:, :out_dim] if out_dim else z_max
        return F.normalize(z, p=2, dim=1)

# --- 6. Example Usage ---
my_texts = ["Contoh kalimat untuk di-embed.", "Another sentence to embed."]
emb_256 = embed_texts_mrl(my_texts, out_dim=256)
print("Sliced embedding shape:", emb_256.shape)
# Expected output: torch.Size([2, 256])

Training Details

  • Loss Function: In-batch contrastive loss with hard negatives.
  • MRL Objective: Loss was averaged across dimensions [128, 256, 512, 1024, 2048, 4096].
  • Dataset: Fine-tuned on a private triplet dataset (query, positive, hard_negative).

Author: [Edbert Khovey]

Downloads last month
1
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for evoreign/sea-lion-8b-mrl-embedding