|
| 1 | +""" Auto Augment |
| 2 | +Implementation adapted from: |
| 3 | + https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py |
| 4 | +Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172 |
| 5 | +
|
| 6 | +Hacked together by Ross Wightman |
| 7 | +""" |
| 8 | +import random |
| 9 | +import math |
| 10 | +from PIL import Image, ImageOps, ImageEnhance |
| 11 | +import PIL |
| 12 | +import numpy as np |
| 13 | + |
| 14 | + |
| 15 | +_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) |
| 16 | + |
| 17 | +_FILL = (128, 128, 128) |
| 18 | + |
| 19 | +# This signifies the max integer that the controller RNN could predict for the |
| 20 | +# augmentation scheme. |
| 21 | +_MAX_LEVEL = 10. |
| 22 | + |
| 23 | +_HPARAMS_DEFAULT = dict( |
| 24 | + translate_const=250, |
| 25 | + img_mean=_FILL, |
| 26 | +) |
| 27 | + |
| 28 | +_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC) |
| 29 | + |
| 30 | + |
| 31 | +def _interpolation(kwargs): |
| 32 | + interpolation = kwargs.pop('resample', Image.NEAREST) |
| 33 | + if isinstance(interpolation, (list, tuple)): |
| 34 | + return random.choice(interpolation) |
| 35 | + else: |
| 36 | + return interpolation |
| 37 | + |
| 38 | + |
| 39 | +def _check_args_tf(kwargs): |
| 40 | + if 'fillcolor' in kwargs and _PIL_VER < (5, 0): |
| 41 | + kwargs.pop('fillcolor') |
| 42 | + kwargs['resample'] = _interpolation(kwargs) |
| 43 | + |
| 44 | + |
| 45 | +def shear_x(img, factor, **kwargs): |
| 46 | + _check_args_tf(kwargs) |
| 47 | + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) |
| 48 | + |
| 49 | + |
| 50 | +def shear_y(img, factor, **kwargs): |
| 51 | + _check_args_tf(kwargs) |
| 52 | + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) |
| 53 | + |
| 54 | + |
| 55 | +def translate_x_rel(img, pct, **kwargs): |
| 56 | + pixels = pct * img.size[0] |
| 57 | + _check_args_tf(kwargs) |
| 58 | + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) |
| 59 | + |
| 60 | + |
| 61 | +def translate_y_rel(img, pct, **kwargs): |
| 62 | + pixels = pct * img.size[1] |
| 63 | + _check_args_tf(kwargs) |
| 64 | + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) |
| 65 | + |
| 66 | + |
| 67 | +def translate_x_abs(img, pixels, **kwargs): |
| 68 | + _check_args_tf(kwargs) |
| 69 | + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) |
| 70 | + |
| 71 | + |
| 72 | +def translate_y_abs(img, pixels, **kwargs): |
| 73 | + _check_args_tf(kwargs) |
| 74 | + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) |
| 75 | + |
| 76 | + |
| 77 | +def rotate(img, degrees, **kwargs): |
| 78 | + _check_args_tf(kwargs) |
| 79 | + if _PIL_VER >= (5, 2): |
| 80 | + return img.rotate(degrees, **kwargs) |
| 81 | + elif _PIL_VER >= (5, 0): |
| 82 | + w, h = img.size |
| 83 | + post_trans = (0, 0) |
| 84 | + rotn_center = (w / 2.0, h / 2.0) |
| 85 | + angle = -math.radians(degrees) |
| 86 | + matrix = [ |
| 87 | + round(math.cos(angle), 15), |
| 88 | + round(math.sin(angle), 15), |
| 89 | + 0.0, |
| 90 | + round(-math.sin(angle), 15), |
| 91 | + round(math.cos(angle), 15), |
| 92 | + 0.0, |
| 93 | + ] |
| 94 | + |
| 95 | + def transform(x, y, matrix): |
| 96 | + (a, b, c, d, e, f) = matrix |
| 97 | + return a * x + b * y + c, d * x + e * y + f |
| 98 | + |
| 99 | + matrix[2], matrix[5] = transform( |
| 100 | + -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix |
| 101 | + ) |
| 102 | + matrix[2] += rotn_center[0] |
| 103 | + matrix[5] += rotn_center[1] |
| 104 | + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) |
| 105 | + else: |
| 106 | + return img.rotate(degrees, resample=kwargs['resample']) |
| 107 | + |
| 108 | + |
| 109 | +def auto_contrast(img, **__): |
| 110 | + return ImageOps.autocontrast(img) |
| 111 | + |
| 112 | + |
| 113 | +def invert(img, **__): |
| 114 | + return ImageOps.invert(img) |
| 115 | + |
| 116 | + |
| 117 | +def equalize(img, **__): |
| 118 | + return ImageOps.equalize(img) |
| 119 | + |
| 120 | + |
| 121 | +def solarize(img, thresh, **__): |
| 122 | + return ImageOps.solarize(img, thresh) |
| 123 | + |
| 124 | + |
| 125 | +def solarize_add(img, add, thresh=128, **__): |
| 126 | + lut = [] |
| 127 | + for i in range(256): |
| 128 | + if i < thresh: |
| 129 | + lut.append(min(255, i + add)) |
| 130 | + else: |
| 131 | + lut.append(i) |
| 132 | + if img.mode in ("L", "RGB"): |
| 133 | + if img.mode == "RGB" and len(lut) == 256: |
| 134 | + lut = lut + lut + lut |
| 135 | + return img.point(lut) |
| 136 | + else: |
| 137 | + return img |
| 138 | + |
| 139 | + |
| 140 | +def posterize(img, bits_to_keep, **__): |
| 141 | + if bits_to_keep >= 8: |
| 142 | + return img |
| 143 | + bits_to_keep = max(1, bits_to_keep) # prevent all 0 images |
| 144 | + return ImageOps.posterize(img, bits_to_keep) |
| 145 | + |
| 146 | + |
| 147 | +def contrast(img, factor, **__): |
| 148 | + return ImageEnhance.Contrast(img).enhance(factor) |
| 149 | + |
| 150 | + |
| 151 | +def color(img, factor, **__): |
| 152 | + return ImageEnhance.Color(img).enhance(factor) |
| 153 | + |
| 154 | + |
| 155 | +def brightness(img, factor, **__): |
| 156 | + return ImageEnhance.Brightness(img).enhance(factor) |
| 157 | + |
| 158 | + |
| 159 | +def sharpness(img, factor, **__): |
| 160 | + return ImageEnhance.Sharpness(img).enhance(factor) |
| 161 | + |
| 162 | + |
| 163 | +def _randomly_negate(v): |
| 164 | + """With 50% prob, negate the value""" |
| 165 | + return -v if random.random() > 0.5 else v |
| 166 | + |
| 167 | + |
| 168 | +def _rotate_level_to_arg(level): |
| 169 | + # range [-30, 30] |
| 170 | + level = (level / _MAX_LEVEL) * 30. |
| 171 | + level = _randomly_negate(level) |
| 172 | + return (level,) |
| 173 | + |
| 174 | + |
| 175 | +def _enhance_level_to_arg(level): |
| 176 | + # range [0.1, 1.9] |
| 177 | + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) |
| 178 | + |
| 179 | + |
| 180 | +def _shear_level_to_arg(level): |
| 181 | + # range [-0.3, 0.3] |
| 182 | + level = (level / _MAX_LEVEL) * 0.3 |
| 183 | + level = _randomly_negate(level) |
| 184 | + return (level,) |
| 185 | + |
| 186 | + |
| 187 | +def _translate_abs_level_to_arg(level, translate_const): |
| 188 | + level = (level / _MAX_LEVEL) * float(translate_const) |
| 189 | + level = _randomly_negate(level) |
| 190 | + return (level,) |
| 191 | + |
| 192 | + |
| 193 | +def _translate_rel_level_to_arg(level): |
| 194 | + # range [-0.45, 0.45] |
| 195 | + level = (level / _MAX_LEVEL) * 0.45 |
| 196 | + level = _randomly_negate(level) |
| 197 | + return (level,) |
| 198 | + |
| 199 | + |
| 200 | +def level_to_arg(hparams): |
| 201 | + return { |
| 202 | + 'AutoContrast': lambda level: (), |
| 203 | + 'Equalize': lambda level: (), |
| 204 | + 'Invert': lambda level: (), |
| 205 | + 'Rotate': _rotate_level_to_arg, |
| 206 | + # FIXME these are both different from original impl as I believe there is a bug, |
| 207 | + # not sure what is the correct alternative, hence 2 options that look better |
| 208 | + 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8] |
| 209 | + 'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0] |
| 210 | + 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256] |
| 211 | + 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110] |
| 212 | + 'Color': _enhance_level_to_arg, |
| 213 | + 'Contrast': _enhance_level_to_arg, |
| 214 | + 'Brightness': _enhance_level_to_arg, |
| 215 | + 'Sharpness': _enhance_level_to_arg, |
| 216 | + 'ShearX': _shear_level_to_arg, |
| 217 | + 'ShearY': _shear_level_to_arg, |
| 218 | + 'TranslateX': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']), |
| 219 | + 'TranslateY': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']), |
| 220 | + 'TranslateXRel': lambda level: _translate_rel_level_to_arg(level), |
| 221 | + 'TranslateYRel': lambda level: _translate_rel_level_to_arg(level), |
| 222 | + } |
| 223 | + |
| 224 | + |
| 225 | +NAME_TO_OP = { |
| 226 | + 'AutoContrast': auto_contrast, |
| 227 | + 'Equalize': equalize, |
| 228 | + 'Invert': invert, |
| 229 | + 'Rotate': rotate, |
| 230 | + 'Posterize': posterize, |
| 231 | + 'Posterize2': posterize, |
| 232 | + 'Solarize': solarize, |
| 233 | + 'SolarizeAdd': solarize_add, |
| 234 | + 'Color': color, |
| 235 | + 'Contrast': contrast, |
| 236 | + 'Brightness': brightness, |
| 237 | + 'Sharpness': sharpness, |
| 238 | + 'ShearX': shear_x, |
| 239 | + 'ShearY': shear_y, |
| 240 | + 'TranslateX': translate_x_abs, |
| 241 | + 'TranslateY': translate_y_abs, |
| 242 | + 'TranslateXRel': translate_x_rel, |
| 243 | + 'TranslateYRel': translate_y_rel, |
| 244 | +} |
| 245 | + |
| 246 | + |
| 247 | +class AutoAugmentOp: |
| 248 | + |
| 249 | + def __init__(self, name, prob, magnitude, hparams={}): |
| 250 | + self.aug_fn = NAME_TO_OP[name] |
| 251 | + self.level_fn = level_to_arg(hparams)[name] |
| 252 | + self.prob = prob |
| 253 | + self.magnitude = magnitude |
| 254 | + # If std deviation of magnitude is > 0, we introduce some randomness |
| 255 | + # in the usually fixed policy and sample magnitude from normal dist |
| 256 | + # with mean magnitude and std-dev of magnitude_std. |
| 257 | + # NOTE This is being tested as it's not in paper or reference impl. |
| 258 | + self.magnitude_std = 0.5 # FIXME add arg/hparam |
| 259 | + self.kwargs = { |
| 260 | + 'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL, |
| 261 | + 'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION |
| 262 | + } |
| 263 | + |
| 264 | + def __call__(self, img): |
| 265 | + if self.prob < random.random(): |
| 266 | + return img |
| 267 | + magnitude = self.magnitude |
| 268 | + if self.magnitude_std and self.magnitude_std > 0: |
| 269 | + magnitude = random.gauss(magnitude, self.magnitude_std) |
| 270 | + magnitude = min(_MAX_LEVEL, max(0, magnitude)) |
| 271 | + level_args = self.level_fn(magnitude) |
| 272 | + return self.aug_fn(img, *level_args, **self.kwargs) |
| 273 | + |
| 274 | + |
| 275 | +def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT): |
| 276 | + # ImageNet policy from TPU EfficientNet impl, cannot find |
| 277 | + # a paper reference. |
| 278 | + policy = [ |
| 279 | + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], |
| 280 | + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], |
| 281 | + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], |
| 282 | + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], |
| 283 | + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], |
| 284 | + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], |
| 285 | + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], |
| 286 | + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], |
| 287 | + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], |
| 288 | + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], |
| 289 | + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], |
| 290 | + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], |
| 291 | + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], |
| 292 | + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], |
| 293 | + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], |
| 294 | + [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], |
| 295 | + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], |
| 296 | + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], |
| 297 | + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], |
| 298 | + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], |
| 299 | + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], |
| 300 | + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], |
| 301 | + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], |
| 302 | + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], |
| 303 | + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], |
| 304 | + ] |
| 305 | + pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy] |
| 306 | + return pc |
| 307 | + |
| 308 | + |
| 309 | +def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT): |
| 310 | + # ImageNet policy from https://arxiv.org/abs/1805.09501 |
| 311 | + policy = [ |
| 312 | + [('Posterize', 0.4, 8), ('Rotate', 0.6, 9)], |
| 313 | + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], |
| 314 | + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], |
| 315 | + [('Posterize', 0.6, 7), ('Posterize', 0.6, 6)], |
| 316 | + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], |
| 317 | + [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], |
| 318 | + [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], |
| 319 | + [('Posterize', 0.8, 5), ('Equalize', 1.0, 2)], |
| 320 | + [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], |
| 321 | + [('Equalize', 0.6, 8), ('Posterize', 0.4, 6)], |
| 322 | + [('Rotate', 0.8, 8), ('Color', 0.4, 0)], |
| 323 | + [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], |
| 324 | + [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], |
| 325 | + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], |
| 326 | + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], |
| 327 | + [('Rotate', 0.8, 8), ('Color', 1.0, 2)], |
| 328 | + [('Color', 0.8, 8), ('Solarize', 0.8, 7)], |
| 329 | + [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], |
| 330 | + [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], |
| 331 | + [('Color', 0.4, 0), ('Equalize', 0.6, 3)], |
| 332 | + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], |
| 333 | + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], |
| 334 | + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], |
| 335 | + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], |
| 336 | + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], |
| 337 | + ] |
| 338 | + pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy] |
| 339 | + return pc |
| 340 | + |
| 341 | + |
| 342 | +def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT): |
| 343 | + if name == 'original': |
| 344 | + return auto_augment_policy_original(hparams) |
| 345 | + elif name == 'v0': |
| 346 | + return auto_augment_policy_v0(hparams) |
| 347 | + else: |
| 348 | + assert False, 'Unknown AA policy (%s)' % name |
| 349 | + |
| 350 | + |
| 351 | +class AutoAugment: |
| 352 | + |
| 353 | + def __init__(self, policy): |
| 354 | + self.policy = policy |
| 355 | + |
| 356 | + def __call__(self, img): |
| 357 | + sub_policy = random.choice(self.policy) |
| 358 | + for op in sub_policy: |
| 359 | + img = op(img) |
| 360 | + return img |
0 commit comments