diff --git a/pkg/context/token.go b/pkg/context/token.go index beddb02b2..3cb3707da 100644 --- a/pkg/context/token.go +++ b/pkg/context/token.go @@ -12,10 +12,8 @@ type tokenCtx string var tokenCtxKey tokenCtx = "tokenctx" type TokenInfo struct { - Token string - TokenType utils.TokenType - ScopesFetched bool - Scopes []string + Token string + TokenType utils.TokenType } // WithTokenInfo adds TokenInfo to the context @@ -30,3 +28,20 @@ func GetTokenInfo(ctx context.Context) (*TokenInfo, bool) { } return nil, false } + +type TokenScopesKey tokenCtx + +var tokenScopesKey TokenScopesKey = "tokenscopesctx" + +// WithTokenScopes adds token scopes to the context +func WithTokenScopes(ctx context.Context, scopes []string) context.Context { + return context.WithValue(ctx, tokenScopesKey, scopes) +} + +// GetTokenScopes retrieves token scopes from the context +func GetTokenScopes(ctx context.Context) ([]string, bool) { + if scopes, ok := ctx.Value(tokenScopesKey).([]string); ok { + return scopes, true + } + return nil, false +} diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 875d54bbb..753c62fa8 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -278,8 +278,10 @@ func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.Fetche // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. // Fine-grained PATs and other token types don't support this, so we skip filtering. if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken { - if tokenInfo.ScopesFetched { - return b.WithFilter(github.CreateToolScopeFilter(tokenInfo.Scopes)) + // Check if scopes are already in context (should be set by WithPATScopes). If not, fetch them. + existingScopes, ok := ghcontext.GetTokenScopes(ctx) + if ok { + return b.WithFilter(github.CreateToolScopeFilter(existingScopes)) } scopesList, err := fetcher.FetchTokenScopes(ctx, tokenInfo.Token) diff --git a/pkg/http/middleware/pat_scope.go b/pkg/http/middleware/pat_scope.go index 8b77b3d32..bb1efdc01 100644 --- a/pkg/http/middleware/pat_scope.go +++ b/pkg/http/middleware/pat_scope.go @@ -26,6 +26,13 @@ func WithPATScopes(logger *slog.Logger, scopeFetcher scopes.FetcherInterface) fu // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. // Fine-grained PATs and other token types don't support this, so we skip filtering. if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken { + existingScopes, ok := ghcontext.GetTokenScopes(ctx) + if ok { + logger.Debug("using existing scopes from context", "scopes", existingScopes) + next.ServeHTTP(w, r) + return + } + scopesList, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token) if err != nil { logger.Warn("failed to fetch PAT scopes", "error", err) @@ -33,11 +40,8 @@ func WithPATScopes(logger *slog.Logger, scopeFetcher scopes.FetcherInterface) fu return } - tokenInfo.Scopes = scopesList - tokenInfo.ScopesFetched = true - // Store fetched scopes in context for downstream use - ctx := ghcontext.WithTokenInfo(ctx, tokenInfo) + ctx = ghcontext.WithTokenScopes(ctx, scopesList) next.ServeHTTP(w, r.WithContext(ctx)) return diff --git a/pkg/http/middleware/pat_scope_test.go b/pkg/http/middleware/pat_scope_test.go index eb472bcf1..0607b8cf2 100644 --- a/pkg/http/middleware/pat_scope_test.go +++ b/pkg/http/middleware/pat_scope_test.go @@ -111,12 +111,13 @@ func TestWithPATScopes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var capturedTokenInfo *ghcontext.TokenInfo + var capturedScopes []string + var scopesFound bool var nextHandlerCalled bool nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { nextHandlerCalled = true - capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context()) + capturedScopes, scopesFound = ghcontext.GetTokenScopes(r.Context()) w.WriteHeader(http.StatusOK) }) @@ -141,10 +142,9 @@ func TestWithPATScopes(t *testing.T) { assert.Equal(t, tt.expectNextHandlerCalled, nextHandlerCalled, "next handler called mismatch") - if tt.expectNextHandlerCalled && tt.tokenInfo != nil { - require.NotNil(t, capturedTokenInfo, "expected token info in context") - assert.Equal(t, tt.expectScopesFetched, capturedTokenInfo.ScopesFetched) - assert.Equal(t, tt.expectedScopes, capturedTokenInfo.Scopes) + if tt.expectNextHandlerCalled { + assert.Equal(t, tt.expectScopesFetched, scopesFound, "scopes found mismatch") + assert.Equal(t, tt.expectedScopes, capturedScopes) } }) } @@ -154,9 +154,12 @@ func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) { logger := slog.Default() var capturedTokenInfo *ghcontext.TokenInfo + var capturedScopes []string + var scopesFound bool nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context()) + capturedScopes, scopesFound = ghcontext.GetTokenScopes(r.Context()) w.WriteHeader(http.StatusOK) }) @@ -182,6 +185,6 @@ func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) { require.NotNil(t, capturedTokenInfo) assert.Equal(t, originalTokenInfo.Token, capturedTokenInfo.Token) assert.Equal(t, originalTokenInfo.TokenType, capturedTokenInfo.TokenType) - assert.True(t, capturedTokenInfo.ScopesFetched) - assert.Equal(t, []string{"repo", "user"}, capturedTokenInfo.Scopes) + assert.True(t, scopesFound) + assert.Equal(t, []string{"repo", "user"}, capturedScopes) } diff --git a/pkg/http/middleware/scope_challenge.go b/pkg/http/middleware/scope_challenge.go index 526797241..1a86bf93c 100644 --- a/pkg/http/middleware/scope_challenge.go +++ b/pkg/http/middleware/scope_challenge.go @@ -94,17 +94,19 @@ func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInter return } - // Get OAuth scopes from GitHub API - activeScopes, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token) - if err != nil { - next.ServeHTTP(w, r) - return + // Get OAuth scopes for Token. First check if scopes are already in context, then fetch from GitHub if not present. + // This allows Remote Server to pass scope info to avoid redundant GitHub API calls. + activeScopes, ok := ghcontext.GetTokenScopes(ctx) + if !ok || (len(activeScopes) == 0 && tokenInfo.Token != "") { + activeScopes, err = scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token) + if err != nil { + next.ServeHTTP(w, r) + return + } } // Store active scopes in context for downstream use - tokenInfo.Scopes = activeScopes - tokenInfo.ScopesFetched = true - ctx = ghcontext.WithTokenInfo(ctx, tokenInfo) + ctx = ghcontext.WithTokenScopes(ctx, activeScopes) r = r.WithContext(ctx) // Check if user has the required scopes diff --git a/pkg/http/middleware/token.go b/pkg/http/middleware/token.go index c362ea201..012bbabef 100644 --- a/pkg/http/middleware/token.go +++ b/pkg/http/middleware/token.go @@ -13,6 +13,16 @@ import ( func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Check if token info already exists in context, if it does, skip extraction. + // In remote setup, we may have already extracted token info earlier. + if _, ok := ghcontext.GetTokenInfo(ctx); ok { + // Token info already exists in context, skip extraction + next.ServeHTTP(w, r) + return + } + tokenType, token, err := utils.ParseAuthorizationHeader(r) if err != nil { // For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec @@ -25,7 +35,6 @@ func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handl return } - ctx := r.Context() ctx = ghcontext.WithTokenInfo(ctx, &ghcontext.TokenInfo{ Token: token, TokenType: tokenType,