diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index edcd84d63..8d6757d61 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -103,6 +103,7 @@ std::optional ModelPullCmd::Exec(const std::string& host, int port, CTL_INF("model: " << model << ", model_id: " << model_id); // Send request download model to server + CLI_LOG("Validating download items, please wait..") Json::Value json_data; json_data["model"] = model; auto data_str = json_data.toStyledString(); diff --git a/engine/cli/utils/download_progress.cc b/engine/cli/utils/download_progress.cc index ba1aa5a5c..e77e43beb 100644 --- a/engine/cli/utils/download_progress.cc +++ b/engine/cli/utils/download_progress.cc @@ -23,13 +23,13 @@ bool DownloadProgress::Connect(const std::string& host, int port) { bool DownloadProgress::Handle(const std::string& id) { assert(!!ws_); - uint64_t total = std::numeric_limits::max(); + std::unordered_map totals; status_ = DownloadStatus::DownloadStarted; std::unique_ptr> bars; std::vector> items; indicators::show_console_cursor(false); - auto handle_message = [this, &bars, &items, &total, + auto handle_message = [this, &bars, &items, &totals, id](const std::string& message) { CTL_INF(message); @@ -72,23 +72,26 @@ bool DownloadProgress::Handle(const std::string& id) { for (int i = 0; i < ev.download_task_.items.size(); i++) { auto& it = ev.download_task_.items[i]; uint64_t downloaded = it.downloadedBytes.value_or(0); - if (total == 0 || total == std::numeric_limits::max()) { - total = it.bytes.value_or(std::numeric_limits::max()); - CTL_INF("Updated - total: " << total); + if (totals.find(it.id) == totals.end()) { + totals[it.id] = it.bytes.value_or(std::numeric_limits::max()); + CTL_INF("Updated " << it.id << " - total: " << totals[it.id]); } - if (ev.type_ == DownloadStatus::DownloadUpdated) { + + if (ev.type_ == DownloadStatus::DownloadStarted || + ev.type_ == DownloadStatus::DownloadUpdated) { (*bars)[i].set_option(indicators::option::PrefixText{ pad_string(it.id) + - std::to_string(int(static_cast(downloaded) / total * 100)) + + std::to_string( + int(static_cast(downloaded) / totals[it.id] * 100)) + '%'}); (*bars)[i].set_progress( - int(static_cast(downloaded) / total * 100)); + int(static_cast(downloaded) / totals[it.id] * 100)); (*bars)[i].set_option(indicators::option::PostfixText{ format_utils::BytesToHumanReadable(downloaded) + "/" + - format_utils::BytesToHumanReadable(total)}); + format_utils::BytesToHumanReadable(totals[it.id])}); } else if (ev.type_ == DownloadStatus::DownloadSuccess) { (*bars)[i].set_progress(100); - auto total_str = format_utils::BytesToHumanReadable(total); + auto total_str = format_utils::BytesToHumanReadable(totals[it.id]); (*bars)[i].set_option( indicators::option::PostfixText{total_str + "/" + total_str}); (*bars)[i].set_option(