Samuel Stevens commited on
Commit
24ce802
·
1 Parent(s): c303125

Update requirements

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +16 -16
  3. pyproject.toml +1 -0
  4. requirements.txt +28 -4
  5. uv.lock +118 -0
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .aider*
2
  .env
 
 
1
  .aider*
2
  .env
3
+ .venv/
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import io
2
  import json
3
  import logging
@@ -46,7 +47,7 @@ vit_ckpt = "ViT-B-16/openai"
46
  n_patches_per_img: int = 196
47
  """Number of patches per image in vit_ckpt."""
48
 
49
- max_frequency = 1e-2
50
  """Maximum frequency. Any feature that fires more than this is ignored."""
51
 
52
  CWD = pathlib.Path(__file__).parent
@@ -93,13 +94,17 @@ def load_model(fpath: str | pathlib.Path, *, device: str = "cpu") -> torch.nn.Mo
93
 
94
 
95
  @beartype.beartype
 
96
  def get_dataset_img(i: int) -> Image.Image:
97
  return Image.open(requests.get(r2_url + image_fpaths[i], stream=True).raw)
98
 
99
 
100
  @beartype.beartype
101
  def make_img(
102
- img: Image.Image, patches: Float[Tensor, ""], *, upper: float | None = None
 
 
 
103
  ) -> Image.Image:
104
  # Resize to 256x256 and crop to 224x224
105
  resize_size_px = (512, 512)
@@ -223,16 +228,6 @@ with open(CWD / "data" / "image_labels.json") as fd:
223
  image_labels = json.load(fd)
224
 
225
 
226
- # TODO:
227
- # This dataset needs to be the CUB2011 dataset. But that means we need to calculate top_img_i based on CUB2011, not on iNat21 train-mini.
228
- # examples_dataset = saev.activations.ImageFolder(
229
- # "/research/nfs_su_809/workspace/stevens.994/datasets/inat21/train_mini",
230
- # transform=v2.Compose([
231
- # v2.Resize(size=(512, 512)),
232
- # v2.CenterCrop(size=(448, 448)),
233
- # ]),
234
- # )
235
-
236
  logger.info("Loaded all datasets.")
237
 
238
  #############
@@ -249,6 +244,7 @@ top_img_i = load_tensor(CWD / "data" / "top_img_i.pt")
249
  top_values = load_tensor(CWD / "data" / "top_values.pt")
250
  sparsity = load_tensor(CWD / "data" / "sparsity.pt")
251
 
 
252
  mask = torch.ones((sae.cfg.d_sae), dtype=bool)
253
  mask = mask & (sparsity < max_frequency)
254
 
@@ -285,10 +281,13 @@ def get_sae_examples(
285
  if not patches:
286
  return [None] * 12 + [-1] * 3
287
 
 
 
288
  img = get_dataset_img(image_i)
289
  x = vit_transform(img)[None, ...].to(device)
290
  x_BPD = split_vit.forward_start(x)
291
- vit_acts_MD = x_BPD[0, patches].to(device)
 
292
 
293
  _, f_x_MS, _ = sae(vit_acts_MD)
294
  f_x_S = f_x_MS.sum(axis=0)
@@ -303,9 +302,8 @@ def get_sae_examples(
303
  if i_im in seen_i_im:
304
  continue
305
 
306
- # example = examples_dataset[i_im]
307
- example = None
308
- img_patch_pairs.append((example["image"], values_p))
309
  seen_i_im.add(i_im)
310
 
311
  # How to scale values.
@@ -492,6 +490,7 @@ with gr.Blocks() as demo:
492
  inputs=[image_number, patch_numbers],
493
  outputs=sae_example_images + top_latent_numbers,
494
  api_name="get-sae-examples",
 
495
  )
496
 
497
  pred_dist = gr.Label(label="Pred. Dist.")
@@ -513,6 +512,7 @@ with gr.Blocks() as demo:
513
  inputs=[image_number, patch_numbers] + latent_numbers + value_sliders,
514
  outputs=[pred_dist],
515
  api_name="get-modified",
 
516
  )
517
 
518
 
 
1
+ import functools
2
  import io
3
  import json
4
  import logging
 
47
  n_patches_per_img: int = 196
48
  """Number of patches per image in vit_ckpt."""
49
 
50
+ max_frequency = 1e-1
51
  """Maximum frequency. Any feature that fires more than this is ignored."""
52
 
53
  CWD = pathlib.Path(__file__).parent
 
94
 
95
 
96
  @beartype.beartype
97
+ @functools.lru_cache(maxsize=512)
98
  def get_dataset_img(i: int) -> Image.Image:
99
  return Image.open(requests.get(r2_url + image_fpaths[i], stream=True).raw)
100
 
101
 
102
  @beartype.beartype
103
  def make_img(
104
+ img: Image.Image,
105
+ patches: Float[Tensor, " n_patches"],
106
+ *,
107
+ upper: float | None = None,
108
  ) -> Image.Image:
109
  # Resize to 256x256 and crop to 224x224
110
  resize_size_px = (512, 512)
 
228
  image_labels = json.load(fd)
229
 
230
 
 
 
 
 
 
 
 
 
 
 
231
  logger.info("Loaded all datasets.")
232
 
233
  #############
 
244
  top_values = load_tensor(CWD / "data" / "top_values.pt")
245
  sparsity = load_tensor(CWD / "data" / "sparsity.pt")
246
 
247
+
248
  mask = torch.ones((sae.cfg.d_sae), dtype=bool)
249
  mask = mask & (sparsity < max_frequency)
250
 
 
281
  if not patches:
282
  return [None] * 12 + [-1] * 3
283
 
284
+ logger.info("Getting SAE examples for patches %s.", patches)
285
+
286
  img = get_dataset_img(image_i)
287
  x = vit_transform(img)[None, ...].to(device)
288
  x_BPD = split_vit.forward_start(x)
289
+ # Need to add 1 to account for [CLS] token.
290
+ vit_acts_MD = x_BPD[0, [p + 1 for p in patches]].to(device)
291
 
292
  _, f_x_MS, _ = sae(vit_acts_MD)
293
  f_x_S = f_x_MS.sum(axis=0)
 
302
  if i_im in seen_i_im:
303
  continue
304
 
305
+ example_img = get_dataset_img(i_im)
306
+ img_patch_pairs.append((example_img, values_p))
 
307
  seen_i_im.add(i_im)
308
 
309
  # How to scale values.
 
490
  inputs=[image_number, patch_numbers],
491
  outputs=sae_example_images + top_latent_numbers,
492
  api_name="get-sae-examples",
493
+ concurrency_limit=16,
494
  )
495
 
496
  pred_dist = gr.Label(label="Pred. Dist.")
 
512
  inputs=[image_number, patch_numbers] + latent_numbers + value_sliders,
513
  outputs=[pred_dist],
514
  api_name="get-modified",
515
+ concurrency_limit=16,
516
  )
517
 
518
 
