Annuvin commited on
Commit
0bc37e3
·
verified ·
1 Parent(s): 1b06320

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +16 -5
README.md CHANGED
@@ -8,14 +8,25 @@ SNAC in safetensors format.
8
  import json
9
  from pathlib import Path
10
 
 
11
  import torch
12
- from safetensors.torch import load_model
13
  from snac import SNAC
14
 
15
- device = "cuda"
16
- dtype = torch.float32
17
  config = json.loads(Path("config.json").read_text(encoding="utf-8"))
18
  model = SNAC(**config)
19
- load_model(model, "model.safetensors", device=device)
20
- model.to(device, dtype).eval()
 
 
 
 
 
 
 
 
 
 
 
 
21
  ```
 
8
  import json
9
  from pathlib import Path
10
 
11
+ import safetensors.torch
12
  import torch
13
+ import torchaudio
14
  from snac import SNAC
15
 
 
 
16
  config = json.loads(Path("config.json").read_text(encoding="utf-8"))
17
  model = SNAC(**config)
18
+ state_dict = safetensors.torch.load_file("model.safetensors")
19
+ model.load_state_dict(state_dict)
20
+ model.cuda().eval()
21
+
22
+ input, sr = torchaudio.load("input.wav")
23
+ input = torchaudio.functional.resample(input, sr, model.sampling_rate)
24
+ input = input.cuda().unsqueeze(0)
25
+
26
+ with torch.inference_mode():
27
+ codes = model.encode(input)
28
+ output = model.decode(codes)
29
+
30
+ output = output.cpu().squeeze(0)
31
+ torchaudio.save("output.wav", output, model.sampling_rate)
32
  ```