@@ -107,23 +107,24 @@ def _parallel_fit_per_epoch(
107107 # Parallelization corrupts the binding between optimizer and scheduler
108108 set_module .update_lr (optimizer , cur_lr )
109109
110- for batch_idx , ( data , target ) in enumerate (train_loader ):
110+ for batch_idx , elem in enumerate (train_loader ):
111111
112- batch_size = data .size ()[0 ]
113- data , target = data .to (device ), target .to (device )
114- data .requires_grad = True
112+ data , target = io .split_data_target (elem , device )
113+ batch_size = data [0 ].size (0 )
114+ for tensor in data :
115+ tensor .requires_grad = True
115116
116117 # Get adversarial samples
117- _output = estimator (data )
118+ _output = estimator (* data )
118119 _loss = criterion (_output , target )
119120 _loss .backward ()
120- data_grad = data .grad .data
121+ data_grad = [ tensor .grad .data for tensor in data ]
121122 adv_data = _get_fgsm_samples (data , epsilon , data_grad )
122123
123124 # Compute the training loss
124125 optimizer .zero_grad ()
125- org_output = estimator (data )
126- adv_output = estimator (adv_data )
126+ org_output = estimator (* data )
127+ adv_output = estimator (* adv_data )
127128 loss = criterion (org_output , target ) + criterion (adv_output , target )
128129 loss .backward ()
129130 optimizer .step ()
@@ -156,27 +157,31 @@ def _parallel_fit_per_epoch(
156157 return estimator , optimizer
157158
158159
159- def _get_fgsm_samples (sample , epsilon , sample_grad ):
160+ def _get_fgsm_samples (sample_list , epsilon , sample_grad_list ):
160161 """
161162 Private functions used to generate adversarial samples with fast gradient
162163 sign method (FGSM).
163164 """
164165
165- # Check the input range of `sample`
166- min_value , max_value = torch .min (sample ), torch .max (sample )
167- if not 0 <= min_value < max_value <= 1 :
168- msg = (
169- "The input range of samples passed to adversarial training"
170- " should be in the range [0, 1], but got [{:.3f}, {:.3f}]"
171- " instead."
172- )
173- raise ValueError (msg .format (min_value , max_value ))
166+ perturbed_sample_list = []
167+ for sample , sample_grad in zip (sample_list , sample_grad_list ):
168+ # Check the input range of `sample`
169+ min_value , max_value = torch .min (sample ), torch .max (sample )
170+ if not 0 <= min_value < max_value <= 1 :
171+ msg = (
172+ "The input range of samples passed to adversarial training"
173+ " should be in the range [0, 1], but got [{:.3f}, {:.3f}]"
174+ " instead."
175+ )
176+ raise ValueError (msg .format (min_value , max_value ))
174177
175- sign_sample_grad = sample_grad .sign ()
176- perturbed_sample = sample + epsilon * sign_sample_grad
177- perturbed_sample = torch .clamp (perturbed_sample , 0 , 1 )
178+ sign_sample_grad = sample_grad .sign ()
179+ perturbed_sample = sample + epsilon * sign_sample_grad
180+ perturbed_sample = torch .clamp (perturbed_sample , 0 , 1 )
178181
179- return perturbed_sample
182+ perturbed_sample_list .append (perturbed_sample )
183+
184+ return perturbed_sample_list
180185
181186
182187class _BaseAdversarialTraining (BaseModule ):
@@ -218,10 +223,10 @@ class AdversarialTrainingClassifier(_BaseAdversarialTraining, BaseClassifier):
218223 """Implementation on the data forwarding in AdversarialTrainingClassifier.""" , # noqa: E501
219224 "classifier_forward" ,
220225 )
221- def forward (self , x ):
226+ def forward (self , * x ):
222227 # Take the average over class distributions from all base estimators.
223228 outputs = [
224- F .softmax (estimator (x ), dim = 1 ) for estimator in self .estimators_
229+ F .softmax (estimator (* x ), dim = 1 ) for estimator in self .estimators_
225230 ]
226231 proba = op .average (outputs )
227232
@@ -282,9 +287,9 @@ def fit(
282287 best_acc = 0.0
283288
284289 # Internal helper function on pesudo forward
285- def _forward (estimators , data ):
290+ def _forward (estimators , * x ):
286291 outputs = [
287- F .softmax (estimator (data ), dim = 1 ) for estimator in estimators
292+ F .softmax (estimator (* x ), dim = 1 ) for estimator in estimators
288293 ]
289294 proba = op .average (outputs )
290295
@@ -336,10 +341,11 @@ def _forward(estimators, data):
336341 with torch .no_grad ():
337342 correct = 0
338343 total = 0
339- for _ , (data , target ) in enumerate (test_loader ):
340- data = data .to (self .device )
341- target = target .to (self .device )
342- output = _forward (estimators , data )
344+ for _ , elem in enumerate (test_loader ):
345+ data , target = io .split_data_target (
346+ elem , self .device
347+ )
348+ output = _forward (estimators , * data )
343349 _ , predicted = torch .max (output .data , 1 )
344350 correct += (predicted == target ).sum ().item ()
345351 total += target .size (0 )
@@ -384,8 +390,8 @@ def evaluate(self, test_loader, return_loss=False):
384390 return super ().evaluate (test_loader , return_loss )
385391
386392 @torchensemble_model_doc (item = "predict" )
387- def predict (self , X , return_numpy = True ):
388- return super ().predict (X , return_numpy )
393+ def predict (self , * x ):
394+ return super ().predict (* x )
389395
390396
391397@torchensemble_model_doc (
@@ -397,9 +403,9 @@ class AdversarialTrainingRegressor(_BaseAdversarialTraining, BaseRegressor):
397403 """Implementation on the data forwarding in AdversarialTrainingRegressor.""" , # noqa: E501
398404 "regressor_forward" ,
399405 )
400- def forward (self , x ):
406+ def forward (self , * x ):
401407 # Take the average over predictions from all base estimators.
402- outputs = [estimator (x ) for estimator in self .estimators_ ]
408+ outputs = [estimator (* x ) for estimator in self .estimators_ ]
403409 pred = op .average (outputs )
404410
405411 return pred
@@ -459,8 +465,8 @@ def fit(
459465 best_mse = float ("inf" )
460466
461467 # Internal helper function on pesudo forward
462- def _forward (estimators , data ):
463- outputs = [estimator (data ) for estimator in estimators ]
468+ def _forward (estimators , * x ):
469+ outputs = [estimator (* x ) for estimator in estimators ]
464470 pred = op .average (outputs )
465471
466472 return pred
@@ -510,10 +516,11 @@ def _forward(estimators, data):
510516 self .eval ()
511517 with torch .no_grad ():
512518 mse = 0.0
513- for _ , (data , target ) in enumerate (test_loader ):
514- data = data .to (self .device )
515- target = target .to (self .device )
516- output = _forward (estimators , data )
519+ for _ , elem in enumerate (test_loader ):
520+ data , target = io .split_data_target (
521+ elem , self .device
522+ )
523+ output = _forward (estimators , * data )
517524 mse += criterion (output , target )
518525 mse /= len (test_loader )
519526
@@ -553,5 +560,5 @@ def evaluate(self, test_loader):
553560 return super ().evaluate (test_loader )
554561
555562 @torchensemble_model_doc (item = "predict" )
556- def predict (self , X , return_numpy = True ):
557- return super ().predict (X , return_numpy )
563+ def predict (self , * x ):
564+ return super ().predict (* x )
0 commit comments