Uhhy commited on
Commit
4df1f63
·
verified ·
1 Parent(s): c6f8a33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -96
app.py CHANGED
@@ -5,8 +5,11 @@ import shutil
5
  from pydub import AudioSegment
6
  import spaces
7
  import torch
8
- import gradio as gr
9
- from huggingface_hub import snapshot_download
 
 
 
10
 
11
  from examples.get_examples import get_examples
12
  from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
@@ -19,24 +22,45 @@ from src.utils.init_path import init_path
19
 
20
  checkpoint_path = 'checkpoints'
21
  config_path = 'src/config'
22
- device = "cuda" if torch.cuda.is_available(
23
- ) else "mps" if platform.system() == 'Darwin' else "cpu"
24
 
25
  os.environ['TORCH_HOME'] = checkpoint_path
26
  snapshot_download(repo_id='vinthony/SadTalker-V002rc',
27
  local_dir=checkpoint_path, local_dir_use_symlinks=True)
28
 
 
 
 
29
 
30
  def mp3_to_wav(mp3_filename, wav_filename, frame_rate):
31
  AudioSegment.from_file(file=mp3_filename).set_frame_rate(
32
  frame_rate).export(wav_filename, format="wav")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  @spaces.GPU(duration=0)
36
- def generate_video(source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
37
- batch_size=1, size=256, pose_style=0, facerender='facevid2vid', exp_scale=1.0,
38
- use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False,
39
- length_of_audio=0, use_blink=True, result_dir='./results/'):
 
40
  # Initialize models and paths
41
  sadtalker_paths = init_path(
42
  checkpoint_path, config_path, size, False, preprocess)
@@ -111,6 +135,10 @@ def generate_video(source_image, driven_audio, preprocess='crop', still_mode=Fal
111
  else:
112
  batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path,
113
  still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink)
 
 
 
 
114
  coeff_path = audio_to_coeff.generate(
115
  batch, save_dir, pose_style, ref_pose_coeff_path)
116
 
