awacke1 commited on
Commit
90d8457
·
verified ·
1 Parent(s): 754171c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -27
app.py CHANGED
@@ -11,12 +11,27 @@ from diffusers.utils import load_image, export_to_video
11
  from PIL import Image
12
  from huggingface_hub import hf_hub_download
13
 
 
 
 
14
 
 
 
 
 
 
 
15
  pipe = StableVideoDiffusionPipeline.from_pretrained(
16
- "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
17
  )
18
- pipe.to("cuda")
19
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
 
 
 
 
 
 
20
  max_64_bit_int = 2**63 - 1
21
 
22
  # Function to sample video from the input image
@@ -29,7 +44,6 @@ def sample(
29
  version: str = "svd_xt",
30
  cond_aug: float = 0.02,
31
  decoding_t: int = 3, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
32
- device: str = "cuda",
33
  output_folder: str = "outputs",
34
  ):
35
  if image.mode == "RGBA":
@@ -42,20 +56,30 @@ def sample(
42
  os.makedirs(output_folder, exist_ok=True)
43
  base_count = len(glob(os.path.join(output_folder, "*.mp4")))
44
  video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
45
- frames = pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=0.1, num_frames=25).frames[0]
 
 
 
 
 
 
 
 
 
46
  export_to_video(frames, video_path, fps=fps_id)
47
  torch.manual_seed(seed)
48
  return video_path, seed
49
 
50
- # Function to resize the uploaded image
51
  def resize_image(image, output_size=(1024, 576)):
 
52
  target_aspect = output_size[0] / output_size[1]
53
  image_aspect = image.width / image.height
54
 
55
  if image_aspect > target_aspect:
56
  new_height = output_size[1]
57
  new_width = int(new_height * image_aspect)
58
- resized_image = image.resize((new_width, new_height), Image.LANCZOS)
59
  left = (new_width - output_size[0]) / 2
60
  top = 0
61
  right = (new_width + output_size[0]) / 2
@@ -63,7 +87,7 @@ def resize_image(image, output_size=(1024, 576)):
63
  else:
64
  new_width = output_size[0]
65
  new_height = int(new_width / image_aspect)
66
- resized_image = image.resize((new_width, new_height), Image.LANCZOS)
67
  left = 0
68
  top = (new_height - output_size[1]) / 2
69
  right = output_size[0]
@@ -75,39 +99,50 @@ def resize_image(image, output_size=(1024, 576)):
75
  # Dynamically load image files from the 'images' directory
76
  def get_example_images():
77
  image_dir = "images/"
 
 
78
  image_files = glob(os.path.join(image_dir, "*.png")) + glob(os.path.join(image_dir, "*.jpg"))
79
  return image_files
80
 
81
  # Gradio interface setup
82
  with gr.Blocks() as demo:
83
- gr.Markdown('''# Stable Video Diffusion using Image 2 Video XT
84
- #### Research release: generate `4s` vid from a single image at (`25 frames` at `6 fps`).''')
85
 
86
  with gr.Row():
87
  with gr.Column():
88
- image = gr.Image(label="Upload your image", type="pil")
89
- generate_btn = gr.Button("Generate")
90
- video = gr.Video()
91
 
92
- with gr.Accordion("Advanced options", open=False):
93
- seed = gr.Slider(label="Seed", value=42, randomize=True, minimum=0, maximum=max_64_bit_int, step=1)
94
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
95
- motion_bucket_id = gr.Slider(label="Motion bucket id", value=127, minimum=1, maximum=255)
96
- fps_id = gr.Slider(label="Frames per second", value=6, minimum=5, maximum=30)
97
 
 
98
  image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
99
- generate_btn.click(fn=sample, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id], outputs=[video, seed], api_name="video")
 
 
 
 
 
 
 
100
 
101
  # Dynamically load examples from the filesystem
102
  example_images = get_example_images()
103
- gr.Examples(
104
- examples=example_images,
105
- inputs=image,
106
- outputs=[video, seed],
107
- fn=sample,
108
- cache_examples=True,
109
- )
 
110
 
111
  if __name__ == "__main__":
112
  demo.queue(max_size=20)
113
- demo.launch(share=True)
 
11
  from PIL import Image
12
  from huggingface_hub import hf_hub_download
13
 
14
+ # ------------------------------------------------------------------------
15
+ # FIX: Adapt to the available hardware (GPU or CPU)
16
+ # ------------------------------------------------------------------------
17
 
18
+ # Automatically detect the device and select the appropriate data type.
19
+ # This makes the code runnable on machines with or without a dedicated NVIDIA GPU.
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
22
+
23
+ # Load the pipeline onto the detected device.
24
  pipe = StableVideoDiffusionPipeline.from_pretrained(
25
+ "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch_dtype, variant="fp16"
26
  )
27
+ pipe.to(device)
28
+
29
+ # Apply torch.compile for optimization only if on a GPU, as it's most effective there.
30
+ if device == "cuda":
31
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
32
+
33
+ # ------------------------------------------------------------------------
34
+
35
  max_64_bit_int = 2**63 - 1
36
 
37
  # Function to sample video from the input image
 
44
  version: str = "svd_xt",
45
  cond_aug: float = 0.02,
46
  decoding_t: int = 3, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
 
47
  output_folder: str = "outputs",
48
  ):
49
  if image.mode == "RGBA":
 
56
  os.makedirs(output_folder, exist_ok=True)
57
  base_count = len(glob(os.path.join(output_folder, "*.mp4")))
58
  video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
59
+
60
+ frames = pipe(
61
+ image,
62
+ decode_chunk_size=decoding_t,
63
+ generator=generator,
64
+ motion_bucket_id=motion_bucket_id,
65
+ noise_aug_strength=0.1,
66
+ num_frames=25
67
+ ).frames[0]
68
+
69
  export_to_video(frames, video_path, fps=fps_id)
70
  torch.manual_seed(seed)
71
  return video_path, seed
72
 
73
+ # Function to resize the uploaded image to the model's optimal input size
74
  def resize_image(image, output_size=(1024, 576)):
75
+ # Resizes and crops the image to a 16:9 aspect ratio.
76
  target_aspect = output_size[0] / output_size[1]
77
  image_aspect = image.width / image.height
78
 
79
  if image_aspect > target_aspect:
80
  new_height = output_size[1]
81
  new_width = int(new_height * image_aspect)
82
+ resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
83
  left = (new_width - output_size[0]) / 2
84
  top = 0
85
  right = (new_width + output_size[0]) / 2
 
87
  else:
88
  new_width = output_size[0]
89
  new_height = int(new_width / image_aspect)
90
+ resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
91
  left = 0
92
  top = (new_height - output_size[1]) / 2
93
  right = output_size[0]
 
99
  # Dynamically load image files from the 'images' directory
100
  def get_example_images():
101
  image_dir = "images/"
102
+ if not os.path.exists(image_dir):
103
+ os.makedirs(image_dir)
104
  image_files = glob(os.path.join(image_dir, "*.png")) + glob(os.path.join(image_dir, "*.jpg"))
105
  return image_files
106
 
107
  # Gradio interface setup
108
  with gr.Blocks() as demo:
109
+ gr.Markdown('''# Stable Video Diffusion
110
+ #### Generate short videos from a single image.''')
111
 
112
  with gr.Row():
113
  with gr.Column():
114
+ image = gr.Image(label="Upload Your Image", type="pil")
115
+ generate_btn = gr.Button("Generate Video", variant="primary")
116
+ video = gr.Video(label="Generated Video")
117
 
118
+ with gr.Accordion("Advanced Options", open=False):
119
+ seed = gr.Slider(label="Seed", value=42, minimum=0, maximum=max_64_bit_int, step=1)
120
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
121
+ motion_bucket_id = gr.Slider(label="Motion Bucket ID", info="Controls the amount of motion in the video.", value=127, minimum=1, maximum=255)
122
+ fps_id = gr.Slider(label="Frames Per Second (FPS)", info="Adjusts the playback speed of the video.", value=7, minimum=5, maximum=30)
123
 
124
+ # When a new image is uploaded, process it immediately
125
  image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
126
+
127
+ # When the generate button is clicked, run the sampling function
128
+ generate_btn.click(
129
+ fn=sample,
130
+ inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id],
131
+ outputs=[video, seed],
132
+ api_name="video"
133
+ )
134
 
135
  # Dynamically load examples from the filesystem
136
  example_images = get_example_images()
137
+ if example_images:
138
+ gr.Examples(
139
+ examples=example_images,
140
+ inputs=image,
141
+ outputs=[video, seed],
142
+ fn=lambda img: sample(resize_image(Image.open(img))), # Resize example images before sampling
143
+ cache_examples=True,
144
+ )
145
 
146
  if __name__ == "__main__":
147
  demo.queue(max_size=20)
148
+ demo.launch(share=True)