pyproject.toml CHANGED
@@ -10,6 +10,7 @@ dependencies = [
10
  "gradio>=5.0.0",
11
  "jaxtyping>=0.2.36",
12
  "numpy>=1.26.4",
 
13
  "pillow>=10.4.0",
14
  "torch>=2.4.0",
15
  "torchvision>=0.19.0",
 
10
  "gradio>=5.0.0",
11
  "jaxtyping>=0.2.36",
12
  "numpy>=1.26.4",
13
+ "open-clip-torch>=2.30.0",
14
  "pillow>=10.4.0",
15
  "torch>=2.4.0",
16
  "torchvision>=0.19.0",
requirements.txt CHANGED
@@ -38,9 +38,11 @@ fsspec==2024.12.0
38
  # gradio-client
39
  # huggingface-hub
40
  # torch
41
- gradio==5.9.1
 
 
42
  # via saev-image-classification (pyproject.toml)
43
- gradio-client==1.5.2
44
  # via gradio
45
  h11==0.14.0
46
  # via
@@ -57,6 +59,8 @@ huggingface-hub==0.27.1
57
  # via
58
  # gradio
59
  # gradio-client
 
 
60
  idna==3.10
61
  # via
62
  # anyio
@@ -118,6 +122,8 @@ nvidia-nvjitlink-cu12==12.4.127
118
  # torch
119
  nvidia-nvtx-cu12==12.4.127
120
  # via torch
 
 
121
  orjson==3.10.13
122
  # via gradio
123
  packaging==24.2
@@ -152,6 +158,9 @@ pyyaml==6.0.2
152
  # via
153
  # gradio
154
  # huggingface-hub
 
 
 
155
  requests==2.32.3
156
  # via huggingface-hub
157
  rich==13.9.4
@@ -160,6 +169,10 @@ ruff==0.8.6
160
  # via gradio
161
  safehttpx==0.1.6
162
  # via gradio
 
 
 
 
163
  semantic-version==2.10.0
164
  # via gradio
165
  setuptools==75.7.0
@@ -176,16 +189,25 @@ starlette==0.41.3
176
  # gradio
177
  sympy==1.13.1
178
  # via torch
 
 
179
  tomlkit==0.13.2
180
  # via gradio
181
  torch==2.5.1
182
  # via
183
  # saev-image-classification (pyproject.toml)
 
 
184
  # torchvision
185
  torchvision==0.20.1
186
- # via saev-image-classification (pyproject.toml)
 
 
 
187
  tqdm==4.67.1
188
- # via huggingface-hub
 
 
189
  triton==3.1.0
190
  # via torch
191
  typer==0.15.1
@@ -207,5 +229,7 @@ urllib3==2.3.0
207
  # via requests
208
  uvicorn==0.34.0
209
  # via gradio
 
 
210
  websockets==14.1
211
  # via gradio-client
 
38
  # gradio-client
39
  # huggingface-hub
40
  # torch
41
+ ftfy==6.3.1
42
+ # via open-clip-torch
43
+ gradio==5.10.0
44
  # via saev-image-classification (pyproject.toml)
45
+ gradio-client==1.5.3
46
  # via gradio
47
  h11==0.14.0
48
  # via
 
59
  # via
60
  # gradio
61
  # gradio-client
62
+ # open-clip-torch
63
+ # timm
64
  idna==3.10
65
  # via
66
  # anyio
 
122
  # torch
123
  nvidia-nvtx-cu12==12.4.127
124
  # via torch
125
+ open-clip-torch==2.30.0
126
+ # via saev-image-classification (pyproject.toml)
127
  orjson==3.10.13
128
  # via gradio
129
  packaging==24.2
 
158
  # via
159
  # gradio
160
  # huggingface-hub
161
+ # timm
162
+ regex==2024.11.6
163
+ # via open-clip-torch
164
  requests==2.32.3
165
  # via huggingface-hub
166
  rich==13.9.4
 
169
  # via gradio
170
  safehttpx==0.1.6
171
  # via gradio
172
+ safetensors==0.5.0
173
+ # via
174
+ # open-clip-torch
175
+ # timm
176
  semantic-version==2.10.0
177
  # via gradio
178
  setuptools==75.7.0
 
189
  # gradio
190
  sympy==1.13.1
191
  # via torch
192
+ timm==1.0.12
193
+ # via open-clip-torch
194
  tomlkit==0.13.2
195
  # via gradio
196
  torch==2.5.1
197
  # via
198
  # saev-image-classification (pyproject.toml)
199
+ # open-clip-torch
200
+ # timm
201
  # torchvision
202
  torchvision==0.20.1
203
+ # via
204
+ # saev-image-classification (pyproject.toml)
205
+ # open-clip-torch
206
+ # timm
207
  tqdm==4.67.1
208
+ # via
209
+ # huggingface-hub
210
+ # open-clip-torch
211
  triton==3.1.0
212
  # via torch
213
  typer==0.15.1
 
229
  # via requests
230
  uvicorn==0.34.0
231
  # via gradio
232
+ wcwidth==0.2.13
233
+ # via ftfy
234
  websockets==14.1
235
  # via gradio-client
uv.lock CHANGED
@@ -162,6 +162,18 @@ wheels = [
162
  { url = "https://files.pythonhosted.org/packages/1d/a0/6aaea0c2fbea2f89bfd5db25fb1e3481896a423002ebe4e55288907a97a3/fsspec-2024.9.0-py3-none-any.whl", hash = "sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b", size = 179253 },
163
  ]
164
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  [[package]]
166
  name = "gradio"
167
  version = "5.3.0"
@@ -493,6 +505,25 @@ wheels = [
493
  { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 },
494
  ]
495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  [[package]]
497
  name = "orjson"
498
  version = "3.10.13"
@@ -723,6 +754,44 @@ wheels = [
723
  { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 },
724
  ]
725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
  [[package]]
727
  name = "requests"
728
  version = "2.32.3"
@@ -786,6 +855,7 @@ dependencies = [
786
  { name = "gradio" },
787
  { name = "jaxtyping" },
788
  { name = "numpy" },
 
789
  { name = "pillow" },
790
  { name = "torch" },
791
  { name = "torchvision" },
@@ -798,11 +868,34 @@ requires-dist = [
798
  { name = "gradio", specifier = ">=5.0.0" },
799
  { name = "jaxtyping", specifier = ">=0.2.36" },
800
  { name = "numpy", specifier = ">=1.26.4" },
 
801
  { name = "pillow", specifier = ">=10.4.0" },
802
  { name = "torch", specifier = ">=2.4.0" },
803
  { name = "torchvision", specifier = ">=0.19.0" },
804
  ]
805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
806
  [[package]]
807
  name = "semantic-version"
808
  version = "2.10.0"
@@ -872,6 +965,22 @@ wheels = [
872
  { url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177 },
873
  ]
874
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
875
  [[package]]
876
  name = "tomlkit"
877
  version = "0.12.0"
@@ -1009,6 +1118,15 @@ wheels = [
1009
  { url = "https://files.pythonhosted.org/packages/61/14/33a3a1352cfa71812a3a21e8c9bfb83f60b0011f5e36f2b1399d51928209/uvicorn-0.34.0-py3-none-any.whl", hash = "sha256:023dc038422502fa28a09c7a30bf2b6991512da7dcdb8fd35fe57cfc154126f4", size = 62315 },
1010
  ]
1011
 
 
 
 
 
 
 
 
 
 
1012
  [[package]]
1013
  name = "websockets"
1014
  version = "12.0"
 
162
  { url = "https://files.pythonhosted.org/packages/1d/a0/6aaea0c2fbea2f89bfd5db25fb1e3481896a423002ebe4e55288907a97a3/fsspec-2024.9.0-py3-none-any.whl", hash = "sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b", size = 179253 },
163
  ]
164
 
165
+ [[package]]
166
+ name = "ftfy"
167
+ version = "6.3.1"
168
+ source = { registry = "https://pypi.org/simple" }
169
+ dependencies = [
170
+ { name = "wcwidth" },
171
+ ]
172
+ sdist = { url = "https://files.pythonhosted.org/packages/a5/d3/8650919bc3c7c6e90ee3fa7fd618bf373cbbe55dff043bd67353dbb20cd8/ftfy-6.3.1.tar.gz", hash = "sha256:9b3c3d90f84fb267fe64d375a07b7f8912d817cf86009ae134aa03e1819506ec", size = 308927 }
173
+ wheels = [
174
+ { url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821 },
175
+ ]
176
+
177
  [[package]]
178
  name = "gradio"
179
  version = "5.3.0"
 
505
  { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 },
506
  ]
