Uhhy commited on
Commit
41c7921
·
verified ·
1 Parent(s): 2107997

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -27
app.py CHANGED
@@ -11,6 +11,7 @@ from fastapi.staticfiles import StaticFiles
11
  from fastapi.templating import Jinja2Templates
12
  from transformers import pipeline
13
  from huggingface_hub import snapshot_download
 
14
  from examples.get_examples import get_examples
15
  from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
16
  from src.utils.preprocess import CropAndExtract
@@ -37,23 +38,15 @@ def mp3_to_wav(mp3_filename, wav_filename, frame_rate):
37
  frame_rate).export(wav_filename, format="wav")
38
 
39
  def get_pose_style_from_audio(audio_path):
40
- """Determines pose style based on audio emotion using a pre-trained model."""
41
- # Load the pre-trained emotion recognition model
42
  emotion_recognizer = pipeline("sentiment-analysis")
43
-
44
- # Analyze the audio emotion
45
  results = emotion_recognizer(audio_path)
46
  emotion = results[0]["label"]
47
-
48
- # Map emotion to pose style (you can adjust these mappings)
49
  pose_style_mapping = {
50
- "POSITIVE": 15, # Happy
51
- "NEGATIVE": 35, # Sad
52
- "NEUTRAL": 0, # Normal
53
- # Add more emotion mappings as needed
54
  }
55
-
56
- return pose_style_mapping.get(emotion, 0) # Default to neutral pose if unknown
57
 
58
  @spaces.GPU(duration=0)
