Skip to content

Commit

Permalink
drop tables in init_client
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Apr 8, 2024
1 parent 2da9c01 commit f1df8de
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 176 deletions.
17 changes: 5 additions & 12 deletions dlt/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,22 +394,15 @@ def extract(
state_paths="*" if self.refresh == "drop_dataset" else [],
)
_state.update(new_state)
drop_schema = source.schema.clone()
if drop_info["tables"]:
drop_tables = {
key: table
for key, table in source.schema.tables.items()
drop_tables = [
table
for table in source.schema.tables.values()
if table["name"] in drop_info["tables"]
or table["name"] in drop_schema.dlt_table_names()
}

drop_schema.tables.clear()
drop_schema.tables.update(drop_tables)
load_package.state["drop_schema"] = drop_schema.to_dict()
]
load_package.state["dropped_tables"] = drop_tables
source.schema.tables.clear()
source.schema.tables.update(new_schema.tables)
# dropped_tables = load_package.state.setdefault("dropped_tables", [])
# dropped_tables.extend(drop_info["tables"])

# reset resource states, the `extracted` list contains all the explicit resources and all their parents
for resource in source.resources.extracted.values():
Expand Down
68 changes: 7 additions & 61 deletions dlt/load/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,75 +345,28 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False)
f"All jobs completed, archiving package {load_id} with aborted set to {aborted}"
)

# def _refresh(self, dropped_tables: Sequence[str], schema: Schema) -> Tuple[Set[str], Set[str]]:
# """When using refresh mode, drop tables if possible.
# Returns a set of tables for main destination and staging destination
# that could not be dropped and should be truncated instead
# """
# # Exclude tables already dropped in the same load
# drop_tables = set(dropped_tables) - self._refreshed_tables
# if not drop_tables:
# return set(), set()
# # Clone schema and remove tables from it
# dropped_schema = deepcopy(schema)
# for table_name in drop_tables:
# # pop not del: The table may not actually be in the schema if it's not being loaded again
# dropped_schema.tables.pop(table_name, None)
# dropped_schema._bump_version()
# trunc_dest: Set[str] = set()
# trunc_staging: Set[str] = set()
# # Drop from destination and replace stored schema so tables will be re-created before load
# with self.get_destination_client(dropped_schema) as job_client:
# # TODO: SupportsSql mixin
# if hasattr(job_client, "drop_tables"):
# job_client.drop_tables(*drop_tables, replace_schema=True)
# else:
# # Tables need to be truncated instead of dropped
# trunc_dest = drop_tables

# if self.staging_destination:
# with self.get_staging_destination_client(dropped_schema) as staging_client:
# if hasattr(staging_client, "drop_tables"):
# staging_client.drop_tables(*drop_tables, replace_schema=True)
# else:
# trunc_staging = drop_tables
# self._refreshed_tables.update(drop_tables) # Don't drop table again in same load
# return trunc_dest, trunc_staging

def load_single_package(self, load_id: str, schema: Schema) -> None:
new_jobs = self.get_new_jobs_info(load_id)

# dropped_tables = current_load_package()["state"].get("dropped_tables", [])
# Drop tables before loading if refresh mode is set
# truncate_dest, truncate_staging = self._refresh(dropped_tables, schema)
drop_schema_dict = current_load_package()["state"].get("drop_schema")
drop_schema = Schema.from_dict(drop_schema_dict) if drop_schema_dict else None
init_schema = drop_schema if drop_schema else schema
dropped_tables = current_load_package()["state"].get("dropped_tables", [])

# initialize analytical storage ie. create dataset required by passed schema
with self.get_destination_client(init_schema) as job_client:
with self.get_destination_client(schema) as job_client:
if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None:
# init job client
# def should_truncate(table: TTableSchema) -> bool:
# # When destination doesn't support dropping refreshed tables (i.e. not SQL based) they should be truncated
# return (
# job_client.should_truncate_table_before_load(table)
# or table["name"] in truncate_dest
# )

applied_update = init_client(
job_client,
init_schema,
schema,
new_jobs,
expected_update,
job_client.should_truncate_table_before_load,
# should_truncate,
(
job_client.should_load_data_to_staging_dataset
if isinstance(job_client, WithStagingDataset)
else None
),
refresh=self.refresh,
drop_tables=dropped_tables,
)

# init staging client
Expand All @@ -423,24 +376,17 @@ def load_single_package(self, load_id: str, schema: Schema) -> None:
" implement SupportsStagingDestination"
)

# def should_truncate_staging(table: TTableSchema) -> bool:
# return (
# job_client.should_truncate_table_before_load_on_staging_destination(
# table
# )
# or table["name"] in truncate_staging
# )

with self.get_staging_destination_client(init_schema) as staging_client:
with self.get_staging_destination_client(schema) as staging_client:
init_client(
staging_client,
init_schema,
schema,
new_jobs,
expected_update,
job_client.should_truncate_table_before_load_on_staging_destination,
# should_truncate_staging,
job_client.should_load_data_to_staging_dataset_on_staging_destination,
refresh=self.refresh,
drop_tables=dropped_tables,
)

self.load_storage.commit_schema_update(load_id, applied_update)
Expand Down
33 changes: 22 additions & 11 deletions dlt/load/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def init_client(
truncate_filter: Callable[[TTableSchema], bool],
load_staging_filter: Callable[[TTableSchema], bool],
refresh: Optional[TRefreshMode] = None,
drop_tables: Optional[List[TTableSchema]] = None,
) -> TSchemaTables:
"""Initializes destination storage including staging dataset if supported
Expand Down Expand Up @@ -96,16 +97,18 @@ def init_client(
tables_with_jobs = set(job.table_name for job in new_jobs) - tables_no_data

# get tables to truncate by extending tables with jobs with all their child tables

if refresh == "drop_data":
truncate_filter = lambda t: True

truncate_tables = set(
_extend_tables_with_table_chain(schema, tables_with_jobs, tables_with_jobs, truncate_filter)
)

if refresh in ("drop_dataset", "drop_tables"):
drop_tables = all_tables - dlt_tables - tables_no_data
else:
drop_tables = set()
# if refresh in ("drop_dataset", "drop_tables"):
# drop_tables = all_tables - dlt_tables - tables_no_data
# else:
# drop_tables = set()

applied_update = _init_dataset_and_update_schema(
job_client,
Expand Down Expand Up @@ -143,13 +146,26 @@ def _init_dataset_and_update_schema(
update_tables: Iterable[str],
truncate_tables: Iterable[str] = None,
staging_info: bool = False,
drop_tables: Optional[Iterable[str]] = None,
drop_tables: Optional[List[TTableSchema]] = None,
) -> TSchemaTables:
staging_text = "for staging dataset" if staging_info else ""
logger.info(
f"Client for {job_client.config.destination_type} will start initialize storage"
f" {staging_text}"
)
if drop_tables:
old_schema = job_client.schema
new_schema = job_client.schema.clone()
job_client.schema = new_schema
for table in drop_tables:
new_schema.tables.pop(table["name"], None)
new_schema._bump_version()
if hasattr(job_client, "drop_tables"):
logger.info(
f"Client for {job_client.config.destination_type} will drop tables {staging_text}"
)
job_client.drop_tables(*[table["name"] for table in drop_tables], replace_schema=True)
job_client.schema = old_schema
job_client.initialize_storage()
logger.info(
f"Client for {job_client.config.destination_type} will update schema to package schema"
Expand All @@ -161,13 +177,8 @@ def _init_dataset_and_update_schema(
logger.info(
f"Client for {job_client.config.destination_type} will truncate tables {staging_text}"
)

job_client.initialize_storage(truncate_tables=truncate_tables)
if drop_tables:
if hasattr(job_client, "drop_tables"):
logger.info(
f"Client for {job_client.config.destination_type} will drop tables {staging_text}"
)
job_client.drop_tables(*drop_tables)
return applied_update


Expand Down
1 change: 1 addition & 0 deletions dlt/pipeline/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,6 @@ def drop_resources(

for tbl in tables_to_drop:
del schema.tables[tbl["name"]]
schema._bump_version() # TODO: needed?

return schema, new_state, info
92 changes: 0 additions & 92 deletions dlt/pipeline/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,6 @@ def _retry_load(ex: BaseException) -> bool:
return _retry_load


# class _DropInfo(TypedDict):
# tables: List[str]
# resource_states: List[str]
# resource_names: List[str]
# state_paths: List[str]
# schema_name: str
# dataset_name: str
# drop_all: bool
# resource_pattern: Optional[REPattern]
# warnings: List[str]


class DropCommand:
def __init__(
self,
Expand All @@ -111,65 +99,13 @@ def __init__(
drop_all: Drop all resources and tables in the schema (supersedes `resources` list)
state_only: Drop only state, not tables
"""
# self.extract_only = extract_only
# self.pipeline = pipeline
# if isinstance(resources, str):
# resources = [resources]
# if isinstance(state_paths, str):
# state_paths = [state_paths]
self.pipeline = pipeline

if not pipeline.default_schema_name:
raise PipelineNeverRan(pipeline.pipeline_name, pipeline.pipelines_dir)
self.schema = pipeline.schemas[schema_name or pipeline.default_schema_name]

self.drop_tables = not state_only
# self.state_paths_to_drop = compile_paths(state_paths)

# resources = set(resources)
# resource_names = []
# if drop_all:
# self.resource_pattern = compile_simple_regex(TSimpleRegex("re:.*")) # Match everything
# elif resources:
# self.resource_pattern = compile_simple_regexes(TSimpleRegex(r) for r in resources)
# else:
# self.resource_pattern = None

# if self.resource_pattern:
# data_tables = {
# t["name"]: t for t in self.schema.data_tables()
# } # Don't remove _dlt tables
# resource_tables = group_tables_by_resource(data_tables, pattern=self.resource_pattern)
# if self.drop_tables:
# self.tables_to_drop = list(chain.from_iterable(resource_tables.values()))
# self.tables_to_drop.reverse()
# else:
# self.tables_to_drop = []
# resource_names = list(resource_tables.keys())
# else:
# self.tables_to_drop = []
# self.drop_tables = False # No tables to drop
# self.drop_state = not not self.state_paths_to_drop # obtain truth value

# self.drop_all = drop_all
# self.info: _DropInfo = dict(
# tables=[t["name"] for t in self.tables_to_drop],
# resource_states=[],
# state_paths=[],
# resource_names=resource_names,
# schema_name=self.schema.name,
# dataset_name=self.pipeline.dataset_name,
# drop_all=drop_all,
# resource_pattern=self.resource_pattern,
# warnings=[],
# )
# if self.resource_pattern and not resource_tables:
# self.info["warnings"].append(
# f"Specified resource(s) {str(resources)} did not select any table(s) in schema"
# f" {self.schema.name}. Possible resources are:"
# f" {list(group_tables_by_resource(data_tables).keys())}"
# )
# self._new_state = self._create_modified_state()

self._drop_schema, self._new_state, self.info = drop_resources(
self.schema,
Expand Down Expand Up @@ -212,34 +148,6 @@ def _delete_schema_tables(self) -> None:
# bump schema, we'll save later
self.schema._bump_version()

# def _list_state_paths(self, source_state: Dict[str, Any]) -> List[str]:
# return resolve_paths(self.state_paths_to_drop, source_state)

# def _create_modified_state(self) -> Dict[str, Any]:
# state = self.pipeline.state
# if not self.drop_state:
# return state # type: ignore[return-value]
# source_states = _sources_state(state).items()
# for source_name, source_state in source_states:
# # drop table states
# if self.drop_state and self.resource_pattern:
# for key in _get_matching_resources(self.resource_pattern, source_state):
# self.info["resource_states"].append(key)
# reset_resource_state(key, source_state)
# # drop additional state paths
# # Don't drop 'resources' key if jsonpath is wildcard
# resolved_paths = [
# p for p in resolve_paths(self.state_paths_to_drop, source_state) if p != "resources"
# ]
# if self.state_paths_to_drop and not resolved_paths:
# self.info["warnings"].append(
# f"State paths {self.state_paths_to_drop} did not select any paths in source"
# f" {source_name}"
# )
# _delete_source_state_keys(resolved_paths, source_state)
# self.info["state_paths"].extend(f"{source_name}.{p}" for p in resolved_paths)
# return state # type: ignore[return-value]

def _extract_state(self) -> None:
state: Dict[str, Any]
with self.pipeline.managed_state(extract_state=True) as state: # type: ignore[assignment]
Expand Down

0 comments on commit f1df8de

Please sign in to comment.