Skip to content
Snippets Groups Projects
Commit 0ac08cf4 authored by opitz's avatar opitz
Browse files

fixes

parent cadc85dc
No related branches found
No related tags found
No related merge requests found
import numpy as np
import re
from operator import itemgetter
import bs4
......@@ -108,7 +109,8 @@ class HasDescriptionNode(dict):
vecs=[]
ncs = []
doc = nlp(self.d["text"])
if not list(doc.noun_chunks):
doc = nlp("This is a dummy document.")
def not_valid(n,d):
if n.text in ["thence","they","them","her","him","it"]:
return True
......@@ -132,6 +134,11 @@ class HasDescriptionNode(dict):
continue
vecs.append(nc.vector/nc.vector_norm)
ncs.append(nc)
#print(vecs[0].shape)
"""
if not ncs:
return [np.zeros(96)],["empty"]
"""
"""
for nc in [t for t in doc if t.pos_ == "VERB"]:
vecs.append(nc.vector/nc.vector_norm)
......@@ -146,6 +153,8 @@ class HasDescriptionNode(dict):
"""
vecs,ncs = self.get_noun_chunk_vectors()
newcopy = HasDescriptionNode(None,"None")
if not classifier:
return newcopy
if not ncs:
return newcopy
maxlen=max([len(x) for x in ncs])
......
......@@ -37,7 +37,7 @@ def simplify_text_description_nodes(G,node_index_dict,mode="None",min_freq=1):
trialnodes=[n for n in G.nodes(data=True) if isinstance(n[1]["nodeobj"],dh.TrialNode)]
descr_nodes=[]
mask=[]
related_cat2=[]
# we iterate over all trials
for i,tn in enumerate(trialnodes):
#get corresponding cat node
......@@ -58,6 +58,7 @@ def simplify_text_description_nodes(G,node_index_dict,mode="None",min_freq=1):
descr_nodes.append(G.nodes[nb]["nodeobj"])
tid=(tn[0],nb)
Xid.append(tid)
related_cat2.append(category)
for dv in descr_vectors:
#put noun chunk vector into training data
Xvector.append(dv)
......@@ -67,8 +68,7 @@ def simplify_text_description_nodes(G,node_index_dict,mode="None",min_freq=1):
if mode=="classifier":
#fit a classifier to learn a mapping between noun chunks and labels
clf=LogisticRegression()
clf.fit(Xvector,related_cat)
clf.fit(Xvector,related_cat)
# now we can remove the textdescription nodes and insert their simplified fporms
for i,idx in enumerate(Xid):
......@@ -78,7 +78,11 @@ def simplify_text_description_nodes(G,node_index_dict,mode="None",min_freq=1):
#node_index_dict.pop(descr_nodes[i])
if mode == "classifier":
simp_descr=descr_nodes[i].simplify(clf,list(clf.classes_).index(related_cat[i]))
#print(clf.classes_,related_cat2[i],"damage" in related_cat,"damage" in related_cat2)
if related_cat2[i] in clf.classes_:
simp_descr=descr_nodes[i].simplify(clf,list(clf.classes_).index(related_cat2[i]))
else:
simp_descr=descr_nodes[i].simplify(None,None)
elif mode == "spacy_direct_object":
simp_descr=descr_nodes[i].simplify_to_direct_object()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment