#ifndef __NPY_SORT_COMMON_H__
#define __NPY_SORT_COMMON_H__
#include <stdlib.h>
#include <numpy/ndarraytypes.h>
/*
*****************************************************************************
** SWAP MACROS **
*****************************************************************************
*/
#define BOOL_SWAP(a,b) {npy_bool tmp = (b); (b)=(a); (a) = tmp;}
#define BYTE_SWAP(a,b) {npy_byte tmp = (b); (b)=(a); (a) = tmp;}
#define UBYTE_SWAP(a,b) {npy_ubyte tmp = (b); (b)=(a); (a) = tmp;}
#define SHORT_SWAP(a,b) {npy_short tmp = (b); (b)=(a); (a) = tmp;}
#define USHORT_SWAP(a,b) {npy_ushort tmp = (b); (b)=(a); (a) = tmp;}
#define INT_SWAP(a,b) {npy_int tmp = (b); (b)=(a); (a) = tmp;}
#define UINT_SWAP(a,b) {npy_uint tmp = (b); (b)=(a); (a) = tmp;}
#define LONG_SWAP(a,b) {npy_long tmp = (b); (b)=(a); (a) = tmp;}
#define ULONG_SWAP(a,b) {npy_ulong tmp = (b); (b)=(a); (a) = tmp;}
#define LONGLONG_SWAP(a,b) {npy_longlong tmp = (b); (b)=(a); (a) = tmp;}
#define ULONGLONG_SWAP(a,b) {npy_ulonglong tmp = (b); (b)=(a); (a) = tmp;}
#define HALF_SWAP(a,b) {npy_half tmp = (b); (b)=(a); (a) = tmp;}
#define FLOAT_SWAP(a,b) {npy_float tmp = (b); (b)=(a); (a) = tmp;}
#define DOUBLE_SWAP(a,b) {npy_double tmp = (b); (b)=(a); (a) = tmp;}
#define LONGDOUBLE_SWAP(a,b) {npy_longdouble tmp = (b); (b)=(a); (a) = tmp;}
#define CFLOAT_SWAP(a,b) {npy_cfloat tmp = (b); (b)=(a); (a) = tmp;}
#define CDOUBLE_SWAP(a,b) {npy_cdouble tmp = (b); (b)=(a); (a) = tmp;}
#define CLONGDOUBLE_SWAP(a,b) {npy_clongdouble tmp = (b); (b)=(a); (a) = tmp;}
#define DATETIME_SWAP(a,b) {npy_datetime tmp = (b); (b)=(a); (a) = tmp;}
#define TIMEDELTA_SWAP(a,b) {npy_timedelta tmp = (b); (b)=(a); (a) = tmp;}
/* Need this for the argsort functions */
#define INTP_SWAP(a,b) {npy_intp tmp = (b); (b)=(a); (a) = tmp;}
/*
*****************************************************************************
** COMPARISON FUNCTIONS **
*****************************************************************************
*/
NPY_INLINE static int
BOOL_LT(npy_bool a, npy_bool b)
{
return a < b;
}
NPY_INLINE static int
BYTE_LT(npy_byte a, npy_byte b)
{
return a < b;
}
NPY_INLINE static int
UBYTE_LT(npy_ubyte a, npy_ubyte b)
{
return a < b;
}
NPY_INLINE static int
SHORT_LT(npy_short a, npy_short b)
{
return a < b;
}
NPY_INLINE static int
USHORT_LT(npy_ushort a, npy_ushort b)
{
return a < b;
}
NPY_INLINE static int
INT_LT(npy_int a, npy_int b)
{
return a < b;
}
NPY_INLINE static int
UINT_LT(npy_uint a, npy_uint b)
{
return a < b;
}
NPY_INLINE static int
LONG_LT(npy_long a, npy_long b)
{
return a < b;
}
NPY_INLINE static int
ULONG_LT(npy_ulong a, npy_ulong b)
{
return a < b;
}
NPY_INLINE static int
LONGLONG_LT(npy_longlong a, npy_longlong b)
{
return a < b;
}
NPY_INLINE static int
ULONGLONG_LT(npy_ulonglong a, npy_ulonglong b)
{
return a < b;
}
NPY_INLINE static int
FLOAT_LT(npy_float a, npy_float b)
{
return a < b || (b != b && a == a);
}
NPY_INLINE static int
DOUBLE_LT(npy_double a, npy_double b)
{
return a < b || (b != b && a == a);
}
NPY_INLINE static int
LONGDOUBLE_LT(npy_longdouble a, npy_longdouble b)
{
return a < b || (b != b && a == a);
}
NPY_INLINE static int
npy_half_isnan(npy_half h)
{
return ((h&0x7c00u) == 0x7c00u) && ((h&0x03ffu) != 0x0000u);
}
NPY_INLINE static int
npy_half_lt_nonan(npy_half h1, npy_half h2)
{
if (h1&0x8000u) {
if (h2&0x8000u) {
return (h1&0x7fffu) > (h2&0x7fffu);
}
else {
/* Signed zeros are equal, have to check for it */
return (h1 != 0x8000u) || (h2 != 0x0000u);
}
}
else {
if (h2&0x8000u) {
return 0;
}
else {
return (h1&0x7fffu) < (h2&0x7fffu);
}
}
}
NPY_INLINE static int
HALF_LT(npy_half a, npy_half b)
{
int ret;
if (npy_half_isnan(b)) {
ret = !npy_half_isnan(a);
}
else {
ret = !npy_half_isnan(a) && npy_half_lt_nonan(a, b);
}
return ret;
}
/*
* For inline functions SUN recommends not using a return in the then part
* of an if statement. It's a SUN compiler thing, so assign the return value
* to a variable instead.
*/
NPY_INLINE static int
CFLOAT_LT(npy_cfloat a, npy_cfloat b)
{
int ret;
if (a.real < b.real) {
ret = a.imag == a.imag || b.imag != b.imag;
}
else if (a.real > b.real) {
ret = b.imag != b.imag && a.imag == a.imag;
}
else if (a.real == b.real || (a.real != a.real && b.real != b.real)) {
ret = a.imag < b.imag || (b.imag != b.imag && a.imag == a.imag);
}
else {
ret = b.real != b.real;
}
return ret;
}
NPY_INLINE static int
CDOUBLE_LT(npy_cdouble a, npy_cdouble b)
{
int ret;
if (a.real < b.real) {
ret = a.imag == a.imag || b.imag != b.imag;
}
else if (a.real > b.real) {
ret = b.imag != b.imag && a.imag == a.imag;
}
else if (a.real == b.real || (a.real != a.real && b.real != b.real)) {
ret = a.imag < b.imag || (b.imag != b.imag && a.imag == a.imag);
}
else {
ret = b.real != b.real;
}
return ret;
}
NPY_INLINE static int
CLONGDOUBLE_LT(npy_clongdouble a, npy_clongdouble b)
{
int ret;
if (a.real < b.real) {
ret = a.imag == a.imag || b.imag != b.imag;
}
else if (a.real > b.real) {
ret = b.imag != b.imag && a.imag == a.imag;
}
else if (a.real == b.real || (a.real != a.real && b.real != b.real)) {
ret = a.imag < b.imag || (b.imag != b.imag && a.imag == a.imag);
}
else {
ret = b.real != b.real;
}
return ret;
}
NPY_INLINE static void
STRING_COPY(char *s1, char *s2, size_t len)
{
memcpy(s1, s2, len);
}
NPY_INLINE static void
STRING_SWAP(char *s1, char *s2, size_t len)
{
while(len--) {
const char t = *s1;
*s1++ = *s2;
*s2++ = t;
}
}
NPY_INLINE static int
STRING_LT(char *s1, char *s2, size_t len)
{
const unsigned char *c1 = (unsigned char *)s1;
const unsigned char *c2 = (unsigned char *)s2;
size_t i;
int ret = 0;
for (i = 0; i < len; ++i) {
if (c1[i] != c2[i]) {
ret = c1[i] < c2[i];
break;
}
}
return ret;
}
NPY_INLINE static void
UNICODE_COPY(npy_ucs4 *s1, npy_ucs4 *s2, size_t len)
{
while(len--) {
*s1++ = *s2++;
}
}
NPY_INLINE static void
UNICODE_SWAP(npy_ucs4 *s1, npy_ucs4 *s2, size_t len)
{
while(len--) {
const npy_ucs4 t = *s1;
*s1++ = *s2;
*s2++ = t;
}
}
NPY_INLINE static int
UNICODE_LT(npy_ucs4 *s1, npy_ucs4 *s2, size_t len)
{
size_t i;
int ret = 0;
for (i = 0; i < len; ++i) {
if (s1[i] != s2[i]) {
ret = s1[i] < s2[i];
break;
}
}
return ret;
}
NPY_INLINE static int
DATETIME_LT(npy_datetime a, npy_datetime b)
{
return a < b;
}
NPY_INLINE static int
TIMEDELTA_LT(npy_timedelta a, npy_timedelta b)
{
return a < b;
}
NPY_INLINE static void
GENERIC_COPY(char *a, char *b, size_t len)
{
memcpy(a, b, len);
}
NPY_INLINE static void
GENERIC_SWAP(char *a, char *b, size_t len)
{
while(len--) {
const char t = *a;
*a++ = *b;
*b++ = t;
}
}
#endif