Skip to content

Commit

Permalink
improve doc with overall architecture #49
Browse files Browse the repository at this point in the history
improve doc with overall architecture
  • Loading branch information
youkaichao authored Aug 26, 2024
2 parents 8c87a5c + d9ef463 commit 02ea32e
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions docs/dev_doc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,24 @@ Developer Documentation

For developers, if you want to understand and contribute to the codebase, this section is for you.

It is recommended to read the following materials before diving into the codebase:
Overall architecture of the library
-----------------------------------

Put it in short, the library is a Python bytecode decompiler with tight integration with PyTorch. It naturally falls into 2 parts:

* The decompiler: It decompiles Python bytecode into Python source code. The decompiler is implemented in the ``depyf/decompiler.py`` file. It can be used as a standalone library to decompile Python bytecode.
* The PyTorch integration: We work together with the Pytorch team to design a bytecode hook mechanism. We use ``torch._dynamo.convert_frame.register_bytecode_hook`` to register a hook to PyTorch. Every time PyTorch compiles a function, the hook will be called, and we can decompile the bytecode and dump the source code to disk. The source code is futher compiled into a new bytecode object, which is functionally equivalent to the original bytecode object but with source code information, making it easier to debug. PyTorch will use the new bytecode object to run the function. The integration logic is implemented in the ``depyf/explain/enable_debugging.py`` file.

Relatively speaking, the PyTorch integration part is easier to understand and contribute. Our main goal for the integration is to make ``depyf`` compatible with all previous versions of PyTorch starting from PyTorch 2.2 . To achieve this goal, the test is run against the nightly build of PyTorch. Whenever we find a compatibility issue, we will fix it in a backward-compatible way. If such a fix is not possible, we will discuss with the PyTorch team to find a solution.

The decompiler part is more challenging. It is complicated and needs to deal with all sorts of random Python implementation details. Fortunately, we only need to deal with official release versions of Python, which makes the task more manageable. The decompiler only needs to be updated once we find a bug or a new Python version is released.

If you want to dive deeper into the decompiler part, please go on reading.

Overview of the decompiler
--------------------------

To become comfortable with reading bytecode, it is recommended to read the following materials first:

- `torchdynamo deepdive <https://www.youtube.com/watch?v=egZB5Uxki0I>`_ : This video explains the motivation and design of torchdynamo. In particular, it mentions how Python bytecode acts like a stack machine, which helps to understand how the bytecode is executed.
- `Python bytecode documentation <https://docs.python.org/3/library/dis.html>`_ : This documentation explains the Python bytecode instructions. Note that Python bytecode does not guarentee any backward compatibility, so the bytecode instructions may change for every Python versions. We should consider all the supported Python versions when implementing the decompiler.
Expand All @@ -27,15 +44,14 @@ It has the following bytecode:
4 BINARY_ADD
6 RETURN_VALUE
When we execute the first bytecode `LOAD_FAST`, instead of loading a variable into the stack, we push the variable name ``"a"`` in the stack, which is a string representation of the variable.
When we execute the first bytecode ``LOAD_FAST``, instead of loading a variable into the stack, we push the variable name ``"a"`` in the stack, which is a string representation of the variable.

When we execute the second bytecode `LOAD_FAST`, likewise, we push the variable name ``"b"`` in the stack.
When we execute the second bytecode ``LOAD_FAST``, likewise, we push the variable name ``"b"`` in the stack.

When we execute the third bytecode `BINARY_ADD`, which intends to add the two variables, we pop the two variables from the stack, and perform the string concatenation ``"a + b"``. The concatenated string is pushed back to the stack.
When we execute the third bytecode ``BINARY_ADD``, which intends to add the two variables, we pop the two variables from the stack, and perform the string concatenation ``"a + b"``. The concatenated string is pushed back to the stack.

Finally, when we execute the fourth bytecode `RETURN_VALUE`, we pop the string from the stack, prefix it with the ``return`` keyword, and then we get the decompiled source code ``"return a + b"``.
Finally, when we execute the fourth bytecode ``RETURN_VALUE``, we pop the string from the stack, prefix it with the ``return`` keyword, and then we get the decompiled source code ``"return a + b"``.

To accurately decompile the bytecode, we need to faithfully respect the semantics of the Python bytecode instructions. It is noteworthy that the `Python bytecode documentation <https://docs.python.org/3/library/dis.html>`_ can be outdated and inaccurate, too. The golden standard is to refer to the CPython source code and the Python interpreter's behavior. The `torchdynamo source code <https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/symbolic_convert.py>`_ is also a good reference to understand how the Python bytecode is generated by PyTorch.

Should you have any further questions, feel free to ask in the `GitHub Issues <https://github.com/thuml/depyf/issues>`_ section.

0 comments on commit 02ea32e

Please sign in to comment.