#!/usr/bin/env python # this file is in the public domain. import StringIO def join(a, b): import itertools return itertools.izip(a, b) #a = iter(a) #b = iter(b) #while True: # a_n = a.next() # yield a_n, b.next() def inner_product(a, b): result = 0 for a_item, b_item in join(a, b): result = result + a_item * b_item return result class matrix(object): def __init__(self, row_count, column_count, cells): object.__init__(self) self.cells = cells self.row_count = row_count self.column_count = column_count assert(len(cells) == self.row_count * self.column_count) def __repr__(self): result = StringIO.StringIO() result.write("matrix(%r, %r, (" % (self.row_count, self.column_count)) if self.row_count > 1 and self.column_count > 1: result.write("\n") column_index = 0 for item in self.cells: result.write("%r, " % item) column_index = column_index + 1 if column_index >= self.column_count: column_index = 0 if self.row_count > 1 and self.column_count > 1: result.write("\n") result.write("))") return result.getvalue() def __add__(self, other): assert(other.column_count == self.column_count) assert(other.row_count == self.row_count) return matrix(self.row_count, self.column_count, [item + other_item for item, other_item in join(self, other)]) def __sub__(self, other): assert(other.column_count == self.column_count) assert(other.row_count == self.row_count) return matrix(self.row_count, self.column_count, [item - other_item for item, other_item in join(self, other)]) def rows(self): for index in range(self.row_count): yield self.row(index) def columns(self): for index in range(self.column_count): yield self.column(index) def row(self, index): return self.cells[self.column_count * index: self.column_count * (index + 1)] def column(self, index): result = [] for row_index in range(0, self.row_count): result.append(self.cells[index + self.column_count * row_index]) return result def __mul__(self, other): if isinstance(other, matrix): """ [3, 1, [ a, b, c]] x [1, 3, [d, e, f]] = [3, 3, [ """ assert(self.column_count == other.row_count) return simplify(matrix(other.column_count, self.row_count, [ inner_product(item, other_item) for item in self.rows() for other_item in other.columns() ])) else: # number return matrix(self.row_count, self.column_count, [item * other for item in self]) def __iter__(self): for item in self.cells: yield item def __eq__(self, other): return self.column_count == other.column_count and self.row_count == other.row_count and [cell for cell in self.cells] == [cell for cell in other.cells] def simplify(item): if isinstance(item, matrix): if item.column_count == 1 and item.row_count == 1: for cell in item: return cell return item def vector(*values): return matrix(len(values), 1, values) if __name__ == "__main__": print matrix(3, 1, (1, 2, 3)) + matrix(3, 1, (4, 5, 6)) print matrix(3, 1, (1, 2, 3)) * 3 print matrix(3, 1, (1, 2, 3)) - matrix(3, 1, (4, 5, 6)) print matrix(3, 1, (1, 2, 3)) * matrix(1, 3, (4, 5, 6)) print vector(1,2,3) a = matrix(2, 2, (1, 2, 3, 4)) b = matrix(2, 2, (5, 6, 7, 8)) print a * b assert(a * b == matrix(2, 2, (19, 22, 43, 50)))