-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add multi-output and multilabel test * switch to `expect_no_error()` for readability * `momentum` consistently default to 0.02 switch target `y` from vector to array * turn `output_dim` into vector when multi_output manage loss cases for multi_output lint and refactor code * pass `is_multi_outcome` to predict encode output_dim for multi-outcome improve multi-outcome classification loss split predict based on `is_multi_outcome` * working predict_impl_class and predict_numeric switch to hardhat v1.3.0 * refactor predict_impl_ for a clearer case_when() call * improve `check_type` to manage multi-outcome fix tests vqlues add mixed-outcome and multi-outcome with valid test * add consistency checks for outcome types * add multi-output description in `tabnet-fit` move multi-output tests is a dedicated file * improve multi-outcome tests fix multi-outcome classification
- Loading branch information
Showing
33 changed files
with
841 additions
and
512 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,3 +9,5 @@ | |
^cran-comments\.md$ | ||
^CRAN-RELEASE$ | ||
^.V8* | ||
^doc$ | ||
^Meta$ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,6 @@ inst/doc | |
.venv | ||
activate | ||
.V8history | ||
/doc/ | ||
/Meta/ | ||
revdep |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
Package: tabnet | ||
Title: Fit 'TabNet' Models for Classification and Regression | ||
Version: 0.3.0.9000 | ||
Version: 0.4.0 | ||
Authors@R: c( | ||
person(given = "Daniel", family = "Falbel", role = c("aut"), email = "[email protected]"), | ||
person(family = "RStudio", role = c("cph")), | ||
|
@@ -19,7 +19,7 @@ URL: https://github.com/mlverse/tabnet | |
BugReports: https://github.com/mlverse/tabnet/issues | ||
Imports: | ||
torch (>= 0.4.0), | ||
hardhat, | ||
hardhat (>= 1.3.0), | ||
magrittr, | ||
glue, | ||
progress, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.