-
Notifications
You must be signed in to change notification settings - Fork 474
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for loading PyTorch .pt
(weights/states) files directly to model's record
#1085
Conversation
.pt
files.pt
files using burn-import
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job! It's a lot of work, and I'm sure it's going to be very helpful to a lot of users!
Regarding the linear conversion: I'm unsure we should "force" our modules to store the weights the same way as PyTorch. What I'm sure about is that it would be very cool to have a way to load records of different versions and apply migration. I'm going to propose something soon regarding that.
Yes, I agree it's not ideal, as my solution would not scale to other formats and would constrain our design choices going forward.
I agree, that would indeed be very cool. It would be even cooler if we can still keep build conversion. I have two possible solutions (A and B): "A" solution:
let record = NamedMpkFileRecorder::<FullPrecisionSettings>::default()
.load_pytorch_translated(file_path)
.expect("Failed to decode state"); The caveat of the minimal translation is that there is still run-time conversion (e.g., transposing). "B" solution: A model's record is loaded from a |
@nathanielsimard and I had an offline conversation, and here is the revised summary:
This solution achieves decoupling and offers the following advantages:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your work @antimora!
Just an advice to:
- Remove a depencendy
- Use an API interface more path-based and less str-based
Just to update everyone. I have a solution that will accomplish what @nathanielsimard and I discussed. I researched serde extensively and it's possible to achieve only through a custom deserializer. No code change in the core or derived required. |
This is does not build but it has enough progress to make the import work. I am committing it not to lose it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine for me too! Thanks a lot for your hard work! 😃
Just a final question, the TODO
in the first message will be covered in a next PR or is this an intermediate review?
The documentation and filing TODOs will be done next (new PR). The TODO comment that you had found regarding Conv group testing is removed because I am testing similar aspects with kernel_size > 1. |
/// [Replacement](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace) for the | ||
/// replacement syntax. | ||
pub fn with_key_remap(mut self, pattern: &str, replacement: &str) -> Self { | ||
let regex = Regex::new(&format!("^{}$", pattern)).unwrap(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's a good idea to insert ^
and $
in this place, it will mislead the user.
And there are many use cases where ^
and $
need to be removed. For example I tried to rename all keys like conv.0
to conv0
because I used some workaround to implement PyTorch's Sequential
in Burn. I want to do this for the whole model, it would be convenient if I could use .with_key_remap(r#"([a-z]+)\.(\d+)"#, "$1$2")
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. If you could submit a quick PR or issue, we will merge/fix it.
Thanks for letting us know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Nikaidou-Shinku I submitted a PR fix: #1196
let mut new_name = name.clone(); | ||
for (pattern, replacement) in &key_remap { | ||
if pattern.is_match(&name) { | ||
new_name = pattern.replace_all(&name, replacement.as_str()).to_string(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since LoadArgs::with_key_remap
inserts ^
and $
, replace_all
has no effect at all here.
This PR introduces a new feature for loading PyTorch
.pt
files using Recorder.Key Highlights:
Loading
.pt
Files:Remapping Levels/Keys:
Aligning the source model's levels/keys with the target model (detailed explanation in the book):
Compatibility with Burn's Modules:
The loader uses special adapters to match Burn's NN module structure. For example, Linear's weight is stored as
[in, out]
in Burn, whereas PyTorch stores it as[out, in]
. This implementation facilitates ongoing development in Burn without needing to conform to PyTorch's format.Dynamic Loading and Deployment:
Dynamic loading from PyTorch files is not necessary for deployments. The
pytorch-import
example shows how to convert apt
file to Burn's format during build time, thus removing the need to link to the Pickle library (candle-core in our case).Current Limitations:
Pull Request Template
Checklist
run-checks all
script.Changes
burn-import
to add the PyTorch feature withPyTorchFileRecorder
.[T;N]
to avoid conversion intoVec[T]
inprimitive.rs
. Instead, implemented serde for the[T;N]
type. This is necessary to set default values for[T;N]
, e.g.,kernel_size
inconv2d
. It's impractical to set a default value forVec[T]
and convert to[T;N]
because the default for aVec
is an empty vector, which cannot be converted to[T;N]
.record-serde
feature inburn-core
that can take aNestedValue
object and deserialize it into aRecordItem
. This is used byPyTorchFileRecorder
inburn-import
and can also be utilized by other formats, e.g., Safetensor.Testing
Added a
pytorch-tests
sub-crate inburn-import
that thoroughly tests all NN models and scenarios using actual PyTorchpt
files.