Compare commits

...

41 Commits

Author SHA1 Message Date
Joseph T. Lyons
2bfac793cd zed 0.184.5 2025-04-28 12:53:16 -04:00
gcp-cherry-pick-bot[bot]
a62a6c46a8 askpass: Shell escape Zed path in askpass script (cherry-pick #29447) (#29449)
Cherry-picked askpass: Shell escape Zed path in askpass script (#29447)

Closes #29439

Add shell escaping as well as additional sanity check for Zed path when
used in askpass. This caused issues on preview and nightly as the
standard paths for those releases contain spaces which were not escaped
appropriately leading to erroneous "Permission denied" errors from SSH
when the askpass script failed

Release Notes:

- Fixed a missing shell-escape in askpass resulting in erroneous
"Permission denied" errors when trying to connect to a remote server
over ssh (effecting preview release v0.184.1 and nightly only)

Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-04-28 11:20:08 -04:00
Danilo Leal
d68df740f7 agent: Simplify elements of the thread design (#29533)
Namely, this PR removes the layout shift when you click on a user
message to edit it and displays the feedback disclaimer only upon
hovering the thumbs up/down button container.

Release Notes:

- N/A
2025-04-28 11:02:02 -04:00
Danilo Leal
17e2746c16 agent: Bring title editing back to text threads (#29425)
This also fixes a little UI bug where the text thread title would push
the buttons away from the UI when there was still space.

Release Notes:

- agent: Made text thread titles editable again.

---------

Co-authored-by: Michael Sloan <mgsloan@gmail.com>
2025-04-28 10:29:10 -04:00
Oleksiy Syvokon
45ebaac2a4 agent tools: Make read_file.end_line inclusive (#29524)
One motivation is that the outlines returned by `read_file` for large
files list line numbers assuming an inclusive `end_line`. As a result,
when the agent uses these outlines for `read_line` calls, it would
otherwise miss the last line.

Release Notes:

- N/A
2025-04-28 10:29:01 -04:00
Antonio Scandurra
0f051dd36a Fix error when deserializing Gemini streams (#29470)
Sometimes Gemini would report `Content` without a `parts` field.

Release Notes:

- Fixed a bug that would sometimes cause Gemini models to fail streaming
their response.
2025-04-28 10:27:21 -04:00
Joseph T. Lyons
18c213709b zed 0.184.4 2025-04-25 18:14:55 -04:00
Smit Barmase
d92de5ee61 editor: Improve fuzzy match bucket logic for code completions (#29442)
Add new test and improve fuzzy match bucket logic which results into far
better balance between LSP and fuzzy search.

Before:
<img width="500" alt="before"
src="https://github.com/user-attachments/assets/3e8900a6-c0ff-4f37-b88e-b0e3783b7e9a"
/>

After:
<img width="500" alt="after"
src="https://github.com/user-attachments/assets/738c074c-d446-4697-aac6-9814362e88db"
/>

Release Notes:

- N/A
2025-04-25 18:13:30 -04:00
Smit Barmase
0078f673c3 editor: Add setting for snippet sorting behavior for code completion (#29429)
Added `snippet_sort_order`, which determines how snippets are sorted
relative to other completion items. It can have the values `top`,
`bottom`, or `inline`, with `inline` being the default.

This mimics VS Code’s setting:
https://code.visualstudio.com/docs/editing/intellisense#_snippets-in-suggestions

Release Notes:

- Added support for `snippet_sort_order` to control snippet sorting
behavior in code completion menus.
2025-04-25 18:13:17 -04:00
Danilo Leal
77f1609822 ui: Add inline_code method to label (#29306)
This makes it easy to have a label that looks like Markdown inline code
via the `inline_code(cx)` method.

Release Notes:

- N/A
2025-04-25 17:55:18 -04:00
Michael Sloan
93302ab78c Fix inclusion of message when counting tokens from message editor (#29443)
Accidentally omitted this in #29233

Release Notes:

- N/A
2025-04-25 17:38:58 -04:00
Danilo Leal
7c93c8e572 agent: Render path search results with ToolCard (#28894)
Implementing the `ToolCard` for the path_search tool. It also adds the
"jump to file" functionality if you expand the results.

Release Notes:

- N/A

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-04-25 17:38:42 -04:00
Danilo Leal
01d206626d agent: Make markdown code blocks uncollapsed by default (#29424)
Seeing the markdown code stream in is actually more helpful than hiding
parts of it; parts that you may be interested in.

Release Notes:

- N/A
2025-04-25 17:38:42 -04:00
Bennet Bo Fenner
dfea962e1d agent: Allow to explictly disable tools when using enable_all_context_servers (#29414)
Previously, all MCP tools would be completed regardless if they were
disabled/enabled for the profile. This meant that the "Write" profile
was always using all MCP tools, even if you disabled them in the
settings.

Now, when `enable_all_context_servers` is set to `true`, we will enable
all tools from all MCP servers by default but disable the ones that are
explicitly disabled for the profile.

Also fixes an issue where the tools would not show up as enabled when
using `enable_all_context_servers: true`

Release Notes:

- agent: Fix an issue where MCP tools could not be enabled/disabled
2025-04-25 13:30:26 -04:00
Zed Bot
0f3c56f8d9 Bump to 0.184.3 for @bennetbo 2025-04-25 13:21:10 +00:00
gcp-cherry-pick-bot[bot]
1cbdce4607 assistant: Fix issue when using inline assistant with Gemini models (cherry-pick #29407) (#29409)
Cherry-picked assistant: Fix issue when using inline assistant with
Gemini models (#29407)

Closes #29020

Release Notes:

- assistant: Fix issue when using inline assistant with Gemini models

Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>
2025-04-25 14:47:17 +02:00
Marshall Bowers
745e7d60c4 language_models: Use POST /completions endpoint for Zed provider (#29389)
This PR updates the Zed provider to use the `POST /completions`
endpoint.

There is no functional difference from `POST /completion`, but the
pluralized version reads better.

Release Notes:

- N/A
2025-04-25 08:22:43 -04:00
Michael Sloan
19de45c638 Bring back reload of agent context before sending message (#29385)
Realized after merging #29233 that this behavior is desired

Release Notes:

- N/A
2025-04-25 08:22:17 -04:00
Max Brunsfeld
f8c3a01523 Remove unnecessary fields from the tool schemas (#29381)
This PR removes two fields from JSON schemas (`$schema` and `title`),
which are not expected by any model provider, but were spuriously
included by our JSON schema library, `schemars`.

These added noise to requests and cost wasted input tokens.

### Old

```json
{
  "$schema": "http://json-schema.org/draft-07/schema#",
  "title": "FetchToolInput",
  "type": "object",
  "required": [
    "url"
  ],
  "properties": {
    "url": {
      "description": "The URL to fetch.",
      "type": "string"
    }
  }
}
```

### New:

```json
{
  "properties": {
    "url": {
      "description": "The URL to fetch.",
      "type": "string"
    }
  },
  "required": [
    "url"
  ],
  "type": "object"
}
```

- N/A
2025-04-25 08:22:09 -04:00
Michael Sloan
08cb7034af Restructure agent context (#29233)
Simplifies the data structures involved in agent context by removing
caching and limiting the use of ContextId:

* `AssistantContext` enum is now like an ID / handle to context that
does not need to be updated. `ContextId` still exists but is only used
for generating unique `ElementId`.
* `ContextStore` has a `IndexMap<ContextSetEntry>`. Only need to keep a
`HashSet<ThreadId>` consistent with it. `ContextSetEntry` is a newtype
wrapper around `AssistantContext` which implements eq / hash on a subset
of fields.
* Thread `Message` directly stores its context.

Fixes the following bugs:

* If a context entry is removed from the strip and added again, it was
reincluded in the next message.
* Clicking file context in the thread that has been removed from the
context strip didn't jump to the file.
* Refresh of directory context didn't reflect added / removed files.
* Deleted directories would remain in the message editor context strip.
* Token counting requests didn't include image context.
* File, directory, and symbol context deduplication relied on
`ProjectPath` for identity, and so didn't handle renames.
* Symbol context line numbers didn't update when shifted

Known bugs (not fixed):

* Deleting a directory causes it to disappear from messages in threads.
Fixing this in a nice way is tricky. One easy fix is to store the
original path and show that on deletion. It's weird that deletion would
cause the name to "revert", though. Another possibility would be to
snapshot context metadata on add (ala `AddedContext`), and keep that
around despite deletion.

Release Notes:

- N/A
2025-04-24 21:04:33 -04:00
Richard Feldman
43a04dc46f Treat invalid JSON in tool calls as failed tool calls (#29375)
Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-04-24 17:00:14 -04:00
Danilo Leal
00e92e37fd Rename "Prompt Library" to "Rules Library" (#29349)
There's probably more to do to fully make the transition, and we'll
still debate a bit internally whether this is the name, but just opening
this PR up now for visibility.

Release Notes:

- N/A
2025-04-24 17:00:14 -04:00
Bennet Bo Fenner
9e703accdb gemini: Fix issue when deserializing tool call (#29363)
Fixes a regression introduced in #29322

Release Notes:

- N/A

Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-04-24 14:26:27 -04:00
Nathan Sobo
44d5976849 Introduce LanguageModelToolUse::raw_input (#29322)
This is to enable alternative streaming solutions at the application
layer. I'm not sure we really should have performed parsing of the input
at this layer. Either way I want to experiment with streaming approaches
in a separate crate on a branch, and this will help.

/cc @maxdeviant @bennetbo @rtfeldman

Closes #ISSUE

Release Notes:

- N/A
2025-04-24 14:26:16 -04:00
Agus Zubiaga
44fc0045dd agent: Do not reuse assistant message across generations (#29360)
#29354 introduced a bug where we would append tool uses to the last
assistant message even if it was from a previous request.

Release Notes:

- N/A

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
2025-04-24 14:25:19 -04:00
Joseph T. Lyons
c60b731f2f zed 0.184.2 2025-04-24 13:40:32 -04:00
Agus Zubiaga
b85bfbad67 agent: Do not create user messages for tool results in thread (#29354)
We used to insert empty user messages into the `Thread::messages` `Vec`
when tools finished running and then we would attach the results when
creating the request. This approach was very easy to mess up during
state handling, leading to empty user messages displayed in the
conversation and API failures.

Instead, we will no longer insert actual user messages for tool results
to the `Thread`, and will only do this on the fly when creating the
model request. This simplifies a lot of code and show fix the mentioned
errors.

Release Notes:

- agent: Improve reliability of LLM requests when including tool results

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
2025-04-24 12:46:52 -04:00
Danilo Leal
f743cdc41a agent: Simplify user message design more (#29326)
Follow-up to https://github.com/zed-industries/zed/pull/29165 where the
user message design is simplified even more. The edit button is not
visible anymore, and you can click on the whole message block to edit a
message.

Release Notes:

- N/A
2025-04-24 11:21:08 -04:00
Peter Tripp
808d9c9361 Fix ctrl-enter opening inline-assistant in assistant text threads (#29313)
Closes: https://github.com/zed-industries/zed/issues/24501

This has been broken for a while on linux (at least since Feb 8th!) for Assistant1.
It is also broken for Text Threads in Assitant2 (on macos and linux).

This should fix both.

Potentially related:
- https://github.com/zed-industries/zed/pull/29107

Release Notes:

- Fix for `ctrl-enter` shortcut in Assistant text threads incorrectly
opening inline assist instead of triggering Send.

Co-authored-by: Conrad Irwin <conrad@zed.dev>
2025-04-24 11:21:00 -04:00
Agus Zubiaga
a53e561d5e agent: Improve initial file search quality (#29317)
This PR significantly improves the quality of the initial file search
that occurs when the model doesn't yet know the full path to a file it
needs to read/edit.

Previously, the assertions in file_search often failed on main as the
model attempted to guess full file paths. On this branch, it reliably
calls `find_path` (previously `path_search`) before reading files.

After getting the model to find paths first, I noticed it would try
using `grep` instead of `path_search`. This motivated renaming
`path_search` to `find_path` (continuing the analogy to unix commands)
and adding system prompt instructions about proper tool selection.

Note: I know the command is just called `find`, but that seemed too
general.

In my eval runs, the `file_search` example improved from 40% ± 10% to
98% ± 2%. The only assertion I'm seeing occasionally fail is "glob
starts with `**` or project". We can probably add some instructions in
that regard.

Release Notes:

- N/A
2025-04-24 09:07:37 -04:00
Agus Zubiaga
80bac1ca92 eval: New add_arg_to_trait_method example (#29297)
Release Notes:

- N/A

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
2025-04-24 09:07:21 -04:00
Agus Zubiaga
4171bbbe9e agent: Encourage model to include displayed fields first (#29308)
Instructs the model to include the fields that we display first in the
input object, so that e.g the user can see the path of a file while the
model generates the content.

Release Notes:

- N/A
2025-04-24 09:05:57 -04:00
Marshall Bowers
2a88867aa0 agent: Read the user's plan from the UserStore (#29305)
This PR updates the Agent panel to read the user's plan from the
`UserStore` instead of hard-coding it.

Release Notes:

- N/A
2025-04-24 09:05:41 -04:00
Oleksiy Syvokon
9d420ae8f1 agent: Add "copy to clipboard" button to error message popups (#29299)
This change makes agent errors copy-able to clipboard:


![image](https://github.com/user-attachments/assets/bd34a3f2-ecd4-4092-9b3b-960953ed1879)



Release Notes:

- N/A
2025-04-24 09:05:29 -04:00
Finn Evers
b431886188 editor: Fix broken mouse scrolling on main (#29307)
This PR is a quick follow-up to #29234 , which unfortunately broke
scrolling with the mouse in editors on main.

The linked PR introduced the possiblilty to completely disable scrolling
for editors. Unfortunately, it also disabled scrolling for editors by
default. This PR fixes this by re-enabling it by default.

This change also needs to be backported to v0.184.x. Otherwise, mouse
scrolling in the next preview release will not work!

Release Notes:

- N/A
2025-04-23 15:38:09 -07:00
Danilo Leal
43c99ad066 agent: Improve feedback text and buttons wrapping (#29302)
Just a little UI improvement here.

Release Notes:

- N/A
2025-04-23 16:32:43 -04:00
Danilo Leal
f802b56299 agent: Render diffs for the edit file tool (#29234)
This PR implements the `ToolCard` for the edit file tool, which allow us
to display an editor with a diff in the thread view with the changes
performed by the model.

- [x] Fix buffer sometimes displaying empty
- [x] Stop buffer from scrolling together with the thread
- [x] Fix multibuffer header sometimes appearing
- [x] Fix buffer height issue
- [x] Implement "full height" expand button
- [x] Add "Jump To File" functionality
- [x] Polish and refine styles

Release Notes:

- agent: Added diff preview cards in the thread view for edits performed
by the agent.

---------

Co-authored-by: João Marcos <marcospb19@hotmail.com>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Co-authored-by: Agus Zubiaga <hi@aguz.me>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-04-23 16:32:43 -04:00
Joseph T. Lyons
0c292b6cc6 zed 0.184.1 2025-04-23 14:31:23 -04:00
Richard Feldman
8446005112 More graceful invalid JSON handling (#29295)
Now we're more tolerant of invalid JSON coming back from the model
(possibly because it was incomplete and we're streaming), plus if we do
end up with invalid JSON once it has all streamed back, we report what
the malformed JSON actually was:

<img width="444" alt="Screenshot 2025-04-23 at 1 49 14 PM"
src="https://github.com/user-attachments/assets/480f5da7-869b-49f3-9ffd-8f08ccddb33d"
/>

Release Notes:

- N/A
2025-04-23 14:30:05 -04:00
Kirill Bulatov
5b028ac2d0 Fix relative paths not properly resolved in the terminal during cmd-click (#29289)
Closes https://github.com/zed-industries/zed/pull/28342
Closes https://github.com/zed-industries/zed/issues/28339
Fixes
https://github.com/zed-industries/zed/pull/29274#issuecomment-2824794396

Release Notes:

- Fixed relative paths not properly resolved in the terminal during
cmd-click
2025-04-23 19:37:15 +03:00
Joseph T. Lyons
57d0276180 v0.184.x preview 2025-04-23 11:57:41 -04:00
119 changed files with 5361 additions and 3533 deletions

485
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -109,7 +109,7 @@ members = [
"crates/project",
"crates/project_panel",
"crates/project_symbols",
"crates/prompt_library",
"crates/rules_library",
"crates/prompt_store",
"crates/proto",
"crates/recent_projects",
@@ -318,7 +318,7 @@ prettier = { path = "crates/prettier" }
project = { path = "crates/project" }
project_panel = { path = "crates/project_panel" }
project_symbols = { path = "crates/project_symbols" }
prompt_library = { path = "crates/prompt_library" }
rules_library = { path = "crates/rules_library" }
prompt_store = { path = "crates/prompt_store" }
proto = { path = "crates/proto" }
recent_projects = { path = "crates/recent_projects" }
@@ -499,6 +499,7 @@ prost-types = "0.9"
pulldown-cmark = { version = "0.12.0", default-features = false }
quote = "1.0.9"
rand = "0.8.5"
ref-cast = "1.0.24"
rayon = "1.8"
regex = "1.5"
repair_json = "0.1.0"

View File

@@ -212,7 +212,7 @@
"ctrl-shift-g": "search::SelectPreviousMatch",
"ctrl-alt-/": "assistant::ToggleModelSelector",
"ctrl-k h": "assistant::DeployHistory",
"ctrl-k l": "assistant::OpenPromptLibrary",
"ctrl-k l": "assistant::OpenRulesLibrary",
"new": "assistant::NewChat",
"ctrl-t": "assistant::NewChat",
"ctrl-n": "assistant::NewChat"
@@ -241,7 +241,7 @@
"ctrl-alt-n": "agent::NewTextThread",
"ctrl-shift-h": "agent::OpenHistory",
"ctrl-alt-c": "agent::OpenConfiguration",
"ctrl-alt-p": "assistant::OpenPromptLibrary",
"ctrl-alt-p": "assistant::OpenRulesLibrary",
"ctrl-i": "agent::ToggleProfileSelector",
"ctrl-alt-/": "assistant::ToggleModelSelector",
"ctrl-shift-a": "agent::ToggleContextPicker",
@@ -308,9 +308,9 @@
{
"context": "PromptLibrary",
"bindings": {
"new": "prompt_library::NewPrompt",
"ctrl-n": "prompt_library::NewPrompt",
"ctrl-shift-s": "prompt_library::ToggleDefaultPrompt"
"new": "rules_library::NewRule",
"ctrl-n": "rules_library::NewRule",
"ctrl-shift-s": "rules_library::ToggleDefaultRule"
}
},
{
@@ -675,7 +675,7 @@
}
},
{
"context": "Editor && mode == full",
"context": "!ContextEditor > Editor && mode == full",
"bindings": {
"alt-enter": "editor::OpenExcerpts",
"shift-enter": "editor::ExpandExcerpts",

View File

@@ -5,8 +5,8 @@
"context": "PromptLibrary",
"use_key_equivalents": true,
"bindings": {
"cmd-n": "prompt_library::NewPrompt",
"cmd-shift-s": "prompt_library::ToggleDefaultPrompt",
"cmd-n": "rules_library::NewRule",
"cmd-shift-s": "rules_library::ToggleDefaultRule",
"cmd-w": "workspace::CloseWindow"
}
},
@@ -257,7 +257,7 @@
"cmd-shift-g": "search::SelectPreviousMatch",
"cmd-alt-/": "assistant::ToggleModelSelector",
"cmd-k h": "assistant::DeployHistory",
"cmd-k l": "assistant::OpenPromptLibrary",
"cmd-k l": "assistant::OpenRulesLibrary",
"cmd-t": "assistant::NewChat",
"cmd-n": "assistant::NewChat"
}
@@ -286,7 +286,7 @@
"cmd-alt-n": "agent::NewTextThread",
"cmd-shift-h": "agent::OpenHistory",
"cmd-alt-c": "agent::OpenConfiguration",
"cmd-alt-p": "assistant::OpenPromptLibrary",
"cmd-alt-p": "assistant::OpenRulesLibrary",
"cmd-i": "agent::ToggleProfileSelector",
"cmd-alt-/": "assistant::ToggleModelSelector",
"cmd-shift-a": "agent::ToggleContextPicker",
@@ -738,7 +738,7 @@
}
},
{
"context": "Editor && mode == full",
"context": "!ContextEditor > Editor && mode == full",
"use_key_equivalents": true,
"bindings": {
"alt-enter": "editor::OpenExcerpts",

View File

@@ -31,6 +31,9 @@ If appropriate, use tool calls to explore the current project, which contains th
- When looking for symbols in the project, prefer the `grep` tool.
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
- Bias towards not asking the user for help if you can find the answer yourself.
{{! TODO: Only mention tools if they are enabled }}
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path!
## Fixing Diagnostics

View File

@@ -167,7 +167,23 @@
// Default: not set, defaults to "bar"
"cursor_shape": null,
// Determines when the mouse cursor should be hidden in an editor or input box.
//
// 1. Never hide the mouse cursor:
// "never"
// 2. Hide only when typing:
// "on_typing"
// 3. Hide on both typing and cursor movement:
// "on_typing_and_movement"
"hide_mouse": "on_typing_and_movement",
// Determines how snippets are sorted relative to other completion items.
//
// 1. Place snippets at the top of the completion list:
// "top"
// 2. Place snippets normally without any preference:
// "inline"
// 3. Place snippets at the bottom of the completion list:
// "bottom"
"snippet_sort_order": "inline",
// How to highlight the current line in the editor.
//
// 1. Don't highlight the current line:
@@ -646,7 +662,7 @@
"fetch": true,
"list_directory": false,
"now": true,
"path_search": true,
"find_path": true,
"read_file": true,
"grep": true,
"thinking": true,
@@ -670,7 +686,7 @@
"list_directory": true,
"move_path": false,
"now": false,
"path_search": true,
"find_path": true,
"read_file": true,
"grep": true,
"rename": false,

View File

@@ -1,2 +1,7 @@
allow-private-module-inception = true
avoid-breaking-exported-api = false
ignore-interior-mutability = [
# Suppresses clippy::mutable_key_type, which is a false positive as the Eq
# and Hash impls do not use fields with interior mutability.
"agent::context::AgentContextKey"
]

View File

@@ -28,7 +28,6 @@ async-watch.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
component.workspace = true
@@ -62,9 +61,10 @@ parking_lot.workspace = true
paths.workspace = true
picker.workspace = true
project.workspace = true
prompt_library.workspace = true
rules_library.workspace = true
prompt_store.workspace = true
proto.workspace = true
ref-cast.workspace = true
release_channel.workspace = true
rope.workspace = true
schemars.workspace = true

File diff suppressed because it is too large Load Diff

View File

@@ -962,11 +962,13 @@ mod tests {
})
.unwrap();
let prompt_store = None;
let thread_store = cx
.update(|cx| {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
prompt_store,
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)

View File

@@ -39,6 +39,7 @@ use thread::ThreadId;
pub use crate::active_thread::ActiveThread;
use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal};
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
pub use crate::context::{ContextLoadResult, LoadedContext};
pub use crate::inline_assistant::InlineAssistant;
pub use crate::thread::{Message, Thread, ThreadEvent};
pub use crate::thread_store::ThreadStore;

View File

@@ -271,7 +271,7 @@ impl PickerDelegate for ToolPickerDelegate {
.get(id.as_ref())
.and_then(|preset| preset.tools.get(&tool.name))
.copied()
.unwrap_or(false),
.unwrap_or(self.profile.enable_all_context_servers),
};
Some(

View File

@@ -5,28 +5,29 @@ use std::time::Duration;
use anyhow::{Result, anyhow};
use assistant_context_editor::{
AssistantPanelDelegate, ConfigurationError, ContextEditor, SlashCommandCompletionProvider,
humanize_token_count, make_lsp_adapter_delegate, render_remaining_tokens,
AssistantContext, AssistantPanelDelegate, ConfigurationError, ContextEditor, ContextEvent,
SlashCommandCompletionProvider, humanize_token_count, make_lsp_adapter_delegate,
render_remaining_tokens,
};
use assistant_settings::{AssistantDockPosition, AssistantSettings};
use assistant_slash_command::SlashCommandWorkingSet;
use assistant_tool::ToolWorkingSet;
use client::zed_urls;
use client::{UserStore, zed_urls};
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
use fs::Fs;
use gpui::{
Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, Corner, Entity,
EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, Pixels, Subscription, Task,
UpdateGlobal, WeakEntity, prelude::*, pulsating_between,
Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, ClipboardItem,
Corner, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, Pixels,
Subscription, Task, UpdateGlobal, WeakEntity, prelude::*, pulsating_between,
};
use language::LanguageRegistry;
use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector;
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
use prompt_store::{PromptBuilder, PromptStore, UserPromptId};
use proto::Plan;
use rules_library::{RulesLibrary, open_rules_library};
use settings::{Settings, update_settings_file};
use time::UtcOffset;
use ui::{
@@ -36,7 +37,7 @@ use util::ResultExt as _;
use workspace::Workspace;
use workspace::dock::{DockPosition, Panel, PanelEvent};
use zed_actions::agent::OpenConfiguration;
use zed_actions::assistant::{OpenPromptLibrary, ToggleFocus};
use zed_actions::assistant::{OpenRulesLibrary, ToggleFocus};
use crate::active_thread::{ActiveThread, ActiveThreadEvent};
use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent};
@@ -79,11 +80,11 @@ pub fn init(cx: &mut App) {
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
}
})
.register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| {
panel.deploy_prompt_library(action, window, cx)
panel.deploy_rules_library(action, window, cx)
});
}
})
@@ -116,6 +117,8 @@ enum ActiveView {
},
PromptEditor {
context_editor: Entity<ContextEditor>,
title_editor: Entity<Editor>,
_subscriptions: Vec<gpui::Subscription>,
},
History,
Configuration,
@@ -176,10 +179,88 @@ impl ActiveView {
_subscriptions: subscriptions,
}
}
pub fn prompt_editor(
context_editor: Entity<ContextEditor>,
window: &mut Window,
cx: &mut App,
) -> Self {
let title = context_editor.read(cx).title(cx).to_string();
let editor = cx.new(|cx| {
let mut editor = Editor::single_line(window, cx);
editor.set_text(title, window, cx);
editor
});
// This is a workaround for `editor.set_text` emitting a `BufferEdited` event, which would
// cause a custom summary to be set. The presence of this custom summary would cause
// summarization to not happen.
let mut suppress_first_edit = true;
let subscriptions = vec![
window.subscribe(&editor, cx, {
{
let context_editor = context_editor.clone();
move |editor, event, window, cx| match event {
EditorEvent::BufferEdited => {
if suppress_first_edit {
suppress_first_edit = false;
return;
}
let new_summary = editor.read(cx).text(cx);
context_editor.update(cx, |context_editor, cx| {
context_editor
.context()
.update(cx, |assistant_context, cx| {
assistant_context.set_custom_summary(new_summary, cx);
})
})
}
EditorEvent::Blurred => {
if editor.read(cx).text(cx).is_empty() {
let summary = context_editor
.read(cx)
.context()
.read(cx)
.summary_or_default();
editor.update(cx, |editor, cx| {
editor.set_text(summary, window, cx);
});
}
}
_ => {}
}
}
}),
window.subscribe(&context_editor.read(cx).context().clone(), cx, {
let editor = editor.clone();
move |assistant_context, event, window, cx| match event {
ContextEvent::SummaryGenerated => {
let summary = assistant_context.read(cx).summary_or_default();
editor.update(cx, |editor, cx| {
editor.set_text(summary, window, cx);
})
}
_ => {}
}
}),
];
Self::PromptEditor {
context_editor,
title_editor: editor,
_subscriptions: subscriptions,
}
}
}
pub struct AssistantPanel {
workspace: WeakEntity<Workspace>,
user_store: Entity<UserStore>,
project: Entity<Project>,
fs: Arc<dyn Fs>,
language_registry: Arc<LanguageRegistry>,
@@ -188,6 +269,7 @@ pub struct AssistantPanel {
message_editor: Entity<MessageEditor>,
_active_thread_subscriptions: Vec<Subscription>,
context_store: Entity<assistant_context_editor::ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
configuration: Option<Entity<AssistantConfiguration>>,
configuration_subscription: Option<Subscription>,
local_timezone: UtcOffset,
@@ -204,14 +286,25 @@ impl AssistantPanel {
pub fn load(
workspace: WeakEntity<Workspace>,
prompt_builder: Arc<PromptBuilder>,
cx: AsyncWindowContext,
mut cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> {
let prompt_store = cx.update(|_window, cx| PromptStore::global(cx));
cx.spawn(async move |cx| {
let prompt_store = match prompt_store {
Ok(prompt_store) => prompt_store.await.ok(),
Err(_) => None,
};
let tools = cx.new(|_| ToolWorkingSet::default())?;
let thread_store = workspace
.update(cx, |workspace, cx| {
let project = workspace.project().clone();
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
ThreadStore::load(
project,
tools.clone(),
prompt_store.clone(),
prompt_builder.clone(),
cx,
)
})?
.await?;
@@ -229,7 +322,16 @@ impl AssistantPanel {
.await?;
workspace.update_in(cx, |workspace, window, cx| {
cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx))
cx.new(|cx| {
Self::new(
workspace,
thread_store,
context_store,
prompt_store,
window,
cx,
)
})
})
})
}
@@ -238,11 +340,13 @@ impl AssistantPanel {
workspace: &Workspace,
thread_store: Entity<ThreadStore>,
context_store: Entity<assistant_context_editor::ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
let fs = workspace.app_state().fs.clone();
let user_store = workspace.app_state().user_store.clone();
let project = workspace.project();
let language_registry = project.read(cx).languages().clone();
let workspace = workspace.weak_handle();
@@ -260,6 +364,7 @@ impl AssistantPanel {
fs.clone(),
workspace.clone(),
message_editor_context_store.clone(),
prompt_store.clone(),
thread_store.downgrade(),
thread.clone(),
window,
@@ -291,7 +396,6 @@ impl AssistantPanel {
thread.clone(),
thread_store.clone(),
language_registry.clone(),
message_editor_context_store.clone(),
workspace.clone(),
window,
cx,
@@ -307,6 +411,7 @@ impl AssistantPanel {
Self {
active_view,
workspace,
user_store,
project: project.clone(),
fs: fs.clone(),
language_registry,
@@ -319,6 +424,7 @@ impl AssistantPanel {
message_editor_subscription,
],
context_store,
prompt_store,
configuration: None,
configuration_subscription: None,
local_timezone: UtcOffset::from_whole_seconds(
@@ -352,18 +458,17 @@ impl AssistantPanel {
self.local_timezone
}
pub(crate) fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
&self.prompt_store
}
pub(crate) fn thread_store(&self) -> &Entity<ThreadStore> {
&self.thread_store
}
fn cancel(
&mut self,
_: &editor::actions::Cancel,
_window: &mut Window,
cx: &mut Context<Self>,
) {
fn cancel(&mut self, _: &editor::actions::Cancel, window: &mut Window, cx: &mut Context<Self>) {
self.thread
.update(cx, |thread, cx| thread.cancel_last_completion(cx));
.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
}
fn new_thread(&mut self, action: &NewThread, window: &mut Window, cx: &mut Context<Self>) {
@@ -413,7 +518,6 @@ impl AssistantPanel {
thread.clone(),
self.thread_store.clone(),
self.language_registry.clone(),
message_editor_context_store.clone(),
self.workspace.clone(),
window,
cx,
@@ -432,6 +536,7 @@ impl AssistantPanel {
self.fs.clone(),
self.workspace.clone(),
message_editor_context_store,
self.prompt_store.clone(),
self.thread_store.downgrade(),
thread,
window,
@@ -477,22 +582,20 @@ impl AssistantPanel {
});
self.set_active_view(
ActiveView::PromptEditor {
context_editor: context_editor.clone(),
},
ActiveView::prompt_editor(context_editor.clone(), window, cx),
window,
cx,
);
context_editor.focus_handle(cx).focus(window);
}
fn deploy_prompt_library(
fn deploy_rules_library(
&mut self,
action: &OpenPromptLibrary,
action: &OpenRulesLibrary,
_window: &mut Window,
cx: &mut Context<Self>,
) {
open_prompt_library(
open_rules_library(
self.language_registry.clone(),
Box::new(PromptLibraryInlineAssist::new(self.workspace.clone())),
Arc::new(|| {
@@ -502,9 +605,9 @@ impl AssistantPanel {
None,
))
}),
action.prompt_to_select.map(|uuid| PromptId::User {
uuid: UserPromptId(uuid),
}),
action
.prompt_to_select
.map(|uuid| UserPromptId(uuid).into()),
cx,
)
.detach_and_log_err(cx);
@@ -553,10 +656,9 @@ impl AssistantPanel {
cx,
)
});
this.set_active_view(
ActiveView::PromptEditor {
context_editor: editor,
},
ActiveView::prompt_editor(editor.clone(), window, cx),
window,
cx,
);
@@ -600,7 +702,6 @@ impl AssistantPanel {
thread.clone(),
this.thread_store.clone(),
this.language_registry.clone(),
message_editor_context_store.clone(),
this.workspace.clone(),
window,
cx,
@@ -619,6 +720,7 @@ impl AssistantPanel {
this.fs.clone(),
this.workspace.clone(),
message_editor_context_store,
this.prompt_store.clone(),
this.thread_store.downgrade(),
thread,
window,
@@ -796,7 +898,7 @@ impl AssistantPanel {
pub(crate) fn active_context_editor(&self) -> Option<Entity<ContextEditor>> {
match &self.active_view {
ActiveView::PromptEditor { context_editor } => Some(context_editor.clone()),
ActiveView::PromptEditor { context_editor, .. } => Some(context_editor.clone()),
_ => None,
}
}
@@ -839,7 +941,7 @@ impl Focusable for AssistantPanel {
match &self.active_view {
ActiveView::Thread { .. } => self.message_editor.focus_handle(cx),
ActiveView::History => self.history.focus_handle(cx),
ActiveView::PromptEditor { context_editor } => context_editor.focus_handle(cx),
ActiveView::PromptEditor { context_editor, .. } => context_editor.focus_handle(cx),
ActiveView::Configuration => {
if let Some(configuration) = self.configuration.as_ref() {
configuration.focus_handle(cx)
@@ -963,9 +1065,34 @@ impl AssistantPanel {
.into_any_element()
}
}
ActiveView::PromptEditor { context_editor } => {
let title = SharedString::from(context_editor.read(cx).title(cx).to_string());
Label::new(title).ml_2().truncate().into_any_element()
ActiveView::PromptEditor {
title_editor,
context_editor,
..
} => {
let context_editor = context_editor.read(cx);
let summary = context_editor.context().read(cx).summary();
match summary {
None => Label::new(AssistantContext::DEFAULT_SUMMARY.clone())
.truncate()
.ml_2()
.into_any_element(),
Some(summary) => {
if summary.done {
div()
.ml_2()
.w_full()
.child(title_editor.clone())
.into_any_element()
} else {
Label::new(LOADING_SUMMARY_PLACEHOLDER)
.ml_2()
.truncate()
.into_any_element()
}
}
}
}
ActiveView::History => Label::new("History").truncate().into_any_element(),
ActiveView::Configuration => Label::new("Settings").truncate().into_any_element(),
@@ -1122,7 +1249,7 @@ impl AssistantPanel {
"New Text Thread",
NewTextThread.boxed_clone(),
)
.action("Prompt Library", Box::new(OpenPromptLibrary::default()))
.action("Rules Library", Box::new(OpenRulesLibrary::default()))
.action("Settings", Box::new(OpenConfiguration))
.separator()
.header("MCPs")
@@ -1238,7 +1365,7 @@ impl AssistantPanel {
Some(token_count)
}
ActiveView::PromptEditor { context_editor } => {
ActiveView::PromptEditor { context_editor, .. } => {
let element = render_remaining_tokens(context_editor, cx)?;
Some(element.into_any_element())
@@ -1548,9 +1675,19 @@ impl AssistantPanel {
}
fn render_usage_banner(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let plan = self
.user_store
.read(cx)
.current_plan()
.map(|plan| match plan {
Plan::Free => zed_llm_client::Plan::Free,
Plan::ZedPro => zed_llm_client::Plan::ZedPro,
Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
})
.unwrap_or(zed_llm_client::Plan::Free);
let usage = self.thread.read(cx).last_usage()?;
Some(UsageBanner::new(zed_llm_client::Plan::ZedProTrial, usage).into_any_element())
Some(UsageBanner::new(plan, usage).into_any_element())
}
fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
@@ -1605,6 +1742,8 @@ impl AssistantPanel {
h_flex()
.justify_end()
.mt_1()
.gap_1()
.child(self.create_copy_button(ERROR_MESSAGE))
.child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
|this, _, _, cx| {
this.thread.update(cx, |this, _cx| {
@@ -1651,6 +1790,8 @@ impl AssistantPanel {
h_flex()
.justify_end()
.mt_1()
.gap_1()
.child(self.create_copy_button(ERROR_MESSAGE))
.child(
Button::new("subscribe", "Update Monthly Spend Limit").on_click(
cx.listener(|this, _, _, cx| {
@@ -1716,6 +1857,8 @@ impl AssistantPanel {
h_flex()
.justify_end()
.mt_1()
.gap_1()
.child(self.create_copy_button(error_message))
.child(
Button::new("subscribe", call_to_action).on_click(cx.listener(
|this, _, _, cx| {
@@ -1747,6 +1890,7 @@ impl AssistantPanel {
message: SharedString,
cx: &mut Context<Self>,
) -> AnyElement {
let message_with_header = format!("{}\n{}", header, message);
v_flex()
.gap_0p5()
.child(
@@ -1761,12 +1905,14 @@ impl AssistantPanel {
.id("error-message")
.max_h_32()
.overflow_y_scroll()
.child(Label::new(message)),
.child(Label::new(message.clone())),
)
.child(
h_flex()
.justify_end()
.mt_1()
.gap_1()
.child(self.create_copy_button(message_with_header))
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|this, _, _, cx| {
this.thread.update(cx, |this, _cx| {
@@ -1780,6 +1926,15 @@ impl AssistantPanel {
.into_any()
}
fn create_copy_button(&self, message: impl Into<String>) -> impl IntoElement {
let message = message.into();
IconButton::new("copy", IconName::Copy)
.on_click(move |_, _, cx| {
cx.write_to_clipboard(ClipboardItem::new_string(message.clone()))
})
.tooltip(Tooltip::text("Copy Error Message"))
}
fn key_context(&self) -> KeyContext {
let mut key_context = KeyContext::new_with_defaults();
key_context.add("AgentPanel");
@@ -1807,7 +1962,7 @@ impl Render for AssistantPanel {
this.open_configuration(window, cx);
}))
.on_action(cx.listener(Self::open_active_thread_as_markdown))
.on_action(cx.listener(Self::deploy_prompt_library))
.on_action(cx.listener(Self::deploy_rules_library))
.on_action(cx.listener(Self::open_agent_diff))
.on_action(cx.listener(Self::go_back))
.child(self.render_toolbar(window, cx))
@@ -1818,7 +1973,9 @@ impl Render for AssistantPanel {
.child(h_flex().child(self.message_editor.clone()))
.children(self.render_last_error(cx)),
ActiveView::History => parent.child(self.history.clone()),
ActiveView::PromptEditor { context_editor } => parent.child(context_editor.clone()),
ActiveView::PromptEditor { context_editor, .. } => {
parent.child(context_editor.clone())
}
ActiveView::Configuration => parent.children(self.configuration.clone()),
})
}
@@ -1834,13 +1991,13 @@ impl PromptLibraryInlineAssist {
}
}
impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
fn assist(
&self,
prompt_editor: &Entity<Editor>,
_initial_prompt: Option<String>,
window: &mut Window,
cx: &mut Context<PromptLibrary>,
cx: &mut Context<RulesLibrary>,
) {
InlineAssistant::update_global(cx, |assistant, cx| {
let Some(project) = self
@@ -1850,11 +2007,14 @@ impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
else {
return;
};
let prompt_store = None;
let thread_store = None;
assistant.assist(
&prompt_editor,
self.workspace.clone(),
project,
None,
prompt_store,
thread_store,
window,
cx,
)
@@ -1933,8 +2093,8 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
// being updated.
cx.defer_in(window, move |panel, window, cx| {
if panel.has_active_thread() {
panel.thread.update(cx, |thread, cx| {
thread.context_store().update(cx, |store, cx| {
panel.message_editor.update(cx, |message_editor, cx| {
message_editor.context_store().update(cx, |store, cx| {
let buffer = buffer.read(cx);
let selection_ranges = selection_ranges
.into_iter()
@@ -1951,9 +2111,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
.collect::<Vec<_>>();
for (buffer, range) in selection_ranges {
store
.add_selection(buffer, range, cx)
.detach_and_log_err(cx);
store.add_selection(buffer, range, cx);
}
})
})

View File

@@ -1,12 +1,14 @@
use crate::context::attach_context_to_message;
use crate::context_store::ContextStore;
use crate::context::ContextLoadResult;
use crate::inline_prompt_editor::CodegenStatus;
use crate::{context::load_context, context_store::ContextStore};
use anyhow::Result;
use client::telemetry::Telemetry;
use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
use futures::{SinkExt, Stream, StreamExt, channel::mpsc, future::LocalBoxFuture, join};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
use futures::{
SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
};
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
use language_model::{
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
@@ -14,7 +16,9 @@ use language_model::{
};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use project::Project;
use prompt_store::PromptBuilder;
use prompt_store::PromptStore;
use rope::Rope;
use smol::future::FutureExt;
use std::{
@@ -39,6 +43,8 @@ pub struct BufferCodegen {
range: Range<Anchor>,
initial_transaction_id: Option<TransactionId>,
context_store: Entity<ContextStore>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
telemetry: Arc<Telemetry>,
builder: Arc<PromptBuilder>,
pub is_insertion: bool,
@@ -50,6 +56,8 @@ impl BufferCodegen {
range: Range<Anchor>,
initial_transaction_id: Option<TransactionId>,
context_store: Entity<ContextStore>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
telemetry: Arc<Telemetry>,
builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
@@ -60,6 +68,8 @@ impl BufferCodegen {
range.clone(),
false,
Some(context_store.clone()),
project.clone(),
prompt_store.clone(),
Some(telemetry.clone()),
builder.clone(),
cx,
@@ -75,6 +85,8 @@ impl BufferCodegen {
range,
initial_transaction_id,
context_store,
project,
prompt_store,
telemetry,
builder,
};
@@ -153,6 +165,8 @@ impl BufferCodegen {
self.range.clone(),
false,
Some(self.context_store.clone()),
self.project.clone(),
self.prompt_store.clone(),
Some(self.telemetry.clone()),
self.builder.clone(),
cx,
@@ -229,13 +243,14 @@ pub struct CodegenAlternative {
generation: Task<()>,
diff: Diff,
context_store: Option<Entity<ContextStore>>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
telemetry: Option<Arc<Telemetry>>,
_subscription: gpui::Subscription,
builder: Arc<PromptBuilder>,
active: bool,
edits: Vec<(Range<Anchor>, String)>,
line_operations: Vec<LineOperation>,
request: Option<LanguageModelRequest>,
elapsed_time: Option<f64>,
completion: Option<String>,
pub message_id: Option<String>,
@@ -249,6 +264,8 @@ impl CodegenAlternative {
range: Range<Anchor>,
active: bool,
context_store: Option<Entity<ContextStore>>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
telemetry: Option<Arc<Telemetry>>,
builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
@@ -290,6 +307,8 @@ impl CodegenAlternative {
generation: Task::ready(()),
diff: Diff::default(),
context_store,
project,
prompt_store,
telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
builder,
@@ -297,7 +316,6 @@ impl CodegenAlternative {
edits: Vec::new(),
line_operations: Vec::new(),
range,
request: None,
elapsed_time: None,
completion: None,
}
@@ -366,16 +384,18 @@ impl CodegenAlternative {
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
} else {
let request = self.build_request(user_prompt, cx)?;
self.request = Some(request.clone());
cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await)
cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
.boxed_local()
};
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
Ok(())
}
fn build_request(&self, user_prompt: String, cx: &mut App) -> Result<LanguageModelRequest> {
fn build_request(
&self,
user_prompt: String,
cx: &mut App,
) -> Result<Task<LanguageModelRequest>> {
let buffer = self.buffer.read(cx).snapshot(cx);
let language = buffer.language_at(self.range.start);
let language_name = if let Some(language) = language.as_ref() {
@@ -408,30 +428,44 @@ impl CodegenAlternative {
.generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: Vec::new(),
cache: false,
};
let context_task = self.context_store.as_ref().map(|context_store| {
if let Some(project) = self.project.upgrade() {
let context = context_store
.read(cx)
.context()
.cloned()
.collect::<Vec<_>>();
load_context(context, &project, &self.prompt_store, cx)
} else {
Task::ready(ContextLoadResult::default())
}
});
if let Some(context_store) = &self.context_store {
attach_context_to_message(
&mut request_message,
context_store.read(cx).context().iter(),
cx,
);
}
Ok(cx.spawn(async move |_cx| {
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: Vec::new(),
cache: false,
};
request_message.content.push(prompt.into());
if let Some(context_task) = context_task {
context_task
.await
.loaded_context
.add_to_request_message(&mut request_message);
}
Ok(LanguageModelRequest {
thread_id: None,
prompt_id: None,
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
messages: vec![request_message],
})
request_message.content.push(prompt.into());
LanguageModelRequest {
thread_id: None,
prompt_id: None,
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
messages: vec![request_message],
}
}))
}
pub fn handle_stream(
@@ -508,7 +542,9 @@ impl CodegenAlternative {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
let chunks = StripInvalidSpans::new(stream?.stream);
let chunks = StripInvalidSpans::new(
stream?.stream.map_err(|error| error.into()),
);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
@@ -1034,6 +1070,7 @@ impl Diff {
#[cfg(test)]
mod tests {
use super::*;
use fs::FakeFs;
use futures::{
Stream,
stream::{self},
@@ -1076,12 +1113,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,
@@ -1140,12 +1181,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,
@@ -1207,12 +1252,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,
@@ -1274,12 +1323,16 @@ mod tests {
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
true,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,
@@ -1329,12 +1382,16 @@ mod tests {
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, vec![], cx).await;
let codegen = cx.new(|cx| {
CodegenAlternative::new(
buffer.clone(),
range.clone(),
false,
None,
project.downgrade(),
None,
None,
prompt_builder,
cx,

File diff suppressed because it is too large Load Diff

View File

@@ -10,8 +10,11 @@ use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{Result, anyhow};
pub use completion_provider::ContextPickerCompletionProvider;
use editor::display_map::{Crease, FoldId};
use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset};
use fetch_context_picker::FetchContextPicker;
use file_context_picker::FileContextPicker;
use file_context_picker::render_file_context_entry;
use gpui::{
App, DismissEvent, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task,
@@ -20,10 +23,10 @@ use gpui::{
use language::Buffer;
use multi_buffer::MultiBufferRow;
use project::{Entry, ProjectPath};
use prompt_store::UserPromptId;
use rules_context_picker::RulesContextEntry;
use prompt_store::{PromptStore, UserPromptId};
use rules_context_picker::{RulesContextEntry, RulesContextPicker};
use symbol_context_picker::SymbolContextPicker;
use thread_context_picker::{ThreadContextEntry, render_thread_context_entry};
use thread_context_picker::{ThreadContextEntry, ThreadContextPicker, render_thread_context_entry};
use ui::{
ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*,
};
@@ -32,11 +35,6 @@ use workspace::{Workspace, notifications::NotifyResultExt};
use crate::AssistantPanel;
use crate::context::RULES_ICON;
pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider;
use crate::context_picker::fetch_context_picker::FetchContextPicker;
use crate::context_picker::file_context_picker::FileContextPicker;
use crate::context_picker::rules_context_picker::RulesContextPicker;
use crate::context_picker::thread_context_picker::ThreadContextPicker;
use crate::context_store::ContextStore;
use crate::thread::ThreadId;
use crate::thread_store::ThreadStore;
@@ -166,6 +164,7 @@ pub(super) struct ContextPicker {
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
thread_store: Option<WeakEntity<ThreadStore>>,
prompt_store: Option<Entity<PromptStore>>,
_subscriptions: Vec<Subscription>,
}
@@ -193,6 +192,13 @@ impl ContextPicker {
)
.collect::<Vec<Subscription>>();
let prompt_store = thread_store.as_ref().and_then(|thread_store| {
thread_store
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
.ok()
.flatten()
});
ContextPicker {
mode: ContextPickerState::Default(ContextMenu::build(
window,
@@ -202,6 +208,7 @@ impl ContextPicker {
workspace,
context_store,
thread_store,
prompt_store,
_subscriptions: subscriptions,
}
}
@@ -226,7 +233,12 @@ impl ContextPicker {
.workspace
.upgrade()
.map(|workspace| {
available_context_picker_entries(&self.thread_store, &workspace, cx)
available_context_picker_entries(
&self.prompt_store,
&self.thread_store,
&workspace,
cx,
)
})
.unwrap_or_default();
@@ -304,10 +316,10 @@ impl ContextPicker {
}));
}
ContextPickerMode::Rules => {
if let Some(thread_store) = self.thread_store.as_ref() {
if let Some(prompt_store) = self.prompt_store.as_ref() {
self.mode = ContextPickerState::Rules(cx.new(|cx| {
RulesContextPicker::new(
thread_store.clone(),
prompt_store.clone(),
context_picker.clone(),
self.context_store.clone(),
window,
@@ -526,6 +538,7 @@ enum RecentEntry {
}
fn available_context_picker_entries(
prompt_store: &Option<Entity<PromptStore>>,
thread_store: &Option<WeakEntity<ThreadStore>>,
workspace: &Entity<Workspace>,
cx: &mut App,
@@ -550,6 +563,9 @@ fn available_context_picker_entries(
if thread_store.is_some() {
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Thread));
}
if prompt_store.is_some() {
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules));
}
@@ -585,22 +601,21 @@ fn recent_context_picker_entries(
}),
);
let mut current_threads = context_store.read(cx).thread_ids();
let current_threads = context_store.read(cx).thread_ids();
if let Some(active_thread) = workspace
let active_thread_id = workspace
.panel::<AssistantPanel>(cx)
.map(|panel| panel.read(cx).active_thread(cx))
{
current_threads.insert(active_thread.read(cx).id().clone());
}
.map(|panel| panel.read(cx).active_thread(cx).read(cx).id());
if let Some(thread_store) = thread_store.and_then(|thread_store| thread_store.upgrade()) {
recent.extend(
thread_store
.read(cx)
.threads()
.reverse_chronological_threads()
.into_iter()
.filter(|thread| !current_threads.contains(&thread.id))
.filter(|thread| {
Some(&thread.id) != active_thread_id && !current_threads.contains(&thread.id)
})
.take(2)
.map(|thread| {
RecentEntry::Thread(ThreadContextEntry {
@@ -622,9 +637,7 @@ fn add_selections_as_context(
let selection_ranges = selection_ranges(workspace, cx);
context_store.update(cx, |context_store, cx| {
for (buffer, range) in selection_ranges {
context_store
.add_selection(buffer, range, cx)
.detach_and_log_err(cx);
context_store.add_selection(buffer, range, cx);
}
})
}

View File

@@ -15,22 +15,21 @@ use itertools::Itertools;
use language::{Buffer, CodeLabel, HighlightId};
use lsp::CompletionContext;
use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId};
use prompt_store::PromptId;
use prompt_store::PromptStore;
use rope::Point;
use text::{Anchor, OffsetRangeExt, ToPoint};
use ui::prelude::*;
use workspace::Workspace;
use crate::context::RULES_ICON;
use crate::context_picker::file_context_picker::search_files;
use crate::context_picker::symbol_context_picker::search_symbols;
use crate::context_store::ContextStore;
use crate::thread_store::ThreadStore;
use super::fetch_context_picker::fetch_url_content;
use super::file_context_picker::FileMatch;
use super::file_context_picker::{FileMatch, search_files};
use super::rules_context_picker::{RulesContextEntry, search_rules};
use super::symbol_context_picker::SymbolMatch;
use super::symbol_context_picker::search_symbols;
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
use super::{
ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry,
@@ -38,8 +37,8 @@ use super::{
};
pub(crate) enum Match {
Symbol(SymbolMatch),
File(FileMatch),
Symbol(SymbolMatch),
Thread(ThreadMatch),
Fetch(SharedString),
Rules(RulesContextEntry),
@@ -69,6 +68,7 @@ fn search(
query: String,
cancellation_flag: Arc<AtomicBool>,
recent_entries: Vec<RecentEntry>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
workspace: Entity<Workspace>,
cx: &mut App,
@@ -85,6 +85,7 @@ fn search(
.collect()
})
}
Some(ContextPickerMode::Symbol) => {
let search_symbols_task =
search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx);
@@ -96,6 +97,7 @@ fn search(
.collect()
})
}
Some(ContextPickerMode::Thread) => {
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
let search_threads_task =
@@ -111,6 +113,7 @@ fn search(
Task::ready(Vec::new())
}
}
Some(ContextPickerMode::Fetch) => {
if !query.is_empty() {
Task::ready(vec![Match::Fetch(query.into())])
@@ -118,10 +121,11 @@ fn search(
Task::ready(Vec::new())
}
}
Some(ContextPickerMode::Rules) => {
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
if let Some(prompt_store) = prompt_store.as_ref() {
let search_rules_task =
search_rules(query.clone(), cancellation_flag.clone(), thread_store, cx);
search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx);
cx.background_spawn(async move {
search_rules_task
.await
@@ -133,6 +137,7 @@ fn search(
Task::ready(Vec::new())
}
}
None => {
if query.is_empty() {
let mut matches = recent_entries
@@ -163,7 +168,7 @@ fn search(
.collect::<Vec<_>>();
matches.extend(
available_context_picker_entries(&thread_store, &workspace, cx)
available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx)
.into_iter()
.map(|mode| {
Match::Entry(EntryMatch {
@@ -180,7 +185,8 @@ fn search(
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
let entries = available_context_picker_entries(&thread_store, &workspace, cx);
let entries =
available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx);
let entry_candidates = entries
.iter()
.enumerate()
@@ -307,9 +313,11 @@ impl ContextPickerCompletionProvider {
move |_, _: &mut Window, cx: &mut App| {
context_store.update(cx, |context_store, cx| {
for (buffer, range) in &selections {
context_store
.add_selection(buffer.clone(), range.clone(), cx)
.detach_and_log_err(cx)
context_store.add_selection(
buffer.clone(),
range.clone(),
cx,
);
}
});
@@ -437,7 +445,6 @@ impl ContextPickerCompletionProvider {
source_range: Range<Anchor>,
editor: Entity<Editor>,
context_store: Entity<ContextStore>,
thread_store: Entity<ThreadStore>,
) -> Completion {
let new_text = MentionLink::for_rules(&rules);
let new_text_len = new_text.len();
@@ -457,29 +464,10 @@ impl ContextPickerCompletionProvider {
new_text_len,
editor.clone(),
move |cx| {
let prompt_uuid = rules.prompt_id;
let prompt_id = PromptId::User { uuid: prompt_uuid };
let context_store = context_store.clone();
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
log::error!("Can't add user rules as prompt store is missing.");
return;
};
let prompt_store = prompt_store.read(cx);
let Some(metadata) = prompt_store.metadata(prompt_id) else {
return;
};
let Some(title) = metadata.title else {
return;
};
let text_task = prompt_store.load(prompt_id, cx);
cx.spawn(async move |cx| {
let text = text_task.await?;
context_store.update(cx, |context_store, cx| {
context_store.add_rules(prompt_uuid, title, text, false, cx)
})
})
.detach_and_log_err(cx);
let user_prompt_id = rules.prompt_id;
context_store.update(cx, |context_store, cx| {
context_store.add_rules(user_prompt_id, false, cx);
});
},
)),
}
@@ -516,7 +504,7 @@ impl ContextPickerCompletionProvider {
let url_to_fetch = url_to_fetch.clone();
cx.spawn(async move |cx| {
if context_store.update(cx, |context_store, _| {
context_store.includes_url(&url_to_fetch).is_some()
context_store.includes_url(&url_to_fetch)
})? {
return Ok(());
}
@@ -592,7 +580,7 @@ impl ContextPickerCompletionProvider {
move |cx| {
context_store.update(cx, |context_store, cx| {
let task = if is_directory {
context_store.add_directory(project_path.clone(), false, cx)
Task::ready(context_store.add_directory(&project_path, false, cx))
} else {
context_store.add_file_from_path(project_path.clone(), false, cx)
};
@@ -732,11 +720,19 @@ impl CompletionProvider for ContextPickerCompletionProvider {
cx,
);
let prompt_store = thread_store.as_ref().and_then(|thread_store| {
thread_store
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
.ok()
.flatten()
});
let search_task = search(
mode,
query,
Arc::<AtomicBool>::default(),
recent_entries,
prompt_store,
thread_store.clone(),
workspace.clone(),
cx,
@@ -768,6 +764,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
cx,
))
}
Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol(
symbol,
excerpt_id,
@@ -777,6 +774,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
workspace.clone(),
cx,
),
Match::Thread(ThreadMatch {
thread, is_recent, ..
}) => {
@@ -791,17 +789,15 @@ impl CompletionProvider for ContextPickerCompletionProvider {
thread_store,
))
}
Match::Rules(user_rules) => {
let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?;
Some(Self::completion_for_rules(
user_rules,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
thread_store,
))
}
Match::Rules(user_rules) => Some(Self::completion_for_rules(
user_rules,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
)),
Match::Fetch(url) => Some(Self::completion_for_fetch(
source_range.clone(),
url,
@@ -810,6 +806,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
context_store.clone(),
http_client.clone(),
)),
Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry(
entry,
excerpt_id,

View File

@@ -227,7 +227,7 @@ impl PickerDelegate for FetchContextPickerDelegate {
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
let added = self.context_store.upgrade().map_or(false, |context_store| {
context_store.read(cx).includes_url(&self.url).is_some()
context_store.read(cx).includes_url(&self.url)
});
Some(

View File

@@ -134,9 +134,9 @@ impl PickerDelegate for FileContextPickerDelegate {
.context_store
.update(cx, |context_store, cx| {
if is_directory {
context_store.add_directory(project_path, true, cx)
Task::ready(context_store.add_directory(&project_path, true, cx))
} else {
context_store.add_file_from_path(project_path, true, cx)
context_store.add_file_from_path(project_path.clone(), true, cx)
}
})
.ok()
@@ -325,11 +325,11 @@ pub fn render_file_context_entry(
path: path.clone(),
};
if is_directory {
context_store.read(cx).includes_directory(&project_path)
} else {
context_store
.read(cx)
.will_include_file_path(&project_path, cx)
.path_included_in_directory(&project_path, cx)
} else {
context_store.read(cx).file_path_included(&project_path, cx)
}
});
@@ -357,7 +357,7 @@ pub fn render_file_context_entry(
})),
)
.when_some(added, |el, added| match added {
FileInclusion::Direct(_) => el.child(
FileInclusion::Direct => el.child(
h_flex()
.w_full()
.justify_end()
@@ -369,9 +369,8 @@ pub fn render_file_context_entry(
)
.child(Label::new("Added").size(LabelSize::Small)),
),
FileInclusion::InDirectory(directory_project_path) => {
// TODO: Consider using worktree full_path to include worktree name.
let directory_path = directory_project_path.path.to_string_lossy().into_owned();
FileInclusion::InDirectory { full_path } => {
let directory_full_path = full_path.to_string_lossy().into_owned();
el.child(
h_flex()
@@ -385,7 +384,7 @@ pub fn render_file_context_entry(
)
.child(Label::new("Included").size(LabelSize::Small)),
)
.tooltip(Tooltip::text(format!("in {directory_path}")))
.tooltip(Tooltip::text(format!("in {directory_full_path}")))
}
})
}

View File

@@ -1,16 +1,15 @@
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use anyhow::anyhow;
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
use picker::{Picker, PickerDelegate};
use prompt_store::{PromptId, UserPromptId};
use prompt_store::{PromptId, PromptStore, UserPromptId};
use ui::{ListItem, prelude::*};
use util::ResultExt as _;
use crate::context::RULES_ICON;
use crate::context_picker::ContextPicker;
use crate::context_store::{self, ContextStore};
use crate::thread_store::ThreadStore;
pub struct RulesContextPicker {
picker: Entity<Picker<RulesContextPickerDelegate>>,
@@ -18,13 +17,13 @@ pub struct RulesContextPicker {
impl RulesContextPicker {
pub fn new(
thread_store: WeakEntity<ThreadStore>,
prompt_store: Entity<PromptStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = RulesContextPickerDelegate::new(thread_store, context_picker, context_store);
let delegate = RulesContextPickerDelegate::new(prompt_store, context_picker, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
RulesContextPicker { picker }
@@ -50,7 +49,7 @@ pub struct RulesContextEntry {
}
pub struct RulesContextPickerDelegate {
thread_store: WeakEntity<ThreadStore>,
prompt_store: Entity<PromptStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
matches: Vec<RulesContextEntry>,
@@ -59,12 +58,12 @@ pub struct RulesContextPickerDelegate {
impl RulesContextPickerDelegate {
pub fn new(
thread_store: WeakEntity<ThreadStore>,
prompt_store: Entity<PromptStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
) -> Self {
RulesContextPickerDelegate {
thread_store,
prompt_store,
context_picker,
context_store,
matches: Vec::new(),
@@ -103,11 +102,12 @@ impl PickerDelegate for RulesContextPickerDelegate {
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let Some(thread_store) = self.thread_store.upgrade() else {
return Task::ready(());
};
let search_task = search_rules(query, Arc::new(AtomicBool::default()), thread_store, cx);
let search_task = search_rules(
query,
Arc::new(AtomicBool::default()),
&self.prompt_store,
cx,
);
cx.spawn_in(window, async move |this, cx| {
let matches = search_task.await;
this.update(cx, |this, cx| {
@@ -124,31 +124,11 @@ impl PickerDelegate for RulesContextPickerDelegate {
return;
};
let Some(thread_store) = self.thread_store.upgrade() else {
return;
};
let prompt_id = entry.prompt_id;
let load_rules_task = thread_store.update(cx, |thread_store, cx| {
thread_store.load_rules(prompt_id, cx)
});
cx.spawn(async move |this, cx| {
let (metadata, text) = load_rules_task.await?;
let Some(title) = metadata.title else {
return Err(anyhow!("Encountered user rule with no title when attempting to add it to agent context."));
};
this.update(cx, |this, cx| {
this.delegate
.context_store
.update(cx, |context_store, cx| {
context_store.add_rules(prompt_id, title, text, true, cx)
})
.ok();
self.context_store
.update(cx, |context_store, cx| {
context_store.add_rules(entry.prompt_id, true, cx)
})
})
.detach_and_log_err(cx);
.log_err();
}
fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
@@ -179,11 +159,10 @@ pub fn render_thread_context_entry(
context_store: WeakEntity<ContextStore>,
cx: &mut App,
) -> Div {
let added = context_store.upgrade().map_or(false, |ctx_store| {
ctx_store
let added = context_store.upgrade().map_or(false, |context_store| {
context_store
.read(cx)
.includes_user_rules(&user_rules.prompt_id)
.is_some()
.includes_user_rules(user_rules.prompt_id)
});
h_flex()
@@ -218,12 +197,9 @@ pub fn render_thread_context_entry(
pub(crate) fn search_rules(
query: String,
cancellation_flag: Arc<AtomicBool>,
thread_store: Entity<ThreadStore>,
prompt_store: &Entity<PromptStore>,
cx: &mut App,
) -> Task<Vec<RulesContextEntry>> {
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
return Task::ready(vec![]);
};
let search_task = prompt_store.read(cx).search(query, cancellation_flag, cx);
cx.background_spawn(async move {
search_task

View File

@@ -10,7 +10,6 @@ use gpui::{
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use project::{DocumentSymbol, Symbol};
use text::OffsetRangeExt;
use ui::{ListItem, prelude::*};
use util::ResultExt as _;
use workspace::Workspace;
@@ -228,18 +227,16 @@ pub(crate) fn add_symbol(
)
})?;
context_store
.update(cx, move |context_store, cx| {
context_store.add_symbol(
buffer,
name.into(),
range,
enclosing_range,
remove_if_exists,
cx,
)
})?
.await
context_store.update(cx, move |context_store, cx| {
context_store.add_symbol(
buffer,
name.into(),
range,
enclosing_range,
remove_if_exists,
cx,
)
})
})
}
@@ -353,38 +350,13 @@ fn compute_symbol_entries(
context_store: &ContextStore,
cx: &App,
) -> Vec<SymbolEntry> {
let mut symbol_entries = Vec::with_capacity(symbols.len());
for SymbolMatch { symbol, .. } in symbols {
let symbols_for_path = context_store.included_symbols_by_path().get(&symbol.path);
let is_included = if let Some(symbols_for_path) = symbols_for_path {
let mut is_included = false;
for included_symbol_id in symbols_for_path {
if included_symbol_id.name.as_ref() == symbol.name.as_str() {
if let Some(buffer) = context_store.buffer_for_symbol(included_symbol_id) {
let snapshot = buffer.read(cx).snapshot();
let included_symbol_range =
included_symbol_id.range.to_point_utf16(&snapshot);
if included_symbol_range.start == symbol.range.start.0
&& included_symbol_range.end == symbol.range.end.0
{
is_included = true;
break;
}
}
}
}
is_included
} else {
false
};
symbol_entries.push(SymbolEntry {
symbols
.into_iter()
.map(|SymbolMatch { symbol, .. }| SymbolEntry {
is_included: context_store.includes_symbol(&symbol, cx),
symbol,
is_included,
})
}
symbol_entries
.collect::<Vec<_>>()
}
pub fn render_symbol_context_entry(id: ElementId, entry: &SymbolEntry) -> Stateful<Div> {

View File

@@ -173,7 +173,7 @@ pub fn render_thread_context_entry(
cx: &mut App,
) -> Div {
let added = context_store.upgrade().map_or(false, |ctx_store| {
ctx_store.read(cx).includes_thread(&thread.id).is_some()
ctx_store.read(cx).includes_thread(&thread.id)
});
h_flex()
@@ -219,7 +219,7 @@ pub(crate) fn search_threads(
) -> Task<Vec<ThreadMatch>> {
let threads = thread_store
.read(cx)
.threads()
.reverse_chronological_threads()
.into_iter()
.map(|thread| ThreadContextEntry {
id: thread.id,

File diff suppressed because it is too large Load Diff

View File

@@ -12,9 +12,9 @@ use itertools::Itertools;
use language::Buffer;
use project::ProjectItem;
use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
use workspace::{Workspace, notifications::NotifyResultExt};
use workspace::Workspace;
use crate::context::{ContextId, ContextKind};
use crate::context::{AgentContext, ContextKind};
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
use crate::thread::Thread;
@@ -32,6 +32,7 @@ pub struct ContextStrip {
focus_handle: FocusHandle,
suggest_context_kind: SuggestContextKind,
workspace: WeakEntity<Workspace>,
thread_store: Option<WeakEntity<ThreadStore>>,
_subscriptions: Vec<Subscription>,
focused_index: Option<usize>,
children_bounds: Option<Vec<Bounds<Pixels>>>,
@@ -73,12 +74,31 @@ impl ContextStrip {
focus_handle,
suggest_context_kind,
workspace,
thread_store,
_subscriptions: subscriptions,
focused_index: None,
children_bounds: None,
}
}
fn added_contexts(&self, cx: &App) -> Vec<AddedContext> {
if let Some(workspace) = self.workspace.upgrade() {
let project = workspace.read(cx).project().read(cx);
let prompt_store = self
.thread_store
.as_ref()
.and_then(|thread_store| thread_store.upgrade())
.and_then(|thread_store| thread_store.read(cx).prompt_store().as_ref());
self.context_store
.read(cx)
.context()
.flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx))
.collect::<Vec<_>>()
} else {
Vec::new()
}
}
fn suggested_context(&self, cx: &Context<Self>) -> Option<SuggestedContext> {
match self.suggest_context_kind {
SuggestContextKind::File => self.suggested_file(cx),
@@ -93,22 +113,19 @@ impl ContextStrip {
let editor = active_item.to_any().downcast::<Editor>().ok()?.read(cx);
let active_buffer_entity = editor.buffer().read(cx).as_singleton()?;
let active_buffer = active_buffer_entity.read(cx);
let project_path = active_buffer.project_path(cx)?;
if self
.context_store
.read(cx)
.will_include_buffer(active_buffer.remote_id(), &project_path)
.file_path_included(&project_path, cx)
.is_some()
{
return None;
}
let file_name = active_buffer.file()?.file_name(cx);
let icon_path = FileIcons::get_icon(&Path::new(&file_name), cx);
Some(SuggestedContext::File {
name: file_name.to_string_lossy().into_owned().into(),
buffer: active_buffer_entity.downgrade(),
@@ -135,7 +152,6 @@ impl ContextStrip {
.context_store
.read(cx)
.includes_thread(active_thread.id())
.is_some()
{
return None;
}
@@ -272,12 +288,12 @@ impl ContextStrip {
best.map(|(index, _, _)| index)
}
fn open_context(&mut self, id: ContextId, window: &mut Window, cx: &mut App) {
fn open_context(&mut self, context: &AgentContext, window: &mut Window, cx: &mut App) {
let Some(workspace) = self.workspace.upgrade() else {
return;
};
crate::active_thread::open_context(id, self.context_store.clone(), workspace, window, cx);
crate::active_thread::open_context(context, workspace, window, cx);
}
fn remove_focused_context(
@@ -287,17 +303,17 @@ impl ContextStrip {
cx: &mut Context<Self>,
) {
if let Some(index) = self.focused_index {
let mut is_empty = false;
let added_contexts = self.added_contexts(cx);
let Some(context) = added_contexts.get(index) else {
return;
};
self.context_store.update(cx, |this, cx| {
if let Some(item) = this.context().get(index) {
this.remove_context(item.id(), cx);
}
is_empty = this.context().is_empty();
this.remove_context(&context.context, cx);
});
if is_empty {
let is_now_empty = added_contexts.len() == 1;
if is_now_empty {
cx.emit(ContextStripEvent::BlurredEmpty);
} else {
self.focused_index = Some(index.saturating_sub(1));
@@ -306,49 +322,28 @@ impl ContextStrip {
}
}
fn is_suggested_focused<T>(&self, context: &Vec<T>) -> bool {
fn is_suggested_focused(&self, added_contexts: &Vec<AddedContext>) -> bool {
// We only suggest one item after the actual context
self.focused_index == Some(context.len())
self.focused_index == Some(added_contexts.len())
}
fn accept_suggested_context(
&mut self,
_: &AcceptSuggestedContext,
window: &mut Window,
_window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some(suggested) = self.suggested_context(cx) {
let context_store = self.context_store.read(cx);
if self.is_suggested_focused(context_store.context()) {
self.add_suggested_context(&suggested, window, cx);
if self.is_suggested_focused(&self.added_contexts(cx)) {
self.add_suggested_context(&suggested, cx);
}
}
}
fn add_suggested_context(
&mut self,
suggested: &SuggestedContext,
window: &mut Window,
cx: &mut Context<Self>,
) {
let task = self.context_store.update(cx, |context_store, cx| {
context_store.accept_suggested_context(&suggested, cx)
fn add_suggested_context(&mut self, suggested: &SuggestedContext, cx: &mut Context<Self>) {
self.context_store.update(cx, |context_store, cx| {
context_store.add_suggested_context(&suggested, cx)
});
cx.spawn_in(window, async move |this, cx| {
match task.await.notify_async_err(cx) {
None => {}
Some(()) => {
if let Some(this) = this.upgrade() {
this.update(cx, |_, cx| cx.notify())?;
}
}
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
cx.notify();
}
}
@@ -361,17 +356,10 @@ impl Focusable for ContextStrip {
impl Render for ContextStrip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let context_store = self.context_store.read(cx);
let context = context_store.context();
let context_picker = self.context_picker.clone();
let focus_handle = self.focus_handle.clone();
let suggested_context = self.suggested_context(cx);
let added_contexts = context
.iter()
.map(|c| AddedContext::new(c, cx))
.collect::<Vec<_>>();
let added_contexts = self.added_contexts(cx);
let dupe_names = added_contexts
.iter()
.map(|c| c.name.clone())
@@ -380,6 +368,14 @@ impl Render for ContextStrip {
.filter(|(a, b)| a == b)
.map(|(a, _)| a)
.collect::<HashSet<SharedString>>();
let no_added_context = added_contexts.is_empty();
let suggested_context = self.suggested_context(cx).map(|suggested_context| {
(
suggested_context,
self.is_suggested_focused(&added_contexts),
)
});
h_flex()
.flex_wrap()
@@ -436,7 +432,7 @@ impl Render for ContextStrip {
})
.with_handle(self.context_picker_menu_handle.clone()),
)
.when(context.is_empty() && suggested_context.is_none(), {
.when(no_added_context && suggested_context.is_none(), {
|parent| {
parent.child(
h_flex()
@@ -466,16 +462,17 @@ impl Render for ContextStrip {
.enumerate()
.map(|(i, added_context)| {
let name = added_context.name.clone();
let id = added_context.id;
let context = added_context.context.clone();
ContextPill::added(
added_context,
dupe_names.contains(&name),
self.focused_index == Some(i),
Some({
let context = context.clone();
let context_store = self.context_store.clone();
Rc::new(cx.listener(move |_this, _event, _window, cx| {
context_store.update(cx, |this, cx| {
this.remove_context(id, cx);
this.remove_context(&context, cx);
});
cx.notify();
}))
@@ -484,7 +481,7 @@ impl Render for ContextStrip {
.on_click({
Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| {
if event.down.click_count > 1 {
this.open_context(id, window, cx);
this.open_context(&context, window, cx);
} else {
this.focused_index = Some(i);
}
@@ -493,22 +490,22 @@ impl Render for ContextStrip {
})
}),
)
.when_some(suggested_context, |el, suggested| {
.when_some(suggested_context, |el, (suggested, focused)| {
el.child(
ContextPill::suggested(
suggested.name().clone(),
suggested.icon_path(),
suggested.kind(),
self.is_suggested_focused(&context),
focused,
)
.on_click(Rc::new(cx.listener(
move |this, _event, window, cx| {
this.add_suggested_context(&suggested, window, cx);
move |this, _event, _window, cx| {
this.add_suggested_context(&suggested, cx);
},
))),
)
})
.when(!context.is_empty(), {
.when(!no_added_context, {
move |parent| {
parent.child(
IconButton::new("remove-all-context", IconName::Eraser)
@@ -534,6 +531,7 @@ impl Render for ContextStrip {
)
}
})
.into_any()
}
}

View File

@@ -51,7 +51,10 @@ impl HistoryStore {
return history_entries;
}
for thread in self.thread_store.update(cx, |this, _cx| this.threads()) {
for thread in self
.thread_store
.update(cx, |this, _cx| this.reverse_chronological_threads())
{
history_entries.push(HistoryEntry::Thread(thread));
}

View File

@@ -32,6 +32,7 @@ use project::LspAction;
use project::Project;
use project::{CodeAction, ProjectTransaction};
use prompt_store::PromptBuilder;
use prompt_store::PromptStore;
use settings::{Settings, SettingsStore};
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use terminal_view::{TerminalView, terminal_panel::TerminalPanel};
@@ -245,9 +246,13 @@ impl InlineAssistant {
.map_or(false, |model| model.provider.is_authenticated(cx))
};
let thread_store = workspace
let assistant_panel = workspace
.panel::<AssistantPanel>(cx)
.map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
.map(|assistant_panel| assistant_panel.read(cx));
let prompt_store = assistant_panel
.and_then(|assistant_panel| assistant_panel.prompt_store().as_ref().cloned());
let thread_store =
assistant_panel.map(|assistant_panel| assistant_panel.thread_store().downgrade());
let handle_assist =
|window: &mut Window, cx: &mut Context<Workspace>| match inline_assist_target {
@@ -257,6 +262,7 @@ impl InlineAssistant {
&active_editor,
cx.entity().downgrade(),
workspace.project().downgrade(),
prompt_store,
thread_store,
window,
cx,
@@ -269,6 +275,7 @@ impl InlineAssistant {
&active_terminal,
cx.entity().downgrade(),
workspace.project().downgrade(),
prompt_store,
thread_store,
window,
cx,
@@ -323,6 +330,7 @@ impl InlineAssistant {
editor: &Entity<Editor>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
@@ -437,6 +445,8 @@ impl InlineAssistant {
range.clone(),
None,
context_store.clone(),
project.clone(),
prompt_store.clone(),
self.telemetry.clone(),
self.prompt_builder.clone(),
cx,
@@ -525,6 +535,7 @@ impl InlineAssistant {
initial_transaction_id: Option<TransactionId>,
focus: bool,
workspace: Entity<Workspace>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
@@ -543,7 +554,7 @@ impl InlineAssistant {
}
let project = workspace.read(cx).project().downgrade();
let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
let context_store = cx.new(|_cx| ContextStore::new(project.clone(), thread_store.clone()));
let codegen = cx.new(|cx| {
BufferCodegen::new(
@@ -551,6 +562,8 @@ impl InlineAssistant {
range.clone(),
initial_transaction_id,
context_store.clone(),
project,
prompt_store,
self.telemetry.clone(),
self.prompt_builder.clone(),
cx,
@@ -1789,6 +1802,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
let editor = self.editor.clone();
let workspace = self.workspace.clone();
let thread_store = self.thread_store.clone();
let prompt_store = PromptStore::global(cx);
window.spawn(cx, async move |cx| {
let workspace = workspace.upgrade().context("workspace was released")?;
let editor = editor.upgrade().context("editor was released")?;
@@ -1829,6 +1843,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
})?
.context("invalid range")?;
let prompt_store = prompt_store.await.ok();
cx.update_global(|assistant: &mut InlineAssistant, window, cx| {
let assist_id = assistant.suggest_assist(
&editor,
@@ -1837,6 +1852,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
None,
true,
workspace,
prompt_store,
thread_store,
window,
cx,

View File

@@ -2,7 +2,7 @@ use std::collections::BTreeMap;
use std::sync::Arc;
use crate::assistant_model_selector::ModelType;
use crate::context::{AssistantContext, format_context_as_string};
use crate::context::{ContextLoadResult, load_context};
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use buffer_diff::BufferDiff;
use collections::HashSet;
@@ -13,15 +13,20 @@ use editor::{
};
use file_icons::FileIcons;
use fs::Fs;
use futures::future::Shared;
use futures::{FutureExt as _, future};
use gpui::{
Animation, AnimationExt, App, ClipboardEntry, Entity, EventEmitter, Focusable, Subscription,
Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
};
use language::{Buffer, Language};
use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage};
use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage, MessageContent,
};
use language_model_selector::ToggleModelSelector;
use multi_buffer;
use project::Project;
use prompt_store::PromptStore;
use settings::Settings;
use std::time::Duration;
use theme::ThemeSettings;
@@ -31,7 +36,7 @@ use workspace::Workspace;
use crate::assistant_model_selector::AssistantModelSelector;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
use crate::context_store::{ContextStore, refresh_context_store_text};
use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::profile_selector::ProfileSelector;
use crate::thread::{Thread, TokenUsageRatio};
@@ -49,9 +54,12 @@ pub struct MessageEditor {
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
context_strip: Entity<ContextStrip>,
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
model_selector: Entity<AssistantModelSelector>,
last_loaded_context: Option<ContextLoadResult>,
context_load_task: Option<Shared<Task<()>>>,
profile_selector: Entity<ProfileSelector>,
edits_expanded: bool,
editor_is_expanded: bool,
@@ -68,6 +76,7 @@ impl MessageEditor {
fs: Arc<dyn Fs>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: WeakEntity<ThreadStore>,
thread: Entity<Thread>,
window: &mut Window,
@@ -135,13 +144,11 @@ impl MessageEditor {
let subscriptions = vec![
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
cx.subscribe(&editor, |this, _, event, cx| match event {
EditorEvent::BufferEdited => {
this.message_or_context_changed(true, cx);
}
EditorEvent::BufferEdited => this.handle_message_changed(cx),
_ => {}
}),
cx.observe(&context_store, |this, _, cx| {
this.message_or_context_changed(false, cx);
let _ = this.start_context_load(cx);
}),
];
@@ -152,8 +159,11 @@ impl MessageEditor {
incompatible_tools_state: incompatible_tools.clone(),
workspace,
context_store,
prompt_store,
context_strip,
context_picker_menu_handle,
context_load_task: None,
last_loaded_context: None,
model_selector: cx.new(|cx| {
AssistantModelSelector::new(
fs.clone(),
@@ -175,6 +185,10 @@ impl MessageEditor {
}
}
pub fn context_store(&self) -> &Entity<ContextStore> {
&self.context_store
}
fn toggle_chat_mode(&mut self, _: &ChatMode, _window: &mut Window, cx: &mut Context<Self>) {
cx.notify();
}
@@ -195,6 +209,7 @@ impl MessageEditor {
editor.set_mode(EditorMode::Full {
scale_ui_elements_with_buffer_font_size: false,
show_active_line_background: false,
sized_by_content: false,
})
} else {
editor.set_mode(EditorMode::AutoHeight {
@@ -213,6 +228,7 @@ impl MessageEditor {
) {
self.context_picker_menu_handle.toggle(window, cx);
}
pub fn remove_all_context(
&mut self,
_: &RemoveAllContext,
@@ -269,56 +285,44 @@ impl MessageEditor {
self.last_estimated_token_count.take();
cx.emit(MessageEditorEvent::EstimatedTokenCount);
let refresh_task =
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
let wait_for_images = self.context_store.read(cx).wait_for_images(cx);
let thread = self.thread.clone();
let context_store = self.context_store.clone();
let git_store = self.project.read(cx).git_store().clone();
let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
let context_task = self.load_context(cx);
let window_handle = window.window_handle();
cx.spawn(async move |this, cx| {
let checkpoint = checkpoint.await.ok();
refresh_task.await;
wait_for_images.await;
cx.spawn(async move |_this, cx| {
let (checkpoint, loaded_context) = future::join(checkpoint, context_task).await;
let loaded_context = loaded_context.unwrap_or_default();
thread
.update(cx, |thread, cx| {
let context = context_store.read(cx).context().clone();
thread.insert_user_message(user_message, context, checkpoint, cx);
thread.insert_user_message(user_message, loaded_context, checkpoint.ok(), cx);
})
.log_err();
context_store
.update(cx, |context_store, cx| {
let excerpt_ids = context_store
.context()
.iter()
.filter(|ctx| {
matches!(
ctx,
AssistantContext::Selection(_) | AssistantContext::Image(_)
)
})
.map(|ctx| ctx.id())
.collect::<Vec<_>>();
for id in excerpt_ids {
context_store.remove_context(id, cx);
}
thread
.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(model, Some(window_handle), cx);
})
.log_err();
})
.detach();
}
fn wait_for_summaries(&mut self, cx: &mut Context<Self>) -> Task<()> {
let context_store = self.context_store.clone();
cx.spawn(async move |this, cx| {
if let Some(wait_for_summaries) = context_store
.update(cx, |context_store, cx| context_store.wait_for_summaries(cx))
.log_err()
.ok()
{
this.update(cx, |this, cx| {
this.waiting_for_summaries_to_send = true;
cx.notify();
})
.log_err();
.ok();
wait_for_summaries.await;
@@ -326,24 +330,15 @@ impl MessageEditor {
this.waiting_for_summaries_to_send = false;
cx.notify();
})
.log_err();
.ok();
}
// Send to model after summaries are done
thread
.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(model, cx);
})
.log_err();
})
.detach();
}
fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let cancelled = self
.thread
.update(cx, |thread, cx| thread.cancel_last_completion(cx));
let cancelled = self.thread.update(cx, |thread, cx| {
thread.cancel_last_completion(Some(window.window_handle()), cx)
});
if cancelled {
self.set_editor_is_expanded(false, cx);
@@ -1013,6 +1008,48 @@ impl MessageEditor {
self.update_token_count_task.is_some()
}
fn handle_message_changed(&mut self, cx: &mut Context<Self>) {
self.message_or_context_changed(true, cx);
}
fn start_context_load(&mut self, cx: &mut Context<Self>) -> Shared<Task<()>> {
let summaries_task = self.wait_for_summaries(cx);
let load_task = cx.spawn(async move |this, cx| {
// Waits for detailed summaries before `load_context`, as it directly reads these from
// the thread. TODO: Would be cleaner to have context loading await on summarization.
summaries_task.await;
let Ok(load_task) = this.update(cx, |this, cx| {
let new_context = this.context_store.read_with(cx, |context_store, cx| {
context_store.new_context_for_thread(this.thread.read(cx))
});
load_context(new_context, &this.project, &this.prompt_store, cx)
}) else {
return;
};
let result = load_task.await;
this.update(cx, |this, cx| {
this.last_loaded_context = Some(result);
this.context_load_task = None;
this.message_or_context_changed(false, cx);
})
.ok();
});
// Replace existing load task, if any, causing it to be cancelled.
let load_task = load_task.shared();
self.context_load_task = Some(load_task.clone());
load_task
}
fn load_context(&mut self, cx: &mut Context<Self>) -> Task<Option<ContextLoadResult>> {
let context_load_task = self.start_context_load(cx);
cx.spawn(async move |this, cx| {
context_load_task.await;
this.read_with(cx, |this, _cx| this.last_loaded_context.clone())
.ok()
.flatten()
})
}
fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context<Self>) {
cx.emit(MessageEditorEvent::Changed);
self.update_token_count_task.take();
@@ -1022,9 +1059,7 @@ impl MessageEditor {
return;
};
let context_store = self.context_store.clone();
let editor = self.editor.clone();
let thread = self.thread.clone();
self.update_token_count_task = Some(cx.spawn(async move |this, cx| {
if debounce {
@@ -1033,27 +1068,39 @@ impl MessageEditor {
.await;
}
let token_count = if let Some(task) = cx.update(|cx| {
let context = context_store.read(cx).context().iter();
let new_context = thread.read(cx).filter_new_context(context);
let context_text =
format_context_as_string(new_context, cx).unwrap_or(String::new());
let token_count = if let Some(task) = this.update(cx, |this, cx| {
let loaded_context = this
.last_loaded_context
.as_ref()
.map(|context_load_result| &context_load_result.loaded_context);
let message_text = editor.read(cx).text(cx);
let content = context_text + &message_text;
if content.is_empty() {
if message_text.is_empty()
&& loaded_context.map_or(true, |loaded_context| loaded_context.is_empty())
{
return None;
}
let mut request_message = LanguageModelRequestMessage {
role: language_model::Role::User,
content: Vec::new(),
cache: false,
};
if let Some(loaded_context) = loaded_context {
loaded_context.add_to_request_message(&mut request_message);
}
if !message_text.is_empty() {
request_message
.content
.push(MessageContent::Text(message_text));
}
let request = language_model::LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: vec![LanguageModelRequestMessage {
role: language_model::Role::User,
content: vec![content.into()],
cache: false,
}],
messages: vec![request_message],
tools: vec![],
stop: vec![],
temperature: None,

View File

@@ -32,7 +32,7 @@ impl TerminalCodegen {
}
}
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
pub fn start(&mut self, prompt_task: Task<LanguageModelRequest>, cx: &mut Context<Self>) {
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
@@ -45,6 +45,7 @@ impl TerminalCodegen {
self.status = CodegenStatus::Pending;
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
self.generation = cx.spawn(async move |this, cx| {
let prompt = prompt_task.await;
let model_telemetry_id = model.telemetry_id();
let model_provider_id = model.provider_id();
let response = model.stream_completion_text(prompt, &cx).await;

View File

@@ -1,4 +1,4 @@
use crate::context::attach_context_to_message;
use crate::context::load_context;
use crate::context_store::ContextStore;
use crate::inline_prompt_editor::{
CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
@@ -10,14 +10,14 @@ use client::telemetry::Telemetry;
use collections::{HashMap, VecDeque};
use editor::{MultiBuffer, actions::SelectAll};
use fs::Fs;
use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
use language::Buffer;
use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
use project::Project;
use prompt_store::PromptBuilder;
use prompt_store::{PromptBuilder, PromptStore};
use std::sync::Arc;
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use terminal_view::TerminalView;
@@ -69,6 +69,7 @@ impl TerminalInlineAssistant {
terminal_view: &Entity<TerminalView>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
@@ -109,6 +110,7 @@ impl TerminalInlineAssistant {
prompt_editor,
workspace.clone(),
context_store,
prompt_store,
window,
cx,
);
@@ -196,11 +198,11 @@ impl TerminalInlineAssistant {
.log_err();
let codegen = assist.codegen.clone();
let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else {
let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
return;
};
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
}
fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
@@ -217,7 +219,7 @@ impl TerminalInlineAssistant {
&self,
assist_id: TerminalInlineAssistId,
cx: &mut App,
) -> Result<LanguageModelRequest> {
) -> Result<Task<LanguageModelRequest>> {
let assist = self.assists.get(&assist_id).context("invalid assist")?;
let shell = std::env::var("SHELL").ok();
@@ -246,28 +248,40 @@ impl TerminalInlineAssistant {
&latest_output,
)?;
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![],
cache: false,
};
let contexts = assist
.context_store
.read(cx)
.context()
.cloned()
.collect::<Vec<_>>();
let context_load_task = assist.workspace.update(cx, |workspace, cx| {
let project = workspace.project();
load_context(contexts, project, &assist.prompt_store, cx)
})?;
attach_context_to_message(
&mut request_message,
assist.context_store.read(cx).context().iter(),
cx,
);
Ok(cx.background_spawn(async move {
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![],
cache: false,
};
request_message.content.push(prompt.into());
context_load_task
.await
.loaded_context
.add_to_request_message(&mut request_message);
Ok(LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: vec![request_message],
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
})
request_message.content.push(prompt.into());
LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: vec![request_message],
tools: Vec::new(),
stop: Vec::new(),
temperature: None,
}
}))
}
fn finish_assist(
@@ -380,6 +394,7 @@ struct TerminalInlineAssist {
codegen: Entity<TerminalCodegen>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
_subscriptions: Vec<Subscription>,
}
@@ -390,6 +405,7 @@ impl TerminalInlineAssist {
prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
workspace: WeakEntity<Workspace>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
window: &mut Window,
cx: &mut App,
) -> Self {
@@ -400,6 +416,7 @@ impl TerminalInlineAssist {
codegen: codegen.clone(),
workspace: workspace.clone(),
context_store,
prompt_store,
_subscriptions: vec![
window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
TerminalInlineAssistant::update_global(cx, |this, cx| {

File diff suppressed because it is too large Load Diff

View File

@@ -24,8 +24,8 @@ use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::{Project, Worktree};
use prompt_store::{
ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent,
RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext,
ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
UserRulesContext, WorktreeContext,
};
use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
@@ -82,12 +82,11 @@ impl ThreadStore {
pub fn load(
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
prompt_store: Option<Entity<PromptStore>>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut App,
) -> Task<Result<Entity<Self>>> {
let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| {
let prompt_store = prompt_store.await.ok();
let (thread_store, ready_rx) = cx.update(|cx| {
let mut option_ready_rx = None;
let thread_store = cx.new(|cx| {
@@ -349,25 +348,8 @@ impl ThreadStore {
self.context_server_manager.clone()
}
pub fn prompt_store(&self) -> Option<Entity<PromptStore>> {
self.prompt_store.clone()
}
pub fn load_rules(
&self,
prompt_id: UserPromptId,
cx: &App,
) -> Task<Result<(PromptMetadata, String)>> {
let prompt_id = PromptId::User { uuid: prompt_id };
let Some(prompt_store) = self.prompt_store.as_ref() else {
return Task::ready(Err(anyhow!("Prompt store unexpectedly missing.")));
};
let prompt_store = prompt_store.read(cx);
let Some(metadata) = prompt_store.metadata(prompt_id) else {
return Task::ready(Err(anyhow!("User rules not found in library.")));
};
let text_task = prompt_store.load(prompt_id, cx);
cx.background_spawn(async move { Ok((metadata, text_task.await?)) })
pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
&self.prompt_store
}
pub fn tools(&self) -> Entity<ToolWorkingSet> {
@@ -379,16 +361,12 @@ impl ThreadStore {
self.threads.len()
}
pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
threads
}
pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
self.threads().into_iter().take(limit).collect()
}
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
cx.new(|cx| {
Thread::new(
@@ -516,6 +494,22 @@ impl ThreadStore {
);
});
}
// Enable all the tools from all context servers, but disable the ones that are explicitly disabled
for (context_server_id, preset) in &profile.context_servers {
self.tools.update(cx, |tools, cx| {
tools.disable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
.collect::<Vec<_>>(),
cx,
)
})
}
} else {
for (context_server_id, preset) in &profile.context_servers {
self.tools.update(cx, |tools, cx| {
@@ -639,12 +633,17 @@ pub struct SerializedThread {
}
impl SerializedThread {
pub const VERSION: &'static str = "0.1.0";
pub const VERSION: &'static str = "0.2.0";
pub fn from_json(json: &[u8]) -> Result<Self> {
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
match saved_thread_json.get("version") {
Some(serde_json::Value::String(version)) => match version.as_str() {
SerializedThreadV0_1_0::VERSION => {
let saved_thread =
serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
Ok(saved_thread.upgrade())
}
SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
saved_thread_json,
)?),
@@ -666,6 +665,38 @@ impl SerializedThread {
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SerializedThreadV0_1_0(
// The structure did not change, so we are reusing the latest SerializedThread.
// When making the next version, make sure this points to SerializedThreadV0_2_0
SerializedThread,
);
impl SerializedThreadV0_1_0 {
pub const VERSION: &'static str = "0.1.0";
pub fn upgrade(self) -> SerializedThread {
debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
for message in self.0.messages {
if message.role == Role::User && !message.tool_results.is_empty() {
if let Some(last_message) = messages.last_mut() {
debug_assert!(last_message.role == Role::Assistant);
last_message.tool_results = message.tool_results;
continue;
}
}
messages.push(message);
}
SerializedThread { messages, ..self.0 }
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SerializedMessage {
pub id: MessageId,

View File

@@ -30,7 +30,6 @@ pub struct ToolUse {
pub struct ToolUseState {
tools: Entity<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
@@ -42,7 +41,6 @@ impl ToolUseState {
Self {
tools,
tool_uses_by_assistant_message: HashMap::default(),
tool_uses_by_user_message: HashMap::default(),
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
tool_result_cards: HashMap::default(),
@@ -56,7 +54,6 @@ impl ToolUseState {
pub fn from_serialized_messages(
tools: Entity<ToolWorkingSet>,
messages: &[SerializedMessage],
mut filter_by_tool_name: impl FnMut(&str) -> bool,
) -> Self {
let mut this = Self::new(tools);
let mut tool_names_by_id = HashMap::default();
@@ -68,10 +65,10 @@ impl ToolUseState {
let tool_uses = message
.tool_uses
.iter()
.filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
.map(|tool_use| LanguageModelToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
raw_input: tool_use.input.to_string(),
input: tool_use.input.clone(),
is_input_complete: true,
})
@@ -85,14 +82,6 @@ impl ToolUseState {
this.tool_uses_by_assistant_message
.insert(message.id, tool_uses);
}
}
Role::User => {
if !message.tool_results.is_empty() {
let tool_uses_by_user_message = this
.tool_uses_by_user_message
.entry(message.id)
.or_default();
for tool_result in &message.tool_results {
let tool_use_id = tool_result.tool_use_id.clone();
@@ -101,11 +90,6 @@ impl ToolUseState {
continue;
};
if !(filter_by_tool_name)(tool_use.as_ref()) {
continue;
}
tool_uses_by_user_message.push(tool_use_id.clone());
this.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
@@ -118,7 +102,7 @@ impl ToolUseState {
}
}
}
Role::System => {}
Role::System | Role::User => {}
}
}
@@ -228,20 +212,26 @@ impl ToolUseState {
}
}
pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
let empty = Vec::new();
pub fn tool_results_for_message(
&self,
assistant_message_id: MessageId,
) -> Vec<&LanguageModelToolResult> {
let Some(tool_uses) = self
.tool_uses_by_assistant_message
.get(&assistant_message_id)
else {
return Vec::new();
};
self.tool_uses_by_user_message
.get(&message_id)
.unwrap_or(&empty)
tool_uses
.iter()
.filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
.filter_map(|tool_use| self.tool_results.get(&tool_use.id))
.collect()
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_uses_by_user_message
.get(&message_id)
pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.get(&assistant_message_id)
.map_or(false, |results| !results.is_empty())
}
@@ -293,14 +283,6 @@ impl ToolUseState {
self.tool_use_metadata_by_id
.insert(tool_use.id.clone(), metadata);
// The tool use is being requested by the Assistant, so we want to
// attach the tool results to the next user message.
let next_user_message_id = MessageId(assistant_message_id.0 + 1);
self.tool_uses_by_user_message
.entry(next_user_message_id)
.or_default()
.push(tool_use.id.clone());
PendingToolUseStatus::Idle
} else {
PendingToolUseStatus::InputStillStreaming
@@ -466,31 +448,49 @@ impl ToolUseState {
}
}
pub fn attach_tool_results(
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.contains_key(&assistant_message_id)
}
pub fn tool_results_message(
&self,
message_id: MessageId,
request_message: &mut LanguageModelRequestMessage,
) {
if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
for tool_use_id in tool_uses {
if let Some(tool_result) = self.tool_results.get(tool_use_id) {
request_message.content.push(MessageContent::ToolResult(
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name: tool_result.tool_name.clone(),
is_error: tool_result.is_error,
content: if tool_result.content.is_empty() {
// Surprisingly, the API fails if we return an empty string here.
// It thinks we are sending a tool use without a tool result.
"<Tool returned an empty string>".into()
} else {
tool_result.content.clone()
},
assistant_message_id: MessageId,
) -> Option<LanguageModelRequestMessage> {
let tool_uses = self
.tool_uses_by_assistant_message
.get(&assistant_message_id)?;
if tool_uses.is_empty() {
return None;
}
let mut request_message = LanguageModelRequestMessage {
role: Role::User,
content: vec![],
cache: false,
};
for tool_use in tool_uses {
if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
request_message
.content
.push(MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_use.id.clone(),
tool_name: tool_result.tool_name.clone(),
is_error: tool_result.is_error,
content: if tool_result.content.is_empty() {
// Surprisingly, the API fails if we return an empty string here.
// It thinks we are sending a tool use without a tool result.
"<Tool returned an empty string>".into()
} else {
tool_result.content.clone()
},
));
}
}));
}
}
Some(request_message)
}
}

View File

@@ -73,6 +73,11 @@ impl LabelCommon for AnimatedLabel {
self.base = self.base.buffer_font(cx);
self
}
fn inline_code(mut self, cx: &App) -> Self {
self.base = self.base.inline_code(cx);
self
}
}
impl RenderOnce for AnimatedLabel {

View File

@@ -1,14 +1,13 @@
use std::sync::Arc;
use std::{rc::Rc, time::Duration};
use file_icons::FileIcons;
use futures::FutureExt;
use gpui::{Animation, AnimationExt as _, Image, MouseButton, pulsating_between};
use gpui::{ClickEvent, Task};
use language_model::LanguageModelImage;
use gpui::{Animation, AnimationExt as _, ClickEvent, Entity, MouseButton, pulsating_between};
use project::Project;
use prompt_store::PromptStore;
use text::OffsetRangeExt;
use ui::{IconButtonShape, Tooltip, prelude::*, tooltip_container};
use crate::context::{AssistantContext, ContextId, ContextKind, ImageContext};
use crate::context::{AgentContext, ContextKind, ImageStatus};
#[derive(IntoElement)]
pub enum ContextPill {
@@ -73,9 +72,7 @@ impl ContextPill {
pub fn id(&self) -> ElementId {
match self {
Self::Added { context, .. } => {
ElementId::NamedInteger("context-pill".into(), context.id.0)
}
Self::Added { context, .. } => context.context.element_id("context-pill".into()),
Self::Suggested { .. } => "suggested-context-pill".into(),
}
}
@@ -199,14 +196,17 @@ impl RenderOnce for ContextPill {
)
.when_some(on_remove.as_ref(), |element, on_remove| {
element.child(
IconButton::new(("remove", context.id.0), IconName::Close)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.tooltip(Tooltip::text("Remove Context"))
.on_click({
let on_remove = on_remove.clone();
move |event, window, cx| on_remove(event, window, cx)
}),
IconButton::new(
context.context.element_id("remove".into()),
IconName::Close,
)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.tooltip(Tooltip::text("Remove Context"))
.on_click({
let on_remove = on_remove.clone();
move |event, window, cx| on_remove(event, window, cx)
}),
)
})
.when_some(on_click.as_ref(), |element, on_click| {
@@ -262,9 +262,11 @@ pub enum ContextStatus {
Error { message: SharedString },
}
#[derive(RegisterComponent)]
// TODO: Component commented out due to new dependency on `Project`.
//
// #[derive(RegisterComponent)]
pub struct AddedContext {
pub id: ContextId,
pub context: AgentContext,
pub kind: ContextKind,
pub name: SharedString,
pub parent: Option<SharedString>,
@@ -275,10 +277,19 @@ pub struct AddedContext {
}
impl AddedContext {
pub fn new(context: &AssistantContext, cx: &App) -> AddedContext {
/// Creates an `AddedContext` by retrieving relevant details of `AgentContext`. This returns a
/// `None` if `DirectoryContext` or `RulesContext` no longer exist.
///
/// TODO: `None` cases are unremovable from `ContextStore` and so are a very minor memory leak.
pub fn new(
context: AgentContext,
prompt_store: Option<&Entity<PromptStore>>,
project: &Project,
cx: &App,
) -> Option<AddedContext> {
match context {
AssistantContext::File(file_context) => {
let full_path = file_context.context_buffer.full_path(cx);
AgentContext::File(ref file_context) => {
let full_path = file_context.buffer.read(cx).file()?.full_path(cx);
let full_path_string: SharedString =
full_path.to_string_lossy().into_owned().into();
let name = full_path
@@ -289,8 +300,7 @@ impl AddedContext {
.parent()
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
AddedContext {
id: file_context.id,
Some(AddedContext {
kind: ContextKind::File,
name,
parent,
@@ -298,18 +308,16 @@ impl AddedContext {
icon_path: FileIcons::get_icon(&full_path, cx),
status: ContextStatus::Ready,
render_preview: None,
}
context,
})
}
AssistantContext::Directory(directory_context) => {
let worktree = directory_context.worktree.read(cx);
// If the directory no longer exists, use its last known path.
let full_path = worktree
.entry_for_id(directory_context.entry_id)
.map_or_else(
|| directory_context.last_path.clone(),
|entry| worktree.full_path(&entry.path).into(),
);
AgentContext::Directory(ref directory_context) => {
let worktree = project
.worktree_for_entry(directory_context.entry_id, cx)?
.read(cx);
let entry = worktree.entry_for_id(directory_context.entry_id)?;
let full_path = worktree.full_path(&entry.path);
let full_path_string: SharedString =
full_path.to_string_lossy().into_owned().into();
let name = full_path
@@ -320,8 +328,7 @@ impl AddedContext {
.parent()
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
AddedContext {
id: directory_context.id,
Some(AddedContext {
kind: ContextKind::Directory,
name,
parent,
@@ -329,33 +336,34 @@ impl AddedContext {
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
}
context,
})
}
AssistantContext::Symbol(symbol_context) => AddedContext {
id: symbol_context.id,
AgentContext::Symbol(ref symbol_context) => Some(AddedContext {
kind: ContextKind::Symbol,
name: symbol_context.context_symbol.id.name.clone(),
name: symbol_context.symbol.clone(),
parent: None,
tooltip: None,
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
},
context,
}),
AssistantContext::Selection(selection_context) => {
let full_path = selection_context.context_buffer.full_path(cx);
AgentContext::Selection(ref selection_context) => {
let buffer = selection_context.buffer.read(cx);
let full_path = buffer.file()?.full_path(cx);
let mut full_path_string = full_path.to_string_lossy().into_owned();
let mut name = full_path
.file_name()
.map(|n| n.to_string_lossy().into_owned())
.unwrap_or_else(|| full_path_string.clone());
let line_range_text = format!(
" ({}-{})",
selection_context.line_range.start.row + 1,
selection_context.line_range.end.row + 1
);
let line_range = selection_context.range.to_point(&buffer.snapshot());
let line_range_text =
format!(" ({}-{})", line_range.start.row + 1, line_range.end.row + 1);
full_path_string.push_str(&line_range_text);
name.push_str(&line_range_text);
@@ -365,16 +373,17 @@ impl AddedContext {
.and_then(|p| p.file_name())
.map(|n| n.to_string_lossy().into_owned().into());
AddedContext {
id: selection_context.id,
Some(AddedContext {
kind: ContextKind::Selection,
name: name.into(),
parent,
tooltip: None,
icon_path: FileIcons::get_icon(&full_path, cx),
status: ContextStatus::Ready,
render_preview: None,
/*
render_preview: Some(Rc::new({
let content = selection_context.context_buffer.text.clone();
let content = selection_context.text.clone();
move |_, cx| {
div()
.id("context-pill-selection-preview")
@@ -385,11 +394,12 @@ impl AddedContext {
.into_any_element()
}
})),
}
*/
context,
})
}
AssistantContext::FetchedUrl(fetched_url_context) => AddedContext {
id: fetched_url_context.id,
AgentContext::FetchedUrl(ref fetched_url_context) => Some(AddedContext {
kind: ContextKind::FetchedUrl,
name: fetched_url_context.url.clone(),
parent: None,
@@ -397,12 +407,12 @@ impl AddedContext {
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
},
context,
}),
AssistantContext::Thread(thread_context) => AddedContext {
id: thread_context.id,
AgentContext::Thread(ref thread_context) => Some(AddedContext {
kind: ContextKind::Thread,
name: thread_context.summary(cx),
name: thread_context.name(cx),
parent: None,
tooltip: None,
icon_path: None,
@@ -418,36 +428,41 @@ impl AddedContext {
ContextStatus::Ready
},
render_preview: None,
},
context,
}),
AssistantContext::Rules(user_rules_context) => AddedContext {
id: user_rules_context.id,
kind: ContextKind::Rules,
name: user_rules_context.title.clone(),
parent: None,
tooltip: None,
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
},
AgentContext::Rules(ref user_rules_context) => {
let name = prompt_store
.as_ref()?
.read(cx)
.metadata(user_rules_context.prompt_id.into())?
.title?;
Some(AddedContext {
kind: ContextKind::Rules,
name: name.clone(),
parent: None,
tooltip: None,
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
context,
})
}
AssistantContext::Image(image_context) => AddedContext {
id: image_context.id,
AgentContext::Image(ref image_context) => Some(AddedContext {
kind: ContextKind::Image,
name: "Image".into(),
parent: None,
tooltip: None,
icon_path: None,
status: if image_context.is_loading() {
ContextStatus::Loading {
status: match image_context.status() {
ImageStatus::Loading => ContextStatus::Loading {
message: "Loading…".into(),
}
} else if image_context.is_error() {
ContextStatus::Error {
},
ImageStatus::Error => ContextStatus::Error {
message: "Failed to load image".into(),
}
} else {
ContextStatus::Ready
},
ImageStatus::Ready => ContextStatus::Ready,
},
render_preview: Some(Rc::new({
let image = image_context.original_image.clone();
@@ -458,7 +473,8 @@ impl AddedContext {
.into_any_element()
}
})),
},
context,
}),
}
}
}
@@ -478,6 +494,8 @@ impl Render for ContextPillPreview {
}
}
// TODO: Component commented out due to new dependency on `Project`.
/*
impl Component for AddedContext {
fn scope() -> ComponentScope {
ComponentScope::Agent
@@ -487,12 +505,13 @@ impl Component for AddedContext {
"AddedContext"
}
fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> {
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
let next_context_id = ContextId::zero();
let image_ready = (
"Ready",
AddedContext::new(
&AssistantContext::Image(ImageContext {
id: ContextId(0),
AgentContext::Image(ImageContext {
context_id: next_context_id.post_inc(),
original_image: Arc::new(Image::empty()),
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
}),
@@ -503,8 +522,8 @@ impl Component for AddedContext {
let image_loading = (
"Loading",
AddedContext::new(
&AssistantContext::Image(ImageContext {
id: ContextId(1),
AgentContext::Image(ImageContext {
context_id: next_context_id.post_inc(),
original_image: Arc::new(Image::empty()),
image_task: cx
.background_spawn(async move {
@@ -520,8 +539,8 @@ impl Component for AddedContext {
let image_error = (
"Error",
AddedContext::new(
&AssistantContext::Image(ImageContext {
id: ContextId(2),
AgentContext::Image(ImageContext {
context_id: next_context_id.post_inc(),
original_image: Arc::new(Image::empty()),
image_task: Task::ready(None).shared(),
}),
@@ -544,5 +563,8 @@ impl Component for AddedContext {
)
.into_any(),
)
None
}
}
*/

View File

@@ -15,6 +15,7 @@ path = "src/askpass.rs"
anyhow.workspace = true
futures.workspace = true
gpui.workspace = true
shlex.workspace = true
smol.workspace = true
tempfile.workspace = true
util.workspace = true

View File

@@ -72,8 +72,7 @@ impl AskPassSession {
let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
let listener =
UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
let zed_path = std::env::current_exe()
.context("Failed to figure out current executable path for use in askpass")?;
let zed_path = get_shell_safe_zed_path()?;
let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
let mut kill_tx = Some(askpass_kill_master_tx);
@@ -115,7 +114,7 @@ impl AskPassSession {
// Create an askpass script that communicates back to this process.
let askpass_script = format!(
"{shebang}\n{print_args} | {zed_exe} --askpass={askpass_socket} 2> /dev/null \n",
zed_exe = zed_path.display(),
zed_exe = zed_path,
askpass_socket = askpass_socket.display(),
print_args = "printf '%s\\0' \"$@\"",
shebang = "#!/bin/sh",
@@ -161,6 +160,32 @@ impl AskPassSession {
}
}
#[cfg(unix)]
fn get_shell_safe_zed_path() -> anyhow::Result<String> {
let zed_path = std::env::current_exe()
.context("Failed to figure out current executable path for use in askpass")?
.to_string_lossy()
.to_string();
// sanity check on unix systems that the path exists and is executable
// todo(windows): implement this check for windows (or just use `is-executable` crate)
use std::os::unix::fs::MetadataExt;
let metadata = std::fs::metadata(&zed_path)
.context("Failed to check metadata of Zed executable path for use in askpass")?;
let is_executable = metadata.is_file() && metadata.mode() & 0o111 != 0;
anyhow::ensure!(
is_executable,
"Failed to verify Zed executable path for use in askpass"
);
// As of writing, this can only be fail if the path contains a null byte, which shouldn't be possible
// but shlex has annotated the error as #[non_exhaustive] so we can't make it a compile error if other
// errors are introduced in the future :(
let zed_path_escaped = shlex::try_quote(&zed_path)
.context("Failed to shell-escape Zed executable path for use in askpass")?;
return Ok(zed_path_escaped.to_string());
}
/// The main function for when Zed is running in netcat mode for use in askpass.
/// Called from both the remote server binary and the zed binary in their respective main functions.
#[cfg(unix)]

View File

@@ -49,7 +49,7 @@ menu.workspace = true
multi_buffer.workspace = true
parking_lot.workspace = true
project.workspace = true
prompt_library.workspace = true
rules_library.workspace = true
prompt_store.workspace = true
proto.workspace = true
rope.workspace = true

View File

@@ -101,7 +101,7 @@ pub fn init(
SlashCommandSettings::register(cx);
assistant_context_editor::init(client.clone(), cx);
prompt_library::init(cx);
rules_library::init(cx);
init_language_model_settings(cx);
assistant_slash_command::init(cx);
assistant_tool::init(cx);

View File

@@ -25,8 +25,8 @@ use language_model::{
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
};
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
use prompt_store::{PromptBuilder, UserPromptId};
use rules_library::{RulesLibrary, open_rules_library};
use search::{BufferSearchBar, buffer_search::DivRegistrar};
use settings::{Settings, update_settings_file};
@@ -43,7 +43,7 @@ use workspace::{
dock::{DockPosition, Panel, PanelEvent},
pane,
};
use zed_actions::assistant::{InlineAssist, OpenPromptLibrary, ShowConfiguration, ToggleFocus};
use zed_actions::assistant::{InlineAssist, OpenRulesLibrary, ShowConfiguration, ToggleFocus};
pub fn init(cx: &mut App) {
workspace::FollowableViewRegistry::register::<ContextEditor>(cx);
@@ -57,11 +57,11 @@ pub fn init(cx: &mut App) {
.register_action(AssistantPanel::show_configuration)
.register_action(AssistantPanel::create_new_context)
.register_action(AssistantPanel::restart_context_servers)
.register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| {
panel.deploy_prompt_library(action, window, cx)
panel.deploy_rules_library(action, window, cx)
});
}
});
@@ -272,8 +272,8 @@ impl AssistantPanel {
.action("New Chat", Box::new(NewChat))
.action("History", Box::new(DeployHistory))
.action(
"Prompt Library",
Box::new(OpenPromptLibrary::default()),
"Rules Library",
Box::new(OpenRulesLibrary::default()),
)
.action("Configure", Box::new(ShowConfiguration))
.action(zoom_label, Box::new(ToggleZoom))
@@ -476,7 +476,7 @@ impl AssistantPanel {
{
return;
}
context.custom_summary(new_summary, cx)
context.set_custom_summary(new_summary, cx)
});
});
}
@@ -1043,13 +1043,13 @@ impl AssistantPanel {
}
}
fn deploy_prompt_library(
fn deploy_rules_library(
&mut self,
action: &OpenPromptLibrary,
action: &OpenRulesLibrary,
_window: &mut Window,
cx: &mut Context<Self>,
) {
open_prompt_library(
open_rules_library(
self.languages.clone(),
Box::new(PromptLibraryInlineAssist),
Arc::new(|| {
@@ -1059,9 +1059,9 @@ impl AssistantPanel {
None,
))
}),
action.prompt_to_select.map(|uuid| PromptId::User {
uuid: UserPromptId(uuid),
}),
action
.prompt_to_select
.map(|uuid| UserPromptId(uuid).into()),
cx,
)
.detach_and_log_err(cx);
@@ -1235,7 +1235,7 @@ impl Render for AssistantPanel {
this.show_configuration_tab(window, cx)
}))
.on_action(cx.listener(AssistantPanel::deploy_history))
.on_action(cx.listener(AssistantPanel::deploy_prompt_library))
.on_action(cx.listener(AssistantPanel::deploy_rules_library))
.child(registrar.size_full().child(self.pane.clone()))
.into_any_element()
}
@@ -1350,13 +1350,13 @@ impl Focusable for AssistantPanel {
struct PromptLibraryInlineAssist;
impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
fn assist(
&self,
prompt_editor: &Entity<Editor>,
initial_prompt: Option<String>,
window: &mut Window,
cx: &mut Context<PromptLibrary>,
cx: &mut Context<RulesLibrary>,
) {
InlineAssistant::update_global(cx, |assistant, cx| {
assistant.assist(&prompt_editor, None, None, initial_prompt, window, cx)

View File

@@ -22,7 +22,7 @@ use feature_flags::{
};
use fs::Fs;
use futures::{
SinkExt, Stream, StreamExt,
SinkExt, Stream, StreamExt, TryStreamExt as _,
channel::mpsc,
future::{BoxFuture, LocalBoxFuture},
join,
@@ -3056,7 +3056,8 @@ impl CodegenAlternative {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
let chunks = StripInvalidSpans::new(stream?.stream);
let chunks =
StripInvalidSpans::new(stream?.stream.map_err(|e| e.into()));
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();

View File

@@ -459,6 +459,7 @@ pub enum ContextEvent {
ShowMaxMonthlySpendReachedError,
MessagesEdited,
SummaryChanged,
SummaryGenerated,
StreamedCompletion,
StartedThoughtProcess(Range<language::Anchor>),
EndedThoughtProcess(language::Anchor),
@@ -482,7 +483,7 @@ pub enum ContextEvent {
#[derive(Clone, Default, Debug)]
pub struct ContextSummary {
pub text: String,
done: bool,
pub done: bool,
timestamp: clock::Lamport,
}
@@ -640,7 +641,7 @@ pub struct AssistantContext {
contents: Vec<Content>,
messages_metadata: HashMap<MessageId, MessageMetadata>,
summary: Option<ContextSummary>,
pending_summary: Task<Option<()>>,
summary_task: Task<Option<()>>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
token_count: Option<usize>,
@@ -741,7 +742,7 @@ impl AssistantContext {
thought_process_output_sections: Vec::new(),
edits_since_last_parse: edits_since_last_slash_command_parse,
summary: None,
pending_summary: Task::ready(None),
summary_task: Task::ready(None),
completion_count: Default::default(),
pending_completions: Default::default(),
token_count: None,
@@ -951,7 +952,7 @@ impl AssistantContext {
fn flush_ops(&mut self, cx: &mut Context<AssistantContext>) {
let mut changed_messages = HashSet::default();
let mut summary_changed = false;
let mut summary_generated = false;
self.pending_ops.sort_unstable_by_key(|op| op.timestamp());
for op in mem::take(&mut self.pending_ops) {
@@ -993,7 +994,7 @@ impl AssistantContext {
.map_or(true, |summary| new_summary.timestamp > summary.timestamp)
{
self.summary = Some(new_summary);
summary_changed = true;
summary_generated = true;
}
}
ContextOperation::SlashCommandStarted {
@@ -1072,8 +1073,9 @@ impl AssistantContext {
cx.notify();
}
if summary_changed {
if summary_generated {
cx.emit(ContextEvent::SummaryChanged);
cx.emit(ContextEvent::SummaryGenerated);
cx.notify();
}
}
@@ -2611,7 +2613,9 @@ impl AssistantContext {
.map(MessageContent::Text),
);
completion_request.messages.push(request_message);
if !request_message.contents_empty() {
completion_request.messages.push(request_message);
}
}
if let RequestType::SuggestEdits = request_type {
@@ -2945,7 +2949,7 @@ impl AssistantContext {
self.message_anchors.insert(insertion_ix, new_anchor);
}
pub fn summarize(&mut self, replace_old: bool, cx: &mut Context<Self>) {
pub fn summarize(&mut self, mut replace_old: bool, cx: &mut Context<Self>) {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return;
};
@@ -2965,7 +2969,18 @@ impl AssistantContext {
cache: false,
});
self.pending_summary = cx.spawn(async move |this, cx| {
// If there is no summary, it is set with `done: false` so that "Loading Summary…" can
// be displayed.
if self.summary.is_none() {
self.summary = Some(ContextSummary {
text: "".to_string(),
done: false,
timestamp: clock::Lamport::default(),
});
replace_old = true;
}
self.summary_task = cx.spawn(async move |this, cx| {
async move {
let stream = model.model.stream_completion_text(request, &cx);
let mut messages = stream.await?;
@@ -2990,6 +3005,7 @@ impl AssistantContext {
};
this.push_op(operation, cx);
cx.emit(ContextEvent::SummaryChanged);
cx.emit(ContextEvent::SummaryGenerated);
})?;
// Stop if the LLM generated multiple lines.
@@ -3010,6 +3026,7 @@ impl AssistantContext {
};
this.push_op(operation, cx);
cx.emit(ContextEvent::SummaryChanged);
cx.emit(ContextEvent::SummaryGenerated);
}
})?;
@@ -3182,7 +3199,7 @@ impl AssistantContext {
});
}
pub fn custom_summary(&mut self, custom_summary: String, cx: &mut Context<Self>) {
pub fn set_custom_summary(&mut self, custom_summary: String, cx: &mut Context<Self>) {
let timestamp = self.next_timestamp();
let summary = self.summary.get_or_insert(ContextSummary::default());
summary.timestamp = timestamp;
@@ -3190,6 +3207,15 @@ impl AssistantContext {
summary.text = custom_summary;
cx.emit(ContextEvent::SummaryChanged);
}
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Text Thread");
pub fn summary_or_default(&self) -> SharedString {
self.summary
.as_ref()
.map(|summary| summary.text.clone().into())
.unwrap_or(Self::DEFAULT_SUMMARY)
}
}
fn trimmed_text_in_range(buffer: &BufferSnapshot, range: Range<text::Anchor>) -> String {

View File

@@ -48,7 +48,7 @@ use project::{Project, Worktree};
use rope::Point;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore, update_settings_file};
use std::{any::TypeId, borrow::Cow, cmp, ops::Range, path::PathBuf, sync::Arc, time::Duration};
use std::{any::TypeId, cmp, ops::Range, path::PathBuf, sync::Arc, time::Duration};
use text::SelectionGoal;
use ui::{
ButtonLike, Disclosure, ElevationIndex, KeyBinding, PopoverMenuHandle, TintColor, Tooltip,
@@ -618,6 +618,7 @@ impl ContextEditor {
context.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
});
}
ContextEvent::SummaryGenerated => {}
ContextEvent::StartedThoughtProcess(range) => {
let creases = self.insert_thought_process_output_sections(
[(
@@ -2179,13 +2180,8 @@ impl ContextEditor {
});
}
pub fn title(&self, cx: &App) -> Cow<str> {
self.context
.read(cx)
.summary()
.map(|summary| summary.text.clone())
.map(Cow::Owned)
.unwrap_or_else(|| Cow::Borrowed(DEFAULT_TAB_TITLE))
pub fn title(&self, cx: &App) -> SharedString {
self.context.read(cx).summary_or_default()
}
fn render_patch_block(

View File

@@ -28,6 +28,7 @@ serde.workspace = true
serde_json.workspace = true
text.workspace = true
util.workspace = true
workspace.workspace = true
workspace-hack.workspace = true
[dev-dependencies]

View File

@@ -10,14 +10,16 @@ use std::sync::Arc;
use anyhow::Result;
use gpui::AnyElement;
use gpui::AnyWindowHandle;
use gpui::Context;
use gpui::IntoElement;
use gpui::Window;
use gpui::{App, Entity, SharedString, Task};
use gpui::{App, Entity, SharedString, Task, WeakEntity};
use icons::IconName;
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
use project::Project;
use workspace::Workspace;
pub use crate::action_log::*;
pub use crate::tool_registry::*;
@@ -65,6 +67,7 @@ pub trait ToolCard: 'static + Sized {
&mut self,
status: &ToolUseStatus,
window: &mut Window,
workspace: WeakEntity<Workspace>,
cx: &mut Context<Self>,
) -> impl IntoElement;
}
@@ -76,6 +79,7 @@ pub struct AnyToolCard {
entity: gpui::AnyEntity,
status: &ToolUseStatus,
window: &mut Window,
workspace: WeakEntity<Workspace>,
cx: &mut App,
) -> AnyElement,
}
@@ -86,11 +90,14 @@ impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
entity: gpui::AnyEntity,
status: &ToolUseStatus,
window: &mut Window,
workspace: WeakEntity<Workspace>,
cx: &mut App,
) -> AnyElement {
let entity = entity.downcast::<T>().unwrap();
entity.update(cx, |entity, cx| {
entity.render(status, window, cx).into_any_element()
entity
.render(status, window, workspace, cx)
.into_any_element()
})
}
@@ -102,8 +109,14 @@ impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
}
impl AnyToolCard {
pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement {
(self.render)(self.entity.clone(), status, window, cx)
pub fn render(
&self,
status: &ToolUseStatus,
window: &mut Window,
workspace: WeakEntity<Workspace>,
cx: &mut App,
) -> AnyElement {
(self.render)(self.entity.clone(), status, window, workspace, cx)
}
}
@@ -163,6 +176,7 @@ pub trait Tool: 'static + Send + Sync {
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult;
}

View File

@@ -10,6 +10,11 @@ pub fn adapt_schema_to_format(
json: &mut Value,
format: LanguageModelToolSchemaFormat,
) -> Result<()> {
if let Value::Object(obj) = json {
obj.remove("$schema");
obj.remove("title");
}
match format {
LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
@@ -30,10 +35,7 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
}
}
const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"];
for key in KEYS_TO_REMOVE {
obj.remove(key);
}
obj.remove("format");
if let Some(default) = obj.get("default") {
let is_null = default.is_null();

View File

@@ -14,9 +14,11 @@ path = "src/assistant_tools.rs"
[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
collections.workspace = true
component.workspace = true
editor.workspace = true
futures.workspace = true
gpui.workspace = true
html_to_markdown.workspace = true
@@ -36,7 +38,7 @@ ui.workspace = true
util.workspace = true
web_search.workspace = true
workspace-hack.workspace = true
worktree.workspace = true
workspace.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]

View File

@@ -9,12 +9,12 @@ mod delete_path_tool;
mod diagnostics_tool;
mod edit_file_tool;
mod fetch_tool;
mod find_path_tool;
mod grep_tool;
mod list_directory_tool;
mod move_path_tool;
mod now_tool;
mod open_tool;
mod path_search_tool;
mod read_file_tool;
mod rename_tool;
mod replace;
@@ -45,18 +45,21 @@ use crate::delete_path_tool::DeletePathTool;
use crate::diagnostics_tool::DiagnosticsTool;
use crate::edit_file_tool::EditFileTool;
use crate::fetch_tool::FetchTool;
use crate::find_path_tool::FindPathTool;
use crate::grep_tool::GrepTool;
use crate::list_directory_tool::ListDirectoryTool;
use crate::now_tool::NowTool;
use crate::open_tool::OpenTool;
use crate::path_search_tool::PathSearchTool;
use crate::read_file_tool::ReadFileTool;
use crate::rename_tool::RenameTool;
use crate::symbol_info_tool::SymbolInfoTool;
use crate::terminal_tool::TerminalTool;
use crate::thinking_tool::ThinkingTool;
pub use path_search_tool::PathSearchToolInput;
pub use create_file_tool::CreateFileToolInput;
pub use edit_file_tool::EditFileToolInput;
pub use find_path_tool::FindPathToolInput;
pub use read_file_tool::ReadFileToolInput;
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
assistant_tool::init(cx);
@@ -78,7 +81,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(OpenTool);
registry.register_tool(CodeSymbolsTool);
registry.register_tool(ContentsTool);
registry.register_tool(PathSearchTool);
registry.register_tool(FindPathTool);
registry.register_tool(ReadFileTool);
registry.register_tool(GrepTool);
registry.register_tool(RenameTool);
@@ -107,11 +110,38 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
#[cfg(test)]
mod tests {
use super::*;
use client::Client;
use clock::FakeSystemClock;
use http_client::FakeHttpClient;
use schemars::JsonSchema;
use serde::Serialize;
use super::*;
#[test]
fn test_json_schema() {
#[derive(Serialize, JsonSchema)]
struct GetWeatherTool {
location: String,
}
let schema = schema::json_schema_for::<GetWeatherTool>(
language_model::LanguageModelToolSchemaFormat::JsonSchema,
)
.unwrap();
assert_eq!(
schema,
serde_json::json!({
"type": "object",
"properties": {
"location": {
"type": "string"
}
},
"required": ["location"],
})
);
}
#[gpui::test]
fn test_builtin_tool_schema_compatibility(cx: &mut App) {

View File

@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult, ToolWorkingSet};
use futures::future::join_all;
use gpui::{App, AppContext, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -97,7 +97,7 @@ pub struct BatchToolInput {
/// }
/// },
/// {
/// "name": "path_search",
/// "name": "find_path",
/// "input": {
/// "glob": "**/*test*.rs"
/// }
@@ -218,6 +218,7 @@ impl Tool for BatchTool {
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<BatchToolInput>(input) {
@@ -258,7 +259,9 @@ impl Tool for BatchTool {
let action_log = action_log.clone();
let messages = messages.clone();
let tool_result = cx
.update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
.update(|cx| {
tool.run(invocation.input, &messages, project, action_log, window, cx)
})
.map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
tasks.push(tool_result.output);

View File

@@ -1,6 +1,6 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language::{self, Anchor, Buffer, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{self, LspAction, Project};
@@ -140,6 +140,7 @@ impl Tool for CodeActionTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CodeActionToolInput>(input) {

View File

@@ -6,7 +6,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use collections::IndexMap;
use gpui::{App, AsyncApp, Entity, Task};
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
use language::{OutlineItem, ParseStatus, Point};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{Project, Symbol};
@@ -128,6 +128,7 @@ impl Tool for CodeSymbolsTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CodeSymbolsInput>(input) {

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use itertools::Itertools;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -102,6 +102,7 @@ impl Tool for ContentsTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ContentsToolInput>(input) {

View File

@@ -1,6 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::AnyWindowHandle;
use gpui::{App, AppContext, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -76,6 +77,7 @@ impl Tool for CopyPathTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CopyPathToolInput>(input) {

View File

@@ -1,6 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::AnyWindowHandle;
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -67,6 +68,7 @@ impl Tool for CreateDirectoryTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CreateDirectoryToolInput>(input) {

View File

@@ -1,6 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::AnyWindowHandle;
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -23,6 +24,9 @@ pub struct CreateFileToolInput {
///
/// You can create a new file by providing a path of "directory1/new_file.txt"
/// </example>
///
/// Make sure to include this field before the `contents` field in the input object
/// so that we can display it immediately.
pub path: String,
/// The text contents of the file to create.
@@ -89,6 +93,7 @@ impl Tool for CreateFileTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CreateFileToolInput>(input) {

View File

@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::{SinkExt, StreamExt, channel::mpsc};
use gpui::{App, AppContext, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{Project, ProjectPath};
use schemars::JsonSchema;
@@ -62,6 +62,7 @@ impl Tool for DeletePathTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) {

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language::{DiagnosticSeverity, OffsetRangeExt};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -82,6 +82,7 @@ impl Tool for DiagnosticsTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
match serde_json::from_value::<DiagnosticsToolInput>(input)

View File

@@ -1,18 +1,40 @@
use crate::{replace::replace_with_flexible_indent, schema::json_schema_for};
use crate::{
replace::{replace_exact, replace_with_flexible_indent},
schema::json_schema_for,
};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolUseStatus};
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{Editor, EditorMode, MultiBuffer, PathKey};
use gpui::{
AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EntityId, Task, WeakEntity,
};
use language::{
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use ui::IconName;
use crate::replace::replace_exact;
use std::{
path::{Path, PathBuf},
sync::Arc,
};
use ui::{Disclosure, Tooltip, Window, prelude::*};
use util::ResultExt;
use workspace::Workspace;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct EditFileToolInput {
/// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
///
/// <example>Fix API endpoint URLs</example>
/// <example>Update copyright year in `page_footer`</example>
///
/// Make sure to include this field before all the others in the input object
/// so that we can display it immediately.
pub display_description: String,
/// The full path of the file to modify in the project.
///
/// WARNING: When specifying which file path need changing, you MUST
@@ -34,12 +56,6 @@ pub struct EditFileToolInput {
/// </example>
pub path: PathBuf,
/// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
///
/// <example>Fix API endpoint URLs</example>
/// <example>Update copyright year in `page_footer`</example>
pub display_description: String,
/// The text to replace.
pub old_string: String,
@@ -113,6 +129,7 @@ impl Tool for EditFileTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<EditFileToolInput>(input) {
@@ -120,7 +137,18 @@ impl Tool for EditFileTool {
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
cx.spawn(async move |cx: &mut AsyncApp| {
let card = window.and_then(|window| {
window
.update(cx, |_, window, cx| {
cx.new(|cx| {
EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
})
})
.ok()
});
let card_clone = card.clone();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
@@ -128,26 +156,38 @@ impl Tool for EditFileTool {
})??;
let buffer = project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})?
.await?;
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
if input.old_string.is_empty() {
return Err(anyhow!("`old_string` cannot be empty. Use a different tool if you want to create a file."));
return Err(anyhow!(
"`old_string` can't be empty, use another tool if you want to create a file."
));
}
if input.old_string == input.new_string {
return Err(anyhow!("The `old_string` and `new_string` are identical, so no changes would be made."));
return Err(anyhow!(
"The `old_string` and `new_string` are identical, so no changes would be made."
));
}
let result = cx
.background_spawn(async move {
// Try to match exactly
let diff = replace_exact(&input.old_string, &input.new_string, &snapshot)
.await
// If that fails, try being flexible about indentation
.or_else(|| replace_with_flexible_indent(&input.old_string, &input.new_string, &snapshot))?;
.await
// If that fails, try being flexible about indentation
.or_else(|| {
replace_with_flexible_indent(
&input.old_string,
&input.new_string,
&snapshot,
)
})?;
if diff.edits.is_empty() {
return None;
@@ -177,41 +217,409 @@ impl Tool for EditFileTool {
}
})?;
return Err(err)
return Err(err);
};
let snapshot = cx.update(|cx| {
action_log.update(cx, |log, cx| {
log.track_buffer(buffer.clone(), cx)
});
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
let snapshot = buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction();
buffer.apply_diff(diff, cx);
buffer.finalize_last_transaction();
buffer.snapshot()
});
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx)
});
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
snapshot
})?;
project.update( cx, |project, cx| {
project.save_buffer(buffer, cx)
})?.await?;
project
.update(cx, |project, cx| project.save_buffer(buffer, cx))?
.await?;
let diff_str = cx.background_spawn(async move {
let new_text = snapshot.text();
language::unified_diff(&old_text, &new_text)
}).await;
let new_text = snapshot.text();
let diff_str = cx
.background_spawn({
let old_text = old_text.clone();
let new_text = new_text.clone();
async move { language::unified_diff(&old_text, &new_text) }
})
.await;
if let Some(card) = card_clone {
card.update(cx, |card, cx| {
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
})
.log_err();
}
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str))
Ok(format!(
"Edited {}:\n\n```diff\n{}\n```",
input.path.display(),
diff_str
))
});
}).into()
ToolResult {
output: task,
card: card.map(AnyToolCard::from),
}
}
}
pub struct EditFileToolCard {
path: PathBuf,
editor: Entity<Editor>,
multibuffer: Entity<MultiBuffer>,
project: Entity<Project>,
diff_task: Option<Task<Result<()>>>,
preview_expanded: bool,
full_height_expanded: bool,
editor_unique_id: EntityId,
}
impl EditFileToolCard {
fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
let editor = cx.new(|cx| {
let mut editor = Editor::new(
EditorMode::Full {
scale_ui_elements_with_buffer_font_size: false,
show_active_line_background: false,
sized_by_content: true,
},
multibuffer.clone(),
Some(project.clone()),
window,
cx,
);
editor.set_show_scrollbars(false, cx);
editor.set_show_gutter(false, cx);
editor.disable_inline_diagnostics();
editor.disable_scrolling(cx);
editor.disable_expand_excerpt_buttons(cx);
editor.set_show_breakpoints(false, cx);
editor.set_show_code_actions(false, cx);
editor.set_show_git_diff_gutter(false, cx);
editor.set_expand_all_diff_hunks(cx);
editor
});
Self {
editor_unique_id: editor.entity_id(),
path,
project,
editor,
multibuffer,
diff_task: None,
preview_expanded: true,
full_height_expanded: false,
}
}
fn set_diff(
&mut self,
path: Arc<Path>,
old_text: String,
new_text: String,
cx: &mut Context<Self>,
) {
let language_registry = self.project.read(cx).languages().clone();
self.diff_task = Some(cx.spawn(async move |this, cx| {
let buffer = build_buffer(new_text, path.clone(), &language_registry, cx).await?;
let buffer_diff = build_buffer_diff(old_text, &buffer, &language_registry, cx).await?;
this.update(cx, |this, cx| {
this.multibuffer.update(cx, |multibuffer, cx| {
let snapshot = buffer.read(cx).snapshot();
let diff = buffer_diff.read(cx);
let diff_hunk_ranges = diff
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
.collect::<Vec<_>>();
let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
buffer,
diff_hunk_ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
debug_assert!(is_newly_added);
multibuffer.add_diff(buffer_diff, cx);
});
cx.notify();
})
}));
}
}
impl ToolCard for EditFileToolCard {
fn render(
&mut self,
status: &ToolUseStatus,
window: &mut Window,
workspace: WeakEntity<Workspace>,
cx: &mut Context<Self>,
) -> impl IntoElement {
let failed = matches!(status, ToolUseStatus::Error(_));
let path_label_button = h_flex()
.id(("edit-tool-path-label-button", self.editor_unique_id))
.w_full()
.max_w_full()
.px_1()
.gap_0p5()
.cursor_pointer()
.rounded_sm()
.opacity(0.8)
.hover(|label| {
label
.opacity(1.)
.bg(cx.theme().colors().element_hover.opacity(0.5))
})
.tooltip(Tooltip::text("Jump to File"))
.child(
h_flex()
.child(
Icon::new(IconName::Pencil)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(
div()
.text_size(rems(0.8125))
.child(self.path.display().to_string())
.ml_1p5()
.mr_0p5(),
)
.child(
Icon::new(IconName::ArrowUpRight)
.size(IconSize::XSmall)
.color(Color::Ignored),
),
)
.on_click({
let path = self.path.clone();
let workspace = workspace.clone();
move |_, window, cx| {
workspace
.update(cx, {
|workspace, cx| {
let Some(project_path) =
workspace.project().read(cx).find_project_path(&path, cx)
else {
return;
};
let open_task =
workspace.open_path(project_path, None, true, window, cx);
window
.spawn(cx, async move |cx| {
let item = open_task.await?;
if let Some(active_editor) = item.downcast::<Editor>() {
active_editor
.update_in(cx, |editor, window, cx| {
editor.go_to_singleton_buffer_point(
language::Point::new(0, 0),
window,
cx,
);
})
.log_err();
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
})
.ok();
}
})
.into_any_element();
let codeblock_header_bg = cx
.theme()
.colors()
.element_background
.blend(cx.theme().colors().editor_foreground.opacity(0.025));
let codeblock_header = h_flex()
.flex_none()
.p_1()
.gap_1()
.justify_between()
.rounded_t_md()
.when(!failed, |header| header.bg(codeblock_header_bg))
.child(path_label_button)
.map(|container| {
if failed {
container.child(
Icon::new(IconName::Close)
.size(IconSize::Small)
.color(Color::Error),
)
} else {
container.child(
Disclosure::new(
("edit-file-disclosure", self.editor_unique_id),
self.preview_expanded,
)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener(
move |this, _event, _window, _cx| {
this.preview_expanded = !this.preview_expanded;
},
)),
)
}
});
let editor = self.editor.update(cx, |editor, cx| {
editor.render(window, cx).into_any_element()
});
let (full_height_icon, full_height_tooltip_label) = if self.full_height_expanded {
(IconName::ChevronUp, "Collapse Code Block")
} else {
(IconName::ChevronDown, "Expand Code Block")
};
let gradient_overlay = div()
.absolute()
.bottom_0()
.left_0()
.w_full()
.h_2_5()
.rounded_b_lg()
.bg(gpui::linear_gradient(
0.,
gpui::linear_color_stop(cx.theme().colors().editor_background, 0.),
gpui::linear_color_stop(cx.theme().colors().editor_background.opacity(0.), 1.),
));
let border_color = cx.theme().colors().border.opacity(0.6);
v_flex()
.mb_2()
.border_1()
.when(failed, |card| card.border_dashed())
.border_color(border_color)
.rounded_lg()
.overflow_hidden()
.child(codeblock_header)
.when(!failed && self.preview_expanded, |card| {
card.child(
v_flex()
.relative()
.overflow_hidden()
.border_t_1()
.border_color(border_color)
.bg(cx.theme().colors().editor_background)
.map(|editor_container| {
if self.full_height_expanded {
editor_container.h_full()
} else {
editor_container.max_h_64()
}
})
.child(div().pl_1().child(editor))
.when(!self.full_height_expanded, |editor_container| {
editor_container.child(gradient_overlay)
}),
)
})
.when(!failed && self.preview_expanded, |card| {
card.child(
h_flex()
.id(("edit-tool-card-inner-hflex", self.editor_unique_id))
.flex_none()
.cursor_pointer()
.h_5()
.justify_center()
.rounded_b_md()
.border_t_1()
.border_color(border_color)
.bg(cx.theme().colors().editor_background)
.hover(|style| style.bg(cx.theme().colors().element_hover.opacity(0.1)))
.child(
Icon::new(full_height_icon)
.size(IconSize::Small)
.color(Color::Muted),
)
.tooltip(Tooltip::text(full_height_tooltip_label))
.on_click(cx.listener(move |this, _event, _window, _cx| {
this.full_height_expanded = !this.full_height_expanded;
})),
)
})
}
}
async fn build_buffer(
mut text: String,
path: Arc<Path>,
language_registry: &Arc<language::LanguageRegistry>,
cx: &mut AsyncApp,
) -> Result<Entity<Buffer>> {
let line_ending = LineEnding::detect(&text);
LineEnding::normalize(&mut text);
let text = Rope::from(text);
let language = cx
.update(|_cx| language_registry.language_for_file_path(&path))?
.await
.ok();
let buffer = cx.new(|cx| {
let buffer = TextBuffer::new_normalized(
0,
cx.entity_id().as_non_zero_u64().into(),
line_ending,
text,
);
let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
buffer.set_language(language, cx);
buffer
})?;
Ok(buffer)
}
async fn build_buffer_diff(
mut old_text: String,
buffer: &Entity<Buffer>,
language_registry: &Arc<LanguageRegistry>,
cx: &mut AsyncApp,
) -> Result<Entity<BufferDiff>> {
LineEnding::normalize(&mut old_text);
let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
let base_buffer = cx
.update(|cx| {
Buffer::build_snapshot(
old_text.clone().into(),
buffer.language().cloned(),
Some(language_registry.clone()),
cx,
)
})?
.await;
let diff_snapshot = cx
.update(|cx| {
BufferDiffSnapshot::new_with_base_buffer(
buffer.text.clone(),
Some(old_text.into()),
base_buffer,
cx,
)
})?
.await;
cx.new(|cx| {
let mut diff = BufferDiff::new(&buffer.text, cx);
diff.set_snapshot(diff_snapshot, &buffer.text, cx);
diff
})
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -6,7 +6,7 @@ use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext as _, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
use http_client::{AsyncBody, HttpClientWithUrl};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -145,6 +145,7 @@ impl Tool for FetchTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<FetchToolInput>(input) {

View File

@@ -0,0 +1,421 @@
use crate::{schema::json_schema_for, ui::ToolCallCardHeader};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use editor::Editor;
use futures::channel::oneshot::{self, Receiver};
use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
};
use language;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::{cmp, path::PathBuf, sync::Arc};
use ui::{Disclosure, Tooltip, prelude::*};
use util::{ResultExt, paths::PathMatcher};
use workspace::Workspace;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct FindPathToolInput {
/// The glob to match against every path in the project.
///
/// <example>
/// If the project has the following root directories:
///
/// - directory1/a/something.txt
/// - directory2/a/things.txt
/// - directory3/a/other.txt
///
/// You can get back the first two paths by providing a glob of "*thing*.txt"
/// </example>
pub glob: String,
/// Optional starting position for paginated results (0-based).
/// When not provided, starts from the beginning.
#[serde(default)]
pub offset: usize,
}
const RESULTS_PER_PAGE: usize = 50;
pub struct FindPathTool;
impl Tool for FindPathTool {
fn name(&self) -> String {
"find_path".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("./find_path_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::SearchCode
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<FindPathToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<FindPathToolInput>(input.clone()) {
Ok(input) => format!("Find paths matching “`{}`”", input.glob),
Err(_) => "Search paths".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let (offset, glob) = match serde_json::from_value::<FindPathToolInput>(input) {
Ok(input) => (input.offset, input.glob),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let (sender, receiver) = oneshot::channel();
let card = cx.new(|cx| FindPathToolCard::new(glob.clone(), receiver, cx));
let search_paths_task = search_paths(&glob, project, cx);
let task = cx.background_spawn(async move {
let matches = search_paths_task.await?;
let paginated_matches: &[PathBuf] = &matches[cmp::min(offset, matches.len())
..cmp::min(offset + RESULTS_PER_PAGE, matches.len())];
sender.send(paginated_matches.to_vec()).log_err();
if matches.is_empty() {
Ok("No matches found".to_string())
} else {
let mut message = format!("Found {} total matches.", matches.len());
if matches.len() > RESULTS_PER_PAGE {
write!(
&mut message,
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
offset + 1,
offset + paginated_matches.len()
)
.unwrap();
}
for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) {
write!(&mut message, "\n{}", mat.display()).unwrap();
}
Ok(message)
}
});
ToolResult {
output: task,
card: Some(card.into()),
}
}
}
fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
let path_matcher = match PathMatcher::new([
// Sometimes models try to search for "". In this case, return all paths in the project.
if glob.is_empty() { "*" } else { glob },
]) {
Ok(matcher) => matcher,
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
};
let snapshots: Vec<_> = project
.read(cx)
.worktrees(cx)
.map(|worktree| worktree.read(cx).snapshot())
.collect();
cx.background_spawn(async move {
Ok(snapshots
.iter()
.flat_map(|snapshot| {
let root_name = PathBuf::from(snapshot.root_name());
snapshot
.entries(false, 0)
.map(move |entry| root_name.join(&entry.path))
.filter(|path| path_matcher.is_match(&path))
})
.collect())
})
}
struct FindPathToolCard {
paths: Vec<PathBuf>,
expanded: bool,
glob: String,
_receiver_task: Option<Task<Result<()>>>,
}
impl FindPathToolCard {
fn new(glob: String, receiver: Receiver<Vec<PathBuf>>, cx: &mut Context<Self>) -> Self {
let _receiver_task = cx.spawn(async move |this, cx| {
let paths = receiver.await?;
this.update(cx, |this, _cx| {
this.paths = paths;
})
.log_err();
Ok(())
});
Self {
paths: Vec::new(),
expanded: false,
glob,
_receiver_task: Some(_receiver_task),
}
}
}
impl ToolCard for FindPathToolCard {
fn render(
&mut self,
_status: &ToolUseStatus,
_window: &mut Window,
workspace: WeakEntity<Workspace>,
cx: &mut Context<Self>,
) -> impl IntoElement {
let matches_label: SharedString = if self.paths.len() == 0 {
"No matches".into()
} else if self.paths.len() == 1 {
"1 match".into()
} else {
format!("{} matches", self.paths.len()).into()
};
let glob_label = self.glob.to_string();
let content = if !self.paths.is_empty() && self.expanded {
Some(
v_flex()
.relative()
.ml_1p5()
.px_1p5()
.gap_0p5()
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.children(self.paths.iter().enumerate().map(|(index, path)| {
let path_clone = path.clone();
let workspace_clone = workspace.clone();
let button_label = path.to_string_lossy().to_string();
Button::new(("path", index), button_label)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
.icon_position(IconPosition::End)
.label_size(LabelSize::Small)
.color(Color::Muted)
.tooltip(Tooltip::text("Jump to File"))
.on_click(move |_, window, cx| {
workspace_clone
.update(cx, |workspace, cx| {
let path = PathBuf::from(&path_clone);
let Some(project_path) = workspace
.project()
.read(cx)
.find_project_path(&path, cx)
else {
return;
};
let open_task = workspace.open_path(
project_path,
None,
true,
window,
cx,
);
window
.spawn(cx, async move |cx| {
let item = open_task.await?;
if let Some(active_editor) =
item.downcast::<Editor>()
{
active_editor
.update_in(cx, |editor, window, cx| {
editor.go_to_singleton_buffer_point(
language::Point::new(0, 0),
window,
cx,
);
})
.log_err();
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
})
.ok();
})
}))
.into_any(),
)
} else {
None
};
v_flex()
.mb_2()
.gap_1()
.child(
ToolCallCardHeader::new(IconName::SearchCode, matches_label)
.with_code_path(glob_label)
.disclosure_slot(
Disclosure::new("path-search-disclosure", self.expanded)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.disabled(self.paths.is_empty())
.on_click(cx.listener(move |this, _, _, _cx| {
this.expanded = !this.expanded;
})),
),
)
.children(content)
}
}
impl Component for FindPathTool {
fn scope() -> ComponentScope {
ComponentScope::Agent
}
fn sort_name() -> &'static str {
"FindPathTool"
}
fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
let successful_card = cx.new(|_| FindPathToolCard {
paths: vec![
PathBuf::from("src/main.rs"),
PathBuf::from("src/lib.rs"),
PathBuf::from("tests/test.rs"),
],
expanded: true,
glob: "*.rs".to_string(),
_receiver_task: None,
});
let empty_card = cx.new(|_| FindPathToolCard {
paths: Vec::new(),
expanded: false,
glob: "*.nonexistent".to_string(),
_receiver_task: None,
});
Some(
v_flex()
.gap_6()
.children(vec![example_group(vec![
single_example(
"With Paths",
div()
.size_full()
.child(successful_card.update(cx, |tool, cx| {
tool.render(
&ToolUseStatus::Finished("".into()),
window,
WeakEntity::new_invalid(),
cx,
)
.into_any_element()
}))
.into_any_element(),
),
single_example(
"No Paths",
div()
.size_full()
.child(empty_card.update(cx, |tool, cx| {
tool.render(
&ToolUseStatus::Finished("".into()),
window,
WeakEntity::new_invalid(),
cx,
)
.into_any_element()
}))
.into_any_element(),
),
])])
.into_any_element(),
)
}
}
#[cfg(test)]
mod test {
use super::*;
use gpui::TestAppContext;
use project::{FakeFs, Project};
use settings::SettingsStore;
use util::path;
#[gpui::test]
async fn test_find_path_tool(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
serde_json::json!({
"apple": {
"banana": {
"carrot": "1",
},
"bandana": {
"carbonara": "2",
},
"endive": "3"
}
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let matches = cx
.update(|cx| search_paths("root/**/car*", project.clone(), cx))
.await
.unwrap();
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
let matches = cx
.update(|cx| search_paths("**/car*", project.clone(), cx))
.await
.unwrap();
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
}
}

View File

@@ -1,4 +1,4 @@
Fast file pattern matching tool that works with any codebase size
Fast file path pattern matching tool that works with any codebase size
- Supports glob patterns like "**/*.js" or "src/**/*.ts"
- Returns matching file paths sorted alphabetically

View File

@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::StreamExt;
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language::OffsetRangeExt;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{
@@ -20,6 +20,8 @@ use util::paths::PathMatcher;
pub struct GrepToolInput {
/// A regex pattern to search for in the entire project. Note that the regex
/// will be parsed by the Rust `regex` crate.
///
/// Do NOT specify a path here! This will only be matched against the code **content**.
pub regex: String,
/// A glob pattern for the paths of files to include in the search.
@@ -96,6 +98,7 @@ impl Tool for GrepTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
const CONTEXT_LINES: u32 = 2;
@@ -405,7 +408,7 @@ mod tests {
) -> String {
let tool = Arc::new(GrepTool);
let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
let task = cx.update(|cx| tool.run(input, &[], project, action_log, cx));
let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx));
match task.output.await {
Ok(result) => result,

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -76,6 +76,7 @@ impl Tool for ListDirectoryTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {

View File

@@ -1 +1 @@
Lists files and directories in a given path. Prefer the `grep` or `path_search` tools when searching the codebase.
Lists files and directories in a given path. Prefer the `grep` or `find_path` tools when searching the codebase.

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -89,6 +89,7 @@ impl Tool for MovePathTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<MovePathToolInput>(input) {

View File

@@ -4,7 +4,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use chrono::{Local, Utc};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -59,6 +59,7 @@ impl Tool for NowTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
_cx: &mut App,
) -> ToolResult {
let input: NowToolInput = match serde_json::from_value(input) {

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -52,6 +52,7 @@ impl Tool for OpenTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input: OpenToolInput = match serde_json::from_value(input) {

View File

@@ -1,200 +0,0 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{cmp, fmt::Write as _, path::PathBuf, sync::Arc};
use ui::IconName;
use util::paths::PathMatcher;
use worktree::Snapshot;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct PathSearchToolInput {
/// The glob to match against every path in the project.
///
/// <example>
/// If the project has the following root directories:
///
/// - directory1/a/something.txt
/// - directory2/a/things.txt
/// - directory3/a/other.txt
///
/// You can get back the first two paths by providing a glob of "*thing*.txt"
/// </example>
pub glob: String,
/// Optional starting position for paginated results (0-based).
/// When not provided, starts from the beginning.
#[serde(default)]
pub offset: u32,
}
const RESULTS_PER_PAGE: usize = 50;
pub struct PathSearchTool;
impl Tool for PathSearchTool {
fn name(&self) -> String {
"path_search".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("./path_search_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::SearchCode
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<PathSearchToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<PathSearchToolInput>(input.clone()) {
Ok(input) => format!("Find paths matching “`{}`”", input.glob),
Err(_) => "Search paths".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
let (offset, glob) = match serde_json::from_value::<PathSearchToolInput>(input) {
Ok(input) => (input.offset, input.glob),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let offset = offset as usize;
let task = search_paths(&glob, project, cx);
cx.background_spawn(async move {
let matches = task.await?;
let paginated_matches = &matches[cmp::min(offset, matches.len())
..cmp::min(offset + RESULTS_PER_PAGE, matches.len())];
if matches.is_empty() {
Ok("No matches found".to_string())
} else {
let mut message = format!("Found {} total matches.", matches.len());
if matches.len() > RESULTS_PER_PAGE {
write!(
&mut message,
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
offset + 1,
offset + paginated_matches.len()
)
.unwrap();
}
for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) {
write!(&mut message, "\n{}", mat.display()).unwrap();
}
Ok(message)
}
})
.into()
}
}
fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
let path_matcher = match PathMatcher::new([
// Sometimes models try to search for "". In this case, return all paths in the project.
if glob.is_empty() { "*" } else { glob },
]) {
Ok(matcher) => matcher,
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
};
let snapshots: Vec<Snapshot> = project
.read(cx)
.worktrees(cx)
.map(|worktree| worktree.read(cx).snapshot())
.collect();
cx.background_spawn(async move {
Ok(snapshots
.iter()
.flat_map(|snapshot| {
let root_name = PathBuf::from(snapshot.root_name());
snapshot
.entries(false, 0)
.map(move |entry| root_name.join(&entry.path))
.filter(|path| path_matcher.is_match(&path))
})
.collect())
})
}
#[cfg(test)]
mod test {
use super::*;
use gpui::TestAppContext;
use project::{FakeFs, Project};
use settings::SettingsStore;
use util::path;
#[gpui::test]
async fn test_path_search_tool(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
serde_json::json!({
"apple": {
"banana": {
"carrot": "1",
},
"bandana": {
"carbonara": "2",
},
"endive": "3"
}
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let matches = cx
.update(|cx| search_paths("root/**/car*", project.clone(), cx))
.await
.unwrap();
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
let matches = cx
.update(|cx| search_paths("**/car*", project.clone(), cx))
.await
.unwrap();
assert_eq!(
matches,
&[
PathBuf::from("root/apple/banana/carrot"),
PathBuf::from("root/apple/bandana/carbonara")
]
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
}
}

View File

@@ -1,7 +1,8 @@
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use indoc::formatdoc;
use itertools::Itertools;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -39,7 +40,7 @@ pub struct ReadFileToolInput {
#[serde(default)]
pub start_line: Option<usize>,
/// Optional line number to end reading on (1-based index)
/// Optional line number to end reading on (1-based index, inclusive)
#[serde(default)]
pub end_line: Option<usize>,
}
@@ -87,6 +88,7 @@ impl Tool for ReadFileTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
@@ -126,7 +128,7 @@ impl Tool for ReadFileTool {
let start = input.start_line.unwrap_or(1);
let lines = text.split('\n').skip(start - 1);
if let Some(end) = input.end_line {
let count = end.saturating_sub(start).max(1); // Ensure at least 1 line
let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
Itertools::intersperse(lines.take(count), "\n").collect()
} else {
Itertools::intersperse(lines, "\n").collect()
@@ -193,7 +195,7 @@ mod test {
"path": "root/nonexistent_file.txt"
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, cx)
.run(input, &[], project.clone(), action_log, None, cx)
.output
})
.await;
@@ -223,7 +225,7 @@ mod test {
"path": "root/small_file.txt"
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, cx)
.run(input, &[], project.clone(), action_log, None, cx)
.output
})
.await;
@@ -253,7 +255,7 @@ mod test {
"path": "root/large_file.rs"
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log.clone(), cx)
.run(input, &[], project.clone(), action_log.clone(), None, cx)
.output
})
.await;
@@ -277,7 +279,7 @@ mod test {
"offset": 1
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, cx)
.run(input, &[], project.clone(), action_log, None, cx)
.output
})
.await;
@@ -323,11 +325,11 @@ mod test {
"end_line": 4
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, cx)
.run(input, &[], project.clone(), action_log, None, cx)
.output
})
.await;
assert_eq!(result.unwrap(), "Line 2\nLine 3");
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
}
fn init_test(cx: &mut TestAppContext) {

View File

@@ -1,6 +1,6 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language::{self, Buffer, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -87,6 +87,7 @@ impl Tool for RenameTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<RenameToolInput>(input) {

View File

@@ -1,6 +1,6 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AsyncApp, Entity, Task};
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -121,6 +121,7 @@ impl Tool for SymbolInfoTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<SymbolInfoToolInput>(input) {

View File

@@ -3,7 +3,7 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::io::BufReader;
use futures::{AsyncBufReadExt, AsyncReadExt, FutureExt};
use gpui::{App, AppContext, Entity, Task};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -78,6 +78,7 @@ impl Tool for TerminalTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input: TerminalToolInput = match serde_json::from_value(input) {

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -50,6 +50,7 @@ impl Tool for ThinkingTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
_cx: &mut App,
) -> ToolResult {
// This tool just "thinks out loud" and doesn't perform any actions.

View File

@@ -1,4 +1,4 @@
use gpui::{Animation, AnimationExt, App, IntoElement, pulsating_between};
use gpui::{Animation, AnimationExt, AnyElement, App, IntoElement, pulsating_between};
use std::time::Duration;
use ui::{Tooltip, prelude::*};
@@ -8,6 +8,8 @@ pub struct ToolCallCardHeader {
icon: IconName,
primary_text: SharedString,
secondary_text: Option<SharedString>,
code_path: Option<SharedString>,
disclosure_slot: Option<AnyElement>,
is_loading: bool,
error: Option<String>,
}
@@ -18,6 +20,8 @@ impl ToolCallCardHeader {
icon,
primary_text: primary_text.into(),
secondary_text: None,
code_path: None,
disclosure_slot: None,
is_loading: false,
error: None,
}
@@ -28,6 +32,16 @@ impl ToolCallCardHeader {
self
}
pub fn with_code_path(mut self, text: impl Into<SharedString>) -> Self {
self.code_path = Some(text.into());
self
}
pub fn disclosure_slot(mut self, element: impl IntoElement) -> Self {
self.disclosure_slot = Some(element.into_any_element());
self
}
pub fn loading(mut self) -> Self {
self.is_loading = true;
self
@@ -42,26 +56,36 @@ impl ToolCallCardHeader {
impl RenderOnce for ToolCallCardHeader {
fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement {
let font_size = rems(0.8125);
let line_height = window.line_height();
let secondary_text = self.secondary_text;
let code_path = self.code_path;
let bullet_divider = || {
div()
.size(px(3.))
.rounded_full()
.bg(cx.theme().colors().text)
};
h_flex()
.id("tool-label-container")
.gap_1p5()
.gap_2()
.max_w_full()
.overflow_x_scroll()
.opacity(0.8)
.child(
h_flex().h(window.line_height()).justify_center().child(
Icon::new(self.icon)
.size(IconSize::XSmall)
.color(Color::Muted),
),
)
.child(
h_flex()
.h(window.line_height())
.h(line_height)
.gap_1p5()
.text_size(font_size)
.child(
h_flex().h(line_height).justify_center().child(
Icon::new(self.icon)
.size(IconSize::XSmall)
.color(Color::Muted),
),
)
.map(|this| {
if let Some(error) = &self.error {
this.child(format!("{} failed", self.primary_text)).child(
@@ -76,13 +100,15 @@ impl RenderOnce for ToolCallCardHeader {
}
})
.when_some(secondary_text, |this, secondary_text| {
this.child(
div()
.size(px(3.))
.rounded_full()
.bg(cx.theme().colors().text),
this.child(bullet_divider())
.child(div().text_size(font_size).child(secondary_text.clone()))
})
.when_some(code_path, |this, code_path| {
this.child(bullet_divider()).child(
Label::new(code_path.clone())
.size(LabelSize::Small)
.inline_code(cx),
)
.child(div().text_size(font_size).child(secondary_text.clone()))
})
.with_animation(
"loading-label",
@@ -98,5 +124,11 @@ impl RenderOnce for ToolCallCardHeader {
},
),
)
.when_some(self.disclosure_slot, |container, disclosure_slot| {
container
.group("disclosure")
.justify_between()
.child(div().visible_on_hover("disclosure").child(disclosure_slot))
})
}
}

View File

@@ -5,13 +5,16 @@ use crate::ui::ToolCallCardHeader;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{Future, FutureExt, TryFutureExt};
use gpui::{App, AppContext, Context, Entity, IntoElement, Task, Window};
use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use ui::{IconName, Tooltip, prelude::*};
use web_search::WebSearchRegistry;
use workspace::Workspace;
use zed_llm_client::{WebSearchCitation, WebSearchResponse};
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
@@ -54,6 +57,7 @@ impl Tool for WebSearchTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
@@ -111,6 +115,7 @@ impl ToolCard for WebSearchToolCard {
&mut self,
_status: &ToolUseStatus,
_window: &mut Window,
_workspace: WeakEntity<Workspace>,
cx: &mut Context<Self>,
) -> impl IntoElement {
let header = match self.response.as_ref() {
@@ -220,8 +225,13 @@ impl Component for WebSearchTool {
div()
.size_full()
.child(in_progress_search.update(cx, |tool, cx| {
tool.render(&ToolUseStatus::Pending, window, cx)
.into_any_element()
tool.render(
&ToolUseStatus::Pending,
window,
WeakEntity::new_invalid(),
cx,
)
.into_any_element()
}))
.into_any_element(),
),
@@ -230,8 +240,13 @@ impl Component for WebSearchTool {
div()
.size_full()
.child(successful_search.update(cx, |tool, cx| {
tool.render(&ToolUseStatus::Finished("".into()), window, cx)
.into_any_element()
tool.render(
&ToolUseStatus::Finished("".into()),
window,
WeakEntity::new_invalid(),
cx,
)
.into_any_element()
}))
.into_any_element(),
),
@@ -240,8 +255,13 @@ impl Component for WebSearchTool {
div()
.size_full()
.child(error_search.update(cx, |tool, cx| {
tool.render(&ToolUseStatus::Error("".into()), window, cx)
.into_any_element()
tool.render(
&ToolUseStatus::Error("".into()),
window,
WeakEntity::new_invalid(),
cx,
)
.into_any_element()
}))
.into_any_element(),
),

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use anyhow::{Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
use gpui::{App, Entity, Task};
use gpui::{AnyWindowHandle, App, Entity, Task};
use icons::IconName;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -77,6 +77,7 @@ impl Tool for ContextServerTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {

View File

@@ -1,4 +1,7 @@
use crate::code_context_menus::{CompletionsMenu, SortableMatch};
use crate::{
code_context_menus::{CompletionsMenu, SortableMatch},
editor_settings::SnippetSortOrder,
};
use fuzzy::StringMatch;
use gpui::TestAppContext;
@@ -74,7 +77,7 @@ fn test_sort_matches_local_variable_over_global_variable(_cx: &mut TestAppContex
sort_key: (2, "floorf128"),
},
];
CompletionsMenu::sort_matches(&mut matches, query);
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::default());
assert_eq!(
matches[0].string_match.string.as_str(),
"foo_bar_qux",
@@ -122,7 +125,7 @@ fn test_sort_matches_local_variable_over_global_variable(_cx: &mut TestAppContex
sort_key: (1, "foo_bar_qux"),
},
];
CompletionsMenu::sort_matches(&mut matches, query);
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::default());
assert_eq!(
matches[0].string_match.string.as_str(),
"foo_bar_qux",
@@ -185,7 +188,7 @@ fn test_sort_matches_local_variable_over_global_enum(_cx: &mut TestAppContext) {
sort_key: (0, "while let"),
},
];
CompletionsMenu::sort_matches(&mut matches, query);
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::default());
assert_eq!(
matches[0].string_match.string.as_str(),
"element_type",
@@ -234,7 +237,7 @@ fn test_sort_matches_local_variable_over_global_enum(_cx: &mut TestAppContext) {
sort_key: (2, "REPLACEMENT_CHARACTER"),
},
];
CompletionsMenu::sort_matches(&mut matches, query);
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::default());
assert_eq!(
matches[0].string_match.string.as_str(),
"element_type",
@@ -272,7 +275,7 @@ fn test_sort_matches_local_variable_over_global_enum(_cx: &mut TestAppContext) {
sort_key: (1, "element_type"),
},
];
CompletionsMenu::sort_matches(&mut matches, query);
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::default());
assert_eq!(
matches[0].string_match.string.as_str(),
"ElementType",
@@ -335,7 +338,7 @@ fn test_sort_matches_for_unreachable(_cx: &mut TestAppContext) {
sort_key: (2, "unreachable_unchecked"),
},
];
CompletionsMenu::sort_matches(&mut matches, query);
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::default());
assert_eq!(
matches[0].string_match.string.as_str(),
"unreachable!(…)",
@@ -379,7 +382,7 @@ fn test_sort_matches_for_unreachable(_cx: &mut TestAppContext) {
sort_key: (3, "unreachable_unchecked"),
},
];
CompletionsMenu::sort_matches(&mut matches, query);
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::default());
assert_eq!(
matches[0].string_match.string.as_str(),
"unreachable!(…)",
@@ -423,7 +426,7 @@ fn test_sort_matches_for_unreachable(_cx: &mut TestAppContext) {
sort_key: (2, "unreachable_unchecked"),
},
];
CompletionsMenu::sort_matches(&mut matches, query);
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::default());
assert_eq!(
matches[0].string_match.string.as_str(),
"unreachable!(…)",
@@ -467,7 +470,7 @@ fn test_sort_matches_for_unreachable(_cx: &mut TestAppContext) {
sort_key: (2, "unreachable_unchecked"),
},
];
CompletionsMenu::sort_matches(&mut matches, query);
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::default());
assert_eq!(
matches[0].string_match.string.as_str(),
"unreachable!(…)",
@@ -503,7 +506,7 @@ fn test_sort_matches_variable_and_constants_over_function(_cx: &mut TestAppConte
sort_key: (1, "var"), // variable
},
];
CompletionsMenu::sort_matches(&mut matches, query);
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::default());
assert_eq!(
matches[0].string_match.candidate_id, 1,
"Match order not expected"
@@ -539,7 +542,7 @@ fn test_sort_matches_variable_and_constants_over_function(_cx: &mut TestAppConte
sort_key: (2, "var"), // constant
},
];
CompletionsMenu::sort_matches(&mut matches, query);
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::default());
assert_eq!(
matches[0].string_match.candidate_id, 1,
"Match order not expected"
@@ -622,7 +625,7 @@ fn test_sort_matches_jsx_event_handler(_cx: &mut TestAppContext) {
sort_key: (3, "className?"),
},
];
CompletionsMenu::sort_matches(&mut matches, query);
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::default());
assert_eq!(
matches[0].string_match.string, "onCut?",
"Match order not expected"
@@ -944,7 +947,7 @@ fn test_sort_matches_jsx_event_handler(_cx: &mut TestAppContext) {
sort_key: (3, "onLoadedData?"),
},
];
CompletionsMenu::sort_matches(&mut matches, query);
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::default());
assert_eq!(
matches
.iter()
@@ -996,10 +999,159 @@ fn test_sort_matches_for_snippets(_cx: &mut TestAppContext) {
sort_key: (2, "println!(…)"),
},
];
CompletionsMenu::sort_matches(&mut matches, query);
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::Top);
assert_eq!(
matches[0].string_match.string.as_str(),
"println!(…)",
"Match order not expected"
);
}
#[gpui::test]
fn test_sort_matches_for_exact_match(_cx: &mut TestAppContext) {
// Case 1: "set_text"
let query: Option<&str> = Some("set_text");
let mut matches: Vec<SortableMatch<'_>> = vec![
SortableMatch {
string_match: StringMatch {
candidate_id: 0,
score: 1.0,
positions: vec![],
string: "set_text".to_string(),
},
is_snippet: false,
sort_text: Some("7fffffff"),
sort_key: (3, "set_text"),
},
SortableMatch {
string_match: StringMatch {
candidate_id: 0,
score: 0.32000000000000006,
positions: vec![],
string: "set_placeholder_text".to_string(),
},
is_snippet: false,
sort_text: Some("7fffffff"),
sort_key: (3, "set_placeholder_text"),
},
SortableMatch {
string_match: StringMatch {
candidate_id: 0,
score: 0.32,
positions: vec![],
string: "set_text_style_refinement".to_string(),
},
is_snippet: false,
sort_text: Some("7fffffff"),
sort_key: (3, "set_text_style_refinement"),
},
SortableMatch {
string_match: StringMatch {
candidate_id: 0,
score: 0.16666666666666666,
positions: vec![],
string: "set_context_menu_options".to_string(),
},
is_snippet: false,
sort_text: Some("7fffffff"),
sort_key: (3, "set_context_menu_options"),
},
SortableMatch {
string_match: StringMatch {
candidate_id: 0,
score: 0.08695652173913043,
positions: vec![],
string: "select_to_next_word_end".to_string(),
},
is_snippet: false,
sort_text: Some("7fffffff"),
sort_key: (3, "select_to_next_word_end"),
},
SortableMatch {
string_match: StringMatch {
candidate_id: 0,
score: 0.07692307692307693,
positions: vec![],
string: "select_to_next_subword_end".to_string(),
},
is_snippet: false,
sort_text: Some("7fffffff"),
sort_key: (3, "select_to_next_subword_end"),
},
SortableMatch {
string_match: StringMatch {
candidate_id: 0,
score: 0.06956521739130435,
positions: vec![],
string: "set_custom_context_menu".to_string(),
},
is_snippet: false,
sort_text: Some("7fffffff"),
sort_key: (3, "set_custom_context_menu"),
},
SortableMatch {
string_match: StringMatch {
candidate_id: 0,
score: 0.06,
positions: vec![],
string: "select_to_end_of_excerpt".to_string(),
},
is_snippet: false,
sort_text: Some("7fffffff"),
sort_key: (3, "select_to_end_of_excerpt"),
},
SortableMatch {
string_match: StringMatch {
candidate_id: 0,
score: 0.055384615384615386,
positions: vec![],
string: "select_to_start_of_excerpt".to_string(),
},
is_snippet: false,
sort_text: Some("7fffffff"),
sort_key: (3, "select_to_start_of_excerpt"),
},
SortableMatch {
string_match: StringMatch {
candidate_id: 0,
score: 0.0464516129032258,
positions: vec![],
string: "select_to_start_of_next_excerpt".to_string(),
},
is_snippet: false,
sort_text: Some("7fffffff"),
sort_key: (3, "select_to_start_of_next_excerpt"),
},
SortableMatch {
string_match: StringMatch {
candidate_id: 0,
score: 0.04363636363636363,
positions: vec![],
string: "select_to_end_of_previous_excerpt".to_string(),
},
is_snippet: false,
sort_text: Some("7fffffff"),
sort_key: (3, "select_to_end_of_previous_excerpt"),
},
];
CompletionsMenu::sort_matches(&mut matches, query, SnippetSortOrder::Top);
assert_eq!(
matches
.iter()
.map(|m| m.string_match.string.as_str())
.collect::<Vec<&str>>(),
vec![
"set_text",
"set_context_menu_options",
"set_placeholder_text",
"set_text_style_refinement",
"select_to_end_of_excerpt",
"select_to_end_of_previous_excerpt",
"select_to_next_subword_end",
"select_to_next_word_end",
"select_to_start_of_excerpt",
"select_to_start_of_next_excerpt",
"set_custom_context_menu"
]
);
}

View File

@@ -25,6 +25,7 @@ use task::ResolvedTask;
use ui::{Color, IntoElement, ListItem, Pixels, Popover, Styled, prelude::*};
use util::ResultExt;
use crate::editor_settings::SnippetSortOrder;
use crate::hover_popover::{hover_markdown_style, open_markdown_url};
use crate::{
CodeActionProvider, CompletionId, CompletionItemKind, CompletionProvider, DisplayRow, Editor,
@@ -184,6 +185,7 @@ pub struct CompletionsMenu {
pub(super) ignore_completion_provider: bool,
last_rendered_range: Rc<RefCell<Option<Range<usize>>>>,
markdown_element: Option<Entity<Markdown>>,
snippet_sort_order: SnippetSortOrder,
}
impl CompletionsMenu {
@@ -195,6 +197,7 @@ impl CompletionsMenu {
initial_position: Anchor,
buffer: Entity<Buffer>,
completions: Box<[Completion]>,
snippet_sort_order: SnippetSortOrder,
) -> Self {
let match_candidates = completions
.iter()
@@ -217,6 +220,7 @@ impl CompletionsMenu {
resolve_completions: true,
last_rendered_range: RefCell::new(None).into(),
markdown_element: None,
snippet_sort_order,
}
}
@@ -226,6 +230,7 @@ impl CompletionsMenu {
choices: &Vec<String>,
selection: Range<Anchor>,
buffer: Entity<Buffer>,
snippet_sort_order: SnippetSortOrder,
) -> Self {
let completions = choices
.iter()
@@ -275,6 +280,7 @@ impl CompletionsMenu {
ignore_completion_provider: false,
last_rendered_range: RefCell::new(None).into(),
markdown_element: None,
snippet_sort_order,
}
}
@@ -657,11 +663,15 @@ impl CompletionsMenu {
)
}
pub fn sort_matches(matches: &mut Vec<SortableMatch<'_>>, query: Option<&str>) {
pub fn sort_matches(
matches: &mut Vec<SortableMatch<'_>>,
query: Option<&str>,
snippet_sort_order: SnippetSortOrder,
) {
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
enum MatchTier<'a> {
WordStartMatch {
sort_score_int: Reverse<i32>,
sort_bucket: Reverse<i32>,
sort_snippet: Reverse<i32>,
sort_text: Option<&'a str>,
sort_key: (usize, &'a str),
@@ -674,11 +684,9 @@ impl CompletionsMenu {
// Our goal here is to intelligently sort completion suggestions. We want to
// balance the raw fuzzy match score with hints from the language server
//
// We first primary sort using fuzzy score by putting matches into two buckets
// strong one and weak one. Among these buckets matches are then compared by
// We first primary sort using fuzzy score by putting matches into multiple
// buckets. Among these buckets matches are then compared by
// various criteria like snippet, LSP hints, kind, label text etc.
//
const FUZZY_THRESHOLD: f64 = 0.1317;
let query_start_lower = query
.and_then(|q| q.chars().next())
@@ -702,10 +710,20 @@ impl CompletionsMenu {
let sort_score = Reverse(OrderedFloat(score));
MatchTier::OtherMatch { sort_score }
} else {
let sort_score_int = Reverse(if score >= FUZZY_THRESHOLD { 1 } else { 0 });
let sort_snippet = Reverse(if mat.is_snippet { 1 } else { 0 });
// Convert fuzzy match score (0.0-1.0) to a priority bucket (0-3)
let sort_bucket = Reverse(match (score * 10.0).floor() as i32 {
s if s >= 7 => 3,
s if s >= 1 => 2,
s if s > 0 => 1,
_ => 0,
});
let sort_snippet = match snippet_sort_order {
SnippetSortOrder::Top => Reverse(if mat.is_snippet { 1 } else { 0 }),
SnippetSortOrder::Bottom => Reverse(if mat.is_snippet { 0 } else { 1 }),
SnippetSortOrder::Inline => Reverse(0),
};
MatchTier::WordStartMatch {
sort_score_int,
sort_bucket,
sort_snippet,
sort_text: mat.sort_text,
sort_key: mat.sort_key,
@@ -770,7 +788,7 @@ impl CompletionsMenu {
})
.collect();
Self::sort_matches(&mut sortable_items, query);
Self::sort_matches(&mut sortable_items, query, self.snippet_sort_order);
matches = sortable_items
.into_iter()

View File

@@ -427,6 +427,8 @@ pub enum EditorMode {
scale_ui_elements_with_buffer_font_size: bool,
/// When set to `true`, the editor will render a background for the active line.
show_active_line_background: bool,
/// When set to `true`, the editor's height will be determined by its content.
sized_by_content: bool,
},
}
@@ -435,6 +437,7 @@ impl EditorMode {
Self::Full {
scale_ui_elements_with_buffer_font_size: true,
show_active_line_background: true,
sized_by_content: false,
}
}
@@ -788,6 +791,8 @@ pub struct Editor {
show_breadcrumbs: bool,
show_gutter: bool,
show_scrollbars: bool,
disable_scrolling: bool,
disable_expand_excerpt_buttons: bool,
show_line_numbers: Option<bool>,
use_relative_line_numbers: Option<bool>,
show_git_diff_gutter: Option<bool>,
@@ -1563,11 +1568,13 @@ impl Editor {
blink_manager: blink_manager.clone(),
show_local_selections: true,
show_scrollbars: true,
disable_scrolling: false,
mode,
show_breadcrumbs: EditorSettings::get_global(cx).toolbar.breadcrumbs,
show_gutter: mode.is_full(),
show_line_numbers: None,
use_relative_line_numbers: None,
disable_expand_excerpt_buttons: false,
show_git_diff_gutter: None,
show_code_actions: None,
show_runnables: None,
@@ -4592,6 +4599,8 @@ impl Editor {
.as_ref()
.map_or(true, |provider| provider.filter_completions());
let snippet_sort_order = EditorSettings::get_global(cx).snippet_sort_order;
let id = post_inc(&mut self.next_completion_id);
let task = cx.spawn_in(window, async move |editor, cx| {
async move {
@@ -4639,6 +4648,7 @@ impl Editor {
position,
buffer.clone(),
completions.into(),
snippet_sort_order,
);
menu.filter(
@@ -8009,10 +8019,18 @@ impl Editor {
let buffer_id = selection.start.buffer_id.unwrap();
let buffer = self.buffer().read(cx).buffer(buffer_id);
let id = post_inc(&mut self.next_completion_id);
let snippet_sort_order = EditorSettings::get_global(cx).snippet_sort_order;
if let Some(buffer) = buffer {
*self.context_menu.borrow_mut() = Some(CodeContextMenu::Completions(
CompletionsMenu::new_snippet_choices(id, true, choices, selection, buffer),
CompletionsMenu::new_snippet_choices(
id,
true,
choices,
selection,
buffer,
snippet_sort_order,
),
));
}
}
@@ -16146,11 +16164,21 @@ impl Editor {
cx.notify();
}
pub fn disable_scrolling(&mut self, cx: &mut Context<Self>) {
self.disable_scrolling = true;
cx.notify();
}
pub fn set_show_line_numbers(&mut self, show_line_numbers: bool, cx: &mut Context<Self>) {
self.show_line_numbers = Some(show_line_numbers);
cx.notify();
}
pub fn disable_expand_excerpt_buttons(&mut self, cx: &mut Context<Self>) {
self.disable_expand_excerpt_buttons = true;
cx.notify();
}
pub fn set_show_git_diff_gutter(&mut self, show_git_diff_gutter: bool, cx: &mut Context<Self>) {
self.show_git_diff_gutter = Some(show_git_diff_gutter);
cx.notify();

View File

@@ -39,6 +39,7 @@ pub struct EditorSettings {
pub go_to_definition_fallback: GoToDefinitionFallback,
pub jupyter: Jupyter,
pub hide_mouse: Option<HideMouseMode>,
pub snippet_sort_order: SnippetSortOrder,
}
#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
@@ -239,6 +240,21 @@ pub enum HideMouseMode {
OnTypingAndMovement,
}
/// Determines how snippets are sorted relative to other completion items.
///
/// Default: inline
#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum SnippetSortOrder {
/// Place snippets at the top of the completion list
Top,
/// Sort snippets normally using the default comparison logic
#[default]
Inline,
/// Place snippets at the bottom of the completion list
Bottom,
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
pub struct EditorSettingsContent {
/// Whether the cursor blinks in the editor.
@@ -254,6 +270,10 @@ pub struct EditorSettingsContent {
///
/// Default: on_typing_and_movement
pub hide_mouse: Option<HideMouseMode>,
/// Determines how snippets are sorted relative to other completion items.
///
/// Default: inline
pub snippet_sort_order: Option<SnippetSortOrder>,
/// How to highlight the current line in the editor.
///
/// Default: all

View File

@@ -10419,6 +10419,7 @@ async fn test_completion_in_multibuffer_with_replace_range(cx: &mut TestAppConte
EditorMode::Full {
scale_ui_elements_with_buffer_font_size: false,
show_active_line_background: false,
sized_by_content: false,
},
multi_buffer.clone(),
Some(project.clone()),

View File

@@ -2183,6 +2183,10 @@ impl EditorElement {
window: &mut Window,
cx: &mut App,
) -> Vec<Option<(AnyElement, gpui::Point<Pixels>)>> {
if self.editor.read(cx).disable_expand_excerpt_buttons {
return vec![];
}
let editor_font_size = self.style.text.font_size.to_pixels(window.rem_size()) * 1.2;
let scroll_top = scroll_position.y * line_height;
@@ -5512,7 +5516,9 @@ impl EditorElement {
}
fn paint_mouse_listeners(&mut self, layout: &EditorLayout, window: &mut Window, cx: &mut App) {
self.paint_scroll_wheel_listener(layout, window, cx);
if !self.editor.read(cx).disable_scrolling {
self.paint_scroll_wheel_listener(layout, window, cx);
}
window.on_mouse_event({
let position_map = layout.position_map.clone();
@@ -6563,10 +6569,21 @@ impl Element for EditorElement {
},
)
}
EditorMode::Full { .. } => {
EditorMode::Full {
sized_by_content, ..
} => {
let mut style = Style::default();
style.size.width = relative(1.).into();
style.size.height = relative(1.).into();
if sized_by_content {
let snapshot = editor.snapshot(window, cx);
let line_height =
self.style.text.line_height_in_pixels(window.rem_size());
let scroll_height =
(snapshot.max_point().row().next_row().0 as f32) * line_height;
style.size.height = scroll_height.into();
} else {
style.size.height = relative(1.).into();
}
window.request_layout(style, None, cx)
}
};

View File

@@ -11,6 +11,7 @@ assistant_tool.workspace = true
assistant_tools.workspace = true
async-trait.workspace = true
async-watch.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
clap.workspace = true
client.workspace = true

View File

@@ -1,6 +1,7 @@
use std::{
error::Error,
fmt::{self, Debug},
path::Path,
sync::{Arc, Mutex},
time::Duration,
};
@@ -9,9 +10,11 @@ use crate::{
ToolMetrics,
assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
};
use agent::ThreadEvent;
use agent::{ContextLoadResult, ThreadEvent};
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use buffer_diff::DiffHunkStatus;
use collections::HashMap;
use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
use gpui::{AppContext, AsyncApp, Entity};
use language_model::{LanguageModel, Role, StopReason};
@@ -112,7 +115,12 @@ impl ExampleContext {
pub fn push_user_message(&mut self, text: impl ToString) {
self.app
.update_entity(&self.agent_thread, |thread, cx| {
thread.insert_user_message(text.to_string(), vec![], None, cx);
thread.insert_user_message(
text.to_string(),
ContextLoadResult::default(),
None,
cx,
);
})
.unwrap();
}
@@ -234,9 +242,9 @@ impl ExampleContext {
let mut tool_metrics = tool_metrics.lock().unwrap();
if let Some(tool_result) = thread.tool_result(&tool_use_id) {
let message = if tool_result.is_error {
format!("TOOL FAILED: {}", tool_use.name)
format!("✖︎ {}", tool_use.name)
} else {
format!("TOOL FINISHED: {}", tool_use.name)
format!("✔︎ {}", tool_use.name)
};
println!("{log_prefix}{message}");
tool_metrics
@@ -250,6 +258,9 @@ impl ExampleContext {
}
});
}
ThreadEvent::InvalidToolInput { .. } => {
println!("{log_prefix} invalid tool input");
}
ThreadEvent::ToolConfirmationNeeded => {
panic!(
"{}Bug: Tool confirmation should not be required in eval",
@@ -278,7 +289,7 @@ impl ExampleContext {
let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
thread.set_remaining_turns(iterations);
thread.send_to_model(model, cx);
thread.send_to_model(model, None, cx);
thread.messages().len()
})?;
@@ -320,6 +331,36 @@ impl ExampleContext {
Ok(response)
}
pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
self.app
.read_entity(&self.agent_thread, |thread, cx| {
let action_log = thread.action_log().read(cx);
HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
|(buffer, diff)| {
let snapshot = buffer.read(cx).snapshot();
let file = snapshot.file().unwrap();
let diff = diff.read(cx);
let base_text = diff.base_text().text();
let hunks = diff
.hunks(&snapshot, cx)
.map(|hunk| FileEditHunk {
base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
text: snapshot
.text_for_range(hunk.range.clone())
.collect::<String>(),
status: hunk.status(),
})
.collect();
(file.path().clone(), FileEdits { hunks })
},
))
})
.unwrap()
}
}
#[derive(Debug)]
@@ -344,6 +385,10 @@ impl Response {
});
cx.assert_some(result, format!("called `{}`", tool_name))
}
pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
self.messages.iter().flat_map(|msg| &msg.tool_use)
}
}
#[derive(Debug)]
@@ -355,17 +400,37 @@ pub struct Message {
#[derive(Debug)]
pub struct ToolUse {
name: String,
pub name: String,
value: serde_json::Value,
}
impl ToolUse {
pub fn expect_input<Input>(&self, cx: &mut ExampleContext) -> Result<Input>
pub fn parse_input<Input>(&self) -> Result<Input>
where
Input: for<'de> serde::Deserialize<'de>,
{
let result =
serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err));
cx.log_assertion(result, format!("valid `{}` input", &self.name))
serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
}
}
#[derive(Debug)]
pub struct FileEdits {
hunks: Vec<FileEditHunk>,
}
#[derive(Debug)]
struct FileEditHunk {
base_text: String,
text: String,
status: DiffHunkStatus,
}
impl FileEdits {
pub fn has_added_line(&self, line: &str) -> bool {
self.hunks.iter().any(|hunk| {
hunk.status == DiffHunkStatus::added_none()
&& hunk.base_text.is_empty()
&& hunk.text.contains(line)
})
}
}

View File

@@ -0,0 +1,147 @@
use std::{collections::HashSet, path::Path};
use anyhow::Result;
use assistant_tools::{CreateFileToolInput, EditFileToolInput, ReadFileToolInput};
use async_trait::async_trait;
use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion, LanguageServer};
pub struct AddArgToTraitMethod;
#[async_trait(?Send)]
impl Example for AddArgToTraitMethod {
fn meta(&self) -> ExampleMetadata {
ExampleMetadata {
name: "add_arg_to_trait_method".to_string(),
url: "https://github.com/zed-industries/zed.git".to_string(),
revision: "f69aeb6311dde3c0b8979c293d019d66498d54f2".to_string(),
language_server: Some(LanguageServer {
file_extension: "rs".to_string(),
allow_preexisting_diagnostics: false,
}),
max_assertions: None,
}
}
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
const FILENAME: &str = "assistant_tool.rs";
cx.push_user_message(format!(
r#"
Add a `window: Option<gpui::AnyWindowHandle>` argument to the `Tool::run` trait method in {FILENAME},
and update all the implementations of the trait and call sites accordingly.
"#
));
let response = cx.run_to_end().await?;
// Reads files before it edits them
let mut read_files = HashSet::new();
for tool_use in response.tool_uses() {
match tool_use.name.as_str() {
"read_file" => {
if let Ok(input) = tool_use.parse_input::<ReadFileToolInput>() {
read_files.insert(input.path);
}
}
"create_file" => {
if let Ok(input) = tool_use.parse_input::<CreateFileToolInput>() {
read_files.insert(input.path);
}
}
"edit_file" => {
if let Ok(input) = tool_use.parse_input::<EditFileToolInput>() {
cx.assert(
read_files.contains(input.path.to_str().unwrap()),
format!(
"Read before edit: {}",
&input.path.file_stem().unwrap().to_str().unwrap()
),
)
.ok();
}
}
_ => {}
}
}
// Adds ignored argument to all but `batch_tool`
let add_ignored_window_paths = &[
"code_action_tool",
"code_symbols_tool",
"contents_tool",
"copy_path_tool",
"create_directory_tool",
"create_file_tool",
"delete_path_tool",
"diagnostics_tool",
"edit_file_tool",
"fetch_tool",
"grep_tool",
"list_directory_tool",
"move_path_tool",
"now_tool",
"open_tool",
"path_search_tool",
"read_file_tool",
"rename_tool",
"symbol_info_tool",
"terminal_tool",
"thinking_tool",
"web_search_tool",
];
let edits = cx.edits();
for tool_name in add_ignored_window_paths {
let path_str = format!("crates/assistant_tools/src/{}.rs", tool_name);
let edits = edits.get(Path::new(&path_str));
let ignored = edits.map_or(false, |edits| {
edits.has_added_line(" _window: Option<gpui::AnyWindowHandle>,\n")
});
let uningored = edits.map_or(false, |edits| {
edits.has_added_line(" window: Option<gpui::AnyWindowHandle>,\n")
});
cx.assert(ignored || uningored, format!("Argument: {}", tool_name))
.ok();
cx.assert(ignored, format!("`_` prefix: {}", tool_name))
.ok();
}
// Adds unignored argument to `batch_tool`
let batch_tool_edits = edits.get(Path::new("crates/assistant_tools/src/batch_tool.rs"));
cx.assert(
batch_tool_edits.map_or(false, |edits| {
edits.has_added_line(" window: Option<gpui::AnyWindowHandle>,\n")
}),
"Argument: batch_tool",
)
.ok();
Ok(())
}
fn diff_assertions(&self) -> Vec<JudgeAssertion> {
vec![
JudgeAssertion {
id: "batch tool passes window to each".to_string(),
description:
"batch_tool is modified to pass a clone of the window to each tool it calls."
.to_string(),
},
JudgeAssertion {
id: "tool tests updated".to_string(),
description:
"tool tests are updated to pass the new `window` argument (`None` is ok)."
.to_string(),
},
]
}
}

View File

@@ -1,5 +1,5 @@
use anyhow::Result;
use assistant_tools::PathSearchToolInput;
use assistant_tools::FindPathToolInput;
use async_trait::async_trait;
use regex::Regex;
@@ -15,7 +15,7 @@ impl Example for FileSearchExample {
url: "https://github.com/zed-industries/zed.git".to_string(),
revision: "03ecb88fe30794873f191ddb728f597935b3101c".to_string(),
language_server: None,
max_assertions: Some(4),
max_assertions: Some(3),
}
}
@@ -32,21 +32,18 @@ impl Example for FileSearchExample {
));
let response = cx.run_turn().await?;
let tool_use = response.expect_tool("path_search", cx)?;
let input = tool_use.expect_input::<PathSearchToolInput>(cx)?;
let tool_use = response.expect_tool("find_path", cx)?;
let input = tool_use.parse_input::<FindPathToolInput>()?;
let glob = input.glob;
cx.assert(
glob.ends_with(FILENAME),
format!("glob ends with `{FILENAME}`"),
)?;
cx.assert(glob.ends_with(FILENAME), "glob ends with file name")?;
let without_filename = glob.replace(FILENAME, "");
let matches = Regex::new("(\\*\\*|zed)/(\\*\\*?/)?")
.unwrap()
.is_match(&without_filename);
cx.assert(matches, "glob starts with either `**` or `zed`")?;
cx.assert(matches, "glob starts with `**` or project")?;
Ok(())
}

View File

@@ -11,10 +11,14 @@ use util::serde::default_true;
use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
mod add_arg_to_trait_method;
mod file_search;
pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
let mut threads: Vec<Rc<dyn Example>> = vec![Rc::new(file_search::FileSearchExample)];
let mut threads: Vec<Rc<dyn Example>> = vec![
Rc::new(file_search::FileSearchExample),
Rc::new(add_arg_to_trait_method::AddArgToTraitMethod),
];
for example_path in list_declarative_examples(examples_dir).unwrap() {
threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap()));

View File

@@ -218,8 +218,14 @@ impl ExampleInstance {
});
let tools = cx.new(|_| ToolWorkingSet::default());
let thread_store =
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
let prompt_store = None;
let thread_store = ThreadStore::load(
project.clone(),
tools,
prompt_store,
app_state.prompt_builder.clone(),
cx,
);
let meta = self.thread.meta();
let this = self.clone();

View File

@@ -52,7 +52,10 @@ pub async fn stream_generate_content(
if let Some(line) = line.strip_prefix("data: ") {
match serde_json::from_str(line) {
Ok(response) => Some(Ok(response)),
Err(error) => Some(Err(anyhow!(error))),
Err(error) => Some(Err(anyhow!(format!(
"Error parsing JSON: {:?}\n{:?}",
error, line
)))),
}
} else {
None
@@ -156,6 +159,7 @@ pub struct GenerateContentCandidate {
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Content {
#[serde(default)]
pub parts: Vec<Part>,
pub role: Role,
}

View File

@@ -15,7 +15,7 @@ use std::{
num::NonZeroU64,
sync::{
Arc, Weak,
atomic::{AtomicUsize, Ordering::SeqCst},
atomic::{AtomicU64, AtomicUsize, Ordering::SeqCst},
},
thread::panicking,
};
@@ -572,6 +572,30 @@ impl AnyWeakEntity {
)
}
}
/// Creates a weak entity that can never be upgraded.
pub fn new_invalid() -> Self {
/// To hold the invariant that all ids are unique, and considering that slotmap
/// increases their IDs from `0`, we can decrease ours from `u64::MAX` so these
/// two will never conflict (u64 is way too large).
static UNIQUE_NON_CONFLICTING_ID_GENERATOR: AtomicU64 = AtomicU64::new(u64::MAX);
let entity_id = UNIQUE_NON_CONFLICTING_ID_GENERATOR.fetch_sub(1, SeqCst);
Self {
// Safety:
// Docs say this is safe but can be unspecified if slotmap changes the representation
// after `1.0.7`, that said, providing a valid entity_id here is not necessary as long
// as we guarantee that that `entity_id` is never used if `entity_ref_counts` equals
// to `Weak::new()` (that is, it's unable to upgrade), that is the invariant that
// actually needs to be hold true.
//
// And there is no sane reason to read an entity slot if `entity_ref_counts` can't be
// read in the first place, so we're good!
entity_id: entity_id.into(),
entity_type: TypeId::of::<()>(),
entity_ref_counts: Weak::new(),
}
}
}
impl std::fmt::Debug for AnyWeakEntity {
@@ -707,6 +731,14 @@ impl<T: 'static> WeakEntity<T> {
.map(|this| cx.read_entity(&this, read)),
)
}
/// Create a new weak entity that can never be upgraded.
pub fn new_invalid() -> Self {
Self {
any_entity: AnyWeakEntity::new_invalid(),
entity_type: PhantomData,
}
}
}
impl<T> Hash for WeakEntity<T> {

View File

@@ -1,7 +1,7 @@
use crate::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
};
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
@@ -168,7 +168,12 @@ impl LanguageModel for FakeLanguageModel {
&self,
request: LanguageModelRequest,
_: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs.lock().push((request, tx));
async move {

View File

@@ -76,6 +76,19 @@ pub enum LanguageModelCompletionEvent {
UsageUpdate(TokenUsage),
}
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
#[error("received bad input JSON")]
BadInputJson {
id: LanguageModelToolUseId,
tool_name: Arc<str>,
raw_input: Arc<str>,
json_parse_error: String,
},
#[error(transparent)]
Other(#[from] anyhow::Error),
}
/// Indicates the format used to define the input schema for a language model tool.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum LanguageModelToolSchemaFormat {
@@ -186,13 +199,14 @@ where
pub struct LanguageModelToolUse {
pub id: LanguageModelToolUseId,
pub name: Arc<str>,
pub raw_input: String,
pub input: serde_json::Value,
pub is_input_complete: bool,
}
pub struct LanguageModelTextStream {
pub message_id: Option<String>,
pub stream: BoxStream<'static, Result<String>>,
pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
// Has complete token usage after the stream has finished
pub last_token_usage: Arc<Mutex<TokenUsage>>,
}
@@ -245,7 +259,12 @@ pub trait LanguageModel: Send + Sync {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
>;
fn stream_completion_with_usage(
&self,
@@ -254,7 +273,7 @@ pub trait LanguageModel: Send + Sync {
) -> BoxFuture<
'static,
Result<(
BoxStream<'static, Result<LanguageModelCompletionEvent>>,
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
Option<RequestUsage>,
)>,
> {

View File

@@ -12,10 +12,10 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, MessageContent,
RateLimiter, Role,
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
};
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use schemars::JsonSchema;
@@ -27,7 +27,7 @@ use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::{ResultExt, maybe};
use util::ResultExt;
const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
const PROVIDER_NAME: &str = "Anthropic";
@@ -448,7 +448,12 @@ impl LanguageModel for AnthropicModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = into_anthropic(
request,
self.model.request_id().into(),
@@ -626,7 +631,7 @@ pub fn into_anthropic(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
struct RawToolUse {
id: String,
name: String,
@@ -714,61 +719,58 @@ pub fn map_to_language_model_completion_events(
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json);
return Some((
vec![maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
// Try to convert invalid (incomplete) JSON into
// valid JSON that serde can accept, e.g. by closing
// unclosed delimiters. This way, we can update the
// UI with whatever has been streamed back so far.
if let Ok(input) = serde_json::Value::from_str(
&partial_json_fixer::fix_json(&tool_use.input_json),
) {
return Some((
vec![Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.clone().into(),
name: tool_use.name.clone().into(),
is_input_complete: false,
input: if tool_use.input_json.is_empty() {
serde_json::Value::Object(
serde_json::Map::default(),
)
} else {
serde_json::Value::from_str(
// Convert invalid (incomplete) JSON into
// JSON that serde will accept, e.g. by closing
// unclosed delimiters. This way, we can update
// the UI with whatever has been streamed back so far.
&partial_json_fixer::fix_json(
&tool_use.input_json,
),
)
.map_err(|err| anyhow!(err))?
},
raw_input: tool_use.input_json.clone(),
input,
},
))
})],
state,
));
))],
state,
));
}
}
}
},
Event::ContentBlockStop { index } => {
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
return Some((
vec![maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
input: if tool_use.input_json.is_empty() {
serde_json::Value::Object(
serde_json::Map::default(),
)
} else {
serde_json::Value::from_str(
&tool_use.input_json,
)
.map_err(|err| anyhow!(err))?
},
},
))
})],
state,
));
let input_json = tool_use.input_json.trim();
let input_value = if input_json.is_empty() {
Ok(serde_json::Value::Object(serde_json::Map::default()))
} else {
serde_json::Value::from_str(input_json)
};
let event_result = match input_value {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
input,
raw_input: tool_use.input_json.clone(),
},
)),
Err(json_parse_err) => {
Err(LanguageModelCompletionError::BadInputJson {
id: tool_use.id.into(),
tool_name: tool_use.name.into(),
raw_input: input_json.into(),
json_parse_error: json_parse_err.to_string(),
})
}
};
return Some((vec![event_result], state));
}
}
Event::MessageStart { message } => {
@@ -815,14 +817,19 @@ pub fn map_to_language_model_completion_events(
}
Event::Error { error } => {
return Some((
vec![Err(anyhow!(AnthropicError::ApiError(error)))],
vec![Err(LanguageModelCompletionError::Other(anyhow!(
AnthropicError::ApiError(error)
)))],
state,
));
}
_ => {}
},
Err(err) => {
return Some((vec![Err(anthropic_err_to_anyhow(err))], state));
return Some((
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
state,
));
}
}
}

View File

@@ -32,9 +32,10 @@ use gpui_tokio::Tokio;
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage,
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
RateLimiter, Role, TokenUsage,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -542,7 +543,12 @@ impl LanguageModel for BedrockModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
// Get region - from credentials or directly from settings
let region = state
@@ -780,7 +786,7 @@ pub fn get_bedrock_tokens(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
handle: Handle,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
struct RawToolUse {
id: String,
name: String,
@@ -894,6 +900,7 @@ pub fn map_to_language_model_completion_events(
id: tool_use.id.into(),
name: tool_use.name.into(),
is_input_complete: true,
raw_input: tool_use.input_json.clone(),
input: if tool_use.input_json.is_empty() {
Value::Null
} else {
@@ -970,7 +977,7 @@ pub fn map_to_language_model_completion_events(
_ => {}
},
Err(err) => return Some((Some(Err(anyhow!(err))), state)),
Err(err) => return Some((Some(Err(anyhow!(err).into())), state)),
}
}
None

View File

@@ -10,11 +10,11 @@ use futures::{
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
ZED_CLOUD_PROVIDER_ID,
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@@ -531,7 +531,7 @@ impl CloudLanguageModel {
{
request_builder.uri(completions_url)
} else {
request_builder.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
request_builder.uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
};
let request = request_builder
.header("Content-Type", "application/json")
@@ -711,7 +711,12 @@ impl LanguageModel for CloudLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
self.stream_completion_with_usage(request, cx)
.map(|result| result.map(|(stream, _)| stream))
.boxed()
@@ -724,7 +729,7 @@ impl LanguageModel for CloudLanguageModel {
) -> BoxFuture<
'static,
Result<(
BoxStream<'static, Result<LanguageModelCompletionEvent>>,
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
Option<RequestUsage>,
)>,
> {

View File

@@ -17,16 +17,16 @@ use gpui::{
Transformation, percentage, svg,
};
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelToolUse, MessageContent, RateLimiter, Role,
StopReason,
};
use settings::SettingsStore;
use std::time::Duration;
use strum::IntoEnumIterator;
use ui::prelude::*;
use util::maybe;
use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens;
@@ -242,7 +242,12 @@ impl LanguageModel for CopilotChatLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
if let Some(message) = request.messages.last() {
if message.contents_empty() {
const EMPTY_PROMPT_MSG: &str =
@@ -285,7 +290,7 @@ impl LanguageModel for CopilotChatLanguageModel {
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
is_streaming: bool,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
#[derive(Default)]
struct RawToolCall {
id: String,
@@ -309,7 +314,7 @@ pub fn map_to_language_model_completion_events(
Ok(event) => {
let Some(choice) = event.choices.first() else {
return Some((
vec![Err(anyhow!("Response contained no choices"))],
vec![Err(anyhow!("Response contained no choices").into())],
state,
));
};
@@ -322,7 +327,7 @@ pub fn map_to_language_model_completion_events(
let Some(delta) = delta else {
return Some((
vec![Err(anyhow!("Response contained no delta"))],
vec![Err(anyhow!("Response contained no delta").into())],
state,
));
};
@@ -361,19 +366,26 @@ pub fn map_to_language_model_completion_events(
}
Some("tool_calls") => {
events.extend(state.tool_calls_by_index.drain().map(
|(_, tool_call)| {
maybe!({
Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
input: serde_json::Value::from_str(
&tool_call.arguments,
)?,
},
))
})
|(_, tool_call)| match serde_json::Value::from_str(
&tool_call.arguments,
) {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.clone().into(),
name: tool_call.name.as_str().into(),
is_input_complete: true,
input,
raw_input: tool_call.arguments.clone(),
},
)),
Err(error) => {
Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
})
}
},
));
@@ -392,7 +404,7 @@ pub fn map_to_language_model_completion_events(
return Some((events, state));
}
Err(err) => return Some((vec![Err(err)], state)),
Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
}
}
@@ -452,9 +464,11 @@ impl CopilotChatLanguageModel {
}
}
messages.push(ChatMessage::User {
content: text_content,
});
if !text_content.is_empty() {
messages.push(ChatMessage::User {
content: text_content,
});
}
}
Role::Assistant => {
let mut tool_calls = Vec::new();

View File

@@ -9,9 +9,9 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -324,7 +324,12 @@ impl LanguageModel for DeepSeekLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = into_deepseek(
request,
self.model.id().to_string(),
@@ -336,20 +341,22 @@ impl LanguageModel for DeepSeekLanguageModel {
let stream = stream.await?;
Ok(stream
.map(|result| {
result.and_then(|response| {
response
.choices
.first()
.ok_or_else(|| anyhow!("Empty response"))
.map(|choice| {
choice
.delta
.content
.clone()
.unwrap_or_default()
.map(LanguageModelCompletionEvent::Text)
})
})
result
.and_then(|response| {
response
.choices
.first()
.ok_or_else(|| anyhow!("Empty response"))
.map(|choice| {
choice
.delta
.content
.clone()
.unwrap_or_default()
.map(LanguageModelCompletionEvent::Text)
})
})
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
}

View File

@@ -11,8 +11,9 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
StopReason,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@@ -355,12 +356,19 @@ impl LanguageModel for GoogleLanguageModel {
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
Result<
futures::stream::BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
>,
> {
let request = into_google(request, self.model.id().to_string());
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request.await.map_err(|err| anyhow!(err))?;
let response = request
.await
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
Ok(map_to_language_model_completion_events(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
@@ -434,13 +442,20 @@ pub fn into_google(
contents: request
.messages
.into_iter()
.map(|message| google_ai::Content {
parts: map_content(message.content),
role: match message.role {
Role::User => google_ai::Role::User,
Role::Assistant => google_ai::Role::Model,
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
},
.filter_map(|message| {
let parts = map_content(message.content);
if parts.is_empty() {
None
} else {
Some(google_ai::Content {
parts,
role: match message.role {
Role::User => google_ai::Role::User,
Role::Assistant => google_ai::Role::Model,
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
},
})
}
})
.collect(),
generation_config: Some(google_ai::GenerationConfig {
@@ -471,7 +486,7 @@ pub fn into_google(
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
use std::sync::atomic::{AtomicU64, Ordering};
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
@@ -492,7 +507,7 @@ pub fn map_to_language_model_completion_events(
if let Some(event) = state.events.next().await {
match event {
Ok(event) => {
let mut events: Vec<Result<LanguageModelCompletionEvent>> = Vec::new();
let mut events: Vec<_> = Vec::new();
let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut state.usage, &usage_metadata);
@@ -537,6 +552,10 @@ pub fn map_to_language_model_completion_events(
id,
name,
is_input_complete: true,
raw_input: function_call_part
.function_call
.args
.to_string(),
input: function_call_part.function_call.args,
},
)));
@@ -555,7 +574,10 @@ pub fn map_to_language_model_completion_events(
return Some((events, state));
}
Err(err) => {
return Some((vec![Err(anyhow!(err))], state));
return Some((
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
state,
));
}
}
}

View File

@@ -2,7 +2,9 @@ use anyhow::{Result, anyhow};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@@ -310,7 +312,12 @@ impl LanguageModel for LmStudioLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = self.to_lmstudio_request(request);
let http_client = self.http_client.clone();
@@ -364,7 +371,11 @@ impl LanguageModel for LmStudioLanguageModel {
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.map(|result| {
result
.map(LanguageModelCompletionEvent::Text)
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
}
.boxed()

View File

@@ -8,9 +8,9 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use futures::stream::BoxStream;
@@ -344,7 +344,12 @@ impl LanguageModel for MistralLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = into_mistral(
request,
self.model.id().to_string(),
@@ -356,20 +361,22 @@ impl LanguageModel for MistralLanguageModel {
let stream = stream.await?;
Ok(stream
.map(|result| {
result.and_then(|response| {
response
.choices
.first()
.ok_or_else(|| anyhow!("Empty response"))
.map(|choice| {
choice
.delta
.content
.clone()
.unwrap_or_default()
.map(LanguageModelCompletionEvent::Text)
})
})
result
.and_then(|response| {
response
.choices
.first()
.ok_or_else(|| anyhow!("Empty response"))
.map(|choice| {
choice
.delta
.content
.clone()
.unwrap_or_default()
.map(LanguageModelCompletionEvent::Text)
})
})
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
}

View File

@@ -2,7 +2,9 @@ use anyhow::{Result, anyhow};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
@@ -322,7 +324,12 @@ impl LanguageModel for OllamaLanguageModel {
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
> {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
@@ -356,7 +363,11 @@ impl LanguageModel for OllamaLanguageModel {
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.map(|result| {
result
.map(LanguageModelCompletionEvent::Text)
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
}
.boxed()

Some files were not shown because too many files have changed in this diff Show More