pytorch-semseg源码解读test.py_ptsemseg_蓝德库洛尔多的博客-程序员宅基地

技术标签: 图像分割代码  tensorflow  机器学习  深度学习  pytorch  神经网络  

这部分代码很坑,原作者代码里若不更改命令行参数norm,则会进行两次标准化

import os
import torch
import argparse
import numpy as np
import scipy.misc as misc


from ptsemseg.models import get_model
from ptsemseg.loader import get_loader
from ptsemseg.utils import convert_state_dict

try:
    import pydensecrf.densecrf as dcrf
except:
    print(
        "Failed to import pydensecrf,\
           CRF post-processing will not work"
    )# 导入CRF后处理


def test(args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_file_name = os.path.split(args.model_path)[1]# 命令行传参,模型路径
    model_name = model_file_name[: model_file_name.find("_")]

    # Setup image
    print("Read Input Image from : {}".format(args.img_path))# 图片路径
    img = misc.imread(args.img_path)

    data_loader = get_loader(args.dataset)
    loader = data_loader(root=None, is_transform=True, img_norm=args.img_norm, test_mode=True)
    n_classes = loader.n_classes# 获取指定训练集的类别数

    resized_img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]), interp="bicubic")
    # 将图片变形成模型接受的尺寸
    orig_size = img.shape[:-1]# 除了最后一个元素(通道)的切片,返回H*W
    if model_name in ["pspnet", "icnet", "icnetBN"]:
        # uint8 with RGB mode, resize width and height which are odd numbers
        img = misc.imresize(img, (orig_size[0] // 2 * 2 + 1, orig_size[1] // 2 * 2 + 1))
    else:
        img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]))
        # 个别网络的输入是原size+1

    img = img[:, :, ::-1]# 最后一维逆序读取
    img = img.astype(np.float64)
    img -= loader.mean# 标准化,减去均值
    if args.img_norm:
        img = img.astype(float) / 255.0

    # NHWC -> NCHW
    img = img.transpose(2, 0, 1)
    img = np.expand_dims(img, 0)# 增加一维
    img = torch.from_numpy(img).float()

    # Setup Model
    model_dict = {"arch": model_name}
    model = get_model(model_dict, n_classes, version=args.dataset)
    state = convert_state_dict(torch.load(args.model_path)["model_state"])
    # 读取了网络结构的名字和对应的参数
    model.load_state_dict(state)
    model.eval()# model.eval() :针对单张图片,不启用 BatchNormalization 和 Dropout
    model.to(device)

    images = img.to(device)
    outputs = model(images)# n张图片*n个class的概率*h*w

    if args.dcrf:
        unary = outputs.data.cpu().numpy()
        unary = np.squeeze(unary, 0)
        unary = -np.log(unary)
        unary = unary.transpose(2, 1, 0)
        w, h, c = unary.shape
        unary = unary.transpose(2, 0, 1).reshape(loader.n_classes, -1)
        unary = np.ascontiguousarray(unary)

        resized_img = np.ascontiguousarray(resized_img)

        d = dcrf.DenseCRF2D(w, h, loader.n_classes)
        d.setUnaryEnergy(unary)
        d.addPairwiseBilateral(sxy=5, srgb=3, rgbim=resized_img, compat=1)

        q = d.inference(50)
        mask = np.argmax(q, axis=0).reshape(w, h).transpose(1, 0)
        decoded_crf = loader.decode_segmap(np.array(mask, dtype=np.uint8))
        dcrf_path = args.out_path[:-4] + "_drf.png"
        misc.imsave(dcrf_path, decoded_crf)
        print("Dense CRF Processed Mask Saved at: {}".format(dcrf_path))

    pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0)
    # 输出h*w,从outputs中取了每个像素预测概率最大的那个值和索引位置,
    # 其中outputs.data.max(1)[]中,返回值有两个,第一个是概率最大值组成的矩阵,
    # 第二个是最大值所在维索引组成的矩阵,这里取得是第二个,即[1]
    # outputs.data.max(1)[1].cpu().numpy()返回1*w*h矩阵,squeeze删除维度为1的维
    if model_name in ["pspnet", "icnet", "icnetBN"]:
        pred = pred.astype(np.float32)
        # float32 with F mode, resize back to orig_size
        pred = misc.imresize(pred, orig_size, "nearest", mode="F")

    decoded = loader.decode_segmap(pred)# 得到Mask颜色图
    print("Classes found: ", np.unique(pred))# 得到寻找到的类
    misc.imsave(args.out_path, decoded)
    print("Segmentation Mask Saved at: {}".format(args.out_path))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Params")
    parser.add_argument(
        "--model_path",
        nargs="?",
        type=str,
        default="fcn8s_pascal_1_26.pkl",
        help="Path to the saved model",
    )
    parser.add_argument(
        "--dataset",
        nargs="?",
        type=str,
        default="pascal",
        help="Dataset to use ['pascal, camvid, ade20k etc']",
    )

    parser.add_argument(
        "--img_norm",
        dest="img_norm",
        action="store_true",
        help="Enable input image scales normalization [0, 1] \
                              | True by default",
    )
    parser.add_argument(
        "--no-img_norm",
        dest="img_norm",
        action="store_false",
        help="Disable input image scales normalization [0, 1] |\
                              True by default",
    )
    parser.set_defaults(img_norm=True)

    parser.add_argument(
        "--dcrf",
        dest="dcrf",
        action="store_true",
        help="Enable DenseCRF based post-processing | \
                              False by default",
    )
    parser.add_argument(
        "--no-dcrf",
        dest="dcrf",
        action="store_false",
        help="Disable DenseCRF based post-processing | \
                              False by default",
    )
    parser.set_defaults(dcrf=False)

    parser.add_argument(
        "--img_path", nargs="?", type=str, default=None, help="Path of the input image"
    )
    parser.add_argument(
        "--out_path", nargs="?", type=str, default=None, help="Path of the output segmap"
    )
    args = parser.parse_args()
    test(args)

 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_33500389/article/details/106676294

智能推荐

美国东北大学计算机科学,美国东北大学计算机科学专业申请要求有哪些?课程设置有哪些?...-程序员宅基地

众所周知,计算机专业是当下受关注度最高的专业,同时也是我国学子出国留学最常选择的专业,纵观全球,美国西北大学开设的计算机科学专业就凭借独特的教学优势吸引了很多人的目光,为此今天小编就为大家整理了美国东北大学计算机科学专业申请要求和课程设置等相关信息。在此推荐给大家以供广大远赴他国的学习的学子们作为参考,希望能对你有所帮助。有任何问题都可以咨询IDP留学顾问哦!美国东北大学计算机科学专业申请要求本专...

C++蓝桥杯 基础练习之高精度加法-程序员宅基地

C++ 蓝桥杯题目讲解汇总(持续更新) VIP试题 高精度加法资源限制时间限制:1.0s 内存限制:512.0MB问题描述输入两个整数a和b,输出这两个整数的和。a和b都不超过100位。算法描述由于a和b都比较大,所以不能直接使用语言中的标准数据类型来存储。对于这种问题,一般使用数组来处理。  定义一个数组A,A[0]用于存储a的个位,A[1]用于存储a的十位,依此类推。同样可...

东北大学应用数理统计知识点总结——贝叶斯估计-程序员宅基地

贝叶斯估计一、基础知识1.1 常用分布函数X ~ B(α,β)B(\alpha, \beta)B(α,β): f(x)=yα−1(1−y)β−1f(x) = y^{\alpha - 1}(1-y)^{\beta - 1}f(x)=yα−1(1−y)β−1E(x)=αα+βE(x) = \frac{\alpha}{\alpha + \beta}E(x)=α+βα​X ~ Γ(α,β)\Gamma(\alpha, \beta)Γ(α,β): f(x)=βαΓ(α)xα−1e−βx,x>0,

VC中缺省库冲突的解决 - fmddlmyy的专栏 - CSDNBlog-程序员宅基地

VC在编译程序时有两个习惯: 1、在从头开始编译时(即生成makefile时),将源文件名按字母排序后,依次处理; 2、一边编译一边决定需要哪些缺省库。 它的这些习惯有时会造成奇怪的编译错误,例如项目中有两个文件: charutil.c gbuni.cpp 其中gbnni.cpp用到了MFC库。 编译器先处理charutil.c,然后觉得需要link一个C Runtime库,根据项目设置选

86 R k-means,层次,EM聚类介绍_层次聚类和em聚类的区别-程序员宅基地

R k-means,层次,EM聚类1什么是客户分群什么是分群?为什么要分群?2 聚类分析方法论客户分群的算法3 样本间距离定义4 层次聚类分析方法论层次聚类概述层次聚类的步骤详解层次聚类的优缺点5 K-means聚类分析方法论K-Means聚类概述K-Means聚类步骤K-Means聚类要点K-Means聚类的优缺点K-Means聚类 vs. 层次聚类6 EM模型聚类分析方法论EM模型聚类概述EM模型聚类步骤EM模型与K-Means的关系7 实战操作1什么是客户分群什么是分群?将现有消费者群体按一定规_层次聚类和em聚类的区别

Zabbix的Windows客户端安装和配置,以及删除_windows 如何删除服务中的zabbix-程序员宅基地

Zabbix的Windows客户端安装和配置1、在浏览器中去下载相应的客户端https://www.zabbix.com/download_agents我选择的是 4.4.1 Windows amd64 下载的,解压之后应该是这个样子,有bin和conf两个文件夹2、在C盘目录下创建一个zabbix文件夹 , 把bin文件夹的zabbix_agentd.exe、zabbi..._windows 如何删除服务中的zabbix

随便推点

android 耳机孔 红外,手机遥控器,3.5mm耳机接口红外遥控改造解析-程序员宅基地

很多家电都用红外遥控,如电视机、机顶盒、空调、电风扇等。越来越多的遥控器反而给我们带来了更多的问题,有时找不到遥控器放哪儿了,或者混淆了都是麻烦,事实上对手机进行简单的改造,可以自制一个万能红外遥控器,来看看我们如何“掌控”家中电器的。方案解析:红外遥控器发送数据时,是将二进制数据调制成一系列的脉冲信号用940nm波长的红外发射管发射出去,红外载波为频率38KHz的方波,红外接收端在收到38KHz..._安卓口转3.5mm

地球大气层简介与垂直分层_大气的垂直分层-程序员宅基地

已下内容翻译自UCAR的科普文章,具体可直接访问以下链接。图片皆根据英文原文标明出处。如有侵权,请联系删除。Earth’s AtmosphereLayers of Earth’s Atmosphere文章目录地球大气层(Earth's atmosphere)简介地球的气体成分大气层的垂直分层对流层(Troposphere)平流层(Stratosphere)中间层(Mesosphere)热层..._大气的垂直分层

pt-online-schema-change安装使用详解_pt-cheng-online_菜鸟一直在路上的博客-程序员宅基地

一、pt-online介绍pt-online-schema-change是percona公司开发的一个工具,在percona-toolkit包里面可以找到这个功能,它可以在线修改表结构原理:首先它会新建一张一模一样的表,表名一般是_new后缀然后在这个新表执行更改字段操作然后在原表上加三个触发器,DELETE/UPDATE/INSERT,将原表中要执行的语句也在新表中执行最后将原表的数据拷贝到新表中,然后替换掉原表使用pt-online-schema-change执行SQL的日志 SQL语句:_pt-cheng-online

如何编译并运行师兄的ICP -SVD 例程_风从海上来556的博客-程序员宅基地

参考链接:本文引用了以下两个链接【手写ICP】ICP -SVD 手动实现与例程(上)(1)https://blog.csdn.net/sinat_25923849/article/details/114969183【手写ICP】ICP -SVD 手动实现与例程(下)(2)https://blog.csdn.net/sinat_25923849/article/details/114969930进入自己放工程的文件夹$git clone https://gitee.com/jqf64078/icp

那年Java那些事 01 — 环境变量和集成开发工具_用集成开发环境还需要配置环境变量吗-程序员宅基地

那年Java那些事 01 — 环境变量和集成开发工具俗语云:工欲善其事,必先利其器。想要学习好Java编程语言,配置相应的环境、安装相应的开发工具是必不可少的。那么,什么是开发环境,什么又是集成开发工具,怎样进行配置呢?下面将会跟大家一一谈到。一、环境变量的安装和配置(JDK)1.1 环境变量环境变量是什么呢?其实我们可以把它理解为【系统的视线范围】。没错,配置进入了环境变量的程序,就等于是进入了系统的视线范围,打开DOS命令窗口后输入程序名,系统就会把在其视线内的(环境变量内)的程_用集成开发环境还需要配置环境变量吗

获取微信所有聊天记录数据并通过Python制作词云图_python抓取微信群聊天记录_致橡树.的博客-程序员宅基地

获取微信所有聊天记录数据并通过Python制作词云图前言本文纯原创,仅供学习、交流使用。不具有任何商业用途,版权归作者所有,如有问题请及时联系我以作处理。作者仅为一名大二学生,能力有限,并且也是我在csdn上的第一篇文章,如果文章中有些问题还请指出,谢谢大家。如果对你有帮助,请记得好评噢。作者:做个财务自由的CEO项目效果图:接下来就是大家期待的过程了,具体如下:一、整个项目一共经历三个步骤:(1)首先获取微信聊天数据将其生成txt文本。(2)对获取的文本进行数据清._python抓取微信群聊天记录

推荐文章

热门文章

相关标签