-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy path_state.py
304 lines (248 loc) · 10.2 KB
/
_state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
"""Helper functions for state and bookmark management."""
from __future__ import annotations
import logging
import typing as t
from singer_sdk.exceptions import InvalidStreamSortException
from singer_sdk.helpers._typing import to_json_compatible
if t.TYPE_CHECKING:
import datetime
from singer_sdk.helpers import types
_T = t.TypeVar("_T", datetime.datetime, str, int, float)
PROGRESS_MARKERS = "progress_markers"
PROGRESS_MARKER_NOTE = "Note"
SIGNPOST_MARKER = "replication_key_signpost"
STARTING_MARKER = "starting_replication_value"
logger = logging.getLogger("singer_sdk")
def get_state_if_exists(
tap_state: dict,
tap_stream_id: str,
state_partition_context: dict | None = None,
key: str | None = None,
) -> t.Any | None: # noqa: ANN401
"""Return the stream or partition state, creating a new one if it does not exist.
Args:
tap_state: the existing state dict which contains all streams.
tap_stream_id: the id of the stream
state_partition_context: keys which identify the partition context,
by default None (not partitioned)
key: name of the key searched for, by default None (return entire state if
found)
Returns:
Returns the state if exists, otherwise None
Raises:
ValueError: Raised if state is invalid or cannot be parsed.
"""
if "bookmarks" not in tap_state:
return None
if tap_stream_id not in tap_state["bookmarks"]:
return None
stream_state = tap_state["bookmarks"][tap_stream_id]
if not state_partition_context:
return stream_state.get(key, None) if key else stream_state
if "partitions" not in stream_state:
return None # No partitions defined
matched_partition = _find_in_partitions_list(
stream_state["partitions"],
state_partition_context,
)
if matched_partition is None:
return None # Partition definition not present
return matched_partition.get(key, None) if key else matched_partition
def get_state_partitions_list(tap_state: dict, tap_stream_id: str) -> list[dict] | None:
"""Return a list of partitions defined in the state, or None if not defined."""
return (get_state_if_exists(tap_state, tap_stream_id) or {}).get("partitions", None) # type: ignore[no-any-return]
def _find_in_partitions_list(
partitions: list[dict],
state_partition_context: types.Context,
) -> dict | None:
found = [
partition_state
for partition_state in partitions
if partition_state["context"] == state_partition_context
]
if len(found) > 1:
msg = (
"State file contains duplicate entries for partition: "
f"{state_partition_context}.\nMatching state values were: {found!s}"
)
raise ValueError(msg)
return found[0] if found else None
def _create_in_partitions_list(
partitions: list[dict],
state_partition_context: types.Context,
) -> dict:
# Existing partition not found. Creating new state entry in partitions list...
new_partition_state = {"context": state_partition_context}
partitions.append(new_partition_state)
return new_partition_state
def get_writeable_state_dict(
tap_state: dict,
tap_stream_id: str,
state_partition_context: types.Context | None = None,
) -> dict:
"""Return the stream or partition state, creating a new one if it does not exist.
Args:
tap_state: the existing state dict which contains all streams.
tap_stream_id: the id of the stream
state_partition_context: keys which identify the partition context,
by default None (not partitioned)
Returns:
Returns a writeable dict at the stream or partition level.
Raises:
ValueError: Raise an error if duplicate entries are found.
"""
if tap_state is None:
msg = "Cannot write state to missing state dictionary." # type: ignore[unreachable]
raise ValueError(msg)
if "bookmarks" not in tap_state:
tap_state["bookmarks"] = {}
if tap_stream_id not in tap_state["bookmarks"]:
tap_state["bookmarks"][tap_stream_id] = {}
stream_state = t.cast(dict, tap_state["bookmarks"][tap_stream_id])
if not state_partition_context:
return stream_state
if "partitions" not in stream_state:
stream_state["partitions"] = []
stream_state_partitions: list[dict] = stream_state["partitions"]
if found := _find_in_partitions_list(
stream_state_partitions,
state_partition_context,
):
return found
return _create_in_partitions_list(stream_state_partitions, state_partition_context)
def write_stream_state(
tap_state: dict,
tap_stream_id: str,
key: str,
val: t.Any, # noqa: ANN401
*,
state_partition_context: dict | None = None,
) -> None:
"""Write stream state."""
state_dict = get_writeable_state_dict(
tap_state,
tap_stream_id,
state_partition_context=state_partition_context,
)
state_dict[key] = val
def reset_state_progress_markers(stream_or_partition_state: dict) -> dict | None:
"""Wipe the state once sync is complete.
For logging purposes, return the wiped 'progress_markers' object if it existed.
"""
progress_markers = stream_or_partition_state.pop(PROGRESS_MARKERS, {})
# Remove auto-generated human-readable note:
progress_markers.pop(PROGRESS_MARKER_NOTE, None)
# Return remaining 'progress_markers' if any:
return progress_markers or None
def write_replication_key_signpost(
stream_or_partition_state: dict,
new_signpost_value: t.Any, # noqa: ANN401
) -> None:
"""Write signpost value."""
stream_or_partition_state[SIGNPOST_MARKER] = to_json_compatible(new_signpost_value)
def write_starting_replication_value(
stream_or_partition_state: dict,
initial_value: t.Any, # noqa: ANN401
) -> None:
"""Write initial replication value to state."""
stream_or_partition_state[STARTING_MARKER] = to_json_compatible(initial_value)
def get_starting_replication_value(stream_or_partition_state: dict) -> t.Any | None: # noqa: ANN401
"""Retrieve initial replication marker value from state."""
if not stream_or_partition_state:
return None
return stream_or_partition_state.get(STARTING_MARKER)
def increment_state(
stream_or_partition_state: dict,
*,
latest_record: dict,
replication_key: str,
is_sorted: bool,
check_sorted: bool,
) -> None:
"""Update the state using data from the latest record.
Raises InvalidStreamSortException if is_sorted=True, check_sorted=True and unsorted
data is detected in the stream.
"""
progress_dict = stream_or_partition_state
if not is_sorted:
if PROGRESS_MARKERS not in stream_or_partition_state:
stream_or_partition_state[PROGRESS_MARKERS] = {
PROGRESS_MARKER_NOTE: "Progress is not resumable if interrupted.",
}
logger.warning(
"Stream is assumed to be unsorted, progress is not resumable if "
"interrupted",
extra={"replication_key": replication_key},
)
progress_dict = stream_or_partition_state[PROGRESS_MARKERS]
old_rk_value = to_json_compatible(progress_dict.get("replication_key_value"))
new_rk_value = to_json_compatible(latest_record[replication_key])
if new_rk_value is None:
logger.warning("New replication value is null")
return
if old_rk_value is None or not check_sorted or new_rk_value >= old_rk_value:
progress_dict["replication_key"] = replication_key
progress_dict["replication_key_value"] = new_rk_value
return
if is_sorted:
msg = (
f"Unsorted data detected in stream. Latest value '{new_rk_value}' is "
f"smaller than previous max '{old_rk_value}'."
)
raise InvalidStreamSortException(msg)
def _greater_than_signpost(
signpost: _T,
new_value: _T,
) -> bool:
"""Compare and return True if new_value is greater than signpost."""
# fails if signpost and bookmark are incompatible types
return new_value > signpost
def is_state_non_resumable(stream_or_partition_state: dict) -> bool:
"""Return True when state is non-resumable.
This is determined by checking for a "progress marker" tag in the state artifact.
"""
return PROGRESS_MARKERS in stream_or_partition_state
def finalize_state_progress_markers(stream_or_partition_state: dict) -> dict | None:
"""Promote or wipe progress markers once sync is complete.
This marks any non-resumable progress markers as finalized. If there are
valid bookmarks present, they will be promoted to be resumable.
"""
signpost_value = stream_or_partition_state.pop(SIGNPOST_MARKER, None)
stream_or_partition_state.pop(STARTING_MARKER, None)
if (
is_state_non_resumable(stream_or_partition_state)
and "replication_key" in stream_or_partition_state[PROGRESS_MARKERS]
):
# Replication keys valid (only) after sync is complete
progress_markers = stream_or_partition_state[PROGRESS_MARKERS]
stream_or_partition_state["replication_key"] = progress_markers.pop(
"replication_key",
)
new_rk_value = progress_markers.pop("replication_key_value")
if signpost_value and _greater_than_signpost(signpost_value, new_rk_value):
new_rk_value = signpost_value
stream_or_partition_state["replication_key_value"] = new_rk_value
# Wipe and return any markers that have not been promoted
return reset_state_progress_markers(stream_or_partition_state)
def log_sort_error(
*,
ex: Exception,
log_fn: t.Callable,
stream_name: str,
current_context: types.Context | None,
state_partition_context: types.Context | None,
record_count: int,
partition_record_count: int,
) -> None:
"""Log a sort error."""
msg = f"Sorting error detected in '{stream_name}' on record #{record_count}. "
if partition_record_count != record_count:
msg += (
f"Record was partition record "
f"#{partition_record_count} with"
f" state partition context {state_partition_context}. "
)
if current_context:
msg += f"Context was {current_context!s}. "
msg += str(ex)
log_fn(msg)