HAN&DAI

  • 首页
  • 遥感应用
  • GIS应用
  • 机器学习
  • 实用工具
  • 文章链接
  • 遥感数据集
HAN&DAI
遥感与地理信息技术交流社区
  1. 首页
  2. 机器学习
  3. 正文

语义分割任务-多分类-类别权重Focal loss(pytorch)代码实现

2024年5月14日 651点热度 2人点赞 0条评论

在多数情况下,图像分割都会遇到类别不均衡的情况,这时候需要通过权重参数来调节各类之间的比重,一般不同类别的权重占比需要通过多次实验调整,这里介绍一种计算类别的权重占比的方法:中值频率平衡,实际应用时还需要在这个基础上做微调。这里采用中值频率平衡(图像分割中一种定量计算类别权重的方法)

  1. 根据我的数据集,计算出的类别权重如下:
  2. 每一类的像素数: [3.13914322e+09 8.32395811e+08 2.22764968e+08]
  3. 像素总个数 4194304000.0
  4. 每一类像素数占总像素数的比值 [0.74843007 0.19845863 0.05311131]
  5. 包含每一类像素的图片个数: [16000. 15412. 14356.]
  6. 像素出现频率 [0.74843007 0.20603024 0.05919343]
  7. 每一类的权重 [0.27528322 1. 3.48062665] 因此简化为[0.28,1,3.48]

根据以上类别权重,使用focal loss,更多focal loss的详细用法和原理可以参考链接1,链接2,链接3
考虑到使用以上链接中的代码存在部分问题,根据自己的需求进行修改,代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Focal_loss(nn.Module):
    def __init__(self, weight=None, gamma=2):
        super(Focal_loss, self).__init__()
        self.weight = weight
        self.gamma = gamma
        self.eps = 1e-8

    def forward(self, predict, target):
        if self.weight is not None:
            # Expand weight to match the shape of the target
            weights = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).repeat(predict.shape[0], 1, predict.shape[2], predict.shape[3])
            #print('weights.shape:', weights.shape)   #([8, 3, 512, 512])

        # Convert target to one-hot encoding
        target_onehot = F.one_hot(target.long(), num_classes=predict.shape[1]).permute(0, 3, 1, 2).float()
        #print('target_onehot.shape:', target_onehot.shape)   #([8, 3, 512, 512])

        # If weights are provided, sum them accordingly
        if self.weight is not None:
            weights = torch.sum(target_onehot * weights, dim=1)
            #print('weights after sum.shape:', weights.shape)   #([8, 512, 512])

        # Softmax over the class dimension
        input_soft = F.softmax(predict, dim=1)   #([8, 3, 512, 512])
        # Calculate the probabilities
        probs = torch.sum(input_soft * target_onehot, dim=1).clamp(min=0.001, max=0.999)   #([8, 512, 512])
        # Calculate the focal weight
        focal_weight = (1 - probs) ** self.gamma
        # Calculate the final loss
        if self.weight is not None:
            loss = torch.sum(-torch.log(probs) * weights * focal_weight) / torch.sum(weights)
        else:
            loss = torch.mean(-torch.log(probs) * focal_weight)

        return loss

类别权重需要用tensor格式,使用focal loss,如下:类别权重需要用tensor格式,使用focal loss,如下:

weight=torch.tensor([0.28,1,3.48]).cuda()
focal_loss = Focal_loss(weight=weight,gamma=0.5)
print('focal loss: ', FL(input, target))
Post Views: 662

相关文章:

  1. 光学影像和SAR影像相互转换(代码实现)
  2. cuda 11.x如何配置tensorflow 1.x?(以3090为例)
  3. labelme等标注软件多分类(二分类)json文件转mask(可rgb显示或one-hot显示)
  4. YOLOv5+DeepSORT实现基于检测的视频目标跟踪
本作品采用 知识共享署名 4.0 国际许可协议 进行许可
标签: 暂无
最后更新:2024年5月14日

daidai

一个热爱RS和GIS技术的小姐姐!

打赏 点赞
< 上一篇
下一篇 >

文章评论

razz evil exclaim smile redface biggrin eek confused idea lol mad twisted rolleyes wink cool arrow neutral cry mrgreen drooling persevering
取消回复

浏览最多的文章
  • BUG:ImportError: /lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.29' not found (1,475)
  • BUG:“ModuleNotFoundError: No module named '_ext'”的解决方案 (1,243)
  • 利用GEE下载指定区域Landsat8影像 (1,199)
  • 利用arcgis制作深度学习标签数据(以二分类为例) (910)
  • 利用传统机器学习方法进行遥感影像分类-以随机森林(RF)为例 (822)

COPYRIGHT © 2025 HAN&DAI. ALL RIGHTS RESERVED. QQ交流群:821388027

Theme Kratos Made By Seaton Jiang