nntoolbox.models.classifier module

class nntoolbox.models.classifier.Classifier(model: torch.nn.modules.module.Module, device=device(type='cpu'), metric: nntoolbox.metrics.metrics.Metric = <nntoolbox.metrics.classification.Accuracy object>)[source]

Bases: object

Abstraction for an classifier

evaluate(test_loader: torch.utils.data.dataloader.DataLoader, requires_prob: bool = False) → float[source]
predict(inputs: torch.Tensor, return_probs: bool = False) → numpy.ndarray[source]

Predict the classes or class probabilities of a batch of inputs :param inputs: inputs to be predicted :param return_probs: whether to return prob or classes :return: