-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
up to date with benchmark_model branch
- Loading branch information
1 parent
eac4f8d
commit ba4f271
Showing
4 changed files
with
225 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#' @rdname zeroshot_DMS_metrics | ||
#' | ||
#' @title Model performance metrics for DMS substitutions in the zero-shot | ||
#' setting | ||
#' | ||
#' @param metadata Logical, whether only experiment metadata should be returned. | ||
#' Default behavior is to return processed data with metadata included. | ||
#' | ||
#' @details `zeroshot_DMS_metrics()` loads in the five model performance metrics | ||
#' for ("AUC", "MCC", "NDCG", "Spearman", "Top_recall") calculated on the | ||
#' DMS substitutions in the zero-shot setting. | ||
#' | ||
#' Each data.frame columns contain: | ||
#' - "DMS_ID": Showing the assay name for the 217 DMS studies. | ||
#' - Columns 2:63: Corresponding to the average performance score of each of the | ||
#' 61 models tested. | ||
#' - "Number_of_Mutants": Number of protein mutants evaluated. | ||
#' - "Selection_Type": Protein function grouping. | ||
#' - "UniProt_ID": UniProt protein entry name identifier | ||
#' - "MSA_Neff_L_category": Multiple sequence alignment category. | ||
#' - "Taxon": taxon group. | ||
#' | ||
#' @return Returns a [list()] object with five [data.frame()] corresponding to | ||
#' a model metric table. | ||
#' | ||
#' @references | ||
#' Notin, P., Kollasch, A., Ritter, D., van Niekerk, L., Paul, S., Spinner, H., | ||
#' Rollins, N., Shaw, A., Orenbuch, R., Weitzman, R., Frazer, J., Dias, M., | ||
#' Franceschi, D., Gal, Y., & Marks, D. (2023). ProteinGym: Large-Scale | ||
#' Benchmarks for Protein Fitness Prediction and Design. In A. Oh, T. Neumann, | ||
#' A. Globerson, K. Saenko, M. Hardt, & S. Levine (Eds.), Advances in Neural | ||
#' Information Processing Systems (Vol. 36, pp. 64331-64379). | ||
#' Curran Associates, Inc. | ||
#' | ||
#' @examples | ||
#' data <- zeroshot_DMS_metrics() | ||
#' data_meta <- zeroshot_DMS_metrics(metadata = TRUE) | ||
#' | ||
#' @export | ||
zeroshot_DMS_metrics <- function (metadata = FALSE) | ||
{ | ||
eh <- ExperimentHub::ExperimentHub() | ||
ehid <- "EH9593" | ||
|
||
if (metadata == TRUE) { | ||
eh[ehid] | ||
} | ||
else eh[[ehid]] | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# This file is part of the standard setup for testthat. | ||
# It is recommended that you do not modify it. | ||
# | ||
# Where should you do additional test configuration? | ||
# Learn more about the roles of various files in: | ||
# * https://r-pkgs.org/testing-design.html#sec-tests-files-overview | ||
# * https://testthat.r-lib.org/articles/special-files.html | ||
|
||
library(testthat) | ||
library(ProteinGymR) | ||
|
||
test_check("ProteinGymR") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
test_that("check_metric_argument() works", { | ||
|
||
## Test case when invalid metric | ||
expect_error( | ||
check_metric_argument(user_metric = "Pearson"), | ||
paste( | ||
"Invalid model\\(s\\) specified: Pearson" | ||
) | ||
) | ||
|
||
## Test case when more than 1 metric selected | ||
expect_error( | ||
check_metric_argument(user_metric = c("AUC", "MCC")), | ||
paste( | ||
"Select only one metric for comparison" | ||
) | ||
) | ||
}) | ||
|
||
test_that("check_model_argument() works", { | ||
|
||
## Test case when invalid metric | ||
expect_error( | ||
check_model_argument(models = "Wrong_model"), | ||
paste( | ||
"Invalid model\\(s\\) specified: Wrong_model" | ||
) | ||
) | ||
|
||
## Test case when more than 5 models selected | ||
expect_error( | ||
check_model_argument( | ||
models = c("Site_Independent", "EVmutation", "ESM_1b", | ||
"ProtGPT2", "Progen2_Base", "CARP_640M") | ||
), | ||
paste( | ||
"Select up to 5 models for comparison" | ||
) | ||
) | ||
}) | ||
|
||
test_that("benchmark_models() works", { | ||
|
||
## Test case when metric not defined | ||
res <- evaluate_promise(benchmark_models(model = "GEMME")) | ||
expect_identical( | ||
res$messages[1], | ||
paste( | ||
"No metric specified. Using default Spearman correlation\n" | ||
) | ||
) | ||
|
||
expect_identical( | ||
res$result$labels$y, | ||
paste( | ||
"Spearman score" | ||
) | ||
) | ||
|
||
## Test case when models not defined | ||
expect_error( | ||
benchmark_models(metric = "AUC"), | ||
paste( | ||
"Select at least one model from `available_models\\(\\)`" | ||
) | ||
) | ||
|
||
## Test Spearman table is all positive values | ||
expect_identical( | ||
all(res$result$data$score >= 0), | ||
as.logical("TRUE") | ||
) | ||
|
||
## Test MCC should be -1 to 1 | ||
res <- evaluate_promise(benchmark_models(metric = "MCC", model = "GEMME")) | ||
# Range should be -0.019, 0.798 | ||
|
||
object <- benchmark_models(metric = "MCC", model = "GEMME") | ||
|
||
|
||
expect_identical( | ||
all(res$result$data$score >= 0), | ||
as.logical("FALSE") | ||
) | ||
|
||
## Test pivot_longer worked correctly | ||
expect_identical( | ||
tibble::is_tibble(res$result$data), | ||
as.logical("TRUE") | ||
) | ||
|
||
expect_equal( | ||
res$result$data |> NROW(), | ||
217L | ||
) | ||
|
||
expect_identical( | ||
colnames(res$result$data), | ||
c("model", "score", "model_mean") | ||
) | ||
|
||
## Test that it created correct ggplot object | ||
expect_identical( | ||
ggplot2::ggplot_build(object)$data[[1]]$xmin |> unique(), | ||
1.12 | ||
) | ||
|
||
expect_identical( | ||
ggplot2::ggplot_build(object)$data[[1]]$xmax |> unique(), | ||
1.72 | ||
) | ||
}) |