From 632e9edcd513e6398e1cfb5f10e9159667830e63 Mon Sep 17 00:00:00 2001 From: Brian Park Date: Thu, 7 Dec 2023 15:24:49 -0500 Subject: [PATCH 1/2] chore: fixed keypoints inputs not updating to most recent logits --- .../layers/numerical_calibrator.py | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/pytorch_lattice/layers/numerical_calibrator.py b/pytorch_lattice/layers/numerical_calibrator.py index 6a15a6b..ab0d60d 100644 --- a/pytorch_lattice/layers/numerical_calibrator.py +++ b/pytorch_lattice/layers/numerical_calibrator.py @@ -151,16 +151,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch.Tensor of shape `(batch_size, 1)` containing calibrated input values. """ if self.input_keypoints_type == InputKeypointsType.LEARNED: - softmaxed_logits = torch.nn.functional.softmax( - self._interpolation_logits, dim=-1 - ) - self._lengths = softmaxed_logits * self._keypoint_range - interior_keypoints = ( - torch.cumsum(self._lengths, dim=-1) + self._keypoint_min - ) - self._interpolation_keypoints = torch.cat( - [torch.tensor([self._keypoint_min]), interior_keypoints[:-1]] - ) + self._update_from_logits() interpolation_weights = (x - self._interpolation_keypoints) / self._lengths interpolation_weights = torch.minimum(interpolation_weights, torch.tensor(1.0)) @@ -302,6 +293,9 @@ def assert_constraints(self, eps: float = 1e-6) -> list[str]: @torch.no_grad() def keypoints_inputs(self) -> torch.Tensor: """Returns tensor of keypoint inputs.""" + if self.input_keypoints_type == InputKeypointsType.LEARNED: + self._update_from_logits() + return torch.cat( ( self._interpolation_keypoints, @@ -428,3 +422,21 @@ def _squeeze_by_scaling( if decreasing: bias, heights = -bias, -heights return bias, heights + + def _update_from_logits(self) -> None: + """Makes necessary updates according to most recent self._interpolation_logits. + + If running the layer with `InputKeyPointType.LEARNED.`, this method will ensure + `self._interpolation_keypoints` and `self._lengths` are correctly updated with + regards to the most recent iteration of `self._interpolation_logits`. + """ + softmaxed_logits = torch.nn.functional.softmax( + self._interpolation_logits, dim=-1 + ) + self._lengths = softmaxed_logits * self._keypoint_range + interior_keypoints = (torch.cumsum(self._lengths, dim=-1) + self._keypoint_min)[ + :-1 + ] + self._interpolation_keypoints = torch.cat( + [torch.tensor([self._keypoint_min]), interior_keypoints] + ) From aa1d10026da1668aff0f201d02ff806173719185 Mon Sep 17 00:00:00 2001 From: Brian Park Date: Thu, 7 Dec 2023 16:37:59 -0500 Subject: [PATCH 2/2] chore:fixed function name from update_from_logits to calculate_lengths_and_interpolation_keypoints --- pytorch_lattice/layers/numerical_calibrator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lattice/layers/numerical_calibrator.py b/pytorch_lattice/layers/numerical_calibrator.py index ab0d60d..3283619 100644 --- a/pytorch_lattice/layers/numerical_calibrator.py +++ b/pytorch_lattice/layers/numerical_calibrator.py @@ -151,7 +151,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch.Tensor of shape `(batch_size, 1)` containing calibrated input values. """ if self.input_keypoints_type == InputKeypointsType.LEARNED: - self._update_from_logits() + self._calculate_lengths_and_interpolation_keypoints() interpolation_weights = (x - self._interpolation_keypoints) / self._lengths interpolation_weights = torch.minimum(interpolation_weights, torch.tensor(1.0)) @@ -294,7 +294,7 @@ def assert_constraints(self, eps: float = 1e-6) -> list[str]: def keypoints_inputs(self) -> torch.Tensor: """Returns tensor of keypoint inputs.""" if self.input_keypoints_type == InputKeypointsType.LEARNED: - self._update_from_logits() + self._calculate_lengths_and_interpolation_keypoints() return torch.cat( ( @@ -423,7 +423,7 @@ def _squeeze_by_scaling( bias, heights = -bias, -heights return bias, heights - def _update_from_logits(self) -> None: + def _calculate_lengths_and_interpolation_keypoints(self) -> None: """Makes necessary updates according to most recent self._interpolation_logits. If running the layer with `InputKeyPointType.LEARNED.`, this method will ensure