diff --git a/detection_rules/kbwrap.py b/detection_rules/kbwrap.py index dec19acf4aa..d3b9c593f20 100644 --- a/detection_rules/kbwrap.py +++ b/detection_rules/kbwrap.py @@ -52,9 +52,7 @@ def kibana_group(ctx: click.Context, **kibana_kwargs): @click.pass_context def upload_rule(ctx, toml_files): """Upload a list of rule .toml files to Kibana.""" - from uuid import uuid4 from .packaging import manage_versions - from .schemas import downgrade kibana = ctx.obj['kibana'] file_lookup = load_rule_files(paths=toml_files) @@ -68,11 +66,8 @@ def upload_rule(ctx, toml_files): api_payloads = [] for rule in rules: - payload = rule.contents.copy() - meta = payload.setdefault("meta", {}) - meta["original"] = dict(id=rule.id, **rule.metadata) - payload["rule_id"] = str(uuid4()) - payload = downgrade(payload, kibana.version) + payload = rule.get_payload(include_version=True, replace_id=True, embed_metadata=True, + target_version=kibana.version) rule = RuleResource(payload) api_payloads.append(rule) diff --git a/detection_rules/main.py b/detection_rules/main.py index 06c83a8882f..5f281407ad2 100644 --- a/detection_rules/main.py +++ b/detection_rules/main.py @@ -150,7 +150,7 @@ def view_rule(ctx, rule_id, rule_file, api_format): client_error('Unknown format!') click.echo(toml_write(rule.rule_format()) if not api_format else - json.dumps(rule.contents, indent=2, sort_keys=True)) + json.dumps(rule.get_payload(), indent=2, sort_keys=True)) return rule diff --git a/detection_rules/packaging.py b/detection_rules/packaging.py index 064c962a49e..5a025c7127d 100644 --- a/detection_rules/packaging.py +++ b/detection_rules/packaging.py @@ -14,7 +14,7 @@ import click from . import rule_loader -from .misc import JS_LICENSE +from .misc import JS_LICENSE, cached from .rule import Rule # noqa: F401 from .utils import get_path, get_etc_path, load_etc_dump, save_etc_dump @@ -49,14 +49,19 @@ def filter_rule(rule: Rule, config_filter: dict, exclude_fields: dict) -> bool: return True +@cached +def load_versions(current_versions: dict = None): + """Load the versions file.""" + return current_versions or load_etc_dump('version.lock.json') + + def manage_versions(rules: list, deprecated_rules: list = None, current_versions: dict = None, exclude_version_update=False, add_new=True, save_changes=False, verbose=True) -> (list, list, list): """Update the contents of the version.lock file and optionally save changes.""" new_rules = {} changed_rules = [] - if current_versions is None: - current_versions = load_etc_dump('version.lock.json') + current_versions = load_versions(current_versions) for rule in rules: # it is a new rule, so add it if specified, and add an initial version to the rule @@ -210,7 +215,7 @@ def get_consolidated(self, as_api=True): """Get a consolidated package of the rules in a single file.""" full_package = [] for rule in self.rules: - full_package.append(rule.contents if as_api else rule.rule_format()) + full_package.append(rule.get_payload() if as_api else rule.rule_format()) return json.dumps(full_package, sort_keys=True) diff --git a/detection_rules/rule.py b/detection_rules/rule.py index 000e8ef9465..c7bb5c207ef 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -164,6 +164,31 @@ def set_metadata(self, contents): defaults.update(metadata) return defaults + @staticmethod + def _add_empty_attack_technique(contents: dict = None): + """Add empty array to ATT&CK technique threat mapping.""" + threat = contents.get('threat', []) + + if threat: + new_threat = [] + + for entry in contents.get('threat', []): + if 'technique' not in entry: + new_entry = entry.copy() + new_entry['technique'] = [] + new_threat.append(new_entry) + else: + new_threat.append(entry) + + contents['threat'] = new_threat + + return contents + + def _run_build_time_transforms(self, contents): + """Apply changes to rules at build time for rule payload.""" + self._add_empty_attack_technique(contents) + return contents + def rule_format(self, formatted_query=True): """Get the contents in rule format.""" contents = self.contents.copy() @@ -299,7 +324,7 @@ def save(self, new_path=None, as_rule=False, verbose=False): toml_write(self.rule_format(), path) else: with open(path, 'w', newline='\n') as f: - json.dump(self.contents, f, sort_keys=True, indent=2) + json.dump(self.get_payload(), f, sort_keys=True, indent=2) f.write('\n') if verbose: @@ -316,7 +341,42 @@ def dict_hash(cls, contents, versioned=True): def get_hash(self): """Get a standardized hash of a rule to consistently check for changes.""" - return self.dict_hash(self.contents) + return self.dict_hash(self.get_payload()) + + def get_version(self): + """Get the version of the rule.""" + from .packaging import load_versions + + rules_versions = load_versions + + if self.id in rules_versions: + version_info = rules_versions[self.id] + version = version_info['version'] + return version + 1 if self.get_hash() != version_info['sha256'] else version + else: + return 1 + + def get_payload(self, include_version=False, replace_id=False, embed_metadata=False, target_version=None): + """Get rule as uploadable/API-compatible payload.""" + from uuid import uuid4 + from .schemas import downgrade + + payload = self._run_build_time_transforms(self.contents.copy()) + + if include_version: + payload['version'] = self.get_version() + + if embed_metadata: + meta = payload.setdefault("meta", {}) + meta["original"] = dict(id=self.id, **self.metadata) + + if replace_id: + payload["rule_id"] = str(uuid4()) + + if target_version: + payload = downgrade(payload, target_version) + + return payload @classmethod def build(cls, path=None, rule_type=None, required_only=True, save=True, verbose=False, **kwargs): diff --git a/detection_rules/schemas/v7_8.py b/detection_rules/schemas/v7_8.py index e1614981f0a..db1534e27d8 100644 --- a/detection_rules/schemas/v7_8.py +++ b/detection_rules/schemas/v7_8.py @@ -72,7 +72,7 @@ class ThreatTactic(jsl.Document): class ThreatTechnique(jsl.Document): id = jsl.StringField(enum=technique_id_list, required=True) name = jsl.StringField(required=True) - reference = jsl.StringField(MITRE_URL_PATTERN.format(type='techniques')) + reference = jsl.StringField(MITRE_URL_PATTERN.format(type='techniques'), required=True) framework = jsl.StringField(default='MITRE ATT&CK', required=True) tactic = jsl.DocumentField(ThreatTactic, required=True)