Modern neural networks often consist of multiple layers and a large number of parameters, sometimes reaching millions. The central challenge in training such networks is determining how to efficiently update these parameters so that the modelβs predictions improve over time. This learning process is achieved through an algorithm known as backpropagation, which systematically computes gradients of the loss function with respect to each parameter in the network.
Build Intuition of Chain Rule of Calculus
Consider we have a function οΏΌΒ g(f(x))
Chain Rule of Calculus (scaler case):
Consider,
z=f(x)=x2
y=g(z)=ez=ex2
We can calculate dzdyβ and dxdzβ easily.
dzdyβ=ez,dxdzβ=2x
But how would you calculate dxdyβ here? We have to apply chain rule here.
dxdyβ=dzdyβ.dxdzβ=ez.2x=ex2.2x
Chain Rule of Calculus (vector case):
Here each ziβ is a function of
x=βx1βx2β..xmβββ
and y is a function of z=βz1βz2β..znβββ. So how do we calculate βxiββyβ for any xiβ?
Now to understand it simply, earlier for scaler case we just had only one path from x to z to y (xβzβy). Now for vector case here we have three paths from xiβ to y (considering n=3). So we have to apply chain rule in all the three paths and take the summation.
We want to train the model, that means we came up with the neural network model structure but we want to find the optimized weight parameters here. Weight parameters are named as wij(l)β and it means the weight from node i from previous layer to node j in the current layer and current layer is l (starting from the first hidden layer).
How do we find the optimizer proper weights? We run Gradient Descent algorithm and weight parameters are updated in each iteration.
wt+1=wtβΞ±βwtβEβ
wt contains all the weight parameters wij(l)β.
So at each iteration, our target is to find βwij(l)ββEβ for all the weight parameters.
What is E here?
E is the Error or Loss we found comparing the actual output (y) and predicted output (y^β).
Loss/Error Functions
There are various types of Loss/Error functions depending on the task we want to perform. Here are few examples..
Mean Square Error(MSE)
Used mostly in linear regression problems.
E=n1βi=1βnβ(yiββyiβ^β)2
Root Mean Square Error(RMSE)
Just taking square root of Mean Square Error
E=n1βi=1βnβ(yiββyiβ^β)2β
Mean Absolute Error (MAE)
Sum of absolute differences between actual and predicted outputs
E=n1βi=1βnββ£yiββyiβ^ββ£
Binary Cross-Entropy
Used in binary classification
E=βylog(y^β)β(1βy)log(1βy^β)
Categorical Cross-Entropy
Used in multi-class classification when outputs are represented as one-hot encoding format and there are K classes
E=βk=1βKβykβlog(ykβ^β)
Sparse Categorical Cross-Entropy
Used in multi-class classification when outputs are represented with class indices (1,2,..K), y is the correct class index and y^βyβ is the predicted probability of correct class y
E=βlog(y^βyβ)
Backward Propagation of Errors
Now we will see how we calculate the gradient of the error (or loss) with respect to all the weights. We start from the output layer and move back towards the input layer.
We are considering binary cross entropy loss for our discussion here.
Here z1(2)β is the linear weighted combination of previous layer neurons. Also, biases ( like w01(2)β ) are not shown in the diagram so you just assume they are there :)
Here we are considering a different network with two hidden layers and we will see how error gradients are flowing backward from output layer to input layers.
Here z1(3)β is a combination of a1(2)β and a2(2)β. So βz1(3)ββEβ will flow in both the path as shown in the image.
βz1(2)ββEβ will flow backward to three paths corresponding to a1(1)β, a2(1)β and a3(1)β. Similarly, βz2(2)ββEβ will also flow backward to three paths corresponding to a1(1)β, a2(1)β and a3(1)β.
How do you calculate βa1(1)ββEβ? a1(1)β is influencing both z1(2)β and z2(2)β. The error βz1(2)ββEβ and βz2(2)ββEβ both backpropagating towards a1(1)β.
So we have to apply vector case for chain rule of calculas..
βz1(2)ββEβ and βz2(2)ββEβ is already backpropagating towards a1(1)β. We just need to calculate βa1(1)ββz1(2)ββ and βa1(1)ββz2(2)ββ.