Improve token count accuracy using Anthropic's API (#44943)

Closes #38533

<img width="807" height="425" alt="Screenshot 2025-12-16 at 2 32 21 PM"
src="https://github.com/user-attachments/assets/6ebb915c-91d3-4158-a2b9-9fe17d301dd6"
/>


Release Notes:

- Use up-to-date token counts from LLM responses when reporting tokens
used per thread

---------

Co-authored-by: Claude Haiku 4.5 <noreply@anthropic.com>
This commit is contained in:
Richard Feldman
2025-12-16 14:32:41 -05:00
committed by GitHub
parent 0c91f061c3
commit d16619a654
5 changed files with 506 additions and 59 deletions

View File

@@ -2809,3 +2809,181 @@ fn setup_context_server(
cx.run_until_parked();
mcp_tool_calls_rx
}
#[gpui::test]
async fn test_tokens_before_message(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
// First message
let message_1_id = UserMessageId::new();
thread
.update(cx, |thread, cx| {
thread.send(message_1_id.clone(), ["First message"], cx)
})
.unwrap();
cx.run_until_parked();
// Before any response, tokens_before_message should return None for first message
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.tokens_before_message(&message_1_id),
None,
"First message should have no tokens before it"
);
});
// Complete first message with usage
fake_model.send_last_completion_stream_text_chunk("Response 1");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
language_model::TokenUsage {
input_tokens: 100,
output_tokens: 50,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
// First message still has no tokens before it
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.tokens_before_message(&message_1_id),
None,
"First message should still have no tokens before it after response"
);
});
// Second message
let message_2_id = UserMessageId::new();
thread
.update(cx, |thread, cx| {
thread.send(message_2_id.clone(), ["Second message"], cx)
})
.unwrap();
cx.run_until_parked();
// Second message should have first message's input tokens before it
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.tokens_before_message(&message_2_id),
Some(100),
"Second message should have 100 tokens before it (from first request)"
);
});
// Complete second message
fake_model.send_last_completion_stream_text_chunk("Response 2");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
language_model::TokenUsage {
input_tokens: 250, // Total for this request (includes previous context)
output_tokens: 75,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
// Third message
let message_3_id = UserMessageId::new();
thread
.update(cx, |thread, cx| {
thread.send(message_3_id.clone(), ["Third message"], cx)
})
.unwrap();
cx.run_until_parked();
// Third message should have second message's input tokens (250) before it
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.tokens_before_message(&message_3_id),
Some(250),
"Third message should have 250 tokens before it (from second request)"
);
// Second message should still have 100
assert_eq!(
thread.tokens_before_message(&message_2_id),
Some(100),
"Second message should still have 100 tokens before it"
);
// First message still has none
assert_eq!(
thread.tokens_before_message(&message_1_id),
None,
"First message should still have no tokens before it"
);
});
}
#[gpui::test]
async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
// Set up three messages with responses
let message_1_id = UserMessageId::new();
thread
.update(cx, |thread, cx| {
thread.send(message_1_id.clone(), ["Message 1"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Response 1");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
language_model::TokenUsage {
input_tokens: 100,
output_tokens: 50,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
let message_2_id = UserMessageId::new();
thread
.update(cx, |thread, cx| {
thread.send(message_2_id.clone(), ["Message 2"], cx)
})
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Response 2");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
language_model::TokenUsage {
input_tokens: 250,
output_tokens: 75,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
));
fake_model.end_last_completion_stream();
cx.run_until_parked();
// Verify initial state
thread.read_with(cx, |thread, _| {
assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
});
// Truncate at message 2 (removes message 2 and everything after)
thread
.update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
.unwrap();
cx.run_until_parked();
// After truncation, message_2_id no longer exists, so lookup should return None
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.tokens_before_message(&message_2_id),
None,
"After truncation, message 2 no longer exists"
);
// Message 1 still exists but has no tokens before it
assert_eq!(
thread.tokens_before_message(&message_1_id),
None,
"First message still has no tokens before it"
);
});
}

View File

