Spaces:
Running
on
Zero
Running
on
Zero
# Prediction interface for Cog ⚙️ | |
# https://cog.run/python | |
from cog import BasePredictor, Input, Path | |
import os | |
import time | |
import torch | |
import subprocess | |
from PIL import Image | |
from typing import List | |
from image_datasets.canny_dataset import canny_processor, c_crop | |
from src.flux.util import load_ae, load_clip, load_t5, load_flow_model, load_controlnet, load_safetensors | |
OUTPUT_DIR = "controlnet_results" | |
MODEL_CACHE = "checkpoints" | |
CONTROLNET_URL = "https://huggingface.co/XLabs-AI/flux-controlnet-canny/resolve/main/controlnet.safetensors" | |
T5_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/t5-cache.tar" | |
CLIP_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/clip-cache.tar" | |
HF_TOKEN = "hf_..." # Your HuggingFace token | |
def download_weights(url, dest): | |
start = time.time() | |
print("downloading url: ", url) | |
print("downloading to: ", dest) | |
subprocess.check_call(["pget", "-xf", url, dest], close_fds=False) | |
print("downloading took: ", time.time() - start) | |
def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool): | |
t5 = load_t5(device, max_length=256 if is_schnell else 512) | |
clip = load_clip(device) | |
model = load_flow_model(name, device="cpu" if offload else device) | |
ae = load_ae(name, device="cpu" if offload else device) | |
controlnet = load_controlnet(name, device).to(torch.bfloat16) | |
return model, ae, t5, clip, controlnet | |
class Predictor(BasePredictor): | |
def setup(self) -> None: | |
"""Load the model into memory to make running multiple predictions efficient""" | |
t1 = time.time() | |
os.system(f"huggingface-cli login --token {HF_TOKEN}") | |
name = "flux-dev" | |
self.offload = False | |
checkpoint = "controlnet.safetensors" | |
print("Checking ControlNet weights") | |
checkpoint = "controlnet.safetensors" | |
if not os.path.exists(checkpoint): | |
os.system(f"wget {CONTROLNET_URL}") | |
print("Checking T5 weights") | |
if not os.path.exists(MODEL_CACHE+"/models--google--t5-v1_1-xxl"): | |
download_weights(T5_URL, MODEL_CACHE) | |
print("Checking CLIP weights") | |
if not os.path.exists(MODEL_CACHE+"/models--openai--clip-vit-large-patch14"): | |
download_weights(CLIP_URL, MODEL_CACHE) | |
self.is_schnell = False | |
device = "cuda" | |
self.torch_device = torch.device(device) | |
model, ae, t5, clip, controlnet = get_models( | |
name, | |
device=self.torch_device, | |
offload=self.offload, | |
is_schnell=self.is_schnell, | |
) | |
self.ae = ae | |
self.t5 = t5 | |
self.clip = clip | |
self.controlnet = controlnet | |
self.model = model.to(self.torch_device) | |
if '.safetensors' in checkpoint: | |
checkpoint1 = load_safetensors(checkpoint) | |
else: | |
checkpoint1 = torch.load(checkpoint, map_location='cpu') | |
controlnet.load_state_dict(checkpoint1, strict=False) | |
t2 = time.time() | |
print(f"Setup time: {t2 - t1}") | |
def preprocess_canny_image(self, image_path: str, width: int = 512, height: int = 512): | |
image = Image.open(image_path) | |
image = c_crop(image) | |
image = image.resize((width, height)) | |
image = canny_processor(image) | |
return image | |
def predict( | |
self, | |
prompt: str = Input(description="Input prompt", default="a handsome viking man with white hair, cinematic, MM full HD"), | |
image: Path = Input(description="Input image", default=None), | |
num_inference_steps: int = Input(description="Number of inference steps", ge=1, le=64, default=28), | |
cfg: float = Input(description="CFG", ge=0, le=10, default=3.5), | |
seed: int = Input(description="Random seed", default=None) | |
) -> List[Path]: | |
"""Run a single prediction on the model""" | |
if seed is None: | |
seed = int.from_bytes(os.urandom(2), "big") | |
print(f"Using seed: {seed}") | |
# clean output dir | |
output_dir = "controlnet_results" | |
os.system(f"rm -rf {output_dir}") | |
input_image = str(image) | |
img = Image.open(input_image) | |
width, height = img.size | |
# Resize input image if it's too large | |
max_image_size = 1536 | |
scale = min(max_image_size / width, max_image_size / height, 1) | |
if scale < 1: | |
width = int(width * scale) | |
height = int(height * scale) | |
print(f"Scaling image down to {width}x{height}") | |
img = img.resize((width, height), resample=Image.Resampling.LANCZOS) | |
input_image = "/tmp/resized_image.png" | |
img.save(input_image) | |
subprocess.check_call( | |
["python3", "main.py", | |
"--local_path", "controlnet.safetensors", | |
"--image", input_image, | |
"--use_controlnet", | |
"--control_type", "canny", | |
"--prompt", prompt, | |
"--width", str(width), | |
"--height", str(height), | |
"--num_steps", str(num_inference_steps), | |
"--guidance", str(cfg), | |
"--seed", str(seed) | |
], close_fds=False) | |
# Find the first file that begins with "controlnet_result_" | |
for file in os.listdir(output_dir): | |
if file.startswith("controlnet_result_"): | |
return [Path(os.path.join(output_dir, file))] | |