@@ -124,94 +152,166 @@ def generate_video(source_image, driven_audio, preprocess='crop', still_mode=Fal
124
 
125
  return return_path
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- # Gradio UI
129
- with gr.Blocks(analytics_enabled=False) as demo:
130
- with gr.Row():
131
- with gr.Column(variant='panel'):
132
- with gr.Tabs(elem_id="sadtalker_source_image"):
133
- with gr.TabItem('Source image'):
134
- with gr.Row():
135
- source_image = gr.Image(
136
- label="Source image", sources="upload", type="filepath", elem_id="img2img_image")
137
-
138
- with gr.Tabs(elem_id="sadtalker_driven_audio"):
139
- with gr.TabItem('Driving Methods'):
140
- gr.Markdown(
141
- "Possible driving combinations: <br> 1. Audio only 2. Audio/IDLE Mode + Ref Video(pose, blink, pose+blink) 3. IDLE Mode only 4. Ref Video only (all) ")
142
-
143
- with gr.Row():
144
- driven_audio = gr.Audio(
145
- label="Input audio", sources="upload", type="filepath")
146
- driven_audio_no = gr.Audio(
147
- label="Use IDLE mode, no audio is required", sources="upload", type="filepath", visible=False)
148
-
149
- with gr.Column():
150
- use_idle_mode = gr.Checkbox(
151
- label="Use Idle Animation")
152
- length_of_audio = gr.Number(
153
- value=5, label="The length(seconds) of the generated video.")
154
- use_idle_mode.change(lambda choice: (gr.update(visible=not choice), gr.update(visible=choice)),
155
- inputs=use_idle_mode, outputs=[driven_audio, driven_audio_no])
156
-
157
- with gr.Row():
158
- ref_video = gr.Video(
159
- label="Reference Video", sources="upload", elem_id="vidref")
160
-
161
- with gr.Column():
162
- use_ref_video = gr.Checkbox(
163
- label="Use Reference Video")
164
- ref_info = gr.Radio(['pose', 'blink', 'pose+blink', 'all'], value='pose', label='Reference Video',
165
- info="How to borrow from reference Video?((fully transfer, aka, video driving mode))")
166
-
167
- ref_video.change(lambda path: gr.update(
168
- value=path is not None), inputs=ref_video, outputs=use_ref_video)
169
-
170
- with gr.Column(variant='panel'):
171
- with gr.Tabs(elem_id="sadtalker_checkbox"):
172
- with gr.TabItem('Settings'):
173
- with gr.Column(variant='panel'):
174
- with gr.Row():
175
- pose_style = gr.Slider(
176
- minimum=0, maximum=45, step=1, label="Pose style", value=0)
177
- exp_weight = gr.Slider(
178
- minimum=0, maximum=3, step=0.1, label="expression scale", value=1)
179
- blink_every = gr.Checkbox(
180
- label="use eye blink", value=True)
181
-
182
- with gr.Row():
183
- size_of_image = gr.Radio(
184
- [256, 512], value=256, label='face model resolution', info="use 256/512 model?")
185
- preprocess_type = gr.Radio(
186
- ['crop', 'resize', 'full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?")
187
-
188
- with gr.Row():
189
- is_still_mode = gr.Checkbox(
190
- label="Still Mode (fewer head motion, works with preprocess `full`)")
191
- facerender = gr.Radio(
192
- ['facevid2vid', 'pirender'], value='facevid2vid', label='facerender', info="which face render?")
193
-
194
- with gr.Row():
195
- batch_size = gr.Slider(
196
- label="batch size in generation", step=1, maximum=10, value=1)
197
- enhancer = gr.Checkbox(
198
- label="GFPGAN as Face enhancer", value=True)
199
-
200
- submit = gr.Button(
201
- 'Generate', elem_id="sadtalker_generate", variant='primary')
202
-
203
- with gr.Tabs(elem_id="sadtalker_generated"):
204
- gen_video = gr.Video(label="Generated video")
205
-
206
- submit.click(
207
- fn=generate_video,
208
- inputs=[source_image, driven_audio, preprocess_type, is_still_mode, enhancer, batch_size, size_of_image,
209
- pose_style, facerender, exp_weight, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, blink_every],
210
- outputs=[gen_video],
211
  )
212
 
213
- with gr.Row():
214
- gr.Examples(examples=get_examples(), inputs=[source_image, driven_audio, preprocess_type, is_still_mode, enhancer],
215
- outputs=[gen_video], fn=generate_video)
216
-
217
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from pydub import AudioSegment
6
  import spaces
7
  import torch
8
+ from fastapi import FastAPI, File, UploadFile, Form
9
+ from fastapi.responses import FileResponse
10
+ from fastapi.staticfiles import StaticFiles
11
+ from fastapi.templating import Jinja2Templates
12
+ from transformers import pipeline
13
 
14
  from examples.get_examples import get_examples
15
  from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
 
22
 
23
  checkpoint_path = 'checkpoints'
24
  config_path = 'src/config'
25
+ device = "cuda" if torch.cuda.is_available() else "mps" if platform.system() == 'Darwin' else "cpu"
 
26
 
27
  os.environ['TORCH_HOME'] = checkpoint_path
28
  snapshot_download(repo_id='vinthony/SadTalker-V002rc',
29
  local_dir=checkpoint_path, local_dir_use_symlinks=True)
30
 
31
+ app = FastAPI()
32
+ app.mount("/results", StaticFiles(directory="results"), name="results")
33
+ templates = Jinja2Templates(directory="templates")
34
 
35
  def mp3_to_wav(mp3_filename, wav_filename, frame_rate):
36
  AudioSegment.from_file(file=mp3_filename).set_frame_rate(
37
  frame_rate).export(wav_filename, format="wav")
38
 
39
+ def get_pose_style_from_audio(audio_path):
40
+ """Determines pose style based on audio emotion using a pre-trained model."""
41
+ # Load the pre-trained emotion recognition model
42
+ emotion_recognizer = pipeline("sentiment-analysis")
43
+
44
+ # Analyze the audio emotion
45
+ results = emotion_recognizer(audio_path)
46
+ emotion = results[0]["label"]
47
+
48
+ # Map emotion to pose style (you can adjust these mappings)
49
+ pose_style_mapping = {
50
+ "POSITIVE": 15, # Happy
51
+ "NEGATIVE": 35, # Sad
52
+ "NEUTRAL": 0, # Normal
53
+ # Add more emotion mappings as needed
54
+ }
55
+
56
+ return pose_style_mapping.get(emotion, 0) # Default to neutral pose if unknown
57
 
58
  @spaces.GPU(duration=0)
59
+ def generate_video(source_image: str, driven_audio: str, preprocess: str = 'crop', still_mode: bool = False,
60
+ use_enhancer: bool = False, batch_size: int = 1, size: int = 256,
61
+ facerender: str = 'facevid2vid', exp_scale: float = 1.0, use_ref_video: bool = False,
62
+ ref_video: str = None, ref_info: str = None, use_idle_mode: bool = False,
63
+ length_of_audio: int = 0, use_blink: bool = True, result_dir: str = './results/') -> str:
64
  # Initialize models and paths
65
  sadtalker_paths = init_path(
66
  checkpoint_path, config_path, size, False, preprocess)
 
135
  else:
136
  batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path,
137
  still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink)
138
+
139
+ # Get pose style from audio
140
+ pose_style = get_pose_style_from_audio(audio_path)
141
+
142
  coeff_path = audio_to_coeff.generate(
143
  batch, save_dir, pose_style, ref_pose_coeff_path)
144
 
 
152
 
153
  return return_path
154
 
155
+ @app.post("/generate")
156
+ async def generate_video_api(source_image: UploadFile = File(...), driven_audio: UploadFile = File(None),
157
+ preprocess: str = Form('crop'), still_mode: bool = Form(False),
158
+ use_enhancer: bool = Form(False), batch_size: int = Form(1), size: int = Form(256),
159
+ facerender: str = Form('facevid2vid'), exp_scale: float = Form(1.0),
160
+ use_ref_video: bool = Form(False), ref_video: UploadFile = File(None),
161
+ ref_info: str = Form(None), use_idle_mode: bool = Form(False),
162
+ length_of_audio: int = Form(0), use_blink: bool = Form(True), result_dir: str = Form('./results/')):
163
+ # Save the uploaded files temporarily
164
+ temp_source_image_path = f"temp/{source_image.filename}"
165
+ os.makedirs("temp", exist_ok=True)
166
+ with open(temp_source_image_path, "wb") as buffer:
167
+ shutil.copyfileobj(source_image.file, buffer)
168
+
169
+ if driven_audio:
170
+ temp_driven_audio_path = f"temp/{driven_audio.filename}"
171
+ with open(temp_driven_audio_path, "wb") as buffer:
172
+ shutil.copyfileobj(driven_audio.file, buffer)
173
+ else:
174
+ temp_driven_audio_path = None
175
 
176
+ if ref_video:
177
+ temp_ref_video_path = f"temp/{ref_video.filename}"
178
+ with open(temp_ref_video_path, "wb") as buffer:
179
+ shutil.copyfileobj(ref_video.file, buffer)
180
+ else:
181
+ temp_ref_video_path = None
182
+
183
+ # Generate the video
184
+ video_path = generate_video(
185
+ source_image=temp_source_image_path,
186
+ driven_audio=temp_driven_audio_path,
187
+ preprocess=preprocess,
188
+ still_mode=still_mode,
189
+ use_enhancer=use_enhancer,
190
+ batch_size=batch_size,
191
+ size=size,
192
+ facerender=facerender,
193
+ exp_scale=exp_scale,
194
+ use_ref_video=use_ref_video,
195
+ ref_video=temp_ref_video_path,
196
+ ref_info=ref_info,
197
+ use_idle_mode=use_idle_mode,
198
+ length_of_audio=length_of_audio,
199
+ use_blink=use_blink,
200
+ result_dir=result_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  )
202
 
203
+ # Clean up temporary files
204
+ shutil.rmtree("temp")
205
+
206
+ # Return the generated video file
207
+ return FileResponse(video_path)
208
+
209
+
210
+ @app.get("/")
211
+ async def root(request):
212
+ return templates.TemplateResponse("index.html", {"request": request})
213
+
214
+ # HTML Template (`templates/index.html`)
215
+ html = """
216
+ <!DOCTYPE html>
217
+ <html lang="en">
218
+ <head>
219
+ <meta charset="UTF-8">
220
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
221
+ <title>SadTalker API</title>
222
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css">
223
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/jquery.slim.min.js"></script>
224
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/umd/popper.min.js"></script>
225
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.min.js"></script>
226
+ </head>
227
+ <body>
228
+ <div class="container mt-5">
229
+ <h1>SadTalker API</h1>
230
+ <form method="POST" action="/generate" enctype="multipart/form-data">
231
+ <div class="form-group">
232
+ <label for="source_image">Source Image:</label>
233
+ <input type="file" class="form-control-file" id="source_image" name="source_image" required>
234
+ </div>
235
+ <div class="form-group">
236
+ <label for="driven_audio">Driving Audio:</label>
237
+ <input type="file" class="form-control-file" id="driven_audio" name="driven_audio">
238
+ </div>
239
+ <div class="form-group">
240
+ <label for="preprocess">Preprocess:</label>
241
+ <select class="form-control" id="preprocess" name="preprocess">
242
+ <option value="crop">Crop</option>
243
+ <option value="resize">Resize</option>
244
+ <option value="full">Full</option>
245
+ <option value="extcrop">ExtCrop</option>
246
+ <option value="extfull">ExtFull</option>
247
+ </select>
248
+ </div>
249
+ <div class="form-check">
250
+ <input type="checkbox" class="form-check-input" id="still_mode" name="still_mode">
251
+ <label class="form-check-label" for="still_mode">Still Mode</label>
252
+ </div>
253
+ <div class="form-check">
254
+ <input type="checkbox" class="form-check-input" id="use_enhancer" name="use_enhancer">
255
+ <label class="form-check-label" for="use_enhancer">Use GFPGAN Enhancer</label>
256
+ </div>
257
+ <div class="form-group">
258
+ <label for="batch_size">Batch Size:</label>
259
+ <input type="number" class="form-control" id="batch_size" name="batch_size" min="1" max="10" value="1">
260
+ </div>
261
+ <div class="form-group">
262
+ <label for="size">Face Model Resolution:</label>
263
+ <select class="form-control" id="size" name="size">
264
+ <option value="256">256</option>
265
+ <option value="512">512</option>
266
+ </select>
267
+ </div>
268
+ <div class="form-group">
269
+ <label for="facerender">Face Render:</label>
270
+ <select class="form-control" id="facerender" name="facerender">
271
+ <option value="facevid2vid">FaceVid2Vid</option>
272
+ <option value="pirender">PIRender</option>
273
+ </select>
274
+ </div>
275
+ <div class="form-group">
276
+ <label for="exp_scale">Expression Scale:</label>
277
+ <input type="number" class="form-control" id="exp_scale" name="exp_scale" min="0" max="3" step="0.1" value="1.0">
278
+ </div>
279
+ <div class="form-check">
280
+ <input type="checkbox" class="form-check-input" id="use_ref_video" name="use_ref_video">
281
+ <label class="form-check-label" for="use_ref_video">Use Reference Video</label>
282
+ </div>
283
+ <div class="form-group">
284
+ <label for="ref_video">Reference Video:</label>
285
+ <input type="file" class="form-control-file" id="ref_video" name="ref_video">
286
+ </div>
287
+ <div class="form-group">
288
+ <label for="ref_info">Reference Video Information:</label>
289
+ <select class="form-control" id="ref_info" name="ref_info">
290
+ <option value="pose">Pose</option>
291
+ <option value="blink">Blink</option>
292
+ <option value="pose+blink">Pose + Blink</option>
293
+ <option value="all">All</option>
294
+ </select>
295
+ </div>
296
+ <div class="form-check">
297
+ <input type="checkbox" class="form-check-input" id="use_idle_mode" name="use_idle_mode">
298
+ <label class="form-check-label" for="use_idle_mode">Use Idle Animation</label>
299
+ </div>
300
+ <div class="form-group">
301
+ <label for="length_of_audio">Length of Audio (seconds):</label>
302
+ <input type="number" class="form-control" id="length_of_audio" name="length_of_audio" min="0" value="0">
303
+ </div>
304
+ <div class="form-check">
305
+ <input type="checkbox" class="form-check-input" id="use_blink" name="use_blink" checked>
306
+ <label class="form-check-label" for="use_blink">Use Eye Blink</label>
307
+ </div>
308
+ <button type="submit" class="btn btn-primary">Generate</button>
309
+ </form>
310
+ </div>
311
+ </body>
312
+ </html>
313
+ """
314
+
315
+ if __name__ == "__main__":
316
+ import uvicorn
317
+ uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)