Blob Blame History Raw
#ifndef __NPY_BINSEARCH_H__
#define __NPY_BINSEARCH_H__

#include "npy_sort.h"
#include <numpy/npy_common.h>
#include <numpy/ndarraytypes.h>

typedef void (PyArray_BinSearchFunc)(const char*, const char*, char*,
                                     npy_intp, npy_intp,
                                     npy_intp, npy_intp, npy_intp,
                                     PyArrayObject*);

typedef int (PyArray_ArgBinSearchFunc)(const char*, const char*,
                                       const char*, char*,
                                       npy_intp, npy_intp, npy_intp,
                                       npy_intp, npy_intp, npy_intp,
                                       PyArrayObject*);

struct binsearch_map {
    enum NPY_TYPES typenum;
    PyArray_BinSearchFunc *binsearch[NPY_NSEARCHSIDES];
};

struct argbinsearch_map {
    enum NPY_TYPES typenum;
    PyArray_ArgBinSearchFunc *argbinsearch[NPY_NSEARCHSIDES];
};

/**begin repeat
 *
 * #side = left, right#
 */

/**begin repeat1
 *
 * #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong,
 *         longlong, ulonglong, half, float, double, longdouble,
 *         cfloat, cdouble, clongdouble, datetime, timedelta#
 */

NPY_VISIBILITY_HIDDEN void
binsearch_@side@_@suff@(const char *arr, const char *key, char *ret,
                        npy_intp arr_len, npy_intp key_len,
                        npy_intp arr_str, npy_intp key_str, npy_intp ret_str,
                        PyArrayObject *unused);
NPY_VISIBILITY_HIDDEN int
argbinsearch_@side@_@suff@(const char *arr, const char *key,
                           const char *sort, char *ret,
                           npy_intp arr_len, npy_intp key_len,
                           npy_intp arr_str, npy_intp key_str,
                           npy_intp sort_str, npy_intp ret_str,
                           PyArrayObject *unused);
/**end repeat1**/

NPY_VISIBILITY_HIDDEN void
npy_binsearch_@side@(const char *arr, const char *key, char *ret,
                     npy_intp arr_len, npy_intp key_len,
                     npy_intp arr_str, npy_intp key_str,
                     npy_intp ret_str, PyArrayObject *cmp);
NPY_VISIBILITY_HIDDEN int
npy_argbinsearch_@side@(const char *arr, const char *key,
                        const char *sort, char *ret,
                        npy_intp arr_len, npy_intp key_len,
                        npy_intp arr_str, npy_intp key_str,
                        npy_intp sort_str, npy_intp ret_str,
                        PyArrayObject *cmp);
/**end repeat**/

/**begin repeat
 *
 * #arg = , arg#
 * #Arg = , Arg#
 */

static struct @arg@binsearch_map _@arg@binsearch_map[] = {
    /* If adding new types, make sure to keep them ordered by type num */
    /**begin repeat1
     *
     * #TYPE = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
     *         LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE,
     *         CFLOAT, CDOUBLE, CLONGDOUBLE, DATETIME, TIMEDELTA, HALF#
     * #suff = bool, byte, ubyte, short, ushort, int, uint, long, ulong,
     *         longlong, ulonglong, float, double, longdouble,
     *         cfloat, cdouble, clongdouble, datetime, timedelta, half#
     */
    {NPY_@TYPE@,
        {
            &@arg@binsearch_left_@suff@,
            &@arg@binsearch_right_@suff@,
        },
    },
    /**end repeat1**/
};

static PyArray_@Arg@BinSearchFunc *gen@arg@binsearch_map[] = {
    &npy_@arg@binsearch_left,
    &npy_@arg@binsearch_right,
};

static NPY_INLINE PyArray_@Arg@BinSearchFunc*
get_@arg@binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side)
{
    static npy_intp num_funcs = sizeof(_@arg@binsearch_map) /
                                sizeof(_@arg@binsearch_map[0]);
    npy_intp min_idx = 0;
    npy_intp max_idx = num_funcs;
    int type = dtype->type_num;

    if (side >= NPY_NSEARCHSIDES) {
        return NULL;
    }

    /*
     * It seems only fair that a binary search function be searched for
     * using a binary search...
     */
    while (min_idx < max_idx) {
        npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);

        if (_@arg@binsearch_map[mid_idx].typenum < type) {
            min_idx = mid_idx + 1;
        }
        else {
            max_idx = mid_idx;
        }
    }

    if (min_idx < num_funcs && _@arg@binsearch_map[min_idx].typenum == type) {
        return _@arg@binsearch_map[min_idx].@arg@binsearch[side];
    }

    if (dtype->f->compare) {
        return gen@arg@binsearch_map[side];
    }

    return NULL;
}
/**end repeat**/

#endif