oist commited on
Commit
ca6801c
·
1 Parent(s): 467be55

Add BLASER-REF model and config

Browse files
Files changed (4) hide show
  1. README.md +288 -0
  2. config.json +17 -0
  3. model.safetensors +3 -0
  4. modeling_blaser.py +136 -0
README.md ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ language:
4
+ - ace
5
+ - acm
6
+ - acq
7
+ - aeb
8
+ - af
9
+ - ajp
10
+ - ak
11
+ - am
12
+ - apc
13
+ - ar
14
+ - ars
15
+ - ary
16
+ - arz
17
+ - as
18
+ - ast
19
+ - awa
20
+ - ay
21
+ - azb
22
+ - azj
23
+ - ba
24
+ - bm
25
+ - ban
26
+ - be
27
+ - bem
28
+ - bn
29
+ - bho
30
+ - bjn
31
+ - bo
32
+ - bs
33
+ - bug
34
+ - bg
35
+ - ca
36
+ - ceb
37
+ - cs
38
+ - cjk
39
+ - ckb
40
+ - crh
41
+ - cy
42
+ - da
43
+ - de
44
+ - dik
45
+ - dyu
46
+ - dz
47
+ - el
48
+ - en
49
+ - eo
50
+ - et
51
+ - eu
52
+ - ee
53
+ - fo
54
+ - fa
55
+ - fj
56
+ - fi
57
+ - fon
58
+ - fr
59
+ - fur
60
+ - ff
61
+ - gd
62
+ - ga
63
+ - gl
64
+ - gn
65
+ - gu
66
+ - ht
67
+ - ha
68
+ - he
69
+ - hi
70
+ - hne
71
+ - hr
72
+ - hu
73
+ - hy
74
+ - ig
75
+ - ilo
76
+ - id
77
+ - is
78
+ - it
79
+ - jv
80
+ - ja
81
+ - kab
82
+ - kac
83
+ - kam
84
+ - kn
85
+ - ks
86
+ - ka
87
+ - kr
88
+ - kk
89
+ - kbp
90
+ - kea
91
+ - km
92
+ - ki
93
+ - rw
94
+ - ky
95
+ - kmb
96
+ - kg
97
+ - ko
98
+ - kmr
99
+ - lo
100
+ - lv
101
+ - lij
102
+ - li
103
+ - ln
104
+ - lt
105
+ - lmo
106
+ - ltg
107
+ - lb
108
+ - lua
109
+ - lg
110
+ - luo
111
+ - lus
112
+ - mag
113
+ - mai
114
+ - ml
115
+ - mr
116
+ - min
117
+ - mk
118
+ - plt
119
+ - mt
120
+ - mni
121
+ - mn
122
+ - mos
123
+ - mi
124
+ - ms
125
+ - my
126
+ - nl
127
+ - nn
128
+ - nb
129
+ - ne
130
+ - nso
131
+ - nus
132
+ - ny
133
+ - oc
134
+ - gaz
135
+ - ory
136
+ - pag
137
+ - pa
138
+ - pap
139
+ - pl
140
+ - pt
141
+ - prs
142
+ - pbt
143
+ - qu
144
+ - ro
145
+ - rn
146
+ - ru
147
+ - sg
148
+ - sa
149
+ - sat
150
+ - scn
151
+ - shn
152
+ - si
153
+ - sk
154
+ - sl
155
+ - sm
156
+ - sn
157
+ - sd
158
+ - so
159
+ - st
160
+ - es
161
+ - als
162
+ - sc
163
+ - sr
164
+ - ss
165
+ - su
166
+ - sv
167
+ - sw
168
+ - szl
169
+ - ta
170
+ - tt
171
+ - te
172
+ - tg
173
+ - tl
174
+ - th
175
+ - ti
176
+ - taq
177
+ - tpi
178
+ - tn
179
+ - ts
180
+ - tk
181
+ - tum
182
+ - tr
183
+ - tw
184
+ - tzm
185
+ - ug
186
+ - uk
187
+ - umb
188
+ - ur
189
+ - uz
190
+ - vec
191
+ - vi
192
+ - war
193
+ - wo
194
+ - xh
195
+ - yi
196
+ - yo
197
+ - yue
198
+ - zh
199
+ - zu
200
+ language_details: >-
201
+ ace_Arab, ace_Latn, acm_Arab, acq_Arab, aeb_Arab, afr_Latn, ajp_Arab,
202
+ aka_Latn, amh_Ethi, apc_Arab, arb_Arab, ars_Arab, ary_Arab, arz_Arab,
203
+ asm_Beng, ast_Latn, awa_Deva, ayr_Latn, azb_Arab, azj_Latn, bak_Cyrl,
204
+ bam_Latn, ban_Latn, bel_Cyrl, bem_Latn, ben_Beng, bho_Deva, bjn_Arab,
205
+ bod_Tibt, bos_Latn, bug_Latn, bul_Cyrl, cat_Latn, ceb_Latn, ces_Latn,
206
+ cjk_Latn, ckb_Arab, crh_Latn, cym_Latn, dan_Latn, deu_Latn, dik_Latn,
207
+ dyu_Latn, dzo_Tibt, ell_Grek, eng_Latn, epo_Latn, est_Latn, eus_Latn,
208
+ ewe_Latn, fao_Latn, pes_Arab, fij_Latn, fin_Latn, fon_Latn, fra_Latn,
209
+ fur_Latn, fuv_Latn, gla_Latn, gle_Latn, glg_Latn, grn_Latn, guj_Gujr,
210
+ hat_Latn, hau_Latn, heb_Hebr, hin_Deva, hne_Deva, hrv_Latn, hun_Latn,
211
+ hye_Armn, ibo_Latn, ilo_Latn, ind_Latn, isl_Latn, ita_Latn, jav_Latn,
212
+ jpn_Jpan, kab_Latn, kac_Latn, kam_Latn, kan_Knda, kas_Arab, kas_Deva,
213
+ kat_Geor, knc_Arab, knc_Latn, kaz_Cyrl, kbp_Latn, kea_Latn, khm_Khmr,
214
+ kik_Latn, kin_Latn, kir_Cyrl, kmb_Latn, kon_Latn, kor_Hang, kmr_Latn,
215
+ lao_Laoo, lvs_Latn, lij_Latn, lim_Latn, lin_Latn, lit_Latn, lmo_Latn,
216
+ ltg_Latn, ltz_Latn, lua_Latn, lug_Latn, luo_Latn, lus_Latn, mag_Deva,
217
+ mai_Deva, mal_Mlym, mar_Deva, min_Latn, mkd_Cyrl, plt_Latn, mlt_Latn,
218
+ mni_Beng, khk_Cyrl, mos_Latn, mri_Latn, zsm_Latn, mya_Mymr, nld_Latn,
219
+ nno_Latn, nob_Latn, npi_Deva, nso_Latn, nus_Latn, nya_Latn, oci_Latn,
220
+ gaz_Latn, ory_Orya, pag_Latn, pan_Guru, pap_Latn, pol_Latn, por_Latn,
221
+ prs_Arab, pbt_Arab, quy_Latn, ron_Latn, run_Latn, rus_Cyrl, sag_Latn,
222
+ san_Deva, sat_Beng, scn_Latn, shn_Mymr, sin_Sinh, slk_Latn, slv_Latn,
223
+ smo_Latn, sna_Latn, snd_Arab, som_Latn, sot_Latn, spa_Latn, als_Latn,
224
+ srd_Latn, srp_Cyrl, ssw_Latn, sun_Latn, swe_Latn, swh_Latn, szl_Latn,
225
+ tam_Taml, tat_Cyrl, tel_Telu, tgk_Cyrl, tgl_Latn, tha_Thai, tir_Ethi,
226
+ taq_Latn, taq_Tfng, tpi_Latn, tsn_Latn, tso_Latn, tuk_Latn, tum_Latn,
227
+ tur_Latn, twi_Latn, tzm_Tfng, uig_Arab, ukr_Cyrl, umb_Latn, urd_Arab,
228
+ uzn_Latn, vec_Latn, vie_Latn, war_Latn, wol_Latn, xho_Latn, ydd_Hebr,
229
+ yor_Latn, yue_Hant, zho_Hans, zho_Hant, zul_Latn
230
+ pipeline_tag: sentence-similarity
231
+ ---
232
+
233
+ # BLASER QE (Ported)
234
+
235
+ This is a **ported version of the BLASER quality estimation (REF) model** originally developed in [BLASER: Bilingual Language-Agnostic Sentence Representations](https://huggingface.co/facebook/blaser-2.0-ref).
236
+
237
+ - **Ported to Hugging Face Transformers**: no dependency on Fairseq.
238
+ - **Uses embeddings from the ported SONAR 200 multilingual text encoder** ([cointegrated/SONAR_200_text_encoder](https://huggingface.co/cointegrated/SONAR_200_text_encoder)).
239
+ - **Supports the same 202 languages** as SONAR / NLLB-200.
240
+ - **Outputs BLASER scores on a 1–5 scale** for a source–MT–REF triplet.
241
+
242
+ > ⚠️ This is **not the original implementation**. Attribution goes to the original BLASER authors.
243
+
244
+ ---
245
+
246
+ ## How to compute QE scores
247
+
248
+ ```python
249
+ import torch
250
+ from transformers import AutoTokenizer, AutoModel
251
+ from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
252
+
253
+ # 1. Load SONAR encoder
254
+ sonar_model_name = "cointegrated/SONAR_200_text_encoder"
255
+ encoder = M2M100Encoder.from_pretrained(sonar_model_name)
256
+ tokenizer = AutoTokenizer.from_pretrained(sonar_model_name)
257
+
258
+ def encode_mean_pool(texts, tokenizer, encoder, lang='eng_Latn', norm=False):
259
+ tokenizer.src_lang = lang
260
+ with torch.inference_mode():
261
+ batch = tokenizer(texts, return_tensors='pt', padding=True)
262
+ seq_embs = encoder(**batch).last_hidden_state
263
+ mask = batch.attention_mask
264
+ mean_emb = (seq_embs * mask.unsqueeze(-1)).sum(1) / mask.unsqueeze(-1).sum(1)
265
+ if norm:
266
+ mean_emb = torch.nn.functional.normalize(mean_emb)
267
+ return mean_emb
268
+
269
+ # Example sentences
270
+ src_sentences = ["Le chat s'assit sur le tapis."]
271
+ mt_sentences = ["The cat sat down on the carpet."] # Example MT output
272
+ ref_sentences = ["The cat sat on the mat."] # Example reference translation
273
+
274
+ # Encode source and MT sentences
275
+ src_embs = encode_mean_pool(src_sentences, tokenizer, encoder, lang="fra_Latn")
276
+ mt_embs = encode_mean_pool(mt_sentences, tokenizer, encoder, lang="eng_Latn")
277
+ ref_embs = encode_mean_pool(ref_sentences, tokenizer, encoder, lang="eng_Latn")
278
+
279
+ # 2. Load BLASER QE model (ported)
280
+ ref_model_name = "oist/blaser-2.0-ref-ported"
281
+ ref_model = AutoModel.from_pretrained(qe_model_name, trust_remote_code=True)
282
+ ref_model.eval() # set to evaluation mode
283
+
284
+ # 3. Compute QE scores
285
+ with torch.inference_mode():
286
+ ref_scores = ref_model(src_embs, mt_embs, ref_embs) # expects source and MT embeddings, and ref embeddings
287
+ print("Blaser score shape:", ref_scores.shape)
288
+ print("Blaser scores:", ref_scores[0])
config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "TANH",
3
+ "architectures": ["BlaserModel"],
4
+ "dropout": 0.1,
5
+ "embedding_dim": 1024,
6
+ "hidden_dims": [3072, 1536],
7
+ "input_form": "COMET",
8
+ "model_type": "blaser",
9
+ "norm_emb": true,
10
+ "output_act": false,
11
+ "output_dim": 1,
12
+ "transformers_version": "4.56.1",
13
+ "auto_map": {
14
+ "AutoConfig": "modeling_blaser.BlaserConfig",
15
+ "AutoModel": "modeling_blaser.BlaserModel"
16
+ }
17
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4836d62d1e5540890dad7a9ac6f41317522a71dd195f3a813c991c87522225c1
3
+ size 94396980
modeling_blaser.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import List, Optional
5
+ from torch import Tensor
6
+ from transformers import PretrainedConfig, PreTrainedModel
7
+
8
+
9
+ # ---------------- CONFIG ---------------- #
10
+ class BlaserConfig(PretrainedConfig):
11
+ model_type = "blaser"
12
+
13
+ def __init__(
14
+ self,
15
+ embedding_dim=1024,
16
+ output_dim=1,
17
+ hidden_dims=None,
18
+ dropout=0.1,
19
+ activation="TANH",
20
+ input_form="COMET",
21
+ norm_emb=True,
22
+ output_act=False,
23
+ **kwargs,
24
+ ):
25
+ super().__init__(**kwargs)
26
+ self.embedding_dim = embedding_dim
27
+ self.output_dim = output_dim
28
+ self.hidden_dims = hidden_dims if hidden_dims is not None else [3072, 1536]
29
+ self.dropout = dropout
30
+ self.activation = activation
31
+ self.input_form = input_form
32
+ self.norm_emb = norm_emb
33
+ self.output_act = output_act
34
+
35
+
36
+ # ---------------- CORE MODEL ---------------- #
37
+ ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU}
38
+
39
+
40
+ class BlaserCore(nn.Module):
41
+ def __init__(
42
+ self,
43
+ embedding_dim: int,
44
+ output_dim: int,
45
+ hidden_dims: List[int],
46
+ dropout: float,
47
+ activation: str,
48
+ input_form: str,
49
+ norm_emb: bool,
50
+ output_act: bool,
51
+ ):
52
+ super().__init__()
53
+ self.input_form = input_form
54
+ self.norm_emb = norm_emb
55
+
56
+ if input_form == "COMET":
57
+ embedding_dim *= 6
58
+ elif input_form == "QE":
59
+ embedding_dim *= 4
60
+ else:
61
+ raise ValueError(f"Unrecognized input_form: {input_form}")
62
+ if activation not in ACTIVATIONS:
63
+ raise ValueError(f"Unrecognized activation: {activation}")
64
+
65
+ modules: List[nn.Module] = []
66
+ if hidden_dims:
67
+ if dropout > 0:
68
+ modules.append(nn.Dropout(p=dropout))
69
+ nprev = embedding_dim
70
+ for h in hidden_dims:
71
+ modules.append(nn.Linear(nprev, h))
72
+ modules.append(ACTIVATIONS[activation]())
73
+ if dropout > 0:
74
+ modules.append(nn.Dropout(p=dropout))
75
+ nprev = h
76
+ modules.append(nn.Linear(nprev, output_dim))
77
+ if output_act:
78
+ modules.append(nn.Tanh())
79
+ else:
80
+ modules.append(nn.Linear(embedding_dim, output_dim))
81
+
82
+ self.mlp = nn.Sequential(*modules)
83
+
84
+ def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]:
85
+ return F.normalize(emb) if (emb is not None and self.norm_emb) else emb
86
+
87
+ def _featurize(self, src: Tensor, mt: Tensor, ref: Optional[Tensor] = None) -> Tensor:
88
+ if self.input_form == "COMET":
89
+ if ref is None:
90
+ raise ValueError("COMET input_form requires reference embedding")
91
+ return torch.cat(
92
+ [ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)],
93
+ dim=-1,
94
+ )
95
+ elif self.input_form == "QE":
96
+ return torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1)
97
+
98
+
99
+ # ---------------- HF MODEL WRAPPER ---------------- #
100
+ class BlaserModel(PreTrainedModel):
101
+ config_class = BlaserConfig
102
+
103
+ def __init__(self, config: BlaserConfig):
104
+ super().__init__(config)
105
+ # Directly assign the Sequential MLP to self.mlp
106
+ core = BlaserCore(
107
+ embedding_dim=config.embedding_dim,
108
+ output_dim=config.output_dim,
109
+ hidden_dims=config.hidden_dims,
110
+ dropout=config.dropout,
111
+ activation=config.activation,
112
+ input_form=config.input_form,
113
+ norm_emb=config.norm_emb,
114
+ output_act=config.output_act,
115
+ )
116
+ self.mlp = core.mlp
117
+ self.input_form = core.input_form
118
+ self.norm_emb = core.norm_emb
119
+
120
+ def forward(self, src, mt, ref=None):
121
+ # Use the same featurization as in BlaserCore
122
+ src = F.normalize(src) if self.norm_emb else src
123
+ mt = F.normalize(mt) if self.norm_emb else mt
124
+ ref = F.normalize(ref) if (ref is not None and self.norm_emb) else ref
125
+
126
+ if self.input_form == "COMET":
127
+ if ref is None:
128
+ raise ValueError("COMET input_form requires reference embedding")
129
+ proc = torch.cat(
130
+ [ref, mt, src * mt, ref * mt, torch.abs(mt - src), torch.abs(mt - ref)],
131
+ dim=-1,
132
+ )
133
+ else: # QE
134
+ proc = torch.cat([src, mt, src * mt, torch.abs(mt - src)], dim=-1)
135
+
136
+ return self.mlp(proc)