nntoolbox.vision.utils.selector module

Selecting pairs, triplets, npairs, etc. from a batch of data

class nntoolbox.vision.utils.selector.AllPairSelector[source]

Bases: nntoolbox.vision.utils.selector.PairSelector

Select all pair from the batch

get_pairs(embeddings: torch.Tensor, labels: torch.Tensor) → Tuple[numpy.ndarray, numpy.ndarray][source]
class nntoolbox.vision.utils.selector.AllTripletSelector[source]

Bases: nntoolbox.vision.utils.selector.TripletSelector

get_triplets(embeddings: torch.Tensor, labels: torch.Tensor) → numpy.ndarray[source]
class nntoolbox.vision.utils.selector.BatchHardTripletSelector[source]

Bases: nntoolbox.vision.utils.selector.TripletSelector

get_triplets(embeddings: torch.Tensor, labels: torch.Tensor) → numpy.ndarray[source]
class nntoolbox.vision.utils.selector.HardTripletSelector(margin: float = 1.0, n_neg_per_ap: int = 1, mode: str = 'semi-hard')[source]

Bases: nntoolbox.vision.utils.selector.TripletSelector

get_triplets(embeddings: torch.Tensor, labels: torch.Tensor) → numpy.ndarray[source]
class nntoolbox.vision.utils.selector.PairSelector[source]

Bases: nntoolbox.vision.utils.selector.Selector

get_pairs(embeddings: torch.Tensor, labels: torch.Tensor) → Tuple[numpy.ndarray, numpy.ndarray][source]
select(embeddings: torch.Tensor, labels: torch.Tensor) → Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor][source]
class nntoolbox.vision.utils.selector.Selector[source]

Bases: object

Abstract class for selector

select(embedings: torch.Tensor, labels: torch.Tensor) → Tuple[torch.Tensor, ][source]
class nntoolbox.vision.utils.selector.TripletSelector[source]

Bases: nntoolbox.vision.utils.selector.Selector

get_triplets(embeddings: torch.Tensor, labels: torch.Tensor) → numpy.ndarray[source]
select(embeddings: torch.Tensor, labels: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]