import gradio as gr from PIL import Image import numpy as np def concat_images(images, patch_size=256): if not images: return None # Convert images to PIL format if they aren't already imgs = [] for img_tuple in images: if img_tuple is None: continue # Gallery returns (image, caption) tuples img = img_tuple[0] # Get the image from the tuple if isinstance(img, str): # If the input is a string path, load the image img = Image.open(img) elif isinstance(img, np.ndarray): img = Image.fromarray(img) imgs.append(img) if not imgs: return None # Resize all images to fit within patch_size x patch_size while maintaining aspect ratio resized_imgs = [] for img in imgs: if img is None: continue # Calculate scaling factor to fit within patch_size x patch_size width, height = img.size scale = min(patch_size/width, patch_size/height) new_width = int(width * scale) new_height = int(height * scale) # Resize image maintaining aspect ratio resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) # Create a new white background image padded_img = Image.new('RGB', (patch_size, patch_size), (255, 255, 255)) # Calculate position to paste the resized image (centered) paste_x = (patch_size - new_width) // 2 paste_y = (patch_size - new_height) // 2 # Paste the resized image onto the white background padded_img.paste(resized_img, (paste_x, paste_y)) resized_imgs.append(padded_img) if not resized_imgs: return None # Calculate grid dimensions n_images = len(resized_imgs) row_size = int(np.ceil(np.sqrt(n_images))) # Create the grid grid = [] for i in range(0, n_images, row_size): row_imgs = resized_imgs[i:i+row_size] # Pad the last row if necessary while len(row_imgs) < row_size: row_imgs.append(Image.new('RGB', (patch_size, patch_size), (255, 255, 255))) row = np.hstack([np.array(img) for img in row_imgs]) grid.append(row) # Combine all rows full_img = np.vstack(grid) return Image.fromarray(full_img) def process_images(files, gallery, patch_size): if not files: return gallery, None # Add all new images to the gallery if gallery is None: gallery = [] for file in files: try: if not file.name.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp')): print(f"Skipping invalid file type: {file.name}") continue img = Image.open(file.name) gallery.append((img, None)) except Exception as e: print(f"Error loading image {file.name}: {e}") # Generate the concatenated result result = concat_images(gallery, patch_size) return gallery, result def clear_gallery(): return None, None with gr.Blocks() as demo: gr.Markdown("# Image Concatenation") gr.Markdown("Upload multiple images to create a grid of concatenated images.") with gr.Row(): with gr.Column(): file_input = gr.File( label="Upload Images", file_count="multiple", file_types=[".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"], height=400 ) clear_button = gr.Button("Clear All Images") with gr.Column(): image_gallery = gr.Gallery( label="Image Gallery", show_label=True, elem_id="gallery", columns=[4], rows=[4], height=400, allow_preview=True ) with gr.Row(): image_output = gr.Image(type="pil", label="Result") patch_size = gr.Slider( minimum=256, maximum=1024, value=256, step=256, label="Patch Size" ) # Set up the event handlers file_input.upload( fn=process_images, inputs=[file_input, image_gallery, patch_size], outputs=[image_gallery, image_output] ) patch_size.change( fn=process_images, inputs=[file_input, image_gallery, patch_size], outputs=[image_gallery, image_output] ) clear_button.click( fn=clear_gallery, outputs=[image_gallery, image_output] ) demo.launch()