diff --git a/python/cuml/feature_extraction/_tfidf_vectorizer.py b/python/cuml/feature_extraction/_tfidf_vectorizer.py index e2cc4158b6..fbeeba7fc2 100644 --- a/python/cuml/feature_extraction/_tfidf_vectorizer.py +++ b/python/cuml/feature_extraction/_tfidf_vectorizer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2021, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -256,3 +256,13 @@ def transform(self, raw_documents): """ X = super().transform(raw_documents) return self._tfidf.transform(X, copy=False) + + def get_feature_names(self): + """ + Array mapping from feature integer indices to feature name. + Returns + ------- + feature_names : Series + A list of feature names. + """ + return super().get_feature_names() diff --git a/python/cuml/test/test_text_feature_extraction.py b/python/cuml/test/test_text_feature_extraction.py index 1ff751d316..bf09e320be 100644 --- a/python/cuml/test/test_text_feature_extraction.py +++ b/python/cuml/test/test_text_feature_extraction.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2020, NVIDIA CORPORATION. +# Copyright (c) 2019-2021, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -374,6 +374,20 @@ def test_tfidf_vectorizer(norm, use_idf, smooth_idf, sublinear_tf): cp.testing.assert_array_almost_equal(tfidf_mat.todense(), ref.toarray()) +def test_tfidf_vectorizer_get_feature_names(): + corpus = [ + 'This is the first document.', + 'This document is the second document.', + 'And this is the third one.', + 'Is this the first document?', + ] + vectorizer = TfidfVectorizer() + vectorizer.fit_transform(Series(corpus)) + output = ['and', 'document', 'first', 'is', + 'one', 'second', 'the', 'third', 'this'] + assert vectorizer.get_feature_names().to_arrow().to_pylist() == output + + # ---------------------------------------------------------------- # HashingVectorizer tests # ----------------------------------------------------------------