import networkx as nx
import numpy as np
import scipy as sp
import scipy.cluster.vq as vq
import matplotlib.pyplot as plt
import math
import random
import operator
import time
from mpi4py import MPI

# Run as:
# mpirun -n 4 python3 lec20.py

################################################################################
# Determine our parallel environment details. We will need the total number of 
# ranks (nprocs) and the local rank ID (procid)

comm = MPI.COMM_WORLD
nprocs = comm.Get_size()
procid = comm.Get_rank()

print("Hello Miners. Rank: ", procid, " / ", nprocs)

################################################################################
# Each rank creates an array of random ints

my_array = list()
for i in range(0, 100):
  my_array.append(int(random.random()*4294967296))

################################################################################
# Determine the local and global sums

local_sum = sum(my_array)
global_sum = 0

print("Rank: ", procid, "Local Sum: ", local_sum)

# Here's our communication buffers
snd = np.zeros((1))
rcv = np.zeros((1))

# do an all reduce to get the global sum
snd[0] = local_sum
comm.Allreduce(snd, rcv, op=MPI.SUM)
global_sum = rcv[0]

print("Rank: ", procid, "Global Sum: ", global_sum)


################################################################################
# Now compute local sum by sharing all arrays

# setup the buffers
snd = np.array(my_array)
rcv = list()
for i in range(0, nprocs):
  rcv.append(list())

# do an all-to-all style communication of sends and recvs
for i in range(0, nprocs):
  for j in range(0, nprocs):
    if i == j:
      if i == procid:
        rcv[i] = snd
      continue
    
    if i == procid:
      comm.send(my_array, dest=j)
    
    if j == procid:
      rcv[i] = comm.recv(source=i)

# now everyone perform the global summation locally to check
global_sum = 0
for i in range(0, nprocs):
  global_sum += sum(rcv[i])

print("Rank: ", procid, "Local Global Sum: ", global_sum)
