###############################################################################
###                                                                         ###
### This file contains the R-functions used in testing the significance of  ###
### a gene pathway using sLDA.  The main function is SLDA.PathwayTest.BIC   ###
### and SLDA.PathwayTest.  Both functions take in the following input:      ###
###                                                                         ###
###    X: an n x p matrix of gene expression values where n is the sample   ###
###       size and p is the number of genes in the gene set/pathway         ###
###                                                                         ###
###    y: an n-vector of class labels (0 or 1)                              ###
###                                                                         ###
###    nperms : the number of permutations to use for estimating signficance###
###                                                                         ###
###    fracs: the grid over which to search for                             ###
###                                                                         ###
###    lambda2: a constant to be added to the diagonal of the estimated     ###
###             covariance matrix                                           ###
###                                                                         ###
### Both functions return:                                                  ###
###    p.value: a p-value for significance                                  ###
###                                                                         ###
###    pathscore: the t-statistic measuring differential expression         ###
###               of the averaged gene expression values                    ###
###                                                                         ###
###    loadings: the weights from the sLDA algorithm                        ###
###                                                                         ###
###    statvec: the vector of the permutation t-statistics                  ###
###              (only useful for finding p.value)                          ###
###                                                                         ###
###                                                                         ###
###############################################################################
###############################################################################
###                                                                         ###
### DATE: last modified 08.10.2009   version 0.0                            ###
###                                                                         ###
###############################################################################

SLDA.PathwayTest.BIC1 = function(X,y, nperms = 100, fracs = 0:500/500, lambda2=0.1) {
  try(library(genefilter))
  try(library(genefilter, lib.loc = "~/Rlib"))
  n = length(y)
  p = ncol(X)

  print("Starting Test...")
  moo = infunction.BIC(X,y, fracs, lambda2)

  print("Start Permuting...")

  pathscore = moo$pathscore
  loadings = moo$loadings

  statvec = rep(NA, nperms)
  for (perm in seq(nperms)) {
    if (perm %% 10 ==0) print(perm)
    y1 = sample(y, n, replace = F)
    cow = try(infunction.BIC(X,y1, fracs, lambda2), silent = T)
    while (class(cow)=="try-error") {
      y1 = sample(y, n, replace = F)
      cow =  try(infunction.BIC(X,y1, fracs, lambda2), silent = T)
    }
    statvec[perm] = cow$pathscore
  }
  return(list(stat = pathscore, p.value = mean(statvec>=pathscore), loadings = loadings, statvec = statvec))
}


infunction.BIC = function(X,y, fracs, lambda2){
  n = length(y)

  n1 = sum(y)
  n2 = n-n1

  mod = SLDApath(X,y, lambda2)
  w = SLDA.coefs(mod, fracs = fracs, meaningful.range = T, plotpath=F)
  
  S =((n1-1)*cov(X[y==1,])+(n2-1)*cov(X[y==0,]))
  ks = apply(w != 0,1, sum)
  
  numerator  = diag(w%*%S%*%t(w))
  BIC = log(numerator/(n-ks-1)) + ks*log(n)/n

  coefs = as.vector(w[which.min(BIC),])

  return(list(pathscore = abs(t.test((X%*%(coefs))~y, var.equal = T)$stat),  loadings = (coefs)))
}



SLDA.PathwayTest = function(X,y, nperms = 100, fracs = 0:500/500, lambda2=0.1, nfold = 4) {
  try(library(genefilter))
  try(library(genefilter, lib.loc = "~/Rlib"))
  n = length(y)
  p = ncol(X)

  print("Starting Test...")
  moo = try(infunction(X,y, fracs, lambda2, nfold), silent = F)
  co = 0
  while (class(moo) == "try-error" && co<5){
    co = co+1
    print(co)
    moo = try(infunction(X,y, fracs, lambda2, nfold), silent = F)
  }
  print("Start Permuting...")

  pathscore = moo$pathscore
  loadings = moo$loadings

  statvec = rep(NA, nperms)
  for (perm in seq(nperms)) {
    if (perm %% 10 ==0) print(perm)
    y1 = sample(y, n, replace = F)
    cow = try(infunction(X,y1, fracs, lambda2, nfold), silent = T)
    while (class(cow)=="try-error") {
      y1 = sample(y, n, replace = F)
      cow =  try(infunction(X,y1, fracs, lambda2, nfold), silent = T)
    }
    statvec[perm] = cow$pathscore
  }
  return(list(stat = pathscore, p.value = mean(statvec>=pathscore), loadings = loadings, statvec = statvec))
}




###################################################
###################################################
##### THE FOLLOWING ARE JUST HELPER FUNCTIONS  ####
###################################################
###################################################

infunction = function(X,y, fracs, lambda2, nfold){
  n = length(y)

  n1 = sum(y)
  n2 = n-n1
  w1 = which(y==1)
  w2 = which(y==0)

  f1 = split(sample(seq(n1)), rep(1:nfold, length = n1))
  f2 = split(sample(seq(n2)), rep(1:nfold, length = n2))

  cv.tstat = rep(NA, length(fracs))
  err = matrix(nrow = nfold, ncol = length(fracs))
  for ( f in seq(nfold)) {
    omit = c(f1[[f]], f2[[f]])
    xx = X[-omit,]
    yy = y[-omit]
    mod = SLDApath(xx,yy, lambda2)
    coefs = SLDA.coefs(mod, fracs = fracs, meaningful.range =T, plotpath = F)
    Xn = X[omit,]%*%t(coefs)
    err[f,] = abs(colttests(Xn, as.factor(y[omit]), tstatOnly = T)$statistic)
  }
  ferr = apply(err,2, mean)
  ffrac = fracs[which.max(ferr)]

  mod = SLDApath(X,y, lambda2)
  coefs = SLDA.coefs(mod, fracs = ffrac, meaningful.range = T, plotpath=F)

  return(list(pathscore = abs(t.test((X%*%t(coefs))~y)$stat),  loadings = as.vector(coefs)))
}



myLDA = function(X,y, lambda2 = 0) {
  n1 = sum(y)
  n2 = sum(1-y)
  n = n1+n2
  xbar1 = apply(X[which(y==1),], 2, mean)
  xbar2= apply(X[which(y==0),], 2, mean)
  xbar = apply(X, 2, mean)
  g = xbar1-xbar2
  p = length(g)
  S =((n1-1)*cov(X[y==1,])+(n2-1)*cov(X[y==0,]))/n + diag(p)*lambda2
  return(as.vector(solve(S)%*%(xbar1-xbar2)))
}

SLDApath = function(X,y,lambda2 = 0) {
## X = data matrix (n x p)
## y = class vector ( 0s and 1s only)

n1 = sum(y)
n2 = sum(1-y)
n = n1+n2
xbar1 = apply(X[which(y==1),], 2, mean)
xbar2= apply(X[which(y==0),], 2, mean)
xbar = apply(X, 2, mean)
g = xbar1-xbar2
p = length(g)
S =((n1-1)*cov(X[y==1,])+(n2-1)*cov(X[y==0,]))/n + diag(p)*lambda2

##################
### Initialize ###
##################

# Add first variable
A = which.min(diag(S)/g**2)
w = rep(0,p)
w[A] = 1/g[A]

# Add Second variable
Saw = as.numeric(S[A,]%*%w)
v = c(2*(Saw - S%*%w)/(g - g[A]), 2*(-S%*%w - Saw)/(g + g[A]))
J = which.max(abs(2*Saw+v*g[A]))
v = v[J]
if (J>p) J = J-p
A = c(A, J)
W = w

keepgoing = (length(A) != p)
cnt = 2
while(keepgoing) {
###################### Get new dir ######################
xi = -sign(2*S[A,]%*%w + v*g[A])
a = length(A)
C = matrix(0, nrow = (a+2), ncol = (a+2))
b = rep(0,(a+2))
b[a+2] = 1

C[(1:a), (1:a)] = 2*S[A,A]
C[(1:a),(a+1)] = g[A]
C[(1:a), (a+2)] = xi
C[(a+1), (1:a)] = g[A]
C[(a+2), (1:a)] = xi

sol = solve(C)%*%b
gamma = rep(0,p)
gamma[A]= sol[1:a]
dv = sol[a+1]
#####################
# Find which to add #
#####################

a1 = A[1]
dpos = (2*t(S - S[,a1])%*%w+v*(g-g[a1])) / (2*t(S[,a1]-S)%*%gamma + dv*(g[a1]-g))
dneg = (2*t(S + S[,a1])%*%w +v*(g+g[a1])) / (-2*t(S[,a1]+S)%*%gamma -dv*(g+g[a1]))
D= cbind(dpos, dneg)
D[A,]= -10

D1 = which(D>0)
addvar = D1[which.min(D[D1])]%%p
if (length(addvar)>0 && addvar==0) addvar=p
if (length(addvar) ==0){
  d1 = Inf
} else {
  d1 = min(D[D1])
}
######################
# Find which to kill #
######################
d2 = d1+100
D2 = -w/gamma
D3 = which(D2>0)
if (length(D3)>0) {
  killvar = D3[which.min(D2[D3])]
  d2 = min(D2[D3])
}

######################
# Distance Till Stop #
######################

d3 = (-2*S[a1,]%*%w - v*g[a1])/(2*S[a1,]%*%gamma + dv*g[a1])

######################
# Choose Kill or Add #
######################

if (d3>0 & d3 <= d1 & d3<=d2) {
  d = d3
  keepgoing = F
} else {
  if (d1 <d2) {
    d = d1
    A = c(A, addvar)
  } else {
    d = d2
    A = A[-which(A==killvar)]
  }
}

v = v+d*dv
w = w+d*gamma
W = rbind(W,w)
#keepgoing = (length(A)!= p)
 cnt = cnt+1
  if (cnt > p*p/2) {
   stop("Uh oh Infinite Loop in SLDA!!!")

 }
}

return(rbind(0,W))
}


SLDA.coefs = function(W,fracs=1:50/50,  meaningful.range = F, plotpath = F, plotmain ="SLDA Path"){ 
  p = ncol(W)
  r = nrow(W)
  bigsum = sum(abs(W[r,]))
  if (!meaningful.range) {
    s = as.vector(fracs*bigsum)
  } else {
    s = sum(abs(W[2,]))+as.vector(fracs*(bigsum- sum(abs(W[2,]))))
  }
  sums = apply(abs(W), 1, sum)

  if (plotpath) {
    names = as.character(1:p)
    if (!meaningful.range) {
      xvec = sums
      yvec = W/bigsum
    } else {
      xvec = sums[-1]
      yvec = W[-1,]/bigsum
    }
    ylim = c(min(yvec), max(yvec))
    plot(xvec, yvec[,1], ylim = ylim, col = 1, lty = 1, type ='l', main = plotmain, xlab = "sum|beta|", ylab = "beta")
    for (i in 2:p) {
      lines(xvec, yvec[,i], col = i, lty = i,  type ='l')
    }
    xlim = c(min(xvec), max(xvec))
    points(rep(xvec[nrow(yvec)]+ .001*(xlim[2]-xlim[1])  ,p), yvec[nrow(yvec),], pch = names)
  }

  slopes = (W[-1,] - W[-r,])/(sums[-1] - sums[-r])
  pts = cbind(sums[-r], sums[-1])
  gping = apply(pts, 1, FUN = function(x){ return(which( (s>=x[1]) & (s<=x[2])))})
  coefs = matrix(nrow = length(s), ncol = p)

  L = length(gping)
  for (x in 1:L) {
    if (length(gping[[x]]) > 0)
    coefs[gping[[x]],] =  t(W[x,]+sapply((s[gping[[x]]] - sums[x]), FUN = function(y) { y*slopes[x,]} ) )
  }
  return(coefs)
}

SLDA.predict = function(coefs, X,y, Xnew=X) {
  p1 = sum(y)/length(y)
  p2 = 1-p1
  L1 = as.matrix(X%*%t(coefs))
  L =  as.matrix(Xnew%*%t(coefs))
  m1 = apply(as.matrix(L1[y==1,]),2, mean)
  m2 = apply(as.matrix(L1[y==0,]),2, mean)
  var1 = apply(as.matrix(L1[y==1,]),2, var)
  var2 = apply(as.matrix(L1[y==0,]),2, var)
  sds = sqrt(p1*var1 + p2*var2)

  d1 = p1* dnorm(L, m1,sds)+10e-10
  d2 = p2* dnorm(L, m2,sds)+10e-10
  d = d1+d2
  d1 = d1/d
  d2 = 1-d1

  return(list(class = ((d1-d2)>=0)+0, case.prob = d1))
}

SLDA.score = function(coefs,ynew, Xnew=X) {
  Xn = Xnew%*%t(coefs)
  Xn1 = Xn[ynew==1,]
  Xn2 = Xn[ynew==0,]
  n = ncol(Xn)
  scores = sapply(1:n, FUN = function(x){ return(abs(t.test(Xn1[,x], Xn2[,x])$statistic)) })
  return(scores)
}

criteria.SLDA = function(X,y, coef, eps = 1e-10) {
  K = sum(abs(coef)>eps)

  n = length(y)
  n1 = sum(y)
  n2 = sum(1-y)
  p1 = n1/n
  p2 = n2/n
  L1 = X%*%coef

  m1 = mean(L1[y==1])
  m2 = mean(L1[y==0])
  s = sqrt(p1*var(L1[y==1])+p2*var(L1[y==0]))

  f1 = p1* dnorm(L1, m1,s)+10e-10
  f2 = p2* dnorm(L1, m2,s)+10e-10

  lf1 = log(f1)
  lf2 = log(f2)
  lf12 = log(f1+f2)

  loglik = sum(lf1[y==1]) +sum(lf2[y==0]) - sum(lf12)

  BIC = -2*loglik+log(n)*K
  AIC = -2*loglik+2*K
  return(list(BIC = BIC, AIC = AIC))
}  

