In [36]:
def get_submatrix(m, i=None, j=None):
'''
This function returns a submatrix for a given matrix
which does not contain the row i and/or column j.
>>> get_submatrix([[1,2,3],[4,5,6],[7,8,9]], 1, 1)
[[1, 3], [7, 9]]
>>> get_submatrix([[2, 5], [2, 5]], 0, 0)
[[5]]
>>> get_submatrix([[1,2,3],[4,5,6],[7,8,9]], i=1)
[[1, 2, 3], [7, 8, 9]]
>>> get_submatrix([[1,2,3],[4,5,6],[7,8,9]], j=1)
[[1, 3], [4, 6], [7, 9]]
'''
# Remove row
if type(i) == int:
m = m[:i] + m[i+1:]
# Remove col
if type(j)==int:
m = [row[:j]+row[j+1:] for row in m]
return m
def matrix_det(m):
'''
This function returns the determinant of a given
square matrix.
>>> matrix_det([[1,2,3],[4,5,6],[7,8,9]])
0
>>> matrix_det([[1,3],[7,9]])
-12
>>> matrix_det([[1,2,3,4,6],[2,2,3,4,5],[1,3,3,4,5], \
[1,2,4,4,5],[1,2,3,5,5]])
16
'''
i = len(m)
j = len(m[0])
assert i == j, 'Matrix is not a square matrix.'
if i == 1:
return m[0][0]
return sum([m[0][ind]*(-1)**(0+ind)*
matrix_det(get_submatrix(m,i=0,j=ind))
for ind in range(j)])
if __name__ == '__main__':
import doctest
doctest.testmod()
In [115]:
def matrix_sub(a, b):
'''
This funciton returns the difference of matrix a minus matrix b
>>> matrix_sub([[1,2],[3,4]],[[1,2],[3,4]])
[[0, 0], [0, 0]]
'''
assert len(a) == len(b) ,\
'Matrices a and b have to have the same dimensions'
assert len(a[0]) == len(b[0]) , \
'Matrices a and b have to have the same dimensions'
m = len(a)
n = len(a[0])
return [[a[i][j] - b[i][j] for j in range(n)]
for i in range(m)]
def matrix_add(a, b):
'''
This function returns the sum of two matrices a and b.
>>> matrix_add([[1,2],[3,4]],[[1,2],[3,4]])
[[2, 4], [6, 8]]
'''
assert len(a) == len(b) ,\
'Matrices a and b have to have the same dimensions'
assert len(a[0]) == len(b[0]) , \
'Matrices a and b have to have the same dimensions'
m = len(a)
n = len(a[0])
return [[a[i][j] + b[i][j] for j in range(n)]
for i in range(m)]
def matrix_get_part(m, i, j, a, b):
'''
This function returns a part of a given matrix which is defined
by the coordinates i and j and the x and y size of the part (a and b).
>>> matrix_get_part([[1,2,2,1], \
[3,4,5,4], \
[1,3,4,5], \
[3,4,5,4]], 2, 2, 2, 2)
[[4, 5], [5, 4]]
'''
# Remove row
if type(i) == int:
m = m[i:i+a]
# Remove col
if type(j)==int:
m = [row[j:j+b] for row in m]
return m
def matrix_split(m):
'''
This function splits a square matrix into four parts
and returns the parts in a list.
>>> matrix_split([[1,2],[3,4]])
[[[1]], [[2]], [[3]], [[4]]]
'''
len_m = len(m)
mid = len_m // 2
splitted = []
splitted.append(matrix_get_part(m, 0, 0, mid, mid))
splitted.append(matrix_get_part(m, 0, mid, mid, len_m))
splitted.append(matrix_get_part(m, mid, 0, len_m, mid))
splitted.append(matrix_get_part(m, mid, mid, len_m, len_m))
return splitted
def matrix_concat(a, b, axis=0):
'''
This function concatenates two 2D matrices horizontally (default)
or vertically if you set axis=1
>>> a = [[1,2],[3,4]]
>>> b = [[1,2],[3,4]]
>>> matrix_concat(a,b, axis=0)
[[1, 2, 1, 2], [3, 4, 3, 4]]
>>> matrix_concat(a,b, axis=1)
[[1, 2], [3, 4], [1, 2], [3, 4]]
'''
assert axis in [0,1]
if axis==0:
return [x[0] + x[1] for x in zip(a,b)]
if axis==1:
return a[:] + b[:]
def matrix_mul(a,b):
'''
This function multiplies two matrices a and b.
>>> matrix_mul([[3,2],[4,5]], [[1,-1],[-2,-3]])
[[-1, -9], [-6, -19]]
>>> matrix_mul([[2,0,-1],[3,4,5]], [[4,1],[5,0],[-2,-1]])
[[10, 3], [22, -2]]
>>> matrix_mul([[1,2,4],[2,1,3]], [[1,2,3],[8,3,2],[2,1,4]])
[[25, 12, 23], [16, 10, 20]]
'''
assert len(a[0]) == len(b)
mout = []
m = len(a) # Len of rows (a)
n = len(b[0]) # Len of cols (b)
for x in range(m):
colres =[]
for y in range(n):
row = [e for e in a[x]]
col = [e[y] for e in b]
num = sum([q*w for q,w in zip(row, col)])
#print(num)
colres.append(num)
mout.append(colres)
return mout
def matrix_mul_strassen(a, b):
'''
This function multiplies two matrices a and b using the
Strassen algorithm.
>>> matrix_mul_strassen([[3,2],[4,5]], [[1,-1],[-2,-3]])
[[-1, -9], [-6, -19]]
>>> matrix_mul_strassen([[2,0,-1],[3,4,5]], [[4,1],[5,0],[-2,-1]])
[[10, 3], [22, -2]]
>>> matrix_mul_strassen([[1,2,4],[2,1,3]], [[1,2,3],[8,3,2],[2,1,4]])
[[25, 12, 23], [16, 10, 20]]
'''
if (len(a) % 2 != 0 or
len(a) != len(a[0]) or
len(b) != len(b[0])):
return matrix_mul(a,b)
a11, a12, a21, a22 = matrix_split(a)
b11, b12, b21, b22 = matrix_split(b)
b = matrix_mul_strassen(matrix_add(a11,a22), matrix_add(b11,b22))
m2 = matrix_mul_strassen(matrix_add(a21,a22), b11)
m3 = matrix_mul_strassen(a11, matrix_sub(b12,b22))
m4 = matrix_mul_strassen(a22, matrix_sub(b21,b11))
m5 = matrix_mul_strassen(matrix_add(a11,a12),b22)
m6 = matrix_mul_strassen(matrix_sub(a21,a11), matrix_add(b11,b12))
m7 = matrix_mul_strassen(matrix_sub(a12,a22), matrix_add(b21,b22))
c11 = matrix_add(matrix_sub(matrix_add(m1, m4), m5), m7)
c12 = matrix_add(m3, m5)
c21 = matrix_add(m2, m4)
c22 = matrix_add(matrix_sub(m1,m2), matrix_add(m3,m6))
return matrix_concat(matrix_concat(c11, c12, axis=0),
matrix_concat(c21, c22, axis=0),
axis=1)
if __name__ == '__main__':
import doctest
doctest.testmod()