Skip to content
Snippets Groups Projects
Commit 755bc6f9 authored by Victor Zimmermann's avatar Victor Zimmermann
Browse files

Add visualisation and corpus reform.

parent c5af1a97
No related branches found
No related tags found
No related merge requests found
......@@ -9,28 +9,21 @@ import config
import spacy # for nlp
from multiprocessing import Pool
import random
import matplotlib.pyplot as plt
nlp = spacy.load('en') # standard english nlp
#counts occurences of nodes and cooccurrences
def frequencies(corpus_path, target):
def frequencies(corpus_path, target, results):
random.seed(1)
stop_words = set(stopwords.words('english') + config.stop_words)
allowed_tags = config.allowed_tags
min_context_size = config.min_context_size
max_nodes = config.max_nodes
max_edges = config.max_edges
node_freq = dict() #counts (potential) nodes
edge_freq = dict() #counts (potential) edges
results = [r.replace('<b>', '').replace('</b>', '').replace(r'\\', '').strip() for r in results]
node_freq, edge_freq = process_file(results, target) #initialises frequencies with counts from results
s_target = target.replace('_', ' ') #target word with spaces
files = [corpus_path + f for f in os.listdir(corpus_path)] #file names of corpus files
random.shuffle(files)
i = 0 #for update print statements
for f in files:
......@@ -49,67 +42,17 @@ def frequencies(corpus_path, target):
#checks maximum node values
if len(node_freq) > max_nodes:
print('[a] 100%\tNodes: {}\tEdges: {}.'.format(len(node_freq), len(edge_freq))+'\t('+target+')')
return node_freq, edge_freq
#checks maximum edge values
if len(edge_freq) > max_edges:
print('[a] 100%\tNodes: {}\tEdges: {}.'.format(len(node_freq), len(edge_freq))+'\t('+target+')')
return node_freq, edge_freq
with open(f, 'r') as lines: #parses single file
try:
for line in lines: #parses single paragraph
line = line.lower()
if s_target in line: #greedy pre selection, not perfect
tokens = set() #set of node candidates
doc = nlp(line.replace(s_target, target)) #nlp processing
if target in [t.text for t in doc]: #better selection
for tok in doc:
text = tok.text #string value
tag = tok.tag_ #pos tag
#doesn't add target word to nodes
if text == target:
pass
#doesn't add stop words to nodes
elif text in stop_words:
pass
#only adds tokens with allowed tags to nodes
elif tag in allowed_tags:
tokens.add(tok.text)
#if there are enough (good) tokens in paragraph
if len(tokens) >= min_context_size:
for token in tokens:
#updates counts for nodes
if token in node_freq:
node_freq[token] += 1
else:
node_freq[token] = 1
for edge in {(x,y) for x in tokens for y in tokens if x < y}:
#updates counts for edges
if edge in edge_freq:
edge_freq[edge] += 1
else:
edge_freq[edge] = 1
#if a file is corrupted (can't always be catched with if-else)
except UnicodeDecodeError:
pass
#print('Failed to decode:', f)
node_freq, edge_freq = process_file(lines, target, node_freq, edge_freq)
i += 1
......@@ -118,6 +61,74 @@ def frequencies(corpus_path, target):
return node_freq, edge_freq
def process_file(lines, target, node_freq=None, edge_freq=None):
if node_freq is None:
node_freq = dict()
if edge_freq is None:
edge_freq = dict()
s_target = target.replace('_', ' ') #target word with spaces
stop_words = set(stopwords.words('english') + config.stop_words)
allowed_tags = config.allowed_tags
min_context_size = config.min_context_size
try:
for line in lines: #parses single paragraph
line = line.lower()
if s_target in line: #greedy pre selection, not perfect
tokens = set() #set of node candidates
doc = nlp(line.replace(s_target, target)) #nlp processing
if target in [t.text for t in doc]: #better selection
for tok in doc:
text = tok.text #string value
tag = tok.tag_ #pos tag
#doesn't add target word to nodes
if text == target:
pass
#doesn't add stop words to nodes
elif text in stop_words:
pass
#only adds tokens with allowed tags to nodes
elif tag in allowed_tags:
tokens.add(tok.text)
#if there are enough (good) tokens in paragraph
if len(tokens) >= min_context_size:
for token in tokens:
#updates counts for nodes
if token in node_freq:
node_freq[token] += 1
else:
node_freq[token] = 1
for edge in {(x,y) for x in tokens for y in tokens if x < y}:
#updates counts for edges
if edge in edge_freq:
edge_freq[edge] += 1
else:
edge_freq[edge] = 1
#if a file is corrupted (can't always be catched with if-else)
except UnicodeDecodeError:
pass
#print('Failed to decode:', f)
return node_freq, edge_freq
#build graph from frequency dictionaries
def build_graph(node_freq, edge_freq):
......@@ -235,7 +246,7 @@ def score(graph, from_node, to_node):
# Basically Word Sense Disambiguation, matches context to sense
def disambiguate(mst, hubs, contexts, target=""):
def disambiguate(mst, hubs, contexts, target):
target = target.replace('_', ' ')
T = mst #minimum spanning tree
......@@ -250,9 +261,11 @@ def disambiguate(mst, hubs, contexts, target=""):
return {0:[i for i in range(1, len(C)+1)]}
idx = 0
for c in C:
idx = C.index(c) + 1 #index based on position in list
idx += 1 #index based on position in list
doc = nlp(c) #parsed context
texts = [tok.text for tok in doc] #tokens
......@@ -283,7 +296,7 @@ def disambiguate(mst, hubs, contexts, target=""):
pass
#if the disambiguator could not detect a sense, it should return a singleton, ie. nothing
#if the disambiguator could not detect a sense, it should return a singleton, ie. nothing
if np.max(scores) == 0:
pass
......@@ -299,6 +312,10 @@ def disambiguate(mst, hubs, contexts, target=""):
return mapping_dict
def draw_graph(G, name):
nx.draw_networkx(G,pos=nx.spring_layout(G), with_labels=True, node_size=40, font_size=9, node_color='#2D98DA')
plt.savefig('../figures/'+name+'.png', dpi=200, bbox_inches='tight')
plt.clf()
# our main function, here the main stepps for word sense induction are called
def WSI(topic_id, topic_name, results):
......@@ -333,11 +350,12 @@ def WSI(topic_id, topic_name, results):
#counts occurences of single words, as well as cooccurrences, saves it in dictionary
print('[a]', 'Counting nodes and edges.\t('+old_target+')')
node_freq, edge_freq = frequencies(corpus_path, target)
node_freq, edge_freq = frequencies(corpus_path, target, results[topic_id])
#builds graph from these dictionaries, also applies multiple filters
print('[a]', 'Building graph.\t('+old_target+')')
G = build_graph(node_freq, edge_freq)
draw_graph(G, topic_name.strip()+'_g')
out_buffer += '[A] Nodes: {}\tEdges: {}\n'.format(str(len(G.nodes)), str(len(G.edges)))
#finds root hubs (senses) within the graph + more filters for these
......@@ -356,6 +374,7 @@ def WSI(topic_id, topic_name, results):
#performs minimum_spanning_tree algorithm on graph
print('[a]', 'Building minimum spanning tree.\t('+old_target+')')
T = components(G, H, target)
draw_graph(T, topic_name.strip()+'_t')
#matches senses to clusters
print('[a]', 'Disambiguating results.\t('+old_target+')')
......@@ -412,7 +431,7 @@ if __name__ == '__main__':
topics[l[0]] = l[1]
# multiprocessing
with Pool(4) as pool:
with Pool(5) as pool:
# calls WSI() for for topics at a time
pool.starmap(WSI, [(key, value, results) for key,value in topics.items()])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment