.. _sec_object-detection-dataset:
The Object Detection Dataset
============================
There is no small dataset such as MNIST and Fashion-MNIST in the field
of object detection. In order to quickly demonstrate object detection
models, we collected and labeled a small dataset. First, we took photos
of free bananas from our office and generated 1000 banana images with
different rotations and sizes. Then we placed each banana image at a
random position on some background image. In the end, we labeled
bounding boxes for those bananas on the images.
Downloading the Dataset
-----------------------
The banana detection dataset with all the image and csv label files can
be downloaded directly from the Internet.
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    %matplotlib inline
    import os
    import pandas as pd
    import torch
    import torchvision
    from d2l import torch as d2l
    
    #@save
    d2l.DATA_HUB['banana-detection'] = (
        d2l.DATA_URL + 'banana-detection.zip',
        '5de26c8fce5ccdea9f91267273464dc968d20d72')
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    %matplotlib inline
    import os
    import pandas as pd
    from mxnet import gluon, image, np, npx
    from d2l import mxnet as d2l
    
    npx.set_np()
    
    #@save
    d2l.DATA_HUB['banana-detection'] = (
        d2l.DATA_URL + 'banana-detection.zip',
        '5de26c8fce5ccdea9f91267273464dc968d20d72')
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    #@save
    def read_data_bananas(is_train=True):
        """Read the banana detection dataset images and labels."""
        data_dir = d2l.download_extract('banana-detection')
        csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
                                 else 'bananas_val', 'label.csv')
        csv_data = pd.read_csv(csv_fname)
        csv_data = csv_data.set_index('img_name')
        images, targets = [], []
        for img_name, target in csv_data.iterrows():
            images.append(torchvision.io.read_image(
                os.path.join(data_dir, 'bananas_train' if is_train else
                             'bananas_val', 'images', f'{img_name}')))
            # Here `target` contains (class, upper-left x, upper-left y,
            # lower-right x, lower-right y), where all the images have the same
            # banana class (index 0)
            targets.append(list(target))
        return images, torch.tensor(targets).unsqueeze(1) / 256
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    #@save
    def read_data_bananas(is_train=True):
        """Read the banana detection dataset images and labels."""
        data_dir = d2l.download_extract('banana-detection')
        csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
                                 else 'bananas_val', 'label.csv')
        csv_data = pd.read_csv(csv_fname)
        csv_data = csv_data.set_index('img_name')
        images, targets = [], []
        for img_name, target in csv_data.iterrows():
            images.append(image.imread(
                os.path.join(data_dir, 'bananas_train' if is_train else
                             'bananas_val', 'images', f'{img_name}')))
            # Here `target` contains (class, upper-left x, upper-left y,
            # lower-right x, lower-right y), where all the images have the same
            # banana class (index 0)
            targets.append(list(target))
        return images, np.expand_dims(np.array(targets), 1) / 256
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    #@save
    class BananasDataset(torch.utils.data.Dataset):
        """A customized dataset to load the banana detection dataset."""
        def __init__(self, is_train):
            self.features, self.labels = read_data_bananas(is_train)
            print('read ' + str(len(self.features)) + (f' training examples' if
                  is_train else f' validation examples'))
    
        def __getitem__(self, idx):
            return (self.features[idx].float(), self.labels[idx])
    
        def __len__(self):
            return len(self.features)
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    #@save
    class BananasDataset(gluon.data.Dataset):
        """A customized dataset to load the banana detection dataset."""
        def __init__(self, is_train):
            self.features, self.labels = read_data_bananas(is_train)
            print('read ' + str(len(self.features)) + (f' training examples' if
                  is_train else f' validation examples'))
    
        def __getitem__(self, idx):
            return (self.features[idx].astype('float32').transpose(2, 0, 1),
                    self.labels[idx])
    
        def __len__(self):
            return len(self.features)
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    #@save
    def load_data_bananas(batch_size):
        """Load the banana detection dataset."""
        train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
                                                 batch_size, shuffle=True)
        val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
                                               batch_size)
        return train_iter, val_iter
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    #@save
    def load_data_bananas(batch_size):
        """Load the banana detection dataset."""
        train_iter = gluon.data.DataLoader(BananasDataset(is_train=True),
                                           batch_size, shuffle=True)
        val_iter = gluon.data.DataLoader(BananasDataset(is_train=False),
                                         batch_size)
        return train_iter, val_iter
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    batch_size, edge_size = 32, 256
    train_iter, _ = load_data_bananas(batch_size)
    batch = next(iter(train_iter))
    batch[0].shape, batch[1].shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    Downloading ../data/banana-detection.zip from http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip...
    read 1000 training examples
    read 100 validation examples
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    (torch.Size([32, 3, 256, 256]), torch.Size([32, 1, 5]))
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    batch_size, edge_size = 32, 256
    train_iter, _ = load_data_bananas(batch_size)
    batch = next(iter(train_iter))
    batch[0].shape, batch[1].shape
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    Downloading ../data/banana-detection.zip from http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip...
    [22:09:31] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
    read 1000 training examples
    read 100 validation examples
.. raw:: latex
   \diilbookstyleoutputcell
.. parsed-literal::
    :class: output
    ((32, 3, 256, 256), (32, 1, 5))
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    imgs = (batch[0][:10].permute(0, 2, 3, 1)) / 255
    axes = d2l.show_images(imgs, 2, 5, scale=2)
    for ax, label in zip(axes, batch[1][:10]):
        d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
.. figure:: output_object-detection-dataset_641ef0_48_0.png
.. raw:: html
    
.. raw:: html
    
.. raw:: latex
   \diilbookstyleinputcell
.. code:: python
    imgs = (batch[0][:10].transpose(0, 2, 3, 1)) / 255
    axes = d2l.show_images(imgs, 2, 5, scale=2)
    for ax, label in zip(axes, batch[1][:10]):
        d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
.. figure:: output_object-detection-dataset_641ef0_51_0.png
.. raw:: html
    
.. raw:: html
    
.. raw:: html
    
`Discussions `__
.. raw:: html
    
.. raw:: html
    
`Discussions `__
.. raw:: html
    
.. raw:: html