In [13]:
import numpy as np
from collections import defaultdict

Build your own Pytorch - 2: Backpropagation¶

In this tutorial, we are going to learn the core of any automatic differentiation software: the backpropagation algorithm. The backpropagation algorithm allows us compute gradients iteratively by going backwards from a leaf node $x_n$ to its ancestors $x_1,\dots,x_{n-1}$.

The goal is as follows: let $x_1,\dots,x_n$ be the values of nodes $i=1,\dots,n$ in a computation graph. Let's assume that $x_n$ is a leaf node, i.e. $x_1,\dots,x_{n-1}$ are ancestors of $x_n$. Then, we want to compute: $$\frac{d}{dx_i}x_n$$ for all $x_i$ and $i=1,\dots,n-1$. A simple derivation that you can see in textbooks is as follows: if $x_{i+1}$ is only function of $x_{i}$, then using the chain rule, we have: $$\frac{d}{dx_i}x_n = \frac{d}{dx_{i+1}}x_n \frac{d}{dx_i}x_{i+1}$$ So basically, by using the chain rule we can compute $\frac{d}{dx_i}x_{n}$ iteratively by going backwards $i=n,n-1,n-2,\dots,2,1$. However, in reality, $x_{i+1}$ can be a complicated function of several nodes (e.g. in a recurrent neural network nodes are re-used multiple times). In this tutorial, we will see how backpropagation is implemented in such complicated (but practically very important) scenarios.

Note: we are using the code developed in the previous tutorial that I have put into a package called compgraph.

In [14]:
import compgraph as cg
from compgraph.nodes import *

Optional: mathematical preparation. As we work with tensors of arbitrary shape (not just vectors and matrices), we quickly define common concepts such as Jacobians and dot products for tensors. Let $x,y$ be two tensors with $\text{shape}(x)=(n_1,\dots,n_r)$ and $\text{shape}(y)=(m_1,\dots,m_s)$. Then the Jacobian $\frac{d}{dx}y$ is the tensor of shape $(m_1,\dots,m_s,n_1,\dots,n_r)$ with $$\left[\frac{d}{dx}y\right]_{j_1,\dots,j_s,i_1,\dots,i_r}=\frac{d}{dx_{i_1,\dots,i_r}}y_{j_1,\dots,j_s}$$ If $x,y$ are two tensors with $\text{shape}(x)=(n_1,\dots,n_r,m_1,\dots,m_s)$ and $\text{shape}(y)=(m_1,\dots,m_s,l_1,\dots,l_t)$, then $\text{shape}(x\cdot y)=(n_1,\dots,n_r,l_1,\dots,l_t)$ with: $$\left[x\cdot y\right]_{i_1,\dots,i_r,j_1,\dots,j_s,j_1,\dots,j_s}=\sum\limits_{k_1,\dots,k_s} x_{i_1,\dots,i_r,k_1,\dots,k_s}y_{k_1,\dots,k_s,j_1,\dots,j_s}$$

1. Mathematical Derivation of the backpropagation algorithm on computation graphs¶

A few definitions to get started. Let our computation graph $G$ have $n$ nodes $i=1,\dots,n$ and let $x_1,\dots,x_n$ be their computed node values. Let the index $i$ represent their order, i.e. $x_{i+1}$ was computed after $x_i$ (and $x_{i+1}$ is therefore a function of only $x_1,\dots,x_{i}$). For a node $i$, let $P_i$ be the parent nodes for node $i$ and $C_i$ the children of node $i$. Let $A_i$ be the descendants of node $i$ (nodes reachable from node i) and $B_i$ be the nodes that are not ancestors and no descendants (i.e. the complement of $A_i$).

We know all elementary operations. In a computation graph, we know the elementary operations $f$ that convert two nodes $x,y$ into an operational node $f_i(x,y)=z$. For example, $f_i$ could be addition ($f_i(x,y)=x+y$) or matrix multiplication ($f_i(x,y)=x\cdot y$). For every operational node $i$, the value $x_i$ is a function of its parent nodes $x_{p_i},x_{q_i}$:

$$x_i = f_{j_i}(x_{p_i},x_{q_i})$$

where $f_i$ is a pre-defined elementary function.

