From 335fac0133f782b1bde295a2c56e0e2dca8e3e58 Mon Sep 17 00:00:00 2001 From: Chris Norman Date: Tue, 29 Jan 2019 21:44:18 -0500 Subject: [PATCH] Downsample reads with a ReservoirDownsampler in CNNScoreVariants. (#5622) * Downsample reads with ReservoirDownsampler in CNNScoreVariants. * updated test files and reset random number generator in tests --- .../hellbender/tools/walkers/vqsr/CNNScoreVariants.java | 6 +++++- .../walkers/vqsr/CNNScoreVariantsIntegrationTest.java | 7 +++++++ .../large/VQSR/expected/cnn_2d_chr20_subset_expected.vcf | 2 +- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java index 025a42fd319..70c5046a5fe 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java @@ -18,6 +18,10 @@ import org.broadinstitute.hellbender.engine.*; import org.broadinstitute.barclay.argparser.*; import org.broadinstitute.hellbender.engine.filters.*; +import org.broadinstitute.hellbender.exceptions.GATKException; +import org.broadinstitute.hellbender.utils.downsampling.ReadsDownsamplingIterator; +import org.broadinstitute.hellbender.utils.downsampling.ReservoirDownsampler; +import org.broadinstitute.hellbender.utils.haplotype.HaplotypeBAMWriter; import org.broadinstitute.hellbender.utils.io.IOUtils; import org.broadinstitute.hellbender.utils.io.Resource; import org.broadinstitute.barclay.help.DocumentedFeature; @@ -420,7 +424,7 @@ private void transferReadsToPythonViaFifo(final VariantContext variant, final Re } catch (UnsupportedEncodingException e) { throw new GATKException("Trying to make string from reference, but unsupported encoding UTF-8.", e); } - Iterator readIt = readsContext.iterator(); + Iterator readIt = new ReadsDownsamplingIterator(readsContext.iterator(), new ReservoirDownsampler(readLimit)); if (!readIt.hasNext()) { logger.warn("No reads at contig:" + variant.getContig() + " site:" + String.valueOf(variant.getStart())); } diff --git a/src/test/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariantsIntegrationTest.java b/src/test/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariantsIntegrationTest.java index 247ee190534..db2adce3489 100644 --- a/src/test/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariantsIntegrationTest.java +++ b/src/test/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariantsIntegrationTest.java @@ -4,6 +4,7 @@ import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions; import org.broadinstitute.hellbender.testutils.ArgumentsBuilder; import org.broadinstitute.hellbender.testutils.IntegrationTestSpec; +import org.broadinstitute.hellbender.utils.Utils; import org.testng.annotations.Test; import java.io.IOException; @@ -155,6 +156,9 @@ public void testOnContigEdge() throws IOException { */ @Test(groups = {"python"}) public void testInference2dResourceModel() throws IOException { + // We reset the random number generator at the beginning of each test so that the random down-sampling of reads + // by the reservoir down-sampler does not cause slightly different scores. + Utils.resetRandomGenerator(); TensorType tt = TensorType.read_tensor; final ArgumentsBuilder argsBuilder = new ArgumentsBuilder(); argsBuilder.addArgument(StandardArgumentDefinitions.VARIANT_LONG_NAME, inputVCF) @@ -177,6 +181,7 @@ public void testInference2dResourceModel() throws IOException { */ @Test(groups = {"python"}) public void testInferenceArchitecture2d() throws IOException { + Utils.resetRandomGenerator(); final boolean newExpectations = false; TensorType tt = TensorType.read_tensor; final ArgumentsBuilder argsBuilder = new ArgumentsBuilder(); @@ -202,6 +207,7 @@ public void testInferenceArchitecture2d() throws IOException { @Test(groups = {"python"}) public void testInferenceWeights2d() throws IOException { + Utils.resetRandomGenerator(); TensorType tt = TensorType.read_tensor; final ArgumentsBuilder argsBuilder = new ArgumentsBuilder(); argsBuilder.addArgument(StandardArgumentDefinitions.VARIANT_LONG_NAME, inputVCF) @@ -222,6 +228,7 @@ public void testInferenceWeights2d() throws IOException { @Test(groups = {"python"}) public void testInferenceArchitectureAndWeights2d() throws IOException { + Utils.resetRandomGenerator(); TensorType tt = TensorType.read_tensor; final ArgumentsBuilder argsBuilder = new ArgumentsBuilder(); argsBuilder.addArgument(StandardArgumentDefinitions.VARIANT_LONG_NAME, inputVCF) diff --git a/src/test/resources/large/VQSR/expected/cnn_2d_chr20_subset_expected.vcf b/src/test/resources/large/VQSR/expected/cnn_2d_chr20_subset_expected.vcf index 005e0e37986..5c86a8f518c 100644 --- a/src/test/resources/large/VQSR/expected/cnn_2d_chr20_subset_expected.vcf +++ b/src/test/resources/large/VQSR/expected/cnn_2d_chr20_subset_expected.vcf @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e12a5ad27856ea6b7bc23c121fde471ac8c9a9fc2fad7a6f73c1bd95feb99d4f +oid sha256:5b635baa9fa1f2af7d5ec5a24773a834acac924b53db715ffab11eabb9bee087 size 149982