Skip to content

Commit

Permalink
add in feature extract code. resulting vcf still has bugs (#6947)
Browse files Browse the repository at this point in the history
* add in feature extract code. resulting vcf still has bugs
  • Loading branch information
ahaessly authored and Marianie-Simeon committed Feb 16, 2021
1 parent 1631bd5 commit 165ca9f
Show file tree
Hide file tree
Showing 19 changed files with 806 additions and 274 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,39 +49,27 @@ public static String getSampleId(final String sampleName, final File sampleMap)
}
br.close();
if (sampleId == null) {
// sampleName not found
throw new UserException("Sample " + sampleName + " could not be found in sample mapping file");
// sampleId not found
throw new UserException("Sample " + sampleId + " could not be found in sample mapping file");
}
} catch (final IOException ioe) { // FileNotFoundException e,
throw new UserException("Could not find sample mapping file");
}
return sampleId;
}

// public static Path createSampleDirectory(Path parentDirectory, int sampleDirectoryNumber) {
// // If this sample set directory doesn't exist yet -- create it
// final String sampleDirectoryName = String.valueOf(sampleDirectoryNumber);
// final Path sampleDirectoryPath = parentDirectory.resolve(sampleDirectoryName);
// final File sampleDirectory = new File(sampleDirectoryPath.toString());
// if (!sampleDirectory.exists()) {
// sampleDirectory.mkdir();
// }
// return sampleDirectoryPath;
// }

// To determine which directory (and ultimately table) the sample's data will go into
// To determine which table the sample's data will go into
// Since tables have a limited number of samples (default is 4k)
public static int getTableNumber(String sampleId, int sampleMod) { // this is based on sample id
// sample ids 1-4000 will go in directory 001
int sampleIdInt = Integer.valueOf(sampleId); // TODO--should sampleId just get refactored as a long?
long sampleIdInt = Long.valueOf(sampleId);
return getTableNumber(sampleIdInt, sampleMod);
}

