Skip to content

Commit

Permalink
memory context
Browse files Browse the repository at this point in the history
  • Loading branch information
wdeconinck committed Oct 14, 2024
1 parent 95c89f4 commit e9d6e5d
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 19 deletions.
66 changes: 59 additions & 7 deletions src/atlas/runtime/Memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,23 @@
#include "atlas/library/Library.h"
#include "atlas/library/config.h"
#include "atlas/runtime/Log.h"
#include "atlas/runtime/Exception.h"
#include "eckit/log/Bytes.h"

namespace atlas {

static bool unified_ = false;

struct MemoryScope {
MemoryScope() {
pluto::scope::push();
previous_unified_ = unified_;
}
~MemoryScope() {
unified_ = previous_unified_;
pluto::scope::pop();
}
MemoryScope(const MemoryScope& previous) {
device_memory_mapped_ = previous.device_memory_mapped_;
}
bool device_memory_mapped_ = false;
bool previous_unified_;
};

static std::stack<MemoryScope>& scope_stack() {
Expand All @@ -43,19 +45,69 @@ static std::stack<MemoryScope>& scope_stack() {
}

void memory::set_unified(bool value) {
scope_stack().top().device_memory_mapped_ = value;
unified_ = value;
}
bool memory::get_unified() {
return scope_stack().top().device_memory_mapped_;
return unified_;
}

void memory::scope::push() {
scope_stack().emplace(scope_stack().top());
scope_stack().emplace();
}
void memory::scope::pop() {
scope_stack().pop();
}


namespace memory {
context::context() {
reset();
}

void context::reset() {
unified_ = get_unified();
host_memory_resource_ = pluto::host::get_default_resource();
device_memory_resource_ = pluto::device::get_default_resource();
};

static std::map<std::string, std::unique_ptr<memory::context>> context_registry_;

bool context_exists(std::string_view name) {
if (context_registry_.find(std::string(name)) != context_registry_.end()) {
return true;
}
return false;
}

void register_context(std::string_view name) {
std::string _name{name};
ATLAS_ASSERT( !context_exists(name) );
context_registry_.emplace(_name, new context());
}

void unregister_context(std::string_view name) {
ATLAS_ASSERT( context_exists(name) );
context_registry_.erase(std::string(name));
}

context* get_context(std::string_view name) {
ATLAS_ASSERT( context_exists(name) );
auto& ctx = context_registry_.at(std::string(name));
return ctx.get();
}

void set_context(context* ctx) {
pluto::host::set_default_resource(ctx->host_memory_resource());
pluto::device::set_default_resource(ctx->device_memory_resource());
set_unified(ctx->unified());
}

void set_context(std::string_view name) {
set_context(get_context(name));
}

}

Memory::Memory(std::string_view name) :
name_(name) {
}
Expand Down
46 changes: 34 additions & 12 deletions src/atlas/runtime/Memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,25 @@ class Memory {
namespace memory {

namespace host {
inline void set_default_resource(pluto::memory_resource* mr) {
pluto::host::set_default_resource(mr);
}
// inline void set_default_resource(pluto::memory_resource* mr) {
// pluto::host::set_default_resource(mr);
// }

inline void set_default_resource(std::string_view name) {
pluto::host::set_default_resource(name);
}
// inline void set_default_resource(std::string_view name) {
// pluto::host::set_default_resource(name);
// }

std::unique_ptr<pluto::memory_resource> traced_resource(pluto::memory_resource* upstream = nullptr);
}

namespace device {
inline void set_default_resource(pluto::memory_resource* mr) {
pluto::device::set_default_resource(mr);
}
// inline void set_default_resource(pluto::memory_resource* mr) {
// pluto::device::set_default_resource(mr);
// }

inline void set_default_resource(std::string_view name) {
pluto::device::set_default_resource(name);
}
// inline void set_default_resource(std::string_view name) {
// pluto::device::set_default_resource(name);
// }

std::unique_ptr<pluto::memory_resource> traced_resource(pluto::memory_resource* upstream = nullptr);
}
Expand All @@ -114,6 +114,28 @@ struct scope {
static void push();
static void pop();
};

class context {
public:
context();
pluto::memory_resource* host_memory_resource() { return host_memory_resource_; }
pluto::memory_resource* device_memory_resource() { return device_memory_resource_; }
bool unified() { return unified_; }
void reset();
private:
bool unified_;
pluto::memory_resource* host_memory_resource_;
pluto::memory_resource* device_memory_resource_;
};

void register_context(std::string_view name);
void unregister_context(std::string_view name);
bool context_exists(std::string_view name);
void set_context(std::string_view name);
void set_context(context* ctx);
context* get_context(std::string_view name);


}

} // namespace atlas
7 changes: 7 additions & 0 deletions src/tests/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ foreach( test test_library test_library_noargs test_library_init_nofinal test_li
)
endforeach()

ecbuild_add_test( TARGET atlas_test_memory
SOURCES test_memory.cc
LIBS atlas
ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT} ATLAS_TRACE_REPORT=1
)


if( HAVE_FCTEST )

add_fctest( TARGET atlas_fctest_trace
Expand Down
157 changes: 157 additions & 0 deletions src/tests/runtime/test_memory.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* (C) Copyright 2013 ECMWF.
*
* This software is licensed under the terms of the Apache Licence Version 2.0
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
* In applying this licence, ECMWF does not waive the privileges and immunities
* granted to it by virtue of its status as an intergovernmental organisation
* nor does it submit to any jurisdiction.
*/

#include "atlas/runtime/Memory.h"
#include "tests/AtlasTestEnvironment.h"


namespace atlas {
namespace test {

struct CustomMemoryResource : public pluto::memory_resource {
void* do_allocate(std::size_t size, std::size_t alignment) override {
std::cout << " + custom allocate" << std::endl;
return pluto::new_delete_resource()->allocate(size, alignment);
}
void do_deallocate(void* ptr, std::size_t size, std::size_t alignment) override {
std::cout << " - custom deallocate" << std::endl;
pluto::new_delete_resource()->deallocate(ptr, size, alignment);
}
bool do_is_equal(const pluto::memory_resource& other) const override {
return true;
}
};

void run_allocator(pluto::allocator<double>&& allocator) {
std::size_t size = 10;
double* data = allocator.allocate(size);
allocator.deallocate(data, size);
}

void run_default() {
run_allocator(pluto::allocator<double>());
}

void run_resource(pluto::memory_resource* mr) {
run_allocator(pluto::host::allocator<double>(mr));
}

void run_registered(std::string_view resource) {
run_resource(pluto::get_registered_resource(resource));
}

void run_scoped_default(std::string_view resource) {
atlas::memory::scope mem_scope;
pluto::host::set_default_resource(resource);
run_default();
}

CASE("test scope") {
atlas::memory::set_unified(true);
EXPECT_EQ(atlas::memory::get_unified(),true);
atlas::memory::set_unified(false);
EXPECT_EQ(atlas::memory::get_unified(),false);
atlas::memory::scope::push();
{
EXPECT_EQ(atlas::memory::get_unified(),false);
atlas::memory::set_unified(true);
EXPECT_EQ(atlas::memory::get_unified(),true);
}
atlas::memory::scope::pop();
EXPECT_EQ(atlas::memory::get_unified(),false);

// Now nested scope
atlas::memory::scope::push();
{
atlas::memory::set_unified(true);
atlas::memory::scope::push();
{
EXPECT_EQ(atlas::memory::get_unified(),true);
atlas::memory::set_unified(false);
atlas::memory::scope::push();
{
EXPECT_EQ(atlas::memory::get_unified(),false);
atlas::memory::set_unified(true);
EXPECT_EQ(atlas::memory::get_unified(),true);
}
atlas::memory::scope::pop();
EXPECT_EQ(atlas::memory::get_unified(),false);
}
atlas::memory::scope::pop();
EXPECT_EQ(atlas::memory::get_unified(),true);
}
atlas::memory::scope::pop();
EXPECT_EQ(atlas::memory::get_unified(),false);
}

CASE("test scope alloc") {
run_default();
run_scoped_default("pluto::pinned_resource");
run_scoped_default("pluto::new_delete_resource");
run_scoped_default("pluto::pinned_pool_resource");
pluto::pinned_pool_resource()->release();
}

// --------------------------------------------------------------------------


CASE("test extension") {
CustomMemoryResource mr;
pluto::register_resource("custom_resource", &mr);

run_scoped_default("custom_resource");

pluto::unregister_resource("custom_resource");
}

// --------------------------------------------------------------------------

CASE("test context") {
CustomMemoryResource mr;
pluto::Register mr_register("custom_resource", &mr);

memory::scope::push();
pluto::host::set_default_resource("custom_resource");
memory::set_unified(false);
memory::register_context("custom");
memory::scope::pop();

run_default();

memory::scope::push();
pluto::host::set_default_resource("pluto::managed_resource");
memory::set_unified(true);
memory::register_context("unified");
memory::scope::pop();

EXPECT(memory::context_exists("custom"));
EXPECT(memory::context_exists("unified"));

run_default();

memory::scope::push();
memory::set_context("custom");
run_default();

memory::set_context("unified");
run_default();
memory::scope::pop();

memory::unregister_context("custom");
memory::unregister_context("unified");

}

} // namespace test
} // namespace atlas

int main(int argc, char** argv) {
return atlas::test::run(argc, argv);
}

0 comments on commit e9d6e5d

Please sign in to comment.