LPX55 commited on
Commit
d81b69d
Β·
1 Parent(s): a739682

attempt 53

Browse files
Files changed (1) hide show
  1. app.py +57 -27
app.py CHANGED
@@ -8,6 +8,10 @@ from transformer_flux import FluxTransformer2DModel
8
  from pipeline_flux_cnet import FluxControlNetInpaintingPipeline
9
  from PIL import Image, ImageDraw
10
  import numpy as np
 
 
 
 
11
 
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  # Ensure that the minimal version of diffusers is installed
@@ -52,57 +56,83 @@ def create_mask_from_editor(editor_value):
52
  mask_image = Image.fromarray(mask_array)
53
  return mask_image
54
 
55
- def create_diptych_image(image, mask):
56
- # Create a diptych image with original on left and masked on right
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  width, height = image.size
58
  diptych = Image.new('RGB', (width * 2, height), 'black')
59
  diptych.paste(image, (0, 0))
60
- diptych.paste(mask, (width, 0))
61
  return diptych
62
 
63
  @spaces.GPU()
64
  def inpaint_image(image, prompt, editor_value):
65
- # Create mask from editor value
66
- mask = create_mask_from_editor(editor_value)
67
-
68
- # Load and preprocess image
69
- image = image.convert("RGB").resize((768, 768))
70
- mask = mask.convert("L").resize((768, 768)) # Convert mask to single channel (grayscale)
71
-
72
- # Create diptych image
73
- diptych_image = create_diptych_image(image, mask)
74
-
75
- # Preprocess prompt and image for the pipeline
76
- prompt = pipe.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).input_ids.to("cuda")
77
- image_tensor = pipe.feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
78
- mask_tensor = pipe.feature_extractor(images=mask, return_tensors="pt").pixel_values.to("cuda")
79
- control_image_tensor = pipe.feature_extractor(images=diptych_image, return_tensors="pt").pixel_values.to("cuda")
80
-
81
  generator = torch.Generator(device="cuda").manual_seed(24)
 
82
 
83
  # Calculate attention scale mask
84
  attn_scale_factor = 1.5
85
- size = (1536, 768)
86
- H, W = size[1] // 16, size[0] // 16
87
  attn_scale_mask = torch.zeros(size[1], size[0])
88
- attn_scale_mask[:, 768:] = 1.0 # height, width
89
  attn_scale_mask = torch.nn.functional.interpolate(attn_scale_mask[None, None, :, :], (H, W), mode='nearest-exact').flatten()
90
  attn_scale_mask = attn_scale_mask[None, None, :, None].repeat(1, 24, 1, H*W)
 
91
  transposed_inverted_attn_scale_mask = (1.0 - attn_scale_mask).transpose(-1, -2)
 
92
  cross_attn_region = torch.logical_and(attn_scale_mask, transposed_inverted_attn_scale_mask)
 
93
  cross_attn_region = cross_attn_region * attn_scale_factor
94
  cross_attn_region[cross_attn_region < 1.0] = 1.0
 
95
  full_attn_scale_mask = torch.ones(1, 24, 512+H*W, 512+H*W)
 
96
  full_attn_scale_mask[:, :, 512:, 512:] = cross_attn_region
 
97
  full_attn_scale_mask = full_attn_scale_mask.to(device=pipe.transformer.device, dtype=torch.bfloat16)
98
 
 
99
  # Inpaint
