Compare commits

...

33 Commits

Author SHA1 Message Date
Joseph T. Lyons
99335128eb zed 0.181.3 2025-04-04 10:45:15 -04:00
Agus Zubiaga
e8457656c6 ai: Separate model settings for each feature (#28088)
Closes: https://github.com/zed-industries/zed/issues/20582

Allows users to select a specific model for each AI-powered feature:
- Agent panel
- Inline assistant
- Thread summarization
- Commit message generation

If unspecified for a given feature, it will use the `default_model`
setting.

Release Notes:

- Added support for configuring a specific model for each AI-powered
feature

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
2025-04-04 10:43:27 -04:00
Thomas Mickley-Doyle
b037c913e5 assistant_eval: Add ACE framework (#27181)
Release Notes:

- N/A

---------

Co-authored-by: Michael Sloan <michael@zed.dev>
2025-04-04 10:43:17 -04:00
Agus Zubiaga
aca7f54773 agent: Add search to Thread History (#28085)
![CleanShot 2025-04-04 at 09 45
47@2x](https://github.com/user-attachments/assets/a8ec4086-f71e-4ff4-a5b3-4eb5d4c48294)


Release Notes:

- agent: Add search box to thread history

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
2025-04-04 09:28:14 -04:00
gcp-cherry-pick-bot[bot]
10a3ad078c Clear path-based excerpt data properly (cherry-pick #28026) (#28082)
Cherry-picked Clear path-based excerpt data properly (#28026)

Follow-up of https://github.com/zed-industries/zed/pull/27893

Release Notes:

- N/A

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>

Co-authored-by: Kirill Bulatov <kirill@zed.dev>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-04-04 09:18:56 -04:00
gcp-cherry-pick-bot[bot]
c4bbdd03c5 Temporarily disable flaky conflicted-cherry-pick test (cherry-pick #27950) (#28087)
Cherry-picked Temporarily disable flaky conflicted-cherry-pick test
(#27950)

Closes #ISSUE

Release Notes:

- N/A

Co-authored-by: Cole Miller <cole@zed.dev>
2025-04-04 09:18:20 -04:00
Antonio Scandurra
27a47233c9 Implement edit rejection in ActionLog (#28080)
Release Notes:

- Fixed a bug that would prevent rejecting certain agent edits.
2025-04-04 08:32:26 -04:00
Bennet Bo Fenner
0b8c14f0d6 agent: Show which lines were read when using read_file tool (#28077)
This makes sure that we specify which lines the agent actually read,
avoids confusing scenarios such as:

<img width="642" alt="Screenshot 2025-04-04 at 10 22 10"
src="https://github.com/user-attachments/assets/2680c313-4f77-4971-8743-8e3f5327c18d"
/>

Here the agent starts out by actually only reading a certain amount of
lines when the first tool call happens, then it does a second tool call
to read the whole file. To the user this looks like to identical tool
calls.

Now:
<img width="621" alt="image"
src="https://github.com/user-attachments/assets/76222258-9cc8-4b7c-98c0-6d5cffb282f2"
/>
<img width="362" alt="image"
src="https://github.com/user-attachments/assets/293f2fc0-365d-4b84-8400-4c11474caeb8"
/>
<img width="420" alt="image"
src="https://github.com/user-attachments/assets/ca92493e-67ce-4d45-8f83-0168df575326"
/>



Release Notes:

- N/A
2025-04-04 08:31:37 -04:00
Bennet Bo Fenner
800f40524b agent: Differentiate @mentions from markdown links (#28073)
This ensures that we display @mentions and normal markdown links
differently:

<img width="670" alt="Screenshot 2025-04-04 at 11 07 51"
src="https://github.com/user-attachments/assets/0a4d0881-abb9-42a8-b3fa-912cd6873ae0"
/>


Release Notes:

- N/A
2025-04-04 08:31:19 -04:00
Antonio Scandurra
1f4a2d50a0 Scroll to first hunk when clicking on a file to review in Agent Panel (#28075)
Release Notes:

- Added the ability to scroll to a file when clicking on it in the Agent
Panel review section.
2025-04-04 08:31:14 -04:00
Marshall Bowers
c706acbfcb open_ai: Disable parallel_tool_calls (#28056)
This PR disables `parallel_tool_calls` for the models that support it,
as the Agent currently expects at most one tool use per turn.

It was a bit of trial and error to figure this out. OpenAI's API
annoyingly will return an error if passing `parallel_tool_calls` to a
model that doesn't support it.

Release Notes:

- N/A
2025-04-04 00:37:35 -04:00
Marshall Bowers
3f19ae1689 Add tool use support for OpenAI models (#28051)
This PR adds support for using tools to the OpenAI models.

Release Notes:

- agent: Added support for tool use with OpenAI models (Preview only).
2025-04-04 00:37:26 -04:00
Marshall Bowers
4744c4ff7e Remove unused extract_tool_args_from_events functions (#28038)
This PR removes the unused `extract_tool_args_from_events` functions
that were defined in some of the LLM provider crates.

Release Notes:

- N/A
2025-04-04 00:37:10 -04:00
Nate Butler
aa168c696b ui_input: TextField -> SingleLineInput (#28031)
- Rename `TextField` -> `SingleLineInput`
- Add a component preview for `SingleLineInput`
- Apply `SingleLineInput` to the AddContextServerModal

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <hi@aguz.me>
Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Danilo Leal <67129314+danilo-leal@users.noreply.github.com>
2025-04-03 23:57:13 -04:00
Agus Zubiaga
ce2918ef46 agent: Snapshot context in user message instead of recreating it (#27967)
This makes context essentially work the same way as `read-file`,
increasing the likelihood of cache hits.

Just like with `read-file`, we'll notify the model when the user makes
an edit to one of the tracked files. In the future, we want to send a
diff instead of just a list of files, but that's an orthogonal change.


Release Notes:
- agent: Improved caching of files in context

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
2025-04-03 23:56:49 -04:00
Joseph T. Lyons
3cfaadd2af zed 0.181.2 2025-04-03 14:45:38 -04:00
Danilo Leal
b69219e9e0 agent: Add token count in the thread view (#28037)
This PR adds the token count to the active thread view. It doesn't
behaves quite like Assistant 1 where it updates as you type, though; it
updates after you submit the message.

<img
src="https://github.com/user-attachments/assets/82d2a180-554a-43ee-b776-3743359b609b"
width="700" />

---

Release Notes:

- agent: Add token count in the thread view

---------

Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-04-03 14:44:32 -04:00
Antonio Scandurra
0ae3518b1c Fix soft-wrapping with fold creases (#28029)
Release Notes:

- Fixed a rendering bug that caused context in the agent to not wrap
properly.

---------

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
Co-authored-by: Zed AI <ai+claude-3.7@zed.dev>
2025-04-03 14:21:43 -04:00
Agus Zubiaga
73964353a4 agent: Handle tool use without text (#28030)
### Context 

The Anthropic API fails if a request message contains a tool use and no
`Text` segments or it only contains empty `Text` segments. These are
cases that the model itself produces, but the API doesn't support
sending them back.

#27917 fixed this by appending "Using tool..." in the thread's message,
but this causes the actual conversation to include it, so it would
appear in the UI (we would actually display a gap because we never
rendered its markdown, but "Using tool..." would show up when the thread
was restored).

### Solution

We'll now only append this placeholder when we build the request, so the
API still sees it, but the UI/Thread doesn't.

Another issue we found is that the model starts mimicking these
placeholders in later tool uses which is undesirable. So unfortunately,
we had to add logic to filter them out.

Release Notes:

- agent: Improved rendering of tool uses without text

---------

Co-authored-by: Bennet <bennet@zed.dev>
2025-04-03 14:20:56 -04:00
Danilo Leal
bf6e7cb6ee agent: Add button to continue iterating once all reviews are done (#28027)
This PR adds a button on the review tab empty state that toggles the
focus back to the agent panel so that users can keep iterating on the
thread that's active in the panel.

<img
src="https://github.com/user-attachments/assets/ace5cf93-8869-49bb-8106-e03a9e3c90f2"
width="700"/>

Release Notes:

- N/A
2025-04-03 14:20:37 -04:00
Joseph T. Lyons
36c4f6082c zed 0.181.1 2025-04-03 09:50:23 -04:00
Bennet Bo Fenner
3d032bcf2c agent: Fix thinking step showing up as pending when completion is cancelled (#28019)
Previously the "Thinking..." step would show up as pending, even though
the user cancelled the generation:
<img width="672" alt="image"
src="https://github.com/user-attachments/assets/c9cdce0a-d827-4e23-96f5-b150465911a7"
/>


Release Notes:

- Fixed an issue where the thinking step would show up as pending even
when the generation was cancelled
2025-04-03 09:49:24 -04:00
Agus Zubiaga
3e28fa2cc4 agent: Include active file in recent history (#27914)
This happened because of two reasons:

- `Workspace::recent_navigation_history` didn't include the current file
- The context picker added the current file to a exclude list

The latter was actually intentional because we already show the file in
the suggested context, but now that we actually have mentions, it's just
inconvenient not to have it there.

Release Notes:

- N/A
2025-04-03 09:49:18 -04:00
Julia Ryan
c42442974d workspace-hack: remove openssl from remote_server (#27990)
This was accidentally getting added due to increased feature
unification. We've manually excluded reqwest to go back to the desired
behavior: remote_server, doesn't depend on openssl.

Release Notes:

- N/A

Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
2025-04-03 08:54:26 -04:00
Julia Ryan
f23b972203 Add workspace-hack (#27277)
This adds a "workspace-hack" crate, see
[mozilla's](https://hg.mozilla.org/mozilla-central/file/3a265fdc9f33e5946f0ca0a04af73acd7e6d1a39/build/workspace-hack/Cargo.toml#l7)
for a concise explanation of why this is useful. For us in practice this
means that if I were to run all the tests (`cargo nextest r
--workspace`) and then `cargo r`, all the deps from the previous cargo
command will be reused. Before this PR it would rebuild many deps due to
resolving different sets of features for them. For me this frequently
caused long rebuilds when things "should" already be cached.

To avoid manually maintaining our workspace-hack crate, we will use
[cargo hakari](https://docs.rs/cargo-hakari) to update the build files
when there's a necessary change. I've added a step to CI that checks
whether the workspace-hack crate is up to date, and instructs you to
re-run `script/update-workspace-hack` when it fails.

Finally, to make sure that people can still depend on crates in our
workspace without pulling in all the workspace deps, we use a `[patch]`
section following [hakari's
instructions](https://docs.rs/cargo-hakari/0.9.36/cargo_hakari/patch_directive/index.html)

One possible followup task would be making guppy use our
`rust-toolchain.toml` instead of having to duplicate that list in its
config, I opened an issue for that upstream: guppy-rs/guppy#481.

TODO:
- [x] Fix the extension test failure
- [x] Ensure the dev dependencies aren't being unified by Hakari into
the main dependencies
- [x] Ensure that the remote-server binary continues to not depend on
LibSSL

Release Notes:

- N/A

---------

Co-authored-by: Mikayla <mikayla@zed.dev>
Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
2025-04-03 08:54:26 -04:00
5brian
07d9cd7e88 agent: Update thread label to use plural form (#27971)
Update thread label to match the other contexts.

|Before|After|
|--|--|

|![image](https://github.com/user-attachments/assets/6e02808e-50d7-480f-a9ca-251e9519a71d)|![image](https://github.com/user-attachments/assets/174aad84-9e55-4531-bb4a-1a1adaa46418)|

Release Notes:

- N/A
2025-04-03 08:54:01 -04:00
Marshall Bowers
a48238701d agent: Allow editing previous messages (#27965)
This PR adds the ability to edit previous user messages in the thread.

Release Notes:

- Agent: Added the ability to edit previous user messages
(Preview-only).
2025-04-03 08:54:01 -04:00
Danilo Leal
d17d747c62 agent: Change loading label if command is waiting on permission (#27955)
If there's a command pending confirmation, the label changes from
"Generating" to "Waiting for confirmation".

<img
src="https://github.com/user-attachments/assets/d804e382-5315-40b0-9588-c257cca2430c"
width="600"/>

Release Notes:

- N/A
2025-04-03 08:54:01 -04:00
Danilo Leal
bc08df2dfd agent: Refine feedback message input (#27948)
<img
src="https://github.com/user-attachments/assets/cde37a88-9973-4c27-80b7-459f5e986c74"
width="650" />

Release Notes:

- N/A

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
2025-04-03 08:54:01 -04:00
Shardul Vaidya
0f4b734d91 aws_http_client: Copy response headers (#27941)
Preemptive fixes required for #26734

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
2025-04-03 08:53:37 -04:00
Marshall Bowers
6df29d3279 agent: Do some cleanup of feedback comments submission (#27940)
This PR does some stylistic cleanup of the feedback comments submission
code.

Release Notes:

- N/A
2025-04-03 08:53:29 -04:00
Michael Sloan
647cca8c8d Use worktree qualified paths in agent file context + some code cleanup (#27943)
Release Notes:

- N/A
2025-04-03 08:53:19 -04:00
Joseph T. Lyons
1f81674927 v0.181.x preview 2025-04-02 13:44:55 -04:00
258 changed files with 6601 additions and 2802 deletions

43
.config/hakari.toml Normal file
View File

@@ -0,0 +1,43 @@
# This file contains settings for `cargo hakari`.
# See https://docs.rs/cargo-hakari/latest/cargo_hakari/config for a full list of options.
hakari-package = "workspace-hack"
resolver = "2"
dep-format-version = "4"
workspace-hack-line-style = "workspace-dotted"
# this should be the same list as "targets" in ../rust-toolchain.toml
platforms = [
"x86_64-apple-darwin",
"aarch64-apple-darwin",
"x86_64-unknown-linux-gnu",
"aarch64-unknown-linux-gnu",
"x86_64-pc-windows-msvc",
"x86_64-unknown-linux-musl", # remote server
]
[traversal-excludes]
workspace-members = [
"remote_server",
]
third-party = [
{ name = "reqwest", version = "0.11.27" },
]
[final-excludes]
workspace-members = [
"zed_extension_api",
# exclude all extensions
"zed_emmet",
"zed_glsl",
"zed_html",
"perplexity",
"zed_proto",
"zed_ruff",
"slash_commands_example",
"zed_snippets",
"zed_test_extension",
"zed_toml",
]

View File

@@ -110,6 +110,37 @@ jobs:
input: "crates/proto/proto/"
against: "https://github.com/${GITHUB_REPOSITORY}.git#branch=${BUF_BASE_BRANCH},subdir=crates/proto/proto/"
workspace_hack:
timeout-minutes: 60
name: Check workspace-hack crate
needs: [job_spec]
if: github.repository_owner == 'zed-industries'
runs-on:
- buildjet-8vcpu-ubuntu-2204
steps:
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
- name: Add Rust to the PATH
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Install cargo-hakari
uses: clechasseur/rs-cargo@8435b10f6e71c2e3d4d3b7573003a8ce4bfc6386 # v2
with:
command: install
args: cargo-hakari@0.9.35
- name: Check workspace-hack Cargo.toml is up-to-date
run: |
cargo hakari generate --diff || {
echo "To fix, run script/update-workspace-hack";
false
}
- name: Check all crates depend on workspace-hack
run: |
cargo hakari manage-deps --dry-run || {
echo "To fix, run script/update-workspace-hack"
false
}
style:
timeout-minutes: 60
name: Check formatting and spelling
@@ -432,6 +463,7 @@ jobs:
- job_spec
- style
- migration_checks
- workspace_hack
- linux_tests
- build_remote_server
- macos_tests

1817
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -192,6 +192,7 @@ members = [
# Tooling
#
"tooling/workspace-hack",
"tooling/xtask",
]
default-members = ["crates/zed"]
@@ -578,6 +579,7 @@ unicode-script = "0.5.7"
url = "2.2"
urlencoding = "2.1.2"
uuid = { version = "1.1.2", features = ["v4", "v5", "v7", "serde"] }
walkdir = "2.3"
wasmparser = "0.221"
wasm-encoder = "0.221"
wasmtime = { version = "29", default-features = false, features = [
@@ -590,6 +592,7 @@ wasmtime = { version = "29", default-features = false, features = [
wasmtime-wasi = "29"
which = "6.0.0"
wit-component = "0.221"
workspace-hack = "0.1.0"
zed_llm_client = "0.4"
zstd = "0.11"
metal = "0.29"
@@ -660,6 +663,9 @@ real-async-tls = { git = "https://github.com/zed-industries/async-tls", rev = "1
notify = { git = "https://github.com/zed-industries/notify.git", rev = "bbb9ea5ae52b253e095737847e367c30653a2e96" }
notify-types = { git = "https://github.com/zed-industries/notify.git", rev = "bbb9ea5ae52b253e095737847e367c30653a2e96" }
# Makes the workspace hack crate refer to the local one, but only when you're building locally
workspace-hack = { path = "tooling/workspace-hack" }
[profile.dev]
split-debuginfo = "unpacked"
debug = "limited"
@@ -772,4 +778,4 @@ let_underscore_future = "allow"
too_many_arguments = "allow"
[workspace.metadata.cargo-machete]
ignored = ["bindgen", "cbindgen", "prost_build", "serde", "component", "linkme"]
ignored = ["bindgen", "cbindgen", "prost_build", "serde", "component", "linkme", "workspace-hack"]

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-forward-icon lucide-forward"><polyline points="15 17 20 12 15 7"/><path d="M4 18v-2a4 4 0 0 1 4-4h12"/></svg>

After

Width:  |  Height:  |  Size: 312 B

View File

@@ -657,6 +657,15 @@
"alt-enter": "editor::Newline"
}
},
{
"context": "AgentFeedbackMessageEditor > Editor",
"use_key_equivalents": true,
"bindings": {
"escape": "menu::Cancel",
"enter": "menu::Confirm",
"alt-enter": "editor::Newline"
}
},
{
"context": "ContextStrip",
"bindings": {

View File

@@ -317,6 +317,15 @@
"alt-enter": "editor::Newline"
}
},
{
"context": "AgentFeedbackMessageEditor > Editor",
"use_key_equivalents": true,
"bindings": {
"escape": "menu::Cancel",
"enter": "menu::Confirm",
"alt-enter": "editor::Newline"
}
},
{
"context": "ContextStrip",
"use_key_equivalents": true,

View File

@@ -25,6 +25,7 @@ smallvec.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
editor = { workspace = true, features = ["test-support"] }

View File

@@ -81,11 +81,13 @@ theme.workspace = true
time.workspace = true
time_format.workspace = true
ui.workspace = true
ui_input.workspace = true
util.workspace = true
uuid.workspace = true
vim_mode_setting.workspace = true
workspace.workspace = true
zed_actions.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
buffer_diff = { workspace = true, features = ["test-support"] }

View File

@@ -21,7 +21,7 @@ use gpui::{
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
};
use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::{Markdown, MarkdownStyle};
use project::ProjectItem as _;
use settings::{Settings as _, update_settings_file};
@@ -34,7 +34,7 @@ use ui::{Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, Tooltip,
use util::ResultExt as _;
use workspace::{OpenOptions, Workspace};
use crate::context_store::{ContextStore, refresh_context_store_text};
use crate::context_store::ContextStore;
pub struct ActiveThread {
language_registry: Arc<LanguageRegistry>,
@@ -55,8 +55,7 @@ pub struct ActiveThread {
notifications: Vec<WindowHandle<AgentNotification>>,
_subscriptions: Vec<Subscription>,
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
showing_feedback_comments: bool,
feedback_comments_editor: Option<Entity<Editor>>,
feedback_message_editor: Option<Entity<Editor>>,
}
struct RenderedMessage {
@@ -246,6 +245,17 @@ fn render_markdown(
}),
..Default::default()
},
link_callback: Some(Rc::new(move |url, cx| {
if MentionLink::is_valid(url) {
let colors = cx.theme().colors();
Some(TextStyleRefinement {
background_color: Some(colors.element_background),
..Default::default()
})
} else {
None
}
})),
..Default::default()
};
@@ -321,6 +331,7 @@ fn open_markdown_link(
});
}
}),
Some(MentionLink::Fetch(url)) => cx.open_url(&url),
None => cx.open_url(&text),
}
}
@@ -371,8 +382,7 @@ impl ActiveThread {
notifications: Vec::new(),
_subscriptions: subscriptions,
notification_subscriptions: HashMap::default(),
showing_feedback_comments: false,
feedback_comments_editor: None,
feedback_message_editor: None,
};
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
@@ -595,54 +605,14 @@ impl ActiveThread {
}
if self.thread.read(cx).all_tools_finished() {
let pending_refresh_buffers = self.thread.update(cx, |thread, cx| {
thread.action_log().update(cx, |action_log, _cx| {
action_log.take_stale_buffers_in_context()
})
});
let context_update_task = if !pending_refresh_buffers.is_empty() {
let refresh_task = refresh_context_store_text(
self.context_store.clone(),
&pending_refresh_buffers,
cx,
);
cx.spawn(async move |this, cx| {
let updated_context_ids = refresh_task.await;
this.update(cx, |this, cx| {
this.context_store.read_with(cx, |context_store, _cx| {
context_store
.context()
.iter()
.filter(|context| {
updated_context_ids.contains(&context.id())
})
.cloned()
.collect()
})
})
})
} else {
Task::ready(anyhow::Ok(Vec::new()))
};
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() {
cx.spawn(async move |this, cx| {
let updated_context = context_update_task.await?;
this.update(cx, |this, cx| {
this.thread.update(cx, |thread, cx| {
thread.attach_tool_results(updated_context, cx);
if !canceled {
thread.send_to_model(model, RequestKind::Chat, cx);
}
});
})
})
.detach();
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
self.thread.update(cx, |thread, cx| {
thread.attach_tool_results(cx);
if !canceled {
thread.send_to_model(model, RequestKind::Chat, cx);
}
});
}
}
}
@@ -844,38 +814,21 @@ impl ActiveThread {
}
});
let provider = LanguageModelRegistry::read_global(cx).active_provider();
if provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx))
{
cx.notify();
return;
}
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.active_model() else {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return;
};
if model.provider.must_accept_terms(cx) {
cx.notify();
return;
}
self.thread.update(cx, |thread, cx| {
thread.send_to_model(model, RequestKind::Chat, cx)
thread.send_to_model(model.model, RequestKind::Chat, cx)
});
cx.notify();
}
fn last_user_message(&self, cx: &Context<Self>) -> Option<MessageId> {
self.messages
.iter()
.rev()
.find(|message_id| {
self.thread
.read(cx)
.message(**message_id)
.map_or(false, |message| message.role == Role::User)
})
.cloned()
}
fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
self.messages
.iter()
@@ -923,77 +876,59 @@ impl ActiveThread {
}
fn handle_show_feedback_comments(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.showing_feedback_comments = true;
if self.feedback_comments_editor.is_none() {
let buffer = cx.new(|cx| {
let empty_string = String::new();
MultiBuffer::singleton(cx.new(|cx| Buffer::local(empty_string, cx)), cx)
});
let editor = cx.new(|cx| {
Editor::new(
editor::EditorMode::AutoHeight { max_lines: 4 },
buffer,
None,
window,
cx,
)
});
self.feedback_comments_editor = Some(editor);
if self.feedback_message_editor.is_some() {
return;
}
let buffer = cx.new(|cx| {
let empty_string = String::new();
MultiBuffer::singleton(cx.new(|cx| Buffer::local(empty_string, cx)), cx)
});
let editor = cx.new(|cx| {
let mut editor = Editor::new(
editor::EditorMode::AutoHeight { max_lines: 4 },
buffer,
None,
window,
cx,
);
editor.set_placeholder_text(
"What went wrong? Share your feedback so we can improve.",
cx,
);
editor
});
editor.read(cx).focus_handle(cx).focus(window);
self.feedback_message_editor = Some(editor);
cx.notify();
}
fn handle_submit_comments(
&mut self,
_: &ClickEvent,
_window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some(editor) = self.feedback_comments_editor.clone() {
let comments = editor.read(cx).text(cx);
fn submit_feedback_message(&mut self, cx: &mut Context<Self>) {
let Some(editor) = self.feedback_message_editor.clone() else {
return;
};
// Submit negative feedback
let report = self.thread.update(cx, |thread, cx| {
thread.report_feedback(ThreadFeedback::Negative, cx)
});
let report_task = self.thread.update(cx, |thread, cx| {
thread.report_feedback(ThreadFeedback::Negative, cx)
});
if !comments.is_empty() {
let thread_id = self.thread.read(cx).id().clone();
let comments_value = String::from(comments.as_str());
let comments = editor.read(cx).text(cx);
if !comments.is_empty() {
let thread_id = self.thread.read(cx).id().clone();
// Log comments as a separate telemetry event
telemetry::event!(
"Assistant Thread Feedback Comments",
thread_id,
comments = comments_value
);
}
self.showing_feedback_comments = false;
self.feedback_comments_editor = None;
let this = cx.entity().downgrade();
cx.spawn(async move |_, cx| {
report.await?;
this.update(cx, |_this, cx| cx.notify())
})
.detach_and_log_err(cx);
telemetry::event!("Assistant Thread Feedback Comments", thread_id, comments);
}
}
fn handle_cancel_comments(
&mut self,
_: &ClickEvent,
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.showing_feedback_comments = false;
self.feedback_comments_editor = None;
cx.notify();
self.feedback_message_editor = None;
let this = cx.entity().downgrade();
cx.spawn(async move |_, cx| {
report_task.await?;
this.update(cx, |_this, cx| cx.notify())
})
.detach_and_log_err(cx);
}
fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
@@ -1021,8 +956,7 @@ impl ActiveThread {
return Empty.into_any();
}
let allow_editing_message =
message.role == Role::User && self.last_user_message(cx) == Some(message_id);
let allow_editing_message = message.role == Role::User;
let edit_message_editor = self
.editing_message
@@ -1133,36 +1067,47 @@ impl ActiveThread {
.into_any_element(),
};
let message_content = v_flex()
.gap_1p5()
.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
div()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.child(edit_message_editor)
} else {
div()
.min_h_6()
.text_ui(cx)
.child(self.render_message_content(
message_id,
rendered_message,
has_tool_uses,
cx,
))
},
)
.when(!context.is_empty(), |parent| {
parent.child(
h_flex()
.flex_wrap()
.gap_1()
.children(context.into_iter().map(|context| {
let context_id = context.id();
ContextPill::added(AddedContext::new(context, cx), false, false, None)
let message_is_empty = message.should_display_content();
let has_content = !message_is_empty || !context.is_empty();
let message_content =
has_content.then(|| {
v_flex()
.gap_1p5()
.when(!message_is_empty, |parent| {
parent.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
div()
.key_context("EditMessageEditor")
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.child(edit_message_editor)
.into_any()
} else {
div()
.min_h_6()
.text_ui(cx)
.child(self.render_message_content(
message_id,
rendered_message,
has_tool_uses,
cx,
))
.into_any()
},
)
})
.when(!context.is_empty(), |parent| {
parent.child(h_flex().flex_wrap().gap_1().children(
context.into_iter().map(|context| {
let context_id = context.id();
ContextPill::added(
AddedContext::new(context, cx),
false,
false,
None,
)
.on_click(Rc::new(cx.listener({
let workspace = workspace.clone();
let context_store = context_store.clone();
@@ -1179,8 +1124,9 @@ impl ActiveThread {
}
}
})))
})),
)
}),
))
})
});
let styled_message = match message.role {
@@ -1229,10 +1175,6 @@ impl ActiveThread {
)
.child(
h_flex()
// DL: To double-check whether we want to fully remove
// the editing feature from meassages. Checkpoint sort of
// solve the same problem.
.invisible()
.gap_1()
.when_some(
edit_message_editor.clone(),
@@ -1299,7 +1241,7 @@ impl ActiveThread {
),
),
)
.child(div().p_2().child(message_content)),
.child(div().p_2().children(message_content)),
),
Role::Assistant => v_flex()
.id(("message-container", ix))
@@ -1308,7 +1250,9 @@ impl ActiveThread {
.pr_4()
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.child(message_content)
.children(message_content)
.gap_2p5()
.pb_2p5()
.when(!tool_uses.is_empty(), |parent| {
parent.child(
v_flex().children(
@@ -1322,7 +1266,7 @@ impl ActiveThread {
v_flex()
.bg(colors.editor_background)
.rounded_sm()
.child(div().p_4().child(message_content)),
.child(div().p_4().children(message_content)),
),
};
@@ -1404,51 +1348,76 @@ impl ActiveThread {
.when(
show_feedback && !self.thread.read(cx).is_generating(),
|parent| {
parent
.child(feedback_items)
.when(self.showing_feedback_comments, |parent| {
parent.child(feedback_items).when_some(
self.feedback_message_editor.clone(),
|parent, feedback_editor| {
let focus_handle = feedback_editor.focus_handle(cx);
parent.child(
v_flex()
.gap_1()
.px_4()
.child(
Label::new(
"Please share your feedback to help us improve:",
)
.size(LabelSize::Small),
)
.child(
div()
.p_2()
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().editor_background)
.child(
self.feedback_comments_editor
.as_ref()
.unwrap()
.clone(),
),
)
.key_context("AgentFeedbackMessageEditor")
.on_action(cx.listener(|this, _: &menu::Cancel, _, cx| {
this.feedback_message_editor = None;
cx.notify();
}))
.on_action(cx.listener(|this, _: &menu::Confirm, _, cx| {
this.submit_feedback_message(cx);
cx.notify();
}))
.on_action(cx.listener(Self::confirm_editing_message))
.mx_4()
.mb_3()
.p_2()
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().editor_background)
.child(feedback_editor)
.child(
h_flex()
.gap_1()
.justify_end()
.pb_2()
.child(
Button::new("cancel-comments", "Cancel").on_click(
cx.listener(Self::handle_cancel_comments),
),
Button::new("dismiss-feedback-message", "Cancel")
.label_size(LabelSize::Small)
.key_binding(
KeyBinding::for_action_in(
&menu::Cancel,
&focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(10.))),
)
.on_click(cx.listener(|this, _, _, cx| {
this.feedback_message_editor = None;
cx.notify();
})),
)
.child(
Button::new("submit-comments", "Submit").on_click(
cx.listener(Self::handle_submit_comments),
),
Button::new(
"submit-feedback-message",
"Share Feedback",
)
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.label_size(LabelSize::Small)
.key_binding(
KeyBinding::for_action_in(
&menu::Confirm,
&focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(10.))),
)
.on_click(cx.listener(|this, _, _, cx| {
this.submit_feedback_message(cx);
cx.notify();
})),
),
),
)
})
},
)
},
)
.into_any()
@@ -1462,7 +1431,8 @@ impl ActiveThread {
cx: &Context<Self>,
) -> impl IntoElement {
let is_last_message = self.messages.last() == Some(&message_id);
let pending_thinking_segment_index = if is_last_message && !has_tool_uses {
let is_generating = self.thread.read(cx).is_generating();
let pending_thinking_segment_index = if is_generating && is_last_message && !has_tool_uses {
rendered_message
.segments
.iter()
@@ -1828,7 +1798,7 @@ impl ActiveThread {
div().map(|element| {
if !tool_use.needs_confirmation {
element.py_2p5().child(
element.child(
v_flex()
.child(
h_flex()
@@ -1900,145 +1870,164 @@ impl ActiveThread {
}),
)
} else {
element.py_2().child(
v_flex()
.rounded_lg()
.border_1()
.border_color(self.tool_card_border_color(cx))
.overflow_hidden()
.child(
h_flex()
.group("disclosure-header")
.relative()
.justify_between()
.py_1()
.map(|element| {
if is_status_finished {
element.pl_2().pr_0p5()
} else {
element.px_2()
}
})
.bg(self.tool_card_header_bg(cx))
.map(|element| {
if is_open {
element.border_b_1().rounded_t_md()
} else if needs_confirmation {
element.rounded_t_md()
} else {
element.rounded_md()
}
})
.border_color(self.tool_card_border_color(cx))
.child(
h_flex()
.id("tool-label-container")
.gap_1p5()
.max_w_full()
.overflow_x_scroll()
.child(
Icon::new(tool_use.icon)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(
h_flex().pr_8().text_ui_sm(cx).children(
self.rendered_tool_use_labels
.get(&tool_use.id)
.cloned(),
),
),
)
.child(
h_flex()
.gap_1()
.child(
div().visible_on_hover("disclosure-header").child(
Disclosure::new("tool-use-disclosure", is_open)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener({
let tool_use_id = tool_use.id.clone();
move |this, _event, _window, _cx| {
let is_open = this
.expanded_tool_uses
.entry(tool_use_id.clone())
.or_insert(false);
*is_open = !*is_open;
}
})),
),
)
.child(status_icons),
)
.child(gradient_overlay(self.tool_card_header_bg(cx))),
)
.map(|parent| {
if !is_open {
return parent;
}
parent.child(
v_flex()
.bg(cx.theme().colors().editor_background)
.map(|element| {
if needs_confirmation {
element.rounded_none()
} else {
element.rounded_b_lg()
}
})
.child(results_content),
)
})
.when(needs_confirmation, |this| {
this.child(
v_flex()
.rounded_lg()
.border_1()
.border_color(self.tool_card_border_color(cx))
.overflow_hidden()
.child(
h_flex()
.group("disclosure-header")
.relative()
.justify_between()
.py_1()
.map(|element| {
if is_status_finished {
element.pl_2().pr_0p5()
} else {
element.px_2()
}
})
.bg(self.tool_card_header_bg(cx))
.map(|element| {
if is_open {
element.border_b_1().rounded_t_md()
} else if needs_confirmation {
element.rounded_t_md()
} else {
element.rounded_md()
}
})
.border_color(self.tool_card_border_color(cx))
.child(
h_flex()
.py_1()
.pl_2()
.pr_1()
.gap_1()
.justify_between()
.bg(cx.theme().colors().editor_background)
.border_t_1()
.border_color(self.tool_card_border_color(cx))
.rounded_b_lg()
.child(Label::new("Action Confirmation").color(Color::Muted).size(LabelSize::Small))
.id("tool-label-container")
.gap_1p5()
.max_w_full()
.overflow_x_scroll()
.child(
h_flex()
.gap_0p5()
.child({
let tool_id = tool_use.id.clone();
Button::new(
"always-allow-tool-action",
"Always Allow",
Icon::new(tool_use.icon)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(
h_flex().pr_8().text_ui_sm(cx).children(
self.rendered_tool_use_labels
.get(&tool_use.id)
.cloned(),
),
),
)
.child(
h_flex()
.gap_1()
.child(
div().visible_on_hover("disclosure-header").child(
Disclosure::new("tool-use-disclosure", is_open)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener({
let tool_use_id = tool_use.id.clone();
move |this, _event, _window, _cx| {
let is_open = this
.expanded_tool_uses
.entry(tool_use_id.clone())
.or_insert(false);
*is_open = !*is_open;
}
})),
),
)
.child(status_icons),
)
.child(gradient_overlay(self.tool_card_header_bg(cx))),
)
.map(|parent| {
if !is_open {
return parent;
}
parent.child(
v_flex()
.bg(cx.theme().colors().editor_background)
.map(|element| {
if needs_confirmation {
element.rounded_none()
} else {
element.rounded_b_lg()
}
})
.child(results_content),
)
})
.when(needs_confirmation, |this| {
this.child(
h_flex()
.py_1()
.pl_2()
.pr_1()
.gap_1()
.justify_between()
.bg(cx.theme().colors().editor_background)
.border_t_1()
.border_color(self.tool_card_border_color(cx))
.rounded_b_lg()
.child(Label::new("Action Confirmation").color(Color::Muted).size(LabelSize::Small))
.child(
h_flex()
.gap_0p5()
.child({
let tool_id = tool_use.id.clone();
Button::new(
"always-allow-tool-action",
"Always Allow",
)
.label_size(LabelSize::Small)
.icon(IconName::CheckDouble)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.tooltip(move |window, cx| {
Tooltip::with_meta(
"Never ask for permission",
None,
"Restore the original behavior in your Agent Panel settings",
window,
cx,
)
.label_size(LabelSize::Small)
.icon(IconName::CheckDouble)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.tooltip(move |window, cx| {
Tooltip::with_meta(
"Never ask for permission",
None,
"Restore the original behavior in your Agent Panel settings",
})
.on_click(cx.listener(
move |this, event, window, cx| {
if let Some(fs) = fs.clone() {
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
|settings, _| {
settings.set_always_allow_tool_actions(true);
},
);
}
this.handle_allow_tool(
tool_id.clone(),
event,
window,
cx,
)
})
},
))
})
.child(ui::Divider::vertical())
.child({
let tool_id = tool_use.id.clone();
Button::new("allow-tool-action", "Allow")
.label_size(LabelSize::Small)
.icon(IconName::Check)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.on_click(cx.listener(
move |this, event, window, cx| {
if let Some(fs) = fs.clone() {
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
|settings, _| {
settings.set_always_allow_tool_actions(true);
},
);
}
this.handle_allow_tool(
tool_id.clone(),
event,
@@ -2047,52 +2036,31 @@ impl ActiveThread {
)
},
))
})
.child(ui::Divider::vertical())
.child({
let tool_id = tool_use.id.clone();
Button::new("allow-tool-action", "Allow")
.label_size(LabelSize::Small)
.icon(IconName::Check)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.on_click(cx.listener(
move |this, event, window, cx| {
this.handle_allow_tool(
tool_id.clone(),
event,
window,
cx,
)
},
))
})
.child({
let tool_id = tool_use.id.clone();
let tool_name: Arc<str> = tool_use.name.into();
Button::new("deny-tool", "Deny")
.label_size(LabelSize::Small)
.icon(IconName::Close)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Error)
.on_click(cx.listener(
move |this, event, window, cx| {
this.handle_deny_tool(
tool_id.clone(),
tool_name.clone(),
event,
window,
cx,
)
},
))
}),
),
)
}),
)
})
.child({
let tool_id = tool_use.id.clone();
let tool_name: Arc<str> = tool_use.name.into();
Button::new("deny-tool", "Deny")
.label_size(LabelSize::Small)
.icon(IconName::Close)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Error)
.on_click(cx.listener(
move |this, event, window, cx| {
this.handle_deny_tool(
tool_id.clone(),
tool_name.clone(),
event,
window,
cx,
)
},
))
}),
),
)
})
}
})
}

View File

@@ -1,9 +1,9 @@
use crate::{Thread, ThreadEvent};
use crate::{Keep, Reject, Thread, ThreadEvent};
use anyhow::Result;
use buffer_diff::DiffHunkStatus;
use collections::HashSet;
use editor::{
AnchorRangeExt, Direction, Editor, EditorEvent, MultiBuffer, ToPoint,
Direction, Editor, EditorEvent, MultiBuffer, ToPoint,
actions::{GoToHunk, GoToPreviousHunk},
scroll::Autoscroll,
};
@@ -26,6 +26,7 @@ use workspace::{
item::{BreadcrumbText, ItemEvent, TabContentParams},
searchable::SearchableItemHandle,
};
use zed_actions::assistant::ToggleFocus;
pub struct AgentDiff {
multibuffer: Entity<MultiBuffer>,
@@ -43,7 +44,7 @@ impl AgentDiff {
workspace: WeakEntity<Workspace>,
window: &mut Window,
cx: &mut App,
) -> Result<()> {
) -> Result<Entity<Self>> {
let existing_diff = workspace.update(cx, |workspace, cx| {
workspace
.items_of_type::<AgentDiff>(cx)
@@ -52,13 +53,15 @@ impl AgentDiff {
if let Some(existing_diff) = existing_diff {
workspace.update(cx, |workspace, cx| {
workspace.activate_item(&existing_diff, true, true, window, cx);
})
})?;
Ok(existing_diff)
} else {
let agent_diff =
cx.new(|cx| AgentDiff::new(thread.clone(), workspace.clone(), window, cx));
workspace.update(cx, |workspace, cx| {
workspace.add_item_to_center(Box::new(agent_diff), window, cx);
})
workspace.add_item_to_center(Box::new(agent_diff.clone()), window, cx);
})?;
Ok(agent_diff)
}
}
@@ -133,11 +136,11 @@ impl AgentDiff {
let mut paths_to_delete = self.multibuffer.read(cx).paths().collect::<HashSet<_>>();
for (buffer, diff_handle) in changed_buffers {
let Some(file) = buffer.read(cx).file().cloned() else {
if buffer.read(cx).file().is_none() {
continue;
};
}
let path_key = PathKey::namespaced(0, file.full_path(cx).into());
let path_key = PathKey::for_buffer(&buffer, cx);
paths_to_delete.remove(&path_key);
let snapshot = buffer.read(cx).snapshot();
@@ -240,6 +243,26 @@ impl AgentDiff {
}
}
pub fn move_to_path(&mut self, path_key: PathKey, window: &mut Window, cx: &mut Context<Self>) {
if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) {
self.editor.update(cx, |editor, cx| {
let first_hunk = editor
.diff_hunks_in_ranges(
&[position..editor::Anchor::max()],
&self.multibuffer.read(cx).read(cx),
)
.next();
if let Some(first_hunk) = first_hunk {
let first_hunk_start = first_hunk.multi_buffer_range().start;
editor.change_selections(Some(Autoscroll::fit()), window, cx, |selections| {
selections.select_anchor_ranges([first_hunk_start..first_hunk_start]);
})
}
});
}
}
fn keep(&mut self, _: &crate::Keep, window: &mut Window, cx: &mut Context<Self>) {
let ranges = self
.editor
@@ -327,13 +350,16 @@ impl AgentDiff {
self.update_selection(&diff_hunks_in_ranges, window, cx);
}
let point_ranges = ranges
.into_iter()
.map(|range| range.to_point(&snapshot))
.collect();
self.editor.update(cx, |editor, cx| {
editor.restore_hunks_in_ranges(point_ranges, window, cx)
});
for hunk in &diff_hunks_in_ranges {
let buffer = self.multibuffer.read(cx).buffer(hunk.buffer_id);
if let Some(buffer) = buffer {
self.thread
.update(cx, |thread, cx| {
thread.reject_edits_in_range(buffer, hunk.buffer_range.clone(), cx)
})
.detach_and_log_err(cx);
}
}
}
fn update_selection(
@@ -553,11 +579,12 @@ impl Item for AgentDiff {
}
impl Render for AgentDiff {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let is_empty = self.multibuffer.read(cx).is_empty();
let focus_handle = &self.focus_handle;
div()
.track_focus(&self.focus_handle)
.track_focus(focus_handle)
.key_context(if is_empty { "EmptyPane" } else { "AgentDiff" })
.on_action(cx.listener(Self::keep))
.on_action(cx.listener(Self::reject))
@@ -568,7 +595,32 @@ impl Render for AgentDiff {
.items_center()
.justify_center()
.size_full()
.when(is_empty, |el| el.child("No changes to review"))
.when(is_empty, |el| {
el.child(
v_flex()
.items_center()
.gap_2()
.child("No changes to review")
.child(
Button::new("continue-iterating", "Continue Iterating")
.style(ButtonStyle::Filled)
.icon(IconName::ForwardArrow)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.full_width()
.key_binding(KeyBinding::for_action_in(
&ToggleFocus,
&focus_handle.clone(),
window,
cx,
))
.on_click(|_event, window, cx| {
window.dispatch_action(ToggleFocus.boxed_clone(), cx)
}),
),
)
})
.when(!is_empty, |el| el.child(self.editor.clone()))
}
}
@@ -604,7 +656,7 @@ fn render_diff_hunk_controls(
.disabled(is_created_file)
.key_binding(
KeyBinding::for_action_in(
&crate::Reject,
&Reject,
&editor.read(cx).focus_handle(cx),
window,
cx,
@@ -625,13 +677,8 @@ fn render_diff_hunk_controls(
}),
Button::new(("keep", row as u64), "Keep")
.key_binding(
KeyBinding::for_action_in(
&crate::Keep,
&editor.read(cx).focus_handle(cx),
window,
cx,
)
.map(|kb| kb.size(rems_from_px(12.))),
KeyBinding::for_action_in(&Keep, &editor.read(cx).focus_handle(cx), window, cx)
.map(|kb| kb.size(rems_from_px(12.))),
)
.on_click({
let agent_diff = agent_diff.clone();
@@ -942,7 +989,7 @@ mod tests {
Point::new(3, 0)..Point::new(3, 0)
);
// Restoring a hunk also moves the cursor to the next hunk, possibly cycling if it's at the end.
// Rejecting a hunk also moves the cursor to the next hunk, possibly cycling if it's at the end.
editor.update_in(cx, |editor, window, cx| {
editor.change_selections(None, window, cx, |selections| {
selections.select_ranges([Point::new(10, 0)..Point::new(10, 0)])

View File

@@ -1,17 +1,17 @@
use context_server::{ContextServerSettings, ServerCommand, ServerConfig};
use editor::Editor;
use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, WeakEntity, prelude::*};
use serde_json::json;
use settings::update_settings_file;
use ui::{Modal, ModalFooter, ModalHeader, Section, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use workspace::{ModalView, Workspace};
use crate::AddContextServer;
pub struct AddContextServerModal {
workspace: WeakEntity<Workspace>,
name_editor: Entity<Editor>,
command_editor: Entity<Editor>,
name_editor: Entity<SingleLineInput>,
command_editor: Entity<SingleLineInput>,
}
impl AddContextServerModal {
@@ -33,15 +33,10 @@ impl AddContextServerModal {
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let name_editor = cx.new(|cx| Editor::single_line(window, cx));
let command_editor = cx.new(|cx| Editor::single_line(window, cx));
name_editor.update(cx, |editor, cx| {
editor.set_placeholder_text("Context server name", cx);
});
command_editor.update(cx, |editor, cx| {
editor.set_placeholder_text("Command to run the context server", cx);
let name_editor =
cx.new(|cx| SingleLineInput::new(window, cx, "Your server name").label("Name"));
let command_editor = cx.new(|cx| {
SingleLineInput::new(window, cx, "Command").label("Command to run the context server")
});
Self {
@@ -52,8 +47,22 @@ impl AddContextServerModal {
}
fn confirm(&mut self, cx: &mut Context<Self>) {
let name = self.name_editor.read(cx).text(cx).trim().to_string();
let command = self.command_editor.read(cx).text(cx).trim().to_string();
let name = self
.name_editor
.read(cx)
.editor()
.read(cx)
.text(cx)
.trim()
.to_string();
let command = self
.command_editor
.read(cx)
.editor()
.read(cx)
.text(cx)
.trim()
.to_string();
if name.is_empty() || command.is_empty() {
return;
@@ -104,8 +113,8 @@ impl EventEmitter<DismissEvent> for AddContextServerModal {}
impl Render for AddContextServerModal {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let is_name_empty = self.name_editor.read(cx).text(cx).trim().is_empty();
let is_command_empty = self.command_editor.read(cx).text(cx).trim().is_empty();
let is_name_empty = self.name_editor.read(cx).is_empty(cx);
let is_command_empty = self.command_editor.read(cx).is_empty(cx);
div()
.elevation_3(cx)
@@ -122,18 +131,8 @@ impl Render for AddContextServerModal {
.header(ModalHeader::new().headline("Add Context Server"))
.section(
Section::new()
.child(
v_flex()
.gap_1()
.child(Label::new("Name"))
.child(self.name_editor.clone()),
)
.child(
v_flex()
.gap_1()
.child(Label::new("Command"))
.child(self.command_editor.clone()),
),
.child(self.name_editor.clone())
.child(self.command_editor.clone()),
)
.footer(
ModalFooter::new()

View File

@@ -202,43 +202,43 @@ impl PickerDelegate for ToolPickerDelegate {
let default_profile = self.profile.clone();
let tool = tool.clone();
move |settings, _cx| match settings {
AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(
settings,
)) => {
let profiles = settings.profiles.get_or_insert_default();
let profile =
profiles
.entry(profile_id)
.or_insert_with(|| AgentProfileContent {
name: default_profile.name.into(),
tools: default_profile.tools,
enable_all_context_servers: Some(
default_profile.enable_all_context_servers,
),
context_servers: default_profile
.context_servers
.into_iter()
.map(|(server_id, preset)| {
(
server_id,
ContextServerPresetContent {
tools: preset.tools,
},
)
})
.collect(),
});
AssistantSettingsContent::Versioned(boxed) => {
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
let profiles = settings.profiles.get_or_insert_default();
let profile =
profiles
.entry(profile_id)
.or_insert_with(|| AgentProfileContent {
name: default_profile.name.into(),
tools: default_profile.tools,
enable_all_context_servers: Some(
default_profile.enable_all_context_servers,
),
context_servers: default_profile
.context_servers
.into_iter()
.map(|(server_id, preset)| {
(
server_id,
ContextServerPresetContent {
tools: preset.tools,
},
)
})
.collect(),
});
match tool.source {
ToolSource::Native => {
*profile.tools.entry(tool.name).or_default() = is_enabled;
}
ToolSource::ContextServer { id } => {
let preset = profile
.context_servers
.entry(id.clone().into())
.or_default();
*preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
match tool.source {
ToolSource::Native => {
*profile.tools.entry(tool.name).or_default() = is_enabled;
}
ToolSource::ContextServer { id } => {
let preset = profile
.context_servers
.entry(id.clone().into())
.or_default();
*preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
}
}
}
}

View File

@@ -9,10 +9,17 @@ use settings::update_settings_file;
use std::sync::Arc;
use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
#[derive(Clone, Copy)]
pub enum ModelType {
Default,
InlineAssistant,
}
pub struct AssistantModelSelector {
selector: Entity<LanguageModelSelector>,
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
focus_handle: FocusHandle,
model_type: ModelType,
}
impl AssistantModelSelector {
@@ -20,6 +27,7 @@ impl AssistantModelSelector {
fs: Arc<dyn Fs>,
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
focus_handle: FocusHandle,
model_type: ModelType,
window: &mut Window,
cx: &mut App,
) -> Self {
@@ -28,11 +36,32 @@ impl AssistantModelSelector {
let fs = fs.clone();
LanguageModelSelector::new(
move |model, cx| {
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _cx| settings.set_model(model.clone()),
);
let provider = model.provider_id().0.to_string();
let model_id = model.id().0.to_string();
match model_type {
ModelType::Default => {
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _cx| {
settings.set_model(model.clone());
},
);
}
ModelType::InlineAssistant => {
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _cx| {
settings.set_inline_assistant_model(
provider.clone(),
model_id.clone(),
);
},
);
}
}
},
window,
cx,
@@ -40,6 +69,7 @@ impl AssistantModelSelector {
}),
menu_handle,
focus_handle,
model_type,
}
}
@@ -50,10 +80,16 @@ impl AssistantModelSelector {
impl Render for AssistantModelSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let active_model = LanguageModelRegistry::read_global(cx).active_model();
let model_registry = LanguageModelRegistry::read_global(cx);
let model = match self.model_type {
ModelType::Default => model_registry.default_model(),
ModelType::InlineAssistant => model_registry.inline_assistant_model(),
};
let focus_handle = self.focus_handle.clone();
let model_name = match active_model {
Some(model) => model.name().0,
let model_name = match model {
Some(model) => model.model.name().0,
_ => SharedString::from("No model selected"),
};

View File

@@ -1,5 +1,6 @@
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Result, anyhow};
use assistant_context_editor::{
@@ -14,9 +15,9 @@ use client::zed_urls;
use editor::{Editor, MultiBuffer};
use fs::Fs;
use gpui::{
Action, AnyElement, App, AsyncWindowContext, Corner, Entity, EventEmitter, FocusHandle,
Focusable, FontWeight, KeyContext, Pixels, Subscription, Task, UpdateGlobal, WeakEntity,
action_with_deprecated_aliases, prelude::*,
Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, Corner, Entity,
EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, Pixels, Subscription, Task,
UpdateGlobal, WeakEntity, action_with_deprecated_aliases, prelude::*, pulsating_between,
};
use language::LanguageRegistry;
use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
@@ -38,7 +39,7 @@ use crate::active_thread::ActiveThread;
use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent};
use crate::history_store::{HistoryEntry, HistoryStore};
use crate::message_editor::MessageEditor;
use crate::thread::{Thread, ThreadError, ThreadId};
use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio};
use crate::thread_history::{PastContext, PastThread, ThreadHistory};
use crate::thread_store::ThreadStore;
use crate::{
@@ -227,7 +228,7 @@ impl AssistantPanel {
)
.unwrap(),
history_store: history_store.clone(),
history: cx.new(|cx| ThreadHistory::new(weak_self, history_store, cx)),
history: cx.new(|cx| ThreadHistory::new(weak_self, history_store, window, cx)),
assistant_dropdown_menu_handle: PopoverMenuHandle::default(),
width: None,
height: None,
@@ -570,10 +571,8 @@ impl AssistantPanel {
match event {
AssistantConfigurationEvent::NewThread(provider) => {
if LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(true, |active_provider| {
active_provider.id() != provider.id()
})
.default_model()
.map_or(true, |model| model.provider.id() != provider.id())
{
if let Some(model) = provider.default_model(cx) {
update_settings_file::<AssistantSettings>(
@@ -715,18 +714,21 @@ impl Panel for AssistantPanel {
impl AssistantPanel {
fn render_toolbar(&self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let thread = self.thread.read(cx);
let is_empty = thread.is_empty();
let active_thread = self.thread.read(cx);
let thread = active_thread.thread().read(cx);
let token_usage = thread.total_token_usage(cx);
let thread_id = thread.id().clone();
let thread_id = thread.thread().read(cx).id().clone();
let is_generating = thread.is_generating();
let is_empty = active_thread.is_empty();
let focus_handle = self.focus_handle(cx);
let title = match self.active_view {
ActiveView::Thread => {
if is_empty {
thread.summary_or_default(cx)
active_thread.summary_or_default(cx)
} else {
thread
active_thread
.summary(cx)
.unwrap_or_else(|| SharedString::from("Loading Summary…"))
}
@@ -742,6 +744,12 @@ impl AssistantPanel {
ActiveView::Configuration => "Settings".into(),
};
let show_token_count = match self.active_view {
ActiveView::Thread => !is_empty,
ActiveView::PromptEditor => self.context_editor.is_some(),
_ => false,
};
h_flex()
.id("assistant-toolbar")
.h(Tab::container_height(cx))
@@ -764,12 +772,67 @@ impl AssistantPanel {
.pl_2()
.gap_2()
.bg(cx.theme().colors().tab_bar_background)
.children(if matches!(self.active_view, ActiveView::PromptEditor) {
self.context_editor
.as_ref()
.and_then(|editor| render_remaining_tokens(editor, cx))
} else {
None
.when(show_token_count, |parent| match self.active_view {
ActiveView::Thread => {
if token_usage.total == 0 {
return parent;
}
let token_color = match token_usage.ratio {
TokenUsageRatio::Normal => Color::Muted,
TokenUsageRatio::Warning => Color::Warning,
TokenUsageRatio::Exceeded => Color::Error,
};
parent.child(
h_flex()
.gap_0p5()
.child(
Label::new(assistant_context_editor::humanize_token_count(
token_usage.total,
))
.size(LabelSize::Small)
.color(token_color)
.map(|label| {
if is_generating {
label
.with_animation(
"used-tokens-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(
0.6, 1.,
)),
|label, delta| label.alpha(delta),
)
.into_any()
} else {
label.into_any_element()
}
}),
)
.child(
Label::new("/").size(LabelSize::Small).color(Color::Muted),
)
.child(
Label::new(assistant_context_editor::humanize_token_count(
token_usage.max,
))
.size(LabelSize::Small)
.color(Color::Muted),
),
)
}
ActiveView::PromptEditor => {
let Some(editor) = self.context_editor.as_ref() else {
return parent;
};
let Some(element) = render_remaining_tokens(editor, cx) else {
return parent;
};
parent.child(element)
}
_ => parent,
})
.child(
h_flex()
@@ -857,16 +920,18 @@ impl AssistantPanel {
}
fn configuration_error(&self, cx: &App) -> Option<ConfigurationError> {
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return Some(ConfigurationError::NoProvider);
};
if !provider.is_authenticated(cx) {
if !model.provider.is_authenticated(cx) {
return Some(ConfigurationError::ProviderNotAuthenticated);
}
if provider.must_accept_terms(cx) {
return Some(ConfigurationError::ProviderPendingTermsAcceptance(provider));
if model.provider.must_accept_terms(cx) {
return Some(ConfigurationError::ProviderPendingTermsAcceptance(
model.provider,
));
}
None
@@ -1069,11 +1134,11 @@ impl AssistantPanel {
// TODO: Add keyboard navigation.
match entry {
HistoryEntry::Thread(thread) => {
PastThread::new(thread, cx.entity().downgrade(), false)
PastThread::new(thread, cx.entity().downgrade(), false, vec![])
.into_any_element()
}
HistoryEntry::Context(context) => {
PastContext::new(context, cx.entity().downgrade(), false)
PastContext::new(context, cx.entity().downgrade(), false, vec![])
.into_any_element()
}
}

View File

@@ -156,8 +156,9 @@ impl BufferCodegen {
}
let primary_model = LanguageModelRegistry::read_global(cx)
.active_model()
.context("no active model")?;
.default_model()
.context("no active model")?
.model;
for (model, alternative) in iter::once(primary_model)
.chain(alternative_models)

View File

@@ -146,11 +146,11 @@ pub struct ContextSymbolId {
pub range: Range<Anchor>,
}
pub fn attach_context_to_message<'a>(
message: &mut LanguageModelRequestMessage,
/// Formats a collection of contexts into a string representation
pub fn format_context_as_string<'a>(
contexts: impl Iterator<Item = &'a AssistantContext>,
cx: &App,
) {
) -> Option<String> {
let mut file_context = Vec::new();
let mut directory_context = Vec::new();
let mut symbol_context = Vec::new();
@@ -167,64 +167,78 @@ pub fn attach_context_to_message<'a>(
}
}
let mut context_chunks = Vec::new();
if file_context.is_empty()
&& directory_context.is_empty()
&& symbol_context.is_empty()
&& fetch_context.is_empty()
&& thread_context.is_empty()
{
return None;
}
let mut result = String::new();
result.push_str("\n<context>\n\
The following items were attached by the user. You don't need to use other tools to read them.\n\n");
if !file_context.is_empty() {
context_chunks.push("<files>\n");
result.push_str("<files>\n");
for context in file_context {
context_chunks.push(&context.context_buffer.text);
result.push_str(&context.context_buffer.text);
}
context_chunks.push("\n</files>\n");
result.push_str("</files>\n");
}
if !directory_context.is_empty() {
context_chunks.push("<directories>\n");
result.push_str("<directories>\n");
for context in directory_context {
for context_buffer in &context.context_buffers {
context_chunks.push(&context_buffer.text);
result.push_str(&context_buffer.text);
}
}
context_chunks.push("\n</directories>\n");
result.push_str("</directories>\n");
}
if !symbol_context.is_empty() {
context_chunks.push("<symbols>\n");
result.push_str("<symbols>\n");
for context in symbol_context {
context_chunks.push(&context.context_symbol.text);
result.push_str(&context.context_symbol.text);
result.push('\n');
}
context_chunks.push("\n</symbols>\n");
result.push_str("</symbols>\n");
}
if !fetch_context.is_empty() {
context_chunks.push("<fetched_urls>\n");
result.push_str("<fetched_urls>\n");
for context in &fetch_context {
context_chunks.push(&context.url);
context_chunks.push(&context.text);
result.push_str(&context.url);
result.push('\n');
result.push_str(&context.text);
result.push('\n');
}
context_chunks.push("\n</fetched_urls>\n");
result.push_str("</fetched_urls>\n");
}
// Need to own the SharedString for summary so that it can be referenced.
let mut thread_context_chunks = Vec::new();
if !thread_context.is_empty() {
context_chunks.push("<conversation_threads>\n");
result.push_str("<conversation_threads>\n");
for context in &thread_context {
thread_context_chunks.push(context.summary(cx));
thread_context_chunks.push(context.text.clone());
result.push_str(&context.summary(cx));
result.push('\n');
result.push_str(&context.text);
result.push('\n');
}
context_chunks.push("\n</conversation_threads>\n");
result.push_str("</conversation_threads>\n");
}
for chunk in &thread_context_chunks {
context_chunks.push(chunk);
}
result.push_str("</context>\n");
Some(result)
}
if !context_chunks.is_empty() {
message.content.push(
"\n<context>\n\
The following items were attached by the user. You don't need to use other tools to read them.\n\n".into(),
);
message.content.push(context_chunks.join("\n").into());
message.content.push("\n</context>\n".into());
pub fn attach_context_to_message<'a>(
message: &mut LanguageModelRequestMessage,
contexts: impl Iterator<Item = &'a AssistantContext>,
cx: &App,
) {
if let Some(context_string) = format_context_as_string(contexts, cx) {
message.content.push(context_string.into());
}
}

View File

@@ -76,7 +76,7 @@ impl ContextPickerMode {
Self::File => "Files & Directories",
Self::Symbol => "Symbols",
Self::Fetch => "Fetch",
Self::Thread => "Thread",
Self::Thread => "Threads",
}
}
@@ -360,73 +360,15 @@ impl ContextPicker {
}
fn recent_entries(&self, cx: &mut App) -> Vec<RecentEntry> {
let Some(workspace) = self.workspace.upgrade().map(|w| w.read(cx)) else {
let Some(workspace) = self.workspace.upgrade() else {
return vec![];
};
let Some(context_store) = self.context_store.upgrade().map(|cs| cs.read(cx)) else {
let Some(context_store) = self.context_store.upgrade() else {
return vec![];
};
let mut recent = Vec::with_capacity(6);
let mut current_files = context_store.file_paths(cx);
if let Some(active_path) = active_singleton_buffer_path(&workspace, cx) {
current_files.insert(active_path);
}
let project = workspace.project().read(cx);
recent.extend(
workspace
.recent_navigation_history_iter(cx)
.filter(|(path, _)| !current_files.contains(&path.path.to_path_buf()))
.take(4)
.filter_map(|(project_path, _)| {
project
.worktree_for_id(project_path.worktree_id, cx)
.map(|worktree| RecentEntry::File {
project_path,
path_prefix: worktree.read(cx).root_name().into(),
})
}),
);
let mut current_threads = context_store.thread_ids();
if let Some(active_thread) = workspace
.panel::<AssistantPanel>(cx)
.map(|panel| panel.read(cx).active_thread(cx))
{
current_threads.insert(active_thread.read(cx).id().clone());
}
let Some(thread_store) = self
.thread_store
.as_ref()
.and_then(|thread_store| thread_store.upgrade())
else {
return recent;
};
thread_store.update(cx, |thread_store, _cx| {
recent.extend(
thread_store
.threads()
.into_iter()
.filter(|thread| !current_threads.contains(&thread.id))
.take(2)
.map(|thread| {
RecentEntry::Thread(ThreadContextEntry {
id: thread.id,
summary: thread.summary,
})
}),
)
});
recent
recent_context_picker_entries(context_store, self.thread_store.clone(), workspace, cx)
}
}
@@ -480,16 +422,6 @@ fn supported_context_picker_modes(
modes
}
fn active_singleton_buffer_path(workspace: &Workspace, cx: &App) -> Option<PathBuf> {
let active_item = workspace.active_item(cx)?;
let editor = active_item.to_any().downcast::<Editor>().ok()?.read(cx);
let buffer = editor.buffer().read(cx).as_singleton()?;
let path = buffer.read(cx).file()?.path().to_path_buf();
Some(path)
}
fn recent_context_picker_entries(
context_store: Entity<ContextStore>,
thread_store: Option<WeakEntity<ThreadStore>>,
@@ -498,14 +430,8 @@ fn recent_context_picker_entries(
) -> Vec<RecentEntry> {
let mut recent = Vec::with_capacity(6);
let mut current_files = context_store.read(cx).file_paths(cx);
let current_files = context_store.read(cx).file_paths(cx);
let workspace = workspace.read(cx);
if let Some(active_path) = active_singleton_buffer_path(workspace, cx) {
current_files.insert(active_path);
}
let project = workspace.project().read(cx);
recent.extend(
@@ -683,24 +609,45 @@ fn fold_toggle(
pub enum MentionLink {
File(ProjectPath, Entry),
Symbol(ProjectPath, String),
Fetch(String),
Thread(ThreadId),
}
impl MentionLink {
const FILE: &str = "@file";
const SYMBOL: &str = "@symbol";
const THREAD: &str = "@thread";
const FETCH: &str = "@fetch";
const SEPARATOR: &str = ":";
pub fn is_valid(url: &str) -> bool {
url.starts_with(Self::FILE)
|| url.starts_with(Self::SYMBOL)
|| url.starts_with(Self::FETCH)
|| url.starts_with(Self::THREAD)
}
pub fn for_file(file_name: &str, full_path: &str) -> String {
format!("[@{}](file:{})", file_name, full_path)
format!("[@{}]({}:{})", file_name, Self::FILE, full_path)
}
pub fn for_symbol(symbol_name: &str, full_path: &str) -> String {
format!("[@{}](symbol:{}:{})", symbol_name, full_path, symbol_name)
format!(
"[@{}]({}:{}:{})",
symbol_name,
Self::SYMBOL,
full_path,
symbol_name
)
}
pub fn for_fetch(url: &str) -> String {
format!("[@{}]({})", url, url)
format!("[@{}]({}:{})", url, Self::FETCH, url)
}
pub fn for_thread(thread: &ThreadContextEntry) -> String {
format!("[@{}](thread:{})", thread.summary, thread.id)
format!("[@{}]({}:{})", thread.summary, Self::THREAD, thread.id)
}
pub fn try_parse(link: &str, workspace: &Entity<Workspace>, cx: &App) -> Option<Self> {
@@ -723,17 +670,10 @@ impl MentionLink {
})
}
let (prefix, link, target) = {
let mut parts = link.splitn(3, ':');
let prefix = parts.next();
let link = parts.next();
let target = parts.next();
(prefix, link, target)
};
match (prefix, link, target) {
(Some("file"), Some(path), _) => {
let project_path = extract_project_path_from_link(path, workspace, cx)?;
let (prefix, argument) = link.split_once(Self::SEPARATOR)?;
match prefix {
Self::FILE => {
let project_path = extract_project_path_from_link(argument, workspace, cx)?;
let entry = workspace
.read(cx)
.project()
@@ -741,14 +681,16 @@ impl MentionLink {
.entry_for_path(&project_path, cx)?;
Some(MentionLink::File(project_path, entry))
}
(Some("symbol"), Some(path), Some(symbol_name)) => {
Self::SYMBOL => {
let (path, symbol) = argument.split_once(Self::SEPARATOR)?;
let project_path = extract_project_path_from_link(path, workspace, cx)?;
Some(MentionLink::Symbol(project_path, symbol_name.to_string()))
Some(MentionLink::Symbol(project_path, symbol.to_string()))
}
(Some("thread"), Some(thread_id), _) => {
let thread_id = ThreadId::from(thread_id);
Self::THREAD => {
let thread_id = ThreadId::from(argument);
Some(MentionLink::Thread(thread_id))
}
Self::FETCH => Some(MentionLink::Fetch(argument.to_string())),
_ => None,
}
}

View File

@@ -890,10 +890,10 @@ mod tests {
assert_eq!(
current_completion_labels(editor),
&[
"editor dir/",
"seven.txt dir/b/",
"six.txt dir/b/",
"five.txt dir/b/",
"four.txt dir/a/",
"Files & Directories",
"Symbols",
"Fetch"
@@ -932,22 +932,22 @@ mod tests {
});
editor.update(&mut cx, |editor, cx| {
assert_eq!(editor.text(cx), "Lorem [@one.txt](file:dir/a/one.txt)",);
assert_eq!(editor.text(cx), "Lorem [@one.txt](@file:dir/a/one.txt)",);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 36)]
vec![Point::new(0, 6)..Point::new(0, 37)]
);
});
cx.simulate_input(" ");
editor.update(&mut cx, |editor, cx| {
assert_eq!(editor.text(cx), "Lorem [@one.txt](file:dir/a/one.txt) ",);
assert_eq!(editor.text(cx), "Lorem [@one.txt](@file:dir/a/one.txt) ",);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 36)]
vec![Point::new(0, 6)..Point::new(0, 37)]
);
});
@@ -956,12 +956,12 @@ mod tests {
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](file:dir/a/one.txt) Ipsum ",
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum ",
);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 36)]
vec![Point::new(0, 6)..Point::new(0, 37)]
);
});
@@ -970,12 +970,12 @@ mod tests {
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](file:dir/a/one.txt) Ipsum @file ",
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum @file ",
);
assert!(editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
vec![Point::new(0, 6)..Point::new(0, 36)]
vec![Point::new(0, 6)..Point::new(0, 37)]
);
});
@@ -988,14 +988,14 @@ mod tests {
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](file:dir/a/one.txt) Ipsum [@seven.txt](file:dir/b/seven.txt)"
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@editor](@file:dir/editor)"
);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
vec![
Point::new(0, 6)..Point::new(0, 36),
Point::new(0, 43)..Point::new(0, 77)
Point::new(0, 6)..Point::new(0, 37),
Point::new(0, 44)..Point::new(0, 71)
]
);
});
@@ -1005,14 +1005,14 @@ mod tests {
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](file:dir/a/one.txt) Ipsum [@seven.txt](file:dir/b/seven.txt)\n@"
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@editor](@file:dir/editor)\n@"
);
assert!(editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
vec![
Point::new(0, 6)..Point::new(0, 36),
Point::new(0, 43)..Point::new(0, 77)
Point::new(0, 6)..Point::new(0, 37),
Point::new(0, 44)..Point::new(0, 71)
]
);
});
@@ -1026,15 +1026,15 @@ mod tests {
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](file:dir/a/one.txt) Ipsum [@seven.txt](file:dir/b/seven.txt)\n[@six.txt](file:dir/b/six.txt)"
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@editor](@file:dir/editor)\n[@seven.txt](@file:dir/b/seven.txt)"
);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
vec![
Point::new(0, 6)..Point::new(0, 36),
Point::new(0, 43)..Point::new(0, 77),
Point::new(1, 0)..Point::new(1, 30)
Point::new(0, 6)..Point::new(0, 37),
Point::new(0, 44)..Point::new(0, 71),
Point::new(1, 0)..Point::new(1, 35)
]
);
});

View File

@@ -6,7 +6,7 @@ use anyhow::{Context as _, Result, anyhow};
use collections::{BTreeMap, HashMap, HashSet};
use futures::future::join_all;
use futures::{self, Future, FutureExt, future};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
use gpui::{App, AppContext as _, Context, Entity, SharedString, Task, WeakEntity};
use language::{Buffer, File};
use project::{ProjectItem, ProjectPath, Worktree};
use rope::Rope;
@@ -95,8 +95,8 @@ impl ContextStore {
project.open_buffer(project_path.clone(), cx)
})?;
let buffer_entity = open_buffer_task.await?;
let buffer_id = this.update(cx, |_, cx| buffer_entity.read(cx).remote_id())?;
let buffer = open_buffer_task.await?;
let buffer_id = this.update(cx, |_, cx| buffer.read(cx).remote_id())?;
let already_included = this.update(cx, |this, _cx| {
match this.will_include_buffer(buffer_id, &project_path.path) {
@@ -115,16 +115,8 @@ impl ContextStore {
return anyhow::Ok(());
}
let (buffer_info, text_task) = this.update(cx, |_, cx| {
let buffer = buffer_entity.read(cx);
collect_buffer_info_and_text(
project_path.path.clone(),
buffer_entity,
buffer,
None,
cx.to_async(),
)
})??;
let (buffer_info, text_task) =
this.update(cx, |_, cx| collect_buffer_info_and_text(buffer, None, cx))??;
let text = text_task.await;
@@ -138,23 +130,12 @@ impl ContextStore {
pub fn add_file_from_buffer(
&mut self,
buffer_entity: Entity<Buffer>,
buffer: Entity<Buffer>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
cx.spawn(async move |this, cx| {
let (buffer_info, text_task) = this.update(cx, |_, cx| {
let buffer = buffer_entity.read(cx);
let Some(file) = buffer.file() else {
return Err(anyhow!("Buffer has no path."));
};
collect_buffer_info_and_text(
file.path().clone(),
buffer_entity,
buffer,
None,
cx.to_async(),
)
})??;
let (buffer_info, text_task) =
this.update(cx, |_, cx| collect_buffer_info_and_text(buffer, None, cx))??;
let text = text_task.await;
@@ -169,10 +150,8 @@ impl ContextStore {
fn insert_file(&mut self, context_buffer: ContextBuffer) {
let id = self.next_context_id.post_inc();
self.files.insert(context_buffer.id, id);
self.context.push(AssistantContext::File(FileContext {
id,
context_buffer: context_buffer,
}));
self.context
.push(AssistantContext::File(FileContext { id, context_buffer }));
}
pub fn add_directory(
@@ -233,22 +212,13 @@ impl ContextStore {
let mut buffer_infos = Vec::new();
let mut text_tasks = Vec::new();
this.update(cx, |_, cx| {
for (path, buffer_entity) in files.into_iter().zip(buffers) {
// Skip all binary files and other non-UTF8 files
if let Ok(buffer_entity) = buffer_entity {
let buffer = buffer_entity.read(cx);
if let Some((buffer_info, text_task)) = collect_buffer_info_and_text(
path,
buffer_entity,
buffer,
None,
cx.to_async(),
)
.log_err()
{
buffer_infos.push(buffer_info);
text_tasks.push(text_task);
}
// Skip all binary files and other non-UTF8 files
for buffer in buffers.into_iter().flatten() {
if let Some((buffer_info, text_task)) =
collect_buffer_info_and_text(buffer, None, cx).log_err()
{
buffer_infos.push(buffer_info);
text_tasks.push(text_task);
}
}
anyhow::Ok(())
@@ -298,12 +268,8 @@ impl ContextStore {
cx: &mut Context<Self>,
) -> Task<Result<bool>> {
let buffer_ref = buffer.read(cx);
let Some(file) = buffer_ref.file() else {
return Task::ready(Err(anyhow!("Buffer has no path.")));
};
let Some(project_path) = buffer_ref.project_path(cx) else {
return Task::ready(Err(anyhow!("Buffer has no project path.")));
return Task::ready(Err(anyhow!("buffer has no path")));
};
if let Some(symbols_for_path) = self.symbols_by_path.get(&project_path) {
@@ -326,16 +292,11 @@ impl ContextStore {
}
}
let (buffer_info, collect_content_task) = match collect_buffer_info_and_text(
file.path().clone(),
buffer,
buffer_ref,
Some(symbol_enclosing_range.clone()),
cx.to_async(),
) {
Ok((buffer_info, collect_context_task)) => (buffer_info, collect_context_task),
Err(err) => return Task::ready(Err(err)),
};
let (buffer_info, collect_content_task) =
match collect_buffer_info_and_text(buffer, Some(symbol_enclosing_range.clone()), cx) {
Ok((buffer_info, collect_context_task)) => (buffer_info, collect_context_task),
Err(err) => return Task::ready(Err(err)),
};
cx.spawn(async move |this, cx| {
let content = collect_content_task.await;
@@ -616,16 +577,16 @@ pub enum FileInclusion {
// ContextBuffer without text.
struct BufferInfo {
buffer_entity: Entity<Buffer>,
file: Arc<dyn File>,
id: BufferId,
buffer: Entity<Buffer>,
file: Arc<dyn File>,
version: clock::Global,
}
fn make_context_buffer(info: BufferInfo, text: SharedString) -> ContextBuffer {
ContextBuffer {
id: info.id,
buffer: info.buffer_entity,
buffer: info.buffer,
file: info.file,
version: info.version,
text,
@@ -644,34 +605,37 @@ fn make_context_symbol(
id: ContextSymbolId { name, range, path },
buffer_version: info.version,
enclosing_range,
buffer: info.buffer_entity,
buffer: info.buffer,
text,
}
}
fn collect_buffer_info_and_text(
path: Arc<Path>,
buffer_entity: Entity<Buffer>,
buffer: &Buffer,
buffer: Entity<Buffer>,
range: Option<Range<Anchor>>,
cx: AsyncApp,
cx: &App,
) -> Result<(BufferInfo, Task<SharedString>)> {
let buffer_info = BufferInfo {
id: buffer.remote_id(),
buffer_entity,
file: buffer
.file()
.context("buffer context must have a file")?
.clone(),
version: buffer.version(),
};
let buffer_ref = buffer.read(cx);
let file = buffer_ref.file().context("file context must have a path")?;
// Important to collect version at the same time as content so that staleness logic is correct.
let version = buffer_ref.version();
let content = if let Some(range) = range {
buffer.text_for_range(range).collect::<Rope>()
buffer_ref.text_for_range(range).collect::<Rope>()
} else {
buffer.as_rope().clone()
buffer_ref.as_rope().clone()
};
let text_task = cx.background_spawn(async move { to_fenced_codeblock(&path, content) });
let buffer_info = BufferInfo {
buffer,
id: buffer_ref.remote_id(),
file: file.clone(),
version,
};
let full_path = file.full_path(cx);
let text_task = cx.background_spawn(async move { to_fenced_codeblock(&full_path, content) });
Ok((buffer_info, text_task))
}
@@ -920,16 +884,9 @@ fn refresh_context_buffer(
cx: &App,
) -> Option<impl Future<Output = ContextBuffer> + use<>> {
let buffer = context_buffer.buffer.read(cx);
let path = buffer_path_log_err(buffer, cx)?;
if buffer.version.changed_since(&context_buffer.version) {
let (buffer_info, text_task) = collect_buffer_info_and_text(
path,
context_buffer.buffer.clone(),
buffer,
None,
cx.to_async(),
)
.log_err()?;
let (buffer_info, text_task) =
collect_buffer_info_and_text(context_buffer.buffer.clone(), None, cx).log_err()?;
Some(text_task.map(move |text| make_context_buffer(buffer_info, text)))
} else {
None
@@ -941,15 +898,12 @@ fn refresh_context_symbol(
cx: &App,
) -> Option<impl Future<Output = ContextSymbol> + use<>> {
let buffer = context_symbol.buffer.read(cx);
let path = buffer_path_log_err(buffer, cx)?;
let project_path = buffer.project_path(cx)?;
if buffer.version.changed_since(&context_symbol.buffer_version) {
let (buffer_info, text_task) = collect_buffer_info_and_text(
path,
context_symbol.buffer.clone(),
buffer,
Some(context_symbol.enclosing_range.clone()),
cx.to_async(),
cx,
)
.log_err()?;
let name = context_symbol.id.name.clone();

View File

@@ -4,6 +4,7 @@ use gpui::{Entity, prelude::*};
use crate::thread_store::{SerializedThreadMetadata, ThreadStore};
#[derive(Debug)]
pub enum HistoryEntry {
Thread(SerializedThreadMetadata),
Context(SavedContextMetadata),
@@ -21,25 +22,27 @@ impl HistoryEntry {
pub struct HistoryStore {
thread_store: Entity<ThreadStore>,
context_store: Entity<assistant_context_editor::ContextStore>,
_subscriptions: Vec<gpui::Subscription>,
}
impl HistoryStore {
pub fn new(
thread_store: Entity<ThreadStore>,
context_store: Entity<assistant_context_editor::ContextStore>,
_cx: &mut Context<Self>,
cx: &mut Context<Self>,
) -> Self {
let subscriptions = vec![
cx.observe(&thread_store, |_, _, cx| cx.notify()),
cx.observe(&context_store, |_, _, cx| cx.notify()),
];
Self {
thread_store,
context_store,
_subscriptions: subscriptions,
}
}
/// Returns the number of history entries.
pub fn entry_count(&self, cx: &mut Context<Self>) -> usize {
self.entries(cx).len()
}
pub fn entries(&self, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
let mut history_entries = Vec::new();

View File

@@ -239,8 +239,8 @@ impl InlineAssistant {
let is_authenticated = || {
LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(false, |provider| provider.is_authenticated(cx))
.inline_assistant_model()
.map_or(false, |model| model.provider.is_authenticated(cx))
};
let thread_store = workspace
@@ -279,8 +279,8 @@ impl InlineAssistant {
cx.spawn_in(window, async move |_workspace, cx| {
let Some(task) = cx.update(|_, cx| {
LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(None, |provider| Some(provider.authenticate(cx)))
.inline_assistant_model()
.map_or(None, |model| Some(model.provider.authenticate(cx)))
})?
else {
let answer = cx
@@ -401,14 +401,14 @@ impl InlineAssistant {
codegen_ranges.push(anchor_range);
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(model) = LanguageModelRegistry::read_global(cx).inline_assistant_model() {
self.telemetry.report_assistant_event(AssistantEvent {
conversation_id: None,
kind: AssistantKind::Inline,
phase: AssistantPhase::Invoked,
message_id: None,
model: model.telemetry_id(),
model_provider: model.provider_id().to_string(),
model: model.model.telemetry_id(),
model_provider: model.provider.id().to_string(),
response_latency: None,
error_message: None,
language_name: buffer.language().map(|language| language.name().to_proto()),
@@ -976,7 +976,7 @@ impl InlineAssistant {
let active_alternative = assist.codegen.read(cx).active_alternative().clone();
let message_id = active_alternative.read(cx).message_id.clone();
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(model) = LanguageModelRegistry::read_global(cx).inline_assistant_model() {
let language_name = assist.editor.upgrade().and_then(|editor| {
let multibuffer = editor.read(cx).buffer().read(cx);
let snapshot = multibuffer.snapshot(cx);
@@ -996,15 +996,15 @@ impl InlineAssistant {
} else {
AssistantPhase::Accepted
},
model: model.telemetry_id(),
model_provider: model.provider_id().to_string(),
model: model.model.telemetry_id(),
model_provider: model.model.provider_id().to_string(),
response_latency: None,
error_message: None,
language_name: language_name.map(|name| name.to_proto()),
},
Some(self.telemetry.clone()),
cx.http_client(),
model.api_key(cx),
model.model.api_key(cx),
cx.background_executor(),
);
}

View File

@@ -1,4 +1,4 @@
use crate::assistant_model_selector::AssistantModelSelector;
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::buffer_codegen::BufferCodegen;
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
@@ -582,7 +582,7 @@ impl<T: 'static> PromptEditor<T> {
let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
let model_registry = LanguageModelRegistry::read_global(cx);
let default_model = model_registry.active_model();
let default_model = model_registry.default_model().map(|default| default.model);
let alternative_models = model_registry.inline_alternative_models();
let get_model_name = |index: usize| -> String {
@@ -890,6 +890,7 @@ impl PromptEditor<BufferCodegen> {
fs,
model_selector_menu_handle,
prompt_editor.focus_handle(cx),
ModelType::InlineAssistant,
window,
cx,
)
@@ -1042,6 +1043,7 @@ impl PromptEditor<TerminalCodegen> {
fs,
model_selector_menu_handle.clone(),
prompt_editor.focus_handle(cx),
ModelType::InlineAssistant,
window,
cx,
)

View File

@@ -1,5 +1,6 @@
use std::sync::Arc;
use crate::assistant_model_selector::ModelType;
use collections::HashSet;
use editor::actions::MoveUp;
use editor::{ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorStyle};
@@ -9,8 +10,10 @@ use gpui::{
Animation, AnimationExt, App, DismissEvent, Entity, Focusable, Subscription, TextStyle,
WeakEntity, linear_color_stop, linear_gradient, point,
};
use language_model::LanguageModelRegistry;
use language::Buffer;
use language_model::{ConfiguredModel, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector;
use multi_buffer;
use project::Project;
use settings::Settings;
use std::time::Duration;
@@ -28,7 +31,7 @@ use crate::context_picker::{ConfirmBehavior, ContextPicker, ContextPickerComplet
use crate::context_store::{ContextStore, refresh_context_store_text};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::profile_selector::ProfileSelector;
use crate::thread::{RequestKind, Thread};
use crate::thread::{RequestKind, Thread, TokenUsageRatio};
use crate::thread_store::ThreadStore;
use crate::{
AgentDiff, Chat, ChatMode, NewThread, OpenAgentDiff, RemoveAllContext, ThreadEvent,
@@ -137,6 +140,7 @@ impl MessageEditor {
fs.clone(),
model_selector_menu_handle,
editor.focus_handle(cx),
ModelType::Default,
window,
cx,
)
@@ -189,7 +193,7 @@ impl MessageEditor {
fn is_model_selected(&self, cx: &App) -> bool {
LanguageModelRegistry::read_global(cx)
.active_model()
.default_model()
.is_some()
}
@@ -199,20 +203,16 @@ impl MessageEditor {
window: &mut Window,
cx: &mut Context<Self>,
) {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
if provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx))
{
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else {
return;
};
if provider.must_accept_terms(cx) {
cx.notify();
return;
}
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.active_model() else {
return;
};
let user_message = self.editor.update(cx, |editor, cx| {
let text = editor.text(cx);
editor.clear(window, cx);
@@ -320,6 +320,19 @@ impl MessageEditor {
fn handle_review_click(&self, window: &mut Window, cx: &mut Context<Self>) {
AgentDiff::deploy(self.thread.clone(), self.workspace.clone(), window, cx).log_err();
}
fn handle_file_click(
&self,
buffer: Entity<Buffer>,
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Ok(diff) = AgentDiff::deploy(self.thread.clone(), self.workspace.clone(), window, cx)
{
let path_key = multi_buffer::PathKey::for_buffer(&buffer, cx);
diff.update(cx, |diff, cx| diff.move_to_path(path_key, window, cx));
}
}
}
impl Focusable for MessageEditor {
@@ -338,9 +351,12 @@ impl Render for MessageEditor {
let thread = self.thread.read(cx);
let is_generating = thread.is_generating();
let is_too_long = thread.is_getting_too_long(cx);
let total_token_usage = thread.total_token_usage(cx);
let is_model_selected = self.is_model_selected(cx);
let is_editor_empty = self.is_editor_empty(cx);
let needs_confirmation =
thread.has_pending_tool_uses() && thread.tools_needing_confirmation().next().is_some();
let submit_label_color = if is_editor_empty {
Color::Muted
} else {
@@ -432,11 +448,17 @@ impl Render for MessageEditor {
},
),
)
.child(
Label::new("Generating…")
.size(LabelSize::XSmall)
.color(Color::Muted),
)
.child({
Label::new(if needs_confirmation {
"Waiting for confirmation…"
} else {
"Generating…"
})
.size(LabelSize::XSmall)
.color(Color::Muted)
})
.child(ui::Divider::vertical())
.child(
Button::new("cancel-generation", "Cancel")
@@ -478,11 +500,16 @@ impl Render for MessageEditor {
}])
.child(
h_flex()
.id("edits-container")
.p_1p5()
.justify_between()
.when(self.edits_expanded, |this| {
this.border_b_1().border_color(border_color)
})
.cursor_pointer()
.on_click(cx.listener(|this, _, window, cx| {
this.handle_review_click(window, cx)
}))
.child(
h_flex()
.gap_1()
@@ -596,11 +623,21 @@ impl Render for MessageEditor {
.justify_between()
.child(
h_flex()
.id("file-container")
.id(("file-container", index))
.pr_8()
.gap_1p5()
.max_w_full()
.overflow_x_scroll()
.cursor_pointer()
.on_click({
let buffer = buffer.clone();
cx.listener(move |this, _, window, cx| {
this.handle_file_click(buffer.clone(), window, cx);
})
})
.tooltip(
Tooltip::text(format!("Review {}", path.display()))
)
.child(file_icon)
.child(
h_flex()
@@ -779,7 +816,7 @@ impl Render for MessageEditor {
),
)
)
.when(is_too_long, |parent| {
.when(total_token_usage.ratio != TokenUsageRatio::Normal, |parent| {
parent.child(
h_flex()
.p_2()

View File

@@ -130,8 +130,8 @@ impl Render for ProfileSelector {
let model_registry = LanguageModelRegistry::read_global(cx);
let supports_tools = model_registry
.active_model()
.map_or(false, |model| model.supports_tools());
.default_model()
.map_or(false, |default| default.model.supports_tools());
let icon = match profile_id.as_str() {
"write" => IconName::Pencil,

View File

@@ -2,7 +2,9 @@ use crate::inline_prompt_editor::CodegenStatus;
use client::telemetry::Telemetry;
use futures::{SinkExt, StreamExt, channel::mpsc};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Task};
use language_model::{LanguageModelRegistry, LanguageModelRequest, report_assistant_event};
use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, report_assistant_event,
};
use std::{sync::Arc, time::Instant};
use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
use terminal::Terminal;
@@ -31,7 +33,9 @@ impl TerminalCodegen {
}
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
return;
};

View File

@@ -13,8 +13,8 @@ use fs::Fs;
use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
use language::Buffer;
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
report_assistant_event,
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
use prompt_store::PromptBuilder;
use std::sync::Arc;
@@ -286,7 +286,9 @@ impl TerminalInlineAssistant {
})
.log_err();
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
{
let codegen = assist.codegen.read(cx);
let executor = cx.background_executor().clone();
report_assistant_event(

View File

@@ -7,17 +7,17 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet};
use collections::{BTreeMap, HashMap};
use fs::Fs;
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
use git::repository::DiffType;
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason, TokenUsage,
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
PaymentRequiredError, Role, StopReason, TokenUsage,
};
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
use project::{Project, Worktree};
@@ -30,12 +30,12 @@ use settings::Settings;
use util::{ResultExt as _, TryFutureExt as _, maybe, post_inc};
use uuid::Uuid;
use crate::context::{AssistantContext, ContextId, attach_context_to_message};
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse,
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
#[derive(Debug, Clone, Copy)]
pub enum RequestKind {
@@ -82,9 +82,16 @@ pub struct Message {
pub id: MessageId,
pub role: Role,
pub segments: Vec<MessageSegment>,
pub context: String,
}
impl Message {
/// Returns whether the message contains any meaningful text that should be displayed
/// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
pub fn should_display_content(&self) -> bool {
self.segments.iter().all(|segment| segment.should_display())
}
pub fn push_thinking(&mut self, text: &str) {
if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
segment.push_str(text);
@@ -104,6 +111,11 @@ impl Message {
pub fn to_string(&self) -> String {
let mut result = String::new();
if !self.context.is_empty() {
result.push_str(&self.context);
}
for segment in &self.segments {
match segment {
MessageSegment::Text(text) => result.push_str(text),
@@ -114,11 +126,12 @@ impl Message {
}
}
}
result
}
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MessageSegment {
Text(String),
Thinking(String),
@@ -131,6 +144,16 @@ impl MessageSegment {
Self::Thinking(text) => text,
}
}
pub fn should_display(&self) -> bool {
// We add USING_TOOL_MARKER when making a request that includes tool uses
// without non-whitespace text around them, and this can cause the model
// to mimic the pattern, so we consider those segments not displayable.
match self {
Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -198,6 +221,21 @@ pub enum DetailedSummaryState {
},
}
#[derive(Default)]
pub struct TotalTokenUsage {
pub total: usize,
pub max: usize,
pub ratio: TokenUsageRatio,
}
#[derive(Default, PartialEq, Eq)]
pub enum TokenUsageRatio {
#[default]
Normal,
Warning,
Exceeded,
}
/// A thread of conversation with the LLM.
pub struct Thread {
id: ThreadId,
@@ -252,7 +290,7 @@ impl Thread {
last_restore_checkpoint: None,
pending_checkpoint: None,
tool_use: ToolUseState::new(tools.clone()),
action_log: cx.new(|_| ActionLog::new()),
action_log: cx.new(|_| ActionLog::new(project.clone())),
initial_project_snapshot: {
let project_snapshot = Self::project_snapshot(project, cx);
cx.foreground_executor()
@@ -304,6 +342,7 @@ impl Thread {
}
})
.collect(),
context: message.context,
})
.collect(),
next_message_id,
@@ -315,11 +354,11 @@ impl Thread {
pending_completions: Vec::new(),
last_restore_checkpoint: None,
pending_checkpoint: None,
project,
project: project.clone(),
prompt_builder,
tools,
tool_use,
action_log: cx.new(|_| ActionLog::new()),
action_log: cx.new(|_| ActionLog::new(project)),
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
cumulative_token_usage: serialized.cumulative_token_usage,
feedback: None,
@@ -564,15 +603,58 @@ impl Thread {
git_checkpoint: Option<GitStoreCheckpoint>,
cx: &mut Context<Self>,
) -> MessageId {
let message_id =
self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
let context_ids = context
let text = text.into();
let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
// Filter out contexts that have already been included in previous messages
let new_context: Vec<_> = context
.into_iter()
.filter(|ctx| !self.context.contains_key(&ctx.id()))
.collect();
if !new_context.is_empty() {
if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
message.context = context_string;
}
}
self.action_log.update(cx, |log, cx| {
// Track all buffers added as context
for ctx in &new_context {
match ctx {
AssistantContext::File(file_ctx) => {
log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
}
AssistantContext::Directory(dir_ctx) => {
for context_buffer in &dir_ctx.context_buffers {
log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
}
}
AssistantContext::Symbol(symbol_ctx) => {
log.buffer_added_as_context(
symbol_ctx.context_symbol.buffer.clone(),
cx,
);
}
AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
}
}
});
}
let context_ids = new_context
.iter()
.map(|context| context.id())
.collect::<Vec<_>>();
self.context
.extend(context.into_iter().map(|context| (context.id(), context)));
self.context.extend(
new_context
.into_iter()
.map(|context| (context.id(), context)),
);
self.context_by_message.insert(message_id, context_ids);
if let Some(git_checkpoint) = git_checkpoint {
self.pending_checkpoint = Some(ThreadCheckpoint {
message_id,
@@ -589,7 +671,12 @@ impl Thread {
cx: &mut Context<Self>,
) -> MessageId {
let id = self.next_message_id.post_inc();
self.messages.push(Message { id, role, segments });
self.messages.push(Message {
id,
role,
segments,
context: String::new(),
});
self.touch_updated_at();
cx.emit(ThreadEvent::MessageAdded(id));
id
@@ -695,6 +782,7 @@ impl Thread {
content: tool_result.content.clone(),
})
.collect(),
context: message.context.clone(),
})
.collect(),
initial_project_snapshot,
@@ -881,8 +969,6 @@ impl Thread {
log::error!("system_prompt_context not set.")
}
let mut added_context_ids = HashSet::<ContextId>::default();
for message in &self.messages {
let mut request_message = LanguageModelRequestMessage {
role: message.role,
@@ -903,23 +989,6 @@ impl Thread {
}
}
// Attach context to this message if it's the first to reference it
if let Some(context_ids) = self.context_by_message.get(&message.id) {
let new_context_ids: Vec<_> = context_ids
.iter()
.filter(|id| !added_context_ids.contains(id))
.collect();
if !new_context_ids.is_empty() {
let referenced_context = new_context_ids
.iter()
.filter_map(|context_id| self.context.get(*context_id));
attach_context_to_message(&mut request_message, referenced_context, cx);
added_context_ids.extend(context_ids.iter());
}
}
if !message.segments.is_empty() {
request_message
.content
@@ -939,11 +1008,9 @@ impl Thread {
request.messages.push(request_message);
}
// Set a cache breakpoint at the second-to-last message.
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
let breakpoint_index = request.messages.len() - 2;
for (index, message) in request.messages.iter_mut().enumerate() {
message.cache = index == breakpoint_index;
if let Some(last) = request.messages.last_mut() {
last.cache = true;
}
self.attached_tracked_files_state(&mut request.messages, cx);
@@ -968,7 +1035,7 @@ impl Thread {
};
if stale_message.is_empty() {
write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
}
writeln!(&mut stale_message, "- {}", file.path().display()).ok();
@@ -1083,32 +1150,15 @@ impl Thread {
}
}
LanguageModelCompletionEvent::ToolUse(tool_use) => {
let last_assistant_message = thread
let last_assistant_message_id = thread
.messages
.iter_mut()
.rfind(|message| message.role == Role::Assistant);
.rfind(|message| message.role == Role::Assistant)
.map(|message| message.id)
.unwrap_or_else(|| {
thread.insert_message(Role::Assistant, vec![], cx)
});
let last_assistant_message_id =
if let Some(message) = last_assistant_message {
if let Some(segment) = message.segments.first_mut() {
let text = segment.text_mut();
if text.is_empty() {
text.push_str("Using tool...");
}
} else {
message.segments.push(MessageSegment::Text(
"Using tool...".to_string(),
));
}
message.id
} else {
thread.insert_message(
Role::Assistant,
vec![MessageSegment::Text("Using tool...".to_string())],
cx,
)
};
thread.tool_use.request_tool_use(
last_assistant_message_id,
tool_use,
@@ -1200,14 +1250,11 @@ impl Thread {
}
pub fn summarize(&mut self, cx: &mut Context<Self>) {
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
return;
};
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
return;
};
if !provider.is_authenticated(cx) {
if !model.provider.is_authenticated(cx) {
return;
}
@@ -1226,7 +1273,7 @@ impl Thread {
self.pending_summary = cx.spawn(async move |this, cx| {
async move {
let stream = model.stream_completion_text(request, &cx);
let stream = model.model.stream_completion_text(request, &cx);
let mut messages = stream.await?;
let mut new_summary = String::new();
@@ -1270,8 +1317,8 @@ impl Thread {
_ => {}
}
let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let ConfiguredModel { model, provider } =
LanguageModelRegistry::read_global(cx).thread_summary_model()?;
if !provider.is_authenticated(cx) {
return None;
@@ -1439,17 +1486,7 @@ impl Thread {
})
}
pub fn attach_tool_results(
&mut self,
updated_context: Vec<AssistantContext>,
cx: &mut Context<Self>,
) {
self.context.extend(
updated_context
.into_iter()
.map(|context| (context.id(), context)),
);
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
// Insert a user message to contain the tool results.
self.insert_user_message(
// TODO: Sending up a user message without any content results in the model sending back
@@ -1658,6 +1695,11 @@ impl Thread {
Role::System => "System",
}
)?;
if !message.context.is_empty() {
writeln!(markdown, "{}", message.context)?;
}
for segment in &message.segments {
match segment {
MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
@@ -1712,6 +1754,17 @@ impl Thread {
.update(cx, |action_log, cx| action_log.keep_all_edits(cx));
}
pub fn reject_edits_in_range(
&mut self,
buffer: Entity<language::Buffer>,
buffer_range: Range<language::Anchor>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
self.action_log.update(cx, |action_log, cx| {
action_log.reject_edits_in_range(buffer, buffer_range, cx)
})
}
pub fn action_log(&self) -> &Entity<ActionLog> {
&self.action_log
}
@@ -1724,26 +1777,33 @@ impl Thread {
self.cumulative_token_usage.clone()
}
pub fn is_getting_too_long(&self, cx: &App) -> bool {
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.active_model() else {
return false;
let Some(model) = model_registry.default_model() else {
return TotalTokenUsage::default();
};
let max_tokens = model.max_token_count();
let current_usage =
self.cumulative_token_usage.input_tokens + self.cumulative_token_usage.output_tokens;
let max = model.model.max_token_count();
#[cfg(debug_assertions)]
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
.unwrap_or("0.9".to_string())
.unwrap_or("0.8".to_string())
.parse()
.unwrap();
#[cfg(not(debug_assertions))]
let warning_threshold: f32 = 0.9;
let warning_threshold: f32 = 0.8;
current_usage as f32 >= (max_tokens as f32 * warning_threshold)
let total = self.cumulative_token_usage.total_tokens() as usize;
let ratio = if total >= max {
TokenUsageRatio::Exceeded
} else if total as f32 / max as f32 >= warning_threshold {
TokenUsageRatio::Warning
} else {
TokenUsageRatio::Normal
};
TotalTokenUsage { total, max, ratio }
}
pub fn deny_tool_use(
@@ -1807,3 +1867,415 @@ struct PendingCompletion {
id: usize,
_task: Task<()>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ThreadStore, context_store::ContextStore, thread_store};
use assistant_settings::AssistantSettings;
use context_server::ContextServerSettings;
use editor::EditorSettings;
use gpui::TestAppContext;
use project::{FakeFs, Project};
use prompt_store::PromptBuilder;
use serde_json::json;
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use theme::ThemeSettings;
use util::path;
use workspace::Workspace;
#[gpui::test]
async fn test_message_with_context(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_workspace, _thread_store, thread, context_store) =
setup_test_environment(cx, project.clone()).await;
add_file_to_context(&project, &context_store, "test/code.rs", cx)
.await
.unwrap();
let context =
context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
// Insert user message with context
let message_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Please explain this code", vec![context], None, cx)
});
// Check content and context in message object
let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
// Use different path format strings based on platform for the test
#[cfg(windows)]
let path_part = r"test\code.rs";
#[cfg(not(windows))]
let path_part = "test/code.rs";
let expected_context = format!(
r#"
<context>
The following items were attached by the user. You don't need to use other tools to read them.
<files>
```rs {path_part}
fn main() {{
println!("Hello, world!");
}}
```
</files>
</context>
"#
);
assert_eq!(message.role, Role::User);
assert_eq!(message.segments.len(), 1);
assert_eq!(
message.segments[0],
MessageSegment::Text("Please explain this code".to_string())
);
assert_eq!(message.context, expected_context);
// Check message in request
let request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 1);
let expected_full_message = format!("{}Please explain this code", expected_context);
assert_eq!(request.messages[0].string_contents(), expected_full_message);
}
#[gpui::test]
async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({
"file1.rs": "fn function1() {}\n",
"file2.rs": "fn function2() {}\n",
"file3.rs": "fn function3() {}\n",
}),
)
.await;
let (_, _thread_store, thread, context_store) =
setup_test_environment(cx, project.clone()).await;
// Open files individually
add_file_to_context(&project, &context_store, "test/file1.rs", cx)
.await
.unwrap();
add_file_to_context(&project, &context_store, "test/file2.rs", cx)
.await
.unwrap();
add_file_to_context(&project, &context_store, "test/file3.rs", cx)
.await
.unwrap();
// Get the context objects
let contexts = context_store.update(cx, |store, _| store.context().clone());
assert_eq!(contexts.len(), 3);
// First message with context 1
let message1_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
});
// Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
let message2_id = thread.update(cx, |thread, cx| {
thread.insert_user_message(
"Message 2",
vec![contexts[0].clone(), contexts[1].clone()],
None,
cx,
)
});
// Third message with all three contexts (contexts 1 and 2 should be skipped)
let message3_id = thread.update(cx, |thread, cx| {
thread.insert_user_message(
"Message 3",
vec![
contexts[0].clone(),
contexts[1].clone(),
contexts[2].clone(),
],
None,
cx,
)
});
// Check what contexts are included in each message
let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
(
thread.message(message1_id).unwrap().clone(),
thread.message(message2_id).unwrap().clone(),
thread.message(message3_id).unwrap().clone(),
)
});
// First message should include context 1
assert!(message1.context.contains("file1.rs"));
// Second message should include only context 2 (not 1)
assert!(!message2.context.contains("file1.rs"));
assert!(message2.context.contains("file2.rs"));
// Third message should include only context 3 (not 1 or 2)
assert!(!message3.context.contains("file1.rs"));
assert!(!message3.context.contains("file2.rs"));
assert!(message3.context.contains("file3.rs"));
// Check entire request to make sure all contexts are properly included
let request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
// The request should contain all 3 messages
assert_eq!(request.messages.len(), 3);
// Check that the contexts are properly formatted in each message
assert!(request.messages[0].string_contents().contains("file1.rs"));
assert!(!request.messages[0].string_contents().contains("file2.rs"));
assert!(!request.messages[0].string_contents().contains("file3.rs"));
assert!(!request.messages[1].string_contents().contains("file1.rs"));
assert!(request.messages[1].string_contents().contains("file2.rs"));
assert!(!request.messages[1].string_contents().contains("file3.rs"));
assert!(!request.messages[2].string_contents().contains("file1.rs"));
assert!(!request.messages[2].string_contents().contains("file2.rs"));
assert!(request.messages[2].string_contents().contains("file3.rs"));
}
#[gpui::test]
async fn test_message_without_files(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_, _thread_store, thread, _context_store) =
setup_test_environment(cx, project.clone()).await;
// Insert user message without any context (empty context vector)
let message_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
});
// Check content and context in message object
let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
// Context should be empty when no files are included
assert_eq!(message.role, Role::User);
assert_eq!(message.segments.len(), 1);
assert_eq!(
message.segments[0],
MessageSegment::Text("What is the best way to learn Rust?".to_string())
);
assert_eq!(message.context, "");
// Check message in request
let request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 1);
assert_eq!(
request.messages[0].string_contents(),
"What is the best way to learn Rust?"
);
// Add second message, also without context
let message2_id = thread.update(cx, |thread, cx| {
thread.insert_user_message("Are there any good books?", vec![], None, cx)
});
let message2 =
thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
assert_eq!(message2.context, "");
// Check that both messages appear in the request
let request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 2);
assert_eq!(
request.messages[0].string_contents(),
"What is the best way to learn Rust?"
);
assert_eq!(
request.messages[1].string_contents(),
"Are there any good books?"
);
}
#[gpui::test]
async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_workspace, _thread_store, thread, context_store) =
setup_test_environment(cx, project.clone()).await;
// Open buffer and add it to context
let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
.await
.unwrap();
let context =
context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
// Insert user message with the buffer as context
thread.update(cx, |thread, cx| {
thread.insert_user_message("Explain this code", vec![context], None, cx)
});
// Create a request and check that it doesn't have a stale buffer warning yet
let initial_request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
// Make sure we don't have a stale file warning yet
let has_stale_warning = initial_request.messages.iter().any(|msg| {
msg.string_contents()
.contains("These files changed since last read:")
});
assert!(
!has_stale_warning,
"Should not have stale buffer warning before buffer is modified"
);
// Modify the buffer
buffer.update(cx, |buffer, cx| {
// Find a position at the end of line 1
buffer.edit(
[(1..1, "\n println!(\"Added a new line\");\n")],
None,
cx,
);
});
// Insert another user message without context
thread.update(cx, |thread, cx| {
thread.insert_user_message("What does the code do now?", vec![], None, cx)
});
// Create a new request and check for the stale buffer warning
let new_request = thread.read_with(cx, |thread, cx| {
thread.to_completion_request(RequestKind::Chat, cx)
});
// We should have a stale file warning as the last message
let last_message = new_request
.messages
.last()
.expect("Request should have messages");
// The last message should be the stale buffer notification
assert_eq!(last_message.role, Role::User);
// Check the exact content of the message
let expected_content = "These files changed since last read:\n- code.rs\n";
assert_eq!(
last_message.string_contents(),
expected_content,
"Last message should be exactly the stale buffer notification"
);
}
fn init_test_settings(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
AssistantSettings::register(cx);
thread_store::init(cx);
workspace::init_settings(cx);
ThemeSettings::register(cx);
ContextServerSettings::register(cx);
EditorSettings::register(cx);
});
}
// Helper to create a test project with test files
async fn create_test_project(
cx: &mut TestAppContext,
files: serde_json::Value,
) -> Entity<Project> {
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/test"), files).await;
Project::test(fs, [path!("/test").as_ref()], cx).await
}
async fn setup_test_environment(
cx: &mut TestAppContext,
project: Entity<Project>,
) -> (
Entity<Workspace>,
Entity<ThreadStore>,
Entity<Thread>,
Entity<ContextStore>,
) {
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let thread_store = cx.update(|_, cx| {
ThreadStore::new(
project.clone(),
Arc::default(),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
.unwrap()
});
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let context_store = cx.new(|_cx| ContextStore::new(workspace.downgrade(), None));
(workspace, thread_store, thread, context_store)
}
async fn add_file_to_context(
project: &Entity<Project>,
context_store: &Entity<ContextStore>,
path: &str,
cx: &mut TestAppContext,
) -> Result<Entity<language::Buffer>> {
let buffer_path = project
.read_with(cx, |project, cx| project.find_project_path(path, cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(buffer_path, cx))
.await
.unwrap();
context_store
.update(cx, |store, cx| {
store.add_file_from_buffer(buffer.clone(), cx)
})
.await?;
Ok(buffer)
}
}

View File

@@ -1,52 +1,176 @@
use std::sync::Arc;
use assistant_context_editor::SavedContextMetadata;
use editor::{Editor, EditorEvent};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
App, Entity, FocusHandle, Focusable, ScrollStrategy, UniformListScrollHandle, WeakEntity,
uniform_list,
App, Entity, FocusHandle, Focusable, ScrollStrategy, Task, UniformListScrollHandle, WeakEntity,
Window, uniform_list,
};
use time::{OffsetDateTime, UtcOffset};
use ui::{IconButtonShape, ListItem, ListItemSpacing, Tooltip, prelude::*};
use ui::{HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Tooltip, prelude::*};
use util::ResultExt;
use crate::history_store::{HistoryEntry, HistoryStore};
use crate::thread_store::SerializedThreadMetadata;
use crate::{AssistantPanel, RemoveSelectedThread};
pub struct ThreadHistory {
focus_handle: FocusHandle,
assistant_panel: WeakEntity<AssistantPanel>,
history_store: Entity<HistoryStore>,
scroll_handle: UniformListScrollHandle,
selected_index: usize,
search_query: SharedString,
search_editor: Entity<Editor>,
all_entries: Arc<Vec<HistoryEntry>>,
matches: Vec<StringMatch>,
_subscriptions: Vec<gpui::Subscription>,
_search_task: Option<Task<()>>,
}
impl ThreadHistory {
pub(crate) fn new(
assistant_panel: WeakEntity<AssistantPanel>,
history_store: Entity<HistoryStore>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let search_editor = cx.new(|cx| {
let mut editor = Editor::single_line(window, cx);
editor.set_placeholder_text("Search threads...", cx);
editor
});
let search_editor_subscription = cx.subscribe_in(
&search_editor,
window,
|this, search_editor, event, window, cx| {
if let EditorEvent::BufferEdited = event {
let query = search_editor.read(cx).text(cx);
this.search_query = query.into();
this.update_search(window, cx);
}
},
);
let entries: Arc<Vec<_>> = history_store
.update(cx, |store, cx| store.entries(cx))
.into();
let history_store_subscription =
cx.observe_in(&history_store, window, |this, history_store, window, cx| {
this.all_entries = history_store
.update(cx, |store, cx| store.entries(cx))
.into();
this.matches.clear();
this.update_search(window, cx);
});
Self {
focus_handle: cx.focus_handle(),
assistant_panel,
history_store,
scroll_handle: UniformListScrollHandle::default(),
selected_index: 0,
search_query: SharedString::new_static(""),
all_entries: entries,
matches: Vec::new(),
search_editor,
_subscriptions: vec![search_editor_subscription, history_store_subscription],
_search_task: None,
}
}
fn update_search(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
self._search_task.take();
if self.has_search_query() {
self.perform_search(cx);
} else {
self.matches.clear();
self.set_selected_index(0, cx);
cx.notify();
}
}
fn perform_search(&mut self, cx: &mut Context<Self>) {
let query = self.search_query.clone();
let all_entries = self.all_entries.clone();
let task = cx.spawn(async move |this, cx| {
let executor = cx.background_executor().clone();
let matches = cx
.background_spawn(async move {
let mut candidates = Vec::with_capacity(all_entries.len());
for (idx, entry) in all_entries.iter().enumerate() {
match entry {
HistoryEntry::Thread(thread) => {
candidates.push(StringMatchCandidate::new(idx, &thread.summary));
}
HistoryEntry::Context(context) => {
candidates.push(StringMatchCandidate::new(idx, &context.title));
}
}
}
const MAX_MATCHES: usize = 100;
fuzzy::match_strings(
&candidates,
&query,
false,
MAX_MATCHES,
&Default::default(),
executor,
)
.await
})
.await;
this.update(cx, |this, cx| {
this.matches = matches;
this.set_selected_index(0, cx);
cx.notify();
})
.log_err();
});
self._search_task = Some(task);
}
fn has_search_query(&self) -> bool {
!self.search_query.is_empty()
}
fn matched_count(&self) -> usize {
if self.has_search_query() {
self.matches.len()
} else {
self.all_entries.len()
}
}
fn get_match(&self, ix: usize) -> Option<&HistoryEntry> {
if self.has_search_query() {
self.matches
.get(ix)
.and_then(|m| self.all_entries.get(m.candidate_id))
} else {
self.all_entries.get(ix)
}
}
pub fn select_previous(
&mut self,
_: &menu::SelectPrevious,
window: &mut Window,
_window: &mut Window,
cx: &mut Context<Self>,
) {
let count = self
.history_store
.update(cx, |this, cx| this.entry_count(cx));
let count = self.matched_count();
if count > 0 {
if self.selected_index == 0 {
self.set_selected_index(count - 1, window, cx);
self.set_selected_index(count - 1, cx);
} else {
self.set_selected_index(self.selected_index - 1, window, cx);
self.set_selected_index(self.selected_index - 1, cx);
}
}
}
@@ -54,40 +178,39 @@ impl ThreadHistory {
pub fn select_next(
&mut self,
_: &menu::SelectNext,
window: &mut Window,
_window: &mut Window,
cx: &mut Context<Self>,
) {
let count = self
.history_store
.update(cx, |this, cx| this.entry_count(cx));
let count = self.matched_count();
if count > 0 {
if self.selected_index == count - 1 {
self.set_selected_index(0, window, cx);
self.set_selected_index(0, cx);
} else {
self.set_selected_index(self.selected_index + 1, window, cx);
self.set_selected_index(self.selected_index + 1, cx);
}
}
}
fn select_first(&mut self, _: &menu::SelectFirst, window: &mut Window, cx: &mut Context<Self>) {
let count = self
.history_store
.update(cx, |this, cx| this.entry_count(cx));
fn select_first(
&mut self,
_: &menu::SelectFirst,
_window: &mut Window,
cx: &mut Context<Self>,
) {
let count = self.matched_count();
if count > 0 {
self.set_selected_index(0, window, cx);
self.set_selected_index(0, cx);
}
}
fn select_last(&mut self, _: &menu::SelectLast, window: &mut Window, cx: &mut Context<Self>) {
let count = self
.history_store
.update(cx, |this, cx| this.entry_count(cx));
fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context<Self>) {
let count = self.matched_count();
if count > 0 {
self.set_selected_index(count - 1, window, cx);
self.set_selected_index(count - 1, cx);
}
}
fn set_selected_index(&mut self, index: usize, _window: &mut Window, cx: &mut Context<Self>) {
fn set_selected_index(&mut self, index: usize, cx: &mut Context<Self>) {
self.selected_index = index;
self.scroll_handle
.scroll_to_item(index, ScrollStrategy::Top);
@@ -95,23 +218,23 @@ impl ThreadHistory {
}
fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let entries = self.history_store.update(cx, |this, cx| this.entries(cx));
if let Some(entry) = self.get_match(self.selected_index) {
let task_result = match entry {
HistoryEntry::Thread(thread) => self
.assistant_panel
.update(cx, move |this, cx| this.open_thread(&thread.id, window, cx))
.ok(),
HistoryEntry::Context(context) => self
.assistant_panel
.update(cx, move |this, cx| {
this.open_saved_prompt_editor(context.path.clone(), window, cx)
})
.ok(),
};
if let Some(entry) = entries.get(self.selected_index) {
match entry {
HistoryEntry::Thread(thread) => {
self.assistant_panel
.update(cx, move |this, cx| this.open_thread(&thread.id, window, cx))
.ok();
}
HistoryEntry::Context(context) => {
self.assistant_panel
.update(cx, move |this, cx| {
this.open_saved_prompt_editor(context.path.clone(), window, cx)
})
.ok();
}
}
if let Some(task) = task_result {
task.detach_and_log_err(cx);
};
cx.notify();
}
@@ -120,12 +243,10 @@ impl ThreadHistory {
fn remove_selected_thread(
&mut self,
_: &RemoveSelectedThread,
_window: &mut Window,
window: &mut Window,
cx: &mut Context<Self>,
) {
let entries = self.history_store.update(cx, |this, cx| this.entries(cx));
if let Some(entry) = entries.get(self.selected_index) {
if let Some(entry) = self.get_match(self.selected_index) {
match entry {
HistoryEntry::Thread(thread) => {
self.assistant_panel
@@ -143,72 +264,117 @@ impl ThreadHistory {
}
}
self.update_search(window, cx);
cx.notify();
}
}
}
impl Focusable for ThreadHistory {
fn focus_handle(&self, _cx: &App) -> FocusHandle {
self.focus_handle.clone()
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.search_editor.focus_handle(cx)
}
}
impl Render for ThreadHistory {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let history_entries = self.history_store.update(cx, |this, cx| this.entries(cx));
let selected_index = self.selected_index;
v_flex()
.id("thread-history-container")
.key_context("ThreadHistory")
.track_focus(&self.focus_handle)
.overflow_y_scroll()
.size_full()
.p_1()
.on_action(cx.listener(Self::select_previous))
.on_action(cx.listener(Self::select_next))
.on_action(cx.listener(Self::select_first))
.on_action(cx.listener(Self::select_last))
.on_action(cx.listener(Self::confirm))
.on_action(cx.listener(Self::remove_selected_thread))
.map(|history| {
if history_entries.is_empty() {
history
.justify_center()
.when(!self.all_entries.is_empty(), |parent| {
parent.child(
h_flex()
.h(px(41.)) // Match the toolbar perfectly
.w_full()
.py_1()
.px_2()
.gap_2()
.justify_between()
.border_b_1()
.border_color(cx.theme().colors().border)
.child(
Icon::new(IconName::MagnifyingGlass)
.color(Color::Muted)
.size(IconSize::Small),
)
.child(self.search_editor.clone()),
)
})
.child({
let view = v_flex().overflow_hidden().flex_grow();
if self.all_entries.is_empty() {
view.justify_center()
.child(
h_flex().w_full().justify_center().child(
Label::new("You don't have any past threads yet.")
.size(LabelSize::Small),
),
)
} else if self.has_search_query() && self.matches.is_empty() {
view.justify_center().child(
h_flex().w_full().justify_center().child(
Label::new("No threads match your search.").size(LabelSize::Small),
),
)
} else {
history.child(
view.p_1().child(
uniform_list(
cx.entity().clone(),
"thread-history",
history_entries.len(),
self.matched_count(),
move |history, range, _window, _cx| {
history_entries[range]
.iter()
.enumerate()
.map(|(index, entry)| {
h_flex().w_full().pb_1().child(match entry {
HistoryEntry::Thread(thread) => PastThread::new(
thread.clone(),
history.assistant_panel.clone(),
selected_index == index,
)
.into_any_element(),
HistoryEntry::Context(context) => PastContext::new(
context.clone(),
history.assistant_panel.clone(),
selected_index == index,
)
.into_any_element(),
})
let range_start = range.start;
let assistant_panel = history.assistant_panel.clone();
let render_item = |index: usize,
entry: &HistoryEntry,
highlight_positions: Vec<usize>|
-> Div {
h_flex().w_full().pb_1().child(match entry {
HistoryEntry::Thread(thread) => PastThread::new(
thread.clone(),
assistant_panel.clone(),
selected_index == index + range_start,
highlight_positions,
)
.into_any_element(),
HistoryEntry::Context(context) => PastContext::new(
context.clone(),
assistant_panel.clone(),
selected_index == index + range_start,
highlight_positions,
)
.into_any_element(),
})
.collect()
};
if history.has_search_query() {
history.matches[range]
.iter()
.enumerate()
.filter_map(|(index, m)| {
history.all_entries.get(m.candidate_id).map(|entry| {
render_item(index, entry, m.positions.clone())
})
})
.collect()
} else {
history.all_entries[range]
.iter()
.enumerate()
.map(|(index, entry)| render_item(index, entry, vec![]))
.collect()
}
},
)
.track_scroll(self.scroll_handle.clone())
@@ -224,6 +390,7 @@ pub struct PastThread {
thread: SerializedThreadMetadata,
assistant_panel: WeakEntity<AssistantPanel>,
selected: bool,
highlight_positions: Vec<usize>,
}
impl PastThread {
@@ -231,11 +398,13 @@ impl PastThread {
thread: SerializedThreadMetadata,
assistant_panel: WeakEntity<AssistantPanel>,
selected: bool,
highlight_positions: Vec<usize>,
) -> Self {
Self {
thread,
assistant_panel,
selected,
highlight_positions,
}
}
}
@@ -258,9 +427,11 @@ impl RenderOnce for PastThread {
.toggle_state(self.selected)
.spacing(ListItemSpacing::Sparse)
.start_slot(
div()
.max_w_4_5()
.child(Label::new(summary).size(LabelSize::Small).truncate()),
div().max_w_4_5().child(
HighlightedLabel::new(summary, self.highlight_positions)
.size(LabelSize::Small)
.truncate(),
),
)
.end_slot(
h_flex()
@@ -318,6 +489,7 @@ pub struct PastContext {
context: SavedContextMetadata,
assistant_panel: WeakEntity<AssistantPanel>,
selected: bool,
highlight_positions: Vec<usize>,
}
impl PastContext {
@@ -325,11 +497,13 @@ impl PastContext {
context: SavedContextMetadata,
assistant_panel: WeakEntity<AssistantPanel>,
selected: bool,
highlight_positions: Vec<usize>,
) -> Self {
Self {
context,
assistant_panel,
selected,
highlight_positions,
}
}
}
@@ -337,7 +511,6 @@ impl PastContext {
impl RenderOnce for PastContext {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
let summary = self.context.title;
let context_timestamp = time_format::format_localized_timestamp(
OffsetDateTime::from_unix_timestamp(self.context.mtime.timestamp()).unwrap(),
OffsetDateTime::now_utc(),
@@ -354,9 +527,11 @@ impl RenderOnce for PastContext {
.toggle_state(self.selected)
.spacing(ListItemSpacing::Sparse)
.start_slot(
div()
.max_w_4_5()
.child(Label::new(summary).size(LabelSize::Small).truncate()),
div().max_w_4_5().child(
HighlightedLabel::new(summary, self.highlight_positions)
.size(LabelSize::Small)
.truncate(),
),
)
.end_slot(
h_flex()

View File

@@ -374,6 +374,8 @@ pub struct SerializedMessage {
pub tool_uses: Vec<SerializedToolUse>,
#[serde(default)]
pub tool_results: Vec<SerializedToolResult>,
#[serde(default)]
pub context: String,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -441,6 +443,7 @@ impl LegacySerializedMessage {
segments: vec![SerializedMessageSegment::Text { text: self.text }],
tool_uses: self.tool_uses,
tool_results: self.tool_results,
context: String::new(),
}
}
}

View File

@@ -43,6 +43,8 @@ pub struct ToolUseState {
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
}
pub const USING_TOOL_MARKER: &str = "<using_tool>";
impl ToolUseState {
pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
Self {
@@ -357,8 +359,28 @@ impl ToolUseState {
request_message: &mut LanguageModelRequestMessage,
) {
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
let mut found_tool_use = false;
for tool_use in tool_uses {
if self.tool_results.contains_key(&tool_use.id) {
if !found_tool_use {
// The API fails if a message contains a tool use without any (non-whitespace) text around it
match request_message.content.last_mut() {
Some(MessageContent::Text(txt)) => {
if txt.is_empty() {
txt.push_str(USING_TOOL_MARKER);
}
}
None | Some(_) => {
request_message
.content
.push(MessageContent::Text(USING_TOOL_MARKER.into()));
}
};
}
found_tool_use = true;
// Do not send tool uses until they are completed
request_message
.content

View File

@@ -26,3 +26,4 @@ serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
util.workspace = true
workspace-hack.workspace = true

View File

@@ -1,10 +1,10 @@
mod supported_countries;
use std::{pin::Pin, str::FromStr};
use std::str::FromStr;
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use futures::{AsyncBufReadExt, AsyncReadExt, Stream, StreamExt, io::BufReader, stream::BoxStream};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::http::{HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
@@ -437,50 +437,6 @@ pub async fn stream_completion_with_rate_limit_info(
}
}
pub async fn extract_tool_args_from_events(
tool_name: String,
mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
) -> Result<impl Send + Stream<Item = Result<String>>> {
let mut tool_use_index = None;
while let Some(event) = events.next().await {
if let Event::ContentBlockStart {
index,
content_block: ResponseContent::ToolUse { name, .. },
} = event?
{
if name == tool_name {
tool_use_index = Some(index);
break;
}
}
}
let Some(tool_use_index) = tool_use_index else {
return Err(anyhow!("tool not used"));
};
Ok(events.filter_map(move |event| {
let result = match event {
Err(error) => Some(Err(error)),
Ok(Event::ContentBlockDelta { index, delta }) => match delta {
ContentDelta::TextDelta { .. } => None,
ContentDelta::ThinkingDelta { .. } => None,
ContentDelta::SignatureDelta { .. } => None,
ContentDelta::InputJsonDelta { partial_json } => {
if index == tool_use_index {
Some(Ok(partial_json))
} else {
None
}
}
},
_ => None,
};
async move { result }
}))
}
#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
#[serde(rename_all = "lowercase")]
pub enum CacheControlType {

View File

@@ -19,3 +19,4 @@ smol.workspace = true
tempfile.workspace = true
util.workspace = true
which.workspace = true
workspace-hack.workspace = true

View File

@@ -15,3 +15,4 @@ workspace = true
anyhow.workspace = true
gpui.workspace = true
rust-embed.workspace = true
workspace-hack.workspace = true

View File

@@ -69,6 +69,7 @@ ui.workspace = true
util.workspace = true
workspace.workspace = true
zed_actions.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
ctor.workspace = true

View File

@@ -161,12 +161,38 @@ fn init_language_model_settings(cx: &mut App) {
fn update_active_language_model_from_settings(cx: &mut App) {
let settings = AssistantSettings::get_global(cx);
// Default model - used as fallback
let active_model_provider_name =
LanguageModelProviderId::from(settings.default_model.provider.clone());
let active_model_id = LanguageModelId::from(settings.default_model.model.clone());
let editor_provider_name =
LanguageModelProviderId::from(settings.editor_model.provider.clone());
let editor_model_id = LanguageModelId::from(settings.editor_model.model.clone());
// Inline assistant model
let inline_assistant_model = settings
.inline_assistant_model
.as_ref()
.unwrap_or(&settings.default_model);
let inline_assistant_provider_name =
LanguageModelProviderId::from(inline_assistant_model.provider.clone());
let inline_assistant_model_id = LanguageModelId::from(inline_assistant_model.model.clone());
// Commit message model
let commit_message_model = settings
.commit_message_model
.as_ref()
.unwrap_or(&settings.default_model);
let commit_message_provider_name =
LanguageModelProviderId::from(commit_message_model.provider.clone());
let commit_message_model_id = LanguageModelId::from(commit_message_model.model.clone());
// Thread summary model
let thread_summary_model = settings
.thread_summary_model
.as_ref()
.unwrap_or(&settings.default_model);
let thread_summary_provider_name =
LanguageModelProviderId::from(thread_summary_model.provider.clone());
let thread_summary_model_id = LanguageModelId::from(thread_summary_model.model.clone());
let inline_alternatives = settings
.inline_alternatives
.iter()
@@ -177,9 +203,29 @@ fn update_active_language_model_from_settings(cx: &mut App) {
)
})
.collect::<Vec<_>>();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.select_active_model(&active_model_provider_name, &active_model_id, cx);
registry.select_editor_model(&editor_provider_name, &editor_model_id, cx);
// Set the default model
registry.select_default_model(&active_model_provider_name, &active_model_id, cx);
// Set the specific models
registry.select_inline_assistant_model(
&inline_assistant_provider_name,
&inline_assistant_model_id,
cx,
);
registry.select_commit_message_model(
&commit_message_provider_name,
&commit_message_model_id,
cx,
);
registry.select_thread_summary_model(
&thread_summary_provider_name,
&thread_summary_model_id,
cx,
);
// Set the alternatives
registry.select_inline_alternative_models(inline_alternatives, cx);
});
}

View File

@@ -22,7 +22,8 @@ use gpui::{
};
use language::LanguageRegistry;
use language_model::{
AuthenticateError, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
ZED_CLOUD_PROVIDER_ID,
};
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
@@ -298,8 +299,10 @@ impl AssistantPanel {
&LanguageModelRegistry::global(cx),
window,
|this, _, event: &language_model::Event, window, cx| match event {
language_model::Event::ActiveModelChanged
| language_model::Event::EditorModelChanged => {
language_model::Event::DefaultModelChanged
| language_model::Event::InlineAssistantModelChanged
| language_model::Event::CommitMessageModelChanged
| language_model::Event::ThreadSummaryModelChanged => {
this.completion_provider_changed(window, cx);
}
language_model::Event::ProviderStateChanged => {
@@ -468,12 +471,12 @@ impl AssistantPanel {
}
fn update_zed_ai_notice_visibility(&mut self, client_status: Status, cx: &mut Context<Self>) {
let active_provider = LanguageModelRegistry::read_global(cx).active_provider();
let model = LanguageModelRegistry::read_global(cx).default_model();
// If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is
// the provider, we want to show a nudge to sign in.
let show_zed_ai_notice = client_status.is_signed_out()
&& active_provider.map_or(true, |provider| provider.id().0 == ZED_CLOUD_PROVIDER_ID);
&& model.map_or(true, |model| model.provider.id().0 == ZED_CLOUD_PROVIDER_ID);
self.show_zed_ai_notice = show_zed_ai_notice;
cx.notify();
@@ -541,8 +544,8 @@ impl AssistantPanel {
}
let Some(new_provider_id) = LanguageModelRegistry::read_global(cx)
.active_provider()
.map(|p| p.id())
.default_model()
.map(|default| default.provider.id())
else {
return;
};
@@ -568,7 +571,9 @@ impl AssistantPanel {
return;
}
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
let Some(ConfiguredModel { provider, .. }) =
LanguageModelRegistry::read_global(cx).default_model()
else {
return;
};
@@ -976,8 +981,8 @@ impl AssistantPanel {
|this, _, event: &ConfigurationViewEvent, window, cx| match event {
ConfigurationViewEvent::NewProviderContextEditor(provider) => {
if LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(true, |p| p.id() != provider.id())
.default_model()
.map_or(true, |default| default.provider.id() != provider.id())
{
if let Some(model) = provider.default_model(cx) {
update_settings_file::<AssistantSettings>(
@@ -1155,8 +1160,8 @@ impl AssistantPanel {
fn is_authenticated(&mut self, cx: &mut Context<Self>) -> bool {
LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(false, |provider| provider.is_authenticated(cx))
.default_model()
.map_or(false, |default| default.provider.is_authenticated(cx))
}
fn authenticate(
@@ -1164,8 +1169,8 @@ impl AssistantPanel {
cx: &mut Context<Self>,
) -> Option<Task<Result<(), AuthenticateError>>> {
LanguageModelRegistry::read_global(cx)
.active_provider()
.map_or(None, |provider| Some(provider.authenticate(cx)))
.default_model()
.map_or(None, |default| Some(default.provider.authenticate(cx)))
}
fn restart_context_servers(

View File

@@ -34,8 +34,8 @@ use gpui::{
};
use language::{Buffer, IndentKind, Point, Selection, TransactionId, line_diff};
use language_model::{
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelTextStream, Role, report_assistant_event,
ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use multi_buffer::MultiBufferRow;
@@ -312,7 +312,9 @@ impl InlineAssistant {
start..end,
));
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).default_model()
{
self.telemetry.report_assistant_event(AssistantEvent {
conversation_id: None,
kind: AssistantKind::Inline,
@@ -877,7 +879,9 @@ impl InlineAssistant {
let active_alternative = assist.codegen.read(cx).active_alternative().clone();
let message_id = active_alternative.read(cx).message_id.clone();
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).default_model()
{
let language_name = assist.editor.upgrade().and_then(|editor| {
let multibuffer = editor.read(cx).buffer().read(cx);
let multibuffer_snapshot = multibuffer.snapshot(cx);
@@ -1629,8 +1633,8 @@ impl Render for PromptEditor {
format!(
"Using {}",
LanguageModelRegistry::read_global(cx)
.active_model()
.map(|model| model.name().0)
.default_model()
.map(|default| default.model.name().0)
.unwrap_or_else(|| "No model selected".into()),
),
None,
@@ -2077,7 +2081,7 @@ impl PromptEditor {
let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
let model_registry = LanguageModelRegistry::read_global(cx);
let default_model = model_registry.active_model();
let default_model = model_registry.default_model().map(|default| default.model);
let alternative_models = model_registry.inline_alternative_models();
let get_model_name = |index: usize| -> String {
@@ -2183,7 +2187,9 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let model = LanguageModelRegistry::read_global(cx)
.default_model()?
.model;
let token_counts = self.token_counts?;
let max_token_count = model.max_token_count();
@@ -2638,8 +2644,9 @@ impl Codegen {
}
let primary_model = LanguageModelRegistry::read_global(cx)
.active_model()
.context("no active model")?;
.default_model()
.context("no active model")?
.model;
for (model, alternative) in iter::once(primary_model)
.chain(alternative_models)
@@ -2863,7 +2870,9 @@ impl CodegenAlternative {
assistant_panel_context: Option<LanguageModelRequest>,
cx: &App,
) -> BoxFuture<'static, Result<TokenCounts>> {
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
{
let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
match request {
Ok(request) => {

View File

@@ -16,8 +16,8 @@ use gpui::{
};
use language::Buffer;
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
report_assistant_event,
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use prompt_store::PromptBuilder;
@@ -318,7 +318,9 @@ impl TerminalInlineAssistant {
})
.log_err();
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
if let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
{
let codegen = assist.codegen.read(cx);
let executor = cx.background_executor().clone();
report_assistant_event(
@@ -652,8 +654,8 @@ impl Render for PromptEditor {
format!(
"Using {}",
LanguageModelRegistry::read_global(cx)
.active_model()
.map(|model| model.name().0)
.inline_assistant_model()
.map(|inline_assistant| inline_assistant.model.name().0)
.unwrap_or_else(|| "No model selected".into()),
),
None,
@@ -822,7 +824,9 @@ impl PromptEditor {
fn count_tokens(&mut self, cx: &mut Context<Self>) {
let assist_id = self.id;
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
return;
};
self.pending_token_count = cx.spawn(async move |this, cx| {
@@ -980,7 +984,9 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let model = LanguageModelRegistry::read_global(cx)
.inline_assistant_model()?
.model;
let token_count = self.token_count?;
let max_token_count = model.max_token_count();
@@ -1131,7 +1137,9 @@ impl Codegen {
}
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
return;
};

View File

@@ -54,6 +54,7 @@ ui.workspace = true
util.workspace = true
uuid.workspace = true
workspace.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
language_model = { workspace = true, features = ["test-support"] }

View File

@@ -1272,7 +1272,7 @@ impl AssistantContext {
// Assume it will be a Chat request, even though that takes fewer tokens (and risks going over the limit),
// because otherwise you see in the UI that your empty message has a bunch of tokens already used.
let request = self.to_completion_request(RequestType::Chat, cx);
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return;
};
let debounce = self.token_count.is_some();
@@ -1284,10 +1284,12 @@ impl AssistantContext {
.await;
}
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
let token_count = cx
.update(|cx| model.model.count_tokens(request, cx))?
.await?;
this.update(cx, |this, cx| {
this.token_count = Some(token_count);
this.start_cache_warming(&model, cx);
this.start_cache_warming(&model.model, cx);
cx.notify()
})
}
@@ -2304,14 +2306,16 @@ impl AssistantContext {
cx: &mut Context<Self>,
) -> Option<MessageAnchor> {
let model_registry = LanguageModelRegistry::read_global(cx);
let provider = model_registry.active_provider()?;
let model = model_registry.active_model()?;
let model = model_registry.default_model()?;
let last_message_id = self.get_last_valid_message_id(cx)?;
if !provider.is_authenticated(cx) {
if !model.provider.is_authenticated(cx) {
log::info!("completion provider has no credentials");
return None;
}
let model = model.model;
// Compute which messages to cache, including the last one.
self.mark_cache_anchors(&model.cache_configuration(), false, cx);
@@ -2940,15 +2944,12 @@ impl AssistantContext {
}
pub fn summarize(&mut self, replace_old: bool, cx: &mut Context<Self>) {
let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
return;
};
let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return;
};
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
if !provider.is_authenticated(cx) {
if !model.provider.is_authenticated(cx) {
return;
}
@@ -2964,7 +2965,7 @@ impl AssistantContext {
self.pending_summary = cx.spawn(async move |this, cx| {
async move {
let stream = model.stream_completion_text(request, &cx);
let stream = model.model.stream_completion_text(request, &cx);
let mut messages = stream.await?;
let mut replaced = !replace_old;

View File

@@ -384,7 +384,9 @@ impl ContextEditor {
window: &mut Window,
cx: &mut Context<Self>,
) {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
let provider = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.provider);
if provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx))
@@ -2395,13 +2397,13 @@ impl ContextEditor {
None => (ButtonStyle::Filled, None),
};
let provider = LanguageModelRegistry::read_global(cx).active_provider();
let model = LanguageModelRegistry::read_global(cx).default_model();
let has_configuration_error = configuration_error(cx).is_some();
let needs_to_accept_terms = self.show_accept_terms
&& provider
&& model
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx));
.map_or(false, |model| model.provider.must_accept_terms(cx));
let disabled = has_configuration_error || needs_to_accept_terms;
ButtonLike::new("send_button")
@@ -2454,7 +2456,9 @@ impl ContextEditor {
None => (ButtonStyle::Filled, None),
};
let provider = LanguageModelRegistry::read_global(cx).active_provider();
let provider = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.provider);
let has_configuration_error = configuration_error(cx).is_some();
let needs_to_accept_terms = self.show_accept_terms
@@ -2500,7 +2504,9 @@ impl ContextEditor {
}
fn render_language_model_selector(&self, cx: &mut Context<Self>) -> impl IntoElement {
let active_model = LanguageModelRegistry::read_global(cx).active_model();
let active_model = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.model);
let focus_handle = self.editor().focus_handle(cx).clone();
let model_name = match active_model {
Some(model) => model.name().0,
@@ -3020,7 +3026,9 @@ impl EventEmitter<SearchEvent> for ContextEditor {}
impl Render for ContextEditor {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
let provider = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.provider);
let accept_terms = if self.show_accept_terms {
provider.as_ref().and_then(|provider| {
provider.render_accept_terms(LanguageModelProviderTosView::PromptEditorPopup, cx)
@@ -3616,7 +3624,9 @@ enum TokenState {
fn token_state(context: &Entity<AssistantContext>, cx: &App) -> Option<TokenState> {
const WARNING_TOKEN_THRESHOLD: f32 = 0.8;
let model = LanguageModelRegistry::read_global(cx).active_model()?;
let model = LanguageModelRegistry::read_global(cx)
.default_model()?
.model;
let token_count = context.read(cx).token_count()?;
let max_token_count = model.max_token_count();
@@ -3669,16 +3679,16 @@ pub enum ConfigurationError {
}
fn configuration_error(cx: &App) -> Option<ConfigurationError> {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
let is_authenticated = provider
let model = LanguageModelRegistry::read_global(cx).default_model();
let is_authenticated = model
.as_ref()
.map_or(false, |provider| provider.is_authenticated(cx));
.map_or(false, |model| model.provider.is_authenticated(cx));
if provider.is_some() && is_authenticated {
if model.is_some() && is_authenticated {
return None;
}
if provider.is_none() {
if model.is_none() {
return Some(ConfigurationError::NoProvider);
}
@@ -3703,6 +3713,18 @@ pub fn humanize_token_count(count: usize) -> String {
format!("{}.{}k", thousands, hundreds)
}
}
1_000_000..=9_999_999 => {
let millions = count / 1_000_000;
let hundred_thousands = (count % 1_000_000 + 50_000) / 100_000;
if hundred_thousands == 0 {
format!("{}M", millions)
} else if hundred_thousands == 10 {
format!("{}M", millions + 1)
} else {
format!("{}.{}M", millions, hundred_thousands)
}
}
10_000_000.. => format!("{}M", (count + 500_000) / 1_000_000),
_ => format!("{}k", (count + 500) / 1000),
}
}

View File

@@ -27,14 +27,12 @@ fs.workspace = true
futures.workspace = true
gpui.workspace = true
gpui_tokio.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
language_models.workspace = true
node_runtime.workspace = true
project.workspace = true
prompt_store.workspace = true
regex.workspace = true
release_channel.workspace = true
reqwest_client.workspace = true
serde.workspace = true
@@ -42,4 +40,7 @@ serde_json.workspace = true
serde_json_lenient.workspace = true
settings.workspace = true
smol.workspace = true
tempfile.workspace = true
util.workspace = true
walkdir.workspace = true
workspace-hack.workspace = true

View File

@@ -1,34 +1,25 @@
# Tool Evals
A framework for evaluating and benchmarking AI assistant performance in the Zed editor.
A framework for evaluating and benchmarking the agent panel generations.
## Overview
Tool Evals provides a headless environment for running assistants evaluations on code repositories. It automates the process of:
1. Cloning and setting up test repositories
1. Setting up test code and repositories
2. Sending prompts to language models
3. Allowing the assistant to use tools to modify code
4. Collecting metrics on performance
4. Collecting metrics on performance and tool usage
5. Evaluating results against known good solutions
## How It Works
The system consists of several key components:
- **Eval**: Loads test cases from the evaluation_data directory, clones repos, and executes evaluations
- **Eval**: Loads exercises from the zed-ace-framework repository, creates temporary repos, and executes evaluations
- **HeadlessAssistant**: Provides a headless environment for running the AI assistant
- **Judge**: Compares AI-generated diffs with reference solutions and scores their functional similarity
The evaluation flow:
1. An evaluation is loaded from the evaluation_data directory
2. The target repository is cloned and checked out at a specific commit
3. A HeadlessAssistant instance is created with the specified language model
4. The user prompt is sent to the assistant
5. The assistant responds and uses tools to modify code
6. Upon completion, a diff is generated from the changes
7. Results are saved including the diff, assistant's response, and performance metrics
8. If a reference solution exists, a Judge evaluates the similarity of the solution
- **Judge**: Evaluates AI-generated solutions against reference implementations and assigns scores
- **Templates**: Defines evaluation frameworks for different tasks (Project Creation, Code Modification, Conversational Guidance)
## Setup Requirements
@@ -36,6 +27,7 @@ The evaluation flow:
- Rust and Cargo
- Git
- Python (for report generation)
- Network access to clone repositories
- Appropriate API keys for language models and git services (Anthropic, GitHub, etc.)
@@ -43,35 +35,34 @@ The evaluation flow:
Ensure you have the required API keys set, either from a dev run of Zed or via these environment variables:
- `ZED_ANTHROPIC_API_KEY` for Claude models
- `ZED_OPENAI_API_KEY` for OpenAI models
- `ZED_GITHUB_API_KEY` for GitHub API (or similar)
## Usage
### Running a Single Evaluation
To run a specific evaluation:
```bash
cargo run -p assistant_eval -- bubbletea-add-set-window-title
```
The arguments are regex patterns for the evaluation names to run, so to run all evaluations that contain `bubbletea`, run:
```bash
cargo run -p assistant_eval -- bubbletea
```
To run all evaluations:
### Running Evaluations
```bash
# Run all tests
cargo run -p assistant_eval -- --all
# Run only specific languages
cargo run -p assistant_eval -- --all --languages python,rust
# Limit concurrent evaluations
cargo run -p assistant_eval -- --all --concurrency 5
# Limit number of exercises per language
cargo run -p assistant_eval -- --all --max-exercises-per-language 3
```
## Evaluation Data Structure
### Evaluation Template Types
Each evaluation should be placed in the `evaluation_data` directory with the following structure:
The system supports three types of evaluation templates:
* `prompt.txt`: The user's prompt.
* `original.diff`: The `git diff` of the change anticipated for this prompt.
* `setup.json`: Information about the repo used for the evaluation.
1. **ProjectCreation**: Tests the model's ability to create new implementations from scratch
2. **CodeModification**: Tests the model's ability to modify existing code to meet new requirements
3. **ConversationalGuidance**: Tests the model's ability to provide guidance without writing code
### Support Repo
The [zed-industries/zed-ace-framework](https://github.com/zed-industries/zed-ace-framework) contains the analytics and reporting scripts.

View File

@@ -1,6 +1,8 @@
use crate::git_commands::{run_git, setup_temp_repo};
use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
use crate::{get_exercise_language, get_exercise_name, templates_eval::Template};
use agent::RequestKind;
use anyhow::anyhow;
use anyhow::{Result, anyhow};
use collections::HashMap;
use gpui::{App, Task};
use language_model::{LanguageModel, TokenUsage};
@@ -10,19 +12,26 @@ use std::{
io::Write,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
time::{Duration, SystemTime},
};
use util::command::new_smol_command;
pub struct Eval {
pub name: String,
pub path: PathBuf,
pub repo_path: PathBuf,
pub eval_setup: EvalSetup,
pub user_prompt: String,
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct EvalResult {
pub exercise_name: String,
pub template_name: String,
pub score: String,
pub diff: String,
pub assistant_response: String,
pub elapsed_time_ms: u128,
pub timestamp: u128,
// Token usage fields
pub input_tokens: usize,
pub output_tokens: usize,
pub total_tokens: usize,
pub tool_use_counts: usize,
pub judge_model_name: String, // Added field for judge model name
}
#[derive(Debug, Serialize)]
pub struct EvalOutput {
pub diff: String,
pub last_message: String,
@@ -38,19 +47,31 @@ pub struct EvalSetup {
pub base_sha: String,
}
pub struct Eval {
pub repo_path: PathBuf,
pub eval_setup: EvalSetup,
pub user_prompt: String,
}
impl Eval {
/// Loads the eval from a path (typically in `evaluation_data`). Clones and checks out the repo
/// if necessary.
pub async fn load(name: String, path: PathBuf, repos_dir: &Path) -> anyhow::Result<Self> {
// Keep this method for potential future use, but mark it as intentionally unused
#[allow(dead_code)]
pub async fn load(_name: String, path: PathBuf, repos_dir: &Path) -> Result<Self> {
let prompt_path = path.join("prompt.txt");
let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
let setup_path = path.join("setup.json");
let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
// Move this internal function inside the load method since it's only used here
fn repo_dir_name(url: &str) -> String {
url.trim_start_matches("https://")
.replace(|c: char| !c.is_alphanumeric(), "_")
}
let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
Ok(Eval {
name,
path,
repo_path,
eval_setup,
user_prompt,
@@ -62,9 +83,9 @@ impl Eval {
app_state: Arc<HeadlessAppState>,
model: Arc<dyn LanguageModel>,
cx: &mut App,
) -> Task<anyhow::Result<EvalOutput>> {
) -> Task<Result<EvalOutput>> {
cx.spawn(async move |cx| {
checkout_repo(&self.eval_setup, &self.repo_path).await?;
run_git(&self.repo_path, &["checkout", &self.eval_setup.base_sha]).await?;
let (assistant, done_rx) =
cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
@@ -104,9 +125,43 @@ impl Eval {
done_rx.recv().await??;
// Add this section to check untracked files
println!("Checking for untracked files:");
let untracked = run_git(
&self.repo_path,
&["ls-files", "--others", "--exclude-standard"],
)
.await?;
if untracked.is_empty() {
println!("No untracked files found");
} else {
// Add all files to git so they appear in the diff
println!("Adding untracked files to git");
run_git(&self.repo_path, &["add", "."]).await?;
}
// get git status
let _status = run_git(&self.repo_path, &["status", "--short"]).await?;
let elapsed_time = start_time.elapsed()?;
let diff = query_git(&self.repo_path, vec!["diff"]).await?;
// Get diff of staged changes (the files we just added)
let staged_diff = run_git(&self.repo_path, &["diff", "--staged"]).await?;
// Get diff of unstaged changes
let unstaged_diff = run_git(&self.repo_path, &["diff"]).await?;
// Combine both diffs
let diff = if unstaged_diff.is_empty() {
staged_diff
} else if staged_diff.is_empty() {
unstaged_diff
} else {
format!(
"# Staged changes\n{}\n\n# Unstaged changes\n{}",
staged_diff, unstaged_diff
)
};
assistant.update(cx, |assistant, cx| {
let thread = assistant.thread.read(cx);
@@ -132,12 +187,9 @@ impl Eval {
}
impl EvalOutput {
// Method to save the output to a directory
pub fn save_to_directory(
&self,
output_dir: &Path,
eval_output_value: String,
) -> anyhow::Result<()> {
// Keep this method for potential future use, but mark it as intentionally unused
#[allow(dead_code)]
pub fn save_to_directory(&self, output_dir: &Path, eval_output_value: String) -> Result<()> {
// Create the output directory if it doesn't exist
fs::create_dir_all(&output_dir)?;
@@ -192,76 +244,305 @@ impl EvalOutput {
}
}
fn repo_dir_name(url: &str) -> String {
url.trim_start_matches("https://")
.replace(|c: char| !c.is_alphanumeric(), "_")
pub async fn read_instructions(exercise_path: &Path) -> Result<String> {
let instructions_path = exercise_path.join(".docs").join("instructions.md");
println!("Reading instructions from: {}", instructions_path.display());
let instructions = smol::unblock(move || std::fs::read_to_string(&instructions_path)).await?;
Ok(instructions)
}
async fn checkout_repo(eval_setup: &EvalSetup, repo_path: &Path) -> anyhow::Result<()> {
if !repo_path.exists() {
smol::unblock({
let repo_path = repo_path.to_path_buf();
|| std::fs::create_dir_all(repo_path)
})
.await?;
run_git(repo_path, vec!["init"]).await?;
run_git(repo_path, vec!["remote", "add", "origin", &eval_setup.url]).await?;
} else {
let actual_origin = query_git(repo_path, vec!["remote", "get-url", "origin"]).await?;
if actual_origin != eval_setup.url {
return Err(anyhow!(
"remote origin {} does not match expected origin {}",
actual_origin,
eval_setup.url
));
}
pub async fn read_example_solution(exercise_path: &Path, language: &str) -> Result<String> {
// Map the language to the file extension
let language_extension = match language {
"python" => "py",
"go" => "go",
"rust" => "rs",
"typescript" => "ts",
"javascript" => "js",
"ruby" => "rb",
"php" => "php",
"bash" => "sh",
"multi" => "diff",
"internal" => "diff",
_ => return Err(anyhow!("Unsupported language: {}", language)),
};
let example_path = exercise_path
.join(".meta")
.join(format!("example.{}", language_extension));
println!("Reading example solution from: {}", example_path.display());
let example = smol::unblock(move || std::fs::read_to_string(&example_path)).await?;
Ok(example)
}
// TODO: consider including "-x" to remove ignored files. The downside of this is that it will
// also remove build artifacts, and so prevent incremental reuse there.
run_git(repo_path, vec!["clean", "--force", "-d"]).await?;
run_git(repo_path, vec!["reset", "--hard", "HEAD"]).await?;
pub async fn save_eval_results(exercise_path: &Path, results: Vec<EvalResult>) -> Result<()> {
let eval_dir = exercise_path.join("evaluation");
fs::create_dir_all(&eval_dir)?;
let eval_file = eval_dir.join("evals.json");
println!("Saving evaluation results to: {}", eval_file.display());
println!(
"Results to save: {} evaluations for exercise path: {}",
results.len(),
exercise_path.display()
);
// Check file existence before reading/writing
if eval_file.exists() {
println!("Existing evals.json file found, will update it");
} else {
println!("No existing evals.json file found, will create new one");
}
run_git(
repo_path,
vec!["fetch", "--depth", "1", "origin", &eval_setup.base_sha],
)
.await?;
run_git(repo_path, vec!["checkout", &eval_setup.base_sha]).await?;
// Structure to organize evaluations by test name and timestamp
let mut eval_data: serde_json::Value = if eval_file.exists() {
let content = fs::read_to_string(&eval_file)?;
serde_json::from_str(&content).unwrap_or_else(|_| serde_json::json!({}))
} else {
serde_json::json!({})
};
// Get current timestamp for this batch of results
let timestamp = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)?
.as_millis()
.to_string();
// Group the new results by test name (exercise name)
for result in results {
let exercise_name = &result.exercise_name;
let template_name = &result.template_name;
println!(
"Adding result: exercise={}, template={}",
exercise_name, template_name
);
// Ensure the exercise entry exists
if eval_data.get(exercise_name).is_none() {
eval_data[exercise_name] = serde_json::json!({});
}
// Ensure the timestamp entry exists as an object
if eval_data[exercise_name].get(&timestamp).is_none() {
eval_data[exercise_name][&timestamp] = serde_json::json!({});
}
// Add this result under the timestamp with template name as key
eval_data[exercise_name][&timestamp][template_name] = serde_json::to_value(&result)?;
}
// Write back to file with pretty formatting
let json_content = serde_json::to_string_pretty(&eval_data)?;
match fs::write(&eval_file, json_content) {
Ok(_) => println!("✓ Successfully saved results to {}", eval_file.display()),
Err(e) => println!("✗ Failed to write results file: {}", e),
}
Ok(())
}
async fn run_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<()> {
let exit_status = new_smol_command("git")
.current_dir(repo_path)
.args(args.clone())
.status()
.await?;
if exit_status.success() {
Ok(())
} else {
Err(anyhow!(
"`git {}` failed with {}",
args.join(" "),
exit_status,
))
}
}
pub async fn run_exercise_eval(
exercise_path: PathBuf,
template: Template,
model: Arc<dyn LanguageModel>,
judge_model: Arc<dyn LanguageModel>,
app_state: Arc<HeadlessAppState>,
base_sha: String,
_framework_path: PathBuf,
cx: gpui::AsyncApp,
) -> Result<EvalResult> {
let exercise_name = get_exercise_name(&exercise_path);
let language = get_exercise_language(&exercise_path)?;
let mut instructions = read_instructions(&exercise_path).await?;
instructions.push_str(&format!(
"\n\nWhen writing the code for this prompt, use {} to achieve the goal.",
language
));
let example_solution = read_example_solution(&exercise_path, &language).await?;
async fn query_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<String> {
let output = new_smol_command("git")
.current_dir(repo_path)
.args(args.clone())
.output()
println!(
"Running evaluation for exercise: {} with template: {}",
exercise_name, template.name
);
// Create temporary directory with exercise files
let temp_dir = setup_temp_repo(&exercise_path, &base_sha).await?;
let temp_path = temp_dir.path().to_path_buf();
if template.name == "ProjectCreation" {
for entry in fs::read_dir(&temp_path)? {
let entry = entry?;
let path = entry.path();
// Skip directories that start with dot (like .docs, .meta, .git)
if path.is_dir()
&& path
.file_name()
.and_then(|name| name.to_str())
.map(|name| name.starts_with("."))
.unwrap_or(false)
{
continue;
}
// Delete regular files
if path.is_file() {
println!(" Deleting file: {}", path.display());
fs::remove_file(path)?;
}
}
// Commit the deletion so it shows up in the diff
run_git(&temp_path, &["add", "."]).await?;
run_git(
&temp_path,
&["commit", "-m", "Remove root files for clean slate"],
)
.await?;
if output.status.success() {
Ok(String::from_utf8(output.stdout)?.trim().to_string())
} else {
Err(anyhow!(
"`git {}` failed with {}",
args.join(" "),
output.status
))
}
let local_commit_sha = run_git(&temp_path, &["rev-parse", "HEAD"]).await?;
// Prepare prompt based on template
let prompt = match template.name {
"ProjectCreation" => format!(
"I need to create a new implementation for this exercise. Please create all the necessary files in the best location.\n\n{}",
instructions
),
"CodeModification" => format!(
"I need help updating my code to meet these requirements. Please modify the appropriate files:\n\n{}",
instructions
),
"ConversationalGuidance" => format!(
"I'm trying to solve this coding exercise but I'm not sure where to start. Can you help me understand the requirements and guide me through the solution process without writing code for me?\n\n{}",
instructions
),
_ => instructions.clone(),
};
let start_time = SystemTime::now();
// Create a basic eval struct to work with the existing system
let eval = Eval {
repo_path: temp_path.clone(),
eval_setup: EvalSetup {
url: format!("file://{}", temp_path.display()),
base_sha: local_commit_sha, // Use the local commit SHA instead of the framework base SHA
},
user_prompt: prompt,
};
// Run the evaluation
let eval_output = cx
.update(|cx| eval.run(app_state.clone(), model.clone(), cx))?
.await?;
// Get diff from git
let diff = eval_output.diff.clone();
// For project creation template, we need to compare with reference implementation
let judge_output = if template.name == "ProjectCreation" {
let project_judge_prompt = template
.content
.replace(
"<!-- ```requirements go here``` -->",
&format!("```\n{}\n```", instructions),
)
.replace(
"<!-- ```reference code goes here``` -->",
&format!("```{}\n{}\n```", language, example_solution),
)
.replace(
"<!-- ```git diff goes here``` -->",
&format!("```\n{}\n```", diff),
);
// Use the run_with_prompt method which we'll add to judge.rs
let judge = crate::judge::Judge {
original_diff: None,
original_message: Some(project_judge_prompt),
model: judge_model.clone(),
};
cx.update(|cx| judge.run_with_prompt(cx))?.await?
} else if template.name == "CodeModification" {
// For CodeModification, we'll compare the example solution with the LLM-generated solution
let code_judge_prompt = template
.content
.replace(
"<!-- ```reference code goes here``` -->",
&format!("```{}\n{}\n```", language, example_solution),
)
.replace(
"<!-- ```git diff goes here``` -->",
&format!("```\n{}\n```", diff),
);
// Use the run_with_prompt method
let judge = crate::judge::Judge {
original_diff: None,
original_message: Some(code_judge_prompt),
model: judge_model.clone(),
};
cx.update(|cx| judge.run_with_prompt(cx))?.await?
} else {
// Conversational template
let conv_judge_prompt = template
.content
.replace(
"<!-- ```query goes here``` -->",
&format!("```\n{}\n```", instructions),
)
.replace(
"<!-- ```transcript goes here``` -->",
&format!("```\n{}\n```", eval_output.last_message),
)
.replace(
"<!-- ```git diff goes here``` -->",
&format!("```\n{}\n```", diff),
);
// Use the run_with_prompt method for consistency
let judge = crate::judge::Judge {
original_diff: None,
original_message: Some(conv_judge_prompt),
model: judge_model.clone(),
};
cx.update(|cx| judge.run_with_prompt(cx))?.await?
};
let elapsed_time = start_time.elapsed()?;
// Calculate total tokens as the sum of input and output tokens
let input_tokens = eval_output.token_usage.input_tokens;
let output_tokens = eval_output.token_usage.output_tokens;
let tool_use_counts = eval_output.tool_use_counts.values().sum::<u32>();
let total_tokens = input_tokens + output_tokens;
// Get judge model name
let judge_model_name = judge_model.id().0.to_string();
// Save results to evaluation directory
let result = EvalResult {
exercise_name: exercise_name.clone(),
template_name: template.name.to_string(),
score: judge_output.trim().to_string(),
diff,
assistant_response: eval_output.last_message.clone(),
elapsed_time_ms: elapsed_time.as_millis(),
timestamp: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)?
.as_millis(),
// Convert u32 token counts to usize
input_tokens: input_tokens.try_into().unwrap(),
output_tokens: output_tokens.try_into().unwrap(),
total_tokens: total_tokens.try_into().unwrap(),
tool_use_counts: tool_use_counts.try_into().unwrap(),
judge_model_name, // Add judge model name to result
};
Ok(result)
}

View File

@@ -0,0 +1,149 @@
use anyhow::{Result, anyhow};
use std::{
fs,
path::{Path, PathBuf},
};
pub fn get_exercise_name(exercise_path: &Path) -> String {
exercise_path
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string()
}
pub fn get_exercise_language(exercise_path: &Path) -> Result<String> {
// Extract the language from path (data/python/exercises/... => python)
let parts: Vec<_> = exercise_path.components().collect();
for (i, part) in parts.iter().enumerate() {
if i > 0 && part.as_os_str() == "eval_code" {
if i + 1 < parts.len() {
let language = parts[i + 1].as_os_str().to_string_lossy().to_string();
return Ok(language);
}
}
}
Err(anyhow!(
"Could not determine language from path: {:?}",
exercise_path
))
}
pub fn find_exercises(
framework_path: &Path,
languages: &[&str],
max_per_language: Option<usize>,
) -> Result<Vec<PathBuf>> {
let mut all_exercises = Vec::new();
println!("Searching for exercises in languages: {:?}", languages);
for language in languages {
let language_dir = framework_path
.join("eval_code")
.join(language)
.join("exercises")
.join("practice");
println!("Checking language directory: {:?}", language_dir);
if !language_dir.exists() {
println!("Warning: Language directory not found: {:?}", language_dir);
continue;
}
let mut exercises = Vec::new();
match fs::read_dir(&language_dir) {
Ok(entries) => {
for entry_result in entries {
match entry_result {
Ok(entry) => {
let path = entry.path();
if path.is_dir() {
// Special handling for "internal" directory
if *language == "internal" {
// Check for repo_info.json to validate it's an internal exercise
let repo_info_path = path.join(".meta").join("repo_info.json");
let instructions_path =
path.join(".docs").join("instructions.md");
if repo_info_path.exists() && instructions_path.exists() {
exercises.push(path);
}
} else {
// Map the language to the file extension - original code
let language_extension = match *language {
"python" => "py",
"go" => "go",
"rust" => "rs",
"typescript" => "ts",
"javascript" => "js",
"ruby" => "rb",
"php" => "php",
"bash" => "sh",
"multi" => "diff",
_ => continue, // Skip unsupported languages
};
// Check if this is a valid exercise with instructions and example
let instructions_path =
path.join(".docs").join("instructions.md");
let has_instructions = instructions_path.exists();
let example_path = path
.join(".meta")
.join(format!("example.{}", language_extension));
let has_example = example_path.exists();
if has_instructions && has_example {
exercises.push(path);
}
}
}
}
Err(err) => println!("Error reading directory entry: {}", err),
}
}
}
Err(err) => println!(
"Error reading directory {}: {}",
language_dir.display(),
err
),
}
// Sort exercises by name for consistent selection
exercises.sort_by(|a, b| {
let a_name = a.file_name().unwrap_or_default().to_string_lossy();
let b_name = b.file_name().unwrap_or_default().to_string_lossy();
a_name.cmp(&b_name)
});
// Apply the limit if specified
if let Some(limit) = max_per_language {
if exercises.len() > limit {
println!(
"Limiting {} exercises to {} for language {}",
exercises.len(),
limit,
language
);
exercises.truncate(limit);
}
}
println!(
"Found {} exercises for language {}: {:?}",
exercises.len(),
language,
exercises
.iter()
.map(|p| p.file_name().unwrap_or_default().to_string_lossy())
.collect::<Vec<_>>()
);
all_exercises.extend(exercises);
}
Ok(all_exercises)
}

View File

@@ -0,0 +1,125 @@
use anyhow::{Result, anyhow};
use serde::Deserialize;
use std::{fs, path::Path};
use tempfile::TempDir;
use util::command::new_smol_command;
use walkdir::WalkDir;
#[derive(Debug, Deserialize)]
pub struct SetupConfig {
#[serde(rename = "base.sha")]
pub base_sha: String,
}
#[derive(Debug, Deserialize)]
pub struct RepoInfo {
pub remote_url: String,
pub head_sha: String,
}
pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
let output = new_smol_command("git")
.current_dir(repo_path)
.args(args)
.output()
.await?;
if output.status.success() {
Ok(String::from_utf8(output.stdout)?.trim().to_string())
} else {
Err(anyhow!(
"Git command failed: {} with status: {}",
args.join(" "),
output.status
))
}
}
pub async fn read_base_sha(framework_path: &Path) -> Result<String> {
let setup_path = framework_path.join("setup.json");
let setup_content = smol::unblock(move || std::fs::read_to_string(&setup_path)).await?;
let setup_config: SetupConfig = serde_json_lenient::from_str_lenient(&setup_content)?;
Ok(setup_config.base_sha)
}
pub async fn read_repo_info(exercise_path: &Path) -> Result<RepoInfo> {
let repo_info_path = exercise_path.join(".meta").join("repo_info.json");
println!("Reading repo info from: {}", repo_info_path.display());
let repo_info_content = smol::unblock(move || std::fs::read_to_string(&repo_info_path)).await?;
let repo_info: RepoInfo = serde_json_lenient::from_str_lenient(&repo_info_content)?;
// Remove any quotes from the strings
let remote_url = repo_info.remote_url.trim_matches('"').to_string();
let head_sha = repo_info.head_sha.trim_matches('"').to_string();
Ok(RepoInfo {
remote_url,
head_sha,
})
}
pub async fn setup_temp_repo(exercise_path: &Path, _base_sha: &str) -> Result<TempDir> {
let temp_dir = TempDir::new()?;
// Check if this is an internal exercise by looking for repo_info.json
let repo_info_path = exercise_path.join(".meta").join("repo_info.json");
if repo_info_path.exists() {
// This is an internal exercise, handle it differently
let repo_info = read_repo_info(exercise_path).await?;
// Clone the repository to the temp directory
let url = repo_info.remote_url;
let clone_path = temp_dir.path();
println!(
"Cloning repository from {} to {}",
url,
clone_path.display()
);
run_git(
&std::env::current_dir()?,
&["clone", &url, &clone_path.to_string_lossy()],
)
.await?;
// Checkout the specified commit
println!("Checking out commit: {}", repo_info.head_sha);
run_git(temp_dir.path(), &["checkout", &repo_info.head_sha]).await?;
println!("Successfully set up internal repository");
} else {
// Original code for regular exercises
// Copy the exercise files to the temp directory, excluding .docs and .meta
for entry in WalkDir::new(exercise_path).min_depth(0).max_depth(10) {
let entry = entry?;
let source_path = entry.path();
// Skip .docs and .meta directories completely
if source_path.starts_with(exercise_path.join(".docs"))
|| source_path.starts_with(exercise_path.join(".meta"))
{
continue;
}
if source_path.is_file() {
let relative_path = source_path.strip_prefix(exercise_path)?;
let dest_path = temp_dir.path().join(relative_path);
// Make sure parent directories exist
if let Some(parent) = dest_path.parent() {
fs::create_dir_all(parent)?;
}
fs::copy(source_path, dest_path)?;
}
}
// Initialize git repo in the temp directory
run_git(temp_dir.path(), &["init"]).await?;
run_git(temp_dir.path(), &["add", "."]).await?;
run_git(temp_dir.path(), &["commit", "-m", "Initial commit"]).await?;
println!("Created temp repo without .docs and .meta directories");
}
Ok(temp_dir)
}

View File

@@ -102,6 +102,40 @@ impl HeadlessAssistant {
thread.use_pending_tools(cx);
});
}
ThreadEvent::ToolConfirmationNeeded => {
// Automatically approve all tools that need confirmation in headless mode
println!("Tool confirmation needed - automatically approving in headless mode");
// Get the tools needing confirmation
let tools_needing_confirmation: Vec<_> = thread
.read(cx)
.tools_needing_confirmation()
.cloned()
.collect();
// Run each tool that needs confirmation
for tool_use in tools_needing_confirmation {
if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
thread.update(cx, |thread, cx| {
println!("Auto-approving tool: {}", tool_use.name);
// Create a request to send to the tool
let request = thread.to_completion_request(RequestKind::Chat, cx);
let messages = Arc::new(request.messages);
// Run the tool
thread.run_tool(
tool_use.id.clone(),
tool_use.ui_text.clone(),
tool_use.input.clone(),
&messages,
tool,
cx,
);
});
}
}
}
ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
@@ -122,11 +156,15 @@ impl HeadlessAssistant {
}
if thread.read(cx).all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() {
if let Some(model) = model_registry.default_model() {
thread.update(cx, |thread, cx| {
thread.attach_tool_results(vec![], cx);
thread.send_to_model(model, RequestKind::Chat, cx);
thread.attach_tool_results(cx);
thread.send_to_model(model.model, RequestKind::Chat, cx);
});
} else {
println!(
"Warning: No active language model available to continue conversation"
);
}
}
}

View File

@@ -1,58 +1,28 @@
use crate::eval::EvalOutput;
use crate::headless_assistant::send_language_model_request;
use anyhow::anyhow;
use gpui::{App, Task};
use language_model::{
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
};
use std::{path::Path, sync::Arc};
use std::sync::Arc;
pub struct Judge {
pub original_diff: Option<String>,
#[allow(dead_code)]
pub original_diff: Option<String>,
pub original_message: Option<String>,
pub model: Arc<dyn LanguageModel>,
}
impl Judge {
pub async fn load(eval_path: &Path, model: Arc<dyn LanguageModel>) -> anyhow::Result<Judge> {
let original_diff_path = eval_path.join("original.diff");
let original_diff = smol::unblock(move || {
if std::fs::exists(&original_diff_path)? {
anyhow::Ok(Some(std::fs::read_to_string(&original_diff_path)?))
} else {
anyhow::Ok(None)
}
});
let original_message_path = eval_path.join("original_message.txt");
let original_message = smol::unblock(move || {
if std::fs::exists(&original_message_path)? {
anyhow::Ok(Some(std::fs::read_to_string(&original_message_path)?))
} else {
anyhow::Ok(None)
}
});
Ok(Self {
original_diff: original_diff.await?,
original_message: original_message.await?,
model,
})
}
pub fn run(&self, eval_output: &EvalOutput, cx: &mut App) -> Task<anyhow::Result<String>> {
let Some(original_diff) = self.original_diff.as_ref() else {
return Task::ready(Err(anyhow!("No original.diff found")));
pub fn run_with_prompt(&self, cx: &mut App) -> Task<anyhow::Result<String>> {
let Some(prompt) = self.original_message.as_ref() else {
return Task::ready(Err(anyhow!("No prompt provided in original_message")));
};
// TODO: check for empty diff?
let prompt = diff_comparison_prompt(&original_diff, &eval_output.diff);
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::Text(prompt)],
content: vec![MessageContent::Text(prompt.clone())],
cache: false,
}],
temperature: Some(0.0),
@@ -61,61 +31,7 @@ impl Judge {
};
let model = self.model.clone();
let request = request.clone();
cx.spawn(async move |cx| send_language_model_request(model, request, cx).await)
}
}
pub fn diff_comparison_prompt(original_diff: &str, new_diff: &str) -> String {
format!(
r#"# Git Diff Similarity Evaluation Template
## Instructions
Compare the two diffs and score them between 0.0 and 1.0 based on their functional similarity.
- 1.0 = Perfect functional match (achieves identical results)
- 0.0 = No functional similarity whatsoever
## Evaluation Criteria
Please consider the following aspects in order of importance:
1. **Functional Equivalence (60%)**
- Do both diffs achieve the same end result?
- Are the changes functionally equivalent despite possibly using different approaches?
- Do the modifications address the same issues or implement the same features?
2. **Logical Structure (20%)**
- Are the logical flows similar?
- Do the modifications affect the same code paths?
- Are control structures (if/else, loops, etc.) modified in similar ways?
3. **Code Content (15%)**
- Are similar lines added/removed?
- Are the same variables, functions, or methods being modified?
- Are the same APIs or libraries being used?
4. **File Layout (5%)**
- Are the same files being modified?
- Are changes occurring in similar locations within files?
## Input
Original Diff:
```git
{}
```
New Diff:
```git
{}
```
## Output Format
THE ONLY OUTPUT SHOULD BE A SCORE BETWEEN 0.0 AND 1.0.
Example output:
0.85"#,
original_diff, new_diff
)
}

View File

@@ -1,18 +1,21 @@
mod eval;
mod get_exercise;
mod git_commands;
mod headless_assistant;
mod judge;
mod templates_eval;
use clap::Parser;
use eval::{Eval, EvalOutput};
use futures::future;
use gpui::{Application, AsyncApp};
use headless_assistant::{HeadlessAppState, authenticate_model_provider, find_model};
use itertools::Itertools;
use judge::Judge;
use language_model::{LanguageModel, LanguageModelRegistry};
use regex::Regex;
use eval::{run_exercise_eval, save_eval_results};
use futures::stream::{self, StreamExt};
use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
use git_commands::read_base_sha;
use gpui::Application;
use headless_assistant::{authenticate_model_provider, find_model};
use language_model::LanguageModelRegistry;
use reqwest_client::ReqwestClient;
use std::{cmp, path::PathBuf, sync::Arc};
use std::{path::PathBuf, sync::Arc};
use templates_eval::all_templates;
#[derive(Parser, Debug)]
#[command(
@@ -21,204 +24,231 @@ use std::{cmp, path::PathBuf, sync::Arc};
before_help = "Tool eval runner"
)]
struct Args {
/// Regexes to match the names of evals to run.
eval_name_regexes: Vec<String>,
/// Runs all evals in `evaluation_data`, causes the regex to be ignored.
/// Match the names of evals to run.
#[arg(long)]
exercise_names: Vec<String>,
/// Runs all exercises, causes the exercise_names to be ignored.
#[arg(long)]
all: bool,
/// Supported language types to evaluate (default: internal).
/// Internal is data generated from the agent panel
#[arg(long, default_value = "internal")]
languages: String,
/// Name of the model (default: "claude-3-7-sonnet-latest")
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
model_name: String,
/// Name of the editor model (default: value of `--model_name`).
#[arg(long)]
editor_model_name: Option<String>,
/// Name of the judge model (default: value of `--model_name`).
#[arg(long)]
judge_model_name: Option<String>,
/// Number of evaluations to run concurrently (default: 10)
#[arg(short, long, default_value = "10")]
/// Number of evaluations to run concurrently (default: 3)
#[arg(short, long, default_value = "3")]
concurrency: usize,
/// Maximum number of exercises to evaluate per language
#[arg(long)]
max_exercises_per_language: Option<usize>,
}
// First, let's define the order in which templates should be executed
const TEMPLATE_EXECUTION_ORDER: [&str; 3] = [
"ProjectCreation",
"CodeModification",
"ConversationalGuidance",
];
fn main() {
env_logger::init();
let args = Args::parse();
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client.clone());
let crate_dir = PathBuf::from("../zed-agent-bench");
let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap();
// Path to the zed-ace-framework repo
let framework_path = PathBuf::from("../zed-ace-framework")
.canonicalize()
.unwrap();
let repos_dir = crate_dir.join("repos");
if !repos_dir.exists() {
std::fs::create_dir_all(&repos_dir).unwrap();
}
let repos_dir = repos_dir.canonicalize().unwrap();
// Fix the 'languages' lifetime issue by creating owned Strings instead of slices
let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
let all_evals = std::fs::read_dir(&evaluation_data_dir)
.unwrap()
.map(|path| path.unwrap().file_name().to_string_lossy().to_string())
.collect::<Vec<_>>();
let evals_to_run = if args.all {
all_evals
} else {
args.eval_name_regexes
.into_iter()
.map(|regex_string| Regex::new(&regex_string).unwrap())
.flat_map(|regex| {
all_evals
.iter()
.filter(|eval_name| regex.is_match(eval_name))
.cloned()
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
};
if evals_to_run.is_empty() {
panic!("Names of evals to run must be provided or `--all` specified");
}
println!("Will run the following evals: {evals_to_run:?}");
println!("Running up to {} evals concurrently", args.concurrency);
let editor_model_name = if let Some(model_name) = args.editor_model_name {
model_name
} else {
args.model_name.clone()
};
let judge_model_name = if let Some(model_name) = args.judge_model_name {
model_name
} else {
args.model_name.clone()
};
println!("Using zed-ace-framework at: {:?}", framework_path);
println!("Evaluating languages: {:?}", languages);
app.run(move |cx| {
let app_state = headless_assistant::init(cx);
let model = find_model(&args.model_name, cx).unwrap();
let editor_model = find_model(&editor_model_name, cx).unwrap();
let judge_model = find_model(&judge_model_name, cx).unwrap();
let judge_model = if let Some(model_name) = &args.judge_model_name {
find_model(model_name, cx).unwrap()
} else {
model.clone()
};
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_active_model(Some(model.clone()), cx);
registry.set_editor_model(Some(editor_model.clone()), cx);
registry.set_default_model(Some(model.clone()), cx);
});
let model_provider_id = model.provider_id();
let editor_model_provider_id = editor_model.provider_id();
let judge_model_provider_id = judge_model.provider_id();
let framework_path_clone = framework_path.clone();
let languages_clone = languages.clone();
let exercise_names = args.exercise_names.clone();
let all_flag = args.all;
cx.spawn(async move |cx| {
// Authenticate all model providers first
cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
let eval_load_futures = evals_to_run
// Read base SHA from setup.json
let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
// Find all exercises for the specified languages
let all_exercises = find_exercises(
&framework_path_clone,
&languages_clone
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>(),
args.max_exercises_per_language,
)
.unwrap();
println!("Found {} exercises total", all_exercises.len());
// Filter exercises if specific ones were requested
let exercises_to_run = if !exercise_names.is_empty() {
// If exercise names are specified, filter by them regardless of --all flag
all_exercises
.into_iter()
.filter(|path| {
let name = get_exercise_name(path);
exercise_names.iter().any(|filter| name.contains(filter))
})
.collect()
} else if all_flag {
// Only use all_flag if no exercise names are specified
all_exercises
} else {
// Default behavior (no filters)
all_exercises
};
println!("Will run {} exercises", exercises_to_run.len());
// Get all templates and sort them according to the execution order
let mut templates = all_templates();
templates.sort_by_key(|template| {
TEMPLATE_EXECUTION_ORDER
.iter()
.position(|&name| name == template.name)
.unwrap_or(usize::MAX)
});
// Create exercise eval tasks - each exercise is a single task that will run templates sequentially
let exercise_tasks: Vec<_> = exercises_to_run
.into_iter()
.map(|eval_name| {
let eval_path = evaluation_data_dir.join(&eval_name);
let load_future = Eval::load(eval_name.clone(), eval_path, &repos_dir);
.map(|exercise_path| {
let exercise_name = get_exercise_name(&exercise_path);
let templates_clone = templates.clone();
let model_clone = model.clone();
let judge_model_clone = judge_model.clone();
let app_state_clone = app_state.clone();
let base_sha_clone = base_sha.clone();
let framework_path_clone = framework_path_clone.clone();
let cx_clone = cx.clone();
async move {
match load_future.await {
Ok(eval) => Some(eval),
println!("Processing exercise: {}", exercise_name);
let mut exercise_results = Vec::new();
// Determine the language for this exercise
let language = match get_exercise_language(&exercise_path) {
Ok(lang) => lang,
Err(err) => {
// TODO: Persist errors / surface errors at the end.
println!("Error loading {eval_name}: {err}");
None
println!(
"Error determining language for {}: {}",
exercise_name, err
);
return exercise_results;
}
};
// Run each template sequentially for this exercise
for template in templates_clone {
// For "multi" or "internal" language, only run the CodeModification template
if (language == "multi" || language == "internal")
&& template.name != "CodeModification"
{
println!(
"Skipping {} template for {} language",
template.name, language
);
continue;
}
match run_exercise_eval(
exercise_path.clone(),
template.clone(),
model_clone.clone(),
judge_model_clone.clone(),
app_state_clone.clone(),
base_sha_clone.clone(),
framework_path_clone.clone(),
cx_clone.clone(),
)
.await
{
Ok(result) => {
println!(
"Completed {} with template {} - score: {}",
exercise_name, template.name, result.score
);
exercise_results.push(result);
}
Err(err) => {
println!(
"Error running {} with template {}: {}",
exercise_name, template.name, err
);
}
}
}
}
})
.collect::<Vec<_>>();
let loaded_evals = future::join_all(eval_load_futures)
.await
.into_iter()
.flatten()
.collect::<Vec<_>>();
// The evals need to be loaded and grouped by URL before concurrently running, since
// evals that use the same remote URL will use the same working directory.
let mut evals_grouped_by_url: Vec<Vec<Eval>> = loaded_evals
.into_iter()
.map(|eval| (eval.eval_setup.url.clone(), eval))
.into_group_map()
.into_values()
.collect::<Vec<_>>();
// Sort groups in descending order, so that bigger groups start first.
evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len()));
let result_futures = evals_grouped_by_url
.into_iter()
.map(|evals| {
let model = model.clone();
let judge_model = judge_model.clone();
let app_state = app_state.clone();
let cx = cx.clone();
async move {
let mut results = Vec::new();
for eval in evals {
let name = eval.name.clone();
println!("Starting eval named {}", name);
let result = run_eval(
eval,
model.clone(),
judge_model.clone(),
app_state.clone(),
cx.clone(),
)
.await;
results.push((name, result));
// Save results for this exercise
if !exercise_results.is_empty() {
if let Err(err) =
save_eval_results(&exercise_path, exercise_results.clone()).await
{
println!("Error saving results for {}: {}", exercise_name, err);
} else {
println!("Saved results for {}", exercise_name);
}
}
results
exercise_results
}
})
.collect::<Vec<_>>();
.collect();
let results = future::join_all(result_futures)
.await
.into_iter()
.flatten()
.collect::<Vec<_>>();
println!(
"Running {} exercises with concurrency: {}",
exercise_tasks.len(),
args.concurrency
);
// Process results in order of completion
for (eval_name, result) in results {
match result {
Ok((eval_output, judge_output)) => {
println!("Generated diff for {eval_name}:\n");
println!("{}\n", eval_output.diff);
println!("Last message for {eval_name}:\n");
println!("{}\n", eval_output.last_message);
println!("Elapsed time: {:?}", eval_output.elapsed_time);
println!(
"Assistant response count: {}",
eval_output.assistant_response_count
);
println!("Tool use counts: {:?}", eval_output.tool_use_counts);
println!("Judge output for {eval_name}: {judge_output}");
}
Err(err) => {
// TODO: Persist errors / surface errors at the end.
println!("Error running {eval_name}: {err}");
}
}
}
// Run exercises concurrently, with each exercise running its templates sequentially
let all_results = stream::iter(exercise_tasks)
.buffer_unordered(args.concurrency)
.flat_map(stream::iter)
.collect::<Vec<_>>()
.await;
println!("Completed {} evaluation runs", all_results.len());
cx.update(|cx| cx.quit()).unwrap();
})
.detach();
@@ -226,18 +256,3 @@ fn main() {
println!("Done running evals");
}
async fn run_eval(
eval: Eval,
model: Arc<dyn LanguageModel>,
judge_model: Arc<dyn LanguageModel>,
app_state: Arc<HeadlessAppState>,
cx: AsyncApp,
) -> anyhow::Result<(EvalOutput, String)> {
let path = eval.path.clone();
let judge = Judge::load(&path, judge_model).await?;
let eval_output = cx.update(|cx| eval.run(app_state, model, cx))?.await?;
let judge_output = cx.update(|cx| judge.run(&eval_output, cx))?.await?;
eval_output.save_to_directory(&path, judge_output.to_string())?;
Ok((eval_output, judge_output))
}

View File

@@ -0,0 +1,210 @@
#[derive(Clone, Debug)]
pub struct Template {
pub name: &'static str,
pub content: &'static str,
}
pub fn all_templates() -> Vec<Template> {
vec![
Template {
name: "ProjectCreation",
content: r#"
# Project Creation Evaluation Template
## Instructions
Evaluate how well the AI assistant created a new implementation from scratch. Score it between 0.0 and 1.0 based on quality and fulfillment of requirements.
- 1.0 = Perfect implementation that creates all necessary files with correct functionality.
- 0.0 = Completely fails to create working files or meet requirements.
Note: A git diff output is required. If no code changes are provided (i.e., no git diff output), the score must be 0.0.
## Evaluation Criteria
Please consider the following aspects in order of importance:
1. **File Creation (25%)**
- Did the assistant create all necessary files?
- Are the files appropriately named and organized?
- Did the assistant create a complete solution without missing components?
2. **Functional Correctness (40%)**
- Does the implementation fulfill all specified requirements?
- Does it handle edge cases properly?
- Is it free of logical errors and bugs?
- Do all components work together as expected?
3. **Code Quality (20%)**
- Is the code well-structured, readable and well-documented?
- Does it follow language-specific best practices?
- Is there proper error handling?
- Are naming conventions clear and consistent?
4. **Architecture Design (15%)**
- Is the code modular and extensible?
- Is there proper separation of concerns?
- Are appropriate design patterns used?
- Is the overall architecture appropriate for the requirements?
## Input
Requirements:
<!-- ```requirements go here``` -->
Reference Implementation:
<!-- ```reference code goes here``` -->
AI-Generated Implementation (git diff output):
<!-- ```git diff goes here``` -->
## Output Format
THE ONLY OUTPUT SHOULD BE A SCORE BETWEEN 0.0 AND 1.0.
EXAMPLE ONE:
0.92
EXAMPLE TWO:
0.85
EXAMPLE THREE:
0.78
"#,
},
Template {
name: "CodeModification",
content: r#"
# Code Modification Evaluation Template
## Instructions
Evaluate how well the AI assistant modified existing code to meet requirements. Score between 0.0 and 1.0 based on quality and appropriateness of changes.
- 1.0 = Perfect modifications that correctly implement all requirements.
- 0.0 = Failed to make appropriate changes or introduced serious errors.
## Evaluation Criteria
Please consider the following aspects in order of importance:
1. **Functional Correctness (50%)**
- Do the modifications correctly implement the requirements?
- Did the assistant modify the right files and code sections?
- Are the changes free of bugs and logical errors?
- Do the modifications maintain compatibility with existing code?
2. **Modification Approach (25%)**
- Are the changes minimal and focused on what needs to be changed?
- Did the assistant avoid unnecessary modifications?
- Are the changes integrated seamlessly with the existing codebase?
- Did the assistant preserve the original code style and patterns?
3. **Code Quality (15%)**
- Are the modifications well-structured and documented?
- Do they follow the same conventions as the original code?
- Is there proper error handling in the modified code?
- Are the changes readable and maintainable?
4. **Solution Completeness (10%)**
- Do the modifications completely address all requirements?
- Are there any missing changes or overlooked requirements?
- Did the assistant consider all necessary edge cases?
## Input
Original:
<!-- ```reference code goes here``` -->
New (git diff output):
<!-- ```git diff goes here``` -->
## Output Format
THE ONLY OUTPUT SHOULD BE A SCORE BETWEEN 0.0 AND 1.0.
EXAMPLE ONE:
0.92
EXAMPLE TWO:
0.85
EXAMPLE THREE:
0.78
"#,
},
Template {
name: "ConversationalGuidance",
content: r#"
# Conversational Guidance Evaluation Template
## Instructions
Evaluate the quality of the AI assistant's conversational guidance and score it between 0.0 and 1.0.
- 1.0 = Perfect guidance with ideal information gathering, clarification, and advice without writing code.
- 0.0 = Completely unhelpful, inappropriate guidance, or wrote code when it should not have.
## Evaluation Criteria
ABSOLUTE REQUIREMENT:
- The assistant should NOT generate complete code solutions in conversation mode.
- If the git diff shows the assistant wrote complete code, the score should be significantly reduced.
1. **Information Gathering Effectiveness (30%)**
- Did the assistant ask relevant and precise questions?
- Did it efficiently narrow down the problem scope?
- Did it avoid unnecessary or redundant questions?
- Was questioning appropriately paced and contextual?
2. **Conceptual Guidance (30%)**
- Did the assistant provide high-level approaches and strategies?
- Did it explain relevant concepts and algorithms?
- Did it offer planning advice without implementing the solution?
- Did it suggest a structured approach to solving the problem?
3. **Educational Value (20%)**
- Did the assistant help the user understand the problem better?
- Did it provide explanations that would help the user learn?
- Did it guide without simply giving away answers?
- Did it encourage the user to think through parts of the problem?
4. **Conversation Quality (20%)**
- Was the conversation logically structured and easy to follow?
- Did the assistant maintain appropriate context throughout?
- Was the interaction helpful without being condescending?
- Did the conversation reach a satisfactory conclusion with clear next steps?
## Input
Initial Query:
<!-- ```query goes here``` -->
Conversation Transcript:
<!-- ```transcript goes here``` -->
Git Diff:
<!-- ```git diff goes here``` -->
## Output Format
THE ONLY OUTPUT SHOULD BE A SCORE BETWEEN 0.0 AND 1.0.
EXAMPLE ONE:
0.92
EXAMPLE TWO:
0.85
EXAMPLE THREE:
0.78
"#,
},
]
}

View File

@@ -26,6 +26,7 @@ deepseek = { workspace = true, features = ["schemars"] }
schemars.workspace = true
serde.workspace = true
settings.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
fs.workspace = true

View File

@@ -77,7 +77,9 @@ pub struct AssistantSettings {
pub default_width: Pixels,
pub default_height: Pixels,
pub default_model: LanguageModelSelection,
pub editor_model: LanguageModelSelection,
pub inline_assistant_model: Option<LanguageModelSelection>,
pub commit_message_model: Option<LanguageModelSelection>,
pub thread_summary_model: Option<LanguageModelSelection>,
pub inline_alternatives: Vec<LanguageModelSelection>,
pub using_outdated_settings_version: bool,
pub enable_experimental_live_diffs: bool,
@@ -95,13 +97,25 @@ impl AssistantSettings {
cx.is_staff() || self.enable_experimental_live_diffs
}
pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
self.inline_assistant_model = Some(LanguageModelSelection { provider, model });
}
pub fn set_commit_message_model(&mut self, provider: String, model: String) {
self.commit_message_model = Some(LanguageModelSelection { provider, model });
}
pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
self.thread_summary_model = Some(LanguageModelSelection { provider, model });
}
}
/// Assistant panel settings
#[derive(Clone, Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum AssistantSettingsContent {
Versioned(VersionedAssistantSettingsContent),
Versioned(Box<VersionedAssistantSettingsContent>),
Legacy(LegacyAssistantSettingsContent),
}
@@ -121,14 +135,14 @@ impl JsonSchema for AssistantSettingsContent {
impl Default for AssistantSettingsContent {
fn default() -> Self {
Self::Versioned(VersionedAssistantSettingsContent::default())
Self::Versioned(Box::new(VersionedAssistantSettingsContent::default()))
}
}
impl AssistantSettingsContent {
pub fn is_version_outdated(&self) -> bool {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
AssistantSettingsContent::Versioned(settings) => match **settings {
VersionedAssistantSettingsContent::V1(_) => true,
VersionedAssistantSettingsContent::V2(_) => false,
},
@@ -138,8 +152,8 @@ impl AssistantSettingsContent {
fn upgrade(&self) -> AssistantSettingsContentV2 {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => AssistantSettingsContentV2 {
AssistantSettingsContent::Versioned(settings) => match **settings {
VersionedAssistantSettingsContent::V1(ref settings) => AssistantSettingsContentV2 {
enabled: settings.enabled,
button: settings.button,
dock: settings.dock,
@@ -186,7 +200,9 @@ impl AssistantSettingsContent {
})
}
}),
editor_model: None,
inline_assistant_model: None,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
default_profile: None,
@@ -194,7 +210,7 @@ impl AssistantSettingsContent {
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
},
VersionedAssistantSettingsContent::V2(settings) => settings.clone(),
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
},
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV2 {
enabled: None,
@@ -211,7 +227,9 @@ impl AssistantSettingsContent {
.id()
.to_string(),
}),
editor_model: None,
inline_assistant_model: None,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
default_profile: None,
@@ -224,11 +242,11 @@ impl AssistantSettingsContent {
pub fn set_dock(&mut self, dock: AssistantDockPosition) {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => {
AssistantSettingsContent::Versioned(settings) => match **settings {
VersionedAssistantSettingsContent::V1(ref mut settings) => {
settings.dock = Some(dock);
}
VersionedAssistantSettingsContent::V2(settings) => {
VersionedAssistantSettingsContent::V2(ref mut settings) => {
settings.dock = Some(dock);
}
},
@@ -243,77 +261,79 @@ impl AssistantSettingsContent {
let provider = language_model.provider_id().0.to_string();
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() {
"zed.dev" => {
log::warn!("attempted to set zed.dev model on outdated settings");
}
"anthropic" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::Anthropic { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::Anthropic {
default_model: AnthropicModel::from_id(&model).ok(),
api_url,
});
}
"ollama" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::Ollama { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::Ollama {
default_model: Some(ollama::Model::new(&model, None, None)),
api_url,
});
}
"lmstudio" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::LmStudio { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::LmStudio {
default_model: Some(lmstudio::Model::new(&model, None, None)),
api_url,
});
}
"openai" => {
let (api_url, available_models) = match &settings.provider {
Some(AssistantProviderContentV1::OpenAi {
AssistantSettingsContent::Versioned(settings) => match **settings {
VersionedAssistantSettingsContent::V1(ref mut settings) => {
match provider.as_ref() {
"zed.dev" => {
log::warn!("attempted to set zed.dev model on outdated settings");
}
"anthropic" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::Anthropic { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::Anthropic {
default_model: AnthropicModel::from_id(&model).ok(),
api_url,
});
}
"ollama" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::Ollama { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::Ollama {
default_model: Some(ollama::Model::new(&model, None, None)),
api_url,
});
}
"lmstudio" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::LmStudio { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::LmStudio {
default_model: Some(lmstudio::Model::new(&model, None, None)),
api_url,
});
}
"openai" => {
let (api_url, available_models) = match &settings.provider {
Some(AssistantProviderContentV1::OpenAi {
api_url,
available_models,
..
}) => (api_url.clone(), available_models.clone()),
_ => (None, None),
};
settings.provider = Some(AssistantProviderContentV1::OpenAi {
default_model: OpenAiModel::from_id(&model).ok(),
api_url,
available_models,
..
}) => (api_url.clone(), available_models.clone()),
_ => (None, None),
};
settings.provider = Some(AssistantProviderContentV1::OpenAi {
default_model: OpenAiModel::from_id(&model).ok(),
api_url,
available_models,
});
});
}
"deepseek" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::DeepSeek { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::DeepSeek {
default_model: DeepseekModel::from_id(&model).ok(),
api_url,
});
}
_ => {}
}
"deepseek" => {
let api_url = match &settings.provider {
Some(AssistantProviderContentV1::DeepSeek { api_url, .. }) => {
api_url.clone()
}
_ => None,
};
settings.provider = Some(AssistantProviderContentV1::DeepSeek {
default_model: DeepseekModel::from_id(&model).ok(),
api_url,
});
}
_ => {}
},
VersionedAssistantSettingsContent::V2(settings) => {
}
VersionedAssistantSettingsContent::V2(ref mut settings) => {
settings.default_model = Some(LanguageModelSelection { provider, model });
}
},
@@ -325,23 +345,48 @@ impl AssistantSettingsContent {
}
}
pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
if let AssistantSettingsContent::Versioned(boxed) = self {
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.inline_assistant_model = Some(LanguageModelSelection { provider, model });
}
}
}
pub fn set_commit_message_model(&mut self, provider: String, model: String) {
if let AssistantSettingsContent::Versioned(boxed) = self {
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.commit_message_model = Some(LanguageModelSelection { provider, model });
}
}
}
pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
if let AssistantSettingsContent::Versioned(boxed) = self {
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.thread_summary_model = Some(LanguageModelSelection { provider, model });
}
}
}
pub fn set_always_allow_tool_actions(&mut self, allow: bool) {
let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) =
self
else {
let AssistantSettingsContent::Versioned(boxed) = self else {
return;
};
settings.always_allow_tool_actions = Some(allow);
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.always_allow_tool_actions = Some(allow);
}
}
pub fn set_profile(&mut self, profile_id: AgentProfileId) {
let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) =
self
else {
let AssistantSettingsContent::Versioned(boxed) = self else {
return;
};
settings.default_profile = Some(profile_id);
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.default_profile = Some(profile_id);
}
}
pub fn create_profile(
@@ -349,37 +394,37 @@ impl AssistantSettingsContent {
profile_id: AgentProfileId,
profile: AgentProfile,
) -> Result<()> {
let AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(settings)) =
self
else {
let AssistantSettingsContent::Versioned(boxed) = self else {
return Ok(());
};
let profiles = settings.profiles.get_or_insert_default();
if profiles.contains_key(&profile_id) {
bail!("profile with ID '{profile_id}' already exists");
}
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
let profiles = settings.profiles.get_or_insert_default();
if profiles.contains_key(&profile_id) {
bail!("profile with ID '{profile_id}' already exists");
}
profiles.insert(
profile_id,
AgentProfileContent {
name: profile.name.into(),
tools: profile.tools,
enable_all_context_servers: Some(profile.enable_all_context_servers),
context_servers: profile
.context_servers
.into_iter()
.map(|(server_id, preset)| {
(
server_id,
ContextServerPresetContent {
tools: preset.tools,
},
)
})
.collect(),
},
);
profiles.insert(
profile_id,
AgentProfileContent {
name: profile.name.into(),
tools: profile.tools,
enable_all_context_servers: Some(profile.enable_all_context_servers),
context_servers: profile
.context_servers
.into_iter()
.map(|(server_id, preset)| {
(
server_id,
ContextServerPresetContent {
tools: preset.tools,
},
)
})
.collect(),
},
);
}
Ok(())
}
@@ -403,7 +448,9 @@ impl Default for VersionedAssistantSettingsContent {
default_width: None,
default_height: None,
default_model: None,
editor_model: None,
inline_assistant_model: None,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
default_profile: None,
@@ -436,10 +483,14 @@ pub struct AssistantSettingsContentV2 {
///
/// Default: 320
default_height: Option<f32>,
/// The default model to use when creating new chats.
/// The default model to use when creating new chats and for other features when a specific model is not specified.
default_model: Option<LanguageModelSelection>,
/// The model to use when applying edits from the assistant.
editor_model: Option<LanguageModelSelection>,
/// Model to use for the inline assistant. Defaults to default_model when not specified.
inline_assistant_model: Option<LanguageModelSelection>,
/// Model to use for generating git commit messages. Defaults to default_model when not specified.
commit_message_model: Option<LanguageModelSelection>,
/// Model to use for generating thread summaries. Defaults to default_model when not specified.
thread_summary_model: Option<LanguageModelSelection>,
/// Additional models with which to generate alternatives when performing inline assists.
inline_alternatives: Option<Vec<LanguageModelSelection>>,
/// Enable experimental live diffs in the assistant panel.
@@ -601,7 +652,15 @@ impl Settings for AssistantSettings {
value.default_height.map(Into::into),
);
merge(&mut settings.default_model, value.default_model);
merge(&mut settings.editor_model, value.editor_model);
settings.inline_assistant_model = value
.inline_assistant_model
.or(settings.inline_assistant_model.take());
settings.commit_message_model = value
.commit_message_model
.or(settings.commit_message_model.take());
settings.thread_summary_model = value
.thread_summary_model
.or(settings.thread_summary_model.take());
merge(&mut settings.inline_alternatives, value.inline_alternatives);
merge(
&mut settings.enable_experimental_live_diffs,
@@ -692,16 +751,15 @@ mod tests {
settings::SettingsStore::global(cx).update_settings_file::<AssistantSettings>(
fs.clone(),
|settings, _| {
*settings = AssistantSettingsContent::Versioned(
*settings = AssistantSettingsContent::Versioned(Box::new(
VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 {
default_model: Some(LanguageModelSelection {
provider: "test-provider".into(),
model: "gpt-99".into(),
}),
editor_model: Some(LanguageModelSelection {
provider: "test-provider".into(),
model: "gpt-99".into(),
}),
inline_assistant_model: None,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
enabled: None,
button: None,
@@ -714,7 +772,7 @@ mod tests {
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
}),
)
))
},
);
});

View File

@@ -26,6 +26,7 @@ serde.workspace = true
serde_json.workspace = true
ui.workspace = true
workspace.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }

View File

@@ -42,6 +42,7 @@ ui.workspace = true
util.workspace = true
workspace.workspace = true
worktree.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
env_logger.workspace = true

View File

@@ -28,6 +28,7 @@ serde.workspace = true
serde_json.workspace = true
text.workspace = true
util.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
buffer_diff = { workspace = true, features = ["test-support"] }

View File

@@ -1,31 +1,31 @@
use anyhow::{Context as _, Result};
use buffer_diff::BufferDiff;
use collections::{BTreeMap, HashSet};
use collections::BTreeMap;
use futures::{StreamExt, channel::mpsc};
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
use language::{Anchor, Buffer, BufferEvent, DiskState, Point};
use project::{Project, ProjectItem};
use std::{cmp, ops::Range, sync::Arc};
use text::{Edit, Patch, Rope};
use util::RangeExt;
/// Tracks actions performed by tools in a thread
pub struct ActionLog {
/// Buffers that user manually added to the context, and whose content has
/// changed since the model last saw them.
stale_buffers_in_context: HashSet<Entity<Buffer>>,
/// Buffers that we want to notify the model about when they change.
tracked_buffers: BTreeMap<Entity<Buffer>, TrackedBuffer>,
/// Has the model edited a file since it last checked diagnostics?
edited_since_project_diagnostics_check: bool,
/// The project this action log is associated with
project: Entity<Project>,
}
impl ActionLog {
/// Creates a new, empty action log.
pub fn new() -> Self {
/// Creates a new, empty action log associated with the given project.
pub fn new(project: Entity<Project>) -> Self {
Self {
stale_buffers_in_context: HashSet::default(),
tracked_buffers: BTreeMap::default(),
edited_since_project_diagnostics_check: false,
project,
}
}
@@ -259,6 +259,11 @@ impl ActionLog {
self.track_buffer(buffer, false, cx);
}
/// Track a buffer that was added as context, so we can notify the model about user edits.
pub fn buffer_added_as_context(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.track_buffer(buffer, false, cx);
}
/// Track a buffer as read, so we can notify the model about user edits.
pub fn will_create_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.track_buffer(buffer.clone(), true, cx);
@@ -268,7 +273,6 @@ impl ActionLog {
/// Mark a buffer as edited, so we can refresh it in the context
pub fn buffer_edited(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.edited_since_project_diagnostics_check = true;
self.stale_buffers_in_context.insert(buffer.clone());
let tracked_buffer = self.track_buffer(buffer.clone(), false, cx);
if let TrackedBufferStatus::Deleted = tracked_buffer.status {
@@ -324,14 +328,14 @@ impl ActionLog {
{
true
} else {
let old_bytes = tracked_buffer
let old_range = tracked_buffer
.base_text
.point_to_offset(Point::new(edit.old.start, 0))
..tracked_buffer.base_text.point_to_offset(cmp::min(
Point::new(edit.old.end, 0),
tracked_buffer.base_text.max_point(),
));
let new_bytes = tracked_buffer
let new_range = tracked_buffer
.snapshot
.point_to_offset(Point::new(edit.new.start, 0))
..tracked_buffer.snapshot.point_to_offset(cmp::min(
@@ -339,10 +343,10 @@ impl ActionLog {
tracked_buffer.snapshot.max_point(),
));
tracked_buffer.base_text.replace(
old_bytes,
old_range,
&tracked_buffer
.snapshot
.text_for_range(new_bytes)
.text_for_range(new_range)
.collect::<String>(),
);
delta += edit.new_len() as i32 - edit.old_len() as i32;
@@ -354,6 +358,87 @@ impl ActionLog {
}
}
pub fn reject_edits_in_range(
&mut self,
buffer: Entity<Buffer>,
buffer_range: Range<impl language::ToPoint>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let Some(tracked_buffer) = self.tracked_buffers.get_mut(&buffer) else {
return Task::ready(Ok(()));
};
match tracked_buffer.status {
TrackedBufferStatus::Created => {
let delete = buffer
.read(cx)
.entry_id(cx)
.and_then(|entry_id| {
self.project
.update(cx, |project, cx| project.delete_entry(entry_id, false, cx))
})
.unwrap_or(Task::ready(Ok(())));
self.tracked_buffers.remove(&buffer);
cx.notify();
delete
}
TrackedBufferStatus::Deleted => {
buffer.update(cx, |buffer, cx| {
buffer.set_text(tracked_buffer.base_text.to_string(), cx)
});
let save = self
.project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx));
// Clear all tracked changes for this buffer and start over as if we just read it.
self.tracked_buffers.remove(&buffer);
self.track_buffer(buffer.clone(), false, cx);
cx.notify();
save
}
TrackedBufferStatus::Modified => {
buffer.update(cx, |buffer, cx| {
let buffer_range =
buffer_range.start.to_point(buffer)..buffer_range.end.to_point(buffer);
let mut edits_to_revert = Vec::new();
for edit in tracked_buffer.unreviewed_changes.edits() {
if buffer_range.end.row < edit.new.start {
break;
} else if buffer_range.start.row > edit.new.end {
continue;
}
let old_range = tracked_buffer
.base_text
.point_to_offset(Point::new(edit.old.start, 0))
..tracked_buffer.base_text.point_to_offset(cmp::min(
Point::new(edit.old.end, 0),
tracked_buffer.base_text.max_point(),
));
let old_text = tracked_buffer
.base_text
.chunks_in_range(old_range)
.collect::<String>();
let new_range = tracked_buffer
.snapshot
.anchor_before(Point::new(edit.new.start, 0))
..tracked_buffer.snapshot.anchor_after(cmp::min(
Point::new(edit.new.end, 0),
tracked_buffer.snapshot.max_point(),
));
edits_to_revert.push((new_range, old_text));
}
buffer.edit(edits_to_revert, None, cx);
});
self.project
.update(cx, |project, cx| project.save_buffer(buffer, cx))
}
}
}
pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
self.tracked_buffers
.retain(|_buffer, tracked_buffer| match tracked_buffer.status {
@@ -391,11 +476,6 @@ impl ActionLog {
})
.map(|(buffer, _)| buffer)
}
/// Takes and returns the set of buffers pending refresh, clearing internal state.
pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
std::mem::take(&mut self.stale_buffers_in_context)
}
}
fn apply_non_conflicting_edits(
@@ -580,9 +660,22 @@ mod tests {
}
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
}
#[gpui::test(iterations = 10)]
async fn test_keep_edits(cx: &mut TestAppContext) {
let action_log = cx.new(|_| ActionLog::new());
init_test(cx);
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
cx.update(|cx| {
@@ -648,7 +741,11 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_deletions(cx: &mut TestAppContext) {
let action_log = cx.new(|_| ActionLog::new());
init_test(cx);
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno\npqr", cx));
cx.update(|cx| {
@@ -718,7 +815,11 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_overlapping_user_edits(cx: &mut TestAppContext) {
let action_log = cx.new(|_| ActionLog::new());
init_test(cx);
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
cx.update(|cx| {
@@ -802,15 +903,12 @@ mod tests {
}
#[gpui::test(iterations = 10)]
async fn test_creation(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
async fn test_creating_files(cx: &mut TestAppContext) {
init_test(cx);
let action_log = cx.new(|_| ActionLog::new());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({})).await;
@@ -869,12 +967,7 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_deleting_files(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
@@ -891,7 +984,7 @@ mod tests {
.read_with(cx, |project, cx| project.find_project_path("dir/file2", cx))
.unwrap();
let action_log = cx.new(|_| ActionLog::new());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let buffer1 = project
.update(cx, |project, cx| {
project.open_buffer(file1_path.clone(), cx)
@@ -981,15 +1074,222 @@ mod tests {
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
}
#[gpui::test(iterations = 10)]
async fn test_reject_edits(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
.await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
.unwrap()
});
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(5, 2)..Point::new(5, 3), "O")], None, cx)
.unwrap()
});
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
});
cx.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndE\nXYZf\nghi\njkl\nmnO"
);
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![
HunkStatus {
range: Point::new(1, 0)..Point::new(3, 0),
diff_status: DiffHunkStatusKind::Modified,
old_text: "def\n".into(),
},
HunkStatus {
range: Point::new(5, 0)..Point::new(5, 3),
diff_status: DiffHunkStatusKind::Modified,
old_text: "mno".into(),
}
],
)]
);
action_log
.update(cx, |log, cx| {
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(1, 0), cx)
})
.await
.unwrap();
cx.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndef\nghi\njkl\nmnO"
);
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![HunkStatus {
range: Point::new(4, 0)..Point::new(4, 3),
diff_status: DiffHunkStatusKind::Modified,
old_text: "mno".into(),
}],
)]
);
action_log
.update(cx, |log, cx| {
log.reject_edits_in_range(buffer.clone(), Point::new(4, 0)..Point::new(4, 0), cx)
})
.await
.unwrap();
cx.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.text()),
"abc\ndef\nghi\njkl\nmno"
);
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
}
#[gpui::test(iterations = 10)]
async fn test_reject_deleted_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({"file": "content"}))
.await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path.clone(), cx))
.await
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.will_delete_buffer(buffer.clone(), cx));
});
project
.update(cx, |project, cx| {
project.delete_file(file_path.clone(), false, cx)
})
.unwrap()
.await
.unwrap();
cx.run_until_parked();
assert!(!fs.is_file(path!("/dir/file").as_ref()).await);
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![HunkStatus {
range: Point::new(0, 0)..Point::new(0, 0),
diff_status: DiffHunkStatusKind::Deleted,
old_text: "content".into(),
}]
)]
);
action_log
.update(cx, |log, cx| {
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 0), cx)
})
.await
.unwrap();
cx.run_until_parked();
assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "content");
assert!(fs.is_file(path!("/dir/file").as_ref()).await);
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
}
#[gpui::test(iterations = 10)]
async fn test_reject_created_file(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| {
project.find_project_path("dir/new_file", cx)
})
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.set_text("content", cx));
action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx));
});
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.await
.unwrap();
assert!(fs.is_file(path!("/dir/new_file").as_ref()).await);
cx.run_until_parked();
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![HunkStatus {
range: Point::new(0, 0)..Point::new(0, 7),
diff_status: DiffHunkStatusKind::Added,
old_text: "".into(),
}],
)]
);
action_log
.update(cx, |log, cx| {
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 11), cx)
})
.await
.unwrap();
cx.run_until_parked();
assert!(!fs.is_file(path!("/dir/new_file").as_ref()).await);
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
}
#[gpui::test(iterations = 100)]
async fn test_random_diffs(mut rng: StdRng, cx: &mut TestAppContext) {
init_test(cx);
let operations = env::var("OPERATIONS")
.map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
.unwrap_or(20);
let action_log = cx.new(|_| ActionLog::new());
let text = RandomCharIter::new(&mut rng).take(50).collect::<String>();
let buffer = cx.new(|cx| Buffer::local(text, cx));
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({"file": text})).await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
for _ in 0..operations {
@@ -997,10 +1297,20 @@ mod tests {
0..25 => {
action_log.update(cx, |log, cx| {
let range = buffer.read(cx).random_byte_range(0, &mut rng);
log::info!("keeping all edits in range {:?}", range);
log::info!("keeping edits in range {:?}", range);
log.keep_edits_in_range(buffer.clone(), range, cx)
});
}
25..50 => {
action_log
.update(cx, |log, cx| {
let range = buffer.read(cx).random_byte_range(0, &mut rng);
log::info!("rejecting edits in range {:?}", range);
log.reject_edits_in_range(buffer.clone(), range, cx)
})
.await
.unwrap();
}
_ => {
let is_agent_change = rng.gen_bool(0.5);
if is_agent_change {

View File

@@ -39,6 +39,7 @@ util.workspace = true
workspace.workspace = true
worktree.workspace = true
open = { workspace = true }
workspace-hack.workspace = true
[dev-dependencies]
collections = { workspace = true, features = ["test-support"] }

View File

@@ -9,7 +9,7 @@ use collections::HashSet;
use edit_action::{EditAction, EditActionParser, edit_model_prompt};
use futures::{SinkExt, StreamExt, channel::mpsc};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use language_model::LanguageModelToolSchemaFormat;
use language_model::{ConfiguredModel, LanguageModelToolSchemaFormat};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
};
@@ -205,8 +205,8 @@ impl EditToolRequest {
cx: &mut App,
) -> Task<Result<String>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.editor_model() else {
return Task::ready(Err(anyhow!("No editor model configured")));
let Some(ConfiguredModel { model, .. }) = model_registry.default_model() else {
return Task::ready(Err(anyhow!("No model configured")));
};
let mut messages = messages.to_vec();

View File

@@ -67,7 +67,11 @@ impl Tool for ReadFileTool {
match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
Ok(input) => {
let path = MarkdownString::inline_code(&input.path.display().to_string());
format!("Read file {path}")
match (input.start_line, input.end_line) {
(Some(start), None) => format!("Read file {path} (from line {start})"),
(Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
_ => format!("Read file {path}"),
}
}
Err(_) => "Read file".to_string(),
}

View File

@@ -20,3 +20,4 @@ gpui.workspace = true
parking_lot.workspace = true
rodio = { version = "0.20.0", default-features = false, features = ["wav"] }
util.workspace = true
workspace-hack.workspace = true

View File

@@ -29,3 +29,4 @@ smol.workspace = true
tempfile.workspace = true
which.workspace = true
workspace.workspace = true
workspace-hack.workspace = true

View File

@@ -25,3 +25,4 @@ serde_json.workspace = true
smol.workspace = true
util.workspace = true
workspace.workspace = true
workspace-hack.workspace = true

View File

@@ -20,3 +20,4 @@ aws-smithy-types.workspace = true
futures.workspace = true
http_client.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
workspace-hack.workspace = true

View File

@@ -9,7 +9,7 @@ use aws_smithy_runtime_api::client::http::{
use aws_smithy_runtime_api::client::orchestrator::{HttpRequest as AwsHttpRequest, HttpResponse};
use aws_smithy_runtime_api::client::result::ConnectorError;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_runtime_api::http::StatusCode;
use aws_smithy_runtime_api::http::{Headers, StatusCode};
use aws_smithy_types::body::SdkBody;
use futures::AsyncReadExt;
use http_client::{AsyncBody, Inner};
@@ -52,10 +52,17 @@ impl AwsConnector for AwsHttpConnector {
let (parts, body) = response.into_parts();
let body = convert_to_sdk_body(body, handle).await;
Ok(HttpResponse::new(
StatusCode::try_from(parts.status.as_u16()).unwrap(),
body,
))
let mut response =
HttpResponse::new(StatusCode::try_from(parts.status.as_u16()).unwrap(), body);
let headers = match Headers::try_from(parts.headers) {
Ok(headers) => headers,
Err(err) => return Err(ConnectorError::other(err.into(), None)),
};
*response.headers_mut() = headers;
Ok(response)
})
}
}

View File

@@ -26,3 +26,4 @@ serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
workspace-hack.workspace = true

View File

@@ -20,6 +20,7 @@ theme.workspace = true
ui.workspace = true
workspace.workspace = true
zed_actions.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
editor = { workspace = true, features = ["test-support"] }

View File

@@ -27,6 +27,7 @@ rope.workspace = true
sum_tree.workspace = true
text.workspace = true
util.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
ctor.workspace = true

View File

@@ -42,6 +42,7 @@ telemetry.workspace = true
util.workspace = true
gpui_tokio.workspace = true
livekit_client.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
client = { workspace = true, features = ["test-support"] }

View File

@@ -32,6 +32,7 @@ sum_tree.workspace = true
text.workspace = true
time.workspace = true
util.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
collections = { workspace = true, features = ["test-support"] }

View File

@@ -31,6 +31,7 @@ release_channel.workspace = true
serde.workspace = true
util.workspace = true
tempfile.workspace = true
workspace-hack.workspace = true
[target.'cfg(any(target_os = "linux", target_os = "freebsd"))'.dependencies]
exec.workspace = true

View File

@@ -51,6 +51,7 @@ url.workspace = true
util.workspace = true
worktree.workspace = true
telemetry.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }

View File

@@ -19,3 +19,4 @@ test-support = ["dep:parking_lot"]
parking_lot = { workspace = true, optional = true }
serde.workspace = true
smallvec.workspace = true
workspace-hack.workspace = true

View File

@@ -76,6 +76,7 @@ tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "registry", "tracing-log"] } # workaround for https://github.com/tokio-rs/tracing/issues/2927
util.workspace = true
uuid.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
assistant = { workspace = true, features = ["test-support"] }

View File

@@ -64,6 +64,7 @@ title_bar.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
call = { workspace = true, features = ["test-support"] }

View File

@@ -18,3 +18,4 @@ test-support = []
[dependencies]
indexmap.workspace = true
rustc-hash.workspace = true
workspace-hack.workspace = true

View File

@@ -31,6 +31,7 @@ util.workspace = true
telemetry.workspace = true
workspace.workspace = true
zed_actions.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
ctor.workspace = true

View File

@@ -16,3 +16,4 @@ doctest = false
collections.workspace = true
derive_more.workspace = true
gpui.workspace = true
workspace-hack.workspace = true

View File

@@ -17,6 +17,7 @@ gpui.workspace = true
linkme.workspace = true
parking_lot.workspace = true
theme.workspace = true
workspace-hack.workspace = true
[features]
default = []

View File

@@ -24,3 +24,4 @@ ui.workspace = true
workspace.workspace = true
notifications.workspace = true
collections.workspace = true
workspace-hack.workspace = true

View File

@@ -33,3 +33,4 @@ settings.workspace = true
smol.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
workspace-hack.workspace = true

View File

@@ -19,3 +19,4 @@ schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
workspace-hack.workspace = true

View File

@@ -52,6 +52,7 @@ task.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true
workspace-hack.workspace = true
[target.'cfg(windows)'.dependencies]
async-std = { version = "1.12.0", features = ["unstable"] }

View File

@@ -19,3 +19,4 @@ paths.workspace = true
release_channel.workspace = true
serde.workspace = true
serde_json.workspace = true
workspace-hack.workspace = true

View File

@@ -48,6 +48,7 @@ smallvec.workspace = true
smol.workspace = true
task.workspace = true
util.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
async-pipe.workspace = true

View File

@@ -32,6 +32,7 @@ serde.workspace = true
serde_json.workspace = true
task.workspace = true
util.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
dap = { workspace = true, features = ["test-support"] }

View File

@@ -26,6 +26,7 @@ smol.workspace = true
sqlez.workspace = true
sqlez_macros.workspace = true
util.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }

View File

@@ -24,3 +24,4 @@ settings.workspace = true
smol.workspace = true
util.workspace = true
workspace.workspace = true
workspace-hack.workspace = true

View File

@@ -50,6 +50,7 @@ theme.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
dap = { workspace = true, features = ["test-support"] }

View File

@@ -22,3 +22,4 @@ http_client.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
workspace-hack.workspace = true

View File

@@ -30,6 +30,7 @@ theme.workspace = true
ui.workspace = true
util.workspace = true
workspace.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
client = { workspace = true, features = ["test-support"] }

View File

@@ -14,6 +14,7 @@ serde_json.workspace = true
settings.workspace = true
regex.workspace = true
util.workspace = true
workspace-hack.workspace = true
[lints]
workspace = true

View File

@@ -87,6 +87,7 @@ util.workspace = true
uuid.workspace = true
workspace.workspace = true
zed_actions.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
ctor.workspace = true

View File

@@ -37,7 +37,7 @@ pub use block_map::{
use block_map::{BlockRow, BlockSnapshot};
use collections::{HashMap, HashSet};
pub use crease_map::*;
pub use fold_map::{Fold, FoldId, FoldPlaceholder, FoldPoint};
pub use fold_map::{ChunkRenderer, ChunkRendererContext, Fold, FoldId, FoldPlaceholder, FoldPoint};
use fold_map::{FoldMap, FoldSnapshot};
use gpui::{App, Context, Entity, Font, HighlightStyle, LineLayout, Pixels, UnderlineStyle};
pub use inlay_map::Inlay;
@@ -45,8 +45,7 @@ use inlay_map::{InlayMap, InlaySnapshot};
pub use inlay_map::{InlayOffset, InlayPoint};
pub use invisibles::{is_invisible, replacement};
use language::{
ChunkRenderer, OffsetUtf16, Point, Subscription as BufferSubscription,
language_settings::language_settings,
OffsetUtf16, Point, Subscription as BufferSubscription, language_settings::language_settings,
};
use lsp::DiagnosticSeverity;
use multi_buffer::{
@@ -515,6 +514,33 @@ impl DisplayMap {
.update(cx, |map, cx| map.set_wrap_width(width, cx))
}
pub fn update_fold_widths(
&mut self,
widths: impl IntoIterator<Item = (FoldId, Pixels)>,
cx: &mut Context<Self>,
) -> bool {
let snapshot = self.buffer.read(cx).snapshot(cx);
let edits = self.buffer_subscription.consume().into_inner();
let tab_size = Self::tab_size(&self.buffer, cx);
let (snapshot, edits) = self.inlay_map.sync(snapshot, edits);
let (mut fold_map, snapshot, edits) = self.fold_map.write(snapshot, edits);
let (snapshot, edits) = self.tab_map.sync(snapshot, edits, tab_size);
let (snapshot, edits) = self
.wrap_map
.update(cx, |map, cx| map.sync(snapshot, edits, cx));
self.block_map.read(snapshot, edits);
let (snapshot, edits) = fold_map.update_fold_widths(widths);
let widths_changed = !edits.is_empty();
let (snapshot, edits) = self.tab_map.sync(snapshot, edits, tab_size);
let (snapshot, edits) = self
.wrap_map
.update(cx, |map, cx| map.sync(snapshot, edits, cx));
self.block_map.read(snapshot, edits);
widths_changed
}
pub(crate) fn current_inlays(&self) -> impl Iterator<Item = &Inlay> {
self.inlay_map.current_inlays()
}

View File

@@ -1,11 +1,12 @@
use super::{
Highlights,
fold_map::Chunk,
wrap_map::{self, WrapEdit, WrapPoint, WrapSnapshot},
};
use crate::{EditorStyle, GutterDimensions};
use collections::{Bound, HashMap, HashSet};
use gpui::{AnyElement, App, EntityId, Pixels, Window};
use language::{Chunk, Patch, Point};
use language::{Patch, Point};
use multi_buffer::{
Anchor, ExcerptId, ExcerptInfo, MultiBuffer, MultiBufferRow, MultiBufferSnapshot, RowInfo,
ToOffset, ToPoint as _,

View File

@@ -2,8 +2,9 @@ use super::{
Highlights,
inlay_map::{InlayBufferRows, InlayChunks, InlayEdit, InlayOffset, InlayPoint, InlaySnapshot},
};
use gpui::{AnyElement, App, ElementId};
use language::{Chunk, ChunkRenderer, Edit, Point, TextSummary};
use gpui::{AnyElement, App, ElementId, HighlightStyle, Pixels, Window};
use language::{Edit, HighlightId, Point, TextSummary};
use lsp::DiagnosticSeverity;
use multi_buffer::{
Anchor, AnchorRangeExt, MultiBufferRow, MultiBufferSnapshot, RowInfo, ToOffset,
};
@@ -14,7 +15,7 @@ use std::{
ops::{Add, AddAssign, Deref, DerefMut, Range, Sub},
sync::Arc,
};
use sum_tree::{Bias, Cursor, FilterCursor, SumTree, Summary};
use sum_tree::{Bias, Cursor, FilterCursor, SumTree, Summary, TreeMap};
use ui::IntoElement as _;
use util::post_inc;
@@ -177,6 +178,13 @@ impl FoldMapWriter<'_> {
let mut new_tree = SumTree::new(buffer);
let mut cursor = self.0.snapshot.folds.cursor::<FoldRange>(buffer);
for fold in folds {
self.0.snapshot.fold_metadata_by_id.insert(
fold.id,
FoldMetadata {
range: fold.range.clone(),
width: None,
},
);
new_tree.append(cursor.slice(&fold.range, Bias::Right, buffer), buffer);
new_tree.push(fold, buffer);
}
@@ -240,6 +248,7 @@ impl FoldMapWriter<'_> {
});
}
fold_ixs_to_delete.push(*folds_cursor.start());
self.0.snapshot.fold_metadata_by_id.remove(&fold.id);
}
folds_cursor.next(buffer);
}
@@ -263,6 +272,42 @@ impl FoldMapWriter<'_> {
let edits = self.0.sync(snapshot.clone(), edits);
(self.0.snapshot.clone(), edits)
}
pub(crate) fn update_fold_widths(
&mut self,
new_widths: impl IntoIterator<Item = (FoldId, Pixels)>,
) -> (FoldSnapshot, Vec<FoldEdit>) {
let mut edits = Vec::new();
let inlay_snapshot = self.0.snapshot.inlay_snapshot.clone();
let buffer = &inlay_snapshot.buffer;
for (id, new_width) in new_widths {
if let Some(metadata) = self.0.snapshot.fold_metadata_by_id.get(&id).cloned() {
if Some(new_width) != metadata.width {
let buffer_start = metadata.range.start.to_offset(buffer);
let buffer_end = metadata.range.end.to_offset(buffer);
let inlay_range = inlay_snapshot.to_inlay_offset(buffer_start)
..inlay_snapshot.to_inlay_offset(buffer_end);
edits.push(InlayEdit {
old: inlay_range.clone(),
new: inlay_range.clone(),
});
self.0.snapshot.fold_metadata_by_id.insert(
id,
FoldMetadata {
range: metadata.range,
width: Some(new_width),
},
);
}
}
}
let edits = consolidate_inlay_edits(edits);
let edits = self.0.sync(inlay_snapshot, edits);
(self.0.snapshot.clone(), edits)
}
}
/// Decides where the fold indicators should be; also tracks parts of a source file that are currently folded.
@@ -290,6 +335,7 @@ impl FoldMap {
),
inlay_snapshot: inlay_snapshot.clone(),
version: 0,
fold_metadata_by_id: TreeMap::default(),
},
next_fold_id: FoldId::default(),
};
@@ -481,6 +527,7 @@ impl FoldMap {
placeholder: Some(TransformPlaceholder {
text: ELLIPSIS,
renderer: ChunkRenderer {
id: fold.id,
render: Arc::new(move |cx| {
(fold.placeholder.render)(
fold_id,
@@ -489,6 +536,7 @@ impl FoldMap {
)
}),
constrain_width: fold.placeholder.constrain_width,
measured_width: self.snapshot.fold_width(&fold_id),
},
}),
},
@@ -573,6 +621,7 @@ impl FoldMap {
pub struct FoldSnapshot {
transforms: SumTree<Transform>,
folds: SumTree<Fold>,
fold_metadata_by_id: TreeMap<FoldId, FoldMetadata>,
pub inlay_snapshot: InlaySnapshot,
pub version: usize,
}
@@ -582,6 +631,10 @@ impl FoldSnapshot {
&self.inlay_snapshot.buffer
}
fn fold_width(&self, fold_id: &FoldId) -> Option<Pixels> {
self.fold_metadata_by_id.get(fold_id)?.width
}
#[cfg(test)]
pub fn text(&self) -> String {
self.chunks(FoldOffset(0)..self.len(), false, Highlights::default())
@@ -1006,7 +1059,7 @@ impl sum_tree::Summary for TransformSummary {
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug, Default)]
#[derive(Copy, Clone, Eq, PartialEq, Debug, Default, Ord, PartialOrd, Hash)]
pub struct FoldId(usize);
impl From<FoldId> for ElementId {
@@ -1045,6 +1098,12 @@ impl Default for FoldRange {
}
}
#[derive(Clone, Debug)]
struct FoldMetadata {
range: FoldRange,
width: Option<Pixels>,
}
impl sum_tree::Item for Fold {
type Summary = FoldSummary;
@@ -1181,10 +1240,74 @@ impl Iterator for FoldRows<'_> {
}
}
/// A chunk of a buffer's text, along with its syntax highlight and
/// diagnostic status.
#[derive(Clone, Debug, Default)]
pub struct Chunk<'a> {
/// The text of the chunk.
pub text: &'a str,
/// The syntax highlighting style of the chunk.
pub syntax_highlight_id: Option<HighlightId>,
/// The highlight style that has been applied to this chunk in
/// the editor.
pub highlight_style: Option<HighlightStyle>,
/// The severity of diagnostic associated with this chunk, if any.
pub diagnostic_severity: Option<DiagnosticSeverity>,
/// Whether this chunk of text is marked as unnecessary.
pub is_unnecessary: bool,
/// Whether this chunk of text was originally a tab character.
pub is_tab: bool,
/// An optional recipe for how the chunk should be presented.
pub renderer: Option<ChunkRenderer>,
}
/// A recipe for how the chunk should be presented.
#[derive(Clone)]
pub struct ChunkRenderer {
/// The id of the fold associated with this chunk.
pub id: FoldId,
/// Creates a custom element to represent this chunk.
pub render: Arc<dyn Send + Sync + Fn(&mut ChunkRendererContext) -> AnyElement>,
/// If true, the element is constrained to the shaped width of the text.
pub constrain_width: bool,
/// The width of the element, as measured during the last layout pass.
///
/// This is None if the element has not been laid out yet.
pub measured_width: Option<Pixels>,
}
pub struct ChunkRendererContext<'a, 'b> {
pub window: &'a mut Window,
pub context: &'b mut App,
pub max_width: Pixels,
}
impl fmt::Debug for ChunkRenderer {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ChunkRenderer")
.field("constrain_width", &self.constrain_width)
.finish()
}
}
impl Deref for ChunkRendererContext<'_, '_> {
type Target = App;
fn deref(&self) -> &Self::Target {
self.context
}
}
impl DerefMut for ChunkRendererContext<'_, '_> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.context
}
}
pub struct FoldChunks<'a> {
transform_cursor: Cursor<'a, Transform, (FoldOffset, InlayOffset)>,
inlay_chunks: InlayChunks<'a>,
inlay_chunk: Option<(InlayOffset, Chunk<'a>)>,
inlay_chunk: Option<(InlayOffset, language::Chunk<'a>)>,
inlay_offset: InlayOffset,
output_offset: FoldOffset,
max_output_offset: FoldOffset,
@@ -1292,7 +1415,15 @@ impl<'a> Iterator for FoldChunks<'a> {
self.inlay_offset = chunk_end;
self.output_offset.0 += chunk.text.len();
return Some(chunk);
return Some(Chunk {
text: chunk.text,
syntax_highlight_id: chunk.syntax_highlight_id,
highlight_style: chunk.highlight_style,
diagnostic_severity: chunk.diagnostic_severity,
is_unnecessary: chunk.is_unnecessary,
is_tab: chunk.is_tab,
renderer: None,
});
}
None

View File

@@ -1,8 +1,8 @@
use super::{
Highlights,
fold_map::{self, FoldChunks, FoldEdit, FoldPoint, FoldSnapshot},
fold_map::{self, Chunk, FoldChunks, FoldEdit, FoldPoint, FoldSnapshot},
};
use language::{Chunk, Point};
use language::Point;
use multi_buffer::MultiBufferSnapshot;
use std::{cmp, mem, num::NonZeroU32, ops::Range};
use sum_tree::Bias;

View File

@@ -1,10 +1,10 @@
use super::{
Highlights,
fold_map::FoldRows,
fold_map::{Chunk, FoldRows},
tab_map::{self, TabEdit, TabPoint, TabSnapshot},
};
use gpui::{App, AppContext as _, Context, Entity, Font, LineWrapper, Pixels, Task};
use language::{Chunk, Point};
use language::Point;
use multi_buffer::{MultiBufferSnapshot, RowInfo};
use smol::future::yield_now;
use std::sync::LazyLock;
@@ -454,6 +454,7 @@ impl WrapSnapshot {
}
let mut line = String::new();
let mut line_fragments = Vec::new();
let mut remaining = None;
let mut chunks = new_tab_snapshot.chunks(
TabPoint::new(edit.new_rows.start, 0)..new_tab_snapshot.max_point(),
@@ -462,15 +463,26 @@ impl WrapSnapshot {
);
let mut edit_transforms = Vec::<Transform>::new();
for _ in edit.new_rows.start..edit.new_rows.end {
while let Some(chunk) =
remaining.take().or_else(|| chunks.next().map(|c| c.text))
{
if let Some(ix) = chunk.find('\n') {
line.push_str(&chunk[..ix + 1]);
remaining = Some(&chunk[ix + 1..]);
while let Some(chunk) = remaining.take().or_else(|| chunks.next()) {
if let Some(ix) = chunk.text.find('\n') {
let (prefix, suffix) = chunk.text.split_at(ix + 1);
line_fragments.push(gpui::LineFragment::text(prefix));
line.push_str(prefix);
remaining = Some(Chunk {
text: suffix,
..chunk
});
break;
} else {
line.push_str(chunk)
if let Some(width) =
chunk.renderer.as_ref().and_then(|r| r.measured_width)
{
line_fragments
.push(gpui::LineFragment::element(width, chunk.text.len()));
} else {
line_fragments.push(gpui::LineFragment::text(chunk.text));
}
line.push_str(chunk.text);
}
}
@@ -479,7 +491,7 @@ impl WrapSnapshot {
}
let mut prev_boundary_ix = 0;
for boundary in line_wrapper.wrap_line(&line, wrap_width) {
for boundary in line_wrapper.wrap_line(&line_fragments, wrap_width) {
let wrapped = &line[prev_boundary_ix..boundary.ix];
push_isomorphic(&mut edit_transforms, TextSummary::from(wrapped));
edit_transforms.push(Transform::wrap(boundary.next_indent));
@@ -494,6 +506,7 @@ impl WrapSnapshot {
}
line.clear();
line_fragments.clear();
yield_now().await;
}
@@ -1173,7 +1186,7 @@ mod tests {
display_map::{fold_map::FoldMap, inlay_map::InlayMap, tab_map::TabMap},
test::test_font,
};
use gpui::{px, test::observe};
use gpui::{LineFragment, px, test::observe};
use rand::prelude::*;
use settings::SettingsStore;
use smol::stream::StreamExt;
@@ -1228,8 +1241,7 @@ mod tests {
log::info!("TabMap text: {:?}", tabs_snapshot.text());
let mut line_wrapper = text_system.line_wrapper(font.clone(), font_size);
let unwrapped_text = tabs_snapshot.text();
let expected_text = wrap_text(&unwrapped_text, wrap_width, &mut line_wrapper);
let expected_text = wrap_text(&tabs_snapshot, wrap_width, &mut line_wrapper);
let (wrap_map, _) =
cx.update(|cx| WrapMap::new(tabs_snapshot.clone(), font, font_size, wrap_width, cx));
@@ -1246,9 +1258,10 @@ mod tests {
let actual_text = initial_snapshot.text();
assert_eq!(
actual_text, expected_text,
actual_text,
expected_text,
"unwrapped text is: {:?}",
unwrapped_text
tabs_snapshot.text()
);
log::info!("Wrapped text: {:?}", actual_text);
@@ -1311,8 +1324,7 @@ mod tests {
let (tabs_snapshot, tab_edits) = tab_map.sync(fold_snapshot, fold_edits, tab_size);
log::info!("TabMap text: {:?}", tabs_snapshot.text());
let unwrapped_text = tabs_snapshot.text();
let expected_text = wrap_text(&unwrapped_text, wrap_width, &mut line_wrapper);
let expected_text = wrap_text(&tabs_snapshot, wrap_width, &mut line_wrapper);
let (mut snapshot, wrap_edits) =
wrap_map.update(cx, |map, cx| map.sync(tabs_snapshot.clone(), tab_edits, cx));
snapshot.check_invariants();
@@ -1328,8 +1340,9 @@ mod tests {
}
if !wrap_map.read_with(cx, |map, _| map.is_rewrapping()) {
let (mut wrapped_snapshot, wrap_edits) =
wrap_map.update(cx, |map, cx| map.sync(tabs_snapshot, Vec::new(), cx));
let (mut wrapped_snapshot, wrap_edits) = wrap_map.update(cx, |map, cx| {
map.sync(tabs_snapshot.clone(), Vec::new(), cx)
});
let actual_text = wrapped_snapshot.text();
let actual_longest_row = wrapped_snapshot.longest_row();
log::info!("Wrapping finished: {:?}", actual_text);
@@ -1337,9 +1350,10 @@ mod tests {
wrapped_snapshot.verify_chunks(&mut rng);
edits.push((wrapped_snapshot.clone(), wrap_edits));
assert_eq!(
actual_text, expected_text,
actual_text,
expected_text,
"unwrapped text is: {:?}",
unwrapped_text
tabs_snapshot.text()
);
let mut summary = TextSummary::default();
@@ -1425,19 +1439,19 @@ mod tests {
}
fn wrap_text(
unwrapped_text: &str,
tab_snapshot: &TabSnapshot,
wrap_width: Option<Pixels>,
line_wrapper: &mut LineWrapper,
) -> String {
if let Some(wrap_width) = wrap_width {
let mut wrapped_text = String::new();
for (row, line) in unwrapped_text.split('\n').enumerate() {
for (row, line) in tab_snapshot.text().split('\n').enumerate() {
if row > 0 {
wrapped_text.push('\n')
wrapped_text.push('\n');
}
let mut prev_ix = 0;
for boundary in line_wrapper.wrap_line(line, wrap_width) {
for boundary in line_wrapper.wrap_line(&[LineFragment::text(&line)], wrap_width) {
wrapped_text.push_str(&line[prev_ix..boundary.ix]);
wrapped_text.push('\n');
wrapped_text.push_str(&" ".repeat(boundary.next_indent as usize));
@@ -1445,9 +1459,10 @@ mod tests {
}
wrapped_text.push_str(&line[prev_ix..]);
}
wrapped_text
} else {
unwrapped_text.to_string()
tab_snapshot.text()
}
}

View File

@@ -58,7 +58,7 @@ use clock::ReplicaId;
use collections::{BTreeMap, HashMap, HashSet, VecDeque};
use convert_case::{Case, Casing};
use display_map::*;
pub use display_map::{DisplayPoint, FoldPlaceholder};
pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder};
use editor_settings::GoToDefinitionFallback;
pub use editor_settings::{
CurrentLineHighlight, EditorSettings, HideMouseMode, ScrollBeyondLastLine, SearchSettings,
@@ -15040,6 +15040,15 @@ impl Editor {
self.active_indent_guides_state.dirty = true;
}
pub fn update_fold_widths(
&mut self,
widths: impl IntoIterator<Item = (FoldId, Pixels)>,
cx: &mut Context<Self>,
) -> bool {
self.display_map
.update(cx, |map, cx| map.update_fold_widths(widths, cx))
}
pub fn default_fold_placeholder(&self, cx: &App) -> FoldPlaceholder {
self.display_map.read(cx).fold_placeholder.clone()
}

Some files were not shown because too many files have changed in this diff Show More