diff --git a/include/faabric/util/bytes.h b/include/faabric/util/bytes.h index 8650262b1..609952cad 100644 --- a/include/faabric/util/bytes.h +++ b/include/faabric/util/bytes.h @@ -5,6 +5,8 @@ #include #include +#include + namespace faabric::util { std::vector stringToBytes(const std::string& str); @@ -42,4 +44,10 @@ size_t readBytesOf(const std::vector& container, std::copy_n(container.data() + offset, sizeof(T), outStart); return offset + sizeof(T); } + +template +std::vector valueToBytes(T val) +{ + return std::vector(BYTES(&val), BYTES(&val) + sizeof(T)); +} } diff --git a/include/faabric/util/snapshot.h b/include/faabric/util/snapshot.h index 2889c14e9..703b30410 100644 --- a/include/faabric/util/snapshot.h +++ b/include/faabric/util/snapshot.h @@ -1,15 +1,47 @@ #pragma once +#include +#include +#include #include #include namespace faabric::util { -struct SnapshotDiff +enum SnapshotDataType { + Raw, + Int +}; + +enum SnapshotMergeOperation +{ + Overwrite, + Sum, + Product, + Subtract, + Max, + Min +}; + +struct SnapshotMergeRegion +{ + uint32_t offset = 0; + size_t length = 0; + SnapshotDataType dataType = SnapshotDataType::Raw; + SnapshotMergeOperation operation = SnapshotMergeOperation::Overwrite; +}; + +class SnapshotDiff +{ + public: uint32_t offset = 0; size_t size = 0; const uint8_t* data = nullptr; + SnapshotDataType dataType = SnapshotDataType::Raw; + SnapshotMergeOperation operation = SnapshotMergeOperation::Overwrite; + + SnapshotDiff() = default; SnapshotDiff(uint32_t offsetIn, const uint8_t* dataIn, size_t sizeIn) { @@ -26,11 +58,21 @@ class SnapshotData uint8_t* data = nullptr; int fd = 0; + SnapshotData() = default; + std::vector getDirtyPages(); std::vector getChangeDiffs(const uint8_t* updated, size_t updatedSize); - void applyDiff(size_t diffOffset, const uint8_t* diffData, size_t diffLen); + void addMergeRegion(uint32_t offset, + size_t length, + SnapshotDataType dataType, + SnapshotMergeOperation operation); + + private: + // Note - we care about the order of this map, as we iterate through it in + // order of offsets + std::map mergeRegions; }; } diff --git a/src/flat/faabric.fbs b/src/flat/faabric.fbs index 9a9509f88..725e5bffd 100644 --- a/src/flat/faabric.fbs +++ b/src/flat/faabric.fbs @@ -9,6 +9,8 @@ table SnapshotDeleteRequest { table SnapshotDiffChunk { offset:int; + dataType:int; + mergeOp:int; data:[ubyte]; } diff --git a/src/snapshot/SnapshotClient.cpp b/src/snapshot/SnapshotClient.cpp index c16176cd6..90fcbe4a8 100644 --- a/src/snapshot/SnapshotClient.cpp +++ b/src/snapshot/SnapshotClient.cpp @@ -113,7 +113,9 @@ void SnapshotClient::pushSnapshotDiffs( std::vector> diffsFbVector; for (const auto& d : diffs) { auto dataOffset = mb.CreateVector(d.data, d.size); - auto chunk = CreateSnapshotDiffChunk(mb, d.offset, dataOffset); + + auto chunk = CreateSnapshotDiffChunk( + mb, d.offset, d.dataType, d.operation, dataOffset); diffsFbVector.push_back(chunk); } diff --git a/src/snapshot/SnapshotServer.cpp b/src/snapshot/SnapshotServer.cpp index c396cf9a3..adcc9922c 100644 --- a/src/snapshot/SnapshotServer.cpp +++ b/src/snapshot/SnapshotServer.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include @@ -117,9 +118,64 @@ SnapshotServer::recvPushSnapshotDiffs(const uint8_t* buffer, size_t bufferSize) faabric::snapshot::getSnapshotRegistry(); faabric::util::SnapshotData& snap = reg.getSnapshot(r->key()->str()); - // Copy diffs to snapshot + // Apply diffs to snapshot for (const auto* r : *r->chunks()) { - snap.applyDiff(r->offset(), r->data()->data(), r->data()->size()); + uint8_t* dest = snap.data + r->offset(); + switch (r->dataType()) { + case (faabric::util::SnapshotDataType::Raw): { + switch (r->mergeOp()) { + case (faabric::util::SnapshotMergeOperation::Overwrite): { + std::memcpy(dest, r->data()->data(), r->data()->size()); + break; + } + default: { + SPDLOG_ERROR("Unsupported raw merge operation: {}", + r->mergeOp()); + throw std::runtime_error( + "Unsupported raw merge operation"); + } + } + break; + } + case (faabric::util::SnapshotDataType::Int): { + const auto* value = + reinterpret_cast(r->data()->data()); + auto* destValue = reinterpret_cast(dest); + switch (r->mergeOp()) { + case (faabric::util::SnapshotMergeOperation::Sum): { + *destValue += *value; + break; + } + case (faabric::util::SnapshotMergeOperation::Subtract): { + *destValue -= *value; + break; + } + case (faabric::util::SnapshotMergeOperation::Product): { + *destValue *= *value; + break; + } + case (faabric::util::SnapshotMergeOperation::Min): { + *destValue = std::min(*destValue, *value); + break; + } + case (faabric::util::SnapshotMergeOperation::Max): { + *destValue = std::max(*destValue, *value); + break; + } + default: { + SPDLOG_ERROR("Unsupported int merge operation: {}", + r->mergeOp()); + throw std::runtime_error( + "Unsupported int merge operation"); + } + } + break; + } + default: { + SPDLOG_ERROR("Unsupported data type: {}", r->dataType()); + throw std::runtime_error("Unsupported merge data type"); + } + } } // Send response diff --git a/src/util/snapshot.cpp b/src/util/snapshot.cpp index 3eaa0c173..5f05c5d63 100644 --- a/src/util/snapshot.cpp +++ b/src/util/snapshot.cpp @@ -1,9 +1,15 @@ +#include #include +#include #include #include namespace faabric::util { +// TODO - this would be better as an instance variable on the SnapshotData +// class, but it can't be copy-constructed. +static std::mutex snapMx; + std::vector SnapshotData::getDirtyPages() { if (data == nullptr || size == 0) { @@ -31,27 +37,126 @@ std::vector SnapshotData::getDirtyPages() std::vector SnapshotData::getChangeDiffs(const uint8_t* updated, size_t updatedSize) { - // Work out which pages have changed in the comparison + // Work out which pages have changed size_t nThisPages = getRequiredHostPages(size); std::vector dirtyPageNumbers = getDirtyPageNumbers(updated, nThisPages); + // Get iterator over merge regions + std::map::iterator mergeIt = + mergeRegions.begin(); + // Get byte-wise diffs _within_ the dirty pages + // // NOTE - this will cause diffs to be split across pages if they hit a page // boundary, but we can be relatively confident that variables will be // page-aligned so this shouldn't be a problem + // + // For each byte we encounter have the following possible scenarios: + // + // 1. the byte is dirty, and is the start of a new diff + // 2. the byte is dirty, but the byte before was also dirty, so we + // are inside a diff + // 3. the byte is not dirty but the previous one was, so we've reached the + // end of a diff + // 4. the last byte of the page is dirty, so we've also come to the end of + // a diff + // 5. the byte is dirty, but is within a special merge region, in which + // case we need to add a diff for that whole region, then skip + // to the next byte after that region std::vector diffs; for (int i : dirtyPageNumbers) { int pageOffset = i * HOST_PAGE_SIZE; - // Iterate through each byte of the page bool diffInProgress = false; int diffStart = 0; int offset = pageOffset; for (int b = 0; b < HOST_PAGE_SIZE; b++) { offset = pageOffset + b; bool isDirtyByte = *(data + offset) != *(updated + offset); - if (isDirtyByte && !diffInProgress) { + + bool isInMergeRegion = + mergeIt != mergeRegions.end() && + offset >= mergeIt->second.offset && + offset < (mergeIt->second.offset + mergeIt->second.length); + + if (isDirtyByte && isInMergeRegion) { + SnapshotMergeRegion region = mergeIt->second; + + // Set up the diff + const uint8_t* updatedValue = updated + region.offset; + const uint8_t* originalValue = data + region.offset; + + SnapshotDiff diff(region.offset, updatedValue, region.length); + diff.dataType = region.dataType; + diff.operation = region.operation; + + // Modify diff data for certain operations + switch (region.dataType) { + case (SnapshotDataType::Int): { + int originalInt = + *(reinterpret_cast(originalValue)); + int updatedInt = + *(reinterpret_cast(updatedValue)); + + switch (region.operation) { + case (SnapshotMergeOperation::Sum): { + // Sums must send the value to be _added_, and + // not the final result + updatedInt -= originalInt; + break; + } + case (SnapshotMergeOperation::Subtract): { + // Subtractions must send the value to be + // subtracted, not the result + updatedInt = originalInt - updatedInt; + break; + } + case (SnapshotMergeOperation::Product): { + // Products must send the value to be + // multiplied, not the result + updatedInt /= originalInt; + break; + } + case (SnapshotMergeOperation::Max): + case (SnapshotMergeOperation::Min): + // Min and max don't need to change + break; + default: { + SPDLOG_ERROR( + "Unhandled integer merge operation: {}", + region.operation); + throw std::runtime_error( + "Unhandled integer merge operation"); + } + } + + // TODO - somehow avoid casting away the const here? + // Modify the memory in-place here + std::memcpy((uint8_t*)updatedValue, + BYTES(&updatedInt), + sizeof(int32_t)); + + break; + } + default: { + SPDLOG_ERROR("Merge region for unhandled data type: {}", + region.dataType); + throw std::runtime_error( + "Merge region for unhandled data type"); + } + } + + // Add the diff to the list + diffs.emplace_back(diff); + + // Bump the loop variable to the end of this region (note that + // the loop itself will increment onto the next) + b = (region.offset - pageOffset) + (region.length - 1); + + // Move onto the next merge region + ++mergeIt; + } else if (isDirtyByte && !diffInProgress) { // Diff starts here if it's different and diff not in progress diffInProgress = true; diffStart = offset; @@ -81,12 +186,17 @@ std::vector SnapshotData::getChangeDiffs(const uint8_t* updated, return diffs; } -void SnapshotData::applyDiff(size_t diffOffset, - const uint8_t* diffData, - size_t diffLen) +void SnapshotData::addMergeRegion(uint32_t offset, + size_t length, + SnapshotDataType dataType, + SnapshotMergeOperation operation) { - uint8_t* dest = data + diffOffset; - std::memcpy(dest, diffData, diffLen); + SnapshotMergeRegion region{ .offset = offset, + .length = length, + .dataType = dataType, + .operation = operation }; + // Locking as this may be called in bursts by multiple threads + faabric::util::UniqueLock lock(snapMx); + mergeRegions[offset] = region; } - } diff --git a/tests/test/snapshot/test_snapshot_client_server.cpp b/tests/test/snapshot/test_snapshot_client_server.cpp index 11de20898..60abe9ec4 100644 --- a/tests/test/snapshot/test_snapshot_client_server.cpp +++ b/tests/test/snapshot/test_snapshot_client_server.cpp @@ -7,11 +7,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include namespace tests { @@ -124,6 +126,169 @@ TEST_CASE_METHOD(SnapshotClientServerFixture, deallocatePages(snap.data, 5); } +TEST_CASE_METHOD(SnapshotClientServerFixture, + "Test detailed snapshot diffs with merge ops", + "[snapshot]") +{ + // Set up a snapshot + std::string snapKey = std::to_string(faabric::util::generateGid()); + faabric::util::SnapshotData snap = takeSnapshot(snapKey, 5, false); + + // Set up a couple of ints in the snapshot + int offsetA1 = 5; + int offsetA2 = 2 * faabric::util::HOST_PAGE_SIZE; + int baseA1 = 25; + int baseA2 = 60; + + int* basePtrA1 = (int*)(snap.data + offsetA1); + int* basePtrA2 = (int*)(snap.data + offsetA2); + *basePtrA1 = baseA1; + *basePtrA2 = baseA2; + + // Set up some diffs with different merge operations + int diffIntA1 = 123; + int diffIntA2 = 345; + + std::vector intDataA1 = + faabric::util::valueToBytes(diffIntA1); + std::vector intDataA2 = + faabric::util::valueToBytes(diffIntA2); + + std::vector diffs; + + faabric::util::SnapshotDiff diffA1( + offsetA1, intDataA1.data(), intDataA1.size()); + diffA1.operation = faabric::util::SnapshotMergeOperation::Sum; + diffA1.dataType = faabric::util::SnapshotDataType::Int; + + faabric::util::SnapshotDiff diffA2( + offsetA2, intDataA2.data(), intDataA2.size()); + diffA2.operation = faabric::util::SnapshotMergeOperation::Sum; + diffA2.dataType = faabric::util::SnapshotDataType::Int; + + diffs = { diffA1, diffA2 }; + cli.pushSnapshotDiffs(snapKey, diffs); + + // Check diffs have been applied according to the merge operations + REQUIRE(*basePtrA1 == baseA1 + diffIntA1); + REQUIRE(*basePtrA2 == baseA2 + diffIntA2); + + deallocatePages(snap.data, 5); +} + +TEST_CASE_METHOD(SnapshotClientServerFixture, + "Test snapshot diffs with merge ops", + "[snapshot]") +{ + // Set up a snapshot + std::string snapKey = std::to_string(faabric::util::generateGid()); + faabric::util::SnapshotData snap = takeSnapshot(snapKey, 5, false); + + int offset = 5; + std::vector originalData; + std::vector diffData; + std::vector expectedData; + + faabric::util::SnapshotMergeOperation operation = + faabric::util::SnapshotMergeOperation::Overwrite; + faabric::util::SnapshotDataType dataType = + faabric::util::SnapshotDataType::Raw; + + SECTION("Integer") + { + dataType = faabric::util::SnapshotDataType::Int; + int original = 0; + int diff = 0; + int expected = 0; + + SECTION("Sum") + { + original = 100; + diff = 10; + expected = 110; + + operation = faabric::util::SnapshotMergeOperation::Sum; + } + + SECTION("Subtract") + { + original = 100; + diff = 10; + expected = 90; + + operation = faabric::util::SnapshotMergeOperation::Subtract; + } + + SECTION("Product") + { + original = 10; + diff = 20; + expected = 200; + + operation = faabric::util::SnapshotMergeOperation::Product; + } + + SECTION("Min") + { + SECTION("With change") + { + original = 1000; + diff = 100; + expected = 100; + } + + SECTION("No change") + { + original = 10; + diff = 20; + expected = 10; + } + + operation = faabric::util::SnapshotMergeOperation::Min; + } + + SECTION("Max") + { + SECTION("With change") + { + original = 100; + diff = 1000; + expected = 1000; + } + + SECTION("No change") + { + original = 20; + diff = 10; + expected = 20; + } + + operation = faabric::util::SnapshotMergeOperation::Max; + } + + originalData = faabric::util::valueToBytes(original); + diffData = faabric::util::valueToBytes(diff); + expectedData = faabric::util::valueToBytes(expected); + } + + // Put original data in place + std::memcpy(snap.data + offset, originalData.data(), originalData.size()); + + faabric::util::SnapshotDiff diff(offset, diffData.data(), diffData.size()); + diff.operation = operation; + diff.dataType = dataType; + + std::vector diffs = { diff }; + cli.pushSnapshotDiffs(snapKey, diffs); + + // Check data is as expected + std::vector actualData(snap.data + offset, + snap.data + offset + expectedData.size()); + REQUIRE(actualData == expectedData); + + deallocatePages(snap.data, 5); +} + TEST_CASE_METHOD(SnapshotClientServerFixture, "Test set thread result", "[snapshot]") diff --git a/tests/test/util/test_snapshot.cpp b/tests/test/util/test_snapshot.cpp new file mode 100644 index 000000000..5ac92938f --- /dev/null +++ b/tests/test/util/test_snapshot.cpp @@ -0,0 +1,413 @@ +#include + +#include "faabric_utils.h" +#include "fixtures.h" + +#include +#include +#include +#include + +using namespace faabric::util; + +namespace tests { + +TEST_CASE_METHOD(SnapshotTestFixture, + "Detailed test snapshot merge regions with ints", + "[util]") +{ + std::string snapKey = "foobar123"; + int snapPages = 5; + + int originalValueA = 100; + int finalValueA = 150; + int sumValueA = 50; + + int originalValueB = 300; + int finalValueB = 425; + int sumValueB = 125; + + faabric::util::SnapshotData snap; + snap.size = snapPages * faabric::util::HOST_PAGE_SIZE; + snap.data = allocatePages(snapPages); + + // Set up some integers in the snapshot + int intAOffset = HOST_PAGE_SIZE + (10 * sizeof(int32_t)); + int intBOffset = (2 * HOST_PAGE_SIZE) + (20 * sizeof(int32_t)); + int* intAOriginal = (int*)(snap.data + intAOffset); + int* intBOriginal = (int*)(snap.data + intBOffset); + + // Set the original values + *intAOriginal = originalValueA; + *intBOriginal = originalValueB; + + // Take the snapshot + reg.takeSnapshot(snapKey, snap, true); + + // Map the snapshot to some memory + size_t sharedMemSize = snapPages * HOST_PAGE_SIZE; + uint8_t* sharedMem = allocatePages(snapPages); + + reg.mapSnapshot(snapKey, sharedMem); + + // Check mapping works + int* intA = (int*)(sharedMem + intAOffset); + int* intB = (int*)(sharedMem + intBOffset); + + REQUIRE(*intA == originalValueA); + REQUIRE(*intB == originalValueB); + + // Reset dirty tracking to get a clean start + faabric::util::resetDirtyTracking(); + + // Set up the merge regions, deliberately do the one at higher offsets first + // to check the ordering + snap.addMergeRegion(intBOffset, + sizeof(int), + SnapshotDataType::Int, + SnapshotMergeOperation::Sum); + + snap.addMergeRegion(intAOffset, + sizeof(int), + SnapshotDataType::Int, + SnapshotMergeOperation::Sum); + + // Modify both values and some other data + *intA = finalValueA; + *intB = finalValueB; + + std::vector otherData(100, 5); + int otherOffset = (3 * HOST_PAGE_SIZE) + 5; + std::memcpy(sharedMem + otherOffset, otherData.data(), otherData.size()); + + // Get the snapshot diffs + std::vector actualDiffs = + snap.getChangeDiffs(sharedMem, sharedMemSize); + + // Check original hasn't changed + REQUIRE(*intAOriginal == originalValueA); + REQUIRE(*intBOriginal == originalValueB); + + // Check diffs themselves + REQUIRE(actualDiffs.size() == 3); + + SnapshotDiff diffA = actualDiffs.at(0); + SnapshotDiff diffB = actualDiffs.at(1); + SnapshotDiff diffOther = actualDiffs.at(2); + + REQUIRE(diffA.offset == intAOffset); + REQUIRE(diffB.offset == intBOffset); + REQUIRE(diffOther.offset == otherOffset); + + REQUIRE(diffA.operation == SnapshotMergeOperation::Sum); + REQUIRE(diffB.operation == SnapshotMergeOperation::Sum); + REQUIRE(diffOther.operation == SnapshotMergeOperation::Overwrite); + + REQUIRE(diffA.dataType == SnapshotDataType::Int); + REQUIRE(diffB.dataType == SnapshotDataType::Int); + REQUIRE(diffOther.dataType == SnapshotDataType::Raw); + + REQUIRE(diffA.size == sizeof(int32_t)); + REQUIRE(diffB.size == sizeof(int32_t)); + REQUIRE(diffOther.size == otherData.size()); + + // Check that original values have been subtracted from final values for + // sums + REQUIRE(*(int*)diffA.data == sumValueA); + REQUIRE(*(int*)diffB.data == sumValueB); + + std::vector actualOtherData(diffOther.data, + diffOther.data + diffOther.size); + REQUIRE(actualOtherData == otherData); + + deallocatePages(snap.data, snapPages); +} + +TEST_CASE_METHOD(SnapshotTestFixture, + "Test edge-cases of snapshot merge regions", + "[util]") +{ + // Region edge cases: + // - start + // - adjacent + // - finish + + std::string snapKey = "foobar123"; + int snapPages = 5; + int snapSize = snapPages * faabric::util::HOST_PAGE_SIZE; + + int originalA = 50; + int finalA = 25; + int subA = 25; + int offsetA = 0; + + int originalB = 100; + int finalB = 200; + int sumB = 100; + int offsetB = HOST_PAGE_SIZE + (2 * sizeof(int32_t)); + + int originalC = 200; + int finalC = 150; + int subC = 50; + int offsetC = offsetB + sizeof(int32_t); + + int originalD = 100; + int finalD = 150; + int sumD = 50; + int offsetD = snapSize - sizeof(int32_t); + + faabric::util::SnapshotData snap; + snap.size = snapSize; + snap.data = allocatePages(snapPages); + + // Set up original values + *(int*)(snap.data + offsetA) = originalA; + *(int*)(snap.data + offsetB) = originalB; + *(int*)(snap.data + offsetC) = originalC; + *(int*)(snap.data + offsetD) = originalD; + + // Take the snapshot + reg.takeSnapshot(snapKey, snap, true); + + // Map the snapshot to some memory + size_t sharedMemSize = snapPages * HOST_PAGE_SIZE; + uint8_t* sharedMem = allocatePages(snapPages); + + reg.mapSnapshot(snapKey, sharedMem); + + // Reset dirty tracking + faabric::util::resetDirtyTracking(); + + // Set up the merge regions + snap.addMergeRegion(offsetA, + sizeof(int), + SnapshotDataType::Int, + SnapshotMergeOperation::Subtract); + + snap.addMergeRegion( + offsetB, sizeof(int), SnapshotDataType::Int, SnapshotMergeOperation::Sum); + + snap.addMergeRegion(offsetC, + sizeof(int), + SnapshotDataType::Int, + SnapshotMergeOperation::Subtract); + + snap.addMergeRegion( + offsetD, sizeof(int), SnapshotDataType::Int, SnapshotMergeOperation::Sum); + + // Set final values + *(int*)(sharedMem + offsetA) = finalA; + *(int*)(sharedMem + offsetB) = finalB; + *(int*)(sharedMem + offsetC) = finalC; + *(int*)(sharedMem + offsetD) = finalD; + + // Check the diffs + std::vector actualDiffs = + snap.getChangeDiffs(sharedMem, sharedMemSize); + REQUIRE(actualDiffs.size() == 4); + + SnapshotDiff diffA = actualDiffs.at(0); + SnapshotDiff diffB = actualDiffs.at(1); + SnapshotDiff diffC = actualDiffs.at(2); + SnapshotDiff diffD = actualDiffs.at(3); + + REQUIRE(diffA.offset == offsetA); + REQUIRE(diffB.offset == offsetB); + REQUIRE(diffC.offset == offsetC); + REQUIRE(diffD.offset == offsetD); + + REQUIRE(diffA.operation == SnapshotMergeOperation::Subtract); + REQUIRE(diffB.operation == SnapshotMergeOperation::Sum); + REQUIRE(diffC.operation == SnapshotMergeOperation::Subtract); + REQUIRE(diffD.operation == SnapshotMergeOperation::Sum); + + REQUIRE(*(int*)diffA.data == subA); + REQUIRE(*(int*)diffB.data == sumB); + REQUIRE(*(int*)diffC.data == subC); + REQUIRE(*(int*)diffD.data == sumD); +} + +TEST_CASE_METHOD(SnapshotTestFixture, "Test snapshot merge regions", "[util]") +{ + std::string snapKey = "foobar123"; + int snapPages = 5; + + int offset = HOST_PAGE_SIZE + (10 * sizeof(int32_t)); + + faabric::util::SnapshotData snap; + snap.size = snapPages * faabric::util::HOST_PAGE_SIZE; + snap.data = allocatePages(snapPages); + + std::vector originalData; + std::vector updatedData; + std::vector expectedData; + + faabric::util::SnapshotDataType dataType = + faabric::util::SnapshotDataType::Raw; + faabric::util::SnapshotMergeOperation operation = + faabric::util::SnapshotMergeOperation::Overwrite; + size_t dataLength = 0; + + SECTION("Integer") + { + int originalValue = 0; + int finalValue = 0; + int diffValue = 0; + + dataType = faabric::util::SnapshotDataType::Int; + dataLength = sizeof(int32_t); + + SECTION("Integer sum") + { + originalValue = 100; + finalValue = 150; + diffValue = 50; + + operation = faabric::util::SnapshotMergeOperation::Sum; + } + + SECTION("Integer subtract") + { + originalValue = 150; + finalValue = 100; + diffValue = 50; + + operation = faabric::util::SnapshotMergeOperation::Subtract; + } + + SECTION("Integer product") + { + originalValue = 3; + finalValue = 150; + diffValue = 50; + + operation = faabric::util::SnapshotMergeOperation::Product; + } + + SECTION("Integer max") + { + originalValue = 10; + finalValue = 200; + diffValue = 200; + + operation = faabric::util::SnapshotMergeOperation::Max; + } + + SECTION("Integer min") + { + originalValue = 30; + finalValue = 10; + diffValue = 10; + + operation = faabric::util::SnapshotMergeOperation::Max; + } + + originalData = faabric::util::valueToBytes(originalValue); + updatedData = faabric::util::valueToBytes(finalValue); + expectedData = faabric::util::valueToBytes(diffValue); + } + + // Write the original data into place + std::memcpy(snap.data + offset, originalData.data(), originalData.size()); + + // Take the snapshot + reg.takeSnapshot(snapKey, snap, true); + + // Map the snapshot to some memory + size_t sharedMemSize = snapPages * HOST_PAGE_SIZE; + uint8_t* sharedMem = allocatePages(snapPages); + + reg.mapSnapshot(snapKey, sharedMem); + + // Reset dirty tracking + faabric::util::resetDirtyTracking(); + + // Set up the merge region + snap.addMergeRegion(offset, dataLength, dataType, operation); + + // Modify the value + std::memcpy(sharedMem + offset, updatedData.data(), updatedData.size()); + + // Get the snapshot diffs + std::vector actualDiffs = + snap.getChangeDiffs(sharedMem, sharedMemSize); + + // Check number of diffs + REQUIRE(actualDiffs.size() == 1); + + SnapshotDiff diff = actualDiffs.at(0); + REQUIRE(diff.offset == offset); + REQUIRE(diff.operation == operation); + REQUIRE(diff.dataType == dataType); + REQUIRE(diff.size == dataLength); + + // Check actual and expected + std::vector actualData(diff.data, diff.data + dataLength); + REQUIRE(actualData == expectedData); + + deallocatePages(snap.data, snapPages); +} + +TEST_CASE_METHOD(SnapshotTestFixture, "Test invalid snapshot merges", "[util]") +{ + std::string snapKey = "foobar123"; + int snapPages = 3; + int offset = HOST_PAGE_SIZE + (2 * sizeof(int32_t)); + + faabric::util::SnapshotData snap; + snap.size = snapPages * faabric::util::HOST_PAGE_SIZE; + snap.data = allocatePages(snapPages); + + faabric::util::SnapshotDataType dataType = + faabric::util::SnapshotDataType::Raw; + faabric::util::SnapshotMergeOperation operation = + faabric::util::SnapshotMergeOperation::Overwrite; + size_t dataLength = 0; + + std::string expectedMsg; + + SECTION("Integer overwrite") + { + dataType = faabric::util::SnapshotDataType::Int; + dataLength = sizeof(int32_t); + expectedMsg = "Unhandled integer merge operation"; + } + + SECTION("Raw sum") + { + dataType = faabric::util::SnapshotDataType::Raw; + dataLength = 123; + expectedMsg = "Merge region for unhandled data type"; + } + + // Take the snapshot + reg.takeSnapshot(snapKey, snap, true); + + // Map the snapshot + size_t sharedMemSize = snapPages * HOST_PAGE_SIZE; + uint8_t* sharedMem = allocatePages(snapPages); + reg.mapSnapshot(snapKey, sharedMem); + + // Reset dirty tracking + faabric::util::resetDirtyTracking(); + + // Set up the merge region + snap.addMergeRegion(offset, dataLength, dataType, operation); + + // Modify the value + std::vector bytes(dataLength, 3); + std::memcpy(sharedMem + offset, bytes.data(), bytes.size()); + + // Check getting diffs throws an exception + bool failed = false; + try { + snap.getChangeDiffs(sharedMem, sharedMemSize); + } catch (std::runtime_error& ex) { + failed = true; + REQUIRE(ex.what() == expectedMsg); + } + + REQUIRE(failed); + deallocatePages(snap.data, snapPages); +} +}