Skip to content

Commit

Permalink
Follow-ups from recent type system changes (#713)
Browse files Browse the repository at this point in the history
  • Loading branch information
manopapad authored May 4, 2023
1 parent 8ac8075 commit 28853ba
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 56 deletions.
66 changes: 25 additions & 41 deletions legate/core/_lib/types.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from libcpp cimport bool
from libcpp.memory cimport unique_ptr
from libcpp.string cimport string
from libcpp.utility cimport move
from libcpp.vector cimport vector

import cython
Expand Down Expand Up @@ -133,53 +134,51 @@ cdef extern from "core/type/type_info.h" namespace "legate" nogil:
unique_ptr[Type] element_type, unsigned int N
) except+

cdef unique_ptr[Type] struct_type_raw_ptrs(
vector[Type*] field_types, bool
cdef unique_ptr[Type] struct_type(
vector[unique_ptr[Type]] field_types, bool
) except+


cdef Dtype from_ptr(Type* ty):
if <int> ty.code == FIXED_ARRAY:
return FixedArrayDtype.from_ptr(ty)
elif <int> ty.code == STRUCT:
return StructDtype.from_ptr(ty)
cdef Dtype from_ptr(unique_ptr[Type] ty):
cdef Dtype dtype
if <int> ty.get().code == FIXED_ARRAY:
dtype = FixedArrayDtype.__new__(FixedArrayDtype)
elif <int> ty.get().code == STRUCT:
dtype = StructDtype.__new__(StructDtype)
else:
return Dtype.from_ptr(ty)
dtype = Dtype.__new__(Dtype)
dtype._type = move(ty)
return dtype


cdef class Dtype:
cdef unique_ptr[Type] _type

@staticmethod
cdef Dtype from_ptr(Type* ty):
cdef Dtype dtype = Dtype.__new__(Dtype)
dtype._type.reset(ty)
return dtype

@staticmethod
def primitive_type(int code) -> Dtype:
return Dtype.from_ptr(primitive_type(<Type.Code> code).release())
return from_ptr(move(primitive_type(<Type.Code> code)))

@staticmethod
def string_type() -> Dtype:
return Dtype.from_ptr(string_type().release())
return from_ptr(move(string_type()))

@staticmethod
def fixed_array_type(
Dtype element_type, unsigned N
) -> FixedArrayDtype:
return FixedArrayDtype.from_ptr(
fixed_array_type(element_type._type.get().clone(), N).release()
return <FixedArrayDtype> from_ptr(
move(fixed_array_type(element_type._type.get().clone(), N))
)

@staticmethod
def struct_type(list field_types, bool align) -> StructDtype:
cdef vector[Type*] types
cdef vector[unique_ptr[Type]] types
for field_type in field_types:
types.push_back(
cython.cast(Dtype, field_type)._type.get().clone().release()
move((<Dtype> field_type)._type.get().clone())
)
return StructDtype.from_ptr(
struct_type_raw_ptrs(types, align).release()
return <StructDtype> from_ptr(
move(struct_type(move(types), align))
)

@property
Expand Down Expand Up @@ -220,21 +219,13 @@ cdef class Dtype:


cdef class FixedArrayDtype(Dtype):
@staticmethod
cdef FixedArrayDtype from_ptr(Type* ty):
cdef FixedArrayDtype dtype = FixedArrayDtype.__new__(
FixedArrayDtype
)
dtype._type.reset(ty)
return dtype

def num_elements(self) -> int:
cdef FixedArrayType* ty = <FixedArrayType*> self._type.get()
return ty.num_elements()

def element_type(self) -> Dtype:
cdef FixedArrayType* ty = <FixedArrayType*> self._type.get()
return from_ptr(ty.element_type().clone().release())
return from_ptr(move(ty.element_type().clone()))

def to_numpy_dtype(self):
arr_type = (
Expand All @@ -247,25 +238,18 @@ cdef class FixedArrayDtype(Dtype):
def serialize(self, buf) -> None:
buf.pack_32bit_int(self.code)
buf.pack_32bit_int(self.uid)
buf.pack_32bit_int(self.num_elements())
buf.pack_32bit_uint(self.num_elements())
self.element_type().serialize(buf)


cdef class StructDtype(Dtype):
@staticmethod
cdef StructDtype from_ptr(Type* ty):
cdef StructDtype dtype = StructDtype.__new__(StructDtype)
dtype._type.reset(ty)
return dtype

def num_fields(self) -> int:
cdef StructType* ty = <StructType*> self._type.get()
return ty.num_fields()

def field_type(self, int field_idx) -> Dtype:
cdef StructType* ty = <StructType*> self._type.get()
field_type = ty.field_type(field_idx).clone().release()
return from_ptr(field_type)
return from_ptr(move(ty.field_type(field_idx).clone()))

def aligned(self) -> bool:
cdef StructType* ty = <StructType*> self._type.get()
Expand All @@ -288,6 +272,6 @@ cdef class StructDtype(Dtype):
buf.pack_32bit_int(self.code)
num_fields = self.num_fields()
buf.pack_32bit_int(self.uid)
buf.pack_32bit_int(num_fields)
buf.pack_32bit_uint(num_fields)
for field_idx in range(num_fields):
self.field_type(field_idx).serialize(buf)
4 changes: 2 additions & 2 deletions src/core/data/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Scalar {
* @brief Creates a shared `Scalar` with an existing allocation. The caller is responsible
* for passing in a sufficiently big allocation.
*
* @param code Type code of the scalar(s)
* @param type Type of the scalar(s)
* @param data Allocation containing the data.
*/
Scalar(std::unique_ptr<Type> type, const void* data);
Expand Down Expand Up @@ -86,7 +86,7 @@ class Scalar {
*
* @return Data type
*/
const Type* type() const { return type_.get(); }
const Type& type() const { return *type_; }
/**
* @brief Returns the size of allocation for the `Scalar`.
*
Expand Down
9 changes: 0 additions & 9 deletions src/core/type/type_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,6 @@ std::unique_ptr<Type> struct_type(std::vector<std::unique_ptr<Type>>&& field_typ
align);
}

std::unique_ptr<Type> struct_type_raw_ptrs(std::vector<Type*> _field_types,
bool align) noexcept(false)
{
std::vector<std::unique_ptr<Type>> field_types;
for (auto field_type : _field_types) field_types.emplace_back(field_type);
return std::make_unique<StructType>(
Runtime::get_runtime()->get_type_uid(), std::move(field_types), align);
}

std::ostream& operator<<(std::ostream& ostream, const Type::Code& code)
{
ostream << static_cast<int32_t>(code);
Expand Down
4 changes: 0 additions & 4 deletions src/core/type/type_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,6 @@ std::unique_ptr<Type> fixed_array_type(std::unique_ptr<Type> element_type,
std::unique_ptr<Type> struct_type(std::vector<std::unique_ptr<Type>>&& field_types,
bool align = false) noexcept(false);

// The caller transfers ownership of the Type objects
std::unique_ptr<Type> struct_type_raw_ptrs(std::vector<Type*> field_types,
bool align = false) noexcept(false);

std::ostream& operator<<(std::ostream&, const Type::Code&);

std::ostream& operator<<(std::ostream&, const Type&);
Expand Down

0 comments on commit 28853ba

Please sign in to comment.