Skip to content

Commit

Permalink
Merge pull request #2468 from helinwang/master_dispatch
Browse files Browse the repository at this point in the history
Implement master client for reading training tasks
  • Loading branch information
helinwang authored Jun 19, 2017
2 parents 0f6bca2 + 4b6243c commit 1a12720
Show file tree
Hide file tree
Showing 18 changed files with 456 additions and 118 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ endif(WITH_GPU)

add_subdirectory(proto)
add_subdirectory(paddle)
add_subdirectory(go/master/c)
add_subdirectory(python)
add_subdirectory(go/pserver/cclient)

Expand Down
12 changes: 4 additions & 8 deletions go/cmake/golang.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,23 @@ function(GO_LIBRARY NAME BUILD_TYPE)

# automatically get all dependencies specified in the source code
# for given target.
add_custom_target(goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...)
add_custom_target(${NAME}_goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...)

# make a symlink that references Paddle inside $GOPATH, so go get
# will use the local changes in Paddle rather than checkout Paddle
# in github.
add_custom_target(copyPaddle
add_custom_target(${NAME}_copyPaddle
COMMAND rm -rf ${PADDLE_IN_GOPATH}/Paddle
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}/Paddle)
add_dependencies(goGet copyPaddle)
add_dependencies(${NAME}_goGet ${NAME}_copyPaddle)

add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-gcflags=-shared -asmflags=-shared -installsuffix=_shared -a
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
${CMAKE_GO_FLAGS} ${GO_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})

add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
add_dependencies(${NAME} goGet)
add_dependencies(${NAME} ${NAME}_goGet)

if(NOT BUILD_TYPE STREQUAL "STATIC")
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME} DESTINATION bin)
endif()
endfunction(GO_LIBRARY)
5 changes: 3 additions & 2 deletions go/connection/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package connection

import (
"errors"
"log"
"net/rpc"
"sync"

log "github.com/sirupsen/logrus"
)

// TODO(helin): add TCP re-connect logic
Expand Down Expand Up @@ -65,7 +66,7 @@ func (c *Conn) Connect(addr string) error {
} else {
err := client.Close()
if err != nil {
log.Println(err)
log.Errorln(err)
}

return errors.New("client already set from a concurrent goroutine")
Expand Down
21 changes: 21 additions & 0 deletions go/master/c/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
cmake_minimum_required(VERSION 3.0)

get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")

project(cxx_go C Go)

include(golang)
include(flags)

set(MASTER_LIB_NAME "paddle_master")
go_library(${MASTER_LIB_NAME} SHARED)

if(PROJ_ROOT)
add_custom_command(OUTPUT ${PROJ_ROOT}/python/paddle/v2/master/lib${MASTER_LIB_NAME}.so
COMMAND rm ${CMAKE_CURRENT_BINARY_DIR}/lib${MASTER_LIB_NAME}.h
COMMAND cp ${CMAKE_CURRENT_BINARY_DIR}/lib${MASTER_LIB_NAME}.so ${PROJ_ROOT}/python/paddle/v2/master/
DEPENDS ${MASTER_LIB_NAME})
add_custom_target(paddle_master_shared ALL DEPENDS ${PROJ_ROOT}/python/paddle/v2/master/lib${MASTER_LIB_NAME}.so)
endif(PROJ_ROOT)
110 changes: 110 additions & 0 deletions go/master/c/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package main

/*
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1
typedef int paddle_master_client;
*/
import "C"

import (
"sync"
"unsafe"

"github.com/PaddlePaddle/Paddle/go/master"
log "github.com/sirupsen/logrus"
)

var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client

func add(c *master.Client) C.paddle_master_client {
mu.Lock()
defer mu.Unlock()
client := curHandle
curHandle++
handleMap[client] = c
return client
}

func get(client C.paddle_master_client) *master.Client {
mu.Lock()
defer mu.Unlock()
return handleMap[client]
}

func remove(client C.paddle_master_client) *master.Client {
mu.Lock()
defer mu.Unlock()
h := handleMap[client]
delete(handleMap, client)
return h
}

type addresser string

func (a addresser) Address() string {
return string(a)
}

//export paddle_new_master_client
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr)
c := master.NewClient(addresser(a), bufSize)
return add(c)
}

//export paddle_release_master_client
func paddle_release_master_client(client C.paddle_master_client) {
remove(client)
}

//export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client)
var paths []string
for i := 0; i < int(size); i++ {
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path)))
str := C.GoString(*ptr)
paths = append(paths, str)
}
err := c.SetDataset(paths)
if err != nil {
log.Errorln(err)
return C.PADDLE_MASTER_ERROR
}

return C.PADDLE_MASTER_OK
}

//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
r := c.NextRecord()
if len(r) == 0 {
*record = (*C.uchar)(nullPtr)
return 0
}

size := C.size_t(len(r))
*record = (*C.uchar)(C.malloc(size))
C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&r[0]), size)
return C.int(size)
}

//export mem_free
func mem_free(p unsafe.Pointer) {
// "free" may be a better name for this function, but doing so
// will cause calling any function of this library from Python
// ctypes hanging.
C.free(p)
}

func main() {}
69 changes: 62 additions & 7 deletions go/master/client.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package master

import (
"log"
"os"
"time"

"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
log "github.com/sirupsen/logrus"
)

// Addresser provide the address of the master server.
Expand All @@ -15,16 +17,61 @@ type Addresser interface {
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
ch chan []byte
}

// NewClient creates a new Client.
func NewClient(addr Addresser) *Client {
//
// bufSize is the record buffer size. NextRecord will read from this
// buffer.
func NewClient(addr Addresser, bufSize int) *Client {
c := &Client{}
c.conn = connection.New()
c.ch = make(chan []byte, bufSize)
go c.monitorMaster(addr)
go c.getRecords()
return c
}

func (c *Client) getRecords() {
for {
t, err := c.getTask()
if err != nil {
// TODO(helin): wait before move on with next
// getTask call.
log.Errorln(err)
continue
}

for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path)
if err != nil {
log.Errorln(err)
continue
}

s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
for s.Scan() {
c.ch <- s.Record()
}

if s.Err() != nil {
log.Errorln(err, chunk.Path)
}

err = f.Close()
if err != nil {
log.Errorln(err)
}
}

// We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly
// correct, but a reasonable approximation.
c.taskFinished(t.ID)
}
}

func (c *Client) monitorMaster(addr Addresser) {
lastMaster := ""
monitor := func() {
Expand All @@ -35,12 +82,12 @@ func (c *Client) monitorMaster(addr Addresser) {
if curMaster == "" {
err := c.conn.Close()
if err != nil {
log.Println(err)
log.Errorln(err)
}
} else {
err := c.conn.Connect(curMaster)
if err != nil {
log.Println(err)
log.Errorln(err)

// connect to addr failed, set
// to last known addr in order
Expand Down Expand Up @@ -69,14 +116,22 @@ func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil)
}

// GetTask gets a new task from the master server.
func (c *Client) GetTask() (Task, error) {
// getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) {
var t Task
err := c.conn.Call("Service.GetTask", 0, &t)
return t, err
}

// TaskFinished tells the master server a task is finished.
func (c *Client) TaskFinished(taskID int) error {
func (c *Client) taskFinished(taskID int) error {
return c.conn.Call("Service.TaskFinished", taskID, nil)
}

// NextRecord returns next record in the dataset.
//
// NextRecord will block until the next record is available. It is
// thread-safe.
func (c *Client) NextRecord() []byte {
return <-c.ch
}
Loading

0 comments on commit 1a12720

Please sign in to comment.