Skip to content

Speed Improments in Neighbor List Creating #756

@Lingyu-Kong

Description

@Lingyu-Kong

Thank you for putting together such a useful package. I noticed that in your current implementation, the neighborlist is created using ase or matscipy package. However, these two are not the most efficient choice, especially the ase neighborlist could be very slow when structure size scales tomore than a few thousands. Instead, pymatgen and vesin packages provide more efficient implementation for creating neighborlist. I have tested all four methods on H2O structures with sizes varying between 192 to 98304 and the results show that by using pymatgen and vesin we can achieve several times acceleration compared with ase and also faster than matscipy. An example code for using pymatgen and vesin to create neighbor list is attached below and hope that would help:

def _build_neighbor_list(
        self,
        Z: torch.Tensor,
        positions: torch.Tensor,   # [N,3]
        cell: torch.Tensor,        # [3,3]
        pbc: torch.Tensor,         # [3]
        cutoff: float,             # scalar
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        
        pos_np = positions.detach().cpu().numpy().reshape(-1, 3)
        cell_np = cell.detach().cpu().numpy().reshape(3, 3)
        pbc_np_bool = pbc.detach().cpu().numpy().astype(bool).reshape(-1)
        pbc_np_int = pbc_np_bool.astype(int).reshape(-1)

        device = positions.device
        dtype = positions.dtype

        match self.fn_type:
            case "pymatgen":
                with optional_import_error_message("pymatgen"):
                    from pymatgen.optimization.neighbors import find_points_in_spheres

                idx_i, idx_j, offsets, distances = find_points_in_spheres(
                    pos_np,
                    pos_np,
                    r=float(cutoff),
                    pbc=pbc_np_int,
                    lattice=cell_np,
                    tol=1e-8,
                )
                # remove self-interactions
                mask = idx_i != idx_j
                idx_i = torch.from_numpy(idx_i[mask]).to(device)
                idx_j = torch.from_numpy(idx_j[mask]).to(device)
                offsets_frac = torch.from_numpy(offsets[mask]).to(dtype=dtype, device=device)
                offsets_cart = offsets_frac @ cell.to(device)   # [E,3]
                return idx_i, idx_j, offsets_cart

            case "ase":
                with optional_import_error_message("ase"):
                    import ase.neighborlist as nl
                idx_i, idx_j, S = nl.primitive_neighbor_list(
                    "ijS",
                    pbc=pbc_np_bool,
                    cell=cell_np,
                    positions=pos_np,
                    cutoff=float(cutoff),
                    self_interaction=False,
                )
                idx_i = torch.from_numpy(idx_i).to(device)
                idx_j = torch.from_numpy(idx_j).to(device)
                S = torch.from_numpy(S).to(dtype=dtype, device=device)
                offsets_cart = S @ cell.to(device)
                return idx_i, idx_j, offsets_cart

            case "vesin":
                with optional_import_error_message("vesin"):
                    from vesin import NeighborList as vesin_nl
                if pbc_np_bool.all():
                    periodic = True
                elif (~pbc_np_bool).all():
                    periodic = False
                else:
                    raise ValueError("vesin neighbor list does not support mixed PBC settings.")
                results = vesin_nl(
                    cutoff=float(cutoff), full_list=True
                ).compute(
                    points=pos_np, box=cell_np, periodic=periodic, quantities="ijS"
                )
                idx_i = torch.from_numpy(results[0]).to(device).to(torch.long)
                idx_j = torch.from_numpy(results[1]).to(device).to(torch.long)
                S = torch.from_numpy(results[2]).to(dtype=dtype, device=device)
                offsets_cart = S @ cell.to(device)
                return idx_i, idx_j, offsets_cart

            case "matscipy":
                with optional_import_error_message("matscipy"):
                    import matscipy.neighbours as mat_nl
                ijS = mat_nl.neighbour_list(
                    "ijS",
                    pbc=pbc_np_bool,
                    cell=cell_np,
                    positions=pos_np,
                    cutoff=float(cutoff),
                )
                idx_i = torch.from_numpy(ijS[0]).to(device)
                idx_j = torch.from_numpy(ijS[1]).to(device)
                S = torch.from_numpy(ijS[2]).to(dtype=dtype, device=device)
                offsets_cart = S @ cell.to(device)
                return idx_i, idx_j, offsets_cart

            case _:
                raise ValueError(f"Invalid fn_type specified, got {self.fn_type}")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions