nntoolbox.vision.models.classifier module¶
-
class
nntoolbox.vision.models.classifier.
EnsembleImageClassifier
(models: List[nntoolbox.vision.models.classifier.ImageClassifier], model_weights: Optional[List[float]] = None)[source]¶ Bases:
nntoolbox.models.ensemble.Ensemble
,nntoolbox.vision.models.classifier.ImageClassifier
-
predict
(images: torch.Tensor, return_probs: bool = False, tries: int = 5) → numpy.ndarray[source]¶ Predict the classes or class probabilities of a batch of images
- Parameters
images – images to be predicted
return_probs – whether to return prob or classes
tries – number of tries for augmentation
- Returns
-
-
class
nntoolbox.vision.models.classifier.
ImageClassifier
(model: torch.nn.modules.module.Module, tta_transform=None, tta_beta: float = 0.4, device=device(type='cpu'))[source]¶ Bases:
object
Abstraction for an image classifier. Support user defined test time augmentation
-
predict
(images: torch.Tensor, return_probs: bool = False, tries: int = 5) → numpy.ndarray[source]¶ Predict the classes or class probabilities of a batch of images
- Parameters
images – images to be predicted
return_probs – whether to return prob or classes
tries – number of tries for augmentation
- Returns
-
-
class
nntoolbox.vision.models.classifier.
KNNClassifier
(database: torch.utils.data.dataloader.DataLoader, model: torch.nn.modules.module.Module, n_neighbors: int = 5, tta_transform=None, tta_beta: float = 0.4, weights: Union[str, Callable] = 'distance', device=device(type='cpu'), threshold=0.0)[source]¶ Bases:
object
-
evaluate
(test_loader: torch.utils.data.dataloader.DataLoader, metrics: Dict[str, nntoolbox.metrics.metrics.Metric], top: int = 5, tries: int = 5) → Dict[str, float][source]¶
-