diff --git a/dev/environment b/dev/environment index e7ae3673787b..ec7eeae6d2f3 100644 --- a/dev/environment +++ b/dev/environment @@ -29,6 +29,9 @@ MAIL_BACKEND=warehouse.email.services.SMTPEmailSender host=smtp port=2525 ssl=fa BREACHED_PASSWORDS=warehouse.accounts.NullPasswordBreachedService +#TODO: change this to PrinterMalwareCheckService before deploy +MALWARE_CHECK_BACKEND=warehouse.malware.services.DatabaseMalwareCheckService + METRICS_BACKEND=warehouse.metrics.DataDogMetrics host=notdatadog STATUSPAGE_URL=https://2p66nmmycsj3.statuspage.io diff --git a/tests/common/db/malware.py b/tests/common/db/malware.py index 263fa82fa684..7b01dc4723d6 100644 --- a/tests/common/db/malware.py +++ b/tests/common/db/malware.py @@ -15,7 +15,12 @@ import factory import factory.fuzzy -from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareCheckType +from warehouse.malware.models import ( + MalwareCheck, + MalwareCheckObjectType, + MalwareCheckState, + MalwareCheckType, +) from .base import WarehouseFactory @@ -29,11 +34,7 @@ class Meta: short_description = factory.fuzzy.FuzzyText(length=80) long_description = factory.fuzzy.FuzzyText(length=300) check_type = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckType]) - hook_name = ( - "project:release:file:upload" - if check_type == MalwareCheckType.event_hook - else None - ) + hooked_object = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckObjectType]) state = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckState]) created = factory.fuzzy.FuzzyNaiveDateTime( datetime.datetime.utcnow() - datetime.timedelta(days=7) diff --git a/tests/conftest.py b/tests/conftest.py index 3623ec94998c..1b3a9b8b93e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -174,6 +174,9 @@ def app_config(database): "files.backend": "warehouse.packaging.services.LocalFileStorage", "docs.backend": "warehouse.packaging.services.LocalFileStorage", "mail.backend": "warehouse.email.services.SMTPEmailSender", + "malware_check.backend": ( + "warehouse.malware.services.PrinterMalwareCheckService" + ), "files.url": "http://localhost:7000/", "sessions.secret": "123456", "sessions.url": "redis://localhost:0/", diff --git a/tests/unit/malware/__init__.py b/tests/unit/malware/__init__.py new file mode 100644 index 000000000000..164f68b09175 --- /dev/null +++ b/tests/unit/malware/__init__.py @@ -0,0 +1,11 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/malware/test_init.py b/tests/unit/malware/test_init.py new file mode 100644 index 000000000000..5db87f2c396e --- /dev/null +++ b/tests/unit/malware/test_init.py @@ -0,0 +1,171 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +import pretend + +from warehouse import malware +from warehouse.malware import utils +from warehouse.malware.interfaces import IMalwareCheckService + +from ...common.db.accounts import UserFactory +from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory + + +def test_determine_malware_checks_no_checks(monkeypatch, db_request): + def get_enabled_checks(session): + return defaultdict(list) + + monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + + project = ProjectFactory.create(name="foo") + release = ReleaseFactory.create(project=project) + file0 = FileFactory.create(release=release, filename="foo.bar") + + session = pretend.stub(info={}, new={file0, release, project}, dirty={}, deleted={}) + + malware.determine_malware_checks(pretend.stub(), session, pretend.stub()) + assert session.info["warehouse.malware.checks"] == set() + + +def test_determine_malware_checks_nothing_new(monkeypatch, db_request): + def get_enabled_checks(session): + result = defaultdict(list) + result["File"] = ["Check1", "Check2"] + result["Release"] = ["Check3"] + return result + + monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + + project = ProjectFactory.create(name="foo") + release = ReleaseFactory.create(project=project) + file0 = FileFactory.create(release=release, filename="foo.bar") + + session = pretend.stub(info={}, new={}, dirty={file0, release}, deleted={}) + + malware.determine_malware_checks(pretend.stub(), session, pretend.stub()) + assert session.info.get("warehouse.malware.checks") is None + + +def test_determine_malware_checks_unsupported_object(monkeypatch, db_request): + def get_enabled_checks(session): + result = defaultdict(list) + result["File"] = ["Check1", "Check2"] + result["Release"] = ["Check3"] + return result + + monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + + user = UserFactory.create() + + session = pretend.stub(info={}, new={user}, dirty={}, deleted={}) + + malware.determine_malware_checks(pretend.stub(), session, pretend.stub()) + assert session.info.get("warehouse.malware.checks") is None + + +def test_determine_malware_checks_file_only(monkeypatch, db_request): + def get_enabled_checks(session): + result = defaultdict(list) + result["File"] = ["Check1", "Check2"] + result["Release"] = ["Check3"] + return result + + monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + + project = ProjectFactory.create(name="foo") + release = ReleaseFactory.create(project=project) + file0 = FileFactory.create(release=release, filename="foo.bar") + + session = pretend.stub(info={}, new={file0}, dirty={}, deleted={}) + + checks = set(["Check%d:%s" % (x, file0.id) for x in range(1, 3)]) + malware.determine_malware_checks(pretend.stub(), session, pretend.stub()) + assert session.info["warehouse.malware.checks"] == checks + + +def test_determine_malware_checks_file_and_release(monkeypatch, db_request): + def get_enabled_checks(session): + result = defaultdict(list) + result["File"] = ["Check1", "Check2"] + result["Release"] = ["Check3"] + return result + + monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks) + + project = ProjectFactory.create(name="foo") + release = ReleaseFactory.create(project=project) + file0 = FileFactory.create(release=release, filename="foo.bar") + file1 = FileFactory.create(release=release, filename="foo.baz") + + session = pretend.stub( + info={}, new={project, release, file0, file1}, dirty={}, deleted={} + ) + + checks = set(["Check%d:%s" % (x, file0.id) for x in range(1, 3)]) + checks.update(["Check%d:%s" % (x, file1.id) for x in range(1, 3)]) + checks.add("Check3:%s" % release.id) + + malware.determine_malware_checks(pretend.stub(), session, pretend.stub()) + + assert session.info["warehouse.malware.checks"] == checks + + +def test_enqueue_malware_checks(app_config): + malware_check = pretend.stub( + run_checks=pretend.call_recorder(lambda malware_checks: None) + ) + factory = pretend.call_recorder(lambda ctx, config: malware_check) + app_config.register_service_factory(factory, IMalwareCheckService) + app_config.commit() + session = pretend.stub( + info={ + "warehouse.malware.checks": {"Check1:ba70267f-fabf-496f-9ac2-d237a983b187"} + } + ) + + malware.queue_malware_checks(app_config, session) + + assert factory.calls == [pretend.call(None, app_config)] + assert malware_check.run_checks.calls == [ + pretend.call({"Check1:ba70267f-fabf-496f-9ac2-d237a983b187"}) + ] + assert "warehouse.malware.checks" not in session.info + + +def test_enqueue_malware_checks_no_checks(app_config): + session = pretend.stub(info={}) + malware.queue_malware_checks(app_config, session) + assert "warehouse.malware.checks" not in session.info + + +def test_includeme(): + malware_check_class = pretend.stub( + create_service=pretend.call_recorder(lambda *a, **kw: pretend.stub()) + ) + + config = pretend.stub( + maybe_dotted=lambda dotted: malware_check_class, + register_service_factory=pretend.call_recorder( + lambda factory, iface, name=None: None + ), + registry=pretend.stub( + settings={"malware_check.backend": "TestMalwareCheckService"} + ), + ) + + malware.includeme(config) + + assert config.register_service_factory.calls == [ + pretend.call(malware_check_class.create_service, IMalwareCheckService), + ] diff --git a/tests/unit/malware/test_services.py b/tests/unit/malware/test_services.py new file mode 100644 index 000000000000..7a9cb636f720 --- /dev/null +++ b/tests/unit/malware/test_services.py @@ -0,0 +1,61 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pretend + +from zope.interface.verify import verifyClass + +from warehouse.malware.interfaces import IMalwareCheckService +from warehouse.malware.services import ( + DatabaseMalwareCheckService, + PrinterMalwareCheckService, +) +from warehouse.malware.tasks import run_check + + +class TestPrinterMalwareCheckService: + def test_verify_service(self): + assert verifyClass(IMalwareCheckService, PrinterMalwareCheckService) + + def test_create_service(self): + request = pretend.stub() + service = PrinterMalwareCheckService.create_service(None, request) + assert service.executor == print + + def test_run_checks(self, capfd): + request = pretend.stub() + service = PrinterMalwareCheckService.create_service(None, request) + checks = ["one", "two", "three"] + service.run_checks(checks) + out, err = capfd.readouterr() + assert out == "one\ntwo\nthree\n" + + +class TestDatabaseMalwareService: + def test_verify_service(self): + assert verifyClass(IMalwareCheckService, DatabaseMalwareCheckService) + + def test_create_service(self, db_request): + _delay = pretend.call_recorder(lambda *args: None) + db_request.task = lambda x: pretend.stub(delay=_delay) + service = DatabaseMalwareCheckService.create_service(None, db_request) + assert service.executor == db_request.task(run_check).delay + + def test_run_checks(self, db_request): + _delay = pretend.call_recorder(lambda *args: None) + db_request.task = lambda x: pretend.stub(delay=_delay) + service = DatabaseMalwareCheckService.create_service(None, db_request) + checks = ["MyTestCheck:ba70267f-fabf-496f-9ac2-d237a983b187"] + service.run_checks(checks) + assert _delay.calls == [ + pretend.call("MyTestCheck", "ba70267f-fabf-496f-9ac2-d237a983b187") + ] diff --git a/tests/unit/malware/test_tasks.py b/tests/unit/malware/test_tasks.py new file mode 100644 index 000000000000..38f79b201ba2 --- /dev/null +++ b/tests/unit/malware/test_tasks.py @@ -0,0 +1,85 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import celery +import pretend +import pytest + +from sqlalchemy.orm.exc import NoResultFound + +import warehouse.malware.checks as checks + +from warehouse.malware.models import MalwareVerdict +from warehouse.malware.tasks import run_check + +from ...common.db.malware import MalwareCheckFactory +from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory + + +def test_run_check(monkeypatch, db_request): + project = ProjectFactory.create(name="foo") + release = ReleaseFactory.create(project=project) + file0 = FileFactory.create(release=release, filename="foo.bar") + MalwareCheckFactory.create(name="ExampleCheck", state="enabled") + + task = pretend.stub() + run_check(task, db_request, "ExampleCheck", file0.id) + assert db_request.db.query(MalwareVerdict).one() + + +def test_run_check_missing_check_id(monkeypatch, db_session): + exc = NoResultFound("No row was found for one()") + + class FakeMalwareCheck: + def __init__(self, db): + raise exc + + class Task: + @staticmethod + @pretend.call_recorder + def retry(exc): + raise celery.exceptions.Retry + + task = Task() + + checks.FakeMalwareCheck = FakeMalwareCheck + + request = pretend.stub( + db=db_session, + log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None),), + ) + + with pytest.raises(celery.exceptions.Retry): + run_check( + task, request, "FakeMalwareCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7" + ) + + assert request.log.error.calls == [ + pretend.call( + "Error executing check %s: %s", + "FakeMalwareCheck", + "No row was found for one()", + ) + ] + + assert task.retry.calls == [pretend.call(exc=exc)] + + +def test_run_check_missing_check(db_request): + task = pretend.stub() + with pytest.raises(AttributeError): + run_check( + task, + db_request, + "DoesNotExistCheck", + "d03d75d1-2511-4a8b-9759-62294a6fe3a7", + ) diff --git a/tests/unit/malware/test_utils.py b/tests/unit/malware/test_utils.py new file mode 100644 index 000000000000..b995147f6872 --- /dev/null +++ b/tests/unit/malware/test_utils.py @@ -0,0 +1,47 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +from warehouse.malware.models import MalwareCheckState, MalwareCheckType +from warehouse.malware.utils import get_enabled_checks + +from ...common.db.malware import MalwareCheckFactory + + +def test_get_enabled_checks(db_session): + check = MalwareCheckFactory.create( + state=MalwareCheckState.enabled, check_type=MalwareCheckType.event_hook + ) + result = defaultdict(list) + result[check.hooked_object.value].append(check.name) + checks = get_enabled_checks(db_session) + assert checks == result + + +def test_get_enabled_checks_many(db_session): + result = defaultdict(list) + for i in range(10): + check = MalwareCheckFactory.create() + if ( + check.state == MalwareCheckState.enabled + and check.check_type == MalwareCheckType.event_hook + ): + result[check.hooked_object.value].append(check.name) + + checks = get_enabled_checks(db_session) + assert checks == result + + +def test_get_enabled_checks_none(db_session): + checks = get_enabled_checks(db_session) + assert checks == defaultdict(list) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 53fab7d8235a..d65a976bc91b 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -317,6 +317,7 @@ def __init__(self): pretend.call(".email"), pretend.call(".accounts"), pretend.call(".macaroons"), + pretend.call(".malware"), pretend.call(".manage"), pretend.call(".packaging"), pretend.call(".redirects"), diff --git a/warehouse/config.py b/warehouse/config.py index 84a49e19d814..fb7efcbe7cd9 100644 --- a/warehouse/config.py +++ b/warehouse/config.py @@ -203,6 +203,7 @@ def configure(settings=None): maybe_set_compound(settings, "mail", "backend", "MAIL_BACKEND") maybe_set_compound(settings, "metrics", "backend", "METRICS_BACKEND") maybe_set_compound(settings, "breached_passwords", "backend", "BREACHED_PASSWORDS") + maybe_set_compound(settings, "malware_check", "backend", "MALWARE_CHECK_BACKEND") # Add the settings we use when the environment is set to development. if settings["warehouse.env"] == Environment.development: @@ -389,6 +390,9 @@ def configure(settings=None): # Register support for Macaroon based authentication config.include(".macaroons") + # Register support for malware checks + config.include(".malware") + # Register logged-in views config.include(".manage") diff --git a/warehouse/malware/__init__.py b/warehouse/malware/__init__.py index 164f68b09175..ee0e36b808aa 100644 --- a/warehouse/malware/__init__.py +++ b/warehouse/malware/__init__.py @@ -9,3 +9,51 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from warehouse import db +from warehouse.malware import utils +from warehouse.malware.interfaces import IMalwareCheckService + + +@db.listens_for(db.Session, "after_flush") +def determine_malware_checks(config, session, flush_context): + if not session.new: + return + + if not any( + [ + obj.__class__.__name__ + for obj in session.new + if obj.__class__.__name__ in utils.valid_check_types() + ] + ): + return + + malware_checks = session.info.setdefault("warehouse.malware.checks", set()) + enabled_checks = utils.get_enabled_checks(session) + for obj in session.new: + for check_name in enabled_checks.get(obj.__class__.__name__, []): + malware_checks.update([f"{check_name}:{obj.id}"]) + + +@db.listens_for(db.Session, "after_commit") +def queue_malware_checks(config, session): + + malware_checks = session.info.pop("warehouse.malware.checks", set()) + if not malware_checks: + return + + malware_check_factory = config.find_service_factory(IMalwareCheckService) + + malware_check = malware_check_factory(None, config) + malware_check.run_checks(malware_checks) + + +def includeme(config): + malware_check_class = config.maybe_dotted( + config.registry.settings["malware_check.backend"] + ) + # Register the malware check service + config.register_service_factory( + malware_check_class.create_service, IMalwareCheckService + ) diff --git a/warehouse/malware/checks/__init__.py b/warehouse/malware/checks/__init__.py new file mode 100644 index 000000000000..a627b7d18159 --- /dev/null +++ b/warehouse/malware/checks/__init__.py @@ -0,0 +1,13 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .example import ExampleCheck # noqa diff --git a/warehouse/malware/checks/base.py b/warehouse/malware/checks/base.py new file mode 100644 index 000000000000..b5102bfe6f71 --- /dev/null +++ b/warehouse/malware/checks/base.py @@ -0,0 +1,45 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from warehouse.malware.models import MalwareCheck, MalwareCheckState + + +class MalwareCheckBase: + def __init__(self, db): + self.db = db + self._name = self.__class__.__name__ + self._load_check() + + def run(self, obj_id): + """ + Executes the check. + """ + + def backfill(self, sample=1): + """ + Runs the check across all historical data in PyPI. The sample value represents + the percentage of files to file the check against. By default, it will run the + backfill on the entire corpus. + """ + + def update(self): + """ + Update the check definition in the database. + """ + + def _load_check(self): + self.id = ( + self.db.query(MalwareCheck.id) + .filter(MalwareCheck.name == self._name) + .filter(MalwareCheck.state == MalwareCheckState.enabled) + .one() + ) diff --git a/warehouse/malware/checks/example.py b/warehouse/malware/checks/example.py new file mode 100644 index 000000000000..c55748cdaf44 --- /dev/null +++ b/warehouse/malware/checks/example.py @@ -0,0 +1,43 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from warehouse.malware.checks.base import MalwareCheckBase +from warehouse.malware.models import ( + MalwareVerdict, + VerdictClassification, + VerdictConfidence, +) + +VERSION = 1 +SHORT_DESCRIPTION = "An example hook-based check" +LONG_DESCRIPTION = """The purpose of this check is to demonstrate the implementation \ +of a hook-based check. This check will generate verdicts if enabled.""" + + +class ExampleCheck(MalwareCheckBase): + + version = VERSION + short_description = SHORT_DESCRIPTION + long_description = LONG_DESCRIPTION + + def __init__(self, db): + super().__init__(db) + + def run(self, file_id): + verdict = MalwareVerdict( + check_id=self.id, + file_id=file_id, + classification=VerdictClassification.benign, + confidence=VerdictConfidence.High, + message="Nothing to see here!", + ) + self.db.add(verdict) diff --git a/warehouse/malware/interfaces.py b/warehouse/malware/interfaces.py new file mode 100644 index 000000000000..f179aa374d55 --- /dev/null +++ b/warehouse/malware/interfaces.py @@ -0,0 +1,26 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from zope.interface import Interface + + +class IMalwareCheckService(Interface): + def create_service(context, request): + """ + Create the service, given the context and request for which it is being + created for. + """ + + def run_checks(checks): + """ + Run a given set of Checks + """ diff --git a/warehouse/malware/models.py b/warehouse/malware/models.py index 0c51e3006991..257e7bfa2bd5 100644 --- a/warehouse/malware/models.py +++ b/warehouse/malware/models.py @@ -46,6 +46,13 @@ class MalwareCheckState(enum.Enum): wiped_out = "wiped_out" +class MalwareCheckObjectType(enum.Enum): + + File = "File" + Release = "Release" + Project = "Project" + + class VerdictClassification(enum.Enum): threat = "threat" @@ -74,9 +81,12 @@ class MalwareCheck(db.Model): Enum(MalwareCheckType, values_callable=lambda x: [e.value for e in x]), nullable=False, ) - # This field contains the same content as the ProjectEvent and UserEvent "tag" - # fields. - hook_name = Column(String, nullable=True) + # This field contains the object name that check operates on, e.g. + # Project, File, Release + hooked_object = Column( + Enum(MalwareCheckObjectType, values_callable=lambda x: [e.value for e in x]), + nullable=True, + ) state = Column( Enum(MalwareCheckState, values_callable=lambda x: [e.value for e in x]), nullable=False, diff --git a/warehouse/malware/services.py b/warehouse/malware/services.py new file mode 100644 index 000000000000..f2f454b964e2 --- /dev/null +++ b/warehouse/malware/services.py @@ -0,0 +1,45 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from zope.interface import implementer + +from warehouse.malware.interfaces import IMalwareCheckService +from warehouse.malware.tasks import run_check + + +@implementer(IMalwareCheckService) +class PrinterMalwareCheckService: + def __init__(self, executor): + self.executor = executor + + @classmethod + def create_service(cls, context, request): + return cls(print) + + def run_checks(self, checks): + for check in checks: + self.executor(check) + + +@implementer(IMalwareCheckService) +class DatabaseMalwareCheckService: + def __init__(self, executor): + self.executor = executor + + @classmethod + def create_service(cls, context, request): + return cls(request.task(run_check).delay) + + def run_checks(self, checks): + for check_info in checks: + check_name, obj_id = check_info.split(":") + self.executor(check_name, obj_id) diff --git a/warehouse/malware/tasks.py b/warehouse/malware/tasks.py new file mode 100644 index 000000000000..ade9f7a9ae73 --- /dev/null +++ b/warehouse/malware/tasks.py @@ -0,0 +1,26 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import warehouse.malware.checks as checks + +from warehouse.tasks import task + + +@task(bind=True, ignore_result=True, acks_late=True) +def run_check(task, request, check_name, obj_id): + try: + check = getattr(checks, check_name)(request.db) + check.run(obj_id) + except Exception as exc: + request.log.error("Error executing check %s: %s", check_name, str(exc)) + raise task.retry(exc=exc) diff --git a/warehouse/malware/utils.py b/warehouse/malware/utils.py new file mode 100644 index 000000000000..b0a5bf5b49e7 --- /dev/null +++ b/warehouse/malware/utils.py @@ -0,0 +1,42 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools + +from collections import defaultdict + +from warehouse.malware.models import ( + MalwareCheck, + MalwareCheckObjectType, + MalwareCheckState, + MalwareCheckType, +) + + +@functools.lru_cache() +def valid_check_types(): + return set([t.value for t in MalwareCheckObjectType]) + + +def get_enabled_checks(session): + checks = ( + session.query(MalwareCheck.name, MalwareCheck.hooked_object) + .filter(MalwareCheck.check_type == MalwareCheckType.event_hook) + .filter(MalwareCheck.state == MalwareCheckState.enabled) + .all() + ) + results = defaultdict(list) + + for check_name, object_type in checks: + results[object_type.value].append(check_name) + + return results diff --git a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py index 569cc0f100b1..e74a9ddabe94 100644 --- a/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py +++ b/warehouse/migrations/versions/061ff3d24c22_add_malware_detection_tables.py @@ -31,6 +31,10 @@ "enabled", "evaluation", "disabled", "wiped_out", name="malwarecheckstate" ) +MalwareCheckObjectTypes = sa.Enum( + "File", "Release", "Project", name="malwarecheckobjecttype" +) + VerdictClassifications = sa.Enum( "threat", "indeterminate", "benign", name="verdictclassification" ) @@ -51,7 +55,7 @@ def upgrade(): sa.Column("short_description", sa.String(length=128), nullable=False), sa.Column("long_description", sa.Text(), nullable=False), sa.Column("check_type", MalwareCheckTypes, nullable=False), - sa.Column("hook_name", sa.String(), nullable=True), + sa.Column("hooked_object", MalwareCheckObjectTypes, nullable=True), sa.Column( "state", MalwareCheckStates, server_default="disabled", nullable=False, ), @@ -106,5 +110,6 @@ def downgrade(): op.drop_table("malware_checks") MalwareCheckTypes.drop(op.get_bind()) MalwareCheckStates.drop(op.get_bind()) + MalwareCheckObjectTypes.drop(op.get_bind()) VerdictClassifications.drop(op.get_bind()) VerdictConfidences.drop(op.get_bind())