Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update README and config files with Moirai-1.1-R and fix Moirai-1.0-R model weights #111

Merged
merged 1 commit into from
Aug 22, 2024

Conversation

gorold
Copy link
Contributor

@gorold gorold commented Aug 19, 2024

Made PRs on huggingface to fix the regression of Moirai-1.0-R model weights with the recent code changes
https://huggingface.co/Salesforce/moirai-1.0-R-small/discussions/7
https://huggingface.co/Salesforce/moirai-1.0-R-base/discussions/6
https://huggingface.co/Salesforce/moirai-1.0-R-large/discussions/6

Please help to test, I only managed to run some small tests on my local :)

The changes to model weights were just:

mm = MoiraiModule.from_pretrained("Salesforce/moirai-1.0-R-small")
mm.param_proj.proj.weights_logits.weight[:] = torch.roll(mm.param_proj.proj.weights_logits.weight, 2, dims=0)
mm.param_proj.proj.weights_logits.bias[:] = torch.roll(mm.param_proj.proj.weights_logits.bias, 2, dims=0)
mm.push_to_hub("Salesforce/moirai-1.0-R-small")

@gorold gorold changed the title Update README and config files with Moirai-R-1.1 and fix 1.0 model weights Update README and config files with Moirai-1.1-R and fix Moirai-1.0-R model weights Aug 19, 2024
@liu-jc
Copy link
Contributor

liu-jc commented Aug 20, 2024

Hi @gorold,

Sorry, I don't quite understand. How these changes with torch.roll can help to fix the weights? We still have / dim in our codebase?

