This IHaskell Jupyter notebook contains my attempts at the exercises posed at the end of Practical Dependent Types in Haskell: Type-Safe Neural Networks (Part 1) by Justin Le.
Original author: David Banas capn.freako@gmail.com
Original date: January 10, 2018
Copyright © 2018 David Banas; all rights reserved World wide.
In [1]:
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
import Control.Monad.Random
import Data.Singletons.Prelude
import Data.Singletons.TypeLits
import GHC.TypeLits
import Numeric.LinearAlgebra.Static
data Weights i o = W { wBiases :: !(R o)
, wNodes :: !(L o i)
} -- an "o x i" layer
data Network :: Nat -> [Nat] -> Nat -> * where
O :: !(Weights i o)
-> Network i '[] o
(:&~) :: KnownNat h
=> !(Weights i h)
-> !(Network h hs o)
-> Network i (h ': hs) o
infixr 5 :&~
randomWeights :: (MonadRandom m, KnownNat i, KnownNat o)
=> m (Weights i o)
randomWeights = do
s1 :: Int <- getRandom
s2 :: Int <- getRandom
let wB = randomVector s1 Uniform * 2 - 1
wN = uniformSample s2 (-1) 1
return $ W wB wN
randomNet :: forall m i hs o. (MonadRandom m, KnownNat i, SingI hs, KnownNat o)
=> m (Network i hs o)
randomNet = go sing
where
go :: forall h hs'. KnownNat h
=> Sing hs'
-> m (Network h hs' o)
go = \case
SNil -> O <$> randomWeights
SNat `SCons` ss -> (:&~) <$> randomWeights <*> go ss
In [2]:
pop :: (KnownNat i, KnownNat o, KnownNat h) => Network i (h ': hs) o -> (Weights i h, Network h hs o)
pop (w :&~ n) = (w, n)
Think about what its type would have to be. Could it possibly be called with a network that cannot be popped? (that is, that has only one layer?)
No, because the (':) in the type signature would cause an error to be flagged at compile time.
Let's confirm this...
In [3]:
r = do
(w1 :: Weights 5 3) <- randomWeights
(w2 :: Weights 3 1) <- randomWeights
let r1 = pop $ O w2
return r1
Okay, good, our error was caught at compile time as expected. Now, let's make sure a correct case goes through...
In [4]:
r = do
(w1 :: Weights 5 3) <- randomWeights
(w2 :: Weights 3 1) <- randomWeights
let r1 = pop $ w1 :&~ O w2
return r1
Okay, looks good.
In [5]:
addW :: (KnownNat i, KnownNat o)
=> Weights i o
-> Weights i o
-> Weights i o
addW (W b1 w1) (W b2 w2) = W (b1 + b2) (w1 + w2)
addN :: (KnownNat i, KnownNat o)
=> Network i hs o
-> Network i hs o
-> Network i hs o
addN (O w1) (O w2) = O (addW w1 w2)
addN (w1 :&~ n1) (w2 :&~ n2) = addW w1 w2 :&~ addN n1 n2
Could this function ever be accidentally called on two networks that have different internal structures?
I don't think so, since the i and o are shared by the two arguments in the type signature, but let's confirm...
In [6]:
-- Test different network depths.
r = do
(w1 :: Weights 5 3) <- randomWeights
(w2 :: Weights 3 1) <- randomWeights
let r1 = addN (w1 :&~ O w2) (O w1)
return r1
Good, we expected failure.
In [7]:
-- Test different network widths.
r = do
(w1 :: Weights 5 3) <- randomWeights
(w2 :: Weights 3 1) <- randomWeights
let r1 = addN (w1 :&~ O w2) (w2 :&~ O w1)
return r1
Good, we expected failure.
In [8]:
-- Test a correct case.
r = do
(w1 :: Weights 5 3) <- randomWeights
(w2 :: Weights 3 1) <- randomWeights
let r1 = addN (w1 :&~ O w2) (w1 :&~ O w2)
return r1
Good, we expected success.
In [12]:
hiddenSing :: (SingI hs) => Network i hs o -> Sing hs
hiddenSing (n :: Network i hs o) = sing :: Sing hs
r = do
net <- randomNet
return $ hiddenSing (net :: Network 5 '[3] 1)
:t r