import warnings

class storedItem():
	def __init__(self,coord,data):
		self.coord = coord
		self.data = data
		self.node = None
	def remove(self):
		self.node.children.remove(self)
		self.node = None

class Node():
	def __init__(self,parent,lower,upper):
		self.parent = parent
		self.children = []
		self.lowerbound = lower
		self.upperbound = upper
		self.setVolume()
	
	def setVolume(self):
		dx = self.upperbound[0] - self.lowerbound[0]
		dy = self.upperbound[1] - self.lowerbound[1]
		dz = self.upperbound[2] - self.lowerbound[2]
		self.volume = dx*dy*dz

	def inbound(self,coord):
		if self.lowerbound[0] <= coord[0] and self.lowerbound[1] <= coord[1] and self.lowerbound[2] <= coord[2]:
			if self.upperbound[0] >= coord[0] and self.upperbound[1] >= coord[1] and self.upperbound[2] >= coord[2]:
				return True
		return False

	def returnValueOrChildnode(self,coord):
		if not self.inbound(coord):
			return self.parent
		for child in self.children:
			if child.__class__ == Node:
				if child.inbound(coord):
					return child
			elif child.__class__ == storedItem:
				if child.coord == coord:
					return child
		return None

	def deleteOrReturnChildNode(self,coord):
		if not self.inbound(coord):
			return self.parent
		for child in self.children:
			if child.__class__ == Node:
				if child.inbound(coord):
					return child
			elif child.__class__ == storedItem:
				if child.coord == coord:
					self.children.remove(child)
					del(child)
					return True
		return None
	
	
	def insertStoredItem(self,item):
		if len(self.children) < 8:
			self.children.append(item)
			item.node = self
			return True

		if len(self.children) == 8:
			for child in self.children:
				if child.__class__ == Node:
					if child.inbound(item.coord):
						return child.insertStoredItem(item)
				elif item.coord == child.coord:
					warnings.warn('Already an item at this location, replacing it')
					self.children.remove(child)
					self.children.append(item)

			self.breakupIntoChildren()
			self.insertStoredItem(item)


	def breakupIntoChildren(self):
		if self.volume == 8:
			raise Exception("Node full. Cannot add to this node")
		nodes = []
		delta = (self.upperbound[0] - self.lowerbound[0] +1)/2
		x1,x2,x3 = (self.lowerbound[0],self.lowerbound[0]+delta -1,self.upperbound[0])
		y1,y2,y3 = (self.lowerbound[1],self.lowerbound[1]+delta -1,self.upperbound[1])
		z1,z2,z3 = (self.lowerbound[2],self.lowerbound[2]+delta -1,self.upperbound[2])

		nodes.append(Node(self,(x1,y1,z1),(x2,y2,z2)))
		nodes.append(Node(self,(x2 + 1,y1,z1),(x3,y2,z2)))
		nodes.append(Node(self,(x1,y1,z2 +1),(x2,y2,z3)))
		nodes.append(Node(self,(x2 + 1,y1,z2 + 1),(x3,y2,z3)))
		nodes.append(Node(self,(x1,y2 + 1,z1),(x2,y3,z2)))
		nodes.append(Node(self,(x2 + 1,y2 + 1,z1),(x3,y3,z2)))
		nodes.append(Node(self,(x1,y2 + 1,z2 + 1),(x2,y3,z3)))
		nodes.append(Node(self,(x2 + 1,y2 + 1,z2 + 1),(x3,y3,z3)))


		while self.children:
			child = self.children[0]
			for node in nodes:
				if node.inbound(child.coord):
					node.insertStoredItem(child)
					self.children.remove(child)

		self.children = nodes



class Octree():
	def __init__(self,size,maxsearch=1000):
		if size % 2:
			raise Exception("Size must be multiple of 2")
		self.root = Node(None, (0,0,0),(size,size,size))
		self.size = size
		self.maxsearch=maxsearch

	def search(self,coord):
		searching = True
		node = self.root
		count = 0
		while searching:
			result = node.returnValueOrChildnode(coord)
			if result is None:
				searching = False
			elif result.__class__ == storedItem:
				result = result.data
				searching = False
			elif result.__class__ == Node:
				node = result
			count += 1
			if count > self.maxsearch: #just incase something goes wrong
				searching=False
				result = None
				raise Exception("Max Search depth limit reached")

		return result
		
	def insert(self,coord,data):
		if not self.root.inbound(coord):
			print coord, self.size, self.root.upperbound, self.root.lowerbound
			raise Exception("Coordinate outside scope of octree")

		item = storedItem(coord,data)
		self.root.insertStoredItem(item)

	def remove(self,coord):
		searching = True
		node = self.root
		count = 0
		while searching:
			result = node.deleteOrReturnChildNode(coord)
			if result is True:
				searching = False
				return True
			elif result is None:
				searching = False
			elif result.__class__ == Node:
				node = result
			count += 1
			if count > self.maxsearch: #just incase something goes wrong
				searching=False
				result = None
				raise Exception("Max Search depth limit reached")

		return result
	

## ---------------------------------------------------------------------------------------------------##


if __name__ == "__main__":

	### Object Insertion Test ###
	
	# So lets test the adding:
	import random
	import time

	#Dummy object class to test with
	class TestObject:
		def __init__(self, name, position):
			self.name = name
			self.position = position

	# Create a new octree, size of world
	treesize = 8
	myTree = Octree(treesize)
	mult = 2

	# Number of objects we intend to add.
	NUM_TEST_OBJECTS = treesize * mult


	poses = []

	# Insert some random objects and time it
	Start = time.time()
	for x in range(NUM_TEST_OBJECTS):
		name = "Node__" + str(x)
		pos = (random.randrange(0, treesize), random.randrange(0, treesize), random.randrange(0, treesize))
		poses.append(pos)
		testOb = TestObject(name, pos)
		myTree.insert(pos,testOb)
	End = time.time() - Start

	# print some results.
	print str(NUM_TEST_OBJECTS) + "-Node Tree Generated in " + str(End) + " Seconds"

	### Lookup Tests ###

	# Look up some random positions and time it
	Start = time.time()
	#for x in range(NUM_COLLISION_LOOKUPS):
	for pos in poses:
		#pos = (random.randrange(0, treesize), random.randrange(0, treesize), random.randrange(0, treesize))
		result = myTree.search(pos)
		
		##################################################################################
		# This proves that results are being returned - but may result in a large printout
		# I'd just comment it out and trust me :)
		# print "Results for test at: " + str(pos)
		if result != None:
			print result
		else:
			print "none found"
		# print
		##################################################################################
		
	End = time.time() - Start

	# print some results.
	print "lookups in " + str(End) + " Seconds"
	#print "Tree Leaves contain a maximum of " + str(MAX_OBJECTS_PER_CUBE) + " objects each."

