Commit 1764b5be authored by axtimhaus's avatar axtimhaus
Browse files

Add baselines and extraction from ISnotes

parent d9d5edeb
Loading
Loading
Loading
Loading

source/__init__.py

0 → 100644
+0 −0

Empty file added.

source/baselines.py

0 → 100644
+102 −0
Original line number Diff line number Diff line
import numpy as np
import logging
import os
import argparse
import gensim
from gensim.models import word2vec
from gensim.models import KeyedVectors as kv

import extraction
import evaluation
import config

fname = config.EMBEDDINGS_PATH
print("Loading model.", end="\r")
try:
    model = kv.load(fname)
except: 
    model = kv.load_word2vec_format(fname, binary=False)
print("Done.         ")


def next_mention(mentions, anaphora):
    sims = [1/(anaphora.span[0] - mention.span[1]) if mention.span[1] < anaphora.span[0] else 0 for mention in candidates]
    return mentions[np.argmax(sims)]

def head_match(mentions, anaphora):
    pass

def best_embedding(mentions, anaphora):
    sims = [1/model.wv.wmdistance(candidate.tokens, anaphora.tokens) if mention.span[1] < anaphora.span[0] else 0 for mention in mentions]
    return mentions[np.argmax(sims)]

def most_salient_entity(entities, anaphora):
    lens = [len(entity) if entity[0].span[1] < anaphora.span[0] else 0 for entity in entities]
    return entities[np.argmax(lens)]

def largest_entity_span(entities, anaphora):
    spans = [entity[-1].span[1] - entity[0].span[0] if entity[0].span[1] < anaphora.span[0] else 0 for entity in entities]
    return entities[np.argmax(spans)]
    
if __name__ == "__main__":
    false = {"next_mention": 0, "best_embedding": 0, "mse": 0, "les": 0}
    true = {"next_mention": 0, "best_embedding": 0, "mse": 0, "les": 0}
    
    root = config.DATA_PATH
    files = os.listdir(root)
    
    for f in files[:10]:
        
        print("Fetching test data...")
        mentions = extraction.get_mentions(root+f)
        
        comparatives = [m for m in mentions if m.comp_from]
        print(comparatives)
        
        entities = []
        entity_dict = {}
        for mention in mentions:
            if not mention.coref_set:
                entities.append(mention)
            else:
                if mention.coref_set in entity_dict:
                    entity_dict[mention.coref_set].append(mention)
                else:
                    entity_dict[mention.coref_set] = [mention]
        
        entities += list(entity_dict.values())


        closest = lambda E, c: E[np.argmax([1/(c.span[0] - e.span[1]) if c.span[0] > e.span[1] else 0 for e in E])] 
        
        for comp in comparatives:
            
            if next_mention(mentions, comp).mention_id == comp.comp_from:
                true["next_mention"] += 1
            else:
                false["next_mention"] += 1
            if best_embedding(mentions, comp).mention_id == comp.comp_from:
                true["best_embedding"] += 1
            else:
                false["best_embedding"] += 1
            
            mse_result = most_salient_entity(entities, comp)
            les_result = largest_entity_span(entities, comp)
            
            if closest(mse_result).mention_id == comp.comp_from:
                true["mse"] += 1
            else:
                false["mse"] += 1
            if closest(les_result).mention_id == comp.comp_from:
                true["les"] += 1
            else:
                false["les"] += 1
        
        recall = {"next_mention": true["next_mention"]/(true["next_mention"]+false["next_mention"]), "best_embedding": true["best_embedding"]/(true["best_embedding"]+false["best_embedding"]), "mse": true["mse"]/(true["mse"]+false["mse"]), "les": true["les"]/(true["les"]+false["les"])}
        
        print("Recall after {}:".format(f))
        print("next_mention:\t{}\t(True: {}, False:{})".format(recall["next_mention"], true["next_mention"], false["next_mention"]))
        print("best_embedding:\t{}\t(True: {}, False:{})".format(recall["best_embedding"], true["best_embedding"], false["best_embedding"]))
        
        print("most_salient_entity:\t{}\t(True: {}, False:{})".format(recall["mse"], true["mse"], false["mse"]))
        print("largest_entity_span:\t{}\t(True: {}, False:{})".format(recall["les"], true["les"], false["les"]))

source/config.py

0 → 100644
+2 −0
Original line number Diff line number Diff line
EMBEDDINGS_PATH = "/softpro/ss18/kernseife/kernseife/data/en/embeddings/glove.twitter.27B.100d.txt"
DATA_PATH = "/proj/zimmermann/isnotes/ISClean/"

source/evaluation.py

0 → 100644
+0 −0

Empty file added.

source/extraction.py

0 → 100644
+75 −0
Original line number Diff line number Diff line
from bs4 import BeautifulSoup
import os
import config
import re

from model import Mention

def get_mentions(path):
    
    split_path = path.split(r"/")
    root, file_name = "/".join(split_path[:-1]), split_path[-1]
    
    mentions = []
    
    if file_name.endswith(".mmax"):
        
        print(file_name)
        
        file_root = file_name[:-5]
        
        basedata = BeautifulSoup(open(root+"/Basedata/"+file_root+"_words.xml", "r").read(), features="lxml")
        entities = BeautifulSoup(open(root+"/markables/"+file_root+"_entity_level.xml", "r").read(), features="lxml")
        coref = BeautifulSoup(open(root+"/markables/"+file_root+"_coref_level.xml", "r").read(), features="lxml")
        
        for entity in entities.contents[2].contents[0].contents[0].find_all():
            
            span = entity["span"]
            tokens = []
            mention_id = entity["id"]
            
            # comparative: either "outsidetext" or antecedent id
            comp_from = None
            if entity["information_status"] == "mediated":
                try:
                    if entity["mediated_type"] == "comparative":
                        print("Hi!")
                        comp_from = entity["comparative_type"]
                        if comp_from == "withintext":
                            comp_from = entity["comp_from"]
                except:
                    pass
            
            try:
                coref_set = coref.find("markable", {"span": span})["coref_set"]
            except:
                coref_set = None
            
            if ".." in span:
                
                span_ids = [int(re.sub("word_", "", word)) for word in span.split("..")]
                for i in range(span_ids[0],span_ids[1]+1):
                    token = basedata.find("word", {"id": "word_"+str(i)}).contents[0]
                    tokens.append(token)
                
                mentions.append(Mention( tokens, [span_ids[0],span_ids[1]], mention_id, coref_set, comp_from ))
                
            else:
                
                token = basedata.find("word", {"id": span}).contents[0]
                tokens.append(token)
                
                word_id = int(span.split("_")[-1])
                mentions.append(Mention( tokens, [word_id, word_id], mention_id, coref_set, comp_from ))
    
    return mentions

def get_entities():
    pass        
        
if __name__ == "__main__":
    root = config.DATA_PATH
    files = os.listdir(root)
    
    for f in files:
        get_mentions(root+f)
Loading