\( \require{ams} \DeclareMathOperator*{\softmax}{softmax} \DeclareMathOperator{\ind}{ind} \DeclareMathOperator{\rec}{rec} \newcommand{\ensuremath}[1]{#1} \newcommand{\vdotswithin}[1]{\vdots} \)

1 Introduction

Most papers about neural networks use the notation of vectors and matrices from applied linear algebra. This notation is optimized for talking about vector spaces, but becomes cumbersome when talking about neural networks. Consider the following equation (Vaswani et al. 2017): \[\text{Attention}(Q, K, V) = \softmax \left( \frac{QK^\top}{\sqrt{d_k}} \right) V.\] where \(Q\), \(K\), and \(V\) (for query, key, and value, respectively) are sequences of feature vectors, packed into matrices. Does the product \(QK^\top\) sum over the sequence, or over the features? It sums over columns, but there’s not enough information to know what the columns represent. Is the softmax taken over the query sequence or the key sequence? The usual notation doesn’t even offer a way to answer this question. With multiple attention heads or multiple sentences in a minibatch, the notation becomes more difficult still.

Here, we propose mathematical notation for tensors with named axes. The notation has a formal underpinning, but is hopefully intuitive enough that machine learning researchers can understand it without much effort.

In our notation, the above equation becomes \[\begin{aligned} \text{Attention} \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}key}}} \times \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}\times \ensuremath{\mathsf{\vphantom{fg}key}}} \times \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}\times\ensuremath{\mathsf{\vphantom{fg}val}}} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}val}}} \\ \text{Attention}(Q,K,V) = \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}seq}}}}{\vphantom{fg}\mathrm{softmax}}} \left( \frac{Q \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}key}}}}{\vphantom{fg}\odot}} K}{\sqrt{|\ensuremath{\mathsf{\vphantom{fg}key}}|}} \right) \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}seq}}}}{\vphantom{fg}\odot}} V.\end{aligned}\]

The \(\ensuremath{\mathsf{\vphantom{fg}key}}\) axis is for the features of queries and keys, and \(\ensuremath{\mathsf{\vphantom{fg}val}}\) is for the features of values. The \(\ensuremath{\mathsf{\vphantom{fg}seq}}\) axis is for tokens of sequences. This notation makes it clear what are the types of each input, and how they are acted upon. The tensor \(Q\) is a query, which is a vector over the \(\ensuremath{\mathsf{\vphantom{fg}key}}\) axis. The tensor \(K\) is a sequence of keys, and so has the \(\ensuremath{\mathsf{\vphantom{fg}seq}}\) and \(\ensuremath{\mathsf{\vphantom{fg}key}}\) axes. You can think of it as a matrix, but don’t need to remember whether \(\ensuremath{\mathsf{\vphantom{fg}seq}}\) corresponds to columns and \(\ensuremath{\mathsf{\vphantom{fg}key}}\) to rows, or vice versa. The tensor \(V\) is a sequence of values, and so has the \(\ensuremath{\mathsf{\vphantom{fg}seq}}\) and \(\ensuremath{\mathsf{\vphantom{fg}val}}\) axes. The description of the functions makes it clear which axis any operation acts upon. For example, to parse the expression \(Q \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}key}}}}{\vphantom{fg}\odot}} K\), it doesn’t matter if \(\ensuremath{\mathsf{\vphantom{fg}key}}\) corresponds to rows or columns in \(K\), since the dot product is taken over the \(\ensuremath{\mathsf{\vphantom{fg}key}}\) axis shared between \(K\) and \(Q\).

Our notation also makes it easy to “broadcast” or “vectorize” a function by extending it to act on tensors with more axes. For example, if instead of being a vector with the single axis \(\ensuremath{\mathsf{\vphantom{fg}key}}\), \(Q\) is a tensor with the three axes \(\ensuremath{\mathsf{\vphantom{fg}key}}\), \(\ensuremath{\mathsf{\vphantom{fg}seq}}\) and \(\ensuremath{\mathsf{\vphantom{fg}batch}}\) (corresponding to tokens of a sequence and examples in a minibatch, respectively) then the \(\text{Attention}\) function works as written, acting on each example in a minibatch in parallel. Similarly, we can also add a \(\ensuremath{\mathsf{\vphantom{fg}heads}}\) axis to the inputs for multiple attention heads.

Our notation is inspired by libraries for programming with multidimensional arrays (Harris et al. 2020; Paszke et al. 2019) and extensions that use named axes, like xarray (Hoyer and Hamman 2017), Nexus (Chen 2017), tsalib (Sinha 2018), NamedTensor (Rush 2019), named tensors in PyTorch (Torch Contributors 2019), and Dex (Maclaurin et al. 2019). However, our focus is on mathematical notation rather than code.

The source code for this document can be found at https://github.com/namedtensor/notation/. We invite anyone to make comments on this proposal by submitting issues or pull requests on this repository.

2 Informal Overview

In standard notation, a vector, matrix, or tensor is indexed by an integer or sequence of integers. If \(A \in \mathbb{R}^{3\times3}\), then the order of the two axes matters: \(A_{1,3}\) and \(A_{3,1}\) are not the same element. It’s up to the reader to remember what each axis of each tensor stands for. This problem is exacerbated in modern machine learning, where tensors have multiple axes with different meanings (batches, channels, etc.), and different operations act on different axes. The solution we propose is for each axis to have a name that describes it and ensures there is no confusion between (for example) the batch or channel axis. Our notation also makes it easy to define unambiguously which operations act on which axes.

2.1 Named tensors

In a named tensor, we give each axis a name. For example, if \(A\) represents an image, we can make it a named tensor like so (writing it two equivalent ways to show that the order of axes does not matter): \[\begin{aligned} A &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}height}}[3] \times \ensuremath{\mathsf{\vphantom{fg}width}}[3]} = \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}width}}[3] \times \ensuremath{\mathsf{\vphantom{fg}height}}[3]} \\ A &= \ensuremath{\mathsf{\vphantom{fg}height}}\begin{array}[b]{@{}c@{}}\ensuremath{\mathsf{\vphantom{fg}width}}\\\begin{bmatrix} 3 & 1 & 4 \\ 1 & 5 & 9 \\ 2 & 6 & 5 \end{bmatrix}\end{array} = \ensuremath{\mathsf{\vphantom{fg}width}}\begin{array}[b]{@{}c@{}}\ensuremath{\mathsf{\vphantom{fg}height}}\\\begin{bmatrix} 3 & 1 & 2 \\ 1 & 5 & 6 \\ 4 & 9 & 5 \end{bmatrix}\end{array}.\end{aligned}\]

We access elements of \(A\) using named indices, whose order again does not matter: \(A_{\ensuremath{\mathsf{\vphantom{fg}height}}(1), \ensuremath{\mathsf{\vphantom{fg}width}}(3)} = A_{\ensuremath{\mathsf{\vphantom{fg}width}}(3), \ensuremath{\mathsf{\vphantom{fg}height}}(1)} = 4\). We also allow partial indexing: \[\begin{aligned} A_{\ensuremath{\mathsf{\vphantom{fg}height}}(1)} &= \begin{array}[b]{@{}c@{}}\ensuremath{\mathsf{\vphantom{fg}width}}\\\begin{bmatrix} 3 & 1 & 4 \end{bmatrix}\end{array} & A_{\ensuremath{\mathsf{\vphantom{fg}width}}(3)} &= \begin{array}[b]{@{}c@{}}\ensuremath{\mathsf{\vphantom{fg}height}}\\\begin{bmatrix} 4 & 9 & 5 \end{bmatrix}\end{array}.\end{aligned}\]

In many contexts, an axis name is used with only one size. If so, we can simply write \(\ensuremath{\mathsf{\vphantom{fg}height}}\) for the unique axis with name \(\ensuremath{\mathsf{\vphantom{fg}height}}\), as in \(\mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}height}}\times \ensuremath{\mathsf{\vphantom{fg}width}}}\). We can leave the size of an axis unspecified at first, and specify its size later (like in a section on experimental details): for example, \(|\ensuremath{\mathsf{\vphantom{fg}height}}|=|\ensuremath{\mathsf{\vphantom{fg}width}}|=28\) to specify its exact size or just \(|\ensuremath{\mathsf{\vphantom{fg}height}}|=|\ensuremath{\mathsf{\vphantom{fg}width}}|\) to specify that it’s a square image.

What are good choices for axis names? We recommend meaningful words instead of single letters, and we recommend words that describe a whole rather than its parts. For example, if we wanted \(A\) to have red, green, and blue channels, we’d name the axis \(\ensuremath{\mathsf{\vphantom{fg}chans}}\), and if we wanted to represent a minibatch of images, we’d name the axis \(\ensuremath{\mathsf{\vphantom{fg}batch}}\). Please see §3 for more examples.

2.2 Named tensor operations

Operations on named tensors are defined by taking a function on low-order tensors and extending it to higher-order tensors.

2.2.1 Elementwise operations and broadcasting

Any function from a scalar to a scalar can be applied elementwise to a named tensor, and any function from two scalars to a scalar can be applied to two named tensors with the same shape. For example: \[\frac1{1+\exp(-A)} = \ensuremath{\mathsf{\vphantom{fg}height}}\begin{array}[b]{@{}c@{}}\ensuremath{\mathsf{\vphantom{fg}width}}\\\begin{bmatrix} \frac 1{1+\exp(-3)} & \frac 1{1+\exp(-1)} & \frac 1{1+\exp(-4)} \\[1ex] \frac 1{1+\exp(-1)} & \frac 1{1+\exp(-5)} & \frac 1{1+\exp(-9)} \\[1ex] \frac 1{1+\exp(-2)} & \frac 1{1+\exp(-6)} & \frac 1{1+\exp(-5)} \end{bmatrix}\end{array}.\]

But if we apply a binary function/operator to tensors with different shapes, they are broadcast against each other (similarly to NumPy and derivatives). Let \[\begin{aligned} x &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}height}}[3]} & y &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}width}}[3]} \\ x &= \ensuremath{\mathsf{\vphantom{fg}height}}\begin{array}[b]{@{}c@{}}\\\begin{bmatrix} 2 \\ 7 \\ 1 \end{bmatrix}\end{array} & y &= \begin{array}[b]{@{}c@{}}\ensuremath{\mathsf{\vphantom{fg}width}}\\\begin{bmatrix} 1 & 4 & 1 \end{bmatrix}\end{array}.\end{aligned}\] (We write \(x\) as a column just to make the broadcasting easier to visualize.) Then, to evaluate \(A+x\), we effectively replace \(x\) with a new tensor \(x'\) that contains a copy of \(x\) for every index of axis \(\ensuremath{\mathsf{\vphantom{fg}width}}\). Likewise for \(A+y\): \[\begin{aligned} A + x &= \ensuremath{\mathsf{\vphantom{fg}height}}\begin{array}[b]{@{}c@{}}\ensuremath{\mathsf{\vphantom{fg}width}}\\\begin{bmatrix} 3+2 & 1+2 & 4+2 \\ 1+7 & 5+7 & 9+7 \\ 2+1 & 6+1 & 5+1 \end{bmatrix}\end{array} & A + y &= \ensuremath{\mathsf{\vphantom{fg}height}}\begin{array}[b]{@{}c@{}}\ensuremath{\mathsf{\vphantom{fg}width}}\\\begin{bmatrix} 3+1 & 1+4 & 4+1 \\ 1+1 & 5+4 & 9+1 \\ 2+1 & 6+4 & 5+1 \end{bmatrix}\end{array}.\end{aligned}\]

2.2.2 Reductions

The same broadcasting rules apply to functions from vectors to scalars, called reductions. We always specify which axis a reduction applies to using a subscript (equivalent to the axis argument in NumPy and dim in PyTorch).

See §5.4 for some common reductions. Here we take summation as an example. \[\begin{aligned} \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}height}}}} A &= \sum_i A_{\ensuremath{\mathsf{\vphantom{fg}height}}(i)} = \begin{array}[b]{@{}c@{}}\ensuremath{\mathsf{\vphantom{fg}width}}\\\begin{bmatrix} 3+1+2 & 1+5+6 & 4+9+5 \end{bmatrix}\end{array} \\ \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}width}}}} A &= \sum_j A_{\ensuremath{\mathsf{\vphantom{fg}width}}(j)} = \begin{array}[b]{@{}c@{}}\ensuremath{\mathsf{\vphantom{fg}height}}\\\begin{bmatrix} 3+1+4 & 1+5+9 & 2+6+5 \end{bmatrix}\end{array}.\end{aligned}\]

We can also write multiple names to sum over multiple axes: \[\sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}height}}\\ \ensuremath{\mathsf{\vphantom{fg}width}}}} A = \sum_i \sum_j A_{\ensuremath{\mathsf{\vphantom{fg}height}}(i),\ensuremath{\mathsf{\vphantom{fg}width}}(j)} = 3+1+4+1+5+9+2+6+5.\] But a summation with an index variable (like \(i\) or \(j\) above) is a standard summation over values of that variable, and a summation with no subscript is a standard summation over a set.

The vector dot-product is a function from two vectors to a scalar, which generalizes to named tensors to give the ubiquitous contraction operator. You can think of it as elementwise multiplication, then summation over one axis: \[\begin{aligned} %A \ndot{\height} x &= \sum_i A_{\height(i)} x_{\height(i)} = \nmatrix{}{\width}{ % 3\cdot2 + 1\cdot7 + 2\cdot1 & 1\cdot2 + 5\cdot7 + 6\cdot1 & 4\cdot2 + 9\cdot7 + 5\cdot 1 %} \\ A \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}width}}}}{\vphantom{fg}\odot}} y &= \sum_j A_{\ensuremath{\mathsf{\vphantom{fg}width}}(j)} \, y_{\ensuremath{\mathsf{\vphantom{fg}width}}(j)} = \ensuremath{\mathsf{\vphantom{fg}height}}\begin{array}[b]{@{}c@{}}\\\begin{bmatrix} 3\cdot 1 + 1\cdot 4 + 4\cdot 1 \\ 1\cdot 1 + 5\cdot 4 + 9\cdot 1 \\ 2\cdot 1 + 6\cdot 4 + 5\cdot 1 \end{bmatrix}\end{array}.\end{aligned}\]

Again, we can write multiple names to contract multiple axes at once. A \(\odot\) with no axis name under it contracts zero axes and is equivalent to elementwise multiplication, so we use \(\odot\) for elementwise multiplication as well.

The contraction operator can be used for many multiplication-like operations: \[\begin{aligned} x \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}height}}}}{\vphantom{fg}\odot}} x &= \sum_i x_{\ensuremath{\mathsf{\vphantom{fg}height}}(i)} \, x_{\ensuremath{\mathsf{\vphantom{fg}height}}(i)} && \text{inner product} \\ [x \odot y]_{\ensuremath{\mathsf{\vphantom{fg}height}}(i), \ensuremath{\mathsf{\vphantom{fg}width}}(j)} &= x_{\ensuremath{\mathsf{\vphantom{fg}height}}(i)} \, y_{\ensuremath{\mathsf{\vphantom{fg}width}}(j)} && \text{outer product} \\ A \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}width}}}}{\vphantom{fg}\odot}} y &= \sum_i A_{\ensuremath{\mathsf{\vphantom{fg}width}}(i)} \, y_{\ensuremath{\mathsf{\vphantom{fg}width}}(i)} && \text{matrix-vector product} \\ x \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}height}}}}{\vphantom{fg}\odot}} A &= \sum_i x_{\ensuremath{\mathsf{\vphantom{fg}height}}(i)} \, A_{\ensuremath{\mathsf{\vphantom{fg}height}}(i)} && \text{vector-matrix product} \\ A \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}width}}}}{\vphantom{fg}\odot}} B &= \sum_i A_{\ensuremath{\mathsf{\vphantom{fg}width}}(i)} \odot B_{\ensuremath{\mathsf{\vphantom{fg}width}}(i)} && \text{matrix-matrix product}~(B \in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}width}}\times \ensuremath{\mathsf{\vphantom{fg}width}}'})\end{aligned}\]

2.2.3 Renaming and reshaping

It’s often useful to rename an axis (analogous to a transpose operation in standard notation): \[A_{\ensuremath{\mathsf{\vphantom{fg}height}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}height}}'} = \ensuremath{\mathsf{\vphantom{fg}height}}'\begin{array}[b]{@{}c@{}}\ensuremath{\mathsf{\vphantom{fg}width}}\\\begin{bmatrix} 3 & 1 & 4 \\ 1 & 5 & 9 \\ 2 & 6 & 5 \\ \end{bmatrix}\end{array}.\] We can also reshape two or more axes into one axis: \[A_{(\ensuremath{\mathsf{\vphantom{fg}height}},\ensuremath{\mathsf{\vphantom{fg}width}})\rightarrow\ensuremath{\mathsf{\vphantom{fg}layer}}} = \begin{array}[b]{@{}c@{}}\ensuremath{\mathsf{\vphantom{fg}layer}}\\\begin{bmatrix} 3 & 1 & 4 & 1 & 5 & 9 & 2 & 6 & 5 \end{bmatrix}\end{array}\] The order of elements in the new axis is undefined. If you need a particular order, you can write a more specific definition.

3 Examples

In this section we give a series of examples illustrating how to use named tensors in various situations, mostly related to machine learning. Many of these examples use functions that the reader can probably guess the meaning of; if not, please see §5.4 for definitions.

3.1 Building blocks

3.1.1 Feedforward neural networks

A feedforward neural network looks like this: \[\begin{aligned} X^0 &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}input}}} \\ X^1 &= \sigma(W^1 \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}input}}}}{\vphantom{fg}\odot}} X^0 + b^1) & W^1 &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}hidden}}_1 \times \ensuremath{\mathsf{\vphantom{fg}input}}} & b^1 &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}hidden}}_1} \\ X^2 &= \sigma(W^2 \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}hidden}}_1}}{\vphantom{fg}\odot}} X^1 + b^2) & W^2 &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}hidden}}_2 \times \ensuremath{\mathsf{\vphantom{fg}hidden}}_1} & b^2 &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}hidden}}_2} \\ X^3 &= \sigma(W^3 \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}hidden}}_2}}{\vphantom{fg}\odot}} X^2 + b^3) & W^3 &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}out}}\times \ensuremath{\mathsf{\vphantom{fg}hidden}}_2} & b^3 &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}out}}}\end{aligned}\] The layer sizes can be set by writing \(|\ensuremath{\mathsf{\vphantom{fg}input}}| = 100\), etc.

If you don’t like repeating the equations for fully-connected layers, you can put them inside a function: \[\begin{aligned} \text{FullConn}^l(x) &= \sigma\left(W^l \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}layer}}}}{\vphantom{fg}\odot}} x + b^l\right)_{\ensuremath{\mathsf{\vphantom{fg}layer}}'\rightarrow\ensuremath{\mathsf{\vphantom{fg}layer}}}\end{aligned}\] where \[\begin{aligned} W^l &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}layer}}'[n_l] \times \ensuremath{\mathsf{\vphantom{fg}layer}}[n_{l-1}]} \\ b^l &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}layer}}'[n_l]}.\end{aligned}\] A couple of things are new here. First, \(\text{FullConn}^l\) encapsulates both the equation for layer \(l\) as well as its parameters (analogous to what TensorFlow and PyTorch call modules). Second, we chose to use the same axis name \(\ensuremath{\mathsf{\vphantom{fg}layer}}\) for all the layers (with different sizes \(n_l\)). So \(\text{FullConn}^l\) temporarily computes its output over axis \(\ensuremath{\mathsf{\vphantom{fg}layer}}'\), then renames it back to \(\ensuremath{\mathsf{\vphantom{fg}layer}}\).

Then the network can be defined like this: \[\begin{aligned} X^0 &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}layer}}[n_0]} \\ X^1 &= \text{FullConn}^1(X^0) \\ X^2 &= \text{FullConn}^2(X^1) \\ X^3 &= \text{FullConn}^3(X^2).\end{aligned}\]

3.1.2 Recurrent neural networks

As a second example, let’s define a simple (Elman) RNN. This is similar to the feedforward network, except that the number of timesteps is variable and they all share parameters. \[\begin{aligned} x^{t} &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}input}}} & t &= 1, \ldots, n \\ W^{\text{h}} &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}hidden}}\times \ensuremath{\mathsf{\vphantom{fg}hidden}}'} & |\ensuremath{\mathsf{\vphantom{fg}hidden}}| &= |\ensuremath{\mathsf{\vphantom{fg}hidden}}'| \\ W^{\text{i}} &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}input}}\times \ensuremath{\mathsf{\vphantom{fg}hidden}}'} \\ b &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}hidden}}'} \\ h^{0} &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}hidden}}} \\ h^{t} &= \sigma\left( W^{\text{h}} \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}hidden}}}}{\vphantom{fg}\odot}} h^{t-1} + W^{\text{i}} \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}input}}}}{\vphantom{fg}\odot}} x^{t} + b \right)_{\ensuremath{\mathsf{\vphantom{fg}hidden}}'\rightarrow\ensuremath{\mathsf{\vphantom{fg}hidden}}} & t &= 1, \ldots, n\end{aligned}\]

3.1.3 Attention

In the introduction (§1), we mentioned some difficulties in interpreting the equation for attention as it’s usually written. In our notation, it looks like this: \[\begin{aligned} \text{Attention} \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}key}}} \times \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}\times\ensuremath{\mathsf{\vphantom{fg}key}}} \times \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}\times\ensuremath{\mathsf{\vphantom{fg}val}}} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}val}}} \\ \text{Attention}(Q,K,V) &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}seq}}}}{\vphantom{fg}\mathrm{softmax}}} \left( \frac{Q \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}key}}}}{\vphantom{fg}\odot}} K}{\sqrt{|\ensuremath{\mathsf{\vphantom{fg}key}}|}} \right) \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}seq}}}}{\vphantom{fg}\odot}} V.\end{aligned}\]

