Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add SAFETY_DB_DIR env var to the scan command #523

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion safety/auth/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def update_token(tokens, **kwargs):
try:
openid_config = client_session.get(url=OPENID_CONFIG_URL, timeout=REQUEST_TIMEOUT).json()
except Exception as e:
LOG.exception('Unable to load the openID config: %s', e)
LOG.debug('Unable to load the openID config: %s', e)
openid_config = {}

client_session.metadata["token_endpoint"] = openid_config.get("token_endpoint",
Expand Down
4 changes: 4 additions & 0 deletions safety/auth/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
import os
from typing import Any, Optional

from authlib.integrations.base_client import BaseOAuth
Expand Down Expand Up @@ -26,6 +27,9 @@ class Auth:
email_verified: bool = False

def is_valid(self) -> bool:
if os.getenv("SAFETY_DB_DIR"):
return True

if not self.client:
return False

Expand Down
11 changes: 10 additions & 1 deletion safety/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,20 @@

LOG = logging.getLogger(__name__)


def configure_logger(ctx, param, debug):
level = logging.CRITICAL

if debug:
level = logging.DEBUG

logging.basicConfig(format='%(asctime)s %(name)s => %(message)s', level=level)

@click.group(cls=SafetyCLILegacyGroup, help=CLI_MAIN_INTRODUCTION, epilog=DEFAULT_EPILOG)
@auth_options()
@proxy_options
@click.option('--disable-optional-telemetry', default=False, is_flag=True, show_default=True, help=CLI_DISABLE_OPTIONAL_TELEMETRY_DATA_HELP)
@click.option('--debug', default=False, help=CLI_DEBUG_HELP)
@click.option('--debug', default=False, help=CLI_DEBUG_HELP, callback=configure_logger)
@click.version_option(version=get_safety_version())
@click.pass_context
@inject_session
Expand Down
27 changes: 20 additions & 7 deletions safety/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,21 @@ def post_results(session, safety_json, policy_file):
return {}


def fetch_database_file(path, db_name, ecosystem: Ecosystem = Ecosystem.PYTHON):
full_path = os.path.join(path, db_name)
if not os.path.exists(full_path):
def fetch_database_file(path: str, db_name: str, cached = 0,
ecosystem: Optional[Ecosystem] = None):
full_path = (Path(path) / (ecosystem.value if ecosystem else '') / db_name).expanduser().resolve()

if not full_path.exists():
raise DatabaseFileNotFoundError(db=path)

with open(full_path) as f:
return json.loads(f.read())
data = json.loads(f.read())

if cached:
LOG.info('Writing %s to cache because cached value was %s', db_name, cached)
write_to_cache(db_name, data)

return data


def is_valid_database(db) -> bool:
Expand All @@ -218,7 +227,8 @@ def is_valid_database(db) -> bool:


def fetch_database(session, full=False, db=False, cached=0, telemetry=True,
ecosystem: Ecosystem = Ecosystem.PYTHON, from_cache=True):
ecosystem: Optional[Ecosystem] = None, from_cache=True):

if session.is_using_auth_credentials():
mirrors = API_MIRRORS
elif db:
Expand All @@ -230,10 +240,13 @@ def fetch_database(session, full=False, db=False, cached=0, telemetry=True,
for mirror in mirrors:
# mirror can either be a local path or a URL
if is_a_remote_mirror(mirror):
if ecosystem is None:
ecosystem = Ecosystem.PYTHON
data = fetch_database_url(session, mirror, db_name=db_name, cached=cached,
telemetry=telemetry, ecosystem=ecosystem, from_cache=from_cache)
else:
data = fetch_database_file(mirror, db_name=db_name, ecosystem=ecosystem)
data = fetch_database_file(mirror, db_name=db_name, cached=cached,
ecosystem=ecosystem)
if data:
if is_valid_database(data):
return data
Expand Down Expand Up @@ -1000,7 +1013,7 @@ def get_licenses(*, session=None, db_mirror=False, cached=0, telemetry=True):
licenses = fetch_database_url(session, mirror, db_name=db_name, cached=cached,
telemetry=telemetry)
else:
licenses = fetch_database_file(mirror, db_name=db_name)
licenses = fetch_database_file(mirror, db_name=db_name, ecosystem=None)
if licenses:
return licenses
raise DatabaseFetchError()
Expand Down
13 changes: 9 additions & 4 deletions safety/scan/decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import wraps
import logging
import os
from pathlib import Path
from random import randint
import sys
Expand Down Expand Up @@ -135,11 +136,15 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path,
if ctx.obj.auth.client.get_authentication_type() == "api_key":
details = {"Account": f"API key used"}
else:
content = ctx.obj.auth.email
if ctx.obj.auth.name != ctx.obj.auth.email:
content = f"{ctx.obj.auth.name}, {ctx.obj.auth.email}"

details = {"Account": f"{content} {render_email_note(ctx.obj.auth)}"}
if ctx.obj.auth.client.get_authentication_type() == "token":
content = ctx.obj.auth.email
if ctx.obj.auth.name != ctx.obj.auth.email:
content = f"{ctx.obj.auth.name}, {ctx.obj.auth.email}"

details = {"Account": f"{content} {render_email_note(ctx.obj.auth)}"}
else:
details = {"Account": f"Offline - {os.getenv('SAFETY_DB_DIR')}"}

if ctx.obj.project.id:
details["Project"] = ctx.obj.project.id
Expand Down
10 changes: 8 additions & 2 deletions safety/scan/finder/handlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
import os
from pathlib import Path
from types import MappingProxyType
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -49,12 +50,17 @@ def __init__(self) -> None:

def download_required_assets(self, session):
from safety.safety import fetch_database

SAFETY_DB_DIR = os.getenv("SAFETY_DB_DIR")

db = False if SAFETY_DB_DIR is None else SAFETY_DB_DIR


fetch_database(session=session, full=False, db=False, cached=True,
fetch_database(session=session, full=False, db=db, cached=True,
telemetry=True, ecosystem=Ecosystem.PYTHON,
from_cache=False)

fetch_database(session=session, full=True, db=False, cached=True,
fetch_database(session=session, full=True, db=db, cached=True,
telemetry=True, ecosystem=Ecosystem.PYTHON,
from_cache=False)

Expand Down
4 changes: 4 additions & 0 deletions safety/scan/validators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

import os
from pathlib import Path
from typing import Optional, Tuple
import typer
Expand Down Expand Up @@ -42,6 +43,9 @@ def fail_if_not_allowed_stage(ctx: typer.Context):
stage = ctx.obj.auth.stage
auth_type: AuthenticationType = ctx.obj.auth.client.get_authentication_type()

if os.getenv("SAFETY_DB_DIR"):
return

if not auth_type.is_allowed_in(stage):
raise typer.BadParameter(f"'{auth_type.value}' auth type isn't allowed with " \
f"the '{stage}' stage.")
Expand Down
29 changes: 29 additions & 0 deletions tests/scan/test_file_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
import pytest
from unittest.mock import Mock, patch
from safety.scan.finder.handlers import PythonFileHandler

@patch('safety.safety.fetch_database')
def test_download_required_assets(mock_fetch_database):
handler = PythonFileHandler()
session = Mock()

os.environ["SAFETY_DB_DIR"] = "/path/to/db"
handler.download_required_assets(session)

_, kwargs = mock_fetch_database.call_args

assert kwargs['db'] == "/path/to/db"

@patch('safety.safety.fetch_database')
def test_download_required_assets_no_db_dir(mock_fetch_database):
handler = PythonFileHandler()
session = Mock()

if "SAFETY_DB_DIR" in os.environ:
del os.environ["SAFETY_DB_DIR"]
handler.download_required_assets(session)

_, kwargs = mock_fetch_database.call_args

assert kwargs['db'] == False
Loading