Skip to content
Snippets Groups Projects
Commit 3fd00781 authored by blunck's avatar blunck
Browse files

F1 score output & adaptation to renamed pos_feature

parent bfff9805
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ import numpy as np
from sklearn import svm
from sklearn import tree
from sklearn.model_selection import cross_val_score
import postagger
import pos_feature
def create_vector(corpus_instance, vocabulary=None, pos_vocabulary=None):
......@@ -17,12 +17,9 @@ def create_vector(corpus_instance, vocabulary=None, pos_vocabulary=None):
Example for corpus instance: OrderedDict([('LABEL', '0'), ('FILENAME', '36_19_RPRRQDRSHDV6J.txt'), ('STARS', '5.0'), ('TITLE', etc.
"""
f1 = ngram_feature.extract(corpus_instance, vocabulary)
f2 = postagger.extract(corpus_instance, pos_vocabulary)
f2 = pos_feature.extract(corpus_instance, pos_vocabulary)
f4 = sent_rating_feature.extract(corpus_instance)
print(f2)
print(len(f2))
return np.concatenate((f1, f2, f4))
......@@ -44,8 +41,7 @@ if __name__ == '__main__':
bigram_vocab = ngram_feature.get_vocabulary(train_set, 2)
# pos_bags
pos_bigram_vocab = postagger.get_pos_vocabulary(train_set)
#print(pos_bigram_vocab) #already lookin' good
pos_bigram_vocab = pos_feature.get_pos_vocabulary(train_set)
# inputs:
train_inputs = [create_vector(el, unigram_vocab, pos_bigram_vocab)
......@@ -75,12 +71,15 @@ if __name__ == '__main__':
train_multiple([svm_clf, tree_clf], train_inputs, train_labels)
# validation
svm_score = cross_val_score(svm_clf, train_inputs, train_labels, cv=5).mean()#, scoring='f1')
tree_score = cross_val_score(tree_clf, train_inputs, train_labels, cv=5).mean()#, scoring='f1')
svm_acc = cross_val_score(svm_clf, train_inputs, train_labels, cv=5, scoring='accuracy').mean()
tree_acc = cross_val_score(tree_clf, train_inputs, train_labels, cv=5, scoring='accuracy').mean()
svm_f1 = cross_val_score(svm_clf, train_inputs, train_labels, cv=5, scoring='f1').mean()
tree_f1 = cross_val_score(tree_clf, train_inputs, train_labels, cv=5, scoring='f1').mean()
print("\n--Cross Validation Scores-- ")
print("\nSVM: {}".format(svm_score))
print("\nTree: {}".format(tree_score))
print("\nSVM: Accuracy: {}, F1-Score: {}".format(svm_acc, svm_f1))
print("\nTree: Accuracy: {}, F1-Score: {}".format(tree_acc, tree_f1))
# testing
# print("\nSVM: Score on test Data:")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment