File size: 3,701 Bytes
62f5cce
 
 
 
 
adfea12
fc652c4
62f5cce
 
fc652c4
 
62f5cce
fc652c4
965498b
6b357fe
62f5cce
fc652c4
 
 
62f5cce
fc652c4
 
 
 
 
 
62f5cce
fc652c4
 
 
 
 
 
 
 
 
 
 
 
adfea12
 
fc652c4
 
 
 
 
 
62f5cce
 
 
fc652c4
 
 
62f5cce
 
 
 
a9c0e29
 
 
 
 
 
 
 
 
 
 
 
 
62f5cce
 
 
 
 
 
 
fc652c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62f5cce
 
 
 
 
 
 
 
 
e2e7f03
6b357fe
25423bd
e2e7f03
dd86f00
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python

from __future__ import annotations

import os
import random
import tempfile

import gradio as gr
import imageio
import numpy as np
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler



MAX_NUM_FRAMES = int(os.getenv('MAX_NUM_FRAMES', '200'))
DEFAULT_NUM_FRAMES = min(MAX_NUM_FRAMES,
                         int(os.getenv('DEFAULT_NUM_FRAMES', '16')))

pipe = DiffusionPipeline.from_pretrained('damo-vilab/text-to-video-ms-1.7b',
                                         torch_dtype=torch.float16,
                                         variant='fp16')
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_vae_slicing()


def to_video(frames: list[np.ndarray], fps: int) -> str:
    out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
    writer = imageio.get_writer(out_file.name, format='FFMPEG', fps=fps)
    for frame in frames:
        writer.append_data(frame)
    writer.close()
    return out_file.name


def generate(prompt: str, seed: int, num_frames: int,
             num_inference_steps: int) -> str:
    if seed == -1:
        seed = random.randint(0, 1000000)
    generator = torch.Generator().manual_seed(seed)
    frames = pipe(prompt,
                  num_inference_steps=num_inference_steps,
                  num_frames=num_frames,
                  generator=generator).frames
    return to_video(frames, 8)


examples = [
    ['An astronaut riding a horse.', 0, 16, 25],
    ['A panda eating bamboo on a rock.', 0, 16, 25],
    ['Spiderman is surfing.', 0, 16, 25],
]

with gr.Blocks(css='style.css') as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Group():
        with gr.Box():
            with gr.Row(elem_id='prompt-container').style(equal_height=True):
                prompt = gr.Text(
                    label='Prompt',
                    show_label=False,
                    max_lines=1,
                    placeholder='Enter your prompt',
                    elem_id='prompt-text-input').style(container=False)
                run_button = gr.Button('Generate video').style(
                    full_width=False)
        result = gr.Video(label='Result', show_label=False, elem_id='gallery')
        with gr.Accordion('Advanced options', open=False):
            seed = gr.Slider(
                label='Seed',
                minimum=-1,
                maximum=1000000,
                step=1,
                value=-1,
                info='If set to -1, a different seed will be used each time.')
            num_frames = gr.Slider(
                label='Number of frames',
                minimum=16,
                maximum=MAX_NUM_FRAMES,
                step=1,
                value=16,
                info=
                'Note that the content of the video also changes when you change the number of frames.'
            )
            num_inference_steps = gr.Slider(label='Number of inference steps',
                                            minimum=10,
                                            maximum=50,
                                            step=1,
                                            value=25)

    inputs = [
        prompt,
        seed,
        num_frames,
        num_inference_steps,
    ]
    gr.Examples(examples=examples,
                inputs=inputs,
                outputs=result,
                fn=generate,
                cache_examples=os.getenv('SYSTEM') == 'spaces')

    prompt.submit(fn=generate, inputs=inputs, outputs=result)
    run_button.click(fn=generate, inputs=inputs, outputs=result)


   

        
demo.queue(api_open=False, max_size=15).launch()