Commit 0de9bd44 authored by Simon Will's avatar Simon Will
Browse files

Add code for creating feature vectors (with bugs)

parent 9b734a57
Loading
Loading
Loading
Loading
+21 −5
Original line number Diff line number Diff line
@@ -19,10 +19,26 @@ class ReadingMeterFeatures(Enum):


class CombinedFeatures(Enum):
    # The values have to start at 0 and be contiguous.
    MCL_TRIGGERS_PL = 0
    SYNIZESIS = 1
    S_ELISION = 2
    DOES_NOT_FIT_METER = 3
    NECESSARY_CHANGES_TO_MAKE_IT_FIT = 4
    NO_USUAL_BREAK_PRESENT = 5
    BRIDGES_VIOLATED = 6
    DOES_NOT_FIT_METER = 2
    NO_USUAL_BREAK_PRESENT = 3
    METER_RULES_VIOLATED = 4


def combine_features(reading_features, reading_meter_features):
    features = [0 for _ in CombinedFeatures]
    for rf, val in reading_features:
        if hasattr(CombinedFeatures, rf.name):
            features[CombinedFeatures[rf.name].value] = val

    meter_rules_violated = 0
    for rmf, val in reading_meter_features:
        if hasattr(CombinedFeatures, rmf.name):
            features[CombinedFeatures[rmf.name].value] = val
        elif 'VIOLATED' in rmf.name.upper():
            meter_rules_violated += 1
    features[CombinedFeatures.METER_RULES_VIOLATED] = meter_rules_violated

    return features
+2 −1
Original line number Diff line number Diff line
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import itertools
import re

from .features import ReadingMeterFeatures
+82 −0
Original line number Diff line number Diff line
#!/usr/bin/python3
# -*- coding: utf-8 -*-

import argparse
import json
import sys
import traceback

from unidecode import unidecode

import allzweckmesser as azm


def main(meter_reference_verses, outfile, meters=['hexameter']):
    meters = [
        azm.meters.ALL_METERS[meter]
        for meter in meters
        if meter in azm.meters.ALL_METERS
    ]
    scanner = azm.scanner.Scanner(meters=meters)
    out = []

    for ref_meter, ref_verse, correct in meter_reference_verses:
        if correct:
            instances = []
            ref_reading = ref_verse.readings[0]
            ref_schema = ref_reading.get_schema()
            try:
                analysis = scanner.scan_verses([unidecode(ref_verse.text)])[0]
            except Exception:
                print('ERROR when scanning verse {!r}'.format(ref_verse),
                      file=sys.stderr)
                traceback.print_exc()
            reading_meter_combinations = (
                azm.meters.get_reading_meter_combinations(
                    analysis.readings, meters
                )
            )
            for reading, meter, rmfeatures in reading_meter_combinations:
                features = azm.features.combine_features(
                    reading.features, rmfeatures)
                # A feature vector gets a correct label if the schema matches
                # the reference reading’s schema and the meter matches the
                # reference meter.
                reading_is_correct = int(
                    meter == ref_meter
                    and reading.get_schema() == ref_schema
                )
                instances.append((features, reading_is_correct))
            out.append((ref_verse.text, ref_meter, instances))

    with open(outfile, 'w') as f:
        json.dump(out, f, indent=2)


def read_infile(infile):
    meter_reference_verses = []
    with open(infile) as f:
        for meter, verse_dict, correct in json.load(f):
            verse = azm.model.Verse.from_json(verse_dict)
            meter_reference_verses.append((meter, verse, correct))
    return meter_reference_verses


def parse_args_and_main():
    d = 'Generate feature vectors for reading-meter combinations'
    parser = argparse.ArgumentParser(description=d)
    parser.add_argument('--meters', '-m', nargs='+',
                        help='Meters to consider when scanning.')
    parser.add_argument('infile',
                        help='JSON file containing the reference verses')
    parser.add_argument('outfile',
                        help='JSON file for the output')
    args = parser.parse_args()
    args = vars(args)
    args['meter_reference_verses'] = read_infile(args['infile'])
    del args['infile']
    main(**args)


if __name__ == '__main__':
    parse_args_and_main()