gauri-sharan commited on
Commit
ab907f9
·
verified ·
1 Parent(s): 9f98d42

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoProcessor, AutoModel
4
+ import scipy.io.wavfile as wavfile
5
+ import spaces # Import the spaces module
6
+
7
+ # Load the model and processor
8
+ def load_model():
9
+ processor = AutoProcessor.from_pretrained("suno/bark-small")
10
+ model = AutoModel.from_pretrained("suno/bark-small")
11
+ model.eval() # Set the model to evaluation mode
12
+ return processor, model
13
+
14
+ # Load models on startup
15
+ print("Loading models...")
16
+ processor, model = load_model()
17
+ print("Models loaded successfully!")
18
+
19
+ @spaces.GPU # Decorate the function to enable GPU usage
20
+ def text_to_speech(text):
21
+ try:
22
+ # Check if a GPU is available and set device
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ # Move model to GPU
26
+ model.to(device)
27
+
28
+ inputs = processor(
29
+ text=[text],
30
+ return_tensors="pt",
31
+ ).to(device) # Move inputs to GPU
32
+
33
+ # Generate speech values on the GPU
34
+ with torch.no_grad(): # Disable gradient calculation for inference
35
+ speech_values = model.generate(**inputs, do_sample=True)
36
+
37
+ # Move generated audio data back to CPU for saving
38
+ audio_data = speech_values.cpu().numpy().squeeze()
39
+ sampling_rate = model.generation_config.sample_rate
40
+
41
+ temp_path = "temp_audio.wav"
42
+ wavfile.write(temp_path, sampling_rate, audio_data)
43
+
44
+ return temp_path
45
+ except Exception as e:
46
+ return f"Error generating speech: {str(e)}"
47
+
48
+ # Define Gradio interface
49
+ demo = gr.Interface(
50
+ fn=text_to_speech,
51
+ inputs=[
52
+ gr.Textbox(
53
+ label="Enter text (Hindi supported)",
54
+ placeholder="दिल्ली मेट्रो में आपका स्वागत है"
55
+ )
56
+ ],
57
+ outputs=gr.Audio(label="Generated Speech"),
58
+ title="Bark TTS Test App",
59
+ description="This app generates speech from text using the Bark TTS model. Supports Hindi.",
60
+ examples=[
61
+ ["दिल्ली मेट्रो में आपका स्वागत है"],
62
+ ["कृपया ध्यान दें"],
63
+ ["अगला स्टेशन राजीव चौक है"]
64
+ ],
65
+ theme="compact" # You can customize the theme
66
+ )
67
+
68
+ if __name__ == "__main__":
69
+ demo.launch()