eg1: 等元素为True
torch.eq(torch.arange(5).view(5,1),torch.arange(5).view(5,1)) Out[28]: tensor([[True], [True], [True], [True], [True]]) torch.eq(torch.arange(5).view(5,1),torch.arange(5).view(5,1).T) Out[29]: tensor([[ True, False, False, False, False], [False, True, False, False, False], [False, False, True, False, False], [False, False, False, True, False], [False, False, False, False, True]]) torch.eq(torch.arange(5).view(5,1).T,torch.arange(5).view(5,1)) Out[30]: tensor([[ True, False, False, False, False], [False, True, False, False, False], [False, False, True, False, False], [False, False, False, True, False], [False, False, False, False, True]])
eg2: 解开指定维度
torch.unbind(torch.tensor1), Out[36]: (tensor([1, 2, 3]),) torch.unbind(torch.tensor([1],[2],[3]]),0), Out[37]: (tensor([1]), tensor([2]), tensor([3]))
eg3:
torch.unbind(features,dim=1) Out[44]: (tensor([[ 0.0165, -0.1257, 0.0335, ..., 0.0430, 0.0588, 0.0256], [ 0.0581, -0.0996, -0.0443, ..., -0.0111, 0.1081, -0.0078], [ 0.0172, -0.1306, -0.0858, ..., -0.0411, 0.0833, 0.0013], ..., [ 0.0601, -0.1264, -0.0413, ..., 0.0127, 0.1198, -0.0309], [-0.0102, -0.1497, 0.0010, ..., -0.0122, 0.1112, -0.0583], [ 0.0758, -0.1189, -0.0197, ..., 0.0220, 0.0872, -0.0166]], device='cuda:0', grad_fn=<UnbindBackward>), tensor([[ 0.0165, -0.1257, 0.0335, ..., 0.0430, 0.0588, 0.0256], [ 0.0581, -0.0996, -0.0443, ..., -0.0111, 0.1081, -0.0078], [ 0.0172, -0.1306, -0.0858, ..., -0.0411, 0.0833, 0.0013], ..., [ 0.0601, -0.1264, -0.0413, ..., 0.0127, 0.1198, -0.0309], [-0.0102, -0.1497, 0.0010, ..., -0.0122, 0.1112, -0.0583], [ 0.0758, -0.1189, -0.0197, ..., 0.0220, 0.0872, -0.0166]], device='cuda:0', grad_fn=<UnbindBackward>))
eg4: t.repeat()

eg5: torch.div()
eq6:
# tile mask mask = mask.repeat(anchor_count, contrast_count) # mask-out self-contrast cases logits_mask = torch.scatter( torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0 )
logits_mask tensor([[0., 1., 1., ..., 1., 1., 1.], [1., 0., 1., ..., 1., 1., 1.], [1., 1., 0., ..., 1., 1., 1.], ..., [1., 1., 1., ..., 0., 1., 1.], [1., 1., 1., ..., 1., 0., 1.], [1., 1., 1., ..., 1., 1., 0.]], device='cuda:0')
torch.scatter( torch.ones_like(mask), 0, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0)
tensor([[0., 1., 1., ..., 1., 1., 1.], [0., 1., 1., ..., 1., 1., 1.], [0., 1., 1., ..., 1., 1., 1.], ..., [0., 1., 1., ..., 1., 1., 1.], [0., 1., 1., ..., 1., 1., 1.], [0., 1., 1., ..., 1., 1., 1.]], device='cuda:0')
【笔记】scatter_函数:用法如 torch.zeros(target.size(0), 2).scatter_(1,target,1).to(self.device)_探索程序猿的博客-CSDN博客
eg7: 消除对角元素
mask = mask.repeat(anchor_count, contrast_count) # mask-out self-contrast cases logits_mask = torch.scatter( torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0 ) mask = mask * logits_mask
mask Out[20]: tensor([[1., 0., 1., ..., 1., 0., 0.], [0., 1., 0., ..., 0., 0., 0.], [1., 0., 1., ..., 1., 0., .], ..., [1., 0., 1., ..., 1., 0., 0.], [0., 0., 0., ..., 0., 1., 0.], [0., 0., 0., ..., 0., 0., 1.]], device='cuda:0') mask*logits_mask Out[21]: tensor([[0., 0., 1., ..., 1., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [1., 0., 0., ..., 1., 0., 0.], ..., [1., 0., 1., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')
eg8:
2 64 loss = loss.view(anchor_count, batch_size)
tensor([[4.7642, 4.7700, 4.9274, 4.8298, 5.0163, 4.8054, 4.9344, 4.6297, 5.0533, 4.8786, 4.8808, 4.8946, 4.5965, 4.9085, 4.6794, 4.9939, 4.8648, 4.8382, 4.5422, 4.7529, 4.6383, 4.7940, 4.7202, 4.9732, 4.5696, 4.7187, 4.8346, 4.8804, 4.5355, 4.7395, 4.8884, 4.7580, 5.0020, 4.9140, 5.2952, 4.7402, 4.8660, 4.9400, 4.9015, 4.8370, 5.0518, 4.8339, 5.0241, 4.8498, 5.0187, 4.6112, 4.6124, 4.7228, 4.8453, 4.6810, 4.7281, 4.7040, 4.8005, 5.0514, 5.0573, 4.2868, 4.9171, 4.5031, 4.7733, 4.8827, 4.7193, 4.9463, 4.8855, 4.9188], [4.7642, 4.7700, 4.9274, 4.8298, 5.0163, 4.8054, 4.9344, 4.6297, 5.0533, 4.8786, 4.8808, 4.8946, 4.5965, 4.9085, 4.6794, 4.9939, 4.8648, 4.8382, 4.5422, 4.7529, 4.6383, 4.7940, 4.7202, 4.9732, 4.5696, 4.7187, 4.8346, 4.8804, 4.5355, 4.7395, 4.8884, 4.7580, 5.0020, 4.9140, 5.2952, 4.7402, 4.8660, 4.9400, 4.9015, 4.8370, 5.0518, 4.8339, 5.0241, 4.8498, 5.0187, 4.6112, 4.6124, 4.7228, 4.8453, 4.6810, 4.7281, 4.7040, 4.8005, 5.0514, 5.0573, 4.2868, 4.9171, 4.5031, 4.7733, 4.8827, 4.7193, 4.9463, 4.8855, 4.9188]], device='cuda:0', grad_fn=<ViewBackward>)
loss = loss.view(anchor_count, batch_size).mean()
tensor(4.8208, device='cuda:0', grad_fn=<MeanBackward0>)
loss.view(anchor_count, batch_size).mean(0)
tensor([4.7642, 4.7700, 4.9274, 4.8298, 5.0163, 4.8054, 4.9344, 4.6297, 5.0533, 4.8786, 4.8808, 4.8946, 4.5965, 4.9085, 4.6794, 4.9939, 4.8648, 4.8382, 4.5422, 4.7529, 4.6383, 4.7940, 4.7202, 4.9732, 4.5696, 4.7187, 4.8346, 4.8804, 4.5355, 4.7395, 4.8884, 4.7580, 5.0020, 4.9140, 5.2952, 4.7402, 4.8660, 4.9400, 4.9015, 4.8370, 5.0518, 4.8339, 5.0241, 4.8498, 5.0187, 4.6112, 4.6124, 4.7228, 4.8453, 4.6810, 4.7281, 4.7040, 4.8005, 5.0514, 5.0573, 4.2868, 4.9171, 4.5031, 4.7733, 4.8827, 4.7193, 4.9463, 4.8855, 4.9188], device='cuda:0', grad_fn=<MeanBackward1>)
loss.view(anchor_count, batch_size).mean(1)
tensor([4.8208, 4.8208], device='cuda:0', grad_fn=<MeanBackward1>)
class SupConLoss(nn.Module):
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
It also supports the unsupervised contrastive loss in SimCLR"""
def __init__(self, temperature=0.07, contrast_mode='all',
base_temperature=0.07):
super(SupConLoss, self).__init__()
self.temperature = temperature
self.contrast_mode = contrast_mode
self.base_temperature = base_temperature
def forward(self, features, labels=None, mask=None):
"""Compute loss for model. If both `labels` and `mask` are None,
it degenerates to SimCLR unsupervised loss:
https://arxiv.org/pdf/2002.05709.pdf
Args:
features: hidden vector of shape [bsz, n_views, ...].
labels: ground truth of shape [bsz].
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
has the same class as sample i. Can be asymmetric.
Returns:
A loss scalar.
"""
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))
if len(features.shape) < 3:
raise ValueError('`features` needs to be [bsz, n_views, ...],'
'at least 3 dimensions are required')
if len(features.shape) > 3:
features = features.view(features.shape[0], features.shape[1], -1)
batch_size = features.shape[0]
if labels is not None and mask is not None:
raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None:
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None:
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
raise ValueError('Num of labels does not match num of features')
mask = torch.eq(labels, labels.T).float().to(device)
else:
mask = mask.float().to(device)
contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
if self.contrast_mode == 'one':
anchor_feature = features[:, 0]
anchor_count = 1
elif self.contrast_mode == 'all':
anchor_feature = contrast_feature
anchor_count = contrast_count
else:
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),
self.temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# tile mask
mask = mask.repeat(anchor_count, contrast_count)
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
0
)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
# loss
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
return loss