chore: use github's shared prettier-config
This commit is contained in:
@@ -83,8 +83,7 @@ jobs:
|
|||||||
run: echo "hello" > prompt.txt
|
run: echo "hello" > prompt.txt
|
||||||
|
|
||||||
- name: Create System Prompt File
|
- name: Create System Prompt File
|
||||||
run:
|
run: echo "You are a helpful AI assistant for testing." > system-prompt.txt
|
||||||
echo "You are a helpful AI assistant for testing." > system-prompt.txt
|
|
||||||
|
|
||||||
- name: Test Local Action with Prompt File
|
- name: Test Local Action with Prompt File
|
||||||
id: test-action-prompt-file
|
id: test-action-prompt-file
|
||||||
|
|||||||
@@ -11,8 +11,7 @@ permissions:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
update_tag:
|
update_tag:
|
||||||
name:
|
name: Update the major tag to include the ${{ github.event.release.tag_name }}
|
||||||
Update the major tag to include the ${{ github.event.release.tag_name }}
|
|
||||||
changes
|
changes
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
# See: https://prettier.io/docs/en/configuration
|
|
||||||
|
|
||||||
printWidth: 80
|
|
||||||
tabWidth: 2
|
|
||||||
useTabs: false
|
|
||||||
semi: false
|
|
||||||
singleQuote: true
|
|
||||||
quoteProps: as-needed
|
|
||||||
jsxSingleQuote: false
|
|
||||||
trailingComma: none
|
|
||||||
bracketSpacing: true
|
|
||||||
bracketSameLine: true
|
|
||||||
arrowParens: always
|
|
||||||
proseWrap: always
|
|
||||||
htmlWhitespaceSensitivity: css
|
|
||||||
endOfLine: lf
|
|
||||||
@@ -83,8 +83,7 @@ model: openai/gpt-4o
|
|||||||
```yaml
|
```yaml
|
||||||
messages:
|
messages:
|
||||||
- role: system
|
- role: system
|
||||||
content:
|
content: You are a helpful assistant that describes animals using JSON format
|
||||||
You are a helpful assistant that describes animals using JSON format
|
|
||||||
- role: user
|
- role: user
|
||||||
content: |-
|
content: |-
|
||||||
Describe a {{animal}}
|
Describe a {{animal}}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import type * as core from '@actions/core'
|
import type * as core from '@actions/core'
|
||||||
import { vi } from 'vitest'
|
import {vi} from 'vitest'
|
||||||
|
|
||||||
export const debug = vi.fn<typeof core.debug>()
|
export const debug = vi.fn<typeof core.debug>()
|
||||||
export const error = vi.fn<typeof core.error>()
|
export const error = vi.fn<typeof core.error>()
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
messages:
|
messages:
|
||||||
- role: system
|
- role: system
|
||||||
content:
|
content: You are a helpful assistant that describes animals using JSON format
|
||||||
You are a helpful assistant that describes animals using JSON format
|
|
||||||
- role: user
|
- role: user
|
||||||
content: |-
|
content: |-
|
||||||
Describe a {{animal}}
|
Describe a {{animal}}
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
import { vi } from 'vitest'
|
import {vi} from 'vitest'
|
||||||
|
|
||||||
export const wait = vi.fn<typeof import('../src/wait.js').wait>()
|
export const wait = vi.fn<typeof import('../src/wait.js').wait>()
|
||||||
|
|||||||
@@ -1,41 +1,37 @@
|
|||||||
import { describe, it, expect } from 'vitest'
|
import {describe, it, expect} from 'vitest'
|
||||||
import {
|
import {buildMessages, buildResponseFormat, buildInferenceRequest} from '../src/helpers'
|
||||||
buildMessages,
|
import {PromptConfig} from '../src/prompt'
|
||||||
buildResponseFormat,
|
|
||||||
buildInferenceRequest
|
|
||||||
} from '../src/helpers'
|
|
||||||
import { PromptConfig } from '../src/prompt'
|
|
||||||
|
|
||||||
describe('helpers.ts - inference request building', () => {
|
describe('helpers.ts - inference request building', () => {
|
||||||
describe('buildMessages', () => {
|
describe('buildMessages', () => {
|
||||||
it('should build messages from prompt config', () => {
|
it('should build messages from prompt config', () => {
|
||||||
const promptConfig: PromptConfig = {
|
const promptConfig: PromptConfig = {
|
||||||
messages: [
|
messages: [
|
||||||
{ role: 'system', content: 'System message' },
|
{role: 'system', content: 'System message'},
|
||||||
{ role: 'user', content: 'User message' }
|
{role: 'user', content: 'User message'},
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = buildMessages(promptConfig)
|
const result = buildMessages(promptConfig)
|
||||||
expect(result).toEqual([
|
expect(result).toEqual([
|
||||||
{ role: 'system', content: 'System message' },
|
{role: 'system', content: 'System message'},
|
||||||
{ role: 'user', content: 'User message' }
|
{role: 'user', content: 'User message'},
|
||||||
])
|
])
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should build messages from legacy format', () => {
|
it('should build messages from legacy format', () => {
|
||||||
const result = buildMessages(undefined, 'System prompt', 'User prompt')
|
const result = buildMessages(undefined, 'System prompt', 'User prompt')
|
||||||
expect(result).toEqual([
|
expect(result).toEqual([
|
||||||
{ role: 'system', content: 'System prompt' },
|
{role: 'system', content: 'System prompt'},
|
||||||
{ role: 'user', content: 'User prompt' }
|
{role: 'user', content: 'User prompt'},
|
||||||
])
|
])
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should use default system prompt when none provided', () => {
|
it('should use default system prompt when none provided', () => {
|
||||||
const result = buildMessages(undefined, undefined, 'User prompt')
|
const result = buildMessages(undefined, undefined, 'User prompt')
|
||||||
expect(result).toEqual([
|
expect(result).toEqual([
|
||||||
{ role: 'system', content: 'You are a helpful assistant' },
|
{role: 'system', content: 'You are a helpful assistant'},
|
||||||
{ role: 'user', content: 'User prompt' }
|
{role: 'user', content: 'User prompt'},
|
||||||
])
|
])
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -47,8 +43,8 @@ describe('helpers.ts - inference request building', () => {
|
|||||||
responseFormat: 'json_schema',
|
responseFormat: 'json_schema',
|
||||||
jsonSchema: JSON.stringify({
|
jsonSchema: JSON.stringify({
|
||||||
name: 'test_schema',
|
name: 'test_schema',
|
||||||
schema: { type: 'object' }
|
schema: {type: 'object'},
|
||||||
})
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = buildResponseFormat(promptConfig)
|
const result = buildResponseFormat(promptConfig)
|
||||||
@@ -56,15 +52,15 @@ describe('helpers.ts - inference request building', () => {
|
|||||||
type: 'json_schema',
|
type: 'json_schema',
|
||||||
json_schema: {
|
json_schema: {
|
||||||
name: 'test_schema',
|
name: 'test_schema',
|
||||||
schema: { type: 'object' }
|
schema: {type: 'object'},
|
||||||
}
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should return undefined for text format', () => {
|
it('should return undefined for text format', () => {
|
||||||
const promptConfig: PromptConfig = {
|
const promptConfig: PromptConfig = {
|
||||||
messages: [],
|
messages: [],
|
||||||
responseFormat: 'text'
|
responseFormat: 'text',
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = buildResponseFormat(promptConfig)
|
const result = buildResponseFormat(promptConfig)
|
||||||
@@ -73,7 +69,7 @@ describe('helpers.ts - inference request building', () => {
|
|||||||
|
|
||||||
it('should return undefined when no response format specified', () => {
|
it('should return undefined when no response format specified', () => {
|
||||||
const promptConfig: PromptConfig = {
|
const promptConfig: PromptConfig = {
|
||||||
messages: []
|
messages: [],
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = buildResponseFormat(promptConfig)
|
const result = buildResponseFormat(promptConfig)
|
||||||
@@ -84,12 +80,10 @@ describe('helpers.ts - inference request building', () => {
|
|||||||
const promptConfig: PromptConfig = {
|
const promptConfig: PromptConfig = {
|
||||||
messages: [],
|
messages: [],
|
||||||
responseFormat: 'json_schema',
|
responseFormat: 'json_schema',
|
||||||
jsonSchema: 'invalid json'
|
jsonSchema: 'invalid json',
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(() => buildResponseFormat(promptConfig)).toThrow(
|
expect(() => buildResponseFormat(promptConfig)).toThrow('Invalid JSON schema')
|
||||||
'Invalid JSON schema'
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -97,14 +91,14 @@ describe('helpers.ts - inference request building', () => {
|
|||||||
it('should build complete inference request from prompt config', () => {
|
it('should build complete inference request from prompt config', () => {
|
||||||
const promptConfig: PromptConfig = {
|
const promptConfig: PromptConfig = {
|
||||||
messages: [
|
messages: [
|
||||||
{ role: 'system', content: 'System message' },
|
{role: 'system', content: 'System message'},
|
||||||
{ role: 'user', content: 'User message' }
|
{role: 'user', content: 'User message'},
|
||||||
],
|
],
|
||||||
responseFormat: 'json_schema',
|
responseFormat: 'json_schema',
|
||||||
jsonSchema: JSON.stringify({
|
jsonSchema: JSON.stringify({
|
||||||
name: 'test_schema',
|
name: 'test_schema',
|
||||||
schema: { type: 'object' }
|
schema: {type: 'object'},
|
||||||
})
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = buildInferenceRequest(
|
const result = buildInferenceRequest(
|
||||||
@@ -114,13 +108,13 @@ describe('helpers.ts - inference request building', () => {
|
|||||||
'gpt-4',
|
'gpt-4',
|
||||||
100,
|
100,
|
||||||
'https://api.test.com',
|
'https://api.test.com',
|
||||||
'test-token'
|
'test-token',
|
||||||
)
|
)
|
||||||
|
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
messages: [
|
messages: [
|
||||||
{ role: 'system', content: 'System message' },
|
{role: 'system', content: 'System message'},
|
||||||
{ role: 'user', content: 'User message' }
|
{role: 'user', content: 'User message'},
|
||||||
],
|
],
|
||||||
modelName: 'gpt-4',
|
modelName: 'gpt-4',
|
||||||
maxTokens: 100,
|
maxTokens: 100,
|
||||||
@@ -130,9 +124,9 @@ describe('helpers.ts - inference request building', () => {
|
|||||||
type: 'json_schema',
|
type: 'json_schema',
|
||||||
json_schema: {
|
json_schema: {
|
||||||
name: 'test_schema',
|
name: 'test_schema',
|
||||||
schema: { type: 'object' }
|
schema: {type: 'object'},
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -144,19 +138,19 @@ describe('helpers.ts - inference request building', () => {
|
|||||||
'gpt-4',
|
'gpt-4',
|
||||||
100,
|
100,
|
||||||
'https://api.test.com',
|
'https://api.test.com',
|
||||||
'test-token'
|
'test-token',
|
||||||
)
|
)
|
||||||
|
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
messages: [
|
messages: [
|
||||||
{ role: 'system', content: 'System prompt' },
|
{role: 'system', content: 'System prompt'},
|
||||||
{ role: 'user', content: 'User prompt' }
|
{role: 'user', content: 'User prompt'},
|
||||||
],
|
],
|
||||||
modelName: 'gpt-4',
|
modelName: 'gpt-4',
|
||||||
maxTokens: 100,
|
maxTokens: 100,
|
||||||
endpoint: 'https://api.test.com',
|
endpoint: 'https://api.test.com',
|
||||||
token: 'test-token',
|
token: 'test-token',
|
||||||
responseFormat: undefined
|
responseFormat: undefined,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { vi, it, expect, beforeEach, describe } from 'vitest'
|
import {vi, it, expect, beforeEach, describe} from 'vitest'
|
||||||
import * as core from '../__fixtures__/core.js'
|
import * as core from '../__fixtures__/core.js'
|
||||||
|
|
||||||
const mockExistsSync = vi.fn()
|
const mockExistsSync = vi.fn()
|
||||||
@@ -6,12 +6,12 @@ const mockReadFileSync = vi.fn()
|
|||||||
|
|
||||||
vi.mock('fs', () => ({
|
vi.mock('fs', () => ({
|
||||||
existsSync: mockExistsSync,
|
existsSync: mockExistsSync,
|
||||||
readFileSync: mockReadFileSync
|
readFileSync: mockReadFileSync,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@actions/core', () => core)
|
vi.mock('@actions/core', () => core)
|
||||||
|
|
||||||
const { loadContentFromFileOrInput } = await import('../src/helpers.js')
|
const {loadContentFromFileOrInput} = await import('../src/helpers.js')
|
||||||
|
|
||||||
describe('helpers.ts', () => {
|
describe('helpers.ts', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
@@ -103,11 +103,7 @@ describe('helpers.ts', () => {
|
|||||||
|
|
||||||
core.getInput.mockImplementation(() => '')
|
core.getInput.mockImplementation(() => '')
|
||||||
|
|
||||||
const result = loadContentFromFileOrInput(
|
const result = loadContentFromFileOrInput('file-input', 'content-input', defaultValue)
|
||||||
'file-input',
|
|
||||||
'content-input',
|
|
||||||
defaultValue
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(result).toBe(defaultValue)
|
expect(result).toBe(defaultValue)
|
||||||
expect(mockExistsSync).not.toHaveBeenCalled()
|
expect(mockExistsSync).not.toHaveBeenCalled()
|
||||||
@@ -131,11 +127,7 @@ describe('helpers.ts', () => {
|
|||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
core.getInput.mockImplementation(() => undefined as any)
|
core.getInput.mockImplementation(() => undefined as any)
|
||||||
|
|
||||||
const result = loadContentFromFileOrInput(
|
const result = loadContentFromFileOrInput('file-input', 'content-input', defaultValue)
|
||||||
'file-input',
|
|
||||||
'content-input',
|
|
||||||
defaultValue
|
|
||||||
)
|
|
||||||
|
|
||||||
expect(result).toBe(defaultValue)
|
expect(result).toBe(defaultValue)
|
||||||
})
|
})
|
||||||
|
|||||||
+87
-113
@@ -1,48 +1,41 @@
|
|||||||
import {
|
import {vi, type MockedFunction, beforeEach, expect, describe, it} from 'vitest'
|
||||||
vi,
|
|
||||||
type MockedFunction,
|
|
||||||
beforeEach,
|
|
||||||
expect,
|
|
||||||
describe,
|
|
||||||
it
|
|
||||||
} from 'vitest'
|
|
||||||
import * as core from '../__fixtures__/core.js'
|
import * as core from '../__fixtures__/core.js'
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
const mockPost = vi.fn() as MockedFunction<any>
|
const mockPost = vi.fn() as MockedFunction<any>
|
||||||
const mockPath = vi.fn(() => ({ post: mockPost }))
|
const mockPath = vi.fn(() => ({post: mockPost}))
|
||||||
const mockClient = vi.fn(() => ({ path: mockPath }))
|
const mockClient = vi.fn(() => ({path: mockPath}))
|
||||||
|
|
||||||
vi.mock('@azure-rest/ai-inference', () => ({
|
vi.mock('@azure-rest/ai-inference', () => ({
|
||||||
default: mockClient,
|
default: mockClient,
|
||||||
isUnexpected: vi.fn(() => false)
|
isUnexpected: vi.fn(() => false),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@azure/core-auth', () => ({
|
vi.mock('@azure/core-auth', () => ({
|
||||||
AzureKeyCredential: vi.fn()
|
AzureKeyCredential: vi.fn(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
const mockExecuteToolCalls = vi.fn() as MockedFunction<any>
|
const mockExecuteToolCalls = vi.fn() as MockedFunction<any>
|
||||||
vi.mock('../src/mcp.js', () => ({
|
vi.mock('../src/mcp.js', () => ({
|
||||||
executeToolCalls: mockExecuteToolCalls
|
executeToolCalls: mockExecuteToolCalls,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@actions/core', () => core)
|
vi.mock('@actions/core', () => core)
|
||||||
|
|
||||||
// Import the module being tested
|
// Import the module being tested
|
||||||
const { simpleInference, mcpInference } = await import('../src/inference.js')
|
const {simpleInference, mcpInference} = await import('../src/inference.js')
|
||||||
|
|
||||||
describe('inference.ts', () => {
|
describe('inference.ts', () => {
|
||||||
const mockRequest = {
|
const mockRequest = {
|
||||||
messages: [
|
messages: [
|
||||||
{ role: 'system', content: 'You are a test assistant' },
|
{role: 'system', content: 'You are a test assistant'},
|
||||||
{ role: 'user', content: 'Hello, AI!' }
|
{role: 'user', content: 'Hello, AI!'},
|
||||||
],
|
],
|
||||||
modelName: 'gpt-4',
|
modelName: 'gpt-4',
|
||||||
maxTokens: 100,
|
maxTokens: 100,
|
||||||
endpoint: 'https://api.test.com',
|
endpoint: 'https://api.test.com',
|
||||||
token: 'test-token'
|
token: 'test-token',
|
||||||
}
|
}
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
@@ -56,11 +49,11 @@ describe('inference.ts', () => {
|
|||||||
choices: [
|
choices: [
|
||||||
{
|
{
|
||||||
message: {
|
message: {
|
||||||
content: 'Hello, user!'
|
content: 'Hello, user!',
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
mockPost.mockResolvedValue(mockResponse)
|
mockPost.mockResolvedValue(mockResponse)
|
||||||
@@ -68,9 +61,7 @@ describe('inference.ts', () => {
|
|||||||
const result = await simpleInference(mockRequest)
|
const result = await simpleInference(mockRequest)
|
||||||
|
|
||||||
expect(result).toBe('Hello, user!')
|
expect(result).toBe('Hello, user!')
|
||||||
expect(core.info).toHaveBeenCalledWith(
|
expect(core.info).toHaveBeenCalledWith('Running simple inference without tools')
|
||||||
'Running simple inference without tools'
|
|
||||||
)
|
|
||||||
expect(core.info).toHaveBeenCalledWith('Model response: Hello, user!')
|
expect(core.info).toHaveBeenCalledWith('Model response: Hello, user!')
|
||||||
|
|
||||||
// Verify the request structure
|
// Verify the request structure
|
||||||
@@ -79,16 +70,16 @@ describe('inference.ts', () => {
|
|||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: 'system',
|
role: 'system',
|
||||||
content: 'You are a test assistant'
|
content: 'You are a test assistant',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: 'Hello, AI!'
|
content: 'Hello, AI!',
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
max_tokens: 100,
|
max_tokens: 100,
|
||||||
model: 'gpt-4'
|
model: 'gpt-4',
|
||||||
}
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -98,11 +89,11 @@ describe('inference.ts', () => {
|
|||||||
choices: [
|
choices: [
|
||||||
{
|
{
|
||||||
message: {
|
message: {
|
||||||
content: null
|
content: null,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
mockPost.mockResolvedValue(mockResponse)
|
mockPost.mockResolvedValue(mockResponse)
|
||||||
@@ -110,9 +101,7 @@ describe('inference.ts', () => {
|
|||||||
const result = await simpleInference(mockRequest)
|
const result = await simpleInference(mockRequest)
|
||||||
|
|
||||||
expect(result).toBeNull()
|
expect(result).toBeNull()
|
||||||
expect(core.info).toHaveBeenCalledWith(
|
expect(core.info).toHaveBeenCalledWith('Model response: No response content')
|
||||||
'Model response: No response content'
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -126,10 +115,10 @@ describe('inference.ts', () => {
|
|||||||
function: {
|
function: {
|
||||||
name: 'test-tool',
|
name: 'test-tool',
|
||||||
description: 'A test tool',
|
description: 'A test tool',
|
||||||
parameters: { type: 'object' }
|
parameters: {type: 'object'},
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
it('performs inference without tool calls', async () => {
|
it('performs inference without tool calls', async () => {
|
||||||
@@ -139,11 +128,11 @@ describe('inference.ts', () => {
|
|||||||
{
|
{
|
||||||
message: {
|
message: {
|
||||||
content: 'Hello, user!',
|
content: 'Hello, user!',
|
||||||
tool_calls: null
|
tool_calls: null,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
mockPost.mockResolvedValue(mockResponse)
|
mockPost.mockResolvedValue(mockResponse)
|
||||||
@@ -151,13 +140,9 @@ describe('inference.ts', () => {
|
|||||||
const result = await mcpInference(mockRequest, mockMcpClient)
|
const result = await mcpInference(mockRequest, mockMcpClient)
|
||||||
|
|
||||||
expect(result).toBe('Hello, user!')
|
expect(result).toBe('Hello, user!')
|
||||||
expect(core.info).toHaveBeenCalledWith(
|
expect(core.info).toHaveBeenCalledWith('Running GitHub MCP inference with tools')
|
||||||
'Running GitHub MCP inference with tools'
|
|
||||||
)
|
|
||||||
expect(core.info).toHaveBeenCalledWith('MCP inference iteration 1')
|
expect(core.info).toHaveBeenCalledWith('MCP inference iteration 1')
|
||||||
expect(core.info).toHaveBeenCalledWith(
|
expect(core.info).toHaveBeenCalledWith('No tool calls requested, ending GitHub MCP inference loop')
|
||||||
'No tool calls requested, ending GitHub MCP inference loop'
|
|
||||||
)
|
|
||||||
|
|
||||||
// The MCP inference loop will always add the assistant message, even when there are no tool calls
|
// The MCP inference loop will always add the assistant message, even when there are no tool calls
|
||||||
// So we don't check the exact messages, just that tools were included
|
// So we don't check the exact messages, just that tools were included
|
||||||
@@ -175,9 +160,9 @@ describe('inference.ts', () => {
|
|||||||
id: 'call-123',
|
id: 'call-123',
|
||||||
function: {
|
function: {
|
||||||
name: 'test-tool',
|
name: 'test-tool',
|
||||||
arguments: '{"param": "value"}'
|
arguments: '{"param": "value"}',
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const toolResults = [
|
const toolResults = [
|
||||||
@@ -185,8 +170,8 @@ describe('inference.ts', () => {
|
|||||||
tool_call_id: 'call-123',
|
tool_call_id: 'call-123',
|
||||||
role: 'tool',
|
role: 'tool',
|
||||||
name: 'test-tool',
|
name: 'test-tool',
|
||||||
content: 'Tool result'
|
content: 'Tool result',
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
// First response with tool calls
|
// First response with tool calls
|
||||||
@@ -196,11 +181,11 @@ describe('inference.ts', () => {
|
|||||||
{
|
{
|
||||||
message: {
|
message: {
|
||||||
content: 'I need to use a tool.',
|
content: 'I need to use a tool.',
|
||||||
tool_calls: toolCalls
|
tool_calls: toolCalls,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Second response after tool execution
|
// Second response after tool execution
|
||||||
@@ -210,26 +195,21 @@ describe('inference.ts', () => {
|
|||||||
{
|
{
|
||||||
message: {
|
message: {
|
||||||
content: 'Here is the final answer.',
|
content: 'Here is the final answer.',
|
||||||
tool_calls: null
|
tool_calls: null,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
mockPost
|
mockPost.mockResolvedValueOnce(firstResponse).mockResolvedValueOnce(secondResponse)
|
||||||
.mockResolvedValueOnce(firstResponse)
|
|
||||||
.mockResolvedValueOnce(secondResponse)
|
|
||||||
|
|
||||||
mockExecuteToolCalls.mockResolvedValue(toolResults)
|
mockExecuteToolCalls.mockResolvedValue(toolResults)
|
||||||
|
|
||||||
const result = await mcpInference(mockRequest, mockMcpClient)
|
const result = await mcpInference(mockRequest, mockMcpClient)
|
||||||
|
|
||||||
expect(result).toBe('Here is the final answer.')
|
expect(result).toBe('Here is the final answer.')
|
||||||
expect(mockExecuteToolCalls).toHaveBeenCalledWith(
|
expect(mockExecuteToolCalls).toHaveBeenCalledWith(mockMcpClient.client, toolCalls)
|
||||||
mockMcpClient.client,
|
|
||||||
toolCalls
|
|
||||||
)
|
|
||||||
expect(mockPost).toHaveBeenCalledTimes(2)
|
expect(mockPost).toHaveBeenCalledTimes(2)
|
||||||
|
|
||||||
// Verify the second call includes the conversation history
|
// Verify the second call includes the conversation history
|
||||||
@@ -247,9 +227,9 @@ describe('inference.ts', () => {
|
|||||||
id: 'call-123',
|
id: 'call-123',
|
||||||
function: {
|
function: {
|
||||||
name: 'test-tool',
|
name: 'test-tool',
|
||||||
arguments: '{}'
|
arguments: '{}',
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const toolResults = [
|
const toolResults = [
|
||||||
@@ -257,8 +237,8 @@ describe('inference.ts', () => {
|
|||||||
tool_call_id: 'call-123',
|
tool_call_id: 'call-123',
|
||||||
role: 'tool',
|
role: 'tool',
|
||||||
name: 'test-tool',
|
name: 'test-tool',
|
||||||
content: 'Tool result'
|
content: 'Tool result',
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
// Always respond with tool calls to trigger infinite loop
|
// Always respond with tool calls to trigger infinite loop
|
||||||
@@ -268,11 +248,11 @@ describe('inference.ts', () => {
|
|||||||
{
|
{
|
||||||
message: {
|
message: {
|
||||||
content: 'Using tool again.',
|
content: 'Using tool again.',
|
||||||
tool_calls: toolCalls
|
tool_calls: toolCalls,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
mockPost.mockResolvedValue(responseWithToolCalls)
|
mockPost.mockResolvedValue(responseWithToolCalls)
|
||||||
@@ -281,9 +261,7 @@ describe('inference.ts', () => {
|
|||||||
const result = await mcpInference(mockRequest, mockMcpClient)
|
const result = await mcpInference(mockRequest, mockMcpClient)
|
||||||
|
|
||||||
expect(mockPost).toHaveBeenCalledTimes(5) // Max iterations reached
|
expect(mockPost).toHaveBeenCalledTimes(5) // Max iterations reached
|
||||||
expect(core.warning).toHaveBeenCalledWith(
|
expect(core.warning).toHaveBeenCalledWith('GitHub MCP inference loop exceeded maximum iterations (5)')
|
||||||
'GitHub MCP inference loop exceeded maximum iterations (5)'
|
|
||||||
)
|
|
||||||
expect(result).toBe('Using tool again.') // Last assistant message
|
expect(result).toBe('Using tool again.') // Last assistant message
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -294,11 +272,11 @@ describe('inference.ts', () => {
|
|||||||
{
|
{
|
||||||
message: {
|
message: {
|
||||||
content: 'Hello, user!',
|
content: 'Hello, user!',
|
||||||
tool_calls: []
|
tool_calls: [],
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
mockPost.mockResolvedValue(mockResponse)
|
mockPost.mockResolvedValue(mockResponse)
|
||||||
@@ -306,9 +284,7 @@ describe('inference.ts', () => {
|
|||||||
const result = await mcpInference(mockRequest, mockMcpClient)
|
const result = await mcpInference(mockRequest, mockMcpClient)
|
||||||
|
|
||||||
expect(result).toBe('Hello, user!')
|
expect(result).toBe('Hello, user!')
|
||||||
expect(core.info).toHaveBeenCalledWith(
|
expect(core.info).toHaveBeenCalledWith('No tool calls requested, ending GitHub MCP inference loop')
|
||||||
'No tool calls requested, ending GitHub MCP inference loop'
|
|
||||||
)
|
|
||||||
expect(mockExecuteToolCalls).not.toHaveBeenCalled()
|
expect(mockExecuteToolCalls).not.toHaveBeenCalled()
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -316,8 +292,8 @@ describe('inference.ts', () => {
|
|||||||
const toolCalls = [
|
const toolCalls = [
|
||||||
{
|
{
|
||||||
id: 'call-123',
|
id: 'call-123',
|
||||||
function: { name: 'test-tool', arguments: '{}' }
|
function: {name: 'test-tool', arguments: '{}'},
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const firstResponse = {
|
const firstResponse = {
|
||||||
@@ -326,11 +302,11 @@ describe('inference.ts', () => {
|
|||||||
{
|
{
|
||||||
message: {
|
message: {
|
||||||
content: 'First message',
|
content: 'First message',
|
||||||
tool_calls: toolCalls
|
tool_calls: toolCalls,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
const secondResponse = {
|
const secondResponse = {
|
||||||
@@ -339,24 +315,22 @@ describe('inference.ts', () => {
|
|||||||
{
|
{
|
||||||
message: {
|
message: {
|
||||||
content: 'Second message',
|
content: 'Second message',
|
||||||
tool_calls: toolCalls
|
tool_calls: toolCalls,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
mockPost
|
mockPost.mockResolvedValueOnce(firstResponse).mockResolvedValue(secondResponse)
|
||||||
.mockResolvedValueOnce(firstResponse)
|
|
||||||
.mockResolvedValue(secondResponse)
|
|
||||||
|
|
||||||
mockExecuteToolCalls.mockResolvedValue([
|
mockExecuteToolCalls.mockResolvedValue([
|
||||||
{
|
{
|
||||||
tool_call_id: 'call-123',
|
tool_call_id: 'call-123',
|
||||||
role: 'tool',
|
role: 'tool',
|
||||||
name: 'test-tool',
|
name: 'test-tool',
|
||||||
content: 'result'
|
content: 'result',
|
||||||
}
|
},
|
||||||
])
|
])
|
||||||
|
|
||||||
const result = await mcpInference(mockRequest, mockMcpClient)
|
const result = await mcpInference(mockRequest, mockMcpClient)
|
||||||
|
|||||||
@@ -1,12 +1,4 @@
|
|||||||
import {
|
import {describe, it, expect, beforeEach, vi, type MockedFunction, type Mock} from 'vitest'
|
||||||
describe,
|
|
||||||
it,
|
|
||||||
expect,
|
|
||||||
beforeEach,
|
|
||||||
vi,
|
|
||||||
type MockedFunction,
|
|
||||||
type Mock
|
|
||||||
} from 'vitest'
|
|
||||||
import * as core from '../__fixtures__/core.js'
|
import * as core from '../__fixtures__/core.js'
|
||||||
|
|
||||||
// Create fs mocks
|
// Create fs mocks
|
||||||
@@ -26,25 +18,25 @@ const mockConnectToGitHubMCP = vi.fn()
|
|||||||
vi.mock('fs', () => ({
|
vi.mock('fs', () => ({
|
||||||
existsSync: mockExistsSync,
|
existsSync: mockExistsSync,
|
||||||
readFileSync: mockReadFileSync,
|
readFileSync: mockReadFileSync,
|
||||||
writeFileSync: mockWriteFileSync
|
writeFileSync: mockWriteFileSync,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Mock the inference functions
|
// Mock the inference functions
|
||||||
vi.mock('../src/inference.js', () => ({
|
vi.mock('../src/inference.js', () => ({
|
||||||
simpleInference: mockSimpleInference,
|
simpleInference: mockSimpleInference,
|
||||||
mcpInference: mockMcpInference
|
mcpInference: mockMcpInference,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Mock the MCP connection
|
// Mock the MCP connection
|
||||||
vi.mock('../src/mcp.js', () => ({
|
vi.mock('../src/mcp.js', () => ({
|
||||||
connectToGitHubMCP: mockConnectToGitHubMCP
|
connectToGitHubMCP: mockConnectToGitHubMCP,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@actions/core', () => core)
|
vi.mock('@actions/core', () => core)
|
||||||
|
|
||||||
// The module being tested should be imported dynamically. This ensures that the
|
// The module being tested should be imported dynamically. This ensures that the
|
||||||
// mocks are used in place of any actual dependencies.
|
// mocks are used in place of any actual dependencies.
|
||||||
const { run } = await import('../src/main.js')
|
const {run} = await import('../src/main.js')
|
||||||
|
|
||||||
describe('main.ts - prompt.yml integration', () => {
|
describe('main.ts - prompt.yml integration', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
@@ -119,29 +111,23 @@ model: openai/gpt-4o
|
|||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: 'system',
|
role: 'system',
|
||||||
content: 'Be as concise as possible'
|
content: 'Be as concise as possible',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: 'Compare cats and dogs, please'
|
content: 'Compare cats and dogs, please',
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
modelName: 'openai/gpt-4o',
|
modelName: 'openai/gpt-4o',
|
||||||
maxTokens: 200,
|
maxTokens: 200,
|
||||||
endpoint: 'https://models.github.ai/inference',
|
endpoint: 'https://models.github.ai/inference',
|
||||||
token: 'test-token'
|
token: 'test-token',
|
||||||
})
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Verify outputs were set
|
// Verify outputs were set
|
||||||
expect(core.setOutput).toHaveBeenCalledWith(
|
expect(core.setOutput).toHaveBeenCalledWith('response', 'Mocked AI response')
|
||||||
'response',
|
expect(core.setOutput).toHaveBeenCalledWith('response-file', expect.any(String))
|
||||||
'Mocked AI response'
|
|
||||||
)
|
|
||||||
expect(core.setOutput).toHaveBeenCalledWith(
|
|
||||||
'response-file',
|
|
||||||
expect.any(String)
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should fall back to legacy format when not using prompt YAML', async () => {
|
it('should fall back to legacy format when not using prompt YAML', async () => {
|
||||||
@@ -173,18 +159,18 @@ model: openai/gpt-4o
|
|||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: 'system',
|
role: 'system',
|
||||||
content: 'You are helpful'
|
content: 'You are helpful',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: 'Hello, world!'
|
content: 'Hello, world!',
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
modelName: 'openai/gpt-4o',
|
modelName: 'openai/gpt-4o',
|
||||||
maxTokens: 200,
|
maxTokens: 200,
|
||||||
endpoint: 'https://models.github.ai/inference',
|
endpoint: 'https://models.github.ai/inference',
|
||||||
token: 'test-token'
|
token: 'test-token',
|
||||||
})
|
}),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
+33
-58
@@ -1,23 +1,12 @@
|
|||||||
import {
|
import {vi, describe, expect, it, beforeEach, type MockedFunction} from 'vitest'
|
||||||
vi,
|
|
||||||
describe,
|
|
||||||
expect,
|
|
||||||
it,
|
|
||||||
beforeEach,
|
|
||||||
type MockedFunction
|
|
||||||
} from 'vitest'
|
|
||||||
import * as core from '../__fixtures__/core.js'
|
import * as core from '../__fixtures__/core.js'
|
||||||
|
|
||||||
// Default to throwing errors to catch unexpected calls
|
// Default to throwing errors to catch unexpected calls
|
||||||
const mockExistsSync = vi.fn().mockImplementation(() => {
|
const mockExistsSync = vi.fn().mockImplementation(() => {
|
||||||
throw new Error(
|
throw new Error('Unexpected call to existsSync - test should override this implementation')
|
||||||
'Unexpected call to existsSync - test should override this implementation'
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
const mockReadFileSync = vi.fn().mockImplementation(() => {
|
const mockReadFileSync = vi.fn().mockImplementation(() => {
|
||||||
throw new Error(
|
throw new Error('Unexpected call to readFileSync - test should override this implementation')
|
||||||
'Unexpected call to readFileSync - test should override this implementation'
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
const mockWriteFileSync = vi.fn()
|
const mockWriteFileSync = vi.fn()
|
||||||
|
|
||||||
@@ -26,10 +15,7 @@ const mockWriteFileSync = vi.fn()
|
|||||||
* @param fileContents - Object mapping file paths to their contents
|
* @param fileContents - Object mapping file paths to their contents
|
||||||
* @param nonExistentFiles - Array of file paths that should be treated as non-existent
|
* @param nonExistentFiles - Array of file paths that should be treated as non-existent
|
||||||
*/
|
*/
|
||||||
function mockFileContent(
|
function mockFileContent(fileContents: Record<string, string> = {}, nonExistentFiles: string[] = []): void {
|
||||||
fileContents: Record<string, string> = {},
|
|
||||||
nonExistentFiles: string[] = []
|
|
||||||
): void {
|
|
||||||
// Mock existsSync to return true for files that exist, false for those that don't
|
// Mock existsSync to return true for files that exist, false for those that don't
|
||||||
mockExistsSync.mockImplementation((...args: unknown[]): boolean => {
|
mockExistsSync.mockImplementation((...args: unknown[]): boolean => {
|
||||||
const [path] = args as [string]
|
const [path] = args as [string]
|
||||||
@@ -59,11 +45,11 @@ function mockInputs(inputs: Record<string, string> = {}): void {
|
|||||||
token: 'fake-token',
|
token: 'fake-token',
|
||||||
model: 'gpt-4',
|
model: 'gpt-4',
|
||||||
'max-tokens': '100',
|
'max-tokens': '100',
|
||||||
endpoint: 'https://api.test.com'
|
endpoint: 'https://api.test.com',
|
||||||
}
|
}
|
||||||
|
|
||||||
// Combine defaults with user-provided inputs
|
// Combine defaults with user-provided inputs
|
||||||
const allInputs: Record<string, string> = { ...defaultInputs, ...inputs }
|
const allInputs: Record<string, string> = {...defaultInputs, ...inputs}
|
||||||
|
|
||||||
core.getInput.mockImplementation((name: string) => {
|
core.getInput.mockImplementation((name: string) => {
|
||||||
return allInputs[name] || ''
|
return allInputs[name] || ''
|
||||||
@@ -80,17 +66,13 @@ function mockInputs(inputs: Record<string, string> = {}): void {
|
|||||||
*/
|
*/
|
||||||
function verifyStandardResponse(): void {
|
function verifyStandardResponse(): void {
|
||||||
expect(core.setOutput).toHaveBeenNthCalledWith(1, 'response', 'Hello, user!')
|
expect(core.setOutput).toHaveBeenNthCalledWith(1, 'response', 'Hello, user!')
|
||||||
expect(core.setOutput).toHaveBeenNthCalledWith(
|
expect(core.setOutput).toHaveBeenNthCalledWith(2, 'response-file', expect.stringContaining('modelResponse.txt'))
|
||||||
2,
|
|
||||||
'response-file',
|
|
||||||
expect.stringContaining('modelResponse.txt')
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
vi.mock('fs', () => ({
|
vi.mock('fs', () => ({
|
||||||
existsSync: mockExistsSync,
|
existsSync: mockExistsSync,
|
||||||
readFileSync: mockReadFileSync,
|
readFileSync: mockReadFileSync,
|
||||||
writeFileSync: mockWriteFileSync
|
writeFileSync: mockWriteFileSync,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Mock MCP and inference modules
|
// Mock MCP and inference modules
|
||||||
@@ -102,19 +84,19 @@ const mockSimpleInference = vi.fn() as MockedFunction<any>
|
|||||||
const mockMcpInference = vi.fn() as MockedFunction<any>
|
const mockMcpInference = vi.fn() as MockedFunction<any>
|
||||||
|
|
||||||
vi.mock('../src/mcp.js', () => ({
|
vi.mock('../src/mcp.js', () => ({
|
||||||
connectToGitHubMCP: mockConnectToGitHubMCP
|
connectToGitHubMCP: mockConnectToGitHubMCP,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('../src/inference.js', () => ({
|
vi.mock('../src/inference.js', () => ({
|
||||||
simpleInference: mockSimpleInference,
|
simpleInference: mockSimpleInference,
|
||||||
mcpInference: mockMcpInference
|
mcpInference: mockMcpInference,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@actions/core', () => core)
|
vi.mock('@actions/core', () => core)
|
||||||
|
|
||||||
// The module being tested should be imported dynamically. This ensures that the
|
// The module being tested should be imported dynamically. This ensures that the
|
||||||
// mocks are used in place of any actual dependencies.
|
// mocks are used in place of any actual dependencies.
|
||||||
const { run } = await import('../src/main.js')
|
const {run} = await import('../src/main.js')
|
||||||
|
|
||||||
describe('main.ts', () => {
|
describe('main.ts', () => {
|
||||||
// Reset all mocks before each test
|
// Reset all mocks before each test
|
||||||
@@ -132,7 +114,7 @@ describe('main.ts', () => {
|
|||||||
it('Sets the response output', async () => {
|
it('Sets the response output', async () => {
|
||||||
mockInputs({
|
mockInputs({
|
||||||
prompt: 'Hello, AI!',
|
prompt: 'Hello, AI!',
|
||||||
'system-prompt': 'You are a test assistant.'
|
'system-prompt': 'You are a test assistant.',
|
||||||
})
|
})
|
||||||
|
|
||||||
await run()
|
await run()
|
||||||
@@ -144,36 +126,33 @@ describe('main.ts', () => {
|
|||||||
it('Sets a failed status when no prompt is set', async () => {
|
it('Sets a failed status when no prompt is set', async () => {
|
||||||
mockInputs({
|
mockInputs({
|
||||||
prompt: '',
|
prompt: '',
|
||||||
'prompt-file': ''
|
'prompt-file': '',
|
||||||
})
|
})
|
||||||
|
|
||||||
await run()
|
await run()
|
||||||
|
|
||||||
expect(core.setFailed).toHaveBeenNthCalledWith(
|
expect(core.setFailed).toHaveBeenNthCalledWith(1, 'Neither prompt-file nor prompt was set')
|
||||||
1,
|
|
||||||
'Neither prompt-file nor prompt was set'
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
it('uses simple inference when MCP is disabled', async () => {
|
it('uses simple inference when MCP is disabled', async () => {
|
||||||
mockInputs({
|
mockInputs({
|
||||||
prompt: 'Hello, AI!',
|
prompt: 'Hello, AI!',
|
||||||
'system-prompt': 'You are a test assistant.',
|
'system-prompt': 'You are a test assistant.',
|
||||||
'enable-github-mcp': 'false'
|
'enable-github-mcp': 'false',
|
||||||
})
|
})
|
||||||
|
|
||||||
await run()
|
await run()
|
||||||
|
|
||||||
expect(mockSimpleInference).toHaveBeenCalledWith({
|
expect(mockSimpleInference).toHaveBeenCalledWith({
|
||||||
messages: [
|
messages: [
|
||||||
{ role: 'system', content: 'You are a test assistant.' },
|
{role: 'system', content: 'You are a test assistant.'},
|
||||||
{ role: 'user', content: 'Hello, AI!' }
|
{role: 'user', content: 'Hello, AI!'},
|
||||||
],
|
],
|
||||||
modelName: 'gpt-4',
|
modelName: 'gpt-4',
|
||||||
maxTokens: 100,
|
maxTokens: 100,
|
||||||
endpoint: 'https://api.test.com',
|
endpoint: 'https://api.test.com',
|
||||||
token: 'fake-token',
|
token: 'fake-token',
|
||||||
responseFormat: undefined
|
responseFormat: undefined,
|
||||||
})
|
})
|
||||||
expect(mockConnectToGitHubMCP).not.toHaveBeenCalled()
|
expect(mockConnectToGitHubMCP).not.toHaveBeenCalled()
|
||||||
expect(mockMcpInference).not.toHaveBeenCalled()
|
expect(mockMcpInference).not.toHaveBeenCalled()
|
||||||
@@ -184,13 +163,13 @@ describe('main.ts', () => {
|
|||||||
const mockMcpClient = {
|
const mockMcpClient = {
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
client: {} as any,
|
client: {} as any,
|
||||||
tools: [{ type: 'function', function: { name: 'test-tool' } }]
|
tools: [{type: 'function', function: {name: 'test-tool'}}],
|
||||||
}
|
}
|
||||||
|
|
||||||
mockInputs({
|
mockInputs({
|
||||||
prompt: 'Hello, AI!',
|
prompt: 'Hello, AI!',
|
||||||
'system-prompt': 'You are a test assistant.',
|
'system-prompt': 'You are a test assistant.',
|
||||||
'enable-github-mcp': 'true'
|
'enable-github-mcp': 'true',
|
||||||
})
|
})
|
||||||
|
|
||||||
mockConnectToGitHubMCP.mockResolvedValue(mockMcpClient)
|
mockConnectToGitHubMCP.mockResolvedValue(mockMcpClient)
|
||||||
@@ -201,12 +180,12 @@ describe('main.ts', () => {
|
|||||||
expect(mockMcpInference).toHaveBeenCalledWith(
|
expect(mockMcpInference).toHaveBeenCalledWith(
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
messages: [
|
messages: [
|
||||||
{ role: 'system', content: 'You are a test assistant.' },
|
{role: 'system', content: 'You are a test assistant.'},
|
||||||
{ role: 'user', content: 'Hello, AI!' }
|
{role: 'user', content: 'Hello, AI!'},
|
||||||
],
|
],
|
||||||
token: 'fake-token'
|
token: 'fake-token',
|
||||||
}),
|
}),
|
||||||
mockMcpClient
|
mockMcpClient,
|
||||||
)
|
)
|
||||||
expect(mockSimpleInference).not.toHaveBeenCalled()
|
expect(mockSimpleInference).not.toHaveBeenCalled()
|
||||||
verifyStandardResponse()
|
verifyStandardResponse()
|
||||||
@@ -216,7 +195,7 @@ describe('main.ts', () => {
|
|||||||
mockInputs({
|
mockInputs({
|
||||||
prompt: 'Hello, AI!',
|
prompt: 'Hello, AI!',
|
||||||
'system-prompt': 'You are a test assistant.',
|
'system-prompt': 'You are a test assistant.',
|
||||||
'enable-github-mcp': 'true'
|
'enable-github-mcp': 'true',
|
||||||
})
|
})
|
||||||
|
|
||||||
mockConnectToGitHubMCP.mockResolvedValue(null)
|
mockConnectToGitHubMCP.mockResolvedValue(null)
|
||||||
@@ -226,9 +205,7 @@ describe('main.ts', () => {
|
|||||||
expect(mockConnectToGitHubMCP).toHaveBeenCalledWith('fake-token')
|
expect(mockConnectToGitHubMCP).toHaveBeenCalledWith('fake-token')
|
||||||
expect(mockSimpleInference).toHaveBeenCalled()
|
expect(mockSimpleInference).toHaveBeenCalled()
|
||||||
expect(mockMcpInference).not.toHaveBeenCalled()
|
expect(mockMcpInference).not.toHaveBeenCalled()
|
||||||
expect(core.warning).toHaveBeenCalledWith(
|
expect(core.warning).toHaveBeenCalledWith('MCP connection failed, falling back to simple inference')
|
||||||
'MCP connection failed, falling back to simple inference'
|
|
||||||
)
|
|
||||||
verifyStandardResponse()
|
verifyStandardResponse()
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -240,27 +217,27 @@ describe('main.ts', () => {
|
|||||||
|
|
||||||
mockFileContent({
|
mockFileContent({
|
||||||
[promptFile]: promptContent,
|
[promptFile]: promptContent,
|
||||||
[systemPromptFile]: systemPromptContent
|
[systemPromptFile]: systemPromptContent,
|
||||||
})
|
})
|
||||||
|
|
||||||
mockInputs({
|
mockInputs({
|
||||||
'prompt-file': promptFile,
|
'prompt-file': promptFile,
|
||||||
'system-prompt-file': systemPromptFile,
|
'system-prompt-file': systemPromptFile,
|
||||||
'enable-github-mcp': 'false'
|
'enable-github-mcp': 'false',
|
||||||
})
|
})
|
||||||
|
|
||||||
await run()
|
await run()
|
||||||
|
|
||||||
expect(mockSimpleInference).toHaveBeenCalledWith({
|
expect(mockSimpleInference).toHaveBeenCalledWith({
|
||||||
messages: [
|
messages: [
|
||||||
{ role: 'system', content: systemPromptContent },
|
{role: 'system', content: systemPromptContent},
|
||||||
{ role: 'user', content: promptContent }
|
{role: 'user', content: promptContent},
|
||||||
],
|
],
|
||||||
modelName: 'gpt-4',
|
modelName: 'gpt-4',
|
||||||
maxTokens: 100,
|
maxTokens: 100,
|
||||||
endpoint: 'https://api.test.com',
|
endpoint: 'https://api.test.com',
|
||||||
token: 'fake-token',
|
token: 'fake-token',
|
||||||
responseFormat: undefined
|
responseFormat: undefined,
|
||||||
})
|
})
|
||||||
verifyStandardResponse()
|
verifyStandardResponse()
|
||||||
})
|
})
|
||||||
@@ -271,13 +248,11 @@ describe('main.ts', () => {
|
|||||||
mockFileContent({}, [promptFile])
|
mockFileContent({}, [promptFile])
|
||||||
|
|
||||||
mockInputs({
|
mockInputs({
|
||||||
'prompt-file': promptFile
|
'prompt-file': promptFile,
|
||||||
})
|
})
|
||||||
|
|
||||||
await run()
|
await run()
|
||||||
|
|
||||||
expect(core.setFailed).toHaveBeenCalledWith(
|
expect(core.setFailed).toHaveBeenCalledWith(`File for prompt-file was not found: ${promptFile}`)
|
||||||
`File for prompt-file was not found: ${promptFile}`
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
+44
-75
@@ -1,11 +1,4 @@
|
|||||||
import {
|
import {vi, type MockedFunction, describe, it, expect, beforeEach} from 'vitest'
|
||||||
vi,
|
|
||||||
type MockedFunction,
|
|
||||||
describe,
|
|
||||||
it,
|
|
||||||
expect,
|
|
||||||
beforeEach
|
|
||||||
} from 'vitest'
|
|
||||||
import * as core from '../__fixtures__/core.js'
|
import * as core from '../__fixtures__/core.js'
|
||||||
|
|
||||||
// Mock MCP SDK
|
// Mock MCP SDK
|
||||||
@@ -19,24 +12,22 @@ const mockCallTool = vi.fn() as MockedFunction<any>
|
|||||||
const mockClient = {
|
const mockClient = {
|
||||||
connect: mockConnect,
|
connect: mockConnect,
|
||||||
listTools: mockListTools,
|
listTools: mockListTools,
|
||||||
callTool: mockCallTool
|
callTool: mockCallTool,
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
} as any
|
} as any
|
||||||
|
|
||||||
vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({
|
vi.mock('@modelcontextprotocol/sdk/client/index.js', () => ({
|
||||||
Client: vi.fn(() => mockClient)
|
Client: vi.fn(() => mockClient),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => ({
|
vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => ({
|
||||||
StreamableHTTPClientTransport: vi.fn()
|
StreamableHTTPClientTransport: vi.fn(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@actions/core', () => core)
|
vi.mock('@actions/core', () => core)
|
||||||
|
|
||||||
// Import the module being tested
|
// Import the module being tested
|
||||||
const { connectToGitHubMCP, executeToolCall, executeToolCalls } = await import(
|
const {connectToGitHubMCP, executeToolCall, executeToolCalls} = await import('../src/mcp.js')
|
||||||
'../src/mcp.js'
|
|
||||||
)
|
|
||||||
|
|
||||||
describe('mcp.ts', () => {
|
describe('mcp.ts', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
@@ -50,20 +41,20 @@ describe('mcp.ts', () => {
|
|||||||
{
|
{
|
||||||
name: 'test-tool-1',
|
name: 'test-tool-1',
|
||||||
description: 'Test tool 1',
|
description: 'Test tool 1',
|
||||||
inputSchema: { type: 'object', properties: {} }
|
inputSchema: {type: 'object', properties: {}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: 'test-tool-2',
|
name: 'test-tool-2',
|
||||||
description: 'Test tool 2',
|
description: 'Test tool 2',
|
||||||
inputSchema: {
|
inputSchema: {
|
||||||
type: 'object',
|
type: 'object',
|
||||||
properties: { param: { type: 'string' } }
|
properties: {param: {type: 'string'}},
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
mockConnect.mockResolvedValue(undefined)
|
mockConnect.mockResolvedValue(undefined)
|
||||||
mockListTools.mockResolvedValue({ tools: mockTools })
|
mockListTools.mockResolvedValue({tools: mockTools})
|
||||||
|
|
||||||
const result = await connectToGitHubMCP(token)
|
const result = await connectToGitHubMCP(token)
|
||||||
|
|
||||||
@@ -75,21 +66,13 @@ describe('mcp.ts', () => {
|
|||||||
function: {
|
function: {
|
||||||
name: 'test-tool-1',
|
name: 'test-tool-1',
|
||||||
description: 'Test tool 1',
|
description: 'Test tool 1',
|
||||||
parameters: { type: 'object', properties: {} }
|
parameters: {type: 'object', properties: {}},
|
||||||
}
|
},
|
||||||
})
|
})
|
||||||
expect(core.info).toHaveBeenCalledWith(
|
expect(core.info).toHaveBeenCalledWith('Connecting to GitHub MCP server...')
|
||||||
'Connecting to GitHub MCP server...'
|
expect(core.info).toHaveBeenCalledWith('Successfully connected to GitHub MCP server')
|
||||||
)
|
expect(core.info).toHaveBeenCalledWith('Retrieved 2 tools from GitHub MCP server')
|
||||||
expect(core.info).toHaveBeenCalledWith(
|
expect(core.info).toHaveBeenCalledWith('Mapped 2 GitHub MCP tools for Azure AI Inference')
|
||||||
'Successfully connected to GitHub MCP server'
|
|
||||||
)
|
|
||||||
expect(core.info).toHaveBeenCalledWith(
|
|
||||||
'Retrieved 2 tools from GitHub MCP server'
|
|
||||||
)
|
|
||||||
expect(core.info).toHaveBeenCalledWith(
|
|
||||||
'Mapped 2 GitHub MCP tools for Azure AI Inference'
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
it('returns null when connection fails', async () => {
|
it('returns null when connection fails', async () => {
|
||||||
@@ -101,27 +84,21 @@ describe('mcp.ts', () => {
|
|||||||
const result = await connectToGitHubMCP(token)
|
const result = await connectToGitHubMCP(token)
|
||||||
|
|
||||||
expect(result).toBeNull()
|
expect(result).toBeNull()
|
||||||
expect(core.warning).toHaveBeenCalledWith(
|
expect(core.warning).toHaveBeenCalledWith('Failed to connect to GitHub MCP server: Error: Connection failed')
|
||||||
'Failed to connect to GitHub MCP server: Error: Connection failed'
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
it('handles empty tools list', async () => {
|
it('handles empty tools list', async () => {
|
||||||
const token = 'test-token'
|
const token = 'test-token'
|
||||||
|
|
||||||
mockConnect.mockResolvedValue(undefined)
|
mockConnect.mockResolvedValue(undefined)
|
||||||
mockListTools.mockResolvedValue({ tools: [] })
|
mockListTools.mockResolvedValue({tools: []})
|
||||||
|
|
||||||
const result = await connectToGitHubMCP(token)
|
const result = await connectToGitHubMCP(token)
|
||||||
|
|
||||||
expect(result).not.toBeNull()
|
expect(result).not.toBeNull()
|
||||||
expect(result?.tools).toHaveLength(0)
|
expect(result?.tools).toHaveLength(0)
|
||||||
expect(core.info).toHaveBeenCalledWith(
|
expect(core.info).toHaveBeenCalledWith('Retrieved 0 tools from GitHub MCP server')
|
||||||
'Retrieved 0 tools from GitHub MCP server'
|
expect(core.info).toHaveBeenCalledWith('Mapped 0 GitHub MCP tools for Azure AI Inference')
|
||||||
)
|
|
||||||
expect(core.info).toHaveBeenCalledWith(
|
|
||||||
'Mapped 0 GitHub MCP tools for Azure AI Inference'
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
it('handles undefined tools list', async () => {
|
it('handles undefined tools list', async () => {
|
||||||
@@ -134,9 +111,7 @@ describe('mcp.ts', () => {
|
|||||||
|
|
||||||
expect(result).not.toBeNull()
|
expect(result).not.toBeNull()
|
||||||
expect(result?.tools).toHaveLength(0)
|
expect(result?.tools).toHaveLength(0)
|
||||||
expect(core.info).toHaveBeenCalledWith(
|
expect(core.info).toHaveBeenCalledWith('Retrieved 0 tools from GitHub MCP server')
|
||||||
'Retrieved 0 tools from GitHub MCP server'
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -147,11 +122,11 @@ describe('mcp.ts', () => {
|
|||||||
type: 'function',
|
type: 'function',
|
||||||
function: {
|
function: {
|
||||||
name: 'test-tool',
|
name: 'test-tool',
|
||||||
arguments: '{"param": "value"}'
|
arguments: '{"param": "value"}',
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
const toolResult = {
|
const toolResult = {
|
||||||
content: [{ type: 'text', text: 'Tool execution result' }]
|
content: [{type: 'text', text: 'Tool execution result'}],
|
||||||
}
|
}
|
||||||
|
|
||||||
mockCallTool.mockResolvedValue(toolResult)
|
mockCallTool.mockResolvedValue(toolResult)
|
||||||
@@ -160,20 +135,16 @@ describe('mcp.ts', () => {
|
|||||||
|
|
||||||
expect(mockCallTool).toHaveBeenCalledWith({
|
expect(mockCallTool).toHaveBeenCalledWith({
|
||||||
name: 'test-tool',
|
name: 'test-tool',
|
||||||
arguments: { param: 'value' }
|
arguments: {param: 'value'},
|
||||||
})
|
})
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
tool_call_id: 'call-123',
|
tool_call_id: 'call-123',
|
||||||
role: 'tool',
|
role: 'tool',
|
||||||
name: 'test-tool',
|
name: 'test-tool',
|
||||||
content: JSON.stringify(toolResult.content)
|
content: JSON.stringify(toolResult.content),
|
||||||
})
|
})
|
||||||
expect(core.info).toHaveBeenCalledWith(
|
expect(core.info).toHaveBeenCalledWith('Executing GitHub MCP tool: test-tool with args: {"param": "value"}')
|
||||||
'Executing GitHub MCP tool: test-tool with args: {"param": "value"}'
|
expect(core.info).toHaveBeenCalledWith('GitHub MCP tool test-tool executed successfully')
|
||||||
)
|
|
||||||
expect(core.info).toHaveBeenCalledWith(
|
|
||||||
'GitHub MCP tool test-tool executed successfully'
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
it('handles tool execution errors gracefully', async () => {
|
it('handles tool execution errors gracefully', async () => {
|
||||||
@@ -182,8 +153,8 @@ describe('mcp.ts', () => {
|
|||||||
type: 'function',
|
type: 'function',
|
||||||
function: {
|
function: {
|
||||||
name: 'failing-tool',
|
name: 'failing-tool',
|
||||||
arguments: '{"param": "value"}'
|
arguments: '{"param": "value"}',
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
const toolError = new Error('Tool execution failed')
|
const toolError = new Error('Tool execution failed')
|
||||||
|
|
||||||
@@ -195,10 +166,10 @@ describe('mcp.ts', () => {
|
|||||||
tool_call_id: 'call-456',
|
tool_call_id: 'call-456',
|
||||||
role: 'tool',
|
role: 'tool',
|
||||||
name: 'failing-tool',
|
name: 'failing-tool',
|
||||||
content: 'Error: Error: Tool execution failed'
|
content: 'Error: Error: Tool execution failed',
|
||||||
})
|
})
|
||||||
expect(core.warning).toHaveBeenCalledWith(
|
expect(core.warning).toHaveBeenCalledWith(
|
||||||
'Failed to execute GitHub MCP tool failing-tool: Error: Tool execution failed'
|
'Failed to execute GitHub MCP tool failing-tool: Error: Tool execution failed',
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -208,8 +179,8 @@ describe('mcp.ts', () => {
|
|||||||
type: 'function',
|
type: 'function',
|
||||||
function: {
|
function: {
|
||||||
name: 'test-tool',
|
name: 'test-tool',
|
||||||
arguments: 'invalid-json'
|
arguments: 'invalid-json',
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = await executeToolCall(mockClient, toolCall)
|
const result = await executeToolCall(mockClient, toolCall)
|
||||||
@@ -218,9 +189,7 @@ describe('mcp.ts', () => {
|
|||||||
expect(result.role).toBe('tool')
|
expect(result.role).toBe('tool')
|
||||||
expect(result.name).toBe('test-tool')
|
expect(result.name).toBe('test-tool')
|
||||||
expect(result.content).toContain('Error:')
|
expect(result.content).toContain('Error:')
|
||||||
expect(core.warning).toHaveBeenCalledWith(
|
expect(core.warning).toHaveBeenCalledWith(expect.stringContaining('Failed to execute GitHub MCP tool test-tool:'))
|
||||||
expect.stringContaining('Failed to execute GitHub MCP tool test-tool:')
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -230,21 +199,21 @@ describe('mcp.ts', () => {
|
|||||||
{
|
{
|
||||||
id: 'call-1',
|
id: 'call-1',
|
||||||
type: 'function',
|
type: 'function',
|
||||||
function: { name: 'tool-1', arguments: '{}' }
|
function: {name: 'tool-1', arguments: '{}'},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'call-2',
|
id: 'call-2',
|
||||||
type: 'function',
|
type: 'function',
|
||||||
function: { name: 'tool-2', arguments: '{"param": "value"}' }
|
function: {name: 'tool-2', arguments: '{"param": "value"}'},
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
mockCallTool
|
mockCallTool
|
||||||
.mockResolvedValueOnce({
|
.mockResolvedValueOnce({
|
||||||
content: [{ type: 'text', text: 'Result 1' }]
|
content: [{type: 'text', text: 'Result 1'}],
|
||||||
})
|
})
|
||||||
.mockResolvedValueOnce({
|
.mockResolvedValueOnce({
|
||||||
content: [{ type: 'text', text: 'Result 2' }]
|
content: [{type: 'text', text: 'Result 2'}],
|
||||||
})
|
})
|
||||||
|
|
||||||
const results = await executeToolCalls(mockClient, toolCalls)
|
const results = await executeToolCalls(mockClient, toolCalls)
|
||||||
@@ -267,18 +236,18 @@ describe('mcp.ts', () => {
|
|||||||
{
|
{
|
||||||
id: 'call-1',
|
id: 'call-1',
|
||||||
type: 'function',
|
type: 'function',
|
||||||
function: { name: 'tool-1', arguments: '{}' }
|
function: {name: 'tool-1', arguments: '{}'},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'call-2',
|
id: 'call-2',
|
||||||
type: 'function',
|
type: 'function',
|
||||||
function: { name: 'tool-2', arguments: '{}' }
|
function: {name: 'tool-2', arguments: '{}'},
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
mockCallTool
|
mockCallTool
|
||||||
.mockResolvedValueOnce({
|
.mockResolvedValueOnce({
|
||||||
content: [{ type: 'text', text: 'Result 1' }]
|
content: [{type: 'text', text: 'Result 1'}],
|
||||||
})
|
})
|
||||||
.mockRejectedValueOnce(new Error('Tool 2 failed'))
|
.mockRejectedValueOnce(new Error('Tool 2 failed'))
|
||||||
|
|
||||||
|
|||||||
+13
-26
@@ -1,12 +1,7 @@
|
|||||||
import { describe, it, expect } from 'vitest'
|
import {describe, it, expect} from 'vitest'
|
||||||
import * as path from 'path'
|
import * as path from 'path'
|
||||||
import { fileURLToPath } from 'url'
|
import {fileURLToPath} from 'url'
|
||||||
import {
|
import {parseTemplateVariables, replaceTemplateVariables, loadPromptFile, isPromptYamlFile} from '../src/prompt'
|
||||||
parseTemplateVariables,
|
|
||||||
replaceTemplateVariables,
|
|
||||||
loadPromptFile,
|
|
||||||
isPromptYamlFile
|
|
||||||
} from '../src/prompt'
|
|
||||||
|
|
||||||
const __filename = fileURLToPath(import.meta.url)
|
const __filename = fileURLToPath(import.meta.url)
|
||||||
const __dirname = path.dirname(__filename)
|
const __dirname = path.dirname(__filename)
|
||||||
@@ -19,7 +14,7 @@ a: hello
|
|||||||
b: world
|
b: world
|
||||||
`
|
`
|
||||||
const result = parseTemplateVariables(input)
|
const result = parseTemplateVariables(input)
|
||||||
expect(result).toEqual({ a: 'hello', b: 'world' })
|
expect(result).toEqual({a: 'hello', b: 'world'})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should parse multiline variables', () => {
|
it('should parse multiline variables', () => {
|
||||||
@@ -49,14 +44,14 @@ var2: |
|
|||||||
describe('replaceTemplateVariables', () => {
|
describe('replaceTemplateVariables', () => {
|
||||||
it('should replace simple variables', () => {
|
it('should replace simple variables', () => {
|
||||||
const text = 'Hello {{name}}, welcome to {{place}}!'
|
const text = 'Hello {{name}}, welcome to {{place}}!'
|
||||||
const variables = { name: 'John', place: 'GitHub' }
|
const variables = {name: 'John', place: 'GitHub'}
|
||||||
const result = replaceTemplateVariables(text, variables)
|
const result = replaceTemplateVariables(text, variables)
|
||||||
expect(result).toBe('Hello John, welcome to GitHub!')
|
expect(result).toBe('Hello John, welcome to GitHub!')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should leave unreplaced variables as is', () => {
|
it('should leave unreplaced variables as is', () => {
|
||||||
const text = 'Hello {{name}}, welcome to {{unknown}}!'
|
const text = 'Hello {{name}}, welcome to {{unknown}}!'
|
||||||
const variables = { name: 'John' }
|
const variables = {name: 'John'}
|
||||||
const result = replaceTemplateVariables(text, variables)
|
const result = replaceTemplateVariables(text, variables)
|
||||||
expect(result).toBe('Hello John, welcome to {{unknown}}!')
|
expect(result).toBe('Hello John, welcome to {{unknown}}!')
|
||||||
})
|
})
|
||||||
@@ -90,31 +85,25 @@ var2: |
|
|||||||
|
|
||||||
describe('loadPromptFile', () => {
|
describe('loadPromptFile', () => {
|
||||||
it('should load simple prompt file', () => {
|
it('should load simple prompt file', () => {
|
||||||
const filePath = path.join(
|
const filePath = path.join(__dirname, '../__fixtures__/prompts/simple.prompt.yml')
|
||||||
__dirname,
|
const variables = {a: 'cats', b: 'dogs'}
|
||||||
'../__fixtures__/prompts/simple.prompt.yml'
|
|
||||||
)
|
|
||||||
const variables = { a: 'cats', b: 'dogs' }
|
|
||||||
const result = loadPromptFile(filePath, variables)
|
const result = loadPromptFile(filePath, variables)
|
||||||
|
|
||||||
expect(result.messages).toHaveLength(2)
|
expect(result.messages).toHaveLength(2)
|
||||||
expect(result.messages[0]).toEqual({
|
expect(result.messages[0]).toEqual({
|
||||||
role: 'system',
|
role: 'system',
|
||||||
content: 'Be as concise as possible'
|
content: 'Be as concise as possible',
|
||||||
})
|
})
|
||||||
expect(result.messages[1]).toEqual({
|
expect(result.messages[1]).toEqual({
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: 'Compare cats and dogs, please'
|
content: 'Compare cats and dogs, please',
|
||||||
})
|
})
|
||||||
expect(result.model).toBe('openai/gpt-4o')
|
expect(result.model).toBe('openai/gpt-4o')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should load JSON schema prompt file', () => {
|
it('should load JSON schema prompt file', () => {
|
||||||
const filePath = path.join(
|
const filePath = path.join(__dirname, '../__fixtures__/prompts/json-schema.prompt.yml')
|
||||||
__dirname,
|
const variables = {animal: 'dog'}
|
||||||
'../__fixtures__/prompts/json-schema.prompt.yml'
|
|
||||||
)
|
|
||||||
const variables = { animal: 'dog' }
|
|
||||||
const result = loadPromptFile(filePath, variables)
|
const result = loadPromptFile(filePath, variables)
|
||||||
|
|
||||||
expect(result.messages).toHaveLength(2)
|
expect(result.messages).toHaveLength(2)
|
||||||
@@ -125,9 +114,7 @@ var2: |
|
|||||||
})
|
})
|
||||||
|
|
||||||
it('should throw error for non-existent file', () => {
|
it('should throw error for non-existent file', () => {
|
||||||
expect(() => loadPromptFile('non-existent.prompt.yml')).toThrow(
|
expect(() => loadPromptFile('non-existent.prompt.yml')).toThrow('Prompt file not found')
|
||||||
'Prompt file not found'
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
+1
-2
@@ -14,8 +14,7 @@ inputs:
|
|||||||
required: false
|
required: false
|
||||||
default: ''
|
default: ''
|
||||||
prompt-file:
|
prompt-file:
|
||||||
description:
|
description: Path to a file containing the prompt (supports .txt and .prompt.yml
|
||||||
Path to a file containing the prompt (supports .txt and .prompt.yml
|
|
||||||
formats)
|
formats)
|
||||||
required: false
|
required: false
|
||||||
default: ''
|
default: ''
|
||||||
|
|||||||
+15
-15
@@ -1,43 +1,43 @@
|
|||||||
// See: https://eslint.org/docs/latest/use/configure/configuration-files
|
// See: https://eslint.org/docs/latest/use/configure/configuration-files
|
||||||
|
|
||||||
import { FlatCompat } from '@eslint/eslintrc'
|
import {FlatCompat} from '@eslint/eslintrc'
|
||||||
import js from '@eslint/js'
|
import js from '@eslint/js'
|
||||||
import typescriptEslint from '@typescript-eslint/eslint-plugin'
|
import typescriptEslint from '@typescript-eslint/eslint-plugin'
|
||||||
import tsParser from '@typescript-eslint/parser'
|
import tsParser from '@typescript-eslint/parser'
|
||||||
import prettier from 'eslint-plugin-prettier'
|
import prettier from 'eslint-plugin-prettier'
|
||||||
import globals from 'globals'
|
import globals from 'globals'
|
||||||
import path from 'node:path'
|
import path from 'node:path'
|
||||||
import { fileURLToPath } from 'node:url'
|
import {fileURLToPath} from 'node:url'
|
||||||
|
|
||||||
const __filename = fileURLToPath(import.meta.url)
|
const __filename = fileURLToPath(import.meta.url)
|
||||||
const __dirname = path.dirname(__filename)
|
const __dirname = path.dirname(__filename)
|
||||||
const compat = new FlatCompat({
|
const compat = new FlatCompat({
|
||||||
baseDirectory: __dirname,
|
baseDirectory: __dirname,
|
||||||
recommendedConfig: js.configs.recommended,
|
recommendedConfig: js.configs.recommended,
|
||||||
allConfig: js.configs.all
|
allConfig: js.configs.all,
|
||||||
})
|
})
|
||||||
|
|
||||||
export default [
|
export default [
|
||||||
{
|
{
|
||||||
ignores: ['**/coverage', '**/dist', '**/linter', '**/node_modules']
|
ignores: ['**/coverage', '**/dist', '**/linter', '**/node_modules'],
|
||||||
},
|
},
|
||||||
...compat.extends(
|
...compat.extends(
|
||||||
'eslint:recommended',
|
'eslint:recommended',
|
||||||
'plugin:@typescript-eslint/eslint-recommended',
|
'plugin:@typescript-eslint/eslint-recommended',
|
||||||
'plugin:@typescript-eslint/recommended',
|
'plugin:@typescript-eslint/recommended',
|
||||||
'plugin:prettier/recommended'
|
'plugin:prettier/recommended',
|
||||||
),
|
),
|
||||||
{
|
{
|
||||||
plugins: {
|
plugins: {
|
||||||
prettier,
|
prettier,
|
||||||
'@typescript-eslint': typescriptEslint
|
'@typescript-eslint': typescriptEslint,
|
||||||
},
|
},
|
||||||
|
|
||||||
languageOptions: {
|
languageOptions: {
|
||||||
globals: {
|
globals: {
|
||||||
...globals.node,
|
...globals.node,
|
||||||
Atomics: 'readonly',
|
Atomics: 'readonly',
|
||||||
SharedArrayBuffer: 'readonly'
|
SharedArrayBuffer: 'readonly',
|
||||||
},
|
},
|
||||||
|
|
||||||
parser: tsParser,
|
parser: tsParser,
|
||||||
@@ -46,17 +46,17 @@ export default [
|
|||||||
|
|
||||||
parserOptions: {
|
parserOptions: {
|
||||||
project: ['tsconfig.eslint.json'],
|
project: ['tsconfig.eslint.json'],
|
||||||
tsconfigRootDir: '.'
|
tsconfigRootDir: '.',
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
settings: {
|
settings: {
|
||||||
'import/resolver': {
|
'import/resolver': {
|
||||||
typescript: {
|
typescript: {
|
||||||
alwaysTryTypes: true,
|
alwaysTryTypes: true,
|
||||||
project: 'tsconfig.eslint.json'
|
project: 'tsconfig.eslint.json',
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
rules: {
|
rules: {
|
||||||
@@ -68,7 +68,7 @@ export default [
|
|||||||
'no-console': 'off',
|
'no-console': 'off',
|
||||||
'no-shadow': 'off',
|
'no-shadow': 'off',
|
||||||
'no-unused-vars': 'off',
|
'no-unused-vars': 'off',
|
||||||
'prettier/prettier': 'error'
|
'prettier/prettier': 'error',
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|||||||
Generated
+8
@@ -20,6 +20,7 @@
|
|||||||
"@azure/core-sse": "latest",
|
"@azure/core-sse": "latest",
|
||||||
"@eslint/compat": "^1.3.0",
|
"@eslint/compat": "^1.3.0",
|
||||||
"@github/local-action": "^5.1.0",
|
"@github/local-action": "^5.1.0",
|
||||||
|
"@github/prettier-config": "^0.0.6",
|
||||||
"@rollup/plugin-commonjs": "^28.0.5",
|
"@rollup/plugin-commonjs": "^28.0.5",
|
||||||
"@rollup/plugin-json": "^6.1.0",
|
"@rollup/plugin-json": "^6.1.0",
|
||||||
"@rollup/plugin-node-resolve": "^16.0.1",
|
"@rollup/plugin-node-resolve": "^16.0.1",
|
||||||
@@ -1481,6 +1482,13 @@
|
|||||||
"node": ">=20.18.1"
|
"node": ">=20.18.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@github/prettier-config": {
|
||||||
|
"version": "0.0.6",
|
||||||
|
"resolved": "https://registry.npmjs.org/@github/prettier-config/-/prettier-config-0.0.6.tgz",
|
||||||
|
"integrity": "sha512-Sdb089z+QbGnFF2NivbDeaJ62ooPlD31wE6Fkb/ESjAOXSjNJo+gjqzYYhlM7G3ERJmKFZRUJYMlsqB7Tym8lQ==",
|
||||||
|
"dev": true,
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
"node_modules/@humanfs/core": {
|
"node_modules/@humanfs/core": {
|
||||||
"version": "0.19.1",
|
"version": "0.19.1",
|
||||||
"resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz",
|
"resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz",
|
||||||
|
|||||||
+5
-3
@@ -21,6 +21,7 @@
|
|||||||
"all": "npm run format:write && npm run lint && npm run test && npm run package"
|
"all": "npm run format:write && npm run lint && npm run test && npm run package"
|
||||||
},
|
},
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
|
"prettier": "@github/prettier-config",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@actions/core": "^1.11.1",
|
"@actions/core": "^1.11.1",
|
||||||
"@modelcontextprotocol/sdk": "^1.15.1",
|
"@modelcontextprotocol/sdk": "^1.15.1",
|
||||||
@@ -33,9 +34,12 @@
|
|||||||
"@azure/core-sse": "latest",
|
"@azure/core-sse": "latest",
|
||||||
"@eslint/compat": "^1.3.0",
|
"@eslint/compat": "^1.3.0",
|
||||||
"@github/local-action": "^5.1.0",
|
"@github/local-action": "^5.1.0",
|
||||||
|
"@github/prettier-config": "^0.0.6",
|
||||||
"@rollup/plugin-commonjs": "^28.0.5",
|
"@rollup/plugin-commonjs": "^28.0.5",
|
||||||
|
"@rollup/plugin-json": "^6.1.0",
|
||||||
"@rollup/plugin-node-resolve": "^16.0.1",
|
"@rollup/plugin-node-resolve": "^16.0.1",
|
||||||
"@rollup/plugin-typescript": "^12.1.2",
|
"@rollup/plugin-typescript": "^12.1.2",
|
||||||
|
"@types/js-yaml": "^4.0.9",
|
||||||
"@types/node": "^22.15.31",
|
"@types/node": "^22.15.31",
|
||||||
"@typescript-eslint/eslint-plugin": "^8.34.0",
|
"@typescript-eslint/eslint-plugin": "^8.34.0",
|
||||||
"@typescript-eslint/parser": "^8.32.1",
|
"@typescript-eslint/parser": "^8.32.1",
|
||||||
@@ -48,9 +52,7 @@
|
|||||||
"prettier-eslint": "^16.4.2",
|
"prettier-eslint": "^16.4.2",
|
||||||
"rollup": "^4.43.0",
|
"rollup": "^4.43.0",
|
||||||
"typescript": "^5.8.3",
|
"typescript": "^5.8.3",
|
||||||
"vitest": "^3",
|
"vitest": "^3"
|
||||||
"@rollup/plugin-json": "^6.1.0",
|
|
||||||
"@types/js-yaml": "^4.0.9"
|
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@rollup/rollup-linux-x64-gnu": "*"
|
"@rollup/rollup-linux-x64-gnu": "*"
|
||||||
|
|||||||
+6
-6
@@ -1,5 +1,5 @@
|
|||||||
// See: https://rollupjs.org/introduction/
|
// See: https://rollupjs.org/introduction/
|
||||||
import { builtinModules } from 'node:module'
|
import {builtinModules} from 'node:module'
|
||||||
import commonjs from '@rollup/plugin-commonjs'
|
import commonjs from '@rollup/plugin-commonjs'
|
||||||
import nodeResolve from '@rollup/plugin-node-resolve'
|
import nodeResolve from '@rollup/plugin-node-resolve'
|
||||||
import typescript from '@rollup/plugin-typescript'
|
import typescript from '@rollup/plugin-typescript'
|
||||||
@@ -11,7 +11,7 @@ const config = {
|
|||||||
esModule: true,
|
esModule: true,
|
||||||
file: 'dist/index.js',
|
file: 'dist/index.js',
|
||||||
format: 'es',
|
format: 'es',
|
||||||
sourcemap: true
|
sourcemap: true,
|
||||||
},
|
},
|
||||||
external: [...builtinModules, /^node:/],
|
external: [...builtinModules, /^node:/],
|
||||||
plugins: [
|
plugins: [
|
||||||
@@ -19,13 +19,13 @@ const config = {
|
|||||||
nodeResolve({
|
nodeResolve({
|
||||||
preferBuiltins: true,
|
preferBuiltins: true,
|
||||||
browser: false,
|
browser: false,
|
||||||
exportConditions: ['node']
|
exportConditions: ['node'],
|
||||||
}),
|
}),
|
||||||
commonjs({
|
commonjs({
|
||||||
include: /node_modules/
|
include: /node_modules/,
|
||||||
}),
|
}),
|
||||||
json()
|
json(),
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
export default config
|
export default config
|
||||||
|
|||||||
+20
-33
@@ -1,8 +1,8 @@
|
|||||||
import * as core from '@actions/core'
|
import * as core from '@actions/core'
|
||||||
import { GetChatCompletionsDefaultResponse } from '@azure-rest/ai-inference'
|
import {GetChatCompletionsDefaultResponse} from '@azure-rest/ai-inference'
|
||||||
import * as fs from 'fs'
|
import * as fs from 'fs'
|
||||||
import { PromptConfig } from './prompt.js'
|
import {PromptConfig} from './prompt.js'
|
||||||
import { InferenceRequest } from './inference.js'
|
import {InferenceRequest} from './inference.js'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Helper function to load content from a file or use fallback input
|
* Helper function to load content from a file or use fallback input
|
||||||
@@ -11,11 +11,7 @@ import { InferenceRequest } from './inference.js'
|
|||||||
* @param defaultValue - Default value to use if neither file nor content is provided
|
* @param defaultValue - Default value to use if neither file nor content is provided
|
||||||
* @returns The loaded content
|
* @returns The loaded content
|
||||||
*/
|
*/
|
||||||
export function loadContentFromFileOrInput(
|
export function loadContentFromFileOrInput(filePathInput: string, contentInput: string, defaultValue?: string): string {
|
||||||
filePathInput: string,
|
|
||||||
contentInput: string,
|
|
||||||
defaultValue?: string
|
|
||||||
): string {
|
|
||||||
const filePath = core.getInput(filePathInput)
|
const filePath = core.getInput(filePathInput)
|
||||||
const contentString = core.getInput(contentInput)
|
const contentString = core.getInput(contentInput)
|
||||||
|
|
||||||
@@ -38,9 +34,7 @@ export function loadContentFromFileOrInput(
|
|||||||
* @param response - The response object from the AI service
|
* @param response - The response object from the AI service
|
||||||
* @throws Error with appropriate error message based on response content
|
* @throws Error with appropriate error message based on response content
|
||||||
*/
|
*/
|
||||||
export function handleUnexpectedResponse(
|
export function handleUnexpectedResponse(response: GetChatCompletionsDefaultResponse): never {
|
||||||
response: GetChatCompletionsDefaultResponse
|
|
||||||
): never {
|
|
||||||
// Extract x-ms-error-code from headers if available
|
// Extract x-ms-error-code from headers if available
|
||||||
const errorCode = response.headers['x-ms-error-code']
|
const errorCode = response.headers['x-ms-error-code']
|
||||||
const errorCodeMsg = errorCode ? ` (error code: ${errorCode})` : ''
|
const errorCodeMsg = errorCode ? ` (error code: ${errorCode})` : ''
|
||||||
@@ -54,16 +48,14 @@ export function handleUnexpectedResponse(
|
|||||||
if (!response.body) {
|
if (!response.body) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`Failed to get response from AI service (status: ${response.status})${errorCodeMsg}. ` +
|
`Failed to get response from AI service (status: ${response.status})${errorCodeMsg}. ` +
|
||||||
'Please check network connection and endpoint configuration.'
|
'Please check network connection and endpoint configuration.',
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle other error cases
|
// Handle other error cases
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`AI service returned error response (status: ${response.status})${errorCodeMsg}: ` +
|
`AI service returned error response (status: ${response.status})${errorCodeMsg}: ` +
|
||||||
(typeof response.body === 'string'
|
(typeof response.body === 'string' ? response.body : JSON.stringify(response.body)),
|
||||||
? response.body
|
|
||||||
: JSON.stringify(response.body))
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,22 +65,22 @@ export function handleUnexpectedResponse(
|
|||||||
export function buildMessages(
|
export function buildMessages(
|
||||||
promptConfig?: PromptConfig,
|
promptConfig?: PromptConfig,
|
||||||
systemPrompt?: string,
|
systemPrompt?: string,
|
||||||
prompt?: string
|
prompt?: string,
|
||||||
): Array<{ role: string; content: string }> {
|
): Array<{role: string; content: string}> {
|
||||||
if (promptConfig?.messages && promptConfig.messages.length > 0) {
|
if (promptConfig?.messages && promptConfig.messages.length > 0) {
|
||||||
// Use new message format
|
// Use new message format
|
||||||
return promptConfig.messages.map((msg) => ({
|
return promptConfig.messages.map(msg => ({
|
||||||
role: msg.role,
|
role: msg.role,
|
||||||
content: msg.content
|
content: msg.content,
|
||||||
}))
|
}))
|
||||||
} else {
|
} else {
|
||||||
// Use legacy format
|
// Use legacy format
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
role: 'system',
|
role: 'system',
|
||||||
content: systemPrompt || 'You are a helpful assistant'
|
content: systemPrompt || 'You are a helpful assistant',
|
||||||
},
|
},
|
||||||
{ role: 'user', content: prompt || '' }
|
{role: 'user', content: prompt || ''},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -97,22 +89,17 @@ export function buildMessages(
|
|||||||
* Build response format object for API from prompt config
|
* Build response format object for API from prompt config
|
||||||
*/
|
*/
|
||||||
export function buildResponseFormat(
|
export function buildResponseFormat(
|
||||||
promptConfig?: PromptConfig
|
promptConfig?: PromptConfig,
|
||||||
): { type: 'json_schema'; json_schema: unknown } | undefined {
|
): {type: 'json_schema'; json_schema: unknown} | undefined {
|
||||||
if (
|
if (promptConfig?.responseFormat === 'json_schema' && promptConfig.jsonSchema) {
|
||||||
promptConfig?.responseFormat === 'json_schema' &&
|
|
||||||
promptConfig.jsonSchema
|
|
||||||
) {
|
|
||||||
try {
|
try {
|
||||||
const schema = JSON.parse(promptConfig.jsonSchema)
|
const schema = JSON.parse(promptConfig.jsonSchema)
|
||||||
return {
|
return {
|
||||||
type: 'json_schema',
|
type: 'json_schema',
|
||||||
json_schema: schema
|
json_schema: schema,
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
throw new Error(
|
throw new Error(`Invalid JSON schema: ${error instanceof Error ? error.message : 'Unknown error'}`)
|
||||||
`Invalid JSON schema: ${error instanceof Error ? error.message : 'Unknown error'}`
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return undefined
|
return undefined
|
||||||
@@ -128,7 +115,7 @@ export function buildInferenceRequest(
|
|||||||
modelName: string,
|
modelName: string,
|
||||||
maxTokens: number,
|
maxTokens: number,
|
||||||
endpoint: string,
|
endpoint: string,
|
||||||
token: string
|
token: string,
|
||||||
): InferenceRequest {
|
): InferenceRequest {
|
||||||
const messages = buildMessages(promptConfig, systemPrompt, prompt)
|
const messages = buildMessages(promptConfig, systemPrompt, prompt)
|
||||||
const responseFormat = buildResponseFormat(promptConfig)
|
const responseFormat = buildResponseFormat(promptConfig)
|
||||||
@@ -139,6 +126,6 @@ export function buildInferenceRequest(
|
|||||||
maxTokens,
|
maxTokens,
|
||||||
endpoint,
|
endpoint,
|
||||||
token,
|
token,
|
||||||
responseFormat
|
responseFormat,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-1
@@ -2,7 +2,7 @@
|
|||||||
* The entrypoint for the action. This file simply imports and runs the action's
|
* The entrypoint for the action. This file simply imports and runs the action's
|
||||||
* main logic.
|
* main logic.
|
||||||
*/
|
*/
|
||||||
import { run } from './main.js'
|
import {run} from './main.js'
|
||||||
|
|
||||||
/* istanbul ignore next */
|
/* istanbul ignore next */
|
||||||
run()
|
run()
|
||||||
|
|||||||
+23
-38
@@ -1,8 +1,8 @@
|
|||||||
import * as core from '@actions/core'
|
import * as core from '@actions/core'
|
||||||
import ModelClient, { isUnexpected } from '@azure-rest/ai-inference'
|
import ModelClient, {isUnexpected} from '@azure-rest/ai-inference'
|
||||||
import { AzureKeyCredential } from '@azure/core-auth'
|
import {AzureKeyCredential} from '@azure/core-auth'
|
||||||
import { GitHubMCPClient, executeToolCalls, MCPTool, ToolCall } from './mcp.js'
|
import {GitHubMCPClient, executeToolCalls, MCPTool, ToolCall} from './mcp.js'
|
||||||
import { handleUnexpectedResponse } from './helpers.js'
|
import {handleUnexpectedResponse} from './helpers.js'
|
||||||
|
|
||||||
interface ChatMessage {
|
interface ChatMessage {
|
||||||
role: string
|
role: string
|
||||||
@@ -14,17 +14,17 @@ interface ChatCompletionsRequestBody {
|
|||||||
messages: ChatMessage[]
|
messages: ChatMessage[]
|
||||||
max_tokens: number
|
max_tokens: number
|
||||||
model: string
|
model: string
|
||||||
response_format?: { type: 'json_schema'; json_schema: unknown }
|
response_format?: {type: 'json_schema'; json_schema: unknown}
|
||||||
tools?: MCPTool[]
|
tools?: MCPTool[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface InferenceRequest {
|
export interface InferenceRequest {
|
||||||
messages: Array<{ role: string; content: string }>
|
messages: Array<{role: string; content: string}>
|
||||||
modelName: string
|
modelName: string
|
||||||
maxTokens: number
|
maxTokens: number
|
||||||
endpoint: string
|
endpoint: string
|
||||||
token: string
|
token: string
|
||||||
responseFormat?: { type: 'json_schema'; json_schema: unknown } // Processed response format for the API
|
responseFormat?: {type: 'json_schema'; json_schema: unknown} // Processed response format for the API
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface InferenceResponse {
|
export interface InferenceResponse {
|
||||||
@@ -42,23 +42,17 @@ export interface InferenceResponse {
|
|||||||
/**
|
/**
|
||||||
* Simple one-shot inference without tools
|
* Simple one-shot inference without tools
|
||||||
*/
|
*/
|
||||||
export async function simpleInference(
|
export async function simpleInference(request: InferenceRequest): Promise<string | null> {
|
||||||
request: InferenceRequest
|
|
||||||
): Promise<string | null> {
|
|
||||||
core.info('Running simple inference without tools')
|
core.info('Running simple inference without tools')
|
||||||
|
|
||||||
const client = ModelClient(
|
const client = ModelClient(request.endpoint, new AzureKeyCredential(request.token), {
|
||||||
request.endpoint,
|
userAgentOptions: {userAgentPrefix: 'github-actions-ai-inference'},
|
||||||
new AzureKeyCredential(request.token),
|
})
|
||||||
{
|
|
||||||
userAgentOptions: { userAgentPrefix: 'github-actions-ai-inference' }
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
const requestBody: ChatCompletionsRequestBody = {
|
const requestBody: ChatCompletionsRequestBody = {
|
||||||
messages: request.messages,
|
messages: request.messages,
|
||||||
max_tokens: request.maxTokens,
|
max_tokens: request.maxTokens,
|
||||||
model: request.modelName
|
model: request.modelName,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add response format if specified
|
// Add response format if specified
|
||||||
@@ -67,7 +61,7 @@ export async function simpleInference(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const response = await client.path('/chat/completions').post({
|
const response = await client.path('/chat/completions').post({
|
||||||
body: requestBody
|
body: requestBody,
|
||||||
})
|
})
|
||||||
|
|
||||||
if (isUnexpected(response)) {
|
if (isUnexpected(response)) {
|
||||||
@@ -85,17 +79,13 @@ export async function simpleInference(
|
|||||||
*/
|
*/
|
||||||
export async function mcpInference(
|
export async function mcpInference(
|
||||||
request: InferenceRequest,
|
request: InferenceRequest,
|
||||||
githubMcpClient: GitHubMCPClient
|
githubMcpClient: GitHubMCPClient,
|
||||||
): Promise<string | null> {
|
): Promise<string | null> {
|
||||||
core.info('Running GitHub MCP inference with tools')
|
core.info('Running GitHub MCP inference with tools')
|
||||||
|
|
||||||
const client = ModelClient(
|
const client = ModelClient(request.endpoint, new AzureKeyCredential(request.token), {
|
||||||
request.endpoint,
|
userAgentOptions: {userAgentPrefix: 'github-actions-ai-inference'},
|
||||||
new AzureKeyCredential(request.token),
|
})
|
||||||
{
|
|
||||||
userAgentOptions: { userAgentPrefix: 'github-actions-ai-inference' }
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// Start with the pre-processed messages
|
// Start with the pre-processed messages
|
||||||
const messages: ChatMessage[] = [...request.messages]
|
const messages: ChatMessage[] = [...request.messages]
|
||||||
@@ -111,7 +101,7 @@ export async function mcpInference(
|
|||||||
messages: messages,
|
messages: messages,
|
||||||
max_tokens: request.maxTokens,
|
max_tokens: request.maxTokens,
|
||||||
model: request.modelName,
|
model: request.modelName,
|
||||||
tools: githubMcpClient.tools
|
tools: githubMcpClient.tools,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add response format if specified (only on first iteration to avoid conflicts)
|
// Add response format if specified (only on first iteration to avoid conflicts)
|
||||||
@@ -120,7 +110,7 @@ export async function mcpInference(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const response = await client.path('/chat/completions').post({
|
const response = await client.path('/chat/completions').post({
|
||||||
body: requestBody
|
body: requestBody,
|
||||||
})
|
})
|
||||||
|
|
||||||
if (isUnexpected(response)) {
|
if (isUnexpected(response)) {
|
||||||
@@ -136,7 +126,7 @@ export async function mcpInference(
|
|||||||
messages.push({
|
messages.push({
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: modelResponse || '',
|
content: modelResponse || '',
|
||||||
...(toolCalls && { tool_calls: toolCalls })
|
...(toolCalls && {tool_calls: toolCalls}),
|
||||||
})
|
})
|
||||||
|
|
||||||
if (!toolCalls || toolCalls.length === 0) {
|
if (!toolCalls || toolCalls.length === 0) {
|
||||||
@@ -147,10 +137,7 @@ export async function mcpInference(
|
|||||||
core.info(`Model requested ${toolCalls.length} tool calls`)
|
core.info(`Model requested ${toolCalls.length} tool calls`)
|
||||||
|
|
||||||
// Execute all tool calls via GitHub MCP
|
// Execute all tool calls via GitHub MCP
|
||||||
const toolResults = await executeToolCalls(
|
const toolResults = await executeToolCalls(githubMcpClient.client, toolCalls)
|
||||||
githubMcpClient.client,
|
|
||||||
toolCalls
|
|
||||||
)
|
|
||||||
|
|
||||||
// Add tool results to the conversation
|
// Add tool results to the conversation
|
||||||
messages.push(...toolResults)
|
messages.push(...toolResults)
|
||||||
@@ -158,15 +145,13 @@ export async function mcpInference(
|
|||||||
core.info('Tool results added, continuing conversation...')
|
core.info('Tool results added, continuing conversation...')
|
||||||
}
|
}
|
||||||
|
|
||||||
core.warning(
|
core.warning(`GitHub MCP inference loop exceeded maximum iterations (${maxIterations})`)
|
||||||
`GitHub MCP inference loop exceeded maximum iterations (${maxIterations})`
|
|
||||||
)
|
|
||||||
|
|
||||||
// Return the last assistant message content
|
// Return the last assistant message content
|
||||||
const lastAssistantMessage = messages
|
const lastAssistantMessage = messages
|
||||||
.slice()
|
.slice()
|
||||||
.reverse()
|
.reverse()
|
||||||
.find((msg) => msg.role === 'assistant')
|
.find(msg => msg.role === 'assistant')
|
||||||
|
|
||||||
return lastAssistantMessage?.content || null
|
return lastAssistantMessage?.content || null
|
||||||
}
|
}
|
||||||
|
|||||||
+6
-15
@@ -2,15 +2,10 @@ import * as core from '@actions/core'
|
|||||||
import * as fs from 'fs'
|
import * as fs from 'fs'
|
||||||
import * as os from 'os'
|
import * as os from 'os'
|
||||||
import * as path from 'path'
|
import * as path from 'path'
|
||||||
import { connectToGitHubMCP } from './mcp.js'
|
import {connectToGitHubMCP} from './mcp.js'
|
||||||
import { simpleInference, mcpInference } from './inference.js'
|
import {simpleInference, mcpInference} from './inference.js'
|
||||||
import { loadContentFromFileOrInput, buildInferenceRequest } from './helpers.js'
|
import {loadContentFromFileOrInput, buildInferenceRequest} from './helpers.js'
|
||||||
import {
|
import {loadPromptFile, parseTemplateVariables, isPromptYamlFile, PromptConfig} from './prompt.js'
|
||||||
loadPromptFile,
|
|
||||||
parseTemplateVariables,
|
|
||||||
isPromptYamlFile,
|
|
||||||
PromptConfig
|
|
||||||
} from './prompt.js'
|
|
||||||
|
|
||||||
const RESPONSE_FILE = 'modelResponse.txt'
|
const RESPONSE_FILE = 'modelResponse.txt'
|
||||||
|
|
||||||
@@ -42,11 +37,7 @@ export async function run(): Promise<void> {
|
|||||||
core.info('Using legacy prompt format')
|
core.info('Using legacy prompt format')
|
||||||
|
|
||||||
prompt = loadContentFromFileOrInput('prompt-file', 'prompt')
|
prompt = loadContentFromFileOrInput('prompt-file', 'prompt')
|
||||||
systemPrompt = loadContentFromFileOrInput(
|
systemPrompt = loadContentFromFileOrInput('system-prompt-file', 'system-prompt', 'You are a helpful assistant')
|
||||||
'system-prompt-file',
|
|
||||||
'system-prompt',
|
|
||||||
'You are a helpful assistant'
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get common parameters
|
// Get common parameters
|
||||||
@@ -68,7 +59,7 @@ export async function run(): Promise<void> {
|
|||||||
modelName,
|
modelName,
|
||||||
maxTokens,
|
maxTokens,
|
||||||
endpoint,
|
endpoint,
|
||||||
token
|
token,
|
||||||
)
|
)
|
||||||
|
|
||||||
const enableMcp = core.getBooleanInput('enable-github-mcp') || false
|
const enableMcp = core.getBooleanInput('enable-github-mcp') || false
|
||||||
|
|||||||
+19
-33
@@ -1,6 +1,6 @@
|
|||||||
import * as core from '@actions/core'
|
import * as core from '@actions/core'
|
||||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
import {Client} from '@modelcontextprotocol/sdk/client/index.js'
|
||||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
import {StreamableHTTPClientTransport} from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
||||||
|
|
||||||
export interface ToolResult {
|
export interface ToolResult {
|
||||||
tool_call_id: string
|
tool_call_id: string
|
||||||
@@ -35,9 +35,7 @@ export interface GitHubMCPClient {
|
|||||||
/**
|
/**
|
||||||
* Connect to the GitHub MCP server and retrieve available tools
|
* Connect to the GitHub MCP server and retrieve available tools
|
||||||
*/
|
*/
|
||||||
export async function connectToGitHubMCP(
|
export async function connectToGitHubMCP(token: string): Promise<GitHubMCPClient | null> {
|
||||||
token: string
|
|
||||||
): Promise<GitHubMCPClient | null> {
|
|
||||||
const githubMcpUrl = 'https://api.githubcopilot.com/mcp/'
|
const githubMcpUrl = 'https://api.githubcopilot.com/mcp/'
|
||||||
|
|
||||||
core.info('Connecting to GitHub MCP server...')
|
core.info('Connecting to GitHub MCP server...')
|
||||||
@@ -46,15 +44,15 @@ export async function connectToGitHubMCP(
|
|||||||
requestInit: {
|
requestInit: {
|
||||||
headers: {
|
headers: {
|
||||||
Authorization: `Bearer ${token}`,
|
Authorization: `Bearer ${token}`,
|
||||||
'X-MCP-Readonly': 'true'
|
'X-MCP-Readonly': 'true',
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
const client = new Client({
|
const client = new Client({
|
||||||
name: 'ai-inference-action',
|
name: 'ai-inference-action',
|
||||||
version: '1.0.0',
|
version: '1.0.0',
|
||||||
transport
|
transport,
|
||||||
})
|
})
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -67,42 +65,35 @@ export async function connectToGitHubMCP(
|
|||||||
core.info('Successfully connected to GitHub MCP server')
|
core.info('Successfully connected to GitHub MCP server')
|
||||||
|
|
||||||
const toolsResponse = await client.listTools()
|
const toolsResponse = await client.listTools()
|
||||||
core.info(
|
core.info(`Retrieved ${toolsResponse.tools?.length || 0} tools from GitHub MCP server`)
|
||||||
`Retrieved ${toolsResponse.tools?.length || 0} tools from GitHub MCP server`
|
|
||||||
)
|
|
||||||
|
|
||||||
// Map GitHub MCP tools → Azure AI Inference tool definitions
|
// Map GitHub MCP tools → Azure AI Inference tool definitions
|
||||||
const tools = (toolsResponse.tools || []).map((t) => ({
|
const tools = (toolsResponse.tools || []).map(t => ({
|
||||||
type: 'function' as const,
|
type: 'function' as const,
|
||||||
function: {
|
function: {
|
||||||
name: t.name,
|
name: t.name,
|
||||||
description: t.description,
|
description: t.description,
|
||||||
parameters: t.inputSchema
|
parameters: t.inputSchema,
|
||||||
}
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
core.info(`Mapped ${tools.length} GitHub MCP tools for Azure AI Inference`)
|
core.info(`Mapped ${tools.length} GitHub MCP tools for Azure AI Inference`)
|
||||||
|
|
||||||
return { client, tools }
|
return {client, tools}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Execute a single tool call via GitHub MCP
|
* Execute a single tool call via GitHub MCP
|
||||||
*/
|
*/
|
||||||
export async function executeToolCall(
|
export async function executeToolCall(githubMcpClient: Client, toolCall: ToolCall): Promise<ToolResult> {
|
||||||
githubMcpClient: Client,
|
core.info(`Executing GitHub MCP tool: ${toolCall.function.name} with args: ${toolCall.function.arguments}`)
|
||||||
toolCall: ToolCall
|
|
||||||
): Promise<ToolResult> {
|
|
||||||
core.info(
|
|
||||||
`Executing GitHub MCP tool: ${toolCall.function.name} with args: ${toolCall.function.arguments}`
|
|
||||||
)
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const args = JSON.parse(toolCall.function.arguments)
|
const args = JSON.parse(toolCall.function.arguments)
|
||||||
|
|
||||||
const result = await githubMcpClient.callTool({
|
const result = await githubMcpClient.callTool({
|
||||||
name: toolCall.function.name,
|
name: toolCall.function.name,
|
||||||
arguments: args
|
arguments: args,
|
||||||
})
|
})
|
||||||
|
|
||||||
core.info(`GitHub MCP tool ${toolCall.function.name} executed successfully`)
|
core.info(`GitHub MCP tool ${toolCall.function.name} executed successfully`)
|
||||||
@@ -111,18 +102,16 @@ export async function executeToolCall(
|
|||||||
tool_call_id: toolCall.id,
|
tool_call_id: toolCall.id,
|
||||||
role: 'tool',
|
role: 'tool',
|
||||||
name: toolCall.function.name,
|
name: toolCall.function.name,
|
||||||
content: JSON.stringify(result.content)
|
content: JSON.stringify(result.content),
|
||||||
}
|
}
|
||||||
} catch (toolError) {
|
} catch (toolError) {
|
||||||
core.warning(
|
core.warning(`Failed to execute GitHub MCP tool ${toolCall.function.name}: ${toolError}`)
|
||||||
`Failed to execute GitHub MCP tool ${toolCall.function.name}: ${toolError}`
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
tool_call_id: toolCall.id,
|
tool_call_id: toolCall.id,
|
||||||
role: 'tool',
|
role: 'tool',
|
||||||
name: toolCall.function.name,
|
name: toolCall.function.name,
|
||||||
content: `Error: ${toolError}`
|
content: `Error: ${toolError}`,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -130,10 +119,7 @@ export async function executeToolCall(
|
|||||||
/**
|
/**
|
||||||
* Execute all tool calls from a response via GitHub MCP
|
* Execute all tool calls from a response via GitHub MCP
|
||||||
*/
|
*/
|
||||||
export async function executeToolCalls(
|
export async function executeToolCalls(githubMcpClient: Client, toolCalls: ToolCall[]): Promise<ToolResult[]> {
|
||||||
githubMcpClient: Client,
|
|
||||||
toolCalls: ToolCall[]
|
|
||||||
): Promise<ToolResult[]> {
|
|
||||||
const toolResults: ToolResult[] = []
|
const toolResults: ToolResult[] = []
|
||||||
|
|
||||||
for (const toolCall of toolCalls) {
|
for (const toolCall of toolCalls) {
|
||||||
|
|||||||
+7
-24
@@ -33,26 +33,19 @@ export function parseTemplateVariables(input: string): TemplateVariables {
|
|||||||
}
|
}
|
||||||
return parsed
|
return parsed
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
throw new Error(
|
throw new Error(`Failed to parse template variables: ${error instanceof Error ? error.message : 'Unknown error'}`)
|
||||||
`Failed to parse template variables: ${error instanceof Error ? error.message : 'Unknown error'}`
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Replace template variables in text using {{variable}} syntax
|
* Replace template variables in text using {{variable}} syntax
|
||||||
*/
|
*/
|
||||||
export function replaceTemplateVariables(
|
export function replaceTemplateVariables(text: string, variables: TemplateVariables): string {
|
||||||
text: string,
|
|
||||||
variables: TemplateVariables
|
|
||||||
): string {
|
|
||||||
return text.replace(/\{\{([\w.-]+)\}\}/g, (match, variableName) => {
|
return text.replace(/\{\{([\w.-]+)\}\}/g, (match, variableName) => {
|
||||||
if (variableName in variables) {
|
if (variableName in variables) {
|
||||||
return variables[variableName]
|
return variables[variableName]
|
||||||
}
|
}
|
||||||
core.warning(
|
core.warning(`Template variable '${variableName}' not found in input variables`)
|
||||||
`Template variable '${variableName}' not found in input variables`
|
|
||||||
)
|
|
||||||
return match // Return the original placeholder if variable not found
|
return match // Return the original placeholder if variable not found
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -60,10 +53,7 @@ export function replaceTemplateVariables(
|
|||||||
/**
|
/**
|
||||||
* Load and parse a prompt YAML file with template variable substitution
|
* Load and parse a prompt YAML file with template variable substitution
|
||||||
*/
|
*/
|
||||||
export function loadPromptFile(
|
export function loadPromptFile(filePath: string, templateVariables: TemplateVariables = {}): PromptConfig {
|
||||||
filePath: string,
|
|
||||||
templateVariables: TemplateVariables = {}
|
|
||||||
): PromptConfig {
|
|
||||||
if (!fs.existsSync(filePath)) {
|
if (!fs.existsSync(filePath)) {
|
||||||
throw new Error(`Prompt file not found: ${filePath}`)
|
throw new Error(`Prompt file not found: ${filePath}`)
|
||||||
}
|
}
|
||||||
@@ -71,10 +61,7 @@ export function loadPromptFile(
|
|||||||
const fileContent = fs.readFileSync(filePath, 'utf-8')
|
const fileContent = fs.readFileSync(filePath, 'utf-8')
|
||||||
|
|
||||||
// Apply template variable substitution
|
// Apply template variable substitution
|
||||||
const processedContent = replaceTemplateVariables(
|
const processedContent = replaceTemplateVariables(fileContent, templateVariables)
|
||||||
fileContent,
|
|
||||||
templateVariables
|
|
||||||
)
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const config = yaml.load(processedContent) as PromptConfig
|
const config = yaml.load(processedContent) as PromptConfig
|
||||||
@@ -86,9 +73,7 @@ export function loadPromptFile(
|
|||||||
// Validate messages
|
// Validate messages
|
||||||
for (const message of config.messages) {
|
for (const message of config.messages) {
|
||||||
if (!message.role || !message.content) {
|
if (!message.role || !message.content) {
|
||||||
throw new Error(
|
throw new Error('Each message must have "role" and "content" properties')
|
||||||
'Each message must have "role" and "content" properties'
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
if (!['system', 'user', 'assistant'].includes(message.role)) {
|
if (!['system', 'user', 'assistant'].includes(message.role)) {
|
||||||
throw new Error(`Invalid message role: ${message.role}`)
|
throw new Error(`Invalid message role: ${message.role}`)
|
||||||
@@ -97,9 +82,7 @@ export function loadPromptFile(
|
|||||||
|
|
||||||
return config
|
return config
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
throw new Error(
|
throw new Error(`Failed to parse prompt file: ${error instanceof Error ? error.message : 'Unknown error'}`)
|
||||||
`Failed to parse prompt file: ${error instanceof Error ? error.message : 'Unknown error'}`
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,11 +6,5 @@
|
|||||||
"noEmit": true
|
"noEmit": true
|
||||||
},
|
},
|
||||||
"exclude": ["dist", "node_modules"],
|
"exclude": ["dist", "node_modules"],
|
||||||
"include": [
|
"include": ["__fixtures__", "__tests__", "src", "eslint.config.mjs", "rollup.config.ts"]
|
||||||
"__fixtures__",
|
|
||||||
"__tests__",
|
|
||||||
"src",
|
|
||||||
"eslint.config.mjs",
|
|
||||||
"rollup.config.ts"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user