-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
parameter client library: stub and cgo part with functional test. #2172
Changes from 1 commit
41780ae
6011b2e
a5091ef
e21d56e
d08c8ea
1c908df
61a49e4
4db7c54
c17cef9
4759b37
7ee280e
f71453b
8774473
478ebc4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
package pserver | ||
|
||
// ElementType is the type of elements of a Parameter. | ||
type ElementType int | ||
|
||
// Supported element types | ||
const ( | ||
Int32 ElementType = iota | ||
UInt32 | ||
Int64 | ||
UInt64 | ||
Float32 | ||
Float64 | ||
) | ||
|
||
type Parameter struct { | ||
Name string | ||
ElementType ElementType | ||
Content []byte | ||
} | ||
|
||
type ParameterWithConfig struct { | ||
Param Parameter | ||
Config []byte | ||
} | ||
|
||
type Gradient Parameter | ||
|
||
type Client struct { | ||
} | ||
|
||
func NewClient(addr string) *Client { | ||
return &Client{} | ||
} | ||
|
||
func (c *Client) BeginInitParams(pserverConfigProto []byte) (bool, error) { | ||
return true, nil | ||
} | ||
|
||
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error { | ||
return nil | ||
} | ||
|
||
func (c *Client) FinishInitParams() error { | ||
return nil | ||
} | ||
|
||
func (c *Client) SendGrads(grads []Gradient) error { | ||
return nil | ||
} | ||
|
||
func (c *Client) GetParams(names []string) ([]Parameter, error) { | ||
return nil, nil | ||
} | ||
|
||
func (c *Client) SaveModel(path string) error { | ||
return nil | ||
} | ||
|
||
func (c *Client) Cleanup() { | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
cmake_minimum_required(VERSION 3.0) | ||
|
||
if(GTEST_INCLUDE_DIR AND GTEST_LIBRARIES) | ||
message("-- Found gtest (include: ${GTEST_INCLUDE_DIR}, library: ${GTEST_LIBRARIES})") | ||
else() | ||
# find #include <majel/xx.h> | ||
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) | ||
include_directories(${PARENT_DIR}) | ||
|
||
# find cmake directory modules | ||
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY) | ||
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY) | ||
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY) | ||
|
||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake") | ||
|
||
# enable c++11 | ||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") | ||
|
||
# enable gtest | ||
set(THIRD_PARTY_PATH ./third_party) | ||
set(WITH_TESTING ON) | ||
include(external/gtest) | ||
endif() | ||
|
||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") | ||
|
||
project(cxx_go CXX C Go) | ||
|
||
include(cmake/golang.cmake) | ||
include(cmake/flags.cmake) | ||
|
||
add_subdirectory(client) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
client.h | ||
client.a |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
cmake_minimum_required(VERSION 3.0) | ||
|
||
ExternalGoProject_Add(pserver github.com/PaddlePaddle/Paddle/paddle/go/pserver) | ||
add_go_library(client STATIC pserver) | ||
add_subdirectory(test) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
package main | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking this need to be integrated into Paddle C++ code (a lot of pointers are used, not sure how much does CPython supports pointer), maybe I am wrong. I think we will know better when we begin to integrate. |
||
|
||
/* | ||
#include <stdlib.h> | ||
#include <string.h> | ||
typedef enum { | ||
PADDLE_ELEMENT_TYPE_INT32 = 0, | ||
PADDLE_ELEMENT_TYPE_UINT32 = 1, | ||
PADDLE_ELEMENT_TYPE_INT64 = 2, | ||
PADDLE_ELEMENT_TYPE_UINT64 = 3, | ||
PADDLE_ELEMENT_TYPE_FLOAT32 = 4, | ||
PADDLE_ELEMENT_TYPE_FLOAT64 = 5, | ||
} paddle_element_type; | ||
|
||
typedef struct { | ||
char* name; | ||
paddle_element_type element_type; | ||
void* content; | ||
int content_len; | ||
} paddle_parameter, paddle_gradient; | ||
|
||
typedef int client; | ||
*/ | ||
import "C" | ||
|
||
import ( | ||
"log" | ||
"sync" | ||
"unsafe" | ||
|
||
"github.com/PaddlePaddle/Paddle/paddle/go/pserver" | ||
) | ||
|
||
const ( | ||
ptrSize = unsafe.Sizeof(uintptr(0)) | ||
) | ||
|
||
var mu sync.Mutex | ||
var handleMap = make(map[C.client]*pserver.Client) | ||
var curHandle C.client | ||
|
||
func add(c *pserver.Client) C.client { | ||
mu.Lock() | ||
defer mu.Unlock() | ||
client := curHandle | ||
curHandle++ | ||
handleMap[client] = c | ||
return client | ||
} | ||
|
||
func get(client C.client) *pserver.Client { | ||
mu.Lock() | ||
defer mu.Unlock() | ||
return handleMap[client] | ||
} | ||
|
||
func remove(client C.client) *pserver.Client { | ||
mu.Lock() | ||
defer mu.Unlock() | ||
h := handleMap[client] | ||
delete(handleMap, client) | ||
return h | ||
} | ||
|
||
func cArrayToSlice(p unsafe.Pointer, len int) []byte { | ||
// create a Go clice backed by a C array, | ||
// reference: /~https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices | ||
return (*[1 << 30]byte)(p)[:len:len] | ||
} | ||
|
||
//export paddle_new_pserver_client | ||
func paddle_new_pserver_client(addr *C.char) C.client { | ||
c := pserver.NewClient(C.GoString(addr)) | ||
return add(c) | ||
} | ||
|
||
//export paddle_pserver_client_release | ||
func paddle_pserver_client_release(client C.client) { | ||
c := remove(client) | ||
c.Cleanup() | ||
} | ||
|
||
//export paddle_begin_init_params | ||
func paddle_begin_init_params(client C.client, pserver_config unsafe.Pointer, config_len C.int) C.int { | ||
c := get(client) | ||
b := cArrayToSlice(pserver_config, int(config_len)) | ||
selected, err := c.BeginInitParams(b) | ||
if err != nil { | ||
log.Println(err) | ||
return -1 | ||
} | ||
|
||
if selected { | ||
return 1 | ||
} | ||
return 0 | ||
} | ||
|
||
//export paddle_init_param | ||
func paddle_init_param(client C.client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int { | ||
et := pserver.ElementType(param.element_type) | ||
name := C.GoString(param.name) | ||
content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len)) | ||
pc := pserver.ParameterWithConfig{ | ||
Param: pserver.Parameter{Name: name, ElementType: et, Content: content}, | ||
Config: cArrayToSlice(param_config, int(config_len)), | ||
} | ||
c := get(client) | ||
err := c.InitParam(pc) | ||
if err != nil { | ||
log.Println(err) | ||
return -1 | ||
} | ||
|
||
return 0 | ||
} | ||
|
||
//export paddle_finish_init_params | ||
func paddle_finish_init_params(client C.client) C.int { | ||
c := get(client) | ||
err := c.FinishInitParams() | ||
if err != nil { | ||
log.Println(err) | ||
return -1 | ||
} | ||
|
||
return 0 | ||
} | ||
|
||
//export paddle_send_grads | ||
func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C.int { | ||
var gs []pserver.Gradient | ||
for i := 0; i < int(total); i++ { | ||
grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*ptrSize))) | ||
et := pserver.ElementType(grad.element_type) | ||
name := C.GoString(grad.name) | ||
content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len)) | ||
gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: content}) | ||
} | ||
|
||
c := get(client) | ||
err := c.SendGrads(gs) | ||
if err != nil { | ||
log.Println(err) | ||
return -1 | ||
} | ||
|
||
return 0 | ||
} | ||
|
||
//export paddle_get_params | ||
func paddle_get_params(client C.client, names **C.char, dst *C.paddle_parameter, total C.int) C.int { | ||
var ns []string | ||
for i := 0; i < int(total); i++ { | ||
name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*ptrSize))) | ||
ns = append(ns, C.GoString(name)) | ||
} | ||
c := get(client) | ||
ps, err := c.GetParams(ns) | ||
if err != nil { | ||
log.Println(err) | ||
return -1 | ||
} | ||
|
||
for i := 0; i < int(total); i++ { | ||
if i >= len(ps) { | ||
break | ||
} | ||
|
||
p := ps[i] | ||
name := C.CString(p.Name) | ||
param := (*C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*ptrSize))) | ||
|
||
if unsafe.Pointer(param.name) != unsafe.Pointer(uintptr(0)) { | ||
C.free(unsafe.Pointer(param.name)) | ||
} | ||
param.name = name | ||
|
||
memReady := false | ||
if param.content != unsafe.Pointer(uintptr(0)) { | ||
if int(param.content_len) == len(p.Content) { | ||
memReady = true | ||
} else { | ||
C.free(param.content) | ||
} | ||
} | ||
|
||
if !memReady { | ||
param.content = C.malloc(C.size_t(len(p.Content))) | ||
} | ||
C.memcpy(param.content, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) | ||
param.content_len = C.int(len(p.Content)) | ||
param.element_type = C.paddle_element_type(p.ElementType) | ||
} | ||
|
||
return 0 | ||
} | ||
|
||
//export paddle_save_model | ||
func paddle_save_model(client C.client, path *C.char) C.int { | ||
p := C.GoString(path) | ||
c := get(client) | ||
err := c.SaveModel(p) | ||
if err != nil { | ||
log.Println(err) | ||
return -1 | ||
} | ||
|
||
return 0 | ||
} | ||
|
||
func main() {} // Required but ignored |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
cmake_minimum_required(VERSION 3.0) | ||
|
||
include_directories(/env/gopath/src/github.com/PaddlePaddle/Paddle/paddle/go/pserver/lib/build/client) | ||
|
||
add_executable(main main.c) | ||
add_dependencies(main client) | ||
set (CMAKE_EXE_LINKER_FLAGS "-pthread") | ||
target_link_libraries(main /env/gopath/src/github.com/PaddlePaddle/Paddle/paddle/go/pserver/lib/build/client/libclient.a ${GTEST_LIBRARIES}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#include "libclient.h" | ||
|
||
#include "gtest/gtest.h" | ||
|
||
int main() { | ||
client c = paddle_new_pserver_client(NULL); | ||
return 0; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
if(NOT CMAKE_Go_COMPILER) | ||
if(NOT $ENV{GO_COMPILER} STREQUAL "") | ||
get_filename_component(CMAKE_Go_COMPILER_INIT $ENV{GO_COMPILER} PROGRAM PROGRAM_ARGS CMAKE_Go_FLAGS_ENV_INIT) | ||
|
||
if(CMAKE_Go_FLAGS_ENV_INIT) | ||
set(CMAKE_Go_COMPILER_ARG1 "${CMAKE_Go_FLAGS_ENV_INIT}" CACHE STRING "First argument to Go compiler") | ||
endif() | ||
|
||
if(NOT EXISTS ${CMAKE_Go_COMPILER_INIT}) | ||
message(SEND_ERROR "Could not find compiler set in environment variable GO_COMPILER:\n$ENV{GO_COMPILER}.") | ||
endif() | ||
|
||
endif() | ||
|
||
set(Go_BIN_PATH | ||
$ENV{GOPATH} | ||
$ENV{GOROOT} | ||
$ENV{GOROOT}/../bin | ||
$ENV{GO_COMPILER} | ||
/usr/bin | ||
/usr/local/bin | ||
) | ||
|
||
if(CMAKE_Go_COMPILER_INIT) | ||
set(CMAKE_Go_COMPILER ${CMAKE_Go_COMPILER_INIT} CACHE PATH "Go Compiler") | ||
else() | ||
find_program(CMAKE_Go_COMPILER | ||
NAMES go | ||
PATHS ${Go_BIN_PATH} | ||
) | ||
EXEC_PROGRAM(${CMAKE_Go_COMPILER} ARGS version OUTPUT_VARIABLE GOLANG_VERSION) | ||
STRING(REGEX MATCH "go[0-9]+.[0-9]+.[0-9]+[ /A-Za-z0-9]*" VERSION "${GOLANG_VERSION}") | ||
message("-- The Golang compiler identification is ${VERSION}") | ||
message("-- Check for working Golang compiler: ${CMAKE_Go_COMPILER}") | ||
endif() | ||
|
||
endif() | ||
|
||
mark_as_advanced(CMAKE_Go_COMPILER) | ||
|
||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/CMakeGoCompiler.cmake.in | ||
${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY) | ||
|
||
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
set(CMAKE_Go_COMPILER "@CMAKE_Go_COMPILER@") | ||
set(CMAKE_Go_COMPILER_LOADED 1) | ||
|
||
set(CMAKE_Go_SOURCE_FILE_EXTENSIONS go) | ||
set(CMAKE_Go_LINKER_PREFERENCE 40) | ||
set(CMAKE_Go_OUTPUT_EXTENSION .o) | ||
set(CMAKE_Go_OUTPUT_EXTENSION_REPLACE 1) | ||
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
if(NOT CMAKE_Go_COMPILE_OBJECT) | ||
set(CMAKE_Go_COMPILE_OBJECT "go tool compile -l -N -o <OBJECT> <SOURCE> ") | ||
endif() | ||
|
||
if(NOT CMAKE_Go_LINK_EXECUTABLE) | ||
set(CMAKE_Go_LINK_EXECUTABLE "go tool link -o <TARGET> <OBJECTS> ") | ||
endif() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
set(CMAKE_Go_COMPILER_WORKS 1 CACHE INTERNAL "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to define the ParameterBlock type(partitioned Parameter ) and put these things in a Constant header file? These things would be a shared concept in many files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite follow what is "ParameterBlock type", do you mean ElementType?