public static int getTableNumber(int sampleId, int sampleMod) { // this is based on sample id
public static int getTableNumber(long sampleId, int sampleMod) { // this is based on sample id
// sample ids 1-4000 will go in directory 001
int sampleIdInt = Integer.valueOf(sampleId); // TODO--should sampleId just get refactored as a long?
// subtract 1 from the sample id to make it 1-index (or do we want to 0-index?) and add 1 to the dir
int directoryNumber = Math.floorDiv((sampleIdInt - 1), sampleMod) + 1; // TODO omg write some unit tests
int directoryNumber = new Long(Math.floorDiv((sampleId - 1), sampleMod) + 1).intValue(); // TODO omg write some unit tests
return directoryNumber;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package org.broadinstitute.hellbender.tools.variantdb;

import com.google.cloud.bigquery.FieldValueList;
import com.google.cloud.bigquery.TableResult;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.*;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.utils.bigquery.BigQueryUtils;
import org.broadinstitute.hellbender.utils.bigquery.TableReference;

public class SampleList {
static final Logger logger = LogManager.getLogger(SampleList.class);

private Map<Long, String> sampleIdMap;
private Map<String, Long> sampleNameMap;

public SampleList(String sampleTableName, File sampleFile, boolean printDebugInformation) {
if (sampleTableName != null) {
initializeMaps(new TableReference(sampleTableName, SchemaUtils.SAMPLE_FIELDS), printDebugInformation);
} else if (sampleFile != null) {
initializeMaps(sampleFile);
} else {
throw new IllegalArgumentException("--cohort-sample-names or --cohort-sample-table must be provided.");
}
}

public int size() {
return sampleIdMap.size();
}

public Collection<String> getSampleNames() {
return sampleIdMap.values();
}

public String getSampleName(long id) {
return sampleIdMap.get(id);
}

public long getSampleId(String name) {
return sampleNameMap.get(name);
}

public Map<Long, String> getMap() {
return sampleIdMap;
}

// protected Map<String, Integer> getSampleNameMap(TableReference sampleTable, List<String> samples, boolean printDebugInformation) {
// Map<String, Integer> results = new HashMap<>();
// // create optional where clause
// String whereClause = "";
// if (samples != null && samples.size() > 0) {
// whereClause = " WHERE " + SchemaUtils.SAMPLE_NAME_FIELD_NAME + " in (\'" + StringUtils.join(samples, "\',\'") + "\') ";
// }
//
// TableResult queryResults = querySampleTable(sampleTable.getFQTableName(), whereClause, printDebugInformation);
//
// // Add our samples to our map:
// for (final FieldValueList row : queryResults.iterateAll()) {
// results.put(row.get(1).getStringValue(), (int) row.get(0).getLongValue());
// }
// return results;
// }


protected void initializeMaps(TableReference sampleTable, boolean printDebugInformation) {

sampleIdMap = new HashMap<>();
sampleNameMap = new HashMap<>();
TableResult queryResults = querySampleTable(sampleTable.getFQTableName(), "", printDebugInformation);

// Add our samples to our map:
for (final FieldValueList row : queryResults.iterateAll()) {
long id = row.get(0).getLongValue();
String name = row.get(1).getStringValue();
sampleIdMap.put(id, name);
sampleNameMap.put(name, id);
}
}

protected void initializeMaps(File cohortSampleFile) {
try {
Files.readAllLines(cohortSampleFile.toPath(), StandardCharsets.US_ASCII).stream()
.map(s -> s.split(","))
.forEach(tokens -> {
long id = Long.parseLong(tokens[0]);
String name = tokens[1];
sampleIdMap.put(id, name);
sampleNameMap.put(name, id);
});
} catch (IOException e) {
throw new IllegalArgumentException("Could not parse --cohort-sample-file", e);
}
}

private TableResult querySampleTable(String fqSampleTableName, String whereClause, boolean printDebugInformation) {
// Get the query string:
final String sampleListQueryString =
"SELECT " + SchemaUtils.SAMPLE_ID_FIELD_NAME + ", " + SchemaUtils.SAMPLE_NAME_FIELD_NAME +
" FROM `" + fqSampleTableName + "`" + whereClause;


// Execute the query:
final TableResult result = BigQueryUtils.executeQuery(sampleListQueryString);

// Show our pretty results:
if (printDebugInformation) {
logger.info("Sample names returned:");
final String prettyQueryResults = BigQueryUtils.getResultDataPrettyString(result);
logger.info("\n" + prettyQueryResults);
}

return result;
}

}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package org.broadinstitute.hellbender.tools.variantdb;

import com.google.common.collect.ImmutableSet;
import org.apache.avro.generic.GenericRecord;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;

public class SchemaUtils {
Expand All @@ -24,8 +26,13 @@ public class SchemaUtils {
public static final String ALT_ALLELE_FIELD_NAME = "alt";

public static final String GENOTYPE_FIELD_PREFIX = "call_";
public static final String AS_FIELD_PREFIX = "AS_";
public static final String MULTIVALUE_FIELD_DELIMITER = ",";

public static final String CALL_GT = GENOTYPE_FIELD_PREFIX + "GT";
public static final String CALL_GQ = GENOTYPE_FIELD_PREFIX + "GQ";
public static final String CALL_RGQ = GENOTYPE_FIELD_PREFIX + "RGQ";

public static final ImmutableSet<String> REQUIRED_FIELDS = ImmutableSet.of(
SAMPLE_NAME_FIELD_NAME,
// SAMPLE_FIELD_NAME,
Expand All @@ -34,18 +41,51 @@ public class SchemaUtils {
ALT_ALLELE_FIELD_NAME
);

public static final List<String> COHORT_FIELDS = Arrays.asList(LOCATION_FIELD_NAME, SAMPLE_NAME_FIELD_NAME, STATE_FIELD_NAME, REF_ALLELE_FIELD_NAME, ALT_ALLELE_FIELD_NAME, "call_GT", "call_GQ", "call_RGQ");
public static final List<String> ARRAY_COHORT_FIELDS = Arrays.asList(LOCATION_FIELD_NAME, SAMPLE_NAME_FIELD_NAME, STATE_FIELD_NAME, REF_ALLELE_FIELD_NAME, ALT_ALLELE_FIELD_NAME, "call_GT", "call_GQ");
public static final List<String> COHORT_FIELDS = Arrays.asList(LOCATION_FIELD_NAME, SAMPLE_NAME_FIELD_NAME, STATE_FIELD_NAME, REF_ALLELE_FIELD_NAME, ALT_ALLELE_FIELD_NAME, CALL_GT, CALL_GQ, CALL_RGQ);
public static final List<String> ARRAY_COHORT_FIELDS = Arrays.asList(LOCATION_FIELD_NAME, SAMPLE_NAME_FIELD_NAME, STATE_FIELD_NAME, REF_ALLELE_FIELD_NAME, ALT_ALLELE_FIELD_NAME, CALL_GT, CALL_GQ);

public static final List<String> RAW_ARRAY_COHORT_FIELDS_COMPRESSED =
Arrays.asList(BASIC_ARRAY_DATA_FIELD_NAME, RAW_ARRAY_DATA_FIELD_NAME);

public static final List<String> RAW_ARRAY_COHORT_FIELDS_UNCOMPRESSED =
Arrays.asList(SAMPLE_ID_FIELD_NAME, "probe_id", "GT_encoded","NORMX","NORMY","BAF","LRR");

public static final List<String> SAMPLE_FIELDS = Arrays.asList(SAMPLE_NAME_FIELD_NAME);
// exomes & genomes
public static final String RAW_QUAL = "RAW_QUAL";
public static final String RAW_MQ = "RAW_MQ";
public static final String AS_RAW_MQ = AS_FIELD_PREFIX + RAW_MQ;
public static final String AS_MQRankSum = AS_FIELD_PREFIX + "MQRankSum";
public static final String AS_RAW_MQRankSum = AS_FIELD_PREFIX + "RAW_MQRankSum";
public static final String AS_QUALapprox = AS_FIELD_PREFIX + "QUALapprox";
public static final String AS_RAW_ReadPosRankSum = AS_FIELD_PREFIX + "RAW_ReadPosRankSum";
public static final String AS_ReadPosRankSum = AS_FIELD_PREFIX + "ReadPosRankSum";
public static final String AS_SB_TABLE = AS_FIELD_PREFIX + "SB_TABLE";
public static final String AS_VarDP = AS_FIELD_PREFIX + "VarDP";
public static final String CALL_AD = GENOTYPE_FIELD_PREFIX + "AD";
public static final String RAW_AD = "RAW_AD";
public static final String CALL_PGT = GENOTYPE_FIELD_PREFIX + "PGT";
public static final String CALL_PID = GENOTYPE_FIELD_PREFIX + "PID";
public static final String CALL_PL = GENOTYPE_FIELD_PREFIX + "PL";

public static final List<String> SAMPLE_FIELDS = Arrays.asList(SchemaUtils.SAMPLE_NAME_FIELD_NAME, SchemaUtils.SAMPLE_ID_FIELD_NAME);
public static final List<String> YNG_FIELDS = Arrays.asList(LOCATION_FIELD_NAME, REF_ALLELE_FIELD_NAME, ALT_ALLELE_FIELD_NAME);

public static final List<String> PET_FIELDS = Arrays.asList(LOCATION_FIELD_NAME, SAMPLE_ID_FIELD_NAME, STATE_FIELD_NAME);
public static final List<String> VET_FIELDS = Arrays.asList(SAMPLE_ID_FIELD_NAME, LOCATION_FIELD_NAME, REF_ALLELE_FIELD_NAME, ALT_ALLELE_FIELD_NAME, AS_RAW_MQ,
AS_RAW_MQRankSum, AS_QUALapprox, AS_RAW_ReadPosRankSum, AS_SB_TABLE, AS_VarDP, CALL_GT, CALL_AD, CALL_GQ, CALL_PGT, CALL_PID, CALL_PL);
public static final List<String> ALT_ALLELE_FIELDS = Arrays.asList(LOCATION_FIELD_NAME, SAMPLE_ID_FIELD_NAME, REF_ALLELE_FIELD_NAME, "allele", ALT_ALLELE_FIELD_NAME, "allele_pos", CALL_GT, AS_RAW_MQ, RAW_MQ, AS_RAW_MQRankSum, "raw_mqranksum_x_10", AS_QUALapprox, "qual", AS_RAW_ReadPosRankSum, "raw_readposranksum_x_10", AS_SB_TABLE, "SB_REF_PLUS","SB_REF_MINUS","SB_ALT_PLUS","SB_ALT_MINUS", CALL_AD, "ref_ad", "ad");
public static final List<String> FEATURE_EXTRACT_FIELDS = Arrays.asList(LOCATION_FIELD_NAME, REF_ALLELE_FIELD_NAME, "allele", RAW_QUAL, "ref_ad", AS_MQRankSum, "AS_MQRankSum_ft", AS_ReadPosRankSum, "AS_ReadPosRankSum_ft", RAW_MQ, RAW_AD, "RAW_AD_GT_1", "SB_REF_PLUS","SB_REF_MINUS","SB_ALT_PLUS","SB_ALT_MINUS");


// private String getPositionTableForContig(final String contig ) {
// return contigToPositionTableMap.get(contig);
// }
//
// private String getVariantTableForContig(final String contig ) {
// return contigToVariantTableMap.get(contig);
// }
//

// private void validateSchema(final Set<String> columnNames) {
// for ( final String requiredField : REQUIRED_FIELDS ) {
// if ( ! columnNames.contains(requiredField) ) {
Expand All @@ -71,4 +111,14 @@ public static String decodeContig(long location) {
public static int decodePosition(long location) {
return (int)(location % chromAdjustment);
}

public final static Comparator<GenericRecord> LOCATION_COMPARATOR = new Comparator<GenericRecord>() {
@Override
public int compare( GenericRecord o1, GenericRecord o2 ) {
final long firstLocation = (Long) o1.get(SchemaUtils.LOCATION_FIELD_NAME);
final long secondLocation = (Long) o2.get(SchemaUtils.LOCATION_FIELD_NAME);
return Long.compare(firstLocation, secondLocation);
}
};

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.broadinstitute.hellbender.engine.GATKTool;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.variantdb.IngestConstants;
import org.broadinstitute.hellbender.tools.variantdb.SchemaUtils;
import org.broadinstitute.hellbender.tools.variantdb.arrays.tables.GenotypeCountsSchema;
import org.broadinstitute.hellbender.utils.GenotypeCounts;
import org.broadinstitute.hellbender.utils.bigquery.StorageAPIAvroReader;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
import org.broadinstitute.hellbender.cmdline.programgroups.ShortVariantDiscoveryProgramGroup;
import org.broadinstitute.hellbender.engine.GATKTool;
import org.broadinstitute.hellbender.tools.variantdb.CommonCode;
import org.broadinstitute.hellbender.tools.variantdb.SampleList;
import org.broadinstitute.hellbender.tools.variantdb.arrays.tables.ProbeInfo;
import org.broadinstitute.hellbender.tools.variantdb.arrays.tables.ProbeQcMetrics;
import org.broadinstitute.hellbender.tools.variantdb.arrays.tables.SampleList;
import org.broadinstitute.hellbender.tools.variantdb.nextgen.ExtractCohort;
import org.broadinstitute.hellbender.tools.walkers.annotator.Annotation;
import org.broadinstitute.hellbender.tools.walkers.annotator.StandardAnnotation;
import org.broadinstitute.hellbender.tools.walkers.annotator.VariantAnnotatorEngine;
import org.broadinstitute.hellbender.tools.walkers.annotator.allelespecific.AS_StandardAnnotation;
import org.broadinstitute.hellbender.utils.bigquery.TableReference;
import org.broadinstitute.hellbender.utils.io.IOUtils;

import java.util.*;
Expand Down Expand Up @@ -199,16 +198,9 @@ protected void onStartup() {

vcfWriter = createVCFWriter(IOUtils.getPath(outputVcfPathString));

Map<Integer, String> sampleIdMap;
if (sampleTableName != null) {
sampleIdMap = SampleList.getSampleIdMap(new TableReference(sampleTableName, SampleList.SAMPLE_LIST_FIELDS), printDebugInformation);
} else if (cohortSampleFile != null) {
sampleIdMap = SampleList.getSampleIdMap(cohortSampleFile);
} else {
throw new IllegalArgumentException("--cohort-sample-names or --cohort-sample-table must be provided.");
}

VCFHeader header = CommonCode.generateRawArrayVcfHeader(new HashSet<>(sampleIdMap.values()), reference.getSequenceDictionary());
SampleList sampleIdMap = new SampleList(sampleTableName, cohortSampleFile, printDebugInformation);
// Map<Integer, String> sampleIdMap;
VCFHeader header = CommonCode.generateRawArrayVcfHeader(new HashSet<>(sampleIdMap.getSampleNames()), reference.getSequenceDictionary());

Map<Long, ProbeInfo> probeIdMap;
if (probeCsvExportFile == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.engine.ProgressMeter;
import org.broadinstitute.hellbender.engine.ReferenceDataSource;
import org.broadinstitute.hellbender.tools.variantdb.SampleList;
import org.broadinstitute.hellbender.tools.variantdb.arrays.BasicArrayData.ArrayGenotype;
import org.broadinstitute.hellbender.tools.variantdb.CommonCode;
import org.broadinstitute.hellbender.tools.variantdb.SchemaUtils;
Expand All @@ -17,6 +18,7 @@
import org.broadinstitute.hellbender.tools.walkers.annotator.VariantAnnotatorEngine;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.bigquery.*;
import org.broadinstitute.hellbender.utils.localsort.AvroSortingCollection;
import org.broadinstitute.hellbender.utils.localsort.SortingCollection;

import java.text.DecimalFormat;
Expand Down Expand Up @@ -46,7 +48,7 @@ public class ArrayExtractCohortEngine {
private final String readProjectId;

/** List of sample names seen in the variant data from BigQuery. */
private final Map<Integer, String> sampleIdMap;
private final SampleList sampleIdMap;
private final Set<String> sampleNames;

private final Map<Long, ProbeInfo> probeIdMap;
Expand All @@ -69,7 +71,7 @@ public ArrayExtractCohortEngine(final String readProjectId,
final VCFHeader vcfHeader,
final VariantAnnotatorEngine annotationEngine,
final ReferenceDataSource refSource,
final Map<Integer, String> sampleIdMap,
final SampleList sampleIdMap,
final Map<Long, ProbeInfo> probeIdMap,
final Map<Long, ProbeQcMetrics> probeQcMetricsMap,
final String cohortTableName,
Expand All @@ -95,13 +97,13 @@ public ArrayExtractCohortEngine(final String readProjectId,
this.vcfWriter = vcfWriter;
this.refSource = refSource;
this.sampleIdMap = sampleIdMap;
this.sampleNames = new HashSet<>(sampleIdMap.values());
this.sampleNames = new HashSet<>(sampleIdMap.getSampleNames());
this.gtDataOnly = gtDataOnly;

this.probeIdMap = probeIdMap;
this.probeQcMetricsMap = probeQcMetricsMap;

this.cohortTableRef = new TableReference(cohortTableName, useCompressedData? SchemaUtils.RAW_ARRAY_COHORT_FIELDS_COMPRESSED:SchemaUtils.RAW_ARRAY_COHORT_FIELDS_UNCOMPRESSED);
this.cohortTableRef = new TableReference(cohortTableName, SchemaUtils.RAW_ARRAY_COHORT_FIELDS_UNCOMPRESSED);
this.minProbeId = minProbeId;
this.maxProbeId = maxProbeId;
// this.useCompressedData = useCompressedData;
Expand Down Expand Up @@ -146,7 +148,7 @@ private void createVariantsFromUngroupedTableResult(final GATKAvroReader avroRea

Comparator<GenericRecord> comparator = UNCOMPRESSED_PROBE_ID_COMPARATOR;

SortingCollection<GenericRecord> sortingCollection = getAvroProbeIdSortingCollection(schema, localSortMaxRecordsInRam, comparator);
SortingCollection<GenericRecord> sortingCollection = AvroSortingCollection.getAvroSortingCollection(schema, localSortMaxRecordsInRam, comparator);
for ( final GenericRecord queryRow : avroReader ) {
sortingCollection.add(queryRow);
}
Expand Down Expand Up @@ -213,7 +215,7 @@ private void processSampleRecordsForLocation(final long probeId, final Iterable<
// }

// TODO: handle missing values
String sampleName = sampleIdMap.get((int) sampleId);
String sampleName = sampleIdMap.getSampleName((int) sampleId);
currentPositionSamplesSeen.add(sampleName);

++numRecordsAtPosition;
Expand Down
Loading

0 comments on commit 165ca9f

Please sign in to comment.