Allex21 commited on
Commit
90f1e5c
·
verified ·
1 Parent(s): a1e2923

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +698 -0
app.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ LoRA Trainer Funcional para Hugging Face
4
+ Baseado no kohya-ss sd-scripts
5
+ """
6
+
7
+ import gradio as gr
8
+ import os
9
+ import sys
10
+ import json
11
+ import subprocess
12
+ import shutil
13
+ import zipfile
14
+ import tempfile
15
+ import toml
16
+ import logging
17
+ from pathlib import Path
18
+ from typing import Optional, Tuple, List, Dict, Any
19
+ import time
20
+ import threading
21
+ import queue
22
+
23
+ # Adicionar o diretório sd-scripts ao path
24
+ sys.path.insert(0, str(Path(__file__).parent / "sd-scripts"))
25
+
26
+ # Configurar logging
27
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
28
+ logger = logging.getLogger(__name__)
29
+
30
+ class LoRATrainerHF:
31
+ def __init__(self):
32
+ self.base_dir = Path("/tmp/lora_training")
33
+ self.base_dir.mkdir(exist_ok=True)
34
+
35
+ self.models_dir = self.base_dir / "models"
36
+ self.models_dir.mkdir(exist_ok=True)
37
+
38
+ self.projects_dir = self.base_dir / "projects"
39
+ self.projects_dir.mkdir(exist_ok=True)
40
+
41
+ self.sd_scripts_dir = Path(__file__).parent / "sd-scripts"
42
+
43
+ # URLs dos modelos
44
+ self.model_urls = {
45
+ "Anime (animefull-final-pruned)": "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/animefull-final-pruned-fp16.safetensors",
46
+ "AnyLoRA": "https://huggingface.co/Lykon/AnyLoRA/resolve/main/AnyLoRA_noVae_fp16-pruned.ckpt",
47
+ "Stable Diffusion 1.5": "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors",
48
+ "Waifu Diffusion 1.4": "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt"
49
+ }
50
+
51
+ self.training_process = None
52
+ self.training_output_queue = queue.Queue()
53
+
54
+ def install_dependencies(self) -> str:
55
+ """Instala as dependências necessárias"""
56
+ try:
57
+ logger.info("Instalando dependências...")
58
+
59
+ # Lista de pacotes necessários
60
+ packages = [
61
+ "torch>=2.0.0",
62
+ "torchvision>=0.15.0",
63
+ "diffusers>=0.21.0",
64
+ "transformers>=4.25.0",
65
+ "accelerate>=0.20.0",
66
+ "safetensors>=0.3.0",
67
+ "huggingface-hub>=0.16.0",
68
+ "xformers>=0.0.20",
69
+ "bitsandbytes>=0.41.0",
70
+ "opencv-python>=4.7.0",
71
+ "Pillow>=9.0.0",
72
+ "numpy>=1.21.0",
73
+ "tqdm>=4.64.0",
74
+ "toml>=0.10.0",
75
+ "tensorboard>=2.13.0",
76
+ "wandb>=0.15.0",
77
+ "scipy>=1.9.0",
78
+ "matplotlib>=3.5.0",
79
+ "datasets>=2.14.0",
80
+ "peft>=0.5.0",
81
+ "omegaconf>=2.3.0"
82
+ ]
83
+
84
+ # Instalar pacotes
85
+ for package in packages:
86
+ try:
87
+ subprocess.run([
88
+ sys.executable, "-m", "pip", "install", package, "--quiet"
89
+ ], check=True, capture_output=True, text=True)
90
+ logger.info(f"✓ {package} instalado")
91
+ except subprocess.CalledProcessError as e:
92
+ logger.warning(f"⚠ Erro ao instalar {package}: {e}")
93
+
94
+ return "✅ Dependências instaladas com sucesso!"
95
+
96
+ except Exception as e:
97
+ logger.error(f"Erro ao instalar dependências: {e}")
98
+ return f"❌ Erro ao instalar dependências: {e}"
99
+
100
+ def download_model(self, model_choice: str, custom_url: str = "") -> str:
101
+ """Download do modelo base"""
102
+ try:
103
+ if custom_url.strip():
104
+ model_url = custom_url.strip()
105
+ model_name = model_url.split("/")[-1]
106
+ else:
107
+ if model_choice not in self.model_urls:
108
+ return f"❌ Modelo '{model_choice}' não encontrado"
109
+ model_url = self.model_urls[model_choice]
110
+ model_name = model_url.split("/")[-1]
111
+
112
+ model_path = self.models_dir / model_name
113
+
114
+ if model_path.exists():
115
+ return f"✅ Modelo já existe: {model_name}"
116
+
117
+ logger.info(f"Baixando modelo: {model_url}")
118
+
119
+ # Download usando wget
120
+ result = subprocess.run([
121
+ "wget", "-O", str(model_path), model_url, "--progress=bar:force"
122
+ ], capture_output=True, text=True)
123
+
124
+ if result.returncode == 0:
125
+ return f"✅ Modelo baixado: {model_name} ({model_path.stat().st_size // (1024*1024)} MB)"
126
+ else:
127
+ return f"❌ Erro no download: {result.stderr}"
128
+
129
+ except Exception as e:
130
+ logger.error(f"Erro ao baixar modelo: {e}")
131
+ return f"❌ Erro ao baixar modelo: {e}"
132
+
133
+ def process_dataset(self, dataset_zip, project_name: str) -> Tuple[str, str]:
134
+ """Processa o dataset enviado"""
135
+ try:
136
+ if not dataset_zip:
137
+ return "❌ Nenhum dataset foi enviado", ""
138
+
139
+ if not project_name.strip():
140
+ return "❌ Nome do projeto é obrigatório", ""
141
+
142
+ project_name = project_name.strip().replace(" ", "_")
143
+ project_dir = self.projects_dir / project_name
144
+ project_dir.mkdir(exist_ok=True)
145
+
146
+ dataset_dir = project_dir / "dataset"
147
+ if dataset_dir.exists():
148
+ shutil.rmtree(dataset_dir)
149
+ dataset_dir.mkdir()
150
+
151
+ # Extrair ZIP
152
+ with zipfile.ZipFile(dataset_zip.name, 'r') as zip_ref:
153
+ zip_ref.extractall(dataset_dir)
154
+
155
+ # Analisar dataset
156
+ image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff'}
157
+ images = []
158
+ captions = []
159
+
160
+ for file_path in dataset_dir.rglob("*"):
161
+ if file_path.suffix.lower() in image_extensions:
162
+ images.append(file_path)
163
+
164
+ # Procurar caption
165
+ caption_path = file_path.with_suffix('.txt')
166
+ if caption_path.exists():
167
+ captions.append(caption_path)
168
+
169
+ info = f"✅ Dataset processado!\n"
170
+ info += f"📁 Projeto: {project_name}\n"
171
+ info += f"🖼️ Imagens: {len(images)}\n"
172
+ info += f"📝 Captions: {len(captions)}\n"
173
+ info += f"📂 Diretório: {dataset_dir}"
174
+
175
+ return info, str(dataset_dir)
176
+
177
+ except Exception as e:
178
+ logger.error(f"Erro ao processar dataset: {e}")
179
+ return f"❌ Erro ao processar dataset: {e}", ""
180
+
181
+ def create_training_config(self,
182
+ project_name: str,
183
+ dataset_dir: str,
184
+ model_choice: str,
185
+ custom_model_url: str,
186
+ resolution: int,
187
+ batch_size: int,
188
+ epochs: int,
189
+ learning_rate: float,
190
+ text_encoder_lr: float,
191
+ network_dim: int,
192
+ network_alpha: int,
193
+ lora_type: str,
194
+ optimizer: str,
195
+ scheduler: str,
196
+ flip_aug: bool,
197
+ shuffle_caption: bool,
198
+ keep_tokens: int,
199
+ clip_skip: int,
200
+ mixed_precision: str,
201
+ save_every_n_epochs: int,
202
+ max_train_steps: int) -> str:
203
+ """Cria configuração de treinamento"""
204
+ try:
205
+ if not project_name.strip():
206
+ return "❌ Nome do projeto é obrigatório"
207
+
208
+ project_name = project_name.strip().replace(" ", "_")
209
+ project_dir = self.projects_dir / project_name
210
+ project_dir.mkdir(exist_ok=True)
211
+
212
+ output_dir = project_dir / "output"
213
+ output_dir.mkdir(exist_ok=True)
214
+
215
+ log_dir = project_dir / "logs"
216
+ log_dir.mkdir(exist_ok=True)
217
+
218
+ # Determinar modelo
219
+ if custom_model_url.strip():
220
+ model_name = custom_model_url.strip().split("/")[-1]
221
+ else:
222
+ model_name = self.model_urls[model_choice].split("/")[-1]
223
+
224
+ model_path = self.models_dir / model_name
225
+
226
+ if not model_path.exists():
227
+ return f"❌ Modelo não encontrado: {model_name}. Faça o download primeiro."
228
+
229
+ # Configuração do dataset
230
+ dataset_config = {
231
+ "general": {
232
+ "shuffle_caption": shuffle_caption,
233
+ "caption_extension": ".txt",
234
+ "keep_tokens": keep_tokens,
235
+ "flip_aug": flip_aug,
236
+ "color_aug": False,
237
+ "face_crop_aug_range": None,
238
+ "random_crop": False,
239
+ "debug_dataset": False
240
+ },
241
+ "datasets": [{
242
+ "resolution": resolution,
243
+ "batch_size": batch_size,
244
+ "subsets": [{
245
+ "image_dir": str(dataset_dir),
246
+ "num_repeats": 1
247
+ }]
248
+ }]
249
+ }
250
+
251
+ # Configuração de treinamento
252
+ training_config = {
253
+ "model_arguments": {
254
+ "pretrained_model_name_or_path": str(model_path),
255
+ "v2": False,
256
+ "v_parameterization": False,
257
+ "clip_skip": clip_skip
258
+ },
259
+ "dataset_arguments": {
260
+ "dataset_config": str(project_dir / "dataset_config.toml")
261
+ },
262
+ "training_arguments": {
263
+ "output_dir": str(output_dir),
264
+ "output_name": project_name,
265
+ "save_precision": "fp16",
266
+ "save_every_n_epochs": save_every_n_epochs,
267
+ "max_train_epochs": epochs if max_train_steps == 0 else None,
268
+ "max_train_steps": max_train_steps if max_train_steps > 0 else None,
269
+ "train_batch_size": batch_size,
270
+ "gradient_accumulation_steps": 1,
271
+ "learning_rate": learning_rate,
272
+ "text_encoder_lr": text_encoder_lr,
273
+ "lr_scheduler": scheduler,
274
+ "lr_warmup_steps": 0,
275
+ "optimizer_type": optimizer,
276
+ "mixed_precision": mixed_precision,
277
+ "save_model_as": "safetensors",
278
+ "seed": 42,
279
+ "max_data_loader_n_workers": 2,
280
+ "persistent_data_loader_workers": True,
281
+ "gradient_checkpointing": True,
282
+ "xformers": True,
283
+ "lowram": True,
284
+ "cache_latents": True,
285
+ "cache_latents_to_disk": True,
286
+ "logging_dir": str(log_dir),
287
+ "log_with": "tensorboard"
288
+ },
289
+ "network_arguments": {
290
+ "network_module": "networks.lora" if lora_type == "LoRA" else "networks.dylora",
291
+ "network_dim": network_dim,
292
+ "network_alpha": network_alpha,
293
+ "network_train_unet_only": False,
294
+ "network_train_text_encoder_only": False
295
+ }
296
+ }
297
+
298
+ # Adicionar argumentos específicos para LoCon
299
+ if lora_type == "LoCon":
300
+ training_config["network_arguments"]["network_module"] = "networks.lora"
301
+ training_config["network_arguments"]["conv_dim"] = max(1, network_dim // 2)
302
+ training_config["network_arguments"]["conv_alpha"] = max(1, network_alpha // 2)
303
+
304
+ # Salvar configurações
305
+ dataset_config_path = project_dir / "dataset_config.toml"
306
+ training_config_path = project_dir / "training_config.toml"
307
+
308
+ with open(dataset_config_path, 'w') as f:
309
+ toml.dump(dataset_config, f)
310
+
311
+ with open(training_config_path, 'w') as f:
312
+ toml.dump(training_config, f)
313
+
314
+ return f"✅ Configuração criada!\n📁 Dataset: {dataset_config_path}\n⚙️ Treinamento: {training_config_path}"
315
+
316
+ except Exception as e:
317
+ logger.error(f"Erro ao criar configuração: {e}")
318
+ return f"❌ Erro ao criar configuração: {e}"
319
+
320
+ def start_training(self, project_name: str) -> str:
321
+ """Inicia o treinamento"""
322
+ try:
323
+ if not project_name.strip():
324
+ return "❌ Nome do projeto é obrigatório"
325
+
326
+ project_name = project_name.strip().replace(" ", "_")
327
+ project_dir = self.projects_dir / project_name
328
+
329
+ training_config_path = project_dir / "training_config.toml"
330
+ if not training_config_path.exists():
331
+ return "❌ Configuração não encontrada. Crie a configuração primeiro."
332
+
333
+ # Script de treinamento
334
+ train_script = self.sd_scripts_dir / "train_network.py"
335
+ if not train_script.exists():
336
+ return "❌ Script de treinamento não encontrado"
337
+
338
+ # Comando de treinamento
339
+ cmd = [
340
+ sys.executable,
341
+ str(train_script),
342
+ "--config_file", str(training_config_path)
343
+ ]
344
+
345
+ logger.info(f"Iniciando treinamento: {' '.join(cmd)}")
346
+
347
+ # Executar em thread separada
348
+ def run_training():
349
+ try:
350
+ process = subprocess.Popen(
351
+ cmd,
352
+ stdout=subprocess.PIPE,
353
+ stderr=subprocess.STDOUT,
354
+ text=True,
355
+ bufsize=1,
356
+ universal_newlines=True,
357
+ cwd=str(self.sd_scripts_dir)
358
+ )
359
+
360
+ self.training_process = process
361
+
362
+ for line in process.stdout:
363
+ self.training_output_queue.put(line.strip())
364
+ logger.info(line.strip())
365
+
366
+ process.wait()
367
+
368
+ if process.returncode == 0:
369
+ self.training_output_queue.put("✅ TREINAMENTO CONCLUÍDO COM SUCESSO!")
370
+ else:
371
+ self.training_output_queue.put(f"❌ TREINAMENTO FALHOU (código {process.returncode})")
372
+
373
+ except Exception as e:
374
+ self.training_output_queue.put(f"❌ ERRO NO TREINAMENTO: {e}")
375
+ finally:
376
+ self.training_process = None
377
+
378
+ # Iniciar thread
379
+ training_thread = threading.Thread(target=run_training)
380
+ training_thread.daemon = True
381
+ training_thread.start()
382
+
383
+ return "🚀 Treinamento iniciado! Acompanhe o progresso abaixo."
384
+
385
+ except Exception as e:
386
+ logger.error(f"Erro ao iniciar treinamento: {e}")
387
+ return f"❌ Erro ao iniciar treinamento: {e}"
388
+
389
+ def get_training_output(self) -> str:
390
+ """Obtém output do treinamento"""
391
+ output_lines = []
392
+ try:
393
+ while not self.training_output_queue.empty():
394
+ line = self.training_output_queue.get_nowait()
395
+ output_lines.append(line)
396
+ except queue.Empty:
397
+ pass
398
+
399
+ if output_lines:
400
+ return "\n".join(output_lines)
401
+ elif self.training_process and self.training_process.poll() is None:
402
+ return "🔄 Treinamento em andamento..."
403
+ else:
404
+ return "⏸️ Nenhum treinamento ativo"
405
+
406
+ def stop_training(self) -> str:
407
+ """Para o treinamento"""
408
+ try:
409
+ if self.training_process and self.training_process.poll() is None:
410
+ self.training_process.terminate()
411
+ self.training_process.wait(timeout=10)
412
+ return "⏹️ Treinamento interrompido"
413
+ else:
414
+ return "ℹ️ Nenhum treinamento ativo para parar"
415
+ except Exception as e:
416
+ return f"❌ Erro ao parar treinamento: {e}"
417
+
418
+ def list_output_files(self, project_name: str) -> List[str]:
419
+ """Lista arquivos de saída"""
420
+ try:
421
+ if not project_name.strip():
422
+ return []
423
+
424
+ project_name = project_name.strip().replace(" ", "_")
425
+ project_dir = self.projects_dir / project_name
426
+ output_dir = project_dir / "output"
427
+
428
+ if not output_dir.exists():
429
+ return []
430
+
431
+ files = []
432
+ for file_path in output_dir.rglob("*.safetensors"):
433
+ size_mb = file_path.stat().st_size // (1024 * 1024)
434
+ files.append(f"{file_path.name} ({size_mb} MB)")
435
+
436
+ return sorted(files, reverse=True) # Mais recentes primeiro
437
+
438
+ except Exception as e:
439
+ logger.error(f"Erro ao listar arquivos: {e}")
440
+ return []
441
+
442
+ # Instância global
443
+ trainer = LoRATrainerHF()
444
+
445
+ def create_interface():
446
+ """Cria a interface Gradio"""
447
+
448
+ with gr.Blocks(title="LoRA Trainer Funcional - Hugging Face", theme=gr.themes.Soft()) as interface:
449
+
450
+ gr.Markdown("""
451
+ # 🎨 LoRA Trainer Funcional para Hugging Face
452
+
453
+ **Treine seus próprios modelos LoRA para Stable Diffusion de forma profissional!**
454
+
455
+ Esta ferramenta é baseada no kohya-ss sd-scripts e oferece treinamento real e funcional de modelos LoRA.
456
+ """)
457
+
458
+ # Estado para armazenar informações
459
+ dataset_dir_state = gr.State("")
460
+
461
+ with gr.Tab("🔧 Instalação"):
462
+ gr.Markdown("### Primeiro, instale as dependências necessárias:")
463
+ install_btn = gr.Button("📦 Instalar Dependências", variant="primary", size="lg")
464
+ install_status = gr.Textbox(label="Status da Instalação", lines=3, interactive=False)
465
+
466
+ install_btn.click(
467
+ fn=trainer.install_dependencies,
468
+ outputs=install_status
469
+ )
470
+
471
+ with gr.Tab("📁 Configuração do Projeto"):
472
+ with gr.Row():
473
+ project_name = gr.Textbox(
474
+ label="Nome do Projeto",
475
+ placeholder="meu_lora_anime",
476
+ info="Nome único para seu projeto (sem espaços especiais)"
477
+ )
478
+
479
+ gr.Markdown("### 📥 Download do Modelo Base")
480
+ with gr.Row():
481
+ model_choice = gr.Dropdown(
482
+ choices=list(trainer.model_urls.keys()),
483
+ label="Modelo Base Pré-definido",
484
+ value="Anime (animefull-final-pruned)",
485
+ info="Escolha um modelo base ou use URL personalizada"
486
+ )
487
+ custom_model_url = gr.Textbox(
488
+ label="URL Personalizada (opcional)",
489
+ placeholder="https://huggingface.co/...",
490
+ info="URL direta para download de modelo personalizado"
491
+ )
492
+
493
+ download_btn = gr.Button("📥 Baixar Modelo", variant="primary")
494
+ download_status = gr.Textbox(label="Status do Download", lines=2, interactive=False)
495
+
496
+ gr.Markdown("### 📊 Upload do Dataset")
497
+ gr.Markdown("""
498
+ **Formato do Dataset:**
499
+ - Crie um arquivo ZIP contendo suas imagens
500
+ - Para cada imagem, inclua um arquivo .txt com o mesmo nome contendo as tags/descrições
501
+ - Exemplo: `imagem1.jpg` + `imagem1.txt`
502
+ """)
503
+
504
+ dataset_upload = gr.File(
505
+ label="Upload do Dataset (ZIP)",
506
+ file_types=[".zip"]
507
+ )
508
+
509
+ process_btn = gr.Button("📊 Processar Dataset", variant="primary")
510
+ dataset_status = gr.Textbox(label="Status do Dataset", lines=4, interactive=False)
511
+
512
+ with gr.Tab("⚙️ Parâmetros de Treinamento"):
513
+ with gr.Row():
514
+ with gr.Column():
515
+ gr.Markdown("#### 🖼️ Configurações de Imagem")
516
+ resolution = gr.Slider(
517
+ minimum=512, maximum=1024, step=64, value=512,
518
+ label="Resolução",
519
+ info="Resolução das imagens (512 = mais rápido, 1024 = melhor qualidade)"
520
+ )
521
+ batch_size = gr.Slider(
522
+ minimum=1, maximum=8, step=1, value=1,
523
+ label="Batch Size",
524
+ info="Imagens por lote (aumente se tiver GPU potente)"
525
+ )
526
+ flip_aug = gr.Checkbox(
527
+ label="Flip Augmentation",
528
+ info="Espelhar imagens para aumentar dataset"
529
+ )
530
+ shuffle_caption = gr.Checkbox(
531
+ value=True,
532
+ label="Shuffle Caption",
533
+ info="Embaralhar ordem das tags"
534
+ )
535
+ keep_tokens = gr.Slider(
536
+ minimum=0, maximum=5, step=1, value=1,
537
+ label="Keep Tokens",
538
+ info="Número de tokens iniciais que não serão embaralhados"
539
+ )
540
+
541
+ with gr.Column():
542
+ gr.Markdown("#### 🎯 Configurações de Treinamento")
543
+ epochs = gr.Slider(
544
+ minimum=1, maximum=100, step=1, value=10,
545
+ label="Épocas",
546
+ info="Número de épocas de treinamento"
547
+ )
548
+ max_train_steps = gr.Number(
549
+ value=0,
550
+ label="Max Train Steps (0 = usar épocas)",
551
+ info="Número máximo de steps (deixe 0 para usar épocas)"
552
+ )
553
+ save_every_n_epochs = gr.Slider(
554
+ minimum=1, maximum=10, step=1, value=1,
555
+ label="Salvar a cada N épocas",
556
+ info="Frequência de salvamento dos checkpoints"
557
+ )
558
+ mixed_precision = gr.Dropdown(
559
+ choices=["fp16", "bf16", "no"],
560
+ value="fp16",
561
+ label="Mixed Precision",
562
+ info="fp16 = mais rápido, bf16 = mais estável"
563
+ )
564
+ clip_skip = gr.Slider(
565
+ minimum=1, maximum=12, step=1, value=2,
566
+ label="CLIP Skip",
567
+ info="Camadas CLIP a pular (2 para anime, 1 para realista)"
568
+ )
569
+
570
+ with gr.Row():
571
+ with gr.Column():
572
+ gr.Markdown("#### 📚 Learning Rate")
573
+ learning_rate = gr.Number(
574
+ value=1e-4,
575
+ label="Learning Rate (UNet)",
576
+ info="Taxa de aprendizado principal"
577
+ )
578
+ text_encoder_lr = gr.Number(
579
+ value=5e-5,
580
+ label="Learning Rate (Text Encoder)",
581
+ info="Taxa de aprendizado do text encoder"
582
+ )
583
+ scheduler = gr.Dropdown(
584
+ choices=["cosine", "cosine_with_restarts", "constant", "constant_with_warmup", "linear"],
585
+ value="cosine_with_restarts",
586
+ label="LR Scheduler",
587
+ info="Algoritmo de ajuste da learning rate"
588
+ )
589
+ optimizer = gr.Dropdown(
590
+ choices=["AdamW8bit", "AdamW", "Lion", "SGD"],
591
+ value="AdamW8bit",
592
+ label="Otimizador",
593
+ info="AdamW8bit = menos memória"
594
+ )
595
+
596
+ with gr.Column():
597
+ gr.Markdown("#### 🧠 Arquitetura LoRA")
598
+ lora_type = gr.Radio(
599
+ choices=["LoRA", "LoCon"],
600
+ value="LoRA",
601
+ label="Tipo de LoRA",
602
+ info="LoRA = geral, LoCon = estilos artísticos"
603
+ )
604
+ network_dim = gr.Slider(
605
+ minimum=4, maximum=128, step=4, value=32,
606
+ label="Network Dimension",
607
+ info="Dimensão da rede (maior = mais detalhes, mais memória)"
608
+ )
609
+ network_alpha = gr.Slider(
610
+ minimum=1, maximum=128, step=1, value=16,
611
+ label="Network Alpha",
612
+ info="Controla a força do LoRA (geralmente dim/2)"
613
+ )
614
+
615
+ with gr.Tab("🚀 Treinamento"):
616
+ create_config_btn = gr.Button("📝 Criar Configuração de Treinamento", variant="primary", size="lg")
617
+ config_status = gr.Textbox(label="Status da Configuração", lines=3, interactive=False)
618
+
619
+ with gr.Row():
620
+ start_training_btn = gr.Button("🎯 Iniciar Treinamento", variant="primary", size="lg")
621
+ stop_training_btn = gr.Button("⏹️ Parar Treinamento", variant="stop")
622
+
623
+ training_output = gr.Textbox(
624
+ label="Output do Treinamento",
625
+ lines=15,
626
+ interactive=False,
627
+ info="Acompanhe o progresso do treinamento em tempo real"
628
+ )
629
+
630
+ # Auto-refresh do output
631
+ def update_output():
632
+ return trainer.get_training_output()
633
+
634
+ with gr.Tab("📥 Download dos Resultados"):
635
+ refresh_files_btn = gr.Button("🔄 Atualizar Lista de Arquivos", variant="secondary")
636
+
637
+ output_files = gr.Dropdown(
638
+ label="Arquivos LoRA Gerados",
639
+ choices=[],
640
+ info="Selecione um arquivo para download"
641
+ )
642
+
643
+ download_info = gr.Markdown("ℹ️ Os arquivos LoRA estarão disponíveis após o treinamento")
644
+
645
+ # Event handlers
646
+ download_btn.click(
647
+ fn=trainer.download_model,
648
+ inputs=[model_choice, custom_model_url],
649
+ outputs=download_status
650
+ )
651
+
652
+ process_btn.click(
653
+ fn=trainer.process_dataset,
654
+ inputs=[dataset_upload, project_name],
655
+ outputs=[dataset_status, dataset_dir_state]
656
+ )
657
+
658
+ create_config_btn.click(
659
+ fn=trainer.create_training_config,
660
+ inputs=[
661
+ project_name, dataset_dir_state, model_choice, custom_model_url,
662
+ resolution, batch_size, epochs, learning_rate, text_encoder_lr,
663
+ network_dim, network_alpha, lora_type, optimizer, scheduler,
664
+ flip_aug, shuffle_caption, keep_tokens, clip_skip, mixed_precision,
665
+ save_every_n_epochs, max_train_steps
666
+ ],
667
+ outputs=config_status
668
+ )
669
+
670
+ start_training_btn.click(
671
+ fn=trainer.start_training,
672
+ inputs=project_name,
673
+ outputs=training_output
674
+ )
675
+
676
+ stop_training_btn.click(
677
+ fn=trainer.stop_training,
678
+ outputs=training_output
679
+ )
680
+
681
+ refresh_files_btn.click(
682
+ fn=trainer.list_output_files,
683
+ inputs=project_name,
684
+ outputs=output_files
685
+ )
686
+
687
+ return interface
688
+
689
+ if __name__ == "__main__":
690
+ print("🚀 Iniciando LoRA Trainer Funcional...")
691
+ interface = create_interface()
692
+ interface.launch(
693
+ server_name="0.0.0.0",
694
+ server_port=7860,
695
+ share=False,
696
+ show_error=True
697
+ )
698
+