507
 
508
+ [[package]]
509
+ name = "open-clip-torch"
510
+ version = "2.30.0"
511
+ source = { registry = "https://pypi.org/simple" }
512
+ dependencies = [
513
+ { name = "ftfy" },
514
+ { name = "huggingface-hub" },
515
+ { name = "regex" },
516
+ { name = "safetensors" },
517
+ { name = "timm" },
518
+ { name = "torch" },
519
+ { name = "torchvision" },
520
+ { name = "tqdm" },
521
+ ]
522
+ sdist = { url = "https://files.pythonhosted.org/packages/28/71/133f3eb549d61a937e488805046baaee9eda4acfa8f8cbf01f43f64d2654/open_clip_torch-2.30.0.tar.gz", hash = "sha256:9a635e542a4fb83b268ec8ba2585698e2d5badcb1a517d26dcb49dff1a64c49f", size = 1485046 }
523
+ wheels = [
524
+ { url = "https://files.pythonhosted.org/packages/be/86/6ba3921b9fc0c83fd1838b1fb197973245994258586887876625eda732f8/open_clip_torch-2.30.0-py3-none-any.whl", hash = "sha256:68343092181a03a6a0b3ba8a3529856e40299d4c06bc83082ce73e0ba438187a", size = 1514664 },
525
+ ]
526
+
527
  [[package]]
528
  name = "orjson"
529
  version = "3.10.13"
 
754
  { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 },
755
  ]
756
 
757
+ [[package]]
758
+ name = "regex"
759
+ version = "2024.11.6"
760
+ source = { registry = "https://pypi.org/simple" }
761
+ sdist = { url = "https://files.pythonhosted.org/packages/8e/5f/bd69653fbfb76cf8604468d3b4ec4c403197144c7bfe0e6a5fc9e02a07cb/regex-2024.11.6.tar.gz", hash = "sha256:7ab159b063c52a0333c884e4679f8d7a85112ee3078fe3d9004b2dd875585519", size = 399494 }
762
+ wheels = [
763
+ { url = "https://files.pythonhosted.org/packages/ba/30/9a87ce8336b172cc232a0db89a3af97929d06c11ceaa19d97d84fa90a8f8/regex-2024.11.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:52fb28f528778f184f870b7cf8f225f5eef0a8f6e3778529bdd40c7b3920796a", size = 483781 },
764
+ { url = "https://files.pythonhosted.org/packages/01/e8/00008ad4ff4be8b1844786ba6636035f7ef926db5686e4c0f98093612add/regex-2024.11.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fdd6028445d2460f33136c55eeb1f601ab06d74cb3347132e1c24250187500d9", size = 288455 },
765
+ { url = "https://files.pythonhosted.org/packages/60/85/cebcc0aff603ea0a201667b203f13ba75d9fc8668fab917ac5b2de3967bc/regex-2024.11.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:805e6b60c54bf766b251e94526ebad60b7de0c70f70a4e6210ee2891acb70bf2", size = 284759 },
766
+ { url = "https://files.pythonhosted.org/packages/94/2b/701a4b0585cb05472a4da28ee28fdfe155f3638f5e1ec92306d924e5faf0/regex-2024.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b85c2530be953a890eaffde05485238f07029600e8f098cdf1848d414a8b45e4", size = 794976 },
767
+ { url = "https://files.pythonhosted.org/packages/4b/bf/fa87e563bf5fee75db8915f7352e1887b1249126a1be4813837f5dbec965/regex-2024.11.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bb26437975da7dc36b7efad18aa9dd4ea569d2357ae6b783bf1118dabd9ea577", size = 833077 },
768
+ { url = "https://files.pythonhosted.org/packages/a1/56/7295e6bad94b047f4d0834e4779491b81216583c00c288252ef625c01d23/regex-2024.11.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:abfa5080c374a76a251ba60683242bc17eeb2c9818d0d30117b4486be10c59d3", size = 823160 },
769
+ { url = "https://files.pythonhosted.org/packages/fb/13/e3b075031a738c9598c51cfbc4c7879e26729c53aa9cca59211c44235314/regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b7fa6606c2881c1db9479b0eaa11ed5dfa11c8d60a474ff0e095099f39d98e", size = 796896 },
770
+ { url = "https://files.pythonhosted.org/packages/24/56/0b3f1b66d592be6efec23a795b37732682520b47c53da5a32c33ed7d84e3/regex-2024.11.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c32f75920cf99fe6b6c539c399a4a128452eaf1af27f39bce8909c9a3fd8cbe", size = 783997 },
771
+ { url = "https://files.pythonhosted.org/packages/f9/a1/eb378dada8b91c0e4c5f08ffb56f25fcae47bf52ad18f9b2f33b83e6d498/regex-2024.11.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:982e6d21414e78e1f51cf595d7f321dcd14de1f2881c5dc6a6e23bbbbd68435e", size = 781725 },
772
+ { url = "https://files.pythonhosted.org/packages/83/f2/033e7dec0cfd6dda93390089864732a3409246ffe8b042e9554afa9bff4e/regex-2024.11.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:a7c2155f790e2fb448faed6dd241386719802296ec588a8b9051c1f5c481bc29", size = 789481 },
773
+ { url = "https://files.pythonhosted.org/packages/83/23/15d4552ea28990a74e7696780c438aadd73a20318c47e527b47a4a5a596d/regex-2024.11.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149f5008d286636e48cd0b1dd65018548944e495b0265b45e1bffecce1ef7f39", size = 852896 },
774
+ { url = "https://files.pythonhosted.org/packages/e3/39/ed4416bc90deedbfdada2568b2cb0bc1fdb98efe11f5378d9892b2a88f8f/regex-2024.11.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:e5364a4502efca094731680e80009632ad6624084aff9a23ce8c8c6820de3e51", size = 860138 },
775
+ { url = "https://files.pythonhosted.org/packages/93/2d/dd56bb76bd8e95bbce684326302f287455b56242a4f9c61f1bc76e28360e/regex-2024.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0a86e7eeca091c09e021db8eb72d54751e527fa47b8d5787caf96d9831bd02ad", size = 787692 },
776
+ { url = "https://files.pythonhosted.org/packages/0b/55/31877a249ab7a5156758246b9c59539abbeba22461b7d8adc9e8475ff73e/regex-2024.11.6-cp312-cp312-win32.whl", hash = "sha256:32f9a4c643baad4efa81d549c2aadefaeba12249b2adc5af541759237eee1c54", size = 262135 },
777
+ { url = "https://files.pythonhosted.org/packages/38/ec/ad2d7de49a600cdb8dd78434a1aeffe28b9d6fc42eb36afab4a27ad23384/regex-2024.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:a93c194e2df18f7d264092dc8539b8ffb86b45b899ab976aa15d48214138e81b", size = 273567 },
778
+ { url = "https://files.pythonhosted.org/packages/90/73/bcb0e36614601016552fa9344544a3a2ae1809dc1401b100eab02e772e1f/regex-2024.11.6-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a6ba92c0bcdf96cbf43a12c717eae4bc98325ca3730f6b130ffa2e3c3c723d84", size = 483525 },
779
+ { url = "https://files.pythonhosted.org/packages/0f/3f/f1a082a46b31e25291d830b369b6b0c5576a6f7fb89d3053a354c24b8a83/regex-2024.11.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:525eab0b789891ac3be914d36893bdf972d483fe66551f79d3e27146191a37d4", size = 288324 },
780
+ { url = "https://files.pythonhosted.org/packages/09/c9/4e68181a4a652fb3ef5099e077faf4fd2a694ea6e0f806a7737aff9e758a/regex-2024.11.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:086a27a0b4ca227941700e0b31425e7a28ef1ae8e5e05a33826e17e47fbfdba0", size = 284617 },
781
+ { url = "https://files.pythonhosted.org/packages/fc/fd/37868b75eaf63843165f1d2122ca6cb94bfc0271e4428cf58c0616786dce/regex-2024.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bde01f35767c4a7899b7eb6e823b125a64de314a8ee9791367c9a34d56af18d0", size = 795023 },
782
+ { url = "https://files.pythonhosted.org/packages/c4/7c/d4cd9c528502a3dedb5c13c146e7a7a539a3853dc20209c8e75d9ba9d1b2/regex-2024.11.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b583904576650166b3d920d2bcce13971f6f9e9a396c673187f49811b2769dc7", size = 833072 },
783
+ { url = "https://files.pythonhosted.org/packages/4f/db/46f563a08f969159c5a0f0e722260568425363bea43bb7ae370becb66a67/regex-2024.11.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c4de13f06a0d54fa0d5ab1b7138bfa0d883220965a29616e3ea61b35d5f5fc7", size = 823130 },
784
+ { url = "https://files.pythonhosted.org/packages/db/60/1eeca2074f5b87df394fccaa432ae3fc06c9c9bfa97c5051aed70e6e00c2/regex-2024.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3cde6e9f2580eb1665965ce9bf17ff4952f34f5b126beb509fee8f4e994f143c", size = 796857 },
785
+ { url = "https://files.pythonhosted.org/packages/10/db/ac718a08fcee981554d2f7bb8402f1faa7e868c1345c16ab1ebec54b0d7b/regex-2024.11.6-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d7f453dca13f40a02b79636a339c5b62b670141e63efd511d3f8f73fba162b3", size = 784006 },
786
+ { url = "https://files.pythonhosted.org/packages/c2/41/7da3fe70216cea93144bf12da2b87367590bcf07db97604edeea55dac9ad/regex-2024.11.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:59dfe1ed21aea057a65c6b586afd2a945de04fc7db3de0a6e3ed5397ad491b07", size = 781650 },
787
+ { url = "https://files.pythonhosted.org/packages/a7/d5/880921ee4eec393a4752e6ab9f0fe28009435417c3102fc413f3fe81c4e5/regex-2024.11.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b97c1e0bd37c5cd7902e65f410779d39eeda155800b65fc4d04cc432efa9bc6e", size = 789545 },
788
+ { url = "https://files.pythonhosted.org/packages/dc/96/53770115e507081122beca8899ab7f5ae28ae790bfcc82b5e38976df6a77/regex-2024.11.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f9d1e379028e0fc2ae3654bac3cbbef81bf3fd571272a42d56c24007979bafb6", size = 853045 },
789
+ { url = "https://files.pythonhosted.org/packages/31/d3/1372add5251cc2d44b451bd94f43b2ec78e15a6e82bff6a290ef9fd8f00a/regex-2024.11.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:13291b39131e2d002a7940fb176e120bec5145f3aeb7621be6534e46251912c4", size = 860182 },
790
+ { url = "https://files.pythonhosted.org/packages/ed/e3/c446a64984ea9f69982ba1a69d4658d5014bc7a0ea468a07e1a1265db6e2/regex-2024.11.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f51f88c126370dcec4908576c5a627220da6c09d0bff31cfa89f2523843316d", size = 787733 },
791
+ { url = "https://files.pythonhosted.org/packages/2b/f1/e40c8373e3480e4f29f2692bd21b3e05f296d3afebc7e5dcf21b9756ca1c/regex-2024.11.6-cp313-cp313-win32.whl", hash = "sha256:63b13cfd72e9601125027202cad74995ab26921d8cd935c25f09c630436348ff", size = 262122 },
792
+ { url = "https://files.pythonhosted.org/packages/45/94/bc295babb3062a731f52621cdc992d123111282e291abaf23faa413443ea/regex-2024.11.6-cp313-cp313-win_amd64.whl", hash = "sha256:2b3361af3198667e99927da8b84c1b010752fa4b1115ee30beaa332cabc3ef1a", size = 273545 },
793
+ ]
794
+
795
  [[package]]
796
  name = "requests"
797
  version = "2.32.3"
 
855
  { name = "gradio" },
856
  { name = "jaxtyping" },
857
  { name = "numpy" },
858
+ { name = "open-clip-torch" },
859
  { name = "pillow" },
860
  { name = "torch" },
861
  { name = "torchvision" },
 
868
  { name = "gradio", specifier = ">=5.0.0" },
869
  { name = "jaxtyping", specifier = ">=0.2.36" },
870
  { name = "numpy", specifier = ">=1.26.4" },
871
+ { name = "open-clip-torch", specifier = ">=2.30.0" },
872
  { name = "pillow", specifier = ">=10.4.0" },
873
  { name = "torch", specifier = ">=2.4.0" },
874
  { name = "torchvision", specifier = ">=0.19.0" },
875
  ]
876
 
877
+ [[package]]
878
+ name = "safetensors"
879
+ version = "0.5.0"
880
+ source = { registry = "https://pypi.org/simple" }
881
+ sdist = { url = "https://files.pythonhosted.org/packages/5d/b3/1d9000e9d0470499d124ca63c6908f8092b528b48bd95ba11507e14d9dba/safetensors-0.5.0.tar.gz", hash = "sha256:c47b34c549fa1e0c655c4644da31332c61332c732c47c8dd9399347e9aac69d1", size = 65660 }
882
+ wheels = [
883
+ { url = "https://files.pythonhosted.org/packages/0f/ee/0fd61b99bc58db736a3ab3d97d49d4a11afe71ee0aad85b25d6c4235b743/safetensors-0.5.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c683b9b485bee43422ba2855f72777c37647190281e03da4c8d2a69fa5336558", size = 426509 },
884
+ { url = "https://files.pythonhosted.org/packages/51/aa/de1a11aa056d0241f95d5de9dbb1ac2dabaf3df5c568f9375451fd593c95/safetensors-0.5.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:6106aa835deb7263f7014f74c05842ab828d6c11d789f2e7e98f26b1a305e72d", size = 408471 },
885
+ { url = "https://files.pythonhosted.org/packages/a5/c7/84b821bd90547a909053a8526ff70446f062287cda20d0ec024c1a1f80f6/safetensors-0.5.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1349611f74f55c5ee1c1c144c536a2743c38f7d8bf60b9fc8267e0efc0591a2", size = 449638 },
886
+ { url = "https://files.pythonhosted.org/packages/b5/25/3d20bb9f669fec704e01d70849e9c6c054601efe9b5e784ce9a865cf3c52/safetensors-0.5.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:56d936028ac799e18644b08a91fd98b4b62ae3dcd0440b1cfcb56535785589f1", size = 458246 },
887
+ { url = "https://files.pythonhosted.org/packages/31/35/68e1c39c4ad6a2f9373fc89588c0fbd29b1899c57c3a6482fc8e42fa4c8f/safetensors-0.5.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2f26afada2233576ffea6b80042c2c0a8105c164254af56168ec14299ad3122", size = 509573 },
888
+ { url = "https://files.pythonhosted.org/packages/85/b0/79927c6d4f70232f04a46785ea8b0ed0f70f9be74d17e0a90e1890523553/safetensors-0.5.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:20067e7a5e63f0cbc88457b2a1161e70ff73af4cc3a24bce90309430cd6f6e7e", size = 525555 },
889
+ { url = "https://files.pythonhosted.org/packages/a6/83/ca8c1af662a20a545c174b8949e63865b747c180b607260eed83c1d38c72/safetensors-0.5.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:649d6a4aa34d5174ae87289068ccc2fec2a1a998ecf83425aa5a42c3eff69bcf", size = 461294 },
890
+ { url = "https://files.pythonhosted.org/packages/81/ef/1d11d08b14b36e3e3d701629c9685ad95c3afee7da2851658d6c65cad9be/safetensors-0.5.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:debff88f41d569a3e93a955469f83864e432af35bb34b16f65a9ddf378daa3ae", size = 490593 },
891
+ { url = "https://files.pythonhosted.org/packages/f6/9a/50bf824a26d768d33485b7208ba5e6a173a80a2633be5e213a2494d1569b/safetensors-0.5.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:bdf6a3e366ea8ba1a0538db6099229e95811194432c684ea28ea7ae28763b8dc", size = 628142 },
892
+ { url = "https://files.pythonhosted.org/packages/28/22/dc5ae22523b8221017dbf6984fedfe2c6f35ff4cc76e80bbab2b9e14cc8a/safetensors-0.5.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:0371afd84c200a80eb7103bf715108b0c3846132fb82453ae018609a15551580", size = 721377 },
893
+ { url = "https://files.pythonhosted.org/packages/fe/87/36323e8058e7101ef0101fde6d71c375a9ab6059d3d9501fe8fb8d13a45a/safetensors-0.5.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5ec7fc8c3d2f32ebf1c7011bc886b362e53ee0a1ec6d828c39d531fed8b325d6", size = 659192 },
894
+ { url = "https://files.pythonhosted.org/packages/dd/2f/8d526f06bb192b45b4e0fec94284d568497e6e19620c834373749a5f9787/safetensors-0.5.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:53715e4ea0ef23c08f004baae0f609a7773de7d4148727760417c6760cfd6b76", size = 632231 },
895
+ { url = "https://files.pythonhosted.org/packages/d3/68/1166bba02f77c811d17766e54a54d7714c1276f54bfcf60d50bb9326a1b4/safetensors-0.5.0-cp38-abi3-win32.whl", hash = "sha256:b85565bc2f0456961a788d2f11d9d892eec46603db0e4923aa9512c2355aa727", size = 290608 },
896
+ { url = "https://files.pythonhosted.org/packages/0c/ab/a428973e43a77791d2fd4b6425f4fd82e9f8559b32222c861acbbd7bc910/safetensors-0.5.0-cp38-abi3-win_amd64.whl", hash = "sha256:f451941f8aa11e7be5c3fa450e264609a2b1e65fa38ae590a74e55a94d646b76", size = 303322 },
897
+ ]
898
+
899
  [[package]]
