diff --git a/pshtt/pshtt.py b/pshtt/pshtt.py index 0ca47de9..cee25776 100644 --- a/pshtt/pshtt.py +++ b/pshtt/pshtt.py @@ -26,7 +26,7 @@ except ImportError: from urllib2 import URLError -import sslyze +from pathlib import Path # Python3 from sslyze import ( Scanner, ServerConnectivityTester, @@ -35,6 +35,7 @@ ) from sslyze.errors import ConnectionToServerFailed from sslyze.plugins.scan_commands import ScanCommand +from sslyze.plugins.certificate_info.implementation import CertificateInfoExtraArguments # We're going to be making requests with certificate validation # disabled. Commented next line due to pylint warning that urllib3 is @@ -604,9 +605,22 @@ def https_check(endpoint): try: cert_plugin_result = None - command = ScanCommand.CERTIFICATE_INFO scanner = Scanner() - scan_request = ServerScanRequest(server_info=server_info, scan_commands=[command]) + command = ScanCommand.CERTIFICATE_INFO + if CA_FILE is not None: + command_extra_args = { + ScanCommand.CERTIFICATE_INFO: CertificateInfoExtraArguments(custom_ca_file=Path(CA_FILE)) + } + scan_request = ServerScanRequest( + server_info=server_info, + scan_commands_extra_arguments=command_extra_args, + scan_commands=[command] + ) + else: + scan_request = ServerScanRequest( + server_info=server_info, + scan_commands=[command] + ) scanner.queue_scan(scan_request) # Retrieve results from generator object scan_result = [x for x in scanner.get_results()][0] @@ -728,9 +742,21 @@ def https_check(endpoint): if(PT_INT_CA_FILE is not None): try: cert_plugin_result = None - command = sslyze.plugins.certificate_info_plugin.CertificateInfoScanCommand(ca_file=PT_INT_CA_FILE) - cert_plugin_result = scanner.run_scan_command(server_info, command) - if(cert_plugin_result.verified_certificate_chain is not None): + scanner = Scanner() + command = ScanCommand.CERTIFICATE_INFO + command_extra_args = { + ScanCommand.CERTIFICATE_INFO: CertificateInfoExtraArguments(custom_ca_file=Path(PT_INT_CA_FILE)) + } + scan_request = ServerScanRequest( + server_info=server_info, + scan_commands_extra_arguments=command_extra_args, + scan_commands=[command] + ) + scanner.queue_scan(scan_request) + # Retrieve results from generator object + scan_result = [x for x in scanner.get_results()][0] + cert_plugin_result = scan_result.scan_commands_results[ScanCommand.CERTIFICATE_INFO] + if(cert_plugin_result.certificate_deployments[0].verified_certificate_chain is not None): public_trust = True endpoint.https_public_trusted = public_trust logging.warning("{}: Trusted by special public trust store with intermediate certificates.".format(endpoint.url))