Gemini899 commited on
Commit
fb42b12
·
verified ·
1 Parent(s): a279f70

Update flux1_img2img.py

Browse files
Files changed (1) hide show
  1. flux1_img2img.py +20 -10
flux1_img2img.py CHANGED
@@ -4,32 +4,42 @@ from PIL import Image
4
  import sys
5
  import spaces
6
 
7
- # I only test with FLUX.1-schnell
8
 
9
  @spaces.GPU
10
  def process_image(image, mask_image, prompt="a person", model_id="black-forest-labs/FLUX.1-schnell", strength=0.75, seed=0, num_inference_steps=4):
11
- print("start process image process_image")
12
  if image is None:
13
- print("empty input image returned")
14
  return None
15
 
16
- # Ensure image is in RGB mode (helps with WebP and other formats)
17
  if image.mode != "RGB":
18
  image = image.convert("RGB")
19
 
20
- pipe = FluxImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
21
- pipe.to("cuda")
 
 
 
22
 
23
  generator = torch.Generator("cuda").manual_seed(seed)
24
  print(prompt)
25
- output = pipe(prompt=prompt, image=image, generator=generator, strength=strength,
26
- guidance_scale=0, num_inference_steps=num_inference_steps, max_sequence_length=256)
 
 
 
 
 
 
 
27
 
28
- # TODO: support mask
29
  return output.images[0]
30
 
31
  if __name__ == "__main__":
32
- # args: input-image input-mask output
33
  image = Image.open(sys.argv[1])
34
  mask = Image.open(sys.argv[2])
35
  output = process_image(image, mask)
 
4
  import sys
5
  import spaces
6
 
7
+ # Tested with FLUX.1-schnell
8
 
9
  @spaces.GPU
10
  def process_image(image, mask_image, prompt="a person", model_id="black-forest-labs/FLUX.1-schnell", strength=0.75, seed=0, num_inference_steps=4):
11
+ print("Starting process_image")
12
  if image is None:
13
+ print("Empty input image returned.")
14
  return None
15
 
16
+ # Ensure the image is in RGB mode (this handles formats like WebP and JFIF)
17
  if image.mode != "RGB":
18
  image = image.convert("RGB")
19
 
20
+ # If needed, add use_auth_token="YOUR_TOKEN" in from_pretrained below.
21
+ pipe = FluxImg2ImgPipeline.from_pretrained(
22
+ model_id,
23
+ torch_dtype=torch.bfloat16
24
+ ).to("cuda")
25
 
26
  generator = torch.Generator("cuda").manual_seed(seed)
27
  print(prompt)
28
+ output = pipe(
29
+ prompt=prompt,
30
+ image=image,
31
+ generator=generator,
32
+ strength=strength,
33
+ guidance_scale=0,
34
+ num_inference_steps=num_inference_steps,
35
+ max_sequence_length=256
36
+ )
37
 
38
+ # TODO: Add mask support if needed
39
  return output.images[0]
40
 
41
  if __name__ == "__main__":
42
+ # Usage: python flux1_img2img.py input-image input-mask output
43
  image = Image.open(sys.argv[1])
44
  mask = Image.open(sys.argv[2])
45
  output = process_image(image, mask)