"""Selecting pairs, triplets, npairs, etc. from a batch of data"""
# Based on https://github.com/adambielski/siamese-triplet/blob/master/utils.py
from torch import Tensor
import numpy as np
from numpy import ndarray
from itertools import combinations
from typing import Tuple
import torch
from ...utils import emb_pairwise_dist
__all__ = [
'Selector', 'PairSelector', 'AllPairSelector',
'TripletSelector', 'AllTripletSelector', 'BatchHardTripletSelector',
'HardTripletSelector'
]
[docs]class Selector:
"""Abstract class for selector"""
[docs] def select(self, embedings: Tensor, labels: Tensor) -> Tuple[Tensor, ...]: pass
[docs]class PairSelector(Selector):
[docs] @torch.no_grad()
def get_pairs(self, embeddings: Tensor, labels: Tensor) -> Tuple[ndarray, ndarray]:
raise NotImplementedError
[docs] def select(self, embeddings: Tensor, labels: Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]:
pos_pairs, neg_pairs = self.get_pairs(embeddings, labels)
first = torch.cat(
(
torch.index_select(embeddings, dim=0, index=torch.tensor(pos_pairs[:, 0]).long().to(embeddings.device)),
torch.index_select(embeddings, dim=0, index=torch.tensor(neg_pairs[:, 0]).long()).to(embeddings.device),
),
dim=0
)
second = torch.cat(
(
torch.index_select(embeddings, dim=0, index=torch.tensor(pos_pairs[:, 1]).long().to(embeddings.device)),
torch.index_select(embeddings, dim=0, index=torch.tensor(neg_pairs[:, 1]).long().to(embeddings.device)),
),
dim=0
)
labels = torch.cat(
(torch.ones(len(pos_pairs)), torch.zeros(len(neg_pairs))),
dim=0
).to(embeddings.device)
return (first, second), labels
[docs]class AllPairSelector(PairSelector):
"""Select all pair from the batch"""
[docs] @torch.no_grad()
def get_pairs(self, embeddings: Tensor, labels: Tensor) -> Tuple[ndarray, ndarray]:
return get_all_pairs(labels.cpu().detach().numpy())
[docs]class TripletSelector(Selector):
[docs] @torch.no_grad()
def get_triplets(self, embeddings: Tensor, labels: Tensor) -> ndarray:
raise NotImplementedError
[docs] def select(self, embeddings: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
triplets = self.get_triplets(embeddings, labels)
anchors = torch.index_select(embeddings, dim=0, index=torch.tensor(triplets[:, 0]).long().to(embeddings.device))
pos = torch.index_select(embeddings, dim=0, index=torch.tensor(triplets[:, 1]).long().to(embeddings.device))
negs = torch.index_select(embeddings, dim=0, index=torch.tensor(triplets[:, 2]).long().to(embeddings.device))
return anchors, pos, negs
[docs]class AllTripletSelector(TripletSelector):
[docs] @torch.no_grad()
def get_triplets(self, embeddings: Tensor, labels: Tensor) -> ndarray:
return get_all_triplets(labels.cpu().detach().numpy())
[docs]class BatchHardTripletSelector(TripletSelector):
[docs] @torch.no_grad()
def get_triplets(self, embeddings: Tensor, labels: Tensor) -> ndarray:
return get_batch_hard_triplets(embeddings, labels.cpu().detach().numpy())
[docs]class HardTripletSelector(TripletSelector):
def __init__(
self, margin: float=1.0, n_neg_per_ap: int=1,
mode: str="semi-hard"
):
self._margin = margin
self._n_neg_per_ap = n_neg_per_ap
self._mode = mode
[docs] @torch.no_grad()
def get_triplets(self, embeddings: Tensor, labels: Tensor) -> ndarray:
return get_hard_triplets(
embeddings, labels.cpu().detach().numpy(),
self._margin, self._n_neg_per_ap, self._mode
)
def get_all_pairs(labels: ndarray) -> Tuple[ndarray, ndarray]:
"""Select all possible pairs from batch"""
labels_flat = labels.ravel()
all_pairs = np.array(list(combinations(range(len(labels_flat)), 2))).astype(np.uint8)
pos_pairs = all_pairs[(labels_flat[all_pairs[:, 0]] == labels_flat[all_pairs[:, 1]]).astype(np.uint8).nonzero()]
neg_pairs = all_pairs[(labels_flat[all_pairs[:, 0]] != labels_flat[all_pairs[:, 1]]).astype(np.uint8).nonzero()]
return pos_pairs, neg_pairs
def get_all_triplets(labels: ndarray) -> ndarray:
"""Select all possible triplets of (anchor, positive, negative) from batch"""
triplets = []
pos_pairs, neg_pairs = get_all_pairs(labels)
for pos in pos_pairs:
for neg in neg_pairs:
if pos[0] == neg[0]:
triplets.append([pos[0], pos[1], neg[1]])
return np.array(triplets)
def get_batch_hard_triplets(embeddings: Tensor, labels: ndarray) -> ndarray:
"""
Implement the batch-hard strategy: For each anchor, select the corresponding hardest (furthest) positive
and hardest (nearest) negative
References:
https://arxiv.org/pdf/1703.07737.pdf
:param embeddings:
:param labels:
:return: array of triplet indices
"""
triplets = []
dist_mat = emb_pairwise_dist(embeddings, False)
unique_class = set(labels.ravel())
for c in unique_class:
c_idx = np.where(labels == c)[0]
if len(c_idx) < 2:
continue
other_idx = np.where(labels != c)[0]
if len(other_idx) < 1:
continue
for an in c_idx:
pos_pairs = np.array([[an, p] for p in c_idx if p != an])
neg_pairs = np.array([[an, n] for n in other_idx])
pos_pair_dist = dist_mat[pos_pairs[:, 0], pos_pairs[:, 1]]
neg_pair_dist = dist_mat[neg_pairs[:, 0], neg_pairs[:, 1]]
hardest_pos = torch.argmax(pos_pair_dist) # furthest positive
hardest_neg = torch.argmin(neg_pair_dist) # nearest negative
triplets.append([an, pos_pairs[hardest_pos, 1], neg_pairs[hardest_neg, 1]])
return np.array(triplets)
def get_hard_triplets(
embeddings: Tensor, labels: ndarray, margin: float=1.0,
n_neg_per_ap: int=1, mode: str="semi-hard",
) -> ndarray:
"""
Hard and semi-hard triplet selecting strategy:
Hard: for each anchor, select negative and positive such that positive is still further to anchor than negative.
Semi-hard: for each anchor, select negative and positive such that positive is still closer to anchor than negative,
but the difference is less than desired margin
:param embeddings:
:param labels:
:param margin:
:param n_neg_per_ap: number of negatives to choose per anchor-positive pair
:param mode
:return:
"""
triplets = []
dist_mat = emb_pairwise_dist(embeddings, False)
unique_class = set(labels.ravel())
for c in unique_class:
c_idx = np.where(labels == c)[0]
if len(c_idx) < 2:
continue
other_idx = np.where(labels != c)[0]
if len(other_idx) < 1:
continue
pos_pairs = np.array(list(combinations(c_idx, 2)))
pos_pair_dist = dist_mat[pos_pairs[:, 0], pos_pairs[:, 1]]
for pos_pair, dist in zip(pos_pairs, pos_pair_dist):
losses = (dist - dist_mat[pos_pair[0], other_idx] + margin).cpu().detach().numpy()
if mode == 'hard':
hard = np.where(losses > 0)[0]
if mode == 'semi-hard':
hard = np.where(np.logical_and(losses > 0, losses < margin))[0]
if len(hard) > 0:
chosen = np.random.choice(hard, min(len(hard), n_neg_per_ap))
neg_ind = other_idx[chosen]
for neg in neg_ind:
triplets.append([pos_pair[0], pos_pair[1], neg])
return np.array(triplets)