DeepBeepMeep commited on
Commit
3032e80
·
1 Parent(s): 1e8abb5

Added Rife Temporal upsampling and Lanczos spatial upsampling

Browse files
gradio_server.py CHANGED
@@ -97,6 +97,8 @@ def process_prompt_and_add_tasks(
97
  image_to_end,
98
  video_to_continue,
99
  max_frames,
 
 
100
  RIFLEx_setting,
101
  slg_switch,
102
  slg_layers,
@@ -230,6 +232,8 @@ def process_prompt_and_add_tasks(
230
  "image_to_end" : image_end,
231
  "video_to_continue" : video_to_continue ,
232
  "max_frames" : max_frames,
 
 
233
  "RIFLEx_setting" : RIFLEx_setting,
234
  "slg_switch" : slg_switch,
235
  "slg_layers" : slg_layers,
@@ -852,48 +856,63 @@ model_filename = ""
852
  # compile = "transformer"
853
 
854
  def preprocess_loras(sd):
 
 
 
 
855
  first = next(iter(sd), None)
856
  if first == None:
857
  return sd
858
- if not first.startswith("lora_unet_"):
859
- return sd
860
- new_sd = {}
861
- print("Converting Lora Safetensors format to Lora Diffusers format")
862
- alphas = {}
863
- repl_list = ["cross_attn", "self_attn", "ffn"]
864
- src_list = ["_" + k + "_" for k in repl_list]
865
- tgt_list = ["." + k + "." for k in repl_list]
866
 
867
- for k,v in sd.items():
868
- k = k.replace("lora_unet_blocks_","diffusion_model.blocks.")
869
 
870
- for s,t in zip(src_list, tgt_list):
871
- k = k.replace(s,t)
872
 
873
- k = k.replace("lora_up","lora_B")
874
- k = k.replace("lora_down","lora_A")
875
 
876
- if "alpha" in k:
877
- alphas[k] = v
878
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
879
  new_sd[k] = v
 
880
 
881
- new_alphas = {}
882
- for k,v in new_sd.items():
883
- if "lora_B" in k:
884
- dim = v.shape[1]
885
- elif "lora_A" in k:
886
- dim = v.shape[0]
887
- else:
888
- continue
889
- alpha_key = k[:-len("lora_X.weight")] +"alpha"
890
- if alpha_key in alphas:
891
- scale = alphas[alpha_key] / dim
892
- new_alphas[alpha_key] = scale
893
- else:
894
- print(f"Lora alpha'{alpha_key}' is missing")
895
- new_sd.update(new_alphas)
896
- return new_sd
897
 
898
 
899
  def download_models(transformer_filename, text_encoder_filename):
@@ -905,7 +924,7 @@ def download_models(transformer_filename, text_encoder_filename):
905
  from huggingface_hub import hf_hub_download, snapshot_download
906
  repoId = "DeepBeepMeep/Wan2.1"
907
  sourceFolderList = ["xlm-roberta-large", "", ]
908
- fileList = [ [], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
909
  targetRoot = "ckpts/"
910
  for sourceFolder, files in zip(sourceFolderList,fileList ):
911
  if len(files)==0:
@@ -1094,6 +1113,7 @@ def load_models(i2v):
1094
  wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P")
1095
  else:
1096
  wan_model, pipe = load_t2v_model(model_filename, "")
 
1097
  kwargs = { "extraModelsToQuantize": None}
1098
  if profile == 2 or profile == 4:
1099
  kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 }
@@ -1441,6 +1461,8 @@ def generate_video(
1441
  image_to_end,
1442
  video_to_continue,
1443
  max_frames,
 
 
1444
  RIFLEx_setting,
1445
  slg_switch,
1446
  slg_layers,
@@ -1693,6 +1715,7 @@ def generate_video(
1693
  cfg_star_switch = cfg_star_switch,
1694
  cfg_zero_step = cfg_zero_step,
1695
  )
 
1696
  except Exception as e:
1697
  if temp_filename!= None and os.path.isfile(temp_filename):
1698
  os.remove(temp_filename)
@@ -1717,8 +1740,6 @@ def generate_video(
1717
  VRAM_crash = True
1718
  break
1719
 
1720
- _ , exc_value, exc_traceback = sys.exc_info()
1721
-
1722
  state["prompt"] = ""
1723
  if VRAM_crash:
1724
  new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
@@ -1759,17 +1780,61 @@ def generate_video(
1759
  file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
1760
  else:
1761
  file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
1762
- video_path = os.path.join(save_path, file_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1763
  cache_video(
1764
  tensor=sample[None],
1765
  save_file=video_path,
1766
- fps=16,
1767
  nrow=1,
1768
  normalize=True,
1769
  value_range=(-1, 1))
 
1770
 
1771
  configs = get_settings_dict(state, image2video, prompt, 0 if image_to_end == None else 1 , video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1772
- loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
1773
 
1774
  metadata_choice = server_config.get("metadata_choice","metadata")
1775
  if metadata_choice == "json":
@@ -2231,7 +2296,7 @@ def switch_advanced(state, new_advanced, lset_name):
2231
 
2232
 
2233
  def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
2234
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
2235
 
2236
  loras = state["loras"]
2237
  activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
@@ -2251,6 +2316,8 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
2251
  "loras_multipliers": loras_mult_choices,
2252
  "tea_cache": tea_cache_setting,
2253
  "tea_cache_start_step_perc": tea_cache_start_step_perc,
 
 
2254
  "RIFLEx_setting": RIFLEx_setting,
2255
  "slg_switch": slg_switch,
2256
  "slg_layers": slg_layers,
@@ -2269,14 +2336,14 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
2269
  return ui_settings
2270
 
2271
  def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
2272
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
2273
 
2274
  if state.get("validate_success",0) != 1:
2275
  return
2276
 
2277
  image2video = state["image2video"]
2278
  ui_defaults = get_settings_dict(state, image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
2279
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
2280
 
2281
  defaults_filename = get_settings_file_name(image2video)
2282
 
@@ -2538,6 +2605,32 @@ def generate_video_tab(image2video=False):
2538
  )
2539
  tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults["tea_cache_start_step_perc"], step=1, label="Tea Cache starting moment in % of generation")
2540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2541
  gr.Markdown("<B>With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model</B>")
2542
  RIFLEx_setting = gr.Dropdown(
2543
  choices=[
@@ -2699,7 +2792,7 @@ def generate_video_tab(image2video=False):
2699
  )
2700
  save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
2701
  save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
2702
- loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
2703
  slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
2704
  save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
2705
  confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
@@ -2758,6 +2851,8 @@ def generate_video_tab(image2video=False):
2758
  image_to_end,
2759
  video_to_continue,
2760
  max_frames,
 
 
2761
  RIFLEx_setting,
2762
  slg_switch,
2763
  slg_layers,
 
97
  image_to_end,
98
  video_to_continue,
99
  max_frames,
100
+ temporal_upsampling,
101
+ spatial_upsampling,
102
  RIFLEx_setting,
103
  slg_switch,
104
  slg_layers,
 
232
  "image_to_end" : image_end,
233
  "video_to_continue" : video_to_continue ,
234
  "max_frames" : max_frames,
235
+ "temporal_upsampling" : temporal_upsampling,
236
+ "spatial_upsampling" : spatial_upsampling,
237
  "RIFLEx_setting" : RIFLEx_setting,
238
  "slg_switch" : slg_switch,
239
  "slg_layers" : slg_layers,
 
856
  # compile = "transformer"
857
 
858
  def preprocess_loras(sd):
859
+ if wan_model == None:
860
+ return sd
861
+ model_filename = wan_model._model_file_name
862
+
863
  first = next(iter(sd), None)
864
  if first == None:
865
  return sd
866
+
867
+ if first.startswith("lora_unet_"):
868
+ new_sd = {}
869
+ print("Converting Lora Safetensors format to Lora Diffusers format")
870
+ alphas = {}
871
+ repl_list = ["cross_attn", "self_attn", "ffn"]
872
+ src_list = ["_" + k + "_" for k in repl_list]
873
+ tgt_list = ["." + k + "." for k in repl_list]
874
 
875
+ for k,v in sd.items():
876
+ k = k.replace("lora_unet_blocks_","diffusion_model.blocks.")
877
 
878
+ for s,t in zip(src_list, tgt_list):
879
+ k = k.replace(s,t)
880
 
881
+ k = k.replace("lora_up","lora_B")
882
+ k = k.replace("lora_down","lora_A")
883
 
884
+ if "alpha" in k:
885
+ alphas[k] = v
886
+ else:
887
+ new_sd[k] = v
888
+
889
+ new_alphas = {}
890
+ for k,v in new_sd.items():
891
+ if "lora_B" in k:
892
+ dim = v.shape[1]
893
+ elif "lora_A" in k:
894
+ dim = v.shape[0]
895
+ else:
896
+ continue
897
+ alpha_key = k[:-len("lora_X.weight")] +"alpha"
898
+ if alpha_key in alphas:
899
+ scale = alphas[alpha_key] / dim
900
+ new_alphas[alpha_key] = scale
901
+ else:
902
+ print(f"Lora alpha'{alpha_key}' is missing")
903
+ new_sd.update(new_alphas)
904
+ sd = new_sd
905
+
906
+ if "text2video" in model_filename:
907
+ new_sd = {}
908
+ # convert loras for i2v to t2v
909
+ for k,v in sd.items():
910
+ if any(layer in k for layer in ["cross_attn.k_img", "cross_attn.v_img"]):
911
+ continue
912
  new_sd[k] = v
913
+ sd = new_sd
914
 
915
+ return sd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
916
 
917
 
918
  def download_models(transformer_filename, text_encoder_filename):
 
924
  from huggingface_hub import hf_hub_download, snapshot_download
925
  repoId = "DeepBeepMeep/Wan2.1"
926
  sourceFolderList = ["xlm-roberta-large", "", ]
927
+ fileList = [ [], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
928
  targetRoot = "ckpts/"
929
  for sourceFolder, files in zip(sourceFolderList,fileList ):
930
  if len(files)==0:
 
1113
  wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P")
1114
  else:
1115
  wan_model, pipe = load_t2v_model(model_filename, "")
1116
+ wan_model._model_file_name = model_filename
1117
  kwargs = { "extraModelsToQuantize": None}
1118
  if profile == 2 or profile == 4:
1119
  kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100, "*" : 1000 }
 
1461
  image_to_end,
1462
  video_to_continue,
1463
  max_frames,
1464
+ temporal_upsampling,
1465
+ spatial_upsampling,
1466
  RIFLEx_setting,
1467
  slg_switch,
1468
  slg_layers,
 
1715
  cfg_star_switch = cfg_star_switch,
1716
  cfg_zero_step = cfg_zero_step,
1717
  )
1718
+ # samples = torch.empty( (1,2)) #for testing
1719
  except Exception as e:
1720
  if temp_filename!= None and os.path.isfile(temp_filename):
1721
  os.remove(temp_filename)
 
1740
  VRAM_crash = True
1741
  break
1742
 
 
 
1743
  state["prompt"] = ""
1744
  if VRAM_crash:
1745
  new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
 
1780
  file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
1781
  else:
1782
  file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
1783
+ video_path = os.path.join(save_path, file_name)
1784
+ # if False: # for testing
1785
+ # torch.save(sample, "ouput.pt")
1786
+ # else:
1787
+ # sample =torch.load("ouput.pt")
1788
+ exp = 0
1789
+ fps = 16
1790
+
1791
+ if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0:
1792
+ progress_args = [0, status + " - Upsampling"]
1793
+ progress(*progress_args )
1794
+ gen["progress_args"] = progress_args
1795
+
1796
+ if temporal_upsampling == "rife2":
1797
+ exp = 1
1798
+ elif temporal_upsampling == "rife4":
1799
+ exp = 2
1800
+
1801
+ if exp > 0:
1802
+ from rife.inference import temporal_interpolation
1803
+ sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device="cuda")
1804
+ fps = fps * 2**exp
1805
+
1806
+ if len(spatial_upsampling) > 0:
1807
+ from wan.utils.utils import resize_lanczos
1808
+ if spatial_upsampling == "lanczos1.5":
1809
+ scale = 1.5
1810
+ else:
1811
+ scale = 2
1812
+ sample = (sample + 1) / 2
1813
+ h, w = sample.shape[-2:]
1814
+ h *= scale
1815
+ w *= scale
1816
+ new_frames =[]
1817
+ for i in range( sample.shape[1] ):
1818
+ frame = sample[:, i]
1819
+ frame = resize_lanczos(frame, h, w)
1820
+ frame = frame.unsqueeze(1)
1821
+ new_frames.append(frame)
1822
+ sample = torch.cat(new_frames, dim=1)
1823
+ new_frames = None
1824
+ sample = sample * 2 - 1
1825
+
1826
+
1827
  cache_video(
1828
  tensor=sample[None],
1829
  save_file=video_path,
1830
+ fps=fps,
1831
  nrow=1,
1832
  normalize=True,
1833
  value_range=(-1, 1))
1834
+
1835
 
1836
  configs = get_settings_dict(state, image2video, prompt, 0 if image_to_end == None else 1 , video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1837
+ loras_mult_choices, tea_cache , tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
1838
 
1839
  metadata_choice = server_config.get("metadata_choice","metadata")
1840
  if metadata_choice == "json":
 
2296
 
2297
 
2298
  def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
2299
+ loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
2300
 
2301
  loras = state["loras"]
2302
  activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
 
2316
  "loras_multipliers": loras_mult_choices,
2317
  "tea_cache": tea_cache_setting,
2318
  "tea_cache_start_step_perc": tea_cache_start_step_perc,
2319
+ "temporal_upsampling" : temporal_upsampling,
2320
+ "spatial_upsampling" : spatial_upsampling,
2321
  "RIFLEx_setting": RIFLEx_setting,
2322
  "slg_switch": slg_switch,
2323
  "slg_layers": slg_layers,
 
2336
  return ui_settings
2337
 
2338
  def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
2339
+ loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
2340
 
2341
  if state.get("validate_success",0) != 1:
2342
  return
2343
 
2344
  image2video = state["image2video"]
2345
  ui_defaults = get_settings_dict(state, image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
2346
+ loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
2347
 
2348
  defaults_filename = get_settings_file_name(image2video)
2349
 
 
2605
  )
2606
  tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults["tea_cache_start_step_perc"], step=1, label="Tea Cache starting moment in % of generation")
2607
 
2608
+ with gr.Row():
2609
+ gr.Markdown("<B>Upsampling</B>")
2610
+ with gr.Row():
2611
+ temporal_upsampling_choice = gr.Dropdown(
2612
+ choices=[
2613
+ ("Disabled", ""),
2614
+ ("Rife x2 (32 frames/s)", "rife2"),
2615
+ ("Rife x4 (64 frames/s)", "rife4"),
2616
+ ],
2617
+ value=ui_defaults.get("temporal_upsampling", ""),
2618
+ visible=True,
2619
+ scale = 1,
2620
+ label="Temporal Upsampling"
2621
+ )
2622
+ spatial_upsampling_choice = gr.Dropdown(
2623
+ choices=[
2624
+ ("Disabled", ""),
2625
+ ("Lanczos x1.5", "lanczos1.5"),
2626
+ ("Lanczos x2.0", "lanczos2"),
2627
+ ],
2628
+ value=ui_defaults.get("spatial_upsampling", ""),
2629
+ visible=True,
2630
+ scale = 1,
2631
+ label="Spatial Upsampling"
2632
+ )
2633
+
2634
  gr.Markdown("<B>With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model</B>")
2635
  RIFLEx_setting = gr.Dropdown(
2636
  choices=[
 
2792
  )
2793
  save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
2794
  save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
2795
+ loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling_choice, spatial_upsampling_choice, RIFLEx_setting, slg_switch, slg_layers,
2796
  slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
2797
  save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
2798
  confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
 
2851
  image_to_end,
2852
  video_to_continue,
2853
  max_frames,
2854
+ temporal_upsampling_choice,
2855
+ spatial_upsampling_choice,
2856
  RIFLEx_setting,
2857
  slg_switch,
2858
  slg_layers,
rife/IFNet_HDv3.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ # from ..model.warplayer import warp
5
+
6
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ backwarp_tenGrid = {}
9
+
10
+ def warp(tenInput, tenFlow, device):
11
+ k = (str(tenFlow.device), str(tenFlow.size()))
12
+ if k not in backwarp_tenGrid:
13
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
14
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
15
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
16
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
17
+ backwarp_tenGrid[k] = torch.cat(
18
+ [tenHorizontal, tenVertical], 1).to(device)
19
+
20
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
21
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
22
+
23
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
24
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
25
+
26
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
27
+ return nn.Sequential(
28
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
29
+ padding=padding, dilation=dilation, bias=True),
30
+ nn.PReLU(out_planes)
31
+ )
32
+
33
+ def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
34
+ return nn.Sequential(
35
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
36
+ padding=padding, dilation=dilation, bias=False),
37
+ nn.BatchNorm2d(out_planes),
38
+ nn.PReLU(out_planes)
39
+ )
40
+
41
+ class IFBlock(nn.Module):
42
+ def __init__(self, in_planes, c=64):
43
+ super(IFBlock, self).__init__()
44
+ self.conv0 = nn.Sequential(
45
+ conv(in_planes, c//2, 3, 2, 1),
46
+ conv(c//2, c, 3, 2, 1),
47
+ )
48
+ self.convblock0 = nn.Sequential(
49
+ conv(c, c),
50
+ conv(c, c)
51
+ )
52
+ self.convblock1 = nn.Sequential(
53
+ conv(c, c),
54
+ conv(c, c)
55
+ )
56
+ self.convblock2 = nn.Sequential(
57
+ conv(c, c),
58
+ conv(c, c)
59
+ )
60
+ self.convblock3 = nn.Sequential(
61
+ conv(c, c),
62
+ conv(c, c)
63
+ )
64
+ self.conv1 = nn.Sequential(
65
+ nn.ConvTranspose2d(c, c//2, 4, 2, 1),
66
+ nn.PReLU(c//2),
67
+ nn.ConvTranspose2d(c//2, 4, 4, 2, 1),
68
+ )
69
+ self.conv2 = nn.Sequential(
70
+ nn.ConvTranspose2d(c, c//2, 4, 2, 1),
71
+ nn.PReLU(c//2),
72
+ nn.ConvTranspose2d(c//2, 1, 4, 2, 1),
73
+ )
74
+
75
+ def forward(self, x, flow, scale=1):
76
+ x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
77
+ flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
78
+ feat = self.conv0(torch.cat((x, flow), 1))
79
+ feat = self.convblock0(feat) + feat
80
+ feat = self.convblock1(feat) + feat
81
+ feat = self.convblock2(feat) + feat
82
+ feat = self.convblock3(feat) + feat
83
+ flow = self.conv1(feat)
84
+ mask = self.conv2(feat)
85
+ flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
86
+ mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
87
+ return flow, mask
88
+
89
+ class IFNet(nn.Module):
90
+ def __init__(self):
91
+ super(IFNet, self).__init__()
92
+ self.block0 = IFBlock(7+4, c=90)
93
+ self.block1 = IFBlock(7+4, c=90)
94
+ self.block2 = IFBlock(7+4, c=90)
95
+ self.block_tea = IFBlock(10+4, c=90)
96
+ # self.contextnet = Contextnet()
97
+ # self.unet = Unet()
98
+
99
+ def forward(self, x, scale_list=[4, 2, 1], training=False):
100
+ if training == False:
101
+ channel = x.shape[1] // 2
102
+ img0 = x[:, :channel]
103
+ img1 = x[:, channel:]
104
+ flow_list = []
105
+ merged = []
106
+ mask_list = []
107
+ warped_img0 = img0
108
+ warped_img1 = img1
109
+ flow = (x[:, :4]).detach() * 0
110
+ mask = (x[:, :1]).detach() * 0
111
+ loss_cons = 0
112
+ block = [self.block0, self.block1, self.block2]
113
+ for i in range(3):
114
+ f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
115
+ f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
116
+ flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
117
+ mask = mask + (m0 + (-m1)) / 2
118
+ mask_list.append(mask)
119
+ flow_list.append(flow)
120
+ warped_img0 = warp(img0, flow[:, :2], device= flow.device)
121
+ warped_img1 = warp(img1, flow[:, 2:4], device= flow.device)
122
+ merged.append((warped_img0, warped_img1))
123
+ '''
124
+ c0 = self.contextnet(img0, flow[:, :2])
125
+ c1 = self.contextnet(img1, flow[:, 2:4])
126
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
127
+ res = tmp[:, 1:4] * 2 - 1
128
+ '''
129
+ for i in range(3):
130
+ mask_list[i] = torch.sigmoid(mask_list[i])
131
+ merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
132
+ # merged[i] = torch.clamp(merged[i] + res, 0, 1)
133
+ return flow_list, mask_list[2], merged
rife/RIFE_HDv3.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from torch.optim import AdamW
5
+ import torch.optim as optim
6
+ import itertools
7
+ from torch.nn.parallel import DistributedDataParallel as DDP
8
+ from .IFNet_HDv3 import *
9
+ import torch.nn.functional as F
10
+ # from ..model.loss import *
11
+
12
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ class Model:
15
+ def __init__(self, local_rank=-1):
16
+ self.flownet = IFNet()
17
+ # self.device()
18
+ # self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
19
+ # self.epe = EPE()
20
+ # self.vgg = VGGPerceptualLoss().to(device)
21
+ # self.sobel = SOBEL()
22
+ if local_rank != -1:
23
+ self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
24
+
25
+ def train(self):
26
+ self.flownet.train()
27
+
28
+ def eval(self):
29
+ self.flownet.eval()
30
+
31
+ def to(self, device):
32
+ self.flownet.to(device)
33
+
34
+ def load_model(self, path, rank=0, device = "cuda"):
35
+ self.device = device
36
+ def convert(param):
37
+ if rank == -1:
38
+ return {
39
+ k.replace("module.", ""): v
40
+ for k, v in param.items()
41
+ if "module." in k
42
+ }
43
+ else:
44
+ return param
45
+ self.flownet.load_state_dict(convert(torch.load(path, map_location=device)))
46
+
47
+ def save_model(self, path, rank=0):
48
+ if rank == 0:
49
+ torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
50
+
51
+ def inference(self, img0, img1, scale=1.0):
52
+ imgs = torch.cat((img0, img1), 1)
53
+ scale_list = [4/scale, 2/scale, 1/scale]
54
+ flow, mask, merged = self.flownet(imgs, scale_list)
55
+ return merged[2]
56
+
57
+ def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
58
+ for param_group in self.optimG.param_groups:
59
+ param_group['lr'] = learning_rate
60
+ img0 = imgs[:, :3]
61
+ img1 = imgs[:, 3:]
62
+ if training:
63
+ self.train()
64
+ else:
65
+ self.eval()
66
+ scale = [4, 2, 1]
67
+ flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training)
68
+ loss_l1 = (merged[2] - gt).abs().mean()
69
+ loss_smooth = self.sobel(flow[2], flow[2]*0).mean()
70
+ # loss_vgg = self.vgg(merged[2], gt)
71
+ if training:
72
+ self.optimG.zero_grad()
73
+ loss_G = loss_cons + loss_smooth * 0.1
74
+ loss_G.backward()
75
+ self.optimG.step()
76
+ else:
77
+ flow_teacher = flow[2]
78
+ return merged[2], {
79
+ 'mask': mask,
80
+ 'flow': flow[2][:, :2],
81
+ 'loss_l1': loss_l1,
82
+ 'loss_cons': loss_cons,
83
+ 'loss_smooth': loss_smooth,
84
+ }
rife/inference.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.nn import functional as F
4
+ # from .model.pytorch_msssim import ssim_matlab
5
+ from .ssim import ssim_matlab
6
+
7
+ from .RIFE_HDv3 import Model
8
+
9
+ def get_frame(frames, frame_no):
10
+ if frame_no >= frames.shape[1]:
11
+ return None
12
+ frame = (frames[:, frame_no] + 1) /2
13
+ frame = frame.clip(0., 1.)
14
+ return frame
15
+
16
+ def add_frame(frames, frame, h, w):
17
+ frame = (frame * 2) - 1
18
+ frame = frame.clip(-1., 1.)
19
+ frame = frame.squeeze(0)
20
+ frame = frame[:, :h, :w]
21
+ frame = frame.unsqueeze(1)
22
+ frames.append(frame.cpu())
23
+
24
+ def process_frames(model, device, frames, exp):
25
+ pos = 0
26
+ output_frames = []
27
+
28
+ lastframe = get_frame(frames, 0)
29
+ _, h, w = lastframe.shape
30
+ scale = 1
31
+ fp16 = False
32
+
33
+ def make_inference(I0, I1, n):
34
+ middle = model.inference(I0, I1, scale)
35
+ if n == 1:
36
+ return [middle]
37
+ first_half = make_inference(I0, middle, n=n//2)
38
+ second_half = make_inference(middle, I1, n=n//2)
39
+ if n%2:
40
+ return [*first_half, middle, *second_half]
41
+ else:
42
+ return [*first_half, *second_half]
43
+
44
+ tmp = max(32, int(32 / scale))
45
+ ph = ((h - 1) // tmp + 1) * tmp
46
+ pw = ((w - 1) // tmp + 1) * tmp
47
+ padding = (0, pw - w, 0, ph - h)
48
+
49
+ def pad_image(img):
50
+ if(fp16):
51
+ return F.pad(img, padding).half()
52
+ else:
53
+ return F.pad(img, padding)
54
+
55
+ I1 = lastframe.to(device, non_blocking=True).unsqueeze(0)
56
+ I1 = pad_image(I1)
57
+ temp = None # save lastframe when processing static frame
58
+
59
+ while True:
60
+ if temp is not None:
61
+ frame = temp
62
+ temp = None
63
+ else:
64
+ pos += 1
65
+ frame = get_frame(frames, pos)
66
+ if frame is None:
67
+ break
68
+ I0 = I1
69
+ I1 = frame.to(device, non_blocking=True).unsqueeze(0)
70
+ I1 = pad_image(I1)
71
+ I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
72
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
73
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
74
+
75
+ break_flag = False
76
+ if ssim > 0.996:
77
+ pos += 1
78
+ frame = get_frame(frames, pos)
79
+ if frame is None:
80
+ break_flag = True
81
+ frame = lastframe
82
+ else:
83
+ temp = frame
84
+ I1 = frame.to(device, non_blocking=True).unsqueeze(0)
85
+ I1 = pad_image(I1)
86
+ I1 = model.inference(I0, I1, scale)
87
+ I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
88
+ ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
89
+ frame = I1[0]
90
+
91
+ if ssim < 0.2:
92
+ output = []
93
+ for _ in range((2 ** exp) - 1):
94
+ output.append(I0)
95
+ else:
96
+ output = make_inference(I0, I1, 2**exp-1) if exp else []
97
+
98
+ add_frame(output_frames, lastframe, h, w)
99
+ for mid in output:
100
+ add_frame(output_frames, mid, h, w)
101
+ lastframe = frame
102
+ if break_flag:
103
+ break
104
+
105
+ add_frame(output_frames, lastframe, h, w)
106
+ return torch.cat( output_frames, dim=1)
107
+
108
+ def temporal_interpolation(model_path, frames, exp, device ="cuda"):
109
+
110
+ model = Model()
111
+ model.load_model(model_path, -1, device=device)
112
+
113
+ model.eval()
114
+ model.to(device=device)
115
+
116
+ with torch.no_grad():
117
+ output = process_frames(model, device, frames, exp)
118
+
119
+ return output
rife/ssim.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from math import exp
4
+ import numpy as np
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ def gaussian(window_size, sigma):
9
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
10
+ return gauss/gauss.sum()
11
+
12
+
13
+ def create_window(window_size, channel=1):
14
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
15
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
16
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
17
+ return window
18
+
19
+ def create_window_3d(window_size, channel=1):
20
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
21
+ _2D_window = _1D_window.mm(_1D_window.t())
22
+ _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
23
+ window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
24
+ return window
25
+
26
+
27
+ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
28
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
29
+ if val_range is None:
30
+ if torch.max(img1) > 128:
31
+ max_val = 255
32
+ else:
33
+ max_val = 1
34
+
35
+ if torch.min(img1) < -0.5:
36
+ min_val = -1
37
+ else:
38
+ min_val = 0
39
+ L = max_val - min_val
40
+ else:
41
+ L = val_range
42
+
43
+ padd = 0
44
+ (_, channel, height, width) = img1.size()
45
+ if window is None:
46
+ real_size = min(window_size, height, width)
47
+ window = create_window(real_size, channel=channel).to(img1.device)
48
+
49
+ # mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
50
+ # mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
51
+ mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
52
+ mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel)
53
+
54
+ mu1_sq = mu1.pow(2)
55
+ mu2_sq = mu2.pow(2)
56
+ mu1_mu2 = mu1 * mu2
57
+
58
+ sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq
59
+ sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq
60
+ sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2
61
+
62
+ C1 = (0.01 * L) ** 2
63
+ C2 = (0.03 * L) ** 2
64
+
65
+ v1 = 2.0 * sigma12 + C2
66
+ v2 = sigma1_sq + sigma2_sq + C2
67
+ cs = torch.mean(v1 / v2) # contrast sensitivity
68
+
69
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
70
+
71
+ if size_average:
72
+ ret = ssim_map.mean()
73
+ else:
74
+ ret = ssim_map.mean(1).mean(1).mean(1)
75
+
76
+ if full:
77
+ return ret, cs
78
+ return ret
79
+
80
+
81
+ def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
82
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
83
+ if val_range is None:
84
+ if torch.max(img1) > 128:
85
+ max_val = 255
86
+ else:
87
+ max_val = 1
88
+
89
+ if torch.min(img1) < -0.5:
90
+ min_val = -1
91
+ else:
92
+ min_val = 0
93
+ L = max_val - min_val
94
+ else:
95
+ L = val_range
96
+
97
+ padd = 0
98
+ (_, _, height, width) = img1.size()
99
+ if window is None:
100
+ real_size = min(window_size, height, width)
101
+ window = create_window_3d(real_size, channel=1).to(img1.device)
102
+ # Channel is set to 1 since we consider color images as volumetric images
103
+
104
+ img1 = img1.unsqueeze(1)
105
+ img2 = img2.unsqueeze(1)
106
+
107
+ mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
108
+ mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1)
109
+
110
+ mu1_sq = mu1.pow(2)
111
+ mu2_sq = mu2.pow(2)
112
+ mu1_mu2 = mu1 * mu2
113
+
114
+ sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq
115
+ sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq
116
+ sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2
117
+
118
+ C1 = (0.01 * L) ** 2
119
+ C2 = (0.03 * L) ** 2
120
+
121
+ v1 = 2.0 * sigma12 + C2
122
+ v2 = sigma1_sq + sigma2_sq + C2
123
+ cs = torch.mean(v1 / v2) # contrast sensitivity
124
+
125
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
126
+
127
+ if size_average:
128
+ ret = ssim_map.mean()
129
+ else:
130
+ ret = ssim_map.mean(1).mean(1).mean(1)
131
+
132
+ if full:
133
+ return ret, cs
134
+ return ret
135
+
136
+
137
+ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
138
+ device = img1.device
139
+ weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
140
+ levels = weights.size()[0]
141
+ mssim = []
142
+ mcs = []
143
+ for _ in range(levels):
144
+ sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
145
+ mssim.append(sim)
146
+ mcs.append(cs)
147
+
148
+ img1 = F.avg_pool2d(img1, (2, 2))
149
+ img2 = F.avg_pool2d(img2, (2, 2))
150
+
151
+ mssim = torch.stack(mssim)
152
+ mcs = torch.stack(mcs)
153
+
154
+ # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
155
+ if normalize:
156
+ mssim = (mssim + 1) / 2
157
+ mcs = (mcs + 1) / 2
158
+
159
+ pow1 = mcs ** weights
160
+ pow2 = mssim ** weights
161
+ # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
162
+ output = torch.prod(pow1[:-1] * pow2[-1])
163
+ return output
164
+
165
+
166
+ # Classes to re-use window
167
+ class SSIM(torch.nn.Module):
168
+ def __init__(self, window_size=11, size_average=True, val_range=None):
169
+ super(SSIM, self).__init__()
170
+ self.window_size = window_size
171
+ self.size_average = size_average
172
+ self.val_range = val_range
173
+
174
+ # Assume 3 channel for SSIM
175
+ self.channel = 3
176
+ self.window = create_window(window_size, channel=self.channel)
177
+
178
+ def forward(self, img1, img2):
179
+ (_, channel, _, _) = img1.size()
180
+
181
+ if channel == self.channel and self.window.dtype == img1.dtype:
182
+ window = self.window
183
+ else:
184
+ window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
185
+ self.window = window
186
+ self.channel = channel
187
+
188
+ _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
189
+ dssim = (1 - _ssim) / 2
190
+ return dssim
191
+
192
+ class MSSSIM(torch.nn.Module):
193
+ def __init__(self, window_size=11, size_average=True, channel=3):
194
+ super(MSSSIM, self).__init__()
195
+ self.window_size = window_size
196
+ self.size_average = size_average
197
+ self.channel = channel
198
+
199
+ def forward(self, img1, img2):
200
+ return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
wan/image2video.py CHANGED
@@ -25,8 +25,7 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
25
  get_sampling_sigmas, retrieve_timesteps)
26
  from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
27
  from wan.modules.posemb_layers import get_rotary_pos_embed
28
-
29
- from PIL import Image
30
 
31
  def optimized_scale(positive_flat, negative_flat):
32
 
@@ -41,10 +40,6 @@ def optimized_scale(positive_flat, negative_flat):
41
 
42
  return st_star
43
 
44
- def resize_lanczos(img, h, w):
45
- img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
46
- img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
47
- return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
48
 
49
 
50
  class WanI2V:
@@ -285,21 +280,6 @@ class WanI2V:
285
  self.clip.model.cpu()
286
 
287
  from mmgp import offload
288
-
289
-
290
- # img_interpolated.save('aaa.png')
291
-
292
- # img_interpolated = torch.from_numpy(np.array(img_interpolated).astype(np.float32) / 255.0).movedim(-1, 0)
293
-
294
- # img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='lanczos')
295
- # img_interpolated = img_interpolated.squeeze(0).transpose(0,2).transpose(1,0)
296
- # img_interpolated = img_interpolated.clamp(-1, 1)
297
- # img_interpolated = (img_interpolated + 1)/2
298
- # img_interpolated = (img_interpolated*255).type(torch.uint8)
299
- # img_interpolated = img_interpolated.cpu().numpy()
300
- # xxx = Image.fromarray(img_interpolated, 'RGB')
301
- # xxx.save('my.png')
302
-
303
  offload.last_offload_obj.unload_all()
304
  if any_end_frame:
305
  mean2 = 0
 
25
  get_sampling_sigmas, retrieve_timesteps)
26
  from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
27
  from wan.modules.posemb_layers import get_rotary_pos_embed
28
+ from wan.utils.utils import resize_lanczos
 
29
 
30
  def optimized_scale(positive_flat, negative_flat):
31
 
 
40
 
41
  return st_star
42
 
 
 
 
 
43
 
44
 
45
  class WanI2V:
 
280
  self.clip.model.cpu()
281
 
282
  from mmgp import offload
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  offload.last_offload_obj.unload_all()
284
  if any_end_frame:
285
  mean2 = 0
wan/utils/utils.py CHANGED
@@ -7,9 +7,16 @@ import os.path as osp
7
  import imageio
8
  import torch
9
  import torchvision
 
 
10
 
11
  __all__ = ['cache_video', 'cache_image', 'str2bool']
12
 
 
 
 
 
 
13
 
14
  def rand_name(length=8, suffix=''):
15
  name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
 
7
  import imageio
8
  import torch
9
  import torchvision
10
+ from PIL import Image
11
+ import numpy as np
12
 
13
  __all__ = ['cache_video', 'cache_image', 'str2bool']
14
 
15
+ def resize_lanczos(img, h, w):
16
+ img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
17
+ img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
18
+ return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
19
+
20
 
21
  def rand_name(length=8, suffix=''):
22
  name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')