Compare commits

..

20 Commits

Author SHA1 Message Date
Mikayla Maki
5b74ddf020 WIP: Remove provider name 2025-11-18 23:56:05 -08:00
Mikayla Maki
54f20ae5d5 Remove provider name 2025-11-18 22:13:04 -08:00
Mikayla Maki
811efa45d0 Disambiguate similar completion events 2025-11-18 21:04:35 -08:00
Mikayla Maki
74501e0936 Simplify LanguageModelCompletionEvent enum, remove
`StatusUpdate::Failed` state by turning it into an error early, and then
faltten `StatusUpdate` into LanguageModelCompletionEvent
2025-11-18 20:56:11 -08:00
Michael Benfield
2ad8bd00ce Simplifying errors
Co-authored-by: Mikayla <mikayla@zed.dev>
2025-11-18 20:45:22 -08:00
Martin Bergo
7c0663b825 google_ai: Add gemini-3-pro-preview model (#43015)
Release Notes:

- Added the newly released Gemini 3 Pro Preview Model


https://docs.cloud.google.com/vertex-ai/generative-ai/docs/models/gemini/3-pro
2025-11-18 23:51:32 +00:00
Lukas Wirth
94a43dc73a extension_host: Fix IS_WASM_THREAD being set for wrong threads (#43005)
https://github.com/zed-industries/zed/pull/40883 implemented this
incorrectly. It was marking a random background thread as a wasm thread
(whatever thread picked up the wasm epoch timer background task),
instead of marking the threads that actually run the wasm extension.

This has two implications:
1. it didn't prevent extension panics from tearing down as planned
2. Worse, it actually made us hide legit panics in sentry for one of our
background workers.

Now 2 still technically applies for all tokio threads after this, but we
basically only use these for wasm extensions in the main zed binary.

Release Notes:

- Fixed extension panics crashing Zed on Linux
2025-11-18 23:49:22 +00:00
Ben Kunkle
e8e0707256 zeta2: Improve queries parsing (#43012)
Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

---------

Co-authored-by: Agus <agus@zed.dev>
Co-authored-by: Max <max@zed.dev>
2025-11-18 23:46:29 +00:00
Tom Zaspel
d7c340c739 docs: Add documenation for OpenTofu support (#42448)
Closes -

Release Notes:

- N/A

Signed-off-by: Tom Zaspel <40226087+tzabbi@users.noreply.github.com>
2025-11-18 18:40:09 -05:00
Julia Ryan
16b24e892e Increase error verbosity (#43013)
Closes #42288

This will actually print the parsing error that prevented the vscode
settings file from being loaded which should make it easier for users to
self help when they have an invalid config.

Release Notes:

- N/A
2025-11-18 23:25:12 +00:00
Barani S
917148c5ce gpui: Use DWM API for backdrop effects and add Mica/Mica Alt support (#41842)
This PR updates window background rendering to use the **official DWM
backdrop API** (`DwmSetWindowAttribute`) instead of the legacy
`SetWindowCompositionAttribute`.
It also adds **Mica** and **Mica Alt** options to
`WindowBackgroundAppearance` for native Windows 11 effects.

### Motivation

Enables modern, stable, and GPU-accelerated backdrops consistent with
Windows 11’s Fluent Design.
Removes reliance on undocumented APIs while maintaining backward
compatibility with older Windows versions.

### Changes

* Added `MicaBackdrop` and `MicaAltBackdrop` variants.
* Switched to DWM API for applying backdrop effects.
* Verified fallback behavior on Windows 10.

### Release Notes:

- Added `WindowBackgroundAppearance::MicaBackdrop` and
`WindowBackgroundAppearance::MicaAltBackdrop` for Windows 11 Mica and
Mica Alt window backdrops.

### Screenshots

- `WindowBackgroundAppearance::Blurred`
<img width="553" height="354" alt="image"
src="https://github.com/user-attachments/assets/57c9c25d-9412-4141-94b5-00000cc0b1ec"
/>

- `WindowBackgroundAppearance::MicaBackdrop`
<img width="553" height="354" alt="image"
src="https://github.com/user-attachments/assets/019f541c-3335-4c9e-b026-71f5a1786534"
/>

- `WindowBackgroundAppearance::MicaAltBackdrop`
<img width="553" height="354" alt="image"
src="https://github.com/user-attachments/assets/5128d600-c94d-4c89-b81a-8b842fe1337a"
/>

---------

Co-authored-by: John Tur <john-tur@outlook.com>
2025-11-18 18:20:32 -05:00
Piotr Osiewicz
951132fc13 chore: Fix build graph - again (#42999)
11.3s -> 10.0s for silly stuff like extracting actions from crates.
project panel still depends on git_ui though..

Release Notes:

- N/A
2025-11-18 19:20:34 +01:00
Ben Kunkle
bf0dd4057c zeta2: Make new_text/old_text parsing more robust (#42997)
Closes #ISSUE

The model often uses the wrong closing tag, or has spaces around the
closing tag name. This PR makes it so that opening tags are treated as
authoritative and any closing tag with the name `new_text` `old_text` or
`edits` is accepted based on depth. This has the additional benefit that
the parsing is more robust with contents that contain `new_text`
`old_text` or `edits. I.e. the following test passes

```rust
    #[test]
    fn test_extract_xml_edits_with_conflicting_content() {
        let input = indoc! {r#"
            <edits path="component.tsx">
            <old_text>
            <new_text></new_text>
            </old_text>
            <new_text>
            <old_text></old_text>
            </new_text>
            </edits>
        "#};

        let result = extract_xml_replacements(input).unwrap();
        assert_eq!(result.file_path, "component.tsx");
        assert_eq!(result.replacements.len(), 1);
        assert_eq!(result.replacements[0].0, "<new_text></new_text>");
        assert_eq!(result.replacements[0].1, "<old_text></old_text>");
    }
```

Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-11-18 12:36:37 -05:00
Conrad Irwin
3c4ca3f372 Remove settings::Maybe (#42933)
It's unclear how this would ever be useful

cc @probably-neb

Release Notes:

- N/A
2025-11-18 10:23:16 -07:00
Artur Shirokov
03132921c7 Add HTTP transport support for MCP servers (#39021)
### What this solves

This PR adds support for HTTP and SSE (Server-Sent Events) transports to
Zed's context server implementation, enabling communication with remote
MCP servers. Currently, Zed only supports local MCP servers via stdio
transport. This limitation prevents users from:

- Connecting to cloud-hosted MCP servers
- Using MCP servers running in containers or on remote machines
- Leveraging MCP servers that are designed to work over HTTP/SSE

### Why it's important

The MCP (Model Context Protocol) specification includes HTTP/SSE as
standard transport options, and many MCP server implementations are
being built with these transports in mind. Without this support, Zed
users are limited to a subset of the MCP ecosystem. This is particularly
important for:

- Enterprise users who need to connect to centralized MCP services
- Developers working with MCP servers that require network isolation
- Users wanting to leverage cloud-based context providers (e.g.,
knowledge bases, API integrations)

### Implementation approach

The implementation follows Zed's existing architectural patterns:

- **Transports**: Added `HttpTransport` and `SseTransport` to the
`context_server` crate, built on top of the existing `http_client` crate
- **Async handling**: Uses `gpui::spawn` for network operations instead
of introducing a new Tokio runtime
- **Settings**: Extended `ContextServerSettings` enum with a `Remote`
variant to support URL-based configuration
- **UI**: Updated the agent configuration UI with an "Add Remote Server"
option and dedicated modal for remote server management

### Changes included

- [x] HTTP transport implementation with request/response handling
- [x] SSE transport for server-sent events streaming
- [x] `build_transport` function to construct appropriate transport
based on URL scheme
- [x] Settings system updates to support remote server configuration
- [x] UI updates for adding/editing remote servers
- [x] Unit tests using `FakeHttpClient` for both transports
- [x] Integration tests (WIP)
- [x] Documentation updates (WIP)

### Testing

- Unit tests for both `HttpTransport` and `SseTransport` using mocked
HTTP client
- Manual testing with example MCP servers over HTTP/SSE
- Settings validation and UI interaction testing

### Screenshots/Recordings

[TODO: Add screenshots of the new "Add Remote Server" UI and
configuration modal]

### Example configuration

Users can now configure remote MCP servers in their `settings.json`:

```json
{
  "context_servers": {
    "my-remote-server": {
      "enabled": true,
      "url": "http://localhost:3000/mcp"
    }
  }
}
```

### AI assistance disclosure

I used AI to help with:

- Understanding the MCP protocol specification and how HTTP/SSE
transports should work
- Reviewing Zed's existing patterns for async operations and suggesting
consistent approaches
- Generating boilerplate for test cases
- Debugging SSE streaming issues

All code has been manually reviewed, tested, and adapted to fit Zed's
architecture. The core logic, architectural decisions, and integration
with Zed's systems were done with human understanding of the codebase.
AI was primarily used as a reference tool and for getting unstuck on
specific technical issues.

Release notes:
* You can now configure MCP Servers that connect over HTTP in your
settings file. These are not yet available in the extensions API.
  ```
  {
    "context_servers": {
      "my-remote-server": {
        "enabled": true,
        "url": "http://localhost:3000/mcp"
      }
    }
  }
  ```

---------

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-11-18 16:39:08 +00:00
Richard Feldman
c0fadae881 Thought signatures (#42915)
Implement Gemini API's [thought
signatures](https://ai.google.dev/gemini-api/docs/thinking#signatures)

Release Notes:

- Added thought signatures for Gemini tool calls
2025-11-18 10:41:19 -05:00
Agus Zubiaga
1c66c3991d Enable sweep flag for staff (#42987)
Release Notes:

- N/A
2025-11-18 15:39:27 +00:00
Agus Zubiaga
7e591a7e9a Fix sweep icon spacing (#42986)
Release Notes:

- N/A
2025-11-18 15:33:03 +00:00
Danilo Leal
c44d93745a agent_ui: Improve the modal to add LLM providers (#42983)
Closes https://github.com/zed-industries/zed/issues/42807

This PR makes the modal to add LLM providers a bit better to interact
with:

1. Added a scrollbar
2. Made the inputs navigable with tab
3. Added some responsiveness to ensure it resizes on shorter windows


https://github.com/user-attachments/assets/758ea5f0-6bcc-4a2b-87ea-114982f37caf

Release Notes:

- agent: Improved the modal to add LLM providers by making it responsive
and keyboard navigable.
2025-11-18 12:28:14 -03:00
Lukas Wirth
b4e4e0d3ac remote: Fix up incorrect logs (#42979)
Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-11-18 15:14:52 +00:00
85 changed files with 2453 additions and 1589 deletions

View File

@@ -16,7 +16,9 @@ rustflags = ["-D", "warnings"]
debug = "limited"
# Use Mold on Linux, because it's faster than GNU ld and LLD.
# We dont use wild in CI as its not production ready.
#
# We no longer set this in the default `config.toml` so that developers can opt in to Wild, which
# is faster than Mold, in their own ~/.cargo/config.toml.
[target.x86_64-unknown-linux-gnu]
linker = "clang"
rustflags = ["-C", "link-arg=-fuse-ld=mold"]

View File

@@ -8,14 +8,6 @@ perf-test = ["test", "--profile", "release-fast", "--lib", "--bins", "--tests",
# Keep similar flags here to share some ccache
perf-compare = ["run", "--profile", "release-fast", "-p", "perf", "--config", "target.'cfg(true)'.rustflags=[\"--cfg\", \"perf_enabled\"]", "--", "compare"]
# [target.x86_64-unknown-linux-gnu]
# linker = "clang"
# rustflags = ["-C", "link-arg=-fuse-ld=mold"]
[target.aarch64-unknown-linux-gnu]
linker = "clang"
rustflags = ["-C", "link-arg=-fuse-ld=mold"]
[target.'cfg(target_os = "windows")']
rustflags = [
"--cfg",

19
Cargo.lock generated
View File

@@ -2617,23 +2617,26 @@ dependencies = [
[[package]]
name = "calloop"
version = "0.14.3"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b99da2f8558ca23c71f4fd15dc57c906239752dd27ff3c00a1d56b685b7cbfec"
dependencies = [
"bitflags 2.9.4",
"log",
"polling",
"rustix 1.1.2",
"rustix 0.38.44",
"slab",
"tracing",
"thiserror 1.0.69",
]
[[package]]
name = "calloop-wayland-source"
version = "0.4.1"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "138efcf0940a02ebf0cc8d1eff41a1682a46b431630f4c52450d6265876021fa"
checksum = "95a66a987056935f7efce4ab5668920b5d0dac4a7c99991a67395f13702ddd20"
dependencies = [
"calloop",
"rustix 1.1.2",
"rustix 0.38.44",
"wayland-backend",
"wayland-client",
]
@@ -3208,6 +3211,7 @@ dependencies = [
"rustc-hash 2.1.1",
"schemars 1.0.4",
"serde",
"serde_json",
"strum 0.27.2",
]
@@ -3687,6 +3691,7 @@ dependencies = [
"collections",
"futures 0.3.31",
"gpui",
"http_client",
"log",
"net",
"parking_lot",
@@ -8725,7 +8730,6 @@ dependencies = [
"ui",
"ui_input",
"util",
"vim",
"workspace",
"zed_actions",
]
@@ -10027,7 +10031,6 @@ name = "miniprofiler_ui"
version = "0.1.0"
dependencies = [
"gpui",
"log",
"serde_json",
"smol",
"util",

View File

@@ -784,7 +784,6 @@ features = [
notify = { git = "https://github.com/zed-industries/notify.git", rev = "b4588b2e5aee68f4c0e100f140e808cbce7b1419" }
notify-types = { git = "https://github.com/zed-industries/notify.git", rev = "b4588b2e5aee68f4c0e100f140e808cbce7b1419" }
windows-capture = { git = "https://github.com/zed-industries/windows-capture.git", rev = "f0d6c1b6691db75461b732f6d5ff56eed002eeb9" }
calloop = { path = "/home/davidsk/tmp/calloop" }
[profile.dev]
split-debuginfo = "unpacked"
@@ -861,7 +860,7 @@ ui_input = { codegen-units = 1 }
zed_actions = { codegen-units = 1 }
[profile.release]
debug = "full"
debug = "limited"
lto = "thin"
codegen-units = 1

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 9.3 KiB

After

Width:  |  Height:  |  Size: 14 KiB

View File

@@ -150,6 +150,7 @@ impl DbThread {
.unwrap_or_default(),
input: tool_use.input,
is_input_complete: true,
thought_signature: None,
},
));
}

View File

@@ -1108,6 +1108,7 @@ fn tool_use(
raw_input: serde_json::to_string_pretty(&input).unwrap(),
input: serde_json::to_value(input).unwrap(),
is_input_complete: true,
thought_signature: None,
})
}

View File

@@ -21,9 +21,9 @@ use gpui::{
use indoc::indoc;
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelToolResult, LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent,
Role, StopReason, fake_provider::FakeLanguageModel,
};
use pretty_assertions::assert_eq;
use project::{
@@ -274,6 +274,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
raw_input: json!({"text": "test"}).to_string(),
input: json!({"text": "test"}),
is_input_complete: true,
thought_signature: None,
};
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
@@ -461,6 +462,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
thought_signature: None,
},
));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
@@ -470,6 +472,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
thought_signature: None,
},
));
fake_model.end_last_completion_stream();
@@ -520,6 +523,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
thought_signature: None,
},
));
fake_model.end_last_completion_stream();
@@ -554,6 +558,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
thought_signature: None,
},
));
fake_model.end_last_completion_stream();
@@ -592,6 +597,7 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
thought_signature: None,
},
));
fake_model.end_last_completion_stream();
@@ -621,6 +627,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
raw_input: "{}".into(),
input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
is_input_complete: true,
thought_signature: None,
};
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
@@ -657,9 +664,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
);
// Simulate reaching tool use limit.
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
fake_model.end_last_completion_stream();
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
assert!(
@@ -731,6 +736,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
raw_input: "{}".into(),
input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
is_input_complete: true,
thought_signature: None,
};
let tool_result = LanguageModelToolResult {
tool_use_id: "tool_id_1".into(),
@@ -741,9 +747,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
};
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
fake_model.end_last_completion_stream();
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
assert!(
@@ -1037,6 +1041,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) {
raw_input: json!({"text": "test"}).to_string(),
input: json!({"text": "test"}),
is_input_complete: true,
thought_signature: None,
},
));
fake_model.end_last_completion_stream();
@@ -1080,6 +1085,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) {
raw_input: json!({"text": "mcp"}).to_string(),
input: json!({"text": "mcp"}),
is_input_complete: true,
thought_signature: None,
},
));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
@@ -1089,6 +1095,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) {
raw_input: json!({"text": "native"}).to_string(),
input: json!({"text": "native"}),
is_input_complete: true,
thought_signature: None,
},
));
fake_model.end_last_completion_stream();
@@ -1522,7 +1529,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
});
fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
language_model::TokenUsage {
input_tokens: 32_000,
output_tokens: 16_000,
@@ -1580,7 +1587,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
});
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
language_model::TokenUsage {
input_tokens: 40_000,
output_tokens: 20_000,
@@ -1625,7 +1632,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Message 1 response");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
language_model::TokenUsage {
input_tokens: 32_000,
output_tokens: 16_000,
@@ -1672,7 +1679,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Message 2 response");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
language_model::TokenUsage {
input_tokens: 40_000,
output_tokens: 20_000,
@@ -1788,6 +1795,7 @@ async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
raw_input: "{}".into(),
input: json!({}),
is_input_complete: true,
thought_signature: None,
};
let echo_tool_use = LanguageModelToolUse {
id: "tool_id_2".into(),
@@ -1795,6 +1803,7 @@ async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
raw_input: json!({"text": "test"}).to_string(),
input: json!({"text": "test"}),
is_input_complete: true,
thought_signature: None,
};
fake_model.send_last_completion_stream_text_chunk("Hi!");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
@@ -2000,6 +2009,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
raw_input: input.to_string(),
input,
is_input_complete: false,
thought_signature: None,
},
));
@@ -2012,6 +2022,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
raw_input: input.to_string(),
input,
is_input_complete: true,
thought_signature: None,
},
));
fake_model.end_last_completion_stream();
@@ -2144,7 +2155,6 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
fake_model.send_last_completion_stream_text_chunk("Hey,");
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)),
});
fake_model.end_last_completion_stream();
@@ -2214,12 +2224,12 @@ async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
raw_input: json!({"text": "test"}).to_string(),
input: json!({"text": "test"}),
is_input_complete: true,
thought_signature: None,
};
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
tool_use_1.clone(),
));
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)),
});
fake_model.end_last_completion_stream();
@@ -2286,7 +2296,6 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
fake_model.send_last_completion_stream_error(
LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)),
},
);

View File

@@ -15,7 +15,7 @@ use agent_settings::{
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage, UserStore};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
use cloud_llm_client::{CompletionIntent, Plan, UsageLimit};
use collections::{HashMap, HashSet, IndexMap};
use fs::Fs;
use futures::stream;
@@ -30,11 +30,11 @@ use gpui::{
};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
ZED_CLOUD_PROVIDER_ID,
LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role,
SelectedModel, StopReason, TokenUsage, ZED_CLOUD_PROVIDER_ID,
};
use project::Project;
use prompt_store::ProjectContext;
@@ -1295,9 +1295,10 @@ impl Thread {
if let Some(error) = error {
attempt += 1;
let provider = model.upstream_provider_name();
let retry = this.update(cx, |this, cx| {
let user_store = this.user_store.read(cx);
this.handle_completion_error(error, attempt, user_store.plan())
this.handle_completion_error(provider, error, attempt, user_store.plan())
})??;
let timer = cx.background_executor().timer(retry.duration);
event_stream.send_retry(retry);
@@ -1323,6 +1324,7 @@ impl Thread {
fn handle_completion_error(
&mut self,
provider: LanguageModelProviderName,
error: LanguageModelCompletionError,
attempt: u8,
plan: Option<Plan>,
@@ -1389,7 +1391,7 @@ impl Thread {
use LanguageModelCompletionEvent::*;
match event {
StartMessage { .. } => {
LanguageModelCompletionEvent::StartMessage { .. } => {
self.flush_pending_message(cx);
self.pending_message = Some(AgentMessage::default());
}
@@ -1416,7 +1418,7 @@ impl Thread {
),
)));
}
UsageUpdate(usage) => {
TokenUsage(usage) => {
telemetry::event!(
"Agent Thread Completion Usage Updated",
thread_id = self.id.to_string(),
@@ -1430,20 +1432,16 @@ impl Thread {
);
self.update_token_usage(usage, cx);
}
StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
RequestUsage { amount, limit } => {
self.update_model_request_usage(amount, limit, cx);
}
StatusUpdate(
CompletionRequestStatus::Started
| CompletionRequestStatus::Queued { .. }
| CompletionRequestStatus::Failed { .. },
) => {}
StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
ToolUseLimitReached => {
self.tool_use_limit_reached = true;
}
Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
Started | Queued { .. } => {}
}
Ok(None)
@@ -1687,9 +1685,7 @@ impl Thread {
let event = event.log_err()?;
let text = match event {
LanguageModelCompletionEvent::Text(text) => text,
LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::UsageUpdated { amount, limit },
) => {
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
this.update(cx, |thread, cx| {
thread.update_model_request_usage(amount, limit, cx);
})
@@ -1753,9 +1749,7 @@ impl Thread {
let event = event?;
let text = match event {
LanguageModelCompletionEvent::Text(text) => text,
LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::UsageUpdated { amount, limit },
) => {
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
this.update(cx, |thread, cx| {
thread.update_model_request_usage(amount, limit, cx);
})?;

View File

@@ -247,37 +247,58 @@ impl AgentConnection for AcpConnection {
let default_mode = self.default_mode.clone();
let cwd = cwd.to_path_buf();
let context_server_store = project.read(cx).context_server_store().read(cx);
let mcp_servers = if project.read(cx).is_local() {
context_server_store
.configured_server_ids()
.iter()
.filter_map(|id| {
let configuration = context_server_store.configuration_for_server(id)?;
let command = configuration.command();
Some(acp::McpServer::Stdio {
name: id.0.to_string(),
command: command.path.clone(),
args: command.args.clone(),
env: if let Some(env) = command.env.as_ref() {
env.iter()
.map(|(name, value)| acp::EnvVariable {
name: name.clone(),
value: value.clone(),
meta: None,
})
.collect()
} else {
vec![]
},
let mcp_servers =
if project.read(cx).is_local() {
context_server_store
.configured_server_ids()
.iter()
.filter_map(|id| {
let configuration = context_server_store.configuration_for_server(id)?;
match &*configuration {
project::context_server_store::ContextServerConfiguration::Custom {
command,
..
}
| project::context_server_store::ContextServerConfiguration::Extension {
command,
..
} => Some(acp::McpServer::Stdio {
name: id.0.to_string(),
command: command.path.clone(),
args: command.args.clone(),
env: if let Some(env) = command.env.as_ref() {
env.iter()
.map(|(name, value)| acp::EnvVariable {
name: name.clone(),
value: value.clone(),
meta: None,
})
.collect()
} else {
vec![]
},
}),
project::context_server_store::ContextServerConfiguration::Http {
url,
headers,
} => Some(acp::McpServer::Http {
name: id.0.to_string(),
url: url.to_string(),
headers: headers.iter().map(|(name, value)| acp::HttpHeader {
name: name.clone(),
value: value.clone(),
meta: None,
}).collect(),
}),
}
})
})
.collect()
} else {
// In SSH projects, the external agent is running on the remote
// machine, and currently we only run MCP servers on the local
// machine. So don't pass any MCP servers to the agent in that case.
Vec::new()
};
.collect()
} else {
// In SSH projects, the external agent is running on the remote
// machine, and currently we only run MCP servers on the local
// machine. So don't pass any MCP servers to the agent in that case.
Vec::new()
};
cx.spawn(async move |cx| {
let response = conn

View File

@@ -1,5 +1,5 @@
mod add_llm_provider_modal;
mod configure_context_server_modal;
pub mod configure_context_server_modal;
mod configure_context_server_tools_modal;
mod manage_profiles_modal;
mod tool_picker;
@@ -46,9 +46,8 @@ pub(crate) use configure_context_server_modal::ConfigureContextServerModal;
pub(crate) use configure_context_server_tools_modal::ConfigureContextServerToolsModal;
pub(crate) use manage_profiles_modal::ManageProfilesModal;
use crate::{
AddContextServer,
agent_configuration::add_llm_provider_modal::{AddLlmProviderModal, LlmCompatibleProvider},
use crate::agent_configuration::add_llm_provider_modal::{
AddLlmProviderModal, LlmCompatibleProvider,
};
pub struct AgentConfiguration {
@@ -553,7 +552,9 @@ impl AgentConfiguration {
move |window, cx| {
Some(ContextMenu::build(window, cx, |menu, _window, _cx| {
menu.entry("Add Custom Server", None, {
|window, cx| window.dispatch_action(AddContextServer.boxed_clone(), cx)
|window, cx| {
window.dispatch_action(crate::AddContextServer.boxed_clone(), cx)
}
})
.entry("Install from Extensions", None, {
|window, cx| {
@@ -651,7 +652,7 @@ impl AgentConfiguration {
let is_running = matches!(server_status, ContextServerStatus::Running);
let item_id = SharedString::from(context_server_id.0.clone());
// Servers without a configuration can only be provided by extensions.
let provided_by_extension = server_configuration.is_none_or(|config| {
let provided_by_extension = server_configuration.as_ref().is_none_or(|config| {
matches!(
config.as_ref(),
ContextServerConfiguration::Extension { .. }
@@ -707,7 +708,10 @@ impl AgentConfiguration {
"Server is stopped.",
),
};
let is_remote = server_configuration
.as_ref()
.map(|config| matches!(config.as_ref(), ContextServerConfiguration::Http { .. }))
.unwrap_or(false);
let context_server_configuration_menu = PopoverMenu::new("context-server-config-menu")
.trigger_with_tooltip(
IconButton::new("context-server-config-menu", IconName::Settings)
@@ -730,14 +734,25 @@ impl AgentConfiguration {
let language_registry = language_registry.clone();
let workspace = workspace.clone();
move |window, cx| {
ConfigureContextServerModal::show_modal_for_existing_server(
context_server_id.clone(),
language_registry.clone(),
workspace.clone(),
window,
cx,
)
.detach_and_log_err(cx);
if is_remote {
crate::agent_configuration::configure_context_server_modal::ConfigureContextServerModal::show_modal_for_existing_server(
context_server_id.clone(),
language_registry.clone(),
workspace.clone(),
window,
cx,
)
.detach();
} else {
ConfigureContextServerModal::show_modal_for_existing_server(
context_server_id.clone(),
language_registry.clone(),
workspace.clone(),
window,
cx,
)
.detach();
}
}
}).when(tool_count > 0, |this| this.entry("View Tools", None, {
let context_server_id = context_server_id.clone();

View File

@@ -3,16 +3,42 @@ use std::sync::Arc;
use anyhow::Result;
use collections::HashSet;
use fs::Fs;
use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, Task};
use gpui::{
DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, ScrollHandle, Task,
};
use language_model::LanguageModelRegistry;
use language_models::provider::open_ai_compatible::{AvailableModel, ModelCapabilities};
use settings::{OpenAiCompatibleSettingsContent, update_settings_file};
use ui::{
Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState, prelude::*,
Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState,
WithScrollbar, prelude::*,
};
use ui_input::InputField;
use workspace::{ModalView, Workspace};
fn single_line_input(
label: impl Into<SharedString>,
placeholder: impl Into<SharedString>,
text: Option<&str>,
tab_index: isize,
window: &mut Window,
cx: &mut App,
) -> Entity<InputField> {
cx.new(|cx| {
let input = InputField::new(window, cx, placeholder)
.label(label)
.tab_index(tab_index)
.tab_stop(true);
if let Some(text) = text {
input
.editor()
.update(cx, |editor, cx| editor.set_text(text, window, cx));
}
input
})
}
#[derive(Clone, Copy)]
pub enum LlmCompatibleProvider {
OpenAi,
@@ -41,12 +67,14 @@ struct AddLlmProviderInput {
impl AddLlmProviderInput {
fn new(provider: LlmCompatibleProvider, window: &mut Window, cx: &mut App) -> Self {
let provider_name = single_line_input("Provider Name", provider.name(), None, window, cx);
let api_url = single_line_input("API URL", provider.api_url(), None, window, cx);
let provider_name =
single_line_input("Provider Name", provider.name(), None, 1, window, cx);
let api_url = single_line_input("API URL", provider.api_url(), None, 2, window, cx);
let api_key = single_line_input(
"API Key",
"000000000000000000000000000000000000000000000000",
None,
3,
window,
cx,
);
@@ -55,12 +83,13 @@ impl AddLlmProviderInput {
provider_name,
api_url,
api_key,
models: vec![ModelInput::new(window, cx)],
models: vec![ModelInput::new(0, window, cx)],
}
}
fn add_model(&mut self, window: &mut Window, cx: &mut App) {
self.models.push(ModelInput::new(window, cx));
let model_index = self.models.len();
self.models.push(ModelInput::new(model_index, window, cx));
}
fn remove_model(&mut self, index: usize) {
@@ -84,11 +113,14 @@ struct ModelInput {
}
impl ModelInput {
fn new(window: &mut Window, cx: &mut App) -> Self {
fn new(model_index: usize, window: &mut Window, cx: &mut App) -> Self {
let base_tab_index = (3 + (model_index * 4)) as isize;
let model_name = single_line_input(
"Model Name",
"e.g. gpt-4o, claude-opus-4, gemini-2.5-pro",
None,
base_tab_index + 1,
window,
cx,
);
@@ -96,6 +128,7 @@ impl ModelInput {
"Max Completion Tokens",
"200000",
Some("200000"),
base_tab_index + 2,
window,
cx,
);
@@ -103,16 +136,26 @@ impl ModelInput {
"Max Output Tokens",
"Max Output Tokens",
Some("32000"),
base_tab_index + 3,
window,
cx,
);
let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx);
let max_tokens = single_line_input(
"Max Tokens",
"Max Tokens",
Some("200000"),
base_tab_index + 4,
window,
cx,
);
let ModelCapabilities {
tools,
images,
parallel_tool_calls,
prompt_cache_key,
} = ModelCapabilities::default();
Self {
name: model_name,
max_completion_tokens,
@@ -165,24 +208,6 @@ impl ModelInput {
}
}
fn single_line_input(
label: impl Into<SharedString>,
placeholder: impl Into<SharedString>,
text: Option<&str>,
window: &mut Window,
cx: &mut App,
) -> Entity<InputField> {
cx.new(|cx| {
let input = InputField::new(window, cx, placeholder).label(label);
if let Some(text) = text {
input
.editor()
.update(cx, |editor, cx| editor.set_text(text, window, cx));
}
input
})
}
fn save_provider_to_settings(
input: &AddLlmProviderInput,
cx: &mut App,
@@ -258,6 +283,7 @@ fn save_provider_to_settings(
pub struct AddLlmProviderModal {
provider: LlmCompatibleProvider,
input: AddLlmProviderInput,
scroll_handle: ScrollHandle,
focus_handle: FocusHandle,
last_error: Option<SharedString>,
}
@@ -278,6 +304,7 @@ impl AddLlmProviderModal {
provider,
last_error: None,
focus_handle: cx.focus_handle(),
scroll_handle: ScrollHandle::new(),
}
}
@@ -418,6 +445,19 @@ impl AddLlmProviderModal {
)
})
}
fn on_tab(&mut self, _: &menu::SelectNext, window: &mut Window, _: &mut Context<Self>) {
window.focus_next();
}
fn on_tab_prev(
&mut self,
_: &menu::SelectPrevious,
window: &mut Window,
_: &mut Context<Self>,
) {
window.focus_prev();
}
}
impl EventEmitter<DismissEvent> for AddLlmProviderModal {}
@@ -431,15 +471,27 @@ impl Focusable for AddLlmProviderModal {
impl ModalView for AddLlmProviderModal {}
impl Render for AddLlmProviderModal {
fn render(&mut self, _window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
fn render(&mut self, window: &mut ui::Window, cx: &mut ui::Context<Self>) -> impl IntoElement {
let focus_handle = self.focus_handle(cx);
div()
let window_size = window.viewport_size();
let rem_size = window.rem_size();
let is_large_window = window_size.height / rem_size > rems_from_px(600.).0;
let modal_max_height = if is_large_window {
rems_from_px(450.)
} else {
rems_from_px(200.)
};
v_flex()
.id("add-llm-provider-modal")
.key_context("AddLlmProviderModal")
.w(rems(34.))
.elevation_3(cx)
.on_action(cx.listener(Self::cancel))
.on_action(cx.listener(Self::on_tab))
.on_action(cx.listener(Self::on_tab_prev))
.capture_any_mouse_down(cx.listener(|this, _, window, cx| {
this.focus_handle(cx).focus(window);
}))
@@ -462,17 +514,25 @@ impl Render for AddLlmProviderModal {
)
})
.child(
v_flex()
.id("modal_content")
div()
.size_full()
.max_h_128()
.overflow_y_scroll()
.px(DynamicSpacing::Base12.rems(cx))
.gap(DynamicSpacing::Base04.rems(cx))
.child(self.input.provider_name.clone())
.child(self.input.api_url.clone())
.child(self.input.api_key.clone())
.child(self.render_model_section(cx)),
.vertical_scrollbar_for(self.scroll_handle.clone(), window, cx)
.child(
v_flex()
.id("modal_content")
.size_full()
.tab_group()
.max_h(modal_max_height)
.pl_3()
.pr_4()
.gap_2()
.overflow_y_scroll()
.track_scroll(&self.scroll_handle)
.child(self.input.provider_name.clone())
.child(self.input.api_url.clone())
.child(self.input.api_key.clone())
.child(self.render_model_section(cx)),
),
)
.footer(
ModalFooter::new().end_slot(
@@ -642,7 +702,7 @@ mod tests {
let cx = setup_test(cx).await;
cx.update(|window, cx| {
let model_input = ModelInput::new(window, cx);
let model_input = ModelInput::new(0, window, cx);
model_input.name.update(cx, |input, cx| {
input.editor().update(cx, |editor, cx| {
editor.set_text("somemodel", window, cx);
@@ -678,7 +738,7 @@ mod tests {
let cx = setup_test(cx).await;
cx.update(|window, cx| {
let mut model_input = ModelInput::new(window, cx);
let mut model_input = ModelInput::new(0, window, cx);
model_input.name.update(cx, |input, cx| {
input.editor().update(cx, |editor, cx| {
editor.set_text("somemodel", window, cx);
@@ -703,7 +763,7 @@ mod tests {
let cx = setup_test(cx).await;
cx.update(|window, cx| {
let mut model_input = ModelInput::new(window, cx);
let mut model_input = ModelInput::new(0, window, cx);
model_input.name.update(cx, |input, cx| {
input.editor().update(cx, |editor, cx| {
editor.set_text("somemodel", window, cx);
@@ -767,7 +827,7 @@ mod tests {
models.iter().enumerate()
{
if i >= input.models.len() {
input.models.push(ModelInput::new(window, cx));
input.models.push(ModelInput::new(i, window, cx));
}
let model = &mut input.models[i];
set_text(&model.name, name, window, cx);

View File

@@ -4,6 +4,7 @@ use std::{
};
use anyhow::{Context as _, Result};
use collections::HashMap;
use context_server::{ContextServerCommand, ContextServerId};
use editor::{Editor, EditorElement, EditorStyle};
use gpui::{
@@ -20,6 +21,7 @@ use project::{
project_settings::{ContextServerSettings, ProjectSettings},
worktree_store::WorktreeStore,
};
use serde::Deserialize;
use settings::{Settings as _, update_settings_file};
use theme::ThemeSettings;
use ui::{
@@ -37,6 +39,11 @@ enum ConfigurationTarget {
id: ContextServerId,
command: ContextServerCommand,
},
ExistingHttp {
id: ContextServerId,
url: String,
headers: HashMap<String, String>,
},
Extension {
id: ContextServerId,
repository_url: Option<SharedString>,
@@ -47,9 +54,11 @@ enum ConfigurationTarget {
enum ConfigurationSource {
New {
editor: Entity<Editor>,
is_http: bool,
},
Existing {
editor: Entity<Editor>,
is_http: bool,
},
Extension {
id: ContextServerId,
@@ -97,6 +106,7 @@ impl ConfigurationSource {
match target {
ConfigurationTarget::New => ConfigurationSource::New {
editor: create_editor(context_server_input(None), jsonc_language, window, cx),
is_http: false,
},
ConfigurationTarget::Existing { id, command } => ConfigurationSource::Existing {
editor: create_editor(
@@ -105,6 +115,20 @@ impl ConfigurationSource {
window,
cx,
),
is_http: false,
},
ConfigurationTarget::ExistingHttp {
id,
url,
headers: auth,
} => ConfigurationSource::Existing {
editor: create_editor(
context_server_http_input(Some((id, url, auth))),
jsonc_language,
window,
cx,
),
is_http: true,
},
ConfigurationTarget::Extension {
id,
@@ -141,16 +165,30 @@ impl ConfigurationSource {
fn output(&self, cx: &mut App) -> Result<(ContextServerId, ContextServerSettings)> {
match self {
ConfigurationSource::New { editor } | ConfigurationSource::Existing { editor } => {
parse_input(&editor.read(cx).text(cx)).map(|(id, command)| {
(
id,
ContextServerSettings::Custom {
enabled: true,
command,
},
)
})
ConfigurationSource::New { editor, is_http }
| ConfigurationSource::Existing { editor, is_http } => {
if *is_http {
parse_http_input(&editor.read(cx).text(cx)).map(|(id, url, auth)| {
(
id,
ContextServerSettings::Http {
enabled: true,
url,
headers: auth,
},
)
})
} else {
parse_input(&editor.read(cx).text(cx)).map(|(id, command)| {
(
id,
ContextServerSettings::Custom {
enabled: true,
command,
},
)
})
}
}
ConfigurationSource::Extension {
id,
@@ -212,6 +250,66 @@ fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand)
)
}
fn context_server_http_input(
existing: Option<(ContextServerId, String, HashMap<String, String>)>,
) -> String {
let (name, url, headers) = match existing {
Some((id, url, headers)) => {
let header = if headers.is_empty() {
r#"// "Authorization": "Bearer <token>"#.to_string()
} else {
let json = serde_json::to_string_pretty(&headers).unwrap();
let mut lines = json.split("\n").collect::<Vec<_>>();
if lines.len() > 1 {
lines.remove(0);
lines.pop();
}
lines
.into_iter()
.map(|line| format!(" {}", line))
.collect::<String>()
};
(id.0.to_string(), url, header)
}
None => (
"some-remote-server".to_string(),
"https://example.com/mcp".to_string(),
r#"// "Authorization": "Bearer <token>"#.to_string(),
),
};
format!(
r#"{{
/// The name of your remote MCP server
"{name}": {{
/// The URL of the remote MCP server
"url": "{url}",
"headers": {{
/// Any headers to send along
{headers}
}}
}}
}}"#
)
}
fn parse_http_input(text: &str) -> Result<(ContextServerId, String, HashMap<String, String>)> {
#[derive(Deserialize)]
struct Temp {
url: String,
#[serde(default)]
headers: HashMap<String, String>,
}
let value: HashMap<String, Temp> = serde_json_lenient::from_str(text)?;
if value.len() != 1 {
anyhow::bail!("Expected exactly one context server configuration");
}
let (key, value) = value.into_iter().next().unwrap();
Ok((ContextServerId(key.into()), value.url, value.headers))
}
fn resolve_context_server_extension(
id: ContextServerId,
worktree_store: Entity<WorktreeStore>,
@@ -312,6 +410,15 @@ impl ConfigureContextServerModal {
id: server_id,
command,
}),
ContextServerSettings::Http {
enabled: _,
url,
headers,
} => Some(ConfigurationTarget::ExistingHttp {
id: server_id,
url,
headers,
}),
ContextServerSettings::Extension { .. } => {
match workspace
.update(cx, |workspace, cx| {
@@ -353,6 +460,7 @@ impl ConfigureContextServerModal {
state: State::Idle,
original_server_id: match &target {
ConfigurationTarget::Existing { id, .. } => Some(id.clone()),
ConfigurationTarget::ExistingHttp { id, .. } => Some(id.clone()),
ConfigurationTarget::Extension { id, .. } => Some(id.clone()),
ConfigurationTarget::New => None,
},
@@ -481,7 +589,7 @@ impl ModalView for ConfigureContextServerModal {}
impl Focusable for ConfigureContextServerModal {
fn focus_handle(&self, cx: &App) -> FocusHandle {
match &self.source {
ConfigurationSource::New { editor } => editor.focus_handle(cx),
ConfigurationSource::New { editor, .. } => editor.focus_handle(cx),
ConfigurationSource::Existing { editor, .. } => editor.focus_handle(cx),
ConfigurationSource::Extension { editor, .. } => editor
.as_ref()
@@ -527,9 +635,10 @@ impl ConfigureContextServerModal {
}
fn render_modal_content(&self, cx: &App) -> AnyElement {
// All variants now use single editor approach
let editor = match &self.source {
ConfigurationSource::New { editor } => editor,
ConfigurationSource::Existing { editor } => editor,
ConfigurationSource::New { editor, .. } => editor,
ConfigurationSource::Existing { editor, .. } => editor,
ConfigurationSource::Extension { editor, .. } => {
let Some(editor) = editor else {
return div().into_any_element();
@@ -601,6 +710,36 @@ impl ConfigureContextServerModal {
move |_, _, cx| cx.open_url(&repository_url)
}),
)
} else if let ConfigurationSource::New { is_http, .. } = &self.source {
let label = if *is_http {
"Run command"
} else {
"Connect via HTTP"
};
let tooltip = if *is_http {
"Configure an MCP serevr that runs on stdin/stdout."
} else {
"Configure an MCP server that you connect to over HTTP"
};
Some(
Button::new("toggle-kind", label)
.tooltip(Tooltip::text(tooltip))
.on_click(cx.listener(|this, _, window, cx| match &mut this.source {
ConfigurationSource::New { editor, is_http } => {
*is_http = !*is_http;
let new_text = if *is_http {
context_server_http_input(None)
} else {
context_server_input(None)
};
editor.update(cx, |editor, cx| {
editor.set_text(new_text, window, cx);
})
}
_ => {}
})),
)
} else {
None
},

View File

@@ -7,9 +7,10 @@ use assistant_slash_command::{
use assistant_slash_commands::FileCommandMetadata;
use client::{self, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry};
use clock::ReplicaId;
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
use cloud_llm_client::{CompletionIntent, UsageLimit};
use collections::{HashMap, HashSet};
use fs::{Fs, RenameOptions};
use futures::{FutureExt, StreamExt, future::Shared};
use gpui::{
App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription,
@@ -2073,14 +2074,15 @@ impl TextThread {
});
match event {
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
if let CompletionRequestStatus::UsageUpdated { amount, limit } = status_update {
this.update_model_request_usage(
amount as u32,
limit,
cx,
);
}
LanguageModelCompletionEvent::Started |
LanguageModelCompletionEvent::Queued {..} |
LanguageModelCompletionEvent::ToolUseLimitReached { .. } => {}
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
this.update_model_request_usage(
amount as u32,
limit,
cx,
);
}
LanguageModelCompletionEvent::StartMessage { .. } => {}
LanguageModelCompletionEvent::Stop(reason) => {
@@ -2142,7 +2144,7 @@ impl TextThread {
}
LanguageModelCompletionEvent::ToolUse(_) |
LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
LanguageModelCompletionEvent::UsageUpdate(_) => {}
LanguageModelCompletionEvent::TokenUsage(_) => {}
}
});

View File

@@ -19,4 +19,5 @@ ordered-float.workspace = true
rustc-hash.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
strum.workspace = true

View File

@@ -40,9 +40,47 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R
pub struct SearchToolInput {
/// An array of queries to run for gathering context relevant to the next prediction
#[schemars(length(max = 3))]
#[serde(deserialize_with = "deserialize_queries")]
pub queries: Box<[SearchToolQuery]>,
}
fn deserialize_queries<'de, D>(deserializer: D) -> Result<Box<[SearchToolQuery]>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
#[derive(Deserialize)]
#[serde(untagged)]
enum QueryCollection {
Array(Box<[SearchToolQuery]>),
DoubleArray(Box<[Box<[SearchToolQuery]>]>),
Single(SearchToolQuery),
}
#[derive(Deserialize)]
#[serde(untagged)]
enum MaybeDoubleEncoded {
SingleEncoded(QueryCollection),
DoubleEncoded(String),
}
let result = MaybeDoubleEncoded::deserialize(deserializer)?;
let normalized = match result {
MaybeDoubleEncoded::SingleEncoded(value) => value,
MaybeDoubleEncoded::DoubleEncoded(value) => {
serde_json::from_str(&value).map_err(D::Error::custom)?
}
};
Ok(match normalized {
QueryCollection::Array(items) => items,
QueryCollection::Single(search_tool_query) => Box::new([search_tool_query]),
QueryCollection::DoubleArray(double_array) => double_array.into_iter().flatten().collect(),
})
}
/// Search for relevant code by path, syntax hierarchy, and content.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct SearchToolQuery {
@@ -92,3 +130,115 @@ const TOOL_USE_REMINDER: &str = indoc! {"
--
Analyze the user's intent in one to two sentences, then call the `search` tool.
"};
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
#[test]
fn test_deserialize_queries() {
let single_query_json = indoc! {r#"{
"queries": {
"glob": "**/*.rs",
"syntax_node": ["fn test"],
"content": "assert"
}
}"#};
let flat_input: SearchToolInput = serde_json::from_str(single_query_json).unwrap();
assert_eq!(flat_input.queries.len(), 1);
assert_eq!(flat_input.queries[0].glob, "**/*.rs");
assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
let flat_json = indoc! {r#"{
"queries": [
{
"glob": "**/*.rs",
"syntax_node": ["fn test"],
"content": "assert"
},
{
"glob": "**/*.ts",
"syntax_node": [],
"content": null
}
]
}"#};
let flat_input: SearchToolInput = serde_json::from_str(flat_json).unwrap();
assert_eq!(flat_input.queries.len(), 2);
assert_eq!(flat_input.queries[0].glob, "**/*.rs");
assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
assert_eq!(flat_input.queries[1].glob, "**/*.ts");
assert_eq!(flat_input.queries[1].syntax_node.len(), 0);
assert_eq!(flat_input.queries[1].content, None);
let nested_json = indoc! {r#"{
"queries": [
[
{
"glob": "**/*.rs",
"syntax_node": ["fn test"],
"content": "assert"
}
],
[
{
"glob": "**/*.ts",
"syntax_node": [],
"content": null
}
]
]
}"#};
let nested_input: SearchToolInput = serde_json::from_str(nested_json).unwrap();
assert_eq!(nested_input.queries.len(), 2);
assert_eq!(nested_input.queries[0].glob, "**/*.rs");
assert_eq!(nested_input.queries[0].syntax_node, vec!["fn test"]);
assert_eq!(nested_input.queries[0].content, Some("assert".to_string()));
assert_eq!(nested_input.queries[1].glob, "**/*.ts");
assert_eq!(nested_input.queries[1].syntax_node.len(), 0);
assert_eq!(nested_input.queries[1].content, None);
let double_encoded_queries = serde_json::to_string(&json!({
"queries": serde_json::to_string(&json!([
{
"glob": "**/*.rs",
"syntax_node": ["fn test"],
"content": "assert"
},
{
"glob": "**/*.ts",
"syntax_node": [],
"content": null
}
])).unwrap()
}))
.unwrap();
let double_encoded_input: SearchToolInput =
serde_json::from_str(&double_encoded_queries).unwrap();
assert_eq!(double_encoded_input.queries.len(), 2);
assert_eq!(double_encoded_input.queries[0].glob, "**/*.rs");
assert_eq!(double_encoded_input.queries[0].syntax_node, vec!["fn test"]);
assert_eq!(
double_encoded_input.queries[0].content,
Some("assert".to_string())
);
assert_eq!(double_encoded_input.queries[1].glob, "**/*.ts");
assert_eq!(double_encoded_input.queries[1].syntax_node.len(), 0);
assert_eq!(double_encoded_input.queries[1].content, None);
// ### ERROR Switching from var declarations to lexical declarations [RUN 073]
// invalid search json {"queries": ["express/lib/response.js", "var\\s+[a-zA-Z_][a-zA-Z0-9_]*\\s*=.*;", "function.*\\(.*\\).*\\{.*\\}"]}
}
}

View File

@@ -12,7 +12,7 @@ workspace = true
path = "src/context_server.rs"
[features]
test-support = []
test-support = ["gpui/test-support"]
[dependencies]
anyhow.workspace = true
@@ -20,6 +20,7 @@ async-trait.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
http_client = { workspace = true, features = ["test-support"] }
log.workspace = true
net.workspace = true
parking_lot.workspace = true
@@ -32,3 +33,6 @@ smol.workspace = true
tempfile.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }

View File

@@ -6,6 +6,8 @@ pub mod test;
pub mod transport;
pub mod types;
use collections::HashMap;
use http_client::HttpClient;
use std::path::Path;
use std::sync::Arc;
use std::{fmt::Display, path::PathBuf};
@@ -15,6 +17,9 @@ use client::Client;
use gpui::AsyncApp;
use parking_lot::RwLock;
pub use settings::ContextServerCommand;
use url::Url;
use crate::transport::HttpTransport;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ContextServerId(pub Arc<str>);
@@ -52,6 +57,25 @@ impl ContextServer {
}
}
pub fn http(
id: ContextServerId,
endpoint: &Url,
headers: HashMap<String, String>,
http_client: Arc<dyn HttpClient>,
executor: gpui::BackgroundExecutor,
) -> Result<Self> {
let transport = match endpoint.scheme() {
"http" | "https" => {
log::info!("Using HTTP transport for {}", endpoint);
let transport =
HttpTransport::new(http_client, endpoint.to_string(), headers, executor);
Arc::new(transport) as _
}
_ => anyhow::bail!("unsupported MCP url scheme {}", endpoint.scheme()),
};
Ok(Self::new(id, transport))
}
pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
Self {
id,

View File

@@ -1,11 +1,12 @@
pub mod http;
mod stdio_transport;
use std::pin::Pin;
use anyhow::Result;
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
pub use http::*;
pub use stdio_transport::*;
#[async_trait]

View File

@@ -0,0 +1,259 @@
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use collections::HashMap;
use futures::{Stream, StreamExt};
use gpui::BackgroundExecutor;
use http_client::{AsyncBody, HttpClient, Request, Response, http::Method};
use parking_lot::Mutex as SyncMutex;
use smol::channel;
use std::{pin::Pin, sync::Arc};
use crate::transport::Transport;
// Constants from MCP spec
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
const JSON_MIME_TYPE: &str = "application/json";
/// HTTP Transport with session management and SSE support
pub struct HttpTransport {
http_client: Arc<dyn HttpClient>,
endpoint: String,
session_id: Arc<SyncMutex<Option<String>>>,
executor: BackgroundExecutor,
response_tx: channel::Sender<String>,
response_rx: channel::Receiver<String>,
error_tx: channel::Sender<String>,
error_rx: channel::Receiver<String>,
// Authentication headers to include in requests
headers: HashMap<String, String>,
}
impl HttpTransport {
pub fn new(
http_client: Arc<dyn HttpClient>,
endpoint: String,
headers: HashMap<String, String>,
executor: BackgroundExecutor,
) -> Self {
let (response_tx, response_rx) = channel::unbounded();
let (error_tx, error_rx) = channel::unbounded();
Self {
http_client,
executor,
endpoint,
session_id: Arc::new(SyncMutex::new(None)),
response_tx,
response_rx,
error_tx,
error_rx,
headers,
}
}
/// Send a message and handle the response based on content type
async fn send_message(&self, message: String) -> Result<()> {
let is_notification =
!message.contains("\"id\":") || message.contains("notifications/initialized");
let mut request_builder = Request::builder()
.method(Method::POST)
.uri(&self.endpoint)
.header("Content-Type", JSON_MIME_TYPE)
.header(
"Accept",
format!("{}, {}", JSON_MIME_TYPE, EVENT_STREAM_MIME_TYPE),
);
for (key, value) in &self.headers {
request_builder = request_builder.header(key.as_str(), value.as_str());
}
// Add session ID if we have one (except for initialize)
if let Some(ref session_id) = *self.session_id.lock() {
request_builder = request_builder.header(HEADER_SESSION_ID, session_id.as_str());
}
let request = request_builder.body(AsyncBody::from(message.into_bytes()))?;
let mut response = self.http_client.send(request).await?;
// Handle different response types based on status and content-type
match response.status() {
status if status.is_success() => {
// Check content type
let content_type = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok());
// Extract session ID from response headers if present
if let Some(session_id) = response
.headers()
.get(HEADER_SESSION_ID)
.and_then(|v| v.to_str().ok())
{
*self.session_id.lock() = Some(session_id.to_string());
log::debug!("Session ID set: {}", session_id);
}
match content_type {
Some(ct) if ct.starts_with(JSON_MIME_TYPE) => {
// JSON response - read and forward immediately
let mut body = String::new();
futures::AsyncReadExt::read_to_string(response.body_mut(), &mut body)
.await?;
// Only send non-empty responses
if !body.is_empty() {
self.response_tx
.send(body)
.await
.map_err(|_| anyhow!("Failed to send JSON response"))?;
}
}
Some(ct) if ct.starts_with(EVENT_STREAM_MIME_TYPE) => {
// SSE stream - set up streaming
self.setup_sse_stream(response).await?;
}
_ => {
// For notifications, 202 Accepted with no content type is ok
if is_notification && status.as_u16() == 202 {
log::debug!("Notification accepted");
} else {
return Err(anyhow!("Unexpected content type: {:?}", content_type));
}
}
}
}
status if status.as_u16() == 202 => {
// Accepted - notification acknowledged, no response needed
log::debug!("Notification accepted");
}
_ => {
let mut error_body = String::new();
futures::AsyncReadExt::read_to_string(response.body_mut(), &mut error_body).await?;
self.error_tx
.send(format!("HTTP {}: {}", response.status(), error_body))
.await
.map_err(|_| anyhow!("Failed to send error"))?;
}
}
Ok(())
}
/// Set up SSE streaming from the response
async fn setup_sse_stream(&self, mut response: Response<AsyncBody>) -> Result<()> {
let response_tx = self.response_tx.clone();
let error_tx = self.error_tx.clone();
// Spawn a task to handle the SSE stream
smol::spawn(async move {
let reader = futures::io::BufReader::new(response.body_mut());
let mut lines = futures::AsyncBufReadExt::lines(reader);
let mut data_buffer = Vec::new();
let mut in_message = false;
while let Some(line_result) = lines.next().await {
match line_result {
Ok(line) => {
if line.is_empty() {
// Empty line signals end of event
if !data_buffer.is_empty() {
let message = data_buffer.join("\n");
// Filter out ping messages and empty data
if !message.trim().is_empty() && message != "ping" {
if let Err(e) = response_tx.send(message).await {
log::error!("Failed to send SSE message: {}", e);
break;
}
}
data_buffer.clear();
}
in_message = false;
} else if let Some(data) = line.strip_prefix("data: ") {
// Handle data lines
let data = data.trim();
if !data.is_empty() {
// Check if this is a ping message
if data == "ping" {
log::trace!("Received SSE ping");
continue;
}
data_buffer.push(data.to_string());
in_message = true;
}
} else if line.starts_with("event:")
|| line.starts_with("id:")
|| line.starts_with("retry:")
{
// Ignore other SSE fields
continue;
} else if in_message {
// Continuation of data
data_buffer.push(line);
}
}
Err(e) => {
let _ = error_tx.send(format!("SSE stream error: {}", e)).await;
break;
}
}
}
})
.detach();
Ok(())
}
}
#[async_trait]
impl Transport for HttpTransport {
async fn send(&self, message: String) -> Result<()> {
self.send_message(message).await
}
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
Box::pin(self.response_rx.clone())
}
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
Box::pin(self.error_rx.clone())
}
}
impl Drop for HttpTransport {
fn drop(&mut self) {
// Try to cleanup session on drop
let http_client = self.http_client.clone();
let endpoint = self.endpoint.clone();
let session_id = self.session_id.lock().clone();
let headers = self.headers.clone();
if let Some(session_id) = session_id {
self.executor
.spawn(async move {
let mut request_builder = Request::builder()
.method(Method::DELETE)
.uri(&endpoint)
.header(HEADER_SESSION_ID, &session_id);
// Add authentication headers if present
for (key, value) in headers {
request_builder = request_builder.header(key.as_str(), value.as_str());
}
let request = request_builder.body(AsyncBody::empty());
if let Ok(request) = request {
let _ = http_client.send(request).await;
}
})
.detach();
}
}
}

View File

@@ -289,15 +289,11 @@ impl minidumper::ServerHandler for CrashServer {
pub fn panic_hook(info: &PanicHookInfo) {
// Don't handle a panic on threads that are not relevant to the main execution.
if extension_host::wasm_host::IS_WASM_THREAD.with(|v| v.load(Ordering::Acquire)) {
log::error!("wasm thread panicked!");
return;
}
let message = info
.payload()
.downcast_ref::<&str>()
.map(|s| s.to_string())
.or_else(|| info.payload().downcast_ref::<String>().cloned())
.unwrap_or_else(|| "Box<Any>".to_string());
let message = info.payload_as_str().unwrap_or("Box<Any>").to_owned();
let span = info
.location()

View File

@@ -3291,8 +3291,8 @@ impl Editor {
self.refresh_document_highlights(cx);
refresh_linked_ranges(self, window, cx);
// self.refresh_selected_text_highlights(false, window, cx);
// self.refresh_matching_bracket_highlights(window, cx);
self.refresh_selected_text_highlights(false, window, cx);
self.refresh_matching_bracket_highlights(window, cx);
self.update_visible_edit_prediction(window, cx);
self.edit_prediction_requires_modifier_in_indent_conflict = true;
self.inline_blame_popover.take();
@@ -21248,9 +21248,9 @@ impl Editor {
self.active_indent_guides_state.dirty = true;
self.refresh_active_diagnostics(cx);
self.refresh_code_actions(window, cx);
// self.refresh_selected_text_highlights(true, window, cx);
self.refresh_selected_text_highlights(true, window, cx);
self.refresh_single_line_folds(window, cx);
// self.refresh_matching_bracket_highlights(window, cx);
self.refresh_matching_bracket_highlights(window, cx);
if self.has_active_edit_prediction() {
self.update_visible_edit_prediction(window, cx);
}
@@ -21345,7 +21345,6 @@ impl Editor {
}
multi_buffer::Event::Reparsed(buffer_id) => {
self.tasks_update_task = Some(self.refresh_runnables(window, cx));
// self.refresh_selected_text_highlights(true, window, cx);
jsx_tag_auto_close::refresh_enabled_in_any_buffer(self, multibuffer, cx);
cx.emit(EditorEvent::Reparsed(*buffer_id));
@@ -23908,10 +23907,6 @@ impl EditorSnapshot {
self.scroll_anchor.scroll_position(&self.display_snapshot)
}
pub fn scroll_near_end(&self) -> bool {
self.scroll_anchor.near_end(&self.display_snapshot)
}
fn gutter_dimensions(
&self,
font_id: FontId,

View File

@@ -9055,9 +9055,6 @@ impl Element for EditorElement {
)
});
if snapshot.scroll_near_end() {
dbg!("near end!");
}
let mut scroll_position = snapshot.scroll_position();
// The scroll position is a fractional point, the whole number of which represents
// the top of the window in terms of display rows.

View File

@@ -46,20 +46,12 @@ impl ScrollAnchor {
}
}
pub fn near_end(&self, snapshot: &DisplaySnapshot) -> bool {
let editor_length = snapshot.max_point().row().as_f64();
let scroll_top = self.anchor.to_display_point(snapshot).row().as_f64();
(scroll_top - editor_length).abs() < 300.0
}
pub fn scroll_position(&self, snapshot: &DisplaySnapshot) -> gpui::Point<ScrollOffset> {
self.offset.apply_along(Axis::Vertical, |offset| {
if self.anchor == Anchor::min() {
0.
} else {
dbg!(snapshot.max_point().row().as_f64());
let scroll_top = self.anchor.to_display_point(snapshot).row().as_f64();
dbg!(scroll_top, offset);
(offset + scroll_top).max(0.)
}
})
@@ -251,11 +243,6 @@ impl ScrollManager {
}
}
};
let near_end = self.anchor.near_end(map);
// // TODO call load more here
// if near_end {
// cx.read();
// }
let scroll_top_row = DisplayRow(scroll_top as u32);
let scroll_top_buffer_point = map

View File

@@ -1250,9 +1250,12 @@ pub fn response_events_to_markdown(
));
}
Ok(
LanguageModelCompletionEvent::UsageUpdate(_)
LanguageModelCompletionEvent::TokenUsage(_)
| LanguageModelCompletionEvent::ToolUseLimitReached
| LanguageModelCompletionEvent::StartMessage { .. }
| LanguageModelCompletionEvent::StatusUpdate { .. },
| LanguageModelCompletionEvent::RequestUsage { .. }
| LanguageModelCompletionEvent::Queued { .. }
| LanguageModelCompletionEvent::Started,
) => {}
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
json_parse_error, ..
@@ -1335,11 +1338,14 @@ impl ThreadDialog {
}
// Skip these
Ok(LanguageModelCompletionEvent::UsageUpdate(_))
Ok(LanguageModelCompletionEvent::TokenUsage(_))
| Ok(LanguageModelCompletionEvent::RedactedThinking { .. })
| Ok(LanguageModelCompletionEvent::StatusUpdate { .. })
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
| Ok(LanguageModelCompletionEvent::Stop(_)) => {}
| Ok(LanguageModelCompletionEvent::Stop(_))
| Ok(LanguageModelCompletionEvent::Queued { .. })
| Ok(LanguageModelCompletionEvent::Started)
| Ok(LanguageModelCompletionEvent::RequestUsage { .. })
| Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => {}
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
json_parse_error,

View File

@@ -537,7 +537,6 @@ fn wasm_engine(executor: &BackgroundExecutor) -> wasmtime::Engine {
let engine_ref = engine.weak();
executor
.spawn(async move {
IS_WASM_THREAD.with(|v| v.store(true, Ordering::Release));
// Somewhat arbitrary interval, as it isn't a guaranteed interval.
// But this is a rough upper bound for how long the extension execution can block on
// `Future::poll`.
@@ -643,6 +642,12 @@ impl WasmHost {
let (tx, mut rx) = mpsc::unbounded::<ExtensionCall>();
let extension_task = async move {
// note: Setting the thread local here will slowly "poison" all tokio threads
// causing us to not record their panics any longer.
//
// This is fine though, the main zed binary only uses tokio for livekit and wasm extensions.
// Livekit seldom (if ever) panics 🤞 so the likelihood of us missing a panic in sentry is very low.
IS_WASM_THREAD.with(|v| v.store(true, Ordering::Release));
while let Some(call) = rx.next().await {
(call)(&mut extension, &mut store).await;
}
@@ -659,8 +664,8 @@ impl WasmHost {
cx.spawn(async move |cx| {
let (extension_task, manifest, work_dir, tx, zed_api_version) =
cx.background_executor().spawn(load_extension_task).await?;
// we need to run run the task in an extension context as wasmtime_wasi may
// call into tokio, accessing its runtime handle
// we need to run run the task in a tokio context as wasmtime_wasi may
// call into tokio, accessing its runtime handle when we trigger the `engine.increment_epoch()` above.
let task = Arc::new(gpui_tokio::Tokio::spawn(cx, extension_task)?);
Ok(WasmExtension {

View File

@@ -990,6 +990,9 @@ impl ExtensionImports for WasmState {
command: None,
settings: Some(settings),
})?),
project::project_settings::ContextServerSettings::Http { .. } => {
bail!("remote context server settings not supported in 0.6.0")
}
}
}
_ => {

View File

@@ -60,27 +60,27 @@ pub fn register_editor(editor: &mut Editor, buffer: Entity<MultiBuffer>, cx: &mu
buffer_added(editor, buffer, cx);
}
// cx.subscribe(&cx.entity(), |editor, _, event, cx| match event {
// EditorEvent::ExcerptsAdded { buffer, .. } => buffer_added(editor, buffer.clone(), cx),
// EditorEvent::ExcerptsExpanded { ids } => {
// let multibuffer = editor.buffer().read(cx).snapshot(cx);
// for excerpt_id in ids {
// let Some(buffer) = multibuffer.buffer_for_excerpt(*excerpt_id) else {
// continue;
// };
// let addon = editor.addon::<ConflictAddon>().unwrap();
// let Some(conflict_set) = addon.conflict_set(buffer.remote_id()).clone() else {
// return;
// };
// excerpt_for_buffer_updated(editor, conflict_set, cx);
// }
// }
// EditorEvent::ExcerptsRemoved {
// removed_buffer_ids, ..
// } => buffers_removed(editor, removed_buffer_ids, cx),
// _ => {}
// })
// .detach();
cx.subscribe(&cx.entity(), |editor, _, event, cx| match event {
EditorEvent::ExcerptsAdded { buffer, .. } => buffer_added(editor, buffer.clone(), cx),
EditorEvent::ExcerptsExpanded { ids } => {
let multibuffer = editor.buffer().read(cx).snapshot(cx);
for excerpt_id in ids {
let Some(buffer) = multibuffer.buffer_for_excerpt(*excerpt_id) else {
continue;
};
let addon = editor.addon::<ConflictAddon>().unwrap();
let Some(conflict_set) = addon.conflict_set(buffer.remote_id()).clone() else {
return;
};
excerpt_for_buffer_updated(editor, conflict_set, cx);
}
}
EditorEvent::ExcerptsRemoved {
removed_buffer_ids, ..
} => buffers_removed(editor, removed_buffer_ids, cx),
_ => {}
})
.detach();
}
fn excerpt_for_buffer_updated(

View File

@@ -30,11 +30,10 @@ use git::{
TrashUntrackedFiles, UnstageAll,
};
use gpui::{
Action, AppContext, AsyncApp, AsyncWindowContext, ClickEvent, Corner, DismissEvent, Entity,
EventEmitter, FocusHandle, Focusable, KeyContext, ListHorizontalSizingBehavior,
ListSizingBehavior, MouseButton, MouseDownEvent, Point, PromptLevel, ScrollStrategy,
Subscription, Task, UniformListScrollHandle, WeakEntity, actions, anchored, deferred,
uniform_list,
Action, AsyncApp, AsyncWindowContext, ClickEvent, Corner, DismissEvent, Entity, EventEmitter,
FocusHandle, Focusable, KeyContext, ListHorizontalSizingBehavior, ListSizingBehavior,
MouseButton, MouseDownEvent, Point, PromptLevel, ScrollStrategy, Subscription, Task,
UniformListScrollHandle, WeakEntity, actions, anchored, deferred, uniform_list,
};
use itertools::Itertools;
use language::{Buffer, File};
@@ -312,9 +311,6 @@ pub struct GitPanel {
bulk_staging: Option<BulkStaging>,
stash_entries: GitStash,
_settings_subscription: Subscription,
/// On clicking an entry in a the git_panel this will
/// trigger loading it
open_diff_task: Option<Task<()>>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
@@ -475,7 +471,6 @@ impl GitPanel {
bulk_staging: None,
stash_entries: Default::default(),
_settings_subscription,
open_diff_task: None,
};
this.schedule_update(window, cx);
@@ -755,23 +750,11 @@ impl GitPanel {
fn open_diff(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
maybe!({
let entry = self
.entries
.get(self.selected_entry?)?
.status_entry()?
.clone();
let entry = self.entries.get(self.selected_entry?)?.status_entry()?;
let workspace = self.workspace.upgrade()?;
let git_repo = self.active_repository.as_ref()?.clone();
let focus_handle = self.focus_handle.clone();
let git_repo = self.active_repository.as_ref()?;
// let panel = panel.upgrade().unwrap(); // TODO FIXME
// cx.read_entity(&panel, |panel, cx| {
// panel
// })
// .unwrap(); // TODO FIXME
let project_diff = if let Some(project_diff) =
workspace.read(cx).active_item_as::<ProjectDiff>(cx)
if let Some(project_diff) = workspace.read(cx).active_item_as::<ProjectDiff>(cx)
&& let Some(project_path) = project_diff.read(cx).active_path(cx)
&& Some(&entry.repo_path)
== git_repo
@@ -781,21 +764,16 @@ impl GitPanel {
{
project_diff.focus_handle(cx).focus(window);
project_diff.update(cx, |project_diff, cx| project_diff.autoscroll(cx));
project_diff
} else {
workspace.update(cx, |workspace, cx| {
ProjectDiff::deploy_at(workspace, Some(entry.clone()), window, cx)
})
return None;
};
focus_handle.focus(window); // TODO: should we focus before the file is loaded or wait for that?
let project_diff = project_diff.downgrade();
// TODO use the fancy new thing
self.open_diff_task = Some(cx.spawn_in(window, async move |_, cx| {
ProjectDiff::refresh_one(project_diff, entry.repo_path, entry.status, cx)
.await
.unwrap(); // TODO FIXME
}));
self.workspace
.update(cx, |workspace, cx| {
ProjectDiff::deploy_at(workspace, Some(entry.clone()), window, cx);
})
.ok();
self.focus_handle.focus(window);
Some(())
});
}
@@ -3882,7 +3860,6 @@ impl GitPanel {
})
}
// context menu
fn deploy_entry_context_menu(
&mut self,
position: Point<Pixels>,
@@ -4108,7 +4085,6 @@ impl GitPanel {
this.selected_entry = Some(ix);
cx.notify();
if event.modifiers().secondary() {
// the click handler
this.open_file(&Default::default(), window, cx)
} else {
this.open_diff(&Default::default(), window, cx);

View File

@@ -7,21 +7,19 @@ use crate::{
use anyhow::{Context as _, Result, anyhow};
use buffer_diff::{BufferDiff, DiffHunkSecondaryStatus};
use collections::{HashMap, HashSet};
use db::smol::stream::StreamExt;
use editor::{
Addon, Editor, EditorEvent, SelectionEffects,
actions::{GoToHunk, GoToPreviousHunk},
multibuffer_context_lines,
scroll::Autoscroll,
};
use futures::stream::FuturesUnordered;
use git::{
Commit, StageAll, StageAndNext, ToggleStaged, UnstageAll, UnstageAndNext,
repository::{Branch, RepoPath, Upstream, UpstreamTracking, UpstreamTrackingStatus},
status::FileStatus,
};
use gpui::{
Action, AnyElement, AnyView, App, AppContext, AsyncWindowContext, Entity, EventEmitter,
Action, AnyElement, AnyView, App, AppContext as _, AsyncWindowContext, Entity, EventEmitter,
FocusHandle, Focusable, Render, Subscription, Task, WeakEntity, actions,
};
use language::{Anchor, Buffer, Capability, OffsetRangeExt};
@@ -29,21 +27,17 @@ use multi_buffer::{MultiBuffer, PathKey};
use project::{
Project, ProjectPath,
git_store::{
self, Repository, StatusEntry,
Repository,
branch_diff::{self, BranchDiffEvent, DiffBase},
},
};
use settings::{Settings, SettingsStore};
use std::{
any::{Any, TypeId},
collections::VecDeque,
sync::Arc,
};
use std::{ops::Range, time::Instant};
use std::any::{Any, TypeId};
use std::ops::Range;
use std::sync::Arc;
use theme::ActiveTheme;
use ui::{KeyBinding, Tooltip, prelude::*, vertical_divider};
use util::{ResultExt, rel_path::RelPath};
use util::{ResultExt as _, rel_path::RelPath};
use workspace::{
CloseActiveItem, ItemNavHistory, SerializableItem, ToolbarItemEvent, ToolbarItemLocation,
ToolbarItemView, Workspace,
@@ -52,8 +46,6 @@ use workspace::{
searchable::SearchableItemHandle,
};
mod diff_loader;
actions!(
git,
[
@@ -100,7 +92,7 @@ impl ProjectDiff {
window: &mut Window,
cx: &mut Context<Workspace>,
) {
Self::deploy_at(workspace, None, window, cx);
Self::deploy_at(workspace, None, window, cx)
}
fn deploy_branch_diff(
@@ -142,7 +134,7 @@ impl ProjectDiff {
entry: Option<GitStatusEntry>,
window: &mut Window,
cx: &mut Context<Workspace>,
) -> Entity<ProjectDiff> {
) {
telemetry::event!(
"Git Diff Opened",
source = if entry.is_some() {
@@ -174,8 +166,7 @@ impl ProjectDiff {
project_diff.update(cx, |project_diff, cx| {
project_diff.move_to_entry(entry, window, cx);
})
};
project_diff
}
}
pub fn autoscroll(&self, cx: &mut Context<Self>) {
@@ -276,23 +267,15 @@ impl ProjectDiff {
cx.subscribe_in(&editor, window, Self::handle_editor_event)
.detach();
let loader = diff_loader::start_loader(cx.entity(), window, cx);
let branch_diff_subscription = cx.subscribe_in(
&branch_diff,
window,
move |this, _git_store, event, window, cx| match event {
BranchDiffEvent::FileListChanged => {
// TODO this does not account for size of paths
// maybe a quick fs metadata could get us info on that?
// would make number of paths async but thats fine here
// let entries = this.first_n_entries(cx, 100);
loader.update_file_list();
// let
// this._task = window.spawn(cx, {
// let this = cx.weak_entity();
// async |cx| Self::refresh(this, entries, cx).await
// })
this._task = window.spawn(cx, {
let this = cx.weak_entity();
async |cx| Self::refresh(this, cx).await
})
}
},
);
@@ -307,32 +290,22 @@ impl ProjectDiff {
if is_sort_by_path != was_sort_by_path
|| is_collapse_untracked_diff != was_collapse_untracked_diff
{
// no idea why we need to do anything here
// probably should sort the multibuffer instead of reparsing
// everything though!!!
todo!("resort multibuffer entries");
todo!("assert the entries in the list did not change")
// this._task = {
// window.spawn(cx, {
// let this = cx.weak_entity();
// async |cx| Self::refresh(this, cx).await
// })
// }
this._task = {
window.spawn(cx, {
let this = cx.weak_entity();
async |cx| Self::refresh(this, cx).await
})
}
}
was_sort_by_path = is_sort_by_path;
was_collapse_untracked_diff = is_collapse_untracked_diff;
})
.detach();
// let task = window.spawn(cx, {
// let this = cx.weak_entity();
// async |cx| {
// let entries = this
// .read_with(cx, |project_diff, cx| project_diff.first_n_entries(cx, 100))
// .unwrap();
// Self::refresh(this, entries, cx).await
// }
// });
let task = window.spawn(cx, {
let this = cx.weak_entity();
async |cx| Self::refresh(this, cx).await
});
Self {
project,
@@ -498,11 +471,10 @@ impl ProjectDiff {
cx: &mut Context<Self>,
) {
let subscription = cx.subscribe_in(&diff, window, move |this, _, _, window, cx| {
// TODO fix this
// this._task = window.spawn(cx, {
// let this = cx.weak_entity();
// async |cx| Self::refresh(this, cx).await
// })
this._task = window.spawn(cx, {
let this = cx.weak_entity();
async |cx| Self::refresh(this, cx).await
})
});
self.buffer_diff_subscriptions
.insert(path_key.path.clone(), (diff.clone(), subscription));
@@ -578,221 +550,51 @@ impl ProjectDiff {
}
}
pub fn all_entries(&self, cx: &App) -> Vec<StatusEntry> {
let Some(ref repo) = self.branch_diff.read(cx).repo else {
return Vec::new();
};
repo.read(cx).cached_status().collect()
}
pub fn entries(&self, cx: &App) -> Option<impl Iterator<Item = StatusEntry>> {
Some(
self.branch_diff
.read(cx)
.repo
.as_ref()?
.read(cx)
.cached_status(),
)
}
pub fn first_n_entries(&self, cx: &App, n: usize) -> VecDeque<StatusEntry> {
let Some(ref repo) = self.branch_diff.read(cx).repo else {
return VecDeque::new();
};
repo.read(cx).cached_status().take(n).collect()
}
pub async fn refresh_one(
this: WeakEntity<Self>,
repo_path: RepoPath,
status: FileStatus,
cx: &mut AsyncWindowContext,
) -> Result<()> {
use git_store::branch_diff::BranchDiff;
let Some(this) = this.upgrade() else {
return Ok(());
};
let multibuffer = cx.read_entity(&this, |this, _| this.multibuffer.clone())?;
let branch_diff = cx.read_entity(&this, |pd, _| pd.branch_diff.clone())?;
let Some(repo) = cx.read_entity(&branch_diff, |bd, _| bd.repo.clone())? else {
return Ok(());
};
let project = cx.read_entity(&branch_diff, |bd, _| bd.project.clone())?;
let mut previous_paths =
cx.read_entity(&multibuffer, |mb, _| mb.paths().collect::<HashSet<_>>())?;
let tree_diff_status = cx.read_entity(&branch_diff, |branch_diff, _| {
branch_diff
.tree_diff
.as_ref()
.and_then(|t| t.entries.get(&repo_path))
.cloned()
})?;
let Some(status) = cx.read_entity(&branch_diff, |bd, _| {
bd.merge_statuses(Some(status), tree_diff_status.as_ref())
})?
else {
return Ok(());
};
if !status.has_changes() {
return Ok(());
}
let Some(project_path) = cx.read_entity(&repo, |repo, cx| {
repo.repo_path_to_project_path(&repo_path, cx)
})?
else {
return Ok(());
};
let sort_prefix =
cx.read_entity(&repo, |repo, cx| sort_prefix(repo, &repo_path, status, cx))?;
let path_key = PathKey::with_sort_prefix(sort_prefix, repo_path.into_arc());
previous_paths.remove(&path_key);
let repo = repo.clone();
let Some((buffer, diff)) = BranchDiff::load_buffer(
tree_diff_status,
project_path,
repo,
project.downgrade(),
&mut cx.to_app(),
)
.await
.log_err() else {
return Ok(());
};
cx.update(|window, cx| {
this.update(cx, |this, cx| {
this.register_buffer(path_key, status, buffer, diff, window, cx)
pub async fn refresh(this: WeakEntity<Self>, cx: &mut AsyncWindowContext) -> Result<()> {
let mut path_keys = Vec::new();
let buffers_to_load = this.update(cx, |this, cx| {
let (repo, buffers_to_load) = this.branch_diff.update(cx, |branch_diff, cx| {
let load_buffers = branch_diff.load_buffers(cx);
(branch_diff.repo().cloned(), load_buffers)
});
})?;
let mut previous_paths = this.multibuffer.read(cx).paths().collect::<HashSet<_>>();
// TODO LL clear multibuff on open?
// // remove anything not part of the diff in the multibuffer
// this.update(cx, |this, cx| {
// multibuffer.update(cx, |multibuffer, cx| {
// for path in previous_paths {
// this.buffer_diff_subscriptions.remove(&path.path);
// multibuffer.remove_excerpts_for_path(path, cx);
// }
// });
// })?;
if let Some(repo) = repo {
let repo = repo.read(cx);
Ok(())
}
pub async fn refresh(
this: WeakEntity<Self>,
cached_status: Vec<StatusEntry>,
cx: &mut AsyncWindowContext,
) -> Result<()> {
dbg!("refreshing all");
use git_store::branch_diff::BranchDiff;
let Some(this) = this.upgrade() else {
return Ok(());
};
let multibuffer = cx.read_entity(&this, |this, _| this.multibuffer.clone())?;
let branch_diff = cx.read_entity(&this, |pd, _| pd.branch_diff.clone())?;
let Some(repo) = cx.read_entity(&branch_diff, |bd, _| bd.repo.clone())? else {
return Ok(());
};
let project = cx.read_entity(&branch_diff, |bd, _| bd.project.clone())?;
let mut previous_paths =
cx.read_entity(&multibuffer, |mb, _| mb.paths().collect::<HashSet<_>>())?;
// Idea: on click in git panel prioritize task for that file in some way ...
// could have a hashmap of futures here
// - needs to prioritize *some* background tasks over others
// -
let mut tasks = FuturesUnordered::new();
let mut seen = HashSet::default();
for entry in cached_status {
seen.insert(entry.repo_path.clone());
let tree_diff_status = cx.read_entity(&branch_diff, |branch_diff, _| {
branch_diff
.tree_diff
.as_ref()
.and_then(|t| t.entries.get(&entry.repo_path))
.cloned()
})?;
let Some(status) = cx.read_entity(&branch_diff, |bd, _| {
bd.merge_statuses(Some(entry.status), tree_diff_status.as_ref())
})?
else {
continue;
};
if !status.has_changes() {
continue;
path_keys = Vec::with_capacity(buffers_to_load.len());
for entry in buffers_to_load.iter() {
let sort_prefix = sort_prefix(&repo, &entry.repo_path, entry.file_status, cx);
let path_key =
PathKey::with_sort_prefix(sort_prefix, entry.repo_path.as_ref().clone());
previous_paths.remove(&path_key);
path_keys.push(path_key)
}
}
let Some(project_path) = cx.read_entity(&repo, |repo, cx| {
repo.repo_path_to_project_path(&entry.repo_path, cx)
})?
else {
continue;
};
let sort_prefix = cx.read_entity(&repo, |repo, cx| {
sort_prefix(repo, &entry.repo_path, entry.status, cx)
})?;
let path_key = PathKey::with_sort_prefix(sort_prefix, entry.repo_path.into_arc());
previous_paths.remove(&path_key);
let repo = repo.clone();
let project = project.downgrade();
let task = cx.spawn(async move |cx| {
let res = BranchDiff::load_buffer(
tree_diff_status,
project_path,
repo,
project,
&mut cx.to_app(),
)
.await;
(res, path_key, entry.status)
});
tasks.push(task)
}
// remove anything not part of the diff in the multibuffer
this.update(cx, |this, cx| {
multibuffer.update(cx, |multibuffer, cx| {
this.multibuffer.update(cx, |multibuffer, cx| {
for path in previous_paths {
this.buffer_diff_subscriptions.remove(&path.path);
multibuffer.remove_excerpts_for_path(path, cx);
}
});
buffers_to_load
})?;
// add the new buffers as they are parsed
let mut last_notify = Instant::now();
while let Some((res, path_key, file_status)) = tasks.next().await {
if let Some((buffer, diff)) = res.log_err() {
for (entry, path_key) in buffers_to_load.into_iter().zip(path_keys.into_iter()) {
if let Some((buffer, diff)) = entry.load.await.log_err() {
cx.update(|window, cx| {
this.update(cx, |this, cx| {
this.register_buffer(path_key, file_status, buffer, diff, window, cx)
});
this.register_buffer(path_key, entry.file_status, buffer, diff, window, cx)
})
.ok();
})?;
}
if last_notify.elapsed().as_millis() > 100 {
cx.update_entity(&this, |_, cx| cx.notify())?;
last_notify = Instant::now();
}
}
this.update(cx, |this, cx| {
this.pending_scroll.take();
cx.notify();
})?;
Ok(())
}

View File

@@ -1,249 +0,0 @@
//! Task which updates the project diff multibuffer without putting too much
//! pressure on the frontend executor. It prioritizes loading the area around the user
use collections::HashSet;
use db::smol::stream::StreamExt;
use futures::channel::mpsc;
use gpui::{AppContext, AsyncWindowContext, Entity, Task, WeakEntity};
use project::git_store::StatusEntry;
use ui::{App, Window};
use util::ResultExt;
use crate::{git_panel::GitStatusEntry, project_diff::ProjectDiff};
enum Update {
Position(usize),
NewFile(StatusEntry),
ListChanged,
// should not need to handle re-ordering (sorting) here.
// something to handle scroll? or should that live in the project diff?
}
struct LoaderHandle {
task: Task<Option<()>>,
sender: mpsc::UnboundedSender<Update>,
}
impl LoaderHandle {
pub fn update_file_list(&self) {
let _ = self
.sender
.unbounded_send(Update::ListChanged)
.log_err();
}
pub fn update_pos(&self, pos: usize) {
let _ = self
.sender
.unbounded_send(Update::Position((pos)))
.log_err();
}
}
pub fn start_loader(project_diff: Entity<ProjectDiff>, window: &Window, cx: &App) -> LoaderHandle {
let (tx, rx) = mpsc::unbounded();
let task = window.spawn(cx, async move |cx| {
load(rx, project_diff.downgrade(), cx).await
});
LoaderHandle { task, sender: tx }
}
enum DiffEntry {
Loading(GitStatusEntry),
Loaded(GitStatusEntry),
Queued(GitStatusEntry),
}
impl DiffEntry {
fn queued(&self) -> bool {
matches!(self, DiffEntry::Queued(_))
}
}
async fn load(
rx: mpsc::UnboundedReceiver<Update>,
project_diff: WeakEntity<ProjectDiff>,
cx: &mut AsyncWindowContext,
) -> Option<()> {
// let initial_entries = cx.read_entity(&cx.entity(), |project_diff, cx| project_diff.first_n_entries(cx, 100));
// let loading = to_load.drain(..100).map(|| refresh_one)
let mut existing = Vec::new();
loop {
let update = rx.next().await?;
match update {
Update::Position(pos) => {
if existing.get(pos).is_some_and(|diff| diff.queued()) {
todo!("append to future unordered, also load in the bit
around (maybe with a short sleep ahead so we get some sense
of 'priority'")
}
// drop whatever is loading so we get to the new bit earlier
}
Update::NewFile(status_entry) => todo!(),
Update::ListChanged => {
let (added, removed) = project_diff
.upgrade()?
.read_with(cx, |diff, cx| diff_current_list(&existing, diff, cx))
.ok()?;
}
}
// wait for Update OR Load done
// -> Immediately spawn update
// OR
// -> spawn next group
}
}
// could be new list
fn diff_current_list(
existing_entries: &[GitStatusEntry],
project_diff: &ProjectDiff,
cx: &App,
) -> (Vec<(usize, GitStatusEntry)>, Vec<usize>) {
let Some(new_entries) = project_diff.entries(cx) else {
return (Vec::new(), Vec::new());
};
let existing_entries = existing_entries.iter().enumerate();
for entry in new_entries {
let Some((idx, existing)) = existing_entries.next() else {
todo!();
};
if existing == entry {
}
}
// let initial_entries = cx.read_entity(&cx.entity(), |project_diff, cx| project_diff.first_n_entries(cx, 100));
// let loading = to_load.drain(..100).map(|| refresh_one)
}
// // remove anything not part of the diff in the multibuffer
// fn remove_anything_not_being_loaded() {
// this.update(cx, |this, cx| {
// multibuffer.update(cx, |multibuffer, cx| {
// for path in previous_paths {
// this.buffer_diff_subscriptions.remove(&path.path);
// multibuffer.remove_excerpts_for_path(path, cx);
// }
// });
// })?;
// }
pub async fn refresh_group(
this: WeakEntity<ProjectDiff>,
cached_status: Vec<StatusEntry>,
cx: &mut AsyncWindowContext,
) -> anyhow::Result<()> {
dbg!("refreshing all");
use project::git_store::branch_diff::BranchDiff;
let Some(this) = this.upgrade() else {
return Ok(());
};
let multibuffer = cx.read_entity(&this, |this, _| this.multibuffer.clone())?;
let branch_diff = cx.read_entity(&this, |pd, _| pd.branch_diff.clone())?;
let Some(repo) = cx.read_entity(&branch_diff, |bd, _| bd.repo.clone())? else {
return Ok(());
};
let project = cx.read_entity(&branch_diff, |bd, _| bd.project.clone())?;
let mut previous_paths =
cx.read_entity(&multibuffer, |mb, _| mb.paths().collect::<HashSet<_>>())?;
// Idea: on click in git panel prioritize task for that file in some way ...
// could have a hashmap of futures here
// - needs to prioritize *some* background tasks over others
// -
let mut tasks = FuturesUnordered::new();
let mut seen = HashSet::default();
for entry in cached_status {
seen.insert(entry.repo_path.clone());
let tree_diff_status = cx.read_entity(&branch_diff, |branch_diff, _| {
branch_diff
.tree_diff
.as_ref()
.and_then(|t| t.entries.get(&entry.repo_path))
.cloned()
})?;
let Some(status) = cx.read_entity(&branch_diff, |bd, _| {
bd.merge_statuses(Some(entry.status), tree_diff_status.as_ref())
})?
else {
continue;
};
if !status.has_changes() {
continue;
}
let Some(project_path) = cx.read_entity(&repo, |repo, cx| {
repo.repo_path_to_project_path(&entry.repo_path, cx)
})?
else {
continue;
};
let sort_prefix = cx.read_entity(&repo, |repo, cx| {
sort_prefix(repo, &entry.repo_path, entry.status, cx)
})?;
let path_key = PathKey::with_sort_prefix(sort_prefix, entry.repo_path.into_arc());
previous_paths.remove(&path_key);
let repo = repo.clone();
let project = project.downgrade();
let task = cx.spawn(async move |cx| {
let res = BranchDiff::load_buffer(
tree_diff_status,
project_path,
repo,
project,
&mut cx.to_app(),
)
.await;
(res, path_key, entry.status)
});
tasks.push(task)
}
// remove anything not part of the diff in the multibuffer
this.update(cx, |this, cx| {
multibuffer.update(cx, |multibuffer, cx| {
for path in previous_paths {
this.buffer_diff_subscriptions.remove(&path.path);
multibuffer.remove_excerpts_for_path(path, cx);
}
});
})?;
// add the new buffers as they are parsed
let mut last_notify = Instant::now();
while let Some((res, path_key, file_status)) = tasks.next().await {
if let Some((buffer, diff)) = res.log_err() {
cx.update(|window, cx| {
this.update(cx, |this, cx| {
this.register_buffer(path_key, file_status, buffer, diff, window, cx)
});
})?;
}
if last_notify.elapsed().as_millis() > 100 {
cx.update_entity(&this, |_, cx| cx.notify())?;
last_notify = Instant::now();
}
}
Ok(())
}
pub(crate) fn sort_or_collapse_changed() {
todo!()
}

View File

@@ -229,6 +229,10 @@ pub struct GenerativeContentBlob {
#[serde(rename_all = "camelCase")]
pub struct FunctionCallPart {
pub function_call: FunctionCall,
/// Thought signature returned by the model for function calls.
/// Only present on the first function call in parallel call scenarios.
#[serde(skip_serializing_if = "Option::is_none")]
pub thought_signature: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -517,6 +521,8 @@ pub enum Model {
alias = "gemini-2.5-pro-preview-06-05"
)]
Gemini25Pro,
#[serde(rename = "gemini-3-pro-preview")]
Gemini3ProPreview,
#[serde(rename = "custom")]
Custom {
name: String,
@@ -543,6 +549,7 @@ impl Model {
Self::Gemini25FlashLitePreview => "gemini-2.5-flash-lite-preview",
Self::Gemini25Flash => "gemini-2.5-flash",
Self::Gemini25Pro => "gemini-2.5-pro",
Self::Gemini3ProPreview => "gemini-3-pro-preview",
Self::Custom { name, .. } => name,
}
}
@@ -556,6 +563,7 @@ impl Model {
Self::Gemini25FlashLitePreview => "gemini-2.5-flash-lite-preview-06-17",
Self::Gemini25Flash => "gemini-2.5-flash",
Self::Gemini25Pro => "gemini-2.5-pro",
Self::Gemini3ProPreview => "gemini-3-pro-preview",
Self::Custom { name, .. } => name,
}
}
@@ -570,6 +578,7 @@ impl Model {
Self::Gemini25FlashLitePreview => "Gemini 2.5 Flash-Lite Preview",
Self::Gemini25Flash => "Gemini 2.5 Flash",
Self::Gemini25Pro => "Gemini 2.5 Pro",
Self::Gemini3ProPreview => "Gemini 3 Pro Preview",
Self::Custom {
name, display_name, ..
} => display_name.as_ref().unwrap_or(name),
@@ -586,6 +595,7 @@ impl Model {
Self::Gemini25FlashLitePreview => 1_000_000,
Self::Gemini25Flash => 1_048_576,
Self::Gemini25Pro => 1_048_576,
Self::Gemini3ProPreview => 1_048_576,
Self::Custom { max_tokens, .. } => *max_tokens,
}
}
@@ -600,6 +610,7 @@ impl Model {
Model::Gemini25FlashLitePreview => Some(64_000),
Model::Gemini25Flash => Some(65_536),
Model::Gemini25Pro => Some(65_536),
Model::Gemini3ProPreview => Some(65_536),
Model::Custom { .. } => None,
}
}
@@ -619,7 +630,10 @@ impl Model {
| Self::Gemini15Flash
| Self::Gemini20FlashLite
| Self::Gemini20Flash => GoogleModelMode::Default,
Self::Gemini25FlashLitePreview | Self::Gemini25Flash | Self::Gemini25Pro => {
Self::Gemini25FlashLitePreview
| Self::Gemini25Flash
| Self::Gemini25Pro
| Self::Gemini3ProPreview => {
GoogleModelMode::Thinking {
// By default these models are set to "auto", so we preserve that behavior
// but indicate they are capable of thinking mode
@@ -636,3 +650,109 @@ impl std::fmt::Display for Model {
write!(f, "{}", self.id())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_function_call_part_with_signature_serializes_correctly() {
let part = FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: Some("test_signature".to_string()),
};
let serialized = serde_json::to_value(&part).unwrap();
assert_eq!(serialized["functionCall"]["name"], "test_function");
assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
assert_eq!(serialized["thoughtSignature"], "test_signature");
}
#[test]
fn test_function_call_part_without_signature_omits_field() {
let part = FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: None,
};
let serialized = serde_json::to_value(&part).unwrap();
assert_eq!(serialized["functionCall"]["name"], "test_function");
assert_eq!(serialized["functionCall"]["args"]["arg"], "value");
// thoughtSignature field should not be present when None
assert!(serialized.get("thoughtSignature").is_none());
}
#[test]
fn test_function_call_part_deserializes_with_signature() {
let json = json!({
"functionCall": {
"name": "test_function",
"args": {"arg": "value"}
},
"thoughtSignature": "test_signature"
});
let part: FunctionCallPart = serde_json::from_value(json).unwrap();
assert_eq!(part.function_call.name, "test_function");
assert_eq!(part.thought_signature, Some("test_signature".to_string()));
}
#[test]
fn test_function_call_part_deserializes_without_signature() {
let json = json!({
"functionCall": {
"name": "test_function",
"args": {"arg": "value"}
}
});
let part: FunctionCallPart = serde_json::from_value(json).unwrap();
assert_eq!(part.function_call.name, "test_function");
assert_eq!(part.thought_signature, None);
}
#[test]
fn test_function_call_part_round_trip() {
let original = FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value", "nested": {"key": "val"}}),
},
thought_signature: Some("round_trip_signature".to_string()),
};
let serialized = serde_json::to_value(&original).unwrap();
let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap();
assert_eq!(deserialized.function_call.name, original.function_call.name);
assert_eq!(deserialized.function_call.args, original.function_call.args);
assert_eq!(deserialized.thought_signature, original.thought_signature);
}
#[test]
fn test_function_call_part_with_empty_signature_serializes() {
let part = FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: Some("".to_string()),
};
let serialized = serde_json::to_value(&part).unwrap();
// Empty string should still be serialized (normalization happens at a higher level)
assert_eq!(serialized["thoughtSignature"], "");
}
}

View File

@@ -187,13 +187,12 @@ font-kit = { git = "https://github.com/zed-industries/font-kit", rev = "11052312
"source-fontconfig-dlopen",
], optional = true }
calloop = { version = "0.14.3" }
calloop = { version = "0.13.0" }
filedescriptor = { version = "0.8.2", optional = true }
open = { version = "5.2.0", optional = true }
# Wayland
calloop-wayland-source = { version = "0.4.1", optional = true }
calloop-wayland-source = { version = "0.3.0", optional = true }
wayland-backend = { version = "0.3.3", features = [
"client_system",
"dlopen",
@@ -266,6 +265,7 @@ naga.workspace = true
[target.'cfg(any(target_os = "linux", target_os = "freebsd"))'.build-dependencies]
naga.workspace = true
[[example]]
name = "hello_world"
path = "examples/hello_world.rs"

View File

@@ -310,11 +310,6 @@ impl AsyncWindowContext {
.update(self, |_, window, cx| read(cx.global(), window, cx))
}
/// Returns an `AsyncApp` by cloning the one used by Self
pub fn to_app(&self) -> AsyncApp {
self.app.clone()
}
/// A convenience method for [`App::update_global`](BorrowAppContext::update_global).
/// for updating the global state of the specified type.
pub fn update_global<G, R>(

View File

@@ -233,9 +233,6 @@ impl<'a, T: 'static> Context<'a, T> {
/// Spawn the future returned by the given function.
/// The function is provided a weak handle to the entity owned by this context and a context that can be held across await points.
/// The returned task must be held or detached.
///
/// # Example
/// `cx.spawn(async move |some_weak_entity, cx| ...)`
#[track_caller]
pub fn spawn<AsyncFn, R>(&self, f: AsyncFn) -> Task<R>
where

View File

@@ -1389,6 +1389,10 @@ pub enum WindowBackgroundAppearance {
///
/// Not always supported.
Blurred,
/// The Mica backdrop material, supported on Windows 11.
MicaBackdrop,
/// The Mica Alt backdrop material, supported on Windows 11.
MicaAltBackdrop,
}
/// The options that can be configured for a file dialog prompt

View File

@@ -487,15 +487,12 @@ impl WaylandClient {
let (common, main_receiver) = LinuxCommon::new(event_loop.get_signal());
let handle = event_loop.handle(); // CHECK that wayland sources get higher prio
let handle = event_loop.handle();
handle
// these are all tasks spawned on the foreground executor.
// There is no concept of priority, they are all equal.
.insert_source(main_receiver, {
let handle = handle.clone();
move |event, _, _: &mut WaylandClientStatePtr| {
if let calloop::channel::Event::Msg(runnable) = event {
// will only be called when the event loop has finished processing all pending events from the sources
handle.insert_idle(|_| {
let start = Instant::now();
let mut timing = match runnable {
@@ -653,7 +650,6 @@ impl WaylandClient {
event_loop: Some(event_loop),
}));
// MAGIC HERE IT IS
WaylandSource::new(conn, event_queue)
.insert(handle)
.unwrap();
@@ -1578,7 +1574,6 @@ fn linux_button_to_gpui(button: u32) -> Option<MouseButton> {
})
}
// how is this being called inside calloop
impl Dispatch<wl_pointer::WlPointer, ()> for WaylandClientStatePtr {
fn event(
this: &mut Self,
@@ -1669,7 +1664,7 @@ impl Dispatch<wl_pointer::WlPointer, ()> for WaylandClientStatePtr {
modifiers: state.modifiers,
});
drop(state);
window.handle_input(input); // How does this get into the event loop?
window.handle_input(input);
}
}
wl_pointer::Event::Button {

View File

@@ -18,6 +18,7 @@ use smallvec::SmallVec;
use windows::{
Win32::{
Foundation::*,
Graphics::Dwm::*,
Graphics::Gdi::*,
System::{Com::*, LibraryLoader::*, Ole::*, SystemServices::*},
UI::{Controls::*, HiDpi::*, Input::KeyboardAndMouse::*, Shell::*, WindowsAndMessaging::*},
@@ -773,20 +774,26 @@ impl PlatformWindow for WindowsWindow {
fn set_background_appearance(&self, background_appearance: WindowBackgroundAppearance) {
let hwnd = self.0.hwnd;
// using Dwm APIs for Mica and MicaAlt backdrops.
// others follow the set_window_composition_attribute approach
match background_appearance {
WindowBackgroundAppearance::Opaque => {
// ACCENT_DISABLED
set_window_composition_attribute(hwnd, None, 0);
}
WindowBackgroundAppearance::Transparent => {
// Use ACCENT_ENABLE_TRANSPARENTGRADIENT for transparent background
set_window_composition_attribute(hwnd, None, 2);
}
WindowBackgroundAppearance::Blurred => {
// Enable acrylic blur
// ACCENT_ENABLE_ACRYLICBLURBEHIND
set_window_composition_attribute(hwnd, Some((0, 0, 0, 0)), 4);
}
WindowBackgroundAppearance::MicaBackdrop => {
// DWMSBT_MAINWINDOW => MicaBase
dwm_set_window_composition_attribute(hwnd, 2);
}
WindowBackgroundAppearance::MicaAltBackdrop => {
// DWMSBT_TABBEDWINDOW => MicaAlt
dwm_set_window_composition_attribute(hwnd, 4);
}
}
}
@@ -1330,9 +1337,34 @@ fn retrieve_window_placement(
Ok(placement)
}
fn dwm_set_window_composition_attribute(hwnd: HWND, backdrop_type: u32) {
let mut version = unsafe { std::mem::zeroed() };
let status = unsafe { windows::Wdk::System::SystemServices::RtlGetVersion(&mut version) };
// DWMWA_SYSTEMBACKDROP_TYPE is available only on version 22621 or later
// using SetWindowCompositionAttributeType as a fallback
if !status.is_ok() || version.dwBuildNumber < 22621 {
return;
}
unsafe {
let result = DwmSetWindowAttribute(
hwnd,
DWMWA_SYSTEMBACKDROP_TYPE,
&backdrop_type as *const _ as *const _,
std::mem::size_of_val(&backdrop_type) as u32,
);
if !result.is_ok() {
return;
}
}
}
fn set_window_composition_attribute(hwnd: HWND, color: Option<Color>, state: u32) {
let mut version = unsafe { std::mem::zeroed() };
let status = unsafe { windows::Wdk::System::SystemServices::RtlGetVersion(&mut version) };
if !status.is_ok() || version.dwBuildNumber < 17763 {
return;
}

View File

@@ -41,7 +41,6 @@ tree-sitter-rust.workspace = true
ui_input.workspace = true
ui.workspace = true
util.workspace = true
vim.workspace = true
workspace.workspace = true
zed_actions.workspace = true

View File

@@ -1769,7 +1769,7 @@ impl Render for KeymapEditor {
)
.action(
"Vim Bindings",
vim::OpenDefaultKeymap.boxed_clone(),
zed_actions::vim::OpenDefaultKeymap.boxed_clone(),
)
}))
})

View File

@@ -12,7 +12,7 @@ pub mod fake_provider;
use anthropic::{AnthropicError, parse_prompt_too_long};
use anyhow::{Result, anyhow};
use client::Client;
use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
@@ -70,7 +70,15 @@ pub fn init_settings(cx: &mut App) {
/// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent {
StatusUpdate(CompletionRequestStatus),
Queued {
position: usize,
},
Started,
RequestUsage {
amount: usize,
limit: UsageLimit,
},
ToolUseLimitReached,
Stop(StopReason),
Text(String),
Thinking {
@@ -90,88 +98,93 @@ pub enum LanguageModelCompletionEvent {
StartMessage {
message_id: String,
},
UsageUpdate(TokenUsage),
TokenUsage(TokenUsage),
}
impl LanguageModelCompletionEvent {
pub fn from_completion_request_status(
status: CompletionRequestStatus,
) -> Result<Self, LanguageModelCompletionError> {
match status {
CompletionRequestStatus::Queued { position } => {
Ok(LanguageModelCompletionEvent::Queued { position })
}
CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
CompletionRequestStatus::UsageUpdated { amount, limit } => {
Ok(LanguageModelCompletionEvent::RequestUsage { amount, limit })
}
CompletionRequestStatus::ToolUseLimitReached => {
Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
}
CompletionRequestStatus::Failed {
code,
message,
request_id: _,
retry_after,
} => Err(LanguageModelCompletionError::from_cloud_failure(
code,
message,
retry_after.map(Duration::from_secs_f64),
)),
}
}
}
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
#[error("prompt too large for context window")]
PromptTooLarge { tokens: Option<u64> },
#[error("missing {provider} API key")]
NoApiKey { provider: LanguageModelProviderName },
#[error("{provider}'s API rate limit exceeded")]
RateLimitExceeded {
provider: LanguageModelProviderName,
retry_after: Option<Duration>,
},
#[error("{provider}'s API servers are overloaded right now")]
ServerOverloaded {
provider: LanguageModelProviderName,
retry_after: Option<Duration>,
},
#[error("{provider}'s API server reported an internal server error: {message}")]
ApiInternalServerError {
provider: LanguageModelProviderName,
message: String,
},
#[error("missing API key")]
NoApiKey,
#[error("API rate limit exceeded")]
RateLimitExceeded { retry_after: Option<Duration> },
#[error("API servers are overloaded right now")]
ServerOverloaded { retry_after: Option<Duration> },
#[error("API server reported an internal server error: {message}")]
ApiInternalServerError { message: String },
#[error("{message}")]
UpstreamProviderError {
message: String,
status: StatusCode,
retry_after: Option<Duration>,
},
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
#[error("HTTP response error from API: status {status_code} - {message:?}")]
HttpResponseError {
provider: LanguageModelProviderName,
status_code: StatusCode,
message: String,
},
// Client errors
#[error("invalid request format to {provider}'s API: {message}")]
BadRequestFormat {
provider: LanguageModelProviderName,
message: String,
},
#[error("authentication error with {provider}'s API: {message}")]
AuthenticationError {
provider: LanguageModelProviderName,
message: String,
},
#[error("Permission error with {provider}'s API: {message}")]
PermissionError {
provider: LanguageModelProviderName,
message: String,
},
#[error("invalid request format to API: {message}")]
BadRequestFormat { message: String },
#[error("authentication error with API: {message}")]
AuthenticationError { message: String },
#[error("Permission error with API: {message}")]
PermissionError { message: String },
#[error("language model provider API endpoint not found")]
ApiEndpointNotFound { provider: LanguageModelProviderName },
#[error("I/O error reading response from {provider}'s API")]
ApiEndpointNotFound,
#[error("I/O error reading response from API")]
ApiReadResponseError {
provider: LanguageModelProviderName,
#[source]
error: io::Error,
},
#[error("error serializing request to {provider} API")]
#[error("error serializing request to API")]
SerializeRequest {
provider: LanguageModelProviderName,
#[source]
error: serde_json::Error,
},
#[error("error building request body to {provider} API")]
#[error("error building request body to API")]
BuildRequestBody {
provider: LanguageModelProviderName,
#[source]
error: http::Error,
},
#[error("error sending HTTP request to {provider} API")]
#[error("error sending HTTP request to API")]
HttpSend {
provider: LanguageModelProviderName,
#[source]
error: anyhow::Error,
},
#[error("error deserializing {provider} API response")]
#[error("error deserializing API response")]
DeserializeResponse {
provider: LanguageModelProviderName,
#[source]
error: serde_json::Error,
},
@@ -182,6 +195,72 @@ pub enum LanguageModelCompletionError {
}
impl LanguageModelCompletionError {
fn display_format(&self, provider: LanguageModelProviderName) {
// match self {
// #[error("prompt too large for context window")]
// PromptTooLarge { tokens: Option<u64> },
// #[error("missing API key")]
// NoApiKey,
// #[error("API rate limit exceeded")]
// RateLimitExceeded { retry_after: Option<Duration> },
// #[error("API servers are overloaded right now")]
// ServerOverloaded { retry_after: Option<Duration> },
// #[error("API server reported an internal server error: {message}")]
// ApiInternalServerError { message: String },
// #[error("{message}")]
// UpstreamProviderError {
// message: String,
// status: StatusCode,
// retry_after: Option<Duration>,
// },
// #[error("HTTP response error from API: status {status_code} - {message:?}")]
// HttpResponseError {
// status_code: StatusCode,
// message: String,
// },
// // Client errors
// #[error("invalid request format to API: {message}")]
// BadRequestFormat { message: String },
// #[error("authentication error with API: {message}")]
// AuthenticationError { message: String },
// #[error("Permission error with API: {message}")]
// PermissionError { message: String },
// #[error("language model provider API endpoint not found")]
// ApiEndpointNotFound,
// #[error("I/O error reading response from API")]
// ApiReadResponseError {
// #[source]
// error: io::Error,
// },
// #[error("error serializing request to API")]
// SerializeRequest {
// #[source]
// error: serde_json::Error,
// },
// #[error("error building request body to API")]
// BuildRequestBody {
// #[source]
// error: http::Error,
// },
// #[error("error sending HTTP request to API")]
// HttpSend {
// #[source]
// error: anyhow::Error,
// },
// #[error("error deserializing API response")]
// DeserializeResponse {
// #[source]
// error: serde_json::Error,
// },
// // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
// #[error(transparent)]
// Other(#[from] anyhow::Error),
// }
}
fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
let upstream_status = error_json
@@ -198,7 +277,6 @@ impl LanguageModelCompletionError {
}
pub fn from_cloud_failure(
upstream_provider: LanguageModelProviderName,
code: String,
message: String,
retry_after: Option<Duration>,
@@ -214,58 +292,46 @@ impl LanguageModelCompletionError {
if let Some((upstream_status, inner_message)) =
Self::parse_upstream_error_json(&message)
{
return Self::from_http_status(
upstream_provider,
upstream_status,
inner_message,
retry_after,
);
return Self::from_http_status(upstream_status, inner_message, retry_after);
}
anyhow!("completion request failed, code: {code}, message: {message}").into()
Self::Other(anyhow!(
"completion request failed, code: {code}, message: {message}"
))
} else if let Some(status_code) = code
.strip_prefix("upstream_http_")
.and_then(|code| StatusCode::from_str(code).ok())
{
Self::from_http_status(upstream_provider, status_code, message, retry_after)
Self::from_http_status(status_code, message, retry_after)
} else if let Some(status_code) = code
.strip_prefix("http_")
.and_then(|code| StatusCode::from_str(code).ok())
{
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
Self::from_http_status(status_code, message, retry_after)
} else {
anyhow!("completion request failed, code: {code}, message: {message}").into()
Self::Other(anyhow!(
"completion request failed, code: {code}, message: {message}"
))
}
}
pub fn from_http_status(
provider: LanguageModelProviderName,
status_code: StatusCode,
message: String,
retry_after: Option<Duration>,
) -> Self {
match status_code {
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
StatusCode::BAD_REQUEST => Self::BadRequestFormat { message },
StatusCode::UNAUTHORIZED => Self::AuthenticationError { message },
StatusCode::FORBIDDEN => Self::PermissionError { message },
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound,
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&message),
},
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
provider,
retry_after,
},
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
provider,
retry_after,
},
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
provider,
retry_after,
},
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded { retry_after },
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { message },
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded { retry_after },
_ if status_code.as_u16() == 529 => Self::ServerOverloaded { retry_after },
_ => Self::HttpResponseError {
provider,
status_code,
message,
},
@@ -275,31 +341,25 @@ impl LanguageModelCompletionError {
impl From<AnthropicError> for LanguageModelCompletionError {
fn from(error: AnthropicError) -> Self {
let provider = ANTHROPIC_PROVIDER_NAME;
match error {
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
AnthropicError::DeserializeResponse(error) => {
Self::DeserializeResponse { provider, error }
}
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { error },
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { error },
AnthropicError::HttpSend(error) => Self::HttpSend { error },
AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse { error },
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { error },
AnthropicError::HttpResponseError {
status_code,
message,
} => Self::HttpResponseError {
provider,
status_code,
message,
},
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after),
},
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
provider,
retry_after,
},
AnthropicError::ServerOverloaded { retry_after } => {
Self::ServerOverloaded { retry_after }
}
AnthropicError::ApiError(api_error) => api_error.into(),
}
}
@@ -308,37 +368,26 @@ impl From<AnthropicError> for LanguageModelCompletionError {
impl From<anthropic::ApiError> for LanguageModelCompletionError {
fn from(error: anthropic::ApiError) -> Self {
use anthropic::ApiErrorCode::*;
let provider = ANTHROPIC_PROVIDER_NAME;
match error.code() {
Some(code) => match code {
InvalidRequestError => Self::BadRequestFormat {
provider,
message: error.message,
},
AuthenticationError => Self::AuthenticationError {
provider,
message: error.message,
},
PermissionError => Self::PermissionError {
provider,
message: error.message,
},
NotFoundError => Self::ApiEndpointNotFound { provider },
NotFoundError => Self::ApiEndpointNotFound,
RequestTooLarge => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&error.message),
},
RateLimitError => Self::RateLimitExceeded {
provider,
retry_after: None,
},
RateLimitError => Self::RateLimitExceeded { retry_after: None },
ApiError => Self::ApiInternalServerError {
provider,
message: error.message,
},
OverloadedError => Self::ServerOverloaded {
provider,
retry_after: None,
},
OverloadedError => Self::ServerOverloaded { retry_after: None },
},
None => Self::Other(error.into()),
}
@@ -349,7 +398,7 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
fn from(error: open_ai::RequestError) -> Self {
match error {
open_ai::RequestError::HttpResponseError {
provider,
provider: _,
status_code,
body,
headers,
@@ -359,7 +408,7 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
.and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
.map(Duration::from_secs);
Self::from_http_status(provider.into(), status_code, body, retry_after)
Self::from_http_status(status_code, body, retry_after)
}
open_ai::RequestError::Other(e) => Self::Other(e),
}
@@ -368,23 +417,18 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
impl From<OpenRouterError> for LanguageModelCompletionError {
fn from(error: OpenRouterError) -> Self {
let provider = LanguageModelProviderName::new("OpenRouter");
match error {
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
OpenRouterError::DeserializeResponse(error) => {
Self::DeserializeResponse { provider, error }
}
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { error },
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { error },
OpenRouterError::HttpSend(error) => Self::HttpSend { error },
OpenRouterError::DeserializeResponse(error) => Self::DeserializeResponse { error },
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { error },
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after),
},
OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
provider,
retry_after,
},
OpenRouterError::ServerOverloaded { retry_after } => {
Self::ServerOverloaded { retry_after }
}
OpenRouterError::ApiError(api_error) => api_error.into(),
}
}
@@ -393,41 +437,28 @@ impl From<OpenRouterError> for LanguageModelCompletionError {
impl From<open_router::ApiError> for LanguageModelCompletionError {
fn from(error: open_router::ApiError) -> Self {
use open_router::ApiErrorCode::*;
let provider = LanguageModelProviderName::new("OpenRouter");
match error.code {
InvalidRequestError => Self::BadRequestFormat {
provider,
message: error.message,
},
AuthenticationError => Self::AuthenticationError {
provider,
message: error.message,
},
PaymentRequiredError => Self::AuthenticationError {
provider,
message: format!("Payment required: {}", error.message),
},
PermissionError => Self::PermissionError {
provider,
message: error.message,
},
RequestTimedOut => Self::HttpResponseError {
provider,
status_code: StatusCode::REQUEST_TIMEOUT,
message: error.message,
},
RateLimitError => Self::RateLimitExceeded {
provider,
retry_after: None,
},
RateLimitError => Self::RateLimitExceeded { retry_after: None },
ApiError => Self::ApiInternalServerError {
provider,
message: error.message,
},
OverloadedError => Self::ServerOverloaded {
provider,
retry_after: None,
},
OverloadedError => Self::ServerOverloaded { retry_after: None },
}
}
}
@@ -515,6 +546,9 @@ pub struct LanguageModelToolUse {
pub raw_input: String,
pub input: serde_json::Value,
pub is_input_complete: bool,
/// Thought signature the model sent us. Some models require that this
/// signature be preserved and sent back in conversation history for validation.
pub thought_signature: Option<String>,
}
pub struct LanguageModelTextStream {
@@ -630,7 +664,10 @@ pub trait LanguageModel: Send + Sync {
let last_token_usage = last_token_usage.clone();
async move {
match result {
Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
Ok(LanguageModelCompletionEvent::Started) => None,
Ok(LanguageModelCompletionEvent::RequestUsage { .. }) => None,
Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
@@ -640,7 +677,7 @@ pub trait LanguageModel: Send + Sync {
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
..
}) => None,
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
Ok(LanguageModelCompletionEvent::TokenUsage(token_usage)) => {
*last_token_usage.lock() = token_usage;
None
}
@@ -829,16 +866,13 @@ mod tests {
#[test]
fn test_from_cloud_failure_with_upstream_http_error() {
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
assert_eq!(provider.0, "anthropic");
}
LanguageModelCompletionError::ServerOverloaded { .. } => {}
_ => panic!(
"Expected ServerOverloaded error for 503 status, got: {:?}",
error
@@ -846,15 +880,13 @@ mod tests {
}
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
assert_eq!(provider.0, "anthropic");
LanguageModelCompletionError::ApiInternalServerError { message } => {
assert_eq!(message, "Internal server error");
}
_ => panic!(
@@ -867,16 +899,13 @@ mod tests {
#[test]
fn test_from_cloud_failure_with_standard_format() {
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_503".to_string(),
"Service unavailable".to_string(),
None,
);
match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
assert_eq!(provider.0, "anthropic");
}
LanguageModelCompletionError::ServerOverloaded { .. } => {}
_ => panic!("Expected ServerOverloaded error for upstream_http_503"),
}
}
@@ -884,16 +913,13 @@ mod tests {
#[test]
fn test_upstream_http_error_connection_timeout() {
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
assert_eq!(provider.0, "anthropic");
}
LanguageModelCompletionError::ServerOverloaded { .. } => {}
_ => panic!(
"Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
error
@@ -901,15 +927,13 @@ mod tests {
}
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
assert_eq!(provider.0, "anthropic");
LanguageModelCompletionError::ApiInternalServerError { message } => {
assert_eq!(
message,
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
@@ -921,4 +945,85 @@ mod tests {
),
}
}
#[test]
fn test_language_model_tool_use_serializes_with_signature() {
use serde_json::json;
let tool_use = LanguageModelToolUse {
id: LanguageModelToolUseId::from("test_id"),
name: "test_tool".into(),
raw_input: json!({"arg": "value"}).to_string(),
input: json!({"arg": "value"}),
is_input_complete: true,
thought_signature: Some("test_signature".to_string()),
};
let serialized = serde_json::to_value(&tool_use).unwrap();
assert_eq!(serialized["id"], "test_id");
assert_eq!(serialized["name"], "test_tool");
assert_eq!(serialized["thought_signature"], "test_signature");
}
#[test]
fn test_language_model_tool_use_deserializes_with_missing_signature() {
use serde_json::json;
let json = json!({
"id": "test_id",
"name": "test_tool",
"raw_input": "{\"arg\":\"value\"}",
"input": {"arg": "value"},
"is_input_complete": true
});
let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
assert_eq!(tool_use.name.as_ref(), "test_tool");
assert_eq!(tool_use.thought_signature, None);
}
#[test]
fn test_language_model_tool_use_round_trip_with_signature() {
use serde_json::json;
let original = LanguageModelToolUse {
id: LanguageModelToolUseId::from("round_trip_id"),
name: "round_trip_tool".into(),
raw_input: json!({"key": "value"}).to_string(),
input: json!({"key": "value"}),
is_input_complete: true,
thought_signature: Some("round_trip_sig".to_string()),
};
let serialized = serde_json::to_value(&original).unwrap();
let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
assert_eq!(deserialized.id, original.id);
assert_eq!(deserialized.name, original.name);
assert_eq!(deserialized.thought_signature, original.thought_signature);
}
#[test]
fn test_language_model_tool_use_round_trip_without_signature() {
use serde_json::json;
let original = LanguageModelToolUse {
id: LanguageModelToolUseId::from("no_sig_id"),
name: "no_sig_tool".into(),
raw_input: json!({"key": "value"}).to_string(),
input: json!({"key": "value"}),
is_input_complete: true,
thought_signature: None,
};
let serialized = serde_json::to_value(&original).unwrap();
let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
assert_eq!(deserialized.id, original.id);
assert_eq!(deserialized.name, original.name);
assert_eq!(deserialized.thought_signature, None);
}
}

View File

@@ -320,9 +320,7 @@ impl AnthropicModel {
async move {
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
return Err(LanguageModelCompletionError::NoApiKey);
};
let request = anthropic::stream_completion(
http_client.as_ref(),
@@ -711,6 +709,7 @@ impl AnthropicEventMapper {
is_input_complete: false,
raw_input: tool_use.input_json.clone(),
input,
thought_signature: None,
},
))];
}
@@ -734,6 +733,7 @@ impl AnthropicEventMapper {
is_input_complete: true,
input,
raw_input: tool_use.input_json.clone(),
thought_signature: None,
},
)),
Err(json_parse_err) => {
@@ -754,7 +754,7 @@ impl AnthropicEventMapper {
Event::MessageStart { message } => {
update_usage(&mut self.usage, &message.usage);
vec![
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
&self.usage,
))),
Ok(LanguageModelCompletionEvent::StartMessage {
@@ -776,9 +776,9 @@ impl AnthropicEventMapper {
}
};
}
vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&self.usage),
))]
vec![Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
&self.usage,
)))]
}
Event::MessageStop => {
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]

View File

@@ -970,11 +970,12 @@ pub fn map_to_language_model_completion_events(
is_input_complete: true,
raw_input: tool_use.input_json,
input,
thought_signature: None,
},
))
}),
ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| {
Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: metadata.input_tokens as u64,
output_tokens: metadata.output_tokens as u64,
cache_creation_input_tokens: metadata

View File

@@ -541,7 +541,6 @@ impl From<ApiError> for LanguageModelCompletionError {
}
return LanguageModelCompletionError::from_http_status(
PROVIDER_NAME,
error.status,
cloud_error.message,
None,
@@ -549,12 +548,7 @@ impl From<ApiError> for LanguageModelCompletionError {
}
let retry_after = None;
LanguageModelCompletionError::from_http_status(
PROVIDER_NAME,
error.status,
error.body,
retry_after,
)
LanguageModelCompletionError::from_http_status(error.status, error.body, retry_after)
}
}
@@ -961,7 +955,7 @@ where
vec![Err(LanguageModelCompletionError::from(error))]
}
Ok(CompletionEvent::Status(event)) => {
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
vec![LanguageModelCompletionEvent::from_completion_request_status(event)]
}
Ok(CompletionEvent::Event(event)) => map_callback(event),
})
@@ -1313,8 +1307,7 @@ mod tests {
let completion_error: LanguageModelCompletionError = api_error.into();
match completion_error {
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
assert_eq!(provider, PROVIDER_NAME);
LanguageModelCompletionError::ApiInternalServerError { message } => {
assert_eq!(message, "Regular internal server error");
}
_ => panic!(
@@ -1362,9 +1355,7 @@ mod tests {
let completion_error: LanguageModelCompletionError = api_error.into();
match completion_error {
LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
assert_eq!(provider, PROVIDER_NAME);
}
LanguageModelCompletionError::ApiInternalServerError { .. } => {}
_ => panic!(
"Expected ApiInternalServerError for invalid JSON, got: {:?}",
completion_error

View File

@@ -422,14 +422,12 @@ pub fn map_to_language_model_completion_events(
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
)));
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
})));
}
match choice.finish_reason.as_deref() {
@@ -458,6 +456,7 @@ pub fn map_to_language_model_completion_events(
is_input_complete: true,
input,
raw_input: tool_call.arguments,
thought_signature: None,
},
)),
Err(error) => Ok(
@@ -560,6 +559,7 @@ impl CopilotResponsesEventMapper {
is_input_complete: true,
input,
raw_input: arguments.clone(),
thought_signature: None,
},
))),
Err(error) => {
@@ -608,7 +608,7 @@ impl CopilotResponsesEventMapper {
copilot::copilot_responses::StreamEvent::Completed { response } => {
let mut events = Vec::new();
if let Some(usage) = response.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.input_tokens.unwrap_or(0),
output_tokens: usage.output_tokens.unwrap_or(0),
cache_creation_input_tokens: 0,
@@ -641,7 +641,7 @@ impl CopilotResponsesEventMapper {
let mut events = Vec::new();
if let Some(usage) = response.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.input_tokens.unwrap_or(0),
output_tokens: usage.output_tokens.unwrap_or(0),
cache_creation_input_tokens: 0,
@@ -653,7 +653,6 @@ impl CopilotResponsesEventMapper {
}
copilot::copilot_responses::StreamEvent::Failed { response } => {
let provider = PROVIDER_NAME;
let (status_code, message) = match response.error {
Some(error) => {
let status_code = StatusCode::from_str(&error.code)
@@ -666,7 +665,6 @@ impl CopilotResponsesEventMapper {
),
};
vec![Err(LanguageModelCompletionError::HttpResponseError {
provider,
status_code,
message,
})]
@@ -1097,7 +1095,7 @@ mod tests {
));
assert!(matches!(
mapped[2],
LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: 5,
output_tokens: 3,
..
@@ -1205,7 +1203,7 @@ mod tests {
let mapped = map_events(events);
assert!(matches!(
mapped[0],
LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: 10,
output_tokens: 0,
..

View File

@@ -224,9 +224,7 @@ impl DeepSeekLanguageModel {
let future = self.request_limiter.stream(async move {
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
return Err(LanguageModelCompletionError::NoApiKey);
};
let request =
deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
@@ -479,7 +477,7 @@ impl DeepSeekEventMapper {
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
@@ -501,6 +499,7 @@ impl DeepSeekEventMapper {
is_input_complete: true,
input,
raw_input: tool_call.arguments.clone(),
thought_signature: None,
},
)),
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {

View File

@@ -350,10 +350,7 @@ impl LanguageModel for GoogleLanguageModel {
async move {
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
}
.into());
return Err(LanguageModelCompletionError::NoApiKey.into());
};
let response = google_ai::count_tokens(
http_client.as_ref(),
@@ -439,11 +436,15 @@ pub fn into_google(
})]
}
language_model::MessageContent::ToolUse(tool_use) => {
// Normalize empty string signatures to None
let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
function_call: google_ai::FunctionCall {
name: tool_use.name.to_string(),
args: tool_use.input,
},
thought_signature,
})]
}
language_model::MessageContent::ToolResult(tool_result) => {
@@ -604,9 +605,9 @@ impl GoogleEventMapper {
let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut self.usage, &usage_metadata);
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&self.usage),
)))
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
&self.usage,
))))
}
if let Some(prompt_feedback) = event.prompt_feedback
@@ -655,6 +656,11 @@ impl GoogleEventMapper {
let id: LanguageModelToolUseId =
format!("{}-{}", name, next_tool_id).into();
// Normalize empty string signatures to None
let thought_signature = function_call_part
.thought_signature
.filter(|s| !s.is_empty());
events.push(Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id,
@@ -662,6 +668,7 @@ impl GoogleEventMapper {
is_input_complete: true,
raw_input: function_call_part.function_call.args.to_string(),
input: function_call_part.function_call.args,
thought_signature,
},
)));
}
@@ -891,3 +898,424 @@ impl Render for ConfigurationView {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use google_ai::{
Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
Part, Role as GoogleRole, TextPart,
};
use language_model::{LanguageModelToolUseId, MessageContent, Role};
use serde_json::json;
#[test]
fn test_function_call_with_signature_creates_tool_use_with_signature() {
let mut mapper = GoogleEventMapper::new();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: Some("test_signature_123".to_string()),
})],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
assert_eq!(events.len(), 2); // ToolUse event + Stop event
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(tool_use.name.as_ref(), "test_function");
assert_eq!(
tool_use.thought_signature.as_deref(),
Some("test_signature_123")
);
} else {
panic!("Expected ToolUse event");
}
}
#[test]
fn test_function_call_without_signature_has_none() {
let mut mapper = GoogleEventMapper::new();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: None,
})],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(tool_use.thought_signature, None);
} else {
panic!("Expected ToolUse event");
}
}
#[test]
fn test_empty_string_signature_normalized_to_none() {
let mut mapper = GoogleEventMapper::new();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: Some("".to_string()),
})],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(tool_use.thought_signature, None);
} else {
panic!("Expected ToolUse event");
}
}
#[test]
fn test_parallel_function_calls_preserve_signatures() {
let mut mapper = GoogleEventMapper::new();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![
Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "function_1".to_string(),
args: json!({"arg": "value1"}),
},
thought_signature: Some("signature_1".to_string()),
}),
Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "function_2".to_string(),
args: json!({"arg": "value2"}),
},
thought_signature: None,
}),
],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(tool_use.name.as_ref(), "function_1");
assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1"));
} else {
panic!("Expected ToolUse event for function_1");
}
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
assert_eq!(tool_use.name.as_ref(), "function_2");
assert_eq!(tool_use.thought_signature, None);
} else {
panic!("Expected ToolUse event for function_2");
}
}
#[test]
fn test_tool_use_with_signature_converts_to_function_call_part() {
let tool_use = language_model::LanguageModelToolUse {
id: LanguageModelToolUseId::from("test_id"),
name: "test_function".into(),
raw_input: json!({"arg": "value"}).to_string(),
input: json!({"arg": "value"}),
is_input_complete: true,
thought_signature: Some("test_signature_456".to_string()),
};
let request = super::into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false,
}],
..Default::default()
},
"gemini-2.5-flash".to_string(),
GoogleModelMode::Default,
);
assert_eq!(request.contents[0].parts.len(), 1);
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
assert_eq!(fc_part.function_call.name, "test_function");
assert_eq!(
fc_part.thought_signature.as_deref(),
Some("test_signature_456")
);
} else {
panic!("Expected FunctionCallPart");
}
}
#[test]
fn test_tool_use_without_signature_omits_field() {
let tool_use = language_model::LanguageModelToolUse {
id: LanguageModelToolUseId::from("test_id"),
name: "test_function".into(),
raw_input: json!({"arg": "value"}).to_string(),
input: json!({"arg": "value"}),
is_input_complete: true,
thought_signature: None,
};
let request = super::into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false,
}],
..Default::default()
},
"gemini-2.5-flash".to_string(),
GoogleModelMode::Default,
);
assert_eq!(request.contents[0].parts.len(), 1);
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
assert_eq!(fc_part.thought_signature, None);
} else {
panic!("Expected FunctionCallPart");
}
}
#[test]
fn test_empty_signature_in_tool_use_normalized_to_none() {
let tool_use = language_model::LanguageModelToolUse {
id: LanguageModelToolUseId::from("test_id"),
name: "test_function".into(),
raw_input: json!({"arg": "value"}).to_string(),
input: json!({"arg": "value"}),
is_input_complete: true,
thought_signature: Some("".to_string()),
};
let request = super::into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false,
}],
..Default::default()
},
"gemini-2.5-flash".to_string(),
GoogleModelMode::Default,
);
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
assert_eq!(fc_part.thought_signature, None);
} else {
panic!("Expected FunctionCallPart");
}
}
#[test]
fn test_round_trip_preserves_signature() {
let mut mapper = GoogleEventMapper::new();
// Simulate receiving a response from Google with a signature
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: Some("round_trip_sig".to_string()),
})],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
tool_use.clone()
} else {
panic!("Expected ToolUse event");
};
// Convert back to Google format
let request = super::into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::ToolUse(tool_use)],
cache: false,
}],
..Default::default()
},
"gemini-2.5-flash".to_string(),
GoogleModelMode::Default,
);
// Verify signature is preserved
if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig"));
} else {
panic!("Expected FunctionCallPart");
}
}
#[test]
fn test_mixed_text_and_function_call_with_signature() {
let mut mapper = GoogleEventMapper::new();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![
Part::TextPart(TextPart {
text: "I'll help with that.".to_string(),
}),
Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "helper_function".to_string(),
args: json!({"query": "help"}),
},
thought_signature: Some("mixed_sig".to_string()),
}),
],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event
if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] {
assert_eq!(text, "I'll help with that.");
} else {
panic!("Expected Text event");
}
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
assert_eq!(tool_use.name.as_ref(), "helper_function");
assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig"));
} else {
panic!("Expected ToolUse event");
}
}
#[test]
fn test_special_characters_in_signature_preserved() {
let mut mapper = GoogleEventMapper::new();
let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
thought_signature: Some(signature_with_special_chars.clone()),
})],
role: GoogleRole::Model,
},
finish_reason: None,
finish_message: None,
safety_ratings: None,
citation_metadata: None,
}]),
prompt_feedback: None,
usage_metadata: None,
};
let events = mapper.map_event(response);
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(
tool_use.thought_signature.as_deref(),
Some(signature_with_special_chars.as_str())
);
} else {
panic!("Expected ToolUse event");
}
}
}

View File

@@ -547,7 +547,7 @@ impl LmStudioEventMapper {
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
@@ -569,6 +569,7 @@ impl LmStudioEventMapper {
is_input_complete: true,
input,
raw_input: tool_call.arguments,
thought_signature: None,
},
)),
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {

View File

@@ -291,9 +291,7 @@ impl MistralLanguageModel {
let future = self.request_limiter.stream(async move {
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
return Err(LanguageModelCompletionError::NoApiKey);
};
let request =
mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
@@ -672,7 +670,7 @@ impl MistralEventMapper {
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
@@ -720,6 +718,7 @@ impl MistralEventMapper {
is_input_complete: true,
input,
raw_input: tool_call.arguments,
thought_signature: None,
},
))),
Err(error) => {

View File

@@ -592,6 +592,7 @@ fn map_to_language_model_completion_events(
raw_input: function.arguments.to_string(),
input: function.arguments,
is_input_complete: true,
thought_signature: None,
});
events.push(Ok(event));
state.used_tools = true;
@@ -602,7 +603,7 @@ fn map_to_language_model_completion_events(
};
if delta.done {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: delta.prompt_eval_count.unwrap_or(0),
output_tokens: delta.eval_count.unwrap_or(0),
cache_creation_input_tokens: 0,

View File

@@ -228,7 +228,7 @@ impl OpenAiLanguageModel {
let future = self.request_limiter.stream(async move {
let provider = PROVIDER_NAME;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { provider });
return Err(LanguageModelCompletionError::NoApiKey);
};
let request = stream_completion(
http_client.as_ref(),
@@ -534,7 +534,7 @@ impl OpenAiEventMapper {
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let mut events = Vec::new();
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
@@ -586,6 +586,7 @@ impl OpenAiEventMapper {
is_input_complete: true,
input,
raw_input: tool_call.arguments.clone(),
thought_signature: None,
},
)),
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {

View File

@@ -227,7 +227,7 @@ impl OpenAiCompatibleLanguageModel {
let provider = self.provider_name.clone();
let future = self.request_limiter.stream(async move {
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { provider });
return Err(LanguageModelCompletionError::NoApiKey);
};
let request = stream_completion(
http_client.as_ref(),

View File

@@ -84,9 +84,7 @@ impl State {
let http_client = self.http_client.clone();
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
let Some(api_key) = self.api_key_state.key(&api_url) else {
return Task::ready(Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
}));
return Task::ready(Err(LanguageModelCompletionError::NoApiKey));
};
cx.spawn(async move |this, cx| {
let models = list_models(http_client.as_ref(), &api_url, &api_key)
@@ -288,9 +286,7 @@ impl OpenRouterLanguageModel {
async move {
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
return Err(LanguageModelCompletionError::NoApiKey);
};
let request =
open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
@@ -613,7 +609,7 @@ impl OpenRouterEventMapper {
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
@@ -635,6 +631,7 @@ impl OpenRouterEventMapper {
is_input_complete: true,
input,
raw_input: tool_call.arguments.clone(),
thought_signature: None,
},
)),
Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {

View File

@@ -222,7 +222,7 @@ impl VercelLanguageModel {
let future = self.request_limiter.stream(async move {
let provider = PROVIDER_NAME;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { provider });
return Err(LanguageModelCompletionError::NoApiKey);
};
let request = open_ai::stream_completion(
http_client.as_ref(),

View File

@@ -230,7 +230,7 @@ impl XAiLanguageModel {
let future = self.request_limiter.stream(async move {
let provider = PROVIDER_NAME;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { provider });
return Err(LanguageModelCompletionError::NoApiKey);
};
let request = open_ai::stream_completion(
http_client.as_ref(),

View File

@@ -18,7 +18,6 @@ workspace.workspace = true
util.workspace = true
serde_json.workspace = true
smol.workspace = true
log.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }

View File

@@ -5,7 +5,10 @@ use std::{
};
use gpui::{
App, AppContext, Context, Entity, Hsla, InteractiveElement, IntoElement, ParentElement, Render, ScrollHandle, SerializedThreadTaskTimings, StatefulInteractiveElement, Styled, Task, TaskTiming, ThreadTaskTimings, TitlebarOptions, WindowBounds, WindowHandle, WindowOptions, div, prelude::FluentBuilder, px, relative, size
App, AppContext, Context, Entity, Hsla, InteractiveElement, IntoElement, ParentElement, Render,
ScrollHandle, SerializedTaskTiming, StatefulInteractiveElement, Styled, Task, TaskTiming,
TitlebarOptions, WindowBounds, WindowHandle, WindowOptions, div, prelude::FluentBuilder, px,
relative, size,
};
use util::ResultExt;
use workspace::{
@@ -284,13 +287,8 @@ impl Render for ProfilerWindow {
let Some(data) = this.get_timings() else {
return;
};
let timings = ThreadTaskTimings {
thread_name: Some("main".to_string()),
thread_id: std::thread::current().id(),
timings: data.clone()
};
let timings = Vec::from([SerializedThreadTaskTimings::convert(this.startup_time, timings)]);
let timings =
SerializedTaskTiming::convert(this.startup_time, &data);
let active_path = workspace
.read_with(cx, |workspace, cx| {
@@ -307,17 +305,12 @@ impl Render for ProfilerWindow {
);
cx.background_spawn(async move {
let path = match path.await.log_err() {
Some(Ok(Some(path))) => path,
Some(e @ Err(_)) => {
e.log_err();
log::warn!("Saving miniprof in workingdir");
std::path::Path::new(
"performance_profile.miniprof",
)
.to_path_buf()
}
Some(Ok(None)) | None => return,
let path = path.await;
let path =
path.log_err().and_then(|p| p.log_err()).flatten();
let Some(path) = path else {
return;
};
let Some(timings) =

View File

@@ -43,7 +43,6 @@ text.workspace = true
theme.workspace = true
tree-sitter.workspace = true
util.workspace = true
zlog.workspace = true
[dev-dependencies]
buffer_diff = { workspace = true, features = ["test-support"] }

View File

@@ -76,8 +76,6 @@ impl MultiBuffer {
context_line_count: u32,
cx: &mut Context<Self>,
) -> (Vec<Range<Anchor>>, bool) {
let _timer =
zlog::time!("set_excerpts_for_path").warn_if_gt(std::time::Duration::from_millis(100));
let buffer_snapshot = buffer.read(cx).snapshot();
let excerpt_ranges = build_excerpt_ranges(ranges, context_line_count, &buffer_snapshot);

View File

@@ -449,7 +449,7 @@ pub async fn handle_import_vscode_settings(
match settings::VsCodeSettings::load_user_settings(source, fs.clone()).await {
Ok(vscode_settings) => vscode_settings,
Err(err) => {
zlog::error!("{err}");
zlog::error!("{err:?}");
let _ = cx.prompt(
gpui::PromptLevel::Info,
&format!("Could not find or load a {source} settings file"),

View File

@@ -99,13 +99,18 @@ pub enum ContextServerConfiguration {
command: ContextServerCommand,
settings: serde_json::Value,
},
Http {
url: url::Url,
headers: HashMap<String, String>,
},
}
impl ContextServerConfiguration {
pub fn command(&self) -> &ContextServerCommand {
pub fn command(&self) -> Option<&ContextServerCommand> {
match self {
ContextServerConfiguration::Custom { command } => command,
ContextServerConfiguration::Extension { command, .. } => command,
ContextServerConfiguration::Custom { command } => Some(command),
ContextServerConfiguration::Extension { command, .. } => Some(command),
ContextServerConfiguration::Http { .. } => None,
}
}
@@ -142,6 +147,14 @@ impl ContextServerConfiguration {
}
}
}
ContextServerSettings::Http {
enabled: _,
url,
headers: auth,
} => {
let url = url::Url::parse(&url).log_err()?;
Some(ContextServerConfiguration::Http { url, headers: auth })
}
}
}
}
@@ -207,7 +220,7 @@ impl ContextServerStore {
#[cfg(any(test, feature = "test-support"))]
pub fn test_maintain_server_loop(
context_server_factory: ContextServerFactory,
context_server_factory: Option<ContextServerFactory>,
registry: Entity<ContextServerDescriptorRegistry>,
worktree_store: Entity<WorktreeStore>,
weak_project: WeakEntity<Project>,
@@ -215,7 +228,7 @@ impl ContextServerStore {
) -> Self {
Self::new_internal(
true,
Some(context_server_factory),
context_server_factory,
registry,
worktree_store,
weak_project,
@@ -385,17 +398,6 @@ impl ContextServerStore {
result
}
pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
if let Some(state) = self.servers.get(id) {
let configuration = state.configuration();
self.stop_server(&state.server().id(), cx)?;
let new_server = self.create_context_server(id.clone(), configuration.clone(), cx);
self.run_server(new_server, configuration, cx);
}
Ok(())
}
fn run_server(
&mut self,
server: Arc<ContextServer>,
@@ -479,33 +481,42 @@ impl ContextServerStore {
id: ContextServerId,
configuration: Arc<ContextServerConfiguration>,
cx: &mut Context<Self>,
) -> Arc<ContextServer> {
let project = self.project.upgrade();
let mut root_path = None;
if let Some(project) = project {
let project = project.read(cx);
if project.is_local() {
if let Some(path) = project.active_project_directory(cx) {
root_path = Some(path);
} else {
for worktree in self.worktree_store.read(cx).visible_worktrees(cx) {
if let Some(path) = worktree.read(cx).root_dir() {
root_path = Some(path);
break;
}
}
}
}
};
) -> Result<Arc<ContextServer>> {
if let Some(factory) = self.context_server_factory.as_ref() {
factory(id, configuration)
} else {
Arc::new(ContextServer::stdio(
return Ok(factory(id, configuration));
}
match configuration.as_ref() {
ContextServerConfiguration::Http { url, headers } => Ok(Arc::new(ContextServer::http(
id,
configuration.command().clone(),
root_path,
))
url,
headers.clone(),
cx.http_client(),
cx.background_executor().clone(),
)?)),
_ => {
let root_path = self
.project
.read_with(cx, |project, cx| project.active_project_directory(cx))
.ok()
.flatten()
.or_else(|| {
self.worktree_store.read_with(cx, |store, cx| {
store.visible_worktrees(cx).fold(None, |acc, item| {
if acc.is_none() {
item.read(cx).root_dir()
} else {
acc
}
})
})
});
Ok(Arc::new(ContextServer::stdio(
id,
configuration.command().unwrap().clone(),
root_path,
)))
}
}
}
@@ -621,14 +632,16 @@ impl ContextServerStore {
let existing_config = state.as_ref().map(|state| state.configuration());
if existing_config.as_deref() != Some(&config) || is_stopped {
let config = Arc::new(config);
let server = this.create_context_server(id.clone(), config.clone(), cx);
let server = this.create_context_server(id.clone(), config.clone(), cx)?;
servers_to_start.push((server, config));
if this.servers.contains_key(&id) {
servers_to_stop.insert(id);
}
}
}
})?;
anyhow::Ok(())
})??;
this.update(cx, |this, cx| {
for id in servers_to_stop {
@@ -654,6 +667,7 @@ mod tests {
};
use context_server::test::create_fake_transport;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use http_client::{FakeHttpClient, Response};
use serde_json::json;
use std::{cell::RefCell, path::PathBuf, rc::Rc};
use util::path;
@@ -894,12 +908,12 @@ mod tests {
});
let store = cx.new(|cx| {
ContextServerStore::test_maintain_server_loop(
Box::new(move |id, _| {
Some(Box::new(move |id, _| {
Arc::new(ContextServer::new(
id.clone(),
Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
))
}),
})),
registry.clone(),
project.read(cx).worktree_store(),
project.downgrade(),
@@ -1130,12 +1144,12 @@ mod tests {
let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
let store = cx.new(|cx| {
ContextServerStore::test_maintain_server_loop(
Box::new(move |id, _| {
Some(Box::new(move |id, _| {
Arc::new(ContextServer::new(
id.clone(),
Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
))
}),
})),
registry.clone(),
project.read(cx).worktree_store(),
project.downgrade(),
@@ -1228,6 +1242,73 @@ mod tests {
});
}
#[gpui::test]
async fn test_remote_context_server(cx: &mut TestAppContext) {
const SERVER_ID: &str = "remote-server";
let server_id = ContextServerId(SERVER_ID.into());
let server_url = "http://example.com/api";
let (_fs, project) = setup_context_server_test(
cx,
json!({ "code.rs": "" }),
vec![(
SERVER_ID.into(),
ContextServerSettings::Http {
enabled: true,
url: server_url.to_string(),
headers: Default::default(),
},
)],
)
.await;
let client = FakeHttpClient::create(|_| async move {
use http_client::AsyncBody;
let response = Response::builder()
.status(200)
.header("Content-Type", "application/json")
.body(AsyncBody::from(
serde_json::to_string(&json!({
"jsonrpc": "2.0",
"id": 0,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"serverInfo": {
"name": "test-server",
"version": "1.0.0"
}
}
}))
.unwrap(),
))
.unwrap();
Ok(response)
});
cx.update(|cx| cx.set_http_client(client));
let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
let store = cx.new(|cx| {
ContextServerStore::test_maintain_server_loop(
None,
registry.clone(),
project.read(cx).worktree_store(),
project.downgrade(),
cx,
)
});
let _server_events = assert_server_events(
&store,
vec![
(server_id.clone(), ContextServerStatus::Starting),
(server_id.clone(), ContextServerStatus::Running),
],
cx,
);
cx.run_until_parked();
}
struct ServerEvents {
received_event_count: Rc<RefCell<usize>>,
expected_event_count: usize,

View File

@@ -34,11 +34,11 @@ impl DiffBase {
pub struct BranchDiff {
diff_base: DiffBase,
pub repo: Option<Entity<Repository>>,
pub project: Entity<Project>,
repo: Option<Entity<Repository>>,
project: Entity<Project>,
base_commit: Option<SharedString>,
head_commit: Option<SharedString>,
pub tree_diff: Option<TreeDiff>,
tree_diff: Option<TreeDiff>,
_subscription: Subscription,
update_needed: postage::watch::Sender<()>,
_task: Task<()>,
@@ -283,11 +283,7 @@ impl BranchDiff {
else {
continue;
};
let repo = repo.clone();
let task = cx.spawn(async move |project, cx| {
Self::load_buffer(branch_diff, project_path, repo.clone(), project, cx).await
});
let task = Self::load_buffer(branch_diff, project_path, repo.clone(), cx);
output.push(DiffBuffer {
repo_path: item.repo_path.clone(),
@@ -307,11 +303,8 @@ impl BranchDiff {
let Some(project_path) = repo.read(cx).repo_path_to_project_path(&path, cx) else {
continue;
};
let repo = repo.clone();
let branch_diff2 = Some(branch_diff.clone());
let task = cx.spawn(async move |project, cx| {
Self::load_buffer(branch_diff2, project_path, repo, project, cx).await
});
let task =
Self::load_buffer(Some(branch_diff.clone()), project_path, repo.clone(), cx);
let file_status = diff_status_to_file_status(branch_diff);
@@ -325,40 +318,42 @@ impl BranchDiff {
output
}
pub async fn load_buffer(
fn load_buffer(
branch_diff: Option<git::status::TreeDiffStatus>,
project_path: crate::ProjectPath,
repo: Entity<Repository>,
project: WeakEntity<Project>,
cx: &mut gpui::AsyncApp, // making this generic over AppContext hangs the compiler
) -> Result<(Entity<Buffer>, Entity<BufferDiff>)> {
let buffer = project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
cx: &Context<'_, Project>,
) -> Task<Result<(Entity<Buffer>, Entity<BufferDiff>)>> {
let task = cx.spawn(async move |project, cx| {
let buffer = project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
let languages = project.update(cx, |project, _cx| project.languages().clone())?;
let languages = project.update(cx, |project, _cx| project.languages().clone())?;
let changes = if let Some(entry) = branch_diff {
let oid = match entry {
git::status::TreeDiffStatus::Added { .. } => None,
git::status::TreeDiffStatus::Modified { old, .. }
| git::status::TreeDiffStatus::Deleted { old } => Some(old),
let changes = if let Some(entry) = branch_diff {
let oid = match entry {
git::status::TreeDiffStatus::Added { .. } => None,
git::status::TreeDiffStatus::Modified { old, .. }
| git::status::TreeDiffStatus::Deleted { old } => Some(old),
};
project
.update(cx, |project, cx| {
project.git_store().update(cx, |git_store, cx| {
git_store.open_diff_since(oid, buffer.clone(), repo, languages, cx)
})
})?
.await?
} else {
project
.update(cx, |project, cx| {
project.open_uncommitted_diff(buffer.clone(), cx)
})?
.await?
};
project
.update(cx, |project, cx| {
project.git_store().update(cx, |git_store, cx| {
git_store.open_diff_since(oid, buffer.clone(), repo, languages, cx)
})
})?
.await?
} else {
project
.update(cx, |project, cx| {
project.open_uncommitted_diff(buffer.clone(), cx)
})?
.await?
};
Ok((buffer, changes))
Ok((buffer, changes))
});
task
}
}

View File

@@ -135,6 +135,16 @@ pub enum ContextServerSettings {
/// are supported.
settings: serde_json::Value,
},
Http {
/// Whether the context server is enabled.
#[serde(default = "default_true")]
enabled: bool,
/// The URL of the remote context server.
url: String,
/// Optional authentication configuration for the remote server.
#[serde(skip_serializing_if = "HashMap::is_empty", default)]
headers: HashMap<String, String>,
},
}
impl From<settings::ContextServerSettingsContent> for ContextServerSettings {
@@ -146,6 +156,15 @@ impl From<settings::ContextServerSettingsContent> for ContextServerSettings {
settings::ContextServerSettingsContent::Extension { enabled, settings } => {
ContextServerSettings::Extension { enabled, settings }
}
settings::ContextServerSettingsContent::Http {
enabled,
url,
headers,
} => ContextServerSettings::Http {
enabled,
url,
headers,
},
}
}
}
@@ -158,6 +177,15 @@ impl Into<settings::ContextServerSettingsContent> for ContextServerSettings {
ContextServerSettings::Extension { enabled, settings } => {
settings::ContextServerSettingsContent::Extension { enabled, settings }
}
ContextServerSettings::Http {
enabled,
url,
headers,
} => settings::ContextServerSettingsContent::Http {
enabled,
url,
headers,
},
}
}
}
@@ -174,6 +202,7 @@ impl ContextServerSettings {
match self {
ContextServerSettings::Custom { enabled, .. } => *enabled,
ContextServerSettings::Extension { enabled, .. } => *enabled,
ContextServerSettings::Http { enabled, .. } => *enabled,
}
}
@@ -181,6 +210,7 @@ impl ContextServerSettings {
match self {
ContextServerSettings::Custom { enabled: e, .. } => *e = enabled,
ContextServerSettings::Extension { enabled: e, .. } => *e = enabled,
ContextServerSettings::Http { enabled: e, .. } => *e = enabled,
}
}
}

View File

@@ -67,7 +67,7 @@ use workspace::{
notifications::{DetachAndPromptErr, NotifyResultExt, NotifyTaskExt},
};
use worktree::CreatedEntry;
use zed_actions::workspace::OpenWithSystem;
use zed_actions::{project_panel::ToggleFocus, workspace::OpenWithSystem};
const PROJECT_PANEL_KEY: &str = "ProjectPanel";
const NEW_ENTRY_ID: ProjectEntryId = ProjectEntryId::MAX;
@@ -306,8 +306,6 @@ actions!(
OpenSplitVertical,
/// Opens the selected file in a horizontal split.
OpenSplitHorizontal,
/// Toggles focus on the project panel.
ToggleFocus,
/// Toggles visibility of git-ignored files.
ToggleHideGitIgnore,
/// Toggles visibility of hidden files.

View File

@@ -489,7 +489,7 @@ impl SshRemoteConnection {
let ssh_shell = socket.shell().await;
log::info!("Remote shell discovered: {}", ssh_shell);
let ssh_platform = socket.platform(ShellKind::new(&ssh_shell, false)).await?;
log::info!("Remote platform discovered: {}", ssh_shell);
log::info!("Remote platform discovered: {:?}", ssh_platform);
let ssh_path_style = match ssh_platform.os {
"windows" => PathStyle::Windows,
_ => PathStyle::Posix,

View File

@@ -92,7 +92,7 @@ impl WslRemoteConnection {
.detect_platform()
.await
.context("failed detecting platform")?;
log::info!("Remote platform discovered: {}", this.shell);
log::info!("Remote platform discovered: {:?}", this.platform);
this.remote_binary_path = Some(
this.ensure_server_binary(&delegate, release_channel, version, commit, cx)
.await

View File

@@ -13,32 +13,6 @@ use std::{
time::Duration,
};
// https://docs.rs/tokio/latest/src/tokio/task/yield_now.rs.html#39-64
pub async fn yield_now() {
/// Yield implementation
struct YieldNow {
yielded: bool,
}
impl Future for YieldNow {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
// use core::task::ready;
// ready!(crate::trace::trace_leaf(cx));
if self.yielded {
return Poll::Ready(());
}
self.yielded = true;
// context::defer(cx.waker());
Poll::Pending
}
}
YieldNow { yielded: false }.await;
}
#[derive(Clone)]
pub struct ForegroundExecutor {
session_id: SessionId,

View File

@@ -1039,218 +1039,3 @@ impl std::fmt::Display for DelayMs {
write!(f, "{}ms", self.0)
}
}
/// A wrapper type that distinguishes between an explicitly set value (including null) and an unset value.
///
/// This is useful for configuration where you need to differentiate between:
/// - A field that is not present in the configuration file (`Maybe::Unset`)
/// - A field that is explicitly set to `null` (`Maybe::Set(None)`)
/// - A field that is explicitly set to a value (`Maybe::Set(Some(value))`)
///
/// # Examples
///
/// In JSON:
/// - `{}` (field missing) deserializes to `Maybe::Unset`
/// - `{"field": null}` deserializes to `Maybe::Set(None)`
/// - `{"field": "value"}` deserializes to `Maybe::Set(Some("value"))`
///
/// WARN: This type should not be wrapped in an option inside of settings, otherwise the default `serde_json` behavior
/// of treating `null` and missing as the `Option::None` will be used
#[derive(Debug, Clone, PartialEq, Eq, strum::EnumDiscriminants, Default)]
#[strum_discriminants(derive(strum::VariantArray, strum::VariantNames, strum::FromRepr))]
pub enum Maybe<T> {
/// An explicitly set value, which may be `None` (representing JSON `null`) or `Some(value)`.
Set(Option<T>),
/// A value that was not present in the configuration.
#[default]
Unset,
}
impl<T: Clone> merge_from::MergeFrom for Maybe<T> {
fn merge_from(&mut self, other: &Self) {
if self.is_unset() {
*self = other.clone();
}
}
}
impl<T> From<Option<Option<T>>> for Maybe<T> {
fn from(value: Option<Option<T>>) -> Self {
match value {
Some(value) => Maybe::Set(value),
None => Maybe::Unset,
}
}
}
impl<T> Maybe<T> {
pub fn is_set(&self) -> bool {
matches!(self, Maybe::Set(_))
}
pub fn is_unset(&self) -> bool {
matches!(self, Maybe::Unset)
}
pub fn into_inner(self) -> Option<T> {
match self {
Maybe::Set(value) => value,
Maybe::Unset => None,
}
}
pub fn as_ref(&self) -> Option<&Option<T>> {
match self {
Maybe::Set(value) => Some(value),
Maybe::Unset => None,
}
}
}
impl<T: serde::Serialize> serde::Serialize for Maybe<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
Maybe::Set(value) => value.serialize(serializer),
Maybe::Unset => serializer.serialize_none(),
}
}
}
impl<'de, T: serde::Deserialize<'de>> serde::Deserialize<'de> for Maybe<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Option::<T>::deserialize(deserializer).map(Maybe::Set)
}
}
impl<T: JsonSchema> JsonSchema for Maybe<T> {
fn schema_name() -> std::borrow::Cow<'static, str> {
format!("Nullable<{}>", T::schema_name()).into()
}
fn json_schema(generator: &mut schemars::generate::SchemaGenerator) -> schemars::Schema {
let mut schema = generator.subschema_for::<Option<T>>();
// Add description explaining that null is an explicit value
let description = if let Some(existing_desc) =
schema.get("description").and_then(|desc| desc.as_str())
{
format!(
"{}. Note: `null` is treated as an explicit value, different from omitting the field entirely.",
existing_desc
)
} else {
"This field supports explicit `null` values. Omitting the field is different from setting it to `null`.".to_string()
};
schema.insert("description".to_string(), description.into());
schema
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_maybe() {
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct TestStruct {
#[serde(default)]
#[serde(skip_serializing_if = "Maybe::is_unset")]
field: Maybe<String>,
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct NumericTest {
#[serde(default)]
value: Maybe<i32>,
}
let json = "{}";
let result: TestStruct = serde_json::from_str(json).unwrap();
assert!(result.field.is_unset());
assert_eq!(result.field, Maybe::Unset);
let json = r#"{"field": null}"#;
let result: TestStruct = serde_json::from_str(json).unwrap();
assert!(result.field.is_set());
assert_eq!(result.field, Maybe::Set(None));
let json = r#"{"field": "hello"}"#;
let result: TestStruct = serde_json::from_str(json).unwrap();
assert!(result.field.is_set());
assert_eq!(result.field, Maybe::Set(Some("hello".to_string())));
let test = TestStruct {
field: Maybe::Unset,
};
let json = serde_json::to_string(&test).unwrap();
assert_eq!(json, "{}");
let test = TestStruct {
field: Maybe::Set(None),
};
let json = serde_json::to_string(&test).unwrap();
assert_eq!(json, r#"{"field":null}"#);
let test = TestStruct {
field: Maybe::Set(Some("world".to_string())),
};
let json = serde_json::to_string(&test).unwrap();
assert_eq!(json, r#"{"field":"world"}"#);
let default_maybe: Maybe<i32> = Maybe::default();
assert!(default_maybe.is_unset());
let unset: Maybe<String> = Maybe::Unset;
assert!(unset.is_unset());
assert!(!unset.is_set());
let set_none: Maybe<String> = Maybe::Set(None);
assert!(set_none.is_set());
assert!(!set_none.is_unset());
let set_some: Maybe<String> = Maybe::Set(Some("value".to_string()));
assert!(set_some.is_set());
assert!(!set_some.is_unset());
let original = TestStruct {
field: Maybe::Set(Some("test".to_string())),
};
let json = serde_json::to_string(&original).unwrap();
let deserialized: TestStruct = serde_json::from_str(&json).unwrap();
assert_eq!(original, deserialized);
let json = r#"{"value": 42}"#;
let result: NumericTest = serde_json::from_str(json).unwrap();
assert_eq!(result.value, Maybe::Set(Some(42)));
let json = r#"{"value": null}"#;
let result: NumericTest = serde_json::from_str(json).unwrap();
assert_eq!(result.value, Maybe::Set(None));
let json = "{}";
let result: NumericTest = serde_json::from_str(json).unwrap();
assert_eq!(result.value, Maybe::Unset);
// Test JsonSchema implementation
use schemars::schema_for;
let schema = schema_for!(Maybe<String>);
let schema_json = serde_json::to_value(&schema).unwrap();
// Verify the description mentions that null is an explicit value
let description = schema_json["description"].as_str().unwrap();
assert!(
description.contains("null") && description.contains("explicit"),
"Schema description should mention that null is an explicit value. Got: {}",
description
);
}
}

View File

@@ -8,7 +8,7 @@ use settings_macros::MergeFrom;
use util::serde::default_true;
use crate::{
AllLanguageSettingsContent, DelayMs, ExtendingVec, Maybe, ProjectTerminalSettingsContent,
AllLanguageSettingsContent, DelayMs, ExtendingVec, ProjectTerminalSettingsContent,
SlashCommandSettings,
};
@@ -61,8 +61,8 @@ pub struct WorktreeSettingsContent {
///
/// Default: null
#[serde(default)]
#[serde(skip_serializing_if = "Maybe::is_unset")]
pub project_name: Maybe<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub project_name: Option<String>,
/// Whether to prevent this project from being shared in public channels.
///
@@ -196,7 +196,7 @@ pub struct SessionSettingsContent {
}
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, MergeFrom, Debug)]
#[serde(tag = "source", rename_all = "snake_case")]
#[serde(untagged, rename_all = "snake_case")]
pub enum ContextServerSettingsContent {
Custom {
/// Whether the context server is enabled.
@@ -206,6 +206,16 @@ pub enum ContextServerSettingsContent {
#[serde(flatten)]
command: ContextServerCommand,
},
Http {
/// Whether the context server is enabled.
#[serde(default = "default_true")]
enabled: bool,
/// The URL of the remote context server.
url: String,
/// Optional headers to send.
#[serde(skip_serializing_if = "HashMap::is_empty", default)]
headers: HashMap<String, String>,
},
Extension {
/// Whether the context server is enabled.
#[serde(default = "default_true")]
@@ -217,19 +227,24 @@ pub enum ContextServerSettingsContent {
settings: serde_json::Value,
},
}
impl ContextServerSettingsContent {
pub fn set_enabled(&mut self, enabled: bool) {
match self {
ContextServerSettingsContent::Custom {
enabled: custom_enabled,
command: _,
..
} => {
*custom_enabled = enabled;
}
ContextServerSettingsContent::Extension {
enabled: ext_enabled,
settings: _,
..
} => *ext_enabled = enabled,
ContextServerSettingsContent::Http {
enabled: remote_enabled,
..
} => *remote_enabled = enabled,
}
}
}

View File

@@ -870,7 +870,7 @@ impl VsCodeSettings {
fn worktree_settings_content(&self) -> WorktreeSettingsContent {
WorktreeSettingsContent {
project_name: crate::Maybe::Unset,
project_name: None,
prevent_sharing_in_public_channels: false,
file_scan_exclusions: self
.read_value("files.watcherExclude")

View File

@@ -33,10 +33,10 @@ pub(crate) fn settings_data(cx: &App) -> Vec<SettingsPage> {
SettingField {
json_path: Some("project_name"),
pick: |settings_content| {
settings_content.project.worktree.project_name.as_ref()?.as_ref().or(DEFAULT_EMPTY_STRING)
settings_content.project.worktree.project_name.as_ref().or(DEFAULT_EMPTY_STRING)
},
write: |settings_content, value| {
settings_content.project.worktree.project_name = settings::Maybe::Set(value.filter(|name| !name.is_empty()));
settings_content.project.worktree.project_name = value.filter(|name| !name.is_empty());
},
}
),

View File

@@ -507,7 +507,6 @@ fn init_renderers(cx: &mut App) {
.add_basic_renderer::<settings::BufferLineHeightDiscriminants>(render_dropdown)
.add_basic_renderer::<settings::AutosaveSettingDiscriminants>(render_dropdown)
.add_basic_renderer::<settings::WorkingDirectoryDiscriminants>(render_dropdown)
.add_basic_renderer::<settings::MaybeDiscriminants>(render_dropdown)
.add_basic_renderer::<settings::IncludeIgnoredContent>(render_dropdown)
.add_basic_renderer::<settings::ShowIndentGuides>(render_dropdown)
.add_basic_renderer::<settings::ShellDiscriminants>(render_dropdown)

View File

@@ -37,10 +37,6 @@ pub struct SweepFeatureFlag;
impl FeatureFlag for SweepFeatureFlag {
const NAME: &str = "sweep-ai";
fn enabled_for_staff() -> bool {
false
}
}
#[derive(Clone)]

View File

@@ -77,6 +77,7 @@ impl RenderOnce for Modal {
.w_full()
.flex_1()
.gap(DynamicSpacing::Base08.rems(cx))
.when(self.footer.is_some(), |this| this.pb_4())
.when_some(
self.container_scroll_handler,
|this, container_scroll_handle| {
@@ -276,7 +277,6 @@ impl RenderOnce for ModalFooter {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
h_flex()
.w_full()
.mt_4()
.p(DynamicSpacing::Base08.rems(cx))
.flex_none()
.justify_between()

View File

@@ -35,7 +35,6 @@ multi_buffer.workspace = true
nvim-rs = { git = "https://github.com/KillTheMule/nvim-rs", rev = "764dd270c642f77f10f3e19d05cc178a6cbe69f3", features = ["use_tokio"], optional = true }
picker.workspace = true
project.workspace = true
project_panel.workspace = true
regex.workspace = true
schemars.workspace = true
search.workspace = true

View File

@@ -183,8 +183,6 @@ actions!(
InnerObject,
/// Maximizes the current pane.
MaximizePane,
/// Opens the default keymap file.
OpenDefaultKeymap,
/// Resets all pane sizes to default.
ResetPaneSizes,
/// Resizes the pane to the right.
@@ -314,7 +312,7 @@ pub fn init(cx: &mut App) {
workspace.register_action(|_, _: &ToggleProjectPanelFocus, window, cx| {
if Vim::take_count(cx).is_none() {
window.dispatch_action(project_panel::ToggleFocus.boxed_clone(), cx);
window.dispatch_action(zed_actions::project_panel::ToggleFocus.boxed_clone(), cx);
}
});
@@ -343,7 +341,7 @@ pub fn init(cx: &mut App) {
};
});
workspace.register_action(|_, _: &OpenDefaultKeymap, _, cx| {
workspace.register_action(|_, _: &zed_actions::vim::OpenDefaultKeymap, _, cx| {
cx.emit(workspace::Event::OpenBundledFile {
text: settings::vim_keymap(),
title: "Default Vim Bindings",

View File

@@ -66,7 +66,7 @@ impl Settings for WorktreeSettings {
.collect();
Self {
project_name: worktree.project_name.into_inner(),
project_name: worktree.project_name,
prevent_sharing_in_public_channels: worktree.prevent_sharing_in_public_channels,
file_scan_exclusions: path_matchers(file_scan_exclusions, "file_scan_exclusions")
.log_err()

View File

@@ -1002,7 +1002,7 @@ fn register_actions(
.register_action(open_project_debug_tasks_file)
.register_action(
|workspace: &mut Workspace,
_: &project_panel::ToggleFocus,
_: &zed_actions::project_panel::ToggleFocus,
window: &mut Window,
cx: &mut Context<Workspace>| {
workspace.toggle_panel_focus::<ProjectPanel>(window, cx);
@@ -4657,133 +4657,6 @@ mod tests {
});
}
/// Checks that action namespaces are the expected set. The purpose of this is to prevent typos
/// and let you know when introducing a new namespace.
#[gpui::test]
async fn test_action_namespaces(cx: &mut gpui::TestAppContext) {
use itertools::Itertools;
init_keymap_test(cx);
cx.update(|cx| {
let all_actions = cx.all_action_names();
let mut actions_without_namespace = Vec::new();
let all_namespaces = all_actions
.iter()
.filter_map(|action_name| {
let namespace = action_name
.split("::")
.collect::<Vec<_>>()
.into_iter()
.rev()
.skip(1)
.rev()
.join("::");
if namespace.is_empty() {
actions_without_namespace.push(*action_name);
}
if &namespace == "test_only" || &namespace == "stories" {
None
} else {
Some(namespace)
}
})
.sorted()
.dedup()
.collect::<Vec<_>>();
assert_eq!(actions_without_namespace, Vec::<&str>::new());
let expected_namespaces = vec![
"action",
"activity_indicator",
"agent",
#[cfg(not(target_os = "macos"))]
"app_menu",
"assistant",
"assistant2",
"auto_update",
"bedrock",
"branches",
"buffer_search",
"channel_modal",
"cli",
"client",
"collab",
"collab_panel",
"command_palette",
"console",
"context_server",
"copilot",
"debug_panel",
"debugger",
"dev",
"diagnostics",
"edit_prediction",
"editor",
"feedback",
"file_finder",
"git",
"git_onboarding",
"git_panel",
"go_to_line",
"icon_theme_selector",
"journal",
"keymap_editor",
"keystroke_input",
"language_selector",
"line_ending_selector",
"lsp_tool",
"markdown",
"menu",
"notebook",
"notification_panel",
"onboarding",
"outline",
"outline_panel",
"pane",
"panel",
"picker",
"project_panel",
"project_search",
"project_symbols",
"projects",
"repl",
"rules_library",
"search",
"settings_editor",
"settings_profile_selector",
"snippets",
"stash_picker",
"supermaven",
"svg",
"syntax_tree_view",
"tab_switcher",
"task",
"terminal",
"terminal_panel",
"theme_selector",
"toast",
"toolchain",
"variable_list",
"vim",
"window",
"workspace",
"zed",
"zed_actions",
"zed_predict_onboarding",
"zeta",
];
assert_eq!(
all_namespaces,
expected_namespaces
.into_iter()
.map(|namespace| namespace.to_string())
.sorted()
.collect::<Vec<_>>()
);
});
}
#[gpui::test]
fn test_bundled_settings_and_themes(cx: &mut App) {
cx.text_system()

View File

@@ -39,7 +39,7 @@ pub fn app_menus(cx: &mut App) -> Vec<Menu> {
],
}),
MenuItem::separator(),
MenuItem::action("Project Panel", project_panel::ToggleFocus),
MenuItem::action("Project Panel", zed_actions::project_panel::ToggleFocus),
MenuItem::action("Outline Panel", outline_panel::ToggleFocus),
MenuItem::action("Collab Panel", collab_panel::ToggleFocus),
MenuItem::action("Terminal Panel", terminal_panel::ToggleFocus),

View File

@@ -250,6 +250,17 @@ pub mod command_palette {
);
}
pub mod project_panel {
use gpui::actions;
actions!(
project_panel,
[
/// Toggles focus on the project panel.
ToggleFocus
]
);
}
pub mod feedback {
use gpui::actions;
@@ -532,6 +543,18 @@ actions!(
]
);
pub mod vim {
use gpui::actions;
actions!(
vim,
[
/// Opens the default keymap file.
OpenDefaultKeymap
]
);
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct WslConnectionOptions {
pub distro_name: String,

View File

@@ -2,6 +2,11 @@ use anyhow::{Context as _, Result};
use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
use std::{cmp, ops::Range, path::Path, sync::Arc};
const EDITS_TAG_NAME: &'static str = "edits";
const OLD_TEXT_TAG_NAME: &'static str = "old_text";
const NEW_TEXT_TAG_NAME: &'static str = "new_text";
const XML_TAGS: &[&str] = &[EDITS_TAG_NAME, OLD_TEXT_TAG_NAME, NEW_TEXT_TAG_NAME];
pub async fn parse_xml_edits<'a>(
input: &'a str,
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
@@ -12,38 +17,22 @@ pub async fn parse_xml_edits<'a>(
}
async fn parse_xml_edits_inner<'a>(
mut input: &'a str,
input: &'a str,
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
let edits_tag = parse_tag(&mut input, "edits")?.context("No edits tag")?;
let xml_edits = extract_xml_replacements(input)?;
input = edits_tag.body;
let (buffer, context_ranges) = get_buffer(xml_edits.file_path.as_ref())
.with_context(|| format!("no buffer for file {}", xml_edits.file_path))?;
let file_path = edits_tag
.attributes
.trim_start()
.strip_prefix("path")
.context("no file attribute on edits tag")?
.trim_end()
.strip_prefix('=')
.context("no value for path attribute")?
.trim()
.trim_start_matches('"')
.trim_end_matches('"');
let (buffer, context_ranges) = get_buffer(file_path.as_ref())
.with_context(|| format!("no buffer for file {file_path}"))?;
let mut edits = vec![];
while let Some(old_text_tag) = parse_tag(&mut input, "old_text")? {
let new_text_tag =
parse_tag(&mut input, "new_text")?.context("no new_text tag following old_text")?;
let match_range = fuzzy_match_in_ranges(old_text_tag.body, buffer, context_ranges)?;
let old_text = buffer
let mut all_edits = vec![];
for (old_text, new_text) in xml_edits.replacements {
let match_range = fuzzy_match_in_ranges(old_text, buffer, context_ranges)?;
let matched_old_text = buffer
.text_for_range(match_range.clone())
.collect::<String>();
let edits_within_hunk = language::text_diff(&old_text, &new_text_tag.body);
edits.extend(
let edits_within_hunk = language::text_diff(&matched_old_text, new_text);
all_edits.extend(
edits_within_hunk
.into_iter()
.map(move |(inner_range, inner_text)| {
@@ -56,7 +45,7 @@ async fn parse_xml_edits_inner<'a>(
);
}
Ok((buffer, edits))
Ok((buffer, all_edits))
}
fn fuzzy_match_in_ranges(
@@ -110,32 +99,128 @@ fn fuzzy_match_in_ranges(
);
}
struct ParsedTag<'a> {
attributes: &'a str,
body: &'a str,
#[derive(Debug)]
struct XmlEdits<'a> {
file_path: &'a str,
/// Vec of (old_text, new_text) pairs
replacements: Vec<(&'a str, &'a str)>,
}
fn parse_tag<'a>(input: &mut &'a str, tag: &str) -> Result<Option<ParsedTag<'a>>> {
let open_tag = format!("<{}", tag);
let close_tag = format!("</{}>", tag);
let Some(start_ix) = input.find(&open_tag) else {
return Ok(None);
};
let start_ix = start_ix + open_tag.len();
let closing_bracket_ix = start_ix
+ input[start_ix..]
fn extract_xml_replacements(input: &str) -> Result<XmlEdits<'_>> {
let mut cursor = 0;
let (edits_body_start, edits_attrs) =
find_tag_open(input, &mut cursor, EDITS_TAG_NAME)?.context("No edits tag found")?;
let file_path = edits_attrs
.trim_start()
.strip_prefix("path")
.context("no path attribute on edits tag")?
.trim_end()
.strip_prefix('=')
.context("no value for path attribute")?
.trim()
.trim_start_matches('"')
.trim_end_matches('"');
cursor = edits_body_start;
let mut edits_list = Vec::new();
while let Some((old_body_start, _)) = find_tag_open(input, &mut cursor, OLD_TEXT_TAG_NAME)? {
let old_body_end = find_tag_close(input, &mut cursor)?;
let old_text = trim_surrounding_newlines(&input[old_body_start..old_body_end]);
let (new_body_start, _) = find_tag_open(input, &mut cursor, NEW_TEXT_TAG_NAME)?
.context("no new_text tag following old_text")?;
let new_body_end = find_tag_close(input, &mut cursor)?;
let new_text = trim_surrounding_newlines(&input[new_body_start..new_body_end]);
edits_list.push((old_text, new_text));
}
Ok(XmlEdits {
file_path,
replacements: edits_list,
})
}
/// Trims a single leading and trailing newline
fn trim_surrounding_newlines(input: &str) -> &str {
let start = input.strip_prefix('\n').unwrap_or(input);
let end = start.strip_suffix('\n').unwrap_or(start);
end
}
fn find_tag_open<'a>(
input: &'a str,
cursor: &mut usize,
expected_tag: &str,
) -> Result<Option<(usize, &'a str)>> {
let mut search_pos = *cursor;
while search_pos < input.len() {
let Some(tag_start) = input[search_pos..].find("<") else {
break;
};
let tag_start = search_pos + tag_start;
if !input[tag_start + 1..].starts_with(expected_tag) {
search_pos = search_pos + tag_start + 1;
continue;
};
let after_tag_name = tag_start + expected_tag.len() + 1;
let close_bracket = input[after_tag_name..]
.find('>')
.with_context(|| format!("missing > after {tag}"))?;
let attributes = &input[start_ix..closing_bracket_ix].trim();
let end_ix = closing_bracket_ix
+ input[closing_bracket_ix..]
.find(&close_tag)
.with_context(|| format!("no `{close_tag}` tag"))?;
let body = &input[closing_bracket_ix + '>'.len_utf8()..end_ix];
let body = body.strip_prefix('\n').unwrap_or(body);
let body = body.strip_suffix('\n').unwrap_or(body);
*input = &input[end_ix + close_tag.len()..];
Ok(Some(ParsedTag { attributes, body }))
.with_context(|| format!("missing > after <{}", expected_tag))?;
let attrs_end = after_tag_name + close_bracket;
let body_start = attrs_end + 1;
let attributes = input[after_tag_name..attrs_end].trim();
*cursor = body_start;
return Ok(Some((body_start, attributes)));
}
Ok(None)
}
fn find_tag_close(input: &str, cursor: &mut usize) -> Result<usize> {
let mut depth = 1;
let mut search_pos = *cursor;
while search_pos < input.len() && depth > 0 {
let Some(bracket_offset) = input[search_pos..].find('<') else {
break;
};
let bracket_pos = search_pos + bracket_offset;
if input[bracket_pos..].starts_with("</")
&& let Some(close_end) = input[bracket_pos + 2..].find('>')
{
let close_start = bracket_pos + 2;
let tag_name = input[close_start..close_start + close_end].trim();
if XML_TAGS.contains(&tag_name) {
depth -= 1;
if depth == 0 {
*cursor = close_start + close_end + 1;
return Ok(bracket_pos);
}
}
search_pos = close_start + close_end + 1;
continue;
} else if let Some(close_bracket_offset) = input[bracket_pos..].find('>') {
let close_bracket_pos = bracket_pos + close_bracket_offset;
let tag_name = &input[bracket_pos + 1..close_bracket_pos].trim();
if XML_TAGS.contains(&tag_name) {
depth += 1;
}
}
search_pos = bracket_pos + 1;
}
anyhow::bail!("no closing tag found")
}
const REPLACEMENT_COST: u32 = 1;
@@ -357,17 +442,128 @@ mod tests {
use util::path;
#[test]
fn test_parse_tags() {
let mut input = indoc! {r#"
Prelude
<tag attr="foo">
tag value
</tag>
"# };
let parsed = parse_tag(&mut input, "tag").unwrap().unwrap();
assert_eq!(parsed.attributes, "attr=\"foo\"");
assert_eq!(parsed.body, "tag value");
assert_eq!(input, "\n");
fn test_extract_xml_edits() {
let input = indoc! {r#"
<edits path="test.rs">
<old_text>
old content
</old_text>
<new_text>
new content
</new_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "test.rs");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "old content");
assert_eq!(result.replacements[0].1, "new content");
}
#[test]
fn test_extract_xml_edits_with_wrong_closing_tags() {
let input = indoc! {r#"
<edits path="test.rs">
<old_text>
old content
</new_text>
<new_text>
new content
</old_text>
</ edits >
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "test.rs");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "old content");
assert_eq!(result.replacements[0].1, "new content");
}
#[test]
fn test_extract_xml_edits_with_xml_like_content() {
let input = indoc! {r#"
<edits path="component.tsx">
<old_text>
<foo><bar></bar></foo>
</old_text>
<new_text>
<foo><bar><baz></baz></bar></foo>
</new_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "component.tsx");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "<foo><bar></bar></foo>");
assert_eq!(
result.replacements[0].1,
"<foo><bar><baz></baz></bar></foo>"
);
}
#[test]
fn test_extract_xml_edits_with_conflicting_content() {
let input = indoc! {r#"
<edits path="component.tsx">
<old_text>
<new_text></new_text>
</old_text>
<new_text>
<old_text></old_text>
</new_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "component.tsx");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "<new_text></new_text>");
assert_eq!(result.replacements[0].1, "<old_text></old_text>");
}
#[test]
fn test_extract_xml_edits_multiple_pairs() {
let input = indoc! {r#"
Some reasoning before edits. Lots of thinking going on here
<edits path="test.rs">
<old_text>
first old
</old_text>
<new_text>
first new
</new_text>
<old_text>
second old
</edits>
<new_text>
second new
</old_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "test.rs");
assert_eq!(result.replacements.len(), 2);
assert_eq!(result.replacements[0].0, "first old");
assert_eq!(result.replacements[0].1, "first new");
assert_eq!(result.replacements[1].0, "second old");
assert_eq!(result.replacements[1].1, "second new");
}
#[test]
fn test_extract_xml_edits_unexpected_eof() {
let input = indoc! {r#"
<edits path="test.rs">
<old_text>
first old
</
"#};
extract_xml_replacements(input).expect_err("Unexpected end of file");
}
#[gpui::test]

View File

@@ -183,7 +183,7 @@ macro_rules! time {
$crate::Timer::new($logger, $name)
};
($name:expr) => {
$crate::time!($crate::default_logger!() => $name)
time!($crate::default_logger!() => $name)
};
}

View File

@@ -126,6 +126,7 @@
- [Markdown](./languages/markdown.md)
- [Nim](./languages/nim.md)
- [OCaml](./languages/ocaml.md)
- [OpenTofu](./languages/opentofu.md)
- [PHP](./languages/php.md)
- [PowerShell](./languages/powershell.md)
- [Prisma](./languages/prisma.md)

View File

@@ -40,11 +40,14 @@ You can connect them by adding their commands directly to your `settings.json`,
```json [settings]
{
"context_servers": {
"your-mcp-server": {
"source": "custom",
"run-command": {
"command": "some-command",
"args": ["arg-1", "arg-2"],
"env": {}
},
"over-http": {
"url": "custom",
"headers": { "Authorization": "Bearer <token>" }
}
}
}

View File

@@ -0,0 +1,20 @@
# OpenTofu
OpenTofu support is available through the [OpenTofu extension](https://github.com/ashpool37/zed-extension-opentofu).
- Tree-sitter: [MichaHoffmann/tree-sitter-hcl](https://github.com/MichaHoffmann/tree-sitter-hcl)
- Language Server: [opentofu/tofu-ls](https://github.com/opentofu/tofu-ls)
## Configuration
In order to automatically use the OpenTofu extension and language server when editing .tf and .tfvars files,
either uninstall the Terraform extension or add this to your settings.json:
```json
"file_types": {
"OpenTofu": ["tf"],
"OpenTofu Vars": ["tfvars"]
},
```
See the [full list of server settings here](https://github.com/opentofu/tofu-ls/blob/main/docs/SETTINGS.md).