| import torch | |
| from pytorch3d.renderer.mesh.shader import ShaderBase | |
| from pytorch3d.renderer import ( | |
| SoftPhongShader, | |
| ) | |
| class MultiOutputShader(ShaderBase): | |
| def __init__(self, device, cameras, lights, materials, ccm_scale=1.0, choices=None): | |
| super().__init__() | |
| self.device = device | |
| self.cameras = cameras | |
| self.lights = lights | |
| self.materials = materials | |
| self.ccm_scale = ccm_scale | |
| if choices is None: | |
| self.choices = ["rgb", "mask", "depth", "normal", "albedo", "ccm"] | |
| else: | |
| self.choices = choices | |
| self.phong_shader = SoftPhongShader( | |
| device=self.device, | |
| cameras=self.cameras, | |
| lights=self.lights, | |
| materials=self.materials | |
| ) | |
| def forward(self, fragments, meshes, **kwargs): | |
| batch_size, H, W, _ = fragments.zbuf.shape | |
| output = {} | |
| if "rgb" in self.choices: | |
| rgb_images = self.phong_shader(fragments, meshes, **kwargs) | |
| rgb = rgb_images[..., :3] | |
| output["rgb"] = rgb | |
| if "mask" in self.choices: | |
| alpha = rgb_images[..., 3:4] | |
| mask = (alpha > 0).float() | |
| output["mask"] = mask | |
| if "albedo" in self.choices: | |
| albedo = meshes.sample_textures(fragments) | |
| output["albedo"] = albedo[..., 0, :] | |
| if "depth" in self.choices: | |
| depth = fragments.zbuf | |
| output["depth"] = depth | |
| if "normal" in self.choices: | |
| pix_to_face = fragments.pix_to_face[..., 0] | |
| bary_coords = fragments.bary_coords[..., 0, :] | |
| valid_mask = pix_to_face >= 0 | |
| face_indices = pix_to_face[valid_mask] | |
| faces_packed = meshes.faces_packed() | |
| normals_packed = meshes.verts_normals_packed() | |
| face_vertex_normals = normals_packed[faces_packed[face_indices]] | |
| bary = bary_coords.view(-1, 3)[valid_mask.view(-1)] | |
| interpolated_normals = ( | |
| bary[..., 0:1] * face_vertex_normals[:, 0, :] + | |
| bary[..., 1:2] * face_vertex_normals[:, 1, :] + | |
| bary[..., 2:3] * face_vertex_normals[:, 2, :] | |
| ) | |
| interpolated_normals = interpolated_normals / interpolated_normals.norm(dim=-1, keepdim=True) | |
| normal = torch.zeros(batch_size, H, W, 3, device=self.device) | |
| normal[valid_mask] = interpolated_normals | |
| output["normal"] = normal | |
| if "ccm" in self.choices: | |
| face_vertices = meshes.verts_packed()[meshes.faces_packed()] | |
| faces_at_pixels = face_vertices[fragments.pix_to_face] | |
| ccm = torch.sum(fragments.bary_coords.unsqueeze(-1) * faces_at_pixels, dim=-2) | |
| ccm = (ccm[..., 0, :] * self.ccm_scale + 1) / 2 | |
| output["ccm"] = ccm | |
| return output |