Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from torch import autocast | |
| from diffusers import StableDiffusionPipeline | |
| import random | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| from transformers import CLIPTokenizer | |
| # Initialize the model | |
| model_id = "CompVis/stable-diffusion-v1-4" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # List of concept embeddings compatible with SD v1.4 | |
| concepts = [ | |
| "sd-concepts-library/cat-toy", | |
| "sd-concepts-library/disco-diffusion-style", | |
| # "sd-concepts-library/modern-disney-style", | |
| # "sd-concepts-library/charliebo-artstyle", | |
| # "sd-concepts-library/redshift-render-style" | |
| ] | |
| def download_concept_embedding(concept_name): | |
| try: | |
| # Download the learned_embeds.bin file from the Hub | |
| embed_path = hf_hub_download( | |
| repo_id=concept_name, | |
| filename="learned_embeds.bin", | |
| repo_type="model" | |
| ) | |
| return embed_path | |
| except Exception as e: | |
| print(f"Error downloading {concept_name}: {str(e)}") | |
| return None | |
| def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer): | |
| loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") | |
| # Add the concept token to tokenizer | |
| token = list(loaded_learned_embeds.keys())[0] | |
| num_added_tokens = tokenizer.add_tokens(token) | |
| # Resize token embeddings | |
| text_encoder.resize_token_embeddings(len(tokenizer)) | |
| # Add the concept embedding | |
| token_id = tokenizer.convert_tokens_to_ids(token) | |
| text_encoder.get_input_embeddings().weight.data[token_id] = loaded_learned_embeds[token] | |
| return token | |
| def generate_images(prompt): | |
| images = [] | |
| failed_concepts = [] | |
| for concept in concepts: | |
| try: | |
| # Create a fresh pipeline for each concept | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
| ).to(device) | |
| # Download and load concept embedding | |
| embed_path = download_concept_embedding(concept) | |
| if embed_path is None: | |
| failed_concepts.append(concept) | |
| continue | |
| token = load_learned_embed_in_clip( | |
| embed_path, | |
| pipe.text_encoder, | |
| pipe.tokenizer | |
| ) | |
| # Generate random seed | |
| seed = random.randint(1, 999999) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| # Add concept token to prompt | |
| concept_prompt = f"{token} {prompt}" | |
| # Generate image | |
| with autocast(device): | |
| image = pipe( | |
| concept_prompt, | |
| num_inference_steps=20, | |
| generator=generator, | |
| guidance_scale=7.5 | |
| ).images[0] | |
| images.append(image) | |
| # Clean up to free memory | |
| del pipe | |
| torch.cuda.empty_cache() | |
| except Exception as e: | |
| print(f"Error processing concept {concept}: {str(e)}") | |
| failed_concepts.append(concept) | |
| continue | |
| if failed_concepts: | |
| print(f"Failed to process concepts: {', '.join(failed_concepts)}") | |
| # Return available images, pad with None if some failed | |
| while len(images) < 5: | |
| images.append(None) | |
| return images[:5] | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=generate_images, | |
| inputs=gr.Textbox(label="Enter your prompt"), | |
| outputs=[gr.Image(label=f"Concept {i+1}") for i in range(5)], | |
| title="Multi-Concept Stable Diffusion Generator", | |
| description="Generate images using 5 different artistic concepts from the SD Concepts Library" | |
| ) | |
| # Launch the app | |
| iface.launch() |