/* nag_rand_kfold_xyw (g05pvc) Example Program.
 *
 * Copyright 2014 Numerical Algorithms Group.
 *
 * Mark 25, 2014.
 */
/* Pre-processor includes */
#include <stdio.h>
#include <nag.h>
#include <nag_stdlib.h>
#include <nagg02.h>
#include <nagg05.h>

int main(void)
{
  /* Integer scalar and array declarations */
  Integer fn, fold, fp, i, ip, k, pdx, lstate, m,
    max_nv, n, nn, np, nt, nv, obs_val, pred_val,
    subid, tn, tp, j, pdv, rank, max_iter, print_iter;
  Integer exit_status = 0, lseed = 1;
  Integer *isx = 0, *state = 0;
  Integer seed[1];

  /* NAG structures and types */
  NagError fail;
  Nag_Link link;
  Nag_IncludeMean mean;
  Nag_BaseRNG genid;
  Nag_Distributions errfn;
  Nag_Boolean vfobs;
  Nag_DataByObsOrVar sordx;

  /* Double scalar and array declarations */
  double ex_power, dev, eps, tol, df, scale;
  double *b = 0, *cov = 0, *eta = 0, *pred = 0, *se = 0, *seeta = 0,
    *sepred = 0, *v = 0, *offset = 0, *wt = 0, *x = 0, *y = 0, *t = 0;

  /* Character scalar and array declarations */
  char clink[40], cmean[40], cgenid[40];

  /* Initialise the error structure */
  INIT_FAIL(fail);

  printf("nag_rand_kfold_xyw (g05pvc) Example Program Results\n\n");

  /* Skip heading in data file */
  scanf("%*[^\n] ");

  /* Set variables required by the regression (g02gbc) ... */

  /* Read in the type of link function, whether a mean is required */
  /* and the problem size */
  scanf("%39s%39s%ld%ld%*[^\n] ",clink, cmean, &n, &m);
  link = (Nag_Link) nag_enum_name_to_value(clink);
  mean = (Nag_IncludeMean) nag_enum_name_to_value(cmean);

  /* Set storage order for g05pvc */
  /* (pick the one required by g02gbc and g02gpc) */
  sordx = Nag_DataByObs;

  pdx = m;
  if (!(x = NAG_ALLOC(pdx*n, double)) ||
      !(y = NAG_ALLOC(n,double)) ||
      !(t = NAG_ALLOC(n,double)) ||
      !(isx = NAG_ALLOC(m,Integer)))
    {
      printf("Allocation failure\n");
      exit_status = -1;
      goto END;
    }

  /* This example is not using an offset or weights */
  offset = 0;
  wt = 0;

  /* Read in data */
  for (i = 0; i < n; i++)
    {
      for (j = 0; j < m; j++)
        {
          scanf("%lf", &x[i * pdx + j]);
        }
      scanf("%lf%lf%*[^\n] ", &y[i], &t[i]);
    }

  /* Read in variable inclusion flags */
  for (j = 0; j < m; j++)
    {
      scanf("%ld",&isx[j]);
    }
  scanf("%*[^\n] ");

  /* Read in control parameters for the regression */
  scanf("%ld%lf%lf%ld%*[^\n] ", &print_iter, &eps,
        &tol, &max_iter);

  /* Calculate IP */
  for (ip = 0, i = 0; i < m; i++) ip += (isx[i] > 0);
  if (mean == Nag_MeanInclude) ip++;
  /* ... End of setting variables required by the regression */


  /* Set variables required by data sampling routine (g05pvc) ... */

  /* Read in the base generator information and seed */
  scanf("%39s%ld%ld%*[^\n] ",cgenid, &subid, &seed[0]);
  genid = (Nag_BaseRNG) nag_enum_name_to_value(cgenid);

  /* Initial call to g05kfc to get size of STATE array */
  lstate = 0;
  nag_rand_init_repeatable(genid,subid,seed,lseed,state,&lstate,NAGERR_DEFAULT);

  /* Allocate state array */
  if (!(state = NAG_ALLOC(lstate, Integer)))
    {
      printf("Allocation failure\n");
      exit_status = -1;
      goto END;
    }

  /* Initialise the generator to a repeatable sequence using g05kfc */
  nag_rand_init_repeatable(genid, subid, seed, lseed, state, &lstate,
                           NAGERR_DEFAULT);

  /* Read in the number of folds */
  scanf("%ld%*[^\n] ",&k);
  /* ... End of setting variables required by data sampling routine */


  /* Set variables required by prediction routine (g02gpc) ... */

  /* Regression is performed using g02gbc so error structure is binomial */
  errfn = Nag_Binomial;

  /* This example does not use the predicted standard errors, so */
  /* it doesn't matter what VFOBS is set to */
  vfobs = Nag_FALSE;
  /* The error and link being used in the linear model don't use scale */
  /* and ex_power so they can be set to anything */
  ex_power = 0.0;
  scale = 0.0;
  /* ... End of setting variables required by prediction routine */


  /* This is the maximum size for a validation dataset */
  max_nv = (Integer) (((double) n / (double) k) + 0.5);

  /* Allocate arrays */
  pdv = n;
  if (!(b = NAG_ALLOC(ip, double)) ||
      !(se = NAG_ALLOC(ip,double)) ||
      !(cov = NAG_ALLOC(ip*(ip+1)/2,double)) ||
      !(v = NAG_ALLOC(n*pdv,double)) ||
      !(eta = NAG_ALLOC(max_nv,double)) ||
      !(seeta = NAG_ALLOC(max_nv,double)) ||
      !(pred = NAG_ALLOC(max_nv,double)) ||
      !(sepred = NAG_ALLOC(max_nv,double)))

    {
      printf("Allocation failure\n");
      exit_status = -1;
      goto END;
    }

  /* Initialise counts */
  tp = tn = fp = fn = 0;

  /* Loop over each fold */
  for (fold = 1; fold <= k; fold++)

    {
      /* Use g05pvc to split the data into training and validation datasets */
      nag_rand_kfold_xyw(k,fold,n,m,sordx,x,pdx,y,t,&nt,state,&fail);
      if (fail.code != NE_NOERROR)
        {
          printf("Error from nag_rand_kfold_xyw (g05pvc).\n%s\n",
                 fail.message);
          exit_status = 1;
          if (fail.code != NW_POTENTIAL_PROBLEM) goto END;
        }

      /* Calculate the size of the validation dataset */
      nv = n - nt;

      /* Call g02gbc to fit generalized linear model, with Binomial */
      /* errors to training data */
      nag_glm_binomial(link, mean, nt, x, pdx, m, isx, ip, y, t, wt,
                       offset, &dev, &df, b, &rank, se, cov, v, pdv,
                       tol, max_iter, print_iter, "", eps, &fail);
      if (fail.code != NE_NOERROR) {
        printf("Error from nag_glm_binomial (g02gbc).\n%s\n",
               fail.message);
        exit_status = 1;
        goto END;

      }

      /* Call g02gpc to predict the response for the observations in the */
      /* validation dataset */
      /* We want to start passing X and T at the (NT+1)th observation, */
      /* These start at (i,j)=(nt+1,1), hence the (nt*pdx+0)th element */
      /* of X and the nt'th element of T */
      nag_glm_predict(errfn,link,mean,nv,&x[nt*pdx],pdx,m,isx,ip,&t[nt],
                      offset,wt,scale,ex_power,b,cov,vfobs,eta,seeta,pred,
                      sepred,&fail);
      if (fail.code != NE_NOERROR)
        {
          printf("Error from nag_glm_predict (g02gpc).\n%s\n",
                 fail.message);
          exit_status = 1;
          goto END;
        }

      /* Count the true/false positives/negatives */
      for (i = 0; i < nv; i++)
        {
          obs_val = (Integer) y[nt+i];
          pred_val = (pred[i] >= 0.5 ? 1 : 0);
          if (obs_val)
            {
              /* Positive */
              if (pred_val)
                {
                  /* True positive */
                  tp++;
                }
              else
                {
                  /* False Negative */
                  fn++;
                }
            }
          else
            {
              /* Negative */
              if (pred_val)
                {
                  /* False positive */
                  fp++;
                }
              else
                {
                  /* True negative */
                  tn++;
                }
            }
        }
    }

  /* Display results */
  np = tp + fn;
  nn = fp + tn;
  printf("                       Observed\n");
  printf("             --------------------------\n");
  printf(" Predicted | Negative  Positive   Total\n");
  printf(" --------------------------------------\n");
  printf(" Negative  | %5ld     %5ld     %5ld\n",
         tn, fn, tn + fn);
  printf(" Positive  | %5ld     %5ld     %5ld\n",
         fp, tp, fp + tp);
  printf(" Total     | %5ld     %5ld     %5ld\n",
         nn, np, nn + np);
  printf("\n");

  if (np != 0)
    {
      printf(" True Positive Rate (Sensitivity): %4.2f\n",
             (double) tp / (double) np);
    }
  else
    {
      printf(" True Positive Rate (Sensitivity): No positives in data\n");
    }
  if (nn != 0)
    {
      printf(" True Negative Rate (Specificity): %4.2f\n",
             (double) tn / (double) nn);
    }
  else
    {
      printf(" True Negative Rate (Specificity): No negatives in data\n");
    }

 END:

  NAG_FREE(isx);
  NAG_FREE(state);
  NAG_FREE(b);
  NAG_FREE(cov);
  NAG_FREE(eta);
  NAG_FREE(pred);
  NAG_FREE(se);
  NAG_FREE(seeta);
  NAG_FREE(sepred);
  NAG_FREE(v);
  NAG_FREE(offset);
  NAG_FREE(wt);
  NAG_FREE(x);
  NAG_FREE(y);
  NAG_FREE(t);

  return(exit_status);
}