Skip to content

Commit

Permalink
refactor: Move arguments node-parsing logic into its own module (used…
Browse files Browse the repository at this point in the history
… by visitor and lambda expressions)
  • Loading branch information
pawamoy committed Feb 12, 2024
1 parent 3091660 commit ad68e65
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 75 deletions.
3 changes: 2 additions & 1 deletion src/griffe/agents/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from griffe.agents.nodes._docstrings import get_docstring
from griffe.agents.nodes._imports import relative_to_absolute
from griffe.agents.nodes._names import get_instance_names, get_name, get_names
from griffe.agents.nodes._parameters import get_call_keyword_arguments
from griffe.agents.nodes._parameters import get_call_keyword_arguments, get_parameters
from griffe.agents.nodes._runtime import ObjectNode
from griffe.agents.nodes._values import get_value, safe_get_value
from griffe.enumerations import ObjectKind
Expand Down Expand Up @@ -52,6 +52,7 @@
"get_instance_names",
"get_name",
"get_names",
"get_parameters",
"get_value",
"ObjectKind",
"ObjectNode",
Expand Down
73 changes: 71 additions & 2 deletions src/griffe/agents/nodes/_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any
from itertools import zip_longest
from typing import TYPE_CHECKING, Any, Iterable

from griffe.expressions import safe_get_expression
from griffe.enumerations import ParameterKind
from griffe.logger import get_logger

if TYPE_CHECKING:
Expand All @@ -16,6 +17,74 @@
logger = get_logger(__name__)


def get_parameters(node: ast.arguments) -> list[tuple[str, ast.AST, ParameterKind, str | ast.AST]]:
parameters = []

# TODO: probably some optimizations to do here
args_kinds_defaults: Iterable = reversed(
(
*zip_longest(
reversed(
(
*zip_longest(
node.posonlyargs,
[],
fillvalue=ParameterKind.positional_only,
),
*zip_longest(node.args, [], fillvalue=ParameterKind.positional_or_keyword),
),
),
reversed(node.defaults),
fillvalue=None,
),
),
)
arg: ast.arg
kind: ParameterKind
arg_default: ast.AST | None
for (arg, kind), arg_default in args_kinds_defaults:
parameters.append((arg.arg, arg.annotation, kind, arg_default))

if node.vararg:
parameters.append(
(
node.vararg.arg,
node.vararg.annotation,
ParameterKind.var_positional,
"()",
),
)

# TODO: probably some optimizations to do here
kwargs_defaults: Iterable = reversed(
(
*zip_longest(
reversed(node.kwonlyargs),
reversed(node.kw_defaults),
fillvalue=None,
),
),
)
kwarg: ast.arg
kwarg_default: ast.AST | None
for kwarg, kwarg_default in kwargs_defaults:
parameters.append(
(kwarg.arg, kwarg.annotation, ParameterKind.keyword_only, kwarg_default),
)

if node.kwarg:
parameters.append(
(
node.kwarg.arg,
node.kwarg.annotation,
ParameterKind.var_keyword,
"{}",
),
)

return parameters


def get_call_keyword_arguments(node: ast.Call, parent: Module | Class) -> dict[str, Any]:
"""Get the list of keyword argument names and values from a Call node.
Expand Down
85 changes: 13 additions & 72 deletions src/griffe/agents/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

import ast
from contextlib import suppress
from itertools import zip_longest
from typing import TYPE_CHECKING, Any, Iterable
from typing import TYPE_CHECKING, Any

from griffe.agents.nodes import (
ast_children,
Expand All @@ -20,6 +19,7 @@
get_docstring,
get_instance_names,
get_names,
get_parameters,
relative_to_absolute,
safe_get__all__,
)
Expand Down Expand Up @@ -361,78 +361,19 @@ def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels:
return

# handle parameters
parameters = Parameters()
annotation: str | Expr | None

posonlyargs = node.args.posonlyargs

# TODO: probably some optimizations to do here
args_kinds_defaults: Iterable = reversed(
(
*zip_longest(
reversed(
(
*zip_longest(
posonlyargs,
[],
fillvalue=ParameterKind.positional_only,
),
*zip_longest(node.args.args, [], fillvalue=ParameterKind.positional_or_keyword),
),
),
reversed(node.args.defaults),
fillvalue=None,
),
),
)
arg: ast.arg
kind: ParameterKind
arg_default: ast.AST | None
for (arg, kind), arg_default in args_kinds_defaults:
annotation = safe_get_annotation(arg.annotation, parent=self.current)
default = safe_get_expression(arg_default, parent=self.current, parse_strings=False)
parameters.add(Parameter(arg.arg, annotation=annotation, kind=kind, default=default))

if node.args.vararg:
annotation = safe_get_annotation(node.args.vararg.annotation, parent=self.current)
parameters.add(
parameters = Parameters(
*[
Parameter(
node.args.vararg.arg,
annotation=annotation,
kind=ParameterKind.var_positional,
default="()",
),
)

# TODO: probably some optimizations to do here
kwargs_defaults: Iterable = reversed(
(
*zip_longest(
reversed(node.args.kwonlyargs),
reversed(node.args.kw_defaults),
fillvalue=None,
),
),
name,
kind=kind,
annotation=safe_get_annotation(annotation, parent=self.current),
default=default
if isinstance(default, str)
else safe_get_expression(default, parent=self.current, parse_strings=False),
)
for name, annotation, kind, default in get_parameters(node.args)
],
)
kwarg: ast.arg
kwarg_default: ast.AST | None
for kwarg, kwarg_default in kwargs_defaults:
annotation = safe_get_annotation(kwarg.annotation, parent=self.current)
default = safe_get_expression(kwarg_default, parent=self.current, parse_strings=False)
parameters.add(
Parameter(kwarg.arg, annotation=annotation, kind=ParameterKind.keyword_only, default=default),
)

if node.args.kwarg:
annotation = safe_get_annotation(node.args.kwarg.annotation, parent=self.current)
parameters.add(
Parameter(
node.args.kwarg.arg,
annotation=annotation,
kind=ParameterKind.var_keyword,
default="{}",
),
)

function = Function(
name=node.name,
Expand Down

0 comments on commit ad68e65

Please sign in to comment.