Skip to content

Commit

Permalink
Merge pull request #4 from MilagrosMarin/main
Browse files Browse the repository at this point in the history
Fix deprecated library to run KPMS + add `task_mode` + complete functionality of the tutorial in codespaces
  • Loading branch information
ttngu207 authored Aug 30, 2024
2 parents 2e87c49 + f30df2f commit 62de906
Show file tree
Hide file tree
Showing 11 changed files with 1,758 additions and 1,723 deletions.
6 changes: 4 additions & 2 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ COPY ./ /tmp/element-moseq/

RUN \
# pipeline dependencies
apt-get install gcc g++ ffmpeg libsm6 libxext6 libgl1 libegl1 -y && \
apt-get update && \
apt-get install -y gcc ffmpeg graphviz && \
pip install ipywidgets && \
pip install --no-cache-dir -e /tmp/element-moseq[elements,tests] && \
# clean up
rm -rf /tmp/element-moseq/ && \
Expand All @@ -52,4 +54,4 @@ ENV DATABASE_PREFIX neuro_
USER vscode
CMD bash -c "sudo rm /var/run/docker.pid; sudo dockerd"

ENV LD_LIBRARY_PATH="/lib:/opt/conda/lib"
ENV LD_LIBRARY_PATH="/lib:/opt/conda/lib"
3 changes: 2 additions & 1 deletion .devcontainer/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ services:
build:
context: ..
dockerfile: ./.devcontainer/Dockerfile
# image: datajoint/element_moseq:latest
#image: datajoint/element_moseq:latest
extra_hosts:
- fakeservices.datajoint.io:127.0.0.1
environment:
Expand All @@ -23,3 +23,4 @@ services:
privileged: true # only because of dind
volumes:
docker_data:

8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.

## [0.2.0] - 2024-08-16
+ Add - `load` functions and new secondary attributes for tutorial purposes
+ Add - `outbox` results in the public s3 bucket to be mounted in Codespaces
+ Update - tutorial content
+ Fix - `scipy.linalg` deprecation in latest release by adjusting version in `setup.py`
+ Update - `pre_kappa` and `full_kappa` to integer to simplify equality comparisons
+ Update - `images` of the pipeline

## [0.1.1] - 2024-03-21

+ Update - Schemas and tables renaming
Expand Down
154 changes: 100 additions & 54 deletions element_moseq/moseq_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def activate(
)


# -------------- Functions required by the element-moseq ---------------
# -------------- Functions required by element-moseq ---------------


def get_kpms_root_data_dir() -> list:
Expand Down Expand Up @@ -87,7 +87,7 @@ def get_kpms_processed_data_dir() -> Optional[str]:
Method in parent namespace should provide a string to a directory where KPMS output
files will be stored. If unspecified, output files will be stored in the
session directory 'videos' folder, per DeepLabCut default.
session directory 'videos' folder, per Keypoint-MoSeq default.
"""
if hasattr(_linking_module, "get_kpms_processed_data_dir"):
return _linking_module.get_kpms_processed_data_dir()
Expand Down Expand Up @@ -197,14 +197,15 @@ class InferenceTask(dj.Manual):
"""

definition = """
-> VideoRecording # `VideoRecording` key
-> Model # `Model` key
-> VideoRecording # `VideoRecording` key
-> Model # `Model` key
---
-> PoseEstimationMethod # Pose estimation method used for the specified `recording_id`
keypointset_dir : varchar(1000) # Keypointset directory for the specified VideoRecording
inference_output_dir='' : varchar(1000) # Optional. Sub-directory where the results will be stored
inference_desc='' : varchar(1000) # Optional. User-defined description of the inference task
num_iterations=NULL : int # Optional. Number of iterations to use for the model inference. If null, the default number internally is 50.
-> PoseEstimationMethod # Pose estimation method used for the specified `recording_id`
keypointset_dir : varchar(1000) # Keypointset directory for the specified VideoRecording
inference_output_dir='' : varchar(1000) # Optional. Sub-directory where the results will be stored
inference_desc='' : varchar(1000) # Optional. User-defined description of the inference task
num_iterations=NULL : int # Optional. Number of iterations to use for the model inference. If null, the default number internally is 50.
task_mode='load' : enum('load', 'trigger') # Task mode for the inference task
"""


Expand Down Expand Up @@ -305,12 +306,14 @@ def make(self, key):
num_iterations,
model_id,
pose_estimation_method,
task_mode,
) = (InferenceTask & key).fetch1(
"keypointset_dir",
"inference_output_dir",
"num_iterations",
"model_id",
"pose_estimation_method",
"task_mode",
)

kpms_root = get_kpms_root_data_dir()
Expand All @@ -322,7 +325,7 @@ def make(self, key):
)
keypointset_dir = find_full_path(kpms_root, keypointset_dir)

inference_output_dir = model_dir / inference_output_dir
inference_output_dir = os.path.join(model_dir, inference_output_dir)

if not os.path.exists(inference_output_dir):
os.makedirs(model_dir / inference_output_dir)
Expand Down Expand Up @@ -366,55 +369,98 @@ def make(self, key):
f"No valid `kpms_dj_config` found in the parent model directory {model_dir.parent}"
)

start_time = datetime.utcnow()
results = apply_model(
model=model,
data=data,
metadata=metadata,
pca=pca,
project_dir=model_dir.parent.as_posix(),
model_name=Path(model_dir).name,
results_path=(inference_output_dir / "results.h5").as_posix(),
return_model=False,
num_iters=num_iterations
or 50.0, # default internal value in the keypoint-moseq function
**kpms_dj_config,
)
end_time = datetime.utcnow()
if task_mode == "trigger":
start_time = datetime.utcnow()
results = apply_model(
model=model,
data=data,
metadata=metadata,
pca=pca,
project_dir=model_dir.parent.as_posix(),
model_name=Path(model_dir).name,
results_path=(inference_output_dir / "results.h5").as_posix(),
return_model=False,
num_iters=num_iterations
or 50, # default internal value in the keypoint-moseq function
**kpms_dj_config,
)
end_time = datetime.utcnow()

duration_seconds = (end_time - start_time).total_seconds()
duration_seconds = (end_time - start_time).total_seconds()

save_results_as_csv(
results=results,
save_dir=(inference_output_dir / "results_as_csv").as_posix(),
)
save_results_as_csv(
results=results,
save_dir=(inference_output_dir / "results_as_csv").as_posix(),
)

fig, _ = plot_syllable_frequencies(
results=results, path=inference_output_dir.as_posix()
)
fig.savefig(inference_output_dir / "syllable_frequencies.png")
plt.close(fig)

generate_trajectory_plots(
coordinates=coordinates,
results=results,
output_dir=(inference_output_dir / "trajectory_plots").as_posix(),
**kpms_dj_config,
)
fig, _ = plot_syllable_frequencies(
results=results, path=inference_output_dir.as_posix()
)
fig.savefig(inference_output_dir / "syllable_frequencies.png")
plt.close(fig)

generate_trajectory_plots(
coordinates=coordinates,
results=results,
output_dir=(inference_output_dir / "trajectory_plots").as_posix(),
**kpms_dj_config,
)

sampled_instances = generate_grid_movies(
coordinates=coordinates,
results=results,
output_dir=(inference_output_dir / "grid_movies").as_posix(),
**kpms_dj_config,
)
sampled_instances = generate_grid_movies(
coordinates=coordinates,
results=results,
output_dir=(inference_output_dir / "grid_movies").as_posix(),
**kpms_dj_config,
)

plot_similarity_dendrogram(
coordinates=coordinates,
results=results,
save_path=(inference_output_dir / "similarity_dendogram").as_posix(),
**kpms_dj_config,
)
plot_similarity_dendrogram(
coordinates=coordinates,
results=results,
save_path=(inference_output_dir / "similarity_dendogram").as_posix(),
**kpms_dj_config,
)

else:
from keypoint_moseq import (
load_results,
filter_centroids_headings,
get_syllable_instances,
sample_instances,
)

# load results
results = load_results(
project_dir=Path(inference_output_dir).parent,
model_name=Path(inference_output_dir).parts[-1],
)

# extract sampled_instances
## extract syllables from results
syllables = {k: v["syllable"] for k, v in results.items()}

## extract and smooth centroids and headings
centroids = {k: v["centroid"] for k, v in results.items()}
headings = {k: v["heading"] for k, v in results.items()}

filter_size = 9 # default value
centroids, headings = filter_centroids_headings(
centroids, headings, filter_size=filter_size
)

# sample instances for each syllable
syllable_instances = get_syllable_instances(
syllables, min_duration=3, min_frequency=0.005
)

sampled_instances = sample_instances(
syllable_instances=syllable_instances,
num_samples=4 * 6, # minimum rows * cols
coordinates=coordinates,
centroids=centroids,
headings=headings,
)

duration_seconds = None

self.insert1({**key, "inference_duration": duration_seconds})

Expand Down
Loading

0 comments on commit 62de906

Please sign in to comment.