waveletdeboshir commited on
Commit
b918931
·
verified ·
1 Parent(s): 5707ac4

Upload 14 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ language_model/unigrams.txt filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,69 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - ru
5
+ pipeline_tag: automatic-speech-recognition
6
+ library_name: transformers
7
+ tags:
8
+ - asr
9
+ - gigaam
10
+ - stt
11
+ - ru
12
+ - ctc
13
+ - audio
14
+ - speech
15
+ ---
16
+
17
+ [![Finetune In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/waveletdeboshir/c01334561f23c5167598b2054e50839a/gigaam-ctc-hf-finetune.ipynb)
18
+
19
+ # GigaAM-v2-CTC 🤗 Hugging Face transformers
20
+
21
+ * original git https://github.com/salute-developers/GigaAM
22
+
23
+ Russian ASR model GigaAM-v2-CTC.
24
+
25
+ ## Model info
26
+ This is an original GigaAM-v2-CTC with `transformers` library interface.
27
+
28
+ File [`gigaam_transformers.py`](https://huggingface.co/waveletdeboshir/gigaam-ctc/blob/main/gigaam_transformers.py) contains model, feature extractor and tokenizer classes with usual transformers methods. Model can be initialized with transformers auto classes (see an example below).
29
+
30
+ ## Installation
31
+
32
+ my lib versions:
33
+ * `torch` 2.5.1
34
+ * `torchaudio` 2.5.1
35
+ * `transformers` 4.49.0
36
+
37
+ ## Usage
38
+ Usage is same as other `transformers` ASR models.
39
+
40
+ ```python
41
+ from transformers import AutoModel, AutoProcessor
42
+ import torch
43
+ import torchaudio
44
+
45
+ # load audio
46
+ wav, sr = torchaudio.load("audio.wav")
47
+ # resample if necessary
48
+ wav = torchaudio.functional.resample(wav, sr, 16000)
49
+
50
+ # load model and processor
51
+ processor = AutoProcessor.from_pretrained("waveletdeboshir/gigaam-ctc", trust_remote_code=True)
52
+ model = AutoModel.from_pretrained("waveletdeboshir/gigaam-ctc", trust_remote_code=True)
53
+ model.eval()
54
+
55
+ input_features = processor(wav[0], sampling_rate=16000, return_tensors="pt")
56
+
57
+ # predict
58
+ with torch.no_grad():
59
+ logits = model(**input_features).logits
60
+ # greedy decoding
61
+ greedy_ids = logits.argmax(dim=-1)
62
+ # decode token ids to text
63
+ transcription = processor.batch_decode(greedy_ids)[0]
64
+
65
+ ```
66
+
67
+ ## Fine-tune
68
+ [![Finetune In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/waveletdeboshir/c01334561f23c5167598b2054e50839a/gigaam-ctc-hf-finetune.ipynb)
69
+ [Fine-tuning Jupyter](https://gist.github.com/waveletdeboshir/c01334561f23c5167598b2054e50839a)
added_tokens.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "</s>": 35,
3
+ "<s>": 34
4
+ }
alphabet.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "labels": [" ", "\u0430", "\u0431", "\u0432", "\u0433", "\u0434", "\u0435", "\u0436", "\u0437", "\u0438", "\u0439", "\u043a", "\u043b", "\u043c", "\u043d", "\u043e", "\u043f", "\u0440", "\u0441", "\u0442", "\u0443", "\u0444", "\u0445", "\u0446", "\u0447", "\u0448", "\u0449", "\u044a", "\u044b", "\u044c", "\u044d", "\u044e", "\u044f", ""],
3
+ "is_bpe": false
4
+ }
config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map":{
3
+ "AutoConfig": "gigaam_transformers.GigaAMConfig",
4
+ "AutoModel": "gigaam_transformers.GigaAMCTCHF",
5
+ "AutoModelForCTC": "gigaam_transformers.GigaAMCTCHF",
6
+ "AutoProcessor": "gigaam_transformers.GigaAMProcessorWithLM",
7
+ "AutoTokenizer": "gigaam_transformers.GigaAMCTCTokenizer",
8
+ "AutoFeatureExtractor": "gigaam_transformers.GigaAMFeatureExtractor"
9
+ },
10
+
11
+ "encoder": {
12
+ "feat_in": 64,
13
+ "n_layers": 16,
14
+ "d_model": 768,
15
+ "subsampling_factor": 4,
16
+ "ff_expansion_factor": 4,
17
+ "self_attention_model": "rotary",
18
+ "pos_emb_max_len": 5000,
19
+ "n_heads": 16,
20
+ "conv_kernel_size": 31,
21
+ "flash_attn": false
22
+ },
23
+ "head": {
24
+ "feat_in": 768,
25
+ "num_classes": 34
26
+ },
27
+ "labels": [
28
+ " ",
29
+ "а",
30
+ "б",
31
+ "в",
32
+ "г",
33
+ "д",
34
+ "е",
35
+ "ж",
36
+ "з",
37
+ "и",
38
+ "й",
39
+ "к",
40
+ "л",
41
+ "м",
42
+ "н",
43
+ "о",
44
+ "п",
45
+ "р",
46
+ "с",
47
+ "т",
48
+ "у",
49
+ "ф",
50
+ "х",
51
+ "ц",
52
+ "ч",
53
+ "ш",
54
+ "щ",
55
+ "ъ",
56
+ "ы",
57
+ "ь",
58
+ "э",
59
+ "ю",
60
+ "я"
61
+ ],
62
+ "blank_id": 33,
63
+ "model_type": "gigaam-ctc"
64
+ }
encoder.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Copied from https://github.com/salute-developers/GigaAM/blob/main/gigaam/encoder.py"""
2
+ import math
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from torch import Tensor, nn
8
+
9
+ # try:
10
+ # from flash_attn import flash_attn_func
11
+
12
+ # IMPORT_FLASH = True
13
+ # except Exception as err:
14
+ # IMPORT_FLASH = False
15
+ # IMPORT_FLASH_ERR = err
16
+
17
+ IMPORT_FLASH = False
18
+ IMPORT_FLASH_ERR = "Flash Attention not installed."
19
+
20
+ # from .utils import apply_masked_flash_attn, apply_rotary_pos_emb
21
+
22
+
23
+ def rtt_half(x: Tensor) -> Tensor:
24
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
25
+ return torch.cat([-x2, x1], dim=x1.ndim - 1)
26
+
27
+
28
+ def apply_rotary_pos_emb(
29
+ q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, offset: int = 0
30
+ ) -> Tuple[Tensor, Tensor]:
31
+ """
32
+ Applies Rotary Position Embeddings to query and key tensors.
33
+ """
34
+ cos, sin = (
35
+ cos[offset : q.shape[0] + offset, ...],
36
+ sin[offset : q.shape[0] + offset, ...],
37
+ )
38
+ return (q * cos) + (rtt_half(q) * sin), (k * cos) + (rtt_half(k) * sin)
39
+
40
+
41
+ # def apply_masked_flash_attn(
42
+ # q: Tensor,
43
+ # k: Tensor,
44
+ # v: Tensor,
45
+ # mask: Tensor,
46
+ # h: int,
47
+ # d_k: int,
48
+ # ) -> Tensor:
49
+ # """
50
+ # Applies Flash Attention with padding masks.
51
+ # """
52
+
53
+ # from einops import rearrange
54
+ # from flash_attn import flash_attn_varlen_func
55
+ # from flash_attn.bert_padding import pad_input, unpad_input
56
+
57
+ # pad_mask = ~mask[:, 0, :]
58
+ # b, t = pad_mask.shape
59
+ # q = q.view(b, t, h * d_k)
60
+ # k = k.view(b, t, h * d_k)
61
+ # v = v.view(b, t, h * d_k)
62
+
63
+ # q_unpad, indices_q, _, max_seqlen_q = unpad_input(q, pad_mask)[:4]
64
+ # q_unpad = rearrange(q_unpad, "nnz (h d) -> nnz h d", h=h)
65
+
66
+ # k_unpad = unpad_input(k, pad_mask)[0]
67
+ # k_unpad = rearrange(k_unpad, "nnz (h d) -> nnz h d", h=h)
68
+
69
+ # v_unpad = unpad_input(v, pad_mask)[0]
70
+ # v_unpad = rearrange(v_unpad, "nnz (h d) -> nnz h d", h=h)
71
+
72
+ # lengths_q = pad_mask.sum(1).to(torch.int32).to(q.device)
73
+ # cu_seqlens_q = F.pad(lengths_q.cumsum(0), (1, 0), value=0).to(torch.int32)
74
+ # max_seqlen_q = torch.max(lengths_q)
75
+
76
+ # output_unpad = flash_attn_varlen_func(
77
+ # q_unpad,
78
+ # k_unpad,
79
+ # v_unpad,
80
+ # cu_seqlens_q,
81
+ # cu_seqlens_q,
82
+ # max_seqlen_q,
83
+ # max_seqlen_q,
84
+ # )
85
+
86
+ # scores = pad_input(
87
+ # rearrange(output_unpad, "nnz h d -> nnz (h d)"),
88
+ # indices_q,
89
+ # b,
90
+ # t,
91
+ # )
92
+
93
+ # return scores
94
+
95
+
96
+ class StridingSubsampling(nn.Module):
97
+ """
98
+ Strided Subsampling layer used to reduce the sequence length.
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ subsampling_factor: int,
104
+ feat_in: int,
105
+ feat_out: int,
106
+ conv_channels: int,
107
+ ):
108
+ super().__init__()
109
+ self._sampling_num = int(math.log(subsampling_factor, 2))
110
+ self._stride = 2
111
+ self._kernel_size = 3
112
+ self._padding = (self._kernel_size - 1) // 2
113
+
114
+ layers: List[nn.Module] = []
115
+ in_channels = 1
116
+ for _ in range(self._sampling_num):
117
+ layers.append(
118
+ torch.nn.Conv2d(
119
+ in_channels=in_channels,
120
+ out_channels=conv_channels,
121
+ kernel_size=self._kernel_size,
122
+ stride=self._stride,
123
+ padding=self._padding,
124
+ )
125
+ )
126
+ layers.append(nn.ReLU())
127
+ in_channels = conv_channels
128
+
129
+ out_length = self.calc_output_length(torch.tensor(feat_in))
130
+ self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
131
+ self.conv = torch.nn.Sequential(*layers)
132
+
133
+ def calc_output_length(self, lengths: Tensor) -> Tensor:
134
+ """
135
+ Calculates the output length after applying the subsampling.
136
+ """
137
+ lengths = lengths.to(torch.float)
138
+ add_pad = 2 * self._padding - self._kernel_size
139
+ for _ in range(self._sampling_num):
140
+ lengths = torch.div(lengths + add_pad, self._stride) + 1.0
141
+ lengths = torch.floor(lengths)
142
+ return lengths.to(dtype=torch.int)
143
+
144
+ def forward(self, x: Tensor, lengths: Tensor) -> Tuple[Tensor, Tensor]:
145
+ x = self.conv(x.unsqueeze(1))
146
+ b, _, t, _ = x.size()
147
+ x = self.out(x.transpose(1, 2).reshape(b, t, -1))
148
+ return x, self.calc_output_length(lengths)
149
+
150
+
151
+ class MultiHeadAttention(nn.Module, ABC):
152
+ """
153
+ Base class of Multi-Head Attention Mechanisms.
154
+ """
155
+
156
+ def __init__(self, n_head: int, n_feat: int, flash_attn=False):
157
+ super().__init__()
158
+ assert n_feat % n_head == 0
159
+ self.d_k = n_feat // n_head
160
+ self.h = n_head
161
+ self.linear_q = nn.Linear(n_feat, n_feat)
162
+ self.linear_k = nn.Linear(n_feat, n_feat)
163
+ self.linear_v = nn.Linear(n_feat, n_feat)
164
+ self.linear_out = nn.Linear(n_feat, n_feat)
165
+ self.flash_attn = flash_attn
166
+ if self.flash_attn and not IMPORT_FLASH:
167
+ raise RuntimeError(
168
+ f"flash_attn_func was imported with err {IMPORT_FLASH_ERR}. "
169
+ "Please install flash_attn or use --no_flash flag. "
170
+ "If you have already done this, "
171
+ "--force-reinstall flag might be useful"
172
+ )
173
+
174
+ def forward_qkv(
175
+ self, query: Tensor, key: Tensor, value: Tensor
176
+ ) -> Tuple[Tensor, Tensor, Tensor]:
177
+ """
178
+ Projects the inputs into queries, keys, and values for multi-head attention.
179
+ """
180
+ b = query.size(0)
181
+ q = self.linear_q(query).view(b, -1, self.h, self.d_k)
182
+ k = self.linear_k(key).view(b, -1, self.h, self.d_k)
183
+ v = self.linear_v(value).view(b, -1, self.h, self.d_k)
184
+ if self.flash_attn:
185
+ return q, k, v
186
+ return q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
187
+
188
+ def forward_attention(
189
+ self, value: Tensor, scores: Tensor, mask: Optional[Tensor]
190
+ ) -> Tensor:
191
+ """
192
+ Computes the scaled dot-product attention given the projected values and scores.
193
+ """
194
+ b = value.size(0)
195
+ if mask is not None:
196
+ mask = mask.unsqueeze(1)
197
+ scores = scores.masked_fill(mask, -10000.0)
198
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
199
+ else:
200
+ attn = torch.softmax(scores, dim=-1)
201
+ x = torch.matmul(attn, value)
202
+ x = x.transpose(1, 2).reshape(b, -1, self.h * self.d_k)
203
+ return self.linear_out(x)
204
+
205
+
206
+ class RelPositionMultiHeadAttention(MultiHeadAttention):
207
+ """
208
+ Relative Position Multi-Head Attention module.
209
+ """
210
+
211
+ def __init__(self, n_head: int, n_feat: int):
212
+ super().__init__(n_head, n_feat)
213
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
214
+ self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
215
+ self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k))
216
+
217
+ def rel_shift(self, x: Tensor) -> Tensor:
218
+ b, h, qlen, pos_len = x.size()
219
+ x = torch.nn.functional.pad(x, pad=(1, 0))
220
+ x = x.view(b, h, -1, qlen)
221
+ return x[:, :, 1:].view(b, h, qlen, pos_len)
222
+
223
+ def forward(
224
+ self,
225
+ query: Tensor,
226
+ key: Tensor,
227
+ value: Tensor,
228
+ pos_emb: Tensor,
229
+ mask: Optional[Tensor] = None,
230
+ ) -> Tensor:
231
+ q, k, v = self.forward_qkv(query, key, value)
232
+ q = q.transpose(1, 2)
233
+ p = self.linear_pos(pos_emb)
234
+ p = p.view(pos_emb.shape[0], -1, self.h, self.d_k).transpose(1, 2)
235
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
236
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
237
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
238
+ matrix_bd = self.rel_shift(matrix_bd)
239
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
240
+ matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]
241
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
242
+ return self.forward_attention(v, scores, mask)
243
+
244
+
245
+ class RotaryPositionMultiHeadAttention(MultiHeadAttention):
246
+ """
247
+ Rotary Position Multi-Head Attention module.
248
+ """
249
+
250
+ def forward(
251
+ self,
252
+ query: Tensor,
253
+ key: Tensor,
254
+ value: Tensor,
255
+ pos_emb: List[Tensor],
256
+ mask: Optional[Tensor] = None,
257
+ ) -> Tensor:
258
+ b, t, _ = value.size()
259
+ query = query.transpose(0, 1).view(t, b, self.h, self.d_k)
260
+ key = key.transpose(0, 1).view(t, b, self.h, self.d_k)
261
+ value = value.transpose(0, 1).view(t, b, self.h, self.d_k)
262
+
263
+ cos, sin = pos_emb
264
+ query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0)
265
+
266
+ q, k, v = self.forward_qkv(
267
+ query.view(t, b, self.h * self.d_k).transpose(0, 1),
268
+ key.view(t, b, self.h * self.d_k).transpose(0, 1),
269
+ value.view(t, b, self.h * self.d_k).transpose(0, 1),
270
+ )
271
+
272
+ # if not self.flash_attn:
273
+ scores = torch.matmul(q, k.transpose(-2, -1) / math.sqrt(self.d_k))
274
+ out = self.forward_attention(v, scores, mask)
275
+ # else:
276
+ # if mask is None:
277
+ # scores = flash_attn_func(q, k, v)
278
+ # else:
279
+ # scores = apply_masked_flash_attn(q, k, v, mask, self.h, self.d_k)
280
+
281
+ # scores = scores.view(b, -1, self.h * self.d_k)
282
+ # out = self.linear_out(scores)
283
+
284
+ return out
285
+
286
+
287
+ class PositionalEncoding(nn.Module, ABC):
288
+ """
289
+ Base class of Positional Encodings.
290
+ """
291
+
292
+ def __init__(self, dim: int, base: int):
293
+ super().__init__()
294
+ self.dim = dim
295
+ self.base = base
296
+
297
+ @abstractmethod
298
+ def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
299
+ pass
300
+
301
+ def extend_pe(self, length: int, device: torch.device):
302
+ """
303
+ Extends the positional encoding buffer to process longer sequences.
304
+ """
305
+ pe = self.create_pe(length, device)
306
+ if pe is None:
307
+ return
308
+ if hasattr(self, "pe"):
309
+ self.pe = pe
310
+ else:
311
+ self.register_buffer("pe", pe, persistent=False)
312
+
313
+
314
+ class RelPositionalEmbedding(PositionalEncoding):
315
+ """
316
+ Relative Positional Embedding module.
317
+ """
318
+
319
+ def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
320
+ """
321
+ Creates the relative positional encoding matrix.
322
+ """
323
+ if hasattr(self, "pe") and self.pe.shape[1] >= 2 * length - 1:
324
+ return None
325
+ positions = torch.arange(length - 1, -length, -1, device=device).unsqueeze(1)
326
+ pos_length = positions.size(0)
327
+ pe = torch.zeros(pos_length, self.dim, device=positions.device)
328
+ div_term = torch.exp(
329
+ torch.arange(0, self.dim, 2, device=pe.device)
330
+ * -(math.log(10000.0) / self.dim)
331
+ )
332
+ pe[:, 0::2] = torch.sin(positions * div_term)
333
+ pe[:, 1::2] = torch.cos(positions * div_term)
334
+ return pe.unsqueeze(0)
335
+
336
+ def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
337
+ input_len = x.size(1)
338
+ center_pos = self.pe.size(1) // 2 + 1
339
+ start_pos = center_pos - input_len
340
+ end_pos = center_pos + input_len - 1
341
+ return x, self.pe[:, start_pos:end_pos]
342
+
343
+
344
+ class RotaryPositionalEmbedding(PositionalEncoding):
345
+ """
346
+ Rotary Positional Embedding module.
347
+ """
348
+
349
+ def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]:
350
+ """
351
+ Creates or extends the rotary positional encoding matrix.
352
+ """
353
+ if hasattr(self, "pe") and self.pe.size(0) >= 2 * length:
354
+ return None
355
+ positions = torch.arange(0, length, dtype=torch.float32, device=device)
356
+ inv_freq = 1.0 / (
357
+ self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
358
+ )
359
+ t = torch.arange(length, device=positions.device).type_as(inv_freq)
360
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
361
+ emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
362
+ return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]])
363
+
364
+ def forward(self, x: torch.Tensor) -> Tuple[Tensor, List[Tensor]]:
365
+ cos_emb = self.pe[0 : x.shape[1]]
366
+ half_pe = self.pe.shape[0] // 2
367
+ sin_emb = self.pe[half_pe : half_pe + x.shape[1]]
368
+ return x, [cos_emb, sin_emb]
369
+
370
+
371
+ class ConformerConvolution(nn.Module):
372
+ """
373
+ Conformer Convolution module.
374
+ """
375
+
376
+ def __init__(
377
+ self,
378
+ d_model: int,
379
+ kernel_size: int,
380
+ ):
381
+ super().__init__()
382
+ assert (kernel_size - 1) % 2 == 0
383
+ self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1)
384
+ self.depthwise_conv = nn.Conv1d(
385
+ in_channels=d_model,
386
+ out_channels=d_model,
387
+ kernel_size=kernel_size,
388
+ padding=(kernel_size - 1) // 2,
389
+ groups=d_model,
390
+ bias=True,
391
+ )
392
+ self.batch_norm = nn.BatchNorm1d(d_model)
393
+ self.activation = nn.SiLU()
394
+ self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
395
+
396
+ def forward(self, x: Tensor, pad_mask: Optional[Tensor] = None) -> Tensor:
397
+ x = x.transpose(1, 2)
398
+ x = self.pointwise_conv1(x)
399
+ x = nn.functional.glu(x, dim=1)
400
+ if pad_mask is not None:
401
+ x = x.masked_fill(pad_mask.unsqueeze(1), 0.0)
402
+ x = self.depthwise_conv(x)
403
+ x = self.batch_norm(x)
404
+ x = self.activation(x)
405
+ x = self.pointwise_conv2(x)
406
+ return x.transpose(1, 2)
407
+
408
+
409
+ class ConformerFeedForward(nn.Module):
410
+ """
411
+ Conformer Feed Forward module.
412
+ """
413
+
414
+ def __init__(self, d_model: int, d_ff: int, use_bias=True):
415
+ super().__init__()
416
+ self.linear1 = nn.Linear(d_model, d_ff, bias=use_bias)
417
+ self.activation = nn.SiLU()
418
+ self.linear2 = nn.Linear(d_ff, d_model, bias=use_bias)
419
+
420
+ def forward(self, x: Tensor) -> Tensor:
421
+ return self.linear2(self.activation(self.linear1(x)))
422
+
423
+
424
+ class ConformerLayer(nn.Module):
425
+ """
426
+ Conformer Layer module.
427
+ This module combines several submodules including feed forward networks,
428
+ depthwise separable convolution, and multi-head self-attention
429
+ to form a single Conformer block.
430
+ """
431
+
432
+ def __init__(
433
+ self,
434
+ d_model: int,
435
+ d_ff: int,
436
+ self_attention_model: str,
437
+ n_heads: int = 16,
438
+ conv_kernel_size: int = 31,
439
+ flash_attn: bool = False,
440
+ ):
441
+ super().__init__()
442
+ self.fc_factor = 0.5
443
+ self.norm_feed_forward1 = nn.LayerNorm(d_model)
444
+ self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff)
445
+ self.norm_conv = nn.LayerNorm(d_model)
446
+ self.conv = ConformerConvolution(
447
+ d_model=d_model,
448
+ kernel_size=conv_kernel_size,
449
+ )
450
+ self.norm_self_att = nn.LayerNorm(d_model)
451
+ if self_attention_model == "rotary":
452
+ self.self_attn: nn.Module = RotaryPositionMultiHeadAttention(
453
+ n_head=n_heads,
454
+ n_feat=d_model,
455
+ flash_attn=flash_attn,
456
+ )
457
+ else:
458
+ assert not flash_attn, "Not supported flash_attn for rel_pos"
459
+ self.self_attn = RelPositionMultiHeadAttention(
460
+ n_head=n_heads,
461
+ n_feat=d_model,
462
+ )
463
+ self.norm_feed_forward2 = nn.LayerNorm(d_model)
464
+ self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff)
465
+ self.norm_out = nn.LayerNorm(d_model)
466
+
467
+ def forward(
468
+ self,
469
+ x: Tensor,
470
+ pos_emb: Union[Tensor, List[Tensor]],
471
+ att_mask: Optional[Tensor] = None,
472
+ pad_mask: Optional[Tensor] = None,
473
+ ) -> Tensor:
474
+ residual = x
475
+ x = self.norm_feed_forward1(x)
476
+ x = self.feed_forward1(x)
477
+ residual = residual + x * self.fc_factor
478
+
479
+ x = self.norm_self_att(residual)
480
+ x = self.self_attn(x, x, x, pos_emb, mask=att_mask)
481
+ residual = residual + x
482
+
483
+ x = self.norm_conv(residual)
484
+ x = self.conv(x, pad_mask=pad_mask)
485
+ residual = residual + x
486
+
487
+ x = self.norm_feed_forward2(residual)
488
+ x = self.feed_forward2(x)
489
+ residual = residual + x * self.fc_factor
490
+
491
+ x = self.norm_out(residual)
492
+ return x
493
+
494
+
495
+ class ConformerEncoder(nn.Module):
496
+ """
497
+ Conformer Encoder module.
498
+ This module encapsulates the entire Conformer encoder architecture,
499
+ consisting of a StridingSubsampling layer, positional embeddings, and
500
+ a stack of Conformer Layers.
501
+ It serves as the main component responsible for processing speech features.
502
+ """
503
+
504
+ def __init__(
505
+ self,
506
+ feat_in: int = 64,
507
+ n_layers: int = 16,
508
+ d_model: int = 768,
509
+ subsampling_factor: int = 4,
510
+ ff_expansion_factor: int = 4,
511
+ self_attention_model: str = "rotary",
512
+ n_heads: int = 16,
513
+ pos_emb_max_len: int = 5000,
514
+ conv_kernel_size: int = 31,
515
+ flash_attn: bool = False,
516
+ ):
517
+ super().__init__()
518
+ self.feat_in = feat_in
519
+ assert self_attention_model in [
520
+ "rotary",
521
+ "rel_pos",
522
+ ], f"Not supported attn = {self_attention_model}"
523
+
524
+ self.pre_encode = StridingSubsampling(
525
+ subsampling_factor=subsampling_factor,
526
+ feat_in=feat_in,
527
+ feat_out=d_model,
528
+ conv_channels=d_model,
529
+ )
530
+
531
+ if self_attention_model == "rotary":
532
+ self.pos_enc: nn.Module = RotaryPositionalEmbedding(
533
+ d_model // n_heads, pos_emb_max_len
534
+ )
535
+ else:
536
+ self.pos_enc = RelPositionalEmbedding(d_model, pos_emb_max_len)
537
+
538
+ self.layers = nn.ModuleList()
539
+ for _ in range(n_layers):
540
+ layer = ConformerLayer(
541
+ d_model=d_model,
542
+ d_ff=d_model * ff_expansion_factor,
543
+ self_attention_model=self_attention_model,
544
+ n_heads=n_heads,
545
+ conv_kernel_size=conv_kernel_size,
546
+ flash_attn=flash_attn,
547
+ )
548
+ self.layers.append(layer)
549
+
550
+ self.pos_enc.extend_pe(pos_emb_max_len, next(self.parameters()).device)
551
+
552
+ def input_example(
553
+ self,
554
+ batch_size: int = 1,
555
+ seqlen: int = 200,
556
+ ):
557
+ device = next(self.parameters()).device
558
+ features = torch.zeros(batch_size, self.feat_in, seqlen)
559
+ feature_lengths = torch.full([batch_size], features.shape[-1])
560
+ return features.float().to(device), feature_lengths.to(device)
561
+
562
+ def input_names(self):
563
+ return ["audio_signal", "length"]
564
+
565
+ def output_names(self):
566
+ return ["encoded", "encoded_len"]
567
+
568
+ def dynamic_axes(self):
569
+ return {
570
+ "audio_signal": {0: "batch_size", 2: "seq_len"},
571
+ "length": {0: "batch_size"},
572
+ "encoded": {0: "batch_size", 1: "seq_len"},
573
+ "encoded_len": {0: "batch_size"},
574
+ }
575
+
576
+ def forward(self, audio_signal: Tensor, length: Tensor) -> Tuple[Tensor, Tensor]:
577
+ audio_signal, length = self.pre_encode(
578
+ x=audio_signal.transpose(1, 2), lengths=length
579
+ )
580
+
581
+ max_len = audio_signal.size(1)
582
+ audio_signal, pos_emb = self.pos_enc(x=audio_signal)
583
+
584
+ pad_mask = torch.arange(0, max_len, device=audio_signal.device).expand(
585
+ length.size(0), -1
586
+ ) < length.unsqueeze(-1)
587
+
588
+ att_mask = None
589
+ if audio_signal.shape[0] > 1:
590
+ att_mask = pad_mask.unsqueeze(1).repeat([1, max_len, 1])
591
+ att_mask = torch.logical_and(att_mask, att_mask.transpose(1, 2))
592
+ att_mask = ~att_mask
593
+
594
+ pad_mask = ~pad_mask
595
+
596
+ for layer in self.layers:
597
+ audio_signal = layer(
598
+ x=audio_signal,
599
+ pos_emb=pos_emb,
600
+ att_mask=att_mask,
601
+ pad_mask=pad_mask,
602
+ )
603
+
604
+ return audio_signal.transpose(1, 2), length
gigaam_transformers.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchaudio
8
+ from .encoder import ConformerEncoder
9
+ from torch import Tensor
10
+ from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
11
+ from transformers.configuration_utils import PretrainedConfig
12
+ from transformers.feature_extraction_sequence_utils import \
13
+ SequenceFeatureExtractor
14
+ from transformers.feature_extraction_utils import BatchFeature
15
+ from transformers.modeling_outputs import CausalLMOutput
16
+ from transformers.modeling_utils import PreTrainedModel
17
+
18
+
19
+ class GigaAMCTC(nn.Module):
20
+ """
21
+ GigaAM-CTC model
22
+ """
23
+
24
+ def __init__(self, config_encoder, config_head):
25
+ super().__init__()
26
+ self.encoder = ConformerEncoder(**config_encoder)
27
+ self.head = CTCHead(**config_head)
28
+
29
+ def forward(self, input_features: Tensor, input_lengths: Tensor) -> Tensor:
30
+ encoded, encoded_lengths = self.encoder(input_features, input_lengths)
31
+ logits = self.head(encoded)
32
+ return logits, encoded_lengths
33
+
34
+
35
+ class CTCHead(nn.Module):
36
+ """
37
+ CTC Head module for Connectionist Temporal Classification.
38
+ """
39
+
40
+ def __init__(self, feat_in: int, num_classes: int):
41
+ super().__init__()
42
+ self.decoder_layers = nn.Sequential(
43
+ nn.Conv1d(feat_in, num_classes, kernel_size=1)
44
+ )
45
+
46
+ def forward(self, encoder_output: Tensor) -> Tensor:
47
+ # B x C x T
48
+ return self.decoder_layers(encoder_output)
49
+
50
+
51
+ class GigaAMFeatureExtractor(SequenceFeatureExtractor):
52
+ """
53
+ Feature extractor for GigaAM.
54
+ """
55
+ model_input_names = ["input_features"]
56
+
57
+ def __init__(
58
+ self,
59
+ feature_size=64,
60
+ sampling_rate=16000,
61
+ padding_value=0.0,
62
+ chunk_length=30.0,
63
+ **kwargs,
64
+ ):
65
+ super().__init__(
66
+ feature_size=feature_size,
67
+ sampling_rate=sampling_rate,
68
+ padding_value=padding_value,
69
+ chunk_length=chunk_length,
70
+ **kwargs,
71
+ )
72
+ self.hop_length = sampling_rate // 100
73
+ self.n_samples = chunk_length * sampling_rate
74
+ self.featurizer = torchaudio.transforms.MelSpectrogram(
75
+ sample_rate=sampling_rate,
76
+ n_fft=sampling_rate // 40,
77
+ win_length=sampling_rate // 40,
78
+ hop_length=self.hop_length,
79
+ n_mels=feature_size,
80
+ )
81
+
82
+ def to_dict(self) -> Dict[str, Union[str, int, Dict]]:
83
+ dictionary = super().to_dict()
84
+
85
+ if "featurizer" in dictionary:
86
+ del dictionary["featurizer"]
87
+ dictionary["hop_length"] = self.hop_length
88
+ dictionary["n_samples"] = self.n_samples
89
+ return dictionary
90
+
91
+ def out_len(self, input_lengths: Tensor) -> Tensor:
92
+ """
93
+ Calculates the output length after the feature extraction process.
94
+ """
95
+ return input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long()
96
+
97
+ def __call__(
98
+ self,
99
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
100
+ sampling_rate: Optional[int] = None,
101
+ padding: str = "max_length",
102
+ **kwargs,
103
+ ):
104
+ is_batched_numpy = (
105
+ isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
106
+ )
107
+ if is_batched_numpy and len(raw_speech.shape) > 2:
108
+ raise ValueError(
109
+ f"Only mono-channel audio is supported for input to {self}"
110
+ )
111
+ is_batched = is_batched_numpy or (
112
+ isinstance(raw_speech, (list, tuple))
113
+ and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
114
+ )
115
+
116
+ if is_batched:
117
+ raw_speech = [
118
+ np.asarray([speech], dtype=np.float32).T for speech in raw_speech
119
+ ]
120
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
121
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
122
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(
123
+ np.float64
124
+ ):
125
+ raw_speech = raw_speech.astype(np.float32)
126
+
127
+ # always return batch
128
+ if not is_batched:
129
+ raw_speech = [np.asarray([raw_speech]).T]
130
+
131
+ input_lengths = torch.tensor([len(speech) for speech in raw_speech])
132
+
133
+ batched_speech = BatchFeature({"input_features": raw_speech})
134
+
135
+ padded_inputs = self.pad(
136
+ batched_speech,
137
+ padding=padding,
138
+ max_length=self.n_samples,
139
+ truncation=False,
140
+ return_tensors="pt",
141
+ )
142
+
143
+ input_features = padded_inputs["input_features"].transpose(1, 2)
144
+ input_features = self.featurizer(input_features).squeeze(1)
145
+ input_features = torch.log(input_features.clamp_(1e-9, 1e9))
146
+ input_lengths = self.out_len(input_lengths)
147
+
148
+ return BatchFeature({"input_features": input_features, "input_lengths": input_lengths}, tensor_type="pt")
149
+
150
+
151
+ class GigaAMCTCTokenizer(Wav2Vec2CTCTokenizer):
152
+ """
153
+ Char tokenizer for GigaAM-CTC model.
154
+ """
155
+ def __init__(
156
+ self,
157
+ vocab_file,
158
+ unk_token="[BLANK]",
159
+ pad_token="[BLANK]",
160
+ bos_token=None,
161
+ eos_token=None,
162
+ word_delimiter_token=" ",
163
+ **kwargs,
164
+ ):
165
+ super().__init__(
166
+ vocab_file=vocab_file,
167
+ unk_token=unk_token,
168
+ pad_token=pad_token,
169
+ bos_token=bos_token,
170
+ eos_token=eos_token,
171
+ word_delimiter_token=word_delimiter_token,
172
+ **kwargs,
173
+ )
174
+
175
+
176
+ class GigaAMProcessor(Wav2Vec2Processor):
177
+ feature_extractor_class = "GigaAMFeatureExtractor"
178
+ tokenizer_class = "GigaAMCTCTokenizer"
179
+
180
+ def __init__(self, feature_extractor, tokenizer):
181
+ # super().__init__(feature_extractor, tokenizer)
182
+ self.feature_extractor = feature_extractor
183
+ self.tokenizer = tokenizer
184
+ self.current_processor = self.feature_extractor
185
+ self._in_target_context_manager = False
186
+
187
+ @classmethod
188
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
189
+ feature_extractor = GigaAMFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
190
+ tokenizer = GigaAMCTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
191
+
192
+ return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
193
+
194
+
195
+ class GigaAMProcessorWithLM(Wav2Vec2ProcessorWithLM):
196
+ feature_extractor_class = "GigaAMFeatureExtractor"
197
+ tokenizer_class = "GigaAMCTCTokenizer"
198
+
199
+ def __init__(self, feature_extractor, tokenizer, decoder, **kwargs):
200
+ from pyctcdecode import BeamSearchDecoderCTC
201
+
202
+ self.feature_extractor = feature_extractor
203
+ self.tokenizer = tokenizer
204
+
205
+ # super().__init__(feature_extractor, tokenizer, decoder, **kwargs)
206
+ if not isinstance(decoder, BeamSearchDecoderCTC):
207
+ raise TypeError(
208
+ f"`decoder` has to be of type {BeamSearchDecoderCTC.__class__} but is {type(decoder)}"
209
+ )
210
+ self.decoder = decoder
211
+
212
+ self.current_processor = self.feature_extractor
213
+ self._in_target_context_manager = False
214
+
215
+ @classmethod
216
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
217
+ from pyctcdecode import BeamSearchDecoderCTC
218
+ feature_extractor = GigaAMFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
219
+ tokenizer = GigaAMCTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
220
+
221
+ if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path):
222
+ unigram_encoding = kwargs.get("unigram_encoding", "utf-8")
223
+ decoder = BeamSearchDecoderCTC.load_from_dir(pretrained_model_name_or_path, unigram_encoding)
224
+ else:
225
+ # BeamSearchDecoderCTC has no auto class
226
+ kwargs.pop("_from_auto", None)
227
+ # snapshot_download has no `trust_remote_code` flag
228
+ kwargs.pop("trust_remote_code", None)
229
+
230
+ # make sure that only relevant filenames are downloaded
231
+ language_model_filenames = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
232
+ alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
233
+ allow_patterns = [language_model_filenames, alphabet_filename]
234
+
235
+ decoder = BeamSearchDecoderCTC.load_from_hf_hub(
236
+ pretrained_model_name_or_path, allow_patterns=allow_patterns, **kwargs
237
+ )
238
+
239
+ # set language model attributes
240
+ for attribute in ["alpha", "beta", "unk_score_offset", "score_boundary"]:
241
+ value = kwargs.pop(attribute, None)
242
+
243
+ if value is not None:
244
+ cls._set_language_model_attribute(decoder, attribute, value)
245
+
246
+ # make sure that decoder's alphabet and tokenizer's vocab match in content
247
+ missing_decoder_tokens = cls.get_missing_alphabet_tokens(decoder, tokenizer)
248
+ if len(missing_decoder_tokens) > 0:
249
+ raise ValueError(
250
+ f"The tokens {missing_decoder_tokens} are defined in the tokenizer's "
251
+ "vocabulary, but not in the decoder's alphabet. "
252
+ f"Make sure to include {missing_decoder_tokens} in the decoder's alphabet."
253
+ )
254
+
255
+ return cls(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder)
256
+
257
+
258
+ class GigaAMConfig(PretrainedConfig):
259
+ def __init__(self, **kwargs):
260
+ super().__init__(**kwargs)
261
+
262
+
263
+ class GigaAMCTCHF(PreTrainedModel):
264
+ """
265
+ GigaAM-CTC model for transformers
266
+ """
267
+ config_class = GigaAMConfig
268
+ base_model_prefix = "gigaamctc"
269
+ main_input_name = "input_features"
270
+
271
+ def __init__(self, config: GigaAMConfig):
272
+ super().__init__(config)
273
+ self.model = GigaAMCTC(config.encoder, config.head)
274
+
275
+ def forward(self, input_features, input_lengths, labels=None, **kwargs):
276
+
277
+ # B x C x T
278
+ logits, encoded_lengths = self.model(input_features, input_lengths)
279
+ # B x C x T -> B x T x C -> T x B x C
280
+ log_probs = torch.log_softmax(
281
+ logits.transpose(1, 2), dim=-1, dtype=torch.float32
282
+ ).transpose(0, 1)
283
+
284
+ loss = None
285
+ if labels is not None:
286
+ labels_mask = labels >= 0
287
+ target_lengths = labels_mask.sum(-1)
288
+ flattened_targets = labels.masked_select(labels_mask)
289
+
290
+ loss = nn.functional.ctc_loss(
291
+ log_probs,
292
+ flattened_targets,
293
+ encoded_lengths,
294
+ target_lengths,
295
+ blank=self.config.blank_id,
296
+ zero_infinity=True,
297
+ )
298
+
299
+ return CausalLMOutput(loss=loss, logits=logits.transpose(1, 2))
language_model/attrs.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"alpha": 0.4612549346468768, "beta": 0.3271780420615455, "unk_score_offset": -10.0, "score_boundary": true}
language_model/ru_3gram.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c2a8143d33c234f881f0f7072bd0ed12c5dd2d697328410c317ed50892d70ee
3
+ size 2142431571
language_model/unigrams.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c5edc3b6134e2a64b04696a1c00abccb660a86776b8016c43108f80f35d2fae
3
+ size 29321792
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7bd7e277cb601bf55036251be654dd374c455313edefaa69d32e6ec1f9c7161
3
+ size 465343856
preprocessor_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "feature_extractor_type": "GigaAMFeatureExtractor",
4
+ "feature_extractor_class": "GigaAMFeatureExtractor",
5
+ "feature_size": 64,
6
+ "hop_length": 160,
7
+ "n_samples": 480000,
8
+ "padding_side": "right",
9
+ "padding_value": 0.0,
10
+ "return_attention_mask": true,
11
+ "sampling_rate": 16000,
12
+ "auto_map": {
13
+ "AutoFeatureExtractor": "gigaam_transformers.GigaAMFeatureExtractor",
14
+ "AutoProcessor": "gigaam_transformers.GigaAMProcessorWithLM"
15
+ },
16
+ "processor_class": "GigaAMProcessorWithLM",
17
+ "model_type": "gigaam-ctc"
18
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "pad_token": "",
3
+ "unk_token": ""
4
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "33": {
4
+ "content": "",
5
+ "lstrip": true,
6
+ "normalized": false,
7
+ "rstrip": true,
8
+ "single_word": false,
9
+ "special": false
10
+ }
11
+ },
12
+ "bos_token": null,
13
+ "clean_up_tokenization_spaces": false,
14
+ "do_lower_case": false,
15
+ "eos_token": null,
16
+ "model_max_length": 1000,
17
+ "pad_token": "",
18
+ "replace_word_delimiter_char": " ",
19
+ "target_lang": null,
20
+ "tokenizer_class": "GigaAMCTCTokenizer",
21
+ "unk_token": "",
22
+ "word_delimiter_token": " ",
23
+ "auto_map": {
24
+ "AutoTokenizer": ["gigaam_transformers.GigaAMCTCTokenizer", null]
25
+ }
26
+ }
vocab.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ " ": 0,
3
+ "": 33,
4
+ "а": 1,
5
+ "б": 2,
6
+ "в": 3,
7
+ "г": 4,
8
+ "д": 5,
9
+ "е": 6,
10
+ "ж": 7,
11
+ "з": 8,
12
+ "и": 9,
13
+ "й": 10,
14
+ "к": 11,
15
+ "л": 12,
16
+ "м": 13,
17
+ "н": 14,
18
+ "о": 15,
19
+ "п": 16,
20
+ "р": 17,
21
+ "с": 18,
22
+ "т": 19,
23
+ "у": 20,
24
+ "ф": 21,
25
+ "х": 22,
26
+ "ц": 23,
27
+ "ч": 24,
28
+ "ш": 25,
29
+ "щ": 26,
30
+ "ъ": 27,
31
+ "ы": 28,
32
+ "ь": 29,
33
+ "э": 30,
34
+ "ю": 31,
35
+ "я": 32
36
+ }