From 29dbe18290dc5a4f344de6d03ec4139484ece116 Mon Sep 17 00:00:00 2001 From: nnshah1 Date: Sun, 6 Oct 2024 09:21:25 -0700 Subject: [PATCH 1/4] trial --- python/tritonserver/_api/_tensor.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/python/tritonserver/_api/_tensor.py b/python/tritonserver/_api/_tensor.py index ee21abd59..b81c56358 100644 --- a/python/tritonserver/_api/_tensor.py +++ b/python/tritonserver/_api/_tensor.py @@ -58,6 +58,8 @@ tuple[MemoryType, int] | MemoryType | tuple[DLDeviceType, int] | str ) +import sys + try: import cupy except ImportError: @@ -214,25 +216,20 @@ def __dlpack__(self, *, stream=None): Any A DLPack-compatible object representing the tensor. """ - self._sync_on_requested_stream(stream) - dl_managed_tensor = Tensor._create_managed_tensor() dl_managed_tensor.dl_tensor.data = self.data_ptr dl_managed_tensor.dl_tensor.device = DLDevice( TRITON_MEMORY_TYPE_TO_DLPACK_DEVICE_TYPE[self.memory_type], self.memory_type_id, ) - dl_managed_tensor.dl_tensor.dtype = TRITON_TO_DLPACK_DTYPE[self.data_type] dl_managed_tensor.dl_tensor.ndim = len(self.shape) - dl_managed_tensor.dl_tensor.shape = (ctypes.c_int64 * len(self.shape))( - *self.shape - ) + self._ctypes_shape = (ctypes.c_int64 * len(self.shape))(*self.shape) + dl_managed_tensor.dl_tensor.shape = self._ctypes_shape dl_managed_tensor.dl_tensor.strides = ctypes.POINTER(ctypes.c_int64)() dl_managed_tensor.dl_tensor.byte_offset = 0 dl_managed_tensor.deleter = Tensor._managed_tensor_deleter - self._set_dlpack_manager_ctx(dl_managed_tensor) pycapsule = ctypes.pythonapi.PyCapsule_New( ctypes.byref(dl_managed_tensor), @@ -618,8 +615,6 @@ def _managed_tensor_deleter(handle: int) -> None: ) tensor_obj = tensor_obj_ptr.contents ctypes.pythonapi.Py_DecRef(tensor_obj) - shape_obj = ctypes.py_object(dl_managed_tensor.dl_tensor.shape) - ctypes.pythonapi.Py_DecRef(shape_obj) ctypes.pythonapi.PyMem_RawFree(handle) @staticmethod @@ -643,9 +638,7 @@ def _set_dlpack_manager_ctx(self, dl_managed_tensor): tensor_obj = ctypes.py_object(self) tensor_obj_ptr = ctypes.pointer(tensor_obj) dl_managed_tensor.manager_ctx = ctypes.cast(tensor_obj_ptr, ctypes.c_void_p) - shape_obj = ctypes.py_object(dl_managed_tensor.dl_tensor.shape) ctypes.pythonapi.Py_IncRef(tensor_obj) - ctypes.pythonapi.Py_IncRef(shape_obj) _from_converters: ClassVar[dict[type, Callable[[Any], Tensor]]] = dict( {numpy.ndarray: _from_numpy, numpy.generic: _from_numpy, list: _from_list}, From 69884679e81d17ff05aa2042da28a1334f0a2cb5 Mon Sep 17 00:00:00 2001 From: nnshah1 Date: Tue, 17 Dec 2024 22:48:08 -0800 Subject: [PATCH 2/4] updating for testing --- python/test/test_api.py | 67 +++++++++++++++++++++++++++++ python/tritonserver/_api/_tensor.py | 48 ++++++++++++++++++++- 2 files changed, 113 insertions(+), 2 deletions(-) diff --git a/python/test/test_api.py b/python/test/test_api.py index c15847aab..10c60d95b 100644 --- a/python/test/test_api.py +++ b/python/test/test_api.py @@ -37,6 +37,9 @@ import pytest import tritonserver +# import objgraph + + try: import cupy except ImportError: @@ -272,6 +275,10 @@ def test_allocate_on_gpu_and_reshape(self): self.assertEqual(torch_fp32_tensor.nbytes, 200) +import gc +from collections import Counter + + class TensorTests(unittest.TestCase): @pytest.mark.skipif(cupy is None, reason="Skipping gpu memory, cupy not installed") def test_cpu_to_gpu(self): @@ -315,6 +322,66 @@ def test_tensor_from_numpy(self): numpy.testing.assert_array_equal(torch_tensor.numpy(), cpu_array) self.assertEqual(torch_tensor.data_ptr(), cpu_array.ctypes.data) + def test_cpu_memory_leak(self): + gc.collect() + objects_before = gc.get_objects() + for index in range(20): + tensor = numpy.ones(2**27) + dl_pack_tensor = tritonserver.Tensor.from_dlpack(tensor) + array = numpy.from_dlpack(dl_pack_tensor) + # print(index, index*torch.numel(tensor)*tensor.element_size()) + del array + del tensor + del dl_pack_tensor + print(index) + # NOTE: if gc collect is called here + # no tensors are leaked - indicating a circular reference + # gc.collect() + gc.collect() + objects_after = gc.get_objects() + print(len(objects_after) - len(objects_before)) + new_objects = [type(x) for x in objects_after[len(objects_before) :]] + tensor_objects = [ + x for x in objects_after if isinstance(x, tritonserver.Tensor) + ] + if tensor_objects: + print("Tensor objects") + print(len(tensor_objects)) + print(type(tensor_objects[-1].memory_buffer.owner)) + + # chain = objgraph.find_backref_chain( + # tensor_objects[-1], objgraph.is_proper_module + # ) + # print(len(chain)) + # print(chain) + print(Counter(new_objects)) + + def test_gpu_memory_leak(self): + gc.collect() + objects_before = gc.get_objects() + for index in range(1000): + tensor = cupy.ones(2**27) + dl_pack_tensor = tritonserver.Tensor.from_dlpack(tensor) + array = cupy.from_dlpack(dl_pack_tensor) + # print(index, index*torch.numel(tensor)*tensor.element_size()) + del array + del tensor + del dl_pack_tensor + print(index) + # gc.collect() + objects_after = gc.get_objects() + print(len(objects_after) - len(objects_before)) + new_objects = [type(x) for x in objects_after[len(objects_before) :]] + tensor_objects = [ + x for x in objects_after if isinstance(x, tritonserver.Tensor) + ] + if tensor_objects: + print(type(tensor_objects[-1].memory_buffer.owner)) + + print(Counter(new_objects)) + + assert len(tensor_objects) == 0, "Leaked Objects" + class ServerTests(unittest.TestCase): def setup_method(self, method): diff --git a/python/tritonserver/_api/_tensor.py b/python/tritonserver/_api/_tensor.py index b81c56358..20775c736 100644 --- a/python/tritonserver/_api/_tensor.py +++ b/python/tritonserver/_api/_tensor.py @@ -54,6 +54,9 @@ from tritonserver._c.triton_bindings import TRITONSERVER_MemoryType as MemoryType from tritonserver._c.triton_bindings import UnsupportedError +# import objgraph + + DeviceOrMemoryType = ( tuple[MemoryType, int] | MemoryType | tuple[DLDeviceType, int] | str ) @@ -217,7 +220,10 @@ def __dlpack__(self, *, stream=None): A DLPack-compatible object representing the tensor. """ self._sync_on_requested_stream(stream) + + ## Debug Note: creates managed tensor with malloc dl_managed_tensor = Tensor._create_managed_tensor() + dl_managed_tensor.dl_tensor.data = self.data_ptr dl_managed_tensor.dl_tensor.device = DLDevice( TRITON_MEMORY_TYPE_TO_DLPACK_DEVICE_TYPE[self.memory_type], @@ -225,12 +231,25 @@ def __dlpack__(self, *, stream=None): ) dl_managed_tensor.dl_tensor.dtype = TRITON_TO_DLPACK_DTYPE[self.data_type] dl_managed_tensor.dl_tensor.ndim = len(self.shape) - self._ctypes_shape = (ctypes.c_int64 * len(self.shape))(*self.shape) - dl_managed_tensor.dl_tensor.shape = self._ctypes_shape + + ## Original issue was that the shape was created here but never unreferenced + ## self._ctypes_shape = (ctypes.c_int64 * len(self.shape))(*self.shape) + ## now we create the shape array using malloc + dl_managed_tensor.dl_tensor.shape = (ctypes.c_int64 * len(self.shape))( + *self.shape + ) + # dl_managed_tensor.dl_tensor.shape = Tensor._create_shape_array(self.shape) + + ## NOTE for debug: this is a null ptr dl_managed_tensor.dl_tensor.strides = ctypes.POINTER(ctypes.c_int64)() dl_managed_tensor.dl_tensor.byte_offset = 0 dl_managed_tensor.deleter = Tensor._managed_tensor_deleter + + ## Note for debug: this method sets the context to point to + ## this Tensor instance after increasing the reference count + self._set_dlpack_manager_ctx(dl_managed_tensor) + pycapsule = ctypes.pythonapi.PyCapsule_New( ctypes.byref(dl_managed_tensor), c_str_dltensor, @@ -600,6 +619,16 @@ def _from_numpy(obj: numpy.ndarray | numpy.generic) -> Tensor: return Tensor(data_type, shape, memory_buffer) + @staticmethod + def _create_shape_array(shape): + array_type = ctypes.c_int64 * len(shape) + size = ctypes.c_size_t(ctypes.sizeof(array_type)) + address = ctypes.pythonapi.PyMem_RawMalloc(size) + array = array_type.from_address(address) + for index in range(len(shape)): + array[index] = shape[index] + return array + @staticmethod def _create_managed_tensor(): size = ctypes.c_size_t(ctypes.sizeof(DLManagedTensor)) @@ -609,18 +638,33 @@ def _create_managed_tensor(): @staticmethod @ctypes.CFUNCTYPE(None, ctypes.c_void_p) def _managed_tensor_deleter(handle: int) -> None: + # DEBUG print("managed tensor deleter!",flush=True) + dl_managed_tensor = DLManagedTensor.from_address(handle) tensor_obj_ptr = ctypes.cast( dl_managed_tensor.manager_ctx, ctypes.POINTER(ctypes.py_object) ) tensor_obj = tensor_obj_ptr.contents + + # DEBUG Note: free the shape array + # ctypes.pythonapi.PyMem_RawFree(dl_managed_tensor.dl_tensor.shape) + + # DEBUG Note: decrement reference to original tensor object ctypes.pythonapi.Py_DecRef(tensor_obj) + + # DEBUG Note: free the managed tensor + ctypes.pythonapi.PyMem_RawFree(handle) + # DEBUG chain = objgraph.find_backref_chain(tensor_obj, objgraph.is_proper_module) + # DEBUG print(len(chain)) + # DEBUG print([type(x) for x in chain]) + @staticmethod @ctypes.CFUNCTYPE(None, ctypes.c_void_p) def _pycapsule_deleter(handle: ctypes.c_void_p) -> None: try: + # DEBUG print("capsule deleter!",flush=True) pycapsule: ctypes.py_object = ctypes.cast(handle, ctypes.py_object) if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, c_str_dltensor): dl_managed_tensor = ctypes.pythonapi.PyCapsule_GetPointer( From f6c5398e16b4e5f9673854fca5667fc733e16df9 Mon Sep 17 00:00:00 2001 From: nnshah1 Date: Tue, 17 Dec 2024 23:47:53 -0800 Subject: [PATCH 3/4] updated with descriptions on before / after state and test failure condition --- python/test/test_api.py | 28 ++++++++++++++++++++++--- python/tritonserver/_api/_tensor.py | 32 ++++++++++++++++++++++------- 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/python/test/test_api.py b/python/test/test_api.py index 10c60d95b..8c98b4166 100644 --- a/python/test/test_api.py +++ b/python/test/test_api.py @@ -325,7 +325,7 @@ def test_tensor_from_numpy(self): def test_cpu_memory_leak(self): gc.collect() objects_before = gc.get_objects() - for index in range(20): + for index in range(50): tensor = numpy.ones(2**27) dl_pack_tensor = tritonserver.Tensor.from_dlpack(tensor) array = numpy.from_dlpack(dl_pack_tensor) @@ -334,10 +334,18 @@ def test_cpu_memory_leak(self): del tensor del dl_pack_tensor print(index) + # NOTE: if gc collect is called here # no tensors are leaked - indicating a circular reference # gc.collect() - gc.collect() + + # Note: + # Originally gc.collect() had no effect on memory reclaiming + # with the changes in the PR - uncommenting this line + # forces all tensors to be reclaimed and test passes + # This shouldn't be needed + + # gc.collect() objects_after = gc.get_objects() print(len(objects_after) - len(objects_before)) new_objects = [type(x) for x in objects_after[len(objects_before) :]] @@ -356,10 +364,12 @@ def test_cpu_memory_leak(self): # print(chain) print(Counter(new_objects)) + assert len(tensor_objects) == 0, "Leaked Objects" + def test_gpu_memory_leak(self): gc.collect() objects_before = gc.get_objects() - for index in range(1000): + for index in range(50): tensor = cupy.ones(2**27) dl_pack_tensor = tritonserver.Tensor.from_dlpack(tensor) array = cupy.from_dlpack(dl_pack_tensor) @@ -368,6 +378,18 @@ def test_gpu_memory_leak(self): del tensor del dl_pack_tensor print(index) + + # NOTE: if gc collect is called here + # no tensors are leaked - indicating a circular reference + # gc.collect() + + # Note: + # Originally gc.collect() had no effect on memory reclaiming + # with the changes in the PR - uncommenting this line + # forces all tensors to be reclaimed and test passes + # This shouldn't be needed + + # gc.collect() # gc.collect() objects_after = gc.get_objects() print(len(objects_after) - len(objects_before)) diff --git a/python/tritonserver/_api/_tensor.py b/python/tritonserver/_api/_tensor.py index 20775c736..2245ad199 100644 --- a/python/tritonserver/_api/_tensor.py +++ b/python/tritonserver/_api/_tensor.py @@ -231,14 +231,17 @@ def __dlpack__(self, *, stream=None): ) dl_managed_tensor.dl_tensor.dtype = TRITON_TO_DLPACK_DTYPE[self.data_type] dl_managed_tensor.dl_tensor.ndim = len(self.shape) + print("storing shape", self.shape) + + ## Original issue was that the shape was created here + ## But could not be freed correctly + ## + ## dl_managed_tensor.dl_tensor.shape = (ctypes.c_int64 * len(self.shape))( + ## *self.shape + ## ) - ## Original issue was that the shape was created here but never unreferenced - ## self._ctypes_shape = (ctypes.c_int64 * len(self.shape))(*self.shape) ## now we create the shape array using malloc - dl_managed_tensor.dl_tensor.shape = (ctypes.c_int64 * len(self.shape))( - *self.shape - ) - # dl_managed_tensor.dl_tensor.shape = Tensor._create_shape_array(self.shape) + dl_managed_tensor.dl_tensor.shape = Tensor._create_shape_array(self.shape) ## NOTE for debug: this is a null ptr dl_managed_tensor.dl_tensor.strides = ctypes.POINTER(ctypes.c_int64)() @@ -646,8 +649,14 @@ def _managed_tensor_deleter(handle: int) -> None: ) tensor_obj = tensor_obj_ptr.contents + print(dl_managed_tensor.dl_tensor.shape[0]) + # DEBUG Note: free the shape array - # ctypes.pythonapi.PyMem_RawFree(dl_managed_tensor.dl_tensor.shape) + ctypes.pythonapi.PyMem_RawFree(dl_managed_tensor.dl_tensor.shape) + + ## Original - caused memory leak + ## shape_obj = ctypes.py_object(dl_managed_tensor.dl_tensor.shape) + ## ctypes.pythonapi.Py_DecRef(shape_obj) # DEBUG Note: decrement reference to original tensor object ctypes.pythonapi.Py_DecRef(tensor_obj) @@ -684,6 +693,15 @@ def _set_dlpack_manager_ctx(self, dl_managed_tensor): dl_managed_tensor.manager_ctx = ctypes.cast(tensor_obj_ptr, ctypes.c_void_p) ctypes.pythonapi.Py_IncRef(tensor_obj) + ## Original Issue + ## this caused the tensor object to never be garbage collected + ## + ## Removing the IncRef caused the shape to be corrupted + ## Current solution uses malloc + + ## shape_obj = ctypes.py_object(dl_managed_tensor.dl_tensor.shape) + ## ctypes.pythonapi.Py_IncRef(shape_obj) + _from_converters: ClassVar[dict[type, Callable[[Any], Tensor]]] = dict( {numpy.ndarray: _from_numpy, numpy.generic: _from_numpy, list: _from_list}, ) From 7d7831ee3d15ced97f5bebf31fd265a01ed455bb Mon Sep 17 00:00:00 2001 From: nnshah1 Date: Wed, 18 Dec 2024 00:07:55 -0800 Subject: [PATCH 4/4] updated order of deletion --- python/test/test_api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/test/test_api.py b/python/test/test_api.py index 8c98b4166..30f7ccd4b 100644 --- a/python/test/test_api.py +++ b/python/test/test_api.py @@ -325,14 +325,14 @@ def test_tensor_from_numpy(self): def test_cpu_memory_leak(self): gc.collect() objects_before = gc.get_objects() - for index in range(50): + for index in range(30): tensor = numpy.ones(2**27) dl_pack_tensor = tritonserver.Tensor.from_dlpack(tensor) array = numpy.from_dlpack(dl_pack_tensor) # print(index, index*torch.numel(tensor)*tensor.element_size()) del array - del tensor del dl_pack_tensor + del tensor print(index) # NOTE: if gc collect is called here @@ -375,8 +375,8 @@ def test_gpu_memory_leak(self): array = cupy.from_dlpack(dl_pack_tensor) # print(index, index*torch.numel(tensor)*tensor.element_size()) del array - del tensor del dl_pack_tensor + del tensor print(index) # NOTE: if gc collect is called here