"""base class for parallel client tests

Authors:

* Min RK
"""

#-------------------------------------------------------------------------------
#  Copyright (C) 2011  The IPython Development Team
#
#  Distributed under the terms of the BSD License.  The full license is in
#  the file COPYING, distributed as part of this software.
#-------------------------------------------------------------------------------
from __future__ import print_function

import sys
import tempfile
import time
from StringIO import StringIO

from nose import SkipTest

import zmq
from zmq.tests import BaseZMQTestCase

from IPython.external.decorator import decorator

from IPython.parallel import error
from IPython.parallel import Client

from IPython.parallel.tests import launchers, add_engines

# simple tasks for use in apply tests

def segfault():
    """this will segfault"""
    import ctypes
    ctypes.memset(-1,0,1)

def crash():
    """from stdlib crashers in the test suite"""
    import types
    if sys.platform.startswith('win'):
        import ctypes
        ctypes.windll.kernel32.SetErrorMode(0x0002);
    args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
    if sys.version_info[0] >= 3:
        # Python3 adds 'kwonlyargcount' as the second argument to Code
        args.insert(1, 0)
        
    co = types.CodeType(*args)
    exec(co)

def wait(n):
    """sleep for a time"""
    import time
    time.sleep(n)
    return n

def raiser(eclass):
    """raise an exception"""
    raise eclass()

def generate_output():
    """function for testing output
    
    publishes two outputs of each type, and returns
    a rich displayable object.
    """
    
    import sys
    from IPython.core.display import display, HTML, Math
    
    print("stdout")
    print("stderr", file=sys.stderr)
    
    display(HTML("<b>HTML</b>"))
    
    print("stdout2")
    print("stderr2", file=sys.stderr)
    
    display(Math(r"\alpha=\beta"))
    
    return Math("42")

# test decorator for skipping tests when libraries are unavailable
def skip_without(*names):
    """skip a test if some names are not importable"""
    @decorator
    def skip_without_names(f, *args, **kwargs):
        """decorator to skip tests in the absence of numpy."""
        for name in names:
            try:
                __import__(name)
            except ImportError:
                raise SkipTest
        return f(*args, **kwargs)
    return skip_without_names

#-------------------------------------------------------------------------------
# Classes
#-------------------------------------------------------------------------------


class ClusterTestCase(BaseZMQTestCase):
    
    def add_engines(self, n=1, block=True):
        """add multiple engines to our cluster"""
        self.engines.extend(add_engines(n))
        if block:
            self.wait_on_engines()

    def minimum_engines(self, n=1, block=True):
        """add engines until there are at least n connected"""
        self.engines.extend(add_engines(n, total=True))
        if block:
            self.wait_on_engines()
            
    
    def wait_on_engines(self, timeout=5):
        """wait for our engines to connect."""
        n = len(self.engines)+self.base_engine_count
        tic = time.time()
        while time.time()-tic < timeout and len(self.client.ids) < n:
            time.sleep(0.1)
        
        assert not len(self.client.ids) < n, "waiting for engines timed out"
    
    def connect_client(self):
        """connect a client with my Context, and track its sockets for cleanup"""
        c = Client(profile='iptest', context=self.context)
        for name in filter(lambda n:n.endswith('socket'), dir(c)):
            s = getattr(c, name)
            s.setsockopt(zmq.LINGER, 0)
            self.sockets.append(s)
        return c
    
    def assertRaisesRemote(self, etype, f, *args, **kwargs):
        try:
            try:
                f(*args, **kwargs)
            except error.CompositeError as e:
                e.raise_exception()
        except error.RemoteError as e:
            self.assertEqual(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
        else:
            self.fail("should have raised a RemoteError")
            
    def _wait_for(self, f, timeout=10):
        """wait for a condition"""
        tic = time.time()
        while time.time() <= tic + timeout:
            if f():
                return
            time.sleep(0.1)
            self.client.spin()
        if not f():
            print("Warning: Awaited condition never arrived")
    
    def setUp(self):
        BaseZMQTestCase.setUp(self)
        self.client = self.connect_client()
        # start every test with clean engine namespaces:
        self.client.clear(block=True)
        self.base_engine_count=len(self.client.ids)
        self.engines=[]
    
    def tearDown(self):
        # self.client.clear(block=True)
        # close fds:
        for e in filter(lambda e: e.poll() is not None, launchers):
            launchers.remove(e)
        
        # allow flushing of incoming messages to prevent crash on socket close
        self.client.wait(timeout=2)
        # time.sleep(2)
        self.client.spin()
        self.client.close()
        BaseZMQTestCase.tearDown(self)
        # this will be redundant when pyzmq merges PR #88
        # self.context.term()
        # print tempfile.TemporaryFile().fileno(),
        # sys.stdout.flush()