Spaces:
Runtime error
Runtime error
| # Prediction interface for Cog ⚙️ | |
| # https://cog.run/python | |
| from cog import BasePredictor, Input, Path | |
| import os | |
| import time | |
| import torch | |
| import subprocess | |
| from PIL import Image | |
| from typing import List | |
| from image_datasets.canny_dataset import canny_processor, c_crop | |
| from src.flux.util import load_ae, load_clip, load_t5, load_flow_model, load_controlnet, load_safetensors | |
| OUTPUT_DIR = "controlnet_results" | |
| MODEL_CACHE = "checkpoints" | |
| CONTROLNET_URL = "https://huggingface.co/XLabs-AI/flux-controlnet-canny/resolve/main/controlnet.safetensors" | |
| T5_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/t5-cache.tar" | |
| CLIP_URL = "https://weights.replicate.delivery/default/black-forest-labs/FLUX.1-dev/clip-cache.tar" | |
| HF_TOKEN = "hf_..." # Your HuggingFace token | |
| def download_weights(url, dest): | |
| start = time.time() | |
| print("downloading url: ", url) | |
| print("downloading to: ", dest) | |
| subprocess.check_call(["pget", "-xf", url, dest], close_fds=False) | |
| print("downloading took: ", time.time() - start) | |
| def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool): | |
| t5 = load_t5(device, max_length=256 if is_schnell else 512) | |
| clip = load_clip(device) | |
| model = load_flow_model(name, device="cpu" if offload else device) | |
| ae = load_ae(name, device="cpu" if offload else device) | |
| controlnet = load_controlnet(name, device).to(torch.bfloat16) | |
| return model, ae, t5, clip, controlnet | |
| class Predictor(BasePredictor): | |
| def setup(self) -> None: | |
| """Load the model into memory to make running multiple predictions efficient""" | |
| t1 = time.time() | |
| os.system(f"huggingface-cli login --token {HF_TOKEN}") | |
| name = "flux-dev" | |
| self.offload = False | |
| checkpoint = "controlnet.safetensors" | |
| print("Checking ControlNet weights") | |
| checkpoint = "controlnet.safetensors" | |
| if not os.path.exists(checkpoint): | |
| os.system(f"wget {CONTROLNET_URL}") | |
| print("Checking T5 weights") | |
| if not os.path.exists(MODEL_CACHE+"/models--google--t5-v1_1-xxl"): | |
| download_weights(T5_URL, MODEL_CACHE) | |
| print("Checking CLIP weights") | |
| if not os.path.exists(MODEL_CACHE+"/models--openai--clip-vit-large-patch14"): | |
| download_weights(CLIP_URL, MODEL_CACHE) | |
| self.is_schnell = False | |
| device = "cuda" | |
| self.torch_device = torch.device(device) | |
| model, ae, t5, clip, controlnet = get_models( | |
| name, | |
| device=self.torch_device, | |
| offload=self.offload, | |
| is_schnell=self.is_schnell, | |
| ) | |
| self.ae = ae | |
| self.t5 = t5 | |
| self.clip = clip | |
| self.controlnet = controlnet | |
| self.model = model.to(self.torch_device) | |
| if '.safetensors' in checkpoint: | |
| checkpoint1 = load_safetensors(checkpoint) | |
| else: | |
| checkpoint1 = torch.load(checkpoint, map_location='cpu') | |
| controlnet.load_state_dict(checkpoint1, strict=False) | |
| t2 = time.time() | |
| print(f"Setup time: {t2 - t1}") | |
| def preprocess_canny_image(self, image_path: str, width: int = 512, height: int = 512): | |
| image = Image.open(image_path) | |
| image = c_crop(image) | |
| image = image.resize((width, height)) | |
| image = canny_processor(image) | |
| return image | |
| def predict( | |
| self, | |
| prompt: str = Input(description="Input prompt", default="a handsome viking man with white hair, cinematic, MM full HD"), | |
| image: Path = Input(description="Input image", default=None), | |
| num_inference_steps: int = Input(description="Number of inference steps", ge=1, le=64, default=28), | |
| cfg: float = Input(description="CFG", ge=0, le=10, default=3.5), | |
| seed: int = Input(description="Random seed", default=None) | |
| ) -> List[Path]: | |
| """Run a single prediction on the model""" | |
| if seed is None: | |
| seed = int.from_bytes(os.urandom(2), "big") | |
| print(f"Using seed: {seed}") | |
| # clean output dir | |
| output_dir = "controlnet_results" | |
| os.system(f"rm -rf {output_dir}") | |
| input_image = str(image) | |
| img = Image.open(input_image) | |
| width, height = img.size | |
| # Resize input image if it's too large | |
| max_image_size = 1536 | |
| scale = min(max_image_size / width, max_image_size / height, 1) | |
| if scale < 1: | |
| width = int(width * scale) | |
| height = int(height * scale) | |
| print(f"Scaling image down to {width}x{height}") | |
| img = img.resize((width, height), resample=Image.Resampling.LANCZOS) | |
| input_image = "/tmp/resized_image.png" | |
| img.save(input_image) | |
| subprocess.check_call( | |
| ["python3", "main.py", | |
| "--local_path", "controlnet.safetensors", | |
| "--image", input_image, | |
| "--use_controlnet", | |
| "--control_type", "canny", | |
| "--prompt", prompt, | |
| "--width", str(width), | |
| "--height", str(height), | |
| "--num_steps", str(num_inference_steps), | |
| "--guidance", str(cfg), | |
| "--seed", str(seed) | |
| ], close_fds=False) | |
| # Find the first file that begins with "controlnet_result_" | |
| for file in os.listdir(output_dir): | |
| if file.startswith("controlnet_result_"): | |
| return [Path(os.path.join(output_dir, file))] | |