Blob Blame History Raw
/*
 * This file implements the API functions for NumPy's nditer that
 * are specialized using the templating system.
 *
 * Copyright (c) 2010-2011 by Mark Wiebe (mwwiebe@gmail.com)
 * The University of British Columbia
 *
 * See LICENSE.txt for the license.
 */

/* Indicate that this .c file is allowed to include the header */
#define NPY_ITERATOR_IMPLEMENTATION_CODE
#include "nditer_impl.h"

/* SPECIALIZED iternext functions that handle the non-buffering part */

/**begin repeat
 * #const_itflags = 0,
 *                  NPY_ITFLAG_HASINDEX,
 *                  NPY_ITFLAG_EXLOOP,
 *                  NPY_ITFLAG_RANGE,
 *                  NPY_ITFLAG_RANGE|NPY_ITFLAG_HASINDEX#
 * #tag_itflags = 0, IND, NOINN, RNG, RNGuIND#
 */
/**begin repeat1
 * #const_ndim = 1, 2, NPY_MAXDIMS#
 * #tag_ndim = 1, 2, ANY#
 */
/**begin repeat2
 * #const_nop = 1, 2, NPY_MAXDIMS#
 * #tag_nop = 1, 2, ANY#
 */

/* Specialized iternext (@const_itflags@,@tag_ndim@,@tag_nop@) */
static int
npyiter_iternext_itflags@tag_itflags@_dims@tag_ndim@_iters@tag_nop@(
                                                      NpyIter *iter)
{
#if !(@const_itflags@&NPY_ITFLAG_EXLOOP) || (@const_ndim@ > 1)
    const npy_uint32 itflags = @const_itflags@;
#  if @const_ndim@ >= NPY_MAXDIMS
    int idim, ndim = NIT_NDIM(iter);
#  endif
#  if @const_nop@ < NPY_MAXDIMS
    const int nop = @const_nop@;
#  else
    int nop = NIT_NOP(iter);
#  endif

    NpyIter_AxisData *axisdata0;
    npy_intp istrides, nstrides = NAD_NSTRIDES();
#endif
#if @const_ndim@ > 1
    NpyIter_AxisData *axisdata1;
    npy_intp sizeof_axisdata;
#endif
#if @const_ndim@ > 2
    NpyIter_AxisData *axisdata2;
#endif

#if (@const_itflags@&NPY_ITFLAG_RANGE)
    /* When ranged iteration is enabled, use the iterindex */
    if (++NIT_ITERINDEX(iter) >= NIT_ITEREND(iter)) {
        return 0;
    }
#endif

#if @const_ndim@ > 1
    sizeof_axisdata = NIT_AXISDATA_SIZEOF(itflags, ndim, nop);
#endif

#  if !(@const_itflags@&NPY_ITFLAG_EXLOOP) || (@const_ndim@ > 1)
    axisdata0 = NIT_AXISDATA(iter);
#  endif
#  if !(@const_itflags@&NPY_ITFLAG_EXLOOP)
    /* Increment index 0 */
    NAD_INDEX(axisdata0)++;
    /* Increment pointer 0 */
    for (istrides = 0; istrides < nstrides; ++istrides) {
        NAD_PTRS(axisdata0)[istrides] += NAD_STRIDES(axisdata0)[istrides];
    }
#  endif

#if @const_ndim@ == 1

#  if !(@const_itflags@&NPY_ITFLAG_EXLOOP)
    /* Finished when the index equals the shape */
    return NAD_INDEX(axisdata0) < NAD_SHAPE(axisdata0);
#  else
    return 0;
#  endif

#else

#  if !(@const_itflags@&NPY_ITFLAG_EXLOOP)
    if (NAD_INDEX(axisdata0) < NAD_SHAPE(axisdata0)) {
        return 1;
    }
#  endif

    axisdata1 = NIT_INDEX_AXISDATA(axisdata0, 1);
    /* Increment index 1 */
    NAD_INDEX(axisdata1)++;
    /* Increment pointer 1 */
    for (istrides = 0; istrides < nstrides; ++istrides) {
        NAD_PTRS(axisdata1)[istrides] += NAD_STRIDES(axisdata1)[istrides];
    }

    if (NAD_INDEX(axisdata1) < NAD_SHAPE(axisdata1)) {
        /* Reset the 1st index to 0 */
        NAD_INDEX(axisdata0) = 0;
        /* Reset the 1st pointer to the value of the 2nd */
        for (istrides = 0; istrides < nstrides; ++istrides) {
            NAD_PTRS(axisdata0)[istrides] = NAD_PTRS(axisdata1)[istrides];
        }
        return 1;
    }

# if @const_ndim@ == 2
    return 0;
# else

    axisdata2 = NIT_INDEX_AXISDATA(axisdata1, 1);
    /* Increment index 2 */
    NAD_INDEX(axisdata2)++;
    /* Increment pointer 2 */
    for (istrides = 0; istrides < nstrides; ++istrides) {
        NAD_PTRS(axisdata2)[istrides] += NAD_STRIDES(axisdata2)[istrides];
    }

    if (NAD_INDEX(axisdata2) < NAD_SHAPE(axisdata2)) {
        /* Reset the 1st and 2nd indices to 0 */
        NAD_INDEX(axisdata0) = 0;
        NAD_INDEX(axisdata1) = 0;
        /* Reset the 1st and 2nd pointers to the value of the 3nd */
        for (istrides = 0; istrides < nstrides; ++istrides) {
            NAD_PTRS(axisdata0)[istrides] = NAD_PTRS(axisdata2)[istrides];
            NAD_PTRS(axisdata1)[istrides] = NAD_PTRS(axisdata2)[istrides];
        }
        return 1;
    }

    for (idim = 3; idim < ndim; ++idim) {
        NIT_ADVANCE_AXISDATA(axisdata2, 1);
        /* Increment the index */
        NAD_INDEX(axisdata2)++;
        /* Increment the pointer */
        for (istrides = 0; istrides < nstrides; ++istrides) {
            NAD_PTRS(axisdata2)[istrides] += NAD_STRIDES(axisdata2)[istrides];
        }


        if (NAD_INDEX(axisdata2) < NAD_SHAPE(axisdata2)) {
            /* Reset the indices and pointers of all previous axisdatas */
            axisdata1 = axisdata2;
            do {
                NIT_ADVANCE_AXISDATA(axisdata1, -1);
                /* Reset the index to 0 */
                NAD_INDEX(axisdata1) = 0;
                /* Reset the pointer to the updated value */
                for (istrides = 0; istrides < nstrides; ++istrides) {
                    NAD_PTRS(axisdata1)[istrides] =
                                        NAD_PTRS(axisdata2)[istrides];
                }
            } while (axisdata1 != axisdata0);

            return 1;
        }
    }

    return 0;

# endif /* ndim != 2 */

#endif /* ndim != 1 */
}

