The basics of elliptic curve cryptography
This is an excercise of doing modular arithmetic in Haskell. It was inspired by some examples written in Python from a video about the basics of elliptic curve cryptography. Code samples from the slides can be found here. The Haskell implementation is quite raw and needs improvement which will come as I gradually get to grips with monads and applicatives.
Modular arithmetic
We are going to have to set some compiler options in order to implement a familiar looking printed version of the congruence class type. The first one allows us to specify type constructor with arguments in the definition of the instance of Show
. The second one will tell compiler not to complain if we shadow the existing more general definition of Show
for Maybe
with our more specific implementation.
In [1]:
{-# LANGUAGE FlexibleInstances, OverlappingInstances #-}
In the examples from the linked video presenters considered fixed values of the modulus $n$ in each part of the discussion. This is natural since it must be standardized for the purposes of creating keys. We are going to account for a more general situation to make code more generic and reusable and must, therefore, safeguard ourselves from undefined operations between congruence classes corresponding to different values of $n$.
In [2]:
import Control.Monad
import Control.Applicative
Now we declare our type for congruence classes. We keep the name short for the purposes of this example as we will have to type it often. The first argument will be the divisor and the second will be the integer to which the modulo operation is applied. In other words M x y
means $y \operatorname{mod} x$. Essentially we are goint to implement a generalised version of the F7
class from the lecture:
class F7():
def __init__(self,x):
self.int = x % 7
def __str__(self):
return str(self.int)
def __repr__(self):
return self.__str__()
def __eq__(self, other):
return self.int == other.int
def __add__(self, other):
return F(self.int + other.int)
def __sub__(self, other):
return F(self.int - other.int)
def __mul__(self, other):
return F(self.int + other.int)
First we define how the values will be printed
In [3]:
data M = M Int Int
instance Show M where show (M a b) = show b ++ " mod " ++ show a
instance Show (Maybe M) where show (Just a) = show a
show _ = "undefined"
We will also need to compare values to perform some verification of our implementations
In [4]:
instance Eq M where (==) (M a b) (M c d) = b `mod` a == d `mod` a
In order to implement the equivalent of F7
class from the lecture (which is instantiated as concgurence class $x \operatorname{mod} 7$) we would need to partially apply our type constructor with a literal value. It turns out, however, that it is not possible to do so without resorting to very involved workarounds. Compiler will suggest to use -XDataKinds to enable
, however, we choose not to follow that route and define a type conversion function instead.
In [5]:
fromInt :: Int -> Int -> M
fromInt = M
Now we proceed to implement the safe operations on congruence classes that will only return valid results if the divisor values match. +
, -
and *
are easy to do
In [6]:
type BinaryOp = Int -> Int -> Int
safeOp :: BinaryOp -> M -> M -> Maybe M
safeOp op (M a b) (M c d)
| a == c = Just (M a ((b `op` d) `mod` a))
| otherwise = Nothing
In [7]:
(#+) = safeOp (+)
(#-) = safeOp (-)
(#*) = safeOp (*)
Division is trickier because the result does not always exist even when numerator and denominator are members of the same field. We will have to use an implementation of the Extended Euclidean algorithm to establish whether modular inverses exist and calculate them them if they do
In [8]:
extendedEuclid a b = extendedEuclid' 1 0 0 1 b a
extendedEuclid' _ oldT _ oldS 0 oldR = (oldR, oldT, oldS)
extendedEuclid' t oldT s oldS r oldR = extendedEuclid' t' oldT' s' oldS' r' oldR'
where t' = (oldT - quotient * t)
oldT' = t
s' = (oldS - quotient * s)
oldS' = s
r' = (oldR - quotient * r)
oldR' = r
quotient = oldR `div` r
In [9]:
(#/) (M a b) (M c d)
| a /= c = Nothing
| otherwise = case r of
1 -> Just (M a (e * b `mod` a))
_ -> Nothing
where (r, _, e) = extendedEuclid d a
Check that the results make sense
In [10]:
x = M 6 5
y = M 6 4
z = M 6 2
y #/ x
x #* z == Just y
We will also need to perform arithmetic on the results of the "safe" operations defined above, i.e. monadic values. This is the part that could probably be improved a lot which in turn would lead to simplification of some of the subsequent code. However, for now we will still be able to demonstrate the result we are after in this document.
In [11]:
instance Num (Maybe M) where
(+) = liftM2 (#+)
(-) = liftM2 (#-)
(*) = liftM2 (#*)
(/) = liftM2 (#/)
Now we can define $\mathbb{Z}/7\mathbb{Z}$ (equivalent to F7
class from the presentation) and $\mathbb{Z}/5\mathbb{Z}$ as follows
In [12]:
f7 = fromInt 7
f5 = fromInt 5
instantiate some congruence classes
In [13]:
a = f7 1
b = f7 6
c = f7 19
d = f5 41
e = f5 38
and demonstrate our arithmetic and presentation at work
In [14]:
show a ++ " + " ++ show c ++ " = " ++ show (a #+ c)
show b ++ " * " ++ show c ++ " = " ++ show (b #* c)
show d ++ " - " ++ show e ++ " = " ++ show (d #- e)
show a ++ " + " ++ show e ++ " = " ++ show (a #+ e)
Curve point arithmetic
Now we can implement addition of ponts on the "circle clock" $\mathbb{Z}/n\mathbb{Z}\times \mathbb{Z}/n\mathbb{Z}$ and on the Edwards elliptic curve equivalent to the following Python implementation:
def clockadd(P1,P2):
x1,y1 = P1
x2,y2 = P2
x3 = x1*y2+y1x2
y3 = y1*y1-x1*x2
return x3,y3
def edwardsadd(P1,P2):
x1,y1 = P1
x2,y2 = P2
x3 = (x1*y2+y1*x2)/(one+d*x1*y1*x2*y2)
y3 = (y1*y2-x1*x2)/(one-d*x1*y1*x2*y2)
return x3,y3
In [15]:
type Point = (M, M)
type Rotation = Point -> Point -> Maybe Point
-- Rotation on a circle
clockAdd :: Rotation
clockAdd (x1, y1) (x2, y2) = case (a,b) of
(Just x, Just y) -> Just (x, y)
otherwise -> Nothing
where (a,b) = (join $ (x1 #* y2) + (y1 #* x2), join $ (y1 #* y2) - (x1 #* x2))
-- Rotation on an Edwards curve
edwardsAdd :: M -> Rotation
edwardsAdd d (x1, y1) (x2, y2) = case (a, b) of
(Just x, Just y) -> Just (x, y)
otherwise -> Nothing
where (a,b) = (join $ (join $ (x1 #* y2) + (y1 #* x2)) / denP, join $ (join $ (y1 #* y2) - (x1 #* x2)) / denM)
denP = join $ Just (M e 1) + (join $ (Just d) * (join $ (x1 #* x2) * (y1 #* y2)))
denM = join $ Just (M e 1) - (join $ (Just d) * (join $ (x1 #* x2) * (y1 #* y2)))
M e _ = x1
We can now test our implementation on the same values as those used in the lecture:
In [16]:
f1009 = fromInt 1009
d = f1009 (-11)
p1 = (f1009 7 ,f1009 415)
p2 = (f1009 23 ,f1009 487)
edwardsAdd d p2 p1
We also check that the addition of points on the clock has the desired arithmetic properties. Fix the modulo and the starting point:
In [17]:
n = 1000003
fp = fromInt n
p = (fp 1000, fp 2)
Perform some repeated additions:
In [18]:
p2 = clockAdd p p
p2
In [19]:
p3 = p2 >>= clockAdd p
p3
In [20]:
p4 = p3 >>= clockAdd p
p5 = p4 >>= clockAdd p
p6 = p5 >>= clockAdd p
p6
Check that we get to the same point folloing different paths:
In [21]:
join (clockAdd <$> p3 <*> p3) == p6
In [22]:
join (clockAdd <$> p2 <*> p4) == p6
Finally we will implement the scalarmult
function that takes a point p
, an integer n
and performs rotation of p
by adding n
copies of it. The function is equivalent to the following Python implementation, except that we make the addition function a parameter, so that we could easily swap it for an alternative (clockAdd
or edwardsAdd
in our case).
def scalarmult(n,P):
if n == 0: return (Fp(0),Fp(1))
if n == 1: return P
Q = scalarmult(n // 2,P)
Q = clockadd(Q,Q)
if n % 2: Q = clockadd(P,Q)
return Q
In [23]:
type Steps = Int
scalarMult :: Rotation -> Steps -> Point -> Maybe Point
scalarMult _ 0 (M a n, M b _)
| a == b = Just (M n 0, M n 1)
| otherwise = Nothing
scalarMult _ 1 p = Just p
scalarMult addFn n p = case (odd n) of
True -> addFn p q
_ -> Just q
where Just q = addFn r r
Just r = scalarMult addFn (n `div` 2) p
Again check that multiplying the given point by a scalar gives results consistent with consecutive addition:
In [24]:
scalarMult clockAdd 6 p
In [25]:
circleMult = scalarMult clockAdd
In [26]:
(p2 >>= circleMult 3) == p6
(p2 >>= circleMult 3) == (p3 >>= circleMult 2)
Diffie-Hellmann protocol
Now we are ready to demonstrate the steps of the Diffie-Hellmann protocol that provides a secure method of deriving a shared secret key.
Bob and Alice standardize the prime number n
and the starting point p
. Alice chooses her secret and computes her public key.
In [27]:
aliceSecret = 397
alicePub = circleMult aliceSecret p
alicePub
Bob chooses his secret and computes his public key.
In [28]:
bobSecret = 479
bobPub = circleMult bobSecret p
bobPub
Alice receives Bob's public key and computes the shared secret
In [29]:
aliceShared = bobPub >>= circleMult aliceSecret
Bob receives Alice's public key and computes the shared secret
In [30]:
bobShared = alicePub >>= circleMult bobSecret
The rules of point arithmetic on the circle guarantee that both sides arrive at the same shared key:
In [31]:
aliceShared == bobShared
And just in case you are wondering:
In [32]:
aliceShared