Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pg ssh normalization squashed #5897

Merged
merged 7 commits into from
Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Developing an SSH Source
# Developing an SSH Connector

## Goal
Easy development of any source that needs the ability to connect to a resource via SSH Tunnel.
Easy development of any connector that needs the ability to connect to a resource via SSH Tunnel.

## Overview
Our SSH connector support is designed to be easy to plug into any existing connector. There are a few major pieces to consider:
1. Add SSH Configuration to the Spec - for SSH, we need to take in additional configuration, so we need to inject extra fields into the connector configuration.
2. Add SSH Logic to the Connector - before the connector code begins to execute we need to start an SSH tunnel. This library provides logic to create that tunnel (and clean it up).
3. Acceptance Testing - it is a good practice to include acceptance testing for the SSH version of a connector for at least one of the SSH types (password or ssh key). While unit testing for the SSH functionality exists in this package (coming soon), high-level acceptance testing to make sure this feature works with the individual connector belongs in the connector.
4. Normalization Support for Destinations - if the connector is a destination and supports normalization, there's a small change required in the normalization code to update the config so that dbt uses the right credentials for the SSH tunnel.

## How To

Expand All @@ -21,6 +22,15 @@ Our SSH connector support is designed to be easy to plug into any existing conne
### Acceptance Testing
1. The only difference between existing acceptance testing and acceptance testing with SSH is that the configuration that is used for testing needs to contain additional fields. You can see the `Postgres Source ssh key creds` in lastpass to see an example of what that might look like. Those credentials leverage an existing bastion host in our test infrastructure. (As future work, we want to get rid of the need to use a static bastion server and instead do it in docker so we can run it all locally.)

### Normalization Support for Destinations
1. The core functionality for ssh tunnelling with normalization is already in place but you'll need to add a small tweak to `transform_config/transform.py` in the normalization module. Find the function `transform_{connector}()` and add at the start:
```
if TransformConfig.is_ssh_tunnelling(config):
config = TransformConfig.get_ssh_altered_config(config, port_key="port", host_key="host")
```
Replace port_key and host_key as necessary. Look at `transform_postgres()` to see an example.
2. If your `host_key="host"` and `port_key="port"` then step 1 should be sufficient. However if the key names differ for your connector, you will also need to add some logic into `sshtunneling.sh` (within airbyte-workers) to handle this, as currently it assumes that the keys are exactly `host` and `port`.

## Misc

### How to wrap the protocol in an SSH Tunnel
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
*
!Dockerfile
!entrypoint.sh
!build/sshtunneling.sh
!setup.py
!normalization
!dbt-project-template
3 changes: 3 additions & 0 deletions airbyte-integrations/bases/base-normalization/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ RUN pip install cx_Oracle

COPY --from=airbyte/base-airbyte-protocol-python:0.1.1 /airbyte /airbyte

RUN apt-get update && apt-get install -y jq sshpass

WORKDIR /airbyte
COPY entrypoint.sh .
COPY build/sshtunneling.sh .

WORKDIR /airbyte/normalization_code
COPY normalization ./normalization
Expand Down
19 changes: 19 additions & 0 deletions airbyte-integrations/bases/base-normalization/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,28 @@ airbytePython {
}

dependencies {
implementation project(':airbyte-workers')
implementation files(project(':airbyte-integrations:bases:airbyte-protocol').airbyteDocker.outputs)
}

// we need to access the sshtunneling script from airbyte-workers for ssh support
task copySshScript(type: Copy, dependsOn: [project(':airbyte-workers').processResources]) {
from "${project(':airbyte-workers').buildDir}/resources/main"
into "${buildDir}"
include "sshtunneling.sh"
}

// make sure the copy task above worked (if it fails, it fails silently annoyingly)
task checkSshScriptCopy(type: Task, dependsOn: copySshScript) {
doFirst {
assert file("${buildDir}/sshtunneling.sh").exists() :
"Copy of sshtunneling.sh failed, check that it is present in airbyte-workers."
}
}

