fffiloni commited on
Commit
4d07565
·
verified ·
1 Parent(s): 95aee8f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import os
4
+ import random
5
+
6
+ import imageio
7
+ import torch
8
+ from diffusers import PNDMScheduler
9
+ from huggingface_hub import hf_hub_download
10
+ from torchvision.utils import save_image
11
+ from diffusers.models import AutoencoderKL
12
+ from datetime import datetime
13
+ from typing import List, Union
14
+ import gradio as gr
15
+ import numpy as np
16
+ from gradio.components import Textbox, Video, Image
17
+ from transformers import T5Tokenizer, T5EncoderModel
18
+
19
+ from opensora.models.ae import ae_stride_config, getae, getae_wrapper
20
+ from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper
21
+ from opensora.models.diffusion.latte.modeling_latte import LatteT2V
22
+ from opensora.sample.pipeline_videogen import VideoGenPipeline
23
+ from opensora.serve.gradio_utils import block_css, title_markdown, randomize_seed_fn, set_env, examples, DESCRIPTION
24
+
25
+ import spaces
26
+
27
+ @spaces.GPU
28
+ def generate_img(prompt, sample_steps, scale, seed=0, randomize_seed=False, force_images=False):
29
+ seed = int(randomize_seed_fn(seed, randomize_seed))
30
+ set_env(seed)
31
+ video_length = transformer_model.config.video_length if not force_images else 1
32
+ height, width = int(args.version.split('x')[1]), int(args.version.split('x')[2])
33
+ num_frames = 1 if video_length == 1 else int(args.version.split('x')[0])
34
+ videos = videogen_pipeline(prompt,
35
+ video_length=video_length,
36
+ height=height,
37
+ width=width,
38
+ num_inference_steps=sample_steps,
39
+ guidance_scale=scale,
40
+ enable_temporal_attentions=not force_images,
41
+ num_images_per_prompt=1,
42
+ mask_feature=True,
43
+ ).video
44
+
45
+ torch.cuda.empty_cache()
46
+ videos = videos[0]
47
+ tmp_save_path = 'tmp.mp4'
48
+ imageio.mimwrite(tmp_save_path, videos, fps=24, quality=9) # highest quality is 10, lowest is 0
49
+ display_model_info = f"Video size: {num_frames}×{height}×{width}, \nSampling Step: {sample_steps}, \nGuidance Scale: {scale}"
50
+ return tmp_save_path, prompt, display_model_info, seed
51
+
52
+ if __name__ == '__main__':
53
+ args = type('args', (), {
54
+ 'ae': 'CausalVAEModel_4x8x8',
55
+ 'force_images': False,
56
+ 'model_path': 'LanguageBind/Open-Sora-Plan-v1.0.0',
57
+ 'text_encoder_name': 'DeepFloyd/t5-v1_1-xxl',
58
+ 'version': '65x512x512'
59
+ })
60
+ device = torch.device('cuda:0')
61
+
62
+ # Load model:
63
+ transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, torch_dtype=torch.float16, cache_dir='cache_dir').to(device)
64
+
65
+ vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir').to(device, dtype=torch.float16)
66
+ vae.vae.enable_tiling()
67
+ image_size = int(args.version.split('x')[1])
68
+ latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2])
69
+ vae.latent_size = latent_size
70
+ transformer_model.force_images = args.force_images
71
+ tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir="cache_dir")
72
+ text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir",
73
+ torch_dtype=torch.float16).to(device)
74
+
75
+ # set eval mode
76
+ transformer_model.eval()
77
+ vae.eval()
78
+ text_encoder.eval()
79
+ scheduler = PNDMScheduler()
80
+ videogen_pipeline = VideoGenPipeline(vae=vae,
81
+ text_encoder=text_encoder,
82
+ tokenizer=tokenizer,
83
+ scheduler=scheduler,
84
+ transformer=transformer_model).to(device=device)
85
+
86
+
87
+ demo = gr.Interface(
88
+ fn=generate_img,
89
+ inputs=[Textbox(label="",
90
+ placeholder="Please enter your prompt. \n"),
91
+ gr.Slider(
92
+ label='Sample Steps',
93
+ minimum=1,
94
+ maximum=500,
95
+ value=50,
96
+ step=10
97
+ ),
98
+ gr.Slider(
99
+ label='Guidance Scale',
100
+ minimum=0.1,
101
+ maximum=30.0,
102
+ value=10.0,
103
+ step=0.1
104
+ ),
105
+ gr.Slider(
106
+ label="Seed",
107
+ minimum=0,
108
+ maximum=203279,
109
+ step=1,
110
+ value=0,
111
+ ),
112
+ gr.Checkbox(label="Randomize seed", value=True),
113
+ gr.Checkbox(label="Generate image (1 frame video)", value=False),
114
+ ],
115
+ outputs=[Video(label="Vid", width=512, height=512),
116
+ Textbox(label="input prompt"),
117
+ Textbox(label="model info"),
118
+ gr.Slider(label='seed')],
119
+ title=title_markdown, description=DESCRIPTION, theme=gr.themes.Default(), css=block_css,
120
+ examples=examples,
121
+ cache_examples=False
122
+ )
123
+ demo.launch()