III. We know the gradient of all elementary operations. For every elementary operations $f$, we know the gradients with respect to each operand that we call left gradient and right gradient: $$\text{lgrad}(f)(x,y)=\frac{d}{dx}f(x,y)$$ $$\text{rgrad}(f)(x,y)=\frac{d}{dy}f(x,y)$$ The above definition lead to the following adjoint functions: $$\text{ladj}(f)(x,y,z)=z \cdot \frac{d}{dx}f(x,y)$$ $$\text{radj}(f)(x,y,z)=z\cdot \frac{d}{dy}f(x,y)$$ for $z$ is a tensor with the same shape as $f(x,y)$. Note that $z\cdot \frac{d}{dx}f(x,y)$ describes the dot product and Jacobian for tensors as defined above.

IV. Deriving the gradient. Let $x_n$ be the leaf node and $x_i$ the ancestor node that we want to compute the gradient $\frac{d}{x_i}x_n$ for. For now, let's consider the values of $x_{B_i}$ as fixed (i.e. all nodes that are neither ancestors of $i$ nor descendants of $i$ are considered constant). Then, we know that for some (unknown) function $G$ by the chain rule:

$$x_n = G(x_{C_i}) = G(x_{c_1},\dots,x_{c_{k}}) = G(f_{j_1}(x_i,x_{j_1}),...,f_{j_k}(x_i,x_{j_k}))$$

for some $j_1<...<j_k$ and $c_1<c_2<\dots<c_k$ - where we assume for the sake of simplicity that $x_i$ has always been the left operand. In other words, for fixed $x_{B_i}$, the value of $x_n$ only depends on the children $C_i=\{c_1,\dots,c_k\}$ of $i$. If we take the gradient with respect to $x_i$, we get:

\begin{align*} \frac{d}{dx_i} x_n =&\sum\limits_{l = 1}^{k}\frac{d}{dx_{c_l}}G(x_{C_i})\frac{d}{dx_{i}}x_{c_l} \quad \text{[by the chain rule]}\\ =&\sum\limits_{l = 1}^{k}\frac{d}{dx_{c_l}}G(x_{C_i})\left(\text{lgrad}(f_{j_l})(x_i,x_{j_l})+\text{rgrad}(f_{j_l})(x_i,x_{j_l})\frac{d}{dx_i}x_{j_l}\right)\quad \text{[by the chain rule]}\\ =&\sum\limits_{l = 1}^{k}\frac{d}{dx_{c_l}}G(x_{C_i})\left(\text{lgrad}(f_{j})(x_i,x_{j_l})+\text{rgrad}(f_{j})(x_i,x_{j_l})\sum\limits_{m=1}^{l-1} \frac{d}{dx_{c_m}}x_{j_l}\frac{d}{dx_{i}}x_{c_m}\right) \quad [\text{because }x_{j_l}\text{ can only depend on } c_1,\dots,c_{l-1}]\\ =&\sum\limits_{l = 1}^{k}\frac{d}{dx_{c_l}}G(x_{C_i})\left(\text{lgrad}(f_{j})(x_i,x_{j_l})+\text{rgrad}(f_{j})(x_i,x_{j_l})\sum\limits_{m=1}^{l-1} \frac{d}{dx_{c_m}}x_{j_l}\text{lgrad}(f_{j})(x_i,x_{j_m})\right) \quad [\text{by the definition of lgrad}(f_{j_l})]\\ =&\sum\limits_{l = 1}^{k}\frac{d}{dx_{c_l}}G(x_{C_i})\text{lgrad}(f_{j})(x_i,x_{j_l})+\sum\limits_{l = 1}^{k}\sum\limits_{m=1}^{l-1} \frac{d}{dx_{c_l}}G(x_{C_i})\text{rgrad}(f_{j})(x_i,x_{j_l})\frac{d}{dx_{c_m}}x_{j_l}\text{lgrad}(f_{j})(x_i,x_{j_m})\quad[\text{matrix algebra}]\\ =&\sum\limits_{l = 1}^{k}\frac{d}{dx_{c_l}}G(x_{C_i})\text{lgrad}(f_{j})(x_i,x_{j_l})+\sum\limits_{l = 1}^{k}\sum\limits_{m=l+1}^{k}\frac{d}{dx_{c_m}}G(x_{C_i})\text{rgrad}(f_{j})(x_i,x_{j_m})\frac{d}{dx_{c_l}}x_{j_m}\text{lgrad}(f_{j})(x_i,x_{j_l})\quad[\text{reindexing}]\\ =&\sum\limits_{l = 1}^{k}\left(\frac{d}{dx_{c_l}}G(x_{C_i})+\sum\limits_{m=l+1}^{k}\frac{d}{dx_{c_m}}G(x_{C_i})\text{rgrad}(f_{j})(x_i,x_{j_m})\frac{d}{dx_{c_l}}x_{j_m}\right)\text{lgrad}(f_{j})(x_i,x_{j_l})\quad[\text{matrix algebra}]\\ =&\sum\limits_{l = 1}^{k}\left(\frac{d}{dx_{c_l}}G(x_{C_i})+\sum\limits_{m=l+1}^{k}\frac{d}{dx_{c_m}}G(x_{C_i})\frac{d}{dx_{c_l}}x_{c_m}\right)\text{lgrad}(f_{j})(x_i,x_{j_l})\quad[\text{chain rule}]\\ =&\sum\limits_{l = 1}^{k}\left(\frac{d}{dx_{c_l}}x_n\right)\text{lgrad}(f_{j})(x_i,x_{j_l})\quad[\text{chain rule and because }x_{c_m}\text{ can change only }x_{c_{m+1}},\dots,x_{c_{k}}]\\ =&\sum\limits_{l = 1}^{k}\text{ladj}(f_{j})(x_i,x_{j_l},\frac{d}{dx_{c_l}}x_n)\quad[\text{definition of ladj}] \end{align*}

