Samuel Stevens commited on
Commit
3e841d9
·
1 Parent(s): 4cfc960

update to work with integer top_values

Browse files
Files changed (2) hide show
  1. app.py +7 -7
  2. requirements.txt +3 -3
app.py CHANGED
@@ -104,7 +104,7 @@ def make_img(
104
  img: Image.Image,
105
  patches: Float[Tensor, " n_patches"],
106
  *,
107
- upper: float | None = None,
108
  ) -> Image.Image:
109
  # Resize to 256x256 and crop to 224x224
110
  resize_size_px = (512, 512)
@@ -241,7 +241,7 @@ def load_tensor(path: str | pathlib.Path) -> Tensor:
241
 
242
 
243
  top_img_i = load_tensor(CWD / "data" / "top_img_i.pt")
244
- top_values = load_tensor(CWD / "data" / "top_values.pt")
245
  sparsity = load_tensor(CWD / "data" / "sparsity.pt")
246
 
247
 
@@ -312,7 +312,7 @@ def get_sae_examples(
312
  upper = top_values[latent].max().item()
313
 
314
  latent_images = [
315
- make_img(img, patches, upper=upper)
316
  for img, patches in img_patch_pairs[:n_sae_examples]
317
  ]
318
 
@@ -359,7 +359,7 @@ def get_modified_dist(
359
 
360
  values = torch.tensor(
361
  [
362
- unscaled(float(value), top_values[latent].max().item())
363
  for value, latent in [
364
  (value1, latent1),
365
  (value2, latent2),
@@ -392,14 +392,14 @@ def get_modified_dist(
392
 
393
 
394
  @beartype.beartype
395
- def unscaled(x: float, max_obs: float) -> float:
396
  """Scale from [-20, 20] to [20 * -max_obs, 20 * max_obs]."""
397
  return map_range(x, (-20.0, 20.0), (-20.0 * max_obs, 20.0 * max_obs))
398
 
399
 
400
  @beartype.beartype
401
  def map_range(
402
- x: float,
403
  domain: tuple[float | int, float | int],
404
  range: tuple[float | int, float | int],
405
  ):
@@ -415,7 +415,7 @@ def add_highlights(
415
  img: Image.Image,
416
  patches: Float[np.ndarray, " n_patches"],
417
  *,
418
- upper: float | None = None,
419
  opacity: float = 0.9,
420
  ) -> Image.Image:
421
  if not len(patches):
 
104
  img: Image.Image,
105
  patches: Float[Tensor, " n_patches"],
106
  *,
107
+ upper: int | None = None,
108
  ) -> Image.Image:
109
  # Resize to 256x256 and crop to 224x224
110
  resize_size_px = (512, 512)
 
241
 
242
 
243
  top_img_i = load_tensor(CWD / "data" / "top_img_i.pt")
244
+ top_values = load_tensor(CWD / "data" / "top_values_uint8.pt")
245
  sparsity = load_tensor(CWD / "data" / "sparsity.pt")
246
 
247
 
 
312
  upper = top_values[latent].max().item()
313
 
314
  latent_images = [
315
+ make_img(img, patches.to(float), upper=upper)
316
  for img, patches in img_patch_pairs[:n_sae_examples]
317
  ]
318
 
 
359
 
360
  values = torch.tensor(
361
  [
362
+ unscaled(value, top_values[latent].max().item())
363
  for value, latent in [
364
  (value1, latent1),
365
  (value2, latent2),
 
392
 
393
 
394
  @beartype.beartype
395
+ def unscaled(x: float | int, max_obs: float | int) -> float:
396
  """Scale from [-20, 20] to [20 * -max_obs, 20 * max_obs]."""
397
  return map_range(x, (-20.0, 20.0), (-20.0 * max_obs, 20.0 * max_obs))
398
 
399
 
400
  @beartype.beartype
401
  def map_range(
402
+ x: float | int,
403
  domain: tuple[float | int, float | int],
404
  range: tuple[float | int, float | int],
405
  ):
 
415
  img: Image.Image,
416
  patches: Float[np.ndarray, " n_patches"],
417
  *,
418
+ upper: int | None = None,
419
  opacity: float = 0.9,
420
  ) -> Image.Image:
421
  if not len(patches):
requirements.txt CHANGED
@@ -265,7 +265,7 @@ multidict==6.1.0
265
  # yarl
266
  multiprocess==0.70.16
267
  # via datasets
268
- narwhals==1.21.0
269
  # via
270
  # altair
271
  # marimo
@@ -414,7 +414,7 @@ pygments==2.19.1
414
  # marimo
415
  # nbconvert
416
  # rich
417
- pymdown-extensions==10.13
418
  # via marimo
419
  pyparsing==3.2.1
420
  # via matplotlib
@@ -483,7 +483,7 @@ saev @ git+https://github.com/samuelstevens/saev@c723ff95462736d907b3c1891d3a149
483
  # via saev-image-classification (pyproject.toml)
484
  safehttpx==0.1.6
485
  # via gradio
486
- safetensors==0.5.0
487
  # via
488
  # open-clip-torch
489
  # timm
 
265
  # yarl
266
  multiprocess==0.70.16
267
  # via datasets
268
+ narwhals==1.21.1
269
  # via
270
  # altair
271
  # marimo
 
414
  # marimo
415
  # nbconvert
416
  # rich
417
+ pymdown-extensions==10.14
418
  # via marimo
419
  pyparsing==3.2.1
420
  # via matplotlib
 
483
  # via saev-image-classification (pyproject.toml)
484
  safehttpx==0.1.6
485
  # via gradio
486
+ safetensors==0.5.1
487
  # via
488
  # open-clip-torch
489
  # timm