From 9d7d68771c0261b89d92377d29dc0e0821a879d5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 28 Aug 2025 17:12:19 +0000 Subject: [PATCH 1/4] Initial plan From 7f30c5d21a7d888717ad7d9f8522badaf01c209c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 28 Aug 2025 17:29:49 +0000 Subject: [PATCH 2/4] feat(dpmodel): refactor compute_input_stats using common mixin Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/dpmodel/common.py | 78 +++++++++++++++++++++++++ deepmd/dpmodel/descriptor/repformers.py | 57 ++---------------- deepmd/dpmodel/descriptor/se_e2_a.py | 46 ++------------- 3 files changed, 88 insertions(+), 93 deletions(-) diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 1f9d4817a2..be04c1a38e 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -10,6 +10,7 @@ Any, Callable, Optional, + Union, overload, ) @@ -220,11 +221,88 @@ def safe_cast_array( return input +class ComputeInputStatsMixin: + """Mixin class providing common compute_input_stats implementation. + + This mixin implements the shared logic for computing input statistics + across all descriptor backends, while allowing backend-specific tensor + assignment through abstract methods. + """ + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[Any] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: tensor + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + from deepmd.dpmodel.utils.env_mat_stat import ( + EnvMatStatSe, + ) + + env_mat_stat = EnvMatStatSe(self) + if path is not None: + path = path / env_mat_stat.get_hash() + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() + + # Backend-specific tensor assignment + self._set_stat_mean_and_stddev(mean, stddev) + + @abstractmethod + def _set_stat_mean_and_stddev(self, mean, stddev) -> None: + """Set the computed statistics to the descriptor's mean and stddev attributes. + + This method should be implemented by each backend to handle the specific + tensor assignment logic for that backend. + + Parameters + ---------- + mean : array-like + The computed mean values + stddev : array-like + The computed standard deviation values + """ + raise NotImplementedError + + def get_stats(self) -> dict[str, Any]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + + __all__ = [ "DEFAULT_PRECISION", "GLOBAL_ENER_FLOAT_PRECISION", "GLOBAL_NP_FLOAT_PRECISION", "PRECISION_DICT", "RESERVED_PRECISION_DICT", + "ComputeInputStatsMixin", "NativeOP", ] diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index 6ac9675d28..d060da562d 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( - Callable, Optional, Union, ) @@ -16,15 +15,13 @@ xp_take_along_axis, ) from deepmd.dpmodel.common import ( + ComputeInputStatsMixin, to_numpy_array, ) from deepmd.dpmodel.utils import ( EnvMat, PairExcludeMask, ) -from deepmd.dpmodel.utils.env_mat_stat import ( - EnvMatStatSe, -) from deepmd.dpmodel.utils.network import ( LayerNorm, NativeLayer, @@ -36,12 +33,6 @@ from deepmd.dpmodel.utils.seed import ( child_seed, ) -from deepmd.utils.env_mat_stat import ( - StatItem, -) -from deepmd.utils.path import ( - DPPath, -) from deepmd.utils.version import ( check_version_compatibility, ) @@ -78,7 +69,7 @@ def xp_transpose_01342(x): @DescriptorBlock.register("se_repformer") @DescriptorBlock.register("se_uni") -class DescrptBlockRepformers(NativeOP, DescriptorBlock): +class DescrptBlockRepformers(NativeOP, DescriptorBlock, ComputeInputStatsMixin): r""" The repformer descriptor block. @@ -379,54 +370,16 @@ def dim_emb(self): """Returns the embedding dimension g2.""" return self.get_dim_emb() - def compute_input_stats( - self, - merged: Union[Callable[[], list[dict]], list[dict]], - path: Optional[DPPath] = None, - ) -> None: - """ - Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. - - Parameters - ---------- - merged : Union[Callable[[], list[dict]], list[dict]] - - list[dict]: A list of data samples from various data systems. - Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` - originating from the `i`-th data system. - - Callable[[], list[dict]]: A lazy function that returns data samples in the above format - only when needed. Since the sampling process can be slow and memory-intensive, - the lazy function helps by only sampling once. - path : Optional[DPPath] - The path to the stat file. + def _set_stat_mean_and_stddev(self, mean, stddev) -> None: + """Set the computed statistics to the descriptor's mean and stddev attributes. + This is the dpmodel backend-specific implementation using array_api_compat. """ - env_mat_stat = EnvMatStatSe(self) - if path is not None: - path = path / env_mat_stat.get_hash() - if path is None or not path.is_dir(): - if callable(merged): - # only get data for once - sampled = merged() - else: - sampled = merged - else: - sampled = [] - env_mat_stat.load_or_compute_stats(sampled, path) - self.stats = env_mat_stat.stats - mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.stddev) if not self.set_davg_zero: self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True) self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True) - def get_stats(self) -> dict[str, StatItem]: - """Get the statistics of the descriptor.""" - if self.stats is None: - raise RuntimeError( - "The statistics of the descriptor has not been computed." - ) - return self.stats - def reinit_exclude( self, exclude_types: list[tuple[int, int]] = [], diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 5bcffc6c53..33926ae236 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -2,7 +2,6 @@ import itertools from typing import ( Any, - Callable, NoReturn, Optional, Union, @@ -17,6 +16,7 @@ NativeOP, ) from deepmd.dpmodel.common import ( + ComputeInputStatsMixin, cast_precision, to_numpy_array, ) @@ -26,9 +26,6 @@ NetworkCollection, PairExcludeMask, ) -from deepmd.dpmodel.utils.env_mat_stat import ( - EnvMatStatSe, -) from deepmd.dpmodel.utils.seed import ( child_seed, ) @@ -38,9 +35,6 @@ from deepmd.utils.data_system import ( DeepmdDataSystem, ) -from deepmd.utils.path import ( - DPPath, -) from deepmd.utils.version import ( check_version_compatibility, ) @@ -52,7 +46,7 @@ @BaseDescriptor.register("se_e2_a") @BaseDescriptor.register("se_a") -class DescrptSeA(NativeOP, BaseDescriptor): +class DescrptSeA(NativeOP, BaseDescriptor, ComputeInputStatsMixin): r"""DeepPot-SE constructed from all information (both angular and radial) of atomic configurations. The embedding takes the distance between atoms as input. @@ -309,41 +303,11 @@ def get_type_map(self) -> list[str]: """Get the name to each type of atoms.""" return self.type_map - def compute_input_stats( - self, - merged: Union[Callable[[], list[dict]], list[dict]], - path: Optional[DPPath] = None, - ) -> None: - """ - Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. - - Parameters - ---------- - merged : Union[Callable[[], list[dict]], list[dict]] - - list[dict]: A list of data samples from various data systems. - Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` - originating from the `i`-th data system. - - Callable[[], list[dict]]: A lazy function that returns data samples in the above format - only when needed. Since the sampling process can be slow and memory-intensive, - the lazy function helps by only sampling once. - path : Optional[DPPath] - The path to the stat file. + def _set_stat_mean_and_stddev(self, mean, stddev) -> None: + """Set the computed statistics to the descriptor's mean and stddev attributes. + This is the dpmodel backend-specific implementation using array_api_compat. """ - env_mat_stat = EnvMatStatSe(self) - if path is not None: - path = path / env_mat_stat.get_hash() - if path is None or not path.is_dir(): - if callable(merged): - # only get data for once - sampled = merged() - else: - sampled = merged - else: - sampled = [] - env_mat_stat.load_or_compute_stats(sampled, path) - self.stats = env_mat_stat.stats - mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.dstd) if not self.set_davg_zero: self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True) From 548141125997131b0b435b0bfec6c1f076394333 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 28 Aug 2025 17:38:10 +0000 Subject: [PATCH 3/4] feat(pt): refactor compute_input_stats using common mixin Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/pt/common.py | 87 ++++++++++++++++++++++++ deepmd/pt/model/descriptor/repformers.py | 59 ++-------------- deepmd/pt/model/descriptor/se_a.py | 55 ++------------- 3 files changed, 101 insertions(+), 100 deletions(-) create mode 100644 deepmd/pt/common.py diff --git a/deepmd/pt/common.py b/deepmd/pt/common.py new file mode 100644 index 0000000000..3502f90ff4 --- /dev/null +++ b/deepmd/pt/common.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Common functionality shared across PyTorch descriptor implementations.""" + +from abc import ( + abstractmethod, +) +from typing import ( + Any, + Callable, + Optional, + Union, +) + + +class ComputeInputStatsMixin: + """Mixin class providing common compute_input_stats implementation for PyTorch backend. + + This mixin implements the shared logic for computing input statistics + while allowing backend-specific tensor assignment through abstract methods. + """ + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[Any] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: torch.Tensor + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSe, + ) + + env_mat_stat = EnvMatStatSe(self) + if path is not None: + path = path / env_mat_stat.get_hash() + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() + + # Backend-specific tensor assignment + self._set_stat_mean_and_stddev(mean, stddev) + + @abstractmethod + def _set_stat_mean_and_stddev(self, mean, stddev) -> None: + """Set the computed statistics to the descriptor's mean and stddev attributes. + + This method should be implemented by each descriptor to handle the specific + tensor assignment logic for PyTorch backend. + + Parameters + ---------- + mean : array-like + The computed mean values + stddev : array-like + The computed standard deviation values + """ + raise NotImplementedError + + def get_stats(self) -> dict[str, Any]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 022c7510df..4782ccd6e0 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( - Callable, Optional, Union, ) @@ -10,6 +9,9 @@ from deepmd.dpmodel.utils.seed import ( child_seed, ) +from deepmd.pt.common import ( + ComputeInputStatsMixin, +) from deepmd.pt.model.descriptor.descriptor import ( DescriptorBlock, ) @@ -25,9 +27,6 @@ from deepmd.pt.utils.env import ( PRECISION_DICT, ) -from deepmd.pt.utils.env_mat_stat import ( - EnvMatStatSe, -) from deepmd.pt.utils.exclude_mask import ( PairExcludeMask, ) @@ -37,12 +36,6 @@ from deepmd.pt.utils.utils import ( ActivationFn, ) -from deepmd.utils.env_mat_stat import ( - StatItem, -) -from deepmd.utils.path import ( - DPPath, -) from .repformer_layer import ( RepformerLayer, @@ -72,7 +65,7 @@ def border_op( @DescriptorBlock.register("se_repformer") @DescriptorBlock.register("se_uni") -class DescrptBlockRepformers(DescriptorBlock): +class DescrptBlockRepformers(DescriptorBlock, ComputeInputStatsMixin): def __init__( self, rcut, @@ -537,41 +530,11 @@ def forward( return g1, g2, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw - def compute_input_stats( - self, - merged: Union[Callable[[], list[dict]], list[dict]], - path: Optional[DPPath] = None, - ) -> None: - """ - Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. - - Parameters - ---------- - merged : Union[Callable[[], list[dict]], list[dict]] - - list[dict]: A list of data samples from various data systems. - Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` - originating from the `i`-th data system. - - Callable[[], list[dict]]: A lazy function that returns data samples in the above format - only when needed. Since the sampling process can be slow and memory-intensive, - the lazy function helps by only sampling once. - path : Optional[DPPath] - The path to the stat file. + def _set_stat_mean_and_stddev(self, mean, stddev) -> None: + """Set the computed statistics to the descriptor's mean and stddev attributes. + This is the PyTorch backend-specific implementation using torch.tensor. """ - env_mat_stat = EnvMatStatSe(self) - if path is not None: - path = path / env_mat_stat.get_hash() - if path is None or not path.is_dir(): - if callable(merged): - # only get data for once - sampled = merged() - else: - sampled = merged - else: - sampled = [] - env_mat_stat.load_or_compute_stats(sampled, path) - self.stats = env_mat_stat.stats - mean, stddev = env_mat_stat() if not self.set_davg_zero: self.mean.copy_( torch.tensor(mean, device=env.DEVICE, dtype=self.mean.dtype) @@ -580,14 +543,6 @@ def compute_input_stats( torch.tensor(stddev, device=env.DEVICE, dtype=self.stddev.dtype) ) - def get_stats(self) -> dict[str, StatItem]: - """Get the statistics of the descriptor.""" - if self.stats is None: - raise RuntimeError( - "The statistics of the descriptor has not been computed." - ) - return self.stats - def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" return True diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index f49b5a1276..1766a92d15 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -14,6 +14,9 @@ from deepmd.dpmodel.utils.seed import ( child_seed, ) +from deepmd.pt.common import ( + ComputeInputStatsMixin, +) from deepmd.pt.model.descriptor import ( DescriptorBlock, prod_env_mat, @@ -25,18 +28,12 @@ PRECISION_DICT, RESERVED_PRECISION_DICT, ) -from deepmd.pt.utils.env_mat_stat import ( - EnvMatStatSe, -) from deepmd.pt.utils.update_sel import ( UpdateSel, ) from deepmd.utils.data_system import ( DeepmdDataSystem, ) -from deepmd.utils.env_mat_stat import ( - StatItem, -) from deepmd.utils.path import ( DPPath, ) @@ -449,7 +446,7 @@ def update_sel( @DescriptorBlock.register("se_e2_a") -class DescrptBlockSeA(DescriptorBlock): +class DescrptBlockSeA(DescriptorBlock, ComputeInputStatsMixin): ndescrpt: Final[int] __constants__: ClassVar[list] = ["ndescrpt"] @@ -627,41 +624,11 @@ def __getitem__(self, key): else: raise KeyError(key) - def compute_input_stats( - self, - merged: Union[Callable[[], list[dict]], list[dict]], - path: Optional[DPPath] = None, - ) -> None: - """ - Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. - - Parameters - ---------- - merged : Union[Callable[[], list[dict]], list[dict]] - - list[dict]: A list of data samples from various data systems. - Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` - originating from the `i`-th data system. - - Callable[[], list[dict]]: A lazy function that returns data samples in the above format - only when needed. Since the sampling process can be slow and memory-intensive, - the lazy function helps by only sampling once. - path : Optional[DPPath] - The path to the stat file. + def _set_stat_mean_and_stddev(self, mean, stddev) -> None: + """Set the computed statistics to the descriptor's mean and stddev attributes. + This is the PyTorch backend-specific implementation using torch.tensor. """ - env_mat_stat = EnvMatStatSe(self) - if path is not None: - path = path / env_mat_stat.get_hash() - if path is None or not path.is_dir(): - if callable(merged): - # only get data for once - sampled = merged() - else: - sampled = merged - else: - sampled = [] - env_mat_stat.load_or_compute_stats(sampled, path) - self.stats = env_mat_stat.stats - mean, stddev = env_mat_stat() if not self.set_davg_zero: self.mean.copy_( torch.tensor(mean, device=env.DEVICE, dtype=self.mean.dtype) @@ -670,14 +637,6 @@ def compute_input_stats( torch.tensor(stddev, device=env.DEVICE, dtype=self.stddev.dtype) ) - def get_stats(self) -> dict[str, StatItem]: - """Get the statistics of the descriptor.""" - if self.stats is None: - raise RuntimeError( - "The statistics of the descriptor has not been computed." - ) - return self.stats - def reinit_exclude( self, exclude_types: list[tuple[int, int]] = [], From ed57aeb3603ea06904f547e5af3828ff7b103009 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 28 Aug 2025 17:42:19 +0000 Subject: [PATCH 4/4] feat(pd): refactor compute_input_stats using common mixin - complete refactorization Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/pd/common.py | 87 ++++++++++++++++++++++++ deepmd/pd/model/descriptor/repformers.py | 59 ++-------------- 2 files changed, 94 insertions(+), 52 deletions(-) create mode 100644 deepmd/pd/common.py diff --git a/deepmd/pd/common.py b/deepmd/pd/common.py new file mode 100644 index 0000000000..a254cf2322 --- /dev/null +++ b/deepmd/pd/common.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Common functionality shared across Paddle descriptor implementations.""" + +from abc import ( + abstractmethod, +) +from typing import ( + Any, + Callable, + Optional, + Union, +) + + +class ComputeInputStatsMixin: + """Mixin class providing common compute_input_stats implementation for Paddle backend. + + This mixin implements the shared logic for computing input statistics + while allowing backend-specific tensor assignment through abstract methods. + """ + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[Any] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: paddle.Tensor + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + from deepmd.pd.utils.env_mat_stat import ( + EnvMatStatSe, + ) + + env_mat_stat = EnvMatStatSe(self) + if path is not None: + path = path / env_mat_stat.get_hash() + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() + + # Backend-specific tensor assignment + self._set_stat_mean_and_stddev(mean, stddev) + + @abstractmethod + def _set_stat_mean_and_stddev(self, mean, stddev) -> None: + """Set the computed statistics to the descriptor's mean and stddev attributes. + + This method should be implemented by each descriptor to handle the specific + tensor assignment logic for Paddle backend. + + Parameters + ---------- + mean : array-like + The computed mean values + stddev : array-like + The computed standard deviation values + """ + raise NotImplementedError + + def get_stats(self) -> dict[str, Any]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats diff --git a/deepmd/pd/model/descriptor/repformers.py b/deepmd/pd/model/descriptor/repformers.py index 4151833f35..c628aae2ea 100644 --- a/deepmd/pd/model/descriptor/repformers.py +++ b/deepmd/pd/model/descriptor/repformers.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( - Callable, Optional, Union, ) @@ -10,6 +9,9 @@ from deepmd.dpmodel.utils.seed import ( child_seed, ) +from deepmd.pd.common import ( + ComputeInputStatsMixin, +) from deepmd.pd.cxx_op import ( ENABLE_CUSTOMIZED_OP, paddle_ops_deepmd, @@ -29,9 +31,6 @@ from deepmd.pd.utils.env import ( PRECISION_DICT, ) -from deepmd.pd.utils.env_mat_stat import ( - EnvMatStatSe, -) from deepmd.pd.utils.exclude_mask import ( PairExcludeMask, ) @@ -41,12 +40,6 @@ from deepmd.pd.utils.utils import ( ActivationFn, ) -from deepmd.utils.env_mat_stat import ( - StatItem, -) -from deepmd.utils.path import ( - DPPath, -) from .repformer_layer import ( RepformerLayer, @@ -79,7 +72,7 @@ def border_op( @DescriptorBlock.register("se_repformer") @DescriptorBlock.register("se_uni") -class DescrptBlockRepformers(DescriptorBlock): +class DescrptBlockRepformers(DescriptorBlock, ComputeInputStatsMixin): def __init__( self, rcut, @@ -583,41 +576,11 @@ def forward( return g1, g2, h2, rot_mat.reshape([nframes, nloc, self.dim_emb, 3]), sw - def compute_input_stats( - self, - merged: Union[Callable[[], list[dict]], list[dict]], - path: Optional[DPPath] = None, - ) -> None: - """ - Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. - - Parameters - ---------- - merged : Union[Callable[[], list[dict]], list[dict]] - - list[dict]: A list of data samples from various data systems. - Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` - originating from the `i`-th data system. - - Callable[[], list[dict]]: A lazy function that returns data samples in the above format - only when needed. Since the sampling process can be slow and memory-intensive, - the lazy function helps by only sampling once. - path : Optional[DPPath] - The path to the stat file. + def _set_stat_mean_and_stddev(self, mean, stddev) -> None: + """Set the computed statistics to the descriptor's mean and stddev attributes. + This is the Paddle backend-specific implementation using paddle.assign. """ - env_mat_stat = EnvMatStatSe(self) - if path is not None: - path = path / env_mat_stat.get_hash() - if path is None or not path.is_dir(): - if callable(merged): - # only get data for once - sampled = merged() - else: - sampled = merged - else: - sampled = [] - env_mat_stat.load_or_compute_stats(sampled, path) - self.stats = env_mat_stat.stats - mean, stddev = env_mat_stat() if not self.set_davg_zero: paddle.assign( paddle.to_tensor(mean, dtype=self.mean.dtype).to(env.DEVICE), @@ -628,14 +591,6 @@ def compute_input_stats( self.stddev, ) # pylint: disable=no-explicit-dtype - def get_stats(self) -> dict[str, StatItem]: - """Get the statistics of the descriptor.""" - if self.stats is None: - raise RuntimeError( - "The statistics of the descriptor has not been computed." - ) - return self.stats - def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" return True