-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2172 from helinwang/cgo
parameter client library: stub and cgo part with functional test.
- Loading branch information
Showing
11 changed files
with
584 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
cmake_minimum_required(VERSION 3.0) | ||
|
||
if(GTEST_INCLUDE_DIR AND GTEST_LIBRARIES) | ||
message("-- Found gtest (include: ${GTEST_INCLUDE_DIR}, library: ${GTEST_LIBRARIES})") | ||
else() | ||
# find #include <majel/xx.h> | ||
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) | ||
include_directories(${PARENT_DIR}) | ||
|
||
# find cmake directory modules | ||
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY) | ||
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY) | ||
|
||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake") | ||
|
||
# enable c++11 | ||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") | ||
|
||
# enable gtest | ||
set(THIRD_PARTY_PATH ./third_party) | ||
set(WITH_TESTING ON) | ||
include(external/gtest) | ||
endif() | ||
|
||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") | ||
|
||
project(cxx_go CXX C Go) | ||
|
||
include(cmake/golang.cmake) | ||
include(cmake/flags.cmake) | ||
|
||
ExternalGoProject_Add(pserver github.com/PaddlePaddle/Paddle/paddle/go/pserver) | ||
add_go_library(client STATIC pserver) | ||
add_subdirectory(test) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
package main | ||
|
||
/* | ||
#include <stdlib.h> | ||
#include <string.h> | ||
typedef enum { | ||
PADDLE_ELEMENT_TYPE_INT32 = 0, | ||
PADDLE_ELEMENT_TYPE_UINT32 = 1, | ||
PADDLE_ELEMENT_TYPE_INT64 = 2, | ||
PADDLE_ELEMENT_TYPE_UINT64 = 3, | ||
PADDLE_ELEMENT_TYPE_FLOAT32 = 4, | ||
PADDLE_ELEMENT_TYPE_FLOAT64 = 5, | ||
} paddle_element_type; | ||
typedef struct { | ||
char* name; | ||
paddle_element_type element_type; | ||
unsigned char* content; | ||
int content_len; | ||
} paddle_parameter, paddle_gradient; | ||
static inline void paddle_release_param(paddle_parameter* param) { | ||
if (param != NULL) { | ||
if (param->name != NULL) { | ||
free(param->name); | ||
} | ||
if (param->content != NULL) { | ||
free(param->content); | ||
} | ||
free(param); | ||
} | ||
} | ||
typedef int client; | ||
*/ | ||
import "C" | ||
|
||
import ( | ||
"log" | ||
"sync" | ||
"unsafe" | ||
|
||
"github.com/PaddlePaddle/Paddle/paddle/go/pserver" | ||
) | ||
|
||
var nullPtr = unsafe.Pointer(uintptr(0)) | ||
var mu sync.Mutex | ||
var handleMap = make(map[C.client]*pserver.Client) | ||
var curHandle C.client | ||
|
||
func add(c *pserver.Client) C.client { | ||
mu.Lock() | ||
defer mu.Unlock() | ||
client := curHandle | ||
curHandle++ | ||
handleMap[client] = c | ||
return client | ||
} | ||
|
||
func get(client C.client) *pserver.Client { | ||
mu.Lock() | ||
defer mu.Unlock() | ||
return handleMap[client] | ||
} | ||
|
||
func remove(client C.client) *pserver.Client { | ||
mu.Lock() | ||
defer mu.Unlock() | ||
h := handleMap[client] | ||
delete(handleMap, client) | ||
return h | ||
} | ||
|
||
func cArrayToSlice(p unsafe.Pointer, len int) []byte { | ||
if p == nullPtr { | ||
return nil | ||
} | ||
|
||
// create a Go clice backed by a C array, | ||
// reference: /~https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices | ||
return (*[1 << 30]byte)(p)[:len:len] | ||
} | ||
|
||
//export paddle_new_pserver_client | ||
func paddle_new_pserver_client(addr *C.char) C.client { | ||
c := pserver.NewClient(C.GoString(addr)) | ||
return add(c) | ||
} | ||
|
||
//export paddle_pserver_client_release | ||
func paddle_pserver_client_release(client C.client) { | ||
c := remove(client) | ||
c.Cleanup() | ||
} | ||
|
||
//export paddle_begin_init_params | ||
func paddle_begin_init_params(client C.client, pserver_config unsafe.Pointer, config_len C.int) C.int { | ||
c := get(client) | ||
b := cArrayToSlice(pserver_config, int(config_len)) | ||
selected, err := c.BeginInitParams(b) | ||
if err != nil { | ||
log.Println(err) | ||
return -1 | ||
} | ||
|
||
if selected { | ||
return 1 | ||
} | ||
return 0 | ||
} | ||
|
||
//export paddle_init_param | ||
func paddle_init_param(client C.client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int { | ||
et := pserver.ElementType(param.element_type) | ||
name := C.GoString(param.name) | ||
content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len)) | ||
pc := pserver.ParameterWithConfig{ | ||
Param: pserver.Parameter{Name: name, ElementType: et, Content: content}, | ||
Config: cArrayToSlice(param_config, int(config_len)), | ||
} | ||
c := get(client) | ||
err := c.InitParam(pc) | ||
if err != nil { | ||
log.Println(err) | ||
return -1 | ||
} | ||
|
||
return 0 | ||
} | ||
|
||
//export paddle_finish_init_params | ||
func paddle_finish_init_params(client C.client) C.int { | ||
c := get(client) | ||
err := c.FinishInitParams() | ||
if err != nil { | ||
log.Println(err) | ||
return -1 | ||
} | ||
|
||
return 0 | ||
} | ||
|
||
//export paddle_send_grads | ||
func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C.int { | ||
var gs []pserver.Gradient | ||
for i := 0; i < int(total); i++ { | ||
grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads)))) | ||
et := pserver.ElementType(grad.element_type) | ||
name := C.GoString(grad.name) | ||
content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len)) | ||
gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: content}) | ||
} | ||
|
||
c := get(client) | ||
err := c.SendGrads(gs) | ||
if err != nil { | ||
log.Println(err) | ||
return -1 | ||
} | ||
|
||
return 0 | ||
} | ||
|
||
//export paddle_get_params | ||
func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter, total C.int) C.int { | ||
var ns []string | ||
for i := 0; i < int(total); i++ { | ||
name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*unsafe.Sizeof(*names)))) | ||
ns = append(ns, C.GoString(name)) | ||
} | ||
c := get(client) | ||
ps, err := c.GetParams(ns) | ||
if err != nil { | ||
log.Println(err) | ||
return -1 | ||
} | ||
|
||
for i := 0; i < int(total); i++ { | ||
if i >= len(ps) { | ||
break | ||
} | ||
|
||
p := ps[i] | ||
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) | ||
nameReady := false | ||
contentAllocated := false | ||
|
||
if unsafe.Pointer(param) == nullPtr { | ||
param = (*C.paddle_parameter)(C.calloc(1, C.size_t(unsafe.Sizeof(*param)))) | ||
} else { | ||
if unsafe.Pointer(param.name) != nullPtr { | ||
if n := C.GoString(param.name); n != p.Name { | ||
log.Println("Warning: the pre-allocated parameter name does not match the parameter name, it will be freed.", n, p.Name) | ||
C.free(unsafe.Pointer(param.name)) | ||
} else { | ||
nameReady = true | ||
} | ||
} | ||
|
||
if unsafe.Pointer(param.content) != nullPtr { | ||
if int(param.content_len) == len(p.Content) { | ||
contentAllocated = true | ||
} else { | ||
log.Println("Warning: the pre-allocated content len does not match parameter content len, the pre-allocated content will be freed.", param.content_len, len(p.Content)) | ||
C.free(unsafe.Pointer(param.content)) | ||
} | ||
} | ||
} | ||
|
||
if !nameReady { | ||
param.name = C.CString(p.Name) | ||
} | ||
if !contentAllocated { | ||
param.content = (*C.uchar)(C.malloc(C.size_t(len(p.Content)))) | ||
} | ||
C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) | ||
param.content_len = C.int(len(p.Content)) | ||
param.element_type = C.paddle_element_type(p.ElementType) | ||
} | ||
|
||
return 0 | ||
} | ||
|
||
//export paddle_save_model | ||
func paddle_save_model(client C.client, path *C.char) C.int { | ||
p := C.GoString(path) | ||
c := get(client) | ||
err := c.SaveModel(p) | ||
if err != nil { | ||
log.Println(err) | ||
return -1 | ||
} | ||
|
||
return 0 | ||
} | ||
|
||
func main() {} // Required but ignored |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
if(NOT CMAKE_Go_COMPILER) | ||
if(NOT $ENV{GO_COMPILER} STREQUAL "") | ||
get_filename_component(CMAKE_Go_COMPILER_INIT $ENV{GO_COMPILER} PROGRAM PROGRAM_ARGS CMAKE_Go_FLAGS_ENV_INIT) | ||
|
||
if(CMAKE_Go_FLAGS_ENV_INIT) | ||
set(CMAKE_Go_COMPILER_ARG1 "${CMAKE_Go_FLAGS_ENV_INIT}" CACHE STRING "First argument to Go compiler") | ||
endif() | ||
|
||
if(NOT EXISTS ${CMAKE_Go_COMPILER_INIT}) | ||
message(SEND_ERROR "Could not find compiler set in environment variable GO_COMPILER:\n$ENV{GO_COMPILER}.") | ||
endif() | ||
|
||
endif() | ||
|
||
set(Go_BIN_PATH | ||
$ENV{GOPATH} | ||
$ENV{GOROOT} | ||
$ENV{GOROOT}/../bin | ||
$ENV{GO_COMPILER} | ||
/usr/bin | ||
/usr/local/bin | ||
) | ||
|
||
if(CMAKE_Go_COMPILER_INIT) | ||
set(CMAKE_Go_COMPILER ${CMAKE_Go_COMPILER_INIT} CACHE PATH "Go Compiler") | ||
else() | ||
find_program(CMAKE_Go_COMPILER | ||
NAMES go | ||
PATHS ${Go_BIN_PATH} | ||
) | ||
EXEC_PROGRAM(${CMAKE_Go_COMPILER} ARGS version OUTPUT_VARIABLE GOLANG_VERSION) | ||
STRING(REGEX MATCH "go[0-9]+.[0-9]+.[0-9]+[ /A-Za-z0-9]*" VERSION "${GOLANG_VERSION}") | ||
message("-- The Golang compiler identification is ${VERSION}") | ||
message("-- Check for working Golang compiler: ${CMAKE_Go_COMPILER}") | ||
endif() | ||
|
||
endif() | ||
|
||
mark_as_advanced(CMAKE_Go_COMPILER) | ||
|
||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/CMakeGoCompiler.cmake.in | ||
${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY) | ||
|
||
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
set(CMAKE_Go_COMPILER "@CMAKE_Go_COMPILER@") | ||
set(CMAKE_Go_COMPILER_LOADED 1) | ||
|
||
set(CMAKE_Go_SOURCE_FILE_EXTENSIONS go) | ||
set(CMAKE_Go_LINKER_PREFERENCE 40) | ||
set(CMAKE_Go_OUTPUT_EXTENSION .o) | ||
set(CMAKE_Go_OUTPUT_EXTENSION_REPLACE 1) | ||
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
if(NOT CMAKE_Go_COMPILE_OBJECT) | ||
set(CMAKE_Go_COMPILE_OBJECT "go tool compile -l -N -o <OBJECT> <SOURCE> ") | ||
endif() | ||
|
||
if(NOT CMAKE_Go_LINK_EXECUTABLE) | ||
set(CMAKE_Go_LINK_EXECUTABLE "go tool link -o <TARGET> <OBJECTS> ") | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
set(CMAKE_Go_COMPILER_WORKS 1 CACHE INTERNAL "") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Setting Paddle Compile Flags | ||
include(CheckCXXCompilerFlag) | ||
include(CheckCCompilerFlag) | ||
include(CheckCXXSymbolExists) | ||
include(CheckTypeSize) | ||
|
||
function(CheckCompilerCXX11Flag) | ||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") | ||
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) | ||
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") | ||
endif() | ||
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") | ||
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" | ||
# Apple Clang is a different compiler than upstream Clang which havs different version numbers. | ||
# https://gist.github.com/yamaya/2924292 | ||
if(APPLE) # cmake < 3.0 compiler id "Clang" on Mac OS X | ||
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 5.1) | ||
message(FATAL_ERROR "Unsupported AppleClang version. AppleClang >= 5.1 required.") | ||
endif() | ||
else() | ||
if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3) | ||
message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.") | ||
endif() | ||
endif() | ||
endif() | ||
endfunction() | ||
|
||
CheckCompilerCXX11Flag() | ||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") | ||
|
||
# Common gpu architectures: Kepler, Maxwell | ||
foreach(capability 30 35 50) | ||
list(APPEND __arch_flags " -gencode arch=compute_${capability},code=sm_${capability}") | ||
endforeach() | ||
|
||
if (CUDA_VERSION VERSION_GREATER "7.0" OR CUDA_VERSION VERSION_EQUAL "7.0") | ||
list(APPEND __arch_flags " -gencode arch=compute_52,code=sm_52") | ||
endif() | ||
|
||
# Modern gpu architectures: Pascal | ||
if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0") | ||
list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60") | ||
endif() | ||
|
||
set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS}) |
Oops, something went wrong.