#ifndef __SAhTREE_TREE_H_INCLUDED
#define __SAhTREE_TREE_H_INCLUDED

#include "Kdtree.h"
#include "Box.hxx"
#include <math.h>
#include <list>
#include <time.h>
#include <algorithm>

/* SAH (Surface Area Heuristic) Kd-tree implementation.
 *
 *
 * Spliting location: Minimizing traversal/intersection cost
 * Spliting plane   : Minimizing traversal/intersection cost
 * Stop criteria    : Minimizing traversal/intersection cost
 * Construction time: O(N*log^2(N))
 *
 * Reference:
 * Ingo Wald and Vlastimil Havran. On building fast kd-trees for ray tracing, and on doing that in
 * O(N*log N). In Proceedings of the 2006 IEEE Symposium on Interactive Ray Tracing, 2006. (accepted
 * for publication, minor revision pending).
 *
 * Miguel Granados, Richard Socher
 * Core Lecture in Computer Graphics by Prof. Dr. Philipp Slusallek
 * University of Saarland
 * WS2006/2007
 */
class SAHKdtree : public Kdtree
{
	typedef Kdtree::Primitives Primitives;

	// traversal cost
	float KT;

	// triangle intersection cost
	float KI;

	// surface area of a voxel V
	inline float SA(const Box& V) const {
		return 2*V.dX()*V.dY() + 2*V.dX()*V.dZ() + 2*V.dY()*V.dZ();
	}

	// probability of hitting the subvoxel Vsub given that the voxel V was hit
	inline float P_Vsub_given_V(const Box& Vsub, const Box& V) const {
		float SA_Vsub = SA(Vsub);
		float SA_V = SA(V);
		return(SA_Vsub/SA_V);
	}

	// bias for the cost function s.t. it is reduced if NL or NR becomes zero
	inline float lambda(int NL, int NR, float PL, float PR) const {
		if((NL == 0 || NR == 0) &&
		    !(PL == 1 || PR == 1) // NOT IN PAPER
		)
			return 0.8f;
		return 1.0f;
	}

	// cost C of a complete tree approximated using the cost CV of subdividing the voxel V with a plane p
	inline float C(float PL, float PR, int NL, int NR) const {
		// cerr << "C: PL=" << PL << ", PR=" << PR << ", NL=" << NL << ", NR=" << NR << ", C=" << (lambda(NL, NR) * (KT + KI * (PL * NL + PR * NR))) << endl;
		return(lambda(NL, NR, PL, PR) * (KT + KI * (PL * NL + PR * NR)));
	}

	// split a voxel V using a plane p
	void splitBox(const Box& V, const SplitPlane& p, Box& VL, Box& VR) const {
		VL = V;
		VR = V;
		VL.setMax(p.pk, p.pe);
		VR.setMin(p.pk, p.pe);
		assert(V.contains(VL));
		assert(V.contains(VR));
	}

	typedef enum { LEFT=-1, RIGHT=1, UNKNOWN=0 } PlaneSide;

	// SAH heuristic for computing the cost of splitting a voxel V using a plane p
	void SAH(const SplitPlane& p, const Box& V, int NL, int NR, int NP, float& CP, PlaneSide& pside) const {
		CP = Infinity;
		Box VL, VR;
		splitBox(V, p, VL, VR);
		float PL, PR;
		PL = P_Vsub_given_V(VL, V);
		PR = P_Vsub_given_V(VR, V);
		if(PL == 0 || PR == 0) // NOT IN PAPER
			return;
		if(V.d(p.pk) == 0) // NOT IN PAPER
			return;
		float CPL, CPR;
		CPL = C(PL, PR, NL + NP, NR);
		CPR = C(PL, PR, NL, NP + NR );
		if(CPL < CPR) {
			CP = CPL;
			pside = LEFT;
		} else {
			CP = CPR;
			pside = RIGHT;
		}
		/*
		cerr << "SHA:"
		     << "  NL=" << NL << ", NP=" << NP << ", NR=" << NR << ", SAL=" << SA(VL)
		     << "  VL.min=" << VL.min << ", VL.max=" << VL.max << ", SAL=" << SA(VL)
		     << ", VR.min=" << VR.min << ", VR.max=" << VR.max << ", SAR=" << SA(VR)
		     << ", V.min=" << V.min << ", V.max=" << V.max << ", SA=" << SA(V) << endl;
		cerr << "SHA: (PL,PR)=(" << PL << "," << PR << "), (CPL, CPR)=(" << CPL << "," << CPR << ")" << endl;
		*/
	}

	// criterion for stopping subdividing a tree node
	inline bool terminate(int N, float minCv) const {
		// cerr << "terminate: minCv=" << minCv << ", KI*N=" << KI*N << endl;
		return(minCv > KI*N);
	}

	struct Event {
		typedef enum { endingOnPlane=0, lyingOnPlane=1, startingOnPlane=2  } EventType;
		Primitive* et;	// triangle
		SplitPlane p;
		EventType type;
		Event(Primitive* et0, int k, float ee0, EventType type0) :
			et(et0), type(type0) {
			assert(type == endingOnPlane || type == lyingOnPlane || type == startingOnPlane);
			p = SplitPlane(k, ee0);
		}
		inline bool operator<(const Event& e) const {
			return((p.pe < e.p.pe) || (p.pe == e.p.pe && type < e.type));
		}


	};

	// get primitives's clipped bounding box
	Box clipTriangleToBox(Primitive* t, const Box& V) const {
		Box b = t->CalcBounds();
		for(int k=0; k<3; k++) {
			if(V.min[k] > b.min[k])
				b.min[k] = V.min[k];
			if(V.max[k] < b.max[k])
				b.max[k] = V.max[k];
		}
		assert(V.contains(b));
		return b;
	}

	// best spliting plane using SAH heuristic
	void findPlane(const vector<Primitive *>& T, const Box& V, int depth,
			SplitPlane& p_est, float& C_est, PlaneSide& pside_est) const {
		// static int count = 0;
		C_est = Infinity;
		for(int k=0; k<3; ++k) {
			vector<Event> events;
			events.reserve(T.size()*2);
			for(Primitives::const_iterator pit = T.begin(); pit != T.end(); pit++) {
				Primitive* t = *pit;
				Box B = clipTriangleToBox(t, V);
				if(B.isPlanar()) {
					events.push_back(Event(t, k, B.min[k], Event::lyingOnPlane));
				} else {
					events.push_back(Event(t, k, B.min[k], Event::startingOnPlane));
					events.push_back(Event(t, k, B.max[k], Event::endingOnPlane));
				}
			}
			sort(events.begin(), events.end());
			int NL = 0, NP = 0, NR = T.size();
			for(vector<Event>::size_type Ei = 0; Ei < events.size(); ++Ei) {
				const SplitPlane& p = events[Ei].p;
				int pLyingOnPlane = 0, pStartingOnPlane = 0, pEndingOnPlane = 0;
				while(Ei < events.size() && events[Ei].p.pe == p.pe && events[Ei].type == Event::endingOnPlane) {
					++pEndingOnPlane;
					Ei++;
				}
				while(Ei < events.size() && events[Ei].p.pe == p.pe && events[Ei].type == Event::lyingOnPlane) {
					++pLyingOnPlane;
					Ei++;
				}
				while(Ei < events.size() && events[Ei].p.pe == p.pe && events[Ei].type == Event::startingOnPlane) {
					++pStartingOnPlane;
					Ei++;
				}
				NP = pLyingOnPlane;
				NR -= pLyingOnPlane;
				NR -= pEndingOnPlane;
				float C;
				PlaneSide pside = UNKNOWN;
				SAH(p, V, NL, NR, NP, C, pside);
				//cerr << "findPlane(" << count++ << "): plane.pk=" << p.pk << ", plane.pe=" << p.pe << ", NL=" << NL << ", NP=" << NP << ", NR=" << NR << ", cost=" << C << endl;
				if(C < C_est) {
					C_est = C;
					p_est = p;
					pside_est = pside;
				}
				NL += pStartingOnPlane;
				NL += pLyingOnPlane;
				NP = 0;
			}
		}
		//cerr << "findPlane(" << count++ << "): p_est.pk=" << p_est.pk << ", p_est.pe=" << p_est.pe << " (" << (p_est.pe - V.min[p_est.pk]) / V.d(p_est.pk) << ", " << (V.max[p_est.pk] - p_est.pe) / V.d(p_est.pk) << ")"
		//	<< ", pside=" << pside_est <<", C_est=" << C_est << endl;

	}

	// sort triangles into left and right voxels
	void splitTrianglesIntoVoxels(const Primitives& T, const SplitPlane& p, const PlaneSide& pside, Primitives& TL, Primitives& TR) const {
		for(Primitives::const_iterator pit = T.begin(); pit != T.end(); pit++) {
			Primitive* t = *pit;
			Box tbox = t->CalcBounds();
			if(tbox.min[p.pk] == p.pe && tbox.max[p.pk] == p.pe) {
				if(pside == LEFT)
					TL.push_back(t);
				else if(pside == RIGHT)
					TR.push_back(t);
				else
					assert(false); // wrong pside
			} else {
				if(tbox.min[p.pk] < p.pe)
					TL.push_back(t);
				if(tbox.max[p.pk] > p.pe)
					TR.push_back(t);
			}
		}
		//cerr << "splitTrianglesIntoVoxels: NL=" << TL.size() << ", NR=" << TR.size()
		//	<< " (" << TL.size()/(float)T.size() << ", " << TR.size()/(float)T.size() << ")" << endl;
	}

	int maxdepth; // DEBUG ONLY
	int nnodes; // DEBUG ONLY

	Node *RecBuild(vector<Primitive *> T, Box &V, int depth, const SplitPlane& prev_p)
	{
		assert(depth < 100); // just as a protection for when the stopping criterion fails
		//cerr << endl << "Recbuild: |T|=" << T.size() << ", Vmin=" << V.min << ", Vmax=" << V.max << ", depth=" << depth << endl;

		++nnodes; // DEBUG ONLY
		if(depth > maxdepth) maxdepth = depth; // DEBUG ONLY

		SplitPlane p;
		float Cp;
		PlaneSide pside;
		findPlane(T, V, depth, p, Cp, pside);
		if(terminate(T.size(), Cp)
			|| p == prev_p) // NOT IN PAPER
		{
			//cerr << "Recbuild: new leaf node" << endl;
			return new LeafNode(T, V);
		}
		Box VL, VR;
		splitBox(V, p, VL, VR); // TODO: avoid doing this step twice
		vector<Primitive *> TL, TR;
		splitTrianglesIntoVoxels(T, p, pside, TL, TR);
		return new InnerNode(p, V, RecBuild(TL, VL, depth+1, p), RecBuild(TR, VR, depth+1, p));
	}

public:
	SAHKdtree() {}

	SAHKdtree(Box topBox, vector<Primitive *> primitives)
	{
		KT = 1.0;
		KI = 1.5;
		bbox = topBox;
		cerr << "SAHKdtree: KT=" << KT << ", KI=" << KI << endl;
		maxdepth = 0;
		nnodes = 0;
		clock_t t_before = clock();
		root = RecBuild(primitives, topBox, 0, SplitPlane());
		clock_t t_after = clock();
		cerr << "SAHKdtree: nnodes=" << nnodes << ", maxdepth=" << maxdepth << endl;
		cerr << "SAHKdtree: construction time=" << (t_after - t_before) / (float)CLOCKS_PER_SEC << "s" << endl;
	}
};

#endif

