import networkx as nx
import numpy as np
import scipy as sp
import scipy.cluster.vq as vq
import matplotlib.pyplot as plt
import math
import random
import operator
import itertools

################################################################################
# Source: https://towardsdatascience.com/graph-neural-networks-in-python-c310c7c18c83
################################################################################
import torch
import torch_geometric

################################################################################
# We'll load up the Karate Club dataset. The vertices are placed into 4 
# classes, which can roughly correspond to communities. Our goal will be to 
# classify 30 vertices by only using 4 'seed' vertices in a semi-supervised 
# approach.
dataset = torch_geometric.datasets.KarateClub()
print("Dataset:", dataset)
print("# Graphs:", len(dataset))
print("# Features:", dataset.num_features)
print("# Classes:", dataset.num_classes)

data = dataset[0]
print(data)
print("Training nodes:", data.train_mask.sum().item())
print("Is directed:", data.is_directed())

G = torch_geometric.utils.to_networkx(data, to_undirected=True)
nx.draw(G, node_color=data.y, node_size=150)
plt.show()

################################################################################
# We can define our graph convolutional structure here. We'll have three layers
# that will output two values which we can use for classification into one of
# four classes. Effectively, each vertex will be trained using feature data from
# within 3 hops away.
class GCN(torch.nn.Module):
  def __init__(self):
    super(GCN, self).__init__()
    torch.manual_seed(42)
    self.conv1 = torch_geometric.nn.GCNConv(dataset.num_features, 4)
    self.conv2 = torch_geometric.nn.GCNConv(4, 4)
    self.conv3 = torch_geometric.nn.GCNConv(4, 2)
    self.classifier = torch.nn.Linear(2, dataset.num_classes)
  
  def forward(self, x, edge_index):
    h = self.conv1(x, edge_index)
    h = h.tanh()
    h = self.conv2(h, edge_index)
    h = h.tanh()
    h = self.conv3(h, edge_index)
    h = h.tanh()
    out = self.classifier(h)
    return out, h

model = GCN()
print(model)

################################################################################
# Here we'll set up and do the training on our model as defined above. We'll 
# train for 300 epochs, keeping track of our output from the third layer that
# we can parse as an embedding into 2D space for visualization.
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train(data):
  optimizer.zero_grad()
  out, h = model(data.x, data.edge_index)
  loss = criterion(out[data.train_mask], data.y[data.train_mask])
  loss.backward()
  optimizer.step()
  return loss, h

epochs = range(0, 300)
losses = []
embeddings = []

for epoch in epochs:
  loss, h = train(data)
  losses.append(loss)
  embeddings.append(h)
  print(f"Epoch: {epoch}\tLoss: {loss:.4f}")

################################################################################
# Using the above embedding values across all 300 epochs, we can watch at how
# our GNN is trained over time. We would expect that each vertex in each ground
# truth class would also cluster in 2D space. Note the similarity to these 
# initial outputs and some of the embeddings we produced via spectral means on
# the same network.

import matplotlib.animation as animation

def animate(i):
  ax.clear()
  h = embeddings[i]
  h = h.detach().numpy()
  ax.scatter(h[:, 0], h[:, 1], c=data.y, s=100)
  ax.set_title(f'Epoch: {epochs[i]}, Loss: {losses[i].item():.4f}')
  ax.set_xlim([-1.1, 1.1])
  ax.set_ylim([-1.1, 1.1])

fig = plt.figure(figsize=(6, 6))
ax = plt.axes()
anim = animation.FuncAnimation(fig, animate, frames=epochs)
plt.show()
