Skip to content

Commit

Permalink
debugged the cuda issue for grn inf
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Nov 13, 2024
1 parent 55bbe0d commit bbc89bc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ release: ## Create a new tag for release.
@git commit -m "release: version $${TAG} 🚀"
@echo "creating git tag : $${TAG}"
@git tag $${TAG}
@git push -u origin HEAD --tags
@git push -u cantini HEAD --tags
@echo "Github Actions will detect the new tag and release the new version."
@mkdocs gh-deploy
@echo "Documentation deployed to https://jkobject.github.io/scPRINT/"
Expand Down
10 changes: 5 additions & 5 deletions scprint/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,10 @@ def add_attn(
if self.data is None:
self.data = torch.zeros(
[self.gene_dim, self.gene_dim, len(x) * x[0].shape[3]],
device="cuda",
device=pos.device,
dtype=torch.float32,
)
self.div = torch.zeros(1, device="cuda", dtype=torch.float32)
self.div = torch.zeros(1, device=pos.device, dtype=torch.float32)

for i, elem in enumerate(x):
batch, seq_len, _, heads, _ = elem.shape
Expand Down Expand Up @@ -545,11 +545,11 @@ def add_qk(
"""
if self.data is None:
self.data = torch.zeros(
[len(x), self.gene_dim] + list(x[0].shape[2:]), device="cuda"
[len(x), self.gene_dim] + list(x[0].shape[2:]), device=pos.device
)
self.div = torch.zeros(self.gene_dim, device="cuda")
self.div = torch.zeros(self.gene_dim, device=pos.device)
for i in range(x[0].shape[0]):
loc = torch.cat([torch.arange(8, device="cuda"), pos[i] + 8]).int()
loc = torch.cat([torch.arange(8, device=pos.device), pos[i] + 8]).int()
for j in range(len(x)):
self.data[j, loc, :, :, :] += x[j][i]
self.div[loc] += 1
Expand Down

0 comments on commit bbc89bc

Please sign in to comment.