HAN&DAI

  • 首页
  • 遥感应用
  • GIS应用
  • 机器学习
  • 实用工具
  • 文章链接
  • 遥感数据集
HAN&DAI
遥感与地理信息技术交流社区
  1. 首页
  2. 遥感应用
  3. 正文

利用Yolov5实现旋转框目标检测(数据处理部分)

2022年4月23日 738点热度 3人点赞 3条评论

HBB转OBB格式

本文是第一篇目标检测系列文章,主要用于介绍Yolov5在旋转框目标检测的数据处理,巩固学习!!!
用到的数据是Airbus ship detection dataset,其余可以以此为例。

用到的代码

Airbus ship数据处理
Yolov5旋转目标检测

数据处理


(1)加载Airbus ship数据处理代码,只需获取mask_to_rotationbb.py即可,然后对代码做一些修改。

import os
import cv2
from tqdm import tqdm
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from skimage.io import imread
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
from skimage.measure import label, regionprops
#ship_dir = ''
train_image_dir = "./Target_detection/kaggle-airbus-ship-detection/train/"
output_image_dir = "./Target_detection/kaggle-airbus-ship-detection/labels_rotate/"

from skimage.morphology import label
def multi_rle_encode(img):
    labels = label(img[:, :, 0])
    return [rle_encode(labels==k) for k in np.unique(labels[labels>0])]

def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def rle_decode(mask_rle, shape=(768, 768)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T  # Needed to align to RLE direction

def masks_as_image(in_mask_list, all_masks=None):
    # Take the individual ship masks and create a single mask array for all ships
    if all_masks is None:
        all_masks = np.zeros((768, 768), dtype = np.int16)
    #if isinstance(in_mask_list, list):
    for mask in in_mask_list:
        if isinstance(mask, str):
            all_masks += rle_decode(mask)
    return np.expand_dims(all_masks, -1)

masks = pd.read_csv("./Target_detection/kaggle-airbus-ship-detection/train_ship_segmentations.csv")
print(masks.shape[0], 'masks found')
print(masks['ImageId'].value_counts().shape[0])
print(masks.head())

images_with_ship = masks.ImageId[masks.EncodedPixels.isnull()==False]
images_with_ship = np.unique(images_with_ship.values)
print('There are ' +str(len(images_with_ship)) + ' image files with masks')

plotme = 1
image_size = 768
for i in tqdm(range(0, 10)):
    image = images_with_ship[i]

    if plotme == 1:
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize = (15, 5))
    img_0 = cv2.imread(train_image_dir+'/' + image)
    rle_0 = masks.query('ImageId=="'+image+'"')['EncodedPixels']
    mask_0 = masks_as_image(rle_0)
    #
    # 
    lbl_0, lbl_cnt = label(mask_0, return_num=True) 
    #props = regionprops(lbl_0)
    img_1 = img_0.copy()
    #print ('Image', image, lbl_cnt)
    for i in range(1, lbl_cnt+1):
        mask = np.array((lbl_0 == i).astype('uint8')[..., 0])
        mask = cv2.resize(mask, (768, 768))
        cnts, hierarchy = cv2.findContours((255*mask).astype('uint8'), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        rect = cv2.minAreaRect(cnts[0])
        '''
        这是一个大坑,网上说opencv 4.5之前和4.5之后的角度定义不一样,因此注释的是4.5之前的表示,
        4.5之后的表示仅需一行代码即可。
        '''
        '''
        if rect[1][1]>rect[1][0]:
            angle = 90-rect[2]
        else:
            angle = -rect[2]
        '''
        angle = -rect[2]
        if rect[1][0]>rect[1][1]:
            angle = 90-rect[2]
        else:
            angle = -(90-rect[2])

        box = cv2.boxPoints(rect)
        #print (box)
        box = np.int0(box)

        if plotme == 1:
            cv2.drawContours(img_1,[box],0,(0,191,255),2)
        x = int(rect[0][0])
        y = int(rect[0][1])
        #print (rect, angle, 360*((props[i-1].orientation + np.pi/2)/(2*np.pi)), x, y)
        if plotme == 1:
            cv2.circle(img_1, ( x, y ), 5, (255, 0, 0), 3)
        print(str(rect[0][0]/image_size) + ' ' + str(rect[0][1]/image_size) + ' ' + str(rect[1][0]/image_size) + ' ' + str(rect[1][1]) + ' ' + '1' + ' ' + str(angle/90) + '\n' )
    '''
    if plotme == 1:
        ax1.imshow(img_0)
        ax1.set_title('Image')
        ax2.set_title('Mask')
        ax3.set_title('Image with derived bounding box')
        ax2.imshow(mask_0[...,0], cmap='gray')
        ax3.imshow(img_1)
        plt.show()
    '''
plotme = 0
rot_bboxes_dict = {}

for i in tqdm(range(0, len(images_with_ship))):

    image = images_with_ship[i]
    '''
    这里为每一张影像建立一个txt文件,保存格式为:类别, x, y, w, h, theta;均为归一化后的值
    举个例子:
    0 0.44062503178914386 0.08958334724108379 0.006987712035576503 0.004658474586904049 0.7048327975802952 
    0 0.5630836486816406 0.2586206793785095 0.0159581924478213 0.035785041749477386 0.7577621459960937 
    '''
    out_file = open(output_image_dir+'/' + image[:-4] + '.txt', 'w')    

    if plotme == 1:
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize = (15, 5))
    img_0 = cv2.imread(train_image_dir+'/' + image)
    rle_0 = masks.query('ImageId=="'+image+'"')['EncodedPixels']
    mask_0 = masks_as_image(rle_0)
    #
    # 
    lbl_0, lbl_cnt = label(mask_0, return_num=True) 
    #props = regionprops(lbl_0)
    img_1 = img_0.copy()
    bboxes = []
    bboxes1 = []
    for i in range(1, lbl_cnt+1):
        mask = np.array((lbl_0 == i).astype('uint8')[..., 0])
        mask = cv2.resize(mask, (768, 768))
        cnts, hierarchy = cv2.findContours((255*mask).astype('uint8'), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        rect = cv2.minAreaRect(cnts[0])
        '''
        这是一个大坑,网上说opencv 4.5之前和4.5之后的角度定义不一样,因此注释的是4.5之前的表示,
        4.5之后的表示仅需一行代码即可。
        '''
        '''
        if rect[1][1]>rect[1][0]:
            angle = 90-rect[2]
        else:
            angle = -rect[2]
        '''
        angle = rect[2]
        box = cv2.boxPoints(rect)
        #print (box)
        box = np.int0(box)

        if plotme == 1:
            cv2.drawContours(img_1,[box],0,(0,191,255),2)
        x = int(rect[0][0])
        y = int(rect[0][1])
        if plotme == 1:
            cv2.circle(img_1, ( x, y ), 5, (255, 0, 0), 3)
        bboxes.append([rect[0][0]/image_size, rect[0][1]/image_size, rect[1][0]/image_size, rect[1][1]/image_size, angle/90])
        bboxes1.append([0, rect[0][0]/image_size, rect[0][1]/image_size, rect[1][0]/image_size, rect[1][1]/image_size, angle/90])
    #print (bboxes1)

    rot_bboxes_dict[image] = bboxes.copy()
    #print (rot_bboxes_dict)
    for k, data in enumerate(bboxes1):
        #print (k, data)
        #print (type(data))
        for j in data:
            #print ('j',j)
            out_file.write(str(j)+ ' ')
        out_file.write('\n')

    '''
    if plotme == 1:
        ax1.imshow(img_0)
        ax1.set_title('Image')
        ax2.set_title('Mask')
        ax3.set_title('Image with derived bounding box')
        ax2.imshow(mask_0[...,0], cmap='gray')
        ax3.imshow(img_1)
        plt.show()
    '''
bboxes_df = pd.DataFrame([rot_bboxes_dict])
bboxes_df = bboxes_df.transpose()
bboxes_df.columns = ['bbox_list']
print(bboxes_df.head())

(2) 以上代码运行后,会得到txt形式的文件夹,如图所示:

(3)然后,对以上的labels和images文件的图像进行分割,分为train、val和test文件夹。

(4)然后,修改本文前面给出的Yolov5代码的数据路径即可运行。

需要注意的一点:Yolov5代码中需要安装detectron这个包,需要下载到本地后才能安装。

Post Views: 679

相关文章:

  1. DarkLabel2.4软件标注视频影像数据做目标检测/跟踪(数据预处理)
  2. 图像加/解密简单系统
  3. 光学影像和SAR影像相互转换(代码实现)
  4. 遥感影像/自然影像手动点选距离(长度)计算
本作品采用 知识共享署名 4.0 国际许可协议 进行许可
标签: 目标检测
最后更新:2023年7月20日

HAN&DAI

RS和GIS研究兴趣者,永远在学习的路上!

打赏 点赞
下一篇 >

文章评论

  • 李大江

    yolov5旋转框?

    2022年4月24日
    回复
    • daidai

      @李大江 是的,大兄嘚

      2022年4月24日
      回复
  • 李大江

    牛皮啊

    2022年4月24日
    回复
  • razz evil exclaim smile redface biggrin eek confused idea lol mad twisted rolleyes wink cool arrow neutral cry mrgreen drooling persevering
    回复 李大江 取消回复

    文章目录
    • 用到的代码
    • 数据处理
    浏览最多的文章
    • BUG:ImportError: /lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.29' not found (1,462)
    • BUG:“ModuleNotFoundError: No module named '_ext'”的解决方案 (1,229)
    • 利用GEE下载指定区域Landsat8影像 (1,176)
    • 利用arcgis制作深度学习标签数据(以二分类为例) (900)
    • 利用传统机器学习方法进行遥感影像分类-以随机森林(RF)为例 (807)

    COPYRIGHT © 2025 HAN&DAI. ALL RIGHTS RESERVED. QQ交流群:821388027

    Theme Kratos Made By Seaton Jiang