diff --git a/R/tetrad_utils.R b/R/tetrad_utils.R index 0129ad5..1eb9882 100644 --- a/R/tetrad_utils.R +++ b/R/tetrad_utils.R @@ -1,42 +1,42 @@ ############################################################################### ugraphToTetradGraph <- function(ugmat, node_list){ - numNodes <- ncol(ugmat) - varnames <- strsplit(gsub("\\[|\\]", "", - node_list$toString()), - split=", ")[[1]] - edgelist <- c() - for (i in 2:numNodes){ - for (j in 1:(i-1)){ - if (ugmat[j,i]==1) edgelist <- c(edgelist, - paste(varnames[j], - "---", - varnames[i])) - } - } - - varstring <- paste(varnames, collapse=" ") - edgestring <- paste(1:length(edgelist),". ", - edgelist, "\n",sep="", collapse="") - graphstring <- paste("\nGraph Nodes:\n", varstring, - " \n\nGraph Edges: \n", - edgestring, "\n", sep="") - - graphfilename <- "impossibly_long_graph_file_name_temporary.txt" - if ("!"(file.exists(graphfilename))){ - write(graphstring, graphfilename) - graphfile <- .jnew("java/io/File", graphfilename) - newug_tetrad <- .jcall("edu/cmu/tetrad/graph/GraphUtils", - "Ledu/cmu/tetrad/graph/Graph;", - "loadGraphTxt", graphfile) - newug_tetrad <- .jcast(newug_tetrad, "edu/cmu/tetrad/graph/Graph", - check=TRUE) - rm(graphfile) - file.remove(graphfilename) - return(newug_tetrad) - } else { - print("Whoops, don't want to overwrite existing file!") - stop() - } + numNodes <- ncol(ugmat) + varnames <- strsplit(gsub("\\[|\\]", "", + node_list$toString()), + split=", ")[[1]] + edgelist <- c() + for (i in 2:numNodes){ + for (j in 1:(i-1)){ + if (ugmat[j,i]==1) edgelist <- c(edgelist, + paste(varnames[j], + "---", + varnames[i])) + } + } + + varstring <- paste(varnames, collapse=" ") + edgestring <- paste(1:length(edgelist),". ", + edgelist, "\n",sep="", collapse="") + graphstring <- paste("\nGraph Nodes:\n", varstring, + "\n\nGraph Edges: \n", + edgestring, "\n", sep="") + + graphfilename <- "impossibly_long_graph_file_name_temporary.txt" + if ("!"(file.exists(graphfilename))){ + write(graphstring, graphfilename) + graphfile <- .jnew("java/io/File", graphfilename) + newug_tetrad <- .jcall("edu/cmu/tetrad/graph/GraphUtils", + "Ledu/cmu/tetrad/graph/Graph;", + "loadGraphTxt", graphfile) + newug_tetrad <- .jcast(newug_tetrad, "edu/cmu/tetrad/graph/Graph", + check=TRUE) + rm(graphfile) + file.remove(graphfilename) + return(newug_tetrad) + } else { + print("Whoops, don't want to overwrite existing file!") + stop() + } } ######################################################## @@ -45,21 +45,19 @@ ugraphToTetradGraph <- function(ugmat, node_list){ # Dataset is discrete. # requires rJava, assumes the JVM is running from the # latest Tetrad jar. -dataFrames2TetradBDeuScoreImages <- function(dfs,structurePrior = 1.0, - samplePrior = 1.0){ - datasets <- .jnew("java/util/ArrayList") - for(i in 1:length(dfs)){ - df <- dfs[[i]] - boxData <- loadDiscreteData(df) - dataModel <- .jcast(boxData, "edu/cmu/tetrad/data/DataModel") - datasets$add(dataModel) - - } - score <- .jnew("edu/cmu/tetrad/search/BdeuScoreImages", datasets) - score$setStructurePrior(as.double(structurePrior)) - score$setSamplePrior(as.double(samplePrior)) - score <- .jcast(score, "edu/cmu/tetrad/search/Score") - return(score) +dataFrames2TetradBDeuScoreImages <- function(dfs,structurePrior = 1.0, samplePrior = 1.0){ + datasets <- .jnew("java/util/ArrayList") + for(i in 1:length(dfs)){ + df <- dfs[[i]] + boxData <- loadDiscreteData(df) + dataModel <- .jcast(boxData, "edu/cmu/tetrad/data/DataModel") + datasets$add(dataModel) + } + score <- .jnew("edu/cmu/tetrad/search/BdeuScoreImages", datasets) + score$setStructurePrior(as.double(structurePrior)) + score$setSamplePrior(as.double(samplePrior)) + score <- .jcast(score, "edu/cmu/tetrad/search/Score") + return(score) } ######################################################## @@ -68,14 +66,13 @@ dataFrames2TetradBDeuScoreImages <- function(dfs,structurePrior = 1.0, # Dataset is discrete. # requires rJava, assumes the JVM is running from the # latest Tetrad jar. -dataFrame2TetradBDeuScore <- function(df,structurePrior = 1.0, - samplePrior = 1.0){ - boxData <- loadDiscreteData(df) - score <- .jnew("edu/cmu/tetrad/search/BDeuScore", boxData) - score$setStructurePrior(as.double(structurePrior)) - score$setSamplePrior(as.double(samplePrior)) - score <- .jcast(score, "edu/cmu/tetrad/search/Score") - return(score) +dataFrame2TetradBDeuScore <- function(df,structurePrior = 1.0, samplePrior = 1.0){ + boxData <- loadDiscreteData(df) + score <- .jnew("edu/cmu/tetrad/search/BDeuScore", boxData) + score$setStructurePrior(as.double(structurePrior)) + score$setSamplePrior(as.double(samplePrior)) + score <- .jcast(score, "edu/cmu/tetrad/search/Score") + return(score) } ######################################################## @@ -85,19 +82,18 @@ dataFrame2TetradBDeuScore <- function(df,structurePrior = 1.0, # requires rJava, assumes the JVM is running from the # latest Tetrad jar. dataFrames2TetradSemBicScoreImages <- function(dfs,penaltydiscount = 4.0){ - datasets <- .jnew("java/util/ArrayList") - for(i in 1:length(dfs)){ - df <- dfs[[i]] - boxData <- loadContinuousData(df) - dataModel <- .jcast(boxData, "edu/cmu/tetrad/data/DataModel") - datasets$add(dataModel) - - } - datasets <- .jcast(datasets, "java/util/List") - score <- .jnew("edu/cmu/tetrad/search/SemBicScoreImages", datasets) - score$setPenaltyDiscount(penaltydiscount) - score <- .jcast(score, "edu/cmu/tetrad/search/Score") - return(score) + datasets <- .jnew("java/util/ArrayList") + for(i in 1:length(dfs)){ + df <- dfs[[i]] + boxData <- loadContinuousData(df) + dataModel <- .jcast(boxData, "edu/cmu/tetrad/data/DataModel") + datasets$add(dataModel) + } + datasets <- .jcast(datasets, "java/util/List") + score <- .jnew("edu/cmu/tetrad/search/SemBicScoreImages", datasets) + score$setPenaltyDiscount(penaltydiscount) + score <- .jcast(score, "edu/cmu/tetrad/search/Score") + return(score) } ######################################################## @@ -123,25 +119,25 @@ dataFrame2TetradSemBicScore <- function(df,penaltydiscount = 4.0){ # requires rJava, assumes the JVM is running from the # latest Tetrad jar. dataFrame2TetradConditionalGaussianScore <- function(df, - numCategoriesToDiscretize = 4, penaltydiscount = 4, structurePrior = 1.0){ - boxData <- loadMixedData(df, numCategoriesToDiscretize) - score <- .jnew("edu/cmu/tetrad/search/ConditionalGaussianScore", - boxData, structurePrior, TRUE) - score$setPenaltyDiscount(penaltydiscount) - score <- .jcast(score, "edu/cmu/tetrad/search/Score") - return(score) + numCategoriesToDiscretize = 4, penaltydiscount = 4, structurePrior = 1.0){ + boxData <- loadMixedData(df, numCategoriesToDiscretize) + score <- .jnew("edu/cmu/tetrad/search/ConditionalGaussianScore", + boxData, structurePrior, TRUE) + score$setPenaltyDiscount(penaltydiscount) + score <- .jcast(score, "edu/cmu/tetrad/search/Score") + return(score) } ######################################################## # converter: R covariance matrix into Tetrad covariance matrix rCovMatrix2TetradCovMatrix <- function(covmat, node_list, sample_size){ - mat <- .jarray(covmat, dispatch=TRUE) - tetmat <- .jnew("edu/cmu/tetrad/util/TetradMatrix", mat) - tetcovmat <- .jnew("edu/cmu/tetrad/data/CovarianceMatrix", node_list, - tetmat, as.integer(sample_size)) - tetcovmat <- .jcast(tetcovmat, "edu/cmu/tetrad/data/ICovarianceMatrix", - check=TRUE) - return(tetcovmat) + mat <- .jarray(covmat, dispatch=TRUE) + tetmat <- .jnew("edu/cmu/tetrad/util/TetradMatrix", mat) + tetcovmat <- .jnew("edu/cmu/tetrad/data/CovarianceMatrix", node_list, + tetmat, as.integer(sample_size)) + tetcovmat <- .jcast(tetcovmat, "edu/cmu/tetrad/data/ICovarianceMatrix", + check=TRUE) + return(tetcovmat) } ######################################################## @@ -149,30 +145,30 @@ rCovMatrix2TetradCovMatrix <- function(covmat, node_list, sample_size){ # requires list of nodes and a set of edges # extract nodes: # tetradPattern2graphNEL <- function(resultGraph, -# verbose = FALSE){ +# verbose = FALSE){ # # V <- extractTetradNodes(resultGraph) -# if(verbose){ -# cat("\nGraph Nodes:\n") -# for(i in 1:length(V)){ -# cat(V[i]," ") -# } -# cat("\n\n") -# } - +# if(verbose){ +# cat("\nGraph Nodes:\n") +# for(i in 1:length(V)){ +# cat(V[i]," ") +# } +# cat("\n\n") +# } + # extract edges # fgs_edges <- extractTetradEdges(resultGraph) # edgemat <- str_split_fixed(fgs_edges, pattern=" ", n=3) -# if(verbose){ -# cat("Graph Edges:\n") -# if(length(fgs_edges) > 0){ -# for(i in 1:length(fgs_edges)){ -# cat(fgs_edges[i],"\n") -# } -# } -# } +# if(verbose){ +# cat("Graph Edges:\n") +# if(length(fgs_edges) > 0){ +# for(i in 1:length(fgs_edges)){ +# cat(fgs_edges[i],"\n") +# } +# } +# } # find undirected edge indices # undir <- which(edgemat[,2]=="---") @@ -275,10 +271,10 @@ priorKnowledge <- function(forbiddirect = NULL, requiredirect = NULL, ############################################################ # Create an IKnowledge object from the knowledge file priorKnowledgeFromFile <- function(knowlegeFile){ - file <- .jnew("java/io/File", knowlegeFile) - reader <- .jnew("edu/cmu/tetrad/data/DataReader") - prior <- .jcall(reader, "Ledu/cmu/tetrad/data/IKnowledge;", "parseKnowledge", file) - return(prior) + file <- .jnew("java/io/File", knowlegeFile) + reader <- .jnew("edu/cmu/tetrad/data/DataReader") + prior <- .jcall(reader, "Ledu/cmu/tetrad/data/IKnowledge;", "parseKnowledge", file) + return(prior) } ############################################################ @@ -293,12 +289,11 @@ loadContinuousData <- function(df){ node_list <- .jcast(node_list, "java/util/List") mt <- as.matrix(df) mat <- .jarray(mt, dispatch=TRUE) - - data <- .jnew("edu/cmu/tetrad/data/DoubleDataBox", mat) - data <- .jcast(data, "edu/cmu/tetrad/data/DataBox") - boxData <- .jnew("edu/cmu/tetrad/data/BoxDataSet", - data, node_list) - boxData <- .jcast(boxData, "edu/cmu/tetrad/data/DataSet") + + data <- .jnew("edu/cmu/tetrad/data/DoubleDataBox", mat) + data <- .jcast(data, "edu/cmu/tetrad/data/DataBox") + boxData <- .jnew("edu/cmu/tetrad/data/BoxDataSet", data, node_list) + boxData <- .jcast(boxData, "edu/cmu/tetrad/data/DataSet") return(boxData) } @@ -308,117 +303,115 @@ loadDiscreteData <- function(df){ node_list <- .jnew("java/util/ArrayList") for (i in 1:length(node_names)){ nodname <- .jnew("java/lang/String", node_names[i]) - cat("node_names: ", node_names[i],"\n") - cate <- unique(df[[node_names[i]]]) - cate <- sort(cate) - cat("value: ") - print(cate) - cat("\n") - cate_list <- .jnew("java/util/ArrayList") - for(j in 1:length(cate)){ - cate_list$add(as.character(cate[j])) - } - cate_list <- .jcast(cate_list, "java/util/List") + cat("node_names: ", node_names[i],"\n") + cate <- unique(df[[node_names[i]]]) + cate <- sort(cate) + cat("value: ") + print(cate) + cat("\n") + cate_list <- .jnew("java/util/ArrayList") + for(j in 1:length(cate)){ + cate_list$add(as.character(cate[j])) + } + cate_list <- .jcast(cate_list, "java/util/List") nodi <- .jnew("edu/cmu/tetrad/data/DiscreteVariable", - nodname, cate_list) + nodname, cate_list) node_list$add(nodi) - - # Substitute a new categorial value - cate <- data.frame(cate) - new_col <- sapply(df[,i],function(x,cate) - as.integer(which(cate[,1] == x)),cate=cate) - new_col = as.integer(new_col - 1) - df[,i] <- (data.frame(new_col))[,1] + + # Substitute a new categorial value + cate <- data.frame(cate) + new_col <- sapply(df[,i],function(x,cate) + as.integer(which(cate[,1] == x)),cate=cate) + new_col = as.integer(new_col - 1) + df[,i] <- (data.frame(new_col))[,1] } node_list <- .jcast(node_list, "java/util/List") mt <- as.matrix(df) mat <- .jarray(t(mt), dispatch=TRUE) data <- .jnew("edu/cmu/tetrad/data/VerticalIntDataBox", mat) - data <- .jcast(data, "edu/cmu/tetrad/data/DataBox") - boxData <- .jnew("edu/cmu/tetrad/data/BoxDataSet", - data, node_list) - boxData <- .jcast(boxData, "edu/cmu/tetrad/data/DataSet") + data <- .jcast(data, "edu/cmu/tetrad/data/DataBox") + boxData <- .jnew("edu/cmu/tetrad/data/BoxDataSet", data, node_list) + boxData <- .jcast(boxData, "edu/cmu/tetrad/data/DataSet") return(boxData) } ############################################################ loadMixedData <- function(df, numCategoriesToDiscretize = 4){ - node_names <- colnames(df) - cont_list <- c() - disc_list <- c() - node_list <- .jnew("java/util/ArrayList") - for (i in 1:length(node_names)){ - nodname <- .jnew("java/lang/String", node_names[i]) - cate <- unique(df[[node_names[i]]]) - cateNumeric <- TRUE - for(j in 1:length(cate)){ - cate_value <- cate[j] - if(!is.numeric(cate_value)){ - cateNumeric <- FALSE - break - } - } - if(length(cate) > numCategoriesToDiscretize && cateNumeric){ - # Continuous variable - nodi <- .jnew("edu/cmu/tetrad/data/ContinuousVariable", nodname) - node_list$add(nodi) - - cont_list <- c(cont_list, i) - }else{ - # Discrete variable - cate <- sort(cate) - cate_list <- .jnew("java/util/ArrayList") - for(j in 1:length(cate)){ - cate_list$add(as.character(cate[j])) - } - cate_list <- .jcast(cate_list, "java/util/List") - nodi <- .jnew("edu/cmu/tetrad/data/DiscreteVariable", - nodname, cate_list) - node_list$add(nodi) - - # Substitute a new categorial value - cate <- data.frame(cate) - new_col <- sapply(df[,i],function(x,cate) - as.integer(which(cate[,1] == x)),cate=cate) - new_col = as.integer(new_col - 1) - df[,i] <- (data.frame(new_col))[,1] + node_names <- colnames(df) + cont_list <- c() + disc_list <- c() + node_list <- .jnew("java/util/ArrayList") + for (i in 1:length(node_names)){ + nodname <- .jnew("java/lang/String", node_names[i]) + cate <- unique(df[[node_names[i]]]) + cateNumeric <- TRUE + for(j in 1:length(cate)){ + cate_value <- cate[j] + if(!is.numeric(cate_value)){ + cateNumeric <- FALSE + break + } + } + if(length(cate) > numCategoriesToDiscretize && cateNumeric){ + # Continuous variable + nodi <- .jnew("edu/cmu/tetrad/data/ContinuousVariable", nodname) + node_list$add(nodi) - disc_list <- c(disc_list, i) - } - } - - node_list <- .jcast(node_list, "java/util/List") - mixedDataBox <- .jnew("edu/cmu/tetrad/data/MixedDataBox", node_list,as.integer(nrow(df))) - - for(row in 1:nrow(df)){ - # print(paste("row:",row,sep=" ")) - if(length(cont_list) > 0){ - for(j in 1:length(cont_list)){ - col <- cont_list[j] - # print(paste("col:",col,sep=" ")) - value <- as.character(df[row,col]) - #print(value) - value <- .jnew("java/lang/Double", value) - value <- .jcast(value, "java/lang/Number") - mixedDataBox$set(as.integer(row-1),as.integer(col-1),value) - } - } - if(length(disc_list) > 0){ - for(j in 1:length(disc_list)){ - col <- disc_list[j] - # print(paste("col:",col,sep=" ")) - value <- as.character(df[row,col]) - # print(value) - value <- .jnew("java/lang/Integer", value) - value <- .jcast(value, "java/lang/Number") - mixedDataBox$set(as.integer(row-1),as.integer(col-1),value) - } - } - } - - data <- .jcast(mixedDataBox, "edu/cmu/tetrad/data/DataBox") - boxData <- .jnew("edu/cmu/tetrad/data/BoxDataSet", - data, node_list) - boxData <- .jcast(boxData, "edu/cmu/tetrad/data/DataSet") - return(boxData) + cont_list <- c(cont_list, i) + }else{ + # Discrete variable + cate <- sort(cate) + cate_list <- .jnew("java/util/ArrayList") + for(j in 1:length(cate)){ + cate_list$add(as.character(cate[j])) + } + cate_list <- .jcast(cate_list, "java/util/List") + nodi <- .jnew("edu/cmu/tetrad/data/DiscreteVariable", + nodname, cate_list) + node_list$add(nodi) + + # Substitute a new categorial value + cate <- data.frame(cate) + new_col <- sapply(df[,i],function(x,cate) + as.integer(which(cate[,1] == x)),cate=cate) + new_col = as.integer(new_col - 1) + df[,i] <- (data.frame(new_col))[,1] + + disc_list <- c(disc_list, i) + } + } + + node_list <- .jcast(node_list, "java/util/List") + mixedDataBox <- .jnew("edu/cmu/tetrad/data/MixedDataBox", node_list,as.integer(nrow(df))) + + for(row in 1:nrow(df)){ + # print(paste("row:",row,sep=" ")) + if(length(cont_list) > 0){ + for(j in 1:length(cont_list)){ + col <- cont_list[j] + # print(paste("col:",col,sep=" ")) + value <- as.character(df[row,col]) + #print(value) + value <- .jnew("java/lang/Double", value) + value <- .jcast(value, "java/lang/Number") + mixedDataBox$set(as.integer(row-1),as.integer(col-1),value) + } + } + if(length(disc_list) > 0){ + for(j in 1:length(disc_list)){ + col <- disc_list[j] + # print(paste("col:",col,sep=" ")) + value <- as.character(df[row,col]) + # print(value) + value <- .jnew("java/lang/Integer", value) + value <- .jcast(value, "java/lang/Number") + mixedDataBox$set(as.integer(row-1),as.integer(col-1),value) + } + } + } + + data <- .jcast(mixedDataBox, "edu/cmu/tetrad/data/DataBox") + boxData <- .jnew("edu/cmu/tetrad/data/BoxDataSet", data, node_list) + boxData <- .jcast(boxData, "edu/cmu/tetrad/data/DataSet") + return(boxData) } diff --git a/R/tetradrunner.R b/R/tetradrunner.R index 9d70066..8b4a960 100644 --- a/R/tetradrunner.R +++ b/R/tetradrunner.R @@ -1,62 +1,62 @@ tetradrunner <- function(algoId, dataType, df = NULL, dfs = NULL, testId = NULL, scoreId = NULL, priorKnowledge = NULL, numCategoriesToDiscretize = 4,java.parameters = NULL,...) { - - arguments <- list(...) - - params <- list() - # result - tetradrunner <- list() - - if(!is.null(java.parameters)){ - options(java.parameters = java.parameters) - params <- c(java.parameters = java.parameters) - } - - algoAnno_instance <- .jcall("edu/cmu/tetrad/annotation/AlgorithmAnnotations", + + arguments <- list(...) + + params <- list() + # result + tetradrunner <- list() + + if(!is.null(java.parameters)){ + options(java.parameters = java.parameters) + params <- c(java.parameters = java.parameters) + } + + algoAnno_instance <- .jcall("edu/cmu/tetrad/annotation/AlgorithmAnnotations", "Ledu/cmu/tetrad/annotation/AlgorithmAnnotations;", "getInstance") - algoClasses <- algoAnno_instance$getAnnotatedClasses() - - algoClass <- .jnull("java/lang/Class") - algoAnno <- NULL - - algoClasses <- algoClasses$toArray() - for(i in 1:algoClasses$length){ - algo <- algoClasses[[i]] - cmd <- algo$getAnnotation()$command() - - if(cmd == algoId){ - algoClass <- algo$getClazz() - algoAnno <- algo$getAnnotation() - break - } + algoClasses <- algoAnno_instance$getAnnotatedClasses() + + algoClass <- .jnull("java/lang/Class") + algoAnno <- NULL + + algoClasses <- algoClasses$toArray() + for(i in 1:algoClasses$length){ + algo <- algoClasses[[i]] + cmd <- algo$getAnnotation()$command() + + if(cmd == algoId){ + algoClass <- algo$getClazz() + algoAnno <- algo$getAnnotation() + break + } } - + if(is.null(algoAnno)){ cat(algoId,' is not found!\n') return } - + tetradProperties <- .jcall("edu/cmu/tetrad/util/TetradProperties", "Ledu/cmu/tetrad/util/TetradProperties;", "getInstance") - - # testId - testClass <- .jnull("java/lang/Class") + + # testId + testClass <- .jnull("java/lang/Class") if(!is.null(testId) || algoAnno_instance$requireIndependenceTest(algoClass)){ testAnno_instance <- .jcall("edu/cmu/tetrad/annotation/TestOfIndependenceAnnotations", "Ledu/cmu/tetrad/annotation/TestOfIndependenceAnnotations;", "getInstance") - testClasses <- testAnno_instance$getAnnotatedClasses() - testClasses <- testClasses$toArray() - - defaultTestClassName <- NULL - - # Default dataType + testClasses <- testAnno_instance$getAnnotatedClasses() + testClasses <- testClasses$toArray() + + defaultTestClassName <- NULL + + # Default dataType continuous <- 'datatype.continuous.test.default' discrete <- 'datatype.discrete.test.default' mixed <- 'datatype.mixed.test.default' - + if(dataType == 'continuous'){ defaultTestClassName <- tetradProperties$getValue(continuous) }else if(dataType == 'discrete'){ @@ -67,37 +67,37 @@ tetradrunner <- function(algoId, dataType, df = NULL, dfs = NULL, testId = NULL, for(i in 1:testClasses$length){ test <- testClasses[[i]] - cmd <- test$getAnnotation()$command() - tClass <- test$getClazz() - name <- tClass$getName() - - if(name == defaultTestClassName){ - testClass <- tClass - } - - if(!is.null(testId) && cmd == testId){ - testClass <- tClass - break - } + cmd <- test$getAnnotation()$command() + tClass <- test$getClazz() + name <- tClass$getName() + + if(name == defaultTestClassName){ + testClass <- tClass + } + + if(!is.null(testId) && cmd == testId){ + testClass <- tClass + break + } } } - + # scoreId scoreClass <- .jnull("java/lang/Class") if(!is.null(scoreId) || algoAnno_instance$requireScore(algoClass)){ scoreAnno_instance <- .jcall("edu/cmu/tetrad/annotation/ScoreAnnotations", "Ledu/cmu/tetrad/annotation/ScoreAnnotations;", "getInstance") - scoreClasses <- scoreAnno_instance$getAnnotatedClasses() - scoreClasses <- scoreClasses$toArray() - - defaultScoreClassName <- NULL - - # Default dataType + scoreClasses <- scoreAnno_instance$getAnnotatedClasses() + scoreClasses <- scoreClasses$toArray() + + defaultScoreClassName <- NULL + + # Default dataType continuous <- 'datatype.continuous.score.default' discrete <- 'datatype.discrete.score.default' mixed <- 'datatype.mixed.score.default' - + if(dataType == 'continuous'){ defaultScoreClassName <- tetradProperties$getValue(continuous) }else if(dataType == 'discrete'){ @@ -105,127 +105,126 @@ tetradrunner <- function(algoId, dataType, df = NULL, dfs = NULL, testId = NULL, }else{ defaultScoreClassName <- tetradProperties$getValue(mixed) } - + for(i in 1:scoreClasses$length){ score <- scoreClasses[[i]] - cmd <- score$getAnnotation()$command() - sClass <- score$getClazz() - name <- sClass$getName() - - if(name == defaultScoreClassName){ - scoreClass <- sClass - } - - if(!is.null(scoreId) && cmd == scoreId){ - scoreClass <- sClass - break - } + cmd <- score$getAnnotation()$command() + sClass <- score$getClazz() + name <- sClass$getName() + + if(name == defaultScoreClassName){ + scoreClass <- sClass + } + + if(!is.null(scoreId) && cmd == scoreId){ + scoreClass <- sClass + break + } } } - # dataset - tetradData <- NULL - if(!is.null(df)){ - - if(dataType == 'continuous'){ - tetradData <- loadContinuousData(df) - }else if(dataType == 'discrete'){ - tetradData <- loadDiscreteData(df) - }else{ - tetradData <- loadMixedData(df, numCategoriesToDiscretize) - } - - tetradData <- .jcast(tetradData, 'edu/cmu/tetrad/data/DataModel') - - }else if(!is.null(dfs)){ - - tetradData <- .jnew("java/util/ArrayList") - for(i in 1:length(dfs)){ - df <- dfs[[i]] - - if(dataType == 'continuous'){ - df <- loadContinuousData(df) - }else if(dataType == 'discrete'){ - df <- loadDiscreteData(df) - }else{ - df <- loadMixedData(df, numCategoriesToDiscretize) - } - - df <- .jcast(df, 'edu/cmu/tetrad/data/DataModel') - - tetradData$add(df) - } - - tetradData <- .jcast(tetradData, "java/util/List") - }else{ - cat("Dataset is required!") - return - } - - algo_instance <- .jcall("edu/cmu/tetrad/algcomparison/algorithm/AlgorithmFactory", - "Ledu/cmu/tetrad/algcomparison/algorithm/Algorithm;", - "create",algoClass, testClass, scoreClass) - - if(!is.null(priorKnowledge)){ - algo_instance$setKnowledge(priorKnowledge) - } - - # Parameters - paramDescs_instance <- .jcall("edu/cmu/tetrad/util/ParamDescriptions", - "Ledu/cmu/tetrad/util/ParamDescriptions;", - "getInstance") - - parameters_instance <- .jnew("edu/cmu/tetrad/util/Parameters") - for(arg in names(arguments)){ - if(!is.null(paramDescs_instance$get(arg))){ - - value <- arguments[[arg]] - parameter_instance <- NULL - obj_value <- NULL - - if(!is.character(value)){ - if(is.logical(value)){ - obj_value <- .jnew("java/lang/Boolean", value) - }else if(value%%1 == 0){ - obj_value <- .jnew("java/lang/Integer", as.integer(value)) - }else{ - obj_value <- .jnew("java/lang/Double", value) - } - - parameter_instance <- .jcast(obj_value, "java/lang/Object") - parameters_instance$set(arg, parameter_instance) - } - - - } - # print(arg) # argument's name - # print(arguments[arg]) # argument's value - } - - # Search - tetrad_graph <- .jcall(algo_instance, "Ledu/cmu/tetrad/graph/Graph;", - "search", tetradData, parameters_instance, check=FALSE) - - if(!is.null(e <- .jgetEx())){ - .jclear() - tetradrunner$nodes <- colnames(df) - tetradrunner$edges <- NULL - # print("Java exception was raised") - # print(e) - }else{ - tetradrunner$graph <- tetrad_graph - - V <- extractTetradNodes(tetrad_graph) - - tetradrunner$nodes <- V - - # extract edges - tetradrunner_edges <- extractTetradEdges(tetrad_graph) - - tetradrunner$edges <- tetradrunner_edges - } - - return(tetradrunner) + # dataset + tetradData <- NULL + if(!is.null(df)){ + + if(dataType == 'continuous'){ + tetradData <- loadContinuousData(df) + }else if(dataType == 'discrete'){ + tetradData <- loadDiscreteData(df) + }else{ + tetradData <- loadMixedData(df, numCategoriesToDiscretize) + } + + tetradData <- .jcast(tetradData, 'edu/cmu/tetrad/data/DataModel') + + }else if(!is.null(dfs)){ + + tetradData <- .jnew("java/util/ArrayList") + for(i in 1:length(dfs)){ + df <- dfs[[i]] + + if(dataType == 'continuous'){ + df <- loadContinuousData(df) + }else if(dataType == 'discrete'){ + df <- loadDiscreteData(df) + }else{ + df <- loadMixedData(df, numCategoriesToDiscretize) + } + + df <- .jcast(df, 'edu/cmu/tetrad/data/DataModel') + + tetradData$add(df) + } + + tetradData <- .jcast(tetradData, "java/util/List") + }else{ + cat("Dataset is required!") + return + } + + algo_instance <- .jcall("edu/cmu/tetrad/algcomparison/algorithm/AlgorithmFactory", + "Ledu/cmu/tetrad/algcomparison/algorithm/Algorithm;", + "create",algoClass, testClass, scoreClass) + + if(!is.null(priorKnowledge)){ + algo_instance$setKnowledge(priorKnowledge) + } + + # Parameters + paramDescs_instance <- .jcall("edu/cmu/tetrad/util/ParamDescriptions", + "Ledu/cmu/tetrad/util/ParamDescriptions;", + "getInstance") + + parameters_instance <- .jnew("edu/cmu/tetrad/util/Parameters") + for(arg in names(arguments)){ + if(!is.null(paramDescs_instance$get(arg))){ + + value <- arguments[[arg]] + parameter_instance <- NULL + obj_value <- NULL + + if(!is.character(value)){ + if(is.logical(value)){ + obj_value <- .jnew("java/lang/Boolean", value) + }else if(value%%1 == 0){ + obj_value <- .jnew("java/lang/Integer", as.integer(value)) + }else{ + obj_value <- .jnew("java/lang/Double", value) + } + + parameter_instance <- .jcast(obj_value, "java/lang/Object") + parameters_instance$set(arg, parameter_instance) + } + + } + # print(arg) # argument's name + # print(arguments[arg]) # argument's value + } + + # Search + tetrad_graph <- .jcall(algo_instance, "Ledu/cmu/tetrad/graph/Graph;", + "search", tetradData, parameters_instance, check=FALSE) + + if(!is.null(e <- .jgetEx())){ + .jclear() + tetradrunner$nodes <- colnames(df) + tetradrunner$edges <- NULL + # print("Java exception was raised") + # print(e) + }else{ + tetradrunner$graph <- tetrad_graph + + V <- extractTetradNodes(tetrad_graph) + + tetradrunner$nodes <- V + + # extract edges + tetradrunner_edges <- extractTetradEdges(tetrad_graph) + + tetradrunner$edges <- tetradrunner_edges + } + + return(tetradrunner) } tetradrunner.tetradGraphToDot <- function(tetrad_graph){ @@ -238,30 +237,30 @@ tetradrunner.listAlgorithms <- function(){ algoAnno_instance <- .jcall("edu/cmu/tetrad/annotation/AlgorithmAnnotations", "Ledu/cmu/tetrad/annotation/AlgorithmAnnotations;", "getInstance") - algoClasses <- algoAnno_instance$getAnnotatedClasses() - - algoClasses <- algoClasses$toArray() - for(i in 1:algoClasses$length){ - algo <- algoClasses[[i]] - algoType <- algo$getAnnotation()$algoType()$toString() - if(algoType != 'orient_pairwise'){ - cmd <- algo$getAnnotation()$command() - cat(cmd,"\n") - } - } + algoClasses <- algoAnno_instance$getAnnotatedClasses() + + algoClasses <- algoClasses$toArray() + for(i in 1:algoClasses$length){ + algo <- algoClasses[[i]] + algoType <- algo$getAnnotation()$algoType()$toString() + if(algoType != 'orient_pairwise'){ + cmd <- algo$getAnnotation()$command() + cat(cmd,"\n") + } + } } tetradrunner.listIndTests <- function(){ testAnno_instance <- .jcall("edu/cmu/tetrad/annotation/TestOfIndependenceAnnotations", "Ledu/cmu/tetrad/annotation/TestOfIndependenceAnnotations;", "getInstance") - testClasses <- testAnno_instance$getAnnotatedClasses() - testClasses <- testClasses$toArray() - + testClasses <- testAnno_instance$getAnnotatedClasses() + testClasses <- testClasses$toArray() + for(i in 1:testClasses$length){ test <- testClasses[[i]] - cmd <- test$getAnnotation()$command() - cat(cmd,"\n") + cmd <- test$getAnnotation()$command() + cat(cmd,"\n") } } @@ -269,13 +268,13 @@ tetradrunner.listScores <- function(){ scoreAnno_instance <- .jcall("edu/cmu/tetrad/annotation/ScoreAnnotations", "Ledu/cmu/tetrad/annotation/ScoreAnnotations;", "getInstance") - scoreClasses <- scoreAnno_instance$getAnnotatedClasses() - scoreClasses <- scoreClasses$toArray() - + scoreClasses <- scoreAnno_instance$getAnnotatedClasses() + scoreClasses <- scoreClasses$toArray() + for(i in 1:scoreClasses$length){ score <- scoreClasses[[i]] - cmd <- score$getAnnotation()$command() - cat(cmd,"\n") + cmd <- score$getAnnotation()$command() + cat(cmd,"\n") } } @@ -283,28 +282,28 @@ tetradrunner.getAlgorithmDescription <- function(algoId){ algoAnno_instance <- .jcall("edu/cmu/tetrad/annotation/AlgorithmAnnotations", "Ledu/cmu/tetrad/annotation/AlgorithmAnnotations;", "getInstance") - algoClasses <- algoAnno_instance$getAnnotatedClasses() - - algoClass <- NULL - algoAnno <- NULL - - algoClasses <- algoClasses$toArray() - for(i in 1:algoClasses$length){ - algo <- algoClasses[[i]] - cmd <- algo$getAnnotation()$command() - - - if(cmd == algoId){ - algoClass <- algo$getClazz() - algoAnno <- algo$getAnnotation() - break - } + algoClasses <- algoAnno_instance$getAnnotatedClasses() + + algoClass <- NULL + algoAnno <- NULL + + algoClasses <- algoClasses$toArray() + for(i in 1:algoClasses$length){ + algo <- algoClasses[[i]] + cmd <- algo$getAnnotation()$command() + + + if(cmd == algoId){ + algoClass <- algo$getClazz() + algoAnno <- algo$getAnnotation() + break + } } - + algoDesc_instance <- .jcall("edu/cmu/tetrad/util/AlgorithmDescriptions", "Ledu/cmu/tetrad/util/AlgorithmDescriptions;", "getInstance") - + cat(algoDesc_instance$get(algoId)) if(algoAnno_instance$requireIndependenceTest(algoClass)){ @@ -322,19 +321,19 @@ tetradrunner.getAlgorithmParameters <- function(algoId, testId = NULL, scoreId = algoAnno_instance <- .jcall("edu/cmu/tetrad/annotation/AlgorithmAnnotations", "Ledu/cmu/tetrad/annotation/AlgorithmAnnotations;", "getInstance") - algoClasses <- algoAnno_instance$getAnnotatedClasses() - - algoClass <- .jnull("java/lang/Class") - - algoClasses <- algoClasses$toArray() - for(i in 1:algoClasses$length){ - algo <- algoClasses[[i]] - cmd <- algo$getAnnotation()$command() - - if(cmd == algoId){ - algoClass <- algo$getClazz() - break - } + algoClasses <- algoAnno_instance$getAnnotatedClasses() + + algoClass <- .jnull("java/lang/Class") + + algoClasses <- algoClasses$toArray() + for(i in 1:algoClasses$length){ + algo <- algoClasses[[i]] + cmd <- algo$getAnnotation()$command() + + if(cmd == algoId){ + algoClass <- algo$getClazz() + break + } } # testId @@ -360,34 +359,34 @@ tetradrunner.getAlgorithmParameters <- function(algoId, testId = NULL, scoreId = # scoreId scoreClass <- .jnull("java/lang/Class") - + if(!is.null(scoreId)){ scoreAnno_instance <- .jcall("edu/cmu/tetrad/annotation/ScoreAnnotations", "Ledu/cmu/tetrad/annotation/ScoreAnnotations;", "getInstance") - scoreClasses <- scoreAnno_instance$getAnnotatedClasses() - scoreClasses <- scoreClasses$toArray() - + scoreClasses <- scoreAnno_instance$getAnnotatedClasses() + scoreClasses <- scoreClasses$toArray() + for(i in 1:scoreClasses$length){ score <- scoreClasses[[i]] - cmd <- score$getAnnotation()$command() - - if(cmd == scoreId){ - scoreClass <- score$getClazz() - break - } + cmd <- score$getAnnotation()$command() + + if(cmd == scoreId){ + scoreClass <- score$getClazz() + break + } } } - + algo_instance <- .jcall("edu/cmu/tetrad/algcomparison/algorithm/AlgorithmFactory", - "Ledu/cmu/tetrad/algcomparison/algorithm/Algorithm;", - "create",algoClass, testClass, scoreClass) + "Ledu/cmu/tetrad/algcomparison/algorithm/Algorithm;", + "create",algoClass, testClass, scoreClass) algoParams <- algo_instance$getParameters() - + paramDescs_instance <- .jcall("edu/cmu/tetrad/util/ParamDescriptions", - "Ledu/cmu/tetrad/util/ParamDescriptions;", - "getInstance") + "Ledu/cmu/tetrad/util/ParamDescriptions;", + "getInstance") for(i in 0:(algoParams$size()-1)){ algoParam <- algoParams$get(i) paramDesc <- paramDescs_instance$get(algoParam) @@ -397,4 +396,4 @@ tetradrunner.getAlgorithmParameters <- function(algoId, testId = NULL, scoreId = cat(algoParam,": ",desc," [default:",defaultValue,"]","\n") } -} \ No newline at end of file +}