Custom Layers ============= One factor behind deep learning’s success is the availability of a wide range of layers that can be composed in creative ways to design architectures suitable for a wide variety of tasks. For instance, researchers have invented layers specifically for handling images, text, looping over sequential data, and performing dynamic programming. Sooner or later, you will need a layer that does not exist yet in the deep learning framework. In these cases, you must build a custom layer. In this section, we show you how. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import torch from torch import nn from torch.nn import functional as F from d2l import torch as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from mxnet import np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import jax from flax import linen as nn from jax import numpy as jnp from d2l import jax as d2l .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf from d2l import tensorflow as d2l .. raw:: html
.. raw:: html
Layers without Parameters ------------------------- To start, we construct a custom layer that does not have any parameters of its own. This should look familiar if you recall our introduction to modules in :numref:`sec_model_construction`. The following ``CenteredLayer`` class simply subtracts the mean from its input. To build it, we simply need to inherit from the base layer class and implement the forward propagation function. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class CenteredLayer(nn.Module): def __init__(self): super().__init__() def forward(self, X): return X - X.mean() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class CenteredLayer(nn.Block): def __init__(self, **kwargs): super().__init__(**kwargs) def forward(self, X): return X - X.mean() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class CenteredLayer(nn.Module): def __call__(self, X): return X - X.mean() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class CenteredLayer(tf.keras.Model): def __init__(self): super().__init__() def call(self, X): return X - tf.reduce_mean(X) .. raw:: html
.. raw:: html
Let’s verify that our layer works as intended by feeding some data through it. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python layer = CenteredLayer() layer(torch.tensor([1.0, 2, 3, 4, 5])) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([-2., -1., 0., 1., 2.]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python layer = CenteredLayer() layer(np.array([1.0, 2, 3, 4, 5])) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [21:49:18] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([-2., -1., 0., 1., 2.]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python layer = CenteredLayer() layer(jnp.array([1.0, 2, 3, 4, 5])) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Array([-2., -1., 0., 1., 2.], dtype=float32) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python layer = CenteredLayer() layer(tf.constant([1.0, 2, 3, 4, 5])) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
We can now incorporate our layer as a component in constructing more complex models. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = nn.Sequential(nn.LazyLinear(128), CenteredLayer()) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = nn.Sequential() net.add(nn.Dense(128), CenteredLayer()) net.initialize() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = nn.Sequential([nn.Dense(128), CenteredLayer()]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = tf.keras.Sequential([tf.keras.layers.Dense(128), CenteredLayer()]) .. raw:: html
.. raw:: html
As an extra sanity check, we can send random data through the network and check that the mean is in fact 0. Because we are dealing with floating point numbers, we may still see a very small nonzero number due to quantization. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python Y = net(torch.rand(4, 8)) Y.mean() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor(-6.5193e-09, grad_fn=) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python Y = net(np.random.rand(4, 8)) Y.mean() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array(3.783498e-10) .. raw:: html
.. raw:: html
Here we utilize the ``init_with_output`` method which returns both the output of the network as well as the parameters. In this case we only focus on the output. .. raw:: latex \diilbookstyleinputcell .. code:: python Y, _ = net.init_with_output(d2l.get_key(), jax.random.uniform(d2l.get_key(), (4, 8))) Y.mean() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Array(-3.7252903e-09, dtype=float32) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python Y = net(tf.random.uniform((4, 8))) tf.reduce_mean(Y) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
Layers with Parameters ---------------------- Now that we know how to define simple layers, let’s move on to defining layers with parameters that can be adjusted through training. We can use built-in functions to create parameters, which provide some basic housekeeping functionality. In particular, they govern access, initialization, sharing, saving, and loading model parameters. This way, among other benefits, we will not need to write custom serialization routines for every custom layer. Now let’s implement our own version of the fully connected layer. Recall that this layer requires two parameters, one to represent the weight and the other for the bias. In this implementation, we bake in the ReLU activation as a default. This layer requires two input arguments: ``in_units`` and ``units``, which denote the number of inputs and outputs, respectively. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class MyLinear(nn.Module): def __init__(self, in_units, units): super().__init__() self.weight = nn.Parameter(torch.randn(in_units, units)) self.bias = nn.Parameter(torch.randn(units,)) def forward(self, X): linear = torch.matmul(X, self.weight.data) + self.bias.data return F.relu(linear) Next, we instantiate the ``MyLinear`` class and access its model parameters. .. raw:: latex \diilbookstyleinputcell .. code:: python linear = MyLinear(5, 3) linear.weight .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Parameter containing: tensor([[ 0.4783, 0.4284, -0.0899], [-0.6347, 0.2913, -0.0822], [-0.4325, -0.1645, -0.3274], [ 1.1898, 0.6482, -1.2384], [-0.1479, 0.0264, -0.9597]], requires_grad=True) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class MyDense(nn.Block): def __init__(self, units, in_units, **kwargs): super().__init__(**kwargs) self.weight = self.params.get('weight', shape=(in_units, units)) self.bias = self.params.get('bias', shape=(units,)) def forward(self, x): linear = np.dot(x, self.weight.data(ctx=x.ctx)) + self.bias.data( ctx=x.ctx) return npx.relu(linear) Next, we instantiate the ``MyDense`` class and access its model parameters. .. raw:: latex \diilbookstyleinputcell .. code:: python dense = MyDense(units=3, in_units=5) dense.params .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output mydense0_ ( Parameter mydense0_weight (shape=(5, 3), dtype=) Parameter mydense0_bias (shape=(3,), dtype=) ) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class MyDense(nn.Module): in_units: int units: int def setup(self): self.weight = self.param('weight', nn.initializers.normal(stddev=1), (self.in_units, self.units)) self.bias = self.param('bias', nn.initializers.zeros, self.units) def __call__(self, X): linear = jnp.matmul(X, self.weight) + self.bias return nn.relu(linear) Next, we instantiate the ``MyDense`` class and access its model parameters. .. raw:: latex \diilbookstyleinputcell .. code:: python dense = MyDense(5, 3) params = dense.init(d2l.get_key(), jnp.zeros((3, 5))) params .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output FrozenDict({ params: { weight: Array([[-0.23823419, -0.70915407, 0.72494346], [ 0.2568525 , -0.20872341, -0.8993567 ], [ 0.80883664, 0.16673394, 0.75610644], [-0.35652584, 0.13841456, -1.0971175 ], [ 0.3117082 , 1.2280334 , -1.0946037 ]], dtype=float32), bias: Array([0., 0., 0.], dtype=float32), }, }) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class MyDense(tf.keras.Model): def __init__(self, units): super().__init__() self.units = units def build(self, X_shape): self.weight = self.add_weight(name='weight', shape=[X_shape[-1], self.units], initializer=tf.random_normal_initializer()) self.bias = self.add_weight( name='bias', shape=[self.units], initializer=tf.zeros_initializer()) def call(self, X): linear = tf.matmul(X, self.weight) + self.bias return tf.nn.relu(linear) Next, we instantiate the ``MyDense`` class and access its model parameters. .. raw:: latex \diilbookstyleinputcell .. code:: python dense = MyDense(3) dense(tf.random.uniform((2, 5))) dense.get_weights() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [array([[-0.01007051, -0.05935554, 0.03142897], [ 0.02453684, -0.01833588, -0.03096254], [-0.09680572, -0.01736571, -0.00858052], [-0.02245625, 0.02958351, -0.05780673], [ 0.03997313, 0.01949595, -0.00150928]], dtype=float32), array([0., 0., 0.], dtype=float32)] .. raw:: html
.. raw:: html
We can directly carry out forward propagation calculations using custom layers. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python linear(torch.rand(2, 5)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([[0.0000, 0.9316, 0.0000], [0.1808, 1.4208, 0.0000]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python dense.initialize() dense(np.random.uniform(size=(2, 5))) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([[0. , 0.01633355, 0. ], [0. , 0.01581812, 0. ]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python dense.apply(params, jax.random.uniform(d2l.get_key(), (2, 5))) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Array([[0.3850514 , 0. , 0.49188882], [0.46509624, 0.26056105, 0. ]], dtype=float32) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python dense(tf.random.uniform((2, 5))) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
We can also construct models using custom layers. Once we have that we can use it just like the built-in fully connected layer. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1)) net(torch.rand(2, 64)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([[ 0.0000], [13.0800]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = nn.Sequential() net.add(MyDense(8, in_units=64), MyDense(1, in_units=8)) net.initialize() net(np.random.uniform(size=(2, 64))) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([[0.06508517], [0.0615553 ]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = nn.Sequential([MyDense(64, 8), MyDense(8, 1)]) Y, _ = net.init_with_output(d2l.get_key(), jax.random.uniform(d2l.get_key(), (2, 64))) Y .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Array([[0.], [0.]], dtype=float32) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = tf.keras.models.Sequential([MyDense(8), MyDense(1)]) net(tf.random.uniform((2, 64))) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
Summary ------- We can design custom layers via the basic layer class. This allows us to define flexible new layers that behave differently from any existing layers in the library. Once defined, custom layers can be invoked in arbitrary contexts and architectures. Layers can have local parameters, which can be created through built-in functions. Exercises --------- 1. Design a layer that takes an input and computes a tensor reduction, i.e., it returns :math:`y_k = \sum_{i, j} W_{ijk} x_i x_j`. 2. Design a layer that returns the leading half of the Fourier coefficients of the data. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html