In [10]:
import sklearn

In [11]:
from sklearn.datasets import make_classification
from sklearn.linear_model import SGDClassifier

import pandas as pd
import numpy as np

In [12]:
X, y = make_classification()
pdf = pd.DataFrame(X)
pdf.columns = ['c{}'.format(x) for x in range(X.shape[1])]

In [13]:
X.shape


Out[13]:
(100, 20)

In [14]:
X1 = pdf[['c{}'.format(x) for x in range(10, 20)]]
X2 = pdf[['c{}'.format(x) for x in range(10)]]

In [19]:
class GraftingClassifier(SGDClassifier):
    """
    Currently only supports logistic regression.    
    """
    def __init__(self, loss="log", penalty='l2', alpha=0.0001, l1_ratio=0.15,
                 fit_intercept=True, max_iter=None, tol=None, shuffle=True,
                 verbose=0, epsilon=0.1, n_jobs=1,
                 random_state=None, learning_rate="optimal", eta0=0.0,
                 power_t=0.5, class_weight=None, warm_start=False,
                 average=False, n_iter=None, reg_penalty=None):
        super(GraftingClassifier, self).__init__(
            loss=loss, penalty=penalty, alpha=alpha, l1_ratio=l1_ratio,
            fit_intercept=fit_intercept, max_iter=max_iter, tol=tol,
            shuffle=shuffle, verbose=verbose, epsilon=epsilon, n_jobs=n_jobs,
            random_state=random_state, learning_rate=learning_rate, eta0=eta0,
            power_t=power_t, class_weight=class_weight, warm_start=warm_start,
            average=average, n_iter=n_iter)
        self.filter_cols = []
        self.base_shape = None
        self.reg_penalty = reg_penalty if reg_penalty is not None else l1_ratio
    
    def _fit_columns(self, X, return_x=True):
        """
        Method filter through "unselected" columns. The goal of this 
        method is to filter any uninformative columns.
        
        This will be selected based on index only?
        
        If return_x is false, it will only return the boolean mask.
        """
        import pandas
        bool_mask = np.ones((X.shape[1],), dtype=np.bool)
        if len(self.filter_cols) == 0:
            if return_x:
                return X
            else:
                return bool_mask
        # otherwise...
        bool_mask[self.filter_cols] = False
        if not return_x:
            return bool_mask
        if type(X) is pandas.core.frame.DataFrame:
            return X[X.columns[bool_mask]]
        else:
            return X[:, bool_mask]
    
    def _reg_penalty(self, tot_new_feats, base_size):
        remove_cols = np.argwhere(np.abs(self.coef_.flatten()[-tot_new_feats:]) < self.reg_penalty)
        add_cols = np.argwhere(np.abs(self.coef_.flatten()[-tot_new_feats:]) >= self.reg_penalty)
        base_coef = self.coef_.flatten()[:-tot_new_feats].tolist()
        # adding new coefs
        base_coef = base_coef + self.coef_.flatten()[-tot_new_feats:][add_cols].flatten().tolist()
        self.coef_ = np.array(base_coef).reshape(1, -1)
        remove_cols_offset = [base_size + x for x in remove_cols]
        self.filter_cols.append(remove_cols_offset)
        
        
    
    def _partial_grafting_fit(self, X_, y):
        """
        Partial fit grafting method to expand the coefficient listing
        to taking into account new coefficients
        """
        # require to know the base shape to determine/
        # check for irrelevant columns in the future.
        self.base_shape = self.coef_.shape[0]
        
        X = self._fit_columns(X_)
        n_samples, n_features = X.shape
        coef_list = np.zeros(n_features, dtype=np.float64, order="C")
        coef_list[:self.coef_.flatten().shape[0]] = self.coef_.flatten()
        self.coef_ = coef_list.reshape(1,-1)
        
    def partial_fit(self, X, y, sample_weight=None):
        base_size = len(self.filter_cols) + self.coef_.flatten().shape[0]
        tot_new_feats = X.shape[1] - base_size
        self._partial_grafting_fit(X, y)
        super(GraftingClassifier, self).partial_fit(X, y, sample_weight=None)  
        
        # update parameters based on weight of regularizer penalty
        self._reg_penalty(tot_new_feats, base_size)
        return self
    
    def predict(self, X):
        X = self._fit_columns(X)
        return super(GraftingClassifier, self).predict(X)
    
    def predict_proba(self, X):
        X = self._fit_columns(X)
        return super(GraftingClassifier, self).predict_proba(X)

In [20]:
model = GraftingClassifier(max_iter=1000)
model.fit(X1, y)


Out[20]:
GraftingClassifier(alpha=0.0001, average=False, class_weight=None,
          epsilon=0.1, eta0=0.0, fit_intercept=True, l1_ratio=0.15,
          learning_rate='optimal', loss='log', max_iter=1000, n_iter=None,
          n_jobs=1, penalty='l2', power_t=0.5, random_state=None,
          reg_penalty=0.15, shuffle=True, tol=None, verbose=0,
          warm_start=False)

In [21]:
model.coef_.shape


Out[21]:
(1, 10)

In [22]:
model.partial_fit(pdf, y)


Out[22]:
GraftingClassifier(alpha=0.0001, average=False, class_weight=None,
          epsilon=0.1, eta0=0.0, fit_intercept=True, l1_ratio=0.15,
          learning_rate='optimal', loss='log', max_iter=1000, n_iter=None,
          n_jobs=1, penalty='l2', power_t=0.5, random_state=None,
          reg_penalty=0.15, shuffle=True, tol=None, verbose=0,
          warm_start=False)

In [23]:
model.coef_.shape


Out[23]:
(1, 9)

In [24]:
model.predict_proba(pdf)


Out[24]:
array([[ 0.00619828,  0.99380172],
       [ 0.92969369,  0.07030631],
       [ 0.68502224,  0.31497776],
       [ 0.11524396,  0.88475604],
       [ 0.03516324,  0.96483676],
       [ 0.35131229,  0.64868771],
       [ 0.72360682,  0.27639318],
       [ 0.7947508 ,  0.2052492 ],
       [ 0.18609653,  0.81390347],
       [ 0.17042461,  0.82957539],
       [ 0.94364955,  0.05635045],
       [ 0.72174869,  0.27825131],
       [ 0.98252103,  0.01747897],
       [ 0.60455504,  0.39544496],
       [ 0.77980623,  0.22019377],
       [ 0.69483036,  0.30516964],
       [ 0.5413798 ,  0.4586202 ],
       [ 0.3501247 ,  0.6498753 ],
       [ 0.36370671,  0.63629329],
       [ 0.5962762 ,  0.4037238 ],
       [ 0.419771  ,  0.580229  ],
       [ 0.89638916,  0.10361084],
       [ 0.75839285,  0.24160715],
       [ 0.82523315,  0.17476685],
       [ 0.44215415,  0.55784585],
       [ 0.05100563,  0.94899437],
       [ 0.20269502,  0.79730498],
       [ 0.62765282,  0.37234718],
       [ 0.72584456,  0.27415544],
       [ 0.84605548,  0.15394452],
       [ 0.75882693,  0.24117307],
       [ 0.14363298,  0.85636702],
       [ 0.15704122,  0.84295878],
       [ 0.91739279,  0.08260721],
       [ 0.94488603,  0.05511397],
       [ 0.7719733 ,  0.2280267 ],
       [ 0.96367466,  0.03632534],
       [ 0.68333343,  0.31666657],
       [ 0.93804905,  0.06195095],
       [ 0.36384688,  0.63615312],
       [ 0.7702775 ,  0.2297225 ],
       [ 0.97813235,  0.02186765],
       [ 0.62574577,  0.37425423],
       [ 0.55332863,  0.44667137],
       [ 0.91327073,  0.08672927],
       [ 0.78662022,  0.21337978],
       [ 0.3903689 ,  0.6096311 ],
       [ 0.28463065,  0.71536935],
       [ 0.06154552,  0.93845448],
       [ 0.58562644,  0.41437356],
       [ 0.4934146 ,  0.5065854 ],
       [ 0.85122699,  0.14877301],
       [ 0.74105576,  0.25894424],
       [ 0.83707032,  0.16292968],
       [ 0.69330963,  0.30669037],
       [ 0.04736971,  0.95263029],
       [ 0.07786129,  0.92213871],
       [ 0.57960624,  0.42039376],
       [ 0.84517802,  0.15482198],
       [ 0.91530769,  0.08469231],
       [ 0.23628402,  0.76371598],
       [ 0.22313465,  0.77686535],
       [ 0.11976533,  0.88023467],
       [ 0.66270921,  0.33729079],
       [ 0.94216506,  0.05783494],
       [ 0.97302726,  0.02697274],
       [ 0.96673777,  0.03326223],
       [ 0.99727629,  0.00272371],
       [ 0.62624011,  0.37375989],
       [ 0.67197917,  0.32802083],
       [ 0.37715775,  0.62284225],
       [ 0.90704241,  0.09295759],
       [ 0.13518674,  0.86481326],
       [ 0.35017109,  0.64982891],
       [ 0.80306754,  0.19693246],
       [ 0.18100188,  0.81899812],
       [ 0.29705418,  0.70294582],
       [ 0.18055756,  0.81944244],
       [ 0.21488882,  0.78511118],
       [ 0.27758515,  0.72241485],
       [ 0.15424098,  0.84575902],
       [ 0.34117787,  0.65882213],
       [ 0.47389542,  0.52610458],
       [ 0.08184066,  0.91815934],
       [ 0.24141037,  0.75858963],
       [ 0.37854512,  0.62145488],
       [ 0.04068648,  0.95931352],
       [ 0.80504507,  0.19495493],
       [ 0.97835053,  0.02164947],
       [ 0.28286714,  0.71713286],
       [ 0.86638237,  0.13361763],
       [ 0.04401544,  0.95598456],
       [ 0.78790736,  0.21209264],
       [ 0.37256829,  0.62743171],
       [ 0.01175228,  0.98824772],
       [ 0.21025498,  0.78974502],
       [ 0.60456134,  0.39543866],
       [ 0.09650353,  0.90349647],
       [ 0.74493591,  0.25506409],
       [ 0.68760358,  0.31239642]])