.. _sec_dcgan:
Deep Convolutional Generative Adversarial Networks
==================================================
In :numref:`sec_basic_gan`, we introduced the basic ideas behind how
GANs work. We showed that they can draw samples from some simple,
easy-to-sample distribution, like a uniform or normal distribution, and
transform them into samples that appear to match the distribution of
some dataset. And while our example of matching a 2D Gaussian
distribution got the point across, it is not especially exciting.
In this section, we will demonstrate how you can use GANs to generate
photorealistic images. We will be basing our models on the deep
convolutional GANs (DCGAN) introduced in
:cite:t:`Radford.Metz.Chintala.2015`. We will borrow the convolutional
architecture that have proven so successful for discriminative computer
vision problems and show how via GANs, they can be leveraged to generate
photorealistic images.
.. raw:: html
    
`__. First download, extract
and load this dataset.
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    #@save
    d2l.DATA_HUB['pokemon'] = (d2l.DATA_URL + 'pokemon.zip',
                               'c065c0e2593b8b161a2d7873e42418bf6a21106c')
    
    data_dir = d2l.download_extract('pokemon')
    pokemon = torchvision.datasets.ImageFolder(data_dir)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    Downloading ../data/pokemon.zip from http://d2l-data.s3-accelerate.amazonaws.com/pokemon.zip...
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    #@save
    d2l.DATA_HUB['pokemon'] = (d2l.DATA_URL + 'pokemon.zip',
                               'c065c0e2593b8b161a2d7873e42418bf6a21106c')
    
    data_dir = d2l.download_extract('pokemon')
    pokemon = gluon.data.vision.datasets.ImageFolderDataset(data_dir)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    Downloading ../data/pokemon.zip from http://d2l-data.s3-accelerate.amazonaws.com/pokemon.zip...
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    #@save
    d2l.DATA_HUB['pokemon'] = (d2l.DATA_URL + 'pokemon.zip',
                               'c065c0e2593b8b161a2d7873e42418bf6a21106c')
    
    data_dir = d2l.download_extract('pokemon')
    batch_size = 256
    pokemon = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir, batch_size=batch_size, image_size=(64, 64))
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    Downloading ../data/pokemon.zip from http://d2l-data.s3-accelerate.amazonaws.com/pokemon.zip...
    Found 40597 files belonging to 721 classes.
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    batch_size = 256
    transformer = torchvision.transforms.Compose([
        torchvision.transforms.Resize((64, 64)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(0.5, 0.5)
    ])
    pokemon.transform = transformer
    data_iter = torch.utils.data.DataLoader(
        pokemon, batch_size=batch_size,
        shuffle=True, num_workers=d2l.get_dataloader_workers())
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    batch_size = 256
    transformer = gluon.data.vision.transforms.Compose([
        gluon.data.vision.transforms.Resize(64),
        gluon.data.vision.transforms.ToTensor(),
        gluon.data.vision.transforms.Normalize(0.5, 0.5)
    ])
    data_iter = gluon.data.DataLoader(
        pokemon.transform_first(transformer), batch_size=batch_size,
        shuffle=True, num_workers=d2l.get_dataloader_workers())
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def transform_func(X):
        X = X / 255.
        X = (X - 0.5) / (0.5)
        return X
    
    # For TF>=2.4 use `num_parallel_calls = tf.data.AUTOTUNE`
    data_iter = pokemon.map(lambda x, y: (transform_func(x), y),
                            num_parallel_calls=tf.data.experimental.AUTOTUNE)
    data_iter = data_iter.cache().shuffle(buffer_size=1000).prefetch(
        buffer_size=tf.data.experimental.AUTOTUNE)
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    warnings.filterwarnings('ignore')
    d2l.set_figsize((4, 4))
    for X, y in data_iter:
        imgs = X[:20,:,:,:].permute(0, 2, 3, 1)/2+0.5
        d2l.show_images(imgs, num_rows=4, num_cols=5)
        break
.. figure:: output_dcgan_2541de_39_0.svg
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    d2l.set_figsize((4, 4))
    for X, y in data_iter:
        imgs = X[:20,:,:,:].transpose(0, 2, 3, 1)/2+0.5
        d2l.show_images(imgs, num_rows=4, num_cols=5)
        break
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    [22:43:03] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
    [22:43:03] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
    [22:43:03] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
    [22:43:03] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
    [22:43:03] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. figure:: output_dcgan_2541de_42_1.svg
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    d2l.set_figsize(figsize=(4, 4))
    for X, y in data_iter.take(1):
        imgs = X[:20, :, :, :] / 2 + 0.5
        d2l.show_images(imgs, num_rows=4, num_cols=5)
.. figure:: output_dcgan_2541de_45_0.svg
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    class G_block(nn.Module):
        def __init__(self, out_channels, in_channels=3, kernel_size=4, strides=2,
                     padding=1, **kwargs):
            super(G_block, self).__init__(**kwargs)
            self.conv2d_trans = nn.ConvTranspose2d(in_channels, out_channels,
                                    kernel_size, strides, padding, bias=False)
            self.batch_norm = nn.BatchNorm2d(out_channels)
            self.activation = nn.ReLU()
    
        def forward(self, X):
            return self.activation(self.batch_norm(self.conv2d_trans(X)))
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    class G_block(nn.Block):
        def __init__(self, channels, kernel_size=4,
                     strides=2, padding=1, **kwargs):
            super(G_block, self).__init__(**kwargs)
            self.conv2d_trans = nn.Conv2DTranspose(
                channels, kernel_size, strides, padding, use_bias=False)
            self.batch_norm = nn.BatchNorm()
            self.activation = nn.Activation('relu')
    
        def forward(self, X):
            return self.activation(self.batch_norm(self.conv2d_trans(X)))
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    class G_block(tf.keras.layers.Layer):
        def __init__(self, out_channels, kernel_size=4, strides=2, padding="same",
                     **kwargs):
            super().__init__(**kwargs)
            self.conv2d_trans = tf.keras.layers.Conv2DTranspose(
                out_channels, kernel_size, strides, padding, use_bias=False)
            self.batch_norm = tf.keras.layers.BatchNormalization()
            self.activation = tf.keras.layers.ReLU()
    
        def call(self, X):
            return self.activation(self.batch_norm(self.conv2d_trans(X)))
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = torch.zeros((2, 3, 16, 16))
    g_blk = G_block(20)
    g_blk(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    torch.Size([2, 20, 32, 32])
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = np.zeros((2, 3, 16, 16))
    g_blk = G_block(20)
    g_blk.initialize()
    g_blk(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (2, 20, 32, 32)
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = tf.zeros((2, 16, 16, 3))  # Channel last convention
    g_blk = G_block(20)
    g_blk(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    TensorShape([2, 32, 32, 20])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = torch.zeros((2, 3, 1, 1))
    g_blk = G_block(20, strides=1, padding=0)
    g_blk(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    torch.Size([2, 20, 4, 4])
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = np.zeros((2, 3, 1, 1))
    g_blk = G_block(20, strides=1, padding=0)
    g_blk.initialize()
    g_blk(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (2, 20, 4, 4)
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = tf.zeros((2, 1, 1, 3))
    # `padding="valid"` corresponds to no padding
    g_blk = G_block(20, strides=1, padding="valid")
    g_blk(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    TensorShape([2, 4, 4, 20])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    n_G = 64
    net_G = nn.Sequential(
        G_block(in_channels=100, out_channels=n_G*8,
                strides=1, padding=0),                  # Output: (64 * 8, 4, 4)
        G_block(in_channels=n_G*8, out_channels=n_G*4), # Output: (64 * 4, 8, 8)
        G_block(in_channels=n_G*4, out_channels=n_G*2), # Output: (64 * 2, 16, 16)
        G_block(in_channels=n_G*2, out_channels=n_G),   # Output: (64, 32, 32)
        nn.ConvTranspose2d(in_channels=n_G, out_channels=3,
                           kernel_size=4, stride=2, padding=1, bias=False),
        nn.Tanh())  # Output: (3, 64, 64)
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    n_G = 64
    net_G = nn.Sequential()
    net_G.add(G_block(n_G*8, strides=1, padding=0),  # Output: (64 * 8, 4, 4)
              G_block(n_G*4),  # Output: (64 * 4, 8, 8)
              G_block(n_G*2),  # Output: (64 * 2, 16, 16)
              G_block(n_G),    # Output: (64, 32, 32)
              nn.Conv2DTranspose(
                  3, kernel_size=4, strides=2, padding=1, use_bias=False,
                  activation='tanh'))  # Output: (3, 64, 64)
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    n_G = 64
    net_G = tf.keras.Sequential([
        # Output: (4, 4, 64 * 8)
        G_block(out_channels=n_G*8, strides=1, padding="valid"),
        G_block(out_channels=n_G*4), # Output: (8, 8, 64 * 4)
        G_block(out_channels=n_G*2), # Output: (16, 16, 64 * 2)
        G_block(out_channels=n_G), # Output: (32, 32, 64)
        # Output: (64, 64, 3)
        tf.keras.layers.Conv2DTranspose(
            3, kernel_size=4, strides=2, padding="same", use_bias=False,
            activation="tanh")
    ])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = torch.zeros((1, 100, 1, 1))
    net_G(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    torch.Size([1, 3, 64, 64])
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = np.zeros((1, 100, 1, 1))
    net_G.initialize()
    net_G(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (1, 3, 64, 64)
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = tf.zeros((1, 1, 1, 100))
    net_G(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    TensorShape([1, 64, 64, 3])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    alphas = [0, .2, .4, .6, .8, 1]
    x = torch.arange(-2, 1, 0.1)
    Y = [nn.LeakyReLU(alpha)(x).detach().numpy() for alpha in alphas]
    d2l.plot(x.detach().numpy(), Y, 'x', 'y', alphas)
.. figure:: output_dcgan_2541de_111_0.svg
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    alphas = [0, .2, .4, .6, .8, 1]
    x = np.arange(-2, 1, 0.1)
    Y = [nn.LeakyReLU(alpha)(x).asnumpy() for alpha in alphas]
    d2l.plot(x.asnumpy(), Y, 'x', 'y', alphas)
.. figure:: output_dcgan_2541de_114_0.svg
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    alphas = [0, .2, .4, .6, .8, 1]
    x = tf.range(-2, 1, 0.1)
    Y = [tf.keras.layers.LeakyReLU(alpha)(x).numpy() for alpha in alphas]
    d2l.plot(x.numpy(), Y, 'x', 'y', alphas)
.. figure:: output_dcgan_2541de_117_0.svg
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    class D_block(nn.Module):
        def __init__(self, out_channels, in_channels=3, kernel_size=4, strides=2,
                    padding=1, alpha=0.2, **kwargs):
            super(D_block, self).__init__(**kwargs)
            self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size,
                                    strides, padding, bias=False)
            self.batch_norm = nn.BatchNorm2d(out_channels)
            self.activation = nn.LeakyReLU(alpha, inplace=True)
    
        def forward(self, X):
            return self.activation(self.batch_norm(self.conv2d(X)))
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    class D_block(nn.Block):
        def __init__(self, channels, kernel_size=4, strides=2,
                     padding=1, alpha=0.2, **kwargs):
            super(D_block, self).__init__(**kwargs)
            self.conv2d = nn.Conv2D(
                channels, kernel_size, strides, padding, use_bias=False)
            self.batch_norm = nn.BatchNorm()
            self.activation = nn.LeakyReLU(alpha)
    
        def forward(self, X):
            return self.activation(self.batch_norm(self.conv2d(X)))
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    class D_block(tf.keras.layers.Layer):
        def __init__(self, out_channels, kernel_size=4, strides=2, padding="same",
                     alpha=0.2, **kwargs):
            super().__init__(**kwargs)
            self.conv2d = tf.keras.layers.Conv2D(out_channels, kernel_size,
                                                 strides, padding, use_bias=False)
            self.batch_norm = tf.keras.layers.BatchNormalization()
            self.activation = tf.keras.layers.LeakyReLU(alpha)
    
        def call(self, X):
            return self.activation(self.batch_norm(self.conv2d(X)))
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = torch.zeros((2, 3, 16, 16))
    d_blk = D_block(20)
    d_blk(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    torch.Size([2, 20, 8, 8])
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = np.zeros((2, 3, 16, 16))
    d_blk = D_block(20)
    d_blk.initialize()
    d_blk(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (2, 20, 8, 8)
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = tf.zeros((2, 16, 16, 3))
    d_blk = D_block(20)
    d_blk(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    TensorShape([2, 8, 8, 20])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    n_D = 64
    net_D = nn.Sequential(
        D_block(n_D),  # Output: (64, 32, 32)
        D_block(in_channels=n_D, out_channels=n_D*2),  # Output: (64 * 2, 16, 16)
        D_block(in_channels=n_D*2, out_channels=n_D*4),  # Output: (64 * 4, 8, 8)
        D_block(in_channels=n_D*4, out_channels=n_D*8),  # Output: (64 * 8, 4, 4)
        nn.Conv2d(in_channels=n_D*8, out_channels=1,
                  kernel_size=4, bias=False))  # Output: (1, 1, 1)
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    n_D = 64
    net_D = nn.Sequential()
    net_D.add(D_block(n_D),   # Output: (64, 32, 32)
              D_block(n_D*2),  # Output: (64 * 2, 16, 16)
              D_block(n_D*4),  # Output: (64 * 4, 8, 8)
              D_block(n_D*8),  # Output: (64 * 8, 4, 4)
              nn.Conv2D(1, kernel_size=4, use_bias=False))  # Output: (1, 1, 1)
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    n_D = 64
    net_D = tf.keras.Sequential([
        D_block(n_D), # Output: (32, 32, 64)
        D_block(out_channels=n_D*2), # Output: (16, 16, 64 * 2)
        D_block(out_channels=n_D*4), # Output: (8, 8, 64 * 4)
        D_block(out_channels=n_D*8), # Outupt: (4, 4, 64 * 64)
        # Output: (1, 1, 1)
        tf.keras.layers.Conv2D(1, kernel_size=4, use_bias=False)
    ])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = torch.zeros((1, 3, 64, 64))
    net_D(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    torch.Size([1, 1, 1, 1])
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = np.zeros((1, 3, 64, 64))
    net_D.initialize()
    net_D(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (1, 1, 1, 1)
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    x = tf.zeros((1, 64, 64, 3))
    net_D(x).shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    TensorShape([1, 1, 1, 1])
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def train(net_D, net_G, data_iter, num_epochs, lr, latent_dim,
              device=d2l.try_gpu()):
        loss = nn.BCEWithLogitsLoss(reduction='sum')
        for w in net_D.parameters():
            nn.init.normal_(w, 0, 0.02)
        for w in net_G.parameters():
            nn.init.normal_(w, 0, 0.02)
        net_D, net_G = net_D.to(device), net_G.to(device)
        trainer_hp = {'lr': lr, 'betas': [0.5,0.999]}
        trainer_D = torch.optim.Adam(net_D.parameters(), **trainer_hp)
        trainer_G = torch.optim.Adam(net_G.parameters(), **trainer_hp)
        animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                                xlim=[1, num_epochs], nrows=2, figsize=(5, 5),
                                legend=['discriminator', 'generator'])
        animator.fig.subplots_adjust(hspace=0.3)
        for epoch in range(1, num_epochs + 1):
            # Train one epoch
            timer = d2l.Timer()
            metric = d2l.Accumulator(3)  # loss_D, loss_G, num_examples
            for X, _ in data_iter:
                batch_size = X.shape[0]
                Z = torch.normal(0, 1, size=(batch_size, latent_dim, 1, 1))
                X, Z = X.to(device), Z.to(device)
                metric.add(d2l.update_D(X, Z, net_D, net_G, loss, trainer_D),
                           d2l.update_G(Z, net_D, net_G, loss, trainer_G),
                           batch_size)
            # Show generated examples
            Z = torch.normal(0, 1, size=(21, latent_dim, 1, 1), device=device)
            # Normalize the synthetic data to N(0, 1)
            fake_x = net_G(Z).permute(0, 2, 3, 1) / 2 + 0.5
            imgs = torch.cat(
                [torch.cat([
                    fake_x[i * 7 + j].cpu().detach() for j in range(7)], dim=1)
                 for i in range(len(fake_x)//7)], dim=0)
            animator.axes[1].cla()
            animator.axes[1].imshow(imgs)
            # Show the losses
            loss_D, loss_G = metric[0] / metric[2], metric[1] / metric[2]
            animator.add(epoch, (loss_D, loss_G))
        print(f'loss_D {loss_D:.3f}, loss_G {loss_G:.3f}, '
              f'{metric[2] / timer.stop():.1f} examples/sec on {str(device)}')
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def train(net_D, net_G, data_iter, num_epochs, lr, latent_dim,
              device=d2l.try_gpu()):
        loss = gluon.loss.SigmoidBCELoss()
        net_D.initialize(init=init.Normal(0.02), force_reinit=True, ctx=device)
        net_G.initialize(init=init.Normal(0.02), force_reinit=True, ctx=device)
        trainer_hp = {'learning_rate': lr, 'beta1': 0.5}
        trainer_D = gluon.Trainer(net_D.collect_params(), 'adam', trainer_hp)
        trainer_G = gluon.Trainer(net_G.collect_params(), 'adam', trainer_hp)
        animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                                xlim=[1, num_epochs], nrows=2, figsize=(5, 5),
                                legend=['discriminator', 'generator'])
        animator.fig.subplots_adjust(hspace=0.3)
        for epoch in range(1, num_epochs + 1):
            # Train one epoch
            timer = d2l.Timer()
            metric = d2l.Accumulator(3)  # loss_D, loss_G, num_examples
            for X, _ in data_iter:
                batch_size = X.shape[0]
                Z = np.random.normal(0, 1, size=(batch_size, latent_dim, 1, 1))
                X, Z = X.as_in_ctx(device), Z.as_in_ctx(device),
                metric.add(d2l.update_D(X, Z, net_D, net_G, loss, trainer_D),
                           d2l.update_G(Z, net_D, net_G, loss, trainer_G),
                           batch_size)
            # Show generated examples
            Z = np.random.normal(0, 1, size=(21, latent_dim, 1, 1), ctx=device)
            # Normalize the synthetic data to N(0, 1)
            fake_x = net_G(Z).transpose(0, 2, 3, 1) / 2 + 0.5
            imgs = np.concatenate(
                [np.concatenate([fake_x[i * 7 + j] for j in range(7)], axis=1)
                 for i in range(len(fake_x)//7)], axis=0)
            animator.axes[1].cla()
            animator.axes[1].imshow(imgs.asnumpy())
            # Show the losses
            loss_D, loss_G = metric[0] / metric[2], metric[1] / metric[2]
            animator.add(epoch, (loss_D, loss_G))
        print(f'loss_D {loss_D:.3f}, loss_G {loss_G:.3f}, '
              f'{metric[2] / timer.stop():.1f} examples/sec on {str(device)}')
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    def train(net_D, net_G, data_iter, num_epochs, lr, latent_dim,
              device=d2l.try_gpu()):
        loss = tf.keras.losses.BinaryCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
    
        for w in net_D.trainable_variables:
            w.assign(tf.random.normal(mean=0, stddev=0.02, shape=w.shape))
        for w in net_G.trainable_variables:
            w.assign(tf.random.normal(mean=0, stddev=0.02, shape=w.shape))
    
        optimizer_hp = {"lr": lr, "beta_1": 0.5, "beta_2": 0.999}
        optimizer_D = tf.keras.optimizers.Adam(**optimizer_hp)
        optimizer_G = tf.keras.optimizers.Adam(**optimizer_hp)
    
        animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                                xlim=[1, num_epochs], nrows=2, figsize=(5, 5),
                                legend=['discriminator', 'generator'])
        animator.fig.subplots_adjust(hspace=0.3)
    
        for epoch in range(1, num_epochs + 1):
            # Train one epoch
            timer = d2l.Timer()
            metric = d2l.Accumulator(3) # loss_D, loss_G, num_examples
            for X, _ in data_iter:
                batch_size = X.shape[0]
                Z = tf.random.normal(mean=0, stddev=1,
                                     shape=(batch_size, 1, 1, latent_dim))
                metric.add(d2l.update_D(X, Z, net_D, net_G, loss, optimizer_D),
                           d2l.update_G(Z, net_D, net_G, loss, optimizer_G),
                           batch_size)
    
            # Show generated examples
            Z = tf.random.normal(mean=0, stddev=1, shape=(21, 1, 1, latent_dim))
            # Normalize the synthetic data to N(0, 1)
            fake_x = net_G(Z) / 2 + 0.5
            imgs = tf.concat([tf.concat([fake_x[i * 7 + j] for j in range(7)],
                                        axis=1)
                              for i in range(len(fake_x) // 7)], axis=0)
            animator.axes[1].cla()
            animator.axes[1].imshow(imgs)
            # Show the losses
            loss_D, loss_G = metric[0] / metric[2], metric[1] / metric[2]
            animator.add(epoch, (loss_D, loss_G))
        print(f'loss_D {loss_D:.3f}, loss_G {loss_G:.3f}, '
              f'{metric[2] / timer.stop():.1f} examples/sec on {str(device._device_name)}')
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    latent_dim, lr, num_epochs = 100, 0.005, 20
    train(net_D, net_G, data_iter, num_epochs, lr, latent_dim)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    loss_D 0.023, loss_G 7.359, 2292.7 examples/sec on cuda:0
.. figure:: output_dcgan_2541de_183_1.svg
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    latent_dim, lr, num_epochs = 100, 0.005, 20
    train(net_D, net_G, data_iter, num_epochs, lr, latent_dim)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    loss_D 0.035, loss_G 6.190, 2583.0 examples/sec on gpu(0)
.. figure:: output_dcgan_2541de_186_1.svg
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    latent_dim, lr, num_epochs = 100, 0.0005, 40
    train(net_D, net_G, data_iter, num_epochs, lr, latent_dim)
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    loss_D 0.161, loss_G 4.254, 2048.0 examples/sec on /GPU:0
.. figure:: output_dcgan_2541de_189_1.svg
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
`Discussions `__
.. raw:: html
    
.. raw:: html
    
`Discussions `__
.. raw:: html
    
.. raw:: html