This equation is slightly different from the one in the introduction. The previous definition computed an output sequence over axis \(\ensuremath{\mathsf{\vphantom{fg}seq}}'\), but this definition computes a single value. If we want a sequence, we can just give \(Q\) a \(\ensuremath{\mathsf{\vphantom{fg}seq}}'\) axis (or some other name), and the function will compute an output sequence. Furthermore, if we give \(Q\), \(K\), and \(V\) a \(\ensuremath{\mathsf{\vphantom{fg}heads}}\) axis for multiple attention heads, then the function will compute multi-head attention.

Sometimes we need to apply a mask to keep from attending to certain positions. \[\begin{aligned} \text{Attention} \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}key}}} \times \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}\times\ensuremath{\mathsf{\vphantom{fg}key}}} \times \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}\times\ensuremath{\mathsf{\vphantom{fg}val}}} \times \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}val}}} \\ \text{Attention}(Q, K, V, M) &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}seq}}}}{\vphantom{fg}\mathrm{softmax}}} \left( \frac{Q \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}key}}}}{\vphantom{fg}\odot}} K}{\sqrt{|\ensuremath{\mathsf{\vphantom{fg}key}}|}} + M \right) \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}seq}}}}{\vphantom{fg}\odot}} V.\end{aligned}\]

3.1.4 Convolution

Convolutions can be easily written by unrolling a tensor and then applying a standard dot product. First, we define the unrolling operation: \[\begin{aligned} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}seq}}\\ \ensuremath{\mathsf{\vphantom{fg}kernel}}}}{\vphantom{fg}\mathrm{unroll}}} \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}[n]} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}[n-|\ensuremath{\mathsf{\vphantom{fg}kernel}}|+1], \ensuremath{\mathsf{\vphantom{fg}kernel}}} \\ \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}seq}}\\ \ensuremath{\mathsf{\vphantom{fg}kernel}}}}{\vphantom{fg}\mathrm{unroll}}} X &= Y,\ \text{where} \\ Y_{\ensuremath{\mathsf{\vphantom{fg}seq}}(i), \ensuremath{\mathsf{\vphantom{fg}kernel}}(j)} &= X_{\ensuremath{\mathsf{\vphantom{fg}seq}}(i+j - 1)}.\end{aligned}\]

Then we can define convolutions as: \[\begin{aligned} \text{Conv1d} \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}chans}}\times \ensuremath{\mathsf{\vphantom{fg}seq}}[n]} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}[n']} \\ \text{Conv1d}(X; W, b) &= W \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}chans}}\\ \ensuremath{\mathsf{\vphantom{fg}kernel}}}}{\vphantom{fg}\odot}} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}seq}}\\ \ensuremath{\mathsf{\vphantom{fg}kernel}}}}{\vphantom{fg}\mathrm{unroll}}} X + b\end{aligned}\] where \[\begin{aligned} W &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}chans}}\times \ensuremath{\mathsf{\vphantom{fg}kernel}}} \\ b &\in \mathbb{R}\\\end{aligned}\] This computes a single output channel, but we can get multiple output channels by giving \(W\) and \(b\) a \(\ensuremath{\mathsf{\vphantom{fg}chans}}'\) axis (or some other name).

The same unrolling operation can be used to define higher-dimensional convolutions: \[\begin{aligned} \text{Conv2d} \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}chans}}\times \ensuremath{\mathsf{\vphantom{fg}height}}[h] \times \ensuremath{\mathsf{\vphantom{fg}width}}[w]} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}height}}[h'] \times \ensuremath{\mathsf{\vphantom{fg}width}}[w']} \\ \text{Conv2d}(X; W, b) &= W \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}chans}}\\ \ensuremath{\mathsf{\vphantom{fg}kh}}, \ensuremath{\mathsf{\vphantom{fg}kw}}}}{\vphantom{fg}\odot}} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}height}}\\ \ensuremath{\mathsf{\vphantom{fg}kh}}}}{\vphantom{fg}\mathrm{unroll}}} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}width}}\\\ensuremath{\mathsf{\vphantom{fg}kw}}}}{\vphantom{fg}\mathrm{unroll}}} X + b\end{aligned}\] where \[\begin{aligned} W &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}chans}}\times \ensuremath{\mathsf{\vphantom{fg}kh}}\times \ensuremath{\mathsf{\vphantom{fg}kw}}} \\ b &\in \mathbb{R}.\end{aligned}\]

3.1.5 Max pooling

We first define an operation to reshape an axis: \[\begin{aligned} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}seq}},\ensuremath{\mathsf{\vphantom{fg}kernel}}}}{\vphantom{fg}\mathrm{pool}}} \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}[n]} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}[n/|\ensuremath{\mathsf{\vphantom{fg}kernel}}|],\ensuremath{\mathsf{\vphantom{fg}kernel}}} \\ \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}seq}},\ensuremath{\mathsf{\vphantom{fg}kernel}}}}{\vphantom{fg}\mathrm{pool}}} X &= Y,\ \text{where} \\ Y_{\ensuremath{\mathsf{\vphantom{fg}seq}}(i), \ensuremath{\mathsf{\vphantom{fg}kernel}}(j)} &= X_{\ensuremath{\mathsf{\vphantom{fg}seq}}((i-1) \cdot |\ensuremath{\mathsf{\vphantom{fg}kernel}}| + j)}.\end{aligned}\]

Then we can define: \[\begin{aligned} \text{MaxPool1d}_{k} \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}[n]} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}[n/k]} \\ \text{MaxPool1d}_{k}(X) &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}kernel}}}}{\vphantom{fg}\mathrm{max}}} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}seq}},\ensuremath{\mathsf{\vphantom{fg}kernel}}}}{\vphantom{fg}\mathrm{pool}}} X \\ |\ensuremath{\mathsf{\vphantom{fg}kernel}}| &= k \\ \text{MaxPool2d}_{kh,kw} \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}height}}[h] \times \ensuremath{\mathsf{\vphantom{fg}width}}[w]} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}height}}[h/kh] \times \ensuremath{\mathsf{\vphantom{fg}width}}[w/kw]} \\ \text{MaxPool2d}_{kh,kw}(X) &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}kh}},\ensuremath{\mathsf{\vphantom{fg}kw}}}}{\vphantom{fg}\mathrm{max}}} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}height}},\ensuremath{\mathsf{\vphantom{fg}kh}}}}{\vphantom{fg}\mathrm{pool}}} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}width}},\ensuremath{\mathsf{\vphantom{fg}kw}}}}{\vphantom{fg}\mathrm{pool}}} X \\ |\ensuremath{\mathsf{\vphantom{fg}kh}}| &= kh \\ |\ensuremath{\mathsf{\vphantom{fg}kw}}| &= kw.\end{aligned}\] Other pooling functions could be defined similarly.

3.1.6 Normalization layers

Batch, instance, and layer normalization are often informally described using the same equation, but they each correspond to very different functions. They differ both by which axes are standardized as well as their parameters.

We can define a single generic standardization function as: \[\begin{aligned} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{standardize}}} \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}ax}}} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}ax}}} \\ \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{standardize}}}(X) &= \frac{X - \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{mean}}}(X)}{\sqrt{\mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{var}}}(X) + \epsilon}}\end{aligned}\] where \(\epsilon > 0\) is a small constant for numerical stability.

Then, we can define the three kinds of normalization layers, all with type \(\mathbb{R}^{{\ensuremath{\mathsf{\vphantom{fg}batch}}\times \ensuremath{\mathsf{\vphantom{fg}chans}}\times \ensuremath{\mathsf{\vphantom{fg}layer}}}} \rightarrow \mathbb{R}^{{\ensuremath{\mathsf{\vphantom{fg}batch}}\times \ensuremath{\mathsf{\vphantom{fg}chans}}\times \ensuremath{\mathsf{\vphantom{fg}layer}}}}\): \[\begin{aligned} \text{BatchNorm}(X; \gamma, \beta) &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}batch}},\ensuremath{\mathsf{\vphantom{fg}layer}}}}{\vphantom{fg}\mathrm{standardize}}}(X) \mathbin{\underset{\substack{}}{\vphantom{fg}\odot}} \gamma + \beta & \gamma, \beta &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}chans}}} \\ \text{InstanceNorm}(X; \gamma, \beta) &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}layer}}}}{\vphantom{fg}\mathrm{standardize}}}(X) \mathbin{\underset{\substack{}}{\vphantom{fg}\odot}} \gamma + \beta & \gamma, \beta &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}chans}}} \\ \text{LayerNorm}(X; \gamma, \beta) &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}layer}},\ensuremath{\mathsf{\vphantom{fg}chans}}}}{\vphantom{fg}\mathrm{standardize}}}(X) \mathbin{\underset{\substack{}}{\vphantom{fg}\odot}} \gamma + \beta & \gamma, \beta &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}chans}},\ensuremath{\mathsf{\vphantom{fg}layer}}}\end{aligned}\]

Note that, while superficially similar, these functions differ in their standardized axes and their parameter shape.

Other deep learning methods have been proposed which consider different shapes of standardization. For instance, group norm is a popular extension that first pools channels into \(k\)-size groups before standardizing.

\[\begin{aligned} \text{GroupNorm}_k(X; \gamma, \beta) &= \left[ \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}kernel}},\ensuremath{\mathsf{\vphantom{fg}layer}}}}{\vphantom{fg}\mathrm{standardize}}} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}chans}}, \ensuremath{\mathsf{\vphantom{fg}kernel}}}}{\vphantom{fg}\mathrm{pool}}} X \right]_{(\ensuremath{\mathsf{\vphantom{fg}chans}},\ensuremath{\mathsf{\vphantom{fg}kernel}})\rightarrow \ensuremath{\mathsf{\vphantom{fg}chans}}} \mathbin{\underset{\substack{}}{\vphantom{fg}\odot}} \gamma + \beta \\ \end{aligned}\] where \[\begin{aligned} |\ensuremath{\mathsf{\vphantom{fg}kernel}}| &= k\\ \gamma, \beta &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}chans}}}.\end{aligned}\]

