diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index 18f7882e4..a47669c37 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -209,7 +209,7 @@ def __delitem__(self, key): del self.store[key] self.proxies_tally.del_key(key) - def evict(self, proxy): + def evict(self, proxy: ProxyObject): proxy._obj_pxy_serialize(serializers=("dask", "pickle")) def maybe_evict(self, extra_dev_mem=0): @@ -219,13 +219,14 @@ def maybe_evict(self, extra_dev_mem=0): ): return - total_dev_mem_usage, dev_buf_access = self.get_access_info() - total_dev_mem_usage += extra_dev_mem - if total_dev_mem_usage > self.device_memory_limit: - dev_buf_access.sort(key=lambda x: (x[0], -x[1])) - for _, size, proxies in dev_buf_access: - for p in proxies: - self.evict(p) - total_dev_mem_usage -= size - if total_dev_mem_usage <= self.device_memory_limit: - break + with self.lock: + total_dev_mem_usage, dev_buf_access = self.get_access_info() + total_dev_mem_usage += extra_dev_mem + if total_dev_mem_usage > self.device_memory_limit: + dev_buf_access.sort(key=lambda x: (x[0], -x[1])) + for _, size, proxies in dev_buf_access: + for p in proxies: + self.evict(p) + total_dev_mem_usage -= size + if total_dev_mem_usage <= self.device_memory_limit: + break diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index 13092e7c0..0232a6d9e 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -107,13 +107,12 @@ def _obj_pxy_cache_wrapper(attr_name): def wrapper1(func): @functools.wraps(func) def wrapper2(self: "ProxyObject"): - with self._obj_pxy_lock: - try: - return self._obj_pxy_cache[attr_name] - except KeyError: - ret = func(self) - self._obj_pxy_cache[attr_name] = ret - return ret + try: + return self._obj_pxy_cache[attr_name] + except KeyError: + ret = func(self) + self._obj_pxy_cache[attr_name] = ret + return ret return wrapper2 @@ -299,7 +298,14 @@ def _obj_pxy_deserialize(self): # to evict because of the increased device memory usage. if "cuda" not in self._obj_pxy["serializers"]: if hostfile is not None: - hostfile.maybe_evict(self.__sizeof__()) + # In order to avoid a potential deadlock, we skip the + # `maybe_evict()` call if another thread is also accessing + # the hostfile. + if hostfile.lock.acquire(blocking=False): + try: + hostfile.maybe_evict(self.__sizeof__()) + finally: + hostfile.lock.release() header, frames = self._obj_pxy["obj"] self._obj_pxy["obj"] = distributed.protocol.deserialize(header, frames)