/* nag_kalman_sqrt_filt_cov_invar (g13ebc) Example Program.
 *
 * Copyright 2014 Numerical Algorithms Group
 *
 * Mark 3, 1993
 * Mark 7, revised, 2001.
 * Mark 8 revised, 2004.
 *
 */

#include <nag.h>
#include <stdio.h>
#include <nag_stdlib.h>
#include <nagf03.h>
#include <nagf06.h>
#include <nagf16.h>
#include <nagg13.h>

typedef enum { read, print }   ioflag;

static int ex1(void);
static int ex2(void);

int main(void)
{
  Integer  exit_status_ex1 = 0;
  Integer  exit_status_ex2 = 0;

  printf("nag_kalman_sqrt_filt_cov_invar (g13ebc) Example Program "
          "Results\n\n");

  /* Skip the heading in the data file  */
  scanf(" %*[^\n] ");

  exit_status_ex1 = ex1();
  exit_status_ex2 = ex2();

  return (exit_status_ex1 == 0 && exit_status_ex2 == 0) ? 0 : 1;
}

#define A(I, J) a[(I) *tda + J]
#define B(I, J) b[(I) *tdb + J]
#define C(I, J) c[(I) *tdc + J]
#define K(I, J) k[(I) *tdk + J]
#define Q(I, J) q[(I) *tdq + J]
#define R(I, J) r[(I) *tdr + J]
#define S(I, J) s[(I) *tds + J]
#define H(I, J) h[(I) *tdh + J]

static int ex1()
{ /* simple example (matrices A and C are supplied in lower observer
     Hessenberg form) */
  Integer  exit_status = 0, i, istep, j, m, n, p, tda, tdb, tdc, tdh, tdk, tdq;
  Integer  tdr, tds;
  NagError fail;
  double   *a = 0, *b = 0, *c = 0, *h = 0, *k = 0, *q = 0, *r = 0, *s = 0, tol;

  INIT_FAIL(fail);

  /* Skip the heading in the data file  */
  scanf(" %*[^\n]");

  printf("Example 1\n");
  scanf("%ld%ld%ld%lf", &n, &m, &p, &tol);
  if (n >= 1 && m >= 1 && p >= 1)
    {
      if (!(a = NAG_ALLOC(n*n, double)) ||
          !(b = NAG_ALLOC(n*m, double)) ||
          !(c = NAG_ALLOC(p*n, double)) ||
          !(k = NAG_ALLOC(n*p, double)) ||
          !(q = NAG_ALLOC(m*m, double)) ||
          !(r = NAG_ALLOC(p*p, double)) ||
          !(s = NAG_ALLOC(n*n, double)) ||
          !(h = NAG_ALLOC(n*p, double)))
        {
          printf("Allocation failure\n");
          exit_status = -1;
          goto END;
        }
      tda = n;
      tdb = m;
      tdc = n;
      tdk = p;
      tdq = m;
      tdr = p;
      tds = n;
      tdh = p;
    }
  else
    {
      printf("Invalid n or m or p.\n");
      exit_status = 1;
      return exit_status;
    }

  /* Read data */
  for (i = 0; i < n; ++i)
    for (j = 0; j < n; ++j)
      scanf("%lf", &S(i, j));
  for (i = 0; i < n; ++i)
    for (j = 0; j < n; ++j)
      scanf("%lf", &A(i, j));
  for (i = 0; i < n; ++i)
    for (j = 0; j < m; ++j)
      scanf("%lf", &B(i, j));

  if (q)
    {
      for (i = 0; i < m; ++i)
        for (j = 0; j < m; ++j)
          scanf("%lf", &Q(i, j));
    }
  for (i = 0; i < p; ++i)
    for (j = 0; j < n; ++j)
      scanf("%lf", &C(i, j));
  for (i = 0; i < p; ++i)
    for (j = 0; j < p; ++j)
      scanf("%lf", &R(i, j));

  /* Perform three iterations of the Kalman filter recursion  */
  for (istep = 1; istep <= 3; ++istep)
    /* nag_kalman_sqrt_filt_cov_invar (g13ebc).
     * One iteration step of the time-invariant Kalman filter
     * recursion using the square root covariance implementation
     * with (AC) in lower observer Hessenberg form
     */
    nag_kalman_sqrt_filt_cov_invar(n, m, p, s, tds, a, tda, b, tdb, q, tdq,
                                   c, tdc, r, tdr, k, tdk, h, tdh, tol, &fail);
  if (fail.code != NE_NOERROR)
    {
      printf("Error from nag_kalman_sqrt_filt_cov_invar (g13ebc).\n%s\n",
              fail.message);
      exit_status = 1;
      goto END;
    }

  printf("\nThe square root of the state covariance matrix is\n\n");
  for (i = 0; i < n; ++i)
    {
      for (j = 0; j < n; ++j)
        printf("%8.4f ", S(i, j));
      printf("\n");
    }
  if (k)
    {
      printf("\nThe matrix AK (the product of the Kalman gain\n");
      printf("matrix with the state transition matrix) is\n\n");
      for (i = 0; i < n; ++i)
        {
          for (j = 0; j < p; ++j)
            printf("%8.4f ", K(i, j));
          printf("\n");
        }
    }
 END:
  NAG_FREE(a);
  NAG_FREE(b);
  NAG_FREE(c);
  NAG_FREE(k);
  NAG_FREE(q);
  NAG_FREE(r);
  NAG_FREE(s);
  NAG_FREE(h);

  return exit_status;
}

static void mat_io(Integer n, Integer m, double mat[], Integer tdmat,
                   ioflag flag, const char *message);


#define KE(I, J)    ke[(I) *tdke + J]
#define KF(I, J)    kf[(I) *tdkf + J]
#define UB(I, J)    ub[(I) *tdub + J]
#define RWORK(I, J) rwork[(I) *tdrwork + J]
#define SF(I, J)    sf[(I) *tdsf + J]
#define SE(I, J)    se[(I) *tdse + J]
#define PF(I, J)    pf[(I) *tdpf + J]
#define PE(I, J)    pe[(I) *tdpe + J]
#define UAUT(I, J)  uaut[(I) *tduaut + J]
#define CUT(I, J)   cut[(I) *tdcut + J]
#define U(I, J)     u[(I) *tdu + J]

static int ex2()
{ /* more general example which requires the data to be transformed. The
     results produced by nag_kalman_sqrt_filt_cov_var (g13eac) and
     nag_kalman_sqrt_filt_cov_invar (g13ebc) are compared */
  Integer          dete, exit_status = 0, i, ione = 1, istep, j, m, n, p, tda,
                   tdb;
  Integer          tdc, tdcut, tdh, tdke, tdkf, tdpe, tdpf, tdq, tdr, tdrwork,
                   tdse;
  Integer          tdsf, tdu, tduaut, tdub;
  NagError         fail;
  Nag_ObserverForm reduceto = Nag_LH_Observer;
  double           *a = 0, *b = 0, *c = 0, *cut = 0, detf, *diag = 0, *h = 0;
  double           *ke = 0, *kf = 0, one = 1.0, *pe = 0, *pf = 0, *q = 0;
  double           *r = 0, *rwork = 0, *se = 0, *sf = 0, tol, *u = 0;
  double           *uaut = 0, *ub = 0, zero = 0.0;

  INIT_FAIL(fail);

  printf("\nExample 2\n\n");

  /* skip the heading in the data file */
  scanf(" %*[^\n]");
  scanf("%ld%ld%ld%lf", &n, &m, &p, &tol);
  if (n >= 1 && m >= 1 && p >= 1)
    {
      if (!(a = NAG_ALLOC(n*n, double)) ||
          !(b = NAG_ALLOC(n*m, double)) ||
          !(c = NAG_ALLOC(p*n, double)) ||
          !(ke = NAG_ALLOC(n*p, double)) ||
          !(kf = NAG_ALLOC(n*p, double)) ||
          !(ub = NAG_ALLOC(n*m, double)) ||
          !(q = NAG_ALLOC(m*m, double)) ||
          !(r = NAG_ALLOC(p*p, double)) ||
          !(rwork = NAG_ALLOC(n*n, double)) ||
          !(sf = NAG_ALLOC(n*n, double)) ||
          !(se = NAG_ALLOC(n*n, double)) ||
          !(h = NAG_ALLOC(n*p, double)) ||
          !(pf = NAG_ALLOC(n*n, double)) ||
          !(pe = NAG_ALLOC(n*n, double)) ||
          !(uaut = NAG_ALLOC(n*n, double)) ||
          !(cut = NAG_ALLOC(p*n, double)) ||
          !(u = NAG_ALLOC(n*n, double)) ||
          !(diag = NAG_ALLOC(n, double)))
        {
          printf("Allocation failure\n");
          exit_status = -1;
          goto END;
        }
      tda = n;
      tdb = m;
      tdc = n;
      tdke = p;
      tdkf = p;
      tdub = m;
      tdq = m;
      tdr = p;
      tdrwork = n;
      tdsf = n;
      tdse = n;
      tdh = p;
      tdpf = n;
      tdpe = n;
      tduaut = n;
      tdcut = n;
      tdu = n;
    }
  else
    {
      printf("Invalid n or m or p.\n");
      exit_status = 1;
      return exit_status;
    }
  mat_io(n, n, se, tdse, read, "");
  mat_io(n, n, a, tda, read, "");
  mat_io(n, m, b, tdb, read, "");
  if (q)
    mat_io(m, m, q, tdq, read, "");
  mat_io(p, n, c, tdc, read, "");
  mat_io(p, p, r, tdr, read, "");
  for (i = 0; i < n; ++i)
    {
      for (j = 0; j < n; ++j)
        {
          if (i < p)
            CUT(i, j) = C(i, j);
          SF(i, j) = SE(i, j);
          UAUT(i, j) = A(i, j);
          U(i, j) = zero;
        }
      U(i, i) = one;
    }
  /* Set up the matrix pair (A,C) in the lower observer hessenberg form */
  /* nag_trans_hessenberg_observer (g13ewc).
   * Unitary state-space transformation to reduce (AC) to
   * lower or upper observer Hessenberg form
   */
  nag_trans_hessenberg_observer(n, p, reduceto, uaut, tduaut, cut, tdcut,
                                u, tdu, &fail);
  if (fail.code != NE_NOERROR)
    {
      printf("Error from nag_trans_hessenberg_observer (g13ewc).\n%s\n",
              fail.message);
      exit_status = 1;
      goto END;
    }
  for (j = 0; j < m; ++j)
    for (i = 0; i < n; ++i)
      UB(i, j) = f06eac(n, &U(i, 0), ione, &B(0, j), tdb);

  /* Generate noise covariance matrices PE and PF = U * PE * U' */
  nag_dgemm(Nag_RowMajor, Nag_NoTrans, Nag_Trans, n, n, n, one, se, tdse,
            se, tdse, zero, pe, tdpe, &fail);
  nag_dgemm(Nag_RowMajor, Nag_NoTrans, Nag_Trans, n, n, n, one, pe, tdpe,
            u, tdu, zero, rwork, tdrwork, &fail);
  nag_dgemm(Nag_RowMajor, Nag_NoTrans, Nag_NoTrans, n, n, n, one, u, tdu,
            rwork, tdrwork, zero, pf, tdpf, &fail);

  /* Now find the lower triangular (left) cholesky factor of PF. */
  /* nag_real_cholesky (f03aec).
   * LL^T factorization and determinant of real symmetric
   * positive-definite matrix
   */
  f03aec(n, pf, tdpf, diag, &detf, &dete, &fail);
  if (fail.code != NE_NOERROR)
    {
      printf("Error from nag_real_cholesky (f03aec).\n%s\n",
              fail.message);
      exit_status = 1;
      goto END;
    }
  for (i = 0; i < n; ++i)
    {
      SF(i, i) = one/diag[i];
      for (j = 0; j < i; ++j)
        SF(i, j) = PF(i, j);
    }
  /* Perform three steps of the Kalman filter recursion */
  for (istep = 1; istep <= 3; ++istep)
    {
      /* nag_kalman_sqrt_filt_cov_var (g13eac).
       * One iteration step of the time-varying Kalman filter
       * recursion using the square root covariance implementation
       */
      nag_kalman_sqrt_filt_cov_var(n, m, p, se, tdse, a, tda, b, tdb, q,
                                   tdq, c, tdc, r, tdr, ke, tdke, h, tdh, tol,
                                   &fail);
      if (fail.code != NE_NOERROR)
        {
          printf("Error from nag_kalman_sqrt_filt_cov_var (g13eac).\n%s\n",
                  fail.message);
          exit_status = 1;
          goto END;
        }
      /* nag_kalman_sqrt_filt_cov_invar (g13ebc), see above. */
      nag_kalman_sqrt_filt_cov_invar(n, m, p, sf, tdsf, uaut, tduaut, ub, tdub,
                                     q, tdq, cut, tdcut, r, tdr, kf, tdkf, h,
                                     tdh, tol, &fail);
      if (fail.code != NE_NOERROR)
        {
          printf("Error from nag_kalman_sqrt_filt_cov_invar (g13ebc).\n%s\n",
                  fail.message);
          exit_status = 1;
          goto END;
        }
    }
  nag_dgemm(Nag_RowMajor, Nag_NoTrans, Nag_Trans, n, n, n, one, se, tdse,
            se, tdse, zero, pe, tdpe, &fail);
  nag_dgemm(Nag_RowMajor, Nag_NoTrans, Nag_Trans, n, n, n, one, sf, tdsf,
            sf, tdsf, zero, pf, tdpf, &fail);
  mat_io(n, n, pe, tdpe, print, "Covariance matrix PE from "
         "nag_kalman_sqrt_filt_cov_var (g13eac) is\n");
  mat_io(n, n, pf, tdpf, print, "Covariance matrix PF from "
         "nag_kalman_sqrt_filt_cov_invar (g13ebc) is\n");

  /* Calculate PF = U' * PF * U */
  nag_dgemm(Nag_RowMajor, Nag_NoTrans, Nag_NoTrans, n, n, n, one, pf, tdpf,
            u, tdu, zero, rwork, tdrwork, &fail);
  nag_dgemm(Nag_RowMajor, Nag_Trans, Nag_NoTrans, n, n, n, one, u, tdu,
            rwork, tdrwork, zero, pf, tdpf, &fail);
  mat_io(n, n, pf, tdpf, print, "Matrix U' * PF * U is \n");
  mat_io(n, p, ke, tdke, print,
         "The matrix KE from nag_kalman_sqrt_filt_cov_var (g13eac) is\n");
  mat_io(n, p, kf, tdkf, print,
         "The matrix KF from nag_kalman_sqrt_filt_cov_invar (g13ebc) is\n");

  /* calculate U' * K */
  nag_dgemm(Nag_RowMajor, Nag_Trans, Nag_NoTrans, n, p, n, one, u, tdu,
            kf, tdkf, zero, rwork, tdrwork, &fail);
  mat_io(n, p, rwork, tdrwork, print, "U' * KF is\n");

 END:
  NAG_FREE(a);
  NAG_FREE(b);
  NAG_FREE(c);
  NAG_FREE(ke);
  NAG_FREE(kf);
  NAG_FREE(ub);
  NAG_FREE(q);
  NAG_FREE(r);
  NAG_FREE(rwork);
  NAG_FREE(sf);
  NAG_FREE(se);
  NAG_FREE(h);
  NAG_FREE(pf);
  NAG_FREE(pe);
  NAG_FREE(uaut);
  NAG_FREE(cut);
  NAG_FREE(u);
  NAG_FREE(diag);

  return exit_status;
}


static void mat_io(Integer n, Integer m, double mat[], Integer tdmat,
                   ioflag flag, const char *message)
{
  Integer i, j;
#define MAT(I, J) mat[((I) -1)*tdmat + (J) -1]
  if (flag == print) printf("%s \n", message);
  for (i = 1; i <= n; ++i)
    {
      for (j = 1; j <= m; ++j)
        {
          if (flag == read) scanf("%lf", &MAT(i, j));
          if (flag == print) printf("%8.4f ", MAT(i, j));
        }
      if (flag == print) printf("\n");
    }
  if (flag == print) printf("\n");
} /* mat_io */