/* nag_ode_bvp_coll_nlin_setup (d02tvc) Example Program.
 *
 * Copyright 2014 Numerical Algorithms Group.
 * 
 * Mark 24, 2013.
 */

#include <stdio.h>
#include <math.h>
#include <nag.h>
#include <nag_stdlib.h>
#include <nagd02.h>
#include <nagx01.h>

typedef struct {
  double  beta0, eta, lambda, mu;
} func_data;

#ifdef __cplusplus
extern "C" {
#endif
static void NAG_CALL ffun(double x, const double y[], Integer neq, 
                          const Integer m[], double f[], Nag_Comm *comm);
static void NAG_CALL fjac(double x, const double y[], Integer neq, 
                          const Integer m[], double dfdy[], Nag_Comm *comm);
static void NAG_CALL gafun(const double ya[], Integer neq, const Integer m[],
                           Integer nlbc, double ga[], Nag_Comm *comm);
static void NAG_CALL gbfun(const double yb[], Integer neq, const Integer m[],
                           Integer nrbc, double gb[], Nag_Comm *comm);
static void NAG_CALL gajac(const double ya[], Integer neq, const Integer m[],
                           Integer nlbc, double dgady[], Nag_Comm *comm);
static void NAG_CALL gbjac(const double yb[], Integer neq, const Integer m[],
                           Integer nrbc, double dgbdy[], Nag_Comm *comm);
static void NAG_CALL guess(double x, Integer neq, const Integer m[], double y[],
                           double dym[], Nag_Comm *comm);
#ifdef __cplusplus
}
#endif

int main(void)
{
  /* Scalars */
  Integer    exit_status = 0, neq = 6, mmax = 1, nlbc = 3, nrbc = 3;
  double     dx, ermx, beta0, eta, lambda, mu;
  Integer    i, iermx, ijermx, j, licomm, lrcomm, mxmesh, ncol, nmesh;
  /* Arrays */
  static double ruser[7] = {-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0};
  double     *mesh = 0, *rcomm = 0, *tols = 0, *y = 0;
  double     rdum[1];
  Integer    *ipmesh = 0, *icomm = 0, *m = 0;
  Integer    idum[2];
  func_data  fd;
  /* Nag Types */
  Nag_Comm   comm;
  NagError   fail;

  INIT_FAIL(fail);

  printf ("nag_ode_bvp_coll_nlin_setup (d02tvc) Example Program Results\n\n");

  /* For communication with user-supplied functions: */
  comm.user = ruser;

  /* Skip heading in data file*/
  scanf("%*[^\n] ");
  scanf("%"NAG_IFMT "%"NAG_IFMT "%"NAG_IFMT "%*[^\n] ", &ncol, &nmesh, &mxmesh);
  if (!(mesh   = NAG_ALLOC(mxmesh, double)) ||
      !(m      = NAG_ALLOC(neq, Integer)) ||
      !(tols   = NAG_ALLOC(neq, double)) ||
      !(y      = NAG_ALLOC(neq*mmax, double)) ||
      !(ipmesh = NAG_ALLOC(mxmesh, Integer)))
    {
      printf("Allocation failure\n");
      exit_status = -1;
      goto END;
    }

  /* Set orders of equations */
  for (i = 0; i < neq; i++) {
    m[i] = 1;
  }
  scanf("%lf%lf%lf%lf%*[^\n] ", &beta0, &eta, &lambda, &mu);
  for (i = 0; i < neq; i++) {
    scanf("%lf", &tols[i]);
  }
  scanf("%*[^\n] ");
  dx = 1.0/(double) (nmesh - 1);
  mesh[0] = 0.0;
  for (i = 1; i < nmesh - 1; i++) {
    mesh[i] = mesh[i - 1] + dx;
  }
  mesh[nmesh - 1] = 1.0;
  ipmesh[0] = 1;
  for (i = 1; i < nmesh - 1; i++) {
    ipmesh[i] = 2;
  }
  ipmesh[nmesh - 1] = 1;

  /* Set data required for the user-supplied functions */
  fd.beta0 = beta0;
  fd.eta = eta;
  fd.lambda = lambda;
  fd.mu = mu;
  /* Associate the data structure with comm.p */
  comm.p = (Pointer) &fd;

  /* Communication space query to get size of rcomm and icomm 
   * by setting lrcomm=0 in call to
   * nag_ode_bvp_coll_nlin_setup (d02tvc):
   * Ordinary differential equations, general nonlinear boundary value problem,
   * setup for nag_ode_bvp_coll_nlin_solve (d02tlc).
   */
  nag_ode_bvp_coll_nlin_setup(neq, m, nlbc, nrbc, ncol, tols, mxmesh, nmesh,
                              mesh, ipmesh, rdum, 0, idum, 2, &fail);
  if (fail.code == NE_NOERROR) {
    lrcomm = idum[0];
    licomm = idum[1];

    if (!(rcomm = NAG_ALLOC(lrcomm, double)) ||
        !(icomm = NAG_ALLOC(licomm, Integer))) {
      printf("Allocation failure\n");
      exit_status = -2;
      goto END;
    }
    /* Initialize, again using nag_ode_bvp_coll_nlin_setup (d02tvc). */
    nag_ode_bvp_coll_nlin_setup(neq, m, nlbc, nrbc, ncol, tols, mxmesh, nmesh,
                                mesh, ipmesh, rcomm, lrcomm, icomm, licomm,
                                &fail);
    if (fail.code != NE_NOERROR) {
      printf("Error from nag_ode_bvp_coll_nlin_setup (d02tvc).\n%s\n",
             fail.message);
      exit_status = 2;
      goto END;
    }
  }
  if (fail.code != NE_NOERROR) {
    printf("Error from nag_ode_bvp_coll_nlin_setup (d02tvc).\n%s\n",
           fail.message);
    exit_status = 1;
    goto END;
  }

  /* Solve*/

  /* nag_ode_bvp_coll_nlin_solve (d02tlc).
   * Ordinary differential equations, general nonlinear boundary value problem,
   * collocation technique.
   */
  nag_ode_bvp_coll_nlin_solve(ffun, fjac, gafun, gbfun, gajac, gbjac, guess,
                              rcomm, icomm, &comm, &fail);
  if (fail.code != NE_NOERROR) {
    printf("Error from nag_ode_bvp_coll_nlin_solve (d02tlc).\n%s\n",
           fail.message);
    exit_status = 3;
    goto END;
  }

  /* Extract mesh.*/

  /* nag_ode_bvp_coll_nlin_diag (d02tzc).
   * Ordinary differential equations, general nonlinear boundary value
   * problem, diagnostics for nag_ode_bvp_coll_nlin_solve (d02tlc).
   */
  nag_ode_bvp_coll_nlin_diag(mxmesh, &nmesh, mesh, ipmesh, &ermx, &iermx,
                             &ijermx, rcomm, icomm, &fail);
  if (fail.code != NE_NOERROR) {
    printf("Error from nag_ode_bvp_coll_nlin_diag (d02tzc).\n%s\n",
           fail.message);
    exit_status = 4;
    goto END;
  }

  /* Print mesh statistics*/
  printf(" Used a mesh of %4ld  points\n", nmesh);
  printf(" Maximum error = %10.2e  in interval %4ld  for component ",
         ermx, iermx);
  printf("%4"NAG_IFMT " \n\n\n", ijermx);
  printf(" Mesh points:\n");
  for (i = 0; i < nmesh; i++) {
    printf("%4ld(%1ld)", i+1, ipmesh[i]);
    printf("%7.4f%s", mesh[i], (i+1)%4?" ":"\n");
  }
  printf("\n");
  /* Print solution on mesh.*/
  printf("\n Computed solution at mesh points\n");
  printf("    x       y1        y2         y3\n");
  for (i = 0; i < nmesh; i++) {

    /* nag_ode_bvp_coll_nlin_interp (d02tyc).
     * Ordinary differential equations, general nonlinear boundary value 
     * problem, interpolation for nag_ode_bvp_coll_nlin_solve (d02tlc).
     */
    nag_ode_bvp_coll_nlin_interp(mesh[i], y, neq, mmax, rcomm, icomm, 
                                 &fail);
    if (fail.code != NE_NOERROR) {
      printf("Error from nag_ode_bvp_coll_nlin_interp (d02tyc).\n%s\n",
             fail.message);
      exit_status = 5;
      goto END;
    }

    printf("%6.3f ", mesh[i]);
    for (j = 0; j < 3; j++) {
      printf("%11.3e", y[j]);
    }
    printf("\n");
  }

 END:
  NAG_FREE(mesh);
  NAG_FREE(m);
  NAG_FREE(tols);
  NAG_FREE(rcomm);
  NAG_FREE(y);
  NAG_FREE(ipmesh);
  NAG_FREE(icomm);
  return exit_status;
}

static void NAG_CALL ffun(double x, const double y[], Integer neq, 
                          const Integer m[], double f[], Nag_Comm *comm)
{
  func_data *fd = (func_data *)comm->p;
  double  beta;
  Integer i;
  double  one = 1.0;
  double  two = 2.0;
  double  zero = 0.0;

  if (comm->user[0] == -1.0)
    {
      printf("(User-supplied callback ffun, first invocation.)\n");
      comm->user[0] = 0.0;
    }
  /* nag_pi (x01aac). */ 
  beta = fd->beta0 * (one + cos(two * nag_pi * x));
  f[0] = fd->mu - beta * y[0] * y[2];
  f[1] = beta * y[0] * y[2] - y[1]/fd->lambda;
  f[2] = y[1]/fd->lambda - y[2]/fd->eta;
  for (i = 3; i < 6; i++) {
    f[i] = zero;
  }
}

static void NAG_CALL fjac(double x, const double y[], Integer neq, 
                          const Integer m[], double dfdy[], Nag_Comm *comm)
{
  func_data *fd = (func_data *)comm->p;
  double beta;
  double  one = 1.0;
  double  two = 2.0;

  if (comm->user[1] == -1.0)
    {
      printf("(User-supplied callback fjac, first invocation.)\n");
      comm->user[1] = 0.0;
    }
  /* nag_pi (x01aac). */
  beta = fd->beta0 * (one + cos(two * nag_pi * x));
  dfdy[0+0*neq] = -beta * y[2];       
  dfdy[0+2*neq] = -beta * y[0];       
  dfdy[1+0*neq] =  beta * y[2];        
  dfdy[1+1*neq] = -one/fd->lambda;    
  dfdy[1+2*neq] =  beta * y[0];        
  dfdy[2+1*neq] =  one/fd->lambda;     
  dfdy[2+2*neq] = -one/fd->eta;       
}

static void NAG_CALL gafun(const double ya[], Integer neq, const Integer m[],
                           Integer nlbc, double ga[], Nag_Comm *comm)
{
  if (comm->user[2] == -1.0)
    {
      printf("(User-supplied callback gafun, first invocation.)\n");
      comm->user[2] = 0.0;
    }
  ga[0] = ya[0] - ya[3];
  ga[1] = ya[1] - ya[4];
  ga[2] = ya[2] - ya[5];
}

static void NAG_CALL gbfun(const double yb[], Integer neq, const Integer m[],
                           Integer nrbc, double gb[], Nag_Comm *comm)
{
  if (comm->user[3] == -1.0)
    {
      printf("(User-supplied callback gbfun, first invocation.)\n");
      comm->user[3] = 0.0;
    }
  gb[0] = yb[0] - yb[3];
  gb[1] = yb[1] - yb[4];
  gb[2] = yb[2] - yb[5];
}

static void NAG_CALL gajac(const double ya[], Integer neq, const Integer m[],
                           Integer nlbc, double dgady[], Nag_Comm *comm)
{
  double  one = 1.0;

  if (comm->user[4] == -1.0)
    {
      printf("(User-supplied callback gajac, first invocation.)\n");
      comm->user[4] = 0.0;
    }
  dgady[0+0*nlbc] =  one;
  dgady[0+3*nlbc] = -one;
  dgady[1+1*nlbc] =  one;
  dgady[1+4*nlbc] = -one;
  dgady[2+2*nlbc] =  one;
  dgady[2+5*nlbc] = -one;
}

static void NAG_CALL gbjac(const double yb[], Integer neq, const Integer m[],
                           Integer nrbc, double dgbdy[], Nag_Comm *comm)
{
  double  one = 1.0;

  if (comm->user[5] == -1.0)
    {
      printf("(User-supplied callback gbjac, first invocation.)\n");
      comm->user[5] = 0.0;
    }
  dgbdy[0+0*nrbc] =  one;
  dgbdy[0+3*nrbc] = -one;
  dgbdy[1+1*nrbc] =  one;
  dgbdy[1+4*nrbc] = -one;
  dgbdy[2+2*nrbc] =  one;
  dgbdy[2+5*nrbc] = -one;
}

static void NAG_CALL guess(double x, Integer neq, const Integer m[], double y[],
                           double dym[], Nag_Comm *comm)
{
  Integer i;

  if (comm->user[6] == -1.0)
    {
      printf("(User-supplied callback guess, first invocation.)\n");
      comm->user[6] = 0.0;
    }
  for (i = 0; i < neq; i++) {
    y[i] = 1.0;
    dym[i] = 0.0;
  }
}