Skip to content

Commit

Permalink
refactor parse_fn_type
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Oct 2, 2023
1 parent 84020dc commit 6c3c505
Showing 1 changed file with 41 additions and 41 deletions.
82 changes: 41 additions & 41 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ pub enum FnType {
}

impl FnType {
pub fn skip_first_rust_argument_in_python_signature(&self) -> bool {
match self {
FnType::Getter(_)
| FnType::Setter(_)
| FnType::Fn(_)
| FnType::FnClass
| FnType::FnNewClass
| FnType::FnModule => true,
FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => false,
}
}

pub fn self_arg(&self, cls: Option<&syn::Type>, error_mode: ExtractErrorMode) -> TokenStream {
match self {
FnType::Getter(st) | FnType::Setter(st) | FnType::Fn(st) => {
Expand Down Expand Up @@ -264,35 +276,35 @@ impl<'a> FnSpec<'a> {

let mut python_name = name.map(|name| name.value.0);

let (fn_type, skip_first_arg, fixed_convention) =
Self::parse_fn_type(sig, meth_attrs, &mut python_name)?;
let fn_type = Self::parse_fn_type(sig, meth_attrs, &mut python_name)?;
ensure_signatures_on_valid_method(&fn_type, signature.as_ref(), text_signature.as_ref())?;

let name = &sig.ident;
let ty = get_return_info(&sig.output);
let python_name = python_name.as_ref().unwrap_or(name).unraw();

let arguments: Vec<_> = if skip_first_arg {
sig.inputs
.iter_mut()
.skip(1)
.map(FnArg::parse)
.collect::<Result<_>>()?
} else {
sig.inputs
.iter_mut()
.map(FnArg::parse)
.collect::<Result<_>>()?
};
let arguments: Vec<_> = sig
.inputs
.iter_mut()
.skip(if fn_type.skip_first_rust_argument_in_python_signature() {
1
} else {
0
})
.map(FnArg::parse)
.collect::<Result<_>>()?;

let signature = if let Some(signature) = signature {
FunctionSignature::from_arguments_and_attribute(arguments, signature)?
} else {
FunctionSignature::from_arguments(arguments)?
};

let convention =
fixed_convention.unwrap_or_else(|| CallingConvention::from_signature(&signature));
let convention = if matches!(fn_type, FnType::FnNew | FnType::FnNewClass) {
CallingConvention::TpNew
} else {
CallingConvention::from_signature(&signature)
};

Ok(FnSpec {
tp: fn_type,
Expand All @@ -314,7 +326,7 @@ impl<'a> FnSpec<'a> {
sig: &syn::Signature,
meth_attrs: &mut Vec<syn::Attribute>,
python_name: &mut Option<syn::Ident>,
) -> Result<(FnType, bool, Option<CallingConvention>)> {
) -> Result<FnType> {
let mut method_attributes = parse_method_attributes(meth_attrs)?;

let name = &sig.ident;
Expand All @@ -334,16 +346,12 @@ impl<'a> FnSpec<'a> {
.map(|stripped| syn::Ident::new(stripped, name.span()))
};

let (fn_type, skip_first_arg, fixed_convention) = match method_attributes.as_mut_slice() {
[] => (
FnType::Fn(parse_receiver(
"static method needs #[staticmethod] attribute",
)?),
true,
None,
),
[MethodTypeAttribute::StaticMethod(_)] => (FnType::FnStatic, false, None),
[MethodTypeAttribute::ClassAttribute(_)] => (FnType::ClassAttribute, false, None),
let fn_type = match method_attributes.as_mut_slice() {
[] => FnType::Fn(parse_receiver(
"static method needs #[staticmethod] attribute",
)?),
[MethodTypeAttribute::StaticMethod(_)] => FnType::FnStatic,
[MethodTypeAttribute::ClassAttribute(_)] => FnType::ClassAttribute,
[MethodTypeAttribute::New(_)]
| [MethodTypeAttribute::New(_), MethodTypeAttribute::ClassMethod(_)]
| [MethodTypeAttribute::ClassMethod(_), MethodTypeAttribute::New(_)] => {
Expand All @@ -352,12 +360,12 @@ impl<'a> FnSpec<'a> {
}
*python_name = Some(syn::Ident::new("__new__", Span::call_site()));
if matches!(method_attributes.as_slice(), [MethodTypeAttribute::New(_)]) {
(FnType::FnNew, false, Some(CallingConvention::TpNew))
FnType::FnNew
} else {
(FnType::FnNewClass, true, Some(CallingConvention::TpNew))
FnType::FnNewClass
}
}
[MethodTypeAttribute::ClassMethod(_)] => (FnType::FnClass, true, None),
[MethodTypeAttribute::ClassMethod(_)] => FnType::FnClass,
[MethodTypeAttribute::Getter(_, name)] => {
if let Some(name) = name.take() {
ensure_spanned!(
Expand All @@ -369,11 +377,7 @@ impl<'a> FnSpec<'a> {
*python_name = strip_fn_name("get_");
}

(
FnType::Getter(parse_receiver("expected receiver for `#[getter]`")?),
true,
None,
)
FnType::Getter(parse_receiver("expected receiver for `#[getter]`")?)
}
[MethodTypeAttribute::Setter(_, name)] => {
if let Some(name) = name.take() {
Expand All @@ -386,11 +390,7 @@ impl<'a> FnSpec<'a> {
*python_name = strip_fn_name("set_");
}

(
FnType::Setter(parse_receiver("expected receiver for `#[setter]`")?),
true,
None,
)
FnType::Setter(parse_receiver("expected receiver for `#[setter]`")?)
}
[first, rest @ .., last] => {
// Join as many of the spans together as possible
Expand All @@ -416,7 +416,7 @@ impl<'a> FnSpec<'a> {
bail_spanned!(span => msg)
}
};
Ok((fn_type, skip_first_arg, fixed_convention))
Ok(fn_type)
}

/// Return a C wrapper function for this signature.
Expand Down

0 comments on commit 6c3c505

Please sign in to comment.