/*
Two's-complement variable-sized bigInt number class by M Phillips - 2008.
Thanks also to Zero Soma Valintine for several bug fixes

This code is provided as is with no warranties or guarantees of
any kind.

Please send an email to M Phillips (mbp2@i4free.co.nz)
  - if you use this file in a released product, or
  - if you find any bugs, or
  - if you have any suggestions
*/

#ifndef VAR_BIG_INT_H
#define VAR_BIG_INT_H

#include <limits>
#include <iostream>
#include <climits>
#include <stdexcept>
#include <memory.h>

// compile-time asserts (failure results in error C2118: negative subscript)
#ifndef C_ASSERT
#define ASSERT_CONCAT_(a, b) a##b
#define ASSERT_CONCAT(a, b) ASSERT_CONCAT_(a, b)
#define C_ASSERT(e) typedef char ASSERT_CONCAT(assert_line_, __LINE__)[(e)?1:-1]
#endif

// If using a big endian processor, define BIGINT_USE_BIG_ENDIAN
// (I've read there isn't a robust method of endianness detection at compile-time)
//#define BIGINT_USE_BIG_ENDIAN

// To use 64-bit maths internally, define BIGINT_USE_64BIT
//#define BIGINT_USE_64BIT

class varbigint
{
protected:
#if _M_AMD64 || defined(BIGINT_USE_64BIT)
	typedef unsigned long long BIGINT_BASE;
	typedef unsigned int BIGINT_SHORT;
#else
	typedef unsigned int BIGINT_BASE;
	typedef unsigned short BIGINT_SHORT;
#endif

	C_ASSERT(sizeof(varbigint::BIGINT_SHORT)*2 == sizeof(varbigint::BIGINT_BASE));

	enum constants {
		BASE_BITS = sizeof(BIGINT_BASE) * CHAR_BIT,
		HALF_BASE_BITS = BASE_BITS / 2,
		LOWER_HALF_MASK = (((1U << (HALF_BASE_BITS-1)) - 1) << 1) + 1,
	};

	union {
		unsigned int info;
		struct {
			signed len:31;
			unsigned negative:1;
		};
	};
	int allocLen;
	BIGINT_BASE *data;

	varbigint& resize(int newLen);
	varbigint& trim();

	template <typename T> void internalInitialise(T q, const int Fill = 0) {
		allocLen = len = (q != T(0) && q != T(~T(0))) ? (sizeof(q) + sizeof(BIGINT_BASE) - 1) / sizeof(BIGINT_BASE) : 0;
		negative = q < 0;
		if (len == 0) {
			data = NULL;
		} else {
			data = new BIGINT_BASE[allocLen];
#ifdef BIGINT_USE_BIG_ENDIAN
			memset(data, 0, sizeof(data) - sizeof(q));
			memcpy(reinterpret_cast<char*>(data) + sizeof(data) - sizeof(q), &q, sizeof(q));
#else
			memcpy(data, &q, sizeof(q));
			memset(reinterpret_cast<char*>(data) + sizeof(q), Fill, sizeof(data) - sizeof(q));
#endif
		}
	}

	template <typename T> void internalExtract(T &q, const int Fill = 0) const {
		if (len == 0) {
			q = negative ? T(~T(0)) : T(0);
		} else {
			if (sizeof(q) <= len * sizeof(BIGINT_BASE)) {
#ifdef BIGINT_USE_BIG_ENDIAN
				memcpy(&q, reinterpret_cast<char*>(data) + len * sizeof(BIGINT_BASE) - sizeof(q), sizeof(q));
#else
				memcpy(&q, data, sizeof(q));
#endif
			} else {
#ifdef BIGINT_USE_BIG_ENDIAN
				memcpy(reinterpret_cast<char*>(&q) - len * sizeof(BIGINT_BASE) + sizeof(q), data, len * sizeof(BIGINT_BASE));
				memset(&q, 0, len * sizeof(BIGINT_BASE) - sizeof(q));
#else
				memcpy(&q, data, len * sizeof(BIGINT_BASE));
				memset(reinterpret_cast<char*>(&q) + sizeof(q), Fill, sizeof(q) - len * sizeof(BIGINT_BASE));
#endif
			}
		}
	}

	template <typename T> void internalInitialiseFromFloating(T f) {
		if (f-f != static_cast<T>(0.0))	// nan or +/-inf
			return;
		bool neg = false;
		if (f < static_cast<T>(0.0)) {
			f = -f;
			neg = true;
		}
		if (f < static_cast<T>(1.0))
			return;
#ifdef _MSC_VER
		static const varbigint::BIGINT_BASE HALF_BASE_MAX = ((std::numeric_limits<varbigint::BIGINT_BASE>::max)() >> 1) + 1;
#endif
		if (f < static_cast<T>((std::numeric_limits<varbigint::BIGINT_BASE>::max)())) {
			resize(1);
#ifdef _MSC_VER
			// Workaround for an overflow bug in __ftol2
			if (sizeof(BIGINT_BASE) > 4 && f >= HALF_BASE_MAX) {
				data[0] = HALF_BASE_MAX | static_cast<BIGINT_BASE>(f-(static_cast<T>(HALF_BASE_MAX)));
			} else
#endif
				data[0] = static_cast<BIGINT_BASE>(f);
		} else {
			int e;
			T mant = frexp(f, &e);
			int numBits = 1 + (e + BASE_BITS-1) % BASE_BITS;
			int chunk = (e + BASE_BITS-1) / BASE_BITS;
			resize(chunk);
			while (mant > static_cast<T>(0.0) && chunk > 0) {
				mant = ldexp(mant, numBits);
				BIGINT_BASE val;
#ifdef _MSC_VER
				// Workaround for an overflow bug in __ftol2
				if (sizeof(BIGINT_BASE) > 4 && f >= HALF_BASE_MAX) {
					val = HALF_BASE_MAX | static_cast<BIGINT_BASE>(f-(static_cast<T>(HALF_BASE_MAX)));
				} else
#endif
					val = static_cast<BIGINT_BASE>(mant);
				mant -= val;
				numBits = BASE_BITS;
				data[--chunk] = val;
			}
		}
		if (neg)
			*this = -*this;
	}

	template <typename T> T internalExtractFromFloating() const {
		static const T multiplier = (static_cast<T>((std::numeric_limits<varbigint::BIGINT_BASE>::max)()) + static_cast<T>(1.0));
		T result = 0;
		if (negative) {
			varbigint temp = -*this;
			for (int i = len; i > 0;)
				result = result * multiplier + temp.data[--i];
			return -result;
		} else {
			for (int i = len; i > 0;)
				result = result * multiplier + data[--i];
			return result;
		}
	}

template <typename ostream_t>
static ostream_t& output(ostream_t &os, varbigint a) {
	typedef typename ostream_t::char_type char_t;
	typedef std::basic_string<char_t, std::char_traits<char_t>, std::allocator<char_t> > tstring;
	bool neg = a.negative;
	if (neg)
		a.makePositive();

	varbigint::BIGINT_SHORT base = 10;
	char_t hexChar = (char_t)'a' - 0xA;
	if (os.flags() & std::ios::hex) {
		base = 16;
		if (os.flags() & std::ios::uppercase)
			hexChar = (char_t)'A' - 0xA;
	} else if (os.flags() & std::ios::oct)
		base = 8;

	tstring s;
	do {
		varbigint::BIGINT_SHORT digit;
		DivMod_short(a, base, a, &digit);
		s.push_back(char_t(digit + ((digit < 0xA) ? (char_t)'0' : hexChar)));
	} while (!!a);

	std::streamsize width = os.width(0) - s.length();
	if (neg || (os.flags() & std::ios::showpos))
		--width;

	if (os.flags() & std::ios::right)
		for (std::streamsize i = width; i > 0; --i)
			os << os.fill();

	if (neg)
		os << (char_t)'-';
	else if (os.flags() & std::ios::showpos)
		os << (char_t)'+';

	if (os.flags() & std::ios::internal)
		for (std::streamsize i = width; i > 0; --i)
			os << os.fill();

	for (typename tstring::size_type i = s.length(); i > 0;)
		os << s[--i];

	if (os.flags() & std::ios::left)
		for (std::streamsize i = width; i > 0; --i)
			os << os.fill();
	return os;
}

template <typename istream_t>
static istream_t& input(istream_t &is, varbigint &val) {
	typedef typename istream_t::char_type char_t;
	varbigint a = 0;
	bool negative = false, gotValue = false;
	char_t ch;
	varbigint::BIGINT_BASE aa = 0, mult = 1;
	varbigint::BIGINT_BASE base = 10, limit = ~varbigint::BIGINT_BASE(0U) / 10;
	if (is.flags() & std::ios::hex)
		base = 16, limit = ~varbigint::BIGINT_BASE(0U) / 16;
	else if (is.flags() & std::ios::oct)
		base = 8, limit = ~varbigint::BIGINT_BASE(0U) / 8;

	std::ios::iostate except(is.exceptions());
	is.exceptions(std::ios::goodbit);

	if (is >> ch) {
		if (ch == (char_t)'-' || ch == (char_t)'+') {
			negative = (ch == (char_t)'-');
			is.get(ch);
		}
		for (;;) {
			if (is.eof()) {
				if (gotValue)
					is.clear(std::ios::eofbit);
				break;
			} else if (ch >= (char_t)'0' && (ch <= (char_t)'7' || (ch <= (char_t)'9' && base >= 10))) {
				aa *= base;
				aa += ch - (char_t)'0';
			} else if (ch >= (char_t)'A' && ch <= (char_t)'F' && base > 10) {
				aa *= base;
				aa += ch - ((char_t)'A' - 0xA);
			} else if (ch >= (char_t)'a' && ch <= (char_t)'f' && base > 10) {
				aa *= base;
				aa += ch - ((char_t)'a' - 0xA);
			} else {
				is.putback(ch);
				if (!gotValue)
					is.setstate(std::ios::failbit);
				break;
			}
			gotValue = true;
			mult *= base;
			if (mult >= limit) {
				a.Mul_int(mult);
				a += aa;
				mult = 1;
				aa = 0;
			}
			is.get(ch);
		}
		if (!is.fail()) {
			a.Mul_int(mult);
			a += aa;
			if (negative)
				a.makeNegative();
			val.swap(a);
		}
	}
	is.exceptions(except);
	return is;
}

