Elanas commited on
Commit
12dd383
·
verified ·
1 Parent(s): 0d3fc6f

Upload filtravimas.py

Browse files
Files changed (1) hide show
  1. filtravimas.py +87 -0
filtravimas.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import torch.nn as nn
5
+ import torchaudio.transforms as T
6
+ import noisereduce as nr
7
+ import numpy as np
8
+ from asteroid.models import DCCRNet
9
+
10
+ TEMP_DIR = "temp_filtered"
11
+ OUTPUT_PATH = os.path.join(TEMP_DIR, "ivestis.wav")
12
+ os.makedirs(TEMP_DIR, exist_ok=True)
13
+
14
+ class WaveUNet(nn.Module):
15
+ def __init__(self):
16
+ super(WaveUNet, self).__init__()
17
+ self.encoder = nn.Sequential(
18
+ nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1),
19
+ nn.ReLU(),
20
+ nn.Conv1d(16, 32, kernel_size=3, stride=1, padding=1),
21
+ nn.ReLU(),
22
+ nn.Conv1d(32, 64, kernel_size=3, stride=1, padding=1),
23
+ nn.ReLU(),
24
+ )
25
+ self.decoder = nn.Sequential(
26
+ nn.ConvTranspose1d(64, 32, kernel_size=3, stride=1, padding=1),
27
+ nn.ReLU(),
28
+ nn.ConvTranspose1d(32, 16, kernel_size=3, stride=1, padding=1),
29
+ nn.ReLU(),
30
+ nn.ConvTranspose1d(16, 1, kernel_size=3, stride=1, padding=1)
31
+ )
32
+
33
+ def forward(self, x):
34
+ x = self.encoder(x)
35
+ x = self.decoder(x)
36
+ return x
37
+
38
+ def filtruoti_su_waveunet(input_path, output_path):
39
+ print("🔧 Wave-U-Net filtravimas...")
40
+ model = WaveUNet()
41
+ model.eval()
42
+ mixture, sr = torchaudio.load(input_path)
43
+ if sr != 16000:
44
+ print("🔁 Resample į 16kHz...")
45
+ resampler = T.Resample(orig_freq=sr, new_freq=16000).to(mixture.device)
46
+ mixture = resampler(mixture)
47
+ if mixture.dim() == 2:
48
+ mixture = mixture.unsqueeze(0)
49
+ with torch.no_grad():
50
+ output = model(mixture)
51
+ output = output.squeeze(0)
52
+ torchaudio.save(output_path, output, 16000)
53
+ print(f"✅ Wave-U-Net išsaugota: {output_path}")
54
+
55
+ def filtruoti_su_denoiser(input_path, output_path):
56
+ print("🔧 Denoiser (DCCRNet)...")
57
+ model = DCCRNet.from_pretrained("JorisCos/DCCRNet_Libri1Mix_enhsingle_16k")
58
+ mixture, sr = torchaudio.load(input_path)
59
+ if sr != 16000:
60
+ print("🔁 Resample į 16kHz...")
61
+ resampler = T.Resample(orig_freq=sr, new_freq=16000).to(mixture.device)
62
+ mixture = resampler(mixture)
63
+ with torch.no_grad():
64
+ est_source = model.separate(mixture)
65
+ torchaudio.save(output_path, est_source[0], 16000)
66
+ print(f"✅ Denoiser išsaugota: {output_path}")
67
+
68
+ def filtruoti_su_noisereduce(input_path, output_path):
69
+ print("🔧 Noisereduce filtravimas...")
70
+ waveform, sr = torchaudio.load(input_path)
71
+ audio = waveform.detach().cpu().numpy()[0]
72
+ reduced = nr.reduce_noise(y=audio, sr=sr)
73
+ reduced_tensor = torch.from_numpy(reduced).unsqueeze(0)
74
+ torchaudio.save(output_path, reduced_tensor, sr)
75
+ print(f"✅ Noisereduce išsaugota: {output_path}")
76
+
77
+ def filtruoti_audio(input_path: str, metodas: str) -> str:
78
+ if metodas == "Denoiser":
79
+ filtruoti_su_denoiser(input_path, OUTPUT_PATH)
80
+ elif metodas == "Wave-U-Net":
81
+ filtruoti_su_waveunet(input_path, OUTPUT_PATH)
82
+ elif metodas == "Noisereduce":
83
+ filtruoti_su_noisereduce(input_path, OUTPUT_PATH)
84
+ else:
85
+ raise ValueError("Nepalaikomas filtravimo metodas")
86
+
87
+ return OUTPUT_PATH