diff --git a/xla/service/cpu/simple_orc_jit.cc b/xla/service/cpu/simple_orc_jit.cc index e4a2af7c7bfdf..eaeee88cbeb34 100644 --- a/xla/service/cpu/simple_orc_jit.cc +++ b/xla/service/cpu/simple_orc_jit.cc @@ -18,18 +18,24 @@ limitations under the License. #include #include +#include #include #include #include +#include // NOLINT #include #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/RTDyldMemoryManager.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/IR/Mangler.h" #include "llvm/IR/Operator.h" +#include "llvm/Support/Alignment.h" #include "llvm/Support/CodeGen.h" +#include "llvm/Support/Memory.h" +#include "llvm/Support/Process.h" #include "llvm/TargetParser/Host.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" // from @llvm-project #include "xla/service/cpu/cpu_runtime.h" @@ -54,6 +60,7 @@ limitations under the License. #include "xla/service/cpu/windows_compatibility.h" #include "xla/service/custom_call_target_registry.h" #include "xla/types.h" +#include "xla/util.h" #include "tsl/platform/logging.h" #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) @@ -82,6 +89,200 @@ llvm::SmallVector DetectMachineAttributes() { return result; } +class DefaultMemoryMapper final + : public llvm::SectionMemoryManager::MemoryMapper { + public: + llvm::sys::MemoryBlock allocateMappedMemory( + llvm::SectionMemoryManager::AllocationPurpose purpose, size_t num_bytes, + const llvm::sys::MemoryBlock* const near_block, unsigned flags, + std::error_code& error_code) override { + return llvm::sys::Memory::allocateMappedMemory(num_bytes, near_block, flags, + error_code); + } + + std::error_code protectMappedMemory(const llvm::sys::MemoryBlock& block, + unsigned flags) override { + return llvm::sys::Memory::protectMappedMemory(block, flags); + } + + std::error_code releaseMappedMemory(llvm::sys::MemoryBlock& m) override { + return llvm::sys::Memory::releaseMappedMemory(m); + } +}; + +// On Windows, LLVM may emit IMAGE_REL_AMD64_ADDR32NB COFF relocations when +// referring to read-only data, however IMAGE_REL_AMD64_ADDR32NB requires that +// the read-only data section follow within 2GB of the code. Oddly enough, +// the LLVM SectionMemoryManager does nothing to enforce this +// (https://github.com/llvm/llvm-project/issues/55386), leading to crashes on +// Windows when the sections end up in the wrong order. Since none +// of the memory managers in the LLVM tree obey the necessary ordering +// constraints, we need to roll our own. +// +// ContiguousSectionMemoryManager is an alternative to SectionMemoryManager +// that maps one large block of memory and suballocates it +// for each section, in the correct order. This is easy enough to do because of +// the llvm::RuntimeDyld::MemoryManager::reserveAllocationSpace() hook, which +// ensures that LLVM will tell us ahead of time the total sizes of all the +// relevant sections. We also know that XLA isn't going to do any more +// complicated memory management: we will allocate the sections once and we are +// done. +class ContiguousSectionMemoryManager : public llvm::RTDyldMemoryManager { + public: + explicit ContiguousSectionMemoryManager( + llvm::SectionMemoryManager::MemoryMapper* mmapper) + : mmapper_(mmapper), mmapper_is_owned_(false) { + if (mmapper_ == nullptr) { + mmapper_ = new DefaultMemoryMapper(); + mmapper_is_owned_ = true; + } + } + + ~ContiguousSectionMemoryManager() override; + + bool needsToReserveAllocationSpace() override { return true; } + void reserveAllocationSpace(uintptr_t code_size, llvm::Align code_align, + uintptr_t ro_data_size, llvm::Align ro_data_align, + uintptr_t rw_data_size, + llvm::Align rw_data_align) override; + + uint8_t* allocateDataSection(uintptr_t size, unsigned alignment, + unsigned section_id, + llvm::StringRef section_name, + bool is_read_only) override; + + uint8_t* allocateCodeSection(uintptr_t size, unsigned alignment, + unsigned section_id, + llvm::StringRef section_name) override; + + bool finalizeMemory(std::string* err_msg) override; + + private: + llvm::SectionMemoryManager::MemoryMapper* mmapper_; + bool mmapper_is_owned_; + + llvm::sys::MemoryBlock allocation_; + + // Sections must be in the order code < rodata < rwdata. + llvm::sys::MemoryBlock code_block_; + llvm::sys::MemoryBlock ro_data_block_; + llvm::sys::MemoryBlock rw_data_block_; + + llvm::sys::MemoryBlock code_free_; + llvm::sys::MemoryBlock ro_data_free_; + llvm::sys::MemoryBlock rw_data_free_; + + uint8_t* Allocate(llvm::sys::MemoryBlock& free_block, std::uintptr_t size, + unsigned alignment); +}; + +ContiguousSectionMemoryManager::~ContiguousSectionMemoryManager() { + if (allocation_.allocatedSize() != 0) { + auto ec = mmapper_->releaseMappedMemory(allocation_); + if (ec) { + LOG(ERROR) << "releaseMappedMemory failed with error: " << ec.message(); + } + } + if (mmapper_is_owned_) { + delete mmapper_; + } +} + +void ContiguousSectionMemoryManager::reserveAllocationSpace( + uintptr_t code_size, llvm::Align code_align, uintptr_t ro_data_size, + llvm::Align ro_data_align, uintptr_t rw_data_size, + llvm::Align rw_data_align) { + CHECK_EQ(allocation_.allocatedSize(), 0); + + static const size_t page_size = llvm::sys::Process::getPageSizeEstimate(); + CHECK_LE(code_align.value(), page_size); + CHECK_LE(ro_data_align.value(), page_size); + CHECK_LE(rw_data_align.value(), page_size); + code_size = RoundUpTo(code_size + code_align.value(), page_size); + ro_data_size = + RoundUpTo(ro_data_size + ro_data_align.value(), page_size); + rw_data_size = + RoundUpTo(rw_data_size + rw_data_align.value(), page_size); + uintptr_t total_size = + code_size + ro_data_size + rw_data_size + page_size * 3; + + std::error_code ec; + allocation_ = mmapper_->allocateMappedMemory( + llvm::SectionMemoryManager::AllocationPurpose::Code, total_size, nullptr, + llvm::sys::Memory::MF_READ | llvm::sys::Memory::MF_WRITE, ec); + if (ec) { + LOG(ERROR) << "allocateMappedMemory failed with error: " << ec.message(); + return; + } + + auto base = reinterpret_cast(allocation_.base()); + code_block_ = code_free_ = + llvm::sys::MemoryBlock(reinterpret_cast(base), code_size); + base += code_size; + ro_data_block_ = ro_data_free_ = + llvm::sys::MemoryBlock(reinterpret_cast(base), ro_data_size); + base += ro_data_size; + rw_data_block_ = rw_data_free_ = + llvm::sys::MemoryBlock(reinterpret_cast(base), rw_data_size); +} + +uint8_t* ContiguousSectionMemoryManager::allocateDataSection( + uintptr_t size, unsigned alignment, unsigned section_id, + llvm::StringRef section_name, bool is_read_only) { + if (is_read_only) { + return Allocate(ro_data_free_, size, alignment); + } else { + return Allocate(rw_data_free_, size, alignment); + } +} + +uint8_t* ContiguousSectionMemoryManager::allocateCodeSection( + uintptr_t size, unsigned alignment, unsigned section_id, + llvm::StringRef section_name) { + return Allocate(code_free_, size, alignment); +} + +uint8_t* ContiguousSectionMemoryManager::Allocate( + llvm::sys::MemoryBlock& free_block, std::uintptr_t size, + unsigned alignment) { + auto base = reinterpret_cast(free_block.base()); + auto start = RoundUpTo(base, alignment); + uintptr_t padded_size = (start - base) + size; + if (padded_size > free_block.allocatedSize()) { + LOG(ERROR) << "Failed to satisfy suballocation request for " << size; + return nullptr; + } + free_block = + llvm::sys::MemoryBlock(reinterpret_cast(base + padded_size), + free_block.allocatedSize() - padded_size); + return reinterpret_cast(start); +} + +bool ContiguousSectionMemoryManager::finalizeMemory(std::string* err_msg) { + std::error_code ec; + + ec = mmapper_->protectMappedMemory( + code_block_, llvm::sys::Memory::MF_READ | llvm::sys::Memory::MF_EXEC); + if (ec) { + if (err_msg) { + *err_msg = ec.message(); + } + return true; + } + ec = + mmapper_->protectMappedMemory(ro_data_block_, llvm::sys::Memory::MF_READ); + if (ec) { + if (err_msg) { + *err_msg = ec.message(); + } + return true; + } + + llvm::sys::Memory::InvalidateInstructionCache(code_block_.base(), + code_block_.allocatedSize()); + return false; +} + } // namespace /*static*/ std::unique_ptr @@ -117,7 +318,7 @@ SimpleOrcJIT::SimpleOrcJIT( execution_session_(std::move(execution_session)), object_layer_(*execution_session_, []() { - return std::make_unique( + return std::make_unique( orc_jit_memory_mapper::GetInstance()); }), compile_layer_(