julien-c HF staff commited on
Commit
35143eb
·
verified ·
1 Parent(s): cf32f12
Files changed (1) hide show
  1. handler.py +12 -4
handler.py CHANGED
@@ -1,11 +1,9 @@
1
  from typing import Dict, List, Any
2
  import torch
3
- from torch import autocast
4
  from huggingface_hub import hf_hub_download
5
  from diffusers import DiffusionPipeline
6
- import base64
7
- from io import BytesIO
8
  from safetensors.torch import load_file
 
9
 
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -43,6 +41,14 @@ class EndpointHandler:
43
  tokenizer=self.pipe.tokenizer_2,
44
  )
45
 
 
 
 
 
 
 
 
 
46
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
47
  """
48
  Args:
@@ -57,7 +63,9 @@ class EndpointHandler:
57
  images = self.pipe(inputs, **data["parameters"]).images
58
  image = images[0]
59
 
60
- return image
 
 
61
 
62
 
63
  if __name__ == "__main__":
 
1
  from typing import Dict, List, Any
2
  import torch
 
3
  from huggingface_hub import hf_hub_download
4
  from diffusers import DiffusionPipeline
 
 
5
  from safetensors.torch import load_file
6
+ from transformers import pipeline
7
 
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
41
  tokenizer=self.pipe.tokenizer_2,
42
  )
43
 
44
+ self.remove_bg = pipeline(
45
+ "image-segmentation",
46
+ model="briaai/RMBG-1.4",
47
+ device=device,
48
+ revision="22532afbdabdc36b2d30a334076720ac72a06f83",
49
+ trust_remote_code=True,
50
+ )
51
+
52
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
53
  """
54
  Args:
 
63
  images = self.pipe(inputs, **data["parameters"]).images
64
  image = images[0]
65
 
66
+ image_no_bg = self.remove_bg(image)
67
+
68
+ return image_no_bg
69
 
70
 
71
  if __name__ == "__main__":