Commit 579d22eb authored by Simon Will's avatar Simon Will
Browse files

Fix bugs in create_feature_vectors.py

parent 0de9bd44
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -29,16 +29,17 @@ class CombinedFeatures(Enum):

def combine_features(reading_features, reading_meter_features):
    features = [0 for _ in CombinedFeatures]
    for rf, val in reading_features:
    for rf, val in reading_features.items():
        if hasattr(CombinedFeatures, rf.name):
            features[CombinedFeatures[rf.name].value] = val

    meter_rules_violated = 0
    for rmf, val in reading_meter_features:
    for rmf, val in reading_meter_features.items():
        if hasattr(CombinedFeatures, rmf.name):
            features[CombinedFeatures[rmf.name].value] = val
        elif 'VIOLATED' in rmf.name.upper():
            meter_rules_violated += 1
    features[CombinedFeatures.METER_RULES_VIOLATED] = meter_rules_violated
    features[
        CombinedFeatures.METER_RULES_VIOLATED.value] = meter_rules_violated

    return features
+1 −1
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ def bridge(position_spec, feature):
    def get_feature(meter: Meter, reading: Reading):
        position = Position.after(position_spec[0], reading, meter,
                                  position_spec[1])
        if position.word_boundary:
        if position and position.word_boundary:
            return None
        else:
            return feature
+6 −2
Original line number Diff line number Diff line
@@ -17,11 +17,14 @@ def main(meter_reference_verses, outfile, meters=['hexameter']):
        for meter in meters
        if meter in azm.meters.ALL_METERS
    ]
    scanner = azm.scanner.Scanner(meters=meters)
    scanner = azm.scanner.Scanner()
    out = []

    for ref_meter, ref_verse, correct in meter_reference_verses:
    total_instances = len(meter_reference_verses)
    for i, (ref_meter, ref_verse, correct) in enumerate(meter_reference_verses, 1):
        if correct:
            print('Processing verse {} ({}/{})'
                  .format(ref_verse.text, i, total_instances))
            instances = []
            ref_reading = ref_verse.readings[0]
            ref_schema = ref_reading.get_schema()
@@ -31,6 +34,7 @@ def main(meter_reference_verses, outfile, meters=['hexameter']):
                print('ERROR when scanning verse {!r}'.format(ref_verse),
                      file=sys.stderr)
                traceback.print_exc()
                continue
            reading_meter_combinations = (
                azm.meters.get_reading_meter_combinations(
                    analysis.readings, meters