self.setinterval,常见的web安全漏洞有哪些

伏羲号

self.setinterval,常见的web安全漏洞有哪些?

前言:

self.setinterval,常见的web安全漏洞有哪些

在互联网时代,数据安全与个人隐私受到了前所未有的挑战,各种新奇的攻击技术层出不穷。如何才能更好地保护我们的数据?本文主要侧重于分析几种常见的攻击的类型以及防御的方法。

一、XSS

XSS (Cross-Site Scripting),跨站脚本攻击,因为缩写和 CSS重叠,所以只能叫 XSS。跨站脚本攻击是指通过存在安全漏洞的Web网站注册用户的浏览器内运行非法的HTML标签或JavaScript进行的一种攻击。

跨站脚本攻击有可能造成以下影响:

利用虚假输入表单骗取用户个人信息。利用脚本窃取用户的Cookie值,被害者在不知情的情况下,帮助攻击者发送恶意请求。显示伪造的文章或图片。

XSS 的原理是恶意攻击者往 Web 页面里插入恶意可执行网页脚本代码,当用户浏览该页之时,嵌入其中 Web 里面的脚本代码会被执行,从而可以达到攻击者盗取用户信息或其他侵犯用户安全隐私的目的。

XSS 的攻击方式千变万化,但还是可以大致细分为几种类型。

1.非持久型 XSS(反射型 XSS )

非持久型 XSS 漏洞,一般是通过给别人发送带有恶意脚本代码参数的 URL,当 URL 地址被打开时,特有的恶意代码参数被 HTML 解析、执行。

举一个例子,比如页面中包含有以下代码:

<select> <script> document.write('' + '<option value=1>' + location.href.substring(location.href.indexOf('default=') + 8) + '</option>' ); document.write('<option value=2>English</option>'); </script> </select>

攻击者可以直接通过 URL (类似:https://xxx.com/xxx?default=<script>alert(document.cookie)</script>) 注入可执行的脚本代码。不过一些浏览器如Chrome其内置了一些XSS过滤器,可以防止大部分反射型XSS攻击。

非持久型 XSS 漏洞攻击有以下几点特征:

即时性,不经过服务器存储,直接通过 HTTP 的 GET 和 POST 请求就能完成一次攻击,拿到用户隐私数据。攻击者需要诱骗点击,必须要通过用户点击链接才能发起反馈率低,所以较难发现和响应修复盗取用户敏感保密信息

为了防止出现非持久型 XSS 漏洞,需要确保这么几件事情:

Web 页面渲染的所有内容或者渲染的数据都必须来自于服务端。尽量不要从 URL,document.referrer,document.forms 等这种 DOM API 中获取数据直接渲染。尽量不要使用 eval, new Function(),document.write(),document.writeln(),window.setInterval(),window.setTimeout(),innerHTML,document.createElement() 等可执行字符串的方法。如果做不到以上几点,也必须对涉及 DOM 渲染的方法传入的字符串参数做 escape 转义。前端渲染的时候对任何的字段都需要做 escape 转义编码。

2.持久型 XSS(存储型 XSS)

持久型 XSS 漏洞,一般存在于 Form 表单提交等交互功能,如文章留言,提交文本信息等,黑客利用的 XSS 漏洞,将内容经正常功能提交进入数据库持久保存,当前端页面获得后端从数据库中读出的注入代码时,恰好将其渲染执行。

举个例子,对于评论功能来说,就得防范持久型 XSS 攻击,因为我可以在评论中输入以下内容

主要注入页面方式和非持久型 XSS 漏洞类似,只不过持久型的不是来源于 URL,referer,forms 等,而是来源于后端从数据库中读出来的数据 。持久型 XSS 攻击不需要诱骗点击,黑客只需要在提交表单的地方完成注入即可,但是这种 XSS 攻击的成本相对还是很高。

攻击成功需要同时满足以下几个条件:

POST 请求提交表单后端没做转义直接入库。后端从数据库中取出数据没做转义直接输出给前端。前端拿到后端数据没做转义直接渲染成 DOM。

持久型 XSS 有以下几个特点:

持久性,植入在数据库中盗取用户敏感私密信息危害面广

3.如何防御

对于 XSS 攻击来说,通常有两种方式可以用来防御。

1) CSP

CSP 本质上就是建立白名单,开发者明确告诉浏览器哪些外部资源可以加载和执行。我们只需要配置规则,如何拦截是由浏览器自己实现的。我们可以通过这种方式来尽量减少 XSS 攻击。

通常可以通过两种方式来开启 CSP:

设置 HTTP Header 中的 Content-Security-Policy设置 meta 标签的方式

这里以设置 HTTP Header 来举例:

只允许加载本站资源

Content-Security-Policy: default-src 'self'

只允许加载 HTTPS 协议图片

Content-Security-Policy: img-src https://*

允许加载任何来源框架

Content-Security-Policy: child-src 'none'

如需了解更多属性,请查看Content-Security-Policy文档

对于这种方式来说,只要开发者配置了正确的规则,那么即使网站存在漏洞,攻击者也不能执行它的攻击代码,并且 CSP 的兼容性也不错。

2) 转义字符

用户的输入永远不可信任的,最普遍的做法就是转义输入输出的内容,对于引号、尖括号、斜杠进行转义

function escape(str) { str = str.replace(/&/g, '&') str = str.replace(/</g, '<') str = str.replace(/>/g, '>') str = str.replace(/"/g, '&quto;') str = str.replace(/'/g, ''') str = str.replace(/`/g, '`') str = str.replace(/\//g, '/') return str }

但是对于显示富文本来说,显然不能通过上面的办法来转义所有字符,因为这样会把需要的格式也过滤掉。对于这种情况,通常采用白名单过滤的办法,当然也可以通过黑名单过滤,但是考虑到需要过滤的标签和标签属性实在太多,更加推荐使用白名单的方式。

const xss = require('xss') let html = xss('<h1 id="title">XSS Demo</h1><script>alert("xss");</script>') // -> <h1>XSS Demo</h1><script>alert("xss");</script> console.log(html)

以上示例使用了 js-xss 来实现,可以看到在输出中保留了 h1 标签且过滤了 script 标签。

3) HttpOnly Cookie。

这是预防XSS攻击窃取用户cookie最有效的防御手段。Web应用程序在设置cookie时,将其属性设为HttpOnly,就可以避免该网页的cookie被客户端恶意JavaScript窃取,保护用户cookie信息。

二、CSRF

CSRF(Cross Site Request Forgery),即跨站请求伪造,是一种常见的Web攻击,它利用用户已登录的身份,在用户毫不知情的情况下,以用户的名义完成非法操作。

1. CSRF攻击的原理

下面先介绍一下CSRF攻击的原理:

完成 CSRF 攻击必须要有三个条件:

用户已经登录了站点 A,并在本地记录了 cookie在用户没有登出站点 A 的情况下(也就是 cookie 生效的情况下),访问了恶意攻击者提供的引诱危险站点 B (B 站点要求访问站点A)。站点 A 没有做任何 CSRF 防御

我们来看一个例子: 当我们登入转账页面后,突然眼前一亮惊现"XXX隐私照片,不看后悔一辈子"的链接,耐不住内心躁动,立马点击了该危险的网站(页面代码如下图所示),但当这页面一加载,便会执行submitForm这个方法来提交转账请求,从而将10块转给黑客。

2.如何防御

防范 CSRF 攻击可以遵循以下几种规则:

Get 请求不对数据进行修改不让第三方网站访问到用户 Cookie阻止第三方网站请求接口请求时附带验证信息,比如验证码或者 Token

1) SameSite

可以对 Cookie 设置 SameSite 属性。该属性表示 Cookie 不随着跨域请求发送,可以很大程度减少 CSRF 的攻击,但是该属性目前并不是所有浏览器都兼容。

2) Referer Check

HTTP Referer是header的一部分,当浏览器向web服务器发送请求时,一般会带上Referer信息告诉服务器是从哪个页面链接过来的,服务器籍此可以获得一些信息用于处理。可以通过检查请求的来源来防御CSRF攻击。正常请求的referer具有一定规律,如在提交表单的referer必定是在该页面发起的请求。所以通过检查http包头referer的值是不是这个页面,来判断是不是CSRF攻击。

但在某些情况下如从https跳转到http,浏览器处于安全考虑,不会发送referer,服务器就无法进行check了。若与该网站同域的其他网站有XSS漏洞,那么攻击者可以在其他网站注入恶意脚本,受害者进入了此类同域的网址,也会遭受攻击。出于以上原因,无法完全依赖Referer Check作为防御CSRF的主要手段。但是可以通过Referer Check来监控CSRF攻击的发生。

3) Anti CSRF Token

目前比较完善的解决方案是加入Anti-CSRF-Token。即发送请求时在HTTP 请求中以参数的形式加入一个随机产生的token,并在服务器建立一个拦截器来验证这个token。服务器读取浏览器当前域cookie中这个token值,会进行校验该请求当中的token和cookie当中的token值是否都存在且相等,才认为这是合法的请求。否则认为这次请求是违法的,拒绝该次服务。

这种方法相比Referer检查要安全很多,token可以在用户登陆后产生并放于session或cookie中,然后在每次请求时服务器把token从session或cookie中拿出,与本次请求中的token 进行比对。由于token的存在,攻击者无法再构造出一个完整的URL实施CSRF攻击。但在处理多个页面共存问题时,当某个页面消耗掉token后,其他页面的表单保存的还是被消耗掉的那个token,其他页面的表单提交时会出现token错误。

4) 验证码

应用程序和用户进行交互过程中,特别是账户交易这种核心步骤,强制用户输入验证码,才能完成最终请求。在通常情况下,验证码够很好地遏制CSRF攻击。但增加验证码降低了用户的体验,网站不能给所有的操作都加上验证码。所以只能将验证码作为一种辅助手段,在关键业务点设置验证码。

三、点击劫持

点击劫持是一种视觉欺骗的攻击手段。攻击者将需要攻击的网站通过 iframe 嵌套的方式嵌入自己的网页中,并将 iframe 设置为透明,在页面中透出一个按钮诱导用户点击。

1. 特点

隐蔽性较高,骗取用户操作"UI-覆盖攻击"利用iframe或者其它标签的属性

2. 点击劫持的原理

用户在登陆 A 网站的系统后,被攻击者诱惑打开第三方网站,而第三方网站通过 iframe 引入了 A 网站的页面内容,用户在第三方网站中点击某个按钮(被装饰的按钮),实际上是点击了 A 网站的按钮。接下来我们举个例子:我在优酷发布了很多视频,想让更多的人关注它,就可以通过点击劫持来实现

iframe { width: 1440px; height: 900px; position: absolute; top: -0px; left: -0px; z-index: 2; -moz-opacity: 0; opacity: 0; filter: alpha(opacity=0); } button { position: absolute; top: 270px; left: 1150px; z-index: 1; width: 90px; height:40px; } </style> ...... <button>点击脱衣</button> <img src="http://pic1.win4000.com/wallpaper/2018-03-19/5aaf2bf0122d2.jpg"> <iframe src="http://i.youku.com/u/UMjA0NTg4Njcy" scrolling="no"></iframe>

从上图可知,攻击者通过图片作为页面背景,隐藏了用户操作的真实界面,当你按耐不住好奇点击按钮以后,真正的点击的其实是隐藏的那个页面的订阅按钮,然后就会在你不知情的情况下订阅了。

3. 如何防御

1)X-FRAME-OPTIONS

X-FRAME-OPTIONS是一个 HTTP 响应头,在现代浏览器有一个很好的支持。这个 HTTP 响应头 就是为了防御用 iframe 嵌套的点击劫持攻击。

该响应头有三个值可选,分别是

DENY,表示页面不允许通过 iframe 的方式展示SAMEORIGIN,表示页面可以在相同域名下通过 iframe 的方式展示ALLOW-FROM,表示页面可以在指定来源的 iframe 中展示

2)JavaScript 防御

对于某些远古浏览器来说,并不能支持上面的这种方式,那我们只有通过 JS 的方式来防御点击劫持了。

<head> <style id="click-jack"> html { display: none !important; } </style> </head> <body> <script> if (self == top) { var style = document.getElementById('click-jack') document.body.removeChild(style) } else { top.location = self.location } </script> </body>

以上代码的作用就是当通过 iframe 的方式加载页面时,攻击者的网页直接不显示所有内容了。

给大家推荐一个好用的BUG监控工具Fundebug,欢迎免费试用!

四、URL跳转漏洞

定义:借助未验证的URL跳转,将应用程序引导到不安全的第三方区域,从而导致的安全问题。

1.URL跳转漏洞原理

黑客利用URL跳转漏洞来诱导安全意识低的用户点击,导致用户信息泄露或者资金的流失。其原理是黑客构建恶意链接(链接需要进行伪装,尽可能迷惑),发在QQ群或者是浏览量多的贴吧/论坛中。安全意识低的用户点击后,经过服务器或者浏览器解析后,跳到恶意的网站中。

恶意链接需要进行伪装,经常的做法是熟悉的链接后面加上一个恶意的网址,这样才迷惑用户。

诸如伪装成像如下的网址,你是否能够识别出来是恶意网址呢?

http://gate.baidu.com/index?act=go&url=http://t.cn/RVTatrd http://qt.qq.com/safecheck.html?flag=1&url=http://t.cn/RVTatrd http://tieba.baidu.com/f/user/passport?jumpUrl=http://t.cn/RVTatrd

2.实现方式:

Header头跳转Javascript跳转META标签跳转

这里我们举个Header头跳转实现方式:

<?php $url=$_GET['jumpto']; header("Location: $url"); ?> http://www.wooyun.org/login.php?jumpto=http://www.evil.com

这里用户会认为www.wooyun.org都是可信的,但是点击上述链接将导致用户最终访问www.evil.com这个恶意网址。

3.如何防御

1)referer的限制

如果确定传递URL参数进入的来源,我们可以通过该方式实现安全限制,保证该URL的有效性,避免恶意用户自己生成跳转链接

2)加入有效性验证Token

我们保证所有生成的链接都是来自于我们可信域的,通过在生成的链接里加入用户不可控的Token对生成的链接进行校验,可以避免用户生成自己的恶意链接从而被利用,但是如果功能本身要求比较开放,可能导致有一定的限制。

五、SQL注入

SQL注入是一种常见的Web安全漏洞,攻击者利用这个漏洞,可以访问或修改数据,或者利用潜在的数据库漏洞进行攻击。

1.SQL注入的原理

我们先举一个万能钥匙的例子来说明其原理:

<form action="/login" method="POST"> <p>Username: <input type="text" name="username" /></p> <p>Password: <input type="password" name="password" /></p> <p><input type="submit" value="登陆" /></p> </form>

后端的 SQL 语句可能是如下这样的:

let querySQL = ` SELECT * FROM user WHERE username='${username}' AND psw='${password}' `; // 接下来就是执行 sql 语句...

这是我们经常见到的登录页面,但如果有一个恶意攻击者输入的用户名是 admin' --,密码随意输入,就可以直接登入系统了。why! ----这就是SQL注入

我们之前预想的SQL 语句是:

SELECT * FROM user WHERE username='admin' AND psw='password'

但是恶意攻击者用奇怪用户名将你的 SQL 语句变成了如下形式:

SELECT * FROM user WHERE username='admin' --' AND psw='xxxx'

在 SQL 中,' --是闭合和注释的意思,-- 是注释后面的内容的意思,所以查询语句就变成了:

SELECT * FROM user WHERE username='admin'

所谓的万能密码,本质上就是SQL注入的一种利用方式。

一次SQL注入的过程包括以下几个过程:

获取用户请求参数拼接到代码当中SQL语句按照我们构造参数的语义执行成功

SQL注入的必备条件:

1.可以控制输入的数据2.服务器要执行的代码拼接了控制的数据。

我们会发现SQL注入流程中与正常请求服务器类似,只是黑客控制了数据,构造了SQL查询,而正常的请求不会SQL查询这一步,SQL注入的本质:数据和代码未分离,即数据当做了代码来执行。

2.危害

获取数据库信息管理员后台用户名和密码获取其他数据库敏感信息:用户名、密码、手机号码、身份证、银行卡信息……整个数据库:脱裤获取服务器权限植入Webshell,获取服务器后门读取服务器敏感文件

3.如何防御

严格限制Web应用的数据库的操作权限,给此用户提供仅仅能够满足其工作的最低权限,从而最大限度的减少注入攻击对数据库的危害后端代码检查输入的数据是否符合预期,严格限制变量的类型,例如使用正则表达式进行一些匹配处理。对进入数据库的特殊字符(',",,<,>,&,*,; 等)进行转义处理,或编码转换。基本上所有的后端语言都有对字符串进行转义处理的方法,比如 lodash 的 lodash._escapehtmlchar 库。所有的查询语句建议使用数据库提供的参数化查询接口,参数化的语句使用参数而不是将用户输入变量嵌入到 SQL 语句中,即不要直接拼接 SQL 语句。例如 Node.js 中的 mysqljs 库的 query 方法中的 ? 占位参数。六、OS命令注入攻击

OS命令注入和SQL注入差不多,只不过SQL注入是针对数据库的,而OS命令注入是针对操作系统的。OS命令注入攻击指通过Web应用,执行非法的操作系统命令达到攻击的目的。只要在能调用Shell函数的地方就有存在被攻击的风险。倘若调用Shell时存在疏漏,就可以执行插入的非法命令。

命令注入攻击可以向Shell发送命令,让Windows或Linux操作系统的命令行启动程序。也就是说,通过命令注入攻击可执行操作系统上安装着的各种程序。

1.原理

黑客构造命令提交给web应用程序,web应用程序提取黑客构造的命令,拼接到被执行的命令中,因黑客注入的命令打破了原有命令结构,导致web应用执行了额外的命令,最后web应用程序将执行的结果输出到响应页面中。

我们通过一个例子来说明其原理,假如需要实现一个需求:用户提交一些内容到服务器,然后在服务器执行一些系统命令去返回一个结果给用户

// 以 Node.js 为例,假如在接口中需要从 github 下载用户指定的 repo const exec = require('mz/child_process').exec; let params = {/* 用户输入的参数 */}; exec(`git clone ${params.repo} /some/path`);

如果 params.repo 传入的是 https://github.com/admin/admin.github.io.git 确实能从指定的 git repo 上下载到想要的代码。但是如果 params.repo 传入的是 https://github.com/xx/xx.git && rm -rf /* && 恰好你的服务是用 root 权限起的就糟糕了。

2.如何防御

后端对前端提交内容进行规则限制(比如正则表达式)。在调用系统命令前对所有传入参数进行命令行参数转义过滤。不要直接拼接命令语句,借助一些工具做拼接、转义预处理,例如 Node.js 的 shell-escape npm包

以上就是常见的web安全漏洞及防御方法!

PyTorch的使用经验都有哪些?

PyTorch 的构建者表明,PyTorch 的哲学是解决当务之急,也就是说即时构建和运行计算图。目前,PyTorch 也已经借助这种即时运行的概念成为最受欢迎的框架之一,开发者能快速构建模型与验证想法,并通过神经网络交换格式 ONNX 在多个框架之间快速迁移。本文从基本概念开始介绍了 PyTorch 的使用方法、训练经验与技巧,并展示了可能出现的问题与解决方案。

项目地址:https://github.com/Kaixhin/grokking-pytorch

PyTorch 是一种灵活的深度学习框架,它允许通过动态神经网络(例如利用动态控流——如 if 语句或 while 循环的网络)进行自动微分。它还支持 GPU 加速、分布式训练以及各类优化任务,同时还拥有许多更简洁的特性。以下是作者关于如何利用 PyTorch 的一些说明,里面虽然没有包含该库的所有细节或最优方法,但可能会对大家有所帮助。

神经网络是计算图的一个子类。计算图接收输入数据,数据被路由到对数据执行处理的节点,并可能被这些节点转换。在深度学习中,神经网络中的神经元(节点)通常利用参数或可微函数转换数据,这样可以优化参数以通过梯度下降将损失最小化。更广泛地说,函数是随机的,图结构可以是动态的。所以说,虽然神经网络可能非常适合数据流式编程,但 PyTorch 的 API 却更关注命令式编程——一种编程更常考虑的形式。这令读取代码和推断复杂程序变得简单,而无需损耗不必要的性能;PyTorch 速度很快,且拥有大量优化,作为终端用户你毫无后顾之忧。

本文其余部分写的是关于 grokking PyTorch 的内容,都是基于 MINIST 官网实例,应该要在学习完官网初学者教程后再查看。为便于阅读,代码以块状形式呈现,并带有注释,因此不会像纯模块化代码一样被分割成不同的函数或文件。

Pytorch 基础

PyTorch 使用一种称之为 imperative / eager 的范式,即每一行代码都要求构建一个图,以定义完整计算图的一个部分。即使完整的计算图还没有构建好,我们也可以独立地执行这些作为组件的小计算图,这种动态计算图被称为「define-by-run」方法。

PyTorch 张量

正如 PyTorch 文档所说,如果我们熟悉 NumPy 的多维数组,那么 Torch 张量的很多操作我们能轻易地掌握。PyTorch 提供了 CPU 张量和 GPU 张量,并且极大地加速了计算的速度。

从张量的构建与运行就能体会,相比 TensorFLow,在 PyTorch 中声明张量、初始化张量要简洁地多。例如,使用 torch.Tensor(5, 3) 语句就能随机初始化一个 5×3 的二维张量,因为 PyTorch 是一种动态图,所以它声明和真实赋值是同时进行的。

在 PyTorch 中,torch.Tensor 是一种多维矩阵,其中每个元素都是单一的数据类型,且该构造函数默认为 torch.FloatTensor。以下是具体的张量类型:

除了直接定义维度,一般我们还可以从 Python 列表或 NumPy 数组中创建张量。而且根据使用 Python 列表和元组等数据结构的习惯,我们可以使用相似的索引方式进行取值或赋值。PyTorch 同样支持广播(Broadcasting)操作,一般它会隐式地把一个数组的异常维度调整到与另一个算子相匹配的维度,以实现维度兼容。

自动微分模块

TensorFlow、Caffe 和 CNTK 等大多数框架都使用静态计算图,开发者必须建立或定义一个神经网络,并重复使用相同的结构来执行模型训练。改变网络的模式就意味着我们必须从头开始设计并定义相关的模块。

但 PyTorch 使用的技术为自动微分(automatic differentiation)。在这种机制下,系统会有一个 Recorder 来记录我们执行的运算,然后再反向计算对应的梯度。这种技术在构建神经网络的过程中十分强大,因为我们可以通过计算前向传播过程中参数的微分来节省时间。

从概念上来说,Autograd 会维护一个图并记录对变量执行的所有运算。这会产生一个有向无环图,其中叶结点为输入向量,根结点为输出向量。通过从根结点到叶结点追踪图的路径,我们可以轻易地使用链式法则自动计算梯度。

在内部,Autograd 将这个图表征为 Function 对象的图,并且可以应用 apply() 计算评估图的结果。在计算前向传播中,当 Autograd 在执行请求的计算时,它还会同时构建一个表征梯度计算的图,且每个 Variable 的 .grad_fn 属性就是这个图的输入单元。在前向传播完成后,我们可以在后向传播中根据这个动态图来计算梯度。

PyTorch 还有很多基础的模块,例如控制学习过程的最优化器、搭建深度模型的神经网络模块和数据加载与处理等。这一节展示的张量与自动微分模块是 PyTorch 最为核心的概念之一,读者可查阅 PyTorch 文档了解更详细的内容。

下面作者以 MNIST 为例从数据加载到模型测试具体讨论了 PyTorch 的使用、思考技巧与陷阱。

PyTorch 实用指南

导入

import argparseimport torchfrom torch import nn, optimfrom torch.nn import functional as Ffrom torch.utils.data import DataLoaderfrom torchvision import datasets, transforms

除了用于计算机视觉问题的 torchvision 模块外,这些都是标准化导入。

设置

parser = argparse.ArgumentParser(description='PyTorch MNIST Example')parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)')parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)')parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)')parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)')parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')parser.add_argument('--save-interval', type=int, default=10, metavar='N', help='how many batches to wait before checkpointing')parser.add_argument('--resume', action='store_true', default=False, help='resume training from checkpoint')args = parser.parse_args()use_cuda = torch.cuda.is_available() and not args.no_cudadevice = torch.device('cuda' if use_cuda else 'cpu')torch.manual_seed(args.seed)if use_cuda: torch.cuda.manual_seed(args.seed)

