【DenseFusion代码详解】测试过程eval_linemod.py_densefusion源码-程序员宅基地

技术标签: python  深度学习  pytorch  【DenseFusion详解】  

DenseFusion系列代码全讲解目录:【DenseFusion系列目录】代码全讲解+可视化+计算评估指标_Panpanpan!的博客-程序员宅基地

这些内容均为个人学习记录,欢迎大家提出错误一起讨论一起学习!


该部分是对LineMod数据集训练结束之后的模型进行评估,代码位置在tools/eval_linemod.py

训练部分包括train和test,评估过程是eval,eval和test的不同之处,浅浅理解就是test过程还是会改变权值,但eval固定权值不变。

parser = argparse.ArgumentParser()
parser.add_argument('--dataset_root', type=str, default = '', help='dataset root dir')
parser.add_argument('--model', type=str, default = '',  help='resume PoseNet model')
parser.add_argument('--refine_model', type=str, default = '',  help='resume PoseRefineNet model')
opt = parser.parse_args()

首先是在.sh文件中可以设置的变量,有数据集路径、保存的训练好的PoseNet模型,保存的训练好的PoseRefineNet模型。

num_objects = 13
objlist = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]
num_points = 500
iteration = 4
bs = 1
dataset_config_dir = 'datasets/linemod/dataset_config'
output_result_dir = 'experiments/eval_result/linemod'
knn = KNearestNeighbor(1)

然后设置物体类别数、类别编号列表、点云数、refine过程的循环次数、批量大小、数据集config文件路径、输出结果路径、knn算法。

estimator = PoseNet(num_points = num_points, num_obj = num_objects)
estimator.cuda()
refiner = PoseRefineNet(num_points = num_points, num_obj = num_objects)
refiner.cuda()
estimator.load_state_dict(torch.load(opt.model))
refiner.load_state_dict(torch.load(opt.refine_model))
estimator.eval()
refiner.eval()

然后初始化estimator和refiner,再加载相应的保存好的模型,设为eval模式,也就是用训练好的模型来进行评估。

testdataset = PoseDataset_linemod('eval', num_points, False, opt.dataset_root, 0.0, True)
testdataloader = torch.utils.data.DataLoader(testdataset, batch_size=1, shuffle=False, num_workers=10)

加载test数据集。

sym_list = testdataset.get_sym_list()
num_points_mesh = testdataset.get_num_points_mesh()
criterion = Loss(num_points_mesh, sym_list)
criterion_refine = Loss_refine(num_points_mesh, sym_list)

获取对称物体编号、mesh点数,定义loss计算和loss_refine计算。

diameter = []
meta_file = open('{0}/models_info.yml'.format(dataset_config_dir), 'r')
meta = yaml.load(meta_file)
for obj in objlist:
    diameter.append(meta[obj]['diameter'] / 1000.0 * 0.1)

这里获取每类物体三维模型外接球直接的直径,当计算出来的距离值小于该直接的10%,就认为姿态估计正确。

success_count = [0 for i in range(num_objects)]
num_count = [0 for i in range(num_objects)]
fw = open('{0}/eval_result_logs.txt'.format(output_result_dir), 'w')

记录正确估计的数量和总的数量,以及输出log文件。

下面进入循环:

for i, data in enumerate(testdataloader, 0):
    points, choose, img, target, model_points, idx = data
    if len(points.size()) == 2:
        print('No.{0} NOT Pass! Lost detection!'.format(i))
        fw.write('No.{0} NOT Pass! Lost detection!\n'.format(i))
        continue
    points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                     Variable(choose).cuda(), \
                                                     Variable(img).cuda(), \
                                                     Variable(target).cuda(), \
                                                     Variable(model_points).cuda(), \
                                                     Variable(idx).cuda()

和train过程一样,获取预处理后的数据。

    pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)

estimator为训练好的PoseNet模型,计算出预测的R和t以及置信度。

    pred_r = pred_r / torch.norm(pred_r, dim=2).view(1, num_points, 1)
    pred_c = pred_c.view(bs, num_points)
    how_max, which_max = torch.max(pred_c, 1)
    pred_t = pred_t.view(bs * num_points, 1, 3)

这里就是跟loss.py里面一样,首先对r进行标准化,然后选取置信度最大的像素,对R和t变化形状。

    my_r = pred_r[0][which_max[0]].view(-1).cpu().data.numpy()
    my_t = (points.view(bs * num_points, 1, 3) + pred_t)[which_max[0]].view(-1).cpu().data.numpy()
    my_pred = np.append(my_r, my_t)

然后最终的R和t就是置信度最大的那个像素预测的结果,注意,这里的my_t加了points,也就是绝对的偏移,将他俩组合起来形成my_pred。

以上就是没有refine过程的评估,下面开始refine过程:

    for ite in range(0, iteration):
        T = Variable(torch.from_numpy(my_t.astype(np.float32))).cuda().view(1, 3).repeat(num_points, 1).contiguous().view(1, num_points, 3)
        my_mat = quaternion_matrix(my_r)
        R = Variable(torch.from_numpy(my_mat[:3, :3].astype(np.float32))).cuda().view(1, 3, 3)
        my_mat[0:3, 3] = my_t

循环iteration次,获取PoseNet计算的T,这里my_r是四元数表示,quaternion_matrix()函数计算原始旋转矩阵(3*3),但返回4*4的矩阵,前[:3,:3]为旋转矩阵,结尾为1,其余为0,然后获取tensor形式的R,将my_mat的最后一行前三个数设为偏移t。

        new_points = torch.bmm((points - Tt), R).contiguous()
        pred_r, pred_t = refiner(new_points, emb, idx)
        pred_r = pred_r.view(1, 1, -1)
        pred_r = pred_r / (torch.norm(pred_r, dim=2).view(1, 1, 1))
        my_r_2 = pred_r.view(-1).cpu().data.numpy()
        my_t_2 = pred_t.view(-1).cpu().data.numpy()
        my_mat_2 = quaternion_matrix(my_r_2)
        my_mat_2[0:3, 3] = my_t_2

这里跟loss_refiner.py的思想是一样的,用points计算逆转之后的new_points,然后输入到PoseRefineNet模型中纠正姿态,输出新预测的旋转和偏移。然后同样地将四元数表示转换成旋转矩阵表示,再用my_mat_2记录新预测的旋转和偏移。

        my_mat_final = np.dot(my_mat, my_mat_2)
        my_r_final = copy.deepcopy(my_mat_final)
        my_r_final[0:3, 3] = 0
        my_r_final = quaternion_from_matrix(my_r_final, True)
        my_t_final = np.array([my_mat_final[0][3], my_mat_final[1][3], my_mat_final[2][3]])

第一行,将my_mat和my_mat_2相乘,这里,my_mat实际上是PoseNet预测的点云,my_mat_2相当于预测的新的姿态,相乘之后得到纠正的姿态。然后取出纠正的R和t,再将R转换成四元数表示。

        my_pred = np.append(my_r_final, my_t_final)
        my_r = my_r_final
        my_t = my_t_final

my_pred更新为refine过程之后的姿态。依次循环之后结束refine过程。

    model_points = model_points[0].cpu().detach().numpy()

    my_r = quaternion_matrix(my_r)[:3, :3]
    pred = np.dot(model_points, my_r.T) + my_t  
    target = target[0].cpu().detach().numpy()

获取model_points第一帧点云数据,然后将预测的my_r从四元数转换成旋转矩阵,和model_points相乘再加上偏移,得到目标点云target,这个target和标准目标点云是有区别的。

    if idx[0].item() in sym_list:
        pred = torch.from_numpy(pred.astype(np.float32)).cuda().transpose(1, 0).contiguous()
        target = torch.from_numpy(target.astype(np.float32)).cuda().transpose(1, 0).contiguous()
        inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
        target = torch.index_select(target, 1, inds.view(-1) - 1)
        dis = torch.mean(torch.norm((pred.transpose(1, 0) - target.transpose(1, 0)), dim=1), dim=0).item()
    else:
        dis = np.mean(np.linalg.norm(pred - target, axis=1))

这一部分就是计算每个点的dis。如果是对称物体,则计算ADD-S,不是对称物体就计算ADD。

    if dis < diameter[idx[0].item()]:
        success_count[idx[0].item()] += 1
        print('No.{0} Pass! Distance: {1}'.format(i, dis))
        fw.write('No.{0} Pass! Distance: {1}\n'.format(i, dis))
    else:
        print('No.{0} NOT Pass! Distance: {1}'.format(i, dis))
        fw.write('No.{0} NOT Pass! Distance: {1}\n'.format(i, dis))
    num_count[idx[0].item()] += 1

比较dis和直径的大小,如果小于之间,则认为姿态估计正确,success_count+1,否则估计错误。整个循环结束。

for i in range(num_objects):
    print('Object {0} success rate: {1}'.format(objlist[i], float(success_count[i]) / num_count[i]))
    fw.write('Object {0} success rate: {1}\n'.format(objlist[i], float(success_count[i]) / num_count[i]))
print('ALL success rate: {0}'.format(float(sum(success_count)) / sum(num_count)))
fw.write('ALL success rate: {0}\n'.format(float(sum(success_count)) / sum(num_count)))
fw.close()

最后,计算所有物体被正确估计的个数/总物体数量,得到准确率,并保存。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_44564705/article/details/125110993

智能推荐

[Andrioid开发] Splash界面/用户协议与隐私政策弹窗/界面开发_android 隐私协议弹窗-程序员宅基地

文章浏览阅读4k次,点赞3次,收藏21次。[Andrioid开发] Splash界面/用户协议与隐私政策弹窗/界面开发启动页界面开发、首次启动时的启动页用户协议与隐私政策弹窗,只要不点击同意每次打开都会显示弹窗,同意后立即跳转到主界面,当下次再进入软件就是两秒后自动跳转到主界面。_android 隐私协议弹窗

java实现数字千分位的显示_java 数字显示k w-程序员宅基地

文章浏览阅读1.3w次。由于项目中要求输入的数字用千分位显示,数字保留两位小数,而且要求再删除数字的时候也要求删除后的数字也要是千分位显示,好像表达的有点不清楚,贴代码吧,作为一个小工具吧。 /** * 格式化数字为千分位显示; * @param 要格式化的数字; * @return */ public static String fmtMicrometer(String text)_java 数字显示k w

Android Paging library的本地数据Demo_android paging 本地静态数据-程序员宅基地

文章浏览阅读677次。分页库属于架构组件(Architecture Components)的一部分,配合RecyclerView使用,主要用来实现无感分页加载。官方文档链接为:https://developer.android.google.cn/topic/libraries/architecture/paging本文参照官方文档来做一个简单的实现,主要分以下几步:1、导库:def support_..._android paging 本地静态数据

使用stm32+esp8266-01s与电脑进行mqtt交互-程序员宅基地

文章浏览阅读1.2w次,点赞7次,收藏55次。注意:本文以简单易理解易实现为主,仅实现最基本的交互通信功能,性能和稳定性暂无考虑。需要材料:硬件:stm32及下载线、esp8266-01s(wifi模块)软件:emqx、keil可选:wireshark,python开始:配置stm32工程首先,我们需要一个stm32的基础工程,为了调试需要,我们需要两个usart串口分别与电脑和wifi模块进行通信。打开stm32cube 需要配置的有RCC、SYS、USART、时钟、project manager几部分。.

Android 读取外设U盘(USB)文件。-程序员宅基地

文章浏览阅读3.9k次,点赞2次,收藏9次。在AndroidManifest文件中添加USB权限 <uses-permission android:name="android.permission.USB_PERMISSION" /> <uses-permission android:name="android.permission.MANAGE_USB"/> <uses-feature android:name="android.hardware.usb.host" android:requ...

Select模型原理_select模型通知机制-程序员宅基地

文章浏览阅读323次。Select模型原理利用select函数,判断套接字上是否存在数据,或者能否向一个套接字写入数据。目的是防止应用程序在套接字处于锁定模式时,调用recv(或send)从没有数据的套接字上接收数据,被迫进入阻塞状态。 select参数和返回值意义如下:int select ( IN int nfds, //0,无意义 IN_select模型通知机制

随便推点

2018.11.6 PION 模拟赛-程序员宅基地

文章浏览阅读69次。期望:100 + 40 + 50 = 190实际:60 + 10 + 50 = 120考得好炸啊!!T1数组开小了炸掉40,T2用 int 读入 long long ,int存储 long long 炸掉 20T3可以吧for维护最大值变成o(1),但是木想到啊,只想写暴力了。。。w(゚Д゚)w最近考试低级错误一个接一个啊!!noip肿么玩啊。。简直没法好好玩耍了。感..._期望的线性性质,考虑每个数的贡献,一个数能够做出1的贡献当且仅当它被自己删掉,所

工厂模式(初学)-程序员宅基地

文章浏览阅读259次。是一种创建型设计模式,旨在通过一个工厂类(简单工厂)来封装对象的实例化过程

[MFC] CWnd类总结-程序员宅基地

文章浏览阅读7.8k次,点赞10次,收藏73次。一、MFC 类别阶层架构二、CWnd类CWnd是MFC的一个窗口类,这个类里几乎封装了所有关于窗口操作的API函数。在Windows系统里,一个窗口的属性分两个地方存放:一部分放在“窗口类”里头,如上所述的在注册窗口时指定;另一部分放在Windows Object本身,如:窗口的尺寸,窗口的位置(X,Y轴),窗口的Z轴顺序,窗口的状态(ACTIVE,MINIMIZED,MAXM..._cwnd

一、数据可视化之堆叠面积图 - Stacked Area Graph-程序员宅基地

文章浏览阅读6.6k次。堆叠面积图把研究整体的演化和各个群体的相对比例变简单!_堆叠面积图

适用于 Windows 10 的触摸板手势(from Microsoft 帮助)附双指右击无法使用的处理方法(ELAN)_elan 需要重启-程序员宅基地

文章浏览阅读2.6k次,点赞2次,收藏3次。在 Windows 10 笔记本电脑的触摸板上试用这些手势。选择项目:点击触摸板。 滚动:将两个手指放在触摸板上,然后以水平或垂直方向滑动。 放大或缩小:将两个手指放在触摸板上,然后收缩或拉伸。 显示更多命令(类似于右键单击):使用两根手指点击触摸板,或按右下角。 查看所有打开的窗口:将三根手指放在触摸板上,然后朝外轻扫。 显示桌面:将三根手指放在触摸板上,然后朝里轻扫。 在打开的窗口之间切换:将三根手指放在触摸板上,然后向右或向左轻扫。 打开 Cortana:用三根手指点击触摸板。 ._elan 需要重启

Linux版awvs最新版v_190325161的安装记录_awvs_linux.zip-程序员宅基地

文章浏览阅读6.2k次,点赞2次,收藏4次。因为之前52的安装教程贴被删除了所以我自己重新记录一下方便以后的使用也是自己用的环境 ----我用的是2019的ubantu和Xshell下载地址安装环境依赖(如果有问题先更新一下源)root@kali:~# sudo apt-get install libxdamage1 libgtk-3-0 libasound2 libnss3 libxss1 -y正在读取软件包列表… ..._awvs_linux.zip

推荐文章

热门文章

相关标签