/* nag_pde_parab_1d_keller_ode (d03pkc) Example Program.
 *
 * NAGPRODCODE Version.
 *
 * Copyright 2016 Numerical Algorithms Group.
 *
 * Mark 26, 2016.
 */

#include <stdio.h>
#include <math.h>
#include <nag.h>
#include <nag_stdlib.h>
#include <nagd03.h>

#ifdef __cplusplus
extern "C"
{
#endif
  static void NAG_CALL pdedef(Integer, double, double, const double[],
                              const double[], const double[], Integer,
                              const double[], const double[], double[],
                              Integer *, Nag_Comm *);

  static void NAG_CALL bndary(Integer npde, double t, Integer ibnd,
                              Integer nobc, const double u[],
                              const double ut[], Integer ncode,
                              const double v[], const double vdot[],
                              double res[], Integer *ires, Nag_Comm *);

  static void NAG_CALL odedef(Integer, double, Integer, const double[],
                              const double[], Integer, const double[],
                              const double[], const double[], const double[],
                              double[], Integer *, Nag_Comm *);

  static void NAG_CALL uvinit(Integer npde, Integer npts, double *x,
                              double *u, Integer ncode, Integer neqn,
                              double ts);

  static void NAG_CALL exact(double, Integer, Integer, double *, double *);
#ifdef __cplusplus
}
#endif

#define UCP(I, J) ucp[npde*((J) -1)+(I) -1]

int main(void)
{
  const Integer npde = 2, npts = 21, ncode = 1, nxi = 1, nleft = 1;
  const Integer neqn = npde * npts + ncode, lisave = 24;
  const Integer nwkres =
         npde * (npts + 6 * nxi + 3 * npde + 15) + ncode + nxi + 7 * npts + 2;
  const Integer lenode = 11 * neqn + 50, lrsave =
         neqn * neqn + neqn + nwkres + lenode;
  static double ruser[3] = { -1.0, -1.0, -1.0 };
  double tout, ts;
  Integer exit_status = 0, i, ind, it, itask, itol, itrace;
  Nag_Boolean theta;
  double *algopt = 0, *atol = 0, *exy = 0, *rsave = 0, *rtol = 0;
  double *u = 0, *x = 0, *xi = 0;
  Integer *isave = 0;
  NagError fail;
  Nag_Comm comm;
  Nag_D03_Save saved;

  INIT_FAIL(fail);

  printf("nag_pde_parab_1d_keller_ode (d03pkc) Example Program Results\n\n");

  /* For communication with user-supplied functions: */
  comm.user = ruser;

  /* Allocate memory */

  if (!(algopt = NAG_ALLOC(30, double)) ||
      !(atol = NAG_ALLOC(1, double)) ||
      !(exy = NAG_ALLOC(neqn, double)) ||
      !(rsave = NAG_ALLOC(lrsave, double)) ||
      !(rtol = NAG_ALLOC(1, double)) ||
      !(u = NAG_ALLOC(neqn, double)) ||
      !(x = NAG_ALLOC(npts, double)) ||
      !(xi = NAG_ALLOC(nxi, double)) || !(isave = NAG_ALLOC(lisave, Integer)))
  {
    printf("Allocation failure\n");
    exit_status = 1;
    goto END;
  }

  itrace = 0;
  itol = 1;
  atol[0] = 1e-4;
  rtol[0] = atol[0];

  printf("  Accuracy requirement =%12.3e", atol[0]);
  printf(" Number of points = %3" NAG_IFMT "\n\n", npts);

  /* Set spatial-mesh points */

  for (i = 0; i < npts; ++i)
    x[i] = i / (npts - 1.0);

  xi[0] = 1.0;
  ind = 0;
  itask = 1;

  /* Set THETA to TRUE if the Theta integrator is required */

  theta = Nag_FALSE;
  for (i = 0; i < 30; ++i)
    algopt[i] = 0.0;
  if (theta) {
    algopt[0] = 2.0;
  }
  else {
    algopt[0] = 0.0;
  }
  algopt[0] = 1.0;
  algopt[12] = 0.005;

  /* Loop over output value of t */

  ts = 1e-4;
  printf("  x        %9.3f%9.3f%9.3f%9.3f%9.3f\n\n",
         x[0], x[4], x[8], x[12], x[20]);

  uvinit(npde, npts, x, u, ncode, neqn, ts);

  for (it = 0; it < 5; ++it) {
    tout = 0.1 * pow(2.0, (it + 1.0));
    /* nag_pde_parab_1d_keller_ode (d03pkc).
     * General system of first-order PDEs, coupled DAEs, method
     * of lines, Keller box discretization, one space variable
     */
    nag_pde_parab_1d_keller_ode(npde, &ts, tout, pdedef, bndary, u, npts, x,
                                nleft, ncode, odedef, nxi, xi, neqn, rtol,
                                atol, itol, Nag_TwoNorm, Nag_LinAlgFull,
                                algopt, rsave, lrsave, isave, lisave, itask,
                                itrace, 0, &ind, &comm, &saved, &fail);

    if (fail.code != NE_NOERROR) {
      printf("Error from nag_pde_parab_1d_keller_ode (d03pkc).\n%s\n",
             fail.message);
      exit_status = 1;
      goto END;
    }

    /* Check against the exact solution */

    exact(tout, neqn, npts, x, exy);

    printf(" t = %6.3f\n", ts);
    printf(" App.  sol.  %7.3f%9.3f%9.3f%9.3f%9.3f",
           u[0], u[8], u[16], u[24], u[40]);
    printf("  ODE sol. =%8.3f\n", u[42]);
    printf(" Exact sol.  %7.3f%9.3f%9.3f%9.3f%9.3f",
           exy[0], exy[8], exy[16], exy[24], exy[40]);
    printf("  ODE sol. =%8.3f\n\n", ts);
  }
  printf(" Number of integration steps in time = %6" NAG_IFMT "\n", isave[0]);
  printf(" Number of function evaluations = %6" NAG_IFMT "\n", isave[1]);
  printf(" Number of Jacobian evaluations =%6" NAG_IFMT "\n", isave[2]);
  printf(" Number of iterations = %6" NAG_IFMT "\n\n", isave[4]);
END:
  NAG_FREE(algopt);
  NAG_FREE(atol);
  NAG_FREE(exy);
  NAG_FREE(rsave);
  NAG_FREE(rtol);
  NAG_FREE(u);
  NAG_FREE(x);
  NAG_FREE(xi);
  NAG_FREE(isave);

  return exit_status;
}

static void NAG_CALL uvinit(Integer npde, Integer npts, double *x,
                            double *u, Integer ncode, Integer neqn, double ts)
{
  Integer i, k;

  /* Routine for PDE initial values */

  k = 0;
  for (i = 0; i < npts; ++i) {
    u[k] = exp(ts * (1.0 - x[i])) - 1.0;
    u[k + 1] = -ts * exp(ts * (1.0 - x[i]));
    k += 2;
  }
  u[neqn - 1] = ts;

  return;
}

static void NAG_CALL odedef(Integer npde, double t, Integer ncode,
                            const double v[], const double vdot[],
                            Integer nxi, const double xi[],
                            const double ucp[], const double ucpx[],
                            const double ucpt[], double f[], Integer *ires,
                            Nag_Comm *comm)
{
  if (comm->user[0] == -1.0) {
    printf("(User-supplied callback odedef, first invocation.)\n");
    comm->user[0] = 0.0;
  }
  if (*ires == -1) {
    f[0] = vdot[0];
  }
  else {
    f[0] = vdot[0] - v[0] * UCP(1, 1) - UCP(2, 1) - 1.0 - t;
  }
  return;
}

static void NAG_CALL pdedef(Integer npde, double t, double x,
                            const double u[], const double ut[],
                            const double ux[], Integer ncode,
                            const double v[], const double vdot[],
                            double res[], Integer *ires, Nag_Comm *comm)
{
  if (comm->user[1] == -1.0) {
    printf("(User-supplied callback pdedef, first invocation.)\n");
    comm->user[1] = 0.0;
  }
  if (*ires == -1) {
    res[0] = v[0] * v[0] * ut[0] - x * u[1] * v[0] * vdot[0];
    res[1] = 0.0;
  }
  else {
    res[0] = v[0] * v[0] * ut[0] - x * u[1] * v[0] * vdot[0] - ux[1];
    res[1] = u[1] - ux[0];
  }
  return;
}

static void NAG_CALL bndary(Integer npde, double t, Integer ibnd,
                            Integer nobc, const double u[], const double ut[],
                            Integer ncode, const double v[],
                            const double vdot[], double res[], Integer *ires,
                            Nag_Comm *comm)
{
  if (comm->user[2] == -1.0) {
    printf("(User-supplied callback bndary, first invocation.)\n");
    comm->user[2] = 0.0;
  }
  if (ibnd == 0) {
    if (*ires == -1) {
      res[0] = 0.0;
    }
    else {
      res[0] = u[1] + v[0] * exp(t);
    }
  }
  else {
    if (*ires == -1) {
      res[0] = v[0] * vdot[0];
    }
    else {
      res[0] = u[1] + v[0] * vdot[0];
    }
  }
  return;
}

static void NAG_CALL exact(double time, Integer neqn, Integer npts, double *x,
                           double *u)
{
  /* Exact solution (for comparison purposes) */

  Integer i, k;

  k = 0;
  for (i = 0; i < npts; ++i) {
    u[k] = exp(time * (1.0 - x[i])) - 1.0;
    k += 2;
  }
  return;
}