From 0828ee6532be9bb1876fdaf3259f1acfaa9e3c92 Mon Sep 17 00:00:00 2001 From: Max Halford Date: Sat, 7 Sep 2024 14:35:01 +0200 Subject: [PATCH 1/7] Enhancements --- tests/test_mca.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_mca.py b/tests/test_mca.py index ca1afc23..87fb377b 100644 --- a/tests/test_mca.py +++ b/tests/test_mca.py @@ -135,6 +135,34 @@ def test_issue_131(): """ +def test_issue_171(): + """ + + >>> from sklearn import impute + >>> from sklearn import pipeline + + >>> test_data = pd.DataFrame(data=np.random.random((10, 5))) + >>> test = pipeline.Pipeline(steps=[ + ... ('impute', impute.SimpleImputer()), # would break the pipeline since it returns an ndarray + ... ('mca', prince.MCA()), + ... ]) + >>> _ = test[0].set_output(transform='pandas') + >>> test.fit_transform(test_data) + 0 1 + 0 -2.384233e-16 1.432250e-16 + 1 -1.296231e+00 -1.146678e+00 + 2 -7.612724e-01 -5.776135e-01 + 3 1.468565e+00 5.043704e-01 + 4 1.618681e+00 -6.557499e-02 + 5 -1.045958e+00 2.636816e+00 + 6 1.127237e+00 -2.579320e-01 + 7 -4.013652e-01 2.714293e-01 + 8 -4.999467e-02 -4.974244e-01 + 9 -6.596616e-01 -8.673923e-01 + + """ + + def test_type_doesnt_matter(): """ From 81b7a1329408ce579a6c2b143fc7bb576fb9309f Mon Sep 17 00:00:00 2001 From: Max Halford Date: Sat, 7 Sep 2024 15:43:24 +0200 Subject: [PATCH 2/7] Add get_feature_names_out --- docs/config.toml | 2 +- docs/content/faq.ipynb | 136 +++++++++++++++++++++++++++++++++++++++++ prince/mca.py | 3 + prince/pca.py | 3 + tests/test_mca.py | 2 + 5 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 docs/content/faq.ipynb diff --git a/docs/config.toml b/docs/config.toml index cab2e968..fdf03b34 100644 --- a/docs/config.toml +++ b/docs/config.toml @@ -8,7 +8,7 @@ theme = 'hugo-bearblog' # Basic metadata configuration for your blog. title = "Prince" author = "Max Halford" -copyright = "Copyright © 2023, Max Halford." +copyright = "Copyright © 2024, Max Halford." languageCode = "en-US" # Generate a nice robots.txt for SEO diff --git a/docs/content/faq.ipynb b/docs/content/faq.ipynb new file mode 100644 index 00000000..c5641611 --- /dev/null +++ b/docs/content/faq.ipynb @@ -0,0 +1,136 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "+++\n", + "title = \"Frequently Asked Questions\"\n", + "menu = \"main\"\n", + "weight = 7\n", + "toc = true\n", + "aliases = [\"faq\"]\n", + "+++" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**How to use Prince with sklearn pipelines?**\n", + "\n", + "Prince estimators consume and produce pandas DataFrames. If you want to use them in a sklearn pipeline, you can [sklearn's `set_output` API](https://scikit-learn.org/stable/auto_examples/miscellaneous/plot_set_output.html). This way, you can tell sklearn that the pipeline should exchange DataFrames instead of numpy arrays between the steps." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
component01
0-2.2647030.480027
1-2.080961-0.674134
2-2.364229-0.341908
3-2.299384-0.597395
4-2.3898420.646835
\n", + "
" + ], + "text/plain": [ + "component 0 1\n", + "0 -2.264703 0.480027\n", + "1 -2.080961 -0.674134\n", + "2 -2.364229 -0.341908\n", + "3 -2.299384 -0.597395\n", + "4 -2.389842 0.646835" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import prince\n", + "from sklearn import datasets\n", + "from sklearn import impute\n", + "from sklearn import pipeline\n", + "\n", + "pipe = pipeline.make_pipeline(\n", + " impute.SimpleImputer(),\n", + " prince.PCA()\n", + ")\n", + "pipe.set_output(transform='pandas')\n", + "dataset = datasets.load_iris()\n", + "pipe.fit_transform(dataset.data).head()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "prince-NQ1O93Uh-py3.11", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/prince/mca.py b/prince/mca.py index 67924c3c..e49305e3 100644 --- a/prince/mca.py +++ b/prince/mca.py @@ -37,6 +37,9 @@ def _prepare(self, X): X = pd.get_dummies(X, columns=X.columns) return X + def get_feature_names_out(self, input_features=None): + return np.arange(self.n_components_) + @utils.check_is_dataframe_input def fit(self, X, y=None): """Fit the MCA for the dataframe X. diff --git a/prince/pca.py b/prince/pca.py index 41458212..e7603122 100755 --- a/prince/pca.py +++ b/prince/pca.py @@ -67,6 +67,9 @@ def _check_input(self, X): if self.check_input: sklearn.utils.check_array(X) + def get_feature_names_out(self, input_features=None): + return np.arange(self.n_components_) + @utils.check_is_dataframe_input def fit(self, X, y=None, supplementary_columns=None): self._check_input(X) diff --git a/tests/test_mca.py b/tests/test_mca.py index 87fb377b..60518b8c 100644 --- a/tests/test_mca.py +++ b/tests/test_mca.py @@ -138,6 +138,8 @@ def test_issue_131(): def test_issue_171(): """ + https://github.com/MaxHalford/prince/issues/171 + >>> from sklearn import impute >>> from sklearn import pipeline From 4cfacf5690736e999c3f1ca486885d33d5cc9de9 Mon Sep 17 00:00:00 2001 From: Max Halford Date: Sat, 7 Sep 2024 15:48:40 +0200 Subject: [PATCH 3/7] Update test_mca.py --- tests/test_mca.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/test_mca.py b/tests/test_mca.py index 60518b8c..3f77dbb0 100644 --- a/tests/test_mca.py +++ b/tests/test_mca.py @@ -143,24 +143,25 @@ def test_issue_171(): >>> from sklearn import impute >>> from sklearn import pipeline - >>> test_data = pd.DataFrame(data=np.random.random((10, 5))) + >>> rng = np.random.RandomState(0) + >>> test_data = pd.DataFrame(data=rng.random((10, 5))) >>> test = pipeline.Pipeline(steps=[ ... ('impute', impute.SimpleImputer()), # would break the pipeline since it returns an ndarray - ... ('mca', prince.MCA()), + ... ('mca', prince.PCA()), ... ]) >>> _ = test[0].set_output(transform='pandas') >>> test.fit_transform(test_data) - 0 1 - 0 -2.384233e-16 1.432250e-16 - 1 -1.296231e+00 -1.146678e+00 - 2 -7.612724e-01 -5.776135e-01 - 3 1.468565e+00 5.043704e-01 - 4 1.618681e+00 -6.557499e-02 - 5 -1.045958e+00 2.636816e+00 - 6 1.127237e+00 -2.579320e-01 - 7 -4.013652e-01 2.714293e-01 - 8 -4.999467e-02 -4.974244e-01 - 9 -6.596616e-01 -8.673923e-01 + component 0 1 + 0 -0.392617 0.296831 + 1 0.119661 -1.660653 + 2 -1.541581 -0.826863 + 3 3.105498 -0.538801 + 4 -2.439259 -0.343292 + 5 1.129341 -0.533576 + 6 -1.077436 0.899673 + 7 0.020571 -0.941029 + 8 1.498005 1.566376 + 9 -0.422184 2.081334 """ From 2018ef908eb86a9d9c46de306424fa414a579530 Mon Sep 17 00:00:00 2001 From: Max Halford Date: Sat, 7 Sep 2024 17:06:52 +0200 Subject: [PATCH 4/7] Support new categories in MCA --- docs/content/mca.ipynb | 586 ++++++++++++++++++++++++++++++++++------- prince/mca.py | 16 +- 2 files changed, 505 insertions(+), 97 deletions(-) diff --git a/docs/content/mca.ipynb b/docs/content/mca.ipynb index ab86ae23..6d098c01 100644 --- a/docs/content/mca.ipynb +++ b/docs/content/mca.ipynb @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-10-11T22:33:02.082588Z", @@ -136,7 +136,7 @@ "4 YELLOW LARGE STRETCH ADULT T" ] }, - "execution_count": 1, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -159,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-10-11T22:33:02.663561Z", @@ -168,7 +168,55 @@ "shell.execute_reply": "2023-10-11T22:33:03.033990Z" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Color_PURPLE Color_YELLOW Size_LARGE Size_SMALL Action_DIP \n", + "0 0.0 1.0 0.0 1.0 0.0 \\\n", + "1 0.0 1.0 0.0 1.0 0.0 \n", + "2 0.0 1.0 0.0 1.0 1.0 \n", + "3 0.0 1.0 0.0 1.0 1.0 \n", + "4 0.0 1.0 1.0 0.0 0.0 \n", + "5 0.0 1.0 1.0 0.0 0.0 \n", + "6 0.0 1.0 1.0 0.0 0.0 \n", + "7 0.0 1.0 1.0 0.0 1.0 \n", + "8 0.0 1.0 1.0 0.0 1.0 \n", + "9 1.0 0.0 0.0 1.0 0.0 \n", + "10 1.0 0.0 0.0 1.0 0.0 \n", + "11 1.0 0.0 0.0 1.0 0.0 \n", + "12 1.0 0.0 0.0 1.0 1.0 \n", + "13 1.0 0.0 0.0 1.0 1.0 \n", + "14 1.0 0.0 1.0 0.0 0.0 \n", + "15 1.0 0.0 1.0 0.0 0.0 \n", + "16 1.0 0.0 1.0 0.0 0.0 \n", + "17 1.0 0.0 1.0 0.0 1.0 \n", + "18 1.0 0.0 1.0 0.0 1.0 \n", + "\n", + " Action_STRETCH Age_ADULT Age_CHILD Inflated_F Inflated_T \n", + "0 1.0 1.0 0.0 0.0 1.0 \n", + "1 1.0 0.0 1.0 1.0 0.0 \n", + "2 0.0 1.0 0.0 1.0 0.0 \n", + "3 0.0 0.0 1.0 1.0 0.0 \n", + "4 1.0 1.0 0.0 0.0 1.0 \n", + "5 1.0 1.0 0.0 0.0 1.0 \n", + "6 1.0 0.0 1.0 1.0 0.0 \n", + "7 0.0 1.0 0.0 1.0 0.0 \n", + "8 0.0 0.0 1.0 1.0 0.0 \n", + "9 1.0 1.0 0.0 0.0 1.0 \n", + "10 1.0 1.0 0.0 0.0 1.0 \n", + "11 1.0 0.0 1.0 1.0 0.0 \n", + "12 0.0 1.0 0.0 1.0 0.0 \n", + "13 0.0 0.0 1.0 1.0 0.0 \n", + "14 1.0 1.0 0.0 0.0 1.0 \n", + "15 1.0 1.0 0.0 0.0 1.0 \n", + "16 1.0 0.0 1.0 1.0 0.0 \n", + "17 0.0 1.0 0.0 1.0 0.0 \n", + "18 0.0 0.0 1.0 1.0 0.0 \n" + ] + } + ], "source": [ "import prince\n", "\n", @@ -192,7 +240,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-10-11T22:33:03.039813Z", @@ -219,7 +267,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-10-11T22:33:03.080736Z", @@ -292,7 +340,7 @@ "2 0.186 18.56% 79.84%" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -311,7 +359,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-10-11T22:33:03.111432Z", @@ -321,6 +369,53 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Color_PURPLE Color_YELLOW Size_LARGE Size_SMALL Action_DIP \n", + "0 0.0 1.0 0.0 1.0 0.0 \\\n", + "1 0.0 1.0 0.0 1.0 0.0 \n", + "2 0.0 1.0 0.0 1.0 1.0 \n", + "3 0.0 1.0 0.0 1.0 1.0 \n", + "4 0.0 1.0 1.0 0.0 0.0 \n", + "5 0.0 1.0 1.0 0.0 0.0 \n", + "6 0.0 1.0 1.0 0.0 0.0 \n", + "7 0.0 1.0 1.0 0.0 1.0 \n", + "8 0.0 1.0 1.0 0.0 1.0 \n", + "9 1.0 0.0 0.0 1.0 0.0 \n", + "10 1.0 0.0 0.0 1.0 0.0 \n", + "11 1.0 0.0 0.0 1.0 0.0 \n", + "12 1.0 0.0 0.0 1.0 1.0 \n", + "13 1.0 0.0 0.0 1.0 1.0 \n", + "14 1.0 0.0 1.0 0.0 0.0 \n", + "15 1.0 0.0 1.0 0.0 0.0 \n", + "16 1.0 0.0 1.0 0.0 0.0 \n", + "17 1.0 0.0 1.0 0.0 1.0 \n", + "18 1.0 0.0 1.0 0.0 1.0 \n", + "\n", + " Action_STRETCH Age_ADULT Age_CHILD Inflated_F Inflated_T \n", + "0 1.0 1.0 0.0 0.0 1.0 \n", + "1 1.0 0.0 1.0 1.0 0.0 \n", + "2 0.0 1.0 0.0 1.0 0.0 \n", + "3 0.0 0.0 1.0 1.0 0.0 \n", + "4 1.0 1.0 0.0 0.0 1.0 \n", + "5 1.0 1.0 0.0 0.0 1.0 \n", + "6 1.0 0.0 1.0 1.0 0.0 \n", + "7 0.0 1.0 0.0 1.0 0.0 \n", + "8 0.0 0.0 1.0 1.0 0.0 \n", + "9 1.0 1.0 0.0 0.0 1.0 \n", + "10 1.0 1.0 0.0 0.0 1.0 \n", + "11 1.0 0.0 1.0 1.0 0.0 \n", + "12 0.0 1.0 0.0 1.0 0.0 \n", + "13 0.0 0.0 1.0 1.0 0.0 \n", + "14 1.0 1.0 0.0 0.0 1.0 \n", + "15 1.0 1.0 0.0 0.0 1.0 \n", + "16 1.0 0.0 1.0 1.0 0.0 \n", + "17 0.0 1.0 0.0 1.0 0.0 \n", + "18 0.0 0.0 1.0 1.0 0.0 \n" + ] + }, { "data": { "text/html": [ @@ -351,25 +446,25 @@ " \n", " 0\n", " 0.705387\n", - " 8.460396e-15\n", + " 9.509676e-15\n", " 0.758639\n", " \n", " \n", " 1\n", " -0.386586\n", - " 8.514287e-15\n", + " 7.593937e-15\n", " 0.626063\n", " \n", " \n", " 2\n", " -0.386586\n", - " 6.249235e-15\n", + " 6.106546e-15\n", " 0.626063\n", " \n", " \n", " 3\n", " -0.852014\n", - " 6.872889e-15\n", + " 5.547435e-15\n", " 0.562447\n", " \n", " \n", @@ -384,14 +479,14 @@ ], "text/plain": [ " 0 1 2\n", - "0 0.705387 8.460396e-15 0.758639\n", - "1 -0.386586 8.514287e-15 0.626063\n", - "2 -0.386586 6.249235e-15 0.626063\n", - "3 -0.852014 6.872889e-15 0.562447\n", + "0 0.705387 9.509676e-15 0.758639\n", + "1 -0.386586 7.593937e-15 0.626063\n", + "2 -0.386586 6.106546e-15 0.626063\n", + "3 -0.852014 5.547435e-15 0.562447\n", "4 0.783539 -6.333333e-01 0.130201" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -402,7 +497,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-10-11T22:33:03.136728Z", @@ -412,6 +507,53 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Color_PURPLE Color_YELLOW Size_LARGE Size_SMALL Action_DIP \n", + "0 0.0 1.0 0.0 1.0 0.0 \\\n", + "1 0.0 1.0 0.0 1.0 0.0 \n", + "2 0.0 1.0 0.0 1.0 1.0 \n", + "3 0.0 1.0 0.0 1.0 1.0 \n", + "4 0.0 1.0 1.0 0.0 0.0 \n", + "5 0.0 1.0 1.0 0.0 0.0 \n", + "6 0.0 1.0 1.0 0.0 0.0 \n", + "7 0.0 1.0 1.0 0.0 1.0 \n", + "8 0.0 1.0 1.0 0.0 1.0 \n", + "9 1.0 0.0 0.0 1.0 0.0 \n", + "10 1.0 0.0 0.0 1.0 0.0 \n", + "11 1.0 0.0 0.0 1.0 0.0 \n", + "12 1.0 0.0 0.0 1.0 1.0 \n", + "13 1.0 0.0 0.0 1.0 1.0 \n", + "14 1.0 0.0 1.0 0.0 0.0 \n", + "15 1.0 0.0 1.0 0.0 0.0 \n", + "16 1.0 0.0 1.0 0.0 0.0 \n", + "17 1.0 0.0 1.0 0.0 1.0 \n", + "18 1.0 0.0 1.0 0.0 1.0 \n", + "\n", + " Action_STRETCH Age_ADULT Age_CHILD Inflated_F Inflated_T \n", + "0 1.0 1.0 0.0 0.0 1.0 \n", + "1 1.0 0.0 1.0 1.0 0.0 \n", + "2 0.0 1.0 0.0 1.0 0.0 \n", + "3 0.0 0.0 1.0 1.0 0.0 \n", + "4 1.0 1.0 0.0 0.0 1.0 \n", + "5 1.0 1.0 0.0 0.0 1.0 \n", + "6 1.0 0.0 1.0 1.0 0.0 \n", + "7 0.0 1.0 0.0 1.0 0.0 \n", + "8 0.0 0.0 1.0 1.0 0.0 \n", + "9 1.0 1.0 0.0 0.0 1.0 \n", + "10 1.0 1.0 0.0 0.0 1.0 \n", + "11 1.0 0.0 1.0 1.0 0.0 \n", + "12 0.0 1.0 0.0 1.0 0.0 \n", + "13 0.0 0.0 1.0 1.0 0.0 \n", + "14 1.0 1.0 0.0 0.0 1.0 \n", + "15 1.0 1.0 0.0 0.0 1.0 \n", + "16 1.0 0.0 1.0 1.0 0.0 \n", + "17 0.0 1.0 0.0 1.0 0.0 \n", + "18 0.0 0.0 1.0 1.0 0.0 \n" + ] + }, { "data": { "text/html": [ @@ -466,7 +608,7 @@ " \n", " Action_DIP\n", " -0.853864\n", - " -1.953058e-15\n", + " -2.712409e-15\n", " -0.079340\n", " \n", " \n", @@ -479,10 +621,10 @@ "Color_YELLOW -0.130342 -7.657805e-01 0.712523\n", "Size_LARGE 0.117308 -6.892024e-01 -0.641270\n", "Size_SMALL -0.130342 7.657805e-01 0.712523\n", - "Action_DIP -0.853864 -1.953058e-15 -0.079340" + "Action_DIP -0.853864 -2.712409e-15 -0.079340" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -501,7 +643,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2023-10-11T22:33:03.165704Z", @@ -511,17 +653,105 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Color_PURPLE Color_YELLOW Size_LARGE Size_SMALL Action_DIP \n", + "0 0.0 1.0 0.0 1.0 0.0 \\\n", + "1 0.0 1.0 0.0 1.0 0.0 \n", + "2 0.0 1.0 0.0 1.0 1.0 \n", + "3 0.0 1.0 0.0 1.0 1.0 \n", + "4 0.0 1.0 1.0 0.0 0.0 \n", + "5 0.0 1.0 1.0 0.0 0.0 \n", + "6 0.0 1.0 1.0 0.0 0.0 \n", + "7 0.0 1.0 1.0 0.0 1.0 \n", + "8 0.0 1.0 1.0 0.0 1.0 \n", + "9 1.0 0.0 0.0 1.0 0.0 \n", + "10 1.0 0.0 0.0 1.0 0.0 \n", + "11 1.0 0.0 0.0 1.0 0.0 \n", + "12 1.0 0.0 0.0 1.0 1.0 \n", + "13 1.0 0.0 0.0 1.0 1.0 \n", + "14 1.0 0.0 1.0 0.0 0.0 \n", + "15 1.0 0.0 1.0 0.0 0.0 \n", + "16 1.0 0.0 1.0 0.0 0.0 \n", + "17 1.0 0.0 1.0 0.0 1.0 \n", + "18 1.0 0.0 1.0 0.0 1.0 \n", + "\n", + " Action_STRETCH Age_ADULT Age_CHILD Inflated_F Inflated_T \n", + "0 1.0 1.0 0.0 0.0 1.0 \n", + "1 1.0 0.0 1.0 1.0 0.0 \n", + "2 0.0 1.0 0.0 1.0 0.0 \n", + "3 0.0 0.0 1.0 1.0 0.0 \n", + "4 1.0 1.0 0.0 0.0 1.0 \n", + "5 1.0 1.0 0.0 0.0 1.0 \n", + "6 1.0 0.0 1.0 1.0 0.0 \n", + "7 0.0 1.0 0.0 1.0 0.0 \n", + "8 0.0 0.0 1.0 1.0 0.0 \n", + "9 1.0 1.0 0.0 0.0 1.0 \n", + "10 1.0 1.0 0.0 0.0 1.0 \n", + "11 1.0 0.0 1.0 1.0 0.0 \n", + "12 0.0 1.0 0.0 1.0 0.0 \n", + "13 0.0 0.0 1.0 1.0 0.0 \n", + "14 1.0 1.0 0.0 0.0 1.0 \n", + "15 1.0 1.0 0.0 0.0 1.0 \n", + "16 1.0 0.0 1.0 1.0 0.0 \n", + "17 0.0 1.0 0.0 1.0 0.0 \n", + "18 0.0 0.0 1.0 1.0 0.0 \n", + " Color_PURPLE Color_YELLOW Size_LARGE Size_SMALL Action_DIP \n", + "0 0.0 1.0 0.0 1.0 0.0 \\\n", + "1 0.0 1.0 0.0 1.0 0.0 \n", + "2 0.0 1.0 0.0 1.0 1.0 \n", + "3 0.0 1.0 0.0 1.0 1.0 \n", + "4 0.0 1.0 1.0 0.0 0.0 \n", + "5 0.0 1.0 1.0 0.0 0.0 \n", + "6 0.0 1.0 1.0 0.0 0.0 \n", + "7 0.0 1.0 1.0 0.0 1.0 \n", + "8 0.0 1.0 1.0 0.0 1.0 \n", + "9 1.0 0.0 0.0 1.0 0.0 \n", + "10 1.0 0.0 0.0 1.0 0.0 \n", + "11 1.0 0.0 0.0 1.0 0.0 \n", + "12 1.0 0.0 0.0 1.0 1.0 \n", + "13 1.0 0.0 0.0 1.0 1.0 \n", + "14 1.0 0.0 1.0 0.0 0.0 \n", + "15 1.0 0.0 1.0 0.0 0.0 \n", + "16 1.0 0.0 1.0 0.0 0.0 \n", + "17 1.0 0.0 1.0 0.0 1.0 \n", + "18 1.0 0.0 1.0 0.0 1.0 \n", + "\n", + " Action_STRETCH Age_ADULT Age_CHILD Inflated_F Inflated_T \n", + "0 1.0 1.0 0.0 0.0 1.0 \n", + "1 1.0 0.0 1.0 1.0 0.0 \n", + "2 0.0 1.0 0.0 1.0 0.0 \n", + "3 0.0 0.0 1.0 1.0 0.0 \n", + "4 1.0 1.0 0.0 0.0 1.0 \n", + "5 1.0 1.0 0.0 0.0 1.0 \n", + "6 1.0 0.0 1.0 1.0 0.0 \n", + "7 0.0 1.0 0.0 1.0 0.0 \n", + "8 0.0 0.0 1.0 1.0 0.0 \n", + "9 1.0 1.0 0.0 0.0 1.0 \n", + "10 1.0 1.0 0.0 0.0 1.0 \n", + "11 1.0 0.0 1.0 1.0 0.0 \n", + "12 0.0 1.0 0.0 1.0 0.0 \n", + "13 0.0 0.0 1.0 1.0 0.0 \n", + "14 1.0 1.0 0.0 0.0 1.0 \n", + "15 1.0 1.0 0.0 0.0 1.0 \n", + "16 1.0 0.0 1.0 1.0 0.0 \n", + "17 0.0 1.0 0.0 1.0 0.0 \n", + "18 0.0 0.0 1.0 1.0 0.0 \n" + ] + }, { "data": { "text/html": [ "\n", - "
\n", + "
\n", "" ], "text/plain": [ "alt.LayerChart(...)" ] }, - "execution_count": 7, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -601,7 +831,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2023-10-11T22:33:03.264339Z", @@ -616,54 +846,54 @@ "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 012012
07%0%16%07%0%16%
12%0%11%12%0%11%
22%0%11%22%0%11%
310%0%9%310%0%9%
48%10%0%48%10%0%
\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -674,7 +904,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2023-10-11T22:33:03.464586Z", @@ -689,54 +919,54 @@ "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 012012
Color_PURPLE0%24%23%Color_PURPLE0%24%23%
Color_YELLOW0%26%26%Color_YELLOW0%26%26%
Size_LARGE0%24%23%Size_LARGE0%24%23%
Size_SMALL0%26%26%Size_SMALL0%26%26%
Action_DIP15%0%0%Action_DIP15%0%0%
\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -755,7 +985,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2023-10-11T22:33:03.478001Z", @@ -765,6 +995,53 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Color_PURPLE Color_YELLOW Size_LARGE Size_SMALL Action_DIP \n", + "0 0.0 1.0 0.0 1.0 0.0 \\\n", + "1 0.0 1.0 0.0 1.0 0.0 \n", + "2 0.0 1.0 0.0 1.0 1.0 \n", + "3 0.0 1.0 0.0 1.0 1.0 \n", + "4 0.0 1.0 1.0 0.0 0.0 \n", + "5 0.0 1.0 1.0 0.0 0.0 \n", + "6 0.0 1.0 1.0 0.0 0.0 \n", + "7 0.0 1.0 1.0 0.0 1.0 \n", + "8 0.0 1.0 1.0 0.0 1.0 \n", + "9 1.0 0.0 0.0 1.0 0.0 \n", + "10 1.0 0.0 0.0 1.0 0.0 \n", + "11 1.0 0.0 0.0 1.0 0.0 \n", + "12 1.0 0.0 0.0 1.0 1.0 \n", + "13 1.0 0.0 0.0 1.0 1.0 \n", + "14 1.0 0.0 1.0 0.0 0.0 \n", + "15 1.0 0.0 1.0 0.0 0.0 \n", + "16 1.0 0.0 1.0 0.0 0.0 \n", + "17 1.0 0.0 1.0 0.0 1.0 \n", + "18 1.0 0.0 1.0 0.0 1.0 \n", + "\n", + " Action_STRETCH Age_ADULT Age_CHILD Inflated_F Inflated_T \n", + "0 1.0 1.0 0.0 0.0 1.0 \n", + "1 1.0 0.0 1.0 1.0 0.0 \n", + "2 0.0 1.0 0.0 1.0 0.0 \n", + "3 0.0 0.0 1.0 1.0 0.0 \n", + "4 1.0 1.0 0.0 0.0 1.0 \n", + "5 1.0 1.0 0.0 0.0 1.0 \n", + "6 1.0 0.0 1.0 1.0 0.0 \n", + "7 0.0 1.0 0.0 1.0 0.0 \n", + "8 0.0 0.0 1.0 1.0 0.0 \n", + "9 1.0 1.0 0.0 0.0 1.0 \n", + "10 1.0 1.0 0.0 0.0 1.0 \n", + "11 1.0 0.0 1.0 1.0 0.0 \n", + "12 0.0 1.0 0.0 1.0 0.0 \n", + "13 0.0 0.0 1.0 1.0 0.0 \n", + "14 1.0 1.0 0.0 0.0 1.0 \n", + "15 1.0 1.0 0.0 0.0 1.0 \n", + "16 1.0 0.0 1.0 1.0 0.0 \n", + "17 0.0 1.0 0.0 1.0 0.0 \n", + "18 0.0 0.0 1.0 1.0 0.0 \n" + ] + }, { "data": { "text/html": [ @@ -795,25 +1072,25 @@ " \n", " 0\n", " 0.461478\n", - " 6.638620e-29\n", + " 8.387409e-29\n", " 0.533786\n", " \n", " \n", " 1\n", " 0.152256\n", - " 7.385455e-29\n", + " 5.875091e-29\n", " 0.399316\n", " \n", " \n", " 2\n", " 0.152256\n", - " 3.978637e-29\n", + " 3.799023e-29\n", " 0.399316\n", " \n", " \n", " 3\n", " 0.653335\n", - " 4.251294e-29\n", + " 2.769663e-29\n", " 0.284712\n", " \n", " \n", @@ -828,14 +1105,14 @@ ], "text/plain": [ " 0 1 2\n", - "0 0.461478 6.638620e-29 0.533786\n", - "1 0.152256 7.385455e-29 0.399316\n", - "2 0.152256 3.978637e-29 0.399316\n", - "3 0.653335 4.251294e-29 0.284712\n", + "0 0.461478 8.387409e-29 0.533786\n", + "1 0.152256 5.875091e-29 0.399316\n", + "2 0.152256 3.799023e-29 0.399316\n", + "3 0.653335 2.769663e-29 0.284712\n", "4 0.592606 3.871772e-01 0.016363" ] }, - "execution_count": 10, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -846,7 +1123,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2023-10-11T22:33:03.494732Z", @@ -856,6 +1133,53 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Color_PURPLE Color_YELLOW Size_LARGE Size_SMALL Action_DIP \n", + "0 0.0 1.0 0.0 1.0 0.0 \\\n", + "1 0.0 1.0 0.0 1.0 0.0 \n", + "2 0.0 1.0 0.0 1.0 1.0 \n", + "3 0.0 1.0 0.0 1.0 1.0 \n", + "4 0.0 1.0 1.0 0.0 0.0 \n", + "5 0.0 1.0 1.0 0.0 0.0 \n", + "6 0.0 1.0 1.0 0.0 0.0 \n", + "7 0.0 1.0 1.0 0.0 1.0 \n", + "8 0.0 1.0 1.0 0.0 1.0 \n", + "9 1.0 0.0 0.0 1.0 0.0 \n", + "10 1.0 0.0 0.0 1.0 0.0 \n", + "11 1.0 0.0 0.0 1.0 0.0 \n", + "12 1.0 0.0 0.0 1.0 1.0 \n", + "13 1.0 0.0 0.0 1.0 1.0 \n", + "14 1.0 0.0 1.0 0.0 0.0 \n", + "15 1.0 0.0 1.0 0.0 0.0 \n", + "16 1.0 0.0 1.0 0.0 0.0 \n", + "17 1.0 0.0 1.0 0.0 1.0 \n", + "18 1.0 0.0 1.0 0.0 1.0 \n", + "\n", + " Action_STRETCH Age_ADULT Age_CHILD Inflated_F Inflated_T \n", + "0 1.0 1.0 0.0 0.0 1.0 \n", + "1 1.0 0.0 1.0 1.0 0.0 \n", + "2 0.0 1.0 0.0 1.0 0.0 \n", + "3 0.0 0.0 1.0 1.0 0.0 \n", + "4 1.0 1.0 0.0 0.0 1.0 \n", + "5 1.0 1.0 0.0 0.0 1.0 \n", + "6 1.0 0.0 1.0 1.0 0.0 \n", + "7 0.0 1.0 0.0 1.0 0.0 \n", + "8 0.0 0.0 1.0 1.0 0.0 \n", + "9 1.0 1.0 0.0 0.0 1.0 \n", + "10 1.0 1.0 0.0 0.0 1.0 \n", + "11 1.0 0.0 1.0 1.0 0.0 \n", + "12 0.0 1.0 0.0 1.0 0.0 \n", + "13 0.0 0.0 1.0 1.0 0.0 \n", + "14 1.0 1.0 0.0 0.0 1.0 \n", + "15 1.0 1.0 0.0 0.0 1.0 \n", + "16 1.0 0.0 1.0 1.0 0.0 \n", + "17 0.0 1.0 0.0 1.0 0.0 \n", + "18 0.0 0.0 1.0 1.0 0.0 \n" + ] + }, { "data": { "text/html": [ @@ -910,7 +1234,7 @@ " \n", " Action_DIP\n", " 0.530243\n", - " 2.774134e-30\n", + " 5.350665e-30\n", " 0.004578\n", " \n", " \n", @@ -923,10 +1247,10 @@ "Color_YELLOW 0.015290 5.277778e-01 0.456920\n", "Size_LARGE 0.015290 5.277778e-01 0.456920\n", "Size_SMALL 0.015290 5.277778e-01 0.456920\n", - "Action_DIP 0.530243 2.774134e-30 0.004578" + "Action_DIP 0.530243 5.350665e-30 0.004578" ] }, - "execution_count": 11, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -934,6 +1258,76 @@ "source": [ "mca.column_cosine_similarities(dataset).head()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Handling unknown categories\n", + "\n", + "The MCA implementation in Prince implements sklearn's fit/transfrom API. This means that you can use the `transform` method to transform new data. The latter might differ from the training data, in that it may contain categories that were not present in the training data. By default, the MCA implementation will raise an error if it encounters such a category. You can change this behavior by setting the `handle_unknown` parameter to `'ignore'`. In this case, the unknown categories will be ignored." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
01
31.4142143.326586e-16
\n", + "
" + ], + "text/plain": [ + " 0 1\n", + "3 1.414214 3.326586e-16" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset = pd.DataFrame({\n", + " 'var1': ['c', 'a', 'b', 'c'],\n", + " 'var2': ['x', 'y', 'y', 'z']\n", + "})\n", + "\n", + "mca = prince.MCA(n_components=2, random_state=42, handle_unknown='ignore')\n", + "mca.fit(dataset[:3])\n", + "mca.transform(dataset[-1:])" + ] } ], "metadata": { @@ -952,7 +1346,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.11.4" }, "vscode": { "interpreter": { diff --git a/prince/mca.py b/prince/mca.py index e49305e3..e6e23a58 100644 --- a/prince/mca.py +++ b/prince/mca.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import sklearn.base +import sklearn.preprocessing import sklearn.utils from prince import utils @@ -21,6 +22,7 @@ def __init__( random_state=None, engine="sklearn", one_hot=True, + handle_unknown="error", ): super().__init__( n_components=n_components, @@ -31,10 +33,22 @@ def __init__( engine=engine, ) self.one_hot = one_hot + self.handle_unknown = handle_unknown def _prepare(self, X): if self.one_hot: - X = pd.get_dummies(X, columns=X.columns) + # Create the one-hot encoder if it doesn't exist (usually because we're in the fit method) + if not hasattr(self, "one_hot_encoder_"): + self.one_hot_encoder_ = ( + ( + sklearn.preprocessing.OneHotEncoder(handle_unknown=self.handle_unknown, sparse_output=False) + .set_output(transform="pandas") + .fit(X) + ) + if not hasattr(self, "one_hot_encoder_") + else self.one_hot_encoder_ + ) + X = self.one_hot_encoder_.transform(X) return X def get_feature_names_out(self, input_features=None): From bbe1a293e0e4108b432fba7b825e9832feb749d4 Mon Sep 17 00:00:00 2001 From: Max Halford Date: Sat, 7 Sep 2024 17:11:17 +0200 Subject: [PATCH 5/7] run docs --- docs/content/ca.ipynb | 202 ++++++------- docs/content/famd.ipynb | 154 +++++----- docs/content/faq.ipynb | 13 +- docs/content/gpa.ipynb | 70 ++--- docs/content/mca.ipynb | 573 +++++++++---------------------------- docs/content/mfa.ipynb | 80 +++--- docs/content/pca.ipynb | 612 ++++++++++++++++++++-------------------- prince/mca.py | 5 +- 8 files changed, 700 insertions(+), 1009 deletions(-) diff --git a/docs/content/ca.ipynb b/docs/content/ca.ipynb index 2c30f9df..78aac6d2 100644 --- a/docs/content/ca.ipynb +++ b/docs/content/ca.ipynb @@ -41,10 +41,10 @@ "execution_count": 1, "metadata": { "execution": { - "iopub.execute_input": "2023-10-11T22:32:55.002617Z", - "iopub.status.busy": "2023-10-11T22:32:55.002348Z", - "iopub.status.idle": "2023-10-11T22:32:55.412294Z", - "shell.execute_reply": "2023-10-11T22:32:55.411835Z" + "iopub.execute_input": "2024-09-07T15:09:11.292492Z", + "iopub.status.busy": "2024-09-07T15:09:11.292202Z", + "iopub.status.idle": "2024-09-07T15:09:11.862852Z", + "shell.execute_reply": "2024-09-07T15:09:11.862470Z" } }, "outputs": [ @@ -165,10 +165,10 @@ "execution_count": 2, "metadata": { "execution": { - "iopub.execute_input": "2023-10-11T22:32:55.427218Z", - "iopub.status.busy": "2023-10-11T22:32:55.427073Z", - "iopub.status.idle": "2023-10-11T22:32:55.440692Z", - "shell.execute_reply": "2023-10-11T22:32:55.439758Z" + "iopub.execute_input": "2024-09-07T15:09:11.878246Z", + "iopub.status.busy": "2024-09-07T15:09:11.878113Z", + "iopub.status.idle": "2024-09-07T15:09:11.900160Z", + "shell.execute_reply": "2024-09-07T15:09:11.899416Z" } }, "outputs": [], @@ -197,10 +197,10 @@ "execution_count": 3, "metadata": { "execution": { - "iopub.execute_input": "2023-10-11T22:32:55.445385Z", - "iopub.status.busy": "2023-10-11T22:32:55.444977Z", - "iopub.status.idle": "2023-10-11T22:32:55.479167Z", - "shell.execute_reply": "2023-10-11T22:32:55.478398Z" + "iopub.execute_input": "2024-09-07T15:09:11.907587Z", + "iopub.status.busy": "2024-09-07T15:09:11.907247Z", + "iopub.status.idle": "2024-09-07T15:09:11.934963Z", + "shell.execute_reply": "2024-09-07T15:09:11.932677Z" } }, "outputs": [ @@ -289,10 +289,10 @@ "execution_count": 4, "metadata": { "execution": { - "iopub.execute_input": "2023-10-11T22:32:55.483618Z", - "iopub.status.busy": "2023-10-11T22:32:55.482903Z", - "iopub.status.idle": "2023-10-11T22:32:55.505250Z", - "shell.execute_reply": "2023-10-11T22:32:55.504272Z" + "iopub.execute_input": "2024-09-07T15:09:11.947085Z", + "iopub.status.busy": "2024-09-07T15:09:11.946714Z", + "iopub.status.idle": "2024-09-07T15:09:11.974884Z", + "shell.execute_reply": "2024-09-07T15:09:11.973959Z" } }, "outputs": [ @@ -387,10 +387,10 @@ "execution_count": 5, "metadata": { "execution": { - "iopub.execute_input": "2023-10-11T22:32:55.509208Z", - "iopub.status.busy": "2023-10-11T22:32:55.508562Z", - "iopub.status.idle": "2023-10-11T22:32:55.525166Z", - "shell.execute_reply": "2023-10-11T22:32:55.524328Z" + "iopub.execute_input": "2024-09-07T15:09:11.983068Z", + "iopub.status.busy": "2024-09-07T15:09:11.982742Z", + "iopub.status.idle": "2024-09-07T15:09:12.021638Z", + "shell.execute_reply": "2024-09-07T15:09:12.020841Z" } }, "outputs": [ @@ -493,10 +493,10 @@ "execution_count": 6, "metadata": { "execution": { - "iopub.execute_input": "2023-10-11T22:32:55.529489Z", - "iopub.status.busy": "2023-10-11T22:32:55.528716Z", - "iopub.status.idle": "2023-10-11T22:32:55.625872Z", - "shell.execute_reply": "2023-10-11T22:32:55.625020Z" + "iopub.execute_input": "2024-09-07T15:09:12.032702Z", + "iopub.status.busy": "2024-09-07T15:09:12.032397Z", + "iopub.status.idle": "2024-09-07T15:09:12.123383Z", + "shell.execute_reply": "2024-09-07T15:09:12.122568Z" } }, "outputs": [ @@ -504,13 +504,13 @@ "data": { "text/html": [ "\n", - "
\n", + "
\n", "