Blame linalg/invtri.c

Packit 67cb25
/* linalg/invtri.c
Packit 67cb25
 *
Packit 67cb25
 * Copyright (C) 2016 Patrick Alken
Packit 67cb25
 *
Packit 67cb25
 * This is free software; you can redistribute it and/or modify it
Packit 67cb25
 * under the terms of the GNU General Public License as published by the
Packit 67cb25
 * Free Software Foundation; either version 3, or (at your option) any
Packit 67cb25
 * later version.
Packit 67cb25
 *
Packit 67cb25
 * This source is distributed in the hope that it will be useful, but WITHOUT
Packit 67cb25
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
Packit 67cb25
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
Packit 67cb25
 * for more details.
Packit 67cb25
 *
Packit 67cb25
 * You should have received a copy of the GNU General Public License
Packit 67cb25
 * along with this program; if not, write to the Free Software
Packit 67cb25
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
Packit 67cb25
 *
Packit 67cb25
 * This module contains code to invert triangular matrices
Packit 67cb25
 */
Packit 67cb25
Packit 67cb25
#include <config.h>
Packit 67cb25
Packit 67cb25
#include <gsl/gsl_math.h>
Packit 67cb25
#include <gsl/gsl_errno.h>
Packit 67cb25
#include <gsl/gsl_vector.h>
Packit 67cb25
#include <gsl/gsl_matrix.h>
Packit 67cb25
#include <gsl/gsl_blas.h>
Packit 67cb25
#include <gsl/gsl_linalg.h>
Packit 67cb25
Packit 67cb25
static int triangular_inverse(CBLAS_UPLO_t Uplo, CBLAS_DIAG_t Diag, gsl_matrix * T);
Packit 67cb25
Packit 67cb25
int
Packit 67cb25
gsl_linalg_tri_upper_invert(gsl_matrix * T)
Packit 67cb25
{
Packit 67cb25
  int status = triangular_inverse(CblasUpper, CblasNonUnit, T);
Packit 67cb25
  return status;
Packit 67cb25
}
Packit 67cb25
Packit 67cb25
int
Packit 67cb25
gsl_linalg_tri_lower_invert(gsl_matrix * T)
Packit 67cb25
{
Packit 67cb25
  int status = triangular_inverse(CblasLower, CblasNonUnit, T);
Packit 67cb25
  return status;
Packit 67cb25
}
Packit 67cb25
Packit 67cb25
int
Packit 67cb25
gsl_linalg_tri_upper_unit_invert(gsl_matrix * T)
Packit 67cb25
{
Packit 67cb25
  int status = triangular_inverse(CblasUpper, CblasUnit, T);
Packit 67cb25
  return status;
Packit 67cb25
}
Packit 67cb25
Packit 67cb25
int
Packit 67cb25
gsl_linalg_tri_lower_unit_invert(gsl_matrix * T)
Packit 67cb25
{
Packit 67cb25
  int status = triangular_inverse(CblasLower, CblasUnit, T);
Packit 67cb25
  return status;
Packit 67cb25
}
Packit 67cb25
Packit 67cb25
/*
Packit 67cb25
triangular_inverse()
Packit 67cb25
  Invert a triangular matrix T
Packit 67cb25
Packit 67cb25
Inputs: Uplo - CblasUpper or CblasLower
Packit 67cb25
        Diag - unit triangular?
Packit 67cb25
        T    - on output the upper (or lower) part of T
Packit 67cb25
               is replaced by its inverse
Packit 67cb25
Packit 67cb25
Return: success/error
Packit 67cb25
*/
Packit 67cb25
Packit 67cb25
static int
Packit 67cb25
triangular_inverse(CBLAS_UPLO_t Uplo, CBLAS_DIAG_t Diag, gsl_matrix * T)
Packit 67cb25
{
Packit 67cb25
  const size_t N = T->size1;
Packit 67cb25
Packit 67cb25
  if (N != T->size2)
Packit 67cb25
    {
Packit 67cb25
      GSL_ERROR ("matrix must be square", GSL_ENOTSQR);
Packit 67cb25
    }
Packit 67cb25
  else
Packit 67cb25
    {
Packit 67cb25
      gsl_matrix_view m;
Packit 67cb25
      gsl_vector_view v;
Packit 67cb25
      size_t i;
Packit 67cb25
Packit 67cb25
      if (Uplo == CblasUpper)
Packit 67cb25
        {
Packit 67cb25
          for (i = 0; i < N; ++i)
Packit 67cb25
            {
Packit 67cb25
              double aii;
Packit 67cb25
Packit 67cb25
              if (Diag == CblasNonUnit)
Packit 67cb25
                {
Packit 67cb25
                  double *Tii = gsl_matrix_ptr(T, i, i);
Packit 67cb25
                  *Tii = 1.0 / *Tii;
Packit 67cb25
                  aii = -(*Tii);
Packit 67cb25
                }
Packit 67cb25
              else
Packit 67cb25
                {
Packit 67cb25
                  aii = -1.0;
Packit 67cb25
                }
Packit 67cb25
Packit 67cb25
              if (i > 0)
Packit 67cb25
                {
Packit 67cb25
                  m = gsl_matrix_submatrix(T, 0, 0, i, i);
Packit 67cb25
                  v = gsl_matrix_subcolumn(T, i, 0, i);
Packit 67cb25
Packit 67cb25
                  gsl_blas_dtrmv(CblasUpper, CblasNoTrans, Diag,
Packit 67cb25
                                 &m.matrix, &v.vector);
Packit 67cb25
Packit 67cb25
                  gsl_blas_dscal(aii, &v.vector);
Packit 67cb25
                }
Packit 67cb25
            } /* for (i = 0; i < N; ++i) */
Packit 67cb25
        }
Packit 67cb25
      else
Packit 67cb25
        {
Packit 67cb25
          for (i = 0; i < N; ++i)
Packit 67cb25
            {
Packit 67cb25
              double ajj;
Packit 67cb25
              size_t j = N - i - 1;
Packit 67cb25
Packit 67cb25
              if (Diag == CblasNonUnit)
Packit 67cb25
                {
Packit 67cb25
                  double *Tjj = gsl_matrix_ptr(T, j, j);
Packit 67cb25
                  *Tjj = 1.0 / *Tjj;
Packit 67cb25
                  ajj = -(*Tjj);
Packit 67cb25
                }
Packit 67cb25
              else
Packit 67cb25
                {
Packit 67cb25
                  ajj = -1.0;
Packit 67cb25
                }
Packit 67cb25
Packit 67cb25
              if (j < N - 1)
Packit 67cb25
                {
Packit 67cb25
                  m = gsl_matrix_submatrix(T, j + 1, j + 1,
Packit 67cb25
                                           N - j - 1, N - j - 1);
Packit 67cb25
                  v = gsl_matrix_subcolumn(T, j, j + 1, N - j - 1);
Packit 67cb25
Packit 67cb25
                  gsl_blas_dtrmv(CblasLower, CblasNoTrans, Diag,
Packit 67cb25
                                 &m.matrix, &v.vector);
Packit 67cb25
Packit 67cb25
                  gsl_blas_dscal(ajj, &v.vector);
Packit 67cb25
                }
Packit 67cb25
            } /* for (i = 0; i < N; ++i) */
Packit 67cb25
        }
Packit 67cb25
Packit 67cb25
      return GSL_SUCCESS;
Packit 67cb25
    }
Packit 67cb25
}