diff --git a/bin/release b/bin/release index 0edce3b489f7..1759c65fe3d4 100755 --- a/bin/release +++ b/bin/release @@ -5,3 +5,6 @@ set -eo pipefail # Migrate our database to the latest revision. python -m warehouse db upgrade head + +# Insert/upgrade malware checks. +python -m warehouse malware sync-checks diff --git a/tests/unit/cli/test_malware.py b/tests/unit/cli/test_malware.py new file mode 100644 index 000000000000..69613bf4ace1 --- /dev/null +++ b/tests/unit/cli/test_malware.py @@ -0,0 +1,36 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pretend + +from warehouse.cli.malware import sync_checks +from warehouse.malware.tasks import sync_checks as _sync_checks + + +class TestCLIMalware: + def test_sync_checks(self, cli): + request = pretend.stub() + task = pretend.stub( + get_request=pretend.call_recorder(lambda *a, **kw: request), + run=pretend.call_recorder(lambda *a, **kw: None), + ) + config = pretend.stub(task=pretend.call_recorder(lambda *a, **kw: task)) + + result = cli.invoke(sync_checks, obj=config) + + assert result.exit_code == 0 + assert config.task.calls == [ + pretend.call(_sync_checks), + pretend.call(_sync_checks), + ] + assert task.get_request.calls == [pretend.call()] + assert task.run.calls == [pretend.call(request)] diff --git a/tests/unit/malware/test_checks.py b/tests/unit/malware/test_checks.py new file mode 100644 index 000000000000..2ce63c624965 --- /dev/null +++ b/tests/unit/malware/test_checks.py @@ -0,0 +1,45 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import warehouse.malware.checks as checks + +from warehouse.malware.checks.base import MalwareCheckBase +from warehouse.malware.utils import get_check_fields + + +def test_checks_subclass_base(): + checks_from_module = inspect.getmembers(checks, inspect.isclass) + + subclasses_of_malware_base = { + cls.__name__: cls for cls in MalwareCheckBase.__subclasses__() + } + + assert len(checks_from_module) == len(subclasses_of_malware_base) + + for check_name, check in checks_from_module: + assert subclasses_of_malware_base[check_name] == check + + +def test_checks_fields(): + checks_from_module = inspect.getmembers(checks, inspect.isclass) + + for check_name, check in checks_from_module: + elems = inspect.getmembers(check, lambda a: not (inspect.isroutine(a))) + inspection_fields = {"name": check_name} + for elem_name, value in elems: + if not elem_name.startswith("__"): + inspection_fields[elem_name] = value + fields = get_check_fields(check) + + assert inspection_fields == fields diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index 38f79b201ba2..1057af6855a5 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -18,68 +18,240 @@ import warehouse.malware.checks as checks -from warehouse.malware.models import MalwareVerdict -from warehouse.malware.tasks import run_check +from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict +from warehouse.malware.tasks import run_check, sync_checks from ...common.db.malware import MalwareCheckFactory from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory -def test_run_check(monkeypatch, db_request): - project = ProjectFactory.create(name="foo") - release = ReleaseFactory.create(project=project) - file0 = FileFactory.create(release=release, filename="foo.bar") - MalwareCheckFactory.create(name="ExampleCheck", state="enabled") +class TestRunCheck: + def test_success(self, monkeypatch, db_request): + project = ProjectFactory.create(name="foo") + release = ReleaseFactory.create(project=project) + file0 = FileFactory.create(release=release, filename="foo.bar") + MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) - task = pretend.stub() - run_check(task, db_request, "ExampleCheck", file0.id) - assert db_request.db.query(MalwareVerdict).one() + task = pretend.stub() + run_check(task, db_request, "ExampleCheck", file0.id) + assert db_request.db.query(MalwareVerdict).one() + def test_missing_check_id(self, monkeypatch, db_session): + exc = NoResultFound("No row was found for one()") -def test_run_check_missing_check_id(monkeypatch, db_session): - exc = NoResultFound("No row was found for one()") + class FakeMalwareCheck: + def __init__(self, db): + raise exc - class FakeMalwareCheck: - def __init__(self, db): - raise exc + checks.FakeMalwareCheck = FakeMalwareCheck - class Task: - @staticmethod - @pretend.call_recorder - def retry(exc): - raise celery.exceptions.Retry + class Task: + @staticmethod + @pretend.call_recorder + def retry(exc): + raise celery.exceptions.Retry - task = Task() + task = Task() - checks.FakeMalwareCheck = FakeMalwareCheck + request = pretend.stub( + db=db_session, + log=pretend.stub( + error=pretend.call_recorder(lambda *args, **kwargs: None), + ), + ) + + with pytest.raises(celery.exceptions.Retry): + run_check( + task, + request, + "FakeMalwareCheck", + "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + ) + + assert request.log.error.calls == [ + pretend.call( + "Error executing check %s: %s", + "FakeMalwareCheck", + "No row was found for one()", + ) + ] + + assert task.retry.calls == [pretend.call(exc=exc)] + + del checks.FakeMalwareCheck + + def test_missing_check(self, db_request): + task = pretend.stub() + with pytest.raises(AttributeError): + run_check( + task, + db_request, + "DoesNotExistCheck", + "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + ) + + +class TestSyncChecks: + def test_no_updates(self, db_session): + MalwareCheckFactory.create( + name="ExampleCheck", state=MalwareCheckState.disabled + ) + + task = pretend.stub() + + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + ) + + sync_checks(task, request) + + assert request.log.info.calls == [ + pretend.call("1 malware checks found in codebase."), + pretend.call("ExampleCheck is unmodified."), + ] - request = pretend.stub( - db=db_session, - log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None),), + @pytest.mark.parametrize( + ("final_state"), [MalwareCheckState.enabled, MalwareCheckState.disabled] ) + def test_upgrade_check(self, monkeypatch, db_session, final_state): + MalwareCheckFactory.create(name="ExampleCheck", state=final_state) + + class ExampleCheck: + version = 2 + short_description = "This is a short description." + long_description = "This is a longer description." + check_type = "scheduled" + + monkeypatch.setattr(checks, "ExampleCheck", ExampleCheck) + + task = pretend.stub() + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + ) + + sync_checks(task, request) + + assert request.log.info.calls == [ + pretend.call("1 malware checks found in codebase."), + pretend.call("Updating existing ExampleCheck."), + ] + db_checks = ( + db_session.query(MalwareCheck) + .filter(MalwareCheck.name == "ExampleCheck") + .all() + ) + + assert len(db_checks) == 2 - with pytest.raises(celery.exceptions.Retry): - run_check( - task, request, "FakeMalwareCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7" + if final_state == MalwareCheckState.disabled: + assert ( + db_checks[0].state == db_checks[1].state == MalwareCheckState.disabled + ) + + else: + for c in db_checks: + if c.state == final_state: + assert c.version == 2 + else: + assert c.version == 1 + + def test_one_new_check(self, db_session): + task = pretend.stub() + + class FakeMalwareCheck: + version = 1 + short_description = "This is a short description." + long_description = "This is a longer description." + check_type = "scheduled" + + checks.FakeMalwareCheck = FakeMalwareCheck + + request = pretend.stub( + db=db_session, + log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), + ) + + MalwareCheckFactory.create( + name="ExampleCheck", state=MalwareCheckState.evaluation + ) + + sync_checks(task, request) + + assert request.log.info.calls == [ + pretend.call("2 malware checks found in codebase."), + pretend.call("ExampleCheck is unmodified."), + pretend.call("Adding new FakeMalwareCheck to the database."), + ] + assert db_session.query(MalwareCheck).count() == 2 + + new_check = ( + db_session.query(MalwareCheck) + .filter(MalwareCheck.name == "FakeMalwareCheck") + .one() ) - assert request.log.error.calls == [ - pretend.call( - "Error executing check %s: %s", - "FakeMalwareCheck", - "No row was found for one()", + assert new_check.state == MalwareCheckState.disabled + + del checks.FakeMalwareCheck + + def test_too_many_db_checks(self, db_session): + task = pretend.stub() + + MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) + MalwareCheckFactory.create( + name="AnotherCheck", state=MalwareCheckState.disabled + ) + MalwareCheckFactory.create( + name="AnotherCheck", state=MalwareCheckState.evaluation, version=2 + ) + + request = pretend.stub( + db=db_session, + log=pretend.stub( + info=pretend.call_recorder(lambda *args, **kwargs: None), + error=pretend.call_recorder(lambda *args, **kwargs: None), + ), ) - ] - assert task.retry.calls == [pretend.call(exc=exc)] + with pytest.raises(Exception): + sync_checks(task, request) + assert request.log.info.calls == [ + pretend.call("1 malware checks found in codebase."), + ] -def test_run_check_missing_check(db_request): - task = pretend.stub() - with pytest.raises(AttributeError): - run_check( - task, - db_request, - "DoesNotExistCheck", - "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + assert request.log.error.calls == [ + pretend.call( + "Found 2 active checks in the db, but only 1 checks in code. Please \ +manually move superfluous checks to the wiped_out state in the check admin: \ +AnotherCheck" + ), + ] + + def test_only_wiped_out(self, db_session): + task = pretend.stub() + MalwareCheckFactory.create( + name="ExampleCheck", state=MalwareCheckState.wiped_out + ) + request = pretend.stub( + db=db_session, + log=pretend.stub( + info=pretend.call_recorder(lambda *args, **kwargs: None), + error=pretend.call_recorder(lambda *args, **kwargs: None), + ), ) + + sync_checks(task, request) + + assert request.log.info.calls == [ + pretend.call("1 malware checks found in codebase."), + ] + + assert request.log.error.calls == [ + pretend.call( + "ExampleCheck is wiped_out and cannot be synced. Please remove check \ +from codebase." + ), + ] diff --git a/tests/unit/malware/test_utils.py b/tests/unit/malware/test_utils.py index b995147f6872..0f6ea532debf 100644 --- a/tests/unit/malware/test_utils.py +++ b/tests/unit/malware/test_utils.py @@ -12,36 +12,64 @@ from collections import defaultdict +import pytest + from warehouse.malware.models import MalwareCheckState, MalwareCheckType -from warehouse.malware.utils import get_enabled_checks +from warehouse.malware.utils import get_check_fields, get_enabled_checks from ...common.db.malware import MalwareCheckFactory -def test_get_enabled_checks(db_session): - check = MalwareCheckFactory.create( - state=MalwareCheckState.enabled, check_type=MalwareCheckType.event_hook - ) - result = defaultdict(list) - result[check.hooked_object.value].append(check.name) - checks = get_enabled_checks(db_session) - assert checks == result +class TestGetEnabledChecks: + def test_one(self, db_session): + check = MalwareCheckFactory.create( + state=MalwareCheckState.enabled, check_type=MalwareCheckType.event_hook + ) + result = defaultdict(list) + result[check.hooked_object.value].append(check.name) + checks = get_enabled_checks(db_session) + assert checks == result + + def test_many(self, db_session): + result = defaultdict(list) + for i in range(10): + check = MalwareCheckFactory.create() + if ( + check.state == MalwareCheckState.enabled + and check.check_type == MalwareCheckType.event_hook + ): + result[check.hooked_object.value].append(check.name) + + checks = get_enabled_checks(db_session) + assert checks == result + + def test_none(self, db_session): + checks = get_enabled_checks(db_session) + assert checks == defaultdict(list) -def test_get_enabled_checks_many(db_session): - result = defaultdict(list) - for i in range(10): - check = MalwareCheckFactory.create() - if ( - check.state == MalwareCheckState.enabled - and check.check_type == MalwareCheckType.event_hook - ): - result[check.hooked_object.value].append(check.name) +class TestGetCheckFields: + def test_success(self): + class MySampleCheck: + version = 6 + foo = "bar" + short_description = "This is the description" + long_description = "This is the description" + check_type = "scheduled" - checks = get_enabled_checks(db_session) - assert checks == result + result = get_check_fields(MySampleCheck) + assert result == { + "name": "MySampleCheck", + "version": 6, + "short_description": "This is the description", + "long_description": "This is the description", + "check_type": "scheduled", + } + def test_failure(self): + class MySampleCheck: + version = 1 + status = True -def test_get_enabled_checks_none(db_session): - checks = get_enabled_checks(db_session) - assert checks == defaultdict(list) + with pytest.raises(AttributeError): + get_check_fields(MySampleCheck) diff --git a/warehouse/cli/malware.py b/warehouse/cli/malware.py new file mode 100644 index 000000000000..ad08f557ebaf --- /dev/null +++ b/warehouse/cli/malware.py @@ -0,0 +1,34 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import click + +from warehouse.cli import warehouse +from warehouse.malware.tasks import sync_checks as _sync_checks + + +@warehouse.group() # pragma: no branch +def malware(): + """ + Manage the Warehouse Malware Checks. + """ + + +@malware.command() +@click.pass_obj +def sync_checks(config): + """ + Sync the Warehouse database with the malware checks in malware/checks. + """ + + request = config.task(_sync_checks).get_request() + config.task(_sync_checks).run(request) diff --git a/warehouse/malware/checks/example.py b/warehouse/malware/checks/example.py index c55748cdaf44..519edecfd4df 100644 --- a/warehouse/malware/checks/example.py +++ b/warehouse/malware/checks/example.py @@ -17,17 +17,15 @@ VerdictConfidence, ) -VERSION = 1 -SHORT_DESCRIPTION = "An example hook-based check" -LONG_DESCRIPTION = """The purpose of this check is to demonstrate the implementation \ -of a hook-based check. This check will generate verdicts if enabled.""" - class ExampleCheck(MalwareCheckBase): - version = VERSION - short_description = SHORT_DESCRIPTION - long_description = LONG_DESCRIPTION + version = 1 + short_description = "An example hook-based check" + long_description = """The purpose of this check is to demonstrate the \ +implementation of a hook-based check. This check will generate verdicts if enabled.""" + check_type = "event_hook" + hooked_object = "File" def __init__(self, db): super().__init__(db) diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py index ade9f7a9ae73..1548d28e66d7 100644 --- a/warehouse/malware/tasks.py +++ b/warehouse/malware/tasks.py @@ -10,9 +10,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import warehouse.malware.checks as checks +from warehouse.malware.models import MalwareCheck, MalwareCheckState +from warehouse.malware.utils import get_check_fields from warehouse.tasks import task @@ -24,3 +27,62 @@ def run_check(task, request, check_name, obj_id): except Exception as exc: request.log.error("Error executing check %s: %s", check_name, str(exc)) raise task.retry(exc=exc) + + +@task(bind=True, ignore_result=True, acks_late=True) +def sync_checks(task, request): + code_checks = inspect.getmembers(checks, inspect.isclass) + request.log.info("%d malware checks found in codebase." % len(code_checks)) + + all_checks = request.db.query(MalwareCheck).all() + active_checks = {} + wiped_out_checks = {} + for check in all_checks: + if not check.is_stale: + if check.state == MalwareCheckState.wiped_out: + wiped_out_checks[check.name] = check + else: + active_checks[check.name] = check + + if len(active_checks) > len(code_checks): + code_check_names = set([name for name, cls in code_checks]) + missing = ", ".join(set(active_checks.keys()) - code_check_names) + request.log.error( + "Found %d active checks in the db, but only %d checks in \ +code. Please manually move superfluous checks to the wiped_out state \ +in the check admin: %s" + % (len(active_checks), len(code_checks), missing) + ) + raise Exception("Mismatch between number of db checks and code checks.") + + for check_name, check_class in code_checks: + check = getattr(checks, check_name) + + if wiped_out_checks.get(check_name): + request.log.error( + "%s is wiped_out and cannot be synced. Please remove check from \ +codebase." + % check_name + ) + continue + + db_check = active_checks.get(check_name) + if db_check: + if check.version == db_check.version: + request.log.info("%s is unmodified." % check_name) + continue + + request.log.info("Updating existing %s." % check_name) + fields = get_check_fields(check) + + # Migrate the check state to the newest check. + # Then mark the old check state as disabled. + if db_check.state != MalwareCheckState.disabled: + fields["state"] = db_check.state.value + db_check.state = MalwareCheckState.disabled + + request.db.add(MalwareCheck(**fields)) + else: + request.log.info("Adding new %s to the database." % check_name) + fields = get_check_fields(check) + request.db.add(MalwareCheck(**fields)) diff --git a/warehouse/malware/utils.py b/warehouse/malware/utils.py index b0a5bf5b49e7..23af5b0a578e 100644 --- a/warehouse/malware/utils.py +++ b/warehouse/malware/utils.py @@ -27,6 +27,18 @@ def valid_check_types(): return set([t.value for t in MalwareCheckObjectType]) +def get_check_fields(check): + result = {"name": check.__name__} + required_fields = ["short_description", "long_description", "version", "check_type"] + for field in required_fields: + result[field] = getattr(check, field) + + if result["check_type"] == "event_hook": + result["hooked_object"] = check.hooked_object + + return result + + def get_enabled_checks(session): checks = ( session.query(MalwareCheck.name, MalwareCheck.hooked_object)