Source code for nntoolbox.vision.utils.dataset
import pandas as pd
from torch.utils.data import Dataset
from typing import Tuple
from torch import Tensor
from torchvision.transforms import ToTensor
from PIL import Image
from warnings import warn
from ...utils import download_from_url
import os
__all__ = ['FaceScrub', 'download_facescrub']
[docs]def download_facescrub(root: str, data_path: str, max_size: int=128):
print("bleh")
df_male = pd.read_csv(root + "/facescrub_actors.txt", sep='\t')
df_female = pd.read_csv(root + "/facescrub_actresses.txt", sep='\t')
n_image = 0
n_ppl = 0
df_both = pd.concat([df_male, df_female])
for i in range(len(df_both)):
try:
url = df_both['url'][i]
name = df_both['name'][i]
folder = data_path + "/" + name
if not os.path.exists(folder):
os.makedirs(folder)
n_ppl += 1
path = folder + "/face_" + str(n_image) + url[:-4]
download_from_url(url, path, max_size)
# Image.open(path)
except:
warn("Image corrupted or URL error. Skip to next image.")
else:
n_image += 1
if n_image >= 2: break
print("Finish downloading " + str(n_image) + " images of " + str(n_ppl) + " people.")
[docs]class FaceScrub(Dataset):
def __init__(self, root, data_path, transform=None):
self.images_paths = []
self.labels = []
self.name2idx = dict()
self.idx2name = dict()
self.transform = ToTensor() if transform is None else transform
df_male = pd.read_csv(root + "/facescrub_actors.txt", sep='\t')
for i in range(len(df_male)):
try:
url = df_male['url'][i]
print(url)
path = data_path + "/face_" + str(len(self.images_paths)) + url[:-4]
download_from_url(url, path)
# Image.open(path)
except:
warn("Image corrupted or URL error. Skip to next image.")
else:
self.images_paths.append(path)
name = df_male['name'][i]
if name not in self.idx2name:
idx = len(self.name2idx)
self.name2idx[name] = idx
self.idx2name[idx] = name
self.labels.append([self.name2idx[name]])
if len(self.images_paths) >= 100:
break
df_female = pd.read_csv(root + "/facescrub_actresses.txt", sep='\t')
for i in range(len(df_female)):
try:
url = df_female['url'][i]
path = data_path + "/face_" + str(len(self.images_paths)) + url[:-4]
download_from_url(url, path)
# Image.open(path)
except:
warn("Image corrupted or URL error. Skip to next image.")
else:
self.images_paths.append(path)
name = df_female['name'][i]
if name not in self.idx2name:
idx = len(self.name2idx)
self.name2idx[name] = idx
self.idx2name[idx] = name
self.labels.append([self.name2idx[name]])
if len(self.images_paths) >= 100:
break
def __len__(self) -> int: return len(self.images_paths)
def __getitem__(self, i: int) -> Tuple[Tensor, Tensor]:
image = Image.open(self.images_paths[i])
image = image.convert('RGB')
return self.transform(image), self.labels[i]