Semi supervised learning (Still under development!)

The first set of methods cover the principals from the following summary: http://sci2s.ugr.es/ssl A batch-generative method, consisting of Kmeans and Logistic Regression, is implemented to cover a naive approach. This experiment is compared to a baseline whice consists of only Logistic Regression.


In [2]:
%run -i initilization.py

In [3]:
from pyspark.sql import functions as F
from pyspark.ml import clustering
from pyspark.ml import  feature
from pyspark.sql import DataFrame
from pyspark.sql import Window
from pyspark.ml import Pipeline
from pyspark.ml import classification
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator   
from shared import Plot2DGraphs, create_dummy_data
from semisupervised import batch_generative_model
Add some parameters in order to generate a dataset

In [4]:
mean_1 = [3.0, 3.0]
std_1 = [2, 2]
mean_2 = [-3.0, -3.0]
std_2 = [1. , 1.0]
n_1 = 300
n_2 = 300
n = [n_1, n_2]
mean = [mean_1, mean_2]
std = [std_1, std_2]

An initial method to semi supervised learning

The following cells are ment to be a data creation method along with an initial try on generate model for semi supervised learning.


In [5]:
def compute_error_rate(data_frame, truth_label='real_label', found_label='prediction'):
    """
    
    """
    
    df_stats = (data_frame
                .groupBy([truth_label, found_label])
                .agg(F.count('prediction').alias('Prediction Count'))
               )
    n = (df_stats
         .select(F.sum(F.col('Prediction Count')).alias('n'))
         .collect()[0]['n']
        )
    wrong_guess = (df_stats
                   .filter((F.col(truth_label) != F.col(found_label)))
                   .select(F.sum(F.col('Prediction Count')).alias('errors'))
                   .collect()[0]['errors']
                  )
    df_stats.show()
    print(n)
    print(wrong_guess)
    print('Error-rate: {}'.format(wrong_guess/n))
Create the labled dataset, and with 1% used lables and the rest is set to NAN.

In [6]:
tester = create_dummy_data.create_labeled_data_with_clusters(n, mean, std, 0.01)
df_tester = spark.createDataFrame(tester)

The dataset with lables and available lables plotted


In [7]:
Plot2DGraphs.plot_known_and_unknown_data(tester)


The initial try at classifying the data, using logistic regression

In [8]:
df_train = df_tester.filter((F.col('used_label') != np.NaN))
df_test = df_tester.filter((F.col('used_label') == np.NaN))

vec_assembler = feature.VectorAssembler(
    inputCols=['x','y'],
    outputCol='features')

lg = classification.LogisticRegression(
    featuresCol=vec_assembler.getOutputCol(),
    labelCol='used_label')

pipeline = Pipeline(stages=[vec_assembler, lg])
# CrossValidation gets build here!
param_grid = (ParamGridBuilder()
              .addGrid(lg.regParam, [0.1, 0.01])
              .build()
             )

evaluator = BinaryClassificationEvaluator(
    rawPredictionCol=lg.getRawPredictionCol(),
    labelCol=lg.getLabelCol())
    
cross_validator = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=param_grid,
    evaluator=evaluator,
    numFolds=3)

cross_validator_model = cross_validator.fit(df_train)
df_without_semisupervised = cross_validator_model.transform(df_test)

Plot2DGraphs.plot_known_and_unknown_data(
    df_without_semisupervised.toPandas(),
    labelCol='prediction')



In [9]:
compute_error_rate(df_without_semisupervised)


+----------+----------+----------------+
|real_label|prediction|Prediction Count|
+----------+----------+----------------+
|       1.0|       1.0|             297|
|       0.0|       1.0|               4|
|       0.0|       0.0|             293|
+----------+----------+----------------+

594
4
Error-rate: 0.006734006734006734
Lets take a look at the semi supervised approach

This simplifyed version uses KMeans and Logistic Regression. In the future, the obvious thing to do is either create a user active system or use an ensembled approach


In [10]:
df_output = batch_generative_model.semi_supervised_batch_single_classifier_generate_approach(df_tester,['x','y'])


2

In [13]:
df_output.limit(5).toPandas()


Out[13]:
x y real_label used_label features Kmeans_prediction rawPrediction probability prediction
0 2.437446 4.045735 0.0 NaN [2.43744643669, 4.04573472692] 0 [2.07021821888, -2.07021821888] [0.887974670747, 0.112025329253] 0.0
1 2.642079 -2.873269 0.0 NaN [2.64207862673, -2.87326915083] 1 [-0.0419769511091, 0.0419769511091] [0.489507302911, 0.510492697089] 1.0
2 0.202908 1.499698 0.0 NaN [0.202908422758, 1.49969825274] 0 [0.550324123549, -0.550324123549] [0.634210786882, 0.365789213118] 0.0
3 1.110012 0.024623 0.0 NaN [1.11001160053, 0.0246226493257] 0 [0.377644577678, -0.377644577678] [0.593304878333, 0.406695121667] 0.0
4 6.783741 2.017823 0.0 NaN [6.78374094559, 2.01782253429] 0 [2.82927630291, -2.82927630291] [0.944237509471, 0.055762490529] 0.0

In [14]:
compute_error_rate(df_output)
Plot2DGraphs.plot_known_and_unknown_data(df_output.toPandas(), labelCol='prediction')


+----------+----------+----------------+
|real_label|prediction|Prediction Count|
+----------+----------+----------------+
|       1.0|       1.0|             300|
|       0.0|       1.0|               8|
|       0.0|       0.0|             292|
+----------+----------+----------------+

600
8
Error-rate: 0.013333333333333334

In [ ]:


In [1]:
df = spark.read.parquet('/home/svanhmic/workspace/data/DABAI/sparkdata/parquet/double_helix.parquet/')
df.write.csv('/home/svanhmic/workspace/data/DABAI/sparkdata/csv/double_helix.csv/')


+-----+----------+--------------------+--------------------+-------------------+---+
|label|true_label|                   x|                   y|                  z| id|
+-----+----------+--------------------+--------------------+-------------------+---+
|  NaN|         1|-0.48894691832462583|   0.104550997416819|  6.072528743102649| 71|
|  NaN|         1| 0.41373998587691363|  0.2807476163506492|  3.737784319598607| 32|
|  NaN|         1| -0.4945318360471719|-0.07374458038264992|0.14802918714272217|160|
|  NaN|         1| -0.4272826987238599|  0.2596718994640259|  5.737102412763009|  4|
|  NaN|         1|-0.42788563654997924|  0.2586771772581011|  5.739428789258191| 36|
|  NaN|         1|-0.03259036911496597| -0.4989367373134096| 1.5055693463972868|180|
|  NaN|         1|-0.49925218997349363|-0.02733588862046773|0.05469904962607518| 89|
|  NaN|         1| -0.4997414986777315|-0.01607589808797...|0.03215733817660402|154|
|  NaN|         1|-0.24850367890777134|  0.4338731629973249|  5.232535583864838|178|
|  NaN|         1|-0.34167143318071613| -0.3650488073507917| 0.8184647885226864|167|
|  NaN|         1|-0.18612879963635826|-0.46406472603067817| 1.1893560705623958|129|
|  NaN|         1|-0.23323686522650874|  0.4422675261640983| 5.1976889685174035| 31|
|  NaN|         1| -0.1101824803363212| -0.4877087461046154| 1.3486077136766805|152|
|  NaN|         1| 0.43789294668424733| -0.2413498855276027| 2.6378578668309016| 72|
|  NaN|         1| 0.09667238984195503| -0.4905654380836924| 1.7653664420905333|110|
|  NaN|         1| -0.3483029577566882|  0.3587269847919761| 5.4830448093723865| 38|
|  NaN|         1|-0.33139367768142586| -0.3744038333040665| 0.8462612779019653| 40|
|  NaN|         1| 0.12396772255565584| 0.48438827789735367| 4.4618404033430386| 33|
|  NaN|         1| -0.3662915742197265|  0.3403387763033101|  5.534498187938466| 97|
|  NaN|         1| 0.12829919652677896|-0.48325905699798627| 1.8302974320667307|133|
+-----+----------+--------------------+--------------------+-------------------+---+
only showing top 20 rows


In [ ]: