File size: 20,903 Bytes
7f5fb3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "203f9969-3d4c-466b-a7ef-c153f9e2d5c6",
   "metadata": {},
   "source": [
    "# Without Voice Cloning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "074a327e-1cf5-4cf3-b358-49b48589af9a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n",
      "The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. \n",
      "The class this function is called from is 'PreTrainedTokenizerFast'.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer, AutoProcessor, CsmForConditionalGeneration\n",
    "from tokenizers.processors import TemplateProcessing\n",
    "import soundfile as sf\n",
    "\n",
    "model_id = \"Marvis-AI/marvis-tts-0.25b-expressive-preview-transformers\"\n",
    "device = \"cuda\"if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# load the model and the processor\n",
    "processor = AutoProcessor.from_pretrained(model_id)\n",
    "model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "875d88e0-d6cb-4103-b1b3-6601588e2c7e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[    1,    75,    32,    77, 19556,   429,  2828,  3966,    30,     2]],\n",
       "       device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# prepare the inputs\n",
    "text = \"[0]Hello from Marvis.\" # `[0]` for speaker id 0\n",
    "inputs = processor(text, add_special_tokens=True, return_tensors=\"pt\").to(device)\n",
    "inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "eb33ddab-44a7-4a1d-8514-7504f311e630",
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs.pop(\"token_type_ids\")\n",
    "# infer the model\n",
    "audio = model.generate(**inputs, output_audio=True)\n",
    "sf.write(\"example_without_context.wav\", audio[0].cpu(), samplerate=24_000, subtype=\"PCM_16\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a14968a-8361-4b22-98de-c068055420e6",
   "metadata": {},
   "source": [
    "# With Voice Cloning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "33d6012e-3f6e-4e9b-93ad-b2085936a8d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from datasets import load_dataset, Audio\n",
    "from transformers import AutoTokenizer, AutoProcessor, CsmForConditionalGeneration\n",
    "from tokenizers.processors import TemplateProcessing\n",
    "import soundfile as sf\n",
    "\n",
    "\n",
    "# prepare the inputs\n",
    "ds = load_dataset(\"hf-internal-testing/dailytalk-dummy\", split=\"train\")\n",
    "# ensure the audio is 24kHz\n",
    "ds = ds.cast_column(\"audio\", Audio(sampling_rate=24000))\n",
    "conversation = []\n",
    "\n",
    "# 1. context\n",
    "for text, audio, speaker_id in zip(ds[:4][\"text\"], ds[:4][\"audio\"], ds[:4][\"speaker_id\"]):\n",
    "    conversation.append(\n",
    "        {\n",
    "            \"role\": f\"{speaker_id}\",\n",
    "            \"content\": [{\"type\": \"text\", \"text\": text}, {\"type\": \"audio\", \"path\": audio[\"array\"]}],\n",
    "        }\n",
    "    )\n",
    "\n",
    "# 2. text prompt\n",
    "conversation.append({\"role\": f\"{ds[4]['speaker_id']}\", \"content\": [{\"type\": \"text\", \"text\": ds[4][\"text\"]}]})\n",
    "\n",
    "inputs = processor.apply_chat_template(\n",
    "    conversation,\n",
    "    tokenize=True,\n",
    "    return_dict=True,\n",
    ").to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1bada9fa-3da5-4105-9da5-3622401fbe73",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[    1,    75,    33,    77,  1780,   359,   346,  1891,   335,    47,\n",
       "             2,    44,   108,    49, 11911,  8772,   108, 21198,   108,    49,\n",
       "         11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,\n",
       "           108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,\n",
       "           108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,    49,\n",
       "         11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,\n",
       "           108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,\n",
       "           108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,    49,\n",
       "         11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,\n",
       "           108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,\n",
       "           108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,    49,\n",
       "         11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,\n",
       "           108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,\n",
       "           108, 21198,   108, 25492,    79,    85,   395,   108,    46,     1,\n",
       "            75,    32,    77,    57,  5248, 23154,   578,   957,  6050,    30,\n",
       "             2,    44,   108,    49, 11911,  8772,   108, 21198,   108,    49,\n",
       "         11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,\n",
       "           108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,\n",
       "           108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,    49,\n",
       "         11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,\n",
       "           108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,\n",
       "           108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,    49,\n",
       "         11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,\n",
       "           108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,\n",
       "           108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,    49,\n",
       "         11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,\n",
       "           108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,\n",
       "           108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,    49,\n",
       "         11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,\n",
       "           108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,\n",
       "           108, 21198,   108, 25492,    79,    85,   395,   108,    46,     1,\n",
       "            75,    33,    77,    69,  6799, 15874,  1812,  6050,    47,     2,\n",
       "            44,   108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,\n",
       "          8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,\n",
       "            49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108,\n",
       "         21198,   108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,\n",
       "          8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,\n",
       "            49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108,\n",
       "         21198,   108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,\n",
       "          8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,\n",
       "            49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108,\n",
       "         21198,   108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,\n",
       "          8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,\n",
       "            49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108,\n",
       "         21198,   108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,\n",
       "          8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,\n",
       "            49, 11911,  8772,   108, 21198,   108, 25492,    79,    85,   395,\n",
       "           108,    46,     1,    75,    32,    77,    57,  5248,  1625,   253,\n",
       "         11316,  6050,    28,   588,   338,   339,  1326,   982,  5344,  1147,\n",
       "          1083,  2478,    30,     2,    44,   108,    49, 11911,  8772,   108,\n",
       "         21198,   108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,\n",
       "          8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,\n",
       "            49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108,\n",
       "         21198,   108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,\n",
       "          8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,\n",
       "            49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108,\n",
       "         21198,   108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,\n",
       "          8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,\n",
       "            49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108,\n",
       "         21198,   108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,\n",
       "          8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,\n",
       "            49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108,\n",
       "         21198,   108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,\n",
       "          8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,\n",
       "            49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108,\n",
       "         21198,   108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,\n",
       "          8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,\n",
       "            49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108,\n",
       "         21198,   108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,\n",
       "          8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,\n",
       "            49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108,\n",
       "         21198,   108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,\n",
       "          8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,\n",
       "            49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108,\n",
       "         21198,   108,    49, 11911,  8772,   108, 21198,   108,    49, 11911,\n",
       "          8772,   108, 21198,   108,    49, 11911,  8772,   108, 21198,   108,\n",
       "            49, 11911,  8772,   108, 21198,   108,    49, 11911,  8772,   108,\n",
       "         21198,   108, 25492,    79,    85,   395,   108,    46,     1,    75,\n",
       "            33,    77,  2020,  1083,  2478,   416,   346,  5344,    47,     2]],\n",
       "       device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],\n",
       "       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],\n",
       "       device='cuda:0'), 'input_values': tensor([[[ 6.1035e-05,  6.1035e-05, -3.0518e-05,  ..., -6.1035e-05,\n",
       "          -6.1035e-05, -3.0518e-05]]], device='cuda:0'), 'input_values_cutoffs': tensor([[ 37181,  83657, 128274, 215183]], device='cuda:0')}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "38687e7b-77ac-4543-b277-af8f8e3d310c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# infer the model\n",
    "# inputs.pop(\"token_type_ids\")\n",
    "audio = model.generate(**inputs, output_audio=True)\n",
    "sf.write(\"example_with_context.wav\", audio[0].cpu(), samplerate=24_000, subtype=\"PCM_16\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7480795e-0a35-42cc-8f1e-3c1f5dffc07e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"<|im_start|>[1]What are you working on?<|im_end|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|audio_eos|><|im_start|>[0]I'm figuring out my budget.<|im_end|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|audio_eos|><|im_start|>[1]Umm…. What budget?<|im_end|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|audio_eos|><|im_start|>[0]I'm making a shopping budget, so that I don't spend too much money.<|im_end|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|audio_eos|><|im_start|>[1]How much money can you spend?<|im_end|>\""
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "processor.tokenizer.decode(inputs[\"input_ids\"][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceb62c4d-3bfd-4805-886d-83eaac0cdf65",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}