Spaces:
Paused
Paused
| # Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| import nvdiffrast.torch as dr | |
| import imageio | |
| #---------------------------------------------------------------------------- | |
| # Vector operations | |
| #---------------------------------------------------------------------------- | |
| def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| return torch.sum(x*y, -1, keepdim=True) | |
| def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: | |
| return 2*dot(x, n)*n - x | |
| def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: | |
| return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN | |
| def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: | |
| return x / length(x, eps) | |
| def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor: | |
| return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w) | |
| #---------------------------------------------------------------------------- | |
| # Tonemapping | |
| #---------------------------------------------------------------------------- | |
| def tonemap_srgb(f: torch.Tensor) -> torch.Tensor: | |
| return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) | |
| #---------------------------------------------------------------------------- | |
| # sRGB color transforms | |
| #---------------------------------------------------------------------------- | |
| def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: | |
| return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055) | |
| def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: | |
| assert f.shape[-1] == 3 or f.shape[-1] == 4 | |
| out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f) | |
| assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] | |
| return out | |
| def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: | |
| return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4)) | |
| def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: | |
| assert f.shape[-1] == 3 or f.shape[-1] == 4 | |
| out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f) | |
| assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] | |
| return out | |
| #---------------------------------------------------------------------------- | |
| # Displacement texture lookup | |
| #---------------------------------------------------------------------------- | |
| def get_miplevels(texture: np.ndarray) -> float: | |
| minDim = min(texture.shape[0], texture.shape[1]) | |
| return np.floor(np.log2(minDim)) | |
| # TODO: Handle wrapping maybe | |
| def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='nearest') -> torch.Tensor: | |
| tex_map = tex_map[None, ...] # Add batch dimension | |
| tex_map = tex_map.permute(0, 3, 1, 2) # NHWC -> NCHW | |
| tex = torch.nn.functional.grid_sample(tex_map, coords[None, None, ...] * 2 - 1, mode=filter, align_corners=False) | |
| tex = tex.permute(0, 2, 3, 1) # NCHW -> NHWC | |
| return tex[0, 0, ...] | |
| #---------------------------------------------------------------------------- | |
| # Image scaling | |
| #---------------------------------------------------------------------------- | |
| def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: | |
| return scale_img_nhwc(x[None, ...], size, mag, min)[0] | |
| def scale_img_nhwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: | |
| assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other" | |
| y = x.permute(0, 3, 1, 2) # NHWC -> NCHW | |
| if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger | |
| y = torch.nn.functional.interpolate(y, size, mode=min) | |
| else: # Magnification | |
| if mag == 'bilinear' or mag == 'bicubic': | |
| y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) | |
| else: | |
| y = torch.nn.functional.interpolate(y, size, mode=mag) | |
| return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC | |
| def avg_pool_nhwc(x : torch.Tensor, size) -> torch.Tensor: | |
| y = x.permute(0, 3, 1, 2) # NHWC -> NCHW | |
| y = torch.nn.functional.avg_pool2d(y, size) | |
| return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC | |
| #---------------------------------------------------------------------------- | |
| # Behaves similar to tf.segment_sum | |
| #---------------------------------------------------------------------------- | |
| def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor: | |
| num_segments = torch.unique_consecutive(segment_ids).shape[0] | |
| # Repeats ids until same dimension as data | |
| if len(segment_ids.shape) == 1: | |
| s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64, device='cuda')).long() | |
| segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:]) | |
| assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal" | |
| shape = [num_segments] + list(data.shape[1:]) | |
| result = torch.zeros(*shape, dtype=torch.float32, device='cuda') | |
| result = result.scatter_add(0, segment_ids, data) | |
| return result | |
| #---------------------------------------------------------------------------- | |
| # Projection and transformation matrix helpers. | |
| #---------------------------------------------------------------------------- | |
| def projection(x=0.1, n=1.0, f=50.0): | |
| return np.array([[n/x, 0, 0, 0], | |
| [ 0, n/-x, 0, 0], | |
| [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], | |
| [ 0, 0, -1, 0]]).astype(np.float32) | |
| def translate(x, y, z): | |
| return np.array([[1, 0, 0, x], | |
| [0, 1, 0, y], | |
| [0, 0, 1, z], | |
| [0, 0, 0, 1]]).astype(np.float32) | |
| def rotate_x(a): | |
| s, c = np.sin(a), np.cos(a) | |
| return np.array([[1, 0, 0, 0], | |
| [0, c, s, 0], | |
| [0, -s, c, 0], | |
| [0, 0, 0, 1]]).astype(np.float32) | |
| def rotate_y(a): | |
| s, c = np.sin(a), np.cos(a) | |
| return np.array([[ c, 0, s, 0], | |
| [ 0, 1, 0, 0], | |
| [-s, 0, c, 0], | |
| [ 0, 0, 0, 1]]).astype(np.float32) | |
| def scale(s): | |
| return np.array([[ s, 0, 0, 0], | |
| [ 0, s, 0, 0], | |
| [ 0, 0, s, 0], | |
| [ 0, 0, 0, 1]]).astype(np.float32) | |
| def lookAt(eye, at, up): | |
| a = eye - at | |
| b = up | |
| w = a / np.linalg.norm(a) | |
| u = np.cross(b, w) | |
| u = u / np.linalg.norm(u) | |
| v = np.cross(w, u) | |
| translate = np.array([[1, 0, 0, -eye[0]], | |
| [0, 1, 0, -eye[1]], | |
| [0, 0, 1, -eye[2]], | |
| [0, 0, 0, 1]]).astype(np.float32) | |
| rotate = np.array([[u[0], u[1], u[2], 0], | |
| [v[0], v[1], v[2], 0], | |
| [w[0], w[1], w[2], 0], | |
| [0, 0, 0, 1]]).astype(np.float32) | |
| return np.matmul(rotate, translate) | |
| def random_rotation_translation(t): | |
| m = np.random.normal(size=[3, 3]) | |
| m[1] = np.cross(m[0], m[2]) | |
| m[2] = np.cross(m[0], m[1]) | |
| m = m / np.linalg.norm(m, axis=1, keepdims=True) | |
| m = np.pad(m, [[0, 1], [0, 1]], mode='constant') | |
| m[3, 3] = 1.0 | |
| m[:3, 3] = np.random.uniform(-t, t, size=[3]) | |
| return m | |
| #---------------------------------------------------------------------------- | |
| # Cosine sample around a vector N | |
| #---------------------------------------------------------------------------- | |
| def cosine_sample(N : np.ndarray) -> np.ndarray: | |
| # construct local frame | |
| N = N/np.linalg.norm(N) | |
| dx0 = np.array([0, N[2], -N[1]]) | |
| dx1 = np.array([-N[2], 0, N[0]]) | |
| dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1 | |
| dx = dx/np.linalg.norm(dx) | |
| dy = np.cross(N,dx) | |
| dy = dy/np.linalg.norm(dy) | |
| # cosine sampling in local frame | |
| phi = 2.0*np.pi*np.random.uniform() | |
| s = np.random.uniform() | |
| costheta = np.sqrt(s) | |
| sintheta = np.sqrt(1.0 - s) | |
| # cartesian vector in local space | |
| x = np.cos(phi)*sintheta | |
| y = np.sin(phi)*sintheta | |
| z = costheta | |
| # local to world | |
| return dx*x + dy*y + N*z | |
| #---------------------------------------------------------------------------- | |
| # Cosine sampled light directions around the vector N | |
| #---------------------------------------------------------------------------- | |
| def cosine_sample_texture(res, N : np.ndarray) -> torch.Tensor: | |
| # construct local frame | |
| N = N/np.linalg.norm(N) | |
| dx0 = np.array([0, N[2], -N[1]]) | |
| dx1 = np.array([-N[2], 0, N[0]]) | |
| dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1 | |
| dx = dx/np.linalg.norm(dx) | |
| dy = np.cross(N,dx) | |
| dy = dy/np.linalg.norm(dy) | |
| X = torch.tensor(dx, dtype=torch.float32, device='cuda') | |
| Y = torch.tensor(dy, dtype=torch.float32, device='cuda') | |
| Z = torch.tensor(N, dtype=torch.float32, device='cuda') | |
| # cosine sampling in local frame | |
| phi = 2.0*np.pi*torch.rand(res, res, 1, dtype=torch.float32, device='cuda') | |
| s = torch.rand(res, res, 1, dtype=torch.float32, device='cuda') | |
| costheta = torch.sqrt(s) | |
| sintheta = torch.sqrt(1.0 - s) | |
| # cartesian vector in local space | |
| x = torch.cos(phi)*sintheta | |
| y = torch.sin(phi)*sintheta | |
| z = costheta | |
| # local to world | |
| return X*x + Y*y + Z*z | |
| #---------------------------------------------------------------------------- | |
| # Bilinear downsample by 2x. | |
| #---------------------------------------------------------------------------- | |
| def bilinear_downsample(x : torch.tensor) -> torch.Tensor: | |
| w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 | |
| w = w.expand(x.shape[-1], 1, 4, 4) | |
| x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1]) | |
| return x.permute(0, 2, 3, 1) | |
| #---------------------------------------------------------------------------- | |
| # Bilinear downsample log(spp) steps | |
| #---------------------------------------------------------------------------- | |
| def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor: | |
| w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 | |
| g = x.shape[-1] | |
| w = w.expand(g, 1, 4, 4) | |
| x = x.permute(0, 3, 1, 2) # NHWC -> NCHW | |
| steps = int(np.log2(spp)) | |
| for _ in range(steps): | |
| xp = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate') | |
| x = torch.nn.functional.conv2d(xp, w, padding=0, stride=2, groups=g) | |
| return x.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC | |
| #---------------------------------------------------------------------------- | |
| # Image display function using OpenGL. | |
| #---------------------------------------------------------------------------- | |
| _glfw_window = None | |
| def display_image(image, zoom=None, size=None, title=None): # HWC | |
| # Import OpenGL and glfw. | |
| import OpenGL.GL as gl | |
| import glfw | |
| # Zoom image if requested. | |
| image = np.asarray(image) | |
| if size is not None: | |
| assert zoom is None | |
| zoom = max(1, size // image.shape[0]) | |
| if zoom is not None: | |
| image = image.repeat(zoom, axis=0).repeat(zoom, axis=1) | |
| height, width, channels = image.shape | |
| # Initialize window. | |
| if title is None: | |
| title = 'Debug window' | |
| global _glfw_window | |
| if _glfw_window is None: | |
| glfw.init() | |
| _glfw_window = glfw.create_window(width, height, title, None, None) | |
| glfw.make_context_current(_glfw_window) | |
| glfw.show_window(_glfw_window) | |
| glfw.swap_interval(0) | |
| else: | |
| glfw.make_context_current(_glfw_window) | |
| glfw.set_window_title(_glfw_window, title) | |
| glfw.set_window_size(_glfw_window, width, height) | |
| # Update window. | |
| glfw.poll_events() | |
| gl.glClearColor(0, 0, 0, 1) | |
| gl.glClear(gl.GL_COLOR_BUFFER_BIT) | |
| gl.glWindowPos2f(0, 0) | |
| gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) | |
| gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels] | |
| gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name] | |
| gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1]) | |
| glfw.swap_buffers(_glfw_window) | |
| if glfw.window_should_close(_glfw_window): | |
| return False | |
| return True | |
| #---------------------------------------------------------------------------- | |
| # Image save helper. | |
| #---------------------------------------------------------------------------- | |
| def save_image(fn, x : np.ndarray) -> np.ndarray: | |
| imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8)) | |
| def load_image(fn) -> np.ndarray: | |
| img = imageio.imread(fn) | |
| if img.dtype == np.float32: # HDR image | |
| return img | |
| else: # LDR image | |
| return img.astype(np.float32) / 255 | |
| #---------------------------------------------------------------------------- | |
| def time_to_text(x): | |
| if x > 3600: | |
| return "%.2f h" % (x / 3600) | |
| elif x > 60: | |
| return "%.2f m" % (x / 60) | |
| else: | |
| return "%.2f s" % x | |
| #---------------------------------------------------------------------------- | |
| def checkerboard(width, repetitions) -> np.ndarray: | |
| tilesize = int(width//repetitions//2) | |
| check = np.kron([[1, 0] * repetitions, [0, 1] * repetitions] * repetitions, np.ones((tilesize, tilesize)))*0.33 + 0.33 | |
| return np.stack((check, check, check), axis=-1)[None, ...] | |