Bla1r commited on
Commit
da1a1f4
·
verified ·
1 Parent(s): 54e77cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -8,12 +8,18 @@ from torchvision import transforms
8
  import uuid
9
  import os
10
 
 
 
 
 
 
 
11
  torch.set_float32_matmul_precision(["high", "highest"][0])
12
 
13
  birefnet = AutoModelForImageSegmentation.from_pretrained(
14
  "ZhengPeng7/BiRefNet", trust_remote_code=True
15
  )
16
- birefnet.to("cuda")
17
 
18
  transform_image = transforms.Compose(
19
  [
@@ -26,7 +32,7 @@ transform_image = transforms.Compose(
26
  @spaces.GPU
27
  def process(image):
28
  image_size = image.size
29
- input_images = transform_image(image).unsqueeze(0).to("cuda")
30
  # Prediction
31
  with torch.no_grad():
32
  preds = birefnet(input_images)[-1].sigmoid().cpu()
 
8
  import uuid
9
  import os
10
 
11
+ # Automatically select device based on availability
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Optional: Print which device is being used
15
+ print(f"Using device: {device}")
16
+
17
  torch.set_float32_matmul_precision(["high", "highest"][0])
18
 
19
  birefnet = AutoModelForImageSegmentation.from_pretrained(
20
  "ZhengPeng7/BiRefNet", trust_remote_code=True
21
  )
22
+ birefnet.to(device)
23
 
24
  transform_image = transforms.Compose(
25
  [
 
32
  @spaces.GPU
33
  def process(image):
34
  image_size = image.size
35
+ input_images = transform_image(image).unsqueeze(0).to(device)
36
  # Prediction
37
  with torch.no_grad():
38
  preds = birefnet(input_images)[-1].sigmoid().cpu()