import networkx as nx
import numpy as np
import scipy as sp
import scipy.cluster.vq as vq
import matplotlib.pyplot as plt
import math
import random
import operator

################################################################################
# draw graph and color communities

def draw_comm_graph(G, comms):
  colors = [comms[v] for v in G.nodes()]
  nx.draw_kamada_kawai(G, with_labels=True, node_color=colors)
  plt.show()


################################################################################
# evaluate edge cut
def edge_cut(G, comms):
  cut = 0
  for v in G.nodes():
    for u in G.neighbors(v):
      if comms[u] != comms[v]:
        cut += 1
  return cut

################################################################################
# Read in our good friend Mr. Karate

G = nx.read_edgelist("karate.data", comments="%")
comms = {}
with open("karate.gt.data") as f:
  for line in f:
    (key, val) = line.split()
    comms[key] = int(val)

draw_comm_graph(G, comms)

################################################################################
# Spectral bi-clustering. Here, we'll separate our graph into two clusters by
# using the Fiedler vector of the non-normalized adjacency matrix. We can 
# observe that the cut produced closely resembles our ground truth.

A = nx.to_numpy_matrix(G)
D = np.diag(np.sum(np.array(A), axis=1))
L = D - A
(V, U) = np.linalg.eigh(L)
x = U[:,1]
comms = {}
counter = 0
for v in G.nodes():
  if x[counter] > 0:
    comms[v] = 0
  else:
    comms[v] = 1
  counter += 1

draw_comm_graph(G, comms)
print(edge_cut(G, comms))

# To do a balanced partitioning, we can simply use the median instead of 0
comms = {}
counter = 0
med = np.median(x.flat)
for v in G.nodes():
  if x[counter] > med:
    comms[v] = 0
  else:
    comms[v] = 1
  counter += 1

draw_comm_graph(G, comms)
print(edge_cut(G, comms))

################################################################################
# Spectral k-clustering. We'll now look at separating our graph into some 
# arbitrary number of clusters. We'll use the k-means approach as discussed.

# Our target number of clusters
k = 3

# We select our k eigenvectors to give us coordinates to pass to k-means
X = U[:,1:(k+1)]

# Call k-means
(means, labels) = vq.kmeans2(X, k)

# set our comms to visualize
comms = {}
counter = 0
for v in G.nodes():
  comms[v] = labels[counter]
  counter += 1

draw_comm_graph(G, comms)

################################################################################
# Spectral Visualization. Here we have a basic visualization algorithm that will
# map an input graph onto 2D Euclidean space using its eigenvalues. You'll 

def draw_graph_from_spectrum(G, comms):
  A = nx.to_numpy_matrix(G)
  D = np.diag(np.sum(np.array(A), axis=1))
  L = D - A
  (V, U) = np.linalg.eigh(L)
  
  # x and y correspond to the x,y coordinate values for each vertex. We'll need
  # to first construct a dictionary from our V numpy matrix to work with our
  # networkx graph
  x = {}
  y = {}
  counter = 0
  for v in G.nodes():
    x[v] = U[counter,1]
    y[v] = U[counter,2]
    counter += 1
  
  # However, we also want to plot edges. So for each edge, we'll plot a line
  # between the x,y values of its endpoint vertices
  plt.clf()
  plt.figure(1)
  for e in G.edges():
    u = e[0]
    v = e[1]
    plt.plot([x[v], x[u]], [y[v], y[u]], marker='o', color='b')
  
  # We can also plot and label our vertices
  cmap = ['k','b','y','g','r']
  for v in G.nodes():
    plt.plot(x[v], y[v], marker='o', color=cmap[int(comms[v])])
  
  for v in G.nodes():
    plt.text(x[v], y[v], v)
  
  # and we'll compare to the original visualization
  plt.figure(2)
  colors = [comms[v] for v in G.nodes()]
  nx.draw_kamada_kawai(G, with_labels=True, node_color=colors)
  
  # show it
  plt.show()

################################################################################
# Spectral Coarsening. Here, we'll just iteratively coarsen our graph and then 
# visualize it with each level of coarsening. To do so, we'll compute the
# eigenvectors, select the pair of vertices that have the closest spectral
# values, and then merge them.

while G.order() > 1:
  A = nx.to_numpy_matrix(G)
  D = np.diag(np.sum(np.array(A), axis=1))
  L = D - A
  (V, U) = np.linalg.eigh(L)
  x = {}
  y = {}
  counter = 0
  for v in G.nodes():
    x[v] = U[counter,1]
    y[v] = U[counter,2]
    counter += 1
  
  min_dist = 999
  min_pair = (-1,-1)
  for v in G.nodes():
    for u in G.nodes():
      if u != v:
        dist = math.sqrt((x[v] - x[u])**2 + (y[v] - y[u])**2)
        if dist < min_dist:
          min_dist = dist
          min_pair = (u, v)
  
  G = nx.contracted_nodes(G, min_pair[0], min_pair[1])
  draw_graph_from_spectrum(G, comms)


