Spaces:
Runtime error
Runtime error
added audio processing function
Browse files
app.py
CHANGED
@@ -9,6 +9,18 @@ import requests
|
|
9 |
import json
|
10 |
import dotenv
|
11 |
from transformers import AutoProcessor, SeamlessM4TModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
from lang_list import (
|
14 |
LANGUAGE_NAME_TO_CODE,
|
@@ -17,61 +29,27 @@ from lang_list import (
|
|
17 |
T2TT_TARGET_LANGUAGE_NAMES,
|
18 |
TEXT_SOURCE_LANGUAGE_NAMES,
|
19 |
LANG_TO_SPKR_ID,
|
20 |
-
)
|
21 |
-
|
22 |
-
dotenv.load_dotenv()
|
23 |
-
|
24 |
-
DEFAULT_TARGET_LANGUAGE = "English"
|
25 |
-
AUDIO_SAMPLE_RATE = 16000.0
|
26 |
-
MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
|
27 |
-
|
28 |
-
|
29 |
-
def predict(
|
30 |
-
task_name: str,
|
31 |
-
audio_source: str,
|
32 |
-
input_audio_mic: str | None,
|
33 |
-
input_audio_file: str | None,
|
34 |
-
input_text: str | None,
|
35 |
-
source_language: str | None,
|
36 |
-
target_language: str,
|
37 |
-
) -> tuple[tuple[int, np.ndarray] | None, str]:
|
38 |
-
task_name = task_name.split()[0]
|
39 |
-
source_language_code = LANGUAGE_NAME_TO_CODE[source_language] if source_language else None
|
40 |
-
target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
|
41 |
-
|
42 |
-
if task_name in ["S2ST", "S2TT", "ASR"]:
|
43 |
-
if audio_source == "microphone":
|
44 |
-
input_data = input_audio_mic
|
45 |
-
else:
|
46 |
-
input_data = input_audio_file
|
47 |
-
|
48 |
-
arr, org_sr = torchaudio.load(input_data)
|
49 |
-
new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
|
50 |
-
max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
|
51 |
-
if new_arr.shape[1] > max_length:
|
52 |
-
new_arr = new_arr[:, :max_length]
|
53 |
-
gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
|
54 |
-
|
55 |
-
|
56 |
-
input_data = processor(audios = new_arr, sampling_rate=AUDIO_SAMPLE_RATE, return_tensors="pt").to(device)
|
57 |
-
else:
|
58 |
-
input_data = processor(text = input_text, src_lang=source_language_code, return_tensors="pt").to(device)
|
59 |
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
text_out = processor.decode(tokens_ids, skip_special_tokens=True)
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
else:
|
74 |
-
return None, text_out
|
75 |
|
76 |
def convert_image_to_required_format(image):
|
77 |
"""
|
|
|
9 |
import json
|
10 |
import dotenv
|
11 |
from transformers import AutoProcessor, SeamlessM4TModel
|
12 |
+
import torchaudio
|
13 |
+
dotenv.load_dotenv()
|
14 |
+
|
15 |
+
|
16 |
+
AUDIO_SAMPLE_RATE = 16000.0
|
17 |
+
MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
|
18 |
+
DEFAULT_TARGET_LANGUAGE = "English"
|
19 |
+
|
20 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
21 |
+
|
22 |
+
processor = AutoProcessor.from_pretrained("ylacombe/hf-seamless-m4t-large")
|
23 |
+
model = SeamlessM4TModel.from_pretrained("ylacombe/hf-seamless-m4t-large").to(device)
|
24 |
|
25 |
from lang_list import (
|
26 |
LANGUAGE_NAME_TO_CODE,
|
|
|
29 |
T2TT_TARGET_LANGUAGE_NAMES,
|
30 |
TEXT_SOURCE_LANGUAGE_NAMES,
|
31 |
LANG_TO_SPKR_ID,
|
32 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
+
def process_speech(sound):
|
36 |
+
"""
|
37 |
+
processing sound using seamless_m4t
|
38 |
+
"""
|
39 |
+
# task_name = "T2TT"
|
40 |
+
arr, org_sr = torchaudio.load(sound)
|
41 |
+
target_language_code = LANGUAGE_NAME_TO_CODE[DEFAULT_TARGET_LANGUAGE]
|
42 |
+
new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
|
43 |
+
max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
|
44 |
+
if new_arr.shape[1] > max_length:
|
45 |
+
new_arr = new_arr[:, :max_length]
|
46 |
+
gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
|
47 |
+
input_data = processor(audios = new_arr, sampling_rate=AUDIO_SAMPLE_RATE, return_tensors="pt").to(device)
|
48 |
+
tokens_ids = model.generate(**input_data, generate_speech=False, tgt_lang=target_language_code, num_beams=5, do_sample=True)[0].cpu().squeeze().detach().tolist()
|
49 |
text_out = processor.decode(tokens_ids, skip_special_tokens=True)
|
50 |
|
51 |
+
return text_out
|
52 |
+
|
|
|
|
|
53 |
|
54 |
def convert_image_to_required_format(image):
|
55 |
"""
|