Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add malware check syncing mechanism #7190

Merged
merged 2 commits into from
Jan 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bin/release
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ set -eo pipefail

# Migrate our database to the latest revision.
python -m warehouse db upgrade head

# Insert/upgrade malware checks.
python -m warehouse malware sync-checks
36 changes: 36 additions & 0 deletions tests/unit/cli/test_malware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pretend

from warehouse.cli.malware import sync_checks
from warehouse.malware.tasks import sync_checks as _sync_checks


class TestCLIMalware:
def test_sync_checks(self, cli):
request = pretend.stub()
task = pretend.stub(
get_request=pretend.call_recorder(lambda *a, **kw: request),
run=pretend.call_recorder(lambda *a, **kw: None),
)
config = pretend.stub(task=pretend.call_recorder(lambda *a, **kw: task))

result = cli.invoke(sync_checks, obj=config)

assert result.exit_code == 0
assert config.task.calls == [
pretend.call(_sync_checks),
pretend.call(_sync_checks),
]
assert task.get_request.calls == [pretend.call()]
assert task.run.calls == [pretend.call(request)]
45 changes: 45 additions & 0 deletions tests/unit/malware/test_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect

import warehouse.malware.checks as checks

from warehouse.malware.checks.base import MalwareCheckBase
from warehouse.malware.utils import get_check_fields


def test_checks_subclass_base():
checks_from_module = inspect.getmembers(checks, inspect.isclass)

subclasses_of_malware_base = {
cls.__name__: cls for cls in MalwareCheckBase.__subclasses__()
}

assert len(checks_from_module) == len(subclasses_of_malware_base)

for check_name, check in checks_from_module:
assert subclasses_of_malware_base[check_name] == check


def test_checks_fields():
checks_from_module = inspect.getmembers(checks, inspect.isclass)

for check_name, check in checks_from_module:
elems = inspect.getmembers(check, lambda a: not (inspect.isroutine(a)))
inspection_fields = {"name": check_name}
for elem_name, value in elems:
if not elem_name.startswith("__"):
inspection_fields[elem_name] = value
fields = get_check_fields(check)

assert inspection_fields == fields
258 changes: 215 additions & 43 deletions tests/unit/malware/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,68 +18,240 @@

import warehouse.malware.checks as checks

from warehouse.malware.models import MalwareVerdict
from warehouse.malware.tasks import run_check
from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict
from warehouse.malware.tasks import run_check, sync_checks

from ...common.db.malware import MalwareCheckFactory
from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory


def test_run_check(monkeypatch, db_request):
project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
file0 = FileFactory.create(release=release, filename="foo.bar")
MalwareCheckFactory.create(name="ExampleCheck", state="enabled")
class TestRunCheck:
def test_success(self, monkeypatch, db_request):
project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
file0 = FileFactory.create(release=release, filename="foo.bar")
MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled)

task = pretend.stub()
run_check(task, db_request, "ExampleCheck", file0.id)
assert db_request.db.query(MalwareVerdict).one()
task = pretend.stub()
run_check(task, db_request, "ExampleCheck", file0.id)
assert db_request.db.query(MalwareVerdict).one()

def test_missing_check_id(self, monkeypatch, db_session):
exc = NoResultFound("No row was found for one()")

def test_run_check_missing_check_id(monkeypatch, db_session):
exc = NoResultFound("No row was found for one()")
class FakeMalwareCheck:
def __init__(self, db):
raise exc

class FakeMalwareCheck:
def __init__(self, db):
raise exc
checks.FakeMalwareCheck = FakeMalwareCheck

class Task:
@staticmethod
@pretend.call_recorder
def retry(exc):
raise celery.exceptions.Retry
class Task:
@staticmethod
@pretend.call_recorder
def retry(exc):
raise celery.exceptions.Retry

task = Task()
task = Task()

checks.FakeMalwareCheck = FakeMalwareCheck
request = pretend.stub(
db=db_session,
log=pretend.stub(
error=pretend.call_recorder(lambda *args, **kwargs: None),
),
)

with pytest.raises(celery.exceptions.Retry):
run_check(
task,
request,
"FakeMalwareCheck",
"d03d75d1-2511-4a8b-9759-62294a6fe3a7",
)

assert request.log.error.calls == [
pretend.call(
"Error executing check %s: %s",
"FakeMalwareCheck",
"No row was found for one()",
)
]

assert task.retry.calls == [pretend.call(exc=exc)]

del checks.FakeMalwareCheck

def test_missing_check(self, db_request):
task = pretend.stub()
with pytest.raises(AttributeError):
run_check(
task,
db_request,
"DoesNotExistCheck",
"d03d75d1-2511-4a8b-9759-62294a6fe3a7",
)


class TestSyncChecks:
def test_no_updates(self, db_session):
MalwareCheckFactory.create(
name="ExampleCheck", state=MalwareCheckState.disabled
)

task = pretend.stub()

request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
)

sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("1 malware checks found in codebase."),
pretend.call("ExampleCheck is unmodified."),
]

request = pretend.stub(
db=db_session,
log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None),),
@pytest.mark.parametrize(
("final_state"), [MalwareCheckState.enabled, MalwareCheckState.disabled]
)
def test_upgrade_check(self, monkeypatch, db_session, final_state):
MalwareCheckFactory.create(name="ExampleCheck", state=final_state)

class ExampleCheck:
version = 2
short_description = "This is a short description."
long_description = "This is a longer description."
check_type = "scheduled"

monkeypatch.setattr(checks, "ExampleCheck", ExampleCheck)

task = pretend.stub()
request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
)

sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("1 malware checks found in codebase."),
pretend.call("Updating existing ExampleCheck."),
]
db_checks = (
db_session.query(MalwareCheck)
.filter(MalwareCheck.name == "ExampleCheck")
.all()
)

assert len(db_checks) == 2

with pytest.raises(celery.exceptions.Retry):
run_check(
task, request, "FakeMalwareCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7"
if final_state == MalwareCheckState.disabled:
assert (
db_checks[0].state == db_checks[1].state == MalwareCheckState.disabled
)

else:
for c in db_checks:
if c.state == final_state:
assert c.version == 2
else:
assert c.version == 1

def test_one_new_check(self, db_session):
task = pretend.stub()

class FakeMalwareCheck:
version = 1
short_description = "This is a short description."
long_description = "This is a longer description."
check_type = "scheduled"

checks.FakeMalwareCheck = FakeMalwareCheck

request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
)

MalwareCheckFactory.create(
name="ExampleCheck", state=MalwareCheckState.evaluation
)

sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("2 malware checks found in codebase."),
pretend.call("ExampleCheck is unmodified."),
pretend.call("Adding new FakeMalwareCheck to the database."),
]
assert db_session.query(MalwareCheck).count() == 2

new_check = (
db_session.query(MalwareCheck)
.filter(MalwareCheck.name == "FakeMalwareCheck")
.one()
)

assert request.log.error.calls == [
pretend.call(
"Error executing check %s: %s",
"FakeMalwareCheck",
"No row was found for one()",
assert new_check.state == MalwareCheckState.disabled

del checks.FakeMalwareCheck

def test_too_many_db_checks(self, db_session):
task = pretend.stub()

MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled)
MalwareCheckFactory.create(
name="AnotherCheck", state=MalwareCheckState.disabled
)
MalwareCheckFactory.create(
name="AnotherCheck", state=MalwareCheckState.evaluation, version=2
)

request = pretend.stub(
db=db_session,
log=pretend.stub(
info=pretend.call_recorder(lambda *args, **kwargs: None),
error=pretend.call_recorder(lambda *args, **kwargs: None),
),
)
]

assert task.retry.calls == [pretend.call(exc=exc)]
with pytest.raises(Exception):
sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("1 malware checks found in codebase."),
]

def test_run_check_missing_check(db_request):
task = pretend.stub()
with pytest.raises(AttributeError):
run_check(
task,
db_request,
"DoesNotExistCheck",
"d03d75d1-2511-4a8b-9759-62294a6fe3a7",
assert request.log.error.calls == [
pretend.call(
"Found 2 active checks in the db, but only 1 checks in code. Please \
manually move superfluous checks to the wiped_out state in the check admin: \
AnotherCheck"
),
]

def test_only_wiped_out(self, db_session):
task = pretend.stub()
MalwareCheckFactory.create(
name="ExampleCheck", state=MalwareCheckState.wiped_out
)
request = pretend.stub(
db=db_session,
log=pretend.stub(
info=pretend.call_recorder(lambda *args, **kwargs: None),
error=pretend.call_recorder(lambda *args, **kwargs: None),
),
)

sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("1 malware checks found in codebase."),
]

assert request.log.error.calls == [
pretend.call(
"ExampleCheck is wiped_out and cannot be synced. Please remove check \
from codebase."
),
]
Loading