kiwhansong's picture
finish task 1
5eea811
raw
history blame
1.21 kB
from typing import List
import tempfile
import numpy as np
import torch
from torch import Tensor
from torchvision.io import write_video
from PIL import Image
def export_to_video(tensor: Tensor, fps: int = 10) -> str:
path = tempfile.NamedTemporaryFile(suffix=".mp4").name
write_video(path, (tensor.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8), fps=fps)
return path
def export_to_gif(tensor: Tensor, fps: int = 4) -> str:
path = tempfile.NamedTemporaryFile(suffix=".gif").name
images = (tensor.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8)
images = [Image.fromarray(image.numpy()) for image in images]
images[0].save(
path,
save_all=True,
append_images=images[1:],
optimize=False,
duration=1000 // fps,
loop=0,
)
return path
def export_images_to_gif(images: List[np.ndarray], fps: int = 4) -> str:
path = tempfile.NamedTemporaryFile(suffix=".gif").name
images = [Image.fromarray(image) for image in images]
images[0].save(
path,
save_all=True,
append_images=images[1:],
optimize=False,
duration=1000 // fps,
loop=0,
)
return path