#ifndef SplayTreeUtils
#define SplayTreeUtils

/*
TNode *SplayTree::Find(TNode *head, const TKey &find);
TNode *SplayTree::FindMin(TNode *&head);
TNode *SplayTree::FindMax(TNode *&head);

void SplayTree::Insert(TNode *&head, TNode *ins);

TNode *SplayTree::Remove(TNode *&head, TKey &rem);
TNode *SplayTree::RemoveMin(TNode *head);
TNode *SplayTree::RemoveMax(TNode *head);

TNode *SplayTree::Merge(TNode *head1, TNode *head2);
void SplayTree::Partition(TNode *headIn1, const TKey &splitter, TNode *&headOut1, TNode *&headOut2);
*/

namespace SplayTree {

// *** Internal Stuff ***
template <class TNode> void splayRotR(TNode *&h) {
	TNode *x = h->left; h->left = x->right;
	x->right = h; h = x;
}
template <class TNode> void splayRotL(TNode *&h) {
	TNode *x = h->right; h->right = x->left;
	x->left = h; h = x;
}

template <class TNode, class TKey>
void Find_Aux(TNode *&head, const TKey &find) {
	if (head != NULL) {
		if (*find < *head) {
			if (head->left != NULL) {
				if (*find < *(head->left)) {
					Find_Aux(head->left->left, find);
					splayRotR(head);
				} else {
					Find_Aux(head->left->right, find);
					splayRotL(head->left);
				}
				if (head->left != NULL)
					splayRotR(head);
			}
		} else { 
			if (head->right != NULL) {
				if (*find < *(head->right)) {
					Find_Aux(head->right->left, find);
					splayRotR(head->right);
				} else {
					Find_Aux(head->right->right, find);
					splayRotL(head);
				}
				if (head->right != NULL)
					splayRotL(head);
			}
		}
	}
} 

template <class TNode>
TNode *FindMin_Aux(TNode *&head) {
	TNode *temp = head->left;
	if (temp != NULL) {
		FindMin_Aux(temp);
		TNode *temp2 = temp->right;
		head->left = temp2;
		temp->right = head;
		head = temp;
	}
	return head;
}

template <class TNode>
TNode *FindMax_Aux(TNode *&head) {
	TNode *temp = head->right;
	if (temp != NULL) {
		FindMax_Aux(temp);
		TNode *temp2 = temp->left;
		head->right = temp2;
		temp->left = head;
		head = temp;
	}
	return head;
}

// *** Searching ***
// O(n) : Amortized O(log(n))
template <class TNode, class TKey>
TNode *Find(TNode *&head, const TKey &find) {
	if (head == NULL) return NULL;
	Find_Aux(head, find);
	return (*head == find) ? head : NULL;
}

// O(n) : Amortized O(log(n))
template <class TNode>
TNode *FindMin(TNode *&head) {
	if (head == NULL) return NULL;
	return FindMin_Aux(head);
}

// O(n) : Amortized O(log(n))
template <class TNode>
TNode *FindMax(TNode *&head) {
	if (head == NULL) return NULL;
	return FindMax_Aux(head);
}

// *** Utilities ***
// O(n) : Amortized O(log(n)) : Insert caller-allocated item
template <class TNode>
void Insert_Aux(TNode *&head, TNode *ins) {
	if (head == NULL) {
		head = ins;
	} else if (*ins < *head) {
		if (head->left == NULL) {
			head->left = ins;
		} else if (*ins < *(head->left)) {
			Insert_Aux(head->left->left, ins);
			splayRotR(head);
		} else {
			Insert_Aux(head->left->right, ins);
			splayRotL(head->left);
		} 
		splayRotR(head);
	} else {
		if (head->right == NULL) {
			head->right = ins;
		} else if (*ins < *(head->right)) {
			Insert_Aux(head->right->left, ins);
			splayRotR(head->right);
		} else {
			Insert_Aux(head->right->right, ins);
			splayRotL(head);
		}
		splayRotL(head);
	}
}

template <class TNode>
void Insert(TNode *&head, TNode *ins) {
	ins->left = ins->right = NULL;
	Insert_Aux(head, ins);
}
/*
// O(n) : Amortized O(log(n)) : Remove leftmost selected item - caller responsible for deallocation
template <class TNode, class TKey>
TNode *RemoveMin(TNode *head) {
	TNode *found = NULL;
	if (head != NULL) {
		FindMin_Aux(head);
		found = head;
		head = head->right;
	}
	return found;
}

// O(n) : Amortized O(log(n)) : Remove rightmost selected item - caller responsible for deallocation
template <class TNode, class TKey>
TNode *RemoveMax(TNode *head) {
	TNode *found = NULL;
	if (head != NULL) {
		FindMax_Aux(head);
		found = head;
		head = head->left;
	}
	return found;
}

// O(n) : Amortized O(log(n)) : Remove selected item - caller responsible for deallocation
template <class TNode, class TKey>
TNode *Remove(TNode *head, const TKey &rem) {
	TNode *found = NULL;
	if (head != NULL) {
		Find_Aux(head, rem);
		if (*head == find) {
			//if we got here then we have found the item
			found = head;
			if (head->left != NULL) {
				TNode *temp = head->left;
				Find_Aux(temp, rem);
				temp->right = head->right;
				head = temp;
			} else if (head->right != NULL) {
				TNode *temp = head->right;
				Find_Aux(temp, find);
				temp->left = head->left;
				head = temp;
			} else {
				head = NULL;
			}
		}
	}
	return found;
}
*/
// *** Combining ***
// O(nlog(n+m)) fastest if head2 has fewer items
template <class TNode>
TNode *Merge(TNode *head1, TNode *head2) {
	TNode *l, *r, *curr = head2;
	while (curr != NULL) {
		//remember both children
		l = curr->left;
		r = curr->right;
		//insert current node into other tree
		SplayTree::Insert(head1, curr);
		if (r != NULL) {
			//recursively add original left subtree
			head1 = Merge(head1, l);
			//iteratively add original right subtree
			curr = r;
		} else
			//iteratively add original left subtree
			curr = l;
	}
	return head1;
}

// O(nlog(n))
template <class TNode, class TKey>
void Partition(TNode *headIn1, const TKey &splitter, TNode *&headOut1, TNode *&headOut2) {
	TNode *l, *r, *curr = headIn1;
	while (curr != NULL) {
		//remember both children
		l = curr->left;
		r = curr->right;
		//decide which list to insert into
		if (*curr < splitter)
			SplayTree::Insert(headOut1, curr);
		else
			SplayTree::Insert(headOut2, curr);
		if (r != NULL) {
			//recursively continue partitioning left subtree
			SplayTree::Partition(l, splitter, headOut1, headOut2);
			//iteratively continue partitioning right subtree
			curr = r;
		} else
			//iteratively continue partitioning left subtree
			curr = l;
	}
}

} //namespace SplayTree

#endif

