Skip to content

Commit

Permalink
add synchronize op in collective library (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
msftsw authored May 30, 2021
1 parent c8c30d7 commit 35ae06b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 20 deletions.
3 changes: 3 additions & 0 deletions docker/Dockerfile.c-mcpu_avx512
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ ADD ./engine /antares/engine
RUN /antares/engine/install_antares_host.sh && rm -rf /var/lib/apt/lists/* ~/.cache
RUN ln -s clang++-10 /usr/bin/clang++ || true
RUN python3 -m pip install mpi4py
RUN mv /usr/bin/mpiexec /usr/bin/mpiexec.real && \
echo 'exec mpiexec.real --allow-run-as-root "$@"' > /usr/bin/mpiexec && \
chmod a+x /usr/bin/mpiexec
5 changes: 5 additions & 0 deletions frameworks/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ def metric(data):
results = communicate_library.metric(data)
return results

def synchronize(data):
communicate_library = init_library()
results = communicate_library.synchronize(data)
return results

def communicate(comm_type, data, names=[]):
rank, size, local_rank = init_communicate_config()
out = communicate_library.collective(data, op_type=comm_type)
Expand Down
58 changes: 38 additions & 20 deletions frameworks/tensorflow/communicate_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,42 @@ REGISTER_OP("Collective")
});

/////////////////////////////////////////////////////////////////////////////////////
template <typename Device>
class SynchronizeOpKernel: public AsyncOpKernel {
public:
explicit SynchronizeOpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c) {
}

~SynchronizeOpKernel() {
}

void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
#if defined(ANTARES_CUDA) || defined(ANTARES_ROCM)
cudaStream_t cu_stream = *CHECK_NOTNULL(reinterpret_cast<const cudaStream_t*>(c->op_device_context()->stream()->implementation()->GpuStreamMemberHack()));
CHECK_EQ(cudaSuccess, cudaStreamSynchronize(cu_stream));
#endif
done();
}

private:
TF_DISALLOW_COPY_AND_ASSIGN(SynchronizeOpKernel);
};

#if defined(ANTARES_CUDA) || defined(ANTARES_ROCM)
REGISTER_KERNEL_BUILDER(Name("Synchronize").Device(DEVICE_GPU), SynchronizeOpKernel<Eigen::GpuDevice>);
#else
REGISTER_KERNEL_BUILDER(Name("Synchronize").Device(DEVICE_CPU), SynchronizeOpKernel<Eigen::ThreadPoolDevice>);
#endif

REGISTER_OP("Synchronize")
.Input("tensor: N * T")
.Attr("T: {float64, float32, float16, int32, int16, int8}")
.Attr("N: int >= 1")
.SetIsStateful();


/////////////////////////////////////////////////////////////////////////////////////
template <typename Device>
class MetricOpKernel: public AsyncOpKernel {
public:
Expand All @@ -269,6 +302,7 @@ class MetricOpKernel: public AsyncOpKernel {
}

void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
#if defined(ANTARES_CUDA) || defined(ANTARES_ROCM)
cudaStream_t cu_stream = *CHECK_NOTNULL(reinterpret_cast<const cudaStream_t*>(c->op_device_context()->stream()->implementation()->GpuStreamMemberHack()));

static cudaEvent_t lastMetricEvent = NULL;
Expand Down Expand Up @@ -310,26 +344,7 @@ class MetricOpKernel: public AsyncOpKernel {
lastMetricEvent = currMetricEvent;
}
pthread_mutex_unlock(&__g_lock);

done();
}

private:
TF_DISALLOW_COPY_AND_ASSIGN(MetricOpKernel);
};

REGISTER_KERNEL_BUILDER(Name("Metric").Device(DEVICE_GPU), MetricOpKernel<Eigen::GpuDevice>);

#else

template <typename Device>
class MetricOpKernel: public AsyncOpKernel {
public:
explicit MetricOpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c) {
}

void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
static std::chrono::time_point<std::chrono::system_clock> lastMetricEvent;
static bool hasLastEvent = false;

Expand Down Expand Up @@ -361,14 +376,17 @@ class MetricOpKernel: public AsyncOpKernel {
hasLastEvent = true;
}
pthread_mutex_unlock(&__g_lock);

#endif
done();
}

private:
TF_DISALLOW_COPY_AND_ASSIGN(MetricOpKernel);
};

#if defined(ANTARES_CUDA) || defined(ANTARES_ROCM)
REGISTER_KERNEL_BUILDER(Name("Metric").Device(DEVICE_GPU), MetricOpKernel<Eigen::GpuDevice>);
#else
REGISTER_KERNEL_BUILDER(Name("Metric").Device(DEVICE_CPU), MetricOpKernel<Eigen::ThreadPoolDevice>);
#endif

Expand Down

0 comments on commit 35ae06b

Please sign in to comment.