import networkx as nx
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import math
import random

################################################################################

def train(A, NZ, U, V, alpha, beta):
	nnz = len(NZ[0])
	for n in range(nnz):
		(i, j) = (NZ[0][n], NZ[1][n])
	# for i in range(A.shape[0]):
	# 	for j in range(A.shape[1]):
			
		eij = A[i,j] - np.dot(U[i,:], np.dot(V, U[j,:]))
		gradI = np.dot(V, U[j,:])
		gradJ = np.dot(U[i,:], V)
		gradV = np.outer(U[i,:], U[j,:])
		
		U[i,:] = U[i,:] + alpha * (2 * eij * gradI - beta * U[i,:])
		U[j,:] = U[j,:] + alpha * (2 * eij * gradJ - beta * U[j,:])
		V = V + alpha * (2 * eij * gradV - beta * V)

def compute_error(A, NZ, U, V, beta):
	nnz = len(NZ[0])
	error = 0
	for n in range(nnz):
		(i, j) = (NZ[0][n], NZ[1][n])
	# for i in range(A.shape[0]):
	#	for j in range(A.shape[1]):
		error = error + pow(A[i,j] - np.dot(U[i,:], np.dot(V, U[j,:])), 2)
	
	error = error + beta / 2 * np.linalg.norm(U, 'fro')
	error = error + beta / 2 * np.linalg.norm(V, 'fro')
	return error

def factorize_matrix(A, U, V, iterations, alpha, beta):
	NZ = A.nonzero()
	prev_error = 99999999
	for i in range(iterations):
		train(A, NZ, U, V, alpha, beta)
		error = compute_error(A, NZ, U, V, beta)
		print(error)
	return U, V.T

################################################################################

# A = [
# 		 [1,1,1,1,0],
# 		 [1,0,0,1,1],
# 		 [1,1,0,1,0],
# 		 [1,0,0,1,1],
# 		 [1,1,0,1,0]
# 	 ]
# A = np.array(A)
# n = A.shape[0]


G = nx.read_weighted_edgelist("out.dnc-temporalGraph", create_using=nx.MultiGraph(), comments="%")
edges = sorted(G.edges(data=True), key=lambda t: t[2].get('weight', 1))

G1 = nx.MultiGraph()
t_0 = G.size() / 10
counter = 0
max_k = 0
for e in edges:
	G1.add_edge(e[0], e[1])
	counter += 1
	if counter > t_0:
		break

A = nx.adjacency_matrix(G1)
n = A.shape[0]

# number of latent features
k = 10
U = np.random.rand(n, k)
V = np.random.rand(k, k)

iterations = 10000
alpha = 0.0001
beta = 0.02
(U, V) = factorize_matrix(A, U, V, iterations, alpha, beta)

P = np.dot(U, np.dot(V, U.T))
num_preds = 100
Preds = []
while len(Preds) < num_preds:
	idx = np.unravel_index(np.argmax(P, axis=None), P.shape)
	#print(idx, P[idx])
	if A[idx] == 0 and idx[0] != idx[1]:
		Preds.append(idx)
	P[(idx[0], idx[1])] = 0
	P[(idx[1], idx[0])] = 0
	

count = 0
for idx in Preds:
	(u, v) = (list(G1.nodes())[idx[0]], list(G1.nodes())[idx[1]])
	if G.has_edge(u, v):
		count += 1

precision = count / num_preds
print(precision)
