Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for loading PyTorch .pt (weights/states) files directly to model's record #1085

Merged
merged 84 commits into from
Jan 25, 2024

Conversation

antimora
Copy link
Collaborator

@antimora antimora commented Dec 19, 2023

This PR introduces a new feature for loading PyTorch .pt files using Recorder.

Key Highlights:

  1. Loading .pt Files:

    let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
        .load("mypytorchweights.pt".into())
        .expect("Failed to decode state");
    
    let model = MyModel::<B>::new_with(record);
  2. Remapping Levels/Keys:
    Aligning the source model's levels/keys with the target model (detailed explanation in the book):

    let load_args = LoadArgs::new("mypytorchweights.pt".into())
        .with_key_remap("conv\\.(.*)", "$1"); // Removes "conv" prefix, e.g., "conv.conv1" -> "conv1"
    
    let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
        .load(load_args)
        .expect("Failed to decode state");
    
    let model = MyModel::<Backend>::new_with(record);
  3. Compatibility with Burn's Modules:
    The loader uses special adapters to match Burn's NN module structure. For example, Linear's weight is stored as [in, out] in Burn, whereas PyTorch stores it as [out, in]. This implementation facilitates ongoing development in Burn without needing to conform to PyTorch's format.

  4. Dynamic Loading and Deployment:
    Dynamic loading from PyTorch files is not necessary for deployments. The pytorch-import example shows how to convert a pt file to Burn's format during build time, thus removing the need to link to the Pickle library (candle-core in our case).

Current Limitations:

  1. Candle's pickle library does not currently function on Windows due to this bug: Candle Issue #1454. Burn issue ticket: Candle's pickle library does not currently function on Windows due to Candle bug #1178
  2. Candle's pickle does not currently unpack boolean tensors. Burn issue ticket: Candle's pickle does not currently unpack boolean tensors #1179.

Pull Request Template

Checklist

  • Confirmed execution of the run-checks all script.
  • Ensuring the book is updated with the changes in this PR. (TODO: I am still working on this)

Changes

  1. Enhanced burn-import to add the PyTorch feature with PyTorchFileRecorder.
  2. Refactored record implementation for [T;N] to avoid conversion into Vec[T] in primitive.rs. Instead, implemented serde for the [T;N] type. This is necessary to set default values for [T;N], e.g., kernel_size in conv2d. It's impractical to set a default value for Vec[T] and convert to [T;N] because the default for a Vec is an empty vector, which cannot be converted to [T;N].
  3. Added a record-serde feature in burn-core that can take a NestedValue object and deserialize it into a RecordItem. This is used by PyTorchFileRecorder in burn-import and can also be utilized by other formats, e.g., Safetensor.

Testing

Added a pytorch-tests sub-crate in burn-import that thoroughly tests all NN models and scenarios using actual PyTorch pt files.

@antimora antimora changed the title [WIP] Add support for import pytorch .pt files [WIP] Add support for importing pytorch .pt files using burn-import Dec 19, 2023
Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! It's a lot of work, and I'm sure it's going to be very helpful to a lot of users!

Regarding the linear conversion: I'm unsure we should "force" our modules to store the weights the same way as PyTorch. What I'm sure about is that it would be very cool to have a way to load records of different versions and apply migration. I'm going to propose something soon regarding that.

burn-derive/src/record/codegen_struct.rs Outdated Show resolved Hide resolved
@antimora
Copy link
Collaborator Author

@nathanielsimard @louisfd

Regarding the linear conversion: I'm unsure we should "force" our modules to store the weights the same way as PyTorch.

Yes, I agree it's not ideal, as my solution would not scale to other formats and would constrain our design choices going forward.

What I'm sure about is that it would be very cool to have a way to load records of different versions and apply migration. I'm going to propose something soon regarding that.

I agree, that would indeed be very cool. It would be even cooler if we can still keep build conversion.

I have two possible solutions (A and B):

"A" solution:

  1. A .pt file (PyTorch model file) is converted during the build and it does not have to be aware of the target model - only during loading. This design choice will allow for build time translation or using a CLI tool. This is currently accomplished in my current implementation. Pre-converting as opposed to converting on the fly a) allows for doing some work in advance, b) eliminates a .pt runtime dependency, c) allows loading a subset of weights (e.g., load only encoder and not decoder).
  2. During record loading, there is additional (minimal) translation because we can match the name & location of tensors. This is when we have the opportunity to know the target modules. This would mean we have to implement a custom load function. Something like the following:
let record = NamedMpkFileRecorder::<FullPrecisionSettings>::default()
    .load_pytorch_translated(file_path)
    .expect("Failed to decode state");

The caveat of the minimal translation is that there is still run-time conversion (e.g., transposing).

"B" solution:

A model's record is loaded from a .pt file but could be re-saved to Burn's file format. This would allow knowing what parts go to target modules (e.g., Linear). However, I am not sure if this would be achievable using build.rs because of the circular dependency. A CLI tool becomes out of the question because the target model info will be missing.

@antimora
Copy link
Collaborator Author

@nathanielsimard and I had an offline conversation, and here is the revised summary:

  1. We agreed that the primary goal is to effectively integrate PyTorch weights into the Burn framework while maintaining independence from PyTorch's structural constraints. This involves developing mechanisms for importing, patching, and handling weights and module structures in a way that aligns with Burn's unique architecture.

  2. The generated "Record" struct will provide essential information about the target module, including its hierarchical position, name, and module type (e.g., Linear or BatchNorm).

  3. For PyTorch integration, we will use a PyTorchFileRecorder that functions as follows:

    let record: MyModelRecord = PyTorchFileRecorder::<FullPrecisionSettings>::default()
        .load(file_path)
        .expect("Failed to load .pt file");
    let model = MyModel::<Backend>::new_with(record);
  4. MyModelRecord is a record type that can be saved in various Burn formats using the existing recorders, such as NamedMpkFileRecorder, PrettyJsonFileRecorder, BinFileRecorder, etc.

This solution achieves decoupling and offers the following advantages:

  1. It enables dynamic or build-time conversion.
  2. It can be implemented by others in addition to PyTorch.
  3. It enhances accuracy as users are not required to tag module types manually.
  4. It remains flexible, allowing for changes in module names from the source.

Copy link
Collaborator

@Luni-4 Luni-4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your work @antimora!

Just an advice to:

  • Remove a depencendy
  • Use an API interface more path-based and less str-based

burn-import/src/bin/pytorch2burn.rs Outdated Show resolved Hide resolved
burn-import/src/bin/pytorch2burn.rs Outdated Show resolved Hide resolved
burn-import/src/pytorch/converter.rs Outdated Show resolved Hide resolved
burn-import/src/pytorch/converter.rs Outdated Show resolved Hide resolved
burn-import/src/pytorch/converter.rs Outdated Show resolved Hide resolved
burn-import/src/pytorch/remapping.rs Outdated Show resolved Hide resolved
burn-import/src/pytorch/converter.rs Outdated Show resolved Hide resolved
burn-import/src/pytorch/remapping.rs Outdated Show resolved Hide resolved
Cargo.toml Outdated Show resolved Hide resolved
@antimora
Copy link
Collaborator Author

Just to update everyone. I have a solution that will accomplish what @nathanielsimard and I discussed. I researched serde extensively and it's possible to achieve only through a custom deserializer. No code change in the core or derived required.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@Luni-4 Luni-4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for me too! Thanks a lot for your hard work! 😃

Just a final question, the TODO in the first message will be covered in a next PR or is this an intermediate review?

@nathanielsimard nathanielsimard merged commit 0368409 into tracel-ai:main Jan 25, 2024
13 of 14 checks passed
@nathanielsimard nathanielsimard deleted the import-torch branch January 25, 2024 15:20
@antimora
Copy link
Collaborator Author

Fine for me too! Thanks a lot for your hard work! 😃

Just a final question, the TODO in the first message will be covered in a next PR or is this an intermediate review?

The documentation and filing TODOs will be done next (new PR). The TODO comment that you had found regarding Conv group testing is removed because I am testing similar aspects with kernel_size > 1.

/// [Replacement](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) for the
/// replacement syntax.
pub fn with_key_remap(mut self, pattern: &str, replacement: &str) -> Self {
let regex = Regex::new(&format!("^{}$", pattern)).unwrap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a good idea to insert ^ and $ in this place, it will mislead the user.
And there are many use cases where ^ and $ need to be removed. For example I tried to rename all keys like conv.0 to conv0 because I used some workaround to implement PyTorch's Sequential in Burn. I want to do this for the whole model, it would be convenient if I could use .with_key_remap(r#"([a-z]+)\.(\d+)"#, "$1$2").

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. If you could submit a quick PR or issue, we will merge/fix it.

Thanks for letting us know.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Nikaidou-Shinku I submitted a PR fix: #1196

let mut new_name = name.clone();
for (pattern, replacement) in &key_remap {
if pattern.is_match(&name) {
new_name = pattern.replace_all(&name, replacement.as_str()).to_string();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since LoadArgs::with_key_remap inserts ^ and $, replace_all has no effect at all here.

@antimora antimora mentioned this pull request Jan 30, 2024
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants