Compare commits
20 Commits
load_diffs
...
simplify-e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b74ddf020 | ||
|
|
54f20ae5d5 | ||
|
|
811efa45d0 | ||
|
|
74501e0936 | ||
|
|
2ad8bd00ce | ||
|
|
7c0663b825 | ||
|
|
94a43dc73a | ||
|
|
e8e0707256 | ||
|
|
d7c340c739 | ||
|
|
16b24e892e | ||
|
|
917148c5ce | ||
|
|
951132fc13 | ||
|
|
bf0dd4057c | ||
|
|
3c4ca3f372 | ||
|
|
03132921c7 | ||
|
|
c0fadae881 | ||
|
|
1c66c3991d | ||
|
|
7e591a7e9a | ||
|
|
c44d93745a | ||
|
|
b4e4e0d3ac |
@@ -16,7 +16,9 @@ rustflags = ["-D", "warnings"]
|
||||
debug = "limited"
|
||||
|
||||
# Use Mold on Linux, because it's faster than GNU ld and LLD.
|
||||
# We dont use wild in CI as its not production ready.
|
||||
#
|
||||
# We no longer set this in the default `config.toml` so that developers can opt in to Wild, which
|
||||
# is faster than Mold, in their own ~/.cargo/config.toml.
|
||||
[target.x86_64-unknown-linux-gnu]
|
||||
linker = "clang"
|
||||
rustflags = ["-C", "link-arg=-fuse-ld=mold"]
|
||||
|
||||
@@ -8,14 +8,6 @@ perf-test = ["test", "--profile", "release-fast", "--lib", "--bins", "--tests",
|
||||
# Keep similar flags here to share some ccache
|
||||
perf-compare = ["run", "--profile", "release-fast", "-p", "perf", "--config", "target.'cfg(true)'.rustflags=[\"--cfg\", \"perf_enabled\"]", "--", "compare"]
|
||||
|
||||
# [target.x86_64-unknown-linux-gnu]
|
||||
# linker = "clang"
|
||||
# rustflags = ["-C", "link-arg=-fuse-ld=mold"]
|
||||
|
||||
[target.aarch64-unknown-linux-gnu]
|
||||
linker = "clang"
|
||||
rustflags = ["-C", "link-arg=-fuse-ld=mold"]
|
||||
|
||||
[target.'cfg(target_os = "windows")']
|
||||
rustflags = [
|
||||
"--cfg",
|
||||
|
||||
19
Cargo.lock
generated
19
Cargo.lock
generated
@@ -2617,23 +2617,26 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "calloop"
|
||||
version = "0.14.3"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b99da2f8558ca23c71f4fd15dc57c906239752dd27ff3c00a1d56b685b7cbfec"
|
||||
dependencies = [
|
||||
"bitflags 2.9.4",
|
||||
"log",
|
||||
"polling",
|
||||
"rustix 1.1.2",
|
||||
"rustix 0.38.44",
|
||||
"slab",
|
||||
"tracing",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "calloop-wayland-source"
|
||||
version = "0.4.1"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "138efcf0940a02ebf0cc8d1eff41a1682a46b431630f4c52450d6265876021fa"
|
||||
checksum = "95a66a987056935f7efce4ab5668920b5d0dac4a7c99991a67395f13702ddd20"
|
||||
dependencies = [
|
||||
"calloop",
|
||||
"rustix 1.1.2",
|
||||
"rustix 0.38.44",
|
||||
"wayland-backend",
|
||||
"wayland-client",
|
||||
]
|
||||
@@ -3208,6 +3211,7 @@ dependencies = [
|
||||
"rustc-hash 2.1.1",
|
||||
"schemars 1.0.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum 0.27.2",
|
||||
]
|
||||
|
||||
@@ -3687,6 +3691,7 @@ dependencies = [
|
||||
"collections",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"http_client",
|
||||
"log",
|
||||
"net",
|
||||
"parking_lot",
|
||||
@@ -8725,7 +8730,6 @@ dependencies = [
|
||||
"ui",
|
||||
"ui_input",
|
||||
"util",
|
||||
"vim",
|
||||
"workspace",
|
||||
"zed_actions",
|
||||
]
|
||||
@@ -10027,7 +10031,6 @@ name = "miniprofiler_ui"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"gpui",
|
||||
"log",
|
||||
"serde_json",
|
||||
"smol",
|
||||
"util",
|
||||
|
||||
@@ -784,7 +784,6 @@ features = [
|
||||
notify = { git = "https://github.com/zed-industries/notify.git", rev = "b4588b2e5aee68f4c0e100f140e808cbce7b1419" }
|
||||
notify-types = { git = "https://github.com/zed-industries/notify.git", rev = "b4588b2e5aee68f4c0e100f140e808cbce7b1419" }
|
||||
windows-capture = { git = "https://github.com/zed-industries/windows-capture.git", rev = "f0d6c1b6691db75461b732f6d5ff56eed002eeb9" }
|
||||
calloop = { path = "/home/davidsk/tmp/calloop" }
|
||||
|
||||
[profile.dev]
|
||||
split-debuginfo = "unpacked"
|
||||
@@ -861,7 +860,7 @@ ui_input = { codegen-units = 1 }
|
||||
zed_actions = { codegen-units = 1 }
|
||||
|
||||
[profile.release]
|
||||
debug = "full"
|
||||
debug = "limited"
|
||||
lto = "thin"
|
||||
codegen-units = 1
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 9.3 KiB After Width: | Height: | Size: 14 KiB |
@@ -150,6 +150,7 @@ impl DbThread {
|
||||
.unwrap_or_default(),
|
||||
input: tool_use.input,
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
@@ -1108,6 +1108,7 @@ fn tool_use(
|
||||
raw_input: serde_json::to_string_pretty(&input).unwrap(),
|
||||
input: serde_json::to_value(input).unwrap(),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -21,9 +21,9 @@ use gpui::{
|
||||
use indoc::indoc;
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelToolResult, LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent,
|
||||
Role, StopReason, fake_provider::FakeLanguageModel,
|
||||
};
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::{
|
||||
@@ -274,6 +274,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
|
||||
raw_input: json!({"text": "test"}).to_string(),
|
||||
input: json!({"text": "test"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
fake_model
|
||||
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
||||
@@ -461,6 +462,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||
raw_input: "{}".into(),
|
||||
input: json!({}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
},
|
||||
));
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
@@ -470,6 +472,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||
raw_input: "{}".into(),
|
||||
input: json!({}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
@@ -520,6 +523,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||
raw_input: "{}".into(),
|
||||
input: json!({}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
@@ -554,6 +558,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||
raw_input: "{}".into(),
|
||||
input: json!({}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
@@ -592,6 +597,7 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
|
||||
raw_input: "{}".into(),
|
||||
input: json!({}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
@@ -621,6 +627,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||
raw_input: "{}".into(),
|
||||
input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
fake_model
|
||||
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
||||
@@ -657,9 +664,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||
);
|
||||
|
||||
// Simulate reaching tool use limit.
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
|
||||
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
|
||||
));
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
|
||||
fake_model.end_last_completion_stream();
|
||||
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
|
||||
assert!(
|
||||
@@ -731,6 +736,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||
raw_input: "{}".into(),
|
||||
input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
let tool_result = LanguageModelToolResult {
|
||||
tool_use_id: "tool_id_1".into(),
|
||||
@@ -741,9 +747,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||
};
|
||||
fake_model
|
||||
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
|
||||
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
|
||||
));
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
|
||||
fake_model.end_last_completion_stream();
|
||||
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
|
||||
assert!(
|
||||
@@ -1037,6 +1041,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) {
|
||||
raw_input: json!({"text": "test"}).to_string(),
|
||||
input: json!({"text": "test"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
@@ -1080,6 +1085,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) {
|
||||
raw_input: json!({"text": "mcp"}).to_string(),
|
||||
input: json!({"text": "mcp"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
},
|
||||
));
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
@@ -1089,6 +1095,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) {
|
||||
raw_input: json!({"text": "native"}).to_string(),
|
||||
input: json!({"text": "native"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
@@ -1522,7 +1529,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
|
||||
});
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
|
||||
language_model::TokenUsage {
|
||||
input_tokens: 32_000,
|
||||
output_tokens: 16_000,
|
||||
@@ -1580,7 +1587,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
|
||||
});
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
|
||||
language_model::TokenUsage {
|
||||
input_tokens: 40_000,
|
||||
output_tokens: 20_000,
|
||||
@@ -1625,7 +1632,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_text_chunk("Message 1 response");
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
|
||||
language_model::TokenUsage {
|
||||
input_tokens: 32_000,
|
||||
output_tokens: 16_000,
|
||||
@@ -1672,7 +1679,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
|
||||
cx.run_until_parked();
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Message 2 response");
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
|
||||
language_model::TokenUsage {
|
||||
input_tokens: 40_000,
|
||||
output_tokens: 20_000,
|
||||
@@ -1788,6 +1795,7 @@ async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
|
||||
raw_input: "{}".into(),
|
||||
input: json!({}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
let echo_tool_use = LanguageModelToolUse {
|
||||
id: "tool_id_2".into(),
|
||||
@@ -1795,6 +1803,7 @@ async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
|
||||
raw_input: json!({"text": "test"}).to_string(),
|
||||
input: json!({"text": "test"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
fake_model.send_last_completion_stream_text_chunk("Hi!");
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
@@ -2000,6 +2009,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||
raw_input: input.to_string(),
|
||||
input,
|
||||
is_input_complete: false,
|
||||
thought_signature: None,
|
||||
},
|
||||
));
|
||||
|
||||
@@ -2012,6 +2022,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||
raw_input: input.to_string(),
|
||||
input,
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
@@ -2144,7 +2155,6 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Hey,");
|
||||
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
||||
provider: LanguageModelProviderName::new("Anthropic"),
|
||||
retry_after: Some(Duration::from_secs(3)),
|
||||
});
|
||||
fake_model.end_last_completion_stream();
|
||||
@@ -2214,12 +2224,12 @@ async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
|
||||
raw_input: json!({"text": "test"}).to_string(),
|
||||
input: json!({"text": "test"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
tool_use_1.clone(),
|
||||
));
|
||||
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
||||
provider: LanguageModelProviderName::new("Anthropic"),
|
||||
retry_after: Some(Duration::from_secs(3)),
|
||||
});
|
||||
fake_model.end_last_completion_stream();
|
||||
@@ -2286,7 +2296,6 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
|
||||
for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
|
||||
fake_model.send_last_completion_stream_error(
|
||||
LanguageModelCompletionError::ServerOverloaded {
|
||||
provider: LanguageModelProviderName::new("Anthropic"),
|
||||
retry_after: Some(Duration::from_secs(3)),
|
||||
},
|
||||
);
|
||||
|
||||
@@ -15,7 +15,7 @@ use agent_settings::{
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use client::{ModelRequestUsage, RequestUsage, UserStore};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
|
||||
use cloud_llm_client::{CompletionIntent, Plan, UsageLimit};
|
||||
use collections::{HashMap, HashSet, IndexMap};
|
||||
use fs::Fs;
|
||||
use futures::stream;
|
||||
@@ -30,11 +30,11 @@ use gpui::{
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
|
||||
LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
|
||||
ZED_CLOUD_PROVIDER_ID,
|
||||
LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role,
|
||||
SelectedModel, StopReason, TokenUsage, ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
@@ -1295,9 +1295,10 @@ impl Thread {
|
||||
|
||||
if let Some(error) = error {
|
||||
attempt += 1;
|
||||
let provider = model.upstream_provider_name();
|
||||
let retry = this.update(cx, |this, cx| {
|
||||
let user_store = this.user_store.read(cx);
|
||||
this.handle_completion_error(error, attempt, user_store.plan())
|
||||
this.handle_completion_error(provider, error, attempt, user_store.plan())
|
||||
})??;
|
||||
let timer = cx.background_executor().timer(retry.duration);
|
||||
event_stream.send_retry(retry);
|
||||
@@ -1323,6 +1324,7 @@ impl Thread {
|
||||
|
||||
fn handle_completion_error(
|
||||
&mut self,
|
||||
provider: LanguageModelProviderName,
|
||||
error: LanguageModelCompletionError,
|
||||
attempt: u8,
|
||||
plan: Option<Plan>,
|
||||
@@ -1389,7 +1391,7 @@ impl Thread {
|
||||
use LanguageModelCompletionEvent::*;
|
||||
|
||||
match event {
|
||||
StartMessage { .. } => {
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||
self.flush_pending_message(cx);
|
||||
self.pending_message = Some(AgentMessage::default());
|
||||
}
|
||||
@@ -1416,7 +1418,7 @@ impl Thread {
|
||||
),
|
||||
)));
|
||||
}
|
||||
UsageUpdate(usage) => {
|
||||
TokenUsage(usage) => {
|
||||
telemetry::event!(
|
||||
"Agent Thread Completion Usage Updated",
|
||||
thread_id = self.id.to_string(),
|
||||
@@ -1430,20 +1432,16 @@ impl Thread {
|
||||
);
|
||||
self.update_token_usage(usage, cx);
|
||||
}
|
||||
StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
|
||||
RequestUsage { amount, limit } => {
|
||||
self.update_model_request_usage(amount, limit, cx);
|
||||
}
|
||||
StatusUpdate(
|
||||
CompletionRequestStatus::Started
|
||||
| CompletionRequestStatus::Queued { .. }
|
||||
| CompletionRequestStatus::Failed { .. },
|
||||
) => {}
|
||||
StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
|
||||
ToolUseLimitReached => {
|
||||
self.tool_use_limit_reached = true;
|
||||
}
|
||||
Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
|
||||
Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
|
||||
Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
|
||||
Started | Queued { .. } => {}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
@@ -1687,9 +1685,7 @@ impl Thread {
|
||||
let event = event.log_err()?;
|
||||
let text = match event {
|
||||
LanguageModelCompletionEvent::Text(text) => text,
|
||||
LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::UsageUpdated { amount, limit },
|
||||
) => {
|
||||
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
|
||||
this.update(cx, |thread, cx| {
|
||||
thread.update_model_request_usage(amount, limit, cx);
|
||||
})
|
||||
@@ -1753,9 +1749,7 @@ impl Thread {
|
||||
let event = event?;
|
||||
let text = match event {
|
||||
LanguageModelCompletionEvent::Text(text) => text,
|
||||
LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::UsageUpdated { amount, limit },
|
||||
) => {
|
||||
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
|
||||
this.update(cx, |thread, cx| {
|
||||
thread.update_model_request_usage(amount, limit, cx);
|
||||
})?;
|
||||
|
||||
@@ -247,37 +247,58 @@ impl AgentConnection for AcpConnection {
|
||||
let default_mode = self.default_mode.clone();
|
||||
let cwd = cwd.to_path_buf();
|
||||
let context_server_store = project.read(cx).context_server_store().read(cx);
|
||||
let mcp_servers = if project.read(cx).is_local() {
|
||||
context_server_store
|
||||
.configured_server_ids()
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
let configuration = context_server_store.configuration_for_server(id)?;
|
||||
let command = configuration.command();
|
||||
Some(acp::McpServer::Stdio {
|
||||
name: id.0.to_string(),
|
||||
command: command.path.clone(),
|
||||
args: command.args.clone(),
|
||||
env: if let Some(env) = command.env.as_ref() {
|
||||
env.iter()
|
||||
.map(|(name, value)| acp::EnvVariable {
|
||||
name: name.clone(),
|
||||
value: value.clone(),
|
||||
meta: None,
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
},
|
||||
let mcp_servers =
|
||||
if project.read(cx).is_local() {
|
||||
context_server_store
|
||||
.configured_server_ids()
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
let configuration = context_server_store.configuration_for_server(id)?;
|
||||
match &*configuration {
|
||||
project::context_server_store::ContextServerConfiguration::Custom {
|
||||
command,
|
||||
..
|
||||
}
|
||||
| project::context_server_store::ContextServerConfiguration::Extension {
|
||||
command,
|
||||
..
|
||||
} => Some(acp::McpServer::Stdio {
|
||||
name: id.0.to_string(),
|
||||
command: command.path.clone(),
|
||||
args: command.args.clone(),
|
||||
env: if let Some(env) = command.env.as_ref() {
|
||||
env.iter()
|
||||
.map(|(name, value)| acp::EnvVariable {
|
||||
name: name.clone(),
|
||||
value: value.clone(),
|
||||
meta: None,
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
},
|
||||
}),
|
||||
project::context_server_store::ContextServerConfiguration::Http {
|
||||
url,
|
||||
headers,
|
||||
} => Some(acp::McpServer::Http {
|
||||
name: id.0.to_string(),
|
||||
url: url.to_string(),
|
||||
headers: headers.iter().map(|(name, value)| acp::HttpHeader {
|
||||
name: name.clone(),
|
||||
value: value.clone(),
|
||||
meta: None,
|
||||
}).collect(),
|
||||
}),
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
// In SSH projects, the external agent is running on the remote
|
||||
// machine, and currently we only run MCP servers on the local
|
||||
// machine. So don't pass any MCP servers to the agent in that case.
|
||||
Vec::new()
|
||||
};
|
||||
.collect()
|
||||
} else {
|
||||
// In SSH projects, the external agent is running on the remote
|
||||
// machine, and currently we only run MCP servers on the local
|
||||
// machine. So don't pass any MCP servers to the agent in that case.
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let response = conn
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
mod add_llm_provider_modal;
|
||||
mod configure_context_server_modal;
|
||||
pub mod configure_context_server_modal;
|
||||
mod configure_context_server_tools_modal;
|
||||
mod manage_profiles_modal;
|
||||
mod tool_picker;
|
||||
@@ -46,9 +46,8 @@ pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
|
||||
pub(crate) use configure_context_server_tools_modal::ConfigureContextServerToolsModal;
|
||||
pub(crate) use manage_profiles_modal::ManageProfilesModal;
|
||||
|
||||
use crate::{
|
||||
AddContextServer,
|
||||
agent_configuration::add_llm_provider_modal::{AddLlmProviderModal, LlmCompatibleProvider},
|
||||
use crate::agent_configuration::add_llm_provider_modal::{
|
||||
AddLlmProviderModal, LlmCompatibleProvider,
|
||||
};
|
||||
|
||||
pub struct AgentConfiguration {
|
||||
@@ -553,7 +552,9 @@ impl AgentConfiguration {
|
||||
move |window, cx| {
|
||||
Some(ContextMenu::build(window, cx, |menu, _window, _cx| {
|
||||
menu.entry("Add Custom Server", None, {
|
||||
|window, cx| window.dispatch_action(AddContextServer.boxed_clone(), cx)
|
||||
|window, cx| {
|
||||
window.dispatch_action(crate::AddContextServer.boxed_clone(), cx)
|
||||
}
|
||||
})
|
||||
.entry("Install from Extensions", None, {
|
||||
|window, cx| {
|
||||
@@ -651,7 +652,7 @@ impl AgentConfiguration {
|
||||
let is_running = matches!(server_status, ContextServerStatus::Running);
|
||||
let item_id = SharedString::from(context_server_id.0.clone());
|
||||
// Servers without a configuration can only be provided by extensions.
|
||||
let provided_by_extension = server_configuration.is_none_or(|config| {
|
||||
let provided_by_extension = server_configuration.as_ref().is_none_or(|config| {
|
||||
matches!(
|
||||
config.as_ref(),
|
||||
ContextServerConfiguration::Extension { .. }
|
||||
@@ -707,7 +708,10 @@ impl AgentConfiguration {
|
||||
"Server is stopped.",
|
||||
),
|
||||
};
|
||||
|
||||
let is_remote = server_configuration
|
||||
.as_ref()
|
||||
.map(|config| matches!(config.as_ref(), ContextServerConfiguration::Http { .. }))
|
||||
.unwrap_or(false);
|
||||
let context_server_configuration_menu = PopoverMenu::new("context-server-config-menu")
|
||||
.trigger_with_tooltip(
|
||||
IconButton::new("context-server-config-menu", IconName::Settings)
|
||||
@@ -730,14 +734,25 @@ impl AgentConfiguration {
|
||||
let language_registry = language_registry.clone();
|
||||
let workspace = workspace.clone();
|
||||
move |window, cx| {
|
||||
ConfigureContextServerModal::show_modal_for_existing_server(
|
||||
context_server_id.clone(),
|
||||
language_registry.clone(),
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.detach_and_log_err(cx);
|
||||
if is_remote {
|
||||
crate::agent_configuration::configure_context_server_modal::ConfigureContextServerModal::show_modal_for_existing_server(
|
||||
context_server_id.clone(),
|
||||
language_registry.clone(),
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.detach();
|
||||
} else {
|
||||
ConfigureContextServerModal::show_modal_for_existing_server(
|
||||
context_server_id.clone(),
|
||||
language_registry.clone(),
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
}).when(tool_count > 0, |this| this.entry("View Tools", None, {
|
||||
let context_server_id = context_server_id.clone();
|
||||
|
||||
@@ -3,16 +3,42 @@ use std::sync::Arc;
|
||||
use anyhow::Result;
|
||||
use collections::HashSet;
|
||||
use fs::Fs;
|
||||
use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, Task};
|
||||
use gpui::{
|
||||
DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, ScrollHandle, Task,
|
||||
};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use language_models::provider::open_ai_compatible::{AvailableModel, ModelCapabilities};
|
||||
use settings::{OpenAiCompatibleSettingsContent, update_settings_file};
|
||||
use ui::{
|
||||
Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState, prelude::*,
|
||||
Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState,
|
||||
WithScrollbar, prelude::*,
|
||||
};
|
||||
use ui_input::InputField;
|
||||
use workspace::{ModalView, Workspace};
|
||||
|
||||
fn single_line_input(
|
||||
label: impl Into<SharedString>,
|
||||
placeholder: impl Into<SharedString>,
|
||||
text: Option<&str>,
|
||||
tab_index: isize,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<InputField> {
|
||||
cx.new(|cx| {
|
||||
let input = InputField::new(window, cx, placeholder)
|
||||
.label(label)
|
||||
.tab_index(tab_index)
|
||||
.tab_stop(true);
|
||||
|
||||
if let Some(text) = text {
|
||||
input
|
||||
.editor()
|
||||
.update(cx, |editor, cx| editor.set_text(text, window, cx));
|
||||
}
|
||||
input
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum LlmCompatibleProvider {
|
||||
OpenAi,
|
||||
@@ -41,12 +67,14 @@ struct AddLlmProviderInput {
|
||||
|
||||
impl AddLlmProviderInput {
|
||||
fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
|
||||
let provider_name = single_line_input("Provider Name", provider.name(), None, window, cx);
|
||||
let api_url = single_line_input("API URL", provider.api_url(), None, window, cx);
|
||||
let provider_name =
|
||||
single_line_input("Provider Name", provider.name(), None, 1, window, cx);
|
||||
let api_url = single_line_input("API URL", provider.api_url(), None, 2, window, cx);
|
||||
let api_key = single_line_input(
|
||||
"API Key",
|
||||
"000000000000000000000000000000000000000000000000",
|
||||
None,
|
||||
3,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
@@ -55,12 +83,13 @@ impl AddLlmProviderInput {
|
||||
provider_name,
|
||||
api_url,
|
||||
api_key,
|
||||
models: vec![ModelInput::new(window, cx)],
|
||||
models: vec![ModelInput::new(0, window, cx)],
|
||||
}
|
||||
}
|
||||
|
||||
fn add_model(&mut self, window: &mut Window, cx: &mut App) {
|
||||
self.models.push(ModelInput::new(window, cx));
|
||||
let model_index = self.models.len();
|
||||
self.models.push(ModelInput::new(model_index, window, cx));
|
||||
}
|
||||
|
||||
fn remove_model(&mut self, index: usize) {
|
||||
@@ -84,11 +113,14 @@ struct ModelInput {
|
||||
}
|
||||
|
||||
impl ModelInput {
|
||||
fn new(window: &mut Window, cx: &mut App) -> Self {
|
||||
fn new(model_index: usize, window: &mut Window, cx: &mut App) -> Self {
|
||||
let base_tab_index = (3 + (model_index * 4)) as isize;
|
||||
|
||||
let model_name = single_line_input(
|
||||
"Model Name",
|
||||
"e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
|
||||
None,
|
||||
base_tab_index + 1,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
@@ -96,6 +128,7 @@ impl ModelInput {
|
||||
"Max Completion Tokens",
|
||||
"200000",
|
||||
Some("200000"),
|
||||
base_tab_index + 2,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
@@ -103,16 +136,26 @@ impl ModelInput {
|
||||
"Max Output Tokens",
|
||||
"Max Output Tokens",
|
||||
Some("32000"),
|
||||
base_tab_index + 3,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx);
|
||||
let max_tokens = single_line_input(
|
||||
"Max Tokens",
|
||||
"Max Tokens",
|
||||
Some("200000"),
|
||||
base_tab_index + 4,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
let ModelCapabilities {
|
||||
tools,
|
||||
images,
|
||||
parallel_tool_calls,
|
||||
prompt_cache_key,
|
||||
} = ModelCapabilities::default();
|
||||
|
||||
Self {
|
||||
name: model_name,
|
||||
max_completion_tokens,
|
||||
@@ -165,24 +208,6 @@ impl ModelInput {
|
||||
}
|
||||
}
|
||||
|
||||
fn single_line_input(
|
||||
label: impl Into<SharedString>,
|
||||
placeholder: impl Into<SharedString>,
|
||||
text: Option<&str>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<InputField> {
|
||||
cx.new(|cx| {
|
||||
let input = InputField::new(window, cx, placeholder).label(label);
|
||||
if let Some(text) = text {
|
||||
input
|
||||
.editor()
|
||||
.update(cx, |editor, cx| editor.set_text(text, window, cx));
|
||||
}
|
||||
input
|
||||
})
|
||||
}
|
||||
|
||||
fn save_provider_to_settings(
|
||||
input: &AddLlmProviderInput,
|
||||
cx: &mut App,
|
||||
@@ -258,6 +283,7 @@ fn save_provider_to_settings(
|
||||
pub struct AddLlmProviderModal {
|
||||
provider: LlmCompatibleProvider,
|
||||
input: AddLlmProviderInput,
|
||||
scroll_handle: ScrollHandle,
|
||||
focus_handle: FocusHandle,
|
||||
last_error: Option<SharedString>,
|
||||
}
|
||||
@@ -278,6 +304,7 @@ impl AddLlmProviderModal {
|
||||
provider,
|
||||
last_error: None,
|
||||
focus_handle: cx.focus_handle(),
|
||||
scroll_handle: ScrollHandle::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -418,6 +445,19 @@ impl AddLlmProviderModal {
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn on_tab(&mut self, _: &menu::SelectNext, window: &mut Window, _: &mut Context<Self>) {
|
||||
window.focus_next();
|
||||
}
|
||||
|
||||
fn on_tab_prev(
|
||||
&mut self,
|
||||
_: &menu::SelectPrevious,
|
||||
window: &mut Window,
|
||||
_: &mut Context<Self>,
|
||||
) {
|
||||
window.focus_prev();
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
|
||||
@@ -431,15 +471,27 @@ impl Focusable for AddLlmProviderModal {
|
||||
impl ModalView for AddLlmProviderModal {}
|
||||
|
||||
impl Render for AddLlmProviderModal {
|
||||
fn render(&mut self, _window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
|
||||
fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
|
||||
let focus_handle = self.focus_handle(cx);
|
||||
|
||||
div()
|
||||
let window_size = window.viewport_size();
|
||||
let rem_size = window.rem_size();
|
||||
let is_large_window = window_size.height / rem_size > rems_from_px(600.).0;
|
||||
|
||||
let modal_max_height = if is_large_window {
|
||||
rems_from_px(450.)
|
||||
} else {
|
||||
rems_from_px(200.)
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.id("add-llm-provider-modal")
|
||||
.key_context("AddLlmProviderModal")
|
||||
.w(rems(34.))
|
||||
.elevation_3(cx)
|
||||
.on_action(cx.listener(Self::cancel))
|
||||
.on_action(cx.listener(Self::on_tab))
|
||||
.on_action(cx.listener(Self::on_tab_prev))
|
||||
.capture_any_mouse_down(cx.listener(|this, _, window, cx| {
|
||||
this.focus_handle(cx).focus(window);
|
||||
}))
|
||||
@@ -462,17 +514,25 @@ impl Render for AddLlmProviderModal {
|
||||
)
|
||||
})
|
||||
.child(
|
||||
v_flex()
|
||||
.id("modal_content")
|
||||
div()
|
||||
.size_full()
|
||||
.max_h_128()
|
||||
.overflow_y_scroll()
|
||||
.px(DynamicSpacing::Base12.rems(cx))
|
||||
.gap(DynamicSpacing::Base04.rems(cx))
|
||||
.child(self.input.provider_name.clone())
|
||||
.child(self.input.api_url.clone())
|
||||
.child(self.input.api_key.clone())
|
||||
.child(self.render_model_section(cx)),
|
||||
.vertical_scrollbar_for(self.scroll_handle.clone(), window, cx)
|
||||
.child(
|
||||
v_flex()
|
||||
.id("modal_content")
|
||||
.size_full()
|
||||
.tab_group()
|
||||
.max_h(modal_max_height)
|
||||
.pl_3()
|
||||
.pr_4()
|
||||
.gap_2()
|
||||
.overflow_y_scroll()
|
||||
.track_scroll(&self.scroll_handle)
|
||||
.child(self.input.provider_name.clone())
|
||||
.child(self.input.api_url.clone())
|
||||
.child(self.input.api_key.clone())
|
||||
.child(self.render_model_section(cx)),
|
||||
),
|
||||
)
|
||||
.footer(
|
||||
ModalFooter::new().end_slot(
|
||||
@@ -642,7 +702,7 @@ mod tests {
|
||||
let cx = setup_test(cx).await;
|
||||
|
||||
cx.update(|window, cx| {
|
||||
let model_input = ModelInput::new(window, cx);
|
||||
let model_input = ModelInput::new(0, window, cx);
|
||||
model_input.name.update(cx, |input, cx| {
|
||||
input.editor().update(cx, |editor, cx| {
|
||||
editor.set_text("somemodel", window, cx);
|
||||
@@ -678,7 +738,7 @@ mod tests {
|
||||
let cx = setup_test(cx).await;
|
||||
|
||||
cx.update(|window, cx| {
|
||||
let mut model_input = ModelInput::new(window, cx);
|
||||
let mut model_input = ModelInput::new(0, window, cx);
|
||||
model_input.name.update(cx, |input, cx| {
|
||||
input.editor().update(cx, |editor, cx| {
|
||||
editor.set_text("somemodel", window, cx);
|
||||
@@ -703,7 +763,7 @@ mod tests {
|
||||
let cx = setup_test(cx).await;
|
||||
|
||||
cx.update(|window, cx| {
|
||||
let mut model_input = ModelInput::new(window, cx);
|
||||
let mut model_input = ModelInput::new(0, window, cx);
|
||||
model_input.name.update(cx, |input, cx| {
|
||||
input.editor().update(cx, |editor, cx| {
|
||||
editor.set_text("somemodel", window, cx);
|
||||
@@ -767,7 +827,7 @@ mod tests {
|
||||
models.iter().enumerate()
|
||||
{
|
||||
if i >= input.models.len() {
|
||||
input.models.push(ModelInput::new(window, cx));
|
||||
input.models.push(ModelInput::new(i, window, cx));
|
||||
}
|
||||
let model = &mut input.models[i];
|
||||
set_text(&model.name, name, window, cx);
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::{
|
||||
};
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use context_server::{ContextServerCommand, ContextServerId};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use gpui::{
|
||||
@@ -20,6 +21,7 @@ use project::{
|
||||
project_settings::{ContextServerSettings, ProjectSettings},
|
||||
worktree_store::WorktreeStore,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use settings::{Settings as _, update_settings_file};
|
||||
use theme::ThemeSettings;
|
||||
use ui::{
|
||||
@@ -37,6 +39,11 @@ enum ConfigurationTarget {
|
||||
id: ContextServerId,
|
||||
command: ContextServerCommand,
|
||||
},
|
||||
ExistingHttp {
|
||||
id: ContextServerId,
|
||||
url: String,
|
||||
headers: HashMap<String, String>,
|
||||
},
|
||||
Extension {
|
||||
id: ContextServerId,
|
||||
repository_url: Option<SharedString>,
|
||||
@@ -47,9 +54,11 @@ enum ConfigurationTarget {
|
||||
enum ConfigurationSource {
|
||||
New {
|
||||
editor: Entity<Editor>,
|
||||
is_http: bool,
|
||||
},
|
||||
Existing {
|
||||
editor: Entity<Editor>,
|
||||
is_http: bool,
|
||||
},
|
||||
Extension {
|
||||
id: ContextServerId,
|
||||
@@ -97,6 +106,7 @@ impl ConfigurationSource {
|
||||
match target {
|
||||
ConfigurationTarget::New => ConfigurationSource::New {
|
||||
editor: create_editor(context_server_input(None), jsonc_language, window, cx),
|
||||
is_http: false,
|
||||
},
|
||||
ConfigurationTarget::Existing { id, command } => ConfigurationSource::Existing {
|
||||
editor: create_editor(
|
||||
@@ -105,6 +115,20 @@ impl ConfigurationSource {
|
||||
window,
|
||||
cx,
|
||||
),
|
||||
is_http: false,
|
||||
},
|
||||
ConfigurationTarget::ExistingHttp {
|
||||
id,
|
||||
url,
|
||||
headers: auth,
|
||||
} => ConfigurationSource::Existing {
|
||||
editor: create_editor(
|
||||
context_server_http_input(Some((id, url, auth))),
|
||||
jsonc_language,
|
||||
window,
|
||||
cx,
|
||||
),
|
||||
is_http: true,
|
||||
},
|
||||
ConfigurationTarget::Extension {
|
||||
id,
|
||||
@@ -141,16 +165,30 @@ impl ConfigurationSource {
|
||||
|
||||
fn output(&self, cx: &mut App) -> Result<(ContextServerId, ContextServerSettings)> {
|
||||
match self {
|
||||
ConfigurationSource::New { editor } | ConfigurationSource::Existing { editor } => {
|
||||
parse_input(&editor.read(cx).text(cx)).map(|(id, command)| {
|
||||
(
|
||||
id,
|
||||
ContextServerSettings::Custom {
|
||||
enabled: true,
|
||||
command,
|
||||
},
|
||||
)
|
||||
})
|
||||
ConfigurationSource::New { editor, is_http }
|
||||
| ConfigurationSource::Existing { editor, is_http } => {
|
||||
if *is_http {
|
||||
parse_http_input(&editor.read(cx).text(cx)).map(|(id, url, auth)| {
|
||||
(
|
||||
id,
|
||||
ContextServerSettings::Http {
|
||||
enabled: true,
|
||||
url,
|
||||
headers: auth,
|
||||
},
|
||||
)
|
||||
})
|
||||
} else {
|
||||
parse_input(&editor.read(cx).text(cx)).map(|(id, command)| {
|
||||
(
|
||||
id,
|
||||
ContextServerSettings::Custom {
|
||||
enabled: true,
|
||||
command,
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
ConfigurationSource::Extension {
|
||||
id,
|
||||
@@ -212,6 +250,66 @@ fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand)
|
||||
)
|
||||
}
|
||||
|
||||
fn context_server_http_input(
|
||||
existing: Option<(ContextServerId, String, HashMap<String, String>)>,
|
||||
) -> String {
|
||||
let (name, url, headers) = match existing {
|
||||
Some((id, url, headers)) => {
|
||||
let header = if headers.is_empty() {
|
||||
r#"// "Authorization": "Bearer <token>"#.to_string()
|
||||
} else {
|
||||
let json = serde_json::to_string_pretty(&headers).unwrap();
|
||||
let mut lines = json.split("\n").collect::<Vec<_>>();
|
||||
if lines.len() > 1 {
|
||||
lines.remove(0);
|
||||
lines.pop();
|
||||
}
|
||||
lines
|
||||
.into_iter()
|
||||
.map(|line| format!(" {}", line))
|
||||
.collect::<String>()
|
||||
};
|
||||
(id.0.to_string(), url, header)
|
||||
}
|
||||
None => (
|
||||
"some-remote-server".to_string(),
|
||||
"https://example.com/mcp".to_string(),
|
||||
r#"// "Authorization": "Bearer <token>"#.to_string(),
|
||||
),
|
||||
};
|
||||
|
||||
format!(
|
||||
r#"{{
|
||||
/// The name of your remote MCP server
|
||||
"{name}": {{
|
||||
/// The URL of the remote MCP server
|
||||
"url": "{url}",
|
||||
"headers": {{
|
||||
/// Any headers to send along
|
||||
{headers}
|
||||
}}
|
||||
}}
|
||||
}}"#
|
||||
)
|
||||
}
|
||||
|
||||
fn parse_http_input(text: &str) -> Result<(ContextServerId, String, HashMap<String, String>)> {
|
||||
#[derive(Deserialize)]
|
||||
struct Temp {
|
||||
url: String,
|
||||
#[serde(default)]
|
||||
headers: HashMap<String, String>,
|
||||
}
|
||||
let value: HashMap<String, Temp> = serde_json_lenient::from_str(text)?;
|
||||
if value.len() != 1 {
|
||||
anyhow::bail!("Expected exactly one context server configuration");
|
||||
}
|
||||
|
||||
let (key, value) = value.into_iter().next().unwrap();
|
||||
|
||||
Ok((ContextServerId(key.into()), value.url, value.headers))
|
||||
}
|
||||
|
||||
fn resolve_context_server_extension(
|
||||
id: ContextServerId,
|
||||
worktree_store: Entity<WorktreeStore>,
|
||||
@@ -312,6 +410,15 @@ impl ConfigureContextServerModal {
|
||||
id: server_id,
|
||||
command,
|
||||
}),
|
||||
ContextServerSettings::Http {
|
||||
enabled: _,
|
||||
url,
|
||||
headers,
|
||||
} => Some(ConfigurationTarget::ExistingHttp {
|
||||
id: server_id,
|
||||
url,
|
||||
headers,
|
||||
}),
|
||||
ContextServerSettings::Extension { .. } => {
|
||||
match workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
@@ -353,6 +460,7 @@ impl ConfigureContextServerModal {
|
||||
state: State::Idle,
|
||||
original_server_id: match &target {
|
||||
ConfigurationTarget::Existing { id, .. } => Some(id.clone()),
|
||||
ConfigurationTarget::ExistingHttp { id, .. } => Some(id.clone()),
|
||||
ConfigurationTarget::Extension { id, .. } => Some(id.clone()),
|
||||
ConfigurationTarget::New => None,
|
||||
},
|
||||
@@ -481,7 +589,7 @@ impl ModalView for ConfigureContextServerModal {}
|
||||
impl Focusable for ConfigureContextServerModal {
|
||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||
match &self.source {
|
||||
ConfigurationSource::New { editor } => editor.focus_handle(cx),
|
||||
ConfigurationSource::New { editor, .. } => editor.focus_handle(cx),
|
||||
ConfigurationSource::Existing { editor, .. } => editor.focus_handle(cx),
|
||||
ConfigurationSource::Extension { editor, .. } => editor
|
||||
.as_ref()
|
||||
@@ -527,9 +635,10 @@ impl ConfigureContextServerModal {
|
||||
}
|
||||
|
||||
fn render_modal_content(&self, cx: &App) -> AnyElement {
|
||||
// All variants now use single editor approach
|
||||
let editor = match &self.source {
|
||||
ConfigurationSource::New { editor } => editor,
|
||||
ConfigurationSource::Existing { editor } => editor,
|
||||
ConfigurationSource::New { editor, .. } => editor,
|
||||
ConfigurationSource::Existing { editor, .. } => editor,
|
||||
ConfigurationSource::Extension { editor, .. } => {
|
||||
let Some(editor) = editor else {
|
||||
return div().into_any_element();
|
||||
@@ -601,6 +710,36 @@ impl ConfigureContextServerModal {
|
||||
move |_, _, cx| cx.open_url(&repository_url)
|
||||
}),
|
||||
)
|
||||
} else if let ConfigurationSource::New { is_http, .. } = &self.source {
|
||||
let label = if *is_http {
|
||||
"Run command"
|
||||
} else {
|
||||
"Connect via HTTP"
|
||||
};
|
||||
let tooltip = if *is_http {
|
||||
"Configure an MCP serevr that runs on stdin/stdout."
|
||||
} else {
|
||||
"Configure an MCP server that you connect to over HTTP"
|
||||
};
|
||||
|
||||
Some(
|
||||
Button::new("toggle-kind", label)
|
||||
.tooltip(Tooltip::text(tooltip))
|
||||
.on_click(cx.listener(|this, _, window, cx| match &mut this.source {
|
||||
ConfigurationSource::New { editor, is_http } => {
|
||||
*is_http = !*is_http;
|
||||
let new_text = if *is_http {
|
||||
context_server_http_input(None)
|
||||
} else {
|
||||
context_server_input(None)
|
||||
};
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.set_text(new_text, window, cx);
|
||||
})
|
||||
}
|
||||
_ => {}
|
||||
})),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
|
||||
@@ -7,9 +7,10 @@ use assistant_slash_command::{
|
||||
use assistant_slash_commands::FileCommandMetadata;
|
||||
use client::{self, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry};
|
||||
use clock::ReplicaId;
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||
use cloud_llm_client::{CompletionIntent, UsageLimit};
|
||||
use collections::{HashMap, HashSet};
|
||||
use fs::{Fs, RenameOptions};
|
||||
|
||||
use futures::{FutureExt, StreamExt, future::Shared};
|
||||
use gpui::{
|
||||
App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription,
|
||||
@@ -2073,14 +2074,15 @@ impl TextThread {
|
||||
});
|
||||
|
||||
match event {
|
||||
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
|
||||
if let CompletionRequestStatus::UsageUpdated { amount, limit } = status_update {
|
||||
this.update_model_request_usage(
|
||||
amount as u32,
|
||||
limit,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
LanguageModelCompletionEvent::Started |
|
||||
LanguageModelCompletionEvent::Queued {..} |
|
||||
LanguageModelCompletionEvent::ToolUseLimitReached { .. } => {}
|
||||
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
|
||||
this.update_model_request_usage(
|
||||
amount as u32,
|
||||
limit,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {}
|
||||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
@@ -2142,7 +2144,7 @@ impl TextThread {
|
||||
}
|
||||
LanguageModelCompletionEvent::ToolUse(_) |
|
||||
LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
|
||||
LanguageModelCompletionEvent::UsageUpdate(_) => {}
|
||||
LanguageModelCompletionEvent::TokenUsage(_) => {}
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -19,4 +19,5 @@ ordered-float.workspace = true
|
||||
rustc-hash.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
|
||||
@@ -40,9 +40,47 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R
|
||||
pub struct SearchToolInput {
|
||||
/// An array of queries to run for gathering context relevant to the next prediction
|
||||
#[schemars(length(max = 3))]
|
||||
#[serde(deserialize_with = "deserialize_queries")]
|
||||
pub queries: Box<[SearchToolQuery]>,
|
||||
}
|
||||
|
||||
fn deserialize_queries<'de, D>(deserializer: D) -> Result<Box<[SearchToolQuery]>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::Error;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum QueryCollection {
|
||||
Array(Box<[SearchToolQuery]>),
|
||||
DoubleArray(Box<[Box<[SearchToolQuery]>]>),
|
||||
Single(SearchToolQuery),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum MaybeDoubleEncoded {
|
||||
SingleEncoded(QueryCollection),
|
||||
DoubleEncoded(String),
|
||||
}
|
||||
|
||||
let result = MaybeDoubleEncoded::deserialize(deserializer)?;
|
||||
|
||||
let normalized = match result {
|
||||
MaybeDoubleEncoded::SingleEncoded(value) => value,
|
||||
MaybeDoubleEncoded::DoubleEncoded(value) => {
|
||||
serde_json::from_str(&value).map_err(D::Error::custom)?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(match normalized {
|
||||
QueryCollection::Array(items) => items,
|
||||
QueryCollection::Single(search_tool_query) => Box::new([search_tool_query]),
|
||||
QueryCollection::DoubleArray(double_array) => double_array.into_iter().flatten().collect(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Search for relevant code by path, syntax hierarchy, and content.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
|
||||
pub struct SearchToolQuery {
|
||||
@@ -92,3 +130,115 @@ const TOOL_USE_REMINDER: &str = indoc! {"
|
||||
--
|
||||
Analyze the user's intent in one to two sentences, then call the `search` tool.
|
||||
"};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_queries() {
|
||||
let single_query_json = indoc! {r#"{
|
||||
"queries": {
|
||||
"glob": "**/*.rs",
|
||||
"syntax_node": ["fn test"],
|
||||
"content": "assert"
|
||||
}
|
||||
}"#};
|
||||
|
||||
let flat_input: SearchToolInput = serde_json::from_str(single_query_json).unwrap();
|
||||
assert_eq!(flat_input.queries.len(), 1);
|
||||
assert_eq!(flat_input.queries[0].glob, "**/*.rs");
|
||||
assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
|
||||
assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
|
||||
|
||||
let flat_json = indoc! {r#"{
|
||||
"queries": [
|
||||
{
|
||||
"glob": "**/*.rs",
|
||||
"syntax_node": ["fn test"],
|
||||
"content": "assert"
|
||||
},
|
||||
{
|
||||
"glob": "**/*.ts",
|
||||
"syntax_node": [],
|
||||
"content": null
|
||||
}
|
||||
]
|
||||
}"#};
|
||||
|
||||
let flat_input: SearchToolInput = serde_json::from_str(flat_json).unwrap();
|
||||
assert_eq!(flat_input.queries.len(), 2);
|
||||
assert_eq!(flat_input.queries[0].glob, "**/*.rs");
|
||||
assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
|
||||
assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
|
||||
assert_eq!(flat_input.queries[1].glob, "**/*.ts");
|
||||
assert_eq!(flat_input.queries[1].syntax_node.len(), 0);
|
||||
assert_eq!(flat_input.queries[1].content, None);
|
||||
|
||||
let nested_json = indoc! {r#"{
|
||||
"queries": [
|
||||
[
|
||||
{
|
||||
"glob": "**/*.rs",
|
||||
"syntax_node": ["fn test"],
|
||||
"content": "assert"
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"glob": "**/*.ts",
|
||||
"syntax_node": [],
|
||||
"content": null
|
||||
}
|
||||
]
|
||||
]
|
||||
}"#};
|
||||
|
||||
let nested_input: SearchToolInput = serde_json::from_str(nested_json).unwrap();
|
||||
|
||||
assert_eq!(nested_input.queries.len(), 2);
|
||||
|
||||
assert_eq!(nested_input.queries[0].glob, "**/*.rs");
|
||||
assert_eq!(nested_input.queries[0].syntax_node, vec!["fn test"]);
|
||||
assert_eq!(nested_input.queries[0].content, Some("assert".to_string()));
|
||||
assert_eq!(nested_input.queries[1].glob, "**/*.ts");
|
||||
assert_eq!(nested_input.queries[1].syntax_node.len(), 0);
|
||||
assert_eq!(nested_input.queries[1].content, None);
|
||||
|
||||
let double_encoded_queries = serde_json::to_string(&json!({
|
||||
"queries": serde_json::to_string(&json!([
|
||||
{
|
||||
"glob": "**/*.rs",
|
||||
"syntax_node": ["fn test"],
|
||||
"content": "assert"
|
||||
},
|
||||
{
|
||||
"glob": "**/*.ts",
|
||||
"syntax_node": [],
|
||||
"content": null
|
||||
}
|
||||
])).unwrap()
|
||||
}))
|
||||
.unwrap();
|
||||
|
||||
let double_encoded_input: SearchToolInput =
|
||||
serde_json::from_str(&double_encoded_queries).unwrap();
|
||||
|
||||
assert_eq!(double_encoded_input.queries.len(), 2);
|
||||
|
||||
assert_eq!(double_encoded_input.queries[0].glob, "**/*.rs");
|
||||
assert_eq!(double_encoded_input.queries[0].syntax_node, vec!["fn test"]);
|
||||
assert_eq!(
|
||||
double_encoded_input.queries[0].content,
|
||||
Some("assert".to_string())
|
||||
);
|
||||
assert_eq!(double_encoded_input.queries[1].glob, "**/*.ts");
|
||||
assert_eq!(double_encoded_input.queries[1].syntax_node.len(), 0);
|
||||
assert_eq!(double_encoded_input.queries[1].content, None);
|
||||
|
||||
// ### ERROR Switching from var declarations to lexical declarations [RUN 073]
|
||||
// invalid search json {"queries": ["express/lib/response.js", "var\\s+[a-zA-Z_][a-zA-Z0-9_]*\\s*=.*;", "function.*\\(.*\\).*\\{.*\\}"]}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ workspace = true
|
||||
path = "src/context_server.rs"
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
test-support = ["gpui/test-support"]
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
@@ -20,6 +20,7 @@ async-trait.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
http_client = { workspace = true, features = ["test-support"] }
|
||||
log.workspace = true
|
||||
net.workspace = true
|
||||
parking_lot.workspace = true
|
||||
@@ -32,3 +33,6 @@ smol.workspace = true
|
||||
tempfile.workspace = true
|
||||
url = { workspace = true, features = ["serde"] }
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -6,6 +6,8 @@ pub mod test;
|
||||
pub mod transport;
|
||||
pub mod types;
|
||||
|
||||
use collections::HashMap;
|
||||
use http_client::HttpClient;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::{fmt::Display, path::PathBuf};
|
||||
@@ -15,6 +17,9 @@ use client::Client;
|
||||
use gpui::AsyncApp;
|
||||
use parking_lot::RwLock;
|
||||
pub use settings::ContextServerCommand;
|
||||
use url::Url;
|
||||
|
||||
use crate::transport::HttpTransport;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct ContextServerId(pub Arc<str>);
|
||||
@@ -52,6 +57,25 @@ impl ContextServer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn http(
|
||||
id: ContextServerId,
|
||||
endpoint: &Url,
|
||||
headers: HashMap<String, String>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
executor: gpui::BackgroundExecutor,
|
||||
) -> Result<Self> {
|
||||
let transport = match endpoint.scheme() {
|
||||
"http" | "https" => {
|
||||
log::info!("Using HTTP transport for {}", endpoint);
|
||||
let transport =
|
||||
HttpTransport::new(http_client, endpoint.to_string(), headers, executor);
|
||||
Arc::new(transport) as _
|
||||
}
|
||||
_ => anyhow::bail!("unsupported MCP url scheme {}", endpoint.scheme()),
|
||||
};
|
||||
Ok(Self::new(id, transport))
|
||||
}
|
||||
|
||||
pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
|
||||
Self {
|
||||
id,
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
pub mod http;
|
||||
mod stdio_transport;
|
||||
|
||||
use std::pin::Pin;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use std::pin::Pin;
|
||||
|
||||
pub use http::*;
|
||||
pub use stdio_transport::*;
|
||||
|
||||
#[async_trait]
|
||||
|
||||
259
crates/context_server/src/transport/http.rs
Normal file
259
crates/context_server/src/transport/http.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use collections::HashMap;
|
||||
use futures::{Stream, StreamExt};
|
||||
use gpui::BackgroundExecutor;
|
||||
use http_client::{AsyncBody, HttpClient, Request, Response, http::Method};
|
||||
use parking_lot::Mutex as SyncMutex;
|
||||
use smol::channel;
|
||||
use std::{pin::Pin, sync::Arc};
|
||||
|
||||
use crate::transport::Transport;
|
||||
|
||||
// Constants from MCP spec
|
||||
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
|
||||
const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
|
||||
const JSON_MIME_TYPE: &str = "application/json";
|
||||
|
||||
/// HTTP Transport with session management and SSE support
|
||||
pub struct HttpTransport {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
endpoint: String,
|
||||
session_id: Arc<SyncMutex<Option<String>>>,
|
||||
executor: BackgroundExecutor,
|
||||
response_tx: channel::Sender<String>,
|
||||
response_rx: channel::Receiver<String>,
|
||||
error_tx: channel::Sender<String>,
|
||||
error_rx: channel::Receiver<String>,
|
||||
// Authentication headers to include in requests
|
||||
headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl HttpTransport {
|
||||
pub fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
endpoint: String,
|
||||
headers: HashMap<String, String>,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Self {
|
||||
let (response_tx, response_rx) = channel::unbounded();
|
||||
let (error_tx, error_rx) = channel::unbounded();
|
||||
|
||||
Self {
|
||||
http_client,
|
||||
executor,
|
||||
endpoint,
|
||||
session_id: Arc::new(SyncMutex::new(None)),
|
||||
response_tx,
|
||||
response_rx,
|
||||
error_tx,
|
||||
error_rx,
|
||||
headers,
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a message and handle the response based on content type
|
||||
async fn send_message(&self, message: String) -> Result<()> {
|
||||
let is_notification =
|
||||
!message.contains("\"id\":") || message.contains("notifications/initialized");
|
||||
|
||||
let mut request_builder = Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(&self.endpoint)
|
||||
.header("Content-Type", JSON_MIME_TYPE)
|
||||
.header(
|
||||
"Accept",
|
||||
format!("{}, {}", JSON_MIME_TYPE, EVENT_STREAM_MIME_TYPE),
|
||||
);
|
||||
|
||||
for (key, value) in &self.headers {
|
||||
request_builder = request_builder.header(key.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
// Add session ID if we have one (except for initialize)
|
||||
if let Some(ref session_id) = *self.session_id.lock() {
|
||||
request_builder = request_builder.header(HEADER_SESSION_ID, session_id.as_str());
|
||||
}
|
||||
|
||||
let request = request_builder.body(AsyncBody::from(message.into_bytes()))?;
|
||||
let mut response = self.http_client.send(request).await?;
|
||||
|
||||
// Handle different response types based on status and content-type
|
||||
match response.status() {
|
||||
status if status.is_success() => {
|
||||
// Check content type
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
// Extract session ID from response headers if present
|
||||
if let Some(session_id) = response
|
||||
.headers()
|
||||
.get(HEADER_SESSION_ID)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
*self.session_id.lock() = Some(session_id.to_string());
|
||||
log::debug!("Session ID set: {}", session_id);
|
||||
}
|
||||
|
||||
match content_type {
|
||||
Some(ct) if ct.starts_with(JSON_MIME_TYPE) => {
|
||||
// JSON response - read and forward immediately
|
||||
let mut body = String::new();
|
||||
futures::AsyncReadExt::read_to_string(response.body_mut(), &mut body)
|
||||
.await?;
|
||||
|
||||
// Only send non-empty responses
|
||||
if !body.is_empty() {
|
||||
self.response_tx
|
||||
.send(body)
|
||||
.await
|
||||
.map_err(|_| anyhow!("Failed to send JSON response"))?;
|
||||
}
|
||||
}
|
||||
Some(ct) if ct.starts_with(EVENT_STREAM_MIME_TYPE) => {
|
||||
// SSE stream - set up streaming
|
||||
self.setup_sse_stream(response).await?;
|
||||
}
|
||||
_ => {
|
||||
// For notifications, 202 Accepted with no content type is ok
|
||||
if is_notification && status.as_u16() == 202 {
|
||||
log::debug!("Notification accepted");
|
||||
} else {
|
||||
return Err(anyhow!("Unexpected content type: {:?}", content_type));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
status if status.as_u16() == 202 => {
|
||||
// Accepted - notification acknowledged, no response needed
|
||||
log::debug!("Notification accepted");
|
||||
}
|
||||
_ => {
|
||||
let mut error_body = String::new();
|
||||
futures::AsyncReadExt::read_to_string(response.body_mut(), &mut error_body).await?;
|
||||
|
||||
self.error_tx
|
||||
.send(format!("HTTP {}: {}", response.status(), error_body))
|
||||
.await
|
||||
.map_err(|_| anyhow!("Failed to send error"))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set up SSE streaming from the response
|
||||
async fn setup_sse_stream(&self, mut response: Response<AsyncBody>) -> Result<()> {
|
||||
let response_tx = self.response_tx.clone();
|
||||
let error_tx = self.error_tx.clone();
|
||||
|
||||
// Spawn a task to handle the SSE stream
|
||||
smol::spawn(async move {
|
||||
let reader = futures::io::BufReader::new(response.body_mut());
|
||||
let mut lines = futures::AsyncBufReadExt::lines(reader);
|
||||
|
||||
let mut data_buffer = Vec::new();
|
||||
let mut in_message = false;
|
||||
|
||||
while let Some(line_result) = lines.next().await {
|
||||
match line_result {
|
||||
Ok(line) => {
|
||||
if line.is_empty() {
|
||||
// Empty line signals end of event
|
||||
if !data_buffer.is_empty() {
|
||||
let message = data_buffer.join("\n");
|
||||
|
||||
// Filter out ping messages and empty data
|
||||
if !message.trim().is_empty() && message != "ping" {
|
||||
if let Err(e) = response_tx.send(message).await {
|
||||
log::error!("Failed to send SSE message: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
data_buffer.clear();
|
||||
}
|
||||
in_message = false;
|
||||
} else if let Some(data) = line.strip_prefix("data: ") {
|
||||
// Handle data lines
|
||||
let data = data.trim();
|
||||
if !data.is_empty() {
|
||||
// Check if this is a ping message
|
||||
if data == "ping" {
|
||||
log::trace!("Received SSE ping");
|
||||
continue;
|
||||
}
|
||||
data_buffer.push(data.to_string());
|
||||
in_message = true;
|
||||
}
|
||||
} else if line.starts_with("event:")
|
||||
|| line.starts_with("id:")
|
||||
|| line.starts_with("retry:")
|
||||
{
|
||||
// Ignore other SSE fields
|
||||
continue;
|
||||
} else if in_message {
|
||||
// Continuation of data
|
||||
data_buffer.push(line);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = error_tx.send(format!("SSE stream error: {}", e)).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Transport for HttpTransport {
|
||||
async fn send(&self, message: String) -> Result<()> {
|
||||
self.send_message(message).await
|
||||
}
|
||||
|
||||
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||
Box::pin(self.response_rx.clone())
|
||||
}
|
||||
|
||||
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||
Box::pin(self.error_rx.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for HttpTransport {
|
||||
fn drop(&mut self) {
|
||||
// Try to cleanup session on drop
|
||||
let http_client = self.http_client.clone();
|
||||
let endpoint = self.endpoint.clone();
|
||||
let session_id = self.session_id.lock().clone();
|
||||
let headers = self.headers.clone();
|
||||
|
||||
if let Some(session_id) = session_id {
|
||||
self.executor
|
||||
.spawn(async move {
|
||||
let mut request_builder = Request::builder()
|
||||
.method(Method::DELETE)
|
||||
.uri(&endpoint)
|
||||
.header(HEADER_SESSION_ID, &session_id);
|
||||
|
||||
// Add authentication headers if present
|
||||
for (key, value) in headers {
|
||||
request_builder = request_builder.header(key.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
let request = request_builder.body(AsyncBody::empty());
|
||||
|
||||
if let Ok(request) = request {
|
||||
let _ = http_client.send(request).await;
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -289,15 +289,11 @@ impl minidumper::ServerHandler for CrashServer {
|
||||
pub fn panic_hook(info: &PanicHookInfo) {
|
||||
// Don't handle a panic on threads that are not relevant to the main execution.
|
||||
if extension_host::wasm_host::IS_WASM_THREAD.with(|v| v.load(Ordering::Acquire)) {
|
||||
log::error!("wasm thread panicked!");
|
||||
return;
|
||||
}
|
||||
|
||||
let message = info
|
||||
.payload()
|
||||
.downcast_ref::<&str>()
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| info.payload().downcast_ref::<String>().cloned())
|
||||
.unwrap_or_else(|| "Box<Any>".to_string());
|
||||
let message = info.payload_as_str().unwrap_or("Box<Any>").to_owned();
|
||||
|
||||
let span = info
|
||||
.location()
|
||||
|
||||
@@ -3291,8 +3291,8 @@ impl Editor {
|
||||
self.refresh_document_highlights(cx);
|
||||
refresh_linked_ranges(self, window, cx);
|
||||
|
||||
// self.refresh_selected_text_highlights(false, window, cx);
|
||||
// self.refresh_matching_bracket_highlights(window, cx);
|
||||
self.refresh_selected_text_highlights(false, window, cx);
|
||||
self.refresh_matching_bracket_highlights(window, cx);
|
||||
self.update_visible_edit_prediction(window, cx);
|
||||
self.edit_prediction_requires_modifier_in_indent_conflict = true;
|
||||
self.inline_blame_popover.take();
|
||||
@@ -21248,9 +21248,9 @@ impl Editor {
|
||||
self.active_indent_guides_state.dirty = true;
|
||||
self.refresh_active_diagnostics(cx);
|
||||
self.refresh_code_actions(window, cx);
|
||||
// self.refresh_selected_text_highlights(true, window, cx);
|
||||
self.refresh_selected_text_highlights(true, window, cx);
|
||||
self.refresh_single_line_folds(window, cx);
|
||||
// self.refresh_matching_bracket_highlights(window, cx);
|
||||
self.refresh_matching_bracket_highlights(window, cx);
|
||||
if self.has_active_edit_prediction() {
|
||||
self.update_visible_edit_prediction(window, cx);
|
||||
}
|
||||
@@ -21345,7 +21345,6 @@ impl Editor {
|
||||
}
|
||||
multi_buffer::Event::Reparsed(buffer_id) => {
|
||||
self.tasks_update_task = Some(self.refresh_runnables(window, cx));
|
||||
// self.refresh_selected_text_highlights(true, window, cx);
|
||||
jsx_tag_auto_close::refresh_enabled_in_any_buffer(self, multibuffer, cx);
|
||||
|
||||
cx.emit(EditorEvent::Reparsed(*buffer_id));
|
||||
@@ -23908,10 +23907,6 @@ impl EditorSnapshot {
|
||||
self.scroll_anchor.scroll_position(&self.display_snapshot)
|
||||
}
|
||||
|
||||
pub fn scroll_near_end(&self) -> bool {
|
||||
self.scroll_anchor.near_end(&self.display_snapshot)
|
||||
}
|
||||
|
||||
fn gutter_dimensions(
|
||||
&self,
|
||||
font_id: FontId,
|
||||
|
||||
@@ -9055,9 +9055,6 @@ impl Element for EditorElement {
|
||||
)
|
||||
});
|
||||
|
||||
if snapshot.scroll_near_end() {
|
||||
dbg!("near end!");
|
||||
}
|
||||
let mut scroll_position = snapshot.scroll_position();
|
||||
// The scroll position is a fractional point, the whole number of which represents
|
||||
// the top of the window in terms of display rows.
|
||||
|
||||
@@ -46,20 +46,12 @@ impl ScrollAnchor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn near_end(&self, snapshot: &DisplaySnapshot) -> bool {
|
||||
let editor_length = snapshot.max_point().row().as_f64();
|
||||
let scroll_top = self.anchor.to_display_point(snapshot).row().as_f64();
|
||||
(scroll_top - editor_length).abs() < 300.0
|
||||
}
|
||||
|
||||
pub fn scroll_position(&self, snapshot: &DisplaySnapshot) -> gpui::Point<ScrollOffset> {
|
||||
self.offset.apply_along(Axis::Vertical, |offset| {
|
||||
if self.anchor == Anchor::min() {
|
||||
0.
|
||||
} else {
|
||||
dbg!(snapshot.max_point().row().as_f64());
|
||||
let scroll_top = self.anchor.to_display_point(snapshot).row().as_f64();
|
||||
dbg!(scroll_top, offset);
|
||||
(offset + scroll_top).max(0.)
|
||||
}
|
||||
})
|
||||
@@ -251,11 +243,6 @@ impl ScrollManager {
|
||||
}
|
||||
}
|
||||
};
|
||||
let near_end = self.anchor.near_end(map);
|
||||
// // TODO call load more here
|
||||
// if near_end {
|
||||
// cx.read();
|
||||
// }
|
||||
|
||||
let scroll_top_row = DisplayRow(scroll_top as u32);
|
||||
let scroll_top_buffer_point = map
|
||||
|
||||
@@ -1250,9 +1250,12 @@ pub fn response_events_to_markdown(
|
||||
));
|
||||
}
|
||||
Ok(
|
||||
LanguageModelCompletionEvent::UsageUpdate(_)
|
||||
LanguageModelCompletionEvent::TokenUsage(_)
|
||||
| LanguageModelCompletionEvent::ToolUseLimitReached
|
||||
| LanguageModelCompletionEvent::StartMessage { .. }
|
||||
| LanguageModelCompletionEvent::StatusUpdate { .. },
|
||||
| LanguageModelCompletionEvent::RequestUsage { .. }
|
||||
| LanguageModelCompletionEvent::Queued { .. }
|
||||
| LanguageModelCompletionEvent::Started,
|
||||
) => {}
|
||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
json_parse_error, ..
|
||||
@@ -1335,11 +1338,14 @@ impl ThreadDialog {
|
||||
}
|
||||
|
||||
// Skip these
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(_))
|
||||
Ok(LanguageModelCompletionEvent::TokenUsage(_))
|
||||
| Ok(LanguageModelCompletionEvent::RedactedThinking { .. })
|
||||
| Ok(LanguageModelCompletionEvent::StatusUpdate { .. })
|
||||
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
|
||||
| Ok(LanguageModelCompletionEvent::Stop(_)) => {}
|
||||
| Ok(LanguageModelCompletionEvent::Stop(_))
|
||||
| Ok(LanguageModelCompletionEvent::Queued { .. })
|
||||
| Ok(LanguageModelCompletionEvent::Started)
|
||||
| Ok(LanguageModelCompletionEvent::RequestUsage { .. })
|
||||
| Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => {}
|
||||
|
||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
json_parse_error,
|
||||
|
||||
@@ -537,7 +537,6 @@ fn wasm_engine(executor: &BackgroundExecutor) -> wasmtime::Engine {
|
||||
let engine_ref = engine.weak();
|
||||
executor
|
||||
.spawn(async move {
|
||||
IS_WASM_THREAD.with(|v| v.store(true, Ordering::Release));
|
||||
// Somewhat arbitrary interval, as it isn't a guaranteed interval.
|
||||
// But this is a rough upper bound for how long the extension execution can block on
|
||||
// `Future::poll`.
|
||||
@@ -643,6 +642,12 @@ impl WasmHost {
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded::<ExtensionCall>();
|
||||
let extension_task = async move {
|
||||
// note: Setting the thread local here will slowly "poison" all tokio threads
|
||||
// causing us to not record their panics any longer.
|
||||
//
|
||||
// This is fine though, the main zed binary only uses tokio for livekit and wasm extensions.
|
||||
// Livekit seldom (if ever) panics 🤞 so the likelihood of us missing a panic in sentry is very low.
|
||||
IS_WASM_THREAD.with(|v| v.store(true, Ordering::Release));
|
||||
while let Some(call) = rx.next().await {
|
||||
(call)(&mut extension, &mut store).await;
|
||||
}
|
||||
@@ -659,8 +664,8 @@ impl WasmHost {
|
||||
cx.spawn(async move |cx| {
|
||||
let (extension_task, manifest, work_dir, tx, zed_api_version) =
|
||||
cx.background_executor().spawn(load_extension_task).await?;
|
||||
// we need to run run the task in an extension context as wasmtime_wasi may
|
||||
// call into tokio, accessing its runtime handle
|
||||
// we need to run run the task in a tokio context as wasmtime_wasi may
|
||||
// call into tokio, accessing its runtime handle when we trigger the `engine.increment_epoch()` above.
|
||||
let task = Arc::new(gpui_tokio::Tokio::spawn(cx, extension_task)?);
|
||||
|
||||
Ok(WasmExtension {
|
||||
|
||||
@@ -990,6 +990,9 @@ impl ExtensionImports for WasmState {
|
||||
command: None,
|
||||
settings: Some(settings),
|
||||
})?),
|
||||
project::project_settings::ContextServerSettings::Http { .. } => {
|
||||
bail!("remote context server settings not supported in 0.6.0")
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
|
||||
@@ -60,27 +60,27 @@ pub fn register_editor(editor: &mut Editor, buffer: Entity<MultiBuffer>, cx: &mu
|
||||
buffer_added(editor, buffer, cx);
|
||||
}
|
||||
|
||||
// cx.subscribe(&cx.entity(), |editor, _, event, cx| match event {
|
||||
// EditorEvent::ExcerptsAdded { buffer, .. } => buffer_added(editor, buffer.clone(), cx),
|
||||
// EditorEvent::ExcerptsExpanded { ids } => {
|
||||
// let multibuffer = editor.buffer().read(cx).snapshot(cx);
|
||||
// for excerpt_id in ids {
|
||||
// let Some(buffer) = multibuffer.buffer_for_excerpt(*excerpt_id) else {
|
||||
// continue;
|
||||
// };
|
||||
// let addon = editor.addon::<ConflictAddon>().unwrap();
|
||||
// let Some(conflict_set) = addon.conflict_set(buffer.remote_id()).clone() else {
|
||||
// return;
|
||||
// };
|
||||
// excerpt_for_buffer_updated(editor, conflict_set, cx);
|
||||
// }
|
||||
// }
|
||||
// EditorEvent::ExcerptsRemoved {
|
||||
// removed_buffer_ids, ..
|
||||
// } => buffers_removed(editor, removed_buffer_ids, cx),
|
||||
// _ => {}
|
||||
// })
|
||||
// .detach();
|
||||
cx.subscribe(&cx.entity(), |editor, _, event, cx| match event {
|
||||
EditorEvent::ExcerptsAdded { buffer, .. } => buffer_added(editor, buffer.clone(), cx),
|
||||
EditorEvent::ExcerptsExpanded { ids } => {
|
||||
let multibuffer = editor.buffer().read(cx).snapshot(cx);
|
||||
for excerpt_id in ids {
|
||||
let Some(buffer) = multibuffer.buffer_for_excerpt(*excerpt_id) else {
|
||||
continue;
|
||||
};
|
||||
let addon = editor.addon::<ConflictAddon>().unwrap();
|
||||
let Some(conflict_set) = addon.conflict_set(buffer.remote_id()).clone() else {
|
||||
return;
|
||||
};
|
||||
excerpt_for_buffer_updated(editor, conflict_set, cx);
|
||||
}
|
||||
}
|
||||
EditorEvent::ExcerptsRemoved {
|
||||
removed_buffer_ids, ..
|
||||
} => buffers_removed(editor, removed_buffer_ids, cx),
|
||||
_ => {}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn excerpt_for_buffer_updated(
|
||||
|
||||
@@ -30,11 +30,10 @@ use git::{
|
||||
TrashUntrackedFiles, UnstageAll,
|
||||
};
|
||||
use gpui::{
|
||||
Action, AppContext, AsyncApp, AsyncWindowContext, ClickEvent, Corner, DismissEvent, Entity,
|
||||
EventEmitter, FocusHandle, Focusable, KeyContext, ListHorizontalSizingBehavior,
|
||||
ListSizingBehavior, MouseButton, MouseDownEvent, Point, PromptLevel, ScrollStrategy,
|
||||
Subscription, Task, UniformListScrollHandle, WeakEntity, actions, anchored, deferred,
|
||||
uniform_list,
|
||||
Action, AsyncApp, AsyncWindowContext, ClickEvent, Corner, DismissEvent, Entity, EventEmitter,
|
||||
FocusHandle, Focusable, KeyContext, ListHorizontalSizingBehavior, ListSizingBehavior,
|
||||
MouseButton, MouseDownEvent, Point, PromptLevel, ScrollStrategy, Subscription, Task,
|
||||
UniformListScrollHandle, WeakEntity, actions, anchored, deferred, uniform_list,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use language::{Buffer, File};
|
||||
@@ -312,9 +311,6 @@ pub struct GitPanel {
|
||||
bulk_staging: Option<BulkStaging>,
|
||||
stash_entries: GitStash,
|
||||
_settings_subscription: Subscription,
|
||||
/// On clicking an entry in a the git_panel this will
|
||||
/// trigger loading it
|
||||
open_diff_task: Option<Task<()>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
@@ -475,7 +471,6 @@ impl GitPanel {
|
||||
bulk_staging: None,
|
||||
stash_entries: Default::default(),
|
||||
_settings_subscription,
|
||||
open_diff_task: None,
|
||||
};
|
||||
|
||||
this.schedule_update(window, cx);
|
||||
@@ -755,23 +750,11 @@ impl GitPanel {
|
||||
|
||||
fn open_diff(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
maybe!({
|
||||
let entry = self
|
||||
.entries
|
||||
.get(self.selected_entry?)?
|
||||
.status_entry()?
|
||||
.clone();
|
||||
let entry = self.entries.get(self.selected_entry?)?.status_entry()?;
|
||||
let workspace = self.workspace.upgrade()?;
|
||||
let git_repo = self.active_repository.as_ref()?.clone();
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
let git_repo = self.active_repository.as_ref()?;
|
||||
|
||||
// let panel = panel.upgrade().unwrap(); // TODO FIXME
|
||||
// cx.read_entity(&panel, |panel, cx| {
|
||||
// panel
|
||||
// })
|
||||
// .unwrap(); // TODO FIXME
|
||||
|
||||
let project_diff = if let Some(project_diff) =
|
||||
workspace.read(cx).active_item_as::<ProjectDiff>(cx)
|
||||
if let Some(project_diff) = workspace.read(cx).active_item_as::<ProjectDiff>(cx)
|
||||
&& let Some(project_path) = project_diff.read(cx).active_path(cx)
|
||||
&& Some(&entry.repo_path)
|
||||
== git_repo
|
||||
@@ -781,21 +764,16 @@ impl GitPanel {
|
||||
{
|
||||
project_diff.focus_handle(cx).focus(window);
|
||||
project_diff.update(cx, |project_diff, cx| project_diff.autoscroll(cx));
|
||||
project_diff
|
||||
} else {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
ProjectDiff::deploy_at(workspace, Some(entry.clone()), window, cx)
|
||||
})
|
||||
return None;
|
||||
};
|
||||
focus_handle.focus(window); // TODO: should we focus before the file is loaded or wait for that?
|
||||
|
||||
let project_diff = project_diff.downgrade();
|
||||
// TODO use the fancy new thing
|
||||
self.open_diff_task = Some(cx.spawn_in(window, async move |_, cx| {
|
||||
ProjectDiff::refresh_one(project_diff, entry.repo_path, entry.status, cx)
|
||||
.await
|
||||
.unwrap(); // TODO FIXME
|
||||
}));
|
||||
self.workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
ProjectDiff::deploy_at(workspace, Some(entry.clone()), window, cx);
|
||||
})
|
||||
.ok();
|
||||
self.focus_handle.focus(window);
|
||||
|
||||
Some(())
|
||||
});
|
||||
}
|
||||
@@ -3882,7 +3860,6 @@ impl GitPanel {
|
||||
})
|
||||
}
|
||||
|
||||
// context menu
|
||||
fn deploy_entry_context_menu(
|
||||
&mut self,
|
||||
position: Point<Pixels>,
|
||||
@@ -4108,7 +4085,6 @@ impl GitPanel {
|
||||
this.selected_entry = Some(ix);
|
||||
cx.notify();
|
||||
if event.modifiers().secondary() {
|
||||
// the click handler
|
||||
this.open_file(&Default::default(), window, cx)
|
||||
} else {
|
||||
this.open_diff(&Default::default(), window, cx);
|
||||
|
||||
@@ -7,21 +7,19 @@ use crate::{
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use buffer_diff::{BufferDiff, DiffHunkSecondaryStatus};
|
||||
use collections::{HashMap, HashSet};
|
||||
use db::smol::stream::StreamExt;
|
||||
use editor::{
|
||||
Addon, Editor, EditorEvent, SelectionEffects,
|
||||
actions::{GoToHunk, GoToPreviousHunk},
|
||||
multibuffer_context_lines,
|
||||
scroll::Autoscroll,
|
||||
};
|
||||
use futures::stream::FuturesUnordered;
|
||||
use git::{
|
||||
Commit, StageAll, StageAndNext, ToggleStaged, UnstageAll, UnstageAndNext,
|
||||
repository::{Branch, RepoPath, Upstream, UpstreamTracking, UpstreamTrackingStatus},
|
||||
status::FileStatus,
|
||||
};
|
||||
use gpui::{
|
||||
Action, AnyElement, AnyView, App, AppContext, AsyncWindowContext, Entity, EventEmitter,
|
||||
Action, AnyElement, AnyView, App, AppContext as _, AsyncWindowContext, Entity, EventEmitter,
|
||||
FocusHandle, Focusable, Render, Subscription, Task, WeakEntity, actions,
|
||||
};
|
||||
use language::{Anchor, Buffer, Capability, OffsetRangeExt};
|
||||
@@ -29,21 +27,17 @@ use multi_buffer::{MultiBuffer, PathKey};
|
||||
use project::{
|
||||
Project, ProjectPath,
|
||||
git_store::{
|
||||
self, Repository, StatusEntry,
|
||||
Repository,
|
||||
branch_diff::{self, BranchDiffEvent, DiffBase},
|
||||
},
|
||||
};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{
|
||||
any::{Any, TypeId},
|
||||
collections::VecDeque,
|
||||
sync::Arc,
|
||||
};
|
||||
use std::{ops::Range, time::Instant};
|
||||
|
||||
use std::any::{Any, TypeId};
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
use theme::ActiveTheme;
|
||||
use ui::{KeyBinding, Tooltip, prelude::*, vertical_divider};
|
||||
use util::{ResultExt, rel_path::RelPath};
|
||||
use util::{ResultExt as _, rel_path::RelPath};
|
||||
use workspace::{
|
||||
CloseActiveItem, ItemNavHistory, SerializableItem, ToolbarItemEvent, ToolbarItemLocation,
|
||||
ToolbarItemView, Workspace,
|
||||
@@ -52,8 +46,6 @@ use workspace::{
|
||||
searchable::SearchableItemHandle,
|
||||
};
|
||||
|
||||
mod diff_loader;
|
||||
|
||||
actions!(
|
||||
git,
|
||||
[
|
||||
@@ -100,7 +92,7 @@ impl ProjectDiff {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
Self::deploy_at(workspace, None, window, cx);
|
||||
Self::deploy_at(workspace, None, window, cx)
|
||||
}
|
||||
|
||||
fn deploy_branch_diff(
|
||||
@@ -142,7 +134,7 @@ impl ProjectDiff {
|
||||
entry: Option<GitStatusEntry>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) -> Entity<ProjectDiff> {
|
||||
) {
|
||||
telemetry::event!(
|
||||
"Git Diff Opened",
|
||||
source = if entry.is_some() {
|
||||
@@ -174,8 +166,7 @@ impl ProjectDiff {
|
||||
project_diff.update(cx, |project_diff, cx| {
|
||||
project_diff.move_to_entry(entry, window, cx);
|
||||
})
|
||||
};
|
||||
project_diff
|
||||
}
|
||||
}
|
||||
|
||||
pub fn autoscroll(&self, cx: &mut Context<Self>) {
|
||||
@@ -276,23 +267,15 @@ impl ProjectDiff {
|
||||
cx.subscribe_in(&editor, window, Self::handle_editor_event)
|
||||
.detach();
|
||||
|
||||
let loader = diff_loader::start_loader(cx.entity(), window, cx);
|
||||
|
||||
let branch_diff_subscription = cx.subscribe_in(
|
||||
&branch_diff,
|
||||
window,
|
||||
move |this, _git_store, event, window, cx| match event {
|
||||
BranchDiffEvent::FileListChanged => {
|
||||
// TODO this does not account for size of paths
|
||||
// maybe a quick fs metadata could get us info on that?
|
||||
// would make number of paths async but thats fine here
|
||||
// let entries = this.first_n_entries(cx, 100);
|
||||
loader.update_file_list();
|
||||
// let
|
||||
// this._task = window.spawn(cx, {
|
||||
// let this = cx.weak_entity();
|
||||
// async |cx| Self::refresh(this, entries, cx).await
|
||||
// })
|
||||
this._task = window.spawn(cx, {
|
||||
let this = cx.weak_entity();
|
||||
async |cx| Self::refresh(this, cx).await
|
||||
})
|
||||
}
|
||||
},
|
||||
);
|
||||
@@ -307,32 +290,22 @@ impl ProjectDiff {
|
||||
if is_sort_by_path != was_sort_by_path
|
||||
|| is_collapse_untracked_diff != was_collapse_untracked_diff
|
||||
{
|
||||
// no idea why we need to do anything here
|
||||
// probably should sort the multibuffer instead of reparsing
|
||||
// everything though!!!
|
||||
todo!("resort multibuffer entries");
|
||||
todo!("assert the entries in the list did not change")
|
||||
// this._task = {
|
||||
// window.spawn(cx, {
|
||||
// let this = cx.weak_entity();
|
||||
// async |cx| Self::refresh(this, cx).await
|
||||
// })
|
||||
// }
|
||||
this._task = {
|
||||
window.spawn(cx, {
|
||||
let this = cx.weak_entity();
|
||||
async |cx| Self::refresh(this, cx).await
|
||||
})
|
||||
}
|
||||
}
|
||||
was_sort_by_path = is_sort_by_path;
|
||||
was_collapse_untracked_diff = is_collapse_untracked_diff;
|
||||
})
|
||||
.detach();
|
||||
|
||||
// let task = window.spawn(cx, {
|
||||
// let this = cx.weak_entity();
|
||||
// async |cx| {
|
||||
// let entries = this
|
||||
// .read_with(cx, |project_diff, cx| project_diff.first_n_entries(cx, 100))
|
||||
// .unwrap();
|
||||
// Self::refresh(this, entries, cx).await
|
||||
// }
|
||||
// });
|
||||
let task = window.spawn(cx, {
|
||||
let this = cx.weak_entity();
|
||||
async |cx| Self::refresh(this, cx).await
|
||||
});
|
||||
|
||||
Self {
|
||||
project,
|
||||
@@ -498,11 +471,10 @@ impl ProjectDiff {
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let subscription = cx.subscribe_in(&diff, window, move |this, _, _, window, cx| {
|
||||
// TODO fix this
|
||||
// this._task = window.spawn(cx, {
|
||||
// let this = cx.weak_entity();
|
||||
// async |cx| Self::refresh(this, cx).await
|
||||
// })
|
||||
this._task = window.spawn(cx, {
|
||||
let this = cx.weak_entity();
|
||||
async |cx| Self::refresh(this, cx).await
|
||||
})
|
||||
});
|
||||
self.buffer_diff_subscriptions
|
||||
.insert(path_key.path.clone(), (diff.clone(), subscription));
|
||||
@@ -578,221 +550,51 @@ impl ProjectDiff {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn all_entries(&self, cx: &App) -> Vec<StatusEntry> {
|
||||
let Some(ref repo) = self.branch_diff.read(cx).repo else {
|
||||
return Vec::new();
|
||||
};
|
||||
repo.read(cx).cached_status().collect()
|
||||
}
|
||||
|
||||
pub fn entries(&self, cx: &App) -> Option<impl Iterator<Item = StatusEntry>> {
|
||||
Some(
|
||||
self.branch_diff
|
||||
.read(cx)
|
||||
.repo
|
||||
.as_ref()?
|
||||
.read(cx)
|
||||
.cached_status(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn first_n_entries(&self, cx: &App, n: usize) -> VecDeque<StatusEntry> {
|
||||
let Some(ref repo) = self.branch_diff.read(cx).repo else {
|
||||
return VecDeque::new();
|
||||
};
|
||||
repo.read(cx).cached_status().take(n).collect()
|
||||
}
|
||||
|
||||
pub async fn refresh_one(
|
||||
this: WeakEntity<Self>,
|
||||
repo_path: RepoPath,
|
||||
status: FileStatus,
|
||||
cx: &mut AsyncWindowContext,
|
||||
) -> Result<()> {
|
||||
use git_store::branch_diff::BranchDiff;
|
||||
|
||||
let Some(this) = this.upgrade() else {
|
||||
return Ok(());
|
||||
};
|
||||
let multibuffer = cx.read_entity(&this, |this, _| this.multibuffer.clone())?;
|
||||
let branch_diff = cx.read_entity(&this, |pd, _| pd.branch_diff.clone())?;
|
||||
|
||||
let Some(repo) = cx.read_entity(&branch_diff, |bd, _| bd.repo.clone())? else {
|
||||
return Ok(());
|
||||
};
|
||||
let project = cx.read_entity(&branch_diff, |bd, _| bd.project.clone())?;
|
||||
|
||||
let mut previous_paths =
|
||||
cx.read_entity(&multibuffer, |mb, _| mb.paths().collect::<HashSet<_>>())?;
|
||||
|
||||
let tree_diff_status = cx.read_entity(&branch_diff, |branch_diff, _| {
|
||||
branch_diff
|
||||
.tree_diff
|
||||
.as_ref()
|
||||
.and_then(|t| t.entries.get(&repo_path))
|
||||
.cloned()
|
||||
})?;
|
||||
|
||||
let Some(status) = cx.read_entity(&branch_diff, |bd, _| {
|
||||
bd.merge_statuses(Some(status), tree_diff_status.as_ref())
|
||||
})?
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
if !status.has_changes() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let Some(project_path) = cx.read_entity(&repo, |repo, cx| {
|
||||
repo.repo_path_to_project_path(&repo_path, cx)
|
||||
})?
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let sort_prefix =
|
||||
cx.read_entity(&repo, |repo, cx| sort_prefix(repo, &repo_path, status, cx))?;
|
||||
|
||||
let path_key = PathKey::with_sort_prefix(sort_prefix, repo_path.into_arc());
|
||||
previous_paths.remove(&path_key);
|
||||
|
||||
let repo = repo.clone();
|
||||
let Some((buffer, diff)) = BranchDiff::load_buffer(
|
||||
tree_diff_status,
|
||||
project_path,
|
||||
repo,
|
||||
project.downgrade(),
|
||||
&mut cx.to_app(),
|
||||
)
|
||||
.await
|
||||
.log_err() else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
cx.update(|window, cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
this.register_buffer(path_key, status, buffer, diff, window, cx)
|
||||
pub async fn refresh(this: WeakEntity<Self>, cx: &mut AsyncWindowContext) -> Result<()> {
|
||||
let mut path_keys = Vec::new();
|
||||
let buffers_to_load = this.update(cx, |this, cx| {
|
||||
let (repo, buffers_to_load) = this.branch_diff.update(cx, |branch_diff, cx| {
|
||||
let load_buffers = branch_diff.load_buffers(cx);
|
||||
(branch_diff.repo().cloned(), load_buffers)
|
||||
});
|
||||
})?;
|
||||
let mut previous_paths = this.multibuffer.read(cx).paths().collect::<HashSet<_>>();
|
||||
|
||||
// TODO LL clear multibuff on open?
|
||||
// // remove anything not part of the diff in the multibuffer
|
||||
// this.update(cx, |this, cx| {
|
||||
// multibuffer.update(cx, |multibuffer, cx| {
|
||||
// for path in previous_paths {
|
||||
// this.buffer_diff_subscriptions.remove(&path.path);
|
||||
// multibuffer.remove_excerpts_for_path(path, cx);
|
||||
// }
|
||||
// });
|
||||
// })?;
|
||||
if let Some(repo) = repo {
|
||||
let repo = repo.read(cx);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn refresh(
|
||||
this: WeakEntity<Self>,
|
||||
cached_status: Vec<StatusEntry>,
|
||||
cx: &mut AsyncWindowContext,
|
||||
) -> Result<()> {
|
||||
dbg!("refreshing all");
|
||||
use git_store::branch_diff::BranchDiff;
|
||||
let Some(this) = this.upgrade() else {
|
||||
return Ok(());
|
||||
};
|
||||
let multibuffer = cx.read_entity(&this, |this, _| this.multibuffer.clone())?;
|
||||
let branch_diff = cx.read_entity(&this, |pd, _| pd.branch_diff.clone())?;
|
||||
|
||||
let Some(repo) = cx.read_entity(&branch_diff, |bd, _| bd.repo.clone())? else {
|
||||
return Ok(());
|
||||
};
|
||||
let project = cx.read_entity(&branch_diff, |bd, _| bd.project.clone())?;
|
||||
|
||||
let mut previous_paths =
|
||||
cx.read_entity(&multibuffer, |mb, _| mb.paths().collect::<HashSet<_>>())?;
|
||||
|
||||
// Idea: on click in git panel prioritize task for that file in some way ...
|
||||
// could have a hashmap of futures here
|
||||
// - needs to prioritize *some* background tasks over others
|
||||
// -
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
let mut seen = HashSet::default();
|
||||
for entry in cached_status {
|
||||
seen.insert(entry.repo_path.clone());
|
||||
let tree_diff_status = cx.read_entity(&branch_diff, |branch_diff, _| {
|
||||
branch_diff
|
||||
.tree_diff
|
||||
.as_ref()
|
||||
.and_then(|t| t.entries.get(&entry.repo_path))
|
||||
.cloned()
|
||||
})?;
|
||||
|
||||
let Some(status) = cx.read_entity(&branch_diff, |bd, _| {
|
||||
bd.merge_statuses(Some(entry.status), tree_diff_status.as_ref())
|
||||
})?
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
if !status.has_changes() {
|
||||
continue;
|
||||
path_keys = Vec::with_capacity(buffers_to_load.len());
|
||||
for entry in buffers_to_load.iter() {
|
||||
let sort_prefix = sort_prefix(&repo, &entry.repo_path, entry.file_status, cx);
|
||||
let path_key =
|
||||
PathKey::with_sort_prefix(sort_prefix, entry.repo_path.as_ref().clone());
|
||||
previous_paths.remove(&path_key);
|
||||
path_keys.push(path_key)
|
||||
}
|
||||
}
|
||||
|
||||
let Some(project_path) = cx.read_entity(&repo, |repo, cx| {
|
||||
repo.repo_path_to_project_path(&entry.repo_path, cx)
|
||||
})?
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let sort_prefix = cx.read_entity(&repo, |repo, cx| {
|
||||
sort_prefix(repo, &entry.repo_path, entry.status, cx)
|
||||
})?;
|
||||
|
||||
let path_key = PathKey::with_sort_prefix(sort_prefix, entry.repo_path.into_arc());
|
||||
previous_paths.remove(&path_key);
|
||||
|
||||
let repo = repo.clone();
|
||||
let project = project.downgrade();
|
||||
let task = cx.spawn(async move |cx| {
|
||||
let res = BranchDiff::load_buffer(
|
||||
tree_diff_status,
|
||||
project_path,
|
||||
repo,
|
||||
project,
|
||||
&mut cx.to_app(),
|
||||
)
|
||||
.await;
|
||||
(res, path_key, entry.status)
|
||||
});
|
||||
|
||||
tasks.push(task)
|
||||
}
|
||||
|
||||
// remove anything not part of the diff in the multibuffer
|
||||
this.update(cx, |this, cx| {
|
||||
multibuffer.update(cx, |multibuffer, cx| {
|
||||
this.multibuffer.update(cx, |multibuffer, cx| {
|
||||
for path in previous_paths {
|
||||
this.buffer_diff_subscriptions.remove(&path.path);
|
||||
multibuffer.remove_excerpts_for_path(path, cx);
|
||||
}
|
||||
});
|
||||
buffers_to_load
|
||||
})?;
|
||||
|
||||
// add the new buffers as they are parsed
|
||||
let mut last_notify = Instant::now();
|
||||
while let Some((res, path_key, file_status)) = tasks.next().await {
|
||||
if let Some((buffer, diff)) = res.log_err() {
|
||||
for (entry, path_key) in buffers_to_load.into_iter().zip(path_keys.into_iter()) {
|
||||
if let Some((buffer, diff)) = entry.load.await.log_err() {
|
||||
cx.update(|window, cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
this.register_buffer(path_key, file_status, buffer, diff, window, cx)
|
||||
});
|
||||
this.register_buffer(path_key, entry.file_status, buffer, diff, window, cx)
|
||||
})
|
||||
.ok();
|
||||
})?;
|
||||
}
|
||||
|
||||
if last_notify.elapsed().as_millis() > 100 {
|
||||
cx.update_entity(&this, |_, cx| cx.notify())?;
|
||||
last_notify = Instant::now();
|
||||
}
|
||||
}
|
||||
this.update(cx, |this, cx| {
|
||||
this.pending_scroll.take();
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,249 +0,0 @@
|
||||
//! Task which updates the project diff multibuffer without putting too much
|
||||
//! pressure on the frontend executor. It prioritizes loading the area around the user
|
||||
|
||||
use collections::HashSet;
|
||||
use db::smol::stream::StreamExt;
|
||||
use futures::channel::mpsc;
|
||||
use gpui::{AppContext, AsyncWindowContext, Entity, Task, WeakEntity};
|
||||
use project::git_store::StatusEntry;
|
||||
use ui::{App, Window};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{git_panel::GitStatusEntry, project_diff::ProjectDiff};
|
||||
|
||||
enum Update {
|
||||
Position(usize),
|
||||
NewFile(StatusEntry),
|
||||
ListChanged,
|
||||
// should not need to handle re-ordering (sorting) here.
|
||||
// something to handle scroll? or should that live in the project diff?
|
||||
}
|
||||
|
||||
struct LoaderHandle {
|
||||
task: Task<Option<()>>,
|
||||
sender: mpsc::UnboundedSender<Update>,
|
||||
}
|
||||
|
||||
impl LoaderHandle {
|
||||
pub fn update_file_list(&self) {
|
||||
let _ = self
|
||||
.sender
|
||||
.unbounded_send(Update::ListChanged)
|
||||
.log_err();
|
||||
|
||||
}
|
||||
pub fn update_pos(&self, pos: usize) {
|
||||
let _ = self
|
||||
.sender
|
||||
.unbounded_send(Update::Position((pos)))
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start_loader(project_diff: Entity<ProjectDiff>, window: &Window, cx: &App) -> LoaderHandle {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
|
||||
let task = window.spawn(cx, async move |cx| {
|
||||
load(rx, project_diff.downgrade(), cx).await
|
||||
});
|
||||
LoaderHandle { task, sender: tx }
|
||||
}
|
||||
|
||||
enum DiffEntry {
|
||||
Loading(GitStatusEntry),
|
||||
Loaded(GitStatusEntry),
|
||||
Queued(GitStatusEntry),
|
||||
}
|
||||
|
||||
impl DiffEntry {
|
||||
fn queued(&self) -> bool {
|
||||
matches!(self, DiffEntry::Queued(_))
|
||||
}
|
||||
}
|
||||
|
||||
async fn load(
|
||||
rx: mpsc::UnboundedReceiver<Update>,
|
||||
project_diff: WeakEntity<ProjectDiff>,
|
||||
cx: &mut AsyncWindowContext,
|
||||
) -> Option<()> {
|
||||
// let initial_entries = cx.read_entity(&cx.entity(), |project_diff, cx| project_diff.first_n_entries(cx, 100));
|
||||
// let loading = to_load.drain(..100).map(|| refresh_one)
|
||||
let mut existing = Vec::new();
|
||||
|
||||
loop {
|
||||
let update = rx.next().await?;
|
||||
match update {
|
||||
Update::Position(pos) => {
|
||||
if existing.get(pos).is_some_and(|diff| diff.queued()) {
|
||||
todo!("append to future unordered, also load in the bit
|
||||
around (maybe with a short sleep ahead so we get some sense
|
||||
of 'priority'")
|
||||
}
|
||||
// drop whatever is loading so we get to the new bit earlier
|
||||
}
|
||||
Update::NewFile(status_entry) => todo!(),
|
||||
Update::ListChanged => {
|
||||
let (added, removed) = project_diff
|
||||
.upgrade()?
|
||||
.read_with(cx, |diff, cx| diff_current_list(&existing, diff, cx))
|
||||
.ok()?;
|
||||
}
|
||||
}
|
||||
|
||||
// wait for Update OR Load done
|
||||
// -> Immediately spawn update
|
||||
// OR
|
||||
// -> spawn next group
|
||||
}
|
||||
}
|
||||
|
||||
// could be new list
|
||||
fn diff_current_list(
|
||||
existing_entries: &[GitStatusEntry],
|
||||
project_diff: &ProjectDiff,
|
||||
cx: &App,
|
||||
) -> (Vec<(usize, GitStatusEntry)>, Vec<usize>) {
|
||||
let Some(new_entries) = project_diff.entries(cx) else {
|
||||
return (Vec::new(), Vec::new());
|
||||
};
|
||||
|
||||
let existing_entries = existing_entries.iter().enumerate();
|
||||
for entry in new_entries {
|
||||
let Some((idx, existing)) = existing_entries.next() else {
|
||||
todo!();
|
||||
};
|
||||
|
||||
if existing == entry {
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
// let initial_entries = cx.read_entity(&cx.entity(), |project_diff, cx| project_diff.first_n_entries(cx, 100));
|
||||
// let loading = to_load.drain(..100).map(|| refresh_one)
|
||||
}
|
||||
|
||||
// // remove anything not part of the diff in the multibuffer
|
||||
// fn remove_anything_not_being_loaded() {
|
||||
// this.update(cx, |this, cx| {
|
||||
// multibuffer.update(cx, |multibuffer, cx| {
|
||||
// for path in previous_paths {
|
||||
// this.buffer_diff_subscriptions.remove(&path.path);
|
||||
// multibuffer.remove_excerpts_for_path(path, cx);
|
||||
// }
|
||||
// });
|
||||
// })?;
|
||||
// }
|
||||
|
||||
pub async fn refresh_group(
|
||||
this: WeakEntity<ProjectDiff>,
|
||||
cached_status: Vec<StatusEntry>,
|
||||
cx: &mut AsyncWindowContext,
|
||||
) -> anyhow::Result<()> {
|
||||
dbg!("refreshing all");
|
||||
use project::git_store::branch_diff::BranchDiff;
|
||||
let Some(this) = this.upgrade() else {
|
||||
return Ok(());
|
||||
};
|
||||
let multibuffer = cx.read_entity(&this, |this, _| this.multibuffer.clone())?;
|
||||
let branch_diff = cx.read_entity(&this, |pd, _| pd.branch_diff.clone())?;
|
||||
|
||||
let Some(repo) = cx.read_entity(&branch_diff, |bd, _| bd.repo.clone())? else {
|
||||
return Ok(());
|
||||
};
|
||||
let project = cx.read_entity(&branch_diff, |bd, _| bd.project.clone())?;
|
||||
|
||||
let mut previous_paths =
|
||||
cx.read_entity(&multibuffer, |mb, _| mb.paths().collect::<HashSet<_>>())?;
|
||||
|
||||
// Idea: on click in git panel prioritize task for that file in some way ...
|
||||
// could have a hashmap of futures here
|
||||
// - needs to prioritize *some* background tasks over others
|
||||
// -
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
let mut seen = HashSet::default();
|
||||
for entry in cached_status {
|
||||
seen.insert(entry.repo_path.clone());
|
||||
let tree_diff_status = cx.read_entity(&branch_diff, |branch_diff, _| {
|
||||
branch_diff
|
||||
.tree_diff
|
||||
.as_ref()
|
||||
.and_then(|t| t.entries.get(&entry.repo_path))
|
||||
.cloned()
|
||||
})?;
|
||||
|
||||
let Some(status) = cx.read_entity(&branch_diff, |bd, _| {
|
||||
bd.merge_statuses(Some(entry.status), tree_diff_status.as_ref())
|
||||
})?
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
if !status.has_changes() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(project_path) = cx.read_entity(&repo, |repo, cx| {
|
||||
repo.repo_path_to_project_path(&entry.repo_path, cx)
|
||||
})?
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let sort_prefix = cx.read_entity(&repo, |repo, cx| {
|
||||
sort_prefix(repo, &entry.repo_path, entry.status, cx)
|
||||
})?;
|
||||
|
||||
let path_key = PathKey::with_sort_prefix(sort_prefix, entry.repo_path.into_arc());
|
||||
previous_paths.remove(&path_key);
|
||||
|
||||
let repo = repo.clone();
|
||||
let project = project.downgrade();
|
||||
let task = cx.spawn(async move |cx| {
|
||||
let res = BranchDiff::load_buffer(
|
||||
tree_diff_status,
|
||||
project_path,
|
||||
repo,
|
||||
project,
|
||||
&mut cx.to_app(),
|
||||
)
|
||||
.await;
|
||||
(res, path_key, entry.status)
|
||||
});
|
||||
|
||||
tasks.push(task)
|
||||
}
|
||||
|
||||
// remove anything not part of the diff in the multibuffer
|
||||
this.update(cx, |this, cx| {
|
||||
multibuffer.update(cx, |multibuffer, cx| {
|
||||
for path in previous_paths {
|
||||
this.buffer_diff_subscriptions.remove(&path.path);
|
||||
multibuffer.remove_excerpts_for_path(path, cx);
|
||||
}
|
||||
});
|
||||
})?;
|
||||
|
||||
// add the new buffers as they are parsed
|
||||
let mut last_notify = Instant::now();
|
||||
while let Some((res, path_key, file_status)) = tasks.next().await {
|
||||
if let Some((buffer, diff)) = res.log_err() {
|
||||
cx.update(|window, cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
this.register_buffer(path_key, file_status, buffer, diff, window, cx)
|
||||
});
|
||||
})?;
|
||||
}
|
||||
|
||||
if last_notify.elapsed().as_millis() > 100 {
|
||||
cx.update_entity(&this, |_, cx| cx.notify())?;
|
||||
last_notify = Instant::now();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn sort_or_collapse_changed() {
|
||||
todo!()
|
||||
}
|
||||
@@ -229,6 +229,10 @@ pub struct GenerativeContentBlob {
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionCallPart {
|
||||
pub function_call: FunctionCall,
|
||||
/// Thought signature returned by the model for function calls.
|
||||
/// Only present on the first function call in parallel call scenarios.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub thought_signature: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -517,6 +521,8 @@ pub enum Model {
|
||||
alias = "gemini-2.5-pro-preview-06-05"
|
||||
)]
|
||||
Gemini25Pro,
|
||||
#[serde(rename = "gemini-3-pro-preview")]
|
||||
Gemini3ProPreview,
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
name: String,
|
||||
@@ -543,6 +549,7 @@ impl Model {
|
||||
Self::Gemini25FlashLitePreview => "gemini-2.5-flash-lite-preview",
|
||||
Self::Gemini25Flash => "gemini-2.5-flash",
|
||||
Self::Gemini25Pro => "gemini-2.5-pro",
|
||||
Self::Gemini3ProPreview => "gemini-3-pro-preview",
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
@@ -556,6 +563,7 @@ impl Model {
|
||||
Self::Gemini25FlashLitePreview => "gemini-2.5-flash-lite-preview-06-17",
|
||||
Self::Gemini25Flash => "gemini-2.5-flash",
|
||||
Self::Gemini25Pro => "gemini-2.5-pro",
|
||||
Self::Gemini3ProPreview => "gemini-3-pro-preview",
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
@@ -570,6 +578,7 @@ impl Model {
|
||||
Self::Gemini25FlashLitePreview => "Gemini 2.5 Flash-Lite Preview",
|
||||
Self::Gemini25Flash => "Gemini 2.5 Flash",
|
||||
Self::Gemini25Pro => "Gemini 2.5 Pro",
|
||||
Self::Gemini3ProPreview => "Gemini 3 Pro Preview",
|
||||
Self::Custom {
|
||||
name, display_name, ..
|
||||
} => display_name.as_ref().unwrap_or(name),
|
||||
@@ -586,6 +595,7 @@ impl Model {
|
||||
Self::Gemini25FlashLitePreview => 1_000_000,
|
||||
Self::Gemini25Flash => 1_048_576,
|
||||
Self::Gemini25Pro => 1_048_576,
|
||||
Self::Gemini3ProPreview => 1_048_576,
|
||||
Self::Custom { max_tokens, .. } => *max_tokens,
|
||||
}
|
||||
}
|
||||
@@ -600,6 +610,7 @@ impl Model {
|
||||
Model::Gemini25FlashLitePreview => Some(64_000),
|
||||
Model::Gemini25Flash => Some(65_536),
|
||||
Model::Gemini25Pro => Some(65_536),
|
||||
Model::Gemini3ProPreview => Some(65_536),
|
||||
Model::Custom { .. } => None,
|
||||
}
|
||||
}
|
||||
@@ -619,7 +630,10 @@ impl Model {
|
||||
| Self::Gemini15Flash
|
||||
| Self::Gemini20FlashLite
|
||||
| Self::Gemini20Flash => GoogleModelMode::Default,
|
||||
Self::Gemini25FlashLitePreview | Self::Gemini25Flash | Self::Gemini25Pro => {
|
||||
Self::Gemini25FlashLitePreview
|
||||
| Self::Gemini25Flash
|
||||
| Self::Gemini25Pro
|
||||
| Self::Gemini3ProPreview => {
|
||||
GoogleModelMode::Thinking {
|
||||
// By default these models are set to "auto", so we preserve that behavior
|
||||
// but indicate they are capable of thinking mode
|
||||
@@ -636,3 +650,109 @@ impl std::fmt::Display for Model {
|
||||
write!(f, "{}", self.id())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_function_call_part_with_signature_serializes_correctly() {
|
||||
let part = FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: Some("test_signature".to_string()),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_value(&part).unwrap();
|
||||
|
||||
assert_eq!(serialized["functionCall"]["name"], "test_function");
|
||||
assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
|
||||
assert_eq!(serialized["thoughtSignature"], "test_signature");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_call_part_without_signature_omits_field() {
|
||||
let part = FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_value(&part).unwrap();
|
||||
|
||||
assert_eq!(serialized["functionCall"]["name"], "test_function");
|
||||
assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
|
||||
// thoughtSignature field should not be present when None
|
||||
assert!(serialized.get("thoughtSignature").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_call_part_deserializes_with_signature() {
|
||||
let json = json!({
|
||||
"functionCall": {
|
||||
"name": "test_function",
|
||||
"args": {"arg": "value"}
|
||||
},
|
||||
"thoughtSignature": "test_signature"
|
||||
});
|
||||
|
||||
let part: FunctionCallPart = serde_json::from_value(json).unwrap();
|
||||
|
||||
assert_eq!(part.function_call.name, "test_function");
|
||||
assert_eq!(part.thought_signature, Some("test_signature".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_call_part_deserializes_without_signature() {
|
||||
let json = json!({
|
||||
"functionCall": {
|
||||
"name": "test_function",
|
||||
"args": {"arg": "value"}
|
||||
}
|
||||
});
|
||||
|
||||
let part: FunctionCallPart = serde_json::from_value(json).unwrap();
|
||||
|
||||
assert_eq!(part.function_call.name, "test_function");
|
||||
assert_eq!(part.thought_signature, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_call_part_round_trip() {
|
||||
let original = FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value", "nested": {"key": "val"}}),
|
||||
},
|
||||
thought_signature: Some("round_trip_signature".to_string()),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_value(&original).unwrap();
|
||||
let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.function_call.name, original.function_call.name);
|
||||
assert_eq!(deserialized.function_call.args, original.function_call.args);
|
||||
assert_eq!(deserialized.thought_signature, original.thought_signature);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_call_part_with_empty_signature_serializes() {
|
||||
let part = FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: Some("".to_string()),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_value(&part).unwrap();
|
||||
|
||||
// Empty string should still be serialized (normalization happens at a higher level)
|
||||
assert_eq!(serialized["thoughtSignature"], "");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,13 +187,12 @@ font-kit = { git = "https://github.com/zed-industries/font-kit", rev = "11052312
|
||||
"source-fontconfig-dlopen",
|
||||
], optional = true }
|
||||
|
||||
calloop = { version = "0.14.3" }
|
||||
calloop = { version = "0.13.0" }
|
||||
filedescriptor = { version = "0.8.2", optional = true }
|
||||
open = { version = "5.2.0", optional = true }
|
||||
|
||||
|
||||
# Wayland
|
||||
calloop-wayland-source = { version = "0.4.1", optional = true }
|
||||
calloop-wayland-source = { version = "0.3.0", optional = true }
|
||||
wayland-backend = { version = "0.3.3", features = [
|
||||
"client_system",
|
||||
"dlopen",
|
||||
@@ -266,6 +265,7 @@ naga.workspace = true
|
||||
[target.'cfg(any(target_os = "linux", target_os = "freebsd"))'.build-dependencies]
|
||||
naga.workspace = true
|
||||
|
||||
|
||||
[[example]]
|
||||
name = "hello_world"
|
||||
path = "examples/hello_world.rs"
|
||||
|
||||
@@ -310,11 +310,6 @@ impl AsyncWindowContext {
|
||||
.update(self, |_, window, cx| read(cx.global(), window, cx))
|
||||
}
|
||||
|
||||
/// Returns an `AsyncApp` by cloning the one used by Self
|
||||
pub fn to_app(&self) -> AsyncApp {
|
||||
self.app.clone()
|
||||
}
|
||||
|
||||
/// A convenience method for [`App::update_global`](BorrowAppContext::update_global).
|
||||
/// for updating the global state of the specified type.
|
||||
pub fn update_global<G, R>(
|
||||
|
||||
@@ -233,9 +233,6 @@ impl<'a, T: 'static> Context<'a, T> {
|
||||
/// Spawn the future returned by the given function.
|
||||
/// The function is provided a weak handle to the entity owned by this context and a context that can be held across await points.
|
||||
/// The returned task must be held or detached.
|
||||
///
|
||||
/// # Example
|
||||
/// `cx.spawn(async move |some_weak_entity, cx| ...)`
|
||||
#[track_caller]
|
||||
pub fn spawn<AsyncFn, R>(&self, f: AsyncFn) -> Task<R>
|
||||
where
|
||||
|
||||
@@ -1389,6 +1389,10 @@ pub enum WindowBackgroundAppearance {
|
||||
///
|
||||
/// Not always supported.
|
||||
Blurred,
|
||||
/// The Mica backdrop material, supported on Windows 11.
|
||||
MicaBackdrop,
|
||||
/// The Mica Alt backdrop material, supported on Windows 11.
|
||||
MicaAltBackdrop,
|
||||
}
|
||||
|
||||
/// The options that can be configured for a file dialog prompt
|
||||
|
||||
@@ -487,15 +487,12 @@ impl WaylandClient {
|
||||
|
||||
let (common, main_receiver) = LinuxCommon::new(event_loop.get_signal());
|
||||
|
||||
let handle = event_loop.handle(); // CHECK that wayland sources get higher prio
|
||||
let handle = event_loop.handle();
|
||||
handle
|
||||
// these are all tasks spawned on the foreground executor.
|
||||
// There is no concept of priority, they are all equal.
|
||||
.insert_source(main_receiver, {
|
||||
let handle = handle.clone();
|
||||
move |event, _, _: &mut WaylandClientStatePtr| {
|
||||
if let calloop::channel::Event::Msg(runnable) = event {
|
||||
// will only be called when the event loop has finished processing all pending events from the sources
|
||||
handle.insert_idle(|_| {
|
||||
let start = Instant::now();
|
||||
let mut timing = match runnable {
|
||||
@@ -653,7 +650,6 @@ impl WaylandClient {
|
||||
event_loop: Some(event_loop),
|
||||
}));
|
||||
|
||||
// MAGIC HERE IT IS
|
||||
WaylandSource::new(conn, event_queue)
|
||||
.insert(handle)
|
||||
.unwrap();
|
||||
@@ -1578,7 +1574,6 @@ fn linux_button_to_gpui(button: u32) -> Option<MouseButton> {
|
||||
})
|
||||
}
|
||||
|
||||
// how is this being called inside calloop
|
||||
impl Dispatch<wl_pointer::WlPointer, ()> for WaylandClientStatePtr {
|
||||
fn event(
|
||||
this: &mut Self,
|
||||
@@ -1669,7 +1664,7 @@ impl Dispatch<wl_pointer::WlPointer, ()> for WaylandClientStatePtr {
|
||||
modifiers: state.modifiers,
|
||||
});
|
||||
drop(state);
|
||||
window.handle_input(input); // How does this get into the event loop?
|
||||
window.handle_input(input);
|
||||
}
|
||||
}
|
||||
wl_pointer::Event::Button {
|
||||
|
||||
@@ -18,6 +18,7 @@ use smallvec::SmallVec;
|
||||
use windows::{
|
||||
Win32::{
|
||||
Foundation::*,
|
||||
Graphics::Dwm::*,
|
||||
Graphics::Gdi::*,
|
||||
System::{Com::*, LibraryLoader::*, Ole::*, SystemServices::*},
|
||||
UI::{Controls::*, HiDpi::*, Input::KeyboardAndMouse::*, Shell::*, WindowsAndMessaging::*},
|
||||
@@ -773,20 +774,26 @@ impl PlatformWindow for WindowsWindow {
|
||||
fn set_background_appearance(&self, background_appearance: WindowBackgroundAppearance) {
|
||||
let hwnd = self.0.hwnd;
|
||||
|
||||
// using Dwm APIs for Mica and MicaAlt backdrops.
|
||||
// others follow the set_window_composition_attribute approach
|
||||
match background_appearance {
|
||||
WindowBackgroundAppearance::Opaque => {
|
||||
// ACCENT_DISABLED
|
||||
set_window_composition_attribute(hwnd, None, 0);
|
||||
}
|
||||
WindowBackgroundAppearance::Transparent => {
|
||||
// Use ACCENT_ENABLE_TRANSPARENTGRADIENT for transparent background
|
||||
set_window_composition_attribute(hwnd, None, 2);
|
||||
}
|
||||
WindowBackgroundAppearance::Blurred => {
|
||||
// Enable acrylic blur
|
||||
// ACCENT_ENABLE_ACRYLICBLURBEHIND
|
||||
set_window_composition_attribute(hwnd, Some((0, 0, 0, 0)), 4);
|
||||
}
|
||||
WindowBackgroundAppearance::MicaBackdrop => {
|
||||
// DWMSBT_MAINWINDOW => MicaBase
|
||||
dwm_set_window_composition_attribute(hwnd, 2);
|
||||
}
|
||||
WindowBackgroundAppearance::MicaAltBackdrop => {
|
||||
// DWMSBT_TABBEDWINDOW => MicaAlt
|
||||
dwm_set_window_composition_attribute(hwnd, 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1330,9 +1337,34 @@ fn retrieve_window_placement(
|
||||
Ok(placement)
|
||||
}
|
||||
|
||||
fn dwm_set_window_composition_attribute(hwnd: HWND, backdrop_type: u32) {
|
||||
let mut version = unsafe { std::mem::zeroed() };
|
||||
let status = unsafe { windows::Wdk::System::SystemServices::RtlGetVersion(&mut version) };
|
||||
|
||||
// DWMWA_SYSTEMBACKDROP_TYPE is available only on version 22621 or later
|
||||
// using SetWindowCompositionAttributeType as a fallback
|
||||
if !status.is_ok() || version.dwBuildNumber < 22621 {
|
||||
return;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let result = DwmSetWindowAttribute(
|
||||
hwnd,
|
||||
DWMWA_SYSTEMBACKDROP_TYPE,
|
||||
&backdrop_type as *const _ as *const _,
|
||||
std::mem::size_of_val(&backdrop_type) as u32,
|
||||
);
|
||||
|
||||
if !result.is_ok() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn set_window_composition_attribute(hwnd: HWND, color: Option<Color>, state: u32) {
|
||||
let mut version = unsafe { std::mem::zeroed() };
|
||||
let status = unsafe { windows::Wdk::System::SystemServices::RtlGetVersion(&mut version) };
|
||||
|
||||
if !status.is_ok() || version.dwBuildNumber < 17763 {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -41,7 +41,6 @@ tree-sitter-rust.workspace = true
|
||||
ui_input.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
vim.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
|
||||
|
||||
@@ -1769,7 +1769,7 @@ impl Render for KeymapEditor {
|
||||
)
|
||||
.action(
|
||||
"Vim Bindings",
|
||||
vim::OpenDefaultKeymap.boxed_clone(),
|
||||
zed_actions::vim::OpenDefaultKeymap.boxed_clone(),
|
||||
)
|
||||
}))
|
||||
})
|
||||
|
||||
@@ -12,7 +12,7 @@ pub mod fake_provider;
|
||||
use anthropic::{AnthropicError, parse_prompt_too_long};
|
||||
use anyhow::{Result, anyhow};
|
||||
use client::Client;
|
||||
use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
|
||||
use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
|
||||
use futures::FutureExt;
|
||||
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||
@@ -70,7 +70,15 @@ pub fn init_settings(cx: &mut App) {
|
||||
/// A completion event from a language model.
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub enum LanguageModelCompletionEvent {
|
||||
StatusUpdate(CompletionRequestStatus),
|
||||
Queued {
|
||||
position: usize,
|
||||
},
|
||||
Started,
|
||||
RequestUsage {
|
||||
amount: usize,
|
||||
limit: UsageLimit,
|
||||
},
|
||||
ToolUseLimitReached,
|
||||
Stop(StopReason),
|
||||
Text(String),
|
||||
Thinking {
|
||||
@@ -90,88 +98,93 @@ pub enum LanguageModelCompletionEvent {
|
||||
StartMessage {
|
||||
message_id: String,
|
||||
},
|
||||
UsageUpdate(TokenUsage),
|
||||
TokenUsage(TokenUsage),
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionEvent {
|
||||
pub fn from_completion_request_status(
|
||||
status: CompletionRequestStatus,
|
||||
) -> Result<Self, LanguageModelCompletionError> {
|
||||
match status {
|
||||
CompletionRequestStatus::Queued { position } => {
|
||||
Ok(LanguageModelCompletionEvent::Queued { position })
|
||||
}
|
||||
CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
|
||||
CompletionRequestStatus::UsageUpdated { amount, limit } => {
|
||||
Ok(LanguageModelCompletionEvent::RequestUsage { amount, limit })
|
||||
}
|
||||
CompletionRequestStatus::ToolUseLimitReached => {
|
||||
Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
|
||||
}
|
||||
CompletionRequestStatus::Failed {
|
||||
code,
|
||||
message,
|
||||
request_id: _,
|
||||
retry_after,
|
||||
} => Err(LanguageModelCompletionError::from_cloud_failure(
|
||||
code,
|
||||
message,
|
||||
retry_after.map(Duration::from_secs_f64),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LanguageModelCompletionError {
|
||||
#[error("prompt too large for context window")]
|
||||
PromptTooLarge { tokens: Option<u64> },
|
||||
#[error("missing {provider} API key")]
|
||||
NoApiKey { provider: LanguageModelProviderName },
|
||||
#[error("{provider}'s API rate limit exceeded")]
|
||||
RateLimitExceeded {
|
||||
provider: LanguageModelProviderName,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("{provider}'s API servers are overloaded right now")]
|
||||
ServerOverloaded {
|
||||
provider: LanguageModelProviderName,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("{provider}'s API server reported an internal server error: {message}")]
|
||||
ApiInternalServerError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("missing API key")]
|
||||
NoApiKey,
|
||||
#[error("API rate limit exceeded")]
|
||||
RateLimitExceeded { retry_after: Option<Duration> },
|
||||
#[error("API servers are overloaded right now")]
|
||||
ServerOverloaded { retry_after: Option<Duration> },
|
||||
#[error("API server reported an internal server error: {message}")]
|
||||
ApiInternalServerError { message: String },
|
||||
#[error("{message}")]
|
||||
UpstreamProviderError {
|
||||
message: String,
|
||||
status: StatusCode,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
|
||||
#[error("HTTP response error from API: status {status_code} - {message:?}")]
|
||||
HttpResponseError {
|
||||
provider: LanguageModelProviderName,
|
||||
status_code: StatusCode,
|
||||
message: String,
|
||||
},
|
||||
|
||||
// Client errors
|
||||
#[error("invalid request format to {provider}'s API: {message}")]
|
||||
BadRequestFormat {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("authentication error with {provider}'s API: {message}")]
|
||||
AuthenticationError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("Permission error with {provider}'s API: {message}")]
|
||||
PermissionError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("invalid request format to API: {message}")]
|
||||
BadRequestFormat { message: String },
|
||||
#[error("authentication error with API: {message}")]
|
||||
AuthenticationError { message: String },
|
||||
#[error("Permission error with API: {message}")]
|
||||
PermissionError { message: String },
|
||||
#[error("language model provider API endpoint not found")]
|
||||
ApiEndpointNotFound { provider: LanguageModelProviderName },
|
||||
#[error("I/O error reading response from {provider}'s API")]
|
||||
ApiEndpointNotFound,
|
||||
#[error("I/O error reading response from API")]
|
||||
ApiReadResponseError {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: io::Error,
|
||||
},
|
||||
#[error("error serializing request to {provider} API")]
|
||||
#[error("error serializing request to API")]
|
||||
SerializeRequest {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: serde_json::Error,
|
||||
},
|
||||
#[error("error building request body to {provider} API")]
|
||||
#[error("error building request body to API")]
|
||||
BuildRequestBody {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: http::Error,
|
||||
},
|
||||
#[error("error sending HTTP request to {provider} API")]
|
||||
#[error("error sending HTTP request to API")]
|
||||
HttpSend {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: anyhow::Error,
|
||||
},
|
||||
#[error("error deserializing {provider} API response")]
|
||||
#[error("error deserializing API response")]
|
||||
DeserializeResponse {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: serde_json::Error,
|
||||
},
|
||||
@@ -182,6 +195,72 @@ pub enum LanguageModelCompletionError {
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionError {
|
||||
fn display_format(&self, provider: LanguageModelProviderName) {
|
||||
// match self {
|
||||
// #[error("prompt too large for context window")]
|
||||
// PromptTooLarge { tokens: Option<u64> },
|
||||
// #[error("missing API key")]
|
||||
// NoApiKey,
|
||||
// #[error("API rate limit exceeded")]
|
||||
// RateLimitExceeded { retry_after: Option<Duration> },
|
||||
// #[error("API servers are overloaded right now")]
|
||||
// ServerOverloaded { retry_after: Option<Duration> },
|
||||
// #[error("API server reported an internal server error: {message}")]
|
||||
// ApiInternalServerError { message: String },
|
||||
// #[error("{message}")]
|
||||
// UpstreamProviderError {
|
||||
// message: String,
|
||||
// status: StatusCode,
|
||||
// retry_after: Option<Duration>,
|
||||
// },
|
||||
// #[error("HTTP response error from API: status {status_code} - {message:?}")]
|
||||
// HttpResponseError {
|
||||
// status_code: StatusCode,
|
||||
// message: String,
|
||||
// },
|
||||
|
||||
// // Client errors
|
||||
// #[error("invalid request format to API: {message}")]
|
||||
// BadRequestFormat { message: String },
|
||||
// #[error("authentication error with API: {message}")]
|
||||
// AuthenticationError { message: String },
|
||||
// #[error("Permission error with API: {message}")]
|
||||
// PermissionError { message: String },
|
||||
// #[error("language model provider API endpoint not found")]
|
||||
// ApiEndpointNotFound,
|
||||
// #[error("I/O error reading response from API")]
|
||||
// ApiReadResponseError {
|
||||
// #[source]
|
||||
// error: io::Error,
|
||||
// },
|
||||
// #[error("error serializing request to API")]
|
||||
// SerializeRequest {
|
||||
// #[source]
|
||||
// error: serde_json::Error,
|
||||
// },
|
||||
// #[error("error building request body to API")]
|
||||
// BuildRequestBody {
|
||||
// #[source]
|
||||
// error: http::Error,
|
||||
// },
|
||||
// #[error("error sending HTTP request to API")]
|
||||
// HttpSend {
|
||||
// #[source]
|
||||
// error: anyhow::Error,
|
||||
// },
|
||||
// #[error("error deserializing API response")]
|
||||
// DeserializeResponse {
|
||||
// #[source]
|
||||
// error: serde_json::Error,
|
||||
// },
|
||||
|
||||
// // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
|
||||
// #[error(transparent)]
|
||||
// Other(#[from] anyhow::Error),
|
||||
|
||||
// }
|
||||
}
|
||||
|
||||
fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
|
||||
let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
|
||||
let upstream_status = error_json
|
||||
@@ -198,7 +277,6 @@ impl LanguageModelCompletionError {
|
||||
}
|
||||
|
||||
pub fn from_cloud_failure(
|
||||
upstream_provider: LanguageModelProviderName,
|
||||
code: String,
|
||||
message: String,
|
||||
retry_after: Option<Duration>,
|
||||
@@ -214,58 +292,46 @@ impl LanguageModelCompletionError {
|
||||
if let Some((upstream_status, inner_message)) =
|
||||
Self::parse_upstream_error_json(&message)
|
||||
{
|
||||
return Self::from_http_status(
|
||||
upstream_provider,
|
||||
upstream_status,
|
||||
inner_message,
|
||||
retry_after,
|
||||
);
|
||||
return Self::from_http_status(upstream_status, inner_message, retry_after);
|
||||
}
|
||||
anyhow!("completion request failed, code: {code}, message: {message}").into()
|
||||
Self::Other(anyhow!(
|
||||
"completion request failed, code: {code}, message: {message}"
|
||||
))
|
||||
} else if let Some(status_code) = code
|
||||
.strip_prefix("upstream_http_")
|
||||
.and_then(|code| StatusCode::from_str(code).ok())
|
||||
{
|
||||
Self::from_http_status(upstream_provider, status_code, message, retry_after)
|
||||
Self::from_http_status(status_code, message, retry_after)
|
||||
} else if let Some(status_code) = code
|
||||
.strip_prefix("http_")
|
||||
.and_then(|code| StatusCode::from_str(code).ok())
|
||||
{
|
||||
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
|
||||
Self::from_http_status(status_code, message, retry_after)
|
||||
} else {
|
||||
anyhow!("completion request failed, code: {code}, message: {message}").into()
|
||||
Self::Other(anyhow!(
|
||||
"completion request failed, code: {code}, message: {message}"
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_http_status(
|
||||
provider: LanguageModelProviderName,
|
||||
status_code: StatusCode,
|
||||
message: String,
|
||||
retry_after: Option<Duration>,
|
||||
) -> Self {
|
||||
match status_code {
|
||||
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
|
||||
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
|
||||
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
|
||||
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
|
||||
StatusCode::BAD_REQUEST => Self::BadRequestFormat { message },
|
||||
StatusCode::UNAUTHORIZED => Self::AuthenticationError { message },
|
||||
StatusCode::FORBIDDEN => Self::PermissionError { message },
|
||||
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound,
|
||||
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
|
||||
tokens: parse_prompt_too_long(&message),
|
||||
},
|
||||
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
|
||||
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded { retry_after },
|
||||
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { message },
|
||||
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded { retry_after },
|
||||
_ if status_code.as_u16() == 529 => Self::ServerOverloaded { retry_after },
|
||||
_ => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
message,
|
||||
},
|
||||
@@ -275,31 +341,25 @@ impl LanguageModelCompletionError {
|
||||
|
||||
impl From<AnthropicError> for LanguageModelCompletionError {
|
||||
fn from(error: AnthropicError) -> Self {
|
||||
let provider = ANTHROPIC_PROVIDER_NAME;
|
||||
match error {
|
||||
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
||||
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
||||
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
|
||||
AnthropicError::DeserializeResponse(error) => {
|
||||
Self::DeserializeResponse { provider, error }
|
||||
}
|
||||
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
||||
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { error },
|
||||
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { error },
|
||||
AnthropicError::HttpSend(error) => Self::HttpSend { error },
|
||||
AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse { error },
|
||||
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { error },
|
||||
AnthropicError::HttpResponseError {
|
||||
status_code,
|
||||
message,
|
||||
} => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
message,
|
||||
},
|
||||
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: Some(retry_after),
|
||||
},
|
||||
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
AnthropicError::ServerOverloaded { retry_after } => {
|
||||
Self::ServerOverloaded { retry_after }
|
||||
}
|
||||
AnthropicError::ApiError(api_error) => api_error.into(),
|
||||
}
|
||||
}
|
||||
@@ -308,37 +368,26 @@ impl From<AnthropicError> for LanguageModelCompletionError {
|
||||
impl From<anthropic::ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: anthropic::ApiError) -> Self {
|
||||
use anthropic::ApiErrorCode::*;
|
||||
let provider = ANTHROPIC_PROVIDER_NAME;
|
||||
match error.code() {
|
||||
Some(code) => match code {
|
||||
InvalidRequestError => Self::BadRequestFormat {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
AuthenticationError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
PermissionError => Self::PermissionError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
NotFoundError => Self::ApiEndpointNotFound { provider },
|
||||
NotFoundError => Self::ApiEndpointNotFound,
|
||||
RequestTooLarge => Self::PromptTooLarge {
|
||||
tokens: parse_prompt_too_long(&error.message),
|
||||
},
|
||||
RateLimitError => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
RateLimitError => Self::RateLimitExceeded { retry_after: None },
|
||||
ApiError => Self::ApiInternalServerError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded { retry_after: None },
|
||||
},
|
||||
None => Self::Other(error.into()),
|
||||
}
|
||||
@@ -349,7 +398,7 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
|
||||
fn from(error: open_ai::RequestError) -> Self {
|
||||
match error {
|
||||
open_ai::RequestError::HttpResponseError {
|
||||
provider,
|
||||
provider: _,
|
||||
status_code,
|
||||
body,
|
||||
headers,
|
||||
@@ -359,7 +408,7 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
|
||||
.and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
|
||||
.map(Duration::from_secs);
|
||||
|
||||
Self::from_http_status(provider.into(), status_code, body, retry_after)
|
||||
Self::from_http_status(status_code, body, retry_after)
|
||||
}
|
||||
open_ai::RequestError::Other(e) => Self::Other(e),
|
||||
}
|
||||
@@ -368,23 +417,18 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
|
||||
|
||||
impl From<OpenRouterError> for LanguageModelCompletionError {
|
||||
fn from(error: OpenRouterError) -> Self {
|
||||
let provider = LanguageModelProviderName::new("OpenRouter");
|
||||
match error {
|
||||
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
||||
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
||||
OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
|
||||
OpenRouterError::DeserializeResponse(error) => {
|
||||
Self::DeserializeResponse { provider, error }
|
||||
}
|
||||
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
||||
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { error },
|
||||
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { error },
|
||||
OpenRouterError::HttpSend(error) => Self::HttpSend { error },
|
||||
OpenRouterError::DeserializeResponse(error) => Self::DeserializeResponse { error },
|
||||
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { error },
|
||||
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: Some(retry_after),
|
||||
},
|
||||
OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
OpenRouterError::ServerOverloaded { retry_after } => {
|
||||
Self::ServerOverloaded { retry_after }
|
||||
}
|
||||
OpenRouterError::ApiError(api_error) => api_error.into(),
|
||||
}
|
||||
}
|
||||
@@ -393,41 +437,28 @@ impl From<OpenRouterError> for LanguageModelCompletionError {
|
||||
impl From<open_router::ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: open_router::ApiError) -> Self {
|
||||
use open_router::ApiErrorCode::*;
|
||||
let provider = LanguageModelProviderName::new("OpenRouter");
|
||||
match error.code {
|
||||
InvalidRequestError => Self::BadRequestFormat {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
AuthenticationError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
PaymentRequiredError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: format!("Payment required: {}", error.message),
|
||||
},
|
||||
PermissionError => Self::PermissionError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
RequestTimedOut => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code: StatusCode::REQUEST_TIMEOUT,
|
||||
message: error.message,
|
||||
},
|
||||
RateLimitError => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
RateLimitError => Self::RateLimitExceeded { retry_after: None },
|
||||
ApiError => Self::ApiInternalServerError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded { retry_after: None },
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -515,6 +546,9 @@ pub struct LanguageModelToolUse {
|
||||
pub raw_input: String,
|
||||
pub input: serde_json::Value,
|
||||
pub is_input_complete: bool,
|
||||
/// Thought signature the model sent us. Some models require that this
|
||||
/// signature be preserved and sent back in conversation history for validation.
|
||||
pub thought_signature: Option<String>,
|
||||
}
|
||||
|
||||
pub struct LanguageModelTextStream {
|
||||
@@ -630,7 +664,10 @@ pub trait LanguageModel: Send + Sync {
|
||||
let last_token_usage = last_token_usage.clone();
|
||||
async move {
|
||||
match result {
|
||||
Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
|
||||
Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
|
||||
Ok(LanguageModelCompletionEvent::Started) => None,
|
||||
Ok(LanguageModelCompletionEvent::RequestUsage { .. }) => None,
|
||||
Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
|
||||
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
|
||||
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
|
||||
@@ -640,7 +677,7 @@ pub trait LanguageModel: Send + Sync {
|
||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
..
|
||||
}) => None,
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
Ok(LanguageModelCompletionEvent::TokenUsage(token_usage)) => {
|
||||
*last_token_usage.lock() = token_usage;
|
||||
None
|
||||
}
|
||||
@@ -829,16 +866,13 @@ mod tests {
|
||||
#[test]
|
||||
fn test_from_cloud_failure_with_upstream_http_error() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
LanguageModelCompletionError::ServerOverloaded { .. } => {}
|
||||
_ => panic!(
|
||||
"Expected ServerOverloaded error for 503 status, got: {:?}",
|
||||
error
|
||||
@@ -846,15 +880,13 @@ mod tests {
|
||||
}
|
||||
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
LanguageModelCompletionError::ApiInternalServerError { message } => {
|
||||
assert_eq!(message, "Internal server error");
|
||||
}
|
||||
_ => panic!(
|
||||
@@ -867,16 +899,13 @@ mod tests {
|
||||
#[test]
|
||||
fn test_from_cloud_failure_with_standard_format() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_503".to_string(),
|
||||
"Service unavailable".to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
LanguageModelCompletionError::ServerOverloaded { .. } => {}
|
||||
_ => panic!("Expected ServerOverloaded error for upstream_http_503"),
|
||||
}
|
||||
}
|
||||
@@ -884,16 +913,13 @@ mod tests {
|
||||
#[test]
|
||||
fn test_upstream_http_error_connection_timeout() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
LanguageModelCompletionError::ServerOverloaded { .. } => {}
|
||||
_ => panic!(
|
||||
"Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
|
||||
error
|
||||
@@ -901,15 +927,13 @@ mod tests {
|
||||
}
|
||||
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
LanguageModelCompletionError::ApiInternalServerError { message } => {
|
||||
assert_eq!(
|
||||
message,
|
||||
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
|
||||
@@ -921,4 +945,85 @@ mod tests {
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_use_serializes_with_signature() {
|
||||
use serde_json::json;
|
||||
|
||||
let tool_use = LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("test_id"),
|
||||
name: "test_tool".into(),
|
||||
raw_input: json!({"arg": "value"}).to_string(),
|
||||
input: json!({"arg": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: Some("test_signature".to_string()),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_value(&tool_use).unwrap();
|
||||
|
||||
assert_eq!(serialized["id"], "test_id");
|
||||
assert_eq!(serialized["name"], "test_tool");
|
||||
assert_eq!(serialized["thought_signature"], "test_signature");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_use_deserializes_with_missing_signature() {
|
||||
use serde_json::json;
|
||||
|
||||
let json = json!({
|
||||
"id": "test_id",
|
||||
"name": "test_tool",
|
||||
"raw_input": "{\"arg\":\"value\"}",
|
||||
"input": {"arg": "value"},
|
||||
"is_input_complete": true
|
||||
});
|
||||
|
||||
let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
|
||||
|
||||
assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
|
||||
assert_eq!(tool_use.name.as_ref(), "test_tool");
|
||||
assert_eq!(tool_use.thought_signature, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_use_round_trip_with_signature() {
|
||||
use serde_json::json;
|
||||
|
||||
let original = LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("round_trip_id"),
|
||||
name: "round_trip_tool".into(),
|
||||
raw_input: json!({"key": "value"}).to_string(),
|
||||
input: json!({"key": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: Some("round_trip_sig".to_string()),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_value(&original).unwrap();
|
||||
let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.id, original.id);
|
||||
assert_eq!(deserialized.name, original.name);
|
||||
assert_eq!(deserialized.thought_signature, original.thought_signature);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_language_model_tool_use_round_trip_without_signature() {
|
||||
use serde_json::json;
|
||||
|
||||
let original = LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("no_sig_id"),
|
||||
name: "no_sig_tool".into(),
|
||||
raw_input: json!({"key": "value"}).to_string(),
|
||||
input: json!({"key": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_value(&original).unwrap();
|
||||
let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.id, original.id);
|
||||
assert_eq!(deserialized.name, original.name);
|
||||
assert_eq!(deserialized.thought_signature, None);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -320,9 +320,7 @@ impl AnthropicModel {
|
||||
|
||||
async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
});
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request = anthropic::stream_completion(
|
||||
http_client.as_ref(),
|
||||
@@ -711,6 +709,7 @@ impl AnthropicEventMapper {
|
||||
is_input_complete: false,
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
input,
|
||||
thought_signature: None,
|
||||
},
|
||||
))];
|
||||
}
|
||||
@@ -734,6 +733,7 @@ impl AnthropicEventMapper {
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
thought_signature: None,
|
||||
},
|
||||
)),
|
||||
Err(json_parse_err) => {
|
||||
@@ -754,7 +754,7 @@ impl AnthropicEventMapper {
|
||||
Event::MessageStart { message } => {
|
||||
update_usage(&mut self.usage, &message.usage);
|
||||
vec![
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
|
||||
Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
|
||||
&self.usage,
|
||||
))),
|
||||
Ok(LanguageModelCompletionEvent::StartMessage {
|
||||
@@ -776,9 +776,9 @@ impl AnthropicEventMapper {
|
||||
}
|
||||
};
|
||||
}
|
||||
vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
convert_usage(&self.usage),
|
||||
))]
|
||||
vec![Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
|
||||
&self.usage,
|
||||
)))]
|
||||
}
|
||||
Event::MessageStop => {
|
||||
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
|
||||
|
||||
@@ -970,11 +970,12 @@ pub fn map_to_language_model_completion_events(
|
||||
is_input_complete: true,
|
||||
raw_input: tool_use.input_json,
|
||||
input,
|
||||
thought_signature: None,
|
||||
},
|
||||
))
|
||||
}),
|
||||
ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| {
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: metadata.input_tokens as u64,
|
||||
output_tokens: metadata.output_tokens as u64,
|
||||
cache_creation_input_tokens: metadata
|
||||
|
||||
@@ -541,7 +541,6 @@ impl From<ApiError> for LanguageModelCompletionError {
|
||||
}
|
||||
|
||||
return LanguageModelCompletionError::from_http_status(
|
||||
PROVIDER_NAME,
|
||||
error.status,
|
||||
cloud_error.message,
|
||||
None,
|
||||
@@ -549,12 +548,7 @@ impl From<ApiError> for LanguageModelCompletionError {
|
||||
}
|
||||
|
||||
let retry_after = None;
|
||||
LanguageModelCompletionError::from_http_status(
|
||||
PROVIDER_NAME,
|
||||
error.status,
|
||||
error.body,
|
||||
retry_after,
|
||||
)
|
||||
LanguageModelCompletionError::from_http_status(error.status, error.body, retry_after)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -961,7 +955,7 @@ where
|
||||
vec![Err(LanguageModelCompletionError::from(error))]
|
||||
}
|
||||
Ok(CompletionEvent::Status(event)) => {
|
||||
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
|
||||
vec![LanguageModelCompletionEvent::from_completion_request_status(event)]
|
||||
}
|
||||
Ok(CompletionEvent::Event(event)) => map_callback(event),
|
||||
})
|
||||
@@ -1313,8 +1307,7 @@ mod tests {
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
||||
assert_eq!(provider, PROVIDER_NAME);
|
||||
LanguageModelCompletionError::ApiInternalServerError { message } => {
|
||||
assert_eq!(message, "Regular internal server error");
|
||||
}
|
||||
_ => panic!(
|
||||
@@ -1362,9 +1355,7 @@ mod tests {
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
|
||||
assert_eq!(provider, PROVIDER_NAME);
|
||||
}
|
||||
LanguageModelCompletionError::ApiInternalServerError { .. } => {}
|
||||
_ => panic!(
|
||||
"Expected ApiInternalServerError for invalid JSON, got: {:?}",
|
||||
completion_error
|
||||
|
||||
@@ -422,14 +422,12 @@ pub fn map_to_language_model_completion_events(
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
},
|
||||
)));
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
})));
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
@@ -458,6 +456,7 @@ pub fn map_to_language_model_completion_events(
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments,
|
||||
thought_signature: None,
|
||||
},
|
||||
)),
|
||||
Err(error) => Ok(
|
||||
@@ -560,6 +559,7 @@ impl CopilotResponsesEventMapper {
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: arguments.clone(),
|
||||
thought_signature: None,
|
||||
},
|
||||
))),
|
||||
Err(error) => {
|
||||
@@ -608,7 +608,7 @@ impl CopilotResponsesEventMapper {
|
||||
copilot::copilot_responses::StreamEvent::Completed { response } => {
|
||||
let mut events = Vec::new();
|
||||
if let Some(usage) = response.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.input_tokens.unwrap_or(0),
|
||||
output_tokens: usage.output_tokens.unwrap_or(0),
|
||||
cache_creation_input_tokens: 0,
|
||||
@@ -641,7 +641,7 @@ impl CopilotResponsesEventMapper {
|
||||
|
||||
let mut events = Vec::new();
|
||||
if let Some(usage) = response.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.input_tokens.unwrap_or(0),
|
||||
output_tokens: usage.output_tokens.unwrap_or(0),
|
||||
cache_creation_input_tokens: 0,
|
||||
@@ -653,7 +653,6 @@ impl CopilotResponsesEventMapper {
|
||||
}
|
||||
|
||||
copilot::copilot_responses::StreamEvent::Failed { response } => {
|
||||
let provider = PROVIDER_NAME;
|
||||
let (status_code, message) = match response.error {
|
||||
Some(error) => {
|
||||
let status_code = StatusCode::from_str(&error.code)
|
||||
@@ -666,7 +665,6 @@ impl CopilotResponsesEventMapper {
|
||||
),
|
||||
};
|
||||
vec![Err(LanguageModelCompletionError::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
message,
|
||||
})]
|
||||
@@ -1097,7 +1095,7 @@ mod tests {
|
||||
));
|
||||
assert!(matches!(
|
||||
mapped[2],
|
||||
LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: 5,
|
||||
output_tokens: 3,
|
||||
..
|
||||
@@ -1205,7 +1203,7 @@ mod tests {
|
||||
let mapped = map_events(events);
|
||||
assert!(matches!(
|
||||
mapped[0],
|
||||
LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: 10,
|
||||
output_tokens: 0,
|
||||
..
|
||||
|
||||
@@ -224,9 +224,7 @@ impl DeepSeekLanguageModel {
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
});
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request =
|
||||
deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
@@ -479,7 +477,7 @@ impl DeepSeekEventMapper {
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
@@ -501,6 +499,7 @@ impl DeepSeekEventMapper {
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
thought_signature: None,
|
||||
},
|
||||
)),
|
||||
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
|
||||
@@ -350,10 +350,7 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
|
||||
async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
}
|
||||
.into());
|
||||
return Err(LanguageModelCompletionError::NoApiKey.into());
|
||||
};
|
||||
let response = google_ai::count_tokens(
|
||||
http_client.as_ref(),
|
||||
@@ -439,11 +436,15 @@ pub fn into_google(
|
||||
})]
|
||||
}
|
||||
language_model::MessageContent::ToolUse(tool_use) => {
|
||||
// Normalize empty string signatures to None
|
||||
let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
|
||||
|
||||
vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
|
||||
function_call: google_ai::FunctionCall {
|
||||
name: tool_use.name.to_string(),
|
||||
args: tool_use.input,
|
||||
},
|
||||
thought_signature,
|
||||
})]
|
||||
}
|
||||
language_model::MessageContent::ToolResult(tool_result) => {
|
||||
@@ -604,9 +605,9 @@ impl GoogleEventMapper {
|
||||
let mut wants_to_use_tool = false;
|
||||
if let Some(usage_metadata) = event.usage_metadata {
|
||||
update_usage(&mut self.usage, &usage_metadata);
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
convert_usage(&self.usage),
|
||||
)))
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
|
||||
&self.usage,
|
||||
))))
|
||||
}
|
||||
|
||||
if let Some(prompt_feedback) = event.prompt_feedback
|
||||
@@ -655,6 +656,11 @@ impl GoogleEventMapper {
|
||||
let id: LanguageModelToolUseId =
|
||||
format!("{}-{}", name, next_tool_id).into();
|
||||
|
||||
// Normalize empty string signatures to None
|
||||
let thought_signature = function_call_part
|
||||
.thought_signature
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id,
|
||||
@@ -662,6 +668,7 @@ impl GoogleEventMapper {
|
||||
is_input_complete: true,
|
||||
raw_input: function_call_part.function_call.args.to_string(),
|
||||
input: function_call_part.function_call.args,
|
||||
thought_signature,
|
||||
},
|
||||
)));
|
||||
}
|
||||
@@ -891,3 +898,424 @@ impl Render for ConfigurationView {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use google_ai::{
|
||||
Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
|
||||
Part, Role as GoogleRole, TextPart,
|
||||
};
|
||||
use language_model::{LanguageModelToolUseId, MessageContent, Role};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_function_call_with_signature_creates_tool_use_with_signature() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: Some("test_signature_123".to_string()),
|
||||
})],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
|
||||
assert_eq!(events.len(), 2); // ToolUse event + Stop event
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
assert_eq!(tool_use.name.as_ref(), "test_function");
|
||||
assert_eq!(
|
||||
tool_use.thought_signature.as_deref(),
|
||||
Some("test_signature_123")
|
||||
);
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_function_call_without_signature_has_none() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: None,
|
||||
})],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
assert_eq!(tool_use.thought_signature, None);
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_string_signature_normalized_to_none() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: Some("".to_string()),
|
||||
})],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
assert_eq!(tool_use.thought_signature, None);
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_function_calls_preserve_signatures() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![
|
||||
Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "function_1".to_string(),
|
||||
args: json!({"arg": "value1"}),
|
||||
},
|
||||
thought_signature: Some("signature_1".to_string()),
|
||||
}),
|
||||
Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "function_2".to_string(),
|
||||
args: json!({"arg": "value2"}),
|
||||
},
|
||||
thought_signature: None,
|
||||
}),
|
||||
],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
|
||||
assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
assert_eq!(tool_use.name.as_ref(), "function_1");
|
||||
assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1"));
|
||||
} else {
|
||||
panic!("Expected ToolUse event for function_1");
|
||||
}
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
|
||||
assert_eq!(tool_use.name.as_ref(), "function_2");
|
||||
assert_eq!(tool_use.thought_signature, None);
|
||||
} else {
|
||||
panic!("Expected ToolUse event for function_2");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_use_with_signature_converts_to_function_call_part() {
|
||||
let tool_use = language_model::LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("test_id"),
|
||||
name: "test_function".into(),
|
||||
raw_input: json!({"arg": "value"}).to_string(),
|
||||
input: json!({"arg": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: Some("test_signature_456".to_string()),
|
||||
};
|
||||
|
||||
let request = super::into_google(
|
||||
LanguageModelRequest {
|
||||
messages: vec![language_model::LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::ToolUse(tool_use)],
|
||||
cache: false,
|
||||
}],
|
||||
..Default::default()
|
||||
},
|
||||
"gemini-2.5-flash".to_string(),
|
||||
GoogleModelMode::Default,
|
||||
);
|
||||
|
||||
assert_eq!(request.contents[0].parts.len(), 1);
|
||||
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
|
||||
assert_eq!(fc_part.function_call.name, "test_function");
|
||||
assert_eq!(
|
||||
fc_part.thought_signature.as_deref(),
|
||||
Some("test_signature_456")
|
||||
);
|
||||
} else {
|
||||
panic!("Expected FunctionCallPart");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_use_without_signature_omits_field() {
|
||||
let tool_use = language_model::LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("test_id"),
|
||||
name: "test_function".into(),
|
||||
raw_input: json!({"arg": "value"}).to_string(),
|
||||
input: json!({"arg": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
|
||||
let request = super::into_google(
|
||||
LanguageModelRequest {
|
||||
messages: vec![language_model::LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::ToolUse(tool_use)],
|
||||
cache: false,
|
||||
}],
|
||||
..Default::default()
|
||||
},
|
||||
"gemini-2.5-flash".to_string(),
|
||||
GoogleModelMode::Default,
|
||||
);
|
||||
|
||||
assert_eq!(request.contents[0].parts.len(), 1);
|
||||
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
|
||||
assert_eq!(fc_part.thought_signature, None);
|
||||
} else {
|
||||
panic!("Expected FunctionCallPart");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_signature_in_tool_use_normalized_to_none() {
|
||||
let tool_use = language_model::LanguageModelToolUse {
|
||||
id: LanguageModelToolUseId::from("test_id"),
|
||||
name: "test_function".into(),
|
||||
raw_input: json!({"arg": "value"}).to_string(),
|
||||
input: json!({"arg": "value"}),
|
||||
is_input_complete: true,
|
||||
thought_signature: Some("".to_string()),
|
||||
};
|
||||
|
||||
let request = super::into_google(
|
||||
LanguageModelRequest {
|
||||
messages: vec![language_model::LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::ToolUse(tool_use)],
|
||||
cache: false,
|
||||
}],
|
||||
..Default::default()
|
||||
},
|
||||
"gemini-2.5-flash".to_string(),
|
||||
GoogleModelMode::Default,
|
||||
);
|
||||
|
||||
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
|
||||
assert_eq!(fc_part.thought_signature, None);
|
||||
} else {
|
||||
panic!("Expected FunctionCallPart");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_preserves_signature() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
// Simulate receiving a response from Google with a signature
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: Some("round_trip_sig".to_string()),
|
||||
})],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
|
||||
let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
tool_use.clone()
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
};
|
||||
|
||||
// Convert back to Google format
|
||||
let request = super::into_google(
|
||||
LanguageModelRequest {
|
||||
messages: vec![language_model::LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::ToolUse(tool_use)],
|
||||
cache: false,
|
||||
}],
|
||||
..Default::default()
|
||||
},
|
||||
"gemini-2.5-flash".to_string(),
|
||||
GoogleModelMode::Default,
|
||||
);
|
||||
|
||||
// Verify signature is preserved
|
||||
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
|
||||
assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig"));
|
||||
} else {
|
||||
panic!("Expected FunctionCallPart");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_text_and_function_call_with_signature() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![
|
||||
Part::TextPart(TextPart {
|
||||
text: "I'll help with that.".to_string(),
|
||||
}),
|
||||
Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "helper_function".to_string(),
|
||||
args: json!({"query": "help"}),
|
||||
},
|
||||
thought_signature: Some("mixed_sig".to_string()),
|
||||
}),
|
||||
],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
|
||||
assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] {
|
||||
assert_eq!(text, "I'll help with that.");
|
||||
} else {
|
||||
panic!("Expected Text event");
|
||||
}
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
|
||||
assert_eq!(tool_use.name.as_ref(), "helper_function");
|
||||
assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig"));
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_special_characters_in_signature_preserved() {
|
||||
let mut mapper = GoogleEventMapper::new();
|
||||
|
||||
let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string();
|
||||
|
||||
let response = GenerateContentResponse {
|
||||
candidates: Some(vec![GenerateContentCandidate {
|
||||
index: Some(0),
|
||||
content: Content {
|
||||
parts: vec![Part::FunctionCallPart(FunctionCallPart {
|
||||
function_call: FunctionCall {
|
||||
name: "test_function".to_string(),
|
||||
args: json!({"arg": "value"}),
|
||||
},
|
||||
thought_signature: Some(signature_with_special_chars.clone()),
|
||||
})],
|
||||
role: GoogleRole::Model,
|
||||
},
|
||||
finish_reason: None,
|
||||
finish_message: None,
|
||||
safety_ratings: None,
|
||||
citation_metadata: None,
|
||||
}]),
|
||||
prompt_feedback: None,
|
||||
usage_metadata: None,
|
||||
};
|
||||
|
||||
let events = mapper.map_event(response);
|
||||
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
|
||||
assert_eq!(
|
||||
tool_use.thought_signature.as_deref(),
|
||||
Some(signature_with_special_chars.as_str())
|
||||
);
|
||||
} else {
|
||||
panic!("Expected ToolUse event");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -547,7 +547,7 @@ impl LmStudioEventMapper {
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
@@ -569,6 +569,7 @@ impl LmStudioEventMapper {
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments,
|
||||
thought_signature: None,
|
||||
},
|
||||
)),
|
||||
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
|
||||
@@ -291,9 +291,7 @@ impl MistralLanguageModel {
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
});
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request =
|
||||
mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
@@ -672,7 +670,7 @@ impl MistralEventMapper {
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
@@ -720,6 +718,7 @@ impl MistralEventMapper {
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments,
|
||||
thought_signature: None,
|
||||
},
|
||||
))),
|
||||
Err(error) => {
|
||||
|
||||
@@ -592,6 +592,7 @@ fn map_to_language_model_completion_events(
|
||||
raw_input: function.arguments.to_string(),
|
||||
input: function.arguments,
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
});
|
||||
events.push(Ok(event));
|
||||
state.used_tools = true;
|
||||
@@ -602,7 +603,7 @@ fn map_to_language_model_completion_events(
|
||||
};
|
||||
|
||||
if delta.done {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: delta.prompt_eval_count.unwrap_or(0),
|
||||
output_tokens: delta.eval_count.unwrap_or(0),
|
||||
cache_creation_input_tokens: 0,
|
||||
|
||||
@@ -228,7 +228,7 @@ impl OpenAiLanguageModel {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let provider = PROVIDER_NAME;
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey { provider });
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
@@ -534,7 +534,7 @@ impl OpenAiEventMapper {
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let mut events = Vec::new();
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
@@ -586,6 +586,7 @@ impl OpenAiEventMapper {
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
thought_signature: None,
|
||||
},
|
||||
)),
|
||||
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
|
||||
@@ -227,7 +227,7 @@ impl OpenAiCompatibleLanguageModel {
|
||||
let provider = self.provider_name.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey { provider });
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
|
||||
@@ -84,9 +84,7 @@ impl State {
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
|
||||
let Some(api_key) = self.api_key_state.key(&api_url) else {
|
||||
return Task::ready(Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
}));
|
||||
return Task::ready(Err(LanguageModelCompletionError::NoApiKey));
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
let models = list_models(http_client.as_ref(), &api_url, &api_key)
|
||||
@@ -288,9 +286,7 @@ impl OpenRouterLanguageModel {
|
||||
|
||||
async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
});
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request =
|
||||
open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
@@ -613,7 +609,7 @@ impl OpenRouterEventMapper {
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
@@ -635,6 +631,7 @@ impl OpenRouterEventMapper {
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
thought_signature: None,
|
||||
},
|
||||
)),
|
||||
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
|
||||
@@ -222,7 +222,7 @@ impl VercelLanguageModel {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let provider = PROVIDER_NAME;
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey { provider });
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request = open_ai::stream_completion(
|
||||
http_client.as_ref(),
|
||||
|
||||
@@ -230,7 +230,7 @@ impl XAiLanguageModel {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let provider = PROVIDER_NAME;
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey { provider });
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request = open_ai::stream_completion(
|
||||
http_client.as_ref(),
|
||||
|
||||
@@ -18,7 +18,6 @@ workspace.workspace = true
|
||||
util.workspace = true
|
||||
serde_json.workspace = true
|
||||
smol.workspace = true
|
||||
log.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -5,7 +5,10 @@ use std::{
|
||||
};
|
||||
|
||||
use gpui::{
|
||||
App, AppContext, Context, Entity, Hsla, InteractiveElement, IntoElement, ParentElement, Render, ScrollHandle, SerializedThreadTaskTimings, StatefulInteractiveElement, Styled, Task, TaskTiming, ThreadTaskTimings, TitlebarOptions, WindowBounds, WindowHandle, WindowOptions, div, prelude::FluentBuilder, px, relative, size
|
||||
App, AppContext, Context, Entity, Hsla, InteractiveElement, IntoElement, ParentElement, Render,
|
||||
ScrollHandle, SerializedTaskTiming, StatefulInteractiveElement, Styled, Task, TaskTiming,
|
||||
TitlebarOptions, WindowBounds, WindowHandle, WindowOptions, div, prelude::FluentBuilder, px,
|
||||
relative, size,
|
||||
};
|
||||
use util::ResultExt;
|
||||
use workspace::{
|
||||
@@ -284,13 +287,8 @@ impl Render for ProfilerWindow {
|
||||
let Some(data) = this.get_timings() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let timings = ThreadTaskTimings {
|
||||
thread_name: Some("main".to_string()),
|
||||
thread_id: std::thread::current().id(),
|
||||
timings: data.clone()
|
||||
};
|
||||
let timings = Vec::from([SerializedThreadTaskTimings::convert(this.startup_time, timings)]);
|
||||
let timings =
|
||||
SerializedTaskTiming::convert(this.startup_time, &data);
|
||||
|
||||
let active_path = workspace
|
||||
.read_with(cx, |workspace, cx| {
|
||||
@@ -307,17 +305,12 @@ impl Render for ProfilerWindow {
|
||||
);
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let path = match path.await.log_err() {
|
||||
Some(Ok(Some(path))) => path,
|
||||
Some(e @ Err(_)) => {
|
||||
e.log_err();
|
||||
log::warn!("Saving miniprof in workingdir");
|
||||
std::path::Path::new(
|
||||
"performance_profile.miniprof",
|
||||
)
|
||||
.to_path_buf()
|
||||
}
|
||||
Some(Ok(None)) | None => return,
|
||||
let path = path.await;
|
||||
let path =
|
||||
path.log_err().and_then(|p| p.log_err()).flatten();
|
||||
|
||||
let Some(path) = path else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(timings) =
|
||||
|
||||
@@ -43,7 +43,6 @@ text.workspace = true
|
||||
theme.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
util.workspace = true
|
||||
zlog.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
buffer_diff = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -76,8 +76,6 @@ impl MultiBuffer {
|
||||
context_line_count: u32,
|
||||
cx: &mut Context<Self>,
|
||||
) -> (Vec<Range<Anchor>>, bool) {
|
||||
let _timer =
|
||||
zlog::time!("set_excerpts_for_path").warn_if_gt(std::time::Duration::from_millis(100));
|
||||
let buffer_snapshot = buffer.read(cx).snapshot();
|
||||
let excerpt_ranges = build_excerpt_ranges(ranges, context_line_count, &buffer_snapshot);
|
||||
|
||||
|
||||
@@ -449,7 +449,7 @@ pub async fn handle_import_vscode_settings(
|
||||
match settings::VsCodeSettings::load_user_settings(source, fs.clone()).await {
|
||||
Ok(vscode_settings) => vscode_settings,
|
||||
Err(err) => {
|
||||
zlog::error!("{err}");
|
||||
zlog::error!("{err:?}");
|
||||
let _ = cx.prompt(
|
||||
gpui::PromptLevel::Info,
|
||||
&format!("Could not find or load a {source} settings file"),
|
||||
|
||||
@@ -99,13 +99,18 @@ pub enum ContextServerConfiguration {
|
||||
command: ContextServerCommand,
|
||||
settings: serde_json::Value,
|
||||
},
|
||||
Http {
|
||||
url: url::Url,
|
||||
headers: HashMap<String, String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl ContextServerConfiguration {
|
||||
pub fn command(&self) -> &ContextServerCommand {
|
||||
pub fn command(&self) -> Option<&ContextServerCommand> {
|
||||
match self {
|
||||
ContextServerConfiguration::Custom { command } => command,
|
||||
ContextServerConfiguration::Extension { command, .. } => command,
|
||||
ContextServerConfiguration::Custom { command } => Some(command),
|
||||
ContextServerConfiguration::Extension { command, .. } => Some(command),
|
||||
ContextServerConfiguration::Http { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,6 +147,14 @@ impl ContextServerConfiguration {
|
||||
}
|
||||
}
|
||||
}
|
||||
ContextServerSettings::Http {
|
||||
enabled: _,
|
||||
url,
|
||||
headers: auth,
|
||||
} => {
|
||||
let url = url::Url::parse(&url).log_err()?;
|
||||
Some(ContextServerConfiguration::Http { url, headers: auth })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -207,7 +220,7 @@ impl ContextServerStore {
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn test_maintain_server_loop(
|
||||
context_server_factory: ContextServerFactory,
|
||||
context_server_factory: Option<ContextServerFactory>,
|
||||
registry: Entity<ContextServerDescriptorRegistry>,
|
||||
worktree_store: Entity<WorktreeStore>,
|
||||
weak_project: WeakEntity<Project>,
|
||||
@@ -215,7 +228,7 @@ impl ContextServerStore {
|
||||
) -> Self {
|
||||
Self::new_internal(
|
||||
true,
|
||||
Some(context_server_factory),
|
||||
context_server_factory,
|
||||
registry,
|
||||
worktree_store,
|
||||
weak_project,
|
||||
@@ -385,17 +398,6 @@ impl ContextServerStore {
|
||||
result
|
||||
}
|
||||
|
||||
pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
|
||||
if let Some(state) = self.servers.get(id) {
|
||||
let configuration = state.configuration();
|
||||
|
||||
self.stop_server(&state.server().id(), cx)?;
|
||||
let new_server = self.create_context_server(id.clone(), configuration.clone(), cx);
|
||||
self.run_server(new_server, configuration, cx);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_server(
|
||||
&mut self,
|
||||
server: Arc<ContextServer>,
|
||||
@@ -479,33 +481,42 @@ impl ContextServerStore {
|
||||
id: ContextServerId,
|
||||
configuration: Arc<ContextServerConfiguration>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Arc<ContextServer> {
|
||||
let project = self.project.upgrade();
|
||||
let mut root_path = None;
|
||||
if let Some(project) = project {
|
||||
let project = project.read(cx);
|
||||
if project.is_local() {
|
||||
if let Some(path) = project.active_project_directory(cx) {
|
||||
root_path = Some(path);
|
||||
} else {
|
||||
for worktree in self.worktree_store.read(cx).visible_worktrees(cx) {
|
||||
if let Some(path) = worktree.read(cx).root_dir() {
|
||||
root_path = Some(path);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
) -> Result<Arc<ContextServer>> {
|
||||
if let Some(factory) = self.context_server_factory.as_ref() {
|
||||
factory(id, configuration)
|
||||
} else {
|
||||
Arc::new(ContextServer::stdio(
|
||||
return Ok(factory(id, configuration));
|
||||
}
|
||||
|
||||
match configuration.as_ref() {
|
||||
ContextServerConfiguration::Http { url, headers } => Ok(Arc::new(ContextServer::http(
|
||||
id,
|
||||
configuration.command().clone(),
|
||||
root_path,
|
||||
))
|
||||
url,
|
||||
headers.clone(),
|
||||
cx.http_client(),
|
||||
cx.background_executor().clone(),
|
||||
)?)),
|
||||
_ => {
|
||||
let root_path = self
|
||||
.project
|
||||
.read_with(cx, |project, cx| project.active_project_directory(cx))
|
||||
.ok()
|
||||
.flatten()
|
||||
.or_else(|| {
|
||||
self.worktree_store.read_with(cx, |store, cx| {
|
||||
store.visible_worktrees(cx).fold(None, |acc, item| {
|
||||
if acc.is_none() {
|
||||
item.read(cx).root_dir()
|
||||
} else {
|
||||
acc
|
||||
}
|
||||
})
|
||||
})
|
||||
});
|
||||
Ok(Arc::new(ContextServer::stdio(
|
||||
id,
|
||||
configuration.command().unwrap().clone(),
|
||||
root_path,
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -621,14 +632,16 @@ impl ContextServerStore {
|
||||
let existing_config = state.as_ref().map(|state| state.configuration());
|
||||
if existing_config.as_deref() != Some(&config) || is_stopped {
|
||||
let config = Arc::new(config);
|
||||
let server = this.create_context_server(id.clone(), config.clone(), cx);
|
||||
let server = this.create_context_server(id.clone(), config.clone(), cx)?;
|
||||
servers_to_start.push((server, config));
|
||||
if this.servers.contains_key(&id) {
|
||||
servers_to_stop.insert(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
anyhow::Ok(())
|
||||
})??;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
for id in servers_to_stop {
|
||||
@@ -654,6 +667,7 @@ mod tests {
|
||||
};
|
||||
use context_server::test::create_fake_transport;
|
||||
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
|
||||
use http_client::{FakeHttpClient, Response};
|
||||
use serde_json::json;
|
||||
use std::{cell::RefCell, path::PathBuf, rc::Rc};
|
||||
use util::path;
|
||||
@@ -894,12 +908,12 @@ mod tests {
|
||||
});
|
||||
let store = cx.new(|cx| {
|
||||
ContextServerStore::test_maintain_server_loop(
|
||||
Box::new(move |id, _| {
|
||||
Some(Box::new(move |id, _| {
|
||||
Arc::new(ContextServer::new(
|
||||
id.clone(),
|
||||
Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
|
||||
))
|
||||
}),
|
||||
})),
|
||||
registry.clone(),
|
||||
project.read(cx).worktree_store(),
|
||||
project.downgrade(),
|
||||
@@ -1130,12 +1144,12 @@ mod tests {
|
||||
let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
|
||||
let store = cx.new(|cx| {
|
||||
ContextServerStore::test_maintain_server_loop(
|
||||
Box::new(move |id, _| {
|
||||
Some(Box::new(move |id, _| {
|
||||
Arc::new(ContextServer::new(
|
||||
id.clone(),
|
||||
Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
|
||||
))
|
||||
}),
|
||||
})),
|
||||
registry.clone(),
|
||||
project.read(cx).worktree_store(),
|
||||
project.downgrade(),
|
||||
@@ -1228,6 +1242,73 @@ mod tests {
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_remote_context_server(cx: &mut TestAppContext) {
|
||||
const SERVER_ID: &str = "remote-server";
|
||||
let server_id = ContextServerId(SERVER_ID.into());
|
||||
let server_url = "http://example.com/api";
|
||||
|
||||
let (_fs, project) = setup_context_server_test(
|
||||
cx,
|
||||
json!({ "code.rs": "" }),
|
||||
vec![(
|
||||
SERVER_ID.into(),
|
||||
ContextServerSettings::Http {
|
||||
enabled: true,
|
||||
url: server_url.to_string(),
|
||||
headers: Default::default(),
|
||||
},
|
||||
)],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = FakeHttpClient::create(|_| async move {
|
||||
use http_client::AsyncBody;
|
||||
|
||||
let response = Response::builder()
|
||||
.status(200)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(AsyncBody::from(
|
||||
serde_json::to_string(&json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 0,
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"serverInfo": {
|
||||
"name": "test-server",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
}))
|
||||
.unwrap(),
|
||||
))
|
||||
.unwrap();
|
||||
Ok(response)
|
||||
});
|
||||
cx.update(|cx| cx.set_http_client(client));
|
||||
let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
|
||||
let store = cx.new(|cx| {
|
||||
ContextServerStore::test_maintain_server_loop(
|
||||
None,
|
||||
registry.clone(),
|
||||
project.read(cx).worktree_store(),
|
||||
project.downgrade(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let _server_events = assert_server_events(
|
||||
&store,
|
||||
vec![
|
||||
(server_id.clone(), ContextServerStatus::Starting),
|
||||
(server_id.clone(), ContextServerStatus::Running),
|
||||
],
|
||||
cx,
|
||||
);
|
||||
cx.run_until_parked();
|
||||
}
|
||||
|
||||
struct ServerEvents {
|
||||
received_event_count: Rc<RefCell<usize>>,
|
||||
expected_event_count: usize,
|
||||
|
||||
@@ -34,11 +34,11 @@ impl DiffBase {
|
||||
|
||||
pub struct BranchDiff {
|
||||
diff_base: DiffBase,
|
||||
pub repo: Option<Entity<Repository>>,
|
||||
pub project: Entity<Project>,
|
||||
repo: Option<Entity<Repository>>,
|
||||
project: Entity<Project>,
|
||||
base_commit: Option<SharedString>,
|
||||
head_commit: Option<SharedString>,
|
||||
pub tree_diff: Option<TreeDiff>,
|
||||
tree_diff: Option<TreeDiff>,
|
||||
_subscription: Subscription,
|
||||
update_needed: postage::watch::Sender<()>,
|
||||
_task: Task<()>,
|
||||
@@ -283,11 +283,7 @@ impl BranchDiff {
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let repo = repo.clone();
|
||||
let task = cx.spawn(async move |project, cx| {
|
||||
Self::load_buffer(branch_diff, project_path, repo.clone(), project, cx).await
|
||||
});
|
||||
let task = Self::load_buffer(branch_diff, project_path, repo.clone(), cx);
|
||||
|
||||
output.push(DiffBuffer {
|
||||
repo_path: item.repo_path.clone(),
|
||||
@@ -307,11 +303,8 @@ impl BranchDiff {
|
||||
let Some(project_path) = repo.read(cx).repo_path_to_project_path(&path, cx) else {
|
||||
continue;
|
||||
};
|
||||
let repo = repo.clone();
|
||||
let branch_diff2 = Some(branch_diff.clone());
|
||||
let task = cx.spawn(async move |project, cx| {
|
||||
Self::load_buffer(branch_diff2, project_path, repo, project, cx).await
|
||||
});
|
||||
let task =
|
||||
Self::load_buffer(Some(branch_diff.clone()), project_path, repo.clone(), cx);
|
||||
|
||||
let file_status = diff_status_to_file_status(branch_diff);
|
||||
|
||||
@@ -325,40 +318,42 @@ impl BranchDiff {
|
||||
output
|
||||
}
|
||||
|
||||
pub async fn load_buffer(
|
||||
fn load_buffer(
|
||||
branch_diff: Option<git::status::TreeDiffStatus>,
|
||||
project_path: crate::ProjectPath,
|
||||
repo: Entity<Repository>,
|
||||
project: WeakEntity<Project>,
|
||||
cx: &mut gpui::AsyncApp, // making this generic over AppContext hangs the compiler
|
||||
) -> Result<(Entity<Buffer>, Entity<BufferDiff>)> {
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||
.await?;
|
||||
cx: &Context<'_, Project>,
|
||||
) -> Task<Result<(Entity<Buffer>, Entity<BufferDiff>)>> {
|
||||
let task = cx.spawn(async move |project, cx| {
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||
.await?;
|
||||
|
||||
let languages = project.update(cx, |project, _cx| project.languages().clone())?;
|
||||
let languages = project.update(cx, |project, _cx| project.languages().clone())?;
|
||||
|
||||
let changes = if let Some(entry) = branch_diff {
|
||||
let oid = match entry {
|
||||
git::status::TreeDiffStatus::Added { .. } => None,
|
||||
git::status::TreeDiffStatus::Modified { old, .. }
|
||||
| git::status::TreeDiffStatus::Deleted { old } => Some(old),
|
||||
let changes = if let Some(entry) = branch_diff {
|
||||
let oid = match entry {
|
||||
git::status::TreeDiffStatus::Added { .. } => None,
|
||||
git::status::TreeDiffStatus::Modified { old, .. }
|
||||
| git::status::TreeDiffStatus::Deleted { old } => Some(old),
|
||||
};
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.git_store().update(cx, |git_store, cx| {
|
||||
git_store.open_diff_since(oid, buffer.clone(), repo, languages, cx)
|
||||
})
|
||||
})?
|
||||
.await?
|
||||
} else {
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_uncommitted_diff(buffer.clone(), cx)
|
||||
})?
|
||||
.await?
|
||||
};
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.git_store().update(cx, |git_store, cx| {
|
||||
git_store.open_diff_since(oid, buffer.clone(), repo, languages, cx)
|
||||
})
|
||||
})?
|
||||
.await?
|
||||
} else {
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_uncommitted_diff(buffer.clone(), cx)
|
||||
})?
|
||||
.await?
|
||||
};
|
||||
Ok((buffer, changes))
|
||||
Ok((buffer, changes))
|
||||
});
|
||||
task
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -135,6 +135,16 @@ pub enum ContextServerSettings {
|
||||
/// are supported.
|
||||
settings: serde_json::Value,
|
||||
},
|
||||
Http {
|
||||
/// Whether the context server is enabled.
|
||||
#[serde(default = "default_true")]
|
||||
enabled: bool,
|
||||
/// The URL of the remote context server.
|
||||
url: String,
|
||||
/// Optional authentication configuration for the remote server.
|
||||
#[serde(skip_serializing_if = "HashMap::is_empty", default)]
|
||||
headers: HashMap<String, String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<settings::ContextServerSettingsContent> for ContextServerSettings {
|
||||
@@ -146,6 +156,15 @@ impl From<settings::ContextServerSettingsContent> for ContextServerSettings {
|
||||
settings::ContextServerSettingsContent::Extension { enabled, settings } => {
|
||||
ContextServerSettings::Extension { enabled, settings }
|
||||
}
|
||||
settings::ContextServerSettingsContent::Http {
|
||||
enabled,
|
||||
url,
|
||||
headers,
|
||||
} => ContextServerSettings::Http {
|
||||
enabled,
|
||||
url,
|
||||
headers,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -158,6 +177,15 @@ impl Into<settings::ContextServerSettingsContent> for ContextServerSettings {
|
||||
ContextServerSettings::Extension { enabled, settings } => {
|
||||
settings::ContextServerSettingsContent::Extension { enabled, settings }
|
||||
}
|
||||
ContextServerSettings::Http {
|
||||
enabled,
|
||||
url,
|
||||
headers,
|
||||
} => settings::ContextServerSettingsContent::Http {
|
||||
enabled,
|
||||
url,
|
||||
headers,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -174,6 +202,7 @@ impl ContextServerSettings {
|
||||
match self {
|
||||
ContextServerSettings::Custom { enabled, .. } => *enabled,
|
||||
ContextServerSettings::Extension { enabled, .. } => *enabled,
|
||||
ContextServerSettings::Http { enabled, .. } => *enabled,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,6 +210,7 @@ impl ContextServerSettings {
|
||||
match self {
|
||||
ContextServerSettings::Custom { enabled: e, .. } => *e = enabled,
|
||||
ContextServerSettings::Extension { enabled: e, .. } => *e = enabled,
|
||||
ContextServerSettings::Http { enabled: e, .. } => *e = enabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ use workspace::{
|
||||
notifications::{DetachAndPromptErr, NotifyResultExt, NotifyTaskExt},
|
||||
};
|
||||
use worktree::CreatedEntry;
|
||||
use zed_actions::workspace::OpenWithSystem;
|
||||
use zed_actions::{project_panel::ToggleFocus, workspace::OpenWithSystem};
|
||||
|
||||
const PROJECT_PANEL_KEY: &str = "ProjectPanel";
|
||||
const NEW_ENTRY_ID: ProjectEntryId = ProjectEntryId::MAX;
|
||||
@@ -306,8 +306,6 @@ actions!(
|
||||
OpenSplitVertical,
|
||||
/// Opens the selected file in a horizontal split.
|
||||
OpenSplitHorizontal,
|
||||
/// Toggles focus on the project panel.
|
||||
ToggleFocus,
|
||||
/// Toggles visibility of git-ignored files.
|
||||
ToggleHideGitIgnore,
|
||||
/// Toggles visibility of hidden files.
|
||||
|
||||
@@ -489,7 +489,7 @@ impl SshRemoteConnection {
|
||||
let ssh_shell = socket.shell().await;
|
||||
log::info!("Remote shell discovered: {}", ssh_shell);
|
||||
let ssh_platform = socket.platform(ShellKind::new(&ssh_shell, false)).await?;
|
||||
log::info!("Remote platform discovered: {}", ssh_shell);
|
||||
log::info!("Remote platform discovered: {:?}", ssh_platform);
|
||||
let ssh_path_style = match ssh_platform.os {
|
||||
"windows" => PathStyle::Windows,
|
||||
_ => PathStyle::Posix,
|
||||
|
||||
@@ -92,7 +92,7 @@ impl WslRemoteConnection {
|
||||
.detect_platform()
|
||||
.await
|
||||
.context("failed detecting platform")?;
|
||||
log::info!("Remote platform discovered: {}", this.shell);
|
||||
log::info!("Remote platform discovered: {:?}", this.platform);
|
||||
this.remote_binary_path = Some(
|
||||
this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
|
||||
.await
|
||||
|
||||
@@ -13,32 +13,6 @@ use std::{
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
// https://docs.rs/tokio/latest/src/tokio/task/yield_now.rs.html#39-64
|
||||
pub async fn yield_now() {
|
||||
/// Yield implementation
|
||||
struct YieldNow {
|
||||
yielded: bool,
|
||||
}
|
||||
|
||||
impl Future for YieldNow {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
// use core::task::ready;
|
||||
// ready!(crate::trace::trace_leaf(cx));
|
||||
if self.yielded {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
|
||||
self.yielded = true;
|
||||
// context::defer(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
YieldNow { yielded: false }.await;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ForegroundExecutor {
|
||||
session_id: SessionId,
|
||||
|
||||
@@ -1039,218 +1039,3 @@ impl std::fmt::Display for DelayMs {
|
||||
write!(f, "{}ms", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper type that distinguishes between an explicitly set value (including null) and an unset value.
|
||||
///
|
||||
/// This is useful for configuration where you need to differentiate between:
|
||||
/// - A field that is not present in the configuration file (`Maybe::Unset`)
|
||||
/// - A field that is explicitly set to `null` (`Maybe::Set(None)`)
|
||||
/// - A field that is explicitly set to a value (`Maybe::Set(Some(value))`)
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// In JSON:
|
||||
/// - `{}` (field missing) deserializes to `Maybe::Unset`
|
||||
/// - `{"field": null}` deserializes to `Maybe::Set(None)`
|
||||
/// - `{"field": "value"}` deserializes to `Maybe::Set(Some("value"))`
|
||||
///
|
||||
/// WARN: This type should not be wrapped in an option inside of settings, otherwise the default `serde_json` behavior
|
||||
/// of treating `null` and missing as the `Option::None` will be used
|
||||
#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants, Default)]
|
||||
#[strum_discriminants(derive(strum::VariantArray, strum::VariantNames, strum::FromRepr))]
|
||||
pub enum Maybe<T> {
|
||||
/// An explicitly set value, which may be `None` (representing JSON `null`) or `Some(value)`.
|
||||
Set(Option<T>),
|
||||
/// A value that was not present in the configuration.
|
||||
#[default]
|
||||
Unset,
|
||||
}
|
||||
|
||||
impl<T: Clone> merge_from::MergeFrom for Maybe<T> {
|
||||
fn merge_from(&mut self, other: &Self) {
|
||||
if self.is_unset() {
|
||||
*self = other.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<Option<Option<T>>> for Maybe<T> {
|
||||
fn from(value: Option<Option<T>>) -> Self {
|
||||
match value {
|
||||
Some(value) => Maybe::Set(value),
|
||||
None => Maybe::Unset,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Maybe<T> {
|
||||
pub fn is_set(&self) -> bool {
|
||||
matches!(self, Maybe::Set(_))
|
||||
}
|
||||
|
||||
pub fn is_unset(&self) -> bool {
|
||||
matches!(self, Maybe::Unset)
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Option<T> {
|
||||
match self {
|
||||
Maybe::Set(value) => value,
|
||||
Maybe::Unset => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_ref(&self) -> Option<&Option<T>> {
|
||||
match self {
|
||||
Maybe::Set(value) => Some(value),
|
||||
Maybe::Unset => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: serde::Serialize> serde::Serialize for Maybe<T> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match self {
|
||||
Maybe::Set(value) => value.serialize(serializer),
|
||||
Maybe::Unset => serializer.serialize_none(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, T: serde::Deserialize<'de>> serde::Deserialize<'de> for Maybe<T> {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
Option::<T>::deserialize(deserializer).map(Maybe::Set)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: JsonSchema> JsonSchema for Maybe<T> {
|
||||
fn schema_name() -> std::borrow::Cow<'static, str> {
|
||||
format!("Nullable<{}>", T::schema_name()).into()
|
||||
}
|
||||
|
||||
fn json_schema(generator: &mut schemars::generate::SchemaGenerator) -> schemars::Schema {
|
||||
let mut schema = generator.subschema_for::<Option<T>>();
|
||||
// Add description explaining that null is an explicit value
|
||||
let description = if let Some(existing_desc) =
|
||||
schema.get("description").and_then(|desc| desc.as_str())
|
||||
{
|
||||
format!(
|
||||
"{}. Note: `null` is treated as an explicit value, different from omitting the field entirely.",
|
||||
existing_desc
|
||||
)
|
||||
} else {
|
||||
"This field supports explicit `null` values. Omitting the field is different from setting it to `null`.".to_string()
|
||||
};
|
||||
|
||||
schema.insert("description".to_string(), description.into());
|
||||
|
||||
schema
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json;
|
||||
|
||||
#[test]
|
||||
fn test_maybe() {
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
struct TestStruct {
|
||||
#[serde(default)]
|
||||
#[serde(skip_serializing_if = "Maybe::is_unset")]
|
||||
field: Maybe<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
struct NumericTest {
|
||||
#[serde(default)]
|
||||
value: Maybe<i32>,
|
||||
}
|
||||
|
||||
let json = "{}";
|
||||
let result: TestStruct = serde_json::from_str(json).unwrap();
|
||||
assert!(result.field.is_unset());
|
||||
assert_eq!(result.field, Maybe::Unset);
|
||||
|
||||
let json = r#"{"field": null}"#;
|
||||
let result: TestStruct = serde_json::from_str(json).unwrap();
|
||||
assert!(result.field.is_set());
|
||||
assert_eq!(result.field, Maybe::Set(None));
|
||||
|
||||
let json = r#"{"field": "hello"}"#;
|
||||
let result: TestStruct = serde_json::from_str(json).unwrap();
|
||||
assert!(result.field.is_set());
|
||||
assert_eq!(result.field, Maybe::Set(Some("hello".to_string())));
|
||||
|
||||
let test = TestStruct {
|
||||
field: Maybe::Unset,
|
||||
};
|
||||
let json = serde_json::to_string(&test).unwrap();
|
||||
assert_eq!(json, "{}");
|
||||
|
||||
let test = TestStruct {
|
||||
field: Maybe::Set(None),
|
||||
};
|
||||
let json = serde_json::to_string(&test).unwrap();
|
||||
assert_eq!(json, r#"{"field":null}"#);
|
||||
|
||||
let test = TestStruct {
|
||||
field: Maybe::Set(Some("world".to_string())),
|
||||
};
|
||||
let json = serde_json::to_string(&test).unwrap();
|
||||
assert_eq!(json, r#"{"field":"world"}"#);
|
||||
|
||||
let default_maybe: Maybe<i32> = Maybe::default();
|
||||
assert!(default_maybe.is_unset());
|
||||
|
||||
let unset: Maybe<String> = Maybe::Unset;
|
||||
assert!(unset.is_unset());
|
||||
assert!(!unset.is_set());
|
||||
|
||||
let set_none: Maybe<String> = Maybe::Set(None);
|
||||
assert!(set_none.is_set());
|
||||
assert!(!set_none.is_unset());
|
||||
|
||||
let set_some: Maybe<String> = Maybe::Set(Some("value".to_string()));
|
||||
assert!(set_some.is_set());
|
||||
assert!(!set_some.is_unset());
|
||||
|
||||
let original = TestStruct {
|
||||
field: Maybe::Set(Some("test".to_string())),
|
||||
};
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: TestStruct = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(original, deserialized);
|
||||
|
||||
let json = r#"{"value": 42}"#;
|
||||
let result: NumericTest = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result.value, Maybe::Set(Some(42)));
|
||||
|
||||
let json = r#"{"value": null}"#;
|
||||
let result: NumericTest = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result.value, Maybe::Set(None));
|
||||
|
||||
let json = "{}";
|
||||
let result: NumericTest = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result.value, Maybe::Unset);
|
||||
|
||||
// Test JsonSchema implementation
|
||||
use schemars::schema_for;
|
||||
let schema = schema_for!(Maybe<String>);
|
||||
let schema_json = serde_json::to_value(&schema).unwrap();
|
||||
|
||||
// Verify the description mentions that null is an explicit value
|
||||
let description = schema_json["description"].as_str().unwrap();
|
||||
assert!(
|
||||
description.contains("null") && description.contains("explicit"),
|
||||
"Schema description should mention that null is an explicit value. Got: {}",
|
||||
description
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ use settings_macros::MergeFrom;
|
||||
use util::serde::default_true;
|
||||
|
||||
use crate::{
|
||||
AllLanguageSettingsContent, DelayMs, ExtendingVec, Maybe, ProjectTerminalSettingsContent,
|
||||
AllLanguageSettingsContent, DelayMs, ExtendingVec, ProjectTerminalSettingsContent,
|
||||
SlashCommandSettings,
|
||||
};
|
||||
|
||||
@@ -61,8 +61,8 @@ pub struct WorktreeSettingsContent {
|
||||
///
|
||||
/// Default: null
|
||||
#[serde(default)]
|
||||
#[serde(skip_serializing_if = "Maybe::is_unset")]
|
||||
pub project_name: Maybe<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub project_name: Option<String>,
|
||||
|
||||
/// Whether to prevent this project from being shared in public channels.
|
||||
///
|
||||
@@ -196,7 +196,7 @@ pub struct SessionSettingsContent {
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, MergeFrom, Debug)]
|
||||
#[serde(tag = "source", rename_all = "snake_case")]
|
||||
#[serde(untagged, rename_all = "snake_case")]
|
||||
pub enum ContextServerSettingsContent {
|
||||
Custom {
|
||||
/// Whether the context server is enabled.
|
||||
@@ -206,6 +206,16 @@ pub enum ContextServerSettingsContent {
|
||||
#[serde(flatten)]
|
||||
command: ContextServerCommand,
|
||||
},
|
||||
Http {
|
||||
/// Whether the context server is enabled.
|
||||
#[serde(default = "default_true")]
|
||||
enabled: bool,
|
||||
/// The URL of the remote context server.
|
||||
url: String,
|
||||
/// Optional headers to send.
|
||||
#[serde(skip_serializing_if = "HashMap::is_empty", default)]
|
||||
headers: HashMap<String, String>,
|
||||
},
|
||||
Extension {
|
||||
/// Whether the context server is enabled.
|
||||
#[serde(default = "default_true")]
|
||||
@@ -217,19 +227,24 @@ pub enum ContextServerSettingsContent {
|
||||
settings: serde_json::Value,
|
||||
},
|
||||
}
|
||||
|
||||
impl ContextServerSettingsContent {
|
||||
pub fn set_enabled(&mut self, enabled: bool) {
|
||||
match self {
|
||||
ContextServerSettingsContent::Custom {
|
||||
enabled: custom_enabled,
|
||||
command: _,
|
||||
..
|
||||
} => {
|
||||
*custom_enabled = enabled;
|
||||
}
|
||||
ContextServerSettingsContent::Extension {
|
||||
enabled: ext_enabled,
|
||||
settings: _,
|
||||
..
|
||||
} => *ext_enabled = enabled,
|
||||
ContextServerSettingsContent::Http {
|
||||
enabled: remote_enabled,
|
||||
..
|
||||
} => *remote_enabled = enabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -870,7 +870,7 @@ impl VsCodeSettings {
|
||||
|
||||
fn worktree_settings_content(&self) -> WorktreeSettingsContent {
|
||||
WorktreeSettingsContent {
|
||||
project_name: crate::Maybe::Unset,
|
||||
project_name: None,
|
||||
prevent_sharing_in_public_channels: false,
|
||||
file_scan_exclusions: self
|
||||
.read_value("files.watcherExclude")
|
||||
|
||||
@@ -33,10 +33,10 @@ pub(crate) fn settings_data(cx: &App) -> Vec<SettingsPage> {
|
||||
SettingField {
|
||||
json_path: Some("project_name"),
|
||||
pick: |settings_content| {
|
||||
settings_content.project.worktree.project_name.as_ref()?.as_ref().or(DEFAULT_EMPTY_STRING)
|
||||
settings_content.project.worktree.project_name.as_ref().or(DEFAULT_EMPTY_STRING)
|
||||
},
|
||||
write: |settings_content, value| {
|
||||
settings_content.project.worktree.project_name = settings::Maybe::Set(value.filter(|name| !name.is_empty()));
|
||||
settings_content.project.worktree.project_name = value.filter(|name| !name.is_empty());
|
||||
},
|
||||
}
|
||||
),
|
||||
|
||||
@@ -507,7 +507,6 @@ fn init_renderers(cx: &mut App) {
|
||||
.add_basic_renderer::<settings::BufferLineHeightDiscriminants>(render_dropdown)
|
||||
.add_basic_renderer::<settings::AutosaveSettingDiscriminants>(render_dropdown)
|
||||
.add_basic_renderer::<settings::WorkingDirectoryDiscriminants>(render_dropdown)
|
||||
.add_basic_renderer::<settings::MaybeDiscriminants>(render_dropdown)
|
||||
.add_basic_renderer::<settings::IncludeIgnoredContent>(render_dropdown)
|
||||
.add_basic_renderer::<settings::ShowIndentGuides>(render_dropdown)
|
||||
.add_basic_renderer::<settings::ShellDiscriminants>(render_dropdown)
|
||||
|
||||
@@ -37,10 +37,6 @@ pub struct SweepFeatureFlag;
|
||||
|
||||
impl FeatureFlag for SweepFeatureFlag {
|
||||
const NAME: &str = "sweep-ai";
|
||||
|
||||
fn enabled_for_staff() -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
||||
@@ -77,6 +77,7 @@ impl RenderOnce for Modal {
|
||||
.w_full()
|
||||
.flex_1()
|
||||
.gap(DynamicSpacing::Base08.rems(cx))
|
||||
.when(self.footer.is_some(), |this| this.pb_4())
|
||||
.when_some(
|
||||
self.container_scroll_handler,
|
||||
|this, container_scroll_handle| {
|
||||
@@ -276,7 +277,6 @@ impl RenderOnce for ModalFooter {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
h_flex()
|
||||
.w_full()
|
||||
.mt_4()
|
||||
.p(DynamicSpacing::Base08.rems(cx))
|
||||
.flex_none()
|
||||
.justify_between()
|
||||
|
||||
@@ -35,7 +35,6 @@ multi_buffer.workspace = true
|
||||
nvim-rs = { git = "https://github.com/KillTheMule/nvim-rs", rev = "764dd270c642f77f10f3e19d05cc178a6cbe69f3", features = ["use_tokio"], optional = true }
|
||||
picker.workspace = true
|
||||
project.workspace = true
|
||||
project_panel.workspace = true
|
||||
regex.workspace = true
|
||||
schemars.workspace = true
|
||||
search.workspace = true
|
||||
|
||||
@@ -183,8 +183,6 @@ actions!(
|
||||
InnerObject,
|
||||
/// Maximizes the current pane.
|
||||
MaximizePane,
|
||||
/// Opens the default keymap file.
|
||||
OpenDefaultKeymap,
|
||||
/// Resets all pane sizes to default.
|
||||
ResetPaneSizes,
|
||||
/// Resizes the pane to the right.
|
||||
@@ -314,7 +312,7 @@ pub fn init(cx: &mut App) {
|
||||
|
||||
workspace.register_action(|_, _: &ToggleProjectPanelFocus, window, cx| {
|
||||
if Vim::take_count(cx).is_none() {
|
||||
window.dispatch_action(project_panel::ToggleFocus.boxed_clone(), cx);
|
||||
window.dispatch_action(zed_actions::project_panel::ToggleFocus.boxed_clone(), cx);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -343,7 +341,7 @@ pub fn init(cx: &mut App) {
|
||||
};
|
||||
});
|
||||
|
||||
workspace.register_action(|_, _: &OpenDefaultKeymap, _, cx| {
|
||||
workspace.register_action(|_, _: &zed_actions::vim::OpenDefaultKeymap, _, cx| {
|
||||
cx.emit(workspace::Event::OpenBundledFile {
|
||||
text: settings::vim_keymap(),
|
||||
title: "Default Vim Bindings",
|
||||
|
||||
@@ -66,7 +66,7 @@ impl Settings for WorktreeSettings {
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
project_name: worktree.project_name.into_inner(),
|
||||
project_name: worktree.project_name,
|
||||
prevent_sharing_in_public_channels: worktree.prevent_sharing_in_public_channels,
|
||||
file_scan_exclusions: path_matchers(file_scan_exclusions, "file_scan_exclusions")
|
||||
.log_err()
|
||||
|
||||
@@ -1002,7 +1002,7 @@ fn register_actions(
|
||||
.register_action(open_project_debug_tasks_file)
|
||||
.register_action(
|
||||
|workspace: &mut Workspace,
|
||||
_: &project_panel::ToggleFocus,
|
||||
_: &zed_actions::project_panel::ToggleFocus,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>| {
|
||||
workspace.toggle_panel_focus::<ProjectPanel>(window, cx);
|
||||
@@ -4657,133 +4657,6 @@ mod tests {
|
||||
});
|
||||
}
|
||||
|
||||
/// Checks that action namespaces are the expected set. The purpose of this is to prevent typos
|
||||
/// and let you know when introducing a new namespace.
|
||||
#[gpui::test]
|
||||
async fn test_action_namespaces(cx: &mut gpui::TestAppContext) {
|
||||
use itertools::Itertools;
|
||||
|
||||
init_keymap_test(cx);
|
||||
cx.update(|cx| {
|
||||
let all_actions = cx.all_action_names();
|
||||
|
||||
let mut actions_without_namespace = Vec::new();
|
||||
let all_namespaces = all_actions
|
||||
.iter()
|
||||
.filter_map(|action_name| {
|
||||
let namespace = action_name
|
||||
.split("::")
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.rev()
|
||||
.skip(1)
|
||||
.rev()
|
||||
.join("::");
|
||||
if namespace.is_empty() {
|
||||
actions_without_namespace.push(*action_name);
|
||||
}
|
||||
if &namespace == "test_only" || &namespace == "stories" {
|
||||
None
|
||||
} else {
|
||||
Some(namespace)
|
||||
}
|
||||
})
|
||||
.sorted()
|
||||
.dedup()
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(actions_without_namespace, Vec::<&str>::new());
|
||||
|
||||
let expected_namespaces = vec![
|
||||
"action",
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
"app_menu",
|
||||
"assistant",
|
||||
"assistant2",
|
||||
"auto_update",
|
||||
"bedrock",
|
||||
"branches",
|
||||
"buffer_search",
|
||||
"channel_modal",
|
||||
"cli",
|
||||
"client",
|
||||
"collab",
|
||||
"collab_panel",
|
||||
"command_palette",
|
||||
"console",
|
||||
"context_server",
|
||||
"copilot",
|
||||
"debug_panel",
|
||||
"debugger",
|
||||
"dev",
|
||||
"diagnostics",
|
||||
"edit_prediction",
|
||||
"editor",
|
||||
"feedback",
|
||||
"file_finder",
|
||||
"git",
|
||||
"git_onboarding",
|
||||
"git_panel",
|
||||
"go_to_line",
|
||||
"icon_theme_selector",
|
||||
"journal",
|
||||
"keymap_editor",
|
||||
"keystroke_input",
|
||||
"language_selector",
|
||||
"line_ending_selector",
|
||||
"lsp_tool",
|
||||
"markdown",
|
||||
"menu",
|
||||
"notebook",
|
||||
"notification_panel",
|
||||
"onboarding",
|
||||
"outline",
|
||||
"outline_panel",
|
||||
"pane",
|
||||
"panel",
|
||||
"picker",
|
||||
"project_panel",
|
||||
"project_search",
|
||||
"project_symbols",
|
||||
"projects",
|
||||
"repl",
|
||||
"rules_library",
|
||||
"search",
|
||||
"settings_editor",
|
||||
"settings_profile_selector",
|
||||
"snippets",
|
||||
"stash_picker",
|
||||
"supermaven",
|
||||
"svg",
|
||||
"syntax_tree_view",
|
||||
"tab_switcher",
|
||||
"task",
|
||||
"terminal",
|
||||
"terminal_panel",
|
||||
"theme_selector",
|
||||
"toast",
|
||||
"toolchain",
|
||||
"variable_list",
|
||||
"vim",
|
||||
"window",
|
||||
"workspace",
|
||||
"zed",
|
||||
"zed_actions",
|
||||
"zed_predict_onboarding",
|
||||
"zeta",
|
||||
];
|
||||
assert_eq!(
|
||||
all_namespaces,
|
||||
expected_namespaces
|
||||
.into_iter()
|
||||
.map(|namespace| namespace.to_string())
|
||||
.sorted()
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_bundled_settings_and_themes(cx: &mut App) {
|
||||
cx.text_system()
|
||||
|
||||
@@ -39,7 +39,7 @@ pub fn app_menus(cx: &mut App) -> Vec<Menu> {
|
||||
],
|
||||
}),
|
||||
MenuItem::separator(),
|
||||
MenuItem::action("Project Panel", project_panel::ToggleFocus),
|
||||
MenuItem::action("Project Panel", zed_actions::project_panel::ToggleFocus),
|
||||
MenuItem::action("Outline Panel", outline_panel::ToggleFocus),
|
||||
MenuItem::action("Collab Panel", collab_panel::ToggleFocus),
|
||||
MenuItem::action("Terminal Panel", terminal_panel::ToggleFocus),
|
||||
|
||||
@@ -250,6 +250,17 @@ pub mod command_palette {
|
||||
);
|
||||
}
|
||||
|
||||
pub mod project_panel {
|
||||
use gpui::actions;
|
||||
|
||||
actions!(
|
||||
project_panel,
|
||||
[
|
||||
/// Toggles focus on the project panel.
|
||||
ToggleFocus
|
||||
]
|
||||
);
|
||||
}
|
||||
pub mod feedback {
|
||||
use gpui::actions;
|
||||
|
||||
@@ -532,6 +543,18 @@ actions!(
|
||||
]
|
||||
);
|
||||
|
||||
pub mod vim {
|
||||
use gpui::actions;
|
||||
|
||||
actions!(
|
||||
vim,
|
||||
[
|
||||
/// Opens the default keymap file.
|
||||
OpenDefaultKeymap
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct WslConnectionOptions {
|
||||
pub distro_name: String,
|
||||
|
||||
@@ -2,6 +2,11 @@ use anyhow::{Context as _, Result};
|
||||
use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
|
||||
use std::{cmp, ops::Range, path::Path, sync::Arc};
|
||||
|
||||
const EDITS_TAG_NAME: &'static str = "edits";
|
||||
const OLD_TEXT_TAG_NAME: &'static str = "old_text";
|
||||
const NEW_TEXT_TAG_NAME: &'static str = "new_text";
|
||||
const XML_TAGS: &[&str] = &[EDITS_TAG_NAME, OLD_TEXT_TAG_NAME, NEW_TEXT_TAG_NAME];
|
||||
|
||||
pub async fn parse_xml_edits<'a>(
|
||||
input: &'a str,
|
||||
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
|
||||
@@ -12,38 +17,22 @@ pub async fn parse_xml_edits<'a>(
|
||||
}
|
||||
|
||||
async fn parse_xml_edits_inner<'a>(
|
||||
mut input: &'a str,
|
||||
input: &'a str,
|
||||
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
|
||||
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
|
||||
let edits_tag = parse_tag(&mut input, "edits")?.context("No edits tag")?;
|
||||
let xml_edits = extract_xml_replacements(input)?;
|
||||
|
||||
input = edits_tag.body;
|
||||
let (buffer, context_ranges) = get_buffer(xml_edits.file_path.as_ref())
|
||||
.with_context(|| format!("no buffer for file {}", xml_edits.file_path))?;
|
||||
|
||||
let file_path = edits_tag
|
||||
.attributes
|
||||
.trim_start()
|
||||
.strip_prefix("path")
|
||||
.context("no file attribute on edits tag")?
|
||||
.trim_end()
|
||||
.strip_prefix('=')
|
||||
.context("no value for path attribute")?
|
||||
.trim()
|
||||
.trim_start_matches('"')
|
||||
.trim_end_matches('"');
|
||||
|
||||
let (buffer, context_ranges) = get_buffer(file_path.as_ref())
|
||||
.with_context(|| format!("no buffer for file {file_path}"))?;
|
||||
|
||||
let mut edits = vec![];
|
||||
while let Some(old_text_tag) = parse_tag(&mut input, "old_text")? {
|
||||
let new_text_tag =
|
||||
parse_tag(&mut input, "new_text")?.context("no new_text tag following old_text")?;
|
||||
let match_range = fuzzy_match_in_ranges(old_text_tag.body, buffer, context_ranges)?;
|
||||
let old_text = buffer
|
||||
let mut all_edits = vec![];
|
||||
for (old_text, new_text) in xml_edits.replacements {
|
||||
let match_range = fuzzy_match_in_ranges(old_text, buffer, context_ranges)?;
|
||||
let matched_old_text = buffer
|
||||
.text_for_range(match_range.clone())
|
||||
.collect::<String>();
|
||||
let edits_within_hunk = language::text_diff(&old_text, &new_text_tag.body);
|
||||
edits.extend(
|
||||
let edits_within_hunk = language::text_diff(&matched_old_text, new_text);
|
||||
all_edits.extend(
|
||||
edits_within_hunk
|
||||
.into_iter()
|
||||
.map(move |(inner_range, inner_text)| {
|
||||
@@ -56,7 +45,7 @@ async fn parse_xml_edits_inner<'a>(
|
||||
);
|
||||
}
|
||||
|
||||
Ok((buffer, edits))
|
||||
Ok((buffer, all_edits))
|
||||
}
|
||||
|
||||
fn fuzzy_match_in_ranges(
|
||||
@@ -110,32 +99,128 @@ fn fuzzy_match_in_ranges(
|
||||
);
|
||||
}
|
||||
|
||||
struct ParsedTag<'a> {
|
||||
attributes: &'a str,
|
||||
body: &'a str,
|
||||
#[derive(Debug)]
|
||||
struct XmlEdits<'a> {
|
||||
file_path: &'a str,
|
||||
/// Vec of (old_text, new_text) pairs
|
||||
replacements: Vec<(&'a str, &'a str)>,
|
||||
}
|
||||
|
||||
fn parse_tag<'a>(input: &mut &'a str, tag: &str) -> Result<Option<ParsedTag<'a>>> {
|
||||
let open_tag = format!("<{}", tag);
|
||||
let close_tag = format!("</{}>", tag);
|
||||
let Some(start_ix) = input.find(&open_tag) else {
|
||||
return Ok(None);
|
||||
};
|
||||
let start_ix = start_ix + open_tag.len();
|
||||
let closing_bracket_ix = start_ix
|
||||
+ input[start_ix..]
|
||||
fn extract_xml_replacements(input: &str) -> Result<XmlEdits<'_>> {
|
||||
let mut cursor = 0;
|
||||
|
||||
let (edits_body_start, edits_attrs) =
|
||||
find_tag_open(input, &mut cursor, EDITS_TAG_NAME)?.context("No edits tag found")?;
|
||||
|
||||
let file_path = edits_attrs
|
||||
.trim_start()
|
||||
.strip_prefix("path")
|
||||
.context("no path attribute on edits tag")?
|
||||
.trim_end()
|
||||
.strip_prefix('=')
|
||||
.context("no value for path attribute")?
|
||||
.trim()
|
||||
.trim_start_matches('"')
|
||||
.trim_end_matches('"');
|
||||
|
||||
cursor = edits_body_start;
|
||||
let mut edits_list = Vec::new();
|
||||
|
||||
while let Some((old_body_start, _)) = find_tag_open(input, &mut cursor, OLD_TEXT_TAG_NAME)? {
|
||||
let old_body_end = find_tag_close(input, &mut cursor)?;
|
||||
let old_text = trim_surrounding_newlines(&input[old_body_start..old_body_end]);
|
||||
|
||||
let (new_body_start, _) = find_tag_open(input, &mut cursor, NEW_TEXT_TAG_NAME)?
|
||||
.context("no new_text tag following old_text")?;
|
||||
let new_body_end = find_tag_close(input, &mut cursor)?;
|
||||
let new_text = trim_surrounding_newlines(&input[new_body_start..new_body_end]);
|
||||
|
||||
edits_list.push((old_text, new_text));
|
||||
}
|
||||
|
||||
Ok(XmlEdits {
|
||||
file_path,
|
||||
replacements: edits_list,
|
||||
})
|
||||
}
|
||||
|
||||
/// Trims a single leading and trailing newline
|
||||
fn trim_surrounding_newlines(input: &str) -> &str {
|
||||
let start = input.strip_prefix('\n').unwrap_or(input);
|
||||
let end = start.strip_suffix('\n').unwrap_or(start);
|
||||
end
|
||||
}
|
||||
|
||||
fn find_tag_open<'a>(
|
||||
input: &'a str,
|
||||
cursor: &mut usize,
|
||||
expected_tag: &str,
|
||||
) -> Result<Option<(usize, &'a str)>> {
|
||||
let mut search_pos = *cursor;
|
||||
|
||||
while search_pos < input.len() {
|
||||
let Some(tag_start) = input[search_pos..].find("<") else {
|
||||
break;
|
||||
};
|
||||
let tag_start = search_pos + tag_start;
|
||||
if !input[tag_start + 1..].starts_with(expected_tag) {
|
||||
search_pos = search_pos + tag_start + 1;
|
||||
continue;
|
||||
};
|
||||
|
||||
let after_tag_name = tag_start + expected_tag.len() + 1;
|
||||
let close_bracket = input[after_tag_name..]
|
||||
.find('>')
|
||||
.with_context(|| format!("missing > after {tag}"))?;
|
||||
let attributes = &input[start_ix..closing_bracket_ix].trim();
|
||||
let end_ix = closing_bracket_ix
|
||||
+ input[closing_bracket_ix..]
|
||||
.find(&close_tag)
|
||||
.with_context(|| format!("no `{close_tag}` tag"))?;
|
||||
let body = &input[closing_bracket_ix + '>'.len_utf8()..end_ix];
|
||||
let body = body.strip_prefix('\n').unwrap_or(body);
|
||||
let body = body.strip_suffix('\n').unwrap_or(body);
|
||||
*input = &input[end_ix + close_tag.len()..];
|
||||
Ok(Some(ParsedTag { attributes, body }))
|
||||
.with_context(|| format!("missing > after <{}", expected_tag))?;
|
||||
let attrs_end = after_tag_name + close_bracket;
|
||||
let body_start = attrs_end + 1;
|
||||
|
||||
let attributes = input[after_tag_name..attrs_end].trim();
|
||||
*cursor = body_start;
|
||||
|
||||
return Ok(Some((body_start, attributes)));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn find_tag_close(input: &str, cursor: &mut usize) -> Result<usize> {
|
||||
let mut depth = 1;
|
||||
let mut search_pos = *cursor;
|
||||
|
||||
while search_pos < input.len() && depth > 0 {
|
||||
let Some(bracket_offset) = input[search_pos..].find('<') else {
|
||||
break;
|
||||
};
|
||||
let bracket_pos = search_pos + bracket_offset;
|
||||
|
||||
if input[bracket_pos..].starts_with("</")
|
||||
&& let Some(close_end) = input[bracket_pos + 2..].find('>')
|
||||
{
|
||||
let close_start = bracket_pos + 2;
|
||||
let tag_name = input[close_start..close_start + close_end].trim();
|
||||
|
||||
if XML_TAGS.contains(&tag_name) {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
*cursor = close_start + close_end + 1;
|
||||
return Ok(bracket_pos);
|
||||
}
|
||||
}
|
||||
search_pos = close_start + close_end + 1;
|
||||
continue;
|
||||
} else if let Some(close_bracket_offset) = input[bracket_pos..].find('>') {
|
||||
let close_bracket_pos = bracket_pos + close_bracket_offset;
|
||||
let tag_name = &input[bracket_pos + 1..close_bracket_pos].trim();
|
||||
if XML_TAGS.contains(&tag_name) {
|
||||
depth += 1;
|
||||
}
|
||||
}
|
||||
|
||||
search_pos = bracket_pos + 1;
|
||||
}
|
||||
|
||||
anyhow::bail!("no closing tag found")
|
||||
}
|
||||
|
||||
const REPLACEMENT_COST: u32 = 1;
|
||||
@@ -357,17 +442,128 @@ mod tests {
|
||||
use util::path;
|
||||
|
||||
#[test]
|
||||
fn test_parse_tags() {
|
||||
let mut input = indoc! {r#"
|
||||
Prelude
|
||||
<tag attr="foo">
|
||||
tag value
|
||||
</tag>
|
||||
"# };
|
||||
let parsed = parse_tag(&mut input, "tag").unwrap().unwrap();
|
||||
assert_eq!(parsed.attributes, "attr=\"foo\"");
|
||||
assert_eq!(parsed.body, "tag value");
|
||||
assert_eq!(input, "\n");
|
||||
fn test_extract_xml_edits() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
old content
|
||||
</old_text>
|
||||
<new_text>
|
||||
new content
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "test.rs");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "old content");
|
||||
assert_eq!(result.replacements[0].1, "new content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_with_wrong_closing_tags() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
old content
|
||||
</new_text>
|
||||
<new_text>
|
||||
new content
|
||||
</old_text>
|
||||
</ edits >
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "test.rs");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "old content");
|
||||
assert_eq!(result.replacements[0].1, "new content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_with_xml_like_content() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="component.tsx">
|
||||
<old_text>
|
||||
<foo><bar></bar></foo>
|
||||
</old_text>
|
||||
<new_text>
|
||||
<foo><bar><baz></baz></bar></foo>
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "component.tsx");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "<foo><bar></bar></foo>");
|
||||
assert_eq!(
|
||||
result.replacements[0].1,
|
||||
"<foo><bar><baz></baz></bar></foo>"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_with_conflicting_content() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="component.tsx">
|
||||
<old_text>
|
||||
<new_text></new_text>
|
||||
</old_text>
|
||||
<new_text>
|
||||
<old_text></old_text>
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "component.tsx");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "<new_text></new_text>");
|
||||
assert_eq!(result.replacements[0].1, "<old_text></old_text>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_multiple_pairs() {
|
||||
let input = indoc! {r#"
|
||||
Some reasoning before edits. Lots of thinking going on here
|
||||
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
first old
|
||||
</old_text>
|
||||
<new_text>
|
||||
first new
|
||||
</new_text>
|
||||
<old_text>
|
||||
second old
|
||||
</edits>
|
||||
<new_text>
|
||||
second new
|
||||
</old_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "test.rs");
|
||||
assert_eq!(result.replacements.len(), 2);
|
||||
assert_eq!(result.replacements[0].0, "first old");
|
||||
assert_eq!(result.replacements[0].1, "first new");
|
||||
assert_eq!(result.replacements[1].0, "second old");
|
||||
assert_eq!(result.replacements[1].1, "second new");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_unexpected_eof() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
first old
|
||||
</
|
||||
"#};
|
||||
|
||||
extract_xml_replacements(input).expect_err("Unexpected end of file");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
||||
@@ -183,7 +183,7 @@ macro_rules! time {
|
||||
$crate::Timer::new($logger, $name)
|
||||
};
|
||||
($name:expr) => {
|
||||
$crate::time!($crate::default_logger!() => $name)
|
||||
time!($crate::default_logger!() => $name)
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -126,6 +126,7 @@
|
||||
- [Markdown](./languages/markdown.md)
|
||||
- [Nim](./languages/nim.md)
|
||||
- [OCaml](./languages/ocaml.md)
|
||||
- [OpenTofu](./languages/opentofu.md)
|
||||
- [PHP](./languages/php.md)
|
||||
- [PowerShell](./languages/powershell.md)
|
||||
- [Prisma](./languages/prisma.md)
|
||||
|
||||
@@ -40,11 +40,14 @@ You can connect them by adding their commands directly to your `settings.json`,
|
||||
```json [settings]
|
||||
{
|
||||
"context_servers": {
|
||||
"your-mcp-server": {
|
||||
"source": "custom",
|
||||
"run-command": {
|
||||
"command": "some-command",
|
||||
"args": ["arg-1", "arg-2"],
|
||||
"env": {}
|
||||
},
|
||||
"over-http": {
|
||||
"url": "custom",
|
||||
"headers": { "Authorization": "Bearer <token>" }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
20
docs/src/languages/opentofu.md
Normal file
20
docs/src/languages/opentofu.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# OpenTofu
|
||||
|
||||
OpenTofu support is available through the [OpenTofu extension](https://github.com/ashpool37/zed-extension-opentofu).
|
||||
|
||||
- Tree-sitter: [MichaHoffmann/tree-sitter-hcl](https://github.com/MichaHoffmann/tree-sitter-hcl)
|
||||
- Language Server: [opentofu/tofu-ls](https://github.com/opentofu/tofu-ls)
|
||||
|
||||
## Configuration
|
||||
|
||||
In order to automatically use the OpenTofu extension and language server when editing .tf and .tfvars files,
|
||||
either uninstall the Terraform extension or add this to your settings.json:
|
||||
|
||||
```json
|
||||
"file_types": {
|
||||
"OpenTofu": ["tf"],
|
||||
"OpenTofu Vars": ["tfvars"]
|
||||
},
|
||||
```
|
||||
|
||||
See the [full list of server settings here](https://github.com/opentofu/tofu-ls/blob/main/docs/SETTINGS.md).
|
||||
Reference in New Issue
Block a user