In [ ]:
import pyspark
from pyspark.sql import SparkSession
sc = pyspark.SparkContext(appName="sparkSQL")
ss = SparkSession(sc)
In [ ]:
data = "file:////path/to/recitation4/problems/kddcup.data_10_percent"
raw = sc.textFile(data).cache()
A DataFrame is a Dataset organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs
We want to convert our raw data into a table. But first we have to parse it and assign desired rows and headers, something like csv format.
In [ ]:
from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)
from pyspark.sql import Row
In [ ]:
csv_data = raw.map(lambda l: l.split(","))
row_data = csv_data.map(lambda p: Row(
duration=int(p[0]),
protocol_type=p[1],
service=p[2],
flag=p[3],
src_bytes=int(p[4]),
dst_bytes=int(p[5])
)
)
Once we have our RDD of Row we can infer and get a schema. We can operate on this schema with SQL queries.
In [ ]:
kdd_df = sqlContext.createDataFrame(row_data)
kdd_df.registerTempTable("KDDdata")
In [ ]:
# Select tcp network interactions with more than 2 second duration and no transfer from destination
tcp_interactions = sqlContext.sql("SELECT duration, dst_bytes FROM KDDdata WHERE protocol_type = 'tcp' AND duration > 2000 AND dst_bytes = 0")
tcp_interactions.show(10)
In [ ]:
# Complete the query to filter data with duration > 2000, dst_bytes = 0.
# Then group the filtered elements by protocol_type and show the total count in each group.
# Refer - https://spark.apache.org/docs/latest/sql-programming-guide.html#dataframegroupby-retains-grouping-columns
kdd_df.select("protocol_type", "duration", "dst_bytes").filter(kdd_df.duration>2000)#.more query...
In [ ]:
def transform_label(label):
'''
Create a function to parse input label
such that if input label is not normal
then it is an attack
'''
row_labeled_data = csv_data.map(lambda p: Row(
duration=int(p[0]),
protocol_type=p[1],
service=p[2],
flag=p[3],
src_bytes=int(p[4]),
dst_bytes=int(p[5]),
label=transform_label(p[41])
)
)
kdd_labeled = sqlContext.createDataFrame(row_labeled_data)
'''
Write a query to select label,
group it and then count total elements
in that group
'''
# query
We can use other dataframes for filtering our data efficiently.
In [ ]:
kdd_labeled.select("label", "protocol_type", "dst_bytes").groupBy("label", "protocol_type", kdd_labeled.dst_bytes==0).count().show()
It can be inferred that we have large number of tcp attacks with zero data transfer = 110583 as compared to normal tcp = 9313.
This type of analysis is known as exploratory data analysis