Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(cloud): Upload through HTTP proxy node
Browse files Browse the repository at this point in the history
mariocynicys authored and joeyparrish committed Oct 22, 2024

Unverified

This commit is not signed, but one or more authors requires that any commit attributed to them is signed.
1 parent 0c4b529 commit e6d4db8
Showing 9 changed files with 585 additions and 260 deletions.
23 changes: 12 additions & 11 deletions shaka-streamer
Original file line number Diff line number Diff line change
@@ -61,15 +61,16 @@ def main():
'bitrates and resolutions for transcoding. ' +
'(optional, see example in ' +
'config_files/bitrate_config.yaml)')
parser.add_argument('-c', '--cloud-url',
default=None,
help='The Google Cloud Storage or Amazon S3 URL to ' +
'upload to. (Starts with gs:// or s3://)')
parser.add_argument('-o', '--output',
default='output_files',
help='The output folder to write files to, or an HTTP ' +
'or HTTPS URL where files will be PUT.' +
'Used even if uploading to cloud storage.')
parser.add_argument('-H', '--add-header',
action='append',
metavar='key=value',
help='Headers to include when sending PUT requests ' +
'to the specified output location.')
parser.add_argument('--skip-deps-check',
action='store_true',
help='Skip checks for dependencies and their versions. ' +
@@ -96,14 +97,15 @@ def main():
with open(args.bitrate_config) as f:
bitrate_config_dict = yaml.safe_load(f)

if args.cloud_url:
if (not args.cloud_url.startswith('gs://') and
not args.cloud_url.startswith('s3://')):
parser.error('Invalid cloud URL! Only gs:// and s3:// URLs are supported')
extra_headers = {}
if args.add_header:
for header in args.add_header:
key, value = tuple(header.split('=', 1))
extra_headers[key] = value

try:
with controller.start(args.output, input_config_dict, pipeline_config_dict,
bitrate_config_dict, args.cloud_url,
bitrate_config_dict, extra_headers,
not args.skip_deps_check,
not args.use_system_binaries):
# Sleep so long as the pipeline is still running.
@@ -114,8 +116,7 @@ def main():

time.sleep(1)
except (streamer.controller_node.VersionError,
streamer.configuration.ConfigError,
streamer.cloud_node.CloudAccessError) as e:
streamer.configuration.ConfigError) as e:
# These are common errors meant to give the user specific, helpful
# information. Format these errors in a relatively friendly way, with no
# backtrace or other Python-specific information.
190 changes: 0 additions & 190 deletions streamer/cloud_node.py

This file was deleted.

91 changes: 47 additions & 44 deletions streamer/controller_node.py
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@
from streamer import __version__
from streamer import autodetect
from streamer import min_versions
from streamer.cloud_node import CloudNode
from streamer import proxy_node
from streamer.bitrate_configuration import BitrateConfig, AudioChannelLayout, VideoResolution
from streamer.external_command_node import ExternalCommandNode
from streamer.input_configuration import InputConfig, InputType, MediaType, Input
@@ -42,6 +42,7 @@
from streamer.pipeline_configuration import ManifestFormat, PipelineConfig, StreamingMode
from streamer.transcoder_node import TranscoderNode
from streamer.periodconcat_node import PeriodConcatNode
from streamer.proxy_node import HTTPUpload
import streamer.subprocessWindowsPatch # side-effects only
from streamer.util import is_url
from streamer.pipe import Pipe
@@ -60,6 +61,7 @@ def __init__(self) -> None:
dir=global_temp_dir, prefix='shaka-live-', suffix='')

self._nodes: List[NodeBase] = []
self.upload_proxy: Optional[HTTPUpload] = None

def __del__(self) -> None:
# Clean up named pipes by removing the temp directory we placed them in.
@@ -75,7 +77,7 @@ def start(self, output_location: str,
input_config_dict: Dict[str, Any],
pipeline_config_dict: Dict[str, Any],
bitrate_config_dict: Dict[Any, Any] = {},
bucket_url: Union[str, None] = None,
extra_headers: Dict[str, str] = {},
check_deps: bool = True,
use_hermetic: bool = True) -> 'ControllerNode':
"""Create and start all other nodes.
@@ -147,21 +149,6 @@ def next_short_version(version: str) -> str:
_check_command_version('Shaka Packager', ['packager', '-version'],
min_versions.PACKAGER)

if bucket_url:
# Check that the Google Cloud SDK is at least v212, which introduced
# gsutil 4.33 with an important rsync bug fix.
# https://cloud.google.com/sdk/docs/release-notes
# https://github.com/GoogleCloudPlatform/gsutil/blob/master/CHANGES.md
# This is only required if the user asked for upload to cloud storage.
_check_command_version('Google Cloud SDK', ['gcloud', '--version'],
(212, 0, 0))


if bucket_url:
# If using cloud storage, make sure the user is logged in and can access
# the destination, independent of the version check above.
CloudNode.check_access(bucket_url)

self.hermetic_ffmpeg: Optional[str] = None
self.hermetic_packager: Optional[str] = None
if use_hermetic:
@@ -181,6 +168,16 @@ def next_short_version(version: str) -> str:
self._input_config = InputConfig(input_config_dict)
self._pipeline_config = PipelineConfig(pipeline_config_dict)

# Note that we remove the trailing slash from the output location, because
# otherwise GCS would create a subdirectory whose name is "".
output_location = output_location.rstrip('/')
if (proxy_node.is_supported_protocol(output_location)
and self._pipeline_config.use_local_proxy):
self.upload_proxy = self.get_upload_node(output_location, extra_headers)
# All the outputs now should be sent to the proxy server instead.
output_location = self.upload_proxy.server_location
self._nodes.append(self.upload_proxy)

if not is_url(output_location):
# Check if the directory for outputted Packager files exists, and if it
# does, delete it and remake a new one.
@@ -194,15 +191,6 @@ def next_short_version(version: str) -> str:
'For HTTP PUT uploads, the pipeline segment_per_file setting ' +
'must be set to True!')

if bucket_url:
raise RuntimeError(
'Cloud bucket upload is incompatible with HTTP PUT support.')

if self._input_config.multiperiod_inputs_list:
# TODO: Edit Multiperiod input list implementation to support HTTP outputs
raise RuntimeError(
'Multiperiod input list support is incompatible with HTTP outputs.')

if self._pipeline_config.low_latency_dash_mode:
# Check some restrictions on LL-DASH packaging.
if ManifestFormat.DASH not in self._pipeline_config.manifest_format:
@@ -214,10 +202,6 @@ def next_short_version(version: str) -> str:
raise RuntimeError(
'For low_latency_dash_mode, the utc_timings must be set.')

# Note that we remove the trailing slash from the output location, because
# otherwise GCS would create a subdirectory whose name is "".
output_location = output_location.rstrip('/')

if self._input_config.inputs:
# InputConfig contains inputs only.
self._append_nodes_for_inputs_list(self._input_config.inputs,
@@ -236,18 +220,8 @@ def next_short_version(version: str) -> str:
self._nodes.append(PeriodConcatNode(
self._pipeline_config,
packager_nodes,
output_location))

if bucket_url:
cloud_temp_dir = os.path.join(self._temp_dir, 'cloud')
os.mkdir(cloud_temp_dir)

packager_nodes = [node for node in self._nodes if isinstance(node, PackagerNode)]
self._nodes.append(CloudNode(output_location,
bucket_url,
cloud_temp_dir,
packager_nodes,
self.is_vod()))
output_location,
self.upload_proxy))

for node in self._nodes:
node.start()
@@ -335,7 +309,8 @@ def _append_nodes_for_inputs_list(self, inputs: List[Input],
# and put that period in it.
if period_dir:
output_location = os.path.join(output_location, period_dir)
os.mkdir(output_location)
if not is_url(output_location):
os.mkdir(output_location)

self._nodes.append(PackagerNode(self._pipeline_config,
output_location,
@@ -349,11 +324,15 @@ def check_status(self) -> ProcessStatus:
If one node is errored, this returns Errored; otherwise if one node is running,
this returns Running; this only returns Finished if all nodes are finished.
If there are no nodes, this returns Finished.
:rtype: ProcessStatus
"""
if not self._nodes:
return ProcessStatus.Finished

value = max(node.check_status().value for node in self._nodes)
value = max(node.check_status().value for node in self._nodes
# We don't check the the upload node.
if node != self.upload_proxy)
return ProcessStatus(value)

def stop(self) -> None:
@@ -379,6 +358,30 @@ def is_low_latency_dash_mode(self) -> bool:

return self._pipeline_config.low_latency_dash_mode

def get_upload_node(self, upload_location: str,
extra_headers: Dict[str, str]) -> HTTPUpload:
"""
Args:
upload_location (str): The location where media content will be uploaded.
extra_headers (Dict[str, str]): Extra headers to be added when
sending the PUT request to `upload_location`.
:rtype: HTTPUpload
:raises: `RuntimeError` if the protocol used in `upload_location` was not
recognized.
"""

# We need to pass a temporary direcotry when working with multi-period input
# and using HTTP PUT for uploading. This is so that the HTTPUpload
# keeps a copy of the manifests in the temporary directory so we can use them
# later to assemble the multi-period manifests.
upload_temp_dir = None
if self._input_config.multiperiod_inputs_list:
upload_temp_dir = os.path.join(self._temp_dir, 'multiperiod_manifests')
os.mkdir(upload_temp_dir)
return proxy_node.get_upload_node(upload_location, extra_headers,
upload_temp_dir)

class VersionError(Exception):
"""A version error for one of Shaka Streamer's external dependencies.
5 changes: 4 additions & 1 deletion streamer/node_base.py
Original file line number Diff line number Diff line change
@@ -166,6 +166,7 @@ def __init__(self, thread_name: str, continue_on_exception: bool, sleep_time: fl
self._thread_name = thread_name
self._continue_on_exception = continue_on_exception
self._sleep_time = sleep_time
self._sleep_waker_event = threading.Event()
self._thread = threading.Thread(target=self._thread_main, name=thread_name)

def _thread_main(self) -> None:
@@ -183,7 +184,7 @@ def _thread_main(self) -> None:
return

# Wait a little bit before performing the next pass.
time.sleep(self._sleep_time)
self._sleep_waker_event.wait(self._sleep_time)

@abc.abstractmethod
def _thread_single_pass(self) -> None:
@@ -204,6 +205,8 @@ def start(self) -> None:

def stop(self, status: Optional[ProcessStatus]) -> None:
self._status = ProcessStatus.Finished
# If the thread was sleeping, wake it up.
self._sleep_waker_event.set()
self._thread.join()

def check_status(self) -> ProcessStatus:
4 changes: 2 additions & 2 deletions streamer/packager_node.py
Original file line number Diff line number Diff line change
@@ -183,7 +183,7 @@ def _setup_manifest_format(self) -> List[str]:
args += [
# Generate DASH manifest file.
'--mpd_output',
os.path.join(self.output_location, self._pipeline_config.dash_output),
build_path(self.output_location, self._pipeline_config.dash_output),
]
if ManifestFormat.HLS in self._pipeline_config.manifest_format:
if self._pipeline_config.streaming_mode == StreamingMode.LIVE:
@@ -197,7 +197,7 @@ def _setup_manifest_format(self) -> List[str]:
args += [
# Generate HLS playlist file(s).
'--hls_master_playlist_output',
os.path.join(self.output_location, self._pipeline_config.hls_output),
build_path(self.output_location, self._pipeline_config.hls_output),
]
return args

66 changes: 56 additions & 10 deletions streamer/periodconcat_node.py
Original file line number Diff line number Diff line change
@@ -17,14 +17,17 @@
import os
import re
import time
from typing import List
from typing import List, Optional
from xml.etree import ElementTree
from http.client import HTTPConnection, CREATED
from streamer import __version__
from streamer.node_base import ProcessStatus, ThreadedNodeBase
from streamer.packager_node import PackagerNode
from streamer.pipeline_configuration import PipelineConfig, ManifestFormat
from streamer.output_stream import AudioOutputStream, VideoOutputStream
from streamer.m3u8_concater import HLSConcater
from streamer.proxy_node import HTTPUpload
from streamer.util import is_url


class PeriodConcatNode(ThreadedNodeBase):
@@ -35,17 +38,20 @@ class PeriodConcatNode(ThreadedNodeBase):
def __init__(self,
pipeline_config: PipelineConfig,
packager_nodes: List[PackagerNode],
output_location: str) -> None:
output_location: str,
upload_proxy: Optional[HTTPUpload]) -> None:
"""Stores all relevant information needed for the period concatenation."""
super().__init__(thread_name='periodconcat', continue_on_exception=False, sleep_time=3)
super().__init__(thread_name='periodconcat', continue_on_exception=False, sleep_time=1)
self._pipeline_config = pipeline_config
self._output_location = output_location
self._packager_nodes: List[PackagerNode] = packager_nodes
self._concat_will_fail = False

self._proxy_node = upload_proxy
self._concat_will_fail = self._check_failed_concatenation()

def _check_failed_concatenation(self) -> bool:
# know whether the first period has video and audio or not.
fp_has_vid, fp_has_aud = False, False
for output_stream in packager_nodes[0].output_streams:
for output_stream in self._packager_nodes[0].output_streams:
if isinstance(output_stream, VideoOutputStream):
fp_has_vid = True
elif isinstance(output_stream, AudioOutputStream):
@@ -59,7 +65,6 @@ def __init__(self,
elif isinstance(output_stream, AudioOutputStream):
has_aud = True
if has_vid != fp_has_vid or has_aud != fp_has_aud:
self._concat_will_fail = True
print("\nWARNING: Stopping period concatenation.")
print("Period#{} has {}video and has {}audio while Period#1 "
"has {}video and has {}audio.".format(i + 1,
@@ -72,8 +77,21 @@ def __init__(self,
"\tperiods with other periods that have video.\n"
"\tThis is necessary for the concatenation to be performed successfully.\n")
time.sleep(5)
break

return True

if self._proxy_node is None and is_url(self._output_location):
print("\nWARNING: Stopping period concatenation.")
print("Shaka Packager is using HTTP PUT but not using"
" Shaka Streamer's upload proxy.")
print("\nHINT:\n\tShaka Streamer's upload proxy stores the manifest files\n"
"\ttemporarily in the local filesystem to use them for period concatenation.\n"
"\tSet use_local_proxy to True in the pipeline config to enable the"
" upload proxy.\n")
time.sleep(5)
return True
# Otherwise, we don't have a reason to fail.
return False

def _thread_single_pass(self) -> None:
"""Watches all the PackagerNode(s), if at least one of them is running it skips this
_thread_single_pass, if all of them are finished, it starts period concatenation, if one of
@@ -90,13 +108,41 @@ def _thread_single_pass(self) -> None:
'to an error in PackagerNode#{}.'.format(i + 1))

