Blame libcelt/vq.c

Packit 664db3
/* (C) 2007-2008 Jean-Marc Valin, CSIRO
Packit 664db3
*/
Packit 664db3
/*
Packit 664db3
   Redistribution and use in source and binary forms, with or without
Packit 664db3
   modification, are permitted provided that the following conditions
Packit 664db3
   are met:
Packit 664db3
   
Packit 664db3
   - Redistributions of source code must retain the above copyright
Packit 664db3
   notice, this list of conditions and the following disclaimer.
Packit 664db3
   
Packit 664db3
   - Redistributions in binary form must reproduce the above copyright
Packit 664db3
   notice, this list of conditions and the following disclaimer in the
Packit 664db3
   documentation and/or other materials provided with the distribution.
Packit 664db3
   
Packit 664db3
   - Neither the name of the Xiph.org Foundation nor the names of its
Packit 664db3
   contributors may be used to endorse or promote products derived from
Packit 664db3
   this software without specific prior written permission.
Packit 664db3
   
Packit 664db3
   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
Packit 664db3
   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
Packit 664db3
   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
Packit 664db3
   A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
Packit 664db3
   CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
Packit 664db3
   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
Packit 664db3
   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
Packit 664db3
   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
Packit 664db3
   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
Packit 664db3
   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
Packit 664db3
   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Packit 664db3
*/
Packit 664db3
Packit 664db3
#ifdef HAVE_CONFIG_H
Packit 664db3
#include "config.h"
Packit 664db3
#endif
Packit 664db3
Packit 664db3
#include "mathops.h"
Packit 664db3
#include "cwrs.h"
Packit 664db3
#include "vq.h"
Packit 664db3
#include "arch.h"
Packit 664db3
#include "os_support.h"
Packit 664db3
Packit 664db3
/** Takes the pitch vector and the decoded residual vector, computes the gain
Packit 664db3
    that will give ||p+g*y||=1 and mixes the residual with the pitch. */
Packit 664db3
static void mix_pitch_and_residual(int * restrict iy, celt_norm_t * restrict X, int N, int K, const celt_norm_t * restrict P)
Packit 664db3
{
Packit 664db3
   int i;
Packit 664db3
   celt_word32_t Ryp, Ryy, Rpp;
Packit 664db3
   celt_word16_t ryp, ryy, rpp;
Packit 664db3
   celt_word32_t g;
Packit 664db3
   VARDECL(celt_norm_t, y);
Packit 664db3
#ifdef FIXED_POINT
Packit 664db3
   int yshift;
Packit 664db3
#endif
Packit 664db3
   SAVE_STACK;
Packit 664db3
#ifdef FIXED_POINT
Packit 664db3
   yshift = 13-celt_ilog2(K);
Packit 664db3
#endif
Packit 664db3
   ALLOC(y, N, celt_norm_t);
Packit 664db3
Packit 664db3
   /*for (i=0;i
Packit 664db3
   printf ("%d ", iy[i]);*/
Packit 664db3
   Rpp = 0;
Packit 664db3
   i=0;
Packit 664db3
   do {
Packit 664db3
      Rpp = MAC16_16(Rpp,P[i],P[i]);
Packit 664db3
      y[i] = SHL16(iy[i],yshift);
Packit 664db3
   } while (++i < N);
Packit 664db3
Packit 664db3
   Ryp = 0;
Packit 664db3
   Ryy = 0;
Packit 664db3
   /* If this doesn't generate a dual MAC (on supported archs), fire the compiler guy */
Packit 664db3
   i=0;
Packit 664db3
   do {
Packit 664db3
      Ryp = MAC16_16(Ryp, y[i], P[i]);
Packit 664db3
      Ryy = MAC16_16(Ryy, y[i], y[i]);
Packit 664db3
   } while (++i < N);
Packit 664db3
Packit 664db3
   ryp = ROUND16(Ryp,14);
Packit 664db3
   ryy = ROUND16(Ryy,14);
Packit 664db3
   rpp = ROUND16(Rpp,14);
Packit 664db3
   /* g = (sqrt(Ryp^2 + Ryy - Rpp*Ryy)-Ryp)/Ryy */
Packit 664db3
   g = MULT16_32_Q15(celt_sqrt(MAC16_16(Ryy, ryp,ryp) - MULT16_16(ryy,rpp)) - ryp,
Packit 664db3
                     celt_rcp(SHR32(Ryy,9)));
Packit 664db3
Packit 664db3
   i=0;
Packit 664db3
   do 
Packit 664db3
      X[i] = ADD16(P[i], ROUND16(MULT16_16(y[i], g),11));
Packit 664db3
   while (++i < N);
Packit 664db3
Packit 664db3
   RESTORE_STACK;
Packit 664db3
}
Packit 664db3
Packit 664db3
Packit 664db3
void alg_quant(celt_norm_t *X, celt_mask_t *W, int N, int K, celt_norm_t *P, ec_enc *enc)
Packit 664db3
{
Packit 664db3
   VARDECL(celt_norm_t, y);
Packit 664db3
   VARDECL(int, iy);
Packit 664db3
   VARDECL(celt_word16_t, signx);
Packit 664db3
   int j, is;
Packit 664db3
   celt_word16_t s;
Packit 664db3
   int pulsesLeft;
Packit 664db3
   celt_word32_t sum;
Packit 664db3
   celt_word32_t xy, yy, yp;
Packit 664db3
   celt_word16_t Rpp;
Packit 664db3
   int N_1; /* Inverse of N, in Q14 format (even for float) */
Packit 664db3
#ifdef FIXED_POINT
Packit 664db3
   int yshift;
Packit 664db3
#endif
Packit 664db3
   SAVE_STACK;
Packit 664db3
Packit 664db3
#ifdef FIXED_POINT
Packit 664db3
   yshift = 13-celt_ilog2(K);
Packit 664db3
#endif
Packit 664db3
Packit 664db3
   ALLOC(y, N, celt_norm_t);
Packit 664db3
   ALLOC(iy, N, int);
Packit 664db3
   ALLOC(signx, N, celt_word16_t);
Packit 664db3
   N_1 = 512/N;
Packit 664db3
Packit 664db3
   sum = 0;
Packit 664db3
   j=0; do {
Packit 664db3
      X[j] -= P[j];
Packit 664db3
      if (X[j]>0)
Packit 664db3
         signx[j]=1;
Packit 664db3
      else {
Packit 664db3
         signx[j]=-1;
Packit 664db3
         X[j]=-X[j];
Packit 664db3
         P[j]=-P[j];
Packit 664db3
      }
Packit 664db3
      iy[j] = 0;
Packit 664db3
      y[j] = 0;
Packit 664db3
      sum = MAC16_16(sum, P[j],P[j]);
Packit 664db3
   } while (++j
Packit 664db3
   Rpp = ROUND16(sum, NORM_SHIFT);
Packit 664db3
Packit 664db3
   celt_assert2(Rpp<=NORM_SCALING, "Rpp should never have a norm greater than unity");
Packit 664db3
Packit 664db3
   xy = yy = yp = 0;
Packit 664db3
Packit 664db3
   pulsesLeft = K;
Packit 664db3
Packit 664db3
   /* Do a pre-search by projecting on the pyramid */
Packit 664db3
   if (K > (N>>1))
Packit 664db3
   {
Packit 664db3
      celt_word16_t rcp;
Packit 664db3
      sum=0;
Packit 664db3
      j=0; do {
Packit 664db3
         sum += X[j];
Packit 664db3
      }  while (++j
Packit 664db3
      if (sum == 0)
Packit 664db3
      {
Packit 664db3
         X[0] = 16384;
Packit 664db3
         sum = 16384;
Packit 664db3
      }
Packit 664db3
      /* Do we have sufficient accuracy here? */
Packit 664db3
      rcp = EXTRACT16(MULT16_32_Q16(K-1, celt_rcp(sum)));
Packit 664db3
      /*rcp = DIV32(SHL32(EXTEND32(K-1),15),EPSILON+sum);*/
Packit 664db3
      /*printf ("%d (%d %d)\n", rcp, N, K);*/
Packit 664db3
      j=0; do {
Packit 664db3
#ifdef FIXED_POINT
Packit 664db3
         /* It's really important to round *towards zero* here */
Packit 664db3
         iy[j] = MULT16_16_Q15(X[j],rcp);
Packit 664db3
#else
Packit 664db3
         iy[j] = floor(rcp*X[j]);
Packit 664db3
#endif
Packit 664db3
         y[j] = SHL16(iy[j],yshift);
Packit 664db3
         yy = MAC16_16(yy, y[j],y[j]);
Packit 664db3
         xy = MAC16_16(xy, X[j],y[j]);
Packit 664db3
         yp += P[j]*y[j];
Packit 664db3
         y[j] *= 2;
Packit 664db3
         pulsesLeft -= iy[j];
Packit 664db3
      }  while (++j
Packit 664db3
   }
Packit 664db3
   /*if (pulsesLeft > N+2)
Packit 664db3
      printf ("%d / %d (%d)\n", pulsesLeft, K, N);*/
Packit 664db3
   celt_assert2(pulsesLeft>=1, "Allocated too many pulses in the quick pass");
Packit 664db3
Packit 664db3
   while (pulsesLeft > 1)
Packit 664db3
   {
Packit 664db3
      int pulsesAtOnce=1;
Packit 664db3
      int best_id;
Packit 664db3
      celt_word16_t magnitude;
Packit 664db3
      celt_word32_t best_num = -VERY_LARGE16;
Packit 664db3
      celt_word16_t best_den = 0;
Packit 664db3
#ifdef FIXED_POINT
Packit 664db3
      int rshift;
Packit 664db3
#endif
Packit 664db3
      /* Decide on how many pulses to find at once */
Packit 664db3
      pulsesAtOnce = (pulsesLeft*N_1)>>9; /* pulsesLeft/N */
Packit 664db3
      if (pulsesAtOnce<1)
Packit 664db3
         pulsesAtOnce = 1;
Packit 664db3
#ifdef FIXED_POINT
Packit 664db3
      rshift = yshift+1+celt_ilog2(K-pulsesLeft+pulsesAtOnce);
Packit 664db3
#endif
Packit 664db3
      magnitude = SHL16(pulsesAtOnce, yshift);
Packit 664db3
Packit 664db3
      best_id = 0;
Packit 664db3
      /* The squared magnitude term gets added anyway, so we might as well 
Packit 664db3
         add it outside the loop */
Packit 664db3
      yy = MAC16_16(yy, magnitude,magnitude);
Packit 664db3
      /* Choose between fast and accurate strategy depending on where we are in the search */
Packit 664db3
         /* This should ensure that anything we can process will have a better score */
Packit 664db3
      j=0;
Packit 664db3
      do {
Packit 664db3
         celt_word16_t Rxy, Ryy;
Packit 664db3
         /* Select sign based on X[j] alone */
Packit 664db3
         s = magnitude;
Packit 664db3
         /* Temporary sums of the new pulse(s) */
Packit 664db3
         Rxy = EXTRACT16(SHR32(MAC16_16(xy, s,X[j]),rshift));
Packit 664db3
         /* We're multiplying y[j] by two so we don't have to do it here */
Packit 664db3
         Ryy = EXTRACT16(SHR32(MAC16_16(yy, s,y[j]),rshift));
Packit 664db3
            
Packit 664db3
            /* Approximate score: we maximise Rxy/sqrt(Ryy) (we're guaranteed that 
Packit 664db3
         Rxy is positive because the sign is pre-computed) */
Packit 664db3
         Rxy = MULT16_16_Q15(Rxy,Rxy);
Packit 664db3
            /* The idea is to check for num/den >= best_num/best_den, but that way
Packit 664db3
         we can do it without any division */
Packit 664db3
         /* OPT: Make sure to use conditional moves here */
Packit 664db3
         if (MULT16_16(best_den, Rxy) > MULT16_16(Ryy, best_num))
Packit 664db3
         {
Packit 664db3
            best_den = Ryy;
Packit 664db3
            best_num = Rxy;
Packit 664db3
            best_id = j;
Packit 664db3
         }
Packit 664db3
      } while (++j
Packit 664db3
      
Packit 664db3
      j = best_id;
Packit 664db3
      is = pulsesAtOnce;
Packit 664db3
      s = SHL16(is, yshift);
Packit 664db3
Packit 664db3
      /* Updating the sums of the new pulse(s) */
Packit 664db3
      xy = xy + MULT16_16(s,X[j]);
Packit 664db3
      /* We're multiplying y[j] by two so we don't have to do it here */
Packit 664db3
      yy = yy + MULT16_16(s,y[j]);
Packit 664db3
      yp = yp + MULT16_16(s, P[j]);
Packit 664db3
Packit 664db3
      /* Only now that we've made the final choice, update y/iy */
Packit 664db3
      /* Multiplying y[j] by 2 so we don't have to do it everywhere else */
Packit 664db3
      y[j] += 2*s;
Packit 664db3
      iy[j] += is;
Packit 664db3
      pulsesLeft -= pulsesAtOnce;
Packit 664db3
   }
Packit 664db3
   
Packit 664db3
   if (pulsesLeft > 0)
Packit 664db3
   {
Packit 664db3
      celt_word16_t g;
Packit 664db3
      celt_word16_t best_num = -VERY_LARGE16;
Packit 664db3
      celt_word16_t best_den = 0;
Packit 664db3
      int best_id = 0;
Packit 664db3
      celt_word16_t magnitude = SHL16(1, yshift);
Packit 664db3
Packit 664db3
      /* The squared magnitude term gets added anyway, so we might as well 
Packit 664db3
      add it outside the loop */
Packit 664db3
      yy = MAC16_16(yy, magnitude,magnitude);
Packit 664db3
      j=0;
Packit 664db3
      do {
Packit 664db3
         celt_word16_t Rxy, Ryy, Ryp;
Packit 664db3
         celt_word16_t num;
Packit 664db3
         /* Select sign based on X[j] alone */
Packit 664db3
         s = magnitude;
Packit 664db3
         /* Temporary sums of the new pulse(s) */
Packit 664db3
         Rxy = ROUND16(MAC16_16(xy, s,X[j]), 14);
Packit 664db3
         /* We're multiplying y[j] by two so we don't have to do it here */
Packit 664db3
         Ryy = ROUND16(MAC16_16(yy, s,y[j]), 14);
Packit 664db3
         Ryp = ROUND16(MAC16_16(yp, s,P[j]), 14);
Packit 664db3
Packit 664db3
            /* Compute the gain such that ||p + g*y|| = 1 
Packit 664db3
         ...but instead, we compute g*Ryy to avoid dividing */
Packit 664db3
         g = celt_psqrt(MULT16_16(Ryp,Ryp) + MULT16_16(Ryy,QCONST16(1.f,14)-Rpp)) - Ryp;
Packit 664db3
            /* Knowing that gain, what's the error: (x-g*y)^2 
Packit 664db3
         (result is negated and we discard x^2 because it's constant) */
Packit 664db3
         /* score = 2*g*Rxy - g*g*Ryy;*/
Packit 664db3
#ifdef FIXED_POINT
Packit 664db3
         /* No need to multiply Rxy by 2 because we did it earlier */
Packit 664db3
         num = MULT16_16_Q15(ADD16(SUB16(Rxy,g),Rxy),g);
Packit 664db3
#else
Packit 664db3
         num = g*(2*Rxy-g);
Packit 664db3
#endif
Packit 664db3
         if (MULT16_16(best_den, num) > MULT16_16(Ryy, best_num))
Packit 664db3
         {
Packit 664db3
            best_den = Ryy;
Packit 664db3
            best_num = num;
Packit 664db3
            best_id = j;
Packit 664db3
         }
Packit 664db3
      } while (++j
Packit 664db3
      iy[best_id] += 1;
Packit 664db3
   }
Packit 664db3
   j=0;
Packit 664db3
   do {
Packit 664db3
      P[j] = MULT16_16(signx[j],P[j]);
Packit 664db3
      X[j] = MULT16_16(signx[j],X[j]);
Packit 664db3
      if (signx[j] < 0)
Packit 664db3
         iy[j] = -iy[j];
Packit 664db3
   } while (++j
Packit 664db3
   encode_pulses(iy, N, K, enc);
Packit 664db3
   
Packit 664db3
   /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
Packit 664db3
   due to the recursive computation used in quantisation. */
Packit 664db3
   mix_pitch_and_residual(iy, X, N, K, P);
Packit 664db3
   RESTORE_STACK;
Packit 664db3
}
Packit 664db3
Packit 664db3
Packit 664db3
/** Decode pulse vector and combine the result with the pitch vector to produce
Packit 664db3
    the final normalised signal in the current band. */
Packit 664db3
void alg_unquant(celt_norm_t *X, int N, int K, celt_norm_t *P, ec_dec *dec)
Packit 664db3
{
Packit 664db3
   VARDECL(int, iy);
Packit 664db3
   SAVE_STACK;
Packit 664db3
   ALLOC(iy, N, int);
Packit 664db3
   decode_pulses(iy, N, K, dec);
Packit 664db3
   mix_pitch_and_residual(iy, X, N, K, P);
Packit 664db3
   RESTORE_STACK;
Packit 664db3
}
Packit 664db3
Packit 664db3
void renormalise_vector(celt_norm_t *X, celt_word16_t value, int N, int stride)
Packit 664db3
{
Packit 664db3
   int i;
Packit 664db3
   celt_word32_t E = EPSILON;
Packit 664db3
   celt_word16_t g;
Packit 664db3
   celt_norm_t *xptr = X;
Packit 664db3
   for (i=0;i
Packit 664db3
   {
Packit 664db3
      E = MAC16_16(E, *xptr, *xptr);
Packit 664db3
      xptr += stride;
Packit 664db3
   }
Packit 664db3
Packit 664db3
   g = MULT16_16_Q15(value,celt_rcp(SHL32(celt_sqrt(E),9)));
Packit 664db3
   xptr = X;
Packit 664db3
   for (i=0;i
Packit 664db3
   {
Packit 664db3
      *xptr = PSHR32(MULT16_16(g, *xptr),8);
Packit 664db3
      xptr += stride;
Packit 664db3
   }
Packit 664db3
}
Packit 664db3
Packit 664db3
static void fold(const CELTMode *m, int N, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B)
Packit 664db3
{
Packit 664db3
   int j;
Packit 664db3
   const int C = CHANNELS(m);
Packit 664db3
   int id = N0 % (C*B);
Packit 664db3
   /* Here, we assume that id will never be greater than N0, i.e. that 
Packit 664db3
      no band is wider than N0. In the unlikely case it happens, we set
Packit 664db3
      everything to zero */
Packit 664db3
   if (id+C*N>N0)
Packit 664db3
      for (j=0;j
Packit 664db3
         P[j] = 0;
Packit 664db3
   else
Packit 664db3
      for (j=0;j
Packit 664db3
         P[j] = Y[id++];
Packit 664db3
}
Packit 664db3
Packit 664db3
#define KGAIN 6
Packit 664db3
Packit 664db3
void intra_fold(const CELTMode *m, celt_norm_t * restrict x, int N, int K, celt_norm_t *Y, celt_norm_t * restrict P, int N0, int B)
Packit 664db3
{
Packit 664db3
   celt_word16_t pred_gain;
Packit 664db3
   const int C = CHANNELS(m);
Packit 664db3
Packit 664db3
   if (K==0)
Packit 664db3
      pred_gain = Q15ONE;
Packit 664db3
   else
Packit 664db3
      pred_gain = celt_div((celt_word32_t)MULT16_16(Q15_ONE,N),(celt_word32_t)(N+KGAIN*K));
Packit 664db3
Packit 664db3
   fold(m, N, Y, P, N0, B);
Packit 664db3
Packit 664db3
   renormalise_vector(P, pred_gain, C*N, 1);
Packit 664db3
}
Packit 664db3