import jax
import jax.numpy as jnp

# Compute RMSE
def compute_rmse(y, y_est):
    return jnp.sqrt(jnp.sum((y-y_est)**2) / len(y))

# Compute RMSE of estimate and print comparison with 
# standard deviation of measurement noise
def compute_and_print_rmse_comparison(y, y_est, R, est_type=""):
    rmse_est = compute_rmse(y, y_est)
    print(f'{f"The RMSE of the {est_type} estimate is":<40}: {rmse_est:.2f}')
    print(f'{"The std of measurement noise is":<40}: {jnp.sqrt(R):.2f}')