Spaces:
Runtime error
Runtime error
| import os | |
| import zipfile | |
| from pathlib import Path | |
| import gradio as gr | |
| from database import ( | |
| get_user_credits, | |
| update_user_credits, | |
| get_lora_models_info, | |
| get_user_lora_models | |
| ) | |
| from services.image_generation import generate_image | |
| from services.train_lora import lora_pipeline | |
| from utils.image_utils import url_to_pil_image | |
| from utils.file_utils import load_file_content | |
| LORA_MODELS = get_lora_models_info() | |
| if not isinstance(LORA_MODELS, list): | |
| raise ValueError("Expected loras_models to be a list of dictionaries.") | |
| BASE_DIR = Path(__file__).parent | |
| LOGIN_CSS_PATH = BASE_DIR / 'static/css/login.css' | |
| MAIN_CSS_PATH = BASE_DIR / 'static/css/main.css' | |
| LANDING_HTML_PATH = BASE_DIR / 'static/html/landing.html' | |
| MAIN_HEADER_PATH = BASE_DIR / 'static/html/main_header.html' | |
| LOGIN_CSS = load_file_content(LOGIN_CSS_PATH) | |
| MAIN_CSS = load_file_content(MAIN_CSS_PATH) | |
| LANDING_PAGE = load_file_content(LANDING_HTML_PATH) | |
| MAIN_HEADER = load_file_content(MAIN_HEADER_PATH) | |
| def load_user_models(request: gr.Request): | |
| user = request.session.get('user') | |
| print(user) | |
| if user: | |
| user_models = get_user_lora_models(user['id']) | |
| if user_models: | |
| return [(item.get("image_url", "assets/logo.jpg"), item["lora_name"]) for item in user_models] | |
| return [] | |
| def update_selection(evt: gr.SelectData, gallery_type: str, width, height): | |
| if gallery_type == "user": | |
| selected_lora = {"lora_name": "custom", "trigger_word": "custom"} | |
| else: | |
| selected_lora = LORA_MODELS[evt.index] | |
| new_placeholder = f"Enter a prompt for {selected_lora['lora_name']}" | |
| trigger_word = selected_lora["trigger_word"] | |
| updated_text = f"#### Trigger Word: {trigger_word} ✨" | |
| if "aspect" in selected_lora: | |
| if selected_lora["aspect"] == "portrait": | |
| width, height = 768, 1024 | |
| elif selected_lora["aspect"] == "landscape": | |
| width, height = 1024, 768 | |
| return gr.update(placeholder=new_placeholder), updated_text, evt.index, width, height, gallery_type | |
| def compress_and_train(request: gr.Request, files, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate): | |
| if not files: | |
| return "No Images. Please, upload some images to start training" | |
| user = request.session.get('user') | |
| _, training_credits = get_user_credits(user['id']) | |
| if training_credits <= 0: | |
| raise gr.Error("You ran out of credtis. Please buy more to continue") | |
| if not user: | |
| raise gr.Error("User not authenticated. Please log in.") | |
| user_id = user['id'] | |
| # Create a directory in the user's home folder | |
| output_dir = os.path.expanduser("~/gradio_training_data") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Create a zip file in the output directory | |
| zip_path = os.path.join(output_dir, "training_data.zip") | |
| with zipfile.ZipFile(zip_path, 'w') as zipf: | |
| for file_info in files: | |
| file_path = file_info[0] # The first element of the tuple is the file path | |
| file_name = os.path.basename(file_path) | |
| zipf.write(file_path, file_name) | |
| print(f"Zip file created at: {zip_path}") | |
| print(f'[INFO] Procesando {trigger_word}') | |
| # Now call the train_lora function with the zip file path | |
| result = lora_pipeline(user_id, | |
| zip_path, | |
| model_name, | |
| trigger_word=trigger_word, | |
| steps=train_steps, | |
| lora_rank=lora_rank, | |
| batch_size=batch_size, | |
| autocaption=True, | |
| learning_rate=learning_rate) | |
| new_training_credits = training_credits - 1 | |
| update_user_credits(user['id'], user['generation_credits'], new_training_credits) | |
| # Update session data | |
| user['training_credits'] = new_training_credits | |
| request.session['user'] = user | |
| return gr.Info("Your model is training. In about 20 minutes, it will be ready for you to test in 'Generation"), new_training_credits | |
| def run_lora(request: gr.Request, prompt, cfg_scale, steps, selected_index, selected_gallery, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)): | |
| user = request.session.get('user') | |
| if not user: | |
| raise gr.Error("User not authenticated. Please log in.") | |
| lora_models = get_user_lora_models(user['id']) | |
| print(f'Selected gallery: {selected_gallery}') | |
| if selected_gallery == "user": | |
| lora_models = get_user_lora_models(user['id']) | |
| print('Using user models') | |
| else: # public | |
| lora_models = get_lora_models_info() | |
| print('Using public models') | |
| print(f'Selected index: {selected_index}') | |
| if selected_index is None: | |
| selected_lora = None | |
| else: | |
| selected_lora = lora_models[selected_index] | |
| generation_credits, _ = get_user_credits(user['id']) | |
| if selected_lora: | |
| print(f"Selected Lora: {selected_lora['lora_name']}") | |
| model_name = selected_lora['lora_name'] | |
| use_default = False | |
| else: | |
| model_name = "black-forest-labs/flux-pro" | |
| print(f"Using default Lora: {model_name}") | |
| use_default = True | |
| if generation_credits <= 0: | |
| raise gr.Error("Ya no tienes creditos disponibles. Compra para continuar.") | |
| image_url = generate_image(model_name, prompt, steps, cfg_scale, width, height, lora_scale, progress, use_default) | |
| image = url_to_pil_image(image_url) | |
| # Update user's credits | |
| new_generation_credits = generation_credits - 1 | |
| update_user_credits(user['id'], new_generation_credits, user['train_credits']) | |
| # Update session data | |
| user['generation_credits'] = new_generation_credits | |
| request.session['user'] = user | |
| print(f"Generation credits remaining: {new_generation_credits}") | |
| return image, new_generation_credits | |
| def display_credits(request: gr.Request): | |
| user = request.session.get('user') | |
| if user: | |
| generation_credits, train_credits = get_user_credits(user['id']) | |
| return generation_credits, train_credits | |
| return 0, 0 | |
| def load_greet_and_credits(request: gr.Request): | |
| greeting = greet(request) | |
| generation_credits, train_credits = display_credits(request) | |
| return greeting, generation_credits, train_credits | |
| def greet(request: gr.Request): | |
| user = request.session.get('user') | |
| if user: | |
| with gr.Column(): | |
| with gr.Row(): | |
| greeting = f"Hola 👋 {user['given_name']}!" | |
| return f"{greeting}\n" | |
| return "OBTU AI. Please log in." | |
| with gr.Blocks(theme=gr.themes.Soft(), css=LOGIN_CSS) as login_demo: | |
| with gr.Column(elem_id="google-btn-container", elem_classes="google-btn-container svelte-vt1mxs gap"): | |
| btn = gr.Button("Sign In with Google", elem_classes="login-with-google-btn") | |
| _js_redirect = """ | |
| () => { | |
| url = '/login' + window.location.search; | |
| window.open(url, '_blank'); | |
| } | |
| """ | |
| btn.click(None, js=_js_redirect) | |
| gr.HTML(LANDING_PAGE) | |
| header = '<script src="https://cdn.lordicon.com/lordicon.js"></script>' | |
| with gr.Blocks(theme=gr.themes.Soft(), head=header, css=MAIN_CSS) as main_demo: | |
| title = gr.HTML(MAIN_HEADER) | |
| with gr.Column(elem_id="logout-btn-container"): | |
| gr.Button("Logout", link="/logout", elem_id="logout_btn") | |
| greetings = gr.Markdown("Loading user information...") | |
| selected_index = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(): | |
| generation_credits_display = gr.Number(label="Generation Credits", precision=0, interactive=False) | |
| with gr.Column(): | |
| train_credits_display = gr.Number(label="Training Credits", precision=0, interactive=False) | |
| with gr.Column(): | |
| gr.Button("Buy Credits 💳", link="/buy_credits") | |
| with gr.Tabs(): | |
| with gr.TabItem('Create'): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| prompt = gr.Textbox(label="Prompt", | |
| lines=1, | |
| placeholder="Enter Your Prompt to start creating 📷", | |
| info='Some public models may experience longer processing times due to server availability and queue management.') | |
| with gr.Column(scale=1, elem_id="gen_column"): | |
| generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| result = gr.Image(label="Imagen Generada") | |
| with gr.Column(scale=3): | |
| with gr.Accordion("Public Models"): | |
| selected_info = gr.Markdown("") | |
| gallery = gr.Gallery( | |
| [(item["image_url"], item["model_name"]) for item in LORA_MODELS], | |
| label="Public Models", | |
| allow_preview=False, | |
| columns=3, | |
| elem_id="gallery" | |
| ) | |
| with gr.Accordion("Your Models"): | |
| user_model_gallery = gr.Gallery( | |
| label="Galeria de Modelos", | |
| allow_preview=False, | |
| columns=3, | |
| elem_id="galley" | |
| ) | |
| gallery_type = gr.State("Public") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5) | |
| steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28) | |
| with gr.Row(): | |
| width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024) | |
| height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024) | |
| with gr.Row(): | |
| randomize_seed = gr.Checkbox(True, label="Randomize seed") | |
| lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95) | |
| gallery.select( | |
| update_selection, | |
| inputs=[gr.State("public"), width, height], | |
| outputs=[prompt, selected_info, selected_index, width, height, gallery_type] | |
| ) | |
| user_model_gallery.select( | |
| update_selection, | |
| inputs=[gr.State("user"), width, height], | |
| outputs=[prompt, selected_info, selected_index, width, height, gallery_type] | |
| ) | |
| gr.on( | |
| triggers=[generate_button.click, prompt.submit], | |
| fn=run_lora, | |
| inputs=[prompt, cfg_scale, steps, selected_index, gallery_type, width, height, lora_scale], | |
| outputs=[result, generation_credits_display] | |
| ) | |
| with gr.TabItem("Train"): | |
| gr.Markdown("# Train your own model 🧠") | |
| gr.Markdown("In this section, you can train your own model using your images.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| train_dataset = gr.Gallery(columns=4, interactive=True, label="Tus Imagenes") | |
| model_name = gr.Textbox(label="Model Name",) | |
| trigger_word = gr.Textbox(label="Trigger Word", | |
| info="This will be a keyword to later instruct the model when to use these new capabilities we're going to teach it", | |
| ) | |
| train_button = gr.Button("Start Training") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| train_steps = gr.Slider(label="Training Steps", minimum=100, maximum=10000, step=100, value=1000) | |
| lora_rank = gr.Number(label='lora_rank', value=16) | |
| batch_size = gr.Number(label='batch_size', value=1) | |
| learning_rate = gr.Number(label='learning_rate', value=0.0004) | |
| training_status = gr.Textbox(label="Training Status") | |
| train_button.click( | |
| compress_and_train, | |
| inputs=[train_dataset, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate], | |
| outputs=[training_status,train_credits_display] | |
| ) | |
| main_demo.load(load_user_models, None, user_model_gallery) | |
| main_demo.load(load_greet_and_credits, None, [greetings, generation_credits_display, train_credits_display]) |