From 72463d07f7d28e0fcfa0742c5921a2b8e3f5937e Mon Sep 17 00:00:00 2001
From: Victor Zimmermann <zimmermann@cl.uni-heidelberg.de>
Date: Wed, 21 Mar 2018 12:47:08 +0100
Subject: [PATCH] Changed binary colouring to colour distribution.

---
 src/absinth.py | 68 ++++++++++++++++++++++++++++++--------------------
 src/config.py  | 13 +++++++++-
 2 files changed, 53 insertions(+), 28 deletions(-)

diff --git a/src/absinth.py b/src/absinth.py
index b1a11f1..94f3fd6 100644
--- a/src/absinth.py
+++ b/src/absinth.py
@@ -564,15 +564,25 @@ def colour_graph(graph: nx.Graph, root_hub_list: list) -> nx.Graph:
     
     
     for node in graph.nodes:
+        
+        graph.node[node]['dist'] = [0] * len(root_hub_list)
+            
         if node in root_hub_list:
-            graph.node[node]['sense'] = root_hub_list.index(node)
+            
+            root_idx = root_hub_list.index(node)
+            
+            graph.node[node]['sense'] = root_idx
+            graph.node[node]['dist'][root_idx] = 1
+            
         else:
+            
             graph.node[node]['sense'] = None
     
     max_iteration_count = config.max_colour_iteration_count
     
     iteration_count = 0
     stable = False
+    
     while stable == False and iteration_count <= max_iteration_count:
         
         graph_copy = deepcopy(graph)
@@ -591,10 +601,14 @@ def colour_graph(graph: nx.Graph, root_hub_list: list) -> nx.Graph:
                     neighbor_weight_list[graph_copy.node[neighbor]['sense']] \
                         += 1 - graph_copy[node][neighbor]['weight']
             
+            
             if any(neighbor_weight_list):
                 
+                graph.node[node]['dist'] = np.mean([graph.node[node]['dist'],
+                                                    neighbor_weight_list], axis=0)
+                
                 old_colour = graph_copy.node[node]['sense']
-                new_colour = np.argmax(neighbor_weight_list)
+                new_colour = np.argmax(graph.node[node]['dist'])
                 
                 if old_colour != new_colour:
                     stable = False
@@ -651,36 +665,37 @@ def disambiguate_colour(graph: nx.Graph, root_hub_list: list, context_list: list
             
             if text in coloured_graph.nodes:
                 
-                text_colour = coloured_graph.node[text]['sense']
+                text_colour_dist = coloured_graph.node[text]['dist']
                 
-                if text_colour == None:
+                if not any(text_colour_dist):
                     
                     pass
                 
                 else:
-                
-                    text_root = root_hub_list[text_colour]
                     
-                    if nx.has_path(coloured_graph , text, text_root):
+                    for root_hub in root_hub_list:
                         
+                        root_hub_idx = root_hub_list.index(root_hub)
                     
-                        shortest_path = nx.shortest_path(coloured_graph ,
-                                                        text,
-                                                        root_hub_list[text_colour],
-                                                        'weight')
-                        total_weight = 0
-
-                        # Add weights of every sub-path.
-                        for i in range(1, len(shortest_path)):
-                            sub_from, sub_to = shortest_path[i-1], shortest_path[i]
-                            total_weight += \
-                                coloured_graph [sub_from][sub_to]['weight']
-
-            
-                        score[text_colour] += 1/(1+total_weight)
+                        if nx.has_path(coloured_graph , text, root_hub):
+                            
+                            shortest_path = nx.shortest_path(coloured_graph ,
+                                                            text,
+                                                            root_hub,
+                                                            'weight')
+                            total_weight = 0
+
+                            # Add weights of every sub-path.
+                            for i in range(1, len(shortest_path)):
+                                sub_from, sub_to = shortest_path[i-1], shortest_path[i]
+                                total_weight += \
+                                    coloured_graph[sub_from][sub_to]['weight']
                 
-                    else:
-                        pass
+                            score[root_hub_idx] += (1/(1+total_weight)) \
+                             * colour_graph.node[text]['dist'][root_hub_idx]
+                    
+                        else:
+                            pass
             
             else:
                 pass
@@ -768,7 +783,7 @@ def disambiguate_mst(graph: nx.Graph, root_hub_list: list,
                 pass
         
         # If disambiguator does not detect a sense, return singleton.
-        if any(score_array):
+        if not any(score_array):
             
             pass
             
@@ -801,8 +816,7 @@ def main(topic_id: int, topic_name: str, result_dict: dict) -> None:
     """
     
     print('[a]', 'Inducing word senses for {}.'.format(topic_name))
-    graph, root_hub_list, stat_dict = induce(topic_name,
-                                                              result_dict[topic_id])
+    graph, root_hub_list, stat_dict = induce(topic_name,result_dict[topic_id])
     
     colour_rank = config.colour_rank
     mst_rank = config.mst_rank
@@ -851,7 +865,7 @@ def main(topic_id: int, topic_name: str, result_dict: dict) -> None:
                     
                 else:
                     merged_mapping_dict[topic] = [result]
-    
+                    
     stat_dict['merge_gain'] = merged_entry_count
     
     #collect statistics from result.
diff --git a/src/config.py b/src/config.py
index d40c5f2..5f6471e 100644
--- a/src/config.py
+++ b/src/config.py
@@ -11,6 +11,18 @@ dataset = "../WSI-Evaluator/datasets/dataset/"
 test = "../WSI-Evaluator/datasets/trial/"
 output = "../output/"
 
+'''
+Disambiguation Pipeline
+There are multiple disambiguation methods implemented. Specify the order in
+which they should be merged and whether or not conflicts should be resolved.
+If conflicts are not resolved, the first method with a positive result is used.
+Methods labeled with 0 are ignored.
+At least one method must be given a value != 0.
+'''
+resolve_conflicts = False #not yet implemented
+mst_rank = 0
+colour_rank = 1
+
 '''
 Choose stop words and allowed pos-tags.
 - Stop words will not be considered for nodes.
@@ -54,5 +66,4 @@ lemma = False
 '''
 colouring options
 '''
-use_colouring = True
 max_colour_iteration_count = 50
-- 
GitLab