Skip to content

Commit 7fd0857

Browse files
authored
Merge pull request #33 from rwightman/autoaugment
Add AutoAugment ImageNet policies for training
2 parents aff194f + 4002c0d commit 7fd0857

File tree

4 files changed

+396
-13
lines changed

4 files changed

+396
-13
lines changed

timm/data/auto_augment.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
""" Auto Augment
2+
Implementation adapted from:
3+
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
4+
Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172
5+
6+
Hacked together by Ross Wightman
7+
"""
8+
import random
9+
import math
10+
from PIL import Image, ImageOps, ImageEnhance
11+
import PIL
12+
import numpy as np
13+
14+
15+
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
16+
17+
_FILL = (128, 128, 128)
18+
19+
# This signifies the max integer that the controller RNN could predict for the
20+
# augmentation scheme.
21+
_MAX_LEVEL = 10.
22+
23+
_HPARAMS_DEFAULT = dict(
24+
translate_const=250,
25+
img_mean=_FILL,
26+
)
27+
28+
_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC)
29+
30+
31+
def _interpolation(kwargs):
32+
interpolation = kwargs.pop('resample', Image.NEAREST)
33+
if isinstance(interpolation, (list, tuple)):
34+
return random.choice(interpolation)
35+
else:
36+
return interpolation
37+
38+
39+
def _check_args_tf(kwargs):
40+
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
41+
kwargs.pop('fillcolor')
42+
kwargs['resample'] = _interpolation(kwargs)
43+
44+
45+
def shear_x(img, factor, **kwargs):
46+
_check_args_tf(kwargs)
47+
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
48+
49+
50+
def shear_y(img, factor, **kwargs):
51+
_check_args_tf(kwargs)
52+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
53+
54+
55+
def translate_x_rel(img, pct, **kwargs):
56+
pixels = pct * img.size[0]
57+
_check_args_tf(kwargs)
58+
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
59+
60+
61+
def translate_y_rel(img, pct, **kwargs):
62+
pixels = pct * img.size[1]
63+
_check_args_tf(kwargs)
64+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
65+
66+
67+
def translate_x_abs(img, pixels, **kwargs):
68+
_check_args_tf(kwargs)
69+
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
70+
71+
72+
def translate_y_abs(img, pixels, **kwargs):
73+
_check_args_tf(kwargs)
74+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
75+
76+
77+
def rotate(img, degrees, **kwargs):
78+
_check_args_tf(kwargs)
79+
if _PIL_VER >= (5, 2):
80+
return img.rotate(degrees, **kwargs)
81+
elif _PIL_VER >= (5, 0):
82+
w, h = img.size
83+
post_trans = (0, 0)
84+
rotn_center = (w / 2.0, h / 2.0)
85+
angle = -math.radians(degrees)
86+
matrix = [
87+
round(math.cos(angle), 15),
88+
round(math.sin(angle), 15),
89+
0.0,
90+
round(-math.sin(angle), 15),
91+
round(math.cos(angle), 15),
92+
0.0,
93+
]
94+
95+
def transform(x, y, matrix):
96+
(a, b, c, d, e, f) = matrix
97+
return a * x + b * y + c, d * x + e * y + f
98+
99+
matrix[2], matrix[5] = transform(
100+
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
101+
)
102+
matrix[2] += rotn_center[0]
103+
matrix[5] += rotn_center[1]
104+
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
105+
else:
106+
return img.rotate(degrees, resample=kwargs['resample'])
107+
108+
109+
def auto_contrast(img, **__):
110+
return ImageOps.autocontrast(img)
111+
112+
113+
def invert(img, **__):
114+
return ImageOps.invert(img)
115+
116+
117+
def equalize(img, **__):
118+
return ImageOps.equalize(img)
119+
120+
121+
def solarize(img, thresh, **__):
122+
return ImageOps.solarize(img, thresh)
123+
124+
125+
def solarize_add(img, add, thresh=128, **__):
126+
lut = []
127+
for i in range(256):
128+
if i < thresh:
129+
lut.append(min(255, i + add))
130+
else:
131+
lut.append(i)
132+
if img.mode in ("L", "RGB"):
133+
if img.mode == "RGB" and len(lut) == 256:
134+
lut = lut + lut + lut
135+
return img.point(lut)
136+
else:
137+
return img
138+
139+
140+
def posterize(img, bits_to_keep, **__):
141+
if bits_to_keep >= 8:
142+
return img
143+
bits_to_keep = max(1, bits_to_keep) # prevent all 0 images
144+
return ImageOps.posterize(img, bits_to_keep)
145+
146+
147+
def contrast(img, factor, **__):
148+
return ImageEnhance.Contrast(img).enhance(factor)
149+
150+
151+
def color(img, factor, **__):
152+
return ImageEnhance.Color(img).enhance(factor)
153+
154+
155+
def brightness(img, factor, **__):
156+
return ImageEnhance.Brightness(img).enhance(factor)
157+
158+
159+
def sharpness(img, factor, **__):
160+
return ImageEnhance.Sharpness(img).enhance(factor)
161+
162+
163+
def _randomly_negate(v):
164+
"""With 50% prob, negate the value"""
165+
return -v if random.random() > 0.5 else v
166+
167+
168+
def _rotate_level_to_arg(level):
169+
# range [-30, 30]
170+
level = (level / _MAX_LEVEL) * 30.
171+
level = _randomly_negate(level)
172+
return (level,)
173+
174+
175+
def _enhance_level_to_arg(level):
176+
# range [0.1, 1.9]
177+
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
178+
179+
180+
def _shear_level_to_arg(level):
181+
# range [-0.3, 0.3]
182+
level = (level / _MAX_LEVEL) * 0.3
183+
level = _randomly_negate(level)
184+
return (level,)
185+
186+
187+
def _translate_abs_level_to_arg(level, translate_const):
188+
level = (level / _MAX_LEVEL) * float(translate_const)
189+
level = _randomly_negate(level)
190+
return (level,)
191+
192+
193+
def _translate_rel_level_to_arg(level):
194+
# range [-0.45, 0.45]
195+
level = (level / _MAX_LEVEL) * 0.45
196+
level = _randomly_negate(level)
197+
return (level,)
198+
199+
200+
def level_to_arg(hparams):
201+
return {
202+
'AutoContrast': lambda level: (),
203+
'Equalize': lambda level: (),
204+
'Invert': lambda level: (),
205+
'Rotate': _rotate_level_to_arg,
206+
# FIXME these are both different from original impl as I believe there is a bug,
207+
# not sure what is the correct alternative, hence 2 options that look better
208+
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8]
209+
'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0]
210+
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256]
211+
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110]
212+
'Color': _enhance_level_to_arg,
213+
'Contrast': _enhance_level_to_arg,
214+
'Brightness': _enhance_level_to_arg,
215+
'Sharpness': _enhance_level_to_arg,
216+
'ShearX': _shear_level_to_arg,
217+
'ShearY': _shear_level_to_arg,
218+
'TranslateX': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']),
219+
'TranslateY': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']),
220+
'TranslateXRel': lambda level: _translate_rel_level_to_arg(level),
221+
'TranslateYRel': lambda level: _translate_rel_level_to_arg(level),
222+
}
223+
224+
225+
NAME_TO_OP = {
226+
'AutoContrast': auto_contrast,
227+
'Equalize': equalize,
228+
'Invert': invert,
229+
'Rotate': rotate,
230+
'Posterize': posterize,
231+
'Posterize2': posterize,
232+
'Solarize': solarize,
233+
'SolarizeAdd': solarize_add,
234+
'Color': color,
235+
'Contrast': contrast,
236+
'Brightness': brightness,
237+
'Sharpness': sharpness,
238+
'ShearX': shear_x,
239+
'ShearY': shear_y,
240+
'TranslateX': translate_x_abs,
241+
'TranslateY': translate_y_abs,
242+
'TranslateXRel': translate_x_rel,
243+
'TranslateYRel': translate_y_rel,
244+
}
245+
246+
247+
class AutoAugmentOp:
248+
249+
def __init__(self, name, prob, magnitude, hparams={}):
250+
self.aug_fn = NAME_TO_OP[name]
251+
self.level_fn = level_to_arg(hparams)[name]
252+
self.prob = prob
253+
self.magnitude = magnitude
254+
# If std deviation of magnitude is > 0, we introduce some randomness
255+
# in the usually fixed policy and sample magnitude from normal dist
256+
# with mean magnitude and std-dev of magnitude_std.
257+
# NOTE This is being tested as it's not in paper or reference impl.
258+
self.magnitude_std = 0.5 # FIXME add arg/hparam
259+
self.kwargs = {
260+
'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL,
261+
'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION
262+
}
263+
264+
def __call__(self, img):
265+
if self.prob < random.random():
266+
return img
267+
magnitude = self.magnitude
268+
if self.magnitude_std and self.magnitude_std > 0:
269+
magnitude = random.gauss(magnitude, self.magnitude_std)
270+
magnitude = min(_MAX_LEVEL, max(0, magnitude))
271+
level_args = self.level_fn(magnitude)
272+
return self.aug_fn(img, *level_args, **self.kwargs)
273+
274+
275+
def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
276+
# ImageNet policy from TPU EfficientNet impl, cannot find
277+
# a paper reference.
278+
policy = [
279+
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
280+
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
281+
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
282+
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
283+
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
284+
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
285+
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
286+
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
287+
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
288+
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
289+
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
290+
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
291+
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
292+
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
293+
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
294+
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
295+
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
296+
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
297+
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
298+
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
299+
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
300+
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
301+
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
302+
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
303+
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
304+
]
305+
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
306+
return pc
307+
308+
309+
def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT):
310+
# ImageNet policy from https://arxiv.org/abs/1805.09501
311+
policy = [
312+
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
313+
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
314+
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
315+
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
316+
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
317+
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
318+
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
319+
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
320+
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
321+
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
322+
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
323+
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
324+
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
325+
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
326+
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
327+
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
328+
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
329+
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
330+
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
331+
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
332+
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
333+
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
334+
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
335+
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
336+
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
337+
]
338+
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
339+
return pc
340+
341+
342+
def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT):
343+
if name == 'original':
344+
return auto_augment_policy_original(hparams)
345+
elif name == 'v0':
346+
return auto_augment_policy_v0(hparams)
347+
else:
348+
assert False, 'Unknown AA policy (%s)' % name
349+
350+
351+
class AutoAugment:
352+
353+
def __init__(self, policy):
354+
self.policy = policy
355+
356+
def __call__(self, img):
357+
sub_policy = random.choice(self.policy)
358+
for op in sub_policy:
359+
img = op(img)
360+
return img

