Skip to content

Commit

Permalink
Improve Info (#3418)
Browse files Browse the repository at this point in the history
* Re-export info from strawberry

* Add defaults to info

* Fix tests

* Add release notes

* Update all references to strawberry.Info

* Fix test

* Fix another test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove just

* Fix test

* Fix one more test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
patrick91 and pre-commit-ci[bot] authored Mar 21, 2024
1 parent 22838db commit f7be08b
Show file tree
Hide file tree
Showing 67 changed files with 496 additions and 311 deletions.
32 changes: 19 additions & 13 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def custom_context_getter(request: Request):
@strawberry.type
class Query:
@strawberry.field
def hello(self, info: Info[object, None]) -> str:
def hello(self, info: strawberry.Info[object, None]) -> str:
return info.context["custom"]


Expand Down Expand Up @@ -1192,7 +1192,7 @@ class MyDataType:
class Subscription:
@strawberry.subscription
async def my_data_subscription(
self, info: Info, groups: list[str]
self, info: strawberry.Info, groups: list[str]
) -> AsyncGenerator[MyDataType | None, None]:
yield None
async for message in info.context["ws"].channel_listen(
Expand All @@ -1207,7 +1207,7 @@ class Subscription:
class Subscription:
@strawberry.subscription
async def my_data_subscription(
self, info: Info, groups: list[str]
self, info: strawberry.Info, groups: list[str]
) -> AsyncGenerator[MyDataType | None, None]:
async with info.context["ws"].listen_to_channel("my_data", groups=groups) as cm:
yield None
Expand Down Expand Up @@ -1240,7 +1240,7 @@ class Query:
@strawberry.field
def get_testing(
self,
info: Info[None, None],
info: strawberry.Info,
id_: Annotated[uuid.UUID, strawberry.argument(name="id")],
) -> str | None:
return None
Expand Down Expand Up @@ -1694,7 +1694,7 @@ class Mutation:
@strawberry.mutation(extensions=[InputMutationExtension()])
def update_fruit_weight(
self,
info: Info,
info: strawberry.Info,
id: strawberry.ID,
weight: Annotated[
float,
Expand Down Expand Up @@ -2188,7 +2188,9 @@ class MyInput:


class MyFieldExtension(FieldExtension):
def resolve(self, next_: Callable[..., Any], source: Any, info: Info, **kwargs):
def resolve(
self, next_: Callable[..., Any], source: Any, info: strawberry.Info, **kwargs
):
# kwargs["my_input"] is instance of MyInput
...

Expand Down Expand Up @@ -2477,7 +2479,7 @@ def custom_context_getter(request: Request):
@strawberry.type
class Query:
@strawberry.field
def hello(self, info: Info[object, None]) -> str:
def hello(self, info: strawberry.Info[object, None]) -> str:
return info.context["custom"]


Expand Down Expand Up @@ -2680,7 +2682,11 @@ from strawberry.extensions import FieldExtension

class UpperCaseExtension(FieldExtension):
async def resolve_async(
self, next: Callable[..., Awaitable[Any]], source: Any, info: Info, **kwargs
self,
next: Callable[..., Awaitable[Any]],
source: Any,
info: strawberry.Info,
**kwargs
):
result = await next(source, info, **kwargs)
return str(result).upper()
Expand Down Expand Up @@ -4973,7 +4979,7 @@ and here's an example of how the new syntax works:
from strawberry.types import Info
def some_resolver(info: Info) -> str:
def some_resolver(info: strawberry.Info) -> str:
return info.context.get("some_key", "default")
Expand Down Expand Up @@ -5010,7 +5016,7 @@ class Query:
locations=[DirectiveLocation.FIELD],
description="Add frosting with ``value`` to a cake.",
)
def add_frosting(value: str, v: DirectiveValue[Cake], my_info: Info):
def add_frosting(value: str, v: DirectiveValue[Cake], my_info: strawberry.Info):
# Arbitrary argument name when using `DirectiveValue` is supported!
assert isinstance(v, Cake)
if (
Expand Down Expand Up @@ -5803,7 +5809,7 @@ Added the response object to `get_context` on the `flask` view. This means that
```python
@strawberry.field
def response_check(self, info: Info) -> bool:
def response_check(self, info: strawberry.Info) -> bool:
response: Response = info.context["response"]
response.status_code = 401
Expand Down Expand Up @@ -7459,7 +7465,7 @@ from starlette.background import BackgroundTask
@strawberry.mutation
def create_flavour(self, info: Info) -> str:
def create_flavour(self, info: strawberry.Info) -> str:
info.context["response"].background = BackgroundTask(...)
```
Expand Down Expand Up @@ -8551,7 +8557,7 @@ This release updates get_context in the django integration to also receive a tem
@strawberry.type
class Query:
@strawberry.field
def abc(self, info: Info) -> str:
def abc(self, info: strawberry.Info) -> str:
info.context.response.status_code = 418
return "ABC"
Expand Down
41 changes: 41 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
Release type: minor

This release improves the `Info` type, by adding support for default TypeVars
and by exporting it from the main module. This makes it easier to use `Info` in
your own code, without having to import it from `strawberry.types.info`.

### New export

By exporting `Info` from the main module, now you can do the follwing:

```python
import strawberry


@strawberry.type
class Query:
@strawberry.field
def info(self, info: strawberry.Info) -> str:
# do something with info
return "hello"
```

### Default TypeVars

The `Info` type now has default TypeVars, so you can use it without having to
specify the type arguments, like we did in the example above. Make sure to use
the latest version of Mypy or Pyright for this. It also means that you can only
pass one value to it if you only care about the context type:

```python
import strawberry

from .context import Context


@strawberry.type
class Query:
@strawberry.field
def info(self, info: strawberry.Info[Context]) -> str:
return info.context.user_id
```
3 changes: 1 addition & 2 deletions docs/extensions/_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ schema = strawberry.Schema(
## API reference:

```python
class ExtensionName(an_argument=None):
...
class ExtensionName(an_argument=None): ...
```

#### `an_argument: Optional[str] = None`
Expand Down
6 changes: 2 additions & 4 deletions docs/extensions/add-validation-rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ from strawberry.extensions import AddValidationRules
from graphql import ValidationRule


class MyCustomRule(ValidationRule):
...
class MyCustomRule(ValidationRule): ...


schema = strawberry.Schema(
Expand All @@ -35,8 +34,7 @@ schema = strawberry.Schema(
## API reference:

```python
class AddValidationRules(validation_rules):
...
class AddValidationRules(validation_rules): ...
```

#### `validation_rules: List[Type[ASTValidationRule]]`
Expand Down
3 changes: 1 addition & 2 deletions docs/extensions/mask-errors.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ schema = strawberry.Schema(
```python
class MaskErrors(
should_mask_error=default_should_mask_error, error_message="Unexpected error."
):
...
): ...
```

#### `should_mask_error: Callable[[GraphQLError], bool] = default_should_mask_error`
Expand Down
3 changes: 1 addition & 2 deletions docs/extensions/max-aliases-limiter.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ schema = strawberry.Schema(
## API reference:

```python
class MaxAliasesLimiter(max_alias_count):
...
class MaxAliasesLimiter(max_alias_count): ...
```

#### `max_alias_count: int`
Expand Down
3 changes: 1 addition & 2 deletions docs/extensions/max-tokens-limiter.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ tokens, the server will respond with an error message.
## API reference:

```python
class MaxTokensLimiter(max_token_count):
...
class MaxTokensLimiter(max_token_count): ...
```

#### `max_token_count: int`
Expand Down
3 changes: 1 addition & 2 deletions docs/extensions/opentelemetry.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ schema = strawberry.Schema(
## API reference:

```python
class OpenTelemetryExtension(arg_filter=None):
...
class OpenTelemetryExtension(arg_filter=None): ...
```

#### `arg_filter: Optional[ArgFilter]`
Expand Down
3 changes: 1 addition & 2 deletions docs/extensions/parser-cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ schema = strawberry.Schema(
## API reference:

```python
class ParserCache(maxsize=None):
...
class ParserCache(maxsize=None): ...
```

#### `maxsize: Optional[int] = None`
Expand Down
3 changes: 1 addition & 2 deletions docs/extensions/query-depth-limiter.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ schema = strawberry.Schema(
## API reference:

```python
class QueryDepthLimiter(max_depth, callback=None, should_ignore=None):
...
class QueryDepthLimiter(max_depth, callback=None, should_ignore=None): ...
```

#### `max_depth: int`
Expand Down
3 changes: 1 addition & 2 deletions docs/extensions/validation-cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ schema = strawberry.Schema(
## API reference:

```python
class ValidationCache(maxsize=None):
...
class ValidationCache(maxsize=None): ...
```

#### `maxsize: Optional[int] = None`
Expand Down
2 changes: 1 addition & 1 deletion docs/general/mutations.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class Mutation:
@strawberry.mutation(extensions=[InputMutationExtension()])
def update_fruit_weight(
self,
info: Info,
info: strawberry.Info,
id: strawberry.ID,
weight: Annotated[
float,
Expand Down
5 changes: 3 additions & 2 deletions docs/general/subscriptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ import asyncio
from typing import AsyncGenerator

import strawberry
from strawberry.types import Info

from .auth import authenticate_token

Expand All @@ -128,7 +127,9 @@ class Query:
@strawberry.type
class Subscription:
@strawberry.subscription
async def count(self, info: Info, target: int = 100) -> AsyncGenerator[int, None]:
async def count(
self, info: strawberry.Info, target: int = 100
) -> AsyncGenerator[int, None]:
connection_params: dict = info.context.get("connection_params")
token: str = connection_params.get(
"authToken"
Expand Down
7 changes: 1 addition & 6 deletions docs/guides/authentication.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ from functools import cached_property
import strawberry
from fastapi import FastAPI
from strawberry.fastapi import BaseContext, GraphQLRouter
from strawberry.types import Info as _Info
from strawberry.types.info import RootValueType


@strawberry.type
Expand All @@ -85,13 +83,10 @@ class Context(BaseContext):
return authorization_service.authorize(authorization)


Info = _Info[Context, RootValueType]


@strawberry.type
class Query:
@strawberry.field
def get_authenticated_user(self, info: Info) -> User | None:
def get_authenticated_user(self, info: strawberry.Info[Context]) -> User | None:
return info.context.user


Expand Down
3 changes: 1 addition & 2 deletions docs/guides/custom-extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,11 @@ check out [field extensions](field-extensions.md).
Note that `resolve` can also be implemented asynchronously.

```python
from strawberry.types import Info
from strawberry.extensions import SchemaExtension


class MyExtension(SchemaExtension):
def resolve(self, _next, root, info: Info, *args, **kwargs):
def resolve(self, _next, root, info: strawberry.Info, *args, **kwargs):
return _next(root, info, *args, **kwargs)
```

Expand Down
6 changes: 2 additions & 4 deletions docs/guides/dataloaders.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ provided.
from typing import List, Union, Any, Optional

import strawberry
from strawberry.types import Info
from strawberry.asgi import GraphQL
from strawberry.dataloader import DataLoader, AbstractCache

Expand Down Expand Up @@ -296,7 +295,7 @@ class MyGraphQL(GraphQL):
@strawberry.type
class Query:
@strawberry.field
async def get_user(self, info: Info, id: strawberry.ID) -> User:
async def get_user(self, info: strawberry.Info, id: strawberry.ID) -> User:
return await info.context["user_loader"].load(id)


Expand Down Expand Up @@ -383,7 +382,6 @@ example of this using our ASGI view:
from typing import List, Union, Any, Optional

import strawberry
from strawberry.types import Info
from strawberry.asgi import GraphQL
from strawberry.dataloader import DataLoader

Expand Down Expand Up @@ -411,7 +409,7 @@ class MyGraphQL(GraphQL):
@strawberry.type
class Query:
@strawberry.field
async def get_user(self, info: Info, id: strawberry.ID) -> User:
async def get_user(self, info: strawberry.Info, id: strawberry.ID) -> User:
return await info.context["user_loader"].load(id)


Expand Down
Loading

0 comments on commit f7be08b

Please sign in to comment.