Gradient Descent梯度下降算法代码实现_ZN_daydayup的博客-程序员宅基地

技术标签: 机器学习  

目录

摘要

梯度下降法原理

使用二次函数简单实现和验证梯度下降

画出梯度下降的过程

学习率对梯度下降快慢的影响

在线性回归模型中使用梯度下降法

Conclusion(总结)


摘要

使用代码实现和验证梯度下降算法

梯度下降法原理

梯度下降,gradient descent(之后将简称GD),是一种通过迭代找最优的方式一步步找到损失函数最小值的算法,基本算法思路可总结为如下几点:

(1) 随机设置一个初始值

(2) 计算损失函数的梯度

(3) 设置步长,步长的长短将会决定梯度下降的速度和准确度

(4) 将初值减去步长乘以梯度,更新初值,然后将这一过程不断迭代

 

使用二次函数简单实现和验证梯度下降

import numpy as np
import matplotlib.pyplot as plt

plot_x = np.linspace(-1,6,141)
plot_y = (plot_x - 2.5)**2 - 1
plt.plot(plot_x,plot_y)
plt.show()

 

def dJ(theta):
    return 2*(theta - 2.5)

def J(theta):
    try:
        return (theta - 2.5)**2-1
    except:
        return float('inf')

实验结果:

theta= 0.5
函数值= 3.0
第0次梯度下降.....
theta= 0.9
函数值= 1.5600000000000005
第1次梯度下降.....
theta= 1.2200000000000002
函数值= 0.6383999999999994
第2次梯度下降.....
theta= 1.4760000000000002
函数值= 0.04857599999999951
第3次梯度下降.....
theta= 1.6808
函数值= -0.3289113600000001
......
第42次梯度下降.....
theta= 2.4998638870532313
函数值= -0.9999999814732657
第43次梯度下降.....
最终theta= 2.499891109642585
最终函数值= -0.99999998814289
 

画出梯度下降的过程

theta = 0.0
theta_history = [theta]
while True:
    gradient = dJ(theta)
    last_theta = theta
    theta = theta - eta * gradient
    theta_history.append(theta)
    if(abs(J(theta) - J(last_theta)) < epsilon):
        break

plt.plot(plot_x,J(plot_x))
plt.plot(np.array(theta_history),J(np.array(theta_history)),color="r",marker="+")
plt.show()

学习率对梯度下降快慢的影响

当学习率eta = 0.001时的图像

initial_theta = 0
eta = 0.001
theta_history = []
gradient_descent(initial_theta,eta)
plot_theta_history()

 

当学习率eta = 0.8时的图像

initial_theta = 0
eta = 0.8
theta_history = []
gradient_descent(initial_theta,eta)
plot_theta_history()

 

当学习率eta = 1.1时的图像

initial_theta = 0
eta = 1.1
theta_history = []
gradient_descent(initial_theta,eta,n_iters=10)
plot_theta_history()

 

 

 

在线性回归模型中使用梯度下降法

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(666)
x = 2 * np.random.random(size=100)
y = x * 3. + 4. + np.random.normal(size=100)
X = x.reshape(-1, 1)
X.shape
plt.scatter(x, y)
plt.show()

 

使用梯度下降法训练

def J(theta, X_b, y):
    try:
        return np.sum((y - X_b.dot(theta))**2) / len(X_b)
    except:
        return float('inf')

def dJ(theta, X_b, y):
    res = np.empty(len(theta))
    res[0] = np.sum(X_b.dot(theta) - y)
    for i in range(1, len(theta)):
        res[i] = (X_b.dot(theta) - y).dot(X_b[:,i])
    return res * 2 / len(X_b)

def gradient_descent(X_b, y, initial_theta, eta, n_iters = 1e4, epsilon=1e-8):
    
    theta = initial_theta
    cur_iter = 0

    while cur_iter < n_iters:
        gradient = dJ(theta, X_b, y)
        last_theta = theta
        theta = theta - eta * gradient
        if(abs(J(theta, X_b, y) - J(last_theta, X_b, y)) < epsilon):
            break
            
        cur_iter += 1

    return theta

X_b = np.hstack([np.ones((len(x), 1)), x.reshape(-1,1)])
initial_theta = np.zeros(X_b.shape[1])
eta = 0.01
theta = gradient_descent(X_b, y, initial_theta, eta)

实验结果:array([4.02145786, 3.00706277])

即截距 b = 4.02145786 ,斜率 a = 3.00706277  大致满足设置的函数 y = x * 3. + 4. + np.random.normal(size=100)

 

Conclusion(总结)

1.学习率eta取值影响获得最优解的速度

2.学习率eta取值不合适,甚至得不到最优解

3.学习率eta是一个超参数

4.使用梯度下降算法不一定能得到全局最优解,可能是局部最优解

5.可以多次运行,随机化初始点,多次尝试找出最优解

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/zn961018/article/details/116333668

智能推荐

android make-standalone-toolchain.sh 使用说明-程序员宅基地

#$ANDROID_NDK/build/tools/make-standalone-toolchain.sh --platform=android-24 --install-dir=./android-toolchain --ndk-dir=/Users/musictom/Library/Android/sdk/ndk-bundle/ --use-llvm#$ANDROID_NDK/bui...

实现不刷新整个页面进行前进后退_cef 不允许刷新跳转-程序员宅基地

在html5出来前,实现无刷新前进后退通常是结合location.hash+onhashchange事件来实现的;在html5出来后,可以使用h5 history api来实现无刷新前进后退点击浏览器的前进后退按钮时,只要将要进入的页面与当前页面不是同一个页面(因为有可能人为地添加了一条历史记录,但实际上还是同一个页面),那么将要进入的页面其所有请求(包括html、js、css、jso_cef 不允许刷新跳转

SQLServer 2005 之 海量数据解决方案 分区表-程序员宅基地

http://www.cnblogs.com/yizhu2000/archive/2007/12/13/992901.htmlCsdn Blog在2007年,由于访问量和数据量的大幅度增长,使得我们原有的在.text 0.96版本上修改的代码基本不堪重负。在数据库方面主要表现为,单单文章表,2007年1年的数据已经达到了30G的量(最后的解决方案是对把文章表分为两个表,分别存放文章...

异常&正则表达式习题作业-程序员宅基地

一、异常作业:1.简述什么是异常、异常的继承体系? 异常:在java程序运行过程中,出现的不正常的情况,出现的错误,称为异常。 Java异常体系 |——Throwable:可抛出的,异常的顶层父类,异常类都是它的子类(实现类描述java的错误和异常) |——Error 错误,用于描述那些无法捕获和处理的错误情况,属于非常严重的错误。 StackOver...

290. Word Pattern-程序员宅基地

Given a pattern and a string str, find if str follows the same pattern.Here follow means a full match, such that there is a bijection between a letter in pattern and a non-empty word in str.Exampl...

梯度,GD梯度下降,SGD随机梯度下降_u小鬼的博客-程序员宅基地

羊了,但是依旧生龙活虎。补补之前落下的SGD算法,这个在深度学习中应用广泛。梯度就是函数对所有单位向量求偏导构成的向量(方向),代表函数fff在定义空间RnR^nRn中的“增长率”。利用方向导数的定义,以及前面的定理,得∇uf(x)=∇f(x)⋅u=∣∣∇f(x)∣∣∣u∣∣cosα∇u​f(x)=∇f(x)⋅u=∣∣∇f(x)∣∣∣u∣∣cosαα\alphaα是∇uf。

随便推点

解决 支付宝沙箱环境测试 出现“沙箱订单信息有错误,建议联系卖家”_订单信息无法识别,建议联系卖家。-程序员宅基地

出现“沙箱订单信息有错误,建议联系卖家”问题看这里!-帖子详情-开放社区 (alipay.com)今天在测试的时候 怎么就连接不上支付宝沙箱的环境;他说合作协议到期,联系商户;我也提交了我的问题,应该很快就可以得到解决了;于是我就去官网进行查找解决方法;终于官方也随之回应;_订单信息无法识别,建议联系卖家。

php mysql 注入unhex_PHP+Mysql注入防护与绕过_weixin_39611340的博客-程序员宅基地

今天给大家分享一个关于php常见的注入防护以及如何bypass的文章,文章内容来源国外某大佬总结,我做了一下整理,文章来源地址不详,下面正文开始。以下的方式也仅仅是针对黑名单的过滤有一定的效果,为了安全最好还是以白名单的方式对参数进行检测。黑名单关键字过滤与绕过过滤关键字and、orPHP匹配函数代码如下:preg_match('/(and|or)/i', $id)如何Bypass,过滤注入测试语..._mysql hex php

内蒙古工业大学计算机网络试卷,内蒙古工业大学计算机网络试卷B-2007答案_NoviScl的博客-程序员宅基地

内蒙古工业大学2006——2007学年第二学期《计算机网络》期末考试试卷(B)参考答案及评分标准(课程代码:020203019)注意事项:1. 本试卷适用于2004级计算机科学与技术、自动化、电子信息工程、通信工程专业学生使用。2. 本试卷共9页,满分100分。答题时间120分钟。一、单选题(本大题共20道小题,每小题1分,共20分)1.以下哪一个选项顺序描述了计算机网络的五层体系结构( A )A...

IDEA mapper.xml无法编译问题-程序员宅基地

最近从eclipse换到了idea…刚换工具就踩了一个大坑…springboot整合mybatis问题:1.@autowired 报错 这个可以在设置中把@autowired 调成waring 或者直接忽略2.@autowired问题解决之后开始编译运行…出现Invalid bound statement(not found)的问题…从来没有出现过这些问题的我看的很懵逼的回去看是不是注解或者mapper.xml中的namespace出错…检查了好几遍发现没有出错…遂开始面向百度 Google编程…3

深入理解Azure自动扩展集VMSS(1)_weixin_30877755的博客-程序员宅基地

前文中已经详细介绍了如何配置和部署Azure的虚拟机扩展集VMSS进行自动扩展,但在实际使用过程当中,用户会出现更进一步使用的一些问题,VMSS基本扩展原理及怎么简单调试?如何进行手动扩展?怎么使用自定义镜像?在设计的时候有哪些最佳实践和考量等等。本文通过测试自动扩展功能开始,逐步介绍如下主题:VMSS自动扩展测试及告警规则配置VMSS中Autoscale基本原理及诊断VMSS...