	friend void DivMod_short(const varbigint &a, BIGINT_SHORT b, varbigint &quotient, BIGINT_SHORT *remainder);
	friend void DivMod(const varbigint &a, const varbigint &b, varbigint &quotient, varbigint *remainder);
	friend void DivMod_Unsigned(const varbigint &a, const varbigint &b, varbigint &quotient, varbigint *remainder);
	static void Multiply(const varbigint &a, const varbigint &b, varbigint &result);
	static void Multiply_Unsigned(const varbigint &a, const varbigint &b, varbigint &result);
	varbigint& Mul_short(const BIGINT_SHORT rhs);
	varbigint& Mul_int(const BIGINT_BASE rhs);

	// Modifying unary functions
	void makeNegative();
	void makePositive();

public:

	// Constructors and conversion operators
	// Default constructor: the value represented is initially zero.
	varbigint() : info(0), allocLen(0), data(NULL) {}
	varbigint(char x)				{ internalInitialise(x, (x >= 0) ? 0 : 0xFF); }
	varbigint(signed char x)		{ internalInitialise(x, (x >= 0) ? 0 : 0xFF); }
	varbigint(short x)				{ internalInitialise(x, (x >= 0) ? 0 : 0xFF); }
	varbigint(int x)				{ internalInitialise(x, (x >= 0) ? 0 : 0xFF); }
	varbigint(long x)				{ internalInitialise(x, (x >= 0) ? 0 : 0xFF); }
	varbigint(long long x)			{ internalInitialise(x, (x >= 0) ? 0 : 0xFF); }
	varbigint(unsigned char x)		{ internalInitialise(x); }
	varbigint(unsigned short x)		{ internalInitialise(x); }
	varbigint(unsigned int x)		{ internalInitialise(x); }
	varbigint(unsigned long x)		{ internalInitialise(x); }
	varbigint(unsigned long long x)	{ internalInitialise(x); }
	varbigint(const varbigint &other);
	varbigint(const varbigint &other, int len);
	varbigint(const std::string &s);
	varbigint(const std::wstring &ws);
	varbigint(float f);
	varbigint(double d);
	varbigint(long double d);

#if defined(_MSC_VER) && _MSC_VER >= 1600
	varbigint(varbigint &&other);
	varbigint& operator = (varbigint &&rhs);
#endif

	~varbigint() { delete[] data; }

	operator char() const				{ char x;				internalExtract(x); return x; }
	operator signed char() const		{ signed char x;		internalExtract(x); return x; }
	operator short() const				{ short x;				internalExtract(x); return x; }
	operator int() const				{ int x;				internalExtract(x); return x; }
	operator long() const				{ long x;				internalExtract(x); return x; }
	operator long long() const			{ long long x;			internalExtract(x); return x; }
	operator unsigned char() const		{ unsigned char x;		internalExtract(x); return x; }
	operator unsigned short() const		{ unsigned short x;		internalExtract(x); return x; }
	operator unsigned int() const		{ unsigned int x;		internalExtract(x); return x; }
	operator unsigned long() const		{ unsigned long x;		internalExtract(x); return x; }
	operator unsigned long long() const	{ unsigned long long x;	internalExtract(x); return x; }
	operator float() const;
	operator double() const;
	operator long double() const;
	operator std::string() const;
	operator std::wstring() const;

	std::string toString() { return *this; }
	std::wstring toWString() { return *this; }

	bool verify();

	// Swapping and assignment
	void swap(varbigint &other);
	varbigint& operator = (const varbigint &rhs);

	// Stream operators
	friend std::ostream& operator << (std::ostream &os, const varbigint &val);
	friend std::wostream& operator << (std::wostream &wos, const varbigint &val);
	friend std::istream& operator >> (std::istream &is, varbigint &val);
	friend std::wistream& operator >> (std::wistream &wis, varbigint &val);

	void bitClear(int b);
	void bitSet(int b);
	void bitToggle(int b);
	bool bitTest(int b) const;
	bool lowBitSet() const;

	// Modifying binary logical operators
	varbigint& operator &=(const varbigint &rhs);
	varbigint& operator |=(const varbigint &rhs);
	varbigint& operator ^=(const varbigint &rhs);

	// Modifying binary mathematical operators
	/*
	The below does an arithmetic right shift. If you want to do a logical
	right shift, then you must first check if the number to be shifted is
	negative, and if so, xor it with 0 to change the negative flag.
	*/
	varbigint& operator>>=(int shift);
	varbigint& operator<<=(int shift);
	varbigint& operator +=(const varbigint &rhs);
	varbigint& operator -=(const varbigint &rhs);
	varbigint& operator *=(const varbigint &rhs);
	varbigint& operator /=(const varbigint &rhs);
	varbigint& operator %=(const varbigint &rhs);

	// Binary comparison operators
	friend int compare(const varbigint &a, const varbigint &b);
	friend bool operator ==(const varbigint &a, const varbigint &b);
	friend bool operator !=(const varbigint &a, const varbigint &b);
	friend bool operator < (const varbigint &a, const varbigint &b);
	friend bool operator > (const varbigint &a, const varbigint &b);
	friend bool operator <= (const varbigint &a, const varbigint &b);
	friend bool operator >= (const varbigint &a, const varbigint &b);

	// Binary logical operators
	/*
	The below does an arithmetic right shift. If you want to do a logical
	right shift, then you must first check if the number to be shifted is
	negative, and if so, xor it with 0 to change the negative flag.
	*/
	const varbigint operator>> (int shift) const;
	const varbigint operator<< (int shift) const;

	friend const varbigint operator & (varbigint a, const varbigint &b);
	friend const varbigint operator | (varbigint a, const varbigint &b);
	friend const varbigint operator ^ (varbigint a, const varbigint &b);

	// Binary mathematical operators
	friend const varbigint operator + (varbigint a, const varbigint &b);
	friend const varbigint operator - (varbigint a, const varbigint &b);
	friend const varbigint operator * (const varbigint &a, const varbigint &b);
	friend const varbigint operator / (const varbigint &a, const varbigint &b);
	friend const varbigint operator % (const varbigint &a, const varbigint &b);

	// Modifying unary operators
	varbigint& operator++ ();  // Pre Increment operator -- faster than add
	const varbigint operator++ (int);  // Post Increment operator -- faster than add
	varbigint& operator-- ();  // Pre Decrement operator -- faster than subtract
	const varbigint operator-- (int);  // Post Decrement operator -- faster than subtract

