diff --git a/models/git/protected_branch.go b/models/git/protected_branch.go index 13e1ced0e1849..1085c14cae663 100644 --- a/models/git/protected_branch.go +++ b/models/git/protected_branch.go @@ -466,11 +466,13 @@ func updateApprovalWhitelist(ctx context.Context, repo *repo_model.Repository, c return currentWhitelist, nil } + prUserIDs, err := access_model.GetUserIDsWithUnitAccess(ctx, repo, perm.AccessModeRead, unit.TypePullRequests) + if err != nil { + return nil, err + } whitelist = make([]int64, 0, len(newWhitelist)) for _, userID := range newWhitelist { - if reader, err := access_model.IsRepoReader(ctx, repo, userID); err != nil { - return nil, err - } else if !reader { + if !prUserIDs.Contains(userID) { continue } whitelist = append(whitelist, userID) diff --git a/models/organization/team_repo.go b/models/organization/team_repo.go index b3e266dbc7651..2652b34c6f7d8 100644 --- a/models/organization/team_repo.go +++ b/models/organization/team_repo.go @@ -53,24 +53,45 @@ func RemoveTeamRepo(ctx context.Context, teamID, repoID int64) error { // GetTeamsWithAccessToAnyRepoUnit returns all teams in an organization that have given access level to the repository special unit. // This function is only used for finding some teams that can be used as branch protection allowlist or reviewers, it isn't really used for access control. // FIXME: TEAM-UNIT-PERMISSION this logic is not complete, search the fixme keyword to see more details -func GetTeamsWithAccessToAnyRepoUnit(ctx context.Context, orgID, repoID int64, mode perm.AccessMode, unitType unit.Type, unitTypesMore ...unit.Type) ([]*Team, error) { - teams := make([]*Team, 0, 5) +func GetTeamsWithAccessToAnyRepoUnit(ctx context.Context, orgID, repoID int64, mode perm.AccessMode, unitType unit.Type, unitTypesMore ...unit.Type) (teams []*Team, err error) { + teamIDs, err := getTeamIDsWithAccessToAnyRepoUnit(ctx, orgID, repoID, mode, unitType, unitTypesMore...) + if err != nil { + return nil, err + } + if len(teamIDs) == 0 { + return teams, nil + } + err = db.GetEngine(ctx).Where(builder.In("id", teamIDs)).OrderBy("team.name").Find(&teams) + return teams, err +} +func getTeamIDsWithAccessToAnyRepoUnit(ctx context.Context, orgID, repoID int64, mode perm.AccessMode, unitType unit.Type, unitTypesMore ...unit.Type) (teamIDs []int64, err error) { sub := builder.Select("team_id").From("team_unit"). Where(builder.Expr("team_unit.team_id = team.id")). And(builder.In("team_unit.type", append([]unit.Type{unitType}, unitTypesMore...))). And(builder.Expr("team_unit.access_mode >= ?", mode)) - err := db.GetEngine(ctx). + err = db.GetEngine(ctx). + Select("team.id"). + Table("team"). Join("INNER", "team_repo", "team_repo.team_id = team.id"). - And("team_repo.org_id = ?", orgID). - And("team_repo.repo_id = ?", repoID). + And("team_repo.org_id = ? AND team_repo.repo_id = ?", orgID, repoID). And(builder.Or( builder.Expr("team.authorize >= ?", mode), builder.In("team.id", sub), )). - OrderBy("name"). - Find(&teams) + Find(&teamIDs) + return teamIDs, err +} - return teams, err +func GetTeamUserIDsWithAccessToAnyRepoUnit(ctx context.Context, orgID, repoID int64, mode perm.AccessMode, unitType unit.Type, unitTypesMore ...unit.Type) (userIDs []int64, err error) { + teamIDs, err := getTeamIDsWithAccessToAnyRepoUnit(ctx, orgID, repoID, mode, unitType, unitTypesMore...) + if err != nil { + return nil, err + } + if len(teamIDs) == 0 { + return userIDs, nil + } + err = db.GetEngine(ctx).Table("team_user").Select("uid").Where(builder.In("team_id", teamIDs)).Find(&userIDs) + return userIDs, err } diff --git a/models/perm/access/repo_permission.go b/models/perm/access/repo_permission.go index 034557db33a8d..15526cb1e6f1f 100644 --- a/models/perm/access/repo_permission.go +++ b/models/perm/access/repo_permission.go @@ -16,6 +16,7 @@ import ( repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/models/unit" user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/container" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/util" @@ -498,54 +499,44 @@ func HasAnyUnitAccess(ctx context.Context, userID int64, repo *repo_model.Reposi return perm.HasAnyUnitAccess(), nil } -// getUsersWithAccessMode returns users that have at least given access mode to the repository. -func getUsersWithAccessMode(ctx context.Context, repo *repo_model.Repository, mode perm_model.AccessMode) (_ []*user_model.User, err error) { - if err = repo.LoadOwner(ctx); err != nil { +func GetUsersWithUnitAccess(ctx context.Context, repo *repo_model.Repository, mode perm_model.AccessMode, unitType unit.Type) (users []*user_model.User, err error) { + userIDs, err := GetUserIDsWithUnitAccess(ctx, repo, mode, unitType) + if err != nil { return nil, err } + if len(userIDs) == 0 { + return users, nil + } + if err = db.GetEngine(ctx).In("id", userIDs.Values()).OrderBy("`name`").Find(&users); err != nil { + return nil, err + } + return users, nil +} +func GetUserIDsWithUnitAccess(ctx context.Context, repo *repo_model.Repository, mode perm_model.AccessMode, unitType unit.Type) (container.Set[int64], error) { + userIDs := container.Set[int64]{} e := db.GetEngine(ctx) accesses := make([]*Access, 0, 10) - if err = e.Where("repo_id = ? AND mode >= ?", repo.ID, mode).Find(&accesses); err != nil { + if err := e.Where("repo_id = ? AND mode >= ?", repo.ID, mode).Find(&accesses); err != nil { return nil, err } + for _, a := range accesses { + userIDs.Add(a.UserID) + } - // Leave a seat for owner itself to append later, but if owner is an organization - // and just waste 1 unit is cheaper than re-allocate memory once. - users := make([]*user_model.User, 0, len(accesses)+1) - if len(accesses) > 0 { - userIDs := make([]int64, len(accesses)) - for i := 0; i < len(accesses); i++ { - userIDs[i] = accesses[i].UserID - } - - if err = e.In("id", userIDs).Find(&users); err != nil { - return nil, err - } + if err := repo.LoadOwner(ctx); err != nil { + return nil, err } if !repo.Owner.IsOrganization() { - users = append(users, repo.Owner) - } - - return users, nil -} - -// GetRepoReaders returns all users that have explicit read access or higher to the repository. -func GetRepoReaders(ctx context.Context, repo *repo_model.Repository) (_ []*user_model.User, err error) { - return getUsersWithAccessMode(ctx, repo, perm_model.AccessModeRead) -} - -// GetRepoWriters returns all users that have write access to the repository. -func GetRepoWriters(ctx context.Context, repo *repo_model.Repository) (_ []*user_model.User, err error) { - return getUsersWithAccessMode(ctx, repo, perm_model.AccessModeWrite) -} - -// IsRepoReader returns true if user has explicit read access or higher to the repository. -func IsRepoReader(ctx context.Context, repo *repo_model.Repository, userID int64) (bool, error) { - if repo.OwnerID == userID { - return true, nil + userIDs.Add(repo.Owner.ID) + } else { + teamUserIDs, err := organization.GetTeamUserIDsWithAccessToAnyRepoUnit(ctx, repo.OwnerID, repo.ID, mode, unitType) + if err != nil { + return nil, err + } + userIDs.AddMultiple(teamUserIDs...) } - return db.GetEngine(ctx).Where("repo_id = ? AND user_id = ? AND mode >= ?", repo.ID, userID, perm_model.AccessModeRead).Get(&Access{}) + return userIDs, nil } // CheckRepoUnitUser check whether user could visit the unit of this repository diff --git a/models/perm/access/repo_permission_test.go b/models/perm/access/repo_permission_test.go index d81dfba288e2c..a36be213ece1a 100644 --- a/models/perm/access/repo_permission_test.go +++ b/models/perm/access/repo_permission_test.go @@ -169,9 +169,9 @@ func TestGetUserRepoPermission(t *testing.T) { user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 4}) team := &organization.Team{OrgID: org.ID, LowerName: "test_team"} require.NoError(t, db.Insert(ctx, team)) + require.NoError(t, db.Insert(ctx, &organization.TeamUser{OrgID: org.ID, TeamID: team.ID, UID: user.ID})) t.Run("DoerInTeamWithNoRepo", func(t *testing.T) { - require.NoError(t, db.Insert(ctx, &organization.TeamUser{OrgID: org.ID, TeamID: team.ID, UID: user.ID})) perm, err := GetUserRepoPermission(ctx, repo32, user) require.NoError(t, err) assert.Equal(t, perm_model.AccessModeRead, perm.AccessMode) @@ -219,6 +219,15 @@ func TestGetUserRepoPermission(t *testing.T) { assert.Equal(t, perm_model.AccessModeNone, perm.AccessMode) assert.Equal(t, perm_model.AccessModeNone, perm.unitsMode[unit.TypeCode]) assert.Equal(t, perm_model.AccessModeRead, perm.unitsMode[unit.TypeIssues]) + + users, err := GetUsersWithUnitAccess(ctx, repo3, perm_model.AccessModeRead, unit.TypeIssues) + require.NoError(t, err) + require.Len(t, users, 1) + assert.Equal(t, user.ID, users[0].ID) + + users, err = GetUsersWithUnitAccess(ctx, repo3, perm_model.AccessModeWrite, unit.TypeIssues) + require.NoError(t, err) + require.Empty(t, users) }) require.NoError(t, db.Insert(ctx, repo_model.Collaboration{RepoID: repo3.ID, UserID: user.ID, Mode: perm_model.AccessModeWrite})) @@ -229,5 +238,10 @@ func TestGetUserRepoPermission(t *testing.T) { assert.Equal(t, perm_model.AccessModeWrite, perm.AccessMode) assert.Equal(t, perm_model.AccessModeWrite, perm.unitsMode[unit.TypeCode]) assert.Equal(t, perm_model.AccessModeWrite, perm.unitsMode[unit.TypeIssues]) + + users, err := GetUsersWithUnitAccess(ctx, repo3, perm_model.AccessModeWrite, unit.TypeIssues) + require.NoError(t, err) + require.Len(t, users, 1) + assert.Equal(t, user.ID, users[0].ID) }) } diff --git a/routers/web/repo/setting/protected_branch.go b/routers/web/repo/setting/protected_branch.go index c7bde087a2d99..4374e95340a1c 100644 --- a/routers/web/repo/setting/protected_branch.go +++ b/routers/web/repo/setting/protected_branch.go @@ -73,10 +73,9 @@ func SettingsProtectedBranch(c *context.Context) { c.Data["PageIsSettingsBranches"] = true c.Data["Title"] = c.Locale.TrString("repo.settings.protected_branch") + " - " + rule.RuleName - - users, err := access_model.GetRepoReaders(c, c.Repo.Repository) + users, err := access_model.GetUsersWithUnitAccess(c, c.Repo.Repository, perm.AccessModeRead, unit.TypePullRequests) if err != nil { - c.ServerError("Repo.Repository.GetReaders", err) + c.ServerError("GetUsersWithUnitAccess", err) return } c.Data["Users"] = users diff --git a/routers/web/repo/setting/protected_tag.go b/routers/web/repo/setting/protected_tag.go index 50f5a28c4c7f2..4b560e6f22451 100644 --- a/routers/web/repo/setting/protected_tag.go +++ b/routers/web/repo/setting/protected_tag.go @@ -149,9 +149,9 @@ func setTagsContext(ctx *context.Context) error { } ctx.Data["ProtectedTags"] = protectedTags - users, err := access_model.GetRepoReaders(ctx, ctx.Repo.Repository) + users, err := access_model.GetUsersWithUnitAccess(ctx, ctx.Repo.Repository, perm.AccessModeRead, unit.TypePullRequests) if err != nil { - ctx.ServerError("Repo.Repository.GetReaders", err) + ctx.ServerError("GetUsersWithUnitAccess", err) return err } ctx.Data["Users"] = users diff --git a/services/convert/convert.go b/services/convert/convert.go index 9f8fff970ccc0..8b10d9364024c 100644 --- a/services/convert/convert.go +++ b/services/convert/convert.go @@ -139,7 +139,7 @@ func getWhitelistEntities[T *user_model.User | *organization.Team](entities []T, // ToBranchProtection convert a ProtectedBranch to api.BranchProtection func ToBranchProtection(ctx context.Context, bp *git_model.ProtectedBranch, repo *repo_model.Repository) *api.BranchProtection { - readers, err := access_model.GetRepoReaders(ctx, repo) + readers, err := access_model.GetUsersWithUnitAccess(ctx, repo, perm.AccessModeRead, unit.TypePullRequests) if err != nil { log.Error("GetRepoReaders: %v", err) } @@ -720,7 +720,7 @@ func ToAnnotatedTagObject(repo *repo_model.Repository, commit *git.Commit) *api. // ToTagProtection convert a git.ProtectedTag to an api.TagProtection func ToTagProtection(ctx context.Context, pt *git_model.ProtectedTag, repo *repo_model.Repository) *api.TagProtection { - readers, err := access_model.GetRepoReaders(ctx, repo) + readers, err := access_model.GetUsersWithUnitAccess(ctx, repo, perm.AccessModeRead, unit.TypePullRequests) if err != nil { log.Error("GetRepoReaders: %v", err) }