From 36b5f245779ff14eb53bc927b2e1b1dd2e4038a9 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 22 Oct 2019 17:39:15 -0700 Subject: [PATCH 1/6] C Api for simplebind, fix comment for trigoops, add atol to assert --- include/mxnet/c_api.h | 38 +++++++++++ python/mxnet/symbol/symbol.py | 110 ++++++++++++++++++++---------- src/c_api/c_api_executor.cc | 105 ++++++++++++++++++++++++++++ tests/nightly/test_large_array.py | 6 +- 4 files changed, 220 insertions(+), 39 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 177ec5d40146..dbe7e02eb80f 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -2255,6 +2255,44 @@ MXNET_DLL int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, NDArrayHandle** aux_states, ExecutorHandle shared_exec_handle, ExecutorHandle* out); + + +MXNET_DLL int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const uint32_t num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const uint32_t provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const uint32_t num_provided_arg_shapes, + const char** provided_arg_shape_names, + const int64_t* provided_arg_shape_data, + const uint32_t* provided_arg_shape_idx, + const uint32_t num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const uint32_t num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const uint32_t num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + uint32_t* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + uint32_t* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out); + + /*! * \brief DEPRECATED. Use MXExecutorReshapeEx instead. * Return a new executor with the same symbol and shared memory, diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index b8e8db57188c..a251f62e69e0 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1695,42 +1695,80 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None, aux_state_handles = ctypes.POINTER(NDArrayHandle)() try: - check_call(_LIB.MXExecutorSimpleBindEx(self.handle, - ctypes.c_int(ctx.device_typeid), - ctypes.c_int(ctx.device_id), - num_ctx_map_keys, - ctx_map_keys, - ctx_map_dev_types, - ctx_map_dev_ids, - mx_uint(provided_req_type_list_len), - provided_grad_req_names, - provided_grad_req_types, - mx_uint(len(provided_arg_shape_names)), - c_str_array(provided_arg_shape_names), - c_array_buf(mx_int, - array('I', provided_arg_shape_data)), - c_array_buf(mx_uint, - array('i', provided_arg_shape_idx)), - num_provided_arg_types, - provided_arg_type_names, - provided_arg_type_data, - num_provided_arg_stypes, - provided_arg_stype_names, - provided_arg_stype_data, - mx_uint(len(shared_arg_name_list)), - c_str_array(shared_arg_name_list), - ctypes.byref(shared_buffer_len), - shared_buffer_names, - shared_buffer_handles, - ctypes.byref(updated_shared_buffer_names), - ctypes.byref(updated_shared_buffer_handles), - ctypes.byref(num_in_args), - ctypes.byref(in_arg_handles), - ctypes.byref(arg_grad_handles), - ctypes.byref(num_aux_states), - ctypes.byref(aux_state_handles), - shared_exec_handle, - ctypes.byref(exe_handle))) + if sys.version_info[0] > 2 and _int64_enabled(): + check_call(_LIB.MXExecutorSimpleBindEx64(self.handle, + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + num_ctx_map_keys, + ctx_map_keys, + ctx_map_dev_types, + ctx_map_dev_ids, + mx_uint(provided_req_type_list_len), + provided_grad_req_names, + provided_grad_req_types, + mx_uint(len(provided_arg_shape_names)), + c_str_array(provided_arg_shape_names), + c_array_buf(mx_int64, + array('q', provided_arg_shape_data)), + c_array_buf(mx_uint, + array('i', provided_arg_shape_idx)), + num_provided_arg_types, + provided_arg_type_names, + provided_arg_type_data, + num_provided_arg_stypes, + provided_arg_stype_names, + provided_arg_stype_data, + mx_uint(len(shared_arg_name_list)), + c_str_array(shared_arg_name_list), + ctypes.byref(shared_buffer_len), + shared_buffer_names, + shared_buffer_handles, + ctypes.byref(updated_shared_buffer_names), + ctypes.byref(updated_shared_buffer_handles), + ctypes.byref(num_in_args), + ctypes.byref(in_arg_handles), + ctypes.byref(arg_grad_handles), + ctypes.byref(num_aux_states), + ctypes.byref(aux_state_handles), + shared_exec_handle, + ctypes.byref(exe_handle))) + else: + check_call(_LIB.MXExecutorSimpleBindEx(self.handle, + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + num_ctx_map_keys, + ctx_map_keys, + ctx_map_dev_types, + ctx_map_dev_ids, + mx_uint(provided_req_type_list_len), + provided_grad_req_names, + provided_grad_req_types, + mx_uint(len(provided_arg_shape_names)), + c_str_array(provided_arg_shape_names), + c_array_buf(mx_int, + array('I', provided_arg_shape_data)), + c_array_buf(mx_uint, + array('i', provided_arg_shape_idx)), + num_provided_arg_types, + provided_arg_type_names, + provided_arg_type_data, + num_provided_arg_stypes, + provided_arg_stype_names, + provided_arg_stype_data, + mx_uint(len(shared_arg_name_list)), + c_str_array(shared_arg_name_list), + ctypes.byref(shared_buffer_len), + shared_buffer_names, + shared_buffer_handles, + ctypes.byref(updated_shared_buffer_names), + ctypes.byref(updated_shared_buffer_handles), + ctypes.byref(num_in_args), + ctypes.byref(in_arg_handles), + ctypes.byref(arg_grad_handles), + ctypes.byref(num_aux_states), + ctypes.byref(aux_state_handles), + shared_exec_handle, + ctypes.byref(exe_handle))) except MXNetError as e: error_msg = "simple_bind error. Arguments:\n" for k, v in kwargs.items(): diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index ff85b4fd62fa..25ac903a4dcd 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -586,6 +586,111 @@ int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, NDArrayHandle** aux_states, ExecutorHandle shared_exec_handle, ExecutorHandle* out) { + return SimpleBindExMaster(symbol_handle, + dev_type, dev_id, + num_g2c_keys, g2c_keys, g2c_dev_types, g2c_dev_ids, + provided_grad_req_list_len, provided_grad_req_names, provided_grad_req_types, + num_provided_arg_shapes, provided_arg_shape_names, + provided_arg_shape_data, provided_arg_shape_idx, + num_provided_arg_dtypes, provided_arg_dtype_names, provided_arg_dtypes, + num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stypes, + num_shared_arg_names, shared_arg_name_list, + shared_buffer_len, shared_buffer_name_list, + shared_buffer_handle_list, updated_shared_buffer_name_list, + updated_shared_buffer_handle_list, + num_in_args, in_args, arg_grads, + num_aux_states, aux_states, + shared_exec_handle, out) +} + + +int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const uint32_t num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const uint32_t provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const uint32_t num_provided_arg_shapes, + const char** provided_arg_shape_names, + const int64_t* provided_arg_shape_data, + const uint32_t* provided_arg_shape_idx, + const uint32_t num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const uint32_t num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const uint32_t num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + uint32_t* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + uint32_t* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out) { + return SimpleBindExMaster(symbol_handle, + dev_type, dev_id, + num_g2c_keys, g2c_keys, g2c_dev_types, g2c_dev_ids, + provided_grad_req_list_len, provided_grad_req_names, provided_grad_req_types, + num_provided_arg_shapes, provided_arg_shape_names, + provided_arg_shape_data, provided_arg_shape_idx, + num_provided_arg_dtypes, provided_arg_dtype_names, provided_arg_dtypes, + num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stypes, + num_shared_arg_names, shared_arg_name_list, + shared_buffer_len, shared_buffer_name_list, + shared_buffer_handle_list, updated_shared_buffer_name_list, + updated_shared_buffer_handle_list, + num_in_args, in_args, arg_grads, + num_aux_states, aux_states, + shared_exec_handle, out) +} + + +template +int SimpleBindExMaster(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const uint32_t num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const uint32_t provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const uint32_t num_provided_arg_shapes, + const char** provided_arg_shape_names, + const DType* provided_arg_shape_data, + const uint32_t* provided_arg_shape_idx, + const uint32_t num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const uint32_t num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const uint32_t num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + uint32_t* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + uint32_t* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out) { MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); nnvm::Symbol *sym = static_cast(symbol_handle); diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 9c2fbd61a9ac..9a90920cb044 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1295,17 +1295,17 @@ def check_trunc(): def create_input_for_trigonometric_ops(vals): - # Creates large vector input of size(LARGE_X*10, SMALL_Y/10) from vals using tile operator + # Creates large vector input of size(LARGE_X*10, SMALL_Y/10) from vals using broadcast_to operator inp = nd.array(vals).reshape(1, 5) inp = nd.broadcast_to(inp, (LARGE_X*10, SMALL_Y//10)) return inp -def assert_correctness_of_trigonometric_ops(output, expected_vals): +def assert_correctness_of_trigonometric_ops(output, expected_vals, atol=1e-3): # checks verifies 5 values at positions(0, 1, -3, -2, -1) of the input vector output_idx_to_inspect = [0, 1, -3, -2, -1] for i in range(len(output_idx_to_inspect)): - assert np.abs(output[1][output_idx_to_inspect[i]].asnumpy()-expected_vals[i]) <= 1e-3 + assert np.abs(output[1][output_idx_to_inspect[i]].asnumpy()-expected_vals[i]) <= atol def test_trigonometric_ops(): From ff72b205861aeef9860db20bd3e9509ef2993ed4 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 23 Oct 2019 10:33:20 -0700 Subject: [PATCH 2/6] fix build issues --- src/c_api/c_api_executor.cc | 284 ++++++++++++++++++------------------ 1 file changed, 145 insertions(+), 139 deletions(-) diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 25ac903a4dcd..27b981cd0b0e 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -515,146 +515,8 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, API_END(); } -/*! - * \brief - * \param symbol_handle symbol handle - * \param dev_type default device type - * \param dev_id default device id - * \param num_g2c_keys number of group2ctx keys - * \param g2c_keys key list of group2ctx - * \param g2c_dev_types device type list of group2ctx - * \param g2c_dev_ids id list of group2ctx - * \param provided_grad_req_list_len grad_req length provided by users in front-end - * \param provided_grad_req_names grad_req names provided by users in front-end - * \param provided_grad_req_types req types provided by users in front-end - * \param num_provided_arg_shapes number of user provided in_arg and aux_state shapes - * \param provided_arg_shape_names name list of provided shapes - * \param provided_arg_shape_data provided shape data - * \param provided_arg_shape_idx provided shape data index - * \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes - * \param provided_arg_dtype_names argument name list of provided dtypes - * \param provided_arg_dtypes data of provided dtypes - * \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types - * \param provided_arg_stype_names argument name list of provided storage types - * \param provided_arg_stypes data of provided storage types - * \param num_shared_arg_names number of parameter names passed from _bind_ith_exec - * \param shared_arg_name_list parameter name list passed from _bind_ith_exec - * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec - * \param shared_buffer_name_list shared data array names passed from _bind_ith_exec - * \param shared_buffer_handle_list shared data array handles passed from _bind_ith_exec - * \param updated_shared_buffer_name_list updated shared data array names after binding - * \param updated_shared_buffer_handle_list updated shared data arrays after binding - * \param num_in_args number of input arguments of this sym - * \param in_args list_arguments associated with the current executor - * \param arg_grads list of gradients of in_args associated with the current executor - * \param num_aux_states number of aux states of this sym - * \param aux_states list_auxiliary_states associated with the current executor - * \param shared_exec_handle shared excutor handle passed from _bind_ith_exec - * \param out the handle of the executor to be created - */ -int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, - int dev_type, - int dev_id, - const uint32_t num_g2c_keys, - const char** g2c_keys, - const int* g2c_dev_types, - const int* g2c_dev_ids, - const uint32_t provided_grad_req_list_len, - const char** provided_grad_req_names, - const char** provided_grad_req_types, - const uint32_t num_provided_arg_shapes, - const char** provided_arg_shape_names, - const int* provided_arg_shape_data, - const uint32_t* provided_arg_shape_idx, - const uint32_t num_provided_arg_dtypes, - const char** provided_arg_dtype_names, - const int* provided_arg_dtypes, - const uint32_t num_provided_arg_stypes, - const char** provided_arg_stype_names, - const int* provided_arg_stypes, - const uint32_t num_shared_arg_names, - const char** shared_arg_name_list, - int* shared_buffer_len, - const char** shared_buffer_name_list, - NDArrayHandle* shared_buffer_handle_list, - const char*** updated_shared_buffer_name_list, - NDArrayHandle** updated_shared_buffer_handle_list, - uint32_t* num_in_args, - NDArrayHandle** in_args, - NDArrayHandle** arg_grads, - uint32_t* num_aux_states, - NDArrayHandle** aux_states, - ExecutorHandle shared_exec_handle, - ExecutorHandle* out) { - return SimpleBindExMaster(symbol_handle, - dev_type, dev_id, - num_g2c_keys, g2c_keys, g2c_dev_types, g2c_dev_ids, - provided_grad_req_list_len, provided_grad_req_names, provided_grad_req_types, - num_provided_arg_shapes, provided_arg_shape_names, - provided_arg_shape_data, provided_arg_shape_idx, - num_provided_arg_dtypes, provided_arg_dtype_names, provided_arg_dtypes, - num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stypes, - num_shared_arg_names, shared_arg_name_list, - shared_buffer_len, shared_buffer_name_list, - shared_buffer_handle_list, updated_shared_buffer_name_list, - updated_shared_buffer_handle_list, - num_in_args, in_args, arg_grads, - num_aux_states, aux_states, - shared_exec_handle, out) -} - - -int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle, - int dev_type, - int dev_id, - const uint32_t num_g2c_keys, - const char** g2c_keys, - const int* g2c_dev_types, - const int* g2c_dev_ids, - const uint32_t provided_grad_req_list_len, - const char** provided_grad_req_names, - const char** provided_grad_req_types, - const uint32_t num_provided_arg_shapes, - const char** provided_arg_shape_names, - const int64_t* provided_arg_shape_data, - const uint32_t* provided_arg_shape_idx, - const uint32_t num_provided_arg_dtypes, - const char** provided_arg_dtype_names, - const int* provided_arg_dtypes, - const uint32_t num_provided_arg_stypes, - const char** provided_arg_stype_names, - const int* provided_arg_stypes, - const uint32_t num_shared_arg_names, - const char** shared_arg_name_list, - int* shared_buffer_len, - const char** shared_buffer_name_list, - NDArrayHandle* shared_buffer_handle_list, - const char*** updated_shared_buffer_name_list, - NDArrayHandle** updated_shared_buffer_handle_list, - uint32_t* num_in_args, - NDArrayHandle** in_args, - NDArrayHandle** arg_grads, - uint32_t* num_aux_states, - NDArrayHandle** aux_states, - ExecutorHandle shared_exec_handle, - ExecutorHandle* out) { - return SimpleBindExMaster(symbol_handle, - dev_type, dev_id, - num_g2c_keys, g2c_keys, g2c_dev_types, g2c_dev_ids, - provided_grad_req_list_len, provided_grad_req_names, provided_grad_req_types, - num_provided_arg_shapes, provided_arg_shape_names, - provided_arg_shape_data, provided_arg_shape_idx, - num_provided_arg_dtypes, provided_arg_dtype_names, provided_arg_dtypes, - num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stypes, - num_shared_arg_names, shared_arg_name_list, - shared_buffer_len, shared_buffer_name_list, - shared_buffer_handle_list, updated_shared_buffer_name_list, - updated_shared_buffer_handle_list, - num_in_args, in_args, arg_grads, - num_aux_states, aux_states, - shared_exec_handle, out) -} +namespace mxnet { template int SimpleBindExMaster(SymbolHandle symbol_handle, @@ -954,6 +816,150 @@ int SimpleBindExMaster(SymbolHandle symbol_handle, API_END(); } +} // namespace mxnet + + +/*! + * \brief + * \param symbol_handle symbol handle + * \param dev_type default device type + * \param dev_id default device id + * \param num_g2c_keys number of group2ctx keys + * \param g2c_keys key list of group2ctx + * \param g2c_dev_types device type list of group2ctx + * \param g2c_dev_ids id list of group2ctx + * \param provided_grad_req_list_len grad_req length provided by users in front-end + * \param provided_grad_req_names grad_req names provided by users in front-end + * \param provided_grad_req_types req types provided by users in front-end + * \param num_provided_arg_shapes number of user provided in_arg and aux_state shapes + * \param provided_arg_shape_names name list of provided shapes + * \param provided_arg_shape_data provided shape data + * \param provided_arg_shape_idx provided shape data index + * \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes + * \param provided_arg_dtype_names argument name list of provided dtypes + * \param provided_arg_dtypes data of provided dtypes + * \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types + * \param provided_arg_stype_names argument name list of provided storage types + * \param provided_arg_stypes data of provided storage types + * \param num_shared_arg_names number of parameter names passed from _bind_ith_exec + * \param shared_arg_name_list parameter name list passed from _bind_ith_exec + * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec + * \param shared_buffer_name_list shared data array names passed from _bind_ith_exec + * \param shared_buffer_handle_list shared data array handles passed from _bind_ith_exec + * \param updated_shared_buffer_name_list updated shared data array names after binding + * \param updated_shared_buffer_handle_list updated shared data arrays after binding + * \param num_in_args number of input arguments of this sym + * \param in_args list_arguments associated with the current executor + * \param arg_grads list of gradients of in_args associated with the current executor + * \param num_aux_states number of aux states of this sym + * \param aux_states list_auxiliary_states associated with the current executor + * \param shared_exec_handle shared excutor handle passed from _bind_ith_exec + * \param out the handle of the executor to be created + */ +int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const uint32_t num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const uint32_t provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const uint32_t num_provided_arg_shapes, + const char** provided_arg_shape_names, + const int* provided_arg_shape_data, + const uint32_t* provided_arg_shape_idx, + const uint32_t num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const uint32_t num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const uint32_t num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + uint32_t* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + uint32_t* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out) { + return mxnet::SimpleBindExMaster(symbol_handle, + dev_type, dev_id, + num_g2c_keys, g2c_keys, g2c_dev_types, g2c_dev_ids, + provided_grad_req_list_len, provided_grad_req_names, provided_grad_req_types, + num_provided_arg_shapes, provided_arg_shape_names, + provided_arg_shape_data, provided_arg_shape_idx, + num_provided_arg_dtypes, provided_arg_dtype_names, provided_arg_dtypes, + num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stypes, + num_shared_arg_names, shared_arg_name_list, + shared_buffer_len, shared_buffer_name_list, + shared_buffer_handle_list, updated_shared_buffer_name_list, + updated_shared_buffer_handle_list, + num_in_args, in_args, arg_grads, + num_aux_states, aux_states, + shared_exec_handle, out); +} + + +int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const uint32_t num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const uint32_t provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const uint32_t num_provided_arg_shapes, + const char** provided_arg_shape_names, + const int64_t* provided_arg_shape_data, + const uint32_t* provided_arg_shape_idx, + const uint32_t num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const uint32_t num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const uint32_t num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + uint32_t* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + uint32_t* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out) { + return mxnet::SimpleBindExMaster(symbol_handle, + dev_type, dev_id, + num_g2c_keys, g2c_keys, g2c_dev_types, g2c_dev_ids, + provided_grad_req_list_len, provided_grad_req_names, provided_grad_req_types, + num_provided_arg_shapes, provided_arg_shape_names, + provided_arg_shape_data, provided_arg_shape_idx, + num_provided_arg_dtypes, provided_arg_dtype_names, provided_arg_dtypes, + num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stypes, + num_shared_arg_names, shared_arg_name_list, + shared_buffer_len, shared_buffer_name_list, + shared_buffer_handle_list, updated_shared_buffer_name_list, + updated_shared_buffer_handle_list, + num_in_args, in_args, arg_grads, + num_aux_states, aux_states, + shared_exec_handle, out); +} + + int MXExecutorReshape(int partial_shaping, int allow_up_sizing, int dev_type, From 5ee7971316e106b17ae9dbbb5ce0996374b69f46 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 23 Oct 2019 11:02:49 -0700 Subject: [PATCH 3/6] fix lint and add regression test --- src/c_api/c_api_executor.cc | 8 +++++--- tests/nightly/test_large_vector.py | 33 ++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 27b981cd0b0e..2c911ee2168c 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -816,7 +816,7 @@ int SimpleBindExMaster(SymbolHandle symbol_handle, API_END(); } -} // namespace mxnet +} // namespace mxnet /*! @@ -893,7 +893,8 @@ int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, return mxnet::SimpleBindExMaster(symbol_handle, dev_type, dev_id, num_g2c_keys, g2c_keys, g2c_dev_types, g2c_dev_ids, - provided_grad_req_list_len, provided_grad_req_names, provided_grad_req_types, + provided_grad_req_list_len, provided_grad_req_names, + provided_grad_req_types, num_provided_arg_shapes, provided_arg_shape_names, provided_arg_shape_data, provided_arg_shape_idx, num_provided_arg_dtypes, provided_arg_dtype_names, provided_arg_dtypes, @@ -945,7 +946,8 @@ int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle, return mxnet::SimpleBindExMaster(symbol_handle, dev_type, dev_id, num_g2c_keys, g2c_keys, g2c_dev_types, g2c_dev_ids, - provided_grad_req_list_len, provided_grad_req_names, provided_grad_req_types, + provided_grad_req_list_len, provided_grad_req_names, + provided_grad_req_types, num_provided_arg_shapes, provided_arg_shape_names, provided_arg_shape_data, provided_arg_shape_idx, num_provided_arg_dtypes, provided_arg_dtype_names, provided_arg_dtypes, diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index 23f4b8e4f310..4981c4207505 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -710,6 +710,39 @@ def test_full(): assert a[-1] == 3 +def test_regression(): + shape = (LARGE_X, ) + + def check_regression(symbol, forward, shape): + # init executor + data_s = mx.symbol.Variable('data') + label_s = mx.symbol.Variable('label') + out_s = symbol(data=data_s, label=label_s) + exe = out_s.simple_bind(ctx=mx.cpu(0), data=shape, label=shape) + + arg_map = dict(zip(out_s.list_arguments(), exe.arg_arrays)) + + # init data + data = mx.random.uniform(-1, -1, shape) + arg_map["data"][:] = data + atol = 1e-5 + density = 0.5 + stype = 'default' + label = arg_map["label"] + label[:] = rand_ndarray(shape, stype, density=density) + exe.forward(is_train=True) + exe.backward() + np_out = forward(data.asnumpy()) + assert_almost_equal(exe.outputs[0].asnumpy(), np_out, atol=atol) + + check_regression(mx.symbol.LogisticRegressionOutput, + lambda x: 1.0 / (1.0 + np.exp(-x)), + shape) + check_regression(mx.symbol.LinearRegressionOutput, + lambda x: x, + shape) + + def test_astype(): x = create_vector(size=LARGE_X//4) x = nd.tile(x, 4) From 5e43d7e06f469fab17dce39e48a1e82a3ca92e18 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 23 Oct 2019 12:15:34 -0700 Subject: [PATCH 4/6] fix indent --- python/mxnet/symbol/symbol.py | 140 +++++++++++++++++----------------- 1 file changed, 70 insertions(+), 70 deletions(-) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index a251f62e69e0..6146ab9dc50e 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1697,78 +1697,78 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None, try: if sys.version_info[0] > 2 and _int64_enabled(): check_call(_LIB.MXExecutorSimpleBindEx64(self.handle, - ctypes.c_int(ctx.device_typeid), - ctypes.c_int(ctx.device_id), - num_ctx_map_keys, - ctx_map_keys, - ctx_map_dev_types, - ctx_map_dev_ids, - mx_uint(provided_req_type_list_len), - provided_grad_req_names, - provided_grad_req_types, - mx_uint(len(provided_arg_shape_names)), - c_str_array(provided_arg_shape_names), - c_array_buf(mx_int64, - array('q', provided_arg_shape_data)), - c_array_buf(mx_uint, - array('i', provided_arg_shape_idx)), - num_provided_arg_types, - provided_arg_type_names, - provided_arg_type_data, - num_provided_arg_stypes, - provided_arg_stype_names, - provided_arg_stype_data, - mx_uint(len(shared_arg_name_list)), - c_str_array(shared_arg_name_list), - ctypes.byref(shared_buffer_len), - shared_buffer_names, - shared_buffer_handles, - ctypes.byref(updated_shared_buffer_names), - ctypes.byref(updated_shared_buffer_handles), - ctypes.byref(num_in_args), - ctypes.byref(in_arg_handles), - ctypes.byref(arg_grad_handles), - ctypes.byref(num_aux_states), - ctypes.byref(aux_state_handles), - shared_exec_handle, - ctypes.byref(exe_handle))) + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + num_ctx_map_keys, + ctx_map_keys, + ctx_map_dev_types, + ctx_map_dev_ids, + mx_uint(provided_req_type_list_len), + provided_grad_req_names, + provided_grad_req_types, + mx_uint(len(provided_arg_shape_names)), + c_str_array(provided_arg_shape_names), + c_array_buf(mx_int64, + array('q', provided_arg_shape_data)), + c_array_buf(mx_uint, + array('i', provided_arg_shape_idx)), + num_provided_arg_types, + provided_arg_type_names, + provided_arg_type_data, + num_provided_arg_stypes, + provided_arg_stype_names, + provided_arg_stype_data, + mx_uint(len(shared_arg_name_list)), + c_str_array(shared_arg_name_list), + ctypes.byref(shared_buffer_len), + shared_buffer_names, + shared_buffer_handles, + ctypes.byref(updated_shared_buffer_names), + ctypes.byref(updated_shared_buffer_handles), + ctypes.byref(num_in_args), + ctypes.byref(in_arg_handles), + ctypes.byref(arg_grad_handles), + ctypes.byref(num_aux_states), + ctypes.byref(aux_state_handles), + shared_exec_handle, + ctypes.byref(exe_handle))) else: check_call(_LIB.MXExecutorSimpleBindEx(self.handle, - ctypes.c_int(ctx.device_typeid), - ctypes.c_int(ctx.device_id), - num_ctx_map_keys, - ctx_map_keys, - ctx_map_dev_types, - ctx_map_dev_ids, - mx_uint(provided_req_type_list_len), - provided_grad_req_names, - provided_grad_req_types, - mx_uint(len(provided_arg_shape_names)), - c_str_array(provided_arg_shape_names), - c_array_buf(mx_int, - array('I', provided_arg_shape_data)), - c_array_buf(mx_uint, - array('i', provided_arg_shape_idx)), - num_provided_arg_types, - provided_arg_type_names, - provided_arg_type_data, - num_provided_arg_stypes, - provided_arg_stype_names, - provided_arg_stype_data, - mx_uint(len(shared_arg_name_list)), - c_str_array(shared_arg_name_list), - ctypes.byref(shared_buffer_len), - shared_buffer_names, - shared_buffer_handles, - ctypes.byref(updated_shared_buffer_names), - ctypes.byref(updated_shared_buffer_handles), - ctypes.byref(num_in_args), - ctypes.byref(in_arg_handles), - ctypes.byref(arg_grad_handles), - ctypes.byref(num_aux_states), - ctypes.byref(aux_state_handles), - shared_exec_handle, - ctypes.byref(exe_handle))) + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + num_ctx_map_keys, + ctx_map_keys, + ctx_map_dev_types, + ctx_map_dev_ids, + mx_uint(provided_req_type_list_len), + provided_grad_req_names, + provided_grad_req_types, + mx_uint(len(provided_arg_shape_names)), + c_str_array(provided_arg_shape_names), + c_array_buf(mx_int, + array('I', provided_arg_shape_data)), + c_array_buf(mx_uint, + array('i', provided_arg_shape_idx)), + num_provided_arg_types, + provided_arg_type_names, + provided_arg_type_data, + num_provided_arg_stypes, + provided_arg_stype_names, + provided_arg_stype_data, + mx_uint(len(shared_arg_name_list)), + c_str_array(shared_arg_name_list), + ctypes.byref(shared_buffer_len), + shared_buffer_names, + shared_buffer_handles, + ctypes.byref(updated_shared_buffer_names), + ctypes.byref(updated_shared_buffer_handles), + ctypes.byref(num_in_args), + ctypes.byref(in_arg_handles), + ctypes.byref(arg_grad_handles), + ctypes.byref(num_aux_states), + ctypes.byref(aux_state_handles), + shared_exec_handle, + ctypes.byref(exe_handle))) except MXNetError as e: error_msg = "simple_bind error. Arguments:\n" for k, v in kwargs.items(): From 6c390ec0444a8e2f3ee7d0374ae1f0a68a6bfc14 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 23 Oct 2019 15:57:03 -0700 Subject: [PATCH 5/6] api doc and function name change --- src/c_api/c_api_executor.cc | 48 +++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 2c911ee2168c..afc64f73de7c 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -519,7 +519,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, namespace mxnet { template -int SimpleBindExMaster(SymbolHandle symbol_handle, +int _SimpleBindImpl(SymbolHandle symbol_handle, int dev_type, int dev_id, const uint32_t num_g2c_keys, @@ -820,7 +820,8 @@ int SimpleBindExMaster(SymbolHandle symbol_handle, /*! - * \brief + * \brief Executor for simple_bind + * when INT64_TENSOR_SIZE = OFF * \param symbol_handle symbol handle * \param dev_type default device type * \param dev_id default device id @@ -890,7 +891,7 @@ int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, NDArrayHandle** aux_states, ExecutorHandle shared_exec_handle, ExecutorHandle* out) { - return mxnet::SimpleBindExMaster(symbol_handle, + return mxnet::_SimpleBindImpl(symbol_handle, dev_type, dev_id, num_g2c_keys, g2c_keys, g2c_dev_types, g2c_dev_ids, provided_grad_req_list_len, provided_grad_req_names, @@ -909,6 +910,45 @@ int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, } +// TODO(ChaiBapchya): add API doc for rest of C APIs for int64 +/*! + * \brief Large tensor specific implementation for simple_bind executor + * when USE_INT64_TENSOR_SIZE = ON + * \param symbol_handle symbol handle + * \param dev_type default device type + * \param dev_id default device id + * \param num_g2c_keys number of group2ctx keys + * \param g2c_keys key list of group2ctx + * \param g2c_dev_types device type list of group2ctx + * \param g2c_dev_ids id list of group2ctx + * \param provided_grad_req_list_len grad_req length provided by users in front-end + * \param provided_grad_req_names grad_req names provided by users in front-end + * \param provided_grad_req_types req types provided by users in front-end + * \param num_provided_arg_shapes number of user provided in_arg and aux_state shapes + * \param provided_arg_shape_names name list of provided shapes + * \param provided_arg_shape_data provided shape data + * \param provided_arg_shape_idx provided shape data index + * \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes + * \param provided_arg_dtype_names argument name list of provided dtypes + * \param provided_arg_dtypes data of provided dtypes + * \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types + * \param provided_arg_stype_names argument name list of provided storage types + * \param provided_arg_stypes data of provided storage types + * \param num_shared_arg_names number of parameter names passed from _bind_ith_exec + * \param shared_arg_name_list parameter name list passed from _bind_ith_exec + * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec + * \param shared_buffer_name_list shared data array names passed from _bind_ith_exec + * \param shared_buffer_handle_list shared data array handles passed from _bind_ith_exec + * \param updated_shared_buffer_name_list updated shared data array names after binding + * \param updated_shared_buffer_handle_list updated shared data arrays after binding + * \param num_in_args number of input arguments of this sym + * \param in_args list_arguments associated with the current executor + * \param arg_grads list of gradients of in_args associated with the current executor + * \param num_aux_states number of aux states of this sym + * \param aux_states list_auxiliary_states associated with the current executor + * \param shared_exec_handle shared excutor handle passed from _bind_ith_exec + * \param out the handle of the executor to be created + */ int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle, int dev_type, int dev_id, @@ -943,7 +983,7 @@ int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle, NDArrayHandle** aux_states, ExecutorHandle shared_exec_handle, ExecutorHandle* out) { - return mxnet::SimpleBindExMaster(symbol_handle, + return mxnet::_SimpleBindImpl(symbol_handle, dev_type, dev_id, num_g2c_keys, g2c_keys, g2c_dev_types, g2c_dev_ids, provided_grad_req_list_len, provided_grad_req_names, From e76aec4ca2eabca780d463f79d0be1b57146297d Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Thu, 24 Oct 2019 13:13:07 -0700 Subject: [PATCH 6/6] fix lint and add infer shape test --- tests/nightly/test_large_vector.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index 4a1731ac1346..c6a99a5d0826 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -64,7 +64,7 @@ def test_ndarray_random_randint(): high = 2**34 a = nd.random.randint(low, high, dtype=np.int64, shape=LARGE_X).asnumpy() assert a.shape == (LARGE_X,) - assert (a >= low).all() and (a < high).all() + assert (a >= low).all() and (a < high).all() def test_ndarray_empty(): @@ -1011,11 +1011,11 @@ def test_add_n(): def test_modulo(): x = mx.nd.ones(LARGE_X)*6 y = mx.nd.ones(LARGE_X)*4 - z = (x%y) + z = (x % y) assert z[0] == 2 assert z[-1] == 2 x = mx.nd.ones(LARGE_X)*5 - z = nd.modulo(x,y) + z = nd.modulo(x, y) assert z[0] == 1 assert z[-1] == 1 @@ -1055,6 +1055,16 @@ def test_gather(): assert np.sum(arr[idx] == 2) == 10 +def test_infer_shape(): + data_1 = mx.symbol.Variable('data_1') + data_2 = mx.symbol.Variable('data_2') + add = data_1+data_2 + # > add.infer_shape(data_1=(LARGE_X,), data_2=(LARGE_X,)) + # OUTPUT - arg_shapes, out_shapes, aux_shapes + _, out_shapes, _ = add.infer_shape(data_1=(LARGE_X,), data_2=(LARGE_X,)) + assert out_shapes == [(LARGE_X,)] + + if __name__ == '__main__': import nose nose.runmodule()