diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 8a68ad9d54baf..b2db00296bf95 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -26,31 +26,25 @@ from .framework import _cpu_num, _cuda_ids __all__ = ['DataFeeder'] +_PADDLE_DTYPE_2_NUMPY_DTYPE = { + core.VarDesc.VarType.BOOL: 'bool', + core.VarDesc.VarType.FP16: 'float16', + core.VarDesc.VarType.FP32: 'float32', + core.VarDesc.VarType.FP64: 'float64', + core.VarDesc.VarType.INT8: 'int8', + core.VarDesc.VarType.INT16: 'int16', + core.VarDesc.VarType.INT32: 'int32', + core.VarDesc.VarType.INT64: 'int64', + core.VarDesc.VarType.UINT8: 'uint8', + core.VarDesc.VarType.COMPLEX64: 'complex64', + core.VarDesc.VarType.COMPLEX128: 'complex128', +} + def convert_dtype(dtype): if isinstance(dtype, core.VarDesc.VarType): - if dtype == core.VarDesc.VarType.BOOL: - return 'bool' - elif dtype == core.VarDesc.VarType.FP16: - return 'float16' - elif dtype == core.VarDesc.VarType.FP32: - return 'float32' - elif dtype == core.VarDesc.VarType.FP64: - return 'float64' - elif dtype == core.VarDesc.VarType.INT8: - return 'int8' - elif dtype == core.VarDesc.VarType.INT16: - return 'int16' - elif dtype == core.VarDesc.VarType.INT32: - return 'int32' - elif dtype == core.VarDesc.VarType.INT64: - return 'int64' - elif dtype == core.VarDesc.VarType.UINT8: - return 'uint8' - elif dtype == core.VarDesc.VarType.COMPLEX64: - return 'complex64' - elif dtype == core.VarDesc.VarType.COMPLEX128: - return 'complex128' + if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE: + return _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype] elif isinstance(dtype, type): if dtype in [ np.bool, np.float16, np.float32, np.float64, np.int8, np.int16, diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 6238f0961781c..d3cf4d7bf3a37 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -23,7 +23,7 @@ from .base import switch_to_static_graph from .math_op_patch import monkey_patch_math_varbase from .parallel import scale_loss -from paddle.fluid.data_feeder import convert_dtype +from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE def monkey_patch_varbase(): @@ -320,13 +320,19 @@ def __bool__(self): ("__name__", "Tensor")): setattr(core.VarBase, method_name, method) - def dtype_str(dtype): - prefix = 'paddle.' - return prefix + convert_dtype(dtype) - # NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class. - # So, we need to overwrite it to custom one. + # So, we need to overwrite it to a more readable one. # See details in https://github.com/pybind/pybind11/issues/2537. + origin = getattr(core.VarDesc.VarType, "__repr__") + + def dtype_str(dtype): + if dtype in _PADDLE_DTYPE_2_NUMPY_DTYPE: + prefix = 'paddle.' + return prefix + _PADDLE_DTYPE_2_NUMPY_DTYPE[dtype] + else: + # for example, paddle.fluid.core.VarDesc.VarType.LOD_TENSOR + return origin(dtype) + setattr(core.VarDesc.VarType, "__repr__", dtype_str) # patch math methods for varbase