2 基于梯度的攻击——PGD_weixin_30379531的博客-程序员秘密

技术标签: python  人工智能  

PGD攻击原论文地址——https://arxiv.org/pdf/1706.06083.pdf

1.PGD攻击的原理

  PGD(Project Gradient Descent)攻击是一种迭代攻击,可以看作是FGSM的翻版——K-FGSM (K表示迭代的次数),大概的思路就是,FGSM是仅仅做一次迭代,走一大步,而PGD是做多次迭代,每次走一小步,每次迭代都会将扰动clip到规定范围内。

 

 

一般来说,PGD的攻击效果比FGSM要好。首先,如果目标模型是一个线性模型,那么用FGSM就可以了,因为此时loss对输入的导数是固定的,换言之,使得loss下降的方向是明确的,即使你多次迭代,扰动的方向也不会改变。而对于一个非线性模型,仅仅做一次迭代,方向是不一定完全正确的,这也是为什么FGSM的效果一般的原因了。

上图中,黑圈是输入样本,假设样本只有两维,那么样本可以改变的就有八个方向,坐标系中显示了loss等高线,以及可以扰动的最大范围(因为是无穷范数,所以限制范围是一个方形,负半轴的范围没有画出来),黑圈每一次改变,都是以最优的方向改变,最后一次由于扰动超出了限制,所以直接截断,如果此时迭代次数没有用完,那么就在截断处继续迭代,直到迭代次数用完。

2.PGD的代码实现

 

class PGD(nn.Module):
    def __init__(self,model):
        super().__init__()
        self.model=model#必须是pytorch的model
        self.device=torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
    def generate(self,x,**params):
        self.parse_params(**params)
        labels=self.y

        adv_x=self.attack(x,labels)
        return adv_x
    def parse_params(self,eps=0.3,iter_eps=0.01,nb_iter=40,clip_min=0.0,clip_max=1.0,C=0.0,
                     y=None,ord=np.inf,rand_init=True,flag_target=False):
        self.eps=eps
        self.iter_eps=iter_eps
        self.nb_iter=nb_iter
        self.clip_min=clip_min
        self.clip_max=clip_max
        self.y=y
        self.ord=ord
        self.rand_init=rand_init
        self.model.to(self.device)
        self.flag_target=flag_target
        self.C=C


    def sigle_step_attack(self,x,pertubation,labels):
        adv_x=x+pertubation
        # get the gradient of x
        adv_x=Variable(adv_x)
        adv_x.requires_grad = True
        loss_func=nn.CrossEntropyLoss()
        preds=self.model(adv_x)
        if self.flag_target:
            loss =-loss_func(preds,labels)
        else:
            loss=loss_func(preds,labels)
            # label_mask=torch_one_hot(labels)
            #
            # correct_logit=torch.mean(torch.sum(label_mask * preds,dim=1))
            # wrong_logit = torch.mean(torch.max((1 - label_mask) * preds, dim=1)[0])
            # loss=-F.relu(correct_logit-wrong_logit+self.C)

        self.model.zero_grad()
        loss.backward()
        grad=adv_x.grad.data
        #get the pertubation of an iter_eps
        pertubation=self.iter_eps*np.sign(grad)
        adv_x=adv_x.cpu().detach().numpy()+pertubation.cpu().numpy()
        x=x.cpu().detach().numpy()

        pertubation=np.clip(adv_x,self.clip_min,self.clip_max)-x
        pertubation=clip_pertubation(pertubation,self.ord,self.eps)


        return pertubation
    def attack(self,x,labels):
        labels = labels.to(self.device)
        print(self.rand_init)
        if self.rand_init:
            x_tmp=x+torch.Tensor(np.random.uniform(-self.eps, self.eps, x.shape)).type_as(x).cuda()
        else:
            x_tmp=x
        pertubation=torch.zeros(x.shape).type_as(x).to(self.device)
        for i in range(self.nb_iter):
            pertubation=self.sigle_step_attack(x_tmp,pertubation=pertubation,labels=labels)
            pertubation=torch.Tensor(pertubation).type_as(x).to(self.device)
        adv_x=x+pertubation
        adv_x=adv_x.cpu().detach().numpy()

        adv_x=np.clip(adv_x,self.clip_min,self.clip_max)

        return adv_x

  

PGD攻击的参数并不多,比较重要的就是下面这几个:

eps: maximum distortion of adversarial example compared to original input

eps_iter: step size for each attack iteration

nb_iter: Number of attack iterations.

上面代码中注释的这行代码是CW攻击的PGD形式,这个在防御论文https://arxiv.org/pdf/1706.06083.pdf中有体现,以后说到CW攻击再细说。

 

1 # label_mask=torch_one_hot(labels)
2 #
3 # correct_logit=torch.mean(torch.sum(label_mask * preds,dim=1))
4 # wrong_logit = torch.mean(torch.max((1 - label_mask) * preds, dim=1)[0])
5 # loss=-F.relu(correct_logit-wrong_logit+self.C)

 

最后再提一点就是,在上面那篇防御论文中也提到了,PGD攻击是最强的一阶攻击,如果防御方法对这个攻击能够有很好的防御效果,那么其他攻击也不在话下了。

转载于:https://www.cnblogs.com/shona/p/11274393.html

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_30379531/article/details/98051357

智能推荐

JDBC的学习_a1151571的博客-程序员秘密

1.什么是JDBC 概念:Java DataBase Connectivity (java数据库连接),通过java语言操作数据库 本质:其实是官方(sun公司)定义的一套操作所有关系型数据库的规则,即接口。各个数据库厂商去实现这套接口,提供数据库驱动jar包。我们可以使用这套接口(JDBC)编程,真正执行的代码是驱动jar包中的实现类2.开始入门 步骤...

爬虫入门之查找JS入口篇(二)_flybirding1001的博客-程序员秘密

各位老铁们,当你划到这个视频,请帮忙点个双击加个关注,感谢老铁们。。。。额。。。不好意思,走错片场了,我们今天继续来将如何查找JS的入口,话不多说,开始吧。目标地址:ht...

Sentinel 与 Hystrix 的对比_sentinel与_想跑步丶小胖子的博客-程序员秘密

Sentinel 与 Hystrix 的对比摘要: [Sentinel](https://github.com/alibaba/Sentinel) 是阿里中间件团队研发的面向分布式服务架构的轻量级高可用流量控制组件,最近正式开源。Sentinel 主要以流量为切入点,从流量控制、熔断降级、系统负载保护等多个维度来帮助用户保护服务的稳定性。大家可能会问:Sentinel 和之前常用的熔断降级库 ...

java multiset_C++ multiset用法详解(附带完整示例)_weixin_39942400的博客-程序员秘密

multiset 容器就像 set 容器,但它可以保存重复的元素。这意味我们总可以插入元素,当然必须是可接受的元素类型。默认用 less 来比较元素,但也可以指定不同的比较函数。在元素等价时,它必须返回 false。例如:std::multiset> words{{"dog", "cat", "mouse"}, std::greater()};这条语句定义了一个以 string 为元素的 m...

spring cloud gateway性能优化_编码牛的博客-程序员秘密

spring gateway 网关压测,瓶颈代码分析,性能优化

随便推点

vscode c++安装与单文件多文件编译配置(win10)_vscode配置c++多文件_岁月歌者BC的博客-程序员秘密

vscode c++安装与单文件多文件编译配置(win10)总体思路:1下载Vscode ,mingw,cmake(用于多文件编译)2配置​ 1>mingw,vscode,cmake环境变量​ 2>vscode​ 插件Chinese,code runner,C/C++,cmake,cmake tools等必须插件​ 通过一定的调试程序进行配置(编译单文件、多文件的配置操作不同)注:mingw可以通过下载DEV-

推荐一款Gin+Vue+ElementUI实现的智慧城市后台管理系统_半城 风雨的博客-程序员秘密

是一款基于Golang、Gin、Xorm、Vue、ElementUI、MySQL等技术栈开发平台框架,拥有完善的(RBAC)权限架构和基础核心管理模块,为了缩短研发周期,系统框架集成了代码生成器,内置平台自定义研发的模板引擎,可以一键CRUD生成整个模块的全部代码,本框架为一站式系统框架开发平台,可以帮助开发者提升开发效率、降低研发成本,同时便于后期的系统维护升级。......

何为大型机、中型机、小型机_计算机中型机举例_Jessie_Zhang的博客-程序员秘密

<br />大型机(Mainframe)<br /><br />  大型机(mainframe)这个词,最初是指装在非常大的带框铁盒子里的大型计算机系统,以用来同小一些的迷你机和微型机有所区别。虽然这个词已经通过不同方式被使用了很多年,大多数时候它却是指 system/360 开始的一系列的IBM计算机。这个词也可以用来指由其他厂商,如Amdahl, Hitachi Data Systems (HDS) 制造的兼容的系统。<br />  有些人用这个词来指IBM的AS/400 或者iSeries 系统,这种

又有程序猿倒下了(华为)!这是有多累啊?再一次输给了生活...._weixin_33755847的博客-程序员秘密

事例一:震惊IT界天才程序员跳楼自杀事件,最杀人的还是人心2017年9月7日凌晨5点左右,WePhone创始人兼开发者苏享茂,在公司附近的住所处跳楼自杀。据悉是不堪前妻翟欣欣的辱骂、威胁、恐吓以及千万巨额索赔。事例二:痛心!中兴42岁程序员跳楼身亡,是什么把他逼上了...

Fsm-----(米里)状态机实现"11101"序列检测(modelsim用脚本语言实现仿真),(米里和摩尔都可以)_米里序列检测器实验原理_寐语者的博客-程序员秘密

//米里状态机进行“11101”序列检测module mealy( input wire clk, input wire rst_n, input wire A, output reg K);//5个数的序列检测,定义6个状态parameter S1=6'b00_0001;parameter S2=6'b00_0010;parameter S3=6'b00_0100...

推荐文章

热门文章

相关标签