import networkx as nx
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import math
import random

################################################################################
# Below are our functions for optimizing our matrix factorization

# Here we perform a round of gradient descent on all nonzeros
def train(A, NZ, U, V, alpha, beta, weighting):
	for i in range(A.shape[0]):
		for j in range(A.shape[1]):
			if A[i,j] == 0:
				w = weighting
			else:
				w = 1.0
			
			# current error
			eij = A[i,j] - np.dot(U[i,:], np.dot(V, U[j,:]))
			
			# our gradients
			gradI = np.dot(V, U[j,:])
			gradJ = np.dot(U[i,:], V)
			gradV = np.outer(U[i,:], U[j,:])
			
			# We descending
			U[i,:] = U[i,:] + w * alpha * (2 * eij * gradI - beta * U[i,:])
			U[j,:] = U[j,:] + w * alpha * (2 * eij * gradJ - beta * U[j,:])
			V = V + w * alpha * (2 * eij * gradV - beta * V)

# This will compute our error
def compute_error(A, NZ, U, V, beta):
	nnz = len(NZ[0])
	error = 0
	for i in range(A.shape[0]):
		for j in range(A.shape[1]):
			error = error + pow(A[i,j] - np.dot(U[i,:], np.dot(V, U[j,:])), 2)
	
	error = error + beta / 2 * np.linalg.norm(U, 'fro')
	error = error + beta / 2 * np.linalg.norm(V, 'fro')
	return error

# This will perform some number of iterations of gradient descent.
# Note that we use a fixed number of iterations. We'd probably want in practice
# to have a more dynamic stopping criteria.
def factorize_matrix(A, U, V, iterations, alpha, beta, weighting):
	NZ = A.nonzero()
	for i in range(iterations):
		train(A, NZ, U, V, alpha, beta, weighting)
		error = compute_error(A, NZ, U, V, beta)
		print(error)
	return U, V.T

################################################################################
# Toy test example that we can easily visualize
A = [
		 [1,1,1,1,0],
		 [1,0,0,1,1],
		 [1,1,0,1,0],
		 [1,0,0,1,1],
		 [1,1,0,1,0]
	 ]
A = np.array(A)

# n is the size of our matrix
# k in the number of latent features we want to use
n = A.shape[0]
k = 5

# initialize U and V to be random
U = np.random.rand(n, k)
V = np.random.rand(k, k)

# Perform our training
# The variables below were selected arbitrarily
iterations = 10000
alpha = 0.001
beta = 0.02
weighting = np.count_nonzero(A) / (n*n)
(U, V) = factorize_matrix(A, U, V, iterations, alpha, beta, weighting)
P = np.dot(U, np.dot(V, U.T))

################################################################################
# Now let's take a look at the ol' DNC temporal graph
G = nx.read_weighted_edgelist("out.dnc-temporalGraph.data", create_using=nx.MultiGraph(), comments="%")
edges = sorted(G.edges(data=True), key=lambda t: t[2].get('weight', 1))

# As before, we'll take the first 5% of edges
G1 = nx.MultiGraph()
t_0 = G.size() / 20
counter = 0
max_k = 0
for e in edges:
	G1.add_edge(e[0], e[1])
	counter += 1
	if counter > t_0:
		break

# Construct an adjacency matrix from our graph
A = nx.adjacency_matrix(G1)

# Set up our parameters and do some descending of that gradient
# Note that it has become apparent that our relatively naive approach of looping
# over all i,j indices is becoming quite slow, even for our relatively modest
# matrix of shape 420x420.
n = A.shape[0]
k = 10
U = np.random.rand(n, k)
V = np.random.rand(k, k)
iterations = 10000
alpha = 0.0001
beta = 0.02
weighting = A.nnz / (n*n)
(U, V) = factorize_matrix(A, U, V, iterations, alpha, beta, weighting)

# Here we're doing our prediction. We'll take the maximum values in our 
# prediction matrix that correspond to non-existing edges in our training graph.
P = np.dot(U, np.dot(V, U.T))
num_preds = 100
Preds = []
while len(Preds) < num_preds:
	idx = np.unravel_index(np.argmax(P, axis=None), P.shape)
	if A[idx] == 0 and idx[0] != idx[1]:
		Preds.append(idx)
	P[(idx[0], idx[1])] = 0
	P[(idx[1], idx[0])] = 0

# We'll then calculate the precision of our predictions
count = 0
for idx in Preds:
	(u, v) = (list(G1.nodes())[idx[0]], list(G1.nodes())[idx[1]])
	if G.has_edge(u, v):
		count += 1

precision = count / num_preds
print(precision)
