diff --git a/src/database.cc b/src/database.cc index bcd093e13..d0e28e7d2 100644 --- a/src/database.cc +++ b/src/database.cc @@ -28,7 +28,7 @@ void Database::Init(Handle target) { NODE_SET_PROTOTYPE_METHOD(t, "serialize", Serialize); NODE_SET_PROTOTYPE_METHOD(t, "parallelize", Parallelize); NODE_SET_PROTOTYPE_METHOD(t, "configure", Configure); - NODE_SET_PROTOTYPE_METHOD(t, "registerFunction", RegisterFunction); + NODE_SET_PROTOTYPE_METHOD(t, "loadEnvironment", RegisterFunctions); NODE_SET_GETTER(t, "open", OpenGetter); @@ -361,15 +361,12 @@ NAN_METHOD(Database::Configure) { NanReturnValue(args.This()); } -NAN_METHOD(Database::RegisterFunction) { +NAN_METHOD(Database::RegisterFunctions) { NanScope(); Database* db = ObjectWrap::Unwrap(args.This()); - REQUIRE_ARGUMENTS(2); - REQUIRE_ARGUMENT_STRING(0, functionName); - REQUIRE_ARGUMENT_FUNCTION(1, callback); - - std::string str = "(" + std::string(*String::Utf8Value(callback->ToString())) + ")"; + REQUIRE_ARGUMENTS(1); + REQUIRE_ARGUMENT_STRING(0, module); Isolate *isolate = v8::Isolate::New(); isolate->Enter(); @@ -379,25 +376,32 @@ NAN_METHOD(Database::RegisterFunction) { HandleScope handle_scope(isolate); Local context = Context::New(isolate); Context::Scope context_scope(context); + Environment *env = CreateEnvironment(isolate, uv_default_loop(), context, + 2, (const char *[]){ "node", *module }, + 0, (const char *[]){}); + LoadEnvironment(env); Local global = NanGetCurrentContext()->Global(); - Local eval = Local::Cast(global->Get(NanNew("eval"))); - - // Local str = String::Concat(String::Concat(NanNew("("), callback->ToString()), NanNew(")")); - Local argv[] = { NanNew(str.c_str(), str.length()) }; - // Local function = Local::Cast(TRY_CATCH_CALL(global, eval, 1, argv)); - Local function = Local::Cast(eval->Call(global, 1, argv)); - - FunctionEnvironment *fn = new FunctionEnvironment(isolate, *functionName, function); - sqlite3_create_function( - db->_handle, - *functionName, - -1, // arbitrary number of args - SQLITE_UTF8 | SQLITE_DETERMINISTIC, - fn, - FunctionIsolate, - NULL, - NULL); + Local process = Local::Cast(global->Get(NanNew("process"))); + Local mainModule = Local::Cast(process->Get(NanNew("mainModule"))); + Local exports = Local::Cast(mainModule->Get(NanNew("exports"))); + Local keys = exports->GetOwnPropertyNames(); + int length = keys->Length(); + for (int i = 0; i < length; i++) { + Local function = Local::Cast(exports->Get(keys->Get(i))); + String::Utf8Value functionName(keys->Get(i)->ToString()); + + FunctionEnvironment *fn = new FunctionEnvironment(isolate, *functionName, function); + sqlite3_create_function( + db->_handle, + *functionName, + -1, // arbitrary number of args + SQLITE_UTF8 | SQLITE_DETERMINISTIC, + fn, + FunctionIsolate, + NULL, + NULL); + } } isolate->Exit(); @@ -461,7 +465,7 @@ void Database::FunctionExecute(FunctionEnvironment *fn, sqlite3_context *context } TryCatch trycatch; - Local result = cb->Call(NanGetCurrentContext()->Global(), argc, argv.data()); + Local result = cb->Call(NanNew(NanUndefined()), argc, argv.data()); // process the result if (trycatch.HasCaught()) { diff --git a/src/database.h b/src/database.h index 9111aab33..f952dbb13 100644 --- a/src/database.h +++ b/src/database.h @@ -166,7 +166,7 @@ class Database : public ObjectWrap { static NAN_METHOD(Configure); - static NAN_METHOD(RegisterFunction); + static NAN_METHOD(RegisterFunctions); static void FunctionIsolate(sqlite3_context *context, int argc, sqlite3_value **argv); static void FunctionExecute(FunctionEnvironment *baton, sqlite3_context *context, int argc, sqlite3_value **argv); diff --git a/test/support/user_functions.js b/test/support/user_functions.js new file mode 100644 index 000000000..516e807ec --- /dev/null +++ b/test/support/user_functions.js @@ -0,0 +1,39 @@ +exports.MY_UPPERCASE = function(value) { + return value.toUpperCase(); +}; + +exports.MY_STRING_JOIN = function(value1, value2) { + return [value1, value2].join(' '); +}; + +exports.MY_Add = function(value1, value2) { + return value1 + value2; +}; + +exports.MY_REGEX = function(regex, value) { + return !!value.match(new RegExp(regex)); +}; + +exports.MY_REGEX_VALUE = function(regex, value) { + return /match things/i; +}; + +exports.MY_ERROR = function(value) { + throw new Error('This function always throws'); +}; + +exports.MY_UNHANDLED_TYPE = function(value) { + return {}; +}; + +exports.MY_NOTHING = function(value) { + +}; + +exports.MY_INVALID_SCOPING = function(value) { + return db; // not accessible +}; + +exports.MY_REQUIRE = function(value) { + require('./helper'); +}; diff --git a/test/user_functions.test.js b/test/user_functions.test.js index a19eec66c..5060b8941 100644 --- a/test/user_functions.test.js +++ b/test/user_functions.test.js @@ -1,35 +1,13 @@ var sqlite3 = require('..'); var assert = require('assert'); +var path = require('path'); describe('user functions', function() { var db; before(function(done) { db = new sqlite3.Database(':memory:', done); }); it('should allow registration of user functions', function() { - db.registerFunction('MY_UPPERCASE', function(value) { - return value.toUpperCase(); - }); - db.registerFunction('MY_STRING_JOIN', function(value1, value2) { - return [value1, value2].join(' '); - }); - db.registerFunction('MY_Add', function(value1, value2) { - return value1 + value2; - }); - db.registerFunction('MY_REGEX', function(regex, value) { - return !!value.match(new RegExp(regex)); - }); - db.registerFunction('MY_REGEX_VALUE', function(regex, value) { - return /match things/i; - }); - db.registerFunction('MY_ERROR', function(value) { - throw new Error('This function always throws'); - }); - db.registerFunction('MY_UNHANDLED_TYPE', function(value) { - return {}; - }); - db.registerFunction('MY_NOTHING', function(value) { - - }); + db.loadEnvironment(path.join(__dirname, 'support/user_functions.js')); }); it('should process user functions with one arg', function(done) { @@ -103,5 +81,22 @@ describe('user functions', function() { }); }); + it('does not allow access to external scope', function(done) { + db.all('SELECT MY_INVALID_SCOPING() AS val', function(err, rows) { + assert.equal(err.message, 'SQLITE_ERROR: Uncaught ReferenceError: db is not defined'); + assert.equal(rows, undefined); + done(); + }); + }); + + it('allows use of require', function(done) { + db.all('SELECT MY_REQUIRE() AS val', function(err, rows) { + if (err) throw err; + assert.equal(rows.length, 1); + assert.equal(rows[0].val, undefined); + done(); + }); + }); + after(function(done) { db.close(done); }); });