Skip to content

Commit

Permalink
Fix tests and add tests for temperature, compliance, and displacement…
Browse files Browse the repository at this point in the history
… grads
  • Loading branch information
yaugenst committed Nov 20, 2023
1 parent 5a7cf20 commit 08f42b0
Showing 1 changed file with 61 additions and 4 deletions.
65 changes: 61 additions & 4 deletions tests/test_fea2d.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
import numpy as np
import pytest
from autograd.test_util import check_grads

from tofea.fea2d import FEA2D_K, FEA2D_T


@pytest.fixture()
def rng():
seed = 36523523
return np.random.default_rng(seed)


class TestFEA2DK:
@pytest.fixture()
def fea2d_k_instance(self):
fixed = np.zeros((5, 5, 2), dtype=bool)
fixed[0] = 1
return FEA2D_K(fixed)

@pytest.fixture()
def x_and_b(self, fea2d_k_instance, rng):
x = rng.random(fea2d_k_instance.shape)
b = rng.random(fea2d_k_instance.fixed.shape)
return x, b

def test_shape(self, fea2d_k_instance):
assert fea2d_k_instance.shape == (4, 4)

Expand All @@ -22,20 +36,50 @@ def test_dofs(self, fea2d_k_instance):
def test_fixdofs(self, fea2d_k_instance):
fixdofs = fea2d_k_instance.fixdofs
assert isinstance(fixdofs, np.ndarray)
assert fixdofs.size == 0
assert fixdofs.size == fea2d_k_instance.fixed[0].size

def test_freedofs(self, fea2d_k_instance):
freedofs = fea2d_k_instance.freedofs
assert isinstance(freedofs, np.ndarray)
assert np.all(freedofs == np.arange(50))
assert freedofs.size == 50 - fea2d_k_instance.fixdofs.size

def test_displacement_grads(self, fea2d_k_instance, x_and_b):
x, b = x_and_b
check_grads(
lambda x_: fea2d_k_instance.displacement(x_, b),
modes=["fwd", "rev"],
order=1,
)(x)
check_grads(
lambda b_: fea2d_k_instance.displacement(x, b_),
modes=["fwd", "rev"],
order=1,
)(b)

def test_compliance_grads(self, fea2d_k_instance, x_and_b, rng):
x, _ = x_and_b
d = rng.random(fea2d_k_instance.dofs.shape)
check_grads(
lambda x_: fea2d_k_instance.compliance(x_, d), modes=["fwd", "rev"], order=1
)(x)
check_grads(
lambda d_: fea2d_k_instance.compliance(x, d_), modes=["fwd", "rev"], order=1
)(d)


class TestFEA2DT:
@pytest.fixture()
def fea2d_t_instance(self):
fixed = np.zeros((5, 5), dtype=bool)
fixed[0, 0] = 1
return FEA2D_T(fixed)

@pytest.fixture()
def x_and_b(self, fea2d_t_instance, rng):
x = rng.random(fea2d_t_instance.shape)
b = rng.random(fea2d_t_instance.fixed.shape)
return x, b

def test_shape(self, fea2d_t_instance):
assert fea2d_t_instance.shape == (4, 4)

Expand All @@ -48,9 +92,22 @@ def test_dofs(self, fea2d_t_instance):
def test_fixdofs(self, fea2d_t_instance):
fixdofs = fea2d_t_instance.fixdofs
assert isinstance(fixdofs, np.ndarray)
assert fixdofs.size == 0
assert fixdofs.size == 1

def test_freedofs(self, fea2d_t_instance):
freedofs = fea2d_t_instance.freedofs
assert isinstance(freedofs, np.ndarray)
assert np.all(freedofs == np.arange(25))
assert freedofs.size == 25 - fea2d_t_instance.fixdofs.size

def test_temperature_grads(self, fea2d_t_instance, x_and_b):
x, b = x_and_b
check_grads(
lambda x_: fea2d_t_instance.temperature(x_, b),
modes=["fwd", "rev"],
order=1,
)(x)
check_grads(
lambda b_: fea2d_t_instance.temperature(x, b_),
modes=["fwd", "rev"],
order=1,
)(b)

0 comments on commit 08f42b0

Please sign in to comment.