YOLO v3 中关于 anchor 的 k-means 聚类代码_yolov3 anchor_轮子去哪儿了的博客-程序员秘密

技术标签: YOLO  

1. k-means 聚类代码

我使用的代码是:https://github.com/lars76/kmeans-anchor-boxes
其他的k-means 代码(没用过)是:

  1. https://github.com/qqwweee/keras-yolo3/blob/master/kmeans.py
  2. https://github.com/AlexeyAB/darknet/blob/master/scripts/gen_anchors.py
    输入是存放 xml标签文件的文件夹:
    只需要更改 example.py 中的一行代码:
ANNOTATIONS_PATH = "xmlLabel/train"  # 更改自己的路径(存放训练标签 xml 的文件路径)
  1. 运行 example.py 计算当前数据集的需要设置的 anchor 的大小(相对于416输入而言)

在我的数据集上的输出结果如下:

rows =  8607  #  我的 label 目标的数量
[[0.01416016 0.015625  ]  # 每一个 anchor的宽/图像的宽 ,高/高
 [0.00830078 0.00927734]
 [0.06542969 0.06982422]
 [0.03417969 0.03662109]
 [0.01123047 0.01220703]
 [0.02685547 0.02832031]
 [0.01757812 0.01953125]
 [0.04443359 0.04833984]
 [0.02148438 0.0234375 ]]
Accuracy: 83.41%
Boxes:
 [ 5.890625  3.453125 27.21875  14.21875   4.671875 11.171875  7.3125   18.484375  8.9375  ]-  # 每个 anchor 的宽
 [ 6.5       3.859375 29.046875 15.234375  5.078125 11.78125   8.125    20.109375  9.75    ]  # # 每个 anchor 的高
Ratios:
 [0.89, 0.9, 0.91, 0.92, 0.92, 0.92, 0.93, 0.94, 0.95]  # 每个 anchor 的 宽/高
  1. 对输入anchor 进行排序后的结果是(取整数是为了好看):
[3, 4, 5, 7, 8, 11, 14, 18, 27]
[3, 5, 6, 8, 9, 11, 15, 20, 29]
anchor_416 = 3, 3, 4, 5, 5, 6,   7, 8, 8, 9, 11, 11,   14, 15, 18, 20, 27, 29 
anchor_416_2 = 6, 7, 9, 10, 11, 13,   14, 16, 17, 19, 22, 23,   28, 30, 36, 40, 54, 58
anchor_416_3 = 10, 11, 14, 15, 17, 19,   21, 24, 26, 29, 33, 35,  42, 45, 55, 60, 81, 87
anchor_416_4 = 13, 15, 18, 20, 23, 26,   29,32, 35,39, 44,47,  56,60, 73,80, 108,116
anchor_416_5 = 17, 19, 23, 25, 29, 32,   36, 40, 44, 48, 55, 58,   71, 76, 92, 100, 136, 145

将 anchor 排序的代码如下(自己写的):

import numpy as np 

# anchors = [10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326]
# for i in range(0, len(anchors), 2):
#   print(anchors[i] * anchors[i + 1])


x = [5.890625,  3.453125, 27.21875,  14.21875,   4.671875, 11.171875,  7.3125,   18.484375,  8.9375]
y = [6.5,       3.859375, 29.046875, 15.234375,  5.078125, 11.78125,   8.125,    20.109375,  9.75 ]     
area = []

for i in range(len(x)):
    area.append(x[i] * y[i])

print(area)
print(np.argsort(area))

new_x = [0 for _ in range(len(x))]
new_y = [0 for _ in range(len(y))]

for i in range(len(np.argsort(area))):
    new_x[i] = int(x[np.argsort(area)[i]])
    new_y[i] = int(y[np.argsort(area)[i]])

anchors = []
for i in range(len(new_x)):
	anchors.append(new_x[i])
	anchors.append(new_y[i])

print(anchors)


for i in range(len(new_x)):
    print(new_x[i] * new_y[i])

2. YOLOv3 中默认的 anchor

anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326

一共有 18个数字,9个anchor,每一个anchor的大小,面积依次是:

130, 480,759,     1830, 2790, 7021,     10440, 30888, 121598

3. github 上的代码复制如下:

example.py

import glob
import xml.etree.ElementTree as ET

import numpy as np

from kmeans import kmeans, avg_iou

ANNOTATIONS_PATH = "Annotations"
CLUSTERS = 5

def load_dataset(path):
	dataset = []
	for xml_file in glob.glob("{}/*xml".format(path)):
		tree = ET.parse(xml_file)

		height = int(tree.findtext("./size/height"))
		width = int(tree.findtext("./size/width"))

		for obj in tree.iter("object"):
			xmin = int(float(obj.findtext("bndbox/xmin"))) / width
			ymin = int(float(obj.findtext("bndbox/ymin"))) / height
			xmax = int(float(obj.findtext("bndbox/xmax"))) / width
			ymax = int(float(obj.findtext("bndbox/ymax"))) / height

			dataset.append([xmax - xmin, ymax - ymin])

	return np.array(dataset)


data = load_dataset(ANNOTATIONS_PATH)
out = kmeans(data, k=CLUSTERS)
print("Accuracy: {:.2f}%".format(avg_iou(data, out) * 100))
print("Boxes:\n {}".format(out))

ratios = np.around(out[:, 0] / out[:, 1], decimals=2).tolist()
print("Ratios:\n {}".format(sorted(ratios)))

kmeans.py

import numpy as np


def iou(box, clusters):
    """
    Calculates the Intersection over Union (IoU) between a box and k clusters.
    :param box: tuple or array, shifted to the origin (i. e. width and height)
    :param clusters: numpy array of shape (k, 2) where k is the number of clusters
    :return: numpy array of shape (k, 0) where k is the number of clusters
    """
    x = np.minimum(clusters[:, 0], box[0])
    y = np.minimum(clusters[:, 1], box[1])
    if np.count_nonzero(x == 0) > 0 or np.count_nonzero(y == 0) > 0:
        raise ValueError("Box has no area")

    intersection = x * y
    box_area = box[0] * box[1]
    cluster_area = clusters[:, 0] * clusters[:, 1]

    iou_ = intersection / (box_area + cluster_area - intersection)

    return iou_


def avg_iou(boxes, clusters):
    """
    Calculates the average Intersection over Union (IoU) between a numpy array of boxes and k clusters.
    :param boxes: numpy array of shape (r, 2), where r is the number of rows
    :param clusters: numpy array of shape (k, 2) where k is the number of clusters
    :return: average IoU as a single float
    """
    return np.mean([np.max(iou(boxes[i], clusters)) for i in range(boxes.shape[0])])


def translate_boxes(boxes):
    """
    Translates all the boxes to the origin.
    :param boxes: numpy array of shape (r, 4)
    :return: numpy array of shape (r, 2)
    """
    new_boxes = boxes.copy()
    for row in range(new_boxes.shape[0]):
        new_boxes[row][2] = np.abs(new_boxes[row][2] - new_boxes[row][0])
        new_boxes[row][3] = np.abs(new_boxes[row][3] - new_boxes[row][1])
    return np.delete(new_boxes, [0, 1], axis=1)


def kmeans(boxes, k, dist=np.median):
    """
    Calculates k-means clustering with the Intersection over Union (IoU) metric.
    :param boxes: numpy array of shape (r, 2), where r is the number of rows
    :param k: number of clusters
    :param dist: distance function
    :return: numpy array of shape (k, 2)
    """
    rows = boxes.shape[0]

    distances = np.empty((rows, k))
    last_clusters = np.zeros((rows,))

    np.random.seed()

    # the Forgy method will fail if the whole array contains the same rows
    clusters = boxes[np.random.choice(rows, k, replace=False)]

    while True:
        for row in range(rows):
            distances[row] = 1 - iou(boxes[row], clusters)

        nearest_clusters = np.argmin(distances, axis=1)

        if (last_clusters == nearest_clusters).all():
            break

        for cluster in range(k):
            clusters[cluster] = dist(boxes[nearest_clusters == cluster], axis=0)

        last_clusters = nearest_clusters

    return clusters

4. 有用请点赞,谢谢

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_42419002/article/details/100566926

智能推荐

通用方法 解决/usr/lib64/libstdc++.so.6: version `CXXABI_1.3.8‘ not found的问题_华章酱的博客-程序员秘密

1.先查看当前Linux服务器gcc版本中包含哪些库。(注意:要区别当前Linux版本是32还是64,下面的操作是查看64的,下载文件也需要用64)strings /usr/lib64/libstdc++.so.6 | grep GLIBCstrings /usr/lib64/libstdc++.so.6|grep CXXABI以上只要缺少对应的版本,都可通过安装对应缺失的libstdc++.so.6.0.13以上的版本,来解决缺失版本的问题。当前截图中是高版本的libst.

IE-LAB网络实验室:PPP 协议PAP和CHAP认证配置_pap的两次握手报文分别是什么_ielab悦然的博客-程序员秘密

PPP(点到点协议)协议:是为在同等单元之间传输数据包这样的简单链路设计的链路层协议。这种链路提供全双工操作,并按照顺序传递数据包。设计目的主要是用来通过拨号或专线方式建立点对点连接发送数据,使其成为各种主机、网桥和路由器之间简单连接的一种共通的解决方案。PAP为两次握手认证,口令为明文,验证过程尽在链路初始建立阶段进行。PAP不是一种安全的验证协议,因为口令是以明文的方式在链路上发送,而且用户...

Android 图片加载框架Glide4.0源码完全解析(一)_管满满的博客-程序员秘密

上一篇博文写的是Picasso基本使用和源码完全解析,Picasso的源码阅读起来还是很顺畅的,然后就想到Glide框架,网上大家也都推荐使用这个框架用来加载图片,正好我目前的写作目标也是分析当前一些流行的框架源码,那就也来解析下Glide的源码吧,而且有了Picasso源码的分析相信很快就搞定Glide的,结果也就悲剧了,深陷其中无法自拔了,Glide的源码远非Picasso能比,阅读起来

Linux 如何从网上下载文件_zeroxes的博客-程序员秘密

将网络上的文件下载到使用 Linux 操作系统的计算机上,需要用到 wget <url> 指令,使用该指令可能会面临两个问题。首先,如何获取文件的下载 url?这需要你在浏览器上找到要下载文件的链接地址,然后右键 -> 复制链接地址,既可获取该文件的下载 url。以阿里提供的 Centos7 镜像为例,如下图:在 Linux 系统上使用 wget 指令即可下载该文件...

(void(*)(void))func()的解读_((void(*)())func)()_gl23838的博客-程序员秘密

根据 Andrew Koening在他的《C 陷阱与缺陷》里对(*(void (*)( ) )0)( )的分析得到以下结论。1.如何声明一个变量?float *g():g是一个函数,他的返回值是一个指针,该指针指向一个float数。     float (*h)(float):h是一个指针,该指针指向一个函数,这个函数的返回值是float类型的,这个函数的参数是float类

es自定义排序_LY笔记的博客-程序员秘密

## 排序### 一、默认排序规则默认情况下,是按照_score降序排序。_score使用的算法,计算出一个索引中的文本,与搜索文本,他们之间的关联匹配程度es使用的是,term frequency和inverse documnet frequency算法,简称为TF/IDF算法term frequency:搜索文本中的各个词条在field文本中出现了多少次,出现次数越多,分数越高inverse documnet frequency:搜索文本中的各个词条在整个索引的所有文档中出现了多少次,出

随便推点

基于 Oauth 2.0 的第三方账号登录实现_weixin_34375251的博客-程序员秘密

为什么80%的码农都做不了架构师?>>> ...

nginx反向代理模块配置详解_nginx做反向代理配置文件的例子_weixin_39859909的博客-程序员秘密

worker_processes 2;error_log logs/error.log;#error_log logs/error.log notice;#error_log logs/error.log info;pid logs/nginx.pid;events {use epoll;worker_connections2048;}http {include ...

linux设备驱动归纳总结(六):3.中断下半部之tasklet_我是黏黏虫的博客-程序员秘密

http://blog.chinaunix.net/uid-25014876-id-100005.html一、什么是下半部中断是一个很霸道的东西,处理器一旦接收到中断,就会打断正在执行的代码,调用中断处理函数。如果在中断处理函数中没有禁止中断,该中断处理函数执行过程中仍有可能被其他中断打断。出于这样的原因,大家都希望中断处理函数执行得越快越好。另外,中断上下

Web开发面临的挑战主要有哪些?_lisky119的博客-程序员秘密

[探讨]Web开发面临的挑战主要有哪些?2011-12-27 14:17 | 2175次阅读 | 来源:CSDN整理自知乎网 【已有8条评论】发表评论关键词:开发,Web |作者:夏梦竹 | 收藏这篇资讯导读:要成为一名高效的Web开发者,这需要我们做很多工作,来提高我们的工作方式,以及改善我们的劳动成果。而在开发中难免会遇到一些困难,从前端到后端,近日

提取VOCtrainval_11-May-2012数据集中特定样本_guaguablue的博客-程序员秘密

前言机器学习常用的PascalVOC数据集,此不做介绍,下载连接中有描述。下载地址:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html目的由于下载完的原始数据集各个分类的图片和标签都放在同一个文件夹,面对近2万张图片,人工区分选择是不可能的。但是如果我们想要在其中选取出比如有关bottle分类的图片和标签文件,VOC2012...

字符串指针与char型指针数组_HAN-Kai的博客-程序员秘密

一、字符串指针字符串是一种特殊的char型数组,指向char类型数组的指针,就是字符串指针。与普通指针一样,字符串指针在使用前也必须定义。字符串与char数组的区别在于长度,字符会自动在尾部加上一个长度‘\0’,而char型数组的长度就是其字符的个数。字符串长度是字符个数+1。例:#includeusing namespace std;int main(){ char s