Restore job acquisition flow (#90)
This commit is contained in:
@@ -510,6 +510,14 @@ func parseRunnerScaleSetMessageResponse(respBody io.Reader) (*RunnerScaleSetMess
|
||||
}
|
||||
|
||||
switch messageType.MessageType {
|
||||
case MessageTypeJobAvailable:
|
||||
var jobAvailable JobAvailable
|
||||
if err := json.Unmarshal(msg, &jobAvailable); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode job available: %w", err)
|
||||
}
|
||||
|
||||
message.JobAvailableMessages = append(message.JobAvailableMessages, &jobAvailable)
|
||||
|
||||
case MessageTypeJobAssigned:
|
||||
var jobAssigned JobAssigned
|
||||
if err := json.Unmarshal(msg, &jobAssigned); err != nil {
|
||||
|
||||
@@ -49,6 +49,7 @@ func (c *Config) Validate() error {
|
||||
type Client interface {
|
||||
GetMessage(ctx context.Context, lastMessageID, maxCapacity int) (*scaleset.RunnerScaleSetMessage, error)
|
||||
DeleteMessage(ctx context.Context, messageID int) error
|
||||
AcquireJobs(ctx context.Context, requestIDs []int64) ([]int64, error)
|
||||
Session() scaleset.RunnerScaleSetSession
|
||||
}
|
||||
|
||||
@@ -210,6 +211,12 @@ func (l *Listener) handleMessage(ctx context.Context, handler Scaler, msg *scale
|
||||
return fmt.Errorf("failed to delete message: %w", err)
|
||||
}
|
||||
|
||||
if len(msg.JobAvailableMessages) > 0 {
|
||||
if err := l.acquireAvailableJobs(ctx, msg.JobAvailableMessages); err != nil {
|
||||
return fmt.Errorf("failed to acquire available jobs: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, jobStarted := range msg.JobStartedMessages {
|
||||
l.metricsRecorder.RecordJobStarted(jobStarted)
|
||||
if err := handler.HandleJobStarted(ctx, jobStarted); err != nil {
|
||||
@@ -232,6 +239,23 @@ func (l *Listener) handleMessage(ctx context.Context, handler Scaler, msg *scale
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) acquireAvailableJobs(ctx context.Context, jobsAvailable []*scaleset.JobAvailable) error {
|
||||
ids := make([]int64, 0, len(jobsAvailable))
|
||||
for _, job := range jobsAvailable {
|
||||
ids = append(ids, job.RunnerRequestID)
|
||||
}
|
||||
|
||||
l.logger.Info("Acquiring jobs", slog.Int("count", len(ids)))
|
||||
|
||||
acquired, err := l.client.AcquireJobs(ctx, ids)
|
||||
if err != nil {
|
||||
return fmt.Errorf("acquiring jobs: %w", err)
|
||||
}
|
||||
|
||||
l.logger.Info("Jobs acquired", slog.Int("count", len(acquired)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) handleStatistics(ctx context.Context, msg *scaleset.RunnerScaleSetStatistic) {
|
||||
l.latestStatistics = msg
|
||||
l.metricsRecorder.RecordStatistics(msg)
|
||||
|
||||
@@ -38,6 +38,74 @@ func (_m *MockClient) EXPECT() *MockClient_Expecter {
|
||||
return &MockClient_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// AcquireJobs provides a mock function for the type MockClient
|
||||
func (_mock *MockClient) AcquireJobs(ctx context.Context, requestIDs []int64) ([]int64, error) {
|
||||
ret := _mock.Called(ctx, requestIDs)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for AcquireJobs")
|
||||
}
|
||||
|
||||
var r0 []int64
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, []int64) ([]int64, error)); ok {
|
||||
return returnFunc(ctx, requestIDs)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, []int64) []int64); ok {
|
||||
r0 = returnFunc(ctx, requestIDs)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]int64)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, []int64) error); ok {
|
||||
r1 = returnFunc(ctx, requestIDs)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockClient_AcquireJobs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AcquireJobs'
|
||||
type MockClient_AcquireJobs_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// AcquireJobs is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - requestIDs []int64
|
||||
func (_e *MockClient_Expecter) AcquireJobs(ctx interface{}, requestIDs interface{}) *MockClient_AcquireJobs_Call {
|
||||
return &MockClient_AcquireJobs_Call{Call: _e.mock.On("AcquireJobs", ctx, requestIDs)}
|
||||
}
|
||||
|
||||
func (_c *MockClient_AcquireJobs_Call) Run(run func(ctx context.Context, requestIDs []int64)) *MockClient_AcquireJobs_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 []int64
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].([]int64)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockClient_AcquireJobs_Call) Return(int64s []int64, err error) *MockClient_AcquireJobs_Call {
|
||||
_c.Call.Return(int64s, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockClient_AcquireJobs_Call) RunAndReturn(run func(ctx context.Context, requestIDs []int64) ([]int64, error)) *MockClient_AcquireJobs_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// DeleteMessage provides a mock function for the type MockClient
|
||||
func (_mock *MockClient) DeleteMessage(ctx context.Context, messageID int) error {
|
||||
ret := _mock.Called(ctx, messageID)
|
||||
|
||||
@@ -234,6 +234,67 @@ func (c *MessageSessionClient) Session() RunnerScaleSetSession {
|
||||
return *c.session
|
||||
}
|
||||
|
||||
// AcquireJobs acquires the given job request IDs from the runner scale set.
|
||||
// If the current session token is expired, it refreshes the session and tries one more time.
|
||||
func (c *MessageSessionClient) AcquireJobs(ctx context.Context, requestIDs []int64) ([]int64, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
ids, err := c.acquireJobs(ctx, requestIDs)
|
||||
if err == nil {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
if !errors.Is(err, MessageQueueTokenExpiredError) {
|
||||
return nil, fmt.Errorf("failed to acquire jobs: %w", err)
|
||||
}
|
||||
|
||||
if err := c.refreshMessageSession(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh message session: %w", err)
|
||||
}
|
||||
|
||||
return c.acquireJobs(ctx, requestIDs)
|
||||
}
|
||||
|
||||
func (c *MessageSessionClient) acquireJobs(ctx context.Context, requestIDs []int64) ([]int64, error) {
|
||||
body, err := json.Marshal(requestIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request ids: %w", err)
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("/%s/%d/acquirejobs", scaleSetEndpoint, c.scaleSetID)
|
||||
|
||||
c.innerClient.mu.Lock()
|
||||
req, err := c.innerClient.newActionsServiceRequest(ctx, http.MethodPost, path, bytes.NewBuffer(body))
|
||||
c.innerClient.mu.Unlock()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create acquire jobs request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.session.MessageQueueAccessToken))
|
||||
|
||||
resp, err := c.commonClient.do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to issue acquire jobs request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
return nil, newRequestResponseError(req, resp, MessageQueueTokenExpiredError)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, newRequestResponseError(req, resp, fmt.Errorf("unexpected status code %s", resp.Status))
|
||||
}
|
||||
|
||||
var result acquireJobsResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, newRequestResponseError(req, resp, fmt.Errorf("failed to decode acquire jobs response: %w", err))
|
||||
}
|
||||
|
||||
return result.Value, nil
|
||||
}
|
||||
|
||||
func (c *MessageSessionClient) doSessionRequest(ctx context.Context, method, path string, requestData io.Reader, expectedResponseStatusCode int, responseUnmarshalTarget any) error {
|
||||
c.innerClient.mu.Lock()
|
||||
defer c.innerClient.mu.Unlock()
|
||||
|
||||
@@ -649,3 +649,231 @@ func TestDeleteMessage(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "unexpected status code")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAcquireJobs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
auth := actionsAuth{
|
||||
token: "token",
|
||||
}
|
||||
|
||||
t.Run("Acquire jobs successfully", func(t *testing.T) {
|
||||
requestIDs := []int64{1, 2}
|
||||
want := []int64{1, 2}
|
||||
response := []byte(`{"count":2,"value":[1,2]}`)
|
||||
|
||||
var handleSessionRequest http.HandlerFunc
|
||||
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "acquirejobs") {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Contains(t, r.URL.Path, "acquirejobs")
|
||||
assert.True(t, strings.HasPrefix(r.Header.Get("Authorization"), "Bearer"), "expected Bearer authorization header")
|
||||
|
||||
var gotIDs []int64
|
||||
err := json.NewDecoder(r.Body).Decode(&gotIDs)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, requestIDs, gotIDs)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(response)
|
||||
return
|
||||
}
|
||||
if strings.HasSuffix(r.URL.Path, "sessions") {
|
||||
handleSessionRequest(w, r)
|
||||
return
|
||||
}
|
||||
if strings.Contains(r.URL.Path, "/sessions/") {
|
||||
handleSessionRequest(w, r)
|
||||
return
|
||||
}
|
||||
}))
|
||||
handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession())
|
||||
|
||||
client, err := newClient(
|
||||
testSystemInfo,
|
||||
server.configURLForOrg("my-org"),
|
||||
auth,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
sessionClient, err := client.MessageSessionClient(ctx, 1, "my-org")
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := sessionClient.AcquireJobs(ctx, requestIDs)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
})
|
||||
|
||||
t.Run("Message token expired", func(t *testing.T) {
|
||||
var handleSessionRequest http.HandlerFunc
|
||||
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "acquirejobs") {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
// create session
|
||||
if strings.HasSuffix(r.URL.Path, "sessions") {
|
||||
handleSessionRequest(w, r)
|
||||
return
|
||||
}
|
||||
// refresh
|
||||
if strings.Contains(r.URL.Path, "/sessions/") {
|
||||
handleSessionRequest(w, r)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession())
|
||||
|
||||
client, err := newClient(
|
||||
testSystemInfo,
|
||||
server.configURLForOrg("my-org"),
|
||||
auth,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
sessionClient, err := client.MessageSessionClient(ctx, 1, "my-org")
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := sessionClient.AcquireJobs(ctx, []int64{1})
|
||||
assert.Nil(t, got)
|
||||
assert.ErrorIs(t, err, MessageQueueTokenExpiredError, "expected error to be MessageQueueTokenExpiredError but got: %v", err)
|
||||
})
|
||||
|
||||
t.Run("Message token refreshed", func(t *testing.T) {
|
||||
want := []int64{1, 2}
|
||||
afterRefreshResponse := []byte(`{"count":2,"value":[1,2]}`)
|
||||
|
||||
var handleSessionRequest http.HandlerFunc
|
||||
type state int
|
||||
const (
|
||||
createSession state = iota
|
||||
firstAcquire
|
||||
refreshToken
|
||||
secondAcquire
|
||||
)
|
||||
currentState := createSession
|
||||
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "acquirejobs") {
|
||||
if currentState == firstAcquire {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
currentState = refreshToken
|
||||
return
|
||||
}
|
||||
require.Equal(t, secondAcquire, currentState)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(afterRefreshResponse)
|
||||
return
|
||||
}
|
||||
// create session
|
||||
if strings.HasSuffix(r.URL.Path, "sessions") {
|
||||
require.Equal(t, createSession, currentState)
|
||||
handleSessionRequest(w, r)
|
||||
currentState = firstAcquire
|
||||
return
|
||||
}
|
||||
// refresh
|
||||
if strings.Contains(r.URL.Path, "/sessions/") {
|
||||
require.Equal(t, refreshToken, currentState)
|
||||
handleSessionRequest(w, r)
|
||||
currentState = secondAcquire
|
||||
return
|
||||
}
|
||||
}))
|
||||
handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession())
|
||||
|
||||
client, err := newClient(
|
||||
testSystemInfo,
|
||||
server.configURLForOrg("my-org"),
|
||||
auth,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
sessionClient, err := client.MessageSessionClient(ctx, 1, "my-org")
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := sessionClient.AcquireJobs(ctx, []int64{1, 2})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
})
|
||||
|
||||
t.Run("Server error", func(t *testing.T) {
|
||||
var handleSessionRequest http.HandlerFunc
|
||||
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "acquirejobs") {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if strings.HasSuffix(r.URL.Path, "sessions") {
|
||||
handleSessionRequest(w, r)
|
||||
return
|
||||
}
|
||||
if strings.Contains(r.URL.Path, "/sessions/") {
|
||||
handleSessionRequest(w, r)
|
||||
return
|
||||
}
|
||||
}))
|
||||
handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession())
|
||||
|
||||
retryMax := 1
|
||||
client, err := newClient(
|
||||
testSystemInfo,
|
||||
server.configURLForOrg("my-org"),
|
||||
auth,
|
||||
WithRetryMax(retryMax),
|
||||
WithRetryWaitMax(1*time.Nanosecond),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
sessionClient, err := client.MessageSessionClient(
|
||||
ctx,
|
||||
1,
|
||||
"my-org",
|
||||
WithRetryMax(retryMax),
|
||||
WithRetryWaitMax(1*time.Nanosecond),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := sessionClient.AcquireJobs(ctx, []int64{1})
|
||||
assert.Nil(t, got)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
t.Run("Empty request IDs", func(t *testing.T) {
|
||||
response := []byte(`{"count":0,"value":[]}`)
|
||||
|
||||
var handleSessionRequest http.HandlerFunc
|
||||
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "acquirejobs") {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(response)
|
||||
return
|
||||
}
|
||||
if strings.HasSuffix(r.URL.Path, "sessions") {
|
||||
handleSessionRequest(w, r)
|
||||
return
|
||||
}
|
||||
if strings.Contains(r.URL.Path, "/sessions/") {
|
||||
handleSessionRequest(w, r)
|
||||
return
|
||||
}
|
||||
}))
|
||||
handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession())
|
||||
|
||||
client, err := newClient(
|
||||
testSystemInfo,
|
||||
server.configURLForOrg("my-org"),
|
||||
auth,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
sessionClient, err := client.MessageSessionClient(ctx, 1, "my-org")
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := sessionClient.AcquireJobs(ctx, []int64{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, got)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,11 +12,17 @@ type MessageType string
|
||||
|
||||
// message types
|
||||
const (
|
||||
MessageTypeJobAvailable MessageType = "JobAvailable"
|
||||
MessageTypeJobAssigned MessageType = "JobAssigned"
|
||||
MessageTypeJobStarted MessageType = "JobStarted"
|
||||
MessageTypeJobCompleted MessageType = "JobCompleted"
|
||||
)
|
||||
|
||||
type JobAvailable struct {
|
||||
AcquireJobURL string `json:"acquireJobUrl"`
|
||||
JobMessageBase
|
||||
}
|
||||
|
||||
type JobAssigned struct {
|
||||
JobMessageBase
|
||||
}
|
||||
@@ -99,6 +105,7 @@ type runnerScaleSetMessageResponse struct {
|
||||
type RunnerScaleSetMessage struct {
|
||||
MessageID int
|
||||
Statistics *RunnerScaleSetStatistic
|
||||
JobAvailableMessages []*JobAvailable
|
||||
JobAssignedMessages []*JobAssigned
|
||||
JobStartedMessages []*JobStarted
|
||||
JobCompletedMessages []*JobCompleted
|
||||
@@ -109,6 +116,11 @@ type runnerScaleSetsResponse struct {
|
||||
RunnerScaleSets []RunnerScaleSet `json:"value"`
|
||||
}
|
||||
|
||||
type acquireJobsResponse struct {
|
||||
Count int `json:"count"`
|
||||
Value []int64 `json:"value"`
|
||||
}
|
||||
|
||||
type RunnerScaleSetSession struct {
|
||||
SessionID uuid.UUID `json:"sessionId,omitempty"`
|
||||
OwnerName string `json:"ownerName,omitempty"`
|
||||
|
||||
Reference in New Issue
Block a user