Skip to content

Commit

Permalink
Update KANunittest.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Mattral authored May 5, 2024
1 parent b462708 commit a3288d2
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions KANunittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,20 @@ def test_forward_pass(self):
class TestBSplineFunctions(unittest.TestCase):
def test_extend_grid(self):
"""Test if the grid is extended correctly on both sides."""
grid = tf.constant([[0.0, 1.0, 2.0]])
grid = tf.constant([0.0, 1.0, 2.0]) # ensure it's 1D as expected
extended_grid = extend_grid_tf(grid, 1)
expected_output = [[-1.0, 0.0, 1.0, 2.0, 3.0]]
expected_output = [-1.0, 0.0, 1.0, 2.0, 3.0]
np.testing.assert_array_almost_equal(extended_grid.numpy(), expected_output)

def test_b_spline_basis(self):
"""Test B-spline basis computation for known inputs and grid."""
x = tf.constant([[0.5], [1.5], [2.5]])
grid = tf.constant([[0.0, 1.0, 2.0, 3.0]])
b_spline_values = B_batch_tf(x, grid, k=2, extend=False)
grid = tf.constant([0.0, 1.0, 2.0, 3.0]) # changed shape
b_spline_values = B_batch_tf(x, tf.expand_dims(grid, 0), k=2, extend=False) # ensure grid dimensions are expanded
expected_shape = (1, 3, 3) # (num_splines, num_samples, num_grid_points + k - 1)
self.assertEqual(b_spline_values.shape, expected_shape)


class TestKANModel(unittest.TestCase):
def test_model_construction(self):
"""Test the construction of the KAN model."""
Expand Down

0 comments on commit a3288d2

Please sign in to comment.