6.3. Parameter Initialization
Open the notebook in Colab
Open the notebook in SageMaker Studio Lab

Now that we know how to access the parameters, let’s look at how to initialize them properly. We discussed the need for proper initialization in Section 5.4. The deep learning framework provides default random initializations to its layers. However, we often want to initialize our weights according to various other protocols. The framework provides most commonly used protocols, and also allows to create a custom initializer.

import torch
from torch import nn
Copy to clipboard

By default, PyTorch initializes weight and bias matrices uniformly by drawing from a range that is computed according to the input and output dimension. PyTorch’s nn.init module provides a variety of preset initialization methods.

net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1))
X = torch.rand(size=(2, 4))
net(X).shape
Copy to clipboard
torch.Size([2, 1])
from mxnet import init, np, npx
from mxnet.gluon import nn

npx.set_np()
Copy to clipboard

By default, MXNet initializes weight parameters by randomly drawing from a uniform distribution U(0.07,0.07), clearing bias parameters to zero. MXNet’s init module provides a variety of preset initialization methods.

net = nn.Sequential()
net.add(nn.Dense(8, activation='relu'))
net.add(nn.Dense(1))
net.initialize()  # Use the default initialization method

X = np.random.uniform(size=(2, 4))
net(X).shape
Copy to clipboard
[22:10:04] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
(2, 1)
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
Copy to clipboard
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

By default, Flax initializes weights using jax.nn.initializers.lecun_normal, i.e., by drawing samples from a truncated normal distribution centered on 0 with the standard deviation set as the squared root of 1/fanin where fan_in is the number of input units in the weight tensor. The bias parameters are all set to zero. Jax’s nn.initializers module provides a variety of preset initialization methods.

net = nn.Sequential([nn.Dense(8), nn.relu, nn.Dense(1)])
X = jax.random.uniform(d2l.get_key(), (2, 4))
params = net.init(d2l.get_key(), X)
net.apply(params, X).shape
Copy to clipboard
(2, 1)
import tensorflow as tf
Copy to clipboard

By default, Keras initializes weight matrices uniformly by drawing from a range that is computed according to the input and output dimension, and the bias parameters are all set to zero. TensorFlow provides a variety of initialization methods both in the root module and the keras.initializers module.

net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(4, activation=tf.nn.relu),
    tf.keras.layers.Dense(1),
])

X = tf.random.uniform((2, 4))
net(X).shape
Copy to clipboard
TensorShape([2, 1])

6.3.1. Built-in Initialization

Let’s begin by calling on built-in initializers. The code below initializes all weight parameters as Gaussian random variables with standard deviation 0.01, while bias parameters are cleared to zero.

def init_normal(module):
    if type(module) == nn.Linear:
        nn.init.normal_(module.weight, mean=0, std=0.01)
        nn.init.zeros_(module.bias)

net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]
Copy to clipboard
(tensor([-0.0129, -0.0007, -0.0033,  0.0276]), tensor(0.))
# Here force_reinit ensures that parameters are freshly initialized even if
# they were already initialized previously
net.initialize(init=init.Normal(sigma=0.01), force_reinit=True)
net[0].weight.data()[0]
Copy to clipboard
array([ 0.00354961, -0.00614133,  0.0107317 ,  0.01830765])
weight_init = nn.initializers.normal(0.01)
bias_init = nn.initializers.zeros

net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])

params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]
Copy to clipboard
(Array([ 0.00457076,  0.01890736, -0.0014968 ,  0.00327491], dtype=float32),
 Array(0., dtype=float32))
net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4, activation=tf.nn.relu,
        kernel_initializer=tf.random_normal_initializer(mean=0, stddev=0.01),
        bias_initializer=tf.zeros_initializer()),
    tf.keras.layers.Dense(1)])

