Commit feb5b438 authored by axtimhaus's avatar axtimhaus
Browse files

Add extraction from conll files

parent 1764b5be
Loading
Loading
Loading
Loading
+25 −13
Original line number Diff line number Diff line
@@ -20,14 +20,23 @@ print("Done. ")


def next_mention(mentions, anaphora):
    sims = [1/(anaphora.span[0] - mention.span[1]) if mention.span[1] < anaphora.span[0] else 0 for mention in candidates]
    sims = [1/(anaphora.span[0] - mention.span[1]) if mention.span[1] < anaphora.span[0] else 0 for mention in mentions]
    return mentions[np.argmax(sims)]

def head_match(mentions, anaphora):
    pass
    # work in progress
    match_list = list()
    for mention in mentions:
        if mention.span[1] < anaphora.span[0]:
            overlap = len([m for m in mention.tokens if m in anaphora.tokens])
            overlap_ratio = 2*overlap/(len(mention.tokens)+len(anaphora.tokens))
            match_list.append(overlap_ratio)
        else:
            match_list.append(0)
    return mentions[np.argmax(match_list)]
        
def best_embedding(mentions, anaphora):
    sims = [1/model.wv.wmdistance(candidate.tokens, anaphora.tokens) if mention.span[1] < anaphora.span[0] else 0 for mention in mentions]
    sims = [1/(model.wv.wmdistance(mention.tokens, anaphora.tokens)+1) if mention.span[1] < anaphora.span[0] else 0 for mention in mentions]
    return mentions[np.argmax(sims)]

def most_salient_entity(entities, anaphora):
@@ -39,25 +48,24 @@ def largest_entity_span(entities, anaphora):
    return entities[np.argmax(spans)]
    
if __name__ == "__main__":
    false = {"next_mention": 0, "best_embedding": 0, "mse": 0, "les": 0}
    true = {"next_mention": 0, "best_embedding": 0, "mse": 0, "les": 0}
    false = {"next_mention": 0, "best_embedding": 0, "head":0, "mse": 0, "les": 0}
    true = {"next_mention": 0, "best_embedding": 0, "head":0, "mse": 0, "les": 0}
    
    root = config.DATA_PATH
    files = os.listdir(root)
    
    for f in files[:10]:
    for f in files[:]:
        
        print("Fetching test data...")
        mentions = extraction.get_mentions(root+f)
        
        comparatives = [m for m in mentions if m.comp_from]
        print(comparatives)
        
        entities = []
        entity_dict = {}
        for mention in mentions:
            if not mention.coref_set:
                entities.append(mention)
                entities.append([mention])
            else:
                if mention.coref_set in entity_dict:
                    entity_dict[mention.coref_set].append(mention)
@@ -79,24 +87,28 @@ if __name__ == "__main__":
                true["best_embedding"] += 1
            else:
                false["best_embedding"] += 1
            if head_match(mentions, comp).mention_id == comp.comp_from:
                true["head"] += 1
            else:
                false["head"] += 1
            
            mse_result = most_salient_entity(entities, comp)
            les_result = largest_entity_span(entities, comp)
            
            if closest(mse_result).mention_id == comp.comp_from:
            if closest(mse_result, comp).mention_id == comp.comp_from:
                true["mse"] += 1
            else:
                false["mse"] += 1
            if closest(les_result).mention_id == comp.comp_from:
            if closest(les_result, comp).mention_id == comp.comp_from:
                true["les"] += 1
            else:
                false["les"] += 1
        
        recall = {"next_mention": true["next_mention"]/(true["next_mention"]+false["next_mention"]), "best_embedding": true["best_embedding"]/(true["best_embedding"]+false["best_embedding"]), "mse": true["mse"]/(true["mse"]+false["mse"]), "les": true["les"]/(true["les"]+false["les"])}
        recall = {"next_mention": true["next_mention"]/(true["next_mention"]+false["next_mention"]), "best_embedding": true["best_embedding"]/(true["best_embedding"]+false["best_embedding"]), "head": true["head"]/(true["head"]+false["head"]), "mse": true["mse"]/(true["mse"]+false["mse"]), "les": true["les"]/(true["les"]+false["les"])}
        
        print("Recall after {}:".format(f))
        print("next_mention:\t{}\t(True: {}, False:{})".format(recall["next_mention"], true["next_mention"], false["next_mention"]))
        print("best_embedding:\t{}\t(True: {}, False:{})".format(recall["best_embedding"], true["best_embedding"], false["best_embedding"]))
        
        print("head_match:\t{}\t(True: {}, False:{})".format(recall["head"], true["head"], false["head"]))
        print("most_salient_entity:\t{}\t(True: {}, False:{})".format(recall["mse"], true["mse"], false["mse"]))
        print("largest_entity_span:\t{}\t(True: {}, False:{})".format(recall["les"], true["les"], false["les"]))
+3 −0
Original line number Diff line number Diff line
EMBEDDINGS_PATH = "/softpro/ss18/kernseife/kernseife/data/en/embeddings/glove.twitter.27B.100d.txt"
DATA_PATH = "/proj/zimmermann/isnotes/ISClean/"
CONLL_PATH = "/resources/corpora/multilingual/ontonotes-5.0-conll-2012/conll-2012/v4/data/train/data/english/"

COMPARATIVES = ["other", "similar", "comparable", "different", "additional", "extra"]
+129 −10
Original line number Diff line number Diff line
@@ -2,10 +2,12 @@ from bs4 import BeautifulSoup
import os
import config
import re
import numpy as np

from model import Mention
from model import Mention, Entity

def get_mentions(path):

def from_isnotes(path):
    
    split_path = path.split(r"/")
    root, file_name = "/".join(split_path[:-1]), split_path[-1]
@@ -33,7 +35,6 @@ def get_mentions(path):
            if entity["information_status"] == "mediated":
                try:
                    if entity["mediated_type"] == "comparative":
                        print("Hi!")
                        comp_from = entity["comparative_type"]
                        if comp_from == "withintext":
                            comp_from = entity["comp_from"]
@@ -64,12 +65,130 @@ def get_mentions(path):
    
    return mentions

def get_entities():
    pass        
def from_conll(path):
    #pass
    conll = open(path, 'r')
    
    mentions = list()
    
    text = list()
    poss = list()
    
    word_id = 0
    mention_id = 0
    coref_id = 0
    
    comp_count = 0
    ment_count = 0
    
    for line in conll:
        if line.startswith("#begin document"):
            stack = list()
            coref_stack = dict()
        elif line.startswith("#end document"):
            if comp_count > 0:
                for m in mentions:
                    print(m)
            return (ment_count, comp_count, mentions)
        elif line == "\n":
            stack = list()
        else:
            naked_line = line.strip()
            split_line = naked_line.split()
            # Important indices:
            # 1:doc_id, 2:token_id,     3:token,
            # 4:POS,    5:syntax_tree, -1:Coref
        
            coref = split_line[-1].split("|")
            
            coref_stack = {**coref_stack, **{int(c): word_id for c in re.findall("\(([0-9]+)", split_line[-1])}}
            
            coref_spans = dict()
            for c in re.findall("([0-9]+)\)", split_line[-1]):
                if int(c) not in coref_stack:
                    raise Exception
                coref_spans[(coref_stack[int(c)], word_id)] = int(c)
            
            syntax = split_line[5].split("*")
            
            opening = [str(f) for f in re.findall("\(([A-Z]+)",syntax[0])]
            for paren in opening:
                stack.append((word_id, paren, split_line[3], split_line[4]))
            
if __name__ == "__main__":
    root = config.DATA_PATH
    files = os.listdir(root)
            
            text.append(split_line[3])
            poss.append(split_line[4])
            
            closing = [str(f) for f in re.finditer("\)",syntax[1])]
            
            
            for paren in closing:
                start = stack[-1]
                if start[1] == "NP" or start[1] == "NML" or start[1] == "NX" or start[1] == "NAC":
                    ment_count += 1
                    
                    rules = list()
                    
                    rules.append(start[3]=="JJR")
                    rules.append(text[start[0]].lower() == "more" 
                                 and len(poss) > start[0]+1 
                                 and poss[start[0]+1] == "JJ")
                    rules.append(text[start[0]].lower() in config.COMPARATIVES)
                    rules.append(text[start[0]].lower() == "others")
                    
                    if any(rules) and not "than" in text[start[0]:word_id+1]:
                        comp_count += 1
                        comp_from = "undefined"
                    else:
                        comp_from = None
                    
                    if (start[0],word_id) in coref_spans:
                        coref_set = coref_spans[(start[0],word_id)]
                    else:
                        coref_set = None
                    
                    mentions.append(Mention(text[start[0]:word_id+1], [start[0], word_id], mention_id, coref_set, comp_from))
                    
                stack = stack[:-1]
                
                
                
                
                mention_id += 1
            
            
            word_id += 1
            
    return (ment_count, comp_count, mentions)
                
    
if __name__ == "__main__":
    # ISnotes
    #root = config.DATA_PATH
    #files = os.listdir(root)
    
    #for f in files:
        #get_mentions(root+f)
    
    # conll/OntoNotes
    root = config.CONLL_PATH
    
    files = []
    for x in os.walk(root):
        for y in x[2]:
            print(y)
            if y.endswith(".v4_gold_conll"):
                files.append(x[0]+"/"+y)
    
    mcount = 0
    ccount = 0
    dist = np.zeros(26)
    for f in files:
        get_mentions(root+f)
        plus_m, plus_c, mentions = from_conll(f)
        if plus_c > 0:
            dist[plus_c] += 1
            mcount += plus_m
            ccount += plus_c
    print(mcount, ccount)
    print(dist[1:])
    
+0 −0

File moved.

+0 −0

File moved.

Loading