6Morpheus6 commited on
Commit
d27deb9
·
verified ·
1 Parent(s): 7a2ae76

use original aspect ration, use ffmpeg

Browse files

- pytorch 2.x compatibility
- use original aspect ratio
- use ffmpeg

Files changed (1) hide show
  1. FGT_codes/tool/video_inpainting.py +32 -16
FGT_codes/tool/video_inpainting.py CHANGED
@@ -28,6 +28,7 @@ import copy
28
  import glob
29
  import cv2
30
  import argparse
 
31
 
32
 
33
  def to_tensor(img):
@@ -206,13 +207,9 @@ def initialize_LAFC(args, device):
206
  pkg = import_module("LAFC.models.{}".format(model))
207
  model = pkg.Model(configs)
208
  if not torch.cuda.is_available():
209
- state = torch.load(
210
- checkpoint, map_location=lambda storage, loc: storage
211
- )
212
  else:
213
- state = torch.load(
214
- checkpoint, map_location=lambda storage, loc: storage.cuda(device)
215
- )
216
  model.load_state_dict(state["model_state_dict"])
217
  model = model.to(device)
218
  return model, configs
@@ -230,13 +227,9 @@ def initialize_FGT(args, device):
230
  net = import_module("FGT.models.{}".format(model))
231
  model = net.Model(configs).to(device)
232
  if not torch.cuda.is_available():
233
- state = torch.load(
234
- checkpoint, map_location=lambda storage, loc: storage
235
- )
236
  else:
237
- state = torch.load(
238
- checkpoint, map_location=lambda storage, loc: storage.cuda(device)
239
- )
240
  model.load_state_dict(state["model_state_dict"])
241
  return model, configs
242
 
@@ -427,7 +420,7 @@ def save_results(outdir, comp_frames):
427
  cv2.imwrite(out_path, comp_frames[i][:, :, ::-1])
428
 
429
 
430
- def video_inpainting(args, imgArr, imgMaskArr):
431
  #device = torch.device("cuda:{}".format(args.gpu))
432
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
433
  print(args)
@@ -739,6 +732,9 @@ def video_inpainting(args, imgArr, imgMaskArr):
739
  flows = np2tensor(videoFlowF, near="t")
740
  flows = norm_flows(flows).to(device)
741
 
 
 
 
742
  for f in range(0, video_length, neighbor_stride):
743
  neighbor_ids = [
744
  i
@@ -770,14 +766,34 @@ def video_inpainting(args, imgArr, imgMaskArr):
770
  comp_frames[idx].astype(np.float32) * 0.5
771
  + comp.astype(np.float32) * 0.5
772
  )
 
773
  if args.vis_frame:
774
  save_results(args.outroot, comp_frames)
 
775
  create_dir(args.outroot)
 
776
  for i in range(len(comp_frames)):
777
- comp_frames[i] = comp_frames[i].astype(np.uint8)
778
- imageio.mimwrite(
779
- os.path.join(args.outroot, args.outfilename), comp_frames, fps=args.out_fps, quality=8
 
 
 
 
 
 
 
 
 
 
 
780
  )
 
 
 
 
 
 
781
  print(f"Done, please check your result in {args.outroot} ")
782
 
783
 
 
28
  import glob
29
  import cv2
30
  import argparse
31
+ import ffmpeg
32
 
33
 
34
  def to_tensor(img):
 
207
  pkg = import_module("LAFC.models.{}".format(model))
208
  model = pkg.Model(configs)
209
  if not torch.cuda.is_available():
210
+ state = torch.load(checkpoint, map_location="cpu")
 
 
211
  else:
212
+ state = torch.load(checkpoint, map_location=device)
 
 
213
  model.load_state_dict(state["model_state_dict"])
214
  model = model.to(device)
215
  return model, configs
 
227
  net = import_module("FGT.models.{}".format(model))
228
  model = net.Model(configs).to(device)
229
  if not torch.cuda.is_available():
230
+ state = torch.load(checkpoint, map_location="cpu")
 
 
231
  else:
232
+ state = torch.load(checkpoint, map_location=device)
 
 
233
  model.load_state_dict(state["model_state_dict"])
234
  return model, configs
235
 
 
420
  cv2.imwrite(out_path, comp_frames[i][:, :, ::-1])
421
 
422
 
423
+ def video_inpainting(args, original_frame_list, imgArr, imgMaskArr):
424
  #device = torch.device("cuda:{}".format(args.gpu))
425
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
426
  print(args)
 
732
  flows = np2tensor(videoFlowF, near="t")
733
  flows = norm_flows(flows).to(device)
734
 
735
+ original_frame = original_frame_list[0]
736
+ orig_h, orig_w = original_frame.shape[:2]
737
+
738
  for f in range(0, video_length, neighbor_stride):
739
  neighbor_ids = [
740
  i
 
766
  comp_frames[idx].astype(np.float32) * 0.5
767
  + comp.astype(np.float32) * 0.5
768
  )
769
+
770
  if args.vis_frame:
771
  save_results(args.outroot, comp_frames)
772
+
773
  create_dir(args.outroot)
774
+
775
  for i in range(len(comp_frames)):
776
+ if comp_frames[i] is not None:
777
+ comp_frames[i] = cv2.resize(
778
+ comp_frames[i].astype(np.uint8),
779
+ (orig_w, orig_h),
780
+ interpolation=cv2.INTER_LANCZOS4
781
+ )
782
+ output_path = os.path.join(args.outroot, args.outfilename)
783
+ process = (
784
+ ffmpeg
785
+ .input('pipe:', format='rawvideo', pix_fmt='rgb24',
786
+ s=f'{orig_w}x{orig_h}', framerate=args.out_fps)
787
+ .output(output_path, vcodec='libx264', crf=18, pix_fmt='yuv420p')
788
+ .overwrite_output()
789
+ .run_async(pipe_stdin=True)
790
  )
791
+
792
+ for frame in comp_frames:
793
+ process.stdin.write(frame.tobytes())
794
+ process.stdin.close()
795
+ process.wait()
796
+
797
  print(f"Done, please check your result in {args.outroot} ")
798
 
799