Skip to content

Commit

Permalink
Support for user functions. Fixes TryGhost#140.
Browse files Browse the repository at this point in the history
  • Loading branch information
wbyoung committed Jun 30, 2015
1 parent 1127c27 commit 19d7375
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 0 deletions.
153 changes: 153 additions & 0 deletions src/database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#include "database.h"
#include "statement.h"

#ifndef SQLITE_DETERMINISTIC
#define SQLITE_DETERMINISTIC 0x800
#endif

using namespace node_sqlite3;

Persistent<FunctionTemplate> Database::constructor_template;
Expand All @@ -24,6 +28,7 @@ void Database::Init(Handle<Object> target) {
NODE_SET_PROTOTYPE_METHOD(t, "serialize", Serialize);
NODE_SET_PROTOTYPE_METHOD(t, "parallelize", Parallelize);
NODE_SET_PROTOTYPE_METHOD(t, "configure", Configure);
NODE_SET_PROTOTYPE_METHOD(t, "registerFunction", RegisterFunction);

NODE_SET_GETTER(t, "open", OpenGetter);

Expand Down Expand Up @@ -356,6 +361,154 @@ NAN_METHOD(Database::Configure) {
NanReturnValue(args.This());
}

NAN_METHOD(Database::RegisterFunction) {
NanScope();
Database* db = ObjectWrap::Unwrap<Database>(args.This());

REQUIRE_ARGUMENTS(2);
REQUIRE_ARGUMENT_STRING(0, functionName);
REQUIRE_ARGUMENT_FUNCTION(1, callback);

FunctionBaton *baton = new FunctionBaton(db, *functionName, callback);
sqlite3_create_function(
db->_handle,
*functionName,
-1, // arbitrary number of args
SQLITE_UTF8 | SQLITE_DETERMINISTIC,
baton,
FunctionEnqueue,
NULL,
NULL);

uv_mutex_init(&baton->mutex);
uv_cond_init(&baton->condition);
uv_async_init(uv_default_loop(), &baton->async, (uv_async_cb)Database::AsyncFunctionProcessQueue);

NanReturnValue(args.This());
}

void Database::FunctionEnqueue(sqlite3_context *context, int argc, sqlite3_value **argv) {
// the JS function can only be safely executed on the main thread
// (uv_default_loop), so setup an invocation w/ the relevant information,
// enqueue it and signal the main thread to process the invocation queue.
// sqlite3 requires the result to be set before this function returns, so
// wait for the invocation to be completed.
FunctionBaton *baton = (FunctionBaton *)sqlite3_user_data(context);
FunctionInvocation invocation = {};
invocation.context = context;
invocation.argc = argc;
invocation.argv = argv;

uv_async_send(&baton->async);
uv_mutex_lock(&baton->mutex);
baton->queue.push(&invocation);
while (!invocation.complete) {
uv_cond_wait(&baton->condition, &baton->mutex);
}
uv_mutex_unlock(&baton->mutex);
}

void Database::AsyncFunctionProcessQueue(uv_async_t *async) {
FunctionBaton *baton = (FunctionBaton *)async->data;

for (;;) {
FunctionInvocation *invocation = NULL;

uv_mutex_lock(&baton->mutex);
if (!baton->queue.empty()) {
invocation = baton->queue.front();
baton->queue.pop();
}
uv_mutex_unlock(&baton->mutex);

if (!invocation) { break; }

Database::FunctionExecute(baton, invocation);

uv_mutex_lock(&baton->mutex);
invocation->complete = true;
uv_cond_signal(&baton->condition); // allow paused thread to complete
uv_mutex_unlock(&baton->mutex);
}
}

void Database::FunctionExecute(FunctionBaton *baton, FunctionInvocation *invocation) {
NanScope();

Database *db = baton->db;
Local<Function> cb = NanNew(baton->callback);
sqlite3_context *context = invocation->context;
sqlite3_value **values = invocation->argv;
int argc = invocation->argc;

if (!cb.IsEmpty() && cb->IsFunction()) {

// build the argument list for the function call
typedef Local<Value> LocalValue;
std::vector<LocalValue> argv;
for (int i = 0; i < argc; i++) {
sqlite3_value *value = values[i];
int type = sqlite3_value_type(value);
Local<Value> arg;
switch(type) {
case SQLITE_INTEGER: {
arg = NanNew<Number>(sqlite3_value_int64(value));
} break;
case SQLITE_FLOAT: {
arg = NanNew<Number>(sqlite3_value_double(value));
} break;
case SQLITE_TEXT: {
const char* text = (const char*)sqlite3_value_text(value);
int length = sqlite3_value_bytes(value);
arg = NanNew<String>(text, length);
} break;
case SQLITE_BLOB: {
const void *blob = sqlite3_value_blob(value);
int length = sqlite3_value_bytes(value);
arg = NanNew(NanNewBufferHandle((char *)blob, length));
} break;
case SQLITE_NULL: {
arg = NanNew(NanNull());
} break;
}

argv.push_back(arg);
}

Local<Value> result = TRY_CATCH_CALL(NanObjectWrapHandle(db), cb, argc, argv.data());

// process the result
if (result->IsString() || result->IsRegExp()) {
String::Utf8Value value(result->ToString());
sqlite3_result_text(context, *value, value.length(), SQLITE_TRANSIENT);
}
else if (result->IsInt32()) {
sqlite3_result_int(context, result->Int32Value());
}
else if (result->IsNumber() || result->IsDate()) {
sqlite3_result_double(context, result->NumberValue());
}
else if (result->IsBoolean()) {
sqlite3_result_int(context, result->BooleanValue());
}
else if (result->IsNull() || result->IsUndefined()) {
sqlite3_result_null(context);
}
else if (Buffer::HasInstance(result)) {
Local<Object> buffer = result->ToObject();
sqlite3_result_blob(context,
Buffer::Data(buffer),
Buffer::Length(buffer),
SQLITE_TRANSIENT);
}
else {
std::string message("invalid return type in user function");
message = message + " " + baton->name;
sqlite3_result_error(context, message.c_str(), message.length());
}
}
}

void Database::SetBusyTimeout(Baton* baton) {
assert(baton->db->open);
assert(baton->db->_handle);
Expand Down
31 changes: 31 additions & 0 deletions src/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,32 @@ class Database : public ObjectWrap {
Baton(db_, cb_), filename(filename_) {}
};

struct FunctionInvocation {
sqlite3_context *context;
sqlite3_value **argv;
int argc;
bool complete;
};

struct FunctionBaton {
Database* db;
std::string name;
Persistent<Function> callback;
uv_async_t async;
uv_mutex_t mutex;
uv_cond_t condition;
std::queue<FunctionInvocation*> queue;

FunctionBaton(Database* db_, const char* name_, Handle<Function> cb_) :
db(db_), name(name_) {
async.data = this;
NanAssignPersistent(callback, cb_);
}
virtual ~FunctionBaton() {
NanDisposePersistent(callback);
}
};

typedef void (*Work_Callback)(Baton* baton);

struct Call {
Expand Down Expand Up @@ -152,6 +178,11 @@ class Database : public ObjectWrap {

static NAN_METHOD(Configure);

static NAN_METHOD(RegisterFunction);
static void FunctionEnqueue(sqlite3_context *context, int argc, sqlite3_value **argv);
static void FunctionExecute(FunctionBaton *baton, FunctionInvocation *invocation);
static void AsyncFunctionProcessQueue(uv_async_t *async);

static void SetBusyTimeout(Baton* baton);

static void RegisterTraceCallback(Baton* baton);
Expand Down
107 changes: 107 additions & 0 deletions test/user_functions.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
var sqlite3 = require('..');
var assert = require('assert');

describe('user functions', function() {
var db;
before(function(done) { db = new sqlite3.Database(':memory:', done); });

it('should allow registration of user functions', function() {
db.registerFunction('MY_UPPERCASE', function(value) {
return value.toUpperCase();
});
db.registerFunction('MY_STRING_JOIN', function(value1, value2) {
return [value1, value2].join(' ');
});
db.registerFunction('MY_Add', function(value1, value2) {
return value1 + value2;
});
db.registerFunction('MY_REGEX', function(regex, value) {
return !!value.match(new RegExp(regex));
});
db.registerFunction('MY_REGEX_VALUE', function(regex, value) {
return /match things/i;
});
db.registerFunction('MY_ERROR', function(value) {
throw new Error('This function always throws');
});
db.registerFunction('MY_UNHANDLED_TYPE', function(value) {
return {};
});
db.registerFunction('MY_NOTHING', function(value) {

});
});

it('should process user functions with one arg', function(done) {
db.all('SELECT MY_UPPERCASE("hello") AS txt', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].txt, 'HELLO')
done();
});
});

it('should process user functions with two args', function(done) {
db.all('SELECT MY_STRING_JOIN("hello", "world") AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].val, 'hello world');
done();
});
});

it('should process user functions with number args', function(done) {
db.all('SELECT MY_ADD(1, 2) AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].val, 3);
done();
});
});

it('allows writing of a regex function', function(done) {
db.all('SELECT MY_REGEX("colou?r", "color") AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(Boolean(rows[0].val), true);
done();
});
});

it('converts returned regex instances to strings', function(done) {
db.all('SELECT MY_REGEX_VALUE() AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].val, '/match things/i');
done();
});
});

it.skip('reports errors thrown in functions', function(done) {
db.all('SELECT MY_ERROR() AS val', function(err, rows) {
assert.equal(err.message, 'This function always throws');
assert.equal(rows, undefined);
done();
});
});

it('reports errors for unhandled types', function(done) {
db.all('SELECT MY_UNHANDLED_TYPE() AS val', function(err, rows) {
assert.equal(err.message, 'SQLITE_ERROR: invalid return type in ' +
'user function MY_UNHANDLED_TYPE');
assert.equal(rows, undefined);
done();
});
});

it('allows no return value from functions', function(done) {
db.all('SELECT MY_NOTHING() AS val', function(err, rows) {
if (err) throw err;
assert.equal(rows.length, 1);
assert.equal(rows[0].val, undefined);
done();
});
});

after(function(done) { db.close(done); });
});

0 comments on commit 19d7375

Please sign in to comment.