/**end repeat2**/
/**end repeat1**/
/**end repeat**/


/**begin repeat
 * #const_nop = 1, 2, 3, 4, NPY_MAXDIMS#
 * #tag_nop = 1, 2, 3, 4, ANY#
 */

/*
 * Iternext function that handles the reduction buffering part.  This
 * is done with a double loop to avoid frequent re-buffering.
 */
static int
npyiter_buffered_reduce_iternext_iters@tag_nop@(NpyIter *iter)
{
    npy_uint32 itflags = NIT_ITFLAGS(iter);
    /*int ndim = NIT_NDIM(iter);*/
#if @const_nop@ >= NPY_MAXDIMS
    int nop = NIT_NOP(iter);
#else
    const int nop = @const_nop@;
#endif

    int iop;

    NpyIter_AxisData *axisdata;
    NpyIter_BufferData *bufferdata = NIT_BUFFERDATA(iter);
    char **ptrs;
    char *prev_dataptrs[NPY_MAXARGS];

    ptrs = NBF_PTRS(bufferdata);

    /*
     * If the iterator handles the inner loop, need to increment all
     * the indices and pointers
     */
    if (!(itflags&NPY_ITFLAG_EXLOOP)) {
        /* Increment within the buffer */
        if (++NIT_ITERINDEX(iter) < NBF_BUFITEREND(bufferdata)) {
            npy_intp *strides;

            strides = NBF_STRIDES(bufferdata);
            for (iop = 0; iop < nop; ++iop) {
                ptrs[iop] += strides[iop];
            }
            return 1;
        }
    }
    else {
        NIT_ITERINDEX(iter) += NBF_SIZE(bufferdata);
    }

    NPY_IT_DBG_PRINT1("Iterator: Finished iteration %d of outer reduce loop\n",
                            (int)NBF_REDUCE_POS(bufferdata));
    /* The outer increment for the reduce double loop */
    if (++NBF_REDUCE_POS(bufferdata) < NBF_REDUCE_OUTERSIZE(bufferdata)) {
        npy_intp *reduce_outerstrides = NBF_REDUCE_OUTERSTRIDES(bufferdata);
        char **reduce_outerptrs = NBF_REDUCE_OUTERPTRS(bufferdata);
        for (iop = 0; iop < nop; ++iop) {
            char *ptr = reduce_outerptrs[iop] + reduce_outerstrides[iop];
            ptrs[iop] = ptr;
            reduce_outerptrs[iop] = ptr;
        }
        NBF_BUFITEREND(bufferdata) = NIT_ITERINDEX(iter) + NBF_SIZE(bufferdata);
        return 1;
    }

    /* Save the previously used data pointers */
    axisdata = NIT_AXISDATA(iter);
    memcpy(prev_dataptrs, NAD_PTRS(axisdata), NPY_SIZEOF_INTP*nop);

    /* Write back to the arrays */
    npyiter_copy_from_buffers(iter);

    /* Check if we're past the end */
    if (NIT_ITERINDEX(iter) >= NIT_ITEREND(iter)) {
        NBF_SIZE(bufferdata) = 0;
        return 0;
    }
    /* Increment to the next buffer */
    else {
        npyiter_goto_iterindex(iter, NIT_ITERINDEX(iter));
    }

    /* Prepare the next buffers and set iterend/size */
    npyiter_copy_to_buffers(iter, prev_dataptrs);

    return 1;
}

/**end repeat**/

/* iternext function that handles the buffering part */
static int
npyiter_buffered_iternext(NpyIter *iter)
{
    npy_uint32 itflags = NIT_ITFLAGS(iter);
    /*int ndim = NIT_NDIM(iter);*/
    int nop = NIT_NOP(iter);

    NpyIter_BufferData *bufferdata = NIT_BUFFERDATA(iter);

    /*
     * If the iterator handles the inner loop, need to increment all
     * the indices and pointers
     */
    if (!(itflags&NPY_ITFLAG_EXLOOP)) {
        /* Increment within the buffer */
        if (++NIT_ITERINDEX(iter) < NBF_BUFITEREND(bufferdata)) {
            int iop;
            npy_intp *strides;
            char **ptrs;

            strides = NBF_STRIDES(bufferdata);
            ptrs = NBF_PTRS(bufferdata);
            for (iop = 0; iop < nop; ++iop) {
                ptrs[iop] += strides[iop];
            }
            return 1;
        }
    }
    else {
        NIT_ITERINDEX(iter) += NBF_SIZE(bufferdata);
    }

    /* Write back to the arrays */
    npyiter_copy_from_buffers(iter);

    /* Check if we're past the end */
    if (NIT_ITERINDEX(iter) >= NIT_ITEREND(iter)) {
        NBF_SIZE(bufferdata) = 0;
        return 0;
    }
    /* Increment to the next buffer */
    else {
        npyiter_goto_iterindex(iter, NIT_ITERINDEX(iter));
    }

    /* Prepare the next buffers and set iterend/size */
    npyiter_copy_to_buffers(iter, NULL);

    return 1;
}

/**end repeat2**/
/**end repeat1**/
/**end repeat**/

/* Specialization of iternext for when the iteration size is 1 */
static int
npyiter_iternext_sizeone(NpyIter *iter)
{
    return 0;
}

/*NUMPY_API
 * Compute the specialized iteration function for an iterator
 *
 * If errmsg is non-NULL, it should point to a variable which will
 * receive the error message, and no Python exception will be set.
 * This is so that the function can be called from code not holding
 * the GIL.
 */
NPY_NO_EXPORT NpyIter_IterNextFunc *
NpyIter_GetIterNext(NpyIter *iter, char **errmsg)
{
    npy_uint32 itflags = NIT_ITFLAGS(iter);
    int ndim = NIT_NDIM(iter);
    int nop = NIT_NOP(iter);

    if (NIT_ITERSIZE(iter) < 0) {
        if (errmsg == NULL) {
            PyErr_SetString(PyExc_ValueError, "iterator is too large");
        }
        else {
            *errmsg = "iterator is too large";
        }
        return NULL;
    }

    /*
     * When there is just one iteration and buffering is disabled
     * the iternext function is very simple.
     */
    if (itflags&NPY_ITFLAG_ONEITERATION) {
        return &npyiter_iternext_sizeone;
    }

    /*
     * If buffering is enabled.
     */
    if (itflags&NPY_ITFLAG_BUFFER) {
        if (itflags&NPY_ITFLAG_REDUCE) {
            switch (nop) {
                case 1:
                    return &npyiter_buffered_reduce_iternext_iters1;
                case 2:
                    return &npyiter_buffered_reduce_iternext_iters2;
                case 3:
                    return &npyiter_buffered_reduce_iternext_iters3;
                case 4:
                    return &npyiter_buffered_reduce_iternext_iters4;
                default:
                    return &npyiter_buffered_reduce_iternext_itersANY;
            }
        }
        else {
            return &npyiter_buffered_iternext;
        }
    }

    /*
     * Ignore all the flags that don't affect the iterator memory
     * layout or the iternext function.  Currently only HASINDEX,
     * EXLOOP, and RANGE affect them here.
     */
    itflags &= (NPY_ITFLAG_HASINDEX|NPY_ITFLAG_EXLOOP|NPY_ITFLAG_RANGE);

    /* Switch statements let the compiler optimize this most effectively */
    switch (itflags) {
    /*
     * The combinations HASINDEX|EXLOOP and RANGE|EXLOOP are excluded
     * by the New functions
     */
/**begin repeat
 * #const_itflags = 0,
 *                  NPY_ITFLAG_HASINDEX,
 *                  NPY_ITFLAG_EXLOOP,
 *                  NPY_ITFLAG_RANGE,
 *                  NPY_ITFLAG_RANGE|NPY_ITFLAG_HASINDEX#
 * #tag_itflags = 0, IND, NOINN, RNG, RNGuIND#
 */
        case @const_itflags@:
            switch (ndim) {
/**begin repeat1
 * #const_ndim = 1, 2#
 * #tag_ndim = 1, 2#
 */
                case @const_ndim@:
                    switch (nop) {
/**begin repeat2
 * #const_nop = 1, 2#
 * #tag_nop = 1, 2#
 */
                        case @const_nop@:
                            return &npyiter_iternext_itflags@tag_itflags@_dims@tag_ndim@_iters@tag_nop@;
/**end repeat2**/
                        /* Not specialized on nop */
                        default:
                            return &npyiter_iternext_itflags@tag_itflags@_dims@tag_ndim@_itersANY;
                    }
/**end repeat1**/
                /* Not specialized on ndim */
                default:
                    switch (nop) {
/**begin repeat1
 * #const_nop = 1, 2#
 * #tag_nop = 1, 2#
 */
                        case @const_nop@:
                            return &npyiter_iternext_itflags@tag_itflags@_dimsANY_iters@tag_nop@;
/**end repeat1**/
                        /* Not specialized on nop */
                        default:
                            return &npyiter_iternext_itflags@tag_itflags@_dimsANY_itersANY;
                    }
            }
/**end repeat**/
    }
    /* The switch above should have caught all the possibilities. */
    if (errmsg == NULL) {
        PyErr_Format(PyExc_ValueError,
                "GetIterNext internal iterator error - unexpected "
                "itflags/ndim/nop combination (%04x/%d/%d)",
                (int)itflags, (int)ndim, (int)nop);
    }
    else {
        *errmsg = "GetIterNext internal iterator error - unexpected "
                  "itflags/ndim/nop combination";
    }
    return NULL;
}


/* SPECIALIZED getindex functions */

/**begin repeat
 * #const_itflags = 0,
 *    NPY_ITFLAG_HASINDEX,
 *    NPY_ITFLAG_IDENTPERM,
 *    NPY_ITFLAG_HASINDEX|NPY_ITFLAG_IDENTPERM,
 *    NPY_ITFLAG_NEGPERM,
 *    NPY_ITFLAG_HASINDEX|NPY_ITFLAG_NEGPERM,
 *    NPY_ITFLAG_BUFFER,
 *    NPY_ITFLAG_HASINDEX|NPY_ITFLAG_BUFFER,
 *    NPY_ITFLAG_IDENTPERM|NPY_ITFLAG_BUFFER,
 *    NPY_ITFLAG_HASINDEX|NPY_ITFLAG_IDENTPERM|NPY_ITFLAG_BUFFER,
 *    NPY_ITFLAG_NEGPERM|NPY_ITFLAG_BUFFER,
 *    NPY_ITFLAG_HASINDEX|NPY_ITFLAG_NEGPERM|NPY_ITFLAG_BUFFER#
 * #tag_itflags = 0, IND, IDP, INDuIDP, NEGP, INDuNEGP,
 *                BUF, INDuBUF, IDPuBUF, INDuIDPuBUF, NEGPuBUF, INDuNEGPuBUF#
 */
static void
npyiter_get_multi_index_itflags@tag_itflags@(
                        NpyIter *iter, npy_intp *out_multi_index)
{
    const npy_uint32 itflags = @const_itflags@;
    int idim, ndim = NIT_NDIM(iter);
    int nop = NIT_NOP(iter);

    npy_intp sizeof_axisdata;
    NpyIter_AxisData *axisdata;
#if !((@const_itflags@)&NPY_ITFLAG_IDENTPERM)
    npy_int8 *perm = NIT_PERM(iter);
#endif

    axisdata = NIT_AXISDATA(iter);
    sizeof_axisdata = NIT_AXISDATA_SIZEOF(itflags, ndim, nop);
#if ((@const_itflags@)&NPY_ITFLAG_IDENTPERM)
    out_multi_index += ndim-1;
    for(idim = 0; idim < ndim; ++idim, --out_multi_index,
                                    NIT_ADVANCE_AXISDATA(axisdata, 1)) {
        *out_multi_index = NAD_INDEX(axisdata);
    }
#elif !((@const_itflags@)&NPY_ITFLAG_NEGPERM)
    for(idim = 0; idim < ndim; ++idim, NIT_ADVANCE_AXISDATA(axisdata, 1)) {
        npy_int8 p = perm[idim];
        out_multi_index[ndim-p-1] = NAD_INDEX(axisdata);
    }
#else
    for(idim = 0; idim < ndim; ++idim, NIT_ADVANCE_AXISDATA(axisdata, 1)) {
        npy_int8 p = perm[idim];
        if (p < 0) {
            /* If the perm entry is negative, reverse the index */
            out_multi_index[ndim+p] = NAD_SHAPE(axisdata) - NAD_INDEX(axisdata) - 1;
        }
        else {
            out_multi_index[ndim-p-1] = NAD_INDEX(axisdata);
        }
    }
#endif /* not ident perm */
}
/**end repeat**/

/*NUMPY_API
 * Compute a specialized get_multi_index function for the iterator
 *
 * If errmsg is non-NULL, it should point to a variable which will
 * receive the error message, and no Python exception will be set.
 * This is so that the function can be called from code not holding
 * the GIL.
 */
