import networkx as nx
import numpy as np
import scipy as sp
import scipy.cluster.vq as vq
import matplotlib.pyplot as plt
import math
import random
import operator

################################################################################
# draw graph and color communities

def draw_comm_graph(G, comms):
	colors = [comms[v] for v in G.nodes()]
	
	nx.draw_kamada_kawai(G, with_labels=True, node_color=colors)
	plt.show()

################################################################################
# ground truth communities

def ground_truth():
	comms = {}
	with open("karate.gt") as f:
		for line in f:
			(key, val) = line.split()
			comms[key] = int(val)
	return comms

################################################################################
# spectral (bi-)clustering

def spectral(G):
	A = nx.to_numpy_matrix(G)
	D = np.diag(np.sum(np.array(A), axis=1))
	L = D - A
	(e, V) = np.linalg.eigh(L)
	x = V[:,1]
	comms = {}
	counter = 0
	for n in G.nodes():
		if x[counter] > 0:
			comms[n] = 0
		else:
			comms[n] = 1
		counter += 1
	return comms

################################################################################
# modularity maximization

def mod_max(G):
	C = nx.community.greedy_modularity_communities(G)
	comms = {}
	counter = 0
	for c in C:
		for v in c:
			comms[v] = counter
		counter += 1
	return comms

################################################################################
# label propagation

def label_prop(G):
	comms = {}
	for v in G.nodes():
		comms[v] = int(v)
	updates = 1
	while updates > 0:
		updates = 0
		for v in sorted(G.nodes(), key=lambda k: random.random()):
			counts = {}
			for u in G.neighbors(v):
				if comms[u] not in dict.keys(counts):
					counts[comms[u]] = 1
				else:
					counts[comms[u]] += 1
			c = np.random.choice([k for k in counts.keys() if counts[k]==max(counts.values())])
			if c != comms[v]:
				comms[v] = c
				updates += 1
	return comms

################################################################################
# evaluate edge cut

def edge_cut(G, comms):
	cut = 0
	for v in G.nodes():
		for u in G.neighbors(v):
			if comms[u] != comms[v]:
				cut += 1
	return cut

################################################################################
# conductance
# conductance(C) = cut(C) / min(degree_sum(C), degree_sum(complement(C)))
# - probability in a random walk that we leave the community

def conductance(G, comms):
	unique_comms = list(set([x for x in comms.values()]))
	avg_conductance = 0.0
	for c in unique_comms:
		cut = 0
		dsum = 0
		for v in G.nodes():
			if comms[v] == c:
				dsum += G.degree(v)
				for u in G.neighbors(v):
					if comms[v] != comms[u]:
						cut += 1
		dsum_comp = G.size()*2 - dsum
		avg_conductance += cut / min(dsum, dsum_comp)
	return avg_conductance / len(unique_comms)

################################################################################
# evaluate modularity

def modularity(G, comms):
	unique_comms = list(set([x for x in comms.values()]))
	new_comms = []
	for c in unique_comms:
		C = []
		for v in G.nodes():
			if comms[v] == c:
				C.append(v)
		new_comms.append(set(C))
	return nx.community.modularity(G, new_comms)
	
################################################################################
# evaluate NMI

def nmi(G, comms, gt):
	T = {}
	U = {}
	V = {}
	
	# construct all U_i
	for i in range(len(comms)):
		U[i] = set()
		for v in G.nodes():
			if comms[v] == i:
				U[i].add(v)
	
	# construct all v_j
	for j in range(len(gt)):
		V[j] = set()
		for v in G.nodes():
			if gt[v] == j:
				V[j].add(v)
	
	# construct contingency table
	for i in range(len(comms)):
		for j in range(len(gt)):
			T[(i,j)] = U[i].intersection(V[j])
	
	# calculate mutual information
	MI = 0.0
	for i in range(len(comms)):
		for j in range(len(gt)):
			if len(T[(i,j)]) > 0:
				P_i = len(U[i]) / G.order()
				P_j = len(V[j]) / G.order()
				P_ij = len(T[(i,j)]) / G.order()
				MI += P_ij * math.log(P_ij / (P_i * P_j))
	
	# get entropy H(U)
	H_U = 0.0
	for i in range(len(comms)):
		P_i = len(U[i]) / G.order()
		if (P_i > 0.0):
			H_U -= P_i * math.log(P_i)
	
	# get entropy H(V)
	H_V = 0.0
	for j in range(len(gt)):
		P_j = len(V[j]) / G.order()
		if (P_j > 0.0):
			H_V -= P_j * math.log(P_j)
	
	# normalize
	NMI = MI / (.5*(H_U + H_V))
	
	return NMI

################################################################################
# run our analysis
G = nx.read_edgelist("karate.data", comments="%")

# get our community assignment
Cg = ground_truth()
Cs = spectral(G)
Cm = mod_max(G)
Cl = label_prop(G)

# print edge cuts
print("Edge cut (ground truth):", edge_cut(G, Cg))
print("Edge cut (spectral):", edge_cut(G, Cs))
print("Edge cut (modularity):", edge_cut(G, Cm))
print("Edge cut (label prop):", edge_cut(G, Cl))

# print avg conductance
print("Conductance (ground truth):", conductance(G, Cg))
print("Conductance (spectral):", conductance(G, Cs))
print("Conductance (modularity):", conductance(G, Cm))
print("Conductance (label prop):", conductance(G, Cl))

# print modularity
print("Modularity (ground truth):", modularity(G, Cg))
print("Modularity (spectral):", modularity(G, Cs))
print("Modularity (modularity):", modularity(G, Cm))
print("Modularity (label prop):", modularity(G, Cl))

# normalized mutual information
print("NMI (test case):", nmi(G, Cg, Cg))
print("NMI (spectral):", nmi(G, Cs, Cg))
print("NMI (modularity):", nmi(G, Cm, Cg))
print("NMI (label prop):", nmi(G, Cl, Cg))

draw_comm_graph(G, Cg)
draw_comm_graph(G, Cs)
draw_comm_graph(G, Cm)
draw_comm_graph(G, Cl)
