Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parameter client library: stub and cgo part with functional test. #2172

Merged
merged 14 commits into from
May 18, 2017
34 changes: 34 additions & 0 deletions paddle/go/cclient/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
cmake_minimum_required(VERSION 3.0)

if(GTEST_INCLUDE_DIR AND GTEST_LIBRARIES)
message("-- Found gtest (include: ${GTEST_INCLUDE_DIR}, library: ${GTEST_LIBRARIES})")
else()
# find #include <majel/xx.h>
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
include_directories(${PARENT_DIR})

# find cmake directory modules
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")

# enable c++11
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")

# enable gtest
set(THIRD_PARTY_PATH ./third_party)
set(WITH_TESTING ON)
include(external/gtest)
endif()

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")

project(cxx_go CXX C Go)

include(cmake/golang.cmake)
include(cmake/flags.cmake)

ExternalGoProject_Add(pserver github.com/PaddlePaddle/Paddle/paddle/go/pserver)
add_go_library(client STATIC pserver)
add_subdirectory(test)
239 changes: 239 additions & 0 deletions paddle/go/cclient/cclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
package main

/*
#include <stdlib.h>
#include <string.h>
typedef enum {
PADDLE_ELEMENT_TYPE_INT32 = 0,
PADDLE_ELEMENT_TYPE_UINT32 = 1,
PADDLE_ELEMENT_TYPE_INT64 = 2,
PADDLE_ELEMENT_TYPE_UINT64 = 3,
PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
} paddle_element_type;

typedef struct {
char* name;
paddle_element_type element_type;
unsigned char* content;
int content_len;
} paddle_parameter, paddle_gradient;

static inline void paddle_release_param(paddle_parameter* param) {
if (param != NULL) {
if (param->name != NULL) {
free(param->name);
}

if (param->content != NULL) {
free(param->content);
}

free(param);
}
}

typedef int client;
*/
import "C"

import (
"log"
"sync"
"unsafe"

"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
)

var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex
var handleMap = make(map[C.client]*pserver.Client)
var curHandle C.client

func add(c *pserver.Client) C.client {
mu.Lock()
defer mu.Unlock()
client := curHandle
curHandle++
handleMap[client] = c
return client
}

func get(client C.client) *pserver.Client {
mu.Lock()
defer mu.Unlock()
return handleMap[client]
}

func remove(client C.client) *pserver.Client {
mu.Lock()
defer mu.Unlock()
h := handleMap[client]
delete(handleMap, client)
return h
}

func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr {
return nil
}

// create a Go clice backed by a C array,
// reference: /~https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
return (*[1 << 30]byte)(p)[:len:len]
}

//export paddle_new_pserver_client
func paddle_new_pserver_client(addr *C.char) C.client {
c := pserver.NewClient(C.GoString(addr))
return add(c)
}

//export paddle_pserver_client_release
func paddle_pserver_client_release(client C.client) {
c := remove(client)
c.Cleanup()
}

//export paddle_begin_init_params
func paddle_begin_init_params(client C.client, pserver_config unsafe.Pointer, config_len C.int) C.int {
c := get(client)
b := cArrayToSlice(pserver_config, int(config_len))
selected, err := c.BeginInitParams(b)
if err != nil {
log.Println(err)
return -1
}

if selected {
return 1
}
return 0
}

//export paddle_init_param
func paddle_init_param(client C.client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int {
et := pserver.ElementType(param.element_type)
name := C.GoString(param.name)
content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len))
pc := pserver.ParameterWithConfig{
Param: pserver.Parameter{Name: name, ElementType: et, Content: content},
Config: cArrayToSlice(param_config, int(config_len)),
}
c := get(client)
err := c.InitParam(pc)
if err != nil {
log.Println(err)
return -1
}

return 0
}

//export paddle_finish_init_params
func paddle_finish_init_params(client C.client) C.int {
c := get(client)
err := c.FinishInitParams()
if err != nil {
log.Println(err)
return -1
}

return 0
}

//export paddle_send_grads
func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C.int {
var gs []pserver.Gradient
for i := 0; i < int(total); i++ {
grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
et := pserver.ElementType(grad.element_type)
name := C.GoString(grad.name)
content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len))
gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: content})
}

c := get(client)
err := c.SendGrads(gs)
if err != nil {
log.Println(err)
return -1
}

return 0
}

//export paddle_get_params
func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter, total C.int) C.int {
var ns []string
for i := 0; i < int(total); i++ {
name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*unsafe.Sizeof(*names))))
ns = append(ns, C.GoString(name))
}
c := get(client)
ps, err := c.GetParams(ns)
if err != nil {
log.Println(err)
return -1
}

for i := 0; i < int(total); i++ {
if i >= len(ps) {
break
}

p := ps[i]
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
nameReady := false
contentAllocated := false

if unsafe.Pointer(param) == nullPtr {
param = (*C.paddle_parameter)(C.calloc(1, C.size_t(unsafe.Sizeof(*param))))
} else {
if unsafe.Pointer(param.name) != nullPtr {
if n := C.GoString(param.name); n != p.Name {
log.Println("Warning: the pre-allocated parameter name does not match the parameter name, it will be freed.", n, p.Name)
C.free(unsafe.Pointer(param.name))
} else {
nameReady = true
}
}

if unsafe.Pointer(param.content) != nullPtr {
if int(param.content_len) == len(p.Content) {
contentAllocated = true
} else {
log.Println("Warning: the pre-allocated content len does not match parameter content len, the pre-allocated content will be freed.", param.content_len, len(p.Content))
C.free(unsafe.Pointer(param.content))
}
}
}

if !nameReady {
param.name = C.CString(p.Name)
}
if !contentAllocated {
param.content = (*C.uchar)(C.malloc(C.size_t(len(p.Content))))
}
C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
param.content_len = C.int(len(p.Content))
param.element_type = C.paddle_element_type(p.ElementType)
}

return 0
}

//export paddle_save_model
func paddle_save_model(client C.client, path *C.char) C.int {
p := C.GoString(path)
c := get(client)
err := c.SaveModel(p)
if err != nil {
log.Println(err)
return -1
}

return 0
}

func main() {} // Required but ignored
44 changes: 44 additions & 0 deletions paddle/go/cclient/cmake/CMakeDetermineGoCompiler.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
if(NOT CMAKE_Go_COMPILER)
if(NOT $ENV{GO_COMPILER} STREQUAL "")
get_filename_component(CMAKE_Go_COMPILER_INIT $ENV{GO_COMPILER} PROGRAM PROGRAM_ARGS CMAKE_Go_FLAGS_ENV_INIT)

if(CMAKE_Go_FLAGS_ENV_INIT)
set(CMAKE_Go_COMPILER_ARG1 "${CMAKE_Go_FLAGS_ENV_INIT}" CACHE STRING "First argument to Go compiler")
endif()

if(NOT EXISTS ${CMAKE_Go_COMPILER_INIT})
message(SEND_ERROR "Could not find compiler set in environment variable GO_COMPILER:\n$ENV{GO_COMPILER}.")
endif()

endif()

set(Go_BIN_PATH
$ENV{GOPATH}
$ENV{GOROOT}
$ENV{GOROOT}/../bin
$ENV{GO_COMPILER}
/usr/bin
/usr/local/bin
)

if(CMAKE_Go_COMPILER_INIT)
set(CMAKE_Go_COMPILER ${CMAKE_Go_COMPILER_INIT} CACHE PATH "Go Compiler")
else()
find_program(CMAKE_Go_COMPILER
NAMES go
PATHS ${Go_BIN_PATH}
)
EXEC_PROGRAM(${CMAKE_Go_COMPILER} ARGS version OUTPUT_VARIABLE GOLANG_VERSION)
STRING(REGEX MATCH "go[0-9]+.[0-9]+.[0-9]+[ /A-Za-z0-9]*" VERSION "${GOLANG_VERSION}")
message("-- The Golang compiler identification is ${VERSION}")
message("-- Check for working Golang compiler: ${CMAKE_Go_COMPILER}")
endif()

endif()

mark_as_advanced(CMAKE_Go_COMPILER)

configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/CMakeGoCompiler.cmake.in
${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY)

set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
8 changes: 8 additions & 0 deletions paddle/go/cclient/cmake/CMakeGoCompiler.cmake.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
set(CMAKE_Go_COMPILER "@CMAKE_Go_COMPILER@")
set(CMAKE_Go_COMPILER_LOADED 1)

set(CMAKE_Go_SOURCE_FILE_EXTENSIONS go)
set(CMAKE_Go_LINKER_PREFERENCE 40)
set(CMAKE_Go_OUTPUT_EXTENSION .o)
set(CMAKE_Go_OUTPUT_EXTENSION_REPLACE 1)
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
7 changes: 7 additions & 0 deletions paddle/go/cclient/cmake/CMakeGoInformation.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
if(NOT CMAKE_Go_COMPILE_OBJECT)
set(CMAKE_Go_COMPILE_OBJECT "go tool compile -l -N -o <OBJECT> <SOURCE> ")
endif()

if(NOT CMAKE_Go_LINK_EXECUTABLE)
set(CMAKE_Go_LINK_EXECUTABLE "go tool link -o <TARGET> <OBJECTS> ")
endif()
1 change: 1 addition & 0 deletions paddle/go/cclient/cmake/CMakeTestGoCompiler.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
set(CMAKE_Go_COMPILER_WORKS 1 CACHE INTERNAL "")
45 changes: 45 additions & 0 deletions paddle/go/cclient/cmake/flags.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Setting Paddle Compile Flags
include(CheckCXXCompilerFlag)
include(CheckCCompilerFlag)
include(CheckCXXSymbolExists)
include(CheckTypeSize)

function(CheckCompilerCXX11Flag)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# Apple Clang is a different compiler than upstream Clang which havs different version numbers.
# https://gist.github.com/yamaya/2924292
if(APPLE) # cmake < 3.0 compiler id "Clang" on Mac OS X
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 5.1)
message(FATAL_ERROR "Unsupported AppleClang version. AppleClang >= 5.1 required.")
endif()
else()
if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3)
message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.")
endif()
endif()
endif()
endfunction()

CheckCompilerCXX11Flag()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")

# Common gpu architectures: Kepler, Maxwell
foreach(capability 30 35 50)
list(APPEND __arch_flags " -gencode arch=compute_${capability},code=sm_${capability}")
endforeach()

if (CUDA_VERSION VERSION_GREATER "7.0" OR CUDA_VERSION VERSION_EQUAL "7.0")
list(APPEND __arch_flags " -gencode arch=compute_52,code=sm_52")
endif()

# Modern gpu architectures: Pascal
if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0")
list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60")
endif()

set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS})
Loading