From 755bc6f902be59e27352786ee6038b3893384217 Mon Sep 17 00:00:00 2001
From: Victor Zimmermann <zimmermann@cl.uni-heidelberg.de>
Date: Wed, 14 Mar 2018 19:26:34 +0100
Subject: [PATCH] Add visualisation and corpus reform.

---
 src/absinth.py | 157 +++++++++++++++++++++++++++----------------------
 1 file changed, 88 insertions(+), 69 deletions(-)

diff --git a/src/absinth.py b/src/absinth.py
index 9715e07..4c1b106 100644
--- a/src/absinth.py
+++ b/src/absinth.py
@@ -9,28 +9,21 @@ import config
 import spacy # for nlp
 from multiprocessing import Pool
 import random
+import matplotlib.pyplot as plt
 
 nlp = spacy.load('en') # standard english nlp
 
 #counts occurences of nodes and cooccurrences
-def frequencies(corpus_path, target):
+def frequencies(corpus_path, target, results):
     
-    random.seed(1)
-    
-    stop_words = set(stopwords.words('english') + config.stop_words)
-    allowed_tags = config.allowed_tags
-    min_context_size = config.min_context_size
     max_nodes = config.max_nodes
     max_edges = config.max_edges
     
-    node_freq = dict() #counts (potential) nodes
-    edge_freq = dict() #counts (potential) edges
+    results = [r.replace('<b>', '').replace('</b>', '').replace(r'\\', '').strip() for r in results]
+    node_freq, edge_freq = process_file(results, target) #initialises frequencies with counts from results
     
-    s_target = target.replace('_', ' ') #target word with spaces
     files = [corpus_path + f for f in os.listdir(corpus_path)] #file names of corpus files
     
-    random.shuffle(files)
-    
     i = 0 #for update print statements
     for f in files:
         
@@ -49,67 +42,17 @@ def frequencies(corpus_path, target):
         
         #checks maximum node values
         if len(node_freq) > max_nodes:
+            print('[a] 100%\tNodes: {}\tEdges: {}.'.format(len(node_freq), len(edge_freq))+'\t('+target+')')
             return node_freq, edge_freq
         
         #checks maximum edge values
         if len(edge_freq) > max_edges:
+            print('[a] 100%\tNodes: {}\tEdges: {}.'.format(len(node_freq), len(edge_freq))+'\t('+target+')')
             return node_freq, edge_freq
         
         with open(f, 'r') as lines: #parses single file
             
-            try:
-                
-                for line in lines: #parses single paragraph
-                    
-                    line = line.lower()
-                    
-                    if s_target in line: #greedy pre selection, not perfect
-                        
-                        tokens = set() #set of node candidates
-                        doc = nlp(line.replace(s_target, target)) #nlp processing
-                        
-                        if target in [t.text for t in doc]: #better selection
-                            
-                            for tok in doc:
-                                
-                                text = tok.text #string value
-                                tag = tok.tag_ #pos tag
-                                
-                                #doesn't add target word to nodes
-                                if text == target:
-                                    pass
-                                
-                                #doesn't add stop words to nodes
-                                elif text in stop_words:
-                                    pass
-                                
-                                #only adds tokens with allowed tags to nodes
-                                elif tag in allowed_tags:
-                                    tokens.add(tok.text)
-                                    
-                            #if there are enough (good) tokens in paragraph
-                            if len(tokens) >= min_context_size:
-                                for token in tokens:
-                                    
-                                    #updates counts for nodes
-                                    if token in node_freq:
-                                        node_freq[token] += 1
-                                    else:
-                                        node_freq[token] = 1
-                                
-                                for edge in {(x,y) for x in tokens for y in tokens if x < y}:
-                                    
-                                    #updates counts for edges
-                                    if edge in edge_freq:
-                                        edge_freq[edge] += 1
-                                    else:
-                                        edge_freq[edge] = 1
-            
-            #if a file is corrupted (can't always be catched with if-else)
-            except UnicodeDecodeError:
-                
-                pass
-                #print('Failed to decode:', f)              
+            node_freq, edge_freq = process_file(lines, target, node_freq, edge_freq)
         
         i += 1
     
@@ -118,6 +61,74 @@ def frequencies(corpus_path, target):
     
     return node_freq, edge_freq
 
+def process_file(lines, target, node_freq=None, edge_freq=None):
+
+    if node_freq is None:
+        node_freq = dict()
+    if edge_freq is None:
+        edge_freq = dict()
+    
+    s_target = target.replace('_', ' ') #target word with spaces
+    
+    stop_words = set(stopwords.words('english') + config.stop_words)
+    allowed_tags = config.allowed_tags
+    min_context_size = config.min_context_size
+        
+    try:
+        
+        for line in lines: #parses single paragraph
+            
+            line = line.lower()
+            
+            if s_target in line: #greedy pre selection, not perfect
+                
+                tokens = set() #set of node candidates
+                doc = nlp(line.replace(s_target, target)) #nlp processing
+                
+                if target in [t.text for t in doc]: #better selection
+                    
+                    for tok in doc:
+                        
+                        text = tok.text #string value
+                        tag = tok.tag_ #pos tag
+                        
+                        #doesn't add target word to nodes
+                        if text == target:
+                            pass
+                        
+                        #doesn't add stop words to nodes
+                        elif text in stop_words:
+                            pass
+                        
+                        #only adds tokens with allowed tags to nodes
+                        elif tag in allowed_tags:
+                            tokens.add(tok.text)
+                            
+                    #if there are enough (good) tokens in paragraph
+                    if len(tokens) >= min_context_size:
+                        for token in tokens:
+                            
+                            #updates counts for nodes
+                            if token in node_freq:
+                                node_freq[token] += 1
+                            else:
+                                node_freq[token] = 1
+                        
+                        for edge in {(x,y) for x in tokens for y in tokens if x < y}:
+                            
+                            #updates counts for edges
+                            if edge in edge_freq:
+                                edge_freq[edge] += 1
+                            else:
+                                edge_freq[edge] = 1
+    
+    #if a file is corrupted (can't always be catched with if-else)
+    except UnicodeDecodeError:
+        
+        pass
+        #print('Failed to decode:', f)              
+    
+    return node_freq, edge_freq
 
 #build graph from frequency dictionaries
 def build_graph(node_freq, edge_freq):
@@ -235,7 +246,7 @@ def score(graph, from_node, to_node):
 
 
 # Basically Word Sense Disambiguation, matches context to sense
-def disambiguate(mst, hubs, contexts, target=""):
+def disambiguate(mst, hubs, contexts, target):
     
     target = target.replace('_', ' ')
     T = mst #minimum spanning tree
@@ -250,9 +261,11 @@ def disambiguate(mst, hubs, contexts, target=""):
             
         return {0:[i for i in range(1, len(C)+1)]}
     
+    idx = 0
+    
     for c in C:
         
-        idx = C.index(c) + 1 #index based on position in list
+        idx += 1 #index based on position in list
     
         doc = nlp(c) #parsed context
         texts = [tok.text for tok in doc] #tokens
@@ -283,7 +296,7 @@ def disambiguate(mst, hubs, contexts, target=""):
             
                 pass
         
-            #if the disambiguator could not detect a sense, it should return a singleton, ie. nothing
+        #if the disambiguator could not detect a sense, it should return a singleton, ie. nothing
         if np.max(scores) == 0:
             
             pass
@@ -299,6 +312,10 @@ def disambiguate(mst, hubs, contexts, target=""):
 
     return mapping_dict
 
+def draw_graph(G, name):
+    nx.draw_networkx(G,pos=nx.spring_layout(G), with_labels=True, node_size=40, font_size=9, node_color='#2D98DA')
+    plt.savefig('../figures/'+name+'.png', dpi=200, bbox_inches='tight')
+    plt.clf()
 
 # our main function, here the main stepps for word sense induction are called
 def WSI(topic_id, topic_name, results):
@@ -333,11 +350,12 @@ def WSI(topic_id, topic_name, results):
     
     #counts occurences of single words, as well as cooccurrences, saves it in dictionary
     print('[a]', 'Counting nodes and edges.\t('+old_target+')')
-    node_freq, edge_freq = frequencies(corpus_path, target)
+    node_freq, edge_freq = frequencies(corpus_path, target, results[topic_id])
     
     #builds graph from these dictionaries, also applies multiple filters
     print('[a]', 'Building graph.\t('+old_target+')')
     G = build_graph(node_freq, edge_freq)
+    draw_graph(G, topic_name.strip()+'_g')
     out_buffer += '[A] Nodes: {}\tEdges: {}\n'.format(str(len(G.nodes)), str(len(G.edges)))
     
     #finds root hubs (senses) within the graph + more filters for these
@@ -356,6 +374,7 @@ def WSI(topic_id, topic_name, results):
     #performs minimum_spanning_tree algorithm on graph
     print('[a]', 'Building minimum spanning tree.\t('+old_target+')')
     T = components(G, H, target)
+    draw_graph(T, topic_name.strip()+'_t')
 
     #matches senses to clusters
     print('[a]', 'Disambiguating results.\t('+old_target+')')
@@ -412,7 +431,7 @@ if __name__ == '__main__':
             topics[l[0]] = l[1]
     
     # multiprocessing
-    with Pool(4) as pool:
+    with Pool(5) as pool:
         # calls WSI() for for topics at a time
         pool.starmap(WSI, [(key, value, results) for key,value in topics.items()])
         
-- 
GitLab