From 1a55b84484b5e645fa5ace2e96fefcfe2f0b820c Mon Sep 17 00:00:00 2001 From: Sven van Haastregt Date: Tue, 22 Oct 2024 10:32:30 +0100 Subject: [PATCH] SPIRVReader: Add OpCopyMemory support Add support for translating `OpCopyMemory` into `llvm.memcpy`. Fixes https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/2769 --- lib/SPIRV/SPIRVReader.cpp | 42 ++++++++++++++++++++++++++++----------- test/OpCopyMemory.spvasm | 41 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 12 deletions(-) create mode 100644 test/OpCopyMemory.spvasm diff --git a/lib/SPIRV/SPIRVReader.cpp b/lib/SPIRV/SPIRVReader.cpp index 48fa29200..cb2f69403 100644 --- a/lib/SPIRV/SPIRVReader.cpp +++ b/lib/SPIRV/SPIRVReader.cpp @@ -1814,21 +1814,39 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F, return mapValue(BV, LI); } + case OpCopyMemory: case OpCopyMemorySized: { - SPIRVCopyMemorySized *BC = static_cast(BV); - CallInst *CI = nullptr; - llvm::Value *Dst = transValue(BC->getTarget(), F, BB); - MaybeAlign Align(BC->getAlignment()); + llvm::Value *Src = nullptr; + llvm::Value *Dst = nullptr; + llvm::Value *Size = nullptr; + SPIRVMemoryAccess *MA = nullptr; + if (OC == OpCopyMemory) { + auto *BC = static_cast(BV); + Src = transValue(BC->getSource(), F, BB); + Dst = transValue(BC->getTarget(), F, BB); + Type *EltTy = transType(BC->getSource()->getType()->getPointerElementType()); + Size = ConstantExpr::getSizeOf(EltTy); + MA = static_cast(BC); + } else { + assert(OC == OpCopyMemorySized); + auto *BC = static_cast(BV); + Src = transValue(BC->getSource(), F, BB); + Dst = transValue(BC->getTarget(), F, BB); + Size = transValue(BC->getSize(), F, BB); + MA = static_cast(BC); + } + assert(Src); + assert(Dst); + assert(Size); + assert(MA); + MaybeAlign Align(MA->getAlignment()); MaybeAlign SrcAlign = - BC->getSrcAlignment() ? MaybeAlign(BC->getSrcAlignment()) : Align; - llvm::Value *Size = transValue(BC->getSize(), F, BB); - bool IsVolatile = BC->SPIRVMemoryAccess::isVolatile(); - IRBuilder<> Builder(BB); + MA->getSrcAlignment() ? MaybeAlign(MA->getSrcAlignment()) : Align; + bool IsVolatile = MA->SPIRVMemoryAccess::isVolatile(); - if (!CI) { - llvm::Value *Src = transValue(BC->getSource(), F, BB); - CI = Builder.CreateMemCpy(Dst, Align, Src, SrcAlign, Size, IsVolatile); - } + IRBuilder<> Builder(BB); + CallInst *CI = + Builder.CreateMemCpy(Dst, Align, Src, SrcAlign, Size, IsVolatile); if (isFuncNoUnwind()) CI->getFunction()->addFnAttr(Attribute::NoUnwind); return mapValue(BV, CI); diff --git a/test/OpCopyMemory.spvasm b/test/OpCopyMemory.spvasm new file mode 100644 index 000000000..755b82c41 --- /dev/null +++ b/test/OpCopyMemory.spvasm @@ -0,0 +1,41 @@ +; Check SPIRVReader support for OpCopyMemory. + +; REQUIRES: spirv-as +; RUN: spirv-as --target-env spv1.0 -o %t.spv %s +; RUN: spirv-val %t.spv +; RUN: llvm-spirv -r %t.spv -o %t.rev.bc +; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s + + OpCapability Addresses + OpCapability Int16 + OpCapability Kernel + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %kernel "copymemory" + OpName %pShort "pShort" + OpName %dstShort "dstShort" + OpName %pInt "pInt" + OpName %dstInt "dstInt" + %ushort = OpTypeInt 16 0 + %uint = OpTypeInt 32 0 + %void = OpTypeVoid + %gptr_short = OpTypePointer CrossWorkgroup %ushort + %pptr_short = OpTypePointer Function %ushort + %gptr_int = OpTypePointer CrossWorkgroup %uint + %pptr_int = OpTypePointer Function %uint + %kernel_sig = OpTypeFunction %void %gptr_short %gptr_int + %ushort_42 = OpConstant %ushort 42 + %uint_4242 = OpConstant %uint 4242 + %kernel = OpFunction %void None %kernel_sig + %dstShort = OpFunctionParameter %gptr_short + %dstInt = OpFunctionParameter %gptr_int + %entry = OpLabel + %pShort = OpVariable %pptr_short Function %ushort_42 + %pInt = OpVariable %pptr_int Function %uint_4242 + OpCopyMemory %dstShort %pShort + OpCopyMemory %dstInt %pInt + OpReturn + OpFunctionEnd + +; CHECK-LABEL: define spir_kernel void @copymemory(ptr addrspace(1) %dstShort, ptr addrspace(1) %dstInt) +; CHECK: call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) %dstShort, ptr @pShort, i64 ptrtoint (ptr getelementptr (i16, ptr null, i32 1) to i64), i1 false) +; CHECK: call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) %dstInt, ptr @pInt, i64 ptrtoint (ptr getelementptr (i32, ptr null, i32 1) to i64), i1 false)