diff --git a/.github/workflows/binary-gems.yml b/.github/workflows/binary-gems.yml index 175394262..4d2c6800c 100644 --- a/.github/workflows/binary-gems.yml +++ b/.github/workflows/binary-gems.yml @@ -101,6 +101,12 @@ jobs: env: PGVERSION: ${{ matrix.PGVERSION }} steps: + # Workaround for broken ubuntu-latest image. + # See https://github.com/Shopify/ruby-lsp/issues/3942 + - name: Remove pre-installed Ruby 4.0 + if: matrix.os == 'ubuntu-latest' && matrix.ruby == '4.0' + run: rm -rf /opt/hostedtoolcache/Ruby/4.0* + - uses: actions/checkout@v4 - name: Set up Ruby if: matrix.platform != 'x86-mingw32' @@ -111,6 +117,11 @@ jobs: brew: "postgresql" # macOS mingw: "postgresql" # Windows mingw / mswin /ucrt + - name: Install postgresql server headers Ubuntu + if: startsWith(matrix.os, 'ubuntu-') + run: | + sudo apt-get -y --allow-downgrades install '^postgresql-server-dev-[0-9]+$' libkrb5-dev + - name: Set up 32 bit x86 Ruby if: matrix.platform == 'x86-mingw32' run: | diff --git a/.github/workflows/source-gem.yml b/.github/workflows/source-gem.yml index 7c1c3deae..4e5eafb91 100644 --- a/.github/workflows/source-gem.yml +++ b/.github/workflows/source-gem.yml @@ -127,9 +127,14 @@ jobs: echo "deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVER" | sudo tee -a /etc/apt/sources.list.d/pgdg.list wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - sudo apt-get -y update - sudo apt-get -y --allow-downgrades install postgresql-$PGVER libpq5=$PGVER* libpq-dev=$PGVER* + sudo apt-get -y --allow-downgrades install postgresql-$PGVER libpq5=$PGVER* libpq-dev=$PGVER* postgresql-server-dev-$PGVER libkrb5-dev echo /usr/lib/postgresql/$PGVER/bin >> $GITHUB_PATH + - name: Download OAuth support Ubuntu + if: matrix.os == 'ubuntu' && matrix.PGVER >= 18 + run: | + sudo apt-get -y --allow-downgrades install libpq-oauth=$PGVER* + - name: Download PostgreSQL Macos if: matrix.os == 'macos' run: | diff --git a/.gitignore b/.gitignore index 1de9c8ee9..b82fc544c 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,9 @@ /lib/2.?/ /lib/3.?/ /pkg/ +/spec/oauth/*.bc +/spec/oauth/*.o +/spec/oauth/*.so /tmp/ /tmp_test_*/ /vendor/ diff --git a/Gemfile b/Gemfile index c22c0988b..8cd0309ad 100644 --- a/Gemfile +++ b/Gemfile @@ -15,6 +15,7 @@ group :test do gem "rake-compiler", "~> 1.0" gem "rake-compiler-dock", "~> 1.11.0" #, git: "https://github.com/rake-compiler/rake-compiler-dock" gem "rspec", "~> 3.5" + gem "webrick", "~> 1.8" # "bigdecimal" is a gem on ruby-3.4+ and it's optional for ruby-pg. # Specs should succeed without it, but 4 examples are then excluded. # With bigdecimal commented out here, corresponding tests are omitted on ruby-3.4+ but are executed on ruby < 3.4. diff --git a/ext/gvl_wrappers.c b/ext/gvl_wrappers.c index 8e2f0ad86..3cd36b40d 100644 --- a/ext/gvl_wrappers.c +++ b/ext/gvl_wrappers.c @@ -19,6 +19,9 @@ PostgresPollingStatusType PQcancelPoll(PGcancelConn *cancelConn){return PGRES_PO #ifndef LIBPQ_HAS_PIPELINING int PQpipelineSync(PGconn *conn){return 0;} #endif +#ifndef LIBPQ_HAS_PROMPT_OAUTH_DEVICE +int auth_data_hook_proxy(PGauthData type, PGconn *conn, void *data){return 0;} +#endif #ifdef ENABLE_GVL_UNLOCK FOR_EACH_BLOCKING_FUNCTION( DEFINE_GVL_WRAPPER_STRUCT ); diff --git a/ext/gvl_wrappers.h b/ext/gvl_wrappers.h index f048d7055..52ade3eb3 100644 --- a/ext/gvl_wrappers.h +++ b/ext/gvl_wrappers.h @@ -24,6 +24,9 @@ #ifndef LIBPQ_HAS_CHUNK_MODE typedef struct pg_cancel_conn PGcancelConn; #endif +#ifndef LIBPQ_HAS_PROMPT_OAUTH_DEVICE +typedef enum { DUMMY_TYPE } PGauthData; +#endif #define DEFINE_PARAM_LIST1(type, name) \ name, @@ -281,6 +284,10 @@ FOR_EACH_BLOCKING_FUNCTION( DEFINE_GVL_STUB_DECL ); * Definitions of callback functions and their parameters */ +#define FOR_EACH_PARAM_OF_auth_data_hook_proxy(param) \ + param(PGauthData, type) \ + param(PGconn *, conn) + #define FOR_EACH_PARAM_OF_notice_processor_proxy(param) \ param(void *, arg) @@ -289,9 +296,11 @@ FOR_EACH_BLOCKING_FUNCTION( DEFINE_GVL_STUB_DECL ); /* function( name, void_or_nonvoid, returntype, lastparamtype, lastparamname ) */ #define FOR_EACH_CALLBACK_FUNCTION(function) \ + function(auth_data_hook_proxy, GVL_TYPE_NONVOID, int, void *, data) \ function(notice_processor_proxy, GVL_TYPE_VOID, void, const char *, message) \ function(notice_receiver_proxy, GVL_TYPE_VOID, void, const PGresult *, result) \ FOR_EACH_CALLBACK_FUNCTION( DEFINE_GVL_STUB_DECL ); + #endif /* end __gvl_wrappers_h */ diff --git a/ext/pg.c b/ext/pg.c index 67969b1cd..aa1395711 100644 --- a/ext/pg.c +++ b/ext/pg.c @@ -682,6 +682,7 @@ Init_pg_ext(void) /* Initialize the main extension classes */ init_pg_connection(); + init_pg_auth_hooks(); init_pg_result(); init_pg_errors(); init_pg_type_map(); diff --git a/ext/pg.h b/ext/pg.h index 58fa630d2..8b1686556 100644 --- a/ext/pg.h +++ b/ext/pg.h @@ -21,6 +21,7 @@ #include "ruby.h" #include "ruby/st.h" #include "ruby/encoding.h" +#include "ruby/thread_native.h" #define PG_ENCODING_SET_NOCHECK(obj,i) \ do { \ @@ -113,6 +114,10 @@ typedef struct { VALUE encoder_for_put_copy_data; /* Kind of PG::Coder object for casting COPY rows to ruby values */ VALUE decoder_for_get_copy_data; +#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE + /* Callback for retrieval of OAuth token */ + VALUE auth_data_hook; +#endif /* Ruby encoding index of the client/internal encoding */ int enc_idx : PG_ENC_IDX_BITS; /* flags controlling Symbol/String field names */ @@ -288,6 +293,7 @@ extern VALUE pg_typemap_all_strings; void Init_pg_ext _(( void )); void init_pg_connection _(( void )); +void init_pg_auth_hooks _(( void )); void init_pg_result _(( void )); void init_pg_errors _(( void )); void init_pg_type_map _(( void )); @@ -374,6 +380,12 @@ rb_encoding * pg_get_pg_encname_as_rb_encoding _(( const char * )); const char * pg_get_rb_encoding_as_pg_encoding _(( rb_encoding * )); rb_encoding *pg_conn_enc_get _(( PGconn * )); +#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE +int auth_data_hook_proxy(PGauthData type, PGconn *conn, void *data); +int pgconn_lookup(PGconn *pgconn, VALUE *rb_conn); +void pgconn_insert(PGconn *pgconn, VALUE rb_conn); +void pgconn_delete(PGconn *pgconn); +#endif void notice_receiver_proxy(void *arg, const PGresult *result); void notice_processor_proxy(void *arg, const char *message); diff --git a/ext/pg_auth_hooks.c b/ext/pg_auth_hooks.c new file mode 100644 index 000000000..b3da2ee90 --- /dev/null +++ b/ext/pg_auth_hooks.c @@ -0,0 +1,367 @@ +/* + * pg_auth_hooks.c - Auth hooks for PG module + * $Id$ + * + */ + +#include "pg.h" + +#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE + +/* + * We store the pgconn pointers in a register to retrieve the PG::Connection VALUE in the oauth hook. + */ +struct st_table *pgconn2value; +rb_nativethread_lock_t pgconn2value_lock; + +static VALUE rb_cPromptOAuthDevice; +static VALUE rb_cOAuthBearerRequest; + +int pgconn_lookup(PGconn *pgconn, VALUE *rb_conn){ + int res; + rb_nativethread_lock_lock(&pgconn2value_lock); + res = st_lookup(pgconn2value, (st_data_t)pgconn, (st_data_t*)rb_conn); + rb_nativethread_lock_unlock(&pgconn2value_lock); + return res; +} + +void pgconn_insert(PGconn *pgconn, VALUE rb_conn) { + rb_nativethread_lock_lock(&pgconn2value_lock); + st_insert( pgconn2value, (st_data_t)pgconn, (st_data_t)rb_conn ); + rb_nativethread_lock_unlock(&pgconn2value_lock); +} + +void pgconn_delete(PGconn *pgconn) { + rb_nativethread_lock_lock(&pgconn2value_lock); + st_delete( pgconn2value, (st_data_t*)&pgconn, NULL ); + rb_nativethread_lock_unlock(&pgconn2value_lock); +} + + +/* + * Document-class: PG::PromptOAuthDevice + */ + +typedef struct { + PGpromptOAuthDevice *prompt; +} t_pg_prompt_oauth_device; + +static size_t +pg_prompt_oauth_device_memsize(const void *_this) +{ + return sizeof(t_pg_prompt_oauth_device); +} + +static const rb_data_type_t pg_prompt_oauth_device_type = { + "PG::PromptOAuthDevice", + { + NULL, + RUBY_TYPED_DEFAULT_FREE, + pg_prompt_oauth_device_memsize, + NULL, + }, + 0, + 0, + RUBY_TYPED_WB_PROTECTED | RUBY_TYPED_FREE_IMMEDIATELY, +}; + +static t_pg_prompt_oauth_device * +pg_get_prompt_oauth_device_safe(VALUE self) +{ + t_pg_prompt_oauth_device *this; + + TypedData_Get_Struct(self, t_pg_prompt_oauth_device, &pg_prompt_oauth_device_type, this); + + if (!this->prompt) + rb_raise(rb_ePGerror, "data cannot be accessed after callback has completed"); + + return this; +} + +/* + * call-seq: + * prompt.verification_uri -> String + */ +static VALUE +pg_prompt_oauth_device_verification_uri(VALUE self) +{ + t_pg_prompt_oauth_device *this = pg_get_prompt_oauth_device_safe(self); + + if (!this->prompt->verification_uri) + rb_raise(rb_ePGerror, "internal error: verification_uri is missing"); + + return rb_str_new_cstr(this->prompt->verification_uri); +} + +/* + * call-seq: + * prompt.user_code -> String + */ +static VALUE +pg_prompt_oauth_device_user_code(VALUE self) +{ + t_pg_prompt_oauth_device *this = pg_get_prompt_oauth_device_safe(self); + + if (!this->prompt->user_code) + rb_raise(rb_ePGerror, "internal error: user_code is missing"); + + return rb_str_new_cstr(this->prompt->user_code); +} + +/* + * call-seq: + * prompt.verification_uri_complete -> String | nil + */ +static VALUE +pg_prompt_oauth_device_verification_uri_complete(VALUE self) +{ + t_pg_prompt_oauth_device *this = pg_get_prompt_oauth_device_safe(self); + + return this->prompt->verification_uri_complete ? rb_str_new_cstr(this->prompt->verification_uri_complete) : Qnil; +} + +/* + * call-seq: + * prompt.expires_in -> Integer + */ +static VALUE +pg_prompt_oauth_device_expires_in(VALUE self) +{ + t_pg_prompt_oauth_device *this = pg_get_prompt_oauth_device_safe(self); + + return INT2FIX(this->prompt->expires_in); +} + +/* + * Document-class: PG::OAuthBearerRequest + */ + +typedef struct { + PGoauthBearerRequest *request; +} t_pg_oauth_bearer_request; + +static size_t +pg_oauth_bearer_request_memsize(const void *_this) +{ + return sizeof(t_pg_oauth_bearer_request); +} + +static const rb_data_type_t pg_oauth_bearer_request_type = { + "PG::OAuthBearerRequest", + { + NULL, + RUBY_TYPED_DEFAULT_FREE, + pg_oauth_bearer_request_memsize, + NULL, + }, + 0, + 0, + RUBY_TYPED_WB_PROTECTED | RUBY_TYPED_FREE_IMMEDIATELY, +}; + +static t_pg_oauth_bearer_request * +pg_get_oauth_bearer_request_safe(VALUE self) +{ + t_pg_oauth_bearer_request *this; + + TypedData_Get_Struct(self, t_pg_oauth_bearer_request, &pg_oauth_bearer_request_type, this); + + if (!this->request) + rb_raise(rb_ePGerror, "data cannot be accessed after callback has completed"); + + return this; +} + +/* + * call-seq: + * prompt.openid_configuration -> String + */ +static VALUE +pg_oauth_bearer_request_openid_configuration(VALUE self) +{ + t_pg_oauth_bearer_request *this = pg_get_oauth_bearer_request_safe(self); + + if (!this->request->openid_configuration) + rb_raise(rb_ePGerror, "internal error: openid_configuration is missing"); + + return rb_str_new_cstr(this->request->openid_configuration); +} + +/* + * call-seq: + * request.scope -> String | nil + */ +static VALUE +pg_oauth_bearer_request_scope(VALUE self) +{ + t_pg_oauth_bearer_request *this = pg_get_oauth_bearer_request_safe(self); + + return this->request->scope ? rb_str_new_cstr(this->request->scope) : Qnil; +} + +/* + * call-seq: + * request.token = token + * + * See also #token + */ +static VALUE +pg_oauth_bearer_request_token_set(VALUE self, VALUE token) +{ + t_pg_oauth_bearer_request *this = pg_get_oauth_bearer_request_safe(self); + + /* This can throw an exception so needs to be done before free() */ + char *token_cstr = NIL_P(token) ? NULL : strdup(StringValueCStr(token)); + + if (this->request->token) + free(this->request->token); + + this->request->token = token_cstr; + + return token; +} + +/* + * call-seq: + * request.token -> String | nil + * + * See also #token= + */ +static VALUE +pg_oauth_bearer_request_token_get(VALUE self) +{ + t_pg_oauth_bearer_request *this = pg_get_oauth_bearer_request_safe(self); + + return this->request->token ? rb_str_new_cstr(this->request->token) : Qnil; +} + +static void +oauth_bearer_request_cleanup(PGconn *_conn, struct PGoauthBearerRequest *request) +{ + if (request->token) + free(request->token); +} + +static VALUE +call_auth_data_hook(VALUE args) +{ + VALUE proc = ((VALUE*)args)[0]; + VALUE conn_num = ((VALUE*)args)[1]; + VALUE v_data = ((VALUE*)args)[2]; + + return rb_funcall(proc, rb_intern("call"), 2, conn_num, v_data); +} + +static VALUE +prompt_oauth_device_hook_cleanup(VALUE self, VALUE ex) +{ + t_pg_prompt_oauth_device *this = pg_get_prompt_oauth_device_safe(self); + + this->prompt = NULL; + + rb_exc_raise(ex); +} + +static VALUE +oauth_bearer_request_hook_cleanup(VALUE self, VALUE ex) +{ + t_pg_oauth_bearer_request *this = pg_get_oauth_bearer_request_safe(self); + + if (this->request->token) + free(this->request->token); + this->request->token = NULL; + + this->request = NULL; + + rb_exc_raise(ex); +} + +/* + * Auth data proxy function -- delegate the callback to the + * currently-registered Ruby auth_data_hook object. + */ +int +auth_data_hook_proxy(PGauthData type, PGconn *pgconn, void *data) +{ + VALUE rb_conn = Qnil; + VALUE ret = Qnil; + + if ( st_lookup(pgconn2value, (st_data_t)pgconn, (st_data_t*)&rb_conn) ) { + t_pg_connection *this = pg_get_connection( rb_conn ); + VALUE proc = this->auth_data_hook; + + if (type == PQAUTHDATA_PROMPT_OAUTH_DEVICE) { + t_pg_prompt_oauth_device *prompt; + + VALUE v_prompt = TypedData_Make_Struct(rb_cPromptOAuthDevice, t_pg_prompt_oauth_device, &pg_prompt_oauth_device_type, prompt); + VALUE args[] = { proc, rb_conn, v_prompt }; + + prompt->prompt = data; + + ret = rb_rescue(call_auth_data_hook, (VALUE)&args, prompt_oauth_device_hook_cleanup, v_prompt); + + prompt->prompt = NULL; + } else if (type == PQAUTHDATA_OAUTH_BEARER_TOKEN) { + t_pg_oauth_bearer_request *request; + + VALUE v_request = TypedData_Make_Struct(rb_cOAuthBearerRequest, t_pg_oauth_bearer_request, &pg_oauth_bearer_request_type, request); + VALUE args[] = { proc, rb_conn, v_request }; + + request->request = data; + request->request->cleanup = oauth_bearer_request_cleanup; + + ret = rb_rescue(call_auth_data_hook, (VALUE)&args, oauth_bearer_request_hook_cleanup, v_request); + + request->request = NULL; + } + } + + /* TODO: a hook can return 1, 0 or -1 */ + return RTEST(ret); +} + +/* + * call-seq: + * PG.pgconn2value_size -> Integer + */ +static VALUE +pg_oauth_pgconn2value_size_get(VALUE self) +{ + return SIZET2NUM(rb_st_table_size(pgconn2value)); +} + + +void +init_pg_auth_hooks(void) +{ + + pgconn2value = st_init_numtable(); + rb_nativethread_lock_initialize(&pgconn2value_lock); + + PQsetAuthDataHook(gvl_auth_data_hook_proxy); // TODO: Add some safeguards? + + /* rb_mPG = rb_define_module("PG") */ + rb_define_private_method(rb_singleton_class(rb_mPG), "pgconn2value_size", pg_oauth_pgconn2value_size_get, 0); + + rb_cPromptOAuthDevice = rb_define_class_under(rb_mPG, "PromptOAuthDevice", rb_cObject); + rb_undef_alloc_func(rb_cPromptOAuthDevice); + + rb_define_method(rb_cPromptOAuthDevice, "verification_uri", pg_prompt_oauth_device_verification_uri, 0); + rb_define_method(rb_cPromptOAuthDevice, "user_code", pg_prompt_oauth_device_user_code, 0); + rb_define_method(rb_cPromptOAuthDevice, "verification_uri_complete", pg_prompt_oauth_device_verification_uri_complete, 0); + rb_define_method(rb_cPromptOAuthDevice, "expires_in", pg_prompt_oauth_device_expires_in, 0); + + rb_cOAuthBearerRequest = rb_define_class_under(rb_mPG, "OAuthBearerRequest", rb_cObject); + rb_undef_alloc_func(rb_cOAuthBearerRequest); + + rb_define_method(rb_cOAuthBearerRequest, "openid_configuration", pg_oauth_bearer_request_openid_configuration, 0); + rb_define_method(rb_cOAuthBearerRequest, "scope", pg_oauth_bearer_request_scope, 0); + rb_define_method(rb_cOAuthBearerRequest, "token=", pg_oauth_bearer_request_token_set, 1); + rb_define_method(rb_cOAuthBearerRequest, "token", pg_oauth_bearer_request_token_get, 0); +} + +#else + +void init_pg_auth_hooks(void) {} + +#endif diff --git a/ext/pg_connection.c b/ext/pg_connection.c index c8b71d67b..97aa7358c 100644 --- a/ext/pg_connection.c +++ b/ext/pg_connection.c @@ -179,11 +179,17 @@ pgconn_gc_mark( void *_this ) rb_gc_mark_movable( this->trace_stream ); rb_gc_mark_movable( this->encoder_for_put_copy_data ); rb_gc_mark_movable( this->decoder_for_get_copy_data ); +#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE + rb_gc_mark_movable( this->auth_data_hook ); +#endif } static void pgconn_gc_compact( void *_this ) { +#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE + VALUE old_rb_conn; +#endif t_pg_connection *this = (t_pg_connection *)_this; pg_gc_location( this->socket_io ); pg_gc_location( this->notice_receiver ); @@ -193,6 +199,15 @@ pgconn_gc_compact( void *_this ) pg_gc_location( this->trace_stream ); pg_gc_location( this->encoder_for_put_copy_data ); pg_gc_location( this->decoder_for_get_copy_data ); + +#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE + pg_gc_location( this->auth_data_hook ); + /* update the PG::Connection object which is maybe stored in pgconn2value */ + if ( pgconn_lookup(this->pgconn, &old_rb_conn) ) { + VALUE new_rb_conn = rb_gc_location(old_rb_conn); + pgconn_insert( this->pgconn, new_rb_conn ); + } +#endif } @@ -210,9 +225,15 @@ pgconn_gc_free( void *_this ) } } #endif - if (this->pgconn != NULL) + if (this->pgconn != NULL) { PQfinish( this->pgconn ); +#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE + /* Remove from auth hook callback table */ + pgconn_delete(this->pgconn ); +#endif + } + xfree(this); } @@ -264,6 +285,9 @@ pgconn_s_allocate( VALUE klass ) RB_OBJ_WRITE(self, &this->type_map_for_results, pg_typemap_all_strings); RB_OBJ_WRITE(self, &this->encoder_for_put_copy_data, Qnil); RB_OBJ_WRITE(self, &this->decoder_for_get_copy_data, Qnil); +#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE + RB_OBJ_WRITE(self, &this->auth_data_hook, Qnil); +#endif RB_OBJ_WRITE(self, &this->trace_stream, Qnil); rb_ivar_set(self, rb_intern("@calls_to_put_copy_data"), INT2FIX(0)); rb_ivar_set(self, rb_intern("@iopts_for_reset"), Qnil); @@ -623,6 +647,45 @@ pgconn_reset_poll(VALUE self) return INT2FIX((int)status); } +#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE + +/* + * call-seq: + * conn.auth_data_hook(&block) + * + * Set a auth data hook. + */ +static VALUE +pgconn_auth_data_hook_set(VALUE self, VALUE proc) +{ + t_pg_connection *this = pg_get_connection( self ); + + if (rb_obj_is_proc(proc)) { + /* set proc */ + pgconn_insert( this->pgconn, self ); + } else if (NIL_P(proc)) { + /* if nil is given, set back to default */ + pgconn_delete( this->pgconn ); + } else { + rb_raise(rb_eArgError, "Proc object expected"); + } + RB_OBJ_WRITE(self, &this->auth_data_hook, proc); + return proc; +} + +/* + * call-seq: + * conn.auth_data_hook() + * + * Returns the defined auth data hook. + */ +static VALUE +pgconn_auth_data_hook_get(VALUE self) +{ + return pg_get_connection(self)->auth_data_hook; +} + +#endif /* * call-seq: @@ -965,6 +1028,19 @@ pgconn_socket(VALUE self) return INT2NUM(sd); } +#ifdef _WIN32 +#define is_socket(fd) rb_w32_is_socket(fd) +#else +static int +is_socket(int fd) +{ + struct stat sbuf; + + if (fstat(fd, &sbuf) < 0) + rb_sys_fail("fstat(2)"); + return S_ISSOCK(sbuf.st_mode); +} +#endif VALUE pg_wrap_socket_io(int sd, VALUE self, VALUE *p_socket_io, int *p_ruby_sd) @@ -983,7 +1059,7 @@ pg_wrap_socket_io(int sd, VALUE self, VALUE *p_socket_io, int *p_ruby_sd) *p_ruby_sd = ruby_sd = sd; #endif - cSocket = rb_const_get(rb_cObject, rb_intern("BasicSocket")); + cSocket = rb_const_get(rb_cObject, rb_intern(is_socket(ruby_sd) ? "BasicSocket" : "IO")); socket_io = rb_funcall( cSocket, rb_intern("for_fd"), 1, INT2NUM(ruby_sd)); /* Disable autoclose feature */ @@ -4737,6 +4813,10 @@ init_pg_connection(void) rb_define_private_method(rb_cPGconn, "reset_start2", pgconn_reset_start2, 1); rb_define_method(rb_cPGconn, "reset_poll", pgconn_reset_poll, 0); rb_define_alias(rb_cPGconn, "close", "finish"); +#ifdef LIBPQ_HAS_PROMPT_OAUTH_DEVICE + rb_define_method(rb_cPGconn, "auth_data_hook=", pgconn_auth_data_hook_set, 1); + rb_define_method(rb_cPGconn, "auth_data_hook", pgconn_auth_data_hook_get, 0); +#endif /****** PG::Connection INSTANCE METHODS: Connection Status ******/ rb_define_method(rb_cPGconn, "db", pgconn_db, 0); diff --git a/lib/pg.rb b/lib/pg.rb index b996362da..73246a9fe 100644 --- a/lib/pg.rb +++ b/lib/pg.rb @@ -84,8 +84,8 @@ def self.version_string( include_buildnum=nil ) ### Convenience alias for PG::Connection.new. - def self.connect( *args, &block ) - Connection.new( *args, &block ) + def self.connect( *args, **kwargs, &block ) + Connection.new( *args, **kwargs, &block ) end if defined?(Ractor.make_shareable) diff --git a/lib/pg/connection.rb b/lib/pg/connection.rb index eb607ab82..40513a454 100644 --- a/lib/pg/connection.rb +++ b/lib/pg/connection.rb @@ -867,8 +867,8 @@ class << self # It's still possible to do load balancing with +load_balance_hosts+ set to +random+ and to increase the number of connections a node gets, when the hostname is provided multiple times in the host string. # This is because in non-timeout cases the host is tried multiple times. # - def new(*args) - conn = connect_to_hosts(*args) + def new(*args, **kwargs) + conn = connect_to_hosts(*args, **kwargs) if block_given? begin @@ -919,8 +919,8 @@ def new(*args) port: dests.map{|d| d[2] }.join(",")) end - private def connect_to_hosts(*args) - option_string = parse_connect_args(*args) + private def connect_to_hosts(*args, set_auth_data_hook: nil, **kwargs) + option_string = parse_connect_args(*args, **kwargs) iopts = PG::Connection.conninfo_parse(option_string).each_with_object({}){|h, o| o[h[:keyword].to_sym] = h[:val] if h[:val] } iopts = PG::Connection.conndefaults.each_with_object({}){|h, o| o[h[:keyword].to_sym] = h[:val] if h[:val] }.merge(iopts) @@ -944,6 +944,12 @@ def new(*args) conn = self.connect_start(iopts) or raise(PG::Error, "Unable to create a new connection") + if conn.respond_to?(:auth_data_hook) + conn.auth_data_hook = set_auth_data_hook + elsif set_auth_data_hook + raise ArgumentError, "invalid option set_auth_data_hook" + end + raise PG::ConnectionBad, conn.error_message if conn.status == PG::CONNECTION_BAD # save the connection options for conn.reset diff --git a/spec/helpers.rb b/spec/helpers.rb index e974def5e..1a35798de 100644 --- a/spec/helpers.rb +++ b/spec/helpers.rb @@ -194,6 +194,7 @@ class PostgresServer attr_reader :port attr_reader :conninfo attr_reader :unix_socket + attr_reader :version ### Set up a PostgreSQL database instance for testing. def initialize(name, port: 23456, postgresql_conf: '') @@ -205,6 +206,7 @@ def initialize(name, port: 23456, postgresql_conf: '') @pgdata = @test_dir + 'data' @logfile = @test_dir + 'setup.log' @pg_bindir = pg_bindir + @version = pg_version @unix_socket = @test_dir.to_s @conninfo = "host=localhost port=#{@port} dbname=test sslrootcert=#{@pgdata + 'ruby-pg-ca-cert'} sslcert=#{@pgdata + 'ruby-pg-client-cert'} sslkey=#{@pgdata + 'ruby-pg-client-key'}" @@ -267,8 +269,13 @@ def setup_cluster(postgresql_conf) ssl_cert_file = 'ruby-pg-server-cert' ssl_key_file = 'ruby-pg-server-key' fsync = off - #{postgresql_conf} EOT + if @version >= 18 + fd.puts <<~EOT + oauth_validator_libraries = '#{TEST_DIRECTORY}/spec/oauth/dummy_validator' + EOT + end + fd.puts postgresql_conf end # Enable MD5 authentication in hba config @@ -278,6 +285,12 @@ def setup_cluster(postgresql_conf) # TYPE DATABASE USER ADDRESS METHOD host all testusermd5 ::1/128 md5 EOT + if @version >= 18 + fd.puts <<~EOT + host all testuseroauth 127.0.0.1/32 oauth scope=test issuer="http://localhost:#{@port + 3}" + host all testuseroauth ::1/32 oauth scope=test issuer="http://localhost:#{@port + 3}" + EOT + end fd.puts hba_content end @@ -340,6 +353,10 @@ def pg_bindir rescue nil end + + def pg_version + `#{pg_bin_path("pg_ctl")} --version`[/pg_ctl \(PostgreSQL\) (\d+)/, 1]&.to_i + end end class CertGenerator @@ -656,6 +673,38 @@ def with_env_vars(**kwargs) def set_etc_hosts(hostaddr, hostname) system "sudo --non-interactive sed -i '/.* #{hostname}$/{h;s/.*/#{hostaddr} #{hostname}/};${x;/^$/{s//#{hostaddr} #{hostname}/;H};x}' /etc/hosts" or skip("unable to change /etc/hosts file") end + + def build_oauth_validator + skip "requires a PostgreSQL 18 cluster" unless $pg_server.version >= 18 + + system "make", "-s", "-C", (TEST_DIRECTORY + "spec/oauth").to_s + raise "Building OAuth validator library failed!" unless $?.success? + + require 'webrick' + + PG.connect(@conninfo) do |conn| + conn.exec("DROP USER IF EXISTS testuseroauth") + conn.exec("CREATE USER testuseroauth") + end + end + + def start_fake_oauth(port) + server = WEBrick::HTTPServer.new(Port: port, Logger: WEBrick::Log.new(nil, WEBrick::BasicLog::WARN)) + server.mount_proc("/.well-known/openid-configuration") do |req, res| + res["Content-Type"] = "application/json" + res.body = %!{"issuer":"http://localhost:#{port}","token_endpoint":"http://localhost:#{port}/token","device_authorization_endpoint":"http://localhost:#{@port + 3}/devauth"}! + end + server.mount_proc("/devauth") do |req, res| + res["Content-Type"] = "application/json" + res.body = %!{"device_code":"42","user_code":"666","verification_uri":"http://localhost:#{port}/verify","expires_in":60}! + end + server.mount_proc("/token") do |req, res| + res["Content-Type"] = "application/json" + res.body = %!{"access_token":"yes","token_type":""}! + end + Thread.new { server.start } + server + end end RSpec.configure do |config| diff --git a/spec/oauth/Makefile b/spec/oauth/Makefile new file mode 100644 index 000000000..508292cea --- /dev/null +++ b/spec/oauth/Makefile @@ -0,0 +1,8 @@ +MODULES = dummy_validator +PGFILEDESC = "dummy_validator - dummy OAuth validator" + +OBJS = $(WIN32RES) + +PG_CONFIG = pg_config +PGXS := $(shell $(PG_CONFIG) --pgxs) +include $(PGXS) diff --git a/spec/oauth/dummy_validator.c b/spec/oauth/dummy_validator.c new file mode 100644 index 000000000..33018392a --- /dev/null +++ b/spec/oauth/dummy_validator.c @@ -0,0 +1,29 @@ +#include "postgres.h" +#include "fmgr.h" +#include "libpq/oauth.h" + +PG_MODULE_MAGIC; + +static bool +validate_token(const ValidatorModuleState *state, + const char *token, const char *role, + ValidatorModuleResult *res) +{ + if (strcmp(token, "yes") == 0) + { + res->authorized = true; + res->authn_id = pstrdup(role); + } + return true; +} + +static const OAuthValidatorCallbacks validator_callbacks = { + PG_OAUTH_VALIDATOR_MAGIC, + .validate_cb = validate_token +}; + +const OAuthValidatorCallbacks * +_PG_oauth_validator_module_init(void) +{ + return &validator_callbacks; +} diff --git a/spec/pg/connection_async_spec.rb b/spec/pg/connection_async_spec.rb index da7174e67..d61104972 100644 --- a/spec/pg/connection_async_spec.rb +++ b/spec/pg/connection_async_spec.rb @@ -156,4 +156,188 @@ def interrupt_thread(exc=nil) expect( conn.hostaddr ).to eq( "::1" ) expect( conn.port ).to eq( @port ) end + + describe "option set_auth_data_hook", :postgresql_18 do + before :all do + build_oauth_validator + end + + before :each do + @old_env, ENV["PGOAUTHDEBUG"] = ENV["PGOAUTHDEBUG"], "UNSAFE" + end + + it "should call prompt oauth device hook" do + oauth_server = start_fake_oauth(@port + 3) + + verification_uri, user_code, verification_uri_complete, expires_in = nil, nil, nil, nil + conn1, conn2 = nil, nil + + hook = proc do |conn, data| + case data + when PG::PromptOAuthDevice + conn1 = conn + verification_uri = data.verification_uri + user_code = data.user_code + verification_uri_complete = data.verification_uri_complete + expires_in = data.expires_in + true + end + end + + begin + PG.connect("host=localhost port=#{@port} dbname=test user=testuseroauth oauth_issuer=http://localhost:#{@port + 3} oauth_client_id=foo", set_auth_data_hook: hook) do |conn| + conn.exec("SELECT 1") + conn2 = conn + end + rescue PG::ConnectionBad => e + if e.message =~ /no OAuth flows are available/ + skip "requires libpq-oauth to be installed" + end + raise + ensure + oauth_server.shutdown + end + + expect(conn1).to eq(conn2) + expect(verification_uri).to eq("http://localhost:#{@port + 3}/verify") + expect(user_code).to eq("666") + expect(verification_uri_complete).to eq(nil) + expect(expires_in).to eq(60) + end + + it "should call oauth bearer request hook" do + openid_configuration, scope = nil, nil + conn1, conn2 = nil, nil + + hook = proc do |conn, data| + case data + when PG::OAuthBearerRequest + conn1 = conn + openid_configuration = data.openid_configuration + scope = data.scope + data.token = "yes" + true + end + end + + PG.connect(host: "localhost", port: @port, dbname: "test", user: "testuseroauth", oauth_issuer: "http://localhost:#{@port + 3}", oauth_client_id: "foo", set_auth_data_hook: hook) do |conn| + conn.exec("SELECT 1") + conn2 = conn + end + + expect(conn1).to eq(conn2) + expect(openid_configuration).to eq("http://localhost:#{@port + 3}/.well-known/openid-configuration") + expect(scope).to eq("test") + end + + it "shouldn't garbage collect PG::Connection in use" do + conn1 = nil + hook = proc do |conn, data| + case data + when PG::OAuthBearerRequest + data.token = "yes" + conn1 = conn + true + end + end + + GC.stress = true + begin + conn = PG.connect(host: "localhost", port: @port, dbname: "test", user: "testuseroauth", oauth_issuer: "http://localhost:#{@port + 3}", oauth_client_id: "foo", set_auth_data_hook: hook) + ensure + GC.stress = false + end + conn.exec("SELECT 1") + + expect(conn1).to eq(conn) + end + + it "should garbage collect PG::Connection after use" do + hook = proc do |conn, data| + case data + when PG::OAuthBearerRequest + openid_configuration = data.openid_configuration + scope = data.scope + data.token = "yes" + true + end + end + + before = PG.send(:pgconn2value_size) + 20.times do + conn = PG.connect(host: "localhost", port: @port, dbname: "test", user: "testuseroauth", oauth_issuer: "http://localhost:#{@port + 3}", oauth_client_id: "foo", set_auth_data_hook: hook) + conn.exec("SELECT 1") + end + + GC.start + after = PG.send(:pgconn2value_size) + + # Number of GC'ed objects + expect(before + 20 - after).to be_between(1, 50) + end + + it "should be usable with Ractor", :ractor do + ractors = 20.times.map do |idx1| + Ractor.new(@conninfo, @port, idx1) do |conninfo, port, idx2| + hook = proc do |conn, data| + case data + when PG::OAuthBearerRequest + openid_configuration = data.openid_configuration + scope = data.scope + data.token = "yes" + true + end + end + + conn = PG.connect(host: "localhost", port: port, dbname: "test", user: "testuseroauth", oauth_issuer: "http://localhost:#{port + 3}", oauth_client_id: "foo", set_auth_data_hook: hook) + conn.exec("SELECT #{idx2}").values + ensure + conn&.finish + end + end + + vals = ractors.map(&:value) + + expect( vals ).to eq( 20.times.map { |i| [[i.to_s]] } ) + end + + # TODO: Is resetting the global hook still useful, when the hook is per connection? + # it "should reset the hook when called without block" do + # oauth_server = start_fake_oauth(@port + 3) + # + # PG.set_auth_data_hook do |conn_num, data| + # raise "broken hook" + # end + # + # expect do + # PG.connect("host=localhost port=#{@port} dbname=test user=testuseroauth oauth_issuer=http://localhost:#{@port + 3} oauth_client_id=foo") {} + # end.to raise_error("broken hook") + # + # PG.set_auth_data_hook + # + # begin + # PG.connect("host=localhost port=#{@port} dbname=test user=testuseroauth oauth_issuer=http://localhost:#{@port + 3} oauth_client_id=foo") do |conn| + # conn.exec("SELECT 1") + # end + # rescue PG::ConnectionBad => e + # if e.message =~ /no OAuth flows are available/ + # skip "requires libpq-oauth to be installed" + # end + # raise + # ensure + # oauth_server.shutdown + # end + # end + + # around :example do |ex| + # GC.stress = true + # ex.run + # GC.stress = false + # end + + after :each do + # PG.set_auth_data_hook + ENV["PGOAUTHDEBUG"] = @old_env + end + end end diff --git a/spec/pg/connection_spec.rb b/spec/pg/connection_spec.rb index 7763493a2..cc09542d9 100644 --- a/spec/pg/connection_spec.rb +++ b/spec/pg/connection_spec.rb @@ -3010,4 +3010,36 @@ def wait_check_socket(conn) .to raise_error(TypeError) end end + + describe "option set_auth_data_hook", :postgresql_18 do + before :all do + build_oauth_validator + end + + before :each do + @old_env, ENV["PGOAUTHDEBUG"] = ENV["PGOAUTHDEBUG"], "UNSAFE" + end + + it "should work with no hook" do + oauth_server = start_fake_oauth(@port + 3) + + begin + PG.connect("host=localhost port=#{@port} dbname=test user=testuseroauth oauth_issuer=http://localhost:#{@port + 3} oauth_client_id=foo") do |conn| + conn.exec("SELECT 1") + end + rescue PG::ConnectionBad => e + if e.message =~ /no OAuth flows are available/ + skip "requires libpq-oauth to be installed" + end + raise + ensure + oauth_server.shutdown + end + end + + after :each do + # PG.set_auth_data_hook + ENV["PGOAUTHDEBUG"] = @old_env + end + end end diff --git a/spec/pg/gc_compact_spec.rb b/spec/pg/gc_compact_spec.rb index 97b2d55f9..fcbac8fa5 100644 --- a/spec/pg/gc_compact_spec.rb +++ b/spec/pg/gc_compact_spec.rb @@ -24,6 +24,8 @@ require_relative '../helpers' describe "GC.compact", if: GC.respond_to?(:compact) do + hook_called = false + before :all do TM1 = Class.new(PG::TypeMapByClass) do def conv_array(value) @@ -57,6 +59,20 @@ def conv_array(value) CANCON.socket_io end + if PG::Connection.instance_methods.include?(:auth_data_hook) + build_oauth_validator + @old_env, ENV["PGOAUTHDEBUG"] = ENV["PGOAUTHDEBUG"], "UNSAFE" + HOOKED_CONN = PG::Connection.connect_start( "host=localhost port=#{@port} dbname=test user=testuseroauth oauth_issuer=http://localhost:#{@port + 3} oauth_client_id=foo" ) + HOOKED_CONN.auth_data_hook = proc do |conn, data| + case data + when PG::OAuthBearerRequest + data.token = "yes" + hook_called = true + true + end + end + end + begin # Use GC.verify_compaction_references instead of GC.compact . # This has the advantage that all movable objects are actually moved. @@ -111,6 +127,14 @@ def conv_array(value) expect( CANCON.socket_io ).to be_kind_of( IO ) end + it "should compact PG::Connection in pgconn2value", :postgresql_18 do + wait_for_polling_ok(HOOKED_CONN) + expect( HOOKED_CONN.error_message ).to eq("") + HOOKED_CONN.finish + expect( hook_called ).to be_truthy + ENV["PGOAUTHDEBUG"] = @old_env + end + after :all do CONN2.close end