This is a simple transformer mixin which removes features that are the same


In [117]:
from sklearn.datasets import make_regression, make_classification

from sklearn.pipeline import Pipeline, make_pipeline, FeatureUnion
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.decomposition import PCA

import numpy as np

from sklearn.base import BaseEstimator
from sklearn.feature_selection.base import SelectorMixin
from sklearn.utils.validation import check_is_fitted

In [180]:
class RepeatedRemover(BaseEstimator, SelectorMixin):
    """
    Repeated Remover removes columns that are the same.    
    """
    def fit(self, X, y=None):
        self.indices_ = np.ones(X.shape[1], dtype=np.bool)
        
        indx = []
        indx_iter_list = np.triu_indices(X.shape[1], 1)
        for i, j in zip(indx_iter_list[0].tolist(), indx_iter_list[1].tolist()):
            if i in indx or j in indx:
                continue
            if np.array_equal(X[:,i], X[:, j]):
                indx.append(j)
        
        self.indices_[indx] = False
        return self
    
    def _get_support_mask(self):
        check_is_fitted(self, 'indices_')
        return self.indices_

In [181]:
class ColumnSelector(BaseEstimator, SelectorMixin):
    """
    Repeated Remover removes columns that are the same.    
    """
    def __init__(self, columns=[], exclude=False):
        self.columns = columns
        self.exclude = exclude
    def fit(self, X, y=None):
        if type(self.columns) is not list:
            self.columns = [self.columns]
        
        self.indices_ = np.zeros(X.shape[1], dtype=np.bool)
        indx = np.argwhere(X_df.columns.isin(self.columns)).flatten()        
        self.indices_[indx] = True
        
        if self.exclude:
            self.indices_ = ~self.indices_
        
        return self
    
    def _get_support_mask(self):
        check_is_fitted(self, 'indices_')
        return self.indices_

In [182]:
X = np.random.normal(size=(5, 3))
X_repeat = np.hstack([X, X])

rr = make_pipeline(RepeatedRemover())
rr.fit_transform(X_repeat).shape


Out[182]:
(5, 3)

In [183]:
X_df = pd.DataFrame(X)
X_df.columns = ["col{}".format(x) for x in range(X_df.shape[1])]

In [184]:
cs = ColumnSelector(['col1'])
cs.fit_transform(X_df).shape


Out[184]:
(5, 1)

In [185]:
cs = ColumnSelector(['col1'], exclude=True)
cs.fit_transform(X_df).shape


Out[185]:
(5, 2)

We can force a column to be selected regardless of the previous step through FeatureUnion. Combining this with RepeatedRemover should ensure that the following dataset does not have duplicates.


In [186]:
from sklearn.datasets import load_iris
from sklearn.feature_selection import SelectKBest, chi2

In [187]:
iris = load_iris()
X, y = iris.data, iris.target

In [188]:
X_df = pd.DataFrame(X)
X_df.columns = iris.feature_names

In [189]:
sk = SelectKBest(chi2, k=2)
sk.fit(X_df, y)


Out[189]:
SelectKBest(k=2, score_func=<function chi2 at 0x0000026E16329C80>)

In [190]:
sk.get_support()


Out[190]:
array([False, False,  True,  True], dtype=bool)

In [191]:
X_df.columns[sk.get_support()]


Out[191]:
Index(['petal length (cm)', 'petal width (cm)'], dtype='object')

In [192]:
# now grab `sepal length (cm)` and `petal length (cm)`
feat_sel = FeatureUnion([
        ('sk', SelectKBest(chi2, k=2)), 
        ('force_sel', ColumnSelector(['sepal width (cm)', 'petal length (cm)']))
    ])
feat_sel.fit(X_df, y)

X_check = feat_sel.transform(X_df)

In [193]:
# this object has two columns which are now the same...
X_check


Out[193]:
array([[ 1.4,  0.2,  3.5,  1.4],
       [ 1.4,  0.2,  3. ,  1.4],
       [ 1.3,  0.2,  3.2,  1.3],
       [ 1.5,  0.2,  3.1,  1.5],
       [ 1.4,  0.2,  3.6,  1.4],
       [ 1.7,  0.4,  3.9,  1.7],
       [ 1.4,  0.3,  3.4,  1.4],
       [ 1.5,  0.2,  3.4,  1.5],
       [ 1.4,  0.2,  2.9,  1.4],
       [ 1.5,  0.1,  3.1,  1.5],
       [ 1.5,  0.2,  3.7,  1.5],
       [ 1.6,  0.2,  3.4,  1.6],
       [ 1.4,  0.1,  3. ,  1.4],
       [ 1.1,  0.1,  3. ,  1.1],
       [ 1.2,  0.2,  4. ,  1.2],
       [ 1.5,  0.4,  4.4,  1.5],
       [ 1.3,  0.4,  3.9,  1.3],
       [ 1.4,  0.3,  3.5,  1.4],
       [ 1.7,  0.3,  3.8,  1.7],
       [ 1.5,  0.3,  3.8,  1.5],
       [ 1.7,  0.2,  3.4,  1.7],
       [ 1.5,  0.4,  3.7,  1.5],
       [ 1. ,  0.2,  3.6,  1. ],
       [ 1.7,  0.5,  3.3,  1.7],
       [ 1.9,  0.2,  3.4,  1.9],
       [ 1.6,  0.2,  3. ,  1.6],
       [ 1.6,  0.4,  3.4,  1.6],
       [ 1.5,  0.2,  3.5,  1.5],
       [ 1.4,  0.2,  3.4,  1.4],
       [ 1.6,  0.2,  3.2,  1.6],
       [ 1.6,  0.2,  3.1,  1.6],
       [ 1.5,  0.4,  3.4,  1.5],
       [ 1.5,  0.1,  4.1,  1.5],
       [ 1.4,  0.2,  4.2,  1.4],
       [ 1.5,  0.1,  3.1,  1.5],
       [ 1.2,  0.2,  3.2,  1.2],
       [ 1.3,  0.2,  3.5,  1.3],
       [ 1.5,  0.1,  3.1,  1.5],
       [ 1.3,  0.2,  3. ,  1.3],
       [ 1.5,  0.2,  3.4,  1.5],
       [ 1.3,  0.3,  3.5,  1.3],
       [ 1.3,  0.3,  2.3,  1.3],
       [ 1.3,  0.2,  3.2,  1.3],
       [ 1.6,  0.6,  3.5,  1.6],
       [ 1.9,  0.4,  3.8,  1.9],
       [ 1.4,  0.3,  3. ,  1.4],
       [ 1.6,  0.2,  3.8,  1.6],
       [ 1.4,  0.2,  3.2,  1.4],
       [ 1.5,  0.2,  3.7,  1.5],
       [ 1.4,  0.2,  3.3,  1.4],
       [ 4.7,  1.4,  3.2,  4.7],
       [ 4.5,  1.5,  3.2,  4.5],
       [ 4.9,  1.5,  3.1,  4.9],
       [ 4. ,  1.3,  2.3,  4. ],
       [ 4.6,  1.5,  2.8,  4.6],
       [ 4.5,  1.3,  2.8,  4.5],
       [ 4.7,  1.6,  3.3,  4.7],
       [ 3.3,  1. ,  2.4,  3.3],
       [ 4.6,  1.3,  2.9,  4.6],
       [ 3.9,  1.4,  2.7,  3.9],
       [ 3.5,  1. ,  2. ,  3.5],
       [ 4.2,  1.5,  3. ,  4.2],
       [ 4. ,  1. ,  2.2,  4. ],
       [ 4.7,  1.4,  2.9,  4.7],
       [ 3.6,  1.3,  2.9,  3.6],
       [ 4.4,  1.4,  3.1,  4.4],
       [ 4.5,  1.5,  3. ,  4.5],
       [ 4.1,  1. ,  2.7,  4.1],
       [ 4.5,  1.5,  2.2,  4.5],
       [ 3.9,  1.1,  2.5,  3.9],
       [ 4.8,  1.8,  3.2,  4.8],
       [ 4. ,  1.3,  2.8,  4. ],
       [ 4.9,  1.5,  2.5,  4.9],
       [ 4.7,  1.2,  2.8,  4.7],
       [ 4.3,  1.3,  2.9,  4.3],
       [ 4.4,  1.4,  3. ,  4.4],
       [ 4.8,  1.4,  2.8,  4.8],
       [ 5. ,  1.7,  3. ,  5. ],
       [ 4.5,  1.5,  2.9,  4.5],
       [ 3.5,  1. ,  2.6,  3.5],
       [ 3.8,  1.1,  2.4,  3.8],
       [ 3.7,  1. ,  2.4,  3.7],
       [ 3.9,  1.2,  2.7,  3.9],
       [ 5.1,  1.6,  2.7,  5.1],
       [ 4.5,  1.5,  3. ,  4.5],
       [ 4.5,  1.6,  3.4,  4.5],
       [ 4.7,  1.5,  3.1,  4.7],
       [ 4.4,  1.3,  2.3,  4.4],
       [ 4.1,  1.3,  3. ,  4.1],
       [ 4. ,  1.3,  2.5,  4. ],
       [ 4.4,  1.2,  2.6,  4.4],
       [ 4.6,  1.4,  3. ,  4.6],
       [ 4. ,  1.2,  2.6,  4. ],
       [ 3.3,  1. ,  2.3,  3.3],
       [ 4.2,  1.3,  2.7,  4.2],
       [ 4.2,  1.2,  3. ,  4.2],
       [ 4.2,  1.3,  2.9,  4.2],
       [ 4.3,  1.3,  2.9,  4.3],
       [ 3. ,  1.1,  2.5,  3. ],
       [ 4.1,  1.3,  2.8,  4.1],
       [ 6. ,  2.5,  3.3,  6. ],
       [ 5.1,  1.9,  2.7,  5.1],
       [ 5.9,  2.1,  3. ,  5.9],
       [ 5.6,  1.8,  2.9,  5.6],
       [ 5.8,  2.2,  3. ,  5.8],
       [ 6.6,  2.1,  3. ,  6.6],
       [ 4.5,  1.7,  2.5,  4.5],
       [ 6.3,  1.8,  2.9,  6.3],
       [ 5.8,  1.8,  2.5,  5.8],
       [ 6.1,  2.5,  3.6,  6.1],
       [ 5.1,  2. ,  3.2,  5.1],
       [ 5.3,  1.9,  2.7,  5.3],
       [ 5.5,  2.1,  3. ,  5.5],
       [ 5. ,  2. ,  2.5,  5. ],
       [ 5.1,  2.4,  2.8,  5.1],
       [ 5.3,  2.3,  3.2,  5.3],
       [ 5.5,  1.8,  3. ,  5.5],
       [ 6.7,  2.2,  3.8,  6.7],
       [ 6.9,  2.3,  2.6,  6.9],
       [ 5. ,  1.5,  2.2,  5. ],
       [ 5.7,  2.3,  3.2,  5.7],
       [ 4.9,  2. ,  2.8,  4.9],
       [ 6.7,  2. ,  2.8,  6.7],
       [ 4.9,  1.8,  2.7,  4.9],
       [ 5.7,  2.1,  3.3,  5.7],
       [ 6. ,  1.8,  3.2,  6. ],
       [ 4.8,  1.8,  2.8,  4.8],
       [ 4.9,  1.8,  3. ,  4.9],
       [ 5.6,  2.1,  2.8,  5.6],
       [ 5.8,  1.6,  3. ,  5.8],
       [ 6.1,  1.9,  2.8,  6.1],
       [ 6.4,  2. ,  3.8,  6.4],
       [ 5.6,  2.2,  2.8,  5.6],
       [ 5.1,  1.5,  2.8,  5.1],
       [ 5.6,  1.4,  2.6,  5.6],
       [ 6.1,  2.3,  3. ,  6.1],
       [ 5.6,  2.4,  3.4,  5.6],
       [ 5.5,  1.8,  3.1,  5.5],
       [ 4.8,  1.8,  3. ,  4.8],
       [ 5.4,  2.1,  3.1,  5.4],
       [ 5.6,  2.4,  3.1,  5.6],
       [ 5.1,  2.3,  3.1,  5.1],
       [ 5.1,  1.9,  2.7,  5.1],
       [ 5.9,  2.3,  3.2,  5.9],
       [ 5.7,  2.5,  3.3,  5.7],
       [ 5.2,  2.3,  3. ,  5.2],
       [ 5. ,  1.9,  2.5,  5. ],
       [ 5.2,  2. ,  3. ,  5.2],
       [ 5.4,  2.3,  3.4,  5.4],
       [ 5.1,  1.8,  3. ,  5.1]])

In [194]:
feat_sel = Pipeline([('feature_sel', FeatureUnion([
        ('sk', SelectKBest(chi2, k=2)), 
        ('force_sel', ColumnSelector(['sepal width (cm)', 'petal length (cm)']))
    ])), 
        ('remover', RepeatedRemover())])
feat_sel.fit(X_df, y)
X_check = feat_sel.transform(X_df)
X_check.shape # should be 3


Out[194]:
(150, 3)