aletabu commited on
Commit
dff80ee
·
1 Parent(s): aff4679

wav2lip_gan support

Browse files
Files changed (3) hide show
  1. app.py +3 -5
  2. src/gradio_demo.py +1 -1
  3. test_api.py +60 -0
app.py CHANGED
@@ -31,11 +31,9 @@ def download_model():
31
  REPO_ID = 'vinthony/SadTalker-V002rc'
32
  snapshot_download(repo_id=REPO_ID, local_dir='./checkpoints', local_dir_use_symlinks=True)
33
 
34
- # Manually replace the Wav2Lip model with wav2lip_gan.pth
35
  wav2lip_model_path = os.path.join('./checkpoints', 'wav2lip.pth')
36
- if not os.path.exists(wav2lip_model_path):
37
- os.system(
38
- f"wget https://github.com/Rudrabha/Wav2Lip/releases/download/v1.0/wav2lip_gan.pth -O {wav2lip_model_path}")
39
 
40
  print("Replaced Wav2Lip model with Wav2Lip GAN.")
41
 
@@ -225,7 +223,7 @@ def sadtalker_demo():
225
  if __name__ == "__main__":
226
 
227
  demo = sadtalker_demo()
228
- demo.queue(max_size=10, api_open=True)
229
  demo.launch(debug=True)
230
 
231
 
 
31
  REPO_ID = 'vinthony/SadTalker-V002rc'
32
  snapshot_download(repo_id=REPO_ID, local_dir='./checkpoints', local_dir_use_symlinks=True)
33
 
34
+ # Descargar wav2lip_gan directamente desde GitHub
35
  wav2lip_model_path = os.path.join('./checkpoints', 'wav2lip.pth')
36
+ os.system(f"wget https://github.com/Rudrabha/Wav2Lip/releases/download/v1.0/wav2lip_gan.pth -O {wav2lip_model_path}")
 
 
37
 
38
  print("Replaced Wav2Lip model with Wav2Lip GAN.")
39
 
 
223
  if __name__ == "__main__":
224
 
225
  demo = sadtalker_demo()
226
+ demo.queue(max_size=100, api_open=True)
227
  demo.launch(debug=True)
228
 
229
 
src/gradio_demo.py CHANGED
@@ -164,7 +164,7 @@ class SadTalker():
164
  torch.cuda.synchronize()
165
 
166
  import gc; gc.collect()
167
-
168
  return return_path
169
 
170
 
 
164
  torch.cuda.synchronize()
165
 
166
  import gc; gc.collect()
167
+
168
  return return_path
169
 
170
 
test_api.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def call_sadtalker_api(source_image, driven_audio, preprocess_type, is_still_mode, enhancer, batch_size, size_of_image, pose_style, facerender, exp_weight, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, blink_every):
4
+ # Define the API endpoint
5
+ api_url = "http://localhost:7860/api/predict"
6
+
7
+ # Prepare the payload
8
+ payload = {
9
+ "source_image": source_image,
10
+ "driven_audio": driven_audio,
11
+ "preprocess_type": preprocess_type,
12
+ "is_still_mode": is_still_mode,
13
+ "enhancer": enhancer,
14
+ "batch_size": batch_size,
15
+ "size_of_image": size_of_image,
16
+ "pose_style": pose_style,
17
+ "facerender": facerender,
18
+ "exp_weight": exp_weight,
19
+ "use_ref_video": use_ref_video,
20
+ "ref_video": ref_video,
21
+ "ref_info": ref_info,
22
+ "use_idle_mode": use_idle_mode,
23
+ "length_of_audio": length_of_audio,
24
+ "blink_every": blink_every
25
+ }
26
+
27
+ # Make the API request
28
+ response = requests.post(api_url, json=payload)
29
+ result = response.json()
30
+
31
+ # Return the generated video URL
32
+ return result["data"]
33
+
34
+ # Create the Gradio interface
35
+ iface = gr.Interface(
36
+ fn=call_sadtalker_api,
37
+ inputs=[
38
+ gr.Image(type="filepath", label="Source Image"),
39
+ gr.Audio(type="filepath", label="Driven Audio"),
40
+ gr.Radio(["crop", "resize", "full", "extcrop", "extfull"], label="Preprocess Type"),
41
+ gr.Checkbox(label="Still Mode"),
42
+ gr.Checkbox(label="Enhancer"),
43
+ gr.Slider(minimum=1, maximum=10, step=1, label="Batch Size"),
44
+ gr.Radio([256, 512], label="Size of Image"),
45
+ gr.Slider(minimum=0, maximum=45, step=1, label="Pose Style"),
46
+ gr.Radio(["facevid2vid", "pirender"], label="Face Render"),
47
+ gr.Slider(minimum=0, maximum=3, step=0.1, label="Expression Weight"),
48
+ gr.Checkbox(label="Use Reference Video"),
49
+ gr.Video(label="Reference Video"),
50
+ gr.Radio(["pose", "blink", "pose+blink", "all"], label="Reference Info"),
51
+ gr.Checkbox(label="Use Idle Mode"),
52
+ gr.Number(label="Length of Audio"),
53
+ gr.Checkbox(label="Blink Every")
54
+ ],
55
+ outputs=gr.Video(label="Generated Video"),
56
+ title="SadTalker API Client"
57
+ )
58
+
59
+ # Launch the interface
60
+ iface.launch()