guangyil commited on
Commit
4a8840b
·
verified ·
1 Parent(s): 3d60426

Update voila_tokenizer.py

Browse files
Files changed (1) hide show
  1. voila_tokenizer.py +5 -1
voila_tokenizer.py CHANGED
@@ -24,7 +24,11 @@ class VoilaTokenizer:
24
  self.sampling_rate = self.processor.sampling_rate
25
  self.model_version = self.model.config.model_version
26
 
27
-
 
 
 
 
28
  @torch.no_grad()
29
  def encode(self, wav, sr):
30
  wav = torch.tensor(wav, dtype=torch.float32, device=self.device)
 
24
  self.sampling_rate = self.processor.sampling_rate
25
  self.model_version = self.model.config.model_version
26
 
27
+ def to(self, device):
28
+ self.device = torch.device(device)
29
+ self.model = self.model.to(device)
30
+ self.bandwidth_id = self.bandwidth_id.to(device)
31
+
32
  @torch.no_grad()
33
  def encode(self, wav, sr):
34
  wav = torch.tensor(wav, dtype=torch.float32, device=self.device)