In [71]:
import ipywidgets as widgets
import random
import pandas as pd
import plotly.graph_objects as go
import sys
import time
from IPython.display import display

import verdict

verdict.set_loglevel('error')
v = verdict.presto('presto')

def hsv_to_rgb(h, s, v):
    # h = [0, 360), s = [0, 1], v = [0, 1]
    c = (1 - abs(2*v - 1)) * s
    x = c * (1 - abs((h/60)%2 - 1))
    m = v - c/2
    h_i = int(h/60)
    r, g, b = {
        0: [c, x, 0],
        1: [x, c, 0],
        2: [0, c, x],
        3: [0, x, c],
        4: [x, 0, c],
        5: [c, 0, x],
    }[h_i]
    r, g, b = int((r+m)*255), int((g+m)*255), int((b+m)*255)
    return (r, g, b)

def rbg_to_str(rgb):
    r, g, b = rgb
    return f"rgb({r}, {g}, {b})"

def gen_colors(num):
    h_begin = 200.0
    h_end = 39.0 + 360.0
    l_begin = 0.3
    l_end = 0.55
    hsl = []
    for i in range(num):
        h = h_begin + (h_end - h_begin)*i/float(num-1)
        h = h%360
        l = l_begin + (l_end - l_begin)*i/float(num-1)
        s = min(abs(h-260)/60.0*0.7+0.3, 1.0)
        hsl.append((h, s, l))
    colors = [rbg_to_str(hsv_to_rgb(a[0], a[1], a[2])) for a in hsl]
    return colors

def new_figure():
    new_figure = go.FigureWidget(data=go.Bar())
    new_figure.update_layout(template='none')
    return new_figure

def updateFigure(fig, df):
    cols = df.columns
    if len(cols) == 1:
        fig.data[0].x = ['value']
        fig.data[0].y = df[cols[0]]
    else:
        fig.data[0].x = df.index
        fig.data[0].y = df[cols[-1]]
    fig.data[0].marker.color = gen_colors(len(df.index))

def sql(query):
    fig = new_figure()
    display(fig)
    start = time.time()
    result = v.sql(query)
    if len(result.columns) == 1:
        display(result)
    else:
        updateFigure(fig, result)
    print(f"elapsed time: {time.time() - start} seconds.")
    
def presto_sql(query):
    sql("bypass " + query)

def sql_stream(query):
    fig = new_figure()
    display(fig)
    result_itr = v.sql_stream(query)
    for i, result in enumerate(result_itr):
        if i < 5:
            if len(result.columns) == 1:
                display(result)
            else:
                updateFigure(fig, result)
                time.sleep(0.1)
        else:
            break

In [72]:
sql_stream("""\
SELECT ship_year, ship_month, count(*)
FROM (
    SELECT year(l_shipdate) AS ship_year, month(l_shipdate) AS ship_month
    FROM hive.tpch_sf100.lineitem
)
GROUP BY ship_year, ship_month
ORDER BY ship_year, ship_month
""")



In [ ]:


In [ ]: