import argparse from model.Hier_BiLSTM_CRF import * from prepare_data import * from train import * def main(): parser = argparse.ArgumentParser() parser.add_argument('--pretrained', default = False, type = bool, help = 'Whether the model uses pretrained sentence embeddings or not') parser.add_argument('--data_path', default = 'data/text/', type = str, help = 'Folder to store the annotated text files') parser.add_argument('--save_path', default = 'saved/', type = str, help = 'Folder where predictions and models will be saved') parser.add_argument('--cat_path', default = 'categories.txt', type = str, help = 'Path to file containing category details') parser.add_argument('--dataset_size', default = 50, type = int, help = 'Total no. of docs') parser.add_argument('--num_folds', default = 5, type = int, help = 'No. of folds to divide the dataset into') parser.add_argument('--device', default = 'cuda', type = str, help = 'cuda / cpu') parser.add_argument('--batch_size', default = 32, type = int) parser.add_argument('--print_every', default = 10, type = int, help = 'Epoch interval after which validation macro f1 and loss will be printed') parser.add_argument('--lr', default = 0.01, type = float, help = 'Learning Rate') parser.add_argument('--reg', default = 0, type = float, help = 'L2 Regularization') parser.add_argument('--emb_dim', default = 200, type = int, help = 'Sentence embedding dimension') parser.add_argument('--word_emb_dim', default = 100, type = int, help = 'Word embedding dimension, applicable only if pretrained = False') parser.add_argument('--epochs', default = 300, type = int) parser.add_argument('--val_fold', default = 'cross', type = str, help = 'Fold number to be used as validation, use cross for num_folds cross validation') args = parser.parse_args() print('\nPreparing data ...', end = ' ') idx_order = prepare_folds(args) x, y, word2idx, tag2idx = prepare_data(idx_order, args) print('Done') print('Vocabulary size:', len(word2idx)) print('#Tags:', len(tag2idx)) # Dump word2idx and tag2idx with open(args.save_path + 'word2idx.json', 'w') as fp: json.dump(word2idx, fp) with open(args.save_path + 'tag2idx.json', 'w') as fp: json.dump(tag2idx, fp) if args.val_fold == 'cross': print('\nCross-validation\n') for f in range(args.num_folds): print('\nInitializing model ...', end = ' ') model = Hier_LSTM_CRF_Classifier(len(tag2idx), args.emb_dim, tag2idx['<start>'], tag2idx['<end>'], tag2idx['<pad>'], vocab_size = len(word2idx), word_emb_dim = args.word_emb_dim, pretrained = args.pretrained, device = args.device).to(args.device) print('Done') print('\nEvaluating on fold', f, '...') learn(model, x, y, tag2idx, f, args) else: print('\nInitializing model ...', end = ' ') model = Hier_LSTM_CRF_Classifier(len(tag2idx), args.emb_dim, tag2idx['<start>'], tag2idx['<end>'], tag2idx['<pad>'], vocab_size = len(word2idx), word_emb_dim = args.word_emb_dim, pretrained = args.pretrained, device = args.device).to(args.device) print('Done') print('\nEvaluating on fold', args.val_fold, '...') learn(model, x, y, tag2idx, int(args.val_fold), args) if __name__ == '__main__': main()