在多数情况下,图像分割都会遇到类别不均衡的情况,这时候需要通过权重参数来调节各类之间的比重,一般不同类别的权重占比需要通过多次实验调整,这里介绍一种计算类别的权重占比的方法:中值频率平衡,实际应用时还需要在这个基础上做微调。这里采用中值频率平衡(图像分割中一种定量计算类别权重的方法)
- 根据我的数据集,计算出的类别权重如下:
- 每一类的像素数: [3.13914322e+09 8.32395811e+08 2.22764968e+08]
- 像素总个数 4194304000.0
- 每一类像素数占总像素数的比值 [0.74843007 0.19845863 0.05311131]
- 包含每一类像素的图片个数: [16000. 15412. 14356.]
- 像素出现频率 [0.74843007 0.20603024 0.05919343]
- 每一类的权重 [0.27528322 1. 3.48062665] 因此简化为[0.28,1,3.48]
根据以上类别权重,使用focal loss,更多focal loss的详细用法和原理可以参考链接1,链接2,链接3
考虑到使用以上链接中的代码存在部分问题,根据自己的需求进行修改,代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Focal_loss(nn.Module):
def __init__(self, weight=None, gamma=2):
super(Focal_loss, self).__init__()
self.weight = weight
self.gamma = gamma
self.eps = 1e-8
def forward(self, predict, target):
if self.weight is not None:
# Expand weight to match the shape of the target
weights = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).repeat(predict.shape[0], 1, predict.shape[2], predict.shape[3])
#print('weights.shape:', weights.shape) #([8, 3, 512, 512])
# Convert target to one-hot encoding
target_onehot = F.one_hot(target.long(), num_classes=predict.shape[1]).permute(0, 3, 1, 2).float()
#print('target_onehot.shape:', target_onehot.shape) #([8, 3, 512, 512])
# If weights are provided, sum them accordingly
if self.weight is not None:
weights = torch.sum(target_onehot * weights, dim=1)
#print('weights after sum.shape:', weights.shape) #([8, 512, 512])
# Softmax over the class dimension
input_soft = F.softmax(predict, dim=1) #([8, 3, 512, 512])
# Calculate the probabilities
probs = torch.sum(input_soft * target_onehot, dim=1).clamp(min=0.001, max=0.999) #([8, 512, 512])
# Calculate the focal weight
focal_weight = (1 - probs) ** self.gamma
# Calculate the final loss
if self.weight is not None:
loss = torch.sum(-torch.log(probs) * weights * focal_weight) / torch.sum(weights)
else:
loss = torch.mean(-torch.log(probs) * focal_weight)
return loss
类别权重需要用tensor格式,使用focal loss,如下:类别权重需要用tensor格式,使用focal loss,如下:
weight=torch.tensor([0.28,1,3.48]).cuda()
focal_loss = Focal_loss(weight=weight,gamma=0.5)
print('focal loss: ', FL(input, target))
文章评论