Musings on the fundamental nature of the Fast Fourier Transform.
Original author: David Banas capn.freako@gmail.com
Original date: November 23, 2015
Copyright (c) 2015 David Banas, except where copied with permission and noted.
All rights reserved World wide.
About 2 years ago, I listened to Conal Elliott give a talk at the Haskell Hackers at the Hacker Dojo meetup on having achieved new levels of efficiency of implementation of the parallel scan algorithm. I got very excited about this, because parallel scan is one of these algorithms, which greatly benefits from a logarithmic "divide and conquer" breakdown in which the problem of calculating scan of some data structure is broken down into finding the scan of a number of subsets of the original data set, recombining these sub-results to form the final answer. By doing so, some redundancy in the more straightforward approach is removed, and the total work required to complete the calculation reduced from, typically, $O(N^2)$ to $O(Nlog_2N)$.
I use another member of this class of algorithm in my daily work: the Fast Fourier Transform (FFT). I, like most, had always assumed that $O(N log_2 N)$ was the best we could do in calculating the FFT of a data set. After hearing Conal's talk, I wondered if this was really true. I began a Haskell project to start investigating this: TreeViz. When I showed TreeViz to Conal, he was intrigued, but noted that I'd made an unfortunate choice: I'd "hard-wired" my implementation to be dependent upon the List functor as the enclosing data container. (At that time, I hadn't made it very far past List, with regard to my Haskelling.)
Conal suggested that we search for a more generic mode of expressing the FFT computation. This IHaskell notebook attempts to capture the historical record of our journey from the List-centric implementation in TreeViz to a very recent breakthrough of Conal's, which has enabled a completely generic expression of the FFT, valid for a much broader class of functors.
Note: In order to use this notebook and play with the code, you will need to install the following Haskell packages:
For those packages, which are more intimately involved with the development of my own code, I have opted to copy the source into the code, here. Note that anywhere, below, where I've done this I have:
The discrete Fourier transform (DFT) is used to assess the spectral content of a set of samples of a time varying function, taken at uniformly spaced time intervals. It is defined as:
$$ DFT \{x_n\} = \{X_n\} \quad | \quad X_n = \sum_{m=0}^{N-1} x_m \cdot e^{-j\frac{2\pi}{N}nm} \qquad (1) $$It can be expressed, in Haskell, using a straightforward translation of the algebraic definition:
(The LANGUAGE pragmas aren't needed, here, but if they don't come at the top of the notebook, we get funny, hard to debug errors, later.)
In [2]:
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE StandaloneDeriving #-}
import Data.Complex
import Data.Time
putStrLn "Notebook last run:"
getCurrentTime
dft :: RealFloat a => [Complex a] -> [Complex a]
dft xs = [ sum [ x * exp((0.0 :+ (-1.0)) * 2 * pi / lenXs * fromIntegral(k * n))
| (x, n) <- zip xs [0..]
]
| k <- [0..(length xs - 1)]
]
where lenXs = fromIntegral $ length xs
However, such an implementation fails to capitalize on much redundancy in the calculation of the individual elements of the output sequence, requiring $O(N^2)$ multiply-accumulate (MAC) operations. A more efficient implementation was first discovered by Gauss, in the early 1800s, and formalized for the signal processing World by Cooley and Tukey in 1965.
By splitting the sum, above, into two parts: taking the even and odd members of the original list, respectively, we get:
$$ X_n = \sum_{m=0}^{\frac{N}{2} - 1} x_{2m} \cdot e^{-j\frac{2\pi}{N}n2m} + \sum_{m=0}^{\frac{N}{2} - 1} x_{2m+1} \cdot e^{-j\frac{2\pi}{N}n(2m+1)} $$$$ = \sum_{m=0}^{\frac{N}{2} - 1} x_{2m} \cdot e^{-j\frac{2\pi}{N}n2m} + e^{-j\frac{2\pi}{N}n}\sum_{m=0}^{\frac{N}{2} - 1} x_{2m+1} \cdot e^{-j\frac{2\pi}{N}n2m} \qquad (2) $$which can be rewritten (by a simple change of variable):
$$ X_n = \sum_{m=0}^{N' - 1} x^e_m \cdot e^{-j\frac{2\pi}{N'}nm} + e^{-j\frac{2\pi}{N}n}\sum_{m=0}^{N' - 1} x^o_m \cdot e^{-j\frac{2\pi}{N'}nm} \qquad (3) $$where $N' = \frac{N}{2}$ and $x^e$ and $x^o$ are the sequences formed by taking the even and odd samples, respectively, of the original input sequence.
With the exception of the added phase factor multiplying the second sum, we have reduced the computation of an N-point DFT to a sum of two N/2-point DFTs. In the process, we've reduced the operational order of our computation from $O(N^2)$ to $O(2(\frac{N}{2})^2)$ = $O(N^2/2)$, cutting the work required to complete the computation in half. And, if we continue applying this technique recursively until we can divide no further, we'll find that we've reduced the operational order of the DFT to $O(N log_2N)$, as was shown by Cooley and Tukey in their 1965 paper.
Perhaps, you noticed that, while $m$ in (3) only sweeps over the range $[0, \frac{N}{2})$, $n$ still sweeps over $[0, N)$. (Afterall, we still require $N$ elements in our output sequence.) And, perhaps, you reasoned that this fact makes our claim of cutting the required work in half false, since we, apparently, need to calculate each half length DFT twice. However, observe what happens when we ask for the $n + \frac{N}{2}$ element of the output sequence:
$$ X_{n + \frac{N}{2}} = X_{n + N'} = \sum_{m=0}^{N' - 1} x^e_m \cdot e^{-j\frac{2\pi}{N'}(n + N')m} + e^{-j\frac{2\pi}{N}(n + N')}\sum_{m=0}^{N' - 1} x^o_m \cdot e^{-j\frac{2\pi}{N'}(n + N')m} $$$$ = \sum_{m=0}^{N' - 1} x^e_m \cdot e^{-j\frac{2\pi}{N'}nm} \cdot e^{-j 2\pi m} + e^{-j\frac{2\pi}{N}n} e^{-j\pi} \sum_{m=0}^{N' - 1} x^o_m \cdot e^{-j\frac{2\pi}{N'}nm} \cdot e^{-j 2\pi m} $$$$ = \sum_{m=0}^{N' - 1} x^e_m \cdot e^{-j\frac{2\pi}{N'}nm} - e^{-j\frac{2\pi}{N}n} \sum_{m=0}^{N' - 1} x^o_m \cdot e^{-j\frac{2\pi}{N'}nm} \qquad (4) $$since $e^{-j 2\pi m} = 1$, for any integer $m$. And, we find:
$$ DFT \{x_n\} = concat[(DFT \{x^e_m\} + \{W_m\} \cdot DFT \{x^o_m\}), (DFT \{x^e_m\} - \{W_m\} \cdot DFT \{x^o_m\})] \qquad (5) $$where:
Note that the $\{W_m\}$ are a property of the computational structure. They do not change with the incoming data. This means they can be calculated in advance and that their calculation does not contribute to the run time computational load, which is the critical optimization metric. (Of course, the element-wise multiplication of these "twiddle" factors does contribute to the run time computational load.)
Let us now use (5) to calculate, precisely, our computational savings, by counting the number of multiplications, which must occur to complete the computation, and comparing that number to $N^2$. (It is commonplace to ignore addition and subtraction, when calculating computational complexity, because the hardware costs (i.e. - power consumption, silicon area, effect on $f_{max}$, etc.) of those operations pale, in comparison to multiplication.) We find that each subtransform requires $(\frac{N}{2})^2$ multiplications, while the application of the twiddle factors requires $\frac{N}{2}$, for a total of $\frac{N^2 + N}{2}$, which is approximately equal to $\frac{N^2}{2}$, for large $N$, yielding a factor of 2 improvement, which was our original claim.
Here, I attempt to show how the implementation of the FFT operation evolved, as our project progressed. The journey started with a very primitive list-based implementation, then moved on to trees, ala Conal's Circat library, and finally arrived at our originally hoped for destination, with Conal's recent successful implementation of an FFT instance for generic functor compositions.
The first FFT implementation occured, as part of the development of my TreeViz project, using List as the enclosing functor. The code for this implementation is show, below.
In [3]:
import System.Random
import Text.Printf
import Data.Newtypes.PrettyDouble
radix2_DIT :: RealFloat a => [Complex a] -> [Complex a]
radix2_DIT [] = []
radix2_DIT [x] = [x]
radix2_DIT xs = (++) (zipWith (+) xes xos) (zipWith (-) xes xos)
where xes = radix2_DIT (evens xs)
xos = zipWith (*) (radix2_DIT (odds xs)) [wn ** (fromIntegral k) | k <- [0..]]
wn = exp(0.0 :+ (-2.0*pi/n))
n = fromIntegral (length xs)
evens :: [a] -> [a]
evens [] = []
evens (x:xs) = x : (odds xs)
odds :: [a] -> [a]
odds [] = []
odds (x:xs) = evens xs
do
g <- getStdGen
let test_val = map ((:+ 0.0) . PrettyDouble) (take 8 (randoms g :: [Double]))
print $ radix2_DIT test_val == dft test_val
The code worked and was much more succint and readable than an equivalent implementation in, say, C. However, the explicit dependency upon the List functor, sprinkled throughout the radix2_DIT definition, was problematic. We wanted to be able to express the FFT more generally.
Our next step involved using the data structures and associated functions in Conal's Circat project, to express the FFT in terms of perfect binary trees. Normally, I just import everything I need. However, here I recreate these definitions, so as to make this notebook more self-contained and descriptive, and less fragile.
You can skip over these recreated imports, if you wish.
In [4]:
import Control.Applicative
import Control.Arrow ((&&&))
import Data.Traversable
import Data.Tuple (swap)
-----------------------------------------------------------------------
-- import TypeUnary.Nat
-----------------------------------------------------------------------
data Z
data S n
data Nat :: * -> * where
Zero :: Nat Z
Succ :: IsNat n => Nat n -> Nat (S n)
class IsNat n where
nat :: Nat n
instance IsNat Z where
nat = Zero
instance IsNat n => IsNat (S n) where
nat = Succ nat
natToZ :: Num a => Nat n -> a
natToZ Zero = 0
natToZ (Succ n) = 1 + natToZ n
-- A pretty printing alternative to Complex:
newtype MyComplex a = MyComplex {complex :: Complex a} deriving (Num)
instance (Show a) => Show (MyComplex a) where
show z@(MyComplex (x :+ y)) = show x ++ " + " ++ show y ++ "j"
-- Note: RTree definition and associated code copied with permission.
-----------------------------------------------------------------------
-- import Circat.Pair
-----------------------------------------------------------------------
infixl 1 :#
data Pair a = a :# a deriving (Functor,Eq)
instance (Show a) => Show (Pair a) where
show (x :# y) = " (" ++ show x ++ " :# " ++ show y ++ ") "
instance Applicative Pair where
pure a = a :# a
(f :# g) <*> (a :# b) = (f a :# g b)
instance Foldable Pair where
foldMap f (a :# b) = f a `mappend` f b
instance Traversable Pair where
traverse h (fa :# fb) = liftA2 (:#) (h fa) (h fb)
type Unop a = a -> a
firstP :: Unop a -> Unop (Pair a)
firstP f (a :# b) = f a :# b
secondP :: Unop a -> Unop (Pair a)
secondP f (a :# b) = a :# f b
toP :: (a,a) -> Pair a
toP (a,b) = a :# b
fromP :: Pair a -> (a,a)
fromP (a :# b) = (a,b)
inP :: ((a, a) -> (a, a)) -> Pair a -> Pair a
inP g = toP . g . fromP
-----------------------------------------------------------------------
-- import Circat.Scan
-----------------------------------------------------------------------
-- Generalize the Prelude's 'scanl' on lists
scanlT :: Traversable t => (b -> a -> b) -> b -> t a -> (t b,b)
scanlT op e = swap . mapAccumL (\ a b -> (a `op` b,a)) e
-- Like 'scanlT', but drop the last element.
scanlTEx :: Traversable t => (b -> a -> b) -> b -> t a -> t b
scanlTEx op e = fst . scanlT op e
-----------------------------------------------------------------------
-- import Circat.Misc
-----------------------------------------------------------------------
transpose :: (Traversable t, Applicative f) => t (f a) -> f (t a)
transpose = sequenceA
-----------------------------------------------------------------------
-- import Circat.RTree
-----------------------------------------------------------------------
data RTree :: * -> * -> * where
RL :: a -> RTree Z a
RB :: Pair (RTree n a) -> RTree (S n) a
r_toL :: a -> RTree Z a
r_toL a = RL a
r_unL :: RTree Z a -> a
r_unL (RL a) = a
r_toB :: Pair (RTree n a) -> RTree (S n) a
r_toB p = (RB p)
r_unB :: RTree (S n) a -> Pair (RTree n a)
r_unB (RB p) = p
r_inL :: (a -> b) -> (RTree Z a -> RTree Z b)
r_inL g = r_toL . g . r_unL
r_inB :: (Pair (RTree m a) -> Pair (RTree n b)) -> (RTree (S m) a -> RTree (S n) b)
r_inB g = r_toB . g . r_unB
instance Functor (RTree n) where
fmap f (RL a ) = RL (f a)
fmap f (RB ts) = RB ((fmap.fmap) f ts)
instance IsNat n => Applicative (RTree n) where
pure a = a <$ units nat
(<*>) = ap''
instance Foldable (RTree n) where
foldMap f (RL a ) = f a
foldMap f (RB ts) = (foldMap.foldMap) f ts
instance Traversable (RTree n) where
traverse f (RL a ) = RL <$> f a
traverse f (RB ts) = RB <$> (traverse.traverse) f ts
instance (IsNat n, Num a) => Num (RTree n a) where
negate = fmap negate
(+) = liftA2 (+)
(*) = liftA2 (*)
fromInteger = pure . fromInteger
abs = fmap abs
signum = fmap signum
units :: Nat n -> RTree n ()
units Zero = RL ()
units (Succ n) = RB (pure (units n))
ap'' :: RTree m (a -> b) -> RTree m a -> RTree m b
ap'' (RL f ) = r_inL (\ x -> f x)
ap'' (RB fs) = r_inB (\ xs -> liftA2 ap'' fs xs)
-- Split into evens & odds
bottomSplit :: IsNat n => RTree (S n) a -> Pair (RTree n a)
bottomSplit = split' nat
where
split' :: Nat n -> RTree (S n) a -> Pair (RTree n a)
split' Zero = r_unB
split' (Succ m) = fmap RB . transpose . fmap (split' m) . r_unB
Okay, that's it for recreated imports; here's the new FFT definition:
In [5]:
-----------------------------------------------------------------------
-- New FFT Definition
-----------------------------------------------------------------------
-- Radix-2, DIT FFT, using RTree:
fft_r2_dit :: (IsNat n, RealFloat a, Enum a) => RTree n (Complex a) -> RTree n (Complex a)
fft_r2_dit = fft_r2_dit' nat
fft_r2_dit' :: (RealFloat a, Enum a) => Nat n -> RTree n (Complex a) -> RTree n (Complex a)
fft_r2_dit' Zero = id
fft_r2_dit' (Succ n) = r_toB -- concatenation of sub-results
. inP (uncurry (+) &&& uncurry (-)) -- sum & difference of sub-transforms
. secondP (liftA2 (*) (phasor n)) -- application of "twiddle" factors to 2nd sub-xform
. fmap (fft_r2_dit' n) -- recursive application of FFT (Note depth reduction.)
. bottomSplit -- Break in half, via de-interleaving (i.e. - DIT).
-- Phasor, as a function of tree depth.
phasor :: (IsNat n, RealFloat a, Enum a) => Nat n -> RTree n (Complex a)
phasor n = scanlTEx (*) 1 (pure phaseDelta)
where phaseDelta = cis ((-pi) / 2 ** natToZ n)
Clearly, the level of abstraction of the FFT definition has been raised, relative to the List case, above. However, we still don't have a generic expression for the FFT computation that is valid for any functor. Note, in particular, the use of toB and bottomSplit in the code, above, which "lock" the implementation to RTree. What we really want is a completely generic expression of the FFT computation, which works for any functor.
Before presenting Conal's solution to this problem, I want to take a short aside, in order to try and recreate, for you, the state of thought in existence, just before Conal made his breakthrough. Recall (5), which expresses a single step in the logarithmic breakdown of the DFT computation:
$$ DFT \{x_n\} = concat[(DFT \{x^e_m\} + \{W_m\} \cdot DFT \{x^o_m\}), (DFT \{x^e_m\} - \{W_m\} \cdot DFT \{x^o_m\})] \qquad (5) $$Now, consider the definition of the DFT of a pair of scalars (i.e. - a vector of length 2):
$$ DFT (x_0, x_1) = (x_0 + x_1, x_0 - x_1) \qquad (6) $$Note that (5) and (6) have the same form: concatenation of the sum and difference of the inputs. The only difference between them is that (5) is summing/differencing vectors (or, some other data structure containing multiple scalars), while (6) is working on scalars. Now, typically, it's considered bad form, when writing Haskell, to separate the degenerate case (i.e. - (6)) from the more general definition (i.e. - (5)), if you don't explicitly need to do so. Let's try to use the discovered similarity of form, above, to clean up our RTree-based FFT implementation:
We'll start by defining a FFT class, so that the actual fft function can be overloaded and called recursively, in a polymorphic fashion. Then, we'll define an instance for Pair and try to use that instance to implement the RTree instance in a more general fashion:
First, we need some more explicit recreations of things that would, normally, be imported:
Skip these.
In [6]:
import Prelude hiding (zip,unzip,zipWith)
import Control.Arrow ((***), first)
import Data.Functor ((<$>))
import Data.Monoid ((<>), Product(..))
infixl 7 :*
type (:*) = (,)
type LScanTy f = forall a. Monoid a => f a -> f a :* a
class Functor f => LScan f where
lscan :: LScanTy f
-- Temporary hack to avoid newtype-like representation.
lscanDummy :: f a
lscanDummy = undefined
instance IsNat n => LScan (RTree n) where
lscan = lscan' nat
lscan' :: Monoid a => Nat n -> RTree n a -> (RTree n a, a)
lscan' Zero = \ (RL a) -> (RL mempty, a)
lscan' (Succ m) = \ (RB ts) -> first RB (lscanComp' lscan (lscan' m) ts) where
lscanComp' :: (Zippable g, Functor g, Functor f, Monoid a) => LScanTy g -> LScanTy f -> g (f a) -> g (f a) :* a
lscanComp' lscanG lscanF gfa = (zipWith adjustl tots' gfa', tot)
where (gfa' ,tots) = unzip (lscanF <$> gfa)
(tots',tot) = lscanG tots
adjustl :: (Monoid a, Functor t) => a -> t a -> t a
adjustl p = fmap (p <>)
instance LScan Pair where
lscan (a :# b) = (mempty :# a, a <> b)
lproducts :: (LScan f, Num b) => f b -> f b :* b
lproducts = (fmap getProduct *** getProduct) . lscan . fmap Product
class Functor f => Zippable f where
zipWith :: (a -> b -> c) -> f a -> f b -> f c
zipWith h as bs = uncurry h <$> zip as bs
zip :: f a -> f b -> f (a,b)
zip = zipWith (,)
{-# MINIMAL zip | zipWith #-}
instance Zippable Pair where
zipWith f (a :# b) (a' :# b') = f a a' :# f b b'
unzip :: Functor f => f (a :* b) -> f a :* f b
unzip ps = (fst <$> ps, snd <$> ps)
Okay, that's it for recreated imports; here's the improved FFT definition:
In [7]:
import Data.Foldable (foldl')
-- FFT, as a class
-- (The LScan constraint comes from the use of 'lproducts', in 'addPhase'.)
class (LScan f) => FFT f a where
fft :: f a -> f a -- Computes the FFT of a functor.
-- Note that this definition of the FFT instance for Pair assumes DIT.
-- How can we eliminate this assumption and make this more general?
instance (RealFloat a, Applicative f, Foldable f, Num (f (Complex a)), FFT f (Complex a))
=> FFT Pair (f (Complex a)) where
fft = inP (uncurry (+) &&& uncurry (-)) . secondP addPhase . fmap fft
instance (IsNat n, RealFloat a) => FFT (RTree n) (Complex a) where
fft = fft' nat
where fft' :: (RealFloat a) => Nat n -> RTree n (Complex a) -> RTree n (Complex a)
fft' Zero = id
fft' (Succ _) = inDIT fft
where inDIT g = r_toB . g . bottomSplit
-- Adds the proper phase adjustments to a functor containing Complex RealFloats,
-- and instancing Num.
addPhase :: (Applicative f, Foldable f, LScan f, RealFloat a, Num (f (Complex a)))
=> f (Complex a) -> f (Complex a)
addPhase = liftA2 (*) id phasor
where phasor f = fst $ lproducts (pure phaseDelta)
where phaseDelta = cis ((-pi) / fromIntegral n)
n = flen f
-- Gives the "length" (i.e. - number of elements in) of a Foldable.
-- (Soon, to be provided by the Foldable class, as "length".)
flen :: (Foldable f) => f a -> Int
flen = foldl' (flip ((+) . const 1)) 0
Finally, we're seeing the general pattern of the recursive logarithmical breakdown exposed, in the code. Note, for instance, that we now have nothing in between toB and bottomSplit, except a recursive call to fft:
fft' (Succ _) = inDIT fft
where inDIT g = toB . g . bottomSplit
This recursive call will, of course, use the FFT instance defined for Pair, since that's what bottomSplit produces. And, the first thing the Pair overload of fft does is apply fft to the two members of the Pair, via fmap:
fft = inP (uncurry (+) &&& uncurry (-)) . secondP addPhase . fmap fft
And, so, by removing the explicit definition of the degenerate case (i.e. - Pair) from the definition of the general case (i.e. - RTree), replacing it, instead, with an overloaded definition of the fft function, via the Haskell type class mechanism, we're making the "ping-pongy" nature of the logarithmic breakdown, which transforms a DFT into a FFT more explicit in our code: we start by asking for the FFT of some data structure (RTree, in this case), then we split that structure into pairs and ask for the FFT of that pair, which in turn brings us back to asking for the FFT of the original structure, half as large. At some point, we'll find ourselves asking for the FFT of a pair of numbers, as opposed to a pair of structures, and the recursion will terminate. At that point, we'll begin unwinding the recursion stack, and assembling the final answer.
We've come a long way from our original List based implementation of the FFT. However, we still have not achieved our goal of a completely general description. Note, for instance, that we are having to explicitly define the inDIT function, which splits the original data structure into two pieces in just the right way, so as to effect a FFT. Note, furthermore, that each new functor for which we'd like to define an FFT instance will need its own unique definition of this function. If we did an instance for LTree, we'd need to find its alternatives to toB and bottomSplit, as those functions don't exist for LTree. Even more insidiously, if more subtly, note that our FFT instance for Pair presumes that we're using decimation in time (DIT) to perform the logarithmic breakdown. What if we wanted to use the other choice: decimation in frequency (DIF)? In that case, we'd need two things:
This is all beginning to seem very messy. We must not be finished, quite yet.
Recall the definition of the DFT of a pair of numbers (as opposed to structures):
$$ DFT (x_0, x_1) = (x_0 + x_1, x_0 - x_1) \qquad (6) $$Note, particularly, that it is identical to the definition of the FFT. This very special case (i.e. - pair of scalars) is completely degenerate and makes all distinctions between DFT and FFT, or DIT and DIF, moot. Might there be a way to use this fact in our quest for a completely generic expression for the FFT of any functor?
We bump into a funny problem, when we first try to use the above realization, in a naive straightforward fashion: we find ourselves wanting the fft function in the FFT class definition to have two different type signatures:
fft :: RealFloat a => f (Complex a) -> f (Complex a) -- for a tree
fft :: RealFloat a => f (g (Complex a)) -> f (g (Complex a)) -- for a pair of trees
We recognize that the second alone would suffice, if we set g equal to the Identity functor for the tree instance, but...
Conal's implementation of FFT for a completely general composition of functors is shown, below. As above, I'm explicitly defining the needed preliminaries, which would normally just be imported, separately from the actual FFT implementation, for clarity.
Note: The code, below, is copyright (c) 2015 Conal Elliott, and copied here with permission; please, read and understand Conal's licensing terms, BEFORE copying and/or using this code. Thanks!
In [8]:
import Prelude hiding (id,(.))
import Control.Category (Category(..))
import Control.Compose ((:.)(..),inO,unO)
-- The left associative equivalent of 'RTree'.
data LTree :: * -> * -> * where
LL :: a -> LTree Z a
LB :: LTree n (Pair a) -> LTree (S n) a
instance Functor (LTree n) where
fmap f (LL a ) = LL (f a)
fmap f (LB ts) = LB ((fmap.fmap) f ts)
instance IsNat n => Applicative (LTree n) where
pure = pure' nat
(<*>) = liftA2'' ($)
instance Foldable (LTree n) where
foldMap f (LL a ) = f a
foldMap f (LB ts) = (foldMap.foldMap) f ts
instance Traversable (LTree n) where
traverse f (LL a ) = LL <$> f a
traverse f (LB ts) = LB <$> (traverse.traverse) f ts
pure' :: Nat n -> a -> LTree n a
pure' Zero a = LL a
pure' (Succ n) a = LB (pure' n (pure a))
liftA2'' :: (a -> b -> c) -> LTree m a -> LTree m b -> LTree m c
liftA2'' f (LL a ) = \ (LL b ) -> LL (f a b)
liftA2'' f (LB as) = \ (LB bs) -> LB (liftA2'' (liftA2 f) as bs)
l_unL :: LTree Z a -> a
l_unL (LL a) = a
l_unB :: LTree (S n) a -> LTree n (Pair a)
l_unB (LB p) = p
-- Statically sized functors.
class Sized f where
size :: f () -> Int -- Argument is ignored at runtime
instance Sized Pair where size = const 2
instance (Sized g, Sized f) => Sized (g :. f) where
size = const ((size (undefined :: (g) ())) * (size (undefined :: (f) ())))
instance IsNat n => Sized (LTree n) where
size = const (twoNat (nat :: Nat n))
instance IsNat n => Sized (RTree n) where
size = const (twoNat (nat :: Nat n))
twoNat :: Integral m => Nat n -> m
twoNat n = 2 ^ (natToZ n :: Int)
Okay, that's it for recreated imports; here's the implementation of FFT for functor compositions:
In [9]:
type DFTTy f f' = forall a. RealFloat a => f (Complex a) -> f' (Complex a)
class FFT' f f' | f -> f' where
fft' :: DFTTy f f'
instance ( Applicative f , Traversable f , Traversable g
, Applicative f', Applicative g', Traversable g'
, FFT' f f', FFT' g g', LScan f, LScan g', Sized f, Sized g' ) =>
FFT' (g :. f) (f' :. g') where
-- fft' = inO (transpose . fmap fft' . twiddle . transpose . (fmap fft') . transpose)
-- This is equivalent and lends itself better to my explanation, below:
fft' = inO (transpose . fmap fft' . transpose . twiddle . (fmap fft') . transpose)
type AFS h = (Applicative h, Foldable h, Sized h, LScan h)
twiddle :: (AFS g, AFS f, RealFloat a) => Unop (g (f (Complex a)))
twiddle = (liftA2.liftA2) (*) twiddles
twiddles :: forall g f a. (AFS g, AFS f, RealFloat a) => g (f (Complex a))
twiddles = powers <$> powers (omega (size (undefined :: (g :. f) ())))
omega :: (Integral n, RealFloat a) => n -> Complex a
omega n = exp (- 2 * (0:+1) * pi / fromIntegral n)
-- Powers of x, starting x^0. Uses 'LScan' for log parallel time
powers :: (LScan f, Applicative f, Num a) => a -> f a
powers = fst . lproducts . pure
{--------------------------------------------------------------------
Specialized FFT instances
--------------------------------------------------------------------}
-- Radix 2 butterfly
instance FFT' Pair Pair where
fft' (a :# b) = (a + b) :# (a - b)
-- Handle trees by conversion to functor compositions.
instance IsNat n => FFT' (LTree n) (RTree n) where
fft' = fft'' nat
where
fft'' :: Nat m -> DFTTy (LTree m) (RTree m)
fft'' Zero = RL . l_unL
fft'' (Succ _) = RB . unO . fft' . O . l_unB
instance IsNat n => FFT' (RTree n) (LTree n) where
fft' = fft'' nat
where
fft'' :: Nat m -> DFTTy (RTree m) (LTree m)
fft'' Zero = LL . r_unL
fft'' (Succ _) = LB . unO . fft' . O . r_unB
res = dft $ map (:+ 0.0) [0.0 :: PrettyDouble, 1.0, 2.0, 3.0]
map MyComplex res
-- We define two simple *Show* instances, so that we can view intermediate results:
-- (I'm not using Conal's full blown Show instances, here,
-- only because they are rather lengthy and I don't need their full power.)
instance Show a => Show (LTree n a) where
show (LL a) = "(" ++ show a ++ ")"
show (LB ts) = concat (fmap show ts)
instance Show a => Show (RTree n a) where
show (RL a) = "(" ++ show a ++ ")"
show (RB ts) = concat (fmap show ts)
p1 = MyComplex ((0.0 :: PrettyDouble) :+ 0.0) :# MyComplex ((1.0 :: PrettyDouble) :+ 0.0)
p2 = MyComplex ((2.0 :: PrettyDouble) :+ 0.0) :# MyComplex ((3.0 :: PrettyDouble) :+ 0.0)
-- myLTree = LB (LB (LL ((((0.0 :: PrettyDouble) :+ 0.0) :# (1.0 :+ 0.0)) :# ((2.0 :+ 0.0) :# (3.0 :+ 0.0)))))
myLTree = LB (LB (LL (fmap complex p1 :# fmap complex p2)))
fmap MyComplex $ fft' myLTree
The first thing that stands out in Conal's new code is the type signature of both the FFT class definition, as well as its sole member function:
type DFTTy f f' = forall a. RealFloat a => f (Complex a) -> f' (Complex a)
class FFT' f f' | f -> f' where
fft' :: DFTTy f f'
Note that, despite the slight change of syntax (i.e. - "fft" => "fft'" & "FFT" => "FFT'"), here, which was made only to avoid conflict with the previous definition of the FFT class in this notebook, Conal's actual code continues to use "fft" and "FFT". So, I will also use those labels, here.
Whereas, previously, we'd assumed that the fft function should map from a functor to the same functor, here we are allowing the enclosing functor to change. (It was, in fact, this insight, which broke the log jam referred to, above, wherein we were led to believe that we needed two different type signatures for the fft function.)
Next, consider the FFT instance definition for a general composition of functors:
instance ( Applicative f , Traversable f , Traversable g
, Applicative f', Applicative g', Traversable g'
, FFT' f f', FFT' g g', LScan f, LScan g', Sized f, Sized g' ) =>
FFT' (g :. f) (f' :. g') where
fft' = inO (transpose . fmap fft' . transpose . twiddle . fmap fft' . transpose)
Note the transposition of functors in the instance head (i.e. - from (g :. f) to (f' :. g')). Furthermore, note the odd number of calls to the transpose function, in the definition of fft. Finally, recall that the first step in breaking down the DFT computation is a "de-interleaving" process, whereby we separate the even and odd members of the original data set. Note that in completing the DFT => FFT conversion we never undo this de-interleaving. The transpose function, above, provides this de-interleaving (at least, for the DIT case) in a generic way, assuming very little about the nature of the enclosing functor. (It must just be a Traversable.) The fact that we don't want to undo this de-interleaving is consistent with an odd number of calls to transpose, in the definition of fft. However, such an application of an odd number of calls to the transpose function necessarily leaves a functor composition in an "inverted" state, where the outer functor has become the inner and vice versa. And, so, we see that Conal's insight, above (i.e. - giving both the class and its sole member function the freedom to transform the enclosing functor), was absolutely necessary, in order to express the FFT of a functor composition in its most natural and elegant form.
Next, consider Conal's FFT instance for Pair:
instance FFT' Pair Pair where
fft' (a :# b) = (a + b) :# (a - b)
Note that he has chosen to express the degenerate form, which completely eliminates any distinction between DIT and DIF decomposition (as well as DFT vs. FFT, for that matter). (Recall our brief aside, above, on Fourier transforming pairs.) In this way, he has given us a completely general purpose building block for constructing the FFT instances for more complicated structures. However, at first glance, he also appears to have ducked the responsibility for properly applying the "phase twiddles" to the second member of the incoming pair, in the case when the pair contains stuctures, as opposed to scalar values. Or, has he...
Let's consider one of Conal's implementations of the FFT instance for trees, which makes use of his functor composition instance:
instance IsNat n => FFT' (LTree n) (RTree n) where
fft' = fft'' nat
where
fft'' :: Nat m -> DFTTy (LTree m) (RTree m)
fft'' Zero = RL . l_unL
fft'' (Succ _) = RB . unO . fft' . O . l_unB
Note first that, unlike all previous tree instances in this notebook, this instance morphs a LTree into a RTree. We'll see why this is a very natural thing to do, shortly.
(In the discussion, below, I abandon the previous convention of referring to fft' in actual code snippets as fft in the accompanying exlanatory verbiage, and use the actual names of functions as they appear in the code, as an aid to the reader trying to precisely correlate my verbal explanation with those code snippets.)
Next, note the indirection provided by fft''. This is just a syntactical convenience, which obviates the need for the user of our library to take care of passing in the correct tree size. Instead, we use the underlying Haskell dictionary mechanism to infer the proper size automatically.
The leaf case (i.e. - fft'' Zero) isn't very interesting. It's essentially just id with the proper functor unwrapping/re-wrapping machinery needed, in order to statisfy the type signature.
The branch case (i.e. - fft'' (Succ _)) is the truly interesting one, and is where Conal's magic is really being performed. Let's take it a step at a time, keeping in mind our goal of finding the, apparently, missing phase twiddles. We'll do this, by constructing a LTree value, carefully sculpted to make tracking the data flow through the computation particularly easy, and applying the fft'' function to it, one step at a time.
First, we'll expand the definition of fft'', above, using the appropriate definition of fft', namely that taken from the functor composition instance (because, that's what O produces). In performing this expansion, I have omitted the triple: inO / O / unO, since their only real function is to force the selection of the proper fft' overload in the code, immediately above:
fft'' (Succ _) = RB . transpose . fmap fft' . transpose . twiddle . (fmap fft') . transpose . l_unB
I define a LTree in which the values are equal to the position indices of an equivalent list/vector, because it makes tracking the data flow through the computation easier:
(The values are only made complex, here, in order to avoid a type matching error, otherwise, further down.)
In [10]:
:t myLTree
fmap MyComplex myLTree
And, finally, apply the first step in the computation:
In [11]:
res = l_unB myLTree
:t res
fmap (fmap MyComplex) res
We see that, as expected, the application of l_unB has transformed a tree of values, having depth n + 1 into a tree of pairs having depth n.
Continuing on:
In [12]:
res2 = transpose res
:t res2
fmap (fmap MyComplex) res2
Again as expected, applying transpose has given us the required de-interleaving (or, separation of even and odd values) required by (5).
(Note that, in the process, it has converted a tree of pairs into a pair of trees.)
In [13]:
res3 = fmap fft' res2
:t res3
fmap (fmap MyComplex) res3
We see that mapping fft' over the enclosing Pair has had the expected effect, yielding new trees consisting of the sum and difference of the elements of the original tree, in both cases. (Note the inner functor transformation, from LTree to RTree.) (The previous sentence may be confusing, because, until now, we've discussed the DFT/FFT of a pair as being composed of the sum and difference of the original elements. However, it is true that the DFT/FFT of any two element structure is composed of two elements, which are the sum and difference of the original elements, respectively.)
In [14]:
res4 = twiddle res3
:t res4
fmap (fmap MyComplex) res4
The twiddling has had the expected effect: leaving the first element of the tree unchanged and multiplying the second element, in point-by-point fashion, with the pair: $(1 :\# -j)$.
Now, if you're like me, the next step seems out of place, when considering (5) in which there is just one transposition (i.e. - deinterleaving).
In [15]:
res5 = transpose res4
:t res5
fmap (fmap MyComplex) res5
We see that the transposition has had the expected effect, but why was it necessary?
Consider the following:
So, we see the motivation for transposing: it makes the required summing/differencing acheivable, via an additional recursive call to fft'. But, doesn't the arithmetic get screwed up, by this swapping of components before adding/subtracting? As it turns out, no; observe:
We want:
$$ RTree ([w, x] + [y, z]) \quad \text{:#} \quad RTree ([w, x] - [y, z]) $$$$ = RTree \; ([w + y, \; x + z]) \quad \text{:#} \quad RTree \; ([w - y, \; x - z]) $$We're doing:
$$ transpose \; ( \; fmap \; fft' \; (transpose \; (RTree \; [w, x] \text{ :# } RTree \; [y, z]))) $$$$ = transpose \; (fmap \; fft' \; RTree \; ((w \text{ :# } y), (x \text{ :# } z))) $$$$ = transpose \; RTree \; ([w + y, w - y] \text{ :# } [x + z, x - z]) $$$$ = RTree \; ([w + y, \; x + z]) \quad \text{:#} \quad RTree \; ([w - y, \; x - z]) $$And we see that we are getting the desired result.
Let's do a sanity check of all this, with the actual code:
(Note that the final application of RB just changes the structure from a pair of trees to a tree of increased depth; it doesn't alter any values or their positions.)
In [16]:
res6 = RB . transpose . fmap fft' $ res5
:t res6
fmap MyComplex res6
We test for the correct answer, using our reference function: dft:
In [17]:
dft $ map (:+ 0) [0 :: PrettyDouble, 1, 2, 3]
So, we're getting the correct answer, despite the code seeming inconsistent with the algebra of (5). But, why not just make use of the Num instance defined for RTree, which does exactly what we want: applies (+) and (-) in an elementwise fashion? Presumably, if we did so, we could eliminate all of this transposing and re-transposing. That is, instead of:
transpose . fmap fft' . transpose
why not just:
fft'
Let's try it, and see if we get the right answer:
In [18]:
res7 = RB . fft' $ res4
:t res7
fmap MyComplex res7
What went wrong? Let's find out:
In [19]:
:t RB
So, RB expects to consume a pair of trees and produce a single tree of increased depth.
fft' is overloaded. So, its type depends upon its argument. Let's see what that is:
In [20]:
:t res4
Okay, it's a Pair and the Pair instance of FFT' defines fft' to have type:
fft' :: Pair a -> Pair a
Now, fft' is expected (by RB) to produce: Pair (RTree n a).
So, it'd better be consuming: Pair (RTree n a), as well (since it produces the same type as it consumes). It's consuming res4, whose type is printed, above. And, it seems we should have type consistency, with:
n = S Z
a = Complex PrettyDouble
So, why did we get a type error? In particular, in the second error block, above, where is the expectation of type: Pair (Complex Double) -> Pair (RTree n a) coming from?
A little more poking reveals the problem:
In [21]:
:t RB . fft'
In [22]:
:t fft'
And, finally, we see that the problem is we're trying to apply fft' directly to a nested functor. Its type signature doesn't allow that! I alluded to this, above, when I mentioned the "funny problem" we bump into, when we try to apply the overloaded fft' function to both pairs of scalars and pairs of structures. Now, you've seen the details of this funny problem, in all their gory detail.
So, we must map the fft' function over the enclosing functor, in order to avoid type errors, and since we must map, we must also transpose twice, in order to get the correct answer.
So, there you have it; the power and elegance of functor composition brought to bear on the problem of expressing the FFT in a completely generic way, ala Conal Elliott. If you're reading this, thank you! for your interest in our little project. And, if you think you might be interested in working on it with us, please, let us know; we'd love to hear from you. :)
Lastly, I want to express my grattitude to Conal for having taken me under his tutelage, as I struggle to understand these higher level programming concepts and constructs and for his loving patience with my rather pedantic mode of learning. Thanks, Conal!
Just before switching gears to Conal's functor composition approach, we mentioned that we might be able to solve our fft type problem, by defining it to operate on a nested pair of functors and simply setting the inner functor to Identity when appropriate. Let's investigate that a bit further and see where it leads. We'll start by following Conal's lead and using the degenerate case for the Pair instance, so as not to force a DIT implementation, and see if we can make that work.
In [28]:
import Data.Functor.Identity (Identity)
class (LScan f) => FFT f g a where
fft :: f (g a) -> f (g a)
-- instance (Num (g a)) => FFT Pair (g a) where
instance (Num (g a)) => FFT Pair g a where
fft = inP (uncurry (+) &&& uncurry (-))
instance (IsNat n, RealFloat a) => FFT (RTree n) Identity (Complex a) where
fft = fft' nat
-- where fft' :: (RealFloat a) => Nat n -> RTree n (Complex a) -> RTree n (Complex a)
where fft' :: (RealFloat a) => Nat n -> RTree n (Identity (Complex a)) -> RTree n (Identity (Complex a))
fft' Zero = id
fft' (Succ _) = inDIT fft
where inDIT g = r_toB . g . bottomSplit
If we want to go any further down this path, we'd have to define a Num instance for (Identity (Complex a)), which seems awfully unnecessarily laborious and most inelegant. In fact, there's a more insidious problem lying in wait: if we ask for the FFT of a Pair of anything other than scalars, we're going to get the wrong answer, because the FFT instance defined for Pair, above, is only going to do the summing and differencing. That's fine for scalars, but incorrect for anything else, which requires a recursive call to fft, followed by "twiddling" of the second result.
So, we quickly find that Conal's Functor Composition approach to this problem is much more than a clever trick to get around a problematic type error. It is, in fact, a fundamentally more powerful approach, which allows us the freedom to define the FFT instance for Pair degenerately, while still yielding the correct answer for any input. And, it is precisely the mechanics of functor composition, which allows us this freedom and convenience.