Skip to content

Commit a93356e

Browse files
committed
extend SynthonCompletion for arbitrary node features
1 parent 0474698 commit a93356e

File tree

3 files changed

+35
-37
lines changed

3 files changed

+35
-37
lines changed

torchdrug/data/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def connected_components(self):
240240
last = min_neighbor
241241
min_neighbor = scatter_min(min_neighbor[node_out], node_in, dim_size=self.num_node)[0]
242242
anchor = torch.unique(min_neighbor)
243-
num_cc = scatter_add(torch.ones_like(anchor), self.node2graph[anchor])
243+
num_cc = scatter_add(torch.ones_like(anchor), self.node2graph[anchor], dim_size=self.batch_size)
244244
return self.split(min_neighbor), num_cc
245245

246246
def split(self, node2graph):

torchdrug/tasks/generation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def preprocess(self, train_set, valid_set, test_set):
9999
self.max_node = 0
100100

101101
train_set = tqdm(train_set, "Computing max number of nodes and edge unrolling")
102-
for data in train_set:
103-
graph = data["graph"]
102+
for sample in train_set:
103+
graph = sample["graph"]
104104
if graph.edge_list.numel():
105105
edge_unroll = (graph.edge_list[:, 0] - graph.edge_list[:, 1]).abs().max().item()
106106
self.max_edge_unroll = max(self.max_edge_unroll, edge_unroll)
@@ -677,8 +677,8 @@ def preprocess(self, train_set, valid_set, test_set):
677677
self.max_node = 0
678678

679679
train_set = tqdm(train_set, "Computing max number of nodes and edge unrolling")
680-
for data in train_set:
681-
graph = data["graph"]
680+
for sample in train_set:
681+
graph = sample["graph"]
682682
if graph.edge_list.numel():
683683
edge_unroll = (graph.edge_list[:, 0] - graph.edge_list[:, 1]).abs().max().item()
684684
self.max_edge_unroll = max(self.max_edge_unroll, edge_unroll)

torchdrug/tasks/retrosynthesis.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
from collections import deque
23

