Skip to content

Commit

Permalink
[BugFix] Fix error in state_dict tests (#531)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Sep 20, 2023
1 parent a721766 commit 785d11f
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 3 deletions.
1 change: 0 additions & 1 deletion tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3456,7 +3456,6 @@ def unflatten_keys(self, separator: str = ".", inplace: bool = False) -> T:
if key in keys and (
not is_tensor_collection(out.get(key)) or not out.get(key).is_empty()
):
print(out.get(key))
raise KeyError(
"Unflattening key(s) in tensordict will override existing unflattened key"
)
Expand Down
3 changes: 1 addition & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3172,7 +3172,6 @@ def test_sd_params(self, detach):
k: v if not isinstance(v, torch.Tensor) else v * 0
for k, v in sd.items()
}
print(sd)
# do some op to create a graph
td.apply(lambda x: x + 1)
# load the data
Expand All @@ -3181,7 +3180,7 @@ def test_sd_params(self, detach):
assert (td == 0).all()

def test_sd_module(self):
td = TensorDict({"1": 1, "2": 2, "3": {"3": 3}}, [])
td = TensorDict({"1": 1.0, "2": 2.0, "3": {"3": 3.0}}, [])
td = TensorDictParams(td)
module = nn.Linear(3, 4)
module.td = td
Expand Down

1 comment on commit 785d11f

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 785d11f Previous: a721766 Ratio
benchmarks/common/memmap_benchmarks_test.py::test_memmaptd_index_op 188.9328079745529 iter/sec (stddev: 0.000495686102918768) 378.22226191768425 iter/sec (stddev: 0.00001613256021337471) 2.00

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.