"""Tests for dependency.py

Authors:

* Min RK
"""

__docformat__ = "restructuredtext en"

#-------------------------------------------------------------------------------
#  Copyright (C) 2011  The IPython Development Team
#
#  Distributed under the terms of the BSD License.  The full license is in
#  the file COPYING, distributed as part of this software.
#-------------------------------------------------------------------------------

#-------------------------------------------------------------------------------
# Imports
#-------------------------------------------------------------------------------

# import
import os

from IPython.utils.pickleutil import can, uncan

import IPython.parallel as pmod
from IPython.parallel.util import interactive

from IPython.parallel.tests import add_engines
from .clienttest import ClusterTestCase

def setup():
    add_engines(1, total=True)

@pmod.require('time')
def wait(n):
    time.sleep(n)
    return n

@pmod.interactive
def func(x):
    return x*x

mixed = list(map(str, range(10)))
completed = list(map(str, range(0,10,2)))
failed = list(map(str, range(1,10,2)))

class DependencyTest(ClusterTestCase):
    
    def setUp(self):
        ClusterTestCase.setUp(self)
        self.user_ns = {'__builtins__' : __builtins__}
        self.view = self.client.load_balanced_view()
        self.dview = self.client[-1]
        self.succeeded = set(map(str, range(0,25,2)))
        self.failed = set(map(str, range(1,25,2)))
    
    def assertMet(self, dep):
        self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
        
    def assertUnmet(self, dep):
        self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
        
    def assertUnreachable(self, dep):
        self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
    
    def assertReachable(self, dep):
        self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
    
    def cancan(self, f):
        """decorator to pass through canning into self.user_ns"""
        return uncan(can(f), self.user_ns)
    
    def test_require_imports(self):
        """test that @require imports names"""
        @self.cancan
        @pmod.require('base64')
        @interactive
        def encode(arg):
            return base64.b64encode(arg)
        # must pass through canning to properly connect namespaces
        self.assertEqual(encode(b'foo'), b'Zm9v')
    
    def test_success_only(self):
        dep = pmod.Dependency(mixed, success=True, failure=False)
        self.assertUnmet(dep)
        self.assertUnreachable(dep)
        dep.all=False
        self.assertMet(dep)
        self.assertReachable(dep)
        dep = pmod.Dependency(completed, success=True, failure=False)
        self.assertMet(dep)
        self.assertReachable(dep)
        dep.all=False
        self.assertMet(dep)
        self.assertReachable(dep)

    def test_failure_only(self):
        dep = pmod.Dependency(mixed, success=False, failure=True)
        self.assertUnmet(dep)
        self.assertUnreachable(dep)
        dep.all=False
        self.assertMet(dep)
        self.assertReachable(dep)
        dep = pmod.Dependency(completed, success=False, failure=True)
        self.assertUnmet(dep)
        self.assertUnreachable(dep)
        dep.all=False
        self.assertUnmet(dep)
        self.assertUnreachable(dep)
    
    def test_require_function(self):
        
        @pmod.interactive
        def bar(a):
            return func(a)

        @pmod.require(func)
        @pmod.interactive
        def bar2(a):
            return func(a)
        
        self.client[:].clear()
        self.assertRaisesRemote(NameError, self.view.apply_sync, bar, 5)
        ar = self.view.apply_async(bar2, 5)
        self.assertEqual(ar.get(5), func(5))

    def test_require_object(self):
        
        @pmod.require(foo=func)
        @pmod.interactive
        def bar(a):
            return foo(a)

        ar = self.view.apply_async(bar, 5)
        self.assertEqual(ar.get(5), func(5))