Commit 635d80fb authored by Victor Zimmermann's avatar Victor Zimmermann
Browse files

Add proper evaluation tools

parent 03567c33
Loading
Loading
Loading
Loading

scripts/eval.py

0 → 100644
+68 −0
Original line number Diff line number Diff line
from sklearn import tree, svm
from sklearn.externals import joblib
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier
from sklearn.calibration import CalibratedClassifierCV
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import graphviz
import json
import numpy as np

def compute_recall(data_set, trained_clf):

    correct = 0
    total = 0

    gold_dict = dict()

    for verse in data_set:
        vectors = [reading[0] for reading in verse[2]]

        probs = trained_clf.predict_proba(vectors)
        sort_probs = sorted([(probs[i], verse[2][i][1]) for i in range(len(probs))], key=lambda x: x[0][0])

        gold = [prob[1] for prob in sort_probs]

        if max(gold) != 0:

            gold_idx = np.argmax(gold)

            if gold_idx in gold_dict:
                gold_dict[gold_idx] += 1
            else:
                gold_dict[gold_idx] = 1

        total += 1

    for i in range(len(gold_dict.items())):
        if i in gold_dict:
            correct += gold_dict[i]
        print('{}\t{}/{}\t{}'.format(i, correct, total, correct/total))

#load data
train_file = open('../train0-9.json', 'r')
dev_file = open('../dev.json')
test_file = open('../test.json')

train = json.load(train_file)
dev = json.load(dev_file)
test = json.load(test_file)

#build model
tree = joblib.load('tree_classifier.joblib')
svm = joblib.load('svm_classifier.joblib')
forest = joblib.load('forest_classifier.joblib')

clfs = [('DecisionTree', tree), ('SupportVectorMachine', svm), ('RandomForest', forest)]

print("DEVELOPMENT")
for name,clf in clfs:
    print(name)
    compute_recall(dev, clf)

print("\nTEST")
for name,clf in clfs:
    print(name)
    compute_recall(test, clf)