Source code for nntoolbox.vision.utils.data

import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torch import Tensor
from .utils import is_image
from PIL import Image
import os
from typing import Tuple, Any, Union, List
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


__all__ = ['UnlabelledImageDataset', 'UnsupervisedFromSupervisedDataset', 'PairedDataset', 'UnlabelledImageListDataset']


[docs]class UnlabelledImageDataset(Dataset): def __init__(self, paths: Union[List[str], str], transform=None, img_dim=None): """ :param paths: a (list of) path(s) to folder(s) of images :param transforms: A transform (possibly a composed one) taking in a PIL image and return a PIL image """ print("Begin reading images and convert to RGB") super(UnlabelledImageDataset, self).__init__() if isinstance(paths, str): paths = [paths] self._images = [] for path in paths: for filename in os.listdir(path): if is_image(filename): full_path = path + filename image = Image.open(full_path).convert('RGB') if img_dim is not None: image = image.resize(img_dim) self._images.append(image) self.transform = transform self._to_tensor = ToTensor() def __len__(self): return len(self._images) def __getitem__(self, i) -> Tensor: if self.transform is not None: return self._to_tensor(self.transform(self._images[i])) else: return self._to_tensor(self._images[i])
[docs]class UnsupervisedFromSupervisedDataset(Dataset): """ Convert a supervisded dataset to an unsupervised dataset """ def __init__(self, dataset: Dataset, transform=None): self._data = dataset self.transform = transform def __getitem__(self, index): data = self._data.__getitem__(index)[0] return self.transform(data) if self.transform is not None else data def __len__(self): return len(self._data)
[docs]class PairedDataset(Dataset): """ Pair up two datasets, and allow users to sample a pair, one from each dataset """ def __init__(self, data_1: Dataset, data_2: Dataset): super(PairedDataset, self).__init__() self.data_1 = data_1 self.data_2 = data_2 def __getitem__(self, index: int) -> Tuple[Any, Any]: assert index < self.__len__() i = index % len(self.data_1) j = index // len(self.data_1) x1 = self.data_1[i] x2 = self.data_2[j] return x1, x2 def __len__(self) -> int: return len(self.data_1) * len(self.data_2)
[docs]class UnlabelledImageListDataset(Dataset): """ Abstraction for a list of path to images without labels """ def __init__(self, paths: Union[List[str], str], transform=None, img_dim=None): """ :param paths: a (list of) path(s) to folder(s) of images :param transforms: A transform (possibly a composed one) taking in a PIL image and return a PIL image :param img_dim: """ print("Begin reading images and convert to RGB") super(UnlabelledImageListDataset, self).__init__() if isinstance(paths, str): paths = [paths] self._image_paths = [] for path in paths: for filename in os.listdir(path): if is_image(filename): full_path = path + filename self._image_paths.append(full_path) self.transform = transform self.img_dim = img_dim self._to_tensor = ToTensor() def __len__(self): return len(self._image_paths) def __getitem__(self, index): image = Image.open(self._image_paths[index]) if self.img_dim is not None: image = image.resize(self.img_dim) image = image.convert('RGB') if self.transform is not None: return self._to_tensor(self.transform(image)) else: return self._to_tensor(image)