Restore job acquisition flow (#90)

This commit is contained in:
Francesco Renzi
2026-04-14 09:27:42 +01:00
committed by GitHub
parent a0708d5ea7
commit 22bae1217e
6 changed files with 401 additions and 0 deletions
+8
View File
@@ -510,6 +510,14 @@ func parseRunnerScaleSetMessageResponse(respBody io.Reader) (*RunnerScaleSetMess
}
switch messageType.MessageType {
case MessageTypeJobAvailable:
var jobAvailable JobAvailable
if err := json.Unmarshal(msg, &jobAvailable); err != nil {
return nil, fmt.Errorf("failed to decode job available: %w", err)
}
message.JobAvailableMessages = append(message.JobAvailableMessages, &jobAvailable)
case MessageTypeJobAssigned:
var jobAssigned JobAssigned
if err := json.Unmarshal(msg, &jobAssigned); err != nil {
+24
View File
@@ -49,6 +49,7 @@ func (c *Config) Validate() error {
type Client interface {
GetMessage(ctx context.Context, lastMessageID, maxCapacity int) (*scaleset.RunnerScaleSetMessage, error)
DeleteMessage(ctx context.Context, messageID int) error
AcquireJobs(ctx context.Context, requestIDs []int64) ([]int64, error)
Session() scaleset.RunnerScaleSetSession
}
@@ -210,6 +211,12 @@ func (l *Listener) handleMessage(ctx context.Context, handler Scaler, msg *scale
return fmt.Errorf("failed to delete message: %w", err)
}
if len(msg.JobAvailableMessages) > 0 {
if err := l.acquireAvailableJobs(ctx, msg.JobAvailableMessages); err != nil {
return fmt.Errorf("failed to acquire available jobs: %w", err)
}
}
for _, jobStarted := range msg.JobStartedMessages {
l.metricsRecorder.RecordJobStarted(jobStarted)
if err := handler.HandleJobStarted(ctx, jobStarted); err != nil {
@@ -232,6 +239,23 @@ func (l *Listener) handleMessage(ctx context.Context, handler Scaler, msg *scale
return nil
}
func (l *Listener) acquireAvailableJobs(ctx context.Context, jobsAvailable []*scaleset.JobAvailable) error {
ids := make([]int64, 0, len(jobsAvailable))
for _, job := range jobsAvailable {
ids = append(ids, job.RunnerRequestID)
}
l.logger.Info("Acquiring jobs", slog.Int("count", len(ids)))
acquired, err := l.client.AcquireJobs(ctx, ids)
if err != nil {
return fmt.Errorf("acquiring jobs: %w", err)
}
l.logger.Info("Jobs acquired", slog.Int("count", len(acquired)))
return nil
}
func (l *Listener) handleStatistics(ctx context.Context, msg *scaleset.RunnerScaleSetStatistic) {
l.latestStatistics = msg
l.metricsRecorder.RecordStatistics(msg)
+68
View File
@@ -38,6 +38,74 @@ func (_m *MockClient) EXPECT() *MockClient_Expecter {
return &MockClient_Expecter{mock: &_m.Mock}
}
// AcquireJobs provides a mock function for the type MockClient
func (_mock *MockClient) AcquireJobs(ctx context.Context, requestIDs []int64) ([]int64, error) {
ret := _mock.Called(ctx, requestIDs)
if len(ret) == 0 {
panic("no return value specified for AcquireJobs")
}
var r0 []int64
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, []int64) ([]int64, error)); ok {
return returnFunc(ctx, requestIDs)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, []int64) []int64); ok {
r0 = returnFunc(ctx, requestIDs)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]int64)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, []int64) error); ok {
r1 = returnFunc(ctx, requestIDs)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockClient_AcquireJobs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AcquireJobs'
type MockClient_AcquireJobs_Call struct {
*mock.Call
}
// AcquireJobs is a helper method to define mock.On call
// - ctx context.Context
// - requestIDs []int64
func (_e *MockClient_Expecter) AcquireJobs(ctx interface{}, requestIDs interface{}) *MockClient_AcquireJobs_Call {
return &MockClient_AcquireJobs_Call{Call: _e.mock.On("AcquireJobs", ctx, requestIDs)}
}
func (_c *MockClient_AcquireJobs_Call) Run(run func(ctx context.Context, requestIDs []int64)) *MockClient_AcquireJobs_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 []int64
if args[1] != nil {
arg1 = args[1].([]int64)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *MockClient_AcquireJobs_Call) Return(int64s []int64, err error) *MockClient_AcquireJobs_Call {
_c.Call.Return(int64s, err)
return _c
}
func (_c *MockClient_AcquireJobs_Call) RunAndReturn(run func(ctx context.Context, requestIDs []int64) ([]int64, error)) *MockClient_AcquireJobs_Call {
_c.Call.Return(run)
return _c
}
// DeleteMessage provides a mock function for the type MockClient
func (_mock *MockClient) DeleteMessage(ctx context.Context, messageID int) error {
ret := _mock.Called(ctx, messageID)
+61
View File
@@ -234,6 +234,67 @@ func (c *MessageSessionClient) Session() RunnerScaleSetSession {
return *c.session
}
// AcquireJobs acquires the given job request IDs from the runner scale set.
// If the current session token is expired, it refreshes the session and tries one more time.
func (c *MessageSessionClient) AcquireJobs(ctx context.Context, requestIDs []int64) ([]int64, error) {
c.mu.Lock()
defer c.mu.Unlock()
ids, err := c.acquireJobs(ctx, requestIDs)
if err == nil {
return ids, nil
}
if !errors.Is(err, MessageQueueTokenExpiredError) {
return nil, fmt.Errorf("failed to acquire jobs: %w", err)
}
if err := c.refreshMessageSession(ctx); err != nil {
return nil, fmt.Errorf("failed to refresh message session: %w", err)
}
return c.acquireJobs(ctx, requestIDs)
}
func (c *MessageSessionClient) acquireJobs(ctx context.Context, requestIDs []int64) ([]int64, error) {
body, err := json.Marshal(requestIDs)
if err != nil {
return nil, fmt.Errorf("failed to marshal request ids: %w", err)
}
path := fmt.Sprintf("/%s/%d/acquirejobs", scaleSetEndpoint, c.scaleSetID)
c.innerClient.mu.Lock()
req, err := c.innerClient.newActionsServiceRequest(ctx, http.MethodPost, path, bytes.NewBuffer(body))
c.innerClient.mu.Unlock()
if err != nil {
return nil, fmt.Errorf("failed to create acquire jobs request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.session.MessageQueueAccessToken))
resp, err := c.commonClient.do(req)
if err != nil {
return nil, fmt.Errorf("failed to issue acquire jobs request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusUnauthorized {
return nil, newRequestResponseError(req, resp, MessageQueueTokenExpiredError)
}
if resp.StatusCode != http.StatusOK {
return nil, newRequestResponseError(req, resp, fmt.Errorf("unexpected status code %s", resp.Status))
}
var result acquireJobsResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, newRequestResponseError(req, resp, fmt.Errorf("failed to decode acquire jobs response: %w", err))
}
return result.Value, nil
}
func (c *MessageSessionClient) doSessionRequest(ctx context.Context, method, path string, requestData io.Reader, expectedResponseStatusCode int, responseUnmarshalTarget any) error {
c.innerClient.mu.Lock()
defer c.innerClient.mu.Unlock()
+228
View File
@@ -649,3 +649,231 @@ func TestDeleteMessage(t *testing.T) {
assert.Contains(t, err.Error(), "unexpected status code")
})
}
func TestAcquireJobs(t *testing.T) {
ctx := context.Background()
auth := actionsAuth{
token: "token",
}
t.Run("Acquire jobs successfully", func(t *testing.T) {
requestIDs := []int64{1, 2}
want := []int64{1, 2}
response := []byte(`{"count":2,"value":[1,2]}`)
var handleSessionRequest http.HandlerFunc
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "acquirejobs") {
assert.Equal(t, http.MethodPost, r.Method)
assert.Contains(t, r.URL.Path, "acquirejobs")
assert.True(t, strings.HasPrefix(r.Header.Get("Authorization"), "Bearer"), "expected Bearer authorization header")
var gotIDs []int64
err := json.NewDecoder(r.Body).Decode(&gotIDs)
require.NoError(t, err)
assert.Equal(t, requestIDs, gotIDs)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(response)
return
}
if strings.HasSuffix(r.URL.Path, "sessions") {
handleSessionRequest(w, r)
return
}
if strings.Contains(r.URL.Path, "/sessions/") {
handleSessionRequest(w, r)
return
}
}))
handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession())
client, err := newClient(
testSystemInfo,
server.configURLForOrg("my-org"),
auth,
)
require.NoError(t, err)
sessionClient, err := client.MessageSessionClient(ctx, 1, "my-org")
require.NoError(t, err)
got, err := sessionClient.AcquireJobs(ctx, requestIDs)
require.NoError(t, err)
assert.Equal(t, want, got)
})
t.Run("Message token expired", func(t *testing.T) {
var handleSessionRequest http.HandlerFunc
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "acquirejobs") {
w.WriteHeader(http.StatusUnauthorized)
return
}
// create session
if strings.HasSuffix(r.URL.Path, "sessions") {
handleSessionRequest(w, r)
return
}
// refresh
if strings.Contains(r.URL.Path, "/sessions/") {
handleSessionRequest(w, r)
return
}
w.WriteHeader(http.StatusUnauthorized)
}))
handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession())
client, err := newClient(
testSystemInfo,
server.configURLForOrg("my-org"),
auth,
)
require.NoError(t, err)
sessionClient, err := client.MessageSessionClient(ctx, 1, "my-org")
require.NoError(t, err)
got, err := sessionClient.AcquireJobs(ctx, []int64{1})
assert.Nil(t, got)
assert.ErrorIs(t, err, MessageQueueTokenExpiredError, "expected error to be MessageQueueTokenExpiredError but got: %v", err)
})
t.Run("Message token refreshed", func(t *testing.T) {
want := []int64{1, 2}
afterRefreshResponse := []byte(`{"count":2,"value":[1,2]}`)
var handleSessionRequest http.HandlerFunc
type state int
const (
createSession state = iota
firstAcquire
refreshToken
secondAcquire
)
currentState := createSession
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "acquirejobs") {
if currentState == firstAcquire {
w.WriteHeader(http.StatusUnauthorized)
currentState = refreshToken
return
}
require.Equal(t, secondAcquire, currentState)
w.Header().Set("Content-Type", "application/json")
w.Write(afterRefreshResponse)
return
}
// create session
if strings.HasSuffix(r.URL.Path, "sessions") {
require.Equal(t, createSession, currentState)
handleSessionRequest(w, r)
currentState = firstAcquire
return
}
// refresh
if strings.Contains(r.URL.Path, "/sessions/") {
require.Equal(t, refreshToken, currentState)
handleSessionRequest(w, r)
currentState = secondAcquire
return
}
}))
handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession())
client, err := newClient(
testSystemInfo,
server.configURLForOrg("my-org"),
auth,
)
require.NoError(t, err)
sessionClient, err := client.MessageSessionClient(ctx, 1, "my-org")
require.NoError(t, err)
got, err := sessionClient.AcquireJobs(ctx, []int64{1, 2})
require.NoError(t, err)
assert.Equal(t, want, got)
})
t.Run("Server error", func(t *testing.T) {
var handleSessionRequest http.HandlerFunc
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "acquirejobs") {
w.WriteHeader(http.StatusInternalServerError)
return
}
if strings.HasSuffix(r.URL.Path, "sessions") {
handleSessionRequest(w, r)
return
}
if strings.Contains(r.URL.Path, "/sessions/") {
handleSessionRequest(w, r)
return
}
}))
handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession())
retryMax := 1
client, err := newClient(
testSystemInfo,
server.configURLForOrg("my-org"),
auth,
WithRetryMax(retryMax),
WithRetryWaitMax(1*time.Nanosecond),
)
require.NoError(t, err)
sessionClient, err := client.MessageSessionClient(
ctx,
1,
"my-org",
WithRetryMax(retryMax),
WithRetryWaitMax(1*time.Nanosecond),
)
require.NoError(t, err)
got, err := sessionClient.AcquireJobs(ctx, []int64{1})
assert.Nil(t, got)
assert.NotNil(t, err)
})
t.Run("Empty request IDs", func(t *testing.T) {
response := []byte(`{"count":0,"value":[]}`)
var handleSessionRequest http.HandlerFunc
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "acquirejobs") {
assert.Equal(t, http.MethodPost, r.Method)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(response)
return
}
if strings.HasSuffix(r.URL.Path, "sessions") {
handleSessionRequest(w, r)
return
}
if strings.Contains(r.URL.Path, "/sessions/") {
handleSessionRequest(w, r)
return
}
}))
handleSessionRequest = newTestSessionRequestHandler(t, server.testRunnerScaleSetSession())
client, err := newClient(
testSystemInfo,
server.configURLForOrg("my-org"),
auth,
)
require.NoError(t, err)
sessionClient, err := client.MessageSessionClient(ctx, 1, "my-org")
require.NoError(t, err)
got, err := sessionClient.AcquireJobs(ctx, []int64{})
require.NoError(t, err)
assert.Empty(t, got)
})
}
+12
View File
@@ -12,11 +12,17 @@ type MessageType string
// message types
const (
MessageTypeJobAvailable MessageType = "JobAvailable"
MessageTypeJobAssigned MessageType = "JobAssigned"
MessageTypeJobStarted MessageType = "JobStarted"
MessageTypeJobCompleted MessageType = "JobCompleted"
)
type JobAvailable struct {
AcquireJobURL string `json:"acquireJobUrl"`
JobMessageBase
}
type JobAssigned struct {
JobMessageBase
}
@@ -99,6 +105,7 @@ type runnerScaleSetMessageResponse struct {
type RunnerScaleSetMessage struct {
MessageID int
Statistics *RunnerScaleSetStatistic
JobAvailableMessages []*JobAvailable
JobAssignedMessages []*JobAssigned
JobStartedMessages []*JobStarted
JobCompletedMessages []*JobCompleted
@@ -109,6 +116,11 @@ type runnerScaleSetsResponse struct {
RunnerScaleSets []RunnerScaleSet `json:"value"`
}
type acquireJobsResponse struct {
Count int `json:"count"`
Value []int64 `json:"value"`
}
type RunnerScaleSetSession struct {
SessionID uuid.UUID `json:"sessionId,omitempty"`
OwnerName string `json:"ownerName,omitempty"`