Commit fcf51e56 authored by axtimhaus's avatar axtimhaus
Browse files

Updates baselines.

parent 8c25f2d7
Loading
Loading
Loading
Loading
+117 −45
Original line number Diff line number Diff line
@@ -9,15 +9,25 @@ import numpy as np
from gensim.models import word2vec
from gensim.models import KeyedVectors as kv
    

def load_nlp():
    print("Loading nlp model.", end="\r")
    nlp = en_core_web_sm.load()
    print("Done.             ")
    return nlp

def load_embeddings():
    fname = config.EMBEDDINGS_PATH
print("Loading model.", end="\r")
    print("Loading embeddings.", end="\r")
    try:
        model = kv.load(fname)
    except: 
        model = kv.load_word2vec_format(fname, binary=False)
    print("Done.              ")
    return model

#model = load_embeddings()
#nlp = load_nlp()

def closest_preceding_mention(toks, ents, ana):
    best_ment = (0,0)
@@ -25,7 +35,7 @@ def closest_preceding_mention(toks, ents, ana):
        for ment in ent:
            if ana[0] - ment[1] <  ana[0] - best_ment[1] and ment[1] < ana[0]:
                best_ment = ment
    return {best_ment}
    return best_ment

def head_match(toks, ents, ana, window=20):
    best_ment = (0,0)
@@ -37,8 +47,8 @@ def head_match(toks, ents, ana, window=20):
    for name, proc in nlp.pipeline:
        doc=proc(doc)
        
    ana_head = [token for token in doc if token.head == token][0].lemma_
    
    ana_head = [token for token in doc if token.head == token][0]
    print(ana_head, ana_toks)
    for ent in ents:
        for ment in ent:
            ment_toks = toks[ment[0]:ment[1]+1]
@@ -49,18 +59,18 @@ def head_match(toks, ents, ana, window=20):
                doc=proc(doc)
                
            try:
                ment_head = [token for token in doc if token.head == token][0].lemma_
                ment_head = [token for token in doc if token.head == token][0]
            except IndexError:
                continue
            
            if ana_head == ment_head:
            if ana_head.lemma_ == ment_head.lemma_:
                if ana[0] - ment[1] <  ana[0] - best_ment[1] and ment[1] < ana[0] and ana[0] - ment[1] < window:
                    best_ment = ment
    
    if best_ment == (0,0):
        return closest_preceding_mention(toks, ents, ana)
    else:
        return {best_ment}
        return best_ment
                
def highest_embedding_similarity(toks, ents, ana):
    best_ment = (0,0)
@@ -73,7 +83,7 @@ def highest_embedding_similarity(toks, ents, ana):
    for name, proc in nlp.pipeline:
        doc=proc(doc)
        
    ana_head = [token for token in doc if token.head == token][0].text
    ana_head = [token for token in doc if token.head == token][0].lemma_
    
    for ent in ents:
        for ment in ent:
@@ -85,34 +95,33 @@ def highest_embedding_similarity(toks, ents, ana):
                doc=proc(doc)
                
            try:
                ment_head = [token for token in doc if token.head == token][0].text
                ment_head = [token for token in doc if token.head == token][0].lemma_
            except IndexError:
                continue
            
            if ment[1] < ana[0]:
                sim = 1/(model.wv.wmdistance(ment_head, ana_head)+1)
                
            if best_sim <= sim and ment[1] < ana[0]:
                print(ment_head, sim, ana_head)
                if best_sim <= sim:
                    best_ment = ment
                    best_sim = sim
    
    if best_ment == (0,0):
    if best_sim < 0.5:
        return closest_preceding_mention(toks, ents, ana)
    else:
        return {best_ment}
        return best_ment
    

def most_salient_entity(toks, ents, ana):
    best_ent = {(0,0)}
    best_ent = (0,0)
    max_ent = 0
    for ent in ents:
        if len(ent) >= max_ent and any([ment[1] < ana[0] for ment in ent]):
            max_ent = len(ent)
            best_ent = ent
    return best_ent
    return list(best_ent)[0] 

def largest_entity_span(toks, ents, ana):
    best_ent = {(0,0)}
    best_ent = (0,0)
    max_span = 0
    for ent in ents:
        if ent:
@@ -123,42 +132,84 @@ def largest_entity_span(toks, ents, ana):
                max_span = span
                best_ent = ent
                
    if best_ment == {(0,0)}:
    if best_ent == (0,0):
        return closest_preceding_mention(toks, ents, ana)
    else:
        return {best_ment}
        return list(best_ent)[0]

def import_doc(path):
    doc = open(path, 'r').read()
    
    tokens, comps, coref = doc.split("\n\n")
    
    if "-1" in comps or "-1" in coref:
        add_one = 1
    else:
        add_one = 0
    
    tokens = [line.strip().split()[1] for line in tokens.split("\n")]
    anaphora = dict()
    for line in comps.split("\n"):
        line = [int(l) for l in line.strip().split()]
        line = [int(l)+add_one for l in line.strip().split()]
        if len(line) > 2:
            anaphora[(line[0],line[1])] = {(line[i],line[i+1]) for i in range(2,len(line),2)}
        else:
            anaphora[(line[0],line[1])] = set()
    entities = list()
    for line in coref.split("\n"):
        line = [int(l) for l in line.strip().split()]
        line = [int(l)+add_one for l in line.strip().split()]
        entities.append([(line[i],line[i+1]) for i in range(0,len(line),2)])
        
    return tokens, anaphora, entities


def strict_match(system_mention, gold_mention, exact=True):
        
    if exact:
        if system_mention == gold_mention:
            return True
        else:
            return False
    elif gold_mention:
        if abs(system_mention[0]-gold_mention[0]) + abs(system_mention[1]-gold_mention[1]) < 2:
            return True
        else:
            return False
    else:
        if system_mention == None:
            return True
        else:
            return False
        
def lenient_match(system_mention, gold_entity, exact=True):
    if exact:
        if system_mention in gold_entity:
            return True
        else:
            return False
    elif gold_entity:
        if any([abs(system_mention[0]-gold_mention[0]) 
                + abs(system_mention[1]-gold_mention[1]) 
                < 2 for gold_mention in gold_entity]):
            return True
        else:
            return False
    else:
        if system_mention == None:
            return True
        else:
            return False

if __name__ == "__main__":
    
    root = "/home/students/zimmermann/Projects/ncr/corpus/"
    root = "/home/students/zimmermann/Projects/ncr/new_corpus/"
    files = os.listdir(root)
    
    baseline = head_match
    baseline = closest_preceding_mention
    
    correct = 0
    incorrect = 0
    strict_exact = strict_approximate = 0
    lenient_exact = lenient_approximate = 0
    total = 0
    
    for f in files:
        path = root+f
@@ -166,16 +217,37 @@ if __name__ == "__main__":
        toks, anas, ents = import_doc(path)
        
        for ana in anas.keys():
            system_mention = baseline(toks, ents, ana)
            if anas[ana]:
                #ents = set.union(*anas.values(), *ents)
                system = baseline(toks, ents, ana)
                #if anas[ana] == system:
                if anas[ana] & set(system) != set():
                    correct += 1
                for ana_value in anas[ana]:
                    gold_mention = ana_value
                    
                    exact_gold_entity = approximate_gold_entity = {ana_value}
                    for entity in ents:
                        if ana_value in entity and len(entity) > 1:
                            exact_gold_entity = entity
                            approximate_gold_entity = entity
                            break
                        if any([abs(ana_value[0] - mention[0]) 
                                + abs(ana_value[1]- mention[1]) 
                                < 2 for mention in entity]):
                            approximate_gold_entity = entity
                        
                    total += 1
                    strict_exact += 1 if strict_match(system_mention, gold_mention, True) else 0
                    strict_approximate += 1 if strict_match(system_mention, gold_mention, False) else 0
                    lenient_exact += 1 if lenient_match(system_mention, exact_gold_entity, True) else 0
                    lenient_approximate += 1 if lenient_match(system_mention, approximate_gold_entity, False) else 0
            else:
                    incorrect += 1
                gold_mention = None
                gold_entity = {}
                
        print(strict_exact, strict_exact/total)
        print(strict_approximate, strict_approximate/total)
        print(lenient_exact, lenient_exact/total)
        print(lenient_approximate, lenient_approximate/total)
                
            
        print(correct, incorrect, correct/(correct+incorrect))