diff --git a/src/gui.py b/src/gui.py index 29b35f9..21d0972 100644 --- a/src/gui.py +++ b/src/gui.py @@ -151,7 +151,7 @@ def computer_play(net): if use_ai: import torch from mcts import MCTS, StateNode - model = CNN_Net() + model = CNN_Net(use_log=False) use_cuda = torch.cuda.is_available() file_path = os.path.join("../checkpoints", args.checkpoint) if not use_cuda: diff --git a/src/model.py b/src/model.py index 7151beb..277b9fd 100644 --- a/src/model.py +++ b/src/model.py @@ -8,13 +8,14 @@ # class to be used class CNN_Net(nn.Module): - def __init__(self): + def __init__(self, use_log=True): super(CNN_Net, self).__init__() self.conv_block = Conv_block() self.residual_blocks = self.__make_residual_blocks() self.value_head = Value_head() self.policy_head = Policy_head() - self.logger = build_logger("model", config.file2write) + if use_log: + self.logger = build_logger("model", config.file2write) def forward(self, x): @@ -39,9 +40,14 @@ def visualize_model(self): example_input = torch.randn((5, 3, 8, 8)) values, policies = model(example_input) - self.message("Value head output shape: " + str(values.shape)) - self.message("Policy head output shape: " + str(policies.shape)) - self.message(model) + if use_log: + self.message("Value head output shape: " + str(values.shape)) + self.message("Policy head output shape: " + str(policies.shape)) + self.message(model) + else: + print("Value head output shape: " + str(values.shape)) + print("Policy head output shape: " + str(policies.shape)) + print(model) def message(self, mess): self.logger.info(mess)