From 3d4c1960cd1797f7b3ccbb285bdd0f0f6ccb637e Mon Sep 17 00:00:00 2001 From: "Fabio M. Graetz, Ph.D" Date: Mon, 24 Apr 2023 04:42:46 +0200 Subject: [PATCH] Feat: Add `ElasticConfig` message type for torch elastic training (#394) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add elastic config args to pytorch proto Signed-off-by: Fabio Graetz * Add elastic config message type for torchrun training Signed-off-by: Fabio Graetz --------- Signed-off-by: Fabio Graetz Co-authored-by: Fabio Grätz Co-authored-by: Ketan Umare Signed-off-by: Eduardo Apolinario --- .../gen/pb-cpp/flyteidl/plugins/pytorch.pb.cc | 610 ++++++++- .../gen/pb-cpp/flyteidl/plugins/pytorch.pb.h | 332 ++++- .../gen/pb-go/flyteidl/plugins/pytorch.pb.go | 125 +- .../flyteidl/plugins/pytorch.pb.validate.go | 85 ++ .../gen/pb-java/flyteidl/plugins/Pytorch.java | 1103 ++++++++++++++++- .../pb_python/flyteidl/plugins/pytorch_pb2.py | 8 +- .../flyteidl/plugins/pytorch_pb2.pyi | 22 +- flyteidl/gen/pb_rust/flyteidl.plugins.rs | 20 + .../protos/flyteidl/plugins/pytorch.proto | 14 + 9 files changed, 2278 insertions(+), 41 deletions(-) diff --git a/flyteidl/gen/pb-cpp/flyteidl/plugins/pytorch.pb.cc b/flyteidl/gen/pb-cpp/flyteidl/plugins/pytorch.pb.cc index 9b1951b3898..20bc8269673 100644 --- a/flyteidl/gen/pb-cpp/flyteidl/plugins/pytorch.pb.cc +++ b/flyteidl/gen/pb-cpp/flyteidl/plugins/pytorch.pb.cc @@ -16,14 +16,33 @@ // @@protoc_insertion_point(includes) #include +extern PROTOBUF_INTERNAL_EXPORT_flyteidl_2fplugins_2fpytorch_2eproto ::google::protobuf::internal::SCCInfo<0> scc_info_ElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto; namespace flyteidl { namespace plugins { +class ElasticConfigDefaultTypeInternal { + public: + ::google::protobuf::internal::ExplicitlyConstructed _instance; +} _ElasticConfig_default_instance_; class DistributedPyTorchTrainingTaskDefaultTypeInternal { public: ::google::protobuf::internal::ExplicitlyConstructed _instance; } _DistributedPyTorchTrainingTask_default_instance_; } // namespace plugins } // namespace flyteidl +static void InitDefaultsElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::flyteidl::plugins::_ElasticConfig_default_instance_; + new (ptr) ::flyteidl::plugins::ElasticConfig(); + ::google::protobuf::internal::OnShutdownDestroyMessage(ptr); + } + ::flyteidl::plugins::ElasticConfig::InitAsDefaultInstance(); +} + +::google::protobuf::internal::SCCInfo<0> scc_info_ElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto = + {{ATOMIC_VAR_INIT(::google::protobuf::internal::SCCInfoBase::kUninitialized), 0, InitDefaultsElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto}, {}}; + static void InitDefaultsDistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto() { GOOGLE_PROTOBUF_VERIFY_VERSION; @@ -35,50 +54,69 @@ static void InitDefaultsDistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpyto ::flyteidl::plugins::DistributedPyTorchTrainingTask::InitAsDefaultInstance(); } -::google::protobuf::internal::SCCInfo<0> scc_info_DistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto = - {{ATOMIC_VAR_INIT(::google::protobuf::internal::SCCInfoBase::kUninitialized), 0, InitDefaultsDistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto}, {}}; +::google::protobuf::internal::SCCInfo<1> scc_info_DistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto = + {{ATOMIC_VAR_INIT(::google::protobuf::internal::SCCInfoBase::kUninitialized), 1, InitDefaultsDistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto}, { + &scc_info_ElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto.base,}}; void InitDefaults_flyteidl_2fplugins_2fpytorch_2eproto() { + ::google::protobuf::internal::InitSCC(&scc_info_ElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto.base); ::google::protobuf::internal::InitSCC(&scc_info_DistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto.base); } -::google::protobuf::Metadata file_level_metadata_flyteidl_2fplugins_2fpytorch_2eproto[1]; +::google::protobuf::Metadata file_level_metadata_flyteidl_2fplugins_2fpytorch_2eproto[2]; constexpr ::google::protobuf::EnumDescriptor const** file_level_enum_descriptors_flyteidl_2fplugins_2fpytorch_2eproto = nullptr; constexpr ::google::protobuf::ServiceDescriptor const** file_level_service_descriptors_flyteidl_2fplugins_2fpytorch_2eproto = nullptr; const ::google::protobuf::uint32 TableStruct_flyteidl_2fplugins_2fpytorch_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::ElasticConfig, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::ElasticConfig, rdzv_backend_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::ElasticConfig, min_replicas_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::ElasticConfig, max_replicas_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::ElasticConfig, nproc_per_node_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::ElasticConfig, max_restarts_), ~0u, // no _has_bits_ PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, _internal_metadata_), ~0u, // no _extensions_ ~0u, // no _oneof_case_ ~0u, // no _weak_field_map_ PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, workers_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedPyTorchTrainingTask, elastic_config_), }; static const ::google::protobuf::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { - { 0, -1, sizeof(::flyteidl::plugins::DistributedPyTorchTrainingTask)}, + { 0, -1, sizeof(::flyteidl::plugins::ElasticConfig)}, + { 10, -1, sizeof(::flyteidl::plugins::DistributedPyTorchTrainingTask)}, }; static ::google::protobuf::Message const * const file_default_instances[] = { + reinterpret_cast(&::flyteidl::plugins::_ElasticConfig_default_instance_), reinterpret_cast(&::flyteidl::plugins::_DistributedPyTorchTrainingTask_default_instance_), }; ::google::protobuf::internal::AssignDescriptorsTable assign_descriptors_table_flyteidl_2fplugins_2fpytorch_2eproto = { {}, AddDescriptors_flyteidl_2fplugins_2fpytorch_2eproto, "flyteidl/plugins/pytorch.proto", schemas, file_default_instances, TableStruct_flyteidl_2fplugins_2fpytorch_2eproto::offsets, - file_level_metadata_flyteidl_2fplugins_2fpytorch_2eproto, 1, file_level_enum_descriptors_flyteidl_2fplugins_2fpytorch_2eproto, file_level_service_descriptors_flyteidl_2fplugins_2fpytorch_2eproto, + file_level_metadata_flyteidl_2fplugins_2fpytorch_2eproto, 2, file_level_enum_descriptors_flyteidl_2fplugins_2fpytorch_2eproto, file_level_service_descriptors_flyteidl_2fplugins_2fpytorch_2eproto, }; const char descriptor_table_protodef_flyteidl_2fplugins_2fpytorch_2eproto[] = "\n\036flyteidl/plugins/pytorch.proto\022\020flytei" - "dl.plugins\"1\n\036DistributedPyTorchTraining" - "Task\022\017\n\007workers\030\001 \001(\005B9Z7github.com/flyt" - "eorg/flyteidl/gen/pb-go/flyteidl/plugins" - "b\006proto3" + "dl.plugins\"\177\n\rElasticConfig\022\024\n\014rdzv_back" + "end\030\001 \001(\t\022\024\n\014min_replicas\030\002 \001(\005\022\024\n\014max_r" + "eplicas\030\003 \001(\005\022\026\n\016nproc_per_node\030\004 \001(\005\022\024\n" + "\014max_restarts\030\005 \001(\005\"j\n\036DistributedPyTorc" + "hTrainingTask\022\017\n\007workers\030\001 \001(\005\0227\n\016elasti" + "c_config\030\002 \001(\0132\037.flyteidl.plugins.Elasti" + "cConfigB9Z7github.com/flyteorg/flyteidl/" + "gen/pb-go/flyteidl/pluginsb\006proto3" ; ::google::protobuf::internal::DescriptorTable descriptor_table_flyteidl_2fplugins_2fpytorch_2eproto = { false, InitDefaults_flyteidl_2fplugins_2fpytorch_2eproto, descriptor_table_protodef_flyteidl_2fplugins_2fpytorch_2eproto, - "flyteidl/plugins/pytorch.proto", &assign_descriptors_table_flyteidl_2fplugins_2fpytorch_2eproto, 168, + "flyteidl/plugins/pytorch.proto", &assign_descriptors_table_flyteidl_2fplugins_2fpytorch_2eproto, 354, }; void AddDescriptors_flyteidl_2fplugins_2fpytorch_2eproto() { @@ -93,16 +131,498 @@ static bool dynamic_init_dummy_flyteidl_2fplugins_2fpytorch_2eproto = []() { Add namespace flyteidl { namespace plugins { +// =================================================================== + +void ElasticConfig::InitAsDefaultInstance() { +} +class ElasticConfig::HasBitSetters { + public: +}; + +#if !defined(_MSC_VER) || _MSC_VER >= 1900 +const int ElasticConfig::kRdzvBackendFieldNumber; +const int ElasticConfig::kMinReplicasFieldNumber; +const int ElasticConfig::kMaxReplicasFieldNumber; +const int ElasticConfig::kNprocPerNodeFieldNumber; +const int ElasticConfig::kMaxRestartsFieldNumber; +#endif // !defined(_MSC_VER) || _MSC_VER >= 1900 + +ElasticConfig::ElasticConfig() + : ::google::protobuf::Message(), _internal_metadata_(nullptr) { + SharedCtor(); + // @@protoc_insertion_point(constructor:flyteidl.plugins.ElasticConfig) +} +ElasticConfig::ElasticConfig(const ElasticConfig& from) + : ::google::protobuf::Message(), + _internal_metadata_(nullptr) { + _internal_metadata_.MergeFrom(from._internal_metadata_); + rdzv_backend_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + if (from.rdzv_backend().size() > 0) { + rdzv_backend_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.rdzv_backend_); + } + ::memcpy(&min_replicas_, &from.min_replicas_, + static_cast(reinterpret_cast(&max_restarts_) - + reinterpret_cast(&min_replicas_)) + sizeof(max_restarts_)); + // @@protoc_insertion_point(copy_constructor:flyteidl.plugins.ElasticConfig) +} + +void ElasticConfig::SharedCtor() { + ::google::protobuf::internal::InitSCC( + &scc_info_ElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto.base); + rdzv_backend_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + ::memset(&min_replicas_, 0, static_cast( + reinterpret_cast(&max_restarts_) - + reinterpret_cast(&min_replicas_)) + sizeof(max_restarts_)); +} + +ElasticConfig::~ElasticConfig() { + // @@protoc_insertion_point(destructor:flyteidl.plugins.ElasticConfig) + SharedDtor(); +} + +void ElasticConfig::SharedDtor() { + rdzv_backend_.DestroyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} + +void ElasticConfig::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const ElasticConfig& ElasticConfig::default_instance() { + ::google::protobuf::internal::InitSCC(&::scc_info_ElasticConfig_flyteidl_2fplugins_2fpytorch_2eproto.base); + return *internal_default_instance(); +} + + +void ElasticConfig::Clear() { +// @@protoc_insertion_point(message_clear_start:flyteidl.plugins.ElasticConfig) + ::google::protobuf::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + rdzv_backend_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + ::memset(&min_replicas_, 0, static_cast( + reinterpret_cast(&max_restarts_) - + reinterpret_cast(&min_replicas_)) + sizeof(max_restarts_)); + _internal_metadata_.Clear(); +} + +#if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER +const char* ElasticConfig::_InternalParse(const char* begin, const char* end, void* object, + ::google::protobuf::internal::ParseContext* ctx) { + auto msg = static_cast(object); + ::google::protobuf::int32 size; (void)size; + int depth; (void)depth; + ::google::protobuf::uint32 tag; + ::google::protobuf::internal::ParseFunc parser_till_end; (void)parser_till_end; + auto ptr = begin; + while (ptr < end) { + ptr = ::google::protobuf::io::Parse32(ptr, &tag); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + switch (tag >> 3) { + // string rdzv_backend = 1; + case 1: { + if (static_cast<::google::protobuf::uint8>(tag) != 10) goto handle_unusual; + ptr = ::google::protobuf::io::ReadSize(ptr, &size); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + ctx->extra_parse_data().SetFieldName("flyteidl.plugins.ElasticConfig.rdzv_backend"); + object = msg->mutable_rdzv_backend(); + if (size > end - ptr + ::google::protobuf::internal::ParseContext::kSlopBytes) { + parser_till_end = ::google::protobuf::internal::GreedyStringParserUTF8; + goto string_till_end; + } + GOOGLE_PROTOBUF_PARSER_ASSERT(::google::protobuf::internal::StringCheckUTF8(ptr, size, ctx)); + ::google::protobuf::internal::InlineGreedyStringParser(object, ptr, size, ctx); + ptr += size; + break; + } + // int32 min_replicas = 2; + case 2: { + if (static_cast<::google::protobuf::uint8>(tag) != 16) goto handle_unusual; + msg->set_min_replicas(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } + // int32 max_replicas = 3; + case 3: { + if (static_cast<::google::protobuf::uint8>(tag) != 24) goto handle_unusual; + msg->set_max_replicas(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } + // int32 nproc_per_node = 4; + case 4: { + if (static_cast<::google::protobuf::uint8>(tag) != 32) goto handle_unusual; + msg->set_nproc_per_node(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } + // int32 max_restarts = 5; + case 5: { + if (static_cast<::google::protobuf::uint8>(tag) != 40) goto handle_unusual; + msg->set_max_restarts(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->EndGroup(tag); + return ptr; + } + auto res = UnknownFieldParse(tag, {_InternalParse, msg}, + ptr, end, msg->_internal_metadata_.mutable_unknown_fields(), ctx); + ptr = res.first; + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr); + if (res.second) return ptr; + } + } // switch + } // while + return ptr; +string_till_end: + static_cast<::std::string*>(object)->clear(); + static_cast<::std::string*>(object)->reserve(size); + goto len_delim_till_end; +len_delim_till_end: + return ctx->StoreAndTailCall(ptr, end, {_InternalParse, msg}, + {parser_till_end, object}, size); +} +#else // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER +bool ElasticConfig::MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) { +#define DO_(EXPRESSION) if (!PROTOBUF_PREDICT_TRUE(EXPRESSION)) goto failure + ::google::protobuf::uint32 tag; + // @@protoc_insertion_point(parse_start:flyteidl.plugins.ElasticConfig) + for (;;) { + ::std::pair<::google::protobuf::uint32, bool> p = input->ReadTagWithCutoffNoLastTag(127u); + tag = p.first; + if (!p.second) goto handle_unusual; + switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { + // string rdzv_backend = 1; + case 1: { + if (static_cast< ::google::protobuf::uint8>(tag) == (10 & 0xFF)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadString( + input, this->mutable_rdzv_backend())); + DO_(::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->rdzv_backend().data(), static_cast(this->rdzv_backend().length()), + ::google::protobuf::internal::WireFormatLite::PARSE, + "flyteidl.plugins.ElasticConfig.rdzv_backend")); + } else { + goto handle_unusual; + } + break; + } + + // int32 min_replicas = 2; + case 2: { + if (static_cast< ::google::protobuf::uint8>(tag) == (16 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &min_replicas_))); + } else { + goto handle_unusual; + } + break; + } + + // int32 max_replicas = 3; + case 3: { + if (static_cast< ::google::protobuf::uint8>(tag) == (24 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &max_replicas_))); + } else { + goto handle_unusual; + } + break; + } + + // int32 nproc_per_node = 4; + case 4: { + if (static_cast< ::google::protobuf::uint8>(tag) == (32 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &nproc_per_node_))); + } else { + goto handle_unusual; + } + break; + } + + // int32 max_restarts = 5; + case 5: { + if (static_cast< ::google::protobuf::uint8>(tag) == (40 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &max_restarts_))); + } else { + goto handle_unusual; + } + break; + } + + default: { + handle_unusual: + if (tag == 0) { + goto success; + } + DO_(::google::protobuf::internal::WireFormat::SkipField( + input, tag, _internal_metadata_.mutable_unknown_fields())); + break; + } + } + } +success: + // @@protoc_insertion_point(parse_success:flyteidl.plugins.ElasticConfig) + return true; +failure: + // @@protoc_insertion_point(parse_failure:flyteidl.plugins.ElasticConfig) + return false; +#undef DO_ +} +#endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + +void ElasticConfig::SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const { + // @@protoc_insertion_point(serialize_start:flyteidl.plugins.ElasticConfig) + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // string rdzv_backend = 1; + if (this->rdzv_backend().size() > 0) { + ::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->rdzv_backend().data(), static_cast(this->rdzv_backend().length()), + ::google::protobuf::internal::WireFormatLite::SERIALIZE, + "flyteidl.plugins.ElasticConfig.rdzv_backend"); + ::google::protobuf::internal::WireFormatLite::WriteStringMaybeAliased( + 1, this->rdzv_backend(), output); + } + + // int32 min_replicas = 2; + if (this->min_replicas() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(2, this->min_replicas(), output); + } + + // int32 max_replicas = 3; + if (this->max_replicas() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(3, this->max_replicas(), output); + } + + // int32 nproc_per_node = 4; + if (this->nproc_per_node() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(4, this->nproc_per_node(), output); + } + + // int32 max_restarts = 5; + if (this->max_restarts() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(5, this->max_restarts(), output); + } + + if (_internal_metadata_.have_unknown_fields()) { + ::google::protobuf::internal::WireFormat::SerializeUnknownFields( + _internal_metadata_.unknown_fields(), output); + } + // @@protoc_insertion_point(serialize_end:flyteidl.plugins.ElasticConfig) +} + +::google::protobuf::uint8* ElasticConfig::InternalSerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const { + // @@protoc_insertion_point(serialize_to_array_start:flyteidl.plugins.ElasticConfig) + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // string rdzv_backend = 1; + if (this->rdzv_backend().size() > 0) { + ::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->rdzv_backend().data(), static_cast(this->rdzv_backend().length()), + ::google::protobuf::internal::WireFormatLite::SERIALIZE, + "flyteidl.plugins.ElasticConfig.rdzv_backend"); + target = + ::google::protobuf::internal::WireFormatLite::WriteStringToArray( + 1, this->rdzv_backend(), target); + } + + // int32 min_replicas = 2; + if (this->min_replicas() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(2, this->min_replicas(), target); + } + + // int32 max_replicas = 3; + if (this->max_replicas() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(3, this->max_replicas(), target); + } + + // int32 nproc_per_node = 4; + if (this->nproc_per_node() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(4, this->nproc_per_node(), target); + } + + // int32 max_restarts = 5; + if (this->max_restarts() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(5, this->max_restarts(), target); + } + + if (_internal_metadata_.have_unknown_fields()) { + target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields(), target); + } + // @@protoc_insertion_point(serialize_to_array_end:flyteidl.plugins.ElasticConfig) + return target; +} + +size_t ElasticConfig::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:flyteidl.plugins.ElasticConfig) + size_t total_size = 0; + + if (_internal_metadata_.have_unknown_fields()) { + total_size += + ::google::protobuf::internal::WireFormat::ComputeUnknownFieldsSize( + _internal_metadata_.unknown_fields()); + } + ::google::protobuf::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // string rdzv_backend = 1; + if (this->rdzv_backend().size() > 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::StringSize( + this->rdzv_backend()); + } + + // int32 min_replicas = 2; + if (this->min_replicas() != 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->min_replicas()); + } + + // int32 max_replicas = 3; + if (this->max_replicas() != 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->max_replicas()); + } + + // int32 nproc_per_node = 4; + if (this->nproc_per_node() != 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->nproc_per_node()); + } + + // int32 max_restarts = 5; + if (this->max_restarts() != 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->max_restarts()); + } + + int cached_size = ::google::protobuf::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void ElasticConfig::MergeFrom(const ::google::protobuf::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:flyteidl.plugins.ElasticConfig) + GOOGLE_DCHECK_NE(&from, this); + const ElasticConfig* source = + ::google::protobuf::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:flyteidl.plugins.ElasticConfig) + ::google::protobuf::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:flyteidl.plugins.ElasticConfig) + MergeFrom(*source); + } +} + +void ElasticConfig::MergeFrom(const ElasticConfig& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:flyteidl.plugins.ElasticConfig) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom(from._internal_metadata_); + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from.rdzv_backend().size() > 0) { + + rdzv_backend_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.rdzv_backend_); + } + if (from.min_replicas() != 0) { + set_min_replicas(from.min_replicas()); + } + if (from.max_replicas() != 0) { + set_max_replicas(from.max_replicas()); + } + if (from.nproc_per_node() != 0) { + set_nproc_per_node(from.nproc_per_node()); + } + if (from.max_restarts() != 0) { + set_max_restarts(from.max_restarts()); + } +} + +void ElasticConfig::CopyFrom(const ::google::protobuf::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:flyteidl.plugins.ElasticConfig) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void ElasticConfig::CopyFrom(const ElasticConfig& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:flyteidl.plugins.ElasticConfig) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool ElasticConfig::IsInitialized() const { + return true; +} + +void ElasticConfig::Swap(ElasticConfig* other) { + if (other == this) return; + InternalSwap(other); +} +void ElasticConfig::InternalSwap(ElasticConfig* other) { + using std::swap; + _internal_metadata_.Swap(&other->_internal_metadata_); + rdzv_backend_.Swap(&other->rdzv_backend_, &::google::protobuf::internal::GetEmptyStringAlreadyInited(), + GetArenaNoVirtual()); + swap(min_replicas_, other->min_replicas_); + swap(max_replicas_, other->max_replicas_); + swap(nproc_per_node_, other->nproc_per_node_); + swap(max_restarts_, other->max_restarts_); +} + +::google::protobuf::Metadata ElasticConfig::GetMetadata() const { + ::google::protobuf::internal::AssignDescriptors(&::assign_descriptors_table_flyteidl_2fplugins_2fpytorch_2eproto); + return ::file_level_metadata_flyteidl_2fplugins_2fpytorch_2eproto[kIndexInFileMessages]; +} + + // =================================================================== void DistributedPyTorchTrainingTask::InitAsDefaultInstance() { + ::flyteidl::plugins::_DistributedPyTorchTrainingTask_default_instance_._instance.get_mutable()->elastic_config_ = const_cast< ::flyteidl::plugins::ElasticConfig*>( + ::flyteidl::plugins::ElasticConfig::internal_default_instance()); } class DistributedPyTorchTrainingTask::HasBitSetters { public: + static const ::flyteidl::plugins::ElasticConfig& elastic_config(const DistributedPyTorchTrainingTask* msg); }; +const ::flyteidl::plugins::ElasticConfig& +DistributedPyTorchTrainingTask::HasBitSetters::elastic_config(const DistributedPyTorchTrainingTask* msg) { + return *msg->elastic_config_; +} #if !defined(_MSC_VER) || _MSC_VER >= 1900 const int DistributedPyTorchTrainingTask::kWorkersFieldNumber; +const int DistributedPyTorchTrainingTask::kElasticConfigFieldNumber; #endif // !defined(_MSC_VER) || _MSC_VER >= 1900 DistributedPyTorchTrainingTask::DistributedPyTorchTrainingTask() @@ -114,12 +634,21 @@ DistributedPyTorchTrainingTask::DistributedPyTorchTrainingTask(const Distributed : ::google::protobuf::Message(), _internal_metadata_(nullptr) { _internal_metadata_.MergeFrom(from._internal_metadata_); + if (from.has_elastic_config()) { + elastic_config_ = new ::flyteidl::plugins::ElasticConfig(*from.elastic_config_); + } else { + elastic_config_ = nullptr; + } workers_ = from.workers_; // @@protoc_insertion_point(copy_constructor:flyteidl.plugins.DistributedPyTorchTrainingTask) } void DistributedPyTorchTrainingTask::SharedCtor() { - workers_ = 0; + ::google::protobuf::internal::InitSCC( + &scc_info_DistributedPyTorchTrainingTask_flyteidl_2fplugins_2fpytorch_2eproto.base); + ::memset(&elastic_config_, 0, static_cast( + reinterpret_cast(&workers_) - + reinterpret_cast(&elastic_config_)) + sizeof(workers_)); } DistributedPyTorchTrainingTask::~DistributedPyTorchTrainingTask() { @@ -128,6 +657,7 @@ DistributedPyTorchTrainingTask::~DistributedPyTorchTrainingTask() { } void DistributedPyTorchTrainingTask::SharedDtor() { + if (this != internal_default_instance()) delete elastic_config_; } void DistributedPyTorchTrainingTask::SetCachedSize(int size) const { @@ -145,6 +675,10 @@ void DistributedPyTorchTrainingTask::Clear() { // Prevent compiler warnings about cached_has_bits being unused (void) cached_has_bits; + if (GetArenaNoVirtual() == nullptr && elastic_config_ != nullptr) { + delete elastic_config_; + } + elastic_config_ = nullptr; workers_ = 0; _internal_metadata_.Clear(); } @@ -169,6 +703,19 @@ const char* DistributedPyTorchTrainingTask::_InternalParse(const char* begin, co GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); break; } + // .flyteidl.plugins.ElasticConfig elastic_config = 2; + case 2: { + if (static_cast<::google::protobuf::uint8>(tag) != 18) goto handle_unusual; + ptr = ::google::protobuf::io::ReadSize(ptr, &size); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + parser_till_end = ::flyteidl::plugins::ElasticConfig::_InternalParse; + object = msg->mutable_elastic_config(); + if (size > end - ptr) goto len_delim_till_end; + ptr += size; + GOOGLE_PROTOBUF_PARSER_ASSERT(ctx->ParseExactRange( + {parser_till_end, object}, ptr - size, ptr)); + break; + } default: { handle_unusual: if ((tag & 7) == 4 || tag == 0) { @@ -184,6 +731,9 @@ const char* DistributedPyTorchTrainingTask::_InternalParse(const char* begin, co } // switch } // while return ptr; +len_delim_till_end: + return ctx->StoreAndTailCall(ptr, end, {_InternalParse, msg}, + {parser_till_end, object}, size); } #else // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER bool DistributedPyTorchTrainingTask::MergePartialFromCodedStream( @@ -209,6 +759,17 @@ bool DistributedPyTorchTrainingTask::MergePartialFromCodedStream( break; } + // .flyteidl.plugins.ElasticConfig elastic_config = 2; + case 2: { + if (static_cast< ::google::protobuf::uint8>(tag) == (18 & 0xFF)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadMessage( + input, mutable_elastic_config())); + } else { + goto handle_unusual; + } + break; + } + default: { handle_unusual: if (tag == 0) { @@ -241,6 +802,12 @@ void DistributedPyTorchTrainingTask::SerializeWithCachedSizes( ::google::protobuf::internal::WireFormatLite::WriteInt32(1, this->workers(), output); } + // .flyteidl.plugins.ElasticConfig elastic_config = 2; + if (this->has_elastic_config()) { + ::google::protobuf::internal::WireFormatLite::WriteMessageMaybeToArray( + 2, HasBitSetters::elastic_config(this), output); + } + if (_internal_metadata_.have_unknown_fields()) { ::google::protobuf::internal::WireFormat::SerializeUnknownFields( _internal_metadata_.unknown_fields(), output); @@ -259,6 +826,13 @@ ::google::protobuf::uint8* DistributedPyTorchTrainingTask::InternalSerializeWith target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(1, this->workers(), target); } + // .flyteidl.plugins.ElasticConfig elastic_config = 2; + if (this->has_elastic_config()) { + target = ::google::protobuf::internal::WireFormatLite:: + InternalWriteMessageToArray( + 2, HasBitSetters::elastic_config(this), target); + } + if (_internal_metadata_.have_unknown_fields()) { target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( _internal_metadata_.unknown_fields(), target); @@ -280,6 +854,13 @@ size_t DistributedPyTorchTrainingTask::ByteSizeLong() const { // Prevent compiler warnings about cached_has_bits being unused (void) cached_has_bits; + // .flyteidl.plugins.ElasticConfig elastic_config = 2; + if (this->has_elastic_config()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::MessageSize( + *elastic_config_); + } + // int32 workers = 1; if (this->workers() != 0) { total_size += 1 + @@ -314,6 +895,9 @@ void DistributedPyTorchTrainingTask::MergeFrom(const DistributedPyTorchTrainingT ::google::protobuf::uint32 cached_has_bits = 0; (void) cached_has_bits; + if (from.has_elastic_config()) { + mutable_elastic_config()->::flyteidl::plugins::ElasticConfig::MergeFrom(from.elastic_config()); + } if (from.workers() != 0) { set_workers(from.workers()); } @@ -344,6 +928,7 @@ void DistributedPyTorchTrainingTask::Swap(DistributedPyTorchTrainingTask* other) void DistributedPyTorchTrainingTask::InternalSwap(DistributedPyTorchTrainingTask* other) { using std::swap; _internal_metadata_.Swap(&other->_internal_metadata_); + swap(elastic_config_, other->elastic_config_); swap(workers_, other->workers_); } @@ -358,6 +943,9 @@ ::google::protobuf::Metadata DistributedPyTorchTrainingTask::GetMetadata() const } // namespace flyteidl namespace google { namespace protobuf { +template<> PROTOBUF_NOINLINE ::flyteidl::plugins::ElasticConfig* Arena::CreateMaybeMessage< ::flyteidl::plugins::ElasticConfig >(Arena* arena) { + return Arena::CreateInternal< ::flyteidl::plugins::ElasticConfig >(arena); +} template<> PROTOBUF_NOINLINE ::flyteidl::plugins::DistributedPyTorchTrainingTask* Arena::CreateMaybeMessage< ::flyteidl::plugins::DistributedPyTorchTrainingTask >(Arena* arena) { return Arena::CreateInternal< ::flyteidl::plugins::DistributedPyTorchTrainingTask >(arena); } diff --git a/flyteidl/gen/pb-cpp/flyteidl/plugins/pytorch.pb.h b/flyteidl/gen/pb-cpp/flyteidl/plugins/pytorch.pb.h index c546d0f4805..b0d9b4a2763 100644 --- a/flyteidl/gen/pb-cpp/flyteidl/plugins/pytorch.pb.h +++ b/flyteidl/gen/pb-cpp/flyteidl/plugins/pytorch.pb.h @@ -41,7 +41,7 @@ struct TableStruct_flyteidl_2fplugins_2fpytorch_2eproto { PROTOBUF_SECTION_VARIABLE(protodesc_cold); static const ::google::protobuf::internal::AuxillaryParseTableField aux[] PROTOBUF_SECTION_VARIABLE(protodesc_cold); - static const ::google::protobuf::internal::ParseTable schema[1] + static const ::google::protobuf::internal::ParseTable schema[2] PROTOBUF_SECTION_VARIABLE(protodesc_cold); static const ::google::protobuf::internal::FieldMetadata field_metadata[]; static const ::google::protobuf::internal::SerializationTable serialization_table[]; @@ -53,11 +53,15 @@ namespace plugins { class DistributedPyTorchTrainingTask; class DistributedPyTorchTrainingTaskDefaultTypeInternal; extern DistributedPyTorchTrainingTaskDefaultTypeInternal _DistributedPyTorchTrainingTask_default_instance_; +class ElasticConfig; +class ElasticConfigDefaultTypeInternal; +extern ElasticConfigDefaultTypeInternal _ElasticConfig_default_instance_; } // namespace plugins } // namespace flyteidl namespace google { namespace protobuf { template<> ::flyteidl::plugins::DistributedPyTorchTrainingTask* Arena::CreateMaybeMessage<::flyteidl::plugins::DistributedPyTorchTrainingTask>(Arena*); +template<> ::flyteidl::plugins::ElasticConfig* Arena::CreateMaybeMessage<::flyteidl::plugins::ElasticConfig>(Arena*); } // namespace protobuf } // namespace google namespace flyteidl { @@ -65,6 +69,154 @@ namespace plugins { // =================================================================== +class ElasticConfig final : + public ::google::protobuf::Message /* @@protoc_insertion_point(class_definition:flyteidl.plugins.ElasticConfig) */ { + public: + ElasticConfig(); + virtual ~ElasticConfig(); + + ElasticConfig(const ElasticConfig& from); + + inline ElasticConfig& operator=(const ElasticConfig& from) { + CopyFrom(from); + return *this; + } + #if LANG_CXX11 + ElasticConfig(ElasticConfig&& from) noexcept + : ElasticConfig() { + *this = ::std::move(from); + } + + inline ElasticConfig& operator=(ElasticConfig&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + #endif + static const ::google::protobuf::Descriptor* descriptor() { + return default_instance().GetDescriptor(); + } + static const ElasticConfig& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ElasticConfig* internal_default_instance() { + return reinterpret_cast( + &_ElasticConfig_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + void Swap(ElasticConfig* other); + friend void swap(ElasticConfig& a, ElasticConfig& b) { + a.Swap(&b); + } + + // implements Message ---------------------------------------------- + + inline ElasticConfig* New() const final { + return CreateMaybeMessage(nullptr); + } + + ElasticConfig* New(::google::protobuf::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::google::protobuf::Message& from) final; + void MergeFrom(const ::google::protobuf::Message& from) final; + void CopyFrom(const ElasticConfig& from); + void MergeFrom(const ElasticConfig& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + #if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + static const char* _InternalParse(const char* begin, const char* end, void* object, ::google::protobuf::internal::ParseContext* ctx); + ::google::protobuf::internal::ParseFunc _ParseFunc() const final { return _InternalParse; } + #else + bool MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) final; + #endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + void SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const final; + ::google::protobuf::uint8* InternalSerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + void SharedCtor(); + void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ElasticConfig* other); + private: + inline ::google::protobuf::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::google::protobuf::Metadata GetMetadata() const final; + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + // string rdzv_backend = 1; + void clear_rdzv_backend(); + static const int kRdzvBackendFieldNumber = 1; + const ::std::string& rdzv_backend() const; + void set_rdzv_backend(const ::std::string& value); + #if LANG_CXX11 + void set_rdzv_backend(::std::string&& value); + #endif + void set_rdzv_backend(const char* value); + void set_rdzv_backend(const char* value, size_t size); + ::std::string* mutable_rdzv_backend(); + ::std::string* release_rdzv_backend(); + void set_allocated_rdzv_backend(::std::string* rdzv_backend); + + // int32 min_replicas = 2; + void clear_min_replicas(); + static const int kMinReplicasFieldNumber = 2; + ::google::protobuf::int32 min_replicas() const; + void set_min_replicas(::google::protobuf::int32 value); + + // int32 max_replicas = 3; + void clear_max_replicas(); + static const int kMaxReplicasFieldNumber = 3; + ::google::protobuf::int32 max_replicas() const; + void set_max_replicas(::google::protobuf::int32 value); + + // int32 nproc_per_node = 4; + void clear_nproc_per_node(); + static const int kNprocPerNodeFieldNumber = 4; + ::google::protobuf::int32 nproc_per_node() const; + void set_nproc_per_node(::google::protobuf::int32 value); + + // int32 max_restarts = 5; + void clear_max_restarts(); + static const int kMaxRestartsFieldNumber = 5; + ::google::protobuf::int32 max_restarts() const; + void set_max_restarts(::google::protobuf::int32 value); + + // @@protoc_insertion_point(class_scope:flyteidl.plugins.ElasticConfig) + private: + class HasBitSetters; + + ::google::protobuf::internal::InternalMetadataWithArena _internal_metadata_; + ::google::protobuf::internal::ArenaStringPtr rdzv_backend_; + ::google::protobuf::int32 min_replicas_; + ::google::protobuf::int32 max_replicas_; + ::google::protobuf::int32 nproc_per_node_; + ::google::protobuf::int32 max_restarts_; + mutable ::google::protobuf::internal::CachedSize _cached_size_; + friend struct ::TableStruct_flyteidl_2fplugins_2fpytorch_2eproto; +}; +// ------------------------------------------------------------------- + class DistributedPyTorchTrainingTask final : public ::google::protobuf::Message /* @@protoc_insertion_point(class_definition:flyteidl.plugins.DistributedPyTorchTrainingTask) */ { public: @@ -103,7 +255,7 @@ class DistributedPyTorchTrainingTask final : &_DistributedPyTorchTrainingTask_default_instance_); } static constexpr int kIndexInFileMessages = - 0; + 1; void Swap(DistributedPyTorchTrainingTask* other); friend void swap(DistributedPyTorchTrainingTask& a, DistributedPyTorchTrainingTask& b) { @@ -160,6 +312,15 @@ class DistributedPyTorchTrainingTask final : // accessors ------------------------------------------------------- + // .flyteidl.plugins.ElasticConfig elastic_config = 2; + bool has_elastic_config() const; + void clear_elastic_config(); + static const int kElasticConfigFieldNumber = 2; + const ::flyteidl::plugins::ElasticConfig& elastic_config() const; + ::flyteidl::plugins::ElasticConfig* release_elastic_config(); + ::flyteidl::plugins::ElasticConfig* mutable_elastic_config(); + void set_allocated_elastic_config(::flyteidl::plugins::ElasticConfig* elastic_config); + // int32 workers = 1; void clear_workers(); static const int kWorkersFieldNumber = 1; @@ -171,6 +332,7 @@ class DistributedPyTorchTrainingTask final : class HasBitSetters; ::google::protobuf::internal::InternalMetadataWithArena _internal_metadata_; + ::flyteidl::plugins::ElasticConfig* elastic_config_; ::google::protobuf::int32 workers_; mutable ::google::protobuf::internal::CachedSize _cached_size_; friend struct ::TableStruct_flyteidl_2fplugins_2fpytorch_2eproto; @@ -184,6 +346,119 @@ class DistributedPyTorchTrainingTask final : #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" #endif // __GNUC__ +// ElasticConfig + +// string rdzv_backend = 1; +inline void ElasticConfig::clear_rdzv_backend() { + rdzv_backend_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline const ::std::string& ElasticConfig::rdzv_backend() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.ElasticConfig.rdzv_backend) + return rdzv_backend_.GetNoArena(); +} +inline void ElasticConfig::set_rdzv_backend(const ::std::string& value) { + + rdzv_backend_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); + // @@protoc_insertion_point(field_set:flyteidl.plugins.ElasticConfig.rdzv_backend) +} +#if LANG_CXX11 +inline void ElasticConfig::set_rdzv_backend(::std::string&& value) { + + rdzv_backend_.SetNoArena( + &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:flyteidl.plugins.ElasticConfig.rdzv_backend) +} +#endif +inline void ElasticConfig::set_rdzv_backend(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + rdzv_backend_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:flyteidl.plugins.ElasticConfig.rdzv_backend) +} +inline void ElasticConfig::set_rdzv_backend(const char* value, size_t size) { + + rdzv_backend_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:flyteidl.plugins.ElasticConfig.rdzv_backend) +} +inline ::std::string* ElasticConfig::mutable_rdzv_backend() { + + // @@protoc_insertion_point(field_mutable:flyteidl.plugins.ElasticConfig.rdzv_backend) + return rdzv_backend_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline ::std::string* ElasticConfig::release_rdzv_backend() { + // @@protoc_insertion_point(field_release:flyteidl.plugins.ElasticConfig.rdzv_backend) + + return rdzv_backend_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline void ElasticConfig::set_allocated_rdzv_backend(::std::string* rdzv_backend) { + if (rdzv_backend != nullptr) { + + } else { + + } + rdzv_backend_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), rdzv_backend); + // @@protoc_insertion_point(field_set_allocated:flyteidl.plugins.ElasticConfig.rdzv_backend) +} + +// int32 min_replicas = 2; +inline void ElasticConfig::clear_min_replicas() { + min_replicas_ = 0; +} +inline ::google::protobuf::int32 ElasticConfig::min_replicas() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.ElasticConfig.min_replicas) + return min_replicas_; +} +inline void ElasticConfig::set_min_replicas(::google::protobuf::int32 value) { + + min_replicas_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.ElasticConfig.min_replicas) +} + +// int32 max_replicas = 3; +inline void ElasticConfig::clear_max_replicas() { + max_replicas_ = 0; +} +inline ::google::protobuf::int32 ElasticConfig::max_replicas() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.ElasticConfig.max_replicas) + return max_replicas_; +} +inline void ElasticConfig::set_max_replicas(::google::protobuf::int32 value) { + + max_replicas_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.ElasticConfig.max_replicas) +} + +// int32 nproc_per_node = 4; +inline void ElasticConfig::clear_nproc_per_node() { + nproc_per_node_ = 0; +} +inline ::google::protobuf::int32 ElasticConfig::nproc_per_node() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.ElasticConfig.nproc_per_node) + return nproc_per_node_; +} +inline void ElasticConfig::set_nproc_per_node(::google::protobuf::int32 value) { + + nproc_per_node_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.ElasticConfig.nproc_per_node) +} + +// int32 max_restarts = 5; +inline void ElasticConfig::clear_max_restarts() { + max_restarts_ = 0; +} +inline ::google::protobuf::int32 ElasticConfig::max_restarts() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.ElasticConfig.max_restarts) + return max_restarts_; +} +inline void ElasticConfig::set_max_restarts(::google::protobuf::int32 value) { + + max_restarts_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.ElasticConfig.max_restarts) +} + +// ------------------------------------------------------------------- + // DistributedPyTorchTrainingTask // int32 workers = 1; @@ -200,9 +475,62 @@ inline void DistributedPyTorchTrainingTask::set_workers(::google::protobuf::int3 // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedPyTorchTrainingTask.workers) } +// .flyteidl.plugins.ElasticConfig elastic_config = 2; +inline bool DistributedPyTorchTrainingTask::has_elastic_config() const { + return this != internal_default_instance() && elastic_config_ != nullptr; +} +inline void DistributedPyTorchTrainingTask::clear_elastic_config() { + if (GetArenaNoVirtual() == nullptr && elastic_config_ != nullptr) { + delete elastic_config_; + } + elastic_config_ = nullptr; +} +inline const ::flyteidl::plugins::ElasticConfig& DistributedPyTorchTrainingTask::elastic_config() const { + const ::flyteidl::plugins::ElasticConfig* p = elastic_config_; + // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedPyTorchTrainingTask.elastic_config) + return p != nullptr ? *p : *reinterpret_cast( + &::flyteidl::plugins::_ElasticConfig_default_instance_); +} +inline ::flyteidl::plugins::ElasticConfig* DistributedPyTorchTrainingTask::release_elastic_config() { + // @@protoc_insertion_point(field_release:flyteidl.plugins.DistributedPyTorchTrainingTask.elastic_config) + + ::flyteidl::plugins::ElasticConfig* temp = elastic_config_; + elastic_config_ = nullptr; + return temp; +} +inline ::flyteidl::plugins::ElasticConfig* DistributedPyTorchTrainingTask::mutable_elastic_config() { + + if (elastic_config_ == nullptr) { + auto* p = CreateMaybeMessage<::flyteidl::plugins::ElasticConfig>(GetArenaNoVirtual()); + elastic_config_ = p; + } + // @@protoc_insertion_point(field_mutable:flyteidl.plugins.DistributedPyTorchTrainingTask.elastic_config) + return elastic_config_; +} +inline void DistributedPyTorchTrainingTask::set_allocated_elastic_config(::flyteidl::plugins::ElasticConfig* elastic_config) { + ::google::protobuf::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete elastic_config_; + } + if (elastic_config) { + ::google::protobuf::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + elastic_config = ::google::protobuf::internal::GetOwnedMessage( + message_arena, elastic_config, submessage_arena); + } + + } else { + + } + elastic_config_ = elastic_config; + // @@protoc_insertion_point(field_set_allocated:flyteidl.plugins.DistributedPyTorchTrainingTask.elastic_config) +} + #ifdef __GNUC__ #pragma GCC diagnostic pop #endif // __GNUC__ +// ------------------------------------------------------------------- + // @@protoc_insertion_point(namespace_scope) diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/pytorch.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/pytorch.pb.go index 79138e56833..f75649fa0f1 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/pytorch.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/pytorch.pb.go @@ -20,20 +20,96 @@ var _ = math.Inf // proto package needs to be updated. const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package -// Custom proto for plugin that enables distributed training using https://github.com/kubeflow/pytorch-operator -type DistributedPyTorchTrainingTask struct { - // number of worker replicas spawned in the cluster for this job - Workers int32 `protobuf:"varint,1,opt,name=workers,proto3" json:"workers,omitempty"` +// Custom proto for torch elastic config for distributed training using +// https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go +type ElasticConfig struct { + RdzvBackend string `protobuf:"bytes,1,opt,name=rdzv_backend,json=rdzvBackend,proto3" json:"rdzv_backend,omitempty"` + MinReplicas int32 `protobuf:"varint,2,opt,name=min_replicas,json=minReplicas,proto3" json:"min_replicas,omitempty"` + MaxReplicas int32 `protobuf:"varint,3,opt,name=max_replicas,json=maxReplicas,proto3" json:"max_replicas,omitempty"` + NprocPerNode int32 `protobuf:"varint,4,opt,name=nproc_per_node,json=nprocPerNode,proto3" json:"nproc_per_node,omitempty"` + MaxRestarts int32 `protobuf:"varint,5,opt,name=max_restarts,json=maxRestarts,proto3" json:"max_restarts,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` } +func (m *ElasticConfig) Reset() { *m = ElasticConfig{} } +func (m *ElasticConfig) String() string { return proto.CompactTextString(m) } +func (*ElasticConfig) ProtoMessage() {} +func (*ElasticConfig) Descriptor() ([]byte, []int) { + return fileDescriptor_4df8a9374b28b766, []int{0} +} + +func (m *ElasticConfig) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ElasticConfig.Unmarshal(m, b) +} +func (m *ElasticConfig) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ElasticConfig.Marshal(b, m, deterministic) +} +func (m *ElasticConfig) XXX_Merge(src proto.Message) { + xxx_messageInfo_ElasticConfig.Merge(m, src) +} +func (m *ElasticConfig) XXX_Size() int { + return xxx_messageInfo_ElasticConfig.Size(m) +} +func (m *ElasticConfig) XXX_DiscardUnknown() { + xxx_messageInfo_ElasticConfig.DiscardUnknown(m) +} + +var xxx_messageInfo_ElasticConfig proto.InternalMessageInfo + +func (m *ElasticConfig) GetRdzvBackend() string { + if m != nil { + return m.RdzvBackend + } + return "" +} + +func (m *ElasticConfig) GetMinReplicas() int32 { + if m != nil { + return m.MinReplicas + } + return 0 +} + +func (m *ElasticConfig) GetMaxReplicas() int32 { + if m != nil { + return m.MaxReplicas + } + return 0 +} + +func (m *ElasticConfig) GetNprocPerNode() int32 { + if m != nil { + return m.NprocPerNode + } + return 0 +} + +func (m *ElasticConfig) GetMaxRestarts() int32 { + if m != nil { + return m.MaxRestarts + } + return 0 +} + +// Custom proto for plugin that enables distributed training using https://github.com/kubeflow/pytorch-operator +type DistributedPyTorchTrainingTask struct { + // number of worker replicas spawned in the cluster for this job + Workers int32 `protobuf:"varint,1,opt,name=workers,proto3" json:"workers,omitempty"` + // config for an elastic pytorch job + // + ElasticConfig *ElasticConfig `protobuf:"bytes,2,opt,name=elastic_config,json=elasticConfig,proto3" json:"elastic_config,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + func (m *DistributedPyTorchTrainingTask) Reset() { *m = DistributedPyTorchTrainingTask{} } func (m *DistributedPyTorchTrainingTask) String() string { return proto.CompactTextString(m) } func (*DistributedPyTorchTrainingTask) ProtoMessage() {} func (*DistributedPyTorchTrainingTask) Descriptor() ([]byte, []int) { - return fileDescriptor_4df8a9374b28b766, []int{0} + return fileDescriptor_4df8a9374b28b766, []int{1} } func (m *DistributedPyTorchTrainingTask) XXX_Unmarshal(b []byte) error { @@ -61,22 +137,39 @@ func (m *DistributedPyTorchTrainingTask) GetWorkers() int32 { return 0 } +func (m *DistributedPyTorchTrainingTask) GetElasticConfig() *ElasticConfig { + if m != nil { + return m.ElasticConfig + } + return nil +} + func init() { + proto.RegisterType((*ElasticConfig)(nil), "flyteidl.plugins.ElasticConfig") proto.RegisterType((*DistributedPyTorchTrainingTask)(nil), "flyteidl.plugins.DistributedPyTorchTrainingTask") } func init() { proto.RegisterFile("flyteidl/plugins/pytorch.proto", fileDescriptor_4df8a9374b28b766) } var fileDescriptor_4df8a9374b28b766 = []byte{ - // 156 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x92, 0x4b, 0xcb, 0xa9, 0x2c, - 0x49, 0xcd, 0x4c, 0xc9, 0xd1, 0x2f, 0xc8, 0x29, 0x4d, 0xcf, 0xcc, 0x2b, 0xd6, 0x2f, 0xa8, 0x2c, - 0xc9, 0x2f, 0x4a, 0xce, 0xd0, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x12, 0x80, 0xc9, 0xeb, 0x41, - 0xe5, 0x95, 0xac, 0xb8, 0xe4, 0x5c, 0x32, 0x8b, 0x4b, 0x8a, 0x32, 0x93, 0x4a, 0x4b, 0x52, 0x53, - 0x02, 0x2a, 0x43, 0x40, 0xaa, 0x43, 0x8a, 0x12, 0x33, 0xf3, 0x32, 0xf3, 0xd2, 0x43, 0x12, 0x8b, - 0xb3, 0x85, 0x24, 0xb8, 0xd8, 0xcb, 0xf3, 0x8b, 0xb2, 0x53, 0x8b, 0x8a, 0x25, 0x18, 0x15, 0x18, - 0x35, 0x58, 0x83, 0x60, 0x5c, 0x27, 0xcb, 0x28, 0xf3, 0xf4, 0xcc, 0x92, 0x8c, 0xd2, 0x24, 0xbd, - 0xe4, 0xfc, 0x5c, 0x7d, 0xb0, 0xd1, 0xf9, 0x45, 0xe9, 0xfa, 0x70, 0x37, 0xa4, 0xa7, 0xe6, 0xe9, - 0x17, 0x24, 0xe9, 0xa6, 0xe7, 0xeb, 0xa3, 0x3b, 0x2b, 0x89, 0x0d, 0xec, 0x1e, 0x63, 0x40, 0x00, - 0x00, 0x00, 0xff, 0xff, 0x91, 0x53, 0x3a, 0xa1, 0xb1, 0x00, 0x00, 0x00, + // 299 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x5c, 0x91, 0xbd, 0x4f, 0xc3, 0x30, + 0x10, 0xc5, 0x15, 0xa0, 0x20, 0xdc, 0x0f, 0xa1, 0x4c, 0x99, 0x4a, 0xa9, 0x18, 0xba, 0x90, 0x48, + 0x30, 0x20, 0xd6, 0xf2, 0x31, 0xa2, 0x2a, 0xea, 0xc4, 0x12, 0x39, 0xf6, 0xd5, 0x3d, 0x35, 0xb5, + 0xad, 0xb3, 0x0b, 0x2d, 0x23, 0xff, 0x19, 0xff, 0x19, 0xaa, 0x9b, 0x7e, 0xd0, 0xf1, 0xde, 0xfd, + 0xee, 0xa4, 0xf7, 0x1e, 0xeb, 0x4e, 0xaa, 0x95, 0x07, 0x94, 0x55, 0x66, 0xab, 0x85, 0x42, 0xed, + 0x32, 0xbb, 0xf2, 0x86, 0xc4, 0x34, 0xb5, 0x64, 0xbc, 0x89, 0xaf, 0xb6, 0xfb, 0xb4, 0xde, 0xf7, + 0x7f, 0x23, 0xd6, 0x7e, 0xad, 0xb8, 0xf3, 0x28, 0x9e, 0x8d, 0x9e, 0xa0, 0x8a, 0x6f, 0x58, 0x8b, + 0xe4, 0xf7, 0x67, 0x51, 0x72, 0x31, 0x03, 0x2d, 0x93, 0xa8, 0x17, 0x0d, 0x2e, 0xf3, 0xe6, 0x5a, + 0x1b, 0x6e, 0xa4, 0x35, 0x32, 0x47, 0x5d, 0x10, 0xd8, 0x0a, 0x05, 0x77, 0xc9, 0x49, 0x2f, 0x1a, + 0x34, 0xf2, 0xe6, 0x1c, 0x75, 0x5e, 0x4b, 0x01, 0xe1, 0xcb, 0x3d, 0x72, 0x5a, 0x23, 0x7c, 0xb9, + 0x43, 0x6e, 0x59, 0x47, 0x5b, 0x32, 0xa2, 0xb0, 0x40, 0x85, 0x36, 0x12, 0x92, 0xb3, 0x00, 0xb5, + 0x82, 0x3a, 0x02, 0x7a, 0x37, 0x12, 0xf6, 0x8f, 0x9c, 0xe7, 0xe4, 0x5d, 0xd2, 0x38, 0x78, 0xb4, + 0x91, 0xfa, 0x3f, 0x11, 0xeb, 0xbe, 0xa0, 0xf3, 0x84, 0xe5, 0xc2, 0x83, 0x1c, 0xad, 0xc6, 0x6b, + 0xcb, 0x63, 0xe2, 0xa8, 0x51, 0xab, 0x31, 0x77, 0xb3, 0x38, 0x61, 0x17, 0x5f, 0x86, 0x66, 0x40, + 0x2e, 0xf8, 0x69, 0xe4, 0xdb, 0x31, 0x7e, 0x63, 0x1d, 0xd8, 0xf8, 0x2f, 0x44, 0x08, 0x20, 0xb8, + 0x69, 0xde, 0x5f, 0xa7, 0xc7, 0x59, 0xa5, 0xff, 0x72, 0xca, 0xdb, 0x70, 0x38, 0x0e, 0x9f, 0x3e, + 0x1e, 0x15, 0xfa, 0xe9, 0xa2, 0x4c, 0x85, 0x99, 0x67, 0xe1, 0xd6, 0x90, 0xca, 0x76, 0x85, 0x28, + 0xd0, 0x99, 0x2d, 0xef, 0x94, 0xc9, 0x8e, 0x3b, 0x2a, 0xcf, 0x43, 0x39, 0x0f, 0x7f, 0x01, 0x00, + 0x00, 0xff, 0xff, 0x6f, 0x80, 0x2c, 0x15, 0xbe, 0x01, 0x00, 0x00, } diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/pytorch.pb.validate.go b/flyteidl/gen/pb-go/flyteidl/plugins/pytorch.pb.validate.go index 8e6af985285..17b90db7271 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/pytorch.pb.validate.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/pytorch.pb.validate.go @@ -36,6 +36,81 @@ var ( // define the regex for a UUID once up-front var _pytorch_uuidPattern = regexp.MustCompile("^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$") +// Validate checks the field values on ElasticConfig with the rules defined in +// the proto definition for this message. If any rules are violated, an error +// is returned. +func (m *ElasticConfig) Validate() error { + if m == nil { + return nil + } + + // no validation rules for RdzvBackend + + // no validation rules for MinReplicas + + // no validation rules for MaxReplicas + + // no validation rules for NprocPerNode + + // no validation rules for MaxRestarts + + return nil +} + +// ElasticConfigValidationError is the validation error returned by +// ElasticConfig.Validate if the designated constraints aren't met. +type ElasticConfigValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e ElasticConfigValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e ElasticConfigValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e ElasticConfigValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e ElasticConfigValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e ElasticConfigValidationError) ErrorName() string { return "ElasticConfigValidationError" } + +// Error satisfies the builtin error interface +func (e ElasticConfigValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sElasticConfig.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = ElasticConfigValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = ElasticConfigValidationError{} + // Validate checks the field values on DistributedPyTorchTrainingTask with the // rules defined in the proto definition for this message. If any rules are // violated, an error is returned. @@ -46,6 +121,16 @@ func (m *DistributedPyTorchTrainingTask) Validate() error { // no validation rules for Workers + if v, ok := interface{}(m.GetElasticConfig()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return DistributedPyTorchTrainingTaskValidationError{ + field: "ElasticConfig", + reason: "embedded message failed validation", + cause: err, + } + } + } + return nil } diff --git a/flyteidl/gen/pb-java/flyteidl/plugins/Pytorch.java b/flyteidl/gen/pb-java/flyteidl/plugins/Pytorch.java index a7709263fda..1df7f243b57 100644 --- a/flyteidl/gen/pb-java/flyteidl/plugins/Pytorch.java +++ b/flyteidl/gen/pb-java/flyteidl/plugins/Pytorch.java @@ -14,6 +14,813 @@ public static void registerAllExtensions( registerAllExtensions( (com.google.protobuf.ExtensionRegistryLite) registry); } + public interface ElasticConfigOrBuilder extends + // @@protoc_insertion_point(interface_extends:flyteidl.plugins.ElasticConfig) + com.google.protobuf.MessageOrBuilder { + + /** + * string rdzv_backend = 1; + */ + java.lang.String getRdzvBackend(); + /** + * string rdzv_backend = 1; + */ + com.google.protobuf.ByteString + getRdzvBackendBytes(); + + /** + * int32 min_replicas = 2; + */ + int getMinReplicas(); + + /** + * int32 max_replicas = 3; + */ + int getMaxReplicas(); + + /** + * int32 nproc_per_node = 4; + */ + int getNprocPerNode(); + + /** + * int32 max_restarts = 5; + */ + int getMaxRestarts(); + } + /** + *
+   * Custom proto for torch elastic config for distributed training using 
+   * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
+   * 
+ * + * Protobuf type {@code flyteidl.plugins.ElasticConfig} + */ + public static final class ElasticConfig extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:flyteidl.plugins.ElasticConfig) + ElasticConfigOrBuilder { + private static final long serialVersionUID = 0L; + // Use ElasticConfig.newBuilder() to construct. + private ElasticConfig(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private ElasticConfig() { + rdzvBackend_ = ""; + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + private ElasticConfig( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + java.lang.String s = input.readStringRequireUtf8(); + + rdzvBackend_ = s; + break; + } + case 16: { + + minReplicas_ = input.readInt32(); + break; + } + case 24: { + + maxReplicas_ = input.readInt32(); + break; + } + case 32: { + + nprocPerNode_ = input.readInt32(); + break; + } + case 40: { + + maxRestarts_ = input.readInt32(); + break; + } + default: { + if (!parseUnknownField( + input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_ElasticConfig_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_ElasticConfig_fieldAccessorTable + .ensureFieldAccessorsInitialized( + flyteidl.plugins.Pytorch.ElasticConfig.class, flyteidl.plugins.Pytorch.ElasticConfig.Builder.class); + } + + public static final int RDZV_BACKEND_FIELD_NUMBER = 1; + private volatile java.lang.Object rdzvBackend_; + /** + * string rdzv_backend = 1; + */ + public java.lang.String getRdzvBackend() { + java.lang.Object ref = rdzvBackend_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + rdzvBackend_ = s; + return s; + } + } + /** + * string rdzv_backend = 1; + */ + public com.google.protobuf.ByteString + getRdzvBackendBytes() { + java.lang.Object ref = rdzvBackend_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + rdzvBackend_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + public static final int MIN_REPLICAS_FIELD_NUMBER = 2; + private int minReplicas_; + /** + * int32 min_replicas = 2; + */ + public int getMinReplicas() { + return minReplicas_; + } + + public static final int MAX_REPLICAS_FIELD_NUMBER = 3; + private int maxReplicas_; + /** + * int32 max_replicas = 3; + */ + public int getMaxReplicas() { + return maxReplicas_; + } + + public static final int NPROC_PER_NODE_FIELD_NUMBER = 4; + private int nprocPerNode_; + /** + * int32 nproc_per_node = 4; + */ + public int getNprocPerNode() { + return nprocPerNode_; + } + + public static final int MAX_RESTARTS_FIELD_NUMBER = 5; + private int maxRestarts_; + /** + * int32 max_restarts = 5; + */ + public int getMaxRestarts() { + return maxRestarts_; + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (!getRdzvBackendBytes().isEmpty()) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, rdzvBackend_); + } + if (minReplicas_ != 0) { + output.writeInt32(2, minReplicas_); + } + if (maxReplicas_ != 0) { + output.writeInt32(3, maxReplicas_); + } + if (nprocPerNode_ != 0) { + output.writeInt32(4, nprocPerNode_); + } + if (maxRestarts_ != 0) { + output.writeInt32(5, maxRestarts_); + } + unknownFields.writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (!getRdzvBackendBytes().isEmpty()) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, rdzvBackend_); + } + if (minReplicas_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(2, minReplicas_); + } + if (maxReplicas_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(3, maxReplicas_); + } + if (nprocPerNode_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(4, nprocPerNode_); + } + if (maxRestarts_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(5, maxRestarts_); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof flyteidl.plugins.Pytorch.ElasticConfig)) { + return super.equals(obj); + } + flyteidl.plugins.Pytorch.ElasticConfig other = (flyteidl.plugins.Pytorch.ElasticConfig) obj; + + if (!getRdzvBackend() + .equals(other.getRdzvBackend())) return false; + if (getMinReplicas() + != other.getMinReplicas()) return false; + if (getMaxReplicas() + != other.getMaxReplicas()) return false; + if (getNprocPerNode() + != other.getNprocPerNode()) return false; + if (getMaxRestarts() + != other.getMaxRestarts()) return false; + if (!unknownFields.equals(other.unknownFields)) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + RDZV_BACKEND_FIELD_NUMBER; + hash = (53 * hash) + getRdzvBackend().hashCode(); + hash = (37 * hash) + MIN_REPLICAS_FIELD_NUMBER; + hash = (53 * hash) + getMinReplicas(); + hash = (37 * hash) + MAX_REPLICAS_FIELD_NUMBER; + hash = (53 * hash) + getMaxReplicas(); + hash = (37 * hash) + NPROC_PER_NODE_FIELD_NUMBER; + hash = (53 * hash) + getNprocPerNode(); + hash = (37 * hash) + MAX_RESTARTS_FIELD_NUMBER; + hash = (53 * hash) + getMaxRestarts(); + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static flyteidl.plugins.Pytorch.ElasticConfig parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static flyteidl.plugins.Pytorch.ElasticConfig parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static flyteidl.plugins.Pytorch.ElasticConfig parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(flyteidl.plugins.Pytorch.ElasticConfig prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + *
+     * Custom proto for torch elastic config for distributed training using 
+     * https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go
+     * 
+ * + * Protobuf type {@code flyteidl.plugins.ElasticConfig} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:flyteidl.plugins.ElasticConfig) + flyteidl.plugins.Pytorch.ElasticConfigOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_ElasticConfig_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_ElasticConfig_fieldAccessorTable + .ensureFieldAccessorsInitialized( + flyteidl.plugins.Pytorch.ElasticConfig.class, flyteidl.plugins.Pytorch.ElasticConfig.Builder.class); + } + + // Construct using flyteidl.plugins.Pytorch.ElasticConfig.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3 + .alwaysUseFieldBuilders) { + } + } + @java.lang.Override + public Builder clear() { + super.clear(); + rdzvBackend_ = ""; + + minReplicas_ = 0; + + maxReplicas_ = 0; + + nprocPerNode_ = 0; + + maxRestarts_ = 0; + + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return flyteidl.plugins.Pytorch.internal_static_flyteidl_plugins_ElasticConfig_descriptor; + } + + @java.lang.Override + public flyteidl.plugins.Pytorch.ElasticConfig getDefaultInstanceForType() { + return flyteidl.plugins.Pytorch.ElasticConfig.getDefaultInstance(); + } + + @java.lang.Override + public flyteidl.plugins.Pytorch.ElasticConfig build() { + flyteidl.plugins.Pytorch.ElasticConfig result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public flyteidl.plugins.Pytorch.ElasticConfig buildPartial() { + flyteidl.plugins.Pytorch.ElasticConfig result = new flyteidl.plugins.Pytorch.ElasticConfig(this); + result.rdzvBackend_ = rdzvBackend_; + result.minReplicas_ = minReplicas_; + result.maxReplicas_ = maxReplicas_; + result.nprocPerNode_ = nprocPerNode_; + result.maxRestarts_ = maxRestarts_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof flyteidl.plugins.Pytorch.ElasticConfig) { + return mergeFrom((flyteidl.plugins.Pytorch.ElasticConfig)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(flyteidl.plugins.Pytorch.ElasticConfig other) { + if (other == flyteidl.plugins.Pytorch.ElasticConfig.getDefaultInstance()) return this; + if (!other.getRdzvBackend().isEmpty()) { + rdzvBackend_ = other.rdzvBackend_; + onChanged(); + } + if (other.getMinReplicas() != 0) { + setMinReplicas(other.getMinReplicas()); + } + if (other.getMaxReplicas() != 0) { + setMaxReplicas(other.getMaxReplicas()); + } + if (other.getNprocPerNode() != 0) { + setNprocPerNode(other.getNprocPerNode()); + } + if (other.getMaxRestarts() != 0) { + setMaxRestarts(other.getMaxRestarts()); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + flyteidl.plugins.Pytorch.ElasticConfig parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (flyteidl.plugins.Pytorch.ElasticConfig) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + + private java.lang.Object rdzvBackend_ = ""; + /** + * string rdzv_backend = 1; + */ + public java.lang.String getRdzvBackend() { + java.lang.Object ref = rdzvBackend_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + rdzvBackend_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + * string rdzv_backend = 1; + */ + public com.google.protobuf.ByteString + getRdzvBackendBytes() { + java.lang.Object ref = rdzvBackend_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + rdzvBackend_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + * string rdzv_backend = 1; + */ + public Builder setRdzvBackend( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + + rdzvBackend_ = value; + onChanged(); + return this; + } + /** + * string rdzv_backend = 1; + */ + public Builder clearRdzvBackend() { + + rdzvBackend_ = getDefaultInstance().getRdzvBackend(); + onChanged(); + return this; + } + /** + * string rdzv_backend = 1; + */ + public Builder setRdzvBackendBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + + rdzvBackend_ = value; + onChanged(); + return this; + } + + private int minReplicas_ ; + /** + * int32 min_replicas = 2; + */ + public int getMinReplicas() { + return minReplicas_; + } + /** + * int32 min_replicas = 2; + */ + public Builder setMinReplicas(int value) { + + minReplicas_ = value; + onChanged(); + return this; + } + /** + * int32 min_replicas = 2; + */ + public Builder clearMinReplicas() { + + minReplicas_ = 0; + onChanged(); + return this; + } + + private int maxReplicas_ ; + /** + * int32 max_replicas = 3; + */ + public int getMaxReplicas() { + return maxReplicas_; + } + /** + * int32 max_replicas = 3; + */ + public Builder setMaxReplicas(int value) { + + maxReplicas_ = value; + onChanged(); + return this; + } + /** + * int32 max_replicas = 3; + */ + public Builder clearMaxReplicas() { + + maxReplicas_ = 0; + onChanged(); + return this; + } + + private int nprocPerNode_ ; + /** + * int32 nproc_per_node = 4; + */ + public int getNprocPerNode() { + return nprocPerNode_; + } + /** + * int32 nproc_per_node = 4; + */ + public Builder setNprocPerNode(int value) { + + nprocPerNode_ = value; + onChanged(); + return this; + } + /** + * int32 nproc_per_node = 4; + */ + public Builder clearNprocPerNode() { + + nprocPerNode_ = 0; + onChanged(); + return this; + } + + private int maxRestarts_ ; + /** + * int32 max_restarts = 5; + */ + public int getMaxRestarts() { + return maxRestarts_; + } + /** + * int32 max_restarts = 5; + */ + public Builder setMaxRestarts(int value) { + + maxRestarts_ = value; + onChanged(); + return this; + } + /** + * int32 max_restarts = 5; + */ + public Builder clearMaxRestarts() { + + maxRestarts_ = 0; + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:flyteidl.plugins.ElasticConfig) + } + + // @@protoc_insertion_point(class_scope:flyteidl.plugins.ElasticConfig) + private static final flyteidl.plugins.Pytorch.ElasticConfig DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new flyteidl.plugins.Pytorch.ElasticConfig(); + } + + public static flyteidl.plugins.Pytorch.ElasticConfig getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public ElasticConfig parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new ElasticConfig(input, extensionRegistry); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public flyteidl.plugins.Pytorch.ElasticConfig getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + public interface DistributedPyTorchTrainingTaskOrBuilder extends // @@protoc_insertion_point(interface_extends:flyteidl.plugins.DistributedPyTorchTrainingTask) com.google.protobuf.MessageOrBuilder { @@ -26,6 +833,34 @@ public interface DistributedPyTorchTrainingTaskOrBuilder extends * int32 workers = 1; */ int getWorkers(); + + /** + *
+     * config for an elastic pytorch job
+     * 
+     * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + boolean hasElasticConfig(); + /** + *
+     * config for an elastic pytorch job
+     * 
+     * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + flyteidl.plugins.Pytorch.ElasticConfig getElasticConfig(); + /** + *
+     * config for an elastic pytorch job
+     * 
+     * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + flyteidl.plugins.Pytorch.ElasticConfigOrBuilder getElasticConfigOrBuilder(); } /** *
@@ -75,6 +910,19 @@ private DistributedPyTorchTrainingTask(
               workers_ = input.readInt32();
               break;
             }
+            case 18: {
+              flyteidl.plugins.Pytorch.ElasticConfig.Builder subBuilder = null;
+              if (elasticConfig_ != null) {
+                subBuilder = elasticConfig_.toBuilder();
+              }
+              elasticConfig_ = input.readMessage(flyteidl.plugins.Pytorch.ElasticConfig.parser(), extensionRegistry);
+              if (subBuilder != null) {
+                subBuilder.mergeFrom(elasticConfig_);
+                elasticConfig_ = subBuilder.buildPartial();
+              }
+
+              break;
+            }
             default: {
               if (!parseUnknownField(
                   input, unknownFields, extensionRegistry, tag)) {
@@ -120,6 +968,42 @@ public int getWorkers() {
       return workers_;
     }
 
+    public static final int ELASTIC_CONFIG_FIELD_NUMBER = 2;
+    private flyteidl.plugins.Pytorch.ElasticConfig elasticConfig_;
+    /**
+     * 
+     * config for an elastic pytorch job
+     * 
+     * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public boolean hasElasticConfig() { + return elasticConfig_ != null; + } + /** + *
+     * config for an elastic pytorch job
+     * 
+     * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public flyteidl.plugins.Pytorch.ElasticConfig getElasticConfig() { + return elasticConfig_ == null ? flyteidl.plugins.Pytorch.ElasticConfig.getDefaultInstance() : elasticConfig_; + } + /** + *
+     * config for an elastic pytorch job
+     * 
+     * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public flyteidl.plugins.Pytorch.ElasticConfigOrBuilder getElasticConfigOrBuilder() { + return getElasticConfig(); + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -137,6 +1021,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (workers_ != 0) { output.writeInt32(1, workers_); } + if (elasticConfig_ != null) { + output.writeMessage(2, getElasticConfig()); + } unknownFields.writeTo(output); } @@ -150,6 +1037,10 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeInt32Size(1, workers_); } + if (elasticConfig_ != null) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(2, getElasticConfig()); + } size += unknownFields.getSerializedSize(); memoizedSize = size; return size; @@ -167,6 +1058,11 @@ public boolean equals(final java.lang.Object obj) { if (getWorkers() != other.getWorkers()) return false; + if (hasElasticConfig() != other.hasElasticConfig()) return false; + if (hasElasticConfig()) { + if (!getElasticConfig() + .equals(other.getElasticConfig())) return false; + } if (!unknownFields.equals(other.unknownFields)) return false; return true; } @@ -180,6 +1076,10 @@ public int hashCode() { hash = (19 * hash) + getDescriptor().hashCode(); hash = (37 * hash) + WORKERS_FIELD_NUMBER; hash = (53 * hash) + getWorkers(); + if (hasElasticConfig()) { + hash = (37 * hash) + ELASTIC_CONFIG_FIELD_NUMBER; + hash = (53 * hash) + getElasticConfig().hashCode(); + } hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; @@ -319,6 +1219,12 @@ public Builder clear() { super.clear(); workers_ = 0; + if (elasticConfigBuilder_ == null) { + elasticConfig_ = null; + } else { + elasticConfig_ = null; + elasticConfigBuilder_ = null; + } return this; } @@ -346,6 +1252,11 @@ public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask build() { public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask buildPartial() { flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask result = new flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask(this); result.workers_ = workers_; + if (elasticConfigBuilder_ == null) { + result.elasticConfig_ = elasticConfig_; + } else { + result.elasticConfig_ = elasticConfigBuilder_.build(); + } onBuilt(); return result; } @@ -397,6 +1308,9 @@ public Builder mergeFrom(flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask if (other.getWorkers() != 0) { setWorkers(other.getWorkers()); } + if (other.hasElasticConfig()) { + mergeElasticConfig(other.getElasticConfig()); + } this.mergeUnknownFields(other.unknownFields); onChanged(); return this; @@ -463,6 +1377,168 @@ public Builder clearWorkers() { onChanged(); return this; } + + private flyteidl.plugins.Pytorch.ElasticConfig elasticConfig_; + private com.google.protobuf.SingleFieldBuilderV3< + flyteidl.plugins.Pytorch.ElasticConfig, flyteidl.plugins.Pytorch.ElasticConfig.Builder, flyteidl.plugins.Pytorch.ElasticConfigOrBuilder> elasticConfigBuilder_; + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public boolean hasElasticConfig() { + return elasticConfigBuilder_ != null || elasticConfig_ != null; + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public flyteidl.plugins.Pytorch.ElasticConfig getElasticConfig() { + if (elasticConfigBuilder_ == null) { + return elasticConfig_ == null ? flyteidl.plugins.Pytorch.ElasticConfig.getDefaultInstance() : elasticConfig_; + } else { + return elasticConfigBuilder_.getMessage(); + } + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public Builder setElasticConfig(flyteidl.plugins.Pytorch.ElasticConfig value) { + if (elasticConfigBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + elasticConfig_ = value; + onChanged(); + } else { + elasticConfigBuilder_.setMessage(value); + } + + return this; + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public Builder setElasticConfig( + flyteidl.plugins.Pytorch.ElasticConfig.Builder builderForValue) { + if (elasticConfigBuilder_ == null) { + elasticConfig_ = builderForValue.build(); + onChanged(); + } else { + elasticConfigBuilder_.setMessage(builderForValue.build()); + } + + return this; + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public Builder mergeElasticConfig(flyteidl.plugins.Pytorch.ElasticConfig value) { + if (elasticConfigBuilder_ == null) { + if (elasticConfig_ != null) { + elasticConfig_ = + flyteidl.plugins.Pytorch.ElasticConfig.newBuilder(elasticConfig_).mergeFrom(value).buildPartial(); + } else { + elasticConfig_ = value; + } + onChanged(); + } else { + elasticConfigBuilder_.mergeFrom(value); + } + + return this; + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public Builder clearElasticConfig() { + if (elasticConfigBuilder_ == null) { + elasticConfig_ = null; + onChanged(); + } else { + elasticConfig_ = null; + elasticConfigBuilder_ = null; + } + + return this; + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public flyteidl.plugins.Pytorch.ElasticConfig.Builder getElasticConfigBuilder() { + + onChanged(); + return getElasticConfigFieldBuilder().getBuilder(); + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + public flyteidl.plugins.Pytorch.ElasticConfigOrBuilder getElasticConfigOrBuilder() { + if (elasticConfigBuilder_ != null) { + return elasticConfigBuilder_.getMessageOrBuilder(); + } else { + return elasticConfig_ == null ? + flyteidl.plugins.Pytorch.ElasticConfig.getDefaultInstance() : elasticConfig_; + } + } + /** + *
+       * config for an elastic pytorch job
+       * 
+       * 
+ * + * .flyteidl.plugins.ElasticConfig elastic_config = 2; + */ + private com.google.protobuf.SingleFieldBuilderV3< + flyteidl.plugins.Pytorch.ElasticConfig, flyteidl.plugins.Pytorch.ElasticConfig.Builder, flyteidl.plugins.Pytorch.ElasticConfigOrBuilder> + getElasticConfigFieldBuilder() { + if (elasticConfigBuilder_ == null) { + elasticConfigBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + flyteidl.plugins.Pytorch.ElasticConfig, flyteidl.plugins.Pytorch.ElasticConfig.Builder, flyteidl.plugins.Pytorch.ElasticConfigOrBuilder>( + getElasticConfig(), + getParentForChildren(), + isClean()); + elasticConfig_ = null; + } + return elasticConfigBuilder_; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -516,6 +1592,11 @@ public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask getDefaultInstanc } + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_flyteidl_plugins_ElasticConfig_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_flyteidl_plugins_ElasticConfig_fieldAccessorTable; private static final com.google.protobuf.Descriptors.Descriptor internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor; private static final @@ -531,10 +1612,14 @@ public flyteidl.plugins.Pytorch.DistributedPyTorchTrainingTask getDefaultInstanc static { java.lang.String[] descriptorData = { "\n\036flyteidl/plugins/pytorch.proto\022\020flytei" + - "dl.plugins\"1\n\036DistributedPyTorchTraining" + - "Task\022\017\n\007workers\030\001 \001(\005B9Z7github.com/flyt" + - "eorg/flyteidl/gen/pb-go/flyteidl/plugins" + - "b\006proto3" + "dl.plugins\"\177\n\rElasticConfig\022\024\n\014rdzv_back" + + "end\030\001 \001(\t\022\024\n\014min_replicas\030\002 \001(\005\022\024\n\014max_r" + + "eplicas\030\003 \001(\005\022\026\n\016nproc_per_node\030\004 \001(\005\022\024\n" + + "\014max_restarts\030\005 \001(\005\"j\n\036DistributedPyTorc" + + "hTrainingTask\022\017\n\007workers\030\001 \001(\005\0227\n\016elasti" + + "c_config\030\002 \001(\0132\037.flyteidl.plugins.Elasti" + + "cConfigB9Z7github.com/flyteorg/flyteidl/" + + "gen/pb-go/flyteidl/pluginsb\006proto3" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { @@ -548,12 +1633,18 @@ public com.google.protobuf.ExtensionRegistry assignDescriptors( .internalBuildGeneratedFileFrom(descriptorData, new com.google.protobuf.Descriptors.FileDescriptor[] { }, assigner); - internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor = + internal_static_flyteidl_plugins_ElasticConfig_descriptor = getDescriptor().getMessageTypes().get(0); + internal_static_flyteidl_plugins_ElasticConfig_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_flyteidl_plugins_ElasticConfig_descriptor, + new java.lang.String[] { "RdzvBackend", "MinReplicas", "MaxReplicas", "NprocPerNode", "MaxRestarts", }); + internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor = + getDescriptor().getMessageTypes().get(1); internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_flyteidl_plugins_DistributedPyTorchTrainingTask_descriptor, - new java.lang.String[] { "Workers", }); + new java.lang.String[] { "Workers", "ElasticConfig", }); } // @@protoc_insertion_point(outer_class_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/pytorch_pb2.py b/flyteidl/gen/pb_python/flyteidl/plugins/pytorch_pb2.py index 000e9a0f8c1..bc3cddc1962 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/pytorch_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/plugins/pytorch_pb2.py @@ -13,7 +13,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1e\x66lyteidl/plugins/pytorch.proto\x12\x10\x66lyteidl.plugins\":\n\x1e\x44istributedPyTorchTrainingTask\x12\x18\n\x07workers\x18\x01 \x01(\x05R\x07workersB\xbe\x01\n\x14\x63om.flyteidl.pluginsB\x0cPytorchProtoP\x01Z7github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1e\x66lyteidl/plugins/pytorch.proto\x12\x10\x66lyteidl.plugins\"\xc1\x01\n\rElasticConfig\x12!\n\x0crdzv_backend\x18\x01 \x01(\tR\x0brdzvBackend\x12!\n\x0cmin_replicas\x18\x02 \x01(\x05R\x0bminReplicas\x12!\n\x0cmax_replicas\x18\x03 \x01(\x05R\x0bmaxReplicas\x12$\n\x0enproc_per_node\x18\x04 \x01(\x05R\x0cnprocPerNode\x12!\n\x0cmax_restarts\x18\x05 \x01(\x05R\x0bmaxRestarts\"\x82\x01\n\x1e\x44istributedPyTorchTrainingTask\x12\x18\n\x07workers\x18\x01 \x01(\x05R\x07workers\x12\x46\n\x0e\x65lastic_config\x18\x02 \x01(\x0b\x32\x1f.flyteidl.plugins.ElasticConfigR\relasticConfigB\xbe\x01\n\x14\x63om.flyteidl.pluginsB\x0cPytorchProtoP\x01Z7github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -22,6 +22,8 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'\n\024com.flyteidl.pluginsB\014PytorchProtoP\001Z7github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins\242\002\003FPX\252\002\020Flyteidl.Plugins\312\002\020Flyteidl\\Plugins\342\002\034Flyteidl\\Plugins\\GPBMetadata\352\002\021Flyteidl::Plugins' - _globals['_DISTRIBUTEDPYTORCHTRAININGTASK']._serialized_start=52 - _globals['_DISTRIBUTEDPYTORCHTRAININGTASK']._serialized_end=110 + _globals['_ELASTICCONFIG']._serialized_start=53 + _globals['_ELASTICCONFIG']._serialized_end=246 + _globals['_DISTRIBUTEDPYTORCHTRAININGTASK']._serialized_start=249 + _globals['_DISTRIBUTEDPYTORCHTRAININGTASK']._serialized_end=379 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/pytorch_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/plugins/pytorch_pb2.pyi index c7f5cef6622..882c38d2deb 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/pytorch_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/plugins/pytorch_pb2.pyi @@ -1,11 +1,27 @@ from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message -from typing import ClassVar as _ClassVar, Optional as _Optional +from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor +class ElasticConfig(_message.Message): + __slots__ = ["rdzv_backend", "min_replicas", "max_replicas", "nproc_per_node", "max_restarts"] + RDZV_BACKEND_FIELD_NUMBER: _ClassVar[int] + MIN_REPLICAS_FIELD_NUMBER: _ClassVar[int] + MAX_REPLICAS_FIELD_NUMBER: _ClassVar[int] + NPROC_PER_NODE_FIELD_NUMBER: _ClassVar[int] + MAX_RESTARTS_FIELD_NUMBER: _ClassVar[int] + rdzv_backend: str + min_replicas: int + max_replicas: int + nproc_per_node: int + max_restarts: int + def __init__(self, rdzv_backend: _Optional[str] = ..., min_replicas: _Optional[int] = ..., max_replicas: _Optional[int] = ..., nproc_per_node: _Optional[int] = ..., max_restarts: _Optional[int] = ...) -> None: ... + class DistributedPyTorchTrainingTask(_message.Message): - __slots__ = ["workers"] + __slots__ = ["workers", "elastic_config"] WORKERS_FIELD_NUMBER: _ClassVar[int] + ELASTIC_CONFIG_FIELD_NUMBER: _ClassVar[int] workers: int - def __init__(self, workers: _Optional[int] = ...) -> None: ... + elastic_config: ElasticConfig + def __init__(self, workers: _Optional[int] = ..., elastic_config: _Optional[_Union[ElasticConfig, _Mapping]] = ...) -> None: ... diff --git a/flyteidl/gen/pb_rust/flyteidl.plugins.rs b/flyteidl/gen/pb_rust/flyteidl.plugins.rs index 14251aa64a5..2bc8b08c229 100644 --- a/flyteidl/gen/pb_rust/flyteidl.plugins.rs +++ b/flyteidl/gen/pb_rust/flyteidl.plugins.rs @@ -103,6 +103,22 @@ pub struct PrestoQuery { #[prost(string, tag="4")] pub statement: ::prost::alloc::string::String, } +/// Custom proto for torch elastic config for distributed training using +/// +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ElasticConfig { + #[prost(string, tag="1")] + pub rdzv_backend: ::prost::alloc::string::String, + #[prost(int32, tag="2")] + pub min_replicas: i32, + #[prost(int32, tag="3")] + pub max_replicas: i32, + #[prost(int32, tag="4")] + pub nproc_per_node: i32, + #[prost(int32, tag="5")] + pub max_restarts: i32, +} /// Custom proto for plugin that enables distributed training using #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -110,6 +126,10 @@ pub struct DistributedPyTorchTrainingTask { /// number of worker replicas spawned in the cluster for this job #[prost(int32, tag="1")] pub workers: i32, + /// config for an elastic pytorch job + /// + #[prost(message, optional, tag="2")] + pub elastic_config: ::core::option::Option, } /// Defines a query to execute on a hive cluster. #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/flyteidl/protos/flyteidl/plugins/pytorch.proto b/flyteidl/protos/flyteidl/plugins/pytorch.proto index 603de00c324..2e219d82bc3 100644 --- a/flyteidl/protos/flyteidl/plugins/pytorch.proto +++ b/flyteidl/protos/flyteidl/plugins/pytorch.proto @@ -4,8 +4,22 @@ package flyteidl.plugins; option go_package = "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"; +// Custom proto for torch elastic config for distributed training using +// https://github.com/kubeflow/training-operator/blob/master/pkg/apis/kubeflow.org/v1/pytorch_types.go +message ElasticConfig { + string rdzv_backend = 1; + int32 min_replicas = 2; + int32 max_replicas = 3; + int32 nproc_per_node = 4; + int32 max_restarts = 5; +} + // Custom proto for plugin that enables distributed training using https://github.com/kubeflow/pytorch-operator message DistributedPyTorchTrainingTask { // number of worker replicas spawned in the cluster for this job int32 workers = 1; + + // config for an elastic pytorch job + // + ElasticConfig elastic_config = 2; }