/** @file     vl_binsum.def
 ** @author   Andrea Vedaldi
 ** @brief    vl_binsum.def - vl_binsum template
 **/

#ifndef ITERATE_1
# define ITERATE_1 1
# include "vl_binsum.def"
# undef ITERATE_1
# define ITERATE_1 2
# include "vl_binsum.def"
# undef ITERATE_1
# define ITERATE_1 3
# include "vl_binsum.def"
# undef ITERATE_1
# define ITERATE_1 4
# include "vl_binsum.def"
# undef ITERATE_1
# define ITERATE_1 5
# include "vl_binsum.def"
# undef ITERATE_1
# define ITERATE_1 6
# include "vl_binsum.def"
# undef ITERATE_1
# define ITERATE_1 7
# include "vl_binsum.def"
# undef ITERATE_1
# define ITERATE_1 8
# include "vl_binsum.def"
# undef ITERATE_1
# define ITERATE_1 9
# include "vl_binsum.def"
# undef ITERATE_1
# define ITERATE_1 10
# include "vl_binsum.def"
# undef ITERATE_1
#else
# ifndef ITERATE_2
#  define ITERATE_2 1
#  include "vl_binsum.def"
#  undef ITERATE_2
#  define ITERATE_2 2
#  include "vl_binsum.def"
#  undef ITERATE_2
#  define ITERATE_2 3
#  include "vl_binsum.def"
#  undef ITERATE_2
#  define ITERATE_2 4
#  include "vl_binsum.def"
#  undef ITERATE_2
#  define ITERATE_2 5
#  include "vl_binsum.def"
#  undef ITERATE_2
#  define ITERATE_2 6
#  include "vl_binsum.def"
#  undef ITERATE_2
#  define ITERATE_2 7
#  include "vl_binsum.def"
#  undef ITERATE_2
#  define ITERATE_2 8
#  include "vl_binsum.def"
#  undef ITERATE_2
#  define ITERATE_2 9
#  include "vl_binsum.def"
#  undef ITERATE_2
#  define ITERATE_2 10
#  include "vl_binsum.def"
#  undef ITERATE_2
# endif
#endif

#if defined(ITERATE_1) & defined(ITERATE_2)
# define __VALUE_T__ TYPE(ITERATE_1)
# define __INDEX_T__ TYPE(ITERATE_2)

static void
VL_XCAT4(_vl_binsum_,__VALUE_T__,_,__INDEX_T__)
(mxArray * H_array,
 mxArray const * X_array,
 mxArray const * B_array,
 vl_index dim)
{
  vl_size KH, KX, KB ;
  vl_index j ;
  __VALUE_T__ * H_pt ;
  __VALUE_T__ const * X_pt ;
  __INDEX_T__ const * B_pt ;
  __INDEX_T__ const * B_end ;

  KH = mxGetNumberOfElements(H_array) ; /* accumulator */
  KX = mxGetNumberOfElements(X_array) ; /* values */
  KB = mxGetNumberOfElements(B_array) ; /* accumulator indeces */

  H_pt = (__VALUE_T__ *) mxGetData(H_array) ;
  X_pt = (__VALUE_T__ const *) mxGetData(X_array) ;
  B_pt = (__INDEX_T__ const *) mxGetData(B_array) ;

  B_end = B_pt + KB ;

  if ((KX != KB) && (KX > 1)) {
    vlmxError(vlmxErrInvalidArgument,
              "VALUES and INDEXES do not have the same number of elements nor VALUES is a scalar.") ;
  }

  /* All dimensions mode ------------------------------------------- */
  if (dim <= 0) {

    while (B_pt < B_end) {

      /* next bin to accumulate */
      j = (vl_index)(*B_pt) - 1 ;

      /* bin index out of bounds ?*/
      if(j < -1 || j >= (signed) KH) {
        vlmxError(vlmxErrInconsistentData,
                  "The index INDEXES(%"  VL_FMT_INDEX ") = %" VL_FMT_INDEX " is out of bound.",
                  (vl_index)(B_pt - (__INDEX_T__ const *)mxGetData(B_array) + 1),
                  j + 1) ;
      }

      /* accumulate (but skip null indeces) */
      if (j >= 0) {
        H_pt[j] += *X_pt ;
      }

      /* except for the scalar X case, keep X and B synchronized */
      ++ B_pt ;
      if (KX > 1) ++ X_pt ;
    }
  }

  /* One dimension mode -------------------------------------------- */
  else {
    vl_index k ;
    vl_size d  = dim - 1 ;

    mwSize HD = mxGetNumberOfDimensions(H_array) ;
    mwSize XD = mxGetNumberOfDimensions(X_array) ;
    mwSize BD = mxGetNumberOfDimensions(B_array) ;

    mwSize const * Hdims = mxGetDimensions(H_array) ;
    mwSize const * Xdims = mxGetDimensions(X_array) ;
    mwSize const * Bdims = mxGetDimensions(B_array) ;

    __INDEX_T__ const * B_brk ;
    __INDEX_T__ const * B_nbrk ;

    vl_size srd ;

    /* We need to check a few more details about the matrices */
    if (d >= HD) {
      vlmxError(vlmxErrInconsistentData,
                "DIM is smaller than one or larger than the number of dimensions of ACCUMULATOR.") ;
    }

    /*
     Here either B has the same number of dimensions of H, or B has
     exactly one dimension less and DIM=end.  The latter is a
     special case due to the fact that MATLAB deletes singleton
     dimensions at the ends of arrays, so it would be impossible to
     operate with DIM=end and size(B,end)=1, which is a logically
     acceptable case.
     */

    if (HD != BD) {
      if (! d == HD - 1 && BD == HD - 1) {
        vlmxError(vlmxErrInconsistentData,
                  "ACCUMULATOR and INDEXES do not have the same number of dimensions.") ;
      }
    }

    if ((BD != XD) && (KX > 1)) {
      vlmxError(vlmxErrInconsistentData,
                "VALUES is not a scalar nor has the same number of dimensions of INDEXES.") ;
    }

    /* This will contain the stride required to advance to the next
     * element along dimension DIM. This is the product of all
     * dimensions < d. */

    srd = 1 ;

    for(k = 0 ; k < (signed) XD ; ++k) {
      if (KX > 1 &&  Xdims[k] != Bdims[k]) {
        vlmxError(vlmxErrInconsistentData,
                  "VALUES and INDEXES have incompatible dimensions.") ;
      }
      if (k != (signed) d && (Bdims[k] != Hdims[k])) {
        vlmxError(vlmxErrInconsistentData,
                  "VALUES and ACCUMULATOR have incompatible dimensions.") ;
      }
      if (k < (signed) d) {
        srd = srd * Bdims[k] ;
      }
    }

    /* We sum efficiently by a single pass of B_pt in memory order.
     * This however makes the algorithm a bit more involved. How the
     * sum is performed is easeir understood by an example. Consider
     * the case d = 3 and BD = 5. So elements in B are indexed as
     *
     *  B[i0, i1, id, i4, i5]  (note that id=i3)
     *
     * where the indexes (i0,i1,id,i4,i5) are scanned in column-major
     * orer. For each of such elements, we acces the element
     *
     *  R[i0, i1, B[i0,i1,id,i4,i5], i4, i5]
     *
     * For greater efficiency, we do not compute the multi-indexes
     * explicity; instead we advance B_pt at each iteration and we
     * keep R_pt properly synchronized.
     *
     * In particular, at each iteration we want R_pt to point to
     *
     *  R[i0, i1, 0, i4, i5]
     *
     * Therefore, whenever advancing B_pt correspnds to advancing
     *
     * a - i0, i1 : we advance R_pt
     * b - id     : we advance R_pt and then we subtract srd to move id back
     *              one position;
     * c - i4,i5  : we do the same as before, but then we need to add
     *              srd * Hdims[d] *  elements to move R_pt up one place
     *              after dimension d.
     *
     * We can easily keep track of the three cases without computing
     * explicitly the indexs. In fact, case (a) occurs srd-1
     * consecutive steps and case (b) at the srd step. Similarly, case
     * (c) accurs at step srd * Bdims[d] steps. The pattern repeats
     * regularly.
     */

    KH     = Hdims[d] ;
    B_brk  = B_pt + srd ;              /* next time case (b) */
    B_nbrk = B_pt + srd * Bdims [d] ;  /* next time case (c) */

    while (B_pt < B_end) {

      /* next bin to accumulate */
      j = (vl_index)(*B_pt) - 1;

      /* index out of bounds? */
      if(j < -1 || j >= (signed) KH) {
        vlmxError(vlmxErrInconsistentData,
                  "The index INDEXES(%"  VL_FMT_INDEX ") = %" VL_FMT_INDEX " is out of bounds.",
                  (vl_index)(B_pt - (__INDEX_T__ * const)mxGetData(B_array)) + 1,
                  j + 1) ;
      }

      /* accumulate (but skip null indeces) */
      if (j >= 0) {
        H_pt [j * srd] += *X_pt ;
      }

      /* next element */
      if (KX > 1) X_pt++ ;
      B_pt ++ ;
      H_pt ++ ;

      if (B_pt == B_brk) {                 /* case (b) occurs */
        B_brk += srd ;                     /* next time case (b) */
        H_pt  -= srd ;
        if (B_pt == B_nbrk) {              /* case (c) occurs */
          B_nbrk += srd * Bdims[d] ;       /* next time case (c) */
          H_pt   += srd * Hdims[d] ;
        }
      }
    }
  }
}

#undef __VALUE_T__
#undef __INDEX_T__
#endif