Skip to content
This repository has been archived by the owner on Jan 27, 2025. It is now read-only.

Commit

Permalink
feat(ingest/stateful): remove platform_instance_id from state urn (da…
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored and cccs-Dustin committed Feb 1, 2023
1 parent 53f700d commit c894dd4
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 68 deletions.
40 changes: 14 additions & 26 deletions metadata-ingestion/src/datahub/cli/state_cli.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import json
import logging
from datetime import datetime
from typing import Optional

import click
from click_default_group import DefaultGroup

from datahub.cli.cli_utils import get_url_and_token
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
IngestionCheckpointingProviderBase,
)
from datahub.ingestion.graph.client import DataHubGraph, DataHubGraphConfig
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.entity_removal_state import GenericCheckpointState
Expand All @@ -18,7 +15,6 @@
from datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider import (
DatahubIngestionCheckpointingProvider,
)
from datahub.metadata.schema_classes import DatahubIngestionCheckpointClass
from datahub.telemetry import telemetry
from datahub.upgrade import upgrade

Expand All @@ -34,10 +30,12 @@ def state() -> None:
@state.command()
@click.option("--pipeline-name", required=True, type=str)
@click.option("--platform", required=True, type=str)
@click.option("--platform-instance", required=True, type=str)
@click.option("--platform-instance", required=False, type=str)
@upgrade.check_upgrade
@telemetry.with_telemetry
def inspect(pipeline_name: str, platform: str, platform_instance: str) -> None:
def inspect(
pipeline_name: str, platform: str, platform_instance: Optional[str]
) -> None:
"""
Get the latest stateful ingestion state for a given pipeline.
Only works for state entity removal for now.
Expand All @@ -48,23 +46,18 @@ def inspect(pipeline_name: str, platform: str, platform_instance: str) -> None:

(url, token) = get_url_and_token()
datahub_graph = DataHubGraph(DataHubGraphConfig(server=url, token=token))
checkpoint_provider = DatahubIngestionCheckpointingProvider(datahub_graph, "cli")

job_name = StaleEntityRemovalHandler.compute_job_id(platform)

data_job_urn = IngestionCheckpointingProviderBase.get_data_job_urn(
DatahubIngestionCheckpointingProvider.orchestrator_name,
pipeline_name,
job_name,
platform_instance,
)
raw_checkpoint = datahub_graph.get_latest_timeseries_value(
entity_urn=data_job_urn,
filter_criteria_map={
"pipelineName": pipeline_name,
"platformInstanceId": platform_instance,
},
aspect_type=DatahubIngestionCheckpointClass,
)
raw_checkpoint = checkpoint_provider.get_latest_checkpoint(pipeline_name, job_name)
if raw_checkpoint is None and platform_instance is not None:
logger.info(
"Failed to fetch state, but trying legacy URN format because platform_instance is provided."
)
raw_checkpoint = checkpoint_provider.get_latest_checkpoint(
pipeline_name, job_name, platform_instance_id=platform_instance
)

if not raw_checkpoint:
click.secho("No ingestion state found.", fg="red")
Expand All @@ -77,9 +70,4 @@ def inspect(pipeline_name: str, platform: str, platform_instance: str) -> None:
)
assert checkpoint

ts = datetime.utcfromtimestamp(raw_checkpoint.timestampMillis / 1000)
logger.info(
f"Found checkpoint with runId {checkpoint.run_id} and timestamp {ts.isoformat()}"
)

click.echo(json.dumps(checkpoint.state.urns, indent=2))
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,19 @@ def get_data_job_urn(
orchestrator: str,
pipeline_name: str,
job_name: JobId,
platform_instance_id: str,
) -> str:
"""
Standardizes datajob urn minting for all ingestion job state providers.
"""
return builder.make_data_job_urn(
return builder.make_data_job_urn(orchestrator, pipeline_name, job_name)

@staticmethod
def get_data_job_legacy_urn(
orchestrator: str,
pipeline_name: str,
job_name: JobId,
platform_instance_id: str,
) -> str:
return IngestionCheckpointingProviderBase.get_data_job_urn(
orchestrator, f"{pipeline_name}_{platform_instance_id}", job_name
)
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ class Checkpoint(Generic[StateType]):

job_name: str
pipeline_name: str
platform_instance_id: str
run_id: str
state: StateType

Expand Down Expand Up @@ -140,12 +139,12 @@ def create_from_checkpoint_aspect(
checkpoint = cls(
job_name=job_name,
pipeline_name=checkpoint_aspect.pipelineName,
platform_instance_id=checkpoint_aspect.platformInstanceId,
run_id=checkpoint_aspect.runId,
state=state_obj,
)
logger.info(
f"Successfully constructed last checkpoint state for job {job_name}"
f"Successfully constructed last checkpoint state for job {job_name} "
f"with timestamp {datetime.utcfromtimestamp(checkpoint_aspect.timestampMillis/1000)}"
)
return checkpoint
return None
Expand Down Expand Up @@ -216,7 +215,7 @@ def to_checkpoint_aspect(
checkpoint_aspect = DatahubIngestionCheckpointClass(
timestampMillis=int(datetime.utcnow().timestamp() * 1000),
pipelineName=self.pipeline_name,
platformInstanceId=self.platform_instance_id,
platformInstanceId="",
runId=self.run_id,
config="",
state=checkpoint_state,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def create_checkpoint(self) -> Optional[Checkpoint[BaseUsageCheckpointState]]:
return Checkpoint(
job_name=self.job_id,
pipeline_name=self.pipeline_name,
platform_instance_id=self.source.get_platform_instance_id(),
run_id=self.run_id,
state=BaseUsageCheckpointState(
begin_timestamp_millis=self.INVALID_TIMESTAMP_VALUE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def create_checkpoint(self) -> Optional[Checkpoint]:
return Checkpoint(
job_name=self.job_id,
pipeline_name=self.pipeline_name,
platform_instance_id=self.source.get_platform_instance_id(),
run_id=self.run_id,
state=self.state_type_class(),
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, Optional, Type, TypeVar, cast

Expand Down Expand Up @@ -224,10 +223,11 @@ def is_checkpointing_enabled(self, job_id: JobId) -> bool:
raise ValueError(f"No use-case handler for job_id{job_id}")
return self._usecase_handlers[job_id].is_checkpointing_enabled()

# Methods that sub-classes must implement
@abstractmethod
def get_platform_instance_id(self) -> str:
raise NotImplementedError("Sub-classes must implement this method.")
# This method is retained for backwards compatibility, but it is not
# required that new sources implement it. We mainly need it for the
# fallback logic in _get_last_checkpoint.
raise NotImplementedError("no platform_instance_id configured")

def _get_last_checkpoint(
self, job_id: JobId, checkpoint_state_class: Type[StateType]
Expand All @@ -237,12 +237,29 @@ def _get_last_checkpoint(
"""
last_checkpoint: Optional[Checkpoint] = None
if self.is_stateful_ingestion_configured():
# TRICKY: We currently don't include the platform_instance_id in the
# checkpoint urn, but we previously did. As such, we need to fallback
# and try the old urn format if the new format doesn't return anything.

# Obtain the latest checkpoint from GMS for this job.
assert self.ctx.pipeline_name
last_checkpoint_aspect = self.ingestion_checkpointing_state_provider.get_latest_checkpoint( # type: ignore
pipeline_name=self.ctx.pipeline_name, # type: ignore
platform_instance_id=self.get_platform_instance_id(),
pipeline_name=self.ctx.pipeline_name,
job_name=job_id,
)
if last_checkpoint_aspect is None:
# Try again with the platform_instance_id, if implemented.
try:
platform_instance_id = self.get_platform_instance_id()
except NotImplementedError:
pass
else:
last_checkpoint_aspect = self.ingestion_checkpointing_state_provider.get_latest_checkpoint( # type: ignore
pipeline_name=self.ctx.pipeline_name,
job_name=job_id,
platform_instance_id=platform_instance_id,
)

# Convert it to a first-class Checkpoint object.
last_checkpoint = Checkpoint[StateType].create_from_checkpoint_aspect(
job_name=job_id,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from datetime import datetime, timezone
from datetime import datetime
from typing import Any, Dict, Optional

from datahub.configuration.common import ConfigurationError
Expand Down Expand Up @@ -63,37 +63,41 @@ def _is_server_stateful_ingestion_capable(self) -> bool:
def get_latest_checkpoint(
self,
pipeline_name: str,
platform_instance_id: str,
job_name: JobId,
platform_instance_id: Optional[str] = None,
) -> Optional[DatahubIngestionCheckpointClass]:
logger.info(
logger.debug(
f"Querying for the latest ingestion checkpoint for pipelineName:'{pipeline_name}',"
f" platformInstanceId:'{platform_instance_id}', job_name:'{job_name}'"
)

data_job_urn = self.get_data_job_urn(
self.orchestrator_name, pipeline_name, job_name, platform_instance_id
)
if platform_instance_id is None:
data_job_urn = self.get_data_job_urn(
self.orchestrator_name, pipeline_name, job_name
)
else:
data_job_urn = self.get_data_job_legacy_urn(
self.orchestrator_name, pipeline_name, job_name, platform_instance_id
)

latest_checkpoint: Optional[
DatahubIngestionCheckpointClass
] = self.graph.get_latest_timeseries_value(
entity_urn=data_job_urn,
aspect_type=DatahubIngestionCheckpointClass,
filter_criteria_map={
"pipelineName": pipeline_name,
"platformInstanceId": platform_instance_id,
},
)
if latest_checkpoint:
logger.info(
logger.debug(
f"The last committed ingestion checkpoint for pipelineName:'{pipeline_name}',"
f" platformInstanceId:'{platform_instance_id}', job_name:'{job_name}' found with start_time:"
f" {datetime.fromtimestamp(latest_checkpoint.timestampMillis/1000, tz=timezone.utc)} and a"
f" bucket duration of {latest_checkpoint.eventGranularity}."
f" {datetime.utcfromtimestamp(latest_checkpoint.timestampMillis/1000)}"
)
return latest_checkpoint
else:
logger.info(
logger.debug(
f"No committed ingestion checkpoint for pipelineName:'{pipeline_name}',"
f" platformInstanceId:'{platform_instance_id}', job_name:'{job_name}' found"
)
Expand All @@ -108,8 +112,8 @@ def commit(self) -> None:
for job_name, checkpoint in self.state_to_commit.items():
# Emit the ingestion state for each job
logger.info(
f"Committing ingestion checkpoint for pipeline:'{checkpoint.pipelineName}',"
f"instance:'{checkpoint.platformInstanceId}', job:'{job_name}'"
f"Committing ingestion checkpoint for pipeline:'{checkpoint.pipelineName}', "
f"job:'{job_name}'"
)

self.committed = False
Expand All @@ -118,7 +122,6 @@ def commit(self) -> None:
self.orchestrator_name,
checkpoint.pipelineName,
job_name,
checkpoint.platformInstanceId,
)

self.graph.emit_mcp(
Expand All @@ -133,7 +136,7 @@ def commit(self) -> None:

self.committed = True

logger.info(
f"Committed ingestion checkpoint for pipeline:'{checkpoint.pipelineName}',"
f"instance:'{checkpoint.platformInstanceId}', job:'{job_name}'"
logger.debug(
f"Committed ingestion checkpoint for pipeline:'{checkpoint.pipelineName}', "
f"job:'{job_name}'"
)
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
class TestDatahubIngestionCheckpointProvider(unittest.TestCase):
# Static members for the tests
pipeline_name: str = "test_pipeline"
platform_instance_id: str = "test_platform_instance_1"
job_names: List[JobId] = [JobId("job1"), JobId("job2")]
run_id: str = "test_run"

Expand Down Expand Up @@ -95,7 +94,6 @@ def monkey_patch_get_latest_timeseries_value(
filter_criteria_map,
{
"pipelineName": self.pipeline_name,
"platformInstanceId": self.platform_instance_id,
},
)
# Retrieve the cached mcpw and return its aspect value.
Expand All @@ -111,7 +109,6 @@ def test_provider(self):
job1_checkpoint = Checkpoint(
job_name=self.job_names[0],
pipeline_name=self.pipeline_name,
platform_instance_id=self.platform_instance_id,
run_id=self.run_id,
state=job1_state_obj,
)
Expand All @@ -122,7 +119,6 @@ def test_provider(self):
job2_checkpoint = Checkpoint(
job_name=self.job_names[1],
pipeline_name=self.pipeline_name,
platform_instance_id=self.platform_instance_id,
run_id=self.run_id,
state=job2_state_obj,
)
Expand All @@ -146,10 +142,10 @@ def test_provider(self):
# 4. Get last committed state. This must match what has been committed earlier.
# NOTE: This will retrieve from in-memory self.mcps_emitted because of the monkey-patching.
job1_last_state = self.provider.get_latest_checkpoint(
self.pipeline_name, self.platform_instance_id, self.job_names[0]
self.pipeline_name, self.job_names[0]
)
job2_last_state = self.provider.get_latest_checkpoint(
self.pipeline_name, self.platform_instance_id, self.job_names[1]
self.pipeline_name, self.job_names[1]
)

# 5. Validate individual job checkpoint state values that have been committed and retrieved
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

# 1. Setup common test param values.
test_pipeline_name: str = "test_pipeline"
test_platform_instance_id: str = "test_platform_instance_1"
test_job_name: str = "test_job_1"
test_run_id: str = "test_run_1"

Expand All @@ -30,8 +29,8 @@ def _assert_checkpoint_deserialization(
checkpoint_aspect = DatahubIngestionCheckpointClass(
timestampMillis=int(datetime.utcnow().timestamp() * 1000),
pipelineName=test_pipeline_name,
platformInstanceId=test_platform_instance_id,
config="",
platformInstanceId="this-can-be-anything-and-will-be-ignored",
config="this-is-also-ignored",
state=serialized_checkpoint_state,
runId=test_run_id,
)
Expand All @@ -46,7 +45,6 @@ def _assert_checkpoint_deserialization(
expected_checkpoint_obj = Checkpoint(
job_name=test_job_name,
pipeline_name=test_pipeline_name,
platform_instance_id=test_platform_instance_id,
run_id=test_run_id,
state=expected_checkpoint_state,
)
Expand Down Expand Up @@ -120,7 +118,6 @@ def test_serde_idempotence(state_obj):
orig_checkpoint_obj = Checkpoint(
job_name=test_job_name,
pipeline_name=test_pipeline_name,
platform_instance_id=test_platform_instance_id,
run_id=test_run_id,
state=state_obj,
)
Expand Down

0 comments on commit c894dd4

Please sign in to comment.