297 lines
8.8 KiB
Go
297 lines
8.8 KiB
Go
package scaleset
|
|
|
|
import (
|
|
"encoding/json"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/actions/scaleset/internal/testserver"
|
|
"github.com/hashicorp/go-retryablehttp"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/net/http/httpproxy"
|
|
)
|
|
|
|
func defaultHTTPClientOption() httpClientOption {
|
|
var opt httpClientOption
|
|
opt.defaults()
|
|
return opt
|
|
}
|
|
|
|
func TestClient_Do(t *testing.T) {
|
|
t.Run("trims byte order mark from response if present", func(t *testing.T) {
|
|
t.Run("when there is no body", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
}))
|
|
defer server.Close()
|
|
|
|
client := newCommonClient(
|
|
testSystemInfo,
|
|
defaultHTTPClientOption(),
|
|
)
|
|
|
|
req, err := http.NewRequest("GET", server.URL, nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := client.do(req)
|
|
require.NoError(t, err)
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
assert.Empty(t, string(body))
|
|
})
|
|
|
|
responses := []string{
|
|
"\xef\xbb\xbf{\"foo\":\"bar\"}",
|
|
"{\"foo\":\"bar\"}",
|
|
}
|
|
|
|
for _, response := range responses {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write([]byte(response))
|
|
}))
|
|
defer server.Close()
|
|
|
|
client := newCommonClient(
|
|
testSystemInfo,
|
|
defaultHTTPClientOption(),
|
|
)
|
|
|
|
req, err := http.NewRequest("GET", server.URL, nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := client.do(req)
|
|
require.NoError(t, err)
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "{\"foo\":\"bar\"}", string(body))
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestClientProxy(t *testing.T) {
|
|
serverCalled := false
|
|
|
|
proxy := testserver.New(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
serverCalled = true
|
|
}))
|
|
|
|
proxyConfig := &httpproxy.Config{
|
|
HTTPProxy: proxy.URL,
|
|
}
|
|
proxyFunc := func(req *http.Request) (*url.URL, error) {
|
|
return proxyConfig.ProxyFunc()(req.URL)
|
|
}
|
|
|
|
opts := defaultHTTPClientOption()
|
|
WithProxy(proxyFunc)(&opts)
|
|
|
|
client := newCommonClient(
|
|
testSystemInfo,
|
|
opts,
|
|
)
|
|
|
|
req, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
require.NoError(t, err)
|
|
|
|
_, err = client.do(req)
|
|
require.NoError(t, err)
|
|
|
|
assert.True(t, serverCalled)
|
|
}
|
|
|
|
func TestUserAgent(t *testing.T) {
|
|
version, sha := detectModuleVersionAndCommit()
|
|
userAgentInfo := SystemInfo{
|
|
System: "actions-runner-controller",
|
|
Version: "0.1.0",
|
|
CommitSHA: "1234567890abcdef",
|
|
ScaleSetID: 10,
|
|
Subsystem: "test",
|
|
}
|
|
|
|
client := newCommonClient(
|
|
testSystemInfo,
|
|
defaultHTTPClientOption(),
|
|
)
|
|
|
|
got := client.userAgent
|
|
wantInfo := userAgent{
|
|
SystemInfo: testSystemInfo,
|
|
BuildCommitSHA: sha,
|
|
BuildVersion: version,
|
|
Kind: "scaleset",
|
|
}
|
|
b, err := json.Marshal(wantInfo)
|
|
require.NoError(t, err, "failed to marshal expected user agent")
|
|
want := string(b)
|
|
|
|
assert.Equal(t, want, got)
|
|
|
|
client.setSystemInfo(SystemInfo{
|
|
System: "actions-runner-controller",
|
|
Version: "0.1.0",
|
|
CommitSHA: "1234567890abcdef",
|
|
ScaleSetID: 10,
|
|
Subsystem: "test",
|
|
})
|
|
|
|
got = client.userAgent
|
|
wantInfo = userAgent{
|
|
SystemInfo: userAgentInfo,
|
|
BuildCommitSHA: sha,
|
|
BuildVersion: version,
|
|
Kind: "scaleset",
|
|
}
|
|
b, err = json.Marshal(wantInfo)
|
|
require.NoError(t, err, "failed to marshal expected user agent after SetSystemInfo")
|
|
want = string(b)
|
|
|
|
assert.Equal(t, want, got)
|
|
}
|
|
|
|
func TestWithLogger(t *testing.T) {
|
|
newJSONHandler := func() slog.Handler {
|
|
return slog.NewJSONHandler(
|
|
io.Discard,
|
|
&slog.HandlerOptions{
|
|
AddSource: true,
|
|
Level: slog.LevelError,
|
|
},
|
|
)
|
|
}
|
|
t.Run("WithLogger(nil) sets a discard logger on raw httpClientOption", func(t *testing.T) {
|
|
opts := httpClientOption{}
|
|
WithLogger(nil)(&opts)
|
|
require.NotNil(t, opts.logger, "WithLogger(nil) should set a discard logger, not leave it nil")
|
|
assert.Equal(t, slog.DiscardHandler, opts.logger.Handler(), "WithLogger(nil) should set a discard logger handler")
|
|
})
|
|
|
|
t.Run("WithLogger(customLogger) assigns the provided logger", func(t *testing.T) {
|
|
handler := newJSONHandler()
|
|
customLogger := slog.New(handler)
|
|
opts := httpClientOption{}
|
|
WithLogger(customLogger)(&opts)
|
|
require.Equal(t, customLogger, opts.logger, "WithLogger should assign the provided logger")
|
|
assert.Equal(t, handler, opts.logger.Handler(), "WithLogger should set the provided logger handler")
|
|
})
|
|
|
|
t.Run("WithLogger(nil) propagates discard logger to retryable HTTP client", func(t *testing.T) {
|
|
opts := httpClientOption{}
|
|
WithLogger(nil)(&opts)
|
|
client, err := opts.newRetryableHTTPClient()
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, client.Logger, "retryable client should have logger set from WithLogger(nil)")
|
|
|
|
logger, ok := client.Logger.(*slog.Logger)
|
|
require.True(t, ok, "retryable client logger should be a *slog.Logger")
|
|
assert.Same(t, opts.logger, logger, "retryable client logger should be the same logger set by WithLogger(nil)")
|
|
assert.Equal(t, slog.DiscardHandler, logger.Handler(), "retryable client logger should be a discard logger from WithLogger(nil)")
|
|
})
|
|
|
|
t.Run("WithLogger(customLogger) propagates custom logger to retryable HTTP client", func(t *testing.T) {
|
|
handler := newJSONHandler()
|
|
customLogger := slog.New(handler)
|
|
opts := httpClientOption{}
|
|
WithLogger(customLogger)(&opts)
|
|
client, err := opts.newRetryableHTTPClient()
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, client.Logger, "retryable client should have logger set")
|
|
logger, ok := client.Logger.(*slog.Logger)
|
|
require.True(t, ok, "retryable client logger should be a *slog.Logger")
|
|
assert.Equal(t, handler, logger.Handler(), "retryable client logger should be the custom logger from WithLogger")
|
|
})
|
|
}
|
|
|
|
// TestWithRetryableHTTPClient verifies that a custom retryable HTTP client
|
|
// provided via WithRetryableHTTPClient is actually used instead of the built-in one
|
|
func TestWithRetryableHTTPClient(t *testing.T) {
|
|
t.Run("uses custom retryable client instead of built-in", func(t *testing.T) {
|
|
attemptCount := 0
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
attemptCount++
|
|
if attemptCount == 1 {
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"result": "success"}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create a custom retryable HTTP client with specific retry configuration
|
|
customRetryClient := retryablehttp.NewClient()
|
|
customRetryClient.RetryMax = 3
|
|
customRetryClient.RetryWaitMax = 10 * time.Millisecond
|
|
|
|
// Create options with the custom retryable client
|
|
opts := defaultHTTPClientOption()
|
|
WithRetryableHTTPClint(customRetryClient)(&opts)
|
|
|
|
// Verify that the custom client is set in options
|
|
assert.NotNil(t, opts.retryableHTTPClient)
|
|
assert.Equal(t, customRetryClient, opts.retryableHTTPClient)
|
|
|
|
// Create the common client with custom retryable client
|
|
client := newCommonClient(testSystemInfo, opts)
|
|
|
|
// Make a request that will trigger a retry
|
|
req, err := http.NewRequest("GET", server.URL, nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := client.do(req)
|
|
require.NoError(t, err)
|
|
|
|
// Should succeed after retry
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
assert.Equal(t, 2, attemptCount)
|
|
|
|
// Verify that the client used is the custom one by checking newRetryableHTTPClient
|
|
retrievedRetryClient, err := client.newRetryableHTTPClient()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, customRetryClient, retrievedRetryClient, "should return the custom retryable client")
|
|
})
|
|
|
|
t.Run("respects custom client's retry configuration over built-in defaults", func(t *testing.T) {
|
|
attemptCount := 0
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
attemptCount++
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Create custom client with limited retries
|
|
customRetryClient := retryablehttp.NewClient()
|
|
customRetryClient.RetryMax = 1 // Only 1 retry (2 total attempts)
|
|
customRetryClient.RetryWaitMax = 5 * time.Millisecond
|
|
|
|
opts := defaultHTTPClientOption()
|
|
WithRetryableHTTPClint(customRetryClient)(&opts)
|
|
|
|
client := newCommonClient(testSystemInfo, opts)
|
|
|
|
req, err := http.NewRequest("GET", server.URL, nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := client.do(req)
|
|
// When all retries are exhausted with a retryable error, the client gives up
|
|
// and an error is returned
|
|
if err != nil {
|
|
// Expected: request failed after exhausting retries
|
|
assert.Contains(t, err.Error(), "giving up after 2 attempt(s)")
|
|
} else {
|
|
// Or the final response is returned
|
|
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
|
|
}
|
|
// Should have tried 1 initial + 1 retry = 2 times total
|
|
assert.Equal(t, 2, attemptCount)
|
|
})
|
|
}
|