3.2 Transformer

We define a Transformer used autoregressively as a language model. The input is a sequence of one-hot vectors, from which we compute word embeddings and positional encodings: \[\begin{aligned} I &\in \{0, 1\}^{\ensuremath{\mathsf{\vphantom{fg}seq}}\times \ensuremath{\mathsf{\vphantom{fg}vocab}}} & \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}vocab}}}} I &= 1 \\ W &= (E \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}vocab}}}}{\vphantom{fg}\odot}} I)\sqrt{|\ensuremath{\mathsf{\vphantom{fg}layer}}|} & E &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}vocab}}\times \ensuremath{\mathsf{\vphantom{fg}layer}}} \\ P &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}\times \ensuremath{\mathsf{\vphantom{fg}layer}}} \\ P_{\ensuremath{\mathsf{\vphantom{fg}seq}}(p), \ensuremath{\mathsf{\vphantom{fg}layer}}(i)} &= \begin{cases} \sin((p-1) / 10000^{(i-1) / |\ensuremath{\mathsf{\vphantom{fg}layer}}|}) & \text{$i$ odd} \\ \cos((p-1) / 10000^{(i-2) / |\ensuremath{\mathsf{\vphantom{fg}layer}}|}) & \text{$i$ even.} \end{cases}\end{aligned}\]

Then we use \(L\) layers of self-attention and feed-forward neural networks: \[\begin{aligned} X^0 &= W+P \\ T^1 &= \text{LayerNorm}^1(\text{SelfAtt}^1(X^0)) + X^0\\ X^1 &= \text{LayerNorm}^{1'}(\text{FFN}^1(T^1)) + T^1\\ &\vdotswithin{=} \\ T^{L} &= \text{LayerNorm}^L(\text{SelfAtt}^L(X^{L-1})) + X^{L-1}\\ X^{L} &= \text{LayerNorm}^{L'}(\text{FFN}^L(T^L)) + T^L\\ O &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}vocab}}}}{\vphantom{fg}\mathrm{softmax}}}(E \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}layer}}}}{\vphantom{fg}\odot}} X^L)\end{aligned}\] where \(\text{LayerNorm}\), \(\text{SelfAtt}\) and \(\text{FFN}\) are defined below.

Layer normalization (\(l = 1, 1', \ldots, L, L'\)): \[\begin{aligned} \text{LayerNorm}^l \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}layer}}} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}layer}}} \\ \text{LayerNorm}^l(X) &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}layer}}}}{\vphantom{fg}\mathrm{XNorm}}}(X; \beta^l, \gamma^l).\end{aligned}\]

We defined attention in §3.1.3; the Transformer uses multi-head self-attention, in which queries, keys, and values are all computed from the same sequence. \[\begin{aligned} \text{SelfAtt}^l \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}\times \ensuremath{\mathsf{\vphantom{fg}layer}}} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}\times \ensuremath{\mathsf{\vphantom{fg}layer}}} \\ \text{SelfAtt}^l(X) &= Y\end{aligned}\] where \[\begin{aligned} |\ensuremath{\mathsf{\vphantom{fg}seq}}| &= |\ensuremath{\mathsf{\vphantom{fg}seq}}'| \\ |\ensuremath{\mathsf{\vphantom{fg}key}}| = |\ensuremath{\mathsf{\vphantom{fg}val}}| &= |\ensuremath{\mathsf{\vphantom{fg}layer}}|/|\ensuremath{\mathsf{\vphantom{fg}heads}}| \\ Q &= W^{l,Q} \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}layer}}}}{\vphantom{fg}\odot}} X_{\ensuremath{\mathsf{\vphantom{fg}seq}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}seq}}'} & W^{l,Q} &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}heads}}\times \ensuremath{\mathsf{\vphantom{fg}layer}}\times \ensuremath{\mathsf{\vphantom{fg}key}}} \\ K &= W^{l,K} \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}layer}}}}{\vphantom{fg}\odot}} X & W^{l,K} &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}heads}}\times \ensuremath{\mathsf{\vphantom{fg}layer}}\times \ensuremath{\mathsf{\vphantom{fg}key}}} \\ V &= W^{l,V} \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}layer}}}}{\vphantom{fg}\odot}} X & W^{l,V} &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}heads}}\times \ensuremath{\mathsf{\vphantom{fg}layer}}\times \ensuremath{\mathsf{\vphantom{fg}val}}} \\ M & \in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}\times \ensuremath{\mathsf{\vphantom{fg}seq}}'} \\ M_{\ensuremath{\mathsf{\vphantom{fg}seq}}(i), \ensuremath{\mathsf{\vphantom{fg}seq}}'(j)} &= \begin{cases} 0 & i \leq j\\ -\infty & \text{otherwise} \end{cases} \\ Y &= W^{l,O} \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}heads}}\\ \ensuremath{\mathsf{\vphantom{fg}val}}}}{\vphantom{fg}\odot}} \text{Attention}(Q, K, V, M)_{\ensuremath{\mathsf{\vphantom{fg}seq}}'\rightarrow\ensuremath{\mathsf{\vphantom{fg}seq}}} & W^{l,O} &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}heads}}\times \ensuremath{\mathsf{\vphantom{fg}val}}\times \ensuremath{\mathsf{\vphantom{fg}layer}}}\end{aligned}\]

Feedforward neural networks: \[\begin{aligned} \text{FFN}^l \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}layer}}} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}layer}}} \\ \text{FFN}^l(X) &= X^2\end{aligned}\] where \[\begin{aligned} X^1 &= \text{relu}(W^{l,1} \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}layer}}}}{\vphantom{fg}\odot}} X + b^{l,1}) & W^{l,1} &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}hidden}}\times \ensuremath{\mathsf{\vphantom{fg}layer}}} & b^{l,1} &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}hidden}}} \\ X^2 &= \text{relu}(W^{l,2} \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}hidden}}}}{\vphantom{fg}\odot}} X^1 + b^{l,2}) & W^{l,2} &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}layer}}\times \ensuremath{\mathsf{\vphantom{fg}hidden}}} & b^{l,2} &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}hidden}}}.\end{aligned}\]

3.3 LeNet

\[\begin{aligned} X^0 &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}batch}}\times \ensuremath{\mathsf{\vphantom{fg}chans}}[c_0] \times \ensuremath{\mathsf{\vphantom{fg}height}}\times \ensuremath{\mathsf{\vphantom{fg}width}}} \\ T^1 &= \text{relu}(\text{Conv}^1(X^0)) \\ X^1 &= \text{MaxPool}^1(T^1) \\ T^2 &= \text{relu}(\text{Conv}^2(X^1)) \\ X^2 &= \text{MaxPool}^2(T^2)_{(\ensuremath{\mathsf{\vphantom{fg}height}},\ensuremath{\mathsf{\vphantom{fg}width}},\ensuremath{\mathsf{\vphantom{fg}chans}})\rightarrow\ensuremath{\mathsf{\vphantom{fg}layer}}} \\ X^3 &= \text{relu}(W^3 \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}layer}}}}{\vphantom{fg}\odot}} X^2 + b^3) & W^3 &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}hidden}}\times \ensuremath{\mathsf{\vphantom{fg}layer}}} & b^3 &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}hidden}}} \\ O &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}classes}}}}{\vphantom{fg}\mathrm{softmax}}} (W^4 \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}hidden}}}}{\vphantom{fg}\odot}} X^3 + b^4) & W^4 &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}classes}}\times \ensuremath{\mathsf{\vphantom{fg}hidden}}} & b^4 &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}classes}}}\end{aligned}\] As an alternative to the flattening operation in the equation for \(X^2\), we could have written \[\begin{aligned} X^2 &= \text{MaxPool}^2(T^2) \\ X^3 &= \text{relu}(W^3 \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}height}}\\ \ensuremath{\mathsf{\vphantom{fg}width}}\\ \ensuremath{\mathsf{\vphantom{fg}chans}}}}{\vphantom{fg}\odot}} X^2 + b^3) & W^3 &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}hidden}}\times \ensuremath{\mathsf{\vphantom{fg}height}}\times \ensuremath{\mathsf{\vphantom{fg}width}}\times \ensuremath{\mathsf{\vphantom{fg}chans}}}.\end{aligned}\]

The convolution and pooling operations are defined as follows: \[\begin{aligned} \text{Conv}^l(X) &= \text{Conv2d}(X; W^l, b^l)_{\ensuremath{\mathsf{\vphantom{fg}chans}}'\rightarrow\ensuremath{\mathsf{\vphantom{fg}chans}}}\end{aligned}\] where \[\begin{aligned} W^l & \in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}chans}}'[c_l] \times \ensuremath{\mathsf{\vphantom{fg}chans}}[c_{l-1}] \times \ensuremath{\mathsf{\vphantom{fg}kh}}[kh_l] \times \ensuremath{\mathsf{\vphantom{fg}kw}}[kw_l]} \\ b^l &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}chans}}'[c_l]}\end{aligned}\] and \[\begin{aligned} \text{MaxPool}^l(X) &= \text{MaxPool2d}_{ph^l,ph^l}(X).\end{aligned}\]

3.4 Other examples

3.4.1 Discrete random variables

