import networkx as nx
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import math
import random

################################################################################


def train(A, NZ, U, V, alpha, beta):
	nnz = len(NZ[0])
	for n in range(nnz):
		(i, j) = (NZ[0][n], NZ[1][n])
		eij = A[i,j] - np.dot(U[i,:],V[:,j])
		U[i,:] = U[i,:] + alpha * (2 * eij * V[:,j] - beta * U[i,:])
		V[:,j] = V[:,j] + alpha * (2 * eij * U[i,:] - beta * V[:,j])

def compute_error(A, NZ, U, V, beta):
	error = 0
	nnz = len(NZ[0])
	for n in range(nnz):
		(i, j) = (NZ[0][n], NZ[1][n])
		error += pow(A[i,j] - np.dot(U[i,:], V[:,j]), 2)
			
	error = error + beta / 2 * np.linalg.norm(U, 'fro')
	error = error + beta / 2 * np.linalg.norm(V, 'fro')
	return error

def factorize_matrix(A, U, V, iterations, alpha, beta):
	NZ = A.nonzero()
	for i in range(iterations):
		train(A, NZ, U, V, alpha, beta)
		error = compute_error(A, NZ, U, V, beta)
		print(error)
	return U, V.T

################################################################################
A = [
		 [5,3,0,1],
		 [4,0,0,1],
		 [1,1,0,5],
		 [1,0,0,4],
		 [0,1,5,4],
		]
A = np.array(A)
n = len(A)
m = len(A[0])
k = 5

U = np.random.rand(n, k)
V = np.random.rand(m, k)

iterations = 10000
alpha = 0.0002
beta = 0.02
(U, V) = factorize_matrix(A, U, V.T, iterations, alpha, beta)
P = np.dot(U, V.T)

################################################################################
B = nx.read_edgelist("amazon_video.data", create_using=nx.Graph(), comments="%", data=(("rating", float),("time",int)))

users = nx.bipartite.basic.sets(B)[0]
videos = nx.bipartite.basic.sets(B)[1]
A = nx.bipartite.biadjacency_matrix(B, row_order=users, column_order=videos, weight='rating')

NZ = A.nonzero()
test_size = 100
test = random.sample([item for item in range(0, len(NZ[0]))],test_size)

test_truths = {}
for t in test:
	i = NZ[0][t]
	j = NZ[1][t]
	test_truths[t] = A[i, j]
	A[i, j] = 0

p_naive = {}
for t in test:
	user = NZ[0][t]
	video = NZ[1][t]
	p_naive[t] = (A[user,:].data.mean() + A[:,video].data.mean()) / 2

n = A.shape[0]
m = A.shape[1]
k = 40
U = np.random.rand(n, k)
V = np.random.rand(m, k)

iterations = 10
alpha = 0.01
beta = 0.02
(U, V) = factorize_matrix(A, U, V.T, iterations, alpha, beta)

P = np.dot(U, V.T)
test_error = 0
test_error_naive = 0
for t in test:
	i = NZ[0][t]
	j = NZ[1][t]
	print("Test:", test_truths[t], P[i, j])
	test_error += abs(P[i, j] - test_truths[t])
	test_error_naive += abs(p_naive[t] - test_truths[t])

print("Test Error:", test_error / test_size)
print("Test Error Naive:", test_error_naive / test_size)

################################################################################
B = nx.read_edgelist("amazon_video.data", create_using=nx.Graph(), comments="%", data=(("rating", float),("time",int)))
edges = sorted(B.edges(data=True), key=lambda t: t[2].get('time', 1))

B1 = nx.Graph()
t_0 = B.size() / 2
counter = 0
max_k = 0
for e in edges:
	B1.add_edge(e[0], e[1], weight=e[2]['rating'])
	counter += 1
	if counter > t_0:
		break

B1 = B1.subgraph(sorted(nx.connected_components(B1), key=len, reverse=True)[0])
users = nx.bipartite.basic.sets(B1)[0]
videos = nx.bipartite.basic.sets(B1)[1]
A = nx.bipartite.biadjacency_matrix(B1, row_order=users, column_order=videos, weight='weight')

n = A.shape[0]
m = A.shape[1]
k = 10
U = np.random.rand(n, k)
V = np.random.rand(m, k)

iterations = 100
alpha = 0.0002
beta = 0.02
(U, V) = factorize_matrix(A, U, V.T, iterations, alpha, beta)

P = np.dot(U, V.T)

num_preds = 1000
Preds = []
while len(Preds) < num_preds:
	idx = np.unravel_index(np.argmax(P, axis=None), P.shape)
	print(idx, P[idx])
	if A[idx] == 0 and idx[0] != idx[1]:
		Preds.append(idx)
	P[(idx[0], idx[1])] = 0

count = 0
for idx in Preds:
	(u, v) = (list(B1.nodes())[idx[0]], list(B1.nodes())[idx[1]])
	if B.has_edge(u, v):
		count += 1

precision = count / num_preds
print("Precision:", precision)