NPY_NO_EXPORT NpyIter_GetMultiIndexFunc *
NpyIter_GetGetMultiIndex(NpyIter *iter, char **errmsg)
{
    npy_uint32 itflags = NIT_ITFLAGS(iter);
    int ndim = NIT_NDIM(iter);
    int nop = NIT_NOP(iter);

    /* These flags must be correct */
    if ((itflags&(NPY_ITFLAG_HASMULTIINDEX|NPY_ITFLAG_DELAYBUF)) !=
            NPY_ITFLAG_HASMULTIINDEX) {
        if (!(itflags&NPY_ITFLAG_HASMULTIINDEX)) {
            if (errmsg == NULL) {
                PyErr_SetString(PyExc_ValueError,
                        "Cannot retrieve a GetMultiIndex function for an "
                        "iterator that doesn't track a multi-index.");
            }
            else {
                *errmsg = "Cannot retrieve a GetMultiIndex function for an "
                          "iterator that doesn't track a multi-index.";
            }
            return NULL;
        }
        else {
            if (errmsg == NULL) {
                PyErr_SetString(PyExc_ValueError,
                        "Cannot retrieve a GetMultiIndex function for an "
                        "iterator that used DELAY_BUFALLOC before a Reset call");
            }
            else {
                *errmsg = "Cannot retrieve a GetMultiIndex function for an "
                          "iterator that used DELAY_BUFALLOC before a "
                          "Reset call";
            }
            return NULL;
        }
    }

    /*
     * Only these flags affect the iterator memory layout or
     * the get_multi_index behavior. IDENTPERM and NEGPERM are mutually
     * exclusive, so that reduces the number of cases slightly.
     */
    itflags &= (NPY_ITFLAG_HASINDEX |
                NPY_ITFLAG_IDENTPERM |
                NPY_ITFLAG_NEGPERM |
                NPY_ITFLAG_BUFFER);

    switch (itflags) {
/**begin repeat
 * #const_itflags = 0,
 *    NPY_ITFLAG_HASINDEX,
 *    NPY_ITFLAG_IDENTPERM,
 *    NPY_ITFLAG_HASINDEX|NPY_ITFLAG_IDENTPERM,
 *    NPY_ITFLAG_NEGPERM,
 *    NPY_ITFLAG_HASINDEX|NPY_ITFLAG_NEGPERM,
 *    NPY_ITFLAG_BUFFER,
 *    NPY_ITFLAG_HASINDEX|NPY_ITFLAG_BUFFER,
 *    NPY_ITFLAG_IDENTPERM|NPY_ITFLAG_BUFFER,
 *    NPY_ITFLAG_HASINDEX|NPY_ITFLAG_IDENTPERM|NPY_ITFLAG_BUFFER,
 *    NPY_ITFLAG_NEGPERM|NPY_ITFLAG_BUFFER,
 *    NPY_ITFLAG_HASINDEX|NPY_ITFLAG_NEGPERM|NPY_ITFLAG_BUFFER#
 * #tag_itflags = 0, IND, IDP, INDuIDP, NEGP, INDuNEGP,
 *                BUF, INDuBUF, IDPuBUF, INDuIDPuBUF, NEGPuBUF, INDuNEGPuBUF#
 */
        case @const_itflags@:
            return npyiter_get_multi_index_itflags@tag_itflags@;
/**end repeat**/
    }
    /* The switch above should have caught all the possibilities. */
    if (errmsg == NULL) {
        PyErr_Format(PyExc_ValueError,
                "GetGetMultiIndex internal iterator error - unexpected "
                "itflags/ndim/nop combination (%04x/%d/%d)",
                (int)itflags, (int)ndim, (int)nop);
    }
    else {
        *errmsg = "GetGetMultiIndex internal iterator error - unexpected "
                  "itflags/ndim/nop combination";
    }
    return NULL;

}

#undef NPY_ITERATOR_IMPLEMENTATION_CODE