Skip to content

Commit

Permalink
Merge pull request #4 from NxNiki/dev
Browse files Browse the repository at this point in the history
[refactor] bug fix on config
  • Loading branch information
NxNiki authored Oct 3, 2024
2 parents 6598189 + 6350dda commit bfba0bf
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 132 deletions.
8 changes: 4 additions & 4 deletions .run/main.run.xml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="main" type="PythonConfigurationType" factoryName="Python" nameIsGenerated="true">
<module name="movie_decoding" />
<configuration default="false" name="main" type="PythonConfigurationType" factoryName="Python">
<module name="brain_decoding" />
<option name="ENV_FILES" value="" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
Expand All @@ -9,11 +9,11 @@
</envs>
<option name="SDK_HOME" value="" />
<option name="SDK_NAME" value="movie_decoding" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="WORKING_DIRECTORY" value="" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/src/movie_decoding/main.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/src/brain_decoding/main.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
Expand Down
25 changes: 25 additions & 0 deletions .run/save_config.run.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="save_config" type="PythonConfigurationType" factoryName="Python">
<module name="brain_decoding" />
<option name="ENV_FILES" value="" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="SDK_NAME" value="movie_decoding" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/scripts/save_config.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
</component>
2 changes: 1 addition & 1 deletion scripts/save_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
config.model.weight_decay = 1e-4
config.model.epochs = 5
config.model.lr_drop = 50
config.model.validation_step = 25
config.model.validation_step = 2
config.model.early_stop = 75
config.model.num_labels = 8
config.model.merge_label = True
Expand Down
70 changes: 35 additions & 35 deletions src/brain_decoding/dataloader/free_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,47 +26,49 @@ def __init__(self, config):
self.lfp_channel_by_region = {}

spikes_data = None
if self.config["use_spike"]:
if self.config["use_sleep"]:
config["spike_data_mode_inference"] = ""
spikes_data = self.read_recording_data("spike_path", "time_sleep", "")
if self.config.experiment["use_spike"]:
data_path = "spike_path"
if self.config.experiment["use_sleep"]:
config.experiment["spike_data_mode_inference"] = ""
spikes_data = self.read_recording_data(data_path, "time_sleep", "")
else:
if isinstance(self.config["free_recall_phase"], str) and "all" in self.config["free_recall_phase"]:
if self.config["patient"] == "i728":
phases = ["FR1a", "FR1b"]
else:
# phases = ["FR1", "FR2"]
phases = ["FR1"]
if (
isinstance(self.config.experiment["free_recall_phase"], str)
and "all" in self.config.experiment["free_recall_phase"]
):
phases = ["FR1"]
for phase in phases:
spikes_data = self.read_recording_data("spike_path", "time_recall", phase)
spikes_data = self.read_recording_data(data_path, "time_recall", phase)
elif (
isinstance(self.config["free_recall_phase"], str) and "control" in self.config["free_recall_phase"]
isinstance(self.config.experiment["free_recall_phase"], str)
and "control" in self.config.experiment["free_recall_phase"]
):
spikes_data = self.read_recording_data(data_path, "time", None)
elif (
isinstance(self.config.experiment["free_recall_phase"], str)
and "movie" in self.config.experiment["free_recall_phase"]
):
spikes_data = self.read_recording_data("spike_path", "time", None)
elif isinstance(self.config["free_recall_phase"], str) and "movie" in self.config["free_recall_phase"]:
spikes_data = self.read_recording_data("spike_path", "time", None)
spikes_data = self.read_recording_data(data_path, "time", None)
else:
spikes_data = self.read_recording_data("spike_path", "time_recall", None)
spikes_data = self.read_recording_data(data_path, "time_recall", None)

lfp_data = None
if self.config["use_lfp"]:
if self.use_sleep:
if self.config.experiment["use_lfp"]:
data_path = "lfp_path"
if self.config.experiment.use_sleep:
config["spike_data_mode_inference"] = ""
lfp_data = self.read_recording_data("lfp_path", "spectrogram_sleep", "")
lfp_data = self.read_recording_data(data_path, "spectrogram_sleep", "")
else:
if isinstance(self.config["free_recall_phase"], str) and "all" in self.config["free_recall_phase"]:
if self.config["patient"] == "i728":
phases = [1, 3]
else:
phases = [1, 2]
phases = [1, 2]
for phase in phases:
lfp_data = self.read_recording_data("lfp_path", "spectrogram_recall", phase)
lfp_data = self.read_recording_data(data_path, "spectrogram_recall", phase)
elif (
isinstance(self.config["free_recall_phase"], str) and "control" in self.config["free_recall_phase"]
):
lfp_data = self.read_recording_data("lfp_path", "spectrogram", None)
lfp_data = self.read_recording_data(data_path, "spectrogram", None)
else:
lfp_data = self.read_recording_data("lfp_path", "spectrogram_recall", None)
lfp_data = self.read_recording_data(data_path, "spectrogram_recall", None)
# self.lfp_data = {key: np.concatenate(value_list, axis=0) for key, value_list in self.lfp_data.items()}

