6.3. Parameter Initialization¶ Open the notebook in SageMaker Studio Lab
Now that we know how to access the parameters, let’s look at how to initialize them properly. We discussed the need for proper initialization in Section 5.4. The deep learning framework provides default random initializations to its layers. However, we often want to initialize our weights according to various other protocols. The framework provides most commonly used protocols, and also allows to create a custom initializer.
import torch
from torch import nn
By default, PyTorch initializes weight and bias matrices uniformly by
drawing from a range that is computed according to the input and output
dimension. PyTorch’s nn.init
module provides a variety of preset
initialization methods.
net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1))
X = torch.rand(size=(2, 4))
net(X).shape
torch.Size([2, 1])
from mxnet import init, np, npx
from mxnet.gluon import nn
npx.set_np()
By default, MXNet initializes weight parameters by randomly drawing from
a uniform distribution \(U(-0.07, 0.07)\), clearing bias parameters
to zero. MXNet’s init
module provides a variety of preset
initialization methods.
net = nn.Sequential()
net.add(nn.Dense(8, activation='relu'))
net.add(nn.Dense(1))
net.initialize() # Use the default initialization method
X = np.random.uniform(size=(2, 4))
net(X).shape
(2, 1)
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
By default, Flax initializes weights using
jax.nn.initializers.lecun_normal
, i.e., by drawing samples from a
truncated normal distribution centered on 0 with the standard deviation
set as the squared root of \(1 / \text{fan}_{\text{in}}\) where
fan_in
is the number of input units in the weight tensor. The bias
parameters are all set to zero. Jax’s nn.initializers
module
provides a variety of preset initialization methods.
net = nn.Sequential([nn.Dense(8), nn.relu, nn.Dense(1)])
X = jax.random.uniform(d2l.get_key(), (2, 4))
params = net.init(d2l.get_key(), X)
net.apply(params, X).shape
(2, 1)
import tensorflow as tf
By default, Keras initializes weight matrices uniformly by drawing from
a range that is computed according to the input and output dimension,
and the bias parameters are all set to zero. TensorFlow provides a
variety of initialization methods both in the root module and the
keras.initializers
module.
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(4, activation=tf.nn.relu),
tf.keras.layers.Dense(1),
])
X = tf.random.uniform((2, 4))
net(X).shape
TensorShape([2, 1])
6.3.1. Built-in Initialization¶
Let’s begin by calling on built-in initializers. The code below initializes all weight parameters as Gaussian random variables with standard deviation 0.01, while bias parameters cleared to zero.
def init_normal(module):
if type(module) == nn.Linear:
nn.init.normal_(module.weight, mean=0, std=0.01)
nn.init.zeros_(module.bias)
net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]
(tensor([-0.0089, 0.0039, -0.0204, -0.0059]), tensor(0.))
# Here force_reinit ensures that parameters are freshly initialized even if
# they were already initialized previously
net.initialize(init=init.Normal(sigma=0.01), force_reinit=True)
net[0].weight.data()[0]
array([ 0.00354961, -0.00614133, 0.0107317 , 0.01830765])
weight_init = nn.initializers.normal(0.01)
bias_init = nn.initializers.zeros
net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
nn.relu,
nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])
params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]
(Array([ 0.00196291, 0.00416906, -0.00880515, -0.00701235], dtype=float32),
Array(0., dtype=float32))
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
4, activation=tf.nn.relu,
kernel_initializer=tf.random_normal_initializer(mean=0, stddev=0.01),
bias_initializer=tf.zeros_initializer()),
tf.keras.layers.Dense(1)])
net(X)
net.weights[0], net.weights[1]
(<tf.Variable 'dense_2/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[ 0.00545899, -0.0061681 , 0.01053761, 0.00168952],
[ 0.02327557, -0.00167401, -0.00378414, 0.00909223],
[-0.00641173, -0.01181791, -0.01811306, -0.01030261],
[ 0.01228647, -0.00428819, 0.00622066, -0.01840671]],
dtype=float32)>,
<tf.Variable 'dense_2/bias:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>)
We can also initialize all the parameters to a given constant value (say, 1).
def init_constant(module):
if type(module) == nn.Linear:
nn.init.constant_(module.weight, 1)
nn.init.zeros_(module.bias)
net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]
(tensor([1., 1., 1., 1.]), tensor(0.))
net.initialize(init=init.Constant(1), force_reinit=True)
net[0].weight.data()[0]
array([1., 1., 1., 1.])
weight_init = nn.initializers.constant(1)
net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
nn.relu,
nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])
params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]
(Array([1., 1., 1., 1.], dtype=float32), Array(0., dtype=float32))
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
4, activation=tf.nn.relu,
kernel_initializer=tf.keras.initializers.Constant(1),
bias_initializer=tf.zeros_initializer()),
tf.keras.layers.Dense(1),
])
net(X)
net.weights[0], net.weights[1]
(<tf.Variable 'dense_4/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=float32)>,
<tf.Variable 'dense_4/bias:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>)
We can also apply different initializers for certain blocks. For example, below we initialize the first layer with the Xavier initializer and initialize the second layer to a constant value of 42.
def init_xavier(module):
if type(module) == nn.Linear:
nn.init.xavier_uniform_(module.weight)
def init_42(module):
if type(module) == nn.Linear:
nn.init.constant_(module.weight, 42)
net[0].apply(init_xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)
tensor([ 0.0542, -0.6922, -0.2629, -0.4876])
tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])
net[0].weight.initialize(init=init.Xavier(), force_reinit=True)
net[1].initialize(init=init.Constant(42), force_reinit=True)
print(net[0].weight.data()[0])
print(net[1].weight.data())
[-0.26102373 0.15249556 -0.19274211 -0.24742058]
[[42. 42. 42. 42. 42. 42. 42. 42.]]
net = nn.Sequential([nn.Dense(8, kernel_init=nn.initializers.xavier_uniform(),
bias_init=bias_init),
nn.relu,
nn.Dense(1, kernel_init=nn.initializers.constant(42),
bias_init=bias_init)])
params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
params['params']['layers_0']['kernel'][:, 0], params['params']['layers_2']['kernel']
(Array([-0.12841032, 0.26980272, 0.57240933, 0.2627143 ], dtype=float32),
Array([[42.],
[42.],
[42.],
[42.],
[42.],
[42.],
[42.],
[42.]], dtype=float32))
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
4,
activation=tf.nn.relu,
kernel_initializer=tf.keras.initializers.GlorotUniform()),
tf.keras.layers.Dense(
1, kernel_initializer=tf.keras.initializers.Constant(42)),
])
net(X)
print(net.layers[1].weights[0])
print(net.layers[2].weights[0])
<tf.Variable 'dense_6/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[ 0.5799206 , -0.5282545 , 0.28409976, 0.19763249],
[-0.2691564 , -0.4106831 , 0.18196362, 0.3139177 ],
[-0.84266007, -0.61732936, -0.80762327, 0.31139785],
[-0.16522151, -0.85030633, 0.0580222 , -0.0847258 ]],
dtype=float32)>
<tf.Variable 'dense_7/kernel:0' shape=(4, 1) dtype=float32, numpy=
array([[42.],
[42.],
[42.],
[42.]], dtype=float32)>
6.3.1.1. Custom Initialization¶
Sometimes, the initialization methods we need are not provided by the deep learning framework. In the example below, we define an initializer for any weight parameter \(w\) using the following strange distribution:
Again, we implement a my_init
function to apply to net
.
def my_init(module):
if type(module) == nn.Linear:
print("Init", *[(name, param.shape)
for name, param in module.named_parameters()][0])
nn.init.uniform_(module.weight, -10, 10)
module.weight.data *= module.weight.data.abs() >= 5
net.apply(my_init)
net[0].weight[:2]
Init weight torch.Size([8, 4])
Init weight torch.Size([1, 8])
tensor([[-8.4801, -0.0000, 0.0000, -5.6451],
[ 5.6075, 0.0000, 5.4012, -5.0333]], grad_fn=<SliceBackward0>)
Note that we always have the option of setting parameters directly.
net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]
tensor([42.0000, 1.0000, 1.0000, -4.6451])
Here we define a subclass of the Initializer
class. Usually, we only
need to implement the _init_weight
function which takes a tensor
argument (data
) and assigns to it the desired initialized values.
class MyInit(init.Initializer):
def _init_weight(self, name, data):
print('Init', name, data.shape)
data[:] = np.random.uniform(-10, 10, data.shape)
data *= np.abs(data) >= 5
net.initialize(MyInit(), force_reinit=True)
net[0].weight.data()[:2]
Init dense0_weight (8, 4)
Init dense1_weight (1, 8)
array([[-6.0683527, 8.991421 , -0. , 0. ],
[ 6.4198647, -9.728567 , -8.057975 , 0. ]])
Note that we always have the option of setting parameters directly.
net[0].weight.data()[:] += 1
net[0].weight.data()[0, 0] = 42
net[0].weight.data()[0]
array([42. , 9.991421, 1. , 1. ])
Jax initialization functions take as arguments the PRNGKey
,
shape
and dtype
. Here we implement the function my_init
that
returns a desired tensor given the shape and data type.
def my_init(key, shape, dtype=jnp.float_):
data = jax.random.uniform(key, shape, minval=-10, maxval=10)
return data * (jnp.abs(data) >= 5)
net = nn.Sequential([nn.Dense(8, kernel_init=my_init), nn.relu, nn.Dense(1)])
params = net.init(d2l.get_key(), X)
print(params['params']['layers_0']['kernel'][:, :2])
[[ 0. 0. ]
[ 0. 0. ]
[ 0. 0. ]
[-6.9971204 0. ]]
When initializing parameters in JAX and Flax, the the dictionary of
parameters returned has a flax.core.frozen_dict.FrozenDict
type. It
is not advisable in the Jax ecosystem to directly alter the values of an
array, hence the datatypes are generally immutable. One might use
params.unfreeze()
to make changes.
Here we define a subclass of Initializer
and implement the
__call__
function that return a desired tensor given the shape and
data type.
class MyInit(tf.keras.initializers.Initializer):
def __call__(self, shape, dtype=None):
data=tf.random.uniform(shape, -10, 10, dtype=dtype)
factor=(tf.abs(data) >= 5)
factor=tf.cast(factor, tf.float32)
return data * factor
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
4,
activation=tf.nn.relu,
kernel_initializer=MyInit()),
tf.keras.layers.Dense(1),
])
net(X)
print(net.layers[1].weights[0])
<tf.Variable 'dense_8/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[ 0. , -5.465231 , 0. , 8.921467 ],
[ 0. , 0. , -0. , 8.300743 ],
[ 9.480003 , -0. , -0. , -9.877076 ],
[ 0. , -9.207285 , -5.5327415, 9.438238 ]], dtype=float32)>
Note that we always have the option of setting parameters directly.
net.layers[1].weights[0][:].assign(net.layers[1].weights[0] + 1)
net.layers[1].weights[0][0, 0].assign(42)
net.layers[1].weights[0]
<tf.Variable 'dense_8/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[42. , -4.465231 , 1. , 9.921467 ],
[ 1. , 1. , 1. , 9.300743 ],
[10.480003 , 1. , 1. , -8.877076 ],
[ 1. , -8.207285 , -4.5327415, 10.438238 ]], dtype=float32)>
6.3.2. Summary¶
We can initialize parameters using built-in and custom initializers.
6.3.3. Exercises¶
Look up the online documentation for more built-in initializers.