Skip to content
Snippets Groups Projects
random_forest.py 1.30 KiB
from sklearn.ensemble import RandomForestClassifier
from sklearn.externals import joblib
import graphviz # doctest: +SKIP
import json
import numpy as np

#load data
train_file = open('../train0-9.json', 'r')
dev_file = open('../dev.json')

train = json.load(train_file)
dev = json.load(dev_file)

X,Y = [],[]
for verse in train:
    for reading in verse[2]:
        X.append(reading[0])
        Y.append(reading[1])


#build model
clf = RandomForestClassifier()

#fit
clf.fit(X, Y)

correct = 0
total = 0

for verse in dev:
    vectors = [reading[0] for reading in verse[2]]

    probs = clf.predict_proba(vectors)
    sort_probs = sorted([(probs[i], verse[2][i][1]) for i in range(len(probs))], key=lambda x: x[0][0])

    gold = sort_probs[0][1]

    if gold == 1:
        correct += 1

    total += 1

print("Recall: {}/{} ({})".format(correct, total, correct/total))

#precision = tp/(tp+fp)
#recall = tp/(tp+fn)
#accuracy = (tp+tn)/(tp+tn+fp+fn)
#f1 = 2*((precision*recall)/(precision+recall))
#print('Precision: {}\tRecall:{}'.format(precision,recall))
#print('Accuracy: {}\tF1-Measure:{}\n'.format(accuracy, f1)) 

joblib.dump(clf, 'forest_classifier.joblib')

#dot_data = tree.export_graphviz(clf, out_file=None) # doctest: +SKIP
#graph = graphviz.Source(dot_data) # doctest: +SKIP
#graph.render("latin_tree") # doctest: +SKIP