library(GeneralisedCovarianceMeasure)
source("./utils.R")
library(dHSIC)
library(scales)


func.complex <- function(x, a){
  return(2*exp(-x^2/2) * sin(a*x))
}

if(1){ #plotting the function
  pdf("experiment1-samplefunction.pdf",width=7,height=3.5)
  xgrid <- seq(-4,4,length.out = 5000); 
  plot(xgrid, func.complex(xgrid,18), cex.lab = 1.3, cex.axis = 1.3, type = "l", xlab = "z", ylab = "f_{18}(z)")
  dev.off()
}
if(1){ #plotting the function
  pdf("experiment1-samplefunction2.pdf",width=7,height=3.5)
  xgrid <- seq(-4,4,length.out = 5000); 
  plot(xgrid, func.complex(xgrid,6), cex.lab = 1.3, cex.axis = 1.3, type = "l", xlab = "z", ylab = "f_{6}(z)")
  dev.off()
}

methodsvec <- c("gcm", "kci")
nrep <- 100
avec <- c(1,6,12,18,24)
nvec <- c(100,1000,10000)


pvals <- array(NA, dim = c(length(methodsvec), nrep, length(avec), length(nvec)))
i <- 1; j <- 1; k <- 1

for(i in 1:nrep){
  show(i)
  for(j in 1:length(avec)){
    for(k in 1:length(nvec)){
      set.seed(100*k + 10*j + i)
      
      a <- avec[j]
      n <- nvec[k]
      Z <- rnorm(n)
      X <- func.complex(Z,a) + .5*rnorm(n)
      Y <- func.complex(Z,a) + .5*rnorm(n)
      
      if(n < 100){
        par(mfrow = c(3,1)); 
        plot(X,Y, xlim = c(-3,3), ylim = c(-3,3)); 
        plot(Z,X, xlim = c(-3,3), ylim = c(-3,3)); 
        plot(Z,Y, xlim = c(-3,3), ylim = c(-3,3))
      }
      for(l in 1:length(methodsvec)){
        switch(methodsvec[l], 
               "gcm" = {
                 pvals[l,i,j,k] <- gcm.test(X, Y, Z, regr.method = "xgboost")$p.value
               },
               "kci" = {
                 a <- 0
               })
      }
      show(paste(avec[j],',',nvec[k],': ', sum(pvals[1,1:i,j,k] < 0.05)))
    }
  }
}



for(j in 1:length(avec)){
  par(mfrow = c(length(nvec),1))
  for(k in 1:length(nvec)){
    plot.pvals(pvals[1,,j,k], plot.title = paste("tba", sum(pvals[1,,j,k] < 0.05)/nrep) )
  }
}



# plot for paper
pdf("experiment1-results.pdf",width=11,height=5)
fact <- 1.0
par(mfrow = c(1,1))
par(mar=c(4.1,4.1,1.1,8.9))
pchar <- c(19, 12, 4)
rejection <- qbinom(0.95, prob = 0.05, size = nrep)/nrep
plot(1, type = "n", ylim = c(0,1), xlim = c(1,15*fact), cex.lab = 1.2, cex.axis = 1.2,
     ylab = paste("proportion of rejections at level 0.05 (out of ", nrep, ")", sep = ""),
     xlab = "complexity of the conditional mean function",
     xaxt = "n"
    )
lines(c(0,16*fact),c(rejection,rejection))
polygon(list(-0.1,-0.1,16*fact,16*fact), list(-0.1,rejection,rejection,-0.1), col = alpha("gray", 0.5))
axis(1, at = c(2*fact,6*fact,10*fact,14*fact), labels = c("a = 1", "a = 6", "a = 12", "a = 18"), cex.axis = 1.2) 
x <- 0
arrows(17*fact, 0.2, 15*fact, 0.03, lwd = 1, xpd = TRUE)
text(17.25*fact,0.285, "acceptance region", cex = 1.2, xpd = TRUE)
text(17.25*fact,0.23, "'size less than 0.05'", cex = 1.2, xpd = TRUE)
for(j in 1:length(avec)){
  x <- x + 1*fact
  for(k in 1:length(nvec)){
    points(x, sum(pvals[1,,j,k] < 0.05)/nrep, type = "p", pch = pchar[k])
    x <- x + 1*fact
  }
}
lines(c(4*fact,4*fact),c(-1,2), lty=2)
lines(c(8*fact,8*fact),c(-1,2), lty=2)
lines(c(12*fact,12*fact),c(-1,2), lty=2)
legend(x = 1, y = 1, title = "sample size", c(" n = 100"," n = 1000"," n = 10000"), pch = c(19, 12, 4), cex = 1.2)
dev.off()


for(j in 1:length(avec)){
  par(mfrow = c(length(nvec),1))
  for(k in 1:length(nvec)){
    plot.pvals(pvals[1,,j,k], plot.title = paste("tba", sum(pvals[1,,j,k] < 0.05)/nrep) )
  }
}

