From 579d22eb9c660dc79fb24057678ec2cc5b6d54ad Mon Sep 17 00:00:00 2001 From: Simon Will <will@cl.uni-heidelberg.de> Date: Fri, 28 Sep 2018 22:08:57 +0200 Subject: [PATCH] Fix bugs in create_feature_vectors.py --- allzweckmesser/features.py | 7 ++++--- allzweckmesser/meters.py | 2 +- scripts/create_feature_vectors.py | 8 ++++++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/allzweckmesser/features.py b/allzweckmesser/features.py index 8e765d8..3c7fb8b 100644 --- a/allzweckmesser/features.py +++ b/allzweckmesser/features.py @@ -29,16 +29,17 @@ class CombinedFeatures(Enum): def combine_features(reading_features, reading_meter_features): features = [0 for _ in CombinedFeatures] - for rf, val in reading_features: + for rf, val in reading_features.items(): if hasattr(CombinedFeatures, rf.name): features[CombinedFeatures[rf.name].value] = val meter_rules_violated = 0 - for rmf, val in reading_meter_features: + for rmf, val in reading_meter_features.items(): if hasattr(CombinedFeatures, rmf.name): features[CombinedFeatures[rmf.name].value] = val elif 'VIOLATED' in rmf.name.upper(): meter_rules_violated += 1 - features[CombinedFeatures.METER_RULES_VIOLATED] = meter_rules_violated + features[ + CombinedFeatures.METER_RULES_VIOLATED.value] = meter_rules_violated return features diff --git a/allzweckmesser/meters.py b/allzweckmesser/meters.py index ac2f7a8..67435fd 100644 --- a/allzweckmesser/meters.py +++ b/allzweckmesser/meters.py @@ -12,7 +12,7 @@ def bridge(position_spec, feature): def get_feature(meter: Meter, reading: Reading): position = Position.after(position_spec[0], reading, meter, position_spec[1]) - if position.word_boundary: + if position and position.word_boundary: return None else: return feature diff --git a/scripts/create_feature_vectors.py b/scripts/create_feature_vectors.py index 43661d9..3d5528f 100644 --- a/scripts/create_feature_vectors.py +++ b/scripts/create_feature_vectors.py @@ -17,11 +17,14 @@ def main(meter_reference_verses, outfile, meters=['hexameter']): for meter in meters if meter in azm.meters.ALL_METERS ] - scanner = azm.scanner.Scanner(meters=meters) + scanner = azm.scanner.Scanner() out = [] - for ref_meter, ref_verse, correct in meter_reference_verses: + total_instances = len(meter_reference_verses) + for i, (ref_meter, ref_verse, correct) in enumerate(meter_reference_verses, 1): if correct: + print('Processing verse {} ({}/{})' + .format(ref_verse.text, i, total_instances)) instances = [] ref_reading = ref_verse.readings[0] ref_schema = ref_reading.get_schema() @@ -31,6 +34,7 @@ def main(meter_reference_verses, outfile, meters=['hexameter']): print('ERROR when scanning verse {!r}'.format(ref_verse), file=sys.stderr) traceback.print_exc() + continue reading_meter_combinations = ( azm.meters.get_reading_meter_combinations( analysis.readings, meters -- GitLab