//==============================================================================
//
// Copyright (c) 2007 Jason Evans <jasone@canonware.com>
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to
// deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
// sell copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
//
//==============================================================================
//
// Implementation of (n choose k) for 64-bit signed integers.
//
//==============================================================================

#include <stdbool.h>
#include <stdio.h>
#include <stdint.h>
#include <unistd.h>
#include <assert.h>

// Binary GCD algorithm, attributed by Knuth to Stein (1961).  See Knuth's
// TAOCP Vol. 2, 3rd Ed., pg 338 for details.
int64_t
gcd(int64_t u, int64_t v) {
    int64_t k = 0;
    int64_t t;

    // Find power of 2.
    while ((u & 1) == 0 && (v & 1) == 0) {
	k++;
	u >>= 1;
	v >>= 1;
    }

    // Initialize.
    if (u & 1) {
	t = -v;
    } else {
	t = u;
	t >>= 1;
    }

    // Reduce t.
    while ((t & 1) == 0) {
	t >>= 1;
    }

    // Reset max(u, v).
    if (t > 0) {
	u = t;
    } else {
	v = -t;
    }

    // Subtract.
    t = u - v;
    while (t != 0) {
	// Repeat various earlier steps.
	t >>= 1;
	while ((t & 1) == 0) {
	    t >>= 1;
	}

	if (t > 0) {
	    u = t;
	} else {
	    v = -t;
	}

	t = u - v;
    }

    return (u << k);
}

// Compute (n choose k), taking care not to overflow if possible.  Return -1 on
// overflow.
int64_t
choose(int64_t n, int64_t k) {
    int64_t l, a, b, i, j;
    bool overflow;

    // (n k) and (n (n-k)) are the same; assure that k is as small as possible.
    l = n - k;
    if (l < k) {
	l = k;
	k = n - k;
    }

    // We maintain accumulators for the numerator and denominator: a/b.
    a = 1;
    b = 1;

    i = n;
    j = k;
    while (i > l || j > 1) {
	// Prepare to detect overflow, which results in no progress for an
	// entire iteration.
	overflow = true;

	// Accumulate into a (numerator).
	for (; i > l; i--) {
	    int64_t g = gcd(i, b);
	    int64_t i2  = i / g;
	    int64_t b2  = b / g;
	    if (a > 0x7fffffffffffffffLL / i2) {
		// Overflowed a; switch accumulation modes.
		break;
	    }
	    a *= i2;
	    b = b2;
	    overflow = false;
	}

	// Accumulate into b (denominator).
	for (; j > 1; j--) {
	    int64_t g = gcd(j, a);
	    int64_t j2 = j / g;
	    int64_t a2 = a / g;
	    if (b > 0x7fffffffffffffffLL / j2) {
		// Overflowed b; switch accumulation modes.
		break;
	    }
	    b *= j2;
	    a = a2;
	    overflow = false;
	}

	// Check for overflow.
	if (overflow) {
	    return -1;
	}
    }

    assert(b == 1);
    return a;
}

int
main(void) {
    int64_t n = 100;
    int64_t i, j, r;

    // Compute and print (n choose k) for a range of inputs.
    for (i = 0; i <= n; i++) {
	fprintf(stderr, "==> %lld\n", i);
	for (j = 0; j <= i; j++)
	{
	    r = choose(i, j);
	    if (r == -1) {
		fprintf(stderr, "  choose(%lld, %lld) --> ?\n", i, j);
	    } else {
		fprintf(stderr, "  choose(%lld, %lld) --> %lld\n", i, j, r);
	    }
	}
    }

    return 0;
}
