From 587df63e3c4950721ca10ea60f1e1be3b222f2a0 Mon Sep 17 00:00:00 2001 From: Jack Scantlebury <39645092+jscant@users.noreply.github.com> Date: Tue, 19 Oct 2021 08:43:13 +0100 Subject: [PATCH] Added optional tanh to coors_mlp This removes the NaN bug completely (must also use norm_coors otherwise performance dies) --- egnn_pytorch/egnn_pytorch.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/egnn_pytorch/egnn_pytorch.py b/egnn_pytorch/egnn_pytorch.py index 5376a12..f0730ab 100644 --- a/egnn_pytorch/egnn_pytorch.py +++ b/egnn_pytorch/egnn_pytorch.py @@ -164,11 +164,13 @@ def __init__( valid_radius = float('inf'), m_pool_method = 'sum', soft_edges = False, - coor_weights_clamp_value = None + coor_weights_clamp_value = None, + coors_tanh=False # Only to be used alongside norm_coors - highly recommended for stability ): super().__init__() assert m_pool_method in {'sum', 'mean'}, 'pool method must be either sum or mean' assert update_feats or update_coors, 'you must update either features, coordinates, or both' + assert not (coors_tanh and not norm_coors), 'coors_tanh must be used with norm_coors' self.fourier_features = fourier_features @@ -200,11 +202,14 @@ def __init__( nn.Linear(dim * 2, dim), ) if update_feats else None + # Tanh layer helps with stability but should only be used in conjuction with + # norm_coors self.coors_mlp = nn.Sequential( nn.Linear(m_dim, m_dim * 4), dropout, SiLU(), - nn.Linear(m_dim * 4, 1) + nn.Linear(m_dim * 4, 1), + nn.Tanh() if coors_tanh else nn.Identity() ) if update_coors else None self.num_nearest_neighbors = num_nearest_neighbors