Skip to content

Commit

Permalink
[embind] Support annotating return values as nonnull.
Browse files Browse the repository at this point in the history
Embind's subclass `implement` methods were generated as returning
`Class | null` after the changes to pointer types in emscripten-core#22184. This could
be considered a regression as the implement method would never return
null.

Previously, we had special handling so constructors were marked as
nonnull so in the TS definitions we didn't add `| null`. I've generalized
this approach to work for all function bindings so they can now use
a `nonnull<ret_val>` policy too avoid the `| null`.
  • Loading branch information
brendandahl committed Sep 13, 2024
1 parent d0d9966 commit e651147
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 24 deletions.
6 changes: 4 additions & 2 deletions src/embind/embind.js
Original file line number Diff line number Diff line change
Expand Up @@ -1941,7 +1941,8 @@ var LibraryEmbind = {
rawInvoker,
context,
isPureVirtual,
isAsync) => {
isAsync,
isNonnullReturn) => {
var rawArgTypes = heap32VectorToArray(argCount, rawArgTypesAddr);
methodName = readLatin1String(methodName);
methodName = getFunctionName(methodName);
Expand Down Expand Up @@ -2077,7 +2078,8 @@ var LibraryEmbind = {
invokerSignature,
rawInvoker,
fn,
isAsync) => {
isAsync,
isNonnullReturn) => {
var rawArgTypes = heap32VectorToArray(argCount, rawArgTypesAddr);
methodName = readLatin1String(methodName);
methodName = getFunctionName(methodName);
Expand Down
24 changes: 13 additions & 11 deletions src/embind/embind_gen.js
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ var LibraryEmbind = {
},
$FunctionDefinition__deps: ['$createJsInvoker', '$createJsInvokerSignature', '$emittedFunctions'],
$FunctionDefinition: class {
constructor(name, returnType, argumentTypes, functionIndex, thisType = null, isConstructor = false, isAsync = false) {
constructor(name, returnType, argumentTypes, functionIndex, thisType = null, isNonnullReturn = false, isAsync = false) {
this.name = name;
this.returnType = returnType;
this.argumentTypes = argumentTypes;
this.functionIndex = functionIndex;
this.thisType = thisType;
this.isConstructor = isConstructor;
this.isNonnullReturn = isNonnullReturn;
this.isAsync = isAsync;
}

Expand Down Expand Up @@ -80,7 +80,7 @@ var LibraryEmbind = {
// Constructors can return a pointer, but it will be a non-null pointer.
// Change the return type to the class type so the TS output doesn't
// have `| null`.
if (this.isConstructor && this.returnType instanceof PointerDefinition) {
if (this.isNonnullReturn && this.returnType instanceof PointerDefinition) {
returnType = this.returnType.classType;
}
out.push(`): ${nameMap(returnType, true)}`);
Expand Down Expand Up @@ -463,7 +463,7 @@ var LibraryEmbind = {
registerType(id, new IntegerType(id));
},
$createFunctionDefinition__deps: ['$FunctionDefinition', '$heap32VectorToArray', '$readLatin1String', '$Argument', '$whenDependentTypesAreResolved', '$getFunctionName', '$getFunctionArgsName', '$PointerDefinition', '$ClassDefinition'],
$createFunctionDefinition: (name, argCount, rawArgTypesAddr, functionIndex, hasThis, isConstructor, isAsync, cb) => {
$createFunctionDefinition: (name, argCount, rawArgTypesAddr, functionIndex, hasThis, isNonnullReturn, isAsync, cb) => {
const argTypes = heap32VectorToArray(argCount, rawArgTypesAddr);
name = typeof name === 'string' ? name : readLatin1String(name);

Expand Down Expand Up @@ -493,7 +493,7 @@ var LibraryEmbind = {
args.push(new Argument(`_${i - argStart}`, argTypes[i]));
}
}
const funcDef = new FunctionDefinition(name, returnType, args, functionIndex, thisType, isConstructor, isAsync);
const funcDef = new FunctionDefinition(name, returnType, args, functionIndex, thisType, isNonnullReturn, isAsync);
cb(funcDef);
return [];
});
Expand Down Expand Up @@ -544,8 +544,8 @@ var LibraryEmbind = {
// TODO
},
_embind_register_function__deps: ['$moduleDefinitions', '$createFunctionDefinition'],
_embind_register_function: (name, argCount, rawArgTypesAddr, signature, rawInvoker, fn, isAsync) => {
createFunctionDefinition(name, argCount, rawArgTypesAddr, fn, false, false, isAsync, (funcDef) => {
_embind_register_function: (name, argCount, rawArgTypesAddr, signature, rawInvoker, fn, isAsync, isNonnullReturn) => {
createFunctionDefinition(name, argCount, rawArgTypesAddr, fn, false, isNonnullReturn, isAsync, (funcDef) => {
moduleDefinitions.push(funcDef);
});
},
Expand Down Expand Up @@ -605,8 +605,9 @@ var LibraryEmbind = {
rawInvoker,
context,
isPureVirtual,
isAsync) {
createFunctionDefinition(methodName, argCount, rawArgTypesAddr, context, true, false, isAsync, (funcDef) => {
isAsync,
isNonnullReturn) {
createFunctionDefinition(methodName, argCount, rawArgTypesAddr, context, true, isNonnullReturn, isAsync, (funcDef) => {
const classDef = funcDef.thisType;
classDef.methods.push(funcDef);
});
Expand Down Expand Up @@ -644,10 +645,11 @@ var LibraryEmbind = {
invokerSignature,
rawInvoker,
fn,
isAsync) {
isAsync,
isNonnullReturn) {
whenDependentTypesAreResolved([], [rawClassType], function(classType) {
classType = classType[0];
createFunctionDefinition(methodName, argCount, rawArgTypesAddr, fn, false, false, isAsync, (funcDef) => {
createFunctionDefinition(methodName, argCount, rawArgTypesAddr, fn, false, isNonnullReturn, isAsync, (funcDef) => {
classType.staticMethods.push(funcDef);
});
return [];
Expand Down
64 changes: 53 additions & 11 deletions system/include/emscripten/bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ void _embind_register_function(
const char* signature,
GenericFunction invoker,
GenericFunction function,
bool isAsync);
bool isAsync,
bool isNonnullReturn);

void _embind_register_value_array(
TYPEID tupleType,
Expand Down Expand Up @@ -182,7 +183,8 @@ void _embind_register_class_function(
GenericFunction invoker,
void* context,
unsigned isPureVirtual,
bool isAsync);
bool isAsync,
bool isNonnullReturn);

void _embind_register_class_property(
TYPEID classType,
Expand All @@ -204,7 +206,8 @@ void _embind_register_class_class_function(
const char* invokerSignature,
GenericFunction invoker,
GenericFunction method,
bool isAsync);
bool isAsync,
bool isNonnullReturn);

void _embind_register_class_class_property(
TYPEID classType,
Expand Down Expand Up @@ -338,6 +341,15 @@ struct pure_virtual {
};
};

template<typename Slot>
struct nonnull {
static_assert(std::is_same<Slot, ret_val>::value, "Only nonnull return values are currently supported.");
template<typename InputType, int Index>
struct Transform {
typedef InputType type;
};
};

namespace return_value_policy {

struct take_ownership : public allow_raw_pointers {};
Expand Down Expand Up @@ -380,6 +392,11 @@ struct isPolicy<emscripten::pure_virtual, Rest...> {
static constexpr bool value = true;
};

template<typename T, typename... Rest>
struct isPolicy<emscripten::nonnull<T>, Rest...> {
static constexpr bool value = true;
};

template<typename T, typename... Rest>
struct isPolicy<T, Rest...> {
static constexpr bool value = isPolicy<Rest...>::value;
Expand Down Expand Up @@ -428,6 +445,24 @@ struct isAsync<> {
static constexpr bool value = false;
};

template<typename... Policies>
struct isNonnullReturn;

template<typename... Rest>
struct isNonnullReturn<nonnull<ret_val>, Rest...> {
static constexpr bool value = true;
};

template<typename T, typename... Rest>
struct isNonnullReturn<T, Rest...> {
static constexpr bool value = isNonnullReturn<Rest...>::value;
};

template<>
struct isNonnullReturn<> {
static constexpr bool value = false;
};

}

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -640,7 +675,8 @@ void function(const char* name, ReturnType (*fn)(Args...), Policies...) {
getSignature(invoke),
reinterpret_cast<GenericFunction>(invoke),
reinterpret_cast<GenericFunction>(fn),
isAsync<Policies...>::value);
isAsync<Policies...>::value,
isNonnullReturn<Policies...>::value);
}

namespace internal {
Expand Down Expand Up @@ -1516,7 +1552,8 @@ struct RegisterClassMethod<ReturnType (ClassType::*)(Args...)> {
reinterpret_cast<GenericFunction>(invoke),
getContext(memberFunction),
isPureVirtual<Policies...>::value,
isAsync<Policies...>::value);
isAsync<Policies...>::value,
isNonnullReturn<Policies...>::value);
}
};

Expand Down Expand Up @@ -1545,7 +1582,8 @@ struct RegisterClassMethod<ReturnType (ClassType::*)(Args...) const> {
reinterpret_cast<GenericFunction>(invoke),
getContext(memberFunction),
isPureVirtual<Policies...>::value,
isAsync<Policies...>::value);
isAsync<Policies...>::value,
isNonnullReturn<Policies...>::value);
}
};

Expand Down Expand Up @@ -1573,7 +1611,8 @@ struct RegisterClassMethod<ReturnType (*)(ThisType, Args...)> {
reinterpret_cast<GenericFunction>(invoke),
getContext(function),
false,
isAsync<Policies...>::value);
isAsync<Policies...>::value,
isNonnullReturn<Policies...>::value);
}
};

Expand Down Expand Up @@ -1601,7 +1640,8 @@ struct RegisterClassMethod<std::function<ReturnType (ThisType, Args...)>> {
reinterpret_cast<GenericFunction>(invoke),
getContext(function),
false,
isAsync<Policies...>::value);
isAsync<Policies...>::value,
isNonnullReturn<Policies...>::value);
}
};

Expand All @@ -1623,7 +1663,8 @@ struct RegisterClassMethod<ReturnType (ThisType, Args...)> {
reinterpret_cast<GenericFunction>(invoke),
getContext(callable),
false,
isAsync<Policies...>::value);
isAsync<Policies...>::value,
isNonnullReturn<Policies...>::value);
}
};

