Source code for nntoolbox.models.classifier
from torch.nn import Module
from torch.utils.data import DataLoader
from ..utils import get_device
from ..metrics import Metric, Accuracy
from .ensemble import Ensemble
from sklearn.metrics import accuracy_score
import torch
from torch import Tensor, nn
import numpy as np
from numpy import ndarray
from typing import List, Optional
__all__ = ['Classifier']
[docs]class Classifier:
"""
Abstraction for an classifier
"""
def __init__(self, model: Module, device=get_device(), metric: Metric=Accuracy()):
self._model = model.to(device)
self._model.eval()
self._device = device
self._softmax = nn.Softmax(dim=1)
self.metric = metric
[docs] def predict(self, inputs: Tensor, return_probs: bool=False) -> ndarray:
"""
Predict the classes or class probabilities of a batch of inputs
:param inputs: inputs to be predicted
:param return_probs: whether to return prob or classes
:return:
"""
probs = self._softmax(self._model(inputs.to(self._device)))
if return_probs:
return probs.cpu().detach().numpy()
else:
return torch.argmax(probs, dim=1).cpu().detach().numpy()
[docs] def evaluate(self, test_loader: DataLoader, requires_prob: bool=False) -> float:
total = 0
metrics = 0
for inputs, labels in test_loader:
outputs = self.predict(inputs, return_probs=requires_prob)
labels = labels.cpu().numpy()
logs = {"outputs": outputs, "labels": labels}
metric = self.metric(logs)
total += len(inputs)
metrics += metric * len(inputs)
return metrics / total