timm/data/loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def create_transform(
9292
is_training=False,
9393
use_prefetcher=False,
9494
color_jitter=0.4,
95+
auto_augment=None,
9596
interpolation='bilinear',
9697
mean=IMAGENET_DEFAULT_MEAN,
9798
std=IMAGENET_DEFAULT_STD,
@@ -112,6 +113,7 @@ def create_transform(
112113
transform = transforms_imagenet_train(
113114
img_size,
114115
color_jitter=color_jitter,
116+
auto_augment=auto_augment,
115117
interpolation=interpolation,
116118
use_prefetcher=use_prefetcher,
117119
mean=mean,
@@ -138,6 +140,7 @@ def create_loader(
138140
rand_erase_mode='const',
139141
rand_erase_count=1,
140142
color_jitter=0.4,
143+
auto_augment=None,
141144
interpolation='bilinear',
142145
mean=IMAGENET_DEFAULT_MEAN,
143146
std=IMAGENET_DEFAULT_STD,
@@ -153,6 +156,7 @@ def create_loader(
153156
is_training=is_training,
154157
use_prefetcher=use_prefetcher,
155158
color_jitter=color_jitter,
159+
auto_augment=auto_augment,
156160
interpolation=interpolation,
157161
mean=mean,
158162
std=std,

0 commit comments

Comments
 (0)