Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

[LLM Runtime] Control printing information using NEURAL_SPEED_VERBOSE #1054

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/script/models/cpp_graph_inference.sh
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ function main() {
## run inference
export LANG=en_US.UTF-8
export LC_ALL=en_US.UTF-8
OMP_NUM_THREADS=$(($cores_per_instance * 1)) numactl -m 0 -C 0-$(($cores_per_instance * 1 - 1)) \
NEURAL_SPEED_VERBOSE=1 OMP_NUM_THREADS=$(($cores_per_instance * 1)) numactl -m 0 -C 0-$(($cores_per_instance * 1 - 1)) \
a32543254 marked this conversation as resolved.
Show resolved Hide resolved
$infer_cmd --seed 1234 -t $cores_per_instance -b 2047 -c ${ctx} -n ${output} -m ${model}-${precision}.bin -p "$prompt" 2>&1 | tee ${WORKING_DIR}/${logs_file} || true &
minitor

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,7 @@ if(NE_BUILD_TESTS)
add_compile_definitions(NE_BUILD_TESTS)
endif()

option(NE_PROFILING "neural_engine: use Profiling" OFF)
if (NE_PROFILING)
add_compile_definitions(NE_PERF)
endif()
add_compile_definitions(NE_PERF)
option(NE_BEAM_SEARCH_VERBOSE "neural_engine: print beam search processing log" OFF)
if (NE_BEAM_SEARCH_VERBOSE)
add_compile_definitions(NE_BEAM_SEARCH_VERBOSE_ON)
Expand Down
10 changes: 10 additions & 0 deletions intel_extension_for_transformers/llm/runtime/graph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,13 @@ outputs = model.generate(inputs, streamer=streamer, stopping_criteria=stopping_c

### 6. Perplexity (measuring model quality)
You can use the [scripts/perplexity.py](./scripts/perplexity.py) script to over a given (subset of) dataset. Run `python scripts/perplexity.py --help` for detailed usage. For more infomation of the perplexity metric, see https://huggingface.co/docs/transformers/perplexity.


### 7. Verbose Mode

Enable verbose mode and control tracing information using the `NEURAL_SPEED_VERBOSE` environment variable.

Available modes:
- 0: Print all tracing information. Comprehensive output, including: evaluation time and operator profiling.
- 1: Print evaluation time. Time taken for each evaluation.
- 2: Profile individual operator. Identify performance bottleneck within the model.
Comment on lines +596 to +598
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 0 is the most comprehensive, which is quite rare

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same as log level, which 0 for debug and larger the less info. not rare

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about 1. 2. 3. because VERBOSE=0 we usually think it is disabling verbose

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got your point, agreed~
@zhenwei-intel could you help change this 0 for disable and 1 for print all .

15 changes: 11 additions & 4 deletions intel_extension_for_transformers/llm/runtime/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def get_model_type(model_config):
return model_type

def init(self, model_name, use_quant=True, use_gptq=False, **quant_kwargs):
"""initialize cpp model using model name"""
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.model_type = Model.get_model_type(self.config)
Expand Down Expand Up @@ -127,6 +128,7 @@ def init(self, model_name, use_quant=True, use_gptq=False, **quant_kwargs):
os.remove(fp32_bin)

def init_from_bin(self, model_type, model_path, **generate_kwargs):
DDEle marked this conversation as resolved.
Show resolved Hide resolved
"""initialize cpp model from bin file"""
self.__import_package(model_type)
self.model = self.module.Model()
if "threads" not in generate_kwargs:
Expand All @@ -138,11 +140,13 @@ def init_from_bin(self, model_type, model_path, **generate_kwargs):
self.model.init_model(model_path, **generate_kwargs)

def quant_model(self, model_type, model_path, out_path, **quant_kwargs):
"""quantize model from fp32 bin"""
self.__import_package(model_type)
self.module.Model.quant_model(model_path=model_path, out_path=out_path, **quant_kwargs)

def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, stopping_criteria=None,
**generate_kwargs):
"""transformer-like generate"""
max_new_tokens = generate_kwargs.get("max_new_tokens", -1)
self.batch_size = input_ids.shape[0]
if self.model is None:
Expand Down Expand Up @@ -190,24 +194,26 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
elif (max_new_tokens != -1 and out_count >= max_new_tokens):
break
else:
all_done = [(r[-1] in [self.eos_token_id(), self.pad_token_id()]) for r in ret]
all_done = [(r[-1] in [self.__eos_token_id(), self.__pad_token_id()]) for r in ret]
if False not in all_done:
break
if streamer:
streamer.end()

self.generate_round += 1
if os.getenv("NEURAL_SPEED_VERBOSE") and os.getenv("NEURAL_SPEED_VERBOSE") in ["1", "0"]:
self.model.print_time()
return ret

def is_token_end(self):
def __is_token_end(self):
return self.model.is_token_end()

def eos_token_id(self):
def __eos_token_id(self):
if self.model_type == 'qwen':
return self.tokenizer.special_tokens['<|endoftext|>']
return self.tokenizer.eos_token_id

def pad_token_id(self):
def __pad_token_id(self):
if self.tokenizer.pad_token_id == None:
if self.batch_size == 1:
return None
Expand All @@ -217,6 +223,7 @@ def pad_token_id(self):
return self.tokenizer.pad_token_id

def __call__(self, input_ids, reinit=False, **kwargs):
"""forward function"""
if self.model is None:
self.init_from_bin(self.model_type, self.bin_file, **kwargs)
self.generate_round = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class Model {
generate_count = 0;
}

void print_time() { model_print_timings(ctx); }

static size_t np_jblas_qpack(py::array_t<int8_t> src_w, py::array_t<float> src_scales, py::array_t<int8_t> src_zeros,
py::array_t<int32_t> g_idx, py::array_t<int8_t> dst, const std::string& weight_dtype,
const std::string& alg, int group_size, const std::string& scale_dtype,
Expand Down Expand Up @@ -688,5 +690,6 @@ PYBIND11_MODULE(qwen_cpp, m)
.def_static("np_jblas_quantize", &Model::np_jblas_quantize, "Quantize tensor to jblas format", py::arg("src_w"),
py::arg("dst"), py::arg("weight_dtype") = "int4", py::arg("alg") = "sym", py::arg("group_size") = 32,
py::arg("scale_dtype") = "fp32", py::arg("compute_dtype") = "int8", py::arg("threads") = 8)
.def("print_time", &Model::print_time)
.def("reinit", &Model::reinit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -727,8 +727,10 @@ int main(int argc, char** argv) { // NOLINT
fprintf(stderr, "\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
model_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
}
if (ns_log_level() == 0 || ns_log_level() == 1) {
model_print_timings(ctx);
}

model_print_timings(ctx);
model_free(ctx);

return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,9 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
ne_build_forward_expand(&gf, inpL);
ne_graph_compute(ctx0, &gf);

#ifdef NE_PERF
bool engine_profiling_ = (getenv("ENGINE_PROFILING") != NULL);
if (engine_profiling_) {
if (ns_log_level() == 0 || ns_log_level() == 2) {
ne_graph_profiling(&gf);
}
#endif

// update kv token count
lctx.model.kv_self.n = n_past + N;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,9 @@ static bool bloom_model_eval_internal(model_context* ctx, const model_input* inp
ne_build_forward_expand(&gf, inpL);
ne_graph_compute(ctx0, &gf);

#ifdef NE_PERF
bool engine_profiling_ = (getenv("ENGINE_PROFILING") != NULL);
if (engine_profiling_) {
if (ns_log_level() == 0 || ns_log_level() == 2) {
ne_graph_profiling(&gf);
}
#endif

// update kv token count
lctx.model.kv_self.n = n_past + N;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,9 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
ne_build_forward_expand(&gf, inpL);
ne_graph_compute(ctx0, &gf);

#ifdef NE_PERF
bool engine_profiling_ = (getenv("ENGINE_PROFILING") != NULL);
if (engine_profiling_) {
if (ns_log_level() == 0 || ns_log_level() == 2) {
ne_graph_profiling(&gf);
}
#endif

// update kv token count
lctx.model.kv_self.n = n_cached;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,9 @@ static bool falcon_model_eval_internal(model_context* ctx, const model_input* in
ne_build_forward_expand(&gf, inpL);
ne_graph_compute(ctx0, &gf);

#ifdef NE_PERF
bool engine_profiling_ = (getenv("ENGINE_PROFILING") != NULL);
if (engine_profiling_) {
if (ns_log_level() == 0 || ns_log_level() == 2) {
ne_graph_profiling(&gf);
}
#endif

// update kv token count
lctx.model.kv_self.n = n_past + N;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,12 +481,9 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
ne_build_forward_expand(&gf, inpL);
ne_graph_compute(ctx0, &gf);

#ifdef NE_PERF
bool engine_profiling_ = (getenv("ENGINE_PROFILING") != NULL);
if (engine_profiling_) {
if (ns_log_level() == 0 || ns_log_level() == 2) {
ne_graph_profiling(&gf);
}
#endif

// update kv token count
lctx.model.kv_self.n = n_cached;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,9 @@ static bool gptneox_model_eval_internal(model_context* ctx, const model_input* i
ne_build_forward_expand(&gf, inpL);
ne_graph_compute(ctx0, &gf);

#ifdef NE_PERF
bool engine_profiling_ = (getenv("ENGINE_PROFILING") != NULL);
if (engine_profiling_) {
if (ns_log_level() == 0 || ns_log_level() == 2) {
ne_graph_profiling(&gf);
}
#endif

// update kv token count
lctx.model.kv_self.n = n_past + N;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,9 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp
ne_build_forward_expand(&gf, inpL);
ne_graph_compute(ctx0, &gf);

#ifdef NE_PERF
bool engine_profiling_ = (getenv("ENGINE_PROFILING") != NULL);
if (engine_profiling_) {
if (ns_log_level() == 0 || ns_log_level() == 2) {
ne_graph_profiling(&gf);
}
#endif

// update kv token count
lctx.model.kv_self.n = n_cached;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@
#include "models/model_utils/util.h"
#include "models/models.h"

int64_t ns_log_level() {
const char* log_level = getenv("NEURAL_SPEED_VERBOSE");
if (log_level == nullptr) return -1;
return std::stoi(log_level);
}

//
// kv cache
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
#define MODEL_SESSION_MAGIC MODEL_FILE_MAGIC_GGSN
#define MODEL_SESSION_VERSION 1

int64_t ns_log_level();

void model_load_internal(const std::string& fname, model_archs arch, model_context* ctx, int n_gpu_layers,
bool use_mmap, bool use_mlock, bool vocab_only, model_progress_callback progress_callback,
void* progress_callback_user_data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,9 @@ static bool mpt_model_eval_internal(model_context* ctx, const model_input* input
ne_build_forward_expand(&gf, inpL);
ne_graph_compute(ctx0, &gf);

#ifdef NE_PERF
bool engine_profiling_ = (getenv("ENGINE_PROFILING") != NULL);
if (engine_profiling_) {
if (ns_log_level() == 0 || ns_log_level() == 2) {
ne_graph_profiling(&gf);
}
#endif

// update kv token count
lctx.model.kv_self.n = n_past + N;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,9 @@ static bool opt_model_eval_internal(model_context* ctx, const model_input* input
ne_build_forward_expand(&gf, inpL);
ne_graph_compute(ctx0, &gf);

#ifdef NE_PERF
bool engine_profiling_ = (getenv("ENGINE_PROFILING") != NULL);
if (engine_profiling_) {
if (ns_log_level() == 0 || ns_log_level() == 2) {
ne_graph_profiling(&gf);
}
#endif

// update kv token count
lctx.model.kv_self.n = n_past + N;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,9 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu
ne_build_forward_expand(&gf, inpL);
ne_graph_compute(ctx0, &gf);

#ifdef NE_PERF
bool engine_profiling_ = (getenv("ENGINE_PROFILING") != NULL);
if (engine_profiling_) {
if (ns_log_level() == 0 || ns_log_level() == 2) {
ne_graph_profiling(&gf);
}
#endif

// update kv token count
lctx.model.kv_self.n = n_past + N;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,9 @@ static bool starcoder_model_eval_internal(model_context* ctx, const model_input*
ne_build_forward_expand(&gf, inpL);
ne_graph_compute(ctx0, &gf);

#ifdef NE_PERF
bool engine_profiling_ = (getenv("ENGINE_PROFILING") != NULL);
if (engine_profiling_) {
if (ns_log_level() == 0 || ns_log_level() == 2) {
ne_graph_profiling(&gf);
}
#endif

// update kv token count
lctx.model.kv_self.n = n_past + N;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def cmpData(numa, numb):
args = parser.parse_args()

woq_configs = {
"fp32": WeightOnlyQuantConfig(use_cache=True, use_quant=False),
"ggml_int4": WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4", use_cache=True, use_ggml=True),
"jblas_int4": WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4", use_cache=True),
"jblas_int8": WeightOnlyQuantConfig(compute_dtype="bf16", weight_dtype="int8", use_cache=True),
"fp32": WeightOnlyQuantConfig(use_quant=False),
"ggml_int4": WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4", use_ggml=True),
"jblas_int4": WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4"),
"jblas_int8": WeightOnlyQuantConfig(compute_dtype="bf16", weight_dtype="int8"),
}
prompt = "What is the meaning of life?"

Expand Down
Loading