From 3b66062cd7506ad7cf821a0aa0bb262aa8f00f6b Mon Sep 17 00:00:00 2001 From: Vanessa Botha <31652336+VanessaBotha@users.noreply.github.com> Date: Fri, 1 Sep 2023 12:31:50 +0200 Subject: [PATCH] some fixes & add tcga study codes --- .../_database_manager.py} | 322 ++++++++---------- .../tcga_files_to_study_types.2023-08-20.json | 0 .../database_label_data}/tcga_study_codes.txt | 0 .../generate_tcga_database.py | 183 ++++++++++ .../generate_tcga_labels.py | 2 +- 5 files changed, 317 insertions(+), 190 deletions(-) rename tools/{generate_tcga_split.py => generate_database/_database_manager.py} (51%) rename tools/{ => generate_database/database_label_data}/tcga_files_to_study_types.2023-08-20.json (100%) rename tools/{ => generate_database/database_label_data}/tcga_study_codes.txt (100%) create mode 100644 tools/generate_database/generate_tcga_database.py rename tools/{ => generate_database}/generate_tcga_labels.py (96%) diff --git a/tools/generate_tcga_split.py b/tools/generate_database/_database_manager.py similarity index 51% rename from tools/generate_tcga_split.py rename to tools/generate_database/_database_manager.py index 87b00e4..6b03487 100644 --- a/tools/generate_tcga_split.py +++ b/tools/generate_database/_database_manager.py @@ -1,82 +1,105 @@ # encoding: utf-8 -import json +from __future__ import annotations + import random -import sqlite3 -from pathlib import Path from typing import Generator +import numpy as np +import sqlite3 from pydantic import BaseModel -from tqdm import tqdm -BATCH_INSERT_SIZE = 100 +def create_dataset_split( + patient_ids: list[int], + split_percentages: tuple[float, float, float], +) -> tuple[list[int], list[int], list[int]]: + """ + Splits data randomly into train, validate and test set using the given ratios. + + Parameters + ---------- + patient_ids: list[int] + List of patient identifiers to be split. + split_percentages: tuple[int, int, int] + Tuple of ratios (train_ratio, validate_ratio, test_ratio) used to split the data. + + Returns + ------- + tuple[list[int], list[int], list[int]] + List of patient identifiers per data subset; train, validate and test + """ + + if len(split_percentages) != 3: + raise ValueError("Split percentages must contain three values: (train, validate, test)") + + if np.sum(split_percentages) != 100: + raise ValueError("Split percentages must sum to 100.") + + random.shuffle(patient_ids) + + n_patients = len(patient_ids) + fit_percentage, validate_percentage, test_percentage = split_percentages + + train_end_idx = int((n_patients * fit_percentage) / 100) + val_end_idx = train_end_idx + int((n_patients * validate_percentage) / 100) + test_end_idx = n_patients + + train_split = patient_ids[0:train_end_idx] + validate_split = patient_ids[train_end_idx:val_end_idx] + test_split = patient_ids[val_end_idx:test_end_idx] -def find_number_of_tiles(folder_path: Path) -> int: - with open(folder_path / "meta_data_tiles.json") as json_file: - meta_data_tiles = json.load(json_file) - return meta_data_tiles["num_regions_masked"] + return train_split, validate_split, test_split class PatientInfo(BaseModel): - patient_id: int - patient_name: str + id: int + patient_code: str class FolderInfo(BaseModel): path: str num_files: int - patient_id: str - + patient_code: str class SplitDefinition(BaseModel): version: str description: str - class SplitInfo(BaseModel): - folder_id: int + patient_id: int split_definition_id: int category: str # This will hold either "train", "test", or "validate" - -class PatientLabelInfo(BaseModel): - patient_id: str - label_slug: str - label_value: str - - -class PatientLabelAssignment(BaseModel): - patient_id: str - label_slug: str - label_value: str - - class LabelCategoryInfo(BaseModel): description: str slug: str values: list[str] - class LabelValueInfo(BaseModel): - category_slug: str + label_category_slug: str value: str +class PatientLabelAssignment(BaseModel): + patient_code: str + label_category_slug: str + label_value: str -def get_patient_id(path: Path) -> str: - return path.name[:12] +class PatientLabelsInfo(BaseModel): + patient_id: int + label_value_id: int class DatabaseManager: - def __init__(self, db_name="tcga_tiled.sqlite"): + def __init__(self, db_name: str = "tcga_tiled.sqlite"): self.db_name = db_name - def _connect(self): + def _connect(self) -> sqlite3.Connection: connection = sqlite3.connect(self.db_name) # Enabling foreign key constraints connection.execute("PRAGMA foreign_keys = ON;") return connection - def create_tables(self): + def create_tables(self) -> None: with self._connect() as connection: cursor = connection.cursor() @@ -85,7 +108,7 @@ def create_tables(self): """ CREATE TABLE IF NOT EXISTS patients ( id INTEGER PRIMARY KEY, - patient_id TEXT UNIQUE NOT NULL + patient_code TEXT UNIQUE NOT NULL ); """ ) @@ -119,15 +142,16 @@ def create_tables(self): """ CREATE TABLE IF NOT EXISTS split_info ( id INTEGER PRIMARY KEY, - folder_id INTEGER NOT NULL, + patient_id INTEGER NOT NULL, split_definition_id INTEGER NOT NULL, category TEXT NOT NULL, -- This will hold either "train", "test", or "validate" - FOREIGN KEY (folder_id) REFERENCES folder_info(id), + FOREIGN KEY (patient_id) REFERENCES patients(id), FOREIGN KEY (split_definition_id) REFERENCES split_definitions(id) ); """ ) + # label_categories table cursor.execute( """ CREATE TABLE IF NOT EXISTS label_categories ( @@ -150,7 +174,7 @@ def create_tables(self): """ ) - # label mapping table + # patient_labels table, mapping patient to label cursor.execute( """ CREATE TABLE IF NOT EXISTS patient_labels ( @@ -163,29 +187,36 @@ def create_tables(self): """ ) - def insert_folder_info(self, folder_infos: list[FolderInfo]): + def insert_folder_info(self, folder_infos: list[FolderInfo]) -> None: with self._connect() as connection: cursor = connection.cursor() for info in folder_infos: - patient_id = get_patient_id(Path(info.path)) + patient_code = info.patient_code - # Check if patient_id exists - cursor.execute("SELECT id FROM patients WHERE patient_id = ?", (patient_id,)) + # Try to get the patient id from the database that matches the patient code. + cursor.execute("SELECT id FROM patients WHERE patient_code = ?", (patient_code,)) row = cursor.fetchone() - if row: - patient_db_id = row[0] + # Insert the patient code if it does not already exist in the database. + if row is not None: + patient_id = row[0] else: - cursor.execute("INSERT INTO patients (patient_id) VALUES (?)", (patient_id,)) - patient_db_id = cursor.lastrowid + cursor.execute("INSERT INTO patients (patient_code) VALUES (?)", (patient_code,)) + patient_id = cursor.lastrowid cursor.execute( "INSERT INTO folder_info (path, num_files, patient_id) VALUES (?, ?, ?)", - (info.path, info.num_files, patient_db_id), + (info.path, info.num_files, patient_id), ) - def create_random_split(self, version, split_ratios, description="", patients_generator=None): + def create_random_split( + self, + version: str, + split_ratios: tuple[int, int, int], + description: str = "", + patients_generator: Generator | None = None, + ) -> None: with self._connect() as connection: cursor = connection.cursor() @@ -197,52 +228,46 @@ def create_random_split(self, version, split_ratios, description="", patients_ge print(f"Split version '{version}' already exists. Skipping...") return + definition = SplitDefinition(version=version, description=description) + cursor.execute( - "INSERT INTO split_definitions (version, description) VALUES (?, ?)", (version, description) + "INSERT INTO split_definitions (version, description) VALUES (?, ?)", + (definition.version, definition.description), ) split_definition_id = cursor.lastrowid # Fetch patients either from the generator or all distinct patient IDs if the generator is not provided if patients_generator: - patient_ids = [patient_info.patient_id for patient_info in patients_generator] + patient_ids = [patient_info.id for patient_info in patients_generator] else: cursor.execute("SELECT DISTINCT patient_id FROM folder_info") patient_ids = [row[0] for row in cursor.fetchall()] - random.shuffle(patient_ids) - - train_end = int(len(patient_ids) * split_ratios[0]) - test_end = train_end + int(len(patient_ids) * split_ratios[1]) - - train_patient_ids = patient_ids[:train_end] - test_patient_ids = patient_ids[train_end:test_end] - validate_patient_ids = patient_ids[test_end:] - - train_ids = self._get_folder_ids_for_patients(cursor, train_patient_ids) - test_ids = self._get_folder_ids_for_patients(cursor, test_patient_ids) - validate_ids = self._get_folder_ids_for_patients(cursor, validate_patient_ids) + train_patient_ids, validate_patient_ids, test_patient_ids = create_dataset_split( + patient_ids=patient_ids, split_percentages=split_ratios + ) # Insert into the split_info table - self._insert_into_split(cursor, train_ids, "train", split_definition_id) - self._insert_into_split(cursor, test_ids, "test", split_definition_id) - self._insert_into_split(cursor, validate_ids, "validate", split_definition_id) + self._insert_into_split(cursor, train_patient_ids, "train", split_definition_id) + self._insert_into_split(cursor, validate_patient_ids, "validate", split_definition_id) + self._insert_into_split(cursor, test_patient_ids, "test", split_definition_id) @staticmethod - def _get_folder_ids_for_patients(cursor, patient_ids): + def _get_folder_ids_for_patients(cursor: sqlite3.Cursor, patient_ids: list[int]): placeholder = ",".join("?" * len(patient_ids)) cursor.execute(f"SELECT id FROM folder_info WHERE patient_id IN ({placeholder})", tuple(patient_ids)) return [row[0] for row in cursor.fetchall()] @staticmethod - def _insert_into_split(cursor, ids, category, split_definition_id): - for folder_id in ids: - data = SplitInfo(folder_id=folder_id, split_definition_id=split_definition_id, category=category) + def _insert_into_split(cursor: sqlite3.Cursor, patient_ids: list[int], category: str, split_definition_id: int) -> None: + for patient_id in patient_ids: + info = SplitInfo(patient_id=patient_id, split_definition_id=split_definition_id, category=category) cursor.execute( - "INSERT INTO split_info (folder_id, category, split_definition_id) VALUES (?, ?, ?)", - (data.folder_id, data.category, data.split_definition_id), + "INSERT INTO split_info (patient_id, split_definition_id, category) VALUES (?, ?, ?)", + (info.patient_id, info.split_definition_id, info.category), ) - def insert_label_category(self, category_info: LabelCategoryInfo): + def insert_label_category(self, category_info: LabelCategoryInfo) -> None: with self._connect() as connection: cursor = connection.cursor() cursor.execute( @@ -250,15 +275,15 @@ def insert_label_category(self, category_info: LabelCategoryInfo): (category_info.description, category_info.slug), ) - def insert_label_value(self, value_info: LabelValueInfo): + def insert_label_value(self, value_info: LabelValueInfo) -> None: with self._connect() as connection: cursor = connection.cursor() # Fetch the category_id based on the slug - cursor.execute("SELECT id FROM label_categories WHERE slug = ?", (value_info.category_slug,)) + cursor.execute("SELECT id FROM label_categories WHERE slug = ?", (value_info.label_category_slug,)) row = cursor.fetchone() if not row: - print(f"No category found for slug '{value_info.category_slug}'. Skipping...") + print(f"No category found for slug '{value_info.label_category_slug}'. Skipping...") return category_id = row[0] @@ -266,11 +291,12 @@ def insert_label_value(self, value_info: LabelValueInfo): "INSERT INTO label_values (category_id, value) VALUES (?, ?)", (category_id, value_info.value) ) - def assign_label_to_patient(self, patient_label_info: PatientLabelInfo): + def assign_label_to_patient(self, patient_label_assignment: PatientLabelAssignment, strict: bool = False) -> None: """ - Assigns a label to a patient using a PatientLabelInfo model. + Assigns a label to a patient using a PatientLabelAssignment model. - :param patient_label_info: The info about the patient and label to assign. + :param patient_label_assignment: The info about the patient and label to assign. + :param strict: if set to false, the assigment will be skipped if the patient is not in the database. """ with self._connect() as connection: cursor = connection.cursor() @@ -283,15 +309,31 @@ def assign_label_to_patient(self, patient_label_info: PatientLabelInfo): JOIN label_categories lc ON lv.category_id = lc.id WHERE lc.slug = ? AND lv.value = ? """, - (patient_label_info.label_slug, patient_label_info.label_value), + (patient_label_assignment.label_category_slug, patient_label_assignment.label_value), ) - label_value_id = cursor.fetchone() - - if not label_value_id: + _label_value_id = cursor.fetchone() + if _label_value_id is None: raise ValueError( - f"Label '{patient_label_info.label_value}' for slug '{patient_label_info.label_slug}' not found." + f"Label '{patient_label_assignment.label_value}' for slug '{patient_label_assignment.label_category_slug}' not found." ) - label_value_id = label_value_id[0] + + label_value_id = _label_value_id[0] + + # Fetch patient_id using patient_code + cursor.execute("SELECT id FROM patients WHERE patient_code = ?", (patient_label_assignment.patient_code,)) + + _patient_id = cursor.fetchone() + if _patient_id is None: + _message = f"Patient code '{patient_label_assignment.patient_code}' not found in database." + if strict: + raise ValueError(_message) + else: + print(f"{_message} Skipping...") + return + + patient_id = _patient_id[0] + + patient_labels_info = PatientLabelsInfo(patient_id=patient_id, label_value_id=label_value_id) # Insert the relation cursor.execute( @@ -299,26 +341,26 @@ def assign_label_to_patient(self, patient_label_info: PatientLabelInfo): INSERT OR IGNORE INTO patient_labels (patient_id, label_value_id) VALUES (?, ?) """, - (patient_label_info.patient_id, label_value_id), + (patient_labels_info.patient_id, patient_labels_info.label_value_id), ) - def get_patients_by_label_category(self, label_category: LabelCategoryInfo) -> Generator[PatientInfo, None, None]: + def get_patients_by_label_category(self, label_category_slug: str) -> Generator[PatientInfo, None, None]: with self._connect() as connection: cursor = connection.cursor() # First, we need to fetch the ID for the given label category slug - cursor.execute("SELECT id FROM label_categories WHERE slug = ?", (label_category.slug,)) + cursor.execute("SELECT id FROM label_categories WHERE slug = ?", (label_category_slug,)) category_id = cursor.fetchone() if not category_id: - raise ValueError(f"No category found for slug '{label_category.slug}'.") + raise ValueError(f"No category found for slug '{label_category_slug}'.") category_id = category_id[0] # Next, fetch patient IDs that have a label associated with this category cursor.execute( """ - SELECT p.id, p.patient_id + SELECT p.id, p.patient_code FROM patients p JOIN patient_labels pl ON p.id = pl.patient_id JOIN label_values lv ON pl.label_value_id = lv.id @@ -329,102 +371,4 @@ def get_patients_by_label_category(self, label_category: LabelCategoryInfo) -> G # Fetching patients and yielding as generator for row in cursor.fetchall(): - yield PatientInfo(id=row[0], patient_id=row[1]) - - -def populate_with_tcga_tiled(db): - original_path = Path("/projects/tcga_tiled/v1/data") - - all_tcgas = original_path.glob("*/*") - infos_to_insert = [] - - for idx, folder_path in tqdm(enumerate(all_tcgas)): - num_files = find_number_of_tiles(folder_path) - patient_id = get_patient_id(folder_path) - info = FolderInfo(path=str(folder_path.relative_to(original_path)), num_files=num_files, patient_id=patient_id) - infos_to_insert.append(info) - - if len(infos_to_insert) >= BATCH_INSERT_SIZE: - db.insert_folder_info(infos_to_insert) - infos_to_insert = [] - - # Easy for debugging. - # if idx > 100: - # break - - if infos_to_insert: - db.insert_folder_info(infos_to_insert) - - -def populate_label_categories_and_values_with_dummies(db: DatabaseManager): - # Defining categories and their respective values - categories = [ - LabelCategoryInfo(description="Tumor Type", slug="tumor_type", values=["Melanoma", "Carcinoma", "Sarcoma"]), - LabelCategoryInfo(description="Tumor Stage", slug="tumor_stage", values=["I", "II", "III", "IV"]), - LabelCategoryInfo( - description="Treatment Response", slug="treatment_response", values=["Positive", "Negative", "Neutral"] - ), - ] - - # Inserting the categories and their values into the database - for category_data in categories: - db.insert_label_category(category_data) - for value in category_data.values: - value_info = LabelValueInfo(category_slug=category_data.slug, value=value) - db.insert_label_value(value_info) - - -def assign_dummy_labels_to_patients(db: DatabaseManager): - # Dummy label assignments for demonstration purposes - dummy_assignments = ( - [ - PatientLabelAssignment( - patient_id=str(i), - label_slug="tumor_type", - label_value=random.choice(["Melanoma", "Carcinoma", "Sarcoma"]), - ) - for i in range(1, 11) - ] - + [ - PatientLabelAssignment( - patient_id=str(i), label_slug="tumor_stage", label_value=random.choice(["I", "II", "III", "IV"]) - ) - for i in range(1, 11) - ] - + [ - PatientLabelAssignment( - patient_id=str(i), - label_slug="treatment_response", - label_value=random.choice(["Positive", "Negative", "Neutral"]), - ) - for i in range(1, 11) - ] - ) - - # Assign the labels to the patients - for assignment in dummy_assignments: - label_info = PatientLabelInfo( - patient_id=assignment.patient_id, label_slug=assignment.label_slug, label_value=assignment.label_value - ) - db.assign_label_to_patient(label_info) - - -if __name__ == "__main__": - db = DatabaseManager() - db.create_tables() - populate_with_tcga_tiled(db) - - # Assign splits based on different versions and ratios - db.create_random_split("v1", (0.8, 0.1, 0.1), "80/10/10 split") - db.create_random_split("v2", (0.7, 0.2, 0.1), "70/20/10 split") - - populate_label_categories_and_values_with_dummies(db) - assign_dummy_labels_to_patients(db) - - # Example, get all patients with a specific label. - label_category = LabelCategoryInfo( - description="Tumor Type", slug="tumor_type", values=["Melanoma", "Carcinoma", "Sarcoma"] - ) - filtered_patients = db.get_patients_by_label_category(label_category) - # Create split that has the patients in there - db.create_random_split("v3", (0.8, 0.1, 0.1), "label split", filtered_patients) + yield PatientInfo(id=row[0], patient_code=row[1]) diff --git a/tools/tcga_files_to_study_types.2023-08-20.json b/tools/generate_database/database_label_data/tcga_files_to_study_types.2023-08-20.json similarity index 100% rename from tools/tcga_files_to_study_types.2023-08-20.json rename to tools/generate_database/database_label_data/tcga_files_to_study_types.2023-08-20.json diff --git a/tools/tcga_study_codes.txt b/tools/generate_database/database_label_data/tcga_study_codes.txt similarity index 100% rename from tools/tcga_study_codes.txt rename to tools/generate_database/database_label_data/tcga_study_codes.txt diff --git a/tools/generate_database/generate_tcga_database.py b/tools/generate_database/generate_tcga_database.py new file mode 100644 index 0000000..70db9e0 --- /dev/null +++ b/tools/generate_database/generate_tcga_database.py @@ -0,0 +1,183 @@ +# encoding: utf-8 +from __future__ import annotations + +import json +import random +import os +from pathlib import Path +from tqdm import tqdm + +from _database_manager import DatabaseManager +from _database_manager import FolderInfo, LabelCategoryInfo, LabelValueInfo, PatientLabelAssignment + +TILES_ROOT_PATH = "/projects/tcga_tiled/v1/data" +META_ROOT_PATH = "/data/groups/aiforoncology/archive/pathology/TCGA/metadata" +BATCH_INSERT_SIZE = 100 +DEBUG = False + + +def find_number_of_tiles(folder_path: Path) -> int: + with open(folder_path / "meta_data_tiles.json") as json_file: + meta_data_tiles = json.load(json_file) + return meta_data_tiles["num_regions_masked"] + + +def get_patient_code(path: Path) -> str: + return path.name[:12] + + +def populate_with_tcga_tiled(db: DatabaseManager) -> None: + tiles_root_path = Path(TILES_ROOT_PATH) + all_tcgas = tiles_root_path.glob("*/*") + infos_to_insert = [] + + for idx, folder_path in tqdm(enumerate(all_tcgas)): + num_files = find_number_of_tiles(folder_path) + patient_code = get_patient_code(folder_path) + info = FolderInfo( + path=str(folder_path.relative_to(tiles_root_path)), num_files=num_files, patient_code=patient_code + ) + infos_to_insert.append(info) + + if len(infos_to_insert) >= BATCH_INSERT_SIZE: + db.insert_folder_info(infos_to_insert) + infos_to_insert = [] + + if DEBUG: + # Easy for debugging. + if idx > 100: + break + + if infos_to_insert: + db.insert_folder_info(infos_to_insert) + + + +def populate_with_tcga_labels(db: DatabaseManager) -> None: + root_label_data = os.path.join(os.path.dirname(__file__), "database_label_data") + with open(Path(root_label_data) / "tcga_study_codes.txt" , "r", encoding="utf-8") as file: + study_codes = [line.strip().split("\t")[0] for line in file.readlines() if line != ""] + + # Defining categories and their respective values + categories = [ + LabelCategoryInfo(description="TCGA study codes", slug="tcga_study_codes", values=study_codes), + ] + + # Inserting the categories and their values into the database + for category_info in categories: + if len(category_info.values) != len(set(category_info.values)): + print(f"Could not add label category {category_info.slug}. Label values within a category must be unique. Skipping...") + continue + + db.insert_label_category(category_info) + + for value in category_info.values: + value_info = LabelValueInfo(label_category_slug=category_info.slug, value=value) + db.insert_label_value(value_info) + + +def assign_labels_to_patients(db: DatabaseManager) -> None: + meta_root_path = Path(META_ROOT_PATH) + meta_basic_path = meta_root_path / "metadata_basic" + + meta_basic_files = meta_basic_path.glob("**/*diagnostic*.txt") + patient_study_code_mapping = {} + for meta_file_path in meta_basic_files: + with open(meta_file_path, "r") as txt_file: + lines = txt_file.readlines() + for line in lines[1:]: # Skip header + columns = line.split("\t") + patient_code = columns[2] + study_code = columns[3][5:] + + patient_study_code_mapping[patient_code] = study_code + + assignments = [] + for patient_code, study_code in patient_study_code_mapping.items(): + assignments.append(PatientLabelAssignment( + patient_code=patient_code, + label_category_slug="tcga_study_codes", + label_value=study_code, + )) + + # Assign the labels to the patients in the database. + for patient_label_assignment in assignments: + db.assign_label_to_patient(patient_label_assignment, strict=False) + + + +def populate_label_categories_and_values_with_dummies(db: DatabaseManager): + # Defining categories and their respective values + categories = [ + LabelCategoryInfo(description="Tumor Type", slug="tumor_type", values=["Melanoma", "Carcinoma", "Sarcoma"]), + LabelCategoryInfo(description="Tumor Stage", slug="tumor_stage", values=["I", "II", "III", "IV"]), + LabelCategoryInfo( + description="Treatment Response", slug="treatment_response", values=["Positive", "Negative", "Neutral"] + ), + ] + + # Inserting the categories and their values into the database + for category_info in categories: + db.insert_label_category(category_info) + for value in category_info.values: + value_info = LabelValueInfo(label_category_slug=category_info.slug, value=value) + db.insert_label_value(value_info) + + +def assign_dummy_labels_to_patients(db: DatabaseManager) -> None: + # Dummy label assignments for demonstration purposes + dummy_assignments = ( + [ + PatientLabelAssignment( + patient_code=patient_code, + label_category_slug="tumor_type", + label_value=random.choice(["Melanoma", "Carcinoma", "Sarcoma"]), + ) + for patient_code in ["TCGA-V4-A9EC", "TCGA-06-0148", "TCGA-AO-A12F"] + ] + + [ + PatientLabelAssignment( + patient_code=patient_code, + label_category_slug="tumor_stage", + label_value=random.choice(["I", "II", "III", "IV"]) + ) + for patient_code in ["TCGA-V4-A9EC", "TCGA-06-0148", "TCGA-AO-A12F"] + ] + + [ + PatientLabelAssignment( + patient_code=patient_code, + label_category_slug="treatment_response", + label_value=random.choice(["Positive", "Negative", "Neutral"]), + ) + for patient_code in ["TCGA-V4-A9EC", "TCGA-06-0148", "TCGA-AO-A12F"] + ] + ) + + # Assign the labels to the patients + for assignment in dummy_assignments: + db.assign_label_to_patient(assignment) + + +if __name__ == "__main__": + db = DatabaseManager() + db.create_tables() + populate_with_tcga_tiled(db) + + # Assign splits based on different versions and ratios + db.create_random_split("v1", (80, 10, 10), "80/10/10 split") + db.create_random_split("v2", (70, 20, 10), "70/20/10 split") + + populate_with_tcga_labels(db) + assign_labels_to_patients(db) + + + # populate_label_categories_and_values_with_dummies(db) + # assign_dummy_labels_to_patients(db) + + # Example, get all patients with a specific label. + label_category_slug = "tcga_study_codes" + + filtered_patients = db.get_patients_by_label_category(label_category_slug) + + # Create split that has the patients in there + db.create_random_split("v3", (80, 10, 10), "label split", filtered_patients) \ No newline at end of file diff --git a/tools/generate_tcga_labels.py b/tools/generate_database/generate_tcga_labels.py similarity index 96% rename from tools/generate_tcga_labels.py rename to tools/generate_database/generate_tcga_labels.py index 529cb59..c9fd0c8 100644 --- a/tools/generate_tcga_labels.py +++ b/tools/generate_database/generate_tcga_labels.py @@ -99,7 +99,7 @@ def write_to_database(conn, chunks_gen): tcga_study_codes = get_tcga_study_code_mapping() # Now we need to map identifer/tcga-code to the label. - with open("tcga_files_to_study_types.2023-08-20.json", "r", encoding="utf-8") as json_file: + with open("database_label_data/tcga_files_to_study_types.2023-08-20.json", "r", encoding="utf-8") as json_file: tcga_files_to_study_types = json.load(json_file) generate_labels(tcga_files_to_study_types, tcga_study_codes)