#ifndef NPY_EXTINT128_H_
#define NPY_EXTINT128_H_
typedef struct {
signed char sign;
npy_uint64 lo, hi;
} npy_extint128_t;
/* Integer addition with overflow checking */
static NPY_INLINE npy_int64
safe_add(npy_int64 a, npy_int64 b, char *overflow_flag)
{
if (a > 0 && b > NPY_MAX_INT64 - a) {
*overflow_flag = 1;
}
else if (a < 0 && b < NPY_MIN_INT64 - a) {
*overflow_flag = 1;
}
return a + b;
}
/* Integer subtraction with overflow checking */
static NPY_INLINE npy_int64
safe_sub(npy_int64 a, npy_int64 b, char *overflow_flag)
{
if (a >= 0 && b < a - NPY_MAX_INT64) {
*overflow_flag = 1;
}
else if (a < 0 && b > a - NPY_MIN_INT64) {
*overflow_flag = 1;
}
return a - b;
}
/* Integer multiplication with overflow checking */
static NPY_INLINE npy_int64
safe_mul(npy_int64 a, npy_int64 b, char *overflow_flag)
{
if (a > 0) {
if (b > NPY_MAX_INT64 / a || b < NPY_MIN_INT64 / a) {
*overflow_flag = 1;
}
}
else if (a < 0) {
if (b > 0 && a < NPY_MIN_INT64 / b) {
*overflow_flag = 1;
}
else if (b < 0 && a < NPY_MAX_INT64 / b) {
*overflow_flag = 1;
}
}
return a * b;
}
/* Long integer init */
static NPY_INLINE npy_extint128_t
to_128(npy_int64 x)
{
npy_extint128_t result;
result.sign = (x >= 0 ? 1 : -1);
if (x >= 0) {
result.lo = x;
}
else {
result.lo = (npy_uint64)(-(x + 1)) + 1;
}
result.hi = 0;
return result;
}
static NPY_INLINE npy_int64
to_64(npy_extint128_t x, char *overflow)
{
if (x.hi != 0 ||
(x.sign > 0 && x.lo > NPY_MAX_INT64) ||
(x.sign < 0 && x.lo != 0 && x.lo - 1 > -(NPY_MIN_INT64 + 1))) {
*overflow = 1;
}
return x.lo * x.sign;
}
/* Long integer multiply */
static NPY_INLINE npy_extint128_t
mul_64_64(npy_int64 a, npy_int64 b)
{
npy_extint128_t x, y, z;
npy_uint64 x1, x2, y1, y2, r1, r2, prev;
x = to_128(a);
y = to_128(b);
x1 = x.lo & 0xffffffff;
x2 = x.lo >> 32;
y1 = y.lo & 0xffffffff;
y2 = y.lo >> 32;
r1 = x1*y2;
r2 = x2*y1;
z.sign = x.sign * y.sign;
z.hi = x2*y2 + (r1 >> 32) + (r2 >> 32);
z.lo = x1*y1;
/* Add with carry */
prev = z.lo;
z.lo += (r1 << 32);
if (z.lo < prev) {
++z.hi;
}
prev = z.lo;
z.lo += (r2 << 32);
if (z.lo < prev) {
++z.hi;
}
return z;
}
/* Long integer add */
static NPY_INLINE npy_extint128_t
add_128(npy_extint128_t x, npy_extint128_t y, char *overflow)
{
npy_extint128_t z;
if (x.sign == y.sign) {
z.sign = x.sign;
z.hi = x.hi + y.hi;
if (z.hi < x.hi) {
*overflow = 1;
}
z.lo = x.lo + y.lo;
if (z.lo < x.lo) {
if (z.hi == NPY_MAX_UINT64) {
*overflow = 1;
}
++z.hi;
}
}
else if (x.hi > y.hi || (x.hi == y.hi && x.lo >= y.lo)) {
z.sign = x.sign;
z.hi = x.hi - y.hi;
z.lo = x.lo;
z.lo -= y.lo;
if (z.lo > x.lo) {
--z.hi;
}
}
else {
z.sign = y.sign;
z.hi = y.hi - x.hi;
z.lo = y.lo;
z.lo -= x.lo;
if (z.lo > y.lo) {
--z.hi;
}
}
return z;
}
/* Long integer negation */
static NPY_INLINE npy_extint128_t
neg_128(npy_extint128_t x)
{
npy_extint128_t z = x;
z.sign *= -1;
return z;
}
static NPY_INLINE npy_extint128_t
sub_128(npy_extint128_t x, npy_extint128_t y, char *overflow)
{
return add_128(x, neg_128(y), overflow);
}
static NPY_INLINE npy_extint128_t
shl_128(npy_extint128_t v)
{
npy_extint128_t z;
z = v;
z.hi <<= 1;
z.hi |= (z.lo & (((npy_uint64)1) << 63)) >> 63;
z.lo <<= 1;
return z;
}
static NPY_INLINE npy_extint128_t
shr_128(npy_extint128_t v)
{
npy_extint128_t z;
z = v;
z.lo >>= 1;
z.lo |= (z.hi & 0x1) << 63;
z.hi >>= 1;
return z;
}
static NPY_INLINE int
gt_128(npy_extint128_t a, npy_extint128_t b)
{
if (a.sign > 0 && b.sign > 0) {
return (a.hi > b.hi) || (a.hi == b.hi && a.lo > b.lo);
}
else if (a.sign < 0 && b.sign < 0) {
return (a.hi < b.hi) || (a.hi == b.hi && a.lo < b.lo);
}
else if (a.sign > 0 && b.sign < 0) {
return a.hi != 0 || a.lo != 0 || b.hi != 0 || b.lo != 0;
}
else {
return 0;
}
}
/* Long integer divide */
static NPY_INLINE npy_extint128_t
divmod_128_64(npy_extint128_t x, npy_int64 b, npy_int64 *mod)
{
npy_extint128_t remainder, pointer, result, divisor;
char overflow = 0;
assert(b > 0);
if (b <= 1 || x.hi == 0) {
result.sign = x.sign;
result.lo = x.lo / b;
result.hi = x.hi / b;
*mod = x.sign * (x.lo % b);
return result;
}
/* Long division, not the most efficient choice */
remainder = x;
remainder.sign = 1;
divisor.sign = 1;
divisor.hi = 0;
divisor.lo = b;
result.sign = 1;
result.lo = 0;
result.hi = 0;
pointer.sign = 1;
pointer.lo = 1;
pointer.hi = 0;
while ((divisor.hi & (((npy_uint64)1) << 63)) == 0 &&
gt_128(remainder, divisor)) {
divisor = shl_128(divisor);
pointer = shl_128(pointer);
}
while (pointer.lo || pointer.hi) {
if (!gt_128(divisor, remainder)) {
remainder = sub_128(remainder, divisor, &overflow);
result = add_128(result, pointer, &overflow);
}
divisor = shr_128(divisor);
pointer = shr_128(pointer);
}
/* Fix signs and return; cannot overflow */
result.sign = x.sign;
*mod = x.sign * remainder.lo;
return result;
}
/* Divide and round down (positive divisor; no overflows) */
static NPY_INLINE npy_extint128_t
floordiv_128_64(npy_extint128_t a, npy_int64 b)
{
npy_extint128_t result;
npy_int64 remainder;
char overflow = 0;
assert(b > 0);
result = divmod_128_64(a, b, &remainder);
if (a.sign < 0 && remainder != 0) {
result = sub_128(result, to_128(1), &overflow);
}
return result;
}
/* Divide and round up (positive divisor; no overflows) */
static NPY_INLINE npy_extint128_t
ceildiv_128_64(npy_extint128_t a, npy_int64 b)
{
npy_extint128_t result;
npy_int64 remainder;
char overflow = 0;
assert(b > 0);
result = divmod_128_64(a, b, &remainder);
if (a.sign > 0 && remainder != 0) {
result = add_128(result, to_128(1), &overflow);
}
return result;
}
#endif