SymbolicvsADm



In [1]:
% ignore - internal setup
path('../scripts', path);

Symbolic Differentiation vs Automatic Differentiation

Consider the function below that, at least computationally, is very simple.


In [ ]:
function y = func(x)

    y = x;
    for i = 1:30
        y = sin(x + y);
    end
   
end

We can compute a derivative symbolically, but it is horrendous (see below). Think of how much worse it would be if we chose a function with products, more dimensions, or iterated more than 30 times.


In [10]:
syms x y;

y = x;
for i = 1:30
    y = sin(x + y);
end


dydx = diff(y, x)


 
dydx =
 
cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x))))))))))))))))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x)))))))))))))))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x))))))))))))))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x)))))))))))))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x))))))))))))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x)))))))))))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x))))))))))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x)))))))))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x))))))))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x)))))))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x))))))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x)))))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x))))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x)))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x))))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x)))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x))))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x)))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x))))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x)))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x))))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x)))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x))))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(x + sin(2*x)))))))*(cos(x + sin(x + sin(x + sin(x + sin(x + sin(2*x))))))*(cos(x + sin(x + sin(x + sin(x + sin(2*x)))))*(cos(x + sin(x + sin(x + sin(2*x))))*(cos(x + sin(x + sin(2*x)))*(cos(x + sin(2*x))*(2*cos(2*x) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1) + 1)
 

We can now evaluate the expression. In this case, we evaluate at x=0.1.


In [11]:
xpt = 0.1;
dydx = vpa(subs(dydx, x, xpt), 16)  % 16 significant digits


 
dydx =
 
1.917706760386666
 

Let's now compare with automatic differentiation using operator overloading. I'm using the AD package here. I don't use Matlab in our research so I am not familiar with the AD support in Matlab. In specific, I am not endorsing this package over others, it was just one of the top results on Google.


In [14]:
clear;
xpt = 0.1;

x = ainit(xpt, 1);  % initialize x at the point, and get 1st derivatives
y = func(x);
format long;
dydx = y{1} % pull out the first derivative


dydx =

   1.917706760386666

Let's also compare to AD using a source code transformation method (I used Tapenade in Fortran to generate this function)


In [ ]:
function [y, yd] = funcad(x)

    xd = 1.0;
    yd = xd;
    y = x;
    for i = 1:30
        yd = (xd + yd)*cos(x + y);
        y = sin(x + y);
    end

end

In [16]:
[~, dydx] = funcad(xpt)


dydx =

   1.917706760386666

For a simple expression like this, symbolic differentiation is long but actually works reasonbly well, and both will give a numerically exact answer. But if we change the loop to 100+ or add other complications, the symbolic solver will fail or take much longer. However, automatic differentiation will continue to work without issue. Furthermore, if we add other dimensions to the problem, symbolic differentiation quickly becomes costly as lots of computations get repeated, whereas automatic differentiation is able to reuse a lot of calculations.

As a specific example, if I change the number of iterations to 300 rather than 30, the symbolic differentiation takes 7.0 seconds, the overloaded AD takes 0.7 seconds, and the source code transformation takes 0.001 seconds. The overloaded AD is an order of magnitude faster than symbolic differentiation (and the source code transformation version is blazingly fast). In some langauges and implementations, overloaded AD speeds aren't as dramtically different as compared to source-code transformed AD.


In [ ]: