/*
Multiway trie for voxel data class by M Phillips - 2007.

example declaration:
	multiway_trie_3d<Particle, 32, 1000> myTrie;

This code is provided as is with no warranties or guarantees of 
any kind.

Please send an email to M Phillips (mbp2@i4free.co.nz)
  - if you use this file in a released product, or
  - if you find any bugs, or
  - if you have any suggestions
*/

#ifndef MULTIWAYTRIE3D_H
#define MULTIWAYTRIE3D_H

#include <assert.h>

template <unsigned n, unsigned b>
struct ceilLogb {
	enum { ret = ceilLogb<(n+b-1)/b, b>::ret + 1 };
}; 
template <unsigned b>
struct ceilLogb<1, b> {
  enum { ret = 0 };
};

template <class T, unsigned K, unsigned D>
class multiway_trie_3d {
public:
	multiway_trie_3d() : head(NULL) {}

	~multiway_trie_3d() {
		delete head;
	}

	// Insert an item and return the previous item from that location
	T* Insert(unsigned x, unsigned y, unsigned z, T* val) {
		assert(x < D && y < D && z < D);
		node **stack[3 * DEPTH];
		unsigned top = 0;

		node** curr = &head;
		for (unsigned i=0; i<DEPTH; ++i) {
			if (*curr == NULL)
				*curr = new node;
			stack[top++] = curr;
			curr = &(*curr)->next[x%K];
			x /= K;
		}
		for (unsigned j=0; j<DEPTH; ++j) {
			if (*curr == NULL)
				*curr = new node;
			stack[top++] = curr;
			curr = &(*curr)->next[y%K];
			y /= K;
		}
		for (unsigned k=1; k<DEPTH; ++k) {
			if (*curr == NULL)
				*curr = new node;
			stack[top++] = curr;
			curr = &(*curr)->next[z%K];
			z /= K;
		}

		if (*curr == NULL)
			*curr = new node;
		stack[top++] = curr;
		T *oldData = (*curr)->data[z];
		(*curr)->data[z] = val;
		(*curr)->lastLevel = true;

		if (oldData == NULL) {
			while (top-- > 0)
				++(*stack[top])->count;
		}

		return oldData;
	}

	// Return whatever item is currently in a certain location
	T* Peek(unsigned x, unsigned y, unsigned z) const {
		assert(x < D && y < D && z < D);

		const node* curr = head;
		for (unsigned i=0; i<DEPTH; ++i) {
			if (curr == NULL)
				return NULL;
			curr = curr->next[x%K];
			x /= K;
		}
		for (unsigned j=0; j<DEPTH; ++j) {
			if (curr == NULL)
				return NULL;
			curr = curr->next[y%K];
			y /= K;
		}
		for (unsigned k=1; k<DEPTH; ++k) {
			if (curr == NULL)
				return NULL;
			curr = curr->next[z%K];
			z /= K;
		}
		if (curr == NULL)
			return NULL;
		return curr->data[z];
	}

	// Remove and return any item found at the requested location
	T* Remove(unsigned x, unsigned y, unsigned z) {
		assert(x < D && y < D && z < D);
		node **stack[3 * DEPTH];
		unsigned top = 0;

		node** curr = &head;
		for (unsigned i=0; i<DEPTH; ++i) {
			if (*curr == NULL)
				return NULL;
			stack[top++] = curr;
			curr = &(*curr)->next[x%K];
			x /= K;
		}
		for (unsigned j=0; j<DEPTH; ++j) {
			if (*curr == NULL)
				return NULL;
			stack[top++] = curr;
			curr = &(*curr)->next[y%K];
			y /= K;
		}
		for (unsigned k=1; k<DEPTH; ++k) {
			if (*curr == NULL)
				return NULL;
			stack[top++] = curr;
			curr = &(*curr)->next[z%K];
			z /= K;
		}

		if (*curr == NULL)
			return NULL;
		stack[top++] = curr;
		T *oldData = (*curr)->data[z];
		(*curr)->data[z] = NULL;

		if (oldData != NULL) {
			while (top-- > 0) {
				if (--(*stack[top])->count == 0) {
					delete (*stack[top]);
					(*stack[top]) = NULL;
				}
			}
		}

		return oldData;
	}

	// Remove (and delete) ALL items from the trie
	void Clear() {
		delete head;
		head = NULL;
	}

	// Tells us how many items are in the trie
	unsigned Size() const {
		return (head == NULL) ? 0 : head->count;
	}

	// Gives the position of the first item found, in the output x, y & z parameters
	bool FindFirst(unsigned &x, unsigned &y, unsigned &z) const {
		x = y = z = 0;
		if (head == NULL) return false;
		
		const node* curr = head;
		unsigned k = 1;
		while (k < D) {
			unsigned i;
			for (i=0; i<K; ++i) {
				if (curr->next[i] != NULL)
					break;
			}
			curr = curr->next[i];
			x += i*k; k *= K;
		}
		k = 1;
		while (k < D) {
			unsigned i;
			for (i=0; i<K; ++i) {
				if (curr->next[i] != NULL)
					break;
			}
			curr = curr->next[i];
			y += i*k; k *= K;
		}
		k = 1;
		while (k < D) {
			unsigned i;
			for (i=0; i<K; ++i) {
				if (curr->next[i] != NULL)
					break;
			}
			curr = curr->next[i];
			z += i*k; k *= K;
		}

		assert(x < D && y < D && z < D);
		return true;
	}

private:
	enum { DEPTH = ceilLogb<D, K>::ret };

	struct node {
		node() : count(0), lastLevel(false) {
			for (unsigned i=0; i<K; ++i)
				next[i] = NULL;
		}
		~node() {
			if (count > 0) {
				if (!lastLevel)
					for (unsigned i=0; i<K; ++i)
						delete next[i];
				else
					for (unsigned i=0; i<K; ++i)
						delete data[i];
			}
		}

		unsigned count;
		bool lastLevel;
		union {
			node *next[K];
			T* data[K];
		};
	};

	node* head;
};

#endif

