Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Close the session before downloading the data #8

Merged
merged 2 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cdsobs/cli/_catalogue_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def list_catalogue(
# with pagination (50 per page)
results = list_catalogue_(session, filters, page)

if len(results) == 0:
raise RuntimeError("No catalogue entries found for these parameters.")
if len(results) == 0:
raise RuntimeError("No catalogue entries found for these parameters.")

print_db_results(results, print_format)
print_db_results(results, print_format)


def list_catalogue_(
Expand Down
16 changes: 7 additions & 9 deletions cdsobs/cli/_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from cdsobs.cli._utils import CliException, ConfigNotFound, config_yml_typer
from cdsobs.config import validate_config
from cdsobs.observation_catalogue.database import get_session
from cdsobs.retrieve.api import retrieve_observations
from cdsobs.retrieve.models import RetrieveArgs
from cdsobs.storage import S3Client
Expand Down Expand Up @@ -61,13 +60,12 @@ def retrieve(
raise ConfigNotFound()
config = validate_config(cdsobs_config_yml)
s3_client = S3Client.from_config(config.s3config)
with get_session(config.catalogue_db) as session:
output_file = retrieve_observations(
session,
s3_client.public_url_base,
retrieve_args,
output_dir,
size_limit,
)
output_file = retrieve_observations(
config.catalogue_db.get_url(),
s3_client.public_url_base,
retrieve_args,
output_dir,
size_limit,
)
console = Console()
console.print(f"[green] Successfully downloaded {output_file} [/green]")
6 changes: 1 addition & 5 deletions cdsobs/forms_jsons.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,7 @@ def get_variables_json(dataset: str, output_path: Path) -> Path:


def get_constraints_json(session, output_path: Path, dataset) -> Path:
"""
JSON file with the constraints in compressed form.

Beware this in the need of some optimization (may be resource heavy).
"""
"""JSON file with the constraints in compressed form."""
# This is probably slow, can it be improved?
catalogue_entries = get_catalogue_entries_stream(session, dataset)
merged_constraints = merged_constraints_table(catalogue_entries)
Expand Down
21 changes: 11 additions & 10 deletions cdsobs/retrieve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import pandas
import xarray
from fsspec.implementations.http import HTTPFileSystem
from sqlalchemy.orm import Session

from cdsobs.cdm.lite import cdm_lite_variables
from cdsobs.constants import TIME_UNITS_REFERENCE_DATE
Expand All @@ -27,14 +26,14 @@
from cdsobs.retrieve.retrieve_services import estimate_data_size, ezclump
from cdsobs.service_definition.api import get_service_definition
from cdsobs.utils.logutils import SizeError, get_logger
from cdsobs.utils.utils import get_code_mapping
from cdsobs.utils.utils import get_code_mapping, get_database_session

logger = get_logger(__name__)
MAX_NUMBER_OF_GROUPS = 10


def retrieve_observations(
session: Session,
catalogue_url: str,
storage_url: str,
retrieve_args: RetrieveArgs,
output_dir: Path,
Expand All @@ -45,8 +44,9 @@ def retrieve_observations(

Parameters
----------
session:
Session in the catalogue database
catalogue_url:
URL of the catalogue database including credentials, in the form of
"postgresql+psycopg2://someuser:somepass@hostname:port/catalogue"
storage_url:
Storage URL
retrieve_args :
Expand All @@ -58,11 +58,12 @@ def retrieve_observations(
"""
logger.info("Starting retrieve pipeline.")
# Query the storage to get the URLS of the files that contain the data requested
catalogue_repository = CatalogueRepository(session)
entries = _get_catalogue_entries(catalogue_repository, retrieve_args)
object_urls = _get_urls_and_check_size(
entries, retrieve_args, size_limit, storage_url
)
with get_database_session(catalogue_url) as session:
catalogue_repository = CatalogueRepository(session)
entries = _get_catalogue_entries(catalogue_repository, retrieve_args)
object_urls = _get_urls_and_check_size(
entries, retrieve_args, size_limit, storage_url
)
# Get the path of the output file
output_path_netcdf = _get_output_path(output_dir, retrieve_args.dataset, "netCDF")
# First we always write the netCDF-lite file
Expand Down
2 changes: 1 addition & 1 deletion cdsobs/sanity_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _sanity_check_dataset(
check_if_missing_in_object_storage(catalogue_repo, s3_client, dataset_name)
# Retrieve and check output
output_path = retrieve_observations(
session,
config.catalogue_db.get_url(),
s3_client.public_url_base,
retrieve_args,
Path(tmpdir),
Expand Down
4 changes: 1 addition & 3 deletions tests/cli/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_cli_make_production(verbose):
assert result.exit_code == 0


@pytest.mark.skip(reason="this test does not reset db after running")
# @pytest.mark.skip(reason="this test does not reset db after running")
def test_cli_retrieve(tmp_path, test_repository):
runner = CliRunner()
test_json_str = """[
Expand Down Expand Up @@ -61,8 +61,6 @@ def test_cli_retrieve(tmp_path, test_repository):
CONFIG_YML,
"--output-dir",
str(tmp_path),
"--np",
"2",
]
result = runner.invoke(
app,
Expand Down
11 changes: 7 additions & 4 deletions tests/cli/test_catalogue_explorer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from typer.testing import CliRunner

from cdsobs.cli._catalogue_explorer import list_catalogue_
Expand All @@ -8,10 +9,11 @@
runner = CliRunner()


def test_list_catalogue(test_session, test_repository):
@pytest.mark.parametrize("print_format", ["table", "json"])
def test_list_catalogue(test_session, test_repository, print_format):
result = runner.invoke(
app,
["list-catalogue", "-c", CONFIG_YML],
["list-catalogue", "-c", CONFIG_YML, "--print-format", print_format],
catch_exceptions=False,
)
assert result.exit_code == 0
Expand All @@ -26,10 +28,11 @@ def test_catalogue_dataset_info(test_session, test_repository):
assert result.exit_code == 0


def test_list_datasets():
@pytest.mark.parametrize("print_format", ["table", "json"])
def test_list_datasets(print_format):
result = runner.invoke(
app,
["list-datasets", "-c", CONFIG_YML, "--print-format", "json"],
["list-datasets", "-c", CONFIG_YML, "--print-format", print_format],
catch_exceptions=False,
)
assert result.exit_code == 0
Expand Down
10 changes: 5 additions & 5 deletions tests/retrieve/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from cdsobs.config import CDSObsConfig
from cdsobs.constants import CONFIG_YML
from cdsobs.observation_catalogue.database import get_session
from cdsobs.retrieve.api import retrieve_observations
from cdsobs.retrieve.models import RetrieveArgs
from cdsobs.storage import S3Client
Expand All @@ -23,7 +22,9 @@


@pytest.mark.parametrize("oformat,dataset_source,time_coverage", PARAMETRIZE_VALUES)
def test_retrieve(test_repository, tmp_path, oformat, dataset_source, time_coverage):
def test_retrieve(
test_repository, test_config, tmp_path, oformat, dataset_source, time_coverage
):
dataset_name = "insitu-observations-woudc-ozone-total-column-and-profiles"
start_year, end_year = get_test_years(dataset_source)
if dataset_source == "OzoneSonde":
Expand Down Expand Up @@ -52,7 +53,7 @@ def test_retrieve(test_repository, tmp_path, oformat, dataset_source, time_cover
retrieve_args = RetrieveArgs(dataset=dataset_name, params=params)
start = datetime.now()
output_file = retrieve_observations(
test_repository.catalogue_repository.session,
test_config.catalogue_db.get_url(),
test_repository.s3_client.base,
retrieve_args,
tmp_path,
Expand Down Expand Up @@ -86,10 +87,9 @@ def test_retrieve_cuon():
],
}
retrieve_args = RetrieveArgs(dataset=dataset_name, params=params)
session = get_session(test_config.catalogue_db)
s3_client = S3Client.from_config(test_config.s3config)
output_file = retrieve_observations(
session,
test_config.catalogue_db.get_url(),
s3_client.base,
retrieve_args,
Path("/tmp"),
Expand Down
7 changes: 4 additions & 3 deletions tests/system/1_year_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ def main():
retrieve_args = RetrieveArgs(dataset=dataset_name, params=params)
s3_client = S3Client.from_config(config.s3config)
start_time = time.perf_counter()
catalogue_url = config.catalogue_db.get_url()
retrieve_funct(
session,
catalogue_url,
s3_client.public_url_base,
retrieve_args,
tmpdir,
Expand Down Expand Up @@ -127,7 +128,7 @@ def main():
retrieve_args = RetrieveArgs(dataset=dataset_name, params=params)
start_time = time.perf_counter()
retrieve_funct(
session,
catalogue_url,
s3_client.public_url_base,
retrieve_args,
tmpdir,
Expand All @@ -151,7 +152,7 @@ def main():
retrieve_args = RetrieveArgs(dataset=dataset_name, params=params)
start_time = time.perf_counter()
retrieve_funct(
session,
catalogue_url,
s3_client.public_url_base,
retrieve_args,
tmpdir,
Expand Down
Loading