diff --git a/conf/env.py b/conf/env.py index 90b9ec02..b09ab38f 100644 --- a/conf/env.py +++ b/conf/env.py @@ -218,6 +218,9 @@ class BaseSettings(PydanticBaseSettings): eyb_salary_s3_prefix: str = '' eyb_rent_s3_prefix: str = '' postcode_from_s3_prefix: str = '' + nomis_uk_business_employee_counts_from_s3_prefix: str = '' + ref_sic_codes_mapping_from_s3_prefix: str = '' + sector_reference_dataset_from_s3_prefix: str = '' class CIEnvironment(BaseSettings): diff --git a/conf/settings.py b/conf/settings.py index f3419241..1bf7d0c3 100644 --- a/conf/settings.py +++ b/conf/settings.py @@ -613,3 +613,6 @@ EYB_SALARY_S3_PREFIX = env.eyb_salary_s3_prefix EYB_RENT_S3_PREFIX = env.eyb_rent_s3_prefix POSTCODE_FROM_S3_PREFIX = env.postcode_from_s3_prefix +NOMIS_UK_BUSINESS_EMPLOYEE_COUNTS_FROM_S3_PREFIX = env.nomis_uk_business_employee_counts_from_s3_prefix +REF_SIC_CODES_MAPPING_FROM_S3_PREFIX = env.ref_sic_codes_mapping_from_s3_prefix +SECTOR_REFERENCE_DATASET_FROM_S3_PREFIX = env.sector_reference_dataset_from_s3_prefix diff --git a/dataservices/management/commands/import_eyb_business_cluster_information.py b/dataservices/management/commands/import_eyb_business_cluster_information.py index 89b8fcdc..bed1b58f 100644 --- a/dataservices/management/commands/import_eyb_business_cluster_information.py +++ b/dataservices/management/commands/import_eyb_business_cluster_information.py @@ -1,15 +1,214 @@ +import json +import logging + import pandas as pd import sqlalchemy as sa +from django.conf import settings +from django.core.management.base import BaseCommand +from sqlalchemy.ext.declarative import declarative_base -from dataservices.models import EYBBusinessClusterInformation +from dataservices.core.mixins import S3DownloadMixin +from dataservices.management.commands.helpers import ingest_data -from .helpers import BaseDataWorkspaceIngestionCommand +logger = logging.getLogger(__name__) -class Command(BaseDataWorkspaceIngestionCommand): - help = 'Import ONS total UK business and employee counts per region and section, 2 and 5 digit Standard Industrial Classification' # noqa:E501 +def get_uk_business_employee_counts_tmp_batch(data, data_table): + + def get_table_data(): + + for uk_business_employee_count in data: + json_data = json.loads(uk_business_employee_count) + + yield ( + ( + data_table, + ( + json_data['geo_description'], + json_data['geo_code'], + json_data['sic_code'], + json_data['sic_description'], + json_data['total_business_count'], + json_data['business_count_release_year'], + # missing employee data represented as np.nan which results in error saving django model + # columns are int in dataframe so cannot store None resulting in below conditional assignment + ( + json_data['total_employee_count'] + if json_data['total_employee_count'] and json_data['total_employee_count'] > 0 + else None + ), + ( + json_data['employee_count_release_year'] + if json_data['employee_count_release_year'] and json_data['employee_count_release_year'] > 0 + else None + ), + ), + ) + ) + + return ( + None, + None, + get_table_data(), + ) + + +def get_uk_business_employee_counts_batch(data, data_table): + + def get_table_data(): + + for json_data in data: + + if json_data['geo_code'] == 'K02000001': + continue + + yield ( + ( + data_table, + ( + json_data['geo_description'], + json_data['geo_code'], + json_data['sic_code'], + json_data['sic_description'], + json_data['total_business_count'], + json_data['business_count_release_year'], + # missing employee data represented as np.nan which results in error saving django model + # columns are int in dataframe so cannot store None resulting in below conditional assignment + ( + json_data['total_employee_count'] + if json_data['total_employee_count'] and json_data['total_employee_count'] > 0 + else None + ), + ( + json_data['employee_count_release_year'] + if json_data['employee_count_release_year'] and json_data['employee_count_release_year'] > 0 + else None + ), + json_data['dbt_full_sector_name'], + json_data['dbt_sector_name'], + ), + ) + ) + + return ( + None, + None, + get_table_data(), + ) + + +def get_uk_business_employee_counts_postgres_tmp_table(metadata, table_name): + return sa.Table( + table_name, + metadata, + sa.Column("geo_description", sa.TEXT, nullable=False), + sa.Column("geo_code", sa.TEXT, nullable=False), + sa.Column("sic_code", sa.TEXT, nullable=False), + sa.Column("sic_description", sa.TEXT, nullable=False), + sa.Column("total_business_count", sa.INTEGER, nullable=True), + sa.Column("business_count_release_year", sa.SMALLINT, nullable=True), + sa.Column("total_employee_count", sa.INTEGER, nullable=True), + sa.Column("employee_count_release_year", sa.SMALLINT, nullable=True), + sa.Index(None, "sic_code"), + schema="public", + ) + + +def get_uk_business_employee_counts_postgres_table(metadata, table_name): + return sa.Table( + table_name, + metadata, + sa.Column("geo_description", sa.TEXT, nullable=False), + sa.Column("geo_code", sa.TEXT, nullable=False), + sa.Column("sic_code", sa.TEXT, nullable=False), + sa.Column("sic_description", sa.TEXT, nullable=False), + sa.Column("total_business_count", sa.INTEGER, nullable=True), + sa.Column("business_count_release_year", sa.SMALLINT, nullable=True), + sa.Column("total_employee_count", sa.INTEGER, nullable=True), + sa.Column("employee_count_release_year", sa.SMALLINT, nullable=True), + sa.Column("dbt_full_sector_name", sa.TEXT, nullable=True), + sa.Column("dbt_sector_name", sa.TEXT, nullable=True), + schema="public", + ) + + +def get_ref_sic_codes_mapping_batch(data, data_table): + + def get_table_data(): + + for ref_sic_codes_mapping in data: + + json_data = json.loads(ref_sic_codes_mapping) + + yield ( + ( + data_table, + ( + json_data['sic_code'], + json_data['dit_sector_list_id'], + ), + ) + ) + + return ( + None, + None, + get_table_data(), + ) + + +def get_sector_reference_dataset_batch(data, data_table): + + def get_table_data(): + + for sector_reference_dataset in data: + + json_data = json.loads(sector_reference_dataset) + + yield ( + ( + data_table, + ( + json_data['id'], + json_data['field_04'], + json_data['full_sector_name'], + ), + ) + ) + + return ( + None, + None, + get_table_data(), + ) + + +def get_ref_sic_codes_mapping_postgres_table(metadata, table_name): + return sa.Table( + table_name, + metadata, + sa.Column("sic_code", sa.INTEGER, nullable=False), + sa.Column("dit_sector_list_id", sa.INTEGER, nullable=True), + sa.Index(None, "dit_sector_list_id"), + schema="public", + ) + + +def get_sector_reference_dataset_postgres_table(metadata, table_name): + return sa.Table( + table_name, + metadata, + sa.Column("id", sa.INTEGER, nullable=False), + sa.Column("field_04", sa.TEXT, nullable=True), + sa.Column("full_sector_name", sa.TEXT, nullable=True), + sa.Index(None, "id"), + schema="public", + ) - sql = ''' + +def save_uk_business_employee_counts_data(): + + sql = """ SELECT nubec.geo_description, nubec.geo_code, @@ -21,42 +220,151 @@ class Command(BaseDataWorkspaceIngestionCommand): nubec.employee_count_release_year, sector_mapping.dbt_full_sector_name, sector_mapping.dbt_sector_name - FROM ons.nomis__uk_business_employee_counts nubec - LEFT JOIN ( - SELECT - scmds."DIT full sector name" as dbt_full_sector_name, - scmds."DIT sector" as dbt_sector_name, + FROM public.dataservices_tmp_eybbusinessclusterinformation nubec + LEFT JOIN ( + SELECT + dataservices_tmp_sector_reference.full_sector_name as dbt_full_sector_name, + dataservices_tmp_sector_reference.field_04 as dbt_sector_name, -- necessary because sic codes are stored as integer in source table meaning leading 0 was dropped - substring(((scmds."SIC code" + 100000)::varchar) from 2 for 5) as five_digit_sic - from public.ref_sic_codes_dit_sector_mapping scmds - ) AS sector_mapping + substring(((dataservices_tmp_ref_sic_codes_mapping.sic_code + 100000)::varchar) from 2 for 5) as five_digit_sic -- # noqa:E501 + FROM public.dataservices_tmp_ref_sic_codes_mapping + INNER JOIN public.dataservices_tmp_sector_reference ON public.dataservices_tmp_ref_sic_codes_mapping.dit_sector_list_id = public.dataservices_tmp_sector_reference.id + ) as sector_mapping ON nubec.sic_code = sector_mapping.five_digit_sic - WHERE nubec.geo_code <> 'K02000001' - ''' + """ + + engine = sa.create_engine(settings.DATABASE_URL, future=True) - def load_data(self): - data = [] - chunks = pd.read_sql(sa.text(self.sql), self.engine, chunksize=5000) + data = [] + + with engine.connect() as connection: + chunks = pd.read_sql_query(sa.text(sql), connection, chunksize=5000) for chunk in chunks: - for _idx, row in chunk.iterrows(): + for _, row in chunk.iterrows(): data.append( - EYBBusinessClusterInformation( - geo_description=row.geo_description, - geo_code=row.geo_code, - sic_code=row.sic_code, - sic_description=row.sic_description, - total_business_count=row.total_business_count, - business_count_release_year=row.business_count_release_year, - # missing employee data represented as np.nan which results in error saving django model - # columns are int in dataframe so cannot store None resulting in below conditional assignment - total_employee_count=row.total_employee_count if row.total_employee_count > 0 else None, - employee_count_release_year=( - row.employee_count_release_year if row.employee_count_release_year > 0 else None - ), - dbt_full_sector_name=row.dbt_full_sector_name, - dbt_sector_name=row.dbt_sector_name, - ) + { + 'geo_description': row.geo_description, + 'geo_code': row.geo_code, + 'sic_code': row.sic_code, + 'sic_description': row.sic_description, + 'total_business_count': row.total_business_count, + 'business_count_release_year': row.business_count_release_year, + 'total_employee_count': row.total_employee_count, + 'employee_count_release_year': row.employee_count_release_year, + 'dbt_full_sector_name': row.dbt_full_sector_name, + 'dbt_sector_name': row.dbt_sector_name, + } ) - return data + metadata = sa.MetaData() + + data_table = get_uk_business_employee_counts_postgres_table(metadata, 'dataservices_eybbusinessclusterinformation') + + def on_before_visible(conn, ingest_table, batch_metadata): + pass + + def batches(_): + yield get_uk_business_employee_counts_batch(data, data_table) + + ingest_data(engine, metadata, on_before_visible, batches) + + +def save_uk_business_employee_counts_tmp_data(data): + + table_name = 'dataservices_tmp_eybbusinessclusterinformation' + + engine = sa.create_engine(settings.DATABASE_URL, future=True) + + metadata = sa.MetaData() + + data_table = get_uk_business_employee_counts_postgres_tmp_table(metadata, table_name) + + def on_before_visible(conn, ingest_table, batch_metadata): + pass + + def batches(_): + yield get_uk_business_employee_counts_tmp_batch(data, data_table) + + ingest_data(engine, metadata, on_before_visible, batches) + + +def save_ref_sic_codes_mapping_data(data): + + table_name = 'dataservices_tmp_ref_sic_codes_mapping' + + engine = sa.create_engine(settings.DATABASE_URL, future=True) + + metadata = sa.MetaData() + + data_table = get_ref_sic_codes_mapping_postgres_table(metadata, table_name) + + def on_before_visible(conn, ingest_table, batch_metadata): + pass + + def batches(_): + yield get_ref_sic_codes_mapping_batch(data, data_table) + + ingest_data(engine, metadata, on_before_visible, batches) + + +def save_sector_reference_dataset_data(data): + + table_name = 'dataservices_tmp_sector_reference' + + engine = sa.create_engine(settings.DATABASE_URL, future=True) + + metadata = sa.MetaData() + + data_table = get_sector_reference_dataset_postgres_table(metadata, table_name) + + def on_before_visible(conn, ingest_table, batch_metadata): + pass + + def batches(_): + yield get_sector_reference_dataset_batch(data, data_table) + + ingest_data(engine, metadata, on_before_visible, batches) + + +def delete_temp_tables(table_names): + Base = declarative_base() + metadata = sa.MetaData() + engine = sa.create_engine(settings.DATABASE_URL, future=True) + metadata.reflect(bind=engine) + for name in table_names: + table = metadata.tables.get(name, None) + if table is not None: + Base.metadata.drop_all(engine, [table], checkfirst=True) + + +class Command(BaseCommand, S3DownloadMixin): + + help = 'Import ONS total UK business and employee counts per region and section, 2 and 5 digit Standard Industrial Classification' # noqa:E501 + + def handle(self, *args, **options): + + try: + self.do_handle( + prefix=settings.NOMIS_UK_BUSINESS_EMPLOYEE_COUNTS_FROM_S3_PREFIX, + save_func=save_uk_business_employee_counts_tmp_data, + ) + self.do_handle( + prefix=settings.REF_SIC_CODES_MAPPING_FROM_S3_PREFIX, + save_func=save_ref_sic_codes_mapping_data, + ) + self.do_handle( + prefix=settings.SECTOR_REFERENCE_DATASET_FROM_S3_PREFIX, + save_func=save_sector_reference_dataset_data, + ) + save_uk_business_employee_counts_data() + except Exception: + logger.exception("import_eyb_business_cluster_information failed to ingest data from s3") + finally: + delete_temp_tables( + [ + 'dataservices_tmp_eybbusinessclusterinformation', + 'dataservices_tmp_ref_sic_codes_mapping', + 'dataservices_tmp_sector_reference', + ] + ) diff --git a/dataservices/management/commands/import_postcodes_from_s3.py b/dataservices/management/commands/import_postcodes_from_s3.py index cc99bf46..cb0be390 100644 --- a/dataservices/management/commands/import_postcodes_from_s3.py +++ b/dataservices/management/commands/import_postcodes_from_s3.py @@ -31,32 +31,29 @@ def map_eer_to_european_reqion(eer_code: str) -> str: def get_postcode_table_batch(data, data_table): - table_data = ( - ( - data_table, - ( - json.loads(postcode)['id'], - ( - json.loads(postcode)['pcd'].replace(' ', '') - if json.loads(postcode)['pcd'] - else json.loads(postcode)['pcd'] - ), + + def get_table_data(): + for postcode in data: + json_data = json.loads(postcode) + + yield ( ( - json.loads(postcode)['region_name'].strip() - if json.loads(postcode)['region_name'] - else json.loads(postcode)['region_name'] - ), - map_eer_to_european_reqion(json.loads(postcode)['eer']), - datetime.now(), - datetime.now(), - ), - ) - for postcode in data - ) + data_table, + ( + json_data['id'], + (json_data['pcd'].replace(' ', '') if json_data['pcd'] else json_data['pcd']), + (json_data['region_name'].strip() if json_data['region_name'] else json_data['region_name']), + map_eer_to_european_reqion(json_data['eer']), + datetime.now(), + datetime.now(), + ), + ) + ) + return ( None, None, - table_data, + get_table_data(), ) diff --git a/dataservices/management/commands/tests/test_import_data.py b/dataservices/management/commands/tests/test_import_data.py index 6464cf3b..d935225b 100644 --- a/dataservices/management/commands/tests/test_import_data.py +++ b/dataservices/management/commands/tests/test_import_data.py @@ -4,7 +4,6 @@ from itertools import cycle, islice from unittest import mock -import numpy as np import pandas as pd import pytest import sqlalchemy @@ -623,44 +622,6 @@ def test_helper_get_dataflow_metadata(): assert result.loc[:, 'source_data_modified_utc'][0] == expected -@pytest.mark.django_db -@mock.patch('pandas.read_sql') -@override_settings(DATA_WORKSPACE_DATASETS_URL='postgresql://') -def test_import_eyb_business_cluster_information(read_sql_mock): - data = { - 'geo_code': ['E92000001', 'N92000002', 'E12000003'], - 'geo_description': ['England', 'Northern Ireland', 'Yorkshire and The Humber'], - 'sic_code': ['42', '01110', '10130'], - 'sic_description': [ - 'Civil Engineering', - 'Growing of cereals (except rice), leguminous crops and oil seeds', - 'Production of meat and poultry meat products', - ], - 'total_business_count': [19070, 170, 55], - 'business_count_release_year': [2023, 2023, 2023], - 'total_employee_count': [159000, np.nan, 8000], - 'employee_count_release_year': [2022, np.nan, 2022], - 'dbt_full_sector_name': [ - None, - 'Agriculture, horticulture, fisheries and pets : Arable crops', - 'Food and drink : Meat products', - ], - 'dbt_sector_name': [None, 'Agriculture, horticulture, fisheries and pets', 'Food and drink'], - } - - read_sql_mock.return_value = [pd.DataFrame(data)] - - assert len(models.EYBBusinessClusterInformation.objects.all()) == 0 - - # dry run - management.call_command('import_eyb_business_cluster_information') - assert len(models.EYBBusinessClusterInformation.objects.all()) == 0 - - # write - management.call_command('import_eyb_business_cluster_information', '--write') - assert len(models.EYBBusinessClusterInformation.objects.all()) == 3 - - @pytest.mark.django_db def test_import_markets_countries_territories(capsys): management.call_command('import_markets_countries_territories', '--write') diff --git a/dataservices/tasks.py b/dataservices/tasks.py index b4fa8c80..86f9bf91 100644 --- a/dataservices/tasks.py +++ b/dataservices/tasks.py @@ -58,7 +58,7 @@ def run_import_dbt_sectors(): @app.task() def run_import_eyb_business_cluster_information(): - call_command('import_eyb_business_cluster_information', '--write') + call_command('import_eyb_business_cluster_information') @app.task() diff --git a/dataservices/tests/conftest.py b/dataservices/tests/conftest.py index 79db7b95..2eab3115 100644 --- a/dataservices/tests/conftest.py +++ b/dataservices/tests/conftest.py @@ -130,7 +130,7 @@ def total_trade_records(countries): @pytest.fixture() def trade_in_services_records(countries): records = [ - {'code': '0', 'name': 'none value', 'exports': None, 'imports': None}, + {'code': '0', 'name': 'null value', 'exports': None, 'imports': None}, {'code': '1', 'name': 'first', 'exports': 6, 'imports': 1}, {'code': '2', 'name': 'second', 'exports': 5, 'imports': 1}, {'code': '3', 'name': 'third', 'exports': 4, 'imports': 1}, @@ -170,7 +170,7 @@ def trade_in_services_records(countries): def trade_in_goods_records(countries): for idx, iso2 in enumerate(['DE', 'FR', 'CN']): records = [ - {'code': '0', 'name': 'none value', 'exports': None, 'imports': None}, + {'code': '0', 'name': 'null value', 'exports': None, 'imports': None}, {'code': '1', 'name': 'first', 'exports': 6, 'imports': 1}, {'code': '2', 'name': 'second', 'exports': 5, 'imports': 1}, {'code': '3', 'name': 'third', 'exports': 4, 'imports': 1}, @@ -324,6 +324,76 @@ def business_cluster_information_data(): models.EYBBusinessClusterInformation.objects.all().delete() +@pytest.fixture +def sector_reference_dataset_data(): + yield [ + '{"id": 3, "field_04": "Advanced engineering", "full_sector_name": "Advanced engineering : Metallurgical process plant"}\n', # noqa: E501 + '{"id": 4, "field_04": "Advanced engineering", "full_sector_name": "Advanced engineering : Metals, minerals and materials"}\n', # noqa: E501 + '{"id": 38, "field_04": "Automotive", "full_sector_name": "Automotive"}\n', # noqa: E501 + ] + + +@pytest.fixture +def ref_sic_codes_mapping_data(): + yield [ + '{"id": 1, "sic_code": 1110, "mapping_id": "SIC-SEC-106", "updated_date": "2021-08-19T10:05:34.680837+00:00", "sic_description": "Growing of cereals (except rice), leguminous crops and oil seeds", "dit_sector_list_id": 21}\n', # noqa: E501 + '{"id": 2, "sic_code": 1120, "mapping_id": "SIC-SEC-107", "updated_date": "2021-08-19T10:05:34.689149+00:00", "sic_description": "Growing of rice", "dit_sector_list_id": 21}\n', # noqa: E501 + '{"id": 3, "sic_code": 1130, "mapping_id": "SIC-SEC-129", "updated_date": "2021-08-19T10:05:34.696666+00:00", "sic_description": "Growing of vegetables and melons, roots and tubers", "dit_sector_list_id": 31}\n', # noqa: E501 + ] + + +@pytest.fixture +def uk_business_employee_counts_data(): + yield [ + { + "geo_code": "K02000002", + "sic_code": "01", + "geo_description": "United Kingdom", + "sic_description": "Crop and animal production, hunting and related service activities", + "total_business_count": 132540, + "total_employee_count": None, + "business_count_release_year": 2023, + "employee_count_release_year": None, + "dbt_full_sector_name": "Metallurgical process plant", + "dbt_sector_name": "Advanced engineering", + }, # noqa: E501 + { + "geo_code": "K02000003", + "sic_code": "03", + "geo_description": "United Kingdom", + "sic_description": "Fishing and aquaculture", + "total_business_count": 4070, + "total_employee_count": None, + "business_count_release_year": 2023, + "employee_count_release_year": None, + "dbt_full_sector_name": "Metallurgical process plant", + "dbt_sector_name": "Advanced engineering", + }, # noqa: E501 + { + "geo_code": "K02000004", + "sic_code": "03", + "geo_description": "United Kingdom", + "sic_description": "Fishing and aquaculture", + "total_business_count": 4070, + "total_employee_count": None, + "business_count_release_year": 2023, + "employee_count_release_year": None, + "dbt_full_sector_name": "Automotive", + "dbt_sector_name": "Automotive", + }, # noqa: E501 + ] + + +@pytest.fixture +def uk_business_employee_counts_str_data(uk_business_employee_counts_data): + + data = [] + for line in uk_business_employee_counts_data: + line = json.dumps(line) + data.append(line) + yield data + + @pytest.fixture def eyb_salary_s3_data(): yield [ diff --git a/dataservices/tests/test_ingestion_pattern.py b/dataservices/tests/test_ingestion_pattern.py index 2308f059..6fcd144a 100644 --- a/dataservices/tests/test_ingestion_pattern.py +++ b/dataservices/tests/test_ingestion_pattern.py @@ -48,6 +48,20 @@ get_sectors_gva_value_bands_table, save_sectors_gva_value_bands_data, ) +from dataservices.management.commands.import_eyb_business_cluster_information import ( + get_ref_sic_codes_mapping_batch, + get_ref_sic_codes_mapping_postgres_table, + get_sector_reference_dataset_batch, + get_sector_reference_dataset_postgres_table, + get_uk_business_employee_counts_batch, + get_uk_business_employee_counts_postgres_table, + get_uk_business_employee_counts_postgres_tmp_table, + get_uk_business_employee_counts_tmp_batch, + save_ref_sic_codes_mapping_data, + save_sector_reference_dataset_data, + save_uk_business_employee_counts_tmp_data, +) + dbsector_data = [ { @@ -253,6 +267,54 @@ def test_import_postcode_data_set_from_s3( assert mock_save_postcode_data.call_count == 1 +uk_business_employee_counts = [ + { + "geo_code": "K02000002", + "sic_code": "01", + "geo_description": "United Kingdom", + "sic_description": "Crop and animal production, hunting and related service activities", + "total_business_count": 132540, + "total_employee_count": None, + "business_count_release_year": 2023, + "employee_count_release_year": None, + "dbt_full_sector_name": "Metallurgical process plant", + "dbt_sector_name": "Advanced engineering", + }, +] + + +@pytest.mark.django_db +@pytest.mark.parametrize("get_s3_file_data", [uk_business_employee_counts[0]], indirect=True) +@mock.patch( + 'dataservices.management.commands.import_eyb_business_cluster_information.save_uk_business_employee_counts_data' +) # noqa:E501 +@mock.patch( + 'dataservices.management.commands.import_eyb_business_cluster_information.save_uk_business_employee_counts_tmp_data' +) # noqa:E501 +@mock.patch( + 'dataservices.management.commands.import_eyb_business_cluster_information.save_ref_sic_codes_mapping_data' +) # noqa:E501 +@mock.patch( + 'dataservices.management.commands.import_eyb_business_cluster_information.save_sector_reference_dataset_data' +) # noqa:E501 +@mock.patch('dataservices.core.mixins.get_s3_file') +@mock.patch('dataservices.core.mixins.get_s3_paginator') +def test_import_eyb_business_cluster_information_from_s3( + mock_get_s3_paginator, + mock_get_s3_file, + mock_save_sector_reference_dataset_data, + mock_save_ref_sic_codes_mapping_data, + mock_save_uk_business_employee_counts_tmp_data, + mock_save_uk_business_employee_counts_data, + get_s3_file_data, + get_s3_data_transfer_data, +): + mock_get_s3_file.return_value = get_s3_file_data + mock_get_s3_paginator.return_value = get_s3_data_transfer_data + management.call_command('import_eyb_business_cluster_information') + assert mock_save_uk_business_employee_counts_data.call_count == 1 + + @pytest.mark.django_db @override_settings(DATABASE_URL='postgresql://') @mock.patch.object(pg_bulk_ingest, 'ingest', return_value=None) @@ -357,6 +419,73 @@ def test_get_postcode_batch(postcode_data): assert next(ret[2]) is not None +@pytest.mark.django_db +@override_settings(DATABASE_URL='postgresql://') +@mock.patch.object(pg_bulk_ingest, 'ingest', return_value=None) +@mock.patch.object(Engine, 'connect') +def test_save_get_uk_business_employee_counts_tmp(mock_connection, mock_ingest, uk_business_employee_counts_str_data): + mock_connection.return_value.__enter__.return_value = mock.MagicMock() + save_uk_business_employee_counts_tmp_data(data=uk_business_employee_counts_str_data) + assert mock_ingest.call_count == 1 + + +@pytest.mark.django_db +def test_get_uk_business_employee_counts_tmp_batch(uk_business_employee_counts_str_data): + metadata = sa.MetaData() + ret = get_uk_business_employee_counts_tmp_batch( + uk_business_employee_counts_str_data, + get_uk_business_employee_counts_postgres_tmp_table(metadata, 'tmp_nomis_table'), + ) + assert next(ret[2]) is not None + + +@pytest.mark.django_db +def test_get_uk_business_employee_counts_batch(uk_business_employee_counts_data): + metadata = sa.MetaData() + ret = get_uk_business_employee_counts_batch( + uk_business_employee_counts_data, get_uk_business_employee_counts_postgres_table(metadata, 'nomis_table') + ) + assert next(ret[2]) is not None + + +@pytest.mark.django_db +@override_settings(DATABASE_URL='postgresql://') +@mock.patch.object(pg_bulk_ingest, 'ingest', return_value=None) +@mock.patch.object(Engine, 'connect') +def test_save_ref_sic_codes_mapping(mock_connection, mock_ingest, ref_sic_codes_mapping_data): + mock_connection.return_value.__enter__.return_value = mock.MagicMock() + save_ref_sic_codes_mapping_data(data=ref_sic_codes_mapping_data) + assert mock_ingest.call_count == 1 + + +@pytest.mark.django_db +def test_get_ref_sic_codes_mapping_batch(ref_sic_codes_mapping_data): + metadata = sa.MetaData() + ret = get_ref_sic_codes_mapping_batch( + ref_sic_codes_mapping_data, get_ref_sic_codes_mapping_postgres_table(metadata, 'tmp_sic+_codes_mapping_ref') + ) + assert next(ret[2]) is not None + + +@pytest.mark.django_db +@override_settings(DATABASE_URL='postgresql://') +@mock.patch.object(pg_bulk_ingest, 'ingest', return_value=None) +@mock.patch.object(Engine, 'connect') +def test_save_sector_reference_dataset(mock_connection, mock_ingest, sector_reference_dataset_data): + mock_connection.return_value.__enter__.return_value = mock.MagicMock() + save_sector_reference_dataset_data(data=sector_reference_dataset_data) + assert mock_ingest.call_count == 1 + + +@pytest.mark.django_db +def test_get_sector_reference_datase_batch(sector_reference_dataset_data): + metadata = sa.MetaData() + ret = get_sector_reference_dataset_batch( + sector_reference_dataset_data, get_sector_reference_dataset_postgres_table(metadata, 'tmp_sector_ref') + ) + assert next(ret[2]) is not None + + @pytest.mark.django_db @patch.object(Paginator, 'paginate') def test_get_s3_paginator(mock_paginate, get_s3_data_transfer_data):