In [ ]:
Initialize workflow.

In [ ]:
import struct, socket
import shutil
import numpy as np
import pandas as pd
import linecache, bisect
import csv, json
import operator
import os, time, subprocess 
from collections import OrderedDict

try:
    import ipywidgets as widgets # For jupyter/ipython >= 1.4
except ImportError:
    from IPython.html import widgets

from IPython.display import display, Javascript, clear_output

path = os.getcwd().split("/") 
date = path[len(path)-1]   
dsource = path[len(path)-2]  
dpath = '/'.join(['data' if var == 'ipynb' else var for var in path]) + '/'
cpath = '/'.join(['context' if var == 'ipynb' else var for var in path][:len(path)-2]) + '/'
opath = '/'.join(['oa' if var == 'ipynb' else var for var in path][:len(path)-1]) + '/'
sconnect = dpath + 'flow_scores.csv' 
sconnectbu = dpath + 'flow_scores_bu.csv'
score_fbk = dpath + 'flow_scores_fb.csv'
tmpconnect = sconnect +'.tmp'
stemp = sconnect + '.new'
file_schemas = opath + dsource + '_conf.json'
#gets feedback columns from config file
feedback_cols = json.loads(open (file_schemas).read(),object_pairs_hook=OrderedDict)['flow_feedback_fields']
coff = 250;
nwloc = cpath + 'networkcontext.csv' 
srcdict,srclist = {},[]
dstdict,dstlist = {},[]
sportdict,sportlist = {},[]
dportdict,dportlist = {},[]

In [ ]:
def apply_css_to_select(select):
    select._css = (
        (None, 'height', '90%'),
        (None, 'width', '90%'),
        ('select', 'overflow-x', 'auto'),
        ('select', 'width', '100%'),
        ('select', 'margin', 0)
    )


#load LDA scores#update lda doc with scores - either one edge,# or ALL that contain x IP and y port
def displaythis() :
    # build dict of IP addresses
    #sev,score, tstart,srcIP,dstIP,sport,dport,proto,ipkt,ibyt
    display(Javascript("$('.widget-area > .widget-subarea > *').remove();"))
    srcdict,srclist = {},[]
    dstdict,dstlist = {},[]
    sportdict,sportlist = {},[]
    dportdict,dportlist = {},[]
    srclist.append('- Select -')
    dstlist.append('- Select -')
    sportlist.append('- Select -')
    dportlist.append('- Select -')
    set_rules()
    with open(sconnect, 'r') as f:
        reader = csv.DictReader(f,delimiter=',') 
        rowct = 1
        for row in reader:
            if row['srcIP'] not in srcdict and row['sev'] == '0':
                srclist.append(row['srcIP'])
                srcdict[row['srcIP']] = struct.unpack("!L", socket.inet_aton(row['srcIP']))[0]
            if row['dstIP'] not in dstdict and row['sev'] == '0':
                dstlist.append(row['dstIP'])
                dstdict[row['dstIP']] = struct.unpack("!L", socket.inet_aton(row['dstIP']))[0]
            if row['sport'] not in sportdict and row['sev'] == '0':
                sportlist.append(row['sport'])
                sportdict[row['sport']] = row['sport']
            if row['dport'] not in dportdict and row['sev'] == '0':
                dportlist.append(row['dport'])
                dportdict[row['dport']] = row['dport']
            if rowct == coff:
                break;
            rowct += 1
    
    # Source IP box
    scrIpLalbel = widgets.HTML(value="Source IP:", height='10%', width='100%')
    srcselect = widgets.Select(options=srclist, height='90%')
    apply_css_to_select(srcselect)
    srcIpBox = widgets.Box(width='25%', height='100%')
    srcIpBox.children = (scrIpLalbel, srcselect)
    
    # Destination IP box
    dstIpLalbel = widgets.HTML(value="Dest IP:", height='10%', width='100%')
    dstselect = widgets.Select(options=dstlist, height='90%')
    apply_css_to_select(dstselect)
    dstIpBox = widgets.Box(width='25%', height='100%')
    dstIpBox.children = (dstIpLalbel, dstselect)
    
    # Source Port box
    scrPortLalbel = widgets.HTML(value="Src Port:", height='10%', width='100%')
    sportselect = widgets.Select(options=sportlist, height='90%')
    apply_css_to_select(sportselect)
    srcPortBox = widgets.Box(width='20%', height='100%')
    srcPortBox.children = (scrPortLalbel, sportselect)
    
    # Destionation Port box
    dstPortLalbel = widgets.HTML(value="Dst Port:", height='10%', width='100%')
    dportselect = widgets.Select(options=dportlist,height='90%')
    apply_css_to_select(dportselect)
    dstPortBox = widgets.Box(width='20%', height='100%')
    dstPortBox.children = (dstPortLalbel, dportselect)
    
    # Quick Search and Actions Box
    emptyLalbel = widgets.HTML(value=" ")
    srctext = widgets.Text(value='', width='100%', placeholder='Quick IP scoring')
    srctext._css = (
        (None, 'width', '100%'),
    )
    ratingbut = widgets.RadioButtons(description='Rating:',options=['1', '2', '3'], width='100%')
    assignbut = widgets.Button(description='Score', width='45%')
    assignbut.button_style = 'primary'
    updatebut = widgets.Button(description='Save', width='45%')
    updatebut.button_style = 'primary'
    updatebut._css = (
        (None, 'margin-left', '10%'),
    )
    actionsBox = widgets.Box(width='20%', height='100%')
    actionsBox.children = (emptyLalbel, srctext,ratingbut,assignbut,updatebut)
    
    # Container Box
    bigBox = widgets.HBox(width='90%', height=250)
    bigBox.children = (srcIpBox, dstIpBox, srcPortBox, dstPortBox, actionsBox)
    
    display(bigBox)
    
    def update_sconnects(b):        
        clear_output()
        time.sleep(.25)
        dvals,svals = [], [] 
        scored_threats =[]
        #define logic based on combo of input
        #Gets input values
        if srctext.value != '':
            svals = [srctext.value,dstselect.value,sportselect.value,dportselect.value]
            dvals = [srcselect.value,srctext.value,sportselect.value,dportselect.value] 
        else:
            svals = [srcselect.value,dstselect.value,sportselect.value,dportselect.value]
            dvals = [] 
        risk = ratingbut.value 
        shash, dhash = 0, 0
        fhash = ['srcIP','dstIP','sport','dport'] 
        
        for k in xrange(len(svals)):
            if svals[k] == '- Select -': svals[k] = ''
            if svals[k] != '': shash += 2**k    
            if len(dvals) > 0:
                if dvals[k] == '- Select -': dvals[k] = ''
                if dvals[k] != '': dhash += 2**k    
        
        rowct = 0
        threat = []
        if shash > 0 or dhash > 0:            
            with open(tmpconnect,'w') as g:
                with open(sconnect, 'r') as f:
                    reader = csv.DictReader(f,delimiter=',')
                    riter = csv.DictWriter(g,delimiter=',', fieldnames=reader.fieldnames)
                    riter.writeheader()
                    
                    for row in reader: 
                        result, resultd = 0,0
                        for n in xrange(0,len(svals)):
                            if (2**n & shash) > 0:  
                                if row[fhash[n]] == svals[n]:
                                    result += 2**n 
                        if result == shash:
                            row['sev'] = risk 
                            scored_threats.append({col:row[col] for col in feedback_cols.keys()})
                            rowct += 1

                        if len(dvals) > 0:
                            for n in xrange(0,len(dvals)):
                                if (2**n & dhash) > 0:  
                                    if row[fhash[n]] == dvals[n]:
                                        resultd += 2**n 
                            if resultd == dhash:
                                row['sev'] = risk
                                scored_threats.append({col:row[col] for col in feedback_cols.keys()})
                                rowct += 1
                                
                        riter.writerow(row) 

            create_feedback_file(scored_threats)
            shutil.copyfile(tmpconnect,sconnect)
            
        print "{0} matching connections scored".format(rowct)
        
            
    def savesort(b):
        clear_output()
        with open(stemp,'w') as g:
            reader = csv.DictReader(open(sconnect), delimiter=",")
            riter = csv.DictWriter(g,fieldnames=reader.fieldnames, delimiter=',')
            srtlist = sorted(reader, key=lambda x: (int(x["sev"]), float(x["score"])))
            riter.writeheader()
            riter.writerows(srtlist)
                
        shutil.copyfile(stemp,sconnect)
        print "Suspicious connects successfully updated"        
        display(Javascript('reloadParentData();')) 
        bigBox.close()
        # Rebuild widgets form
        displaythis()
        ml_feedback()
    assignbut.on_click(update_sconnects)
    updatebut.on_click(savesort)

    
def create_feedback_file(scored_rows):
#     #works on the feedback tab-separated file
    if not os.path.exists(score_fbk):  
        with open(score_fbk, 'w') as feedback:
            wr = csv.DictWriter(feedback, fieldnames=feedback_cols, delimiter='\t', quoting=csv.QUOTE_NONE)   
            wr.writeheader()

    wr = csv.DictWriter(open(score_fbk, 'a'), delimiter='\t', fieldnames=feedback_cols, quoting=csv.QUOTE_NONE)
    for row in scored_rows:
        wr.writerow(row)


def set_rules():
    rops = ['leq','leq','leq','leq','leq','leq']
    rvals = ['','','',1024,'',54]
    risk = 2
    apply_rules(rops,rvals,risk)
    rops = ['leq','leq','leq','leq','eq','eq']
    rvals = ['','','',1024,3,152]
    risk = 2
    apply_rules(rops,rvals,risk)
    rops = ['leq','leq','leq','leq','eq','eq']
    rvals = ['','','',1024,2,104]
    risk = 2
    rops = ['leq','leq','eq','leq','leq','leq']
    rvals = ['','',0,1023,'','']
    risk = 2
    apply_rules(rops,rvals,risk)

    
    
def apply_rules(rops,rvals,risk):
    #define logic based on combo of input
    rhash = 0
    rfhash = ['srcIP','dstIP','sport','dport', 'ipkt', 'ibyt']
    scored_threats=[]
    
    for k in xrange(len(rvals)):
        if rvals[k] != '':                
            rhash += 2**k
            
    with open(sconnect, 'r') as f:
        with open(tmpconnect,'w') as g:
            reader = csv.DictReader(f,delimiter=',')
            riter = csv.DictWriter(g,fieldnames=reader.fieldnames,delimiter=',')
            riter.writeheader()
            for row in reader: 
                result = 0
                for n in xrange(0,len(rvals)):
                    if (2**n & rhash) > 0:
                        if rops[n] == 'leq':
                            if int(row[rfhash[n]]) <= int(rvals[n]):
                                result += 2**n                           
                        if rops[n] == 'eq':
                            if int(row[rfhash[n]]) == int(rvals[n]):
                                result += 2**n                           
                if result == rhash:
                    row['sev'] = risk
                    scored_threats.append({col:row[col] for col in feedback_cols.keys()})
                riter.writerow(row)  
                
    create_feedback_file(scored_threats)
    shutil.copyfile(tmpconnect,sconnect)
    
    
def attack_heuristics():
    with open(sconnect, 'rb') as f:
        reader = csv.DictReader(f,delimiter=',') 
        reader.next();
        rowct = 1
        for row in reader:
            if row['srcIP'] not in srcdict:
                srcdict[row['srcIP']] = row['srcIP']
            if row['dstIP'] not in dstdict:
                 dstdict[row['dstIP']] = row['dstIP']
            if row['sport'] not in sportdict:
                sportdict[row['sport']] = row['sport']
            if row['dport'] not in dportdict:
                dportdict[row['dport']] = row['dport']

    df = pd.read_csv(sconnect)   
    gb = df.groupby([u'srcIP'])      
  
    for srcip in srcdict:
        try:
            if len(gb.get_group(srcip)) > 20:
                print srcip,'connects:',len(gb.get_group(srcip))
        except:
            print "Key Error for ip: " + srcip
               
            
def ml_feedback():
    dst_name = os.path.basename(sconnect)
    str_fb="DSOURCE={0} &&\
        FDATE={1} &&\
        source /etc/spot.conf &&\
        usr=$(echo $LUSER | cut -f3 -d'/') &&\
        mlnode=$MLNODE &&\
        lpath=$LPATH &&\
        scp {2} $usr@$mlnode:$lpath/{3}".format(dsource,date,score_fbk,dst_name)  
    
    subprocess.call(str_fb, shell=True)

Run attack heuristics.


In [ ]:
# set_rules()

In [ ]:
# attack_heuristics()

In [ ]:
displaythis()

In [ ]:
# !cp $sconnectbu $sconnect