/*
Fast Sudoku solving and generating functions by M Phillips - December 2008.

This code is provided as is with no warranties or guarantees of
any kind.

Please send an email to M Phillips (mbp2@i4free.co.nz)
  - if you use this file in a released product, or
  - if you find any bugs, or
  - if you have any suggestions
*/

#ifndef SUDOKU_SOLVER_H
#define SUDOKU_SOLVER_H

#include <memory.h>
#include <stdlib.h>

template <bool Condition, typename TrueResult, typename FalseResult>
struct if_
{
    typedef TrueResult ret;
};
template <typename TrueResult, typename FalseResult>
struct if_<false, TrueResult, FalseResult>
{
    typedef FalseResult ret;
};

template <int n, int a = (n>>1), int b = -1>
struct floorSqrtn {
    enum { ret = a*a <= n && n < (a+1)*(a+1) ? a : floorSqrtn<n, ((n/((n/a + a)>>1) + ((n/a + a)>>1))>>1), a>::ret };
};
template <int n, int a>
struct floorSqrtn<n, a, a> {
    enum { ret = a };
};
template <>
struct floorSqrtn<1, 0, -1> {
    enum { ret = 1 };
};

template <unsigned int WIDTH>
class SudokuSolver {
	typedef typename if_<(WIDTH < 16), unsigned short, unsigned int>::ret CellInfo;

	enum misc {
		BLOCK = floorSqrtn<WIDTH>::ret,
		ALL_POSSIBILITES = (1<<(WIDTH+1))-2,
	};

	enum solvingTechnique {
		TryLowDigitsFirst,
		TryHighDigitsFirst,
		TryDigitsInRandomOrder,
	};

