-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest.py
103 lines (87 loc) · 3.82 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False
from models import AS_FedDAG, GS_FedDAG, AS_FedDAG_linear
from datasets.simulation import property_generation
from helpers.evaluation import MetricsDAG
from helpers.tf_utils import set_seed
from datasets.simulation import property_generation
from helpers.config_utils import setup_parser, setup_logger
from datasets.data_gen import data_gen
from argparse import ArgumentParser
import warnings
warnings.filterwarnings('ignore')
from warnings import simplefilter
simplefilter(action='ignore', category=FutureWarning)
def main():
# Parser the arguments
parser = ArgumentParser()
setup_parser(parser)
args = parser.parse_args()
# Set up logger and seed
logger = setup_logger()
set_seed(args.seed)
# Generate the properties for heterogeneous data
dataset_property = property_generation(args.num_client)
# Generate the data
print('Generating data, maybe slow! You can pre-genrate the data.')
B_true, _, dataset, _ = data_gen(args.graph_type,
args.node,
args.edge,
args.seed,
args.num_client,
args.gen_method,
args.n,
args.sem_type,
dataset_property=dataset_property,
method=args.linearity)
# Run the FedDAG method
print('Begin running the mothod.....')
if args.fed_type == 'GS':
model = GS_FedDAG(d=args.node,
num_client=args.num_client,
use_gpu=args.use_gpu,
seed=args.seed,
init_rho=args.init_rho,
l1_graph_penalty=args.l1_graph_penalty,
rho_multiply=args.rho_multiply,
lr=args.lr,
max_iter=args.max_iter,
iter_step=args.iter_step,
it_fl=args.it_fl,
init_alpha=args.init_alpha,
num_shared_client=args.num_shared_client,
logger=logger)
model.learn(dataset)
elif args.fed_type == 'AS':
model = AS_FedDAG(n=args.n,
d=args.node,
use_gpu=args.use_gpu,
num_client=args.num_client,
seed=args.seed,
init_rho=args.init_rho,
l1_graph_penalty=args.l1_graph_penalty,
rho_multiply=args.rho_multiply,
lr=args.lr,
max_iter=args.max_iter,
iter_step=args.iter_step,
it_fl=args.it_fl,
init_alpha=args.init_alpha,
num_shared_client=args.num_shared_client,
logger=logger)
model.learn(dataset)
elif args.fed_type == 'AS_linear':
model = AS_FedDAG_linear(n=args.n,
d=args.node,
use_gpu=args.use_gpu,
num_client=args.num_client,
max_iter=args.max_iter,
iter_step=args.iter_step,
seed=args.seed,
logger=logger)
model.learn(dataset)
raw_result = MetricsDAG(model.causal_matrix, B_true).metrics
logger.info("run result:{0}".format(raw_result))
if __name__ == '__main__':
main()