Named axes are very helpful for working with discrete random variables, because each random variable can be represented by an axis with the same name. For instance, if \(\ensuremath{\mathsf{\vphantom{fg}A}}\) and \(\ensuremath{\mathsf{\vphantom{fg}B}}\) are random variables, we can treat \(p(\ensuremath{\mathsf{\vphantom{fg}B}} \mid \ensuremath{\mathsf{\vphantom{fg}A}})\) and \(p(\ensuremath{\mathsf{\vphantom{fg}A}})\) as tensors: \[\begin{aligned} p(\ensuremath{\mathsf{\vphantom{fg}B}} \mid \ensuremath{\mathsf{\vphantom{fg}A}}) &\in [0, 1]^{\ensuremath{\mathsf{\vphantom{fg}A}} \times \ensuremath{\mathsf{\vphantom{fg}B}}} & \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}B}}}} p(\ensuremath{\mathsf{\vphantom{fg}B}} \mid \ensuremath{\mathsf{\vphantom{fg}A}}) &= 1 \\ p(\ensuremath{\mathsf{\vphantom{fg}A}}) &\in [0, 1]^{\ensuremath{\mathsf{\vphantom{fg}A}}} & \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}A}}}} p(\ensuremath{\mathsf{\vphantom{fg}A}}) &= 1\end{aligned}\] Then many common operations on probability distributions can be expressed in terms of tensor operations: \[\begin{aligned} p(\ensuremath{\mathsf{\vphantom{fg}A}}, \ensuremath{\mathsf{\vphantom{fg}B}}) &= p(\ensuremath{\mathsf{\vphantom{fg}B}} \mid \ensuremath{\mathsf{\vphantom{fg}A}}) \odot p(\ensuremath{\mathsf{\vphantom{fg}A}}) && \text{chain rule}\\ p(\ensuremath{\mathsf{\vphantom{fg}B}}) &= \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}A}}}} p(\ensuremath{\mathsf{\vphantom{fg}A}}, \ensuremath{\mathsf{\vphantom{fg}B}}) = p(\ensuremath{\mathsf{\vphantom{fg}B}} \mid \ensuremath{\mathsf{\vphantom{fg}A}}) \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}A}}}}{\vphantom{fg}\odot}} p(\ensuremath{\mathsf{\vphantom{fg}A}}) && \text{marginalization} \\ p(\ensuremath{\mathsf{\vphantom{fg}A}} \mid \ensuremath{\mathsf{\vphantom{fg}B}}) &= \frac{p(\ensuremath{\mathsf{\vphantom{fg}A}}, \ensuremath{\mathsf{\vphantom{fg}B}})}{p(\ensuremath{\mathsf{\vphantom{fg}B}})} = \frac{p(\ensuremath{\mathsf{\vphantom{fg}B}} \mid \ensuremath{\mathsf{\vphantom{fg}A}}) \odot p(\ensuremath{\mathsf{\vphantom{fg}A}})}{p(\ensuremath{\mathsf{\vphantom{fg}B}} \mid \ensuremath{\mathsf{\vphantom{fg}A}}) \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}A}}}}{\vphantom{fg}\odot}} p(\ensuremath{\mathsf{\vphantom{fg}A}})}. && \text{Bayes' rule}\end{aligned}\]

3.4.2 Advanced indexing

Contributors: Tongfei Chen and Chu-Cheng Lin

NumPy and its derivatives provide various ways to recombine elements of a tensor to form a new tensor: integer array indexing, and functions like take, index_select, gather, and batch_gather. Using named tensors, we can write nearly all of these operations with a single function: \[\begin{aligned} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{index}}} \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}ax}}[n]} \times [n] &\rightarrow \mathbb{R}\\ \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{index}}}(A, i) &= A_{\ensuremath{\mathsf{\vphantom{fg}ax}}(i)}.\end{aligned}\] Suppose we have \[\begin{aligned} E &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}vocab}}[n] \times \ensuremath{\mathsf{\vphantom{fg}emb}}} \\ i &\in [n] \\ I &\in [n]^{\ensuremath{\mathsf{\vphantom{fg}seq}}} \\ P &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}seq}}\times \ensuremath{\mathsf{\vphantom{fg}vocab}}[n]}\end{aligned}\] Tensor \(E\) contains word embeddings for all the words in the vocabulary. Integer \(i\) is the numeric identifier of a word, while tensor \(I\) is a sequence of words. Tensor \(P\) contains a sequence of probability distributions over the vocabulary. Then:

  • The expression \(\mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}vocab}}}}{\vphantom{fg}\mathrm{index}}}(E,i)\) broadcasts \(E\)’s \(\ensuremath{\mathsf{\vphantom{fg}emb}}\) axis, giving the word embedding of word \(i\). This is the same as partial indexing (\(E_{\ensuremath{\mathsf{\vphantom{fg}vocab}}(i)}\)).

  • The expression \(\mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}vocab}}}}{\vphantom{fg}\mathrm{index}}}(E,I)\) also broadcasts \(I\)’s \(\ensuremath{\mathsf{\vphantom{fg}seq}}\) axis, giving a sequence of word embeddings. This is the same as integer array indexing.

  • The expression \(\mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}vocab}}}}{\vphantom{fg}\mathrm{index}}}(P,I)\) aligns \(P\)’s and \(I\)’s \(\ensuremath{\mathsf{\vphantom{fg}seq}}\) axes, giving a sequence of probabilities. This is the same as gather.

In NumPy, indexing using two or more integer arrays requires a special definition with some surprising special cases. With named tensors, we simply apply the indexing function twice. For example, if we (for some reason) wanted to get probabilities of words at a subset of positions: \[\begin{aligned} |\ensuremath{\mathsf{\vphantom{fg}seq}}| &= m \\ I_1 &= [m]^\ensuremath{\mathsf{\vphantom{fg}subseq}}\\ I_2 &= [n]^\ensuremath{\mathsf{\vphantom{fg}subseq}}\\ S &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}vocab}}}}{\vphantom{fg}\mathrm{index}}}(\mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}seq}}}}{\vphantom{fg}\mathrm{index}}}(P, I_1), I_2) \in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}subseq}}} \\ S_{\ensuremath{\mathsf{\vphantom{fg}subseq}}(i)} &= P_{\ensuremath{\mathsf{\vphantom{fg}seq}}(I_{\ensuremath{\mathsf{\vphantom{fg}subseq}}(i)}), \ensuremath{\mathsf{\vphantom{fg}vocab}}(I_{\ensuremath{\mathsf{\vphantom{fg}subseq}}(i)})}.\end{aligned}\]

3.4.3 Continuous bag of words

A continuous bag-of-words model classifies by summing up the embeddings of a sequence of words \(X\) (as one-hot vectors) and projecting them to the space of classes.

\[\begin{aligned} \text{CBOW} \colon \{0, 1\}^{\ensuremath{\mathsf{\vphantom{fg}seq}}\times \ensuremath{\mathsf{\vphantom{fg}vocab}}} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}classes}}} \\ \text{CBOW}(X; E, W) &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}classes}}}}{\vphantom{fg}\mathrm{softmax}}} \left(\sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}seq}}}} W \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}emb}}}}{\vphantom{fg}\odot}} E \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}vocab}}}}{\vphantom{fg}\odot}} X\right)\end{aligned}\] where \[\begin{aligned} \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}vocab}}}} X_{\ensuremath{\mathsf{\vphantom{fg}seq}}(i)} &= 1 & i &= 1, \ldots, |\ensuremath{\mathsf{\vphantom{fg}seq}}| \\ E &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}vocab}}\times \ensuremath{\mathsf{\vphantom{fg}emb}}} \\ W &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}classes}}\times \ensuremath{\mathsf{\vphantom{fg}emb}}}.\end{aligned}\]

3.4.4 Sudoku ILP

Sudoku puzzles can be represented as binary tiled tensors. Given a grid we can check that it is valid by converting it to a grid of grids. Constraints then ensure that there is one digit per row, per column and per sub-box.

\[\begin{aligned} \text{check} \colon \{0, 1\}^{\ensuremath{\mathsf{\vphantom{fg}height}}[9] \times \ensuremath{\mathsf{\vphantom{fg}width}}[9] \times \ensuremath{\mathsf{\vphantom{fg}assign}}[9]} &\rightarrow \{0, 1\} \\ \text{check}(X) &= \mathbb{I}\left[\begin{aligned} \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}assign}}}} X = 1 &\land \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}height}}\\ \ensuremath{\mathsf{\vphantom{fg}width}}}} Y = 1 \land {} \\ \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}height}}}} X = 1 &\land \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}width}}}} X = 1 \end{aligned}\right]\end{aligned}\] where \[\begin{aligned} Y &\in \{0, 1\}^{\ensuremath{\mathsf{\vphantom{fg}height}}'[3] \times \ensuremath{\mathsf{\vphantom{fg}width}}'[3] \times \ensuremath{\mathsf{\vphantom{fg}height}}[3] \times \ensuremath{\mathsf{\vphantom{fg}width}}[3] \times \ensuremath{\mathsf{\vphantom{fg}assign}}[9]} \\ Y_{\ensuremath{\mathsf{\vphantom{fg}height}}'(h'), \ensuremath{\mathsf{\vphantom{fg}height}}(h), \ensuremath{\mathsf{\vphantom{fg}width}}'(w'), \ensuremath{\mathsf{\vphantom{fg}width}}(w)} &= X_{\ensuremath{\mathsf{\vphantom{fg}height}}(3h' + h-1), \ensuremath{\mathsf{\vphantom{fg}width}}(3 w' + w-1)}.\end{aligned}\]

3.4.5 \(K\)-means clustering

The following equations define one step of \(k\)-means clustering. Given a set of points \(X\) and an initial set of cluster centers \(C\), \[\begin{aligned} X &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}batch}}\times \ensuremath{\mathsf{\vphantom{fg}d}}} \\ C &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}clusters}}\times \ensuremath{\mathsf{\vphantom{fg}d}}}\end{aligned}\] we repeat the following update: Compute cluster assignments \[\begin{aligned} Q &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}clusters}}}}{\vphantom{fg}\mathrm{argmin}}} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}d}}}}{\vphantom{fg}\mathrm{norm}}}(C-X)\end{aligned}\] then recompute the cluster centers: \[C \leftarrow \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}batch}}}} \frac{Q \odot X}{Q}.\]

Beam search is a commonly used approach for approximate discrete search. Here \(H\) is the score of each element in the beam, \(S\) is the state of each element in the beam, and \(f\) is an update function that returns the score of each state transition. \[\begin{aligned} H &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}beam}}} \\ S &\in \{0, 1\}^{\ensuremath{\mathsf{\vphantom{fg}beam}}\times \ensuremath{\mathsf{\vphantom{fg}state}}} & \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}state}}}} S &= 1 \\ f &\colon \{0, 1\}^{\ensuremath{\mathsf{\vphantom{fg}state}}} \rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}state}}} \\\end{aligned}\] Then we repeat the following update: \[\begin{aligned} H' &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}beam}}}}{\vphantom{fg}\mathrm{max}}} (H \odot f(S)) \\ H &\leftarrow \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}state}},\ensuremath{\mathsf{\vphantom{fg}beam}}}}{\vphantom{fg}\mathrm{maxk}}} H' \\ S &\leftarrow \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}state}},\ensuremath{\mathsf{\vphantom{fg}beam}}}}{\vphantom{fg}\mathrm{argmaxk}}} H'\end{aligned}\] where \[\begin{aligned} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}},\ensuremath{\mathsf{\vphantom{fg}k}}}}{\vphantom{fg}\mathrm{maxk}}} \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}ax}}} &\rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}k}}} \\ \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}},\ensuremath{\mathsf{\vphantom{fg}k}}}}{\vphantom{fg}\mathrm{argmaxk}}} \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}ax}}} &\rightarrow \{0,1\}^{\ensuremath{\mathsf{\vphantom{fg}ax}},\ensuremath{\mathsf{\vphantom{fg}k}}}\end{aligned}\] are defined such that \([\mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}},\ensuremath{\mathsf{\vphantom{fg}k}}}}{\vphantom{fg}\mathrm{maxk}}} A]_{\ensuremath{\mathsf{\vphantom{fg}k}}(i)}\) is the \(i\)-th largest value along axis \(\ensuremath{\mathsf{\vphantom{fg}ax}}\) and \(A \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\odot}} (\mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}},\ensuremath{\mathsf{\vphantom{fg}k}}}}{\vphantom{fg}\mathrm{argmaxk}}}{A}) = \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}},\ensuremath{\mathsf{\vphantom{fg}k}}}}{\vphantom{fg}\mathrm{max}}} A\).

