How to implement t-SNE in a model?

I split my data to train/test. When i use PCA It is straight forward.

from sklearn.decomposition import PCA
pca = PCA()
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)

From here i can use X_train_pca and X_test_pca in the next step and so on..

But when i use t-SNE

from sklearn.manifold import TSNE
X_train_tsne = TSNE(n_components=2, random_state=0).fit_transform(X_train)

I can't seem to transform the test set so that i can use the t-SNE data for the next step e.g. SVM.

Any help?

2 answers

  • answered 2018-10-17 08:16 Gabriel M

    I believe that what you're trying to do is impossible.

    t-SNE makes a projection that tries to keep pairwise distances between the samples that you fit. So you cannot use a t-SNE model to predict a projection on new data without doing a refit.

    On the other hand, I would not give the output of a t-SNE as input to a classifier. Mainly because t-SNE is highly non linear and somewhat random and you can get very different outputs depending with different runs and different values of perplexity.

    See this explanation of t-SNE.

    However, if you really with to use t-SNE for this purpose, you'll have to fit your t-SNE model on the whole data, and once it is fitted you make your train and test splits.

    from sklearn.manifold import TSNE
    size_train = X_train.shape[0]
    X = np.vstack((X_train,X_test))
    X_tsne = TSNE(n_components=2, random_state=0).fit_transform( X ) 
    X_train_tsne = X_tsne[0:size_train,:]
    X_test_tsne  = X_tsne[size_train:,:]

  • answered 2018-10-17 08:25 Danylo Baibak

    According to the documentation TSNE is a tool to visualize high-dimensional data. A bit lower in the description we can find: it is highly recommended to use another dimensionality reduction method (e.g. PCA for dense data or TruncatedSVD for sparse data) to reduce the number of dimensions.

    My suggestion would be use TSNE for visualisation and PCA or TruncatedSVD as a part of the machine learning model.