Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to include external directory of schemas when running generator #494

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ generate: template legacy_use_cases codegen generator
# Run the new generator
.PHONY: generator
generator:
$(PYTHON) scripts/generator.py
$(PYTHON) scripts/generator.py --include "${INCLUDE}"

# Generate Go code from the schema.
.PHONY: gocodegen
Expand Down
25 changes: 21 additions & 4 deletions scripts/generator.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
import argparse

import glob
import os
import schema_reader
from generators import intermediate_files
from generators import csv_generator
from generators import es_template
from generators import beats
from generators import asciidoc_fields
from generators import ecs_helpers


def main():
args = argument_parser()

ecs_version = read_version()
print "Running generator. ECS version " + ecs_version
print 'Running generator. ECS version ' + ecs_version

# Load the default schemas
print 'Loading default schemas'
(ecs_nested, ecs_flat) = schema_reader.load_ecs(sorted(glob.glob('schemas/*.yml')))

# Maybe load user specified directory of schemas
if args.include:
include_glob = os.path.join(args.include + '/*.yml')

(ecs_nested, ecs_flat) = schema_reader.load_ecs()
print 'Loading user defined schemas: {0}'.format(include_glob)

(user_ecs_nested, user_ecs_flat) = schema_reader.load_ecs(sorted(glob.glob(include_glob)))

# Merge without allowing user schemas to overwrite default schemas
ecs_nested = ecs_helpers.safe_merge_dicts(ecs_nested, user_ecs_nested)
ecs_flat = ecs_helpers.safe_merge_dicts(ecs_flat, user_ecs_flat)

intermediate_files.generate(ecs_nested, ecs_flat)
if args.intermediate_only:
Expand All @@ -30,7 +46,8 @@ def argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--intermediate-only', action='store_true',
help='generate intermediary files only')

parser.add_argument('--include', action='store',
help='include user specified directory of (custom) ecs schemas')
return parser.parse_args()


Expand Down
12 changes: 12 additions & 0 deletions scripts/generators/ecs_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import yaml

from collections import OrderedDict
from copy import deepcopy

# Dictionary helpers

Expand Down Expand Up @@ -35,6 +36,17 @@ def dict_sorted_by_keys(dict, sort_keys):
return list(map(lambda t: t[-1], sorted(tuples)))


def safe_merge_dicts(a, b):
"""Merges two dictionaries into one. If duplicate keys are detected a ValueError is raised."""
c = deepcopy(a)
for key in b:
if key not in c:
c[key] = b[key]
else:
raise ValueError('Duplicate key found when merging dictionaries: {0}'.format(key))
return c


def yaml_ordereddict(dumper, data):
# YAML representation of an OrderedDict will be like a dictionary, but
# respecting the order of the dictionary.
Expand Down
12 changes: 4 additions & 8 deletions scripts/schema_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@
# File loading stuff


def schema_files():
"""Return the schema file list to load"""
return sorted(glob.glob("schemas/*.yml"))


def read_schema_file(file):
"""Read a raw schema yml into a map, removing the wrapping array in each file"""
with open(file) as f:
Expand All @@ -19,7 +14,7 @@ def read_schema_file(file):
return fields


def load_schema_files(files=schema_files()):
def load_schema_files(files):
fields_nested = {}
for f in files:
new_fields = read_schema_file(f)
Expand Down Expand Up @@ -169,8 +164,9 @@ def finalize_schemas(fields_nested, fields_flat):
duplicate_reusable_fieldsets(schema, fields_flat, fields_nested)


def load_ecs():
fields_nested = load_schema_files()
def load_ecs(files):
"""Loads the given list of files"""
fields_nested = load_schema_files(files)
fields_flat = {}
finalize_schemas(fields_nested, fields_flat)
return (fields_nested, fields_flat)
23 changes: 23 additions & 0 deletions scripts/tests/test_ecs_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,29 @@ def test_sorted_by_multiple_keys(self):
result = ecs_helpers.dict_sorted_by_keys(dict, ['group', 'name'])
self.assertEqual(result, expected)

def test_merge_dicts(self):
a = {
'cloud': {'group': 2, 'name': 'cloud'},
'agent': {'group': 2, 'name': 'agent'},
}
b = {'base': {'group': 1, 'name': 'base'}}

result = ecs_helpers.safe_merge_dicts(a, b)

self.assertEquals(result,
{
'cloud': {'group': 2, 'name': 'cloud'},
'agent': {'group': 2, 'name': 'agent'},
'base': {'group': 1, 'name': 'base'}
})
mcpate marked this conversation as resolved.
Show resolved Hide resolved

def test_merge_dicts_raises_if_duplicate_key_added(self):
a = {'cloud': {'group': 2, 'name': 'cloud'}}
b = {'cloud': {'group': 9, 'name': 'bazbar'}}

with self.assertRaises(ValueError):
ecs_helpers.safe_merge_dicts(a, b)


if __name__ == '__main__':
unittest.main()
3 changes: 2 additions & 1 deletion scripts/tests/test_ecs_spec.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
import sys
import glob
import unittest

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from scripts import schema_reader


(nested, flat) = schema_reader.load_ecs()
(nested, flat) = schema_reader.load_ecs(sorted(glob.glob('schemas/*.yml')))


class TestEcsSpec(unittest.TestCase):
Expand Down
4 changes: 4 additions & 0 deletions scripts/tests/test_schema_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def test_field_set_multi_field_defaults_missing_name(self):
}
self.assertEqual(field, expected)

def test_load_ecs_with_empty_list_loads_nothing(self):
result = schema_reader.load_ecs([])
self.assertEqual(result, ({}, {}))


if __name__ == '__main__':
unittest.main()