chore: use github's shared prettier-config

This commit is contained in:
Marais Rossouw
2025-07-24 19:11:15 +10:00
parent a2235c5511
commit 7e2aa19f3b
26 changed files with 351 additions and 559 deletions
+1 -2
View File
@@ -83,8 +83,7 @@ jobs:
run: echo "hello" > prompt.txt run: echo "hello" > prompt.txt
- name: Create System Prompt File - name: Create System Prompt File
run: run: echo "You are a helpful AI assistant for testing." > system-prompt.txt
echo "You are a helpful AI assistant for testing." > system-prompt.txt
- name: Test Local Action with Prompt File - name: Test Local Action with Prompt File
id: test-action-prompt-file id: test-action-prompt-file
@@ -11,8 +11,7 @@ permissions:
jobs: jobs:
update_tag: update_tag:
name: name: Update the major tag to include the ${{ github.event.release.tag_name }}
Update the major tag to include the ${{ github.event.release.tag_name }}
changes changes
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
-16
View File
@@ -1,16 +0,0 @@
# See: https://prettier.io/docs/en/configuration
printWidth: 80
tabWidth: 2
useTabs: false
semi: false
singleQuote: true
quoteProps: as-needed
jsxSingleQuote: false
trailingComma: none
bracketSpacing: true
bracketSameLine: true
arrowParens: always
proseWrap: always
htmlWhitespaceSensitivity: css
endOfLine: lf
+1 -2
View File
@@ -83,8 +83,7 @@ model: openai/gpt-4o
```yaml ```yaml
messages: messages:
- role: system - role: system
content: content: You are a helpful assistant that describes animals using JSON format
You are a helpful assistant that describes animals using JSON format
- role: user - role: user
content: |- content: |-
Describe a {{animal}} Describe a {{animal}}
+1 -1
View File
@@ -1,5 +1,5 @@
import type * as core from '@actions/core' import type * as core from '@actions/core'
import { vi } from 'vitest' import {vi} from 'vitest'
export const debug = vi.fn<typeof core.debug>() export const debug = vi.fn<typeof core.debug>()
export const error = vi.fn<typeof core.error>() export const error = vi.fn<typeof core.error>()
+1 -2
View File
@@ -1,7 +1,6 @@
messages: messages:
- role: system - role: system
content: content: You are a helpful assistant that describes animals using JSON format
You are a helpful assistant that describes animals using JSON format
- role: user - role: user
content: |- content: |-
Describe a {{animal}} Describe a {{animal}}
+1 -1
View File
@@ -1,3 +1,3 @@
import { vi } from 'vitest' import {vi} from 'vitest'
export const wait = vi.fn<typeof import('../src/wait.js').wait>() export const wait = vi.fn<typeof import('../src/wait.js').wait>()
+34 -40
View File
@@ -1,41 +1,37 @@
import { describe, it, expect } from 'vitest' import {describe, it, expect} from 'vitest'
import { import {buildMessages, buildResponseFormat, buildInferenceRequest} from '../src/helpers'
buildMessages, import {PromptConfig} from '../src/prompt'
buildResponseFormat,
buildInferenceRequest
} from '../src/helpers'
import { PromptConfig } from '../src/prompt'
describe('helpers.ts - inference request building', () => { describe('helpers.ts - inference request building', () => {
describe('buildMessages', () => { describe('buildMessages', () => {
it('should build messages from prompt config', () => { it('should build messages from prompt config', () => {
const promptConfig: PromptConfig = { const promptConfig: PromptConfig = {
messages: [ messages: [
{ role: 'system', content: 'System message' }, {role: 'system', content: 'System message'},
{ role: 'user', content: 'User message' } {role: 'user', content: 'User message'},
] ],
} }
const result = buildMessages(promptConfig) const result = buildMessages(promptConfig)
expect(result).toEqual([ expect(result).toEqual([
{ role: 'system', content: 'System message' }, {role: 'system', content: 'System message'},
{ role: 'user', content: 'User message' } {role: 'user', content: 'User message'},
]) ])
}) })
it('should build messages from legacy format', () => { it('should build messages from legacy format', () => {
const result = buildMessages(undefined, 'System prompt', 'User prompt') const result = buildMessages(undefined, 'System prompt', 'User prompt')
expect(result).toEqual([ expect(result).toEqual([
{ role: 'system', content: 'System prompt' }, {role: 'system', content: 'System prompt'},
{ role: 'user', content: 'User prompt' } {role: 'user', content: 'User prompt'},
]) ])
}) })
it('should use default system prompt when none provided', () => { it('should use default system prompt when none provided', () => {
const result = buildMessages(undefined, undefined, 'User prompt') const result = buildMessages(undefined, undefined, 'User prompt')
expect(result).toEqual([ expect(result).toEqual([
{ role: 'system', content: 'You are a helpful assistant' }, {role: 'system', content: 'You are a helpful assistant'},
{ role: 'user', content: 'User prompt' } {role: 'user', content: 'User prompt'},
]) ])
}) })
}) })
@@ -47,8 +43,8 @@ describe('helpers.ts - inference request building', () => {
responseFormat: 'json_schema', responseFormat: 'json_schema',
jsonSchema: JSON.stringify({ jsonSchema: JSON.stringify({
name: 'test_schema', name: 'test_schema',
schema: { type: 'object' } schema: {type: 'object'},
}) }),
} }
const result = buildResponseFormat(promptConfig) const result = buildResponseFormat(promptConfig)
@@ -56,15 +52,15 @@ describe('helpers.ts - inference request building', () => {
type: 'json_schema', type: 'json_schema',
json_schema: { json_schema: {
name: 'test_schema', name: 'test_schema',
schema: { type: 'object' } schema: {type: 'object'},
} },
}) })
}) })
it('should return undefined for text format', () => { it('should return undefined for text format', () => {
const promptConfig: PromptConfig = { const promptConfig: PromptConfig = {
messages: [], messages: [],
responseFormat: 'text' responseFormat: 'text',
} }
const result = buildResponseFormat(promptConfig) const result = buildResponseFormat(promptConfig)
@@ -73,7 +69,7 @@ describe('helpers.ts - inference request building', () => {
it('should return undefined when no response format specified', () => { it('should return undefined when no response format specified', () => {
const promptConfig: PromptConfig = { const promptConfig: PromptConfig = {
messages: [] messages: [],
} }
const result = buildResponseFormat(promptConfig) const result = buildResponseFormat(promptConfig)
@@ -84,12 +80,10 @@ describe('helpers.ts - inference request building', () => {
const promptConfig: PromptConfig = { const promptConfig: PromptConfig = {
messages: [], messages: [],
responseFormat: 'json_schema', responseFormat: 'json_schema',
jsonSchema: 'invalid json' jsonSchema: 'invalid json',
} }
expect(() => buildResponseFormat(promptConfig)).toThrow( expect(() => buildResponseFormat(promptConfig)).toThrow('Invalid JSON schema')
'Invalid JSON schema'
)
}) })
}) })
@@ -97,14 +91,14 @@ describe('helpers.ts - inference request building', () => {
it('should build complete inference request from prompt config', () => { it('should build complete inference request from prompt config', () => {
const promptConfig: PromptConfig = { const promptConfig: PromptConfig = {
messages: [ messages: [
{ role: 'system', content: 'System message' }, {role: 'system', content: 'System message'},
{ role: 'user', content: 'User message' } {role: 'user', content: 'User message'},
], ],
responseFormat: 'json_schema', responseFormat: 'json_schema',
jsonSchema: JSON.stringify({ jsonSchema: JSON.stringify({
name: 'test_schema', name: 'test_schema',
schema: { type: 'object' } schema: {type: 'object'},
}) }),
} }
const result = buildInferenceRequest( const result = buildInferenceRequest(
@@ -114,13 +108,13 @@ describe('helpers.ts - inference request building', () => {
'gpt-4', 'gpt-4',
100, 100,
'https://api.test.com', 'https://api.test.com',
'test-token' 'test-token',
) )
expect(result).toEqual({ expect(result).toEqual({
messages: [ messages: [
{ role: 'system', content: 'System message' }, {role: 'system', content: 'System message'},
{ role: 'user', content: 'User message' } {role: 'user', content: 'User message'},
], ],
modelName: 'gpt-4', modelName: 'gpt-4',
maxTokens: 100, maxTokens: 100,
@@ -130,9 +124,9 @@ describe('helpers.ts - inference request building', () => {
type: 'json_schema', type: 'json_schema',
json_schema: { json_schema: {
name: 'test_schema', name: 'test_schema',
schema: { type: 'object' } schema: {type: 'object'},
} },
} },
}) })
}) })
@@ -144,19 +138,19 @@ describe('helpers.ts - inference request building', () => {
'gpt-4', 'gpt-4',
100, 100,
'https://api.test.com', 'https://api.test.com',
'test-token' 'test-token',
) )
expect(result).toEqual({ expect(result).toEqual({
messages: [ messages: [
{ role: 'system', content: 'System prompt' }, {role: 'system', content: 'System prompt'},
{ role: 'user', content: 'User prompt' } {role: 'user', content: 'User prompt'},
], ],
modelName: 'gpt-4', modelName: 'gpt-4',
maxTokens: 100, maxTokens: 100,
endpoint: 'https://api.test.com', endpoint: 'https://api.test.com',
token: 'test-token', token: 'test-token',
responseFormat: undefined responseFormat: undefined,
}) })
}) })
}) })
+5 -13
View File
@@ -1,4 +1,4 @@
import { vi, it, expect, beforeEach, describe } from 'vitest' import {vi, it, expect, beforeEach, describe} from 'vitest'
import * as core from '../__fixtures__/core.js' import * as core from '../__fixtures__/core.js'
const mockExistsSync = vi.fn() const mockExistsSync = vi.fn()
@@ -6,12 +6,12 @@ const mockReadFileSync = vi.fn()
vi.mock('fs', () => ({ vi.mock('fs', () => ({
existsSync: mockExistsSync, existsSync: mockExistsSync,
readFileSync: mockReadFileSync readFileSync: mockReadFileSync,
})) }))
vi.mock('@actions/core', () => core) vi.mock('@actions/core', () => core)
const { loadContentFromFileOrInput } = await import('../src/helpers.js') const {loadContentFromFileOrInput} = await import('../src/helpers.js')
describe('helpers.ts', () => { describe('helpers.ts', () => {
beforeEach(() => { beforeEach(() => {
@@ -103,11 +103,7 @@ describe('helpers.ts', () => {
core.getInput.mockImplementation(() => '') core.getInput.mockImplementation(() => '')
const result = loadContentFromFileOrInput( const result = loadContentFromFileOrInput('file-input', 'content-input', defaultValue)
'file-input',
'content-input',
defaultValue
)
expect(result).toBe(defaultValue) expect(result).toBe(defaultValue)
expect(mockExistsSync).not.toHaveBeenCalled() expect(mockExistsSync).not.toHaveBeenCalled()
@@ -131,11 +127,7 @@ describe('helpers.ts', () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
core.getInput.mockImplementation(() => undefined as any) core.getInput.mockImplementation(() => undefined as any)
const result = loadContentFromFileOrInput( const result = loadContentFromFileOrInput('file-input', 'content-input', defaultValue)
'file-input',
'content-input',
defaultValue
)
expect(result).toBe(defaultValue) expect(result).toBe(defaultValue)
}) })
+87 -113
View File
@@ -1,48 +1,41 @@
import { import {vi, type MockedFunction, beforeEach, expect, describe, it} from 'vitest'
vi,
type MockedFunction,
beforeEach,
expect,
describe,
it
} from 'vitest'
import * as core from '../__fixtures__/core.js' import * as core from '../__fixtures__/core.js'
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
const mockPost = vi.fn() as MockedFunction<any> const mockPost = vi.fn() as MockedFunction<any>
const mockPath = vi.fn(() => ({ post: mockPost })) const mockPath = vi.fn(() => ({post: mockPost}))
const mockClient = vi.fn(() => ({ path: mockPath })) const mockClient = vi.fn(() => ({path: mockPath}))
vi.mock('@azure-rest/ai-inference', () => ({ vi.mock('@azure-rest/ai-inference', () => ({
default: mockClient, default: mockClient,
isUnexpected: vi.fn(() => false) isUnexpected: vi.fn(() => false),
})) }))
vi.mock('@azure/core-auth', () => ({ vi.mock('@azure/core-auth', () => ({
AzureKeyCredential: vi.fn() AzureKeyCredential: vi.fn(),
})) }))
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
const mockExecuteToolCalls = vi.fn() as MockedFunction<any> const mockExecuteToolCalls = vi.fn() as MockedFunction<any>
vi.mock('../src/mcp.js', () => ({ vi.mock('../src/mcp.js', () => ({
executeToolCalls: mockExecuteToolCalls executeToolCalls: mockExecuteToolCalls,
})) }))
vi.mock('@actions/core', () => core) vi.mock('@actions/core', () => core)
// Import the module being tested // Import the module being tested
const { simpleInference, mcpInference } = await import('../src/inference.js') const {simpleInference, mcpInference} = await import('../src/inference.js')
describe('inference.ts', () => { describe('inference.ts', () => {
const mockRequest = { const mockRequest = {
messages: [ messages: [
{ role: 'system', content: 'You are a test assistant' }, {role: 'system', content: 'You are a test assistant'},
{ role: 'user', content: 'Hello, AI!' } {role: 'user', content: 'Hello, AI!'},
], ],
modelName: 'gpt-4', modelName: 'gpt-4',
maxTokens: 100, maxTokens: 100,
endpoint: 'https://api.test.com', endpoint: 'https://api.test.com',
token: 'test-token' token: 'test-token',
} }
beforeEach(() => { beforeEach(() => {
@@ -56,11 +49,11 @@ describe('inference.ts', () => {
choices: [ choices: [
{ {
message: { message: {
content: 'Hello, user!' content: 'Hello, user!',
} },
} },
] ],
} },
} }
mockPost.mockResolvedValue(mockResponse) mockPost.mockResolvedValue(mockResponse)
@@ -68,9 +61,7 @@ describe('inference.ts', () => {
const result = await simpleInference(mockRequest) const result = await simpleInference(mockRequest)
expect(result).toBe('Hello, user!') expect(result).toBe('Hello, user!')
expect(core.info).toHaveBeenCalledWith( expect(core.info).toHaveBeenCalledWith('Running simple inference without tools')
'Running simple inference without tools'
)
expect(core.info).toHaveBeenCalledWith('Model response: Hello, user!') expect(core.info).toHaveBeenCalledWith('Model response: Hello, user!')
// Verify the request structure // Verify the request structure
@@ -79,16 +70,16 @@ describe('inference.ts', () => {
messages: [ messages: [
{ {
role: 'system', role: 'system',
content: 'You are a test assistant' content: 'You are a test assistant',
}, },
{ {
role: 'user', role: 'user',
content: 'Hello, AI!' content: 'Hello, AI!',
} },
], ],
max_tokens: 100, max_tokens: 100,
model: 'gpt-4' model: 'gpt-4',
} },
}) })
}) })
@@ -98,11 +89,11 @@ describe('inference.ts', () => {
choices: [ choices: [
{ {
message: { message: {
content: null content: null,
} },
} },
] ],
} },
} }
mockPost.mockResolvedValue(mockResponse) mockPost.mockResolvedValue(mockResponse)
@@ -110,9 +101,7 @@ describe('inference.ts', () => {
const result = await simpleInference(mockRequest) const result = await simpleInference(mockRequest)
expect(result).toBeNull() expect(result).toBeNull()
expect(core.info).toHaveBeenCalledWith( expect(core.info).toHaveBeenCalledWith('Model response: No response content')
'Model response: No response content'
)
}) })
}) })
@@ -126,10 +115,10 @@ describe('inference.ts', () => {
function: { function: {
name: 'test-tool', name: 'test-tool',
description: 'A test tool', description: 'A test tool',
parameters: { type: 'object' } parameters: {type: 'object'},
} },
} },
] ],
} }
it('performs inference without tool calls', async () => { it('performs inference without tool calls', async () => {
@@ -139,11 +128,11 @@ describe('inference.ts', () => {
{ {
message: { message: {
content: 'Hello, user!', content: 'Hello, user!',
tool_calls: null tool_calls: null,
} },
} },
] ],
} },
} }
mockPost.mockResolvedValue(mockResponse) mockPost.mockResolvedValue(mockResponse)
@@ -151,13 +140,9 @@ describe('inference.ts', () => {
const result = await mcpInference(mockRequest, mockMcpClient) const result = await mcpInference(mockRequest, mockMcpClient)
expect(result).toBe('Hello, user!') expect(result).toBe('Hello, user!')
expect(core.info).toHaveBeenCalledWith( expect(core.info).toHaveBeenCalledWith('Running GitHub MCP inference with tools')
'Running GitHub MCP inference with tools'
)
expect(core.info).toHaveBeenCalledWith('MCP inference iteration 1') expect(core.info).toHaveBeenCalledWith('MCP inference iteration 1')
expect(core.info).toHaveBeenCalledWith( expect(core.info).toHaveBeenCalledWith('No tool calls requested, ending GitHub MCP inference loop')
'No tool calls requested, ending GitHub MCP inference loop'
)
// The MCP inference loop will always add the assistant message, even when there are no tool calls // The MCP inference loop will always add the assistant message, even when there are no tool calls
// So we don't check the exact messages, just that tools were included // So we don't check the exact messages, just that tools were included
@@ -175,9 +160,9 @@ describe('inference.ts', () => {
id: 'call-123', id: 'call-123',
function: { function: {
name: 'test-tool', name: 'test-tool',
arguments: '{"param": "value"}' arguments: '{"param": "value"}',
} },
} },
] ]
const toolResults = [ const toolResults = [
@@ -185,8 +170,8 @@ describe('inference.ts', () => {
tool_call_id: 'call-123', tool_call_id: 'call-123',
role: 'tool', role: 'tool',
name: 'test-tool', name: 'test-tool',
content: 'Tool result' content: 'Tool result',
} },
] ]
// First response with tool calls // First response with tool calls
@@ -196,11 +181,11 @@ describe('inference.ts', () => {
{ {
message: { message: {
content: 'I need to use a tool.', content: 'I need to use a tool.',
tool_calls: toolCalls tool_calls: toolCalls,
} },
} },
] ],
} },
} }
// Second response after tool execution // Second response after tool execution
@@ -210,26 +195,21 @@ describe('inference.ts', () => {
{ {
message: { message: {
content: 'Here is the final answer.', content: 'Here is the final answer.',
tool_calls: null tool_calls: null,
} },
} },
] ],
} },
} }
mockPost mockPost.mockResolvedValueOnce(firstResponse).mockResolvedValueOnce(secondResponse)
.mockResolvedValueOnce(firstResponse)
.mockResolvedValueOnce(secondResponse)
mockExecuteToolCalls.mockResolvedValue(toolResults) mockExecuteToolCalls.mockResolvedValue(toolResults)
const result = await mcpInference(mockRequest, mockMcpClient) const result = await mcpInference(mockRequest, mockMcpClient)
expect(result).toBe('Here is the final answer.') expect(result).toBe('Here is the final answer.')
expect(mockExecuteToolCalls).toHaveBeenCalledWith( expect(mockExecuteToolCalls).toHaveBeenCalledWith(mockMcpClient.client, toolCalls)
mockMcpClient.client,
toolCalls
)
expect(mockPost).toHaveBeenCalledTimes(2) expect(mockPost).toHaveBeenCalledTimes(2)
// Verify the second call includes the conversation history // Verify the second call includes the conversation history
@@ -247,9 +227,9 @@ describe('inference.ts', () => {
id: 'call-123', id: 'call-123',
function: { function: {
name: 'test-tool', name: 'test-tool',
arguments: '{}' arguments: '{}',
} },
} },
] ]
const toolResults = [ const toolResults = [
@@ -257,8 +237,8 @@ describe('inference.ts', () => {
tool_call_id: 'call-123', tool_call_id: 'call-123',
role: 'tool', role: 'tool',
name: 'test-tool', name: 'test-tool',
content: 'Tool result' content: 'Tool result',
} },
] ]
// Always respond with tool calls to trigger infinite loop // Always respond with tool calls to trigger infinite loop
@@ -268,11 +248,11 @@ describe('inference.ts', () => {
{ {
message: { message: {
content: 'Using tool again.', content: 'Using tool again.',
tool_calls: toolCalls tool_calls: toolCalls,
} },
} },
] ],
} },
} }
mockPost.mockResolvedValue(responseWithToolCalls) mockPost.mockResolvedValue(responseWithToolCalls)
@@ -281,9 +261,7 @@ describe('inference.ts', () => {
const result = await mcpInference(mockRequest, mockMcpClient) const result = await mcpInference(mockRequest, mockMcpClient)
expect(mockPost).toHaveBeenCalledTimes(5) // Max iterations reached expect(mockPost).toHaveBeenCalledTimes(5) // Max iterations reached
expect(core.warning).toHaveBeenCalledWith( expect(core.warning).toHaveBeenCalledWith('GitHub MCP inference loop exceeded maximum iterations (5)')
'GitHub MCP inference loop exceeded maximum iterations (5)'
)
expect(result).toBe('Using tool again.') // Last assistant message expect(result).toBe('Using tool again.') // Last assistant message
}) })
@@ -294,11 +272,11 @@ describe('inference.ts', () => {
{ {
message: { message: {
content: 'Hello, user!', content: 'Hello, user!',
tool_calls: [] tool_calls: [],
} },
} },
] ],
} },
} }
mockPost.mockResolvedValue(mockResponse) mockPost.mockResolvedValue(mockResponse)
@@ -306,9 +284,7 @@ describe('inference.ts', () => {
const result = await mcpInference(mockRequest, mockMcpClient) const result = await mcpInference(mockRequest, mockMcpClient)
expect(result).toBe('Hello, user!') expect(result).toBe('Hello, user!')
expect(core.info).toHaveBeenCalledWith( expect(core.info).toHaveBeenCalledWith('No tool calls requested, ending GitHub MCP inference loop')
'No tool calls requested, ending GitHub MCP inference loop'
)
expect(mockExecuteToolCalls).not.toHaveBeenCalled() expect(mockExecuteToolCalls).not.toHaveBeenCalled()
}) })
@@ -316,8 +292,8 @@ describe('inference.ts', () => {
const toolCalls = [ const toolCalls = [
{ {
id: 'call-123', id: 'call-123',
function: { name: 'test-tool', arguments: '{}' } function: {name: 'test-tool', arguments: '{}'},
} },
] ]
const firstResponse = { const firstResponse = {
@@ -326,11 +302,11 @@ describe('inference.ts', () => {
{ {
message: { message: {
content: 'First message', content: 'First message',
tool_calls: toolCalls tool_calls: toolCalls,
} },
} },
] ],
} },
} }
const secondResponse = { const secondResponse = {
@@ -339,24 +315,22 @@ describe('inference.ts', () => {
{ {
message: { message: {
content: 'Second message', content: 'Second message',
tool_calls: toolCalls tool_calls: toolCalls,
} },
} },
] ],
} },
} }
mockPost mockPost.mockResolvedValueOnce(firstResponse).mockResolvedValue(secondResponse)
.mockResolvedValueOnce(firstResponse)
.mockResolvedValue(secondResponse)
mockExecuteToolCalls.mockResolvedValue([ mockExecuteToolCalls.mockResolvedValue([
{ {
tool_call_id: 'call-123', tool_call_id: 'call-123',
role: 'tool', role: 'tool',
name: 'test-tool', name: 'test-tool',
content: 'result' content: 'result',
} },
]) ])
const result = await mcpInference(mockRequest, mockMcpClient) const result = await mcpInference(mockRequest, mockMcpClient)
+17 -31
View File
@@ -1,12 +1,4 @@
import { import {describe, it, expect, beforeEach, vi, type MockedFunction, type Mock} from 'vitest'
describe,
it,
expect,
beforeEach,
vi,
type MockedFunction,
type Mock
} from 'vitest'
import * as core from '../__fixtures__/core.js' import * as core from '../__fixtures__/core.js'
// Create fs mocks // Create fs mocks
@@ -26,25 +18,25 @@ const mockConnectToGitHubMCP = vi.fn()
vi.mock('fs', () => ({ vi.mock('fs', () => ({
existsSync: mockExistsSync, existsSync: mockExistsSync,
readFileSync: mockReadFileSync, readFileSync: mockReadFileSync,
writeFileSync: mockWriteFileSync writeFileSync: mockWriteFileSync,
})) }))
// Mock the inference functions // Mock the inference functions
vi.mock('../src/inference.js', () => ({ vi.mock('../src/inference.js', () => ({
simpleInference: mockSimpleInference, simpleInference: mockSimpleInference,
mcpInference: mockMcpInference mcpInference: mockMcpInference,
})) }))
// Mock the MCP connection // Mock the MCP connection
vi.mock('../src/mcp.js', () => ({ vi.mock('../src/mcp.js', () => ({
connectToGitHubMCP: mockConnectToGitHubMCP connectToGitHubMCP: mockConnectToGitHubMCP,
})) }))
vi.mock('@actions/core', () => core) vi.mock('@actions/core', () => core)
// The module being tested should be imported dynamically. This ensures that the // The module being tested should be imported dynamically. This ensures that the
// mocks are used in place of any actual dependencies. // mocks are used in place of any actual dependencies.
const { run } = await import('../src/main.js') const {run} = await import('../src/main.js')
describe('main.ts - prompt.yml integration', () => { describe('main.ts - prompt.yml integration', () => {
beforeEach(() => { beforeEach(() => {
@@ -119,29 +111,23 @@ model: openai/gpt-4o
messages: [ messages: [
{ {
role: 'system', role: 'system',
content: 'Be as concise as possible' content: 'Be as concise as possible',
}, },
{ {
role: 'user', role: 'user',
content: 'Compare cats and dogs, please' content: 'Compare cats and dogs, please',
} },
], ],
modelName: 'openai/gpt-4o', modelName: 'openai/gpt-4o',
maxTokens: 200, maxTokens: 200,
endpoint: 'https://models.github.ai/inference', endpoint: 'https://models.github.ai/inference',
token: 'test-token' token: 'test-token',
}) }),
) )
// Verify outputs were set // Verify outputs were set
expect(core.setOutput).toHaveBeenCalledWith( expect(core.setOutput).toHaveBeenCalledWith('response', 'Mocked AI response')
'response', expect(core.setOutput).toHaveBeenCalledWith('response-file', expect.any(String))
'Mocked AI response'
)
expect(core.setOutput).toHaveBeenCalledWith(
'response-file',
expect.any(String)
)
}) })
it('should fall back to legacy format when not using prompt YAML', async () => { it('should fall back to legacy format when not using prompt YAML', async () => {
@@ -173,18 +159,18 @@ model: openai/gpt-4o
messages: [ messages: [
{ {
role: 'system', role: 'system',
content: 'You are helpful' content: 'You are helpful',
}, },
{ {
role: 'user', role: 'user',
content: 'Hello, world!' content: 'Hello, world!',
} },
], ],
modelName: 'openai/gpt-4o', modelName: 'openai/gpt-4o',
maxTokens: 200, maxTokens: 200,
endpoint: 'https://models.github.ai/inference', endpoint: 'https://models.github.ai/inference',
token: 'test-token' token: 'test-token',
}) }),
) )
}) })
}) })
+33 -58
View File
@@ -1,23 +1,12 @@
import { import {vi, describe, expect, it, beforeEach, type MockedFunction} from 'vitest'
vi,
describe,
expect,
it,
beforeEach,
type MockedFunction
} from 'vitest'
import * as core from '../__fixtures__/core.js' import * as core from '../__fixtures__/core.js'
// Default to throwing errors to catch unexpected calls // Default to throwing errors to catch unexpected calls
const mockExistsSync = vi.fn().mockImplementation(() => { const mockExistsSync = vi.fn().mockImplementation(() => {
throw new Error( throw new Error('Unexpected call to existsSync - test should override this implementation')
'Unexpected call to existsSync - test should override this implementation'
)
}) })
const mockReadFileSync = vi.fn().mockImplementation(() => { const mockReadFileSync = vi.fn().mockImplementation(() => {
throw new Error( throw new Error('Unexpected call to readFileSync - test should override this implementation')
'Unexpected call to readFileSync - test should override this implementation'
)
}) })
const mockWriteFileSync = vi.fn() const mockWriteFileSync = vi.fn()
@@ -26,10 +15,7 @@ const mockWriteFileSync = vi.fn()
* @param fileContents - Object mapping file paths to their contents * @param fileContents - Object mapping file paths to their contents
* @param nonExistentFiles - Array of file paths that should be treated as non-existent * @param nonExistentFiles - Array of file paths that should be treated as non-existent
*/ */
function mockFileContent( function mockFileContent(fileContents: Record<string, string> = {}, nonExistentFiles: string[] = []): void {
fileContents: Record<string, string> = {},
nonExistentFiles: string[] = []
): void {
// Mock existsSync to return true for files that exist, false for those that don't // Mock existsSync to return true for files that exist, false for those that don't
mockExistsSync.mockImplementation((...args: unknown[]): boolean => { mockExistsSync.mockImplementation((...args: unknown[]): boolean => {
const [path] = args as [string] const [path] = args as [string]
@@ -59,11 +45,11 @@ function mockInputs(inputs: Record<string, string> = {}): void {
token: 'fake-token', token: 'fake-token',
model: 'gpt-4', model: 'gpt-4',
'max-tokens': '100', 'max-tokens': '100',
endpoint: 'https://api.test.com' endpoint: 'https://api.test.com',
} }
// Combine defaults with user-provided inputs // Combine defaults with user-provided inputs
const allInputs: Record<string, string> = { ...defaultInputs, ...inputs } const allInputs: Record<string, string> = {...defaultInputs, ...inputs}
core.getInput.mockImplementation((name: string) => { core.getInput.mockImplementation((name: string) => {
return allInputs[name] || '' return allInputs[name] || ''
@@ -80,17 +66,13 @@ function mockInputs(inputs: Record<string, string> = {}): void {
*/ */
function verifyStandardResponse(): void { function verifyStandardResponse(): void {
expect(core.setOutput).toHaveBeenNthCalledWith(1, 'response', 'Hello, user!') expect(core.setOutput).toHaveBeenNthCalledWith(1, 'response', 'Hello, user!')
expect(core.setOutput).toHaveBeenNthCalledWith( expect(core.setOutput).toHaveBeenNthCalledWith(2, 'response-file', expect.stringContaining('modelResponse.txt'))
2,
'response-file',
expect.stringContaining('modelResponse.txt')
)
} }
vi.mock('fs', () => ({ vi.mock('fs', () => ({
existsSync: mockExistsSync, existsSync: mockExistsSync,
readFileSync: mockReadFileSync, readFileSync: mockReadFileSync,
writeFileSync: mockWriteFileSync writeFileSync: mockWriteFileSync,
})) }))
// Mock MCP and inference modules // Mock MCP and inference modules
@@ -102,19 +84,19 @@ const mockSimpleInference = vi.fn() as MockedFunction<any>
const mockMcpInference = vi.fn() as MockedFunction<any> const mockMcpInference = vi.fn() as MockedFunction<any>
vi.mock('../src/mcp.js', () => ({ vi.mock('../src/mcp.js', () => ({
connectToGitHubMCP: mockConnectToGitHubMCP connectToGitHubMCP: mockConnectToGitHubMCP,
})) }))
vi.mock('../src/inference.js', () => ({ vi.mock('../src/inference.js', () => ({
simpleInference: mockSimpleInference, simpleInference: mockSimpleInference,
mcpInference: mockMcpInference mcpInference: mockMcpInference,
})) }))
vi.mock('@actions/core', () => core) vi.mock('@actions/core', () => core)
// The module being tested should be imported dynamically. This ensures that the // The module being tested should be imported dynamically. This ensures that the
// mocks are used in place of any actual dependencies. // mocks are used in place of any actual dependencies.
const { run } = await import('../src/main.js') const {run} = await import('../src/main.js')
describe('main.ts', () => { describe('main.ts', () => {
// Reset all mocks before each test // Reset all mocks before each test
@@ -132,7 +114,7 @@ describe('main.ts', () => {
it('Sets the response output', async () => { it('Sets the response output', async () => {
mockInputs({ mockInputs({
prompt: 'Hello, AI!', prompt: 'Hello, AI!',
'system-prompt': 'You are a test assistant.' 'system-prompt': 'You are a test assistant.',
}) })
await run() await run()
@@ -144,36 +126,33 @@ describe('main.ts', () => {
it('Sets a failed status when no prompt is set', async () => { it('Sets a failed status when no prompt is set', async () => {
mockInputs({ mockInputs({
prompt: '', prompt: '',
'prompt-file': '' 'prompt-file': '',
}) })
await run() await run()
expect(core.setFailed).toHaveBeenNthCalledWith( expect(core.setFailed).toHaveBeenNthCalledWith(1, 'Neither prompt-file nor prompt was set')
1,
'Neither prompt-file nor prompt was set'
)
}) })
it('uses simple inference when MCP is disabled', async () => { it('uses simple inference when MCP is disabled', async () => {
mockInputs({ mockInputs({
prompt: 'Hello, AI!', prompt: 'Hello, AI!',
'system-prompt': 'You are a test assistant.', 'system-prompt': 'You are a test assistant.',
'enable-github-mcp': 'false' 'enable-github-mcp': 'false',
}) })
await run() await run()
expect(mockSimpleInference).toHaveBeenCalledWith({ expect(mockSimpleInference).toHaveBeenCalledWith({
messages: [ messages: [
{ role: 'system', content: 'You are a test assistant.' }, {role: 'system', content: 'You are a test assistant.'},
{ role: 'user', content: 'Hello, AI!' } {role: 'user', content: 'Hello, AI!'},
], ],
modelName: 'gpt-4', modelName: 'gpt-4',
maxTokens: 100, maxTokens: 100,
endpoint: 'https://api.test.com', endpoint: 'https://api.test.com',
token: 'fake-token', token: 'fake-token',
responseFormat: undefined responseFormat: undefined,
}) })
expect(mockConnectToGitHubMCP).not.toHaveBeenCalled() expect(mockConnectToGitHubMCP).not.toHaveBeenCalled()
expect(mockMcpInference).not.toHaveBeenCalled() expect(mockMcpInference).not.toHaveBeenCalled()
@@ -184,13 +163,13 @@ describe('main.ts', () => {
const mockMcpClient = { const mockMcpClient = {
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
client: {} as any, client: {} as any,
tools: [{ type: 'function', function: { name: 'test-tool' } }] tools: [{type: 'function', function: {name: 'test-tool'}}],
} }
mockInputs({ mockInputs({
prompt: 'Hello, AI!', prompt: 'Hello, AI!',
'system-prompt': 'You are a test assistant.', 'system-prompt': 'You are a test assistant.',
'enable-github-mcp': 'true' 'enable-github-mcp': 'true',
}) })
mockConnectToGitHubMCP.mockResolvedValue(mockMcpClient) mockConnectToGitHubMCP.mockResolvedValue(mockMcpClient)
@@ -201,12 +180,12 @@ describe('main.ts', () => {
expect(mockMcpInference).toHaveBeenCalledWith( expect(mockMcpInference).toHaveBeenCalledWith(
expect.objectContaining({ expect.objectContaining({
messages: [ messages: [
{ role: 'system', content: 'You are a test assistant.' }, {role: 'system', content: 'You are a test assistant.'},
{ role: 'user', content: 'Hello, AI!' } {role: 'user', content: 'Hello, AI!'},
], ],
token: 'fake-token' token: 'fake-token',
}), }),
mockMcpClient mockMcpClient,
) )
expect(mockSimpleInference).not.toHaveBeenCalled() expect(mockSimpleInference).not.toHaveBeenCalled()
verifyStandardResponse() verifyStandardResponse()
@@ -216,7 +195,7 @@ describe('main.ts', () => {
mockInputs({ mockInputs({
prompt: 'Hello, AI!', prompt: 'Hello, AI!',
'system-prompt': 'You are a test assistant.', 'system-prompt': 'You are a test assistant.',
'enable-github-mcp': 'true' 'enable-github-mcp': 'true',
}) })
mockConnectToGitHubMCP.mockResolvedValue(null) mockConnectToGitHubMCP.mockResolvedValue(null)
@@ -226,9 +205,7 @@ describe('main.ts', () => {
expect(mockConnectToGitHubMCP).toHaveBeenCalledWith('fake-token') expect(mockConnectToGitHubMCP).toHaveBeenCalledWith('fake-token')
expect(mockSimpleInference).toHaveBeenCalled() expect(mockSimpleInference).toHaveBeenCalled()
expect(mockMcpInference).not.toHaveBeenCalled() expect(mockMcpInference).not.toHaveBeenCalled()
expect(core.warning).toHaveBeenCalledWith( expect(core.warning).toHaveBeenCalledWith('MCP connection failed, falling back to simple inference')
'MCP connection failed, falling back to simple inference'
)
verifyStandardResponse() verifyStandardResponse()
}) })
@@ -240,27 +217,27 @@ describe('main.ts', () => {
mockFileContent({ mockFileContent({
[promptFile]: promptContent, [promptFile]: promptContent,
[systemPromptFile]: systemPromptContent [systemPromptFile]: systemPromptContent,
}) })
mockInputs({ mockInputs({
'prompt-file': promptFile, 'prompt-file': promptFile,
'system-prompt-file': systemPromptFile, 'system-prompt-file': systemPromptFile,
'enable-github-mcp': 'false' 'enable-github-mcp': 'false',
}) })
await run() await run()
expect(mockSimpleInference).toHaveBeenCalledWith({ expect(mockSimpleInference).toHaveBeenCalledWith({
messages: [ messages: [
{ role: 'system', content: systemPromptContent }, {role: 'system', content: systemPromptContent},
{ role: 'user', content: promptContent } {role: 'user', content: promptContent},
], ],
modelName: 'gpt-4', modelName: 'gpt-4',
maxTokens: 100, maxTokens: 100,
endpoint: 'https://api.test.com', endpoint: 'https://api.test.com',
token: 'fake-token', token: 'fake-token',
responseFormat: undefined responseFormat: undefined,
}) })
verifyStandardResponse() verifyStandardResponse()
}) })
@@ -271,13 +248,11 @@ describe('main.ts', () => {
mockFileContent({}, [promptFile]) mockFileContent({}, [promptFile])
mockInputs({ mockInputs({
'prompt-file': promptFile 'prompt-file': promptFile,
}) })
await run() await run()
expect(core.setFailed).toHaveBeenCalledWith( expect(core.setFailed).toHaveBeenCalledWith(`File for prompt-file was not found: ${promptFile}`)
`File for prompt-file was not found: ${promptFile}`
)
}) })
}) })
+44 -75
View File
@@ -1,11 +1,4 @@
import { import {vi, type MockedFunction, describe, it, expect, beforeEach} from 'vitest'
vi,
type MockedFunction,
describe,
it,
expect,
beforeEach
} from 'vitest'
import * as core from '../__fixtures__/core.js' import * as core from '../__fixtures__/core.js'
// Mock MCP SDK // Mock MCP SDK
@@ -19,24 +12,22 @@ const mockCallTool = vi.fn() as MockedFunction<any>
const mockClient = { const mockClient = {
connect: mockConnect, connect: mockConnect,
listTools: mockListTools, listTools: mockListTools,
callTool: mockCallTool callTool: mockCallTool,
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any } as any
vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({ vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({
Client: vi.fn(() => mockClient) Client: vi.fn(() => mockClient),
})) }))
vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => ({ vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => ({
StreamableHTTPClientTransport: vi.fn() StreamableHTTPClientTransport: vi.fn(),
})) }))
vi.mock('@actions/core', () => core) vi.mock('@actions/core', () => core)
// Import the module being tested // Import the module being tested
const { connectToGitHubMCP, executeToolCall, executeToolCalls } = await import( const {connectToGitHubMCP, executeToolCall, executeToolCalls} = await import('../src/mcp.js')
'../src/mcp.js'
)
describe('mcp.ts', () => { describe('mcp.ts', () => {
beforeEach(() => { beforeEach(() => {
@@ -50,20 +41,20 @@ describe('mcp.ts', () => {
{ {
name: 'test-tool-1', name: 'test-tool-1',
description: 'Test tool 1', description: 'Test tool 1',
inputSchema: { type: 'object', properties: {} } inputSchema: {type: 'object', properties: {}},
}, },
{ {
name: 'test-tool-2', name: 'test-tool-2',
description: 'Test tool 2', description: 'Test tool 2',
inputSchema: { inputSchema: {
type: 'object', type: 'object',
properties: { param: { type: 'string' } } properties: {param: {type: 'string'}},
} },
} },
] ]
mockConnect.mockResolvedValue(undefined) mockConnect.mockResolvedValue(undefined)
mockListTools.mockResolvedValue({ tools: mockTools }) mockListTools.mockResolvedValue({tools: mockTools})
const result = await connectToGitHubMCP(token) const result = await connectToGitHubMCP(token)
@@ -75,21 +66,13 @@ describe('mcp.ts', () => {
function: { function: {
name: 'test-tool-1', name: 'test-tool-1',
description: 'Test tool 1', description: 'Test tool 1',
parameters: { type: 'object', properties: {} } parameters: {type: 'object', properties: {}},
} },
}) })
expect(core.info).toHaveBeenCalledWith( expect(core.info).toHaveBeenCalledWith('Connecting to GitHub MCP server...')
'Connecting to GitHub MCP server...' expect(core.info).toHaveBeenCalledWith('Successfully connected to GitHub MCP server')
) expect(core.info).toHaveBeenCalledWith('Retrieved 2 tools from GitHub MCP server')
expect(core.info).toHaveBeenCalledWith( expect(core.info).toHaveBeenCalledWith('Mapped 2 GitHub MCP tools for Azure AI Inference')
'Successfully connected to GitHub MCP server'
)
expect(core.info).toHaveBeenCalledWith(
'Retrieved 2 tools from GitHub MCP server'
)
expect(core.info).toHaveBeenCalledWith(
'Mapped 2 GitHub MCP tools for Azure AI Inference'
)
}) })
it('returns null when connection fails', async () => { it('returns null when connection fails', async () => {
@@ -101,27 +84,21 @@ describe('mcp.ts', () => {
const result = await connectToGitHubMCP(token) const result = await connectToGitHubMCP(token)
expect(result).toBeNull() expect(result).toBeNull()
expect(core.warning).toHaveBeenCalledWith( expect(core.warning).toHaveBeenCalledWith('Failed to connect to GitHub MCP server: Error: Connection failed')
'Failed to connect to GitHub MCP server: Error: Connection failed'
)
}) })
it('handles empty tools list', async () => { it('handles empty tools list', async () => {
const token = 'test-token' const token = 'test-token'
mockConnect.mockResolvedValue(undefined) mockConnect.mockResolvedValue(undefined)
mockListTools.mockResolvedValue({ tools: [] }) mockListTools.mockResolvedValue({tools: []})
const result = await connectToGitHubMCP(token) const result = await connectToGitHubMCP(token)
expect(result).not.toBeNull() expect(result).not.toBeNull()
expect(result?.tools).toHaveLength(0) expect(result?.tools).toHaveLength(0)
expect(core.info).toHaveBeenCalledWith( expect(core.info).toHaveBeenCalledWith('Retrieved 0 tools from GitHub MCP server')
'Retrieved 0 tools from GitHub MCP server' expect(core.info).toHaveBeenCalledWith('Mapped 0 GitHub MCP tools for Azure AI Inference')
)
expect(core.info).toHaveBeenCalledWith(
'Mapped 0 GitHub MCP tools for Azure AI Inference'
)
}) })
it('handles undefined tools list', async () => { it('handles undefined tools list', async () => {
@@ -134,9 +111,7 @@ describe('mcp.ts', () => {
expect(result).not.toBeNull() expect(result).not.toBeNull()
expect(result?.tools).toHaveLength(0) expect(result?.tools).toHaveLength(0)
expect(core.info).toHaveBeenCalledWith( expect(core.info).toHaveBeenCalledWith('Retrieved 0 tools from GitHub MCP server')
'Retrieved 0 tools from GitHub MCP server'
)
}) })
}) })
@@ -147,11 +122,11 @@ describe('mcp.ts', () => {
type: 'function', type: 'function',
function: { function: {
name: 'test-tool', name: 'test-tool',
arguments: '{"param": "value"}' arguments: '{"param": "value"}',
} },
} }
const toolResult = { const toolResult = {
content: [{ type: 'text', text: 'Tool execution result' }] content: [{type: 'text', text: 'Tool execution result'}],
} }
mockCallTool.mockResolvedValue(toolResult) mockCallTool.mockResolvedValue(toolResult)
@@ -160,20 +135,16 @@ describe('mcp.ts', () => {
expect(mockCallTool).toHaveBeenCalledWith({ expect(mockCallTool).toHaveBeenCalledWith({
name: 'test-tool', name: 'test-tool',
arguments: { param: 'value' } arguments: {param: 'value'},
}) })
expect(result).toEqual({ expect(result).toEqual({
tool_call_id: 'call-123', tool_call_id: 'call-123',
role: 'tool', role: 'tool',
name: 'test-tool', name: 'test-tool',
content: JSON.stringify(toolResult.content) content: JSON.stringify(toolResult.content),
}) })
expect(core.info).toHaveBeenCalledWith( expect(core.info).toHaveBeenCalledWith('Executing GitHub MCP tool: test-tool with args: {"param": "value"}')
'Executing GitHub MCP tool: test-tool with args: {"param": "value"}' expect(core.info).toHaveBeenCalledWith('GitHub MCP tool test-tool executed successfully')
)
expect(core.info).toHaveBeenCalledWith(
'GitHub MCP tool test-tool executed successfully'
)
}) })
it('handles tool execution errors gracefully', async () => { it('handles tool execution errors gracefully', async () => {
@@ -182,8 +153,8 @@ describe('mcp.ts', () => {
type: 'function', type: 'function',
function: { function: {
name: 'failing-tool', name: 'failing-tool',
arguments: '{"param": "value"}' arguments: '{"param": "value"}',
} },
} }
const toolError = new Error('Tool execution failed') const toolError = new Error('Tool execution failed')
@@ -195,10 +166,10 @@ describe('mcp.ts', () => {
tool_call_id: 'call-456', tool_call_id: 'call-456',
role: 'tool', role: 'tool',
name: 'failing-tool', name: 'failing-tool',
content: 'Error: Error: Tool execution failed' content: 'Error: Error: Tool execution failed',
}) })
expect(core.warning).toHaveBeenCalledWith( expect(core.warning).toHaveBeenCalledWith(
'Failed to execute GitHub MCP tool failing-tool: Error: Tool execution failed' 'Failed to execute GitHub MCP tool failing-tool: Error: Tool execution failed',
) )
}) })
@@ -208,8 +179,8 @@ describe('mcp.ts', () => {
type: 'function', type: 'function',
function: { function: {
name: 'test-tool', name: 'test-tool',
arguments: 'invalid-json' arguments: 'invalid-json',
} },
} }
const result = await executeToolCall(mockClient, toolCall) const result = await executeToolCall(mockClient, toolCall)
@@ -218,9 +189,7 @@ describe('mcp.ts', () => {
expect(result.role).toBe('tool') expect(result.role).toBe('tool')
expect(result.name).toBe('test-tool') expect(result.name).toBe('test-tool')
expect(result.content).toContain('Error:') expect(result.content).toContain('Error:')
expect(core.warning).toHaveBeenCalledWith( expect(core.warning).toHaveBeenCalledWith(expect.stringContaining('Failed to execute GitHub MCP tool test-tool:'))
expect.stringContaining('Failed to execute GitHub MCP tool test-tool:')
)
}) })
}) })
@@ -230,21 +199,21 @@ describe('mcp.ts', () => {
{ {
id: 'call-1', id: 'call-1',
type: 'function', type: 'function',
function: { name: 'tool-1', arguments: '{}' } function: {name: 'tool-1', arguments: '{}'},
}, },
{ {
id: 'call-2', id: 'call-2',
type: 'function', type: 'function',
function: { name: 'tool-2', arguments: '{"param": "value"}' } function: {name: 'tool-2', arguments: '{"param": "value"}'},
} },
] ]
mockCallTool mockCallTool
.mockResolvedValueOnce({ .mockResolvedValueOnce({
content: [{ type: 'text', text: 'Result 1' }] content: [{type: 'text', text: 'Result 1'}],
}) })
.mockResolvedValueOnce({ .mockResolvedValueOnce({
content: [{ type: 'text', text: 'Result 2' }] content: [{type: 'text', text: 'Result 2'}],
}) })
const results = await executeToolCalls(mockClient, toolCalls) const results = await executeToolCalls(mockClient, toolCalls)
@@ -267,18 +236,18 @@ describe('mcp.ts', () => {
{ {
id: 'call-1', id: 'call-1',
type: 'function', type: 'function',
function: { name: 'tool-1', arguments: '{}' } function: {name: 'tool-1', arguments: '{}'},
}, },
{ {
id: 'call-2', id: 'call-2',
type: 'function', type: 'function',
function: { name: 'tool-2', arguments: '{}' } function: {name: 'tool-2', arguments: '{}'},
} },
] ]
mockCallTool mockCallTool
.mockResolvedValueOnce({ .mockResolvedValueOnce({
content: [{ type: 'text', text: 'Result 1' }] content: [{type: 'text', text: 'Result 1'}],
}) })
.mockRejectedValueOnce(new Error('Tool 2 failed')) .mockRejectedValueOnce(new Error('Tool 2 failed'))
+13 -26
View File
@@ -1,12 +1,7 @@
import { describe, it, expect } from 'vitest' import {describe, it, expect} from 'vitest'
import * as path from 'path' import * as path from 'path'
import { fileURLToPath } from 'url' import {fileURLToPath} from 'url'
import { import {parseTemplateVariables, replaceTemplateVariables, loadPromptFile, isPromptYamlFile} from '../src/prompt'
parseTemplateVariables,
replaceTemplateVariables,
loadPromptFile,
isPromptYamlFile
} from '../src/prompt'
const __filename = fileURLToPath(import.meta.url) const __filename = fileURLToPath(import.meta.url)
const __dirname = path.dirname(__filename) const __dirname = path.dirname(__filename)
@@ -19,7 +14,7 @@ a: hello
b: world b: world
` `
const result = parseTemplateVariables(input) const result = parseTemplateVariables(input)
expect(result).toEqual({ a: 'hello', b: 'world' }) expect(result).toEqual({a: 'hello', b: 'world'})
}) })
it('should parse multiline variables', () => { it('should parse multiline variables', () => {
@@ -49,14 +44,14 @@ var2: |
describe('replaceTemplateVariables', () => { describe('replaceTemplateVariables', () => {
it('should replace simple variables', () => { it('should replace simple variables', () => {
const text = 'Hello {{name}}, welcome to {{place}}!' const text = 'Hello {{name}}, welcome to {{place}}!'
const variables = { name: 'John', place: 'GitHub' } const variables = {name: 'John', place: 'GitHub'}
const result = replaceTemplateVariables(text, variables) const result = replaceTemplateVariables(text, variables)
expect(result).toBe('Hello John, welcome to GitHub!') expect(result).toBe('Hello John, welcome to GitHub!')
}) })
it('should leave unreplaced variables as is', () => { it('should leave unreplaced variables as is', () => {
const text = 'Hello {{name}}, welcome to {{unknown}}!' const text = 'Hello {{name}}, welcome to {{unknown}}!'
const variables = { name: 'John' } const variables = {name: 'John'}
const result = replaceTemplateVariables(text, variables) const result = replaceTemplateVariables(text, variables)
expect(result).toBe('Hello John, welcome to {{unknown}}!') expect(result).toBe('Hello John, welcome to {{unknown}}!')
}) })
@@ -90,31 +85,25 @@ var2: |
describe('loadPromptFile', () => { describe('loadPromptFile', () => {
it('should load simple prompt file', () => { it('should load simple prompt file', () => {
const filePath = path.join( const filePath = path.join(__dirname, '../__fixtures__/prompts/simple.prompt.yml')
__dirname, const variables = {a: 'cats', b: 'dogs'}
'../__fixtures__/prompts/simple.prompt.yml'
)
const variables = { a: 'cats', b: 'dogs' }
const result = loadPromptFile(filePath, variables) const result = loadPromptFile(filePath, variables)
expect(result.messages).toHaveLength(2) expect(result.messages).toHaveLength(2)
expect(result.messages[0]).toEqual({ expect(result.messages[0]).toEqual({
role: 'system', role: 'system',
content: 'Be as concise as possible' content: 'Be as concise as possible',
}) })
expect(result.messages[1]).toEqual({ expect(result.messages[1]).toEqual({
role: 'user', role: 'user',
content: 'Compare cats and dogs, please' content: 'Compare cats and dogs, please',
}) })
expect(result.model).toBe('openai/gpt-4o') expect(result.model).toBe('openai/gpt-4o')
}) })
it('should load JSON schema prompt file', () => { it('should load JSON schema prompt file', () => {
const filePath = path.join( const filePath = path.join(__dirname, '../__fixtures__/prompts/json-schema.prompt.yml')
__dirname, const variables = {animal: 'dog'}
'../__fixtures__/prompts/json-schema.prompt.yml'
)
const variables = { animal: 'dog' }
const result = loadPromptFile(filePath, variables) const result = loadPromptFile(filePath, variables)
expect(result.messages).toHaveLength(2) expect(result.messages).toHaveLength(2)
@@ -125,9 +114,7 @@ var2: |
}) })
it('should throw error for non-existent file', () => { it('should throw error for non-existent file', () => {
expect(() => loadPromptFile('non-existent.prompt.yml')).toThrow( expect(() => loadPromptFile('non-existent.prompt.yml')).toThrow('Prompt file not found')
'Prompt file not found'
)
}) })
}) })
}) })
+1 -2
View File
@@ -14,8 +14,7 @@ inputs:
required: false required: false
default: '' default: ''
prompt-file: prompt-file:
description: description: Path to a file containing the prompt (supports .txt and .prompt.yml
Path to a file containing the prompt (supports .txt and .prompt.yml
formats) formats)
required: false required: false
default: '' default: ''
+15 -15
View File
@@ -1,43 +1,43 @@
// See: https://eslint.org/docs/latest/use/configure/configuration-files // See: https://eslint.org/docs/latest/use/configure/configuration-files
import { FlatCompat } from '@eslint/eslintrc' import {FlatCompat} from '@eslint/eslintrc'
import js from '@eslint/js' import js from '@eslint/js'
import typescriptEslint from '@typescript-eslint/eslint-plugin' import typescriptEslint from '@typescript-eslint/eslint-plugin'
import tsParser from '@typescript-eslint/parser' import tsParser from '@typescript-eslint/parser'
import prettier from 'eslint-plugin-prettier' import prettier from 'eslint-plugin-prettier'
import globals from 'globals' import globals from 'globals'
import path from 'node:path' import path from 'node:path'
import { fileURLToPath } from 'node:url' import {fileURLToPath} from 'node:url'
const __filename = fileURLToPath(import.meta.url) const __filename = fileURLToPath(import.meta.url)
const __dirname = path.dirname(__filename) const __dirname = path.dirname(__filename)
const compat = new FlatCompat({ const compat = new FlatCompat({
baseDirectory: __dirname, baseDirectory: __dirname,
recommendedConfig: js.configs.recommended, recommendedConfig: js.configs.recommended,
allConfig: js.configs.all allConfig: js.configs.all,
}) })
export default [ export default [
{ {
ignores: ['**/coverage', '**/dist', '**/linter', '**/node_modules'] ignores: ['**/coverage', '**/dist', '**/linter', '**/node_modules'],
}, },
...compat.extends( ...compat.extends(
'eslint:recommended', 'eslint:recommended',
'plugin:@typescript-eslint/eslint-recommended', 'plugin:@typescript-eslint/eslint-recommended',
'plugin:@typescript-eslint/recommended', 'plugin:@typescript-eslint/recommended',
'plugin:prettier/recommended' 'plugin:prettier/recommended',
), ),
{ {
plugins: { plugins: {
prettier, prettier,
'@typescript-eslint': typescriptEslint '@typescript-eslint': typescriptEslint,
}, },
languageOptions: { languageOptions: {
globals: { globals: {
...globals.node, ...globals.node,
Atomics: 'readonly', Atomics: 'readonly',
SharedArrayBuffer: 'readonly' SharedArrayBuffer: 'readonly',
}, },
parser: tsParser, parser: tsParser,
@@ -46,17 +46,17 @@ export default [
parserOptions: { parserOptions: {
project: ['tsconfig.eslint.json'], project: ['tsconfig.eslint.json'],
tsconfigRootDir: '.' tsconfigRootDir: '.',
} },
}, },
settings: { settings: {
'import/resolver': { 'import/resolver': {
typescript: { typescript: {
alwaysTryTypes: true, alwaysTryTypes: true,
project: 'tsconfig.eslint.json' project: 'tsconfig.eslint.json',
} },
} },
}, },
rules: { rules: {
@@ -68,7 +68,7 @@ export default [
'no-console': 'off', 'no-console': 'off',
'no-shadow': 'off', 'no-shadow': 'off',
'no-unused-vars': 'off', 'no-unused-vars': 'off',
'prettier/prettier': 'error' 'prettier/prettier': 'error',
} },
} },
] ]
+8
View File
@@ -20,6 +20,7 @@
"@azure/core-sse": "latest", "@azure/core-sse": "latest",
"@eslint/compat": "^1.3.0", "@eslint/compat": "^1.3.0",
"@github/local-action": "^5.1.0", "@github/local-action": "^5.1.0",
"@github/prettier-config": "^0.0.6",
"@rollup/plugin-commonjs": "^28.0.5", "@rollup/plugin-commonjs": "^28.0.5",
"@rollup/plugin-json": "^6.1.0", "@rollup/plugin-json": "^6.1.0",
"@rollup/plugin-node-resolve": "^16.0.1", "@rollup/plugin-node-resolve": "^16.0.1",
@@ -1481,6 +1482,13 @@
"node": ">=20.18.1" "node": ">=20.18.1"
} }
}, },
"node_modules/@github/prettier-config": {
"version": "0.0.6",
"resolved": "https://registry.npmjs.org/@github/prettier-config/-/prettier-config-0.0.6.tgz",
"integrity": "sha512-Sdb089z+QbGnFF2NivbDeaJ62ooPlD31wE6Fkb/ESjAOXSjNJo+gjqzYYhlM7G3ERJmKFZRUJYMlsqB7Tym8lQ==",
"dev": true,
"license": "MIT"
},
"node_modules/@humanfs/core": { "node_modules/@humanfs/core": {
"version": "0.19.1", "version": "0.19.1",
"resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz", "resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz",
+5 -3
View File
@@ -21,6 +21,7 @@
"all": "npm run format:write && npm run lint && npm run test && npm run package" "all": "npm run format:write && npm run lint && npm run test && npm run package"
}, },
"license": "MIT", "license": "MIT",
"prettier": "@github/prettier-config",
"dependencies": { "dependencies": {
"@actions/core": "^1.11.1", "@actions/core": "^1.11.1",
"@modelcontextprotocol/sdk": "^1.15.1", "@modelcontextprotocol/sdk": "^1.15.1",
@@ -33,9 +34,12 @@
"@azure/core-sse": "latest", "@azure/core-sse": "latest",
"@eslint/compat": "^1.3.0", "@eslint/compat": "^1.3.0",
"@github/local-action": "^5.1.0", "@github/local-action": "^5.1.0",
"@github/prettier-config": "^0.0.6",
"@rollup/plugin-commonjs": "^28.0.5", "@rollup/plugin-commonjs": "^28.0.5",
"@rollup/plugin-json": "^6.1.0",
"@rollup/plugin-node-resolve": "^16.0.1", "@rollup/plugin-node-resolve": "^16.0.1",
"@rollup/plugin-typescript": "^12.1.2", "@rollup/plugin-typescript": "^12.1.2",
"@types/js-yaml": "^4.0.9",
"@types/node": "^22.15.31", "@types/node": "^22.15.31",
"@typescript-eslint/eslint-plugin": "^8.34.0", "@typescript-eslint/eslint-plugin": "^8.34.0",
"@typescript-eslint/parser": "^8.32.1", "@typescript-eslint/parser": "^8.32.1",
@@ -48,9 +52,7 @@
"prettier-eslint": "^16.4.2", "prettier-eslint": "^16.4.2",
"rollup": "^4.43.0", "rollup": "^4.43.0",
"typescript": "^5.8.3", "typescript": "^5.8.3",
"vitest": "^3", "vitest": "^3"
"@rollup/plugin-json": "^6.1.0",
"@types/js-yaml": "^4.0.9"
}, },
"optionalDependencies": { "optionalDependencies": {
"@rollup/rollup-linux-x64-gnu": "*" "@rollup/rollup-linux-x64-gnu": "*"
+6 -6
View File
@@ -1,5 +1,5 @@
// See: https://rollupjs.org/introduction/ // See: https://rollupjs.org/introduction/
import { builtinModules } from 'node:module' import {builtinModules} from 'node:module'
import commonjs from '@rollup/plugin-commonjs' import commonjs from '@rollup/plugin-commonjs'
import nodeResolve from '@rollup/plugin-node-resolve' import nodeResolve from '@rollup/plugin-node-resolve'
import typescript from '@rollup/plugin-typescript' import typescript from '@rollup/plugin-typescript'
@@ -11,7 +11,7 @@ const config = {
esModule: true, esModule: true,
file: 'dist/index.js', file: 'dist/index.js',
format: 'es', format: 'es',
sourcemap: true sourcemap: true,
}, },
external: [...builtinModules, /^node:/], external: [...builtinModules, /^node:/],
plugins: [ plugins: [
@@ -19,13 +19,13 @@ const config = {
nodeResolve({ nodeResolve({
preferBuiltins: true, preferBuiltins: true,
browser: false, browser: false,
exportConditions: ['node'] exportConditions: ['node'],
}), }),
commonjs({ commonjs({
include: /node_modules/ include: /node_modules/,
}), }),
json() json(),
] ],
} }
export default config export default config
+20 -33
View File
@@ -1,8 +1,8 @@
import * as core from '@actions/core' import * as core from '@actions/core'
import { GetChatCompletionsDefaultResponse } from '@azure-rest/ai-inference' import {GetChatCompletionsDefaultResponse} from '@azure-rest/ai-inference'
import * as fs from 'fs' import * as fs from 'fs'
import { PromptConfig } from './prompt.js' import {PromptConfig} from './prompt.js'
import { InferenceRequest } from './inference.js' import {InferenceRequest} from './inference.js'
/** /**
* Helper function to load content from a file or use fallback input * Helper function to load content from a file or use fallback input
@@ -11,11 +11,7 @@ import { InferenceRequest } from './inference.js'
* @param defaultValue - Default value to use if neither file nor content is provided * @param defaultValue - Default value to use if neither file nor content is provided
* @returns The loaded content * @returns The loaded content
*/ */
export function loadContentFromFileOrInput( export function loadContentFromFileOrInput(filePathInput: string, contentInput: string, defaultValue?: string): string {
filePathInput: string,
contentInput: string,
defaultValue?: string
): string {
const filePath = core.getInput(filePathInput) const filePath = core.getInput(filePathInput)
const contentString = core.getInput(contentInput) const contentString = core.getInput(contentInput)
@@ -38,9 +34,7 @@ export function loadContentFromFileOrInput(
* @param response - The response object from the AI service * @param response - The response object from the AI service
* @throws Error with appropriate error message based on response content * @throws Error with appropriate error message based on response content
*/ */
export function handleUnexpectedResponse( export function handleUnexpectedResponse(response: GetChatCompletionsDefaultResponse): never {
response: GetChatCompletionsDefaultResponse
): never {
// Extract x-ms-error-code from headers if available // Extract x-ms-error-code from headers if available
const errorCode = response.headers['x-ms-error-code'] const errorCode = response.headers['x-ms-error-code']
const errorCodeMsg = errorCode ? ` (error code: ${errorCode})` : '' const errorCodeMsg = errorCode ? ` (error code: ${errorCode})` : ''
@@ -54,16 +48,14 @@ export function handleUnexpectedResponse(
if (!response.body) { if (!response.body) {
throw new Error( throw new Error(
`Failed to get response from AI service (status: ${response.status})${errorCodeMsg}. ` + `Failed to get response from AI service (status: ${response.status})${errorCodeMsg}. ` +
'Please check network connection and endpoint configuration.' 'Please check network connection and endpoint configuration.',
) )
} }
// Handle other error cases // Handle other error cases
throw new Error( throw new Error(
`AI service returned error response (status: ${response.status})${errorCodeMsg}: ` + `AI service returned error response (status: ${response.status})${errorCodeMsg}: ` +
(typeof response.body === 'string' (typeof response.body === 'string' ? response.body : JSON.stringify(response.body)),
? response.body
: JSON.stringify(response.body))
) )
} }
@@ -73,22 +65,22 @@ export function handleUnexpectedResponse(
export function buildMessages( export function buildMessages(
promptConfig?: PromptConfig, promptConfig?: PromptConfig,
systemPrompt?: string, systemPrompt?: string,
prompt?: string prompt?: string,
): Array<{ role: string; content: string }> { ): Array<{role: string; content: string}> {
if (promptConfig?.messages && promptConfig.messages.length > 0) { if (promptConfig?.messages && promptConfig.messages.length > 0) {
// Use new message format // Use new message format
return promptConfig.messages.map((msg) => ({ return promptConfig.messages.map(msg => ({
role: msg.role, role: msg.role,
content: msg.content content: msg.content,
})) }))
} else { } else {
// Use legacy format // Use legacy format
return [ return [
{ {
role: 'system', role: 'system',
content: systemPrompt || 'You are a helpful assistant' content: systemPrompt || 'You are a helpful assistant',
}, },
{ role: 'user', content: prompt || '' } {role: 'user', content: prompt || ''},
] ]
} }
} }
@@ -97,22 +89,17 @@ export function buildMessages(
* Build response format object for API from prompt config * Build response format object for API from prompt config
*/ */
export function buildResponseFormat( export function buildResponseFormat(
promptConfig?: PromptConfig promptConfig?: PromptConfig,
): { type: 'json_schema'; json_schema: unknown } | undefined { ): {type: 'json_schema'; json_schema: unknown} | undefined {
if ( if (promptConfig?.responseFormat === 'json_schema' && promptConfig.jsonSchema) {
promptConfig?.responseFormat === 'json_schema' &&
promptConfig.jsonSchema
) {
try { try {
const schema = JSON.parse(promptConfig.jsonSchema) const schema = JSON.parse(promptConfig.jsonSchema)
return { return {
type: 'json_schema', type: 'json_schema',
json_schema: schema json_schema: schema,
} }
} catch (error) { } catch (error) {
throw new Error( throw new Error(`Invalid JSON schema: ${error instanceof Error ? error.message : 'Unknown error'}`)
`Invalid JSON schema: ${error instanceof Error ? error.message : 'Unknown error'}`
)
} }
} }
return undefined return undefined
@@ -128,7 +115,7 @@ export function buildInferenceRequest(
modelName: string, modelName: string,
maxTokens: number, maxTokens: number,
endpoint: string, endpoint: string,
token: string token: string,
): InferenceRequest { ): InferenceRequest {
const messages = buildMessages(promptConfig, systemPrompt, prompt) const messages = buildMessages(promptConfig, systemPrompt, prompt)
const responseFormat = buildResponseFormat(promptConfig) const responseFormat = buildResponseFormat(promptConfig)
@@ -139,6 +126,6 @@ export function buildInferenceRequest(
maxTokens, maxTokens,
endpoint, endpoint,
token, token,
responseFormat responseFormat,
} }
} }
+1 -1
View File
@@ -2,7 +2,7 @@
* The entrypoint for the action. This file simply imports and runs the action's * The entrypoint for the action. This file simply imports and runs the action's
* main logic. * main logic.
*/ */
import { run } from './main.js' import {run} from './main.js'
/* istanbul ignore next */ /* istanbul ignore next */
run() run()
+23 -38
View File
@@ -1,8 +1,8 @@
import * as core from '@actions/core' import * as core from '@actions/core'
import ModelClient, { isUnexpected } from '@azure-rest/ai-inference' import ModelClient, {isUnexpected} from '@azure-rest/ai-inference'
import { AzureKeyCredential } from '@azure/core-auth' import {AzureKeyCredential} from '@azure/core-auth'
import { GitHubMCPClient, executeToolCalls, MCPTool, ToolCall } from './mcp.js' import {GitHubMCPClient, executeToolCalls, MCPTool, ToolCall} from './mcp.js'
import { handleUnexpectedResponse } from './helpers.js' import {handleUnexpectedResponse} from './helpers.js'
interface ChatMessage { interface ChatMessage {
role: string role: string
@@ -14,17 +14,17 @@ interface ChatCompletionsRequestBody {
messages: ChatMessage[] messages: ChatMessage[]
max_tokens: number max_tokens: number
model: string model: string
response_format?: { type: 'json_schema'; json_schema: unknown } response_format?: {type: 'json_schema'; json_schema: unknown}
tools?: MCPTool[] tools?: MCPTool[]
} }
export interface InferenceRequest { export interface InferenceRequest {
messages: Array<{ role: string; content: string }> messages: Array<{role: string; content: string}>
modelName: string modelName: string
maxTokens: number maxTokens: number
endpoint: string endpoint: string
token: string token: string
responseFormat?: { type: 'json_schema'; json_schema: unknown } // Processed response format for the API responseFormat?: {type: 'json_schema'; json_schema: unknown} // Processed response format for the API
} }
export interface InferenceResponse { export interface InferenceResponse {
@@ -42,23 +42,17 @@ export interface InferenceResponse {
/** /**
* Simple one-shot inference without tools * Simple one-shot inference without tools
*/ */
export async function simpleInference( export async function simpleInference(request: InferenceRequest): Promise<string | null> {
request: InferenceRequest
): Promise<string | null> {
core.info('Running simple inference without tools') core.info('Running simple inference without tools')
const client = ModelClient( const client = ModelClient(request.endpoint, new AzureKeyCredential(request.token), {
request.endpoint, userAgentOptions: {userAgentPrefix: 'github-actions-ai-inference'},
new AzureKeyCredential(request.token), })
{
userAgentOptions: { userAgentPrefix: 'github-actions-ai-inference' }
}
)
const requestBody: ChatCompletionsRequestBody = { const requestBody: ChatCompletionsRequestBody = {
messages: request.messages, messages: request.messages,
max_tokens: request.maxTokens, max_tokens: request.maxTokens,
model: request.modelName model: request.modelName,
} }
// Add response format if specified // Add response format if specified
@@ -67,7 +61,7 @@ export async function simpleInference(
} }
const response = await client.path('/chat/completions').post({ const response = await client.path('/chat/completions').post({
body: requestBody body: requestBody,
}) })
if (isUnexpected(response)) { if (isUnexpected(response)) {
@@ -85,17 +79,13 @@ export async function simpleInference(
*/ */
export async function mcpInference( export async function mcpInference(
request: InferenceRequest, request: InferenceRequest,
githubMcpClient: GitHubMCPClient githubMcpClient: GitHubMCPClient,
): Promise<string | null> { ): Promise<string | null> {
core.info('Running GitHub MCP inference with tools') core.info('Running GitHub MCP inference with tools')
const client = ModelClient( const client = ModelClient(request.endpoint, new AzureKeyCredential(request.token), {
request.endpoint, userAgentOptions: {userAgentPrefix: 'github-actions-ai-inference'},
new AzureKeyCredential(request.token), })
{
userAgentOptions: { userAgentPrefix: 'github-actions-ai-inference' }
}
)
// Start with the pre-processed messages // Start with the pre-processed messages
const messages: ChatMessage[] = [...request.messages] const messages: ChatMessage[] = [...request.messages]
@@ -111,7 +101,7 @@ export async function mcpInference(
messages: messages, messages: messages,
max_tokens: request.maxTokens, max_tokens: request.maxTokens,
model: request.modelName, model: request.modelName,
tools: githubMcpClient.tools tools: githubMcpClient.tools,
} }
// Add response format if specified (only on first iteration to avoid conflicts) // Add response format if specified (only on first iteration to avoid conflicts)
@@ -120,7 +110,7 @@ export async function mcpInference(
} }
const response = await client.path('/chat/completions').post({ const response = await client.path('/chat/completions').post({
body: requestBody body: requestBody,
}) })
if (isUnexpected(response)) { if (isUnexpected(response)) {
@@ -136,7 +126,7 @@ export async function mcpInference(
messages.push({ messages.push({
role: 'assistant', role: 'assistant',
content: modelResponse || '', content: modelResponse || '',
...(toolCalls && { tool_calls: toolCalls }) ...(toolCalls && {tool_calls: toolCalls}),
}) })
if (!toolCalls || toolCalls.length === 0) { if (!toolCalls || toolCalls.length === 0) {
@@ -147,10 +137,7 @@ export async function mcpInference(
core.info(`Model requested ${toolCalls.length} tool calls`) core.info(`Model requested ${toolCalls.length} tool calls`)
// Execute all tool calls via GitHub MCP // Execute all tool calls via GitHub MCP
const toolResults = await executeToolCalls( const toolResults = await executeToolCalls(githubMcpClient.client, toolCalls)
githubMcpClient.client,
toolCalls
)
// Add tool results to the conversation // Add tool results to the conversation
messages.push(...toolResults) messages.push(...toolResults)
@@ -158,15 +145,13 @@ export async function mcpInference(
core.info('Tool results added, continuing conversation...') core.info('Tool results added, continuing conversation...')
} }
core.warning( core.warning(`GitHub MCP inference loop exceeded maximum iterations (${maxIterations})`)
`GitHub MCP inference loop exceeded maximum iterations (${maxIterations})`
)
// Return the last assistant message content // Return the last assistant message content
const lastAssistantMessage = messages const lastAssistantMessage = messages
.slice() .slice()
.reverse() .reverse()
.find((msg) => msg.role === 'assistant') .find(msg => msg.role === 'assistant')
return lastAssistantMessage?.content || null return lastAssistantMessage?.content || null
} }
+6 -15
View File
@@ -2,15 +2,10 @@ import * as core from '@actions/core'
import * as fs from 'fs' import * as fs from 'fs'
import * as os from 'os' import * as os from 'os'
import * as path from 'path' import * as path from 'path'
import { connectToGitHubMCP } from './mcp.js' import {connectToGitHubMCP} from './mcp.js'
import { simpleInference, mcpInference } from './inference.js' import {simpleInference, mcpInference} from './inference.js'
import { loadContentFromFileOrInput, buildInferenceRequest } from './helpers.js' import {loadContentFromFileOrInput, buildInferenceRequest} from './helpers.js'
import { import {loadPromptFile, parseTemplateVariables, isPromptYamlFile, PromptConfig} from './prompt.js'
loadPromptFile,
parseTemplateVariables,
isPromptYamlFile,
PromptConfig
} from './prompt.js'
const RESPONSE_FILE = 'modelResponse.txt' const RESPONSE_FILE = 'modelResponse.txt'
@@ -42,11 +37,7 @@ export async function run(): Promise<void> {
core.info('Using legacy prompt format') core.info('Using legacy prompt format')
prompt = loadContentFromFileOrInput('prompt-file', 'prompt') prompt = loadContentFromFileOrInput('prompt-file', 'prompt')
systemPrompt = loadContentFromFileOrInput( systemPrompt = loadContentFromFileOrInput('system-prompt-file', 'system-prompt', 'You are a helpful assistant')
'system-prompt-file',
'system-prompt',
'You are a helpful assistant'
)
} }
// Get common parameters // Get common parameters
@@ -68,7 +59,7 @@ export async function run(): Promise<void> {
modelName, modelName,
maxTokens, maxTokens,
endpoint, endpoint,
token token,
) )
const enableMcp = core.getBooleanInput('enable-github-mcp') || false const enableMcp = core.getBooleanInput('enable-github-mcp') || false
+19 -33
View File
@@ -1,6 +1,6 @@
import * as core from '@actions/core' import * as core from '@actions/core'
import { Client } from '@modelcontextprotocol/sdk/client/index.js' import {Client} from '@modelcontextprotocol/sdk/client/index.js'
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import {StreamableHTTPClientTransport} from '@modelcontextprotocol/sdk/client/streamableHttp.js'
export interface ToolResult { export interface ToolResult {
tool_call_id: string tool_call_id: string
@@ -35,9 +35,7 @@ export interface GitHubMCPClient {
/** /**
* Connect to the GitHub MCP server and retrieve available tools * Connect to the GitHub MCP server and retrieve available tools
*/ */
export async function connectToGitHubMCP( export async function connectToGitHubMCP(token: string): Promise<GitHubMCPClient | null> {
token: string
): Promise<GitHubMCPClient | null> {
const githubMcpUrl = 'https://api.githubcopilot.com/mcp/' const githubMcpUrl = 'https://api.githubcopilot.com/mcp/'
core.info('Connecting to GitHub MCP server...') core.info('Connecting to GitHub MCP server...')
@@ -46,15 +44,15 @@ export async function connectToGitHubMCP(
requestInit: { requestInit: {
headers: { headers: {
Authorization: `Bearer ${token}`, Authorization: `Bearer ${token}`,
'X-MCP-Readonly': 'true' 'X-MCP-Readonly': 'true',
} },
} },
}) })
const client = new Client({ const client = new Client({
name: 'ai-inference-action', name: 'ai-inference-action',
version: '1.0.0', version: '1.0.0',
transport transport,
}) })
try { try {
@@ -67,42 +65,35 @@ export async function connectToGitHubMCP(
core.info('Successfully connected to GitHub MCP server') core.info('Successfully connected to GitHub MCP server')
const toolsResponse = await client.listTools() const toolsResponse = await client.listTools()
core.info( core.info(`Retrieved ${toolsResponse.tools?.length || 0} tools from GitHub MCP server`)
`Retrieved ${toolsResponse.tools?.length || 0} tools from GitHub MCP server`
)
// Map GitHub MCP tools → Azure AI Inference tool definitions // Map GitHub MCP tools → Azure AI Inference tool definitions
const tools = (toolsResponse.tools || []).map((t) => ({ const tools = (toolsResponse.tools || []).map(t => ({
type: 'function' as const, type: 'function' as const,
function: { function: {
name: t.name, name: t.name,
description: t.description, description: t.description,
parameters: t.inputSchema parameters: t.inputSchema,
} },
})) }))
core.info(`Mapped ${tools.length} GitHub MCP tools for Azure AI Inference`) core.info(`Mapped ${tools.length} GitHub MCP tools for Azure AI Inference`)
return { client, tools } return {client, tools}
} }
/** /**
* Execute a single tool call via GitHub MCP * Execute a single tool call via GitHub MCP
*/ */
export async function executeToolCall( export async function executeToolCall(githubMcpClient: Client, toolCall: ToolCall): Promise<ToolResult> {
githubMcpClient: Client, core.info(`Executing GitHub MCP tool: ${toolCall.function.name} with args: ${toolCall.function.arguments}`)
toolCall: ToolCall
): Promise<ToolResult> {
core.info(
`Executing GitHub MCP tool: ${toolCall.function.name} with args: ${toolCall.function.arguments}`
)
try { try {
const args = JSON.parse(toolCall.function.arguments) const args = JSON.parse(toolCall.function.arguments)
const result = await githubMcpClient.callTool({ const result = await githubMcpClient.callTool({
name: toolCall.function.name, name: toolCall.function.name,
arguments: args arguments: args,
}) })
core.info(`GitHub MCP tool ${toolCall.function.name} executed successfully`) core.info(`GitHub MCP tool ${toolCall.function.name} executed successfully`)
@@ -111,18 +102,16 @@ export async function executeToolCall(
tool_call_id: toolCall.id, tool_call_id: toolCall.id,
role: 'tool', role: 'tool',
name: toolCall.function.name, name: toolCall.function.name,
content: JSON.stringify(result.content) content: JSON.stringify(result.content),
} }
} catch (toolError) { } catch (toolError) {
core.warning( core.warning(`Failed to execute GitHub MCP tool ${toolCall.function.name}: ${toolError}`)
`Failed to execute GitHub MCP tool ${toolCall.function.name}: ${toolError}`
)
return { return {
tool_call_id: toolCall.id, tool_call_id: toolCall.id,
role: 'tool', role: 'tool',
name: toolCall.function.name, name: toolCall.function.name,
content: `Error: ${toolError}` content: `Error: ${toolError}`,
} }
} }
} }
@@ -130,10 +119,7 @@ export async function executeToolCall(
/** /**
* Execute all tool calls from a response via GitHub MCP * Execute all tool calls from a response via GitHub MCP
*/ */
export async function executeToolCalls( export async function executeToolCalls(githubMcpClient: Client, toolCalls: ToolCall[]): Promise<ToolResult[]> {
githubMcpClient: Client,
toolCalls: ToolCall[]
): Promise<ToolResult[]> {
const toolResults: ToolResult[] = [] const toolResults: ToolResult[] = []
for (const toolCall of toolCalls) { for (const toolCall of toolCalls) {
+7 -24
View File
@@ -33,26 +33,19 @@ export function parseTemplateVariables(input: string): TemplateVariables {
} }
return parsed return parsed
} catch (error) { } catch (error) {
throw new Error( throw new Error(`Failed to parse template variables: ${error instanceof Error ? error.message : 'Unknown error'}`)
`Failed to parse template variables: ${error instanceof Error ? error.message : 'Unknown error'}`
)
} }
} }
/** /**
* Replace template variables in text using {{variable}} syntax * Replace template variables in text using {{variable}} syntax
*/ */
export function replaceTemplateVariables( export function replaceTemplateVariables(text: string, variables: TemplateVariables): string {
text: string,
variables: TemplateVariables
): string {
return text.replace(/\{\{([\w.-]+)\}\}/g, (match, variableName) => { return text.replace(/\{\{([\w.-]+)\}\}/g, (match, variableName) => {
if (variableName in variables) { if (variableName in variables) {
return variables[variableName] return variables[variableName]
} }
core.warning( core.warning(`Template variable '${variableName}' not found in input variables`)
`Template variable '${variableName}' not found in input variables`
)
return match // Return the original placeholder if variable not found return match // Return the original placeholder if variable not found
}) })
} }
@@ -60,10 +53,7 @@ export function replaceTemplateVariables(
/** /**
* Load and parse a prompt YAML file with template variable substitution * Load and parse a prompt YAML file with template variable substitution
*/ */
export function loadPromptFile( export function loadPromptFile(filePath: string, templateVariables: TemplateVariables = {}): PromptConfig {
filePath: string,
templateVariables: TemplateVariables = {}
): PromptConfig {
if (!fs.existsSync(filePath)) { if (!fs.existsSync(filePath)) {
throw new Error(`Prompt file not found: ${filePath}`) throw new Error(`Prompt file not found: ${filePath}`)
} }
@@ -71,10 +61,7 @@ export function loadPromptFile(
const fileContent = fs.readFileSync(filePath, 'utf-8') const fileContent = fs.readFileSync(filePath, 'utf-8')
// Apply template variable substitution // Apply template variable substitution
const processedContent = replaceTemplateVariables( const processedContent = replaceTemplateVariables(fileContent, templateVariables)
fileContent,
templateVariables
)
try { try {
const config = yaml.load(processedContent) as PromptConfig const config = yaml.load(processedContent) as PromptConfig
@@ -86,9 +73,7 @@ export function loadPromptFile(
// Validate messages // Validate messages
for (const message of config.messages) { for (const message of config.messages) {
if (!message.role || !message.content) { if (!message.role || !message.content) {
throw new Error( throw new Error('Each message must have "role" and "content" properties')
'Each message must have "role" and "content" properties'
)
} }
if (!['system', 'user', 'assistant'].includes(message.role)) { if (!['system', 'user', 'assistant'].includes(message.role)) {
throw new Error(`Invalid message role: ${message.role}`) throw new Error(`Invalid message role: ${message.role}`)
@@ -97,9 +82,7 @@ export function loadPromptFile(
return config return config
} catch (error) { } catch (error) {
throw new Error( throw new Error(`Failed to parse prompt file: ${error instanceof Error ? error.message : 'Unknown error'}`)
`Failed to parse prompt file: ${error instanceof Error ? error.message : 'Unknown error'}`
)
} }
} }
+1 -7
View File
@@ -6,11 +6,5 @@
"noEmit": true "noEmit": true
}, },
"exclude": ["dist", "node_modules"], "exclude": ["dist", "node_modules"],
"include": [ "include": ["__fixtures__", "__tests__", "src", "eslint.config.mjs", "rollup.config.ts"]
"__fixtures__",
"__tests__",
"src",
"eslint.config.mjs",
"rollup.config.ts"
]
} }