nntoolbox.vision.models.classifier module

class nntoolbox.vision.models.classifier.EnsembleImageClassifier(models: List[nntoolbox.vision.models.classifier.ImageClassifier], model_weights: Optional[List[float]] = None)[source]

Bases: nntoolbox.models.ensemble.Ensemble, nntoolbox.vision.models.classifier.ImageClassifier

predict(images: torch.Tensor, return_probs: bool = False, tries: int = 5) → numpy.ndarray[source]

Predict the classes or class probabilities of a batch of images

Parameters
  • images – images to be predicted

  • return_probs – whether to return prob or classes

  • tries – number of tries for augmentation

Returns

class nntoolbox.vision.models.classifier.ImageClassifier(model: torch.nn.modules.module.Module, tta_transform=None, tta_beta: float = 0.4, device=device(type='cpu'))[source]

Bases: object

Abstraction for an image classifier. Support user defined test time augmentation

evaluate(test_loader: torch.utils.data.dataloader.DataLoader, tries: int = 5) → float[source]
predict(images: torch.Tensor, return_probs: bool = False, tries: int = 5) → numpy.ndarray[source]

Predict the classes or class probabilities of a batch of images

Parameters
  • images – images to be predicted

  • return_probs – whether to return prob or classes

  • tries – number of tries for augmentation

Returns

class nntoolbox.vision.models.classifier.KNNClassifier(database: torch.utils.data.dataloader.DataLoader, model: torch.nn.modules.module.Module, n_neighbors: int = 5, tta_transform=None, tta_beta: float = 0.4, weights: Union[str, Callable] = 'distance', device=device(type='cpu'), threshold=0.0)[source]

Bases: object

evaluate(test_loader: torch.utils.data.dataloader.DataLoader, metrics: Dict[str, nntoolbox.metrics.metrics.Metric], top: int = 5, tries: int = 5) → Dict[str, float][source]
predict(images: torch.Tensor, top: int = 5, tries: int = 5) → Union[numpy.ndarray, Tuple[numpy.ndarray]][source]

Predict the classes or class probabilities of a batch of images :param images: images to be predicted :param tries: number of tries for augmentation :return: