From ac83ac20c444656f0f7d5bbad5b62da389395439 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 9 Mar 2026 16:24:57 -0700 Subject: [PATCH] anthropic: fix KV cache reuse degraded by tool call argument reordering Use typed structs for tool call arguments instead of map[string]any to preserve JSON key order, which Go maps do not guarantee. --- anthropic/anthropic.go | 373 ++++++++++++++++------------------- anthropic/anthropic_test.go | 169 ++++++++-------- middleware/anthropic.go | 13 +- middleware/anthropic_test.go | 18 +- model/parsers/qwen3.go | 10 +- 5 files changed, 275 insertions(+), 308 deletions(-) diff --git a/anthropic/anthropic.go b/anthropic/anthropic.go index cde6360e5..d764b2928 100755 --- a/anthropic/anthropic.go +++ b/anthropic/anthropic.go @@ -68,7 +68,7 @@ type MessagesRequest struct { Model string `json:"model"` MaxTokens int `json:"max_tokens"` Messages []MessageParam `json:"messages"` - System any `json:"system,omitempty"` // string or []ContentBlock + System any `json:"system,omitempty"` // string or []map[string]any (JSON-decoded ContentBlock) Stream bool `json:"stream,omitempty"` Temperature *float64 `json:"temperature,omitempty"` TopP *float64 `json:"top_p,omitempty"` @@ -82,8 +82,27 @@ type MessagesRequest struct { // MessageParam represents a message in the request type MessageParam struct { - Role string `json:"role"` // "user" or "assistant" - Content any `json:"content"` // string or []ContentBlock + Role string `json:"role"` // "user" or "assistant" + Content []ContentBlock `json:"content"` // always []ContentBlock; plain strings are normalized on unmarshal +} + +func (m *MessageParam) UnmarshalJSON(data []byte) error { + var raw struct { + Role string `json:"role"` + Content json.RawMessage `json:"content"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + m.Role = raw.Role + + var s string + if err := json.Unmarshal(raw.Content, &s); err == nil { + m.Content = []ContentBlock{{Type: "text", Text: &s}} + return nil + } + + return json.Unmarshal(raw.Content, &m.Content) } // ContentBlock represents a content block in a message. @@ -102,9 +121,9 @@ type ContentBlock struct { Source *ImageSource `json:"source,omitempty"` // For tool_use and server_tool_use blocks - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input any `json:"input,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input api.ToolCallFunctionArguments `json:"input,omitempty"` // For tool_result and web_search_tool_result blocks ToolUseID string `json:"tool_use_id,omitempty"` @@ -377,178 +396,145 @@ func convertMessage(msg MessageParam) ([]api.Message, error) { var messages []api.Message role := strings.ToLower(msg.Role) - switch content := msg.Content.(type) { - case string: - messages = append(messages, api.Message{Role: role, Content: content}) + var textContent strings.Builder + var images []api.ImageData + var toolCalls []api.ToolCall + var thinking string + var toolResults []api.Message + textBlocks := 0 + imageBlocks := 0 + toolUseBlocks := 0 + toolResultBlocks := 0 + serverToolUseBlocks := 0 + webSearchToolResultBlocks := 0 + thinkingBlocks := 0 + unknownBlocks := 0 - case []any: - var textContent strings.Builder - var images []api.ImageData - var toolCalls []api.ToolCall - var thinking string - var toolResults []api.Message - textBlocks := 0 - imageBlocks := 0 - toolUseBlocks := 0 - toolResultBlocks := 0 - serverToolUseBlocks := 0 - webSearchToolResultBlocks := 0 - thinkingBlocks := 0 - unknownBlocks := 0 - - for _, block := range content { - blockMap, ok := block.(map[string]any) - if !ok { - logutil.Trace("anthropic: invalid content block format", "role", role) - return nil, errors.New("invalid content block format") + for _, block := range msg.Content { + switch block.Type { + case "text": + textBlocks++ + if block.Text != nil { + textContent.WriteString(*block.Text) } - blockType, _ := blockMap["type"].(string) + case "image": + imageBlocks++ + if block.Source == nil { + logutil.Trace("anthropic: invalid image source", "role", role) + return nil, errors.New("invalid image source") + } - switch blockType { - case "text": - textBlocks++ - if text, ok := blockMap["text"].(string); ok { - textContent.WriteString(text) + if block.Source.Type == "base64" { + decoded, err := base64.StdEncoding.DecodeString(block.Source.Data) + if err != nil { + logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err) + return nil, fmt.Errorf("invalid base64 image data: %w", err) } + images = append(images, decoded) + } else { + logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", block.Source.Type) + return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", block.Source.Type) + } - case "image": - imageBlocks++ - source, ok := blockMap["source"].(map[string]any) - if !ok { - logutil.Trace("anthropic: invalid image source", "role", role) - return nil, errors.New("invalid image source") - } + case "tool_use": + toolUseBlocks++ + if block.ID == "" { + logutil.Trace("anthropic: tool_use block missing id", "role", role) + return nil, errors.New("tool_use block missing required 'id' field") + } + if block.Name == "" { + logutil.Trace("anthropic: tool_use block missing name", "role", role) + return nil, errors.New("tool_use block missing required 'name' field") + } + toolCalls = append(toolCalls, api.ToolCall{ + ID: block.ID, + Function: api.ToolCallFunction{ + Name: block.Name, + Arguments: block.Input, + }, + }) - sourceType, _ := source["type"].(string) - if sourceType == "base64" { - data, _ := source["data"].(string) - decoded, err := base64.StdEncoding.DecodeString(data) - if err != nil { - logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err) - return nil, fmt.Errorf("invalid base64 image data: %w", err) - } - images = append(images, decoded) - } else { - logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", sourceType) - return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", sourceType) - } - // URL images would need to be fetched - skip for now + case "tool_result": + toolResultBlocks++ + var resultContent string - case "tool_use": - toolUseBlocks++ - id, ok := blockMap["id"].(string) - if !ok { - logutil.Trace("anthropic: tool_use block missing id", "role", role) - return nil, errors.New("tool_use block missing required 'id' field") - } - name, ok := blockMap["name"].(string) - if !ok { - logutil.Trace("anthropic: tool_use block missing name", "role", role) - return nil, errors.New("tool_use block missing required 'name' field") - } - tc := api.ToolCall{ - ID: id, - Function: api.ToolCallFunction{ - Name: name, - }, - } - if input, ok := blockMap["input"].(map[string]any); ok { - tc.Function.Arguments = mapToArgs(input) - } - toolCalls = append(toolCalls, tc) - - case "tool_result": - toolResultBlocks++ - toolUseID, _ := blockMap["tool_use_id"].(string) - var resultContent string - - switch c := blockMap["content"].(type) { - case string: - resultContent = c - case []any: - for _, cb := range c { - if cbMap, ok := cb.(map[string]any); ok { - if cbMap["type"] == "text" { - if text, ok := cbMap["text"].(string); ok { - resultContent += text - } + switch c := block.Content.(type) { + case string: + resultContent = c + case []any: + for _, cb := range c { + if cbMap, ok := cb.(map[string]any); ok { + if cbMap["type"] == "text" { + if text, ok := cbMap["text"].(string); ok { + resultContent += text } } } } - - toolResults = append(toolResults, api.Message{ - Role: "tool", - Content: resultContent, - ToolCallID: toolUseID, - }) - - case "thinking": - thinkingBlocks++ - if t, ok := blockMap["thinking"].(string); ok { - thinking = t - } - - case "server_tool_use": - serverToolUseBlocks++ - id, _ := blockMap["id"].(string) - name, _ := blockMap["name"].(string) - tc := api.ToolCall{ - ID: id, - Function: api.ToolCallFunction{ - Name: name, - }, - } - if input, ok := blockMap["input"].(map[string]any); ok { - tc.Function.Arguments = mapToArgs(input) - } - toolCalls = append(toolCalls, tc) - - case "web_search_tool_result": - webSearchToolResultBlocks++ - toolUseID, _ := blockMap["tool_use_id"].(string) - toolResults = append(toolResults, api.Message{ - Role: "tool", - Content: formatWebSearchToolResultContent(blockMap["content"]), - ToolCallID: toolUseID, - }) - default: - unknownBlocks++ } - } - if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" { - m := api.Message{ - Role: role, - Content: textContent.String(), - Images: images, - ToolCalls: toolCalls, - Thinking: thinking, + toolResults = append(toolResults, api.Message{ + Role: "tool", + Content: resultContent, + ToolCallID: block.ToolUseID, + }) + + case "thinking": + thinkingBlocks++ + if block.Thinking != nil { + thinking = *block.Thinking } - messages = append(messages, m) + + case "server_tool_use": + serverToolUseBlocks++ + toolCalls = append(toolCalls, api.ToolCall{ + ID: block.ID, + Function: api.ToolCallFunction{ + Name: block.Name, + Arguments: block.Input, + }, + }) + + case "web_search_tool_result": + webSearchToolResultBlocks++ + toolResults = append(toolResults, api.Message{ + Role: "tool", + Content: formatWebSearchToolResultContent(block.Content), + ToolCallID: block.ToolUseID, + }) + default: + unknownBlocks++ } - - // Add tool results as separate messages - messages = append(messages, toolResults...) - logutil.Trace("anthropic: converted block message", - "role", role, - "blocks", len(content), - "text", textBlocks, - "image", imageBlocks, - "tool_use", toolUseBlocks, - "tool_result", toolResultBlocks, - "server_tool_use", serverToolUseBlocks, - "web_search_result", webSearchToolResultBlocks, - "thinking", thinkingBlocks, - "unknown", unknownBlocks, - "messages", TraceAPIMessages(messages), - ) - - default: - return nil, fmt.Errorf("invalid message content type: %T", content) } + if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" { + m := api.Message{ + Role: role, + Content: textContent.String(), + Images: images, + ToolCalls: toolCalls, + Thinking: thinking, + } + messages = append(messages, m) + } + + // Add tool results as separate messages + messages = append(messages, toolResults...) + logutil.Trace("anthropic: converted block message", + "role", role, + "blocks", len(msg.Content), + "text", textBlocks, + "image", imageBlocks, + "tool_use", toolUseBlocks, + "tool_result", toolResultBlocks, + "server_tool_use", serverToolUseBlocks, + "web_search_result", webSearchToolResultBlocks, + "thinking", thinkingBlocks, + "unknown", unknownBlocks, + "messages", TraceAPIMessages(messages), + ) + return messages, nil } @@ -892,7 +878,7 @@ func (c *StreamConverter) Process(r api.ChatResponse) []StreamEvent { Type: "tool_use", ID: tc.ID, Name: tc.Function.Name, - Input: map[string]any{}, + Input: api.ToolCallFunctionArguments{}, }, }, }) @@ -989,15 +975,6 @@ func ptr(s string) *string { return &s } -// mapToArgs converts a map to ToolCallFunctionArguments -func mapToArgs(m map[string]any) api.ToolCallFunctionArguments { - args := api.NewToolCallFunctionArguments() - for k, v := range m { - args.Set(k, v) - } - return args -} - // CountTokensRequest represents an Anthropic count_tokens request type CountTokensRequest struct { Model string `json:"model"` @@ -1030,17 +1007,13 @@ func estimateTokens(req CountTokensRequest) int { var totalLen int // Count system prompt - if req.System != nil { - totalLen += countAnyContent(req.System) - } + totalLen += countAnyContent(req.System) - // Count messages for _, msg := range req.Messages { // Count role (always present) totalLen += len(msg.Role) // Count content - contentLen := countAnyContent(msg.Content) - totalLen += contentLen + totalLen += countAnyContent(msg.Content) } for _, tool := range req.Tools { @@ -1063,12 +1036,25 @@ func countAnyContent(content any) int { switch c := content.(type) { case string: return len(c) - case []any: + case []ContentBlock: total := 0 for _, block := range c { total += countContentBlock(block) } return total + case []any: + total := 0 + for _, item := range c { + data, err := json.Marshal(item) + if err != nil { + continue + } + var block ContentBlock + if err := json.Unmarshal(data, &block); err == nil { + total += countContentBlock(block) + } + } + return total default: if data, err := json.Marshal(content); err == nil { return len(data) @@ -1077,38 +1063,19 @@ func countAnyContent(content any) int { } } -func countContentBlock(block any) int { - blockMap, ok := block.(map[string]any) - if !ok { - if s, ok := block.(string); ok { - return len(s) - } - return 0 - } - +func countContentBlock(block ContentBlock) int { total := 0 - blockType, _ := blockMap["type"].(string) - - if text, ok := blockMap["text"].(string); ok { - total += len(text) + if block.Text != nil { + total += len(*block.Text) } - - if thinking, ok := blockMap["thinking"].(string); ok { - total += len(thinking) + if block.Thinking != nil { + total += len(*block.Thinking) } - - if blockType == "tool_use" { - if data, err := json.Marshal(blockMap); err == nil { + if block.Type == "tool_use" || block.Type == "tool_result" { + if data, err := json.Marshal(block); err == nil { total += len(data) } } - - if blockType == "tool_result" { - if data, err := json.Marshal(blockMap); err == nil { - total += len(data) - } - } - return total } diff --git a/anthropic/anthropic_test.go b/anthropic/anthropic_test.go index faa98a2ef..7b450d47b 100755 --- a/anthropic/anthropic_test.go +++ b/anthropic/anthropic_test.go @@ -15,11 +15,16 @@ const ( testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` ) -// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests) -func testArgs(m map[string]any) api.ToolCallFunctionArguments { +// textContent is a convenience for constructing []ContentBlock with a single text block in tests. +func textContent(s string) []ContentBlock { + return []ContentBlock{{Type: "text", Text: &s}} +} + +// makeArgs creates ToolCallFunctionArguments from key-value pairs (convenience function for tests) +func makeArgs(kvs ...any) api.ToolCallFunctionArguments { args := api.NewToolCallFunctionArguments() - for k, v := range m { - args.Set(k, v) + for i := 0; i < len(kvs)-1; i += 2 { + args.Set(kvs[i].(string), kvs[i+1]) } return args } @@ -29,7 +34,7 @@ func TestFromMessagesRequest_Basic(t *testing.T) { Model: "test-model", MaxTokens: 1024, Messages: []MessageParam{ - {Role: "user", Content: "Hello"}, + {Role: "user", Content: textContent("Hello")}, }, } @@ -61,7 +66,7 @@ func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) { MaxTokens: 1024, System: "You are a helpful assistant.", Messages: []MessageParam{ - {Role: "user", Content: "Hello"}, + {Role: "user", Content: textContent("Hello")}, }, } @@ -88,7 +93,7 @@ func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) { map[string]any{"type": "text", "text": " Be concise."}, }, Messages: []MessageParam{ - {Role: "user", Content: "Hello"}, + {Role: "user", Content: textContent("Hello")}, }, } @@ -113,7 +118,7 @@ func TestFromMessagesRequest_WithOptions(t *testing.T) { req := MessagesRequest{ Model: "test-model", MaxTokens: 2048, - Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}}, Temperature: &temp, TopP: &topP, TopK: &topK, @@ -148,14 +153,14 @@ func TestFromMessagesRequest_WithImage(t *testing.T) { Messages: []MessageParam{ { Role: "user", - Content: []any{ - map[string]any{"type": "text", "text": "What's in this image?"}, - map[string]any{ - "type": "image", - "source": map[string]any{ - "type": "base64", - "media_type": "image/png", - "data": testImage, + Content: []ContentBlock{ + {Type: "text", Text: ptr("What's in this image?")}, + { + Type: "image", + Source: &ImageSource{ + Type: "base64", + MediaType: "image/png", + Data: testImage, }, }, }, @@ -190,15 +195,15 @@ func TestFromMessagesRequest_WithToolUse(t *testing.T) { Model: "test-model", MaxTokens: 1024, Messages: []MessageParam{ - {Role: "user", Content: "What's the weather in Paris?"}, + {Role: "user", Content: textContent("What's the weather in Paris?")}, { Role: "assistant", - Content: []any{ - map[string]any{ - "type": "tool_use", - "id": "call_123", - "name": "get_weather", - "input": map[string]any{"location": "Paris"}, + Content: []ContentBlock{ + { + Type: "tool_use", + ID: "call_123", + Name: "get_weather", + Input: makeArgs("location", "Paris"), }, }, }, @@ -234,11 +239,11 @@ func TestFromMessagesRequest_WithToolResult(t *testing.T) { Messages: []MessageParam{ { Role: "user", - Content: []any{ - map[string]any{ - "type": "tool_result", - "tool_use_id": "call_123", - "content": "The weather in Paris is sunny, 22°C", + Content: []ContentBlock{ + { + Type: "tool_result", + ToolUseID: "call_123", + Content: "The weather in Paris is sunny, 22°C", }, }, }, @@ -270,7 +275,7 @@ func TestFromMessagesRequest_WithTools(t *testing.T) { req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, - Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}}, Tools: []Tool{ { Name: "get_weather", @@ -305,7 +310,7 @@ func TestFromMessagesRequest_DropsCustomWebSearchWhenBuiltinPresent(t *testing.T req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, - Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}}, Tools: []Tool{ { Type: "web_search_20250305", @@ -346,7 +351,7 @@ func TestFromMessagesRequest_KeepsCustomWebSearchWhenBuiltinAbsent(t *testing.T) req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, - Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}}, Tools: []Tool{ { Type: "custom", @@ -377,7 +382,7 @@ func TestFromMessagesRequest_WithThinking(t *testing.T) { req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, - Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}}, Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000}, } @@ -399,13 +404,13 @@ func TestFromMessagesRequest_ThinkingOnlyBlock(t *testing.T) { Model: "test-model", MaxTokens: 1024, Messages: []MessageParam{ - {Role: "user", Content: "Hello"}, + {Role: "user", Content: textContent("Hello")}, { Role: "assistant", - Content: []any{ - map[string]any{ - "type": "thinking", - "thinking": "Let me think about this...", + Content: []ContentBlock{ + { + Type: "thinking", + Thinking: ptr("Let me think about this..."), }, }, }, @@ -434,10 +439,10 @@ func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) { Messages: []MessageParam{ { Role: "assistant", - Content: []any{ - map[string]any{ - "type": "tool_use", - "name": "get_weather", + Content: []ContentBlock{ + { + Type: "tool_use", + Name: "get_weather", }, }, }, @@ -460,10 +465,10 @@ func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) { Messages: []MessageParam{ { Role: "assistant", - Content: []any{ - map[string]any{ - "type": "tool_use", - "id": "call_123", + Content: []ContentBlock{ + { + Type: "tool_use", + ID: "call_123", }, }, }, @@ -483,7 +488,7 @@ func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) { req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, - Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Messages: []MessageParam{{Role: "user", Content: textContent("Hello")}}, Tools: []Tool{ { Name: "bad_tool", @@ -548,7 +553,7 @@ func TestToMessagesResponse_WithToolCalls(t *testing.T) { ID: "call_123", Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: testArgs(map[string]any{"location": "Paris"}), + Arguments: makeArgs("location", "Paris"), }, }, }, @@ -760,7 +765,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) { ID: "call_123", Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: testArgs(map[string]any{"location": "Paris"}), + Arguments: makeArgs("location", "Paris"), }, }, }, @@ -843,7 +848,7 @@ func TestStreamConverter_ThinkingDirectlyFollowedByToolCall(t *testing.T) { ID: "call_abc", Function: api.ToolCallFunction{ Name: "ask_user", - Arguments: testArgs(map[string]any{"question": "cats or dogs?"}), + Arguments: makeArgs("question", "cats or dogs?"), }, }, }, @@ -965,7 +970,7 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) { ID: "call_good", Function: api.ToolCallFunction{ Name: "good_function", - Arguments: testArgs(map[string]any{"location": "Paris"}), + Arguments: makeArgs("location", "Paris"), }, }, { @@ -1140,7 +1145,7 @@ func TestEstimateTokens_SimpleMessage(t *testing.T) { req := CountTokensRequest{ Model: "test-model", Messages: []MessageParam{ - {Role: "user", Content: "Hello, world!"}, + {Role: "user", Content: textContent("Hello, world!")}, }, } @@ -1161,7 +1166,7 @@ func TestEstimateTokens_WithSystemPrompt(t *testing.T) { Model: "test-model", System: "You are a helpful assistant.", Messages: []MessageParam{ - {Role: "user", Content: "Hello"}, + {Role: "user", Content: textContent("Hello")}, }, } @@ -1177,7 +1182,7 @@ func TestEstimateTokens_WithTools(t *testing.T) { req := CountTokensRequest{ Model: "test-model", Messages: []MessageParam{ - {Role: "user", Content: "What's the weather?"}, + {Role: "user", Content: textContent("What's the weather?")}, }, Tools: []Tool{ { @@ -1200,17 +1205,17 @@ func TestEstimateTokens_WithThinking(t *testing.T) { req := CountTokensRequest{ Model: "test-model", Messages: []MessageParam{ - {Role: "user", Content: "Hello"}, + {Role: "user", Content: textContent("Hello")}, { Role: "assistant", - Content: []any{ - map[string]any{ - "type": "thinking", - "thinking": "Let me think about this carefully...", + Content: []ContentBlock{ + { + Type: "thinking", + Thinking: ptr("Let me think about this carefully..."), }, - map[string]any{ - "type": "text", - "text": "Here is my response.", + { + Type: "text", + Text: ptr("Here is my response."), }, }, }, @@ -1308,12 +1313,12 @@ func TestConvertTool_RegularTool(t *testing.T) { func TestConvertMessage_ServerToolUse(t *testing.T) { msg := MessageParam{ Role: "assistant", - Content: []any{ - map[string]any{ - "type": "server_tool_use", - "id": "srvtoolu_123", - "name": "web_search", - "input": map[string]any{"query": "test query"}, + Content: []ContentBlock{ + { + Type: "server_tool_use", + ID: "srvtoolu_123", + Name: "web_search", + Input: makeArgs("query", "test query"), }, }, } @@ -1344,11 +1349,11 @@ func TestConvertMessage_ServerToolUse(t *testing.T) { func TestConvertMessage_WebSearchToolResult(t *testing.T) { msg := MessageParam{ Role: "user", - Content: []any{ - map[string]any{ - "type": "web_search_tool_result", - "tool_use_id": "srvtoolu_123", - "content": []any{ + Content: []ContentBlock{ + { + Type: "web_search_tool_result", + ToolUseID: "srvtoolu_123", + Content: []any{ map[string]any{ "type": "web_search_result", "title": "Test Result", @@ -1385,11 +1390,11 @@ func TestConvertMessage_WebSearchToolResult(t *testing.T) { func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testing.T) { msg := MessageParam{ Role: "user", - Content: []any{ - map[string]any{ - "type": "web_search_tool_result", - "tool_use_id": "srvtoolu_empty", - "content": []any{}, + Content: []ContentBlock{ + { + Type: "web_search_tool_result", + ToolUseID: "srvtoolu_empty", + Content: []any{}, }, }, } @@ -1416,11 +1421,11 @@ func TestConvertMessage_WebSearchToolResultEmptyStillCreatesToolMessage(t *testi func TestConvertMessage_WebSearchToolResultErrorStillCreatesToolMessage(t *testing.T) { msg := MessageParam{ Role: "user", - Content: []any{ - map[string]any{ - "type": "web_search_tool_result", - "tool_use_id": "srvtoolu_error", - "content": map[string]any{ + Content: []ContentBlock{ + { + Type: "web_search_tool_result", + ToolUseID: "srvtoolu_error", + Content: map[string]any{ "type": "web_search_tool_result_error", "error_code": "max_uses_exceeded", }, diff --git a/middleware/anthropic.go b/middleware/anthropic.go index 6e1f66ea8..4293034b7 100644 --- a/middleware/anthropic.go +++ b/middleware/anthropic.go @@ -283,7 +283,7 @@ func (w *WebSearchAnthropicWriter) runWebSearchLoop(ctx context.Context, initial Type: "server_tool_use", ID: toolUseID, Name: "web_search", - Input: map[string]any{"query": query}, + Input: queryArgs(query), }, anthropic.ContentBlock{ Type: "web_search_tool_result", @@ -348,7 +348,7 @@ func (w *WebSearchAnthropicWriter) runWebSearchLoop(ctx context.Context, initial Type: "server_tool_use", ID: maxLoopToolUseID, Name: "web_search", - Input: map[string]any{"query": maxLoopQuery}, + Input: queryArgs(maxLoopQuery), }, anthropic.ContentBlock{ Type: "web_search_tool_result", @@ -786,7 +786,7 @@ func (w *WebSearchAnthropicWriter) webSearchErrorResponse(errorCode, query strin Type: "server_tool_use", ID: toolUseID, Name: "web_search", - Input: map[string]any{"query": query}, + Input: queryArgs(query), }, { Type: "web_search_tool_result", @@ -942,6 +942,13 @@ func writeSSE(w http.ResponseWriter, eventType string, data any) error { return nil } +// queryArgs creates a ToolCallFunctionArguments with a single "query" key. +func queryArgs(query string) api.ToolCallFunctionArguments { + args := api.NewToolCallFunctionArguments() + args.Set("query", query) + return args +} + // serverToolUseID derives a server tool use ID from a message ID func serverToolUseID(messageID string) string { return "srvtoolu_" + strings.TrimPrefix(messageID, "msg_") diff --git a/middleware/anthropic_test.go b/middleware/anthropic_test.go index bacda4b8c..87e427460 100644 --- a/middleware/anthropic_test.go +++ b/middleware/anthropic_test.go @@ -1208,7 +1208,7 @@ func TestWebSearchStreamResponse(t *testing.T) { Type: "server_tool_use", ID: "srvtoolu_test123", Name: "web_search", - Input: map[string]any{"query": "test query"}, + Input: queryArgs("test query"), }, { Type: "web_search_tool_result", @@ -1413,12 +1413,8 @@ func TestWebSearchSendError_NonStreaming(t *testing.T) { t.Errorf("expected name 'web_search', got %q", result.Content[0].Name) } // Verify input contains the query - inputMap, ok := result.Content[0].Input.(map[string]any) - if !ok { - t.Fatalf("expected Input to be map, got %T", result.Content[0].Input) - } - if inputMap["query"] != "test query" { - t.Errorf("expected query 'test query', got %v", inputMap["query"]) + if q, ok := result.Content[0].Input.Get("query"); !ok || q != "test query" { + t.Errorf("expected query 'test query', got %v", q) } // Block 1: web_search_tool_result with error @@ -1561,12 +1557,8 @@ func TestWebSearchSendError_EmptyQuery(t *testing.T) { } // Verify the input has empty query - inputMap, ok := result.Content[0].Input.(map[string]any) - if !ok { - t.Fatalf("expected Input to be map, got %T", result.Content[0].Input) - } - if inputMap["query"] != "" { - t.Errorf("expected empty query, got %v", inputMap["query"]) + if q, ok := result.Content[0].Input.Get("query"); !ok || q != "" { + t.Errorf("expected empty query, got %v", q) } } diff --git a/model/parsers/qwen3.go b/model/parsers/qwen3.go index 7e503b232..29427bbcf 100644 --- a/model/parsers/qwen3.go +++ b/model/parsers/qwen3.go @@ -328,8 +328,8 @@ func (p *Qwen3Parser) eat() ([]qwen3Event, bool) { func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) { var parsed struct { - Name string `json:"name"` - Arguments map[string]any `json:"arguments"` + Name string `json:"name"` + Arguments api.ToolCallFunctionArguments `json:"arguments"` } if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil { @@ -345,13 +345,9 @@ func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCa toolCall := api.ToolCall{ Function: api.ToolCallFunction{ Name: parsed.Name, - Arguments: api.NewToolCallFunctionArguments(), + Arguments: parsed.Arguments, }, } - for key, value := range parsed.Arguments { - toolCall.Function.Arguments.Set(key, value) - } - return toolCall, nil }