In other words, the derivative of a leaf value $x_n$ with respect to an ancestor node $x_i$ equals the sum of gradients of the children of $i$ multiplied with the adjoints of the children of $i$ (i.e. $C_i$). Note the significance of this: to compute the derivative $\frac{d}{dx_i}x_n$, we only need to know the derivatives of the children nodes: $\frac{d}{dx_{c_j}}x_n$ and how to take derivatives of the operations $f_j$. The only important thing is that we have traverse the graph backwards in hierachical order. As we have learnt in tutorial 1, this basically corresponds to breadth-first search. We only need to modify the breadth-first-search such that we not only build the computation graph but we also add gradients.

2. Backpropagation with Breadth-First-Search¶

The above derivation leads to the following algorithm to compute gradients:

Input: leaf node $l$.

  1. Initialize NodesQueue with $l$
  2. Initialize gradient dict: grad_dict$(l)=1$ and grad_dict$(j)=0$ for any other node $j$.
  3. Initialize graph G with a single node $l$.
  4. While NodesQueue not empty:

    4.1 Pop node $j$ in NodesQueue with highest node_order.\ 4.2. If $j$ is OperationalNode with $x_j=f_j(x_{p_j},x_{q_j})$, then:

    4.2.1. Add all parent nodes $p_j,q_j$ to G\ 4.2.2. Add edges $p_j\to j,q_j\to j$ to G.\ 4.2.3. Add gradient to $p_j$ if $p_j$ is not ConstantNode: grad_dict$(p_j)$ += $\text{ladj}(f_j)(x_{p_j},x_{q_j},$ grad_dict$(j))$\ 4.2.4. Add gradient to $q_j$ if $q_j$ is not ConstantNode: grad_dict$(q_j)$ += $\text{radj}(f_j)(x_{p_j},x_{q_j}$ grad_dict$(j))$

  5. Return grad_dict

3. Implementing adjoints¶

For every elementary operation $f_j$ that we defined in our computation graph in tutorial 1, we have to define the adjoint function. We list here a few examples:

3.1. $f$ = sum¶

If $f(x,y) = x+y$, then: $$\frac{d}{dx}f(x,y) = \frac{d}{dy}f(x,y) = \mathbb{1}_d$$ $$\Rightarrow \text{ladj}(f)(x,y,z)=z, \text{radj}(f)(x,y,z)=z$$ where $d$ is the dimension of $x$.

In [2]:
def add_grad(prev_adjoint, node):
    return [prev_adjoint, prev_adjoint]

3.2. $f$ = multiply¶

If $f(x,y) = x*y$ (componentise multiplication), then: $$\frac{d}{dx}f(x,y) = \text{diag}(y)$$ $$\frac{d}{dy}f(x,y) = \text{diag}(x)$$ $$\Rightarrow \text{ladj}(f)(x,y,z)=z*y, \text{radj}(f)(x,y,z)=z*x$$

