Skip to content

Commit

Permalink
added history matching and test
Browse files Browse the repository at this point in the history
  • Loading branch information
marjanfamili committed Feb 12, 2025
1 parent 005b3c0 commit 3bfcaf8
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 1,231 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ notebooks/
__pycache__/
.pytest_cache/
dist/
my_tests/

# Ignore Sphinx build artifacts
docs/build/
Expand Down
33 changes: 33 additions & 0 deletions autoemulate/history_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np


def history_matching(obs, expectations, threshold=3.0, discrepancy=0.0, rank=1):
"""
Perform history matching to compute implausibility and identify NROY and RO points.
Parameters:
obs (tuple): Observations as (mean, variance).
expectations (tuple): Predicted (mean, variance).
threshold (float): Implausibility threshold for NROY classification.
discrepancy (float or ndarray): Discrepancy value(s).
rank (int): Rank for implausibility calculation.
Returns:
dict: Contains implausibility (I), NROY indices, and RO indices.
"""
obs_mean, obs_var = np.atleast_1d(obs[0]), np.atleast_1d(obs[1])
pred_mean, pred_var = np.atleast_1d(expectations[0]), np.atleast_1d(expectations[1])

discrepancy = np.atleast_1d(discrepancy)
n_obs = len(obs_mean)
rank = min(max(rank, 0), n_obs - 1)
if discrepancy.size == 1:
discrepancy = np.full(n_obs, discrepancy)

Vs = pred_var + discrepancy + obs_var
I = np.abs(obs_mean - pred_mean) / np.sqrt(Vs)

NROY = np.where(I <= threshold)[0]
RO = np.where(I > threshold)[0]

return {"I": I, "NROY": list(NROY), "RO": list(RO)}
1,306 changes: 75 additions & 1,231 deletions docs/tutorials/01_start.ipynb

Large diffs are not rendered by default.

56 changes: 56 additions & 0 deletions tests/test_history_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np
import pytest

from autoemulate.history_matching import history_matching


@pytest.fixture
def sample_data_2d():
pred_mean = np.array([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1], [4.0, 4.1], [5.0, 5.1]])
pred_std = np.array([[0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4], [0.5, 0.5]])
pred_var = np.square(pred_std)
expectations = (pred_mean, pred_var)
obs = [(1.5, 0.1), (2.5, 0.2)]
return expectations, obs


@pytest.fixture
def sample_data_1d():
pred_mean = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
pred_std = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
pred_var = np.square(pred_std)
expectations = (pred_mean, pred_var)
obs = [1.5, 10]
return expectations, obs


def test_history_matching_1d(sample_data_1d):
expectations, obs = sample_data_1d
result = history_matching(expectations=expectations, obs=obs, threshold=1.0)
assert "NROY" in result # Ensure the key exists in the result
assert isinstance(result["NROY"], list) # Validate that NROY is a list
assert len(result["NROY"]) > 0 # Ensure the list is not empty


def test_history_matching_threshold_1d(sample_data_1d):
expectations, obs = sample_data_1d
result = history_matching(expectations=expectations, obs=obs, threshold=0.5)
assert "NROY" in result
assert isinstance(result["NROY"], list)
assert len(result["NROY"]) <= len(expectations[0])


def test_history_matching_2d(sample_data_2d):
expectations, obs = sample_data_2d
result = history_matching(expectations=expectations, obs=obs, threshold=1.0)
assert "NROY" in result # Ensure the key exists in the result
assert isinstance(result["NROY"], list) # Validate that NROY is a list
assert len(result["NROY"]) > 0 # Ensure the list is not empty


def test_history_matching_threshold_2d(sample_data_2d):
expectations, obs = sample_data_2d
result = history_matching(expectations=expectations, obs=obs, threshold=0.5)
assert "NROY" in result
assert isinstance(result["NROY"], list)
assert len(result["NROY"]) <= len(expectations[0])

0 comments on commit 3bfcaf8

Please sign in to comment.