PyTorch: passing numpy array for weight initialization(PyTorch:传递 numpy 数组进行权重初始化)
问题描述
我想用np数组初始化RNN的参数.
I'd like to initialize the parameters of RNN with np arrays.
在下面的示例中,我想将 w
传递给 rnn
的参数.我知道pytorch提供了很多初始化方法,比如Xavier、uniform等,但是有没有办法通过传递numpy数组来初始化参数呢?
In the following example, I want to pass w
to the parameters of rnn
. I know pytorch provides many initialization methods like Xavier, uniform, etc., but is there way to initialize the parameters by passing numpy arrays?
import numpy as np
import torch as nn
rng = np.random.RandomState(313)
w = rng.randn(input_size, hidden_size).astype(np.float32)
rnn = nn.RNN(input_size, hidden_size, num_layers)
推荐答案
首先,让我们注意 nn.RNN
有多个权重变量,c.f.文档:
First, let's note that nn.RNN
has more than one weight variable, c.f. the documentation:
变量:
weight_ih_l[k]
– 第k
层的可学习输入隐藏权重,形状为(hidden_size * input_size)
k = 0
.否则,形状是(hidden_size * hidden_size)
weight_hh_l[k]
–k
层的可学习隐藏权重,形状为(hidden_size * hidden_size)
bias_ih_l[k]
–k
层的可学习输入隐藏偏差,形状为(hidden_size)
bias_hh_l[k]
–k
-th 层的可学习 hidden-hidden 偏差,形状为(hidden_size)
weight_ih_l[k]
– the learnable input-hidden weights of thek
-th layer, of shape(hidden_size * input_size)
fork = 0
. Otherwise, the shape is(hidden_size * hidden_size)
weight_hh_l[k]
– the learnable hidden-hidden weights of thek
-th layer, of shape(hidden_size * hidden_size)
bias_ih_l[k]
– the learnable input-hidden bias of thek
-th layer, of shape(hidden_size)
bias_hh_l[k]
– the learnable hidden-hidden bias of thek
-th layer, of shape(hidden_size)
现在,每个变量(Parameter
实例)是 nn.RNN
实例的属性.您可以通过两种方式访问和编辑它们,如下所示:
Now, each of these variables (Parameter
instances) are attributes of your nn.RNN
instance. You can access them, and edit them, two ways, as show below:
- 方案一:按名称访问所有RNN
Parameter
属性(rnn.weight_hh_lK
、rnn.weight_ih_lK
等):
import torch
from torch import nn
import numpy as np
input_size, hidden_size, num_layers = 3, 4, 2
use_bias = True
rng = np.random.RandomState(313)
rnn = nn.RNN(input_size, hidden_size, num_layers, bias=use_bias)
def set_nn_parameter_data(layer, parameter_name, new_data):
param = getattr(layer, parameter_name)
param.data = new_data
for i in range(num_layers):
weights_hh_layer_i = rng.randn(hidden_size, hidden_size).astype(np.float32)
weights_ih_layer_i = rng.randn(hidden_size, hidden_size).astype(np.float32)
set_nn_parameter_data(rnn, "weight_hh_l{}".format(i),
torch.from_numpy(weights_hh_layer_i))
set_nn_parameter_data(rnn, "weight_ih_l{}".format(i),
torch.from_numpy(weights_ih_layer_i))
if use_bias:
bias_hh_layer_i = rng.randn(hidden_size).astype(np.float32)
bias_ih_layer_i = rng.randn(hidden_size).astype(np.float32)
set_nn_parameter_data(rnn, "bias_hh_l{}".format(i),
torch.from_numpy(bias_hh_layer_i))
set_nn_parameter_data(rnn, "bias_ih_l{}".format(i),
torch.from_numpy(bias_ih_layer_i))
- 方案二:通过
rnn.all_weights
列表属性访问所有RNNParameter
属性: - Solution 2: Accessing all the RNN
Parameter
attributes throughrnn.all_weights
list attribute:
import torch
from torch import nn
import numpy as np
input_size, hidden_size, num_layers = 3, 4, 2
use_bias = True
rng = np.random.RandomState(313)
rnn = nn.RNN(input_size, hidden_size, num_layers, bias=use_bias)
for i in range(num_layers):
weights_hh_layer_i = rng.randn(hidden_size, hidden_size).astype(np.float32)
weights_ih_layer_i = rng.randn(hidden_size, hidden_size).astype(np.float32)
rnn.all_weights[i][0].data = torch.from_numpy(weights_ih_layer_i)
rnn.all_weights[i][1].data = torch.from_numpy(weights_hh_layer_i)
if use_bias:
bias_hh_layer_i = rng.randn(hidden_size).astype(np.float32)
bias_ih_layer_i = rng.randn(hidden_size).astype(np.float32)
rnn.all_weights[i][2].data = torch.from_numpy(bias_ih_layer_i)
rnn.all_weights[i][3].data = torch.from_numpy(bias_hh_layer_i)
这篇关于PyTorch:传递 numpy 数组进行权重初始化的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
本文标题为:PyTorch:传递 numpy 数组进行权重初始化
基础教程推荐
- Dask.array.套用_沿_轴:由于额外的元素([1]),使用dask.array的每一行作为另一个函数的输入失败 2022-01-01
- 线程时出现 msgbox 错误,GUI 块 2022-01-01
- 用于分类数据的跳跃记号标签 2022-01-01
- 何时使用 os.name、sys.platform 或 platform.system? 2022-01-01
- 使用PyInstaller后在Windows中打开可执行文件时出错 2022-01-01
- 如何在海运重新绘制中自定义标题和y标签 2022-01-01
- Python kivy 入口点 inflateRest2 无法定位 libpng16-16.dll 2022-01-01
- 筛选NumPy数组 2022-01-01
- 在 Python 中,如果我在一个“with"中返回.块,文件还会关闭吗? 2022-01-01
- 如何让 python 脚本监听来自另一个脚本的输入 2022-01-01