import networkx as nx
import numpy as np
import scipy as sp
import scipy.cluster.vq as vq
import matplotlib.pyplot as plt
import math
import random
import operator


################################################################################
# read in data and construct graph

def construct_graph():
	counter = 1
	label_lgenre = dict()
	with open("discogs_lgenre/ent.discogs_lgenre_lgenre.label.name") as f:
		for line in f:
			label_lgenre[counter] = line.rstrip()
			counter += 1
	
	counter = 1
	genre_lgenre = dict()
	with open("discogs_lgenre/ent.discogs_lgenre_lgenre.genre.name") as f:
		for line in f:
			genre_lgenre[counter] = line.rstrip()
			counter += 1
	
	counter = 1
	label_lstyle = dict()
	with open("discogs_lstyle/ent.discogs_lstyle_lstyle.label.name") as f:
		for line in f:
			label_lstyle[counter] = line.rstrip()
			counter += 1
	
	counter = 1
	style_lstyle = dict()
	with open("discogs_lstyle/ent.discogs_lstyle_lstyle.style.name") as f:
		for line in f:
			style_lstyle[counter] = line.rstrip()
			counter += 1
	
	G_lgenre = nx.Graph()
	G_lstyle = nx.Graph()
	
	with open("discogs_lgenre/out.discogs_lgenre_lgenre") as f:
		for line in f:
			if line[0] == '%':
				continue
			(v1, v2) = line.split()
			G_lgenre.add_edge(label_lgenre[int(v1)], genre_lgenre[int(v2)])
	
	with open("discogs_lstyle/out.discogs_lstyle_lstyle") as f:
		for line in f:
			if line[0] == '%':
				continue
			(v1, v2) = line.split()
			G_lstyle.add_edge(label_lstyle[int(v1)], style_lstyle[int(v2)])
	
	G_actual = nx.Graph()
	for v in style_lstyle.values():
		Nv = list()
		for u in G_lstyle.neighbors(v):
			Nv.append(u)
		
		ratio = 0.001
		num_edges = int(ratio*len(Nv)*(len(Nv)-1)*0.5)
		for i in range(0, num_edges):
			x = random.choice(Nv)
			y = random.choice(Nv)
			G_actual.add_edge(x, y)
	
	vals = label_lgenre.values()
	ground_truth = {}
	for v in list(G_actual.nodes()):
		if v not in vals:
			G_actual.remove_node(v)
		else:
			counts = {}
			for u in G_lgenre.neighbors(v):
				if u == 'Electronic' or u == 'Rock':
					continue
				if u not in dict.keys(counts):
					counts[u] = 1
				else:
					counts[u] += 1
			if len(counts.values()) == 0:
				G_actual.remove_node(v)
			else:
				c = np.random.choice([k for k in counts.keys() if counts[k]==max(counts.values())])
				ground_truth[v] = c
	
	G = G_actual.subgraph(sorted(nx.connected_components(G_actual), key=len, reverse=True)[0])
	return G, ground_truth

################################################################################
# method 1 - simple label propagation

def label_prop(G, labels, unlabeled):
	updates = len(unlabeled)
	while updates > 1000:
		updates = 0
		for v in sorted(unlabeled, key=lambda k: random.random()):
			counts = {}
			for u in G.neighbors(v):
				if labels[u] == "NA":
					continue
				elif labels[u] not in dict.keys(counts):
					counts[labels[u]] = 1
				else:
					counts[labels[u]] += 1
			if len(counts.values()) > 0:
				c = np.random.choice([k for k in counts.keys() if counts[k]==max(counts.values())])
				if c != labels[v]:
					labels[v] = c
					updates += 1
		print(updates, eval_precision(G, gt, labels, unlabeled))
	return labels


################################################################################
# method 2 - iterative naive bayes

def init_features(G, labels):
	feat_idx = {}
	counter = 0
	for l in labels.values():
		if l not in feat_idx:
			feat_idx[l] = counter
			counter += 1
	features = {}
	for v in G.nodes():
		features[v] = [0.0]*len(feat_idx)
		update_features(G, v, features, feat_idx, labels)
	return features, feat_idx

