From 61f35b6613b1b0058cab81adb8b1c021ebc6dd19 Mon Sep 17 00:00:00 2001 From: eli Date: Tue, 27 Aug 2024 15:40:31 +0300 Subject: [PATCH] should use log_probs in affinity loss --- affinity_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/affinity_loss.py b/affinity_loss.py index 180f74e..b08b23b 100644 --- a/affinity_loss.py +++ b/affinity_loss.py @@ -78,8 +78,8 @@ def forward(self, logits, labels): lbedge = labels[:, idx_e[0]:idx_e[1], idx_e[2]:idx_e[3]].detach() igncenter = ignore_mask[:, idx_c[0]:idx_c[1], idx_c[2]:idx_c[3]].detach() ignedge = ignore_mask[:, idx_e[0]:idx_e[1], idx_e[2]:idx_e[3]].detach() - lgp_center = probs[:, :, idx_c[0]:idx_c[1], idx_c[2]:idx_c[3]] - lgp_edge = probs[:, :, idx_e[0]:idx_e[1], idx_e[2]:idx_e[3]] + lgp_center = log_probs[:, :, idx_c[0]:idx_c[1], idx_c[2]:idx_c[3]] + lgp_edge = log_probs[:, :, idx_e[0]:idx_e[1], idx_e[2]:idx_e[3]] prob_edge = probs[:, :, idx_e[0]:idx_e[1], idx_e[2]:idx_e[3]] kldiv = (prob_edge * (lgp_edge - lgp_center)).sum(dim=1)