Spaces:
Running
on
L4
Running
on
L4
from dataclasses import dataclass, field | |
from typing import Dict, List, Optional | |
import torch | |
from jaxtyping import Float | |
from torch import Tensor | |
from spar3d.models.utils import BaseModule | |
from .field import RENIField | |
def _direction_from_coordinate( | |
coordinate: Float[Tensor, "*B 2"], | |
) -> Float[Tensor, "*B 3"]: | |
# OpenGL Convention | |
# +X Right | |
# +Y Up | |
# +Z Backward | |
u, v = coordinate.unbind(-1) | |
theta = (2 * torch.pi * u) - torch.pi | |
phi = torch.pi * v | |
dir = torch.stack( | |
[ | |
theta.sin() * phi.sin(), | |
phi.cos(), | |
-1 * theta.cos() * phi.sin(), | |
], | |
-1, | |
) | |
return dir | |
def _get_sample_coordinates( | |
resolution: List[int], device: Optional[torch.device] = None | |
) -> Float[Tensor, "H W 2"]: | |
return torch.stack( | |
torch.meshgrid( | |
(torch.arange(resolution[1], device=device) + 0.5) / resolution[1], | |
(torch.arange(resolution[0], device=device) + 0.5) / resolution[0], | |
indexing="xy", | |
), | |
-1, | |
) | |
class RENIEnvMap(BaseModule): | |
class Config(BaseModule.Config): | |
reni_config: dict = field(default_factory=dict) | |
resolution: int = 128 | |
cfg: Config | |
def configure(self): | |
self.field = RENIField(self.cfg.reni_config) | |
resolution = (self.cfg.resolution, self.cfg.resolution * 2) | |
sample_directions = _direction_from_coordinate( | |
_get_sample_coordinates(resolution) | |
) | |
self.img_shape = sample_directions.shape[:-1] | |
sample_directions_flat = sample_directions.view(-1, 3) | |
# Lastly these have y up but reni expects z up. Rotate 90 degrees on x axis | |
sample_directions_flat = torch.stack( | |
[ | |
sample_directions_flat[:, 0], | |
-sample_directions_flat[:, 2], | |
sample_directions_flat[:, 1], | |
], | |
-1, | |
) | |
self.sample_directions = torch.nn.Parameter( | |
sample_directions_flat, requires_grad=False | |
) | |
def forward( | |
self, | |
latent_codes: Float[Tensor, "B latent_dim 3"], | |
rotation: Optional[Float[Tensor, "B 3 3"]] = None, | |
scale: Optional[Float[Tensor, "B"]] = None, | |
) -> Dict[str, Tensor]: | |
return { | |
k: v.view(latent_codes.shape[0], *self.img_shape, -1) | |
for k, v in self.field( | |
self.sample_directions.unsqueeze(0).repeat(latent_codes.shape[0], 1, 1), | |
latent_codes, | |
rotation=rotation, | |
scale=scale, | |
).items() | |
} | |