Spaces:
Runtime error
Runtime error
attempt 53
Browse files
app.py
CHANGED
|
@@ -8,6 +8,10 @@ from transformer_flux import FluxTransformer2DModel
|
|
| 8 |
from pipeline_flux_cnet import FluxControlNetInpaintingPipeline
|
| 9 |
from PIL import Image, ImageDraw
|
| 10 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 13 |
# Ensure that the minimal version of diffusers is installed
|
|
@@ -52,57 +56,83 @@ def create_mask_from_editor(editor_value):
|
|
| 52 |
mask_image = Image.fromarray(mask_array)
|
| 53 |
return mask_image
|
| 54 |
|
| 55 |
-
def
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
width, height = image.size
|
| 58 |
diptych = Image.new('RGB', (width * 2, height), 'black')
|
| 59 |
diptych.paste(image, (0, 0))
|
| 60 |
-
diptych.paste(mask, (width, 0))
|
| 61 |
return diptych
|
| 62 |
|
| 63 |
@spaces.GPU()
|
| 64 |
def inpaint_image(image, prompt, editor_value):
|
| 65 |
-
#
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
mask =
|
| 71 |
-
|
| 72 |
-
# Create diptych image
|
| 73 |
-
diptych_image = create_diptych_image(image, mask)
|
| 74 |
-
|
| 75 |
-
# Preprocess prompt and image for the pipeline
|
| 76 |
-
prompt = pipe.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).input_ids.to("cuda")
|
| 77 |
-
image_tensor = pipe.feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
|
| 78 |
-
mask_tensor = pipe.feature_extractor(images=mask, return_tensors="pt").pixel_values.to("cuda")
|
| 79 |
-
control_image_tensor = pipe.feature_extractor(images=diptych_image, return_tensors="pt").pixel_values.to("cuda")
|
| 80 |
-
|
| 81 |
generator = torch.Generator(device="cuda").manual_seed(24)
|
|
|
|
| 82 |
|
| 83 |
# Calculate attention scale mask
|
| 84 |
attn_scale_factor = 1.5
|
| 85 |
-
size
|
| 86 |
-
H, W = size[1]
|
| 87 |
attn_scale_mask = torch.zeros(size[1], size[0])
|
| 88 |
-
attn_scale_mask[:, 768:] = 1.0
|
| 89 |
attn_scale_mask = torch.nn.functional.interpolate(attn_scale_mask[None, None, :, :], (H, W), mode='nearest-exact').flatten()
|
| 90 |
attn_scale_mask = attn_scale_mask[None, None, :, None].repeat(1, 24, 1, H*W)
|
|
|
|
| 91 |
transposed_inverted_attn_scale_mask = (1.0 - attn_scale_mask).transpose(-1, -2)
|
|
|
|
| 92 |
cross_attn_region = torch.logical_and(attn_scale_mask, transposed_inverted_attn_scale_mask)
|
|
|
|
| 93 |
cross_attn_region = cross_attn_region * attn_scale_factor
|
| 94 |
cross_attn_region[cross_attn_region < 1.0] = 1.0
|
|
|
|
| 95 |
full_attn_scale_mask = torch.ones(1, 24, 512+H*W, 512+H*W)
|
|
|
|
| 96 |
full_attn_scale_mask[:, :, 512:, 512:] = cross_attn_region
|
|
|
|
| 97 |
full_attn_scale_mask = full_attn_scale_mask.to(device=pipe.transformer.device, dtype=torch.bfloat16)
|
| 98 |
|
|
|
|
| 99 |
# Inpaint
|
| 100 |
result = pipe(
|
| 101 |
prompt=prompt,
|
| 102 |
height=size[1],
|
| 103 |
width=size[0],
|
| 104 |
-
control_image=
|
| 105 |
-
control_mask=
|
| 106 |
num_inference_steps=20,
|
| 107 |
generator=generator,
|
| 108 |
controlnet_conditioning_scale=0.95,
|
|
@@ -117,8 +147,8 @@ def inpaint_image(image, prompt, editor_value):
|
|
| 117 |
iface = gr.Interface(
|
| 118 |
fn=inpaint_image,
|
| 119 |
inputs=[
|
| 120 |
-
gr.Image(type="
|
| 121 |
-
gr.Textbox(lines=
|
| 122 |
gr.ImageEditor(type="pil", label="Image with Mask", sources="upload", interactive=True)
|
| 123 |
],
|
| 124 |
outputs=[
|
|
@@ -130,4 +160,4 @@ iface = gr.Interface(
|
|
| 130 |
)
|
| 131 |
|
| 132 |
# Launch the app
|
| 133 |
-
iface.launch()
|
|
|
|
| 8 |
from pipeline_flux_cnet import FluxControlNetInpaintingPipeline
|
| 9 |
from PIL import Image, ImageDraw
|
| 10 |
import numpy as np
|
| 11 |
+
import subprocess
|
| 12 |
+
|
| 13 |
+
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
|
| 14 |
+
|
| 15 |
|
| 16 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 17 |
# Ensure that the minimal version of diffusers is installed
|
|
|
|
| 56 |
mask_image = Image.fromarray(mask_array)
|
| 57 |
return mask_image
|
| 58 |
|
| 59 |
+
def create_mask_on_image(image, xyxy):
|
| 60 |
+
"""
|
| 61 |
+
Create a white mask on the image given xyxy coordinates.
|
| 62 |
+
Args:
|
| 63 |
+
image: PIL Image
|
| 64 |
+
xyxy: List of [x1, y1, x2, y2] coordinates
|
| 65 |
+
Returns:
|
| 66 |
+
PIL Image with white mask
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
# Convert to numpy array
|
| 70 |
+
img_array = np.array(image)
|
| 71 |
+
|
| 72 |
+
# Create mask
|
| 73 |
+
mask = Image.new('RGB', image.size, (0, 0, 0))
|
| 74 |
+
draw = ImageDraw.Draw(mask)
|
| 75 |
+
|
| 76 |
+
# Draw white rectangle
|
| 77 |
+
draw.rectangle(xyxy, fill=(255, 255, 255))
|
| 78 |
+
|
| 79 |
+
# Convert mask to array
|
| 80 |
+
mask_array = np.array(mask)
|
| 81 |
+
|
| 82 |
+
# Apply mask to image
|
| 83 |
+
masked_array = np.where(mask_array == 255, 255, img_array)
|
| 84 |
+
|
| 85 |
+
return Image.fromarray(mask_array), Image.fromarray(masked_array)
|
| 86 |
+
|
| 87 |
+
def create_diptych_image(image):
|
| 88 |
+
# Create a diptych image with original on left and black on right
|
| 89 |
width, height = image.size
|
| 90 |
diptych = Image.new('RGB', (width * 2, height), 'black')
|
| 91 |
diptych.paste(image, (0, 0))
|
|
|
|
| 92 |
return diptych
|
| 93 |
|
| 94 |
@spaces.GPU()
|
| 95 |
def inpaint_image(image, prompt, editor_value):
|
| 96 |
+
# Load image and mask
|
| 97 |
+
size = (1536, 768)
|
| 98 |
+
image = load_image(image).convert("RGB").resize((768, 768))
|
| 99 |
+
diptych_image = create_diptych_image(image)
|
| 100 |
+
# mask = load_image(mask_path).convert("RGB").resize(size)
|
| 101 |
+
# mask, mask_image = create_mask_on_image(image, [250, 275, 500, 400])
|
| 102 |
+
mask, mask_image = create_mask_on_image(diptych_image, [768, 0, 1536, 768])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
generator = torch.Generator(device="cuda").manual_seed(24)
|
| 104 |
+
# Load and preprocess image
|
| 105 |
|
| 106 |
# Calculate attention scale mask
|
| 107 |
attn_scale_factor = 1.5
|
| 108 |
+
# Create a tensor of ones with same size as diptych image
|
| 109 |
+
H, W = size[1]//16, size[0]//16
|
| 110 |
attn_scale_mask = torch.zeros(size[1], size[0])
|
| 111 |
+
attn_scale_mask[:, 768:] = 1.0 # height, width
|
| 112 |
attn_scale_mask = torch.nn.functional.interpolate(attn_scale_mask[None, None, :, :], (H, W), mode='nearest-exact').flatten()
|
| 113 |
attn_scale_mask = attn_scale_mask[None, None, :, None].repeat(1, 24, 1, H*W)
|
| 114 |
+
# Get inverted attention mask by subtracting from 1.0
|
| 115 |
transposed_inverted_attn_scale_mask = (1.0 - attn_scale_mask).transpose(-1, -2)
|
| 116 |
+
|
| 117 |
cross_attn_region = torch.logical_and(attn_scale_mask, transposed_inverted_attn_scale_mask)
|
| 118 |
+
|
| 119 |
cross_attn_region = cross_attn_region * attn_scale_factor
|
| 120 |
cross_attn_region[cross_attn_region < 1.0] = 1.0
|
| 121 |
+
|
| 122 |
full_attn_scale_mask = torch.ones(1, 24, 512+H*W, 512+H*W)
|
| 123 |
+
|
| 124 |
full_attn_scale_mask[:, :, 512:, 512:] = cross_attn_region
|
| 125 |
+
# Convert to bfloat16 to match model dtype
|
| 126 |
full_attn_scale_mask = full_attn_scale_mask.to(device=pipe.transformer.device, dtype=torch.bfloat16)
|
| 127 |
|
| 128 |
+
|
| 129 |
# Inpaint
|
| 130 |
result = pipe(
|
| 131 |
prompt=prompt,
|
| 132 |
height=size[1],
|
| 133 |
width=size[0],
|
| 134 |
+
control_image=diptych_image,
|
| 135 |
+
control_mask=mask,
|
| 136 |
num_inference_steps=20,
|
| 137 |
generator=generator,
|
| 138 |
controlnet_conditioning_scale=0.95,
|
|
|
|
| 147 |
iface = gr.Interface(
|
| 148 |
fn=inpaint_image,
|
| 149 |
inputs=[
|
| 150 |
+
gr.Image(type="filepath", label="Upload Image"),
|
| 151 |
+
gr.Textbox(lines=2, placeholder="Enter your prompt here (e.g., 'wearing a christmas hat, in a busy street')", label="Prompt"),
|
| 152 |
gr.ImageEditor(type="pil", label="Image with Mask", sources="upload", interactive=True)
|
| 153 |
],
|
| 154 |
outputs=[
|
|
|
|
| 160 |
)
|
| 161 |
|
| 162 |
# Launch the app
|
| 163 |
+
iface.launch(share=True)
|