"""Tests for dependency.py Authors: * Min RK """ __docformat__ = "restructuredtext en" #------------------------------------------------------------------------------- # Copyright (C) 2011 The IPython Development Team # # Distributed under the terms of the BSD License. The full license is in # the file COPYING, distributed as part of this software. #------------------------------------------------------------------------------- #------------------------------------------------------------------------------- # Imports #------------------------------------------------------------------------------- # import import os from IPython.utils.pickleutil import can, uncan import IPython.parallel as pmod from IPython.parallel.util import interactive from IPython.parallel.tests import add_engines from .clienttest import ClusterTestCase def setup(): add_engines(1, total=True) @pmod.require('time') def wait(n): time.sleep(n) return n @pmod.interactive def func(x): return x*x mixed = list(map(str, range(10))) completed = list(map(str, range(0,10,2))) failed = list(map(str, range(1,10,2))) class DependencyTest(ClusterTestCase): def setUp(self): ClusterTestCase.setUp(self) self.user_ns = {'__builtins__' : __builtins__} self.view = self.client.load_balanced_view() self.dview = self.client[-1] self.succeeded = set(map(str, range(0,25,2))) self.failed = set(map(str, range(1,25,2))) def assertMet(self, dep): self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met") def assertUnmet(self, dep): self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met") def assertUnreachable(self, dep): self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable") def assertReachable(self, dep): self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable") def cancan(self, f): """decorator to pass through canning into self.user_ns""" return uncan(can(f), self.user_ns) def test_require_imports(self): """test that @require imports names""" @self.cancan @pmod.require('base64') @interactive def encode(arg): return base64.b64encode(arg) # must pass through canning to properly connect namespaces self.assertEqual(encode(b'foo'), b'Zm9v') def test_success_only(self): dep = pmod.Dependency(mixed, success=True, failure=False) self.assertUnmet(dep) self.assertUnreachable(dep) dep.all=False self.assertMet(dep) self.assertReachable(dep) dep = pmod.Dependency(completed, success=True, failure=False) self.assertMet(dep) self.assertReachable(dep) dep.all=False self.assertMet(dep) self.assertReachable(dep) def test_failure_only(self): dep = pmod.Dependency(mixed, success=False, failure=True) self.assertUnmet(dep) self.assertUnreachable(dep) dep.all=False self.assertMet(dep) self.assertReachable(dep) dep = pmod.Dependency(completed, success=False, failure=True) self.assertUnmet(dep) self.assertUnreachable(dep) dep.all=False self.assertUnmet(dep) self.assertUnreachable(dep) def test_require_function(self): @pmod.interactive def bar(a): return func(a) @pmod.require(func) @pmod.interactive def bar2(a): return func(a) self.client[:].clear() self.assertRaisesRemote(NameError, self.view.apply_sync, bar, 5) ar = self.view.apply_async(bar2, 5) self.assertEqual(ar.get(5), func(5)) def test_require_object(self): @pmod.require(foo=func) @pmod.interactive def bar(a): return foo(a) ar = self.view.apply_async(bar, 5) self.assertEqual(ar.get(5), func(5))