Skip to content

Commit

Permalink
Merge branch 'microsoft:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel authored Jan 16, 2024
2 parents 42f6a9c + 8d4369b commit af23b8e
Show file tree
Hide file tree
Showing 30 changed files with 766 additions and 62 deletions.
2 changes: 1 addition & 1 deletion .pipelines/nuget_config/x64/packages.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="python" version="3.9.7" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.13.0" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.13.1" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.201201.7" targetFramework="native" />
</packages>
2 changes: 1 addition & 1 deletion .pipelines/nuget_config/x86/packages.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="pythonx86" version="3.9.7" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.13.0" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.13.1" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.201201.7" targetFramework="native" />
</packages>
2 changes: 1 addition & 1 deletion cmake/external/dml.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.0)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.1)

# Restore nuget packages, which will pull down the DirectML redist package.
add_custom_command(
Expand Down
10 changes: 10 additions & 0 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e
export ORTMODULE_MEMORY_OPT_LEVEL=0
```
### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, the memory-efficient gradient management is turned off. The gradient after it is computed in ONNX Runtime, will trigger the corresponding parameter's backward function through `PythonOpGrad` operator. This would help release the gradient buffer managed in ONNX Runtime, which originally is released once all backward computation finishes.

```bash
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1 # Enable
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable
```

### 2.2 Memory Optimization

Q: *Want to run a bigger batch size?*
Expand Down
63 changes: 55 additions & 8 deletions java/src/main/java/ai/onnxruntime/TensorInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.lang.reflect.Array;
import java.nio.Buffer;
import java.util.Arrays;
import java.util.stream.Collectors;

/** Describes an {@link OnnxTensor}, including it's size, shape and element type. */
public class TensorInfo implements ValueInfo {
Expand Down Expand Up @@ -159,6 +160,12 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
/** The shape of the tensor. */
final long[] shape;

/** The names of the unbound dimensions. */
final String[] dimensionNames;

/** If there are non-empty dimension names */
private final boolean hasNames;

/** The Java type of this tensor. */
public final OnnxJavaType type;

Expand All @@ -177,6 +184,9 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
*/
TensorInfo(long[] shape, OnnxJavaType type, OnnxTensorType onnxType) {
this.shape = shape;
this.dimensionNames = new String[shape.length];
Arrays.fill(dimensionNames, "");
this.hasNames = false;
this.type = type;
this.onnxType = onnxType;
this.numElements = elementCount(shape);
Expand All @@ -188,10 +198,20 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
* <p>Called from JNI.
*
* @param shape The tensor shape.
* @param names The dimension names.
* @param typeInt The native type int.
*/
TensorInfo(long[] shape, int typeInt) {
TensorInfo(long[] shape, String[] names, int typeInt) {
this.shape = shape;
this.dimensionNames = names;
boolean hasNames = false;
for (String s : names) {
if (!s.isEmpty()) {
hasNames = true;
break;
}
}
this.hasNames = hasNames;
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
this.numElements = elementCount(shape);
Expand All @@ -206,15 +226,42 @@ public long[] getShape() {
return Arrays.copyOf(shape, shape.length);
}

/**
* Get a copy of the tensor's named dimensions.
*
* @return A copof the tensor's named dimensions.
*/
public String[] getDimensionNames() {
return Arrays.copyOf(dimensionNames, dimensionNames.length);
}

@Override
public String toString() {
return "TensorInfo(javaType="
+ type.toString()
+ ",onnxType="
+ onnxType.toString()
+ ",shape="
+ Arrays.toString(shape)
+ ")";
String output =
"TensorInfo(javaType="
+ type.toString()
+ ",onnxType="
+ onnxType.toString()
+ ",shape="
+ Arrays.toString(shape);
if (hasNames) {
output =
output
+ ",dimNames=["
+ Arrays.stream(dimensionNames)
.map(
a -> {
if (a.isEmpty()) {
return "\"\"";
} else {
return a;
}
})
.collect(Collectors.joining(","))
+ "]";
}
output = output + ")";
return output;
}

/**
Expand Down
26 changes: 22 additions & 4 deletions java/src/main/native/OrtJniUtil.c
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
if (code != ORT_OK) {
return NULL;
}
//printf("numDim %d\n",numDim);
int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim);
code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim));
if (code != ORT_OK) {
Expand All @@ -358,12 +357,31 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
free(dimensions);
dimensions = NULL;

// Create the string array for the names.
const char** dimensionNames = (const char**) malloc(sizeof(char*)*numDim);
if (dimensionNames == NULL) {
throwOrtException(jniEnv, 1, "Not enough memory");
return NULL;
}
code = checkOrtStatus(jniEnv, api, api->GetSymbolicDimensions(info, dimensionNames, numDim));
if (code != ORT_OK) {
// extraction failed, exception has been thrown, return to Java.
free(dimensionNames);
return NULL;
}
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String");
jobjectArray names = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numDim), stringClazz, NULL);
for (size_t i = 0; i < numDim; i++) {
jobject javaName = (*jniEnv)->NewStringUTF(jniEnv, dimensionNames[i]);
(*jniEnv)->SetObjectArrayElement(jniEnv, names, safecast_size_t_to_jsize(i), javaName);
}
free(dimensionNames);

// Create the TensorInfo object
static const char *tensorInfoClassName = "ai/onnxruntime/TensorInfo";
jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName);
jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "<init>", "([JI)V");
//printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor);
jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, onnxTypeInt);
jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "<init>", "([J[Ljava/lang/String;I)V");
jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, names, onnxTypeInt);
return tensorInfo;
}

Expand Down
6 changes: 6 additions & 0 deletions java/src/test/java/ai/onnxruntime/InferenceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,12 @@ public void testSymbolicDimensionAssignment() throws OrtException {
Map<String, NodeInfo> infoMap = session.getInputInfo();
TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo();
assertArrayEquals(new long[] {-1, 2}, aInfo.shape);
assertEquals(2, aInfo.dimensionNames.length);
assertEquals("n", aInfo.dimensionNames[0]);
assertEquals("", aInfo.dimensionNames[1]);
TensorInfo bInfo = (TensorInfo) infoMap.get("B").getInfo();
assertEquals(1, bInfo.dimensionNames.length);
assertEquals("m", bInfo.dimensionNames[0]);
}
}
// Check that when the options are assigned it overrides the symbolic dimension
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
${calculateAlpha}
${(() => {
if (c != null) {
return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += uniforms.beta * ${
c.getByOffset('cOffset')};`;
return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += ${
dataType}(uniforms.beta) * ${c.getByOffset('cOffset')};`;
}
return '';
})()}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/execution_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ void IExecutionFrame::Init(gsl::span<const int> feed_mlvalue_idxs, gsl::span<con
const std::unordered_map<int, OrtValue>& initializers,
const std::function<bool(const std::string& name)>& is_initializer_sparse_func,
gsl::span<const OrtValue> fetches) {
ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size());
ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size(), "Get feed size: ", feeds.size(), " but expected feed size: ",
feed_mlvalue_idxs.size());
ORT_ENFORCE(fetches.empty() || fetches.size() == fetch_mlvalue_idxs_.size());

// Need this for sparse conversions in host memory
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/core/providers/dml/dml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@ static std::optional<OrtDmlDeviceFilter> ParseFilter(const ProviderOptions& prov
static const std::string Any = "any";
static const std::string Gpu = "gpu";
#ifdef ENABLE_NPU_ADAPTER_ENUMERATION
static const std::string Any = "any";
static const std::string Npu = "npu";
#endif

Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,14 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) {
// TODO: Remove legacy "type" once all browsers implement the new "dataType".
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
desc.set("type", emscripten::val("uint8"));
desc.set("dataType", emscripten::val("uint8"));
return true;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
desc.set("type", emscripten::val("int8"));
desc.set("dataType", emscripten::val("int8"));
return true;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
desc.set("type", emscripten::val("float16"));
desc.set("dataType", emscripten::val("float16"));
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,12 @@ inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::va
}
switch (tensor.data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
scalar = emscripten::val{*reinterpret_cast<uint8_t*>(unpacked_tensor.data())};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
scalar = emscripten::val{*reinterpret_cast<int8_t*>(unpacked_tensor.data())};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
scalar = emscripten::val{MLFloat16::FromBits(*reinterpret_cast<uint16_t*>(unpacked_tensor.data())).ToFloat()};
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
std::string operand_type;
switch (to_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
operand_type = "uint8";
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
operand_type = "int8";
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
operand_type = "float16";
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,12 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder,
size_t element_size{0};
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
element_size = sizeof(uint8_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
element_size = sizeof(int8_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
element_size = sizeof(uint16_t);
break;
Expand Down
18 changes: 14 additions & 4 deletions onnxruntime/core/providers/webnn/builders/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
emscripten::val view = emscripten::val::undefined();
switch (tensor.tensor_info.data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint8_t*>(tensor.buffer))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const int8_t*>(tensor.buffer))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint16_t*>(tensor.buffer))};
Expand Down Expand Up @@ -90,11 +93,14 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
emscripten::val view = emscripten::val::undefined();
switch (tensor.tensor_info.data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint8_t*>(tensor.buffer))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const int8_t*>(tensor.buffer))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
static_cast<const uint16_t*>(tensor.buffer))};
Expand Down Expand Up @@ -168,10 +174,12 @@ void Model::AllocateInputOutputBuffers() {
const auto data_type = input_info.data_type;
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
wnn_inputs_.set(input, emscripten::val::global("Uint8Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
wnn_inputs_.set(input, emscripten::val::global("Int8Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
wnn_inputs_.set(input, emscripten::val::global("Uint16Array").new_(num_elements));
break;
Expand Down Expand Up @@ -201,10 +209,12 @@ void Model::AllocateInputOutputBuffers() {
const auto data_type = output_info.data_type;
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
wnn_outputs_.set(output, emscripten::val::global("Uint8Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
wnn_outputs_.set(output, emscripten::val::global("Int8Array").new_(num_elements));
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
wnn_outputs_.set(output, emscripten::val::global("Uint16Array").new_(num_elements));
break;
Expand Down
11 changes: 9 additions & 2 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,16 @@ Status ModelBuilder::RegisterInitializers() {
}
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
desc.set("type", emscripten::val("uint8"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint8_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
desc.set("type", emscripten::val("int8"));
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int8_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint16_t*>(tensor_ptr))};
Expand Down Expand Up @@ -318,11 +322,14 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer(
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint8_t),
reinterpret_cast<const uint8_t*>(dest))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int8_t),
reinterpret_cast<const int8_t*>(dest))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint16_t),
reinterpret_cast<const uint16_t*>(dest))};
Expand Down
Loading

0 comments on commit af23b8e

Please sign in to comment.