diff --git a/__tests__/helpers-inference.test.ts b/__tests__/helpers-inference.test.ts index 6f46b3e..358d36e 100644 --- a/__tests__/helpers-inference.test.ts +++ b/__tests__/helpers-inference.test.ts @@ -109,6 +109,7 @@ describe('helpers.ts - inference request building', () => { undefined, undefined, 100, + undefined, 'https://api.test.com', 'test-token', ) @@ -122,6 +123,7 @@ describe('helpers.ts - inference request building', () => { temperature: undefined, topP: undefined, maxTokens: 100, + maxCompletionTokens: undefined, endpoint: 'https://api.test.com', token: 'test-token', responseFormat: { @@ -143,6 +145,7 @@ describe('helpers.ts - inference request building', () => { undefined, undefined, 100, + undefined, 'https://api.test.com', 'test-token', ) @@ -156,6 +159,7 @@ describe('helpers.ts - inference request building', () => { temperature: undefined, topP: undefined, maxTokens: 100, + maxCompletionTokens: undefined, endpoint: 'https://api.test.com', token: 'test-token', responseFormat: undefined, diff --git a/__tests__/inference.test.ts b/__tests__/inference.test.ts index 5c10fc6..4775ee7 100644 --- a/__tests__/inference.test.ts +++ b/__tests__/inference.test.ts @@ -31,7 +31,7 @@ describe('inference.ts', () => { {role: 'user' as const, content: 'Hello, AI!'}, ], modelName: 'gpt-4', - maxTokens: 100, + maxCompletionTokens: 100, endpoint: 'https://api.test.com', token: 'test-token', } @@ -633,4 +633,35 @@ describe('inference.ts', () => { expect(result).toBe('{"immediate": "result"}') }) }) + + describe('token param routing', () => { + it('sends max_tokens when only maxTokens is set', async () => { + const requestWithMaxTokens = { + ...mockRequest, + maxCompletionTokens: undefined, + maxTokens: 100, + } + + const mockResponse = { + choices: [ + { + message: { + content: 'Direct max_tokens response', + }, + }, + ], + } + + mockCreate.mockResolvedValueOnce(mockResponse) + + const result = await simpleInference(requestWithMaxTokens) + + expect(result).toBe('Direct max_tokens response') + expect(mockCreate).toHaveBeenCalledTimes(1) + + // Should have sent max_tokens directly + expect(mockCreate.mock.calls[0][0]).toHaveProperty('max_tokens', 100) + expect(mockCreate.mock.calls[0][0]).not.toHaveProperty('max_completion_tokens') + }) + }) }) diff --git a/__tests__/main.test.ts b/__tests__/main.test.ts index 81e29ea..578bf98 100644 --- a/__tests__/main.test.ts +++ b/__tests__/main.test.ts @@ -168,6 +168,7 @@ describe('main.ts', () => { ], modelName: 'gpt-4', maxTokens: 100, + maxCompletionTokens: undefined, endpoint: 'https://api.test.com', token: 'fake-token', responseFormat: undefined, @@ -259,6 +260,7 @@ describe('main.ts', () => { ], modelName: 'gpt-4', maxTokens: 100, + maxCompletionTokens: undefined, endpoint: 'https://api.test.com', token: 'fake-token', responseFormat: undefined, diff --git a/action.yml b/action.yml index ba19576..4fc7863 100644 --- a/action.yml +++ b/action.yml @@ -43,9 +43,13 @@ inputs: required: false default: '' max-tokens: - description: The maximum number of tokens to generate + description: Maximum tokens to generate (deprecated) required: false default: '200' + max-completion-tokens: + description: Maximum tokens to generate + required: false + default: '' temperature: description: The sampling temperature to use (0-1) required: false diff --git a/src/helpers.ts b/src/helpers.ts index ff79c0e..33bd10b 100644 --- a/src/helpers.ts +++ b/src/helpers.ts @@ -162,7 +162,8 @@ export function buildInferenceRequest( modelName: string, temperature: number | undefined, topP: number | undefined, - maxTokens: number, + maxTokens: number | undefined, // Deprecated + maxCompletionTokens: number | undefined, endpoint: string, token: string, customHeaders?: Record, @@ -175,7 +176,8 @@ export function buildInferenceRequest( modelName, temperature, topP, - maxTokens, + maxTokens, // Deprecated + maxCompletionTokens, endpoint, token, responseFormat, diff --git a/src/inference.ts b/src/inference.ts index c7880d7..e8d5136 100644 --- a/src/inference.ts +++ b/src/inference.ts @@ -12,7 +12,8 @@ interface ChatMessage { export interface InferenceRequest { messages: Array<{role: 'system' | 'user' | 'assistant' | 'tool'; content: string}> modelName: string - maxTokens: number + maxTokens?: number // Deprecated + maxCompletionTokens?: number endpoint: string token: string temperature?: number @@ -33,6 +34,20 @@ export interface InferenceResponse { }> } +/** + * Build the token limit params for a chat completion request. + * Only one of max_tokens or max_completion_tokens will be set. + */ +function buildMaxTokensParam(request: InferenceRequest): {max_tokens?: number; max_completion_tokens?: number} { + if (request.maxCompletionTokens != null) { + return {max_completion_tokens: request.maxCompletionTokens} + } + if (request.maxTokens != null) { + return {max_tokens: request.maxTokens} + } + return {} +} + /** * Simple one-shot inference without tools */ @@ -47,10 +62,10 @@ export async function simpleInference(request: InferenceRequest): Promise { // Get common parameters const modelName = promptConfig?.model || core.getInput('model') - let maxTokens = promptConfig?.modelParameters?.maxTokens ?? core.getInput('max-tokens') - if (typeof maxTokens === 'string') { - maxTokens = parseInt(maxTokens, 10) - } + // Parse token limit inputs + const maxCompletionTokensInput = + promptConfig?.modelParameters?.maxCompletionTokens ?? core.getInput('max-completion-tokens') + const maxCompletionTokens = maxCompletionTokensInput ? Number(maxCompletionTokensInput) : undefined + + const maxTokensInput = promptConfig?.modelParameters?.maxTokens ?? core.getInput('max-tokens') + const maxTokens = maxCompletionTokens != null ? undefined : maxTokensInput ? Number(maxTokensInput) : undefined const token = process.env['GITHUB_TOKEN'] || core.getInput('token') if (token === undefined) { @@ -85,6 +88,7 @@ export async function run(): Promise { temperature, topP, maxTokens, + maxCompletionTokens, endpoint, token, customHeaders, diff --git a/src/prompt.ts b/src/prompt.ts index 7a34b7c..57dcd5c 100644 --- a/src/prompt.ts +++ b/src/prompt.ts @@ -8,7 +8,8 @@ export interface PromptMessage { } export interface ModelParameters { - maxTokens?: number + maxTokens?: number // Deprecated + maxCompletionTokens?: number temperature?: number topP?: number }