Samuel Stevens
commited on
Commit
·
3e841d9
1
Parent(s):
4cfc960
update to work with integer top_values
Browse files- app.py +7 -7
- requirements.txt +3 -3
app.py
CHANGED
@@ -104,7 +104,7 @@ def make_img(
|
|
104 |
img: Image.Image,
|
105 |
patches: Float[Tensor, " n_patches"],
|
106 |
*,
|
107 |
-
upper:
|
108 |
) -> Image.Image:
|
109 |
# Resize to 256x256 and crop to 224x224
|
110 |
resize_size_px = (512, 512)
|
@@ -241,7 +241,7 @@ def load_tensor(path: str | pathlib.Path) -> Tensor:
|
|
241 |
|
242 |
|
243 |
top_img_i = load_tensor(CWD / "data" / "top_img_i.pt")
|
244 |
-
top_values = load_tensor(CWD / "data" / "
|
245 |
sparsity = load_tensor(CWD / "data" / "sparsity.pt")
|
246 |
|
247 |
|
@@ -312,7 +312,7 @@ def get_sae_examples(
|
|
312 |
upper = top_values[latent].max().item()
|
313 |
|
314 |
latent_images = [
|
315 |
-
make_img(img, patches, upper=upper)
|
316 |
for img, patches in img_patch_pairs[:n_sae_examples]
|
317 |
]
|
318 |
|
@@ -359,7 +359,7 @@ def get_modified_dist(
|
|
359 |
|
360 |
values = torch.tensor(
|
361 |
[
|
362 |
-
unscaled(
|
363 |
for value, latent in [
|
364 |
(value1, latent1),
|
365 |
(value2, latent2),
|
@@ -392,14 +392,14 @@ def get_modified_dist(
|
|
392 |
|
393 |
|
394 |
@beartype.beartype
|
395 |
-
def unscaled(x: float, max_obs: float) -> float:
|
396 |
"""Scale from [-20, 20] to [20 * -max_obs, 20 * max_obs]."""
|
397 |
return map_range(x, (-20.0, 20.0), (-20.0 * max_obs, 20.0 * max_obs))
|
398 |
|
399 |
|
400 |
@beartype.beartype
|
401 |
def map_range(
|
402 |
-
x: float,
|
403 |
domain: tuple[float | int, float | int],
|
404 |
range: tuple[float | int, float | int],
|
405 |
):
|
@@ -415,7 +415,7 @@ def add_highlights(
|
|
415 |
img: Image.Image,
|
416 |
patches: Float[np.ndarray, " n_patches"],
|
417 |
*,
|
418 |
-
upper:
|
419 |
opacity: float = 0.9,
|
420 |
) -> Image.Image:
|
421 |
if not len(patches):
|
|
|
104 |
img: Image.Image,
|
105 |
patches: Float[Tensor, " n_patches"],
|
106 |
*,
|
107 |
+
upper: int | None = None,
|
108 |
) -> Image.Image:
|
109 |
# Resize to 256x256 and crop to 224x224
|
110 |
resize_size_px = (512, 512)
|
|
|
241 |
|
242 |
|
243 |
top_img_i = load_tensor(CWD / "data" / "top_img_i.pt")
|
244 |
+
top_values = load_tensor(CWD / "data" / "top_values_uint8.pt")
|
245 |
sparsity = load_tensor(CWD / "data" / "sparsity.pt")
|
246 |
|
247 |
|
|
|
312 |
upper = top_values[latent].max().item()
|
313 |
|
314 |
latent_images = [
|
315 |
+
make_img(img, patches.to(float), upper=upper)
|
316 |
for img, patches in img_patch_pairs[:n_sae_examples]
|
317 |
]
|
318 |
|
|
|
359 |
|
360 |
values = torch.tensor(
|
361 |
[
|
362 |
+
unscaled(value, top_values[latent].max().item())
|
363 |
for value, latent in [
|
364 |
(value1, latent1),
|
365 |
(value2, latent2),
|
|
|
392 |
|
393 |
|
394 |
@beartype.beartype
|
395 |
+
def unscaled(x: float | int, max_obs: float | int) -> float:
|
396 |
"""Scale from [-20, 20] to [20 * -max_obs, 20 * max_obs]."""
|
397 |
return map_range(x, (-20.0, 20.0), (-20.0 * max_obs, 20.0 * max_obs))
|
398 |
|
399 |
|
400 |
@beartype.beartype
|
401 |
def map_range(
|
402 |
+
x: float | int,
|
403 |
domain: tuple[float | int, float | int],
|
404 |
range: tuple[float | int, float | int],
|
405 |
):
|
|
|
415 |
img: Image.Image,
|
416 |
patches: Float[np.ndarray, " n_patches"],
|
417 |
*,
|
418 |
+
upper: int | None = None,
|
419 |
opacity: float = 0.9,
|
420 |
) -> Image.Image:
|
421 |
if not len(patches):
|
requirements.txt
CHANGED
@@ -265,7 +265,7 @@ multidict==6.1.0
|
|
265 |
# yarl
|
266 |
multiprocess==0.70.16
|
267 |
# via datasets
|
268 |
-
narwhals==1.21.
|
269 |
# via
|
270 |
# altair
|
271 |
# marimo
|
@@ -414,7 +414,7 @@ pygments==2.19.1
|
|
414 |
# marimo
|
415 |
# nbconvert
|
416 |
# rich
|
417 |
-
pymdown-extensions==10.
|
418 |
# via marimo
|
419 |
pyparsing==3.2.1
|
420 |
# via matplotlib
|
@@ -483,7 +483,7 @@ saev @ git+https://github.com/samuelstevens/saev@c723ff95462736d907b3c1891d3a149
|
|
483 |
# via saev-image-classification (pyproject.toml)
|
484 |
safehttpx==0.1.6
|
485 |
# via gradio
|
486 |
-
safetensors==0.5.
|
487 |
# via
|
488 |
# open-clip-torch
|
489 |
# timm
|
|
|
265 |
# yarl
|
266 |
multiprocess==0.70.16
|
267 |
# via datasets
|
268 |
+
narwhals==1.21.1
|
269 |
# via
|
270 |
# altair
|
271 |
# marimo
|
|
|
414 |
# marimo
|
415 |
# nbconvert
|
416 |
# rich
|
417 |
+
pymdown-extensions==10.14
|
418 |
# via marimo
|
419 |
pyparsing==3.2.1
|
420 |
# via matplotlib
|
|
|
483 |
# via saev-image-classification (pyproject.toml)
|
484 |
safehttpx==0.1.6
|
485 |
# via gradio
|
486 |
+
safetensors==0.5.1
|
487 |
# via
|
488 |
# open-clip-torch
|
489 |
# timm
|