/* nag_opt_check_2nd_deriv (e04hdc) Example Program.
 *
 * NAGPRODCODE Version.
 *
 * Copyright 2016 Numerical Algorithms Group.
 *
 * Mark 26, 2016.
 *
 */

#include <nag.h>
#include <stdio.h>
#include <nag_stdlib.h>
#include <math.h>
#include <nage04.h>

#ifdef __cplusplus
extern "C"
{
#endif
  static void NAG_CALL h(Integer n, const double xc[], double fhesl[],
                         double fhesd[], Nag_Comm *comm);

  static void NAG_CALL funct(Integer n, const double xc[], double *fc,
                             double gc[], Nag_Comm *comm);
#ifdef __cplusplus
}
#endif

int main(void)
{
  static double ruser[2] = { -1.0, -1.0 };
  Integer exit_status = 0, i, j, k, n;
  NagError fail;
  Nag_Comm comm;
  double f, *g = 0, *hesd = 0, *hesl = 0, *x = 0;

  INIT_FAIL(fail);

#define X(I)    x[(I) -1]
#define HESL(I) hesl[(I) -1]
#define HESD(I) hesd[(I) -1]
#define G(I)    g[(I) -1]

  printf("nag_opt_check_2nd_deriv (e04hdc) Example Program Results\n\n");

  /* For communication with user-supplied functions: */
  comm.user = ruser;

  /* Set up an arbitrary point at which to check the derivatives */
  n = 4;

  if (n >= 1) {
    if (!(hesd = NAG_ALLOC(n, double)) ||
        !(hesl = NAG_ALLOC(n * (n - 1) / 2, double)) ||
        !(g = NAG_ALLOC(n, double)) || !(x = NAG_ALLOC(n, double)))
    {
      printf("Allocation failure\n");
      exit_status = -1;
      goto END;
    }
  }
  else {
    printf("Invalid n.\n");
    exit_status = 1;
    return exit_status;
  }

  X(1) = 1.46;
  X(2) = -0.82;
  X(3) = 0.57;
  X(4) = 1.21;

  printf("The test point is\n");
  for (j = 1; j <= n; ++j)
    printf("%9.4f", X(j));
  printf("\n");

  /* Check the 1st derivatives */
  /* nag_opt_check_deriv (e04hcc).
   * Derivative checker for use with nag_opt_bounds_deriv
   * (e04kbc)
   */
  nag_opt_check_deriv(n, funct, &X(1), &f, &G(1), &comm, &fail);
  if (fail.code != NE_NOERROR) {
    printf("Error from nag_opt_check_deriv (e04hcc).\n%s\n", fail.message);
    exit_status = 1;
    goto END;
  }

  /* Check the 2nd derivatives */
  /* nag_opt_check_2nd_deriv (e04hdc).
   * Checks second derivatives of a user-defined function
   */
  nag_opt_check_2nd_deriv(n, funct, h, &X(1), &G(1), &HESL(1), &HESD(1),
                          &comm, &fail);
  if (fail.code != NE_NOERROR) {
    printf("Error from nag_opt_check_2nd_deriv (e04hdc).\n%s\n",
           fail.message);
    exit_status = 1;
    goto END;
  }

  printf("\n2nd derivatives are consistent with 1st derivatives.\n\n");
  printf("At the test point, funct gives the function value, %13.4e\n", f);
  printf("and the 1st derivatives\n");
  for (j = 1; j <= n; ++j)
    printf("%12.3e%s", G(j), j % 4 ? "" : "\n");

  printf("\nh gives the lower triangle of the Hessian matrix\n");
  printf("%12.3e\n", HESD(1));
  k = 1;
  for (i = 2; i <= n; ++i) {
    for (j = k; j <= k + i - 2; ++j)
      printf("%12.3e", HESL(j));
    printf("%12.3e\n", HESD(i));
    k = k + i - 1;
  }
END:
  NAG_FREE(hesd);
  NAG_FREE(hesl);
  NAG_FREE(g);
  NAG_FREE(x);
  return exit_status;
}

static void NAG_CALL funct(Integer n, const double xc[], double *fc,
                           double gc[], Nag_Comm *comm)
{
  /* Routine to evaluate objective function and its 1st derivatives. */

  if (comm->user[0] == -1.0) {
    printf("(User-supplied callback funct, first invocation.)\n");
    comm->user[0] = 0.0;
  }
  *fc = pow(xc[0] + 10.0 * xc[1], 2.0) + 5.0 * pow(xc[2] - xc[3], 2.0)
         + pow(xc[1] - 2.0 * xc[2], 4.0) + 10.0 * pow(xc[0] - xc[3], 4.0);

  gc[0] = 2.0 * (xc[0] + 10.0 * xc[1]) + 40.0 * pow(xc[0] - xc[3], 3.0);
  gc[1] = 20.0 * (xc[0] + 10.0 * xc[1]) + 4.0 * pow(xc[1] - 2.0 * xc[2], 3.0);
  gc[2] = 10.0 * (xc[2] - xc[3]) - 8.0 * pow(xc[1] - 2.0 * xc[2], 3.0);
  gc[3] = 10.0 * (xc[3] - xc[2]) - 40.0 * pow(xc[0] - xc[3], 3.0);
}

static void NAG_CALL h(Integer n, const double xc[], double fhesl[],
                       double fhesd[], Nag_Comm *comm)
{
  /* Routine to evaluate 2nd derivatives */

  if (comm->user[1] == -1.0) {
    printf("(User-supplied callback h, first invocation.)\n");
    comm->user[1] = 0.0;
  }
  fhesd[0] = 2.0 + 120.0 * pow(xc[0] - xc[3], 2.0);
  fhesd[1] = 200.0 + 12.0 * pow(xc[1] - 2.0 * xc[2], 2.0);
  fhesd[2] = 10.0 + 48.0 * pow(xc[1] - 2.0 * xc[2], 2.0);
  fhesd[3] = 10.0 + 120.0 * pow(xc[0] - xc[3], 2.0);
  fhesl[0] = 20.0;
  fhesl[1] = 0.0;
  fhesl[2] = -24.0 * pow(xc[1] - 2.0 * xc[2], 2.0);
  fhesl[3] = -120.0 * pow(xc[0] - xc[3], 2.0);
  fhesl[4] = 0.0;
  fhesl[5] = -10.0;
}