Skip to content
Snippets Groups Projects
Commit d01631b5 authored by Victor Zimmermann's avatar Victor Zimmermann
Browse files

Merge branch 'colour_dist' into 'master'

Better Statistics, bugfixes.

See merge request zimmermann/absinth!2
parents 42d66c63 6ee3d178
No related branches found
No related tags found
No related merge requests found
......@@ -32,9 +32,11 @@ import os # for reading files
import pprint
import re
import spacy # for nlp
import time
from multiprocessing import Pool
from copy import deepcopy
from multiprocessing import Pool
from scipy import stats
nlp = spacy.load('en') # standard english nlp
......@@ -499,53 +501,46 @@ def induce(topic_name: str, result_list: list) -> (nx.Graph, list, dict):
stat_dict = dict()
if topic_name in [output_file_name.replace('.absinth', '')
for output_file_name in os.listdir(config.output)]:
stat_dict['target'] = topic_name
return None
else:
stat_dict['target'] = topic_name
#in topics longer than two words, the leading 'the' can generally be removed without changing the sense
if topic_name[:4] == 'the_' and topic_name.count('_') > 1:
#in topics longer than two words, the leading 'the' can generally be removed without changing the sense
if topic_name[:4] == 'the_' and topic_name.count('_') > 1:
target_string = topic_name[4:]
else:
target_string = topic_name
print('[a]', 'Counting nodes and edges.\t('+topic_name+')')
node_freq_dict, edge_freq_dict = frequencies(target_string, result_list)
target_string = topic_name[4:]
#builds graph from these dictionaries, also applies multiple filters
print('[a]', 'Building graph.\t('+topic_name+')')
graph = build_graph(node_freq_dict, edge_freq_dict)
else:
stat_dict['node count'] = len(graph.nodes)
stat_dict['edge count'] = len(graph.edges)
target_string = topic_name
print('[a]', 'Counting nodes and edges.\t('+topic_name+')')
node_freq_dict, edge_freq_dict = frequencies(target_string, result_list)
#builds graph from these dictionaries, also applies multiple filters
print('[a]', 'Building graph.\t('+topic_name+')')
graph = build_graph(node_freq_dict, edge_freq_dict)
stat_dict['node count'] = len(graph.nodes)
stat_dict['edge count'] = len(graph.edges)
#finds root hubs (senses) within the graph + more filters for these
print('[a]', 'Collecting root hubs.\t('+topic_name+')')
root_hub_list = root_hubs(graph, edge_freq_dict)
#adds sense inventory to buffer with some common neighbors for context
stat_dict['hubs'] = dict()
for root_hub in root_hub_list:
#finds root hubs (senses) within the graph + more filters for these
print('[a]', 'Collecting root hubs.\t('+topic_name+')')
root_hub_list = root_hubs(graph, edge_freq_dict)
#adds sense inventory to buffer with some common neighbors for context
stat_dict['hubs'] = dict()
for root_hub in root_hub_list:
by_frequency = lambda node: edge_freq_dict[root_hub,node] \
if root_hub < node \
else edge_freq_dict[node, root_hub]
most_frequent_neighbor_list = sorted(graph.adj[root_hub],
key=by_frequency, reverse=True)
by_frequency = lambda node: edge_freq_dict[root_hub,node] \
if root_hub < node \
else edge_freq_dict[node, root_hub]
most_frequent_neighbor_list = sorted(graph.adj[root_hub],
key=by_frequency, reverse=True)
stat_dict['hubs'][root_hub] = most_frequent_neighbor_list[:6]
stat_dict['hubs'][root_hub] = most_frequent_neighbor_list[:6]
return graph, root_hub_list, stat_dict
return graph, root_hub_list, stat_dict
def colour_graph(graph: nx.Graph, root_hub_list: list) -> nx.Graph:
......@@ -692,9 +687,10 @@ def disambiguate_colour(graph: nx.Graph, root_hub_list: list, context_list: list
coloured_graph[sub_from][sub_to]['weight']
score[root_hub_idx] += (1/(1+total_weight)) \
* colour_graph.node[text]['dist'][root_hub_idx]
* coloured_graph.node[text]['dist'][root_hub_idx]
else:
pass
else:
......@@ -799,6 +795,34 @@ def disambiguate_mst(graph: nx.Graph, root_hub_list: list,
return mapping_dict
def print_stats(stat_dict: dict) -> None:
"""Prints various statistics and logs them to file.
Args:
stat_dict: Dictionary with various statistics.
"""
stat_string = []
ts = time.gmtime()
stat_string.append('[A] Topic:\t{}.'.format(stat_dict['target']))
stat_string.append('[A] Processed {} at {}.'.format(time.strftime("%Y-%m-%d", ts),time.strftime("%H:%M:%S", ts)))
stat_string.append('[A] Nodes: {}\tEdges: {}.'.format(stat_dict['node count'],stat_dict['edge count']))
stat_string.append('[A] Mean cluster length (harmonic):\t{}.'.format(stat_dict['hmean_cluster_length']))
stat_string.append('[A] Mean cluster length (arithmetic):\t{}.'.format(stat_dict['mean_cluster_length']))
stat_string.append('[A] Number of clusters: {}.'.format(stat_dict['cluster_count']))
stat_string.append('[A] Tuples gained through merging: {}.'.format(stat_dict['merge_gain']))
stat_string.append('[A] Sense inventory:')
for hub in stat_dict['hubs'].keys():
stat_string.append('[A] {}:\t{}.'.format(hub, ", ".join(stat_dict['hubs'][hub])))
with open('stats.txt', 'a') as stat_file:
stat_file.write('\n'.join(stat_string)+'\n\n')
print('\n'+'\n'.join(stat_string)+'\n')
def main(topic_id: int, topic_name: str, result_dict: dict) -> None:
"""Calls induction and disambiguation functions, performs main task.
......@@ -815,92 +839,101 @@ def main(topic_id: int, topic_name: str, result_dict: dict) -> None:
"""
print('[a]', 'Inducing word senses for {}.'.format(topic_name))
graph, root_hub_list, stat_dict = induce(topic_name,result_dict[topic_id])
colour_rank = config.colour_rank
mst_rank = config.mst_rank
#Merges Mappings according to pipeline
mapping_dict = dict()
if topic_name in [output_file_name.replace('.absinth', '')
for output_file_name in os.listdir(config.output)]:
return None
#matches senses to clusters
print('[a]', 'Disambiguating results.\t('+topic_name+')')
if colour_rank != 0:
else:
print('[a]', 'Colouring graph.\t('+topic_name+')')
mapping_dict[colour_rank] = disambiguate_colour(graph, root_hub_list,
result_dict[topic_id])
print('[a]', 'Inducing word senses for {}.'.format(topic_name))
if mst_rank != 0:
graph, root_hub_list, stat_dict = induce(topic_name, result_dict[topic_id])
print('[a]', 'Building minimum spanning tree.\t('+topic_name+')')
mapping_dict[mst_rank] = disambiguate_mst(graph,
root_hub_list,
result_dict[topic_id],
topic_name)
mapping_list = [item[1] for item in sorted(mapping_dict.items())]
mapping_count = len(mapping_list)
merged_mapping_dict = mapping_list[0]
merged_entry_count = 0
for i in range(1,mapping_count):
colour_rank = config.colour_rank
mst_rank = config.mst_rank
result_list = [result for result_list in merged_mapping_dict.values()
for result in result_list]
#individual mappings
relation_list = [(topic,result) for topic in mapping_list[i].keys()
for result in mapping_list[i][topic]]
#Merges Mappings according to pipeline
mapping_dict = dict()
for topic, result in relation_list:
#matches senses to clusters
print('[a]', 'Disambiguating results.\t('+topic_name+')')
if colour_rank != 0:
if result not in result_list:
merged_entry_count += 1
print('[a]', 'Colouring graph.\t('+topic_name+')')
mapping_dict[colour_rank] = disambiguate_colour(graph, root_hub_list,
result_dict[topic_id])
if mst_rank != 0:
print('[a]', 'Building minimum spanning tree.\t('+topic_name+')')
mapping_dict[mst_rank] = disambiguate_mst(graph,
root_hub_list,
result_dict[topic_id],
topic_name)
mapping_list = [item[1] for item in sorted(mapping_dict.items())]
mapping_count = len(mapping_list)
merged_mapping_dict = mapping_list[0]
merged_entry_count = 0
for i in range(1,mapping_count):
result_list = [result for result_list in merged_mapping_dict.values()
for result in result_list]
#individual mappings
relation_list = [(topic,result) for topic in mapping_list[i].keys()
for result in mapping_list[i][topic]]
for topic, result in relation_list:
if topic in merged_mapping_dict:
merged_mapping_dict[topic].append(result)
if result not in result_list:
else:
merged_mapping_dict[topic] = [result]
merged_entry_count += 1
stat_dict['merge_gain'] = merged_entry_count
#collect statistics from result.
cluster_count = 0
cluster_length_list = list()
for cluster,result_list in merged_mapping_dict.items():
if topic in merged_mapping_dict:
merged_mapping_dict[topic].append(result)
else:
merged_mapping_dict[topic] = [result]
stat_dict['merge_gain'] = merged_entry_count
cluster_length = len(result_list)
#collect statistics from result.
cluster_count = 0
cluster_length_list = list()
if cluster_length != 0:
for cluster,result_list in merged_mapping_dict.items():
cluster_count += 1
cluster_length_list.append(cluster_length)
cluster_length = len(result_list)
stat_dict['mean_cluster_length'] = np.mean(cluster_length_list)
stat_dict['cluster_count'] = cluster_count
if cluster_length != 0:
cluster_count += 1
cluster_length_list.append(cluster_length)
stat_dict['hmean_cluster_length'] = stats.hmean(cluster_length_list)
stat_dict['mean_cluster_length'] = np.mean(cluster_length_list)
stat_dict['cluster_count'] = cluster_count
print('[a]', 'Writing to file.\t('+topic_name+')')
output_path = config.output
output_file_name = output_path+topic_name+'.absinth'
with open(output_file_name, 'w') as output_file:
print('[a]', 'Writing to file.\t('+topic_name+')')
output_path = config.output
output_file_name = output_path+topic_name+'.absinth'
with open(output_file_name, 'w') as output_file:
output_file.write('subTopicID\tresultID\n')
output_file.write('subTopicID\tresultID\n')
for cluster_id,result_list in merged_mapping_dict.items():
for result_id in result_list:
output_line = '{}.{}\t{}.{}\n'.format(topic_id, cluster_id,
topic_id, result_id)
output_file.write(output_line)
pprint.pprint(stat_dict)
for cluster_id,result_list in merged_mapping_dict.items():
for result_id in result_list:
output_line = '{}.{}\t{}.{}\n'.format(topic_id, cluster_id,
topic_id, result_id)
output_file.write(output_line)
print_stats(stat_dict)
if __name__ == '__main__':
......@@ -921,11 +954,17 @@ if __name__ == '__main__':
# Enables manual setting of process count.
if '-p' in sys.argv:
process_count = int(sys.argv[sys.argv.index('-p') + 1])
with Pool(process_count) as pool:
parameter_list = [(topic_id, topic_name, result_dict)
for topic_id,topic_name in topic_dict.items()]
pool.starmap(main, sorted(parameter_list)) #determineate function
else:
process_count = 1
for topic_id, topic_name in sorted(topic_dict.items()):
main(topic_id, topic_name, result_dict)
with Pool(process_count) as pool:
parameter_list = [(topic_id, topic_name, result_dict)
for topic_id,topic_name in topic_dict.items()]
pool.starmap(main, sorted(parameter_list)) #determineate function
......@@ -20,7 +20,7 @@ Methods labeled with 0 are ignored.
At least one method must be given a value != 0.
'''
resolve_conflicts = False #not yet implemented
mst_rank = 0
mst_rank = 2
colour_rank = 1
'''
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment