prince-canuma commited on
Commit
7f5fb3e
·
verified ·
1 Parent(s): f2729be

Upload marvis_inference.ipynb

Browse files
Files changed (1) hide show
  1. marvis_inference.ipynb +357 -0
marvis_inference.ipynb ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "203f9969-3d4c-466b-a7ef-c153f9e2d5c6",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Without Voice Cloning"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "074a327e-1cf5-4cf3-b358-49b48589af9a",
15
+ "metadata": {},
16
+ "outputs": [
17
+ {
18
+ "name": "stderr",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n",
22
+ "The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. \n",
23
+ "The class this function is called from is 'PreTrainedTokenizerFast'.\n"
24
+ ]
25
+ }
26
+ ],
27
+ "source": [
28
+ "import torch\n",
29
+ "from transformers import AutoTokenizer, AutoProcessor, CsmForConditionalGeneration\n",
30
+ "from tokenizers.processors import TemplateProcessing\n",
31
+ "import soundfile as sf\n",
32
+ "\n",
33
+ "model_id = \"Marvis-AI/marvis-tts-0.25b-expressive-preview-transformers\"\n",
34
+ "device = \"cuda\"if torch.cuda.is_available() else \"cpu\"\n",
35
+ "\n",
36
+ "# load the model and the processor\n",
37
+ "processor = AutoProcessor.from_pretrained(model_id)\n",
38
+ "model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": 5,
44
+ "id": "875d88e0-d6cb-4103-b1b3-6601588e2c7e",
45
+ "metadata": {},
46
+ "outputs": [
47
+ {
48
+ "data": {
49
+ "text/plain": [
50
+ "{'input_ids': tensor([[ 1, 75, 32, 77, 19556, 429, 2828, 3966, 30, 2]],\n",
51
+ " device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}"
52
+ ]
53
+ },
54
+ "execution_count": 5,
55
+ "metadata": {},
56
+ "output_type": "execute_result"
57
+ }
58
+ ],
59
+ "source": [
60
+ "# prepare the inputs\n",
61
+ "text = \"[0]Hello from Marvis.\" # `[0]` for speaker id 0\n",
62
+ "inputs = processor(text, add_special_tokens=True, return_tensors=\"pt\").to(device)\n",
63
+ "inputs"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 7,
69
+ "id": "eb33ddab-44a7-4a1d-8514-7504f311e630",
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "inputs.pop(\"token_type_ids\")\n",
74
+ "# infer the model\n",
75
+ "audio = model.generate(**inputs, output_audio=True)\n",
76
+ "sf.write(\"example_without_context.wav\", audio[0].cpu(), samplerate=24_000, subtype=\"PCM_16\")"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "id": "2a14968a-8361-4b22-98de-c068055420e6",
82
+ "metadata": {},
83
+ "source": [
84
+ "# With Voice Cloning"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 2,
90
+ "id": "33d6012e-3f6e-4e9b-93ad-b2085936a8d3",
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "import torch\n",
95
+ "from datasets import load_dataset, Audio\n",
96
+ "from transformers import AutoTokenizer, AutoProcessor, CsmForConditionalGeneration\n",
97
+ "from tokenizers.processors import TemplateProcessing\n",
98
+ "import soundfile as sf\n",
99
+ "\n",
100
+ "\n",
101
+ "# prepare the inputs\n",
102
+ "ds = load_dataset(\"hf-internal-testing/dailytalk-dummy\", split=\"train\")\n",
103
+ "# ensure the audio is 24kHz\n",
104
+ "ds = ds.cast_column(\"audio\", Audio(sampling_rate=24000))\n",
105
+ "conversation = []\n",
106
+ "\n",
107
+ "# 1. context\n",
108
+ "for text, audio, speaker_id in zip(ds[:4][\"text\"], ds[:4][\"audio\"], ds[:4][\"speaker_id\"]):\n",
109
+ " conversation.append(\n",
110
+ " {\n",
111
+ " \"role\": f\"{speaker_id}\",\n",
112
+ " \"content\": [{\"type\": \"text\", \"text\": text}, {\"type\": \"audio\", \"path\": audio[\"array\"]}],\n",
113
+ " }\n",
114
+ " )\n",
115
+ "\n",
116
+ "# 2. text prompt\n",
117
+ "conversation.append({\"role\": f\"{ds[4]['speaker_id']}\", \"content\": [{\"type\": \"text\", \"text\": ds[4][\"text\"]}]})\n",
118
+ "\n",
119
+ "inputs = processor.apply_chat_template(\n",
120
+ " conversation,\n",
121
+ " tokenize=True,\n",
122
+ " return_dict=True,\n",
123
+ ").to(device)\n"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 3,
129
+ "id": "1bada9fa-3da5-4105-9da5-3622401fbe73",
130
+ "metadata": {},
131
+ "outputs": [
132
+ {
133
+ "data": {
134
+ "text/plain": [
135
+ "{'input_ids': tensor([[ 1, 75, 33, 77, 1780, 359, 346, 1891, 335, 47,\n",
136
+ " 2, 44, 108, 49, 11911, 8772, 108, 21198, 108, 49,\n",
137
+ " 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198,\n",
138
+ " 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772,\n",
139
+ " 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49,\n",
140
+ " 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198,\n",
141
+ " 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772,\n",
142
+ " 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49,\n",
143
+ " 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198,\n",
144
+ " 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772,\n",
145
+ " 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49,\n",
146
+ " 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198,\n",
147
+ " 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772,\n",
148
+ " 108, 21198, 108, 25492, 79, 85, 395, 108, 46, 1,\n",
149
+ " 75, 32, 77, 57, 5248, 23154, 578, 957, 6050, 30,\n",
150
+ " 2, 44, 108, 49, 11911, 8772, 108, 21198, 108, 49,\n",
151
+ " 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198,\n",
152
+ " 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772,\n",
153
+ " 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49,\n",
154
+ " 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198,\n",
155
+ " 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772,\n",
156
+ " 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49,\n",
157
+ " 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198,\n",
158
+ " 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772,\n",
159
+ " 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49,\n",
160
+ " 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198,\n",
161
+ " 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772,\n",
162
+ " 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49,\n",
163
+ " 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198,\n",
164
+ " 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772,\n",
165
+ " 108, 21198, 108, 25492, 79, 85, 395, 108, 46, 1,\n",
166
+ " 75, 33, 77, 69, 6799, 15874, 1812, 6050, 47, 2,\n",
167
+ " 44, 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911,\n",
168
+ " 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108,\n",
169
+ " 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108,\n",
170
+ " 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911,\n",
171
+ " 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108,\n",
172
+ " 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108,\n",
173
+ " 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911,\n",
174
+ " 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108,\n",
175
+ " 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108,\n",
176
+ " 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911,\n",
177
+ " 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108,\n",
178
+ " 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108,\n",
179
+ " 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911,\n",
180
+ " 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108,\n",
181
+ " 49, 11911, 8772, 108, 21198, 108, 25492, 79, 85, 395,\n",
182
+ " 108, 46, 1, 75, 32, 77, 57, 5248, 1625, 253,\n",
183
+ " 11316, 6050, 28, 588, 338, 339, 1326, 982, 5344, 1147,\n",
184
+ " 1083, 2478, 30, 2, 44, 108, 49, 11911, 8772, 108,\n",
185
+ " 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911,\n",
186
+ " 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108,\n",
187
+ " 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108,\n",
188
+ " 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911,\n",
189
+ " 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108,\n",
190
+ " 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108,\n",
191
+ " 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911,\n",
192
+ " 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108,\n",
193
+ " 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108,\n",
194
+ " 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911,\n",
195
+ " 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108,\n",
196
+ " 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108,\n",
197
+ " 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911,\n",
198
+ " 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108,\n",
199
+ " 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108,\n",
200
+ " 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911,\n",
201
+ " 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108,\n",
202
+ " 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108,\n",
203
+ " 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911,\n",
204
+ " 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108,\n",
205
+ " 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108,\n",
206
+ " 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911,\n",
207
+ " 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108,\n",
208
+ " 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108,\n",
209
+ " 21198, 108, 49, 11911, 8772, 108, 21198, 108, 49, 11911,\n",
210
+ " 8772, 108, 21198, 108, 49, 11911, 8772, 108, 21198, 108,\n",
211
+ " 49, 11911, 8772, 108, 21198, 108, 49, 11911, 8772, 108,\n",
212
+ " 21198, 108, 25492, 79, 85, 395, 108, 46, 1, 75,\n",
213
+ " 33, 77, 2020, 1083, 2478, 416, 346, 5344, 47, 2]],\n",
214
+ " device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
215
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
216
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
217
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
218
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
219
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
220
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
221
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
222
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
223
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
224
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
225
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
226
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
227
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
228
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
229
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
230
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
231
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
232
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
233
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
234
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
235
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
236
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
237
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
238
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
239
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
240
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
241
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
242
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
243
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
244
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
245
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
246
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],\n",
247
+ " device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
248
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
249
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
250
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
251
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
252
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
253
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
254
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
255
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
256
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
257
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
258
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
259
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
260
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
261
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
262
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
263
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
264
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
265
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
266
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
267
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
268
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
269
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
270
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
271
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
272
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
273
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
274
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
275
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
276
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
277
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
278
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
279
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],\n",
280
+ " device='cuda:0'), 'input_values': tensor([[[ 6.1035e-05, 6.1035e-05, -3.0518e-05, ..., -6.1035e-05,\n",
281
+ " -6.1035e-05, -3.0518e-05]]], device='cuda:0'), 'input_values_cutoffs': tensor([[ 37181, 83657, 128274, 215183]], device='cuda:0')}"
282
+ ]
283
+ },
284
+ "execution_count": 3,
285
+ "metadata": {},
286
+ "output_type": "execute_result"
287
+ }
288
+ ],
289
+ "source": [
290
+ "inputs"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": 5,
296
+ "id": "38687e7b-77ac-4543-b277-af8f8e3d310c",
297
+ "metadata": {},
298
+ "outputs": [],
299
+ "source": [
300
+ "# infer the model\n",
301
+ "# inputs.pop(\"token_type_ids\")\n",
302
+ "audio = model.generate(**inputs, output_audio=True)\n",
303
+ "sf.write(\"example_with_context.wav\", audio[0].cpu(), samplerate=24_000, subtype=\"PCM_16\")"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": 5,
309
+ "id": "7480795e-0a35-42cc-8f1e-3c1f5dffc07e",
310
+ "metadata": {},
311
+ "outputs": [
312
+ {
313
+ "data": {
314
+ "text/plain": [
315
+ "\"<|im_start|>[1]What are you working on?<|im_end|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|audio_eos|><|im_start|>[0]I'm figuring out my budget.<|im_end|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|audio_eos|><|im_start|>[1]Umm…. What budget?<|im_end|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|audio_eos|><|im_start|>[0]I'm making a shopping budget, so that I don't spend too much money.<|im_end|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|AUDIO|><|audio_eos|><|im_start|>[1]How much money can you spend?<|im_end|>\""
316
+ ]
317
+ },
318
+ "execution_count": 5,
319
+ "metadata": {},
320
+ "output_type": "execute_result"
321
+ }
322
+ ],
323
+ "source": [
324
+ "processor.tokenizer.decode(inputs[\"input_ids\"][0])"
325
+ ]
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "execution_count": null,
330
+ "id": "ceb62c4d-3bfd-4805-886d-83eaac0cdf65",
331
+ "metadata": {},
332
+ "outputs": [],
333
+ "source": []
334
+ }
335
+ ],
336
+ "metadata": {
337
+ "kernelspec": {
338
+ "display_name": "Python 3 (ipykernel)",
339
+ "language": "python",
340
+ "name": "python3"
341
+ },
342
+ "language_info": {
343
+ "codemirror_mode": {
344
+ "name": "ipython",
345
+ "version": 3
346
+ },
347
+ "file_extension": ".py",
348
+ "mimetype": "text/x-python",
349
+ "name": "python",
350
+ "nbconvert_exporter": "python",
351
+ "pygments_lexer": "ipython3",
352
+ "version": "3.10.12"
353
+ }
354
+ },
355
+ "nbformat": 4,
356
+ "nbformat_minor": 5
357
+ }