-
Notifications
You must be signed in to change notification settings - Fork 247
Description
Thank you for putting together such a useful package. I noticed that in your current implementation, the neighborlist is created using ase or matscipy package. However, these two are not the most efficient choice, especially the ase neighborlist could be very slow when structure size scales tomore than a few thousands. Instead, pymatgen and vesin packages provide more efficient implementation for creating neighborlist. I have tested all four methods on H2O structures with sizes varying between 192 to 98304 and the results show that by using pymatgen and vesin we can achieve several times acceleration compared with ase and also faster than matscipy. An example code for using pymatgen and vesin to create neighbor list is attached below and hope that would help:
def _build_neighbor_list(
self,
Z: torch.Tensor,
positions: torch.Tensor, # [N,3]
cell: torch.Tensor, # [3,3]
pbc: torch.Tensor, # [3]
cutoff: float, # scalar
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pos_np = positions.detach().cpu().numpy().reshape(-1, 3)
cell_np = cell.detach().cpu().numpy().reshape(3, 3)
pbc_np_bool = pbc.detach().cpu().numpy().astype(bool).reshape(-1)
pbc_np_int = pbc_np_bool.astype(int).reshape(-1)
device = positions.device
dtype = positions.dtype
match self.fn_type:
case "pymatgen":
with optional_import_error_message("pymatgen"):
from pymatgen.optimization.neighbors import find_points_in_spheres
idx_i, idx_j, offsets, distances = find_points_in_spheres(
pos_np,
pos_np,
r=float(cutoff),
pbc=pbc_np_int,
lattice=cell_np,
tol=1e-8,
)
# remove self-interactions
mask = idx_i != idx_j
idx_i = torch.from_numpy(idx_i[mask]).to(device)
idx_j = torch.from_numpy(idx_j[mask]).to(device)
offsets_frac = torch.from_numpy(offsets[mask]).to(dtype=dtype, device=device)
offsets_cart = offsets_frac @ cell.to(device) # [E,3]
return idx_i, idx_j, offsets_cart
case "ase":
with optional_import_error_message("ase"):
import ase.neighborlist as nl
idx_i, idx_j, S = nl.primitive_neighbor_list(
"ijS",
pbc=pbc_np_bool,
cell=cell_np,
positions=pos_np,
cutoff=float(cutoff),
self_interaction=False,
)
idx_i = torch.from_numpy(idx_i).to(device)
idx_j = torch.from_numpy(idx_j).to(device)
S = torch.from_numpy(S).to(dtype=dtype, device=device)
offsets_cart = S @ cell.to(device)
return idx_i, idx_j, offsets_cart
case "vesin":
with optional_import_error_message("vesin"):
from vesin import NeighborList as vesin_nl
if pbc_np_bool.all():
periodic = True
elif (~pbc_np_bool).all():
periodic = False
else:
raise ValueError("vesin neighbor list does not support mixed PBC settings.")
results = vesin_nl(
cutoff=float(cutoff), full_list=True
).compute(
points=pos_np, box=cell_np, periodic=periodic, quantities="ijS"
)
idx_i = torch.from_numpy(results[0]).to(device).to(torch.long)
idx_j = torch.from_numpy(results[1]).to(device).to(torch.long)
S = torch.from_numpy(results[2]).to(dtype=dtype, device=device)
offsets_cart = S @ cell.to(device)
return idx_i, idx_j, offsets_cart
case "matscipy":
with optional_import_error_message("matscipy"):
import matscipy.neighbours as mat_nl
ijS = mat_nl.neighbour_list(
"ijS",
pbc=pbc_np_bool,
cell=cell_np,
positions=pos_np,
cutoff=float(cutoff),
)
idx_i = torch.from_numpy(ijS[0]).to(device)
idx_j = torch.from_numpy(ijS[1]).to(device)
S = torch.from_numpy(ijS[2]).to(dtype=dtype, device=device)
offsets_cart = S @ cell.to(device)
return idx_i, idx_j, offsets_cart
case _:
raise ValueError(f"Invalid fn_type specified, got {self.fn_type}")