Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Nov 13, 2024
1 parent 64b0417 commit 6bc732a
Show file tree
Hide file tree
Showing 4 changed files with 434 additions and 11 deletions.
13 changes: 11 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,25 @@ flash = [
[project.urls]
repository = "https://github.com/jkobject/scPRINT"

[project.scripts]
scprint = "scprint.__main__:main"

[tool.ruff]
line-length = 88

[tool.ruff.lint]
select = ["E", "F", "I"]
ignore = ["E501", "E203", "E266", "E265", "F401", "F403", "E722", "E741", "E731", "E721"]

[tool.hatch.build.targets.wheel]
only-include = [
"/scprint",
"/slurm",
"/config",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project.scripts]
scprint = "scprint.__main__:main"

4 changes: 1 addition & 3 deletions scprint/tasks/grn.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def aggregate(self, attn, genes):
return attn.detach().cpu().numpy()
badloc = torch.isnan(attn.sum((0, 2, 3, 4)))
attn = attn[:, ~badloc, :, :, :]
badloc = badloc.detach().cpu().numpy()
self.curr_genes = (
np.array(self.curr_genes)[~badloc[self.add_emb_in_model :]]
if self.how == "random expr"
Expand Down Expand Up @@ -399,9 +400,6 @@ def filter(self, adj, gt=None):
return adj

def save(self, grn, subadata, loc=""):
import pdb

pdb.set_trace()
grn = GRNAnnData(
subadata[:, subadata.var.index.isin(self.curr_genes)].copy(), grn=grn
)
Expand Down
1 change: 0 additions & 1 deletion scprint/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def get_free_gpu():
import sys
from io import StringIO

import pandas as pd

gpu_stats = subprocess.check_output(
[
Expand Down
Loading

0 comments on commit 6bc732a

Please sign in to comment.