egrace479 Jianyang Gu commited on
Commit
1e72477
·
1 Parent(s): cd0baf2

revert back to local inference code

Browse files

Co-authored-by: Jianyang Gu <[email protected]>

app.py CHANGED
@@ -11,9 +11,10 @@ import torch
11
  import torch.nn.functional as F
12
  from open_clip import create_model, get_tokenizer
13
  from torchvision import transforms
 
14
 
 
15
  from components.query import get_sample
16
- from bioclip import CustomLabelsClassifier
17
 
18
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
19
  logging.basicConfig(level=logging.INFO, format=log_format)
@@ -27,16 +28,16 @@ METADATA_PATH = "components/metadata.parquet"
27
  metadata_df = pl.read_parquet(METADATA_PATH, low_memory = False)
28
  metadata_df = metadata_df.with_columns(pl.col(["eol_page_id", "gbif_id"]).cast(pl.Int64))
29
 
30
- MODEL_STR = "hf-hub:imageomics/bioclip-2"
31
- TOKENIZER_STR = "ViT-L-14"
 
32
 
33
- txt_emb_npy = "https://huggingface.co/datasets/imageomics/TreeOfLife-200M/resolve/main/embeddings/txt_emb_species.npy"
34
- txt_names_json = "embeddings/txt_emb_species.json"
35
 
36
  min_prob = 1e-9
37
  k = 5
38
 
39
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
 
