-
Notifications
You must be signed in to change notification settings - Fork 12
/
algcomparison_large_fges.py
73 lines (56 loc) · 2.39 KB
/
algcomparison_large_fges.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Note: This is an example of how to write an algcomparison script to do algorithm
# comparison in Tetrad. It may not be the best example yet, but it does make
# clear how the script can be written. JR 2023-02-27
import jpype.imports
import importlib.resources as importlib_resources
jar_path = importlib_resources.files('pytetrad').joinpath('resources','tetrad-current.jar')
jar_path = str(jar_path)
if not jpype.isJVMStarted():
try:
jpype.startJVM(jpype.getDefaultJVMPath(), classpath=[jar_path])
except OSError:
print("can't load jvm")
pass
from edu.cmu.tetrad.util import Params, Parameters
from edu.cmu.tetrad.algcomparison import Comparison
from edu.cmu.tetrad.algcomparison.algorithm import Algorithms
from edu.cmu.tetrad.algcomparison.simulation import Simulations
import edu.cmu.tetrad.algcomparison.simulation as sim
import edu.cmu.tetrad.algcomparison.score as score
import edu.cmu.tetrad.algcomparison.graph as graph
import edu.cmu.tetrad.algcomparison.independence as ind
import edu.cmu.tetrad.algcomparison.statistic as stat
import edu.cmu.tetrad.algcomparison.algorithm.oracle.cpdag as cpdag
params = Parameters()
params.set(Params.PENALTY_DISCOUNT, 8)
params.set(Params.SAMPLE_SIZE, 500)
params.set(Params.NUM_MEASURES, 2000)
# params.set(Params.NUM_MEASURES, 20000)
params.set(Params.AVG_DEGREE, 6)
params.set(Params.NUM_LATENTS, 0)
params.set(Params.RANDOMIZE_COLUMNS, False)
params.set(Params.COEF_LOW, 0)
params.set(Params.COEF_HIGH, 1)
params.set(Params.VAR_LOW, 1)
params.set(Params.VAR_HIGH, 3)
params.set(Params.FAITHFULNESS_ASSUMED, True)
params.set(Params.PARALLELIZED, True)
params.set(Params.VERBOSE, True)
params.set(Params.NUM_RUNS, 1)
score = score.SemBicScore()
test = ind.FisherZ()
algorithms = Algorithms()
algorithms.add(cpdag.Fges(score))
simulations = Simulations()
simulations.add(sim.LinearFisherModel(graph.RandomForward()))
statistics = stat.Statistics()
statistics.add(stat.ParameterColumn(Params.NUM_MEASURES))
statistics.add(stat.ParameterColumn(Params.SAMPLE_SIZE))
statistics.add(stat.AdjacencyPrecision())
statistics.add(stat.AdjacencyRecall())
statistics.add(stat.ArrowheadPrecision())
statistics.add(stat.ArrowheadRecall())
statistics.add(stat.ElapsedCpuTime())
comparison = Comparison()
comparison.setComparisonGraph(Comparison.ComparisonGraph.true_DAG)
comparison.compareFromSimulations("../testFges", simulations, algorithms, statistics, params)