Skip to content

Commit

Permalink
update R to changes from ARROW-3144 apache#4316
Browse files Browse the repository at this point in the history
  • Loading branch information
romainfrancois committed May 29, 2019
1 parent 49714fb commit 99f2b6b
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 65 deletions.
12 changes: 6 additions & 6 deletions r/R/RcppExports.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions r/R/dictionary.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,25 @@

active = list(
index_type = function() `arrow::DataType`$dispatch(DictionaryType__index_type(self)),
dictionary = function() shared_ptr(`arrow::Array`, DictionaryType__dictionary(self)),
value_type = function() `arrow::DataType`$dispatch(DictionaryType__value_type(self)),
name = function() DictionaryType__name(self),
ordered = function() DictionaryType__ordered(self)
)
)

#' dictionary type factory
#'
#' @param type indices type, e.g. [int32()]
#' @param values values array, typically an arrow array of strings
#' @param ordered Is this an ordered dictionary
#' @param index_type index type, e.g. [int32()]
#' @param value_type value type, probably [utf8()]
#' @param ordered Is this an ordered dictionary ?
#'
#' @return a [arrow::DictionaryType][arrow__DictionaryType]
#'
#' @export
dictionary <- function(type, values, ordered = FALSE) {
dictionary <- function(index_type, value_type, ordered = FALSE) {
assert_that(
inherits(type, "arrow::DataType"),
inherits(values, "arrow::Array")
inherits(index_type, "arrow::DataType"),
inherits(index_type, "arrow::DataType")
)
shared_ptr(`arrow::DictionaryType`, DictionaryType__initialize(type, values, ordered))
shared_ptr(`arrow::DictionaryType`, DictionaryType__initialize(index_type, value_type, ordered))
}
28 changes: 14 additions & 14 deletions r/src/RcppExports.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 11 additions & 24 deletions r/src/array_from_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,11 @@ std::shared_ptr<Array> MakeFactorArrayImpl(Rcpp::IntegerVector_ factor,
ArrayData::Make(std::make_shared<Type>(), n, std::move(buffers), null_count, 0);
auto array_indices = MakeArray(array_indices_data);

SEXP levels = Rf_getAttrib(factor, R_LevelsSymbol);
auto dict = MakeStringArray(levels);

std::shared_ptr<Array> out;
STOP_IF_NOT_OK(DictionaryArray::FromArrays(type, array_indices, &out));
STOP_IF_NOT_OK(DictionaryArray::FromArrays(type, array_indices, dict, &out));
return out;
}

Expand Down Expand Up @@ -741,22 +744,20 @@ Status GetConverter(const std::shared_ptr<DataType>& type,
}

template <typename Type>
std::shared_ptr<arrow::DataType> GetFactorTypeImpl(Rcpp::IntegerVector_ factor) {
auto dict_values = MakeStringArray(Rf_getAttrib(factor, R_LevelsSymbol));
auto dict_type =
dictionary(std::make_shared<Type>(), dict_values, Rf_inherits(factor, "ordered"));
return dict_type;
std::shared_ptr<arrow::DataType> GetFactorTypeImpl(bool ordered) {
return dictionary(std::make_shared<Type>(), arrow::utf8(), ordered);
}

std::shared_ptr<arrow::DataType> GetFactorType(SEXP factor) {
SEXP levels = Rf_getAttrib(factor, R_LevelsSymbol);
bool is_ordered = Rf_inherits(factor, "ordered");
int n = Rf_length(levels);
if (n < 128) {
return GetFactorTypeImpl<arrow::Int8Type>(factor);
return GetFactorTypeImpl<arrow::Int8Type>(is_ordered);
} else if (n < 32768) {
return GetFactorTypeImpl<arrow::Int16Type>(factor);
return GetFactorTypeImpl<arrow::Int16Type>(is_ordered);
} else {
return GetFactorTypeImpl<arrow::Int32Type>(factor);
return GetFactorTypeImpl<arrow::Int32Type>(is_ordered);
}
}

Expand Down Expand Up @@ -909,21 +910,7 @@ bool CheckCompatibleFactor(SEXP obj, const std::shared_ptr<arrow::DataType>& typ

arrow::DictionaryType* dict_type =
arrow::checked_cast<arrow::DictionaryType*>(type.get());
auto dictionary = dict_type->dictionary();
if (dictionary->type() != utf8()) return false;

// then compare levels
auto typed_dict = checked_cast<arrow::StringArray*>(dictionary.get());
SEXP levels = Rf_getAttrib(obj, R_LevelsSymbol);

R_xlen_t n = XLENGTH(levels);
if (n != typed_dict->length()) return false;

for (R_xlen_t i = 0; i < n; i++) {
if (typed_dict->GetString(i) != CHAR(STRING_ELT(levels, i))) return false;
}

return true;
return dict_type->value_type() == utf8();
}

std::shared_ptr<arrow::Array> Array__from_vector(
Expand Down
16 changes: 8 additions & 8 deletions r/src/datatype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,9 @@ arrow::TimeUnit::type TimestampType__unit(

// [[Rcpp::export]]
std::shared_ptr<arrow::DataType> DictionaryType__initialize(
const std::shared_ptr<arrow::DataType>& type,
const std::shared_ptr<arrow::Array>& array, bool ordered) {
return arrow::dictionary(type, array, ordered);
const std::shared_ptr<arrow::DataType>& index_type,
const std::shared_ptr<arrow::DataType>& value_type, bool ordered) {
return arrow::dictionary(index_type, value_type, ordered);
}

// [[Rcpp::export]]
Expand All @@ -262,14 +262,14 @@ std::shared_ptr<arrow::DataType> DictionaryType__index_type(
}

// [[Rcpp::export]]
std::string DictionaryType__name(const std::shared_ptr<arrow::DictionaryType>& type) {
return type->name();
std::shared_ptr<arrow::DataType> DictionaryType__value_type(
const std::shared_ptr<arrow::DictionaryType>& type) {
return type->value_type();
}

// [[Rcpp::export]]
std::shared_ptr<arrow::Array> DictionaryType__dictionary(
const std::shared_ptr<arrow::DictionaryType>& type) {
return type->dictionary();
std::string DictionaryType__name(const std::shared_ptr<arrow::DictionaryType>& type) {
return type->name();
}

// [[Rcpp::export]]
Expand Down
9 changes: 7 additions & 2 deletions r/src/message.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,20 @@ std::shared_ptr<arrow::RecordBatch> ipc___ReadRecordBatch__Message__Schema(
const std::unique_ptr<arrow::ipc::Message>& message,
const std::shared_ptr<arrow::Schema>& schema) {
std::shared_ptr<arrow::RecordBatch> batch;
STOP_IF_NOT_OK(arrow::ipc::ReadRecordBatch(*message, schema, &batch));

// TODO: perhaps this should come from the R side
arrow::ipc::DictionaryMemo memo;
STOP_IF_NOT_OK(arrow::ipc::ReadRecordBatch(*message, schema, &memo, &batch));
return batch;
}

// [[Rcpp::export]]
std::shared_ptr<arrow::Schema> ipc___ReadSchema_InputStream(
const std::shared_ptr<arrow::io::InputStream>& stream) {
std::shared_ptr<arrow::Schema> schema;
STOP_IF_NOT_OK(arrow::ipc::ReadSchema(stream.get(), &schema));
// TODO: promote to function argument
arrow::ipc::DictionaryMemo memo;
STOP_IF_NOT_OK(arrow::ipc::ReadSchema(stream.get(), &memo, &schema));
return schema;
}

Expand Down
4 changes: 3 additions & 1 deletion r/src/recordbatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ std::shared_ptr<arrow::RecordBatch> ipc___ReadRecordBatch__InputStream__Schema(
const std::shared_ptr<arrow::io::InputStream>& stream,
const std::shared_ptr<arrow::Schema>& schema) {
std::shared_ptr<arrow::RecordBatch> batch;
STOP_IF_NOT_OK(arrow::ipc::ReadRecordBatch(schema, stream.get(), &batch));
// TODO: promote to function arg
arrow::ipc::DictionaryMemo memo;
STOP_IF_NOT_OK(arrow::ipc::ReadRecordBatch(schema, &memo, stream.get(), &batch));
return batch;
}
4 changes: 2 additions & 2 deletions r/tests/testthat/test-DataType.R
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,13 @@ test_that("struct type works as expected", {
})

test_that("DictionaryType works as expected (ARROW-3355)", {
d <- dictionary(int32(), array(c("foo", "bar", "baz")))
d <- dictionary(int32(), utf8())
expect_equal(d, d)
expect_true(d == d)
expect_false(d == int32())
expect_equal(d$id, Type$DICTIONARY)
expect_equal(d$bit_width, 32L)
expect_equal(d$ToString(), "dictionary<values=string, indices=int32, ordered=0>")
expect_equal(d$index_type, int32())
expect_equal(d$dictionary, array(c("foo", "bar", "baz")))
expect_equal(d$value_type, utf8())
})

0 comments on commit 99f2b6b

Please sign in to comment.