Skip to content
Snippets Groups Projects
Commit 1ec85395 authored by Lili Gasser's avatar Lili Gasser
Browse files

WIP: create file with tested training data

parent 67da49e8
No related branches found
No related tags found
No related merge requests found
source diff could not be displayed: it is too large. Options to address this: view the blob.
......@@ -18,7 +18,7 @@ import spacy
from spacy.util import minibatch, compounding
import sys
sys.path.append("./src/python")
from utils_ner import read_from_txt, transform_to_training_format
from utils_ner import read_from_txt, write_to_txt, transform_to_training_format, transform_to_reading_format
......@@ -34,11 +34,12 @@ def main(model=None, output_dir=None, n_iter=100, train_data=None, print_output=
if train_data is not None:
dict_onedoc = read_from_txt(train_data)
TRAIN_DATA = transform_to_training_format(dict_onedoc)[:50] # TODO: get rid of [:50]
TRAIN_DATA_orig = TRAIN_DATA
print(TRAIN_DATA[:10])
# TODO: format checks
else:
sys.exit("no training data")
if model is not None:
......@@ -58,7 +59,7 @@ def main(model=None, output_dir=None, n_iter=100, train_data=None, print_output=
ner = nlp.get_pipe("ner")
# add labels
for _, annotations in TRAIN_DATA:
for _, annotations, _ in TRAIN_DATA:
for ent in annotations.get("entities"):
ner.add_label(ent[2])
......@@ -75,7 +76,7 @@ def main(model=None, output_dir=None, n_iter=100, train_data=None, print_output=
# batch up the examples using spaCy's minibatch
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
texts, annotations = zip(*batch)
texts, annotations, _ = zip(*batch)
nlp.update(
texts, # batch of texts
annotations, # batch of annotations
......@@ -85,19 +86,30 @@ def main(model=None, output_dir=None, n_iter=100, train_data=None, print_output=
print("Losses", losses)
# test the trained model
for text, dict_ents_train in TRAIN_DATA:
TRAIN_DATA_tested = []
for text, dict_ents_train, title in TRAIN_DATA_orig:
print(title)
list_ents_train = dict_ents_train['entities']
doc = nlp(text)
list_ents_test = [(ent.start_char, ent.end_char, ent.label_) for ent in doc.ents]
list_ents_test = [(ent.start_char, ent.end_char, ent.label_) for ent in doc.ents]
dict_ents_test = {}
dict_ents_test['entities'] = list_ents_test
tpl = (text, dict_ents_test, title)
TRAIN_DATA_tested.append(tpl)
# print('train', list_ents_train)
# print('test', list_ents_test)
# print(set(list_ents_train) == set(list_ents_test))
# if print_output
if not set(list_ents_train) == set(list_ents_test):
print(text)
print("Entities", [(ent.text, ent.label_) for ent in doc.ents])
#if not set(list_ents_train) == set(list_ents_test):
#print(text)
#print("Entities", [(ent.text, ent.label_) for ent in doc.ents])
#print("Tokens", [(t.text, t.ent_type_, t.ent_iob) for t in doc])
alldicts_tested = transform_to_reading_format(TRAIN_DATA_tested)
filename_tested = str(train_data)[:-4] + '_tested.txt'
write_to_txt(alldicts_tested, filename_tested)
# save model to output directory
if output_dir is not None:
output_dir = Path(output_dir)
......@@ -109,7 +121,7 @@ def main(model=None, output_dir=None, n_iter=100, train_data=None, print_output=
# test the saved model
print("Loading from", output_dir)
nlp2 = spacy.load(output_dir)
for text, _ in TRAIN_DATA:
for text, _, _ in TRAIN_DATA:
doc = nlp2(text)
if print_output:
print(text)
......
#!/usr/bin/env python3
import datetime
from spacy import displacy
colors = {'ORG': '#73c6b6', 'LOC': '#bb8fce', 'PER': '#e59866', 'MISC': '#a6acaf'}
options = {'ents': ['ORG', 'LOC', 'PER', 'MISC'], 'colors': colors}
......@@ -130,11 +130,48 @@ def transform_to_training_format(alldicts):
ents_as_list = somedict['ents']
ents_in_dict = {}
ents_in_dict['entities'] = get_entitities_in_training_format(ents_as_list)
tpl = (text, ents_in_dict)
title = somedict['title']
tpl = (text, ents_in_dict, title)
train_data.append(tpl)
return train_data
def transform_to_reading_format(train_data):
def get_entitities_in_reading_format(dict_ents):
list_ents_read = []
list_ents_train = dict_ents['entities']
for tpl_ent in list_ents_train:
dict_ent = {}
dict_ent['start'] = tpl_ent[0]
dict_ent['end'] = tpl_ent[1]
label = tpl_ent[2]
if label == 'PER':
label = 'PERSON'
if label == 'ORG':
label = 'ORGANIZATION'
if label == 'LOC':
label = 'LOCATION'
dict_ent['label'] = label
list_ents_read.append(dict_ent)
return list_ents_read
alldicts = {}
for text, dict_ents, title in train_data:
key = title.split('/')[-1]
somedict = {}
somedict['text'] = text
somedict['ents'] = get_entitities_in_reading_format(dict_ents)
somedict['title'] = title
alldicts[key] = somedict
return alldicts
def get_language(filepath):
language = 'french' if filepath.endswith('_french.txt') else 'german'
return language
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment