# TDA: TP 2

# Statistics, Unsupervised, and Supervised Machine Learning on 3D shapes with Topological Data Analysis

In this practical session, we will use the various TDA tools presented in class in order to run data science tasks (inference, clustering, classification) on a data set of 3D shapes. As in the first practical session, we will use [`Gudhi`](https://gudhi.inria.fr/) (see first practical session for installation instructions). The different sections of this notebook can be run independently (except Section 0 which is mandatory), so feel free to start with the project that sounds the more interesting to you :-)

Note also that if you choose to switch from a section to another, make sure to clear all variables first (and run Section 0 again) since some variable names are shared between sections.

In [None]:
import gudhi as gd
print(gd.__version__)

In [None]:
import gudhi.clustering.tomato as gdt
import gudhi.representations as gdr

The `TensorFlow` module of `Gudhi` is only required in Section 4.

In [None]:
import gudhi.tensorflow as gdtf

Other than that, you are free to use whatever other Python package you feel comfortable with :-) We make some suggestions below (these dependencies are also required to run our solutions to the exercises). 

In [None]:
import os
import sys

We will use three standard Python libraries: `NumPy`, `Scipy` and `Matplotlib`.

In [None]:
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt

In [None]:
%matplotlib notebook

In order to visualize 3D shapes, we will use [`meshplot`](https://skoch9.github.io/meshplot/tutorial/).

In [None]:
import meshplot as mp

Finally, some dependencies are section-specific: we list those below.

In Section 1, when running bootstrap tests on ToMATo, we will use the `statistics` Python module. These tests will be based on the Laplace-Beltrami operator, which can be computed with [`robust_laplacian`](https://pypi.org/project/robust-laplacian/).

In [None]:
import statistics
import robust_laplacian as rlap

In Section 2, we will use the [`networkx`](https://networkx.org/) package to visualize and run computations on Mapper graphs.

In [None]:
import networkx as nx

In Sections 3 and 4, when computing vectorizations and performing supervised machine learning and deep learning tasks, we will use various modules of [`Scikit-Learn`](https://scikit-learn.org/stable/index.html) and [`TensorFlow`](https://www.tensorflow.org/). 

In [None]:
import sklearn.preprocessing as skp
import sklearn.neighbors as skn
import sklearn.model_selection as skm
import sklearn.decomposition as skd
import sklearn.manifold as skf
import sklearn.pipeline as skl
import sklearn.svm as sks
import sklearn.ensemble as ske

In [None]:
import itertools
import tensorflow as tf

# Section 0: Data set manipulation

We are good to go! First things first, we have to download the data set. It can be obtained [here](https://people.cs.umass.edu/~kalo/papers/LabelMeshes/labeledDb.7z). Extract it, and save its path in the `dataset_path` variable.

In [None]:
dataset_path = './3dshapes/'

As you can see, the data set in split in several categories (`Airplane`, `Human`, `Teddy`, etc), each category having its own folder. Inside each folder, some 3D shapes (i.e., 3D triangulations) are provided in [`.off`](https://en.wikipedia.org/wiki/OFF_(file_format)) format, and face (i.e., triangle) labels are provided in text files (extension `.txt`). 

Every data science project begins by some preprocessing ;-) 

Write a function `off2numpy` that reads information from an `.off` file and store it in two `NumPy` arrays, called `vertices` (type float and shape number_of_vertices x 3---the 3D coordinates of the vertices) and `faces` (type integer and shape number_of_faces x 3---the IDs of the vertices that create faces). Write also a function `get_labels` that stores the face labels of a given 3D shape in a `NumPy` array (type string or integer and shape [number_of_faces]. 

In [None]:
def off2numpy(shape_name):
 with open(shape_name, 'r') as S:
 S.readline()
 num_vertices, num_faces, _ = [int(n) for n in S.readline().split(' ')]
 info = S.readlines()
 vertices = np.array([[float(coord) for coord in l.split(' ')] for l in info[0:num_vertices]])
 faces = np.array([[int(coord) for coord in l.split(' ')[1:]] for l in info[num_vertices:]])
 return vertices, faces

In [None]:
def get_labels(label_name, num_faces):
 L = np.empty([num_faces], dtype='|S100')
 with open(label_name, 'r') as S:
 info = S.readlines()
 labels, face_indices = info[0::2], info[1::2]
 for ilab, lab in enumerate(labels):
 indices = [int(f)-1 for f in face_indices[ilab].split(' ')[:-1]]
 L[ np.array(indices) ] = lab[:-1]
 return L

You can now apply your code and use `meshplot` to visualize a given 3D shape, say `61.off` in `Airplane`, and the labels on its faces.

In [None]:
vertices, faces = off2numpy(dataset_path + 'Airplane/61.off')
label_faces = get_labels(dataset_path + 'Airplane/61_labels.txt', len(faces))

In [None]:
mp.plot(vertices, faces, c=skp.LabelEncoder().fit_transform(label_faces))

# Section 1: 3D robust segmentation with ToMATo

In this section, our goal is to use the ToMATo algorithm to compute segmentations of 3D shapes, i.e., to assign labels to 3D shape vertices in an unsupervised way, that is, without training on known labels. This task was initially explored in [this article](https://www.lix.polytechnique.fr/~maks/papers/pers_seg.pdf). 

Overall, the main idea is to run ToMATo on the neighborhood graph given by the triangulation, with the so-called Heat Kernel Signature (HKS) as the filter. This is motivated by the fact that the HKS function typically takes higher values on the parts of the 3D shape that are very curved (such as, e.g., the tips of fingers in human hand shapes). 

The HKS was defined in [this article](https://onlinelibrary.wiley.com/doi/epdf/10.1111/j.1467-8659.2009.01515.x). It is related to the heat equation on a given 3D shape $S$:

$$\Delta_S f = -\frac{\partial f}{\partial t}.$$

More formally, the HKS function with parameter $t >0$ on a vertex $v\in S$, and denoted by ${\rm HKS}_t(v)$, is computed as:

$${\rm HKS}_t(v) = \sum_{i=0}^{+\infty} {\rm exp}(-\lambda_i\cdot t)\cdot \phi_i^2(v),$$

where $\{\lambda_i, \phi_i\}_i$ are the eigenvalues and eigenvectors of $\Delta_S$.
Intuitively, ${\rm HKS}_t(v)$ is the amount of heat remaining on $v$ at time $t$, after unit sources of heat have been placed on each vertex at time `t=0`.

Let's first pick a 3D shape. For instance, use `Hand/181.off` (or any other one you would like to try).

In [None]:
vertices, faces = off2numpy(dataset_path + 'Hand/181.off')

Now, use `robust_laplacian` to compute the first 200 eigenvalues and eigenvectors of its Laplacian (you can use the `eigsh` function of `SciPy` for diagonalizing the Laplacian).

Write a function `HKS` that uses these eigenvalues and eigenvectors, as well as a time parameter, to compute the HKS value on a given vertex.

Visualize the function values with `meshplot` for different time parameters.

Recall that ToMATo requires, in addition to the filter, a neighborhood graph built on top of the data. Fortunately, we can use the triangulations of our 3D shapes as input graphs! Write a function `get_neighborhood_graph_from_faces` that computes a neighborhood graph (in the format required by ToMATo) from the faces of a triangulation. 

Finally, apply ToMATo (with no prior on the number of clusters or merging threshold) on the neighborhood graph and the HKS function associated to a given time parameter.

Visualize the persistence diagram produced by ToMATo.

How many points do you see standing out from the diagonal? Use this number to re-cluster.

Visualize the 3D shape with the ToMATo labels.

Does our segmentation make sense? Can you interpret the boundaries between labels?

Since the boundaries are driven by the elder rule, they can seem a bit shaggy. In order to fix this, we can use bootstrap-like smoothing. The idea is to first save the current ToMATo clustering obtained with filter $f$ (let's call it the initial clustering), and then perturb $f$ a little bit into another function $\tilde f$, and finally recompute clustering with ToMATo using $\tilde f$. Since clusters are now created with the maxima of $\tilde f$ (which will be different in general from those of $f$), we can use the initial clustering to relate the clusters of $\tilde f$ to those of $f$, by simply looking at which (initial) clusters do the maxima of $\tilde f$ belong to. If we repeat this procedure $N$ times, we will end up with a distribution (of size $N$) of candidate clusters for each vertex $v$. It suffices to pick the most frequent one for each vertex to get a smooth segmentation for the 3D shape. See also Section 6 in [the article](https://www.lix.polytechnique.fr/~maks/papers/pers_seg.pdf).

In order to implement this, write first a function `get_indices_of_maxima` which computes the indices of the maxima associated to a set of ToMATo clusters.

Compute and plot these maxima on the 3D shape to make sure your code is working.

Now, use this function to write another function `bootstrap_tomato` that perform a bootstrap smoothing of a set to ToMATo labels. This function will also take as arguments a number $N$ of bootstrap iterations, and a parameter $\epsilon$ controlling the amplitude of the uniform noise used to perturb the filter.

Apply the bootstrap smoothing and visualize the segmentation.

Is the segmentation any better? How does the result depend on the noise amplitude?

# Section 2: 3D shape skeletonization with Mapper

In this section, our goal is to use Mapper to produce 1-skeletons (i.e., graphs) of 3D shapes. We will also see how to partition this graph into different parts and run statistical tests to decide whether these parts should be considered as numerical artifacts or true signal.

Let's first pick a 3D shape. For instance, use `Human/3.off` (or any other one you would like to try).

In [None]:
vertices, faces = off2numpy(dataset_path + 'Human/4.off')

In [None]:
mp.plot(vertices, faces, c=vertices[:,1])

In `Gudhi`, Mapper is implemented as a specific case of Graph Induced Complex (GIC), see Definition 2.1 in [this article](https://web.cse.ohio-state.edu/~dey.8/paper/GIC/GIC.pdf). Indeed, given a fixed vertex cover, Mapper computed with hierarchical clustering with parameter $\delta > 0$ is (roughly) the same as GIC computed with neighborhood graph with parameter $\delta$.

Initiate a `CoverComplex` from `Gudhi`, and set its type to `"GIC"`.

Define the point cloud on which Mapper is computed with the array `vertices`, and the filter function as the height coordinate.

Define the node color function (used only for visualization) as the height coordinate as well.

Define the clustering algorithm by automatically tuning the $\delta$ parameter. This can be done by setting the neighborhood graph automatically with the `set_graph_from_automatic_rips` function. 

Finally, define the cover parameters: 20 intervals for the range of the filter (this parameter is called resolution), and 30% overlap (this one is called gain). Then, compute the cover using preimages of the intervals.

We can now compute Mapper!

During computations, the pairwise distances are saved in a binary file `matrix_dist` in order to save time for further computations. Hence, if you want to apply Mapper again on a different shape, make sure to remove this file!

The simplicial complex produced by Mapper can be obtained with the `create_simplex_tree` function. However, its vertices are given integer IDs associated to the cover used to compute Mapper. For convenience, rename the vertices from 0 to number_of_vertices in increasing order of the initial IDs. 

`Gudhi` also computes the mean of the midpoint of the interval associated to each Mapper vertex, and store it as a filtration value. Check that you have correct filtration values in your simplex tree (at least by eye ;-)).

With the `write_info` function, Gudhi can produce a `.txt` file containing information about the 1-skeleton of the Mapper, that can be processed by an utility function, available [here](https://github.com/GUDHI/gudhi-devel/blob/master/src/Nerve_GIC/utilities/KeplerMapperVisuFromTxtFile.py). Download and apply this utility function. This will produce an `.html` file that you can visualize in your browser.

Another (more convenient) way to visualize our complex is to plot its 1-skeleton in a Python figure with `networkx`. Using `networkx` documentation, write a function `get_networkx` that turns a simplicial complex into a `networkx` graph corresponding to the 1-skeleton. Make it so the `networkx` graph has two attributes, `"color"` and `"size"` that contain the mean of the filter values of the points associated to the Mapper vertices, and the number of these points (i.e., the size of the preimages) respectively. For this, you can use `subpopulation` method of Mapper, which returns the point IDs corresponding to every Mapper vertex.

Apply your function and plot your graph with `networkx.draw`. 

As seen in class, we can now compute a bag-of-feature descriptor for Mapper, defined as the extended persistence diagram of the Mapper complex associated to the filter. Compute and visuzalize this descriptor with the `compute_PD` function.

Can you guess the parts of the 3D shape that are associated to each persistence diagram point?

In order to understand the parts that are relevant, we can use the (empirical) bootstrap to generate a distribution of bottleneck distances (computed as the distances between our current persistence diagram and a distribution of persistence diagrams obtained from bootstrapped Mapper complexes), and use this distribution to derive confidence regions. Compute first such a distribution with the `compute_distribution` function.

Now, fix a confidence threshold, say 90%, and retrieve the bottleneck distance value $d_b^*$ such that 90% of distances are below this value. You can use the `compute_distance_from_confidence_level` function for that.

Finally, retrieve the points of our current persistence diagram whose distance to the diagonal is larger than $d_b^*$.

Some points were assessed as non robust, can you guess why?

Finally, one might ask whether there is a direct map from the points of the persistence diagrams to the parts of the 3D shape. It is actually a non-trivial question for Mappers of dimension greater than 2, but for Mappers in dimension 1, it is easier. Indeed, connected components and loops (corresponding to persistence diagram points in ${\rm Ext}_0$ and ${\rm Ext}_1$ respectively---see class) are standard graph features. 

Compute and visualize the connected components with the `connected_components` function of `networkx`.

Compute and visualize the loops with the `cycle_basis` function of `networkx`.

Now, concerning branches, i.e., points in ${\rm Ord}_0$ and ${\rm Rel}_1$, the question is a bit more tricky, but fortunately one can use ToMATo as an approximate solution. This is because ToMATo keeps track of the points forming connected components that are merged later on, wich correspond to branches! Hence, one can apply ToMATo with the filter (resp. the opposite of the filter) to obtain the upward (resp. downward) branches.

Since ToMATo requires neighborhood graphs as inputs, write a function `get_neighborhood_graph_from_simplex_tree` that computes the neighborhood graph associated to the 1-skeleton of a simplex tree, in a format that is acceptable for ToMATo.

Now, compute this neighborhood graph and apply ToMATo using both the filter function and its opposite (with no prior on the number of clusters or merging threshold). 

Finally, visualize the ToMATo labels on the graph.

The branches should be detected and highlighted with different colors!

# Section 3: 3D shape statistics with persistence diagrams

In this section, our goal is to compute confidence regions associated to the persistence diagram of a given 3D shape. We will study such regions for both the persistence diagram, and one of its representation, the persistence landscape. 

Let's first pick a 3D shape. Let's first pick a 3D shape. For instance, use `Hand/181.off` (or any other one you would like to try).

In [None]:
vertices, faces = off2numpy('3dshapes/Vase/361.off')

In [None]:
mp.plot(vertices, faces, c=vertices[:,1])

The first standard way of obtaining confidence regions for (geometric) persistence diagrams is through the stability theorem (see class):

$$\mathbb{P}(d_b(D_{\rm Rips}(X),D_{\rm Rips}(\hat X_n)) \geq \delta)\leq \mathbb{P}(d_H(X,\hat X_n)\geq \delta/2),$$
$$\mathbb{P}(d_b(D_{\rm Cech}(X),D_{\rm Cech}(\hat X_n)) \geq \delta)\leq \mathbb{P}(d_H(X,\hat X_n)\geq \delta),$$

where $d_H(\cdot,\cdot)$ is the Hausdorff distance, defined, for any two compact spaces $X,Y\subset \mathbb{R}^d$, as 

$$d_H(X,Y)={\rm min}\{{\rm max}_{x\in X}{\rm min}_{y\in Y}\|x-y\|, {\rm max}_{y\in Y}{\rm min}_{x\in X}\|y-x\|\}.$$

Hence, it suffices to estimate $\mathbb{P}(d_H(X,\hat X_n)\geq \delta)$ in order to derive confidence regions for persistence diagrams. There exists an upper bound for this probability when $\hat X_n$ is drawn from an $(a,b)$-standard probability measure, however this bound depends on $a$ and $b$. In the following, we will rather use the subsampling method, that allows to estimate the probability solely from subsampling $\hat X_n$ with $s(n) =o\left(\frac{n}{{\rm log}(n)}\right)$ points, and computing $d_H(\hat X_n, \hat X_{s(n)})$. The exact procedure is described in Section 4.1 in [this article](file:///user/mcarrier/home/Downloads/14-AOS1252.pdf).

Write a function `hausdorff_distance` that computes the Hausdorff distance between the vertices of our 3D shape and a subset of these vertices.

Now, write a function `hausdorff_interval` that computes this Hausdorff distance many times and uses the corresponding distribution of Hausdorff distances in order to output the bottleneck distance value associated to a given confidence level (by computing the quantile---corresponding to this confidence level---of the distribution).

Apply your code to obtain a bottleneck distance associated to, say, 90% confidence.

All right, now let's see which points of the persistence diagram are we going to label non-significant and discard. Compute the Rips and Alpha persistence diagrams of the points. 

Now, visualize the persistence diagrams with a band of size the previously computed bottleneck distance times 2 (for Alpha filtration) and 4 (for Rips filtration).

Are you discarding many points? If yes, this could be because the confidence region is computed only from the stability property of persistence diagrams: subsampling the Hausdorff distance can sometimes be quite conservative. It would be more efficient to bootstrap the persistence diagrams themselves---this is the approach advocated in Section 6 of [this article](https://www.jmlr.org/papers/volume18/15-484/15-484.pdf). However, this method was only proved for persistence diagrams obtained through the sublevel sets of kernel density estimators... But let's try it anyway! ;-)

Similarly than before, write `bottleneck_distance_bootstrap` and `bottleneck_interval` functions that compute the bottleneck distances between our current persistence diagram (in homology dimension 1) and the persistence diagrams of many bootstrap iterates.

Compute the bottleneck distance associated to a confidence level and visualize it.

Are you discarding less points in the persistence diagram now?

Another approach with more theoretical guarantees is to use the persistence landscapes associated to the persistence diagram. Indeed, valid confidence regions can be easily obtained using, e.g., algorithm 1 in [this article](https://geometrica.saclay.inria.fr/team/Fred.Chazal/papers/cflrw-scpls-14/cflrw-scpls-14.pdf). In the following, we will fix a subsample size $s(n)$, and estimate $\mathbb{E}[\Lambda_{s(n)}]$, where $\Lambda_{s(n)}$ is the landscape of a random subsample of size $s(n)$ (i.e., drawn from a probability measure $\mu$ such as, e.g., the empirical measure). 

Let's first make sure that we can compute landscapes ;-) Use `Gudhi` to compute and plot the first six persistence landscapes associated to the persistence diagram computed above in homology dimension 1. Landscapes (and other vectorizations) are implemented with the API of `Scikit-Learn` estimators, which means that you have to call the `fit_transform` method on a list of persistence diagrams in order to get their landscapes. 

Write a function `landscape_interval` that implements the landscape bootstrap procedure, that is, drawing many subsamples of size $s(n)$, computing their Alpha persistence diagrams and landscapes, computing the distribution of distances between each single landscape and their mean (multiplied by a random normal variable), and finally using the quantiles of this distribution in order to obtain confidence regions for the mean landscape.

Apply and visualize the confidence interval around the different landscapes.

The confidence regions are much better now!

Another interesting property of mean landscapes is their robustness to noise:

$$\|\mathbb{E}[\Lambda_{s(n)}^X]-\mathbb{E}[\Lambda_{s(n)}^Y]\|_\infty\leq 2 \cdot s(n) \cdot d_{GW}(\mu,\nu),$$

where $d_{GW}$ is the 1-Gromov-Wasserstein distance between probability measures. See Remark 6 in [this article](https://geometrica.saclay.inria.fr/team/Fred.Chazal/papers/cflmrw-smph-15/ICMLFinal.pdf). We will now confirm this by adding outlier noise to the 3D shape and looking at the resulting mean landscape. 

Create a noisy version of `vertices` with some outlier noise.

Let's first compare the persistence landscapes of the two sets of vertices. Compute and visualize these landscapes on the same plot.

As one can see, they are quite different. By contrast, computing the mean landscape with subsamples is much more robust, as we will now see.

Compute the mean persistence landscape of the noisy point cloud, and visualize it next to the mean persistence landscape of the clean point cloud.

Now, these landscapes looks much more in agreement!

# Section 4: 3D shape classification with persistence diagrams

In this section, our goal is to use persistence diagrams for classifying and segmenting 3D shapes with supervised machine learning. 

Let's start with classification. We will compute persistence diagrams for all shapes in different categories, and train a classifier from `Scikit-Learn` to predict the category from the persistence diagrams. Since `Gudhi` requires simplex trees from the persistence diagram computations, write a `get_simplex_tree_from_faces` function that builds a simplex tree from the faces of a given 3D shape triangulation.

Now, compute all the persistence diagrams (in homology dimension 0) associated to the sublevel sets of the third coordinate from a few categories, and retrieve their corresponding labels.

As discussed in class, it is not very convenient to use persistence diagrams directly for machine learning purposes (except for a few methods such as $K$-nearest neighbors). What we need is to define a vectorization, that is, a map $\Phi:\mathcal{D}\rightarrow\mathcal{H}$ sending persistence diagrams into a Hilbert space, or equivalently, a symmetric kernel function $k:\mathcal{D}\times \mathcal{D} \rightarrow \mathbb{R}$ such that $k(D,D')=\langle \Phi(D),\Phi(D')\rangle$. Fortunately, there are already a bunch of such maps and kernels in `Gudhi` :-)

In the following we will compute and visualize the most popular kernels on some persistence diagrams. Pick first a specific persistence diagram and use `DiagramSelector` to remove its points with infinite coordinates.

Now, let's see what `Gudhi` has to offer to vectorize persistence diagrams with `Scikit-Learn` estimator-like classes, that is, with classes that have `fit`, `transform`, and `fit_transform` methods, see [this article](https://arxiv.org/pdf/1309.0238.pdf) for more details. For each vectorization mentioned below, we recommend you to play with its parameters and infer their influence on the ouput in order to get some intuition. 

The first vectorization method that was introduced historically is the persistence landscape. A persistence landscape is basically obtained by rotating the persistence diagram by $-\pi/4$
(so that the diagonal becomes the $x$-axis), and then putting tent functions on each point. The $k$th landscape is then defined as the $k$th largest value among all these tent functions. It is eventually turned into a vector by evaluating it on a bunch of uniformly sampled points on the $x$-axis.

Compute and visualize the first landscape of the persistence diagram for various parameters.

A variation, called the silhouette, takes a weighted average of these tent functions instead. Here, we weight each tent function by the distance of the corresponding point to the diagonal.

The second method is the persistence image. A persistence image is obtained by rotating by $-\pi/4$, centering Gaussian functions on all diagram points (usually weighted by a parameter function, such as, e.g., the squared distance to the diagonal) and summing all these Gaussians. This gives a 2D function, that is pixelized into an image.

`Gudhi` also has a variety of metrics and kernels, which sometimes perform better than explicit vectorizations such as the ones described above. Pick another persistence diagram, and get familiar with the bottleneck and the Wasserstein distances between them. Note that you can call them in different ways in `Gudhi`, there are `bottleneck_distance` and `wasserstein_distance` functions for instance, but there are also wrappers of these functions into estimator classes `BottleneckDistance` and `WassersteinDistance` (with `fit` and `transform` methods). These classes are especially useful when doing model selection with `Scikit-Learn`, see below.

`Gudhi` also allows to use standard kernels such as, among others, the persistence scale space kernel, persistence Fisher kernel, sliced Wasserstein kernel, etc. Try computing the kernel values for your pair of diagrams.

Before trying to classify the persistence diagrams, let's do a quick dimension reduction with PCA. Apply `PCA`, `KernelPCA` or `MDS` (available in `Scikit-Learn`) on the explicit maps (landscapes, images, etc), kernel matrices (Fisher, sliced Wasserstein, etc) and distance matrices (bottleneck, Wasserstein, etc) respectively.

Is there any method that looks better in separating the categories, at least by eye?

All right, let's try classification now! Shuffle the data, and create a random 80/20 train/test split.

Here is the best thing about having estimator-like classes: they can be integrated flawlessly in a `Pipeline` of `Scikit-Learn` for model selection and cross-validation! A `Pipeline` is itself an estimator, and is initialized as with a list of estimators. It will just sequentially apply the `fit_transform` methods of the estimators in the list.

Define a `Pipeline` with four estimators: one for selecting the finite persistence diagram points, one for scaling (or not) the persistence diagrams (with `DiagramScaler`), one for vectorizing persistence diagrams, and one for performing the final prediction. See the [documentation](https://scikit-learn.org/stable/modules/compose.html#combining-estimators).

Now, define a grid of parameter that will be used in cross-validation.

Define and train the model.

Check the parameters that were chosen during model selection, and evaluate your model on the test set.

How is your score? If it is bad, you can always increase the parameter and/or classifier search, but this can quickly become quite cumbersome. Moreover, a potential source of error comes from the fact that the third coordinate do not necessarily correspond to the height (i.e., the 3D shapes are not embedded in a consistent way). This is where persistence differentiation can come to the rescue! Indeed, instead of picking a specific coordinate, we can try to optimize a linear combination of the coordinates:

$$f_\alpha: x\mapsto \sum_{i=1}^d \alpha_i x_i,$$

such that the persistence diagrams of the same category are close, while persistence diagrams from different categories are far away from each other. This means minimizing a loss of the form:

$$\alpha^* = {\rm min}_\alpha \sum_l \frac{\sum_{y_i=y_j=l}d(D_{f_\alpha}(x_i),D_{f_\alpha}(x_j))}{\sum_{y_i=l,y_j}d(D_{f_\alpha}(x_i),D_{f_\alpha}(x_j))},$$

where $d$ is any (pseudo)-distance between persistence diagrams, that can be differentiated through a deep learning library (such as `TensorFlow` or `PyTorch`). For instance, the sliced Wasserstein distance is quite easy to compute with standard deep learning libraries since it only involves projecting points onto lines. See [this article](http://proceedings.mlr.press/v70/carriere17a/carriere17a.pdf).

Write a `deep_swd` function that computes the sliced Wasserstein distance between persistence diagrams with `TensorFlow` or `PyTorch` operations.

Now, just as before, split the data into train/test, but this time, collect the vertices, simplex trees and labels (it is useless to compute persistence diagrams since they will be recomputed after each gradient descent iteration and update of $\alpha$).

Initialize the alpha values, as well as the angles used for computing the sliced Wasserstein distances (and make sure these angles are not optimized during training). Define also the iteration number, batch size, learning rate and optimizer.

Run gradient descent! For this, you can use the `LowerStarSimplexTreeLayer` class from `Gudhi`, which computes persistence diagrams from simplex trees in a differentiable way with `TensorFlow` operations. Make sure to save the loss value at each iteration.

Visualize the losses. Is it decreasing? What are the final alpha values?

We can now use these values to train a model again with this new filtration, and check whether the accuracy is better now!

Yay! That's definitely better!

If you managed to go that far, congrats, you are basically a TDA expert now ;-) Do not hesitate to reuse these pieces of code for your own research, and let us know if you have any comment/question/suggestion!