In [1]:
import gensim
import os
import collections
import smart_open
import random
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import itertools
import plotly.plotly as py
import plotly.graph_objs as go
import plotly.offline as offline
import time
import seaborn as sns
import numpy as np
import multiprocessing
import pandas as pd

flatten = lambda l: [item for sublist in l for item in sublist]

In [2]:
"""
Load basic ingredients and compounds data

"""

path = 'data'
ingr_info = path + os.sep + 'ingr_info.tsv'
comp_info = path + os.sep + 'comp_info.tsv'
ingr_comp = path + os.sep + 'ingr_comp.tsv'


# {ingredient_id: [ingredient_name, ingredient_category]}
def load_ingredients(path):
    ingredients = {}
    ingredients_list = []
    with open(path, 'r') as f:
        for line in f:
            if line[0] == '#':
                pass
            else:
                line_split = line.rstrip().split('\t')
                ingredients_id = line_split[0]
                ingredients_list = line_split[1:]
                ingredients[ingredients_id] = ingredients_list
    return ingredients

# {compound_id: [compound_name, CAS_number]}
def load_compounds(path):
    compounds = {}
    compounds_list = []
    with open(path, 'r') as f:
        for line in f:
            if line[0] == '#':
                pass
            else:
                line_split = line.rstrip().split('\t')
                compounds_id = line_split[0]
                compounds_list = line_split[1:]
                compounds[compounds_id] = compounds_list
    return compounds

# {ingredient_id: [compound_id1, compound_id2, ...] }
def load_relations(path):
    relations = {}
    with open(path, 'r') as f:
        for line in f:
            if line[0] == '#':
                pass
            else:
                line_split = line.rstrip().split('\t')
                ingredient_id = line_split[0]
                compound_id = line_split[1]
                
                if ingredient_id in relations:
                    relations[ingredient_id].append(compound_id)
                    
                else:
                    relations[ingredient_id] = [compound_id]
                    
    return relations

ingredients = load_ingredients(ingr_info)
compounds = load_compounds(comp_info)
relations = load_relations(ingr_comp)

def ingredient_to_category(tag, ingredients):
    for ingr_id in ingredients:
        if ingredients[ingr_id][0] == tag:
            return ingredients[ingr_id][1]
        else: 
            continue
    return

print ingredient_to_category('copaiba', ingredients)


plant derivative

In [3]:
"""
Load functions for plotting a graph
"""

# Prettify ingredients
pretty_food = lambda s: ' '.join(s.split('_')).capitalize().lstrip()
# Prettify cuisine names
pretty_category = lambda s: ''.join(map(lambda x: x if x.islower() else " "+x, s)).lstrip()

def make_plot_simple(name, points, labels, publish):
    traces = []
    traces.append(go.Scattergl(
            x = points[:, 0],
            y = points[:, 1],
            mode = 'markers',
            marker = dict(
                color = sns.xkcd_rgb["black"],
                size = 8,
                opacity = 0.6,
                #line = dict(width = 1)
            ),
            text = labels,
            hoverinfo = 'text',
        )
        )
                  
    layout = go.Layout(
        xaxis=dict(
            autorange=True,
            showgrid=False,
            zeroline=False,
            showline=False,
            autotick=True,
            ticks='',
            showticklabels=False
        ),
        yaxis=dict(
            autorange=True,
            showgrid=False,
            zeroline=False,
            showline=False,
            autotick=True,
            ticks='',
            showticklabels=False
        )
        )
                  
    fig = go.Figure(data=traces, layout=layout)
    if publish:
        plotter = py.iplot
    else:
        plotter = offline.plot
    plotter(fig, filename=name + '.html')

def make_plot(name, points, labels, legend_labels, legend_order, legend_label_to_color, pretty_legend_label, publish):
    lst = zip(points, labels, legend_labels)
    full = sorted(lst, key=lambda x: x[2])
    traces = []
    for legend_label, group in itertools.groupby(full, lambda x: x[2]):
        group_points = []
        group_labels = []
        for tup in group:
            point, label, _ = tup
            group_points.append(point)
            group_labels.append(label)
        group_points = np.stack(group_points)
        traces.append(go.Scattergl(
            x = group_points[:, 0],
            y = group_points[:, 1],
            mode = 'markers',
            marker = dict(
                color = legend_label_to_color[legend_label],
                size = 8,
                opacity = 0.6,
                #line = dict(width = 1)
            ),
            text = ['{} ({})'.format(label, pretty_legend_label(legend_label)) for label in group_labels],
            hoverinfo = 'text',
            name = legend_label
        )
        )
    # order the legend
    ordered = [[trace for trace in traces if trace.name == lab] for lab in legend_order]
    traces_ordered = flatten(ordered)
    def _set_name(trace):
        trace.name = pretty_legend_label(trace.name)
        return trace
    traces_ordered = list(map(_set_name, traces_ordered))
    
    """
    annotations = []
    for index in range(50):
        new_dict = dict(
                x=points[:, 0][index],
                y=points[:, 1][index],
                xref='x',
                yref='y',
                text=labels[index],
                showarrow=True,
                arrowhead=7,
                ax=0,
                ay=-10
            )
        annotations.append(new_dict)
    """
    
    layout = go.Layout(
        xaxis=dict(
            autorange=True,
            showgrid=False,
            zeroline=True,
            showline=True,
            autotick=True,
            ticks='',
            showticklabels=False
        ),
        yaxis=dict(
            autorange=True,
            showgrid=False,
            zeroline=True,
            showline=True,
            autotick=True,
            ticks='',
            showticklabels=False
        ),
        #annotations=annotations
    )
    fig = go.Figure(data=traces_ordered, layout=layout)
    if publish:
        plotter = py.iplot
    else:
        plotter = offline.plot
    plotter(fig, filename=name + '.html')

In [109]:
import csv

path = 'data'
fname = path + os.sep + 'dict.csv'

i = 0

index_list = []
vector_list = []

with open(fname) as csvfile:
    reader = csv.reader(csvfile)
    for row in reader:
        index, vector = row
        vector = vector.split()[1:523]
    
        vector = np.array(vector)
        vector = vector.astype(np.float32)
        vector = list(vector)
        
        index_list.append(index)
        vector_list.append(vector)

X = np.array(vector_list)

In [110]:
"""
RUN TSNE - MAKE N-dimensional array to 2 dimmensional array
"""

tsne = TSNE(n_components=2)
X_tsne = tsne.fit_transform(X)

In [112]:
"""
Plotting
"""

# Create Label & Category
labels = []
categories = []

for i in index_list:
    label = ingredients[i][0]
    labels.append(label)    

for label in labels:
    categories.append(ingredient_to_category(label,ingredients))

categories_color = list(set(categories))


category2color = {
    'plant' :  sns.xkcd_rgb["purple"],
    'flower' : sns.xkcd_rgb["forest green"],
    'meat' : sns.xkcd_rgb["light pink"],
    'nut/seed/pulse' : sns.xkcd_rgb["mustard yellow"],
    'herb' : sns.xkcd_rgb["orange"],
    'alcoholic beverage' : sns.xkcd_rgb["magenta"],
    'plant derivative' : sns.xkcd_rgb["purple"],
    'fruit' : sns.xkcd_rgb["blue"],
    'dairy' : sns.xkcd_rgb["deep blue"],
    'cereal/crop' : sns.xkcd_rgb["sky blue"],
    'vegetable' : sns.xkcd_rgb["olive"],
    'animal product' : sns.xkcd_rgb["red"],
    'fish/seafood' : sns.xkcd_rgb["yellow"],
    'spice' : sns.xkcd_rgb["black"],
}

category_order = [
'plant',
'flower',
'meat',
'nut/seed/pulse',
'herb',
'alcoholic beverage',
'plant derivative',
'fruit',
'dairy',
'cereal/crop',
'vegetable',
'animal product',
'fish/seafood',
'spice',
]

In [114]:
make_plot_simple(name='yongqyu_test',
          points=X_tsne, 
          labels=labels, 
          publish=False)

In [113]:
make_plot(name='yongqyu_test2',
          points=X_tsne, 
          labels=labels, 
          legend_labels=categories, 
          legend_order=category_order, 
          legend_label_to_color=category2color, 
          pretty_legend_label=pretty_category,
          publish=False)

In [ ]: