diff --git a/paddlemix/appflow/text2speech_synthesize.py b/paddlemix/appflow/text2speech_synthesize.py index 5d33e23b82878..8d348df0e005d 100644 --- a/paddlemix/appflow/text2speech_synthesize.py +++ b/paddlemix/appflow/text2speech_synthesize.py @@ -16,6 +16,26 @@ from .apptask import AppTask import paddle from paddlemix.utils.log import logger +import paddle.nn as nn, paddle +def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype: + try: + return next(parameter.named_parameters())[1].dtype + except StopIteration: + try: + return next(parameter.named_buffers())[1].dtype + except StopIteration: + return parameter._dtype +@property +def dtype_getter(self): + if hasattr(self, "__dtype"): + return self.__dtype + return get_parameter_dtype(self) +nn.Layer.dtype = dtype_getter + +@nn.Layer.dtype.setter +def dtype_setter(self, value): + self.__dtype = value +nn.Layer.dtype = dtype_setter class AudioTTSTask(AppTask): @@ -24,6 +44,7 @@ def __init__(self, task, model, **kwargs): # Default to static mode self._static_mode = False + self._construct_model() def _construct_model(self):