@@ -1095,6 +1095,28 @@ impl Thread {
})
}
/// Get the total input token count as of the message before the given message.
///
/// Returns `None` if:
/// - `target_id` is the first message (no previous message)
/// - The previous message hasn't received a response yet (no usage data)
/// - `target_id` is not found in the messages
pub fn tokens_before_message(&self, target_id: &UserMessageId) -> Option<u64> {
let mut previous_user_message_id: Option<&UserMessageId> = None;
for message in &self.messages {
if let Message::User(user_msg) = message {
if &user_msg.id == target_id {
let prev_id = previous_user_message_id?;
let usage = self.request_token_usage.get(prev_id)?;
return Some(usage.input_tokens);
}
previous_user_message_id = Some(&user_msg.id);
}
}
None
}
/// Look up the active profile and resolve its preferred model if one is configured.
fn resolve_profile_model(
profile_id: &AgentProfileId,

View File

@@ -1052,6 +1052,71 @@ pub fn parse_prompt_too_long(message: &str) -> Option<u64> {
.ok()
}
/// Request body for the token counting API.
/// Similar to `Request` but without `max_tokens` since it's not needed for counting.
#[derive(Debug, Serialize)]
pub struct CountTokensRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub system: Option<StringOrContents>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub thinking: Option<Thinking>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
}
/// Response from the token counting API.
#[derive(Debug, Deserialize)]
pub struct CountTokensResponse {
pub input_tokens: u64,
}
/// Count the number of tokens in a message without creating it.
pub async fn count_tokens(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: CountTokensRequest,
) -> Result<CountTokensResponse, AnthropicError> {
let uri = format!("{api_url}/v1/messages/count_tokens");
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Anthropic-Version", "2023-06-01")
.header("X-Api-Key", api_key.trim())
.header("Content-Type", "application/json");
let serialized_request =
serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
let http_request = request_builder
.body(AsyncBody::from(serialized_request))
.map_err(AnthropicError::BuildRequestBody)?;
let mut response = client
.send(http_request)
.await
.map_err(AnthropicError::HttpSend)?;
let rate_limits = RateLimitInfo::from_headers(response.headers());
if response.status().is_success() {
let mut body = String::new();
response
.body_mut()
.read_to_string(&mut body)
.await
.map_err(AnthropicError::ReadResponse)?;
serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
} else {
Err(handle_error_response(response, rate_limits).await)
}
}
#[test]
fn test_match_window_exceeded() {
let error = ApiError {

View File

@@ -1,6 +1,6 @@
use anthropic::{
ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent,
ToolResultContent, ToolResultPart, Usage,
ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, CountTokensRequest, Event,
ResponseContent, ToolResultContent, ToolResultPart, Usage,
};
use anyhow::{Result, anyhow};
use collections::{BTreeMap, HashMap};
@@ -219,68 +219,215 @@ pub struct AnthropicModel {
request_limiter: RateLimiter,
}
pub fn count_anthropic_tokens(
/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest.
pub fn into_anthropic_count_tokens_request(
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
cx.background_spawn(async move {
let messages = request.messages;
let mut tokens_from_images = 0;
let mut string_messages = Vec::with_capacity(messages.len());
model: String,
mode: AnthropicModelMode,
) -> CountTokensRequest {
let mut new_messages: Vec<anthropic::Message> = Vec::new();
let mut system_message = String::new();
for message in messages {
use language_model::MessageContent;
for message in request.messages {
if message.contents_empty() {
continue;
}
let mut string_contents = String::new();
match message.role {
Role::User | Role::Assistant => {
let anthropic_message_content: Vec<anthropic::RequestContent> = message
.content
.into_iter()
.filter_map(|content| match content {
MessageContent::Text(text) => {
let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
text.trim_end().to_string()
} else {
text
};
if !text.is_empty() {
Some(anthropic::RequestContent::Text {
text,
cache_control: None,
})
} else {
None
}
}
MessageContent::Thinking {
text: thinking,
signature,
} => {
if !thinking.is_empty() {
Some(anthropic::RequestContent::Thinking {
thinking,
signature: signature.unwrap_or_default(),
cache_control: None,
})
} else {
None
}
}
MessageContent::RedactedThinking(data) => {
if !data.is_empty() {
Some(anthropic::RequestContent::RedactedThinking { data })
} else {
None
}
}
MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
source: anthropic::ImageSource {
source_type: "base64".to_string(),
media_type: "image/png".to_string(),
data: image.source.to_string(),
},
cache_control: None,
}),
MessageContent::ToolUse(tool_use) => {
Some(anthropic::RequestContent::ToolUse {
id: tool_use.id.to_string(),
name: tool_use.name.to_string(),
input: tool_use.input,
cache_control: None,
})
}
MessageContent::ToolResult(tool_result) => {
Some(anthropic::RequestContent::ToolResult {
tool_use_id: tool_result.tool_use_id.to_string(),
is_error: tool_result.is_error,
content: match tool_result.content {
LanguageModelToolResultContent::Text(text) => {
ToolResultContent::Plain(text.to_string())
}
LanguageModelToolResultContent::Image(image) => {
ToolResultContent::Multipart(vec![ToolResultPart::Image {
source: anthropic::ImageSource {
source_type: "base64".to_string(),
media_type: "image/png".to_string(),
data: image.source.to_string(),
},
}])
}
},
cache_control: None,
})
}
})
.collect();
let anthropic_role = match message.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => unreachable!("System role should never occur here"),
};
if let Some(last_message) = new_messages.last_mut()
&& last_message.role == anthropic_role
{
last_message.content.extend(anthropic_message_content);
continue;
}
for content in message.content {
match content {
MessageContent::Text(text) => {
string_contents.push_str(&text);
new_messages.push(anthropic::Message {
role: anthropic_role,
content: anthropic_message_content,
});
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.string_contents());
}
}
}
CountTokensRequest {
model,
messages: new_messages,
system: if system_message.is_empty() {
None
} else {
Some(anthropic::StringOrContents::String(system_message))
},
thinking: if request.thinking_allowed
&& let AnthropicModelMode::Thinking { budget_tokens } = mode
{
Some(anthropic::Thinking::Enabled { budget_tokens })
} else {
None
},
tools: request
.tools
.into_iter()
.map(|tool| anthropic::Tool {
name: tool.name,
description: tool.description,
input_schema: tool.input_schema,
})
.collect(),
tool_choice: request.tool_choice.map(|choice| match choice {
LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
LanguageModelToolChoice::None => anthropic::ToolChoice::None,
}),
}
}
/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable,
/// or by providers (like Zed Cloud) that don't have direct Anthropic API access.
pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result<u64> {
let messages = request.messages;
let mut tokens_from_images = 0;
let mut string_messages = Vec::with_capacity(messages.len());
for message in messages {
let mut string_contents = String::new();
for content in message.content {
match content {
MessageContent::Text(text) => {
string_contents.push_str(&text);
}
MessageContent::Thinking { .. } => {
// Thinking blocks are not included in the input token count.
}
MessageContent::RedactedThinking(_) => {
// Thinking blocks are not included in the input token count.
}
MessageContent::Image(image) => {
tokens_from_images += image.estimate_tokens();
}
MessageContent::ToolUse(_tool_use) => {
// TODO: Estimate token usage from tool uses.
}
MessageContent::ToolResult(tool_result) => match &tool_result.content {
LanguageModelToolResultContent::Text(text) => {
string_contents.push_str(text);
}
MessageContent::Thinking { .. } => {
// Thinking blocks are not included in the input token count.
}
MessageContent::RedactedThinking(_) => {
// Thinking blocks are not included in the input token count.
}
MessageContent::Image(image) => {
LanguageModelToolResultContent::Image(image) => {
tokens_from_images += image.estimate_tokens();
}
MessageContent::ToolUse(_tool_use) => {
// TODO: Estimate token usage from tool uses.
}
MessageContent::ToolResult(tool_result) => match &tool_result.content {
LanguageModelToolResultContent::Text(text) => {
string_contents.push_str(text);
}
LanguageModelToolResultContent::Image(image) => {
tokens_from_images += image.estimate_tokens();
}
},
}
}
if !string_contents.is_empty() {
string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(string_contents),
name: None,
function_call: None,
});
},
}
}
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
.map(|tokens| (tokens + tokens_from_images) as u64)
})
.boxed()
if !string_contents.is_empty() {
string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(string_contents),
name: None,
function_call: None,
});
}
}
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
.map(|tokens| (tokens + tokens_from_images) as u64)
}
impl AnthropicModel {
@@ -386,7 +533,40 @@ impl LanguageModel for AnthropicModel {
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
count_anthropic_tokens(request, cx)
let http_client = self.http_client.clone();
let model_id = self.model.request_id().to_string();
let mode = self.model.mode();
let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
let api_url = AnthropicLanguageModelProvider::api_url(cx);
(
state.api_key_state.key(&api_url).map(|k| k.to_string()),
api_url.to_string(),
)
});
async move {
// If no API key, fall back to tiktoken estimation
let Some(api_key) = api_key else {
return count_anthropic_tokens_with_tiktoken(request);
};
let count_request =
into_anthropic_count_tokens_request(request.clone(), model_id, mode);
match anthropic::count_tokens(http_client.as_ref(), &api_url, &api_key, count_request)
.await
{
Ok(response) => Ok(response.input_tokens),
Err(err) => {
log::error!(
"Anthropic count_tokens API failed, falling back to tiktoken: {err:?}"
);
count_anthropic_tokens_with_tiktoken(request)
}
}
}
.boxed()
}
fn stream_completion(

View File

@@ -42,7 +42,9 @@ use thiserror::Error;
use ui::{TintColor, prelude::*};
use util::{ResultExt as _, maybe};
use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
use crate::provider::anthropic::{
AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
};
use crate::provider::google::{GoogleEventMapper, into_google};
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
use crate::provider::x_ai::count_xai_tokens;
@@ -667,9 +669,9 @@ impl LanguageModel for CloudLanguageModel {
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
match self.model.provider {
cloud_llm_client::LanguageModelProvider::Anthropic => {
count_anthropic_tokens(request, cx)
}
cloud_llm_client::LanguageModelProvider::Anthropic => cx
.background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
.boxed(),
cloud_llm_client::LanguageModelProvider::OpenAi => {
let model = match open_ai::Model::from_id(&self.model.id.0) {
Ok(model) => model,