	static unsigned int countBits(unsigned int v) {
		v -= ((v >> 1) & 0x55555555);
		v = (v & 0x33333333) + ((v >> 2) & 0x33333333);
		return (((v + (v >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24;
	}

public:
	SudokuSolver() : solvingMethod(TryLowDigitsFirst), solvingSteps(0) {
		// precompute each block starting index.
		for (unsigned int i=0; i<WIDTH; ++i)
			square[i] = (i/BLOCK) * BLOCK;
	}

	bool solve(const int initial[WIDTH][WIDTH], int solved[WIDTH][WIDTH], bool onlyIfSingleSolution = false) {
		solvingMethod = onlyIfSingleSolution ? TryLowDigitsFirst : TryHighDigitsFirst;
		// set up the board of numbers to try in each spot
		CellInfo maybes[WIDTH][WIDTH];
		for (unsigned int i=0; i<WIDTH; ++i) {
			for (unsigned int j=0; j<WIDTH; ++j) {
				maybes[i][j] = ALL_POSSIBILITES;
			}
		}

		// mask out all numbers that break the rules
		// from the numbers to try in each cell
		for (unsigned int i=0; i<WIDTH; ++i) {
			for (unsigned int j=0; j<WIDTH; ++j) {
				int val = initial[i][j];
				if (val != 0) {
					scratchOut(maybes, i, j, val);
				}
			}
		}

		// check for an illegal board and set up the
		// solution board to match the given initial board.
		for (unsigned int i=0; i<WIDTH; ++i) {
			for (unsigned int j=0; j<WIDTH; ++j) {
				solved[i][j] = initial[i][j];
				if (maybes[i][j] == 0) {
					return false;
				}
			}
		}

		// begin the main solving process
		if (!backtrackSolving(solved, maybes))
			return false;

		// see if the caller wants to also verify that there are no other solutions
		if (onlyIfSingleSolution) {
			// try solving the board from the other end of the
			// solution space and confirm that the answer matches
			int solved2[WIDTH][WIDTH];
			solve(initial, solved2, false); // cannot fail since we already found a solution
			// if they don't match then there is more than one solution
			for (unsigned int i=0; i<WIDTH; ++i) {
				for (unsigned int j=0; j<WIDTH; ++j) {
					if (solved[i][j] != solved2[i][j]) {
						return false;
					}
				}
			}
		}
		// otherwise there is only one solution
		return true;
	}

	void generate(int board[WIDTH][WIDTH]) {
		// set up a board with all numbers to try in each spot
		CellInfo maybes[WIDTH][WIDTH];
		for (unsigned int i=0; i<WIDTH; ++i) {
			for (unsigned int j=0; j<WIDTH; ++j) {
				maybes[i][j] = ALL_POSSIBILITES;
				board[i][j] = 0;
			}
		}

		// generate a random completed board by solving
		// the empty board using the random strategy
		// cannot fail because the board starts out empty
		solvingMethod = TryDigitsInRandomOrder;
		backtrackSolving(board, maybes);

		// work backwards towards a board with mostly holes
		// by clearning a square and making sure that the
		// resulting board is easily solveable
		int prevSolvingSteps = WIDTH*WIDTH*2;
		for (int k=0; k<1000; ++k) {
			int pos = rand() % (WIDTH*WIDTH);
			int x = pos/WIDTH, y = pos%WIDTH, temp = board[x][y];
			board[x][y] = 0;
			solvingSteps = 0;
			int solved[WIDTH][WIDTH] = {{0}};
			bool single = solve(board, solved, true);
			if (solvingSteps > prevSolvingSteps + 6 || !single)
				board[x][y] = temp;
			else {
				if (solvingSteps > prevSolvingSteps)
					prevSolvingSteps = solvingSteps;
			}
		}
		// remove any potential bias towards certain digits being
		// easier to find, by substituting numbers afterwards
		int permute[WIDTH+1];
		for (unsigned int i = 0; i<=WIDTH; ++i)
			permute[i] = i;
		// first generate permutation
		for (unsigned int i = WIDTH; i>0; --i) {
			int pos = rand() % i + 1;
			int val = permute[pos];
			permute[pos] = permute[i];
			permute[i] = val;
		}
		// now perform the substitution
		for (unsigned int i=0; i<WIDTH; ++i) {
			for (unsigned int j=0; j<WIDTH; ++j) {
				board[i][j] = permute[board[i][j]];
			}
		}
	}

	int getNumSteps() const { return solvingSteps; }

private:
	void scratchOut(CellInfo maybes[WIDTH][WIDTH], int x, int y, int val) {
		// mark off places that can now not take this number
		int mask = ~(1<<val);
		int sqx = square[x], sqy = square[y];
		// mark off all places in this BLOCK
		for (int i=sqx; i<sqx+BLOCK; ++i)
			for (int j=sqy; j<sqy+BLOCK; ++j)
				maybes[i][j] &= mask;
		// mark off all spots in this row or column
		for (unsigned int k=0; k<WIDTH; ++k) {
			maybes[x][k] &= mask;
			maybes[k][y] &= mask;
		}
		// don't mark off this spot though!
		maybes[x][y] = static_cast<CellInfo>(1<<val);
		++solvingSteps;
	}

	bool backtrackSolving(int board[WIDTH][WIDTH], const CellInfo maybes[WIDTH][WIDTH]) {
		// find the spot with the fewest number of possible answers because
		// trying that first reduces the amount of backtracking needed
		unsigned int ii, jj, i=0, j=0;
		int bestBits = WIDTH+1, val;
		for (unsigned ii=0; ii<WIDTH; ++ii) {
			for (unsigned jj=0; jj<WIDTH; ++jj) {
				if (board[ii][jj] == 0) {
					int count = countBits(maybes[ii][jj]);
					if (count < bestBits) {
						bestBits = count;
						i = ii;
						j = jj;
					}
				}
			}
		}
		if (bestBits == 0)
			return false;	// we broke a rule - oops backtrack!
		if (bestBits == WIDTH+1)
			return true;	// we filled the whole board - woohoo, done!

		// check for rows or columns that can no longer hold a certain digit at all
		for (ii=0; ii<WIDTH; ++ii) {
			int isum = 0, jsum = 0;
			for (jj=0; jj<WIDTH; ++jj) {
				isum |= maybes[jj][ii];
				jsum |= maybes[ii][jj];
			}
			if ((isum & jsum) != ALL_POSSIBILITES)
				return false;	// we broke a rule - oops backtrack!
		}
		// I dont bother to check for all numbers in a BLOCK as it seems
		// to make it slower. the above checks probably cover us well enough

		CellInfo newMaybes[WIDTH][WIDTH];
		switch (solvingMethod) {
			case TryLowDigitsFirst:
				for (val=1; val<=(int)WIDTH; ++val) {
					// see if this value can be tried here
					if ((maybes[i][j] & (1<<val)) != 0) {
						board[i][j] = val;
						memcpy(newMaybes, maybes, sizeof(newMaybes));
						scratchOut(newMaybes, i, j, val);
						// recursive solving
						if (backtrackSolving(board, newMaybes))
							return true; // this call solved it!
					}
				}
				break;
			case TryHighDigitsFirst:
				for (val=WIDTH; val>=1; --val) {
					// see if this value can be tried here
					if ((maybes[i][j] & (1<<val)) != 0) {
						board[i][j] = val;
						memcpy(newMaybes, maybes, sizeof(newMaybes));
						scratchOut(newMaybes, i, j, val);
						// recursive solving
						if (backtrackSolving(board, newMaybes))
							return true; // this call solved it!
					}
				}
				break;
			case TryDigitsInRandomOrder:
				{
					char possibles[WIDTH]; // char rather than int, to conserve memory
					for (ii = 0; ii<WIDTH; ++ii)
						possibles[ii] = static_cast<char>(ii+1);
					for (ii = WIDTH; ii>0; --ii) {
						// randomise the order. (Note the caller must call srand first)
						int pos = rand() % ii;
						val = possibles[pos];
						possibles[pos] = possibles[ii-1];
						// see if this value can be tried here
						if ((maybes[i][j] & (1<<val)) != 0) {
							board[i][j] = val;
							memcpy(newMaybes, maybes, sizeof(newMaybes));
							scratchOut(newMaybes, i, j, val);
							// recursive solving
							if (backtrackSolving(board, newMaybes))
								return true; // this call solved it!
						}
					}
				}
				break;
		}
		// if we get here, we failed to find any number that works
		// in the empty spot found and will have to backtrack
		board[i][j] = 0;
		return false;
	}

	solvingTechnique solvingMethod;
	int solvingSteps;
	unsigned int square[WIDTH];
};

#endif

