From 9c1979c9311acb8662fc22a7deac17907a9b128b Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Mon, 11 Sep 2023 17:58:20 +0300 Subject: [PATCH] [BUG]: URL Parsing And Validation (#1118) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Added additional validations to URLs - URLs like api-gw.aws.com/dev will now trigger an error asking the user to correctly specify the URL with http or https - When the full URL (http(s)://example.com) is provided by the user, the port parameter is ignored (debug message is logged). An assumption is made that the URL is entirely defined, thus not requiring additional alterations such as injecting the port. - Added negative test cases for invalid URLs ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python ## Documentation Changes TBD --- chromadb/api/fastapi.py | 37 ++++++++++++++--- chromadb/test/property/test_client_url.py | 48 +++++++++++++++++++++-- 2 files changed, 76 insertions(+), 9 deletions(-) diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 8498f9ec110..c08458a2fcb 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -1,4 +1,5 @@ import json +import logging from typing import Optional, cast from typing import Sequence from uuid import UUID @@ -32,28 +33,54 @@ from chromadb.telemetry import Telemetry from urllib.parse import urlparse, urlunparse, quote +logger = logging.getLogger(__name__) + class FastAPI(API): _settings: Settings + @staticmethod + def _validate_host(host: str) -> None: + parsed = urlparse(host) + if "/" in host and parsed.scheme not in {"http", "https"}: + raise ValueError( + "Invalid URL. " f"Unrecognized protocol - {parsed.scheme}." + ) + if "/" in host and (not host.startswith("http")): + raise ValueError( + "Invalid URL. " + "Seems that you are trying to pass URL as a host but without specifying the protocol. " + "Please add http:// or https:// to the host." + ) + @staticmethod def resolve_url( chroma_server_host: str, chroma_server_ssl_enabled: Optional[bool] = False, default_api_path: Optional[str] = "", - chroma_server_http_port: int = 8000, + chroma_server_http_port: Optional[int] = 8000, ) -> str: - parsed = urlparse(chroma_server_host) + _skip_port = False + _chroma_server_host = chroma_server_host + FastAPI._validate_host(_chroma_server_host) + if _chroma_server_host.startswith("http"): + logger.debug("Skipping port as the user is passing a full URL") + _skip_port = True + parsed = urlparse(_chroma_server_host) scheme = "https" if chroma_server_ssl_enabled else parsed.scheme or "http" net_loc = parsed.netloc or parsed.hostname or chroma_server_host - port = parsed.port or chroma_server_http_port + port = ( + ":" + str(parsed.port or chroma_server_http_port) if not _skip_port else "" + ) path = parsed.path or default_api_path - if not path or path == net_loc or not path.endswith(default_api_path or ""): + if not path or path == net_loc: path = default_api_path if default_api_path else "" + if not path.endswith(default_api_path or ""): + path = path + default_api_path if default_api_path else "" full_url = urlunparse( - (scheme, f"{net_loc}:{port}", quote(path.replace("//", "/")), "", "", "") + (scheme, f"{net_loc}{port}", quote(path.replace("//", "/")), "", "", "") ) return full_url diff --git a/chromadb/test/property/test_client_url.py b/chromadb/test/property/test_client_url.py index 992af981399..cc5df1e0514 100644 --- a/chromadb/test/property/test_client_url.py +++ b/chromadb/test/property/test_client_url.py @@ -1,6 +1,7 @@ from typing import Optional from urllib.parse import urlparse +import pytest from hypothesis import given, strategies as st from chromadb.api.fastapi import FastAPI @@ -28,7 +29,7 @@ def domain_strategy() -> st.SearchStrategy[str]: return st.tuples(label, tld).map(".".join) -port_strategy = st.integers(min_value=1, max_value=65535) +port_strategy = st.one_of(st.integers(min_value=1, max_value=65535), st.none()) ssl_enabled_strategy = st.booleans() @@ -56,8 +57,21 @@ def is_valid_url(url: str) -> bool: def generate_valid_domain_url() -> st.SearchStrategy[str]: return st.builds( - lambda url_scheme, hostname, url_path: f"{url_scheme}://{hostname}{url_path}", - url_scheme=st.sampled_from(["http", "https"]), + lambda url_scheme, hostname, url_path: f"{url_scheme}{hostname}{url_path}", + url_scheme=st.sampled_from(["http://", "https://"]), + hostname=domain_strategy(), + url_path=url_path_strategy(), + ) + + +def generate_invalid_domain_url() -> st.SearchStrategy[str]: + return st.builds( + lambda url_scheme, hostname, url_path: f"{url_scheme}{hostname}{url_path}", + url_scheme=st.builds( + lambda scheme, suffix: f"{scheme}{suffix}", + scheme=st.text(max_size=10), + suffix=st.sampled_from(["://", ":///", ":////", ""]), + ), hostname=domain_strategy(), url_path=url_path_strategy(), ) @@ -76,7 +90,7 @@ def generate_valid_domain_url() -> st.SearchStrategy[str]: ) def test_url_resolve( hostname: str, - port: int, + port: Optional[int], ssl_enabled: bool, default_api_path: Optional[str], ) -> None: @@ -90,5 +104,31 @@ def test_url_resolve( assert ( _url.startswith("https") if ssl_enabled else _url.startswith("http") ), f"Invalid URL: {_url} - SSL Enabled: {ssl_enabled}" + if hostname.startswith("http"): + assert ":" + str(port) not in _url, f"Port in URL not expected: {_url}" + else: + assert ":" + str(port) in _url, f"Port in URL expected: {_url}" if default_api_path: assert _url.endswith(default_api_path), f"Invalid URL: {_url}" + + +@given( + hostname=generate_invalid_domain_url(), + port=port_strategy, + ssl_enabled=ssl_enabled_strategy, + default_api_path=st.sampled_from(["/api/v1", "/api/v2", None]), +) +def test_resolve_invalid( + hostname: str, + port: Optional[int], + ssl_enabled: bool, + default_api_path: Optional[str], +) -> None: + with pytest.raises(ValueError) as e: + FastAPI.resolve_url( + chroma_server_host=hostname, + chroma_server_http_port=port, + chroma_server_ssl_enabled=ssl_enabled, + default_api_path=default_api_path, + ) + assert "Invalid URL" in str(e.value)