merve HF staff commited on
Commit
9e9e5c4
1 Parent(s): 35abf13

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+
3
+ from transformers import pipeline
4
+ import torch
5
+ import numpy as np
6
+ from sam2.build_sam import build_sam2
7
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
8
+ import gradio as gr
9
+
10
+ hf_hub_download(repo_id = "merve/sam2-hiera-tiny", filename="sam2_hiera_tiny.pt", local_dir = "./")
11
+
12
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ CHECKPOINT = f"./sam2_hiera_tiny.pt"
14
+ CONFIG = "sam2_hiera_t.yaml"
15
+
16
+ sam2_model = build_sam2(CONFIG, CHECKPOINT, device=DEVICE, apply_postprocessing=False)
17
+ predictor = SAM2ImagePredictor(sam2_model)
18
+
19
+ checkpoint = "google/owlv2-base-patch16-ensemble"
20
+ detector = pipeline(model=checkpoint, task="zero-shot-object-detection", device="cuda")
21
+
22
+
23
+
24
+ def query(image, texts, threshold):
25
+ texts = texts.split(",")
26
+
27
+ predictions = detector(
28
+ image,
29
+ candidate_labels=texts,
30
+ threshold=threshold
31
+ )
32
+
33
+ result_labels = []
34
+ for pred in predictions:
35
+
36
+ box = pred["box"]
37
+ score = pred["score"]
38
+ label = pred["label"]
39
+ box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
40
+ round(pred["box"]["xmin"], 2)+round(pred["box"]["xmax"], 2),
41
+ round(pred["box"]["ymin"], 2)+round(pred["box"]["ymax"], 2)]
42
+ predictor.set_image(image)
43
+
44
+ mask, scores, logits = predictor.predict(box=box,
45
+ multimask_output=False)
46
+ mask = mask[np.newaxis, ...]
47
+ result_labels.append((mask, label))
48
+ return image, result_labels
49
+
50
+
51
+ description = "This Space combines OWLv2, the state-of-the-art zero-shot object detection model with SAM2, the state-of-the-art mask generation model. SAM2 normally doesn't accept text input. Combining SAM with OWLv2 makes SAM2 text promptable. Try the example or input an image and comma separated candidate labels to segment."
52
+ demo = gr.Interface(
53
+ query,
54
+ inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
55
+ outputs="annotatedimage",
56
+ title="OWLv2 🤝 SAMv2",
57
+ description=description,
58
+ examples=[
59
+ ["./cats.png", "cat", 0.1],
60
+ ],
61
+ cache_examples=True
62
+ )
63
+ demo.launch(debug=True)