#!/usr/bin/env python3


from spacy import displacy
colors = {'ORG': '#73c6b6', 'LOC': '#bb8fce', 'PER': '#e59866', 'MISC': '#a6acaf'}
options = {'ents': ['ORG', 'LOC', 'PER', 'MISC'], 'colors': colors}
colors_sner = {'ORGANIZATION': '#73c6b6', 'LOCATION': '#bb8fce', 'PERSON': '#e59866', 'MISC': '#a6acaf'}
options_sner = {'ents': ['ORGANIZATION', 'LOCATION', 'PERSON', 'MISC'], 'colors': colors_sner}




def get_spacyformat_for_sner_output(text):
    if isinstance(text, tuple):
        title = text[0]
        sometext = text[1]
    elif isinstance(text, str):
        title = None
        sometext = text
    
    # generate sner output as a list
    list_tags_all = tagger.get_entities(sometext)
    
    # initiliaze dictionary
    somedict = {}
    
    # specify text and title
    somedict['text'] = sometext
    if isinstance(title, str):
        somedict['title'] = title
    
    # generate list of entities with corresponding start und end indices
    list_ents = []
    remtext = sometext
    overall_start_index = 0
    for ent in list_tags_all:
        start_index = remtext.index(ent[0])
        overall_start_index += start_index
        end_index = start_index + len(ent[0])
        overall_end_index = overall_start_index + len(ent[0])
        #print(ent, remtext[start_index:end_index], sometext[overall_start_index:overall_end_index])
        remtext = remtext[end_index+1:]
    
        if ent[1] != 'O':
            list_ents.append({'start': overall_start_index, 'end': overall_end_index, 'label': ent[1]})
            #print(ent)
        overall_start_index += len(ent[0]) + 1
    
    # add that to dictionary
    somedict['ents'] = list_ents
    
    return somedict

def write_to_txt(alldicts, filename):
    f = open(filename, 'w')
    for key, value in alldicts.items():
        f.write( str(value) )
        f.write( '\n\n.\n\n')
    f.write('generated: ' + str(datetime.datetime.now()))
    f.close()

def read_from_txt(filename):
    f = open(filename,'r')
    data=f.read()
    f.close()
    alldicts = {}
    for sent in data.split('\n\n.\n\n'):
        #print(sent)
        try:
            somedict = eval(sent)
            key = somedict['title'].split('/')[-1]
            alldicts[key] = somedict
        except SyntaxError:
            #print('SyntaxError')
            pass
    return alldicts

def render_dict(alldicts):
    for sent_key, sent_dict in alldicts.items():
        displacy.render(sent_dict, style='ent', jupyter=True, manual=True, options=options_sner)