From e6d4db8b242ed9929d0337a0d152b7ae67ab3218 Mon Sep 17 00:00:00 2001 From: meryacine Date: Sat, 23 Oct 2021 07:43:24 +0200 Subject: [PATCH] feat(cloud): Upload through HTTP proxy node --- shaka-streamer | 23 +- streamer/cloud_node.py | 190 --------------- streamer/controller_node.py | 91 +++---- streamer/node_base.py | 5 +- streamer/packager_node.py | 4 +- streamer/periodconcat_node.py | 66 +++++- streamer/pipeline_configuration.py | 5 + streamer/proxy_node.py | 365 +++++++++++++++++++++++++++++ streamer/util.py | 96 +++++++- 9 files changed, 585 insertions(+), 260 deletions(-) delete mode 100644 streamer/cloud_node.py create mode 100644 streamer/proxy_node.py diff --git a/shaka-streamer b/shaka-streamer index f406b3f..9481a7c 100755 --- a/shaka-streamer +++ b/shaka-streamer @@ -61,15 +61,16 @@ def main(): 'bitrates and resolutions for transcoding. ' + '(optional, see example in ' + 'config_files/bitrate_config.yaml)') - parser.add_argument('-c', '--cloud-url', - default=None, - help='The Google Cloud Storage or Amazon S3 URL to ' + - 'upload to. (Starts with gs:// or s3://)') parser.add_argument('-o', '--output', default='output_files', help='The output folder to write files to, or an HTTP ' + 'or HTTPS URL where files will be PUT.' + 'Used even if uploading to cloud storage.') + parser.add_argument('-H', '--add-header', + action='append', + metavar='key=value', + help='Headers to include when sending PUT requests ' + + 'to the specified output location.') parser.add_argument('--skip-deps-check', action='store_true', help='Skip checks for dependencies and their versions. ' + @@ -96,14 +97,15 @@ def main(): with open(args.bitrate_config) as f: bitrate_config_dict = yaml.safe_load(f) - if args.cloud_url: - if (not args.cloud_url.startswith('gs://') and - not args.cloud_url.startswith('s3://')): - parser.error('Invalid cloud URL! Only gs:// and s3:// URLs are supported') + extra_headers = {} + if args.add_header: + for header in args.add_header: + key, value = tuple(header.split('=', 1)) + extra_headers[key] = value try: with controller.start(args.output, input_config_dict, pipeline_config_dict, - bitrate_config_dict, args.cloud_url, + bitrate_config_dict, extra_headers, not args.skip_deps_check, not args.use_system_binaries): # Sleep so long as the pipeline is still running. @@ -114,8 +116,7 @@ def main(): time.sleep(1) except (streamer.controller_node.VersionError, - streamer.configuration.ConfigError, - streamer.cloud_node.CloudAccessError) as e: + streamer.configuration.ConfigError) as e: # These are common errors meant to give the user specific, helpful # information. Format these errors in a relatively friendly way, with no # backtrace or other Python-specific information. diff --git a/streamer/cloud_node.py b/streamer/cloud_node.py deleted file mode 100644 index d44d675..0000000 --- a/streamer/cloud_node.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Pushes output from packager to cloud.""" - -import glob -import os -from streamer.packager_node import PackagerNode -import subprocess -import time - -from streamer.node_base import ProcessStatus, ThreadedNodeBase -from typing import Optional, List - -# This is the HTTP header "Cache-Control" which will be attached to the Cloud -# Storage blobs uploaded by this tool. When the browser requests a file from -# Cloud Storage, the server will use this as the "Cache-Control" header it -# returns. -# -# Here "no-store" means that the response must not be stored in a cache, and -# "no-transform" means that the response must not be manipulated in any way -# (including Chrome's data saver features which might want to re-encode -# content). -# -# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control -CACHE_CONTROL_HEADER = 'Cache-Control: no-store, no-transform' - -COMMON_GSUTIL_ARGS = [ - 'gsutil', - '-q', # quiet mode: report errors, but not progress - '-h', CACHE_CONTROL_HEADER, # set the appropriate cache header on uploads - '-m', # parllelize the operation - 'rsync', # operation to perform - '-C', # still try to push other files if one fails - '-r', # recurse into folders -] - -class CloudAccessError(Exception): - """Raised when the cloud URL cannot be written to by the user.""" - pass - -class CloudNode(ThreadedNodeBase): - def __init__(self, - input_dir: str, - bucket_url: str, - temp_dir: str, - packager_nodes: List[PackagerNode], - is_vod: bool): - super().__init__(thread_name='cloud', continue_on_exception=True, sleep_time=1) - self._input_dir: str = input_dir - self._bucket_url: str = bucket_url - self._temp_dir: str = temp_dir - self._packager_nodes: List[PackagerNode] = packager_nodes - self._is_vod: bool = is_vod - - @staticmethod - def check_access(bucket_url: str) -> None: - """Called early to test that the user can write to the destination bucket. - - Writes an empty file called ".shaka-streamer-access-check" to the - destination. Raises CloudAccessError if the destination cannot be written - to. - """ - - # Note that we make sure there are not two slashes in a row here, which - # would create a subdirectory whose name is "". - destination = bucket_url.rstrip('/') + '/.shaka-streamer-access-check' - # Note that this can't be "gsutil ls" on the destination, because the user - # might have read-only access. In fact, some buckets grant read-only - # access to anonymous (non-logged-in) users. So writing to the bucket is - # the only way to check. - args = ['gsutil', 'cp', '-', destination] - status = subprocess.run(args, - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.PIPE, - universal_newlines=True) - # If the command failed, raise an error. - if status.returncode != 0: - message = """Unable to write to cloud storage URL: {} - -Please double-check that the URL is correct, that you are signed into the -Google Cloud SDK or Amazon AWS CLI, and that you have access to the -destination bucket. - -Additional output from gsutil: - {}""".format(bucket_url, status.stderr) - raise CloudAccessError(message) - - def _thread_single_pass(self) -> None: - - # Sync the files with the cloud storage. - self._upload() - - for packager_node in self._packager_nodes: - status = packager_node.check_status() - if status == ProcessStatus.Running: - return - - # Do one last sync to be sure that the latest versions of the files are uploaded. - self._upload() - self._status = ProcessStatus.Finished - - def _upload(self) -> None: - # With recursive=True, glob's ** will also match the base dir. - manifest_files = ( - glob.glob(self._input_dir + '/**/*.mpd', recursive=True) + - glob.glob(self._input_dir + '/**/*.m3u8', recursive=True)) - - # The manifest at any moment will reference existing segment files. - # We must be careful not to upload a manifest that references segments that - # haven't been uploaded yet. So first we will capture manifest contents, - # then upload current segments, then upload the manifest contents we - # captured. - - for manifest_path in manifest_files: - # The path within the input dir. - subdir_path = os.path.relpath(manifest_path, self._input_dir) - - # Capture manifest contents, and retry until the file is non-empty or - # until the thread is killed. - with open(manifest_path, 'rb') as f: - contents = f.read() - - while (not contents and - self.check_status() == ProcessStatus.Running): - time.sleep(0.1) - - with open(manifest_path, 'rb') as f: - contents = f.read() - - # Now that we have manifest contents, put them into a temp file so that - # the manifests can be pushed en masse later. - temp_file_path = os.path.join(self._temp_dir, subdir_path) - # Create any necessary intermediate folders. - temp_file_dir_path = os.path.dirname(temp_file_path) - os.makedirs(temp_file_dir_path, exist_ok=True) - # Write the temp file. - with open(temp_file_path, 'wb') as f: - f.write(contents) - - # Sync all files except manifest files. - args = COMMON_GSUTIL_ARGS + [ - '-d', # delete remote files that are no longer needed - '-x', '.*m3u8', # skip m3u8 files, which we'll push separately later - '-x', '.*mpd', # skip mpd files, which we'll push separately later - self._input_dir, # local input folder to sync - self._bucket_url, # destination in cloud storage - ] - # NOTE: The -d option above will not result in the files ignored by -x - # being deleted from the remote storage location. - subprocess.check_call(args) - - compression_args = [] - if self._bucket_url.startswith('gs:'): - # This arg seems to fail on S3, but still works for GCS. - compression_args = [ - '-J', # compress all files in transit, since they are text - ] - - # Sync the temporary copies of the manifest files. - args = COMMON_GSUTIL_ARGS + compression_args + [ - self._temp_dir, # local input folder to sync - self._bucket_url, # destination in cloud storage - ] - subprocess.check_call(args) - - def stop(self, - status: Optional[ProcessStatus]) -> None: - super().stop(status) - - # A fix for issue #30: - if self._is_vod: - # After processing the stop, run _one more_ pass. This is how we ensure - # that the final version of a VOD asset gets uploaded to cloud storage. - # Otherwise, we might not have the final manifest or every single segment - # uploaded. - self._thread_single_pass() - diff --git a/streamer/controller_node.py b/streamer/controller_node.py index d5430d0..25cf090 100644 --- a/streamer/controller_node.py +++ b/streamer/controller_node.py @@ -32,7 +32,7 @@ from streamer import __version__ from streamer import autodetect from streamer import min_versions -from streamer.cloud_node import CloudNode +from streamer import proxy_node from streamer.bitrate_configuration import BitrateConfig, AudioChannelLayout, VideoResolution from streamer.external_command_node import ExternalCommandNode from streamer.input_configuration import InputConfig, InputType, MediaType, Input @@ -42,6 +42,7 @@ from streamer.pipeline_configuration import ManifestFormat, PipelineConfig, StreamingMode from streamer.transcoder_node import TranscoderNode from streamer.periodconcat_node import PeriodConcatNode +from streamer.proxy_node import HTTPUpload import streamer.subprocessWindowsPatch # side-effects only from streamer.util import is_url from streamer.pipe import Pipe @@ -60,6 +61,7 @@ def __init__(self) -> None: dir=global_temp_dir, prefix='shaka-live-', suffix='') self._nodes: List[NodeBase] = [] + self.upload_proxy: Optional[HTTPUpload] = None def __del__(self) -> None: # Clean up named pipes by removing the temp directory we placed them in. @@ -75,7 +77,7 @@ def start(self, output_location: str, input_config_dict: Dict[str, Any], pipeline_config_dict: Dict[str, Any], bitrate_config_dict: Dict[Any, Any] = {}, - bucket_url: Union[str, None] = None, + extra_headers: Dict[str, str] = {}, check_deps: bool = True, use_hermetic: bool = True) -> 'ControllerNode': """Create and start all other nodes. @@ -147,21 +149,6 @@ def next_short_version(version: str) -> str: _check_command_version('Shaka Packager', ['packager', '-version'], min_versions.PACKAGER) - if bucket_url: - # Check that the Google Cloud SDK is at least v212, which introduced - # gsutil 4.33 with an important rsync bug fix. - # https://cloud.google.com/sdk/docs/release-notes - # https://github.com/GoogleCloudPlatform/gsutil/blob/master/CHANGES.md - # This is only required if the user asked for upload to cloud storage. - _check_command_version('Google Cloud SDK', ['gcloud', '--version'], - (212, 0, 0)) - - - if bucket_url: - # If using cloud storage, make sure the user is logged in and can access - # the destination, independent of the version check above. - CloudNode.check_access(bucket_url) - self.hermetic_ffmpeg: Optional[str] = None self.hermetic_packager: Optional[str] = None if use_hermetic: @@ -181,6 +168,16 @@ def next_short_version(version: str) -> str: self._input_config = InputConfig(input_config_dict) self._pipeline_config = PipelineConfig(pipeline_config_dict) + # Note that we remove the trailing slash from the output location, because + # otherwise GCS would create a subdirectory whose name is "". + output_location = output_location.rstrip('/') + if (proxy_node.is_supported_protocol(output_location) + and self._pipeline_config.use_local_proxy): + self.upload_proxy = self.get_upload_node(output_location, extra_headers) + # All the outputs now should be sent to the proxy server instead. + output_location = self.upload_proxy.server_location + self._nodes.append(self.upload_proxy) + if not is_url(output_location): # Check if the directory for outputted Packager files exists, and if it # does, delete it and remake a new one. @@ -194,15 +191,6 @@ def next_short_version(version: str) -> str: 'For HTTP PUT uploads, the pipeline segment_per_file setting ' + 'must be set to True!') - if bucket_url: - raise RuntimeError( - 'Cloud bucket upload is incompatible with HTTP PUT support.') - - if self._input_config.multiperiod_inputs_list: - # TODO: Edit Multiperiod input list implementation to support HTTP outputs - raise RuntimeError( - 'Multiperiod input list support is incompatible with HTTP outputs.') - if self._pipeline_config.low_latency_dash_mode: # Check some restrictions on LL-DASH packaging. if ManifestFormat.DASH not in self._pipeline_config.manifest_format: @@ -214,10 +202,6 @@ def next_short_version(version: str) -> str: raise RuntimeError( 'For low_latency_dash_mode, the utc_timings must be set.') - # Note that we remove the trailing slash from the output location, because - # otherwise GCS would create a subdirectory whose name is "". - output_location = output_location.rstrip('/') - if self._input_config.inputs: # InputConfig contains inputs only. self._append_nodes_for_inputs_list(self._input_config.inputs, @@ -236,18 +220,8 @@ def next_short_version(version: str) -> str: self._nodes.append(PeriodConcatNode( self._pipeline_config, packager_nodes, - output_location)) - - if bucket_url: - cloud_temp_dir = os.path.join(self._temp_dir, 'cloud') - os.mkdir(cloud_temp_dir) - - packager_nodes = [node for node in self._nodes if isinstance(node, PackagerNode)] - self._nodes.append(CloudNode(output_location, - bucket_url, - cloud_temp_dir, - packager_nodes, - self.is_vod())) + output_location, + self.upload_proxy)) for node in self._nodes: node.start() @@ -335,7 +309,8 @@ def _append_nodes_for_inputs_list(self, inputs: List[Input], # and put that period in it. if period_dir: output_location = os.path.join(output_location, period_dir) - os.mkdir(output_location) + if not is_url(output_location): + os.mkdir(output_location) self._nodes.append(PackagerNode(self._pipeline_config, output_location, @@ -349,11 +324,15 @@ def check_status(self) -> ProcessStatus: If one node is errored, this returns Errored; otherwise if one node is running, this returns Running; this only returns Finished if all nodes are finished. If there are no nodes, this returns Finished. + + :rtype: ProcessStatus """ if not self._nodes: return ProcessStatus.Finished - value = max(node.check_status().value for node in self._nodes) + value = max(node.check_status().value for node in self._nodes + # We don't check the the upload node. + if node != self.upload_proxy) return ProcessStatus(value) def stop(self) -> None: @@ -379,6 +358,30 @@ def is_low_latency_dash_mode(self) -> bool: return self._pipeline_config.low_latency_dash_mode + def get_upload_node(self, upload_location: str, + extra_headers: Dict[str, str]) -> HTTPUpload: + """ + Args: + upload_location (str): The location where media content will be uploaded. + extra_headers (Dict[str, str]): Extra headers to be added when + sending the PUT request to `upload_location`. + + :rtype: HTTPUpload + :raises: `RuntimeError` if the protocol used in `upload_location` was not + recognized. + """ + + # We need to pass a temporary direcotry when working with multi-period input + # and using HTTP PUT for uploading. This is so that the HTTPUpload + # keeps a copy of the manifests in the temporary directory so we can use them + # later to assemble the multi-period manifests. + upload_temp_dir = None + if self._input_config.multiperiod_inputs_list: + upload_temp_dir = os.path.join(self._temp_dir, 'multiperiod_manifests') + os.mkdir(upload_temp_dir) + return proxy_node.get_upload_node(upload_location, extra_headers, + upload_temp_dir) + class VersionError(Exception): """A version error for one of Shaka Streamer's external dependencies. diff --git a/streamer/node_base.py b/streamer/node_base.py index f7222f4..5092c79 100644 --- a/streamer/node_base.py +++ b/streamer/node_base.py @@ -166,6 +166,7 @@ def __init__(self, thread_name: str, continue_on_exception: bool, sleep_time: fl self._thread_name = thread_name self._continue_on_exception = continue_on_exception self._sleep_time = sleep_time + self._sleep_waker_event = threading.Event() self._thread = threading.Thread(target=self._thread_main, name=thread_name) def _thread_main(self) -> None: @@ -183,7 +184,7 @@ def _thread_main(self) -> None: return # Wait a little bit before performing the next pass. - time.sleep(self._sleep_time) + self._sleep_waker_event.wait(self._sleep_time) @abc.abstractmethod def _thread_single_pass(self) -> None: @@ -204,6 +205,8 @@ def start(self) -> None: def stop(self, status: Optional[ProcessStatus]) -> None: self._status = ProcessStatus.Finished + # If the thread was sleeping, wake it up. + self._sleep_waker_event.set() self._thread.join() def check_status(self) -> ProcessStatus: diff --git a/streamer/packager_node.py b/streamer/packager_node.py index ae9ccdd..e6e20ca 100644 --- a/streamer/packager_node.py +++ b/streamer/packager_node.py @@ -183,7 +183,7 @@ def _setup_manifest_format(self) -> List[str]: args += [ # Generate DASH manifest file. '--mpd_output', - os.path.join(self.output_location, self._pipeline_config.dash_output), + build_path(self.output_location, self._pipeline_config.dash_output), ] if ManifestFormat.HLS in self._pipeline_config.manifest_format: if self._pipeline_config.streaming_mode == StreamingMode.LIVE: @@ -197,7 +197,7 @@ def _setup_manifest_format(self) -> List[str]: args += [ # Generate HLS playlist file(s). '--hls_master_playlist_output', - os.path.join(self.output_location, self._pipeline_config.hls_output), + build_path(self.output_location, self._pipeline_config.hls_output), ] return args diff --git a/streamer/periodconcat_node.py b/streamer/periodconcat_node.py index 2b11f23..6fd43f6 100644 --- a/streamer/periodconcat_node.py +++ b/streamer/periodconcat_node.py @@ -17,14 +17,17 @@ import os import re import time -from typing import List +from typing import List, Optional from xml.etree import ElementTree +from http.client import HTTPConnection, CREATED from streamer import __version__ from streamer.node_base import ProcessStatus, ThreadedNodeBase from streamer.packager_node import PackagerNode from streamer.pipeline_configuration import PipelineConfig, ManifestFormat from streamer.output_stream import AudioOutputStream, VideoOutputStream from streamer.m3u8_concater import HLSConcater +from streamer.proxy_node import HTTPUpload +from streamer.util import is_url class PeriodConcatNode(ThreadedNodeBase): @@ -35,17 +38,20 @@ class PeriodConcatNode(ThreadedNodeBase): def __init__(self, pipeline_config: PipelineConfig, packager_nodes: List[PackagerNode], - output_location: str) -> None: + output_location: str, + upload_proxy: Optional[HTTPUpload]) -> None: """Stores all relevant information needed for the period concatenation.""" - super().__init__(thread_name='periodconcat', continue_on_exception=False, sleep_time=3) + super().__init__(thread_name='periodconcat', continue_on_exception=False, sleep_time=1) self._pipeline_config = pipeline_config self._output_location = output_location self._packager_nodes: List[PackagerNode] = packager_nodes - self._concat_will_fail = False - + self._proxy_node = upload_proxy + self._concat_will_fail = self._check_failed_concatenation() + + def _check_failed_concatenation(self) -> bool: # know whether the first period has video and audio or not. fp_has_vid, fp_has_aud = False, False - for output_stream in packager_nodes[0].output_streams: + for output_stream in self._packager_nodes[0].output_streams: if isinstance(output_stream, VideoOutputStream): fp_has_vid = True elif isinstance(output_stream, AudioOutputStream): @@ -59,7 +65,6 @@ def __init__(self, elif isinstance(output_stream, AudioOutputStream): has_aud = True if has_vid != fp_has_vid or has_aud != fp_has_aud: - self._concat_will_fail = True print("\nWARNING: Stopping period concatenation.") print("Period#{} has {}video and has {}audio while Period#1 " "has {}video and has {}audio.".format(i + 1, @@ -72,8 +77,21 @@ def __init__(self, "\tperiods with other periods that have video.\n" "\tThis is necessary for the concatenation to be performed successfully.\n") time.sleep(5) - break - + return True + + if self._proxy_node is None and is_url(self._output_location): + print("\nWARNING: Stopping period concatenation.") + print("Shaka Packager is using HTTP PUT but not using" + " Shaka Streamer's upload proxy.") + print("\nHINT:\n\tShaka Streamer's upload proxy stores the manifest files\n" + "\ttemporarily in the local filesystem to use them for period concatenation.\n" + "\tSet use_local_proxy to True in the pipeline config to enable the" + " upload proxy.\n") + time.sleep(5) + return True + # Otherwise, we don't have a reason to fail. + return False + def _thread_single_pass(self) -> None: """Watches all the PackagerNode(s), if at least one of them is running it skips this _thread_single_pass, if all of them are finished, it starts period concatenation, if one of @@ -90,13 +108,41 @@ def _thread_single_pass(self) -> None: 'to an error in PackagerNode#{}.'.format(i + 1)) if self._concat_will_fail: - raise RuntimeError('Unable to concatenate the inputs.') + raise RuntimeError('Unable to concatenate the inputs') + + # If the packager was pushing HTTP requests to the stream's proxy server, + # the proxy server should have stored the manifest files in a temporary + # directory in the filesystem. + if self._proxy_node is not None: + assert self._proxy_node.temp_dir, ('There should be a proxy temp direcotry' + ' when processing multi-period input') + self._output_location = self._proxy_node.temp_dir + # As the period concatenator node is the last to run, changing the + # output location at run time won't disturb any other node. + for packager_node in self._packager_nodes: + packager_node.output_location = packager_node.output_location.replace( + self._proxy_node.server_location, + self._proxy_node.temp_dir, 1) if ManifestFormat.DASH in self._pipeline_config.manifest_format: self._dash_concat() if ManifestFormat.HLS in self._pipeline_config.manifest_format: self._hls_concat() + + # Push the concatenated manifests if a proxy is used. + if self._proxy_node is not None: + conn = HTTPConnection(self._proxy_node.server.server_name, + self._proxy_node.server.server_port) + # The concatenated manifest files where written in `self._output_location`. + for manifest_file_name in os.listdir(self._output_location): + if manifest_file_name.endswith(('.mpd', '.m3u8')): + manifest_file_path = os.path.join(self._output_location, manifest_file_name) + conn.request('PUT', '/' + manifest_file_name, open(manifest_file_path, 'r')) + res = conn.getresponse() + if res.status != CREATED: + print("Got unexpected status code: {}, Msg: {!r}".format(res.status, + res.read())) self._status = ProcessStatus.Finished diff --git a/streamer/pipeline_configuration.py b/streamer/pipeline_configuration.py index 8ce8aba..d898877 100644 --- a/streamer/pipeline_configuration.py +++ b/streamer/pipeline_configuration.py @@ -320,6 +320,11 @@ class PipelineConfig(configuration.Base): default=EncryptionConfig({})).cast() """Encryption settings.""" + use_local_proxy = configuration.Field(bool, default=True).cast() + """Whether to use shaka streamer's local proxy when uploading to a remote + storage. This must be set to True when uploading to GCS or amazon S3 buckets. + """ + # TODO: Generalize this to low_latency_mode once LL-HLS is supported by Packager low_latency_dash_mode = configuration.Field(bool, default=False).cast() """If true, stream in low latency mode for DASH.""" diff --git a/streamer/proxy_node.py b/streamer/proxy_node.py new file mode 100644 index 0000000..a0e3e9c --- /dev/null +++ b/streamer/proxy_node.py @@ -0,0 +1,365 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A module that implements a simple proxy server for uploading packaged +content to cloud storage providers (GCS and S3) and also any server that +accepts PUT requests. +""" + +import os +import json +import threading +import urllib.parse +from typing import Optional, Union, Dict +from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler +from http.client import HTTPConnection, HTTPSConnection, CREATED, OK + +from streamer.node_base import ProcessStatus, ThreadedNodeBase +from streamer.util import RequestBodyAsFileIO + + +# Protocols we have classes for to handle them. +SUPPORTED_PROTOCOLS = ['http', 'https', 'gs', 's3'] + + +class RequestHandler(BaseHTTPRequestHandler): + """A request handler that processes the PUT requests coming from + shaka packager and pushes them to the destination. + """ + + def __init__(self, conn: Union[HTTPConnection, HTTPSConnection], + extra_headers: Dict[str, str], base_path: str, param_query: str, + temp_dir: Optional[str], *args, **kwargs): + + self._conn = conn + # Extra headers to add when sending the request to the host + # using `self._conn`. + self._extra_headers = extra_headers + # The base path that will be prepended to the path of the handled request + # before forwarding the request to the host using `self._conn`. + self._base_path = base_path + # Parameters and query string to send in the url with each forwarded request. + self._params_and_querystring = param_query + self._temp_dir = temp_dir + # Call `super().__init__()` last because this call is what handles the + # actual request and we need the variables defined above to handle + # this request. + super().__init__(*args, **kwargs) + + def do_PUT(self): + """do_PUT will handle the PUT requests coming from shaka packager.""" + + headers = {} + # Use the same headers for requesting. + for k, v in self.headers.items(): + if k.lower() != 'host': + headers[k] = v + # Add the extra headers, this might contain an access token for instance. + for k, v in self._extra_headers.items(): + headers[k] = v + + # Don't chunk by default, as the request body is already chunked + # or we have a content-length header which means we are not using + # chunked transfer-encoding. + encode_chunked = False + content_length = self.headers['Content-Length'] + # Store the manifest files locally in `self._temp_dir`. + if self._temp_dir is not None and self.path.endswith(('.mpd', '.m3u8')): + body = self._write_body_and_get_file_io() + if content_length == None: + # We need to re-chunk it again as we wrote it as a whole + # to the filesystem. + encode_chunked = True + else: + content_length = content_length and int(content_length) + body = RequestBodyAsFileIO(self.rfile, content_length) + + # The url will be the result of joining the base path we should + # send the request to with the path this request came to. + url = self._base_path + self.path + # Also include any parameters and the query string. + url += self._params_and_querystring + + self._conn.request('PUT', url, body, headers, encode_chunked=encode_chunked) + res = self._conn.getresponse() + # Disable response logging. + self.log_request = lambda _: None + # Respond to Shaka Packager with the response we got. + self.send_response(res.status) + self.end_headers() + # self.wfile.write(res.read()) + # The destination should send (201/CREATED), but some do also send (200/OK). + if res.status != CREATED and res.status != OK: + print('Unexpected status for the PUT request:' + ' {}, ErrMsg: {!r}'.format(res.status, res.read())) + + def _write_body_and_get_file_io(self): + """A method that writes a request body to the filesystem + and returns a file io object opened for reading. + """ + + # Store the request body in `self._temp_dir`. + # Ignore the first '/' `self.path` as posixpath will think + # it points to the root direcotry. + path = os.path.join(self._temp_dir, self.path[1:]) + # With `exist_ok=True`, any intermidiate direcotries are created if needed. + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, 'wb') as request_body_file: + if self.headers['Content-Length'] is not None: + content_length = int(self.headers['Content-Length']) + request_body_file.write(self.rfile.read(content_length)) + else: + while True: + bytes_chunk_size = self.rfile.readline() + int_chunk_size = int(bytes_chunk_size.strip(), base=16) + request_body_file.write(self.rfile.read(int_chunk_size)) + # An empty newline that we have to consume. + self.rfile.readline() + # Chunk of size zero indicates that we have reached the end. + if int_chunk_size == 0: + break + return open(path, 'rb') + + +class RequestHandlersFactory(): + """A request handlers' factory that produces a RequestHandler whenever + its __call__ method is called. It stores all the relevant data that the + instantiated request handler will need when sending a request to the host. + """ + + def __init__(self, upload_location: str, initial_headers: Dict[str, str] = {}, + temp_dir: Optional[str] = None, max_conns: int = 50): + + url = urllib.parse.urlparse(upload_location) + if url.scheme not in ['http', 'https']: + # We can only instantiate HTTP/HTTPS connections. + raise RuntimeError("Unsupported scheme: {}", url.scheme) + self._ConnectionFactory = HTTPConnection if url.scheme == 'http' \ + else HTTPSConnection + self._destination_host = url.netloc + # Store the url path to prepend it to the path of each handled + # request before forwarding the request to `self._destination_host`. + self._base_path = url.path + # Store the parameters and the query string to send them in + # any request going to `self._destination_host`. + self._params_query = ';' + url.params if url.params else '' + self._params_query += '?' + url.query if url.query else '' + # These headers are going to be sent to `self._destination_host` + # with each request along with the headers that the request handler + # receives. Note that these extra headers can possibely overwrite + # the original request headers that the request handler received. + self._extra_headers = initial_headers + self._temp_dir = temp_dir + self._max_conns = max_conns + + def __call__(self, *args, **kwargs) -> RequestHandler: + """This magical method makes a RequestHandlersFactory instance + callable and returns a RequestHandler when called. + """ + + connection = self._ConnectionFactory(self._destination_host) + return RequestHandler(connection, self._extra_headers, + self._base_path, self._params_query, self._temp_dir, + *args, **kwargs) + + def update_headers(self, **kwargs): + self._extra_headers.update(**kwargs) + + +class HTTPUpload(ThreadedNodeBase): + """A ThreadedNodeBase subclass that launches a local threaded + HTTP server running at `self.server_location` and connected to + the host of `upload_location`. The requests sent to this server + will be sent to the `upload_location` after adding `extra_headers` + to its headers. if `temp_dir` argument was not None, DASH and HLS + manifests will be stored in it before sending them to `upload_location`. + + The local HTTP server at `self.server_location` can only ingest PUT requests. + """ + + def __init__(self, upload_location: str, extra_headers: Dict[str, str], + temp_dir: Optional[str], + periodic_job_wait_time: float = 3600 * 24 * 365.25): + + super().__init__(thread_name=self.__class__.__name__, + continue_on_exception=True, + sleep_time=periodic_job_wait_time) + + self.temp_dir = temp_dir + self.RequestHandlersFactory = RequestHandlersFactory(upload_location, + extra_headers, + self.temp_dir) + + # By specifying port 0, a random unused port will be chosen for the server. + self.server = ThreadingHTTPServer(('localhost', 0), + self.RequestHandlersFactory) + + self.server_location = 'http://' + self.server.server_name + \ + ':' + str(self.server.server_port) + + self.server_thread = threading.Thread(name=self.server_location, + target=self.server.serve_forever) + + def stop(self, status: Optional[ProcessStatus]): + self.server.shutdown() + self.server_thread.join() + return super().stop(status) + + def start(self): + self.server_thread.start() + return super().start() + + def _thread_single_pass(self): + return self.periodic_job() + + def periodic_job(self) -> None: + # Ideally, we will have nothing to do periodically after the wait time + # which is a very long time by default. However, this can be overridden + # by subclasses and populated with calls to all the functions that need + # to be executed periodically. + return + + +class GCSUpload(HTTPUpload): + """The upload node used when PUT requesting to a GCS bucket. + + It will parse the `upload_location` argument with `gs://` protocol + and use the GCP REST API that uses HTTPS protocol instead. + """ + + def __init__(self, upload_location: str, extra_headers: Dict[str, str], + temp_dir: Optional[str]): + upload_location = 'https://storage.googleapis.com/' + upload_location[5:] + + # Normalize the extra headers dictionary. + for key in list(extra_headers.copy()): + extra_headers[key.lower()] = extra_headers.pop(key) + + # We don't have to get a refresh token. Maybe there is an access token + # provided and we won't outlive it anyway, but that's the user's responsibility. + self.refresh_token = extra_headers.pop('refresh-token', None) + self.client_id = extra_headers.pop('client-id', None) + self.client_secret = extra_headers.pop('client-secret', None) + # The access token expires after 3600s in GCS. + refresh_period = int(extra_headers.pop('refresh-every', None) or 3300) + + super().__init__(upload_location, extra_headers, temp_dir, refresh_period) + + # We yet don't have an access token, so we need to get a one. + self._refresh_access_token() + + def _refresh_access_token(self): + if (self.refresh_token is not None + and self.client_id is not None + and self.client_secret is not None): + conn = HTTPSConnection('oauth2.googleapis.com') + req_body = { + 'grant_type': 'refresh_token', + 'refresh_token': self.refresh_token, + 'client_id': self.client_id, + 'client_secret': self.client_secret, + } + conn.request('POST', '/token', json.dumps(req_body)) + res = conn.getresponse() + if res.status == OK: + res_body = json.loads(res.read()) + # Update the Authorization header that the request factory has. + auth = res_body['token_type'] + ' ' + res_body['access_token'] + self.RequestHandlersFactory.update_headers(Authorization=auth) + else: + print("Couldn't refresh access token. ErrCode: {}, ErrMst: {!r}".format( + res.status, res.read())) + else: + print("Non sufficient info provided to refresh the access token.") + print("To refresh access token periodically, 'refresh-token', 'client-id'" + " and 'client-secret' headers must be provided.") + print("After the current access token expires, the upload will fail.") + + def periodic_job(self) -> None: + self._refresh_access_token() + + +class S3Upload(HTTPUpload): + """The upload node used when PUT requesting to a S3 bucket. + + It will parse the `upload_location` argument with `s3://` protocol + and use the AWS REST API that uses HTTPS protocol instead. + """ + + def __init__(self, upload_location: str, extra_headers: Dict[str, str], + temp_dir: Optional[str]): + raise NotImplementedError("S3 uploads aren't working yet.") + url_parts = upload_location[5:].split('/', 1) + bucket = url_parts[0] + path = '/' + url_parts[1] if len(url_parts) > 1 else '' + upload_location = 'https://' + bucket + '.s3.amazonaws.com' + path + + # We don't have to get a refresh token. Maybe there is an access token + # provided and we won't outlive it anyway, but that's the user's responsibility. + self.refresh_token = extra_headers.pop('refresh-token', None) + self.client_id = extra_headers.pop('client-id', None) + # The access token expires after 3600s in S3. + refresh_period = int(extra_headers.pop('refresh-every', None) or 3300) + + super().__init__(upload_location, extra_headers, temp_dir, refresh_period) + + # We yet don't have an access token, so we need to get a one. + self._refresh_access_token() + + def _refresh_access_token(self): + if (self.refresh_token is not None and self.client_id is not None): + conn = HTTPSConnection('api.amazon.com') + req_body = { + 'grant_type': 'refresh_token', + 'refresh_token': self.refresh_token, + 'client_id': self.client_id, + } + conn.request('POST', '/auth/o2/token', json.dumps(req_body)) + res = conn.getresponse() + if res.status == OK: + res_body = json.loads(res.read()) + # Update the Authorization header that the request factory has. + auth = res_body['token_type'] + ' ' + res_body['access_token'] + self.RequestHandlersFactory.update_headers(Authorization=auth) + else: + print("Couldn't refresh access token. ErrCode: {}, ErrMst: {!r}".format( + res.status, res.read())) + else: + print("Non sufficient info provided to refresh the access token.") + print("To refresh access token periodically, 'refresh-token'" + " and 'client-id' headers must be provided.") + print("After the current access token expires, the upload will fail.") + + def periodic_job(self) -> None: + self._refresh_access_token() + + +def get_upload_node(upload_location: str, extra_headers: Dict[str, str], + temp_dir: Optional[str] = None) -> HTTPUpload: + """Instantiates an appropriate HTTPUpload node based on the protocol + used in `upload_location` url. + """ + + if upload_location.startswith(("http://", "https://")): + return HTTPUpload(upload_location, extra_headers, temp_dir) + elif upload_location.startswith("gs://"): + return GCSUpload(upload_location, extra_headers, temp_dir) + elif upload_location.startswith("s3://"): + return S3Upload(upload_location, extra_headers, temp_dir) + else: + raise RuntimeError("Protocol of {} isn't supported".format(upload_location)) + +def is_supported_protocol(upload_location: str) -> bool: + return bool([upload_location.startswith(protocol + '://') for + protocol in SUPPORTED_PROTOCOLS].count(True)) diff --git a/streamer/util.py b/streamer/util.py index 7250df9..19690ce 100644 --- a/streamer/util.py +++ b/streamer/util.py @@ -14,7 +14,99 @@ """Utility functions used by multiple modules.""" +import io +from typing import Optional + + def is_url(output_location: str) -> bool: """Returns True if the output location is a URL.""" - return (output_location.startswith('http:') or - output_location.startswith('https:')) + return output_location.startswith(('http://', + 'https://')) + + +class RequestBodyAsFileIO(io.BufferedIOBase): + """A class that provides a layer of access to an HTTP request body. It provides + an interface to treat a request body (of type `io.BufferedIOBase`) as a file. + Since a request body does not have an `EOF`, this class will encapsulate the + logic of using Content-Length or chunk size to provide an emulated `EOF`. + + This implementation is much faster than storing the request body + in the filesystem then reading it with an `EOF` included. + """ + + def __init__(self, rfile: io.BufferedIOBase, content_length: Optional[int]): + super().__init__() + self._body = rfile + # Decide whether this is a chunked request or not based on content length. + if content_length is not None: + self._is_chunked = False + self._left_to_read = content_length + else: + self._is_chunked = True + self._last_chunk_read = False + self._buffer = b'' + + def read(self, blocksize: Optional[int] = None) -> bytes: + """This method reads `self.body` incrementally with each call. + This is done because if we try to use `read()` on `self._body` it will wait + forever for an `EOF` which is not present and will never be. + + This method -like the original `read()`- will read up to (but not more than) + `blocksize` if it is a non-negative integer, and will read till `EOF` if + blocksize is None, a negative integer, or not passed. + """ + + if self._is_chunked: + return self._read_chunked(blocksize) + else: + return self._read_not_chunked(blocksize) + + def _read_chunked(self, blocksize: Optional[int] = None) -> bytes: + """This method provides the read functionality from a request + body with chunked Transfer-Encoding. + """ + + # For non-negative blocksize values. + if blocksize and blocksize >= 0: + # Keep buffering until we can fulfil the blocksize or there + # are no chunks left to buffer. + while blocksize > len(self._buffer) and not self._last_chunk_read: + byte_chunk_size = self._body.readline() + self._buffer += byte_chunk_size + int_chunk_size = int(byte_chunk_size.strip(), base=16) + self._buffer += self._body.read(int_chunk_size) + # Consume the CLRF after each chunk. + self._buffer += self._body.readline() + if int_chunk_size == 0: + # A zero sized chunk indicates that no more chunks left. + self._last_chunk_read = True + bytes_read, self._buffer = self._buffer[:blocksize], self._buffer[blocksize:] + return bytes_read + # When blocksize is a negative integer or None. + else: + bytes_read = b'' + while True: + chunk = self._read_chunked(64 * 1024) + bytes_read += chunk + if chunk == b'': + return bytes_read + + def _read_not_chunked(self, blocksize: Optional[int] = None) -> bytes: + """This method provides the read functionality from a request + body of a known Content-Length. + """ + + # Don't try to read if there is nothing to read. + if self._left_to_read == 0: + # This indicates `EOF` for the caller. + return b'' + # For non-negative blocksize values. + if blocksize and blocksize >= 0: + size_to_read = min(blocksize, self._left_to_read) + self._left_to_read -= size_to_read + return self._body.read(size_to_read) + # When blocksize is a negative integer or None. + else: + size_to_read, self._left_to_read = self._left_to_read, 0 + return self._body.read(size_to_read) +