#ifndef __KDTREE_TREE_H_INCLUDED
#define __KDTREE_TREE_H_INCLUDED

#include "RayIntersectionAccelerator.hxx"
#include "Box.hxx"
#include <time.h>


/* Naive Kd-tree implementation.
 * 
 * Spliting location: Middle point of bounding box. 
 * Spliting plane   : Round robbin.
 * Stop criteria    : Maximum depth and minimum number of primitives.
 * Construction time: O(N*log(N))
 * 
 * Miguel Granados
 * Core Lecture in Computer Graphics by Prof. Dr. Philipp Slusallek
 * University of Saarland
 * WS2006/2007 
 */
class Kdtree : public RayIntersectionAccelerator
{
protected: 
	struct SplitPlane {		
		short pk;	// splitting dimension		
		float pe;	// splitting point
		
		SplitPlane(short pk0=-1, float pe0=0) {
			pk = pk0;
			pe = pe0;
		}
		
		bool operator==(const SplitPlane& sp) {
			return(pk == sp.pk && pe == sp.pe);
		}		
	};
		
	struct Node
	{
		virtual void traverse(Ray &ray, float t_min, float t_max) = 0; //Note: the result is stored in ray! (in ray.t and ray.hit)
		virtual void debug() const = 0;
		virtual int depth(int d) const = 0;
		virtual ~Node() {}
	};

	struct InnerNode : public Node
	{
		SplitPlane p;		
		Node *leftChild, *rightChild;
		
		InnerNode(const SplitPlane& p0, const Box& V0, Node* lc, Node* rc) :
			p(p0), leftChild(lc), rightChild(rc) {
		}
		
		virtual void traverse(Ray &ray, float t_min, float t_max)
		{	
			//float t_split = distanceAlongRayToPlane(ray);			
			float t_split = (p.pe - ray.org[p.pk]) * (ray.dir[p.pk] == 0 ? Infinity : 1/ray.dir[p.pk]);
			
			// near is the side containing the origin of the ray
			Node *near, *far;
			if(ray.org[p.pk] < p.pe) {
				near = leftChild;
				far = rightChild;
			} else {
				near = rightChild;
				far = leftChild;
			}
			
			if( t_split > t_max || t_split < 0) {
			    near->traverse(ray, t_min, t_max);
			}
			else if(t_split < t_min) {
			    far->traverse(ray, t_min, t_max);
			}
			else {
			    near->traverse(ray, t_min, t_split);
			    if(ray.t < t_split) 
			    	return;
			    return far->traverse(ray, t_split, t_max);
			}
		}
				
		
		virtual int depth(int d) const {
			return max(leftChild->depth(d+1), rightChild->depth(d+1));			
		}
		
		virtual void debug() const {
			cerr << "(";  
			leftChild->debug(); 
			cerr << ", "; 
			rightChild->debug();
			cerr << ") "; 
		}
		
		virtual ~InnerNode() {};
	};

	struct LeafNode : public Node
	{
		vector<Primitive *> T;
		Box V;

		LeafNode(vector<Primitive *>& T0, const Box& V0) : T(T0) {
			V = V0;
		}

		virtual void traverseBVH(Ray &ray)
		{
			for (vector<Primitive*>::size_type i=0; i<T.size(); i++)
				T[i]->Intersect(ray);
		}

		virtual void traverse(Ray &ray, float t_min, float t_max)
		{
			for (vector<Primitive*>::size_type i=0; i<T.size(); i++)
				T[i]->Intersect(ray);
		}
				
		virtual void debug() const {
			cerr << T.size(); 
		}
		
		virtual int depth(int d) const {
			return d;
		}		

		virtual ~LeafNode() {};
	};

	Node *root;

	typedef vector<Primitive*> Primitives;
	
	Box bbox;
	
private:
	unsigned int KmaxDepth, KtriTarget;	

	bool terminate(const vector<Primitive *>& T, unsigned int depth) {
		return (T.size() <= KtriTarget || depth >= KmaxDepth); 
	}

	SplitPlane findPlane(const vector<Primitive *>& T, const Box& V, int depth) {
		short pk = depth % 3;
		float pe = (V.min[pk] + V.max[pk]) / 2;
		return SplitPlane(pk, pe);
	}
	
	void splitVoxelWithPlane(const Box& V, const SplitPlane& p, Box& VL, Box& VR) {
		VL = VR = V;
		VL.max[p.pk] = VR.min[p.pk] = p.pe;
	}
	
	void splitTrianglesIntoVoxels(const Primitives& T, const SplitPlane& p, Primitives& TL, Primitives& TR) {
		for(Primitives::const_iterator pit = T.begin(); pit != T.end(); pit++) {
			Primitive* t = *pit;
			Box tbox = t->CalcBounds();
			if(tbox.min[p.pk] <= p.pe)
				TL.push_back(t);  
			if(tbox.max[p.pk] >= p.pe)
				TR.push_back(t);  
		} 
	}

	Node *RecBuild(vector<Primitive *> T, Box &V, int depth)
	{
		//cerr << "Recbuild: |T|=" << T.size() << ", Vmin=" << V.min << ", Vmax=" << V.max << "), depth=" << depth << endl;
		if (terminate(T, depth)) {
			//cerr << "Recbuild: new leaf node" << endl;
			return new LeafNode(T, V);
		}
		SplitPlane p = findPlane(T, V, depth);
		Box VL, VR;
		splitVoxelWithPlane(V, p, VL, VR);
		vector<Primitive *> TL, TR;
		splitTrianglesIntoVoxels(T, p, TL, TR);
		return new InnerNode(p, V, RecBuild(TL, VL, depth+1), RecBuild(TR, VR, depth+1));
	}

public:
	Kdtree() {}
	
	Kdtree(Box topBox, vector<Primitive *> primitives)
	{
		bbox = topBox;
		KmaxDepth = 24;
		KtriTarget = 3;
		cerr << "Kdtree: KmaxDepth=" << KmaxDepth << ", KtriTarget=" << KtriTarget << endl;
		clock_t t_before = clock();
		root = RecBuild(primitives, topBox, 0);
		clock_t t_after = clock();
		cerr << "Kdtree: construction time=" << (t_after - t_before) / (float)CLOCKS_PER_SEC << endl;
		// debug(); cerr << endl;		
	}
	
	bool Intersect(Ray &ray) 
	{
		std::pair<float, float> t = bbox.Intersect(ray);
		root->traverse(ray, t.first, t.second);
		return ray.hit != NULL;
	}	
	
	void debug() const {
		root->debug();
	}
};

#endif

