from datasets import load_dataset
dataset_name = "dim/nfs_pix2pix_1920_1080_v6"
dataset = load_dataset(dataset_name, num_proc=4)
dataset = dataset["train"]
import os
os.chdir("/code/img2img-turbo/src")
import argparse
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import torchvision.transforms.functional as F
from pix2pix_turbo import Pix2Pix_Turbo
from image_prep import canny_from_pil
model_name = ""
model_path = "/code/img2img-turbo/output/pix2pix_turbo/nfs_pix2pix_1736564855/checkpoints/model_16001.pkl"
use_fp16 = False
model = Pix2Pix_Turbo(pretrained_name=model_name, pretrained_path=model_path)
model.set_eval()
if use_fp16:
model.half()
T = transforms.Compose(
[
transforms.Resize(512, interpolation=transforms.InterpolationMode.LANCZOS),
transforms.CenterCrop(512),
]
)
input_image = dataset[290]["input_image"].convert("RGB")
prompt = dataset[0]["edit_prompt"]
with torch.no_grad():
i_t = T(input_image)
c_t = F.to_tensor(i_t).unsqueeze(0).cuda()
if use_fp16:
c_t = c_t.half()
output_image = model(c_t, prompt)
output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
output_pil