In [ ]:
import pyspark
from pyspark.sql import SparkSession
sc = pyspark.SparkContext(appName="sparkSQL")
ss = SparkSession(sc)

In [ ]:
data = "file:////path/to/recitation4/problems/kddcup.data_10_percent"
raw = sc.textFile(data).cache()

DataFrame

A DataFrame is a Dataset organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs

We want to convert our raw data into a table. But first we have to parse it and assign desired rows and headers, something like csv format.


In [ ]:
from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
from pyspark.sql import Row

In [ ]:
csv_data = raw.map(lambda l: l.split(","))
row_data = csv_data.map(lambda p: Row(
    duration=int(p[0]), 
    protocol_type=p[1],
    service=p[2],
    flag=p[3],
    src_bytes=int(p[4]),
    dst_bytes=int(p[5])
    )
)

Once we have our RDD of Row we can infer and get a schema. We can operate on this schema with SQL queries.


In [ ]:
kdd_df = sqlContext.createDataFrame(row_data)
kdd_df.registerTempTable("KDDdata")

In [ ]:
# Select tcp network interactions with more than 2 second duration and no transfer from destination
tcp_interactions = sqlContext.sql("SELECT duration, dst_bytes FROM KDDdata WHERE protocol_type = 'tcp' AND duration > 2000 AND dst_bytes = 0")
tcp_interactions.show(10)

In [ ]:
# Complete the query to filter data with duration > 2000, dst_bytes = 0. 
# Then group the filtered elements by protocol_type and show the total count in each group.
# Refer - https://spark.apache.org/docs/latest/sql-programming-guide.html#dataframegroupby-retains-grouping-columns

kdd_df.select("protocol_type", "duration", "dst_bytes").filter(kdd_df.duration>2000)#.more query...

In [ ]:
def transform_label(label):
    '''
    Create a function to parse input label
    such that if input label is not normal 
    then it is an attack
    '''
    


row_labeled_data = csv_data.map(lambda p: Row(
    duration=int(p[0]), 
    protocol_type=p[1],
    service=p[2],
    flag=p[3],
    src_bytes=int(p[4]),
    dst_bytes=int(p[5]),
    label=transform_label(p[41])
    )
)
kdd_labeled = sqlContext.createDataFrame(row_labeled_data)

'''
Write a query to select label, 
group it and then count total elements
in that group
'''
# query

We can use other dataframes for filtering our data efficiently.


In [ ]:
kdd_labeled.select("label", "protocol_type", "dst_bytes").groupBy("label", "protocol_type", kdd_labeled.dst_bytes==0).count().show()

It can be inferred that we have large number of tcp attacks with zero data transfer = 110583 as compared to normal tcp = 9313.

This type of analysis is known as exploratory data analysis