Skip to content

Commit

Permalink
Merge master into multiparty (#134)
Browse files Browse the repository at this point in the history
* Interval (#116)

* add date_add, interval sql still running into issues

* Add Interval SQL support

* uncomment out the other tests

* resolve comments

* change interval equality

Co-authored-by: Eric Feng <[email protected]>

* Fix NULL handling for aggregation (#130)

* Modify COUNT and SUM to correctly handle NULL values

* Change average to support NULL values

* Fix

* Changing operator matching from logical to physical (#129)

* WIP

* Fix

* Unapply change

* Aggregation rewrite (#132)

Co-authored-by: Eric Feng <[email protected]>
Co-authored-by: Wenting Zheng <[email protected]>
  • Loading branch information
3 people authored Jan 22, 2021
1 parent 9c87e8e commit 8cfa2a1
Show file tree
Hide file tree
Showing 17 changed files with 687 additions and 468 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ concurrentRestrictions in Global := Seq(
fork in Test := true
fork in run := true

testOptions in Test += Tests.Argument("-oF")
javaOptions in Test ++= Seq("-Xmx2048m", "-XX:ReservedCodeCacheSize=384m")
javaOptions in run ++= Seq(
"-Xmx2048m", "-XX:ReservedCodeCacheSize=384m", "-Dspark.master=local[1]")
Expand Down
103 changes: 10 additions & 93 deletions src/enclave/App/App.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,8 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin(
}

JNIEXPORT jobject JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregateStep1(
JNIEnv *env, jobject obj, jlong eid, jbyteArray agg_op, jbyteArray input_rows) {
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate(
JNIEnv *env, jobject obj, jlong eid, jbyteArray agg_op, jbyteArray input_rows, jboolean isPartial) {
(void)obj;

jboolean if_copy;
Expand All @@ -611,98 +611,21 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregateStep1
uint32_t input_rows_length = (uint32_t) env->GetArrayLength(input_rows);
uint8_t *input_rows_ptr = (uint8_t *) env->GetByteArrayElements(input_rows, &if_copy);

uint8_t *first_row = nullptr;
size_t first_row_length = 0;

uint8_t *last_group = nullptr;
size_t last_group_length = 0;

uint8_t *last_row = nullptr;
size_t last_row_length = 0;

if (input_rows_ptr == nullptr) {
ocall_throw("NonObliviousAggregateStep1: JNI failed to get input byte array.");
} else {
oe_check_and_time("Non-Oblivious Aggregate Step 1",
ecall_non_oblivious_aggregate_step1(
(oe_enclave_t*)eid,
agg_op_ptr, agg_op_length,
input_rows_ptr, input_rows_length,
&first_row, &first_row_length,
&last_group, &last_group_length,
&last_row, &last_row_length));
}

jbyteArray first_row_array = env->NewByteArray(first_row_length);
env->SetByteArrayRegion(first_row_array, 0, first_row_length, (jbyte *) first_row);
free(first_row);

jbyteArray last_group_array = env->NewByteArray(last_group_length);
env->SetByteArrayRegion(last_group_array, 0, last_group_length, (jbyte *) last_group);
free(last_group);

jbyteArray last_row_array = env->NewByteArray(last_row_length);
env->SetByteArrayRegion(last_row_array, 0, last_row_length, (jbyte *) last_row);
free(last_row);

env->ReleaseByteArrayElements(agg_op, (jbyte *) agg_op_ptr, 0);
env->ReleaseByteArrayElements(input_rows, (jbyte *) input_rows_ptr, 0);

jclass tuple3_class = env->FindClass("scala/Tuple3");
jobject ret = env->NewObject(
tuple3_class,
env->GetMethodID(tuple3_class, "<init>",
"(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)V"),
first_row_array, last_group_array, last_row_array);

return ret;
}

JNIEXPORT jbyteArray JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregateStep2(
JNIEnv *env, jobject obj, jlong eid, jbyteArray agg_op, jbyteArray input_rows,
jbyteArray next_partition_first_row, jbyteArray prev_partition_last_group,
jbyteArray prev_partition_last_row) {
(void)obj;

jboolean if_copy;

uint32_t agg_op_length = (uint32_t) env->GetArrayLength(agg_op);
uint8_t *agg_op_ptr = (uint8_t *) env->GetByteArrayElements(agg_op, &if_copy);

uint32_t input_rows_length = (uint32_t) env->GetArrayLength(input_rows);
uint8_t *input_rows_ptr = (uint8_t *) env->GetByteArrayElements(input_rows, &if_copy);

uint32_t next_partition_first_row_length =
(uint32_t) env->GetArrayLength(next_partition_first_row);
uint8_t *next_partition_first_row_ptr =
(uint8_t *) env->GetByteArrayElements(next_partition_first_row, &if_copy);

uint32_t prev_partition_last_group_length =
(uint32_t) env->GetArrayLength(prev_partition_last_group);
uint8_t *prev_partition_last_group_ptr =
(uint8_t *) env->GetByteArrayElements(prev_partition_last_group, &if_copy);

uint32_t prev_partition_last_row_length =
(uint32_t) env->GetArrayLength(prev_partition_last_row);
uint8_t *prev_partition_last_row_ptr =
(uint8_t *) env->GetByteArrayElements(prev_partition_last_row, &if_copy);

uint8_t *output_rows = nullptr;
size_t output_rows_length = 0;

bool is_partial = (bool) isPartial;

if (input_rows_ptr == nullptr) {
ocall_throw("NonObliviousAggregateStep2: JNI failed to get input byte array.");
ocall_throw("NonObliviousAggregateStep: JNI failed to get input byte array.");
} else {
oe_check_and_time("Non-Oblivious Aggregate Step 2",
ecall_non_oblivious_aggregate_step2(
oe_check_and_time("Non-Oblivious Aggregate",
ecall_non_oblivious_aggregate(
(oe_enclave_t*)eid,
agg_op_ptr, agg_op_length,
input_rows_ptr, input_rows_length,
next_partition_first_row_ptr, next_partition_first_row_length,
prev_partition_last_group_ptr, prev_partition_last_group_length,
prev_partition_last_row_ptr, prev_partition_last_row_length,
&output_rows, &output_rows_length));
&output_rows, &output_rows_length,
is_partial));
}

jbyteArray ret = env->NewByteArray(output_rows_length);
Expand All @@ -711,13 +634,7 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregateStep2

env->ReleaseByteArrayElements(agg_op, (jbyte *) agg_op_ptr, 0);
env->ReleaseByteArrayElements(input_rows, (jbyte *) input_rows_ptr, 0);
env->ReleaseByteArrayElements(
next_partition_first_row, (jbyte *) next_partition_first_row_ptr, 0);
env->ReleaseByteArrayElements(
prev_partition_last_group, (jbyte *) prev_partition_last_group_ptr, 0);
env->ReleaseByteArrayElements(
prev_partition_last_row, (jbyte *) prev_partition_last_row_ptr, 0);


return ret;
}

Expand Down
8 changes: 2 additions & 6 deletions src/enclave/App/SGXEnclave.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,8 @@ extern "C" {
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jbyteArray);

JNIEXPORT jobject JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregateStep1(
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray);

JNIEXPORT jbyteArray JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregateStep2(
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jbyteArray, jbyteArray, jbyteArray);
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate(
JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jboolean);

JNIEXPORT jbyteArray JNICALL
Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_CountRowsPerPartition(
Expand Down
108 changes: 15 additions & 93 deletions src/enclave/Enclave/Aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,116 +5,38 @@
#include "FlatbuffersWriters.h"
#include "common.h"

void non_oblivious_aggregate_step1(
void non_oblivious_aggregate(
uint8_t *agg_op, size_t agg_op_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t **first_row, size_t *first_row_length,
uint8_t **last_group, size_t *last_group_length,
uint8_t **last_row, size_t *last_row_length) {
uint8_t **output_rows, size_t *output_rows_length,
bool is_partial) {

FlatbuffersAggOpEvaluator agg_op_eval(agg_op, agg_op_length);
RowReader r(BufferRefView<tuix::EncryptedBlocks>(input_rows, input_rows_length));
RowWriter first_row_writer;
RowWriter last_group_writer;
RowWriter last_row_writer;
RowWriter w;

FlatbuffersTemporaryRow prev, cur;
size_t count = 0;

while (r.has_next()) {
prev.set(cur.get());
cur.set(r.next());

if (prev.get() == nullptr) {
first_row_writer.append(cur.get());
}

if (!r.has_next()) {
last_row_writer.append(cur.get());
}


if (prev.get() != nullptr && !agg_op_eval.is_same_group(prev.get(), cur.get())) {
w.append(agg_op_eval.evaluate());
agg_op_eval.reset_group();
}
agg_op_eval.aggregate(cur.get());
count += 1;
}
last_group_writer.append(agg_op_eval.get_partial_agg());

first_row_writer.output_buffer(first_row, first_row_length);
last_group_writer.output_buffer(last_group, last_group_length);
last_row_writer.output_buffer(last_row, last_row_length);
}

void non_oblivious_aggregate_step2(
uint8_t *agg_op, size_t agg_op_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t *next_partition_first_row, size_t next_partition_first_row_length,
uint8_t *prev_partition_last_group, size_t prev_partition_last_group_length,
uint8_t *prev_partition_last_row, size_t prev_partition_last_row_length,
uint8_t **output_rows, size_t *output_rows_length) {

FlatbuffersAggOpEvaluator agg_op_eval(agg_op, agg_op_length);
RowReader r(BufferRefView<tuix::EncryptedBlocks>(input_rows, input_rows_length));
RowReader next_partition_first_row_reader(
BufferRefView<tuix::EncryptedBlocks>(
next_partition_first_row, next_partition_first_row_length));
RowReader prev_partition_last_group_reader(
BufferRefView<tuix::EncryptedBlocks>(
prev_partition_last_group, prev_partition_last_group_length));
RowReader prev_partition_last_row_reader(
BufferRefView<tuix::EncryptedBlocks>(
prev_partition_last_row, prev_partition_last_row_length));
RowWriter w;

if (next_partition_first_row_reader.num_rows() > 1) {
throw std::runtime_error(
std::string("Incorrect number of starting rows from next partition passed: expected 0 or 1, got ")
+ std::to_string(next_partition_first_row_reader.num_rows()));
}
if (prev_partition_last_group_reader.num_rows() > 1) {
throw std::runtime_error(
std::string("Incorrect number of ending groups from prev partition passed: expected 0 or 1, got ")
+ std::to_string(prev_partition_last_group_reader.num_rows()));
}
if (prev_partition_last_row_reader.num_rows() > 1) {
throw std::runtime_error(
std::string("Incorrect number of ending rows from prev partition passed: expected 0 or 1, got ")
+ std::to_string(prev_partition_last_row_reader.num_rows()));
}

const tuix::Row *next_partition_first_row_ptr =
next_partition_first_row_reader.has_next() ? next_partition_first_row_reader.next() : nullptr;
agg_op_eval.set(prev_partition_last_group_reader.has_next() ?
prev_partition_last_group_reader.next() : nullptr);
const tuix::Row *prev_partition_last_row_ptr =
prev_partition_last_row_reader.has_next() ? prev_partition_last_row_reader.next() : nullptr;

FlatbuffersTemporaryRow prev, cur(prev_partition_last_row_ptr), next;
bool stop = false;
if (r.has_next()) {
next.set(r.next());
} else {
stop = true;
}
while (!stop) {
// Populate prev, cur, next to enable lookbehind and lookahead
prev.set(cur.get());
cur.set(next.get());
if (r.has_next()) {
next.set(r.next());
} else {
next.set(next_partition_first_row_ptr);
stop = true;
}

if (prev.get() != nullptr && !agg_op_eval.is_same_group(prev.get(), cur.get())) {
agg_op_eval.reset_group();
}
agg_op_eval.aggregate(cur.get());

// Output the current aggregate if it is the last aggregate for its run
if (next.get() == nullptr || !agg_op_eval.is_same_group(cur.get(), next.get())) {
w.append(agg_op_eval.evaluate());
}
// Skip outputting the final row if the number of input rows is 0 AND
// 1. It's a grouping aggregation, OR
// 2. It's a global aggregation, the mode is final
if (!(count == 0 && (agg_op_eval.get_num_grouping_keys() > 0 || (agg_op_eval.get_num_grouping_keys() == 0 && !is_partial)))) {
w.append(agg_op_eval.evaluate());
}

w.output_buffer(output_rows, output_rows_length);
}

15 changes: 3 additions & 12 deletions src/enclave/Enclave/Aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,10 @@
#ifndef AGGREGATE_H
#define AGGREGATE_H

void non_oblivious_aggregate_step1(
void non_oblivious_aggregate(
uint8_t *agg_op, size_t agg_op_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t **first_row, size_t *first_row_length,
uint8_t **last_group, size_t *last_group_length,
uint8_t **last_row, size_t *last_row_length);

void non_oblivious_aggregate_step2(
uint8_t *agg_op, size_t agg_op_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t *next_partition_first_row, size_t next_partition_first_row_length,
uint8_t *prev_partition_last_group, size_t prev_partition_last_group_length,
uint8_t *prev_partition_last_row, size_t prev_partition_last_row_length,
uint8_t **output_rows, size_t *output_rows_length);
uint8_t **output_rows, size_t *output_rows_length,
bool is_partial);

#endif // AGGREGATE_H
45 changes: 8 additions & 37 deletions src/enclave/Enclave/Enclave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,50 +190,21 @@ void ecall_non_oblivious_sort_merge_join(uint8_t *join_expr, size_t join_expr_le
}
}

void ecall_non_oblivious_aggregate_step1(
void ecall_non_oblivious_aggregate(
uint8_t *agg_op, size_t agg_op_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t **first_row, size_t *first_row_length,
uint8_t **last_group, size_t *last_group_length,
uint8_t **last_row, size_t *last_row_length) {
uint8_t **output_rows, size_t *output_rows_length,
bool is_partial) {
// Guard against operating on arbitrary enclave memory
assert(oe_is_outside_enclave(input_rows, input_rows_length) == 1);
__builtin_ia32_lfence();

try {
non_oblivious_aggregate_step1(
agg_op, agg_op_length,
input_rows, input_rows_length,
first_row, first_row_length,
last_group, last_group_length,
last_row, last_row_length);
} catch (const std::runtime_error &e) {
ocall_throw(e.what());
}
}

void ecall_non_oblivious_aggregate_step2(
uint8_t *agg_op, size_t agg_op_length,
uint8_t *input_rows, size_t input_rows_length,
uint8_t *next_partition_first_row, size_t next_partition_first_row_length,
uint8_t *prev_partition_last_group, size_t prev_partition_last_group_length,
uint8_t *prev_partition_last_row, size_t prev_partition_last_row_length,
uint8_t **output_rows, size_t *output_rows_length) {
// Guard against operating on arbitrary enclave memory
assert(oe_is_outside_enclave(input_rows, input_rows_length) == 1);
assert(oe_is_outside_enclave(next_partition_first_row, next_partition_first_row_length) == 1);
assert(oe_is_outside_enclave(prev_partition_last_group, prev_partition_last_group_length) == 1);
assert(oe_is_outside_enclave(prev_partition_last_row, prev_partition_last_row_length) == 1);
__builtin_ia32_lfence();

try {
non_oblivious_aggregate_step2(
agg_op, agg_op_length,
input_rows, input_rows_length,
next_partition_first_row, next_partition_first_row_length,
prev_partition_last_group, prev_partition_last_group_length,
prev_partition_last_row, prev_partition_last_row_length,
output_rows, output_rows_length);
non_oblivious_aggregate(agg_op, agg_op_length,
input_rows, input_rows_length,
output_rows, output_rows_length,
is_partial);

} catch (const std::runtime_error &e) {
ocall_throw(e.what());
}
Expand Down
15 changes: 3 additions & 12 deletions src/enclave/Enclave/Enclave.edl
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,11 @@ enclave {
[user_check] uint8_t *join_row, size_t join_row_length,
[out] uint8_t **output_rows, [out] size_t *output_rows_length);

public void ecall_non_oblivious_aggregate_step1(
public void ecall_non_oblivious_aggregate(
[in, count=agg_op_length] uint8_t *agg_op, size_t agg_op_length,
[user_check] uint8_t *input_rows, size_t input_rows_length,
[out] uint8_t **first_row, [out] size_t *first_row_length,
[out] uint8_t **last_group, [out] size_t *last_group_length,
[out] uint8_t **last_row, [out] size_t *last_row_length);

public void ecall_non_oblivious_aggregate_step2(
[in, count=agg_op_length] uint8_t *agg_op, size_t agg_op_length,
[user_check] uint8_t *input_rows, size_t input_rows_length,
[user_check] uint8_t *next_partition_first_row, size_t next_partition_first_row_length,
[user_check] uint8_t *prev_partition_last_group, size_t prev_partition_last_group_length,
[user_check] uint8_t *prev_partition_last_row, size_t prev_partition_last_row_length,
[out] uint8_t **output_rows, [out] size_t *output_rows_length);
[out] uint8_t **output_rows, [out] size_t *output_rows_length,
bool is_partial);

public void ecall_count_rows_per_partition(
[user_check] uint8_t *input_rows, size_t input_rows_length,
Expand Down
Loading

0 comments on commit 8cfa2a1

Please sign in to comment.