diff --git a/package.json b/package.json index 41dde86..87d4c16 100644 --- a/package.json +++ b/package.json @@ -1,4 +1,8 @@ { + "scripts": { + "lint": "python3 -m mypy streamer/", + "test": "python3 run_end_to_end_tests.py" + }, "devDependencies": { "karma": "^6.3.16", "karma-chrome-launcher": "^3.1.0", diff --git a/shaka-streamer b/shaka-streamer index 9481a7c..70587de 100755 --- a/shaka-streamer +++ b/shaka-streamer @@ -61,16 +61,15 @@ def main(): 'bitrates and resolutions for transcoding. ' + '(optional, see example in ' + 'config_files/bitrate_config.yaml)') + parser.add_argument('-c', '--cloud-url', + default=None, + help='The Google Cloud Storage or Amazon S3 URL to ' + + 'upload to. (Starts with gs:// or s3://)') parser.add_argument('-o', '--output', default='output_files', help='The output folder to write files to, or an HTTP ' + 'or HTTPS URL where files will be PUT.' + 'Used even if uploading to cloud storage.') - parser.add_argument('-H', '--add-header', - action='append', - metavar='key=value', - help='Headers to include when sending PUT requests ' + - 'to the specified output location.') parser.add_argument('--skip-deps-check', action='store_true', help='Skip checks for dependencies and their versions. ' + @@ -97,15 +96,9 @@ def main(): with open(args.bitrate_config) as f: bitrate_config_dict = yaml.safe_load(f) - extra_headers = {} - if args.add_header: - for header in args.add_header: - key, value = tuple(header.split('=', 1)) - extra_headers[key] = value - try: with controller.start(args.output, input_config_dict, pipeline_config_dict, - bitrate_config_dict, extra_headers, + bitrate_config_dict, args.cloud_url, not args.skip_deps_check, not args.use_system_binaries): # Sleep so long as the pipeline is still running. diff --git a/streamer/controller_node.py b/streamer/controller_node.py index 25cf090..9a04166 100644 --- a/streamer/controller_node.py +++ b/streamer/controller_node.py @@ -32,7 +32,6 @@ from streamer import __version__ from streamer import autodetect from streamer import min_versions -from streamer import proxy_node from streamer.bitrate_configuration import BitrateConfig, AudioChannelLayout, VideoResolution from streamer.external_command_node import ExternalCommandNode from streamer.input_configuration import InputConfig, InputType, MediaType, Input @@ -42,7 +41,7 @@ from streamer.pipeline_configuration import ManifestFormat, PipelineConfig, StreamingMode from streamer.transcoder_node import TranscoderNode from streamer.periodconcat_node import PeriodConcatNode -from streamer.proxy_node import HTTPUpload +from streamer.proxy_node import ProxyNode import streamer.subprocessWindowsPatch # side-effects only from streamer.util import is_url from streamer.pipe import Pipe @@ -61,7 +60,6 @@ def __init__(self) -> None: dir=global_temp_dir, prefix='shaka-live-', suffix='') self._nodes: List[NodeBase] = [] - self.upload_proxy: Optional[HTTPUpload] = None def __del__(self) -> None: # Clean up named pipes by removing the temp directory we placed them in. @@ -77,7 +75,7 @@ def start(self, output_location: str, input_config_dict: Dict[str, Any], pipeline_config_dict: Dict[str, Any], bitrate_config_dict: Dict[Any, Any] = {}, - extra_headers: Dict[str, str] = {}, + bucket_url: Union[str, None] = None, check_deps: bool = True, use_hermetic: bool = True) -> 'ControllerNode': """Create and start all other nodes. @@ -168,15 +166,22 @@ def next_short_version(version: str) -> str: self._input_config = InputConfig(input_config_dict) self._pipeline_config = PipelineConfig(pipeline_config_dict) - # Note that we remove the trailing slash from the output location, because - # otherwise GCS would create a subdirectory whose name is "". - output_location = output_location.rstrip('/') - if (proxy_node.is_supported_protocol(output_location) - and self._pipeline_config.use_local_proxy): - self.upload_proxy = self.get_upload_node(output_location, extra_headers) + if bucket_url is not None: + if not ProxyNode.is_understood(bucket_url): + url_prefixes = [ + protocol + '://' for protocol in ProxyNode.ALL_SUPPORTED_PROTOCOLS] + raise RuntimeError( + 'Invalid cloud URL! Only these are supported: ' + + ', '.join(url_prefixes)) + + if not ProxyNode.is_supported(bucket_url): + raise RuntimeError('Missing libraries for cloud URL: ' + bucket_url) + + upload_proxy = ProxyNode.create(bucket_url) + # All the outputs now should be sent to the proxy server instead. - output_location = self.upload_proxy.server_location - self._nodes.append(self.upload_proxy) + output_location = upload_proxy.server_location + self._nodes.append(upload_proxy) if not is_url(output_location): # Check if the directory for outputted Packager files exists, and if it @@ -208,6 +213,10 @@ def next_short_version(version: str) -> str: output_location) else: # InputConfig contains multiperiod_inputs_list only. + if bucket_url: + raise RuntimeError( + 'Direct cloud upload is incompatible with multiperiod support.') + # Create one Transcoder node and one Packager node for each period. for i, singleperiod in enumerate(self._input_config.multiperiod_inputs_list): sub_dir_name = 'period_' + str(i + 1) @@ -220,8 +229,7 @@ def next_short_version(version: str) -> str: self._nodes.append(PeriodConcatNode( self._pipeline_config, packager_nodes, - output_location, - self.upload_proxy)) + output_location)) for node in self._nodes: node.start() @@ -309,8 +317,7 @@ def _append_nodes_for_inputs_list(self, inputs: List[Input], # and put that period in it. if period_dir: output_location = os.path.join(output_location, period_dir) - if not is_url(output_location): - os.mkdir(output_location) + os.mkdir(output_location) self._nodes.append(PackagerNode(self._pipeline_config, output_location, @@ -324,15 +331,11 @@ def check_status(self) -> ProcessStatus: If one node is errored, this returns Errored; otherwise if one node is running, this returns Running; this only returns Finished if all nodes are finished. If there are no nodes, this returns Finished. - - :rtype: ProcessStatus """ if not self._nodes: return ProcessStatus.Finished - value = max(node.check_status().value for node in self._nodes - # We don't check the the upload node. - if node != self.upload_proxy) + value = max(node.check_status().value for node in self._nodes) return ProcessStatus(value) def stop(self) -> None: @@ -358,30 +361,6 @@ def is_low_latency_dash_mode(self) -> bool: return self._pipeline_config.low_latency_dash_mode - def get_upload_node(self, upload_location: str, - extra_headers: Dict[str, str]) -> HTTPUpload: - """ - Args: - upload_location (str): The location where media content will be uploaded. - extra_headers (Dict[str, str]): Extra headers to be added when - sending the PUT request to `upload_location`. - - :rtype: HTTPUpload - :raises: `RuntimeError` if the protocol used in `upload_location` was not - recognized. - """ - - # We need to pass a temporary direcotry when working with multi-period input - # and using HTTP PUT for uploading. This is so that the HTTPUpload - # keeps a copy of the manifests in the temporary directory so we can use them - # later to assemble the multi-period manifests. - upload_temp_dir = None - if self._input_config.multiperiod_inputs_list: - upload_temp_dir = os.path.join(self._temp_dir, 'multiperiod_manifests') - os.mkdir(upload_temp_dir) - return proxy_node.get_upload_node(upload_location, extra_headers, - upload_temp_dir) - class VersionError(Exception): """A version error for one of Shaka Streamer's external dependencies. diff --git a/streamer/periodconcat_node.py b/streamer/periodconcat_node.py index 6fd43f6..2b11f23 100644 --- a/streamer/periodconcat_node.py +++ b/streamer/periodconcat_node.py @@ -17,17 +17,14 @@ import os import re import time -from typing import List, Optional +from typing import List from xml.etree import ElementTree -from http.client import HTTPConnection, CREATED from streamer import __version__ from streamer.node_base import ProcessStatus, ThreadedNodeBase from streamer.packager_node import PackagerNode from streamer.pipeline_configuration import PipelineConfig, ManifestFormat from streamer.output_stream import AudioOutputStream, VideoOutputStream from streamer.m3u8_concater import HLSConcater -from streamer.proxy_node import HTTPUpload -from streamer.util import is_url class PeriodConcatNode(ThreadedNodeBase): @@ -38,20 +35,17 @@ class PeriodConcatNode(ThreadedNodeBase): def __init__(self, pipeline_config: PipelineConfig, packager_nodes: List[PackagerNode], - output_location: str, - upload_proxy: Optional[HTTPUpload]) -> None: + output_location: str) -> None: """Stores all relevant information needed for the period concatenation.""" - super().__init__(thread_name='periodconcat', continue_on_exception=False, sleep_time=1) + super().__init__(thread_name='periodconcat', continue_on_exception=False, sleep_time=3) self._pipeline_config = pipeline_config self._output_location = output_location self._packager_nodes: List[PackagerNode] = packager_nodes - self._proxy_node = upload_proxy - self._concat_will_fail = self._check_failed_concatenation() - - def _check_failed_concatenation(self) -> bool: + self._concat_will_fail = False + # know whether the first period has video and audio or not. fp_has_vid, fp_has_aud = False, False - for output_stream in self._packager_nodes[0].output_streams: + for output_stream in packager_nodes[0].output_streams: if isinstance(output_stream, VideoOutputStream): fp_has_vid = True elif isinstance(output_stream, AudioOutputStream): @@ -65,6 +59,7 @@ def _check_failed_concatenation(self) -> bool: elif isinstance(output_stream, AudioOutputStream): has_aud = True if has_vid != fp_has_vid or has_aud != fp_has_aud: + self._concat_will_fail = True print("\nWARNING: Stopping period concatenation.") print("Period#{} has {}video and has {}audio while Period#1 " "has {}video and has {}audio.".format(i + 1, @@ -77,21 +72,8 @@ def _check_failed_concatenation(self) -> bool: "\tperiods with other periods that have video.\n" "\tThis is necessary for the concatenation to be performed successfully.\n") time.sleep(5) - return True - - if self._proxy_node is None and is_url(self._output_location): - print("\nWARNING: Stopping period concatenation.") - print("Shaka Packager is using HTTP PUT but not using" - " Shaka Streamer's upload proxy.") - print("\nHINT:\n\tShaka Streamer's upload proxy stores the manifest files\n" - "\ttemporarily in the local filesystem to use them for period concatenation.\n" - "\tSet use_local_proxy to True in the pipeline config to enable the" - " upload proxy.\n") - time.sleep(5) - return True - # Otherwise, we don't have a reason to fail. - return False - + break + def _thread_single_pass(self) -> None: """Watches all the PackagerNode(s), if at least one of them is running it skips this _thread_single_pass, if all of them are finished, it starts period concatenation, if one of @@ -108,41 +90,13 @@ def _thread_single_pass(self) -> None: 'to an error in PackagerNode#{}.'.format(i + 1)) if self._concat_will_fail: - raise RuntimeError('Unable to concatenate the inputs') - - # If the packager was pushing HTTP requests to the stream's proxy server, - # the proxy server should have stored the manifest files in a temporary - # directory in the filesystem. - if self._proxy_node is not None: - assert self._proxy_node.temp_dir, ('There should be a proxy temp direcotry' - ' when processing multi-period input') - self._output_location = self._proxy_node.temp_dir - # As the period concatenator node is the last to run, changing the - # output location at run time won't disturb any other node. - for packager_node in self._packager_nodes: - packager_node.output_location = packager_node.output_location.replace( - self._proxy_node.server_location, - self._proxy_node.temp_dir, 1) + raise RuntimeError('Unable to concatenate the inputs.') if ManifestFormat.DASH in self._pipeline_config.manifest_format: self._dash_concat() if ManifestFormat.HLS in self._pipeline_config.manifest_format: self._hls_concat() - - # Push the concatenated manifests if a proxy is used. - if self._proxy_node is not None: - conn = HTTPConnection(self._proxy_node.server.server_name, - self._proxy_node.server.server_port) - # The concatenated manifest files where written in `self._output_location`. - for manifest_file_name in os.listdir(self._output_location): - if manifest_file_name.endswith(('.mpd', '.m3u8')): - manifest_file_path = os.path.join(self._output_location, manifest_file_name) - conn.request('PUT', '/' + manifest_file_name, open(manifest_file_path, 'r')) - res = conn.getresponse() - if res.status != CREATED: - print("Got unexpected status code: {}, Msg: {!r}".format(res.status, - res.read())) self._status = ProcessStatus.Finished diff --git a/streamer/pipe.py b/streamer/pipe.py index 4f1bb90..555f2c7 100644 --- a/streamer/pipe.py +++ b/streamer/pipe.py @@ -24,7 +24,7 @@ class Pipe: """A class that represents a pipe.""" - def __init__(self): + def __init__(self) -> None: """Initializes a non-functioning pipe.""" self._read_pipe_name = '' diff --git a/streamer/pipeline_configuration.py b/streamer/pipeline_configuration.py index d898877..8ce8aba 100644 --- a/streamer/pipeline_configuration.py +++ b/streamer/pipeline_configuration.py @@ -320,11 +320,6 @@ class PipelineConfig(configuration.Base): default=EncryptionConfig({})).cast() """Encryption settings.""" - use_local_proxy = configuration.Field(bool, default=True).cast() - """Whether to use shaka streamer's local proxy when uploading to a remote - storage. This must be set to True when uploading to GCS or amazon S3 buckets. - """ - # TODO: Generalize this to low_latency_mode once LL-HLS is supported by Packager low_latency_dash_mode = configuration.Field(bool, default=False).cast() """If true, stream in low latency mode for DASH.""" diff --git a/streamer/proxy_node.py b/streamer/proxy_node.py index a0e3e9c..8a68913 100644 --- a/streamer/proxy_node.py +++ b/streamer/proxy_node.py @@ -12,198 +12,172 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""A module that implements a simple proxy server for uploading packaged -content to cloud storage providers (GCS and S3) and also any server that -accepts PUT requests. -""" +"""A simple proxy server to upload to cloud stroage providers.""" -import os -import json +import abc import threading +import traceback import urllib.parse -from typing import Optional, Union, Dict +from typing import IO, Optional from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler -from http.client import HTTPConnection, HTTPSConnection, CREATED, OK from streamer.node_base import ProcessStatus, ThreadedNodeBase -from streamer.util import RequestBodyAsFileIO -# Protocols we have classes for to handle them. -SUPPORTED_PROTOCOLS = ['http', 'https', 'gs', 's3'] +HTTP_STATUS_CREATED = 201 +HTTP_STATUS_FAILED = 500 +MAX_CHUNK_SIZE = (1 << 20) # 1 MB -class RequestHandler(BaseHTTPRequestHandler): +# Supported protocols. Built based on which optional modules are available for +# cloud storage providers. +SUPPORTED_PROTOCOLS: list[str] = [] + +# All supported protocols. Used to provide more useful error messages. +ALL_SUPPORTED_PROTOCOLS: list[str] = ['gs', 's3'] + + +try: + import google.cloud.storage as gcs # type: ignore + SUPPORTED_PROTOCOLS.append('gs') +except: + pass + +try: + import boto3 as aws # type: ignore + SUPPORTED_PROTOCOLS.append('s3') +except: + pass + + +class RequestHandlerBase(BaseHTTPRequestHandler): """A request handler that processes the PUT requests coming from - shaka packager and pushes them to the destination. + Shaka Packager and pushes them to the destination. """ + def do_PUT(self) -> None: + """Handle the PUT requests coming from Shaka Packager.""" + try: + if self.headers.get('Transfer-Encoding', '').lower() == 'chunked': + self.start_chunked(self.path) + + while True: + # Parse the chunk size + chunk_size_line = self.rfile.readline().strip() + chunk_size = int(chunk_size_line, 16) + + # Read the chunk and process it + if chunk_size != 0: + self.handle_chunk(self.rfile.read(chunk_size)) + self.rfile.readline() # Read the trailer + + if chunk_size == 0: + break # EOF + + self.end_chunked() + else: + content_length = int(self.headers['Content-Length']) + if content_length != 0: + self.handle_non_chunked(self.path, content_length, self.rfile) + + self.rfile.close() + self.send_response(HTTP_STATUS_CREATED) + except Exception as ex: + print('Upload failure: ' + str(ex)) + traceback.print_exc() + self.send_response(HTTP_STATUS_FAILED) + self.end_headers() + + @abc.abstractmethod + def handle_non_chunked(self, path: str, length: int, file: IO) -> None: + """Write the non-chunked data stream from |file| to the destination.""" + pass + + @abc.abstractmethod + def start_chunked(self, path: str) -> None: + """Set up for a chunked transfer to the destination.""" + pass + + @abc.abstractmethod + def handle_chunk(self, data: bytes) -> None: + """Handle a single chunk of data.""" + pass - def __init__(self, conn: Union[HTTPConnection, HTTPSConnection], - extra_headers: Dict[str, str], base_path: str, param_query: str, - temp_dir: Optional[str], *args, **kwargs): + @abc.abstractmethod + def end_chunked(self) -> None: + """End the chunked transfer.""" + pass - self._conn = conn - # Extra headers to add when sending the request to the host - # using `self._conn`. - self._extra_headers = extra_headers - # The base path that will be prepended to the path of the handled request - # before forwarding the request to the host using `self._conn`. + +class GCSHandler(RequestHandlerBase): + def __init__(self, bucket: gcs.Bucket, base_path: str, + *args, **kwargs) -> None: + self._bucket = bucket self._base_path = base_path - # Parameters and query string to send in the url with each forwarded request. - self._params_and_querystring = param_query - self._temp_dir = temp_dir - # Call `super().__init__()` last because this call is what handles the - # actual request and we need the variables defined above to handle - # this request. + + # The HTTP server passes *args and *kwargs that we need to pass along, but + # don't otherwise care about. super().__init__(*args, **kwargs) - def do_PUT(self): - """do_PUT will handle the PUT requests coming from shaka packager.""" - - headers = {} - # Use the same headers for requesting. - for k, v in self.headers.items(): - if k.lower() != 'host': - headers[k] = v - # Add the extra headers, this might contain an access token for instance. - for k, v in self._extra_headers.items(): - headers[k] = v - - # Don't chunk by default, as the request body is already chunked - # or we have a content-length header which means we are not using - # chunked transfer-encoding. - encode_chunked = False - content_length = self.headers['Content-Length'] - # Store the manifest files locally in `self._temp_dir`. - if self._temp_dir is not None and self.path.endswith(('.mpd', '.m3u8')): - body = self._write_body_and_get_file_io() - if content_length == None: - # We need to re-chunk it again as we wrote it as a whole - # to the filesystem. - encode_chunked = True - else: - content_length = content_length and int(content_length) - body = RequestBodyAsFileIO(self.rfile, content_length) - - # The url will be the result of joining the base path we should - # send the request to with the path this request came to. - url = self._base_path + self.path - # Also include any parameters and the query string. - url += self._params_and_querystring - - self._conn.request('PUT', url, body, headers, encode_chunked=encode_chunked) - res = self._conn.getresponse() - # Disable response logging. - self.log_request = lambda _: None - # Respond to Shaka Packager with the response we got. - self.send_response(res.status) - self.end_headers() - # self.wfile.write(res.read()) - # The destination should send (201/CREATED), but some do also send (200/OK). - if res.status != CREATED and res.status != OK: - print('Unexpected status for the PUT request:' - ' {}, ErrMsg: {!r}'.format(res.status, res.read())) - - def _write_body_and_get_file_io(self): - """A method that writes a request body to the filesystem - and returns a file io object opened for reading. - """ - - # Store the request body in `self._temp_dir`. - # Ignore the first '/' `self.path` as posixpath will think - # it points to the root direcotry. - path = os.path.join(self._temp_dir, self.path[1:]) - # With `exist_ok=True`, any intermidiate direcotries are created if needed. - os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, 'wb') as request_body_file: - if self.headers['Content-Length'] is not None: - content_length = int(self.headers['Content-Length']) - request_body_file.write(self.rfile.read(content_length)) - else: - while True: - bytes_chunk_size = self.rfile.readline() - int_chunk_size = int(bytes_chunk_size.strip(), base=16) - request_body_file.write(self.rfile.read(int_chunk_size)) - # An empty newline that we have to consume. - self.rfile.readline() - # Chunk of size zero indicates that we have reached the end. - if int_chunk_size == 0: - break - return open(path, 'rb') - - -class RequestHandlersFactory(): - """A request handlers' factory that produces a RequestHandler whenever - its __call__ method is called. It stores all the relevant data that the - instantiated request handler will need when sending a request to the host. - """ + def handle_non_chunked(self, path: str, length: int, file: IO) -> None: + full_path = self._base_path + path + blob = self._bucket.blob(full_path) + blob.upload_from_file(file, size=length, retries=3) - def __init__(self, upload_location: str, initial_headers: Dict[str, str] = {}, - temp_dir: Optional[str] = None, max_conns: int = 50): + def start_chunked(self, path: str) -> None: + full_path = self._base_path + path + blob = self._bucket.blob(full_path) + self._chunk_file = blob.open('wb') - url = urllib.parse.urlparse(upload_location) - if url.scheme not in ['http', 'https']: - # We can only instantiate HTTP/HTTPS connections. - raise RuntimeError("Unsupported scheme: {}", url.scheme) - self._ConnectionFactory = HTTPConnection if url.scheme == 'http' \ - else HTTPSConnection - self._destination_host = url.netloc - # Store the url path to prepend it to the path of each handled - # request before forwarding the request to `self._destination_host`. - self._base_path = url.path - # Store the parameters and the query string to send them in - # any request going to `self._destination_host`. - self._params_query = ';' + url.params if url.params else '' - self._params_query += '?' + url.query if url.query else '' - # These headers are going to be sent to `self._destination_host` - # with each request along with the headers that the request handler - # receives. Note that these extra headers can possibely overwrite - # the original request headers that the request handler received. - self._extra_headers = initial_headers - self._temp_dir = temp_dir - self._max_conns = max_conns - - def __call__(self, *args, **kwargs) -> RequestHandler: - """This magical method makes a RequestHandlersFactory instance - callable and returns a RequestHandler when called. - """ - - connection = self._ConnectionFactory(self._destination_host) - return RequestHandler(connection, self._extra_headers, - self._base_path, self._params_query, self._temp_dir, - *args, **kwargs) - - def update_headers(self, **kwargs): - self._extra_headers.update(**kwargs) - - -class HTTPUpload(ThreadedNodeBase): - """A ThreadedNodeBase subclass that launches a local threaded - HTTP server running at `self.server_location` and connected to - the host of `upload_location`. The requests sent to this server - will be sent to the `upload_location` after adding `extra_headers` - to its headers. if `temp_dir` argument was not None, DASH and HLS - manifests will be stored in it before sending them to `upload_location`. + def handle_chunk(self, data: bytes) -> None: + self._chunk_file.write(data) + + def end_chunked(self) -> None: + self._chunk_file.close() + + +class S3Handler(RequestHandlerBase): + def __init__(self, upload_location: str, *args, **kwargs) -> None: + self._upload_location = upload_location + + # The HTTP server passes *args and *kwargs that we need to pass along, but + # don't otherwise care about. + super().__init__(*args, **kwargs) + + def handle_non_chunked(self, path: str, length: int, file: IO) -> None: + # FIXME: S3 upload + pass + + def start_chunked(self, path: str) -> None: + # FIXME: S3 upload + pass + + def handle_chunk(self, data: bytes) -> None: + # FIXME: S3 upload + pass + + def end_chunked(self) -> None: + # FIXME: S3 upload + pass + + +class HTTPUploadBase(ThreadedNodeBase): + """Runs an HTTP server at `self.server_location` to upload to cloud. + + Subclasses handle upload to specific cloud storage providers. The local HTTP server at `self.server_location` can only ingest PUT requests. """ - def __init__(self, upload_location: str, extra_headers: Dict[str, str], - temp_dir: Optional[str], - periodic_job_wait_time: float = 3600 * 24 * 365.25): - + def __init__(self) -> None: super().__init__(thread_name=self.__class__.__name__, continue_on_exception=True, - sleep_time=periodic_job_wait_time) + sleep_time=3) - self.temp_dir = temp_dir - self.RequestHandlersFactory = RequestHandlersFactory(upload_location, - extra_headers, - self.temp_dir) + handler_factory = ( + lambda *args, **kwargs: self.create_handler(*args, **kwargs)) # By specifying port 0, a random unused port will be chosen for the server. - self.server = ThreadingHTTPServer(('localhost', 0), - self.RequestHandlersFactory) + self.server = ThreadingHTTPServer(('localhost', 0), handler_factory) self.server_location = 'http://' + self.server.server_name + \ ':' + str(self.server.server_port) @@ -211,155 +185,80 @@ def __init__(self, upload_location: str, extra_headers: Dict[str, str], self.server_thread = threading.Thread(name=self.server_location, target=self.server.serve_forever) - def stop(self, status: Optional[ProcessStatus]): + @abc.abstractmethod + def create_handler(self, *args, **kwargs) -> BaseHTTPRequestHandler: + """Returns a cloud-provider-specific request handler to upload to cloud.""" + pass + + def stop(self, status: Optional[ProcessStatus]) -> None: self.server.shutdown() self.server_thread.join() return super().stop(status) - def start(self): + def start(self) -> None: self.server_thread.start() return super().start() - def _thread_single_pass(self): - return self.periodic_job() + def check_status(self) -> ProcessStatus: + # This makes sure this node will never prevent the shutdown of the whole + # system. It will be stopped explicitly when ControllerNode tears down. + return ProcessStatus.Finished - def periodic_job(self) -> None: - # Ideally, we will have nothing to do periodically after the wait time - # which is a very long time by default. However, this can be overridden - # by subclasses and populated with calls to all the functions that need - # to be executed periodically. + def _thread_single_pass(self) -> None: + # Nothing to do here. return -class GCSUpload(HTTPUpload): - """The upload node used when PUT requesting to a GCS bucket. +class GCSUpload(HTTPUploadBase): + """Upload to Google Cloud Storage.""" - It will parse the `upload_location` argument with `gs://` protocol - and use the GCP REST API that uses HTTPS protocol instead. - """ + def __init__(self, upload_location: str) -> None: + url = urllib.parse.urlparse(upload_location) + self._client = gcs.Client() + self._bucket = self._client.bucket(url.netloc) + # Strip both left and right slashes. Otherwise, we get a blank folder name. + self._base_path = url.path.strip('/') + super().__init__() - def __init__(self, upload_location: str, extra_headers: Dict[str, str], - temp_dir: Optional[str]): - upload_location = 'https://storage.googleapis.com/' + upload_location[5:] - - # Normalize the extra headers dictionary. - for key in list(extra_headers.copy()): - extra_headers[key.lower()] = extra_headers.pop(key) - - # We don't have to get a refresh token. Maybe there is an access token - # provided and we won't outlive it anyway, but that's the user's responsibility. - self.refresh_token = extra_headers.pop('refresh-token', None) - self.client_id = extra_headers.pop('client-id', None) - self.client_secret = extra_headers.pop('client-secret', None) - # The access token expires after 3600s in GCS. - refresh_period = int(extra_headers.pop('refresh-every', None) or 3300) - - super().__init__(upload_location, extra_headers, temp_dir, refresh_period) - - # We yet don't have an access token, so we need to get a one. - self._refresh_access_token() - - def _refresh_access_token(self): - if (self.refresh_token is not None - and self.client_id is not None - and self.client_secret is not None): - conn = HTTPSConnection('oauth2.googleapis.com') - req_body = { - 'grant_type': 'refresh_token', - 'refresh_token': self.refresh_token, - 'client_id': self.client_id, - 'client_secret': self.client_secret, - } - conn.request('POST', '/token', json.dumps(req_body)) - res = conn.getresponse() - if res.status == OK: - res_body = json.loads(res.read()) - # Update the Authorization header that the request factory has. - auth = res_body['token_type'] + ' ' + res_body['access_token'] - self.RequestHandlersFactory.update_headers(Authorization=auth) - else: - print("Couldn't refresh access token. ErrCode: {}, ErrMst: {!r}".format( - res.status, res.read())) - else: - print("Non sufficient info provided to refresh the access token.") - print("To refresh access token periodically, 'refresh-token', 'client-id'" - " and 'client-secret' headers must be provided.") - print("After the current access token expires, the upload will fail.") + def create_handler(self, *args, **kwargs) -> BaseHTTPRequestHandler: + """Returns a cloud-provider-specific request handler to upload to cloud.""" + return GCSHandler(self._bucket, self._base_path, *args, **kwargs) - def periodic_job(self) -> None: - self._refresh_access_token() +class S3Upload(HTTPUploadBase): + """Upload to Amazon S3.""" -class S3Upload(HTTPUpload): - """The upload node used when PUT requesting to a S3 bucket. + def __init__(self, upload_location: str) -> None: + self._upload_location = upload_location + super().__init__() - It will parse the `upload_location` argument with `s3://` protocol - and use the AWS REST API that uses HTTPS protocol instead. - """ + def create_handler(self, *args, **kwargs) -> BaseHTTPRequestHandler: + """Returns a cloud-provider-specific request handler to upload to cloud.""" + return S3Handler(self._upload_location, *args, **kwargs) - def __init__(self, upload_location: str, extra_headers: Dict[str, str], - temp_dir: Optional[str]): - raise NotImplementedError("S3 uploads aren't working yet.") - url_parts = upload_location[5:].split('/', 1) - bucket = url_parts[0] - path = '/' + url_parts[1] if len(url_parts) > 1 else '' - upload_location = 'https://' + bucket + '.s3.amazonaws.com' + path - - # We don't have to get a refresh token. Maybe there is an access token - # provided and we won't outlive it anyway, but that's the user's responsibility. - self.refresh_token = extra_headers.pop('refresh-token', None) - self.client_id = extra_headers.pop('client-id', None) - # The access token expires after 3600s in S3. - refresh_period = int(extra_headers.pop('refresh-every', None) or 3300) - - super().__init__(upload_location, extra_headers, temp_dir, refresh_period) - - # We yet don't have an access token, so we need to get a one. - self._refresh_access_token() - - def _refresh_access_token(self): - if (self.refresh_token is not None and self.client_id is not None): - conn = HTTPSConnection('api.amazon.com') - req_body = { - 'grant_type': 'refresh_token', - 'refresh_token': self.refresh_token, - 'client_id': self.client_id, - } - conn.request('POST', '/auth/o2/token', json.dumps(req_body)) - res = conn.getresponse() - if res.status == OK: - res_body = json.loads(res.read()) - # Update the Authorization header that the request factory has. - auth = res_body['token_type'] + ' ' + res_body['access_token'] - self.RequestHandlersFactory.update_headers(Authorization=auth) - else: - print("Couldn't refresh access token. ErrCode: {}, ErrMst: {!r}".format( - res.status, res.read())) - else: - print("Non sufficient info provided to refresh the access token.") - print("To refresh access token periodically, 'refresh-token'" - " and 'client-id' headers must be provided.") - print("After the current access token expires, the upload will fail.") - def periodic_job(self) -> None: - self._refresh_access_token() +class ProxyNode(object): + SUPPORTED_PROTOCOLS = SUPPORTED_PROTOCOLS + ALL_SUPPORTED_PROTOCOLS = ALL_SUPPORTED_PROTOCOLS + @staticmethod + def create(upload_location: str) -> HTTPUploadBase: + """Creates an upload node based on the protocol used in |upload_location|.""" + if upload_location.startswith("gs://"): + return GCSUpload(upload_location) + elif upload_location.startswith("s3://"): + return S3Upload(upload_location) + else: + raise RuntimeError("Protocol of {} isn't supported".format(upload_location)) -def get_upload_node(upload_location: str, extra_headers: Dict[str, str], - temp_dir: Optional[str] = None) -> HTTPUpload: - """Instantiates an appropriate HTTPUpload node based on the protocol - used in `upload_location` url. - """ + @staticmethod + def is_understood(upload_location: str) -> bool: + """Is the URL understood, independent of libraries available?""" + url = urllib.parse.urlparse(upload_location) + return url.scheme in ALL_SUPPORTED_PROTOCOLS - if upload_location.startswith(("http://", "https://")): - return HTTPUpload(upload_location, extra_headers, temp_dir) - elif upload_location.startswith("gs://"): - return GCSUpload(upload_location, extra_headers, temp_dir) - elif upload_location.startswith("s3://"): - return S3Upload(upload_location, extra_headers, temp_dir) - else: - raise RuntimeError("Protocol of {} isn't supported".format(upload_location)) - -def is_supported_protocol(upload_location: str) -> bool: - return bool([upload_location.startswith(protocol + '://') for - protocol in SUPPORTED_PROTOCOLS].count(True)) + @staticmethod + def is_supported(upload_location: str) -> bool: + """Is the URL supported with the libraries available?""" + url = urllib.parse.urlparse(upload_location) + return url.scheme in SUPPORTED_PROTOCOLS diff --git a/streamer/util.py b/streamer/util.py index 19690ce..7250df9 100644 --- a/streamer/util.py +++ b/streamer/util.py @@ -14,99 +14,7 @@ """Utility functions used by multiple modules.""" -import io -from typing import Optional - - def is_url(output_location: str) -> bool: """Returns True if the output location is a URL.""" - return output_location.startswith(('http://', - 'https://')) - - -class RequestBodyAsFileIO(io.BufferedIOBase): - """A class that provides a layer of access to an HTTP request body. It provides - an interface to treat a request body (of type `io.BufferedIOBase`) as a file. - Since a request body does not have an `EOF`, this class will encapsulate the - logic of using Content-Length or chunk size to provide an emulated `EOF`. - - This implementation is much faster than storing the request body - in the filesystem then reading it with an `EOF` included. - """ - - def __init__(self, rfile: io.BufferedIOBase, content_length: Optional[int]): - super().__init__() - self._body = rfile - # Decide whether this is a chunked request or not based on content length. - if content_length is not None: - self._is_chunked = False - self._left_to_read = content_length - else: - self._is_chunked = True - self._last_chunk_read = False - self._buffer = b'' - - def read(self, blocksize: Optional[int] = None) -> bytes: - """This method reads `self.body` incrementally with each call. - This is done because if we try to use `read()` on `self._body` it will wait - forever for an `EOF` which is not present and will never be. - - This method -like the original `read()`- will read up to (but not more than) - `blocksize` if it is a non-negative integer, and will read till `EOF` if - blocksize is None, a negative integer, or not passed. - """ - - if self._is_chunked: - return self._read_chunked(blocksize) - else: - return self._read_not_chunked(blocksize) - - def _read_chunked(self, blocksize: Optional[int] = None) -> bytes: - """This method provides the read functionality from a request - body with chunked Transfer-Encoding. - """ - - # For non-negative blocksize values. - if blocksize and blocksize >= 0: - # Keep buffering until we can fulfil the blocksize or there - # are no chunks left to buffer. - while blocksize > len(self._buffer) and not self._last_chunk_read: - byte_chunk_size = self._body.readline() - self._buffer += byte_chunk_size - int_chunk_size = int(byte_chunk_size.strip(), base=16) - self._buffer += self._body.read(int_chunk_size) - # Consume the CLRF after each chunk. - self._buffer += self._body.readline() - if int_chunk_size == 0: - # A zero sized chunk indicates that no more chunks left. - self._last_chunk_read = True - bytes_read, self._buffer = self._buffer[:blocksize], self._buffer[blocksize:] - return bytes_read - # When blocksize is a negative integer or None. - else: - bytes_read = b'' - while True: - chunk = self._read_chunked(64 * 1024) - bytes_read += chunk - if chunk == b'': - return bytes_read - - def _read_not_chunked(self, blocksize: Optional[int] = None) -> bytes: - """This method provides the read functionality from a request - body of a known Content-Length. - """ - - # Don't try to read if there is nothing to read. - if self._left_to_read == 0: - # This indicates `EOF` for the caller. - return b'' - # For non-negative blocksize values. - if blocksize and blocksize >= 0: - size_to_read = min(blocksize, self._left_to_read) - self._left_to_read -= size_to_read - return self._body.read(size_to_read) - # When blocksize is a negative integer or None. - else: - size_to_read, self._left_to_read = self._left_to_read, 0 - return self._body.read(size_to_read) - + return (output_location.startswith('http:') or + output_location.startswith('https:'))