We can add a \(\ensuremath{\mathsf{\vphantom{fg}batch}}\) axis to \(H\) and \(S\) and the above equations will work unchanged.

3.4.7 Multivariate normal distribution

To define a multivariate normal distribution, we need some matrix operations. These have two axis names written under them, for rows and columns, respectively. Determinant and inverse have the following signatures: \[\begin{aligned} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}_1,\ensuremath{\mathsf{\vphantom{fg}ax}}_2}}{\vphantom{fg}\mathrm{det}}} \colon F^{\ensuremath{\mathsf{\vphantom{fg}ax}}_1[n] \times \ensuremath{\mathsf{\vphantom{fg}ax}}_2[n]} &\rightarrow F \\ \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}_1,\ensuremath{\mathsf{\vphantom{fg}ax}}_2}}{\vphantom{fg}\mathrm{inv}}} \colon F^{\ensuremath{\mathsf{\vphantom{fg}ax}}_1[n] \times \ensuremath{\mathsf{\vphantom{fg}ax}}_2[n]} &\rightarrow F^{\ensuremath{\mathsf{\vphantom{fg}ax}}_1[n] \times \ensuremath{\mathsf{\vphantom{fg}ax}}_2[n]}.\end{aligned}\] (We write \(\text{inv}\) instead of \(\cdot^{-1}\) because there’s no way to write axis names under the latter.)

In our notation, the application of a bilinear form is more verbose than the standard notation (\((X-\mu)^\top \Sigma^{-1} (X-\mu)\)), but also makes it look more like a function of two arguments (and would generalize to three or more arguments).

\[\begin{aligned} \mathcal{N} \colon \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}d}}} &\rightarrow \mathbb{R}\\ \mathcal{N}(X; \mu, \Sigma) &= \frac{\exp\left(-\frac{1}{2} \left(\mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}d}}_1, \ensuremath{\mathsf{\vphantom{fg}d}}_2}}{\vphantom{fg}\mathrm{inv}}} \Sigma\right) \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}d}}_1,\ensuremath{\mathsf{\vphantom{fg}d}}_2}}{\vphantom{fg}\odot}} \left([X - \mu]_{\ensuremath{\mathsf{\vphantom{fg}d}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}d}}_1} \odot [X - \mu]_{\ensuremath{\mathsf{\vphantom{fg}d}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}d}}_2} \right) \right)}{\sqrt{(2 \pi)^{|\ensuremath{\mathsf{\vphantom{fg}d}}|} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}d}}_1, \ensuremath{\mathsf{\vphantom{fg}d}}_2}}{\vphantom{fg}\mathrm{det}}} \Sigma}}\end{aligned}\] where \[\begin{aligned} |\ensuremath{\mathsf{\vphantom{fg}d}}| &= |\ensuremath{\mathsf{\vphantom{fg}d}}_1| = |\ensuremath{\mathsf{\vphantom{fg}d}}_2| \\ \mu &\in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}d}}} \\ \Sigma & \in \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}d}}_1 \times \ensuremath{\mathsf{\vphantom{fg}d}}_2}.\end{aligned}\]

4 LaTeX Macros

Many of the LaTeX macros used in this document are available in the style file https://namedtensor.github.io/namedtensor.sty. To use it, put

\usepackage{namedtensor}

in the preamble of your LaTeX source file (after \documentclass{article} but before \begin{document}).

We write axis names in sans-serif font. To make this easier, \ndef{\ax}{ax} defines a macro \ax that looks like this: \(\ensuremath{\mathsf{\vphantom{fg}ax}}\).

  • Binary operators

    • Use A \ndot{\ax} B for contraction: \(A \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\odot}} B\). You can use \\ to stack up several names.

    • In general, you can use \nbin to make a new binary operator with a name under it: A \nbin{\ax}{\star} B gives you \(A \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\star}} B\).

  • Functions

    • Use \nsum{\ax} A for summation: \(\sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} A\).

    • In general, you can use \nfun to make a function with a name under it: \nfun{\ax}{qux} A gives you \(\mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{qux}}} A\).

5 Formal Definitions

5.1 Records and shapes

A named index is a pair, written \(\ensuremath{\mathsf{\vphantom{fg}ax}}(i)\), where \(\ensuremath{\mathsf{\vphantom{fg}ax}}\) is a name and \(i\) is usually a natural number. We write both names and variables ranging over names using sans-serif font.

A record is a set of named indices \(\{\ensuremath{\mathsf{\vphantom{fg}ax}}_1(i_1), \ldots, \ensuremath{\mathsf{\vphantom{fg}ax}}_r(i_r)\}\), where \(\ensuremath{\mathsf{\vphantom{fg}ax}}_1, \ldots \ensuremath{\mathsf{\vphantom{fg}ax}}_r\) are pairwise distinct names.

An axis is a pair, written \(\ensuremath{\mathsf{\vphantom{fg}ax}}[I]\), where \(\ensuremath{\mathsf{\vphantom{fg}ax}}\) is a name and \(I\) is a set of indices. We deal with axes of the form \(\ensuremath{\mathsf{\vphantom{fg}ax}}[[n]]\) (that is, \(\ensuremath{\mathsf{\vphantom{fg}ax}}[\{1, \ldots, n\}]\)) so frequently that we abbreviate this as \(\ensuremath{\mathsf{\vphantom{fg}ax}}[n]\).

In many contexts, there is only one axis with name \(\ensuremath{\mathsf{\vphantom{fg}ax}}\), and so we refer to the axis simply as \(\ensuremath{\mathsf{\vphantom{fg}ax}}\). The context always makes it clear whether \(\ensuremath{\mathsf{\vphantom{fg}ax}}\) is a name or an axis. If \(\ensuremath{\mathsf{\vphantom{fg}ax}}\) is an axis, we write \(\ind(\ensuremath{\mathsf{\vphantom{fg}ax}})\) for its index set, and we write \(|\ensuremath{\mathsf{\vphantom{fg}ax}}|\) as shorthand for \(|\ind(\ensuremath{\mathsf{\vphantom{fg}ax}})|\).

A shape is a set of axes, written \(\ensuremath{\mathsf{\vphantom{fg}ax}}_1[I_1] \times \cdots \times \ensuremath{\mathsf{\vphantom{fg}ax}}_r[I_r]\), where \(\ensuremath{\mathsf{\vphantom{fg}ax}}_1, \ldots \ensuremath{\mathsf{\vphantom{fg}ax}}_r\) are pairwise distinct names. We write \(\emptyset\) for the empty shape. A shape defines a set of records: \[\rec (\ensuremath{\mathsf{\vphantom{fg}ax}}_1[I_1] \times \cdots \times \ensuremath{\mathsf{\vphantom{fg}ax}}_r[I_r]) = \left\{\{\ensuremath{\mathsf{\vphantom{fg}ax}}_1(i_1), \ldots, \ensuremath{\mathsf{\vphantom{fg}ax}}_r(i_r)\} \mid i_1 \in I_1, \ldots, i_r \in I_r\right\}.\]

We say two shapes \(\mathcal{S}\) and \(\mathcal{T}\) are compatible if whenever \(\ensuremath{\mathsf{\vphantom{fg}ax}}[I] \in \mathcal{S}\) and \(\ensuremath{\mathsf{\vphantom{fg}ax}}[J] \in \mathcal{T}\), then \(I = J\). We say that \(\mathcal{S}\) and \(\mathcal{T}\) are orthogonal if there is no \(\ensuremath{\mathsf{\vphantom{fg}ax}}\) such that \(\ensuremath{\mathsf{\vphantom{fg}ax}}[I] \in \mathcal{S}\) and \(\ensuremath{\mathsf{\vphantom{fg}ax}}[J] \in \mathcal{T}\) for any \(I\), \(J\).

If \(t \in \rec \mathcal{T}\) and \(\mathcal{S} \subseteq \mathcal{T}\), then we write \(\mathopen{}\left.t\right|_{\mathcal{S}}\) for the unique record in \(\rec \mathcal{S}\) such that \(\mathopen{}\left.t\right|_{\mathcal{S}} \subseteq t\).

5.2 Named tensors

Let \(F\) be a field and let \(\mathcal{S}\) be a shape. Then a named tensor over \(F\) with shape \(\mathcal{S}\) is a mapping from \(\mathcal{S}\) to \(F\). We write the set of all named tensors with shape \(\mathcal{S}\) as \(F^{\mathcal{S}}\).

We don’t make any distinction between a scalar (an element of \(F\)) and a named tensor with empty shape (an element of \(F^\emptyset\)).

If \(A \in F^{\mathcal{S}}\), then we access an element of \(A\) by applying it to a record \(s \in \rec \mathcal{S}\); but we write this using the usual subscript notation: \(A_s\) rather than \(A(s)\). To avoid clutter, in place of \(A_{\{\ensuremath{\mathsf{\vphantom{fg}ax}}_1(i_1), \ldots, \ensuremath{\mathsf{\vphantom{fg}ax}}_r(i_r)\}}\), we usually write \(A_{\ensuremath{\mathsf{\vphantom{fg}ax}}_1(i_1), \ldots, \ensuremath{\mathsf{\vphantom{fg}ax}}_r(x_r)}\). When a named tensor is an expression like \((A+B)\), we surround it with square brackets like this: \([A+B]_{\ensuremath{\mathsf{\vphantom{fg}ax}}_1(i_1), \ldots, \ensuremath{\mathsf{\vphantom{fg}ax}}_r(x_r)}\).

We also allow partial indexing. If \(A\) is a tensor with shape \(\mathcal{T}\) and \(s \in \rec \mathcal{S}\) where \(\mathcal{S} \subseteq \mathcal{T}\), then we define \(A_s\) to be the named tensor with shape \(\mathcal{T} \setminus \mathcal{S}\) such that, for any \(t \in \rec (\mathcal{T} \setminus \mathcal{S})\), \[\begin{aligned} \left[A_s\right]_t &= A_{s \cup t}.\end{aligned}\] (For the edge case \(\mathcal{T} = \emptyset\), our definitions for indexing and partial indexing coincide: one gives a scalar and the other gives a tensor with empty shape, but we don’t distinguish between the two.)

