-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
new parameterupdater use paddle pserver cclient of go #2413
Changes from 1 commit
99dc606
6f1c91d
28476f5
966bf9a
c44f5dd
39d0b3d
4f366be
da3e84a
dc458a0
37594ea
8941a38
ebba2b1
c093a24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,6 @@ import ( | |
"strings" | ||
"sync" | ||
"unsafe" | ||
"fmt" | ||
|
||
"github.com/PaddlePaddle/Paddle/go/pserver" | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
#include <stdio.h> | ||
#include <stdlib.h> | ||
|
||
#include "libpaddle_pserver_cclient.h" | ||
|
||
typedef float real; | ||
|
||
void fail() { | ||
// TODO(helin): fix: gtest using cmake is not working, using this | ||
// hacky way for now. | ||
printf("test failed.\n"); | ||
exit(-1); | ||
} | ||
|
||
void print_parameter(paddle_gradient* param) { | ||
if (param == NULL) { | ||
printf("param is NULL!!\n"); | ||
} else { | ||
printf("==== parameter ====\n"); | ||
printf("name: %s\n", param->name); | ||
printf("content_len: %d\n", param->content_len); | ||
printf("content_type: %d\n", param->element_type); | ||
for (int i = 0; i < param->content_len/sizeof(real); ++i) { | ||
printf("%f ", ((float *)param->content)[i]); | ||
} | ||
printf("\n\n"); | ||
} | ||
} | ||
|
||
int main() { | ||
char addr[] = "localhost:3000"; | ||
client c = paddle_new_pserver_client(addr, 1); | ||
|
||
char* names[] = {"param_a", "param_b"}; | ||
retry: | ||
|
||
if (paddle_begin_init_params(c)) { | ||
paddle_parameter param; | ||
real param_content1[] = {0.1, 0.2, 0.3}; | ||
param.element_type = PADDLE_ELEMENT_TYPE_FLOAT32; | ||
param.name = names[0]; | ||
param.content = (unsigned char*)param_content1; | ||
param.content_len = 3 * sizeof(real); | ||
if (paddle_init_param(c, param, NULL, 0) != 0) { | ||
goto retry; | ||
} | ||
real param_content2[] = {0.4, 0.5, 0.6}; | ||
param.element_type = PADDLE_ELEMENT_TYPE_INT32; | ||
param.name = names[1]; | ||
param.content = (unsigned char*)param_content2; | ||
param.content_len = 3 * sizeof(real); | ||
if (paddle_init_param(c, param, NULL, 0) != 0) { | ||
goto retry; | ||
} | ||
|
||
if (paddle_finish_init_params(c) != 0) { | ||
goto retry; | ||
} | ||
} else { | ||
fail(); | ||
} | ||
|
||
printf("get initialized parameters from pserver:\n"); | ||
paddle_parameter* param_ptrs[2] = {NULL, NULL}; | ||
if (paddle_get_params(c, names, param_ptrs, 2) != 0) { | ||
fail(); | ||
} | ||
print_parameter(param_ptrs[0]); | ||
print_parameter(param_ptrs[1]); | ||
|
||
printf("send gradient to pserver:\n"); | ||
real gradient_content1[] = {0.01, 0.02, 0.03}; | ||
real gradinet_content2[] = {0.04, 0.05, 0.06}; | ||
|
||
paddle_gradient** grads = | ||
(paddle_gradient**)malloc(sizeof(paddle_gradient*) * 2); | ||
grads[0] = (paddle_gradient*)malloc(sizeof(paddle_gradient)); | ||
grads[0]->name = names[0]; | ||
grads[0]->content = (unsigned char*)gradient_content1; | ||
grads[0]->content_len = 3 * sizeof(real); | ||
grads[0]->element_type = PADDLE_ELEMENT_TYPE_FLOAT32; | ||
|
||
grads[1] = (paddle_gradient*)malloc(sizeof(paddle_gradient)); | ||
grads[1]->name = names[1]; | ||
grads[1]->content = (unsigned char*)gradinet_content2; | ||
grads[1]->content_len = 3 * sizeof(real); | ||
grads[1]->element_type = PADDLE_ELEMENT_TYPE_INT32; | ||
|
||
print_parameter(grads[0]); | ||
print_parameter(grads[1]); | ||
|
||
if (paddle_send_grads(c, grads, 2) != 0) { | ||
fail(); | ||
} | ||
|
||
printf("get updated parameters from pserver:\n"); | ||
// get parameters again by reusing the allocated parameter buffers. | ||
if (paddle_get_params(c, names, param_ptrs, 2) != 0) { | ||
fail(); | ||
} | ||
|
||
print_parameter(param_ptrs[0]); | ||
print_parameter(param_ptrs[1]); | ||
|
||
paddle_release_param(param_ptrs[0]); | ||
paddle_release_param(param_ptrs[1]); | ||
|
||
if (paddle_save_model(c, "/tmp/") != 0) { | ||
fail(); | ||
} | ||
|
||
printf("test success!\n"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's follow the Unix way, do not print anything (with return code 0) when everything went well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, removed |
||
return 0; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,13 @@ int update_SGD(void* optimizer, | |
const void* gradient, | ||
int num_bytes) { | ||
SGD_optimizer* o = (SGD_optimizer*)optimizer; | ||
// TODO | ||
// TODO(a simple SGD implement) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A simple SGD implementation is already done, maybe remove this comment? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, removed |
||
float* parameter = (float *)buffer; | ||
float* grad = (float *)gradient; | ||
|
||
for(int i = 0; i < num_bytes/sizeof(float); ++i) { | ||
parameter[i] -= o->learning_rate * grad[i]; | ||
} | ||
return 0; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's follow the Unix way, do not print anything (with return code 0) when everything went well.
If we print success on every test, the output would be very messy.