Notes on Weight Initialization for Deep Neural Networks

Tl; dr

Neural networks involve long sequence of multiplications, usually between a matrix and a vector, say \(a*x\). The result of this sequence of multiplications will either have a huge magnitude or be reduced to 0. We can divide \(a\) by a number (scaling_factor) to scale down its magnitude to the right level. Proper init strategies help us in finding a good scaling_factor.

Outline

The problem of weight initialization is motivated by a simulation. The real cause of ill-behaved multiplication output is identified, and scaling the weight matrix using the Xavier initialization is presented as a solution. The rest of the writeup then provides experiments and proofs to explain why such an initialization worked. The blog originated from a discussion during part 2 of the Fast AI course (Spring 2019 session), and parts of simulations are taken from the course notebooks.

Motivation

Training (and inference) of a neural network involves a bunch of operations, and one of the most common of these operations is multiplication. Typically, the multiplication happens between matrices. In the case of deep neural networks, we end up with a long sequence of such multiplications.

input = x
output = input
for layer in network_layers:
	output = activation(output * layer.weights + layer.bias)

Investigating the Sequence of Multiplications

To begin our investigation, let’s take a random input vector, x and a random matrix a. Note that the numbers are sampled from a normal distribution with a mean 0 and variance 1 or as it is popularly known, \(\mathcal{N}(0, 1)\).

We’ll multiply the vector x and the matrix a 100 times (as if the network had 100 layers), and see what gets out on the other side. Note that we don’t use any activation function for the sake of simplicity.

x = torch.randn(512)
a = torch.randn(512,512)
for i in range(100):
	x = a @ x
x.mean(),x.std()

output:

(tensor(nan), tensor(nan))

x has a huge magnitude! It seems like the multiplication snowballed, and the magnitude of x increased with each step, finally pushing the mean of a*x out of the limits of numbers in python (note that we are feeding x back to a @ x).

A first intuitive solution

Intuitively, since the product of a and x is becoming large, we may start by reducing the magnitude of the matrix a. The hope is that because of a smaller a, the product a * x won’t shoot up in magnitude. Thus, we divide our matrix (i.e., each element of the matrix a) by a scaling_factor of 100, and repeat the process.

scaling_factor = 100
x = torch.randn(512)
a = torch.randn(512,512) / scaling_factor
for i in range(100): 
	x = a @ x
x.mean(), x.std()
(tensor(0.), tensor(0.))

So we did solve the problem of magnitude explosion, only to create another: the output now vanishes to 0.

Xavier Initialization

We saw that using a scaling_factor of 100 didn’t quite work. It reduced the product to 0. We had started with the problem of the magnitude exploding to infinity, and the scaling brought it down to 0. Surely, the right solution lies somewhere in the middle. That’s exactly what the Xavier initialization does: it helps us in finding a scaling factor that would get it right.

The Xavier initialization suggests using a scaling factor of \(\sqrt(n\_in)\), where n_in is the number of inputs to the matrix (or the dimension that’s common with the vector the matrix is being multiplied with).

In our case, the number of inputs to the matrix \(a\) is 512. Thus, the scaling factor should be \(\sqrt(512)\). In other words, if we divide our matrix a by \(\sqrt(512)\), we should not see either vanishing or exploding magnitudes. Let’s see if the Xavier init helps:

import math
scaling_factor = math.sqrt(512)
x = torch.randn(512)
a = torch.randn(512,512) / scaling_factor
for i in range(100):
	x = a @ x
x.mean(), x.std()
(tensor(0.0429), tensor(0.9888))

The magnitude of the product hasn’t exploded or vanished. In fact, the output has a nice mean (close to 0) and standard deviation (close to 1). Recall that the input was actually sampled from such a distribution. In a way, our solution managed to retain the distribution of the inputs. That’s a really nice thing, because now we can perform a large number of such multiplications.

Putting things in context, this translates to being allowed to train really deep neural networks. Note that Xavier initialization is sufficient to solve the problem in this case because we did not use any activation function. If we had used, say a ReLu, the more recent Kaiming Initialization would have been more effective. So why did this work? What is so special about \(\sqrt(512)\) as a scaling factor?

Why \(\sqrt(512)\)? | Intuition

Before we start, let us look closely at our simulation, particularly the following line:

x = a @ x

Note that we are not changing a at all. Thus, the only element that can cause trouble is x, since it’s being updated. Seems like at some point in the multiplication sequence, x starts getting high values, and thus the subsequent multiplications keep making things worse. To closely examine this phenomenon, let us denote the product of a and x by y.

\[y = a * x\]

As in our running examples, if a is a matrix of size 512 x 512, and x is a vector of size 512, then output y is a vector of the size 512.

To be more explicit, one element of y is calculated as follows:

\[y_{i} = a_{i,0} x_{0} + a_{i,1} x_{1} + \cdots + a_{i,n-1} x_{n-1} = \sum_{k=0}^{n-1} a_{i,k} x_{k}\]