100
  result = pipe(
101
  prompt=prompt,
102
  height=size[1],
103
  width=size[0],
104
- control_image=control_image_tensor,
105
- control_mask=mask_tensor,
106
  num_inference_steps=20,
107
  generator=generator,
108
  controlnet_conditioning_scale=0.95,
@@ -117,8 +147,8 @@ def inpaint_image(image, prompt, editor_value):
117
  iface = gr.Interface(
118
  fn=inpaint_image,
119
  inputs=[
120
- gr.Image(type="pil", label="Upload Image"),
121
- gr.Textbox(lines=1, placeholder="Enter your prompt here (e.g., 'wearing a christmas hat, in a busy street')", label="Prompt"),
122
  gr.ImageEditor(type="pil", label="Image with Mask", sources="upload", interactive=True)
123
  ],
124
  outputs=[
@@ -130,4 +160,4 @@ iface = gr.Interface(
130
  )
131
 
132
  # Launch the app
133
- iface.launch()
 
8
  from pipeline_flux_cnet import FluxControlNetInpaintingPipeline
9
  from PIL import Image, ImageDraw
10
  import numpy as np
11
+ import subprocess
12
+
13
+ subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
14
+
15
 
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
  # Ensure that the minimal version of diffusers is installed
 
56
  mask_image = Image.fromarray(mask_array)
57
  return mask_image
58
 
59
+ def create_mask_on_image(image, xyxy):
60
+ """
61
+ Create a white mask on the image given xyxy coordinates.
62
+ Args:
63
+ image: PIL Image
64
+ xyxy: List of [x1, y1, x2, y2] coordinates
65
+ Returns:
66
+ PIL Image with white mask
67
+ """
68
+
69
+ # Convert to numpy array
70
+ img_array = np.array(image)
71
+
72
+ # Create mask
73
+ mask = Image.new('RGB', image.size, (0, 0, 0))
74
+ draw = ImageDraw.Draw(mask)
75
+
76
+ # Draw white rectangle
77
+ draw.rectangle(xyxy, fill=(255, 255, 255))
78
+
79
+ # Convert mask to array
80
+ mask_array = np.array(mask)
81
+
82
+ # Apply mask to image
83
+ masked_array = np.where(mask_array == 255, 255, img_array)
84
+
85
+ return Image.fromarray(mask_array), Image.fromarray(masked_array)
86
+
87
+ def create_diptych_image(image):
88
+ # Create a diptych image with original on left and black on right
89
  width, height = image.size
90
  diptych = Image.new('RGB', (width * 2, height), 'black')
91
  diptych.paste(image, (0, 0))
 
92
  return diptych
93
 
94
  @spaces.GPU()
95
  def inpaint_image(image, prompt, editor_value):
96
+ # Load image and mask
97
+ size = (1536, 768)
98
+ image = load_image(image).convert("RGB").resize((768, 768))
99
+ diptych_image = create_diptych_image(image)
100
+ # mask = load_image(mask_path).convert("RGB").resize(size)
101
+ # mask, mask_image = create_mask_on_image(image, [250, 275, 500, 400])
102
+ mask, mask_image = create_mask_on_image(diptych_image, [768, 0, 1536, 768])
 
 
 
 
 
 
 
 
 
103
  generator = torch.Generator(device="cuda").manual_seed(24)
104
+ # Load and preprocess image
105
 
106
  # Calculate attention scale mask
107
  attn_scale_factor = 1.5
108
+ # Create a tensor of ones with same size as diptych image
109
+ H, W = size[1]//16, size[0]//16
110
  attn_scale_mask = torch.zeros(size[1], size[0])
111
+ attn_scale_mask[:, 768:] = 1.0 # height, width
112
  attn_scale_mask = torch.nn.functional.interpolate(attn_scale_mask[None, None, :, :], (H, W), mode='nearest-exact').flatten()
113
  attn_scale_mask = attn_scale_mask[None, None, :, None].repeat(1, 24, 1, H*W)
114
+ # Get inverted attention mask by subtracting from 1.0
115
  transposed_inverted_attn_scale_mask = (1.0 - attn_scale_mask).transpose(-1, -2)
116
+
117
  cross_attn_region = torch.logical_and(attn_scale_mask, transposed_inverted_attn_scale_mask)
118
+
119
  cross_attn_region = cross_attn_region * attn_scale_factor
120
  cross_attn_region[cross_attn_region < 1.0] = 1.0
121
+
122
  full_attn_scale_mask = torch.ones(1, 24, 512+H*W, 512+H*W)
123
+
124
  full_attn_scale_mask[:, :, 512:, 512:] = cross_attn_region
125
+ # Convert to bfloat16 to match model dtype
126
  full_attn_scale_mask = full_attn_scale_mask.to(device=pipe.transformer.device, dtype=torch.bfloat16)
127
 
128
+
129
  # Inpaint
130
  result = pipe(
131
  prompt=prompt,
132
  height=size[1],
133
  width=size[0],
134
+ control_image=diptych_image,
135
+ control_mask=mask,
136
  num_inference_steps=20,
137
  generator=generator,
138
  controlnet_conditioning_scale=0.95,
 
147
  iface = gr.Interface(
148
  fn=inpaint_image,
149
  inputs=[
150
+ gr.Image(type="filepath", label="Upload Image"),
151
+ gr.Textbox(lines=2, placeholder="Enter your prompt here (e.g., 'wearing a christmas hat, in a busy street')", label="Prompt"),
152
  gr.ImageEditor(type="pil", label="Image with Mask", sources="upload", interactive=True)
153
  ],
154
  outputs=[
 
160
  )
161
 
162
  # Launch the app
163
+ iface.launch(share=True)