MindSpore反向传播配置关键字参数

news/2024/10/2 16:22:50

技术背景

在MindSpore深度学习框架中,我们可以向construct函数传输必备参数或者关键字参数,这跟普通的Python函数没有什么区别。但是对于MindSpore中的自定义反向传播bprop函数,因为标准化格式决定了最后的两位函数输入必须是必备参数outdout用于接收函数值和导数值。那么对于一个自定义的反向传播函数而言,我们有可能要传入多个参数。例如这样的一个案例:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnpclass Net(nn.Cell):def bprop(self, x, y=1, out, dout):return msnp.cos(x) + ydef construct(self, x, y=1):return msnp.sin(x) + yx = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0))

但是因为在Python的函数传参规则下,必备参数必须放在关键字参数之前,也就是out和dout这两个参数要放在前面,否则就会出现这样的报错:

  File "test_rand.py", line 53def bprop(self, x, y=1, out, dout):^
SyntaxError: non-default argument follows default argument

按照普通Python函数的传参规则,我们可以把y这个关键字参数的放到最后面去:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnpclass Net(nn.Cell):def bprop(self, x, out, dout, y=1):return msnp.cos(x) + ydef construct(self, x, y=1):return msnp.sin(x) + yx = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0))

经过这一番调整之后,我们发现没有报错了,可以正常输出结果,但是这个结果似乎不太正常:

[1.0015925] (Tensor(shape=[1], dtype=Float32, value= [ 1.59254798e-03]), Tensor(shape=[1], dtype=Float32, value= [ 1.25169754e-06]))

因为这里x传入了一个近似的\(\pi\),所以在construct函数计算函数值时,得到的结果应该是\(\sin(\pi)+y\),那么这里面\(y\)\(0\)\(1\)所得到的结果都是对的。但是关键问题在反向传播函数的计算,原本应该是\(\cos(\pi)+y=y-1\),但是在这里输入的\(y=0\),而导数的计算结果却是\(0\)而不是正确结果\(-1\)。这就说明,在MindSpore的自定义反向传播函数中,并不支持传入关键字参数。

解决方案

刚好前面写了一篇关于PyTorch的文章,这篇文章中提到的两个Issue就针对此类问题。受到这两个Issue的启发,我们在MindSpore中如果需要自定义反向传播函数,可以这么写:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnpclass Net(nn.Cell):def bprop(self, x, y, out, dout):return msnp.cos(x) + y if y is not None else msnp.cos(x)def construct(self, x, y=1):return msnp.sin(x) + yx = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0))

简单来说就是,把原本要传给bprop的关键字参数,转换成必备参数的方式进行传入,然后做一个条件判断:当给定了该输入的时候,执行计算一,如果不给定参数值,或者给一个None,执行计算二。上述代码的执行结果如下所示:

[1.0015925] (Tensor(shape=[1], dtype=Float32, value= [ 1.59254798e-03]), Tensor(shape=[1], dtype=Float32, value= [-9.99998748e-01]))

这里输出的结果都是正确的。

当然,这里因为我们其实是强行把关键字参数按照顺序变成了必备参数进行输入,所以在顺序上一定要严格遵守bprop所定义的必备参数的顺序,否则计算结果也会出错:

import mindspore as ms
from mindspore import nn, Tensor, value_and_grad
from mindspore import numpy as msnpclass Net(nn.Cell):def bprop(self, x, w, y, out, dout):return w*msnp.cos(x) + y if y is not None else msnp.cos(x)def construct(self, x, w=1, y=1):return msnp.sin(x) + yx = Tensor([3.14], ms.float32)
net = Net()
print (net(x, y=1), value_and_grad(net)(x, y=0, w=2))

输出的结果为:

[1.0015925] (Tensor(shape=[1], dtype=Float32, value= [ 1.59254798e-03]), Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]))

那么很显然,这个结果就是因为在执行函数时给定的关键字参数跟必备参数顺序不一致,所以才出错的。

总结概要

继上一篇文章从Torch的两个Issue中找到一些类似的问题之后,可以发现深度学习框架对于自定义反向传播函数中的传参还是比较依赖于必备参数,而不是关键字参数,MindSpore深度学习框架也是如此。但是我们可以使用一些临时的解决方案,对此问题进行一定程度上的规避,只要能够自定义的传参顺序传入关键字参数即可。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/bprop-kwargs.html

作者ID:DechinPhy

更多原著文章:https://www.cnblogs.com/dechinphy/

请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

参考链接

  1. https://www.cnblogs.com/dechinphy/p/18179248/torch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.ryyt.cn/news/29364.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

AR精灵——风险分析和典型用户

风险分析典型用户 典型用户一名字:盛宇伟 年龄:28岁,收入:每月约8000元 代表的用户在市场上的比例和重要性:虽然使用AR精灵的付费用户比例较少,但他们对产品的热爱和忠诚度很高,他们的反馈和建议对产品的改进至关重要。 使用这个软件的典型场景:李梅在下班后回到家中,…

BOSHIDA AC/DC电源模块在通信与网络设备中的应用研究

BOSHIDA AC/DC电源模块在通信与网络设备中的应用研究 随着通信与网络技术的不断发展,通信与网络设备的使用不断增加。电源作为通信与网络设备的重要组成部分之一,在其稳定工作中起到至关重要的作用。AC/DC电源模块作为一种常用的电源转换器,广泛应用于通信与网络设备中。 一…

AWVS(acunetix) 安装详细教程

一、软件介绍Acunetix Web Vulnerability Scanner(简称AWVS)是一款知名的Web网络漏洞扫描工具,它通过网络爬虫测试你的网站安全,检测流行安全漏洞。AWVS官方网站是:http://www.acunetix.com/ 软件有window版本,linux版本,还可以docker安装 二、下载安装 官方下载地址: h…

国密算法SM3-java实现

aven依赖<dependency><groupId>org.bouncycastle</groupId><artifactId>bcprov-jdk15on</artifactId><version>1.56</version> </dependency> SM3Utilsimport org.bouncycastle.crypto.digests.SM3Digest; import org.bouncy…

time:Python的时间时钟处理

前言 time库运行访问多种类型的时钟,这些时钟用于不同的场景。本篇,将详细讲解time库的应用知识。 获取各种时钟 既然time库提供了多种类型的时钟。下面我们直接来获取这些时钟,对比其具体的用途。具体代码如下: import timeprint(time.monotonic()) print(time.monotonic_…

时间函数的简单理解和应用(time.h)

目录关于时间的函数及tm结构体描述对函数的简单理解操作函数功能实现 关于时间的函数及tm结构体描述time.h头文件中常用的几个函数描述如下:序号 函数&描述1 time_t time(time_t *tloc)最基础的函数,计算当前时间,并返回成 time_t(aka long int)格式而且返回值是从Epoch…

Streamlit:快速构建可视化网页(数据科学必备)

很多算法工程师在完成数据分析、模型训练或者项目总结的时候,往往只能通过ppt汇报,添加数据图表、截图模型实验结果等。如果想提供一个前端演示demo,通常可以搭建flask服务,但是flask需要学习很多前端知识,如css、html等,这又是一个深之又深的坑。那有没有什么工具能够跳…

SpringBoot项目GraalVM迁移

一些背景 一直想把项目迁移到使用GraalVM构建出的原生应用上,但是在前段时间的一次尝试后,发现很难做到,其中一个最主要原因就在于我目前手头上没有X86架构的电脑。平时我使用的是一个M1处理器的MacBook,编译出的Docker镜像架构指令集也是Arm64的,无法在我的X86服务器启动…