-
-
Notifications
You must be signed in to change notification settings - Fork 423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix chess to white perspective, fix observation bug, add documentation #1004
Changes from 2 commits
cbf6941
cd2b699
f8fb129
ee20281
7b00b3c
53b2240
b010282
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,8 @@ def boards_to_ndarray(boards): | |
bits = np.unpackbits(arr8) | ||
floats = bits.astype(bool) | ||
boardstack = floats.reshape([len(boards), 8, 8]) | ||
boardimage = np.transpose(boardstack, [1, 2, 0]) | ||
# We do np.flip() onto `boardstack` because the board is 180 degrees rotated after the process above. | ||
boardimage = np.flip(np.transpose(boardstack, [1, 2, 0]), axis=[0, 1]) | ||
return boardimage | ||
|
||
|
||
|
@@ -135,7 +136,7 @@ def get_move_plane(move): | |
actions_to_moves = {} | ||
|
||
|
||
def action_to_move(board, action, player): | ||
def action_to_move(board: chess.Board, action, player: int): | ||
base_move = chess.Move.from_uci(actions_to_moves[action]) | ||
|
||
base_coord = square_to_coord(base_move.from_square) | ||
|
@@ -164,7 +165,7 @@ def make_move_mapping(uci_move): | |
actions_to_moves[cur_action] = uci_move | ||
|
||
|
||
def legal_moves(orig_board): | ||
def legal_moves(orig_board: chess.Board): | ||
"""Returns legal moves. | ||
|
||
action space is a 8x8x73 dimensional array | ||
|
@@ -194,7 +195,7 @@ def legal_moves(orig_board): | |
return legal_moves | ||
|
||
|
||
def get_observation(orig_board, player): | ||
def get_observation(orig_board: chess.Board, player: int): | ||
"""Returns observation array. | ||
|
||
Observation is an 8x8x(P + L) dimensional array. | ||
|
@@ -281,8 +282,9 @@ def get_observation(orig_board, player): | |
|
||
""" | ||
base = BASE | ||
OURS = 0 | ||
THEIRS = 1 | ||
# In the module `chess`, the color is represented by 1 for white and 0 for black. | ||
OURS = 1 | ||
jacob975 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
THEIRS = 0 | ||
result[base + 0] = board.pieces(chess.PAWN, OURS) | ||
result[base + 1] = board.pieces(chess.KNIGHT, OURS) | ||
result[base + 2] = board.pieces(chess.BISHOP, OURS) | ||
|
@@ -321,17 +323,31 @@ def get_observation(orig_board, player): | |
} | ||
""" | ||
# from 0-63 | ||
square = board.ep_square | ||
# Adjust the row number for the white pawn to the 1st if the en passant flag is set, and vice versa for black pawns. | ||
# For example | ||
jacob975 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# If the white play an en passant move, the opponent can play a special move called en passant capture. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And this doesn’t actually explain it, the idea is if a pawn is next to another piece and moves diagonally to still be next to it then it removes it, right? I’d just Google around and find a nice simple explanation |
||
# To show this, we denote the pawn at (row, col) = (1, `dest_square`) instead of (5, `dest_square`). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don’t get why we denote it differently with 1, dest square. Is that just because it’s impossible for a pawn to go backwards so this is just the only encoding for a pawn being in en passant position? Is this something done by AlphaZero or others? Seems a bit odd. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, a white pawn never walks to the 1st row. It is just a notation of the chance to make a en passant capture. It is written in the document There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I review the paper of AlphaZero and the notation system of chess called Forsyth-Edwards Notation (FEN). AlphaZero mentioned some of special movement in chess like queen castling, but nothing about en passant capture. They did not explain but I guess the reason is a neural network can retrive information of en passant flag from their board history. On the other hand, FEN is a popular notation used on chess. I also see a function for it in the module Back to our observation space, so far I revise the en passant section to keep its functionality. However, it is not conventional to my knowledge. In addition, we have provided a borad history where a model can retrive the information of en passant flag. Therefore, I would like to comment out this section of code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you check other RL chess implementations such as openspiel or other papers? Just want to make sure it won’t mess things up for people. I think I linked it previously but RLlib has a LeelaChessZero model, maybe we can check if that would work with this change? I can try to get in contact with the person who did the pr for that if need be. My initial thought is it makes sense, first row pawn seems very non standard and roundabout way of handling it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I will check and inform you if I need help. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I copied this behavior from LC0, but I took a 2nd look at the AlphaZero paper and it definitely doesn't mention this. Presumably alphazero is relying on the board history to provide the necessary information for en-passant, and the Lc0 developers decided to add this information in explicitly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the current code behavior is also done by leela chess zero that seems reasonable to do imo. Having some way of accessing the information is better than none. Do you have thoughts on the FEN system mentioned above @benblack769? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, FEN's representation is nice because it is concise and unambiguous, but neural networks aren't necessarily the best at learning concise information, it can be better to duplicate information. For a code reference, here is how Lc0 decodes FEN into its internal representation of en-passant https://github.com/LeelaChessZero/lc0/blob/master/src/chess/board.cc#L1114 |
||
square = board.ep_square # square where the en passant happened (int) | ||
elliottower marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if square: | ||
ours = square > 32 | ||
ours = ( | ||
square < 32 | ||
elliottower marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) # Less than 32 is a white square, otherwise it's a black square | ||
row = square % 8 | ||
dest_col_add = 8 * 7 if ours else 0 | ||
dest_col_add = 0 if ours else 8 * 7 | ||
dest_square = dest_col_add + row | ||
if ours: | ||
result[base + 0].remove(square - 8) | ||
result[base + 0].add(dest_square) | ||
result[base + 0].remove( | ||
elliottower marked this conversation as resolved.
Show resolved
Hide resolved
|
||
square + 8 | ||
) # Set the `square + 8` position in channel `base` to 0 | ||
result[base + 0].add( | ||
dest_square | ||
) # Set the `dest_square` position in channel `base` to 1 | ||
else: | ||
result[base + 6].remove(square + 8) | ||
result[base + 6].add(dest_square) | ||
result[base + 6].remove( | ||
square - 8 | ||
) # Set the `square + 8` position in channel `base` to 0 | ||
result[base + 6].add( | ||
dest_square | ||
elliottower marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) # Set the `dest_square` position in channel `base` to 1 | ||
|
||
return boards_to_ndarray(result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job on this, super thorough and well written, and good thinking to include the option for if people do want to observe the flipped board for black.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for reviewing. It is my pleasure.