#!/usr/bin/env python
# -*- coding: utf-8 -*-
# I, Danny Milosavljevic, place this file in the public domain.

# Nary

def simple_repr(value):
	if isinstance(value, unicode):
		return value.encode("utf-8")
	elif isinstance(value, str):
		return value
	else:
		return repr(value)

class Node(object):
	def __init__(self, operands):
		if len(operands) > 0:
			self.operands = operands
		elif hasattr(self.__class__, "fallback_operand"):
			self.operands = [self.__class__.fallback_operand] # at the "end".
		else:
			self.operands = [] # ???

	def __repr__(self):
		return "(%s %s)" % (self.__class__.name, " ".join(map(simple_repr, self.operands)))

	def __add__(self, other):
		return Addition(other)

	def __sub__(self, other):
		return self + negate(other)

class UnaryNode(Node):
	def __init__(self, operand):
		if isinstance(operand, list):
			assert(len(operand) == 1)
			Node.__init__(self, operand)
		else:
			Node.__init__(self, [operand])

class Power(Node): # well, not exactly... TODO swap base, exponent?
	name = "^"
	fallback_operand = 1 # at the end.

class Product(Node):
	name = "â‹…"
	fallback_operand = 1

class InnerProduct(Node):
	name = "∙"
	fallback_operand = 1

class CrossProduct(Node):
	name = "⨯"
	fallback_operand = 1

class Convolution(Node):
	name = "*"
	fallback_operand = 1

class Addition(Node):
	name = "+"
	fallback_operand = 0


class Andation(Node):
	name = "&"
	fallback_operand = True

class Oration(Node):
	name = "|"
	fallback_operand = False

class Relation(Node): # TODO does it even work here?
	pass

class Equation(Relation):
	name = "="
	pass

class LessOrEqualRelation(Relation):
	name = "≤"

class LessRelation(Relation):
	name = "<"

class GreaterOrEqualRelation(Relation):
	name = "≥"

class GreaterRelation(Relation):
	name = ">"

class NotEqualRelation(Relation):
	name = "≠"

class Notation(UnaryNode):
	name = "¬"
	pass

class Negation(UnaryNode):
	name = "0-"
	def __cmp__(self, other):
		if isinstance(other, Negation):
			return cmp(self.operands[0], other.operands[0])
		else:
			return -1 # FIXME
	# FIXME hash

class Reciprocation(UnaryNode):
	name = "1/"
	def __cmp__(self, other):
		return cmp(self.operands[0], other.operands[0])
	# FIXME hash

class FunctionApplication(Node): # note: factor()
	def __repr__(self):
		#return "{%r}" % self.operands
		return "(%s %s)" % (simple_repr(self.operands[0]), " ".join(map(simple_repr, self.operands[1 : ])))

class Antiderivative(UnaryNode):
	name = "∫"

class Gradient(UnaryNode):
	name = "∇"

class UnarySuffixNode(UnaryNode):
	pass

class Factorial(UnarySuffixNode):
	name = "!"

"""
commutative: 3?2=2?3.
associative: a?(b?c)=(a?b)?c.
distributive: aâ‹…(b+c)=aâ‹…b+aâ‹…c.
"""

def equation_P(expression):
	return isinstance(expression, Equation)

def addition_P(expression):
	return isinstance(expression, Addition)

def relation_P(expression):
	return isinstance(expression, Relation)

def negation_P(expression):
	return isinstance(expression, Negation)

def negate(expression):
	return simplify(Negation(expression))

def power_P(expression):
	return isinstance(expression, Power)

def simplify(expression):
	# (+ (+ 5 2) ∙ 2 3)
	if equation_P(expression):
		return Equation([simplify(item) for item in expression.operands])
	elif relation_P(expression):
		return Relation([simplify(item) for item in expression.operands])
	elif addition_P(expression):
		if len(expression.operands) == 1:
			return simplify(expression.operands[0])

		new_operands = []
		for arg in expression.operands:
			arg = simplify(arg)
			if addition_P(arg):
				for arg_x in arg.operands:
					new_operands.append(arg_x)
			else:
				new_operands.append(arg)

		return Addition(new_operands)
		
	if negation_P(expression):
		arg = simplify(expression.operands[0])
		if arg == 0:
			return arg
		if negation_P(arg):
			return simplify(arg.operands[0])
		elif addition_P(arg):
			# (- (+ (- 1) x)))
			return simplify(Addition([negate(operand) for operand in arg.operands]))
		else:
			return Negation(arg)

	return expression

def simplify_all(expression):
	# FIXME
	return simplify(expression)

def contains_variable_P(for_variable, expression):
	if hasattr(expression, "operands"):
		for thing in expression.operands:
			if contains_variable_P(for_variable, thing):
				return True
	else:
		return expression == for_variable # FIXME recognize dot etc

def addition_P(expression):
	return isinstance(expression, Addition)

def node_P(expression):
	return isinstance(expression, Node)

def product_P(expression):
	return isinstance(expression, Product)

def cross_product_P(expression):
	return isinstance(expression, CrossProduct)

def inner_product_P(expression):
	return isinstance(expression, InnerProduct)

def convolution_P(expression):
	return isinstance(expression, Convolution)

def function_application_P(expression):
	return isinstance(expression, FunctionApplication)

def gradient_P(expression):
	return isinstance(expression, Gradient)

def reciprocation_P(expression):
	return isinstance(expression, Reciprocation)

def antiderivative_P(expression):
	return isinstance(expression, Antiderivative)

def factorial_P(expression):
	return isinstance(expression, Factorial)

def unary_suffix_node_P(expression):
	return isinstance(expression, UnarySuffixNode)

def operation_P(expression):
	# any of the above...
	return not isinstance(expression, str) and not isinstance(expression, unicode) and not isinstance(expression, int)