Skip to content

Commit

Permalink
Add initial hook-based check execution mechanism (#7160)
Browse files Browse the repository at this point in the history
* Add initial hook-based check execution mechanism

* scratch/poc

* Add initial hook-based check execution mechanism

* Use sqlalchemy event hooks for malware checks

* Fix unit tests

* Add enum for MalwareCheckObjectType

* Add unit tests for init.

* Add tests for tasks, services, and utils.

Also, some small bugfixes in MalwareCheckFactory and the
get_enabled_checks method.

* Fix spurious task test.

* Add missing drop enum to downgrade function.

* Added TODO to dev/environment

* Be more explicit in check lookup

Co-authored-by: Ernest W. Durbin III <[email protected]>
  • Loading branch information
xmunoz and ewdurbin committed Jan 8, 2020
1 parent dc50d93 commit 466acdc
Show file tree
Hide file tree
Showing 20 changed files with 700 additions and 10 deletions.
3 changes: 3 additions & 0 deletions dev/environment
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ MAIL_BACKEND=warehouse.email.services.SMTPEmailSender host=smtp port=2525 ssl=fa

BREACHED_PASSWORDS=warehouse.accounts.NullPasswordBreachedService

#TODO: change this to PrinterMalwareCheckService before deploy
MALWARE_CHECK_BACKEND=warehouse.malware.services.DatabaseMalwareCheckService

METRICS_BACKEND=warehouse.metrics.DataDogMetrics host=notdatadog

STATUSPAGE_URL=https://2p66nmmycsj3.statuspage.io
Expand Down
13 changes: 7 additions & 6 deletions tests/common/db/malware.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
import factory
import factory.fuzzy

from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareCheckType
from warehouse.malware.models import (
MalwareCheck,
MalwareCheckObjectType,
MalwareCheckState,
MalwareCheckType,
)

from .base import WarehouseFactory

Expand All @@ -29,11 +34,7 @@ class Meta:
short_description = factory.fuzzy.FuzzyText(length=80)
long_description = factory.fuzzy.FuzzyText(length=300)
check_type = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckType])
hook_name = (
"project:release:file:upload"
if check_type == MalwareCheckType.event_hook
else None
)
hooked_object = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckObjectType])
state = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckState])
created = factory.fuzzy.FuzzyNaiveDateTime(
datetime.datetime.utcnow() - datetime.timedelta(days=7)
Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def app_config(database):
"files.backend": "warehouse.packaging.services.LocalFileStorage",
"docs.backend": "warehouse.packaging.services.LocalFileStorage",
"mail.backend": "warehouse.email.services.SMTPEmailSender",
"malware_check.backend": (
"warehouse.malware.services.PrinterMalwareCheckService"
),
"files.url": "http://localhost:7000/",
"sessions.secret": "123456",
"sessions.url": "redis://localhost:0/",
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/malware/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
171 changes: 171 additions & 0 deletions tests/unit/malware/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict

import pretend

from warehouse import malware
from warehouse.malware import utils
from warehouse.malware.interfaces import IMalwareCheckService

from ...common.db.accounts import UserFactory
from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory


def test_determine_malware_checks_no_checks(monkeypatch, db_request):
def get_enabled_checks(session):
return defaultdict(list)

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
file0 = FileFactory.create(release=release, filename="foo.bar")

session = pretend.stub(info={}, new={file0, release, project}, dirty={}, deleted={})

malware.determine_malware_checks(pretend.stub(), session, pretend.stub())
assert session.info["warehouse.malware.checks"] == set()


def test_determine_malware_checks_nothing_new(monkeypatch, db_request):
def get_enabled_checks(session):
result = defaultdict(list)
result["File"] = ["Check1", "Check2"]
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
file0 = FileFactory.create(release=release, filename="foo.bar")

session = pretend.stub(info={}, new={}, dirty={file0, release}, deleted={})

malware.determine_malware_checks(pretend.stub(), session, pretend.stub())
assert session.info.get("warehouse.malware.checks") is None


def test_determine_malware_checks_unsupported_object(monkeypatch, db_request):
def get_enabled_checks(session):
result = defaultdict(list)
result["File"] = ["Check1", "Check2"]
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)

user = UserFactory.create()

session = pretend.stub(info={}, new={user}, dirty={}, deleted={})

malware.determine_malware_checks(pretend.stub(), session, pretend.stub())
assert session.info.get("warehouse.malware.checks") is None


def test_determine_malware_checks_file_only(monkeypatch, db_request):
def get_enabled_checks(session):
result = defaultdict(list)
result["File"] = ["Check1", "Check2"]
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
file0 = FileFactory.create(release=release, filename="foo.bar")

session = pretend.stub(info={}, new={file0}, dirty={}, deleted={})

checks = set(["Check%d:%s" % (x, file0.id) for x in range(1, 3)])
malware.determine_malware_checks(pretend.stub(), session, pretend.stub())
assert session.info["warehouse.malware.checks"] == checks


def test_determine_malware_checks_file_and_release(monkeypatch, db_request):
def get_enabled_checks(session):
result = defaultdict(list)
result["File"] = ["Check1", "Check2"]
result["Release"] = ["Check3"]
return result

monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)

project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
file0 = FileFactory.create(release=release, filename="foo.bar")
file1 = FileFactory.create(release=release, filename="foo.baz")

session = pretend.stub(
info={}, new={project, release, file0, file1}, dirty={}, deleted={}
)

checks = set(["Check%d:%s" % (x, file0.id) for x in range(1, 3)])
checks.update(["Check%d:%s" % (x, file1.id) for x in range(1, 3)])
checks.add("Check3:%s" % release.id)

malware.determine_malware_checks(pretend.stub(), session, pretend.stub())

assert session.info["warehouse.malware.checks"] == checks


def test_enqueue_malware_checks(app_config):
malware_check = pretend.stub(
run_checks=pretend.call_recorder(lambda malware_checks: None)
)
factory = pretend.call_recorder(lambda ctx, config: malware_check)
app_config.register_service_factory(factory, IMalwareCheckService)
app_config.commit()
session = pretend.stub(
info={
"warehouse.malware.checks": {"Check1:ba70267f-fabf-496f-9ac2-d237a983b187"}
}
)

malware.queue_malware_checks(app_config, session)

assert factory.calls == [pretend.call(None, app_config)]
assert malware_check.run_checks.calls == [
pretend.call({"Check1:ba70267f-fabf-496f-9ac2-d237a983b187"})
]
assert "warehouse.malware.checks" not in session.info


def test_enqueue_malware_checks_no_checks(app_config):
session = pretend.stub(info={})
malware.queue_malware_checks(app_config, session)
assert "warehouse.malware.checks" not in session.info


def test_includeme():
malware_check_class = pretend.stub(
create_service=pretend.call_recorder(lambda *a, **kw: pretend.stub())
)

config = pretend.stub(
maybe_dotted=lambda dotted: malware_check_class,
register_service_factory=pretend.call_recorder(
lambda factory, iface, name=None: None
),
registry=pretend.stub(
settings={"malware_check.backend": "TestMalwareCheckService"}
),
)

malware.includeme(config)

assert config.register_service_factory.calls == [
pretend.call(malware_check_class.create_service, IMalwareCheckService),
]
61 changes: 61 additions & 0 deletions tests/unit/malware/test_services.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pretend

from zope.interface.verify import verifyClass

from warehouse.malware.interfaces import IMalwareCheckService
from warehouse.malware.services import (
DatabaseMalwareCheckService,
PrinterMalwareCheckService,
)
from warehouse.malware.tasks import run_check


class TestPrinterMalwareCheckService:
def test_verify_service(self):
assert verifyClass(IMalwareCheckService, PrinterMalwareCheckService)

def test_create_service(self):
request = pretend.stub()
service = PrinterMalwareCheckService.create_service(None, request)
assert service.executor == print

def test_run_checks(self, capfd):
request = pretend.stub()
service = PrinterMalwareCheckService.create_service(None, request)
checks = ["one", "two", "three"]
service.run_checks(checks)
out, err = capfd.readouterr()
assert out == "one\ntwo\nthree\n"


class TestDatabaseMalwareService:
def test_verify_service(self):
assert verifyClass(IMalwareCheckService, DatabaseMalwareCheckService)

def test_create_service(self, db_request):
_delay = pretend.call_recorder(lambda *args: None)
db_request.task = lambda x: pretend.stub(delay=_delay)
service = DatabaseMalwareCheckService.create_service(None, db_request)
assert service.executor == db_request.task(run_check).delay

def test_run_checks(self, db_request):
_delay = pretend.call_recorder(lambda *args: None)
db_request.task = lambda x: pretend.stub(delay=_delay)
service = DatabaseMalwareCheckService.create_service(None, db_request)
checks = ["MyTestCheck:ba70267f-fabf-496f-9ac2-d237a983b187"]
service.run_checks(checks)
assert _delay.calls == [
pretend.call("MyTestCheck", "ba70267f-fabf-496f-9ac2-d237a983b187")
]
85 changes: 85 additions & 0 deletions tests/unit/malware/test_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import celery
import pretend
import pytest

from sqlalchemy.orm.exc import NoResultFound

import warehouse.malware.checks as checks

from warehouse.malware.models import MalwareVerdict
from warehouse.malware.tasks import run_check

from ...common.db.malware import MalwareCheckFactory
from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory


def test_run_check(monkeypatch, db_request):
project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
file0 = FileFactory.create(release=release, filename="foo.bar")
MalwareCheckFactory.create(name="ExampleCheck", state="enabled")

task = pretend.stub()
run_check(task, db_request, "ExampleCheck", file0.id)
assert db_request.db.query(MalwareVerdict).one()


def test_run_check_missing_check_id(monkeypatch, db_session):
exc = NoResultFound("No row was found for one()")

class FakeMalwareCheck:
def __init__(self, db):
raise exc

class Task:
@staticmethod
@pretend.call_recorder
def retry(exc):
raise celery.exceptions.Retry

task = Task()

checks.FakeMalwareCheck = FakeMalwareCheck

request = pretend.stub(
db=db_session,
log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None),),
)

with pytest.raises(celery.exceptions.Retry):
run_check(
task, request, "FakeMalwareCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7"
)

assert request.log.error.calls == [
pretend.call(
"Error executing check %s: %s",
"FakeMalwareCheck",
"No row was found for one()",
)
]

assert task.retry.calls == [pretend.call(exc=exc)]


def test_run_check_missing_check(db_request):
task = pretend.stub()
with pytest.raises(AttributeError):
run_check(
task,
db_request,
"DoesNotExistCheck",
"d03d75d1-2511-4a8b-9759-62294a6fe3a7",
)
Loading

0 comments on commit 466acdc

Please sign in to comment.