Skip to content

Commit 50de90f

Browse files
authored
feat: support dataloader with multiple input (#76)
* Update CHANGELOG.rst * update temporary code * fix error * fix typo * add unit tests * Update _constants.py * update doc * update doc * Update index.rst * Update test_all_models_multi_input.py * update unit tests * Update test_all_models_multi_input.py * update
1 parent 30e2f4e commit 50de90f

16 files changed

+580
-782
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Changelog
1818
Ver 0.1.*
1919
---------
2020

21-
* |Feature| |API| Add :class:`SoftGradientBoostingClassifier` and :class:`SoftGradientBoostingRegressor` | `@xuyxu <https://github.com/xuyxu>`__
21+
* |Feature| |API| Support using dataloader with multiple input | `@xuyxu <https://github.com/xuyxu>`__
2222
* |Fix| Fix missing functionality of ``use_reduction_sum`` for :meth:`fit` of Gradient Boosting | `@xuyxu <https://github.com/xuyxu>`__
2323
* |Enhancement| Relax :mod:`tensorboard` as a soft dependency | `@xuyxu <https://github.com/xuyxu>`__
2424
* |Enhancement| |API| Simplify the training workflow of :class:`FastGeometricClassifier` and :class:`FastGeometricRegressor` | `@xuyxu <https://github.com/xuyxu>`__

torchensemble/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from .adversarial_training import AdversarialTrainingRegressor
1313
from .fast_geometric import FastGeometricClassifier
1414
from .fast_geometric import FastGeometricRegressor
15-
from .soft_gradient_boosting import SoftGradientBoostingClassifier
16-
from .soft_gradient_boosting import SoftGradientBoostingRegressor
1715

1816

1917
__all__ = [
@@ -31,6 +29,4 @@
3129
"AdversarialTrainingRegressor",
3230
"FastGeometricClassifier",
3331
"FastGeometricRegressor",
34-
"SoftGradientBoostingClassifier",
35-
"SoftGradientBoostingRegressor",
3632
]

torchensemble/_base.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.nn as nn
88

99
from . import _constants as const
10+
from .utils.io import split_data_target
1011
from .utils.logging import get_tb_logger
1112

1213

@@ -148,7 +149,7 @@ def set_scheduler(self, scheduler_name, **kwargs):
148149
self.use_scheduler_ = True
149150

150151
@abc.abstractmethod
151-
def forward(self, x):
152+
def forward(self, *x):
152153
"""
153154
Implementation on the data forwarding in the ensemble. Notice
154155
that the input ``x`` should be a data batch instead of a standalone
@@ -170,27 +171,26 @@ def fit(
170171
"""
171172

172173
@torch.no_grad()
173-
def predict(self, X, return_numpy=True):
174-
"""Docstrings decorated by downstream models."""
174+
def predict(self, *x):
175+
"""Docstrings decorated by downstream ensembles."""
175176
self.eval()
176-
pred = None
177177

178-
if isinstance(X, torch.Tensor):
179-
pred = self.forward(X.to(self.device))
180-
elif isinstance(X, np.ndarray):
181-
X = torch.Tensor(X).to(self.device)
182-
pred = self.forward(X)
183-
else:
184-
msg = (
185-
"The type of input X should be one of {{torch.Tensor,"
186-
" np.ndarray}}."
187-
)
188-
raise ValueError(msg)
178+
# Copy data
179+
x_device = []
180+
for data in x:
181+
if isinstance(data, torch.Tensor):
182+
x_device.append(data.to(self.device))
183+
elif isinstance(data, np.ndarray):
184+
x_device.append(torch.Tensor(data).to(self.device))
185+
else:
186+
msg = (
187+
"The type of input X should be one of {{torch.Tensor,"
188+
" np.ndarray}}."
189+
)
190+
raise ValueError(msg)
189191

192+
pred = self.forward(*x_device)
190193
pred = pred.cpu()
191-
if return_numpy:
192-
return pred.numpy()
193-
194194
return pred
195195

196196

@@ -212,7 +212,8 @@ def _decide_n_outputs(self, train_loader):
212212
# Infer `n_outputs` from the dataloader
213213
else:
214214
labels = []
215-
for _, (_, target) in enumerate(train_loader):
215+
for _, elem in enumerate(train_loader):
216+
_, target = split_data_target(elem, self.device)
216217
labels.append(target)
217218
labels = torch.unique(torch.cat(labels))
218219
n_outputs = labels.size(0)
@@ -228,9 +229,9 @@ def evaluate(self, test_loader, return_loss=False):
228229
criterion = nn.CrossEntropyLoss()
229230
loss = 0.0
230231

231-
for _, (data, target) in enumerate(test_loader):
232-
data, target = data.to(self.device), target.to(self.device)
233-
output = self.forward(data)
232+
for _, elem in enumerate(test_loader):
233+
data, target = split_data_target(elem, self.device)
234+
output = self.forward(*data)
234235
_, predicted = torch.max(output.data, 1)
235236
correct += (predicted == target).sum().item()
236237
total += target.size(0)
@@ -258,25 +259,26 @@ def _decide_n_outputs(self, train_loader):
258259
The number of outputs equals the number of target variables for
259260
regressors (e.g., `1` in univariate regression).
260261
"""
261-
for _, (_, target) in enumerate(train_loader):
262+
for _, elem in enumerate(train_loader):
263+
_, target = split_data_target(elem, self.device)
262264
if len(target.size()) == 1:
263-
n_outputs = 1
265+
n_outputs = 1 # univariate regression
264266
else:
265-
n_outputs = target.size(1)
267+
n_outputs = target.size(1) # multivariate regression
266268
break
267269

268270
return n_outputs
269271

270272
@torch.no_grad()
271273
def evaluate(self, test_loader):
272-
"""Docstrings decorated by downstream models."""
274+
"""Docstrings decorated by downstream ensembles."""
273275
self.eval()
274276
mse = 0.0
275277
criterion = nn.MSELoss()
276278

277-
for _, (data, target) in enumerate(test_loader):
278-
data, target = data.to(self.device), target.to(self.device)
279-
output = self.forward(data)
279+
for _, elem in enumerate(test_loader):
280+
data, target = split_data_target(elem, self.device)
281+
output = self.forward(*data)
280282
mse += criterion(output, target)
281283

282284
return float(mse) / len(test_loader)

torchensemble/_constants.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,14 @@
124124
125125
Parameters
126126
----------
127-
X : {Tensor, ndarray}
128-
A data batch in the form of tensor or Numpy array.
129-
return_numpy : bool, default=True
130-
Whether to convert the predictions into a Numpy array.
127+
X : {tensor, numpy array}
128+
A data batch in the form of tensor or numpy array.
131129
132130
Returns
133131
-------
134-
pred : Array of shape (n_samples, n_outputs)
132+
pred : tensor of shape (n_samples, n_outputs)
135133
For classifiers, ``n_outputs`` is the number of distinct classes. For
136134
regressors, ``n_output`` is the number of target variables.
137-
138-
- If ``return_numpy`` is ``False``, the result is a tensor.
139-
- If ``return_numpy`` is ``True``, the result is a Numpy array.
140135
"""
141136

142137

torchensemble/adversarial_training.py

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -107,23 +107,24 @@ def _parallel_fit_per_epoch(
107107
# Parallelization corrupts the binding between optimizer and scheduler
108108
set_module.update_lr(optimizer, cur_lr)
109109

110-
for batch_idx, (data, target) in enumerate(train_loader):
110+
for batch_idx, elem in enumerate(train_loader):
111111

112-
batch_size = data.size()[0]
113-
data, target = data.to(device), target.to(device)
114-
data.requires_grad = True
112+
data, target = io.split_data_target(elem, device)
113+
batch_size = data[0].size(0)
114+
for tensor in data:
115+
tensor.requires_grad = True
115116

116117
# Get adversarial samples
117-
_output = estimator(data)
118+
_output = estimator(*data)
118119
_loss = criterion(_output, target)
119120
_loss.backward()
120-
data_grad = data.grad.data
121+
data_grad = [tensor.grad.data for tensor in data]
121122
adv_data = _get_fgsm_samples(data, epsilon, data_grad)
122123

123124
# Compute the training loss
124125
optimizer.zero_grad()
125-
org_output = estimator(data)
126-
adv_output = estimator(adv_data)
126+
org_output = estimator(*data)
127+
adv_output = estimator(*adv_data)
127128
loss = criterion(org_output, target) + criterion(adv_output, target)
128129
loss.backward()
129130
optimizer.step()
@@ -156,27 +157,31 @@ def _parallel_fit_per_epoch(
156157
return estimator, optimizer
157158

158159

159-
def _get_fgsm_samples(sample, epsilon, sample_grad):
160+
def _get_fgsm_samples(sample_list, epsilon, sample_grad_list):
160161
"""
161162
Private functions used to generate adversarial samples with fast gradient
162163
sign method (FGSM).
163164
"""
164165

165-
# Check the input range of `sample`
166-
min_value, max_value = torch.min(sample), torch.max(sample)
167-
if not 0 <= min_value < max_value <= 1:
168-
msg = (
169-
"The input range of samples passed to adversarial training"
170-
" should be in the range [0, 1], but got [{:.3f}, {:.3f}]"
171-
" instead."
172-
)
173-
raise ValueError(msg.format(min_value, max_value))
166+
perturbed_sample_list = []
167+
for sample, sample_grad in zip(sample_list, sample_grad_list):
168+
# Check the input range of `sample`
169+
min_value, max_value = torch.min(sample), torch.max(sample)
170+
if not 0 <= min_value < max_value <= 1:
171+
msg = (
172+
"The input range of samples passed to adversarial training"
173+
" should be in the range [0, 1], but got [{:.3f}, {:.3f}]"
174+
" instead."
175+
)
176+
raise ValueError(msg.format(min_value, max_value))
174177

175-
sign_sample_grad = sample_grad.sign()
176-
perturbed_sample = sample + epsilon * sign_sample_grad
177-
perturbed_sample = torch.clamp(perturbed_sample, 0, 1)
178+
sign_sample_grad = sample_grad.sign()
179+
perturbed_sample = sample + epsilon * sign_sample_grad
180+
perturbed_sample = torch.clamp(perturbed_sample, 0, 1)
178181

179-
return perturbed_sample
182+
perturbed_sample_list.append(perturbed_sample)
183+
184+
return perturbed_sample_list
180185

181186

182187
class _BaseAdversarialTraining(BaseModule):
@@ -218,10 +223,10 @@ class AdversarialTrainingClassifier(_BaseAdversarialTraining, BaseClassifier):
218223
"""Implementation on the data forwarding in AdversarialTrainingClassifier.""", # noqa: E501
219224
"classifier_forward",
220225
)
221-
def forward(self, x):
226+
def forward(self, *x):
222227
# Take the average over class distributions from all base estimators.
223228
outputs = [
224-
F.softmax(estimator(x), dim=1) for estimator in self.estimators_
229+
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
225230
]
226231
proba = op.average(outputs)
227232

@@ -282,9 +287,9 @@ def fit(
282287
best_acc = 0.0
283288

284289
# Internal helper function on pesudo forward
285-
def _forward(estimators, data):
290+
def _forward(estimators, *x):
286291
outputs = [
287-
F.softmax(estimator(data), dim=1) for estimator in estimators
292+
F.softmax(estimator(*x), dim=1) for estimator in estimators
288293
]
289294
proba = op.average(outputs)
290295

@@ -336,10 +341,11 @@ def _forward(estimators, data):
336341
with torch.no_grad():
337342
correct = 0
338343
total = 0
339-
for _, (data, target) in enumerate(test_loader):
340-
data = data.to(self.device)
341-
target = target.to(self.device)
342-
output = _forward(estimators, data)
344+
for _, elem in enumerate(test_loader):
345+
data, target = io.split_data_target(
346+
elem, self.device
347+
)
348+
output = _forward(estimators, *data)
343349
_, predicted = torch.max(output.data, 1)
344350
correct += (predicted == target).sum().item()
345351
total += target.size(0)
@@ -384,8 +390,8 @@ def evaluate(self, test_loader, return_loss=False):
384390
return super().evaluate(test_loader, return_loss)
385391

386392
@torchensemble_model_doc(item="predict")
387-
def predict(self, X, return_numpy=True):
388-
return super().predict(X, return_numpy)
393+
def predict(self, *x):
394+
return super().predict(*x)
389395

390396

391397
@torchensemble_model_doc(
@@ -397,9 +403,9 @@ class AdversarialTrainingRegressor(_BaseAdversarialTraining, BaseRegressor):
397403
"""Implementation on the data forwarding in AdversarialTrainingRegressor.""", # noqa: E501
398404
"regressor_forward",
399405
)
400-
def forward(self, x):
406+
def forward(self, *x):
401407
# Take the average over predictions from all base estimators.
402-
outputs = [estimator(x) for estimator in self.estimators_]
408+
outputs = [estimator(*x) for estimator in self.estimators_]
403409
pred = op.average(outputs)
404410

405411
return pred
@@ -459,8 +465,8 @@ def fit(
459465
best_mse = float("inf")
460466

461467
# Internal helper function on pesudo forward
462-
def _forward(estimators, data):
463-
outputs = [estimator(data) for estimator in estimators]
468+
def _forward(estimators, *x):
469+
outputs = [estimator(*x) for estimator in estimators]
464470
pred = op.average(outputs)
465471

466472
return pred
@@ -510,10 +516,11 @@ def _forward(estimators, data):
510516
self.eval()
511517
with torch.no_grad():
512518
mse = 0.0
513-
for _, (data, target) in enumerate(test_loader):
514-
data = data.to(self.device)
515-
target = target.to(self.device)
516-
output = _forward(estimators, data)
519+
for _, elem in enumerate(test_loader):
520+
data, target = io.split_data_target(
521+
elem, self.device
522+
)
523+
output = _forward(estimators, *data)
517524
mse += criterion(output, target)
518525
mse /= len(test_loader)
519526

@@ -553,5 +560,5 @@ def evaluate(self, test_loader):
553560
return super().evaluate(test_loader)
554561

555562
@torchensemble_model_doc(item="predict")
556-
def predict(self, X, return_numpy=True):
557-
return super().predict(X, return_numpy)
563+
def predict(self, *x):
564+
return super().predict(*x)

0 commit comments

Comments
 (0)