As we saw above, something goes wrong with the y values. That something is the following:

To compute one element of y, we add 512 products of one element of a by one element of x. What’s the mean and the variance of such a product? As we show later, as long as the elements in a and the elements in x are independent (which they are in this case; one doesn’t affect the other), the mean is 0 and the variance is 512. That is, each element of y is now taking the values as if it was picked \(\mathcal{N}(0, 512)\)! This can also be seen experimentally as in the following code snippet. To avoid one-off errors, we repeat the experiment for 10000 iterations.

mean, var = 0.0, 0.0
n_iter = 10000
n_dim = 512
ys = []
for i in range(n_iter):
	x = torch.randn(n_dim)
	a = torch.randn(n_dim) #just like one row of a
	y = a@x
	mean += y.item()
	ys.append(y.item())
mean/n_iter, torch.tensor(ys).var(), torch.tensor(ys).std()
(-0.13198307995796205, tensor(513.4638), tensor(22.6597))

In other words, each element of y is now picked from an erratic distribution, and that’s happening because we are adding a product of 512 elements, each picked from \(\mathcal{N}(0, 1)\). We keep feeding these y elements again in the loop as the input, and thus things go haywire soon.

Now, if we scale the weights of the matrix \(a\) and divide them by \(math.sqrt(512)\), we will be picking elements of \(a\) from a normal distribution with \(0\) mean and variance = \(512\) or \(\mathcal{N}(0, 512)\) (see the next section for a proof).

This scaling will, in turn, give us a distribution of \(y\) in which each element has 0 mean and std = 1, thus allowing us to repeat the product has many times as we want. This is NOT different from the intuitive solution we had discussed earlier. We were right in guessing that scaling one of the participants in the product may help, and it did. Xavier init helped us in finding out the exact magnitude of the scaling factor: \(\sqrt(512)\), instead of what we had initially used: 100.

mean, var = 0.0, 0.0
n_iter = 10000
n_dim = 512
ys = []
for i in range(n_iter):
	x = torch.randn(n_dim)
	a = torch.randn(n_dim) / math.sqrt(n_dim) #just like one row of a
	y = a@x
	mean += y.item()
	ys.append(y.item())
mean/n_iter, torch.tensor(ys).var(), torch.tensor(ys).std()
(-0.00671036799326539, tensor(1.0186), tensor(1.0092))

It works, and each element of y (and y as a whole) now has mean 0 and variance/std 1. We can thus keep multiplying the output y with a repeatedly, without worrying about things changing a lot.

Why \(\sqrt(512)\)? | Proofs

We are given that \(x\) and \(a\) are from a normal distribution with mean = 0 and variance = 1 (or \(\mathcal{N}(0, 1)\)). That is, to create \(x\), we pick 512 random numbers from \(\mathcal{N}(0, 1)\) (see). Similarly, to create \(a\), we pick (512 * 512) random numbers from \(\mathcal{N}(0, 1)\) . Then, the \(i^{th}\) element of \(y\) is calculated by multiplying 512 elements of \(a\) (i.e. \(a[i]\)) with 512 elements of \(x\).

\[y_{i} = a_{i,0} x_{0} + a_{i,1} x_{1} + + a_{i,511} x_{511} = \sum_{k=0}^{511} a_{i,k} x_{k}\]

1. Proof that \(Y \sim \mathcal{N}(0, 512)\)

Let \(A\), \(X\) and \(Y\) be the random variables from which \(a\), \(x\) and \(y\) are sampled respectively. We know that one element of \(Y\) is created by multiplying 512 elements from \(A\) and \(X\) with each other. That is, we sample 512 elements from \(A\), 512 elements from \(X\), multiply them element by element, and add them. Thus, so far we have:

\[\begin{aligned} & A \sim \mathcal{N}(0, 1) \\ & X \sim \mathcal{N}(0, 1) \\ & E[A] = 0 \\ & E[X] = 0 \\ & Var[A] = Std[A] = 1 \\ & Var[X] = Std[X] = 1 \\ \end{aligned}\]

and \(\begin{aligned} Y = \sum_{k=0}^{511} A*X \end{aligned}\)

Let’s start by calculating the mean of Y

1.1 Expectation (Mean) of Y

\[\begin{aligned} E[Y] & = E[AX] \\ & = E[A] * E[X] = 0 & (\text{A and X are independent, and E[A] = E[X] = 0}) \end{aligned}\]

(See properties of expectation)

1.2 Variance of Y

We know that \(Y\) is created by adding 512 elements sampled from \(A*X\). Thus, let’s first calculate the variance of \(A*X\). That is, what would be the variance if we pick one element randomly from \(A\) and \(X\) and then multiply them?

\[\begin{aligned} Var[AX] & = Var(A)*(E(X))^2 + Var(X)*(E(A))^2 + Var(A)*Var(X) & (\text{A and X are independent}) \\ & = Var(A) * Var(X)\\ & = 1 \end{aligned}\]

(Reference for the variance property)

We know that Y is formed by summing 512 such elements or

\[\begin{aligned} Y = \sum_{k=0}^{511} A*X \end{aligned}\]

Thus

\[\begin{aligned} Var[Y] & = Var[\sum_{k=0}^{511}A * X] \\ & = \sum_{k=0}^{511} Var[AX] &(\text{A and X are independent}) \\ & = \sum_{k=0}^{511} 1 &(\text{Var[AX] = 1 from above}) \\\\ & = 512 \end{aligned}\]

(Reference)

In other words, \(Y \sim \mathcal{N}(0, 512)\) which is terrible, since Y now varies a lot! The experiment is reproduced below for ready reference. Each of the ys have a large variance! As they are fed to the subsequent layers, the product can only get worse, as we’ve seen.

mean, var = 0.0, 0.0
n_iter = 10000
n_dim = 512
ys = []
for i in range(n_iter):
	x = torch.randn(n_dim)
	a = torch.randn(n_dim) #just like one row of a
	y = a@x
	mean += y.item()
	ys.append(y.item())
mean/n_iter, torch.tensor(ys).var()
(-0.10872242888212204, tensor(514.2963))

2. Proof that Y is \(\sim \mathcal{N}(0, 1)\) when A \(\sim \mathcal{N}(0, 1 / 512)\)

If we scale the weights of the matrix \(a\) and divide them by this \(\sqrt(512)\), we will be picking elements of \(a\) from a distribution with \(0\) mean and variance = (1 / 512) i.e. \(\mathcal{N}(0, 1 / 512)\). This will in turn give us a distribution of \(y\) in which each element has 0 mean and std = 1, thus allowing us to repeat the product as many times as we want (or in other words, make our network deeper).

We will now prove that dividing a by \(\sqrt(512)\) leads to \(Y\) getting a better distribution.

Now we have

\[\begin{aligned} & A \sim \mathcal{N}(0, 1 / 512) \\ & X \sim \mathcal{N}(0, 1) \\ & E[A] = 0 \\ & E[X] = 0 \\ & Var[A] = 1 / 512, Std[A] = 1 / \sqrt(512) \\ & Var[X] = Std[X] = 1 \\ \end{aligned}\]

2.1 Expectation (Mean) of Y

\[\begin{aligned} E[Y] & = E[AX] \\ & = E[A] * E[X] = 0 \ (\because E[A] = E[X] = 0) \end{aligned}\]

2.2 Variance of Y

As before, let’s first calculate the variance of \(AX\). That is, what would be the variance if we pick one element randomly from \(A\) and \(X\) and then multiply them?

\[\begin{aligned} Var[AX] & = Var(A)*(E(X))^2 + Var(X)*(E(A))^2 + Var(A)*Var(X) \\ & = Var(A) * Var(X) = 1 / 512 \end{aligned}\]

Now,

\[\begin{aligned} Y = \sum_{k=0}^{511} A*X \end{aligned}\]

Thus,

\[\begin{aligned} Var[Y] & = Var[\sum_{k=0}^{511}A * X] \\ & = \sum_{k=0}^{511} Var[AX] &(\text{A and X are independent}) \\ & = \sum_{k=0}^{511} 1 / 512 &(\text{Var[AX] = 1 from above}) \\\\ & = 1 \end{aligned}\]

In other words, \(Y \sim \mathcal{N}(0, 1)\) which is what we wanted! Let’s do an experiment to make sure this holds:

mean, var = 0.0, 0.0
n_iter = 10000
n_dim = 512
ys = []
for i in range(n_iter):
	x = torch.randn(n_dim)
	a = torch.randn(n_dim) / math.sqrt(n_dim) #just like one row of a
	y = a@x
	mean += y.item()
	ys.append(y.item())
mean/n_iter, torch.tensor(ys).var(), torch.tensor(ys).std()
(-0.008042885749042035, tensor(0.9856), tensor(0.9928))

Works! Each element of y will thus be sampled from a well behaved distribution. Here is the original simulation with the fix for a quick reference:

scaling_factor = math.sqrt(512)
x = torch.randn(512)
a = torch.randn(512,512) / scaling_factor
for i in range(100): 
	x = a @ x
x.mean(), x.std()
(tensor(0.0121), tensor(1.1693))

Summary

The operation of sequenced multiplication of matrices lies at the core of neural networks. We see that without proper initialization, inputs sampled from well-behaved distribution \(( \mathcal{N}(0, 1))\) will vanish (over-scaling) or explode (under-scaling). Dividing weight matrix by \(\sqrt(num\_inputs)\) (num_inputs = 512 in the running example), known as Xavier Initialization, helps in ensuring that the output of each of the multiplications is well-behaved, thus ensuring that the sequence of multiplications yields a reasonable output at each step. While Xavier intialization puts us on the right track, Kaiming Initialization provides the optimal scaling factor when ReLu is used as an activation (non-linearity) between multiplications in the network.

Written on April 1, 2019