Converting a reranker model to a single label classification model
#1
by
sigridjineth
- opened
Hi all—sharing a small utility conversion I made to simplify using ContextualAI/ctxl-rerank-v2-instruct-multilingual-1b in CrossEncoder/“classification” style pipelines.
- Model: https://huggingface.co/sigridjineth/ctxl-rerank-v2-1b-seq-cls
- What it is: a SequenceClassification (num_labels=1) wrapper that emits a single score per input which is numerically identical to the original model’s last-token
next_logits[:, 0]
. - Why: easier integration (Sentence-Transformers CrossEncoder, standard Transformers classification APIs), lower overhead (no full vocab projection), simpler serving/thresholding.
I’ll attach the convert.py
I used.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
convert.py
Qwen3 기반 reranker LM(예: ContextualAI/ctxl-rerank-v2-instruct-multilingual-1b)을
단일 로짓 SequenceClassification 모델로 변환합니다.
원리:
- 모델이 마지막 토큰에서 vocab ID=k(기본 0)의 로짓을 '점수'로 쓰도록 학습되었다면,
lm_head.weight[k] (필요 시 bias[k])를 그대로 분류기 헤드에 이식하면
LM의 next_logits[:, k]와 동치인 단일 로짓을 얻을 수 있습니다.
주의:
- 좌패딩(left padding) + pad_token 설정을 권장합니다 (모델 카드 Quickstart도 동일).
- parity(수치 일치)를 위해 가중치 벡터에 BF16 round-trip 옵션 제공.
요구:
- transformers>=4.51.0, torch, safetensors
"""
import argparse
import os
import sys
from typing import List, Optional
import torch
from transformers import (
AutoTokenizer,
Qwen3ForCausalLM,
Qwen3ForSequenceClassification,
)
from transformers.utils.versions import require_version
# -----------------------------
# Helpers
# -----------------------------
def pick_device(name: str) -> str:
name = name.lower()
if name == "auto":
if torch.cuda.is_available():
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
if name in {"cuda", "cpu", "mps"}:
if name == "cuda" and not torch.cuda.is_available():
print("[warn] --device=cuda 이지만 CUDA 가용 X → cpu로 대체합니다.", file=sys.stderr)
return "cpu"
if name == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
print("[warn] --device=mps 이지만 MPS 가용 X → cpu로 대체합니다.", file=sys.stderr)
return "cpu"
return name
print(f"[warn] 알 수 없는 --device={name} → cpu 사용", file=sys.stderr)
return "cpu"
def pick_dtype(name: str, device: str) -> torch.dtype:
name = name.lower()
if name == "auto":
if device == "cuda":
return torch.bfloat16
return torch.float32
mapping = {
"float32": torch.float32,
"fp32": torch.float32,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
"float16": torch.float16,
"fp16": torch.float16,
}
if name in mapping:
return mapping[name]
print(f"[warn] 알 수 없는 --dtype={name} → float32 사용", file=sys.stderr)
return torch.float32
def ensure_tokenizer_padding(tok):
# pad_token 없으면 eos를 pad로 사용
if tok.pad_token is None:
tok.pad_token = tok.eos_token
# 좌패딩 권장 (마지막 유효 토큰 정렬)
tok.padding_side = "left"
def propagate_special_tokens_to_config(model, tok):
"""토크나이저의 special token ID를 모델 config에 반영 (배치>1 검증에 필요)"""
changed = []
if getattr(model.config, "pad_token_id", None) is None and tok.pad_token_id is not None:
model.config.pad_token_id = tok.pad_token_id
changed.append("pad_token_id")
if getattr(model.config, "eos_token_id", None) is None and tok.eos_token_id is not None:
model.config.eos_token_id = tok.eos_token_id
changed.append("eos_token_id")
if getattr(model.config, "bos_token_id", None) is None and tok.bos_token_id is not None:
model.config.bos_token_id = tok.bos_token_id
changed.append("bos_token_id")
if changed:
print(f"[info] propagated to config: {', '.join(changed)}")
def build_prompts(query: str, instruction: str, docs: List[str]) -> List[str]:
"""모델 카드 권장 템플릿과 동일한 형태 (HF Quickstart 참조)."""
inst = f" {instruction}" if instruction else ""
return [
"Check whether a given document contains information helpful to answer the query.\n"
f"<Document> {d}\n"
f"<Query> {query}{inst} ??"
for d in docs
]
def get_lm_head(model) -> torch.nn.Module:
if hasattr(model, "lm_head"):
return model.lm_head
if hasattr(model, "get_output_embeddings"):
return model.get_output_embeddings()
raise AttributeError("LM head를 찾을 수 없습니다 (lm_head/get_output_embeddings 둘 다 없음).")
def find_classifier_head(model_seqcls):
# Qwen3ForSequenceClassification는 보통 'score' 레이어를 가짐
if hasattr(model_seqcls, "score"):
return model_seqcls.score
if hasattr(model_seqcls, "classifier"):
return model_seqcls.classifier
raise AttributeError("SequenceClassification 모델에서 분류 헤드(score/classifier)를 찾을 수 없습니다.")
def copy_head_weights(
lm_head: torch.nn.Linear,
cls_head: torch.nn.Linear,
score_token_id: int = 0,
parity_bf16_roundtrip: bool = True,
):
W = lm_head.weight.detach() # [vocab, hidden]
if score_token_id < 0 or score_token_id >= W.shape[0]:
raise IndexError(f"score_token_id={score_token_id} 가 유효 범위를 벗어났습니다 (0..{W.shape[0]-1}).")
vec = W[score_token_id] # (hidden,)
if parity_bf16_roundtrip:
vec = vec.to(torch.bfloat16).to(W.dtype)
# bias가 있다면 해당 채널 bias도 반영
bias_val = None
if getattr(lm_head, "bias", None) is not None:
b = lm_head.bias.detach()
bias_val = b[score_token_id].item()
with torch.no_grad():
cls_head.weight.copy_(vec.to(cls_head.weight.dtype).unsqueeze(0)) # (1, hidden)
if getattr(cls_head, "bias", None) is not None:
if bias_val is None:
cls_head.bias.zero_()
else:
cls_head.bias.fill_(bias_val)
def verify_parity(
tok,
lm: Qwen3ForCausalLM,
seqcls: Qwen3ForSequenceClassification,
device: str,
query: str = "Which is a domestic animal?",
docs: Optional[List[str]] = None,
instruction: str = "",
max_length: int = 8192,
score_token_id: int = 0,
use_bf16_roundtrip_on_lm_scores: bool = True,
) -> bool:
if docs is None:
docs = ["Cats are pets.", "The moon is made of cheese."]
# 안전장치: 모델 config에도 pad/eos/bos 반영
propagate_special_tokens_to_config(lm, tok)
propagate_special_tokens_to_config(seqcls, tok)
prompts = build_prompts(query, instruction, docs)
enc = tok(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
enc = {k: v.to(device) for k, v in enc.items()}
lm.eval()
seqcls.eval()
with torch.no_grad():
lm_logits = lm(**enc).logits[:, -1, :] # [B, V]
lm_scores = lm_logits[:, score_token_id]
if use_bf16_roundtrip_on_lm_scores:
lm_scores = lm_scores.to(torch.bfloat16) # round-trip 유지(비교는 아래에서 f32로)
cls_scores = seqcls(**enc).logits.squeeze(-1)
# (중요) dtype/device 정렬 후 비교/출력
lm_scores = lm_scores.float().detach().cpu()
cls_scores = cls_scores.float().detach().cpu()
print("[verify] LM scores :", lm_scores[: min(3, lm_scores.shape[0])].tolist())
print("[verify] SeqCLS scores:", cls_scores[: min(3, cls_scores.shape[0])].tolist())
ok = torch.allclose(lm_scores, cls_scores, atol=1e-6, rtol=1e-6)
max_abs = (lm_scores - cls_scores).abs().max().item()
print(f"[verify] allclose = {ok} (max_abs_diff={max_abs:.3e})")
return bool(ok)
# -----------------------------
# Main
# -----------------------------
def main():
parser = argparse.ArgumentParser(description="Convert Qwen3-based reranker LM to single-logit SequenceClassification.")
parser.add_argument("--source", type=str, required=True, help="소스 LM 모델 ID 또는 로컬 경로 (예: ContextualAI/ctxl-rerank-v2-instruct-multilingual-1b)")
parser.add_argument("--out", type=str, required=True, help="변환된 seq-cls 모델 저장 디렉터리")
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu", "mps"], help="사용 디바이스")
parser.add_argument("--dtype", type=str, default="auto", choices=["auto", "float32", "fp32", "bfloat16", "bf16", "float16", "fp16"], help="모델 로드 dtype")
parser.add_argument("--score-token-id", type=int, default=0, help="점수 채널에 해당하는 vocab ID (기본=0)")
parser.add_argument("--no-bf16-roundtrip", action="store_true", help="가중치/점수 BF16 round-trip 비활성화")
parser.add_argument("--max-length", type=int, default=8192, help="검증 시 truncation 길이")
parser.add_argument("--verify", action="store_true", help="변환 후 동치성 검증 수행")
parser.add_argument("--verify-query", type=str, default="Which is a domestic animal?", help="검증용 쿼리")
parser.add_argument("--verify-docs", type=str, nargs="*", default=["Cats are pets.", "The moon is made of cheese."], help="검증용 문서들")
parser.add_argument("--verify-instruction", type=str, default="", help="검증용 instruction (선택)")
parser.add_argument("--push-to-hub", action="store_true", help="허깅페이스 허브 업로드")
parser.add_argument("--repo-id", type=str, default=None, help="업로드 리포지토리 ID (예: user/repo)")
parser.add_argument("--private", action="store_true", help="허브 업로드 시 private")
args = parser.parse_args()
try:
require_version("transformers>=4.51.0")
except Exception:
print("[error] transformers>=4.51.0 필요. 업그레이드: pip install -U 'transformers>=4.51.0'", file=sys.stderr)
raise
device = pick_device(args.device)
dtype = pick_dtype(args.dtype, device)
parity_bf16_roundtrip = not args.no_bf16_roundtrip
print(f"[info] device={device}, dtype={dtype}, source={args.source}")
print(f"[info] score_token_id={args.score_token_id}, bf16_roundtrip={parity_bf16_roundtrip}")
# 토크나이저
tok = AutoTokenizer.from_pretrained(args.source, use_fast=True)
ensure_tokenizer_padding(tok)
# LM 로드
lm = Qwen3ForCausalLM.from_pretrained(args.source, torch_dtype=dtype).to(device).eval()
propagate_special_tokens_to_config(lm, tok)
# SeqCLS 로드
seqcls = Qwen3ForSequenceClassification.from_pretrained(
args.source,
num_labels=1,
ignore_mismatched_sizes=True,
torch_dtype=dtype,
).to(device).eval()
propagate_special_tokens_to_config(seqcls, tok)
# 헤드 이식
lm_head = get_lm_head(lm)
cls_head = find_classifier_head(seqcls)
copy_head_weights(
lm_head=lm_head,
cls_head=cls_head,
score_token_id=args.score_token_id,
parity_bf16_roundtrip=parity_bf16_roundtrip,
)
# 메타데이터
seqcls.config.problem_type = "single_label_classification"
seqcls.config.id2label = {0: "SCORE"}
seqcls.config.label2id = {"SCORE": 0}
# 저장
os.makedirs(args.out, exist_ok=True)
seqcls.save_pretrained(args.out)
tok.save_pretrained(args.out)
print(f"[save] saved to: {args.out}")
# 검증 (선택)
if args.verify:
ok = verify_parity(
tok=tok,
lm=lm,
seqcls=seqcls,
device=device,
query=args.verify_query,
docs=args.verify_docs,
instruction=args.verify_instruction,
max_length=args.max_length,
score_token_id=args.score_token_id,
use_bf16_roundtrip_on_lm_scores=parity_bf16_roundtrip,
)
if not ok:
print("[warn] 검증 allclose=False 입니다. 템플릿/패딩/round-trip/dtype을 점검하세요.", file=sys.stderr)
# 허브 업로드 (선택)
if args.push_to_hub:
if not args.repo_id:
print("[error] --push-to-hub 사용 시 --repo-id를 지정하세요 (예: user/repo).", file=sys.stderr)
sys.exit(2)
print(f"[hub] pushing to: {args.repo_id} (private={args.private})")
seqcls.push_to_hub(args.repo_id, private=args.private)
tok.push_to_hub(args.repo_id, private=args.private)
print("[hub] push complete.")
if __name__ == "__main__":
main()
Method (short)
- Load the base LM and read
lm_head.weight
. - Copy row 0 (the score channel used by the original card) into the classifier head:
seq_cls.score.weight ← lm_head.weight[0]
(bias = 0 or matching bias if present). - Propagate tokenizer specials to model config (
pad_token_id
,eos_token_id
, etc.) and use left padding to keep the last token aligned across a batch. - Optional: cast scores through bf16 → float at readout to match the base card’s rounding.
Result: SequenceClassification
logit ≡ original next_logits[:, 0]
given the same prompt template and padding.
Quick parity check (minimal)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
BASE = "ContextualAI/ctxl-rerank-v2-instruct-multilingual-1b"
WRAP = "sigridjineth/ctxl-rerank-v2-1b-seq-cls"
def fmt(q, inst, docs):
inst = f" {inst}" if inst else ""
return [f"Check whether a given document contains information helpful to answer the query.\n"
f"<Document> {d}\n<Query> {q}{inst} ??" for d in docs]
q, docs = "Which is a domestic animal?", ["Cats are pets.", "The moon is made of cheese."]
tok = AutoTokenizer.from_pretrained(BASE, use_fast=True)
if tok.pad_token is None: tok.pad_token = tok.eos_token
tok.padding_side = "left"
enc = tok(fmt(q, "", docs), return_tensors="pt", padding=True, truncation=True)
lm = AutoModelForCausalLM.from_pretrained(BASE, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32).eval()
cls = AutoModelForSequenceClassification.from_pretrained(WRAP, torch_dtype=lm.dtype).eval()
with torch.no_grad():
lm_scores = lm(**enc).logits[:, -1, 0].to(torch.bfloat16).float()
cls_scores = cls(**enc).logits.squeeze(-1).to(torch.bfloat16).float()
print("allclose:", torch.allclose(lm_scores, cls_scores, atol=1e-6, rtol=1e-6))
Usage (SequenceClassification)
from transformers import AutoTokenizer, AutoModelForSequenceClassification
MODEL = "sigridjineth/ctxl-rerank-v2-1b-seq-cls"
def format_prompts(query, instruction, docs):
inst = f" {instruction}" if instruction else ""
return [f"Check whether a given document contains information helpful to answer the query.\n"
f"<Document> {d}\n<Query> {query}{inst} ??" for d in docs]
tok = AutoTokenizer.from_pretrained(MODEL, use_fast=True)
if tok.pad_token is None: tok.pad_token = tok.eos_token
tok.padding_side = "left"
model = AutoModelForSequenceClassification.from_pretrained(MODEL).eval()
query = "Which is a domestic animal?"
docs = ["Cats are pets.", "The moon is made of cheese.", "Dogs are loyal companions."]
enc = tok(format_prompts(query, "", docs), return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
logits = model(**enc).logits.squeeze(-1) # one score per doc
scores = logits.to(torch.bfloat16).float().tolist()
for s, d in sorted(zip(scores, docs), key=lambda x: x[0], reverse=True):
print(f"{s:.4f} | {d}")
Notes
- Keep the prompt template identical to the base card and use left padding; that’s what makes the “last token” alignment (and parity) reliable.
- Requires
transformers ≥ 4.51.0
. - License is inherited: CC-BY-NC-SA 4.0 (non-commercial, share-alike, attribution).
- Credit for the model and training goes to Contextual AI; this is just a packaging change for easier deployment.
If this is useful, I’m happy to add wrappers for the other ctxl-rerank-v2 sizes as well.