Skip to content

Commit

Permalink
BatchedMcts validate_prior on CuArray
Browse files Browse the repository at this point in the history
  • Loading branch information
Whojo committed Sep 24, 2022
1 parent bf82e9e commit 90338dd
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion redesign/src/BatchedMcts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ function create_tree(mcts, envs)
valid_actions = fill(false, mcts.device, (A, N, B))
valid_actions[:, ROOT, :] = info.valid_actions
policy_prior = zeros(Float32, mcts.device, (A, N, B))
policy_prior[:, ROOT, :] = validate_prior(info.policy_prior, info.valid_actions)
policy_prior[:, ROOT, :] = validate_prior(info.policy_prior, valid_actions[:, ROOT, :])
value_prior = zeros(Float32, mcts.device, (N, B))
value_prior[ROOT, :] = info.value_prior

Expand Down

0 comments on commit 90338dd

Please sign in to comment.