Skip to content

Commit

Permalink
Add malware check syncing mechanism (#7190)
Browse files Browse the repository at this point in the history
* Add malware check syncing mechanism

* Code review changes.
  • Loading branch information
xmunoz authored and ewdurbin committed Jan 7, 2020
1 parent decdea9 commit b16196b
Show file tree
Hide file tree
Showing 9 changed files with 464 additions and 74 deletions.
3 changes: 3 additions & 0 deletions bin/release
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ set -eo pipefail

# Migrate our database to the latest revision.
python -m warehouse db upgrade head

# Insert/upgrade malware checks.
python -m warehouse malware sync-checks
36 changes: 36 additions & 0 deletions tests/unit/cli/test_malware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pretend

from warehouse.cli.malware import sync_checks
from warehouse.malware.tasks import sync_checks as _sync_checks


class TestCLIMalware:
def test_sync_checks(self, cli):
request = pretend.stub()
task = pretend.stub(
get_request=pretend.call_recorder(lambda *a, **kw: request),
run=pretend.call_recorder(lambda *a, **kw: None),
)
config = pretend.stub(task=pretend.call_recorder(lambda *a, **kw: task))

result = cli.invoke(sync_checks, obj=config)

assert result.exit_code == 0
assert config.task.calls == [
pretend.call(_sync_checks),
pretend.call(_sync_checks),
]
assert task.get_request.calls == [pretend.call()]
assert task.run.calls == [pretend.call(request)]
45 changes: 45 additions & 0 deletions tests/unit/malware/test_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect

import warehouse.malware.checks as checks

from warehouse.malware.checks.base import MalwareCheckBase
from warehouse.malware.utils import get_check_fields


def test_checks_subclass_base():
checks_from_module = inspect.getmembers(checks, inspect.isclass)

subclasses_of_malware_base = {
cls.__name__: cls for cls in MalwareCheckBase.__subclasses__()
}

assert len(checks_from_module) == len(subclasses_of_malware_base)

for check_name, check in checks_from_module:
assert subclasses_of_malware_base[check_name] == check


def test_checks_fields():
checks_from_module = inspect.getmembers(checks, inspect.isclass)

for check_name, check in checks_from_module:
elems = inspect.getmembers(check, lambda a: not (inspect.isroutine(a)))
inspection_fields = {"name": check_name}
for elem_name, value in elems:
if not elem_name.startswith("__"):
inspection_fields[elem_name] = value
fields = get_check_fields(check)

assert inspection_fields == fields
258 changes: 215 additions & 43 deletions tests/unit/malware/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,68 +18,240 @@

import warehouse.malware.checks as checks

from warehouse.malware.models import MalwareVerdict
from warehouse.malware.tasks import run_check
from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict
from warehouse.malware.tasks import run_check, sync_checks

from ...common.db.malware import MalwareCheckFactory
from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory


def test_run_check(monkeypatch, db_request):
project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
file0 = FileFactory.create(release=release, filename="foo.bar")
MalwareCheckFactory.create(name="ExampleCheck", state="enabled")
class TestRunCheck:
def test_success(self, monkeypatch, db_request):
project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
file0 = FileFactory.create(release=release, filename="foo.bar")
MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled)

task = pretend.stub()
run_check(task, db_request, "ExampleCheck", file0.id)
assert db_request.db.query(MalwareVerdict).one()
task = pretend.stub()
run_check(task, db_request, "ExampleCheck", file0.id)
assert db_request.db.query(MalwareVerdict).one()

def test_missing_check_id(self, monkeypatch, db_session):
exc = NoResultFound("No row was found for one()")

def test_run_check_missing_check_id(monkeypatch, db_session):
exc = NoResultFound("No row was found for one()")
class FakeMalwareCheck:
def __init__(self, db):
raise exc

class FakeMalwareCheck:
def __init__(self, db):
raise exc
checks.FakeMalwareCheck = FakeMalwareCheck

class Task:
@staticmethod
@pretend.call_recorder
def retry(exc):
raise celery.exceptions.Retry
class Task:
@staticmethod
@pretend.call_recorder
def retry(exc):
raise celery.exceptions.Retry

task = Task()
task = Task()

checks.FakeMalwareCheck = FakeMalwareCheck
request = pretend.stub(
db=db_session,
log=pretend.stub(
error=pretend.call_recorder(lambda *args, **kwargs: None),
),
)

with pytest.raises(celery.exceptions.Retry):
run_check(
task,
request,
"FakeMalwareCheck",
"d03d75d1-2511-4a8b-9759-62294a6fe3a7",
)

assert request.log.error.calls == [
pretend.call(
"Error executing check %s: %s",
"FakeMalwareCheck",
"No row was found for one()",
)
]

assert task.retry.calls == [pretend.call(exc=exc)]

del checks.FakeMalwareCheck

def test_missing_check(self, db_request):
task = pretend.stub()
with pytest.raises(AttributeError):
run_check(
task,
db_request,
"DoesNotExistCheck",
"d03d75d1-2511-4a8b-9759-62294a6fe3a7",
)


class TestSyncChecks:
def test_no_updates(self, db_session):
MalwareCheckFactory.create(
name="ExampleCheck", state=MalwareCheckState.disabled
)

task = pretend.stub()

request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
)

sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("1 malware checks found in codebase."),
pretend.call("ExampleCheck is unmodified."),
]

request = pretend.stub(
db=db_session,
log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None),),
@pytest.mark.parametrize(
("final_state"), [MalwareCheckState.enabled, MalwareCheckState.disabled]
)
def test_upgrade_check(self, monkeypatch, db_session, final_state):
MalwareCheckFactory.create(name="ExampleCheck", state=final_state)

class ExampleCheck:
version = 2
short_description = "This is a short description."
long_description = "This is a longer description."
check_type = "scheduled"

monkeypatch.setattr(checks, "ExampleCheck", ExampleCheck)

task = pretend.stub()
request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
)

sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("1 malware checks found in codebase."),
pretend.call("Updating existing ExampleCheck."),
]
db_checks = (
db_session.query(MalwareCheck)
.filter(MalwareCheck.name == "ExampleCheck")
.all()
)

assert len(db_checks) == 2

with pytest.raises(celery.exceptions.Retry):
run_check(
task, request, "FakeMalwareCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7"
if final_state == MalwareCheckState.disabled:
assert (
db_checks[0].state == db_checks[1].state == MalwareCheckState.disabled
)

else:
for c in db_checks:
if c.state == final_state:
assert c.version == 2
else:
assert c.version == 1

def test_one_new_check(self, db_session):
task = pretend.stub()

class FakeMalwareCheck:
version = 1
short_description = "This is a short description."
long_description = "This is a longer description."
check_type = "scheduled"

checks.FakeMalwareCheck = FakeMalwareCheck

request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
)

MalwareCheckFactory.create(
name="ExampleCheck", state=MalwareCheckState.evaluation
)

sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("2 malware checks found in codebase."),
pretend.call("ExampleCheck is unmodified."),
pretend.call("Adding new FakeMalwareCheck to the database."),
]
assert db_session.query(MalwareCheck).count() == 2

new_check = (
db_session.query(MalwareCheck)
.filter(MalwareCheck.name == "FakeMalwareCheck")
.one()
)

assert request.log.error.calls == [
pretend.call(
"Error executing check %s: %s",
"FakeMalwareCheck",
"No row was found for one()",
assert new_check.state == MalwareCheckState.disabled

del checks.FakeMalwareCheck

def test_too_many_db_checks(self, db_session):
task = pretend.stub()

MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled)
MalwareCheckFactory.create(
name="AnotherCheck", state=MalwareCheckState.disabled
)
MalwareCheckFactory.create(
name="AnotherCheck", state=MalwareCheckState.evaluation, version=2
)

request = pretend.stub(
db=db_session,
log=pretend.stub(
info=pretend.call_recorder(lambda *args, **kwargs: None),
error=pretend.call_recorder(lambda *args, **kwargs: None),
),
)
]

assert task.retry.calls == [pretend.call(exc=exc)]
with pytest.raises(Exception):
sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("1 malware checks found in codebase."),
]

def test_run_check_missing_check(db_request):
task = pretend.stub()
with pytest.raises(AttributeError):
run_check(
task,
db_request,
"DoesNotExistCheck",
"d03d75d1-2511-4a8b-9759-62294a6fe3a7",
assert request.log.error.calls == [
pretend.call(
"Found 2 active checks in the db, but only 1 checks in code. Please \
manually move superfluous checks to the wiped_out state in the check admin: \
AnotherCheck"
),
]

def test_only_wiped_out(self, db_session):
task = pretend.stub()
MalwareCheckFactory.create(
name="ExampleCheck", state=MalwareCheckState.wiped_out
)
request = pretend.stub(
db=db_session,
log=pretend.stub(
info=pretend.call_recorder(lambda *args, **kwargs: None),
error=pretend.call_recorder(lambda *args, **kwargs: None),
),
)

sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("1 malware checks found in codebase."),
]

assert request.log.error.calls == [
pretend.call(
"ExampleCheck is wiped_out and cannot be synced. Please remove check \
from codebase."
),
]
Loading

0 comments on commit b16196b

Please sign in to comment.