diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index e12f4b6803..8070561787 100644 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -15,7 +15,7 @@ from scene.gaussian_model import GaussianModel from utils.sh_utils import eval_sh -def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, separate_sh = False, override_color = None, use_trained_exp=False): +def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, separate_sh = False, override_color = None, use_trained_exp=False, orthographic=True): """ Render the scene. @@ -33,21 +33,48 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + # raster_settings = GaussianRasterizationSettings( + # image_height=int(viewpoint_camera.image_height), + # image_width=int(viewpoint_camera.image_width), + # tanfovx=tanfovx, + # tanfovy=tanfovy, + # bg=bg_color, + # scale_modifier=scaling_modifier, + # viewmatrix=viewpoint_camera.world_view_transform, + # projmatrix=viewpoint_camera.full_proj_transform, + # sh_degree=pc.active_sh_degree, + # campos=viewpoint_camera.camera_center, + # prefiltered=False, + # debug=pipe.debug, + # antialiasing=pipe.antialiasing + # ) + + # Set up rasterization configuration + if not orthographic: + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + full_proj_transform = viewpoint_camera.get_full_proj_transform(orthographic) + else: + tanfovx, tanfovy, full_proj_transform = viewpoint_camera.get_full_proj_transform(orthographic) + + raster_settings = GaussianRasterizationSettings( - image_height=int(viewpoint_camera.image_height), - image_width=int(viewpoint_camera.image_width), - tanfovx=tanfovx, - tanfovy=tanfovy, - bg=bg_color, - scale_modifier=scaling_modifier, - viewmatrix=viewpoint_camera.world_view_transform, - projmatrix=viewpoint_camera.full_proj_transform, - sh_degree=pc.active_sh_degree, - campos=viewpoint_camera.camera_center, - prefiltered=False, - debug=pipe.debug, - antialiasing=pipe.antialiasing - ) + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=pipe.debug, + antialiasing=pipe.antialiasing, + orthographic=orthographic + ) + rasterizer = GaussianRasterizer(raster_settings=raster_settings) @@ -71,6 +98,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. shs = None colors_precomp = None + dc = None if override_color is None: if pipe.convert_SHs_python: shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) diff --git a/scene/cameras.py b/scene/cameras.py index 63161b90c1..4ef5f5cbaa 100644 --- a/scene/cameras.py +++ b/scene/cameras.py @@ -87,6 +87,14 @@ def __init__(self, resolution, colmap_id, R, T, FoVx, FoVy, depth_params, image, self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) self.camera_center = self.world_view_transform.inverse()[3, :3] + + def get_full_proj_transform(self, orthographic=False): + if not orthographic: + return self.full_proj_transform + else: + tanfovx, tanfovy, projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy, orthographic=True) + full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(projection_matrix.transpose(0,1).cuda().unsqueeze(0))).squeeze(0) + return tanfovx, tanfovy, full_proj_transform class MiniCam: def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): @@ -101,3 +109,11 @@ def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, view_inv = torch.inverse(self.world_view_transform) self.camera_center = view_inv[3][:3] + def get_full_proj_transform(self, orthographic=False): + if not orthographic: + return self.full_proj_transform + else: + tanfovx, tanfovy, projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy, orthographic=True) + full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(projection_matrix.transpose(0,1).cuda().unsqueeze(0))).squeeze(0) + return tanfovx, tanfovy, full_proj_transform + diff --git a/submodules/diff-gaussian-rasterization b/submodules/diff-gaussian-rasterization index 9c5c2028f6..99bd069645 160000 --- a/submodules/diff-gaussian-rasterization +++ b/submodules/diff-gaussian-rasterization @@ -1 +1 @@ -Subproject commit 9c5c2028f6fbee2be239bc4c9421ff894fe4fbe0 +Subproject commit 99bd069645e993621331df61c917e2174ee1e933 diff --git a/utils/graphics_utils.py b/utils/graphics_utils.py index b4627d837c..53000d4f60 100644 --- a/utils/graphics_utils.py +++ b/utils/graphics_utils.py @@ -48,27 +48,47 @@ def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): Rt = np.linalg.inv(C2W) return np.float32(Rt) -def getProjectionMatrix(znear, zfar, fovX, fovY): - tanHalfFovY = math.tan((fovY / 2)) - tanHalfFovX = math.tan((fovX / 2)) - - top = tanHalfFovY * znear - bottom = -top - right = tanHalfFovX * znear - left = -right - - P = torch.zeros(4, 4) - - z_sign = 1.0 - - P[0, 0] = 2.0 * znear / (right - left) - P[1, 1] = 2.0 * znear / (top - bottom) - P[0, 2] = (right + left) / (right - left) - P[1, 2] = (top + bottom) / (top - bottom) - P[3, 2] = z_sign - P[2, 2] = z_sign * zfar / (zfar - znear) - P[2, 3] = -(zfar * znear) / (zfar - znear) - return P +def getProjectionMatrix(znear, zfar, fovX, fovY, orthographic=False): + if not orthographic: + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + + return P + + else: + left, right = -fovX, fovX + bottom, top = -fovY, fovY + + P = torch.zeros(4, 4) + + z_sign = 1.0 + P[0, 0] = 2.0 / (right - left) + P[0, 3] = - (right + left) / (right - left) + P[1, 1] = 2.0 / (top - bottom) + P[1, 3] = - (top + bottom) / (top - bottom) + P[2, 2] = -2.0 / (zfar - znear) + P[2, 3] = - (zfar + znear) / (zfar - znear) + P[3, 3] = z_sign + + # tanfovx, tanfovy, P + return (right - left) / 2, (top - bottom) / 2, P def fov2focal(fov, pixels): return pixels / (2 * math.tan(fov / 2))