import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import random
import operator
################################################################################
# Weak and Strong ties
# We've talked about about strength of ties and triadic closure in class. We've
# noted that diffusive processes, or information/disease/etc. spread, in a
# network can jump to distant parts of a network across weak ties. This relates
# to our notions of connectivity -- e.g., weak ties are more likely to be cut
# edges, or local cut edges, than strong ties. We'll eventually talk more about
# 'centrality', which is the notion of 'importance' within the network. Here, we
# might think of vertices that have weak ties to each other as important for our
# diffusive processes.

# We'll read in a sheep-based dataset, where links are given a measure of
# importance or strength. Sheep with stronger links spend more time around
# each other doing various sheep things (presumably).
G = nx.read_weighted_edgelist("out.moreno_sheep_sheep.data", create_using=nx.Graph(), comments="%")

# Does the strength of local ties correlate with the number of common neighbors?
# Let's look at overlaps of neighborhoods for each edge, similar to what we did
# last class. First, let's enumerate the common neighborhoods for each edge and
# take a ratio of overlaps relative to the sum degree of both vertices.
overlaps = {}
for e in G.edges():
  overlaps[e] = len(list(nx.common_neighbors(G, e[0], e[1]))) / (G.degree(e[0]) + G.degree(e[1]))

# Next, we'll measure the actual weights for each given edge and compare to the
# strength of ties in terms of neighborhood overlap. We'll sort the edges by
# weight of each edge, go through them in order, and track the sums of the above
# ratios for each specific weight. We expect these values to increase over time.
tie_strengths = sorted(G.edges(data=True), key=lambda t: t[2].get('weight', 1))

data = []
strength = 1.0
cur_sum = 0.0
cur_num = 0
for e in tie_strengths:
  if e[2]['weight'] != strength:
    if cur_num > 0:
      data.append((strength, cur_sum / cur_num))
    cur_sum = overlaps[(e[0],e[1])]
    cur_num = 1
    strength = e[2]['weight']
  else:
    cur_num += 1
    cur_sum += overlaps[(e[0],e[1])]

data.append((strength, cur_sum / cur_num))

# Let's plot if that appears to be the case. The x axis has the given edge 
# weight and the y axis has the average ratio of neighborhood overlaps.
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot([k for (k,v) in data], [v for (k, v) in data])
plt.show()


################################################################################
# Weak and strong ties relative to connectivity
# Suppose we wanted to ask the question posed above, specifically about whether 
# removing ties from weakest to strongest disconnects the graph quicker than 
# removing ties from strongest to weakest? Let's empirically observe.
tie_strengths_weak_to_strong = sorted(G.edges(data=True), key=lambda t: t[2].get('weight', 1), reverse=False)
tie_strengths_strong_to_weak = sorted(G.edges(data=True), key=lambda t: t[2].get('weight', 1), reverse=True)

# First, let's see from strong to weak.
G1 = G.copy()
counter = 0
for e in tie_strengths_strong_to_weak:
  G1.remove_edge(e[0], e[1])
  counter += 1
  if nx.is_connected(G1) == False:
    break 
  if counter % 10 == 0:
    nx.draw(G1)
    plt.show()

print(counter)

# Next, let's go from weak to strong to observe any difference.
G1 = G.copy()
counter = 0
for e in tie_strengths_weak_to_strong:
  G1.remove_edge(e[0], e[1])
  counter += 1
  if nx.is_connected(G1) == False:
    break 
  if counter % 1 == 0:
    nx.draw(G1)
    plt.show()

print(counter)


################################################################################
# Diffusion
# Let's run a diffusive algorithm using the removal-of-ties as in above. We 
# should probably expect that the more weak ties we remove the slower the 
# algorithm takes to converge. Strong ties will likely have little effect.

# print out our graph, coloring of vertices in based on state values
def print_G(G, S):
  colors = [S.get(node, 0.25) for node in G.nodes()]
  nx.draw_kamada_kawai(G, cmap=plt.get_cmap('viridis'), node_color=colors,  with_labels=True, font_color='white')
  plt.show()
  return

# a function to select the argument of the maximum value in a dictionary
def dict_argmax(d):
  if len(d.values()) == 0:
    return -1
  min_val = max(d.values())
  return [k for k in d if d[k] == min_val][0]

# state initialization - every vertex gets different vertex
def init_state(G):
  S = {}
  vid = 0
  for v in G.nodes():
    S[v] = vid
    vid += 1
  return S

# actually run our iterative diffusive process
def run_defusion(G):
  S = init_state(G)
  updates = 1
  iterations = 0
  while updates > 0:
    print_G(G, S)
    updates = 0
    for v in G.nodes():
      counts = {}
      for u in G.neighbors(v):
        if S[u] not in counts:
          counts[S[u]] = 1
        else:
          counts[S[u]] += 1
      new_val = dict_argmax(counts)
      if S[v] != new_val:
        S[v] = new_val
        updates += 1
    iterations += 1
  print("Iterations:", iterations)
  return

# we'll run our diffusion after removing 50 edges each time
# What would we expect to see here?
G1 = G.copy()
counter = 0
for e in tie_strengths_strong_to_weak:
#for e in tie_strengths_weak_to_strong:
  G1.remove_edge(e[0], e[1])
  counter += 1
  if counter % 50 == 0:
    run_defusion(G1)
