Spaces:
Running
on
L4
Running
on
L4
# Copyright 2023 The University of York. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Modified by Mark Boss | |
"""RENI field""" | |
import contextlib | |
from dataclasses import dataclass | |
from typing import Dict, Literal, Optional | |
import torch | |
from einops.layers.torch import Rearrange | |
from jaxtyping import Float | |
from torch import Tensor, nn | |
from spar3d.models.network import get_activation_module, trunc_exp | |
from spar3d.models.utils import BaseModule | |
from .components.film_siren import FiLMSiren | |
from .components.siren import Siren | |
from .components.transformer_decoder import Decoder | |
from .components.vn_layers import VNInvariant, VNLinear | |
# from nerfstudio.cameras.rays import RaySamples | |
def expected_sin(x_means: torch.Tensor, x_vars: torch.Tensor) -> torch.Tensor: | |
"""Computes the expected value of sin(y) where y ~ N(x_means, x_vars) | |
Args: | |
x_means: Mean values. | |
x_vars: Variance of values. | |
Returns: | |
torch.Tensor: The expected value of sin. | |
""" | |
return torch.exp(-0.5 * x_vars) * torch.sin(x_means) | |
class NeRFEncoding(torch.nn.Module): | |
"""Multi-scale sinousoidal encodings. Support ``integrated positional encodings`` if covariances are provided. | |
Each axis is encoded with frequencies ranging from 2^min_freq_exp to 2^max_freq_exp. | |
Args: | |
in_dim: Input dimension of tensor | |
num_frequencies: Number of encoded frequencies per axis | |
min_freq_exp: Minimum frequency exponent | |
max_freq_exp: Maximum frequency exponent | |
include_input: Append the input coordinate to the encoding | |
""" | |
def __init__( | |
self, | |
in_dim: int, | |
num_frequencies: int, | |
min_freq_exp: float, | |
max_freq_exp: float, | |
include_input: bool = False, | |
off_axis: bool = False, | |
) -> None: | |
super().__init__() | |
self.in_dim = in_dim | |
self.num_frequencies = num_frequencies | |
self.min_freq = min_freq_exp | |
self.max_freq = max_freq_exp | |
self.include_input = include_input | |
self.off_axis = off_axis | |
self.P = torch.tensor( | |
[ | |
[0.8506508, 0, 0.5257311], | |
[0.809017, 0.5, 0.309017], | |
[0.5257311, 0.8506508, 0], | |
[1, 0, 0], | |
[0.809017, 0.5, -0.309017], | |
[0.8506508, 0, -0.5257311], | |
[0.309017, 0.809017, -0.5], | |
[0, 0.5257311, -0.8506508], | |
[0.5, 0.309017, -0.809017], | |
[0, 1, 0], | |
[-0.5257311, 0.8506508, 0], | |
[-0.309017, 0.809017, -0.5], | |
[0, 0.5257311, 0.8506508], | |
[-0.309017, 0.809017, 0.5], | |
[0.309017, 0.809017, 0.5], | |
[0.5, 0.309017, 0.809017], | |
[0.5, -0.309017, 0.809017], | |
[0, 0, 1], | |
[-0.5, 0.309017, 0.809017], | |
[-0.809017, 0.5, 0.309017], | |
[-0.809017, 0.5, -0.309017], | |
] | |
).T | |
def get_out_dim(self) -> int: | |
if self.in_dim is None: | |
raise ValueError("Input dimension has not been set") | |
out_dim = self.in_dim * self.num_frequencies * 2 | |
if self.off_axis: | |
out_dim = self.P.shape[1] * self.num_frequencies * 2 | |
if self.include_input: | |
out_dim += self.in_dim | |
return out_dim | |
def forward( | |
self, | |
in_tensor: Float[Tensor, "*b input_dim"], | |
covs: Optional[Float[Tensor, "*b input_dim input_dim"]] = None, | |
) -> Float[Tensor, "*b output_dim"]: | |
"""Calculates NeRF encoding. If covariances are provided the encodings will be integrated as proposed | |
in mip-NeRF. | |
Args: | |
in_tensor: For best performance, the input tensor should be between 0 and 1. | |
covs: Covariances of input points. | |
Returns: | |
Output values will be between -1 and 1 | |
""" | |
# TODO check scaling here but just comment it for now | |
# in_tensor = 2 * torch.pi * in_tensor # scale to [0, 2pi] | |
freqs = 2 ** torch.linspace( | |
self.min_freq, self.max_freq, self.num_frequencies | |
).to(in_tensor.device) | |
# freqs = 2 ** ( | |
# torch.sin(torch.linspace(self.min_freq, torch.pi / 2.0, self.num_frequencies)) * self.max_freq | |
# ).to(in_tensor.device) | |
# freqs = 2 ** ( | |
# torch.linspace(self.min_freq, 1.0, self.num_frequencies).to(in_tensor.device) ** 0.2 * self.max_freq | |
# ) | |
if self.off_axis: | |
scaled_inputs = ( | |
torch.matmul(in_tensor, self.P.to(in_tensor.device))[..., None] * freqs | |
) | |
else: | |
scaled_inputs = ( | |
in_tensor[..., None] * freqs | |
) # [..., "input_dim", "num_scales"] | |
scaled_inputs = scaled_inputs.view( | |
*scaled_inputs.shape[:-2], -1 | |
) # [..., "input_dim" * "num_scales"] | |
if covs is None: | |
encoded_inputs = torch.sin( | |
torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1) | |
) | |
else: | |
input_var = ( | |
torch.diagonal(covs, dim1=-2, dim2=-1)[..., :, None] | |
* freqs[None, :] ** 2 | |
) | |
input_var = input_var.reshape((*input_var.shape[:-2], -1)) | |
encoded_inputs = expected_sin( | |
torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1), | |
torch.cat(2 * [input_var], dim=-1), | |
) | |
if self.include_input: | |
encoded_inputs = torch.cat([encoded_inputs, in_tensor], dim=-1) | |
return encoded_inputs | |
class RENIField(BaseModule): | |
class Config(BaseModule.Config): | |
"""Configuration for model instantiation""" | |
fixed_decoder: bool = False | |
"""Whether to fix the decoder weights""" | |
equivariance: str = "SO2" | |
"""Type of equivariance to use: None, SO2, SO3""" | |
axis_of_invariance: str = "y" | |
"""Which axis should SO2 equivariance be invariant to: x, y, z""" | |
invariant_function: str = "GramMatrix" | |
"""Type of invariant function to use: GramMatrix, VN""" | |
conditioning: str = "Concat" | |
"""Type of conditioning to use: FiLM, Concat, Attention""" | |
positional_encoding: str = "NeRF" | |
"""Type of positional encoding to use. Currently only NeRF is supported""" | |
encoded_input: str = "Directions" | |
"""Type of input to encode: None, Directions, Conditioning, Both""" | |
latent_dim: int = 36 | |
"""Dimensionality of latent code, N for a latent code size of (N x 3)""" | |
hidden_layers: int = 3 | |
"""Number of hidden layers""" | |
hidden_features: int = 128 | |
"""Number of hidden features""" | |
mapping_layers: int = 3 | |
"""Number of mapping layers""" | |
mapping_features: int = 128 | |
"""Number of mapping features""" | |
num_attention_heads: int = 8 | |
"""Number of attention heads""" | |
num_attention_layers: int = 3 | |
"""Number of attention layers""" | |
out_features: int = 3 # RGB | |
"""Number of output features""" | |
last_layer_linear: bool = False | |
"""Whether to use a linear layer as the last layer""" | |
output_activation: str = "exp" | |
"""Activation function for output layer: sigmoid, tanh, relu, exp, None""" | |
first_omega_0: float = 30.0 | |
"""Omega_0 for first layer""" | |
hidden_omega_0: float = 30.0 | |
"""Omega_0 for hidden layers""" | |
fixed_decoder: bool = False | |
"""Whether to fix the decoder weights""" | |
old_implementation: bool = False | |
"""Whether to match implementation of old RENI, when using old checkpoints""" | |
cfg: Config | |
def configure(self): | |
self.equivariance = self.cfg.equivariance | |
self.conditioning = self.cfg.conditioning | |
self.latent_dim = self.cfg.latent_dim | |
self.hidden_layers = self.cfg.hidden_layers | |
self.hidden_features = self.cfg.hidden_features | |
self.mapping_layers = self.cfg.mapping_layers | |
self.mapping_features = self.cfg.mapping_features | |
self.out_features = self.cfg.out_features | |
self.last_layer_linear = self.cfg.last_layer_linear | |
self.output_activation = self.cfg.output_activation | |
self.first_omega_0 = self.cfg.first_omega_0 | |
self.hidden_omega_0 = self.cfg.hidden_omega_0 | |
self.old_implementation = self.cfg.old_implementation | |
self.axis_of_invariance = ["x", "y", "z"].index(self.cfg.axis_of_invariance) | |
self.fixed_decoder = self.cfg.fixed_decoder | |
if self.cfg.invariant_function == "GramMatrix": | |
self.invariant_function = self.gram_matrix_invariance | |
else: | |
self.vn_proj_in = nn.Sequential( | |
Rearrange("... c -> ... 1 c"), | |
VNLinear(dim_in=1, dim_out=1, bias_epsilon=0), | |
) | |
dim_coor = 2 if self.cfg.equivariance == "SO2" else 3 | |
self.vn_invar = VNInvariant(dim=1, dim_coor=dim_coor) | |
self.invariant_function = self.vn_invariance | |
self.network = self.setup_network() | |
if self.fixed_decoder: | |
for param in self.network.parameters(): | |
param.requires_grad = False | |
if self.cfg.invariant_function == "VN": | |
for param in self.vn_proj_in.parameters(): | |
param.requires_grad = False | |
for param in self.vn_invar.parameters(): | |
param.requires_grad = False | |
def hold_decoder_fixed(self): | |
"""Context manager to fix the decoder weights | |
Example usage: | |
``` | |
with instance_of_RENIField.hold_decoder_fixed(): | |
# do stuff | |
``` | |
""" | |
prev_state_network = { | |
name: p.requires_grad for name, p in self.network.named_parameters() | |
} | |
for param in self.network.parameters(): | |
param.requires_grad = False | |
if self.cfg.invariant_function == "VN": | |
prev_state_proj_in = { | |
k: p.requires_grad for k, p in self.vn_proj_in.named_parameters() | |
} | |
prev_state_invar = { | |
k: p.requires_grad for k, p in self.vn_invar.named_parameters() | |
} | |
for param in self.vn_proj_in.parameters(): | |
param.requires_grad = False | |
for param in self.vn_invar.parameters(): | |
param.requires_grad = False | |
prev_decoder_state = self.fixed_decoder | |
self.fixed_decoder = True | |
try: | |
yield | |
finally: | |
# Restore the previous requires_grad state | |
for name, param in self.network.named_parameters(): | |
param.requires_grad = prev_state_network[name] | |
if self.cfg.invariant_function == "VN": | |
for name, param in self.vn_proj_in.named_parameters(): | |
param.requires_grad_(prev_state_proj_in[name]) | |
for name, param in self.vn_invar.named_parameters(): | |
param.requires_grad_(prev_state_invar[name]) | |
self.fixed_decoder = prev_decoder_state | |
def vn_invariance( | |
self, | |
Z: Float[Tensor, "B latent_dim 3"], | |
D: Float[Tensor, "B num_rays 3"], | |
equivariance: Literal["None", "SO2", "SO3"] = "SO2", | |
axis_of_invariance: int = 1, | |
): | |
"""Generates a batched invariant representation from latent code Z and direction coordinates D. | |
Args: | |
Z: [B, latent_dim, 3] - Latent code. | |
D: [B num_rays, 3] - Direction coordinates. | |
equivariance: The type of equivariance to use. Options are 'None', 'SO2', 'SO3'. | |
axis_of_invariance: The axis of rotation invariance. Should be 0 (x-axis), 1 (y-axis), or 2 (z-axis). | |
Returns: | |
Tuple[Tensor, Tensor]: directional_input, conditioning_input | |
""" | |
assert 0 <= axis_of_invariance < 3, "axis_of_invariance should be 0, 1, or 2." | |
other_axes = [i for i in range(3) if i != axis_of_invariance] | |
B, latent_dim, _ = Z.shape | |
_, num_rays, _ = D.shape | |
if equivariance == "None": | |
# get inner product between latent code and direction coordinates | |
innerprod = torch.sum( | |
Z.unsqueeze(1) * D.unsqueeze(2), dim=-1 | |
) # [B, num_rays, latent_dim] | |
z_input = ( | |
Z.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, latent_dim * 3) | |
) # [B, num_rays, latent_dim * 3] | |
return innerprod, z_input | |
if equivariance == "SO2": | |
z_other = torch.stack( | |
(Z[..., other_axes[0]], Z[..., other_axes[1]]), -1 | |
) # [B, latent_dim, 2] | |
d_other = torch.stack( | |
(D[..., other_axes[0]], D[..., other_axes[1]]), -1 | |
).unsqueeze(2) # [B, num_rays, 1, 2] | |
d_other = d_other.expand( | |
B, num_rays, latent_dim, 2 | |
) # [B, num_rays, latent_dim, 2] | |
z_other_emb = self.vn_proj_in(z_other) # [B, latent_dim, 1, 2] | |
z_other_invar = self.vn_invar(z_other_emb) # [B, latent_dim, 2] | |
# Get invariant component of Z along the axis of invariance | |
z_invar = Z[..., axis_of_invariance].unsqueeze(-1) # [B, latent_dim, 1] | |
# Innerproduct between projection of Z and D on the plane orthogonal to the axis of invariance. | |
# This encodes the rotational information. This is rotation-equivariant to rotations of either Z | |
# or D and is invariant to rotations of both Z and D. | |
innerprod = (z_other.unsqueeze(1) * d_other).sum( | |
dim=-1 | |
) # [B, num_rays, latent_dim] | |
# Compute norm along the axes orthogonal to the axis of invariance | |
d_other_norm = torch.sqrt( | |
D[..., other_axes[0]] ** 2 + D[..., other_axes[1]] ** 2 | |
).unsqueeze(-1) # [B num_rays, 1] | |
# Get invariant component of D along the axis of invariance | |
d_invar = D[..., axis_of_invariance].unsqueeze(-1) # [B, num_rays, 1] | |
directional_input = torch.cat( | |
(innerprod, d_invar, d_other_norm), -1 | |
) # [B, num_rays, latent_dim + 2] | |
conditioning_input = ( | |
torch.cat((z_other_invar, z_invar), dim=-1) | |
.flatten(1) | |
.unsqueeze(1) | |
.expand(B, num_rays, latent_dim * 3) | |
) # [B, num_rays, latent_dim * 3] | |
return directional_input, conditioning_input | |
if equivariance == "SO3": | |
z = self.vn_proj_in(Z) # [B, latent_dim, 1, 3] | |
z_invar = self.vn_invar(z) # [B, latent_dim, 3] | |
conditioning_input = ( | |
z_invar.flatten(1).unsqueeze(1).expand(B, num_rays, latent_dim) | |
) # [B, num_rays, latent_dim * 3] | |
# D [B, num_rays, 3] -> [B, num_rays, 1, 3] | |
# Z [B, latent_dim, 3] -> [B, 1, latent_dim, 3] | |
innerprod = torch.sum( | |
Z.unsqueeze(1) * D.unsqueeze(2), dim=-1 | |
) # [B, num_rays, latent_dim] | |
return innerprod, conditioning_input | |
def gram_matrix_invariance( | |
self, | |
Z: Float[Tensor, "B latent_dim 3"], | |
D: Float[Tensor, "B num_rays 3"], | |
equivariance: Literal["None", "SO2", "SO3"] = "SO2", | |
axis_of_invariance: int = 1, | |
): | |
"""Generates an invariant representation from latent code Z and direction coordinates D. | |
Args: | |
Z (torch.Tensor): Latent code (B x latent_dim x 3) | |
D (torch.Tensor): Direction coordinates (B x num_rays x 3) | |
equivariance (str): Type of equivariance to use. Options are 'none', 'SO2', and 'SO3' | |
axis_of_invariance (int): The axis of rotation invariance. Should be 0 (x-axis), 1 (y-axis), or 2 (z-axis). | |
Default is 1 (y-axis). | |
Returns: | |
torch.Tensor: Invariant representation | |
""" | |
assert 0 <= axis_of_invariance < 3, "axis_of_invariance should be 0, 1, or 2." | |
other_axes = [i for i in range(3) if i != axis_of_invariance] | |
B, latent_dim, _ = Z.shape | |
_, num_rays, _ = D.shape | |
if equivariance == "None": | |
# get inner product between latent code and direction coordinates | |
innerprod = torch.sum( | |
Z.unsqueeze(1) * D.unsqueeze(2), dim=-1 | |
) # [B, num_rays, latent_dim] | |
z_input = ( | |
Z.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, latent_dim * 3) | |
) # [B, num_rays, latent_dim * 3] | |
return innerprod, z_input | |
if equivariance == "SO2": | |
# Select components along axes orthogonal to the axis of invariance | |
z_other = torch.stack( | |
(Z[..., other_axes[0]], Z[..., other_axes[1]]), -1 | |
) # [B, latent_dim, 2] | |
d_other = torch.stack( | |
(D[..., other_axes[0]], D[..., other_axes[1]]), -1 | |
).unsqueeze(2) # [B, num_rays, 1, 2] | |
d_other = d_other.expand( | |
B, num_rays, latent_dim, 2 | |
) # size becomes [B, num_rays, latent_dim, 2] | |
# Invariant representation of Z, gram matrix G=Z*Z' is size num_rays x latent_dim x latent_dim | |
G = torch.bmm(z_other, torch.transpose(z_other, 1, 2)) | |
# Flatten G to be size B x latent_dim^2 | |
z_other_invar = G.flatten(start_dim=1) | |
# Get invariant component of Z along the axis of invariance | |
z_invar = Z[..., axis_of_invariance] # [B, latent_dim] | |
# Innerprod is size num_rays x latent_dim | |
innerprod = (z_other.unsqueeze(1) * d_other).sum( | |
dim=-1 | |
) # [B, num_rays, latent_dim] | |
# Compute norm along the axes orthogonal to the axis of invariance | |
d_other_norm = torch.sqrt( | |
D[..., other_axes[0]] ** 2 + D[..., other_axes[1]] ** 2 | |
).unsqueeze(-1) # [B, num_rays, 1] | |
# Get invariant component of D along the axis of invariance | |
d_invar = D[..., axis_of_invariance].unsqueeze(-1) # [B, num_rays, 1] | |
if not self.old_implementation: | |
directional_input = torch.cat( | |
(innerprod, d_invar, d_other_norm), -1 | |
) # [B, num_rays, latent_dim + 2] | |
conditioning_input = ( | |
torch.cat((z_other_invar, z_invar), -1) | |
.unsqueeze(1) | |
.expand(B, num_rays, latent_dim * 3) | |
) # [B, num_rays, latent_dim^2 + latent_dim] | |
else: | |
# this is matching the previous implementation of RENI, needed if using old checkpoints | |
z_other_invar = z_other_invar.unsqueeze(1).expand(B, num_rays, -1) | |
z_invar = z_invar.unsqueeze(1).expand(B, num_rays, -1) | |
return torch.cat( | |
(innerprod, z_other_invar, d_other_norm, z_invar, d_invar), 1 | |
) | |
return directional_input, conditioning_input | |
if equivariance == "SO3": | |
G = Z @ torch.transpose(Z, 1, 2) # [B, latent_dim, latent_dim] | |
innerprod = torch.sum( | |
Z.unsqueeze(1) * D.unsqueeze(2), dim=-1 | |
) # [B, num_rays, latent_dim] | |
z_invar = ( | |
G.flatten(start_dim=1).unsqueeze(1).expand(B, num_rays, -1) | |
) # [B, num_rays, latent_dim^2] | |
return innerprod, z_invar | |
def setup_network(self): | |
"""Sets up the network architecture""" | |
base_input_dims = { | |
"VN": { | |
"None": { | |
"direction": self.latent_dim, | |
"conditioning": self.latent_dim * 3, | |
}, | |
"SO2": { | |
"direction": self.latent_dim + 2, | |
"conditioning": self.latent_dim * 3, | |
}, | |
"SO3": { | |
"direction": self.latent_dim, | |
"conditioning": self.latent_dim * 3, | |
}, | |
}, | |
"GramMatrix": { | |
"None": { | |
"direction": self.latent_dim, | |
"conditioning": self.latent_dim * 3, | |
}, | |
"SO2": { | |
"direction": self.latent_dim + 2, | |
"conditioning": self.latent_dim**2 + self.latent_dim, | |
}, | |
"SO3": { | |
"direction": self.latent_dim, | |
"conditioning": self.latent_dim**2, | |
}, | |
}, | |
} | |
# Extract the necessary input dimensions | |
input_types = ["direction", "conditioning"] | |
input_dims = { | |
key: base_input_dims[self.cfg.invariant_function][self.cfg.equivariance][ | |
key | |
] | |
for key in input_types | |
} | |
# Helper function to create NeRF encoding | |
def create_nerf_encoding(in_dim): | |
return NeRFEncoding( | |
in_dim=in_dim, | |
num_frequencies=2, | |
min_freq_exp=0.0, | |
max_freq_exp=2.0, | |
include_input=True, | |
) | |
# Dictionary-based encoding setup | |
encoding_setup = { | |
"None": [], | |
"Conditioning": ["conditioning"], | |
"Directions": ["direction"], | |
"Both": ["direction", "conditioning"], | |
} | |
# Setting up the required encodings | |
for input_type in encoding_setup.get(self.cfg.encoded_input, []): | |
# create self.{input_type}_encoding and update input_dims | |
setattr( | |
self, | |
f"{input_type}_encoding", | |
create_nerf_encoding(input_dims[input_type]), | |
) | |
input_dims[input_type] = getattr( | |
self, f"{input_type}_encoding" | |
).get_out_dim() | |
output_activation = get_activation_module(self.cfg.output_activation) | |
network = None | |
if self.conditioning == "Concat": | |
network = Siren( | |
in_dim=input_dims["direction"] + input_dims["conditioning"], | |
hidden_layers=self.hidden_layers, | |
hidden_features=self.hidden_features, | |
out_dim=self.out_features, | |
outermost_linear=self.last_layer_linear, | |
first_omega_0=self.first_omega_0, | |
hidden_omega_0=self.hidden_omega_0, | |
out_activation=output_activation, | |
) | |
elif self.conditioning == "FiLM": | |
network = FiLMSiren( | |
in_dim=input_dims["direction"], | |
hidden_layers=self.hidden_layers, | |
hidden_features=self.hidden_features, | |
mapping_network_in_dim=input_dims["conditioning"], | |
mapping_network_layers=self.mapping_layers, | |
mapping_network_features=self.mapping_features, | |
out_dim=self.out_features, | |
outermost_linear=True, | |
out_activation=output_activation, | |
) | |
elif self.conditioning == "Attention": | |
# transformer where K, V is from conditioning input and Q is from pos encoded directional input | |
network = Decoder( | |
in_dim=input_dims["direction"], | |
conditioning_input_dim=input_dims["conditioning"], | |
hidden_features=self.cfg.hidden_features, | |
num_heads=self.cfg.num_attention_heads, | |
num_layers=self.cfg.num_attention_layers, | |
out_activation=output_activation, | |
) | |
assert network is not None, "unknown conditioning type" | |
return network | |
def apply_positional_encoding(self, directional_input, conditioning_input): | |
# conditioning on just invariant directional input | |
if self.cfg.encoded_input == "Conditioning": | |
conditioning_input = self.conditioning_encoding( | |
conditioning_input | |
) # [num_rays, embedding_dim] | |
elif self.cfg.encoded_input == "Directions": | |
directional_input = self.direction_encoding( | |
directional_input | |
) # [num_rays, embedding_dim] | |
elif self.cfg.encoded_input == "Both": | |
directional_input = self.direction_encoding(directional_input) | |
conditioning_input = self.conditioning_encoding(conditioning_input) | |
return directional_input, conditioning_input | |
def get_outputs( | |
self, | |
rays_d: Float[Tensor, "batch num_rays 3"], # type: ignore | |
latent_codes: Float[Tensor, "batch_size latent_dim 3"], # type: ignore | |
rotation: Optional[Float[Tensor, "batch_size 3 3"]] = None, # type: ignore | |
scale: Optional[Float[Tensor, "batch_size"]] = None, # type: ignore | |
) -> Dict[str, Tensor]: | |
"""Returns the outputs of the field. | |
Args: | |
ray_samples: [batch_size num_rays 3] | |
latent_codes: [batch_size, latent_dim, 3] | |
rotation: [batch_size, 3, 3] | |
scale: [batch_size] | |
""" | |
if rotation is not None: | |
if len(rotation.shape) == 3: # [batch_size, 3, 3] | |
# Expand latent_codes to match [batch_size, latent_dim, 3] | |
latent_codes = torch.einsum( | |
"bik,blk->bli", | |
rotation, | |
latent_codes, | |
) | |
else: | |
raise NotImplementedError( | |
"Unsupported rotation shape. Expected [batch_size, 3, 3]." | |
) | |
B, num_rays, _ = rays_d.shape | |
_, latent_dim, _ = latent_codes.shape | |
if not self.old_implementation: | |
directional_input, conditioning_input = self.invariant_function( | |
latent_codes, | |
rays_d, | |
equivariance=self.equivariance, | |
axis_of_invariance=self.axis_of_invariance, | |
) # [B, num_rays, 3] | |
if self.cfg.positional_encoding == "NeRF": | |
directional_input, conditioning_input = self.apply_positional_encoding( | |
directional_input, conditioning_input | |
) | |
if self.conditioning == "Concat": | |
model_outputs = self.network( | |
torch.cat((directional_input, conditioning_input), dim=-1).reshape( | |
B * num_rays, -1 | |
) | |
).view(B, num_rays, 3) # returns -> [B num_rays, 3] | |
elif self.conditioning == "FiLM": | |
model_outputs = self.network( | |
directional_input.reshape(B * num_rays, -1), | |
conditioning_input.reshape(B * num_rays, -1), | |
).view(B, num_rays, 3) # returns -> [B num_rays, 3] | |
elif self.conditioning == "Attention": | |
model_outputs = self.network( | |
directional_input.reshape(B * num_rays, -1), | |
conditioning_input.reshape(B * num_rays, -1), | |
).view(B, num_rays, 3) # returns -> [B num_rays, 3] | |
else: | |
# in the old implementation directions were sampled with y-up not z-up so need to swap y and z in directions | |
directions = torch.stack( | |
(rays_d[..., 0], rays_d[..., 2], rays_d[..., 1]), -1 | |
) | |
model_input = self.invariant_function( | |
latent_codes, | |
directions, | |
equivariance=self.equivariance, | |
axis_of_invariance=self.axis_of_invariance, | |
) # [B, num_rays, 3] | |
model_outputs = self.network(model_input.view(B * num_rays, -1)).view( | |
B, num_rays, 3 | |
) | |
outputs = {} | |
if scale is not None: | |
scale = trunc_exp(scale) # [num_rays] exp to ensure positive | |
model_outputs = model_outputs * scale.view(-1, 1, 1) # [num_rays, 3] | |
outputs["rgb"] = model_outputs | |
return outputs | |
def forward( | |
self, | |
rays_d: Float[Tensor, "batch num_rays 3"], # type: ignore | |
latent_codes: Float[Tensor, "batch_size latent_dim 3"], # type: ignore | |
rotation: Optional[Float[Tensor, "batch_size 3 3"]] = None, # type: ignore | |
scale: Optional[Float[Tensor, "batch_size"]] = None, # type: ignore | |
) -> Dict[str, Tensor]: | |
"""Evaluates spherical field for a given ray bundle and rotation. | |
Args: | |
ray_samples: [B num_rays 3] | |
latent_codes: [B, num_rays, latent_dim, 3] | |
rotation: [batch_size, 3, 3] | |
scale: [batch_size] | |
Returns: | |
Dict[str, Tensor]: A dictionary containing the outputs of the field. | |
""" | |
return self.get_outputs( | |
rays_d=rays_d, | |
latent_codes=latent_codes, | |
rotation=rotation, | |
scale=scale, | |
) | |