Skip to content

Commit

Permalink
Add ability to include external directory of schemas when running gen…
Browse files Browse the repository at this point in the history
…erator (#494)
  • Loading branch information
mcpate authored and Mathieu Martin committed Jul 8, 2019
1 parent a35a903 commit 48f0cb8
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ generate: template legacy_use_cases codegen generator
# Run the new generator
.PHONY: generator
generator:
$(PYTHON) scripts/generator.py
$(PYTHON) scripts/generator.py --include "${INCLUDE}"

# Generate Go code from the schema.
.PHONY: gocodegen
Expand Down
25 changes: 21 additions & 4 deletions scripts/generator.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
import argparse

import glob
import os
import schema_reader
from generators import intermediate_files
from generators import csv_generator
from generators import es_template
from generators import beats
from generators import asciidoc_fields
from generators import ecs_helpers


def main():
args = argument_parser()

ecs_version = read_version()
print "Running generator. ECS version " + ecs_version
print 'Running generator. ECS version ' + ecs_version

# Load the default schemas
print 'Loading default schemas'
(ecs_nested, ecs_flat) = schema_reader.load_ecs(sorted(glob.glob('schemas/*.yml')))

# Maybe load user specified directory of schemas
if args.include:
include_glob = os.path.join(args.include + '/*.yml')

(ecs_nested, ecs_flat) = schema_reader.load_ecs()
print 'Loading user defined schemas: {0}'.format(include_glob)

(user_ecs_nested, user_ecs_flat) = schema_reader.load_ecs(sorted(glob.glob(include_glob)))

# Merge without allowing user schemas to overwrite default schemas
ecs_nested = ecs_helpers.safe_merge_dicts(ecs_nested, user_ecs_nested)
ecs_flat = ecs_helpers.safe_merge_dicts(ecs_flat, user_ecs_flat)

intermediate_files.generate(ecs_nested, ecs_flat)
if args.intermediate_only:
Expand All @@ -30,7 +46,8 @@ def argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--intermediate-only', action='store_true',
help='generate intermediary files only')

parser.add_argument('--include', action='store',
help='include user specified directory of (custom) ecs schemas')
return parser.parse_args()


Expand Down
12 changes: 12 additions & 0 deletions scripts/generators/ecs_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import yaml

from collections import OrderedDict
from copy import deepcopy

# Dictionary helpers

Expand Down Expand Up @@ -35,6 +36,17 @@ def dict_sorted_by_keys(dict, sort_keys):
return list(map(lambda t: t[-1], sorted(tuples)))


def safe_merge_dicts(a, b):
"""Merges two dictionaries into one. If duplicate keys are detected a ValueError is raised."""
c = deepcopy(a)
for key in b:
if key not in c:
c[key] = b[key]
else:
raise ValueError('Duplicate key found when merging dictionaries: {0}'.format(key))
return c


def yaml_ordereddict(dumper, data):
# YAML representation of an OrderedDict will be like a dictionary, but
# respecting the order of the dictionary.
Expand Down
12 changes: 4 additions & 8 deletions scripts/schema_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@
# File loading stuff


def schema_files():
"""Return the schema file list to load"""
return sorted(glob.glob("schemas/*.yml"))


def read_schema_file(file):
"""Read a raw schema yml into a map, removing the wrapping array in each file"""
with open(file) as f:
Expand All @@ -19,7 +14,7 @@ def read_schema_file(file):
return fields


def load_schema_files(files=schema_files()):
def load_schema_files(files):
fields_nested = {}
for f in files:
new_fields = read_schema_file(f)
Expand Down Expand Up @@ -169,8 +164,9 @@ def finalize_schemas(fields_nested, fields_flat):
duplicate_reusable_fieldsets(schema, fields_flat, fields_nested)


def load_ecs():
fields_nested = load_schema_files()
def load_ecs(files):
"""Loads the given list of files"""
fields_nested = load_schema_files(files)
fields_flat = {}
finalize_schemas(fields_nested, fields_flat)
return (fields_nested, fields_flat)
23 changes: 23 additions & 0 deletions scripts/tests/test_ecs_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,29 @@ def test_sorted_by_multiple_keys(self):
result = ecs_helpers.dict_sorted_by_keys(dict, ['group', 'name'])
self.assertEqual(result, expected)

def test_merge_dicts(self):
a = {
'cloud': {'group': 2, 'name': 'cloud'},
'agent': {'group': 2, 'name': 'agent'},
}
b = {'base': {'group': 1, 'name': 'base'}}

result = ecs_helpers.safe_merge_dicts(a, b)

self.assertEquals(result,
{
'cloud': {'group': 2, 'name': 'cloud'},
'agent': {'group': 2, 'name': 'agent'},
'base': {'group': 1, 'name': 'base'}
})

def test_merge_dicts_raises_if_duplicate_key_added(self):
a = {'cloud': {'group': 2, 'name': 'cloud'}}
b = {'cloud': {'group': 9, 'name': 'bazbar'}}

with self.assertRaises(ValueError):
ecs_helpers.safe_merge_dicts(a, b)


if __name__ == '__main__':
unittest.main()
3 changes: 2 additions & 1 deletion scripts/tests/test_ecs_spec.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
import sys
import glob
import unittest

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from scripts import schema_reader


(nested, flat) = schema_reader.load_ecs()
(nested, flat) = schema_reader.load_ecs(sorted(glob.glob('schemas/*.yml')))


class TestEcsSpec(unittest.TestCase):
Expand Down
4 changes: 4 additions & 0 deletions scripts/tests/test_schema_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def test_field_set_multi_field_defaults_missing_name(self):
}
self.assertEqual(field, expected)

def test_load_ecs_with_empty_list_loads_nothing(self):
result = schema_reader.load_ecs([])
self.assertEqual(result, ({}, {}))


if __name__ == '__main__':
unittest.main()

0 comments on commit 48f0cb8

Please sign in to comment.