Skip to content

Commit

Permalink
[BUG]: URL Parsing And Validation (#1118)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Added additional validations to URLs - URLs like api-gw.aws.com/dev
will now trigger an error asking the user to correctly specify the URL
with http or https
- When the full URL (http(s)://example.com) is provided by the user, the
port parameter is ignored (debug message is logged). An assumption is
made that the URL is entirely defined, thus not requiring additional
alterations such as injecting the port.
    - Added negative test cases for invalid URLs

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python

## Documentation Changes
TBD
  • Loading branch information
tazarov authored Sep 11, 2023
1 parent ea73f05 commit 9c1979c
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 9 deletions.
37 changes: 32 additions & 5 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
from typing import Optional, cast
from typing import Sequence
from uuid import UUID
Expand Down Expand Up @@ -32,28 +33,54 @@
from chromadb.telemetry import Telemetry
from urllib.parse import urlparse, urlunparse, quote

logger = logging.getLogger(__name__)


class FastAPI(API):
_settings: Settings

@staticmethod
def _validate_host(host: str) -> None:
parsed = urlparse(host)
if "/" in host and parsed.scheme not in {"http", "https"}:
raise ValueError(
"Invalid URL. " f"Unrecognized protocol - {parsed.scheme}."
)
if "/" in host and (not host.startswith("http")):
raise ValueError(
"Invalid URL. "
"Seems that you are trying to pass URL as a host but without specifying the protocol. "
"Please add http:// or https:// to the host."
)

@staticmethod
def resolve_url(
chroma_server_host: str,
chroma_server_ssl_enabled: Optional[bool] = False,
default_api_path: Optional[str] = "",
chroma_server_http_port: int = 8000,
chroma_server_http_port: Optional[int] = 8000,
) -> str:
parsed = urlparse(chroma_server_host)
_skip_port = False
_chroma_server_host = chroma_server_host
FastAPI._validate_host(_chroma_server_host)
if _chroma_server_host.startswith("http"):
logger.debug("Skipping port as the user is passing a full URL")
_skip_port = True
parsed = urlparse(_chroma_server_host)

scheme = "https" if chroma_server_ssl_enabled else parsed.scheme or "http"
net_loc = parsed.netloc or parsed.hostname or chroma_server_host
port = parsed.port or chroma_server_http_port
port = (
":" + str(parsed.port or chroma_server_http_port) if not _skip_port else ""
)
path = parsed.path or default_api_path

if not path or path == net_loc or not path.endswith(default_api_path or ""):
if not path or path == net_loc:
path = default_api_path if default_api_path else ""
if not path.endswith(default_api_path or ""):
path = path + default_api_path if default_api_path else ""
full_url = urlunparse(
(scheme, f"{net_loc}:{port}", quote(path.replace("//", "/")), "", "", "")
(scheme, f"{net_loc}{port}", quote(path.replace("//", "/")), "", "", "")
)

return full_url
Expand Down
48 changes: 44 additions & 4 deletions chromadb/test/property/test_client_url.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional
from urllib.parse import urlparse

import pytest
from hypothesis import given, strategies as st

from chromadb.api.fastapi import FastAPI
Expand Down Expand Up @@ -28,7 +29,7 @@ def domain_strategy() -> st.SearchStrategy[str]:
return st.tuples(label, tld).map(".".join)


port_strategy = st.integers(min_value=1, max_value=65535)
port_strategy = st.one_of(st.integers(min_value=1, max_value=65535), st.none())

ssl_enabled_strategy = st.booleans()

Expand Down Expand Up @@ -56,8 +57,21 @@ def is_valid_url(url: str) -> bool:

def generate_valid_domain_url() -> st.SearchStrategy[str]:
return st.builds(
lambda url_scheme, hostname, url_path: f"{url_scheme}://{hostname}{url_path}",
url_scheme=st.sampled_from(["http", "https"]),
lambda url_scheme, hostname, url_path: f"{url_scheme}{hostname}{url_path}",
url_scheme=st.sampled_from(["http://", "https://"]),
hostname=domain_strategy(),
url_path=url_path_strategy(),
)


def generate_invalid_domain_url() -> st.SearchStrategy[str]:
return st.builds(
lambda url_scheme, hostname, url_path: f"{url_scheme}{hostname}{url_path}",
url_scheme=st.builds(
lambda scheme, suffix: f"{scheme}{suffix}",
scheme=st.text(max_size=10),
suffix=st.sampled_from(["://", ":///", ":////", ""]),
),
hostname=domain_strategy(),
url_path=url_path_strategy(),
)
Expand All @@ -76,7 +90,7 @@ def generate_valid_domain_url() -> st.SearchStrategy[str]:
)
def test_url_resolve(
hostname: str,
port: int,
port: Optional[int],
ssl_enabled: bool,
default_api_path: Optional[str],
) -> None:
Expand All @@ -90,5 +104,31 @@ def test_url_resolve(
assert (
_url.startswith("https") if ssl_enabled else _url.startswith("http")
), f"Invalid URL: {_url} - SSL Enabled: {ssl_enabled}"
if hostname.startswith("http"):
assert ":" + str(port) not in _url, f"Port in URL not expected: {_url}"
else:
assert ":" + str(port) in _url, f"Port in URL expected: {_url}"
if default_api_path:
assert _url.endswith(default_api_path), f"Invalid URL: {_url}"


@given(
hostname=generate_invalid_domain_url(),
port=port_strategy,
ssl_enabled=ssl_enabled_strategy,
default_api_path=st.sampled_from(["/api/v1", "/api/v2", None]),
)
def test_resolve_invalid(
hostname: str,
port: Optional[int],
ssl_enabled: bool,
default_api_path: Optional[str],
) -> None:
with pytest.raises(ValueError) as e:
FastAPI.resolve_url(
chroma_server_host=hostname,
chroma_server_http_port=port,
chroma_server_ssl_enabled=ssl_enabled,
default_api_path=default_api_path,
)
assert "Invalid URL" in str(e.value)

0 comments on commit 9c1979c

Please sign in to comment.