diff --git a/tests/common/checks/__init__.py b/tests/common/checks/__init__.py new file mode 100644 index 000000000000..dfd77b961075 --- /dev/null +++ b/tests/common/checks/__init__.py @@ -0,0 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .hooked import ExampleHookedCheck # noqa +from .scheduled import ExampleScheduledCheck # noqa diff --git a/warehouse/malware/checks/example.py b/tests/common/checks/hooked.py similarity index 87% rename from warehouse/malware/checks/example.py rename to tests/common/checks/hooked.py index 22b91906ffa3..6bd6f2e512e4 100644 --- a/warehouse/malware/checks/example.py +++ b/tests/common/checks/hooked.py @@ -14,19 +14,19 @@ from warehouse.malware.models import VerdictClassification, VerdictConfidence -class ExampleCheck(MalwareCheckBase): +class ExampleHookedCheck(MalwareCheckBase): version = 1 short_description = "An example hook-based check" - long_description = """The purpose of this check is to demonstrate the \ -implementation of a hook-based check. This check will generate verdicts if enabled.""" + long_description = "The purpose of this check is to test the \ +implementation of a hook-based check. This check will generate verdicts if enabled." check_type = "event_hook" hooked_object = "File" def __init__(self, db): super().__init__(db) - def scan(self, file_id): + def scan(self, file_id=None): self.add_verdict( file_id=file_id, classification=VerdictClassification.benign, diff --git a/tests/common/checks/scheduled.py b/tests/common/checks/scheduled.py new file mode 100644 index 000000000000..128ce102a83b --- /dev/null +++ b/tests/common/checks/scheduled.py @@ -0,0 +1,37 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from warehouse.malware.checks.base import MalwareCheckBase +from warehouse.malware.models import VerdictClassification, VerdictConfidence +from warehouse.packaging.models import Project + + +class ExampleScheduledCheck(MalwareCheckBase): + + version = 1 + short_description = "An example scheduled check" + long_description = "The purpose of this check is to test the \ +implementation of a scheduled check. This check will generate verdicts if enabled." + check_type = "scheduled" + schedule = {"minute": "0", "hour": "*/8"} + + def __init__(self, db): + super().__init__(db) + + def scan(self): + project = self.db.query(Project).first() + self.add_verdict( + project_id=project.id, + classification=VerdictClassification.benign, + confidence=VerdictConfidence.High, + message="Nothing to see here!", + ) diff --git a/tests/common/db/malware.py b/tests/common/db/malware.py index b6e1bf387b90..4e41a0c23865 100644 --- a/tests/common/db/malware.py +++ b/tests/common/db/malware.py @@ -39,6 +39,7 @@ class Meta: long_description = factory.fuzzy.FuzzyText(length=300) check_type = factory.fuzzy.FuzzyChoice(list(MalwareCheckType)) hooked_object = factory.fuzzy.FuzzyChoice(list(MalwareCheckObjectType)) + schedule = {"minute": "*/10"} state = factory.fuzzy.FuzzyChoice(list(MalwareCheckState)) created = factory.fuzzy.FuzzyNaiveDateTime( datetime.datetime.utcnow() - datetime.timedelta(days=7) diff --git a/tests/unit/malware/test_checks.py b/tests/unit/malware/test_checks.py index 2ce63c624965..9427972ef5e8 100644 --- a/tests/unit/malware/test_checks.py +++ b/tests/unit/malware/test_checks.py @@ -12,26 +12,35 @@ import inspect -import warehouse.malware.checks as checks +import pytest + +import warehouse.malware.checks as prod_checks from warehouse.malware.checks.base import MalwareCheckBase from warehouse.malware.utils import get_check_fields +from ...common import checks as test_checks + def test_checks_subclass_base(): - checks_from_module = inspect.getmembers(checks, inspect.isclass) + prod_checks_from_module = inspect.getmembers(prod_checks, inspect.isclass) + test_checks_from_module = inspect.getmembers(test_checks, inspect.isclass) + all_checks = prod_checks_from_module + test_checks_from_module subclasses_of_malware_base = { cls.__name__: cls for cls in MalwareCheckBase.__subclasses__() } - assert len(checks_from_module) == len(subclasses_of_malware_base) + assert len(all_checks) == len(subclasses_of_malware_base) - for check_name, check in checks_from_module: + for check_name, check in all_checks: assert subclasses_of_malware_base[check_name] == check -def test_checks_fields(): +@pytest.mark.parametrize( + ("checks"), [prod_checks, test_checks], +) +def test_checks_fields(checks): checks_from_module = inspect.getmembers(checks, inspect.isclass) for check_name, check in checks_from_module: diff --git a/tests/unit/malware/test_init.py b/tests/unit/malware/test_init.py index 5db87f2c396e..4d2888aef6a1 100644 --- a/tests/unit/malware/test_init.py +++ b/tests/unit/malware/test_init.py @@ -18,15 +18,16 @@ from warehouse.malware import utils from warehouse.malware.interfaces import IMalwareCheckService +from ...common import checks as test_checks from ...common.db.accounts import UserFactory from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory def test_determine_malware_checks_no_checks(monkeypatch, db_request): - def get_enabled_checks(session): + def get_enabled_hooked_checks(session): return defaultdict(list) - monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -39,13 +40,13 @@ def get_enabled_checks(session): def test_determine_malware_checks_nothing_new(monkeypatch, db_request): - def get_enabled_checks(session): + def get_enabled_hooked_checks(session): result = defaultdict(list) result["File"] = ["Check1", "Check2"] result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -58,13 +59,13 @@ def get_enabled_checks(session): def test_determine_malware_checks_unsupported_object(monkeypatch, db_request): - def get_enabled_checks(session): + def get_enabled_hooked_checks(session): result = defaultdict(list) result["File"] = ["Check1", "Check2"] result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) user = UserFactory.create() @@ -75,13 +76,13 @@ def get_enabled_checks(session): def test_determine_malware_checks_file_only(monkeypatch, db_request): - def get_enabled_checks(session): + def get_enabled_hooked_checks(session): result = defaultdict(list) result["File"] = ["Check1", "Check2"] result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -95,13 +96,13 @@ def get_enabled_checks(session): def test_determine_malware_checks_file_and_release(monkeypatch, db_request): - def get_enabled_checks(session): + def get_enabled_hooked_checks(session): result = defaultdict(list) result["File"] = ["Check1", "Check2"] result["Release"] = ["Check3"] return result - monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + monkeypatch.setattr(utils, "get_enabled_hooked_checks", get_enabled_hooked_checks) project = ProjectFactory.create(name="foo") release = ReleaseFactory.create(project=project) @@ -149,7 +150,9 @@ def test_enqueue_malware_checks_no_checks(app_config): assert "warehouse.malware.checks" not in session.info -def test_includeme(): +def test_includeme(monkeypatch): + monkeypatch.setattr(malware, "checks", test_checks) + malware_check_class = pretend.stub( create_service=pretend.call_recorder(lambda *a, **kw: pretend.stub()) ) diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py index 91ce94966438..8fc427e35010 100644 --- a/tests/unit/malware/test_tasks.py +++ b/tests/unit/malware/test_tasks.py @@ -16,42 +16,47 @@ from sqlalchemy.orm.exc import NoResultFound -import warehouse.malware.checks as checks - +from warehouse.malware import tasks from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict -from warehouse.malware.tasks import backfill, remove_verdicts, run_check, sync_checks +from ...common import checks as test_checks from ...common.db.malware import MalwareCheckFactory, MalwareVerdictFactory from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory class TestRunCheck: - def test_success(self, db_request): + def test_success(self, db_request, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) file0 = FileFactory.create() - MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) + MalwareCheckFactory.create( + name="ExampleHookedCheck", state=MalwareCheckState.enabled + ) task = pretend.stub() - run_check(task, db_request, "ExampleCheck", file0.id) + tasks.run_check(task, db_request, "ExampleHookedCheck", file0.id) assert db_request.db.query(MalwareVerdict).one() - def test_disabled_check(self, db_request): + def test_disabled_check(self, db_request, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) MalwareCheckFactory.create( - name="ExampleCheck", state=MalwareCheckState.disabled + name="ExampleHookedCheck", state=MalwareCheckState.disabled ) + task = pretend.stub() with pytest.raises(NoResultFound): - run_check( + tasks.run_check( task, db_request, - "ExampleCheck", + "ExampleHookedCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7", ) - def test_missing_check(self, db_request): + def test_missing_check(self, db_request, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) task = pretend.stub() with pytest.raises(AttributeError): - run_check( + tasks.run_check( task, db_request, "DoesNotExistCheck", @@ -59,16 +64,17 @@ def test_missing_check(self, db_request): ) def test_retry(self, db_session, monkeypatch): - MalwareCheckFactory.create( - name="ExampleCheck", state=MalwareCheckState.evaluation - ) - exc = Exception("Scan failed") def scan(self, file_id): raise exc - monkeypatch.setattr(checks.ExampleCheck, "scan", scan) + monkeypatch.setattr(tasks, "checks", test_checks) + monkeypatch.setattr(tasks.checks.ExampleHookedCheck, "scan", scan) + + MalwareCheckFactory.create( + name="ExampleHookedCheck", state=MalwareCheckState.evaluation + ) task = pretend.stub( retry=pretend.call_recorder(pretend.raiser(celery.exceptions.Retry)), @@ -79,32 +85,40 @@ def scan(self, file_id): ) with pytest.raises(celery.exceptions.Retry): - run_check( - task, request, "ExampleCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7" + tasks.run_check( + task, + request, + "ExampleHookedCheck", + "d03d75d1-2511-4a8b-9759-62294a6fe3a7", ) assert request.log.error.calls == [ - pretend.call("Error executing check ExampleCheck: Scan failed") + pretend.call("Error executing check ExampleHookedCheck: Scan failed") ] assert task.retry.calls == [pretend.call(exc=exc)] class TestBackfill: - def test_invalid_check_name(self, db_request): + def test_invalid_check_name(self, db_request, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) task = pretend.stub() with pytest.raises(AttributeError): - backfill(task, db_request, "DoesNotExist", 1) + tasks.backfill(task, db_request, "DoesNotExist", 1) @pytest.mark.parametrize( ("num_objects", "num_runs"), [(11, 1), (11, 11), (101, 90)], ) - def test_run(self, db_session, num_objects, num_runs): + def test_run(self, db_session, num_objects, num_runs, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) files = [] for i in range(num_objects): files.append(FileFactory.create()) - MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) + MalwareCheckFactory.create( + name="ExampleHookedCheck", state=MalwareCheckState.enabled + ) + enqueue_recorder = pretend.stub( delay=pretend.call_recorder(lambda *a, **kw: None) ) @@ -116,21 +130,29 @@ def test_run(self, db_session, num_objects, num_runs): task=task, ) - backfill(task, request, "ExampleCheck", num_runs) + tasks.backfill(task, request, "ExampleHookedCheck", num_runs) assert request.log.info.calls == [ pretend.call("Running backfill on %d Files." % num_runs), ] assert enqueue_recorder.delay.calls == [ - pretend.call("ExampleCheck", files[i].id) for i in range(num_runs) + pretend.call("ExampleHookedCheck", files[i].id) for i in range(num_runs) ] class TestSyncChecks: - def test_no_updates(self, db_session): + def test_no_updates(self, db_session, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) + monkeypatch.setattr(tasks.checks.ExampleScheduledCheck, "version", 2) + MalwareCheckFactory.create( + name="ExampleHookedCheck", state=MalwareCheckState.disabled + ) MalwareCheckFactory.create( - name="ExampleCheck", state=MalwareCheckState.disabled + name="ExampleScheduledCheck", state=MalwareCheckState.disabled + ) + MalwareCheckFactory.create( + name="ExampleScheduledCheck", state=MalwareCheckState.enabled, version=2 ) task = pretend.stub() @@ -140,26 +162,25 @@ def test_no_updates(self, db_session): log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), ) - sync_checks(task, request) + tasks.sync_checks(task, request) assert request.log.info.calls == [ - pretend.call("1 malware checks found in codebase."), - pretend.call("ExampleCheck is unmodified."), + pretend.call("2 malware checks found in codebase."), + pretend.call("ExampleHookedCheck is unmodified."), + pretend.call("ExampleScheduledCheck is unmodified."), ] @pytest.mark.parametrize( ("final_state"), [MalwareCheckState.enabled, MalwareCheckState.disabled] ) def test_upgrade_check(self, monkeypatch, db_session, final_state): - MalwareCheckFactory.create(name="ExampleCheck", state=final_state) + monkeypatch.setattr(tasks, "checks", test_checks) + monkeypatch.setattr(tasks.checks.ExampleHookedCheck, "version", 2) - class ExampleCheck: - version = 2 - short_description = "This is a short description." - long_description = "This is a longer description." - check_type = "scheduled" - - monkeypatch.setattr(checks, "ExampleCheck", ExampleCheck) + MalwareCheckFactory.create(name="ExampleHookedCheck", state=final_state) + MalwareCheckFactory.create( + name="ExampleScheduledCheck", state=MalwareCheckState.disabled + ) task = pretend.stub() request = pretend.stub( @@ -167,15 +188,16 @@ class ExampleCheck: log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), ) - sync_checks(task, request) + tasks.sync_checks(task, request) assert request.log.info.calls == [ - pretend.call("1 malware checks found in codebase."), - pretend.call("Updating existing ExampleCheck."), + pretend.call("2 malware checks found in codebase."), + pretend.call("Updating existing ExampleHookedCheck."), + pretend.call("ExampleScheduledCheck is unmodified."), ] db_checks = ( db_session.query(MalwareCheck) - .filter(MalwareCheck.name == "ExampleCheck") + .filter(MalwareCheck.name == "ExampleHookedCheck") .all() ) @@ -193,7 +215,16 @@ class ExampleCheck: else: assert c.version == 1 - def test_one_new_check(self, db_session): + def test_one_new_check(self, db_session, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) + + MalwareCheckFactory.create( + name="ExampleHookedCheck", state=MalwareCheckState.disabled + ) + MalwareCheckFactory.create( + name="ExampleScheduledCheck", state=MalwareCheckState.disabled + ) + task = pretend.stub() class FakeMalwareCheck: @@ -201,26 +232,24 @@ class FakeMalwareCheck: short_description = "This is a short description." long_description = "This is a longer description." check_type = "scheduled" + schedule = {"minute": "0", "hour": "*/8"} - checks.FakeMalwareCheck = FakeMalwareCheck + tasks.checks.FakeMalwareCheck = FakeMalwareCheck request = pretend.stub( db=db_session, log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), ) - MalwareCheckFactory.create( - name="ExampleCheck", state=MalwareCheckState.evaluation - ) - - sync_checks(task, request) + tasks.sync_checks(task, request) assert request.log.info.calls == [ - pretend.call("2 malware checks found in codebase."), - pretend.call("ExampleCheck is unmodified."), + pretend.call("3 malware checks found in codebase."), + pretend.call("ExampleHookedCheck is unmodified."), + pretend.call("ExampleScheduledCheck is unmodified."), pretend.call("Adding new FakeMalwareCheck to the database."), ] - assert db_session.query(MalwareCheck).count() == 2 + assert db_session.query(MalwareCheck).count() == 3 new_check = ( db_session.query(MalwareCheck) @@ -230,19 +259,23 @@ class FakeMalwareCheck: assert new_check.state == MalwareCheckState.disabled - del checks.FakeMalwareCheck + del tasks.checks.FakeMalwareCheck - def test_too_many_db_checks(self, db_session): - task = pretend.stub() + def test_too_many_db_checks(self, db_session, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) - MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled) MalwareCheckFactory.create( - name="AnotherCheck", state=MalwareCheckState.disabled + name="ExampleHookedCheck", state=MalwareCheckState.enabled + ) + MalwareCheckFactory.create( + name="ExampleScheduledCheck", state=MalwareCheckState.enabled ) MalwareCheckFactory.create( name="AnotherCheck", state=MalwareCheckState.evaluation, version=2 ) + task = pretend.stub() + request = pretend.stub( db=db_session, log=pretend.stub( @@ -252,25 +285,30 @@ def test_too_many_db_checks(self, db_session): ) with pytest.raises(Exception): - sync_checks(task, request) + tasks.sync_checks(task, request) assert request.log.info.calls == [ - pretend.call("1 malware checks found in codebase."), + pretend.call("2 malware checks found in codebase."), ] assert request.log.error.calls == [ pretend.call( - "Found 2 active checks in the db, but only 1 checks in code. Please \ + "Found 3 active checks in the db, but only 2 checks in code. Please \ manually move superfluous checks to the wiped_out state in the check admin: \ AnotherCheck" ), ] - def test_only_wiped_out(self, db_session): - task = pretend.stub() + def test_only_wiped_out(self, db_session, monkeypatch): + monkeypatch.setattr(tasks, "checks", test_checks) + MalwareCheckFactory.create( + name="ExampleHookedCheck", state=MalwareCheckState.wiped_out + ) MalwareCheckFactory.create( - name="ExampleCheck", state=MalwareCheckState.wiped_out + name="ExampleScheduledCheck", state=MalwareCheckState.wiped_out ) + + task = pretend.stub() request = pretend.stub( db=db_session, log=pretend.stub( @@ -279,15 +317,19 @@ def test_only_wiped_out(self, db_session): ), ) - sync_checks(task, request) + tasks.sync_checks(task, request) assert request.log.info.calls == [ - pretend.call("1 malware checks found in codebase."), + pretend.call("2 malware checks found in codebase."), ] assert request.log.error.calls == [ pretend.call( - "ExampleCheck is wiped_out and cannot be synced. Please remove check \ + "ExampleHookedCheck is wiped_out and cannot be synced. Please remove check \ +from codebase." + ), + pretend.call( + "ExampleScheduledCheck is wiped_out and cannot be synced. Please remove check \ from codebase." ), ] @@ -302,7 +344,7 @@ def test_no_verdicts(self, db_session): log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),), ) task = pretend.stub() - removed = remove_verdicts(task, request, check.name) + removed = tasks.remove_verdicts(task, request, check.name) assert request.log.info.calls == [ pretend.call( @@ -338,7 +380,7 @@ def test_many_verdicts(self, db_session, check_with_verdicts): wiped_out_check = check0 num_verdicts = 0 - removed = remove_verdicts(task, request, wiped_out_check.name) + removed = tasks.remove_verdicts(task, request, wiped_out_check.name) assert request.log.info.calls == [ pretend.call( diff --git a/tests/unit/malware/test_utils.py b/tests/unit/malware/test_utils.py index 0f6ea532debf..c3cc7093ec6a 100644 --- a/tests/unit/malware/test_utils.py +++ b/tests/unit/malware/test_utils.py @@ -15,8 +15,9 @@ import pytest from warehouse.malware.models import MalwareCheckState, MalwareCheckType -from warehouse.malware.utils import get_check_fields, get_enabled_checks +from warehouse.malware.utils import get_check_fields, get_enabled_hooked_checks +from ...common.checks import ExampleHookedCheck, ExampleScheduledCheck from ...common.db.malware import MalwareCheckFactory @@ -27,7 +28,7 @@ def test_one(self, db_session): ) result = defaultdict(list) result[check.hooked_object.value].append(check.name) - checks = get_enabled_checks(db_session) + checks = get_enabled_hooked_checks(db_session) assert checks == result def test_many(self, db_session): @@ -40,36 +41,49 @@ def test_many(self, db_session): ): result[check.hooked_object.value].append(check.name) - checks = get_enabled_checks(db_session) + checks = get_enabled_hooked_checks(db_session) assert checks == result def test_none(self, db_session): - checks = get_enabled_checks(db_session) + checks = get_enabled_hooked_checks(db_session) assert checks == defaultdict(list) class TestGetCheckFields: - def test_success(self): - class MySampleCheck: - version = 6 - foo = "bar" - short_description = "This is the description" - long_description = "This is the description" - check_type = "scheduled" + @pytest.mark.parametrize( + ("check", "result"), + [ + ( + ExampleHookedCheck, + { + "name": "ExampleHookedCheck", + "version": 1, + "short_description": "An example hook-based check", + "long_description": "The purpose of this check is to test the \ +implementation of a hook-based check. This check will generate verdicts if enabled.", + "check_type": "event_hook", + "hooked_object": "File", + }, + ), + ( + ExampleScheduledCheck, + { + "name": "ExampleScheduledCheck", + "version": 1, + "short_description": "An example scheduled check", + "long_description": "The purpose of this check is to test the \ +implementation of a scheduled check. This check will generate verdicts if enabled.", + "check_type": "scheduled", + "schedule": {"minute": "0", "hour": "*/8"}, + }, + ), + ], + ) + def test_success(self, check, result): + assert get_check_fields(check) == result - result = get_check_fields(MySampleCheck) - assert result == { - "name": "MySampleCheck", - "version": 6, - "short_description": "This is the description", - "long_description": "This is the description", - "check_type": "scheduled", - } - - def test_failure(self): - class MySampleCheck: - version = 1 - status = True + def test_failure(self, monkeypatch): + monkeypatch.delattr(ExampleScheduledCheck, "schedule") with pytest.raises(AttributeError): - get_check_fields(MySampleCheck) + get_check_fields(ExampleScheduledCheck) diff --git a/warehouse/malware/__init__.py b/warehouse/malware/__init__.py index ee0e36b808aa..f54a9e89b4f5 100644 --- a/warehouse/malware/__init__.py +++ b/warehouse/malware/__init__.py @@ -30,7 +30,7 @@ def determine_malware_checks(config, session, flush_context): return malware_checks = session.info.setdefault("warehouse.malware.checks", set()) - enabled_checks = utils.get_enabled_checks(session) + enabled_checks = utils.get_enabled_hooked_checks(session) for obj in session.new: for check_name in enabled_checks.get(obj.__class__.__name__, []): malware_checks.update([f"{check_name}:{obj.id}"]) diff --git a/warehouse/malware/checks/__init__.py b/warehouse/malware/checks/__init__.py index a627b7d18159..164f68b09175 100644 --- a/warehouse/malware/checks/__init__.py +++ b/warehouse/malware/checks/__init__.py @@ -9,5 +9,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .example import ExampleCheck # noqa diff --git a/warehouse/malware/models.py b/warehouse/malware/models.py index 3e9aa388a701..0464ba0d47ce 100644 --- a/warehouse/malware/models.py +++ b/warehouse/malware/models.py @@ -81,12 +81,14 @@ class MalwareCheck(db.Model): Enum(MalwareCheckType, values_callable=lambda x: [e.value for e in x]), nullable=False, ) - # This field contains the object name that check operates on, e.g. + # The object name that hooked-based checks operate on, e.g. # Project, File, Release hooked_object = Column( Enum(MalwareCheckObjectType, values_callable=lambda x: [e.value for e in x]), nullable=True, ) + # The run schedule for schedule-based checks. + schedule = Column(JSONB, nullable=True) state = Column( Enum(MalwareCheckState, values_callable=lambda x: [e.value for e in x]), nullable=False, diff --git a/warehouse/malware/utils.py b/warehouse/malware/utils.py index 23af5b0a578e..6139c3e248a7 100644 --- a/warehouse/malware/utils.py +++ b/warehouse/malware/utils.py @@ -29,6 +29,7 @@ def valid_check_types(): def get_check_fields(check): result = {"name": check.__name__} + required_fields = ["short_description", "long_description", "version", "check_type"] for field in required_fields: result[field] = getattr(check, field) @@ -36,10 +37,13 @@ def get_check_fields(check): if result["check_type"] == "event_hook": result["hooked_object"] = check.hooked_object + if result["check_type"] == "scheduled": + result["schedule"] = check.schedule + return result -def get_enabled_checks(session): +def get_enabled_hooked_checks(session): checks = ( session.query(MalwareCheck.name, MalwareCheck.hooked_object) .filter(MalwareCheck.check_type == MalwareCheckType.event_hook) diff --git a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py index 6e23aeb243f8..622660fd042f 100644 --- a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py +++ b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py @@ -56,6 +56,7 @@ def upgrade(): sa.Column("long_description", sa.Text(), nullable=False), sa.Column("check_type", MalwareCheckTypes, nullable=False), sa.Column("hooked_object", MalwareCheckObjectTypes, nullable=True), + sa.Column("schedule", postgresql.JSONB(astext_type=sa.Text()), nullable=True), sa.Column( "state", MalwareCheckStates, server_default="disabled", nullable=False, ),