5.3 Named tensor operations

In §2, we described several classes of functions that can be extended to named tensors. Here, we define how to do this for general functions.

Let \(f \colon F^{\mathcal{S}} \rightarrow G^{\mathcal{T}}\) be a function from tensors to tensors. For any shape \(\mathcal{S'}\) orthogonal to both \(\mathcal{S}\) and \(\mathcal{T}\), we can extend \(f\) to: \[\begin{aligned} f \colon F^{\mathcal{S} \cup \mathcal{S'}} &\rightarrow G^{\mathcal{T} \cup \mathcal{S'}} \\ [f(A)]_s &= f(A_s) \qquad \text{for all $s \in \rec\mathcal{S'}$.}\end{aligned}\]

If \(f\) is a multary function, we can extend its arguments to larger shapes, and we don’t have to extend all the arguments with the same names. We consider just the case of two arguments; three or more arguments are analogous. Let \(f \colon F^{\mathcal{S}} \times G^{\mathcal{T}} \rightarrow H^{\mathcal{U}}\) be a binary function from tensors to tensors. For any shapes \(\mathcal{S'}\) and \(\mathcal{T'}\) that are compatible with each other and orthogonal to \(\mathcal{S}\) and \(\mathcal{T}\), respectively, and \(\mathcal{S'} \cup \mathcal{T'}\) is orthogonal to \(\mathcal{U}\), we can extend \(f\) to: \[\begin{aligned} f \colon F^{\mathcal{S} \cup \mathcal{S'}} \times G^{\mathcal{T} \cup \mathcal{T'}} &\rightarrow H^{\mathcal{U} \cup \mathcal{S'} \cup \mathcal{T'}} \\ [f(A,B)]_s &= f\left(A_{\mathopen{}\left.s\right|_{\mathcal{S'}}},B_{\mathopen{}\left.s\right|_{\mathcal{T'}}}\right) \qquad \text{for all $s \in \rec (\mathcal{S'} \cup \mathcal{T'})$.}\end{aligned}\]

5.4 Common operations

All the tensor operations described in §2.2 can be defined in this way, and others listed below.

5.4.0.1 Elementwise operations

(\(\mathbb{R}\rightarrow \mathbb{R}\)) \[\begin{aligned} \sigma(x) &= \frac{1}{1+\exp(-x)} \\ \text{relu}(x) &= \max(0, x)\end{aligned}\]

5.4.0.2 Reductions

(\(\mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}ax}}[n]} \rightarrow \mathbb{R}\)) \[\begin{aligned} \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} A &= \sum_{i=1}^n A_{\ensuremath{\mathsf{\vphantom{fg}ax}}(i)} \\ \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{min}}} A &= \min \{A_{\ensuremath{\mathsf{\vphantom{fg}ax}}(i)} \mid 1 \leq i \leq n\} \\ \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{max}}} A &= \max \{A_{\ensuremath{\mathsf{\vphantom{fg}ax}}(i)} \mid 1 \leq i \leq n\} \\ \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{norm}}} A &= \sqrt{\sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} A^2} \\ \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{mean}}} A &= \frac{1}{n} \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} A \\ \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{var}}} A &= \frac{1}{n} \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} (A - \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{mean}}} A)^2\end{aligned}\]

5.4.0.3 Contraction

(\(\mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}ax}}[n]} \times \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}ax}}[n]} \rightarrow F\)) \[\begin{aligned} A \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\odot}} B &= \sum_{i=1}^n A_{\ensuremath{\mathsf{\vphantom{fg}ax}}(i)} B_{\ensuremath{\mathsf{\vphantom{fg}ax}}(i)}\end{aligned}\]

5.4.0.4 Vectors to vectors

(\(\mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}ax}}[n]} \rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}ax}}[n]}\)) \[\begin{aligned} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{softmax}}} A &= \frac{\exp A}{\sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} \exp A} \\ \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{argmax}}} A &= \lim_{\alpha \rightarrow \infty} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{softmax}}} \alpha A \\ \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{argmin}}} A &= \lim_{\alpha \rightarrow -\infty} \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{softmax}}} \alpha A\end{aligned}\]

5.4.0.5 Renaming

(\(\mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}ax}}[n]} \rightarrow \mathbb{R}^{\ensuremath{\mathsf{\vphantom{fg}ax}}'[n]}\)) \[\begin{aligned} [A_{\ensuremath{\mathsf{\vphantom{fg}ax}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}ax}}'}]_{\ensuremath{\mathsf{\vphantom{fg}ax}}'(i)} &= A_{\ensuremath{\mathsf{\vphantom{fg}ax}}(i)}\end{aligned}\]

6 Differentiation

Let \(f\) be a function from order-\(m\) tensors to order-\(n\) tensors and let \(Y = f(X)\). The partial derivatives of \(Y\) with respect to \(X\) form an order-\((m+n)\) tensor: \(m\) “input” axes for the directions in which \(X\) could change and \(n\) “output” axes for the change in \(Y\).

For example, if \(f\) maps from vectors to vectors, then \(\frac{\partial Y}{\partial X}\) is a matrix (the Jacobian). But using matrix notation, there are conflicting conventions about whether the first axis is the input axis (“denominator layout”) or the output axis (“numerator layout”). The derivative of a function from vectors to matrices or matrices to vectors cannot be represented as a matrix at all, so one must resort to flattening the matrices into vectors.

With tensors, taking derivatives of higher-order tensors with respect to higher-order tensors is not difficult (Laue, Mitterreiter, and Giesen 2018). With named tensors, we get the additional advantage of using names to distinguish input and output axes.

6.1 Definition

Let \(f \colon \mathbb{R}^\mathcal{S} \rightarrow \mathbb{R}^\mathcal{T}\), where \(\mathcal{S}\) and \(\mathcal{T}\) are orthogonal, and let \(Y = f(X)\). Then the derivative of \(Y\) at \(X\) is the tensor with shape \(\mathcal{S} \times \mathcal{T}\) such that for all \(s \in \rec\mathcal{S}\) and \(t \in \rec\mathcal{T}\), \[\left[\frac{\partial Y}{\partial X} \right]_{s,t} = \frac{\partial Y_t}{\partial X_s}.\]

If \(X\) and \(Y\)’s shapes are not orthogonal, we take the derivative of \(Y_{\mathcal{T}\rightarrow\mathcal{T'}}\) instead. (It’s also possible to rename \(X\), but we think it’s easier to think about renaming \(Y\), so that’s what we’ll do.) Assume \(\mathcal{T} = \ensuremath{\mathsf{\vphantom{fg}ax}}_1 \times \cdots \times \ensuremath{\mathsf{\vphantom{fg}ax}}_r\). Then for each \(\ensuremath{\mathsf{\vphantom{fg}ax}}_i\), choose a new name \(\ensuremath{\mathsf{\vphantom{fg}ax}}_i'\) not in either \(\mathcal{S}\) or \(\mathcal{T}\), and let \(\mathcal{T'} = \ensuremath{\mathsf{\vphantom{fg}ax}}_1' \times \cdots \times \ensuremath{\mathsf{\vphantom{fg}ax}}_r'\). Then we seek the tensor of partial derivatives \[\left[\frac{\partial Y_{\mathcal{T}\rightarrow\mathcal{T'}}}{\partial X} \right]_{s,t'} = \frac{\partial Y_t}{\partial X_s}.\]

6.2 Rules

To compute derivatives, we use the method of differentials (Magnus and Neudecker 1985). The differential of an expression \(U\), written \(\partial U\), is a tensor with the same shape as \(U\), computed using rules like the following: \[\begin{aligned} \partial f(U) &= f'(U) \mathbin{\underset{\substack{\mathcal{U}}}{\vphantom{fg}\odot}} \partial U && f \colon \mathbb{R}^\mathcal{U} \rightarrow \mathbb{R}^{\mathcal{V}} \\ \partial (U + V) &= \partial U + \partial V \\ \partial \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} U &= \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} \partial U \\ \partial (U \odot V) &= \partial U \odot V + U \odot \partial V \\ \partial (U \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\odot}} V) &= \partial U \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\odot}} V + U \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\odot}} \partial V \\ \partial \left(\frac{U}{V}\right) &= \frac{\partial U \odot V - U \odot \partial V}{V^2} \\ \partial U_s &= \left[\partial U\right]_s \\ \partial U_{\ensuremath{\mathsf{\vphantom{fg}ax}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}ax}}'} &= \left[\partial U\right]_{\ensuremath{\mathsf{\vphantom{fg}ax}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}ax}}'}\end{aligned}\] If we obtain an equation in the so-called canonical form \[\partial Y = A \mathbin{\underset{\substack{\mathcal{S}}}{\vphantom{fg}\odot}} \partial X + \text{const.}\] where \(\mathcal{S}\) is orthogonal to \(\mathcal{T}\) and “const” stands for terms not depending on \(\partial X\), then we have \[\frac{\partial Y}{\partial X} = A.\]

In order to get equations into canonical form, some tricks are useful. First, contractions can be easier to reason about if rewritten as sums of elementwise products: \[A \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\odot}} B = \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} A \odot B.\] Second, renaming can be thought of as contraction with an identity matrix: \[\begin{aligned} _{\ensuremath{\mathsf{\vphantom{fg}ax}}(i),\ensuremath{\mathsf{\vphantom{fg}ax}}'(j)} &= \begin{cases} 1 & i = j \\ 0 & i \neq j \end{cases} \\ A_{\ensuremath{\mathsf{\vphantom{fg}ax}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}ax}}'} &= \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} I_{\ensuremath{\mathsf{\vphantom{fg}ax}},\ensuremath{\mathsf{\vphantom{fg}ax}}'} \odot A.\end{aligned}\]

6.3 Example

Let’s find the differential of the softmax operator. \[\begin{aligned} Y &= \mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{softmax}}} X \\ \partial Y &= \partial \biggl(\frac{\exp X}{\sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} \exp X}\biggr) \\ &= \frac{\exp X \odot \partial X \odot \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} \exp X - \exp X \odot \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} (\exp X \odot \partial X)}{\bigl(\sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} \exp X\bigr)^2} \\ &= Y \odot (\partial X - Y \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\odot}} \partial X).\end{aligned}\]

Next, use this to find the Jacobian, \(\frac{\partial Y}{\partial X}\). Since \(X\) and \(Y\) have the same shape, we rename \(Y\): \[\begin{aligned} \partial Y_{\ensuremath{\mathsf{\vphantom{fg}ax}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}ax}}'} &= [Y \odot (\partial X - \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} Y \odot \partial X)]_{\ensuremath{\mathsf{\vphantom{fg}ax}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}ax}}'} \\ &= Y_{\ensuremath{\mathsf{\vphantom{fg}ax}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}ax}}'} \odot (\partial X_{\ensuremath{\mathsf{\vphantom{fg}ax}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}ax}}'} - \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} Y \odot \partial X) \\ &= Y_{\ensuremath{\mathsf{\vphantom{fg}ax}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}ax}}'} \odot \left(\sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} I_{\ensuremath{\mathsf{\vphantom{fg}ax}}',\ensuremath{\mathsf{\vphantom{fg}ax}}} \odot \partial X - \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} Y \odot \partial X\right) \\ &= \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} Y_{\ensuremath{\mathsf{\vphantom{fg}ax}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}ax}}'} \odot (I_{\ensuremath{\mathsf{\vphantom{fg}ax}}',\ensuremath{\mathsf{\vphantom{fg}ax}}} - Y) \odot \partial X \\ \frac{\partial Y_{\ensuremath{\mathsf{\vphantom{fg}ax}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}ax}}'}}{\partial X} &= Y_{\ensuremath{\mathsf{\vphantom{fg}ax}}\rightarrow\ensuremath{\mathsf{\vphantom{fg}ax}}'} \odot (I_{\ensuremath{\mathsf{\vphantom{fg}ax}}',\ensuremath{\mathsf{\vphantom{fg}ax}}} - Y).\end{aligned}\]

To derive the rule for backpropagation, we assume a function \(f \colon \mathbb{R}^\ensuremath{\mathsf{\vphantom{fg}ax}}\rightarrow \mathbb{R}\) and differentiate \(f(Y)\). Since \(f\) is scalar-valued, there is no name overlap, so no renaming is needed. \[\begin{aligned} \partial f(Y) &= \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} f'(Y) \odot \partial Y \\ &= \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} f'(Y) \odot Y \odot (\partial X - \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} Y \odot \partial X) \\ %&= f'(Y) \ndot{\ax} (Y \odot (\partial X - Y \ndot{\ax} \partial X)) \\ &= \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} f'(Y) \odot Y \odot \partial X - \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} f'(Y) \odot Y \odot \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} Y \odot \partial X \\ &= \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} f'(Y) \odot Y \odot \partial X - \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} \left(\sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} f'(Y) \odot Y\right) \odot Y \odot \partial X \\ &= \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} Y \odot (f'(Y) - \sum\limits_{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}} f'(Y) \odot Y) \odot \partial X \\ \frac{\partial f(Y)}{\partial X} &= Y \odot (f'(Y) - f'(Y) \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\odot}} Y).\end{aligned}\]

6.4 Broadcasting

Let \(f \colon \mathbb{R}^\mathcal{S} \rightarrow \mathbb{R}^\mathcal{T}\), and let \(f'\) be its derivative. If \(X \in \mathbb{R}^{\mathcal{S} \cup \mathcal{U}}\), where \(\mathcal{U}\) is orthogonal to both \(\mathcal{S}\) and \(\mathcal{T}\), recall that \(Y = f(X)\) is defined by: \[\begin{aligned} Y_r &= f(X_r)\end{aligned}\] Finding the differential of \(Y\) is easy: \[\begin{aligned} \partial Y_r &= f'(X_r) \mathbin{\underset{\substack{\mathcal{S}}}{\vphantom{fg}\odot}} \partial X_r \\ \partial Y &= f'(X) \mathbin{\underset{\substack{\mathcal{S}}}{\vphantom{fg}\odot}} \partial X.\end{aligned}\] But although \(f'\) extends to \(X\) using the usual broadcasting rules, it’s not the case that \(\frac{\partial Y}{\partial X} = f'(X)\), which would have the wrong shape. The reason is that the contraction is only over \(\mathcal{S}\), not \(\mathcal{S}\cup\mathcal{U}\). To get this into the form ([eq:canonical]): \[\begin{aligned} \partial Y_{\mathcal{U}\rightarrow\mathcal{U'}} &= \sum\limits_{\substack{\mathcal{S}}} [f'(X) \odot \partial X]_{\mathcal{U}\rightarrow\mathcal{U'}} \\ &= \sum\limits_{\substack{\mathcal{S}}} \sum\limits_{\substack{\mathcal{U}}} I_{\mathcal{U},\mathcal{U'}} \odot f'(X) \odot \partial X \\ \frac{\partial Y_{\mathcal{U}\rightarrow\mathcal{U'}}}{\partial X} &= I_{\mathcal{U},\mathcal{U'}} \odot f'(X).\end{aligned}\] In general, then, when we extend a function to new axes, we extend its derivative by multiplying by the identity matrix for those axes.

7 Alternatives

A very frequently asked question is why we haven’t used index notation as used in physics, and the Einstein summation convention in particular. In this notation, axes are ordered, and every equation is written in terms of tensor components. If an index appears on both sides of an equation, then the equation must hold for each value of the index, and if an index appears twice on one side and not on the other, there is an implicit summation over that index. \[\begin{aligned} \text{Attention} \colon \mathbb{R}^{n' \times d_k} \times \mathbb{R}^{n \times d_k} \times \mathbb{R}^{n \times d_v} &\rightarrow \mathbb{R}^{n' \times d_v} \\ \left[\text{Attention}(Q, K, V)\right]_{i'k} &= \softmax_i \left( \frac{Q_{i'j} K_{ij}}{\sqrt{d_k}} \right) V_{ik}.\end{aligned}\] Because \(i'\) and \(k\) appear on both sides, the equation must hold over all values of these indices. But because \(j\) and \(k\) occur twice on only the right-hand side, they are both summed over. We’d have to define exactly what the \(i\) under softmax means (\(i\) is bound inside the softmax and free outside it), and since softmax doesn’t distribute over addition, we’d need to clarify that the summation over \(j\) occurs inside the softmax.

Other than that, this is concise and unambiguous. But it doesn’t really solve the main problem we set out to solve, which is that ordered axes force the author and reader to remember the purpose of each axis. The indices do act as symbolic names for axes (indeed, in abstract index notation, they really are symbols, not variables), but they are temporary names; they could be totally different in the next equation. It would be up to the author to choose to use consistent names, and to do so correctly.

A second issue is that because it depends on repetition of indices to work, index notation can be a little bit more verbose than our notation, particularly for reductions and contractions: \[\begin{aligned} C &= \max_i A_i & C &=\mathop{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\mathrm{max}}} A \\ C &= A_i B_i & C &= A \mathbin{\underset{\substack{\ensuremath{\mathsf{\vphantom{fg}ax}}}}{\vphantom{fg}\odot}} B.\end{aligned}\]

Finally, index notation requires us to write out all indices explicitly. So if we wanted to extend attention to multiple heads and minibatches, we would write: \[\begin{gathered} \text{Attention} \colon \mathbb{R}^{B \times H \times n' \times d_k} \times \mathbb{R}^{B \times H \times n \times d_k} \times \mathbb{R}^{B \times H \times n \times d_v} \rightarrow \mathbb{R}^{B \times H \times n' \times d_v} \\ \left[\text{Attention}(Q, K, V)\right]_{bhi'k} = \softmax_i \left( \frac{Q_{bhi'j} K_{bhij}}{\sqrt{d_k}} \right) V_{bhik}.\end{gathered}\] We could adopt a convention that extends a function on tensors to tensors that have extra axes to the left, but such conventions tend to lead to messy reordering and squeezing/unsqueezing of axes. Named axes make this unnecessary.

Acknowledgements

Thanks to Ekin Akyürek, Justin Bayer, Colin McDonald, Adam Poliak, Matt Post, Chung-chieh Shan, Nishant Sinha, and Yee Whye Teh for their input to this document (or the ideas in it).

References

Chen, Tongfei. 2017. “Typesafe Abstractions for Tensor Operations.” In Proceedings of the 8th ACM SIGPLAN International Symposium on Scala, 45–50. SCALA 2017. https://doi.org/10.1145/3136000.3136001.

Harris, Charles R., K. Jarrod Millman, Stéfan J. van der Walt, Ralf Gommers, Pauli Virtanen, David Cournapeau, Eric Wieser, et al. 2020. “Array Programming with NumPy.” Nature 585 (7825): 357–62. https://doi.org/10.1038/s41586-020-2649-2.

Hoyer, Stephan, and Joe Hamman. 2017. “xarray: N-D Labeled Arrays and Datasets in Python.” Journal of Open Research Software 5 (1): 10. https://doi.org/http://doi.org/10.5334/jors.148.

Laue, Soeren, Matthias Mitterreiter, and Joachim Giesen. 2018. “Computing Higher Order Derivatives of Matrix and Tensor Expressions.” In Advances in Neural Information Processing Systems, edited by S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, 31:2750–9. Curran Associates, Inc. https://proceedings.neurips.cc/paper/2018/file/0a1bf96b7165e962e90cb14648c9462d-Paper.pdf.

Maclaurin, Dougal, Alexey Radul, Matthew J. Johnson, and Dimitrios Vytiniotis. 2019. “Dex: Array Programming with Typed Indices.” In NeurIPS Workshop on Program Transformations for ML. https://openreview.net/forum?id=rJxd7vsWPS.

Magnus, Jan R., and H. Neudecker. 1985. “Matrix Differential Calculus with Applications to Simple, Hadamard, and Kronecker Products.” Journal of Mathematical Psychology 29 (4): 474–92. https://doi.org/https://doi.org/10.1016/0022-2496(85)90006-9.

Paszke, Adam, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, et al. 2019. “PyTorch: An Imperative Style, High-Performance Deep Learning Library.” In Advances in Neural Information Processing Systems 32, edited by H. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché-Buc, E. Fox, and R. Garnett, 8024–35. Curran Associates, Inc. http://papers.neurips.cc/paper/9015-pytorch-an-imperative-style-high-performance-deep-learning-library.pdf.

Rush, Alexander. 2019. “Named Tensors.” https://github.com/harvardnlp/NamedTensor.

Sinha, Nishant. 2018. “Tensor Shape (Annotation) Library.” https://github.com/ofnote/tsalib.

Torch Contributors. 2019. “Named Tensors.” https://pytorch.org/docs/stable/named_tensor.html.

Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. 2017. “Attention Is All You Need.” In Advances in Neural Information Processing Systems, edited by I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, 30:5998–6008. Curran Associates, Inc. https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf.