#!/usr/bin/python3

"""NEST Server with MPI support.

Usage:
  nest-server-mpi --help
  mpirun -np N nest-server-mpi [--host HOST] [--port PORT]

Options:
  -h --help     display usage information and exit
  --host HOST   use hostname/IP address HOST for server [default: 127.0.0.1]
  --port PORT   use port PORT for opening the socket [default: 52425]

"""

from docopt import docopt
from mpi4py import MPI

if __name__ == "__main__":
    opt = docopt(__doc__)

import os
import sys
import time

import nest
import nest.server

HOST = os.getenv("NEST_SERVER_HOST", "127.0.0.1")
PORT = os.getenv("NEST_SERVER_PORT", "52425")

comm = MPI.COMM_WORLD.Clone()
rank = comm.Get_rank()


def log(call_name, msg):
    msg = f"==> WORKER {rank}/{time.time():.7f} ({call_name}): {msg}"
    print(msg, flush=True)


if rank == 0:
    print("==> Starting NEST Server Master on rank 0", flush=True)
    nest.server.set_mpi_comm(comm)
    nest.server.run_mpi_app(host=opt.get("--host", HOST), port=opt.get("--port", PORT))

else:
    print(f"==> Starting NEST Server Worker on rank {rank}", flush=True)
    nest.server.set_mpi_comm(comm)
    while True:
        log("spinwait", "waiting for call bcast")
        call_name = comm.bcast(None, root=0)

        log(call_name, "received call bcast, waiting for data bcast")
        data = comm.bcast(None, root=0)

        log(call_name, f"received data bcast, data={data}")
        args, kwargs = data

        if call_name == "exec":
            response = nest.server.do_exec(args, kwargs)
        else:
            call, args, kwargs = nest.server.nestify(call_name, args, kwargs)
            log(call_name, f"local call, args={args}, kwargs={kwargs}")

            # The following exception handler is useful if an error
            # occurs simulataneously on all processes. If only a
            # subset of processes raises an exception, a deadlock due
            # to mismatching MPI communication calls is inevitable on
            # the next call.
            try:
                response = call(*args, **kwargs)
            except Exception:
                continue

        log(call_name, f"sending reponse gather, data={response}")
        comm.gather(nest.serializable(response), root=0)
