forked from elastic/detection-rules
-
Notifications
You must be signed in to change notification settings - Fork 1
/
mixins.py
251 lines (196 loc) · 9.85 KB
/
mixins.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.
"""Generic mixin classes."""
import dataclasses
from pathlib import Path
from typing import Any, Optional, TypeVar, Type, Literal
import json
import marshmallow_dataclass
import marshmallow_dataclass.union_field
import marshmallow_jsonschema
import marshmallow_union
import marshmallow
from marshmallow import Schema, ValidationError, validates_schema, fields as marshmallow_fields
from .misc import load_current_package_version
from .schemas import definitions
from .schemas.stack_compat import get_incompatible_fields
from semver import Version
from .utils import cached, dict_hash
T = TypeVar('T')
ClassT = TypeVar('ClassT') # bound=dataclass?
UNKNOWN_VALUES = Literal['raise', 'exclude', 'include']
def _strip_none_from_dict(obj: T) -> T:
"""Strip none values from a dict recursively."""
if isinstance(obj, dict):
return {key: _strip_none_from_dict(value) for key, value in obj.items() if value is not None}
if isinstance(obj, list):
return [_strip_none_from_dict(o) for o in obj]
if isinstance(obj, tuple):
return tuple(_strip_none_from_dict(list(obj)))
return obj
def patch_jsonschema(obj: dict) -> dict:
"""Patch marshmallow-jsonschema output to look more like JSL."""
def dive(child: dict) -> dict:
if "$ref" in child:
name = child["$ref"].split("/")[-1]
definition = obj["definitions"][name]
return dive(definition)
child = child.copy()
if "default" in child and child["default"] is None:
child.pop("default")
child.pop("title", None)
if "anyOf" in child:
child["anyOf"] = [dive(c) for c in child["anyOf"]]
elif isinstance(child["type"], list):
if 'null' in child["type"]:
child["type"] = [t for t in child["type"] if t != 'null']
if len(child["type"]) == 1:
child["type"] = child["type"][0]
if "items" in child:
child["items"] = dive(child["items"])
if "properties" in child:
# .rstrip("_") is workaround for `from_` -> from
# https://github.com/fuhrysteve/marshmallow-jsonschema/issues/107
child["properties"] = {k.rstrip("_"): dive(v) for k, v in child["properties"].items()}
if isinstance(child.get("additionalProperties"), dict):
# .rstrip("_") is workaround for `from_` -> from
# https://github.com/fuhrysteve/marshmallow-jsonschema/issues/107
child["additionalProperties"] = dive(child["additionalProperties"])
return child
patched = {"$schema": "http://json-schema.org/draft-04/schema#"}
patched.update(dive(obj))
return patched
class BaseSchema(Schema):
"""Base schema for marshmallow dataclasses with unknown."""
class Meta:
"""Meta class for marshmallow schema."""
def exclude_class_schema(
clazz, base_schema: type[Schema] = BaseSchema, unknown: UNKNOWN_VALUES = marshmallow.EXCLUDE, **kwargs
) -> type[Schema]:
"""Get a marshmallow schema for a dataclass with unknown=EXCLUDE."""
base_schema.Meta.unknown = unknown
return marshmallow_dataclass.class_schema(clazz, base_schema=base_schema, **kwargs)
def recursive_class_schema(
clazz, base_schema: type[Schema] = BaseSchema, unknown: UNKNOWN_VALUES = marshmallow.EXCLUDE, **kwargs
) -> type[Schema]:
"""Recursively apply the unknown parameter for nested schemas."""
schema = exclude_class_schema(clazz, base_schema=base_schema, unknown=unknown, **kwargs)
for field in dataclasses.fields(clazz):
if dataclasses.is_dataclass(field.type):
nested_cls = field.type
nested_schema = recursive_class_schema(nested_cls, base_schema=base_schema, **kwargs)
setattr(schema, field.name, nested_schema)
return schema
class MarshmallowDataclassMixin:
"""Mixin class for marshmallow serialization."""
@classmethod
@cached
def __schema(cls: ClassT, unknown: Optional[UNKNOWN_VALUES] = None) -> Schema:
"""Get the marshmallow schema for the data class"""
if unknown:
return recursive_class_schema(cls, unknown=unknown)()
else:
return marshmallow_dataclass.class_schema(cls)()
def get(self, key: str, default: Optional[Any] = None):
"""Get a key from the query data without raising attribute errors."""
return getattr(self, key, default)
@classmethod
@cached
def jsonschema(cls):
"""Get the jsonschema representation for this class."""
jsonschema = PatchedJSONSchema().dump(cls.__schema())
jsonschema = patch_jsonschema(jsonschema)
return jsonschema
@classmethod
def from_dict(cls: Type[ClassT], obj: dict, unknown: Optional[UNKNOWN_VALUES] = None) -> ClassT:
"""Deserialize and validate a dataclass from a dict using marshmallow."""
schema = cls.__schema(unknown=unknown)
return schema.load(obj)
def to_dict(self, strip_none_values=True) -> dict:
"""Serialize a dataclass to a dictionary using marshmallow."""
schema = self.__schema()
serialized: dict = schema.dump(self)
if strip_none_values:
serialized = _strip_none_from_dict(serialized)
return serialized
class LockDataclassMixin:
"""Mixin class for version and deprecated rules lock files."""
@classmethod
@cached
def __schema(cls: ClassT) -> Schema:
"""Get the marshmallow schema for the data class"""
return marshmallow_dataclass.class_schema(cls)()
def get(self, key: str, default: Optional[Any] = None):
"""Get a key from the query data without raising attribute errors."""
return getattr(self, key, default)
@classmethod
def from_dict(cls: Type[ClassT], obj: dict) -> ClassT:
"""Deserialize and validate a dataclass from a dict using marshmallow."""
schema = cls.__schema()
try:
loaded = schema.load(obj)
except ValidationError as e:
err_msg = json.dumps(e.messages, indent=2)
raise ValidationError(f'Validation error loading: {cls.__name__}\n{err_msg}') from None
return loaded
def to_dict(self, strip_none_values=True) -> dict:
"""Serialize a dataclass to a dictionary using marshmallow."""
schema = self.__schema()
serialized: dict = schema.dump(self)
if strip_none_values:
serialized = _strip_none_from_dict(serialized)
return serialized['data']
@classmethod
def load_from_file(cls: Type[ClassT], lock_file: Optional[Path] = None) -> ClassT:
"""Load and validate a version lock file."""
path: Path = getattr(cls, 'file_path', lock_file)
contents = json.loads(path.read_text())
loaded = cls.from_dict(dict(data=contents))
return loaded
def sha256(self) -> definitions.Sha256:
"""Get the sha256 hash of the version lock contents."""
contents = self.to_dict()
return dict_hash(contents)
def save_to_file(self, lock_file: Optional[Path] = None):
"""Save and validate a version lock file."""
path: Path = lock_file or getattr(self, 'file_path', None)
assert path, 'No path passed or set'
contents = self.to_dict()
path.write_text(json.dumps(contents, indent=2, sort_keys=True))
class StackCompatMixin:
"""Mixin to restrict schema compatibility to defined stack versions."""
@validates_schema
def validate_field_compatibility(self, data: dict, **kwargs):
"""Verify stack-specific fields are properly applied to schema."""
package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True)
schema_fields = getattr(self, 'fields', {})
incompatible = get_incompatible_fields(list(schema_fields.values()), package_version)
if not incompatible:
return
package_version = load_current_package_version()
for field, bounds in incompatible.items():
min_compat, max_compat = bounds
if data.get(field) is not None:
raise ValidationError(f'Invalid field: "{field}" for stack version: {package_version}, '
f'min compatibility: {min_compat}, max compatibility: {max_compat}')
class PatchedJSONSchema(marshmallow_jsonschema.JSONSchema):
# Patch marshmallow-jsonschema to support marshmallow-dataclass[union]
def _get_schema_for_field(self, obj, field):
"""Patch marshmallow_jsonschema.base.JSONSchema to support marshmallow-dataclass[union]."""
if isinstance(field, marshmallow_fields.Raw) and field.allow_none and not field.validate:
# raw fields shouldn't be type string but type any. bug in marshmallow_dataclass:__init__.py:
# if typ is Any:
# metadata.setdefault("allow_none", True)
# return marshmallow.fields.Raw(**metadata)
return {"type": ["string", "number", "object", "array", "boolean", "null"]}
if isinstance(field, marshmallow_dataclass.union_field.Union):
# convert to marshmallow_union.Union
field = marshmallow_union.Union([subfield for _, subfield in field.union_fields],
metadata=field.metadata,
required=field.required, name=field.name,
parent=field.parent, root=field.root, error_messages=field.error_messages,
default_error_messages=field.default_error_messages, default=field.default,
allow_none=field.allow_none)
return super()._get_schema_for_field(obj, field)