Skip to content

Commit

Permalink
multi-outcome tabnet-fit and predict (#118)
Browse files Browse the repository at this point in the history
* add multi-output and multilabel test

* switch to `expect_no_error()` for readability

* `momentum` consistently default to 0.02
switch target `y` from vector to array

* turn `output_dim` into vector when multi_output
manage loss cases for multi_output
lint and refactor code

* pass `is_multi_outcome` to predict
encode output_dim for multi-outcome
improve multi-outcome classification loss
split predict based on `is_multi_outcome`

* working predict_impl_class and predict_numeric
switch to hardhat v1.3.0

* refactor predict_impl_ for a clearer case_when() call

* improve `check_type` to manage multi-outcome
fix tests vqlues
add mixed-outcome and multi-outcome with valid test

* add consistency checks for outcome types

* add multi-output description in `tabnet-fit`
move multi-output tests is a dedicated file

* improve multi-outcome tests
fix multi-outcome classification
  • Loading branch information
cregouby authored May 9, 2023
1 parent 162134c commit e5c5306
Show file tree
Hide file tree
Showing 33 changed files with 841 additions and 512 deletions.
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@
^cran-comments\.md$
^CRAN-RELEASE$
^.V8*
^doc$
^Meta$
108 changes: 24 additions & 84 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,66 +35,33 @@ jobs:
TORCH_TEST: 1

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3

- uses: r-lib/actions/setup-r@v1
- uses: r-lib/actions/setup-r@v2
with:
r-version: ${{ matrix.config.r }}

- uses: r-lib/actions/setup-pandoc@v1

- name: Query dependencies
run: |
install.packages('remotes')
saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2)
writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version")
shell: Rscript {0}
- uses: r-lib/actions/setup-pandoc@v2

- name: Cache R packages
if: runner.os != 'Windows'
uses: actions/cache@v2
- uses: r-lib/actions/setup-r-dependencies@v2
with:
path: ${{ env.R_LIBS_USER }}
key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }}
restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-

- name: Install system dependencies
if: runner.os == 'Linux'
run: |
while read -r cmd
do
eval sudo $cmd
done < <(Rscript -e 'writeLines(remotes::system_requirements("ubuntu", "20.04"))')
- name: Install macOS dependencies
if: runner.os == 'macOS'
run: brew install --cask xquartz

- name: Install dependencies
run: |
remotes::install_deps(dependencies = TRUE)
remotes::install_cran("rcmdcheck")
shell: Rscript {0}

- name: Check
env:
_R_CHECK_CRAN_INCOMING_REMOTE_: false
run: rcmdcheck::rcmdcheck(args = c("--no-manual", "--as-cran"), error_on = "warning", check_dir = "check")
shell: Rscript {0}

- name: Upload check results
if: failure()
uses: actions/upload-artifact@main
extra-packages: any::rcmdcheck
needs: check

- uses: r-lib/actions/check-r-package@v2
with:
name: ${{ runner.os }}-r${{ matrix.config.r }}-results
path: check
error-on: '"error"'
args: 'c("--no-multiarch", "--no-manual", "--as-cran")'

GPU:
runs-on: ['self-hosted', 'gce', 'gpu']
name: 'gpu'

container:
image: nvidia/cuda:11.6.0-cudnn8-devel-ubuntu18.04
options: --gpus all
image: 'nvidia/cuda:11.6.0-cudnn8-devel-ubuntu18.04'
options: '--gpus all --runtime=nvidia'

timeout-minutes: 120

env:
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
Expand All @@ -105,49 +72,22 @@ jobs:
DEBIAN_FRONTEND: 'noninteractive'

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3

- run: |
apt-get update -y
apt-get install -y sudo software-properties-common dialog apt-utils tzdata
apt-get install -y sudo software-properties-common dialog apt-utils tzdata libpng-dev
- uses: r-lib/actions/setup-r@v2
with:
r-version: 'release'

- uses: r-lib/actions/setup-pandoc@v2

- name: Query dependencies
run: |
install.packages('remotes')
saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2)
writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version")
shell: Rscript {0}

- name: Cache R packages
if: runner.os != 'Windows'
uses: actions/cache@v2
- uses: r-lib/actions/setup-r-dependencies@v2
with:
path: ${{ env.R_LIBS_USER }}
key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }}
restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-

- name: Install dependencies
run: |
remotes::install_deps(dependencies = TRUE)
remotes::install_cran("rcmdcheck")
shell: Rscript {0}

- name: Check
env:
_R_CHECK_CRAN_INCOMING_REMOTE_: false
run: rcmdcheck::rcmdcheck(args = c("--no-manual", "--as-cran"), error_on = "error", check_dir = "check")
shell: Rscript {0}

- name: Upload check results
if: failure()
uses: actions/upload-artifact@main
with:
name: ${{ runner.os }}-r${{ matrix.config.r }}-results
path: check
extra-packages: any::rcmdcheck
needs: check

- uses: r-lib/actions/check-r-package@v2
with:
error-on: '"error"'
args: 'c("--no-multiarch", "--no-manual", "--as-cran")'
55 changes: 27 additions & 28 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ jobs:
runs-on: ['self-hosted', 'gce', 'gpu']

container:
image: nvidia/cuda:11.6.0-cudnn8-devel-ubuntu18.04
options: --gpus all
image: 'nvidia/cuda:11.6.0-cudnn8-devel-ubuntu18.04'
options: '--gpus all --runtime=nvidia'

timeout-minutes: 120

env:
RSPM: https://packagemanager.rstudio.com/cran/__linux__/bionic/latest
Expand All @@ -24,42 +26,39 @@ jobs:
DEBIAN_FRONTEND: 'noninteractive'

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3

- run: |
apt-get update -y
apt-get install -y sudo software-properties-common dialog apt-utils tzdata
- uses: r-lib/actions/setup-r@v1
id: install-r
- name: Install pak and query dependencies
run: |
install.packages("pak", repos = "https://r-lib.github.io/p/pak/dev/")
saveRDS(pak::pkg_deps("local::.", dependencies = TRUE), ".github/r-depends.rds")
shell: Rscript {0}
- uses: r-lib/actions/setup-r@v2

- name: Restore R package cache
uses: actions/cache@v2
- uses: r-lib/actions/setup-r-dependencies@v2
with:
path: |
${{ env.R_LIBS_USER }}/*
!${{ env.R_LIBS_USER }}/pak
key: ubuntu-18.04-${{ steps.install-r.outputs.installed-r-version }}-1-${{ hashFiles('.github/r-depends.rds') }}
restore-keys: ubuntu-18.04-${{ steps.install-r.outputs.installed-r-version }}-1-
extra-packages: |
any::covr
- name: Install system dependencies
if: runner.os == 'Linux'
- name: Test coverage
run: |
pak::local_system_requirements(execute = TRUE)
pak::pkg_system_requirements("covr", execute = TRUE)
covr::codecov(
quiet = FALSE,
clean = FALSE,
install_path = file.path(Sys.getenv("RUNNER_TEMP"), "package")
)
shell: Rscript {0}

- name: Install dependencies
- name: Show testthat output
if: always()
run: |
pak::local_install_dev_deps(upgrade = TRUE)
pak::pkg_install("covr")
shell: Rscript {0}
## --------------------------------------------------------------------
find ${{ runner.temp }}/package -name 'testthat.Rout*' -exec cat '{}' \; || true
shell: bash

- name: Upload test results
if: failure()
uses: actions/upload-artifact@v3
with:
name: coverage-test-failures
path: ${{ runner.temp }}/package

- name: Test coverage
run: covr::codecov()
shell: Rscript {0}
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ inst/doc
.venv
activate
.V8history
/doc/
/Meta/
revdep
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tabnet
Title: Fit 'TabNet' Models for Classification and Regression
Version: 0.3.0.9000
Version: 0.4.0
Authors@R: c(
person(given = "Daniel", family = "Falbel", role = c("aut"), email = "[email protected]"),
person(family = "RStudio", role = c("cph")),
Expand All @@ -19,7 +19,7 @@ URL: https://github.com/mlverse/tabnet
BugReports: https://github.com/mlverse/tabnet/issues
Imports:
torch (>= 0.4.0),
hardhat,
hardhat (>= 1.3.0),
magrittr,
glue,
progress,
Expand Down
8 changes: 4 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# tabnet (development version)

# tabnet 0.4.0

## New features
Expand All @@ -11,18 +9,20 @@
* Allow missing-values values in predictor for unsupervised training. (#68)
* Improve performance of `random_obfuscator()` torch_nn module. (#68)
* Add support for early stopping (#69)
* `tabnet_fit()` and `predict()` now allow missing values in predictors. (#76)
* `tabnet_fit()` and `predict()` now allow **missing values** in predictors. (#76)
* `tabnet_config()` now supports a `num_workers=` parameters to control parallel dataloading (#83)
* Add a vignette on missing data (#83)
* `tabnet_config()` now has a flag `skip_importance` to skip calculating feature importance (@egillax, #91)
* Export and document `tabnet_nn`
* Added `min_grid.tabnet` method for `tune` (@cphaarmeyer, #107)
* Added `tabnet_explain()` method for parsnip models (@cphaarmeyer, #108)
* `tabnet_fit()` and `predict()` now allow **multi-outcome**, all numeric or all factors but not mixed. (#118)

## Bugfixes

* `tabnet_explain()` is now correctly handling missing values in predictors. (#77)
* `dataloader` can now use `num_workers>0` (#83)
* new default values for `batch_size` and `virtual_batch_size` do not limit performance on mid-range devices.
* new default values for `batch_size` and `virtual_batch_size` improves performance on mid-range devices.
* add default `engine="torch"` to tabnet parsnip model (#114)
* fix `autoplot()` warnings turned into errors with {ggplot2} v3.4 (#113)

Expand Down
Loading

0 comments on commit e5c5306

Please sign in to comment.