The basics of elliptic curve cryptography

This is an excercise of doing modular arithmetic in Haskell. It was inspired by some examples written in Python from a video about the basics of elliptic curve cryptography. Code samples from the slides can be found here. The Haskell implementation is quite raw and needs improvement which will come as I gradually get to grips with monads and applicatives.

Modular arithmetic

We are going to have to set some compiler options in order to implement a familiar looking printed version of the congruence class type. The first one allows us to specify type constructor with arguments in the definition of the instance of Show. The second one will tell compiler not to complain if we shadow the existing more general definition of Show for Maybe with our more specific implementation.


In [1]:
{-# LANGUAGE FlexibleInstances, OverlappingInstances #-}

In the examples from the linked video presenters considered fixed values of the modulus $n$ in each part of the discussion. This is natural since it must be standardized for the purposes of creating keys. We are going to account for a more general situation to make code more generic and reusable and must, therefore, safeguard ourselves from undefined operations between congruence classes corresponding to different values of $n$.


In [2]:
import Control.Monad
import Control.Applicative

Now we declare our type for congruence classes. We keep the name short for the purposes of this example as we will have to type it often. The first argument will be the divisor and the second will be the integer to which the modulo operation is applied. In other words M x y means $y \operatorname{mod} x$. Essentially we are goint to implement a generalised version of the F7 class from the lecture:

class F7():
    def __init__(self,x):
        self.int = x % 7
    def __str__(self):
        return str(self.int)
    def __repr__(self):
        return self.__str__()
    def __eq__(self, other):
        return self.int == other.int
    def __add__(self, other):
        return F(self.int + other.int)
    def __sub__(self, other):
        return F(self.int - other.int)
    def __mul__(self, other):
        return F(self.int + other.int)

First we define how the values will be printed


In [3]:
data M = M Int Int
instance Show M where show (M a b) = show b ++ " mod " ++ show a
instance Show (Maybe M) where show (Just a) = show a
                              show _ = "undefined"

We will also need to compare values to perform some verification of our implementations


In [4]:
instance Eq M where (==) (M a b) (M c d) = b `mod` a == d `mod` a

In order to implement the equivalent of F7 class from the lecture (which is instantiated as concgurence class $x \operatorname{mod} 7$) we would need to partially apply our type constructor with a literal value. It turns out, however, that it is not possible to do so without resorting to very involved workarounds. Compiler will suggest to use -XDataKinds to enable, however, we choose not to follow that route and define a type conversion function instead.


In [5]:
fromInt :: Int -> Int -> M
fromInt = M

Now we proceed to implement the safe operations on congruence classes that will only return valid results if the divisor values match. +, - and * are easy to do


In [6]:
type BinaryOp = Int -> Int -> Int
safeOp :: BinaryOp -> M -> M -> Maybe M
safeOp op (M a b) (M c d)
    | a == c    = Just (M a ((b `op` d) `mod` a))
    | otherwise = Nothing

In [7]:
(#+) = safeOp (+)
(#-) = safeOp (-)
(#*) = safeOp (*)

Division is trickier because the result does not always exist even when numerator and denominator are members of the same field. We will have to use an implementation of the Extended Euclidean algorithm to establish whether modular inverses exist and calculate them them if they do


In [8]:
extendedEuclid a b = extendedEuclid' 1 0 0 1 b a
extendedEuclid' _ oldT _ oldS 0 oldR = (oldR, oldT, oldS)
extendedEuclid' t oldT s oldS r oldR = extendedEuclid' t' oldT' s' oldS' r' oldR'
  where t' = (oldT - quotient * t)
        oldT' = t
        s' = (oldS - quotient * s)
        oldS' = s
        r' = (oldR - quotient * r)
        oldR' = r 
        quotient = oldR `div` r

In [9]:
(#/) (M a b) (M c d)
    | a /= c = Nothing
    | otherwise = case r of
                    1 -> Just (M a (e * b `mod` a))
                    _ -> Nothing
                where (r, _, e) = extendedEuclid d a

Check that the results make sense


In [10]:
x = M 6 5
y = M 6 4
z = M 6 2
y #/ x
x #* z == Just y


2 mod 6
True

We will also need to perform arithmetic on the results of the "safe" operations defined above, i.e. monadic values. This is the part that could probably be improved a lot which in turn would lead to simplification of some of the subsequent code. However, for now we will still be able to demonstrate the result we are after in this document.


In [11]:
instance Num (Maybe M) where 
(+) = liftM2 (#+)
(-) = liftM2 (#-)
(*) = liftM2 (#*)
(/) = liftM2 (#/)

Now we can define $\mathbb{Z}/7\mathbb{Z}$ (equivalent to F7 class from the presentation) and $\mathbb{Z}/5\mathbb{Z}$ as follows


In [12]:
f7 = fromInt 7
f5 = fromInt 5

instantiate some congruence classes


In [13]:
a = f7 1
b = f7 6
c = f7 19
d = f5 41
e = f5 38

and demonstrate our arithmetic and presentation at work


In [14]:
show a ++ " + " ++ show c ++ " = " ++ show (a #+ c)
show b ++ " * " ++ show c ++ " = " ++ show (b #* c)
show d ++ " - " ++ show e ++ " = " ++ show (d #- e)
show a ++ " + " ++ show e ++ " = " ++ show (a #+ e)


"1 mod 7 + 19 mod 7 = 6 mod 7"
"6 mod 7 * 19 mod 7 = 2 mod 7"
"41 mod 5 - 38 mod 5 = 3 mod 5"
"1 mod 7 + 38 mod 5 = undefined"

Curve point arithmetic

Now we can implement addition of ponts on the "circle clock" $\mathbb{Z}/n\mathbb{Z}\times \mathbb{Z}/n\mathbb{Z}$ and on the Edwards elliptic curve equivalent to the following Python implementation:

def clockadd(P1,P2):
    x1,y1 = P1
    x2,y2 = P2
    x3 = x1*y2+y1x2
    y3 = y1*y1-x1*x2
    return x3,y3


def edwardsadd(P1,P2):
    x1,y1 = P1
    x2,y2 = P2
    x3 = (x1*y2+y1*x2)/(one+d*x1*y1*x2*y2)
    y3 = (y1*y2-x1*x2)/(one-d*x1*y1*x2*y2)
return x3,y3

In [15]:
type Point = (M, M)
type Rotation = Point -> Point -> Maybe Point

-- Rotation on a circle
clockAdd :: Rotation
clockAdd (x1, y1) (x2, y2) = case (a,b) of 
    (Just x, Just y) -> Just (x, y)
    otherwise        -> Nothing
  where (a,b) = (join $ (x1 #* y2) + (y1 #* x2), join $ (y1 #* y2) - (x1 #* x2))

-- Rotation on an Edwards curve
edwardsAdd :: M -> Rotation
edwardsAdd d (x1, y1) (x2, y2) = case (a, b) of 
    (Just x, Just y) -> Just (x, y)
    otherwise        -> Nothing
  where (a,b) = (join $ (join $ (x1 #* y2) + (y1 #* x2)) / denP, join $ (join $ (y1 #* y2) - (x1 #* x2)) / denM)
        denP   = join $ Just (M e 1) + (join $ (Just d) * (join $ (x1 #* x2) * (y1 #* y2)))
        denM   = join $ Just (M e 1) - (join $ (Just d) * (join $ (x1 #* x2) * (y1 #* y2)))
        M e _ = x1

We can now test our implementation on the same values as those used in the lecture:


In [16]:
f1009 = fromInt 1009
d = f1009 (-11)
p1 = (f1009 7 ,f1009 415)
p2 = (f1009 23 ,f1009 487)
edwardsAdd d p2 p1


Just (944 mod 1009,175 mod 1009)

We also check that the addition of points on the clock has the desired arithmetic properties. Fix the modulo and the starting point:


In [17]:
n = 1000003
fp = fromInt n
p = (fp 1000, fp 2)

Perform some repeated additions:


In [18]:
p2 = clockAdd p p
p2


Just (4000 mod 1000003,7 mod 1000003)

In [19]:
p3 = p2 >>= clockAdd p
p3


Just (15000 mod 1000003,26 mod 1000003)

In [20]:
p4 = p3 >>= clockAdd p
p5 = p4 >>= clockAdd p
p6 = p5 >>= clockAdd p
p6


Just (780000 mod 1000003,1351 mod 1000003)

Check that we get to the same point folloing different paths:


In [21]:
join (clockAdd <$> p3 <*> p3) == p6


True

In [22]:
join (clockAdd <$> p2 <*> p4) == p6


True

Finally we will implement the scalarmult function that takes a point p, an integer n and performs rotation of p by adding n copies of it. The function is equivalent to the following Python implementation, except that we make the addition function a parameter, so that we could easily swap it for an alternative (clockAdd or edwardsAdd in our case).

def scalarmult(n,P):
    if n == 0: return (Fp(0),Fp(1))
    if n == 1: return P
    Q = scalarmult(n // 2,P)
    Q = clockadd(Q,Q)
    if n % 2: Q = clockadd(P,Q)
    return Q

In [23]:
type Steps = Int
scalarMult :: Rotation -> Steps -> Point -> Maybe Point
scalarMult _ 0 (M a n, M b _)
    | a == b    = Just (M n 0, M n 1)
    | otherwise = Nothing   
scalarMult _ 1 p = Just p
scalarMult addFn n p = case (odd n) of
    True      -> addFn p q
    _         -> Just q
  where Just q = addFn r r
        Just r = scalarMult addFn (n `div` 2) p

Again check that multiplying the given point by a scalar gives results consistent with consecutive addition:


In [24]:
scalarMult clockAdd 6 p


Just (780000 mod 1000003,1351 mod 1000003)

In [25]:
circleMult = scalarMult clockAdd

In [26]:
(p2 >>= circleMult 3) == p6
(p2 >>= circleMult 3) == (p3 >>= circleMult 2)


True
True

Diffie-Hellmann protocol

Now we are ready to demonstrate the steps of the Diffie-Hellmann protocol that provides a secure method of deriving a shared secret key.

Bob and Alice standardize the prime number n and the starting point p. Alice chooses her secret and computes her public key.


In [27]:
aliceSecret = 397
alicePub = circleMult aliceSecret p
alicePub


Just (662233 mod 1000003,576366 mod 1000003)

Bob chooses his secret and computes his public key.


In [28]:
bobSecret = 479
bobPub = circleMult bobSecret p
bobPub


Just (903916 mod 1000003,250061 mod 1000003)

Alice receives Bob's public key and computes the shared secret


In [29]:
aliceShared = bobPub >>= circleMult aliceSecret

Bob receives Alice's public key and computes the shared secret


In [30]:
bobShared = alicePub >>= circleMult bobSecret

The rules of point arithmetic on the circle guarantee that both sides arrive at the same shared key:


In [31]:
aliceShared == bobShared


True

And just in case you are wondering:


In [32]:
aliceShared


Just (46646 mod 1000003,334110 mod 1000003)