1+ import inspect
12from collections import deque
23
34import torch
@@ -43,9 +44,9 @@ def __init__(self, model, feature=("reaction", "graph", "atom", "bond"), num_mlp
4344 def preprocess (self , train_set , valid_set , test_set ):
4445 reaction_types = set ()
4546 bond_types = set ()
46- for data in train_set :
47- reaction_types .add (data ["reaction" ])
48- for graph in data ["graph" ]:
47+ for sample in train_set :
48+ reaction_types .add (sample ["reaction" ])
49+ for graph in sample ["graph" ]:
4950 bond_types .update (graph .edge_list [:, 2 ].tolist ())
5051 self .num_reaction = len (reaction_types )
5152 self .num_relation = len (bond_types )
@@ -312,35 +313,22 @@ def __init__(self, model, feature=("reaction", "graph", "atom"), num_mlp_layer=2
312313
313314 def preprocess (self , train_set , valid_set , test_set ):
314315 reaction_types = set ()
315- for data in train_set :
316- reaction_types .add (data ["reaction" ])
317- self .num_reaction = len (reaction_types )
318-
319- if isinstance (train_set , torch_data .Subset ):
320- dataset = train_set .dataset
321- else :
322- dataset = train_set
323- dataset .transform = transforms .Compose ([
324- dataset .transform ,
325- RandomBFSOrder (),
326- ])
327-
328- # atom_types = set()
329- # bond_types = set()
330- # for data in train_set:
331- # for graph in data["graph"]:
332- # atom_types.update(graph.atom_type.tolist())
333- # bond_types.update(graph.edge_list[:, 2].tolist())
334- # atom_types = torch.tensor(sorted(atom_types))
335-
316+ atom_types = set ()
317+ bond_types = set ()
318+ for sample in train_set :
319+ reaction_types .add (sample ["reaction" ])
320+ for graph in sample ["graph" ]:
321+ atom_types .update (graph .atom_type .tolist ())
322+ bond_types .update (graph .edge_list [:, 2 ].tolist ())
336323 # TODO: only for fast debugging, to remove
337- atom_types = torch .tensor ([5 , 6 , 7 , 8 , 9 , 12 , 14 , 15 , 16 , 17 , 29 , 30 , 34 , 35 , 50 , 53 ])
338- bond_types = torch .tensor ([0 , 1 , 2 ])
339-
324+ # atom_types = torch.tensor([5, 6, 7, 8, 9, 12, 14, 15, 16, 17, 29, 30, 34, 35, 50, 53])
325+ # bond_types = torch.tensor([0, 1, 2])
326+ atom_types = torch . tensor ( sorted ( atom_types ))
340327 atom2id = - torch .ones (atom_types .max () + 1 , dtype = torch .long )
341328 atom2id [atom_types ] = torch .arange (len (atom_types ))
342329 self .register_buffer ("id2atom" , atom_types )
343330 self .register_buffer ("atom2id" , atom2id )
331+ self .num_reaction = len (reaction_types )
344332 self .num_atom_type = len (atom_types )
345333 self .num_bond_type = len (bond_types )
346334 node_feature_dim = train_set [0 ]["graph" ][0 ].node_feature .shape [- 1 ]
@@ -349,7 +337,18 @@ def preprocess(self, train_set, valid_set, test_set):
349337 dataset = train_set .dataset
350338 else :
351339 dataset = train_set
352- self .dataset_kwargs = dataset .config_dict ().get ("kwargs" )
340+ dataset .transform = transforms .Compose ([
341+ dataset .transform ,
342+ RandomBFSOrder (),
343+ ])
344+ sig = inspect .signature (data .PackedMolecule .from_molecule )
345+ keys = set (sig .parameters .keys ())
346+ kwargs = dataset .config_dict ()
347+ feature_kwargs = {}
348+ for k , v in kwargs .items ():
349+ if k in keys :
350+ feature_kwargs [k ] = v
351+ self .feature_kwargs = feature_kwargs
353352
354353 node_dim = self .model .output_dim
355354 edge_dim = 0
@@ -382,7 +381,7 @@ def _update_molecule_feature(self, graphs):
382381 mols = graphs .to_molecule (ignore_error = True )
383382 valid = [mol is not None for mol in mols ]
384383 valid = torch .tensor (valid , device = graphs .device )
385- new_graphs = type (graphs ).from_molecule (mols , node_feature = "synthon_completion" , kekulize = True )
384+ new_graphs = type (graphs ).from_molecule (mols , ** self . feature_kwargs )
386385
387386 node_feature = torch .zeros (graphs .num_node , * new_graphs .node_feature .shape [1 :],
388387 dtype = new_graphs .node_feature .dtype , device = graphs .device )
@@ -915,8 +914,7 @@ def predict_reactant(self, batch, num_beam=10, max_prediction=20, max_step=20):
915914 order = key .argsort (descending = True )
916915 new_graph = new_graph [order ]
917916
918- num_candidate = scatter_add (torch .ones_like (new_graph .synthon_id ), new_graph .synthon_id ,
919- dim_size = len (synthon ))
917+ num_candidate = new_graph .synthon_id .bincount (minlength = len (synthon ))
920918 topk = functional .variadic_topk (new_graph .logp , num_candidate , num_beam )[1 ]
921919 topk_index = topk + (num_candidate .cumsum (0 ) - num_candidate ).unsqueeze (- 1 )
922920 topk_index = torch .unique (topk_index )
@@ -965,7 +963,7 @@ def _extend(self, data, num_xs, input, input2graph=None):
965963 num_input_per_graph = len (input ) // len (num_xs )
966964 input2graph = torch .arange (len (num_xs ), device = data .device ).unsqueeze (- 1 )
967965 input2graph = input2graph .repeat (1 , num_input_per_graph ).flatten ()
968- num_inputs = scatter_add ( torch . ones_like ( input2graph ), input2graph , dim_size = len (num_xs ))
966+ num_inputs = input2graph . bincount ( minlength = len (num_xs ))
969967 new_num_xs = num_xs + num_inputs
970968 new_num_cum_xs = new_num_xs .cumsum (0 )
971969 new_num_x = new_num_cum_xs [- 1 ].item ()
0 commit comments