8.8. Designing Convolution Network Architectures¶ Open the notebook in SageMaker Studio Lab
The previous sections have taken us on a tour of modern network design for computer vision. Common to all the work we covered was that it greatly relied on the intuition of scientists. Many of the architectures are heavily informed by human creativity and to a much lesser extent by systematic exploration of the design space that deep networks offer. Nonetheless, this network engineering approach has been tremendously successful.
Ever since AlexNet (Section 8.1) beat conventional computer
vision models on ImageNet, it has become popular to construct very deep
networks by stacking blocks of convolutions, all designed according to
the same pattern. In particular,
Up to now we have omitted networks obtained via neural architecture search (NAS) (Liu et al., 2018, Zoph and Le, 2016). We chose to do so since their cost is usually enormous, relying on brute-force search, genetic algorithms, reinforcement learning, or some other form of hyperparameter optimization. Given a fixed search space, NAS uses a search strategy to automatically select an architecture based on the returned performance estimation. The outcome of NAS is a single network instance. EfficientNets are a notable outcome of this search (Tan and Le, 2019).
In the following we discuss an idea that is quite different to the quest for the single best network. It is computationally relatively inexpensive, it leads to scientific insights on the way, and it is quite effective in terms of the quality of outcomes. Let’s review the strategy by Radosavovic et al. (2020) to design network design spaces. The strategy combines the strength of manual design and NAS. It accomplishes this by operating on distributions of networks and optimizing the distributions in a way to obtain good performance for entire families of networks. The outcome of it are RegNets, specifically RegNetX and RegNetY, plus a range of guiding principles for the design of performant CNNs.
8.8.1. The AnyNet Design Space¶
The description below closely follows the reasoning in Radosavovic et al. (2020) with some abbreviations to make it fit in the scope of the book. To begin, we need a template for the family of networks to explore. One of the commonalities of the designs in this chapter is that the networks consist of a stem, a body and a head. The stem performs initial image processing, often through convolutions with a larger window size. The body consists of multiple blocks, carrying out the bulk of the transformations needed to go from raw images to object representations. Lastly, the head converts this into the desired outputs, such as via a softmax regressor for multiclass classification. The body, in turn, consists of multiple stages, operating on the image at decreasing resolutions. In fact, both the stem and each subsequent stage quarter the spatial resolution. Lastly, each stage consists of one or more blocks. This pattern is common to all networks, from VGG to ResNeXt. Indeed, for the design of generic AnyNet networks, Radosavovic et al. (2020) used the ResNeXt block of Fig. 8.6.5.
Fig. 8.8.1 The AnyNet design space. The numbers
Let’s review the structure outlined in Fig. 8.8.1 in
detail. As mentioned, an AnyNet consists of a stem, body, and head. The
stem takes as its input RGB images (3 channels), using a
Since the network is designed to work well with ImageNet images of shape
Most of the relevant design decisions are inherent to the body of the
network. It proceeds in stages, where each stage is composed of the same
type of ResNeXt blocks as we discussed in Section 8.6.5.
The design there is again entirely generic: we begin with a block that
halves the resolution by using a stride of
This seemingly generic design space provides us nonetheless with many
parameters: we can set the block width (number of channels)
class AnyNet(d2l.Classifier):
arch: tuple
stem_channels: int
lr: float = 0.1
num_classes: int = 10
training: bool = True
def setup(self):
self.net = self.create_net()
def stem(self, num_channels):
return nn.Sequential([
nn.Conv(num_channels, kernel_size=(3, 3), strides=(2, 2),
padding=(1, 1)),
nn.BatchNorm(not self.training),
nn.relu
])
Each stage consists of depth
ResNeXt blocks, where num_channels
specifies the block width. Note that the first block halves the height
and width of input images.
@d2l.add_to_class(AnyNet)
def stage(self, depth, num_channels, groups, bot_mul):
blk = []
for i in range(depth):
if i == 0:
blk.append(d2l.ResNeXtBlock(num_channels, groups, bot_mul,
use_1x1conv=True, strides=2))
else:
blk.append(d2l.ResNeXtBlock(num_channels, groups, bot_mul))
return nn.Sequential(*blk)
@d2l.add_to_class(AnyNet)
def stage(self, depth, num_channels, groups, bot_mul):
net = nn.Sequential()
for i in range(depth):
if i == 0:
net.add(d2l.ResNeXtBlock(
num_channels, groups, bot_mul, use_1x1conv=True, strides=2))
else:
net.add(d2l.ResNeXtBlock(
num_channels, num_channels, groups, bot_mul))
return net
@d2l.add_to_class(AnyNet)
def stage(self, depth, num_channels, groups, bot_mul):
blk = []
for i in range(depth):
if i == 0:
blk.append(d2l.ResNeXtBlock(num_channels, groups, bot_mul,
use_1x1conv=True, strides=(2, 2), training=self.training))
else:
blk.append(d2l.ResNeXtBlock(num_channels, groups, bot_mul,
training=self.training))
return nn.Sequential(blk)
@d2l.add_to_class(AnyNet)
def stage(self, depth, num_channels, groups, bot_mul):
net = tf.keras.models.Sequential()
for i in range(depth):
if i == 0:
net.add(d2l.ResNeXtBlock(num_channels, groups, bot_mul,
use_1x1conv=True, strides=2))
else:
net.add(d2l.ResNeXtBlock(num_channels, groups, bot_mul))
return net
Putting the network stem, body, and head together, we complete the implementation of AnyNet.
@d2l.add_to_class(AnyNet)
def __init__(self, arch, stem_channels, lr=0.1, num_classes=10):
super(AnyNet, self).__init__()
self.save_hyperparameters()
self.net = nn.Sequential(self.stem(stem_channels))
for i, s in enumerate(arch):
self.net.add_module(f'stage{i+1}', self.stage(*s))
self.net.add_module('head', nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),
nn.LazyLinear(num_classes)))
self.net.apply(d2l.init_cnn)
@d2l.add_to_class(AnyNet)
def __init__(self, arch, stem_channels, lr=0.1, num_classes=10):
super(AnyNet, self).__init__()
self.save_hyperparameters()
self.net = nn.Sequential()
self.net.add(self.stem(stem_channels))
for i, s in enumerate(arch):
self.net.add(self.stage(*s))
self.net.add(nn.GlobalAvgPool2D(), nn.Dense(num_classes))
self.net.initialize(init.Xavier())
@d2l.add_to_class(AnyNet)
def create_net(self):
net = nn.Sequential([self.stem(self.stem_channels)])
for i, s in enumerate(self.arch):
net.layers.extend([self.stage(*s)])
net.layers.extend([nn.Sequential([
lambda x: nn.avg_pool(x, window_shape=x.shape[1:3],
strides=x.shape[1:3], padding='valid'),
lambda x: x.reshape((x.shape[0], -1)),
nn.Dense(self.num_classes)])])
return net
@d2l.add_to_class(AnyNet)
def __init__(self, arch, stem_channels, lr=0.1, num_classes=10):
super(AnyNet, self).__init__()
self.save_hyperparameters()
self.net = tf.keras.models.Sequential(self.stem(stem_channels))
for i, s in enumerate(arch):
self.net.add(self.stage(*s))
self.net.add(tf.keras.models.Sequential([
tf.keras.layers.GlobalAvgPool2D(),
tf.keras.layers.Dense(units=num_classes)]))
8.8.2. Distributions and Parameters of Design Spaces¶
As just discussed in Section 8.8.1,
parameters of a design space are hyperparameters of networks in that
design space. Consider the problem of identifying good parameters in the
AnyNet design space. We could try finding the single best parameter
choice for a given amount of computation (e.g., FLOPs and compute time).
If we allowed for even only two possible choices for each parameter,
we would have to explore
We assume that general design principles actually exist, so that many networks satisfying these requirements should offer good performance. Consequently, identifying a distribution over networks can be a sensible strategy. In other words, we assume that there are many good needles in the haystack.
We need not train networks to convergence before we can assess whether a network is good. Instead, it is sufficient to use the intermediate results as reliable guidance for final accuracy. Using (approximate) proxies to optimize an objective is referred to as multi-fidelity optimization (Forrester et al., 2007). Consequently, design optimization is carried out, based on the accuracy achieved after only a few passes through the dataset, reducing the cost significantly.
Results obtained at a smaller scale (for smaller networks) generalize to larger ones. Consequently, optimization is carried out for networks that are structurally similar, but with a smaller number of blocks, fewer channels, etc. Only in the end will we need to verify that the so-found networks also offer good performance at scale.
Aspects of the design can be approximately factorized so that it is possible to infer their effect on the quality of the outcome somewhat independently. In other words, the optimization problem is moderately easy.
These assumptions allow us to test many networks cheaply. In particular,
we can sample uniformly from the space of configurations and evaluate
their performance. Subsequently, we can evaluate the quality of the
choice of parameters by reviewing the distribution of error/accuracy
that can be achieved with said networks. Denote by
Our goal is now to find a distribution
Whenever the CDF for one set of choices majorizes (or matches) another
CDF it follows that its choice of parameters is superior (or
indifferent). Accordingly
Radosavovic et al. (2020) experimented with a
shared network bottleneck ratio

