diff --git a/mne_features/feature_extraction.py b/mne_features/feature_extraction.py index f65cb10..76315fc 100644 --- a/mne_features/feature_extraction.py +++ b/mne_features/feature_extraction.py @@ -116,6 +116,10 @@ def get_feature_names(self): else: return np.arange(self.output_shape_).astype(str) + def get_feature_names_out(self, input_features=None): + """Mapping of the feature indices to feature names.""" + return self.get_feature_names() + def get_params(self, deep=True): """Get the parameters (if any) of the given feature function. @@ -219,7 +223,10 @@ def _apply_extractor(extractor, X, ch_names, return_as_df): X = extractor.fit_transform(X) feature_names = None if return_as_df: - feature_names = extractor.get_feature_names() + if hasattr(extractor, 'get_feature_names_out'): + feature_names = extractor.get_feature_names_out() + else: + feature_names = extractor.get_feature_names() if ch_names is not None: # rename channels mapping = {'ch%s' % i: ch_name for i, ch_name in enumerate(ch_names)}