Skip to content

Commit

Permalink
Add encoding cache and lazy-train mechanism (#50)
Browse files Browse the repository at this point in the history
* Add new config about knowledge distillation for query binary classifier

* remove inferenced result in knowledge distillation for query binary classifier

* Add AUC.py in tools folder

* Add test_data_path into conf_kdqbc_bilstmattn_cnn.json

* Modify AUC.py

* Rename AUC.py into calculate_AUC.py

* Modify test&calculate AUC commands for Knowledge Distillation for Query Binary Classifier

* Add cpu_thread_num parameter in conf.training_params

* Rename cpu_thread_num into cpu_num_workers

* update comments in ModelConf.py

* Add cup_num_workers in model_zoo/advanced/conf.json

* Add the description of cpu_num_workers in Tutorial.md

* Update inference speed of compressed model

* Add ProcessorsScheduler Class

* Add license in ProcessorScheduler.py

* use lazy loading instead of one-off loading

* Remove Debug Info in problem.py

* use open instead of codecs.open

* update the inference of build dictionary for classification

* add md5 function in common_utils.py

* add merge_encode_* function

* update typo

* update typo

* reorg the logical flow in train.py

* remove dummy comments in problem.py

* add encoding cache mechanism

* add lazy-load mechanism for training phase

* enumerate problem types in problem.py

* remove data_encoding.py

* add lazy load train logic

* Modify comment and remove debug code

* Judge if test_path exists

* fix parameter missing when use char embedding

* merge master

* add file_column_num in problem.py

* merge add_encoding_cache branch

* add SST-2 in .gitignore

* merge master

* use steps_per_validation instead of valid_times_per_epoch

* Fix Learning Rate decay logic bug

* add log of calculating md5 of training data

* fix multi-gpu char_emb OOM problem & add char leval fix_lengths

* Modify batch_num_to_show_results in multi-gpu

* Modify batch_num_to_show_results

* delete deepcopy in get_batches

* add new parameters chunk_size and max_building_lines in conf and update tutorials
  • Loading branch information
chengfx authored and ljshou committed Aug 2, 2019
1 parent db26940 commit 58ad563
Show file tree
Hide file tree
Showing 15 changed files with 590 additions and 362 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
*.vs*
dataset/GloVe/
dataset/20_newsgroups/
models/
dataset/SST-2/
models/
438 changes: 220 additions & 218 deletions LearningMachine.py

Large diffs are not rendered by default.

35 changes: 32 additions & 3 deletions ModelConf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from losses.BaseLossConf import BaseLossConf
#import traceback
from settings import LanguageTypes, ProblemTypes, TaggingSchemes, SupportedMetrics, PredictionTypes, DefaultPredictionFields
from utils.common_utils import log_set, prepare_dir
from settings import LanguageTypes, ProblemTypes, TaggingSchemes, SupportedMetrics, PredictionTypes, DefaultPredictionFields, ConstantStatic
from utils.common_utils import log_set, prepare_dir, md5
from utils.exceptions import ConfigurationError
import numpy as np

Expand Down Expand Up @@ -219,6 +219,10 @@ def load_from_file(self, conf_path):
# vocabulary setting
self.max_vocabulary = self.get_item(['training_params', 'vocabulary', 'max_vocabulary'], default=800000, use_default=True)
self.min_word_frequency = self.get_item(['training_params', 'vocabulary', 'min_word_frequency'], default=3, use_default=True)
self.max_building_lines = self.get_item(['training_params', 'vocabulary', 'max_building_lines'], default=1000 * 1000, use_default=True)

# chunk_size
self.chunk_size = self.get_item(['training_params', 'chunk_size'], default=1000 * 1000, use_default=True)

# file column header setting
self.file_with_col_header = self.get_item(['inputs', 'file_with_col_header'], default=False, use_default=True)
Expand Down Expand Up @@ -280,6 +284,9 @@ def load_from_file(self, conf_path):
tmp_problem_path = os.path.join(self.save_base_dir, '.necessary_cache', 'problem.pkl')
self.problem_path = tmp_problem_path if os.path.isfile(tmp_problem_path) else os.path.join(self.save_base_dir, 'necessary_cache', 'problem.pkl')

# cache configuration
self._load_cache_config_from_conf()

# training params
self.training_params = self.get_item(['training_params'])

Expand All @@ -303,7 +310,9 @@ def load_from_file(self, conf_path):
self.max_epoch = self.params.max_epoch
else:
self.max_epoch = self.get_item(['training_params', 'max_epoch'], default=float('inf'))
self.valid_times_per_epoch = self.get_item(['training_params', 'valid_times_per_epoch'], default=1)
if 'valid_times_per_epoch' in self.conf['training_params']:
logging.info("configuration[training_params][valid_times_per_epoch] is deprecated, please use configuration[training_params][steps_per_validation] instead")
self.steps_per_validation = self.get_item(['training_params', 'steps_per_validation'], default=10)
self.batch_num_to_show_results = self.get_item(['training_params', 'batch_num_to_show_results'], default=10)
self.max_lengths = self.get_item(['training_params', 'max_lengths'], default=None, use_default=True)
self.fixed_lengths = self.get_item(['training_params', 'fixed_lengths'], default=None, use_default=True)
Expand Down Expand Up @@ -529,3 +538,23 @@ def back_up(self, params):
shutil.copy(params.conf_path, self.save_base_dir)
logging.info('Configuration file is backed up to %s' % (self.save_base_dir))

def _load_cache_config_from_conf(self):
# training data
self.train_data_md5 = None
if self.phase == 'train' and self.train_data_path:
logging.info("Calculating the md5 of traing data ...")

This comment has been minimized.

Copy link
@yogin36

yogin36 Nov 6, 2019

DEV-ad16839d1e6668536f0dbdf417120d55c3e738bd

self.train_data_md5 = md5([self.train_data_path])
logging.info("the md5 of traing data is %s"%(self.train_data_md5))

# problem
self.problem_md5 = None

# encoding
self.encoding_cache_dir = None
self.encoding_cache_index_file_path = None
self.encoding_cache_index_file_md5_path = None
self.encoding_file_index = None
self.encoding_cache_legal_line_cnt = 0
self.encoding_cache_illegal_line_cnt = 0
self.load_encoding_cache_generator = None

6 changes: 5 additions & 1 deletion Tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,12 @@ The architecture of the configuration file is:
CUDA_VISIBLE_DEVICES= python train.py
```
- ***cpu_num_workers***. [default: -1] Define the number of processes to preprocess the dataset. The number of processes is equal to that of logical cores CPU supports if value is negtive or 0, otherwise it is equal to *cpu_num_workers*.
- ***chunk_size***. [default: 1000000] Define the chunk size of files that NB reads every time for avoiding out of memory and the mechanism of lazy-loading.
- ***batch_size***. Define the batch size here. If there are multiple GPUs, *batch_size* is the batch size of each GPU.
- ***batch_num_to_show_results***. [necessary for training] During the training process, show the results every batch_num_to_show_results batches.
- ***max_epoch***. [necessary for training] The maximum number of epochs to train.
- ***valid_times_per_epoch***. [optional for training, default: 1] Define how many times to conduct validation per epoch. Usually, we conduct validation after each epoch, but for a very large corpus, we'd better validate multiple times in case to miss the best state of our model. The default value is 1.
- ~~***valid_times_per_epoch***~~. [**deprecated**] Please use steps_per_validation instead.
- ***steps_per_validation***. [default: 10] Define how many steps does each validation take place.
- ***tokenizer***. [optional] Define tokenizer here. Currently, we support 'nltk' and 'jieba'. By default, 'nltk' for English and 'jieba' for Chinese.
- **architecture**. Define the model architecture. The node is a list of layers (blocks) in block_zoo to represent a model. The supported layers of this toolkit are given in [block_zoo overview](https://microsoft.github.io/NeuronBlocks).
Expand Down Expand Up @@ -729,5 +731,7 @@ To solve the above problems, NeuronBlocks supports *fixing embedding weight* (em
***training_params/vocabulary/max_vocabulary***. [int, optional for training, default: 800,000] The max size of corpus vocabulary. If corpus vocabulary size is larger than *max_vocabulary*, it will be cut according to word frequency.
***training_params/vocabulary/max_building_lines***. [int, optional for training, default: 1,000,000] The max lines NB will read from every file to build vocabulary
## <span id="faq">Frequently Asked Questions</span>
6 changes: 5 additions & 1 deletion Tutorial_zh_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,12 @@ python predict.py --conf_path=model_zoo/demo/conf.json
CUDA_VISIBLE_DEVICES= python train.py
```
- ***cpu_num_workers***. [default: -1] Define the number of processes to preprocess the dataset. The number of processes is equal to that of logical cores CPU supports if value is negtive or 0, otherwise it is equal to *cpu_num_workers*.
- ***chunk_size***. [default: 1000000] Define the chunk size of files that NB reads every time for avoiding out of memory and the mechanism of lazy-loading.
- ***batch_size***. Define the batch size here. If there are multiple GPUs, *batch_size* is the batch size of each GPU.
- ***batch_num_to_show_results***. [necessary for training] During the training process, show the results every batch_num_to_show_results batches.
- ***max_epoch***. [necessary for training] The maximum number of epochs to train.
- ***valid_times_per_epoch***. [optional for training, default: 1] Define how many times to conduct validation per epoch. Usually, we conduct validation after each epoch, but for a very large corpus, we'd better validate multiple times in case to miss the best state of our model. The default value is 1.
- ~~***valid_times_per_epoch***~~. [**deprecated**] Please use steps_per_validation instead.
- ***steps_per_validation***. [default: 10] Define how many steps does each validation take place.
- ***tokenizer***. [optional] Define tokenizer here. Currently, we support 'nltk' and 'jieba'. By default, 'nltk' for English and 'jieba' for Chinese.
- **architecture**. Define the model architecture. The node is a list of layers (blocks) in block_zoo to represent a model. The supported layers of this toolkit are given in [block_zoo overview](https://microsoft.github.io/NeuronBlocks).
Expand Down Expand Up @@ -719,4 +721,6 @@ To solve the above problems, NeuronBlocks supports *fixing embedding weight* (em
***training_params/vocabulary/max_vocabulary***. [int, optional for training, default: 800,000] The max size of corpus vocabulary. If corpus vocabulary size is larger than *max_vocabulary*, it will be cut according to word frequency.
***training_params/vocabulary/max_building_lines***. [int, optional for training, default: 1,000,000] The max lines NB will read from every file to build vocabulary
## <span id="faq">常见问题与答案</span>
21 changes: 12 additions & 9 deletions block_zoo/Embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def inference(self):
for emb_type in self.conf:
if emb_type == 'position':
continue
self.output_dim[2] += self.conf[emb_type]['dim']
if isinstance(self.conf[emb_type]['dim'], list):
self.output_dim[2] += sum(self.conf[emb_type]['dim'])
else:
self.output_dim[2] += self.conf[emb_type]['dim']

super(EmbeddingConf, self).inference()

Expand Down Expand Up @@ -113,6 +116,7 @@ def __init__(self, layer_conf):
self.layer_conf = layer_conf

self.embeddings = nn.ModuleDict() if layer_conf.weight_on_gpu else dict()
self.char_embeddings = nn.ModuleDict()
for input_cluster in layer_conf.conf:
if 'type' in layer_conf.conf[input_cluster]:
# char embedding
Expand All @@ -122,7 +126,7 @@ def __init__(self, layer_conf):
char_emb_conf = eval(layer_conf.conf[input_cluster]['type'] + "Conf")(** char_emb_conf_dict)
char_emb_conf.inference()
char_emb_conf.verify()
self.embeddings[input_cluster] = eval(layer_conf.conf[input_cluster]['type'])(char_emb_conf)
self.char_embeddings[input_cluster] = eval(layer_conf.conf[input_cluster]['type'])(char_emb_conf)
else:
# word embedding, postag embedding, and so on
self.embeddings[input_cluster] = nn.Embedding(layer_conf.conf[input_cluster]['vocab_size'], layer_conf.conf[input_cluster]['dim'], padding_idx=0)
Expand Down Expand Up @@ -155,14 +159,13 @@ def forward(self, inputs, use_gpu=False):
if 'extra' in input_cluster:
continue
input = inputs[input_cluster]
# if 'type' in self.layer_conf.conf[input_cluster]:
# emb = self.embeddings[input_cluster](input, lengths[input]).float()
# else:
# emb = self.embeddings[input_cluster](input).float()
if list(self.embeddings[input_cluster].parameters())[0].device.type == 'cpu':
emb = self.embeddings[input_cluster](input.cpu()).float()
if input_cluster == 'char':
emb = self.char_embeddings[input_cluster](input).float()
else:
emb = self.embeddings[input_cluster](input).float()
if list(self.embeddings[input_cluster].parameters())[0].device.type == 'cpu':
emb = self.embeddings[input_cluster](input.cpu()).float()
else:
emb = self.embeddings[input_cluster](input).float()
if use_gpu is True:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
emb = emb.to(device)
Expand Down
7 changes: 5 additions & 2 deletions block_zoo/Pooling2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def default(self):
self.pool_type = 'max' # Supported: ['max', mean']
self.stride = 1
self.padding = 0
self.window_size = 3
# self.window_size = [self.input_dims[0][1], self.input_dims[0][2]]

@DocInherit
def declare(self):
Expand All @@ -38,7 +38,7 @@ def declare(self):

def check_size(self, value, attr):
res = value
if isinstance(value,int):
if isinstance(value, int):
res = [value, value]
elif (isinstance(self.window_size, tuple) or isinstance(self.window_size, list)) and len(value)==2:
res = list(value)
Expand All @@ -48,6 +48,9 @@ def check_size(self, value, attr):

@DocInherit
def inference(self):

if not hasattr(self, "window_size"):
self.window_size = [self.input_dims[0][1], self.input_dims[0][2]]

self.window_size = self.check_size(self.window_size, "window_size")
self.stride = self.check_size(self.stride, "stride")
Expand Down
52 changes: 36 additions & 16 deletions block_zoo/embedding/CNNCharEmbedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def __init__(self, **kwargs):

@DocInherit
def default(self):
self.dim = 30 # cnn's output channel dim
self.dim = [30] # cnn's output channel dim
self.embedding_matrix_dim = 30 #
self.stride = 1
self.stride = [1]
self.padding = 0
self.window_size = 3
self.window_size = [3]
self.activation = 'ReLU'

@DocInherit
Expand All @@ -41,8 +41,14 @@ def declare(self):
self.num_of_inputs = 1
self.input_ranks = [3]

def change_to_list(self, attribute):
for single in attribute:
if not isinstance(getattr(self, single), list):
setattr(self, single, [getattr(self, single)])

@DocInherit
def inference(self):
self.change_to_list(['dim', 'stride', 'window_size'])
self.output_channel_num = self.dim
self.output_rank = 3

Expand All @@ -65,20 +71,24 @@ def __init__(self, layer_conf):
super(CNNCharEmbedding, self).__init__(layer_conf)
self.layer_conf = layer_conf

assert len(layer_conf.dim) == len(layer_conf.window_size) == len(layer_conf.stride), "The attribute dim/window_size/stride must have the same length."

self.char_embeddings = nn.Embedding(layer_conf.vocab_size, layer_conf.embedding_matrix_dim, padding_idx=self.layer_conf.padding)
nn.init.uniform_(self.char_embeddings.weight, -0.001, 0.001)

self.char_cnn = nn.Conv2d(1, layer_conf.output_channel_num, (layer_conf.window_size, layer_conf.embedding_matrix_dim),
stride=self.layer_conf.stride, padding=self.layer_conf.padding)
self.char_cnn = nn.ModuleList()
for i in range(len(layer_conf.output_channel_num)):
self.char_cnn.append(nn.Conv2d(1, layer_conf.output_channel_num[i], (layer_conf.window_size[i], layer_conf.embedding_matrix_dim),
stride=self.layer_conf.stride[i], padding=self.layer_conf.padding))
if layer_conf.activation:
self.activation = eval("nn." + self.layer_conf.activation)()
else:
self.activation = None
if self.is_cuda():
self.char_embeddings = self.char_embeddings.cuda()
self.char_cnn = self.char_cnn.cuda()
if self.activation and hasattr(self.activation, 'weight'):
self.activation.weight = torch.nn.Parameter(self.activation.weight.cuda())
# if self.is_cuda():
# self.char_embeddings = self.char_embeddings.cuda()
# self.char_cnn = self.char_cnn.cuda()
# if self.activation and hasattr(self.activation, 'weight'):
# self.activation.weight = torch.nn.Parameter(self.activation.weight.cuda())

def forward(self, string):
"""
Expand All @@ -102,14 +112,24 @@ def forward(self, string):
char_embs_lookup = char_embs_lookup.view(-1, string.size()[2], self.layer_conf.embedding_matrix_dim) #[batch_size * seq_len, char num in words, embedding_dim]

string_input = torch.unsqueeze(char_embs_lookup, 1) # [batch_size * seq_len, input_channel_num=1, char num in words, embedding_dim]
string_conv = self.char_cnn(string_input).squeeze()
if self.activation:
string_conv = self.activation(string_conv)

string_maxpooling = F.max_pool1d(string_conv, string_conv.size(2)).squeeze()
string_out = string_maxpooling.view(string.size()[0], -1, self.layer_conf.output_channel_num)
outputs = []
for index, single_cnn in enumerate(self.char_cnn):
string_conv = single_cnn(string_input).squeeze(3)
if self.activation:
string_conv = self.activation(string_conv)

string_maxpooling = F.max_pool1d(string_conv, string_conv.size(2)).squeeze()
string_out = string_maxpooling.view(string.size()[0], -1, self.layer_conf.output_channel_num[index])

outputs.append(string_out)

if len(outputs) > 1:
string_output = torch.cat(outputs, 2)
else:
string_output = outputs[0]

return string_out
return string_output


if __name__ == '__main__':
Expand Down
14 changes: 8 additions & 6 deletions model_zoo/advanced/conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,16 @@
"training_params": {
"vocabulary": {
"min_word_frequency": 1,
"max_vocabulary": 100000
"max_vocabulary": 100000,
"max_building_lines": 1000000
},
"optimizer": {
"name": "Adam",
"params": {
"lr": 0.001
}
},
"chunk_size": 1000000,
"lr_decay": 0.95,
"minimum_lr": 0.0001,
"epoch_start_lr_decay": 1,
Expand All @@ -65,7 +67,7 @@
"batch_size": 30,
"batch_num_to_show_results": 10,
"max_epoch": 3,
"valid_times_per_epoch": 1,
"steps_per_validation": 10,
"text_preprocessing": ["DBC2SBC"],
"max_lengths":{
"question": 30,
Expand All @@ -90,10 +92,10 @@
"cols": ["question_char", "answer_char"],
"type": "CNNCharEmbedding",
"dropout": 0.2,
"dim": 30,
"embedding_matrix_dim": 8,
"stride":1,
"window_size": 5,
"dim": [30, 20, 100],
"embedding_matrix_dim": 50,
"stride":[1, 2, 3],
"window_size": [3,3,5],
"activation": "ReLU"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"batch_size": 256,
"batch_num_to_show_results": 10,
"max_epoch": 30,
"valid_times_per_epoch": 10,
"steps_per_validation": 10,
"fixed_lengths":{
"query": 30
}
Expand Down
Loading

0 comments on commit 58ad563

Please sign in to comment.