argparse 是在 Python 中处理命令行参数的一种标准方式。

编写与设备无关的代码(可用时受益于 GPU 加速,不可用时会倒退回 CPU)时,选择并保存适当的 torch.device, 不失为一种好方法,它可用于确定存储张量的位置。关于与设备无关代码的更多内容请参阅官网文件。PyTorch 的方法是使用户能控制设备,这对简单示例来说有些麻烦,但是可以更容易地找出张量所在的位置——这对于 a)调试很有用,并且 b)可有效地使用手动化设备。

对于可重复实验,有必要为使用随机数生成的任何数据设置随机种子(如果也使用随机数,则包括随机或 numpy)。要注意,cuDNN 用的是非确定算法,可以通过语句 torch.backends.cudnn.enabled = False 将其禁用。

数据

train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))test_data = datasets.MNIST('data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)test_loader = DataLoader(test_data, batch_size=args.batch_size, num_workers=4, pin_memory=True)

torchvision.transforms 对于单张图像有非常多便利的转换工具,例如裁剪和归一化等。

DataLoader 包含非常多的参数,除了 batch_size 和 shuffle,num_workers 和 pin_memory 对于高效加载数据同样非常重要。例如配置 num_workers > 0 将使用子进程异步加载数据,而不是使用一个主进程块加载数据。参数 pin_memory 使用固定 RAM 以加速 RAM 到 GPU 的转换,且在仅使用 CPU 时不会做任何运算。

模型

class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1)model = Net().to(device)optimiser = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)if args.resume: model.load_state_dict(torch.load('model.pth')) optimiser.load_state_dict(torch.load('optimiser.pth'))

神经网络初始化一般包括变量、包含可训练参数的层级、可能独立的可训练参数和不可训练的缓存器。随后前向传播将这些初始化参数与 F 中的函数结合,其中该函数为不包含参数的纯函数。有些开发者喜欢使用完全函数化的网络(如保持所有参数独立,使用 F.conv2d 而不是 nn.Conv2d),或者完全由 layers 函数构成的网络(如使用 nn.ReLU 而不是 F.relu)。

在将 device 设置为 GPU 时,.to(device) 是一种将设备参数(和缓存器)发送到 GPU 的便捷方式,且在将 device 设置为 CPU 时不会做任何处理。在将网络参数传递给优化器之前,把它们传递给适当的设备非常重要,不然的话优化器不能正确地追踪参数。

神经网络(nn.Module)和优化器(optim.Optimizer)都能保存和加载它们的内部状态,而.load_state_dict(state_dict) 是完成这一操作的推荐方法,我们可以从以前保存的状态字典中加载两者的状态并恢复训练。此外,保存整个对象可能会出错。

这里没讨论的一些注意事项即前向传播可以使用控制流,例如一个成员变量或数据本身能决定 if 语句的执行。此外,在前向传播的过程中打印张量也是可行的,这令 debug 更加简单。最后,前向传播可以使用多个参数。以下使用间断的代码块展示这一点:

def forward(self, x, hx, drop=False):

hx2 = self.rnn(x, hx)

print(hx.mean().item(), hx.var().item())

if hx.max.item() > 10 or self.can_drop and drop:

return hx

else:

return hx2

训练

model.train()train_losses = []for i, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimiser.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() train_losses.append(loss.item()) optimiser.step() if i % 10 == 0: print(i, loss.item()) torch.save(model.state_dict(), 'model.pth') torch.save(optimiser.state_dict(), 'optimiser.pth') torch.save(train_losses, 'train_losses.pth')

网络模块默认设置为训练模式,这影响了某些模块的工作方式,最明显的是 dropout 和批归一化。最好用.train() 对其进行手动设置,这样可以把训练标记向下传播到所有子模块。

