Blame randist/binomial_tpe.c

Packit 67cb25
/* randist/binomial_tpe.c
Packit 67cb25
 * 
Packit 67cb25
 * Copyright (C) 1996, 2003, 2007 James Theiler, Brian Gough
Packit 67cb25
 * 
Packit 67cb25
 * This program is free software; you can redistribute it and/or modify
Packit 67cb25
 * it under the terms of the GNU General Public License as published by
Packit 67cb25
 * the Free Software Foundation; either version 3 of the License, or (at
Packit 67cb25
 * your option) any later version.
Packit 67cb25
 * 
Packit 67cb25
 * This program is distributed in the hope that it will be useful, but
Packit 67cb25
 * WITHOUT ANY WARRANTY; without even the implied warranty of
Packit 67cb25
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
Packit 67cb25
 * General Public License for more details.
Packit 67cb25
 * 
Packit 67cb25
 * You should have received a copy of the GNU General Public License
Packit 67cb25
 * along with this program; if not, write to the Free Software
Packit 67cb25
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
Packit 67cb25
 */
Packit 67cb25
Packit 67cb25
#include <config.h>
Packit 67cb25
#include <math.h>
Packit 67cb25
#include <gsl/gsl_rng.h>
Packit 67cb25
#include <gsl/gsl_randist.h>
Packit 67cb25
#include <gsl/gsl_pow_int.h>
Packit 67cb25
#include <gsl/gsl_sf_gamma.h>
Packit 67cb25
Packit 67cb25
/* The binomial distribution has the form,
Packit 67cb25
Packit 67cb25
   f(x) =  n!/(x!(n-x)!) * p^x (1-p)^(n-x) for integer 0 <= x <= n
Packit 67cb25
        =  0                               otherwise
Packit 67cb25
Packit 67cb25
   This implementation follows the public domain ranlib function
Packit 67cb25
   "ignbin", the bulk of which is the BTPE (Binomial Triangle
Packit 67cb25
   Parallelogram Exponential) algorithm introduced in
Packit 67cb25
   Kachitvichyanukul and Schmeiser[1].  It has been translated to use
Packit 67cb25
   modern C coding standards.
Packit 67cb25
Packit 67cb25
   If n is small and/or p is near 0 or near 1 (specifically, if
Packit 67cb25
   n*min(p,1-p) < SMALL_MEAN), then a different algorithm, called
Packit 67cb25
   BINV, is used which has an average runtime that scales linearly
Packit 67cb25
   with n*min(p,1-p).
Packit 67cb25
Packit 67cb25
   But for larger problems, the BTPE algorithm takes the form of two
Packit 67cb25
   functions b(x) and t(x) -- "bottom" and "top" -- for which b(x) <
Packit 67cb25
   f(x)/f(M) < t(x), with M = floor(n*p+p).  b(x) defines a triangular
Packit 67cb25
   region, and t(x) includes a parallelogram and two tails.  Details
Packit 67cb25
   (including a nice drawing) are in the paper.
Packit 67cb25
Packit 67cb25
   [1] Kachitvichyanukul, V. and Schmeiser, B. W.  Binomial Random
Packit 67cb25
   Variate Generation.  Communications of the ACM, 31, 2 (February,
Packit 67cb25
   1988) 216.
Packit 67cb25
Packit 67cb25
   Note, Bruce Schmeiser (personal communication) points out that if
Packit 67cb25
   you want very fast binomial deviates, and you are happy with
Packit 67cb25
   approximate results, and/or n and n*p are both large, then you can
Packit 67cb25
   just use gaussian estimates: mean=n*p, variance=n*p*(1-p).
Packit 67cb25
Packit 67cb25
   This implementation by James Theiler, April 2003, after obtaining
Packit 67cb25
   permission -- and some good advice -- from Drs. Kachitvichyanukul
Packit 67cb25
   and Schmeiser to use their code as a starting point, and then doing
Packit 67cb25
   a little bit of tweaking.
Packit 67cb25
Packit 67cb25
   Additional polishing for GSL coding standards by Brian Gough.  */
Packit 67cb25
Packit 67cb25
#define SMALL_MEAN 14           /* If n*p < SMALL_MEAN then use BINV
Packit 67cb25
                                   algorithm. The ranlib
Packit 67cb25
                                   implementation used cutoff=30; but
Packit 67cb25
                                   on my computer 14 works better */
Packit 67cb25
Packit 67cb25
#define BINV_CUTOFF 110         /* In BINV, do not permit ix too large */
Packit 67cb25
Packit 67cb25
#define FAR_FROM_MEAN 20        /* If ix-n*p is larger than this, then
Packit 67cb25
                                   use the "squeeze" algorithm.
Packit 67cb25
                                   Ranlib used 20, and this seems to
Packit 67cb25
                                   be the best choice on my machine as
Packit 67cb25
                                   well */
Packit 67cb25
Packit 67cb25
#define LNFACT(x) gsl_sf_lnfact(x)
Packit 67cb25
Packit 67cb25
inline static double
Packit 67cb25
Stirling (double y1)
Packit 67cb25
{
Packit 67cb25
  double y2 = y1 * y1;
Packit 67cb25
  double s =
Packit 67cb25
    (13860.0 -
Packit 67cb25
     (462.0 - (132.0 - (99.0 - 140.0 / y2) / y2) / y2) / y2) / y1 / 166320.0;
Packit 67cb25
  return s;
Packit 67cb25
}
Packit 67cb25
Packit 67cb25
unsigned int
Packit 67cb25
gsl_ran_binomial_tpe (const gsl_rng * rng, double p, unsigned int n)
Packit 67cb25
{
Packit 67cb25
  return gsl_ran_binomial (rng, p, n);
Packit 67cb25
}
Packit 67cb25
Packit 67cb25
unsigned int
Packit 67cb25
gsl_ran_binomial (const gsl_rng * rng, double p, unsigned int n)
Packit 67cb25
{
Packit 67cb25
  int ix;                       /* return value */
Packit 67cb25
  int flipped = 0;
Packit 67cb25
  double q, s, np;
Packit 67cb25
Packit 67cb25
  if (n == 0)
Packit 67cb25
    return 0;
Packit 67cb25
Packit 67cb25
  if (p > 0.5)
Packit 67cb25
    {
Packit 67cb25
      p = 1.0 - p;              /* work with small p */
Packit 67cb25
      flipped = 1;
Packit 67cb25
    }
Packit 67cb25
Packit 67cb25
  q = 1 - p;
Packit 67cb25
  s = p / q;
Packit 67cb25
  np = n * p;
Packit 67cb25
Packit 67cb25
  /* Inverse cdf logic for small mean (BINV in K+S) */
Packit 67cb25
Packit 67cb25
  if (np < SMALL_MEAN)
Packit 67cb25
    {
Packit 67cb25
      double f0 = gsl_pow_uint (q, n);   /* f(x), starting with x=0 */
Packit 67cb25
Packit 67cb25
      while (1)
Packit 67cb25
        {
Packit 67cb25
          /* This while(1) loop will almost certainly only loop once; but
Packit 67cb25
           * if u=1 to within a few epsilons of machine precision, then it
Packit 67cb25
           * is possible for roundoff to prevent the main loop over ix to
Packit 67cb25
           * achieve its proper value.  following the ranlib implementation,
Packit 67cb25
           * we introduce a check for that situation, and when it occurs,
Packit 67cb25
           * we just try again.
Packit 67cb25
           */
Packit 67cb25
Packit 67cb25
          double f = f0;
Packit 67cb25
          double u = gsl_rng_uniform (rng);
Packit 67cb25
Packit 67cb25
          for (ix = 0; ix <= BINV_CUTOFF; ++ix)
Packit 67cb25
            {
Packit 67cb25
              if (u < f)
Packit 67cb25
                goto Finish;
Packit 67cb25
              u -= f;
Packit 67cb25
              /* Use recursion f(x+1) = f(x)*[(n-x)/(x+1)]*[p/(1-p)] */
Packit 67cb25
              f *= s * (n - ix) / (ix + 1);
Packit 67cb25
            }
Packit 67cb25
Packit 67cb25
          /* It should be the case that the 'goto Finish' was encountered
Packit 67cb25
           * before this point was ever reached.  But if we have reached
Packit 67cb25
           * this point, then roundoff has prevented u from decreasing
Packit 67cb25
           * all the way to zero.  This can happen only if the initial u
Packit 67cb25
           * was very nearly equal to 1, which is a rare situation.  In
Packit 67cb25
           * that rare situation, we just try again.
Packit 67cb25
           *
Packit 67cb25
           * Note, following the ranlib implementation, we loop ix only to
Packit 67cb25
           * a hardcoded value of SMALL_MEAN_LARGE_N=110; we could have
Packit 67cb25
           * looped to n, and 99.99...% of the time it won't matter.  This
Packit 67cb25
           * choice, I think is a little more robust against the rare
Packit 67cb25
           * roundoff error.  If n>LARGE_N, then it is technically
Packit 67cb25
           * possible for ix>LARGE_N, but it is astronomically rare, and
Packit 67cb25
           * if ix is that large, it is more likely due to roundoff than
Packit 67cb25
           * probability, so better to nip it at LARGE_N than to take a
Packit 67cb25
           * chance that roundoff will somehow conspire to produce an even
Packit 67cb25
           * larger (and more improbable) ix.  If n
Packit 67cb25
           * ix=n, f=0, and the loop will continue until ix=LARGE_N.
Packit 67cb25
           */
Packit 67cb25
        }
Packit 67cb25
    }
Packit 67cb25
  else
Packit 67cb25
    {
Packit 67cb25
      /* For n >= SMALL_MEAN, we invoke the BTPE algorithm */
Packit 67cb25
Packit 67cb25
      int k;
Packit 67cb25
Packit 67cb25
      double ffm = np + p;      /* ffm = n*p+p             */
Packit 67cb25
      int m = (int) ffm;        /* m = int floor[n*p+p]    */
Packit 67cb25
      double fm = m;            /* fm = double m;          */
Packit 67cb25
      double xm = fm + 0.5;     /* xm = half integer mean (tip of triangle)  */
Packit 67cb25
      double npq = np * q;      /* npq = n*p*q            */
Packit 67cb25
Packit 67cb25
      /* Compute cumulative area of tri, para, exp tails */
Packit 67cb25
Packit 67cb25
      /* p1: radius of triangle region; since height=1, also: area of region */
Packit 67cb25
      /* p2: p1 + area of parallelogram region */
Packit 67cb25
      /* p3: p2 + area of left tail */
Packit 67cb25
      /* p4: p3 + area of right tail */
Packit 67cb25
      /* pi/p4: probability of i'th area (i=1,2,3,4) */
Packit 67cb25
Packit 67cb25
      /* Note: magic numbers 2.195, 4.6, 0.134, 20.5, 15.3 */
Packit 67cb25
      /* These magic numbers are not adjustable...at least not easily! */
Packit 67cb25
Packit 67cb25
      double p1 = floor (2.195 * sqrt (npq) - 4.6 * q) + 0.5;
Packit 67cb25
Packit 67cb25
      /* xl, xr: left and right edges of triangle */
Packit 67cb25
      double xl = xm - p1;
Packit 67cb25
      double xr = xm + p1;
Packit 67cb25
Packit 67cb25
      /* Parameter of exponential tails */
Packit 67cb25
      /* Left tail:  t(x) = c*exp(-lambda_l*[xl - (x+0.5)]) */
Packit 67cb25
      /* Right tail: t(x) = c*exp(-lambda_r*[(x+0.5) - xr]) */
Packit 67cb25
Packit 67cb25
      double c = 0.134 + 20.5 / (15.3 + fm);
Packit 67cb25
      double p2 = p1 * (1.0 + c + c);
Packit 67cb25
Packit 67cb25
      double al = (ffm - xl) / (ffm - xl * p);
Packit 67cb25
      double lambda_l = al * (1.0 + 0.5 * al);
Packit 67cb25
      double ar = (xr - ffm) / (xr * q);
Packit 67cb25
      double lambda_r = ar * (1.0 + 0.5 * ar);
Packit 67cb25
      double p3 = p2 + c / lambda_l;
Packit 67cb25
      double p4 = p3 + c / lambda_r;
Packit 67cb25
Packit 67cb25
      double var, accept;
Packit 67cb25
      double u, v;              /* random variates */
Packit 67cb25
Packit 67cb25
    TryAgain:
Packit 67cb25
Packit 67cb25
      /* generate random variates, u specifies which region: Tri, Par, Tail */
Packit 67cb25
      u = gsl_rng_uniform (rng) * p4;
Packit 67cb25
      v = gsl_rng_uniform (rng);
Packit 67cb25
Packit 67cb25
      if (u <= p1)
Packit 67cb25
        {
Packit 67cb25
          /* Triangular region */
Packit 67cb25
          ix = (int) (xm - p1 * v + u);
Packit 67cb25
          goto Finish;
Packit 67cb25
        }
Packit 67cb25
      else if (u <= p2)
Packit 67cb25
        {
Packit 67cb25
          /* Parallelogram region */
Packit 67cb25
          double x = xl + (u - p1) / c;
Packit 67cb25
          v = v * c + 1.0 - fabs (x - xm) / p1;
Packit 67cb25
          if (v > 1.0 || v <= 0.0)
Packit 67cb25
            goto TryAgain;
Packit 67cb25
          ix = (int) x;
Packit 67cb25
        }
Packit 67cb25
      else if (u <= p3)
Packit 67cb25
        {
Packit 67cb25
          /* Left tail */
Packit 67cb25
          ix = (int) (xl + log (v) / lambda_l);
Packit 67cb25
          if (ix < 0)
Packit 67cb25
            goto TryAgain;
Packit 67cb25
          v *= ((u - p2) * lambda_l);
Packit 67cb25
        }
Packit 67cb25
      else
Packit 67cb25
        {
Packit 67cb25
          /* Right tail */
Packit 67cb25
          ix = (int) (xr - log (v) / lambda_r);
Packit 67cb25
          if (ix > (double) n)
Packit 67cb25
            goto TryAgain;
Packit 67cb25
          v *= ((u - p3) * lambda_r);
Packit 67cb25
        }
Packit 67cb25
Packit 67cb25
      /* At this point, the goal is to test whether v <= f(x)/f(m) 
Packit 67cb25
       *
Packit 67cb25
       *  v <= f(x)/f(m) = (m!(n-m)! / (x!(n-x)!)) * (p/q)^{x-m}
Packit 67cb25
       *
Packit 67cb25
       */
Packit 67cb25
Packit 67cb25
      /* Here is a direct test using logarithms.  It is a little
Packit 67cb25
       * slower than the various "squeezing" computations below, but
Packit 67cb25
       * if things are working, it should give exactly the same answer
Packit 67cb25
       * (given the same random number seed).  */
Packit 67cb25
Packit 67cb25
#ifdef DIRECT
Packit 67cb25
      var = log (v);
Packit 67cb25
Packit 67cb25
      accept =
Packit 67cb25
        LNFACT (m) + LNFACT (n - m) - LNFACT (ix) - LNFACT (n - ix)
Packit 67cb25
        + (ix - m) * log (p / q);
Packit 67cb25
Packit 67cb25
#else /* SQUEEZE METHOD */
Packit 67cb25
Packit 67cb25
      /* More efficient determination of whether v < f(x)/f(M) */
Packit 67cb25
Packit 67cb25
      k = abs (ix - m);
Packit 67cb25
Packit 67cb25
      if (k <= FAR_FROM_MEAN)
Packit 67cb25
        {
Packit 67cb25
          /* 
Packit 67cb25
           * If ix near m (ie, |ix-m|
Packit 67cb25
           * explicit evaluation using recursion relation for f(x)
Packit 67cb25
           */
Packit 67cb25
          double g = (n + 1) * s;
Packit 67cb25
          double f = 1.0;
Packit 67cb25
Packit 67cb25
          var = v;
Packit 67cb25
Packit 67cb25
          if (m < ix)
Packit 67cb25
            {
Packit 67cb25
              int i;
Packit 67cb25
              for (i = m + 1; i <= ix; i++)
Packit 67cb25
                {
Packit 67cb25
                  f *= (g / i - s);
Packit 67cb25
                }
Packit 67cb25
            }
Packit 67cb25
          else if (m > ix)
Packit 67cb25
            {
Packit 67cb25
              int i;
Packit 67cb25
              for (i = ix + 1; i <= m; i++)
Packit 67cb25
                {
Packit 67cb25
                  f /= (g / i - s);
Packit 67cb25
                }
Packit 67cb25
            }
Packit 67cb25
Packit 67cb25
          accept = f;
Packit 67cb25
        }
Packit 67cb25
      else
Packit 67cb25
        {
Packit 67cb25
          /* If ix is far from the mean m: k=ABS(ix-m) large */
Packit 67cb25
Packit 67cb25
          var = log (v);
Packit 67cb25
Packit 67cb25
          if (k < npq / 2 - 1)
Packit 67cb25
            {
Packit 67cb25
              /* "Squeeze" using upper and lower bounds on
Packit 67cb25
               * log(f(x)) The squeeze condition was derived
Packit 67cb25
               * under the condition k < npq/2-1 */
Packit 67cb25
              double amaxp =
Packit 67cb25
                k / npq * ((k * (k / 3.0 + 0.625) + (1.0 / 6.0)) / npq + 0.5);
Packit 67cb25
              double ynorm = -(k * k / (2.0 * npq));
Packit 67cb25
              if (var < ynorm - amaxp)
Packit 67cb25
                goto Finish;
Packit 67cb25
              if (var > ynorm + amaxp)
Packit 67cb25
                goto TryAgain;
Packit 67cb25
            }
Packit 67cb25
Packit 67cb25
          /* Now, again: do the test log(v) vs. log f(x)/f(M) */
Packit 67cb25
Packit 67cb25
#if USE_EXACT
Packit 67cb25
          /* This is equivalent to the above, but is a little (~20%) slower */
Packit 67cb25
          /* There are five log's vs three above, maybe that's it? */
Packit 67cb25
Packit 67cb25
          accept = LNFACT (m) + LNFACT (n - m)
Packit 67cb25
            - LNFACT (ix) - LNFACT (n - ix) + (ix - m) * log (p / q);
Packit 67cb25
Packit 67cb25
#else /* USE STIRLING */
Packit 67cb25
          /* The "#define Stirling" above corresponds to the first five
Packit 67cb25
           * terms in asymptoic formula for
Packit 67cb25
           * log Gamma (y) - (y-0.5)log(y) + y - 0.5 log(2*pi);
Packit 67cb25
           * See Abramowitz and Stegun, eq 6.1.40
Packit 67cb25
           */
Packit 67cb25
Packit 67cb25
          /* Note below: two Stirling's are added, and two are
Packit 67cb25
           * subtracted.  In both K+S, and in the ranlib
Packit 67cb25
           * implementation, all four are added.  I (jt) believe that
Packit 67cb25
           * is a mistake -- this has been confirmed by personal
Packit 67cb25
           * correspondence w/ Dr. Kachitvichyanukul.  Note, however,
Packit 67cb25
           * the corrections are so small, that I couldn't find an
Packit 67cb25
           * example where it made a difference that could be
Packit 67cb25
           * observed, let alone tested.  In fact, define'ing Stirling
Packit 67cb25
           * to be zero gave identical results!!  In practice, alv is
Packit 67cb25
           * O(1), ranging 0 to -10 or so, while the Stirling
Packit 67cb25
           * correction is typically O(10^{-5}) ...setting the
Packit 67cb25
           * correction to zero gives about a 2% performance boost;
Packit 67cb25
           * might as well keep it just to be pendantic.  */
Packit 67cb25
Packit 67cb25
          {
Packit 67cb25
            double x1 = ix + 1.0;
Packit 67cb25
            double w1 = n - ix + 1.0;
Packit 67cb25
            double f1 = fm + 1.0;
Packit 67cb25
            double z1 = n + 1.0 - fm;
Packit 67cb25
Packit 67cb25
            accept = xm * log (f1 / x1) + (n - m + 0.5) * log (z1 / w1)
Packit 67cb25
              + (ix - m) * log (w1 * p / (x1 * q))
Packit 67cb25
              + Stirling (f1) + Stirling (z1) - Stirling (x1) - Stirling (w1);
Packit 67cb25
          }
Packit 67cb25
#endif
Packit 67cb25
#endif
Packit 67cb25
        }
Packit 67cb25
Packit 67cb25
Packit 67cb25
      if (var <= accept)
Packit 67cb25
        {
Packit 67cb25
          goto Finish;
Packit 67cb25
        }
Packit 67cb25
      else
Packit 67cb25
        {
Packit 67cb25
          goto TryAgain;
Packit 67cb25
        }
Packit 67cb25
    }
Packit 67cb25
Packit 67cb25
Finish:
Packit 67cb25
Packit 67cb25
  return (flipped) ? (n - ix) : (unsigned int)ix;
Packit 67cb25
}