-
Notifications
You must be signed in to change notification settings - Fork 0
split
Okerew edited this page Aug 2, 2024
·
3 revisions
You can use split like here
from okrolearn.okrolearn import Tensor, np
def test_split():
# Create a simple tensor
data = np.array([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
tensor = Tensor(data)
# Test split into equal sections
split_tensors = tensor.split(2, axis=1)
assert len(split_tensors) == 2
assert np.array_equal(split_tensors[0].data, np.array([[1, 2], [5, 6], [9, 10]]))
assert np.array_equal(split_tensors[1].data, np.array([[3, 4], [7, 8], [11, 12]]))
# Test backward pass for equal sections
split_tensors[0].backward()
split_tensors[1].backward()
expected_grad = np.ones_like(data)
assert np.array_equal(tensor.grad, expected_grad)
# Reset grad for next test
tensor.grad = None
# Test split at specific indices
split_tensors = tensor.split([1, 3], axis=1)
assert len(split_tensors) == 3
assert np.array_equal(split_tensors[0].data, np.array([[1], [5], [9]]))
assert np.array_equal(split_tensors[1].data, np.array([[2, 3], [6, 7], [10, 11]]))
assert np.array_equal(split_tensors[2].data, np.array([[4], [8], [12]]))
# Test backward pass for specific indices
split_tensors[0].backward()
split_tensors[1].backward()
split_tensors[2].backward()
assert np.array_equal(tensor.grad, expected_grad)
print("All tests passed!")
# Run the test
test_split()