test.dependsOn checkSshScriptCopy
assemble.dependsOn checkSshScriptCopy

installReqs.dependsOn(":airbyte-integrations:bases:airbyte-protocol:installReqs")
integrationTest.dependsOn(build)

Expand Down
4 changes: 4 additions & 0 deletions airbyte-integrations/bases/base-normalization/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,12 @@ function main() {
case "$CMD" in
run)
configuredbt
. /airbyte/sshtunneling.sh
openssh $CONFIG_FILE "${PROJECT_DIR}/localsshport.json"
trap 'closessh' EXIT
# Run dbt to compile and execute the generated normalization models
dbt run --profiles-dir "${PROJECT_DIR}" --project-dir "${PROJECT_DIR}"
closessh
;;
configure-dbt)
configuredbt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import json
import os
import pkgutil
import socket
from enum import Enum
from typing import Any, Dict

Expand All @@ -53,6 +54,8 @@ def run(self, args):

transformed_config = self.transform(integration_type, original_config)
self.write_yaml_config(inputs["output_path"], transformed_config, "profiles.yml")
if self.is_ssh_tunnelling(original_config):
self.write_ssh_port(inputs["output_path"], self.pick_a_port())

@staticmethod
def parse(args):
Expand Down Expand Up @@ -104,6 +107,55 @@ def transform(self, integration_type: DestinationType, config: Dict[str, Any]):

return base_profile

@staticmethod
def is_ssh_tunnelling(config: Dict[str, Any]) -> bool:
tunnel_methods = ["SSH_KEY_AUTH", "SSH_PASSWORD_AUTH"]
if (
"tunnel_method" in config.keys()
and "tunnel_method" in config["tunnel_method"]
and config["tunnel_method"]["tunnel_method"].upper() in tunnel_methods
):
return True
else:
return False

@staticmethod
def is_port_free(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(("localhost", port))
except Exception as e:
print(f"port {port} unsuitable: {e}")
return False
else:
print(f"port {port} is free")
return True

@staticmethod
def pick_a_port() -> int:
"""
This function finds a free port, starting with 50001 and adding 1 until we find an open port.
"""
port_to_check = 50001 # just past start of dynamic port range (49152:65535)
while not TransformConfig.is_port_free(port_to_check):
port_to_check += 1
# error if we somehow hit end of port range
if port_to_check > 65535:
raise RuntimeError("Couldn't find a free port to use.")
return port_to_check

@staticmethod
def get_ssh_altered_config(config: Dict[str, Any], port_key: str = "port", host_key: str = "host") -> Dict[str, Any]:
"""
This should be called only if ssh tunneling is on.
It will return config with appropriately altered port and host values
"""
# make a copy of config rather than mutate in place
ssh_ready_config = {k: v for k, v in config.items()}
ssh_ready_config[port_key] = TransformConfig.pick_a_port()
ssh_ready_config[host_key] = "localhost"
return ssh_ready_config

@staticmethod
def transform_bigquery(config: Dict[str, Any]):
print("transform_bigquery")
Expand All @@ -126,6 +178,10 @@ def transform_bigquery(config: Dict[str, Any]):
@staticmethod
def transform_postgres(config: Dict[str, Any]):
print("transform_postgres")

if TransformConfig.is_ssh_tunnelling(config):
config = TransformConfig.get_ssh_altered_config(config, port_key="port", host_key="host")

# https://docs.getdbt.com/reference/warehouse-profiles/postgres-profile
dbt_config = {
"type": "postgres",
Expand Down Expand Up @@ -225,6 +281,19 @@ def write_yaml_config(output_path: str, config: Dict[str, Any], filename: str):
with open(os.path.join(output_path, filename), "w") as fh:
fh.write(yaml.dump(config))

@staticmethod
def write_ssh_port(output_path: str, port: int):
"""
This function writes a small json file with content like {"port":xyz}
This is being used only when ssh tunneling.
We do this because we need to decide on and save this port number into our dbt config
and then use that same port in sshtunneling.sh when opening the tunnel.
"""
if not os.path.exists(output_path):
os.makedirs(output_path)
with open(os.path.join(output_path, "localsshport.json"), "w") as fh:
json.dump({"port": port}, fh)


def main(args=None):
TransformConfig().run(args)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@


import os
import socket
import time

import pytest
from normalization.transform_catalog.transform import extract_schema
Expand All @@ -47,6 +49,104 @@ def before_all_tests(self, request):
yield
os.chdir(request.config.invocation_dir)

def test_is_ssh_tunnelling(self):
def single_test(config, expected_output):
assert TransformConfig.is_ssh_tunnelling(config) == expected_output

inputs = [
({}, False),
(
{
"type": "postgres",
"dbname": "my_db",
"host": "airbyte.io",
"pass": "password123",
"port": 5432,
"schema": "public",
"threads": 32,
"user": "a user",
},
False,
),
(
{
"type": "postgres",
"dbname": "my_db",
"host": "airbyte.io",
"pass": "password123",
"port": 5432,
"schema": "public",
"threads": 32,
"user": "a user",
"tunnel_method": {
"tunnel_host": "1.2.3.4",
"tunnel_method": "SSH_PASSWORD_AUTH",
"tunnel_port": 22,
"tunnel_user": "user",
"tunnel_user_password": "pass",
},
},
True,
),
(
{
"type": "postgres",
"dbname": "my_db",
"host": "airbyte.io",
"pass": "password123",
"port": 5432,
"schema": "public",
"threads": 32,
"user": "a user",
"tunnel_method": {
"tunnel_method": "SSH_KEY_AUTH",
},
},
True,
),
(
{
"type": "postgres",
"dbname": "my_db",
"host": "airbyte.io",
"pass": "password123",
"port": 5432,
"schema": "public",
"threads": 32,
"user": "a user",
"tunnel_method": {
"nothing": "nothing",
},
},
False,
),
]
for input_tuple in inputs:
single_test(input_tuple[0], input_tuple[1])

def test_is_port_free(self):
# to test that this accurately identifies 'free' ports, we'll find a 'free' port and then try to use it
test_port = 13055
while not TransformConfig.is_port_free(test_port):
test_port += 1
if test_port > 65535:
raise RuntimeError("couldn't find a free port...")

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", test_port))
# if we haven't failed then we accurately identified a 'free' port.
# now we can test for accurate identification of 'in-use' port since we're using it
assert TransformConfig.is_port_free(test_port) is False

# and just for good measure now that our context manager is closed (and port open again)
time.sleep(1)
assert TransformConfig.is_port_free(test_port) is True

def test_pick_a_port(self):
supposedly_open_port = TransformConfig.pick_a_port()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", supposedly_open_port))

def test_transform_bigquery(self):
input = {"project_id": "my_project_id", "dataset_id": "my_dataset_id", "credentials_json": '{ "type": "service_account-json" }'}

Expand Down Expand Up @@ -108,6 +208,39 @@ def test_transform_postgres(self):
assert expected == actual
assert extract_schema(actual) == "public"

def test_transform_postgres_ssh(self):
input = {
"host": "airbyte.io",
"port": 5432,
"username": "a user",
"password": "password123",
"database": "my_db",
"schema": "public",
"tunnel_method": {
"tunnel_host": "1.2.3.4",
"tunnel_method": "SSH_PASSWORD_AUTH",
"tunnel_port": 22,
"tunnel_user": "user",
"tunnel_user_password": "pass",
},
}
port = TransformConfig.pick_a_port()

actual = TransformConfig().transform_postgres(input)
expected = {
"type": "postgres",
"dbname": "my_db",
"host": "localhost",
"pass": "password123",
"port": port,
"schema": "public",
"threads": 32,
"user": "a user",
}

assert expected == actual
assert extract_schema(actual) == "public"

def test_transform_snowflake(self):
input = {
"host": "http://123abc.us-east-7.aws.snowflakecomputing.com",
Expand Down
Loading