	// Unary operators
	bool operator ! () const;	//For comparison against zero
	const varbigint operator ~ () const;
	const varbigint& operator + () const;  // Unary positive
	const varbigint operator - () const;  // Unary Negative

	//Misc
	friend const varbigint sqrt(const varbigint &x);		// returns the square root of x
	friend const varbigint cbrt(const varbigint &x);		// returns the cube root of x
	friend const varbigint abs(const varbigint &x);
	friend const varbigint factorial(const varbigint &x);

	friend bool Fermat(const varbigint &v, unsigned int maxa);
	friend bool MillerRabin(const varbigint &v, unsigned int maxa);
	friend bool isPrime(const varbigint &v);

	friend const varbigint gcd(const varbigint &a, const varbigint &b);
	friend const varbigint lcm(const varbigint &a, const varbigint &b);
	friend const varbigint gcdext(varbigint a, varbigint b, varbigint &s, varbigint &t);
	friend int jacobi(varbigint m, varbigint n);
	friend int findHighestBitSet(const varbigint &x);
	friend int findLowestBitSet(const varbigint &x);
	friend bool isPow2(const varbigint &x);
	friend const varbigint nextPow2(const varbigint &x);
	friend int ceillog2(const varbigint &x);
	friend int popcount(const varbigint &x);
	friend const varbigint pow(const varbigint &a, int b);
	friend const varbigint modpow(varbigint base, const varbigint &exp, const varbigint &mod);

	friend bool operator < (const varbigint &q, int a) { return q < varbigint(a); }
	friend bool operator > (const varbigint &q, int a) { return varbigint(a) < q; }
	friend bool operator <=(const varbigint &q, int a) { return !(varbigint(a) < q); }
	friend bool operator >=(const varbigint &q, int a) { return !(q < varbigint(a)); }
	friend bool operator ==(const varbigint &q, int a) { return q.len < 2 && q == varbigint(a); }
	friend bool operator !=(const varbigint &q, int a) { return q.len > 1 || q != varbigint(a); }
	friend bool operator < (int a, const varbigint &q) { return varbigint(a) < q; }
	friend bool operator > (int a, const varbigint &q) { return q < varbigint(a); }
	friend bool operator <=(int a, const varbigint &q) { return !(q < varbigint(a)); }
	friend bool operator >=(int a, const varbigint &q) { return !(varbigint(a) < q); }
	friend bool operator ==(int a, const varbigint &q) { return q.len < 2 && varbigint(a) == q; }
	friend bool operator !=(int a, const varbigint &q) { return q.len > 1 || varbigint(a) != q; }

	friend const varbigint operator & (const varbigint &q, int a) { return q & varbigint(a); }
	friend const varbigint operator | (const varbigint &q, int a) { return q | varbigint(a); }
	friend const varbigint operator ^ (const varbigint &q, int a) { return q ^ varbigint(a); }
	friend const varbigint operator & (int a, const varbigint &q) { return varbigint(a) & q; }
	friend const varbigint operator | (int a, const varbigint &q) { return varbigint(a) | q; }
	friend const varbigint operator ^ (int a, const varbigint &q) { return varbigint(a) ^ q; }

	friend const varbigint operator + (const varbigint &q, int a) { return q + varbigint(a); }
	friend const varbigint operator - (const varbigint &q, int a) { return q - varbigint(a); }
	friend const varbigint operator * (const varbigint &q, int a) { return q * varbigint(a); }
	friend const varbigint operator / (const varbigint &q, int a) { return q / varbigint(a); }
	friend const varbigint operator % (const varbigint &q, int a) { return q % varbigint(a); }
	friend const varbigint operator + (int a, const varbigint &q) { return varbigint(a) + q; }
	friend const varbigint operator - (int a, const varbigint &q) { return varbigint(a) - q; }
	friend const varbigint operator * (int a, const varbigint &q) { return q * varbigint(a); }
	friend const varbigint operator / (int a, const varbigint &q) { return varbigint(a) / q; }
	friend const varbigint operator % (int a, const varbigint &q) { return varbigint(a) % q; }
};

namespace std {
	template<> inline void swap(varbigint &a, varbigint &b) throw() {
		a.swap(b);
	}
};

#endif

