Skip to content

Commit

Permalink
Merge pull request #2172 from helinwang/cgo
Browse files Browse the repository at this point in the history
parameter client library: stub and cgo part with functional test.
  • Loading branch information
helinwang authored May 18, 2017
2 parents 7377c59 + 478ebc4 commit 501b59a
Show file tree
Hide file tree
Showing 11 changed files with 584 additions and 0 deletions.
34 changes: 34 additions & 0 deletions paddle/go/cclient/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
cmake_minimum_required(VERSION 3.0)

if(GTEST_INCLUDE_DIR AND GTEST_LIBRARIES)
message("-- Found gtest (include: ${GTEST_INCLUDE_DIR}, library: ${GTEST_LIBRARIES})")
else()
# find #include <majel/xx.h>
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
include_directories(${PARENT_DIR})

# find cmake directory modules
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")

# enable c++11
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")

# enable gtest
set(THIRD_PARTY_PATH ./third_party)
set(WITH_TESTING ON)
include(external/gtest)
endif()

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")

project(cxx_go CXX C Go)

include(cmake/golang.cmake)
include(cmake/flags.cmake)

ExternalGoProject_Add(pserver github.com/PaddlePaddle/Paddle/paddle/go/pserver)
add_go_library(client STATIC pserver)
add_subdirectory(test)
239 changes: 239 additions & 0 deletions paddle/go/cclient/cclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
package main

/*
#include <stdlib.h>
#include <string.h>
typedef enum {
PADDLE_ELEMENT_TYPE_INT32 = 0,
PADDLE_ELEMENT_TYPE_UINT32 = 1,
PADDLE_ELEMENT_TYPE_INT64 = 2,
PADDLE_ELEMENT_TYPE_UINT64 = 3,
PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
} paddle_element_type;
typedef struct {
char* name;
paddle_element_type element_type;
unsigned char* content;
int content_len;
} paddle_parameter, paddle_gradient;
static inline void paddle_release_param(paddle_parameter* param) {
if (param != NULL) {
if (param->name != NULL) {
free(param->name);
}
if (param->content != NULL) {
free(param->content);
}
free(param);
}
}
typedef int client;
*/
import "C"

import (
"log"
"sync"
"unsafe"

"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
)

var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex
var handleMap = make(map[C.client]*pserver.Client)
var curHandle C.client

func add(c *pserver.Client) C.client {
mu.Lock()
defer mu.Unlock()
client := curHandle
curHandle++
handleMap[client] = c
return client
}

func get(client C.client) *pserver.Client {
mu.Lock()
defer mu.Unlock()
return handleMap[client]
}

func remove(client C.client) *pserver.Client {
mu.Lock()
defer mu.Unlock()
h := handleMap[client]
delete(handleMap, client)
return h
}

func cArrayToSlice(p unsafe.Pointer, len int) []byte {
if p == nullPtr {
return nil
}

// create a Go clice backed by a C array,
// reference: /~https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
return (*[1 << 30]byte)(p)[:len:len]
}

//export paddle_new_pserver_client
func paddle_new_pserver_client(addr *C.char) C.client {
c := pserver.NewClient(C.GoString(addr))
return add(c)
}

//export paddle_pserver_client_release
func paddle_pserver_client_release(client C.client) {
c := remove(client)
c.Cleanup()
}

//export paddle_begin_init_params
func paddle_begin_init_params(client C.client, pserver_config unsafe.Pointer, config_len C.int) C.int {
c := get(client)
b := cArrayToSlice(pserver_config, int(config_len))
selected, err := c.BeginInitParams(b)
if err != nil {
log.Println(err)
return -1
}

if selected {
return 1
}
return 0
}

//export paddle_init_param
func paddle_init_param(client C.client, param C.paddle_parameter, param_config unsafe.Pointer, config_len C.int) C.int {
et := pserver.ElementType(param.element_type)
name := C.GoString(param.name)
content := cArrayToSlice(unsafe.Pointer(param.content), int(param.content_len))
pc := pserver.ParameterWithConfig{
Param: pserver.Parameter{Name: name, ElementType: et, Content: content},
Config: cArrayToSlice(param_config, int(config_len)),
}
c := get(client)
err := c.InitParam(pc)
if err != nil {
log.Println(err)
return -1
}

return 0
}

//export paddle_finish_init_params
func paddle_finish_init_params(client C.client) C.int {
c := get(client)
err := c.FinishInitParams()
if err != nil {
log.Println(err)
return -1
}

return 0
}

//export paddle_send_grads
func paddle_send_grads(client C.client, grads *C.paddle_gradient, total C.int) C.int {
var gs []pserver.Gradient
for i := 0; i < int(total); i++ {
grad := (*C.paddle_gradient)(unsafe.Pointer((uintptr(unsafe.Pointer(grads)) + uintptr(i)*unsafe.Sizeof(*grads))))
et := pserver.ElementType(grad.element_type)
name := C.GoString(grad.name)
content := cArrayToSlice(unsafe.Pointer(grad.content), int(grad.content_len))
gs = append(gs, pserver.Gradient{Name: name, ElementType: et, Content: content})
}

c := get(client)
err := c.SendGrads(gs)
if err != nil {
log.Println(err)
return -1
}

return 0
}

//export paddle_get_params
func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter, total C.int) C.int {
var ns []string
for i := 0; i < int(total); i++ {
name := *(**C.char)(unsafe.Pointer((uintptr(unsafe.Pointer(names)) + uintptr(i)*unsafe.Sizeof(*names))))
ns = append(ns, C.GoString(name))
}
c := get(client)
ps, err := c.GetParams(ns)
if err != nil {
log.Println(err)
return -1
}

for i := 0; i < int(total); i++ {
if i >= len(ps) {
break
}

p := ps[i]
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
nameReady := false
contentAllocated := false

if unsafe.Pointer(param) == nullPtr {
param = (*C.paddle_parameter)(C.calloc(1, C.size_t(unsafe.Sizeof(*param))))
} else {
if unsafe.Pointer(param.name) != nullPtr {
if n := C.GoString(param.name); n != p.Name {
log.Println("Warning: the pre-allocated parameter name does not match the parameter name, it will be freed.", n, p.Name)
C.free(unsafe.Pointer(param.name))
} else {
nameReady = true
}
}

if unsafe.Pointer(param.content) != nullPtr {
if int(param.content_len) == len(p.Content) {
contentAllocated = true
} else {
log.Println("Warning: the pre-allocated content len does not match parameter content len, the pre-allocated content will be freed.", param.content_len, len(p.Content))
C.free(unsafe.Pointer(param.content))
}
}
}

if !nameReady {
param.name = C.CString(p.Name)
}
if !contentAllocated {
param.content = (*C.uchar)(C.malloc(C.size_t(len(p.Content))))
}
C.memcpy(unsafe.Pointer(param.content), unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
param.content_len = C.int(len(p.Content))
param.element_type = C.paddle_element_type(p.ElementType)
}

return 0
}

//export paddle_save_model
func paddle_save_model(client C.client, path *C.char) C.int {
p := C.GoString(path)
c := get(client)
err := c.SaveModel(p)
if err != nil {
log.Println(err)
return -1
}

return 0
}

func main() {} // Required but ignored
44 changes: 44 additions & 0 deletions paddle/go/cclient/cmake/CMakeDetermineGoCompiler.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
if(NOT CMAKE_Go_COMPILER)
if(NOT $ENV{GO_COMPILER} STREQUAL "")
get_filename_component(CMAKE_Go_COMPILER_INIT $ENV{GO_COMPILER} PROGRAM PROGRAM_ARGS CMAKE_Go_FLAGS_ENV_INIT)

if(CMAKE_Go_FLAGS_ENV_INIT)
set(CMAKE_Go_COMPILER_ARG1 "${CMAKE_Go_FLAGS_ENV_INIT}" CACHE STRING "First argument to Go compiler")
endif()

if(NOT EXISTS ${CMAKE_Go_COMPILER_INIT})
message(SEND_ERROR "Could not find compiler set in environment variable GO_COMPILER:\n$ENV{GO_COMPILER}.")
endif()

endif()

set(Go_BIN_PATH
$ENV{GOPATH}
$ENV{GOROOT}
$ENV{GOROOT}/../bin
$ENV{GO_COMPILER}
/usr/bin
/usr/local/bin
)

if(CMAKE_Go_COMPILER_INIT)
set(CMAKE_Go_COMPILER ${CMAKE_Go_COMPILER_INIT} CACHE PATH "Go Compiler")
else()
find_program(CMAKE_Go_COMPILER
NAMES go
PATHS ${Go_BIN_PATH}
)
EXEC_PROGRAM(${CMAKE_Go_COMPILER} ARGS version OUTPUT_VARIABLE GOLANG_VERSION)
STRING(REGEX MATCH "go[0-9]+.[0-9]+.[0-9]+[ /A-Za-z0-9]*" VERSION "${GOLANG_VERSION}")
message("-- The Golang compiler identification is ${VERSION}")
message("-- Check for working Golang compiler: ${CMAKE_Go_COMPILER}")
endif()

endif()

mark_as_advanced(CMAKE_Go_COMPILER)

configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/CMakeGoCompiler.cmake.in
${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY)

set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
8 changes: 8 additions & 0 deletions paddle/go/cclient/cmake/CMakeGoCompiler.cmake.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
set(CMAKE_Go_COMPILER "@CMAKE_Go_COMPILER@")
set(CMAKE_Go_COMPILER_LOADED 1)

set(CMAKE_Go_SOURCE_FILE_EXTENSIONS go)
set(CMAKE_Go_LINKER_PREFERENCE 40)
set(CMAKE_Go_OUTPUT_EXTENSION .o)
set(CMAKE_Go_OUTPUT_EXTENSION_REPLACE 1)
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
7 changes: 7 additions & 0 deletions paddle/go/cclient/cmake/CMakeGoInformation.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
if(NOT CMAKE_Go_COMPILE_OBJECT)
set(CMAKE_Go_COMPILE_OBJECT "go tool compile -l -N -o <OBJECT> <SOURCE> ")
endif()

if(NOT CMAKE_Go_LINK_EXECUTABLE)
set(CMAKE_Go_LINK_EXECUTABLE "go tool link -o <TARGET> <OBJECTS> ")
endif()
1 change: 1 addition & 0 deletions paddle/go/cclient/cmake/CMakeTestGoCompiler.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
set(CMAKE_Go_COMPILER_WORKS 1 CACHE INTERNAL "")
45 changes: 45 additions & 0 deletions paddle/go/cclient/cmake/flags.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Setting Paddle Compile Flags
include(CheckCXXCompilerFlag)
include(CheckCCompilerFlag)
include(CheckCXXSymbolExists)
include(CheckTypeSize)

function(CheckCompilerCXX11Flag)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# Apple Clang is a different compiler than upstream Clang which havs different version numbers.
# https://gist.github.com/yamaya/2924292
if(APPLE) # cmake < 3.0 compiler id "Clang" on Mac OS X
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 5.1)
message(FATAL_ERROR "Unsupported AppleClang version. AppleClang >= 5.1 required.")
endif()
else()
if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3)
message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.")
endif()
endif()
endif()
endfunction()

CheckCompilerCXX11Flag()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")

# Common gpu architectures: Kepler, Maxwell
foreach(capability 30 35 50)
list(APPEND __arch_flags " -gencode arch=compute_${capability},code=sm_${capability}")
endforeach()

if (CUDA_VERSION VERSION_GREATER "7.0" OR CUDA_VERSION VERSION_EQUAL "7.0")
list(APPEND __arch_flags " -gencode arch=compute_52,code=sm_52")
endif()

# Modern gpu architectures: Pascal
if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0")
list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60")
endif()

set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS})
Loading

0 comments on commit 501b59a

Please sign in to comment.