self.data = {"clusterless": spikes_data, "lfp": lfp_data}
Expand All @@ -85,14 +87,12 @@ def read_recording_data(self, root_path: str, file_path_prefix: str, phase: Opti
if phase == "":
exp_file_path = file_path_prefix
else:
if phase is None:
phase = self.config["free_recall_phase"]
exp_file_path = f"{file_path_prefix}_{phase}"

recording_file_path = os.path.join(
self.config[root_path],
self.config["patient"],
self.config["spike_data_mode_inference"],
self.config.data[root_path],
str(self.config.experiment["patient"]),
self.config.data["spike_data_mode_inference"],
exp_file_path,
)
recording_files = glob.glob(os.path.join(recording_file_path, "*.npz"))
Expand Down Expand Up @@ -146,7 +146,7 @@ def load_clustless(self, files) -> np.ndarray[float]:
# spike[spike < self.spike_data_sd] = 0
# vmax, vmin = self.channel_max(spike)
# normalized_spike = 2 * (spike - vmin[None, None, :, None]) / (vmax[None, None, :, None] - vmin[None, None, :, None]) - 1
spike[spike < self.config["spike_data_sd_inference"]] = 0
spike[spike < self.config.data["spike_data_sd_inference"]] = 0
# spike[spike > 500] = 0
vmax = np.max(spike)
normalized_spike = spike / vmax
Expand Down Expand Up @@ -270,7 +270,7 @@ def load_pickle(self, fn):
return lookup

def preprocess_data(self):
if self.config["use_combined"]:
if self.config.experiment["use_combined"]:
assert self.data["clusterless"].shape[0] == self.data["lfp"].shape[0]

# self.label = np.array(self.ml_label).transpose()[:length, :].astype(np.float32)
Expand Down Expand Up @@ -352,8 +352,8 @@ def create_inference_combined_loaders(
# np.random.seed(seed)
np.random.shuffle(all_indices)

spike_inference = dataset.data["clusterless"][all_indices] if config["use_spike"] else None
lfp_inference = dataset.data["lfp"][all_indices] if config["use_lfp"] else None
spike_inference = dataset.data["clusterless"][all_indices] if config.experiment["use_spike"] else None
lfp_inference = dataset.data["lfp"][all_indices] if config.experiment["use_lfp"] else None

# label_inference = dataset.smoothed_label[all_indices]
label_inference = None
Expand Down
2 changes: 1 addition & 1 deletion src/brain_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def pipeline(config: PipelineConfig) -> Trainer:

if __name__ == "__main__":
patient = 562
config_file = CONFIG_FILE_PATH / "config_test-None-None_2024-10-02-13:10:10.yaml"
config_file = CONFIG_FILE_PATH / "config_test-None-None_2024-10-02-17:31:47.yaml"

config = set_config(
config_file,
Expand Down
96 changes: 33 additions & 63 deletions src/brain_decoding/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def train(self, epochs, fold):
best_f1 = -1
self.model.train()
os.makedirs(self.config.data["train_save_path"], exist_ok=True)
os.makedirs(self.config.data["train_save_path"], exist_ok=True)
for epoch in tqdm(range(epochs)):
meter = Meter(fold)

Expand Down Expand Up @@ -131,7 +132,7 @@ def train(self, epochs, fold):
)

model_save_path = os.path.join(
self.config["train_save_path"],
self.config.data["train_save_path"],
"best_weights_fold{}.tar".format(fold + 1),
)
torch.save(
Expand Down Expand Up @@ -356,15 +357,11 @@ def permutation_p(label, activation):
df.to_csv(os.path.join(self.config["test_save_path"], "p_values.csv"))

def memory(self, epoch=-1, phase: str = "free_recall1", alongwith=[]):
torch.manual_seed(self.config["seed"])
np.random.seed(self.config["seed"])
random.seed(self.config["seed"])
self.config["free_recall_phase"] = phase
if self.config["patient"] == "i728" and "1" in phase:
self.config["free_recall_phase"] = "free_recall1a"
dataloaders = initialize_inference_dataloaders(self.config)
else:
dataloaders = initialize_inference_dataloaders(self.config)
torch.manual_seed(self.config.experiment["seed"])
np.random.seed(self.config.experiment["seed"])
random.seed(self.config.experiment["seed"])
self.config.experiment["free_recall_phase"] = phase
dataloaders = initialize_inference_dataloaders(self.config)
model = initialize_model(self.config)
# model = torch.compile(model)
model = model.to(device_name)
Expand All @@ -375,79 +372,52 @@ def memory(self, epoch=-1, phase: str = "free_recall1", alongwith=[]):

# load the model with best F1-score
# model_dir = os.path.join(self.config['train_save_path'], 'best_weights_fold{}.tar'.format(fold + 1))
model_dir = os.path.join(self.config["train_save_path"], "model_weights_epoch{}.tar".format(epoch))
model_dir = os.path.join(self.config.data["train_save_path"], "model_weights_epoch{}.tar".format(epoch))
model.load_state_dict(torch.load(model_dir)["model_state_dict"])
# print('Resume model: %s' % model_dir)
model.eval()

predictions_all = np.empty((0, self.config["num_labels"]))
predictions_all = np.empty((0, self.config.model["num_labels"]))
predictions_length = {}
with torch.no_grad():
if self.config["patient"] == "i728" and "1" in phase:
# load the best epoch number from the saved "model_results" structure
for ph in ["FR1a", "FR1b"]:
predictions = np.empty((0, self.config["num_labels"]))
self.config["free_recall_phase"] = ph
dataloaders = initialize_inference_dataloaders(self.config)
# y_true = np.empty((0, self.config['num_labels']))
for i, (feature, index) in enumerate(dataloaders["inference"]):
# target = target.to(self.device)
spike, lfp = self.extract_feature(feature)
# forward pass

# start_time = time.time()
spike_emb, lfp_emb, output = model(lfp, spike)
# end_time = time.time()
# print('inference time: ', end_time - start_time)
output = torch.sigmoid(output)
pred = output.cpu().detach().numpy()
predictions = np.concatenate([predictions, pred], axis=0)

if self.config["use_overlap"]:
fake_activation = np.mean(predictions, axis=0)
predictions = np.vstack((fake_activation, predictions, fake_activation))

predictions_all = np.concatenate([predictions_all, predictions], axis=0)
predictions_length[phase] = len(predictions_all)
else:
self.config["free_recall_phase"] = phase
dataloaders = initialize_inference_dataloaders(self.config)
predictions = np.empty((0, self.config["num_labels"]))
# y_true = np.empty((0, self.config['num_labels']))
for i, (feature, index) in enumerate(dataloaders["inference"]):
# target = target.to(self.device)
spike, lfp = self.extract_feature(feature)
# forward pass
self.config.experiment["free_recall_phase"] = phase
dataloaders = initialize_inference_dataloaders(self.config)
predictions = np.empty((0, self.config.model["num_labels"]))
# y_true = np.empty((0, self.config['num_labels']))
for i, (feature, index) in enumerate(dataloaders["inference"]):
# target = target.to(self.device)
spike, lfp = self.extract_feature(feature)
# forward pass

# start_time = time.time()
spike_emb, lfp_emb, output = model(lfp, spike)
# end_time = time.time()
# print('inference time: ', end_time - start_time)
output = torch.sigmoid(output)
pred = output.cpu().detach().numpy()
predictions = np.concatenate([predictions, pred], axis=0)
# start_time = time.time()
spike_emb, lfp_emb, output = model(lfp, spike)
# end_time = time.time()
# print('inference time: ', end_time - start_time)
output = torch.sigmoid(output)
pred = output.cpu().detach().numpy()
predictions = np.concatenate([predictions, pred], axis=0)

if self.config["use_overlap"]:
fake_activation = np.mean(predictions, axis=0)
predictions = np.vstack((fake_activation, predictions, fake_activation))
if self.config.experiment["use_overlap"]:
fake_activation = np.mean(predictions, axis=0)
predictions = np.vstack((fake_activation, predictions, fake_activation))

predictions_length[phase] = len(predictions)
predictions_all = np.concatenate([predictions_all, predictions], axis=0)
predictions_length[phase] = len(predictions)
predictions_all = np.concatenate([predictions_all, predictions], axis=0)

# np.save(os.path.join(self.config['memory_save_path'], 'free_recall_{}_results.npy'.format(phase)), predictions)
save_path = os.path.join(self.config["memory_save_path"], "prediction")
save_path = os.path.join(self.config.data["memory_save_path"], "prediction")
os.makedirs(save_path, exist_ok=True)
np.save(
os.path.join(save_path, "epoch{}_free_recall_{}_results.npy".format(epoch, phase)),
predictions_all,
)

for ph in alongwith:
self.config["free_recall_phase"] = ph
self.config.experiment["free_recall_phase"] = ph
dataloaders = initialize_inference_dataloaders(self.config)
with torch.no_grad():
# load the best epoch number from the saved "model_results" structure
predictions = np.empty((0, self.config["num_labels"]))
predictions = np.empty((0, self.config.model["num_labels"]))
# y_true = np.empty((0, self.config['num_labels']))
for i, (feature, index) in enumerate(dataloaders["inference"]):
# target = target.to(self.device)
Expand All @@ -462,7 +432,7 @@ def memory(self, epoch=-1, phase: str = "free_recall1", alongwith=[]):
pred = output.cpu().detach().numpy()
predictions = np.concatenate([predictions, pred], axis=0)

if self.config["use_overlap"]:
if self.config.experiment["use_overlap"]:
fake_activation = np.mean(predictions, axis=0)
predictions = np.vstack((fake_activation, predictions, fake_activation))

Expand Down
24 changes: 12 additions & 12 deletions src/brain_decoding/utils/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,23 @@ def initialize_configs(architecture) -> Dict:
return args


def initialize_inference_dataloaders(config):
if config["use_sleep"]:
def initialize_inference_dataloaders(config: PipelineConfig):
if config.experiment["use_sleep"]:
dataset = InferenceDataset(
config["data_path"],
config["patient"],
config["use_lfp"],
config["use_spike"],
config["use_bipolar"],
config["use_sleep"],
config["free_recall_phase"],
config["hour"],
config.data["data_path"],
config.experiment["patient"],
config.experiment["use_lfp"],
config.experiment["use_spike"],
config.experiment["use_bipolar"],
config.experiment["use_sleep"],
config.experiment["free_recall_phase"],
config.experiment["hour"],
)
else:
dataset = InferenceDataset(config)

LFP_CHANNEL[config["patient"]] = dataset.lfp_channel_by_region
test_loader = create_inference_combined_loaders(dataset, config, batch_size=config["batch_size"])
LFP_CHANNEL[config.experiment["patient"]] = dataset.lfp_channel_by_region
test_loader = create_inference_combined_loaders(dataset, config, batch_size=config.model["batch_size"])

dataloaders = {"train": None, "valid": None, "inference": test_loader}
return dataloaders
Expand Down
Loading

0 comments on commit bfba0bf

Please sign in to comment.