Skip to content

Commit

Permalink
remove logs of GUI
Browse files Browse the repository at this point in the history
  • Loading branch information
SunHaozhe committed Jan 31, 2019
1 parent debb453 commit dbee6af
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def computer_play(net):
if use_ai:
import torch
from mcts import MCTS, StateNode
model = CNN_Net()
model = CNN_Net(use_log=False)
use_cuda = torch.cuda.is_available()
file_path = os.path.join("../checkpoints", args.checkpoint)
if not use_cuda:
Expand Down
16 changes: 11 additions & 5 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@

# class to be used
class CNN_Net(nn.Module):
def __init__(self):
def __init__(self, use_log=True):
super(CNN_Net, self).__init__()
self.conv_block = Conv_block()
self.residual_blocks = self.__make_residual_blocks()
self.value_head = Value_head()
self.policy_head = Policy_head()
self.logger = build_logger("model", config.file2write)
if use_log:
self.logger = build_logger("model", config.file2write)


def forward(self, x):
Expand All @@ -39,9 +40,14 @@ def visualize_model(self):
example_input = torch.randn((5, 3, 8, 8))
values, policies = model(example_input)

self.message("Value head output shape: " + str(values.shape))
self.message("Policy head output shape: " + str(policies.shape))
self.message(model)
if use_log:
self.message("Value head output shape: " + str(values.shape))
self.message("Policy head output shape: " + str(policies.shape))
self.message(model)
else:
print("Value head output shape: " + str(values.shape))
print("Policy head output shape: " + str(policies.shape))
print(model)

def message(self, mess):
self.logger.info(mess)
Expand Down

0 comments on commit dbee6af

Please sign in to comment.