From fa97eff1b943bfa1e1783791d3261ccce76972b7 Mon Sep 17 00:00:00 2001
From: vvye <ekaiser.hellwege@gmail.com>
Date: Wed, 29 Sep 2021 01:49:12 +0200
Subject: [PATCH] Implement dataset selection via argument

---
 dataset.py             | 18 ++++++++++++++++
 misc.py                | 49 +++++++++++++++++++++++++++++++++++++++++-
 run.py                 |  9 +++++++-
 timeline_generation.py |  4 +---
 4 files changed, 75 insertions(+), 5 deletions(-)

diff --git a/dataset.py b/dataset.py
index d2e0f3c..0f1d9bd 100644
--- a/dataset.py
+++ b/dataset.py
@@ -129,6 +129,10 @@ def get_crisis_dataset():
             for article_filename in util.files(date_path, extension='.cont'):
                 article_file_path = date_path / article_filename
 
+                # TODO separate sentences and tokenize (the input files are not tokenized and not one-sentence-per-line)
+                # (probably best to create a temporary file with the correct format and run heideltime on that
+                # and also do the first-line-skipping business there)
+
                 # this one breaks heideltime due to character encoding shenanigans
                 if topic_name == 'egypt' and '2429.htm.cont' in str(article_file_path):
                     continue
@@ -176,5 +180,19 @@ def get_crisis_dataset():
     return data
 
 
+def get_entities_dataset():
+    """
+    Returns the ENTITIES dataset as a dictionary.
+    If data/in/entities/entities.pkl exists, it will be loaded from there,
+    otherwise, it will be parsed from scratch (assuming the default folder structure).
+
+    :return: A dictionary containing the dataset.
+    """
+
+    # TODO actually build the dataset from the input files?
+    cache_filename = Path('data/in/entities/entities.pkl')
+    return pickle.load(open(cache_filename, 'rb'))
+
+
 def filter_articles_by_date(articles, start_date, end_date):
     return [a for a in articles if start_date <= a['pub_date'] <= end_date]
diff --git a/misc.py b/misc.py
index 6107c7b..ea5901d 100644
--- a/misc.py
+++ b/misc.py
@@ -16,4 +16,51 @@ keywords_by_topic = {
     'mj': ['michael', 'jackson'],
     'syria': ['syria', 'syrian'],
     'yemen': ['yemen'],
-}
\ No newline at end of file
+    'Al_Gore': ['al gore', 'gore'],
+    'Angela_Merkel': ['angela merkel', 'merkel'],
+    'Ariel_Sharon': ['ariel sharon', 'sharon'],
+    'Arnold_Schwarzenegger': ['arnold schwarzenegger', 'schwarzenegger'],
+    'Bashar_al-Assad': ['al-assad', 'bashar al-assad', 'assad'],
+    'Bill_Clinton': ['bill clinton', 'clinton'],
+    'Charles_Taylor': ['charles taylor', 'taylor'],
+    'Chris_Brown': ['brown', 'chris brown'],
+    'David_Beckham': ['beckham', 'david beckham'],
+    'David_Bowie': ['bowie', 'david bowie'],
+    'Dilma_Rousseff': ['dilma rousseff', 'rousseff'],
+    'Dmitry_Medvedev': ['dmitry medvedev', 'medvedev'],
+    'Dominique_Strauss-Kahn': ['dominique strauss-kahn', 'strauss-kahn', 'kahn'],
+    'Edward_Snowden': ['edward snowden', 'snowden'],
+    'Ehud_Olmert': ['ehud olmert', 'olmert'],
+    'Enron': ['enron'],
+    'Hamid_Karzai': ['hamid karzai', 'karzai'],
+    'Hassan_Rouhani': ['hassan rouhani', 'rouhani'],
+    'Hu_Jintao': ['hu jintao', 'jintao'],
+    'Jacob_Zuma': ['jacob zuma', 'zuma'],
+    'John_Boehner': ['boehner', 'john boehner'],
+    'John_Kerry': ['john kerry', 'kerry'],
+    'Julian_Assange': ['assange', 'julian assange'],
+    'Lance_Armstrong': ['armstrong', 'lance armstrong'],
+    'Mahmoud_Ahmadinejad': ['ahmadinejad', 'mahmoud ahmadinejad'],
+    'Marco_Rubio': ['marco rubio', 'rubio'],
+    'Margaret_Thatcher': ['margaret thatcher', 'thatcher'],
+    'Michael_Jackson': ['jackson', 'michael jackson'],
+    'Michelle_Obama': ['michelle obama', 'obama'],
+    'Mitt_Romney': ['mitt romney', 'romney'],
+    'Morgan_Tsvangirai': ['morgan tsvangirai', 'tsvangirai'],
+    'Nawaz_Sharif': ['nawaz sharif', 'sharif'],
+    'Nelson_Mandela': ['mandela', 'nelson mandela'],
+    'Osama_bin_Laden': ['laden', 'osama bin laden'],
+    'Oscar_Pistorius': ['oscar pistorius', 'pistorius'],
+    'Phil_Spector': ['phil spector', 'spector'],
+    'Prince_William': ['prince william', 'william'],
+    'Robert_Mugabe': ['mugabe', 'robert mugabe'],
+    'Rupert_Murdoch': ['murdoch', 'rupert murdoch'],
+    'Saddam_Hussein': ['hussein', 'saddam hussein'],
+    'Sarah_Palin': ['palin', 'sarah palin'],
+    'Silvio_Berlusconi': ['berlusconi', 'silvio berlusconi'],
+    'Steve_Jobs': ['jobs', 'steve jobs'],
+    'Taliban': ['taliban'],
+    'Ted_Cruz': ['cruz', 'ted cruz'],
+    'Tiger_Woods': ['tiger woods', 'woods'],
+    'WikiLeaks': ['wikileaks']
+}
diff --git a/run.py b/run.py
index ec758aa..6515792 100644
--- a/run.py
+++ b/run.py
@@ -10,7 +10,12 @@ import timeline_generation
 def main(args):
     eval_results = evaluation.ResultLogger()
 
-    data = dataset.get_crisis_dataset()
+    data = {
+        'timeline17': dataset.get_timeline17_dataset,
+        'crisis': dataset.get_crisis_dataset,
+        'entities': dataset.get_entities_dataset
+    }[args.dataset]()
+
     for topic in data.keys():
 
         articles = data[topic]['articles']
@@ -44,6 +49,8 @@ def main(args):
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
+    parser.add_argument('--dataset', type=str, choices=['timeline17', 'crisis', 'entities'], help='the dataset to use',
+                        required=True)
     parser.add_argument('--print_timelines',
                         action='store_true',
                         help='whether to print the timelines to the console after generating them')
diff --git a/timeline_generation.py b/timeline_generation.py
index 9e6919a..a74fbfd 100644
--- a/timeline_generation.py
+++ b/timeline_generation.py
@@ -1,5 +1,3 @@
-from pathlib import Path
-
 from sklearn.feature_extraction.text import TfidfVectorizer
 
 import dataset
@@ -23,7 +21,7 @@ def make_timeline(articles, gold_timeline, keywords):
     # articles = dataset.filter_articles_by_keywords(articles, keywords)
 
     # select dates
-    ranked_dates = date_selection.rank_dates_by_wilson(articles, start_date, end_date, num_dates)
+    ranked_dates = date_selection.rank_dates_by_mention_count(articles, start_date, end_date, num_dates)
 
     # train TFIDF vectorizer on all sentences (not just the ones for this date)
     all_sentences = [sentence['text'] for article in articles for sentence in article['sentences']]
-- 
GitLab