diff --git a/ding/torch_utils/data_helper.py b/ding/torch_utils/data_helper.py index 95953f8d9d..e34df5308a 100644 --- a/ding/torch_utils/data_helper.py +++ b/ding/torch_utils/data_helper.py @@ -218,7 +218,10 @@ def transform(d): elif isinstance(item, bool) or isinstance(item, str): return item elif np.isscalar(item): - return np.array(item) + if dtype is None: + return np.array(item) + else: + return np.array(item, dtype=dtype) elif item is None: return None else: