Skip to content

Commit

Permalink
[Fetch Migration] Added total document count to report (#261)
Browse files Browse the repository at this point in the history
This change allows the Index Configuration Tool report to print the total number of documents that will be migrated, based on the indices identified for creation. The change includes a new API call under index_operations.py and updated unit tests. This commit also includes some other minor changes:
* The output YAML file is now optional, allowing users to print a report without producing a YAML file. However, one of --report or an output YAML file path are required - omitting both will result in a ValueError
* A new EndpointInfo dataclass has been introduced to encapsulate endpoint information (URL, auth and SSL verification flag) for source and target
* The default output of the tool has been changed to dump the total document count followed by the list of created indices. -r / --report should be specified to produce human-readable output. All other print statements have been removed.

---------

Signed-off-by: Kartik Ganesh <[email protected]>
  • Loading branch information
kartg authored Aug 16, 2023
1 parent edfb16a commit 2c8e3ab
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 59 deletions.
8 changes: 8 additions & 0 deletions FetchMigration/index_configuration_tool/endpoint_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass


@dataclass
class EndpointInfo:
url: str
auth: tuple = None
verify_ssl: bool = True
28 changes: 18 additions & 10 deletions FetchMigration/index_configuration_tool/index_operations.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import sys
from typing import Optional

import requests

from endpoint_info import EndpointInfo

# Constants
SETTINGS_KEY = "settings"
MAPPINGS_KEY = "mappings"
COUNT_KEY = "count"
__INDEX_KEY = "index"
__ALL_INDICES_ENDPOINT = "*"
__COUNT_ENDPOINT = "/_count"
__INTERNAL_SETTINGS_KEYS = ["creation_date", "uuid", "provided_name", "version", "store"]


def fetch_all_indices(endpoint: str, optional_auth: Optional[tuple] = None, verify: bool = True) -> dict:
actual_endpoint = endpoint + __ALL_INDICES_ENDPOINT
resp = requests.get(actual_endpoint, auth=optional_auth, verify=verify)
def fetch_all_indices(endpoint: EndpointInfo) -> dict:
actual_endpoint = endpoint.url + __ALL_INDICES_ENDPOINT
resp = requests.get(actual_endpoint, auth=endpoint.auth, verify=endpoint.verify_ssl)
# Remove internal settings
result = dict(resp.json())
for index in result:
Expand All @@ -24,14 +25,21 @@ def fetch_all_indices(endpoint: str, optional_auth: Optional[tuple] = None, veri
return result


def create_indices(indices_data: dict, endpoint: str, auth_tuple: Optional[tuple]):
def create_indices(indices_data: dict, endpoint: EndpointInfo):
for index in indices_data:
actual_endpoint = endpoint + index
actual_endpoint = endpoint.url + index
data_dict = dict()
data_dict[SETTINGS_KEY] = indices_data[index][SETTINGS_KEY]
data_dict[MAPPINGS_KEY] = indices_data[index][MAPPINGS_KEY]
try:
resp = requests.put(actual_endpoint, auth=auth_tuple, json=data_dict)
resp = requests.put(actual_endpoint, auth=endpoint.auth, verify=endpoint.verify_ssl, json=data_dict)
resp.raise_for_status()
except requests.exceptions.RequestException as e:
print(f"Failed to create index [{index}] - {e!s}", file=sys.stderr)
raise RuntimeError(f"Failed to create index [{index}] - {e!s}")


def doc_count(indices: set, endpoint: EndpointInfo) -> int:
actual_endpoint = endpoint.url + ','.join(indices) + __COUNT_ENDPOINT
resp = requests.get(actual_endpoint, auth=endpoint.auth, verify=endpoint.verify_ssl)
result = dict(resp.json())
return int(result[COUNT_KEY])
72 changes: 44 additions & 28 deletions FetchMigration/index_configuration_tool/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import utils

# Constants
from endpoint_info import EndpointInfo

SUPPORTED_ENDPOINTS = ["opensearch", "elasticsearch"]
SOURCE_KEY = "source"
SINK_KEY = "sink"
Expand Down Expand Up @@ -36,19 +38,14 @@ def get_auth(input_data: dict) -> Optional[tuple]:
return input_data[USER_KEY], input_data[PWD_KEY]


def get_endpoint_info(plugin_config: dict) -> tuple:
def get_endpoint_info(plugin_config: dict) -> EndpointInfo:
# "hosts" can be a simple string, or an array of hosts for Logstash to hit.
# This tool needs one accessible host, so we pick the first entry in the latter case.
endpoint = plugin_config[HOSTS_KEY][0] if type(plugin_config[HOSTS_KEY]) is list else plugin_config[HOSTS_KEY]
endpoint += "/"
return endpoint, get_auth(plugin_config)


def fetch_all_indices_by_plugin(plugin_config: dict) -> dict:
endpoint, auth_tuple = get_endpoint_info(plugin_config)
url = plugin_config[HOSTS_KEY][0] if type(plugin_config[HOSTS_KEY]) is list else plugin_config[HOSTS_KEY]
url += "/"
# verify boolean will be the inverse of the insecure SSL key, if present
should_verify = not is_insecure(plugin_config)
return index_operations.fetch_all_indices(endpoint, auth_tuple, should_verify)
return EndpointInfo(url, get_auth(plugin_config), should_verify)


def check_supported_endpoint(config: dict) -> Optional[tuple]:
Expand Down Expand Up @@ -112,7 +109,6 @@ def write_output(yaml_data: dict, new_indices: set, output_path: str):
source_config[INDICES_KEY] = source_indices
with open(output_path, 'w') as out_file:
yaml.dump(yaml_data, out_file)
print("Wrote output YAML pipeline to: " + output_path)


# Computes differences in indices between source and target.
Expand All @@ -138,44 +134,64 @@ def get_index_differences(source: dict, target: dict) -> tuple[set, set, set]:

# The order of data in the tuple is:
# (indices to create), (identical indices), (indices with conflicts)
def print_report(index_differences: tuple[set, set, set]): # pragma no cover
def print_report(index_differences: tuple[set, set, set], count: int): # pragma no cover
print("Identical indices in the target cluster (no changes will be made): " +
utils.string_from_set(index_differences[1]))
print("Indices in target cluster with conflicting settings/mappings: " +
utils.string_from_set(index_differences[2]))
print("Indices to create: " + utils.string_from_set(index_differences[0]))
print("Total documents to be moved: " + str(count))


def dump_count_and_indices(count: int, indices: set): # pragma no cover
print(count)
for index_name in indices:
print(index_name)


def compute_endpoint_and_fetch_indices(config: dict, key: str) -> tuple[EndpointInfo, dict]:
endpoint = get_supported_endpoint(config, key)
# Endpoint is a tuple of (type, config)
endpoint_info = get_endpoint_info(endpoint[1])
indices = index_operations.fetch_all_indices(endpoint_info)
return endpoint_info, indices


def run(args: argparse.Namespace) -> None:
# Sanity check
if not args.report and len(args.output_file) == 0:
raise ValueError("No output file specified")
# Parse and validate pipelines YAML file
with open(args.config_file_path, 'r') as pipeline_file:
dp_config = yaml.safe_load(pipeline_file)
# We expect the Data Prepper pipeline to only have a single top-level value
pipeline_config = next(iter(dp_config.values()))
validate_pipeline_config(pipeline_config)
# Endpoint is a tuple of (type, config)
endpoint = get_supported_endpoint(pipeline_config, SOURCE_KEY)
# Fetch all indices from source cluster
source_indices = fetch_all_indices_by_plugin(endpoint[1])
# Fetch all indices from target cluster
# TODO Refactor this to avoid duplication with fetch_all_indices_by_plugin
endpoint = get_supported_endpoint(pipeline_config, SINK_KEY)
target_endpoint, target_auth = get_endpoint_info(endpoint[1])
target_indices = index_operations.fetch_all_indices(target_endpoint, target_auth)
# Fetch EndpointInfo and indices
source_endpoint_info, source_indices = compute_endpoint_and_fetch_indices(pipeline_config, SOURCE_KEY)
target_endpoint_info, target_indices = compute_endpoint_and_fetch_indices(pipeline_config, SINK_KEY)
# Compute index differences and print report
diff = get_index_differences(source_indices, target_indices)
if args.report:
print_report(diff)
# The first element in the tuple is the set of indices to create
indices_to_create = diff[0]
doc_count = 0
if indices_to_create:
doc_count = index_operations.doc_count(indices_to_create, source_endpoint_info)
if args.report:
print_report(diff, doc_count)
if indices_to_create:
if not args.report:
dump_count_and_indices(doc_count, indices_to_create)
# Write output YAML
write_output(dp_config, indices_to_create, args.output_file)
if len(args.output_file) > 0:
write_output(dp_config, indices_to_create, args.output_file)
if args.report:
print("Wrote output YAML pipeline to: " + args.output_file)
if not args.dryrun:
index_data = dict()
for index_name in indices_to_create:
index_data[index_name] = source_indices[index_name]
index_operations.create_indices(index_data, target_endpoint, target_auth)
index_operations.create_indices(index_data, target_endpoint_info)


if __name__ == '__main__': # pragma no cover
Expand All @@ -191,20 +207,20 @@ def run(args: argparse.Namespace) -> None:
"along with indices that are identical or have conflicting settings/mappings.",
formatter_class=argparse.RawTextHelpFormatter
)
# Positional, required arguments
# Required positional argument
arg_parser.add_argument(
"config_file_path",
help="Path to the Data Prepper pipeline YAML file to parse for source and target endpoint information"
)
# Optional positional argument
arg_parser.add_argument(
"output_file",
nargs='?', default="",
help="Output path for the Data Prepper pipeline YAML file that will be generated"
)
# Optional arguments
# Flags
arg_parser.add_argument("--report", "-r", action="store_true",
help="Print a report of the index differences")
arg_parser.add_argument("--dryrun", action="store_true",
help="Skips the actual creation of indices on the target cluster")
print("\n##### Starting index configuration tool... #####\n")
run(arg_parser.parse_args())
print("\n##### Index configuration tool has completed! #####\n")
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from responses import matchers

import index_operations
from endpoint_info import EndpointInfo
from tests import test_constants


Expand All @@ -15,7 +16,7 @@ def test_fetch_all_indices(self):
# Set up GET response
responses.get(test_constants.SOURCE_ENDPOINT + "*", json=test_constants.BASE_INDICES_DATA)
# Now send request
index_data = index_operations.fetch_all_indices(test_constants.SOURCE_ENDPOINT)
index_data = index_operations.fetch_all_indices(EndpointInfo(test_constants.SOURCE_ENDPOINT))
self.assertEqual(3, len(index_data.keys()))
# Test that internal data has been filtered, but non-internal data is retained
index_settings = index_data[test_constants.INDEX1_NAME][test_constants.SETTINGS_KEY]
Expand All @@ -33,7 +34,7 @@ def test_create_indices(self):
match=[matchers.json_params_matcher(test_data[test_constants.INDEX2_NAME])])
responses.put(test_constants.TARGET_ENDPOINT + test_constants.INDEX3_NAME,
match=[matchers.json_params_matcher(test_data[test_constants.INDEX3_NAME])])
index_operations.create_indices(test_data, test_constants.TARGET_ENDPOINT, None)
index_operations.create_indices(test_data, EndpointInfo(test_constants.TARGET_ENDPOINT))

@responses.activate
def test_create_indices_exception(self):
Expand All @@ -43,7 +44,18 @@ def test_create_indices_exception(self):
del test_data[test_constants.INDEX3_NAME]
responses.put(test_constants.TARGET_ENDPOINT + test_constants.INDEX1_NAME,
body=requests.Timeout())
index_operations.create_indices(test_data, test_constants.TARGET_ENDPOINT, None)
self.assertRaises(RuntimeError, index_operations.create_indices, test_data,
EndpointInfo(test_constants.TARGET_ENDPOINT))

@responses.activate
def test_doc_count(self):
test_indices = {test_constants.INDEX1_NAME, test_constants.INDEX2_NAME}
expected_count_endpoint = test_constants.SOURCE_ENDPOINT + ",".join(test_indices) + "/_count"
mock_count_response = {"count": "10"}
responses.get(expected_count_endpoint, json=mock_count_response)
# Now send request
count_value = index_operations.doc_count(test_indices, EndpointInfo(test_constants.SOURCE_ENDPOINT))
self.assertEqual(10, count_value)


if __name__ == '__main__':
Expand Down
54 changes: 36 additions & 18 deletions FetchMigration/index_configuration_tool/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,25 +88,26 @@ def test_get_endpoint_info(self):
# Simple base case
test_config = create_plugin_config([host_input])
result = main.get_endpoint_info(test_config)
self.assertEqual(expected_endpoint, result[0])
self.assertIsNone(result[1])
self.assertEqual(expected_endpoint, result.url)
self.assertIsNone(result.auth)
self.assertTrue(result.verify_ssl)
# Invalid auth config
test_config = create_plugin_config([host_input], test_user)
result = main.get_endpoint_info(test_config)
self.assertEqual(expected_endpoint, result[0])
self.assertIsNone(result[1])
self.assertEqual(expected_endpoint, result.url)
self.assertIsNone(result.auth)
# Valid auth config
test_config = create_plugin_config([host_input], user=test_user, password=test_password)
result = main.get_endpoint_info(test_config)
self.assertEqual(expected_endpoint, result[0])
self.assertEqual(test_user, result[1][0])
self.assertEqual(test_password, result[1][1])
self.assertEqual(expected_endpoint, result.url)
self.assertEqual(test_user, result.auth[0])
self.assertEqual(test_password, result.auth[1])
# Array of hosts uses the first entry
test_config = create_plugin_config([host_input, "other_host"], test_user, test_password)
result = main.get_endpoint_info(test_config)
self.assertEqual(expected_endpoint, result[0])
self.assertEqual(test_user, result[1][0])
self.assertEqual(test_password, result[1][1])
self.assertEqual(expected_endpoint, result.url)
self.assertEqual(test_user, result.auth[0])
self.assertEqual(test_password, result.auth[1])

def test_get_index_differences_empty(self):
# Base case should return an empty list
Expand Down Expand Up @@ -225,17 +226,18 @@ def test_validate_pipeline_config_happy_case(self):
test_config = next(iter(self.loaded_pipeline_config.values()))
main.validate_pipeline_config(test_config)

@patch('index_operations.doc_count')
@patch('main.write_output')
@patch('main.print_report')
@patch('index_operations.create_indices')
@patch('index_operations.fetch_all_indices')
# Note that mock objects are passed bottom-up from the patch order above
def test_run_report(self, mock_fetch_indices: MagicMock, mock_create_indices: MagicMock,
mock_print_report: MagicMock, mock_write_output: MagicMock):
mock_print_report: MagicMock, mock_write_output: MagicMock, mock_doc_count: MagicMock):
mock_doc_count.return_value = 1
index_to_create = test_constants.INDEX3_NAME
index_with_conflict = test_constants.INDEX2_NAME
index_exact_match = test_constants.INDEX1_NAME
expected_output_path = "dummy"
# Set up expected arguments to mocks so we can verify
expected_create_payload = {index_to_create: test_constants.BASE_INDICES_DATA[index_to_create]}
# Print report accepts a tuple. The elements of the tuple
Expand All @@ -252,21 +254,26 @@ def test_run_report(self, mock_fetch_indices: MagicMock, mock_create_indices: Ma
# Set up test input
test_input = argparse.Namespace()
test_input.config_file_path = test_constants.PIPELINE_CONFIG_RAW_FILE_PATH
test_input.output_file = expected_output_path
# Default value for missing output file
test_input.output_file = ""
test_input.report = True
test_input.dryrun = False
main.run(test_input)
mock_create_indices.assert_called_once_with(expected_create_payload, test_constants.TARGET_ENDPOINT, ANY)
mock_print_report.assert_called_once_with(expected_diff)
mock_write_output.assert_called_once_with(self.loaded_pipeline_config, {index_to_create}, expected_output_path)
mock_create_indices.assert_called_once_with(expected_create_payload, ANY)
mock_doc_count.assert_called()
mock_print_report.assert_called_once_with(expected_diff, 1)
mock_write_output.assert_not_called()

@patch('index_operations.doc_count')
@patch('main.dump_count_and_indices')
@patch('main.print_report')
@patch('main.write_output')
@patch('index_operations.fetch_all_indices')
# Note that mock objects are passed bottom-up from the patch order above
def test_run_dryrun(self, mock_fetch_indices: MagicMock, mock_write_output: MagicMock,
mock_print_report: MagicMock):
mock_print_report: MagicMock, mock_dump: MagicMock, mock_doc_count: MagicMock):
index_to_create = test_constants.INDEX1_NAME
mock_doc_count.return_value = 1
expected_output_path = "dummy"
# Create mock data for indices on target
target_indices_data = copy.deepcopy(test_constants.BASE_INDICES_DATA)
Expand All @@ -281,8 +288,10 @@ def test_run_dryrun(self, mock_fetch_indices: MagicMock, mock_write_output: Magi
test_input.report = False
main.run(test_input)
mock_write_output.assert_called_once_with(self.loaded_pipeline_config, {index_to_create}, expected_output_path)
# Report should not be printed
mock_doc_count.assert_called()
# Report should not be printed, but dump should be invoked
mock_print_report.assert_not_called()
mock_dump.assert_called_once_with(mock_doc_count.return_value, {index_to_create})

@patch('yaml.dump')
def test_write_output(self, mock_dump: MagicMock):
Expand Down Expand Up @@ -311,6 +320,15 @@ def test_write_output(self, mock_dump: MagicMock):
mock_open.assert_called_once_with(expected_output_path, 'w')
mock_dump.assert_called_once_with(expected_output_data, ANY)

def test_missing_output_file_non_report(self):
# Set up test input
test_input = argparse.Namespace()
test_input.config_file_path = test_constants.PIPELINE_CONFIG_RAW_FILE_PATH
# Default value for missing output file
test_input.output_file = ""
test_input.report = False
self.assertRaises(ValueError, main.run, test_input)


if __name__ == '__main__':
unittest.main()

0 comments on commit 2c8e3ab

Please sign in to comment.