41
  preprocess_img = transforms.Compose(
42
  [
@@ -52,41 +53,45 @@ preprocess_img = transforms.Compose(
52
  ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
53
 
54
  open_domain_examples = [
55
- ["examples/Ursus-arctos.jpeg", "Species"],
56
- ["examples/Phoca-vitulina.png", "Species"],
57
- ["examples/Felis-catus.jpeg", "Genus"],
58
- ["examples/Sarcoscypha-coccinea.jpeg", "Order"],
 
 
59
  ]
60
  zero_shot_examples = [
61
  [
62
- "examples/Ursus-arctos.jpeg",
63
- "brown bear\nblack bear\npolar bear\nkoala bear\ngrizzly bear",
64
  ],
65
- ["examples/milk-snake.png", "coral snake\nmilk snake"],
66
- ["examples/coral-snake.jpeg", "coral snake\nmilk snake"],
67
  [
68
- "examples/Carnegiea-gigantea.png",
69
- "Carnegiea gigantea\nSchlumbergera opuntioides\nMammillaria albicoma",
 
 
 
 
70
  ],
71
  [
72
- "examples/Amanita-muscaria.jpeg",
73
- "Amanita fulva\nAmanita vaginata (grisette)\nAmanita calyptrata (coccoli)\nAmanita crocea\nAmanita rubescens (blusher)\nAmanita caesarea (Caesar's mushroom)\nAmanita jacksonii (American Caesar's mushroom)\nAmanita muscaria (fly agaric)\nAmanita pantherina (panther cap)",
74
  ],
75
  [
76
- "examples/Actinostola-abyssorum.png",
77
- "Animalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola abyssorum\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola bulbosa\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola callosa\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola capensis\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola carlgreni",
78
  ],
79
  [
80
- "examples/Sarcoscypha-coccinea.jpeg",
81
- "scarlet elf cup (coccinea)\nscharlachroter kelchbecherling (austriaca)\ncrimson cup (dudleyi)\nstalked scarlet cup (occidentalis)",
82
  ],
83
  [
84
- "examples/Onoclea-hintonii.jpg",
85
- "Onoclea attenuata\nOnoclea boryana\nOnoclea hintonii\nOnoclea intermedia\nOnoclea sensibilis",
86
  ],
87
  [
88
- "examples/Onoclea-sensibilis.jpg",
89
- "Onoclea attenuata\nOnoclea boryana\nOnoclea hintonii\nOnoclea intermedia\nOnoclea sensibilis",
90
  ],
91
  ]
92
 
@@ -95,13 +100,32 @@ def indexed(lst, indices):
95
  return [lst[i] for i in indices]
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def zero_shot_classification(img, cls_str: str) -> dict[str, float]:
99
  classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
100
- classifier = CustomLabelsClassifier(
101
- cls_ary = classes,
102
- model_str = MODEL_STR, # remove this line once pybioclip uses BioCLIP 2
103
- )
104
- return classifier.predict(img)
 
 
 
 
105
 
106
 
107
  def format_name(taxon, common):
@@ -165,16 +189,20 @@ def change_output(choice):
165
 
166
  if __name__ == "__main__":
167
  logger.info("Starting.")
168
- model = create_model(MODEL_STR, output_dict=True, require_pretrained=True)
169
  model = model.to(device)
170
  logger.info("Created model.")
171
 
172
  model = torch.compile(model)
173
  logger.info("Compiled model.")
174
 
175
- tokenizer = get_tokenizer(TOKENIZER_STR)
176
 
177
- txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r")).to(device)
 
 
 
 
178
  with open(txt_names_json) as fd:
179
  txt_names = json.load(fd)
180
 
 
11
  import torch.nn.functional as F
12
  from open_clip import create_model, get_tokenizer
13
  from torchvision import transforms
14
+ from huggingface_hub import hf_hub_download
15
 
16
+ from components.templates import openai_imagenet_template
17
  from components.query import get_sample
 
18
 
19
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
20
  logging.basicConfig(level=logging.INFO, format=log_format)
 
28
  metadata_df = pl.read_parquet(METADATA_PATH, low_memory = False)
29
  metadata_df = metadata_df.with_columns(pl.col(["eol_page_id", "gbif_id"]).cast(pl.Int64))
30
 
31
+ model_str = "hf-hub:imageomics/bioclip-2"
32
+ tokenizer_str = "ViT-L-14"
33
+ HF_DATA_STR = "imageomics/TreeOfLife-200M"
34
 
35
+ txt_names_json = "components/txt_emb_species.json"
 
36
 
37
  min_prob = 1e-9
38
  k = 5
39
 
40
+ device = torch.device("cpu")
41
 
42
  preprocess_img = transforms.Compose(
43
  [
 
53
  ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
54
 
55
  open_domain_examples = [
56
+ ["examples/Carcharhinus-melanopterus.jpg", "Species"],
57
+ ["examples/house-finch.jpeg", "Species"],
58
+ ["examples/Bovidae-Oryx.jpg", "Genus"],
59
+ ["examples/Cebidae-Cebus.jpg", "Genus"],
60
+ ["examples/Solanales-Petunia.png", "Genus"],
61
+ ["examples/Asparagales-Orchidaceae.jpg", "Family"],
62
  ]
63
  zero_shot_examples = [
64
  [
65
+ "examples/Cortinarius-austroalbidus.jpg",
66
+ "Cortinarius austroalbidus\nCortinarius armillatus\nCortinarius caperatus"
67
  ],
 
 
68
  [
69
+ "examples/leopard.jpg",
70
+ "Jaguar\nLeopard\nCheetah",
71
+ ],
72
+ [
73
+ "examples/jaguar.jpg",
74
+ "Jaguar\nLeopard\nCheetah",
75
  ],
76
  [
77
+ "examples/cheetah.jpg",
78
+ "Jaguar\nLeopard\nCheetah",
79
  ],
80
  [
81
+ "examples/monarch.jpg",
82
+ "Danaus plexippus — Monarch\nLimenitis archippus — Viceroy",
83
  ],
84
  [
85
+ "examples/viceroy.jpg",
86
+ "Danaus plexippus — Monarch\nLimenitis archippus — Viceroy",
87
  ],
88
  [
89
+ "examples/Ursus-arctos.jpeg",
90
+ "brown bear\nblack bear\npolar bear\nkoala bear\ngrizzly bear",
91
  ],
92
  [
93
+ "examples/Carnegiea-gigantea.png",
94
+ "Carnegiea gigantea\nSchlumbergera opuntioides\nMammillaria albicoma",
95
  ],
96
  ]
97
 
 
100
  return [lst[i] for i in indices]
101
 
102
 
103
+ @torch.no_grad()
104
+ def get_txt_features(classnames, templates):
105
+ all_features = []
106
+ for classname in classnames:
107
+ txts = [template(classname) for template in templates]
108
+ txts = tokenizer(txts).to(device)
109
+ txt_features = model.encode_text(txts)
110
+ txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
111
+ txt_features /= txt_features.norm()
112
+ all_features.append(txt_features)
113
+ all_features = torch.stack(all_features, dim=1)
114
+ return all_features
115
+
116
+
117
+ @torch.no_grad()
118
  def zero_shot_classification(img, cls_str: str) -> dict[str, float]:
119
  classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
120
+ txt_features = get_txt_features(classes, openai_imagenet_template)
121
+
122
+ img = preprocess_img(img).to(device)
123
+ img_features = model.encode_image(img.unsqueeze(0))
124
+ img_features = F.normalize(img_features, dim=-1)
125
+
126
+ logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze()
127
+ probs = F.softmax(logits, dim=0).to("cpu").tolist()
128
+ return {cls: prob for cls, prob in zip(classes, probs)}
129
 
130
 
131
  def format_name(taxon, common):
 
189
 
190
  if __name__ == "__main__":
191
  logger.info("Starting.")
192
+ model = create_model(model_str, output_dict=True, require_pretrained=True)
193
  model = model.to(device)
194
  logger.info("Created model.")
195
 
196
  model = torch.compile(model)
197
  logger.info("Compiled model.")
198
 
199
+ tokenizer = get_tokenizer(tokenizer_str)
200
 
201
+ txt_emb = torch.from_numpy(np.load(hf_hub_download(
202
+ repo_id=HF_DATA_STR,
203
+ filename="embeddings/txt_emb_species.npy",
204
+ repo_type="dataset",
205
+ )))
206
  with open(txt_names_json) as fd:
207
  txt_names = json.load(fd)
208
 
components/templates.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai_imagenet_template = [
2
+ lambda c: f"a bad photo of a {c}.",
3
+ lambda c: f"a photo of many {c}.",
4
+ lambda c: f"a sculpture of a {c}.",
5
+ lambda c: f"a photo of the hard to see {c}.",
6
+ lambda c: f"a low resolution photo of the {c}.",
7
+ lambda c: f"a rendering of a {c}.",
8
+ lambda c: f"graffiti of a {c}.",
9
+ lambda c: f"a bad photo of the {c}.",
10
+ lambda c: f"a cropped photo of the {c}.",
11
+ lambda c: f"a tattoo of a {c}.",
12
+ lambda c: f"the embroidered {c}.",
13
+ lambda c: f"a photo of a hard to see {c}.",
14
+ lambda c: f"a bright photo of a {c}.",
15
+ lambda c: f"a photo of a clean {c}.",
16
+ lambda c: f"a photo of a dirty {c}.",
17
+ lambda c: f"a dark photo of the {c}.",
18
+ lambda c: f"a drawing of a {c}.",
19
+ lambda c: f"a photo of my {c}.",
20
+ lambda c: f"the plastic {c}.",
21
+ lambda c: f"a photo of the cool {c}.",
22
+ lambda c: f"a close-up photo of a {c}.",
23
+ lambda c: f"a black and white photo of the {c}.",
24
+ lambda c: f"a painting of the {c}.",
25
+ lambda c: f"a painting of a {c}.",
26
+ lambda c: f"a pixelated photo of the {c}.",
27
+ lambda c: f"a sculpture of the {c}.",
28
+ lambda c: f"a bright photo of the {c}.",
29
+ lambda c: f"a cropped photo of a {c}.",
30
+ lambda c: f"a plastic {c}.",
31
+ lambda c: f"a photo of the dirty {c}.",
32
+ lambda c: f"a jpeg corrupted photo of a {c}.",
33
+ lambda c: f"a blurry photo of the {c}.",
34
+ lambda c: f"a photo of the {c}.",
35
+ lambda c: f"a good photo of the {c}.",
36
+ lambda c: f"a rendering of the {c}.",
37
+ lambda c: f"a {c} in a video game.",
38
+ lambda c: f"a photo of one {c}.",
39
+ lambda c: f"a doodle of a {c}.",
40
+ lambda c: f"a close-up photo of the {c}.",
41
+ lambda c: f"a photo of a {c}.",
42
+ lambda c: f"the origami {c}.",
43
+ lambda c: f"the {c} in a video game.",
44
+ lambda c: f"a sketch of a {c}.",
45
+ lambda c: f"a doodle of the {c}.",
46
+ lambda c: f"a origami {c}.",
47
+ lambda c: f"a low resolution photo of a {c}.",
48
+ lambda c: f"the toy {c}.",
49
+ lambda c: f"a rendition of the {c}.",
50
+ lambda c: f"a photo of the clean {c}.",
51
+ lambda c: f"a photo of a large {c}.",
52
+ lambda c: f"a rendition of a {c}.",
53
+ lambda c: f"a photo of a nice {c}.",
54
+ lambda c: f"a photo of a weird {c}.",
55
+ lambda c: f"a blurry photo of a {c}.",
56
+ lambda c: f"a cartoon {c}.",
57
+ lambda c: f"art of a {c}.",
58
+ lambda c: f"a sketch of the {c}.",
59
+ lambda c: f"a embroidered {c}.",
60
+ lambda c: f"a pixelated photo of a {c}.",
61
+ lambda c: f"itap of the {c}.",
62
+ lambda c: f"a jpeg corrupted photo of the {c}.",
63
+ lambda c: f"a good photo of a {c}.",
64
+ lambda c: f"a plushie {c}.",
65
+ lambda c: f"a photo of the nice {c}.",
66
+ lambda c: f"a photo of the small {c}.",
67
+ lambda c: f"a photo of the weird {c}.",
68
+ lambda c: f"the cartoon {c}.",
69
+ lambda c: f"art of the {c}.",
70
+ lambda c: f"a drawing of the {c}.",
71
+ lambda c: f"a photo of the large {c}.",
72
+ lambda c: f"a black and white photo of a {c}.",
73
+ lambda c: f"the plushie {c}.",
74
+ lambda c: f"a dark photo of a {c}.",
75
+ lambda c: f"itap of a {c}.",
76
+ lambda c: f"graffiti of the {c}.",
77
+ lambda c: f"a toy {c}.",
78
+ lambda c: f"itap of my {c}.",
79
+ lambda c: f"a photo of a cool {c}.",
80
+ lambda c: f"a photo of a small {c}.",
81
+ lambda c: f"a tattoo of the {c}.",
82
+ ]
components/txt_emb_species.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a81b2931330d7e0e5cf1e9a96982d7eed4ac187b08ad99533c9dad523f5b4f4
3
+ size 110609010