59
  def generate_video(source_image: str, driven_audio: str, preprocess: str = 'crop', still_mode: bool = False,
@@ -61,7 +54,6 @@ def generate_video(source_image: str, driven_audio: str, preprocess: str = 'crop
61
  facerender: str = 'facevid2vid', exp_scale: float = 1.0, use_ref_video: bool = False,
62
  ref_video: str = None, ref_info: str = None, use_idle_mode: bool = False,
63
  length_of_audio: int = 0, use_blink: bool = True, result_dir: str = './results/') -> str:
64
- # Initialize models and paths
65
  sadtalker_paths = init_path(
66
  checkpoint_path, config_path, size, False, preprocess)
67
  audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
@@ -69,18 +61,15 @@ def generate_video(source_image: str, driven_audio: str, preprocess: str = 'crop
69
  animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device) if facerender == 'facevid2vid' and device != 'mps' \
70
  else AnimateFromCoeff_PIRender(sadtalker_paths, device)
71
 
72
- # Create directories for saving results
73
  time_tag = str(uuid.uuid4())
74
  save_dir = os.path.join(result_dir, time_tag)
75
  os.makedirs(save_dir, exist_ok=True)
76
  input_dir = os.path.join(save_dir, 'input')
77
  os.makedirs(input_dir, exist_ok=True)
78
 
79
- # Process source image
80
  pic_path = os.path.join(input_dir, os.path.basename(source_image))
81
  shutil.move(source_image, input_dir)
82
 
83
- # Process driven audio
84
  if driven_audio and os.path.isfile(driven_audio):
85
  audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
86
  if '.mp3' in audio_path:
@@ -96,7 +85,6 @@ def generate_video(source_image: str, driven_audio: str, preprocess: str = 'crop
96
  else:
97
  assert use_ref_video and ref_info == 'all'
98
 
99
- # Process reference video
100
  if use_ref_video and ref_info == 'all':
101
  ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0]
102
  audio_path = os.path.join(save_dir, ref_video_videoname+'.wav')
@@ -109,7 +97,6 @@ def generate_video(source_image: str, driven_audio: str, preprocess: str = 'crop
109
  else:
110
  ref_video_coeff_path = None
111
 
112
- # Preprocess source image
113
  first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
114
  os.makedirs(first_frame_dir, exist_ok=True)
115
  first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(
@@ -117,7 +104,6 @@ def generate_video(source_image: str, driven_audio: str, preprocess: str = 'crop
117
  if first_coeff_path is None:
118
  raise AttributeError("No face is detected")
119
 
120
- # Determine reference coefficients
121
  ref_pose_coeff_path, ref_eyeblink_coeff_path = None, None
122
  if use_ref_video:
123
  if ref_info == 'pose':
@@ -129,20 +115,17 @@ def generate_video(source_image: str, driven_audio: str, preprocess: str = 'crop
129
  else:
130
  ref_pose_coeff_path = ref_eyeblink_coeff_path = None
131
 
132
- # Generate coefficients from audio or reference video
133
  if use_ref_video and ref_info == 'all':
134
  coeff_path = ref_video_coeff_path
135
  else:
136
  batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path,
137
  still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink)
138
 
139
- # Get pose style from audio
140
  pose_style = get_pose_style_from_audio(audio_path)
141
 
142
  coeff_path = audio_to_coeff.generate(
143
  batch, save_dir, pose_style, ref_pose_coeff_path)
144
 
145
- # Generate video from coefficients
146
  data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode,
147
  preprocess=preprocess, size=size, expression_scale=exp_scale, facemodel=facerender)
148
  return_path = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None,
@@ -160,7 +143,6 @@ async def generate_video_api(source_image: UploadFile = File(...), driven_audio:
160
  use_ref_video: bool = Form(False), ref_video: UploadFile = File(None),
161
  ref_info: str = Form(None), use_idle_mode: bool = Form(False),
162
  length_of_audio: int = Form(0), use_blink: bool = Form(True), result_dir: str = Form('./results/')):
163
- # Save the uploaded files temporarily
164
  temp_source_image_path = f"temp/{source_image.filename}"
165
  os.makedirs("temp", exist_ok=True)
166
  with open(temp_source_image_path, "wb") as buffer:
@@ -180,7 +162,6 @@ async def generate_video_api(source_image: UploadFile = File(...), driven_audio:
180
  else:
181
  temp_ref_video_path = None
182
 
183
- # Generate the video
184
  video_path = generate_video(
185
  source_image=temp_source_image_path,
186
  driven_audio=temp_driven_audio_path,
@@ -200,10 +181,8 @@ async def generate_video_api(source_image: UploadFile = File(...), driven_audio:
200
  result_dir=result_dir
201
  )
202
 
203
- # Clean up temporary files
204
  shutil.rmtree("temp")
205
 
206
- # Return the generated video file
207
  return FileResponse(video_path)
208
 
209
 
@@ -314,4 +293,4 @@ html = """
314
 
315
  if __name__ == "__main__":
316
  import uvicorn
317
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
11
  from fastapi.templating import Jinja2Templates
12
  from transformers import pipeline
13
  from huggingface_hub import snapshot_download
14
+
15
  from examples.get_examples import get_examples
16
  from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
17
  from src.utils.preprocess import CropAndExtract
 
38
  frame_rate).export(wav_filename, format="wav")
39
 
40
  def get_pose_style_from_audio(audio_path):
 
 
41
  emotion_recognizer = pipeline("sentiment-analysis")
 
 
42
  results = emotion_recognizer(audio_path)
43
  emotion = results[0]["label"]
 
 
44
  pose_style_mapping = {
45
+ "POSITIVE": 15,
46
+ "NEGATIVE": 35,
47
+ "NEUTRAL": 0,
 
48
  }
49
+ return pose_style_mapping.get(emotion, 0)
 
50
 
51
  @spaces.GPU(duration=0)
52
  def generate_video(source_image: str, driven_audio: str, preprocess: str = 'crop', still_mode: bool = False,
 
54
  facerender: str = 'facevid2vid', exp_scale: float = 1.0, use_ref_video: bool = False,
55
  ref_video: str = None, ref_info: str = None, use_idle_mode: bool = False,
56
  length_of_audio: int = 0, use_blink: bool = True, result_dir: str = './results/') -> str:
 
57
  sadtalker_paths = init_path(
58
  checkpoint_path, config_path, size, False, preprocess)
59
  audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
 
61
  animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device) if facerender == 'facevid2vid' and device != 'mps' \
62
  else AnimateFromCoeff_PIRender(sadtalker_paths, device)
63
 
 
64
  time_tag = str(uuid.uuid4())
65
  save_dir = os.path.join(result_dir, time_tag)
66
  os.makedirs(save_dir, exist_ok=True)
67
  input_dir = os.path.join(save_dir, 'input')
68
  os.makedirs(input_dir, exist_ok=True)
69
 
 
70
  pic_path = os.path.join(input_dir, os.path.basename(source_image))
71
  shutil.move(source_image, input_dir)
72
 
 
73
  if driven_audio and os.path.isfile(driven_audio):
74
  audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
75
  if '.mp3' in audio_path:
 
85
  else:
86
  assert use_ref_video and ref_info == 'all'
87
 
 
88
  if use_ref_video and ref_info == 'all':
89
  ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0]
90
  audio_path = os.path.join(save_dir, ref_video_videoname+'.wav')
 
97
  else:
98
  ref_video_coeff_path = None
99
 
 
100
  first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
101
  os.makedirs(first_frame_dir, exist_ok=True)
102
  first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(
 
104
  if first_coeff_path is None:
105
  raise AttributeError("No face is detected")
106
 
 
107
  ref_pose_coeff_path, ref_eyeblink_coeff_path = None, None
108
  if use_ref_video:
109
  if ref_info == 'pose':
 
115
  else:
116
  ref_pose_coeff_path = ref_eyeblink_coeff_path = None
117
 
 
118
  if use_ref_video and ref_info == 'all':
119
  coeff_path = ref_video_coeff_path
120
  else:
121
  batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path,
122
  still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink)
123
 
 
124
  pose_style = get_pose_style_from_audio(audio_path)
125
 
126
  coeff_path = audio_to_coeff.generate(
127
  batch, save_dir, pose_style, ref_pose_coeff_path)
128
 
 
129
  data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode,
130
  preprocess=preprocess, size=size, expression_scale=exp_scale, facemodel=facerender)
131
  return_path = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None,
 
143
  use_ref_video: bool = Form(False), ref_video: UploadFile = File(None),
144
  ref_info: str = Form(None), use_idle_mode: bool = Form(False),
145
  length_of_audio: int = Form(0), use_blink: bool = Form(True), result_dir: str = Form('./results/')):
 
146
  temp_source_image_path = f"temp/{source_image.filename}"
147
  os.makedirs("temp", exist_ok=True)
148
  with open(temp_source_image_path, "wb") as buffer:
 
162
  else:
163
  temp_ref_video_path = None
164
 
 
165
  video_path = generate_video(
166
  source_image=temp_source_image_path,
167
  driven_audio=temp_driven_audio_path,
 
181
  result_dir=result_dir
182
  )
183
 
 
184
  shutil.rmtree("temp")
185
 
 
186
  return FileResponse(video_path)
187
 
188
 
 
293
 
294
  if __name__ == "__main__":
295
  import uvicorn
296
+ uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)