In [3]:
def mul_grad(prev_adjoint, node):
    return [
        prev_adjoint * node.operand_b,
        prev_adjoint * node.operand_a
    ]

3.3. Matrix multiplication¶

If $f(x,y)=x \cdot y$ for matrices $x\in\mathbb{R}^{k\times l},y\in\mathbb{R}^{l\times m}$ we have that: $$\frac{d}{dx_{i,j}}f_{i',j'}(x,y) =\delta_{i=i'}y_{j,j'} \Rightarrow [z\cdot \frac{d}{dx}f(x,y)]_{ij}=\sum\limits_{j'}z_{ij'}y_{j,j'} \Rightarrow \text{ladj}(f)(x,y,z)=zy^T$$ $$\frac{d}{dy_{i,j}}f_{i',j'}(x,y) =\delta_{j=j'}x_{i',i} \Rightarrow [z\cdot \frac{d}{dy}f(x,y)]_{ij}=\sum\limits_{i'}z_{i'j}x_{i',i} \Rightarrow \text{radj}(f)(x,y,z)=x^Tz$$

In [4]:
def dot_grad(prev_adjoint, node):
    
    prev_adj = prev_adjoint
    op_a = node.operand_a
    op_b = node.operand_b

    
    if node.operand_b.ndim == 1:
        prev_adj = cg.reshape(prev_adjoint, (-1,1))
        op_b = cg.reshape(op_b, (-1, 1))

    if node.operand_a.ndim == 1:
        prev_adj = cg.reshape(prev_adjoint, (1,-1))
        op_a = cg.reshape(op_a, (1, -1))

    adj_op_a = cg.dot(prev_adj, op_b.T)
    adj_op_b = cg.dot(op_a.T, prev_adj)
    
    if node.operand_a.ndim == 1:
        adj_op_a = adj_op_a.squeeze()
        
    if node.operand_b.ndim == 1:
        adj_op_b = adj_op_b.squeeze()

    return [adj_op_a,adj_op_b]

3.4. Max operation¶

In neural networks, we also use operations that are not differentiably everywhere. An important example is the max operation, that is used for max pooling for example. For the max operation, we approximate it as follows: we count the number of elements in the input tensor that equal the maximum - let's say there are $d$ values. Then we set the gradient to be $1/d$ for these nodes and $0$ for the others.

In [5]:
def max_grad(prev_adjoint, node):
    doperand_a = cg.where(node.operand_a == node.with_keepdims, 1, 0)
    normalizers = cg.sum(doperand_a, axis=node.axis, keepdims=True)
    normalized_doperand_a = doperand_a / normalizers

    if node.axis is not None:
        return [np.expand_dims(prev_adjoint, node.axis) * normalized_doperand_a, None]
    else:
        return [prev_adjoint * normalized_doperand_a, None]

3.5. Packaging adjoints¶

I have written adjoint functions for all numerical functions that we defined to built computation graphs and then put them in the autodiff package. Every function basically uses the techniques defined above. The only technical novelty that I have added is a function decorater called @define_grad_func that ensures that all inputs have the correct shape and deals with broadcasting edge cases, e.g. cases when we add we sum tensors of different shapes, e.g. $[2]+[4,2,1]=[6,4,3]$.

In [6]:
import autodiff.grads as grads

4. Implementing Breadth-First-Search Backpropagation¶

Finally, we can also implement the breadth-first search backpropagation algorithm as defined above: the crucial line is

op_grad = getattr(grads, '{}_grad'.format(current_op))

that selects the right adjoint for a specific operation from the autodiff.grad package.