在使用 loss.backward() 收集一系列新的梯度以及用 optimiser.step() 做反向传播之前,有必要手动地将由 optimiser.zero_grad() 优化的参数梯度归零。默认情况下,PyTorch 会累加梯度,在单次迭代中没有足够资源来计算所有需要的梯度时,这种做法非常便利。

PyTorch 使用一种基于 tape 的自动化梯度(autograd)系统,它收集按顺序在张量上执行的运算,然后反向重放它们来执行反向模式微分。这正是为什么 PyTorch 如此灵活并允许执行任意计算图的原因。如果没有张量需要做梯度更新(当你需要为该过程构建一个张量时,你必须设置 requires_grad=True),则不需要保存任何图。然而,网络倾向于包含需要梯度更新的参数,因此任何网络输出过程中执行的计算都将保存在图中。因此如果想保存在该过程中得到的数据,你将需要手动禁止梯度更新,或者,更常见的做法是将其保存为一个 Python 数(通过一个 Python 标量上的.item())或者 NumPy 数组。更多关于 autograd 的细节详见官网文件。

截取计算图的一种方式是使用.detach(),当通过沿时间的截断反向传播训练 RNN 时,数据流传递到一个隐藏状态可能会应用这个函数。当对损失函数求微分(其中一个成分是另一个网络的输出)时,也会很方便。但另一个网络不应该用「loss - examples」的模式进行优化,包括在 GAN 训练中从生成器的输出训练判别器,或使用价值函数作为基线(例如 A2C)训练 actor-critic 算法的策略。另一种在 GAN 训练(从判别器训练生成器)中能高效阻止梯度计算的方法是在整个网络参数上建立循环,并设置 param.requires_grad=False,这在微调中也很常用。

除了在控制台/日志文件里记录结果以外,检查模型参数(以及优化器状态)也是很重要的。你还可以使用 torch.save() 来保存一般的 Python 对象,但其它标准选择还包括内建的 pickle。

测试

model.train()train_losses = []for i, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimiser.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() train_losses.append(loss.item()) optimiser.step() if i % 10 == 0: print(i, loss.item()) torch.save(model.state_dict(), 'model.pth') torch.save(optimiser.state_dict(), 'optimiser.pth') torch.save(train_losses, 'train_losses.pth')

为了早点响应.train(),应利用.eval() 将网络明确地设置为评估模式。

正如前文所述,计算图通常会在使用网络时生成。通过 with torch.no_grad() 使用 no_grad 上下文管理器,可以防止这种情况发生。

其它

内存有问题?可以查看官网文件获取帮助。

CUDA 出错?它们很难调试,而且通常是一个逻辑问题,会在 CPU 上产生更易理解的错误信息。如果你计划使用 GPU,那最好能够在 CPU 和 GPU 之间轻松切换。更普遍的开发技巧是设置代码,以便在启动合适的项目(例如准备一个较小/合成的数据集、运行一个 train + test epoch 等)之前快速运行所有逻辑来检查它。如果这是一个 CUDA 错误,或者你没法切换到 CPU,设置 CUDA_LAUNCH_BLOCKING=1 将使 CUDA 内核同步启动,从而提供更详细的错误信息。

torch.multiprocessing,甚至只是一次运行多个 PyTorch 脚本的注意事项。因为 PyTorch 使用多线程 BLAS 库来加速 CPU 上的线性代数计算,所以它通常需要使用多个内核。如果你想一次运行多个任务,在具有多进程或多个脚本的情况下,通过将环境变量 OMP_NUM_THREADS 设置为 1 或另一个较小的数字来手动减少线程,这样做减少了 CPU thrashing 的可能性。官网文件还有一些其它注意事项,尤其是关于多进程。

发表评论

快捷回复: 表情:
AddoilApplauseBadlaughBombCoffeeFabulousFacepalmFecesFrownHeyhaInsidiousKeepFightingNoProbPigHeadShockedSinistersmileSlapSocialSweatTolaughWatermelonWittyWowYeahYellowdog
评论列表 (暂无评论,105人围观)

还没有评论,来说两句吧...