PandaArtStation commited on
Commit
e9d3a3c
·
verified ·
1 Parent(s): 90868e4

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +270 -0
models.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import (
3
+ StableDiffusionXLImg2ImgPipeline,
4
+ StableDiffusionInpaintPipeline,
5
+ DDIMScheduler,
6
+ PNDMScheduler,
7
+ EulerDiscreteScheduler,
8
+ DPMSolverMultistepScheduler
9
+ )
10
+ from PIL import Image, ImageFilter, ImageEnhance
11
+ import numpy as np
12
+ import cv2
13
+
14
+ class InteriorDesignerPro:
15
+ def __init__(self):
16
+ self.device = torch.device("cuda") # ТОЛЬКО GPU!
17
+ self.model_name = "RealVisXL V4.0"
18
+
19
+ # Проверка GPU
20
+ gpu_name = torch.cuda.get_device_name(0)
21
+ self.is_powerful_gpu = any(gpu in gpu_name for gpu in ['A100', 'H100', 'RTX 4090', 'RTX 3090', 'T4', 'A10G'])
22
+
23
+ # Основная модель - RealVisXL V4
24
+ print(f"Loading {self.model_name} on {gpu_name}...")
25
+ self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
26
+ "SG161222/RealVisXL_V4.0",
27
+ torch_dtype=torch.float16,
28
+ use_safetensors=True,
29
+ variant="fp16"
30
+ ).to(self.device)
31
+
32
+ # БЕЗ enable_model_cpu_offload() и enable_vae_slicing() - они замедляют H200!
33
+
34
+ # Настройка scheduler для качества
35
+ self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
36
+
37
+ # Inpainting модель
38
+ try:
39
+ self.inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
40
+ "stabilityai/stable-diffusion-2-inpainting",
41
+ torch_dtype=torch.float16,
42
+ safety_checker=None,
43
+ requires_safety_checker=False,
44
+ local_files_only=False,
45
+ resume_download=True
46
+ ).to(self.device)
47
+ print("Inpainting model loaded")
48
+ except Exception as e:
49
+ print(f"Warning: Could not load inpainting model: {e}")
50
+ print("Using img2img as fallback for object removal")
51
+ self.inpaint_pipe = None
52
+
53
+ @torch.inference_mode()
54
+ def apply_style_pro(self, image, style_name, room_type, strength=0.75, quality="balanced", custom_prompt=None, custom_negative=None):
55
+ """Применение стиля к изображению"""
56
+ from design_styles import DESIGN_STYLES
57
+
58
+ # Ресайз для скорости
59
+ original_size = image.size
60
+ if image.width > 768 or image.height > 768:
61
+ image.thumbnail((768, 768), Image.Resampling.LANCZOS)
62
+
63
+ if style_name == "custom" and custom_prompt:
64
+ # Кастомный промпт
65
+ full_prompt = custom_prompt
66
+ negative = custom_negative or "low quality, blurry"
67
+ else:
68
+ # Предустановленный стиль
69
+ style = DESIGN_STYLES.get(style_name, DESIGN_STYLES["Современный минимализм"])
70
+ room_specific = style.get("room_specific", {}).get(room_type, "")
71
+ full_prompt = f"{style['prompt']}, {room_specific}, {room_type} interior design, professional photo, high quality, 8k, photorealistic"
72
+ negative = style.get("negative", "low quality, blurry")
73
+
74
+ # Настройки качества - оптимизированные для H200
75
+ quality_settings = {
76
+ "fast": {"steps": 15, "guidance": 6.0},
77
+ "balanced": {"steps": 20, "guidance": 7.0},
78
+ "ultra": {"steps": 30, "guidance": 8.0}
79
+ }
80
+
81
+ settings = quality_settings.get(quality, quality_settings["balanced"])
82
+
83
+ # Генерация с SDXL
84
+ result = self.pipe(
85
+ prompt=full_prompt,
86
+ prompt_2=full_prompt, # Для SDXL
87
+ negative_prompt=negative,
88
+ negative_prompt_2=negative, # Для SDXL
89
+ image=image,
90
+ strength=strength,
91
+ num_inference_steps=settings["steps"],
92
+ guidance_scale=settings["guidance"],
93
+ # SDXL параметры - оптимизированные
94
+ original_size=(768, 768),
95
+ target_size=(768, 768)
96
+ ).images[0]
97
+
98
+ # Возвращаем к оригинальному размеру если нужно
99
+ if result.size != original_size and max(original_size) <= 1024:
100
+ result = result.resize(original_size, Image.Resampling.LANCZOS)
101
+
102
+ return result
103
+
104
+ def create_variations(self, image, num_variations=4):
105
+ """Создание вариаций дизайна"""
106
+ variations = []
107
+ base_seed = torch.randint(0, 1000000, (1,)).item()
108
+
109
+ # Ресайз для скорости
110
+ if image.width > 768 or image.height > 768:
111
+ image.thumbnail((768, 768), Image.Resampling.LANCZOS)
112
+
113
+ for i in range(num_variations):
114
+ torch.manual_seed(base_seed + i)
115
+
116
+ var = self.pipe(
117
+ prompt="interior design variation, same style, different details",
118
+ prompt_2="interior design variation, same style, different details",
119
+ image=image,
120
+ strength=0.4 + (i * 0.05),
121
+ num_inference_steps=20, # Меньше шагов для скорости
122
+ guidance_scale=6.0
123
+ ).images[0]
124
+
125
+ variations.append(var)
126
+
127
+ return variations
128
+
129
+ def create_hdr_lighting(self, image, intensity=0.3):
130
+ """Улучшение освещения в стиле HDR"""
131
+ # Конвертируем в numpy
132
+ img_array = np.array(image)
133
+
134
+ # Применяем CLAHE для улучшения контраста
135
+ lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
136
+ l, a, b = cv2.split(lab)
137
+
138
+ clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
139
+ l_clahe = clahe.apply(l)
140
+
141
+ enhanced_lab = cv2.merge([l_clahe, a, b])
142
+ enhanced_rgb = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2RGB)
143
+
144
+ # Смешиваем с оригиналом
145
+ result = cv2.addWeighted(img_array, 1-intensity, enhanced_rgb, intensity, 0)
146
+
147
+ return Image.fromarray(result)
148
+
149
+ def enhance_details(self, image):
150
+ """Улучшение деталей изображения"""
151
+ # Увеличиваем резкость
152
+ enhancer = ImageEnhance.Sharpness(image)
153
+ sharp = enhancer.enhance(1.5)
154
+
155
+ # Немного увеличиваем контраст
156
+ enhancer = ImageEnhance.Contrast(sharp)
157
+ contrast = enhancer.enhance(1.1)
158
+
159
+ return contrast
160
+
161
+ def change_element(self, image, element, value, strength=0.7):
162
+ """Изменение отдельного элемента интерьера"""
163
+ from design_styles import ROOM_ELEMENTS
164
+
165
+ # Ресайз для скорости
166
+ original_size = image.size
167
+ if image.width > 768 or image.height > 768:
168
+ image.thumbnail((768, 768), Image.Resampling.LANCZOS)
169
+
170
+ element_info = ROOM_ELEMENTS.get(element, {})
171
+ prompt_add = element_info.get("prompt_add", element.lower())
172
+
173
+ prompt = f"interior with {value} {prompt_add}, professional photo"
174
+ negative = f"old {element}, damaged, ugly"
175
+
176
+ result = self.pipe(
177
+ prompt=prompt,
178
+ prompt_2=prompt, # Для SDXL
179
+ negative_prompt=negative,
180
+ negative_prompt_2=negative,
181
+ image=image,
182
+ strength=min(strength, 0.8), # Ограничиваем для скорости
183
+ num_inference_steps=20, # Оптимизировано для H200
184
+ guidance_scale=6.0
185
+ ).images[0]
186
+
187
+ # Возвращаем к оригинальному размеру
188
+ if result.size != original_size:
189
+ result = result.resize(original_size, Image.Resampling.LANCZOS)
190
+
191
+ return result
192
+
193
+ def create_style_comparison(self, image, styles, quality="fast"):
194
+ """Создание сравнения стилей"""
195
+ results = []
196
+
197
+ # Настройки для быстрой генерации
198
+ steps = 15 if quality == "fast" else 20
199
+
200
+ for style in styles:
201
+ styled = self.apply_style_pro(
202
+ image,
203
+ style,
204
+ "living room", # default
205
+ strength=0.75,
206
+ quality=quality
207
+ )
208
+ results.append((style, styled))
209
+
210
+ return results
211
+
212
+
213
+ class ObjectRemover:
214
+ """Класс для удаления объектов - оптимизированный"""
215
+
216
+ def __init__(self, inpaint_pipe):
217
+ self.pipe = inpaint_pipe
218
+ self.device = torch.device("cuda")
219
+
220
+ def remove_objects(self, image, mask):
221
+ """Удаление объектов с изображения"""
222
+ if self.pipe is None:
223
+ # Fallback на простое заполнение
224
+ return self.simple_inpaint(image, mask)
225
+
226
+ try:
227
+ # Используем inpainting pipeline с оптимизированными параметрами
228
+ result = self.pipe(
229
+ prompt="empty room interior, clean wall, seamless texture",
230
+ negative_prompt="furniture, objects, people, clutter",
231
+ image=image,
232
+ mask_image=mask,
233
+ strength=0.95, # Немного меньше для скорости
234
+ num_inference_steps=25, # Оптимизировано!
235
+ guidance_scale=5.0 # Меньше для скорости
236
+ ).images[0]
237
+
238
+ return result
239
+ except Exception as e:
240
+ print(f"Inpainting failed: {e}, using OpenCV fallback")
241
+ return self.simple_inpaint(image, mask)
242
+
243
+ def simple_inpaint(self, image, mask):
244
+ """Простое заполнение через OpenCV - очень быстро"""
245
+ img_array = np.array(image)
246
+ mask_array = np.array(mask.convert('L'))
247
+
248
+ # Инпейнтинг через OpenCV
249
+ result = cv2.inpaint(img_array, mask_array, 3, cv2.INPAINT_TELEA)
250
+
251
+ return Image.fromarray(result)
252
+
253
+ def generate_mask_from_text(self, image, text_description, precision=0.3):
254
+ """Генерация маски на основе текстового описания"""
255
+ # Простая маска в центре (заглушка)
256
+ # В реальности тут должен быть CLIP или SAM
257
+ width, height = image.size
258
+ mask = Image.new('L', (width, height), 0)
259
+
260
+ # Создаем маску в центре
261
+ center_x, center_y = width // 2, height // 2
262
+ radius = int(min(width, height) * precision)
263
+
264
+ # Рисуем круг
265
+ from PIL import ImageDraw
266
+ draw = ImageDraw.Draw(mask)
267
+ draw.ellipse([center_x - radius, center_y - radius,
268
+ center_x + radius, center_y + radius], fill=255)
269
+
270
+ return mask