def update_features(G, v, features, feat_idx, labels):
	Nv_size = 0.0
	for i in range(0, len(feat_idx)):
		features[v][i] = 0.0
	for u in G.neighbors(v):
		if labels[u] != "NA":
			features[v][feat_idx[labels[u]]] += 1.0
			Nv_size += 1.0
	if Nv_size > 0.0:
		for i in range(0, len(feat_idx)):
			features[v][i] /= Nv_size

def init_bayes(G, features, feat_idx, labels):
	C = [0.0]*len(feat_idx)
	count = 0
	for v in G.nodes():
		if labels[v] != "NA":
			C[feat_idx[labels[v]]] += 1
			count += 1
	for i in range(0, len(feat_idx)):
		C[i] /= count
	
	avgs = dict()
	stds = dict()
	counts = dict()
	for c in feat_idx.keys():
		avgs[c] = [0.0]*len(feat_idx)
		stds[c] = [0.0]*len(feat_idx)
		counts[c] = [0.0]*len(feat_idx)
	
	for v in G.nodes():
		if labels[v] != "NA":
			c = labels[v]
			for i in range(0, len(feat_idx)):
				avgs[c][i] += features[v][i]
				counts[c][i] += 1.0
	for c in feat_idx.keys():
		for i in range(0, len(feat_idx)):
			if counts[c][i] > 0.0:
				avgs[c][i] /= counts[c][i]
			else:
				avgs[c][i] = 0.0
	
	for v in G.nodes():
		if labels[v] != "NA":
			c = labels[v]
			for i in range(0, len(feat_idx)):
				stds[c][i] += (features[v][i] - avgs[c][i])**2
	for c in feat_idx.keys():
		for i in range(0, len(feat_idx)):
			if counts[c][i] > 0.0:
				stds[c][i] /= counts[c][i]
				stds[c][i] = math.sqrt(avgs[c][i])
			else:
				stds[c][i] = 0.0
	
	X = (avgs, stds)
	return C, X

def calc_prob(val, avg, std):
	exponent = math.exp(-((val - avg)**2 / (2 * std**2)))
	return (1 / (math.sqrt(2 * math.pi) * std)) * exponent

def update_label(G, v, features, feat_idx, C, X):
	max_c = "NA"
	max_prob = 0.0
	for c in feat_idx.keys():
		prob = C[feat_idx[c]]
		for i in range(0, len(feat_idx)):
			avg = X[0][c][i]
			std = X[1][c][i]
			if avg > 0.0 and std > 0.0:
				prob *= calc_prob(features[v][i], avg, std)
		if prob > max_prob:
			max_prob = prob
			max_c = c
	return max_c

def naive_bayes(G, labels, unlabeled):
	features, feat_idx = init_features(G, labels)
	C, X = init_bayes(G, features, feat_idx, labels)
	updates = len(unlabeled)
	while updates > len(unlabeled)*0.01:
		updates = 0
		for v in unlabeled:
			update_features(G, v, features, feat_idx, labels)
			new_label = update_label(G, v, features, feat_idx, C, X)
			if new_label != labels[v]:
				labels[v] = new_label
				updates += 1
		print(updates, eval_precision(G, gt, labels, unlabeled))
	return labels

################################################################################
# random walk approach

def construct_matrices(G, labels):
	ids = {}
	counter = 0
	for v in G.nodes():
		if labels[v] != "NA":
			ids[v] = counter
			counter += 1
	
	num_labeled = counter
	
	counter = 0
	for v in G.nodes():
		if labels[v] == "NA":
			ids[v] = counter
			counter += 1
	
	num_unlabeled = counter
	
	rows_P_ul = []
	cols_P_ul = []
	vals_P_ul = []
	for v in G.nodes():
		if labels[v] != "NA":
			continue
		for u in G.neighbors(v):
			if labels[u] != "NA":
				rows_P_ul.append(ids[v])
				cols_P_ul.append(ids[u])
				vals_P_ul.append(1.0 / G.degree(v))
	
	rows_P_uu = []
	cols_P_uu = []
	vals_P_uu = []
	for v in G.nodes():
		if labels[v] != "NA":
			continue
		for u in G.neighbors(v):
			if labels[u] == "NA":
				rows_P_uu.append(ids[v])
				cols_P_uu.append(ids[u])
				vals_P_uu.append(1 / G.degree(v))
	
	feat_idx = {}
	counter = 0
	for l in labels.values():
		if l not in feat_idx:
			feat_idx[l] = counter
			counter += 1
	
	num_feat = len(feat_idx)
	
	rows_Y_l = []
	cols_Y_l = []
	vals_Y_l = []
	for v in G.nodes():
		if labels[v] != "NA":
			rows_Y_l.append(ids[v])
			cols_Y_l.append(feat_idx[labels[v]])
			vals_Y_l.append(1.0)
	
	rows_Y_u = []
	cols_Y_u = []
	vals_Y_u = []
	for v in G.nodes():
		if labels[v] == "NA":
			for i in range(0, num_feat):
				rows_Y_u.append(ids[v])
				cols_Y_u.append(1)
				vals_Y_u.append(1.0 / num_feat)
	
	P_ul = sp.sparse.csr_matrix((vals_P_ul, (rows_P_ul, cols_P_ul)), shape=(num_unlabeled, num_labeled))
	P_uu = sp.sparse.csr_matrix((vals_P_uu, (rows_P_uu, cols_P_uu)), shape=(num_unlabeled, num_unlabeled))
	Y_l = sp.sparse.csr_matrix((vals_Y_l, (rows_Y_l, cols_Y_l)), shape=(num_labeled, num_feat))
	Y_u = sp.sparse.csr_matrix((vals_Y_u, (rows_Y_u, cols_Y_u)), shape=(num_unlabeled, num_feat))
	
	return P_ul, P_uu, Y_l, Y_u, ids, feat_idx

def determine_classes(G, labels, unlabeled, feat_idx, ids, Y_u):
	for v in unlabeled:
		max_c = "NA"
		max_Y = 0.0
		for c in feat_idx.keys():
			if Y_u[ids[v], feat_idx[c]] > max_Y:
				max_Y = Y_u[ids[v], feat_idx[c]]
				max_c = c
		labels[v] = max_c
	return labels

def random_walk(G, labels, unlabeled):
	P_ul, P_uu, Y_l, Y_u, ids, feat_idx = construct_matrices(G, labels)
	for t in range(0, 10):
		Y_u_next = P_ul*Y_l + P_uu*Y_u
		Y_u = Y_u_next.copy()
		labels = determine_classes(G, labels, unlabeled, feat_idx, ids, Y_u)
		print(t, eval_precision(G, gt, labels, unlabeled))
	return labels



################################################################################
# initialize labels

def init_labels(G, gt, ratio):
	labels = {}
	unlabeled = []
	for v in G.nodes():
		if random.random() < ratio:
			labels[v] = "NA"
			unlabeled.append(v)
		else:
			labels[v] = gt[v]
	return labels, unlabeled


################################################################################
# evaluate precision

def eval_precision(G, gt, labels, unlabeled):
	true_positives = 0
	for v in unlabeled:
		if gt[v] == labels[v]:
			true_positives += 1
	return true_positives / len(unlabeled)

################################################################################
# run tests

# construct graph and get ground_truth
G, gt = construct_graph()

# proportion of data in unlabeled set
# 1 - test is size of training set
test = 0.1

# run label propagation
labels, unlabeled = init_labels(G, gt, test)
labels = label_prop(G, labels, unlabeled)

# run naive bayes
labels, unlabeled = init_labels(G, gt, test)
labels = naive_bayes(G, labels, unlabeled)

# run random walk
labels, unlabeled = init_labels(G, gt, test)
labels = random_walk(G, labels, unlabeled)

