#-------------------------------------------------------------------------------
# Imports
#-------------------------------------------------------------------------------

import time
import numpy

import IPython.kernel.magic
from IPython.kernel import client
from IPython.kernel.error import *

mec = client.MultiEngineClient()

#-------------------------------------------------------------------------------
# Setup
#-------------------------------------------------------------------------------

mec.reset()
mec.activate()
mec.block = True
mec.get_ids()
n = len(mec)
assert n >= 4, "Not Enough Engines: %i, 4 needed for this script"%n

values = [
    10,
    1.0,
    range(100),
    ('asdf', 1000),
    {'a': 10, 'b': 20}
    ]

keys = ['a','b','c','d','e']

sequences = [
    range(100),
    numpy.arange(100)
]

#-------------------------------------------------------------------------------
# Blocking execution
#-------------------------------------------------------------------------------

# Execute

mec.execute('import math')
mec.execute('a = 2.0*math.pi')
mec.execute('print a')

for id in mec.get_ids():
    mec.execute('b=%d' % id, targets=id)


mec.execute('print b')

try:
    mec.execute('b = 10',targets=-1)
except InvalidEngineID:
    print "Caught invalid engine ID OK."

try:
    mec.execute('a=5; 1/0')
except CompositeError:
    print "Caught 1/0 correctly."



%px print a, b
try:
    %px 1/0
except CompositeError:
    print "Caught 1/0 correctly."


%autopx

import numpy
a = numpy.random.rand(4,4)
a = a+a.transpose()
print numpy.linalg.eigvals(a)

%autopx


mec.targets = [0,2]
%px a = 5
mec.targets = [1,3]
%px a = 10
mec.targets = 'all'
%px print a


# Push/Pull

mec.push(dict(a=10, b=30, c={'f':range(10)}))
mec.pull(('a', 'b'))
mec.zip_pull(('a', 'b'))

for id in mec.get_ids():
    mec.push(dict(a=id), targets=id)


for id in mec.get_ids():
    mec.pull('a', targets=id)


mec.pull('a')


mec['a'] = 100
mec['a']

# get_result/reset/keys

mec.get_result()
%result
mec.keys()
mec.reset()
mec.keys()

try:
    %result
except CompositeError:
    print "Caught IndexError ok."


%px a = 5
mec.get_result(1)
mec.keys()

# Queue management methods

%px import time
prs = [mec.execute('time.sleep(2.0)',block=False) for x in range(5)]


mec.queue_status()
time.sleep(3.0)
mec.clear_queue()
mec.queue_status()
time.sleep(2.0)
mec.queue_status()

mec.barrier(prs)

for pr in prs:
    try:
        pr.r
    except CompositeError:
        print "Caught QueueCleared OK."


# scatter/gather

mec.scatter('a', range(10))
mec.gather('a')
mec.scatter('b', numpy.arange(10))
mec.gather('b')

#-------------------------------------------------------------------------------
# Non-Blocking execution
#-------------------------------------------------------------------------------

mec.block = False

# execute

pr1 = mec.execute('a=5')
pr2 = mec.execute('import sets')

mec.barrier((pr1, pr2))

pr1 = mec.execute('1/0')
pr2 = mec.execute('c = sets.Set()')

mec.barrier((pr1, pr2))
try:
    pr1.r
except CompositeError:
    print "Caught ZeroDivisionError OK."

pr = mec.execute("print 'hi'")
pr.r

pr = mec.execute('1/0')
try:
    pr.r
except CompositeError:
    print "Caught ZeroDivisionError OK."

# Make sure we can reraise it!
try:
    pr.r
except CompositeError:
    print "Caught ZeroDivisionError OK."  

# push/pull

pr1 = mec.push(dict(a=10))
pr1.get_result()
pr2 = mec.pull('a')
pr2.r

# flush

mec.flush()
pd1 = mec.execute('a=30')
pd2 = mec.pull('a')
mec.flush()

try:
    pd1.get_result()
except InvalidDeferredID:
    print "PendingResult object was cleared OK."


try:
    pd2.get_result()
except InvalidDeferredID:
    print "PendingResult object was cleared OK."



# This is a command to make sure the end of the file is happy.

print "The tests are done!"