"""
BinaryTree inter-engine communication class

use from bintree_script.py

Provides parallel [all]reduce functionality

"""
from __future__ import print_function

import cPickle as pickle
import re
import socket
import uuid

import zmq

from IPython.parallel.util import disambiguate_url


#----------------------------------------------------------------------------
# bintree-related construction/printing helpers
#----------------------------------------------------------------------------

def bintree(ids, parent=None):
    """construct {child:parent} dict representation of a binary tree
    
    keys are the nodes in the tree, and values are the parent of each node.
    
    The root node has parent `parent`, default: None.
    
    >>> tree = bintree(range(7))
    >>> tree
    {0: None, 1: 0, 2: 1, 3: 1, 4: 0, 5: 4, 6: 4}
    >>> print_bintree(tree)
    0
      1
        2
        3
      4
        5
        6
    """
    parents = {}
    n = len(ids)
    if n == 0:
        return parents
    root = ids[0]
    parents[root] = parent
    if len(ids) == 1:
        return parents
    else:
        ids = ids[1:]
        n = len(ids)
        left = bintree(ids[:n/2], parent=root)
        right = bintree(ids[n/2:], parent=root)
        parents.update(left)
        parents.update(right)
    return parents

def reverse_bintree(parents):
    """construct {parent:[children]} dict from {child:parent}
    
    keys are the nodes in the tree, and values are the lists of children
    of that node in the tree.
    
    reverse_tree[None] is the root node
    
    >>> tree = bintree(range(7))
    >>> reverse_bintree(tree)
    {None: 0, 0: [1, 4], 4: [5, 6], 1: [2, 3]}
    """
    children = {}
    for child,parent in parents.iteritems():
        if parent is None:
            children[None] = child
            continue
        elif parent not in children:
            children[parent] = []
        children[parent].append(child)
    
    return children

def depth(n, tree):
    """get depth of an element in the tree"""
    d = 0
    parent = tree[n]
    while parent is not None:
        d += 1
        parent = tree[parent]
    return d

def print_bintree(tree, indent='  '):
    """print a binary tree"""
    for n in sorted(tree.keys()):
        print("%s%s" % (indent * depth(n,tree), n))

#----------------------------------------------------------------------------
# Communicator class for a binary-tree map
#----------------------------------------------------------------------------

ip_pat = re.compile(r'^\d+\.\d+\.\d+\.\d+$')

def disambiguate_dns_url(url, location):
    """accept either IP address or dns name, and return IP"""
    if not ip_pat.match(location):
        location = socket.gethostbyname(location)
    return disambiguate_url(url, location)

class BinaryTreeCommunicator(object):
    
    id          = None
    pub         = None
    sub         = None
    downstream  = None
    upstream    = None
    pub_url     = None
    tree_url    = None
    
    def __init__(self, id, interface='tcp://*', root=False):
        self.id = id
        self.root = root
        
        # create context and sockets
        self._ctx = zmq.Context()
        if root:
            self.pub = self._ctx.socket(zmq.PUB)
        else:
            self.sub = self._ctx.socket(zmq.SUB)
            self.sub.setsockopt(zmq.SUBSCRIBE, b'')
        self.downstream = self._ctx.socket(zmq.PULL)
        self.upstream = self._ctx.socket(zmq.PUSH)
        
        # bind to ports
        interface_f = interface + ":%i"
        if self.root:
            pub_port = self.pub.bind_to_random_port(interface)
            self.pub_url = interface_f % pub_port
        
        tree_port = self.downstream.bind_to_random_port(interface)
        self.tree_url = interface_f % tree_port
        self.downstream_poller = zmq.Poller()
        self.downstream_poller.register(self.downstream, zmq.POLLIN)
        
        # guess first public IP from socket
        self.location = socket.gethostbyname_ex(socket.gethostname())[-1][0]
    
    def __del__(self):
        self.downstream.close()
        self.upstream.close()
        if self.root:
            self.pub.close()
        else:
            self.sub.close()
        self._ctx.term()
    
    @property
    def info(self):
        """return the connection info for this object's sockets."""
        return (self.tree_url, self.location)
    
    def connect(self, peers, btree, pub_url, root_id=0):
        """connect to peers.  `peers` will be a dict of 4-tuples, keyed by name.
        {peer : (ident, addr, pub_addr, location)}
        where peer is the name, ident is the XREP identity, addr,pub_addr are the
        """
        
        # count the number of children we have
        self.nchildren = btree.values().count(self.id)
        
        if self.root:
            return # root only binds
        
        root_location = peers[root_id][-1]
        self.sub.connect(disambiguate_dns_url(pub_url, root_location))
        
        parent = btree[self.id]
        
        tree_url, location = peers[parent]
        self.upstream.connect(disambiguate_dns_url(tree_url, location))
    
    def serialize(self, obj):
        """serialize objects.
        
        Must return list of sendable buffers.
        
        Can be extended for more efficient/noncopying serialization of numpy arrays, etc.
        """
        return [pickle.dumps(obj)]
    
    def unserialize(self, msg):
        """inverse of serialize"""
        return pickle.loads(msg[0])
    
    def publish(self, value):
        assert self.root
        self.pub.send_multipart(self.serialize(value))
    
    def consume(self):
        assert not self.root
        return self.unserialize(self.sub.recv_multipart())

    def send_upstream(self, value, flags=0):
        assert not self.root
        self.upstream.send_multipart(self.serialize(value), flags=flags|zmq.NOBLOCK)
    
    def recv_downstream(self, flags=0, timeout=2000.):
        # wait for a message, so we won't block if there was a bug
        self.downstream_poller.poll(timeout)
        
        msg = self.downstream.recv_multipart(zmq.NOBLOCK|flags)
        return self.unserialize(msg)
    
    def reduce(self, f, value, flat=True, all=False):
        """parallel reduce on binary tree
        
        if flat:
            value is an entry in the sequence
        else:
            value is a list of entries in the sequence
        
        if all:
            broadcast final result to all nodes
        else:
            only root gets final result
        """
        if not flat:
            value = reduce(f, value)
        
        for i in range(self.nchildren):
            value = f(value, self.recv_downstream())
        
        if not self.root:
            self.send_upstream(value)
        
        if all:
            if self.root:
                self.publish(value)
            else:
                value = self.consume()
        return value
    
    def allreduce(self, f, value, flat=True):
        """parallel reduce followed by broadcast of the result"""
        return self.reduce(f, value, flat=flat, all=True)