Converting a reranker model to a single label classification model

#1
by sigridjineth - opened

Hi all—sharing a small utility conversion I made to simplify using ContextualAI/ctxl-rerank-v2-instruct-multilingual-1b in CrossEncoder/“classification” style pipelines.

  • Model: https://huggingface.co/sigridjineth/ctxl-rerank-v2-1b-seq-cls
  • What it is: a SequenceClassification (num_labels=1) wrapper that emits a single score per input which is numerically identical to the original model’s last-token next_logits[:, 0].
  • Why: easier integration (Sentence-Transformers CrossEncoder, standard Transformers classification APIs), lower overhead (no full vocab projection), simpler serving/thresholding.

I’ll attach the convert.py I used.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
convert.py

Qwen3 기반 reranker LM(예: ContextualAI/ctxl-rerank-v2-instruct-multilingual-1b)을
단일 로짓 SequenceClassification 모델로 변환합니다.

원리:
- 모델이 마지막 토큰에서 vocab ID=k(기본 0)의 로짓을 '점수'로 쓰도록 학습되었다면,
  lm_head.weight[k] (필요 시 bias[k])를 그대로 분류기 헤드에 이식하면
  LM의 next_logits[:, k]와 동치인 단일 로짓을 얻을 수 있습니다.

주의:
- 좌패딩(left padding) + pad_token 설정을 권장합니다 (모델 카드 Quickstart도 동일). 
- parity(수치 일치)를 위해 가중치 벡터에 BF16 round-trip 옵션 제공.

요구:
- transformers>=4.51.0, torch, safetensors
"""

import argparse
import os
import sys
from typing import List, Optional

import torch
from transformers import (
    AutoTokenizer,
    Qwen3ForCausalLM,
    Qwen3ForSequenceClassification,
)
from transformers.utils.versions import require_version


# -----------------------------
# Helpers
# -----------------------------

def pick_device(name: str) -> str:
    name = name.lower()
    if name == "auto":
        if torch.cuda.is_available():
            return "cuda"
        if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            return "mps"
        return "cpu"
    if name in {"cuda", "cpu", "mps"}:
        if name == "cuda" and not torch.cuda.is_available():
            print("[warn] --device=cuda 이지만 CUDA 가용 X → cpu로 대체합니다.", file=sys.stderr)
            return "cpu"
        if name == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
            print("[warn] --device=mps 이지만 MPS 가용 X → cpu로 대체합니다.", file=sys.stderr)
            return "cpu"
        return name
    print(f"[warn] 알 수 없는 --device={name} → cpu 사용", file=sys.stderr)
    return "cpu"


def pick_dtype(name: str, device: str) -> torch.dtype:
    name = name.lower()
    if name == "auto":
        if device == "cuda":
            return torch.bfloat16
        return torch.float32
    mapping = {
        "float32": torch.float32,
        "fp32": torch.float32,
        "bfloat16": torch.bfloat16,
        "bf16": torch.bfloat16,
        "float16": torch.float16,
        "fp16": torch.float16,
    }
    if name in mapping:
        return mapping[name]
    print(f"[warn] 알 수 없는 --dtype={name} → float32 사용", file=sys.stderr)
    return torch.float32


def ensure_tokenizer_padding(tok):
    # pad_token 없으면 eos를 pad로 사용
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    # 좌패딩 권장 (마지막 유효 토큰 정렬)
    tok.padding_side = "left"


def propagate_special_tokens_to_config(model, tok):
    """토크나이저의 special token ID를 모델 config에 반영 (배치>1 검증에 필요)"""
    changed = []
    if getattr(model.config, "pad_token_id", None) is None and tok.pad_token_id is not None:
        model.config.pad_token_id = tok.pad_token_id
        changed.append("pad_token_id")
    if getattr(model.config, "eos_token_id", None) is None and tok.eos_token_id is not None:
        model.config.eos_token_id = tok.eos_token_id
        changed.append("eos_token_id")
    if getattr(model.config, "bos_token_id", None) is None and tok.bos_token_id is not None:
        model.config.bos_token_id = tok.bos_token_id
        changed.append("bos_token_id")
    if changed:
        print(f"[info] propagated to config: {', '.join(changed)}")


def build_prompts(query: str, instruction: str, docs: List[str]) -> List[str]:
    """모델 카드 권장 템플릿과 동일한 형태 (HF Quickstart 참조)."""
    inst = f" {instruction}" if instruction else ""
    return [
        "Check whether a given document contains information helpful to answer the query.\n"
        f"<Document> {d}\n"
        f"<Query> {query}{inst} ??"
        for d in docs
    ]


def get_lm_head(model) -> torch.nn.Module:
    if hasattr(model, "lm_head"):
        return model.lm_head
    if hasattr(model, "get_output_embeddings"):
        return model.get_output_embeddings()
    raise AttributeError("LM head를 찾을 수 없습니다 (lm_head/get_output_embeddings 둘 다 없음).")


def find_classifier_head(model_seqcls):
    # Qwen3ForSequenceClassification는 보통 'score' 레이어를 가짐
    if hasattr(model_seqcls, "score"):
        return model_seqcls.score
    if hasattr(model_seqcls, "classifier"):
        return model_seqcls.classifier
    raise AttributeError("SequenceClassification 모델에서 분류 헤드(score/classifier)를 찾을 수 없습니다.")


def copy_head_weights(
    lm_head: torch.nn.Linear,
    cls_head: torch.nn.Linear,
    score_token_id: int = 0,
    parity_bf16_roundtrip: bool = True,
):
    W = lm_head.weight.detach()  # [vocab, hidden]
    if score_token_id < 0 or score_token_id >= W.shape[0]:
        raise IndexError(f"score_token_id={score_token_id} 가 유효 범위를 벗어났습니다 (0..{W.shape[0]-1}).")
    vec = W[score_token_id]  # (hidden,)
    if parity_bf16_roundtrip:
        vec = vec.to(torch.bfloat16).to(W.dtype)

    # bias가 있다면 해당 채널 bias도 반영
    bias_val = None
    if getattr(lm_head, "bias", None) is not None:
        b = lm_head.bias.detach()
        bias_val = b[score_token_id].item()

    with torch.no_grad():
        cls_head.weight.copy_(vec.to(cls_head.weight.dtype).unsqueeze(0))  # (1, hidden)
        if getattr(cls_head, "bias", None) is not None:
            if bias_val is None:
                cls_head.bias.zero_()
            else:
                cls_head.bias.fill_(bias_val)


def verify_parity(
    tok,
    lm: Qwen3ForCausalLM,
    seqcls: Qwen3ForSequenceClassification,
    device: str,
    query: str = "Which is a domestic animal?",
    docs: Optional[List[str]] = None,
    instruction: str = "",
    max_length: int = 8192,
    score_token_id: int = 0,
    use_bf16_roundtrip_on_lm_scores: bool = True,
) -> bool:
    if docs is None:
        docs = ["Cats are pets.", "The moon is made of cheese."]

    # 안전장치: 모델 config에도 pad/eos/bos 반영
    propagate_special_tokens_to_config(lm, tok)
    propagate_special_tokens_to_config(seqcls, tok)

    prompts = build_prompts(query, instruction, docs)
    enc = tok(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length,
    )
    enc = {k: v.to(device) for k, v in enc.items()}

    lm.eval()
    seqcls.eval()
    with torch.no_grad():
        lm_logits = lm(**enc).logits[:, -1, :]  # [B, V]
        lm_scores = lm_logits[:, score_token_id]
        if use_bf16_roundtrip_on_lm_scores:
            lm_scores = lm_scores.to(torch.bfloat16)  # round-trip 유지(비교는 아래에서 f32로)
        cls_scores = seqcls(**enc).logits.squeeze(-1)

    # (중요) dtype/device 정렬 후 비교/출력
    lm_scores = lm_scores.float().detach().cpu()
    cls_scores = cls_scores.float().detach().cpu()

    print("[verify] LM scores   :", lm_scores[: min(3, lm_scores.shape[0])].tolist())
    print("[verify] SeqCLS scores:", cls_scores[: min(3, cls_scores.shape[0])].tolist())

    ok = torch.allclose(lm_scores, cls_scores, atol=1e-6, rtol=1e-6)
    max_abs = (lm_scores - cls_scores).abs().max().item()
    print(f"[verify] allclose = {ok} (max_abs_diff={max_abs:.3e})")
    return bool(ok)


# -----------------------------
# Main
# -----------------------------

def main():
    parser = argparse.ArgumentParser(description="Convert Qwen3-based reranker LM to single-logit SequenceClassification.")
    parser.add_argument("--source", type=str, required=True, help="소스 LM 모델 ID 또는 로컬 경로 (예: ContextualAI/ctxl-rerank-v2-instruct-multilingual-1b)")
    parser.add_argument("--out", type=str, required=True, help="변환된 seq-cls 모델 저장 디렉터리")
    parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu", "mps"], help="사용 디바이스")
    parser.add_argument("--dtype", type=str, default="auto", choices=["auto", "float32", "fp32", "bfloat16", "bf16", "float16", "fp16"], help="모델 로드 dtype")
    parser.add_argument("--score-token-id", type=int, default=0, help="점수 채널에 해당하는 vocab ID (기본=0)")
    parser.add_argument("--no-bf16-roundtrip", action="store_true", help="가중치/점수 BF16 round-trip 비활성화")
    parser.add_argument("--max-length", type=int, default=8192, help="검증 시 truncation 길이")
    parser.add_argument("--verify", action="store_true", help="변환 후 동치성 검증 수행")
    parser.add_argument("--verify-query", type=str, default="Which is a domestic animal?", help="검증용 쿼리")
    parser.add_argument("--verify-docs", type=str, nargs="*", default=["Cats are pets.", "The moon is made of cheese."], help="검증용 문서들")
    parser.add_argument("--verify-instruction", type=str, default="", help="검증용 instruction (선택)")
    parser.add_argument("--push-to-hub", action="store_true", help="허깅페이스 허브 업로드")
    parser.add_argument("--repo-id", type=str, default=None, help="업로드 리포지토리 ID (예: user/repo)")
    parser.add_argument("--private", action="store_true", help="허브 업로드 시 private")
    args = parser.parse_args()

    try:
        require_version("transformers>=4.51.0")
    except Exception:
        print("[error] transformers>=4.51.0 필요. 업그레이드: pip install -U 'transformers>=4.51.0'", file=sys.stderr)
        raise

    device = pick_device(args.device)
    dtype = pick_dtype(args.dtype, device)
    parity_bf16_roundtrip = not args.no_bf16_roundtrip

    print(f"[info] device={device}, dtype={dtype}, source={args.source}")
    print(f"[info] score_token_id={args.score_token_id}, bf16_roundtrip={parity_bf16_roundtrip}")

    # 토크나이저
    tok = AutoTokenizer.from_pretrained(args.source, use_fast=True)
    ensure_tokenizer_padding(tok)

    # LM 로드
    lm = Qwen3ForCausalLM.from_pretrained(args.source, torch_dtype=dtype).to(device).eval()
    propagate_special_tokens_to_config(lm, tok)

    # SeqCLS 로드
    seqcls = Qwen3ForSequenceClassification.from_pretrained(
        args.source,
        num_labels=1,
        ignore_mismatched_sizes=True,
        torch_dtype=dtype,
    ).to(device).eval()
    propagate_special_tokens_to_config(seqcls, tok)

    # 헤드 이식
    lm_head = get_lm_head(lm)
    cls_head = find_classifier_head(seqcls)
    copy_head_weights(
        lm_head=lm_head,
        cls_head=cls_head,
        score_token_id=args.score_token_id,
        parity_bf16_roundtrip=parity_bf16_roundtrip,
    )

    # 메타데이터
    seqcls.config.problem_type = "single_label_classification"
    seqcls.config.id2label = {0: "SCORE"}
    seqcls.config.label2id = {"SCORE": 0}

    # 저장
    os.makedirs(args.out, exist_ok=True)
    seqcls.save_pretrained(args.out)
    tok.save_pretrained(args.out)
    print(f"[save] saved to: {args.out}")

    # 검증 (선택)
    if args.verify:
        ok = verify_parity(
            tok=tok,
            lm=lm,
            seqcls=seqcls,
            device=device,
            query=args.verify_query,
            docs=args.verify_docs,
            instruction=args.verify_instruction,
            max_length=args.max_length,
            score_token_id=args.score_token_id,
            use_bf16_roundtrip_on_lm_scores=parity_bf16_roundtrip,
        )
        if not ok:
            print("[warn] 검증 allclose=False 입니다. 템플릿/패딩/round-trip/dtype을 점검하세요.", file=sys.stderr)

    # 허브 업로드 (선택)
    if args.push_to_hub:
        if not args.repo_id:
            print("[error] --push-to-hub 사용 시 --repo-id를 지정하세요 (예: user/repo).", file=sys.stderr)
            sys.exit(2)
        print(f"[hub] pushing to: {args.repo_id} (private={args.private})")
        seqcls.push_to_hub(args.repo_id, private=args.private)
        tok.push_to_hub(args.repo_id, private=args.private)
        print("[hub] push complete.")


if __name__ == "__main__":
    main()

Method (short)

  1. Load the base LM and read lm_head.weight.
  2. Copy row 0 (the score channel used by the original card) into the classifier head:
    seq_cls.score.weight ← lm_head.weight[0] (bias = 0 or matching bias if present).
  3. Propagate tokenizer specials to model config (pad_token_id, eos_token_id, etc.) and use left padding to keep the last token aligned across a batch.
  4. Optional: cast scores through bf16 → float at readout to match the base card’s rounding.

Result: SequenceClassification logit ≡ original next_logits[:, 0] given the same prompt template and padding.


Quick parity check (minimal)

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification

BASE = "ContextualAI/ctxl-rerank-v2-instruct-multilingual-1b"
WRAP = "sigridjineth/ctxl-rerank-v2-1b-seq-cls"

def fmt(q, inst, docs):
    inst = f" {inst}" if inst else ""
    return [f"Check whether a given document contains information helpful to answer the query.\n"
            f"<Document> {d}\n<Query> {q}{inst} ??" for d in docs]

q, docs = "Which is a domestic animal?", ["Cats are pets.", "The moon is made of cheese."]

tok = AutoTokenizer.from_pretrained(BASE, use_fast=True)
if tok.pad_token is None: tok.pad_token = tok.eos_token
tok.padding_side = "left"

enc = tok(fmt(q, "", docs), return_tensors="pt", padding=True, truncation=True)

lm  = AutoModelForCausalLM.from_pretrained(BASE, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32).eval()
cls = AutoModelForSequenceClassification.from_pretrained(WRAP, torch_dtype=lm.dtype).eval()

with torch.no_grad():
    lm_scores  = lm(**enc).logits[:, -1, 0].to(torch.bfloat16).float()
    cls_scores = cls(**enc).logits.squeeze(-1).to(torch.bfloat16).float()

print("allclose:", torch.allclose(lm_scores, cls_scores, atol=1e-6, rtol=1e-6))

Usage (SequenceClassification)

from transformers import AutoTokenizer, AutoModelForSequenceClassification

MODEL = "sigridjineth/ctxl-rerank-v2-1b-seq-cls"

def format_prompts(query, instruction, docs):
    inst = f" {instruction}" if instruction else ""
    return [f"Check whether a given document contains information helpful to answer the query.\n"
            f"<Document> {d}\n<Query> {query}{inst} ??" for d in docs]

tok = AutoTokenizer.from_pretrained(MODEL, use_fast=True)
if tok.pad_token is None: tok.pad_token = tok.eos_token
tok.padding_side = "left"

model = AutoModelForSequenceClassification.from_pretrained(MODEL).eval()

query = "Which is a domestic animal?"
docs  = ["Cats are pets.", "The moon is made of cheese.", "Dogs are loyal companions."]

enc = tok(format_prompts(query, "", docs), return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
    logits = model(**enc).logits.squeeze(-1)          # one score per doc
    scores = logits.to(torch.bfloat16).float().tolist()

for s, d in sorted(zip(scores, docs), key=lambda x: x[0], reverse=True):
    print(f"{s:.4f} | {d}")

Notes

  • Keep the prompt template identical to the base card and use left padding; that’s what makes the “last token” alignment (and parity) reliable.
  • Requires transformers ≥ 4.51.0.
  • License is inherited: CC-BY-NC-SA 4.0 (non-commercial, share-alike, attribution).
  • Credit for the model and training goes to Contextual AI; this is just a packaging change for easier deployment.

If this is useful, I’m happy to add wrappers for the other ctxl-rerank-v2 sizes as well.

Sign up or log in to comment