900
  name = "semantic-version"
901
  version = "2.10.0"
 
965
  { url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177 },
966
  ]
967
 
968
+ [[package]]
969
+ name = "timm"
970
+ version = "1.0.12"
971
+ source = { registry = "https://pypi.org/simple" }
972
+ dependencies = [
973
+ { name = "huggingface-hub" },
974
+ { name = "pyyaml" },
975
+ { name = "safetensors" },
976
+ { name = "torch" },
977
+ { name = "torchvision" },
978
+ ]
979
+ sdist = { url = "https://files.pythonhosted.org/packages/8b/d9/7382e27b379d4b791811396e05c9421703bbb21b2f726ff3e78469b5b194/timm-1.0.12.tar.gz", hash = "sha256:9da490683bd06302ec40e1892f1ccf87985f033e41f3580887d886b9aee9449a", size = 2219847 }
980
+ wheels = [
981
+ { url = "https://files.pythonhosted.org/packages/6b/02/0d8925809296bed4cf841446e1291c3f381fde6d777a1ab2a25a3829b4a4/timm-1.0.12-py3-none-any.whl", hash = "sha256:6b2770674213f10b7f193be5598ce48bd010ab21cc8af77dba6aeef58b1298a1", size = 2351767 },
982
+ ]
983
+
984
  [[package]]
985
  name = "tomlkit"
986
  version = "0.12.0"
 
1118
  { url = "https://files.pythonhosted.org/packages/61/14/33a3a1352cfa71812a3a21e8c9bfb83f60b0011f5e36f2b1399d51928209/uvicorn-0.34.0-py3-none-any.whl", hash = "sha256:023dc038422502fa28a09c7a30bf2b6991512da7dcdb8fd35fe57cfc154126f4", size = 62315 },
1119
  ]
1120
 
1121
+ [[package]]
1122
+ name = "wcwidth"
1123
+ version = "0.2.13"
1124
+ source = { registry = "https://pypi.org/simple" }
1125
+ sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5", size = 101301 }
1126
+ wheels = [
1127
+ { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 },
1128
+ ]
1129
+
1130
  [[package]]
1131
  name = "websockets"
1132
  version = "12.0"