Source code for nntoolbox.vision.transforms.batch
"""Transform input by batch"""
from torch import Tensor, flip, rot90
from functools import partial
from typing import List, Callable
__all__ = [
'BatchCompose', 'Identity', 'BatchLambdaTransform',
'BatchHorizontalFlip', 'BatchVerticalFlip', 'BatchRotation90',
'BatchRotation180', 'BatchRotation270'
]
hflip = partial(flip, dims=(-1,))
vflip = partial(flip, dims=(-2,))
rotimg = partial(rot90, dims=(-2, -1))
rot90deg = partial(rotimg, k=1)
rot180deg = partial(rotimg, k=2)
rot270deg = partial(rotimg, k=3)
[docs]class BatchCompose:
def __init__(self, transforms: List):
self.transforms = transforms
def __call__(self, input: Tensor) -> Tensor:
for transform in self.transforms: input = transform(input)
return input
[docs]class Identity:
def __call__(self, input: Tensor) -> Tensor: return input
[docs]class BatchHorizontalFlip(BatchLambdaTransform):
def __init__(self): super(BatchHorizontalFlip, self).__init__(hflip)
[docs]class BatchVerticalFlip(BatchLambdaTransform):
def __init__(self): super(BatchVerticalFlip, self).__init__(vflip)
[docs]class BatchRotation90(BatchLambdaTransform):
def __init__(self): super(BatchRotation90, self).__init__(rot90deg)
[docs]class BatchRotation180(BatchLambdaTransform):
def __init__(self): super(BatchRotation180, self).__init__(rot180deg)
[docs]class BatchRotation270(BatchLambdaTransform):
def __init__(self): super(BatchRotation270, self).__init__(rot270deg)