diff --git a/crates/rune-core/src/protocol.rs b/crates/rune-core/src/protocol.rs index b3304609a..06e7a5ae2 100644 --- a/crates/rune-core/src/protocol.rs +++ b/crates/rune-core/src/protocol.rs @@ -414,5 +414,19 @@ impl Protocol { repr: None, doc: ["Test if the provided argument is a variant."], }; + + /// Function used for the question mark operation. + /// + /// Signature: `fn(self) -> Result`. + /// + /// Note that it uses the `Result` like [`std::ops::Try`] uses + /// [`ControlFlow`](std::ops::ControlFlow) i.e., for `Result::` + /// it should return `Result>` + pub const TRY: Protocol = Protocol { + name: "try", + hash: 0x5da1a80787003354, + repr: Some("value?"), + doc: ["Allows the `?` operator to apply to values of this type."], + }; } } diff --git a/crates/rune/src/runtime/vm.rs b/crates/rune/src/runtime/vm.rs index 5d0ba0ff9..c517673b4 100644 --- a/crates/rune/src/runtime/vm.rs +++ b/crates/rune/src/runtime/vm.rs @@ -2373,27 +2373,41 @@ impl Vm { fn op_try(&mut self, address: InstAddress, clean: usize, preserve: bool) -> VmResult { let return_value = vm_try!(self.stack.address(address)); - let unwrapped_value = match &return_value { - Value::Result(result) => match &*vm_try!(result.borrow_ref()) { - Result::Ok(value) => Some(value.clone()), - Result::Err(..) => None, - }, - Value::Option(option) => (*vm_try!(option.borrow_ref())).clone(), - other => { - return err(VmErrorKind::UnsupportedTryOperand { - actual: vm_try!(other.type_info()), - }); + let result = match &return_value { + Value::Result(result) => { + let ok = match &*vm_try!(result.borrow_ref()) { + Result::Ok(value) => Some(value.clone()), + Result::Err(..) => None, + }; + ok.ok_or(return_value) + } + Value::Option(option) => { + let some = (*vm_try!(option.borrow_ref())).clone(); + some.ok_or(return_value) + } + _ => { + if let CallResult::Unsupported(target) = + vm_try!(self.call_instance_fn(return_value, Protocol::TRY, ())) + { + return err(VmErrorKind::UnsupportedTryOperand { + actual: vm_try!(target.type_info()), + }); + } + vm_try!(>::from_value(vm_try!(self + .stack + .pop()))) } }; - if let Some(value) = unwrapped_value { - if preserve { - self.stack.push(value); - } + match result { + Ok(value) => { + if preserve { + self.stack.push(value); + } - VmResult::Ok(false) - } else { - VmResult::Ok(vm_try!(self.op_return_internal(return_value, clean))) + VmResult::Ok(false) + } + Err(err) => VmResult::Ok(vm_try!(self.op_return_internal(err, clean))), } } diff --git a/crates/rune/src/tests/vm_try.rs b/crates/rune/src/tests/vm_try.rs index a4e66fb05..0fb6e17e7 100644 --- a/crates/rune/src/tests/vm_try.rs +++ b/crates/rune/src/tests/vm_try.rs @@ -42,3 +42,34 @@ fn test_unwrap() { }; assert_eq!(out, Err(3)); } + +#[test] +fn custom_try() -> Result<()> { + #[derive(Any)] + struct CustomResult(bool); + let mut module = Module::new(); + module.ty::()?; + module.associated_function(Protocol::TRY, |r: CustomResult| { + r.0.then_some(42).ok_or(Err::<(), _>(0)) + })?; + + assert_eq!( + 42, + rune_n! { + &module, + (CustomResult(true),), + i64 => pub fn main(r) { r? } + } + ); + + assert_eq!( + Err(0), + rune_n! { + &module, + (CustomResult(false),), + Result<(), i64> => pub fn main(r) { r? } + } + ); + + Ok(()) +}