Add tests
This commit is contained in:
@@ -5,6 +5,7 @@ export const debug = jest.fn<typeof core.debug>()
|
||||
export const error = jest.fn<typeof core.error>()
|
||||
export const info = jest.fn<typeof core.info>()
|
||||
export const getInput = jest.fn<typeof core.getInput>()
|
||||
export const getBooleanInput = jest.fn<typeof core.getBooleanInput>()
|
||||
export const setOutput = jest.fn<typeof core.setOutput>()
|
||||
export const setFailed = jest.fn<typeof core.setFailed>()
|
||||
export const warning = jest.fn<typeof core.warning>()
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
/**
|
||||
* Unit tests for the helpers module, src/helpers.ts
|
||||
*/
|
||||
import { jest } from '@jest/globals'
|
||||
import * as core from '../__fixtures__/core.js'
|
||||
|
||||
// Mock fs module
|
||||
const mockExistsSync = jest.fn()
|
||||
const mockReadFileSync = jest.fn()
|
||||
|
||||
jest.unstable_mockModule('fs', () => ({
|
||||
existsSync: mockExistsSync,
|
||||
readFileSync: mockReadFileSync
|
||||
}))
|
||||
|
||||
jest.unstable_mockModule('@actions/core', () => core)
|
||||
|
||||
// Import the module being tested
|
||||
const { loadContentFromFileOrInput } = await import('../src/helpers.js')
|
||||
|
||||
describe('helpers.ts', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('loadContentFromFileOrInput', () => {
|
||||
it('loads content from file when file path is provided', () => {
|
||||
const filePath = '/path/to/file.txt'
|
||||
const fileContent = 'File content here'
|
||||
|
||||
core.getInput.mockImplementation((name: string) => {
|
||||
if (name === 'file-input') return filePath
|
||||
if (name === 'content-input') return ''
|
||||
return ''
|
||||
})
|
||||
|
||||
mockExistsSync.mockReturnValue(true)
|
||||
mockReadFileSync.mockReturnValue(fileContent)
|
||||
|
||||
const result = loadContentFromFileOrInput('file-input', 'content-input')
|
||||
|
||||
expect(core.getInput).toHaveBeenCalledWith('file-input')
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(filePath)
|
||||
expect(mockReadFileSync).toHaveBeenCalledWith(filePath, 'utf-8')
|
||||
expect(result).toBe(fileContent)
|
||||
})
|
||||
|
||||
it('throws error when file path is provided but file does not exist', () => {
|
||||
const filePath = '/path/to/nonexistent.txt'
|
||||
|
||||
core.getInput.mockImplementation((name: string) => {
|
||||
if (name === 'file-input') return filePath
|
||||
if (name === 'content-input') return ''
|
||||
return ''
|
||||
})
|
||||
|
||||
mockExistsSync.mockReturnValue(false)
|
||||
|
||||
expect(() => {
|
||||
loadContentFromFileOrInput('file-input', 'content-input')
|
||||
}).toThrow('File for file-input was not found: /path/to/nonexistent.txt')
|
||||
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(filePath)
|
||||
expect(mockReadFileSync).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses content input when file path is empty', () => {
|
||||
const contentInput = 'Direct content input'
|
||||
|
||||
core.getInput.mockImplementation((name: string) => {
|
||||
if (name === 'file-input') return ''
|
||||
if (name === 'content-input') return contentInput
|
||||
return ''
|
||||
})
|
||||
|
||||
const result = loadContentFromFileOrInput('file-input', 'content-input')
|
||||
|
||||
expect(core.getInput).toHaveBeenCalledWith('file-input')
|
||||
expect(core.getInput).toHaveBeenCalledWith('content-input')
|
||||
expect(mockExistsSync).not.toHaveBeenCalled()
|
||||
expect(mockReadFileSync).not.toHaveBeenCalled()
|
||||
expect(result).toBe(contentInput)
|
||||
})
|
||||
|
||||
it('prefers file path over content input when both are provided', () => {
|
||||
const filePath = '/path/to/file.txt'
|
||||
const fileContent = 'File content'
|
||||
const contentInput = 'Direct content input'
|
||||
|
||||
core.getInput.mockImplementation((name: string) => {
|
||||
if (name === 'file-input') return filePath
|
||||
if (name === 'content-input') return contentInput
|
||||
return ''
|
||||
})
|
||||
|
||||
mockExistsSync.mockReturnValue(true)
|
||||
mockReadFileSync.mockReturnValue(fileContent)
|
||||
|
||||
const result = loadContentFromFileOrInput('file-input', 'content-input')
|
||||
|
||||
expect(result).toBe(fileContent)
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(filePath)
|
||||
expect(mockReadFileSync).toHaveBeenCalledWith(filePath, 'utf-8')
|
||||
})
|
||||
|
||||
it('uses default value when neither file nor content is provided', () => {
|
||||
const defaultValue = 'Default content'
|
||||
|
||||
core.getInput.mockImplementation(() => '')
|
||||
|
||||
const result = loadContentFromFileOrInput(
|
||||
'file-input',
|
||||
'content-input',
|
||||
defaultValue
|
||||
)
|
||||
|
||||
expect(result).toBe(defaultValue)
|
||||
expect(mockExistsSync).not.toHaveBeenCalled()
|
||||
expect(mockReadFileSync).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('throws error when neither file nor content is provided and no default', () => {
|
||||
core.getInput.mockImplementation(() => '')
|
||||
|
||||
expect(() => {
|
||||
loadContentFromFileOrInput('file-input', 'content-input')
|
||||
}).toThrow('Neither file-input nor content-input was set')
|
||||
|
||||
expect(mockExistsSync).not.toHaveBeenCalled()
|
||||
expect(mockReadFileSync).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('handles undefined inputs correctly', () => {
|
||||
const defaultValue = 'Default content'
|
||||
|
||||
core.getInput.mockImplementation(() => undefined as any)
|
||||
|
||||
const result = loadContentFromFileOrInput(
|
||||
'file-input',
|
||||
'content-input',
|
||||
defaultValue
|
||||
)
|
||||
|
||||
expect(result).toBe(defaultValue)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,356 @@
|
||||
/**
|
||||
* Unit tests for the inference module, src/inference.ts
|
||||
*/
|
||||
import { jest } from '@jest/globals'
|
||||
import * as core from '../__fixtures__/core.js'
|
||||
|
||||
// Mock Azure AI Inference
|
||||
const mockPost = jest.fn() as jest.MockedFunction<any>
|
||||
const mockPath = jest.fn(() => ({ post: mockPost }))
|
||||
const mockClient = jest.fn(() => ({ path: mockPath }))
|
||||
|
||||
jest.unstable_mockModule('@azure-rest/ai-inference', () => ({
|
||||
default: mockClient,
|
||||
isUnexpected: jest.fn(() => false)
|
||||
}))
|
||||
|
||||
jest.unstable_mockModule('@azure/core-auth', () => ({
|
||||
AzureKeyCredential: jest.fn()
|
||||
}))
|
||||
|
||||
// Mock MCP functions
|
||||
const mockExecuteToolCalls = jest.fn() as jest.MockedFunction<any>
|
||||
jest.unstable_mockModule('../src/mcp.js', () => ({
|
||||
executeToolCalls: mockExecuteToolCalls
|
||||
}))
|
||||
|
||||
jest.unstable_mockModule('@actions/core', () => core)
|
||||
|
||||
// Import the module being tested
|
||||
const { simpleInference, mcpInference } = await import('../src/inference.js')
|
||||
|
||||
describe('inference.ts', () => {
|
||||
const mockRequest = {
|
||||
systemPrompt: 'You are a test assistant',
|
||||
prompt: 'Hello, AI!',
|
||||
modelName: 'gpt-4',
|
||||
maxTokens: 100,
|
||||
endpoint: 'https://api.test.com',
|
||||
token: 'test-token'
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('simpleInference', () => {
|
||||
it('performs simple inference without tools', async () => {
|
||||
const mockResponse = {
|
||||
body: {
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
content: 'Hello, user!'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
mockPost.mockResolvedValue(mockResponse)
|
||||
|
||||
const result = await simpleInference(mockRequest)
|
||||
|
||||
expect(result).toBe('Hello, user!')
|
||||
expect(core.info).toHaveBeenCalledWith(
|
||||
'Running simple inference without tools'
|
||||
)
|
||||
expect(core.info).toHaveBeenCalledWith('Model response: Hello, user!')
|
||||
|
||||
// Verify the request structure
|
||||
expect(mockPost).toHaveBeenCalledWith({
|
||||
body: {
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: 'You are a test assistant'
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: 'Hello, AI!'
|
||||
}
|
||||
],
|
||||
max_tokens: 100,
|
||||
model: 'gpt-4'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('handles null response content', async () => {
|
||||
const mockResponse = {
|
||||
body: {
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
content: null
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
mockPost.mockResolvedValue(mockResponse)
|
||||
|
||||
const result = await simpleInference(mockRequest)
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(core.info).toHaveBeenCalledWith(
|
||||
'Model response: No response content'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('mcpInference', () => {
|
||||
const mockMcpClient = {
|
||||
client: {} as any,
|
||||
tools: [
|
||||
{
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'test-tool',
|
||||
description: 'A test tool',
|
||||
parameters: { type: 'object' }
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
it('performs inference without tool calls', async () => {
|
||||
const mockResponse = {
|
||||
body: {
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
content: 'Hello, user!',
|
||||
tool_calls: null
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
mockPost.mockResolvedValue(mockResponse)
|
||||
|
||||
const result = await mcpInference(mockRequest, mockMcpClient)
|
||||
|
||||
expect(result).toBe('Hello, user!')
|
||||
expect(core.info).toHaveBeenCalledWith('Running MCP inference with tools')
|
||||
expect(core.info).toHaveBeenCalledWith('MCP inference iteration 1')
|
||||
expect(core.info).toHaveBeenCalledWith(
|
||||
'No tool calls requested, ending MCP inference loop'
|
||||
)
|
||||
|
||||
// The MCP inference loop will always add the assistant message, even when there are no tool calls
|
||||
// So we don't check the exact messages, just that tools were included
|
||||
expect(mockPost).toHaveBeenCalledTimes(1)
|
||||
const callArgs = mockPost.mock.calls[0][0] as any
|
||||
expect(callArgs.body.tools).toEqual(mockMcpClient.tools)
|
||||
expect(callArgs.body.model).toBe('gpt-4')
|
||||
expect(callArgs.body.max_tokens).toBe(100)
|
||||
})
|
||||
|
||||
it('executes tool calls and continues conversation', async () => {
|
||||
const toolCalls = [
|
||||
{
|
||||
id: 'call-123',
|
||||
function: {
|
||||
name: 'test-tool',
|
||||
arguments: '{"param": "value"}'
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const toolResults = [
|
||||
{
|
||||
tool_call_id: 'call-123',
|
||||
role: 'tool',
|
||||
name: 'test-tool',
|
||||
content: 'Tool result'
|
||||
}
|
||||
]
|
||||
|
||||
// First response with tool calls
|
||||
const firstResponse = {
|
||||
body: {
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
content: 'I need to use a tool.',
|
||||
tool_calls: toolCalls
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
// Second response after tool execution
|
||||
const secondResponse = {
|
||||
body: {
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
content: 'Here is the final answer.',
|
||||
tool_calls: null
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
mockPost
|
||||
.mockResolvedValueOnce(firstResponse)
|
||||
.mockResolvedValueOnce(secondResponse)
|
||||
|
||||
mockExecuteToolCalls.mockResolvedValue(toolResults)
|
||||
|
||||
const result = await mcpInference(mockRequest, mockMcpClient)
|
||||
|
||||
expect(result).toBe('Here is the final answer.')
|
||||
expect(mockExecuteToolCalls).toHaveBeenCalledWith(
|
||||
mockMcpClient.client,
|
||||
toolCalls
|
||||
)
|
||||
expect(mockPost).toHaveBeenCalledTimes(2)
|
||||
|
||||
// Verify the second call includes the conversation history
|
||||
const secondCall = mockPost.mock.calls[1][0] as any
|
||||
expect(secondCall.body.messages).toHaveLength(5) // system, user, assistant, tool, assistant
|
||||
expect(secondCall.body.messages[2].role).toBe('assistant')
|
||||
expect(secondCall.body.messages[2].tool_calls).toEqual(toolCalls)
|
||||
expect(secondCall.body.messages[3]).toEqual(toolResults[0])
|
||||
})
|
||||
|
||||
it('handles maximum iteration limit', async () => {
|
||||
const toolCalls = [
|
||||
{
|
||||
id: 'call-123',
|
||||
function: {
|
||||
name: 'test-tool',
|
||||
arguments: '{}'
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const toolResults = [
|
||||
{
|
||||
tool_call_id: 'call-123',
|
||||
role: 'tool',
|
||||
name: 'test-tool',
|
||||
content: 'Tool result'
|
||||
}
|
||||
]
|
||||
|
||||
// Always respond with tool calls to trigger infinite loop
|
||||
const responseWithToolCalls = {
|
||||
body: {
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
content: 'Using tool again.',
|
||||
tool_calls: toolCalls
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
mockPost.mockResolvedValue(responseWithToolCalls)
|
||||
mockExecuteToolCalls.mockResolvedValue(toolResults)
|
||||
|
||||
const result = await mcpInference(mockRequest, mockMcpClient)
|
||||
|
||||
expect(mockPost).toHaveBeenCalledTimes(5) // Max iterations reached
|
||||
expect(core.warning).toHaveBeenCalledWith(
|
||||
'MCP inference loop exceeded maximum iterations (5)'
|
||||
)
|
||||
expect(result).toBe('Using tool again.') // Last assistant message
|
||||
})
|
||||
|
||||
it('handles empty tool calls array', async () => {
|
||||
const mockResponse = {
|
||||
body: {
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
content: 'Hello, user!',
|
||||
tool_calls: []
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
mockPost.mockResolvedValue(mockResponse)
|
||||
|
||||
const result = await mcpInference(mockRequest, mockMcpClient)
|
||||
|
||||
expect(result).toBe('Hello, user!')
|
||||
expect(core.info).toHaveBeenCalledWith(
|
||||
'No tool calls requested, ending MCP inference loop'
|
||||
)
|
||||
expect(mockExecuteToolCalls).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('returns last assistant message when no content in final iteration', async () => {
|
||||
const toolCalls = [
|
||||
{
|
||||
id: 'call-123',
|
||||
function: { name: 'test-tool', arguments: '{}' }
|
||||
}
|
||||
]
|
||||
|
||||
const firstResponse = {
|
||||
body: {
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
content: 'First message',
|
||||
tool_calls: toolCalls
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
const secondResponse = {
|
||||
body: {
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
content: 'Second message',
|
||||
tool_calls: toolCalls
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
mockPost
|
||||
.mockResolvedValueOnce(firstResponse)
|
||||
.mockResolvedValue(secondResponse)
|
||||
|
||||
mockExecuteToolCalls.mockResolvedValue([
|
||||
{
|
||||
tool_call_id: 'call-123',
|
||||
role: 'tool',
|
||||
name: 'test-tool',
|
||||
content: 'result'
|
||||
}
|
||||
])
|
||||
|
||||
const result = await mcpInference(mockRequest, mockMcpClient)
|
||||
|
||||
expect(result).toBe('Second message')
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,264 @@
|
||||
/**
|
||||
* Unit tests for the action's main functionality, src/main.ts
|
||||
*/
|
||||
import { jest } from '@jest/globals'
|
||||
import * as core from '../__fixtures__/core.js'
|
||||
|
||||
// Default to throwing errors to catch unexpected calls
|
||||
const mockExistsSync = jest.fn().mockImplementation(() => {
|
||||
throw new Error(
|
||||
'Unexpected call to existsSync - test should override this implementation'
|
||||
)
|
||||
})
|
||||
const mockReadFileSync = jest.fn().mockImplementation(() => {
|
||||
throw new Error(
|
||||
'Unexpected call to readFileSync - test should override this implementation'
|
||||
)
|
||||
})
|
||||
const mockWriteFileSync = jest.fn()
|
||||
|
||||
/**
|
||||
* Helper function to mock file system operations for one or more files
|
||||
* @param fileContents - Object mapping file paths to their contents
|
||||
* @param nonExistentFiles - Array of file paths that should be treated as non-existent
|
||||
*/
|
||||
function mockFileContent(
|
||||
fileContents: Record<string, string> = {},
|
||||
nonExistentFiles: string[] = []
|
||||
): void {
|
||||
// Mock existsSync to return true for files that exist, false for those that don't
|
||||
mockExistsSync.mockImplementation((...args: unknown[]): boolean => {
|
||||
const [path] = args as [string]
|
||||
if (nonExistentFiles.includes(path)) {
|
||||
return false
|
||||
}
|
||||
return path in fileContents || true
|
||||
})
|
||||
|
||||
// Mock readFileSync to return the content for known files
|
||||
mockReadFileSync.mockImplementation((...args: unknown[]): string => {
|
||||
const [path, options] = args as [string, BufferEncoding]
|
||||
if (options === 'utf-8' && path in fileContents) {
|
||||
return fileContents[path]
|
||||
}
|
||||
throw new Error(`Unexpected file read: ${path}`)
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to mock action inputs
|
||||
* @param inputs - Object mapping input names to their values
|
||||
*/
|
||||
function mockInputs(inputs: Record<string, string> = {}): void {
|
||||
// Default values that are applied unless overridden
|
||||
const defaultInputs: Record<string, string> = {
|
||||
token: 'fake-token',
|
||||
model: 'gpt-4',
|
||||
'max-tokens': '100',
|
||||
endpoint: 'https://api.test.com'
|
||||
}
|
||||
|
||||
// Combine defaults with user-provided inputs
|
||||
const allInputs: Record<string, string> = { ...defaultInputs, ...inputs }
|
||||
|
||||
core.getInput.mockImplementation((name: string) => {
|
||||
return allInputs[name] || ''
|
||||
})
|
||||
|
||||
core.getBooleanInput.mockImplementation((name: string) => {
|
||||
const value = allInputs[name]
|
||||
return value === 'true'
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to verify common response assertions
|
||||
*/
|
||||
function verifyStandardResponse(): void {
|
||||
expect(core.setOutput).toHaveBeenNthCalledWith(1, 'response', 'Hello, user!')
|
||||
expect(core.setOutput).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
'response-file',
|
||||
expect.stringContaining('modelResponse.txt')
|
||||
)
|
||||
}
|
||||
|
||||
jest.unstable_mockModule('fs', () => ({
|
||||
existsSync: mockExistsSync,
|
||||
readFileSync: mockReadFileSync,
|
||||
writeFileSync: mockWriteFileSync
|
||||
}))
|
||||
|
||||
// Mock MCP and inference modules
|
||||
const mockConnectToMCP = jest.fn() as jest.MockedFunction<any>
|
||||
const mockSimpleInference = jest.fn() as jest.MockedFunction<any>
|
||||
const mockMcpInference = jest.fn() as jest.MockedFunction<any>
|
||||
|
||||
jest.unstable_mockModule('../src/mcp.js', () => ({
|
||||
connectToMCP: mockConnectToMCP
|
||||
}))
|
||||
|
||||
jest.unstable_mockModule('../src/inference.js', () => ({
|
||||
simpleInference: mockSimpleInference,
|
||||
mcpInference: mockMcpInference
|
||||
}))
|
||||
|
||||
jest.unstable_mockModule('@actions/core', () => core)
|
||||
|
||||
// The module being tested should be imported dynamically. This ensures that the
|
||||
// mocks are used in place of any actual dependencies.
|
||||
const { run } = await import('../src/main.js')
|
||||
|
||||
describe('main.ts', () => {
|
||||
// Reset all mocks before each test
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
|
||||
// Set up default mock responses
|
||||
mockSimpleInference.mockResolvedValue('Hello, user!')
|
||||
mockMcpInference.mockResolvedValue('Hello, user!')
|
||||
})
|
||||
|
||||
it('Sets the response output', async () => {
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt': 'You are a test assistant.'
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(core.setOutput).toHaveBeenCalled()
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('Sets a failed status when no prompt is set', async () => {
|
||||
mockInputs({
|
||||
prompt: '',
|
||||
'prompt-file': ''
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(core.setFailed).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
'Neither prompt-file nor prompt was set'
|
||||
)
|
||||
})
|
||||
|
||||
it('uses simple inference when MCP is disabled', async () => {
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt': 'You are a test assistant.',
|
||||
'enable-mcp': 'false'
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockSimpleInference).toHaveBeenCalledWith({
|
||||
systemPrompt: 'You are a test assistant.',
|
||||
prompt: 'Hello, AI!',
|
||||
modelName: 'gpt-4',
|
||||
maxTokens: 100,
|
||||
endpoint: 'https://api.test.com',
|
||||
token: 'fake-token'
|
||||
})
|
||||
expect(mockConnectToMCP).not.toHaveBeenCalled()
|
||||
expect(mockMcpInference).not.toHaveBeenCalled()
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('uses MCP inference when enabled and connection succeeds', async () => {
|
||||
const mockMcpClient = {
|
||||
client: {} as any,
|
||||
tools: [{ type: 'function', function: { name: 'test-tool' } }]
|
||||
}
|
||||
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt': 'You are a test assistant.',
|
||||
'enable-mcp': 'true'
|
||||
})
|
||||
|
||||
mockConnectToMCP.mockResolvedValue(mockMcpClient)
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockConnectToMCP).toHaveBeenCalledWith('fake-token')
|
||||
expect(mockMcpInference).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
systemPrompt: 'You are a test assistant.',
|
||||
prompt: 'Hello, AI!',
|
||||
token: 'fake-token'
|
||||
}),
|
||||
mockMcpClient
|
||||
)
|
||||
expect(mockSimpleInference).not.toHaveBeenCalled()
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('falls back to simple inference when MCP connection fails', async () => {
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt': 'You are a test assistant.',
|
||||
'enable-mcp': 'true'
|
||||
})
|
||||
|
||||
mockConnectToMCP.mockResolvedValue(null)
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockConnectToMCP).toHaveBeenCalledWith('fake-token')
|
||||
expect(mockSimpleInference).toHaveBeenCalled()
|
||||
expect(mockMcpInference).not.toHaveBeenCalled()
|
||||
expect(core.warning).toHaveBeenCalledWith(
|
||||
'MCP connection failed, falling back to simple inference'
|
||||
)
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('properly integrates with loadContentFromFileOrInput', async () => {
|
||||
const promptFile = 'prompt.txt'
|
||||
const systemPromptFile = 'system-prompt.txt'
|
||||
const promptContent = 'File-based prompt'
|
||||
const systemPromptContent = 'File-based system prompt'
|
||||
|
||||
mockFileContent({
|
||||
[promptFile]: promptContent,
|
||||
[systemPromptFile]: systemPromptContent
|
||||
})
|
||||
|
||||
mockInputs({
|
||||
'prompt-file': promptFile,
|
||||
'system-prompt-file': systemPromptFile,
|
||||
'enable-mcp': 'false'
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockSimpleInference).toHaveBeenCalledWith({
|
||||
systemPrompt: systemPromptContent,
|
||||
prompt: promptContent,
|
||||
modelName: 'gpt-4',
|
||||
maxTokens: 100,
|
||||
endpoint: 'https://api.test.com',
|
||||
token: 'fake-token'
|
||||
})
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('handles non-existent prompt-file with an error', async () => {
|
||||
const promptFile = 'non-existent-prompt.txt'
|
||||
|
||||
mockFileContent({}, [promptFile])
|
||||
|
||||
mockInputs({
|
||||
'prompt-file': promptFile
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(core.setFailed).toHaveBeenCalledWith(
|
||||
`File for prompt-file was not found: ${promptFile}`
|
||||
)
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,383 @@
|
||||
/**
|
||||
* Unit tests for the action's main functionality, src/main.ts
|
||||
*/
|
||||
import { jest } from '@jest/globals'
|
||||
import * as core from '../__fixtures__/core.js'
|
||||
|
||||
// Default to throwing errors to catch unexpected calls
|
||||
const mockExistsSync = jest.fn().mockImplementation(() => {
|
||||
throw new Error(
|
||||
'Unexpected call to existsSync - test should override this implementation'
|
||||
)
|
||||
})
|
||||
const mockReadFileSync = jest.fn().mockImplementation(() => {
|
||||
throw new Error(
|
||||
'Unexpected call to readFileSync - test should override this implementation'
|
||||
)
|
||||
})
|
||||
const mockWriteFileSync = jest.fn()
|
||||
|
||||
/**
|
||||
* Helper function to mock file system operations for one or more files
|
||||
* @param fileContents - Object mapping file paths to their contents
|
||||
* @param nonExistentFiles - Array of file paths that should be treated as non-existent
|
||||
*/
|
||||
function mockFileContent(
|
||||
fileContents: Record<string, string> = {},
|
||||
nonExistentFiles: string[] = []
|
||||
): void {
|
||||
// Mock existsSync to return true for files that exist, false for those that don't
|
||||
mockExistsSync.mockImplementation((...args: unknown[]): boolean => {
|
||||
const [path] = args as [string]
|
||||
if (nonExistentFiles.includes(path)) {
|
||||
return false
|
||||
}
|
||||
return path in fileContents || true
|
||||
})
|
||||
|
||||
// Mock readFileSync to return the content for known files
|
||||
mockReadFileSync.mockImplementation((...args: unknown[]): string => {
|
||||
const [path, options] = args as [string, BufferEncoding]
|
||||
if (options === 'utf-8' && path in fileContents) {
|
||||
return fileContents[path]
|
||||
}
|
||||
throw new Error(`Unexpected file read: ${path}`)
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to mock action inputs
|
||||
* @param inputs - Object mapping input names to their values
|
||||
*/
|
||||
function mockInputs(inputs: Record<string, string> = {}): void {
|
||||
// Default values that are applied unless overridden
|
||||
const defaultInputs: Record<string, string> = {
|
||||
token: 'fake-token',
|
||||
model: 'gpt-4',
|
||||
'max-tokens': '100',
|
||||
endpoint: 'https://api.test.com'
|
||||
}
|
||||
|
||||
// Combine defaults with user-provided inputs
|
||||
const allInputs: Record<string, string> = { ...defaultInputs, ...inputs }
|
||||
|
||||
core.getInput.mockImplementation((name: string) => {
|
||||
return allInputs[name] || ''
|
||||
})
|
||||
|
||||
core.getBooleanInput.mockImplementation((name: string) => {
|
||||
const value = allInputs[name]
|
||||
return value === 'true'
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to verify common response assertions
|
||||
*/
|
||||
function verifyStandardResponse(): void {
|
||||
expect(core.setOutput).toHaveBeenNthCalledWith(1, 'response', 'Hello, user!')
|
||||
expect(core.setOutput).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
'response-file',
|
||||
expect.stringContaining('modelResponse.txt')
|
||||
)
|
||||
}
|
||||
|
||||
jest.unstable_mockModule('fs', () => ({
|
||||
existsSync: mockExistsSync,
|
||||
readFileSync: mockReadFileSync,
|
||||
writeFileSync: mockWriteFileSync
|
||||
}))
|
||||
|
||||
// Mock MCP and inference modules
|
||||
const mockConnectToMCP = jest.fn() as jest.MockedFunction<any>
|
||||
const mockSimpleInference = jest.fn() as jest.MockedFunction<any>
|
||||
const mockMcpInference = jest.fn() as jest.MockedFunction<any>
|
||||
|
||||
jest.unstable_mockModule('../src/mcp.js', () => ({
|
||||
connectToMCP: mockConnectToMCP
|
||||
}))
|
||||
|
||||
jest.unstable_mockModule('../src/inference.js', () => ({
|
||||
simpleInference: mockSimpleInference,
|
||||
mcpInference: mockMcpInference
|
||||
}))
|
||||
|
||||
jest.unstable_mockModule('@actions/core', () => core)
|
||||
|
||||
// The module being tested should be imported dynamically. This ensures that the
|
||||
// mocks are used in place of any actual dependencies.
|
||||
const { run } = await import('../src/main.js')
|
||||
|
||||
describe('main.ts', () => {
|
||||
// Reset all mocks before each test
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
|
||||
// Set up default mock responses
|
||||
mockSimpleInference.mockResolvedValue('Hello, user!')
|
||||
mockMcpInference.mockResolvedValue('Hello, user!')
|
||||
})
|
||||
|
||||
it('Sets the response output', async () => {
|
||||
// Set the action's inputs as return values from core.getInput().
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt': 'You are a test assistant.'
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(core.setOutput).toHaveBeenCalled()
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('Sets a failed status when no prompt is set', async () => {
|
||||
// Clear the getInput mock and simulate no prompt or prompt-file input
|
||||
mockInputs({
|
||||
prompt: '',
|
||||
'prompt-file': ''
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
// Verify that the action was marked as failed.
|
||||
expect(core.setFailed).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
'Neither prompt-file nor prompt was set'
|
||||
)
|
||||
})
|
||||
|
||||
it('uses prompt-file', async () => {
|
||||
const promptFile = 'prompt.txt'
|
||||
const promptContent = 'This is a prompt from a file'
|
||||
|
||||
// Set up mock to return specific content for the prompt file
|
||||
mockFileContent({
|
||||
[promptFile]: promptContent
|
||||
})
|
||||
|
||||
// Set up input mocks
|
||||
mockInputs({
|
||||
'prompt-file': promptFile,
|
||||
'system-prompt': 'You are a test assistant.'
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(promptFile)
|
||||
expect(mockReadFileSync).toHaveBeenCalledWith(promptFile, 'utf-8')
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('handles non-existent prompt-file with an error', async () => {
|
||||
const promptFile = 'non-existent-prompt.txt'
|
||||
|
||||
// Mock the file not existing
|
||||
mockFileContent({}, [promptFile])
|
||||
|
||||
// Set up input mocks
|
||||
mockInputs({
|
||||
'prompt-file': promptFile
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
// Verify that the error was correctly reported
|
||||
expect(core.setFailed).toHaveBeenCalledWith(
|
||||
`File for prompt-file was not found: ${promptFile}`
|
||||
)
|
||||
})
|
||||
|
||||
it('prefers prompt-file over prompt when both are provided', async () => {
|
||||
const promptFile = 'prompt.txt'
|
||||
const promptFileContent = 'This is a prompt from a file that should be used'
|
||||
const promptString = 'This is a direct prompt that should be ignored'
|
||||
|
||||
// Set up mock to return specific content for the prompt file
|
||||
mockFileContent({
|
||||
[promptFile]: promptFileContent
|
||||
})
|
||||
|
||||
// Set up input mocks
|
||||
mockInputs({
|
||||
prompt: promptString,
|
||||
'prompt-file': promptFile,
|
||||
'system-prompt': 'You are a test assistant.'
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(promptFile)
|
||||
expect(mockReadFileSync).toHaveBeenCalledWith(promptFile, 'utf-8')
|
||||
|
||||
// Check that the post call was made with the prompt from the file, not the input parameter
|
||||
expect(mockPost).toHaveBeenCalledWith({
|
||||
body: {
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: expect.any(String)
|
||||
},
|
||||
{ role: 'user', content: promptFileContent } // Should use the file content, not the string input
|
||||
],
|
||||
max_tokens: expect.any(Number),
|
||||
model: expect.any(String)
|
||||
}
|
||||
})
|
||||
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('uses system-prompt-file', async () => {
|
||||
const systemPromptFile = 'system-prompt.txt'
|
||||
const systemPromptContent =
|
||||
'You are a specialized system assistant for testing'
|
||||
|
||||
// Set up mock to return specific content for the system prompt file
|
||||
mockFileContent({
|
||||
[systemPromptFile]: systemPromptContent
|
||||
})
|
||||
|
||||
// Set up input mocks
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt-file': systemPromptFile
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(systemPromptFile)
|
||||
expect(mockReadFileSync).toHaveBeenCalledWith(systemPromptFile, 'utf-8')
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('handles non-existent system-prompt-file with an error', async () => {
|
||||
const systemPromptFile = 'non-existent-system-prompt.txt'
|
||||
|
||||
// Mock the file not existing
|
||||
mockFileContent({}, [systemPromptFile])
|
||||
|
||||
// Set up input mocks
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt-file': systemPromptFile
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
// Verify that the error was correctly reported
|
||||
expect(core.setFailed).toHaveBeenCalledWith(
|
||||
`File for system-prompt-file was not found: ${systemPromptFile}`
|
||||
)
|
||||
})
|
||||
|
||||
it('prefers system-prompt-file over system-prompt when both are provided', async () => {
|
||||
const systemPromptFile = 'system-prompt.txt'
|
||||
const systemPromptFileContent =
|
||||
'You are a specialized system assistant from file'
|
||||
const systemPromptString =
|
||||
'You are a basic system assistant from input parameter'
|
||||
|
||||
// Set up mock to return specific content for the system prompt file
|
||||
mockFileContent({
|
||||
[systemPromptFile]: systemPromptFileContent
|
||||
})
|
||||
|
||||
// Set up input mocks
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt-file': systemPromptFile,
|
||||
'system-prompt': systemPromptString
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(systemPromptFile)
|
||||
expect(mockReadFileSync).toHaveBeenCalledWith(systemPromptFile, 'utf-8')
|
||||
|
||||
// Check that the post call was made with the system prompt from the file, not the input parameter
|
||||
expect(mockPost).toHaveBeenCalledWith({
|
||||
body: {
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: systemPromptFileContent // Should use the file content, not the string input
|
||||
},
|
||||
{ role: 'user', content: 'Hello, AI!' }
|
||||
],
|
||||
max_tokens: expect.any(Number),
|
||||
model: expect.any(String)
|
||||
}
|
||||
})
|
||||
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('uses both prompt-file and system-prompt-file together', async () => {
|
||||
const promptFile = 'prompt.txt'
|
||||
const promptContent = 'This is a prompt from a file'
|
||||
const systemPromptFile = 'system-prompt.txt'
|
||||
const systemPromptContent =
|
||||
'You are a specialized system assistant from file'
|
||||
|
||||
// Set up mock to return specific content for both files
|
||||
mockFileContent({
|
||||
[promptFile]: promptContent,
|
||||
[systemPromptFile]: systemPromptContent
|
||||
})
|
||||
|
||||
// Set up input mocks
|
||||
mockInputs({
|
||||
'prompt-file': promptFile,
|
||||
'system-prompt-file': systemPromptFile
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(promptFile)
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(systemPromptFile)
|
||||
expect(mockReadFileSync).toHaveBeenCalledWith(promptFile, 'utf-8')
|
||||
expect(mockReadFileSync).toHaveBeenCalledWith(systemPromptFile, 'utf-8')
|
||||
|
||||
// Check that the post call was made with both the prompt and system prompt from files
|
||||
expect(mockPost).toHaveBeenCalledWith({
|
||||
body: {
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: systemPromptContent
|
||||
},
|
||||
{ role: 'user', content: promptContent }
|
||||
],
|
||||
max_tokens: expect.any(Number),
|
||||
model: expect.any(String)
|
||||
}
|
||||
})
|
||||
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('passes custom max-tokens parameter to the model', async () => {
|
||||
const customMaxTokens = 500
|
||||
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt': 'You are a test assistant.',
|
||||
'max-tokens': customMaxTokens.toString()
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
// Check that the post call was made with the correct max_tokens parameter
|
||||
expect(mockPost).toHaveBeenCalledWith({
|
||||
body: {
|
||||
messages: expect.any(Array),
|
||||
max_tokens: customMaxTokens,
|
||||
model: expect.any(String)
|
||||
}
|
||||
})
|
||||
|
||||
verifyStandardResponse()
|
||||
})
|
||||
})
|
||||
+126
-238
@@ -1,32 +1,8 @@
|
||||
/**
|
||||
* Unit tests for the action's main functionality, src/main.ts
|
||||
*
|
||||
* To mock dependencies in ESM, you can create fixtures that export mock
|
||||
* functions and objects. For example, the core module is mocked in this test,
|
||||
* so that the actual '@actions/core' module is not imported.
|
||||
*/
|
||||
import { jest } from '@jest/globals'
|
||||
import * as core from '../__fixtures__/core.js'
|
||||
const mockPost = jest.fn().mockImplementation(() => ({
|
||||
body: {
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
content: 'Hello, user!'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}))
|
||||
|
||||
jest.unstable_mockModule('@azure-rest/ai-inference', () => ({
|
||||
default: jest.fn(() => ({
|
||||
path: jest.fn(() => ({
|
||||
post: mockPost
|
||||
}))
|
||||
})),
|
||||
isUnexpected: jest.fn(() => false)
|
||||
}))
|
||||
|
||||
// Default to throwing errors to catch unexpected calls
|
||||
const mockExistsSync = jest.fn().mockImplementation(() => {
|
||||
@@ -39,6 +15,7 @@ const mockReadFileSync = jest.fn().mockImplementation(() => {
|
||||
'Unexpected call to readFileSync - test should override this implementation'
|
||||
)
|
||||
})
|
||||
const mockWriteFileSync = jest.fn()
|
||||
|
||||
/**
|
||||
* Helper function to mock file system operations for one or more files
|
||||
@@ -75,7 +52,10 @@ function mockFileContent(
|
||||
function mockInputs(inputs: Record<string, string> = {}): void {
|
||||
// Default values that are applied unless overridden
|
||||
const defaultInputs: Record<string, string> = {
|
||||
token: 'fake-token'
|
||||
token: 'fake-token',
|
||||
model: 'gpt-4',
|
||||
'max-tokens': '100',
|
||||
endpoint: 'https://api.test.com'
|
||||
}
|
||||
|
||||
// Combine defaults with user-provided inputs
|
||||
@@ -84,6 +64,11 @@ function mockInputs(inputs: Record<string, string> = {}): void {
|
||||
core.getInput.mockImplementation((name: string) => {
|
||||
return allInputs[name] || ''
|
||||
})
|
||||
|
||||
core.getBooleanInput.mockImplementation((name: string) => {
|
||||
const value = allInputs[name]
|
||||
return value === 'true'
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -100,7 +85,22 @@ function verifyStandardResponse(): void {
|
||||
|
||||
jest.unstable_mockModule('fs', () => ({
|
||||
existsSync: mockExistsSync,
|
||||
readFileSync: mockReadFileSync
|
||||
readFileSync: mockReadFileSync,
|
||||
writeFileSync: mockWriteFileSync
|
||||
}))
|
||||
|
||||
// Mock MCP and inference modules
|
||||
const mockConnectToMCP = jest.fn() as jest.MockedFunction<any>
|
||||
const mockSimpleInference = jest.fn() as jest.MockedFunction<any>
|
||||
const mockMcpInference = jest.fn() as jest.MockedFunction<any>
|
||||
|
||||
jest.unstable_mockModule('../src/mcp.js', () => ({
|
||||
connectToMCP: mockConnectToMCP
|
||||
}))
|
||||
|
||||
jest.unstable_mockModule('../src/inference.js', () => ({
|
||||
simpleInference: mockSimpleInference,
|
||||
mcpInference: mockMcpInference
|
||||
}))
|
||||
|
||||
jest.unstable_mockModule('@actions/core', () => core)
|
||||
@@ -113,10 +113,16 @@ describe('main.ts', () => {
|
||||
// Reset all mocks before each test
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
|
||||
// Remove any existing GITHUB_TOKEN
|
||||
delete process.env.GITHUB_TOKEN
|
||||
|
||||
// Set up default mock responses
|
||||
mockSimpleInference.mockResolvedValue('Hello, user!')
|
||||
mockMcpInference.mockResolvedValue('Hello, user!')
|
||||
})
|
||||
|
||||
it('Sets the response output', async () => {
|
||||
// Set the action's inputs as return values from core.getInput().
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt': 'You are a test assistant.'
|
||||
@@ -129,7 +135,6 @@ describe('main.ts', () => {
|
||||
})
|
||||
|
||||
it('Sets a failed status when no prompt is set', async () => {
|
||||
// Clear the getInput mock and simulate no prompt or prompt-file input
|
||||
mockInputs({
|
||||
prompt: '',
|
||||
'prompt-file': ''
|
||||
@@ -137,243 +142,126 @@ describe('main.ts', () => {
|
||||
|
||||
await run()
|
||||
|
||||
// Verify that the action was marked as failed.
|
||||
expect(core.setFailed).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
'Neither prompt-file nor prompt was set'
|
||||
)
|
||||
})
|
||||
|
||||
it('uses prompt-file', async () => {
|
||||
const promptFile = 'prompt.txt'
|
||||
const promptContent = 'This is a prompt from a file'
|
||||
|
||||
// Set up mock to return specific content for the prompt file
|
||||
mockFileContent({
|
||||
[promptFile]: promptContent
|
||||
})
|
||||
|
||||
// Set up input mocks
|
||||
it('uses simple inference when MCP is disabled', async () => {
|
||||
mockInputs({
|
||||
'prompt-file': promptFile,
|
||||
'system-prompt': 'You are a test assistant.'
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt': 'You are a test assistant.',
|
||||
'enable-mcp': 'false'
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(promptFile)
|
||||
expect(mockReadFileSync).toHaveBeenCalledWith(promptFile, 'utf-8')
|
||||
expect(mockSimpleInference).toHaveBeenCalledWith({
|
||||
systemPrompt: 'You are a test assistant.',
|
||||
prompt: 'Hello, AI!',
|
||||
modelName: 'gpt-4',
|
||||
maxTokens: 100,
|
||||
endpoint: 'https://api.test.com',
|
||||
token: 'fake-token'
|
||||
})
|
||||
expect(mockConnectToMCP).not.toHaveBeenCalled()
|
||||
expect(mockMcpInference).not.toHaveBeenCalled()
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('uses MCP inference when enabled and connection succeeds', async () => {
|
||||
const mockMcpClient = {
|
||||
client: {} as any,
|
||||
tools: [{ type: 'function', function: { name: 'test-tool' } }]
|
||||
}
|
||||
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt': 'You are a test assistant.',
|
||||
'enable-mcp': 'true'
|
||||
})
|
||||
|
||||
mockConnectToMCP.mockResolvedValue(mockMcpClient)
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockConnectToMCP).toHaveBeenCalledWith('fake-token')
|
||||
expect(mockMcpInference).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
systemPrompt: 'You are a test assistant.',
|
||||
prompt: 'Hello, AI!',
|
||||
token: 'fake-token'
|
||||
}),
|
||||
mockMcpClient
|
||||
)
|
||||
expect(mockSimpleInference).not.toHaveBeenCalled()
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('falls back to simple inference when MCP connection fails', async () => {
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt': 'You are a test assistant.',
|
||||
'enable-mcp': 'true'
|
||||
})
|
||||
|
||||
mockConnectToMCP.mockResolvedValue(null)
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockConnectToMCP).toHaveBeenCalledWith('fake-token')
|
||||
expect(mockSimpleInference).toHaveBeenCalled()
|
||||
expect(mockMcpInference).not.toHaveBeenCalled()
|
||||
expect(core.warning).toHaveBeenCalledWith(
|
||||
'MCP connection failed, falling back to simple inference'
|
||||
)
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('properly integrates with loadContentFromFileOrInput', async () => {
|
||||
const promptFile = 'prompt.txt'
|
||||
const systemPromptFile = 'system-prompt.txt'
|
||||
const promptContent = 'File-based prompt'
|
||||
const systemPromptContent = 'File-based system prompt'
|
||||
|
||||
mockFileContent({
|
||||
[promptFile]: promptContent,
|
||||
[systemPromptFile]: systemPromptContent
|
||||
})
|
||||
|
||||
mockInputs({
|
||||
'prompt-file': promptFile,
|
||||
'system-prompt-file': systemPromptFile,
|
||||
'enable-mcp': 'false'
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockSimpleInference).toHaveBeenCalledWith({
|
||||
systemPrompt: systemPromptContent,
|
||||
prompt: promptContent,
|
||||
modelName: 'gpt-4',
|
||||
maxTokens: 100,
|
||||
endpoint: 'https://api.test.com',
|
||||
token: 'fake-token'
|
||||
})
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('handles non-existent prompt-file with an error', async () => {
|
||||
const promptFile = 'non-existent-prompt.txt'
|
||||
|
||||
// Mock the file not existing
|
||||
mockFileContent({}, [promptFile])
|
||||
|
||||
// Set up input mocks
|
||||
mockInputs({
|
||||
'prompt-file': promptFile
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
// Verify that the error was correctly reported
|
||||
expect(core.setFailed).toHaveBeenCalledWith(
|
||||
`File for prompt-file was not found: ${promptFile}`
|
||||
)
|
||||
})
|
||||
|
||||
it('prefers prompt-file over prompt when both are provided', async () => {
|
||||
const promptFile = 'prompt.txt'
|
||||
const promptFileContent = 'This is a prompt from a file that should be used'
|
||||
const promptString = 'This is a direct prompt that should be ignored'
|
||||
|
||||
// Set up mock to return specific content for the prompt file
|
||||
mockFileContent({
|
||||
[promptFile]: promptFileContent
|
||||
})
|
||||
|
||||
// Set up input mocks
|
||||
mockInputs({
|
||||
prompt: promptString,
|
||||
'prompt-file': promptFile,
|
||||
'system-prompt': 'You are a test assistant.'
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(promptFile)
|
||||
expect(mockReadFileSync).toHaveBeenCalledWith(promptFile, 'utf-8')
|
||||
|
||||
// Check that the post call was made with the prompt from the file, not the input parameter
|
||||
expect(mockPost).toHaveBeenCalledWith({
|
||||
body: {
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: expect.any(String)
|
||||
},
|
||||
{ role: 'user', content: promptFileContent } // Should use the file content, not the string input
|
||||
],
|
||||
max_tokens: expect.any(Number),
|
||||
model: expect.any(String)
|
||||
}
|
||||
})
|
||||
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('uses system-prompt-file', async () => {
|
||||
const systemPromptFile = 'system-prompt.txt'
|
||||
const systemPromptContent =
|
||||
'You are a specialized system assistant for testing'
|
||||
|
||||
// Set up mock to return specific content for the system prompt file
|
||||
mockFileContent({
|
||||
[systemPromptFile]: systemPromptContent
|
||||
})
|
||||
|
||||
// Set up input mocks
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt-file': systemPromptFile
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(systemPromptFile)
|
||||
expect(mockReadFileSync).toHaveBeenCalledWith(systemPromptFile, 'utf-8')
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('handles non-existent system-prompt-file with an error', async () => {
|
||||
const systemPromptFile = 'non-existent-system-prompt.txt'
|
||||
|
||||
// Mock the file not existing
|
||||
mockFileContent({}, [systemPromptFile])
|
||||
|
||||
// Set up input mocks
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt-file': systemPromptFile
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
// Verify that the error was correctly reported
|
||||
expect(core.setFailed).toHaveBeenCalledWith(
|
||||
`File for system-prompt-file was not found: ${systemPromptFile}`
|
||||
)
|
||||
})
|
||||
|
||||
it('prefers system-prompt-file over system-prompt when both are provided', async () => {
|
||||
const systemPromptFile = 'system-prompt.txt'
|
||||
const systemPromptFileContent =
|
||||
'You are a specialized system assistant from file'
|
||||
const systemPromptString =
|
||||
'You are a basic system assistant from input parameter'
|
||||
|
||||
// Set up mock to return specific content for the system prompt file
|
||||
mockFileContent({
|
||||
[systemPromptFile]: systemPromptFileContent
|
||||
})
|
||||
|
||||
// Set up input mocks
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt-file': systemPromptFile,
|
||||
'system-prompt': systemPromptString
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(systemPromptFile)
|
||||
expect(mockReadFileSync).toHaveBeenCalledWith(systemPromptFile, 'utf-8')
|
||||
|
||||
// Check that the post call was made with the system prompt from the file, not the input parameter
|
||||
expect(mockPost).toHaveBeenCalledWith({
|
||||
body: {
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: systemPromptFileContent // Should use the file content, not the string input
|
||||
},
|
||||
{ role: 'user', content: 'Hello, AI!' }
|
||||
],
|
||||
max_tokens: expect.any(Number),
|
||||
model: expect.any(String)
|
||||
}
|
||||
})
|
||||
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('uses both prompt-file and system-prompt-file together', async () => {
|
||||
const promptFile = 'prompt.txt'
|
||||
const promptContent = 'This is a prompt from a file'
|
||||
const systemPromptFile = 'system-prompt.txt'
|
||||
const systemPromptContent =
|
||||
'You are a specialized system assistant from file'
|
||||
|
||||
// Set up mock to return specific content for both files
|
||||
mockFileContent({
|
||||
[promptFile]: promptContent,
|
||||
[systemPromptFile]: systemPromptContent
|
||||
})
|
||||
|
||||
// Set up input mocks
|
||||
mockInputs({
|
||||
'prompt-file': promptFile,
|
||||
'system-prompt-file': systemPromptFile
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(promptFile)
|
||||
expect(mockExistsSync).toHaveBeenCalledWith(systemPromptFile)
|
||||
expect(mockReadFileSync).toHaveBeenCalledWith(promptFile, 'utf-8')
|
||||
expect(mockReadFileSync).toHaveBeenCalledWith(systemPromptFile, 'utf-8')
|
||||
|
||||
// Check that the post call was made with both the prompt and system prompt from files
|
||||
expect(mockPost).toHaveBeenCalledWith({
|
||||
body: {
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: systemPromptContent
|
||||
},
|
||||
{ role: 'user', content: promptContent }
|
||||
],
|
||||
max_tokens: expect.any(Number),
|
||||
model: expect.any(String)
|
||||
}
|
||||
})
|
||||
|
||||
verifyStandardResponse()
|
||||
})
|
||||
|
||||
it('passes custom max-tokens parameter to the model', async () => {
|
||||
const customMaxTokens = 500
|
||||
|
||||
mockInputs({
|
||||
prompt: 'Hello, AI!',
|
||||
'system-prompt': 'You are a test assistant.',
|
||||
'max-tokens': customMaxTokens.toString()
|
||||
})
|
||||
|
||||
await run()
|
||||
|
||||
// Check that the post call was made with the correct max_tokens parameter
|
||||
expect(mockPost).toHaveBeenCalledWith({
|
||||
body: {
|
||||
messages: expect.any(Array),
|
||||
max_tokens: customMaxTokens,
|
||||
model: expect.any(String)
|
||||
}
|
||||
})
|
||||
|
||||
verifyStandardResponse()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -0,0 +1,280 @@
|
||||
/**
|
||||
* Unit tests for the MCP module, src/mcp.ts
|
||||
*/
|
||||
import { jest } from '@jest/globals'
|
||||
import * as core from '../__fixtures__/core.js'
|
||||
|
||||
// Mock MCP SDK
|
||||
const mockConnect = jest.fn() as jest.MockedFunction<any>
|
||||
const mockListTools = jest.fn() as jest.MockedFunction<any>
|
||||
const mockCallTool = jest.fn() as jest.MockedFunction<any>
|
||||
|
||||
const mockClient = {
|
||||
connect: mockConnect,
|
||||
listTools: mockListTools,
|
||||
callTool: mockCallTool
|
||||
} as any
|
||||
|
||||
jest.unstable_mockModule('@modelcontextprotocol/sdk/client/index.js', () => ({
|
||||
Client: jest.fn(() => mockClient)
|
||||
}))
|
||||
|
||||
jest.unstable_mockModule(
|
||||
'@modelcontextprotocol/sdk/client/streamableHttp.js',
|
||||
() => ({
|
||||
StreamableHTTPClientTransport: jest.fn()
|
||||
})
|
||||
)
|
||||
|
||||
jest.unstable_mockModule('@actions/core', () => core)
|
||||
|
||||
// Import the module being tested
|
||||
const { connectToMCP, executeToolCall, executeToolCalls } = await import(
|
||||
'../src/mcp.js'
|
||||
)
|
||||
|
||||
describe('mcp.ts', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('connectToMCP', () => {
|
||||
it('successfully connects to MCP server and retrieves tools', async () => {
|
||||
const token = 'test-token'
|
||||
const mockTools = [
|
||||
{
|
||||
name: 'test-tool-1',
|
||||
description: 'Test tool 1',
|
||||
inputSchema: { type: 'object', properties: {} }
|
||||
},
|
||||
{
|
||||
name: 'test-tool-2',
|
||||
description: 'Test tool 2',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: { param: { type: 'string' } }
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
mockConnect.mockResolvedValue(undefined)
|
||||
mockListTools.mockResolvedValue({ tools: mockTools })
|
||||
|
||||
const result = await connectToMCP(token)
|
||||
|
||||
expect(result).not.toBeNull()
|
||||
expect(result?.client).toBe(mockClient)
|
||||
expect(result?.tools).toHaveLength(2)
|
||||
expect(result?.tools[0]).toEqual({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'test-tool-1',
|
||||
description: 'Test tool 1',
|
||||
parameters: { type: 'object', properties: {} }
|
||||
}
|
||||
})
|
||||
expect(core.info).toHaveBeenCalledWith(
|
||||
'Connecting to GitHub MCP server...'
|
||||
)
|
||||
expect(core.info).toHaveBeenCalledWith(
|
||||
'Successfully connected to MCP server'
|
||||
)
|
||||
expect(core.info).toHaveBeenCalledWith(
|
||||
'Retrieved 2 tools from MCP server'
|
||||
)
|
||||
expect(core.info).toHaveBeenCalledWith(
|
||||
'Mapped 2 tools for Azure AI Inference'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns null when connection fails', async () => {
|
||||
const token = 'test-token'
|
||||
const connectionError = new Error('Connection failed')
|
||||
|
||||
mockConnect.mockRejectedValue(connectionError)
|
||||
|
||||
const result = await connectToMCP(token)
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(core.warning).toHaveBeenCalledWith(
|
||||
'Failed to connect to MCP server: Error: Connection failed'
|
||||
)
|
||||
})
|
||||
|
||||
it('handles empty tools list', async () => {
|
||||
const token = 'test-token'
|
||||
|
||||
mockConnect.mockResolvedValue(undefined)
|
||||
mockListTools.mockResolvedValue({ tools: [] })
|
||||
|
||||
const result = await connectToMCP(token)
|
||||
|
||||
expect(result).not.toBeNull()
|
||||
expect(result?.tools).toHaveLength(0)
|
||||
expect(core.info).toHaveBeenCalledWith(
|
||||
'Retrieved 0 tools from MCP server'
|
||||
)
|
||||
expect(core.info).toHaveBeenCalledWith(
|
||||
'Mapped 0 tools for Azure AI Inference'
|
||||
)
|
||||
})
|
||||
|
||||
it('handles undefined tools list', async () => {
|
||||
const token = 'test-token'
|
||||
|
||||
mockConnect.mockResolvedValue(undefined)
|
||||
mockListTools.mockResolvedValue({})
|
||||
|
||||
const result = await connectToMCP(token)
|
||||
|
||||
expect(result).not.toBeNull()
|
||||
expect(result?.tools).toHaveLength(0)
|
||||
expect(core.info).toHaveBeenCalledWith(
|
||||
'Retrieved 0 tools from MCP server'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('executeToolCall', () => {
|
||||
it('successfully executes a tool call', async () => {
|
||||
const toolCall = {
|
||||
id: 'call-123',
|
||||
function: {
|
||||
name: 'test-tool',
|
||||
arguments: '{"param": "value"}'
|
||||
}
|
||||
}
|
||||
const toolResult = {
|
||||
content: [{ type: 'text', text: 'Tool execution result' }]
|
||||
}
|
||||
|
||||
mockCallTool.mockResolvedValue(toolResult)
|
||||
|
||||
const result = await executeToolCall(mockClient, toolCall)
|
||||
|
||||
expect(mockCallTool).toHaveBeenCalledWith({
|
||||
name: 'test-tool',
|
||||
arguments: { param: 'value' }
|
||||
})
|
||||
expect(result).toEqual({
|
||||
tool_call_id: 'call-123',
|
||||
role: 'tool',
|
||||
name: 'test-tool',
|
||||
content: JSON.stringify(toolResult.content)
|
||||
})
|
||||
expect(core.info).toHaveBeenCalledWith(
|
||||
'Executing tool: test-tool with args: {"param": "value"}'
|
||||
)
|
||||
expect(core.info).toHaveBeenCalledWith(
|
||||
'Tool test-tool executed successfully'
|
||||
)
|
||||
})
|
||||
|
||||
it('handles tool execution errors gracefully', async () => {
|
||||
const toolCall = {
|
||||
id: 'call-456',
|
||||
function: {
|
||||
name: 'failing-tool',
|
||||
arguments: '{"param": "value"}'
|
||||
}
|
||||
}
|
||||
const toolError = new Error('Tool execution failed')
|
||||
|
||||
mockCallTool.mockRejectedValue(toolError)
|
||||
|
||||
const result = await executeToolCall(mockClient, toolCall)
|
||||
|
||||
expect(result).toEqual({
|
||||
tool_call_id: 'call-456',
|
||||
role: 'tool',
|
||||
name: 'failing-tool',
|
||||
content: 'Error: Error: Tool execution failed'
|
||||
})
|
||||
expect(core.warning).toHaveBeenCalledWith(
|
||||
'Failed to execute tool failing-tool: Error: Tool execution failed'
|
||||
)
|
||||
})
|
||||
|
||||
it('handles invalid JSON arguments', async () => {
|
||||
const toolCall = {
|
||||
id: 'call-789',
|
||||
function: {
|
||||
name: 'test-tool',
|
||||
arguments: 'invalid-json'
|
||||
}
|
||||
}
|
||||
|
||||
const result = await executeToolCall(mockClient, toolCall)
|
||||
|
||||
expect(result.tool_call_id).toBe('call-789')
|
||||
expect(result.role).toBe('tool')
|
||||
expect(result.name).toBe('test-tool')
|
||||
expect(result.content).toContain('Error:')
|
||||
expect(core.warning).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to execute tool test-tool:')
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('executeToolCalls', () => {
|
||||
it('executes multiple tool calls successfully', async () => {
|
||||
const toolCalls = [
|
||||
{
|
||||
id: 'call-1',
|
||||
function: { name: 'tool-1', arguments: '{}' }
|
||||
},
|
||||
{
|
||||
id: 'call-2',
|
||||
function: { name: 'tool-2', arguments: '{"param": "value"}' }
|
||||
}
|
||||
]
|
||||
|
||||
mockCallTool
|
||||
.mockResolvedValueOnce({
|
||||
content: [{ type: 'text', text: 'Result 1' }]
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
content: [{ type: 'text', text: 'Result 2' }]
|
||||
})
|
||||
|
||||
const results = await executeToolCalls(mockClient, toolCalls)
|
||||
|
||||
expect(results).toHaveLength(2)
|
||||
expect(results[0].tool_call_id).toBe('call-1')
|
||||
expect(results[1].tool_call_id).toBe('call-2')
|
||||
expect(mockCallTool).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('handles empty tool calls array', async () => {
|
||||
const results = await executeToolCalls(mockClient, [])
|
||||
|
||||
expect(results).toHaveLength(0)
|
||||
expect(mockCallTool).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('continues execution even if some tools fail', async () => {
|
||||
const toolCalls = [
|
||||
{
|
||||
id: 'call-1',
|
||||
function: { name: 'tool-1', arguments: '{}' }
|
||||
},
|
||||
{
|
||||
id: 'call-2',
|
||||
function: { name: 'tool-2', arguments: '{}' }
|
||||
}
|
||||
]
|
||||
|
||||
mockCallTool
|
||||
.mockResolvedValueOnce({
|
||||
content: [{ type: 'text', text: 'Result 1' }]
|
||||
})
|
||||
.mockRejectedValueOnce(new Error('Tool 2 failed'))
|
||||
|
||||
const results = await executeToolCalls(mockClient, toolCalls)
|
||||
|
||||
expect(results).toHaveLength(2)
|
||||
expect(results[0].content).toContain('Result 1')
|
||||
expect(results[1].content).toContain('Error:')
|
||||
})
|
||||
})
|
||||
})
|
||||
+7298
-7236
File diff suppressed because it is too large
Load Diff
+1
-1
File diff suppressed because one or more lines are too long
@@ -0,0 +1,31 @@
|
||||
import * as core from '@actions/core'
|
||||
import * as fs from 'fs'
|
||||
|
||||
/**
|
||||
* Helper function to load content from a file or use fallback input
|
||||
* @param filePathInput - Input name for the file path
|
||||
* @param contentInput - Input name for the direct content
|
||||
* @param defaultValue - Default value to use if neither file nor content is provided
|
||||
* @returns The loaded content
|
||||
*/
|
||||
export function loadContentFromFileOrInput(
|
||||
filePathInput: string,
|
||||
contentInput: string,
|
||||
defaultValue?: string
|
||||
): string {
|
||||
const filePath = core.getInput(filePathInput)
|
||||
const contentString = core.getInput(contentInput)
|
||||
|
||||
if (filePath !== undefined && filePath !== '') {
|
||||
if (!fs.existsSync(filePath)) {
|
||||
throw new Error(`File for ${filePathInput} was not found: ${filePath}`)
|
||||
}
|
||||
return fs.readFileSync(filePath, 'utf-8')
|
||||
} else if (contentString !== undefined && contentString !== '') {
|
||||
return contentString
|
||||
} else if (defaultValue !== undefined) {
|
||||
return defaultValue
|
||||
} else {
|
||||
throw new Error(`Neither ${filePathInput} nor ${contentInput} was set`)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
import * as core from '@actions/core'
|
||||
import ModelClient, { isUnexpected } from '@azure-rest/ai-inference'
|
||||
import { AzureKeyCredential } from '@azure/core-auth'
|
||||
import { MCPClient, executeToolCalls } from './mcp.js'
|
||||
|
||||
export interface InferenceRequest {
|
||||
systemPrompt: string
|
||||
prompt: string
|
||||
modelName: string
|
||||
maxTokens: number
|
||||
endpoint: string
|
||||
token: string
|
||||
}
|
||||
|
||||
export interface InferenceResponse {
|
||||
content: string | null
|
||||
toolCalls?: any[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple one-shot inference without tools
|
||||
*/
|
||||
export async function simpleInference(
|
||||
request: InferenceRequest
|
||||
): Promise<string | null> {
|
||||
core.info('Running simple inference without tools')
|
||||
|
||||
const client = ModelClient(
|
||||
request.endpoint,
|
||||
new AzureKeyCredential(request.token),
|
||||
{
|
||||
userAgentOptions: { userAgentPrefix: 'github-actions-ai-inference' }
|
||||
}
|
||||
)
|
||||
|
||||
const requestBody = {
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: request.systemPrompt
|
||||
},
|
||||
{ role: 'user', content: request.prompt }
|
||||
],
|
||||
max_tokens: request.maxTokens,
|
||||
model: request.modelName
|
||||
}
|
||||
|
||||
const response = await client.path('/chat/completions').post({
|
||||
body: requestBody
|
||||
})
|
||||
|
||||
if (isUnexpected(response)) {
|
||||
throw new Error(
|
||||
'An error occurred while fetching the response (' +
|
||||
response.status +
|
||||
'): ' +
|
||||
response.body
|
||||
)
|
||||
}
|
||||
|
||||
const modelResponse = response.body.choices[0].message.content
|
||||
core.info(`Model response: ${modelResponse || 'No response content'}`)
|
||||
|
||||
return modelResponse
|
||||
}
|
||||
|
||||
/**
|
||||
* MCP-enabled inference with tool execution loop
|
||||
*/
|
||||
export async function mcpInference(
|
||||
request: InferenceRequest,
|
||||
mcpClient: MCPClient
|
||||
): Promise<string | null> {
|
||||
core.info('Running MCP inference with tools')
|
||||
|
||||
const client = ModelClient(
|
||||
request.endpoint,
|
||||
new AzureKeyCredential(request.token),
|
||||
{
|
||||
userAgentOptions: { userAgentPrefix: 'github-actions-ai-inference' }
|
||||
}
|
||||
)
|
||||
|
||||
// Start with the initial conversation
|
||||
let messages: any[] = [
|
||||
{
|
||||
role: 'system',
|
||||
content: request.systemPrompt
|
||||
},
|
||||
{ role: 'user', content: request.prompt }
|
||||
]
|
||||
|
||||
let iterationCount = 0
|
||||
const maxIterations = 5 // Prevent infinite loops
|
||||
|
||||
while (iterationCount < maxIterations) {
|
||||
iterationCount++
|
||||
core.info(`MCP inference iteration ${iterationCount}`)
|
||||
|
||||
const requestBody = {
|
||||
messages: messages,
|
||||
max_tokens: request.maxTokens,
|
||||
model: request.modelName,
|
||||
tools: mcpClient.tools
|
||||
}
|
||||
|
||||
const response = await client.path('/chat/completions').post({
|
||||
body: requestBody
|
||||
})
|
||||
|
||||
if (isUnexpected(response)) {
|
||||
throw new Error(
|
||||
'An error occurred while fetching the response (' +
|
||||
response.status +
|
||||
'): ' +
|
||||
response.body
|
||||
)
|
||||
}
|
||||
|
||||
const assistantMessage = response.body.choices[0].message
|
||||
const modelResponse = assistantMessage.content
|
||||
const toolCalls = assistantMessage.tool_calls
|
||||
|
||||
core.info(`Model response: ${modelResponse || 'No response content'}`)
|
||||
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content: modelResponse,
|
||||
...(toolCalls && { tool_calls: toolCalls })
|
||||
})
|
||||
|
||||
if (!toolCalls || toolCalls.length === 0) {
|
||||
core.info('No tool calls requested, ending MCP inference loop')
|
||||
return modelResponse
|
||||
}
|
||||
|
||||
core.info(`Model requested ${toolCalls.length} tool calls`)
|
||||
|
||||
const toolResults = await executeToolCalls(mcpClient.client, toolCalls)
|
||||
messages.push(...toolResults)
|
||||
|
||||
core.info('Tool results added, continuing conversation...')
|
||||
}
|
||||
|
||||
core.warning(
|
||||
`MCP inference loop exceeded maximum iterations (${maxIterations})`
|
||||
)
|
||||
|
||||
// Return the last assistant message content
|
||||
const lastAssistantMessage = messages
|
||||
.slice()
|
||||
.reverse()
|
||||
.find((msg) => msg.role === 'assistant')
|
||||
|
||||
return lastAssistantMessage?.content || null
|
||||
}
|
||||
+21
-220
@@ -1,43 +1,13 @@
|
||||
import * as core from '@actions/core'
|
||||
import ModelClient, { isUnexpected } from '@azure-rest/ai-inference'
|
||||
import { AzureKeyCredential } from '@azure/core-auth'
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
||||
import * as fs from 'fs'
|
||||
import * as os from 'os'
|
||||
import * as path from 'path'
|
||||
import { connectToMCP } from './mcp.js'
|
||||
import { simpleInference, mcpInference, InferenceRequest } from './inference.js'
|
||||
import { loadContentFromFileOrInput } from './helpers.js'
|
||||
|
||||
const RESPONSE_FILE = 'modelResponse.txt'
|
||||
|
||||
/**
|
||||
* Helper function to load content from a file or use fallback input
|
||||
* @param filePathInput - Input name for the file path
|
||||
* @param contentInput - Input name for the direct content
|
||||
* @param defaultValue - Default value to use if neither file nor content is provided
|
||||
* @returns The loaded content
|
||||
*/
|
||||
function loadContentFromFileOrInput(
|
||||
filePathInput: string,
|
||||
contentInput: string,
|
||||
defaultValue?: string
|
||||
): string {
|
||||
const filePath = core.getInput(filePathInput)
|
||||
const contentString = core.getInput(contentInput)
|
||||
|
||||
if (filePath !== undefined && filePath !== '') {
|
||||
if (!fs.existsSync(filePath)) {
|
||||
throw new Error(`File for ${filePathInput} was not found: ${filePath}`)
|
||||
}
|
||||
return fs.readFileSync(filePath, 'utf-8')
|
||||
} else if (contentString !== undefined && contentString !== '') {
|
||||
return contentString
|
||||
} else if (defaultValue !== undefined) {
|
||||
return defaultValue
|
||||
} else {
|
||||
throw new Error(`Neither ${filePathInput} nor ${contentInput} was set`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The main function for the action.
|
||||
*
|
||||
@@ -45,10 +15,8 @@ function loadContentFromFileOrInput(
|
||||
*/
|
||||
export async function run(): Promise<void> {
|
||||
try {
|
||||
// Load prompt content - required
|
||||
const prompt = loadContentFromFileOrInput('prompt-file', 'prompt')
|
||||
|
||||
// Load system prompt with default value
|
||||
const systemPrompt = loadContentFromFileOrInput(
|
||||
'system-prompt-file',
|
||||
'system-prompt',
|
||||
@@ -64,200 +32,34 @@ export async function run(): Promise<void> {
|
||||
}
|
||||
|
||||
const endpoint = core.getInput('endpoint')
|
||||
|
||||
// Get MCP server configuration
|
||||
const mcpServerUrl = 'https://api.githubcopilot.com/mcp/'
|
||||
const enableMcp = core.getBooleanInput('enable-mcp') || false
|
||||
|
||||
let azureTools: any[] = []
|
||||
let mcp: Client | null = null
|
||||
const inferenceRequest: InferenceRequest = {
|
||||
systemPrompt,
|
||||
prompt,
|
||||
modelName,
|
||||
maxTokens,
|
||||
endpoint,
|
||||
token
|
||||
}
|
||||
|
||||
// Connect to MCP server if enabled
|
||||
if (enableMcp || true) {
|
||||
core.info('Connecting to GitHub MCP server...' + token)
|
||||
let modelResponse: string | null = null
|
||||
|
||||
const transport = new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerUrl),
|
||||
{
|
||||
requestInit: {
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
if (enableMcp) {
|
||||
const mcpClient = await connectToMCP(token)
|
||||
|
||||
mcp = new Client({
|
||||
name: 'ai-inference-action',
|
||||
version: '1.0.0',
|
||||
transport
|
||||
})
|
||||
|
||||
try {
|
||||
await mcp.connect(transport)
|
||||
} catch (mcpError) {
|
||||
core.warning(`Failed to connect to MCP server: ${mcpError}`)
|
||||
// Continue without tools if MCP connection fails
|
||||
return
|
||||
if (mcpClient) {
|
||||
modelResponse = await mcpInference(inferenceRequest, mcpClient)
|
||||
} else {
|
||||
core.warning('MCP connection failed, falling back to simple inference')
|
||||
modelResponse = await simpleInference(inferenceRequest)
|
||||
}
|
||||
|
||||
core.info('Successfully connected to MCP server')
|
||||
|
||||
// Pull tool metadata
|
||||
const tools = await mcp.listTools()
|
||||
core.info(`Retrieved ${tools.tools?.length || 0} tools from MCP server`)
|
||||
|
||||
// Map MCP → Azure tool definitions
|
||||
azureTools = (tools.tools || []).map((t) => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.inputSchema
|
||||
}
|
||||
}))
|
||||
|
||||
core.info(`Mapped ${azureTools.length} tools for Azure AI Inference`)
|
||||
} else {
|
||||
modelResponse = await simpleInference(inferenceRequest)
|
||||
}
|
||||
|
||||
const client = ModelClient(endpoint, new AzureKeyCredential(token), {
|
||||
userAgentOptions: { userAgentPrefix: 'github-actions-ai-inference' }
|
||||
})
|
||||
|
||||
const requestBody: any = {
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: systemPrompt
|
||||
},
|
||||
{ role: 'user', content: prompt }
|
||||
],
|
||||
max_tokens: maxTokens,
|
||||
model: modelName
|
||||
}
|
||||
|
||||
// Add tools if available
|
||||
if (azureTools.length > 0) {
|
||||
requestBody.tools = azureTools
|
||||
}
|
||||
|
||||
const response = await client.path('/chat/completions').post({
|
||||
body: requestBody
|
||||
})
|
||||
|
||||
if (isUnexpected(response)) {
|
||||
throw new Error(
|
||||
'An error occurred while fetching the response (' +
|
||||
response.status +
|
||||
'): ' +
|
||||
response.body
|
||||
)
|
||||
}
|
||||
|
||||
let modelResponse: string | null =
|
||||
response.body.choices[0].message.content
|
||||
|
||||
core.info(`Model response: ${response || 'No response content'}`)
|
||||
|
||||
// Handle tool calls if present
|
||||
const toolCalls = response.body.choices[0].message.tool_calls
|
||||
if (toolCalls && toolCalls.length > 0 && mcp) {
|
||||
core.info(`Model requested ${toolCalls.length} tool calls`)
|
||||
|
||||
// Execute tool calls via MCP and continue the conversation
|
||||
const toolResults: any[] = []
|
||||
|
||||
for (const toolCall of toolCalls) {
|
||||
core.info(
|
||||
`Executing tool: ${toolCall.function.name} with args: ${toolCall.function.arguments}`
|
||||
)
|
||||
|
||||
try {
|
||||
// Parse the arguments from JSON string
|
||||
const args = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
// Call the tool via MCP
|
||||
const result = await mcp.callTool({
|
||||
name: toolCall.function.name,
|
||||
arguments: args
|
||||
})
|
||||
|
||||
core.info(`Tool ${toolCall.function.name} executed successfully`)
|
||||
|
||||
// Store the result for the follow-up conversation
|
||||
toolResults.push({
|
||||
tool_call_id: toolCall.id,
|
||||
role: 'tool',
|
||||
name: toolCall.function.name,
|
||||
content: JSON.stringify(result.content)
|
||||
})
|
||||
|
||||
} catch (toolError) {
|
||||
core.warning(`Failed to execute tool ${toolCall.function.name}: ${toolError}`)
|
||||
|
||||
// Add error result to continue conversation
|
||||
toolResults.push({
|
||||
tool_call_id: toolCall.id,
|
||||
role: 'tool',
|
||||
name: toolCall.function.name,
|
||||
content: `Error: ${toolError}`
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// If we have tool results, continue the conversation
|
||||
if (toolResults.length > 0) {
|
||||
core.info('Continuing conversation with tool results...')
|
||||
|
||||
// Build the follow-up request with the original conversation + tool results
|
||||
const followUpMessages = [
|
||||
{
|
||||
role: 'system',
|
||||
content: systemPrompt
|
||||
},
|
||||
{ role: 'user', content: prompt },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: modelResponse,
|
||||
tool_calls: toolCalls
|
||||
},
|
||||
...toolResults
|
||||
]
|
||||
|
||||
const followUpRequest: any = {
|
||||
messages: followUpMessages,
|
||||
max_tokens: maxTokens,
|
||||
model: modelName
|
||||
}
|
||||
|
||||
// Add tools again for potential follow-up tool calls
|
||||
if (azureTools.length > 0) {
|
||||
followUpRequest.tools = azureTools
|
||||
}
|
||||
|
||||
const followUpResponse = await client.path('/chat/completions').post({
|
||||
body: followUpRequest
|
||||
})
|
||||
|
||||
if (isUnexpected(followUpResponse)) {
|
||||
core.warning(
|
||||
'Failed to get follow-up response after tool execution: ' +
|
||||
followUpResponse.status + ': ' + followUpResponse.body
|
||||
)
|
||||
} else {
|
||||
const finalResponse = followUpResponse.body.choices[0].message.content
|
||||
core.info(`Final response after tool execution: ${finalResponse}`)
|
||||
|
||||
// Update the model response to the final one
|
||||
modelResponse = finalResponse || modelResponse
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set outputs for other workflow steps to use
|
||||
core.setOutput('response', modelResponse || '')
|
||||
|
||||
// Save the response to a file in case the response overflow the output limit
|
||||
const responseFilePath = path.join(tempDir(), RESPONSE_FILE)
|
||||
core.setOutput('response-file', responseFilePath)
|
||||
|
||||
@@ -265,7 +67,6 @@ export async function run(): Promise<void> {
|
||||
fs.writeFileSync(responseFilePath, modelResponse, 'utf-8')
|
||||
}
|
||||
} catch (error) {
|
||||
// Fail the workflow run if an error occurs
|
||||
if (error instanceof Error) {
|
||||
core.setFailed(error.message)
|
||||
} else {
|
||||
|
||||
+129
@@ -0,0 +1,129 @@
|
||||
import * as core from '@actions/core'
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
||||
|
||||
export interface ToolResult {
|
||||
tool_call_id: string
|
||||
role: 'tool'
|
||||
name: string
|
||||
content: string
|
||||
}
|
||||
|
||||
export interface MCPClient {
|
||||
client: Client
|
||||
tools: any[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect to the MCP server and retrieve available tools
|
||||
*/
|
||||
export async function connectToMCP(token: string): Promise<MCPClient | null> {
|
||||
const mcpServerUrl = 'https://api.githubcopilot.com/mcp/'
|
||||
|
||||
core.info('Connecting to GitHub MCP server...')
|
||||
|
||||
const transport = new StreamableHTTPClientTransport(new URL(mcpServerUrl), {
|
||||
requestInit: {
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const client = new Client({
|
||||
name: 'ai-inference-action',
|
||||
version: '1.0.0',
|
||||
transport
|
||||
})
|
||||
|
||||
try {
|
||||
await client.connect(transport)
|
||||
} catch (mcpError) {
|
||||
core.warning(`Failed to connect to MCP server: ${mcpError}`)
|
||||
return null
|
||||
}
|
||||
|
||||
core.info('Successfully connected to MCP server')
|
||||
|
||||
// Pull tool metadata
|
||||
const toolsResponse = await client.listTools()
|
||||
core.info(
|
||||
`Retrieved ${toolsResponse.tools?.length || 0} tools from MCP server`
|
||||
)
|
||||
|
||||
// Map MCP → Azure tool definitions
|
||||
const tools = (toolsResponse.tools || []).map((t) => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.inputSchema
|
||||
}
|
||||
}))
|
||||
|
||||
core.info(`Mapped ${tools.length} tools for Azure AI Inference`)
|
||||
|
||||
return { client, tools }
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a single tool call via MCP
|
||||
*/
|
||||
export async function executeToolCall(
|
||||
mcpClient: Client,
|
||||
toolCall: any
|
||||
): Promise<ToolResult> {
|
||||
core.info(
|
||||
`Executing tool: ${toolCall.function.name} with args: ${toolCall.function.arguments}`
|
||||
)
|
||||
|
||||
try {
|
||||
// Parse the arguments from JSON string
|
||||
const args = JSON.parse(toolCall.function.arguments)
|
||||
|
||||
// Call the tool via MCP
|
||||
const result = await mcpClient.callTool({
|
||||
name: toolCall.function.name,
|
||||
arguments: args
|
||||
})
|
||||
|
||||
core.info(`Tool ${toolCall.function.name} executed successfully`)
|
||||
|
||||
// Return the result formatted for the conversation
|
||||
return {
|
||||
tool_call_id: toolCall.id,
|
||||
role: 'tool',
|
||||
name: toolCall.function.name,
|
||||
content: JSON.stringify(result.content)
|
||||
}
|
||||
} catch (toolError) {
|
||||
core.warning(
|
||||
`Failed to execute tool ${toolCall.function.name}: ${toolError}`
|
||||
)
|
||||
|
||||
// Return error result to continue conversation
|
||||
return {
|
||||
tool_call_id: toolCall.id,
|
||||
role: 'tool',
|
||||
name: toolCall.function.name,
|
||||
content: `Error: ${toolError}`
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute all tool calls from a response
|
||||
*/
|
||||
export async function executeToolCalls(
|
||||
mcpClient: Client,
|
||||
toolCalls: any[]
|
||||
): Promise<ToolResult[]> {
|
||||
const toolResults: ToolResult[] = []
|
||||
|
||||
for (const toolCall of toolCalls) {
|
||||
const result = await executeToolCall(mcpClient, toolCall)
|
||||
toolResults.push(result)
|
||||
}
|
||||
|
||||
return toolResults
|
||||
}
|
||||
Reference in New Issue
Block a user