/*
 *  IMAL.c
 *
 *  Created by Wei Sun on Sun Mar 8th 2009.
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <string.h>
#include <R.h>
#include <Rmath.h>
#include "utility.h"
#include "IMAL.h"

/*********************************************************************
 *
 * IMALr
 *
 * The Iterative Adaptive Lasso, with X input from R
 *
 *********************************************************************/

void IMALr(double* Ry, double* RX, double* RB,
  double* Rdelta, double* Rtau, int* dims, double* Repsilon, int* nIter, 
  int* b_update_order, double* Rscore, int* RscoNA, double* score2use, 
  double* delta2use, double* tau2use)
{
  double **y, **X, **B;
  int n = dims[0];
  int p = dims[1];
  int h = dims[2];

  /* 
   * reorganize vector into matrix 
   * NOTE, each row of y is one response variable
   * and each column of X is a covariate
   */
  reorg(Ry, &y, h, n);
  reorg(RX, &X, n, p);
  reorg(RB, &B, p, h);

  IMAL(y, X, B, Rdelta, Rtau, dims, Repsilon, nIter, b_update_order, 
  Rscore, RscoNA, score2use, delta2use, tau2use);
}


/*********************************************************************
 *
 * IMALc
 *
 * The Iterative Adaptive Lasso, with X input from text file
 *
 *********************************************************************/

void IMALc(double* Ry, char** fname, double* RB, 
  double* Rdelta, double* Rtau, int* dims, double* Repsilon, int* nIter, 
  int* b_update_order, double* Rscore, int* RscoNA, double* score2use, 
  double* delta2use, double* tau2use)
{
  double **y, **X, **B;
  int i;
  int n = dims[0];
  int p = dims[1];
  int h = dims[2];
  int offsetrow  = dims[5];
  int offsetcol  = dims[6];
  int transposeX = dims[7];

  /* 
   * reorganize vector into matrix 
   * NOTE, each row of y is one response variable
   * and each column of X is a covariate
   */
  reorg(Ry, &y, h, n);
  reorg(RB, &B, p, h);

  /* allocate memory */
  X = (double **) malloc(n * sizeof(double*));
  X[0] = (double *) calloc(n*p, sizeof(double));
  if(X[0] == NULL){ error("fail to allocate memory of size %d\n", n*p); }
  for(i=1; i<n; i++){
    X[i] = X[0] + i*p;
  }
  
  /* read in data into matrix */
  // Rprintf("fname=%s, n=%d, p=%d, offsetrow=%d, offsetcol=%d, 
  // transposeX=%d\n", fname[0], n, p, offsetrow, offsetcol, transposeX);
  readtext(X, fname[0], n, p, offsetrow, offsetcol, transposeX);

  IMAL(y, X, B, Rdelta, Rtau, dims, Repsilon, nIter, b_update_order, 
  Rscore, RscoNA, score2use, delta2use, tau2use);
  
  free(X[0]);
  free(X);
}


/*********************************************************************
 *
 * IMAL
 *
 * Iterative Multivariate Adaptive Lasso
 *
 *********************************************************************/

void IMAL(double** y, double** X, double** RB, double* Rdelta, double* Rtau, 
  int* dims, double* Repsilon, int* nIter, int* b_update_order, double* Rscore, 
  int* RscoNA, double* score2use, double* delta2use, double* tau2use)
{
  int i, j, s, s1, s2, n, p, h, L, Nmax, n_delta, n_tau, p_max, *Jset, **scoNA;
  int k1, k2, k, w, jk, found, n_b2use, *w_b2use, remove1;
  double xij, XjAve, *Xj2, **B, **B1, B_js, **res, *res_s, **resCov, b0;
  double *b_norm, *kappa, *bj_bar, bj_bar_norm, *sigma2, sigma2_js, tmp;
  double **score, epsilon, delta1, delta2, tau1, BIC, dfBIC;
  time_t timer;

  n = dims[0];
  p = dims[1];
  h = dims[2];
  L = dims[3];
  Nmax = dims[4];
  n_delta = dims[8];
  n_tau   = dims[9];
  p_max   = dims[10];
  
  epsilon = *Repsilon;

  // Rprintf("n=%d p=%d h=%d L=%d Nmax=%d \n", n, p, h, L, Nmax);
  // Rprintf("n_delta=%d n_tau=%d p_max=%d\n", n_delta, n_tau, p_max);

  // BIC score
  *score2use = DBL_MAX;

  /* reorganize vector into matrix */
  reorg(Rscore, &score, n_delta, n_tau);
  reorg_int(RscoNA, &scoNA, n_delta, n_tau);

  /* allocate memory */
  sigma2 = (double *)calloc(h, sizeof(double));
  kappa  = (double *)calloc(p, sizeof(double));
  Xj2    = (double *)calloc(p, sizeof(double));
  bj_bar = (double *)calloc(h, sizeof(double));
  b_norm = (double *)calloc(p, sizeof(double));
  Jset    = (int *)calloc(p, sizeof(int));
  w_b2use = (int *)calloc(p_max, sizeof(int));

	res = (double**) calloc(h, sizeof(double*));
  res[0] = (double*) calloc(n*h, sizeof(double));
  for(s=1; s<h; s++){
    res[s] = res[0] + s*n;
  }

	B = (double**) calloc(p, sizeof(double*));
  B[0] = (double*) calloc(p*h, sizeof(double));
  for(j=1; j<p; j++){
    B[j] = B[0] + j*h;
  }

	B1 = (double**) calloc(p, sizeof(double*));
  B1[0] = (double*) calloc(p*h, sizeof(double));
  for(j=1; j<p; j++){
    B1[j] = B1[0] + j*h;
  }

  /* allocate memory */
  resCov = (double **)calloc(h, sizeof(double*));
  resCov[0] = (double*) calloc(h*h, sizeof(double));
  for(j=1; j<h; j++){
    resCov[j] = resCov[0] + j*h;
  }

  // Rprintf("memory allocation is done now\n");

  /* initialize the order to update coefficients */
  for(j=0; j<p; j++){
    Jset[j] = j;
  }

  /**
   * Remove mean values of Xj and Calculate Xj2
   * sum square for each marker, i.e., each column of X
   */
  for(j=0; j<p; j++){
    XjAve  = 0.0;
    for(i=0; i<n; i++){
      XjAve += X[i][j];
    }
    XjAve /= n;
    
    Xj2[j] = 0.0;
    for(i=0; i<n; i++){
      X[i][j] -= XjAve;
      xij      = X[i][j];
      Xj2[j]  += xij*xij;
    }
  }
  
  GetRNGstate();

  /**********************************************************
   * The EM algorithm
   **********************************************************/

  for(k1=0; k1 < n_delta; k1++){
    delta1 = Rdelta[k1];
    delta2 = delta1 + 1.0;
    
    for(k2=0; k2 < n_tau; k2++){
      tau1 = Rtau[k2];
      
      /**
       * step 1. Initialization
       */

      timer=time(NULL);
      // Rprintf("k1=%d, k2=%d, delta1=%f, tau1=%f, ", k1, k2, delta1, tau1);
      // Rprintf("%s\n", asctime(localtime(&timer)));
      
      for(s=0; s<h; s++){
        sigma2[s] = var(y[s], n);
        // Rprintf("s=%d, sigma2[s]=%f\n", s, sigma2[s]);
      }

      for(j=0; j<p; j++){
        kappa[j] = tau1/delta2;
      }

      /* initialize B1 by B */
      for(j=0; j<p; j++){
        for(s=0; s<h; s++){
          B1[j][s] = RB[j][s];
          B[j][s]  = RB[j][s];
        }
      }

      /* k is used to check the convergence */
      k = 0;
      
      /* initialize residuals */
      for(s=0; s<h; s++){
        for(i=0; i<n; i++){
          res[s][i] = y[s][i];
        }
      }

      for(w=1; w<=Nmax; w++){
        /*
        if(w % 100 == 0){
          Rprintf("w=%d\n", w);
        }
        */
        
        /**
         * step 3.1 choose the order of updating b[j]
         */

        if(*b_update_order == 1){
          /* need to do nothing */
        }else if(*b_update_order == 2){
          rsample(Jset, p);
        }else{
          error("invalid b_update_order\n");
        }

        /**
         * step 3.2 Update B[j][i]
         */
        // Rprintf("Update B\n");

        for(jk=0; jk<p; jk++){
          j = Jset[jk];

          /* remove the effect of Xj from the residual*/
          for(s=0; s<h; s++){
            B_js = B[j][s];
            for(i=0; i<n; i++){
              res[s][i] += X[i][j]*B_js;
            }
          }
          
          bj_bar_norm = 0.0;
          for(s=0; s<h; s++){
            bj_bar[s] = 0.0;
            res_s = res[s];
            for(i=0; i<n; i++){
              bj_bar[s] += X[i][j]*res_s[i];
            }
            bj_bar[s] /= Xj2[j];
            bj_bar_norm += bj_bar[s]*bj_bar[s]/sigma2[s];
          }
          bj_bar_norm = sqrt(bj_bar_norm);

          if(bj_bar_norm <= 1/kappa[j]/Xj2[j]){
            b_norm[j] = 0.0;
            for(s=0; s<h; s++){
              B[j][s] = 0.0;
            }
          }else{
            b_norm[j] = bj_bar_norm - 1/kappa[j]/Xj2[j];
            for(s=0; s<h; s++){
              B[j][s] = bj_bar[s]*(1 - 1/(kappa[j]*Xj2[j]*bj_bar_norm));
            }
            
            /* add the effect of Xj back into the residual */
            for(s=0; s<h; s++){
              B_js = B[j][s];
              for(i=0; i<n; i++){
                res[s][i] -= X[i][j]*B_js;
              }
            }
            
          } // end of updating B[j][s]
        } // end of updating B[j]

        /**
         * step 2. Update b0
         */
        for(s=0; s<h; s++){
          b0 = mean(res[s], n);
          if(fabs(b0) > 1e-5){
            warning("s=%d, intercept=%f, and it is not 0.0\n", s, b0);
            for(i=0; i<n; i++){
              res[s][i] -= b0;
            }
          }
        }
    
        /**
         * step 4 Update sigma^2
         */
        // Rprintf("Update sigma^2\n");

        for(s=0; s<h; s++){
          sigma2[s] = 0.0;
          res_s = res[s];
          for(i=0; i<n; i++){
            sigma2[s] += res_s[i]*res_s[i];
          }
          sigma2[s] /= n;
          // Rprintf("sigma2[%d]=%f\n", h, sigma2[h]);
        }

        /**
         * step 5 Update kappa
         */
        // Rprintf("Update kappa\n");

        for(j=0; j<p; j++){
          kappa[j] = (b_norm[j] + tau1)/delta2;
        }
    
        /**
         * check convergence
         */
        // Rprintf("check convergence\n");
        
        found = 0;
        for(j=0; j<p; j++){
          for(s=0; s<h; s++){
            if(fabs(B1[j][s] - B[j][s]) > epsilon){
              found = 1; k = 0; break;
            }
          }
        }
        
        for(j=0; j<p; j++){
          for(s=0; s<h; s++){
            B1[j][s] = B[j][s];
          }
        }
    
        if(found == 0){
          k += 1;
          if(k>=L){
            break; // converged :)
          }
        }
      } // end of loop for one set of delta and tau 

      n_b2use = 0;
      for(j=0; j < p; j++){
        if(b_norm[j] > 1e-10){
          w_b2use[n_b2use] = j;
          n_b2use++;
          if(n_b2use >= p_max){
            break;
          }
        }
      }
      
      /**
       * ignore this combinition of delta and tau 
       * if too many covariates are chosen 
       */
      if(n_b2use == 0 || n_b2use >= p_max){
        continue;
      }
      
      /**
       * calculate the residual covariance matrix
       */
      resCov[0][0] = sigma2[0];
      
      for(s1=1; s1<h; s1++){
        resCov[s1][s1] = sigma2[s1];
        for(s2=0; s2<s1; s2++){
          tmp = 0.0;
          for(i=0; i<n; i++){
            tmp += res[s1][i]*res[s2][i];
          }
          tmp /= n;
          resCov[s1][s2] = tmp;
          resCov[s2][s1] = tmp;
        }
      }
      
      dfBIC = (double)n_b2use;
      
      for(j=0; j<n_b2use; j++){
        jk     = w_b2use[j];
        dfBIC += (h-1)*b_norm[jk]/(b_norm[jk] + 1/(kappa[jk]*Xj2[jk]));
      }
      
      // note resCov will be changed in the function determinant
      determinant(resCov, h, &BIC);
      BIC = n*log(BIC) + dfBIC*log(n*h);
      
      if(BIC < *score2use){
        *score2use = BIC;
        *delta2use = delta1;
        *tau2use   = tau1;
        *nIter = w;
        for(j=0; j<p; j++){
          for(s=0; s<h; s++){
            RB[j][s] = B[j][s];
          }
        }
      }
      
      //Rprintf("n_b2use=%d, dfBIC = %f, BIC=%f\n", n_b2use, dfBIC, BIC);
      //Rprintf("delta = %f, tau = %f\n", delta1, tau1);
      //Rprintf("=================================================\n\n");

      score[k1][k2] = BIC;
      scoNA[k1][k2] = 0;
    }
  }

  free(sigma2);
  free(kappa);
  free(Xj2);
  free(bj_bar);
  free(b_norm);
  free(Jset);
  free(w_b2use);

  free(res[0]);
  free(res);

  free(resCov[0]);
  free(resCov);

  free(B[0]);
  free(B);

  free(B1[0]);
  free(B1);
}
