diff --git a/src/absinth.py b/src/absinth.py index b1a11f187180761ffa024e74d7e0460abb3a07dd..94f3fd67acfa5950315309f1b37bc11b91a9a02e 100644 --- a/src/absinth.py +++ b/src/absinth.py @@ -564,15 +564,25 @@ def colour_graph(graph: nx.Graph, root_hub_list: list) -> nx.Graph: for node in graph.nodes: + + graph.node[node]['dist'] = [0] * len(root_hub_list) + if node in root_hub_list: - graph.node[node]['sense'] = root_hub_list.index(node) + + root_idx = root_hub_list.index(node) + + graph.node[node]['sense'] = root_idx + graph.node[node]['dist'][root_idx] = 1 + else: + graph.node[node]['sense'] = None max_iteration_count = config.max_colour_iteration_count iteration_count = 0 stable = False + while stable == False and iteration_count <= max_iteration_count: graph_copy = deepcopy(graph) @@ -591,10 +601,14 @@ def colour_graph(graph: nx.Graph, root_hub_list: list) -> nx.Graph: neighbor_weight_list[graph_copy.node[neighbor]['sense']] \ += 1 - graph_copy[node][neighbor]['weight'] + if any(neighbor_weight_list): + graph.node[node]['dist'] = np.mean([graph.node[node]['dist'], + neighbor_weight_list], axis=0) + old_colour = graph_copy.node[node]['sense'] - new_colour = np.argmax(neighbor_weight_list) + new_colour = np.argmax(graph.node[node]['dist']) if old_colour != new_colour: stable = False @@ -651,36 +665,37 @@ def disambiguate_colour(graph: nx.Graph, root_hub_list: list, context_list: list if text in coloured_graph.nodes: - text_colour = coloured_graph.node[text]['sense'] + text_colour_dist = coloured_graph.node[text]['dist'] - if text_colour == None: + if not any(text_colour_dist): pass else: - - text_root = root_hub_list[text_colour] - if nx.has_path(coloured_graph , text, text_root): + for root_hub in root_hub_list: + root_hub_idx = root_hub_list.index(root_hub) - shortest_path = nx.shortest_path(coloured_graph , - text, - root_hub_list[text_colour], - 'weight') - total_weight = 0 - - # Add weights of every sub-path. - for i in range(1, len(shortest_path)): - sub_from, sub_to = shortest_path[i-1], shortest_path[i] - total_weight += \ - coloured_graph [sub_from][sub_to]['weight'] - - - score[text_colour] += 1/(1+total_weight) + if nx.has_path(coloured_graph , text, root_hub): + + shortest_path = nx.shortest_path(coloured_graph , + text, + root_hub, + 'weight') + total_weight = 0 + + # Add weights of every sub-path. + for i in range(1, len(shortest_path)): + sub_from, sub_to = shortest_path[i-1], shortest_path[i] + total_weight += \ + coloured_graph[sub_from][sub_to]['weight'] - else: - pass + score[root_hub_idx] += (1/(1+total_weight)) \ + * colour_graph.node[text]['dist'][root_hub_idx] + + else: + pass else: pass @@ -768,7 +783,7 @@ def disambiguate_mst(graph: nx.Graph, root_hub_list: list, pass # If disambiguator does not detect a sense, return singleton. - if any(score_array): + if not any(score_array): pass @@ -801,8 +816,7 @@ def main(topic_id: int, topic_name: str, result_dict: dict) -> None: """ print('[a]', 'Inducing word senses for {}.'.format(topic_name)) - graph, root_hub_list, stat_dict = induce(topic_name, - result_dict[topic_id]) + graph, root_hub_list, stat_dict = induce(topic_name,result_dict[topic_id]) colour_rank = config.colour_rank mst_rank = config.mst_rank @@ -851,7 +865,7 @@ def main(topic_id: int, topic_name: str, result_dict: dict) -> None: else: merged_mapping_dict[topic] = [result] - + stat_dict['merge_gain'] = merged_entry_count #collect statistics from result. diff --git a/src/config.py b/src/config.py index d40c5f2612036a91c03a835a5a512db029ea53ee..5f6471e67c947e7e9269ed0f765486059662ad82 100644 --- a/src/config.py +++ b/src/config.py @@ -11,6 +11,18 @@ dataset = "../WSI-Evaluator/datasets/dataset/" test = "../WSI-Evaluator/datasets/trial/" output = "../output/" +''' +Disambiguation Pipeline +There are multiple disambiguation methods implemented. Specify the order in +which they should be merged and whether or not conflicts should be resolved. +If conflicts are not resolved, the first method with a positive result is used. +Methods labeled with 0 are ignored. +At least one method must be given a value != 0. +''' +resolve_conflicts = False #not yet implemented +mst_rank = 0 +colour_rank = 1 + ''' Choose stop words and allowed pos-tags. - Stop words will not be considered for nodes. @@ -54,5 +66,4 @@ lemma = False ''' colouring options ''' -use_colouring = True max_colour_iteration_count = 50