Spaces:
Paused
Paused
| import argparse | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| from model import FluxModel | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Flux Image Generation Tool') | |
| # Required arguments | |
| parser.add_argument('--mode', type=str, required=True, | |
| choices=['variation', 'img2img', 'inpaint', 'controlnet', 'controlnet-inpaint'], | |
| help='Generation mode') | |
| parser.add_argument('--input_image', type=str, required=True, | |
| help='Path to the input image') | |
| # Optional arguments | |
| parser.add_argument('--prompt', type=str, default="", | |
| help='Text prompt to guide the generation') | |
| parser.add_argument('--reference_image', type=str, default=None, | |
| help='Path to the reference image (for img2img/controlnet modes)') | |
| parser.add_argument('--mask_image', type=str, default=None, | |
| help='Path to the mask image (for inpainting modes)') | |
| parser.add_argument('--output_dir', type=str, default='outputs', | |
| help='Directory to save generated images') | |
| parser.add_argument('--image_count', type=int, default=1, | |
| help='Number of images to generate') | |
| parser.add_argument('--aspect_ratio', type=str, default='1:1', | |
| choices=['1:1', '16:9', '9:16', '2.4:1', '3:4', '4:3'], | |
| help='Output image aspect ratio') | |
| parser.add_argument('--steps', type=int, default=28, | |
| help='Number of inference steps') | |
| parser.add_argument('--guidance_scale', type=float, default=7.5, | |
| help='Guidance scale for generation') | |
| parser.add_argument('--denoise_strength', type=float, default=0.8, | |
| help='Denoising strength for img2img/inpaint') | |
| # Attention related arguments | |
| parser.add_argument('--center_x', type=float, default=None, | |
| help='X coordinate of attention center (0-1)') | |
| parser.add_argument('--center_y', type=float, default=None, | |
| help='Y coordinate of attention center (0-1)') | |
| parser.add_argument('--radius', type=float, default=None, | |
| help='Radius of attention circle (0-1)') | |
| # ControlNet related arguments | |
| parser.add_argument('--line_mode', action='store_true', | |
| help='Enable line detection mode for ControlNet') | |
| parser.add_argument('--depth_mode', action='store_true', | |
| help='Enable depth mode for ControlNet') | |
| parser.add_argument('--line_strength', type=float, default=0.4, | |
| help='Strength of line guidance') | |
| parser.add_argument('--depth_strength', type=float, default=0.2, | |
| help='Strength of depth guidance') | |
| # Device selection | |
| parser.add_argument('--device', type=str, default='cuda', | |
| choices=['cuda', 'cpu'], | |
| help='Device to run the model on') | |
| parser.add_argument('--turbo', action='store_true', | |
| help='Enable turbo mode for faster inference') | |
| return parser.parse_args() | |
| def load_image(image_path): | |
| """Load and return a PIL Image.""" | |
| try: | |
| return Image.open(image_path).convert('RGB') | |
| except Exception as e: | |
| raise ValueError(f"Error loading image {image_path}: {str(e)}") | |
| def save_images(images, output_dir, prefix="generated"): | |
| """Save generated images with sequential numbering.""" | |
| import os | |
| os.makedirs(output_dir, exist_ok=True) | |
| for i, image in enumerate(images): | |
| output_path = os.path.join(output_dir, f"{prefix}_{i+1}.png") | |
| image.save(output_path) | |
| print(f"Saved image to {output_path}") | |
| def get_required_features(args): | |
| """Determine which model features are required based on the arguments.""" | |
| features = [] | |
| if args.mode in ['controlnet', 'controlnet-inpaint']: | |
| features.append('controlnet') | |
| if args.depth_mode: | |
| features.append('depth') | |
| if args.line_mode: | |
| features.append('line') | |
| if args.mode in ['inpaint', 'controlnet-inpaint']: | |
| features.append('sam') # If you're using SAM for mask generation | |
| return features | |
| def main(): | |
| args = parse_args() | |
| # Check CUDA availability if requested | |
| if args.device == 'cuda' and not torch.cuda.is_available(): | |
| print("CUDA requested but not available. Falling back to CPU.") | |
| args.device = 'cpu' | |
| # Determine required features based on mode and arguments | |
| required_features = get_required_features(args) | |
| # Initialize model with only required features | |
| print(f"Initializing model on {args.device} with features: {required_features}") | |
| model = FluxModel( | |
| is_turbo=args.turbo, | |
| device=args.device, | |
| required_features=required_features | |
| ) | |
| # Load input images | |
| input_image = load_image(args.input_image) | |
| reference_image = load_image(args.reference_image) if args.reference_image else None | |
| mask_image = load_image(args.mask_image) if args.mask_image else None | |
| # Validate inputs based on mode | |
| if args.mode in ['inpaint', 'controlnet-inpaint'] and mask_image is None: | |
| raise ValueError(f"{args.mode} mode requires a mask image") | |
| # Generate images | |
| print(f"Generating {args.image_count} images in {args.mode} mode...") | |
| generated_images = model.generate( | |
| input_image_a=input_image, | |
| input_image_b=reference_image, | |
| prompt=args.prompt, | |
| mask_image=mask_image, | |
| mode=args.mode, | |
| imageCount=args.image_count, | |
| aspect_ratio=args.aspect_ratio, | |
| num_inference_steps=args.steps, | |
| guidance_scale=args.guidance_scale, | |
| denoise_strength=args.denoise_strength, | |
| center_x=args.center_x, | |
| center_y=args.center_y, | |
| radius=args.radius, | |
| line_mode=args.line_mode, | |
| depth_mode=args.depth_mode, | |
| line_strength=args.line_strength, | |
| depth_strength=args.depth_strength | |
| ) | |
| # Save generated images | |
| save_images(generated_images, args.output_dir) | |
| print("Generation completed successfully!") | |
| if __name__ == "__main__": | |
| main() |