.. _sec_transposed_conv:
Transposed Convolution
======================
The CNN layers we have seen so far, such as convolutional layers
(:numref:`sec_conv_layer`) and pooling layers
(:numref:`sec_pooling`), typically reduce (downsample) the spatial
dimensions (height and width) of the input, or keep them unchanged. In
semantic segmentation that classifies at pixel-level, it will be
convenient if the spatial dimensions of the input and output are the
same. For example, the channel dimension at one output pixel can hold
the classification results for the input pixel at the same spatial
position.
To achieve this, especially after the spatial dimensions are reduced by
CNN layers, we can use another type of CNN layers that can increase
(upsample) the spatial dimensions of intermediate feature maps. In this
section, we will introduce *transposed convolution*, which is also
called *fractionally-strided convolution* :cite:`Dumoulin.Visin.2016`,
for reversing downsampling operations by the convolution.
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    import torch
    from torch import nn
    from d2l import torch as d2l
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    from mxnet import init, np, npx
    from mxnet.gluon import nn
    from d2l import mxnet as d2l
    
    npx.set_np()
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def trans_conv(X, K):
        h, w = K.shape
        Y = torch.zeros((X.shape[0] + h - 1, X.shape[1] + w - 1))
        for i in range(X.shape[0]):
            for j in range(X.shape[1]):
                Y[i: i + h, j: j + w] += X[i, j] * K
        return Y
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def trans_conv(X, K):
        h, w = K.shape
        Y = np.zeros((X.shape[0] + h - 1, X.shape[1] + w - 1))
        for i in range(X.shape[0]):
            for j in range(X.shape[1]):
                Y[i: i + h, j: j + w] += X[i, j] * K
        return Y
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    X = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
    K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
    trans_conv(X, K)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    tensor([[ 0.,  0.,  1.],
            [ 0.,  4.,  6.],
            [ 4., 12.,  9.]])
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    X = np.array([[0.0, 1.0], [2.0, 3.0]])
    K = np.array([[0.0, 1.0], [2.0, 3.0]])
    trans_conv(X, K)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    [22:07:07] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    array([[ 0.,  0.,  1.],
           [ 0.,  4.,  6.],
           [ 4., 12.,  9.]])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    X, K = X.reshape(1, 1, 2, 2), K.reshape(1, 1, 2, 2)
    tconv = nn.ConvTranspose2d(1, 1, kernel_size=2, bias=False)
    tconv.weight.data = K
    tconv(X)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    tensor([[[[ 0.,  0.,  1.],
              [ 0.,  4.,  6.],
              [ 4., 12.,  9.]]]], grad_fn=)
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    X, K = X.reshape(1, 1, 2, 2), K.reshape(1, 1, 2, 2)
    tconv = nn.Conv2DTranspose(1, kernel_size=2)
    tconv.initialize(init.Constant(K))
    tconv(X)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    array([[[[ 0.,  0.,  1.],
             [ 0.,  4.,  6.],
             [ 4., 12.,  9.]]]])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    tconv = nn.ConvTranspose2d(1, 1, kernel_size=2, padding=1, bias=False)
    tconv.weight.data = K
    tconv(X)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    tensor([[[[4.]]]], grad_fn=)
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    tconv = nn.Conv2DTranspose(1, kernel_size=2, padding=1)
    tconv.initialize(init.Constant(K))
    tconv(X)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    array([[[[4.]]]])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    tconv = nn.ConvTranspose2d(1, 1, kernel_size=2, stride=2, bias=False)
    tconv.weight.data = K
    tconv(X)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    tensor([[[[0., 0., 0., 1.],
              [0., 0., 2., 3.],
              [0., 2., 0., 3.],
              [4., 6., 6., 9.]]]], grad_fn=)
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    tconv = nn.Conv2DTranspose(1, kernel_size=2, strides=2)
    tconv.initialize(init.Constant(K))
    tconv(X)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    array([[[[0., 0., 0., 1.],
             [0., 0., 2., 3.],
             [0., 2., 0., 3.],
             [4., 6., 6., 9.]]]])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    X = torch.rand(size=(1, 10, 16, 16))
    conv = nn.Conv2d(10, 20, kernel_size=5, padding=2, stride=3)
    tconv = nn.ConvTranspose2d(20, 10, kernel_size=5, padding=2, stride=3)
    tconv(conv(X)).shape == X.shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    True
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    X = np.random.uniform(size=(1, 10, 16, 16))
    conv = nn.Conv2D(20, kernel_size=5, padding=2, strides=3)
    tconv = nn.Conv2DTranspose(10, kernel_size=5, padding=2, strides=3)
    conv.initialize()
    tconv.initialize()
    tconv(conv(X)).shape == X.shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    True
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    X = torch.arange(9.0).reshape(3, 3)
    K = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
    Y = d2l.corr2d(X, K)
    Y
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    tensor([[27., 37.],
            [57., 67.]])
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    X = np.arange(9.0).reshape(3, 3)
    K = np.array([[1.0, 2.0], [3.0, 4.0]])
    Y = d2l.corr2d(X, K)
    Y
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    array([[27., 37.],
           [57., 67.]])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def kernel2matrix(K):
        k, W = torch.zeros(5), torch.zeros((4, 9))
        k[:2], k[3:5] = K[0, :], K[1, :]
        W[0, :5], W[1, 1:6], W[2, 3:8], W[3, 4:] = k, k, k, k
        return W
    
    W = kernel2matrix(K)
    W
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    tensor([[1., 2., 0., 3., 4., 0., 0., 0., 0.],
            [0., 1., 2., 0., 3., 4., 0., 0., 0.],
            [0., 0., 0., 1., 2., 0., 3., 4., 0.],
            [0., 0., 0., 0., 1., 2., 0., 3., 4.]])
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def kernel2matrix(K):
        k, W = np.zeros(5), np.zeros((4, 9))
        k[:2], k[3:5] = K[0, :], K[1, :]
        W[0, :5], W[1, 1:6], W[2, 3:8], W[3, 4:] = k, k, k, k
        return W
    
    W = kernel2matrix(K)
    W
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    array([[1., 2., 0., 3., 4., 0., 0., 0., 0.],
           [0., 1., 2., 0., 3., 4., 0., 0., 0.],
           [0., 0., 0., 1., 2., 0., 3., 4., 0.],
           [0., 0., 0., 0., 1., 2., 0., 3., 4.]])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    Y == torch.matmul(W, X.reshape(-1)).reshape(2, 2)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    tensor([[True, True],
            [True, True]])
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    Y == np.dot(W, X.reshape(-1)).reshape(2, 2)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    array([[ True,  True],
           [ True,  True]])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    Z = trans_conv(Y, K)
    Z == torch.matmul(W.T, Y.reshape(-1)).reshape(3, 3)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    tensor([[True, True, True],
            [True, True, True],
            [True, True, True]])
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    Z = trans_conv(Y, K)
    Z == np.dot(W.T, Y.reshape(-1)).reshape(3, 3)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    array([[ True,  True,  True],
           [ True,  True,  True],
           [ True,  True,  True]])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
`Discussions `__
.. raw:: html
    
.. raw:: html
    
`Discussions `__
.. raw:: html
    
.. raw:: html