From 22bae1217e2048c4187352e0be5481edd758f863 Mon Sep 17 00:00:00 2001 From: Francesco Renzi Date: Tue, 14 Apr 2026 09:27:42 +0100 Subject: [PATCH] Restore job acquisition flow (#90) --- client.go | 8 ++ listener/listener.go | 24 +++++ listener/mocks_test.go | 68 ++++++++++++ session_client.go | 61 +++++++++++ session_client_test.go | 228 +++++++++++++++++++++++++++++++++++++++++ types.go | 12 +++ 6 files changed, 401 insertions(+) diff --git a/client.go b/client.go index a88e12b..7775443 100644 --- a/client.go +++ b/client.go @@ -510,6 +510,14 @@ func parseRunnerScaleSetMessageResponse(respBody io.Reader) (*RunnerScaleSetMess } switch messageType.MessageType { + case MessageTypeJobAvailable: + var jobAvailable JobAvailable + if err := json.Unmarshal(msg, &jobAvailable); err != nil { + return nil, fmt.Errorf("failed to decode job available: %w", err) + } + + message.JobAvailableMessages = append(message.JobAvailableMessages, &jobAvailable) + case MessageTypeJobAssigned: var jobAssigned JobAssigned if err := json.Unmarshal(msg, &jobAssigned); err != nil { diff --git a/listener/listener.go b/listener/listener.go index 0db0776..77802ab 100644 --- a/listener/listener.go +++ b/listener/listener.go @@ -49,6 +49,7 @@ func (c *Config) Validate() error { type Client interface { GetMessage(ctx context.Context, lastMessageID, maxCapacity int) (*scaleset.RunnerScaleSetMessage, error) DeleteMessage(ctx context.Context, messageID int) error + AcquireJobs(ctx context.Context, requestIDs []int64) ([]int64, error) Session() scaleset.RunnerScaleSetSession } @@ -210,6 +211,12 @@ func (l *Listener) handleMessage(ctx context.Context, handler Scaler, msg *scale return fmt.Errorf("failed to delete message: %w", err) } + if len(msg.JobAvailableMessages) > 0 { + if err := l.acquireAvailableJobs(ctx, msg.JobAvailableMessages); err != nil { + return fmt.Errorf("failed to acquire available jobs: %w", err) + } + } + for _, jobStarted := range msg.JobStartedMessages { l.metricsRecorder.RecordJobStarted(jobStarted) if err := handler.HandleJobStarted(ctx, jobStarted); err != nil { @@ -232,6 +239,23 @@ func (l *Listener) handleMessage(ctx context.Context, handler Scaler, msg *scale return nil } +func (l *Listener) acquireAvailableJobs(ctx context.Context, jobsAvailable []*scaleset.JobAvailable) error { + ids := make([]int64, 0, len(jobsAvailable)) + for _, job := range jobsAvailable { + ids = append(ids, job.RunnerRequestID) + } + + l.logger.Info("Acquiring jobs", slog.Int("count", len(ids))) + + acquired, err := l.client.AcquireJobs(ctx, ids) + if err != nil { + return fmt.Errorf("acquiring jobs: %w", err) + } + + l.logger.Info("Jobs acquired", slog.Int("count", len(acquired))) + return nil +} + func (l *Listener) handleStatistics(ctx context.Context, msg *scaleset.RunnerScaleSetStatistic) { l.latestStatistics = msg l.metricsRecorder.RecordStatistics(msg) diff --git a/listener/mocks_test.go b/listener/mocks_test.go index f045a06..1585100 100644 --- a/listener/mocks_test.go +++ b/listener/mocks_test.go @@ -38,6 +38,74 @@ func (_m *MockClient) EXPECT() *MockClient_Expecter { return &MockClient_Expecter{mock: &_m.Mock} } +// AcquireJobs provides a mock function for the type MockClient +func (_mock *MockClient) AcquireJobs(ctx context.Context, requestIDs []int64) ([]int64, error) { + ret := _mock.Called(ctx, requestIDs) + + if len(ret) == 0 { + panic("no return value specified for AcquireJobs") + } + + var r0 []int64 + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, []int64) ([]int64, error)); ok { + return returnFunc(ctx, requestIDs) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, []int64) []int64); ok { + r0 = returnFunc(ctx, requestIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int64) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, []int64) error); ok { + r1 = returnFunc(ctx, requestIDs) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockClient_AcquireJobs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AcquireJobs' +type MockClient_AcquireJobs_Call struct { + *mock.Call +} + +// AcquireJobs is a helper method to define mock.On call +// - ctx context.Context +// - requestIDs []int64 +func (_e *MockClient_Expecter) AcquireJobs(ctx interface{}, requestIDs interface{}) *MockClient_AcquireJobs_Call { + return &MockClient_AcquireJobs_Call{Call: _e.mock.On("AcquireJobs", ctx, requestIDs)} +} + +func (_c *MockClient_AcquireJobs_Call) Run(run func(ctx context.Context, requestIDs []int64)) *MockClient_AcquireJobs_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 []int64 + if args[1] != nil { + arg1 = args[1].([]int64) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockClient_AcquireJobs_Call) Return(int64s []int64, err error) *MockClient_AcquireJobs_Call { + _c.Call.Return(int64s, err) + return _c +} + +func (_c *MockClient_AcquireJobs_Call) RunAndReturn(run func(ctx context.Context, requestIDs []int64) ([]int64, error)) *MockClient_AcquireJobs_Call { + _c.Call.Return(run) + return _c +} + // DeleteMessage provides a mock function for the type MockClient func (_mock *MockClient) DeleteMessage(ctx context.Context, messageID int) error { ret := _mock.Called(ctx, messageID) diff --git a/session_client.go b/session_client.go index c80556a..85877b3 100644 --- a/session_client.go +++ b/session_client.go @@ -234,6 +234,67 @@ func (c *MessageSessionClient) Session() RunnerScaleSetSession { return *c.session } +// AcquireJobs acquires the given job request IDs from the runner scale set. +// If the current session token is expired, it refreshes the session and tries one more time. +func (c *MessageSessionClient) AcquireJobs(ctx context.Context, requestIDs []int64) ([]int64, error) { + c.mu.Lock() + defer c.mu.Unlock() + + ids, err := c.acquireJobs(ctx, requestIDs) + if err == nil { + return ids, nil + } + + if !errors.Is(err, MessageQueueTokenExpiredError) { + return nil, fmt.Errorf("failed to acquire jobs: %w", err) + } + + if err := c.refreshMessageSession(ctx); err != nil { + return nil, fmt.Errorf("failed to refresh message session: %w", err) + } + + return c.acquireJobs(ctx, requestIDs) +} + +func (c *MessageSessionClient) acquireJobs(ctx context.Context, requestIDs []int64) ([]int64, error) { + body, err := json.Marshal(requestIDs) + if err != nil { + return nil, fmt.Errorf("failed to marshal request ids: %w", err) + } + + path := fmt.Sprintf("/%s/%d/acquirejobs", scaleSetEndpoint, c.scaleSetID) + + c.innerClient.mu.Lock() + req, err := c.innerClient.newActionsServiceRequest(ctx, http.MethodPost, path, bytes.NewBuffer(body)) + c.innerClient.mu.Unlock() + if err != nil { + return nil, fmt.Errorf("failed to create acquire jobs request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.session.MessageQueueAccessToken)) + + resp, err := c.commonClient.do(req) + if err != nil { + return nil, fmt.Errorf("failed to issue acquire jobs request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusUnauthorized { + return nil, newRequestResponseError(req, resp, MessageQueueTokenExpiredError) + } + + if resp.StatusCode != http.StatusOK { + return nil, newRequestResponseError(req, resp, fmt.Errorf("unexpected status code %s", resp.Status)) + } + + var result acquireJobsResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, newRequestResponseError(req, resp, fmt.Errorf("failed to decode acquire jobs response: %w", err)) + } + + return result.Value, nil +} + func (c *MessageSessionClient) doSessionRequest(ctx context.Context, method, path string, requestData io.Reader, expectedResponseStatusCode int, responseUnmarshalTarget any) error { c.innerClient.mu.Lock() defer c.innerClient.mu.Unlock() diff --git a/session_client_test.go b/session_client_test.go index 79c0e69..7e35ac0 100644 --- a/session_client_test.go +++ b/session_client_test.go @@ -649,3 +649,231 @@ func TestDeleteMessage(t *testing.T) { assert.Contains(t, err.Error(), "unexpected status code") }) } + +func TestAcquireJobs(t *testing.T) { + ctx := context.Background() + auth := actionsAuth{ + token: "token", + } + + t.Run("Acquire jobs successfully", func(t *testing.T) { + requestIDs := []int64{1, 2} + want := []int64{1, 2} + response := []byte(`{"count":2,"value":[1,2]}`) + + var handleSessionRequest http.HandlerFunc + server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "acquirejobs") { + assert.Equal(t, http.MethodPost, r.Method) + assert.Contains(t, r.URL.Path, "acquirejobs") + assert.True(t, strings.HasPrefix(r.Header.Get("Authorization"), "Bearer"), "expected Bearer authorization header") + + var gotIDs []int64 + err := json.NewDecoder(r.Body).Decode(&gotIDs) + require.NoError(t, err) + assert.Equal(t, requestIDs, gotIDs) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(response) + return + } + if strings.HasSuffix(r.URL.Path, "sessions") { + handleSessionRequest(w, r) + return + } + if strings.Contains(r.URL.Path, "/sessions/") { + handleSessionRequest(w, r) + return + } + })) + handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession()) + + client, err := newClient( + testSystemInfo, + server.configURLForOrg("my-org"), + auth, + ) + require.NoError(t, err) + + sessionClient, err := client.MessageSessionClient(ctx, 1, "my-org") + require.NoError(t, err) + + got, err := sessionClient.AcquireJobs(ctx, requestIDs) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("Message token expired", func(t *testing.T) { + var handleSessionRequest http.HandlerFunc + server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "acquirejobs") { + w.WriteHeader(http.StatusUnauthorized) + return + } + // create session + if strings.HasSuffix(r.URL.Path, "sessions") { + handleSessionRequest(w, r) + return + } + // refresh + if strings.Contains(r.URL.Path, "/sessions/") { + handleSessionRequest(w, r) + return + } + w.WriteHeader(http.StatusUnauthorized) + })) + handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession()) + + client, err := newClient( + testSystemInfo, + server.configURLForOrg("my-org"), + auth, + ) + require.NoError(t, err) + + sessionClient, err := client.MessageSessionClient(ctx, 1, "my-org") + require.NoError(t, err) + + got, err := sessionClient.AcquireJobs(ctx, []int64{1}) + assert.Nil(t, got) + assert.ErrorIs(t, err, MessageQueueTokenExpiredError, "expected error to be MessageQueueTokenExpiredError but got: %v", err) + }) + + t.Run("Message token refreshed", func(t *testing.T) { + want := []int64{1, 2} + afterRefreshResponse := []byte(`{"count":2,"value":[1,2]}`) + + var handleSessionRequest http.HandlerFunc + type state int + const ( + createSession state = iota + firstAcquire + refreshToken + secondAcquire + ) + currentState := createSession + server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "acquirejobs") { + if currentState == firstAcquire { + w.WriteHeader(http.StatusUnauthorized) + currentState = refreshToken + return + } + require.Equal(t, secondAcquire, currentState) + w.Header().Set("Content-Type", "application/json") + w.Write(afterRefreshResponse) + return + } + // create session + if strings.HasSuffix(r.URL.Path, "sessions") { + require.Equal(t, createSession, currentState) + handleSessionRequest(w, r) + currentState = firstAcquire + return + } + // refresh + if strings.Contains(r.URL.Path, "/sessions/") { + require.Equal(t, refreshToken, currentState) + handleSessionRequest(w, r) + currentState = secondAcquire + return + } + })) + handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession()) + + client, err := newClient( + testSystemInfo, + server.configURLForOrg("my-org"), + auth, + ) + require.NoError(t, err) + + sessionClient, err := client.MessageSessionClient(ctx, 1, "my-org") + require.NoError(t, err) + + got, err := sessionClient.AcquireJobs(ctx, []int64{1, 2}) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("Server error", func(t *testing.T) { + var handleSessionRequest http.HandlerFunc + server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "acquirejobs") { + w.WriteHeader(http.StatusInternalServerError) + return + } + if strings.HasSuffix(r.URL.Path, "sessions") { + handleSessionRequest(w, r) + return + } + if strings.Contains(r.URL.Path, "/sessions/") { + handleSessionRequest(w, r) + return + } + })) + handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession()) + + retryMax := 1 + client, err := newClient( + testSystemInfo, + server.configURLForOrg("my-org"), + auth, + WithRetryMax(retryMax), + WithRetryWaitMax(1*time.Nanosecond), + ) + require.NoError(t, err) + + sessionClient, err := client.MessageSessionClient( + ctx, + 1, + "my-org", + WithRetryMax(retryMax), + WithRetryWaitMax(1*time.Nanosecond), + ) + require.NoError(t, err) + + got, err := sessionClient.AcquireJobs(ctx, []int64{1}) + assert.Nil(t, got) + assert.NotNil(t, err) + }) + + t.Run("Empty request IDs", func(t *testing.T) { + response := []byte(`{"count":0,"value":[]}`) + + var handleSessionRequest http.HandlerFunc + server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "acquirejobs") { + assert.Equal(t, http.MethodPost, r.Method) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(response) + return + } + if strings.HasSuffix(r.URL.Path, "sessions") { + handleSessionRequest(w, r) + return + } + if strings.Contains(r.URL.Path, "/sessions/") { + handleSessionRequest(w, r) + return + } + })) + handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession()) + + client, err := newClient( + testSystemInfo, + server.configURLForOrg("my-org"), + auth, + ) + require.NoError(t, err) + + sessionClient, err := client.MessageSessionClient(ctx, 1, "my-org") + require.NoError(t, err) + + got, err := sessionClient.AcquireJobs(ctx, []int64{}) + require.NoError(t, err) + assert.Empty(t, got) + }) +} diff --git a/types.go b/types.go index 851039f..228f5d6 100644 --- a/types.go +++ b/types.go @@ -12,11 +12,17 @@ type MessageType string // message types const ( + MessageTypeJobAvailable MessageType = "JobAvailable" MessageTypeJobAssigned MessageType = "JobAssigned" MessageTypeJobStarted MessageType = "JobStarted" MessageTypeJobCompleted MessageType = "JobCompleted" ) +type JobAvailable struct { + AcquireJobURL string `json:"acquireJobUrl"` + JobMessageBase +} + type JobAssigned struct { JobMessageBase } @@ -99,6 +105,7 @@ type runnerScaleSetMessageResponse struct { type RunnerScaleSetMessage struct { MessageID int Statistics *RunnerScaleSetStatistic + JobAvailableMessages []*JobAvailable JobAssignedMessages []*JobAssigned JobStartedMessages []*JobStarted JobCompletedMessages []*JobCompleted @@ -109,6 +116,11 @@ type runnerScaleSetsResponse struct { RunnerScaleSets []RunnerScaleSet `json:"value"` } +type acquireJobsResponse struct { + Count int `json:"count"` + Value []int64 `json:"value"` +} + type RunnerScaleSetSession struct { SessionID uuid.UUID `json:"sessionId,omitempty"` OwnerName string `json:"ownerName,omitempty"`