In [7]:
def compute_backprop(node, verbose=False):
    """
    computes and returns the gradient of the given node wrt to VariableNodes
    the function implements a breadth-first-search (BFS) to traverse the
    computational graph from the gievn node back to VariableNodes

    Parameters:
    ----------
    node: Node
        the leaf node to compute its gradient
    """

    adjoint = defaultdict(int)
    grad = {} #to be computed for variable nodes:
    queue = NodesHeap()
    
    #Ensure that node is a float:
    assert node.size == 1, "We only allow to compute gradients from scalars."
    
    # put the given node in the queue and set its adjoint to one
    adjoint[node.name] = ConstantNode.create_using(np.ones(node.shape))
    queue.push(node)

    while len(queue) > 0:
        current_node = queue.pop()
        
        if verbose:
            print("Popped node: ", current_node.name)

        if isinstance(current_node, ConstantNode):
            continue
        if isinstance(current_node, VariableNode):
            grad[current_node.name] = adjoint[current_node.name]
            continue

        current_adjoint = adjoint[current_node.name]
        current_op = current_node.opname
        
        #Get gradient function:
        op_grad = getattr(grads, '{}_grad'.format(current_op))
        
        #Compute next adjoints:
        next_adjoints = op_grad(current_adjoint, current_node)

        #Add next_adjoint to old adjoint:
        adjoint[current_node.operand_a.name] = adjoint[current_node.operand_a.name] + next_adjoints[0]
        
        #Add new node to queue if not already inside the queue:
        if current_node.operand_a not in queue:
            queue.push(current_node.operand_a)

        #Do the same for operand_b is exists:
        if current_node.operand_b is not None:
            adjoint[current_node.operand_b.name] = adjoint[current_node.operand_b.name] + next_adjoints[1]

            if current_node.operand_b not in queue:
                queue.push(current_node.operand_b)

    return grad

5 Computing a few example gradients¶

Finally, let's compute a few example gradients to see whether the above algorithms lead to correct results. We have a function that compute's gradients by a numerical approximation:

In [8]:
import itertools
def check_gradient(fx, x_input, suspect):
    """
    checks the correctness of the suspect derivative value against
    the value of the numerical approximation of the derivative

    Parameters:
    ----------
    fx: callable
        The function to check its derivative
    wrt: int
        0-based index of the variable to differntiate with respect to
    args: list
        the values of the function variables at the derivative point
    suspect: float
        the the suspected value of the derivative to check
    """
    h = 1.e-7
    approx_grad = np.zeros_like(x_input)
    fx_input = fx(x_input)
    
    for el in itertools.product(*[list(range(i)) for i in x_input.shape], repeat=1):
        x_input_shifted = x_input.copy()
        x_input_shifted[el] = x_input_shifted[el] + h
        approx_grad[el] = (fx(x_input_shifted) - fx_input) / h

    return approx_grad

2.1. Example A¶

In [11]:
cg.reset()
var_node = cg.VariableNode.create_using([[1,2,4],[2.0,4.0,5.0]])
const_node = cg.ConstantNode.create_using([[8,1,3],[4.0,2.0,4.0]])

def function_1(x_input):
    power_x = x_input**const_node
    x_sum = x_input + const_node
    x_cos = cg.cos(x_sum)
    labels = cg.ConstantNode.create_using(np.zeros_like(x_cos))
    labels[:,0] = 1
    softmax = cg.softmax_cross_entropy(x_cos,labels)
    return softmax

output = function_1(var_node)
cg.build_and_visualize_graph(output)
computed_gradient = compute_backprop(output)['_0']
approx_grad = check_gradient(function_1,var_node,computed_gradient)
assert np.allclose(approx_grad,computed_gradient)

5.2. Example B¶

In [12]:
cg.reset()
var_node = cg.VariableNode.create_using([[[1,-4,4],[1.0,4.0,5.0]],[[12,-34,44],[-2,-4,6.0]]])
const_node = cg.ConstantNode.create_using([[[-1,4,4],[2.0,-3.0,5.0]],[[12,-3,4],[-4,-4,2.0]]])
mat_node = cg.ConstantNode.create_using([[4,-1],[2,3]])

def function_2(x_input):
    x_sum = x_input - const_node
    x_cos = cg.cos(x_sum)
    x_sin = cg.sin(x_sum)
    x_reduced_1 = cg.sum(x_cos,axis=2)
    x_reduced_2 = cg.mean(x_sin,axis=2)
    prod_reduced = x_reduced_1*x_reduced_2
    max_prod = cg.max(prod_reduced,axis=1)
    output = cg.dot(mat_node,max_prod)
    return cg.sum(output)

output = function_2(var_node)
cg.build_and_visualize_graph(output)
computed_gradient = compute_backprop(output)['_0']
approx_grad = check_gradient(function_2,var_node,computed_gradient)
assert np.allclose(approx_grad,computed_gradient)