Skip to content

Commit

Permalink
[BLD] Python & misc lint fixes (#2046)
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb authored Apr 24, 2024
1 parent df65e5a commit 413174d
Show file tree
Hide file tree
Showing 87 changed files with 4,488 additions and 3,372 deletions.
14 changes: 12 additions & 2 deletions .github/workflows/chroma-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,15 @@ jobs:
- name: Install pre-commit
run: python -m pip install -r requirements_dev.txt
- name: Run pre-commit
# todo: remove || true once lint issues are resolved
run: pre-commit run --all-files || true
run: |
pre-commit run --all-files trailing-whitespace
pre-commit run --all-files mixed-line-ending
pre-commit run --all-files end-of-file-fixer
pre-commit run --all-files requirements-txt-fixer
pre-commit run --all-files check-xml
pre-commit run --all-files check-merge-conflict
pre-commit run --all-files check-case-conflict
pre-commit run --all-files check-docstring-first
pre-commit run --all-files black
pre-commit run --all-files flake8
pre-commit run --all-files prettier
2 changes: 1 addition & 1 deletion .github/workflows/release-helm-chart.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ jobs:
workflow_id: 'copy-oss-helm.yaml',
ref: 'main'
})
console.log(result)
console.log(result)
26 changes: 22 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
exclude: 'chromadb/proto/(chroma_pb2|coordinator_pb2)\.(py|pyi|py_grpc\.py)' # Generated files
exclude: 'chromadb/proto/(chroma_pb2|coordinator_pb2|logservice_pb2)\.(py|pyi|py_grpc\.py)' # Generated files
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: mixed-line-ending
- id: end-of-file-fixer
exclude: "go/migrations"
- id: requirements-txt-fixer
- id: check-yaml
args: ["--allow-multiple-documents"]
Expand All @@ -32,9 +33,26 @@ repos:
rev: "v1.2.0"
hooks:
- id: mypy
args: [--strict, --ignore-missing-imports, --follow-imports=silent, --disable-error-code=type-abstract, --config-file=./pyproject.toml]
additional_dependencies: ["types-requests", "pydantic", "overrides", "hypothesis", "pytest", "pypika", "numpy", "types-protobuf", "kubernetes"]

args:
[
--strict,
--ignore-missing-imports,
--follow-imports=silent,
--disable-error-code=type-abstract,
--config-file=./pyproject.toml,
]
additional_dependencies:
[
"types-requests",
"pydantic",
"overrides",
"hypothesis",
"pytest",
"pypika",
"numpy",
"types-protobuf",
"kubernetes",
]

- repo: https://github.com/pre-commit/mirrors-prettier
rev: "v3.1.0"
Expand Down
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,4 @@
"unordered_set": "cpp",
"algorithm": "cpp"
},
}
}
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ ENV CHROMA_TIMEOUT_KEEP_ALIVE 30
EXPOSE 8000

ENTRYPOINT ["/docker_entrypoint.sh"]
CMD [ "--workers ${CHROMA_WORKERS} --host ${CHROMA_HOST_ADDR} --port ${CHROMA_HOST_PORT} --proxy-headers --log-config ${CHROMA_LOG_CONFIG} --timeout-keep-alive ${CHROMA_TIMEOUT_KEEP_ALIVE}"]
CMD [ "--workers ${CHROMA_WORKERS} --host ${CHROMA_HOST_ADDR} --port ${CHROMA_HOST_PORT} --proxy-headers --log-config ${CHROMA_LOG_CONFIG} --timeout-keep-alive ${CHROMA_TIMEOUT_KEEP_ALIVE}"]
4 changes: 3 additions & 1 deletion chromadb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ def CloudClient(
# Always use SSL for cloud
settings.chroma_server_ssl_enabled = enable_ssl

settings.chroma_client_auth_provider = "chromadb.auth.token_authn.TokenAuthClientProvider"
settings.chroma_client_auth_provider = (
"chromadb.auth.token_authn.TokenAuthClientProvider"
)
settings.chroma_client_auth_credentials = api_key
settings.chroma_auth_token_transport_header = (
TokenTransportHeader.X_CHROMA_TOKEN.name
Expand Down
3 changes: 2 additions & 1 deletion chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ def modify(
validate_metadata(metadata)
if "hnsw:space" in metadata:
raise ValueError(
"Changing the distance function of a collection once it is created is not supported currently.")
"Changing the distance function of a collection once it is created is not supported currently."
)

self._client._modify(id=self.id, new_name=name, new_metadata=metadata)
if name:
Expand Down
50 changes: 21 additions & 29 deletions chromadb/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ClientAuthProvider(Component):
client requests. Client implementations (in our case, just the FastAPI
client) must inject these headers into their requests.
"""

def __init__(self, system: System) -> None:
super().__init__(system)

Expand All @@ -52,6 +53,7 @@ class UserIdentity:
_all_ information known about the user, and the AuthorizationProvider is
responsible for making decisions based on that information.
"""

user_id: str
tenant: Optional[str] = None
databases: Optional[List[str]] = None
Expand All @@ -71,20 +73,18 @@ class ServerAuthenticationProvider(Component):
The ServerAuthenticationProvider should return a UserIdentity object if the
request is authenticated for use by the ServerAuthorizationProvider.
"""

def __init__(self, system: System) -> None:
super().__init__(system)
self._ignore_auth_paths: Dict[
str, List[str]
] = system.settings.chroma_server_auth_ignore_paths
self.overwrite_singleton_tenant_database_access_from_auth = (
system.settings.
chroma_overwrite_singleton_tenant_database_access_from_auth
system.settings.chroma_overwrite_singleton_tenant_database_access_from_auth
)

@abstractmethod
def authenticate_or_raise(
self, headers: Headers
) -> UserIdentity:
def authenticate_or_raise(self, headers: Headers) -> UserIdentity:
pass

def ignore_operation(self, verb: str, path: str) -> bool:
Expand All @@ -100,13 +100,11 @@ def read_creds_or_creds_file(self) -> List[str]:
_creds = None

if self._system.settings.chroma_server_authn_credentials_file:
_creds_file = str(self._system.settings[
"chroma_server_authn_credentials_file"
])
_creds_file = str(
self._system.settings["chroma_server_authn_credentials_file"]
)
if self._system.settings.chroma_server_authn_credentials:
_creds = str(self._system.settings[
"chroma_server_authn_credentials"
])
_creds = str(self._system.settings["chroma_server_authn_credentials"])
if not _creds_file and not _creds:
raise ValueError(
"No credentials file or credentials found in "
Expand All @@ -122,9 +120,7 @@ def read_creds_or_creds_file(self) -> List[str]:
elif _creds_file:
with open(_creds_file, "r") as f:
return f.readlines()
raise ValueError(
"Should never happen"
)
raise ValueError("Should never happen")

def singleton_tenant_database_if_applicable(
self, user: Optional[UserIdentity]
Expand All @@ -144,15 +140,13 @@ def singleton_tenant_database_if_applicable(
- If the user has access to multiple tenants and/or databases this
function will return None for the corresponding value(s).
"""
if (not self.overwrite_singleton_tenant_database_access_from_auth or
not user):
if not self.overwrite_singleton_tenant_database_access_from_auth or not user:
return None, None
tenant = None
database = None
if user.tenant and user.tenant != "*":
tenant = user.tenant
if (user.databases and len(user.databases) == 1 and
user.databases[0] != "*"):
if user.databases and len(user.databases) == 1 and user.databases[0] != "*":
database = user.databases[0]
return tenant, database

Expand All @@ -161,6 +155,7 @@ class AuthzAction(str, Enum):
"""
The set of actions that can be authorized by the authorization provider.
"""

RESET = "system:reset"
CREATE_TENANT = "tenant:create_tenant"
GET_TENANT = "tenant:get_tenant"
Expand All @@ -187,6 +182,7 @@ class AuthzResource:
"""
The resource being accessed in an authorization request.
"""

tenant: Optional[str]
database: Optional[str]
collection: Optional[str]
Expand All @@ -202,23 +198,21 @@ class ServerAuthorizationProvider(Component):
ServerAuthorizationProvider should raise an exception if the request is not
authorized.
"""

def __init__(self, system: System) -> None:
super().__init__(system)

@abstractmethod
def authorize_or_raise(self,
user: UserIdentity,
action: AuthzAction,
resource: AuthzResource) -> None:
def authorize_or_raise(
self, user: UserIdentity, action: AuthzAction, resource: AuthzResource
) -> None:
pass

def read_config_or_config_file(self) -> List[str]:
_config_file = None
_config = None
if self._system.settings.chroma_server_authz_config_file:
_config_file = self._system.settings[
"chroma_server_authz_config_file"
]
_config_file = self._system.settings["chroma_server_authz_config_file"]
if self._system.settings.chroma_server_authz_config:
_config = str(self._system.settings["chroma_server_authz_config"])
if not _config_file and not _config:
Expand All @@ -231,10 +225,8 @@ def read_config_or_config_file(self) -> List[str]:
"Please provide only one."
)
if _config:
return [c for c in _config.split('\n') if c]
return [c for c in _config.split("\n") if c]
elif _config_file:
with open(_config_file, "r") as f:
return f.readlines()
raise ValueError(
"Should never happen"
)
raise ValueError("Should never happen")
35 changes: 16 additions & 19 deletions chromadb/auth/basic_authn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import base64
import bcrypt
import importlib
import logging

from fastapi import HTTPException
Expand Down Expand Up @@ -31,25 +30,20 @@ class BasicAuthClientProvider(ClientAuthProvider):
Client auth provider for basic auth. The credentials are passed as a
base64-encoded string in the Authorization header prepended with "Basic ".
"""

def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
system.settings.require("chroma_client_auth_credentials")
self._creds = SecretStr(
str(system.settings.chroma_client_auth_credentials)
)
self._creds = SecretStr(str(system.settings.chroma_client_auth_credentials))

@override
def authenticate(self) -> ClientAuthHeaders:
encoded = base64.b64encode(
f"{self._creds.get_secret_value()}".encode("utf-8")
).decode(
"utf-8"
)
).decode("utf-8")
return {
"Authorization": SecretStr(
f"Basic {encoded}"
),
"Authorization": SecretStr(f"Basic {encoded}"),
}


Expand All @@ -62,6 +56,7 @@ class BasicAuthenticationServerProvider(ServerAuthenticationProvider):
Expects tokens to be passed as a base64-encoded string in the Authorization
header prepended with "Basic".
"""

def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
Expand All @@ -73,8 +68,12 @@ def __init__(self, system: System) -> None:
if not line.strip():
continue
_raw_creds = [v for v in line.strip().split(":")]
if (_raw_creds and _raw_creds[0] and
len(_raw_creds) != 2 or not all(_raw_creds)):
if (
_raw_creds
and _raw_creds[0]
and len(_raw_creds) != 2
or not all(_raw_creds)
):
raise ValueError(
f"Invalid htpasswd credentials found: {_raw_creds}. "
"Lines must be exactly <username>:<bcrypt passwd>."
Expand All @@ -89,12 +88,11 @@ def __init__(self, system: System) -> None:
)
self._creds[username] = SecretStr(password)

@trace_method("BasicAuthenticationServerProvider.authenticate",
OpenTelemetryGranularity.ALL)
@trace_method(
"BasicAuthenticationServerProvider.authenticate", OpenTelemetryGranularity.ALL
)
@override
def authenticate_or_raise(
self, headers: Headers
) -> UserIdentity:
def authenticate_or_raise(self, headers: Headers) -> UserIdentity:
try:
_auth_header = headers["Authorization"]
_auth_header = _auth_header.replace("Basic ", "")
Expand All @@ -115,7 +113,6 @@ def authenticate_or_raise(

except Exception as e:
logger.error(
"BasicAuthenticationServerProvider.authenticate "
f"failed: {repr(e)}"
"BasicAuthenticationServerProvider.authenticate " f"failed: {repr(e)}"
)
raise HTTPException(status_code=403, detail="Forbidden")
23 changes: 11 additions & 12 deletions chromadb/auth/simple_rbac_authz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)

from hypothesis import Phase, settings

settings.register_profile("ci", phases=[Phase.generate, Phase.target])


Expand All @@ -33,12 +34,11 @@ class SimpleRBACAuthorizationProvider(ServerAuthorizationProvider):
For an example of an RBAC configuration file, see
examples/basic_functionality/authz/authz.yaml.
"""

def __init__(self, system: System) -> None:
super().__init__(system)
self._settings = system.settings
self._config = yaml.safe_load(
'\n'.join(self.read_config_or_config_file())
)
self._config = yaml.safe_load("\n".join(self.read_config_or_config_file()))

# We favor preprocessing here to avoid having to parse the config file
# on every request. This AuthorizationProvider does not support
Expand All @@ -51,23 +51,22 @@ def __init__(self, system: System) -> None:
_actions = self._config["roles_mapping"][user["role"]]["actions"]
self._permissions[user["id"]] = set(_actions)
logger.info(
"Authorization Provider SimpleRBACAuthorizationProvider "
"initialized"
"Authorization Provider SimpleRBACAuthorizationProvider " "initialized"
)

@trace_method(
"SimpleRBACAuthorizationProvider.authorize",
OpenTelemetryGranularity.ALL,
)
@override
def authorize_or_raise(self,
user: UserIdentity,
action: AuthzAction,
resource: AuthzResource) -> None:

def authorize_or_raise(
self, user: UserIdentity, action: AuthzAction, resource: AuthzResource
) -> None:
policy_decision = False
if (user.user_id in self._permissions and
action in self._permissions[user.user_id]):
if (
user.user_id in self._permissions
and action in self._permissions[user.user_id]
):
policy_decision = True

logger.debug(
Expand Down
Loading

0 comments on commit 413174d

Please sign in to comment.