Skip to content

Commit

Permalink
Allow dlc pipeline to run without prior position tracking (LorenFrank…
Browse files Browse the repository at this point in the history
…Lab#970)

* fix dlc pose estimation populate if no raw position data

* allow dlc pipeline to run without raw spatial data

* update changelog

* string formatting

* fix analysis nwb create time
  • Loading branch information
samuelbray32 authored May 13, 2024
1 parent a6e2ea6 commit 113ce9a
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 37 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
- Add long-distance restrictions via `<<` and `>>` operators. #943, #969
- Fix relative pathing for `mkdocstring-python=>1.9.1`. #967, #968

### Pipelines

- DLC
- Allow dlc without pre-existing tracking data #950

## [0.5.2] (April 22, 2024)

### Infrastructure
Expand Down
19 changes: 14 additions & 5 deletions src/spyglass/position/v1/position_dlc_centroid.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,17 +268,26 @@ def make(self, key):
)
position = pynwb.behavior.Position()
velocity = pynwb.behavior.BehavioralTimeSeries()
spatial_series = (RawPosition() & key).fetch_nwb()[0][
"raw_position"
]
if (
RawPosition & key
): # if spatial series exists, get metadata from there
spatial_series = (RawPosition() & key).fetch_nwb()[0][
"raw_position"
]
reference_frame = spatial_series.reference_frame
comments = spatial_series.comments
else:
reference_frame = ""
comments = "no comments"

METERS_PER_CM = 0.01
position.create_spatial_series(
name="position",
timestamps=final_df.index.to_numpy(),
conversion=METERS_PER_CM,
data=final_df.loc[:, idx[("x", "y")]].to_numpy(),
reference_frame=spatial_series.reference_frame,
comments=spatial_series.comments,
reference_frame=reference_frame,
comments=comments,
description="x_position, y_position",
)
velocity.create_timeseries(
Expand Down
25 changes: 19 additions & 6 deletions src/spyglass/position/v1/position_dlc_orient.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from time import time

import datajoint as dj
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -85,9 +87,7 @@ class DLCOrientation(SpyglassMixin, dj.Computed):

def make(self, key):
# Get labels to smooth from Parameters table
key["analysis_file_name"] = AnalysisNwbfile().create( # logged
key["nwb_file_name"]
)
AnalysisNwbfile()._creation_times["pre_create_time"] = time()
cohort_entries = DLCSmoothInterpCohort.BodyPart & key
pos_df = pd.concat(
{
Expand Down Expand Up @@ -133,15 +133,28 @@ def make(self, key):
final_df = pd.DataFrame(
orientation, columns=["orientation"], index=pos_df.index
)
spatial_series = (RawPosition() & key).fetch_nwb()[0]["raw_position"]
key["analysis_file_name"] = AnalysisNwbfile().create( # logged
key["nwb_file_name"]
)
if (
RawPosition & key
): # if spatial series exists, get metadata from there
spatial_series = (RawPosition() & key).fetch_nwb()[0][
"raw_position"
]
reference_frame = spatial_series.reference_frame
comments = spatial_series.comments
else:
reference_frame = ""
comments = "no comments"
orientation = pynwb.behavior.CompassDirection()
orientation.create_spatial_series(
name="orientation",
timestamps=final_df.index.to_numpy(),
conversion=1.0,
data=final_df["orientation"].to_numpy(),
reference_frame=spatial_series.reference_frame,
comments=spatial_series.comments,
reference_frame=reference_frame,
comments=comments,
description="orientation",
)
nwb_analysis_file = AnalysisNwbfile()
Expand Down
53 changes: 33 additions & 20 deletions src/spyglass/position/v1/position_dlc_pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,25 +232,38 @@ def make(self, key):
dlc_result.creation_time
).strftime("%Y-%m-%d %H:%M:%S")

logger.logger.info("getting raw position")
interval_list_name = (
convert_epoch_interval_name_to_position_interval_name(
{
"nwb_file_name": key["nwb_file_name"],
"epoch": key["epoch"],
},
populate_missing=False,
# get video information
_, _, meters_per_pixel, video_time = get_video_path(key)
# check if a position interval exists for this epoch
try:
interval_list_name = (
convert_epoch_interval_name_to_position_interval_name(
{
"nwb_file_name": key["nwb_file_name"],
"epoch": key["epoch"],
},
populate_missing=False,
)
)
)
spatial_series = (
RawPosition()
& {**key, "interval_list_name": interval_list_name}
).fetch_nwb()[0]["raw_position"]
_, _, _, video_time = get_video_path(key)
pos_time = spatial_series.timestamps
# TODO: should get timestamps from VideoFile, but need the video_frame_ind from RawPosition,
# which also has timestamps
key["meters_per_pixel"] = spatial_series.conversion
raw_position = True
except KeyError:
raw_position = False

if raw_position:
logger.logger.info("Getting raw position")
spatial_series = (
RawPosition()
& {**key, "interval_list_name": interval_list_name}
).fetch_nwb()[0]["raw_position"]
pos_time = spatial_series.timestamps
reference_frame = spatial_series.reference_frame
comments = spatial_series.comments
else:
pos_time = video_time
reference_frame = ""
comments = "no comments"

key["meters_per_pixel"] = meters_per_pixel

# Insert entry into DLCPoseEstimation
logger.logger.info(
Expand Down Expand Up @@ -292,8 +305,8 @@ def make(self, key):
timestamps=part_df.time.to_numpy(),
conversion=METERS_PER_CM,
data=part_df.loc[:, idx[("x", "y")]].to_numpy(),
reference_frame=spatial_series.reference_frame,
comments=spatial_series.comments,
reference_frame=reference_frame,
comments=comments,
description="x_position, y_position",
)
likelihood.create_timeseries(
Expand Down
9 changes: 6 additions & 3 deletions src/spyglass/position/v1/position_dlc_position.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from time import time

import datajoint as dj
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -167,9 +169,7 @@ def make(self, key):
path=f"{output_dir.as_posix()}/log.log",
print_console=False,
) as logger:
key["analysis_file_name"] = AnalysisNwbfile().create( # logged
key["nwb_file_name"]
)
AnalysisNwbfile()._creation_times["pre_create_time"] = time()
logger.logger.info("-----------------------")
idx = pd.IndexSlice
# Get labels to smooth from Parameters table
Expand Down Expand Up @@ -227,6 +227,9 @@ def make(self, key):
.fetch_nwb()[0]["dlc_pose_estimation_position"]
.get_spatial_series()
)
key["analysis_file_name"] = AnalysisNwbfile().create( # logged
key["nwb_file_name"]
)
# Add dataframe to AnalysisNwbfile
nwb_analysis_file = AnalysisNwbfile()
position = pynwb.behavior.Position()
Expand Down
8 changes: 5 additions & 3 deletions src/spyglass/position/v1/position_dlc_selection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
from pathlib import Path
from time import time

import datajoint as dj
import numpy as np
Expand Down Expand Up @@ -58,9 +59,7 @@ class DLCPosV1(SpyglassMixin, dj.Computed):
def make(self, key):
orig_key = copy.deepcopy(key)
# Add to Analysis NWB file
key["analysis_file_name"] = AnalysisNwbfile().create( # logged
key["nwb_file_name"]
)
AnalysisNwbfile()._creation_times["pre_create_time"] = time()
key["pose_eval_result"] = self.evaluate_pose_estimation(key)

pos_nwb = (DLCCentroid & key).fetch_nwb()[0]
Expand Down Expand Up @@ -114,6 +113,9 @@ def make(self, key):
comments=vid_frame_obj.comments,
)

key["analysis_file_name"] = AnalysisNwbfile().create(
key["nwb_file_name"]
)
nwb_analysis_file = AnalysisNwbfile()
key["orientation_object_id"] = nwb_analysis_file.add_nwb_object(
key["analysis_file_name"], orientation
Expand Down

0 comments on commit 113ce9a

Please sign in to comment.