Source code for nntoolbox.utils.data

from torch.utils.data import Dataset
from .utils import get_device
import torch
from numpy import ndarray
from torch import float32, long, Tensor
import pandas as pd
from typing import Optional, List, Iterable, Union
from torch.utils.data import DataLoader
from torchtext.data import Iterator


__all__ = ['SupervisedDataset', 'get_first_batch', 'grab_next_batch']


[docs]class SupervisedDataset(Dataset): def __init__(self, inputs: ndarray, labels: ndarray, device=get_device(), transform=None): assert len(inputs) == len(labels) self._device = device self.inputs = torch.from_numpy(inputs) self.labels = torch.from_numpy(labels) self.transform = transform
[docs] @classmethod def from_csv(cls, path: str, label_name: str, data_fields: Optional[List[str]]=None, device=get_device()): """Create a supervised dataset from csv file""" assert path.endswith(".csv") df = pd.read_csv(path) labels = df[label_name].values if data_fields is None: inputs = df.drop(label_name, axis=1).values else: inputs = df[data_fields].values return cls(inputs, labels, device)
def __len__(self): return self.inputs.shape[0] def __getitem__(self, index: int): input, label = self.prepare_arr(self.inputs[index], float32), self.prepare_arr(self.labels[index], long) if self.transform is not None: input = self.transform(input) return input, label
[docs] def prepare_arr(self, tensor: Tensor, dtype): return tensor.to(dtype).to(self._device)
[docs]def get_first_batch(data: DataLoader, callbacks: Optional[Iterable['Callback']]=None): """ Get the first batch from dataloader :param data: the dataloader :param callbacks: the list of callbacks to applied to data """ first_batch = next(iter(data)) if callbacks is None or len(callbacks) == 0: return first_batch else: if isinstance(first_batch, tuple): data = {"inputs": first_batch[0], "labels": first_batch[1]} else: data = {"inputs": first_batch[0], "labels": first_batch[1]} for callback in callbacks: data = callback.on_batch_begin(data, True) return data["inputs"] if callbacks is None else data["inputs"], data["labels"]
[docs]def grab_next_batch(data: Union[DataLoader, Iterator]): """Grab the next batch from dataloader""" return next(iter(data))