From 677a5f023c31dd9f86e1b77a448f76e374388205 Mon Sep 17 00:00:00 2001 From: Victor Zimmermann <zimmermann@cl.uni-heidelberg.de> Date: Wed, 21 Mar 2018 11:20:54 +0100 Subject: [PATCH] Minor bugfixes. --- src/absinth.py | 64 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/src/absinth.py b/src/absinth.py index 13aaa3d..b1a11f1 100644 --- a/src/absinth.py +++ b/src/absinth.py @@ -30,7 +30,6 @@ import networkx as nx # for visualisation import numpy as np import os # for reading files import pprint -import random import re import spacy # for nlp @@ -548,10 +547,12 @@ def induce(topic_name: str, result_list: list) -> (nx.Graph, list, dict): return graph, root_hub_list, stat_dict + def colour_graph(graph: nx.Graph, root_hub_list: list) -> nx.Graph: """Colours graph accoring to root hubs. - Evolving network that colours neighboring nodes iterative. + Evolving network that colours neighboring nodes iterative. See sentiment + propagation. Args: graph: Weighted undirected graph. @@ -608,6 +609,7 @@ def colour_graph(graph: nx.Graph, root_hub_list: list) -> nx.Graph: return graph + def disambiguate_colour(graph: nx.Graph, root_hub_list: list, context_list: list) -> dict: """Clusters senses to root hubs using a coloured graph. @@ -802,23 +804,61 @@ def main(topic_id: int, topic_name: str, result_dict: dict) -> None: graph, root_hub_list, stat_dict = induce(topic_name, result_dict[topic_id]) + colour_rank = config.colour_rank + mst_rank = config.mst_rank + + #Merges Mappings according to pipeline + mapping_dict = dict() + #matches senses to clusters - print('[a]', 'Disambiguating result_list.\t('+topic_name+')') - if config.use_colouring == True: + print('[a]', 'Disambiguating results.\t('+topic_name+')') + if colour_rank != 0: + print('[a]', 'Colouring graph.\t('+topic_name+')') - mapping_dict = disambiguate_colour(graph, root_hub_list, - result_dict[topic_id]) - else: + mapping_dict[colour_rank] = disambiguate_colour(graph, root_hub_list, + result_dict[topic_id]) + + if mst_rank != 0: print('[a]', 'Building minimum spanning tree.\t('+topic_name+')') - mapping_dict = disambiguate_mst(graph, root_hub_list, - result_dict[topic_id], topic_name) + mapping_dict[mst_rank] = disambiguate_mst(graph, + root_hub_list, + result_dict[topic_id], + topic_name) + + mapping_list = [item[1] for item in sorted(mapping_dict.items())] + mapping_count = len(mapping_list) + + merged_mapping_dict = mapping_list[0] + merged_entry_count = 0 + + for i in range(1,mapping_count): + + result_list = [result for result_list in merged_mapping_dict.values() + for result in result_list] + #individual mappings + relation_list = [(topic,result) for topic in mapping_list[i].keys() + for result in mapping_list[i][topic]] + + for topic, result in relation_list: + + if result not in result_list: + + merged_entry_count += 1 + + if topic in merged_mapping_dict: + merged_mapping_dict[topic].append(result) + + else: + merged_mapping_dict[topic] = [result] + + stat_dict['merge_gain'] = merged_entry_count #collect statistics from result. cluster_count = 0 cluster_length_list = list() - for cluster,result_list in mapping_dict.items(): + for cluster,result_list in merged_mapping_dict.items(): cluster_length = len(result_list) @@ -839,7 +879,7 @@ def main(topic_id: int, topic_name: str, result_dict: dict) -> None: output_file.write('subTopicID\tresultID\n') - for cluster_id,result_list in mapping_dict.items(): + for cluster_id,result_list in merged_mapping_dict.items(): for result_id in result_list: output_line = '{}.{}\t{}.{}\n'.format(topic_id, cluster_id, topic_id, result_id) @@ -874,4 +914,4 @@ if __name__ == '__main__': with Pool(process_count) as pool: parameter_list = [(topic_id, topic_name, result_dict) for topic_id,topic_name in topic_dict.items()] - pool.starmap(main, parameter_list) + pool.starmap(main, sorted(parameter_list)) #determineate function -- GitLab