diff --git a/examples/howtos/howto_ensemble_calculation.ipynb b/examples/howtos/howto_ensemble_calculation.ipynb index 6496c716f..063bd4822 100644 --- a/examples/howtos/howto_ensemble_calculation.ipynb +++ b/examples/howtos/howto_ensemble_calculation.ipynb @@ -42,7 +42,11 @@ "from ase.md.velocitydistribution import MaxwellBoltzmannDistribution\n", "from ase.md.langevin import Langevin\n", "\n", - "from schnetpack.interfaces.ase_interface import SpkEnsembleCalculator, AbsoluteUncertainty, RelativeUncertainty\n", + "from schnetpack.interfaces.ase_interface import (\n", + " SpkEnsembleCalculator,\n", + " AbsoluteUncertainty,\n", + " RelativeUncertainty,\n", + ")\n", "import schnetpack.transform as trn\n", "from schnetpack.datasets import MD17\n", "import torch\n", @@ -78,11 +82,13 @@ }, "outputs": [], "source": [ - "model_path_list = ['../trained_models/rmd17_ethanol/painn_1/best_model',\n", - " '../trained_models/rmd17_ethanol/painn_2/best_model',\n", - " '../trained_models/rmd17_ethanol/painn_3/best_model',\n", - " '../trained_models/rmd17_ethanol/painn_4/best_model',\n", - " '../trained_models/rmd17_ethanol/painn_5/best_model']" + "model_path_list = [\n", + " \"../trained_models/rmd17_ethanol/painn_1/best_model\",\n", + " \"../trained_models/rmd17_ethanol/painn_2/best_model\",\n", + " \"../trained_models/rmd17_ethanol/painn_3/best_model\",\n", + " \"../trained_models/rmd17_ethanol/painn_4/best_model\",\n", + " \"../trained_models/rmd17_ethanol/painn_5/best_model\",\n", + "]" ] }, { @@ -113,7 +119,7 @@ }, "outputs": [], "source": [ - "uncertainty_abs = AbsoluteUncertainty(energy_weight=0.5,force_weight=1.0)\n", + "uncertainty_abs = AbsoluteUncertainty(energy_weight=0.5, force_weight=1.0)\n", "uncertainty_rel = RelativeUncertainty(energy_weight=1.0, force_weight=2.0)\n", "\n", "uncertainty = [uncertainty_abs, uncertainty_rel]\n", @@ -125,7 +131,8 @@ " force_key=MD17.forces,\n", " energy_unit=\"kcal/mol\",\n", " position_unit=\"Ang\",\n", - " uncertainty_fn=uncertainty)" + " uncertainty_fn=uncertainty,\n", + ")" ] }, { @@ -146,8 +153,8 @@ }, "outputs": [], "source": [ - "#load data into atoms object\n", - "atoms = read('../../tests/testdata/md_ethanol.xyz', index=0)\n", + "# load data into atoms object\n", + "atoms = read(\"../../tests/testdata/md_ethanol.xyz\", index=0)\n", "# specify atoms calculator\n", "atoms.calc = ensemble_calculator" ] @@ -290,22 +297,22 @@ "fig, ax1 = plt.subplots(figsize=(8, 6))\n", "\n", "# Plot absolute uncertainty on left y-axis\n", - "ax1.plot(steps, abs_vals, label=\"Absolute Uncertainty\", marker='o', color='tab:blue')\n", + "ax1.plot(steps, abs_vals, label=\"Absolute Uncertainty\", marker=\"o\", color=\"tab:blue\")\n", "ax1.set_xlabel(\"Optimization Step\")\n", - "ax1.set_ylabel(\"Absolute Uncertainty\", color='tab:blue')\n", - "ax1.tick_params(axis='y', labelcolor='tab:blue')\n", + "ax1.set_ylabel(\"Absolute Uncertainty\", color=\"tab:blue\")\n", + "ax1.tick_params(axis=\"y\", labelcolor=\"tab:blue\")\n", "ax1.grid(True)\n", "\n", "# Create second y-axis for relative uncertainty\n", "ax2 = ax1.twinx()\n", - "ax2.plot(steps, rel_vals, label=\"Relative Uncertainty\", marker='x', color='tab:red')\n", - "ax2.set_ylabel(\"Relative Uncertainty\", color='tab:red')\n", - "ax2.tick_params(axis='y', labelcolor='tab:red')\n", + "ax2.plot(steps, rel_vals, label=\"Relative Uncertainty\", marker=\"x\", color=\"tab:red\")\n", + "ax2.set_ylabel(\"Relative Uncertainty\", color=\"tab:red\")\n", + "ax2.tick_params(axis=\"y\", labelcolor=\"tab:red\")\n", "\n", "# Title and layout\n", "plt.title(\"Uncertainty during Optimization\")\n", "fig.tight_layout()\n", - "plt.show()\n" + "plt.show()" ] }, { @@ -344,7 +351,7 @@ }, "outputs": [], "source": [ - "uncertainty_abs = AbsoluteUncertainty(energy_weight=0.5,force_weight=1.0)\n", + "uncertainty_abs = AbsoluteUncertainty(energy_weight=0.5, force_weight=1.0)\n", "\n", "abs_ensemble_calculator = SpkEnsembleCalculator(\n", " models=model_path_list,\n", @@ -353,7 +360,8 @@ " force_key=MD17.forces,\n", " energy_unit=\"kcal/mol\",\n", " position_unit=\"Ang\",\n", - " uncertainty_fn=uncertainty_abs)" + " uncertainty_fn=uncertainty_abs,\n", + ")" ] }, { @@ -367,13 +375,13 @@ }, "outputs": [], "source": [ - "target_temperatures = [_ for _ in range(50, 800, 100)] \n", - "n_steps = 1000 \n", - "sampling_interval = 10 \n", - "step_size = 0.5 \n", + "target_temperatures = [_ for _ in range(50, 800, 100)]\n", + "n_steps = 1000\n", + "sampling_interval = 10\n", + "step_size = 0.5\n", "\n", "# setting up initial atoms\n", - "atoms = read('../../tests/testdata/md_ethanol.xyz', index=0)\n", + "atoms = read(\"../../tests/testdata/md_ethanol.xyz\", index=0)\n", "atoms.calc = abs_ensemble_calculator\n", "\n", "MaxwellBoltzmannDistribution(atoms, temperature_K=target_temperatures[0])\n", @@ -385,19 +393,19 @@ "for target_temperature in target_temperatures:\n", " print(f\"Temp: {target_temperature:.2f} K\")\n", " for step in tqdm(range(n_steps // sampling_interval)):\n", - " \n", + "\n", " dyn = Langevin(\n", - " atoms, \n", - " timestep=step_size * units.fs, \n", + " atoms,\n", + " timestep=step_size * units.fs,\n", " temperature_K=target_temperature,\n", - " friction=0.01 / units.fs\n", + " friction=0.01 / units.fs,\n", " )\n", - " \n", + "\n", " dyn.run(sampling_interval)\n", - " \n", + "\n", " temp.append(atoms.get_temperature())\n", " uncertainties.append(abs_ensemble_calculator.get_uncertainty(atoms))\n", - " \n", + "\n", " ats_traj.append(atoms.copy())" ] }, @@ -409,22 +417,22 @@ "source": [ "fig, ax1 = plt.subplots(figsize=(8, 6))\n", "\n", - "ax1.plot(uncertainties, marker='o', color='blue', label='Uncertainty')\n", + "ax1.plot(uncertainties, marker=\"o\", color=\"blue\", label=\"Uncertainty\")\n", "ax1.set_xlabel(\"MD Step\")\n", - "ax1.set_ylabel(\"Uncertainty\", color='blue')\n", - "ax1.tick_params(axis='y', labelcolor='blue')\n", + "ax1.set_ylabel(\"Uncertainty\", color=\"blue\")\n", + "ax1.tick_params(axis=\"y\", labelcolor=\"blue\")\n", "\n", "ax2 = ax1.twinx()\n", - "ax2.plot(temp, marker='x', color='red', label='Temperature')\n", - "ax2.set_ylabel(\"Temperature (K)\", color='red')\n", - "ax2.tick_params(axis='y', labelcolor='red')\n", + "ax2.plot(temp, marker=\"x\", color=\"red\", label=\"Temperature\")\n", + "ax2.set_ylabel(\"Temperature (K)\", color=\"red\")\n", + "ax2.tick_params(axis=\"y\", labelcolor=\"red\")\n", "\n", "plt.title(\"Molecular Dynamics: Uncertainty and Temperature Profile\")\n", "ax1.grid(True)\n", "\n", "lines_1, labels_1 = ax1.get_legend_handles_labels()\n", "lines_2, labels_2 = ax2.get_legend_handles_labels()\n", - "ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc='upper right')\n", + "ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc=\"upper right\")\n", "\n", "plt.tight_layout()\n", "plt.show()" @@ -444,6 +452,7 @@ "outputs": [], "source": [ "from ase.visualize import view\n", + "\n", "view(ats_traj)" ] } diff --git a/examples/tutorials/tutorial_02_qm9.ipynb b/examples/tutorials/tutorial_02_qm9.ipynb index 7f800891e..bc6defa5b 100644 --- a/examples/tutorials/tutorial_02_qm9.ipynb +++ b/examples/tutorials/tutorial_02_qm9.ipynb @@ -377,9 +377,7 @@ "from ase import Atoms\n", "from schnetpack.utils.compatibility import load_model\n", "\n", - "best_model = load_model(\n", - " os.path.join(qm9tut, \"best_inference_model\"), device=\"cpu\"\n", - ")" + "best_model = load_model(os.path.join(qm9tut, \"best_inference_model\"), device=\"cpu\")" ] }, { diff --git a/src/schnetpack/representation/painn.py b/src/schnetpack/representation/painn.py index c57d413e6..cc2012920 100644 --- a/src/schnetpack/representation/painn.py +++ b/src/schnetpack/representation/painn.py @@ -136,6 +136,7 @@ def __init__( cutoff_fn: Optional[Callable] = None, activation: Optional[Callable] = F.silu, shared_interactions: bool = False, + return_vector_representation: bool = False, shared_filters: bool = False, epsilon: float = 1e-8, nuclear_embedding: Optional[nn.Module] = None, @@ -165,6 +166,8 @@ def __init__( self.cutoff_fn = cutoff_fn self.cutoff = cutoff_fn.cutoff self.radial_basis = radial_basis + self.return_vector_representation = return_vector_representation + self.epsilon = epsilon # initialize embeddings if nuclear_embedding is None: @@ -198,7 +201,7 @@ def __init__( ) self.mixing = snn.replicate_module( lambda: PaiNNMixing( - n_atom_basis=self.n_atom_basis, activation=activation, epsilon=epsilon + n_atom_basis=self.n_atom_basis, activation=activation, epsilon=self.epsilon ), self.n_interactions, shared_interactions, @@ -224,7 +227,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]): n_atoms = atomic_numbers.shape[0] # compute atom and pair features - d_ij = torch.norm(r_ij, dim=1, keepdim=True) + d_ij = torch.sqrt(torch.sum(r_ij**2, dim=-1, keepdim=True) + self.epsilon) dir_ij = r_ij / d_ij phi_ij = self.radial_basis(d_ij) fcut = self.cutoff_fn(d_ij) @@ -251,6 +254,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]): # collect results inputs["scalar_representation"] = q - inputs["vector_representation"] = mu + if self.return_vector_representation: + inputs["vector_representation"] = mu return inputs