diff --git a/terracotta/scripts/cli.py b/terracotta/scripts/cli.py index 1bd062b4..1f4604ea 100644 --- a/terracotta/scripts/cli.py +++ b/terracotta/scripts/cli.py @@ -8,7 +8,7 @@ import click -from terracotta.scripts.click_utils import TOMLFile +from terracotta.scripts.click_types import TOMLFile from terracotta import get_settings, update_settings, logs, __version__ @@ -42,7 +42,7 @@ def cli(ctx: click.Context, def entrypoint() -> None: try: cli(obj={}) - except Exception as exc: + except Exception: import logging logger = logging.getLogger(__name__) logger.exception('Uncaught exception!', exc_info=True) @@ -52,8 +52,8 @@ def entrypoint() -> None: from terracotta.scripts.connect import connect cli.add_command(connect) -from terracotta.scripts.create_database import create_database -cli.add_command(create_database) +from terracotta.scripts.ingest import ingest +cli.add_command(ingest) from terracotta.scripts.optimize_rasters import optimize_rasters cli.add_command(optimize_rasters) diff --git a/terracotta/scripts/click_utils.py b/terracotta/scripts/click_types.py similarity index 63% rename from terracotta/scripts/click_utils.py rename to terracotta/scripts/click_types.py index bc69398d..0944de95 100644 --- a/terracotta/scripts/click_utils.py +++ b/terracotta/scripts/click_types.py @@ -1,4 +1,4 @@ -"""scripts/click_utils.py +"""scripts/click_types.py Custom click parameter types and utilities. """ @@ -7,7 +7,6 @@ import pathlib import glob import re -import os import string import click @@ -31,37 +30,61 @@ def convert(self, *args: Any) -> pathlib.Path: # type: ignore RasterPatternType = Tuple[List[str], Dict[Tuple[str, ...], str]] +def _parse_raster_pattern(raster_pattern: str) -> Tuple[List[str], str, str]: + """Parse a raster pattern string using Python format syntax. + + Extracts names of unique placeholders, a glob pattern + and a regular expression to retrieve files matching the given pattern. + + Example: + + >>> _parse_raster_pattern('{key1}/{key2}_{}.tif') + (['key1', 'key2'], '*/*_*.tif', '(?P[^\\W_]+)/(?P[^\\W_]+)_.*?\\.tif') + + """ + + # raises ValueError on invalid patterns + parsed_value = string.Formatter().parse(raster_pattern) + + keys: List[str] = [] + glob_pattern: List[str] = [] + regex_pattern: List[str] = [] + + for before_field, field_name, _, _ in parsed_value: + glob_pattern += before_field + regex_pattern += re.escape(before_field) + + if field_name is None: + # no placeholder + continue + + glob_pattern.append('*') + + if field_name == '': + # unnamed placeholder + regex_pattern.append('.*?') + elif field_name in keys: + # duplicate placeholder + key_group_number = keys.index(field_name) + 1 + regex_pattern.append(rf'\{key_group_number}') + else: + # new placeholder + keys.append(field_name) + regex_pattern += rf'(?P<{field_name}>[^\W_]+)' + + return keys, ''.join(glob_pattern), ''.join(regex_pattern) + + class RasterPattern(click.ParamType): """Expands a pattern following the Python format specification to matching files""" name = 'raster-pattern' def convert(self, value: str, *args: Any) -> RasterPatternType: - value = os.path.realpath(value) - try: - parsed_value = list(string.Formatter().parse(value)) + keys, glob_pattern, regex_pattern = _parse_raster_pattern(value) except ValueError as exc: self.fail(f'Invalid pattern: {exc!s}') - # extract keys from format string and assemble glob and regex patterns matching it - keys = [] - glob_pattern = '' - regex_pattern = '' - for before_field, field_name, _, _ in parsed_value: - glob_pattern += before_field - regex_pattern += re.escape(before_field) - if field_name is None: # no placeholder - continue - glob_pattern += '*' - if field_name == '': # unnamed placeholder - regex_pattern += '.*?' - elif field_name in keys: # duplicate placeholder - key_group_number = keys.index(field_name) + 1 - regex_pattern += f'\\{key_group_number}' - else: # new placeholder - keys.append(field_name) - regex_pattern += f'(?P<{field_name}>[^\\W_]+)' - if not keys: self.fail('Pattern must contain at least one placeholder') @@ -69,7 +92,7 @@ def convert(self, value: str, *args: Any) -> RasterPatternType: self.fail('Key names must be alphanumeric') # use glob to find candidates, regex to extract placeholder values - candidates = map(os.path.realpath, glob.glob(glob_pattern)) + candidates = glob.glob(glob_pattern) matched_candidates = [re.match(regex_pattern, candidate) for candidate in candidates] if not any(matched_candidates): diff --git a/terracotta/scripts/connect.py b/terracotta/scripts/connect.py index a9a27d3d..1875e9e2 100644 --- a/terracotta/scripts/connect.py +++ b/terracotta/scripts/connect.py @@ -10,7 +10,7 @@ import click -from terracotta.scripts.click_utils import Hostname +from terracotta.scripts.click_types import Hostname from terracotta.scripts.http_utils import find_open_port diff --git a/terracotta/scripts/create_database.py b/terracotta/scripts/ingest.py similarity index 68% rename from terracotta/scripts/create_database.py rename to terracotta/scripts/ingest.py index 492f3067..7d258faf 100644 --- a/terracotta/scripts/create_database.py +++ b/terracotta/scripts/ingest.py @@ -1,4 +1,4 @@ -"""scripts/create_database.py +"""scripts/ingest.py A convenience tool to create a Terracotta database from some raster files. """ @@ -10,32 +10,29 @@ import click import tqdm -from terracotta.scripts.click_utils import RasterPattern, RasterPatternType, PathlibPath +from terracotta.scripts.click_types import RasterPattern, RasterPatternType, PathlibPath logger = logging.getLogger(__name__) -@click.command('create-database', - short_help='Create a new SQLite raster database from a collection of raster files.') +@click.command('ingest', + short_help='Ingest a collection of raster files into a SQLite database.') @click.argument('raster-pattern', type=RasterPattern(), required=True) @click.option('-o', '--output-file', required=True, help='Path to output file', type=PathlibPath(dir_okay=False, writable=True)) -@click.option('--overwrite', is_flag=True, default=False, - help='Always overwrite existing database without asking') @click.option('--skip-metadata', is_flag=True, default=False, - help='Speed up ingestion by not pre-computing metadata ' + help='Speed up ingestion by skipping computation of metadata ' '(will be computed on first request instead)') @click.option('--rgb-key', default=None, help='Key to use for RGB compositing [default: last key in pattern]') @click.option('-q', '--quiet', is_flag=True, default=False, show_default=True, help='Suppress all output to stdout') -def create_database(raster_pattern: RasterPatternType, - output_file: Path, - overwrite: bool = False, - skip_metadata: bool = False, - rgb_key: str = None, - quiet: bool = False) -> None: - """Create a new SQLite raster database from a collection of raster files. +def ingest(raster_pattern: RasterPatternType, + output_file: Path, + skip_metadata: bool = False, + rgb_key: str = None, + quiet: bool = False) -> None: + """Ingest a collection of raster files into a SQLite database. First argument is a format pattern defining paths and keys of all raster files. @@ -45,16 +42,13 @@ def create_database(raster_pattern: RasterPatternType, The empty group {} is replaced by a wildcard matching anything (similar to * in glob patterns). + Existing datasets are silently overwritten. + This command only supports the creation of a simple SQLite database without any additional metadata. For more sophisticated use cases use the Terracotta Python API. """ from terracotta import get_driver - if output_file.is_file() and not overwrite: - click.confirm(f'Existing output file {output_file} will be overwritten. Continue?', - abort=True) - output_file.unlink() - keys, raster_files = raster_pattern if rgb_key is not None: @@ -71,7 +65,16 @@ def push_to_last(seq: Sequence[Any], index: int) -> Tuple[Any, ...]: raster_files = {push_to_last(k, rgb_idx): v for k, v in raster_files.items()} driver = get_driver(output_file) - driver.create(keys) + + if not output_file.is_file(): + driver.create(keys) + + if tuple(keys) != driver.key_names: + click.echo( + f'Database file {output_file!s} has incompatible key names {driver.key_names}', + err=True + ) + click.Abort() with driver.connect(): progress = tqdm.tqdm(raster_files.items(), desc='Ingesting raster files', disable=quiet) diff --git a/terracotta/scripts/optimize_rasters.py b/terracotta/scripts/optimize_rasters.py index 5e99ae0a..1fb9a025 100644 --- a/terracotta/scripts/optimize_rasters.py +++ b/terracotta/scripts/optimize_rasters.py @@ -3,7 +3,7 @@ Convert some raster files to cloud-optimized GeoTiff for use with Terracotta. """ -from typing import Sequence, Dict, Any +from typing import Sequence, Iterator, Union import os import math import itertools @@ -17,7 +17,7 @@ from rasterio.vrt import WarpedVRT from rasterio.enums import Resampling -from terracotta.scripts.click_utils import GlobbityGlob, PathlibPath +from terracotta.scripts.click_types import GlobbityGlob, PathlibPath logger = logging.getLogger(__name__) @@ -39,9 +39,9 @@ 'tiled': True, 'blockxsize': 256, 'blockysize': 256, - #'compress': 'DEFLATE', - 'ZLEVEL': 1, 'photometric': 'MINISBLACK', + 'ZLEVEL': 1, + 'ZSTD_LEVEL': 9, 'BIGTIFF': 'IF_SAFER' } @@ -53,8 +53,13 @@ } -def _prefered_compression_method() -> Dict[str, Any]: - import rasterio +def _prefered_compression_method() -> str: + from rasterio.env import GDALVersion + + if GDALVersion.runtime() < GDALVersion.parse('2.3'): + return 'ZSTD' + + return 'DEFLATE' def _get_vrt(src: DatasetReader, rs_method: int) -> WarpedVRT: @@ -71,10 +76,10 @@ def _get_vrt(src: DatasetReader, rs_method: int) -> WarpedVRT: @contextlib.contextmanager -def named_tempfile(basedir: str = None) -> str: +def _named_tempfile(basedir: Union[str, Path] = None) -> Iterator[str]: if basedir is None: basedir = tempfile.gettempdir() - fileobj = tempfile.NamedTemporaryFile(dir=basedir, suffix='.tif') + fileobj = tempfile.NamedTemporaryFile(dir=str(basedir), suffix='.tif') fileobj.close() try: yield fileobj.name @@ -82,26 +87,43 @@ def named_tempfile(basedir: str = None) -> str: os.remove(fileobj.name) -TemporaryRasterFile = named_tempfile +TemporaryRasterFile = _named_tempfile -@click.command('optimize-rasters', - short_help='Optimize a collection of raster files for use with Terracotta.') +@click.command( + 'optimize-rasters', + short_help='Optimize a collection of raster files for use with Terracotta.' +) @click.argument('raster-files', nargs=-1, type=GlobbityGlob(), required=True) -@click.option('-o', '--output-folder', required=True, - type=PathlibPath(file_okay=False, writable=True), - help='Output folder for cloud-optimized rasters. Subdirectories will be flattened.') -@click.option('--overwrite', is_flag=True, default=False, help='Force overwrite of existing files') -@click.option('--resampling-method', type=click.Choice(RESAMPLING_METHODS.keys()), - default='average', help='Resampling method for overviews', show_default=True) -@click.option('--reproject', is_flag=True, default=False, show_default=True, - help='Reproject raster file to Web Mercator for faster access') -@click.option('--in-memory/--no-in-memory', default=None, - help='Force processing raster in memory / not in memory [default: process in memory ' - f'if smaller than {IN_MEMORY_THRESHOLD // 1e6:.0f} million pixels]') -@click.option('--compression', default='auto', type=click.Choice(['auto', 'deflate', 'lzw', 'zstd', 'none'])) -@click.option('-q', '--quiet', is_flag=True, default=False, show_default=True, - help='Suppress all output to stdout') +@click.option( + '-o', '--output-folder', required=True, + type=PathlibPath(file_okay=False, writable=True), + help='Output folder for cloud-optimized rasters. Subdirectories will be flattened.' +) +@click.option( + '--overwrite', is_flag=True, default=False, help='Force overwrite of existing files' +) +@click.option( + '--resampling-method', type=click.Choice(RESAMPLING_METHODS.keys()), + default='average', help='Resampling method for overviews', show_default=True +) +@click.option( + '--reproject', is_flag=True, default=False, show_default=True, + help='Reproject raster file to Web Mercator for faster access' +) +@click.option( + '--in-memory/--no-in-memory', default=None, + help='Force processing raster in memory / not in memory [default: process in memory ' + f'if smaller than {IN_MEMORY_THRESHOLD // 1e6:.0f} million pixels]' +) +@click.option( + '--compression', default='auto', type=click.Choice(['auto', 'deflate', 'lzw', 'zstd', 'none']), + help='Compression algorithm to use [default: auto (ZSTD if available, DEFLATE otherwise)' +) +@click.option( + '-q', '--quiet', is_flag=True, default=False, show_default=True, + help='Suppress all output to stdout' +) def optimize_rasters(raster_files: Sequence[Sequence[Path]], output_folder: Path, overwrite: bool = False, @@ -126,18 +148,27 @@ def optimize_rasters(raster_files: Sequence[Sequence[Path]], from rasterio.shutil import copy raster_files_flat = sorted(set(itertools.chain.from_iterable(raster_files))) - rs_method = RESAMPLING_METHODS[resampling_method] if not raster_files_flat: click.echo('No files given') return + rs_method = RESAMPLING_METHODS[resampling_method] + + if compression == 'auto': + compression = _prefered_compression_method() + total_pixels = 0 for f in raster_files_flat: if not f.is_file(): raise click.BadParameter(f'Input raster {f!s} is not a file') with rasterio.open(str(f), 'r') as src: + if src.count > 1 and not quiet: + click.echo( + f'Warning: raster file {f!s} has more than one band. ' + 'Only the first one will be used.', err=True + ) total_pixels += src.height * src.width output_folder.mkdir(exist_ok=True) @@ -146,7 +177,14 @@ def optimize_rasters(raster_files: Sequence[Sequence[Path]], # insert newline for nicer progress bar style click.echo('') - with tqdm.tqdm(total=total_pixels, smoothing=0, unit_scale=True, disable=quiet, desc='Optimizing rasters') as pbar, rasterio.Env(**GDAL_CONFIG): + with contextlib.ExitStack() as outer_env: + pbar = outer_env.enter_context(tqdm.tqdm( + total=total_pixels, smoothing=0, disable=quiet, + bar_format='{l_bar}{bar}| [{elapsed}<{remaining}{postfix}]', + desc='Optimizing rasters' + )) + outer_env.enter_context(rasterio.Env(**GDAL_CONFIG)) + for input_file in raster_files_flat: if len(input_file.name) > 30: short_name = input_file.name[:13] + '...' + input_file.name[-13:] @@ -165,12 +203,6 @@ def optimize_rasters(raster_files: Sequence[Sequence[Path]], with contextlib.ExitStack() as es: src = es.enter_context(rasterio.open(str(input_file))) - if src.count > 1: - click.echo( - f'Warning: raster file {input_file!s} has more than one band. ' - 'Only the first one will be used.', err=True - ) - if reproject: vrt = es.enter_context(_get_vrt(src, rs_method=rs_method)) else: @@ -192,7 +224,7 @@ def optimize_rasters(raster_files: Sequence[Sequence[Path]], # iterate over blocks windows = list(dst.block_windows(1)) - for _, w in tqdm.tqdm(windows, desc='Reading'): + for _, w in tqdm.tqdm(windows, desc='Reading', leave=False): block_data = vrt.read(window=w, indexes=[1]) dst.write(block_data, window=w) block_mask = vrt.dataset_mask(window=w) @@ -205,13 +237,16 @@ def optimize_rasters(raster_files: Sequence[Sequence[Path]], ))) overviews = [2 ** j for j in range(1, max_overview_level + 1)] - for overview in tqdm.tqdm(overviews, desc='Creating overviews'): + for overview in tqdm.tqdm(overviews, desc='Creating overviews', leave=False): dst.build_overviews([overview], rs_method) - dst.update_tags(ns='tc_overview', resampling=rs_method.value) + + dst.update_tags(ns='rio_overview', resampling=rs_method.value) # copy to destination (this is necessary to push overviews to start of file) - with tqdm.tqdm(desc='Compressing') as compbar: - copy(dst, str(output_file), copy_src_overviews=True, **COG_PROFILE, compress='deflate') - compbar.update(1) + with tqdm.tqdm(desc='Compressing', leave=False): + copy( + dst, str(output_file), copy_src_overviews=True, + compress=compression, **COG_PROFILE + ) pbar.update(dst.height * dst.width) diff --git a/terracotta/scripts/serve.py b/terracotta/scripts/serve.py index ad573002..c28c7adc 100644 --- a/terracotta/scripts/serve.py +++ b/terracotta/scripts/serve.py @@ -10,7 +10,7 @@ import click import tqdm -from terracotta.scripts.click_utils import RasterPattern, RasterPatternType +from terracotta.scripts.click_types import RasterPattern, RasterPatternType from terracotta.scripts.http_utils import find_open_port logger = logging.getLogger(__name__) @@ -41,11 +41,13 @@ def serve(database: str = None, rgb_key: str = None) -> None: """Serve rasters through a local Flask development server. - Either --database or --raster-pattern must be given. + Either -d/--database or -r/--raster-pattern must be given. Example: - terracotta serve -r /path/to/rasters/{name}/{date}_{band}.tif + terracotta serve -r /path/to/rasters/{name}/{date}_{band}_{}.tif + + The empty group {} is replaced by a wildcard matching anything (similar to * in glob patterns). This command is a data exploration tool and not meant for production use. Deploy Terracotta as a WSGI or serverless app instead. diff --git a/tests/handlers/test_colormap.py b/tests/handlers/test_colormap.py index 30dbd34f..4d765271 100644 --- a/tests/handlers/test_colormap.py +++ b/tests/handlers/test_colormap.py @@ -58,9 +58,9 @@ def test_colormap_consistency(use_read_only_database, read_only_database, raster # test values inside stretch_range values_to_test = np.unique(tile_data) - values_to_test = values_to_test[(values_to_test >= stretch_range[0]) & - (values_to_test <= stretch_range[1]) & - (values_to_test != nodata)] + values_to_test = values_to_test[(values_to_test >= stretch_range[0]) + & (values_to_test <= stretch_range[1]) + & (values_to_test != nodata)] for val in values_to_test: rgb = cmap[val] diff --git a/tests/scripts/test_create_database.py b/tests/scripts/test_ingest.py similarity index 72% rename from tests/scripts/test_create_database.py rename to tests/scripts/test_ingest.py index 9aa72fdd..724be36d 100644 --- a/tests/scripts/test_create_database.py +++ b/tests/scripts/test_ingest.py @@ -110,14 +110,14 @@ def tmpworkdir(tmpdir): os.chdir(orig_dir) -def test_create_database(raster_file, tmpdir): +def test_ingest(raster_file, tmpdir): from terracotta.scripts import cli outfile = tmpdir / 'out.sqlite' input_pattern = str(raster_file.dirpath('{name}.tif')) runner = CliRunner() - result = runner.invoke(cli.cli, ['create-database', input_pattern, '-o', str(outfile)]) + result = runner.invoke(cli.cli, ['ingest', input_pattern, '-o', str(outfile)]) assert result.exit_code == 0 assert outfile.check() @@ -127,9 +127,54 @@ def test_create_database(raster_file, tmpdir): assert driver.get_datasets() == {('img',): str(raster_file)} +def test_ingest_append(raster_file, tmpworkdir): + from terracotta.scripts import cli + + for infile in ('dir1/img1.tif', 'dir2/img2.tif'): + temp_infile = tmpworkdir / infile + os.makedirs(temp_infile.dirpath(), exist_ok=True) + shutil.copy(raster_file, temp_infile) + + outfile = tmpworkdir / 'out.sqlite' + + runner = CliRunner() + result = runner.invoke(cli.cli, ['ingest', 'dir1/{name}.tif', '-o', str(outfile)]) + assert result.exit_code == 0 + assert outfile.check() + + result = runner.invoke(cli.cli, ['ingest', 'dir2/{name}.tif', '-o', str(outfile)]) + assert result.exit_code == 0 + assert outfile.check() + + from terracotta import get_driver + driver = get_driver(str(outfile), provider='sqlite') + assert driver.key_names == ('name',) + assert all((ds,) in driver.get_datasets() for ds in ('img1', 'img2')) + + +def test_ingest_append_invalid(raster_file, tmpworkdir): + from terracotta.scripts import cli + + for infile in ('dir1/img1.tif', 'dir2/img2.tif'): + temp_infile = tmpworkdir / infile + os.makedirs(temp_infile.dirpath(), exist_ok=True) + shutil.copy(raster_file, temp_infile) + + outfile = tmpworkdir / 'out.sqlite' + + runner = CliRunner() + result = runner.invoke(cli.cli, ['ingest', 'dir1/{name}.tif', '-o', str(outfile)]) + assert result.exit_code == 0 + assert outfile.check() + + result = runner.invoke(cli.cli, ['ingest', '{dir}/{name}.tif', '-o', str(outfile)]) + assert result.exit_code != 0 + assert 'incompatible key names' in result.output + + @pytest.mark.parametrize('case', TEST_CASES) @pytest.mark.parametrize('abspath', [True, False]) -def test_create_database_pattern(case, abspath, raster_file, tmpworkdir): +def test_ingest_pattern(case, abspath, raster_file, tmpworkdir): from terracotta.scripts import cli for infile in case['filenames']: @@ -145,7 +190,7 @@ def test_create_database_pattern(case, abspath, raster_file, tmpworkdir): input_pattern = case['input_pattern'] runner = CliRunner() - result = runner.invoke(cli.cli, ['create-database', input_pattern, '-o', str(outfile)]) + result = runner.invoke(cli.cli, ['ingest', input_pattern, '-o', str(outfile)]) assert result.exit_code == 0, result.output assert outfile.check() @@ -156,7 +201,7 @@ def test_create_database_pattern(case, abspath, raster_file, tmpworkdir): @pytest.mark.parametrize('case', INVALID_TEST_CASES) -def test_create_database_invalid_pattern(case, raster_file, tmpworkdir): +def test_ingest_invalid_pattern(case, raster_file, tmpworkdir): from terracotta.scripts import cli for infile in case['filenames']: @@ -168,12 +213,12 @@ def test_create_database_invalid_pattern(case, raster_file, tmpworkdir): input_pattern = case['input_pattern'] runner = CliRunner() - result = runner.invoke(cli.cli, ['create-database', input_pattern, '-o', str(outfile)]) + result = runner.invoke(cli.cli, ['ingest', input_pattern, '-o', str(outfile)]) assert result.exit_code != 0 assert case['error_contains'].lower() in result.output.lower() -def test_create_database_rgb_key(raster_file, tmpdir): +def test_ingest_rgb_key(raster_file, tmpdir): from terracotta.scripts import cli outfile = tmpdir / 'out.sqlite' @@ -181,7 +226,7 @@ def test_create_database_rgb_key(raster_file, tmpdir): runner = CliRunner() result = runner.invoke( - cli.cli, ['create-database', input_pattern, '-o', str(outfile), '--rgb-key', 'rgb'] + cli.cli, ['ingest', input_pattern, '-o', str(outfile), '--rgb-key', 'rgb'] ) assert result.exit_code == 0 assert outfile.check() @@ -192,7 +237,7 @@ def test_create_database_rgb_key(raster_file, tmpdir): assert driver.get_datasets() == {('g', 'i'): str(raster_file)} -def test_create_database_invalid_rgb_key(raster_file, tmpdir): +def test_ingest_invalid_rgb_key(raster_file, tmpdir): from terracotta.scripts import cli outfile = tmpdir / 'out.sqlite' @@ -200,7 +245,7 @@ def test_create_database_invalid_rgb_key(raster_file, tmpdir): runner = CliRunner() result = runner.invoke( - cli.cli, ['create-database', input_pattern, '-o', str(outfile), '--rgb-key', 'bar'] + cli.cli, ['ingest', input_pattern, '-o', str(outfile), '--rgb-key', 'bar'] ) assert result.exit_code != 0 assert not outfile.check() diff --git a/tests/scripts/test_optimize_rasters.py b/tests/scripts/test_optimize_rasters.py index 5b7528f0..486d7818 100644 --- a/tests/scripts/test_optimize_rasters.py +++ b/tests/scripts/test_optimize_rasters.py @@ -10,7 +10,8 @@ @pytest.mark.parametrize('in_memory', [True, None, False]) @pytest.mark.parametrize('reproject', [True, False]) -def test_optimize_rasters(unoptimized_raster_file, tmpdir, in_memory, reproject): +@pytest.mark.parametrize('compression', ['auto', 'lzw', 'none']) +def test_optimize_rasters(unoptimized_raster_file, tmpdir, in_memory, reproject, compression): from terracotta.cog import validate from terracotta.scripts import cli @@ -19,7 +20,7 @@ def test_optimize_rasters(unoptimized_raster_file, tmpdir, in_memory, reproject) runner = CliRunner() - flags = [] + flags = ['--compression', compression] if in_memory is not None: flags.append('--in-memory' if in_memory else '--no-in-memory') @@ -65,3 +66,34 @@ def test_optimize_rasters_invalid(tmpdir): assert result.exit_code != 0 assert 'not a file' in result.output + + +def test_optimize_rasters_multiband(tmpdir, unoptimized_raster_file): + from terracotta.scripts import cli + import rasterio + + with rasterio.open(str(unoptimized_raster_file)) as src: + profile = src.profile.copy() + data = src.read(1) + + profile['count'] = 3 + + multiband_file = tmpdir.join(unoptimized_raster_file.basename) + with rasterio.open(str(multiband_file), 'w', **profile) as dst: + dst.write(data, 1) + dst.write(data, 2) + dst.write(data, 3) + + input_pattern = str(multiband_file.dirpath('*.tif')) + outfile = tmpdir / 'co' / unoptimized_raster_file.basename + + runner = CliRunner() + result = runner.invoke(cli.cli, ['optimize-rasters', input_pattern, '-o', str(tmpdir / 'co')]) + + assert result.exit_code == 0 + assert 'has more than one band' in result.output + + with rasterio.open(str(unoptimized_raster_file)) as src1, rasterio.open(str(outfile)) as src2: + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'invalid value encountered.*') + np.testing.assert_array_equal(src1.read(), src2.read()) diff --git a/tests/server/test_flask_api.py b/tests/server/test_flask_api.py index 827b3b5e..4f84c4db 100644 --- a/tests/server/test_flask_api.py +++ b/tests/server/test_flask_api.py @@ -48,7 +48,7 @@ def test_get_datasets(client, use_read_only_database): rv = client.get('/datasets') assert rv.status_code == 200 datasets = json.loads(rv.data, object_pairs_hook=OrderedDict)['datasets'] - assert len(datasets) == 3 + assert len(datasets) == 4 assert OrderedDict([('key1', 'val11'), ('akey', 'x'), ('key2', 'val12')]) in datasets