34
import torch
@@ -43,9 +44,9 @@ def __init__(self, model, feature=("reaction", "graph", "atom", "bond"), num_mlp
4344
def preprocess(self, train_set, valid_set, test_set):
4445
reaction_types = set()
4546
bond_types = set()
46-
for data in train_set:
47-
reaction_types.add(data["reaction"])
48-
for graph in data["graph"]:
47+
for sample in train_set:
48+
reaction_types.add(sample["reaction"])
49+
for graph in sample["graph"]:
4950
bond_types.update(graph.edge_list[:, 2].tolist())
5051
self.num_reaction = len(reaction_types)
5152
self.num_relation = len(bond_types)
@@ -312,35 +313,22 @@ def __init__(self, model, feature=("reaction", "graph", "atom"), num_mlp_layer=2
312313

313314
def preprocess(self, train_set, valid_set, test_set):
314315
reaction_types = set()
315-
for data in train_set:
316-
reaction_types.add(data["reaction"])
317-
self.num_reaction = len(reaction_types)
318-
319-
if isinstance(train_set, torch_data.Subset):
320-
dataset = train_set.dataset
321-
else:
322-
dataset = train_set
323-
dataset.transform = transforms.Compose([
324-
dataset.transform,
325-
RandomBFSOrder(),
326-
])
327-
328-
# atom_types = set()
329-
# bond_types = set()
330-
# for data in train_set:
331-
# for graph in data["graph"]:
332-
# atom_types.update(graph.atom_type.tolist())
333-
# bond_types.update(graph.edge_list[:, 2].tolist())
334-
# atom_types = torch.tensor(sorted(atom_types))
335-
316+
atom_types = set()
317+
bond_types = set()
318+
for sample in train_set:
319+
reaction_types.add(sample["reaction"])
320+
for graph in sample["graph"]:
321+
atom_types.update(graph.atom_type.tolist())
322+
bond_types.update(graph.edge_list[:, 2].tolist())
336323
# TODO: only for fast debugging, to remove
337-
atom_types = torch.tensor([5, 6, 7, 8, 9, 12, 14, 15, 16, 17, 29, 30, 34, 35, 50, 53])
338-
bond_types = torch.tensor([0, 1, 2])
339-
324+
# atom_types = torch.tensor([5, 6, 7, 8, 9, 12, 14, 15, 16, 17, 29, 30, 34, 35, 50, 53])
325+
# bond_types = torch.tensor([0, 1, 2])
326+
atom_types = torch.tensor(sorted(atom_types))
340327
atom2id = -torch.ones(atom_types.max() + 1, dtype=torch.long)
341328
atom2id[atom_types] = torch.arange(len(atom_types))
342329
self.register_buffer("id2atom", atom_types)
343330
self.register_buffer("atom2id", atom2id)
331+
self.num_reaction = len(reaction_types)
344332
self.num_atom_type = len(atom_types)
345333
self.num_bond_type = len(bond_types)
346334
node_feature_dim = train_set[0]["graph"][0].node_feature.shape[-1]
@@ -349,7 +337,18 @@ def preprocess(self, train_set, valid_set, test_set):
349337
dataset = train_set.dataset
350338
else:
351339
dataset = train_set
352-
self.dataset_kwargs = dataset.config_dict().get("kwargs")
340+
dataset.transform = transforms.Compose([
341+
dataset.transform,
342+
RandomBFSOrder(),
343+
])
344+
sig = inspect.signature(data.PackedMolecule.from_molecule)
345+
keys = set(sig.parameters.keys())
346+
kwargs = dataset.config_dict()
347+
feature_kwargs = {}
348+
for k, v in kwargs.items():
349+
if k in keys:
350+
feature_kwargs[k] = v
351+
self.feature_kwargs = feature_kwargs
353352

354353
node_dim = self.model.output_dim
355354
edge_dim = 0
@@ -382,7 +381,7 @@ def _update_molecule_feature(self, graphs):
382381
mols = graphs.to_molecule(ignore_error=True)
383382
valid = [mol is not None for mol in mols]
384383
valid = torch.tensor(valid, device=graphs.device)
385-
new_graphs = type(graphs).from_molecule(mols, node_feature="synthon_completion", kekulize=True)
384+
new_graphs = type(graphs).from_molecule(mols, **self.feature_kwargs)
386385

387386
node_feature = torch.zeros(graphs.num_node, *new_graphs.node_feature.shape[1:],
388387
dtype=new_graphs.node_feature.dtype, device=graphs.device)
@@ -915,8 +914,7 @@ def predict_reactant(self, batch, num_beam=10, max_prediction=20, max_step=20):
915914
order = key.argsort(descending=True)
916915
new_graph = new_graph[order]
917916

918-
num_candidate = scatter_add(torch.ones_like(new_graph.synthon_id), new_graph.synthon_id,
919-
dim_size=len(synthon))
917+
num_candidate = new_graph.synthon_id.bincount(minlength=len(synthon))
920918
topk = functional.variadic_topk(new_graph.logp, num_candidate, num_beam)[1]
921919
topk_index = topk + (num_candidate.cumsum(0) - num_candidate).unsqueeze(-1)
922920
topk_index = torch.unique(topk_index)
@@ -965,7 +963,7 @@ def _extend(self, data, num_xs, input, input2graph=None):
965963
num_input_per_graph = len(input) // len(num_xs)
966964
input2graph = torch.arange(len(num_xs), device=data.device).unsqueeze(-1)
967965
input2graph = input2graph.repeat(1, num_input_per_graph).flatten()
968-
num_inputs = scatter_add(torch.ones_like(input2graph), input2graph, dim_size=len(num_xs))
966+
num_inputs = input2graph.bincount(minlength=len(num_xs))
969967
new_num_xs = num_xs + num_inputs
970968
new_num_cum_xs = new_num_xs.cumsum(0)
971969
new_num_x = new_num_cum_xs[-1].item()

0 commit comments

Comments
 (0)