.. _sec_lazy_init:
Lazy Initialization
===================
So far, it might seem that we got away with being sloppy in setting up
our networks. Specifically, we did the following unintuitive things,
which might not seem like they should work:
- We defined the network architectures without specifying the input
dimensionality.
- We added layers without specifying the output dimension of the
previous layer.
- We even “initialized” these parameters before providing enough
information to determine how many parameters our models should
contain.
You might be surprised that our code runs at all. After all, there is no
way the deep learning framework could tell what the input dimensionality
of a network would be. The trick here is that the framework *defers
initialization*, waiting until the first time we pass data through the
model, to infer the sizes of each layer on the fly.
Later on, when working with convolutional neural networks, this
technique will become even more convenient since the input
dimensionality (e.g., the resolution of an image) will affect the
dimensionality of each subsequent layer. Hence the ability to set
parameters without the need to know, at the time of writing the code,
the value of the dimension can greatly simplify the task of specifying
and subsequently modifying our models. Next, we go deeper into the
mechanics of initialization.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
from torch import nn
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import np, npx
from mxnet.gluon import nn
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import tensorflow as tf
.. raw:: html
.. raw:: html
To begin, let’s instantiate an MLP.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential()
net.add(nn.Dense(256, activation='relu'))
net.add(nn.Dense(10))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential([nn.Dense(256), nn.relu, nn.Dense(10)])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = tf.keras.models.Sequential([
tf.keras.layers.Dense(256, activation=tf.nn.relu),
tf.keras.layers.Dense(10),
])
.. raw:: html
.. raw:: html
At this point, the network cannot possibly know the dimensions of the
input layer’s weights because the input dimension remains unknown.
.. raw:: html
.. raw:: html
Consequently the framework has not yet initialized any parameters. We
confirm by attempting to access the parameters below.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net[0].weight
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Consequently the framework has not yet initialized any parameters. We
confirm by attempting to access the parameters below.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
print(net.collect_params)
print(net.collect_params())
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
256, Activation(relu))
(1): Dense(-1 -> 10, linear)
)>
sequential0_ (
Parameter dense0_weight (shape=(256, -1), dtype=float32)
Parameter dense0_bias (shape=(256,), dtype=float32)
Parameter dense1_weight (shape=(10, -1), dtype=float32)
Parameter dense1_bias (shape=(10,), dtype=float32)
)
Note that while the parameter objects exist, the input dimension to each
layer is listed as -1. MXNet uses the special value -1 to indicate that
the parameter dimension remains unknown. At this point, attempts to
access ``net[0].weight.data()`` would trigger a runtime error stating
that the network must be initialized before the parameters can be
accessed. Now let’s see what happens when we attempt to initialize
parameters via the ``initialize`` method.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net.initialize()
net.collect_params()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[22:11:11] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
sequential0_ (
Parameter dense0_weight (shape=(256, -1), dtype=float32)
Parameter dense0_bias (shape=(256,), dtype=float32)
Parameter dense1_weight (shape=(10, -1), dtype=float32)
Parameter dense1_bias (shape=(10,), dtype=float32)
)
As we can see, nothing has changed. When input dimensions are unknown,
calls to initialize do not truly initialize the parameters. Instead,
this call registers to MXNet that we wish (and optionally, according to
which distribution) to initialize the parameters.
.. raw:: html
.. raw:: html
As mentioned in :numref:`subsec_param-access`, parameters and the
network definition are decoupled in Jax and Flax, and the user handles
both manually. Flax models are stateless hence there is no
``parameters`` attribute.
.. raw:: html
.. raw:: html
Consequently the framework has not yet initialized any parameters. We
confirm by attempting to access the parameters below.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
[net.layers[i].get_weights() for i in range(len(net.layers))]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[[], []]
Note that each layer objects exist but the weights are empty. Using
``net.get_weights()`` would throw an error since the weights have not
been initialized yet.
.. raw:: html
.. raw:: html
Next let’s pass data through the network to make the framework finally
initialize parameters.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = torch.rand(2, 20)
net(X)
net[0].weight.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([256, 20])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = np.random.uniform(size=(2, 20))
net(X)
net.collect_params()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
sequential0_ (
Parameter dense0_weight (shape=(256, 20), dtype=float32)
Parameter dense0_bias (shape=(256,), dtype=float32)
Parameter dense1_weight (shape=(10, 256), dtype=float32)
Parameter dense1_bias (shape=(10,), dtype=float32)
)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
params = net.init(d2l.get_key(), jnp.zeros((2, 20)))
jax.tree_util.tree_map(lambda x: x.shape, params).tree_flatten_with_keys()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(((DictKey(key='params'),
{'layers_0': {'bias': (256,), 'kernel': (20, 256)},
'layers_2': {'bias': (10,), 'kernel': (256, 10)}}),),
('params',))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = tf.random.uniform((2, 20))
net(X)
[w.shape for w in net.get_weights()]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[(20, 256), (256,), (256, 10), (10,)]
.. raw:: html
.. raw:: html
As soon as we know the input dimensionality, 20, the framework can
identify the shape of the first layer’s weight matrix by plugging in the
value of 20. Having recognized the first layer’s shape, the framework
proceeds to the second layer, and so on through the computational graph
until all shapes are known. Note that in this case, only the first layer
requires lazy initialization, but the framework initializes
sequentially. Once all parameter shapes are known, the framework can
finally initialize the parameters.
.. raw:: html
.. raw:: html
The following method passes in dummy inputs through the network for a
dry run to infer all parameter shapes and subsequently initializes the
parameters. It will be used later when default random initializations
are not desired.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(d2l.Module) #@save
def apply_init(self, inputs, init=None):
self.forward(*inputs)
if init is not None:
self.net.apply(init)
.. raw:: html
.. raw:: html
Parameter initialization in Flax is always done manually and handled by
the user. The following method takes a dummy input and a key dictionary
as argument. This key dictionary has the rngs for initializing the model
parameters and dropout rng for generating the dropout mask for the
models with dropout layers. More about dropout will be covered later in
:numref:`sec_dropout`. Ultimately the method initializes the model
returning the parameters. We have been using it under the hood in the
previous sections as well.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(d2l.Module) #@save
def apply_init(self, dummy_input, key):
params = self.init(key, *dummy_input) # dummy_input tuple unpacked
return params
.. raw:: html
.. raw:: html
Summary
-------
Lazy initialization can be convenient, allowing the framework to infer
parameter shapes automatically, making it easy to modify architectures
and eliminating one common source of errors. We can pass data through
the model to make the framework finally initialize parameters.
Exercises
---------
1. What happens if you specify the input dimensions to the first layer
but not to subsequent layers? Do you get immediate initialization?
2. What happens if you specify mismatching dimensions?
3. What would you need to do if you have input of varying
dimensionality? Hint: look at the parameter tying.
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html