资讯详情

用Transformer思想的分类器进行小样本分割

作者丨李xiang

来源丨GiantPandaCV

文章目录

  • 1 前言

  • 2 CWT-for-FSS 整体架构

  • 3 求解方法

  • 4 分析实验结果

  • 5 代码和可视化

  • 6 总结

  • 7 参考链接

1 前言

aa3703fa98cf35233635ddcbfcf9bf6c.png

我以前写过几篇与医学图像分割相关的论文阅读笔记。这一次,我计划打开一个小样本语义分割的新坑。这篇阅读笔记中介绍的论文也是很久以前读过的。 ICCV 上面,思路值得借鉴。代码也跑过了,但还没来得及整理,arXiv:https://arxiv.org/pdf/2108.03032.pdf 。

针对小样本的语义分割,本文提出了一个更简单的元学习范式,即只对分类器进行元学习,并采用传统的特征编码解码器分割模型训练方法。也就是说,它只适用于 Classifier Weight Transformer(后面简称 CWT)元学习的训练使 CWT 测试样本可以动态适应,从而提高分割

让我们先介绍一下背景。传统的语义分割通常由三个部分组成:一个 CNN 编码器,一个 CNN 解码器和简单的分类器来区分前景像素和背景像素。

当模型学习识别一个从未见过的新类别时,元学习需要分别训练这三个部分。如果新类别中的图像太少,则很难同时训练三个模块。

在这篇论文中, 提出一种新的训练方法,在面对新类别时只关注模型中最简单的分类器。统的分割网络学习了大量的图片和信息的图片和信息,可以从任何图片中充分捕捉到有利于区分背景和前景的信息,无论在训练中是否遇到类似的图片。所以面对新的样本,只要分类器学习元。

首先总结了这个阅读笔记。 CWT-for-FSS 介绍了整体结构的训练方法,然后分析了实验结果,最后对代码训练做了简单的指导。

2 CWT-for-FSS 整体架构

一个小样本分类系统一般由三部分构成:编码器,解码器和分类器。

其中,前两个模块模型复杂,最后一个分类器结构简单。小样本分类方法通常在元学习过程中更新除编码器以外的所有模块或模块,只有少数样本使用更新模块。

在这种情况下,与数据提供的信息相比,模型更新的参数过多,不足以优化模型参数。基于这一分析,本文提出了一种新的元学习训练范式,即只对分类器进行元学习。两种方法的比较如下图所示:

值得注意的是,我们知道 Support set 上迭代的模型往往不能很好地作用在 Query set 因为同类图像也可能不同。

利用 CWT 要解决这个问题,就是本文的重点。也就是说,可以动态使用 Query set 进一步更新分类器的特征信息,以提高分割的准确性。整体结构如下图所示:

借助 Transformer 将分类器权重转化为 Query,将 Query set 提取的特征转化为 Key 和 Value,然后根据这三个值调整分类器的权重,最后通过残差连接与原分类器的参数和谐。

3 求解方法

首先,对网络进行预训练,这里就不赘述了。 CWT 元学习分为两个步骤。第一步是内循环,就像预训练一样,根据支持集上的图片和 mask 只修改分类器参数进行训练。

当新样本数量足够大时,只使用外循环,即只更新分类器 SOTA,但当面对小样本时,性能并不令人满意。第二步是外循环,根据每对查询图片,微调分类器参数。

微调参数仅针对此查询图片,不能用于其他查询图像,也不能覆盖修改原分类器参数。

假设查询图像,提取的特征是F,形状为n × d,n单通道的像素数,d如果是通道数,则全连接分类器参数 w 形状为 2 × d。参照 Transformer,令Query = w × Wq, Key = F × Wk, Value = F × Wv,其中 Wq、Wk 和Wv 都是可学的d × da矩阵,d 为维度数,da 本文将其设置为人为规定的隐藏层维度 2048年。根据这三个数字和残差链接,新分类器的权重为:

其中,Ψ 输入维度为线性层 da,输出维度为 d。softmax 针对的维度是行的。在找出每个查询集对应的权重后,只需将特征放在一边 F 塞进 w* 就好。

4 分析实验结果

本部分在两个标准样本分割数据集中显示了论文中的实验结果 PASCAL 和 COCO 在大多数情况下,本文中的方法取得了最佳效果。

此外,模型的性能在跨数据集的情况下进行了测试 CWT-for-FSS 该方法具有良好的鲁棒性。

最后,可视化结果如下:

5 代码和可视化

开源代码 https://github.com/lixiang007666/CWT-for-FSS 最后,让我们简要看看如何使用它。仓库提供训练脚本:

shscripts/train.shpascal0[0]501

数据集依次指定了以下参数。split 数、gpus、layers 和 k-shots。若需要多卡训练,gpus 为[0,1,3,4,5,6,layers 除了 50 也可指定为 101,说明 backbone 为 resnet101。对应的,测试的脚本为 scripts/test.sh。

此外,仓库中的代码没有提供可视化脚本。如果需要可视化分割结果,请参考以下代码。首先将以下内容插入主 test.py 脚本(在 classes.append() 下方):

logits_q[i]=pred_q.detach() gt_q[i,0]=q_label classes.append([class_.item()forclass_insubcls]) #Insertvisualizationroutinehere ifargs.visualize: output={} output['query'],output['support']={},{} output['query']['gt'],output['query']['pred']=vis_res(qry_oris[0][0],qry_oris[1],F.interpolate(pred_q,size=q_label.size()[1:],mode='bilinear',align_corners=True).squeeze().detach().cpu().numpy()) spprt_label=torch.cat(spprt_oris[1],0)       output['support']['gt'], output['support']['pred'] = vis_res(spprt_oris[0][0][0],spprt_label, output_support.squeeze().detach().cpu().numpy())

                    save_image = np.concatenate((output['support']['gt'], output['query']['gt'], output['query']['pred']), 1)
                    cv2.imwrite('./analysis/' + qry_oris[0][0].split('/')[-1] ,   save_image)

主要可视化函数vis_res如下:

def resize_image_label(image, label, size = 473):
    import cv2
    def find_new_hw(ori_h, ori_w, test_size):
        if ori_h >= ori_w:
            ratio = test_size * 1.0 / ori_h
            new_h = test_size
            new_w = int(ori_w * ratio)
        elif ori_w > ori_h:
            ratio = test_size * 1.0 / ori_w
            new_h = int(ori_h * ratio)
            new_w = test_size

        if new_h % 8 != 0:
            new_h = (int(new_h / 8)) * 8
        else:
            new_h = new_h
        if new_w % 8 != 0:
            new_w = (int(new_w / 8)) * 8
        else:
            new_w = new_w
        return new_h, new_w

    # Step 1: resize while keeping the h/w ratio. The largest side (i.e height or width) is reduced to $size.
    #                                             The other is reduced accordingly
    test_size = size
    new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size)

    image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)),
                            interpolation=cv2.INTER_LINEAR)

    # Step 2: Pad wtih 0 whatever needs to be padded to get a ($size, $size) image
    back_crop = np.zeros((test_size, test_size, 3))

    back_crop[:new_h, :new_w, :] = image_crop
    image = back_crop

    # Step 3: Do the same for the label (the padding is 255)
    s_mask = label
    new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size)
    s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)),
                        interpolation=cv2.INTER_NEAREST)
    back_crop_s_mask = np.ones((test_size, test_size)) * 255
    back_crop_s_mask[:new_h, :new_w] = s_mask
    label = back_crop_s_mask

    return image, label
def vis_res(image_path, label, pred):

    import cv2
    def read_image(path):
        image = cv2.imread(path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = np.float32(image)
        return image

    def label_to_image(label):
        label = label == 1.
        label = np.float32(label) * 255.
        placeholder = np.zeros_like(label)
        label = np.concatenate((label, placeholder), 0)
        label = np.concatenate((label, placeholder), 0)
        label = np.transpose(label, (1,2,0))
        return label

    def blend_image_label(image, label):
        result = 0.5 * image + 0.5 * label
        result = np.float32(result)
        result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)

        return result

    def pred_to_image(label):
        label = np.float32(label) * 255.
        placeholder = np.zeros_like(label)
        placeholder = np.concatenate((placeholder, placeholder), 0)
        label = np.concatenate((placeholder, label), 0)
        label = np.transpose(label, (1,2,0))
        return label

    image = read_image(image_path)
    label = label.squeeze().detach().cpu().numpy()
    image, label = resize_image_label(image, label)
    label = label_to_image(np.expand_dims(label, 0))
    out_image_gt = blend_image_label(image, label)
    #cv2.imwrite('./analysis/' + image_path.split('/')[-1][:-4] +  '_gt.jpg',   out_image)

    pred  = np.argmax(pred, 0)
    pred = np.expand_dims(pred, 0)
    pred = pred_to_image(pred)
    out_image_pred = blend_image_label(image, pred)
    #cv2.imwrite('./analysis/' + image_path.split('/')[-1][:-4] +  '_pred.jpg',   out_image)

    return out_image_gt, out_image_pred

注意,是在每次测试迭代结束时可视化分割结果。

6 总结

这篇阅读笔记介绍了一种新的元学习训练范式来解决小样本语义分割问题。相比于现有的方法,这种方法更加简洁有效,只对分类器进行元学习。

重要的是,为了解决类内差异问题,提出 Classifier Weight Transformer 利用 Query 特征信息来迭代训练分类器,从而获得更加鲁棒和精准的分割效果。

7 参考链接

  • https://github.com/zhiheLu/CWT-for-FSS

  • https://arxiv.org/pdf/2108.03032.pdf

本文仅做学术分享,如有侵权,请联系删文。

后台回复:即可下载国外大学沉淀数年3D Vison精品课件

后台回复:即可下载3D视觉领域经典书籍pdf

后台回复:即可学习3D视觉领域精品课程

1.面向自动驾驶领域的多传感器数据融合技术

2.面向自动驾驶领域的3D点云目标检测全栈学习路线!(单模态+多模态/数据+代码)3.彻底搞透视觉三维重建:原理剖析、代码讲解、及优化改进4.国内首个面向工业级实战的点云处理课程5.激光-视觉-IMU-GPS融合SLAM算法梳理和代码讲解6.彻底搞懂视觉-惯性SLAM:基于VINS-Fusion正式开课啦7.彻底搞懂基于LOAM框架的3D激光SLAM: 源码剖析到算法优化8.彻底剖析室内、室外激光SLAM关键算法原理、代码和实战(cartographer+LOAM +LIO-SAM)

9.从零搭建一套结构光3D重建系统[理论+源码+实践]

10.单目深度估计方法:算法梳理与代码实现

11.自动驾驶中的深度学习模型部署实战

12.相机模型与标定(单目+双目+鱼眼)

13.重磅!四旋翼飞行器:算法与实战

14.ROS2从入门到精通:理论与实战

15.国内首个3D缺陷检测教程:理论、源码与实战

16.基于Open3D的点云处理入门与实战教程

扫码添加小助手微信,可申请加入3D视觉工坊-学术论文写作与投稿 微信交流群,旨在

也可申请加入我们的细分方向交流群,目前主要有等微信群,请扫描下面微信号加群,备注:”研究方向+学校/公司+昵称“,例如:”3D视觉 + 上海交大 + 静静“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进去相关微信群。也请联系。

▲长按加微信群或投稿

▲长按关注公众号

:针对3D视觉领域的五个方面进行深耕,更有各类大厂的算法工程人员进行技术指导。与此同时,星球将联合知名企业发布3D视觉相关算法开发岗位以及项目对接信息,打造成集技术与就业为一体的铁杆粉丝聚集区,近4000星球成员为创造更好的AI世界共同进步,知识星球入口:

学习3D视觉核心技术,扫描查看介绍,3天内无条件退款

 圈里有高质量教程资料、答疑解惑、助你高效解决问题

标签: f37h影像传感器

锐单商城拥有海量元器件数据手册IC替代型号,打造 电子元器件IC百科大全!

锐单商城 - 一站式电子元器件采购平台