.. _sec_ndarray:
Data Manipulation
=================
In order to get anything done, we need some way to store and manipulate
data. Generally, there are two important things we need to do with data:
(i) acquire them; and (ii) process them once they are inside the
computer. There is no point in acquiring data without some way to store
it, so to start, let’s get our hands dirty with :math:`n`-dimensional
arrays, which we also call *tensors*. If you already know the NumPy
scientific computing package, this will be a breeze. For all modern deep
learning frameworks, the *tensor class* (``ndarray`` in MXNet,
``Tensor`` in PyTorch and TensorFlow) resembles NumPy’s ``ndarray``,
with a few killer features added. First, the tensor class supports
automatic differentiation. Second, it leverages GPUs to accelerate
numerical computation, whereas NumPy only runs on CPUs. These properties
make neural networks both easy to code and fast to run.
Getting Started
---------------
.. raw:: html
.. raw:: html
To start, we import the PyTorch library. Note that the package name is
``torch``.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
.. raw:: html
.. raw:: html
To start, we import the ``np`` (``numpy``) and ``npx``
(``numpy_extension``) modules from MXNet. Here, the ``np`` module
includes functions supported by NumPy, while the ``npx`` module contains
a set of extensions developed to empower deep learning within a
NumPy-like environment. When using tensors, we almost always invoke the
``set_np`` function: this is for compatibility of tensor processing by
other components of MXNet.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import np, npx
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import jax
from jax import numpy as jnp
.. raw:: html
.. raw:: html
To start, we import ``tensorflow``. For brevity, practitioners often
assign the alias ``tf``.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import tensorflow as tf
.. raw:: html
.. raw:: html
A tensor represents a (possibly multidimensional) array of numerical
values. In the one-dimensional case, i.e., when only one axis is needed
for the data, a tensor is called a *vector*. With two axes, a tensor is
called a *matrix*. With :math:`k > 2` axes, we drop the specialized
names and just refer to the object as a :math:`k^\textrm{th}`-*order
tensor*.
.. raw:: html
.. raw:: html
PyTorch provides a variety of functions for creating new tensors
prepopulated with values. For example, by invoking ``arange(n)``, we can
create a vector of evenly spaced values, starting at 0 (included) and
ending at ``n`` (not included). By default, the interval size is
:math:`1`. Unless otherwise specified, new tensors are stored in main
memory and designated for CPU-based computation.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x = torch.arange(12, dtype=torch.float32)
x
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.])
Each of these values is called an *element* of the tensor. The tensor
``x`` contains 12 elements. We can inspect the total number of elements
in a tensor via its ``numel`` method.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x.numel()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
12
.. raw:: html
.. raw:: html
MXNet provides a variety of functions for creating new tensors
prepopulated with values. For example, by invoking ``arange(n)``, we can
create a vector of evenly spaced values, starting at 0 (included) and
ending at ``n`` (not included). By default, the interval size is
:math:`1`. Unless otherwise specified, new tensors are stored in main
memory and designated for CPU-based computation.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x = np.arange(12)
x
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[21:58:20] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.])
Each of these values is called an *element* of the tensor. The tensor
``x`` contains 12 elements. We can inspect the total number of elements
in a tensor via its ``size`` attribute.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x.size
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
12
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x = jnp.arange(12)
x
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], dtype=int32)
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x.size
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
12
.. raw:: html
.. raw:: html
TensorFlow provides a variety of functions for creating new tensors
prepopulated with values. For example, by invoking ``range(n)``, we can
create a vector of evenly spaced values, starting at 0 (included) and
ending at ``n`` (not included). By default, the interval size is
:math:`1`. Unless otherwise specified, new tensors are stored in main
memory and designated for CPU-based computation.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x = tf.range(12, dtype=tf.float32)
x
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Each of these values is called an *element* of the tensor. The tensor
``x`` contains 12 elements. We can inspect the total number of elements
in a tensor via the ``size`` function.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tf.size(x)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
We can access a tensor’s *shape* (the length along each axis) by
inspecting its ``shape`` attribute. Because we are dealing with a vector
here, the ``shape`` contains just a single element and is identical to
the size.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([12])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(12,)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(12,)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
TensorShape([12])
.. raw:: html
.. raw:: html
We can change the shape of a tensor without altering its size or values,
by invoking ``reshape``. For example, we can transform our vector ``x``
whose shape is (12,) to a matrix ``X`` with shape (3, 4). This new
tensor retains all elements but reconfigures them into a matrix. Notice
that the elements of our vector are laid out one row at a time and thus
``x[3] == X[0, 3]``.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = x.reshape(3, 4)
X
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = x.reshape(3, 4)
X
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = x.reshape(3, 4)
X
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = tf.reshape(x, (3, 4))
X
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Note that specifying every shape component to ``reshape`` is redundant.
Because we already know our tensor’s size, we can work out one component
of the shape given the rest. For example, given a tensor of size
:math:`n` and target shape (:math:`h`, :math:`w`), we know that
:math:`w = n/h`. To automatically infer one component of the shape, we
can place a ``-1`` for the shape component that should be inferred
automatically. In our case, instead of calling ``x.reshape(3, 4)``, we
could have equivalently called ``x.reshape(-1, 4)`` or
``x.reshape(3, -1)``.
Practitioners often need to work with tensors initialized to contain all
0s or 1s. We can construct a tensor with all elements set to 0 and a
shape of (2, 3, 4) via the ``zeros`` function.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
torch.zeros((2, 3, 4))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
np.zeros((2, 3, 4))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
jnp.zeros((2, 3, 4))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]], dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tf.zeros((2, 3, 4))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Similarly, we can create a tensor with all 1s by invoking ``ones``.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
torch.ones((2, 3, 4))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
np.ones((2, 3, 4))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
jnp.ones((2, 3, 4))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]], dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tf.ones((2, 3, 4))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
We often wish to sample each element randomly (and independently) from a
given probability distribution. For example, the parameters of neural
networks are often initialized randomly. The following snippet creates a
tensor with elements drawn from a standard Gaussian (normal)
distribution with mean 0 and standard deviation 1.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
torch.randn(3, 4)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[ 0.1351, -0.9099, -0.2028, 2.1937],
[-0.3200, -0.7545, 0.8086, -1.8730],
[ 0.3929, 0.4931, 0.9114, -0.7072]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
np.random.normal(0, 1, size=(3, 4))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[ 2.2122064 , 1.1630787 , 0.7740038 , 0.4838046 ],
[ 1.0434403 , 0.29956347, 1.1839255 , 0.15302546],
[ 1.8917114 , -1.1688148 , -1.2347414 , 1.5580711 ]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Any call of a random function in JAX requires a key to be
# specified, feeding the same key to a random function will
# always result in the same sample being generated
jax.random.normal(jax.random.PRNGKey(0), (3, 4))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[ 1.1901639 , -1.0996888 , 0.44367844, 0.5984697 ],
[-0.39189556, 0.69261974, 0.46018356, -2.068578 ],
[-0.21438177, -0.9898306 , -0.6789304 , 0.27362573]], dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tf.random.normal(shape=[3, 4])
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Finally, we can construct tensors by supplying the exact values for each
element by supplying (possibly nested) Python list(s) containing
numerical literals. Here, we construct a matrix with a list of lists,
where the outermost list corresponds to axis 0, and the inner list
corresponds to axis 1.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
torch.tensor([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[2, 1, 4, 3],
[1, 2, 3, 4],
[4, 3, 2, 1]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
np.array([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[2., 1., 4., 3.],
[1., 2., 3., 4.],
[4., 3., 2., 1.]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
jnp.array([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[2, 1, 4, 3],
[1, 2, 3, 4],
[4, 3, 2, 1]], dtype=int32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tf.constant([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Indexing and Slicing
--------------------
As with Python lists, we can access tensor elements by indexing
(starting with 0). To access an element based on its position relative
to the end of the list, we can use negative indexing. Finally, we can
access whole ranges of indices via slicing (e.g., ``X[start:stop]``),
where the returned value includes the first index (``start``) *but not
the last* (``stop``). Finally, when only one index (or slice) is
specified for a :math:`k^\textrm{th}`-order tensor, it is applied along
axis 0. Thus, in the following code, ``[-1]`` selects the last row and
``[1:3]`` selects the second and third rows.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X[-1], X[1:3]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(tensor([ 8., 9., 10., 11.]),
tensor([[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]]))
Beyond reading them, we can also *write* elements of a matrix by
specifying indices.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X[1, 2] = 17
X
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 17., 7.],
[ 8., 9., 10., 11.]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X[-1], X[1:3]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(array([ 8., 9., 10., 11.]),
array([[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]]))
Beyond reading them, we can also *write* elements of a matrix by
specifying indices.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X[1, 2] = 17
X
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[ 0., 1., 2., 3.],
[ 4., 5., 17., 7.],
[ 8., 9., 10., 11.]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X[-1], X[1:3]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(Array([ 8, 9, 10, 11], dtype=int32),
Array([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32))
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# JAX arrays are immutable. jax.numpy.ndarray.at index
# update operators create a new array with the corresponding
# modifications made
X_new_1 = X.at[1, 2].set(17)
X_new_1
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[ 0, 1, 2, 3],
[ 4, 5, 17, 7],
[ 8, 9, 10, 11]], dtype=int32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X[-1], X[1:3]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(,
)
``Tensors`` in TensorFlow are immutable, and cannot be assigned to.
``Variables`` in TensorFlow are mutable containers of state that support
assignments. Keep in mind that gradients in TensorFlow do not flow
backwards through ``Variable`` assignments.
Beyond assigning a value to the entire ``Variable``, we can write
elements of a ``Variable`` by specifying indices.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X_var = tf.Variable(X)
X_var[1, 2].assign(9)
X_var
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
If we want to assign multiple elements the same value, we apply the
indexing on the left-hand side of the assignment operation. For
instance, ``[:2, :]`` accesses the first and second rows, where ``:``
takes all the elements along axis 1 (column). While we discussed
indexing for matrices, this also works for vectors and for tensors of
more than two dimensions.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X[:2, :] = 12
X
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[12., 12., 12., 12.],
[12., 12., 12., 12.],
[ 8., 9., 10., 11.]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X[:2, :] = 12
X
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[12., 12., 12., 12.],
[12., 12., 12., 12.],
[ 8., 9., 10., 11.]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X_new_2 = X_new_1.at[:2, :].set(12)
X_new_2
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[12, 12, 12, 12],
[12, 12, 12, 12],
[ 8, 9, 10, 11]], dtype=int32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X_var = tf.Variable(X)
X_var[:2, :].assign(tf.ones(X_var[:2,:].shape, dtype=tf.float32) * 12)
X_var
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Operations
----------
Now that we know how to construct tensors and how to read from and write
to their elements, we can begin to manipulate them with various
mathematical operations. Among the most useful of these are the
*elementwise* operations. These apply a standard scalar operation to
each element of a tensor. For functions that take two tensors as inputs,
elementwise operations apply some standard binary operator on each pair
of corresponding elements. We can create an elementwise function from
any function that maps from a scalar to a scalar.
In mathematical notation, we denote such *unary* scalar operators
(taking one input) by the signature
:math:`f: \mathbb{R} \rightarrow \mathbb{R}`. This just means that the
function maps from any real number onto some other real number. Most
standard operators, including unary ones like :math:`e^x`, can be
applied elementwise.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
torch.exp(x)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([162754.7969, 162754.7969, 162754.7969, 162754.7969, 162754.7969,
162754.7969, 162754.7969, 162754.7969, 2980.9580, 8103.0840,
22026.4648, 59874.1406])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
np.exp(x)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([1.0000000e+00, 2.7182817e+00, 7.3890562e+00, 2.0085537e+01,
5.4598148e+01, 1.4841316e+02, 4.0342880e+02, 1.0966332e+03,
2.9809580e+03, 8.1030840e+03, 2.2026465e+04, 5.9874141e+04])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
jnp.exp(x)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([1.0000000e+00, 2.7182817e+00, 7.3890562e+00, 2.0085537e+01,
5.4598152e+01, 1.4841316e+02, 4.0342880e+02, 1.0966332e+03,
2.9809580e+03, 8.1030840e+03, 2.2026465e+04, 5.9874141e+04], dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tf.exp(x)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Likewise, we denote *binary* scalar operators, which map pairs of real
numbers to a (single) real number via the signature
:math:`f: \mathbb{R}, \mathbb{R} \rightarrow \mathbb{R}`. Given any two
vectors :math:`\mathbf{u}` and :math:`\mathbf{v}` *of the same shape*,
and a binary operator :math:`f`, we can produce a vector
:math:`\mathbf{c} = F(\mathbf{u},\mathbf{v})` by setting
:math:`c_i \gets f(u_i, v_i)` for all :math:`i`, where :math:`c_i, u_i`,
and :math:`v_i` are the :math:`i^\textrm{th}` elements of vectors
:math:`\mathbf{c}, \mathbf{u}`, and :math:`\mathbf{v}`. Here, we
produced the vector-valued
:math:`F: \mathbb{R}^d, \mathbb{R}^d \rightarrow \mathbb{R}^d` by
*lifting* the scalar function to an elementwise vector operation. The
common standard arithmetic operators for addition (``+``), subtraction
(``-``), multiplication (``*``), division (``/``), and exponentiation
(``**``) have all been *lifted* to elementwise operations for
identically-shaped tensors of arbitrary shape.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x = torch.tensor([1.0, 2, 4, 8])
y = torch.tensor([2, 2, 2, 2])
x + y, x - y, x * y, x / y, x ** y
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(tensor([ 3., 4., 6., 10.]),
tensor([-1., 0., 2., 6.]),
tensor([ 2., 4., 8., 16.]),
tensor([0.5000, 1.0000, 2.0000, 4.0000]),
tensor([ 1., 4., 16., 64.]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x = np.array([1, 2, 4, 8])
y = np.array([2, 2, 2, 2])
x + y, x - y, x * y, x / y, x ** y
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(array([ 3., 4., 6., 10.]),
array([-1., 0., 2., 6.]),
array([ 2., 4., 8., 16.]),
array([0.5, 1. , 2. , 4. ]),
array([ 1., 4., 16., 64.]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x = jnp.array([1.0, 2, 4, 8])
y = jnp.array([2, 2, 2, 2])
x + y, x - y, x * y, x / y, x ** y
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(Array([ 3., 4., 6., 10.], dtype=float32),
Array([-1., 0., 2., 6.], dtype=float32),
Array([ 2., 4., 8., 16.], dtype=float32),
Array([0.5, 1. , 2. , 4. ], dtype=float32),
Array([ 1., 4., 16., 64.], dtype=float32))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x = tf.constant([1.0, 2, 4, 8])
y = tf.constant([2.0, 2, 2, 2])
x + y, x - y, x * y, x / y, x ** y
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(,
,
,
,
)
.. raw:: html
.. raw:: html
In addition to elementwise computations, we can also perform linear
algebraic operations, such as dot products and matrix multiplications.
We will elaborate on these in :numref:`sec_linear-algebra`.
We can also *concatenate* multiple tensors, stacking them end-to-end to
form a larger one. We just need to provide a list of tensors and tell
the system along which axis to concatenate. The example below shows what
happens when we concatenate two matrices along rows (axis 0) instead of
columns (axis 1). We can see that the first output’s axis-0 length
(:math:`6`) is the sum of the two input tensors’ axis-0 lengths
(:math:`3 + 3`); while the second output’s axis-1 length (:math:`8`) is
the sum of the two input tensors’ axis-1 lengths (:math:`4 + 4`).
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = torch.arange(12, dtype=torch.float32).reshape((3,4))
Y = torch.tensor([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
torch.cat((X, Y), dim=0), torch.cat((X, Y), dim=1)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[ 2., 1., 4., 3.],
[ 1., 2., 3., 4.],
[ 4., 3., 2., 1.]]),
tensor([[ 0., 1., 2., 3., 2., 1., 4., 3.],
[ 4., 5., 6., 7., 1., 2., 3., 4.],
[ 8., 9., 10., 11., 4., 3., 2., 1.]]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = np.arange(12).reshape(3, 4)
Y = np.array([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
np.concatenate([X, Y], axis=0), np.concatenate([X, Y], axis=1)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[ 2., 1., 4., 3.],
[ 1., 2., 3., 4.],
[ 4., 3., 2., 1.]]),
array([[ 0., 1., 2., 3., 2., 1., 4., 3.],
[ 4., 5., 6., 7., 1., 2., 3., 4.],
[ 8., 9., 10., 11., 4., 3., 2., 1.]]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = jnp.arange(12, dtype=jnp.float32).reshape((3, 4))
Y = jnp.array([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
jnp.concatenate((X, Y), axis=0), jnp.concatenate((X, Y), axis=1)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(Array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[ 2., 1., 4., 3.],
[ 1., 2., 3., 4.],
[ 4., 3., 2., 1.]], dtype=float32),
Array([[ 0., 1., 2., 3., 2., 1., 4., 3.],
[ 4., 5., 6., 7., 1., 2., 3., 4.],
[ 8., 9., 10., 11., 4., 3., 2., 1.]], dtype=float32))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
Y = tf.constant([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
tf.concat([X, Y], axis=0), tf.concat([X, Y], axis=1)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(,
)
.. raw:: html
.. raw:: html
Sometimes, we want to construct a binary tensor via *logical
statements*. Take ``X == Y`` as an example. For each position ``i, j``,
if ``X[i, j]`` and ``Y[i, j]`` are equal, then the corresponding entry
in the result takes value ``1``, otherwise it takes value ``0``.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X == Y
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[False, True, False, True],
[False, False, False, False],
[False, False, False, False]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X == Y
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[False, True, False, True],
[False, False, False, False],
[False, False, False, False]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X == Y
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[False, True, False, True],
[False, False, False, False],
[False, False, False, False]], dtype=bool)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X == Y
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Summing all the elements in the tensor yields a tensor with only one
element.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X.sum()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor(66.)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X.sum()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array(66.)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X.sum()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array(66., dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
tf.reduce_sum(X)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
.. _subsec_broadcasting:
Broadcasting
------------
By now, you know how to perform elementwise binary operations on two
tensors of the same shape. Under certain conditions, even when shapes
differ, we can still perform elementwise binary operations by invoking
the *broadcasting mechanism*. Broadcasting works according to the
following two-step procedure: (i) expand one or both arrays by copying
elements along axes with length 1 so that after this transformation, the
two tensors have the same shape; (ii) perform an elementwise operation
on the resulting arrays.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
a = torch.arange(3).reshape((3, 1))
b = torch.arange(2).reshape((1, 2))
a, b
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(tensor([[0],
[1],
[2]]),
tensor([[0, 1]]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
a = np.arange(3).reshape(3, 1)
b = np.arange(2).reshape(1, 2)
a, b
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(array([[0.],
[1.],
[2.]]),
array([[0., 1.]]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
a = jnp.arange(3).reshape((3, 1))
b = jnp.arange(2).reshape((1, 2))
a, b
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(Array([[0],
[1],
[2]], dtype=int32),
Array([[0, 1]], dtype=int32))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
a = tf.reshape(tf.range(3), (3, 1))
b = tf.reshape(tf.range(2), (1, 2))
a, b
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(,
)
.. raw:: html
.. raw:: html
Since ``a`` and ``b`` are :math:`3\times1` and :math:`1\times2`
matrices, respectively, their shapes do not match up. Broadcasting
produces a larger :math:`3\times2` matrix by replicating matrix ``a``
along the columns and matrix ``b`` along the rows before adding them
elementwise.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
a + b
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[0, 1],
[1, 2],
[2, 3]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
a + b
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[0., 1.],
[1., 2.],
[2., 3.]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
a + b
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[0, 1],
[1, 2],
[2, 3]], dtype=int32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
a + b
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Saving Memory
-------------
Running operations can cause new memory to be allocated to host results.
For example, if we write ``Y = X + Y``, we dereference the tensor that
``Y`` used to point to and instead point ``Y`` at the newly allocated
memory. We can demonstrate this issue with Python’s ``id()`` function,
which gives us the exact address of the referenced object in memory.
Note that after we run ``Y = Y + X``, ``id(Y)`` points to a different
location. That is because Python first evaluates ``Y + X``, allocating
new memory for the result and then points ``Y`` to this new location in
memory.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
before = id(Y)
Y = Y + X
id(Y) == before
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
False
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
before = id(Y)
Y = Y + X
id(Y) == before
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
False
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
before = id(Y)
Y = Y + X
id(Y) == before
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
False
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
before = id(Y)
Y = Y + X
id(Y) == before
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
False
.. raw:: html
.. raw:: html
This might be undesirable for two reasons. First, we do not want to run
around allocating memory unnecessarily all the time. In machine
learning, we often have hundreds of megabytes of parameters and update
all of them multiple times per second. Whenever possible, we want to
perform these updates *in place*. Second, we might point at the same
parameters from multiple variables. If we do not update in place, we
must be careful to update all of these references, lest we spring a
memory leak or inadvertently refer to stale parameters.
.. raw:: html
.. raw:: html
Fortunately, performing in-place operations is easy. We can assign the
result of an operation to a previously allocated array ``Y`` by using
slice notation: ``Y[:] = ``. To illustrate this concept, we
overwrite the values of tensor ``Z``, after initializing it, using
``zeros_like``, to have the same shape as ``Y``.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Z = torch.zeros_like(Y)
print('id(Z):', id(Z))
Z[:] = X + Y
print('id(Z):', id(Z))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
id(Z): 140381179266448
id(Z): 140381179266448
If the value of ``X`` is not reused in subsequent computations, we can
also use ``X[:] = X + Y`` or ``X += Y`` to reduce the memory overhead of
the operation.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
before = id(X)
X += Y
id(X) == before
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
True
.. raw:: html
.. raw:: html
Fortunately, performing in-place operations is easy. We can assign the
result of an operation to a previously allocated array ``Y`` by using
slice notation: ``Y[:] = ``. To illustrate this concept, we
overwrite the values of tensor ``Z``, after initializing it, using
``zeros_like``, to have the same shape as ``Y``.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Z = np.zeros_like(Y)
print('id(Z):', id(Z))
Z[:] = X + Y
print('id(Z):', id(Z))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
id(Z): 139767554095872
id(Z): 139767554095872
If the value of ``X`` is not reused in subsequent computations, we can
also use ``X[:] = X + Y`` or ``X += Y`` to reduce the memory overhead of
the operation.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
before = id(X)
X += Y
id(X) == before
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
True
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# JAX arrays do not allow in-place operations
.. raw:: html
.. raw:: html
``Variables`` are mutable containers of state in TensorFlow. They
provide a way to store your model parameters. We can assign the result
of an operation to a ``Variable`` with ``assign``. To illustrate this
concept, we overwrite the values of ``Variable`` ``Z`` after
initializing it, using ``zeros_like``, to have the same shape as ``Y``.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Z = tf.Variable(tf.zeros_like(Y))
print('id(Z):', id(Z))
Z.assign(X + Y)
print('id(Z):', id(Z))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
id(Z): 139652041257360
id(Z): 139652041257360
Even once you store state persistently in a ``Variable``, you may want
to reduce your memory usage further by avoiding excess allocations for
tensors that are not your model parameters. Because TensorFlow
``Tensors`` are immutable and gradients do not flow through ``Variable``
assignments, TensorFlow does not provide an explicit way to run an
individual operation in-place.
However, TensorFlow provides the ``tf.function`` decorator to wrap
computation inside of a TensorFlow graph that gets compiled and
optimized before running. This allows TensorFlow to prune unused values,
and to reuse prior allocations that are no longer needed. This minimizes
the memory overhead of TensorFlow computations.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@tf.function
def computation(X, Y):
Z = tf.zeros_like(Y) # This unused value will be pruned out
A = X + Y # Allocations will be reused when no longer needed
B = A + Y
C = B + Y
return C + Y
computation(X, Y)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Conversion to Other Python Objects
----------------------------------
.. raw:: html
.. raw:: html
Converting to a NumPy tensor (``ndarray``), or vice versa, is easy. The
torch tensor and NumPy array will share their underlying memory, and
changing one through an in-place operation will also change the other.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
A = X.numpy()
B = torch.from_numpy(A)
type(A), type(B)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(numpy.ndarray, torch.Tensor)
.. raw:: html
.. raw:: html
Converting to a NumPy tensor (``ndarray``), or vice versa, is easy. The
converted result does not share memory. This minor inconvenience is
actually quite important: when you perform operations on the CPU or on
GPUs, you do not want to halt computation, waiting to see whether the
NumPy package of Python might want to be doing something else with the
same chunk of memory.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
A = X.asnumpy()
B = np.array(A)
type(A), type(B)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(numpy.ndarray, mxnet.numpy.ndarray)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
A = jax.device_get(X)
B = jax.device_put(A)
type(A), type(B)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(numpy.ndarray, jaxlib.xla_extension.ArrayImpl)
.. raw:: html
.. raw:: html
Converting to a NumPy tensor (``ndarray``), or vice versa, is easy. The
converted result does not share memory. This minor inconvenience is
actually quite important: when you perform operations on the CPU or on
GPUs, you do not want to halt computation, waiting to see whether the
NumPy package of Python might want to be doing something else with the
same chunk of memory.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
A = X.numpy()
B = tf.constant(A)
type(A), type(B)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(numpy.ndarray, tensorflow.python.framework.ops.EagerTensor)
.. raw:: html
.. raw:: html
To convert a size-1 tensor to a Python scalar, we can invoke the
``item`` function or Python’s built-in functions.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
a = torch.tensor([3.5])
a, a.item(), float(a), int(a)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(tensor([3.5000]), 3.5, 3.5, 3)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
a = np.array([3.5])
a, a.item(), float(a), int(a)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(array([3.5]), 3.5, 3.5, 3)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
a = jnp.array([3.5])
a, a.item(), float(a), int(a)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(Array([3.5], dtype=float32), 3.5, 3.5, 3)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
a = tf.constant([3.5]).numpy()
a, a.item(), float(a), int(a)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(array([3.5], dtype=float32), 3.5, 3.5, 3)
.. raw:: html
.. raw:: html
Summary
-------
The tensor class is the main interface for storing and manipulating data
in deep learning libraries. Tensors provide a variety of functionalities
including construction routines; indexing and slicing; basic mathematics
operations; broadcasting; memory-efficient assignment; and conversion to
and from other Python objects.
Exercises
---------
1. Run the code in this section. Change the conditional statement
``X == Y`` to ``X < Y`` or ``X > Y``, and then see what kind of
tensor you can get.
2. Replace the two tensors that operate by element in the broadcasting
mechanism with other shapes, e.g., 3-dimensional tensors. Is the
result the same as expected?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html