Spaces:
Running
on
L4
Running
on
L4
| # -*- coding: utf-8 -*- | |
| # | |
| # @File: inference.py | |
| # @Author: Haozhe Xie | |
| # @Date: 2024-03-02 16:30:00 | |
| # @Last Modified by: Haozhe Xie | |
| # @Last Modified at: 2024-09-22 10:22:05 | |
| # @Email: [email protected] | |
| import copy | |
| import cv2 | |
| import logging | |
| import math | |
| import numpy as np | |
| import torch | |
| import citydreamer.extensions.extrude_tensor | |
| import citydreamer.extensions.voxlib | |
| # Global constants | |
| HEIGHTS = { | |
| "ROAD": 4, | |
| "GREEN_LANDS": 8, | |
| "CONSTRUCTION": 10, | |
| "COAST_ZONES": 0, | |
| "ROOF": 1, | |
| } | |
| CLASSES = { | |
| "NULL": 0, | |
| "ROAD": 1, | |
| "BLD_FACADE": 2, | |
| "GREEN_LANDS": 3, | |
| "CONSTRUCTION": 4, | |
| "COAST_ZONES": 5, | |
| "OTHERS": 6, | |
| "BLD_ROOF": 7, | |
| } | |
| # NOTE: ID > 10 are reserved for building instances. | |
| # Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1. | |
| CONSTANTS = { | |
| "BLD_INS_LABEL_MIN": 10, | |
| "LAYOUT_N_CLASSES": 7, | |
| "LAYOUT_VOL_SIZE": 1536, | |
| "BUILDING_VOL_SIZE": 672, | |
| "EXTENDED_VOL_SIZE": 2880, | |
| "LAYOUT_MAX_HEIGHT": 640, | |
| "GES_VFOV": 20, | |
| "GES_IMAGE_HEIGHT": 540, | |
| "GES_IMAGE_WIDTH": 960, | |
| "IMAGE_PADDING": 8, | |
| "N_VOXEL_INTERSECT_SAMPLES": 6, | |
| } | |
| def generate_city(fgm, bgm, hf, seg, cx, cy, radius, altitude, azimuth): | |
| cam_pos = get_orbit_camera_position(radius, altitude, azimuth) | |
| seg, building_stats = get_instance_seg_map(seg) | |
| # Generate latent codes | |
| logging.info("Generating latent codes ...") | |
| bg_z, building_zs = get_latent_codes( | |
| building_stats, | |
| bgm.module.cfg.NETWORK.GANCRAFT.STYLE_DIM, | |
| bgm.output_device, | |
| ) | |
| # Generate local image patch of the height field and seg map | |
| part_hf, part_seg = get_part_hf_seg(hf, seg, cx, cy, CONSTANTS["EXTENDED_VOL_SIZE"]) | |
| # Generate local image patch of the height field and seg map | |
| part_hf, part_seg = get_part_hf_seg(hf, seg, cx, cy, CONSTANTS["EXTENDED_VOL_SIZE"]) | |
| # print(part_hf.shape) # (2880, 2880) | |
| # print(part_seg.shape) # (2880, 2880) | |
| # Recalculate the building positions based on the current patch | |
| _building_stats = get_part_building_stats(part_seg, building_stats, cx, cy) | |
| # Generate the concatenated height field and seg. map tensor | |
| hf_seg = get_hf_seg_tensor(part_hf, part_seg, bgm.output_device) | |
| # print(hf_seg.size()) # torch.Size([1, 8, 2880, 2880]) | |
| # Build seg_volume | |
| logging.info("Generating seg volume ...") | |
| seg_volume = get_seg_volume(part_hf, part_seg) | |
| logging.info("Rendering City Image ...") | |
| img = render( | |
| (CONSTANTS["GES_IMAGE_HEIGHT"] // 5, CONSTANTS["GES_IMAGE_WIDTH"] // 5), | |
| seg_volume, | |
| hf_seg, | |
| cam_pos, | |
| bgm, | |
| fgm, | |
| _building_stats, | |
| bg_z, | |
| building_zs, | |
| ) | |
| img = ((img.cpu().numpy().squeeze().transpose((1, 2, 0)) / 2 + 0.5) * 255).astype( | |
| np.uint8 | |
| ) | |
| return img | |
| def get_orbit_camera_position(radius, altitude, azimuth): | |
| cx = CONSTANTS["LAYOUT_VOL_SIZE"] // 2 | |
| cy = cx | |
| theta = np.deg2rad(azimuth) | |
| cam_x = cx + radius * math.cos(theta) | |
| cam_y = cy + radius * math.sin(theta) | |
| return {"x": cam_x, "y": cam_y, "z": altitude} | |
| def get_instance_seg_map(seg_map): | |
| # Mapping constructions to buildings | |
| seg_map[seg_map == CLASSES["CONSTRUCTION"]] = CLASSES["BLD_FACADE"] | |
| # Use connected components to get building instances | |
| _, labels, stats, _ = cv2.connectedComponentsWithStats( | |
| (seg_map == CLASSES["BLD_FACADE"]).astype(np.uint8), connectivity=4 | |
| ) | |
| # Remove non-building instance masks | |
| labels[seg_map != CLASSES["BLD_FACADE"]] = 0 | |
| # Building instance mask | |
| building_mask = labels != 0 | |
| # Make building instance IDs are even numbers and start from 10 | |
| # Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1. | |
| labels = (labels + CONSTANTS["BLD_INS_LABEL_MIN"]) * 2 | |
| seg_map[seg_map == CLASSES["BLD_FACADE"]] = 0 | |
| seg_map = seg_map * (1 - building_mask) + labels * building_mask | |
| assert np.max(labels) < 2147483648 | |
| return seg_map.astype(np.int32), stats[:, :4] | |
| def get_latent_codes(building_stats, bg_style_dim, output_device): | |
| bg_z = _get_z(output_device, bg_style_dim) | |
| building_zs = { | |
| (i + CONSTANTS["BLD_INS_LABEL_MIN"]) * 2: _get_z(output_device) | |
| for i in range(len(building_stats)) | |
| } | |
| return bg_z, building_zs | |
| def _get_z(device, z_dim=256): | |
| if z_dim is None: | |
| return None | |
| return torch.randn(1, z_dim, dtype=torch.float32, device=device) | |
| def get_part_hf_seg(hf, seg, cx, cy, patch_size): | |
| part_hf = _get_image_patch(hf, cx, cy, patch_size) | |
| part_seg = _get_image_patch(seg, cx, cy, patch_size) | |
| assert part_hf.shape == ( | |
| patch_size, | |
| patch_size, | |
| ), part_hf.shape | |
| assert part_hf.shape == part_seg.shape, part_seg.shape | |
| return part_hf, part_seg | |
| def _get_image_patch(image, cx, cy, patch_size): | |
| sx = cx - patch_size // 2 | |
| sy = cy - patch_size // 2 | |
| ex = sx + patch_size | |
| ey = sy + patch_size | |
| return image[sy:ey, sx:ex] | |
| def get_part_building_stats(part_seg, building_stats, cx, cy): | |
| _buildings = np.unique(part_seg[part_seg > CONSTANTS["BLD_INS_LABEL_MIN"]]) | |
| _building_stats = {} | |
| for b in _buildings: | |
| _b = b // 2 - CONSTANTS["BLD_INS_LABEL_MIN"] | |
| _building_stats[b] = [ | |
| building_stats[_b, 1] - cy + building_stats[_b, 3] / 2, | |
| building_stats[_b, 0] - cx + building_stats[_b, 2] / 2, | |
| ] | |
| return _building_stats | |
| def get_hf_seg_tensor(part_hf, part_seg, output_device): | |
| part_hf = torch.from_numpy(part_hf[None, None, ...]).to(output_device) | |
| part_seg = torch.from_numpy(part_seg[None, None, ...]).to(output_device) | |
| part_hf = part_hf / CONSTANTS["LAYOUT_MAX_HEIGHT"] | |
| part_seg = _masks_to_onehots(part_seg[:, 0, :, :], CONSTANTS["LAYOUT_N_CLASSES"]) | |
| return torch.cat([part_hf, part_seg], dim=1) | |
| def _masks_to_onehots(masks, n_class, ignored_classes=[]): | |
| b, h, w = masks.shape | |
| n_class_actual = n_class - len(ignored_classes) | |
| one_hot_masks = torch.zeros( | |
| (b, n_class_actual, h, w), dtype=torch.float32, device=masks.device | |
| ) | |
| n_class_cnt = 0 | |
| for i in range(n_class): | |
| if i not in ignored_classes: | |
| one_hot_masks[:, n_class_cnt] = masks == i | |
| n_class_cnt += 1 | |
| return one_hot_masks | |
| def get_seg_volume(part_hf, part_seg): | |
| tensor_extruder = citydreamer.extensions.extrude_tensor.TensorExtruder( | |
| CONSTANTS["LAYOUT_MAX_HEIGHT"] | |
| ) | |
| if part_hf.shape == ( | |
| CONSTANTS["EXTENDED_VOL_SIZE"], | |
| CONSTANTS["EXTENDED_VOL_SIZE"], | |
| ): | |
| part_hf = part_hf[ | |
| CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"], | |
| CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"], | |
| ] | |
| # print(part_hf.shape) # torch.Size([1, 8, 1536, 1536]) | |
| part_seg = part_seg[ | |
| CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"], | |
| CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"], | |
| ] | |
| # print(part_seg.shape) # torch.Size([1, 8, 1536, 1536]) | |
| assert part_hf.shape == ( | |
| CONSTANTS["LAYOUT_VOL_SIZE"], | |
| CONSTANTS["LAYOUT_VOL_SIZE"], | |
| ) | |
| assert part_hf.shape == part_seg.shape, part_seg.shape | |
| seg_volume = tensor_extruder( | |
| torch.from_numpy(part_seg[None, None, ...]).cuda(), | |
| torch.from_numpy(part_hf[None, None, ...]).cuda(), | |
| ).squeeze() | |
| logging.debug("The shape of SegVolume: %s" % (seg_volume.size(),)) | |
| # Change the top-level voxel of the "Building Facade" to "Building Roof" | |
| roof_seg_map = part_seg.copy() | |
| non_roof_msk = part_seg <= CONSTANTS["BLD_INS_LABEL_MIN"] | |
| # Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1. | |
| roof_seg_map = roof_seg_map - 1 | |
| roof_seg_map[non_roof_msk] = 0 | |
| for rh in range(1, HEIGHTS["ROOF"] + 1): | |
| seg_volume = seg_volume.scatter_( | |
| dim=2, | |
| index=torch.from_numpy(part_hf[..., None] + rh).long().cuda(), | |
| src=torch.from_numpy(roof_seg_map[..., None]).cuda(), | |
| ) | |
| # print(seg_volume.size()) # torch.Size([1536, 1536, 640]) | |
| return seg_volume | |
| def get_voxel_intersection_perspective(seg_volume, camera_location): | |
| CAMERA_FOCAL = ( | |
| CONSTANTS["GES_IMAGE_HEIGHT"] / 2 / np.tan(np.deg2rad(CONSTANTS["GES_VFOV"])) | |
| ) | |
| # print(seg_volume.size()) # torch.Size([1536, 1536, 640]) | |
| camera_target = { | |
| "x": seg_volume.size(1) // 2 - 1, | |
| "y": seg_volume.size(0) // 2 - 1, | |
| } | |
| cam_origin = torch.tensor( | |
| [ | |
| camera_location["y"], | |
| camera_location["x"], | |
| camera_location["z"], | |
| ], | |
| dtype=torch.float32, | |
| device=seg_volume.device, | |
| ) | |
| ( | |
| voxel_id, | |
| depth2, | |
| raydirs, | |
| ) = citydreamer.extensions.voxlib.ray_voxel_intersection_perspective( | |
| seg_volume, | |
| cam_origin, | |
| torch.tensor( | |
| [ | |
| camera_target["y"] - camera_location["y"], | |
| camera_target["x"] - camera_location["x"], | |
| -camera_location["z"], | |
| ], | |
| dtype=torch.float32, | |
| device=seg_volume.device, | |
| ), | |
| torch.tensor([0, 0, 1], dtype=torch.float32), | |
| CAMERA_FOCAL * 2.06, | |
| [ | |
| (CONSTANTS["GES_IMAGE_HEIGHT"] - 1) / 2.0, | |
| (CONSTANTS["GES_IMAGE_WIDTH"] - 1) / 2.0, | |
| ], | |
| [CONSTANTS["GES_IMAGE_HEIGHT"], CONSTANTS["GES_IMAGE_WIDTH"]], | |
| CONSTANTS["N_VOXEL_INTERSECT_SAMPLES"], | |
| ) | |
| return ( | |
| voxel_id.unsqueeze(dim=0), | |
| depth2.permute(1, 2, 0, 3, 4).unsqueeze(dim=0), | |
| raydirs.unsqueeze(dim=0), | |
| cam_origin.unsqueeze(dim=0), | |
| ) | |
| def _get_pad_img_bbox(sx, ex, sy, ey): | |
| psx = sx - CONSTANTS["IMAGE_PADDING"] if sx != 0 else 0 | |
| psy = sy - CONSTANTS["IMAGE_PADDING"] if sy != 0 else 0 | |
| pex = ( | |
| ex + CONSTANTS["IMAGE_PADDING"] | |
| if ex != CONSTANTS["GES_IMAGE_WIDTH"] | |
| else CONSTANTS["GES_IMAGE_WIDTH"] | |
| ) | |
| pey = ( | |
| ey + CONSTANTS["IMAGE_PADDING"] | |
| if ey != CONSTANTS["GES_IMAGE_HEIGHT"] | |
| else CONSTANTS["GES_IMAGE_HEIGHT"] | |
| ) | |
| return psx, pex, psy, pey | |
| def _get_img_without_pad(img, sx, ex, sy, ey, psx, pex, psy, pey): | |
| if CONSTANTS["IMAGE_PADDING"] == 0: | |
| return img | |
| return img[ | |
| :, | |
| :, | |
| sy - psy : ey - pey if ey != pey else ey, | |
| sx - psx : ex - pex if ex != pex else ex, | |
| ] | |
| def render_bg( | |
| patch_size, gancraft_bg, hf_seg, voxel_id, depth2, raydirs, cam_origin, z | |
| ): | |
| assert hf_seg.size(2) == CONSTANTS["EXTENDED_VOL_SIZE"] | |
| assert hf_seg.size(3) == CONSTANTS["EXTENDED_VOL_SIZE"] | |
| hf_seg = hf_seg[ | |
| :, | |
| :, | |
| CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"], | |
| CONSTANTS["BUILDING_VOL_SIZE"] : -CONSTANTS["BUILDING_VOL_SIZE"], | |
| ] | |
| assert hf_seg.size(2) == CONSTANTS["LAYOUT_VOL_SIZE"] | |
| assert hf_seg.size(3) == CONSTANTS["LAYOUT_VOL_SIZE"] | |
| # Fix: operator torchvision::nms does not exist | |
| import torchvision | |
| blurrer = torchvision.transforms.GaussianBlur(kernel_size=3, sigma=(2, 2)) | |
| _voxel_id = copy.deepcopy(voxel_id) | |
| _voxel_id[voxel_id >= CONSTANTS["BLD_INS_LABEL_MIN"]] = CLASSES["BLD_FACADE"] | |
| assert (_voxel_id < CONSTANTS["LAYOUT_N_CLASSES"]).all() | |
| bg_img = torch.zeros( | |
| 1, | |
| 3, | |
| CONSTANTS["GES_IMAGE_HEIGHT"], | |
| CONSTANTS["GES_IMAGE_WIDTH"], | |
| dtype=torch.float32, | |
| device=gancraft_bg.output_device, | |
| ) | |
| # Render background patches by patch to avoid OOM | |
| for i in range(CONSTANTS["GES_IMAGE_HEIGHT"] // patch_size[0]): | |
| for j in range(CONSTANTS["GES_IMAGE_WIDTH"] // patch_size[1]): | |
| sy, sx = i * patch_size[0], j * patch_size[1] | |
| ey, ex = sy + patch_size[0], sx + patch_size[1] | |
| psx, pex, psy, pey = _get_pad_img_bbox(sx, ex, sy, ey) | |
| output_bg = gancraft_bg( | |
| hf_seg=hf_seg, | |
| voxel_id=_voxel_id[:, psy:pey, psx:pex], | |
| depth2=depth2[:, psy:pey, psx:pex], | |
| raydirs=raydirs[:, psy:pey, psx:pex], | |
| cam_origin=cam_origin, | |
| building_stats=None, | |
| z=z, | |
| deterministic=True, | |
| ) | |
| # Make road blurry | |
| road_mask = ( | |
| (_voxel_id[:, None, psy:pey, psx:pex, 0, 0] == CLASSES["ROAD"]) | |
| .repeat(1, 3, 1, 1) | |
| .float() | |
| ) | |
| output_bg = blurrer(output_bg) * road_mask + output_bg * (1 - road_mask) | |
| bg_img[:, :, sy:ey, sx:ex] = _get_img_without_pad( | |
| output_bg, sx, ex, sy, ey, psx, pex, psy, pey | |
| ) | |
| return bg_img | |
| def render_fg( | |
| patch_size, | |
| gancraft_fg, | |
| building_id, | |
| hf_seg, | |
| voxel_id, | |
| depth2, | |
| raydirs, | |
| cam_origin, | |
| building_stats, | |
| building_z, | |
| ): | |
| _voxel_id = copy.deepcopy(voxel_id) | |
| _curr_bld = torch.tensor([building_id, building_id - 1], device=voxel_id.device) | |
| _voxel_id[~torch.isin(_voxel_id, _curr_bld)] = 0 | |
| _voxel_id[voxel_id == building_id] = CLASSES["BLD_FACADE"] | |
| _voxel_id[voxel_id == building_id - 1] = CLASSES["BLD_ROOF"] | |
| # assert (_voxel_id < CONSTANTS["LAYOUT_N_CLASSES"]).all() | |
| _hf_seg = copy.deepcopy(hf_seg) | |
| _hf_seg[hf_seg != building_id] = 0 | |
| _hf_seg[hf_seg == building_id] = CLASSES["BLD_FACADE"] | |
| _raydirs = copy.deepcopy(raydirs) | |
| _raydirs[_voxel_id[..., 0, 0] == 0] = 0 | |
| # Crop the "hf_seg" image using the center of the target building as the reference | |
| cx = CONSTANTS["EXTENDED_VOL_SIZE"] // 2 - int(building_stats[1]) | |
| cy = CONSTANTS["EXTENDED_VOL_SIZE"] // 2 - int(building_stats[0]) | |
| sx = cx - CONSTANTS["BUILDING_VOL_SIZE"] // 2 | |
| ex = cx + CONSTANTS["BUILDING_VOL_SIZE"] // 2 | |
| sy = cy - CONSTANTS["BUILDING_VOL_SIZE"] // 2 | |
| ey = cy + CONSTANTS["BUILDING_VOL_SIZE"] // 2 | |
| _hf_seg = hf_seg[:, :, sy:ey, sx:ex] | |
| fg_img = torch.zeros( | |
| 1, | |
| 3, | |
| CONSTANTS["GES_IMAGE_HEIGHT"], | |
| CONSTANTS["GES_IMAGE_WIDTH"], | |
| dtype=torch.float32, | |
| device=gancraft_fg.output_device, | |
| ) | |
| fg_mask = torch.zeros( | |
| 1, | |
| 1, | |
| CONSTANTS["GES_IMAGE_HEIGHT"], | |
| CONSTANTS["GES_IMAGE_WIDTH"], | |
| dtype=torch.float32, | |
| device=gancraft_fg.output_device, | |
| ) | |
| # Prevent some buildings are out of bound. | |
| # THIS SHOULD NEVER HAPPEN AGAIN. | |
| # if ( | |
| # _hf_seg.size(2) != CONSTANTS["BUILDING_VOL_SIZE"] | |
| # or _hf_seg.size(3) != CONSTANTS["BUILDING_VOL_SIZE"] | |
| # ): | |
| # return fg_img, fg_mask | |
| # Render foreground patches by patch to avoid OOM | |
| for i in range(CONSTANTS["GES_IMAGE_HEIGHT"] // patch_size[0]): | |
| for j in range(CONSTANTS["GES_IMAGE_WIDTH"] // patch_size[1]): | |
| sy, sx = i * patch_size[0], j * patch_size[1] | |
| ey, ex = sy + patch_size[0], sx + patch_size[1] | |
| psx, pex, psy, pey = _get_pad_img_bbox(sx, ex, sy, ey) | |
| if torch.count_nonzero(_raydirs[:, sy:ey, sx:ex]) > 0: | |
| output_fg = gancraft_fg( | |
| _hf_seg, | |
| _voxel_id[:, psy:pey, psx:pex], | |
| depth2[:, psy:pey, psx:pex], | |
| _raydirs[:, psy:pey, psx:pex], | |
| cam_origin, | |
| building_stats=torch.from_numpy(np.array(building_stats)).unsqueeze( | |
| dim=0 | |
| ), | |
| z=building_z, | |
| deterministic=True, | |
| ) | |
| facade_mask = ( | |
| voxel_id[:, sy:ey, sx:ex, 0, 0] == building_id | |
| ).unsqueeze(dim=1) | |
| roof_mask = ( | |
| voxel_id[:, sy:ey, sx:ex, 0, 0] == building_id - 1 | |
| ).unsqueeze(dim=1) | |
| facade_img = facade_mask * _get_img_without_pad( | |
| output_fg, sx, ex, sy, ey, psx, pex, psy, pey | |
| ) | |
| # Make roof blurry | |
| # output_fg = F.interpolate( | |
| # F.interpolate(output_fg * 0.8, scale_factor=0.75), | |
| # scale_factor=4 / 3, | |
| # ), | |
| roof_img = roof_mask * _get_img_without_pad( | |
| output_fg, | |
| sx, | |
| ex, | |
| sy, | |
| ey, | |
| psx, | |
| pex, | |
| psy, | |
| pey, | |
| ) | |
| fg_mask[:, :, sy:ey, sx:ex] = torch.logical_or(facade_mask, roof_mask) | |
| fg_img[:, :, sy:ey, sx:ex] = ( | |
| facade_img * facade_mask + roof_img * roof_mask | |
| ) | |
| return fg_img, fg_mask | |
| def render( | |
| patch_size, | |
| seg_volume, | |
| hf_seg, | |
| cam_pos, | |
| gancraft_bg, | |
| gancraft_fg, | |
| building_stats, | |
| bg_z, | |
| building_zs, | |
| ): | |
| voxel_id, depth2, raydirs, cam_origin = get_voxel_intersection_perspective( | |
| seg_volume, cam_pos | |
| ) | |
| buildings = torch.unique(voxel_id[voxel_id > CONSTANTS["BLD_INS_LABEL_MIN"]]) | |
| # Remove odd numbers from the list because they are reserved by roofs. | |
| buildings = buildings[buildings % 2 == 0] | |
| with torch.no_grad(): | |
| bg_img = render_bg( | |
| patch_size, gancraft_bg, hf_seg, voxel_id, depth2, raydirs, cam_origin, bg_z | |
| ) | |
| for b in buildings: | |
| assert b % 2 == 0, "Building Instance ID MUST be an even number." | |
| fg_img, fg_mask = render_fg( | |
| patch_size, | |
| gancraft_fg, | |
| b.item(), | |
| hf_seg, | |
| voxel_id, | |
| depth2, | |
| raydirs, | |
| cam_origin, | |
| building_stats[b.item()], | |
| building_zs[b.item()], | |
| ) | |
| bg_img = bg_img * (1 - fg_mask) + fg_img * fg_mask | |
| return bg_img | |