From adab197463c2fb95a5cfdf5bbfc96366d7b2242f Mon Sep 17 00:00:00 2001
From: Victor Zimmermann <zimmermann@cl.uni-heidelberg.de>
Date: Wed, 21 Mar 2018 18:17:24 +0100
Subject: [PATCH] Colouring algorithm now converges.

---
 src/absinth.py | 20 +++++++++++++++-----
 1 file changed, 15 insertions(+), 5 deletions(-)

diff --git a/src/absinth.py b/src/absinth.py
index 294e55c..d6d22ae 100644
--- a/src/absinth.py
+++ b/src/absinth.py
@@ -561,14 +561,14 @@ def colour_graph(graph: nx.Graph, root_hub_list: list) -> nx.Graph:
     
     for node in graph.nodes:
         
-        graph.node[node]['dist'] = [0] * len(root_hub_list)
+        graph.node[node]['dist'] = [[0] * len(root_hub_list)]
             
         if node in root_hub_list:
             
             root_idx = root_hub_list.index(node)
             
             graph.node[node]['sense'] = root_idx
-            graph.node[node]['dist'][root_idx] = 1
+            graph.node[node]['dist'][0][root_idx] = 1
             
         else:
             
@@ -600,11 +600,10 @@ def colour_graph(graph: nx.Graph, root_hub_list: list) -> nx.Graph:
             
             if any(neighbor_weight_list):
                 
-                graph.node[node]['dist'] = np.mean([graph.node[node]['dist'],
-                                                    neighbor_weight_list], axis=0)
+                graph.node[node]['dist'].append(neighbor_weight_list)
                 
                 old_colour = graph_copy.node[node]['sense']
-                new_colour = np.argmax(graph.node[node]['dist'])
+                new_colour = np.argmax(np.mean(graph.node[node]['dist'], axis=0))
                 
                 if old_colour != new_colour:
                     stable = False
@@ -617,6 +616,10 @@ def colour_graph(graph: nx.Graph, root_hub_list: list) -> nx.Graph:
                 
                 pass
     
+    for node in graph.nodes:
+        
+        graph.node[node]['dist'] = np.mean(graph.node[node]['dist'], axis=0)
+    
     return graph
     
     
@@ -825,6 +828,10 @@ def print_stats(stat_dict: dict) -> None:
     
     print('\n[A] '+'\n[A] '.join(stat_string)+'\n')
     
+    with open('statistics.txt', 'a') as stat_file:
+        
+        stat_file.write('\n '.join(stat_string)+'\n\n')
+    
     write_header = not os.path.exists('.statistics.tsv')
     
     with open('.statistics.tsv', 'a') as stat_file:
@@ -834,6 +841,9 @@ def print_stats(stat_dict: dict) -> None:
             stat_file.write('\t'.join(key_list)+'\n')
             
         stat_file.write('\t'.join([str(stat_dict[key]) for key in key_list])+'\n')
+        
+    
+        
             
         
 
-- 
GitLab