diff --git a/cmd/github-mcp-server/generate_docs.go b/cmd/github-mcp-server/generate_docs.go index 4d03c446e..478b35bd4 100644 --- a/cmd/github-mcp-server/generate_docs.go +++ b/cmd/github-mcp-server/generate_docs.go @@ -64,7 +64,7 @@ func generateReadmeDocs(readmePath string) error { t, _ := translations.TranslationHelper() // Create toolset group with mock clients - tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000) + tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}) // Generate toolsets documentation toolsetsDoc := generateToolsetsDoc(tsg) @@ -302,7 +302,7 @@ func generateRemoteToolsetsDoc() string { t, _ := translations.TranslationHelper() // Create toolset group with mock clients - tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000) + tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}) // Generate table header buf.WriteString("| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) |\n") diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 647ec1d19..125cd5a8d 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -61,6 +61,7 @@ var ( EnableCommandLogging: viper.GetBool("enable-command-logging"), LogFilePath: viper.GetString("log-file"), ContentWindowSize: viper.GetInt("content-window-size"), + LockdownMode: viper.GetBool("lockdown-mode"), } return ghmcp.RunStdioServer(stdioServerConfig) }, @@ -82,6 +83,7 @@ func init() { rootCmd.PersistentFlags().Bool("export-translations", false, "Save translations to a JSON file") rootCmd.PersistentFlags().String("gh-host", "", "Specify the GitHub hostname (for GitHub Enterprise etc.)") rootCmd.PersistentFlags().Int("content-window-size", 5000, "Specify the content window size") + rootCmd.PersistentFlags().Bool("lockdown-mode", false, "Enable lockdown mode") // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) @@ -92,6 +94,7 @@ func init() { _ = viper.BindPFlag("export-translations", rootCmd.PersistentFlags().Lookup("export-translations")) _ = viper.BindPFlag("host", rootCmd.PersistentFlags().Lookup("gh-host")) _ = viper.BindPFlag("content-window-size", rootCmd.PersistentFlags().Lookup("content-window-size")) + _ = viper.BindPFlag("lockdown-mode", rootCmd.PersistentFlags().Lookup("lockdown-mode")) // Add subcommands rootCmd.AddCommand(stdioCmd) diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 1e66f1eb3..0e338cfd9 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -51,6 +51,9 @@ type MCPServerConfig struct { // Content window size ContentWindowSize int + + // LockdownMode indicates if we should enable lockdown mode + LockdownMode bool } const stdioServerLogPrefix = "stdioserver" @@ -154,7 +157,15 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { } // Create default toolsets - tsg := github.DefaultToolsetGroup(cfg.ReadOnly, getClient, getGQLClient, getRawClient, cfg.Translator, cfg.ContentWindowSize) + tsg := github.DefaultToolsetGroup( + cfg.ReadOnly, + getClient, + getGQLClient, + getRawClient, + cfg.Translator, + cfg.ContentWindowSize, + github.FeatureFlags{LockdownMode: cfg.LockdownMode}, + ) err = tsg.EnableToolsets(enabledToolsets, nil) if err != nil { @@ -205,6 +216,9 @@ type StdioServerConfig struct { // Content window size ContentWindowSize int + + // LockdownMode indicates if we should enable lockdown mode + LockdownMode bool } // RunStdioServer is not concurrent safe. @@ -224,6 +238,7 @@ func RunStdioServer(cfg StdioServerConfig) error { ReadOnly: cfg.ReadOnly, Translator: t, ContentWindowSize: cfg.ContentWindowSize, + LockdownMode: cfg.LockdownMode, }) if err != nil { return fmt.Errorf("failed to create MCP server: %w", err) @@ -245,7 +260,7 @@ func RunStdioServer(cfg StdioServerConfig) error { slogHandler = slog.NewTextHandler(logOutput, &slog.HandlerOptions{Level: slog.LevelInfo}) } logger := slog.New(slogHandler) - logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "dynamicToolsets", cfg.DynamicToolsets, "readOnly", cfg.ReadOnly) + logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "dynamicToolsets", cfg.DynamicToolsets, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode) stdLogger := log.New(logOutput, stdioServerLogPrefix, 0) stdioServer.SetErrorLogger(stdLogger) diff --git a/pkg/github/feature_flags.go b/pkg/github/feature_flags.go new file mode 100644 index 000000000..047042e44 --- /dev/null +++ b/pkg/github/feature_flags.go @@ -0,0 +1,6 @@ +package github + +// FeatureFlags defines runtime feature toggles that adjust tool behavior. +type FeatureFlags struct { + LockdownMode bool +} diff --git a/pkg/github/issues.go b/pkg/github/issues.go index c83aac8ff..1032d4d04 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -10,6 +10,7 @@ import ( "time" ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/lockdown" "github.com/github/github-mcp-server/pkg/sanitize" "github.com/github/github-mcp-server/pkg/translations" "github.com/go-viper/mapstructure/v2" @@ -227,7 +228,7 @@ func fragmentToIssue(fragment IssueFragment) *github.Issue { } // GetIssue creates a tool to get details of a specific issue in a GitHub repository. -func IssueRead(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func IssueRead(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc, flags FeatureFlags) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("issue_read", mcp.WithDescription(t("TOOL_ISSUE_READ_DESCRIPTION", "Get information about a specific issue in a GitHub repository.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -296,20 +297,20 @@ Options are: switch method { case "get": - return GetIssue(ctx, client, owner, repo, issueNumber) + return GetIssue(ctx, client, gqlClient, owner, repo, issueNumber, flags) case "get_comments": - return GetIssueComments(ctx, client, owner, repo, issueNumber, pagination) + return GetIssueComments(ctx, client, owner, repo, issueNumber, pagination, flags) case "get_sub_issues": - return GetSubIssues(ctx, client, owner, repo, issueNumber, pagination) + return GetSubIssues(ctx, client, owner, repo, issueNumber, pagination, flags) case "get_labels": - return GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber) + return GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber, flags) default: return mcp.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil } } } -func GetIssue(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int) (*mcp.CallToolResult, error) { +func GetIssue(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner string, repo string, issueNumber int, flags FeatureFlags) (*mcp.CallToolResult, error) { issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) if err != nil { return nil, fmt.Errorf("failed to get issue: %w", err) @@ -324,6 +325,18 @@ func GetIssue(ctx context.Context, client *github.Client, owner string, repo str return mcp.NewToolResultError(fmt.Sprintf("failed to get issue: %s", string(body))), nil } + if flags.LockdownMode { + if issue.User != nil { + shouldRemoveContent, err := lockdown.ShouldRemoveContent(ctx, gqlClient, *issue.User.Login, owner, repo) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil + } + if shouldRemoveContent { + return mcp.NewToolResultError("access to issue details is restricted by lockdown mode"), nil + } + } + } + // Sanitize title/body on response if issue != nil { if issue.Title != nil { @@ -342,7 +355,7 @@ func GetIssue(ctx context.Context, client *github.Client, owner string, repo str return mcp.NewToolResultText(string(r)), nil } -func GetIssueComments(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, pagination PaginationParams) (*mcp.CallToolResult, error) { +func GetIssueComments(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, pagination PaginationParams, _ FeatureFlags) (*mcp.CallToolResult, error) { opts := &github.IssueListCommentsOptions{ ListOptions: github.ListOptions{ Page: pagination.Page, @@ -372,7 +385,7 @@ func GetIssueComments(ctx context.Context, client *github.Client, owner string, return mcp.NewToolResultText(string(r)), nil } -func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, pagination PaginationParams) (*mcp.CallToolResult, error) { +func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, pagination PaginationParams, _ FeatureFlags) (*mcp.CallToolResult, error) { opts := &github.IssueListOptions{ ListOptions: github.ListOptions{ Page: pagination.Page, @@ -407,7 +420,7 @@ func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo return mcp.NewToolResultText(string(r)), nil } -func GetIssueLabels(ctx context.Context, client *githubv4.Client, owner string, repo string, issueNumber int) (*mcp.CallToolResult, error) { +func GetIssueLabels(ctx context.Context, client *githubv4.Client, owner string, repo string, issueNumber int, _ FeatureFlags) (*mcp.CallToolResult, error) { // Get current labels on the issue using GraphQL var query struct { Repository struct { diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index ddd9104b3..d13b93e4b 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -22,8 +22,8 @@ import ( func Test_GetIssue(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - mockGQLClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQLClient), translations.NullTranslationHelper) + defaultGQLClient := githubv4.NewClient(nil) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(defaultGQLClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -44,15 +44,24 @@ func Test_GetIssue(t *testing.T) { User: &github.User{ Login: github.Ptr("testuser"), }, + Repository: &github.Repository{ + Name: github.Ptr("repo"), + Owner: &github.User{ + Login: github.Ptr("owner"), + }, + }, } tests := []struct { - name string - mockedClient *http.Client - requestArgs map[string]interface{} - expectError bool - expectedIssue *github.Issue - expectedErrMsg string + name string + mockedClient *http.Client + gqlHTTPClient *http.Client + requestArgs map[string]interface{} + expectHandlerError bool + expectResultError bool + expectedIssue *github.Issue + expectedErrMsg string + lockdownEnabled bool }{ { name: "successful issue retrieval", @@ -68,7 +77,6 @@ func Test_GetIssue(t *testing.T) { "repo": "repo", "issue_number": float64(42), }, - expectError: false, expectedIssue: mockIssue, }, { @@ -85,34 +93,147 @@ func Test_GetIssue(t *testing.T) { "repo": "repo", "issue_number": float64(999), }, - expectError: true, - expectedErrMsg: "failed to get issue", + expectHandlerError: true, + expectedErrMsg: "failed to get issue", + }, + { + name: "lockdown enabled - private repository", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposIssuesByOwnerByRepoByIssueNumber, + mockIssue, + ), + ), + gqlHTTPClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + IsPrivate githubv4.Boolean + Collaborators struct { + Edges []struct { + Permission githubv4.String + Node struct { + Login githubv4.String + } + } + } `graphql:"collaborators(query: $username, first: 1)"` + } `graphql:"repository(owner: $owner, name: $name)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "name": githubv4.String("repo"), + "username": githubv4.String("testuser"), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "isPrivate": true, + "collaborators": map[string]any{ + "edges": []any{}, + }, + }, + }), + ), + ), + requestArgs: map[string]interface{}{ + "method": "get", + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + }, + expectedIssue: mockIssue, + lockdownEnabled: true, + }, + { + name: "lockdown enabled - user lacks push access", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.GetReposIssuesByOwnerByRepoByIssueNumber, + mockIssue, + ), + ), + gqlHTTPClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + IsPrivate githubv4.Boolean + Collaborators struct { + Edges []struct { + Permission githubv4.String + Node struct { + Login githubv4.String + } + } + } `graphql:"collaborators(query: $username, first: 1)"` + } `graphql:"repository(owner: $owner, name: $name)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "name": githubv4.String("repo"), + "username": githubv4.String("testuser"), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "isPrivate": false, + "collaborators": map[string]any{ + "edges": []any{ + map[string]any{ + "permission": "READ", + "node": map[string]any{ + "login": "testuser", + }, + }, + }, + }, + }, + }), + ), + ), + requestArgs: map[string]interface{}{ + "method": "get", + "owner": "owner", + "repo": "repo", + "issue_number": float64(42), + }, + expectResultError: true, + expectedErrMsg: "access to issue details is restricted by lockdown mode", + lockdownEnabled: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(mockGQLClient), translations.NullTranslationHelper) - // Create call request - request := createMCPRequest(tc.requestArgs) + var gqlClient *githubv4.Client + if tc.gqlHTTPClient != nil { + gqlClient = githubv4.NewClient(tc.gqlHTTPClient) + } else { + gqlClient = defaultGQLClient + } - // Call handler + flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, flags) + + request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), request) - // Verify results - if tc.expectError { + if tc.expectHandlerError { require.Error(t, err) assert.Contains(t, err.Error(), tc.expectedErrMsg) return } require.NoError(t, err) + require.NotNil(t, result) + + if tc.expectResultError { + errorContent := getErrorResult(t, result) + assert.Contains(t, errorContent.Text, tc.expectedErrMsg) + return + } + textContent := getTextResult(t, result) - // Unmarshal and verify the result var returnedIssue github.Issue err = json.Unmarshal([]byte(textContent.Text), &returnedIssue) require.NoError(t, err) @@ -1589,7 +1710,7 @@ func Test_GetIssueComments(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) gqlClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -1695,7 +1816,7 @@ func Test_GetIssueComments(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1732,7 +1853,7 @@ func Test_GetIssueLabels(t *testing.T) { // Verify tool definition mockGQClient := githubv4.NewClient(nil) mockClient := github.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQClient), translations.NullTranslationHelper) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -1807,7 +1928,7 @@ func Test_GetIssueLabels(t *testing.T) { t.Run(tc.name, func(t *testing.T) { gqlClient := githubv4.NewClient(tc.mockedClient) client := github.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) request := createMCPRequest(tc.requestArgs) result, err := handler(context.Background(), request) @@ -2498,7 +2619,7 @@ func Test_GetSubIssues(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) gqlClient := githubv4.NewClient(nil) - tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "issue_read", tool.Name) @@ -2695,7 +2816,7 @@ func Test_GetSubIssues(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) gqlClient := githubv4.NewClient(nil) - _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper) + _, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index e08324544..117f92ecf 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -19,7 +19,7 @@ import ( ) // GetPullRequest creates a tool to get details of a specific pull request. -func PullRequestRead(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { +func PullRequestRead(getClient GetClientFn, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("pull_request_read", mcp.WithDescription(t("TOOL_PULL_REQUEST_READ_DESCRIPTION", "Get information on a specific pull request in GitHub repository.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -98,7 +98,7 @@ Possible options: case "get_reviews": return GetPullRequestReviews(ctx, client, owner, repo, pullNumber) case "get_comments": - return GetIssueComments(ctx, client, owner, repo, pullNumber, pagination) + return GetIssueComments(ctx, client, owner, repo, pullNumber, pagination, flags) default: return nil, fmt.Errorf("unknown method: %s", method) } diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 1a7635afb..4cc4480e9 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -21,7 +21,7 @@ import ( func Test_GetPullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -102,7 +102,7 @@ func Test_GetPullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper) + _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1133,7 +1133,7 @@ func Test_SearchPullRequests(t *testing.T) { func Test_GetPullRequestFiles(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1236,7 +1236,7 @@ func Test_GetPullRequestFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper) + _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1277,7 +1277,7 @@ func Test_GetPullRequestFiles(t *testing.T) { func Test_GetPullRequestStatus(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1404,7 +1404,7 @@ func Test_GetPullRequestStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper) + _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1566,7 +1566,7 @@ func Test_UpdatePullRequestBranch(t *testing.T) { func Test_GetPullRequestComments(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1658,7 +1658,7 @@ func Test_GetPullRequestComments(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper) + _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1700,7 +1700,7 @@ func Test_GetPullRequestComments(t *testing.T) { func Test_GetPullRequestReviews(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -1788,7 +1788,7 @@ func Test_GetPullRequestReviews(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper) + _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) @@ -2789,7 +2789,7 @@ func TestGetPullRequestDiff(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper) + tool, _ := PullRequestRead(stubGetClientFn(mockClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) require.NoError(t, toolsnaps.Test(tool.Name, tool)) assert.Equal(t, "pull_request_read", tool.Name) @@ -2847,7 +2847,7 @@ index 5d6e7b2..8a4f5c3 100644 // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper) + _, handler := PullRequestRead(stubGetClientFn(client), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false})) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 0550655a5..77752d090 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -38,6 +38,12 @@ func stubGetGQLClientFn(client *githubv4.Client) GetGQLClientFn { } } +func stubFeatureFlags(enabledFlags map[string]bool) FeatureFlags { + return FeatureFlags{ + LockdownMode: enabledFlags["lockdown-mode"], + } +} + func stubGetRawClientFn(client *raw.Client) raw.GetRawClientFn { return func(_ context.Context) (*raw.Client, error) { return client, nil diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 837880bf7..36c22e7a8 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -159,7 +159,7 @@ func GetDefaultToolsetIDs() []string { } } -func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc, contentWindowSize int) *toolsets.ToolsetGroup { +func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, getRawClient raw.GetRawClientFn, t translations.TranslationHelperFunc, contentWindowSize int, flags FeatureFlags) *toolsets.ToolsetGroup { tsg := toolsets.NewToolsetGroup(readOnly) // Define all available features with their default state (disabled) @@ -199,7 +199,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG ) issues := toolsets.NewToolset(ToolsetMetadataIssues.ID, ToolsetMetadataIssues.Description). AddReadTools( - toolsets.NewServerTool(IssueRead(getClient, getGQLClient, t)), + toolsets.NewServerTool(IssueRead(getClient, getGQLClient, t, flags)), toolsets.NewServerTool(SearchIssues(getClient, t)), toolsets.NewServerTool(ListIssues(getGQLClient, t)), toolsets.NewServerTool(ListIssueTypes(getClient, t)), @@ -224,7 +224,7 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG ) pullRequests := toolsets.NewToolset(ToolsetMetadataPullRequests.ID, ToolsetMetadataPullRequests.Description). AddReadTools( - toolsets.NewServerTool(PullRequestRead(getClient, t)), + toolsets.NewServerTool(PullRequestRead(getClient, t, flags)), toolsets.NewServerTool(ListPullRequests(getClient, t)), toolsets.NewServerTool(SearchPullRequests(getClient, t)), ). diff --git a/pkg/lockdown/lockdown.go b/pkg/lockdown/lockdown.go new file mode 100644 index 000000000..5a474f73c --- /dev/null +++ b/pkg/lockdown/lockdown.go @@ -0,0 +1,71 @@ +package lockdown + +import ( + "context" + "fmt" + "strings" + + "github.com/shurcooL/githubv4" +) + +// ShouldRemoveContent determines if content should be removed based on +// lockdown mode rules. It checks if the repository is private and if the user +// has push access to the repository. +func ShouldRemoveContent(ctx context.Context, client *githubv4.Client, username, owner, repo string) (bool, error) { + isPrivate, hasPushAccess, err := repoAccessInfo(ctx, client, username, owner, repo) + if err != nil { + return false, err + } + + // Do not filter content for private repositories + if isPrivate { + return false, nil + } + + return !hasPushAccess, nil +} + +func repoAccessInfo(ctx context.Context, client *githubv4.Client, username, owner, repo string) (bool, bool, error) { + if client == nil { + return false, false, fmt.Errorf("nil GraphQL client") + } + + var query struct { + Repository struct { + IsPrivate githubv4.Boolean + Collaborators struct { + Edges []struct { + Permission githubv4.String + Node struct { + Login githubv4.String + } + } + } `graphql:"collaborators(query: $username, first: 1)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } + + variables := map[string]interface{}{ + "owner": githubv4.String(owner), + "name": githubv4.String(repo), + "username": githubv4.String(username), + } + + err := client.Query(ctx, &query, variables) + if err != nil { + return false, false, fmt.Errorf("failed to query repository access info: %w", err) + } + + // Check if the user has push access + hasPush := false + for _, edge := range query.Repository.Collaborators.Edges { + login := string(edge.Node.Login) + if strings.EqualFold(login, username) { + permission := string(edge.Permission) + // WRITE, ADMIN, and MAINTAIN permissions have push access + hasPush = permission == "WRITE" || permission == "ADMIN" || permission == "MAINTAIN" + break + } + } + + return bool(query.Repository.IsPrivate), hasPush, nil +}