##// END OF EJS Templates
SGE test related fixes...
MinRK -
Show More
@@ -1,1035 +1,1090 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 18 import sys
19 19 import time
20 20 from datetime import datetime
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop
24 24 from zmq.eventloop.zmqstream import ZMQStream
25 25
26 26 # internal:
27 27 from IPython.utils.importstring import import_item
28 28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
29 29
30 30 from .entry_point import select_random_ports
31 31 from .factory import RegistrationFactory, LoggingFactory
32 32
33 33 from . import error
34 34 from .heartmonitor import HeartMonitor
35 35 from .util import validate_url_container, ISO8601
36 36
37 37 #-----------------------------------------------------------------------------
38 38 # Code
39 39 #-----------------------------------------------------------------------------
40 40
41 41 def _passer(*args, **kwargs):
42 42 return
43 43
44 44 def _printer(*args, **kwargs):
45 45 print (args)
46 46 print (kwargs)
47 47
48 def empty_record():
49 """Return an empty dict with all record keys."""
50 return {
51 'msg_id' : None,
52 'header' : None,
53 'content': None,
54 'buffers': None,
55 'submitted': None,
56 'client_uuid' : None,
57 'engine_uuid' : None,
58 'started': None,
59 'completed': None,
60 'resubmitted': None,
61 'result_header' : None,
62 'result_content' : None,
63 'result_buffers' : None,
64 'queue' : None,
65 'pyin' : None,
66 'pyout': None,
67 'pyerr': None,
68 'stdout': '',
69 'stderr': '',
70 }
71
48 72 def init_record(msg):
49 73 """Initialize a TaskRecord based on a request."""
50 74 header = msg['header']
51 75 return {
52 76 'msg_id' : header['msg_id'],
53 77 'header' : header,
54 78 'content': msg['content'],
55 79 'buffers': msg['buffers'],
56 80 'submitted': datetime.strptime(header['date'], ISO8601),
57 81 'client_uuid' : None,
58 82 'engine_uuid' : None,
59 83 'started': None,
60 84 'completed': None,
61 85 'resubmitted': None,
62 86 'result_header' : None,
63 87 'result_content' : None,
64 88 'result_buffers' : None,
65 89 'queue' : None,
66 90 'pyin' : None,
67 91 'pyout': None,
68 92 'pyerr': None,
69 93 'stdout': '',
70 94 'stderr': '',
71 95 }
72 96
73 97
74 98 class EngineConnector(HasTraits):
75 99 """A simple object for accessing the various zmq connections of an object.
76 100 Attributes are:
77 101 id (int): engine ID
78 102 uuid (str): uuid (unused?)
79 103 queue (str): identity of queue's XREQ socket
80 104 registration (str): identity of registration XREQ socket
81 105 heartbeat (str): identity of heartbeat XREQ socket
82 106 """
83 107 id=Int(0)
84 108 queue=Str()
85 109 control=Str()
86 110 registration=Str()
87 111 heartbeat=Str()
88 112 pending=Set()
89 113
90 114 class HubFactory(RegistrationFactory):
91 115 """The Configurable for setting up a Hub."""
92 116
93 117 # name of a scheduler scheme
94 118 scheme = Str('leastload', config=True)
95 119
96 120 # port-pairs for monitoredqueues:
97 121 hb = Instance(list, config=True)
98 122 def _hb_default(self):
99 123 return select_random_ports(2)
100 124
101 125 mux = Instance(list, config=True)
102 126 def _mux_default(self):
103 127 return select_random_ports(2)
104 128
105 129 task = Instance(list, config=True)
106 130 def _task_default(self):
107 131 return select_random_ports(2)
108 132
109 133 control = Instance(list, config=True)
110 134 def _control_default(self):
111 135 return select_random_ports(2)
112 136
113 137 iopub = Instance(list, config=True)
114 138 def _iopub_default(self):
115 139 return select_random_ports(2)
116 140
117 141 # single ports:
118 142 mon_port = Instance(int, config=True)
119 143 def _mon_port_default(self):
120 144 return select_random_ports(1)[0]
121 145
122 146 notifier_port = Instance(int, config=True)
123 147 def _notifier_port_default(self):
124 148 return select_random_ports(1)[0]
125 149
126 150 ping = Int(1000, config=True) # ping frequency
127 151
128 152 engine_ip = CStr('127.0.0.1', config=True)
129 153 engine_transport = CStr('tcp', config=True)
130 154
131 155 client_ip = CStr('127.0.0.1', config=True)
132 156 client_transport = CStr('tcp', config=True)
133 157
134 158 monitor_ip = CStr('127.0.0.1', config=True)
135 159 monitor_transport = CStr('tcp', config=True)
136 160
137 161 monitor_url = CStr('')
138 162
139 163 db_class = CStr('IPython.parallel.dictdb.DictDB', config=True)
140 164
141 165 # not configurable
142 166 db = Instance('IPython.parallel.dictdb.BaseDB')
143 167 heartmonitor = Instance('IPython.parallel.heartmonitor.HeartMonitor')
144 168 subconstructors = List()
145 169 _constructed = Bool(False)
146 170
147 171 def _ip_changed(self, name, old, new):
148 172 self.engine_ip = new
149 173 self.client_ip = new
150 174 self.monitor_ip = new
151 175 self._update_monitor_url()
152 176
153 177 def _update_monitor_url(self):
154 178 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
155 179
156 180 def _transport_changed(self, name, old, new):
157 181 self.engine_transport = new
158 182 self.client_transport = new
159 183 self.monitor_transport = new
160 184 self._update_monitor_url()
161 185
162 186 def __init__(self, **kwargs):
163 187 super(HubFactory, self).__init__(**kwargs)
164 188 self._update_monitor_url()
165 189 # self.on_trait_change(self._sync_ips, 'ip')
166 190 # self.on_trait_change(self._sync_transports, 'transport')
167 191 self.subconstructors.append(self.construct_hub)
168 192
169 193
170 194 def construct(self):
171 195 assert not self._constructed, "already constructed!"
172 196
173 197 for subc in self.subconstructors:
174 198 subc()
175 199
176 200 self._constructed = True
177 201
178 202
179 203 def start(self):
180 204 assert self._constructed, "must be constructed by self.construct() first!"
181 205 self.heartmonitor.start()
182 206 self.log.info("Heartmonitor started")
183 207
184 208 def construct_hub(self):
185 209 """construct"""
186 210 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
187 211 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
188 212
189 213 ctx = self.context
190 214 loop = self.loop
191 215
192 216 # Registrar socket
193 217 q = ZMQStream(ctx.socket(zmq.XREP), loop)
194 218 q.bind(client_iface % self.regport)
195 219 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
196 220 if self.client_ip != self.engine_ip:
197 221 q.bind(engine_iface % self.regport)
198 222 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
199 223
200 224 ### Engine connections ###
201 225
202 226 # heartbeat
203 227 hpub = ctx.socket(zmq.PUB)
204 228 hpub.bind(engine_iface % self.hb[0])
205 229 hrep = ctx.socket(zmq.XREP)
206 230 hrep.bind(engine_iface % self.hb[1])
207 231 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
208 232 period=self.ping, logname=self.log.name)
209 233
210 234 ### Client connections ###
211 235 # Notifier socket
212 236 n = ZMQStream(ctx.socket(zmq.PUB), loop)
213 237 n.bind(client_iface%self.notifier_port)
214 238
215 239 ### build and launch the queues ###
216 240
217 241 # monitor socket
218 242 sub = ctx.socket(zmq.SUB)
219 243 sub.setsockopt(zmq.SUBSCRIBE, "")
220 244 sub.bind(self.monitor_url)
221 245 sub.bind('inproc://monitor')
222 246 sub = ZMQStream(sub, loop)
223 247
224 248 # connect the db
225 249 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
226 250 # cdir = self.config.Global.cluster_dir
227 251 self.db = import_item(self.db_class)(session=self.session.session, config=self.config)
228 252 time.sleep(.25)
229 253
230 254 # build connection dicts
231 255 self.engine_info = {
232 256 'control' : engine_iface%self.control[1],
233 257 'mux': engine_iface%self.mux[1],
234 258 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
235 259 'task' : engine_iface%self.task[1],
236 260 'iopub' : engine_iface%self.iopub[1],
237 261 # 'monitor' : engine_iface%self.mon_port,
238 262 }
239 263
240 264 self.client_info = {
241 265 'control' : client_iface%self.control[0],
242 266 'mux': client_iface%self.mux[0],
243 267 'task' : (self.scheme, client_iface%self.task[0]),
244 268 'iopub' : client_iface%self.iopub[0],
245 269 'notification': client_iface%self.notifier_port
246 270 }
247 271 self.log.debug("Hub engine addrs: %s"%self.engine_info)
248 272 self.log.debug("Hub client addrs: %s"%self.client_info)
249 273 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
250 274 query=q, notifier=n, db=self.db,
251 275 engine_info=self.engine_info, client_info=self.client_info,
252 276 logname=self.log.name)
253 277
254 278
255 279 class Hub(LoggingFactory):
256 280 """The IPython Controller Hub with 0MQ connections
257 281
258 282 Parameters
259 283 ==========
260 284 loop: zmq IOLoop instance
261 285 session: StreamSession object
262 286 <removed> context: zmq context for creating new connections (?)
263 287 queue: ZMQStream for monitoring the command queue (SUB)
264 288 query: ZMQStream for engine registration and client queries requests (XREP)
265 289 heartbeat: HeartMonitor object checking the pulse of the engines
266 290 notifier: ZMQStream for broadcasting engine registration changes (PUB)
267 291 db: connection to db for out of memory logging of commands
268 292 NotImplemented
269 293 engine_info: dict of zmq connection information for engines to connect
270 294 to the queues.
271 295 client_info: dict of zmq connection information for engines to connect
272 296 to the queues.
273 297 """
274 298 # internal data structures:
275 299 ids=Set() # engine IDs
276 300 keytable=Dict()
277 301 by_ident=Dict()
278 302 engines=Dict()
279 303 clients=Dict()
280 304 hearts=Dict()
281 305 pending=Set()
282 306 queues=Dict() # pending msg_ids keyed by engine_id
283 307 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
284 308 completed=Dict() # completed msg_ids keyed by engine_id
285 309 all_completed=Set() # completed msg_ids keyed by engine_id
310 dead_engines=Set() # completed msg_ids keyed by engine_id
286 311 # mia=None
287 312 incoming_registrations=Dict()
288 313 registration_timeout=Int()
289 314 _idcounter=Int(0)
290 315
291 316 # objects from constructor:
292 317 loop=Instance(ioloop.IOLoop)
293 318 query=Instance(ZMQStream)
294 319 monitor=Instance(ZMQStream)
295 320 heartmonitor=Instance(HeartMonitor)
296 321 notifier=Instance(ZMQStream)
297 322 db=Instance(object)
298 323 client_info=Dict()
299 324 engine_info=Dict()
300 325
301 326
302 327 def __init__(self, **kwargs):
303 328 """
304 329 # universal:
305 330 loop: IOLoop for creating future connections
306 331 session: streamsession for sending serialized data
307 332 # engine:
308 333 queue: ZMQStream for monitoring queue messages
309 334 query: ZMQStream for engine+client registration and client requests
310 335 heartbeat: HeartMonitor object for tracking engines
311 336 # extra:
312 337 db: ZMQStream for db connection (NotImplemented)
313 338 engine_info: zmq address/protocol dict for engine connections
314 339 client_info: zmq address/protocol dict for client connections
315 340 """
316 341
317 342 super(Hub, self).__init__(**kwargs)
318 343 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
319 344
320 345 # validate connection dicts:
321 346 for k,v in self.client_info.iteritems():
322 347 if k == 'task':
323 348 validate_url_container(v[1])
324 349 else:
325 350 validate_url_container(v)
326 351 # validate_url_container(self.client_info)
327 352 validate_url_container(self.engine_info)
328 353
329 354 # register our callbacks
330 355 self.query.on_recv(self.dispatch_query)
331 356 self.monitor.on_recv(self.dispatch_monitor_traffic)
332 357
333 358 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
334 359 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
335 360
336 361 self.monitor_handlers = { 'in' : self.save_queue_request,
337 362 'out': self.save_queue_result,
338 363 'intask': self.save_task_request,
339 364 'outtask': self.save_task_result,
340 365 'tracktask': self.save_task_destination,
341 366 'incontrol': _passer,
342 367 'outcontrol': _passer,
343 368 'iopub': self.save_iopub_message,
344 369 }
345 370
346 371 self.query_handlers = {'queue_request': self.queue_status,
347 372 'result_request': self.get_results,
348 373 'purge_request': self.purge_results,
349 374 'load_request': self.check_load,
350 375 'resubmit_request': self.resubmit_task,
351 376 'shutdown_request': self.shutdown_request,
352 377 'registration_request' : self.register_engine,
353 378 'unregistration_request' : self.unregister_engine,
354 379 'connection_request': self.connection_request,
355 380 }
356 381
357 382 self.log.info("hub::created hub")
358 383
359 384 @property
360 385 def _next_id(self):
361 386 """gemerate a new ID.
362 387
363 388 No longer reuse old ids, just count from 0."""
364 389 newid = self._idcounter
365 390 self._idcounter += 1
366 391 return newid
367 392 # newid = 0
368 393 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
369 394 # # print newid, self.ids, self.incoming_registrations
370 395 # while newid in self.ids or newid in incoming:
371 396 # newid += 1
372 397 # return newid
373 398
374 399 #-----------------------------------------------------------------------------
375 400 # message validation
376 401 #-----------------------------------------------------------------------------
377 402
378 403 def _validate_targets(self, targets):
379 404 """turn any valid targets argument into a list of integer ids"""
380 405 if targets is None:
381 406 # default to all
382 407 targets = self.ids
383 408
384 409 if isinstance(targets, (int,str,unicode)):
385 410 # only one target specified
386 411 targets = [targets]
387 412 _targets = []
388 413 for t in targets:
389 414 # map raw identities to ids
390 415 if isinstance(t, (str,unicode)):
391 416 t = self.by_ident.get(t, t)
392 417 _targets.append(t)
393 418 targets = _targets
394 419 bad_targets = [ t for t in targets if t not in self.ids ]
395 420 if bad_targets:
396 421 raise IndexError("No Such Engine: %r"%bad_targets)
397 422 if not targets:
398 423 raise IndexError("No Engines Registered")
399 424 return targets
400 425
401 426 #-----------------------------------------------------------------------------
402 427 # dispatch methods (1 per stream)
403 428 #-----------------------------------------------------------------------------
404 429
405 430 # def dispatch_registration_request(self, msg):
406 431 # """"""
407 432 # self.log.debug("registration::dispatch_register_request(%s)"%msg)
408 433 # idents,msg = self.session.feed_identities(msg)
409 434 # if not idents:
410 435 # self.log.error("Bad Query Message: %s"%msg, exc_info=True)
411 436 # return
412 437 # try:
413 438 # msg = self.session.unpack_message(msg,content=True)
414 439 # except:
415 440 # self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
416 441 # return
417 442 #
418 443 # msg_type = msg['msg_type']
419 444 # content = msg['content']
420 445 #
421 446 # handler = self.query_handlers.get(msg_type, None)
422 447 # if handler is None:
423 448 # self.log.error("registration::got bad registration message: %s"%msg)
424 449 # else:
425 450 # handler(idents, msg)
426 451
427 452 def dispatch_monitor_traffic(self, msg):
428 453 """all ME and Task queue messages come through here, as well as
429 454 IOPub traffic."""
430 455 self.log.debug("monitor traffic: %s"%msg[:2])
431 456 switch = msg[0]
432 457 idents, msg = self.session.feed_identities(msg[1:])
433 458 if not idents:
434 459 self.log.error("Bad Monitor Message: %s"%msg)
435 460 return
436 461 handler = self.monitor_handlers.get(switch, None)
437 462 if handler is not None:
438 463 handler(idents, msg)
439 464 else:
440 465 self.log.error("Invalid monitor topic: %s"%switch)
441 466
442 467
443 468 def dispatch_query(self, msg):
444 469 """Route registration requests and queries from clients."""
445 470 idents, msg = self.session.feed_identities(msg)
446 471 if not idents:
447 472 self.log.error("Bad Query Message: %s"%msg)
448 473 return
449 474 client_id = idents[0]
450 475 try:
451 476 msg = self.session.unpack_message(msg, content=True)
452 477 except:
453 478 content = error.wrap_exception()
454 479 self.log.error("Bad Query Message: %s"%msg, exc_info=True)
455 480 self.session.send(self.query, "hub_error", ident=client_id,
456 481 content=content)
457 482 return
458 483
459 484 # print client_id, header, parent, content
460 485 #switch on message type:
461 486 msg_type = msg['msg_type']
462 487 self.log.info("client::client %s requested %s"%(client_id, msg_type))
463 488 handler = self.query_handlers.get(msg_type, None)
464 489 try:
465 490 assert handler is not None, "Bad Message Type: %s"%msg_type
466 491 except:
467 492 content = error.wrap_exception()
468 493 self.log.error("Bad Message Type: %s"%msg_type, exc_info=True)
469 494 self.session.send(self.query, "hub_error", ident=client_id,
470 495 content=content)
471 496 return
472 497 else:
473 498 handler(idents, msg)
474 499
475 500 def dispatch_db(self, msg):
476 501 """"""
477 502 raise NotImplementedError
478 503
479 504 #---------------------------------------------------------------------------
480 505 # handler methods (1 per event)
481 506 #---------------------------------------------------------------------------
482 507
483 508 #----------------------- Heartbeat --------------------------------------
484 509
485 510 def handle_new_heart(self, heart):
486 511 """handler to attach to heartbeater.
487 512 Called when a new heart starts to beat.
488 513 Triggers completion of registration."""
489 514 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
490 515 if heart not in self.incoming_registrations:
491 516 self.log.info("heartbeat::ignoring new heart: %r"%heart)
492 517 else:
493 518 self.finish_registration(heart)
494 519
495 520
496 521 def handle_heart_failure(self, heart):
497 522 """handler to attach to heartbeater.
498 523 called when a previously registered heart fails to respond to beat request.
499 524 triggers unregistration"""
500 525 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
501 526 eid = self.hearts.get(heart, None)
502 527 queue = self.engines[eid].queue
503 528 if eid is None:
504 529 self.log.info("heartbeat::ignoring heart failure %r"%heart)
505 530 else:
506 531 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
507 532
508 533 #----------------------- MUX Queue Traffic ------------------------------
509 534
510 535 def save_queue_request(self, idents, msg):
511 536 if len(idents) < 2:
512 537 self.log.error("invalid identity prefix: %s"%idents)
513 538 return
514 539 queue_id, client_id = idents[:2]
515 540 try:
516 541 msg = self.session.unpack_message(msg, content=False)
517 542 except:
518 543 self.log.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
519 544 return
520 545
521 546 eid = self.by_ident.get(queue_id, None)
522 547 if eid is None:
523 548 self.log.error("queue::target %r not registered"%queue_id)
524 549 self.log.debug("queue:: valid are: %s"%(self.by_ident.keys()))
525 550 return
526 551
527 552 header = msg['header']
528 553 msg_id = header['msg_id']
529 554 record = init_record(msg)
530 555 record['engine_uuid'] = queue_id
531 556 record['client_uuid'] = client_id
532 557 record['queue'] = 'mux'
533 558
559 try:
560 # it's posible iopub arrived first:
561 existing = self.db.get_record(msg_id)
562 for key,evalue in existing.iteritems():
563 rvalue = record[key]
564 if evalue and rvalue and evalue != rvalue:
565 self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
566 elif evalue and not rvalue:
567 record[key] = evalue
568 self.db.update_record(msg_id, record)
569 except KeyError:
570 self.db.add_record(msg_id, record)
571
534 572 self.pending.add(msg_id)
535 573 self.queues[eid].append(msg_id)
536 self.db.add_record(msg_id, record)
537 574
538 575 def save_queue_result(self, idents, msg):
539 576 if len(idents) < 2:
540 577 self.log.error("invalid identity prefix: %s"%idents)
541 578 return
542 579
543 580 client_id, queue_id = idents[:2]
544 581 try:
545 582 msg = self.session.unpack_message(msg, content=False)
546 583 except:
547 584 self.log.error("queue::engine %r sent invalid message to %r: %s"%(
548 585 queue_id,client_id, msg), exc_info=True)
549 586 return
550 587
551 588 eid = self.by_ident.get(queue_id, None)
552 589 if eid is None:
553 590 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
554 self.log.debug("queue:: %s"%msg[2:])
591 # self.log.debug("queue:: %s"%msg[2:])
555 592 return
556 593
557 594 parent = msg['parent_header']
558 595 if not parent:
559 596 return
560 597 msg_id = parent['msg_id']
561 598 if msg_id in self.pending:
562 599 self.pending.remove(msg_id)
563 600 self.all_completed.add(msg_id)
564 601 self.queues[eid].remove(msg_id)
565 602 self.completed[eid].append(msg_id)
566 603 elif msg_id not in self.all_completed:
567 604 # it could be a result from a dead engine that died before delivering the
568 605 # result
569 606 self.log.warn("queue:: unknown msg finished %s"%msg_id)
570 607 return
571 608 # update record anyway, because the unregistration could have been premature
572 609 rheader = msg['header']
573 610 completed = datetime.strptime(rheader['date'], ISO8601)
574 611 started = rheader.get('started', None)
575 612 if started is not None:
576 613 started = datetime.strptime(started, ISO8601)
577 614 result = {
578 615 'result_header' : rheader,
579 616 'result_content': msg['content'],
580 617 'started' : started,
581 618 'completed' : completed
582 619 }
583 620
584 621 result['result_buffers'] = msg['buffers']
585 622 self.db.update_record(msg_id, result)
586 623
587 624
588 625 #--------------------- Task Queue Traffic ------------------------------
589 626
590 627 def save_task_request(self, idents, msg):
591 628 """Save the submission of a task."""
592 629 client_id = idents[0]
593 630
594 631 try:
595 632 msg = self.session.unpack_message(msg, content=False)
596 633 except:
597 634 self.log.error("task::client %r sent invalid task message: %s"%(
598 635 client_id, msg), exc_info=True)
599 636 return
600 637 record = init_record(msg)
601 638
602 639 record['client_uuid'] = client_id
603 640 record['queue'] = 'task'
604 641 header = msg['header']
605 642 msg_id = header['msg_id']
606 643 self.pending.add(msg_id)
607 self.db.add_record(msg_id, record)
644 try:
645 # it's posible iopub arrived first:
646 existing = self.db.get_record(msg_id)
647 for key,evalue in existing.iteritems():
648 rvalue = record[key]
649 if evalue and rvalue and evalue != rvalue:
650 self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
651 elif evalue and not rvalue:
652 record[key] = evalue
653 self.db.update_record(msg_id, record)
654 except KeyError:
655 self.db.add_record(msg_id, record)
608 656
609 657 def save_task_result(self, idents, msg):
610 658 """save the result of a completed task."""
611 659 client_id = idents[0]
612 660 try:
613 661 msg = self.session.unpack_message(msg, content=False)
614 662 except:
615 663 self.log.error("task::invalid task result message send to %r: %s"%(
616 664 client_id, msg), exc_info=True)
617 665 raise
618 666 return
619 667
620 668 parent = msg['parent_header']
621 669 if not parent:
622 670 # print msg
623 671 self.log.warn("Task %r had no parent!"%msg)
624 672 return
625 673 msg_id = parent['msg_id']
626 674
627 675 header = msg['header']
628 676 engine_uuid = header.get('engine', None)
629 677 eid = self.by_ident.get(engine_uuid, None)
630 678
631 679 if msg_id in self.pending:
632 680 self.pending.remove(msg_id)
633 681 self.all_completed.add(msg_id)
634 682 if eid is not None:
635 683 self.completed[eid].append(msg_id)
636 684 if msg_id in self.tasks[eid]:
637 685 self.tasks[eid].remove(msg_id)
638 686 completed = datetime.strptime(header['date'], ISO8601)
639 687 started = header.get('started', None)
640 688 if started is not None:
641 689 started = datetime.strptime(started, ISO8601)
642 690 result = {
643 691 'result_header' : header,
644 692 'result_content': msg['content'],
645 693 'started' : started,
646 694 'completed' : completed,
647 695 'engine_uuid': engine_uuid
648 696 }
649 697
650 698 result['result_buffers'] = msg['buffers']
651 699 self.db.update_record(msg_id, result)
652 700
653 701 else:
654 702 self.log.debug("task::unknown task %s finished"%msg_id)
655 703
656 704 def save_task_destination(self, idents, msg):
657 705 try:
658 706 msg = self.session.unpack_message(msg, content=True)
659 707 except:
660 708 self.log.error("task::invalid task tracking message", exc_info=True)
661 709 return
662 710 content = msg['content']
663 711 # print (content)
664 712 msg_id = content['msg_id']
665 713 engine_uuid = content['engine_id']
666 714 eid = self.by_ident[engine_uuid]
667 715
668 716 self.log.info("task::task %s arrived on %s"%(msg_id, eid))
669 717 # if msg_id in self.mia:
670 718 # self.mia.remove(msg_id)
671 719 # else:
672 720 # self.log.debug("task::task %s not listed as MIA?!"%(msg_id))
673 721
674 722 self.tasks[eid].append(msg_id)
675 723 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
676 724 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
677 725
678 726 def mia_task_request(self, idents, msg):
679 727 raise NotImplementedError
680 728 client_id = idents[0]
681 729 # content = dict(mia=self.mia,status='ok')
682 730 # self.session.send('mia_reply', content=content, idents=client_id)
683 731
684 732
685 733 #--------------------- IOPub Traffic ------------------------------
686 734
687 735 def save_iopub_message(self, topics, msg):
688 736 """save an iopub message into the db"""
689 737 # print (topics)
690 738 try:
691 739 msg = self.session.unpack_message(msg, content=True)
692 740 except:
693 741 self.log.error("iopub::invalid IOPub message", exc_info=True)
694 742 return
695 743
696 744 parent = msg['parent_header']
697 745 if not parent:
698 746 self.log.error("iopub::invalid IOPub message: %s"%msg)
699 747 return
700 748 msg_id = parent['msg_id']
701 749 msg_type = msg['msg_type']
702 750 content = msg['content']
703 751
704 752 # ensure msg_id is in db
705 753 try:
706 754 rec = self.db.get_record(msg_id)
707 except:
708 self.log.error("iopub::IOPub message has invalid parent", exc_info=True)
709 return
755 except KeyError:
756 rec = empty_record()
757 rec['msg_id'] = msg_id
758 self.db.add_record(msg_id, rec)
710 759 # stream
711 760 d = {}
712 761 if msg_type == 'stream':
713 762 name = content['name']
714 763 s = rec[name] or ''
715 764 d[name] = s + content['data']
716 765
717 766 elif msg_type == 'pyerr':
718 767 d['pyerr'] = content
719 768 else:
720 769 d[msg_type] = content['data']
721 770
722 771 self.db.update_record(msg_id, d)
723 772
724 773
725 774
726 775 #-------------------------------------------------------------------------
727 776 # Registration requests
728 777 #-------------------------------------------------------------------------
729 778
730 779 def connection_request(self, client_id, msg):
731 780 """Reply with connection addresses for clients."""
732 781 self.log.info("client::client %s connected"%client_id)
733 782 content = dict(status='ok')
734 783 content.update(self.client_info)
735 784 jsonable = {}
736 785 for k,v in self.keytable.iteritems():
737 jsonable[str(k)] = v
786 if v not in self.dead_engines:
787 jsonable[str(k)] = v
738 788 content['engines'] = jsonable
739 789 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
740 790
741 791 def register_engine(self, reg, msg):
742 792 """Register a new engine."""
743 793 content = msg['content']
744 794 try:
745 795 queue = content['queue']
746 796 except KeyError:
747 797 self.log.error("registration::queue not specified", exc_info=True)
748 798 return
749 799 heart = content.get('heartbeat', None)
750 800 """register a new engine, and create the socket(s) necessary"""
751 801 eid = self._next_id
752 802 # print (eid, queue, reg, heart)
753 803
754 804 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
755 805
756 806 content = dict(id=eid,status='ok')
757 807 content.update(self.engine_info)
758 808 # check if requesting available IDs:
759 809 if queue in self.by_ident:
760 810 try:
761 811 raise KeyError("queue_id %r in use"%queue)
762 812 except:
763 813 content = error.wrap_exception()
764 814 self.log.error("queue_id %r in use"%queue, exc_info=True)
765 815 elif heart in self.hearts: # need to check unique hearts?
766 816 try:
767 817 raise KeyError("heart_id %r in use"%heart)
768 818 except:
769 819 self.log.error("heart_id %r in use"%heart, exc_info=True)
770 820 content = error.wrap_exception()
771 821 else:
772 822 for h, pack in self.incoming_registrations.iteritems():
773 823 if heart == h:
774 824 try:
775 825 raise KeyError("heart_id %r in use"%heart)
776 826 except:
777 827 self.log.error("heart_id %r in use"%heart, exc_info=True)
778 828 content = error.wrap_exception()
779 829 break
780 830 elif queue == pack[1]:
781 831 try:
782 832 raise KeyError("queue_id %r in use"%queue)
783 833 except:
784 834 self.log.error("queue_id %r in use"%queue, exc_info=True)
785 835 content = error.wrap_exception()
786 836 break
787 837
788 838 msg = self.session.send(self.query, "registration_reply",
789 839 content=content,
790 840 ident=reg)
791 841
792 842 if content['status'] == 'ok':
793 843 if heart in self.heartmonitor.hearts:
794 844 # already beating
795 845 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
796 846 self.finish_registration(heart)
797 847 else:
798 848 purge = lambda : self._purge_stalled_registration(heart)
799 849 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
800 850 dc.start()
801 851 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
802 852 else:
803 853 self.log.error("registration::registration %i failed: %s"%(eid, content['evalue']))
804 854 return eid
805 855
806 856 def unregister_engine(self, ident, msg):
807 857 """Unregister an engine that explicitly requested to leave."""
808 858 try:
809 859 eid = msg['content']['id']
810 860 except:
811 861 self.log.error("registration::bad engine id for unregistration: %s"%ident, exc_info=True)
812 862 return
813 863 self.log.info("registration::unregister_engine(%s)"%eid)
814 864 # print (eid)
815 content=dict(id=eid, queue=self.engines[eid].queue)
816 self.ids.remove(eid)
817 uuid = self.keytable.pop(eid)
818 ec = self.engines.pop(eid)
819 self.hearts.pop(ec.heartbeat)
820 self.by_ident.pop(ec.queue)
821 self.completed.pop(eid)
822 self._handle_stranded_msgs(eid, uuid)
823 ############## TODO: HANDLE IT ################
865 uuid = self.keytable[eid]
866 content=dict(id=eid, queue=uuid)
867 self.dead_engines.add(uuid)
868 # self.ids.remove(eid)
869 # uuid = self.keytable.pop(eid)
870 #
871 # ec = self.engines.pop(eid)
872 # self.hearts.pop(ec.heartbeat)
873 # self.by_ident.pop(ec.queue)
874 # self.completed.pop(eid)
875 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
876 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
877 dc.start()
878 ############## TODO: HANDLE IT ################
824 879
825 880 if self.notifier:
826 881 self.session.send(self.notifier, "unregistration_notification", content=content)
827 882
828 883 def _handle_stranded_msgs(self, eid, uuid):
829 884 """Handle messages known to be on an engine when the engine unregisters.
830 885
831 886 It is possible that this will fire prematurely - that is, an engine will
832 887 go down after completing a result, and the client will be notified
833 888 that the result failed and later receive the actual result.
834 889 """
835 890
836 outstanding = self.queues.pop(eid)
891 outstanding = self.queues[eid]
837 892
838 893 for msg_id in outstanding:
839 894 self.pending.remove(msg_id)
840 895 self.all_completed.add(msg_id)
841 896 try:
842 897 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
843 898 except:
844 899 content = error.wrap_exception()
845 900 # build a fake header:
846 901 header = {}
847 902 header['engine'] = uuid
848 903 header['date'] = datetime.now().strftime(ISO8601)
849 904 rec = dict(result_content=content, result_header=header, result_buffers=[])
850 905 rec['completed'] = header['date']
851 906 rec['engine_uuid'] = uuid
852 907 self.db.update_record(msg_id, rec)
853 908
854 909 def finish_registration(self, heart):
855 910 """Second half of engine registration, called after our HeartMonitor
856 911 has received a beat from the Engine's Heart."""
857 912 try:
858 913 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
859 914 except KeyError:
860 915 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
861 916 return
862 917 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
863 918 if purge is not None:
864 919 purge.stop()
865 920 control = queue
866 921 self.ids.add(eid)
867 922 self.keytable[eid] = queue
868 923 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
869 924 control=control, heartbeat=heart)
870 925 self.by_ident[queue] = eid
871 926 self.queues[eid] = list()
872 927 self.tasks[eid] = list()
873 928 self.completed[eid] = list()
874 929 self.hearts[heart] = eid
875 930 content = dict(id=eid, queue=self.engines[eid].queue)
876 931 if self.notifier:
877 932 self.session.send(self.notifier, "registration_notification", content=content)
878 933 self.log.info("engine::Engine Connected: %i"%eid)
879 934
880 935 def _purge_stalled_registration(self, heart):
881 936 if heart in self.incoming_registrations:
882 937 eid = self.incoming_registrations.pop(heart)[0]
883 938 self.log.info("registration::purging stalled registration: %i"%eid)
884 939 else:
885 940 pass
886 941
887 942 #-------------------------------------------------------------------------
888 943 # Client Requests
889 944 #-------------------------------------------------------------------------
890 945
891 946 def shutdown_request(self, client_id, msg):
892 947 """handle shutdown request."""
893 948 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
894 949 # also notify other clients of shutdown
895 950 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
896 951 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
897 952 dc.start()
898 953
899 954 def _shutdown(self):
900 955 self.log.info("hub::hub shutting down.")
901 956 time.sleep(0.1)
902 957 sys.exit(0)
903 958
904 959
905 960 def check_load(self, client_id, msg):
906 961 content = msg['content']
907 962 try:
908 963 targets = content['targets']
909 964 targets = self._validate_targets(targets)
910 965 except:
911 966 content = error.wrap_exception()
912 967 self.session.send(self.query, "hub_error",
913 968 content=content, ident=client_id)
914 969 return
915 970
916 971 content = dict(status='ok')
917 972 # loads = {}
918 973 for t in targets:
919 974 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
920 975 self.session.send(self.query, "load_reply", content=content, ident=client_id)
921 976
922 977
923 978 def queue_status(self, client_id, msg):
924 979 """Return the Queue status of one or more targets.
925 980 if verbose: return the msg_ids
926 981 else: return len of each type.
927 982 keys: queue (pending MUX jobs)
928 983 tasks (pending Task jobs)
929 984 completed (finished jobs from both queues)"""
930 985 content = msg['content']
931 986 targets = content['targets']
932 987 try:
933 988 targets = self._validate_targets(targets)
934 989 except:
935 990 content = error.wrap_exception()
936 991 self.session.send(self.query, "hub_error",
937 992 content=content, ident=client_id)
938 993 return
939 994 verbose = content.get('verbose', False)
940 995 content = dict(status='ok')
941 996 for t in targets:
942 997 queue = self.queues[t]
943 998 completed = self.completed[t]
944 999 tasks = self.tasks[t]
945 1000 if not verbose:
946 1001 queue = len(queue)
947 1002 completed = len(completed)
948 1003 tasks = len(tasks)
949 1004 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
950 1005 # pending
951 1006 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
952 1007
953 1008 def purge_results(self, client_id, msg):
954 1009 """Purge results from memory. This method is more valuable before we move
955 1010 to a DB based message storage mechanism."""
956 1011 content = msg['content']
957 1012 msg_ids = content.get('msg_ids', [])
958 1013 reply = dict(status='ok')
959 1014 if msg_ids == 'all':
960 1015 self.db.drop_matching_records(dict(completed={'$ne':None}))
961 1016 else:
962 1017 for msg_id in msg_ids:
963 1018 if msg_id in self.all_completed:
964 1019 self.db.drop_record(msg_id)
965 1020 else:
966 1021 if msg_id in self.pending:
967 1022 try:
968 1023 raise IndexError("msg pending: %r"%msg_id)
969 1024 except:
970 1025 reply = error.wrap_exception()
971 1026 else:
972 1027 try:
973 1028 raise IndexError("No such msg: %r"%msg_id)
974 1029 except:
975 1030 reply = error.wrap_exception()
976 1031 break
977 1032 eids = content.get('engine_ids', [])
978 1033 for eid in eids:
979 1034 if eid not in self.engines:
980 1035 try:
981 1036 raise IndexError("No such engine: %i"%eid)
982 1037 except:
983 1038 reply = error.wrap_exception()
984 1039 break
985 1040 msg_ids = self.completed.pop(eid)
986 1041 uid = self.engines[eid].queue
987 1042 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
988 1043
989 1044 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
990 1045
991 1046 def resubmit_task(self, client_id, msg, buffers):
992 1047 """Resubmit a task."""
993 1048 raise NotImplementedError
994 1049
995 1050 def get_results(self, client_id, msg):
996 1051 """Get the result of 1 or more messages."""
997 1052 content = msg['content']
998 1053 msg_ids = sorted(set(content['msg_ids']))
999 1054 statusonly = content.get('status_only', False)
1000 1055 pending = []
1001 1056 completed = []
1002 1057 content = dict(status='ok')
1003 1058 content['pending'] = pending
1004 1059 content['completed'] = completed
1005 1060 buffers = []
1006 1061 if not statusonly:
1007 1062 content['results'] = {}
1008 1063 records = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1009 1064 for msg_id in msg_ids:
1010 1065 if msg_id in self.pending:
1011 1066 pending.append(msg_id)
1012 1067 elif msg_id in self.all_completed:
1013 1068 completed.append(msg_id)
1014 1069 if not statusonly:
1015 1070 rec = records[msg_id]
1016 1071 io_dict = {}
1017 1072 for key in 'pyin pyout pyerr stdout stderr'.split():
1018 1073 io_dict[key] = rec[key]
1019 1074 content[msg_id] = { 'result_content': rec['result_content'],
1020 1075 'header': rec['header'],
1021 1076 'result_header' : rec['result_header'],
1022 1077 'io' : io_dict,
1023 1078 }
1024 1079 if rec['result_buffers']:
1025 1080 buffers.extend(map(str, rec['result_buffers']))
1026 1081 else:
1027 1082 try:
1028 1083 raise KeyError('No such message: '+msg_id)
1029 1084 except:
1030 1085 content = error.wrap_exception()
1031 1086 break
1032 1087 self.session.send(self.query, "result_reply", content=content,
1033 1088 parent=msg, ident=client_id,
1034 1089 buffers=buffers)
1035 1090
@@ -1,971 +1,971 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 Facilities for launching IPython processes asynchronously.
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2009 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import copy
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23
24 24 from signal import SIGINT, SIGTERM
25 25 try:
26 26 from signal import SIGKILL
27 27 except ImportError:
28 28 SIGKILL=SIGTERM
29 29
30 30 from subprocess import Popen, PIPE, STDOUT
31 31 try:
32 32 from subprocess import check_output
33 33 except ImportError:
34 34 # pre-2.7, define check_output with Popen
35 35 def check_output(*args, **kwargs):
36 36 kwargs.update(dict(stdout=PIPE))
37 37 p = Popen(*args, **kwargs)
38 38 out,err = p.communicate()
39 39 return out
40 40
41 41 from zmq.eventloop import ioloop
42 42
43 43 from IPython.external import Itpl
44 44 # from IPython.config.configurable import Configurable
45 45 from IPython.utils.traitlets import Any, Str, Int, List, Unicode, Dict, Instance, CUnicode
46 46 from IPython.utils.path import get_ipython_module_path
47 47 from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError
48 48
49 49 from .factory import LoggingFactory
50 50
51 51 # load winhpcjob only on Windows
52 52 try:
53 53 from .winhpcjob import (
54 54 IPControllerTask, IPEngineTask,
55 55 IPControllerJob, IPEngineSetJob
56 56 )
57 57 except ImportError:
58 58 pass
59 59
60 60
61 61 #-----------------------------------------------------------------------------
62 62 # Paths to the kernel apps
63 63 #-----------------------------------------------------------------------------
64 64
65 65
66 66 ipclusterz_cmd_argv = pycmd2argv(get_ipython_module_path(
67 67 'IPython.parallel.ipclusterapp'
68 68 ))
69 69
70 70 ipenginez_cmd_argv = pycmd2argv(get_ipython_module_path(
71 71 'IPython.parallel.ipengineapp'
72 72 ))
73 73
74 74 ipcontrollerz_cmd_argv = pycmd2argv(get_ipython_module_path(
75 75 'IPython.parallel.ipcontrollerapp'
76 76 ))
77 77
78 78 #-----------------------------------------------------------------------------
79 79 # Base launchers and errors
80 80 #-----------------------------------------------------------------------------
81 81
82 82
83 83 class LauncherError(Exception):
84 84 pass
85 85
86 86
87 87 class ProcessStateError(LauncherError):
88 88 pass
89 89
90 90
91 91 class UnknownStatus(LauncherError):
92 92 pass
93 93
94 94
95 95 class BaseLauncher(LoggingFactory):
96 96 """An asbtraction for starting, stopping and signaling a process."""
97 97
98 98 # In all of the launchers, the work_dir is where child processes will be
99 99 # run. This will usually be the cluster_dir, but may not be. any work_dir
100 100 # passed into the __init__ method will override the config value.
101 101 # This should not be used to set the work_dir for the actual engine
102 102 # and controller. Instead, use their own config files or the
103 103 # controller_args, engine_args attributes of the launchers to add
104 104 # the --work-dir option.
105 105 work_dir = Unicode(u'.')
106 106 loop = Instance('zmq.eventloop.ioloop.IOLoop')
107 107
108 108 start_data = Any()
109 109 stop_data = Any()
110 110
111 111 def _loop_default(self):
112 112 return ioloop.IOLoop.instance()
113 113
114 114 def __init__(self, work_dir=u'.', config=None, **kwargs):
115 115 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
116 116 self.state = 'before' # can be before, running, after
117 117 self.stop_callbacks = []
118 118 self.start_data = None
119 119 self.stop_data = None
120 120
121 121 @property
122 122 def args(self):
123 123 """A list of cmd and args that will be used to start the process.
124 124
125 125 This is what is passed to :func:`spawnProcess` and the first element
126 126 will be the process name.
127 127 """
128 128 return self.find_args()
129 129
130 130 def find_args(self):
131 131 """The ``.args`` property calls this to find the args list.
132 132
133 133 Subcommand should implement this to construct the cmd and args.
134 134 """
135 135 raise NotImplementedError('find_args must be implemented in a subclass')
136 136
137 137 @property
138 138 def arg_str(self):
139 139 """The string form of the program arguments."""
140 140 return ' '.join(self.args)
141 141
142 142 @property
143 143 def running(self):
144 144 """Am I running."""
145 145 if self.state == 'running':
146 146 return True
147 147 else:
148 148 return False
149 149
150 150 def start(self):
151 151 """Start the process.
152 152
153 153 This must return a deferred that fires with information about the
154 154 process starting (like a pid, job id, etc.).
155 155 """
156 156 raise NotImplementedError('start must be implemented in a subclass')
157 157
158 158 def stop(self):
159 159 """Stop the process and notify observers of stopping.
160 160
161 161 This must return a deferred that fires with information about the
162 162 processing stopping, like errors that occur while the process is
163 163 attempting to be shut down. This deferred won't fire when the process
164 164 actually stops. To observe the actual process stopping, see
165 165 :func:`observe_stop`.
166 166 """
167 167 raise NotImplementedError('stop must be implemented in a subclass')
168 168
169 169 def on_stop(self, f):
170 170 """Get a deferred that will fire when the process stops.
171 171
172 172 The deferred will fire with data that contains information about
173 173 the exit status of the process.
174 174 """
175 175 if self.state=='after':
176 176 return f(self.stop_data)
177 177 else:
178 178 self.stop_callbacks.append(f)
179 179
180 180 def notify_start(self, data):
181 181 """Call this to trigger startup actions.
182 182
183 183 This logs the process startup and sets the state to 'running'. It is
184 184 a pass-through so it can be used as a callback.
185 185 """
186 186
187 187 self.log.info('Process %r started: %r' % (self.args[0], data))
188 188 self.start_data = data
189 189 self.state = 'running'
190 190 return data
191 191
192 192 def notify_stop(self, data):
193 193 """Call this to trigger process stop actions.
194 194
195 195 This logs the process stopping and sets the state to 'after'. Call
196 196 this to trigger all the deferreds from :func:`observe_stop`."""
197 197
198 198 self.log.info('Process %r stopped: %r' % (self.args[0], data))
199 199 self.stop_data = data
200 200 self.state = 'after'
201 201 for i in range(len(self.stop_callbacks)):
202 202 d = self.stop_callbacks.pop()
203 203 d(data)
204 204 return data
205 205
206 206 def signal(self, sig):
207 207 """Signal the process.
208 208
209 209 Return a semi-meaningless deferred after signaling the process.
210 210
211 211 Parameters
212 212 ----------
213 213 sig : str or int
214 214 'KILL', 'INT', etc., or any signal number
215 215 """
216 216 raise NotImplementedError('signal must be implemented in a subclass')
217 217
218 218
219 219 #-----------------------------------------------------------------------------
220 220 # Local process launchers
221 221 #-----------------------------------------------------------------------------
222 222
223 223
224 224 class LocalProcessLauncher(BaseLauncher):
225 225 """Start and stop an external process in an asynchronous manner.
226 226
227 227 This will launch the external process with a working directory of
228 228 ``self.work_dir``.
229 229 """
230 230
231 231 # This is used to to construct self.args, which is passed to
232 232 # spawnProcess.
233 233 cmd_and_args = List([])
234 234 poll_frequency = Int(100) # in ms
235 235
236 236 def __init__(self, work_dir=u'.', config=None, **kwargs):
237 237 super(LocalProcessLauncher, self).__init__(
238 238 work_dir=work_dir, config=config, **kwargs
239 239 )
240 240 self.process = None
241 241 self.start_deferred = None
242 242 self.poller = None
243 243
244 244 def find_args(self):
245 245 return self.cmd_and_args
246 246
247 247 def start(self):
248 248 if self.state == 'before':
249 249 self.process = Popen(self.args,
250 250 stdout=PIPE,stderr=PIPE,stdin=PIPE,
251 251 env=os.environ,
252 252 cwd=self.work_dir
253 253 )
254 254
255 255 self.loop.add_handler(self.process.stdout.fileno(), self.handle_stdout, self.loop.READ)
256 256 self.loop.add_handler(self.process.stderr.fileno(), self.handle_stderr, self.loop.READ)
257 257 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
258 258 self.poller.start()
259 259 self.notify_start(self.process.pid)
260 260 else:
261 261 s = 'The process was already started and has state: %r' % self.state
262 262 raise ProcessStateError(s)
263 263
264 264 def stop(self):
265 265 return self.interrupt_then_kill()
266 266
267 267 def signal(self, sig):
268 268 if self.state == 'running':
269 269 self.process.send_signal(sig)
270 270
271 271 def interrupt_then_kill(self, delay=2.0):
272 272 """Send INT, wait a delay and then send KILL."""
273 273 self.signal(SIGINT)
274 274 self.killer = ioloop.DelayedCallback(lambda : self.signal(SIGKILL), delay*1000, self.loop)
275 275 self.killer.start()
276 276
277 277 # callbacks, etc:
278 278
279 279 def handle_stdout(self, fd, events):
280 280 line = self.process.stdout.readline()
281 281 # a stopped process will be readable but return empty strings
282 282 if line:
283 283 self.log.info(line[:-1])
284 284 else:
285 285 self.poll()
286 286
287 287 def handle_stderr(self, fd, events):
288 288 line = self.process.stderr.readline()
289 289 # a stopped process will be readable but return empty strings
290 290 if line:
291 291 self.log.error(line[:-1])
292 292 else:
293 293 self.poll()
294 294
295 295 def poll(self):
296 296 status = self.process.poll()
297 297 if status is not None:
298 298 self.poller.stop()
299 299 self.loop.remove_handler(self.process.stdout.fileno())
300 300 self.loop.remove_handler(self.process.stderr.fileno())
301 301 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
302 302 return status
303 303
304 304 class LocalControllerLauncher(LocalProcessLauncher):
305 305 """Launch a controller as a regular external process."""
306 306
307 307 controller_cmd = List(ipcontrollerz_cmd_argv, config=True)
308 308 # Command line arguments to ipcontroller.
309 309 controller_args = List(['--log-to-file','--log-level', str(logging.INFO)], config=True)
310 310
311 311 def find_args(self):
312 312 return self.controller_cmd + self.controller_args
313 313
314 314 def start(self, cluster_dir):
315 315 """Start the controller by cluster_dir."""
316 316 self.controller_args.extend(['--cluster-dir', cluster_dir])
317 317 self.cluster_dir = unicode(cluster_dir)
318 318 self.log.info("Starting LocalControllerLauncher: %r" % self.args)
319 319 return super(LocalControllerLauncher, self).start()
320 320
321 321
322 322 class LocalEngineLauncher(LocalProcessLauncher):
323 323 """Launch a single engine as a regular externall process."""
324 324
325 325 engine_cmd = List(ipenginez_cmd_argv, config=True)
326 326 # Command line arguments for ipengine.
327 327 engine_args = List(
328 328 ['--log-to-file','--log-level', str(logging.INFO)], config=True
329 329 )
330 330
331 331 def find_args(self):
332 332 return self.engine_cmd + self.engine_args
333 333
334 334 def start(self, cluster_dir):
335 335 """Start the engine by cluster_dir."""
336 336 self.engine_args.extend(['--cluster-dir', cluster_dir])
337 337 self.cluster_dir = unicode(cluster_dir)
338 338 return super(LocalEngineLauncher, self).start()
339 339
340 340
341 341 class LocalEngineSetLauncher(BaseLauncher):
342 342 """Launch a set of engines as regular external processes."""
343 343
344 344 # Command line arguments for ipengine.
345 345 engine_args = List(
346 346 ['--log-to-file','--log-level', str(logging.INFO)], config=True
347 347 )
348 348 # launcher class
349 349 launcher_class = LocalEngineLauncher
350 350
351 351 launchers = Dict()
352 352 stop_data = Dict()
353 353
354 354 def __init__(self, work_dir=u'.', config=None, **kwargs):
355 355 super(LocalEngineSetLauncher, self).__init__(
356 356 work_dir=work_dir, config=config, **kwargs
357 357 )
358 358 self.stop_data = {}
359 359
360 360 def start(self, n, cluster_dir):
361 361 """Start n engines by profile or cluster_dir."""
362 362 self.cluster_dir = unicode(cluster_dir)
363 363 dlist = []
364 364 for i in range(n):
365 365 el = self.launcher_class(work_dir=self.work_dir, config=self.config, logname=self.log.name)
366 366 # Copy the engine args over to each engine launcher.
367 367 el.engine_args = copy.deepcopy(self.engine_args)
368 368 el.on_stop(self._notice_engine_stopped)
369 369 d = el.start(cluster_dir)
370 370 if i==0:
371 371 self.log.info("Starting LocalEngineSetLauncher: %r" % el.args)
372 372 self.launchers[i] = el
373 373 dlist.append(d)
374 374 self.notify_start(dlist)
375 375 # The consumeErrors here could be dangerous
376 376 # dfinal = gatherBoth(dlist, consumeErrors=True)
377 377 # dfinal.addCallback(self.notify_start)
378 378 return dlist
379 379
380 380 def find_args(self):
381 381 return ['engine set']
382 382
383 383 def signal(self, sig):
384 384 dlist = []
385 385 for el in self.launchers.itervalues():
386 386 d = el.signal(sig)
387 387 dlist.append(d)
388 388 # dfinal = gatherBoth(dlist, consumeErrors=True)
389 389 return dlist
390 390
391 391 def interrupt_then_kill(self, delay=1.0):
392 392 dlist = []
393 393 for el in self.launchers.itervalues():
394 394 d = el.interrupt_then_kill(delay)
395 395 dlist.append(d)
396 396 # dfinal = gatherBoth(dlist, consumeErrors=True)
397 397 return dlist
398 398
399 399 def stop(self):
400 400 return self.interrupt_then_kill()
401 401
402 402 def _notice_engine_stopped(self, data):
403 403 pid = data['pid']
404 404 for idx,el in self.launchers.iteritems():
405 405 if el.process.pid == pid:
406 406 break
407 407 self.launchers.pop(idx)
408 408 self.stop_data[idx] = data
409 409 if not self.launchers:
410 410 self.notify_stop(self.stop_data)
411 411
412 412
413 413 #-----------------------------------------------------------------------------
414 414 # MPIExec launchers
415 415 #-----------------------------------------------------------------------------
416 416
417 417
418 418 class MPIExecLauncher(LocalProcessLauncher):
419 419 """Launch an external process using mpiexec."""
420 420
421 421 # The mpiexec command to use in starting the process.
422 422 mpi_cmd = List(['mpiexec'], config=True)
423 423 # The command line arguments to pass to mpiexec.
424 424 mpi_args = List([], config=True)
425 425 # The program to start using mpiexec.
426 426 program = List(['date'], config=True)
427 427 # The command line argument to the program.
428 428 program_args = List([], config=True)
429 429 # The number of instances of the program to start.
430 430 n = Int(1, config=True)
431 431
432 432 def find_args(self):
433 433 """Build self.args using all the fields."""
434 434 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
435 435 self.program + self.program_args
436 436
437 437 def start(self, n):
438 438 """Start n instances of the program using mpiexec."""
439 439 self.n = n
440 440 return super(MPIExecLauncher, self).start()
441 441
442 442
443 443 class MPIExecControllerLauncher(MPIExecLauncher):
444 444 """Launch a controller using mpiexec."""
445 445
446 446 controller_cmd = List(ipcontrollerz_cmd_argv, config=True)
447 447 # Command line arguments to ipcontroller.
448 448 controller_args = List(['--log-to-file','--log-level', str(logging.INFO)], config=True)
449 449 n = Int(1, config=False)
450 450
451 451 def start(self, cluster_dir):
452 452 """Start the controller by cluster_dir."""
453 453 self.controller_args.extend(['--cluster-dir', cluster_dir])
454 454 self.cluster_dir = unicode(cluster_dir)
455 455 self.log.info("Starting MPIExecControllerLauncher: %r" % self.args)
456 456 return super(MPIExecControllerLauncher, self).start(1)
457 457
458 458 def find_args(self):
459 459 return self.mpi_cmd + ['-n', self.n] + self.mpi_args + \
460 460 self.controller_cmd + self.controller_args
461 461
462 462
463 463 class MPIExecEngineSetLauncher(MPIExecLauncher):
464 464
465 465 program = List(ipenginez_cmd_argv, config=True)
466 466 # Command line arguments for ipengine.
467 467 program_args = List(
468 468 ['--log-to-file','--log-level', str(logging.INFO)], config=True
469 469 )
470 470 n = Int(1, config=True)
471 471
472 472 def start(self, n, cluster_dir):
473 473 """Start n engines by profile or cluster_dir."""
474 474 self.program_args.extend(['--cluster-dir', cluster_dir])
475 475 self.cluster_dir = unicode(cluster_dir)
476 476 self.n = n
477 477 self.log.info('Starting MPIExecEngineSetLauncher: %r' % self.args)
478 478 return super(MPIExecEngineSetLauncher, self).start(n)
479 479
480 480 #-----------------------------------------------------------------------------
481 481 # SSH launchers
482 482 #-----------------------------------------------------------------------------
483 483
484 484 # TODO: Get SSH Launcher working again.
485 485
486 486 class SSHLauncher(LocalProcessLauncher):
487 487 """A minimal launcher for ssh.
488 488
489 489 To be useful this will probably have to be extended to use the ``sshx``
490 490 idea for environment variables. There could be other things this needs
491 491 as well.
492 492 """
493 493
494 494 ssh_cmd = List(['ssh'], config=True)
495 495 ssh_args = List(['-tt'], config=True)
496 496 program = List(['date'], config=True)
497 497 program_args = List([], config=True)
498 498 hostname = CUnicode('', config=True)
499 499 user = CUnicode('', config=True)
500 500 location = CUnicode('')
501 501
502 502 def _hostname_changed(self, name, old, new):
503 503 if self.user:
504 504 self.location = u'%s@%s' % (self.user, new)
505 505 else:
506 506 self.location = new
507 507
508 508 def _user_changed(self, name, old, new):
509 509 self.location = u'%s@%s' % (new, self.hostname)
510 510
511 511 def find_args(self):
512 512 return self.ssh_cmd + self.ssh_args + [self.location] + \
513 513 self.program + self.program_args
514 514
515 515 def start(self, cluster_dir, hostname=None, user=None):
516 516 self.cluster_dir = unicode(cluster_dir)
517 517 if hostname is not None:
518 518 self.hostname = hostname
519 519 if user is not None:
520 520 self.user = user
521 521
522 522 return super(SSHLauncher, self).start()
523 523
524 524 def signal(self, sig):
525 525 if self.state == 'running':
526 526 # send escaped ssh connection-closer
527 527 self.process.stdin.write('~.')
528 528 self.process.stdin.flush()
529 529
530 530
531 531
532 532 class SSHControllerLauncher(SSHLauncher):
533 533
534 534 program = List(ipcontrollerz_cmd_argv, config=True)
535 535 # Command line arguments to ipcontroller.
536 536 program_args = List(['-r', '--log-to-file','--log-level', str(logging.INFO)], config=True)
537 537
538 538
539 539 class SSHEngineLauncher(SSHLauncher):
540 540 program = List(ipenginez_cmd_argv, config=True)
541 541 # Command line arguments for ipengine.
542 542 program_args = List(
543 543 ['--log-to-file','--log-level', str(logging.INFO)], config=True
544 544 )
545 545
546 546 class SSHEngineSetLauncher(LocalEngineSetLauncher):
547 547 launcher_class = SSHEngineLauncher
548 548 engines = Dict(config=True)
549 549
550 550 def start(self, n, cluster_dir):
551 551 """Start engines by profile or cluster_dir.
552 552 `n` is ignored, and the `engines` config property is used instead.
553 553 """
554 554
555 555 self.cluster_dir = unicode(cluster_dir)
556 556 dlist = []
557 557 for host, n in self.engines.iteritems():
558 558 if isinstance(n, (tuple, list)):
559 559 n, args = n
560 560 else:
561 561 args = copy.deepcopy(self.engine_args)
562 562
563 563 if '@' in host:
564 564 user,host = host.split('@',1)
565 565 else:
566 566 user=None
567 567 for i in range(n):
568 568 el = self.launcher_class(work_dir=self.work_dir, config=self.config, logname=self.log.name)
569 569
570 570 # Copy the engine args over to each engine launcher.
571 571 i
572 572 el.program_args = args
573 573 el.on_stop(self._notice_engine_stopped)
574 574 d = el.start(cluster_dir, user=user, hostname=host)
575 575 if i==0:
576 576 self.log.info("Starting SSHEngineSetLauncher: %r" % el.args)
577 577 self.launchers[host+str(i)] = el
578 578 dlist.append(d)
579 579 self.notify_start(dlist)
580 580 return dlist
581 581
582 582
583 583
584 584 #-----------------------------------------------------------------------------
585 585 # Windows HPC Server 2008 scheduler launchers
586 586 #-----------------------------------------------------------------------------
587 587
588 588
589 589 # This is only used on Windows.
590 590 def find_job_cmd():
591 591 if os.name=='nt':
592 592 try:
593 593 return find_cmd('job')
594 594 except FindCmdError:
595 595 return 'job'
596 596 else:
597 597 return 'job'
598 598
599 599
600 600 class WindowsHPCLauncher(BaseLauncher):
601 601
602 602 # A regular expression used to get the job id from the output of the
603 603 # submit_command.
604 604 job_id_regexp = Str(r'\d+', config=True)
605 605 # The filename of the instantiated job script.
606 606 job_file_name = CUnicode(u'ipython_job.xml', config=True)
607 607 # The full path to the instantiated job script. This gets made dynamically
608 608 # by combining the work_dir with the job_file_name.
609 609 job_file = CUnicode(u'')
610 610 # The hostname of the scheduler to submit the job to
611 611 scheduler = CUnicode('', config=True)
612 612 job_cmd = CUnicode(find_job_cmd(), config=True)
613 613
614 614 def __init__(self, work_dir=u'.', config=None, **kwargs):
615 615 super(WindowsHPCLauncher, self).__init__(
616 616 work_dir=work_dir, config=config, **kwargs
617 617 )
618 618
619 619 @property
620 620 def job_file(self):
621 621 return os.path.join(self.work_dir, self.job_file_name)
622 622
623 623 def write_job_file(self, n):
624 624 raise NotImplementedError("Implement write_job_file in a subclass.")
625 625
626 626 def find_args(self):
627 627 return [u'job.exe']
628 628
629 629 def parse_job_id(self, output):
630 630 """Take the output of the submit command and return the job id."""
631 631 m = re.search(self.job_id_regexp, output)
632 632 if m is not None:
633 633 job_id = m.group()
634 634 else:
635 635 raise LauncherError("Job id couldn't be determined: %s" % output)
636 636 self.job_id = job_id
637 637 self.log.info('Job started with job id: %r' % job_id)
638 638 return job_id
639 639
640 640 def start(self, n):
641 641 """Start n copies of the process using the Win HPC job scheduler."""
642 642 self.write_job_file(n)
643 643 args = [
644 644 'submit',
645 645 '/jobfile:%s' % self.job_file,
646 646 '/scheduler:%s' % self.scheduler
647 647 ]
648 648 self.log.info("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
649 649 # Twisted will raise DeprecationWarnings if we try to pass unicode to this
650 650 output = check_output([self.job_cmd]+args,
651 651 env=os.environ,
652 652 cwd=self.work_dir,
653 653 stderr=STDOUT
654 654 )
655 655 job_id = self.parse_job_id(output)
656 656 self.notify_start(job_id)
657 657 return job_id
658 658
659 659 def stop(self):
660 660 args = [
661 661 'cancel',
662 662 self.job_id,
663 663 '/scheduler:%s' % self.scheduler
664 664 ]
665 665 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
666 666 try:
667 667 output = check_output([self.job_cmd]+args,
668 668 env=os.environ,
669 669 cwd=self.work_dir,
670 670 stderr=STDOUT
671 671 )
672 672 except:
673 673 output = 'The job already appears to be stoppped: %r' % self.job_id
674 674 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
675 675 return output
676 676
677 677
678 678 class WindowsHPCControllerLauncher(WindowsHPCLauncher):
679 679
680 680 job_file_name = CUnicode(u'ipcontroller_job.xml', config=True)
681 681 extra_args = List([], config=False)
682 682
683 683 def write_job_file(self, n):
684 684 job = IPControllerJob(config=self.config)
685 685
686 686 t = IPControllerTask(config=self.config)
687 687 # The tasks work directory is *not* the actual work directory of
688 688 # the controller. It is used as the base path for the stdout/stderr
689 689 # files that the scheduler redirects to.
690 690 t.work_directory = self.cluster_dir
691 691 # Add the --cluster-dir and from self.start().
692 692 t.controller_args.extend(self.extra_args)
693 693 job.add_task(t)
694 694
695 695 self.log.info("Writing job description file: %s" % self.job_file)
696 696 job.write(self.job_file)
697 697
698 698 @property
699 699 def job_file(self):
700 700 return os.path.join(self.cluster_dir, self.job_file_name)
701 701
702 702 def start(self, cluster_dir):
703 703 """Start the controller by cluster_dir."""
704 704 self.extra_args = ['--cluster-dir', cluster_dir]
705 705 self.cluster_dir = unicode(cluster_dir)
706 706 return super(WindowsHPCControllerLauncher, self).start(1)
707 707
708 708
709 709 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher):
710 710
711 711 job_file_name = CUnicode(u'ipengineset_job.xml', config=True)
712 712 extra_args = List([], config=False)
713 713
714 714 def write_job_file(self, n):
715 715 job = IPEngineSetJob(config=self.config)
716 716
717 717 for i in range(n):
718 718 t = IPEngineTask(config=self.config)
719 719 # The tasks work directory is *not* the actual work directory of
720 720 # the engine. It is used as the base path for the stdout/stderr
721 721 # files that the scheduler redirects to.
722 722 t.work_directory = self.cluster_dir
723 723 # Add the --cluster-dir and from self.start().
724 724 t.engine_args.extend(self.extra_args)
725 725 job.add_task(t)
726 726
727 727 self.log.info("Writing job description file: %s" % self.job_file)
728 728 job.write(self.job_file)
729 729
730 730 @property
731 731 def job_file(self):
732 732 return os.path.join(self.cluster_dir, self.job_file_name)
733 733
734 734 def start(self, n, cluster_dir):
735 735 """Start the controller by cluster_dir."""
736 736 self.extra_args = ['--cluster-dir', cluster_dir]
737 737 self.cluster_dir = unicode(cluster_dir)
738 738 return super(WindowsHPCEngineSetLauncher, self).start(n)
739 739
740 740
741 741 #-----------------------------------------------------------------------------
742 742 # Batch (PBS) system launchers
743 743 #-----------------------------------------------------------------------------
744 744
745 745 class BatchSystemLauncher(BaseLauncher):
746 746 """Launch an external process using a batch system.
747 747
748 748 This class is designed to work with UNIX batch systems like PBS, LSF,
749 749 GridEngine, etc. The overall model is that there are different commands
750 750 like qsub, qdel, etc. that handle the starting and stopping of the process.
751 751
752 752 This class also has the notion of a batch script. The ``batch_template``
753 753 attribute can be set to a string that is a template for the batch script.
754 754 This template is instantiated using Itpl. Thus the template can use
755 755 ${n} fot the number of instances. Subclasses can add additional variables
756 756 to the template dict.
757 757 """
758 758
759 759 # Subclasses must fill these in. See PBSEngineSet
760 760 # The name of the command line program used to submit jobs.
761 761 submit_command = List([''], config=True)
762 762 # The name of the command line program used to delete jobs.
763 763 delete_command = List([''], config=True)
764 764 # A regular expression used to get the job id from the output of the
765 765 # submit_command.
766 766 job_id_regexp = CUnicode('', config=True)
767 767 # The string that is the batch script template itself.
768 768 batch_template = CUnicode('', config=True)
769 769 # The file that contains the batch template
770 770 batch_template_file = CUnicode(u'', config=True)
771 771 # The filename of the instantiated batch script.
772 772 batch_file_name = CUnicode(u'batch_script', config=True)
773 773 # The PBS Queue
774 774 queue = CUnicode(u'', config=True)
775 775
776 776 # not configurable, override in subclasses
777 777 # PBS Job Array regex
778 778 job_array_regexp = CUnicode('')
779 779 job_array_template = CUnicode('')
780 780 # PBS Queue regex
781 781 queue_regexp = CUnicode('')
782 782 queue_template = CUnicode('')
783 783 # The default batch template, override in subclasses
784 784 default_template = CUnicode('')
785 785 # The full path to the instantiated batch script.
786 786 batch_file = CUnicode(u'')
787 787 # the format dict used with batch_template:
788 788 context = Dict()
789 789
790 790
791 791 def find_args(self):
792 792 return self.submit_command + [self.batch_file]
793 793
794 794 def __init__(self, work_dir=u'.', config=None, **kwargs):
795 795 super(BatchSystemLauncher, self).__init__(
796 796 work_dir=work_dir, config=config, **kwargs
797 797 )
798 798 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
799 799
800 800 def parse_job_id(self, output):
801 801 """Take the output of the submit command and return the job id."""
802 802 m = re.search(self.job_id_regexp, output)
803 803 if m is not None:
804 804 job_id = m.group()
805 805 else:
806 806 raise LauncherError("Job id couldn't be determined: %s" % output)
807 807 self.job_id = job_id
808 808 self.log.info('Job submitted with job id: %r' % job_id)
809 809 return job_id
810 810
811 811 def write_batch_script(self, n):
812 812 """Instantiate and write the batch script to the work_dir."""
813 813 self.context['n'] = n
814 814 self.context['queue'] = self.queue
815 815 print self.context
816 816 # first priority is batch_template if set
817 817 if self.batch_template_file and not self.batch_template:
818 818 # second priority is batch_template_file
819 819 with open(self.batch_template_file) as f:
820 820 self.batch_template = f.read()
821 821 if not self.batch_template:
822 822 # third (last) priority is default_template
823 823 self.batch_template = self.default_template
824 824
825 825 regex = re.compile(self.job_array_regexp)
826 826 # print regex.search(self.batch_template)
827 827 if not regex.search(self.batch_template):
828 828 self.log.info("adding job array settings to batch script")
829 829 firstline, rest = self.batch_template.split('\n',1)
830 830 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
831 831
832 832 regex = re.compile(self.queue_regexp)
833 833 # print regex.search(self.batch_template)
834 834 if self.queue and not regex.search(self.batch_template):
835 835 self.log.info("adding PBS queue settings to batch script")
836 836 firstline, rest = self.batch_template.split('\n',1)
837 837 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
838 838
839 839 script_as_string = Itpl.itplns(self.batch_template, self.context)
840 840 self.log.info('Writing instantiated batch script: %s' % self.batch_file)
841 841
842 842 with open(self.batch_file, 'w') as f:
843 843 f.write(script_as_string)
844 844 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
845 845
846 846 def start(self, n, cluster_dir):
847 847 """Start n copies of the process using a batch system."""
848 848 # Here we save profile and cluster_dir in the context so they
849 849 # can be used in the batch script template as ${profile} and
850 850 # ${cluster_dir}
851 851 self.context['cluster_dir'] = cluster_dir
852 852 self.cluster_dir = unicode(cluster_dir)
853 853 self.write_batch_script(n)
854 854 output = check_output(self.args, env=os.environ)
855 855
856 856 job_id = self.parse_job_id(output)
857 857 self.notify_start(job_id)
858 858 return job_id
859 859
860 860 def stop(self):
861 861 output = check_output(self.delete_command+[self.job_id], env=os.environ)
862 862 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
863 863 return output
864 864
865 865
866 866 class PBSLauncher(BatchSystemLauncher):
867 867 """A BatchSystemLauncher subclass for PBS."""
868 868
869 869 submit_command = List(['qsub'], config=True)
870 870 delete_command = List(['qdel'], config=True)
871 871 job_id_regexp = CUnicode(r'\d+', config=True)
872 872
873 873 batch_file = CUnicode(u'')
874 874 job_array_regexp = CUnicode('#PBS\W+-t\W+[\w\d\-\$]+')
875 875 job_array_template = CUnicode('#PBS -t 1-$n')
876 876 queue_regexp = CUnicode('#PBS\W+-q\W+\$?\w+')
877 877 queue_template = CUnicode('#PBS -q $queue')
878 878
879 879
880 880 class PBSControllerLauncher(PBSLauncher):
881 881 """Launch a controller using PBS."""
882 882
883 883 batch_file_name = CUnicode(u'pbs_controller', config=True)
884 884 default_template= CUnicode("""#!/bin/sh
885 885 #PBS -V
886 886 #PBS -N ipcontrollerz
887 887 %s --log-to-file --cluster-dir $cluster_dir
888 888 """%(' '.join(ipcontrollerz_cmd_argv)))
889 889
890 890 def start(self, cluster_dir):
891 891 """Start the controller by profile or cluster_dir."""
892 892 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
893 893 return super(PBSControllerLauncher, self).start(1, cluster_dir)
894 894
895 895
896 896 class PBSEngineSetLauncher(PBSLauncher):
897 897 """Launch Engines using PBS"""
898 898 batch_file_name = CUnicode(u'pbs_engines', config=True)
899 899 default_template= CUnicode(u"""#!/bin/sh
900 900 #PBS -V
901 901 #PBS -N ipenginez
902 902 %s --cluster-dir $cluster_dir
903 903 """%(' '.join(ipenginez_cmd_argv)))
904 904
905 905 def start(self, n, cluster_dir):
906 906 """Start n engines by profile or cluster_dir."""
907 self.log.info('Starting %n engines with PBSEngineSetLauncher: %r' % (n, self.args))
907 self.log.info('Starting %i engines with PBSEngineSetLauncher: %r' % (n, self.args))
908 908 return super(PBSEngineSetLauncher, self).start(n, cluster_dir)
909 909
910 910 #SGE is very similar to PBS
911 911
912 912 class SGELauncher(PBSLauncher):
913 913 """Sun GridEngine is a PBS clone with slightly different syntax"""
914 914 job_array_regexp = CUnicode('#$$\W+-t\W+[\w\d\-\$]+')
915 915 job_array_template = CUnicode('#$$ -t 1-$n')
916 916 queue_regexp = CUnicode('#$$\W+-q\W+\$?\w+')
917 917 queue_template = CUnicode('#$$ -q $queue')
918 918
919 919 class SGEControllerLauncher(SGELauncher):
920 920 """Launch a controller using SGE."""
921 921
922 922 batch_file_name = CUnicode(u'sge_controller', config=True)
923 923 default_template= CUnicode(u"""#$$ -V
924 924 #$$ -S /bin/sh
925 925 #$$ -N ipcontrollerz
926 926 %s --log-to-file --cluster-dir $cluster_dir
927 927 """%(' '.join(ipcontrollerz_cmd_argv)))
928 928
929 929 def start(self, cluster_dir):
930 930 """Start the controller by profile or cluster_dir."""
931 931 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
932 932 return super(PBSControllerLauncher, self).start(1, cluster_dir)
933 933
934 934 class SGEEngineSetLauncher(SGELauncher):
935 935 """Launch Engines with SGE"""
936 936 batch_file_name = CUnicode(u'sge_engines', config=True)
937 937 default_template = CUnicode("""#$$ -V
938 938 #$$ -S /bin/sh
939 939 #$$ -N ipenginez
940 940 %s --cluster-dir $cluster_dir
941 941 """%(' '.join(ipenginez_cmd_argv)))
942 942
943 943 def start(self, n, cluster_dir):
944 944 """Start n engines by profile or cluster_dir."""
945 self.log.info('Starting %n engines with SGEEngineSetLauncher: %r' % (n, self.args))
945 self.log.info('Starting %i engines with SGEEngineSetLauncher: %r' % (n, self.args))
946 946 return super(SGEEngineSetLauncher, self).start(n, cluster_dir)
947 947
948 948
949 949 #-----------------------------------------------------------------------------
950 950 # A launcher for ipcluster itself!
951 951 #-----------------------------------------------------------------------------
952 952
953 953
954 954 class IPClusterLauncher(LocalProcessLauncher):
955 955 """Launch the ipcluster program in an external process."""
956 956
957 957 ipcluster_cmd = List(ipclusterz_cmd_argv, config=True)
958 958 # Command line arguments to pass to ipcluster.
959 959 ipcluster_args = List(
960 960 ['--clean-logs', '--log-to-file', '--log-level', str(logging.INFO)], config=True)
961 961 ipcluster_subcommand = Str('start')
962 962 ipcluster_n = Int(2)
963 963
964 964 def find_args(self):
965 965 return self.ipcluster_cmd + [self.ipcluster_subcommand] + \
966 966 ['-n', repr(self.ipcluster_n)] + self.ipcluster_args
967 967
968 968 def start(self):
969 969 self.log.info("Starting ipcluster: %r" % self.args)
970 970 return super(IPClusterLauncher, self).start()
971 971
@@ -1,273 +1,284 b''
1 1 """A TaskRecord backend using sqlite3"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2011 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 import json
10 10 import os
11 11 import cPickle as pickle
12 12 from datetime import datetime
13 13
14 14 import sqlite3
15 15
16 from zmq.eventloop import ioloop
17
16 18 from IPython.utils.traitlets import CUnicode, CStr, Instance, List
17 19 from .dictdb import BaseDB
18 20 from .util import ISO8601
19 21
20 22 #-----------------------------------------------------------------------------
21 23 # SQLite operators, adapters, and converters
22 24 #-----------------------------------------------------------------------------
23 25
24 26 operators = {
25 27 '$lt' : lambda a,b: "%s < ?",
26 28 '$gt' : ">",
27 29 # null is handled weird with ==,!=
28 30 '$eq' : "IS",
29 31 '$ne' : "IS NOT",
30 32 '$lte': "<=",
31 33 '$gte': ">=",
32 34 '$in' : ('IS', ' OR '),
33 35 '$nin': ('IS NOT', ' AND '),
34 36 # '$all': None,
35 37 # '$mod': None,
36 38 # '$exists' : None
37 39 }
38 40
39 41 def _adapt_datetime(dt):
40 42 return dt.strftime(ISO8601)
41 43
42 44 def _convert_datetime(ds):
43 45 if ds is None:
44 46 return ds
45 47 else:
46 48 return datetime.strptime(ds, ISO8601)
47 49
48 50 def _adapt_dict(d):
49 51 return json.dumps(d)
50 52
51 53 def _convert_dict(ds):
52 54 if ds is None:
53 55 return ds
54 56 else:
55 57 return json.loads(ds)
56 58
57 59 def _adapt_bufs(bufs):
58 60 # this is *horrible*
59 61 # copy buffers into single list and pickle it:
60 62 if bufs and isinstance(bufs[0], (bytes, buffer)):
61 63 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
62 64 elif bufs:
63 65 return bufs
64 66 else:
65 67 return None
66 68
67 69 def _convert_bufs(bs):
68 70 if bs is None:
69 71 return []
70 72 else:
71 73 return pickle.loads(bytes(bs))
72 74
73 75 #-----------------------------------------------------------------------------
74 76 # SQLiteDB class
75 77 #-----------------------------------------------------------------------------
76 78
77 79 class SQLiteDB(BaseDB):
78 80 """SQLite3 TaskRecord backend."""
79 81
80 82 filename = CUnicode('tasks.db', config=True)
81 83 location = CUnicode('', config=True)
82 84 table = CUnicode("", config=True)
83 85
84 86 _db = Instance('sqlite3.Connection')
85 87 _keys = List(['msg_id' ,
86 88 'header' ,
87 89 'content',
88 90 'buffers',
89 91 'submitted',
90 92 'client_uuid' ,
91 93 'engine_uuid' ,
92 94 'started',
93 95 'completed',
94 96 'resubmitted',
95 97 'result_header' ,
96 98 'result_content' ,
97 99 'result_buffers' ,
98 100 'queue' ,
99 101 'pyin' ,
100 102 'pyout',
101 103 'pyerr',
102 104 'stdout',
103 105 'stderr',
104 106 ])
105 107
106 108 def __init__(self, **kwargs):
107 109 super(SQLiteDB, self).__init__(**kwargs)
108 110 if not self.table:
109 111 # use session, and prefix _, since starting with # is illegal
110 112 self.table = '_'+self.session.replace('-','_')
111 113 if not self.location:
112 114 if hasattr(self.config.Global, 'cluster_dir'):
113 115 self.location = self.config.Global.cluster_dir
114 116 else:
115 117 self.location = '.'
116 118 self._init_db()
119
120 # register db commit as 2s periodic callback
121 # to prevent clogging pipes
122 # assumes we are being run in a zmq ioloop app
123 loop = ioloop.IOLoop.instance()
124 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
125 pc.start()
117 126
118 127 def _defaults(self):
119 128 """create an empty record"""
120 129 d = {}
121 130 for key in self._keys:
122 131 d[key] = None
123 132 return d
124 133
125 134 def _init_db(self):
126 135 """Connect to the database and get new session number."""
127 136 # register adapters
128 137 sqlite3.register_adapter(datetime, _adapt_datetime)
129 138 sqlite3.register_converter('datetime', _convert_datetime)
130 139 sqlite3.register_adapter(dict, _adapt_dict)
131 140 sqlite3.register_converter('dict', _convert_dict)
132 141 sqlite3.register_adapter(list, _adapt_bufs)
133 142 sqlite3.register_converter('bufs', _convert_bufs)
134 143 # connect to the db
135 144 dbfile = os.path.join(self.location, self.filename)
136 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES, cached_statements=16)
145 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
146 # isolation_level = None)#,
147 cached_statements=64)
137 148 # print dir(self._db)
138 149
139 150 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
140 151 (msg_id text PRIMARY KEY,
141 152 header dict text,
142 153 content dict text,
143 154 buffers bufs blob,
144 155 submitted datetime text,
145 156 client_uuid text,
146 157 engine_uuid text,
147 158 started datetime text,
148 159 completed datetime text,
149 160 resubmitted datetime text,
150 161 result_header dict text,
151 162 result_content dict text,
152 163 result_buffers bufs blob,
153 164 queue text,
154 165 pyin text,
155 166 pyout text,
156 167 pyerr text,
157 168 stdout text,
158 169 stderr text)
159 170 """%self.table)
160 171 # self._db.execute("""CREATE TABLE IF NOT EXISTS %s_buffers
161 172 # (msg_id text, result integer, buffer blob)
162 173 # """%self.table)
163 174 self._db.commit()
164 175
165 176 def _dict_to_list(self, d):
166 177 """turn a mongodb-style record dict into a list."""
167 178
168 179 return [ d[key] for key in self._keys ]
169 180
170 181 def _list_to_dict(self, line):
171 182 """Inverse of dict_to_list"""
172 183 d = self._defaults()
173 184 for key,value in zip(self._keys, line):
174 185 d[key] = value
175 186
176 187 return d
177 188
178 189 def _render_expression(self, check):
179 190 """Turn a mongodb-style search dict into an SQL query."""
180 191 expressions = []
181 192 args = []
182 193
183 194 skeys = set(check.keys())
184 195 skeys.difference_update(set(self._keys))
185 196 skeys.difference_update(set(['buffers', 'result_buffers']))
186 197 if skeys:
187 198 raise KeyError("Illegal testing key(s): %s"%skeys)
188 199
189 200 for name,sub_check in check.iteritems():
190 201 if isinstance(sub_check, dict):
191 202 for test,value in sub_check.iteritems():
192 203 try:
193 204 op = operators[test]
194 205 except KeyError:
195 206 raise KeyError("Unsupported operator: %r"%test)
196 207 if isinstance(op, tuple):
197 208 op, join = op
198 209 expr = "%s %s ?"%(name, op)
199 210 if isinstance(value, (tuple,list)):
200 211 expr = '( %s )'%( join.join([expr]*len(value)) )
201 212 args.extend(value)
202 213 else:
203 214 args.append(value)
204 215 expressions.append(expr)
205 216 else:
206 217 # it's an equality check
207 218 expressions.append("%s IS ?"%name)
208 219 args.append(sub_check)
209 220
210 221 expr = " AND ".join(expressions)
211 222 return expr, args
212 223
213 224 def add_record(self, msg_id, rec):
214 225 """Add a new Task Record, by msg_id."""
215 226 d = self._defaults()
216 227 d.update(rec)
217 228 d['msg_id'] = msg_id
218 229 line = self._dict_to_list(d)
219 230 tups = '(%s)'%(','.join(['?']*len(line)))
220 231 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
221 self._db.commit()
232 # self._db.commit()
222 233
223 234 def get_record(self, msg_id):
224 235 """Get a specific Task Record, by msg_id."""
225 236 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
226 237 line = cursor.fetchone()
227 238 if line is None:
228 239 raise KeyError("No such msg: %r"%msg_id)
229 240 return self._list_to_dict(line)
230 241
231 242 def update_record(self, msg_id, rec):
232 243 """Update the data in an existing record."""
233 244 query = "UPDATE %s SET "%self.table
234 245 sets = []
235 246 keys = sorted(rec.keys())
236 247 values = []
237 248 for key in keys:
238 249 sets.append('%s = ?'%key)
239 250 values.append(rec[key])
240 251 query += ', '.join(sets)
241 252 query += ' WHERE msg_id == %r'%msg_id
242 253 self._db.execute(query, values)
243 self._db.commit()
254 # self._db.commit()
244 255
245 256 def drop_record(self, msg_id):
246 257 """Remove a record from the DB."""
247 258 self._db.execute("""DELETE FROM %s WHERE mgs_id==?"""%self.table, (msg_id,))
248 self._db.commit()
259 # self._db.commit()
249 260
250 261 def drop_matching_records(self, check):
251 262 """Remove a record from the DB."""
252 263 expr,args = self._render_expression(check)
253 264 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
254 265 self._db.execute(query,args)
255 self._db.commit()
266 # self._db.commit()
256 267
257 268 def find_records(self, check, id_only=False):
258 269 """Find records matching a query dict."""
259 270 req = 'msg_id' if id_only else '*'
260 271 expr,args = self._render_expression(check)
261 272 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
262 273 cursor = self._db.execute(query, args)
263 274 matches = cursor.fetchall()
264 275 if id_only:
265 276 return [ m[0] for m in matches ]
266 277 else:
267 278 records = {}
268 279 for line in matches:
269 280 rec = self._list_to_dict(line)
270 281 records[rec['msg_id']] = rec
271 282 return records
272 283
273 284 __all__ = ['SQLiteDB'] No newline at end of file
@@ -1,205 +1,205 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 A simple python program of solving a 2D wave equation in parallel.
4 4 Domain partitioning and inter-processor communication
5 5 are done by an object of class MPIRectPartitioner2D
6 6 (which is a subclass of RectPartitioner2D and uses MPI via mpi4py)
7 7
8 8 An example of running the program is (8 processors, 4x2 partition,
9 9 400x100 grid cells)::
10 10
11 11 $ ipclusterz start --profile mpi -n 8 # start 8 engines (assuming mpi profile has been configured)
12 12 $ ./parallelwave-mpi.py --grid 400 100 --partition 4 2 --profile mpi
13 13
14 14 See also parallelwave-mpi, which runs the same program, but uses MPI
15 15 (via mpi4py) for the inter-engine communication.
16 16
17 17 Authors
18 18 -------
19 19
20 20 * Xing Cai
21 21 * Min Ragan-Kelley
22 22
23 23 """
24 24
25 25 import sys
26 26 import time
27 27
28 28 from numpy import exp, zeros, newaxis, sqrt
29 29
30 30 from IPython.external import argparse
31 from IPython.parallel.client import Client, Reference
31 from IPython.parallel import Client, Reference
32 32
33 33 def setup_partitioner(index, num_procs, gnum_cells, parts):
34 34 """create a partitioner in the engine namespace"""
35 35 global partitioner
36 36 p = MPIRectPartitioner2D(my_id=index, num_procs=num_procs)
37 37 p.redim(global_num_cells=gnum_cells, num_parts=parts)
38 38 p.prepare_communication()
39 39 # put the partitioner into the global namespace:
40 40 partitioner=p
41 41
42 42 def setup_solver(*args, **kwargs):
43 43 """create a WaveSolver in the engine namespace"""
44 44 global solver
45 45 solver = WaveSolver(*args, **kwargs)
46 46
47 47 def wave_saver(u, x, y, t):
48 48 """save the wave log"""
49 49 global u_hist
50 50 global t_hist
51 51 t_hist.append(t)
52 52 u_hist.append(1.0*u)
53 53
54 54
55 55 # main program:
56 56 if __name__ == '__main__':
57 57
58 58 parser = argparse.ArgumentParser()
59 59 paa = parser.add_argument
60 60 paa('--grid', '-g',
61 61 type=int, nargs=2, default=[100,100], dest='grid',
62 62 help="Cells in the grid, e.g. --grid 100 200")
63 63 paa('--partition', '-p',
64 64 type=int, nargs=2, default=None,
65 65 help="Process partition grid, e.g. --partition 4 2 for 4x2")
66 66 paa('-c',
67 67 type=float, default=1.,
68 68 help="Wave speed (I think)")
69 69 paa('-Ly',
70 70 type=float, default=1.,
71 71 help="system size (in y)")
72 72 paa('-Lx',
73 73 type=float, default=1.,
74 74 help="system size (in x)")
75 75 paa('-t', '--tstop',
76 76 type=float, default=1.,
77 77 help="Time units to run")
78 78 paa('--profile',
79 79 type=unicode, default=u'default',
80 80 help="Specify the ipcluster profile for the client to connect to.")
81 81 paa('--save',
82 82 action='store_true',
83 83 help="Add this flag to save the time/wave history during the run.")
84 84 paa('--scalar',
85 85 action='store_true',
86 86 help="Also run with scalar interior implementation, to see vector speedup.")
87 87
88 88 ns = parser.parse_args()
89 89 # set up arguments
90 90 grid = ns.grid
91 91 partition = ns.partition
92 92 Lx = ns.Lx
93 93 Ly = ns.Ly
94 94 c = ns.c
95 95 tstop = ns.tstop
96 96 if ns.save:
97 97 user_action = wave_saver
98 98 else:
99 99 user_action = None
100 100
101 101 num_cells = 1.0*(grid[0]-1)*(grid[1]-1)
102 102 final_test = True
103 103
104 104 # create the Client
105 105 rc = Client(profile=ns.profile)
106 106 num_procs = len(rc.ids)
107 107
108 108 if partition is None:
109 109 partition = [1,num_procs]
110 110
111 111 assert partition[0]*partition[1] == num_procs, "can't map partition %s to %i engines"%(partition, num_procs)
112 112
113 113 view = rc[:]
114 114 print "Running %s system on %s processes until %f"%(grid, partition, tstop)
115 115
116 116 # functions defining initial/boundary/source conditions
117 117 def I(x,y):
118 118 from numpy import exp
119 119 return 1.5*exp(-100*((x-0.5)**2+(y-0.5)**2))
120 120 def f(x,y,t):
121 121 return 0.0
122 122 # from numpy import exp,sin
123 123 # return 10*exp(-(x - sin(100*t))**2)
124 124 def bc(x,y,t):
125 125 return 0.0
126 126
127 127 # initial imports, setup rank
128 128 view.execute('\n'.join([
129 129 "from mpi4py import MPI",
130 130 "import numpy",
131 131 "mpi = MPI.COMM_WORLD",
132 132 "my_id = MPI.COMM_WORLD.Get_rank()"]), block=True)
133 133
134 134 # initialize t_hist/u_hist for saving the state at each step (optional)
135 135 view['t_hist'] = []
136 136 view['u_hist'] = []
137 137
138 138 # set vector/scalar implementation details
139 139 impl = {}
140 140 impl['ic'] = 'vectorized'
141 141 impl['inner'] = 'scalar'
142 142 impl['bc'] = 'vectorized'
143 143
144 144 # execute some files so that the classes we need will be defined on the engines:
145 145 view.run('RectPartitioner.py')
146 146 view.run('wavesolver.py')
147 147
148 148 # setup remote partitioner
149 149 # note that Reference means that the argument passed to setup_partitioner will be the
150 150 # object named 'my_id' in the engine's namespace
151 151 view.apply_sync(setup_partitioner, Reference('my_id'), num_procs, grid, partition)
152 152 # wait for initial communication to complete
153 153 view.execute('mpi.barrier()')
154 154 # setup remote solvers
155 155 view.apply_sync(setup_solver, I,f,c,bc,Lx,Ly,partitioner=Reference('partitioner'), dt=0,implementation=impl)
156 156
157 157 # lambda for calling solver.solve:
158 158 _solve = lambda *args, **kwargs: solver.solve(*args, **kwargs)
159 159
160 160 if ns.scalar:
161 161 impl['inner'] = 'scalar'
162 162 # run first with element-wise Python operations for each cell
163 163 t0 = time.time()
164 164 ar = view.apply_async(_solve, tstop, dt=0, verbose=True, final_test=final_test, user_action=user_action)
165 165 if final_test:
166 166 # this sum is performed element-wise as results finish
167 167 s = sum(ar)
168 168 # the L2 norm (RMS) of the result:
169 169 norm = sqrt(s/num_cells)
170 170 else:
171 171 norm = -1
172 172 t1 = time.time()
173 173 print 'scalar inner-version, Wtime=%g, norm=%g'%(t1-t0, norm)
174 174
175 175 impl['inner'] = 'vectorized'
176 176 # setup new solvers
177 177 view.apply_sync(setup_solver, I,f,c,bc,Lx,Ly,partitioner=Reference('partitioner'), dt=0,implementation=impl)
178 178 view.execute('mpi.barrier()')
179 179
180 180 # run again with numpy vectorized inner-implementation
181 181 t0 = time.time()
182 182 ar = view.apply_async(_solve, tstop, dt=0, verbose=True, final_test=final_test)#, user_action=wave_saver)
183 183 if final_test:
184 184 # this sum is performed element-wise as results finish
185 185 s = sum(ar)
186 186 # the L2 norm (RMS) of the result:
187 187 norm = sqrt(s/num_cells)
188 188 else:
189 189 norm = -1
190 190 t1 = time.time()
191 191 print 'vector inner-version, Wtime=%g, norm=%g'%(t1-t0, norm)
192 192
193 193 # if ns.save is True, then u_hist stores the history of u as a list
194 194 # If the partion scheme is Nx1, then u can be reconstructed via 'gather':
195 195 if ns.save and partition[-1] == 1:
196 196 import pylab
197 197 view.execute('u_last=u_hist[-1]')
198 198 # map mpi IDs to IPython IDs, which may not match
199 199 ranks = view['my_id']
200 200 targets = range(len(ranks))
201 201 for idx in range(len(ranks)):
202 202 targets[idx] = ranks.index(idx)
203 203 u_last = rc[targets].gather('u_last', block=True)
204 204 pylab.pcolor(u_last)
205 205 pylab.show()
@@ -1,209 +1,209 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 A simple python program of solving a 2D wave equation in parallel.
4 4 Domain partitioning and inter-processor communication
5 5 are done by an object of class ZMQRectPartitioner2D
6 6 (which is a subclass of RectPartitioner2D and uses 0MQ via pyzmq)
7 7
8 8 An example of running the program is (8 processors, 4x2 partition,
9 9 200x200 grid cells)::
10 10
11 11 $ ipclusterz start -n 8 # start 8 engines
12 12 $ ./parallelwave.py --grid 200 200 --partition 4 2
13 13
14 14 See also parallelwave-mpi, which runs the same program, but uses MPI
15 15 (via mpi4py) for the inter-engine communication.
16 16
17 17 Authors
18 18 -------
19 19
20 20 * Xing Cai
21 21 * Min Ragan-Kelley
22 22
23 23 """
24 24 #
25 25 import sys
26 26 import time
27 27
28 28 from numpy import exp, zeros, newaxis, sqrt
29 29
30 30 from IPython.external import argparse
31 from IPython.parallel.client import Client, Reference
31 from IPython.parallel import Client, Reference
32 32
33 33 def setup_partitioner(comm, addrs, index, num_procs, gnum_cells, parts):
34 34 """create a partitioner in the engine namespace"""
35 35 global partitioner
36 36 p = ZMQRectPartitioner2D(comm, addrs, my_id=index, num_procs=num_procs)
37 37 p.redim(global_num_cells=gnum_cells, num_parts=parts)
38 38 p.prepare_communication()
39 39 # put the partitioner into the global namespace:
40 40 partitioner=p
41 41
42 42 def setup_solver(*args, **kwargs):
43 43 """create a WaveSolver in the engine namespace."""
44 44 global solver
45 45 solver = WaveSolver(*args, **kwargs)
46 46
47 47 def wave_saver(u, x, y, t):
48 48 """save the wave state for each timestep."""
49 49 global u_hist
50 50 global t_hist
51 51 t_hist.append(t)
52 52 u_hist.append(1.0*u)
53 53
54 54
55 55 # main program:
56 56 if __name__ == '__main__':
57 57
58 58 parser = argparse.ArgumentParser()
59 59 paa = parser.add_argument
60 60 paa('--grid', '-g',
61 61 type=int, nargs=2, default=[100,100], dest='grid',
62 62 help="Cells in the grid, e.g. --grid 100 200")
63 63 paa('--partition', '-p',
64 64 type=int, nargs=2, default=None,
65 65 help="Process partition grid, e.g. --partition 4 2 for 4x2")
66 66 paa('-c',
67 67 type=float, default=1.,
68 68 help="Wave speed (I think)")
69 69 paa('-Ly',
70 70 type=float, default=1.,
71 71 help="system size (in y)")
72 72 paa('-Lx',
73 73 type=float, default=1.,
74 74 help="system size (in x)")
75 75 paa('-t', '--tstop',
76 76 type=float, default=1.,
77 77 help="Time units to run")
78 78 paa('--profile',
79 79 type=unicode, default=u'default',
80 80 help="Specify the ipcluster profile for the client to connect to.")
81 81 paa('--save',
82 82 action='store_true',
83 83 help="Add this flag to save the time/wave history during the run.")
84 84 paa('--scalar',
85 85 action='store_true',
86 86 help="Also run with scalar interior implementation, to see vector speedup.")
87 87
88 88 ns = parser.parse_args()
89 89 # set up arguments
90 90 grid = ns.grid
91 91 partition = ns.partition
92 92 Lx = ns.Lx
93 93 Ly = ns.Ly
94 94 c = ns.c
95 95 tstop = ns.tstop
96 96 if ns.save:
97 97 user_action = wave_saver
98 98 else:
99 99 user_action = None
100 100
101 101 num_cells = 1.0*(grid[0]-1)*(grid[1]-1)
102 102 final_test = True
103 103
104 104 # create the Client
105 105 rc = Client(profile=ns.profile)
106 106 num_procs = len(rc.ids)
107 107
108 108 if partition is None:
109 109 partition = [num_procs,1]
110 110 else:
111 111 num_procs = min(num_procs, partition[0]*partition[1])
112 112
113 113 assert partition[0]*partition[1] == num_procs, "can't map partition %s to %i engines"%(partition, num_procs)
114 114
115 115 # construct the View:
116 116 view = rc[:num_procs]
117 117 print "Running %s system on %s processes until %f"%(grid, partition, tstop)
118 118
119 119 # functions defining initial/boundary/source conditions
120 120 def I(x,y):
121 121 from numpy import exp
122 122 return 1.5*exp(-100*((x-0.5)**2+(y-0.5)**2))
123 123 def f(x,y,t):
124 124 return 0.0
125 125 # from numpy import exp,sin
126 126 # return 10*exp(-(x - sin(100*t))**2)
127 127 def bc(x,y,t):
128 128 return 0.0
129 129
130 130 # initialize t_hist/u_hist for saving the state at each step (optional)
131 131 view['t_hist'] = []
132 132 view['u_hist'] = []
133 133
134 134 # set vector/scalar implementation details
135 135 impl = {}
136 136 impl['ic'] = 'vectorized'
137 137 impl['inner'] = 'scalar'
138 138 impl['bc'] = 'vectorized'
139 139
140 140 # execute some files so that the classes we need will be defined on the engines:
141 141 view.execute('import numpy')
142 142 view.run('communicator.py')
143 143 view.run('RectPartitioner.py')
144 144 view.run('wavesolver.py')
145 145
146 146 # scatter engine IDs
147 147 view.scatter('my_id', range(num_procs), flatten=True)
148 148
149 149 # create the engine connectors
150 150 view.execute('com = EngineCommunicator()')
151 151
152 152 # gather the connection information into a single dict
153 153 ar = view.apply_async(lambda : com.info)
154 154 peers = ar.get_dict()
155 155 # print peers
156 156 # this is a dict, keyed by engine ID, of the connection info for the EngineCommunicators
157 157
158 158 # setup remote partitioner
159 159 # note that Reference means that the argument passed to setup_partitioner will be the
160 160 # object named 'com' in the engine's namespace
161 161 view.apply_sync(setup_partitioner, Reference('com'), peers, Reference('my_id'), num_procs, grid, partition)
162 162 time.sleep(1)
163 163 # convenience lambda to call solver.solve:
164 164 _solve = lambda *args, **kwargs: solver.solve(*args, **kwargs)
165 165
166 166 if ns.scalar:
167 167 impl['inner'] = 'scalar'
168 168 # setup remote solvers
169 169 view.apply_sync(setup_solver, I,f,c,bc,Lx,Ly, partitioner=Reference('partitioner'), dt=0,implementation=impl)
170 170
171 171 # run first with element-wise Python operations for each cell
172 172 t0 = time.time()
173 173 ar = view.apply_async(_solve, tstop, dt=0, verbose=True, final_test=final_test, user_action=user_action)
174 174 if final_test:
175 175 # this sum is performed element-wise as results finish
176 176 s = sum(ar)
177 177 # the L2 norm (RMS) of the result:
178 178 norm = sqrt(s/num_cells)
179 179 else:
180 180 norm = -1
181 181 t1 = time.time()
182 182 print 'scalar inner-version, Wtime=%g, norm=%g'%(t1-t0, norm)
183 183
184 184 # run again with faster numpy-vectorized inner implementation:
185 185 impl['inner'] = 'vectorized'
186 186 # setup remote solvers
187 187 view.apply_sync(setup_solver, I,f,c,bc,Lx,Ly,partitioner=Reference('partitioner'), dt=0,implementation=impl)
188 188
189 189 t0 = time.time()
190 190
191 191 ar = view.apply_async(_solve, tstop, dt=0, verbose=True, final_test=final_test)#, user_action=wave_saver)
192 192 if final_test:
193 193 # this sum is performed element-wise as results finish
194 194 s = sum(ar)
195 195 # the L2 norm (RMS) of the result:
196 196 norm = sqrt(s/num_cells)
197 197 else:
198 198 norm = -1
199 199 t1 = time.time()
200 200 print 'vector inner-version, Wtime=%g, norm=%g'%(t1-t0, norm)
201 201
202 202 # if ns.save is True, then u_hist stores the history of u as a list
203 203 # If the partion scheme is Nx1, then u can be reconstructed via 'gather':
204 204 if ns.save and partition[-1] == 1:
205 205 import pylab
206 206 view.execute('u_last=u_hist[-1]')
207 207 u_last = view.gather('u_last', block=True)
208 208 pylab.pcolor(u_last)
209 209 pylab.show() No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now