from __future__ import print_function

import socket
import datetime
from collections import OrderedDict
from math import ceil, pi

import numpy as np

try:
    import ipywidgets as widgets
    from IPython.display import HTML, display
    from bokeh.application import Application
    from bokeh.application.handlers import FunctionHandler
    from bokeh.embed import autoload_server
    from bokeh.io import output_notebook, reset_output
    from bokeh.layouts import column, layout, gridplot
    from bokeh.models import ColumnDataSource, CustomJS, DatetimeTickFormatter
    from bokeh.models import widgets as bkhwidgets
    from bokeh.models.glyphs import Rect
    from bokeh.models.mappers import LinearColorMapper
    from bokeh.models.ranges import Range1d
    from bokeh.models.tools import BoxSelectTool, HoverTool, CrosshairTool
    from bokeh.models.tools import ResetTool, PanTool, BoxZoomTool, ResizeTool
    from bokeh.models.tools import WheelZoomTool, SaveTool
    from bokeh.palettes import Plasma256
    from bokeh.plotting import figure, Figure
    from bokeh.resources import INLINE
    from bokeh.server.server import Server
    from six import iteritems
    from tornado.ioloop import IOLoop
except:
    pass

try:
    from fs.client.jupyter.tools import *
except:
    from tools import *
    
# -- globals --------------------
bokeh_output_redirected = False
        
# ------------------------------------------------------------------------------
class Children(OrderedDict):
    def __init__(self, owner, obj_class):
        OrderedDict.__init__(self)
        self._owner = owner
        self._obj_class = obj_class
        self._add_callbacks = OrderedDict()

    def register_add_callback(self, cb):
        if cb and hasattr(cb, '__call__'):
            l = len(self._add_callbacks)
            self._add_callbacks[l + 1] = cb

    def call_add_callbacks(self, child):
        for cb in self._add_callbacks.values():
            try:
                cb(child)
            except:
                pass
    
    def add(self, children):
        if hasattr(children, '__iter__'):
            for c in children:
                self.__add_child(c)
        else:
            self.__add_child(children)

    def __add_child(self, child):
        if child is None:
            return
        if child is self._owner:
            err = "invalid argument: can't add 'self' to children"
            raise ValueError(err)
        if isinstance(child, self._obj_class):
            if not len(self):
                self._master_child = child
            self[child.name] = child
            self.call_add_callbacks(child)
        else:
            ic = child.__class__
            ec = self._obj_class.__name__
            err = "invalid argument: expected an iterable collection or a single instance of {} - got {}".format(ec, ic)
            raise ValueError(err)


# ------------------------------------------------------------------------------
class DataStreamEvent(object):
    """Data stream event"""

    Type = enum(
        'ERROR',
        'RECOVER',
        'EOS',
        'UNKNOWN'
    )

    def __init__(self, event_type=Type.UNKNOWN, emitter=None, error=None, exception=None):
        # evt. type
        self.type = event_type
        # evt. emitter
        self.emitter = None
        # error text
        self.error = error
        # exception
        self.exception = exception


# ------------------------------------------------------------------------------
class DataStreamEventHandler(object):
    """Data stream event handler"""

    supported_events = [
        DataStreamEvent.Type.ERROR,
        DataStreamEvent.Type.RECOVER,
        DataStreamEvent.Type.EOS
    ]

    def __init__(self, name):
        self._name = name
        # callbacks
        self._callbacks = dict()
        for event_type in self.supported_events:
            self._callbacks[event_type] = list()

    @property
    def name(self):
        return self._name

    def register_event_handler(self, event_handler, events):
        assert(isinstance(events, (list, tuple)))
        assert(isinstance(event_handler, DataStreamEventHandler))
        for event in events:
            if event in self.supported_events:
                #print("{}: registering event handler {} for event {}".format(self.name, event_handler.name, event))
                self._callbacks[event].append(event_handler)

    def emit(self, event):
        assert(isinstance(event, DataStreamEvent))
        if event.type in self.supported_events:
            for event_handler in self._callbacks[event.type]:
                try:
                    #print("{}: emitting event {} towards {}".format(self.name, event.type, event_handler.name))
                    event_handler.__handle_stream_event(event)
                except Exception as e:
                    print(e) #TODO
                    pass

    def __handle_stream_event(self, event):
        try:
            self.handle_stream_event(event)
        except:
            pass
        finally:
            self.__propagate(event)

    def __propagate(self, event):
        assert(isinstance(event, DataStreamEvent))
        #print("{}: propagating event {} ".format(self.name, event.type))
        self.emit(event)

    def emit_error(self, sd):
        evt = DataStreamEvent(DataStreamEvent.Type.ERROR, self, sd.error, sd.exception)
        self.emit(evt)

    def emit_recover(self):
        evt = DataStreamEvent(DataStreamEvent.Type.RECOVER, self)
        self.emit(evt)

    def handle_stream_event(self, event):
        raise Exception("DataStreamEventHandler.handle_stream_event: default implementation called")


# ------------------------------------------------------------------------------
class ChannelData(object):
    """channel data"""

    Format = enum(
        'SCALAR',
        'SPECTRUM',
        'IMAGE',
        'UNKNOWN'
    )

    def __init__(self, name='anonymous'):
        # name
        self._name = name
        # format
        self._format = ChannelData.Format.UNKNOWN
        # data buffer (numpy ndarray)
        self._buffer = np.zeros((0,0))
        # time buffer (numpy ndarray)
        self._time_buffer = None
        # update failed - data is invalid
        self._has_failed = False
        # has new data (updated since last read)
        self.has_been_updated = False
        # error txt
        self._error = "no error"
        # exception caught
        self._exception = None

    @property
    def name(self):
        return self._name

    @property
    def format(self):
        return self._format

    @property
    def has_failed(self):
        return self._has_failed

    @property
    def error(self):
        return self._error

    @property
    def exception(self):
        return self._exception

    @property
    def is_valid(self):
        return not self._has_failed and self._buffer is not None

    @property
    def dim_x(self):
        num_dims = len(self._buffer.shape)
        if num_dims >= 1:
            return self._buffer.shape[num_dims - 1]
        else:
            return 0

    @property
    def dim_y(self):
        num_dims = len(self._buffer.shape)
        if num_dims >= 2:
            return self._buffer.shape[num_dims - 2]
        else:
            return 0

    @property
    def buffer(self):
        return self._buffer

    @buffer.setter
    def buffer(self, data_buffer):
        self.set_data(data_buffer)
    
    @property
    def time_buffer(self):
        return self._time_buffer

    def set_data(self, data_buffer, time_buffer=None, format=None):
        assert (isinstance(data_buffer, np.ndarray))
        self._buffer = data_buffer
        self._time_buffer = time_buffer
        self._format = format
        self.has_been_updated = True
        self.reset_error()

    def reset_error(self):
        self._has_failed = False
        self._error = "no error"
        self._exception = None

    def set_error(self, err, exc):
        if not self._has_failed:
            self._has_failed = True
            self._error = "unknown error" if not err else err
            self._exception = Exception("unknown error") if not exc else exc
            self.__reset_data()

    def __reset_data(self):
        self._buffer = None
        self._time_buffer = None
        self._has_been_updated = False


# ------------------------------------------------------------------------------
class DataSource(object):

    def __init__(self, name):
        self._name = name

    @property
    def name(self):
        return self._name

    def pull_data(self):
        return ChannelData()

    def cleanup(self):
        pass

    
# ------------------------------------------------------------------------------
class Channel(CellChild, DataStreamEventHandler):
    """single data stream channel"""

    def __init__(self, name, data_sources=None, model_properties=None, notebook_cell=None):
        CellChild.__init__(self, name, notebook_cell)
        DataStreamEventHandler.__init__(self, name)
        # data sources
        self._bad_source_cnt = 0
        self._data_sources = Children(self, DataSource)
        self.add_data_sources(data_sources)
        # model properties
        self._model_props = model_properties

    def handle_stream_event(self, event):
        assert (isinstance(event, DataStreamEvent))
        if event.type == DataStreamEvent.Type.ERROR:
            self.__on_stream_error(event)
        elif event.type == DataStreamEvent.Type.RECOVER:
            self.__on_stream_recover(event)

    def __on_stream_error(self, event):
        pass

    def __on_stream_recover(self, event):
        pass

    @property
    def data_source(self):
        """returns the dict of data sources"""
        for ds in self._data_sources.values():
            return ds
        return None

    @property
    def data_sources(self):
        """returns the dict of data sources"""
        return self._data_sources

    def set_data_source(self, ds):
        """set the channel unique data source"""
        if ds is not None:
            assert(isinstance(ds, DataSource))
            self._data_sources.clear()
            self.add_data_source(ds)

    def add_data_source(self, ds):
        """add the specified data source to the channel"""
        if ds is not None:
            assert(isinstance(ds, DataSource))
            self._data_sources[ds.name] = ds

    def add_data_sources(self, ds):
        """add the specified data source to the channel"""
        if ds is not None:
            assert(isinstance(ds, (list, tuple)))
            for s in ds:
                self.add_data_source(s)

    def get_data(self):
        """returns a dict containing the data of each data source"""
        data = dict()
        for dsn, dsi in iteritems(self._data_sources):
            data[dsn] = dsi.pull_data()
        return data

    def cleanup(self):
        """cleanup default do nothing implementation"""
        pass

    @property
    def model_properties(self):
        """returns the dict of model properties"""
        return self._model_props

    @model_properties.setter
    def model_properties(self, mp):
        """set the dict of model properties"""
        self._model_props = mp

    @staticmethod
    def _merge_properties(mp1, mp2, overwrite=False):
        if mp1 is None:
            props = mp2 if mp2 is not None else dict()
        elif mp2 is None:
            props = mp1
        else:
            props = mp1
            for k, v in iteritems(mp2):
                if overwrite or k not in mp1:
                    props[k] = v
        return props

    def setup_model(self, **kwargs):
        """asks the channel to setup then return its Bokeh associated model - returns None if no model"""
        return None

    def get_model(self):
        """returns the Bokeh model (figure, layout, ...) associated with the Channel or None if no model"""
        return None

    def update(self):
        """gives the Channel a chance to update itself"""
        pass


# ------------------------------------------------------------------------------
class DataStream(CellChild, DataStreamEventHandler):
    """data stream interface"""

    def __init__(self, name, channels=None, cell=None):
        CellChild.__init__(self, name, cell)
        DataStreamEventHandler.__init__(self, name)
        # channels
        self._channels = Children(self, Channel)
        self._channels.register_add_callback(self._on_add_channel)
        self.add(channels)

    def add(self, channels):
        """add the specified channels"""
        self._channels.add(channels)

    def _on_add_channel(self, channel):
        """called when a new channel is added to the data stream"""
        channel.parent = self.parent
        events = [DataStreamEvent.Type.ERROR, DataStreamEvent.Type.RECOVER]
        channel.register_event_handler(self, events)

    def handle_stream_event(self, event):
        assert (isinstance(event, DataStreamEvent))
        if event.type == DataStreamEvent.Type.ERROR:
            self.__on_stream_error(event)
        elif event.type == DataStreamEvent.Type.RECOVER:
            self.__on_stream_recover(event)

    def __on_stream_error(self, event):
        pass

    def __on_stream_recover(self, event):
        pass

    def get_models(self):
        """returns the Bokeh model (figure, layout, ...)s associated with the DataStream"""
        return [channel.get_model() for channel in self._channels.values()]

    def setup_models(self):
        """returns the Bokeh model (figure, layout, ...)s associated with the DataStream"""
        return [channel.setup_model() for channel in self._channels.values()]

    def update(self):
        """gives each Channel a chance to update itself (e.g. to update the ColumDataSources)"""
        for channel in self._channels.values():
            try:
                channel.update()
            except Exception as e:
                self.exception(e)

    def cleanup(self):
        """asks each Channel to cleanup itself (e.g. release resources)"""
        for channel in self._channels.values():
            try:
                channel.cleanup()
            except Exception as e:
                self.exception(e)


# ------------------------------------------------------------------------------
class DataStreamer(CellChild, DataStreamEventHandler):
    """a data stream manager embedded a bokeh server"""

    def __init__(self, name, data_streams, update_period=1., parent_cell=None, ip_addr=None):
        # route output to current cell
        CellChild.__init__(self, name, parent_cell)
        DataStreamEventHandler.__init__(self, name)
        # ip addr on which the server will be started
        self._ip_addr = ip_addr
        # embedded bokeh server
        self._srv = None
        # bokeh document
        self._doc = None
        # ipython html context in which the datastream is displayed
        self._html_display = None
        # callback period in sec
        self._update_period = 1000. * update_period
        # the data streams
        self._data_streams = list()
        self.add(data_streams)

    def add(self, ds):
        if isinstance(ds, DataStream):
            ds.parent = self.parent
            self.__register_event_handler(ds)
            self._data_streams.append(ds)
        elif isinstance(ds, (list, tuple)):
            for s in ds:
                if not isinstance(s, DataStream):
                    raise ValueError("invalid argument: expected a list, a tuple or a single instance of DataStream")
                s.parent = self.parent
                self.__register_event_handler(s)
                self._data_streams.append(s)
        else:
            raise ValueError("invalid argument: expected a list, a tuple or a single instance of DataStream")

    def __register_event_handler(self, ds):
        assert(isinstance(ds, DataStream))
        events = [DataStreamEvent.Type.ERROR, DataStreamEvent.Type.RECOVER]
        ds.register_event_handler(self, events)

    def handle_stream_event(self, event):
        assert (isinstance(event, DataStreamEvent))
        if event.type == DataStreamEvent.Type.ERROR:
            self.__on_stream_error(event)
        elif event.type == DataStreamEvent.Type.RECOVER:
            self.__on_stream_recover(event)

    def __on_stream_error(self, event):
        pass

    def __on_stream_recover(self, event):
        pass

    @tracer
    def start(self):
        """starts attached data streams"""
        self.__start_bokeh_server()

    @tracer
    def stop(self):
        """stops attached data streams"""
        self.__uninstall_periodic_callbacks()

    @tracer
    def close(self):
        """stops attached data streams then clean"""
        self.cleanup()

    @tracer
    def cleanup(self):
        # TODO: use with_silent_catch
        try:
            self.__uninstall_periodic_callbacks()
        except:
            pass
        try:
            self.__cleanup_data_streams()
        except:
            pass
        try:
            self.__clear_models()
        except:
            pass
        try:
            self.__stop_bokeh_server()
        except:
            pass

    @property
    def update_period(self):
        """returns the update period (in seconds)"""
        return self._update_period / 1000.

    @update_period.setter
    def update_period(self, update_period):
        """set the update period (in seconds)"""
        self._update_period = 1000. * update_period
        self.__uninstall_periodic_callbacks()
        self.__install_periodic_callbacks()

    @tracer
    def __start_bokeh_server(self, clear_output=False):
        """starts the underlying bokeh server (if not already running)"""
        global bokeh_output_redirected
        if bokeh_output_redirected:
            self.debug("Bokeh output already redirected to Jupyter notebook")
        else:
            self.debug("redirecting Bokeh output to Jupyter notebook...")
            output_notebook(resources=INLINE, hide_banner=True)
            bokeh_output_redirected = True
            self.debug("Bokeh output successfully redirected")
        if self._srv:
            self.debug("Bokeh server already running")
            if clear_output:
                self.__clear_all_for_restart()
            else:
                self.__install_periodic_callbacks()
        else:
            self.debug("starting Bokeh server...")
            self._srv = Server(
                {'/': Application(FunctionHandler(self.__entry_point))},
                io_loop=IOLoop.current(),
                port=0,
                host='*',
                allow_websocket_origin=['*']
            )
            self._srv.start()
            if not self._ip_addr:
                self._ip_addr = socket.gethostbyname(socket.gethostname())
            script = autoload_server(model=None, url='http://{}:{}'.format(self._ip_addr, self._srv.port))
            self._html_display = HTML(script)
            display(self._html_display)
            self.debug("Bokeh server successfully started")

    @tracer
    def __stop_bokeh_server(self):
        """stops the underlying bokeh server"""
        if self._srv:
            self.debug("stopping Bokeh server...")
            try:
                self.__get_session().destroy()
                self._srv.stop()
            except Exception as e:
                self.error(e)
            finally:
                self._srv_session = None
                self._srv = None
        self.__uninstall_periodic_callbacks()
        self.__clear_models()
        self._doc = None
        self.debug("Bokeh server stopped & cleanup done")

    def __entry_point(self, doc):
        """the bokeh server entry point"""
        try:
            self._doc = doc
            self.__setup_models()
            self.__periodic_callback()
            self.__install_periodic_callbacks()
        except Exception as e:
            self.error(e)

    @tracer
    def __clear_all_for_restart(self):
        """does the necessary cleanup to restart the data stream"""
        try:
            self.__clear_models()
            self._parent.clear_output()
            display(self._html_display)
        except Exception as e:
            self.error(e)

    def __get_session(self):
        """returns the server's session"""
        session = None
        try:
            session = self._srv.get_sessions('/')[0]
        except:
            pass
        return session

    @tracer
    def __install_periodic_callbacks(self):
        """installs the periodic callbacks - notably the one used to trigger stream updates"""
        try:
            self._doc.add_periodic_callback(self.__periodic_callback, self._update_period)
        except Exception as e:
            self.error(e)

    @tracer
    def __uninstall_periodic_callbacks(self):
        """uninstalls the periodic callbacks"""
        try:
            if self._doc:
                self._doc.remove_periodic_callback(self.__periodic_callback)
        except ValueError:
            # already removed
            pass
        except Exception as e:
            self.error(e)

    def __periodic_callback(self):
        """the periodic callback"""
        for ds in self._data_streams:
            try:
                ds.update()
            except Exception as e:
                # self.error(e)
                pass

    @tracer
    def __setup_models(self):
        """add the data stream models to the bokeh document"""
        session = self.__get_session()
        for ds in self._data_streams:
            models = ds.setup_models()
            for m in models:
                try:
                    self._doc.add_root(m, setter=session)
                except Exception as e:
                    self.exception(e)

    @tracer
    def __clear_models(self):
        """removes the data stream models from the bokeh document"""
        session = self.__get_session()
        for ds in self._data_streams:
            models = ds.get_models()
            for m in models:
                try:
                    self._doc.remove_root(m, setter=session)
                except Exception as e:
                    self.exception(e)
        try:
            reset_output()
        except Exception as e:
            self.exception(e)

    def __cleanup_data_streams(self):
        """the periodic callback"""

        for ds in self._data_streams:
            try:
                ds.cleanup()
            except Exception as e:
                self.exception(e)


# ------------------------------------------------------------------------------
class DataStreamerController(CellChild, DataStreamEventHandler):
    """a DataStreamer controller"""

    def __init__(self, name, data_streamer, **kwargs):
        # check input parameters
        assert (isinstance(data_streamer, DataStreamer))
        # route output to current cell
        CellChild.__init__(self, name, kwargs.get('parent_cell', None))
        DataStreamEventHandler.__init__(self, name)
        # data streamer
        self.data_streamer = data_streamer
        # start/stop/close button
        self.__setup_controls(**kwargs)
        # function called when the close button is clicked
        self._close_callbacks = list()
        # auto-start
        if kwargs.get('auto_start', True):
            self._running = False
            self.__on_freeze_unfreeze_clicked()
        else:
            self._running = False

    @staticmethod
    def l01a(width='auto', *args, **kwargs):
        return widgets.Layout(flex='0 1 auto', width=width, *args, **kwargs)

    @staticmethod
    def l11a(width='auto', *args, **kwargs):
        return widgets.Layout(flex='1 1 auto', width=width, *args, **kwargs)

    def __setup_update_period_slider(self):
        return widgets.FloatSlider(
            value=self.data_streamer.update_period,
            min=0.25,
            max=5.0,
            step=0.25,
            description='Refresh Period (s)',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='.2f',
        )

    def __setup_controls(self, **kwargs):
        self._error_area = None
        self._error_layout = None
        self._up_slider = None
        if kwargs.get('up_slider_enabled', True):
            self._up_slider = self.__setup_update_period_slider()
            self._up_slider.observe(self.__on_refresh_period_changed, names='value')
        else:
            self._up_slider = None
        bd = "Freeze" if kwargs.get('auto_start', True) else "Unfreeze"
        self._freeze_unfreeze_button = widgets.Button(description=bd, layout=self.l01a(width="100px"))
        self._freeze_unfreeze_button.on_click(self.__on_freeze_unfreeze_clicked)
        self._close_button = widgets.Button(description="Close",layout=self.l01a(width="100px"))
        self._close_button.on_click(self.__on_close_clicked)
        self._switch_buttons_to_valid_state()
        wigets_list = list()
        if self._up_slider:
            wigets_list.append(self._up_slider)
        wigets_list.extend([self._freeze_unfreeze_button, self._close_button])
        self._controls = widgets.HBox(wigets_list, layout=self.l01a())
        display(self._controls)

    def __on_refresh_period_changed(self, event):
        try:
            self.data_streamer.update_period = event['new']
        except Exception as e:
            self.exception(e)

    def __on_freeze_unfreeze_clicked(self, b=None):
        if self._running:
            self._data_streamer.stop()
            self._freeze_unfreeze_button.description = "Unfreeze"
        else:
            self._data_streamer.start()
            self._freeze_unfreeze_button.description = "Freeze"
        self._running = not self._running
        if self._running and self._up_slider is not None:
            self._up_slider.value = self.data_streamer.update_period

    def __on_close_clicked(self, b=None):
        self.close()

    def close(self):
        try:
            self._data_streamer.close()
        except Exception as e:
            self.exception(e)
        self._controls.close()
        if self._error_area:
            self._error_area.close()
        self._parent.clear_output()
        self.__call_close_callbacks()

    def register_close_callback(self, cb):
        assert(hasattr(cb, '__call__'))
        self._close_callbacks.append(cb)

    def __call_close_callbacks(self):
        for cb in self._close_callbacks:
            try:
                cb()
            except:
                pass

    def handle_stream_event(self, event):
        assert(isinstance(event, DataStreamEvent))
        if event.type == DataStreamEvent.Type.ERROR:
            self.__on_stream_error(event)
        elif event.type == DataStreamEvent.Type.RECOVER:
            self.__on_stream_recover(event)
        elif event.type == DataStreamEvent.Type.EOS:
            self.__on_end_of_stream(event)

    def __on_stream_error(self, event):
        self._switch_buttons_to_invalid_state()
        self._show_error(event.error)

    def __on_stream_recover(self, event):
        self._switch_buttons_to_valid_state()
        self._hide_error()

    def __on_end_of_stream(self, event):
        self.__on_freeze_unfreeze_clicked()

    def _switch_buttons_to_valid_state(self):
        self._close_button.style.button_color = '#00FF00'
        self._freeze_unfreeze_button.style.button_color = '#00FF00'

    def _switch_buttons_to_invalid_state(self):
        self._close_button.style.button_color = '#FF0000'
        self._freeze_unfreeze_button.style.button_color = '#FF0000'

    def _show_error(self, err_desc):
        err = "Oops, the following error occurred:\n"
        err += err_desc
        if not self._error_area:
            self._error_area = widgets.Textarea(value=err, layout=self.l11a())
            self._error_area.rows = 3
            self.display(self._error_area)
        else:
            self._error_area.value = err

    def _hide_error(self):
        try:
            self._error_area.close()
        except:
            pass
        finally:
            self._error_area = None

    @property
    def data_streamer(self):
        return self._data_streamer

    @data_streamer.setter
    def data_streamer(self, ds):
        # check input parameter
        assert (isinstance(ds, DataStreamer))
        # data streamer
        self._data_streamer = ds
        # route data streamer output to current cell
        self._data_streamer.parent = self.parent
        # register event handler
        events = [DataStreamEvent.Type.ERROR, DataStreamEvent.Type.RECOVER, DataStreamEvent.Type.EOS]
        self._data_streamer.register_event_handler(self, events)


# ------------------------------------------------------------------------------
class BoxSelectionManager(object):
    """BoxSelectTool manager"""

    repository = dict()

    def __init__(self, selection_callback=None, reset_callback=None):
        self._uid = uuid4().int
        BoxSelectionManager.repository[self._uid] = self
        self._selection_callback = selection_callback
        self._reset_callback = reset_callback
        self._selection_cds = self.__setup_selection_data_source()

    def __del__(self):
        del BoxSelectionManager.repository[self._uid]

    def __setup_selection_data_source(self):
        cds = ColumnDataSource(data=dict(x0=[0], y0=[0], width=[0], height=[0]))
        cds.tags = [str(self._uid)]
        return cds

    @property
    def selection_callback(self):
        return self._selection_callback

    @selection_callback.setter
    def selection_callback(self, scb):
        self._selection_callback = scb

    @property
    def reset_callback(self):
        return self._reset_callback

    @reset_callback.setter
    def reset_callback(self, rcb):
        self._reset_callback = rcb

    def __selection_glyph(self):
        kwargs = dict()
        kwargs['x'] = 'x0'
        kwargs['y'] = 'y0'
        kwargs['width'] = 'width'
        kwargs['height'] = 'height'
        kwargs['fill_alpha'] = 0.1
        kwargs['fill_color'] = '#009933'
        kwargs['line_color'] = 'white'
        kwargs['line_dash'] = 'dotdash'
        kwargs['line_width'] = 2
        return Rect(**kwargs)

    def register_figure(self, fig):
        try:
            bst = fig.select(BoxSelectTool)[0]
            bst.callback = self.__box_selection_callback()
        except:
            return
        try:
            rst = fig.select(ResetTool)[0]
            rst.js_on_change('do', self.__reset_callback())
        except:
            return
        rect = self.__selection_glyph()
        fig.add_glyph(self._selection_cds, glyph=rect, selection_glyph=rect, nonselection_glyph=rect)

    def __box_selection_callback(self):
        return CustomJS(args=dict(cds=self._selection_cds), code="""
            var data = cds.data
            var geometry = cb_data['geometry']
            var width = geometry['x1'] - geometry['x0']
            var height = geometry['y1'] - geometry['y0']
            var x0 = geometry['x0'] + width / 2
            var y0 = geometry['y0'] + height / 2
            cds.data['x0'][0] = x0
            cds.data['y0'][0] = y0
            cds.data['width'][0] = width
            cds.data['height'][0] = height
            cds.trigger('change')
            var imp = "from DataStreamingModel import BoxSelectionManager;"
            var pfx = "BoxSelectionManager.repository[".concat(cds.tags[0], "].on_selection_change(")
            var arg = JSON.stringify({'x0':[x0], 'y0':[y0], 'width':[width], 'height':[height]})
            var sfx = ")"
            var cmd  = imp.concat(pfx, arg, sfx)
            console.log(cmd)
            var kernel = IPython.notebook.kernel
            kernel.execute(cmd)
        """)

    def __reset_callback(self):
        return CustomJS(args=dict(cds=self._selection_cds), code="""
            cds.data['x0'][0] = 0
            cds.data['y0'][0] = 0
            cds.data['width'][0] = 0
            cds.data['height'][0] = 0
            cds.trigger('change')
            var imp = "from DataStreamingModel import BoxSelectionManager;"
            var rst = "BoxSelectionManager.repository[".concat(cds.tags[0],"].on_selection_reset()")
            var cmd  = imp.concat(rst)
            console.log(cmd)
            var kernel = IPython.notebook.kernel
            kernel.execute(cmd)
        """)

    def __selection_range(self):
        w = self._selection_cds.data['width'][0]
        h = self._selection_cds.data['height'][0]
        x0 = self._selection_cds.data['x0'][0] - w / 2.
        x1 = self._selection_cds.data['x0'][0] + w / 2.
        y0 = self._selection_cds.data['y0'][0] - h / 2.
        y1 = self._selection_cds.data['y0'][0] + h / 2.
        return {'x0': x0, 'x1': x1, 'y0': y0, 'y1': y1, 'width': w, 'height': h}

    def on_selection_change(self, selection):
        self._selection_cds.data.update(selection)
        try:
            if self._selection_callback:
                self._selection_callback(self.__selection_range())
        except Exception as e:
            print(e)

    def on_selection_reset(self):
        if self._reset_callback:
            self._reset_callback()

            
# ------------------------------------------------------------------------------
class Scale(object):
    """a scale"""
    
    def __init__(self, **kwargs):
        self._start = kwargs.get('start', None)
        self._end = kwargs.get('end', None)
        self._num_points = kwargs.get('num_points', None)
        self._label = kwargs.get('label', None)
        self._unit = kwargs.get('unit', None)
        self._channel = kwargs.get('channel', None)
        self._array, self._step = self.__compute_linear_space()
        
    @property
    def start(self):
        return self._start
    
    @start.setter
    def start(self, s):
        raise Exception("Scale.start is immutable - can't change its value")
        
    @property
    def end(self):
        return self._end
    
    @end.setter
    def end(self, e):
        raise Exception("Scale.end is immutable - can't change its value")
        
    @property
    def num_points(self):
        return self._num_points
    
    @num_points.setter
    def num_points(self, np):
        raise Exception("Scale.num_points is immutable - can't change its value")

    @property
    def step(self):
        return self._step

    @step.setter
    def step(self, s):
        raise Exception("Scale.step is immutable - can't change its value")
       
    @property
    def label(self):
        return self._label

    @label.setter
    def label(self, label):
        self._label = label
        
    @property
    def unit(self):
        return self._unit

    @unit.setter
    def unit(self, unit):
        self._unit = unit
        
    @property
    def channel(self):
        return self._channel

    @channel.setter
    def channel(self, c):
        raise Exception("Scale.channel is immutable - can't change its value")
      
    @property
    def array(self):
        return self._array

    @array.setter
    def array(self, a):
        raise Exception("Scale.array is immutable - can't change its value")

    @property
    def shared_range(self):
        return None
    
    @shared_range.setter
    def shared_range(self, r):
        raise Exception("Scale.shared_range is immutable - can't change its value")
        
    def validate(self):
        self.__validate_range()
        self.__validate_num_points()
        
    def __validate_range(self):
        if self._start is not None and self._end is not None and self._start == self._end:
            raise ValueError("invalid axis scale: the specified 'range' is empty")
    
    def __validate_num_points(self):
        if self._start is not None and self._end is not None and self._num_points is not None and self._num_points < 1:
            raise ValueError("invalid axis scale: the specified 'num_points' is invalid")
            
    def __compute_linear_space(self):
        try:
            array, step = np.linspace(float(self._start), 
                                      float(self._end), 
                                      int(self._num_points),
                                      endpoint=True, 
                                      retstep=True)
        except:
            array, step = np.zeros((0,)), 0.
        return array, step
    
    def has_valid_range(self):
        return False
        
    def has_valid_scale(self):
        valid_range = self._start is not None and self._end is not None and self._start != self._end
        return valid_range and self._num_points is not None and self._num_points >= 1
    
    def axis_label(self):
        label = self._label
        unit = self._unit
        axis_label = ''
        if label:
            axis_label = label
        if unit:
            axis_label += ' [' if len(label) else ''
            axis_label += unit
            axis_label += ']' if len(label) else ''
        return None if not len(axis_label) else axis_label
        
        
# ------------------------------------------------------------------------------
class SharedScale(Scale):
    """a scale shared"""
    
    def __init__(self, **kwargs):
        shared_range = kwargs.get('shared_range', None)
        if shared_range is None or not isinstance(shared_range, Range1d):
            raise ValueError("invalid share scale: no valid 'shared_range' specified")  
        kwargs['start'] = kwargs['shared_range'].start
        kwargs['end'] = kwargs['shared_range'].end
        Scale.__init__(self, **kwargs)
        self._range = shared_range
     
    @property
    def shared_range(self):
        return self._range
    
    @shared_range.setter
    def shared_range(self, r):
        raise Exception("Scale.shared_range is immutable - can't change its value")

    def has_valid_range(self):
        return self._range is not None
         
        
# ------------------------------------------------------------------------------
class ModelHelper(object):

    line_colors = {
        0: 'darkblue',
        1: 'crimson',
        2: 'darkgreen',
        3: 'black',
        4: 'darkorchid',
        5: 'darkorange',
        6: 'deepskyblue',
        7: 'slategrey',
        8: 'gold',
        9: 'magenta'
    }

    @staticmethod
    def line_color(index):
        i = index % 10
        return ModelHelper.line_colors[i]

    @staticmethod
    def plot_style(instance, index):
        assert (isinstance(instance, Figure))
        i = index % 3
        if i == 0:
            return instance.circle
        if i == 1:
            return instance.square
        if i == 2:
            return instance.diamond
        return instance.square


# ------------------------------------------------------------------------------
class ScalarChannel(Channel):
    """this is not supposed to be instanciated directly"""

    def __init__(self, name, data_sources=None, model_properties=None):
        Channel.__init__(self, name, data_sources=data_sources, model_properties=model_properties)
        self.__reinitialize()

    def __reinitialize(self):
        self._cds = None  # column data source
        self._mdl = None  # model
        self._ngl = 0  # num of glyphs in figure

    def get_model(self):
        """returns the Bokeh model (figure, layout, ...) associated with the Channel or None if no model"""
        return self._mdl

    def __instanciate_data_source(self):
        columns = OrderedDict()
        # add an entry for timestamp
        columns['_@time@_'] = np.zeros(0)
        # add an entry for each child
        for cn, ci in iteritems(self.data_sources):
            columns[cn] = np.zeros(0)
        return ColumnDataSource(data=columns)

    def __setup_toolbar(self, figure):
        htt = [
            ("index", "$index"),
            ("(x,y)", "($x, $y)")
        ]
        figure.add_tools(PanTool())
        figure.add_tools(BoxZoomTool())
        figure.add_tools(WheelZoomTool())
        figure.add_tools(ResizeTool())
        figure.add_tools(ResetTool())
        figure.add_tools(SaveTool())
        figure.add_tools(HoverTool(tooltips=htt))
        figure.add_tools(CrosshairTool())
        figure.toolbar.logo = None
        figure.toolbar.active_drag = None
        figure.toolbar.active_scroll = None
        figure.toolbar.active_tap = None

    def __setup_figure(self, **kwargs):
        fkwargs = dict()
        fkwargs['webgl'] = True
        fkwargs['plot_width'] = kwargs.get('width', 950)
        fkwargs['plot_height'] = kwargs.get('height', 250)
        fkwargs['toolbar_location'] = 'above'
        fkwargs['tools'] = ''
        fkwargs['x_axis_type'] = 'datetime'
        f = figure(**fkwargs)
        dtf = DatetimeTickFormatter()
        dtf.milliseconds = "%H:%M:%S:%3N"
        dtf.seconds = "%H:%M:%S:%3N"
        dtf.minutes = "%H:%M:%S:%3N"
        dtf.hours = "%H:%M:%S:%3N"
        f.xaxis.formatter = dtf
        f.xaxis.major_label_orientation = pi / 4
        layout = kwargs.get('layout', 'column')
        if kwargs['show_title'] and layout != 'tabs':
            f.title.text = self.name
        return f

    def __setup_glyph(self, figure, data_source, show_legend=True):
        kwargs = dict()
        kwargs['x'] = '_@time@_'
        kwargs['y'] = data_source
        kwargs['source'] = self._cds
        kwargs['line_color'] = ModelHelper.line_color(self._ngl)
        figure.line(**kwargs)
        kwargs['size'] = 3
        kwargs['line_color'] = ModelHelper.line_color(self._ngl + 1)
        kwargs['legend'] = None if not show_legend else data_source + ' '
        figure.circle(**kwargs)
        self._ngl += 1

    def setup_model(self, **kwargs):
        """asks the channel to setup then return its Bokeh associated model - returns None if no model"""
        self._mdl = None
        props = self._merge_properties(self.model_properties, kwargs)
        # instanciate the ColumnDataSource
        self._cds = self.__instanciate_data_source()
        # setup figure
        show_title = True if len(self.data_sources) == 1 else False
        show_title = props.get('show_title', show_title)
        props['show_title'] = show_title
        f = self.__setup_figure(**props)
        # setup glyphs
        show_legend = False if len(self.data_sources) == 1 else True
        show_legend = props.get('show_legend', show_legend)
        for data_source in self.data_sources:
            self.__setup_glyph(f, data_source, show_legend)
        # setup the toolbar
        self.__setup_toolbar(f)
        # store figure
        self._mdl = f
        return self._mdl

    def update(self):
        """gives each Channel a chance to update itself (e.g. to update the ColumDataSources)"""
        try:
            # get data from each channel
            min_len = 2 ** 32 - 1
            data = dict()
            previous_bad_source_cnt = self._bad_source_cnt
            self._bad_source_cnt = 0
            for sn, si in iteritems(self.data_sources):
                data[sn] = sd = si.pull_data()
                if sd.has_failed or sd.buffer is None:
                    min_len = 0
                    self._bad_source_cnt += 1
                    self.emit_error(sd)
                else:
                    min_len = min(min_len, sd.buffer.shape[0])
            if not self._bad_source_cnt and previous_bad_source_cnt:
                self.emit_recover()
            updated_data = dict()
            time_scale_set = False
            for cn, ci in iteritems(self.data_sources):
                try:
                    if not time_scale_set:
                        updated_data['_@time@_'] = data[cn].time_buffer[-min_len:]
                        time_scale_set = True
                    updated_data[cn] = data[cn].buffer[-min_len:]
                except Exception:
                    updated_data['_@time@_'] = np.zeros((min_len,), dtype=datetime.datetime)
                    updated_data[cn] = np.zeros((min_len,), np.float)
            self._cds.data.update(updated_data)
        except Exception as e:
            print(e)

    def cleanup(self):
        self.__reinitialize()
        super(ScalarChannel, self).cleanup()
