diff --git a/allzweckmesser/features.py b/allzweckmesser/features.py index 8e765d8c2d88562e71d85d60594d50a1c10e5072..3c7fb8b71d072fe84cc20710b7a4d9ad5141b7f5 100644 --- a/allzweckmesser/features.py +++ b/allzweckmesser/features.py @@ -29,16 +29,17 @@ class CombinedFeatures(Enum): def combine_features(reading_features, reading_meter_features): features = [0 for _ in CombinedFeatures] - for rf, val in reading_features: + for rf, val in reading_features.items(): if hasattr(CombinedFeatures, rf.name): features[CombinedFeatures[rf.name].value] = val meter_rules_violated = 0 - for rmf, val in reading_meter_features: + for rmf, val in reading_meter_features.items(): if hasattr(CombinedFeatures, rmf.name): features[CombinedFeatures[rmf.name].value] = val elif 'VIOLATED' in rmf.name.upper(): meter_rules_violated += 1 - features[CombinedFeatures.METER_RULES_VIOLATED] = meter_rules_violated + features[ + CombinedFeatures.METER_RULES_VIOLATED.value] = meter_rules_violated return features diff --git a/allzweckmesser/meters.py b/allzweckmesser/meters.py index 31b1298b9732d8c171321dde04ebb3ca0c0f57ff..92e756aa4a725ccb83f839161c9fc48e9db65323 100644 --- a/allzweckmesser/meters.py +++ b/allzweckmesser/meters.py @@ -12,7 +12,7 @@ def bridge(position_spec, feature): def get_feature(meter: Meter, reading: Reading): position = Position.after(position_spec[0], reading, meter, position_spec[1]) - if position.word_boundary: + if position and position.word_boundary: return None else: return feature diff --git a/scripts/create_feature_vectors.py b/scripts/create_feature_vectors.py index 43661d9a4f427ff0b0143d94b86e1ef66d1cf1d5..3d5528fdc5f29680a00941a0482bf7650178416c 100644 --- a/scripts/create_feature_vectors.py +++ b/scripts/create_feature_vectors.py @@ -17,11 +17,14 @@ def main(meter_reference_verses, outfile, meters=['hexameter']): for meter in meters if meter in azm.meters.ALL_METERS ] - scanner = azm.scanner.Scanner(meters=meters) + scanner = azm.scanner.Scanner() out = [] - for ref_meter, ref_verse, correct in meter_reference_verses: + total_instances = len(meter_reference_verses) + for i, (ref_meter, ref_verse, correct) in enumerate(meter_reference_verses, 1): if correct: + print('Processing verse {} ({}/{})' + .format(ref_verse.text, i, total_instances)) instances = [] ref_reading = ref_verse.readings[0] ref_schema = ref_reading.get_schema() @@ -31,6 +34,7 @@ def main(meter_reference_verses, outfile, meters=['hexameter']): print('ERROR when scanning verse {!r}'.format(ref_verse), file=sys.stderr) traceback.print_exc() + continue reading_meter_combinations = ( azm.meters.get_reading_meter_combinations( analysis.readings, meters