Expand Down Expand Up @@ -1752,7 +1793,7 @@ class class_ {
class_function(
"implement",
&wrapped_new<WrapperType*, WrapperType, val, ConstructorArgs...>,
allow_raw_pointer<ret_val>())
allow_raw_pointer<ret_val>(), nonnull<ret_val>())
.class_function(
"extend",
&wrapped_extend<WrapperType>)
Expand Down Expand Up @@ -1940,7 +1981,8 @@ class class_ {
getSignature(invoke),
reinterpret_cast<GenericFunction>(invoke),
reinterpret_cast<GenericFunction>(classMethod),
isAsync<Policies...>::value);
isAsync<Policies...>::value,
isNonnullReturn<Policies...>::value);
return *this;
}

Expand Down
19 changes: 19 additions & 0 deletions test/other/embind_tsgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ int smart_ptr_function(std::shared_ptr<ClassWithSmartPtrConstructor>) {

struct Obj {};
Obj* get_pointer(Obj* ptr) { return ptr; }
Obj* get_nonnull_pointer() { return new Obj(); }

int function_with_callback_param(CallbackType ct) {
ct(val("hello"));
Expand Down Expand Up @@ -127,6 +128,18 @@ class DerivedClass : public BaseClass {
int fn2(int x) { return 2; }
};

struct Interface {
virtual void invoke(const std::string& str) = 0;
virtual ~Interface() {}
};

struct InterfaceWrapper : public wrapper<Interface> {
EMSCRIPTEN_WRAPPER(InterfaceWrapper);
void invoke(const std::string& str) {
return call<void>("invoke", str);
}
};

EMSCRIPTEN_BINDINGS(Test) {
class_<Test>("Test")
.function("functionOne", &Test::function_one)
Expand All @@ -151,6 +164,7 @@ EMSCRIPTEN_BINDINGS(Test) {
&class_unique_ptr_returning_fn);
class_<Obj>("Obj");
function("getPointer", &get_pointer, allow_raw_pointers());
function("getNonnullPointer", &get_nonnull_pointer, allow_raw_pointers(), nonnull<ret_val>());

constant("an_int", 5);
constant("a_bool", false);
Expand Down Expand Up @@ -225,6 +239,11 @@ EMSCRIPTEN_BINDINGS(Test) {

class_<DerivedClass, base<BaseClass>>("DerivedClass")
.function("fn2", &DerivedClass::fn2);

class_<Interface>("Interface")
.function("invoke", &Interface::invoke, pure_virtual())
.allow_subclass<InterfaceWrapper>("InterfaceWrapper")
;
}

int Test::static_property = 42;
Expand Down
16 changes: 16 additions & 0 deletions test/other/embind_tsgen.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ export interface DerivedClass extends BaseClass {
delete(): void;
}

export interface Interface {
invoke(_0: EmbindString): void;
delete(): void;
}

export interface InterfaceWrapper extends Interface {
notifyOnDestruction(): void;
delete(): void;
}

export type ValArr = [ number, number, number ];

export type ValObj = {
Expand All @@ -117,6 +127,7 @@ interface EmbindModule {
class_unique_ptr_returning_fn(): Test;
Obj: {};
getPointer(_0: Obj | null): Obj | null;
getNonnullPointer(): Obj;
a_class_instance: Test;
an_enum: Bar;
Bar: {valueOne: BarValue<0>, valueTwo: BarValue<1>, valueThree: BarValue<2>};
Expand All @@ -141,6 +152,11 @@ interface EmbindModule {
};
BaseClass: {};
DerivedClass: {};
Interface: {
implement(_0: any): InterfaceWrapper;
extend(_0: EmbindString, _1: any): any;
};
InterfaceWrapper: {};
a_bool: boolean;
an_int: number;
optional_test(_0?: Foo): number | undefined;
Expand Down
16 changes: 16 additions & 0 deletions test/other/embind_tsgen_ignore_1.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@ export interface DerivedClass extends BaseClass {
delete(): void;
}

export interface Interface {
invoke(_0: EmbindString): void;
delete(): void;
}

export interface InterfaceWrapper extends Interface {
notifyOnDestruction(): void;
delete(): void;
}

export type ValArr = [ number, number, number ];

export type ValObj = {
Expand All @@ -126,6 +136,7 @@ interface EmbindModule {
class_unique_ptr_returning_fn(): Test;
Obj: {};
getPointer(_0: Obj | null): Obj | null;
getNonnullPointer(): Obj;
a_class_instance: Test;
an_enum: Bar;
Bar: {valueOne: BarValue<0>, valueTwo: BarValue<1>, valueThree: BarValue<2>};
Expand All @@ -150,6 +161,11 @@ interface EmbindModule {
};
BaseClass: {};
DerivedClass: {};
Interface: {
implement(_0: any): InterfaceWrapper;
extend(_0: EmbindString, _1: any): any;
};
InterfaceWrapper: {};
a_bool: boolean;
an_int: number;
optional_test(_0?: Foo): number | undefined;
Expand Down
Loading

0 comments on commit e651147

Please sign in to comment.