1. 什么是 One-hot 编码

最直观的理解就是,比如说现在有三个类别A、B、C,它们对应的标签值分别为 [1, 2, 3],如果对这三个类别使用One-hot编码,得到的结果则是,[[1, 0, 0], [0, 1, 0], [0, 0, 1]],相当于:

  • 1 被编码为 1 0 0
  • 2 被编码为 0 1 0
  • 3 被编码为 0 0 1

2. 为什么要对数据进行 One-hot 编码

分割任务中,网络模型最后的输出shape为 [N, C, H, W] (以pytoch为例, 其中N为batch_size, C为预测的类别数),而我们给的的gt(ground truth)的shape一般为[H, W, 3](彩色图或rgb图)或[H, W](灰度图)。
假设我们现在的分割任务里面有5个目标需要分割,给定的gt是彩色的。则网络模型最后的输出shape为 [N, 5, H, W],这和gt的shape不匹配,在训练的时候它们两者之间不能进行损失值计算。因此,就需要使用One-hot编码对gt进行编码,将其编码为[H, W, 5],最后再对维度进行transpose即可。

编码前和编码后的变化类似图中所示(上图对应编码前,下图对应编码后)。
{% asset_img 1.png %}
(图片来源:https://www.eefocus.com/communication/413211/r0)
{% asset_img 2.png %}
(图片来源:https://www.eefocus.com/communication/413211/r0)

3.代码实现

3.1 方法一

mask_to_onehot用来将标签进行one-hot,onehot_to_mask用来恢复one-hot,在可视化的时候使用。

def mask_to_onehot(mask, palette):
    """
    Converts a segmentation mask (H, W, C) to (H, W, K) where the last dim is a one
    hot encoding vector, C is usually 1 or 3, and K is the number of class.
    """
    semantic_map = []
    for colour in palette:
        equality = np.equal(mask, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
    return semantic_map

def onehot_to_mask(mask, palette):
    """
    Converts a mask (H, W, K) to (H, W, C)
    """
    x = np.argmax(mask, axis=-1)
    colour_codes = np.array(palette)
    x = np.uint8(colour_codes[x.astype(np.uint8)])
    return x

方法一在使用的时候需要先定义好颜色表palette(根据自己的数据集来定义就行了)。下面演示两个例子。

假设gt是灰度图,需要分割两个目标(正常器官和肿瘤)(加上背景就是3分类任务),正常器官的灰度值为128,肿瘤的灰度值为255, 背景的灰度值为0。

palette = [[0], [128], [255]]  # 里面值的顺序不是固定的,可以按自己的要求来
# 注意:灰度图的话要确保 gt的 shape = [H, W, 1],该函数实在最后的通道维上进行映射
# 如果加载后的gt的 shape = [H, W],则需要进行通道的扩维
gt_onehot = mask_to_onehot(gt, palette)  # one-hot 后 gt的shape=[H, W, 3]

假设gt彩色图,需要分割5个目标(加上背景就是6分类任务),颜色值如下。 和灰度图的处理方法类似。

palette = [[0, 0, 0], [192, 224, 224], [128, 128, 64], [0, 192, 128], [128, 128, 192], [128, 128, 0]]
gt_onehot = mask_to_onehot(gt, palette)  # one-hot 后 gt的shape=[H, W, 6]

3.1 方法二

为了以示区别,名字不要起的一样。

def mask2onehot(mask, num_classes):
    """
    Converts a segmentation mask (H,W) to (K,H,W) where the last dim is a one
    hot encoding vector

    """
    _mask = [mask == i for i in range(num_classes)]
    return np.array(_mask).astype(np.uint8)

def onehot2mask(mask):
    """
    Converts a mask (K, H, W) to (H,W)
    """
    _mask = np.argmax(mask, axis=0).astype(np.uint8)
    return _mask

用法:如果gt是灰度图,如上面的例子,用起来就比较简单。

# 需要先指定每个类别的颜色值对应的标签
# 注意: 第一类从0开始,而不是从1开始
label2trainid = {0: 0, 128: 1, 255: 2}
gt_copy = gt.copy()
# 这一步相当于把
for k, v in label2trainid.items():
    gt_copy[gt == k] = v
gt_with_trainid = gt_copy.astype(np.uint8)

gt_onehot = mask2onehot(gt_with_trainid, 3) # one-hot 后 gt的shape=[3, H, W]

如果gt是彩色图,要先把rgb颜色值映射为标签,再进行one-hot编码,相对来说就比较繁琐了。直接用方法一就行了。

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