.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/_general/plot_moon_clustering.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples__general_plot_moon_clustering.py: ================================================================================================ Drawing a decision boundary between two interlacing moons ================================================================================================ This example is a retake from the experiment in the original GEMINI paper where we want to find the true clusters in between two facing moons. To do so, the trick is to use a specific distance using the "precomputed" option which will guide the clustering algorithm to the desired solution. Note that we use :class:`gemclus.mlp.MLPWasserstein` because a linear model would not be able to find the optimal boundary. .. GENERATED FROM PYTHON SOURCE LINES 13-21 .. code-block:: Python import numpy as np from matplotlib import pyplot as plt from scipy.sparse import csgraph from sklearn import datasets, metrics from gemclus.mlp import MLPWasserstein .. GENERATED FROM PYTHON SOURCE LINES 22-24 Generate two interlacing moons -------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 26-28 .. code-block:: Python X, y = datasets.make_moons(n_samples=200, noise=0.05, random_state=2023) .. GENERATED FROM PYTHON SOURCE LINES 29-31 Pre-compute a specific metric between samples -------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 33-46 .. code-block:: Python # Create an adjacency graph where edges are defined if the distance between two samples is # less than the 5% quantile of the Euclidean distances distances = metrics.pairwise_distances(X, metric="euclidean") threshold = np.quantile(distances, 0.05) adjacency = distances < threshold # compute the all-pairs shortest path in this graph distances = csgraph.floyd_warshall(adjacency, directed=False, unweighted=True) # Replace np.inf with 2 times the size of the matrix distances[np.isinf(distances)] = 2 * distances.shape[0] .. GENERATED FROM PYTHON SOURCE LINES 47-50 Train the model -------------------------------------------------------------- Note that we use the precomputed option and pass our distance to the `.fit` function along `X`. .. GENERATED FROM PYTHON SOURCE LINES 52-55 .. code-block:: Python model = MLPWasserstein(n_clusters=2, metric="precomputed", random_state=2023, learning_rate=1e-2) y_pred = model.fit_predict(X, distances) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/.local/lib/python3.10/site-packages/sklearn/base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API. warnings.warn( .. GENERATED FROM PYTHON SOURCE LINES 56-58 Final Clustering ----------------- .. GENERATED FROM PYTHON SOURCE LINES 60-71 .. code-block:: Python x_vals = np.linspace(X[:, 0].min() - 0.5, X[:, 0].max() + 0.5, num=50) y_vals = np.linspace(X[:, 1].min() - 0.5, X[:, 1].max() + 0.5, num=50) xx, yy = np.meshgrid(x_vals, y_vals) grid_inputs = np.c_[xx.ravel(), yy.ravel()] zz = model.predict(grid_inputs).reshape((50, 50)) plt.contourf(xx, yy, zz, alpha=0.3, cmap=plt.cm.Spectral) plt.scatter(X[:, 0], X[:, 1], c=y_pred) plt.axis("off") plt.show() .. image-sg:: /auto_examples/_general/images/sphx_glr_plot_moon_clustering_001.png :alt: plot moon clustering :srcset: /auto_examples/_general/images/sphx_glr_plot_moon_clustering_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 7.955 seconds) .. _sphx_glr_download_auto_examples__general_plot_moon_clustering.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_moon_clustering.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_moon_clustering.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_moon_clustering.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_