ART1 demo

Adaptive Resonance Theory Neural Networks by Aman Ahuja | github.com/amanahuja | twitter: @amanqa

Overview

Reminders:

  • ART1 accepts binary inputs only.

In this example:

  • We'll use 10x10 ASCII blocks to demonstrate

[Load data]


In [ ]:
import os

In [ ]:
# make sure we're in the root directory

pwd = os.getcwd()
if pwd.endswith('ipynb'):
    os.chdir('..')
    
#print os.getcwd()

In [ ]:
# data directory
data_dir = 'data'
print os.listdir(data_dir)

# ASCII data file
data_file = 'ASCII_01.txt'

In [ ]:
with open(os.path.join(data_dir, data_file), 'r') as f: 
    raw_data = f.read()

In [ ]:
# print out raw_data to see what it looks like
# print raw_data

In [ ]:
# Get data into a usable form here
data = [d.strip() for d in raw_data.split('\n\n')]
data = [d for d in data if d is not '']

data = [d.replace('\n', '') for d in data]

# print the data
data

[Cleaning and proceprocessing]


In [ ]:
import numpy as np

In [ ]:
from collections import Counter
import numpy as np

def preprocess_data(data): 
    """
    Convert to numpy array
    Convert to 1s and 0s
    
    """
    # Get useful information from first row
    if data[0]: 
        irow = data[0]

        # get size
        idat_size = len(irow)

        # get unique characters
        chars = False
        while not chars: 
            chars = get_unique_chars(irow, reverse=True)
        char1, char2 = chars

    outdata = []
    idat = np.zeros(idat_size, dtype=bool)

    #convert to boolean using the chars identified
    for irow in data:
        assert len(irow) == idat_size, "data row lengths not consistent"
        idat = [x==char1 for x in irow]
        # note: idat is a list of bools
        idat =list(np.array(idat).astype(int))
        outdata.append(idat)
    
    outdata = np.array(outdata)
    return outdata.astype(int)

def get_unique_chars(irow, reverse=False):
    """
    Get unique characters in data
    Helper function
    ---- 
    reverse:   bool
        Reverses order of the two chars returned
    """
    chars = Counter(irow)
    if len(chars) > 2: 
        raise Exception("Data is not binary")
    elif len(chars) < 2: 
        # first row doesn't contain both chars
        return False, False

    # Reorder here?
    if reverse: 
        char2, char1 = chars.keys()
    else: 
        char1, char2 = chars.keys()
    
    return char1, char2

In [ ]:
# preprocess data
data_cleaned = preprocess_data(data)

In [ ]:
def display_ASCII(raw):
    out = "{}\n{}\n{}\n{}\n{}".format(
        raw[:5],
        raw[5:10],
        raw[10:15],
        raw[15:20],
        raw[20:25],
    )
    return out

In [ ]:
## Simplied ART1

class ART1:
    """
    ART class
    modified Aman Ahuja
    
    Usage example:
    --------------
    # Create a ART network with input of size 5 and 20 internal units
    >>> network = ART(5,10,0.5)
    """

    def __init__(self, n=5, m=10, rho=.5):
        '''
        Create network with specified shape
        
        For Input array I of size n, we need n input nodes in F1. 
        
        
        Parameters:
        -----------
        n : int
            feature dimension of input; number of nodes in F1
        m : int
            Number of neurons in F2 competition layer
            max number of categories
            compare to n_class
        rho : float
            Vigilance parameter
            larger rho: less inclusive prototypes
            smaller rho: more generalization
        
        internal paramters
        ---------- 
        F1: array of size (n)
            array of F1 neurons
        F2: array of size (m)
            array of F2 neurons
        Wf: array of shape (m x n)
            Feed-Forward weights
            These are Tk
        Wb: array of shape (n x m)
            Feed-back weights
        n_cats : int
            Number of F2 neurons that are active
            (at any given time, number of category templates)
        
        '''
        # Comparison layer
        self.F1 = np.ones(n)
        
        # Recognition layer
        self.F2 = np.ones(m)
        
        # Feed-forward weights
        self.Wf = np.random.random((m,n))
        
        # Feed-back weights
        self.Wb = np.random.random((n,m))
        
        # Vigilance parameter
        self.rho = rho
        
        # Number of active units in F2
        self.n_cats = 0

    def reset(self):
        """Reset whole network to start conditions
        """
        self.F1 = np.ones(n)
        self.F2 = np.ones(m)
        self.Wf = np.random.random((m,n))
        self.Wb = np.random.random((n,m))
        self.n_cats = 0 
        
    def learn(self, X):
        """Learn X
        use i as index over inputs or F1
        use k as index over categories or F2
        """ 

        # Compute F2 output using feed forward weights
        self.F2[...] = np.dot(self.Wf, X)
        
        # collect and sort the output of each active node (C)
        C = np.argsort(self.F2[:self.n_cats].ravel())[::-1]

        for k in C:
            # compute nearest memory
            d = (self.Wb[:,k]*X).sum()/X.sum()

            # Check if d is above the vigilance level
            if d >= self.rho:
                ww = self._learn_data(k, X)
                return ww
            else: 
                pass

        # No match found within vigilance level
        # If there's room, increase the number of active units
        # and make the newly active unit to learn data
        if self.n_cats < self.F2.size:
            k = self.n_cats  # index of last category
            ww = self._learn_data(k, X)
            self.n_cats += 1
            return ww
        else: 
            return None,None

    def _learn_data(self, node, dat):
        """
        node : i : F2 node
        dat  : X : input data
        """ 
        self._validate_data(dat)
        
        # Learn data
        self.Wb[:,node] *= dat
        self.Wf[node,:] = self.Wb[:,node]/(0.5+self.Wb[:,node].sum())
        return self.Wb[:,node], node
    
    def predict(self, X):
        C = np.dot(self.Wf[:self.n_cats], X)

        #return active F2 node, unless none are active
        if np.all(C == 0):
            return None

        return np.argmax(C)

    def _validate_data(self, dat):
        """
        dat is a single input record
        Checks: data must be 1s and 0s
        """
        pass_checks = True
        
        # Dimensions must match
        if dat.shape[0] != len(self.F1):
            pass_checks = False
            msg = "Input dimensins mismatch."
        
        # Data must be 1s or 0s
        if not np.all((dat == 1) | (dat == 0)):
            pass_checks = False
            msg = "Input must be binary."
        
        if pass_checks:
            return True
        else: 
            raise Exception("Data does not validate: {}".format(msg))

DO


In [ ]:
from collections import defaultdict

# create network

input_row_size = 25
max_categories = 10
rho = 0.20

network = ART1(n=input_row_size, m=max_categories, rho=rho)

# preprocess data
data_cleaned = preprocess_data(data)

# shuffle data? 
np.random.seed(1221)
np.random.shuffle(data_cleaned)

# learn data array, row by row
for row in data_cleaned:
    network.learn(row)

print
print "n rows of data:         ", len(data_cleaned)
print "max categories allowed: ", max_categories
print "rho:                    ", rho

print "n categories used:      ", network.n_cats
print


# output results, row by row
output_dict = defaultdict(list)

for row, row_cleaned in zip (data, data_cleaned): 
    pred = network.predict(row_cleaned)
    output_dict[pred].append(row)

for k,v in output_dict.iteritems():
    print "category: {}, ({} members)".format(k, len(v))
    print '-'*20
    for row in v: 
        print display_ASCII(row)
        print
    print 
#   \  print "'{}':{}".format(
#         row, 
#         network.predict(row_cleaned))