Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow-ups from recent type system changes #713

Merged
merged 1 commit into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 25 additions & 41 deletions legate/core/_lib/types.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from libcpp cimport bool
from libcpp.memory cimport unique_ptr
from libcpp.string cimport string
from libcpp.utility cimport move
from libcpp.vector cimport vector

import cython
Expand Down Expand Up @@ -133,53 +134,51 @@ cdef extern from "core/type/type_info.h" namespace "legate" nogil:
unique_ptr[Type] element_type, unsigned int N
) except+

cdef unique_ptr[Type] struct_type_raw_ptrs(
vector[Type*] field_types, bool
cdef unique_ptr[Type] struct_type(
vector[unique_ptr[Type]] field_types, bool
) except+


cdef Dtype from_ptr(Type* ty):
if <int> ty.code == FIXED_ARRAY:
return FixedArrayDtype.from_ptr(ty)
elif <int> ty.code == STRUCT:
return StructDtype.from_ptr(ty)
cdef Dtype from_ptr(unique_ptr[Type] ty):
cdef Dtype dtype
if <int> ty.get().code == FIXED_ARRAY:
dtype = FixedArrayDtype.__new__(FixedArrayDtype)
elif <int> ty.get().code == STRUCT:
dtype = StructDtype.__new__(StructDtype)
else:
return Dtype.from_ptr(ty)
dtype = Dtype.__new__(Dtype)
dtype._type = move(ty)
return dtype


cdef class Dtype:
cdef unique_ptr[Type] _type

@staticmethod
cdef Dtype from_ptr(Type* ty):
cdef Dtype dtype = Dtype.__new__(Dtype)
dtype._type.reset(ty)
return dtype

@staticmethod
def primitive_type(int code) -> Dtype:
return Dtype.from_ptr(primitive_type(<Type.Code> code).release())
return from_ptr(move(primitive_type(<Type.Code> code)))

@staticmethod
def string_type() -> Dtype:
return Dtype.from_ptr(string_type().release())
return from_ptr(move(string_type()))

@staticmethod
def fixed_array_type(
Dtype element_type, unsigned N
) -> FixedArrayDtype:
return FixedArrayDtype.from_ptr(
fixed_array_type(element_type._type.get().clone(), N).release()
return <FixedArrayDtype> from_ptr(
move(fixed_array_type(element_type._type.get().clone(), N))
)

@staticmethod
def struct_type(list field_types, bool align) -> StructDtype:
cdef vector[Type*] types
cdef vector[unique_ptr[Type]] types
for field_type in field_types:
types.push_back(
cython.cast(Dtype, field_type)._type.get().clone().release()
move((<Dtype> field_type)._type.get().clone())
)
return StructDtype.from_ptr(
struct_type_raw_ptrs(types, align).release()
return <StructDtype> from_ptr(
move(struct_type(move(types), align))
)

@property
Expand Down Expand Up @@ -220,21 +219,13 @@ cdef class Dtype:


cdef class FixedArrayDtype(Dtype):
@staticmethod
cdef FixedArrayDtype from_ptr(Type* ty):
cdef FixedArrayDtype dtype = FixedArrayDtype.__new__(
FixedArrayDtype
)
dtype._type.reset(ty)
return dtype

def num_elements(self) -> int:
cdef FixedArrayType* ty = <FixedArrayType*> self._type.get()
return ty.num_elements()

def element_type(self) -> Dtype:
cdef FixedArrayType* ty = <FixedArrayType*> self._type.get()
return from_ptr(ty.element_type().clone().release())
return from_ptr(move(ty.element_type().clone()))

def to_numpy_dtype(self):
arr_type = (
Expand All @@ -247,25 +238,18 @@ cdef class FixedArrayDtype(Dtype):
def serialize(self, buf) -> None:
buf.pack_32bit_int(self.code)
buf.pack_32bit_int(self.uid)
buf.pack_32bit_int(self.num_elements())
buf.pack_32bit_uint(self.num_elements())
self.element_type().serialize(buf)


cdef class StructDtype(Dtype):
@staticmethod
cdef StructDtype from_ptr(Type* ty):
cdef StructDtype dtype = StructDtype.__new__(StructDtype)
dtype._type.reset(ty)
return dtype

def num_fields(self) -> int:
cdef StructType* ty = <StructType*> self._type.get()
return ty.num_fields()

def field_type(self, int field_idx) -> Dtype:
cdef StructType* ty = <StructType*> self._type.get()
field_type = ty.field_type(field_idx).clone().release()
return from_ptr(field_type)
return from_ptr(move(ty.field_type(field_idx).clone()))

def aligned(self) -> bool:
cdef StructType* ty = <StructType*> self._type.get()
Expand All @@ -288,6 +272,6 @@ cdef class StructDtype(Dtype):
buf.pack_32bit_int(self.code)
num_fields = self.num_fields()
buf.pack_32bit_int(self.uid)
buf.pack_32bit_int(num_fields)
buf.pack_32bit_uint(num_fields)
for field_idx in range(num_fields):
self.field_type(field_idx).serialize(buf)
4 changes: 2 additions & 2 deletions src/core/data/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Scalar {
* @brief Creates a shared `Scalar` with an existing allocation. The caller is responsible
* for passing in a sufficiently big allocation.
*
* @param code Type code of the scalar(s)
* @param type Type of the scalar(s)
* @param data Allocation containing the data.
*/
Scalar(std::unique_ptr<Type> type, const void* data);
Expand Down Expand Up @@ -86,7 +86,7 @@ class Scalar {
*
* @return Data type
*/
const Type* type() const { return type_.get(); }
const Type& type() const { return *type_; }
/**
* @brief Returns the size of allocation for the `Scalar`.
*
Expand Down
9 changes: 0 additions & 9 deletions src/core/type/type_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,6 @@ std::unique_ptr<Type> struct_type(std::vector<std::unique_ptr<Type>>&& field_typ
align);
}

std::unique_ptr<Type> struct_type_raw_ptrs(std::vector<Type*> _field_types,
bool align) noexcept(false)
{
std::vector<std::unique_ptr<Type>> field_types;
for (auto field_type : _field_types) field_types.emplace_back(field_type);
return std::make_unique<StructType>(
Runtime::get_runtime()->get_type_uid(), std::move(field_types), align);
}

std::ostream& operator<<(std::ostream& ostream, const Type::Code& code)
{
ostream << static_cast<int32_t>(code);
Expand Down
4 changes: 0 additions & 4 deletions src/core/type/type_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,6 @@ std::unique_ptr<Type> fixed_array_type(std::unique_ptr<Type> element_type,
std::unique_ptr<Type> struct_type(std::vector<std::unique_ptr<Type>>&& field_types,
bool align = false) noexcept(false);

// The caller transfers ownership of the Type objects
std::unique_ptr<Type> struct_type_raw_ptrs(std::vector<Type*> field_types,
bool align = false) noexcept(false);

std::ostream& operator<<(std::ostream&, const Type::Code&);

std::ostream& operator<<(std::ostream&, const Type&);
Expand Down