diff --git a/src/absinth.py b/src/absinth.py index 294e55c4bc47be478f21cc7ff596ef1f52b525b7..d6d22ae0e0e731131270b1a239e7805141ac0f35 100644 --- a/src/absinth.py +++ b/src/absinth.py @@ -561,14 +561,14 @@ def colour_graph(graph: nx.Graph, root_hub_list: list) -> nx.Graph: for node in graph.nodes: - graph.node[node]['dist'] = [0] * len(root_hub_list) + graph.node[node]['dist'] = [[0] * len(root_hub_list)] if node in root_hub_list: root_idx = root_hub_list.index(node) graph.node[node]['sense'] = root_idx - graph.node[node]['dist'][root_idx] = 1 + graph.node[node]['dist'][0][root_idx] = 1 else: @@ -600,11 +600,10 @@ def colour_graph(graph: nx.Graph, root_hub_list: list) -> nx.Graph: if any(neighbor_weight_list): - graph.node[node]['dist'] = np.mean([graph.node[node]['dist'], - neighbor_weight_list], axis=0) + graph.node[node]['dist'].append(neighbor_weight_list) old_colour = graph_copy.node[node]['sense'] - new_colour = np.argmax(graph.node[node]['dist']) + new_colour = np.argmax(np.mean(graph.node[node]['dist'], axis=0)) if old_colour != new_colour: stable = False @@ -617,6 +616,10 @@ def colour_graph(graph: nx.Graph, root_hub_list: list) -> nx.Graph: pass + for node in graph.nodes: + + graph.node[node]['dist'] = np.mean(graph.node[node]['dist'], axis=0) + return graph @@ -825,6 +828,10 @@ def print_stats(stat_dict: dict) -> None: print('\n[A] '+'\n[A] '.join(stat_string)+'\n') + with open('statistics.txt', 'a') as stat_file: + + stat_file.write('\n '.join(stat_string)+'\n\n') + write_header = not os.path.exists('.statistics.tsv') with open('.statistics.tsv', 'a') as stat_file: @@ -834,6 +841,9 @@ def print_stats(stat_dict: dict) -> None: stat_file.write('\t'.join(key_list)+'\n') stat_file.write('\t'.join([str(stat_dict[key]) for key in key_list])+'\n') + + +