From 9b37b3ea4467b3aab9110e0d259d0cf27478697d Mon Sep 17 00:00:00 2001 From: Chester Liu <4710575+skyline75489@users.noreply.github.com> Date: Mon, 23 Sep 2024 17:19:30 +0800 Subject: [PATCH 01/13] Specify the paths of system tools when building Apple framework (#22056) ### Description Specify the path of `ar`, `ld` and `libtool` when building apple framework. ### Motivation and Context Sometimes non-system executables will comes before the system-provided ones. This PR intends to prevent it from happening. --- cmake/onnxruntime.cmake | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 7e992fb33077c..f2be742458313 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -352,12 +352,12 @@ if(onnxruntime_BUILD_APPLE_FRAMEWORK) # make both maccatalyst and other builds do the same thing. set(CUR_TARGET_CMAKE_SOURCE_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/${_LIB}.dir) add_custom_command(TARGET onnxruntime POST_BUILD - COMMAND ar -t $ | grep "\.o$" > ${_LIB}.object_file_list.txt + COMMAND /usr/bin/ar -t $ | grep "\.o$" > ${_LIB}.object_file_list.txt COMMAND ${CMAKE_COMMAND} -E env python3 ${CMAKE_CURRENT_SOURCE_DIR}/maccatalyst_prepare_objects_for_prelink.py ${CUR_TARGET_CMAKE_SOURCE_LIB_DIR} ${CUR_STATIC_LIB_OBJ_DIR} ${CUR_STATIC_LIB_OBJ_DIR}/${_LIB}.object_file_list.txt WORKING_DIRECTORY ${CUR_STATIC_LIB_OBJ_DIR}) else() add_custom_command(TARGET onnxruntime POST_BUILD - COMMAND ar ARGS -x $ + COMMAND /usr/bin/ar ARGS -x $ WORKING_DIRECTORY ${CUR_STATIC_LIB_OBJ_DIR}) endif() endif() @@ -378,12 +378,12 @@ if(onnxruntime_BUILD_APPLE_FRAMEWORK) # do the pre-link with `ld -r` to create a single relocatable object with correct symbol visibility add_custom_command(TARGET onnxruntime POST_BUILD - COMMAND ld ARGS -r -o ${STATIC_LIB_DIR}/prelinked_objects.o */*.o ../*.a + COMMAND /usr/bin/ld ARGS -r -o ${STATIC_LIB_DIR}/prelinked_objects.o */*.o ../*.a WORKING_DIRECTORY ${STATIC_LIB_TEMP_DIR}) # create the static library add_custom_command(TARGET onnxruntime POST_BUILD - COMMAND libtool -static -o ${STATIC_FRAMEWORK_DIR}/onnxruntime prelinked_objects.o + COMMAND /usr/bin/libtool -static -o ${STATIC_FRAMEWORK_DIR}/onnxruntime prelinked_objects.o WORKING_DIRECTORY ${STATIC_LIB_DIR}) # Assemble the other pieces of the static framework From 80e9df826e7ab544d153ce5032e55626c7bfdee9 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 24 Sep 2024 02:32:09 +0800 Subject: [PATCH 02/13] [js/webgpu] Optimize InstanceNormalization (#21995) ### Description For InstanceNormalization, it has `y = scale * (x - mean) / sqrt(variance + epsilon) + B` , where mean and variance are computed per instance per channel. Calculating mean and variance per channel is a reduce processing, which is NCHW layout friendly since it makes the adjacent threads can access contiguous data in gpu memory. This PR optimizes both NHWC and NCHW InstanceNormalization. To efficiently calculate the mean and variance, we need to make sure the input is NCHW instead of NHWC. Then use shared memory to do the reduce operation to get `channel_scale` and `channel_shift`. With this PR, getting `channel_scale` and `channel_shift` are same for NHWC and NCHW InstanceNormalization. And the overall performance becomes very close now. Below data comes from SD Turbo profiling results. Before (InstanceNormalization overall time: 140.84 ms) InstanceNormalization\|InstanceNormComputeMean | 129.70 -- | -- InstanceNormalization\|InstanceNormalizationNHWC | 10.55 InstanceNormalization\|InstanceNormComputeChannelScaleShift | 0.59 After (InstanceNormalization overall time: 59.44 ms) InstanceNormalization\|InstanceNormComputeChannelScaleShift | 28.57 -- | -- InstanceNormalization\|TransposeShared | 20.19 InstanceNormalization\|InstanceNormalizationNHWC | 10.68 --- .../lib/wasm/jsep/webgpu/ops/instance-norm.ts | 369 ++++++++---------- 1 file changed, 154 insertions(+), 215 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 7b6140f3b1185..859bd850862aa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -4,18 +4,17 @@ import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { ShapeUtil } from '../../util'; -import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; +import { ComputeContext, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; +import { createTransposeProgramInfo } from './transpose'; import { createTensorShapeVariables, - fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, - UniformsArrayType, } from './common'; export interface InstanceNormAttributes { @@ -23,117 +22,7 @@ export interface InstanceNormAttributes { format: 'NHWC' | 'NCHW'; } -const createInstanceNormProgramInfo = ( - inputs: readonly TensorView[], - attributes: InstanceNormAttributes, -): ProgramInfo => { - const xShape = inputs[0].dims; - const outputShape = xShape; - const axis = 2; - const normCount = ShapeUtil.sizeToDimension(xShape, axis); - const normSize = ShapeUtil.sizeFromDimension(xShape, axis); - const components = getMaxComponents(normSize); - const normPackedSize = normSize / components; - const inputShape = [xShape[0], xShape[1], normPackedSize]; - const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; - const programUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: normSize }, - { type: DataType.uint32, data: normPackedSize }, - ]; - programUniforms.push(...createTensorShapeVariables(inputShape, inputShape)); - - const getShaderSource = (shaderHelper: ShaderHelper) => { - const x = inputVariable('x', inputs[0].dataType, inputShape.length, components); - const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); - const output = outputVariable('output', inputs[0].dataType, inputShape.length, components); - const variables = [x, scale, bias, output]; - const dataType = x.type.value; - const f32Type = components === 1 ? 'f32' : `vec${components}`; - const workgroupSize = 64; - - const uniforms: UniformsArrayType = [ - { name: 'normSize', type: 'u32' }, - { name: 'normPackedSize', type: 'u32' }, - ]; - return ` - var meanShared : f32; - var squaredNormShared : f32; - var workgroupShared : array<${f32Type}, ${workgroupSize}>; - const workgroupSize = ${workgroupSize}u; - ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} - ${shaderHelper.mainStart(workgroupSize)} - let norm = global_idx / workgroupSize; - let batch = norm / uniforms.x_shape[1]; - let channel = norm % uniforms.x_shape[1]; - let localIndex = local_id.x; - - // initialize workgroup memory - var initial = ${f32Type}(0); - for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { - initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')}); - } - workgroupShared[localIndex] = initial; - workgroupBarrier(); - - // Calculate the mean of current channel data. - for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) { - if (localIndex < currSize) { - workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize]; - } - workgroupBarrier(); - } - if (localIndex == 0) { - meanShared = ${sumVector('workgroupShared[0]', components)} / f32(uniforms.normSize); - } - workgroupBarrier(); - - // reinitialize workgroup memory. - initial = ${f32Type}(0); - for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { - let deviation = ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared); - initial = initial + deviation * deviation; - } - workgroupShared[localIndex] = initial; - workgroupBarrier(); - - // Calculate the sum of square of deviation of current channel data. - for (var currSize = workgroupSize >> 1; currSize > 0; currSize = currSize >> 1) { - if (localIndex < currSize) { - workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize]; - } - workgroupBarrier(); - } - if (localIndex == 0) { - squaredNormShared = ${sumVector('workgroupShared[0]', components)}; - } - workgroupBarrier(); - - let invStdDev = inverseSqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon})); - let channelScale = invStdDev * f32(${scale.getByOffset('channel')}); - let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale; - for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { - let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${ - f32Type - }(channelShift)); - ${output.set('batch', 'channel', 'h', 'value')}; - } - }`; - }; - return { - ...{ name: 'InstanceNormalization' }, - // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. - shaderCache: { hint: `${attributes.epsilon};${components}`, inputDependencies }, - getRunData: () => ({ - outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], - dispatchGroup: { x: normCount }, - programUniforms, - }), - getShaderSource, - }; -}; - -const computeMean = ( +const computeChannelScaleShift = ( context: ComputeContext, input: TensorView, scale: TensorView, @@ -143,121 +32,140 @@ const computeMean = ( c: number, epsilon: number, ) => { - const components = getMaxComponents(c); - const WG = 64; - // we will store channel scale and channel shift in [2, components] matrix - // or in vec2 when components == 1 - const outputType = components === 1 ? 'vec2f' : `mat2x${components}f`; - const sumCastType = components === 1 ? 'f32' : `vec${components}f`; - const setOutputValue = (var1: string, var2: string) => `${outputType}(${var1}, ${var2})`; - const unitsOfWork = (n * c) / components; - const wgSize = Math.ceil(h / WG); + const components = getMaxComponents(h); + const f32Type = components === 1 ? 'f32' : `vec${components}f`; + const wgType = components === 1 ? 'vec2f' : `mat2x${components}f`; + const unitsOfWork = n * c; - const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type']; - const meanProgramUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: wgSize }, - { type: DataType.uint32, data: h }, - { type: DataType.uint32, data: Math.floor(c / components) }, - { type: DataType.uint32, data: Math.floor((h * c) / components) }, - ]; + const inputShape = [n, c, h / components]; + const outputShape = [n, c, 2]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; + const programUniforms: ProgramUniform[] = []; + programUniforms.push(...createTensorShapeVariables(inputShape, outputShape)); - const getMeanShaderSource = (shaderHelper: ShaderHelper) => { - const inputHelper = inputVariable('input', input.dataType, input.dims, components); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const x = inputVariable('x', input.dataType, 3, components); + const s = inputVariable('scale', scale.dataType, scale.dims); + const b = inputVariable('bias', bias.dataType, bias.dims); + const output = outputVariable('output', DataType.float, 3, 2); + const variables = [x, s, b, output]; + const workgroupSize = 64; return ` - ${shaderHelper.declareVariables(inputHelper)} - @group(0) @binding(1) var output : array<${outputType}>; - struct Uniforms {wg_size:u32, H:u32, C:u32, image_size:u32}; - @group(0) @binding(2) var uniforms: Uniforms; + var workgroup_shared : array<${wgType}, ${workgroupSize}>; + const workgroup_size = ${workgroupSize}u; + ${shaderHelper.declareVariables(...variables)} + ${shaderHelper.mainStart(workgroupSize)} + let batch = workgroup_index / uniforms.x_shape[1]; + let channel = workgroup_index % uniforms.x_shape[1]; + let hight = uniforms.x_shape[2]; + // initialize workgroup memory + var sum = ${f32Type}(0); + var squared_sum = ${f32Type}(0); + for (var h = local_idx; h < hight; h += workgroup_size) { + let value = ${f32Type}(${x.get('batch', 'channel', 'h')}); + sum += value; + squared_sum += value * value; + } + workgroup_shared[local_idx] = ${wgType}(sum, squared_sum); + workgroupBarrier(); - ${shaderHelper.mainStart(WG)} - let currentImageNumber = global_idx / ${WG} / uniforms.C; - let currentChannelNumber = (global_idx / ${WG}) % uniforms.C; - let wgOffset = local_id.x * uniforms.wg_size; - if (wgOffset >= uniforms.H) { - return; + for (var currSize = workgroup_size >> 1; currSize > 0; currSize = currSize >> 1) { + if (local_idx < currSize) { + workgroup_shared[local_idx] = workgroup_shared[local_idx] + workgroup_shared[local_idx + currSize]; + } + workgroupBarrier(); } - let wgMax = min(wgOffset + uniforms.wg_size, uniforms.H); + if (local_idx == 0) { + let sum_final = ${sumVector('workgroup_shared[0][0]', components)} / f32(hight * ${components}); + let squared_sum_final = ${sumVector('workgroup_shared[0][1]', components)} / f32(hight * ${components}); - let offset = currentImageNumber * uniforms.image_size + currentChannelNumber; - var sum = ${fillVector('f32', components)}; - var squaredSum = ${fillVector('f32', components)}; - for (var i: u32 = wgOffset; i < wgMax; i++) { - let value = ${sumCastType}(input[offset + i * uniforms.C]); - sum += value; - squaredSum += value * value; + let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + f32(${epsilon})); + let channel_scale = inv_std_dev * f32(scale[channel]); + let channel_shift = f32(bias[channel]) - sum_final * channel_scale; + output[workgroup_index] = vec2f(channel_scale, channel_shift); } - output[global_idx] = ${setOutputValue('sum', 'squaredSum')}; }`; }; - const meanValues = context.compute( + return context.compute( { - name: 'InstanceNormComputeMean', - shaderCache: { hint: `${components}`, inputDependencies: meanInputDependencies }, + name: 'InstanceNormComputeChannelScaleShift', + // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. + shaderCache: { hint: `${components};${epsilon}`, inputDependencies }, getRunData: () => ({ - outputs: [{ dims: [n, c, WG, 2], dataType: DataType.float }], - dispatchGroup: { x: (n * c) / components }, - programUniforms: meanProgramUniforms, + outputs: [{ dims: outputShape, dataType: DataType.float }], + dispatchGroup: { x: unitsOfWork }, + programUniforms, }), - getShaderSource: getMeanShaderSource, + getShaderSource, }, - { inputs: [input], outputs: [-1] }, + { inputs: [input, scale, bias], outputs: [-1] }, )[0]; +}; + +const createInstanceNormProgramInfo = ( + context: ComputeContext, + inputs: readonly TensorView[], + attributes: InstanceNormAttributes, +) => { + const xShape = inputs[0].dims; + const outputShape = xShape; + const axis = 2; + const N = xShape[0]; + const C = xShape[1]; + const H = ShapeUtil.sizeFromDimension(xShape, axis); + const components = getMaxComponents(H); + const outputSize = ShapeUtil.size(outputShape) / components; + // compute channel scale and channel shift. + const channelScaleShift = computeChannelScaleShift( + context, + inputs[0], + inputs[1], + inputs[2], + N, + H, + C, + attributes.epsilon, + ); + + const inputShape = [N, C, H / components]; + const scaleShape = [N, C]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'none']; - const programUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: unitsOfWork }, - { type: DataType.uint32, data: h }, - { type: DataType.uint32, data: Math.floor(c / components) }, - { type: DataType.uint32, data: Math.floor((WG * c) / components) }, - ]; - const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type']; const getShaderSource = (shaderHelper: ShaderHelper) => { - const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); - const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); + const x = inputVariable('x', inputs[0].dataType, inputShape.length, components); + const scale = inputVariable('scale_shift', DataType.float, scaleShape.length, 2); + const output = outputVariable('output', inputs[0].dataType, inputShape.length, components); + const variables = [x, scale, output]; return ` - @group(0) @binding(0) var input : array<${outputType}>; - @group(0) @binding(1) var scale : array<${scaleHelper.type.storage}>; - @group(0) @binding(2) var bias : array<${biasHelper.type.storage}>; - @group(0) @binding(3) var output : array<${outputType}>; - struct Uniforms {units_of_work : u32, H: u32, C : u32, image_size : u32}; - @group(0) @binding(4) var uniforms: Uniforms; - + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(...variables)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.units_of_work')} - let currentImageNumber = global_idx / uniforms.C; - let currentChannelNumber = global_idx % uniforms.C; - - let offset = currentImageNumber * uniforms.image_size; - var sum = ${fillVector('f32', components)}; - var squaredSum = ${fillVector('f32', components)}; - for (var i: u32 = 0; i < min(${WG}, uniforms.H); i++) { - let value = input[offset + i + currentChannelNumber * ${WG}]; - sum += value[0]; - squaredSum += value[1]; - } - sum = sum / f32(uniforms.H); - squaredSum = squaredSum / f32(uniforms.H); - let invStdDev = inverseSqrt(squaredSum - sum * sum + f32(${epsilon})); - let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]); - let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale; - - output[global_idx] = ${setOutputValue('channelScale', 'channelShift')}; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let outputIndices = ${output.offsetToIndices('global_idx')}; + let batch = outputIndices[0]; + let channel = outputIndices[1]; + let scale_shift = ${scale.getByIndices('vec2(batch, channel)')}; + let value = ${x.getByOffset('global_idx')} * ${output.type.value}(scale_shift.x) + ${output.type.value}(scale_shift.y); + ${output.setByOffset('global_idx', 'value')}; }`; }; - return context.compute( + + context.compute( { - name: 'InstanceNormComputeChannelScaleShift', - // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. - shaderCache: { hint: `${components};${epsilon}`, inputDependencies }, + name: 'InstanceNormalization', + shaderCache: { hint: `${components}`, inputDependencies }, getRunData: () => ({ - outputs: [{ dims: [n, c, 2], dataType: DataType.float }], - dispatchGroup: { x: Math.ceil(unitsOfWork / 64 /* workgroup size */) }, - programUniforms, + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms: [ + { type: DataType.uint32, data: outputSize }, + ...createTensorShapeVariables(inputShape, scaleShape, inputShape), + ], }), getShaderSource, }, - { inputs: [meanValues, scale, bias], outputs: [-1] }, - )[0]; + { inputs: [inputs[0], channelScaleShift] }, + ); }; const createInstanceNormNHWCProgramInfo = ( @@ -277,30 +185,61 @@ const createInstanceNormNHWCProgramInfo = ( { type: DataType.uint32, data: Math.floor(C / components) }, ]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; - // first compute mean - const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); + + // 1. transpose x from NHWC to NCHW + const transposedXPerm = [0, xShape.length - 1]; + for (let i = 0; i < xShape.length - 2; i++) { + transposedXPerm.push(i + 1); + } + const transposedX = context.compute(createTransposeProgramInfo(context.inputs[0], transposedXPerm), { + inputs: [context.inputs[0]], + outputs: [-1], + })[0]; + // 2. compute channel scale and channel shift. + const channelScaleShift = computeChannelScaleShift( + context, + transposedX, + inputs[1], + inputs[2], + N, + H, + C, + attributes.epsilon, + ); const getShaderSource = (shaderHelper: ShaderHelper) => { const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`; - const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`; - + const scaleType = components === 1 ? 'vec2f' : `mat${components}x2f`; + const scaleData = (num: number) => { + const index = num === 0 ? 'x' : 'y'; + const f32Type = components === 1 ? 'f32' : `vec${components}f`; + switch (components) { + case 1: + return `${dataType}(${f32Type}(scale.${index}))`; + case 2: + return `vec2<${dataType}>(${f32Type}(scale[0].${index}, scale[1].${index}))`; + case 4: + return `vec4<${dataType}>(${f32Type}(scale[0].${index}, scale[1].${index}, scale[2].${index}, scale[3].${index}))`; + default: + throw new Error(`Not supported compoents ${components}`); + } + }; const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); return ` @group(0) @binding(0) var input : array<${inputHelper.type.storage}>; - @group(0) @binding(1) var scaleInput : array<${scaleType}>; + @group(0) @binding(1) var scale_input : array<${scaleType}>; @group(0) @binding(2) var output : array<${outputHelper.type.storage}>; struct Uniforms {H: u32, C : u32}; @group(0) @binding(3) var uniforms: Uniforms; ${shaderHelper.mainStart()} - let currentImageNumber = global_idx / (uniforms.C * uniforms.H); - let currentChannelNumber = global_idx % uniforms.C; + let current_image_number = global_idx / (uniforms.C * uniforms.H); + let current_channel_number = global_idx % uniforms.C; - let scaleOffset = currentImageNumber * uniforms.C + currentChannelNumber; - let scale = scaleInput[scaleOffset]; - output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1])); + let scale_offset = current_image_number * uniforms.C + current_channel_number; + let scale = scale_input[scale_offset]; + output[global_idx] = fma(input[global_idx], ${scaleData(0)}, ${scaleData(1)}); }`; }; context.compute( @@ -322,6 +261,6 @@ export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAt if (attributes.format === 'NHWC') { createInstanceNormNHWCProgramInfo(context, context.inputs, attributes); } else { - context.compute(createInstanceNormProgramInfo(context.inputs, attributes)); + createInstanceNormProgramInfo(context, context.inputs, attributes); } }; From 1a84f53c35049192b1d380cc374a9be9f6cf8f0a Mon Sep 17 00:00:00 2001 From: Christian Bourjau Date: Mon, 23 Sep 2024 22:02:29 +0200 Subject: [PATCH 03/13] Make argmin/armax support identical data types and add int64 support (#21641) --- docs/OperatorKernels.md | 12 +-- .../providers/cpu/cpu_execution_provider.cc | 42 ++++++++++ .../providers/cpu/reduction/reduction_ops.cc | 14 ++++ .../cpu/reduction/reduction_ops_test.cc | 77 +++++++++++++++++++ 4 files changed, 139 insertions(+), 6 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 121240e6e18f9..407e08c96a891 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -27,12 +27,12 @@ Do not modify directly.* |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |AffineGrid|*in* theta:**T1**
*in* size:**T2**
*out* grid:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| -|||[1, 10]|**T** = tensor(float), tensor(int32), tensor(int8), tensor(uint8)| -|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32)| -|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32)| -|||[1, 10]|**T** = tensor(float), tensor(int32)| +|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| +|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)| |Asin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float)| |Asinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float)| |Atan|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 7ed776f1358a5..7b1b136eb091e 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -227,10 +227,16 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int8_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, uint8_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int8_t, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, uint8_t, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ArgMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 13, GRU); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 13, LSTM); @@ -408,9 +414,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int8_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, uint8_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int8_t, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, uint8_t, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ArgMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL1); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL1); @@ -636,9 +646,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, ArgMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, ArgMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, ArgMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, ArgMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, ArgMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, ArgMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Concat); @@ -1443,16 +1457,28 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { int64_t, ReduceSumSquare)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1725,12 +1751,20 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { uint8_t, ArgMax)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2065,11 +2099,19 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { ArgMax)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc index 5aac1d9387f57..24fbfbe8d525b 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc @@ -288,22 +288,36 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ReduceSumSquare, 18); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceSumSquare, 18); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMax, 1, 10) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMax, 1, 10) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMax, 1, 10) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMax, 1, 10) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMax, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMax, 11, 12) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMax, 11, 12) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMax, 11, 12) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMax, 11, 12) REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ArgMax, 13); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ArgMax, 13); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 1, 10); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMin, 1, 10) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMin, 1, 10) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMin, 1, 10) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMin, 1, 10) REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ArgMin, 11, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_DOUBLE_ONLY(ArgMin, 11, 12) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT8_ONLY(ArgMin, 11, 12) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ArgMin, 11, 12) +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_UINT8_ONLY(ArgMin, 11, 12) REGISTER_UNARY_ELEMENTWISE_KERNEL(ArgMin, 13); REGISTER_UNARY_ELEMENTWISE_KERNEL_DOUBLE_ONLY(ArgMin, 13); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ArgMin, 13); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT8_ONLY(ArgMin, 13); +REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ArgMin, 13); FastReduceKind operator|(FastReduceKind a, FastReduceKind b) { return static_cast(static_cast(a) | static_cast(b)); diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 0697187a777d6..0968bc32e0de4 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -3246,6 +3246,26 @@ TEST(ReductionOpTest, ArgMax_do_not_keepdims_2) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: node1: at least 2 dimensions are required for input } +TEST(ReductionOpTest, ArgMax_int64) { + OpTester test("ArgMax", 13); + test.AddAttribute("axis", (int64_t)1); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {3, 1, 2}, + {1, 1, + 1, 1, + 1, 1}); + test.Run(); +} + TEST(ReductionOpTest, ArgMax_int32) { OpTester test("ArgMax"); test.AddAttribute("axis", (int64_t)1); @@ -3511,6 +3531,63 @@ TEST(ReductionOpTest, ArgMin_do_not_keepdims_2_select_last) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +TEST(ReductionOpTest, ArgMin_uint8) { + OpTester test("ArgMin", 13); + test.AddAttribute("axis", (int64_t)0); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1, 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {2, 2}, + {0, 0, + 0, 0}); + test.Run(); +} + +TEST(ReductionOpTest, ArgMin_int8) { + OpTester test("ArgMin", 13); + test.AddAttribute("axis", (int64_t)0); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1, 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {2, 2}, + {0, 0, + 0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +TEST(ReductionOpTest, ArgMin_int64) { + OpTester test("ArgMin", 13); + test.AddAttribute("axis", (int64_t)0); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1, 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {2, 2}, + {0, 0, + 0, 0}); + test.Run(); +} + TEST(ReductionOpTest, ArgMin_int32) { OpTester test("ArgMin"); test.AddAttribute("axis", (int64_t)0); From 7a782b72139412b94a667ea6f036d3677a1cb4f0 Mon Sep 17 00:00:00 2001 From: Hann Wang Date: Tue, 24 Sep 2024 05:01:54 +0800 Subject: [PATCH 04/13] [ROCm] fix rocm-6.2 build issues (#21993) Composable Kernel build fails under ROCm 6.2. This PR patches Composable Kernel the same way as https://github.com/ROCm/composable_kernel/pull/1346 * fix buffer resource to match "s" constraint * add missing memory clobber --- .../composable_kernel/Fix_Clang_Build.patch | 62 +++++++++++++++---- 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/cmake/patches/composable_kernel/Fix_Clang_Build.patch b/cmake/patches/composable_kernel/Fix_Clang_Build.patch index 73ece647d82c7..d63da63445fde 100644 --- a/cmake/patches/composable_kernel/Fix_Clang_Build.patch +++ b/cmake/patches/composable_kernel/Fix_Clang_Build.patch @@ -3,22 +3,22 @@ index c23746e7f..bc326c8b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,10 +23,10 @@ endif() - + set(version 1.1.0) # Check support for CUDA/HIP in Cmake -project(composable_kernel VERSION ${version} LANGUAGES CXX) +project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) include(CTest) - + -find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED) +find_package(Python3 COMPONENTS Interpreter REQUIRED) - + list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") - + @@ -227,27 +227,6 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") - + -## OpenMP -if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - # workaround issue hipcc in rocm3.5 cannot find openmp @@ -53,11 +53,11 @@ index c23746e7f..bc326c8b5 100644 -else() - add_compile_definitions(__HIP_PLATFORM_HCC__=1) -endif() - + ## tidy include(EnableCompilerWarnings) @@ -541,11 +514,3 @@ rocm_install(FILES - + set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") set(CPACK_RPM_PACKAGE_LICENSE "MIT") - @@ -88,7 +88,7 @@ index c0894f1d7..559481fee 100644 @@ -6,19 +6,7 @@ #include #include - + -// To be removed, which really does not tell the location of failed HIP functional call -inline void hip_check_error(hipError_t x) -{ @@ -121,9 +121,9 @@ index a164c3f94..293ead89a 100644 --- a/include/ck_tile/core/utility/transpose_vectors.hpp +++ b/include/ck_tile/core/utility/transpose_vectors.hpp @@ -11,6 +11,9 @@ - + namespace ck_tile { - + +template +constexpr bool always_false = false; + @@ -139,7 +139,7 @@ index a164c3f94..293ead89a 100644 } } }; - + + } // namespace ck_tile + @@ -150,7 +150,7 @@ index 3acdb4d87..cc26e184f 100644 @@ -8,20 +8,7 @@ #include #include - + -namespace ck_tile { -// To be removed, which really does not tell the location of failed HIP functional call -CK_TILE_HOST void hip_check_error(hipError_t x) @@ -198,3 +198,41 @@ index c035e7e56..8c5f36d2e 100644 set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) clang_tidy_check(${INSTANCE_NAME}) set(result 0) +--- ./include/ck/utility/amd_buffer_addressing.hpp 2024-09-05 10:12:33.343091000 +0800 ++++ ./include/ck/utility/amd_buffer_addressing_new.hpp 2024-09-05 10:12:20.276686000 +0800 +@@ -991,7 +991,8 @@ + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(global_offset_bytes), +- "s"(src_resource)); ++ "s"(src_resource) ++ : "memory"); + #else + // LDS pointer must be attributed with the LDS address space. + __attribute__((address_space(3))) uint32_t* lds_ptr = +--- ./include/ck_tile/core/arch/amd_buffer_addressing.hpp 2024-09-05 10:18:28.884031000 +0800 ++++ ./include/ck_tile/core/arch/amd_buffer_addressing_new.hpp 2024-09-05 10:17:29.434931000 +0800 +@@ -26,7 +26,12 @@ + CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) + { + buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; +- return __builtin_bit_cast(int32x4_t, res); ++ int32x4_t r = __builtin_bit_cast(int32x4_t, res); ++ r.x = __builtin_amdgcn_readfirstlane(r.x); ++ r.y = __builtin_amdgcn_readfirstlane(r.y); ++ r.z = __builtin_amdgcn_readfirstlane(r.z); ++ r.w = __builtin_amdgcn_readfirstlane(r.w); ++ return r; + } + + // TODO: glc/slc/... +@@ -2016,7 +2021,8 @@ + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(global_offset_bytes), +- "s"(src_resource)); ++ "s"(src_resource) ++ : "memory"); + #else + // LDS pointer must be attributed with the LDS address space. + __attribute__((address_space(3))) uint32_t* lds_ptr = From df25006d1b297b409b968194b43f641b3901b077 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 23 Sep 2024 14:39:32 -0700 Subject: [PATCH 05/13] upgrade micromatch to v4.0.8 (#22174) ### Description Upgrade `micromatch` to v4.0.8 https://github.com/advisories/GHSA-952p-6rrq-rcjv --- js/package-lock.json | 16 ++++++++-------- js/web/package-lock.json | 16 ++++++++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/js/package-lock.json b/js/package-lock.json index d3684dfdf9117..58a13a9112116 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -3282,12 +3282,12 @@ } }, "node_modules/micromatch": { - "version": "4.0.5", - "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", - "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", "dev": true, "dependencies": { - "braces": "^3.0.2", + "braces": "^3.0.3", "picomatch": "^2.3.1" }, "engines": { @@ -7059,12 +7059,12 @@ "dev": true }, "micromatch": { - "version": "4.0.5", - "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", - "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", "dev": true, "requires": { - "braces": "^3.0.2", + "braces": "^3.0.3", "picomatch": "^2.3.1" } }, diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 9db48f74a94a4..6e723a76e8fd8 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -2390,12 +2390,12 @@ } }, "node_modules/micromatch": { - "version": "4.0.5", - "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", - "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", "dev": true, "dependencies": { - "braces": "^3.0.2", + "braces": "^3.0.3", "picomatch": "^2.3.1" }, "engines": { @@ -5514,12 +5514,12 @@ "dev": true }, "micromatch": { - "version": "4.0.5", - "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", - "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", "dev": true, "requires": { - "braces": "^3.0.2", + "braces": "^3.0.3", "picomatch": "^2.3.1" } }, From a7c9f27d2d6f6c514447b4588c3c14dea65d6936 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 23 Sep 2024 18:15:41 -0700 Subject: [PATCH 06/13] Remove training pipelines from Win CPI CI as redundant (#22190) --- .../azure-pipelines/win-ci-pipeline.yml | 36 ------------------- 1 file changed, 36 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index 3e2ade7eacd25..94c2d35a563b6 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -196,42 +196,6 @@ stages: WITH_CACHE: false MachinePool: 'onnxruntime-Win-CPU-2022' -- stage: training_x64_debug - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'Debug' - buildArch: x64 - additionalBuildFlags: --enable_training --build_wheel --disable_memleak_checker - msbuildPlatform: x64 - isX86: false - job_name_suffix: training_x64_debug - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - isTraining: true - ORT_EP_NAME: CPU - GenerateDocumentation: false - WITH_CACHE: false - MachinePool: 'onnxruntime-Win2022-CPU-training-AMD' - -- stage: training_x64_release - dependsOn: [] - jobs: - - template: templates/jobs/win-ci-vs-2022-job.yml - parameters: - BuildConfig: 'RelWithDebInfo' - buildArch: x64 - additionalBuildFlags: --enable_training --build_wheel - msbuildPlatform: x64 - isX86: false - job_name_suffix: training_x64_release - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - isTraining: true - ORT_EP_NAME: CPU - GenerateDocumentation: false - WITH_CACHE: false - MachinePool: 'onnxruntime-Win2022-CPU-training-AMD' - - stage: ort_training_apis_x64_release dependsOn: [] jobs: From 0806879ad40a6f2fb2a28c30a7ad672bec76b646 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 23 Sep 2024 18:27:16 -0700 Subject: [PATCH 07/13] Update lintrunner requirements (#22185) ### Description * Add lintrunner to requirements-lintrunner.txt * Lock lintrunner and lintrunner-adapter version * Update documentation ### Motivation and Context The document is not up to date. --- docs/Coding_Conventions_and_Standards.md | 2 +- requirements-lintrunner.txt | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/Coding_Conventions_and_Standards.md b/docs/Coding_Conventions_and_Standards.md index e8e1e7dc9ccd8..f18f1036efee8 100644 --- a/docs/Coding_Conventions_and_Standards.md +++ b/docs/Coding_Conventions_and_Standards.md @@ -155,7 +155,7 @@ Using `Show Code Coverage Coloring` will allow you to visually inspect which lin This project uses [lintrunner](https://github.com/suo/lintrunner) for linting. It provides a consistent linting experience locally and in CI. You can install the dependencies and initialize with ```sh -pip install lintrunner lintrunner-adapters +pip install -r requirements-lintrunner.txt lintrunner init ``` diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 7d384f7b1df67..72d9ce72ea7cb 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -1,5 +1,7 @@ # This file is auto updated by dependabot -lintrunner-adapters>=0.12.4 +# When any package below is changed, you shall run "lintrunner init" again. +lintrunner==0.12.5 +lintrunner-adapters==0.12.4 # RUFF ruff==0.5.4 # BLACK-ISORT From ae66d0e7cf6774dc1b6435e122d3589251e6fbc8 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Tue, 24 Sep 2024 11:58:48 +1000 Subject: [PATCH 08/13] Update ROCm reduction to match recent CUDA change (#22192) ### Description Add handling of a missing optional axes input to the ROCm reduction ops. Matches CUDA EP change from #22149 ### Motivation and Context Fix pipeline. --- onnxruntime/core/providers/rocm/reduction/reduction_ops.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc index 11073ab3584eb..a1f5eba9a24c8 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc @@ -731,10 +731,9 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, miopenR std::vector axes; size_t num_inputs = ctx->InputCount(); - if (num_inputs == 2) { + const Tensor* axes_tensor = num_inputs == 2 ? ctx->Input(1) : nullptr; // optional input. may be nullptr. + if (axes_tensor != nullptr) { // override the attribute value with the input value for reduction_axes - const Tensor* axes_tensor = ctx->Input(1); - ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor."); auto nDims = static_cast(axes_tensor->Shape()[0]); const auto* data = axes_tensor->Data(); From cfa45df6b5060af6327a98a625eb9fe74580f56c Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 24 Sep 2024 01:36:52 -0400 Subject: [PATCH 09/13] [java] Migrate OnnxTensors created from arrays over to a backing Java buffer (#18556) ### Description Following from #16578 and #16835 this migrates over `OnnxTensor.createTensor()` to first instantiate a `java.nio.Buffer` and then copy the array into that buffer in Java before creating the tensor. It also changes the `OnnxTensor.getValue()` method which returns a multidimensional array so it does the array construction and value copy in Java. This allows the removal of some unpleasant recursive C code which repeatedly calls into the JVM to traverse Java's arrays. The equivalent Java code is still unpleasant and recursive, but it's easier to reason about and memory safe. As a bonus, more `OnnxTensor`s are now backed by buffers which allow users to pin memory and reduce allocations by reusing them for same sized inputs. Some of the JNI code which parses Java arrays still exists as it's used by `OnnxMap`, removing that will be the target of a future refactor. Strings are still processed in JNI as it is easier to work with String tensors and UTF-8 arrays in C. ### Motivation and Context Minimizing the amount of JNI code makes it easier to maintain and using buffers in preference to arrays allows for fewer allocations. --- .../main/java/ai/onnxruntime/OnnxTensor.java | 214 ++++++++++-- .../src/main/java/ai/onnxruntime/OrtUtil.java | 308 +++++++++++++++++- .../main/java/ai/onnxruntime/TensorInfo.java | 5 + java/src/main/native/OrtJniUtil.c | 157 --------- java/src/main/native/OrtJniUtil.h | 8 - .../main/native/ai_onnxruntime_OnnxTensor.c | 83 +---- .../java/ai/onnxruntime/InferenceTest.java | 63 +++- .../java/ai/onnxruntime/OnnxTensorTest.java | 155 +++++++-- 8 files changed, 682 insertions(+), 311 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index e1ee2c14fd9d1..3f276a3670156 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -54,7 +54,7 @@ public class OnnxTensor extends OnnxTensorLike { * the state of this buffer without first getting the reference via {@link #getBufferRef()}. * * @return True if the buffer in this OnnxTensor was allocated by it on construction (i.e., it is - * a copy of a user buffer.) + * a copy of a user buffer or array.) */ public boolean ownsBuffer() { return this.ownsBuffer; @@ -62,8 +62,8 @@ public boolean ownsBuffer() { /** * Returns a reference to the buffer which backs this {@code OnnxTensor}. If the tensor is not - * backed by a buffer (i.e., it was created from a Java array, or is backed by memory allocated by - * ORT) this method returns an empty {@link Optional}. + * backed by a buffer (i.e., it is backed by memory allocated by ORT) this method returns an empty + * {@link Optional}. * *

Changes to the buffer elements will be reflected in the native {@code OrtValue}, this can be * used to repeatedly update a single tensor for multiple different inferences without allocating @@ -77,7 +77,116 @@ public boolean ownsBuffer() { * @return A reference to the buffer. */ public Optional getBufferRef() { - return Optional.ofNullable(buffer); + return Optional.ofNullable(duplicate(buffer)); + } + + /** + * Duplicates the buffer to ensure concurrent reads don't disrupt the buffer position. Concurrent + * writes will modify the underlying memory in a racy way, don't do that. + * + *

Can be replaced to a call to buf.duplicate() in Java 9+. + * + * @param buf The buffer to duplicate. + * @return A copy of the buffer which refers to the same underlying memory, but has an independent + * position, limit and mark. + */ + private static Buffer duplicate(Buffer buf) { + if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).duplicate().order(ByteOrder.nativeOrder()); + } else if (buf instanceof ShortBuffer) { + return ((ShortBuffer) buf).duplicate(); + } else if (buf instanceof IntBuffer) { + return ((IntBuffer) buf).duplicate(); + } else if (buf instanceof LongBuffer) { + return ((LongBuffer) buf).duplicate(); + } else if (buf instanceof FloatBuffer) { + return ((FloatBuffer) buf).duplicate(); + } else if (buf instanceof DoubleBuffer) { + return ((DoubleBuffer) buf).duplicate(); + } else { + throw new IllegalStateException("Unknown buffer type " + buf.getClass()); + } + } + + /** + * Checks that the buffer is the right type for the {@code info.type}, and if it's a {@link + * ByteBuffer} then convert it to the right type. If it's not convertible it throws {@link + * IllegalStateException}. + * + *

Note this method converts FP16 and BFLOAT16 ShortBuffers into FP32 FloatBuffers, to preserve + * compatibility with existing {@link #getValue} calls. + * + * @param buf The buffer to convert. + * @return The buffer with the expected type. + */ + private Buffer castBuffer(Buffer buf) { + switch (info.type) { + case FLOAT: + if (buf instanceof FloatBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asFloatBuffer(); + } + break; + case DOUBLE: + if (buf instanceof DoubleBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asDoubleBuffer(); + } + break; + case BOOL: + case INT8: + case UINT8: + if (buf instanceof ByteBuffer) { + return buf; + } + break; + case BFLOAT16: + if (buf instanceof ShortBuffer) { + ShortBuffer bf16Buf = (ShortBuffer) buf; + return Fp16Conversions.convertBf16BufferToFloatBuffer(bf16Buf); + } else if (buf instanceof ByteBuffer) { + ShortBuffer bf16Buf = ((ByteBuffer) buf).asShortBuffer(); + return Fp16Conversions.convertBf16BufferToFloatBuffer(bf16Buf); + } + break; + case FLOAT16: + if (buf instanceof ShortBuffer) { + ShortBuffer fp16Buf = (ShortBuffer) buf; + return Fp16Conversions.convertFp16BufferToFloatBuffer(fp16Buf); + } else if (buf instanceof ByteBuffer) { + ShortBuffer fp16Buf = ((ByteBuffer) buf).asShortBuffer(); + return Fp16Conversions.convertFp16BufferToFloatBuffer(fp16Buf); + } + break; + case INT16: + if (buf instanceof ShortBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asShortBuffer(); + } + break; + case INT32: + if (buf instanceof IntBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asIntBuffer(); + } + break; + case INT64: + if (buf instanceof LongBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asLongBuffer(); + } + break; + } + throw new IllegalStateException( + "Invalid buffer type for cast operation, found " + + buf.getClass() + + " expected something convertible to " + + info.type); } @Override @@ -133,15 +242,26 @@ public Object getValue() throws OrtException { Object carrier = info.makeCarrier(); if (info.getNumElements() > 0) { // If the tensor has values copy them out - getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier); - } - if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) { - // We read the strings out from native code in a flat array and then reshape - // to the desired output shape. - return OrtUtil.reshape((String[]) carrier, info.shape); - } else { - return carrier; + if (info.type == OnnxJavaType.STRING) { + // We read the strings out from native code in a flat array and then reshape + // to the desired output shape if necessary. + getStringArray(OnnxRuntime.ortApiHandle, nativeHandle, (String[]) carrier); + if (info.shape.length != 1) { + carrier = OrtUtil.reshape((String[]) carrier, info.shape); + } + } else { + // Wrap ORT owned memory in buffer, otherwise use our reference + Buffer buf; + if (buffer == null) { + buf = castBuffer(getBuffer()); + } else { + buf = castBuffer(duplicate(buffer)); + } + // Copy out buffer into arrays + OrtUtil.fillArrayFromBuffer(info, buf, 0, carrier); + } } + return carrier; } } @@ -175,8 +295,8 @@ public synchronized void close() { public ByteBuffer getByteBuffer() { checkClosed(); if (info.type != OnnxJavaType.STRING) { - ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, nativeHandle); - ByteBuffer output = ByteBuffer.allocate(buffer.capacity()); + ByteBuffer buffer = getBuffer(); + ByteBuffer output = ByteBuffer.allocate(buffer.capacity()).order(ByteOrder.nativeOrder()); output.put(buffer); output.rewind(); return output; @@ -201,12 +321,12 @@ public FloatBuffer getFloatBuffer() { output.rewind(); return output; } else if (info.type == OnnxJavaType.FLOAT16) { - // if it's fp16 we need to copy it out by hand. + // if it's fp16 we need to convert it. ByteBuffer buf = getBuffer(); ShortBuffer buffer = buf.asShortBuffer(); return Fp16Conversions.convertFp16BufferToFloatBuffer(buffer); } else if (info.type == OnnxJavaType.BFLOAT16) { - // if it's bf16 we need to copy it out by hand. + // if it's bf16 we need to convert it. ByteBuffer buf = getBuffer(); ShortBuffer buffer = buf.asShortBuffer(); return Fp16Conversions.convertBf16BufferToFloatBuffer(buffer); @@ -331,7 +451,7 @@ private native short getShort(long apiHandle, long nativeHandle, int onnxType) private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException; - private native void getArray(long apiHandle, long nativeHandle, Object carrier) + private native void getStringArray(long apiHandle, long nativeHandle, String[] carrier) throws OrtException; private native void close(long apiHandle, long nativeHandle); @@ -387,21 +507,32 @@ static OnnxTensor createTensor(OrtEnvironment env, OrtAllocator allocator, Objec info); } } else { + Buffer buf; if (info.shape.length == 0) { - data = OrtUtil.convertBoxedPrimitiveToArray(info.type, data); - if (data == null) { + buf = OrtUtil.convertBoxedPrimitiveToBuffer(info.type, data); + if (buf == null) { throw new OrtException( "Failed to convert a boxed primitive to an array, this is an error with the ORT Java API, please report this message & stack trace. JavaType = " + info.type + ", object = " + data); } + } else { + buf = OrtUtil.convertArrayToBuffer(info, data); } return new OnnxTensor( - createTensor( - OnnxRuntime.ortApiHandle, allocator.handle, data, info.shape, info.onnxType.value), + createTensorFromBuffer( + OnnxRuntime.ortApiHandle, + allocator.handle, + buf, + 0, + info.type.size * info.numElements, + info.shape, + info.onnxType.value), allocator.handle, - info); + info, + buf, + true); } } else { throw new IllegalStateException("Trying to create an OnnxTensor with a closed OrtAllocator."); @@ -627,7 +758,26 @@ static OnnxTensor createTensor( */ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long[] shape) throws OrtException { - return createTensor(env, env.defaultAllocator, data, shape); + return createTensor(env, env.defaultAllocator, data, shape, OnnxJavaType.INT16); + } + + /** + * Create an OnnxTensor backed by a direct ShortBuffer. The buffer should be in nativeOrder. + * + *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime + * of the tensor. Uses the default allocator. + * + * @param env The current OrtEnvironment. + * @param data The tensor data. + * @param shape The shape of tensor. + * @param type The type of the data in the buffer, can be either {@link OnnxJavaType#INT16}, + * {@link OnnxJavaType#FLOAT16} or {@link OnnxJavaType#BFLOAT16}. + * @return An OnnxTensor of the required shape. + * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. + */ + public static OnnxTensor createTensor( + OrtEnvironment env, ShortBuffer data, long[] shape, OnnxJavaType type) throws OrtException { + return createTensor(env, env.defaultAllocator, data, shape, type); } /** @@ -640,15 +790,23 @@ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long * @param allocator The allocator to use. * @param data The tensor data. * @param shape The shape of tensor. + * @param type The type of the data in the buffer, can be either {@link OnnxJavaType#INT16}, + * {@link OnnxJavaType#FLOAT16} or {@link OnnxJavaType#BFLOAT16}. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ static OnnxTensor createTensor( - OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape) + OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape, OnnxJavaType type) throws OrtException { if (!allocator.isClosed()) { - OnnxJavaType type = OnnxJavaType.INT16; - return createTensor(type, allocator, data, shape); + if ((type == OnnxJavaType.BFLOAT16) + || (type == OnnxJavaType.FLOAT16) + || (type == OnnxJavaType.INT16)) { + return createTensor(type, allocator, data, shape); + } else { + throw new IllegalArgumentException( + "Only int16, float16 or bfloat16 tensors can be created from ShortBuffer."); + } } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } @@ -768,10 +926,6 @@ private static OnnxTensor createTensor( tuple.isCopy); } - private static native long createTensor( - long apiHandle, long allocatorHandle, Object data, long[] shape, int onnxType) - throws OrtException; - private static native long createTensorFromBuffer( long apiHandle, long allocatorHandle, diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index 4f3dee3c00b91..2f44236e4ef67 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -26,10 +26,10 @@ public final class OrtUtil { private OrtUtil() {} /** - * Converts an long shape into a int shape. + * Converts a long shape into an int shape. * - *

Validates that the shape has more than 1 elements, less than 9 elements, each element is - * less than {@link Integer#MAX_VALUE} and that each entry is non-negative. + *

Validates that the shape has more than 1 element, less than 9 elements, each element is less + * than {@link Integer#MAX_VALUE} and that each entry is non-negative. * * @param shape The long shape. * @return The int shape. @@ -460,6 +460,308 @@ static Object convertBoxedPrimitiveToArray(OnnxJavaType javaType, Object data) { } } + /** + * Stores a boxed primitive in a single element buffer of the unboxed type. + * + *

If it's not a boxed primitive then it returns null. + * + * @param javaType The type of the boxed primitive. + * @param data The boxed primitive. + * @return The primitive in a direct buffer. + */ + static Buffer convertBoxedPrimitiveToBuffer(OnnxJavaType javaType, Object data) { + switch (javaType) { + case FLOAT: + { + FloatBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer(); + buf.put(0, (Float) data); + return buf; + } + case DOUBLE: + { + DoubleBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asDoubleBuffer(); + buf.put(0, (Double) data); + return buf; + } + case BOOL: + { + ByteBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()); + buf.put(0, ((boolean) data) ? (byte) 1 : (byte) 0); + return buf; + } + case UINT8: + case INT8: + { + ByteBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()); + buf.put(0, (Byte) data); + return buf; + } + case FLOAT16: + case BFLOAT16: + case INT16: + { + ShortBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asShortBuffer(); + buf.put(0, (Short) data); + return buf; + } + case INT32: + { + IntBuffer buf = + ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()).asIntBuffer(); + buf.put(0, (Integer) data); + return buf; + } + case INT64: + { + LongBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asLongBuffer(); + buf.put(0, (Long) data); + return buf; + } + case STRING: + case UNKNOWN: + default: + return null; + } + } + + /** + * Copies a Java (possibly multidimensional) array into a direct {@link Buffer}. + * + *

Throws {@link IllegalArgumentException} if the array is not an array of Java primitives or + * if the array is ragged. + * + * @param info The tensor info object containing the types and shape of the array. + * @param array The array object. + * @return A direct buffer containing all the elements. + */ + static Buffer convertArrayToBuffer(TensorInfo info, Object array) { + ByteBuffer byteBuffer = + ByteBuffer.allocateDirect((int) info.numElements * info.type.size) + .order(ByteOrder.nativeOrder()); + + Buffer buffer; + switch (info.type) { + case FLOAT: + buffer = byteBuffer.asFloatBuffer(); + break; + case DOUBLE: + buffer = byteBuffer.asDoubleBuffer(); + break; + case BOOL: + case INT8: + case UINT8: + // no-op, it's already a bytebuffer + buffer = byteBuffer; + break; + case BFLOAT16: + case FLOAT16: + case INT16: + buffer = byteBuffer.asShortBuffer(); + break; + case INT32: + buffer = byteBuffer.asIntBuffer(); + break; + case INT64: + buffer = byteBuffer.asLongBuffer(); + break; + case STRING: + case UNKNOWN: + default: + throw new IllegalArgumentException( + "Unexpected type, expected Java primitive found " + info.type); + } + + fillBufferFromArray(info, array, 0, buffer); + + if (buffer.remaining() != 0) { + throw new IllegalArgumentException( + "Failed to copy all elements into the buffer, expected to copy " + + info.numElements + + " into a buffer of capacity " + + buffer.capacity() + + " but had " + + buffer.remaining() + + " values left over."); + } + buffer.rewind(); + + return buffer; + } + + /** + * Fills the provided buffer with the values from the array, recursing through the array + * structure. + * + * @param info The tensor info containing the type and shape of the array. + * @param array The array object to read from. + * @param curDim The current dimension we're processing. + * @param buffer The buffer to write to. + */ + private static void fillBufferFromArray( + TensorInfo info, Object array, int curDim, Buffer buffer) { + if (curDim == info.shape.length - 1) { + // Reached primitive values, copy into buffer + switch (info.type) { + case FLOAT: + float[] fArr = (float[]) array; + FloatBuffer fBuf = (FloatBuffer) buffer; + fBuf.put(fArr); + break; + case DOUBLE: + double[] dArr = (double[]) array; + DoubleBuffer dBuf = (DoubleBuffer) buffer; + dBuf.put(dArr); + break; + case INT8: + case UINT8: + byte[] bArr = (byte[]) array; + ByteBuffer bBuf = (ByteBuffer) buffer; + bBuf.put(bArr); + break; + case FLOAT16: + case BFLOAT16: + case INT16: + short[] sArr = (short[]) array; + ShortBuffer sBuf = (ShortBuffer) buffer; + sBuf.put(sArr); + break; + case INT32: + int[] iArr = (int[]) array; + IntBuffer iBuf = (IntBuffer) buffer; + iBuf.put(iArr); + break; + case INT64: + long[] lArr = (long[]) array; + LongBuffer lBuf = (LongBuffer) buffer; + lBuf.put(lArr); + break; + case BOOL: + boolean[] boolArr = (boolean[]) array; + ByteBuffer boolBuf = (ByteBuffer) buffer; + for (int i = 0; i < boolArr.length; i++) { + boolBuf.put(boolArr[i] ? (byte) 1 : (byte) 0); + } + break; + case STRING: + case UNKNOWN: + throw new IllegalArgumentException( + "Unexpected type, expected Java primitive found " + info.type); + } + } else { + // Recurse through array + long expectedSize = info.shape[curDim]; + long actualSize = Array.getLength(array); + if (expectedSize != actualSize) { + throw new IllegalArgumentException( + "Mismatch in array sizes, expected " + + expectedSize + + " at dim " + + curDim + + " from shape " + + Arrays.toString(info.shape) + + ", found " + + actualSize); + } else { + for (int i = 0; i < actualSize; i++) { + fillBufferFromArray(info, Array.get(array, i), curDim + 1, buffer); + } + } + } + } + + /** + * Fills the provided array with the values from the buffer, recursing through the array + * structure. + * + * @param info The tensor info containing the type and shape of the array. + * @param buffer The buffer to read from. + * @param curDim The current dimension we're processing. + * @param array The array object to write to. + */ + static void fillArrayFromBuffer(TensorInfo info, Buffer buffer, int curDim, Object array) { + if (curDim == info.shape.length - 1) { + // Reached primitive values, copy into buffer + switch (info.type) { + case FLOAT16: + case BFLOAT16: + case FLOAT: + float[] fArr = (float[]) array; + FloatBuffer fBuf = (FloatBuffer) buffer; + fBuf.get(fArr); + break; + case DOUBLE: + double[] dArr = (double[]) array; + DoubleBuffer dBuf = (DoubleBuffer) buffer; + dBuf.get(dArr); + break; + case INT8: + case UINT8: + byte[] bArr = (byte[]) array; + ByteBuffer bBuf = (ByteBuffer) buffer; + bBuf.get(bArr); + break; + case INT16: + short[] sArr = (short[]) array; + ShortBuffer sBuf = (ShortBuffer) buffer; + sBuf.get(sArr); + break; + case INT32: + int[] iArr = (int[]) array; + IntBuffer iBuf = (IntBuffer) buffer; + iBuf.get(iArr); + break; + case INT64: + long[] lArr = (long[]) array; + LongBuffer lBuf = (LongBuffer) buffer; + lBuf.get(lArr); + break; + case BOOL: + boolean[] boolArr = (boolean[]) array; + ByteBuffer boolBuf = (ByteBuffer) buffer; + for (int i = 0; i < boolArr.length; i++) { + // Test to see if the byte is non-zero, non-zero bytes are true, zero bytes are false. + boolArr[i] = boolBuf.get() != 0; + } + break; + case STRING: + case UNKNOWN: + throw new IllegalArgumentException( + "Unexpected type, expected Java primitive found " + info.type); + } + } else { + // Recurse through array + long expectedSize = info.shape[curDim]; + long actualSize = Array.getLength(array); + if (expectedSize != actualSize) { + throw new IllegalArgumentException( + "Mismatch in array sizes, expected " + + expectedSize + + " at dim " + + curDim + + " from shape " + + Arrays.toString(info.shape) + + ", found " + + actualSize); + } else { + for (int i = 0; i < actualSize; i++) { + fillArrayFromBuffer(info, buffer, curDim + 1, Array.get(array, i)); + } + } + } + } + /** * Returns expected JDK map capacity for a given size, this factors in the default JDK load factor * diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 1c21387b50455..f3e9f21ef408d 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -323,6 +323,9 @@ public long getNumElements() { * all elements as that's the expected format of the native code. It can be reshaped to the * correct shape using {@link OrtUtil#reshape(String[],long[])}. * + *

For fp16 and bf16 tensors the output carrier type is float, and so this method produces + * multidimensional float arrays. + * * @return A multidimensional array of the appropriate primitive type (or String). * @throws OrtException If the shape isn't representable in Java (i.e. if one of its indices is * greater than an int). @@ -335,6 +338,8 @@ public Object makeCarrier() throws OrtException { + Arrays.toString(shape)); } switch (type) { + case BFLOAT16: + case FLOAT16: case FLOAT: return OrtUtil.newFloatArray(shape); case DOUBLE: diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 7b26291581395..6a3c279073860 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -502,104 +502,6 @@ jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSeque return sequenceInfo; } -int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor) { - int32_t inputLength = (*jniEnv)->GetArrayLength(jniEnv, inputArray); - int64_t consumedSize = inputLength * onnxTypeSize(onnxType); - switch (onnxType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // maps to c type uint8_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { // maps to c type int8_t - jbyteArray typedArr = (jbyteArray)inputArray; - (*jniEnv)->GetByteArrayRegion(jniEnv, typedArr, 0, inputLength, (jbyte * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: // maps to c type uint16_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { // maps to c type int16_t - jshortArray typedArr = (jshortArray)inputArray; - (*jniEnv)->GetShortArrayRegion(jniEnv, typedArr, 0, inputLength, (jshort * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: // maps to c type uint32_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { // maps to c type int32_t - jintArray typedArr = (jintArray)inputArray; - (*jniEnv)->GetIntArrayRegion(jniEnv, typedArr, 0, inputLength, (jint * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: // maps to c type uint64_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { // maps to c type int64_t - jlongArray typedArr = (jlongArray)inputArray; - (*jniEnv)->GetLongArrayRegion(jniEnv, typedArr, 0, inputLength, (jlong * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // Non-IEEE floating-point format based on IEEE754 single-precision - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "16-bit float not supported."); - return -1; - /* - float *floatArr = malloc(sizeof(float) * inputLength); - uint16_t *halfArr = (uint16_t *) outputTensor; - for (uint32_t i = 0; i < inputLength; i++) { - floatArr[i] = convertHalfToFloat(halfArr[i]); - } - jfloatArray typedArr = (jfloatArray) inputArray; - (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, floatArr); - free(floatArr); - return consumedSize; - */ - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { // maps to c type float - jfloatArray typedArr = (jfloatArray)inputArray; - (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, (jfloat * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { // maps to c type double - jdoubleArray typedArr = (jdoubleArray)inputArray; - (*jniEnv)->GetDoubleArrayRegion(jniEnv, typedArr, 0, inputLength, (jdouble * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: { // maps to c++ type std::string - throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "String is not supported."); - return -1; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - jbooleanArray typedArr = (jbooleanArray)inputArray; - (*jniEnv)->GetBooleanArrayRegion(jniEnv, typedArr, 0, inputLength, (jboolean *)outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: // complex with float32 real and imaginary components - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: // complex with float64 real and imaginary components - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: - default: { - throwOrtException(jniEnv, convertErrorCode(ORT_INVALID_ARGUMENT), "Invalid outputTensor element type."); - return -1; - } - } -} - -int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor) { - if (dimensionsRemaining == 1) { - // write out 1d array of the respective primitive type - return copyJavaToPrimitiveArray(jniEnv, onnxType, inputArray, outputTensor); - } else { - // recurse through the dimensions - // Java arrays are objects until the final dimension - jobjectArray inputObjArr = (jobjectArray)inputArray; - int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, inputObjArr); - int64_t sizeConsumed = 0; - for (int32_t i = 0; i < dimLength; i++) { - jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, inputObjArr, i); - int64_t consumed = copyJavaToTensor(jniEnv, onnxType, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr, outputTensor + sizeConsumed); - sizeConsumed += consumed; - // Cleanup reference to childArr so it doesn't prevent GC. - (*jniEnv)->DeleteLocalRef(jniEnv, childArr); - // If we failed to copy an array then break and return. - if (consumed == -1) { - return -1; - } - } - return sizeConsumed; - } -} - int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray) { int32_t outputLength = (*jniEnv)->GetArrayLength(jniEnv, outputArray); if (outputLength == 0) return 0; @@ -697,65 +599,6 @@ int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxT } } -int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, - size_t dimensionsRemaining, jarray outputArray) { - if (dimensionsRemaining == 1) { - // write out 1d array of the respective primitive type - return copyPrimitiveArrayToJava(jniEnv, onnxType, inputTensor, outputArray); - } else { - // recurse through the dimensions - // Java arrays are objects until the final dimension - jobjectArray outputObjArr = (jobjectArray)outputArray; - int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, outputObjArr); - int64_t sizeConsumed = 0; - for (int32_t i = 0; i < dimLength; i++) { - jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, outputObjArr, i); - int64_t consumed = copyTensorToJava(jniEnv, onnxType, inputTensor + sizeConsumed, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr); - sizeConsumed += consumed; - // Cleanup reference to childArr so it doesn't prevent GC. - (*jniEnv)->DeleteLocalRef(jniEnv, childArr); - // If we failed to copy an array then break and return. - if (consumed == -1) { - return -1; - } - } - return sizeConsumed; - } -} - -jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) { - jobject tempString = NULL; - // Get the buffer size needed - size_t totalStringLength = 0; - OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetStringTensorDataLength(tensor, &totalStringLength)); - if (code != ORT_OK) { - return NULL; - } - - // Create the character and offset buffers, character is one larger to allow zero termination. - char * characterBuffer = malloc(sizeof(char)*(totalStringLength+1)); - if (characterBuffer == NULL) { - throwOrtException(jniEnv, 1, "OOM error"); - } else { - size_t * offsets = malloc(sizeof(size_t)); - if (offsets != NULL) { - // Get a view on the String data - code = checkOrtStatus(jniEnv, api, api->GetStringTensorContent(tensor, characterBuffer, totalStringLength, offsets, 1)); - - if (code == ORT_OK) { - size_t curSize = (offsets[0]) + 1; - characterBuffer[curSize-1] = '\0'; - tempString = (*jniEnv)->NewStringUTF(jniEnv, characterBuffer); - } - - free((void*)characterBuffer); - free((void*)offsets); - } - } - - return tempString; -} - OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray) { size_t bufferSize = 16; char * tempBuffer = malloc(bufferSize); diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index 023bc0c739583..7f41e06371f2a 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -54,16 +54,8 @@ jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInf jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info); -int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor); - -int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor); - int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray); -int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, size_t dimensionsRemaining, jarray outputArray); - -jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor); - OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray); jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor); diff --git a/java/src/main/native/ai_onnxruntime_OnnxTensor.c b/java/src/main/native/ai_onnxruntime_OnnxTensor.c index b694f57357bb5..d757bd6281499 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxTensor.c +++ b/java/src/main/native/ai_onnxruntime_OnnxTensor.c @@ -8,72 +8,6 @@ #include "OrtJniUtil.h" #include "ai_onnxruntime_OnnxTensor.h" -/* - * Class: ai_onnxruntime_OnnxTensor - * Method: createTensor - * Signature: (JJLjava/lang/Object;[JI)J - */ -JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor - (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject dataObj, - jlongArray shape, jint onnxTypeJava) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; - // Convert type to ONNX C enum - ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava); - - // Extract the shape information - jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL); - jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape); - - // Create the OrtValue - OrtValue* ortValue = NULL; - OrtErrorCode code = checkOrtStatus(jniEnv, api, - api->CreateTensorAsOrtValue( - allocator, (int64_t*)shapeArr, shapeLen, onnxType, &ortValue - ) - ); - (*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT); - - int failed = 0; - if (code == ORT_OK) { - // Get a reference to the OrtValue's data - uint8_t* tensorData = NULL; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&tensorData)); - if (code == ORT_OK) { - // Check if we're copying a scalar or not - if (shapeLen == 0) { - // Scalars are passed in as a single element array - int64_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, dataObj, tensorData); - failed = copied == -1 ? 1 : failed; - } else { - // Extract the tensor shape information - JavaTensorTypeShape typeShape; - code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue); - - if (code == ORT_OK) { - // Copy the java array into the tensor - int64_t copied = copyJavaToTensor(jniEnv, onnxType, typeShape.elementCount, - typeShape.dimensions, dataObj, tensorData); - failed = copied == -1 ? 1 : failed; - } else { - failed = 1; - } - } - } else { - failed = 1; - } - } - - if (failed) { - api->ReleaseValue(ortValue); - ortValue = NULL; - } - - // Return the pointer to the OrtValue - return (jlong) ortValue; -} - /* * Class: ai_onnxruntime_OnnxTensor * Method: createTensorFromBuffer @@ -227,7 +161,7 @@ JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer size_t sizeBytes = typeShape.elementCount * typeSize; uint8_t* arr = NULL; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr)); + code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&arr)); if (code == ORT_OK) { return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, (jlong)sizeBytes); @@ -401,11 +335,11 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool /* * Class: ai_onnxruntime_OnnxTensor - * Method: getArray - * Signature: (JJLjava/lang/Object;)V + * Method: getStringArray + * Signature: (JJ[Ljava/lang/String;)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobject carrier) { +JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getStringArray + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobjectArray carrier) { (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtValue* value = (OrtValue*) handle; @@ -415,12 +349,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { copyStringTensorToArray(jniEnv, api, value, typeShape.elementCount, carrier); } else { - uint8_t* arr = NULL; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr)); - if (code == ORT_OK) { - copyTensorToJava(jniEnv, typeShape.onnxTypeEnum, arr, typeShape.elementCount, - typeShape.dimensions, (jarray)carrier); - } + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Non-string types are not supported by this codepath, please raise a Github issue as it should not reach here."); } } } diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 11141a3a65a3e..7cb6305923279 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -495,12 +495,12 @@ public void throwWrongInputName() throws OrtException { container.put("wrong_name", OnnxTensor.createTensor(env, tensor)); try { session.run(container); - OnnxValue.close(container.values()); fail("Should throw exception for incorrect name."); } catch (OrtException e) { - OnnxValue.close(container.values()); String msg = e.getMessage(); assertTrue(msg.contains("Unknown input name")); + } finally { + OnnxValue.close(container.values()); } } } @@ -522,12 +522,57 @@ public void throwWrongInputType() throws OrtException { container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); try { session.run(container); - OnnxValue.close(container.values()); fail("Should throw exception for incorrect type."); } catch (OrtException e) { - OnnxValue.close(container.values()); String msg = e.getMessage(); assertTrue(msg.contains("Unexpected input data type")); + } finally { + OnnxValue.close(container.values()); + } + } + } + + @Test + public void throwWrongSizeInput() throws OrtException { + SqueezeNetTuple tuple = openSessionSqueezeNet(); + try (OrtSession session = tuple.session) { + + float[] inputData = tuple.inputData; + NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); + Map container = new HashMap<>(); + float[] wrongSizeData = Arrays.copyOf(inputData, 2 * 224 * 224); + Object tensor = OrtUtil.reshape(wrongSizeData, new long[] {1, 2, 224, 224}); + container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); + try { + session.run(container); + fail("Should throw exception for incorrect size."); + } catch (OrtException e) { + String msg = e.getMessage(); + assertTrue(msg.contains("Got invalid dimensions for input")); + } finally { + OnnxValue.close(container.values()); + } + } + } + + @Test + public void throwWrongRankInput() throws OrtException { + SqueezeNetTuple tuple = openSessionSqueezeNet(); + try (OrtSession session = tuple.session) { + + float[] inputData = tuple.inputData; + NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); + Map container = new HashMap<>(); + Object tensor = OrtUtil.reshape(inputData, new long[] {1, 1, 3, 224, 224}); + container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); + try { + session.run(container); + fail("Should throw exception for incorrect size."); + } catch (OrtException e) { + String msg = e.getMessage(); + assertTrue(msg.contains("Invalid rank for input")); + } finally { + OnnxValue.close(container.values()); } } } @@ -550,12 +595,12 @@ public void throwExtraInputs() throws OrtException { container.put("extra", OnnxTensor.createTensor(env, tensor)); try { session.run(container); - OnnxValue.close(container.values()); fail("Should throw exception for too many inputs."); } catch (OrtException e) { - OnnxValue.close(container.values()); String msg = e.getMessage(); assertTrue(msg.contains("Unexpected number of inputs")); + } finally { + OnnxValue.close(container.values()); } } } @@ -565,12 +610,11 @@ public void testMultiThreads() throws OrtException, InterruptedException { int numThreads = 10; int loop = 10; SqueezeNetTuple tuple = openSessionSqueezeNet(); + Map container = new HashMap<>(); try (OrtSession session = tuple.session) { - float[] inputData = tuple.inputData; float[] expectedOutput = tuple.outputData; NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); - Map container = new HashMap<>(); long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape; Object tensor = OrtUtil.reshape(inputData, inputShape); container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); @@ -592,8 +636,9 @@ public void testMultiThreads() throws OrtException, InterruptedException { } executor.shutdown(); executor.awaitTermination(1, TimeUnit.MINUTES); - OnnxValue.close(container.values()); assertTrue(executor.isTerminated()); + } finally { + OnnxValue.close(container.values()); } } diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java index ea210d96c1507..064f14f3b51ff 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java +++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java @@ -12,8 +12,11 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; +import java.nio.IntBuffer; import java.nio.ShortBuffer; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.SplittableRandom; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -93,30 +96,108 @@ public void testScalarCreation() throws OrtException { } @Test - public void testBufferCreation() throws OrtException { + public void testArrayCreation() throws OrtException { OrtEnvironment env = OrtEnvironment.getEnvironment(); - // Test creating a value from an array - // Arrays result in tensors allocated by ORT, so they do not have a backing java.nio.Buffer + // Test creating a value from a single dimensional array float[] arrValues = new float[] {0, 1, 2, 3, 4}; try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) { - // array creation isn't backed by buffers - assertFalse(t.ownsBuffer()); - assertFalse(t.getBufferRef().isPresent()); - FloatBuffer buf = t.getFloatBuffer(); + Assertions.assertTrue(t.ownsBuffer()); + Assertions.assertTrue(t.getBufferRef().isPresent()); + FloatBuffer buf = (FloatBuffer) t.getBufferRef().get(); float[] output = new float[arrValues.length]; buf.get(output); Assertions.assertArrayEquals(arrValues, output); - // Can't modify the tensor through this buffer. + // Can modify the tensor through this buffer. buf.put(0, 25); - Assertions.assertArrayEquals(arrValues, output); + Assertions.assertArrayEquals(new float[] {25, 1, 2, 3, 4}, (float[]) t.getValue()); } + // Test creating a value from a multidimensional float array + float[][][] arr3dValues = + new float[][][] { + {{0, 1, 2}, {3, 4, 5}}, + {{6, 7, 8}, {9, 10, 11}}, + {{12, 13, 14}, {15, 16, 17}}, + {{18, 19, 20}, {21, 22, 23}} + }; + try (OnnxTensor t = OnnxTensor.createTensor(env, arr3dValues)) { + Assertions.assertArrayEquals(new long[] {4, 2, 3}, t.getInfo().getShape()); + Assertions.assertTrue(t.ownsBuffer()); + Assertions.assertTrue(t.getBufferRef().isPresent()); + float[][][] output = (float[][][]) t.getValue(); + Assertions.assertArrayEquals(arr3dValues, output); + + // Can modify the tensor through the buffer. + FloatBuffer buf = (FloatBuffer) t.getBufferRef().get(); + buf.put(0, 25); + buf.put(12, 32); + buf.put(13, 33); + buf.put(23, 35); + arr3dValues[0][0][0] = 25; + arr3dValues[2][0][0] = 32; + arr3dValues[2][0][1] = 33; + arr3dValues[3][1][2] = 35; + output = (float[][][]) t.getValue(); + Assertions.assertArrayEquals(arr3dValues, output); + } + + // Test creating a value from a multidimensional int array + int[][][] iArr3dValues = + new int[][][] { + {{0, 1, 2}, {3, 4, 5}}, + {{6, 7, 8}, {9, 10, 11}}, + {{12, 13, 14}, {15, 16, 17}}, + {{18, 19, 20}, {21, 22, 23}} + }; + try (OnnxTensor t = OnnxTensor.createTensor(env, iArr3dValues)) { + Assertions.assertArrayEquals(new long[] {4, 2, 3}, t.getInfo().getShape()); + Assertions.assertTrue(t.ownsBuffer()); + Assertions.assertTrue(t.getBufferRef().isPresent()); + int[][][] output = (int[][][]) t.getValue(); + Assertions.assertArrayEquals(iArr3dValues, output); + + // Can modify the tensor through the buffer. + IntBuffer buf = (IntBuffer) t.getBufferRef().get(); + buf.put(0, 25); + iArr3dValues[0][0][0] = 25; + output = (int[][][]) t.getValue(); + Assertions.assertArrayEquals(iArr3dValues, output); + } + + // Test creating a value from a ragged array throws + int[][][] ragged = + new int[][][] { + {{0, 1, 2}, {3, 4, 5}}, + {{6, 7, 8}}, + {{12, 13}, {15, 16, 17}}, + {{18, 19, 20}, {21, 22, 23}} + }; + try (OnnxTensor t = OnnxTensor.createTensor(env, ragged)) { + Assertions.fail("Can't create tensors from ragged arrays"); + } catch (OrtException e) { + Assertions.assertTrue(e.getMessage().contains("ragged")); + } + + // Test creating a value from a non-array, non-primitive type throws. + List list = new ArrayList<>(5); + list.add(5); + try (OnnxTensor t = OnnxTensor.createTensor(env, list)) { + Assertions.fail("Can't create tensors from lists"); + } catch (OrtException e) { + Assertions.assertTrue(e.getMessage().contains("Cannot convert")); + } + } + + @Test + public void testBufferCreation() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + // Test creating a value from a non-direct byte buffer // Non-direct byte buffers are allocated on the Java heap and must be copied into off-heap - // direct byte buffers - // which can be directly passed to ORT + // direct byte buffers which can be directly passed to ORT + float[] arrValues = new float[] {0, 1, 2, 3, 4}; FloatBuffer nonDirectBuffer = FloatBuffer.allocate(5); nonDirectBuffer.put(arrValues); nonDirectBuffer.rewind(); @@ -335,10 +416,12 @@ public void testFp32ToFp16() throws OrtException { String modelPath = TestHelpers.getResourcePath("/java-fp32-to-fp16.onnx").toString(); SplittableRandom rng = new SplittableRandom(1); - float[][] input = new float[10][5]; + int dim1 = 10, dim2 = 5; + float[][] input = new float[dim1][dim2]; + float[][] expectedOutput = new float[dim1][dim2]; FloatBuffer floatBuf = - ByteBuffer.allocateDirect(4 * 10 * 5).order(ByteOrder.nativeOrder()).asFloatBuffer(); - ShortBuffer shortBuf = ShortBuffer.allocate(10 * 5); + ByteBuffer.allocateDirect(4 * dim1 * dim2).order(ByteOrder.nativeOrder()).asFloatBuffer(); + ShortBuffer shortBuf = ShortBuffer.allocate(dim1 * dim2); // Generate data for (int i = 0; i < input.length; i++) { @@ -347,6 +430,8 @@ public void testFp32ToFp16() throws OrtException { input[i][j] = Float.intBitsToFloat(bits); floatBuf.put(input[i][j]); shortBuf.put(Fp16Conversions.floatToFp16(input[i][j])); + expectedOutput[i][j] = + Fp16Conversions.fp16ToFloat(Fp16Conversions.floatToFp16(input[i][j])); } } floatBuf.rewind(); @@ -354,25 +439,31 @@ public void testFp32ToFp16() throws OrtException { try (OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); OrtSession session = env.createSession(modelPath, opts); - OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {10, 5}); + OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {dim1, dim2}); OrtSession.Result result = session.run(Collections.singletonMap("input", tensor))) { OnnxTensor output = (OnnxTensor) result.get(0); // Check outbound Java side cast to fp32 works FloatBuffer castOutput = output.getFloatBuffer(); - float[] expectedFloatArr = new float[10 * 5]; + float[] expectedFloatArr = new float[dim1 * dim2]; Fp16Conversions.convertFp16BufferToFloatBuffer(shortBuf).get(expectedFloatArr); - float[] actualFloatArr = new float[10 * 5]; + float[] actualFloatArr = new float[dim1 * dim2]; castOutput.get(actualFloatArr); Assertions.assertArrayEquals(expectedFloatArr, actualFloatArr); // Check bits are correct ShortBuffer outputBuf = output.getShortBuffer(); - short[] expectedShortArr = new short[10 * 5]; + short[] expectedShortArr = new short[dim1 * dim2]; shortBuf.get(expectedShortArr); - short[] actualShortArr = new short[10 * 5]; + short[] actualShortArr = new short[dim1 * dim2]; outputBuf.get(actualShortArr); Assertions.assertArrayEquals(expectedShortArr, actualShortArr); + + // Check outbound fp16 -> float[] conversion + float[][] floats = (float[][]) output.getValue(); + for (int i = 0; i < dim1; i++) { + Assertions.assertArrayEquals(expectedOutput[i], floats[i]); + } } } @@ -382,10 +473,12 @@ public void testFp32ToBf16() throws OrtException { String modelPath = TestHelpers.getResourcePath("/java-fp32-to-bf16.onnx").toString(); SplittableRandom rng = new SplittableRandom(1); - float[][] input = new float[10][5]; + int dim1 = 10, dim2 = 5; + float[][] input = new float[dim1][dim2]; + float[][] expectedOutput = new float[dim1][dim2]; FloatBuffer floatBuf = - ByteBuffer.allocateDirect(4 * 10 * 5).order(ByteOrder.nativeOrder()).asFloatBuffer(); - ShortBuffer shortBuf = ShortBuffer.allocate(10 * 5); + ByteBuffer.allocateDirect(4 * dim1 * dim2).order(ByteOrder.nativeOrder()).asFloatBuffer(); + ShortBuffer shortBuf = ShortBuffer.allocate(dim1 * dim2); // Generate data for (int i = 0; i < input.length; i++) { @@ -394,6 +487,8 @@ public void testFp32ToBf16() throws OrtException { input[i][j] = Float.intBitsToFloat(bits); floatBuf.put(input[i][j]); shortBuf.put(Fp16Conversions.floatToBf16(input[i][j])); + expectedOutput[i][j] = + Fp16Conversions.bf16ToFloat(Fp16Conversions.floatToBf16(input[i][j])); } } floatBuf.rewind(); @@ -401,25 +496,31 @@ public void testFp32ToBf16() throws OrtException { try (OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); OrtSession session = env.createSession(modelPath, opts); - OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {10, 5}); + OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {dim1, dim2}); OrtSession.Result result = session.run(Collections.singletonMap("input", tensor))) { OnnxTensor output = (OnnxTensor) result.get(0); // Check outbound Java side cast to fp32 works FloatBuffer castOutput = output.getFloatBuffer(); - float[] expectedFloatArr = new float[10 * 5]; + float[] expectedFloatArr = new float[dim1 * dim2]; Fp16Conversions.convertBf16BufferToFloatBuffer(shortBuf).get(expectedFloatArr); - float[] actualFloatArr = new float[10 * 5]; + float[] actualFloatArr = new float[dim1 * dim2]; castOutput.get(actualFloatArr); Assertions.assertArrayEquals(expectedFloatArr, actualFloatArr); // Check bits are correct ShortBuffer outputBuf = output.getShortBuffer(); - short[] expectedShortArr = new short[10 * 5]; + short[] expectedShortArr = new short[dim1 * dim2]; shortBuf.get(expectedShortArr); - short[] actualShortArr = new short[10 * 5]; + short[] actualShortArr = new short[dim1 * dim2]; outputBuf.get(actualShortArr); Assertions.assertArrayEquals(expectedShortArr, actualShortArr); + + // Check outbound bf16 -> float[] conversion + float[][] floats = (float[][]) output.getValue(); + for (int i = 0; i < dim1; i++) { + Assertions.assertArrayEquals(expectedOutput[i], floats[i]); + } } } From ce13f651d86952335a126f04e741d68bc41323fa Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 25 Sep 2024 03:25:20 +1200 Subject: [PATCH 10/13] Fix NaN propagation for float16 min and max operators (#22161) This makes min and max with NaN for either operand always return NaN for float16 data, matching the behaviour of float and double. The behaviour for floats and doubles was previously fixed for the CPU provider in #21492 and the CUDA provider in #19984, but these PRs didn't fix the behaviour for float16 due to tests causing asan errors. The memory access violations with float16 data have now been fixed in #22135, so this PR is a follow up to make float16 min and max behave the same as float and double for both the CPU and CUDA providers now that we can add tests for this. ### Motivation and Context Relevant previous issues (not float16 specific): * #21455 * https://github.com/onnx/onnx/issues/6003 --- .../providers/cpu/math/element_wise_ops.cc | 16 +- .../core/providers/cuda/cu_inc/common.cuh | 61 ++++- .../cpu/math/element_wise_ops_test.cc | 209 +++++++++++------- 3 files changed, 191 insertions(+), 95 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 91717486b77cb..a78ff69e5c894 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -757,9 +757,11 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_1_vec_map.min(static_cast(per_iter_bh.ScalarInput0())); + output_vec_map = input_1_vec_map.template min( + static_cast(per_iter_bh.ScalarInput0())); } else { - output_vec_map = input_1_vec_map.max(static_cast(per_iter_bh.ScalarInput0())); + output_vec_map = input_1_vec_map.template max( + static_cast(per_iter_bh.ScalarInput0())); } }, [](BroadcastHelper& per_iter_bh) { @@ -772,9 +774,11 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_0_vec_map.min(static_cast(per_iter_bh.ScalarInput1())); + output_vec_map = input_0_vec_map.template min( + static_cast(per_iter_bh.ScalarInput1())); } else { - output_vec_map = input_0_vec_map.max(static_cast(per_iter_bh.ScalarInput1())); + output_vec_map = input_0_vec_map.template max( + static_cast(per_iter_bh.ScalarInput1())); } }, [](BroadcastHelper& per_iter_bh) { @@ -790,9 +794,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_0_vec_map.min(input_1_vec_map); + output_vec_map = input_0_vec_map.template min(input_1_vec_map); } else { - output_vec_map = input_0_vec_map.max(input_1_vec_map); + output_vec_map = input_0_vec_map.template max(input_1_vec_map); } }}; diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index db36754319309..55935a9eae86d 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -10,13 +10,10 @@ #include #include #include +#include #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/shared_inc/cuda_call.h" -#if CUDA_VERSION >= 11000 -#include -#endif - namespace onnxruntime { namespace cuda { @@ -347,6 +344,21 @@ __device__ __inline__ double _Pow(double a, double b) { return pow(a, b); } template <> __device__ __inline__ half _Pow(half a, half b) { return half(powf((float)a, (float)b)); } +#define ISNAN_HALF(v__) static_cast(*reinterpret_cast(&v__) & ~MLFloat16::kSignMask) \ + > MLFloat16::kPositiveInfinityBits + +#define ISNAN_BFLOAT16(v__) static_cast(*reinterpret_cast(&v__) & ~BFloat16::kSignMask) \ + > BFloat16::kPositiveInfinityBits + +// CUDART_NAN_BF16 and CUDART_NAN_FP16 constants were only added in CUDA 12.2, +// so define our own equivalent constants to support older versions. +// Note that there is no consistent canonical NaN for FP16 and BF16; +// CUDA uses 0x7FFF for both, but ONNX Runtime uses 0x7E00 and 0x7FC1 +// for FP16 and BF16 respectively +// (see Float16Impl::kPositiveQNaNBits and BFloat16Impl::kPositiveQNaNBits). +#define NAN_HALF __ushort_as_half((unsigned short)0x7FFFU) +#define NAN_BFLOAT16 BFloat16::FromBits((uint16_t)0x7FFFU) + template __device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; } @@ -360,6 +372,24 @@ __device__ __inline__ double _Min(double a, double b) { return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a < b ? a : b ); } +template <> +__device__ __inline__ half _Min(half a, half b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_HALF(a) || ISNAN_HALF(b)) ? NAN_HALF : (a < b ? a : b); +#else + return __hmin_nan(a, b); +#endif +} + +template <> +__device__ __inline__ BFloat16 _Min(BFloat16 a, BFloat16 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a < b ? a : b); +#else + return BFloat16(__hmin_nan((__nv_bfloat16)a, (__nv_bfloat16)b)); +#endif +} + template __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } @@ -373,6 +403,29 @@ __device__ __inline__ double _Max(double a, double b) { return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a > b ? a : b ); } +template <> +__device__ __inline__ half _Max(half a, half b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_HALF(a) || ISNAN_HALF(b)) ? NAN_HALF : (a > b ? a : b); +#else + return __hmax_nan(a, b); +#endif +} + +template <> +__device__ __inline__ BFloat16 _Max(BFloat16 a, BFloat16 b) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) && ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a > b ? a : b); +#else + return BFloat16(__hmax_nan((__nv_bfloat16)a, (__nv_bfloat16)b)); +#endif +} + +#undef ISNAN_HALF +#undef ISNAN_BFLOAT16 +#undef NAN_HALF +#undef NAN_BFLOAT16 + template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index eb914646942fe..507ed8e91a728 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1787,54 +1787,90 @@ TEST(MathOpTest, Min_12_MLFloat16_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFloat16_MatrixVector) { - OpTester test("Min", 12); - test.AddInput("data_0", {3, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -2.0f, - 0.5f, 0.0f, 2.0f})); - test.AddInput("data_1", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); - test.AddOutput("min", {3, 3}, - MakeMLFloat16({0.0f, 0.0f, 0.0f, - -1.0f, -1.0f, -2.0f, - 0.5f, 0.0f, 1.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { +void TestFloat16MinMax( + const char* op_name, + const std::vector& lhs_dim, + const std::initializer_list& lhs_values, + const std::vector& rhs_dim, + const std::initializer_list& rhs_values, + const std::vector& out_dim, + const std::initializer_list& out_values) { + { std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); + if (nullptr != DefaultCpuExecutionProvider()) { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + if (nullptr != DefaultCudaExecutionProvider()) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + OpTester test(op_name, 13); + test.AddInput("data_0", lhs_dim, MakeMLFloat16(lhs_values)); + test.AddInput("data_1", rhs_dim, MakeMLFloat16(rhs_values)); + test.AddOutput("output", out_dim, MakeMLFloat16(out_values)); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -} -TEST(MathOpTest, Min_12_MLFloat16_VectorMatrix) { - OpTester test("Min", 12); - test.AddInput("data_0", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); - test.AddInput("data_1", {3, 4}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, -1.0f, - -0.5f, 0.0f, -2.0f, -1.25f, - 0.5f, 0.0f, 2.0f, 1.5f})); - test.AddOutput("min", {3, 4}, - MakeMLFloat16({0.0f, 0.0f, 0.0f, -1.0f, - -1.0f, -1.0f, -2.0f, -1.25f, - 0.5f, 0.0f, 1.0f, 1.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (nullptr != DefaultCudaExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); + OpTester test(op_name, 13); + test.AddInput("data_0", lhs_dim, MakeBFloat16(lhs_values)); + test.AddInput("data_1", rhs_dim, MakeBFloat16(rhs_values)); + test.AddOutput("output", out_dim, MakeBFloat16(out_values)); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } +TEST(MathOpTest, Min_13_Float16_MatrixVector) { + TestFloat16MinMax("Min", + {3, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -2.0f, + 0.5f, 0.0f, 2.0f}, + {3, 1}, {0.0f, -1.0f, 1.0f}, + {3, 3}, + {0.0f, 0.0f, 0.0f, + -1.0f, -1.0f, -2.0f, + 0.5f, 0.0f, 1.0f}); +} + +TEST(MathOpTest, Min_13_Float16_VectorMatrix) { + TestFloat16MinMax("Min", + {3, 1}, {0.0f, -1.0f, 1.0f}, + {3, 4}, + {1.0f, 1.0f, 1.0f, -1.0f, + -0.5f, 0.0f, -2.0f, -1.25f, + 0.5f, 0.0f, 2.0f, 1.5f}, + {3, 4}, + {0.0f, 0.0f, 0.0f, -1.0f, + -1.0f, -1.0f, -2.0f, -1.25f, + 0.5f, 0.0f, 1.0f, 1.0f}); +} + +TEST(MathOpTest, Min_13_Float16_Nan) { + TestFloat16MinMax("Min", + {4, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f, 0.5f}, + {4, 1}, {0.5f, 1.0f, 0.25f, std::numeric_limits::quiet_NaN()}, + {4, 1}, + {-1.0f, std::numeric_limits::quiet_NaN(), 0.25f, std::numeric_limits::quiet_NaN()}); +} + +TEST(MathOpTest, Min_13_Float16_Nan_with_scalar) { + TestFloat16MinMax("Min", + {3, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f}, + {1}, {0.25f}, + {3, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 0.25f}); +} + +TEST(MathOpTest, Min_13_Float16_with_scalar_Nan) { + TestFloat16MinMax("Min", + {3, 1}, {-0.5f, 1.0f, 1.5f}, + {1}, {std::numeric_limits::quiet_NaN()}, + {3, 1}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); +} TEST(MathOpTest, Max_6) { OpTester test("Max", 6); std::vector dims{3, 3}; @@ -2185,54 +2221,57 @@ TEST(MathOpTest, Max_12_MLFloat16_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFloat16_MatrixVector) { - OpTester test("Max", 12); - test.AddInput("data_0", {4, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -2.0f, - 0.0f, 0.5f, 0.75f, - 0.5f, 0.0f, 2.0f})); - test.AddInput("data_1", {4, 1}, - MakeMLFloat16({0.0f, -1.0f, 0.5f, 1.0f})); - test.AddOutput("max", {4, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -1.0f, - 0.5f, 0.5f, 0.75f, - 1.0f, 1.0f, 2.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (nullptr != DefaultCudaExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -} - -TEST(MathOpTest, Max_12_MLFloat16_VectorMatrix) { - OpTester test("Max", 12); - test.AddInput("data_0", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); - test.AddInput("data_1", {3, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -2.0f, - 0.5f, 0.0f, 2.0f})); - test.AddOutput("max", {3, 3}, - MakeMLFloat16({1.0f, 1.0f, 1.0f, - -0.5f, 0.0f, -1.0f, - 1.0f, 1.0f, 2.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - if (nullptr != DefaultCudaExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } +TEST(MathOpTest, Max_13_Float16_MatrixVector) { + TestFloat16MinMax("Max", + {4, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -2.0f, + 0.0f, 0.5f, 0.75f, + 0.5f, 0.0f, 2.0f}, + {4, 1}, {0.0f, -1.0f, 0.5f, 1.0f}, + {4, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -1.0f, + 0.5f, 0.5f, 0.75f, + 1.0f, 1.0f, 2.0f}); +} + +TEST(MathOpTest, Max_13_Float16_VectorMatrix) { + TestFloat16MinMax("Max", + {3, 1}, {0.0f, -1.0f, 1.0f}, + {3, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -2.0f, + 0.5f, 0.0f, 2.0f}, + {3, 3}, + {1.0f, 1.0f, 1.0f, + -0.5f, 0.0f, -1.0f, + 1.0f, 1.0f, 2.0f}); +} + +TEST(MathOpTest, Max_13_Float16_Nan) { + TestFloat16MinMax("Max", + {4, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f, 0.5f}, + {4, 1}, {0.5f, 1.0f, 0.25f, std::numeric_limits::quiet_NaN()}, + {4, 1}, + {0.5f, std::numeric_limits::quiet_NaN(), 1.0f, std::numeric_limits::quiet_NaN()}); +} + +TEST(MathOpTest, Max_13_Float16_Nan_with_scalar) { + TestFloat16MinMax("Max", + {3, 1}, {-1.0f, std::numeric_limits::quiet_NaN(), 1.0f}, + {1}, {0.25f}, + {3, 1}, {0.25f, std::numeric_limits::quiet_NaN(), 1.0f}); +} + +TEST(MathOpTest, Max_13_Float16_with_scalar_Nan) { + TestFloat16MinMax("Max", + {3, 1}, {-0.5f, 1.0f, 1.5f}, + {1}, {std::numeric_limits::quiet_NaN()}, + {3, 1}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); } TEST(MathOpTest, Not) { From 209ff86d5238df63e5b242e046aa5ca22a6303ff Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Tue, 24 Sep 2024 08:33:03 -0700 Subject: [PATCH 11/13] Get build working on Xcode 16 (#22168) --- .github/workflows/mac.yml | 389 +++++++++--------- cmake/CMakeLists.txt | 6 +- .../external/onnxruntime_external_deps.cmake | 1 + cmake/onnxruntime_config.h.in | 1 + cmake/onnxruntime_unittests.cmake | 2 - cmake/patches/nsync/nsync_1.26.0.patch | 14 + .../core/common/eigen_common_wrapper.h | 6 + .../contrib_ops/cpu/bert/embed_layer_norm.cc | 7 +- onnxruntime/core/session/inference_session.cc | 3 +- onnxruntime/test/optimizer/optimizer_test.cc | 2 - 10 files changed, 231 insertions(+), 200 deletions(-) create mode 100644 cmake/patches/nsync/nsync_1.26.0.patch diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 6efa8a5592337..aecc05c91d736 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -1,192 +1,197 @@ -name: Mac_CI - -on: - push: - branches: - - main - - rel-* - pull_request: - branches: - - main - - rel-* - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -env: - python_version: 3.11 - xcode_version: 15.2 - -jobs: - ARM64: - runs-on: macos-14 - - timeout-minutes: 60 - - steps: - - uses: actions/setup-python@v5 - with: - python-version: ${{ env.python_version }} - - - name: Verify ARM64 machine - shell: python - run: | - import platform - assert platform.machine() == "arm64", "This job expects to be run on an ARM64 machine." - - - name: Use Xcode ${{ env.xcode_version }} - shell: bash - run: | - XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.xcode_version }}.app/Contents/Developer" - sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" - - - uses: actions/checkout@v4 - - - name: Build and test - shell: bash - run: | - python ./tools/ci_build/build.py \ - --build_dir ./build \ - --update \ - --build --parallel \ - --test \ - --build_shared_lib \ - --build_objc \ - --use_coreml \ - --use_xnnpack \ - --use_binskim_compliant_compile_flags - - Vcpkg: - runs-on: macos-13 - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: ${{ env.python_version }} - - - name: "Run vcpkg(x64-osx)" - uses: lukka/run-vcpkg@v11 - with: - vcpkgDirectory: "${{ runner.temp }}/vcpkg" - vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 - runVcpkgInstall: true - vcpkgJsonGlob: "cmake/vcpkg.json" - vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" - env: - VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" - VCPKG_DEFAULT_TRIPLET: "x64-osx" - # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching - - - name: "Run compile_schema.py" - run: | - # Runner's host triplet should be x64-osx or arm64-osx - export FLATC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/flatbuffers" - export PATH="$FLATC_DIR:$PATH" - flatc --version - python onnxruntime/core/flatbuffers/schema/compile_schema.py --flatc "$(which flatc)" - - - name: "Detect protoc" - id: protoc-detect - run: | - export PROTOC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/protobuf" - export PATH="$PROTOC_DIR:$PATH" - protoc --version - echo "protoc_path=$(which protoc)" >> "$GITHUB_OUTPUT" - - - name: "Run build.py(x64-osx)" - run: | - python ./tools/ci_build/build.py \ - --build_dir "build/x64-osx" \ - --skip_submodule_sync \ - --skip_tests \ - --compile_no_warning_as_error \ - --parallel \ - --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ - --osx_arch x86_64 \ - --use_vcpkg \ - --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ - --cmake_extra_defines "VCPKG_TARGET_TRIPLET=x64-osx" \ - --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ - --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" - shell: bash - - - name: "Run vcpkg(arm64-osx)" - uses: lukka/run-vcpkg@v11 - with: - vcpkgDirectory: "${{ runner.temp }}/vcpkg" - vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 - runVcpkgInstall: true - vcpkgJsonGlob: "cmake/vcpkg.json" - vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" - env: - VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" - VCPKG_DEFAULT_TRIPLET: "arm64-osx" - # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching - - - name: "Run build.py(arm64-osx)" - run: | - python ./tools/ci_build/build.py \ - --build_dir "build/arm64-osx" \ - --skip_submodule_sync \ - --skip_tests \ - --compile_no_warning_as_error \ - --parallel \ - --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ - --osx_arch arm64 \ - --use_vcpkg \ - --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ - --cmake_extra_defines "VCPKG_TARGET_TRIPLET=arm64-osx" \ - --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ - --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" - shell: bash - - Objective-C-StaticAnalysis: - runs-on: macos-14 - - timeout-minutes: 30 - - steps: - - uses: actions/setup-python@v5 - with: - python-version: ${{ env.python_version }} - - - name: Use Xcode ${{ env.xcode_version }} - shell: bash - run: | - XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.xcode_version }}.app/Contents/Developer" - sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" - - - uses: actions/checkout@v4 - - - name: Generate compile_commands.json and ONNX protobuf files - shell: bash - run: | - python ./tools/ci_build/build.py \ - --build_dir ./build \ - --cmake_generator "Unix Makefiles" \ - --config Debug \ - --build_shared_lib \ - --use_coreml \ - --build_objc \ - --enable_training_apis \ - --cmake_extra_defines CMAKE_EXPORT_COMPILE_COMMANDS=ON \ - --use_binskim_compliant_compile_flags \ - --update \ - --build --parallel \ - --target onnx_proto - - - name: Analyze Objective-C/C++ source code - shell: bash - run: | - CLANG_TIDY_CHECKS="-*,clang-analyzer-*" - - "$(brew --prefix llvm@15)/bin/clang-tidy" \ - -p=./build/Debug \ - --checks="${CLANG_TIDY_CHECKS}" \ - --warnings-as-errors="${CLANG_TIDY_CHECKS}" \ - --header-filter="objectivec/include|objectivec|onnxruntime/core" \ - ./objectivec/*.mm \ - ./onnxruntime/core/platform/apple/logging/apple_log_sink.mm \ - ./onnxruntime/core/providers/coreml/model/*.mm +name: Mac_CI + +on: + push: + branches: + - main + - rel-* + pull_request: + branches: + - main + - rel-* + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +env: + python_version: 3.11 + +jobs: + ARM64: + runs-on: macos-14 + + env: + xcode_version: 16 + + timeout-minutes: 60 + + steps: + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.python_version }} + + - name: Verify ARM64 machine + shell: python + run: | + import platform + assert platform.machine() == "arm64", "This job expects to be run on an ARM64 machine." + + - name: Use Xcode ${{ env.xcode_version }} + shell: bash + run: | + XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.xcode_version }}.app/Contents/Developer" + sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" + + - uses: actions/checkout@v4 + + - name: Build and test + shell: bash + run: | + python ./tools/ci_build/build.py \ + --build_dir ./build \ + --update \ + --build --parallel \ + --test \ + --build_shared_lib \ + --build_objc \ + --use_coreml \ + --use_xnnpack \ + --use_binskim_compliant_compile_flags + + Vcpkg: + runs-on: macos-13 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.python_version }} + + - name: "Run vcpkg(x64-osx)" + uses: lukka/run-vcpkg@v11 + with: + vcpkgDirectory: "${{ runner.temp }}/vcpkg" + vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 + runVcpkgInstall: true + vcpkgJsonGlob: "cmake/vcpkg.json" + vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" + env: + VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" + VCPKG_DEFAULT_TRIPLET: "x64-osx" + # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching + + - name: "Run compile_schema.py" + run: | + # Runner's host triplet should be x64-osx or arm64-osx + export FLATC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/flatbuffers" + export PATH="$FLATC_DIR:$PATH" + flatc --version + python onnxruntime/core/flatbuffers/schema/compile_schema.py --flatc "$(which flatc)" + + - name: "Detect protoc" + id: protoc-detect + run: | + export PROTOC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/protobuf" + export PATH="$PROTOC_DIR:$PATH" + protoc --version + echo "protoc_path=$(which protoc)" >> "$GITHUB_OUTPUT" + + - name: "Run build.py(x64-osx)" + run: | + python ./tools/ci_build/build.py \ + --build_dir "build/x64-osx" \ + --skip_submodule_sync \ + --skip_tests \ + --compile_no_warning_as_error \ + --parallel \ + --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ + --osx_arch x86_64 \ + --use_vcpkg \ + --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ + --cmake_extra_defines "VCPKG_TARGET_TRIPLET=x64-osx" \ + --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ + --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" + shell: bash + + - name: "Run vcpkg(arm64-osx)" + uses: lukka/run-vcpkg@v11 + with: + vcpkgDirectory: "${{ runner.temp }}/vcpkg" + vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 + runVcpkgInstall: true + vcpkgJsonGlob: "cmake/vcpkg.json" + vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" + env: + VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" + VCPKG_DEFAULT_TRIPLET: "arm64-osx" + # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching + + - name: "Run build.py(arm64-osx)" + run: | + python ./tools/ci_build/build.py \ + --build_dir "build/arm64-osx" \ + --skip_submodule_sync \ + --skip_tests \ + --compile_no_warning_as_error \ + --parallel \ + --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ + --osx_arch arm64 \ + --use_vcpkg \ + --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ + --cmake_extra_defines "VCPKG_TARGET_TRIPLET=arm64-osx" \ + --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ + --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" + shell: bash + + Objective-C-StaticAnalysis: + runs-on: macos-14 + + env: + xcode_version: 15.2 + + timeout-minutes: 30 + + steps: + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.python_version }} + + - name: Use Xcode ${{ env.xcode_version }} + shell: bash + run: | + XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.xcode_version }}.app/Contents/Developer" + sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" + + - uses: actions/checkout@v4 + + - name: Generate compile_commands.json and ONNX protobuf files + shell: bash + run: | + python ./tools/ci_build/build.py \ + --build_dir ./build \ + --cmake_generator "Unix Makefiles" \ + --config Debug \ + --build_shared_lib \ + --use_coreml \ + --build_objc \ + --enable_training_apis \ + --cmake_extra_defines CMAKE_EXPORT_COMPILE_COMMANDS=ON \ + --use_binskim_compliant_compile_flags \ + --update \ + --build --parallel \ + --target onnx_proto + + - name: Analyze Objective-C/C++ source code + shell: bash + run: | + CLANG_TIDY_CHECKS="-*,clang-analyzer-*" + + "$(brew --prefix llvm@15)/bin/clang-tidy" \ + -p=./build/Debug \ + --checks="${CLANG_TIDY_CHECKS}" \ + --warnings-as-errors="${CLANG_TIDY_CHECKS}" \ + --header-filter="objectivec/include|objectivec|onnxruntime/core" \ + ./objectivec/*.mm \ + ./onnxruntime/core/platform/apple/logging/apple_log_sink.mm \ + ./onnxruntime/core/providers/coreml/model/*.mm diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 246675b72f4e6..7168a99fe1f93 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -642,10 +642,12 @@ else() check_cxx_compiler_flag(-Wcast-function-type HAS_CAST_FUNCTION_TYPE) check_cxx_compiler_flag(-Wcatch-value HAS_CATCH_VALUE) check_cxx_compiler_flag(-Wclass-memaccess HAS_CLASS_MEMACCESS) + check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE) check_cxx_compiler_flag(-Wdeprecated-anon-enum-enum-conversion HAS_DEPRECATED_ANON_ENUM_ENUM_CONVERSION) check_cxx_compiler_flag(-Wdeprecated-builtins HAS_DEPRECATED_BUILTINS) check_cxx_compiler_flag(-Wdeprecated-copy HAS_DEPRECATED_COPY) check_cxx_compiler_flag(-Wdeprecated-declarations HAS_DEPRECATED_DECLARATIONS) + check_cxx_compiler_flag(-Wdeprecated-this-capture HAS_DEPRECATED_THIS_CAPTURE) check_cxx_compiler_flag(-Wenum-constexpr-conversion HAS_ENUM_CONSTEXPR_CONVERSION) check_cxx_compiler_flag(-Wformat-truncation HAS_FORMAT_TRUNCATION) check_cxx_compiler_flag(-Wignored-attributes HAS_IGNORED_ATTRIBUTES) @@ -656,15 +658,15 @@ else() check_cxx_compiler_flag(-Wshorten-64-to-32 HAS_SHORTEN_64_TO_32) check_cxx_compiler_flag(-Wstrict-aliasing HAS_STRICT_ALIASING) check_nvcc_compiler_flag(-Wstrict-aliasing NVCC_HAS_STRICT_ALIASING) + check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW) check_cxx_compiler_flag(-Wtautological-pointer-compare HAS_TAUTOLOGICAL_POINTER_COMPARE) check_cxx_compiler_flag(-Wundefined-var-template HAS_UNDEFINED_VAR_TEMPLATE) check_cxx_compiler_flag(-Wunused-but-set-parameter HAS_UNUSED_BUT_SET_PARAMETER) check_cxx_compiler_flag(-Wunused-but-set-variable HAS_UNUSED_BUT_SET_VARIABLE) check_cxx_compiler_flag(-Wunused-variable HAS_UNUSED_VARIABLE) check_cxx_compiler_flag(-Wuseless-cast HAS_USELESS_CAST) - check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW) + if(onnxruntime_ENABLE_TRAINING_APIS) - check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE) if(HAS_DANGLING_REFERENCE) list(APPEND ORT_WARNING_FLAGS -Wno-dangling-reference) endif() diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 43f18abbe9522..cb737ee53639f 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -91,6 +91,7 @@ if (NOT WIN32) google_nsync URL ${DEP_URL_google_nsync} URL_HASH SHA1=${DEP_SHA1_google_nsync} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/nsync/nsync_1.26.0.patch FIND_PACKAGE_ARGS NAMES nsync unofficial-nsync ) #nsync tests failed on Mac Build diff --git a/cmake/onnxruntime_config.h.in b/cmake/onnxruntime_config.h.in index e3ea767401ddc..bbddefe531cb8 100644 --- a/cmake/onnxruntime_config.h.in +++ b/cmake/onnxruntime_config.h.in @@ -9,6 +9,7 @@ #cmakedefine HAS_CLASS_MEMACCESS #cmakedefine HAS_DEPRECATED_COPY #cmakedefine HAS_DEPRECATED_DECLARATIONS +#cmakedefine HAS_DEPRECATED_THIS_CAPTURE #cmakedefine HAS_FORMAT_TRUNCATION #cmakedefine HAS_IGNORED_ATTRIBUTES #cmakedefine HAS_MAYBE_UNINITIALIZED diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 4b880c4437dfd..a4ba85e868896 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -893,8 +893,6 @@ if (MSVC) set_property(SOURCE "${TEST_SRC_DIR}/optimizer/graph_transform_test.cc" "${TEST_SRC_DIR}/optimizer/qdq_transformer_test.cc" APPEND PROPERTY COMPILE_OPTIONS "/bigobj") - set_property(SOURCE "${TEST_SRC_DIR}/optimizer/qdq_transformer_test.cc" - APPEND PROPERTY COMPILE_OPTIONS "/bigobj") else() target_compile_options(onnxruntime_test_all PRIVATE "-Wno-parentheses") endif() diff --git a/cmake/patches/nsync/nsync_1.26.0.patch b/cmake/patches/nsync/nsync_1.26.0.patch new file mode 100644 index 0000000000000..78ef2b3cb20d4 --- /dev/null +++ b/cmake/patches/nsync/nsync_1.26.0.patch @@ -0,0 +1,14 @@ +diff --git a/public/nsync_atomic.h b/public/nsync_atomic.h +index aebe4f7..466a262 100644 +--- a/public/nsync_atomic.h ++++ b/public/nsync_atomic.h +@@ -45,7 +45,8 @@ NSYNC_CPP_END_ + NSYNC_CPP_START_ + typedef std::atomic nsync_atomic_uint32_; + NSYNC_CPP_END_ +-#define NSYNC_ATOMIC_UINT32_INIT_ ATOMIC_VAR_INIT (0) ++// Replace deprecated ATOMIC_VAR_INIT with std::atomic brace initialization ++#define NSYNC_ATOMIC_UINT32_INIT_ { 0 } + #define NSYNC_ATOMIC_UINT32_LOAD_(p) (std::atomic_load (p)) + #define NSYNC_ATOMIC_UINT32_STORE_(p,v) (std::atomic_store ((p), (uint32_t) (v))) + #define NSYNC_ATOMIC_UINT32_PTR_(p) (p) diff --git a/include/onnxruntime/core/common/eigen_common_wrapper.h b/include/onnxruntime/core/common/eigen_common_wrapper.h index 57599e04037dc..19efa7bcff107 100644 --- a/include/onnxruntime/core/common/eigen_common_wrapper.h +++ b/include/onnxruntime/core/common/eigen_common_wrapper.h @@ -49,6 +49,12 @@ #pragma GCC diagnostic ignored "-Wshorten-64-to-32" #endif +// eigen-src/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h:215:9: +// error: implicit capture of 'this' with a capture default of '=' is deprecated [-Werror,-Wdeprecated-this-capture] +#ifdef HAS_DEPRECATED_THIS_CAPTURE +#pragma GCC diagnostic ignored "-Wdeprecated-this-capture" +#endif + #elif defined(_MSC_VER) // build\windows\debug\external\eigen3\unsupported\eigen\cxx11\src/Tensor/Tensor.h(76): // warning C4554: '&': check operator precedence for possible error; use parentheses to clarify precedence diff --git a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc index 570f4108c3f62..72adfa025da57 100644 --- a/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc @@ -86,6 +86,11 @@ Status EmbedLayerNorm::Compute(OpKernelContext* context) const { std::atomic_bool failed{false}; int n = batch_size * sequence_length; + + // Put epsilon into local variable here to avoid the need to capture 'this' in the TryBatchParallelFor() lambda. + // Using the copy capture default (=) to implicitly capture 'this' is deprecated. + const float epsilon_value = epsilon(); + concurrency::ThreadPool::TryBatchParallelFor( context->GetOperatorThreadPool(), n, [=, &failed](ptrdiff_t index) { int word_col_index = input_ids_data[index]; @@ -136,7 +141,7 @@ Status EmbedLayerNorm::Compute(OpKernelContext* context) const { y[i] = a; sum += a * a; } - T e = sqrt(sum / hidden_size + static_cast(epsilon())); + T e = sqrt(sum / hidden_size + static_cast(epsilon_value)); for (int i = 0; i < hidden_size; i++) { y[i] = y[i] / e * gamma_data[i] + beta_data[i]; } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b9e017df5baa3..83e7596d2f6b8 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2770,7 +2770,8 @@ common::Status InferenceSession::RunAsync(const RunOptions* run_options, if (!tp || concurrency::ThreadPool::DegreeOfParallelism(tp) < 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync"); } - std::function run_fn = [=]() { + std::function run_fn = [run_options, feed_names, feeds, fetch_names, fetches, num_fetches, + callback, user_data, this]() { Status status = Status::OK(); ORT_TRY { if (run_options) { diff --git a/onnxruntime/test/optimizer/optimizer_test.cc b/onnxruntime/test/optimizer/optimizer_test.cc index 79704f2cc79e3..81c1a4ace1e33 100644 --- a/onnxruntime/test/optimizer/optimizer_test.cc +++ b/onnxruntime/test/optimizer/optimizer_test.cc @@ -24,8 +24,6 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace test { -static const std::string MODEL_FOLDER = "testdata/transform/"; - TEST(OptimizerTest, Basic) { Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); From 5fa4505d1b035731ed76c4b3440f411766040447 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Tue, 24 Sep 2024 09:37:53 -0700 Subject: [PATCH 12/13] Set enable_htp_fp16_precision default to true (#22186) ### Description Set enable_htp_fp16_precision default to true for HTP backend. --- .../onnxruntime/core/session/onnxruntime_c_api.h | 6 +++--- .../core/providers/qnn/qnn_execution_provider.h | 2 +- onnxruntime/test/onnx/main.cc | 2 +- onnxruntime/test/perftest/command_args_parser.cc | 2 +- onnxruntime/test/providers/qnn/cast_test.cc | 15 +++++++++++---- onnxruntime/test/providers/qnn/clip_op_test.cc | 13 +++++++++++-- onnxruntime/test/providers/qnn/matmul_test.cpp | 12 ++++++++++-- .../test/providers/qnn/simple_op_htp_test.cc | 2 ++ .../test/providers/qnn/transpose_htp_test.cc | 11 +++++++++-- .../test/qnn_ctx_gen/command_args_parser.cc | 2 +- 10 files changed, 50 insertions(+), 17 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index a4ec66761c4ba..3aa98bb020452 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3650,10 +3650,10 @@ struct OrtApi { * - "73" * - "75" * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). - "enable_htp_fp16_precision": Only used for float32 model. + "enable_htp_fp16_precision": Used for float32 model for HTP backend. Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. - - "0": Default. With fp32 precision. - - "1": With fp16 precision. + - "0": With fp32 precision. + - "1": Default. With fp16 precision. "enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context. - "0": Default. Disabled. - "1": Enabled. diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 9cd73edbff0e0..ac9098f907975 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -142,7 +142,7 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t device_id_ = 0; qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; uint32_t default_rpc_control_latency_ = 0; - bool enable_HTP_FP16_precision_ = false; + bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; #ifdef _WIN32 onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_; diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 924616f49ab25..e8c948ade1068 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -73,7 +73,7 @@ void usage() { "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" - "\t Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n" + "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [Usage]: -e -i '| |' \n\n" "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index c1c48d4945a4d..6e811f4596eab 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -98,7 +98,7 @@ namespace perftest { "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" - "\t Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n" + "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n" "\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" diff --git a/onnxruntime/test/providers/qnn/cast_test.cc b/onnxruntime/test/providers/qnn/cast_test.cc index f03782c33c30a..9b83dd281a56d 100644 --- a/onnxruntime/test/providers/qnn/cast_test.cc +++ b/onnxruntime/test/providers/qnn/cast_test.cc @@ -49,7 +49,8 @@ static GetTestModelFn BuildCastTestCase(const std::vector& shape, template static void RunCastOpTest(const std::vector& shape, ONNX_NAMESPACE::TensorProto_DataType dst_type, ExpectedEPNodeAssignment expected_ep_assignment, - bool use_htp) { + bool use_htp, + bool enable_fp16_precision = true) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = use_htp ? "QnnHtp.dll" : "QnnCpu.dll"; @@ -57,6 +58,12 @@ static void RunCastOpTest(const std::vector& shape, ONNX_NAMESPACE::Ten provider_options["backend_path"] = use_htp ? "libQnnHtp.so" : "libQnnCpu.so"; #endif + if (use_htp && enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } + RunQnnModelTest(BuildCastTestCase(shape, dst_type), provider_options, 13, // opset @@ -93,19 +100,19 @@ TEST_F(QnnCPUBackendTests, TestCastFloatToInt32) { // Cast int32_t to float on HTP TEST_F(QnnHTPBackendTests, TestCastInt32ToFloatHTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, - true); + true, false); } // Cast uint8_t to float on HTP TEST_F(QnnHTPBackendTests, TestCastUInt8ToFloatHTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, - true); + true, false); } // Cast float to int32_t on HTP TEST_F(QnnHTPBackendTests, TestCastFloatToInt32HTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, ExpectedEPNodeAssignment::All, - true); + true, false); } // Cast int64_t to int32_t on HTP diff --git a/onnxruntime/test/providers/qnn/clip_op_test.cc b/onnxruntime/test/providers/qnn/clip_op_test.cc index c3a75fd7446e2..cfa77a46210b3 100644 --- a/onnxruntime/test/providers/qnn/clip_op_test.cc +++ b/onnxruntime/test/providers/qnn/clip_op_test.cc @@ -21,7 +21,8 @@ static void RunClipTest(const TestInputDef& input_def, const std::vector>& min_max_defs, ExpectedEPNodeAssignment expected_ep_assignment, bool on_cpu_backend = true, - int opset = 13) { + int opset = 13, + bool enable_fp16_precision = true) { ProviderOptions provider_options; #if defined(_WIN32) @@ -30,6 +31,12 @@ static void RunClipTest(const TestInputDef& input_def, provider_options["backend_path"] = on_cpu_backend ? "libQnnCpu.so" : "libQnnHtp.so"; #endif + if (!on_cpu_backend && enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } + RunQnnModelTest(BuildOpTestCase("Clip", {input_def}, min_max_defs, {}), provider_options, opset, @@ -80,7 +87,9 @@ TEST_F(QnnHTPBackendTests, Clip_f32) { {TestInputDef({}, true, {-5.0f}), TestInputDef({}, true, {5.0f})}, ExpectedEPNodeAssignment::All, - on_cpu_backend); + on_cpu_backend, + 13, + false); } // Test Clip with int32 on HTP diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index d8c34d6a6c6ed..708aac03ceb2e 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -117,7 +117,8 @@ static void RunQDQPerChannelMatMulOpOpTest(const TestInputDef& input_def, ExpectedEPNodeAssignment expected_ep_assignment, int opset = 21, bool use_contrib_qdq = false, - QDQTolerance tolerance = QDQTolerance()) { + QDQTolerance tolerance = QDQTolerance(), + bool enable_fp16_precision = true) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -125,6 +126,12 @@ static void RunQDQPerChannelMatMulOpOpTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif + if (enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } + TestQDQModelAccuracy(BuildMatMulOpTestCase(input_def, weights_def), BuildQDQPerChannelMatMulTestCase(input_def, weights_def, @@ -275,7 +282,8 @@ TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_AS8_WeightInt4) { ExpectedEPNodeAssignment::All, 21, false, - QDQTolerance(0.007f)); + QDQTolerance(0.007f), + false); } // Test QDQ per-channel MatMul with 16-bit act, int8 weights (static) diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 2ebc2c6251b44..83899ec6ef17b 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -157,6 +157,8 @@ static void RunOpTest(const std::string& op_type, if (enable_htp_fp16_precision) { provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; // enabled in QNN EP by default } // Runs model with a Q/DQ binary op and compares the outputs of the CPU and QNN EPs. diff --git a/onnxruntime/test/providers/qnn/transpose_htp_test.cc b/onnxruntime/test/providers/qnn/transpose_htp_test.cc index 119b8301f36ed..63746e22d214d 100644 --- a/onnxruntime/test/providers/qnn/transpose_htp_test.cc +++ b/onnxruntime/test/providers/qnn/transpose_htp_test.cc @@ -90,7 +90,8 @@ static void RunTransposeQDQTest(const TestInputDef& input_def, template static void RunTransposeNonQDQOnHTP(const TestInputDef& input_def, const std::vector& attrs, - ExpectedEPNodeAssignment expected_ep_assignment) { + ExpectedEPNodeAssignment expected_ep_assignment, + bool enable_fp16_precision = true) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -98,6 +99,12 @@ static void RunTransposeNonQDQOnHTP(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif + if (enable_fp16_precision) { + provider_options["enable_htp_fp16_precision"] = "1"; + } else { + provider_options["enable_htp_fp16_precision"] = "0"; + } + RunQnnModelTest(BuildTransposeTestCase(input_def, attrs), provider_options, 13, @@ -123,7 +130,7 @@ TEST_F(QnnHTPBackendTests, TransposeInt32OnHTP) { TEST_F(QnnHTPBackendTests, TransposeFloatOnHTP) { RunTransposeNonQDQOnHTP(TestInputDef({1, 3, 224, 128}, false, 0, 10.0f), {utils::MakeAttribute("perm", std::vector{0, 2, 3, 1})}, - ExpectedEPNodeAssignment::All); + ExpectedEPNodeAssignment::All, false); } #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc index 509f56664e572..102846e08ac5f 100644 --- a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc @@ -46,7 +46,7 @@ namespace qnnctxgen { "\t [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" "\t [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. eg: '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" "\t [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" - "\t Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n" + "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" "\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n" "\n" From 6cc06ad06967db4611138b4028df9ffacc6eb860 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Tue, 24 Sep 2024 09:51:59 -0700 Subject: [PATCH 13/13] GQA MLFloat16 cpu (#22102) ### Description ### Motivation and Context --------- Co-authored-by: Your Name --- docs/OperatorKernels.md | 4 +- .../contrib_ops/cpu/bert/attention_utils.cc | 20 ++- .../contrib_ops/cpu/bert/gqa_attention_base.h | 96 +++++++--- .../cpu/bert/group_query_attention.cc | 24 +-- .../contrib_ops/cpu/bert/rotary_embedding.cc | 46 +++-- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 4 + .../test/python/transformers/test_gqa_cpu.py | 170 +++++++++--------- 7 files changed, 229 insertions(+), 135 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 407e08c96a891..734506681ab60 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -482,7 +482,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -508,7 +508,7 @@ Do not modify directly.* |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| -|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc index 7b84971585f9f..c8fe9c77d8ff8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_utils.cc @@ -48,13 +48,13 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat constexpr size_t element_size = sizeof(T); ProcessBroadcastSpanFuncs add_funcs{ [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); + per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); }}; // For element-wise add // Allocate space for output of Q(BS, D) + bias(D) @@ -132,13 +132,13 @@ Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is constexpr size_t element_size = sizeof(T); ProcessBroadcastSpanFuncs add_funcs{ [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); + per_iter_bh.OutputEigen() = per_iter_bh.ScalarInput0() + per_iter_bh.EigenInput1().array(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0().array() + per_iter_bh.ScalarInput1(); }, [](BroadcastHelper& per_iter_bh) { - per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); + per_iter_bh.OutputEigen() = per_iter_bh.EigenInput0() + per_iter_bh.EigenInput1(); }}; // For element-wise add // Get Q's bias from combined bias @@ -219,6 +219,10 @@ template Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, int batch_size, int num_heads, int sequence_length, int head_size, const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out); +template Status MaybeTransposeToBNSHAndAddBias(OpKernelContext* context, AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out); + template Status MaybeTransposeToBNSH(AllocatorPtr allocator, int batch_size, int num_heads, int sequence_length, int head_size, @@ -242,5 +246,9 @@ template Status MaybeTransposeToBNSH(AllocatorPtr allocator, int batch_size, int num_heads, int sequence_length, int head_size, const Tensor* in, OrtValue& out); +template Status MaybeTransposeToBNSH(AllocatorPtr allocator, + int batch_size, int num_heads, int sequence_length, int head_size, + const Tensor* in, OrtValue& out); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index bfec9aef56727..ccaeb6654e286 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -75,7 +75,7 @@ class GQAAttentionBase { int seqlen_present_kv_cache = static_cast(present_key->Shape().GetDims()[2]); // Compute the attention score. - size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(T); + size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(float); auto attention_probs = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); @@ -87,16 +87,17 @@ class GQAAttentionBase { bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp); + present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; - ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(), + ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, + seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - is_prompt, tp); + is_prompt, tp, allocator); return Status::OK(); } @@ -106,7 +107,7 @@ class GQAAttentionBase { // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) // attention_probs(B, N, S, T) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + void ComputeAttentionProbs(float* attention_probs, // output buffer with size BxNxSxT const T* Q, // Q data. Its size is BxNxSxH const T* K, // k data. Its size is BxNxLxH const int32_t* seqlens_k, // total - 1 sequence lengths tensor @@ -120,7 +121,8 @@ class GQAAttentionBase { const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt - ThreadPool* tp) const { // thread pool + ThreadPool* tp, // thread pool + AllocatorPtr allocator) const { // allocator for temporary buffer const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); @@ -131,7 +133,9 @@ class GQAAttentionBase { const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { - memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); + memset((void*)present_key, + 0, + batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } const size_t loop_len = batch_size * num_heads_; @@ -164,7 +168,7 @@ class GQAAttentionBase { const size_t past_chunk_length = past_seqlen * head_size; const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; - T* output = attention_probs + output_offset; + float* output = attention_probs + output_offset; const T* k; if (packed_qkv) { @@ -190,12 +194,28 @@ class GQAAttentionBase { q = Q + q_input_chunk_length * i; } - math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, - static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, output, - static_cast(present_buffer_sequence_length), nullptr); + if constexpr (std::is_same::value) { + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, + static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, + output, static_cast(present_buffer_sequence_length), nullptr); + } else { + size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float); + auto q_k_fp32 = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(q_k_fp32, BufferDeleter(allocator)); + + float* q_fp32 = static_cast(q_k_fp32); + MlasConvertHalfToFloatBuffer(q, q_fp32, head_size * sequence_length); + + float* k_fp32 = q_fp32 + head_size * sequence_length; + MlasConvertHalfToFloatBuffer(k, k_fp32, head_size * total_seqlen); + + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q_fp32, + static_cast(head_size), k_fp32, static_cast(head_size), 0.0f /*bata*/, + output, static_cast(present_buffer_sequence_length), nullptr); + } // compute Softmax - T* output_softmax = output; + float* output_softmax = output; for (size_t seq = 0; seq < sequence_length; seq++) { size_t seq_causal_length = past_seqlen + seq + 1; if (local_window_size_ > 0 && seq_causal_length > static_cast(local_window_size_) + 1) { @@ -237,7 +257,7 @@ class GQAAttentionBase { template void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH - const T* attention_probs, // Attention probs with size BxNxSxT + const float* attention_probs, // Attention probs with size BxNxSxT const T* V, // V value with size BxN_kvxSxH const int32_t* seqlens_k, // total - 1 sequence lengths tensor const size_t batch_size, // batch size @@ -251,7 +271,8 @@ class GQAAttentionBase { const bool past_present_share_buffer, // whether present key and value share the same buffer const bool packed_qkv, // whether Q, K, V are packed const bool is_prompt, // whether it is prompt - ThreadPool* tp) const { + ThreadPool* tp, + AllocatorPtr allocator) const { const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); @@ -261,7 +282,9 @@ class GQAAttentionBase { const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { - memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); + memset((void*)present_value, + 0, + batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } const size_t loop_len = batch_size * num_heads_; @@ -285,6 +308,13 @@ class GQAAttentionBase { unit_cost.bytes_loaded += bytes_to_copy_trans_all; unit_cost.bytes_stored += bytes_to_copy_trans_all; + size_t output_fp32_bytes = 0; + if constexpr (std::is_same::value) { + output_fp32_bytes = SafeInt(sequence_length) * batch_size * num_heads_ * head_size * sizeof(float); + } + auto output_fp32 = allocator->Alloc(output_fp32_bytes); + BufferUniquePtr scratch_buffer(output_fp32, BufferDeleter(allocator)); + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const size_t batch_index = i / num_heads_; @@ -305,15 +335,39 @@ class GQAAttentionBase { i / kv_num_heads_factor); } - T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; - math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/ - attention_probs + attention_probs_offset, - static_cast(present_buffer_sequence_length), v, static_cast(head_size), - 0.0f /*beta*/, output_current, static_cast(hidden_size), nullptr); + if constexpr (std::is_same::value) { + T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, + 1.f, /*alpha*/ attention_probs + attention_probs_offset, + static_cast(present_buffer_sequence_length), v, + static_cast(head_size), 0.0f /*beta*/, output_current, + static_cast(hidden_size), nullptr); + } else { + size_t bytes = head_size * total_seqlen * sizeof(float); + auto v_fp32 = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator)); + + float* v_fp32_ptr = static_cast(v_fp32); + MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, head_size * total_seqlen); + + float* output_fp32_current = static_cast(output_fp32) + + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, + 1.f, /*alpha*/ attention_probs + attention_probs_offset, + static_cast(present_buffer_sequence_length), v_fp32_ptr, + static_cast(head_size), 0.0f /*beta*/, output_fp32_current, + static_cast(hidden_size), nullptr); + } } }); + + if constexpr (std::is_same::value) { + MlasConvertFloatToHalfBuffer(static_cast(output_fp32), + output, + SafeInt(sequence_length) * batch_size * num_heads_ * head_size); + } } }; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 2a38e4a1ac636..a1ed35e54b008 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -22,16 +22,20 @@ namespace onnxruntime { namespace contrib { // These ops are internal-only, so register outside of onnx -ONNX_OPERATOR_TYPED_KERNEL_EX( - GroupQueryAttention, - kMSDomain, - 1, - float, - kCpuExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("M", DataTypeImpl::GetTensorType()), - GroupQueryAttention); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + GroupQueryAttention); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) template GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 6732f8b96cce2..cbfd2f0949363 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -13,16 +13,20 @@ namespace onnxruntime { namespace contrib { // These ops are internal-only, so register outside of onnx -ONNX_OPERATOR_TYPED_KERNEL_EX( - RotaryEmbedding, - kMSDomain, - 1, - float, - kCpuExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("M", DataTypeImpl::GetTensorType()), - RotaryEmbedding); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RotaryEmbedding, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + RotaryEmbedding); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { @@ -75,19 +79,27 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete const T* sin_data = sin_cache + cache_offset; int cache_idx = 0; - T sign = 0; + bool sign = false; int j = 0; for (int i = 0; i < rotary_emb_dim; i++) { if (interleaved) { cache_idx = (i / 2) % half_rotary_emb_dim; - sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); - j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign + sign = i & 1; + j = sign ? i - 1 : i + 1; // i - sign } else { cache_idx = i % half_rotary_emb_dim; - sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1); + sign = (i >= half_rotary_emb_dim); j = (i + half_rotary_emb_dim) % rotary_emb_dim; } - output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]); + float input_data_j = static_cast(input_data[j]); + float sin_data_cache_idx = static_cast(sin_data[cache_idx]); + if (sign) { + output_data_i += input_data_j * sin_data_cache_idx; + } else { + output_data_i -= input_data_j * sin_data_cache_idx; + } + output_data[i] = static_cast(output_data_i); } for (int i = rotary_emb_dim; i < head_size; i++) { output_data[i] = input_data[i]; @@ -102,6 +114,10 @@ template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryPar const int64_t* position_ids, const float* cos_cache, const float* sin_cache, float* output, bool interleaved); +template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const MLFloat16* input, + const int64_t* position_ids, const MLFloat16* cos_cache, const MLFloat16* sin_cache, + MLFloat16* output, bool interleaved); + template Status RotaryEmbedding::Compute(OpKernelContext* context) const { const Tensor* input = context->Input(0); diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index dcd1f5ec22b52..e75d485830ca5 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -22,8 +22,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GroupQueryAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); @@ -288,8 +290,10 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index dc21d4e4a5890..08ec5de328b9d 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -29,6 +29,12 @@ GREEN = "\033[32m" RESET = "\033[0m" +ORT_TYPE = TensorProto.FLOAT +TORCH_TYPE = torch.float16 if ORT_TYPE == TensorProto.FLOAT16 else torch.float32 +NUMPY_TYPE = numpy.float16 if ORT_TYPE == TensorProto.FLOAT16 else numpy.float32 +RTOL = 3e-2 if ORT_TYPE == TensorProto.FLOAT16 else 1e-3 +ATOL = RTOL + class Formats: BSNH = 0 @@ -186,7 +192,7 @@ def create_group_query_attention_graph_prompt( graph_input = [ helper.make_tensor_value_info( "query", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.q_sequence_length, @@ -212,7 +218,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length, @@ -221,7 +227,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length, @@ -233,7 +239,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "past_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -243,7 +249,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "past_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -256,7 +262,7 @@ def create_group_query_attention_graph_prompt( graph_input += [ helper.make_tensor_value_info( "cos_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.buffer_sequence_length if share_buffer else config.kv_sequence_length, (math.floor(config.head_size / 16) * 16) // 2, @@ -264,7 +270,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "sin_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.buffer_sequence_length if share_buffer else config.kv_sequence_length, (math.floor(config.head_size / 16) * 16) // 2, @@ -275,12 +281,12 @@ def create_group_query_attention_graph_prompt( graph_output = [ helper.make_tensor_value_info( "output", - TensorProto.FLOAT, + ORT_TYPE, [config.batch_size, config.q_sequence_length, config.num_heads * config.head_size], ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -290,7 +296,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -300,7 +306,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -310,7 +316,7 @@ def create_group_query_attention_graph_prompt( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -378,7 +384,7 @@ def create_group_query_attention_graph_past( graph_input = [ helper.make_tensor_value_info( "query", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -391,7 +397,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "past_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -401,7 +407,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "past_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -424,7 +430,7 @@ def create_group_query_attention_graph_past( graph_input += [ helper.make_tensor_value_info( "key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -433,7 +439,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, config.sequence_length, @@ -445,7 +451,7 @@ def create_group_query_attention_graph_past( graph_input += [ helper.make_tensor_value_info( "cos_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.kv_sequence_length + (0 if share_buffer else config.sequence_length), (math.floor(config.head_size / 16) * 16) // 2, @@ -453,7 +459,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "sin_cache", - TensorProto.FLOAT, + ORT_TYPE, [ config.kv_sequence_length + (0 if share_buffer else config.sequence_length), (math.floor(config.head_size / 16) * 16) // 2, @@ -464,12 +470,12 @@ def create_group_query_attention_graph_past( graph_output = [ helper.make_tensor_value_info( "output", - TensorProto.FLOAT, + ORT_TYPE, [config.batch_size, config.sequence_length, config.num_heads * config.head_size], ), helper.make_tensor_value_info( "present_key", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -479,7 +485,7 @@ def create_group_query_attention_graph_past( ), helper.make_tensor_value_info( "present_value", - TensorProto.FLOAT, + ORT_TYPE, [ config.batch_size, present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, @@ -641,7 +647,7 @@ def create_inputs(config: Config, kv_packed=False, qkv_packed=True): config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) key_padding_mask = generate_random_padding_mask( @@ -722,13 +728,13 @@ def gqa_prompt_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( - "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) io_binding.bind_input( "past_value", "cpu", 0, - numpy.float32, + NUMPY_TYPE, ort_inputs["past_value"].shape(), ort_inputs["past_value"].data_ptr(), ) @@ -835,13 +841,13 @@ def gqa_past_func( io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( - "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + "past_key", "cpu", 0, NUMPY_TYPE, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) io_binding.bind_input( "past_value", "cpu", 0, - numpy.float32, + NUMPY_TYPE, ort_inputs["past_value"].shape(), ort_inputs["past_value"].data_ptr(), ) @@ -1017,9 +1023,11 @@ def attention_ref( attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) @@ -1058,8 +1066,8 @@ def parity_check_gqa_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1067,7 +1075,7 @@ def parity_check_gqa_prompt( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1076,7 +1084,7 @@ def parity_check_gqa_prompt( config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1085,7 +1093,7 @@ def parity_check_gqa_prompt( config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1094,7 +1102,7 @@ def parity_check_gqa_prompt( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1103,7 +1111,7 @@ def parity_check_gqa_prompt( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1129,8 +1137,8 @@ def parity_check_gqa_prompt( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved @@ -1152,8 +1160,8 @@ def parity_check_gqa_prompt( kv_seqlens = torch.tensor([config.kv_sequence_length], device="cpu").repeat(config.batch_size) kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") update_mask = arange < kv_seqlens_expanded - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...").to(dtype=TORCH_TYPE) + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...").to(dtype=TORCH_TYPE) k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded @@ -1218,11 +1226,11 @@ def parity_check_gqa_prompt( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "KV-buffer", @@ -1271,8 +1279,8 @@ def parity_check_gqa_prompt_no_buff( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1280,7 +1288,7 @@ def parity_check_gqa_prompt_no_buff( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1289,7 +1297,7 @@ def parity_check_gqa_prompt_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1298,7 +1306,7 @@ def parity_check_gqa_prompt_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1321,8 +1329,8 @@ def parity_check_gqa_prompt_no_buff( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), rotary_seqlens, rotary_interleaved @@ -1405,11 +1413,11 @@ def parity_check_gqa_prompt_no_buff( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "No buff", @@ -1458,8 +1466,8 @@ def parity_check_gqa_past( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): q = torch.randn( config.batch_size, @@ -1467,7 +1475,7 @@ def parity_check_gqa_past( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1476,7 +1484,7 @@ def parity_check_gqa_past( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1485,7 +1493,7 @@ def parity_check_gqa_past( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1494,7 +1502,7 @@ def parity_check_gqa_past( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1503,7 +1511,7 @@ def parity_check_gqa_past( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1534,8 +1542,8 @@ def parity_check_gqa_past( rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved @@ -1624,11 +1632,11 @@ def parity_check_gqa_past( out = out.detach().cpu().numpy() # Make sure past-present buffer updating correctly - assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) - assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=RTOL, atol=ATOL, equal_nan=True) # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "KV-buffer", @@ -1677,8 +1685,8 @@ def parity_check_gqa_past_no_buff( packed=False, softcap=0.0, use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, ): torch.manual_seed(69) q = torch.randn( @@ -1687,7 +1695,7 @@ def parity_check_gqa_past_no_buff( config.num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) k = torch.randn( @@ -1696,7 +1704,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) v = torch.randn( @@ -1705,7 +1713,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_k = torch.randn( @@ -1714,7 +1722,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) new_v = torch.randn( @@ -1723,7 +1731,7 @@ def parity_check_gqa_past_no_buff( config.kv_num_heads, config.head_size, device="cpu", - dtype=torch.float32, + dtype=TORCH_TYPE, requires_grad=False, ) @@ -1759,8 +1767,8 @@ def parity_check_gqa_past_no_buff( angle = ( torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi ) - cos = torch.cos(angle).to(dtype=torch.float32) - sin = torch.sin(angle).to(dtype=torch.float32) + cos = torch.cos(angle).to(dtype=TORCH_TYPE) + sin = torch.sin(angle).to(dtype=TORCH_TYPE) rot = LlamaMSRotaryEmbedding() q_ro = rot( q.clone(), cos.unsqueeze(0).unsqueeze(2), sin.unsqueeze(0).unsqueeze(2), cache_seqlens, rotary_interleaved @@ -1849,7 +1857,7 @@ def parity_check_gqa_past_no_buff( out = out.detach().cpu().numpy() # Compare results - all_close = numpy.allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True) + all_close = numpy.allclose(out, out_ref, rtol=RTOL, atol=ATOL, equal_nan=True) correct = GREEN + "True" + RESET if all_close else RED + "False" + RESET print( "NO buff", @@ -1983,8 +1991,8 @@ def test_gqa_past(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -1996,8 +2004,8 @@ def test_gqa_past(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2042,8 +2050,8 @@ def test_gqa_interactive_one_batch(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2053,8 +2061,8 @@ def test_gqa_interactive_one_batch(self): config, local=local, past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, + rtol=RTOL, + atol=ATOL, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed,