Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -5,8 +5,11 @@ import shutil
|
|
5 |
from pydub import AudioSegment
|
6 |
import spaces
|
7 |
import torch
|
8 |
-
import
|
9 |
-
from
|
|
|
|
|
|
|
10 |
|
11 |
from examples.get_examples import get_examples
|
12 |
from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
|
@@ -19,24 +22,45 @@ from src.utils.init_path import init_path
|
|
19 |
|
20 |
checkpoint_path = 'checkpoints'
|
21 |
config_path = 'src/config'
|
22 |
-
device = "cuda" if torch.cuda.is_available(
|
23 |
-
) else "mps" if platform.system() == 'Darwin' else "cpu"
|
24 |
|
25 |
os.environ['TORCH_HOME'] = checkpoint_path
|
26 |
snapshot_download(repo_id='vinthony/SadTalker-V002rc',
|
27 |
local_dir=checkpoint_path, local_dir_use_symlinks=True)
|
28 |
|
|
|
|
|
|
|
29 |
|
30 |
def mp3_to_wav(mp3_filename, wav_filename, frame_rate):
|
31 |
AudioSegment.from_file(file=mp3_filename).set_frame_rate(
|
32 |
frame_rate).export(wav_filename, format="wav")
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
@spaces.GPU(duration=0)
|
36 |
-
def generate_video(source_image, driven_audio, preprocess='crop', still_mode=
|
37 |
-
batch_size=1, size
|
38 |
-
|
39 |
-
|
|
|
40 |
# Initialize models and paths
|
41 |
sadtalker_paths = init_path(
|
42 |
checkpoint_path, config_path, size, False, preprocess)
|
@@ -111,6 +135,10 @@ def generate_video(source_image, driven_audio, preprocess='crop', still_mode=Fal
|
|
111 |
else:
|
112 |
batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path,
|
113 |
still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink)
|
|
|
|
|
|
|
|
|
114 |
coeff_path = audio_to_coeff.generate(
|
115 |
batch, save_dir, pose_style, ref_pose_coeff_path)
|
116 |
|
@@ -124,94 +152,166 @@ def generate_video(source_image, driven_audio, preprocess='crop', still_mode=Fal
|
|
124 |
|
125 |
return return_path
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
value=5, label="The length(seconds) of the generated video.")
|
154 |
-
use_idle_mode.change(lambda choice: (gr.update(visible=not choice), gr.update(visible=choice)),
|
155 |
-
inputs=use_idle_mode, outputs=[driven_audio, driven_audio_no])
|
156 |
-
|
157 |
-
with gr.Row():
|
158 |
-
ref_video = gr.Video(
|
159 |
-
label="Reference Video", sources="upload", elem_id="vidref")
|
160 |
-
|
161 |
-
with gr.Column():
|
162 |
-
use_ref_video = gr.Checkbox(
|
163 |
-
label="Use Reference Video")
|
164 |
-
ref_info = gr.Radio(['pose', 'blink', 'pose+blink', 'all'], value='pose', label='Reference Video',
|
165 |
-
info="How to borrow from reference Video?((fully transfer, aka, video driving mode))")
|
166 |
-
|
167 |
-
ref_video.change(lambda path: gr.update(
|
168 |
-
value=path is not None), inputs=ref_video, outputs=use_ref_video)
|
169 |
-
|
170 |
-
with gr.Column(variant='panel'):
|
171 |
-
with gr.Tabs(elem_id="sadtalker_checkbox"):
|
172 |
-
with gr.TabItem('Settings'):
|
173 |
-
with gr.Column(variant='panel'):
|
174 |
-
with gr.Row():
|
175 |
-
pose_style = gr.Slider(
|
176 |
-
minimum=0, maximum=45, step=1, label="Pose style", value=0)
|
177 |
-
exp_weight = gr.Slider(
|
178 |
-
minimum=0, maximum=3, step=0.1, label="expression scale", value=1)
|
179 |
-
blink_every = gr.Checkbox(
|
180 |
-
label="use eye blink", value=True)
|
181 |
-
|
182 |
-
with gr.Row():
|
183 |
-
size_of_image = gr.Radio(
|
184 |
-
[256, 512], value=256, label='face model resolution', info="use 256/512 model?")
|
185 |
-
preprocess_type = gr.Radio(
|
186 |
-
['crop', 'resize', 'full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?")
|
187 |
-
|
188 |
-
with gr.Row():
|
189 |
-
is_still_mode = gr.Checkbox(
|
190 |
-
label="Still Mode (fewer head motion, works with preprocess `full`)")
|
191 |
-
facerender = gr.Radio(
|
192 |
-
['facevid2vid', 'pirender'], value='facevid2vid', label='facerender', info="which face render?")
|
193 |
-
|
194 |
-
with gr.Row():
|
195 |
-
batch_size = gr.Slider(
|
196 |
-
label="batch size in generation", step=1, maximum=10, value=1)
|
197 |
-
enhancer = gr.Checkbox(
|
198 |
-
label="GFPGAN as Face enhancer", value=True)
|
199 |
-
|
200 |
-
submit = gr.Button(
|
201 |
-
'Generate', elem_id="sadtalker_generate", variant='primary')
|
202 |
-
|
203 |
-
with gr.Tabs(elem_id="sadtalker_generated"):
|
204 |
-
gen_video = gr.Video(label="Generated video")
|
205 |
-
|
206 |
-
submit.click(
|
207 |
-
fn=generate_video,
|
208 |
-
inputs=[source_image, driven_audio, preprocess_type, is_still_mode, enhancer, batch_size, size_of_image,
|
209 |
-
pose_style, facerender, exp_weight, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, blink_every],
|
210 |
-
outputs=[gen_video],
|
211 |
)
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from pydub import AudioSegment
|
6 |
import spaces
|
7 |
import torch
|
8 |
+
from fastapi import FastAPI, File, UploadFile, Form
|
9 |
+
from fastapi.responses import FileResponse
|
10 |
+
from fastapi.staticfiles import StaticFiles
|
11 |
+
from fastapi.templating import Jinja2Templates
|
12 |
+
from transformers import pipeline
|
13 |
|
14 |
from examples.get_examples import get_examples
|
15 |
from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
|
|
|
22 |
|
23 |
checkpoint_path = 'checkpoints'
|
24 |
config_path = 'src/config'
|
25 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if platform.system() == 'Darwin' else "cpu"
|
|
|
26 |
|
27 |
os.environ['TORCH_HOME'] = checkpoint_path
|
28 |
snapshot_download(repo_id='vinthony/SadTalker-V002rc',
|
29 |
local_dir=checkpoint_path, local_dir_use_symlinks=True)
|
30 |
|
31 |
+
app = FastAPI()
|
32 |
+
app.mount("/results", StaticFiles(directory="results"), name="results")
|
33 |
+
templates = Jinja2Templates(directory="templates")
|
34 |
|
35 |
def mp3_to_wav(mp3_filename, wav_filename, frame_rate):
|
36 |
AudioSegment.from_file(file=mp3_filename).set_frame_rate(
|
37 |
frame_rate).export(wav_filename, format="wav")
|
38 |
|
39 |
+
def get_pose_style_from_audio(audio_path):
|
40 |
+
"""Determines pose style based on audio emotion using a pre-trained model."""
|
41 |
+
# Load the pre-trained emotion recognition model
|
42 |
+
emotion_recognizer = pipeline("sentiment-analysis")
|
43 |
+
|
44 |
+
# Analyze the audio emotion
|
45 |
+
results = emotion_recognizer(audio_path)
|
46 |
+
emotion = results[0]["label"]
|
47 |
+
|
48 |
+
# Map emotion to pose style (you can adjust these mappings)
|
49 |
+
pose_style_mapping = {
|
50 |
+
"POSITIVE": 15, # Happy
|
51 |
+
"NEGATIVE": 35, # Sad
|
52 |
+
"NEUTRAL": 0, # Normal
|
53 |
+
# Add more emotion mappings as needed
|
54 |
+
}
|
55 |
+
|
56 |
+
return pose_style_mapping.get(emotion, 0) # Default to neutral pose if unknown
|
57 |
|
58 |
@spaces.GPU(duration=0)
|
59 |
+
def generate_video(source_image: str, driven_audio: str, preprocess: str = 'crop', still_mode: bool = False,
|
60 |
+
use_enhancer: bool = False, batch_size: int = 1, size: int = 256,
|
61 |
+
facerender: str = 'facevid2vid', exp_scale: float = 1.0, use_ref_video: bool = False,
|
62 |
+
ref_video: str = None, ref_info: str = None, use_idle_mode: bool = False,
|
63 |
+
length_of_audio: int = 0, use_blink: bool = True, result_dir: str = './results/') -> str:
|
64 |
# Initialize models and paths
|
65 |
sadtalker_paths = init_path(
|
66 |
checkpoint_path, config_path, size, False, preprocess)
|
|
|
135 |
else:
|
136 |
batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path,
|
137 |
still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink)
|
138 |
+
|
139 |
+
# Get pose style from audio
|
140 |
+
pose_style = get_pose_style_from_audio(audio_path)
|
141 |
+
|
142 |
coeff_path = audio_to_coeff.generate(
|
143 |
batch, save_dir, pose_style, ref_pose_coeff_path)
|
144 |
|
|
|
152 |
|
153 |
return return_path
|
154 |
|
155 |
+
@app.post("/generate")
|
156 |
+
async def generate_video_api(source_image: UploadFile = File(...), driven_audio: UploadFile = File(None),
|
157 |
+
preprocess: str = Form('crop'), still_mode: bool = Form(False),
|
158 |
+
use_enhancer: bool = Form(False), batch_size: int = Form(1), size: int = Form(256),
|
159 |
+
facerender: str = Form('facevid2vid'), exp_scale: float = Form(1.0),
|
160 |
+
use_ref_video: bool = Form(False), ref_video: UploadFile = File(None),
|
161 |
+
ref_info: str = Form(None), use_idle_mode: bool = Form(False),
|
162 |
+
length_of_audio: int = Form(0), use_blink: bool = Form(True), result_dir: str = Form('./results/')):
|
163 |
+
# Save the uploaded files temporarily
|
164 |
+
temp_source_image_path = f"temp/{source_image.filename}"
|
165 |
+
os.makedirs("temp", exist_ok=True)
|
166 |
+
with open(temp_source_image_path, "wb") as buffer:
|
167 |
+
shutil.copyfileobj(source_image.file, buffer)
|
168 |
+
|
169 |
+
if driven_audio:
|
170 |
+
temp_driven_audio_path = f"temp/{driven_audio.filename}"
|
171 |
+
with open(temp_driven_audio_path, "wb") as buffer:
|
172 |
+
shutil.copyfileobj(driven_audio.file, buffer)
|
173 |
+
else:
|
174 |
+
temp_driven_audio_path = None
|
175 |
|
176 |
+
if ref_video:
|
177 |
+
temp_ref_video_path = f"temp/{ref_video.filename}"
|
178 |
+
with open(temp_ref_video_path, "wb") as buffer:
|
179 |
+
shutil.copyfileobj(ref_video.file, buffer)
|
180 |
+
else:
|
181 |
+
temp_ref_video_path = None
|
182 |
+
|
183 |
+
# Generate the video
|
184 |
+
video_path = generate_video(
|
185 |
+
source_image=temp_source_image_path,
|
186 |
+
driven_audio=temp_driven_audio_path,
|
187 |
+
preprocess=preprocess,
|
188 |
+
still_mode=still_mode,
|
189 |
+
use_enhancer=use_enhancer,
|
190 |
+
batch_size=batch_size,
|
191 |
+
size=size,
|
192 |
+
facerender=facerender,
|
193 |
+
exp_scale=exp_scale,
|
194 |
+
use_ref_video=use_ref_video,
|
195 |
+
ref_video=temp_ref_video_path,
|
196 |
+
ref_info=ref_info,
|
197 |
+
use_idle_mode=use_idle_mode,
|
198 |
+
length_of_audio=length_of_audio,
|
199 |
+
use_blink=use_blink,
|
200 |
+
result_dir=result_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
)
|
202 |
|
203 |
+
# Clean up temporary files
|
204 |
+
shutil.rmtree("temp")
|
205 |
+
|
206 |
+
# Return the generated video file
|
207 |
+
return FileResponse(video_path)
|
208 |
+
|
209 |
+
|
210 |
+
@app.get("/")
|
211 |
+
async def root(request):
|
212 |
+
return templates.TemplateResponse("index.html", {"request": request})
|
213 |
+
|
214 |
+
# HTML Template (`templates/index.html`)
|
215 |
+
html = """
|
216 |
+
<!DOCTYPE html>
|
217 |
+
<html lang="en">
|
218 |
+
<head>
|
219 |
+
<meta charset="UTF-8">
|
220 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
221 |
+
<title>SadTalker API</title>
|
222 |
+
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css">
|
223 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/jquery.slim.min.js"></script>
|
224 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/umd/popper.min.js"></script>
|
225 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.min.js"></script>
|
226 |
+
</head>
|
227 |
+
<body>
|
228 |
+
<div class="container mt-5">
|
229 |
+
<h1>SadTalker API</h1>
|
230 |
+
<form method="POST" action="/generate" enctype="multipart/form-data">
|
231 |
+
<div class="form-group">
|
232 |
+
<label for="source_image">Source Image:</label>
|
233 |
+
<input type="file" class="form-control-file" id="source_image" name="source_image" required>
|
234 |
+
</div>
|
235 |
+
<div class="form-group">
|
236 |
+
<label for="driven_audio">Driving Audio:</label>
|
237 |
+
<input type="file" class="form-control-file" id="driven_audio" name="driven_audio">
|
238 |
+
</div>
|
239 |
+
<div class="form-group">
|
240 |
+
<label for="preprocess">Preprocess:</label>
|
241 |
+
<select class="form-control" id="preprocess" name="preprocess">
|
242 |
+
<option value="crop">Crop</option>
|
243 |
+
<option value="resize">Resize</option>
|
244 |
+
<option value="full">Full</option>
|
245 |
+
<option value="extcrop">ExtCrop</option>
|
246 |
+
<option value="extfull">ExtFull</option>
|
247 |
+
</select>
|
248 |
+
</div>
|
249 |
+
<div class="form-check">
|
250 |
+
<input type="checkbox" class="form-check-input" id="still_mode" name="still_mode">
|
251 |
+
<label class="form-check-label" for="still_mode">Still Mode</label>
|
252 |
+
</div>
|
253 |
+
<div class="form-check">
|
254 |
+
<input type="checkbox" class="form-check-input" id="use_enhancer" name="use_enhancer">
|
255 |
+
<label class="form-check-label" for="use_enhancer">Use GFPGAN Enhancer</label>
|
256 |
+
</div>
|
257 |
+
<div class="form-group">
|
258 |
+
<label for="batch_size">Batch Size:</label>
|
259 |
+
<input type="number" class="form-control" id="batch_size" name="batch_size" min="1" max="10" value="1">
|
260 |
+
</div>
|
261 |
+
<div class="form-group">
|
262 |
+
<label for="size">Face Model Resolution:</label>
|
263 |
+
<select class="form-control" id="size" name="size">
|
264 |
+
<option value="256">256</option>
|
265 |
+
<option value="512">512</option>
|
266 |
+
</select>
|
267 |
+
</div>
|
268 |
+
<div class="form-group">
|
269 |
+
<label for="facerender">Face Render:</label>
|
270 |
+
<select class="form-control" id="facerender" name="facerender">
|
271 |
+
<option value="facevid2vid">FaceVid2Vid</option>
|
272 |
+
<option value="pirender">PIRender</option>
|
273 |
+
</select>
|
274 |
+
</div>
|
275 |
+
<div class="form-group">
|
276 |
+
<label for="exp_scale">Expression Scale:</label>
|
277 |
+
<input type="number" class="form-control" id="exp_scale" name="exp_scale" min="0" max="3" step="0.1" value="1.0">
|
278 |
+
</div>
|
279 |
+
<div class="form-check">
|
280 |
+
<input type="checkbox" class="form-check-input" id="use_ref_video" name="use_ref_video">
|
281 |
+
<label class="form-check-label" for="use_ref_video">Use Reference Video</label>
|
282 |
+
</div>
|
283 |
+
<div class="form-group">
|
284 |
+
<label for="ref_video">Reference Video:</label>
|
285 |
+
<input type="file" class="form-control-file" id="ref_video" name="ref_video">
|
286 |
+
</div>
|
287 |
+
<div class="form-group">
|
288 |
+
<label for="ref_info">Reference Video Information:</label>
|
289 |
+
<select class="form-control" id="ref_info" name="ref_info">
|
290 |
+
<option value="pose">Pose</option>
|
291 |
+
<option value="blink">Blink</option>
|
292 |
+
<option value="pose+blink">Pose + Blink</option>
|
293 |
+
<option value="all">All</option>
|
294 |
+
</select>
|
295 |
+
</div>
|
296 |
+
<div class="form-check">
|
297 |
+
<input type="checkbox" class="form-check-input" id="use_idle_mode" name="use_idle_mode">
|
298 |
+
<label class="form-check-label" for="use_idle_mode">Use Idle Animation</label>
|
299 |
+
</div>
|
300 |
+
<div class="form-group">
|
301 |
+
<label for="length_of_audio">Length of Audio (seconds):</label>
|
302 |
+
<input type="number" class="form-control" id="length_of_audio" name="length_of_audio" min="0" value="0">
|
303 |
+
</div>
|
304 |
+
<div class="form-check">
|
305 |
+
<input type="checkbox" class="form-check-input" id="use_blink" name="use_blink" checked>
|
306 |
+
<label class="form-check-label" for="use_blink">Use Eye Blink</label>
|
307 |
+
</div>
|
308 |
+
<button type="submit" class="btn btn-primary">Generate</button>
|
309 |
+
</form>
|
310 |
+
</div>
|
311 |
+
</body>
|
312 |
+
</html>
|
313 |
+
"""
|
314 |
+
|
315 |
+
if __name__ == "__main__":
|
316 |
+
import uvicorn
|
317 |
+
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)
|