from datasets import load_dataset

# dataset_name = "dim/nfs_pix2pix_1920_1080_v5"
dataset_name = "dim/nfs_pix2pix_1920_1080_v6"
dataset = load_dataset(dataset_name, num_proc=4)
dataset = dataset["train"]
import os

os.chdir("/code/img2img-turbo/src")
import argparse
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import torchvision.transforms.functional as F
from pix2pix_turbo import Pix2Pix_Turbo
from image_prep import canny_from_pil

model_name = ""
model_path = "/code/img2img-turbo/output/pix2pix_turbo/nfs_pix2pix_1736564855/checkpoints/model_16001.pkl"
use_fp16 = False

# initialize the model
model = Pix2Pix_Turbo(pretrained_name=model_name, pretrained_path=model_path)
model.set_eval()
if use_fp16:
    model.half()

T = transforms.Compose(
    [
        transforms.Resize(512, interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.CenterCrop(512),
    ]
)
input_image = dataset[290]["input_image"].convert("RGB")
prompt = dataset[0]["edit_prompt"]
with torch.no_grad():
    i_t = T(input_image)
    c_t = F.to_tensor(i_t).unsqueeze(0).cuda()
    if use_fp16:
        c_t = c_t.half()
    output_image = model(c_t, prompt)

    output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)

output_pil
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .