##// END OF EJS Templates
some initial tests for newparallel
MinRK -
Show More
@@ -1,35 +1,30 b''
1 """toplevel setup/teardown for prallel tests."""
1 """toplevel setup/teardown for parallel tests."""
2
2 import time
3 import time
4 from subprocess import Popen, PIPE
3
5
4 from IPython.zmq.parallel.ipcluster import launch_process
6 from IPython.zmq.parallel.ipcluster import launch_process
5 from IPython.zmq.parallel.entry_point import select_random_ports
7 from IPython.zmq.parallel.entry_point import select_random_ports
6 # from multiprocessing import Process
7
8
8 cluster_logs = dict(
9 processes = []
9 regport=0,
10
10 processes = [],
11 # nose setup/teardown
11 )
12
12
13 def setup():
13 def setup():
14 p = select_random_ports(1)[0]
14 cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=PIPE, stdin=PIPE, stderr=PIPE)
15 cluster_logs['regport']=p
15 processes.append(cp)
16 cp = launch_process('controller',('--scheduler lru --ping 100 --regport %i'%p).split())
16 time.sleep(.5)
17 # cp.start()
17 add_engine()
18 cluster_logs['processes'].append(cp)
18 time.sleep(3)
19 add_engine(p)
20 time.sleep(2)
21
19
22 def add_engine(port=None):
20 def add_engine(profile='iptest'):
23 if port is None:
21 ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=PIPE, stdin=PIPE, stderr=PIPE)
24 port = cluster_logs['regport']
25 ep = launch_process('engine', ['--regport',str(port)])
26 # ep.start()
22 # ep.start()
27 cluster_logs['processes'].append(ep)
23 processes.append(ep)
28 return ep
24 return ep
29
25
30 def teardown():
26 def teardown():
31 time.sleep(1)
27 time.sleep(1)
32 processes = cluster_logs['processes']
33 while processes:
28 while processes:
34 p = processes.pop()
29 p = processes.pop()
35 if p.poll() is None:
30 if p.poll() is None:
@@ -48,4 +43,3 b' def teardown():'
48 except:
43 except:
49 print "couldn't shutdown process: ",p
44 print "couldn't shutdown process: ",p
50
45
51
@@ -2,40 +2,70 b' import time'
2 from signal import SIGINT
2 from signal import SIGINT
3 from multiprocessing import Process
3 from multiprocessing import Process
4
4
5 from nose import SkipTest
6
5 from zmq.tests import BaseZMQTestCase
7 from zmq.tests import BaseZMQTestCase
6
8
9 from IPython.external.decorator import decorator
10
7 from IPython.zmq.parallel.ipcluster import launch_process
11 from IPython.zmq.parallel.ipcluster import launch_process
8 from IPython.zmq.parallel.entry_point import select_random_ports
12 from IPython.zmq.parallel.entry_point import select_random_ports
9 from IPython.zmq.parallel.client import Client
13 from IPython.zmq.parallel.client import Client
10 from IPython.zmq.parallel.tests import cluster_logs,add_engine
14 from IPython.zmq.parallel.tests import processes,add_engine
15
16 # simple tasks for use in apply tests
17
18 def segfault():
19 """"""
20 import ctypes
21 ctypes.memset(-1,0,1)
22
23 def wait(n):
24 """sleep for a time"""
25 import time
26 time.sleep(n)
27 return n
28
29 def raiser(eclass):
30 """raise an exception"""
31 raise eclass()
32
33 # test decorator for skipping tests when libraries are unavailable
34 def skip_without(*names):
35 """skip a test if some names are not importable"""
36 @decorator
37 def skip_without_names(f, *args, **kwargs):
38 """decorator to skip tests in the absence of numpy."""
39 for name in names:
40 try:
41 __import__(name)
42 except ImportError:
43 raise SkipTest
44 return f(*args, **kwargs)
45 return skip_without_names
11
46
12
47
13 class ClusterTestCase(BaseZMQTestCase):
48 class ClusterTestCase(BaseZMQTestCase):
14
49
15 def add_engines(self, n=1):
50 def add_engines(self, n=1, block=True):
16 """add multiple engines to our cluster"""
51 """add multiple engines to our cluster"""
17 for i in range(n):
52 for i in range(n):
18 self.engines.append(add_engine())
53 self.engines.append(add_engine())
54 if block:
55 self.wait_on_engines()
19
56
20 def wait_on_engines(self):
57 def wait_on_engines(self, timeout=5):
21 """wait for our engines to connect."""
58 """wait for our engines to connect."""
22 while len(self.client.ids) < len(self.engines)+self.base_engine_count:
59 n = len(self.engines)+self.base_engine_count
60 tic = time.time()
61 while time.time()-tic < timeout and len(self.client.ids) < n:
23 time.sleep(0.1)
62 time.sleep(0.1)
63
64 assert not self.client.ids < n, "waiting for engines timed out"
24
65
25 def start_cluster(self, n=1):
66 def connect_client(self):
26 """start a cluster"""
27 raise NotImplementedError("Don't use this anymore")
28 rport = select_random_ports(1)[0]
29 args = [ '--regport', str(rport), '--ip', '127.0.0.1' ]
30 cp = launch_process('controller', args)
31 eps = [ launch_process('engine', args+['--ident', 'engine-%i'%i]) for i in range(n) ]
32 return rport, args, cp, eps
33
34 def connect_client(self, port=None):
35 """connect a client with my Context, and track its sockets for cleanup"""
67 """connect a client with my Context, and track its sockets for cleanup"""
36 if port is None:
68 c = Client(profile='iptest',context=self.context)
37 port = cluster_logs['regport']
38 c = Client('tcp://127.0.0.1:%i'%port,context=self.context)
39 for name in filter(lambda n:n.endswith('socket'), dir(c)):
69 for name in filter(lambda n:n.endswith('socket'), dir(c)):
40 self.sockets.append(getattr(c, name))
70 self.sockets.append(getattr(c, name))
41 return c
71 return c
@@ -1,42 +1,117 b''
1 import time
1 import time
2
2
3 import nose.tools as nt
4
5 from IPython.zmq.parallel.asyncresult import AsyncResult
3 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
6 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
4
7
5 from clienttest import ClusterTestCase
8 from clienttest import ClusterTestCase, segfault
6
9
7 class TestClient(ClusterTestCase):
10 class TestClient(ClusterTestCase):
8
11
9 def test_ids(self):
12 def test_ids(self):
10 self.assertEquals(len(self.client.ids), 1)
13 self.assertEquals(len(self.client.ids), 1)
11 self.add_engines(3)
14 self.add_engines(3)
12 self.wait_on_engines()
15 self.assertEquals(len(self.client.ids), 4)
13 self.assertEquals(self.client.ids, set(range(4)))
14
16
15 def test_segfault(self):
17 def test_segfault(self):
16 def segfault():
18 self.add_engines(1)
17 import ctypes
19 eid = self.client.ids[-1]
18 ctypes.memset(-1,0,1)
20 self.client[eid].apply(segfault)
19 self.client[0].apply(segfault)
21 while eid in self.client.ids:
20 while 0 in self.client.ids:
21 time.sleep(.01)
22 time.sleep(.01)
22 self.client.spin()
23 self.client.spin()
23
24
24 def test_view_indexing(self):
25 def test_view_indexing(self):
25 self.add_engines(7)
26 self.add_engines(4)
26 self.wait_on_engines()
27 targets = self.client._build_targets('all')[-1]
27 targets = self.client._build_targets('all')[-1]
28 v = self.client[:]
28 v = self.client[:]
29 self.assertEquals(v.targets, targets)
29 self.assertEquals(v.targets, targets)
30 v =self.client[2]
30 t = self.client.ids[2]
31 self.assertEquals(v.targets, 2)
31 v = self.client[t]
32 v =self.client[1,2]
32 self.assert_(isinstance(v, DirectView))
33 self.assertEquals(v.targets, [1,2])
33 self.assertEquals(v.targets, t)
34 v =self.client[::2]
34 t = self.client.ids[2:4]
35 v = self.client[t]
36 self.assert_(isinstance(v, DirectView))
37 self.assertEquals(v.targets, t)
38 v = self.client[::2]
39 self.assert_(isinstance(v, DirectView))
35 self.assertEquals(v.targets, targets[::2])
40 self.assertEquals(v.targets, targets[::2])
36 v =self.client[1::3]
41 v = self.client[1::3]
42 self.assert_(isinstance(v, DirectView))
37 self.assertEquals(v.targets, targets[1::3])
43 self.assertEquals(v.targets, targets[1::3])
38 v =self.client[:-3]
44 v = self.client[:-3]
45 self.assert_(isinstance(v, DirectView))
39 self.assertEquals(v.targets, targets[:-3])
46 self.assertEquals(v.targets, targets[:-3])
40 v =self.client[None]
47 nt.assert_raises(TypeError, lambda : self.client[None])
41 self.assert_(isinstance(v, LoadBalancedView))
48
49 def test_view_cache(self):
50 """test blocking and non-blocking behavior"""
51 v = self.client[:2]
52 v2 =self.client[:2]
53 self.assertTrue(v is v2)
54 v = self.client.view()
55 v2 = self.client.view(balanced=True)
56 self.assertTrue(v is v2)
57
58 def test_targets(self):
59 """test various valid targets arguments"""
60 pass
61
62 def test_clear(self):
63 """test clear behavior"""
64 # self.add_engines(4)
65 # self.client.push()
66
67 def test_push_pull(self):
68 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
69 self.add_engines(4)
70 push = self.client.push
71 pull = self.client.pull
72 self.client.block=True
73 nengines = len(self.client)
74 push({'data':data}, targets=0)
75 d = pull('data', targets=0)
76 self.assertEquals(d, data)
77 push({'data':data})
78 d = pull('data')
79 self.assertEquals(d, nengines*[data])
80 ar = push({'data':data}, block=False)
81 self.assertTrue(isinstance(ar, AsyncResult))
82 r = ar.get()
83 ar = pull('data', block=False)
84 self.assertTrue(isinstance(ar, AsyncResult))
85 r = ar.get()
86 self.assertEquals(r, nengines*[data])
87 push(dict(a=10,b=20))
88 r = pull(('a','b'))
89 self.assertEquals(r, nengines*[[10,20]])
90
91 def test_push_pull_function(self):
92 def testf(x):
93 return 2.0*x
94
95 self.add_engines(4)
96 self.client.block=True
97 push = self.client.push
98 pull = self.client.pull
99 execute = self.client.execute
100 push({'testf':testf}, targets=0)
101 r = pull('testf', targets=0)
102 self.assertEqual(r(1.0), testf(1.0))
103 execute('r = testf(10)', targets=0)
104 r = pull('r', targets=0)
105 self.assertEquals(r, testf(10))
106 ar = push({'testf':testf}, block=False)
107 ar.get()
108 ar = pull('testf', block=False)
109 rlist = ar.get()
110 for r in rlist:
111 self.assertEqual(r(1.0), testf(1.0))
112 execute("def g(x): return x*x", targets=0)
113 r = pull(('testf','g'),targets=0)
114 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
115
116
42 No newline at end of file
117
@@ -1,4 +1,89 b''
1 """test serialization with newserialized"""
1
2
2 from unittest import TestCase
3 from unittest import TestCase
3 # from zmq.tests import BaseZMQTest
4
4
5 import nose.tools as nt
6
7 from IPython.testing.parametric import parametric
8 from IPython.utils import newserialized as ns
9 from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
10 from IPython.zmq.parallel.tests.clienttest import skip_without
11
12
13 class CanningTestCase(TestCase):
14 def test_canning(self):
15 d = dict(a=5,b=6)
16 cd = can(d)
17 nt.assert_true(isinstance(cd, dict))
18
19 def test_canned_function(self):
20 f = lambda : 7
21 cf = can(f)
22 nt.assert_true(isinstance(cf, CannedFunction))
23
24 @parametric
25 def test_can_roundtrip(cls):
26 objs = [
27 dict(),
28 set(),
29 list(),
30 ['a',1,['a',1],u'e'],
31 ]
32 return map(cls.run_roundtrip, objs)
33
34 @classmethod
35 def run_roundtrip(cls, obj):
36 o = uncan(can(obj))
37 nt.assert_equals(obj, o)
38
39 def test_serialized_interfaces(self):
40
41 us = {'a':10, 'b':range(10)}
42 s = ns.serialize(us)
43 uus = ns.unserialize(s)
44 nt.assert_true(isinstance(s, ns.SerializeIt))
45 nt.assert_equals(uus, us)
46
47 def test_pickle_serialized(self):
48 obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L}
49 original = ns.UnSerialized(obj)
50 originalSer = ns.SerializeIt(original)
51 firstData = originalSer.getData()
52 firstTD = originalSer.getTypeDescriptor()
53 firstMD = originalSer.getMetadata()
54 nt.assert_equals(firstTD, 'pickle')
55 nt.assert_equals(firstMD, {})
56 unSerialized = ns.UnSerializeIt(originalSer)
57 secondObj = unSerialized.getObject()
58 for k, v in secondObj.iteritems():
59 nt.assert_equals(obj[k], v)
60 secondSer = ns.SerializeIt(ns.UnSerialized(secondObj))
61 nt.assert_equals(firstData, secondSer.getData())
62 nt.assert_equals(firstTD, secondSer.getTypeDescriptor() )
63 nt.assert_equals(firstMD, secondSer.getMetadata())
64
65 @skip_without('numpy')
66 def test_ndarray_serialized(self):
67 import numpy
68 a = numpy.linspace(0.0, 1.0, 1000)
69 unSer1 = ns.UnSerialized(a)
70 ser1 = ns.SerializeIt(unSer1)
71 td = ser1.getTypeDescriptor()
72 nt.assert_equals(td, 'ndarray')
73 md = ser1.getMetadata()
74 nt.assert_equals(md['shape'], a.shape)
75 nt.assert_equals(md['dtype'], a.dtype.str)
76 buff = ser1.getData()
77 nt.assert_equals(buff, numpy.getbuffer(a))
78 s = ns.Serialized(buff, td, md)
79 final = ns.unserialize(s)
80 nt.assert_equals(numpy.getbuffer(a), numpy.getbuffer(final))
81 nt.assert_true((a==final).all())
82 nt.assert_equals(a.dtype.str, final.dtype.str)
83 nt.assert_equals(a.shape, final.shape)
84 # test non-copying:
85 a[2] = 1e9
86 nt.assert_true((a==final).all())
87
88
89 No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now