if self._concat_will_fail:
raise RuntimeError('Unable to concatenate the inputs.')
raise RuntimeError('Unable to concatenate the inputs')

# If the packager was pushing HTTP requests to the stream's proxy server,
# the proxy server should have stored the manifest files in a temporary
# directory in the filesystem.
if self._proxy_node is not None:
assert self._proxy_node.temp_dir, ('There should be a proxy temp direcotry'
' when processing multi-period input')
self._output_location = self._proxy_node.temp_dir
# As the period concatenator node is the last to run, changing the
# output location at run time won't disturb any other node.
for packager_node in self._packager_nodes:
packager_node.output_location = packager_node.output_location.replace(
self._proxy_node.server_location,
self._proxy_node.temp_dir, 1)

if ManifestFormat.DASH in self._pipeline_config.manifest_format:
self._dash_concat()

if ManifestFormat.HLS in self._pipeline_config.manifest_format:
self._hls_concat()

# Push the concatenated manifests if a proxy is used.
if self._proxy_node is not None:
conn = HTTPConnection(self._proxy_node.server.server_name,
self._proxy_node.server.server_port)
# The concatenated manifest files where written in `self._output_location`.
for manifest_file_name in os.listdir(self._output_location):
if manifest_file_name.endswith(('.mpd', '.m3u8')):
manifest_file_path = os.path.join(self._output_location, manifest_file_name)
conn.request('PUT', '/' + manifest_file_name, open(manifest_file_path, 'r'))
res = conn.getresponse()
if res.status != CREATED:
print("Got unexpected status code: {}, Msg: {!r}".format(res.status,
res.read()))

self._status = ProcessStatus.Finished

5 changes: 5 additions & 0 deletions streamer/pipeline_configuration.py
Original file line number Diff line number Diff line change
@@ -320,6 +320,11 @@ class PipelineConfig(configuration.Base):
default=EncryptionConfig({})).cast()
"""Encryption settings."""

use_local_proxy = configuration.Field(bool, default=True).cast()
"""Whether to use shaka streamer's local proxy when uploading to a remote
storage. This must be set to True when uploading to GCS or amazon S3 buckets.
"""

# TODO: Generalize this to low_latency_mode once LL-HLS is supported by Packager
low_latency_dash_mode = configuration.Field(bool, default=False).cast()
"""If true, stream in low latency mode for DASH."""
365 changes: 365 additions & 0 deletions streamer/proxy_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,365 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A module that implements a simple proxy server for uploading packaged
content to cloud storage providers (GCS and S3) and also any server that
accepts PUT requests.
"""

import os
import json
import threading
import urllib.parse
from typing import Optional, Union, Dict
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
from http.client import HTTPConnection, HTTPSConnection, CREATED, OK

from streamer.node_base import ProcessStatus, ThreadedNodeBase
from streamer.util import RequestBodyAsFileIO


# Protocols we have classes for to handle them.
SUPPORTED_PROTOCOLS = ['http', 'https', 'gs', 's3']


class RequestHandler(BaseHTTPRequestHandler):
"""A request handler that processes the PUT requests coming from
shaka packager and pushes them to the destination.
"""

def __init__(self, conn: Union[HTTPConnection, HTTPSConnection],
extra_headers: Dict[str, str], base_path: str, param_query: str,
temp_dir: Optional[str], *args, **kwargs):

self._conn = conn
# Extra headers to add when sending the request to the host
# using `self._conn`.
self._extra_headers = extra_headers
# The base path that will be prepended to the path of the handled request
# before forwarding the request to the host using `self._conn`.
self._base_path = base_path
# Parameters and query string to send in the url with each forwarded request.
self._params_and_querystring = param_query
self._temp_dir = temp_dir
# Call `super().__init__()` last because this call is what handles the
# actual request and we need the variables defined above to handle
# this request.
super().__init__(*args, **kwargs)

def do_PUT(self):
"""do_PUT will handle the PUT requests coming from shaka packager."""

headers = {}
# Use the same headers for requesting.
for k, v in self.headers.items():
if k.lower() != 'host':
headers[k] = v
# Add the extra headers, this might contain an access token for instance.
for k, v in self._extra_headers.items():
headers[k] = v

# Don't chunk by default, as the request body is already chunked
# or we have a content-length header which means we are not using
# chunked transfer-encoding.
encode_chunked = False
content_length = self.headers['Content-Length']
# Store the manifest files locally in `self._temp_dir`.
if self._temp_dir is not None and self.path.endswith(('.mpd', '.m3u8')):
body = self._write_body_and_get_file_io()
if content_length == None:
# We need to re-chunk it again as we wrote it as a whole
# to the filesystem.
encode_chunked = True
else:
content_length = content_length and int(content_length)
body = RequestBodyAsFileIO(self.rfile, content_length)

# The url will be the result of joining the base path we should
# send the request to with the path this request came to.
url = self._base_path + self.path
# Also include any parameters and the query string.
url += self._params_and_querystring

self._conn.request('PUT', url, body, headers, encode_chunked=encode_chunked)
res = self._conn.getresponse()
# Disable response logging.
self.log_request = lambda _: None
# Respond to Shaka Packager with the response we got.
self.send_response(res.status)
self.end_headers()
# self.wfile.write(res.read())
# The destination should send (201/CREATED), but some do also send (200/OK).
if res.status != CREATED and res.status != OK:
print('Unexpected status for the PUT request:'
' {}, ErrMsg: {!r}'.format(res.status, res.read()))

def _write_body_and_get_file_io(self):
"""A method that writes a request body to the filesystem
and returns a file io object opened for reading.
"""

# Store the request body in `self._temp_dir`.
# Ignore the first '/' `self.path` as posixpath will think
# it points to the root direcotry.
path = os.path.join(self._temp_dir, self.path[1:])
# With `exist_ok=True`, any intermidiate direcotries are created if needed.
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'wb') as request_body_file:
if self.headers['Content-Length'] is not None:
content_length = int(self.headers['Content-Length'])
request_body_file.write(self.rfile.read(content_length))
else:
while True:
bytes_chunk_size = self.rfile.readline()
int_chunk_size = int(bytes_chunk_size.strip(), base=16)
request_body_file.write(self.rfile.read(int_chunk_size))
# An empty newline that we have to consume.
self.rfile.readline()
# Chunk of size zero indicates that we have reached the end.
if int_chunk_size == 0:
break
return open(path, 'rb')


class RequestHandlersFactory():
"""A request handlers' factory that produces a RequestHandler whenever
its __call__ method is called. It stores all the relevant data that the
instantiated request handler will need when sending a request to the host.
"""

def __init__(self, upload_location: str, initial_headers: Dict[str, str] = {},
temp_dir: Optional[str] = None, max_conns: int = 50):

url = urllib.parse.urlparse(upload_location)
if url.scheme not in ['http', 'https']:
# We can only instantiate HTTP/HTTPS connections.
raise RuntimeError("Unsupported scheme: {}", url.scheme)
self._ConnectionFactory = HTTPConnection if url.scheme == 'http' \
else HTTPSConnection
self._destination_host = url.netloc
# Store the url path to prepend it to the path of each handled
# request before forwarding the request to `self._destination_host`.
self._base_path = url.path
# Store the parameters and the query string to send them in
# any request going to `self._destination_host`.
self._params_query = ';' + url.params if url.params else ''
self._params_query += '?' + url.query if url.query else ''
# These headers are going to be sent to `self._destination_host`
# with each request along with the headers that the request handler
# receives. Note that these extra headers can possibely overwrite
# the original request headers that the request handler received.
self._extra_headers = initial_headers
self._temp_dir = temp_dir
self._max_conns = max_conns

def __call__(self, *args, **kwargs) -> RequestHandler:
"""This magical method makes a RequestHandlersFactory instance
callable and returns a RequestHandler when called.
"""

connection = self._ConnectionFactory(self._destination_host)
return RequestHandler(connection, self._extra_headers,
self._base_path, self._params_query, self._temp_dir,
*args, **kwargs)

def update_headers(self, **kwargs):
self._extra_headers.update(**kwargs)


class HTTPUpload(ThreadedNodeBase):
"""A ThreadedNodeBase subclass that launches a local threaded
HTTP server running at `self.server_location` and connected to
the host of `upload_location`. The requests sent to this server
will be sent to the `upload_location` after adding `extra_headers`
to its headers. if `temp_dir` argument was not None, DASH and HLS
manifests will be stored in it before sending them to `upload_location`.
The local HTTP server at `self.server_location` can only ingest PUT requests.
"""

def __init__(self, upload_location: str, extra_headers: Dict[str, str],
temp_dir: Optional[str],
periodic_job_wait_time: float = 3600 * 24 * 365.25):

super().__init__(thread_name=self.__class__.__name__,
continue_on_exception=True,
sleep_time=periodic_job_wait_time)

self.temp_dir = temp_dir
self.RequestHandlersFactory = RequestHandlersFactory(upload_location,
extra_headers,
self.temp_dir)

# By specifying port 0, a random unused port will be chosen for the server.
self.server = ThreadingHTTPServer(('localhost', 0),
self.RequestHandlersFactory)

self.server_location = 'http://' + self.server.server_name + \
':' + str(self.server.server_port)

self.server_thread = threading.Thread(name=self.server_location,
target=self.server.serve_forever)

def stop(self, status: Optional[ProcessStatus]):
self.server.shutdown()
self.server_thread.join()
return super().stop(status)

def start(self):
self.server_thread.start()
return super().start()

def _thread_single_pass(self):
return self.periodic_job()

def periodic_job(self) -> None:
# Ideally, we will have nothing to do periodically after the wait time
# which is a very long time by default. However, this can be overridden
# by subclasses and populated with calls to all the functions that need
# to be executed periodically.
return


class GCSUpload(HTTPUpload):
"""The upload node used when PUT requesting to a GCS bucket.
It will parse the `upload_location` argument with `gs://` protocol
and use the GCP REST API that uses HTTPS protocol instead.
"""

def __init__(self, upload_location: str, extra_headers: Dict[str, str],
temp_dir: Optional[str]):
upload_location = 'https://storage.googleapis.com/' + upload_location[5:]

# Normalize the extra headers dictionary.
for key in list(extra_headers.copy()):
extra_headers[key.lower()] = extra_headers.pop(key)

# We don't have to get a refresh token. Maybe there is an access token
# provided and we won't outlive it anyway, but that's the user's responsibility.
self.refresh_token = extra_headers.pop('refresh-token', None)
self.client_id = extra_headers.pop('client-id', None)
self.client_secret = extra_headers.pop('client-secret', None)
# The access token expires after 3600s in GCS.
refresh_period = int(extra_headers.pop('refresh-every', None) or 3300)

super().__init__(upload_location, extra_headers, temp_dir, refresh_period)

# We yet don't have an access token, so we need to get a one.
self._refresh_access_token()

def _refresh_access_token(self):
if (self.refresh_token is not None
and self.client_id is not None
and self.client_secret is not None):
conn = HTTPSConnection('oauth2.googleapis.com')
req_body = {
'grant_type': 'refresh_token',
'refresh_token': self.refresh_token,
'client_id': self.client_id,
'client_secret': self.client_secret,
}
conn.request('POST', '/token', json.dumps(req_body))
res = conn.getresponse()
if res.status == OK:
res_body = json.loads(res.read())
# Update the Authorization header that the request factory has.
auth = res_body['token_type'] + ' ' + res_body['access_token']
self.RequestHandlersFactory.update_headers(Authorization=auth)
else:
print("Couldn't refresh access token. ErrCode: {}, ErrMst: {!r}".format(
res.status, res.read()))
else:
print("Non sufficient info provided to refresh the access token.")
print("To refresh access token periodically, 'refresh-token', 'client-id'"
" and 'client-secret' headers must be provided.")
print("After the current access token expires, the upload will fail.")

def periodic_job(self) -> None:
self._refresh_access_token()


class S3Upload(HTTPUpload):
"""The upload node used when PUT requesting to a S3 bucket.
It will parse the `upload_location` argument with `s3://` protocol
and use the AWS REST API that uses HTTPS protocol instead.
"""

def __init__(self, upload_location: str, extra_headers: Dict[str, str],
temp_dir: Optional[str]):
raise NotImplementedError("S3 uploads aren't working yet.")
url_parts = upload_location[5:].split('/', 1)
bucket = url_parts[0]
path = '/' + url_parts[1] if len(url_parts) > 1 else ''
upload_location = 'https://' + bucket + '.s3.amazonaws.com' + path

# We don't have to get a refresh token. Maybe there is an access token
# provided and we won't outlive it anyway, but that's the user's responsibility.
self.refresh_token = extra_headers.pop('refresh-token', None)
self.client_id = extra_headers.pop('client-id', None)
# The access token expires after 3600s in S3.
refresh_period = int(extra_headers.pop('refresh-every', None) or 3300)

super().__init__(upload_location, extra_headers, temp_dir, refresh_period)

# We yet don't have an access token, so we need to get a one.
self._refresh_access_token()

def _refresh_access_token(self):
if (self.refresh_token is not None and self.client_id is not None):
conn = HTTPSConnection('api.amazon.com')
req_body = {
'grant_type': 'refresh_token',
'refresh_token': self.refresh_token,
'client_id': self.client_id,
}
conn.request('POST', '/auth/o2/token', json.dumps(req_body))
res = conn.getresponse()
if res.status == OK:
res_body = json.loads(res.read())
# Update the Authorization header that the request factory has.
auth = res_body['token_type'] + ' ' + res_body['access_token']
self.RequestHandlersFactory.update_headers(Authorization=auth)
else:
print("Couldn't refresh access token. ErrCode: {}, ErrMst: {!r}".format(
res.status, res.read()))
else:
print("Non sufficient info provided to refresh the access token.")
print("To refresh access token periodically, 'refresh-token'"
" and 'client-id' headers must be provided.")
print("After the current access token expires, the upload will fail.")

def periodic_job(self) -> None:
self._refresh_access_token()


def get_upload_node(upload_location: str, extra_headers: Dict[str, str],
temp_dir: Optional[str] = None) -> HTTPUpload:
"""Instantiates an appropriate HTTPUpload node based on the protocol
used in `upload_location` url.
"""

if upload_location.startswith(("http://", "https://")):
return HTTPUpload(upload_location, extra_headers, temp_dir)
elif upload_location.startswith("gs://"):
return GCSUpload(upload_location, extra_headers, temp_dir)
elif upload_location.startswith("s3://"):
return S3Upload(upload_location, extra_headers, temp_dir)
else:
raise RuntimeError("Protocol of {} isn't supported".format(upload_location))

def is_supported_protocol(upload_location: str) -> bool:
return bool([upload_location.startswith(protocol + '://') for
protocol in SUPPORTED_PROTOCOLS].count(True))
96 changes: 94 additions & 2 deletions streamer/util.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,99 @@

"""Utility functions used by multiple modules."""

import io
from typing import Optional


def is_url(output_location: str) -> bool:
"""Returns True if the output location is a URL."""
return (output_location.startswith('http:') or
output_location.startswith('https:'))
return output_location.startswith(('http://',
'https://'))


class RequestBodyAsFileIO(io.BufferedIOBase):
"""A class that provides a layer of access to an HTTP request body. It provides
an interface to treat a request body (of type `io.BufferedIOBase`) as a file.
Since a request body does not have an `EOF`, this class will encapsulate the
logic of using Content-Length or chunk size to provide an emulated `EOF`.
This implementation is much faster than storing the request body
in the filesystem then reading it with an `EOF` included.
"""

def __init__(self, rfile: io.BufferedIOBase, content_length: Optional[int]):
super().__init__()
self._body = rfile
# Decide whether this is a chunked request or not based on content length.
if content_length is not None:
self._is_chunked = False
self._left_to_read = content_length
else:
self._is_chunked = True
self._last_chunk_read = False
self._buffer = b''

def read(self, blocksize: Optional[int] = None) -> bytes:
"""This method reads `self.body` incrementally with each call.
This is done because if we try to use `read()` on `self._body` it will wait
forever for an `EOF` which is not present and will never be.
This method -like the original `read()`- will read up to (but not more than)
`blocksize` if it is a non-negative integer, and will read till `EOF` if
blocksize is None, a negative integer, or not passed.
"""

if self._is_chunked:
return self._read_chunked(blocksize)
else:
return self._read_not_chunked(blocksize)

def _read_chunked(self, blocksize: Optional[int] = None) -> bytes:
"""This method provides the read functionality from a request
body with chunked Transfer-Encoding.
"""

# For non-negative blocksize values.
if blocksize and blocksize >= 0:
# Keep buffering until we can fulfil the blocksize or there
# are no chunks left to buffer.
while blocksize > len(self._buffer) and not self._last_chunk_read:
byte_chunk_size = self._body.readline()
self._buffer += byte_chunk_size
int_chunk_size = int(byte_chunk_size.strip(), base=16)
self._buffer += self._body.read(int_chunk_size)
# Consume the CLRF after each chunk.
self._buffer += self._body.readline()
if int_chunk_size == 0:
# A zero sized chunk indicates that no more chunks left.
self._last_chunk_read = True
bytes_read, self._buffer = self._buffer[:blocksize], self._buffer[blocksize:]
return bytes_read
# When blocksize is a negative integer or None.
else:
bytes_read = b''
while True:
chunk = self._read_chunked(64 * 1024)
bytes_read += chunk
if chunk == b'':
return bytes_read

def _read_not_chunked(self, blocksize: Optional[int] = None) -> bytes:
"""This method provides the read functionality from a request
body of a known Content-Length.
"""

# Don't try to read if there is nothing to read.
if self._left_to_read == 0:
# This indicates `EOF` for the caller.
return b''
# For non-negative blocksize values.
if blocksize and blocksize >= 0:
size_to_read = min(blocksize, self._left_to_read)
self._left_to_read -= size_to_read
return self._body.read(size_to_read)
# When blocksize is a negative integer or None.
else:
size_to_read, self._left_to_read = self._left_to_read, 0
return self._body.read(size_to_read)

0 comments on commit e6d4db8

Please sign in to comment.