Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot] 571877c327 Bump github.com/spf13/cobra from 1.8.0 to 1.8.1
Bumps [github.com/spf13/cobra](https://github.com/spf13/cobra) from 1.8.0 to 1.8.1.
- [Release notes](https://github.com/spf13/cobra/releases)
- [Commits](https://github.com/spf13/cobra/compare/v1.8.0...v1.8.1)

---
updated-dependencies:
- dependency-name: github.com/spf13/cobra
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-06-17 20:19:22 +00:00
21 changed files with 151 additions and 700 deletions
+1 -1
View File
@@ -1,2 +1,2 @@
* @actions/actions-oss-maintainers @actions/actions-sync-maintainers
* @actions/actions-delivery-nexus @actions/actions-oss-maintainers
+3 -3
View File
@@ -9,8 +9,8 @@ jobs:
- name: Checkout
uses: actions/checkout@v4
- name: Lint
run: docker compose run --rm lint
run: docker-compose run --rm lint
- name: Test
run: docker compose run --rm test
run: docker-compose run --rm test
- name: E2E
run: docker compose run --rm test-build
run: docker-compose run --rm test-build
+1 -1
View File
@@ -35,7 +35,7 @@ jobs:
- run: go mod vendor
# Ruby is required for licensed
- uses: ruby/setup-ruby@90be1154f987f4dc0fe0dd0feedac9e473aa4ba8 # v1
- uses: ruby/setup-ruby@6bd3d993c602f6b675728ebaecb2b569ff86e99b
with:
ruby-version: "3.2"
+4 -22
View File
@@ -14,24 +14,6 @@ It is designed to work when:
* The GitHub Enterprise instance is separate from the rest of the internet.
* The GitHub Enterprise instance is connected to the rest of the internet.
### Note
Thank you for your interest in this GitHub action, however, right now we are not taking contributions.
We continue to focus our resources on strategic areas that help our customers be successful while making developers' lives easier. While GitHub Actions remains a key part of this vision, we are allocating resources towards other areas of Actions and are not taking contributions to this repository at this time. The GitHub public roadmap is the best place to follow along for any updates on features were working on and what stage theyre in.
We are taking the following steps to better direct requests related to GitHub Actions, including:
1. We will be directing questions and support requests to our [Community Discussions area](https://github.com/orgs/community/discussions/categories/actions)
2. High Priority bugs can be reported through Community Discussions or you can report these to our support team https://support.github.com/contact/bug-report.
3. Security Issues should be handled as per our [security.md](security.md)
We will still provide security updates for this project and fix major breaking changes during this time.
You are welcome to still raise bugs in this repo.
## Connected instances
When there are machines which have access to both the public internet and the GHES instance run `actions-sync sync`.
@@ -56,8 +38,6 @@ When there are machines which have access to both the public internet and the GH
A path to a file containing a newline separated list of repositories to be synced. Each entry follows the format of `repo-name`.
- `actions-admin-user` _(optional)_
The name of the Actions admin user, which will be used for updating the chosen action. To use the default user, pass `actions-admin`. If not set, the impersonation is disabled. Note that `site_admin` scope is required in the token for the impersonation to work.
- `batch-size` _(optional)_
Number of refs to push in each batch. Default is 0 (no batching). Use a value like 100 if pushing fails for large repositories with many branches and tags.
**Example Usage:**
@@ -116,8 +96,6 @@ When no machine has access to both the public internet and the GHES instance:
Limit push to specific repositories in the cache directory.
- `actions-admin-user` _(optional)_
The name of the Actions admin user, which will be used for updating the chosen action. To use the default user, pass `actions-admin`. If not set, the impersonation is disabled. Note that `site_admin` scope is required in the token for the impersonation to work.
- `batch-size` _(optional)_
Number of refs to push in each batch. Default is 0 (no batching). Use a value like 100 if pushing fails for large repositories with many branches and tags.
**Example Usage:**
@@ -132,3 +110,7 @@ When no machine has access to both the public internet and the GHES instance:
When creating a personal access token include the `repo` and `workflow` scopes. Include the `site_admin` scope (optional) if you want organizations to be created as necessary or you want to use the impersonation logic for the `push` or `sync` commands.
## Contributing
If you would like to contribute your work back to the project, please see
[`CONTRIBUTING.md`](CONTRIBUTING.md).
+1 -1
View File
@@ -7,7 +7,7 @@ require (
github.com/google/go-github/v43 v43.0.0
github.com/gorilla/mux v1.8.1
github.com/pkg/errors v0.9.1
github.com/spf13/cobra v1.8.0
github.com/spf13/cobra v1.8.1
github.com/stretchr/testify v1.9.0
golang.org/x/oauth2 v0.19.0
)
+3 -3
View File
@@ -13,7 +13,7 @@ github.com/bwesterb/go-ristretto v1.2.3/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7N
github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA=
github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU=
github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA=
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/cyphar/filepath-securejoin v0.2.4 h1:Ugdm7cg7i6ZK6x3xDF1oEu1nfkyfH53EtKeQYTC3kyg=
github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -73,8 +73,8 @@ github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/skeema/knownhosts v1.2.2 h1:Iug2P4fLmDw9f41PB6thxUkNUkJzB5i+1/exaj40L3A=
github.com/skeema/knownhosts v1.2.2/go.mod h1:xYbVRSPxqBZFrdmDyMmsOs+uX1UZC3nTN3ThzgDxUwo=
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+4 -4
View File
@@ -18,12 +18,12 @@ if [ ! -f go.mod ]; then
go mod init tools
fi
go get golang.org/x/tools/go/packages@v0.16.0
go get golang.org/x/tools/go/packages@master
if [ ! -f "${GOBIN}/mockgen" ]; then
echo "mockgen was not found, installing..."
go get github.com/golang/mock/gomock@v1.6.0
go get github.com/golang/mock/mockgen@v1.6.0
go get github.com/golang/mock/gomock@master
go get github.com/golang/mock/mockgen@master
fi
if [ ! -f "${GOBIN}/golangci-lint" ]; then
@@ -33,5 +33,5 @@ fi
if [ ! -f "${GOBIN}/goimports" ]; then
echo "goimports was not found, installing..."
go get golang.org/x/tools/cmd/goimports@v0.16.0
go get golang.org/x/tools/cmd/goimports@master
fi
-6
View File
@@ -5,7 +5,6 @@ import (
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/config"
"github.com/go-git/go-git/v5/plumbing/storer"
)
// A really thin Git wrapper so we can stub it out in our tests
@@ -20,7 +19,6 @@ type GitRepository interface {
DeleteRemote(string) error
CreateRemote(*config.RemoteConfig) (GitRemote, error)
FetchContext(context.Context, *git.FetchOptions) error
References() (storer.ReferenceIter, error)
}
type GitRemote interface {
@@ -67,7 +65,3 @@ func (r *gitRepository) CreateRemote(c *config.RemoteConfig) (GitRemote, error)
func (r *gitRepository) FetchContext(ctx context.Context, o *git.FetchOptions) error {
return r.inner.FetchContext(ctx, o)
}
func (r *gitRepository) References() (storer.ReferenceIter, error) {
return r.inner.References()
}
-75
View File
@@ -1,75 +0,0 @@
package src
import (
"context"
"testing"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/config"
"github.com/go-git/go-git/v5/plumbing/storer"
"github.com/stretchr/testify/assert"
)
// Tests for GitRepository interface and implementations
func TestGitRepositoryInterface(t *testing.T) {
// This test verifies that our mock implements the GitRepository interface
var _ GitRepository = &mockGitRepository{}
}
func TestGitRemoteInterface(t *testing.T) {
// This test verifies that our mock implements the GitRemote interface
var _ GitRemote = &mockGitRemote{}
}
// Ensure the mockGitRepository implements all methods of GitRepository
func TestMockGitRepository_DeleteRemote(t *testing.T) {
repo := &mockGitRepository{}
err := repo.DeleteRemote("origin")
assert.NoError(t, err)
}
func TestMockGitRepository_CreateRemote(t *testing.T) {
repo := &mockGitRepository{}
remote, err := repo.CreateRemote(&config.RemoteConfig{Name: "test"})
assert.NoError(t, err)
assert.Nil(t, remote)
}
func TestMockGitRepository_FetchContext(t *testing.T) {
repo := &mockGitRepository{}
err := repo.FetchContext(context.Background(), &git.FetchOptions{})
assert.NoError(t, err)
}
func TestMockGitRepository_References(t *testing.T) {
repo := &mockGitRepository{}
refs, err := repo.References()
assert.NoError(t, err)
assert.NotNil(t, refs)
// Verify it returns a valid iterator
_, ok := refs.(storer.ReferenceIter)
assert.True(t, ok)
}
// Ensure the mockGitRemote implements all methods of GitRemote
func TestMockGitRemote_PushContext(t *testing.T) {
remote := &mockGitRemote{}
err := remote.PushContext(context.Background(), &git.PushOptions{})
assert.NoError(t, err)
}
func TestMockGitRemote_Config(t *testing.T) {
remote := &mockGitRemote{}
cfg := remote.Config()
assert.NotNil(t, cfg)
assert.Equal(t, "test-remote", cfg.Name)
// Test with custom config
customRemote := &mockGitRemote{
remoteConfig: &config.RemoteConfig{Name: "custom-remote"},
}
cfg = customRemote.Config()
assert.Equal(t, "custom-remote", cfg.Name)
}
+10 -92
View File
@@ -9,7 +9,6 @@ import (
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/config"
"github.com/go-git/go-git/v5/plumbing"
"github.com/go-git/go-git/v5/plumbing/transport"
"github.com/go-git/go-git/v5/plumbing/transport/http"
"github.com/google/go-github/v43/github"
@@ -23,16 +22,9 @@ const enterpriseAPIPath = "/api/v3"
const enterpriseVersionHeaderKey = "X-GitHub-Enterprise-Version"
const xOAuthScopesHeader = "X-OAuth-Scopes"
// DefaultBatchSize of 0 means no batching (push all refs at once, original behavior)
const DefaultBatchSize = 0
// MinBatchSize is the minimum allowed batch size when batching is enabled
const MinBatchSize = 10
type PushOnlyFlags struct {
BaseURL, Token, ActionsAdminUser string
DisableGitAuth bool
BatchSize int
}
type PushFlags struct {
@@ -50,7 +42,6 @@ func (f *PushOnlyFlags) Init(cmd *cobra.Command) {
cmd.Flags().StringVar(&f.ActionsAdminUser, "actions-admin-user", "", "A user to impersonate for the push requests. To use the default name, pass 'actions-admin'. Note that the site_admin scope in the token is required for the impersonation to work.")
cmd.Flags().StringVar(&f.Token, "destination-token", "", "Token to access API on GHES instance")
cmd.Flags().BoolVar(&f.DisableGitAuth, "disable-push-git-auth", false, "Disables git authentication whilst pushing")
cmd.Flags().IntVar(&f.BatchSize, "batch-size", DefaultBatchSize, "Number of refs to push in each batch (0 = no batching). Use a value like 100 if pushing fails for large repositories.")
}
func (f *PushFlags) Validate() Validations {
@@ -65,9 +56,6 @@ func (f *PushOnlyFlags) Validate() Validations {
if f.Token == "" {
validations = append(validations, "--destination-token must be set")
}
if f.BatchSize != 0 && f.BatchSize < MinBatchSize {
validations = append(validations, fmt.Sprintf("--batch-size must be 0 (no batching) or at least %d", MinBatchSize))
}
return validations
}
@@ -294,86 +282,16 @@ func syncWithCachedRepository(ctx context.Context, flags *PushFlags, ghRepo *git
Password: flags.Token,
}
}
// If batch size is 0 or negative, use original wildcard approach (no batching)
if flags.BatchSize <= 0 {
err = remote.PushContext(ctx, &git.PushOptions{
RemoteName: remote.Config().Name,
RefSpecs: []config.RefSpec{
"+refs/heads/*:refs/heads/*",
"+refs/tags/*:refs/tags/*",
},
Auth: auth,
})
if errors.Cause(err) == git.NoErrAlreadyUpToDate {
return nil
}
return errors.Wrapf(err, "failed to push to repo: %s", ghRepo.GetCloneURL())
}
// Batching requested - collect all refs and push in batches
refs, err := collectRefs(gitRepo)
if err != nil {
return errors.Wrap(err, "error collecting refs")
}
return pushRefsInBatches(ctx, remote, refs, flags.BatchSize, auth, ghRepo.GetCloneURL())
}
// collectRefs gathers all branch and tag refs from the repository
func collectRefs(gitRepo GitRepository) ([]plumbing.ReferenceName, error) {
refIter, err := gitRepo.References()
if err != nil {
return nil, err
}
var refs []plumbing.ReferenceName
err = refIter.ForEach(func(ref *plumbing.Reference) error {
name := ref.Name()
// Only include branches and tags
if name.IsBranch() || name.IsTag() {
refs = append(refs, name)
}
return nil
err = remote.PushContext(ctx, &git.PushOptions{
RemoteName: remote.Config().Name,
RefSpecs: []config.RefSpec{
"+refs/heads/*:refs/heads/*",
"+refs/tags/*:refs/tags/*",
},
Auth: auth,
})
if err != nil {
return nil, err
if errors.Cause(err) == git.NoErrAlreadyUpToDate {
return nil
}
return refs, nil
}
// pushRefsInBatches pushes refs in smaller batches to avoid server-side limits
func pushRefsInBatches(ctx context.Context, remote GitRemote, refs []plumbing.ReferenceName, batchSize int, auth transport.AuthMethod, cloneURL string) error {
totalRefs := len(refs)
for i := 0; i < totalRefs; i += batchSize {
end := i + batchSize
if end > totalRefs {
end = totalRefs
}
batch := refs[i:end]
refSpecs := make([]config.RefSpec, len(batch))
for j, ref := range batch {
// Create a refspec like "+refs/heads/main:refs/heads/main"
refSpecs[j] = config.RefSpec("+" + ref.String() + ":" + ref.String())
}
err := remote.PushContext(ctx, &git.PushOptions{
RemoteName: remote.Config().Name,
RefSpecs: refSpecs,
Auth: auth,
})
if err != nil {
if errors.Cause(err) == git.NoErrAlreadyUpToDate {
// This batch was already up to date, continue to next batch
continue
}
return errors.Wrapf(err, "failed to push batch %d-%d of %d refs to repo: %s", i+1, end, totalRefs, cloneURL)
}
}
return nil
return errors.Wrapf(err, "failed to push to repo: %s", ghRepo.GetCloneURL())
}
-400
View File
@@ -1,400 +0,0 @@
package src
import (
"context"
"fmt"
"testing"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/config"
"github.com/go-git/go-git/v5/plumbing"
"github.com/go-git/go-git/v5/plumbing/storer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Mock implementations for testing
type mockReferenceIter struct {
refs []*plumbing.Reference
index int
}
func (m *mockReferenceIter) Next() (*plumbing.Reference, error) {
if m.index >= len(m.refs) {
return nil, storer.ErrStop
}
ref := m.refs[m.index]
m.index++
return ref, nil
}
func (m *mockReferenceIter) ForEach(fn func(*plumbing.Reference) error) error {
for _, ref := range m.refs {
if err := fn(ref); err != nil {
if err == storer.ErrStop {
return nil
}
return err
}
}
return nil
}
func (m *mockReferenceIter) Close() {}
type mockGitRepository struct {
refs []*plumbing.Reference
err error
}
func (m *mockGitRepository) DeleteRemote(name string) error {
return nil
}
func (m *mockGitRepository) CreateRemote(c *config.RemoteConfig) (GitRemote, error) {
return nil, nil
}
func (m *mockGitRepository) FetchContext(ctx context.Context, o *git.FetchOptions) error {
return nil
}
func (m *mockGitRepository) References() (storer.ReferenceIter, error) {
if m.err != nil {
return nil, m.err
}
return &mockReferenceIter{refs: m.refs, index: 0}, nil
}
type mockGitRemote struct {
pushCalls [][]config.RefSpec
pushError error
alreadyUpToDate bool
remoteConfig *config.RemoteConfig
}
func (m *mockGitRemote) PushContext(ctx context.Context, o *git.PushOptions) error {
m.pushCalls = append(m.pushCalls, o.RefSpecs)
if m.alreadyUpToDate {
return git.NoErrAlreadyUpToDate
}
return m.pushError
}
func (m *mockGitRemote) Config() *config.RemoteConfig {
if m.remoteConfig != nil {
return m.remoteConfig
}
return &config.RemoteConfig{Name: "test-remote"}
}
// Tests for PushOnlyFlags.Validate batch size validation
func TestPushOnlyFlags_Validate_BatchSize(t *testing.T) {
tests := []struct {
name string
batchSize int
expectErr bool
errMessage string
}{
{
name: "batch size 0 (no batching) is valid",
batchSize: 0,
expectErr: false,
},
{
name: "batch size at minimum (10) is valid",
batchSize: MinBatchSize,
expectErr: false,
},
{
name: "batch size above minimum is valid",
batchSize: 100,
expectErr: false,
},
{
name: "batch size below minimum is invalid",
batchSize: 5,
expectErr: true,
errMessage: fmt.Sprintf("--batch-size must be 0 (no batching) or at least %d", MinBatchSize),
},
{
name: "batch size of 1 is invalid",
batchSize: 1,
expectErr: true,
errMessage: fmt.Sprintf("--batch-size must be 0 (no batching) or at least %d", MinBatchSize),
},
{
name: "batch size of 9 is invalid",
batchSize: 9,
expectErr: true,
errMessage: fmt.Sprintf("--batch-size must be 0 (no batching) or at least %d", MinBatchSize),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
flags := PushOnlyFlags{
BaseURL: "https://example.com",
Token: "test-token",
BatchSize: tt.batchSize,
}
validations := flags.Validate()
if tt.expectErr {
require.NotEmpty(t, validations, "expected validation error")
found := false
for _, v := range validations {
if v == tt.errMessage {
found = true
break
}
}
assert.True(t, found, "expected error message not found: %s", tt.errMessage)
} else {
// Check that batch size validation didn't add an error
for _, v := range validations {
assert.NotContains(t, v, "batch-size", "unexpected batch-size validation error")
}
}
})
}
}
// Tests for collectRefs function
func TestCollectRefs(t *testing.T) {
tests := []struct {
name string
refs []*plumbing.Reference
expectedLen int
expectedRefs []plumbing.ReferenceName
expectErr bool
}{
{
name: "empty repository",
refs: []*plumbing.Reference{},
expectedLen: 0,
},
{
name: "branches only",
refs: []*plumbing.Reference{
plumbing.NewHashReference(plumbing.NewBranchReferenceName("main"), plumbing.NewHash("abc123")),
plumbing.NewHashReference(plumbing.NewBranchReferenceName("feature"), plumbing.NewHash("def456")),
},
expectedLen: 2,
expectedRefs: []plumbing.ReferenceName{
plumbing.NewBranchReferenceName("main"),
plumbing.NewBranchReferenceName("feature"),
},
},
{
name: "tags only",
refs: []*plumbing.Reference{
plumbing.NewHashReference(plumbing.NewTagReferenceName("v1.0.0"), plumbing.NewHash("abc123")),
plumbing.NewHashReference(plumbing.NewTagReferenceName("v2.0.0"), plumbing.NewHash("def456")),
},
expectedLen: 2,
expectedRefs: []plumbing.ReferenceName{
plumbing.NewTagReferenceName("v1.0.0"),
plumbing.NewTagReferenceName("v2.0.0"),
},
},
{
name: "mixed branches and tags",
refs: []*plumbing.Reference{
plumbing.NewHashReference(plumbing.NewBranchReferenceName("main"), plumbing.NewHash("abc123")),
plumbing.NewHashReference(plumbing.NewTagReferenceName("v1.0.0"), plumbing.NewHash("def456")),
plumbing.NewHashReference(plumbing.NewBranchReferenceName("develop"), plumbing.NewHash("ghi789")),
},
expectedLen: 3,
},
{
name: "filters out HEAD and other refs",
refs: []*plumbing.Reference{
plumbing.NewHashReference(plumbing.HEAD, plumbing.NewHash("abc123")),
plumbing.NewHashReference(plumbing.NewBranchReferenceName("main"), plumbing.NewHash("def456")),
plumbing.NewHashReference(plumbing.NewRemoteReferenceName("origin", "main"), plumbing.NewHash("ghi789")),
plumbing.NewHashReference(plumbing.NewTagReferenceName("v1.0.0"), plumbing.NewHash("jkl012")),
},
expectedLen: 2, // Only main branch and v1.0.0 tag
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &mockGitRepository{refs: tt.refs}
refs, err := collectRefs(repo)
if tt.expectErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Len(t, refs, tt.expectedLen)
if tt.expectedRefs != nil {
for i, expected := range tt.expectedRefs {
assert.Equal(t, expected, refs[i])
}
}
})
}
}
func TestCollectRefs_Error(t *testing.T) {
repo := &mockGitRepository{err: fmt.Errorf("failed to get references")}
refs, err := collectRefs(repo)
require.Error(t, err)
assert.Nil(t, refs)
assert.Contains(t, err.Error(), "failed to get references")
}
// Tests for pushRefsInBatches function
func TestPushRefsInBatches(t *testing.T) {
tests := []struct {
name string
refs []plumbing.ReferenceName
batchSize int
expectedBatches int
alreadyUpToDate bool
pushError error
expectErr bool
expectedErrSubstr string
}{
{
name: "single batch - fewer refs than batch size",
refs: []plumbing.ReferenceName{
plumbing.NewBranchReferenceName("main"),
plumbing.NewBranchReferenceName("feature"),
},
batchSize: 10,
expectedBatches: 1,
},
{
name: "single batch - exact batch size",
refs: createNRefs(10),
batchSize: 10,
expectedBatches: 1,
},
{
name: "multiple batches - exactly divisible",
refs: createNRefs(30),
batchSize: 10,
expectedBatches: 3,
},
{
name: "multiple batches - not exactly divisible",
refs: createNRefs(25),
batchSize: 10,
expectedBatches: 3, // 10 + 10 + 5
},
{
name: "empty refs",
refs: []plumbing.ReferenceName{},
batchSize: 10,
expectedBatches: 0,
},
{
name: "all batches already up to date",
refs: []plumbing.ReferenceName{
plumbing.NewBranchReferenceName("main"),
},
batchSize: 10,
expectedBatches: 1,
alreadyUpToDate: true,
},
{
name: "push error",
refs: []plumbing.ReferenceName{
plumbing.NewBranchReferenceName("main"),
},
batchSize: 10,
pushError: fmt.Errorf("network error"),
expectErr: true,
expectedErrSubstr: "failed to push batch",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
remote := &mockGitRemote{
alreadyUpToDate: tt.alreadyUpToDate,
pushError: tt.pushError,
}
err := pushRefsInBatches(context.Background(), remote, tt.refs, tt.batchSize, nil, "https://example.com/repo.git")
if tt.expectErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.expectedErrSubstr)
return
}
require.NoError(t, err)
assert.Len(t, remote.pushCalls, tt.expectedBatches)
})
}
}
func TestPushRefsInBatches_RefSpecFormat(t *testing.T) {
refs := []plumbing.ReferenceName{
plumbing.NewBranchReferenceName("main"),
plumbing.NewTagReferenceName("v1.0.0"),
}
remote := &mockGitRemote{}
err := pushRefsInBatches(context.Background(), remote, refs, 10, nil, "https://example.com/repo.git")
require.NoError(t, err)
require.Len(t, remote.pushCalls, 1)
require.Len(t, remote.pushCalls[0], 2)
// Check refspec format: should be "+refs/heads/main:refs/heads/main"
assert.Equal(t, config.RefSpec("+refs/heads/main:refs/heads/main"), remote.pushCalls[0][0])
assert.Equal(t, config.RefSpec("+refs/tags/v1.0.0:refs/tags/v1.0.0"), remote.pushCalls[0][1])
}
func TestPushRefsInBatches_BatchSizes(t *testing.T) {
// Create 25 refs
refs := createNRefs(25)
batchSize := 10
remote := &mockGitRemote{}
err := pushRefsInBatches(context.Background(), remote, refs, batchSize, nil, "https://example.com/repo.git")
require.NoError(t, err)
require.Len(t, remote.pushCalls, 3)
// First batch should have 10 refs
assert.Len(t, remote.pushCalls[0], 10)
// Second batch should have 10 refs
assert.Len(t, remote.pushCalls[1], 10)
// Third batch should have 5 refs (remainder)
assert.Len(t, remote.pushCalls[2], 5)
}
// Tests for constants
func TestConstants(t *testing.T) {
assert.Equal(t, 0, DefaultBatchSize, "DefaultBatchSize should be 0 for backward compatibility")
assert.Equal(t, 10, MinBatchSize, "MinBatchSize should be 10")
}
// Helper function to create N test refs
func createNRefs(n int) []plumbing.ReferenceName {
refs := make([]plumbing.ReferenceName, n)
for i := 0; i < n; i++ {
refs[i] = plumbing.NewBranchReferenceName(fmt.Sprintf("branch-%d", i))
}
return refs
}
+8 -13
View File
@@ -26,33 +26,28 @@ linters:
- errcheck
#- exhaustive
#- funlen
- gas
#- gochecknoinits
- goconst
#- gocritic
- gocritic
#- gocyclo
#- gofmt
- gofmt
- goimports
- golint
#- gomnd
#- goprintffuncname
#- gosec
#- gosimple
- gosec
- gosimple
- govet
- ineffassign
- interfacer
#- lll
- maligned
- megacheck
#- misspell
- misspell
#- nakedret
#- noctx
#- nolintlint
- nolintlint
#- rowserrcheck
#- scopelint
#- staticcheck
- staticcheck
#- structcheck ! deprecated since v1.49.0; replaced by 'unused'
#- stylecheck
- stylecheck
#- typecheck
- unconvert
#- unparam
+3 -10
View File
@@ -17,21 +17,17 @@ package cobra
import (
"fmt"
"os"
"regexp"
"strings"
)
const (
activeHelpMarker = "_activeHelp_ "
// The below values should not be changed: programs will be using them explicitly
// in their user documentation, and users will be using them explicitly.
activeHelpEnvVarSuffix = "_ACTIVE_HELP"
activeHelpGlobalEnvVar = "COBRA_ACTIVE_HELP"
activeHelpEnvVarSuffix = "ACTIVE_HELP"
activeHelpGlobalEnvVar = configEnvVarGlobalPrefix + "_" + activeHelpEnvVarSuffix
activeHelpGlobalDisable = "0"
)
var activeHelpEnvVarPrefixSubstRegexp = regexp.MustCompile(`[^A-Z0-9_]`)
// AppendActiveHelp adds the specified string to the specified array to be used as ActiveHelp.
// Such strings will be processed by the completion script and will be shown as ActiveHelp
// to the user.
@@ -60,8 +56,5 @@ func GetActiveHelpConfig(cmd *Command) string {
// variable. It has the format <PROGRAM>_ACTIVE_HELP where <PROGRAM> is the name of the
// root command in upper case, with all non-ASCII-alphanumeric characters replaced by `_`.
func activeHelpEnvVar(name string) string {
// This format should not be changed: users will be using it explicitly.
activeHelpEnvVar := strings.ToUpper(fmt.Sprintf("%s%s", name, activeHelpEnvVarSuffix))
activeHelpEnvVar = activeHelpEnvVarPrefixSubstRegexp.ReplaceAllString(activeHelpEnvVar, "_")
return activeHelpEnvVar
return configEnvVar(name, activeHelpEnvVarSuffix)
}
+2 -2
View File
@@ -52,9 +52,9 @@ func OnlyValidArgs(cmd *Command, args []string) error {
if len(cmd.ValidArgs) > 0 {
// Remove any description that may be included in ValidArgs.
// A description is following a tab character.
var validArgs []string
validArgs := make([]string, 0, len(cmd.ValidArgs))
for _, v := range cmd.ValidArgs {
validArgs = append(validArgs, strings.Split(v, "\t")[0])
validArgs = append(validArgs, strings.SplitN(v, "\t", 2)[0])
}
for _, v := range args {
if !stringInSlice(v, validArgs) {
+10 -13
View File
@@ -597,19 +597,16 @@ func writeRequiredFlag(buf io.StringWriter, cmd *Command) {
if nonCompletableFlag(flag) {
return
}
for key := range flag.Annotations {
switch key {
case BashCompOneRequiredFlag:
format := " must_have_one_flag+=(\"--%s"
if flag.Value.Type() != "bool" {
format += "="
}
format += cbn
WriteStringAndCheck(buf, fmt.Sprintf(format, flag.Name))
if _, ok := flag.Annotations[BashCompOneRequiredFlag]; ok {
format := " must_have_one_flag+=(\"--%s"
if flag.Value.Type() != "bool" {
format += "="
}
format += cbn
WriteStringAndCheck(buf, fmt.Sprintf(format, flag.Name))
if len(flag.Shorthand) > 0 {
WriteStringAndCheck(buf, fmt.Sprintf(" must_have_one_flag+=(\"-%s"+cbn, flag.Shorthand))
}
if len(flag.Shorthand) > 0 {
WriteStringAndCheck(buf, fmt.Sprintf(" must_have_one_flag+=(\"-%s"+cbn, flag.Shorthand))
}
}
})
@@ -621,7 +618,7 @@ func writeRequiredNouns(buf io.StringWriter, cmd *Command) {
for _, value := range cmd.ValidArgs {
// Remove any description that may be included following a tab character.
// Descriptions are not supported by bash completion.
value = strings.Split(value, "\t")[0]
value = strings.SplitN(value, "\t", 2)[0]
WriteStringAndCheck(buf, fmt.Sprintf(" must_have_one_noun+=(%q)\n", value))
}
if cmd.ValidArgsFunction != nil {
-2
View File
@@ -193,8 +193,6 @@ func ld(s, t string, ignoreCase bool) int {
d := make([][]int, len(s)+1)
for i := range d {
d[i] = make([]int, len(t)+1)
}
for i := range d {
d[i][0] = i
}
for j := range d[0] {
+31 -20
View File
@@ -154,8 +154,10 @@ type Command struct {
// pflags contains persistent flags.
pflags *flag.FlagSet
// lflags contains local flags.
// This field does not represent internal state, it's used as a cache to optimise LocalFlags function call
lflags *flag.FlagSet
// iflags contains inherited flags.
// This field does not represent internal state, it's used as a cache to optimise InheritedFlags function call
iflags *flag.FlagSet
// parentsPflags is all persistent flags of cmd's parents.
parentsPflags *flag.FlagSet
@@ -706,7 +708,7 @@ Loop:
// This is not a flag or a flag value. Check to see if it matches what we're looking for, and if so,
// return the args, excluding the one at this position.
if s == x {
ret := []string{}
ret := make([]string, 0, len(args)-1)
ret = append(ret, args[:pos]...)
ret = append(ret, args[pos+1:]...)
return ret
@@ -754,14 +756,14 @@ func (c *Command) findSuggestions(arg string) string {
if c.SuggestionsMinimumDistance <= 0 {
c.SuggestionsMinimumDistance = 2
}
suggestionsString := ""
var sb strings.Builder
if suggestions := c.SuggestionsFor(arg); len(suggestions) > 0 {
suggestionsString += "\n\nDid you mean this?\n"
sb.WriteString("\n\nDid you mean this?\n")
for _, s := range suggestions {
suggestionsString += fmt.Sprintf("\t%v\n", s)
_, _ = fmt.Fprintf(&sb, "\t%v\n", s)
}
}
return suggestionsString
return sb.String()
}
func (c *Command) findNext(next string) *Command {
@@ -873,7 +875,7 @@ func (c *Command) ArgsLenAtDash() int {
func (c *Command) execute(a []string) (err error) {
if c == nil {
return fmt.Errorf("Called Execute() on a nil Command")
return fmt.Errorf("called Execute() on a nil Command")
}
if len(c.Deprecated) > 0 {
@@ -1187,10 +1189,11 @@ func (c *Command) InitDefaultHelpFlag() {
c.mergePersistentFlags()
if c.Flags().Lookup("help") == nil {
usage := "help for "
if c.Name() == "" {
name := c.displayName()
if name == "" {
usage += "this command"
} else {
usage += c.Name()
usage += name
}
c.Flags().BoolP("help", "h", false, usage)
_ = c.Flags().SetAnnotation("help", FlagSetByCobraAnnotation, []string{"true"})
@@ -1236,7 +1239,7 @@ func (c *Command) InitDefaultHelpCmd() {
Use: "help [command]",
Short: "Help about any command",
Long: `Help provides help for any command in the application.
Simply type ` + c.Name() + ` help [path to command] for full details.`,
Simply type ` + c.displayName() + ` help [path to command] for full details.`,
ValidArgsFunction: func(c *Command, args []string, toComplete string) ([]string, ShellCompDirective) {
var completions []string
cmd, _, e := c.Root().Find(args)
@@ -1427,6 +1430,10 @@ func (c *Command) CommandPath() string {
if c.HasParent() {
return c.Parent().CommandPath() + " " + c.Name()
}
return c.displayName()
}
func (c *Command) displayName() string {
if displayName, ok := c.Annotations[CommandDisplayNameAnnotation]; ok {
return displayName
}
@@ -1436,10 +1443,11 @@ func (c *Command) CommandPath() string {
// UseLine puts out the full usage for a given command (including parents).
func (c *Command) UseLine() string {
var useline string
use := strings.Replace(c.Use, c.Name(), c.displayName(), 1)
if c.HasParent() {
useline = c.parent.CommandPath() + " " + c.Use
useline = c.parent.CommandPath() + " " + use
} else {
useline = c.Use
useline = use
}
if c.DisableFlagsInUseLine {
return useline
@@ -1452,7 +1460,6 @@ func (c *Command) UseLine() string {
// DebugFlags used to determine which flags have been assigned to which commands
// and which persist.
// nolint:goconst
func (c *Command) DebugFlags() {
c.Println("DebugFlags called on", c.Name())
var debugflags func(*Command)
@@ -1642,7 +1649,7 @@ func (c *Command) GlobalNormalizationFunc() func(f *flag.FlagSet, name string) f
// to this command (local and persistent declared here and by all parents).
func (c *Command) Flags() *flag.FlagSet {
if c.flags == nil {
c.flags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
c.flags = flag.NewFlagSet(c.displayName(), flag.ContinueOnError)
if c.flagErrorBuf == nil {
c.flagErrorBuf = new(bytes.Buffer)
}
@@ -1653,10 +1660,11 @@ func (c *Command) Flags() *flag.FlagSet {
}
// LocalNonPersistentFlags are flags specific to this command which will NOT persist to subcommands.
// This function does not modify the flags of the current command, it's purpose is to return the current state.
func (c *Command) LocalNonPersistentFlags() *flag.FlagSet {
persistentFlags := c.PersistentFlags()
out := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
out := flag.NewFlagSet(c.displayName(), flag.ContinueOnError)
c.LocalFlags().VisitAll(func(f *flag.Flag) {
if persistentFlags.Lookup(f.Name) == nil {
out.AddFlag(f)
@@ -1666,11 +1674,12 @@ func (c *Command) LocalNonPersistentFlags() *flag.FlagSet {
}
// LocalFlags returns the local FlagSet specifically set in the current command.
// This function does not modify the flags of the current command, it's purpose is to return the current state.
func (c *Command) LocalFlags() *flag.FlagSet {
c.mergePersistentFlags()
if c.lflags == nil {
c.lflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
c.lflags = flag.NewFlagSet(c.displayName(), flag.ContinueOnError)
if c.flagErrorBuf == nil {
c.flagErrorBuf = new(bytes.Buffer)
}
@@ -1693,11 +1702,12 @@ func (c *Command) LocalFlags() *flag.FlagSet {
}
// InheritedFlags returns all flags which were inherited from parent commands.
// This function does not modify the flags of the current command, it's purpose is to return the current state.
func (c *Command) InheritedFlags() *flag.FlagSet {
c.mergePersistentFlags()
if c.iflags == nil {
c.iflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
c.iflags = flag.NewFlagSet(c.displayName(), flag.ContinueOnError)
if c.flagErrorBuf == nil {
c.flagErrorBuf = new(bytes.Buffer)
}
@@ -1718,6 +1728,7 @@ func (c *Command) InheritedFlags() *flag.FlagSet {
}
// NonInheritedFlags returns all flags which were not inherited from parent commands.
// This function does not modify the flags of the current command, it's purpose is to return the current state.
func (c *Command) NonInheritedFlags() *flag.FlagSet {
return c.LocalFlags()
}
@@ -1725,7 +1736,7 @@ func (c *Command) NonInheritedFlags() *flag.FlagSet {
// PersistentFlags returns the persistent FlagSet specifically set in the current command.
func (c *Command) PersistentFlags() *flag.FlagSet {
if c.pflags == nil {
c.pflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
c.pflags = flag.NewFlagSet(c.displayName(), flag.ContinueOnError)
if c.flagErrorBuf == nil {
c.flagErrorBuf = new(bytes.Buffer)
}
@@ -1738,9 +1749,9 @@ func (c *Command) PersistentFlags() *flag.FlagSet {
func (c *Command) ResetFlags() {
c.flagErrorBuf = new(bytes.Buffer)
c.flagErrorBuf.Reset()
c.flags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
c.flags = flag.NewFlagSet(c.displayName(), flag.ContinueOnError)
c.flags.SetOutput(c.flagErrorBuf)
c.pflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
c.pflags = flag.NewFlagSet(c.displayName(), flag.ContinueOnError)
c.pflags.SetOutput(c.flagErrorBuf)
c.lflags = nil
@@ -1857,7 +1868,7 @@ func (c *Command) mergePersistentFlags() {
// If c.parentsPflags == nil, it makes new.
func (c *Command) updateParentsPflags() {
if c.parentsPflags == nil {
c.parentsPflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
c.parentsPflags = flag.NewFlagSet(c.displayName(), flag.ContinueOnError)
c.parentsPflags.SetOutput(c.flagErrorBuf)
c.parentsPflags.SortFlags = false
}
+50 -12
View File
@@ -17,6 +17,8 @@ package cobra
import (
"fmt"
"os"
"regexp"
"strconv"
"strings"
"sync"
@@ -211,24 +213,29 @@ func (c *Command) initCompleteCmd(args []string) {
// 2- Even without completions, we need to print the directive
}
noDescriptions := (cmd.CalledAs() == ShellCompNoDescRequestCmd)
noDescriptions := cmd.CalledAs() == ShellCompNoDescRequestCmd
if !noDescriptions {
if doDescriptions, err := strconv.ParseBool(getEnvConfig(cmd, configEnvVarSuffixDescriptions)); err == nil {
noDescriptions = !doDescriptions
}
}
noActiveHelp := GetActiveHelpConfig(finalCmd) == activeHelpGlobalDisable
out := finalCmd.OutOrStdout()
for _, comp := range completions {
if GetActiveHelpConfig(finalCmd) == activeHelpGlobalDisable {
// Remove all activeHelp entries in this case
if strings.HasPrefix(comp, activeHelpMarker) {
continue
}
if noActiveHelp && strings.HasPrefix(comp, activeHelpMarker) {
// Remove all activeHelp entries if it's disabled.
continue
}
if noDescriptions {
// Remove any description that may be included following a tab character.
comp = strings.Split(comp, "\t")[0]
comp = strings.SplitN(comp, "\t", 2)[0]
}
// Make sure we only write the first line to the output.
// This is needed if a description contains a linebreak.
// Otherwise the shell scripts will interpret the other lines as new flags
// and could therefore provide a wrong completion.
comp = strings.Split(comp, "\n")[0]
comp = strings.SplitN(comp, "\n", 2)[0]
// Finally trim the completion. This is especially important to get rid
// of a trailing tab when there are no description following it.
@@ -237,14 +244,14 @@ func (c *Command) initCompleteCmd(args []string) {
// although there is no description).
comp = strings.TrimSpace(comp)
// Print each possible completion to stdout for the completion script to consume.
fmt.Fprintln(finalCmd.OutOrStdout(), comp)
// Print each possible completion to the output for the completion script to consume.
fmt.Fprintln(out, comp)
}
// As the last printout, print the completion directive for the completion script to parse.
// The directive integer must be that last character following a single colon (:).
// The completion script expects :<directive>
fmt.Fprintf(finalCmd.OutOrStdout(), ":%d\n", directive)
fmt.Fprintf(out, ":%d\n", directive)
// Print some helpful info to stderr for the user to understand.
// Output from stderr must be ignored by the completion script.
@@ -291,7 +298,7 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi
}
if err != nil {
// Unable to find the real command. E.g., <program> someInvalidCmd <TAB>
return c, []string{}, ShellCompDirectiveDefault, fmt.Errorf("Unable to find a command for arguments: %v", trimmedArgs)
return c, []string{}, ShellCompDirectiveDefault, fmt.Errorf("unable to find a command for arguments: %v", trimmedArgs)
}
finalCmd.ctx = c.ctx
@@ -899,3 +906,34 @@ func CompError(msg string) {
func CompErrorln(msg string) {
CompError(fmt.Sprintf("%s\n", msg))
}
// These values should not be changed: users will be using them explicitly.
const (
configEnvVarGlobalPrefix = "COBRA"
configEnvVarSuffixDescriptions = "COMPLETION_DESCRIPTIONS"
)
var configEnvVarPrefixSubstRegexp = regexp.MustCompile(`[^A-Z0-9_]`)
// configEnvVar returns the name of the program-specific configuration environment
// variable. It has the format <PROGRAM>_<SUFFIX> where <PROGRAM> is the name of the
// root command in upper case, with all non-ASCII-alphanumeric characters replaced by `_`.
func configEnvVar(name, suffix string) string {
// This format should not be changed: users will be using it explicitly.
v := strings.ToUpper(fmt.Sprintf("%s_%s", name, suffix))
v = configEnvVarPrefixSubstRegexp.ReplaceAllString(v, "_")
return v
}
// getEnvConfig returns the value of the configuration environment variable
// <PROGRAM>_<SUFFIX> where <PROGRAM> is the name of the root command in upper
// case, with all non-ASCII-alphanumeric characters replaced by `_`.
// If the value is empty or not set, the value of the environment variable
// COBRA_<SUFFIX> is returned instead.
func getEnvConfig(cmd *Command, suffix string) string {
v := os.Getenv(configEnvVar(cmd.Root().Name(), suffix))
if v == "" {
v = os.Getenv(configEnvVar(configEnvVarGlobalPrefix, suffix))
}
return v
}
+17 -17
View File
@@ -23,9 +23,9 @@ import (
)
const (
requiredAsGroup = "cobra_annotation_required_if_others_set"
oneRequired = "cobra_annotation_one_required"
mutuallyExclusive = "cobra_annotation_mutually_exclusive"
requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
oneRequiredAnnotation = "cobra_annotation_one_required"
mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
)
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
@@ -37,7 +37,7 @@ func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
}
if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil {
if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil {
// Only errs if the flag isn't found.
panic(err)
}
@@ -53,7 +53,7 @@ func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
if f == nil {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
}
if err := c.Flags().SetAnnotation(v, oneRequired, append(f.Annotations[oneRequired], strings.Join(flagNames, " "))); err != nil {
if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil {
// Only errs if the flag isn't found.
panic(err)
}
@@ -70,7 +70,7 @@ func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
}
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil {
panic(err)
}
}
@@ -91,9 +91,9 @@ func (c *Command) ValidateFlagGroups() error {
oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
flags.VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
})
if err := validateRequiredFlagGroups(groupStatus); err != nil {
@@ -130,7 +130,7 @@ func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annota
continue
}
groupStatus[group] = map[string]bool{}
groupStatus[group] = make(map[string]bool, len(flagnames))
for _, name := range flagnames {
groupStatus[group][name] = false
}
@@ -232,9 +232,9 @@ func (c *Command) enforceFlagGroupsForCompletion() {
oneRequiredGroupStatus := map[string]map[string]bool{}
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
c.Flags().VisitAll(func(pflag *flag.Flag) {
processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
})
// If a flag that is part of a group is present, we make all the other flags
@@ -253,17 +253,17 @@ func (c *Command) enforceFlagGroupsForCompletion() {
// If none of the flags of a one-required group are present, we make all the flags
// of that group required so that the shell completion suggests them automatically
for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
set := 0
isSet := false
for _, isSet := range flagnameAndStatus {
for _, isSet = range flagnameAndStatus {
if isSet {
set++
break
}
}
// None of the flags of the group are set, mark all flags in the group
// as required
if set == 0 {
if !isSet {
for _, fName := range strings.Split(flagList, " ") {
_ = c.MarkFlagRequired(fName)
}
+2 -2
View File
@@ -28,8 +28,8 @@ import (
func genPowerShellComp(buf io.StringWriter, name string, includeDesc bool) {
// Variables should not contain a '-' or ':' character
nameForVar := name
nameForVar = strings.Replace(nameForVar, "-", "_", -1)
nameForVar = strings.Replace(nameForVar, ":", "_", -1)
nameForVar = strings.ReplaceAll(nameForVar, "-", "_")
nameForVar = strings.ReplaceAll(nameForVar, ":", "_")
compCmd := ShellCompRequestCmd
if !includeDesc {
Generated Vendored
+1 -1
View File
@@ -156,7 +156,7 @@ github.com/sergi/go-diff/diffmatchpatch
# github.com/skeema/knownhosts v1.2.2
## explicit; go 1.17
github.com/skeema/knownhosts
# github.com/spf13/cobra v1.8.0
# github.com/spf13/cobra v1.8.1
## explicit; go 1.15
github.com/spf13/cobra
# github.com/spf13/pflag v1.0.5