Fig. 8.8.2 Comparing error empirical distribution functions of design spaces.
Next we look for ways to reduce the multitude of potential choices for
width and depth of the stages. It is a reasonable assumption that, as we
go deeper, the number of channels should increase, i.e.,
8.8.3. RegNet¶
The resulting
Share the bottleneck ratio
for all stages ;Share the group width
for all stages ;Increase network width across stages:
;Increase network depth across stages:
.
This leaves us with a final set of choices: how to pick the specific
values for the above parameters of the eventual
We recommend the interested reader reviews further details in the design
of specific networks for different amounts of computation by perusing
Radosavovic et al. (2020). For instance, an
effective 32-layer RegNetX variant is given by
We can see that each RegNetX stage progressively reduces resolution and increases output channels.
Sequential output shape: torch.Size([1, 32, 48, 48])
Sequential output shape: torch.Size([1, 32, 24, 24])
Sequential output shape: torch.Size([1, 80, 12, 12])
Sequential output shape: torch.Size([1, 10])
[22:33:30] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
Sequential output shape: (1, 32, 48, 48)
Sequential output shape: (1, 32, 24, 24)
Sequential output shape: (1, 80, 12, 12)
GlobalAvgPool2D output shape: (1, 80, 1, 1)
Dense output shape: (1, 10)
Sequential output shape: (1, 48, 48, 32)
Sequential output shape: (1, 24, 24, 32)
Sequential output shape: (1, 12, 12, 80)
Sequential output shape: (1, 10)
8.8.5. Discussion¶
With desirable inductive biases (assumptions or preferences) like locality and translation invariance (Section 7.1) for vision, CNNs have been the dominant architectures in this area. This remained the case from LeNet up until Transformers (Section 11.7) (Dosovitskiy et al., 2021, Touvron et al., 2021) started surpassing CNNs in terms of accuracy. While much of the recent progress in terms of vision Transformers can be backported into CNNs (Liu et al., 2022), it is only possible at a higher computational cost. Just as importantly, recent hardware optimizations (NVIDIA Ampere and Hopper) have only widened the gap in favor of Transformers.
It is worth noting that Transformers have a significantly lower degree of inductive bias towards locality and translation invariance than CNNs. That learned structures prevailed is due, not least, to the availability of large image collections, such as LAION-400m and LAION-5B (Schuhmann et al., 2022) with up to 5 billion images. Quite surprisingly, some of the more relevant work in this context even includes MLPs (Tolstikhin et al., 2021).
In sum, vision Transformers (Section 11.8) by now lead in terms of state-of-the-art performance in large-scale image classification, showing that scalability trumps inductive biases (Dosovitskiy et al., 2021). This includes pretraining large-scale Transformers (Section 11.9) with multi-head self-attention (Section 11.5). We invite the readers to dive into these chapters for a much more detailed discussion.
8.8.6. Exercises¶
Increase the number of stages to four. Can you design a deeper RegNetX that performs better?
De-ResNeXt-ify RegNets by replacing the ResNeXt block with the ResNet block. How does your new model perform?
Implement multiple instances of a “VioNet” family by violating the design principles of RegNetX. How do they perform? Which of (
, , , ) is the most important factor?Your goal is to design the “perfect” MLP. Can you use the design principles introduced above to find good architectures? Is it possible to extrapolate from small to large networks?