net(X)
net.weights[0], net.weights[1]
Copy to clipboard
(<tf.Variable 'dense_2/kernel:0' shape=(4, 4) dtype=float32, numpy=
 array([[-0.02287503, -0.00437018, -0.00140329, -0.00622254],
        [ 0.00495972,  0.00324918, -0.00965284, -0.00612193],
        [-0.00183808, -0.00826601, -0.00676942,  0.00917007],
        [ 0.00847368, -0.00507652, -0.00761351, -0.00762984]],
       dtype=float32)>,
 <tf.Variable 'dense_2/bias:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>)

We can also initialize all the parameters to a given constant value (say, 1).

def init_constant(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 1)
        nn.init.zeros_(module.bias)

net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]
Copy to clipboard
(tensor([1., 1., 1., 1.]), tensor(0.))
net.initialize(init=init.Constant(1), force_reinit=True)
net[0].weight.data()[0]
Copy to clipboard
array([1., 1., 1., 1.])
weight_init = nn.initializers.constant(1)

net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])

params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]
Copy to clipboard
(Array([1., 1., 1., 1.], dtype=float32), Array(0., dtype=float32))
net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4, activation=tf.nn.relu,
        kernel_initializer=tf.keras.initializers.Constant(1),
        bias_initializer=tf.zeros_initializer()),
    tf.keras.layers.Dense(1),
])

net(X)
net.weights[0], net.weights[1]
Copy to clipboard
(<tf.Variable 'dense_4/kernel:0' shape=(4, 4) dtype=float32, numpy=
 array([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]], dtype=float32)>,
 <tf.Variable 'dense_4/bias:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>)

We can also apply different initializers for certain blocks. For example, below we initialize the first layer with the Xavier initializer and initialize the second layer to a constant value of 42.

def init_xavier(module):
    if type(module) == nn.Linear:
        nn.init.xavier_uniform_(module.weight)

def init_42(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 42)

net[0].apply(init_xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)
Copy to clipboard
tensor([-0.0974,  0.1707,  0.5840, -0.5032])
tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])
net[0].weight.initialize(init=init.Xavier(), force_reinit=True)
net[1].initialize(init=init.Constant(42), force_reinit=True)
print(net[0].weight.data()[0])
print(net[1].weight.data())
Copy to clipboard
[-0.26102373  0.15249556 -0.19274211 -0.24742058]
[[42. 42. 42. 42. 42. 42. 42. 42.]]
net = nn.Sequential([nn.Dense(8, kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=nn.initializers.constant(42),
                              bias_init=bias_init)])

params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
params['params']['layers_0']['kernel'][:, 0], params['params']['layers_2']['kernel']
Copy to clipboard
(Array([ 0.38926104, -0.4023119 , -0.41848803, -0.6341998 ], dtype=float32),
 Array([[42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.]], dtype=float32))
net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4,
        activation=tf.nn.relu,
        kernel_initializer=tf.keras.initializers.GlorotUniform()),
    tf.keras.layers.Dense(
        1, kernel_initializer=tf.keras.initializers.Constant(42)),
])

net(X)
print(net.layers[1].weights[0])
print(net.layers[2].weights[0])
Copy to clipboard
<tf.Variable 'dense_6/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[ 0.54234487,  0.2669801 , -0.2516024 ,  0.1076265 ],
       [ 0.30622882,  0.30598146, -0.4484879 ,  0.07192796],
       [ 0.36688513,  0.3838529 ,  0.40699893,  0.577269  ],
       [-0.2649538 ,  0.43839508, -0.3203209 ,  0.29825717]],
      dtype=float32)>
<tf.Variable 'dense_7/kernel:0' shape=(4, 1) dtype=float32, numpy=
array([[42.],
       [42.],
       [42.],
       [42.]], dtype=float32)>

6.3.1.1. Custom Initialization

Sometimes, the initialization methods we need are not provided by the deep learning framework. In the example below, we define an initializer for any weight parameter w using the following strange distribution:

(6.3.1)w{U(5,10) with probability 140 with probability 12U(10,5) with probability 14

Again, we implement a my_init function to apply to net.

def my_init(module):
    if type(module) == nn.Linear:
        print("Init", *[(name, param.shape)
                        for name, param in module.named_parameters()][0])
        nn.init.uniform_(module.weight, -10, 10)
        module.weight.data *= module.weight.data.abs() >= 5

net.apply(my_init)
net[0].weight[:2]
Copy to clipboard
Init weight torch.Size([8, 4])
Init weight torch.Size([1, 8])
tensor([[ 0.0000, -7.6364, -0.0000, -6.1206],
        [ 9.3516, -0.0000,  5.1208, -8.4003]], grad_fn=<SliceBackward0>)

Note that we always have the option of setting parameters directly.

net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]
Copy to clipboard
tensor([42.0000, -6.6364,  1.0000, -5.1206])

Here we define a subclass of the Initializer class. Usually, we only need to implement the _init_weight function which takes a tensor argument (data) and assigns to it the desired initialized values.

class MyInit(init.Initializer):
    def _init_weight(self, name, data):
        print('Init', name, data.shape)
        data[:] = np.random.uniform(-10, 10, data.shape)
        data *= np.abs(data) >= 5

net.initialize(MyInit(), force_reinit=True)
net[0].weight.data()[:2]
Copy to clipboard
Init dense0_weight (8, 4)
Init dense1_weight (1, 8)
array([[-6.0683527,  8.991421 , -0.       ,  0.       ],
       [ 6.4198647, -9.728567 , -8.057975 ,  0.       ]])

Note that we always have the option of setting parameters directly.

net[0].weight.data()[:] += 1
net[0].weight.data()[0, 0] = 42
net[0].weight.data()[0]
Copy to clipboard
array([42.      ,  9.991421,  1.      ,  1.      ])

Jax initialization functions take as arguments the PRNGKey, shape and dtype. Here we implement the function my_init that returns a desired tensor given the shape and data type.

def my_init(key, shape, dtype=jnp.float_):
    data = jax.random.uniform(key, shape, minval=-10, maxval=10)
    return data * (jnp.abs(data) >= 5)

net = nn.Sequential([nn.Dense(8, kernel_init=my_init), nn.relu, nn.Dense(1)])
params = net.init(d2l.get_key(), X)
print(params['params']['layers_0']['kernel'][:, :2])
Copy to clipboard
[[ 0.        -5.891962 ]
 [ 0.        -9.597271 ]
 [-5.809202   6.3091564]
 [ 0.         0.       ]]

When initializing parameters in JAX and Flax, the the dictionary of parameters returned has a flax.core.frozen_dict.FrozenDict type. It is not advisable in the Jax ecosystem to directly alter the values of an array, hence the datatypes are generally immutable. One might use params.unfreeze() to make changes.

Here we define a subclass of Initializer and implement the __call__ function that return a desired tensor given the shape and data type.

class MyInit(tf.keras.initializers.Initializer):
    def __call__(self, shape, dtype=None):
        data=tf.random.uniform(shape, -10, 10, dtype=dtype)
        factor=(tf.abs(data) >= 5)
        factor=tf.cast(factor, tf.float32)
        return data * factor

net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4,
        activation=tf.nn.relu,
        kernel_initializer=MyInit()),
    tf.keras.layers.Dense(1),
])

net(X)
print(net.layers[1].weights[0])
Copy to clipboard
<tf.Variable 'dense_8/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[-8.454213 , -0.       ,  0.       , -0.       ],
       [-9.362183 ,  0.       ,  0.       ,  0.       ],
       [-0.       , -9.406505 , -0.       , -0.       ],
       [ 6.2464294, -0.       , -0.       , -9.80323  ]], dtype=float32)>

Note that we always have the option of setting parameters directly.

net.layers[1].weights[0][:].assign(net.layers[1].weights[0] + 1)
net.layers[1].weights[0][0, 0].assign(42)
net.layers[1].weights[0]
Copy to clipboard
<tf.Variable 'dense_8/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[42.       ,  1.       ,  1.       ,  1.       ],
       [-8.362183 ,  1.       ,  1.       ,  1.       ],
       [ 1.       , -8.406505 ,  1.       ,  1.       ],
       [ 7.2464294,  1.       ,  1.       , -8.80323  ]], dtype=float32)>

6.3.2. Summary

We can initialize parameters using built-in and custom initializers.

6.3.3. Exercises

Look up the online documentation for more built-in initializers.