Skip to content

Commit

Permalink
Backport Secret type to async branch (#1706)
Browse files Browse the repository at this point in the history
* Define Secret type (#1546)

Signed-off-by: Mattt Zmuda <[email protected]>

* Fix linter errors (#1691)

* ruff --fix

Signed-off-by: Mattt Zmuda <[email protected]>

* Update ruff pyproject settings

Signed-off-by: Mattt Zmuda <[email protected]>

* Update ruff lint command in Makefile

Signed-off-by: Mattt Zmuda <[email protected]>

---------

Signed-off-by: Mattt Zmuda <[email protected]>

* Run ruff --fix python

Signed-off-by: Mattt Zmuda <[email protected]>

---------

Signed-off-by: Mattt Zmuda <[email protected]>
  • Loading branch information
mattt authored May 31, 2024
1 parent e07dc07 commit 8dc4890
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 26 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ reportUnusedExpression = "warning"
package-dir = { "" = "python" }

[tool.ruff]
select = [
lint.select = [
"E", # pycodestyle error
"F", # Pyflakes
"I", # isort
Expand All @@ -80,7 +80,7 @@ select = [
"B", # flake8-bugbear
"ANN", # flake8-annotations
]
ignore = [
lint.ignore = [
"E501", # Line too long
"S101", # Use of `assert` detected"
"S113", # Probable use of requests call without timeout
Expand All @@ -97,7 +97,7 @@ extend-exclude = [
"test-integration/test_integration/fixtures/*",
]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"python/cog/server/http.py" = [
"S104", # Possible binding to all interfaces
]
Expand Down
10 changes: 9 additions & 1 deletion python/cog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

from .predictor import BasePredictor
from .server.worker import emit_metric
from .types import AsyncConcatenateIterator, ConcatenateIterator, File, Input, Path
from .types import (
AsyncConcatenateIterator,
ConcatenateIterator,
File,
Input,
Path,
Secret,
)

try:
from ._version import __version__
Expand All @@ -20,4 +27,5 @@
"Input",
"Path",
"emit_metric",
"Secret",
]
18 changes: 9 additions & 9 deletions python/cog/command/ast_openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def predict(
format = {"format": "uri"} if name in ("Path", "File") else {}
return {}, {"title": "Output", "type": OPENAPI_TYPES.get(name, name), **format}
# it must be a custom object
schema: "JSONDict" = {name: parse_class(find(tree, name))}
schema: JSONDict = {name: parse_class(find(tree, name))}
return schema, {
"title": "Output",
"$ref": f"#/components/schemas/{name}",
Expand All @@ -524,10 +524,10 @@ def predict(
def extract_info(code: str) -> "JSONDict":
"""Parse the schemas from a file with a predict function"""
tree = ast.parse(code)
properties: "JSONDict" = {}
inputs: "JSONDict" = {"title": "Input", "type": "object", "properties": properties}
required: "list[str]" = []
schemas: "JSONDict" = {}
properties: JSONDict = {}
inputs: JSONDict = {"title": "Input", "type": "object", "properties": properties}
required: list[str] = []
schemas: JSONDict = {}
for arg, default in parse_args(tree):
if arg.arg == "self":
continue
Expand All @@ -544,7 +544,7 @@ def extract_info(code: str) -> "JSONDict":
kws = {}
else:
raise ValueError("Unexpected default value", default)
input: "JSONDict" = {"x-order": len(properties)}
input: JSONDict = {"x-order": len(properties)}
# need to handle other types?
arg_type = OPENAPI_TYPES.get(get_annotation(arg.annotation), "string")
if get_annotation(arg.annotation) in ("Path", "File"):
Expand All @@ -571,15 +571,15 @@ def extract_info(code: str) -> "JSONDict":
inputs["required"] = list(required)
# List[Path], list[Path], str, Iterator[str], MyOutput, Output
return_schema, output = parse_return_annotation(tree, "predict")
schema: "JSONDict" = json.loads(BASE_SCHEMA)
components: "JSONDict" = {
schema: JSONDict = json.loads(BASE_SCHEMA)
components: JSONDict = {
"Input": inputs,
"Output": output,
**schemas,
**return_schema,
}
# trust me, typechecker, I know BASE_SCHEMA
x: "JSONDict" = schema["components"]["schemas"] # type: ignore
x: JSONDict = schema["components"]["schemas"] # type: ignore
x.update(components)
return schema

Expand Down
13 changes: 11 additions & 2 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,17 @@
from .types import (
Path as CogPath,
)

ALLOWED_INPUT_TYPES: List[Type[Any]] = [str, int, float, bool, CogFile, CogPath]
from .types import Secret as CogSecret

ALLOWED_INPUT_TYPES: List[Type[Any]] = [
str,
int,
float,
bool,
CogFile,
CogPath,
CogSecret,
]


class BasePredictor(ABC):
Expand Down
8 changes: 4 additions & 4 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class SetupResult:

class TimeShareTracker:
def __init__(self) -> None:
self._time_shares_per_prediction: "dict[str, float]" = {}
self._time_shares_per_prediction: dict[str, float] = {}
self._last_updated_time_shares = 0.0

def update_time_shares(self) -> None:
Expand Down Expand Up @@ -122,8 +122,8 @@ def __init__(
self._shutdown_event = shutdown_event # __main__ waits for this event

self._upload_url = upload_url
self._predictions: "dict[str, tuple[schema.PredictionResponse, PredictionTask]]" = {}
self._predictions_in_flight: "set[str]" = set()
self._predictions: dict[str, tuple[schema.PredictionResponse, PredictionTask]] = {}
self._predictions_in_flight: set[str] = set()
# it would be lovely to merge these but it's not fully clear how best to handle it
# since idempotent requests can kinda come whenever?
# p: dict[str, PredictionTask]
Expand All @@ -142,7 +142,7 @@ def __init__(
# A pipe with which to communicate with the child worker.
events, child_events = _spawn.Pipe()
self._child = _ChildWorker(predictor_ref, child_events, tee_output)
self._events: "AsyncConnection[tuple[str, PublicEventType]]" = AsyncConnection(
self._events: AsyncConnection[tuple[str, PublicEventType]] = AsyncConnection(
events
)
# shutdown requested
Expand Down
8 changes: 4 additions & 4 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ class WorkerState(Enum):

class Mux:
def __init__(self, terminating: asyncio.Event) -> None:
self.outs: "defaultdict[str, asyncio.Queue[PublicEventType]]" = defaultdict(
self.outs: defaultdict[str, asyncio.Queue[PublicEventType]] = defaultdict(
asyncio.Queue
)
self.terminating = terminating
self.fatal: "Optional[FatalWorkerException]" = None
self.fatal: Optional[FatalWorkerException] = None

async def write(self, id: str, item: PublicEventType) -> None:
await self.outs[id].put(item)
Expand Down Expand Up @@ -199,11 +199,11 @@ def _loop_sync(self) -> None:
print(f"Got unexpected event: {ev}", file=sys.stderr)

async def _loop_async(self) -> None:
events: "AsyncConnection[tuple[str, PublicEventType]]" = AsyncConnection(
events: AsyncConnection[tuple[str, PublicEventType]] = AsyncConnection(
self._events
)
with events:
tasks: "dict[str, asyncio.Task[None]]" = {}
tasks: dict[str, asyncio.Task[None]] = {}
while True:
try:
ev = await events.recv()
Expand Down
15 changes: 14 additions & 1 deletion python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import httpx
import requests
from pydantic import Field
from pydantic import Field, SecretStr

FILENAME_ILLEGAL_CHARS = set("\u0000/")

Expand Down Expand Up @@ -44,6 +44,19 @@ def Input(
)


class Secret(SecretStr):
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
"""Defines what this type should be in openapi.json"""
field_schema.update(
{
"type": "string",
"format": "password",
"x-cog-secret": True,
}
)


class File(io.IOBase):
validate_always = True

Expand Down
6 changes: 6 additions & 0 deletions python/tests/server/fixtures/input_secret.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from cog import BasePredictor, Secret


class Predictor(BasePredictor):
def predict(self, secret: Secret) -> str:
return secret.get_secret_value()
10 changes: 10 additions & 0 deletions python/tests/server/test_http_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,16 @@ def test_union_integers(client):
assert resp.status_code == 422


@uses_predictor("input_secret")
def test_secret_str(client, match):
resp = client.post("/predictions", json={"input": {"secret": "foo"}})
assert resp.status_code == 200
assert resp.json() == match({"output": "foo", "status": "succeeded"})

resp = client.post("/predictions", json={"input": {"secret": {}}})
assert resp.status_code == 422


def test_untyped_inputs():
config = {"predict": _fixture_path("input_untyped")}
app = create_app(
Expand Down
1 change: 0 additions & 1 deletion python/tests/server/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import pytest
import pytest_asyncio

from cog.schema import PredictionRequest, PredictionResponse, Status, WebhookEvent
from cog.server.clients import ClientManager
from cog.server.eventtypes import (
Expand Down
10 changes: 9 additions & 1 deletion python/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
import responses
from cog.types import URLFile, get_filename_from_url, get_filename_from_urlopen
from cog.types import Secret, URLFile, get_filename_from_url, get_filename_from_urlopen


@responses.activate
Expand Down Expand Up @@ -123,3 +123,11 @@ def test_get_filename(url, filename):
def test_get_filename_from_urlopen(url, filename):
resp = urllib.request.urlopen(url) # noqa: S310
assert get_filename_from_urlopen(resp) == filename


def test_secret_type():
secret_value = "sw0rdf1$h" # noqa: S105
secret = Secret(secret_value)

assert secret.get_secret_value() == secret_value
assert str(secret) == "**********"

0 comments on commit 8dc4890

Please sign in to comment.