Given a number of points, we want to draw the straight line which leads to the lowest Mean Squared Error. This is basically me being silly - rather than looking up the formula or method that is used to calculate the regression line, I thought I could try to figure it out myself. In the process, I would learn a lot about Haskell, and math in general. Other's do crosswords ;)
My natural inclination would have been to just use generic data types, such as lists, hashes etc, but in Haskell, we seem to be encouraged to use more specific types, and it's anyway a good exericise to learn more about the typing system. I'll define a type for 2D points, and one for an infinite straight line, with an intercept and a slope. I'll also define a convenience function for converting tuples into points.
In [13]:
import Data.List
data Point = Pt {ptx, pty :: Float} deriving (Show, Eq)
data Line = Line {intercept, slope :: Float} deriving (Show, Eq)
tupleToPoint :: (Float, Float) -> Point
tupleToPoint (x,y) = Pt {ptx = x, pty = y}
pointToTuple :: Point -> (Float, Float)
pointToTuple pt = (ptx pt, pty pt)
tupleList = [(1,2), (2,3), (4,5)] :: [(Float, Float)]
ptList = map tupleToPoint tupleList
ptList
A good place to start searching for the derived line would be in the center of the points, calculated by taking the average of the xs, and the average of the ys. Let's define a function for that.
In [2]:
findCenter :: [Point] -> Maybe Point
findCenter [] = Nothing
findCenter points = Just Pt {ptx = avgXs, pty = avgYs}
where
xs = [ptx pt | pt <- points]
ys = [pty pt | pt <- points]
avgXs = sum xs / (fromIntegral $ length xs)
avgYs = sum ys / (fromIntegral $ length ys)
ctr = findCenter ptList
ctr
Given a line, and a series of point, can I calculate the MSE? First, I need to calculate where a point on the line given a certain x would be, and the difference between the x of the point on the same y.
In [3]:
line1 = Line {intercept=1.0, slope=2.0} -- y=2x+1
lineYatX :: Line -> Float -> Float
lineYatX line x = intercept line + slope line * x
errorLinePoint :: Line -> Point -> Float
errorLinePoint line pt = lineYatX line (ptx pt) - pty pt
errorLinePoint line1 $ tupleToPoint (5,4)
So at x=5, the point is y=4, but the line is 2x+1 = 11, thus the difference is 7. Note that I'm still playing with types - I chose to use a record for points, but this means that I can't use ptx/pty as independent constructors (I think). If they were independent constructors, I could have pattern-matched lineAt for either X or Y... I am also passing around Floats instead of something more specific... Just slowly getting a feel for things.
Also writing a bunch of small functions and making things very explicit to get a good sense of things.
Now we should be able to the MSE of a cluster of points with a line very easily.
In [4]:
mSE :: Line -> [Point] -> Float
mSE line points = sum sqerrs / fromIntegral (length points)
where
errs = map (errorLinePoint line) points
sqerrs = map (** 2.0) errs
mSE line1 ptList
Now let's calculate the MSE for lines with varying slope
In [5]:
lineList = [Line {slope=x, intercept=2.0} | x <- [1.0, 2.0, 3.0, 4.0, 5.0]]
map (`mSE` ptList) lineList
Let's try to find the optimal fit, varying first slope and then intercept...
First let me define a function that finds the infinite line that intercepts two given points.
In [6]:
findInterceptLine :: Point -> Point -> Line
findInterceptLine pt1 pt2 = Line {intercept = intercept', slope = slope'}
where
slope' = dy / dx
dx = ptx pt1 / ptx pt2
dy = pty pt1 / pty pt2
intercept' = pty pt1 - (ptx pt1 * slope')
line2 = findInterceptLine Pt {ptx=1.0, pty=10.0} Pt {ptx=2.0, pty=20.0} -- should be slope=2, intercept=-1
print line2
getLinePoints :: Line -> [Float] -> [Point]
getLinePoints line xs = zipWith (curry tupleToPoint) xs $ map (lineYatX line) xs
getLinePoints line2 [1.0..10]
In [14]:
getCombinations :: [a] -> [[a]]
getCombinations na = do
a <- na
b <- na
[[a,b]]
a = getCombinations [1,2,3,4,5]
eqPairs :: (Eq a) => [a] -> [a] -> Bool
eqPairs x y = head x `elem` y && x !! 1 `elem` y
nubBy eqPairs a
In [9]:
map pointToTuple ptList