torch.eq(out_feat_size, feat_size // self.dim).unsqueeze(-1)

@chenghaoliu89 chenghaoliu89 requested a review from liu-jc August 20, 2024 02:02
@gorold
Copy link
Contributor Author

gorold commented Aug 20, 2024

  1. The only weights affected are those where self.dim > 1, which in this case is just param_proj.proj.weights_logits.
  2. What happened during training of Moirai 1.0 is that we had weights_logits.out_features_ls = (32, 64, 128, 256, 512), and did not have feat_size // self.dim leading to this layer learning the weights for patch size 32 in the first position, 64 in the second position, and so on.
  3. This meant that weights_logits.weight[0:3] contain the correct weights for patch sizes 32, 64, 128, whereas the weights_logits.weight[3:5] are useless.
  4. So we use torch.roll to push the weights into the correct positions for patch sizes 32, 64, 128, and leave 256, 512 in the first two positions as dummies.

Copy link
Contributor

@liu-jc liu-jc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I got it. Do you think we should test all models and see if we can get the similar results on all datasets?

@gorold
Copy link
Contributor Author

gorold commented Aug 21, 2024

Ran it on the PF benchmark ex weather:

Small

with fix:

index dataset test_metrics/MSE[mean] test_metrics/MSE[0.5] test_metrics/MAE[0.5] test_metrics/MASE[0.5] test_metrics/MAPE[0.5] test_metrics/sMAPE[0.5] test_metrics/MSIS test_metrics/RMSE[mean] test_metrics/NRMSE[mean] test_metrics/ND[0.5] test_metrics/mean_weighted_sum_quantile_loss
0 electricity 3846122.0 4047497.75 219.01437377929688 0.9806302189826965 0.1304870843887329 0.13351434469223022 8.014351844787598 1961.1531982421875 0.8221926093101501 0.09181945025920868 0.07224071025848389
1 solar-energy 1229.113525390625 1439.4451904296875 19.23911476135254 1.4672776460647583 2.4206011295318604 1.444728970527649 8.45895767211914 35.0587158203125 1.138300895690918 0.6246635317802429 0.4719245433807373
2 walmart 27342336.0 19882204.0 2114.254150390625 0.9929503202438354 0.24787983298301697 0.17269687354564667 8.747377395629883 5228.990234375 0.29910993576049805 0.12094007432460785 0.09698692709207535
3 istanbul_traffic 135.40859985351562 154.15487670898438 8.99282455444336 1.0580122470855713 0.5734052658081055 0.3754161596298218 5.677515506744385 11.636520385742188 0.31070834398269653 0.24011866748332977 0.17352095246315002
4 turkey_power 755125.1875 758453.9375 358.3751525878906 0.9450287818908691 0.5046748518943787 0.3893551230430603 7.002157211303711 868.9793701171875 0.14972686767578125 0.061748750507831573 0.048235006630420685

without fix:

index dataset test_metrics/MSE[mean] test_metrics/MSE[0.5] test_metrics/MAE[0.5] test_metrics/MASE[0.5] test_metrics/MAPE[0.5] test_metrics/sMAPE[0.5] test_metrics/MSIS test_metrics/RMSE[mean] test_metrics/NRMSE[mean] test_metrics/ND[0.5] test_metrics/mean_weighted_sum_quantile_loss
0 electricity 23829644.0 22212912.0 560.96923828125 2.2790398597717285 0.33599036931991577 0.25199955701828003 12.980779647827148 4881.5615234375 2.0465428829193115 0.23518039286136627 0.17356589436531067
1 solar-energy 1965.3592529296875 2181.220458984375 23.293292999267578 1.7744004726409912 2.6420609951019287 1.4974911212921143 13.509282112121582 44.33237075805664 1.4394018650054932 0.7562963366508484 0.6452235579490662
2 walmart 31941660.0 22301972.0 2262.530517578125 1.0623531341552734 0.2722017765045166 0.1812155842781067 9.110936164855957 5651.6953125 0.3232896327972412 0.12942181527614594 0.10336102545261383
3 istanbul_traffic 202.15194702148438 205.9799346923828 11.197728157043457 1.317776083946228 1.6119800806045532 0.4647684693336487 7.559725761413574 14.21801471710205 0.379637211561203 0.2989920973777771 0.21395711600780487
4 turkey_power 5094314.5 2785156.5 762.2886352539062 2.0671753883361816 2.91452693939209 0.4795369803905487 25.759414672851562 2257.058837890625 0.388895720243454 0.13134384155273438 0.10627298802137375

Feel free to test on other sizes and datasets. You can directly download the model in the PR branch with the revision argument:

mm = MoiraiModule.from_pretrained("Salesforce/moirai-1.0-R-small", revision="pr/7")

Note that the PR number might be different for the other model sizes.

@gorold
Copy link
Contributor Author

gorold commented Aug 21, 2024

moirai-1.0-R-Base

with fix:

index dataset test_metrics/MSE[mean] test_metrics/MSE[0.5] test_metrics/MAE[0.5] test_metrics/MASE[0.5] test_metrics/MAPE[0.5] test_metrics/sMAPE[0.5] test_metrics/MSIS test_metrics/RMSE[mean] test_metrics/NRMSE[mean] test_metrics/ND[0.5] test_metrics/mean_weighted_sum_quantile_loss
0 electricity 1708712.5 1711201.625 164.1307373046875 0.7915405631065369 0.10031454265117645 0.11076250672340393 6.184873580932617 1307.17724609375 0.5480201840400696 0.06881006807088852 0.054687876254320145
1 solar-energy 1011.0943603515625 1108.408935546875 16.981399536132812 1.2911229133605957 2.296311855316162 1.4095485210418701 7.017038345336914 31.797710418701172 1.0324212312698364 0.5513591766357422 0.41874560713768005
2 walmart 26299296.0 19072352.0 2049.69384765625 0.9657745957374573 0.23114198446273804 0.1677016019821167 8.415294647216797 5128.28369140625 0.29334932565689087 0.117247074842453 0.09353204816579819
3 istanbul_traffic 37.16828918457031 40.77793884277344 4.562923431396484 0.5369675755500793 0.2586391270160675 0.2553446292877197 3.8279149532318115 6.0965800285339355 0.1627856343984604 0.12183525413274765 0.09833786636590958
4 turkey_power 473797.125 474377.5625 295.6066589355469 0.8949130177497864 0.16863861680030823 0.37849825620651245 6.532022476196289 688.3292236328125 0.11860048770904541 0.050933610647916794 0.040024567395448685

without fix:

index dataset test_metrics/MSE[mean] test_metrics/MSE[0.5] test_metrics/MAE[0.5] test_metrics/MASE[0.5] test_metrics/MAPE[0.5] test_metrics/sMAPE[0.5] test_metrics/MSIS test_metrics/RMSE[mean] test_metrics/NRMSE[mean] test_metrics/ND[0.5] test_metrics/mean_weighted_sum_quantile_loss
0 electricity 7544522.5 6037268.0 300.3915710449219 1.4005271196365356 0.1897597312927246 0.17114804685115814 8.438410758972168 2746.7294921875 1.1515370607376099 0.12593597173690796 0.09828455746173859
1 solar-energy 6939.76318359375 3234.031982421875 46.51622772216797 3.5624678134918213 15.757192611694336 1.4608030319213867 40.661155700683594 83.30523681640625 2.704789161682129 1.5103082656860352 1.162361741065979
2 walmart 20425108.0 19084788.0 1990.737548828125 0.9465618133544922 0.22026395797729492 0.16590428352355957 8.317343711853027 4519.41455078125 0.25852063298225403 0.1138746440410614 0.09144116938114166
3 istanbul_traffic 122.82603454589844 160.76173400878906 8.947684288024902 1.0514392852783203 1.8842673301696777 0.375012069940567 5.328226566314697 11.082691192626953 0.2959204614162445 0.23891335725784302 0.17149078845977783
4 turkey_power 7537457.0 2658505.75 731.1760864257812 1.9673036336898804 5.040554523468018 0.4480009973049164 29.197540283203125 2745.44287109375 0.4730452299118042 0.1259830892086029 0.11150600761175156

Copy link
Contributor

@liu-jc liu-jc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@liu-jc liu-jc merged commit 2e1e3fd into SalesforceAIResearch:main Aug 22, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants