import zmq
import time
import random
import itertools
from functools import partial
from multiprocessing import Process
from bokeh.plotting import figure
from bokeh.layouts import gridplot
from bokeh.palettes import Category10
from bokeh.server.server import Server
from bokeh.application import Application
from bokeh.models.ranges import DataRange1d
from bokeh.models import ColumnDataSource
from bokeh.application.handlers.function import FunctionHandler


def make_document(context, doc):
    print('make document')
    socket = context.socket(zmq.SUB)
    socket.connect('tcp://localhost:5559')
    socket.setsockopt(zmq.SUBSCRIBE, b'')
    poller = zmq.Poller()
    poller.register(socket, zmq.POLLIN)
    sources = {}
    figures = []
    color_cycle = itertools.cycle(Category10[10])

    x_range = DataRange1d(follow='end', follow_interval=5*60*1000, range_padding=0)
    columns = ['event_rate', 'data_rate', 'buffer_queue', 'output_queue']
    for i in range(4):
        fig = figure(x_axis_type='datetime', x_range=x_range,
                     plot_width=600, plot_height=350)
        figures.append(fig)

    layout = gridplot([[figures[0], figures[1]], [figures[2], figures[3]]])
    doc.add_root(layout)

    def update():
        while True:
            socks = dict(poller.poll(timeout=0))
            if not socks:
                break
            hostname, metrics = socket.recv_json()
            if hostname not in sources:
                source = ColumnDataSource(data={'time': [], 'event_rate': [], 'data_rate': [],
                                          'buffer_queue': [], 'output_queue': []})
                color = next(color_cycle)
                for i in range(len(figures)):
                    line = figures[i].line(x='time', y=columns[i], source=source,
                            line_width=2, color=color)
                sources[hostname] = source
                print('new host')
            # shift timestamp from UTC to current timezone and convert to milliseconds
            metrics['time'] = [(t - time.altzone)*1000 for t in metrics['time']]
            sources[hostname].stream(metrics)
            print(time.ctime(), 'streamed data')

    doc.add_periodic_callback(update, 1000)


def sender(context, hostname):
    context = zmq.Context()
    socket = context.socket(zmq.PUB)
    socket.bind('tcp://*:5559')
    while True:
        socket.send_json([hostname,
                         {'time': [time.time()],
                          'event_rate': [random.random()],
                          'data_rate': [random.random()],
                          'buffer_queue': [random.random()],
                          'output_queue': [random.random()]}])
        time.sleep(1)


context = zmq.Context()
Process(target=sender, args=(context, 'host1')).start()

apps = {'/': Application(FunctionHandler(partial(make_document, context)))}

server = Server(apps, port=5000)
server.start()
server.io_loop.add_callback(server.show, '/')
server.io_loop.start()
