Compare commits
140 Commits
inspector-
...
streaming-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3f7f22e50b | ||
|
|
938a9ad405 | ||
|
|
190bdddfc1 | ||
|
|
7ad3816ea9 | ||
|
|
530d2a0ccd | ||
|
|
0cd0ee6fb9 | ||
|
|
936972d9b0 | ||
|
|
e9533423db | ||
|
|
ba480295c1 | ||
|
|
9106f4495b | ||
|
|
1feb1296fe | ||
|
|
582a247922 | ||
|
|
c2881a4537 | ||
|
|
187f851613 | ||
|
|
a77db45865 | ||
|
|
6bb6be826d | ||
|
|
7d9a55d101 | ||
|
|
57d8397f53 | ||
|
|
17ecf94f6f | ||
|
|
d492939bed | ||
|
|
720dfee803 | ||
|
|
a98c648201 | ||
|
|
c147daae4a | ||
|
|
d3911e34de | ||
|
|
87f85f1863 | ||
|
|
1a4dab97db | ||
|
|
cd365b0cf5 | ||
|
|
58604fba86 | ||
|
|
b0609272c0 | ||
|
|
a17807d8b1 | ||
|
|
f81e65ae7c | ||
|
|
952fe34aaa | ||
|
|
f527df6fa1 | ||
|
|
b54bbebc03 | ||
|
|
8bb7a1f9e7 | ||
|
|
e70d8d4dfd | ||
|
|
ea5ce2a1a4 | ||
|
|
fd8eeb537d | ||
|
|
92f21ee39d | ||
|
|
fcfeea4825 | ||
|
|
c0f8e0f605 | ||
|
|
9d10489607 | ||
|
|
8836c6fb42 | ||
|
|
f125353b6f | ||
|
|
fef2681cfa | ||
|
|
8b5835de17 | ||
|
|
2124b7ea99 | ||
|
|
74442b68ea | ||
|
|
ba3d82629e | ||
|
|
ecc600a68f | ||
|
|
218496744c | ||
|
|
d095bab8ad | ||
|
|
f8c3fe7871 | ||
|
|
aa161078fb | ||
|
|
f11c749353 | ||
|
|
40b5a1b028 | ||
|
|
2d43818c04 | ||
|
|
636c6e7f2d | ||
|
|
45d3f5168a | ||
|
|
8366cd0b52 | ||
|
|
f6774ae60d | ||
|
|
92e810bfec | ||
|
|
724c935196 | ||
|
|
ef54b58346 | ||
|
|
01bdd170ec | ||
|
|
4b9f4feff1 | ||
|
|
19fb1e1b0d | ||
|
|
f2cb6d69d5 | ||
|
|
822b6f837d | ||
|
|
09db31288a | ||
|
|
a320d324f1 | ||
|
|
266c41ed9a | ||
|
|
4f4bbf264f | ||
|
|
990ca48744 | ||
|
|
f69aeb6311 | ||
|
|
d5f3fbdc88 | ||
|
|
76a78b550b | ||
|
|
e515b2c714 | ||
|
|
55ea481707 | ||
|
|
5e31d86f1f | ||
|
|
4a8f114528 | ||
|
|
ce1a674eba | ||
|
|
0d3fe474db | ||
|
|
6a009b447a | ||
|
|
75ab8ff9a1 | ||
|
|
3705986fac | ||
|
|
aefb3aa2fa | ||
|
|
8e7c145f20 | ||
|
|
a2a502f026 | ||
|
|
c231c95521 | ||
|
|
fcc6a86c90 | ||
|
|
338a6a3b7e | ||
|
|
a0eaede13d | ||
|
|
abf2b9d7d3 | ||
|
|
a50fbc9b5c | ||
|
|
9bbc2e0fb2 | ||
|
|
6caf34ab7e | ||
|
|
8607c7d3ee | ||
|
|
e26bb05567 | ||
|
|
b3b89c8443 | ||
|
|
962b024248 | ||
|
|
833653a3ea | ||
|
|
886f0b7214 | ||
|
|
207fb04969 | ||
|
|
36d02de784 | ||
|
|
36da97935a | ||
|
|
19b547565d | ||
|
|
109f1d43fc | ||
|
|
a5852d4537 | ||
|
|
10ded0ab75 | ||
|
|
b0b620af56 | ||
|
|
eca6d5a04e | ||
|
|
3357736aea | ||
|
|
458ffaa134 | ||
|
|
b14356d1d3 | ||
|
|
19ef56ba7c | ||
|
|
dfbd132d9f | ||
|
|
2e8ee9b64f | ||
|
|
c15382c4d8 | ||
|
|
70c51b513b | ||
|
|
38afae86a9 | ||
|
|
9249919b7a | ||
|
|
9fe4a14f73 | ||
|
|
7cc3c03b08 | ||
|
|
4f2f9ff762 | ||
|
|
7aa0fa1543 | ||
|
|
3b31860d52 | ||
|
|
733cd6b68c | ||
|
|
e8fe0eb2e6 | ||
|
|
0f3ac38332 | ||
|
|
32e9757a85 | ||
|
|
be76942a69 | ||
|
|
942d4eb126 | ||
|
|
9d35f0389d | ||
|
|
d13cd007a2 | ||
|
|
f8ac6eef75 | ||
|
|
6d2bdc3bac | ||
|
|
9a3434efb4 | ||
|
|
333de5d673 | ||
|
|
97ab0980d1 |
71
.github/workflows/eval.yml
vendored
Normal file
71
.github/workflows/eval.yml
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
name: Run Agent Eval
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 0 * * *"
|
||||
|
||||
pull_request:
|
||||
branches:
|
||||
- "**"
|
||||
types: [opened, synchronize, reopened, labeled]
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
# Allow only one workflow per any non-`main` branch.
|
||||
group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
CARGO_INCREMENTAL: 0
|
||||
RUST_BACKTRACE: 1
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
|
||||
ZED_EVAL_TELEMETRY: 1
|
||||
|
||||
jobs:
|
||||
run_eval:
|
||||
timeout-minutes: 60
|
||||
name: Run Agent Eval
|
||||
if: >
|
||||
github.repository_owner == 'zed-industries' &&
|
||||
(github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-eval'))
|
||||
runs-on:
|
||||
- buildjet-16vcpu-ubuntu-2204
|
||||
steps:
|
||||
- name: Add Rust to the PATH
|
||||
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2
|
||||
with:
|
||||
save-if: ${{ github.ref == 'refs/heads/main' }}
|
||||
cache-provider: "buildjet"
|
||||
|
||||
- name: Install Linux dependencies
|
||||
run: ./script/linux
|
||||
|
||||
- name: Configure CI
|
||||
run: |
|
||||
mkdir -p ./../.cargo
|
||||
cp ./.cargo/ci-config.toml ./../.cargo/config.toml
|
||||
|
||||
- name: Compile eval
|
||||
run: cargo build --package=eval
|
||||
|
||||
- name: Run eval
|
||||
run: cargo run --package=eval -- --repetitions=3 --concurrency=1
|
||||
|
||||
# Even the Linux runner is not stateful, in theory there is no need to do this cleanup.
|
||||
# But, to avoid potential issues in the future if we choose to use a stateful Linux runner and forget to add code
|
||||
# to clean up the config file, I’ve included the cleanup code here as a precaution.
|
||||
# While it’s not strictly necessary at this moment, I believe it’s better to err on the side of caution.
|
||||
- name: Clean CI config file
|
||||
if: always()
|
||||
run: rm -rf ./../.cargo
|
||||
28
.github/workflows/run_agent_eval_daily.yml
vendored
28
.github/workflows/run_agent_eval_daily.yml
vendored
@@ -1,28 +0,0 @@
|
||||
name: Run Eval Daily
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
CARGO_INCREMENTAL: 0
|
||||
RUST_BACKTRACE: 1
|
||||
|
||||
jobs:
|
||||
run_eval:
|
||||
name: Run Eval
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Run cargo eval
|
||||
run: cargo run -p eval
|
||||
11
.rules
11
.rules
@@ -5,6 +5,7 @@
|
||||
* Prefer implementing functionality in existing files unless it is a new logical component. Avoid creating many small files.
|
||||
* Avoid using functions that panic like `unwrap()`, instead use mechanisms like `?` to propagate errors.
|
||||
* Be careful with operations like indexing which may panic if the indexes are out of bounds.
|
||||
* Never create files with `mod.rs` paths - prefer `src/some_module.rs` instead of `src/some_module/mod.rs`.
|
||||
|
||||
# GPUI
|
||||
|
||||
@@ -108,3 +109,13 @@ When a view's state has changed in a way that may affect its rendering, it shoul
|
||||
While updating an entity (`cx: Context<T>`), it can emit an event using `cx.emit(event)`. Entities register which events they can emit by declaring `impl EventEmittor<EventType> for EntityType {}`.
|
||||
|
||||
Other entities can then register a callback to handle these events by doing `cx.subscribe(other_entity, |this, other_entity, event, cx| ...)`. This will return a `Subscription` which deregisters the callback when dropped. Typically `cx.subscribe` happens when creating a new entity and the subscriptions are stored in a `_subscriptions: Vec<Subscription>` field.
|
||||
|
||||
## Recent API changes
|
||||
|
||||
GPUI has had some changes to its APIs. Always write code using the new APIs:
|
||||
|
||||
* `spawn` methods now take async closures (`AsyncFn`), and so should be called like `cx.spawn(async move |cx| ...)`.
|
||||
* Use `Entity<T>`. This replaces `Model<T>` and `View<T>` which longer exists and should NEVER be used.
|
||||
* Use `App` references. This replaces `AppContext` which no longer exists and should NEVER be used.
|
||||
* Use `Context<T>` references. This replaces `ModelContext<T>` which no longer exists and should NEVER be used.
|
||||
* `Window` is now passed around explicitly. The new interface adds a `Window` reference parameter to some methods, and adds some new "*_in" methods for plumbing `Window`. The old types `WindowContext` and `ViewContext<T>` should NEVER be used.
|
||||
|
||||
531
Cargo.lock
generated
531
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -109,7 +109,7 @@ members = [
|
||||
"crates/project",
|
||||
"crates/project_panel",
|
||||
"crates/project_symbols",
|
||||
"crates/prompt_library",
|
||||
"crates/rules_library",
|
||||
"crates/prompt_store",
|
||||
"crates/proto",
|
||||
"crates/recent_projects",
|
||||
@@ -296,6 +296,7 @@ livekit_api = { path = "crates/livekit_api" }
|
||||
livekit_client = { path = "crates/livekit_client" }
|
||||
lmstudio = { path = "crates/lmstudio" }
|
||||
lsp = { path = "crates/lsp" }
|
||||
lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "c9c189f1c5dd53c624a419ce35bc77ad6a908d18" }
|
||||
markdown = { path = "crates/markdown" }
|
||||
markdown_preview = { path = "crates/markdown_preview" }
|
||||
media = { path = "crates/media" }
|
||||
@@ -318,7 +319,7 @@ prettier = { path = "crates/prettier" }
|
||||
project = { path = "crates/project" }
|
||||
project_panel = { path = "crates/project_panel" }
|
||||
project_symbols = { path = "crates/project_symbols" }
|
||||
prompt_library = { path = "crates/prompt_library" }
|
||||
rules_library = { path = "crates/rules_library" }
|
||||
prompt_store = { path = "crates/prompt_store" }
|
||||
proto = { path = "crates/proto" }
|
||||
recent_projects = { path = "crates/recent_projects" }
|
||||
@@ -480,6 +481,7 @@ num-format = "0.4.4"
|
||||
ordered-float = "2.1.1"
|
||||
palette = { version = "0.7.5", default-features = false, features = ["std"] }
|
||||
parking_lot = "0.12.1"
|
||||
partial-json-fixer = "0.5.3"
|
||||
pathdiff = "0.2"
|
||||
pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
@@ -498,6 +500,7 @@ prost-types = "0.9"
|
||||
pulldown-cmark = { version = "0.12.0", default-features = false }
|
||||
quote = "1.0.9"
|
||||
rand = "0.8.5"
|
||||
ref-cast = "1.0.24"
|
||||
rayon = "1.8"
|
||||
regex = "1.5"
|
||||
repair_json = "0.1.0"
|
||||
@@ -604,7 +607,7 @@ wasmtime-wasi = "29"
|
||||
which = "6.0.0"
|
||||
wit-component = "0.221"
|
||||
workspace-hack = "0.1.0"
|
||||
zed_llm_client = "0.6.1"
|
||||
zed_llm_client = "0.7.1"
|
||||
zstd = "0.11"
|
||||
metal = "0.29"
|
||||
|
||||
|
||||
1
assets/icons/image.svg
Normal file
1
assets/icons/image.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-image-icon lucide-image"><rect width="18" height="18" x="3" y="3" rx="2" ry="2"/><circle cx="9" cy="9" r="2"/><path d="m21 15-3.086-3.086a2 2 0 0 0-2.828 0L6 21"/></svg>
|
||||
|
After Width: | Height: | Size: 372 B |
@@ -212,7 +212,7 @@
|
||||
"ctrl-shift-g": "search::SelectPreviousMatch",
|
||||
"ctrl-alt-/": "assistant::ToggleModelSelector",
|
||||
"ctrl-k h": "assistant::DeployHistory",
|
||||
"ctrl-k l": "assistant::OpenPromptLibrary",
|
||||
"ctrl-k l": "assistant::OpenRulesLibrary",
|
||||
"new": "assistant::NewChat",
|
||||
"ctrl-t": "assistant::NewChat",
|
||||
"ctrl-n": "assistant::NewChat"
|
||||
@@ -241,7 +241,7 @@
|
||||
"ctrl-alt-n": "agent::NewTextThread",
|
||||
"ctrl-shift-h": "agent::OpenHistory",
|
||||
"ctrl-alt-c": "agent::OpenConfiguration",
|
||||
"ctrl-alt-p": "assistant::OpenPromptLibrary",
|
||||
"ctrl-alt-p": "assistant::OpenRulesLibrary",
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"ctrl-alt-/": "assistant::ToggleModelSelector",
|
||||
"ctrl-shift-a": "agent::ToggleContextPicker",
|
||||
@@ -308,9 +308,9 @@
|
||||
{
|
||||
"context": "PromptLibrary",
|
||||
"bindings": {
|
||||
"new": "prompt_library::NewPrompt",
|
||||
"ctrl-n": "prompt_library::NewPrompt",
|
||||
"ctrl-shift-s": "prompt_library::ToggleDefaultPrompt"
|
||||
"new": "rules_library::NewRule",
|
||||
"ctrl-n": "rules_library::NewRule",
|
||||
"ctrl-shift-s": "rules_library::ToggleDefaultRule"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -675,7 +675,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && mode == full",
|
||||
"context": "!ContextEditor > Editor && mode == full",
|
||||
"bindings": {
|
||||
"alt-enter": "editor::OpenExcerpts",
|
||||
"shift-enter": "editor::ExpandExcerpts",
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
"context": "PromptLibrary",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-n": "prompt_library::NewPrompt",
|
||||
"cmd-shift-s": "prompt_library::ToggleDefaultPrompt",
|
||||
"cmd-n": "rules_library::NewRule",
|
||||
"cmd-shift-s": "rules_library::ToggleDefaultRule",
|
||||
"cmd-w": "workspace::CloseWindow"
|
||||
}
|
||||
},
|
||||
@@ -257,7 +257,7 @@
|
||||
"cmd-shift-g": "search::SelectPreviousMatch",
|
||||
"cmd-alt-/": "assistant::ToggleModelSelector",
|
||||
"cmd-k h": "assistant::DeployHistory",
|
||||
"cmd-k l": "assistant::OpenPromptLibrary",
|
||||
"cmd-k l": "assistant::OpenRulesLibrary",
|
||||
"cmd-t": "assistant::NewChat",
|
||||
"cmd-n": "assistant::NewChat"
|
||||
}
|
||||
@@ -286,7 +286,7 @@
|
||||
"cmd-alt-n": "agent::NewTextThread",
|
||||
"cmd-shift-h": "agent::OpenHistory",
|
||||
"cmd-alt-c": "agent::OpenConfiguration",
|
||||
"cmd-alt-p": "assistant::OpenPromptLibrary",
|
||||
"cmd-alt-p": "assistant::OpenRulesLibrary",
|
||||
"cmd-i": "agent::ToggleProfileSelector",
|
||||
"cmd-alt-/": "assistant::ToggleModelSelector",
|
||||
"cmd-shift-a": "agent::ToggleContextPicker",
|
||||
@@ -738,7 +738,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && mode == full",
|
||||
"context": "!ContextEditor > Editor && mode == full",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"alt-enter": "editor::OpenExcerpts",
|
||||
@@ -1028,10 +1028,10 @@
|
||||
// Using `ctrl-shift-space` in Zed requires disabling the macOS global shortcut.
|
||||
// System Preferences->Keyboard->Keyboard Shortcuts->Input Sources->Select the previous input source (uncheck)
|
||||
"ctrl-shift-space": "terminal::ToggleViMode",
|
||||
"ctrl-k up": "pane::SplitUp",
|
||||
"ctrl-k down": "pane::SplitDown",
|
||||
"ctrl-k left": "pane::SplitLeft",
|
||||
"ctrl-k right": "pane::SplitRight"
|
||||
"ctrl-alt-up": "pane::SplitUp",
|
||||
"ctrl-alt-down": "pane::SplitDown",
|
||||
"ctrl-alt-left": "pane::SplitLeft",
|
||||
"ctrl-alt-right": "pane::SplitRight"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -830,5 +830,13 @@
|
||||
// and Windows.
|
||||
"alt-l": "editor::AcceptEditPrediction"
|
||||
}
|
||||
},
|
||||
{
|
||||
// Fixes https://github.com/zed-industries/zed/issues/29095 by ensuring that
|
||||
// the last binding for editor::ToggleComments is not ctrl-c.
|
||||
"context": "hack_to_fix_ctrl-c",
|
||||
"bindings": {
|
||||
"g c": "editor::ToggleComments"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -31,6 +31,9 @@ If appropriate, use tool calls to explore the current project, which contains th
|
||||
- When looking for symbols in the project, prefer the `grep` tool.
|
||||
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
|
||||
- Bias towards not asking the user for help if you can find the answer yourself.
|
||||
{{! TODO: Only mention tools if they are enabled }}
|
||||
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
|
||||
- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path!
|
||||
|
||||
## Fixing Diagnostics
|
||||
|
||||
@@ -73,9 +76,9 @@ There are project rules that apply to these root directories:
|
||||
{{/each}}
|
||||
{{/if}}
|
||||
|
||||
{{#if has_default_user_rules}}
|
||||
{{#if has_user_rules}}
|
||||
The user has specified the following rules that should be applied:
|
||||
{{#each default_user_rules}}
|
||||
{{#each user_rules}}
|
||||
|
||||
{{#if title}}
|
||||
Rules title: {{title}}
|
||||
|
||||
@@ -646,7 +646,7 @@
|
||||
"fetch": true,
|
||||
"list_directory": false,
|
||||
"now": true,
|
||||
"path_search": true,
|
||||
"find_path": true,
|
||||
"read_file": true,
|
||||
"grep": true,
|
||||
"thinking": true,
|
||||
@@ -670,7 +670,7 @@
|
||||
"list_directory": true,
|
||||
"move_path": false,
|
||||
"now": false,
|
||||
"path_search": true,
|
||||
"find_path": true,
|
||||
"read_file": true,
|
||||
"grep": true,
|
||||
"rename": false,
|
||||
@@ -1489,7 +1489,12 @@
|
||||
"use_multiline_find": false,
|
||||
"use_smartcase_find": false,
|
||||
"highlight_on_yank_duration": 200,
|
||||
"custom_digraphs": {}
|
||||
"custom_digraphs": {},
|
||||
// Cursor shape for the each mode.
|
||||
// Specify the mode as the key and the shape as the value.
|
||||
// The mode can be one of the following: "normal", "replace", "insert", "visual".
|
||||
// The shape can be one of the following: "block", "bar", "underline", "hollow".
|
||||
"cursor_shape": {}
|
||||
},
|
||||
// The server to connect to. If the environment variable
|
||||
// ZED_SERVER_URL is set, it will override this setting.
|
||||
|
||||
@@ -1,2 +1,7 @@
|
||||
allow-private-module-inception = true
|
||||
avoid-breaking-exported-api = false
|
||||
ignore-interior-mutability = [
|
||||
# Suppresses clippy::mutable_key_type, which is a false positive as the Eq
|
||||
# and Hash impls do not use fields with interior mutability.
|
||||
"agent::context::AgentContextKey"
|
||||
]
|
||||
|
||||
@@ -28,7 +28,6 @@ async-watch.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
clock.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
component.workspace = true
|
||||
@@ -62,9 +61,10 @@ parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
picker.workspace = true
|
||||
project.workspace = true
|
||||
prompt_library.workspace = true
|
||||
rules_library.workspace = true
|
||||
prompt_store.workspace = true
|
||||
proto.workspace = true
|
||||
ref-cast.workspace = true
|
||||
release_channel.workspace = true
|
||||
rope.workspace = true
|
||||
schemars.workspace = true
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent};
|
||||
use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent, ui::AnimatedLabel};
|
||||
use anyhow::Result;
|
||||
use buffer_diff::DiffHunkStatus;
|
||||
use collections::{HashMap, HashSet};
|
||||
@@ -8,8 +8,8 @@ use editor::{
|
||||
scroll::Autoscroll,
|
||||
};
|
||||
use gpui::{
|
||||
Action, AnyElement, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, SharedString,
|
||||
Subscription, Task, WeakEntity, Window, prelude::*,
|
||||
Action, AnyElement, AnyView, App, Empty, Entity, EventEmitter, FocusHandle, Focusable,
|
||||
SharedString, Subscription, Task, WeakEntity, Window, prelude::*,
|
||||
};
|
||||
use language::{Capability, DiskState, OffsetRangeExt, Point};
|
||||
use multi_buffer::PathKey;
|
||||
@@ -307,6 +307,10 @@ impl AgentDiff {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if self.thread.read(cx).is_generating() {
|
||||
return;
|
||||
}
|
||||
|
||||
let snapshot = self.multibuffer.read(cx).snapshot(cx);
|
||||
let diff_hunks_in_ranges = self
|
||||
.editor
|
||||
@@ -339,6 +343,10 @@ impl AgentDiff {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if self.thread.read(cx).is_generating() {
|
||||
return;
|
||||
}
|
||||
|
||||
let snapshot = self.multibuffer.read(cx).snapshot(cx);
|
||||
let diff_hunks_in_ranges = self
|
||||
.editor
|
||||
@@ -650,6 +658,11 @@ fn render_diff_hunk_controls(
|
||||
cx: &mut App,
|
||||
) -> AnyElement {
|
||||
let editor = editor.clone();
|
||||
|
||||
if agent_diff.read(cx).thread.read(cx).is_generating() {
|
||||
return Empty.into_any();
|
||||
}
|
||||
|
||||
h_flex()
|
||||
.h(line_height)
|
||||
.mr_0p5()
|
||||
@@ -857,8 +870,14 @@ impl Render for AgentDiffToolbar {
|
||||
None => return div(),
|
||||
};
|
||||
|
||||
let is_empty = agent_diff.read(cx).multibuffer.read(cx).is_empty();
|
||||
let is_generating = agent_diff.read(cx).thread.read(cx).is_generating();
|
||||
if is_generating {
|
||||
return div()
|
||||
.w(rems(6.5625)) // Arbitrary 105px size—so the label doesn't dance around
|
||||
.child(AnimatedLabel::new("Generating"));
|
||||
}
|
||||
|
||||
let is_empty = agent_diff.read(cx).multibuffer.read(cx).is_empty();
|
||||
if is_empty {
|
||||
return div();
|
||||
}
|
||||
@@ -943,11 +962,13 @@ mod tests {
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let prompt_store = None;
|
||||
let thread_store = cx
|
||||
.update(|cx| {
|
||||
ThreadStore::load(
|
||||
project.clone(),
|
||||
cx.new(|_| ToolWorkingSet::default()),
|
||||
prompt_store,
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
@@ -969,7 +990,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|_, cx| {
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.edit(
|
||||
|
||||
@@ -39,6 +39,7 @@ use thread::ThreadId;
|
||||
pub use crate::active_thread::ActiveThread;
|
||||
use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal};
|
||||
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
|
||||
pub use crate::context::{ContextLoadResult, LoadedContext};
|
||||
pub use crate::inline_assistant::InlineAssistant;
|
||||
pub use crate::thread::{Message, Thread, ThreadEvent};
|
||||
pub use crate::thread_store::ThreadStore;
|
||||
|
||||
@@ -16,7 +16,7 @@ use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageMod
|
||||
use settings::{Settings, update_settings_file};
|
||||
use ui::{
|
||||
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
|
||||
Switch, Tooltip, prelude::*,
|
||||
Switch, SwitchColor, Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use zed_actions::ExtensionCategoryFilter;
|
||||
@@ -236,6 +236,7 @@ impl AssistantConfiguration {
|
||||
"always-allow-tool-actions-switch",
|
||||
always_allow_tool_actions.into(),
|
||||
)
|
||||
.color(SwitchColor::Accent)
|
||||
.on_click({
|
||||
let fs = self.fs.clone();
|
||||
move |state, _window, cx| {
|
||||
@@ -332,41 +333,44 @@ impl AssistantConfiguration {
|
||||
),
|
||||
)
|
||||
.child(
|
||||
Switch::new("context-server-switch", is_running.into()).on_click({
|
||||
let context_server_manager =
|
||||
self.context_server_manager.clone();
|
||||
let context_server = context_server.clone();
|
||||
move |state, _window, cx| match state {
|
||||
ToggleState::Unselected | ToggleState::Indeterminate => {
|
||||
context_server_manager.update(cx, |this, cx| {
|
||||
this.stop_server(context_server.clone(), cx)
|
||||
.log_err();
|
||||
});
|
||||
}
|
||||
ToggleState::Selected => {
|
||||
cx.spawn({
|
||||
let context_server_manager =
|
||||
context_server_manager.clone();
|
||||
let context_server = context_server.clone();
|
||||
async move |cx| {
|
||||
if let Some(start_server_task) =
|
||||
context_server_manager
|
||||
.update(cx, |this, cx| {
|
||||
this.start_server(
|
||||
context_server,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err()
|
||||
{
|
||||
start_server_task.await.log_err();
|
||||
Switch::new("context-server-switch", is_running.into())
|
||||
.color(SwitchColor::Accent)
|
||||
.on_click({
|
||||
let context_server_manager =
|
||||
self.context_server_manager.clone();
|
||||
let context_server = context_server.clone();
|
||||
move |state, _window, cx| match state {
|
||||
ToggleState::Unselected
|
||||
| ToggleState::Indeterminate => {
|
||||
context_server_manager.update(cx, |this, cx| {
|
||||
this.stop_server(context_server.clone(), cx)
|
||||
.log_err();
|
||||
});
|
||||
}
|
||||
ToggleState::Selected => {
|
||||
cx.spawn({
|
||||
let context_server_manager =
|
||||
context_server_manager.clone();
|
||||
let context_server = context_server.clone();
|
||||
async move |cx| {
|
||||
if let Some(start_server_task) =
|
||||
context_server_manager
|
||||
.update(cx, |this, cx| {
|
||||
this.start_server(
|
||||
context_server,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err()
|
||||
{
|
||||
start_server_task.await.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
}),
|
||||
),
|
||||
)
|
||||
.map(|parent| {
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use assistant_settings::{
|
||||
AgentProfile, AgentProfileContent, AgentProfileId, AssistantSettings, AssistantSettingsContent,
|
||||
ContextServerPresetContent, VersionedAssistantSettingsContent,
|
||||
ContextServerPresetContent,
|
||||
};
|
||||
use assistant_tool::{ToolSource, ToolWorkingSet};
|
||||
use fs::Fs;
|
||||
@@ -201,10 +201,10 @@ impl PickerDelegate for ToolPickerDelegate {
|
||||
let profile_id = self.profile_id.clone();
|
||||
let default_profile = self.profile.clone();
|
||||
let tool = tool.clone();
|
||||
move |settings, _cx| match settings {
|
||||
AssistantSettingsContent::Versioned(boxed) => {
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
let profiles = settings.profiles.get_or_insert_default();
|
||||
move |settings: &mut AssistantSettingsContent, _cx| {
|
||||
settings
|
||||
.v2_setting(|v2_settings| {
|
||||
let profiles = v2_settings.profiles.get_or_insert_default();
|
||||
let profile =
|
||||
profiles
|
||||
.entry(profile_id)
|
||||
@@ -240,9 +240,10 @@ impl PickerDelegate for ToolPickerDelegate {
|
||||
*preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -12,21 +12,21 @@ use assistant_settings::{AssistantDockPosition, AssistantSettings};
|
||||
use assistant_slash_command::SlashCommandWorkingSet;
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
|
||||
use client::zed_urls;
|
||||
use client::{UserStore, zed_urls};
|
||||
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, Corner, Entity,
|
||||
EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, Pixels, Subscription, Task,
|
||||
UpdateGlobal, WeakEntity, prelude::*, pulsating_between,
|
||||
Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, ClipboardItem,
|
||||
Corner, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, Pixels,
|
||||
Subscription, Task, UpdateGlobal, WeakEntity, prelude::*, pulsating_between,
|
||||
};
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
|
||||
use language_model_selector::ToggleModelSelector;
|
||||
use project::Project;
|
||||
use prompt_library::{PromptLibrary, open_prompt_library};
|
||||
use prompt_store::{PromptBuilder, PromptId};
|
||||
use prompt_store::{PromptBuilder, PromptStore, UserPromptId};
|
||||
use proto::Plan;
|
||||
use rules_library::{RulesLibrary, open_rules_library};
|
||||
use settings::{Settings, update_settings_file};
|
||||
use time::UtcOffset;
|
||||
use ui::{
|
||||
@@ -36,7 +36,7 @@ use util::ResultExt as _;
|
||||
use workspace::Workspace;
|
||||
use workspace::dock::{DockPosition, Panel, PanelEvent};
|
||||
use zed_actions::agent::OpenConfiguration;
|
||||
use zed_actions::assistant::{OpenPromptLibrary, ToggleFocus};
|
||||
use zed_actions::assistant::{OpenRulesLibrary, ToggleFocus};
|
||||
|
||||
use crate::active_thread::{ActiveThread, ActiveThreadEvent};
|
||||
use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent};
|
||||
@@ -79,11 +79,11 @@ pub fn init(cx: &mut App) {
|
||||
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, _: &OpenPromptLibrary, window, cx| {
|
||||
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.deploy_prompt_library(&OpenPromptLibrary::default(), window, cx)
|
||||
panel.deploy_rules_library(action, window, cx)
|
||||
});
|
||||
}
|
||||
})
|
||||
@@ -180,6 +180,7 @@ impl ActiveView {
|
||||
|
||||
pub struct AssistantPanel {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
user_store: Entity<UserStore>,
|
||||
project: Entity<Project>,
|
||||
fs: Arc<dyn Fs>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
@@ -188,6 +189,7 @@ pub struct AssistantPanel {
|
||||
message_editor: Entity<MessageEditor>,
|
||||
_active_thread_subscriptions: Vec<Subscription>,
|
||||
context_store: Entity<assistant_context_editor::ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
configuration: Option<Entity<AssistantConfiguration>>,
|
||||
configuration_subscription: Option<Subscription>,
|
||||
local_timezone: UtcOffset,
|
||||
@@ -204,14 +206,25 @@ impl AssistantPanel {
|
||||
pub fn load(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
cx: AsyncWindowContext,
|
||||
mut cx: AsyncWindowContext,
|
||||
) -> Task<Result<Entity<Self>>> {
|
||||
let prompt_store = cx.update(|_window, cx| PromptStore::global(cx));
|
||||
cx.spawn(async move |cx| {
|
||||
let prompt_store = match prompt_store {
|
||||
Ok(prompt_store) => prompt_store.await.ok(),
|
||||
Err(_) => None,
|
||||
};
|
||||
let tools = cx.new(|_| ToolWorkingSet::default())?;
|
||||
let thread_store = workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
let project = workspace.project().clone();
|
||||
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
|
||||
ThreadStore::load(
|
||||
project,
|
||||
tools.clone(),
|
||||
prompt_store.clone(),
|
||||
prompt_builder.clone(),
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
@@ -229,7 +242,16 @@ impl AssistantPanel {
|
||||
.await?;
|
||||
|
||||
workspace.update_in(cx, |workspace, window, cx| {
|
||||
cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx))
|
||||
cx.new(|cx| {
|
||||
Self::new(
|
||||
workspace,
|
||||
thread_store,
|
||||
context_store,
|
||||
prompt_store,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -238,11 +260,13 @@ impl AssistantPanel {
|
||||
workspace: &Workspace,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
context_store: Entity<assistant_context_editor::ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
|
||||
let fs = workspace.app_state().fs.clone();
|
||||
let user_store = workspace.app_state().user_store.clone();
|
||||
let project = workspace.project();
|
||||
let language_registry = project.read(cx).languages().clone();
|
||||
let workspace = workspace.weak_handle();
|
||||
@@ -260,6 +284,7 @@ impl AssistantPanel {
|
||||
fs.clone(),
|
||||
workspace.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
prompt_store.clone(),
|
||||
thread_store.downgrade(),
|
||||
thread.clone(),
|
||||
window,
|
||||
@@ -291,7 +316,6 @@ impl AssistantPanel {
|
||||
thread.clone(),
|
||||
thread_store.clone(),
|
||||
language_registry.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -307,6 +331,7 @@ impl AssistantPanel {
|
||||
Self {
|
||||
active_view,
|
||||
workspace,
|
||||
user_store,
|
||||
project: project.clone(),
|
||||
fs: fs.clone(),
|
||||
language_registry,
|
||||
@@ -319,6 +344,7 @@ impl AssistantPanel {
|
||||
message_editor_subscription,
|
||||
],
|
||||
context_store,
|
||||
prompt_store,
|
||||
configuration: None,
|
||||
configuration_subscription: None,
|
||||
local_timezone: UtcOffset::from_whole_seconds(
|
||||
@@ -352,18 +378,17 @@ impl AssistantPanel {
|
||||
self.local_timezone
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
|
||||
&self.prompt_store
|
||||
}
|
||||
|
||||
pub(crate) fn thread_store(&self) -> &Entity<ThreadStore> {
|
||||
&self.thread_store
|
||||
}
|
||||
|
||||
fn cancel(
|
||||
&mut self,
|
||||
_: &editor::actions::Cancel,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
fn cancel(&mut self, _: &editor::actions::Cancel, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.thread
|
||||
.update(cx, |thread, cx| thread.cancel_last_completion(cx));
|
||||
.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
|
||||
}
|
||||
|
||||
fn new_thread(&mut self, action: &NewThread, window: &mut Window, cx: &mut Context<Self>) {
|
||||
@@ -413,7 +438,6 @@ impl AssistantPanel {
|
||||
thread.clone(),
|
||||
self.thread_store.clone(),
|
||||
self.language_registry.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
self.workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -432,6 +456,7 @@ impl AssistantPanel {
|
||||
self.fs.clone(),
|
||||
self.workspace.clone(),
|
||||
message_editor_context_store,
|
||||
self.prompt_store.clone(),
|
||||
self.thread_store.downgrade(),
|
||||
thread,
|
||||
window,
|
||||
@@ -486,13 +511,13 @@ impl AssistantPanel {
|
||||
context_editor.focus_handle(cx).focus(window);
|
||||
}
|
||||
|
||||
fn deploy_prompt_library(
|
||||
fn deploy_rules_library(
|
||||
&mut self,
|
||||
action: &OpenPromptLibrary,
|
||||
action: &OpenRulesLibrary,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
open_prompt_library(
|
||||
open_rules_library(
|
||||
self.language_registry.clone(),
|
||||
Box::new(PromptLibraryInlineAssist::new(self.workspace.clone())),
|
||||
Arc::new(|| {
|
||||
@@ -502,7 +527,9 @@ impl AssistantPanel {
|
||||
None,
|
||||
))
|
||||
}),
|
||||
action.prompt_to_focus.map(|uuid| PromptId::User { uuid }),
|
||||
action
|
||||
.prompt_to_select
|
||||
.map(|uuid| UserPromptId(uuid).into()),
|
||||
cx,
|
||||
)
|
||||
.detach_and_log_err(cx);
|
||||
@@ -598,7 +625,6 @@ impl AssistantPanel {
|
||||
thread.clone(),
|
||||
this.thread_store.clone(),
|
||||
this.language_registry.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
this.workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -617,6 +643,7 @@ impl AssistantPanel {
|
||||
this.fs.clone(),
|
||||
this.workspace.clone(),
|
||||
message_editor_context_store,
|
||||
this.prompt_store.clone(),
|
||||
this.thread_store.downgrade(),
|
||||
thread,
|
||||
window,
|
||||
@@ -1120,7 +1147,7 @@ impl AssistantPanel {
|
||||
"New Text Thread",
|
||||
NewTextThread.boxed_clone(),
|
||||
)
|
||||
.action("Prompt Library", Box::new(OpenPromptLibrary::default()))
|
||||
.action("Rules Library", Box::new(OpenRulesLibrary::default()))
|
||||
.action("Settings", Box::new(OpenConfiguration))
|
||||
.separator()
|
||||
.header("MCPs")
|
||||
@@ -1546,9 +1573,19 @@ impl AssistantPanel {
|
||||
}
|
||||
|
||||
fn render_usage_banner(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
||||
let plan = self
|
||||
.user_store
|
||||
.read(cx)
|
||||
.current_plan()
|
||||
.map(|plan| match plan {
|
||||
Plan::Free => zed_llm_client::Plan::Free,
|
||||
Plan::ZedPro => zed_llm_client::Plan::ZedPro,
|
||||
Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
|
||||
})
|
||||
.unwrap_or(zed_llm_client::Plan::Free);
|
||||
let usage = self.thread.read(cx).last_usage()?;
|
||||
|
||||
Some(UsageBanner::new(zed_llm_client::Plan::ZedProTrial, usage.amount).into_any_element())
|
||||
Some(UsageBanner::new(plan, usage).into_any_element())
|
||||
}
|
||||
|
||||
fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
||||
@@ -1603,6 +1640,8 @@ impl AssistantPanel {
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.gap_1()
|
||||
.child(self.create_copy_button(ERROR_MESSAGE))
|
||||
.child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
|
||||
|this, _, _, cx| {
|
||||
this.thread.update(cx, |this, _cx| {
|
||||
@@ -1649,6 +1688,8 @@ impl AssistantPanel {
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.gap_1()
|
||||
.child(self.create_copy_button(ERROR_MESSAGE))
|
||||
.child(
|
||||
Button::new("subscribe", "Update Monthly Spend Limit").on_click(
|
||||
cx.listener(|this, _, _, cx| {
|
||||
@@ -1714,6 +1755,8 @@ impl AssistantPanel {
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.gap_1()
|
||||
.child(self.create_copy_button(error_message))
|
||||
.child(
|
||||
Button::new("subscribe", call_to_action).on_click(cx.listener(
|
||||
|this, _, _, cx| {
|
||||
@@ -1745,6 +1788,7 @@ impl AssistantPanel {
|
||||
message: SharedString,
|
||||
cx: &mut Context<Self>,
|
||||
) -> AnyElement {
|
||||
let message_with_header = format!("{}\n{}", header, message);
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
@@ -1759,12 +1803,14 @@ impl AssistantPanel {
|
||||
.id("error-message")
|
||||
.max_h_32()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(message)),
|
||||
.child(Label::new(message.clone())),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.gap_1()
|
||||
.child(self.create_copy_button(message_with_header))
|
||||
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||
|this, _, _, cx| {
|
||||
this.thread.update(cx, |this, _cx| {
|
||||
@@ -1778,6 +1824,15 @@ impl AssistantPanel {
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn create_copy_button(&self, message: impl Into<String>) -> impl IntoElement {
|
||||
let message = message.into();
|
||||
IconButton::new("copy", IconName::Copy)
|
||||
.on_click(move |_, _, cx| {
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(message.clone()))
|
||||
})
|
||||
.tooltip(Tooltip::text("Copy Error Message"))
|
||||
}
|
||||
|
||||
fn key_context(&self) -> KeyContext {
|
||||
let mut key_context = KeyContext::new_with_defaults();
|
||||
key_context.add("AgentPanel");
|
||||
@@ -1805,7 +1860,7 @@ impl Render for AssistantPanel {
|
||||
this.open_configuration(window, cx);
|
||||
}))
|
||||
.on_action(cx.listener(Self::open_active_thread_as_markdown))
|
||||
.on_action(cx.listener(Self::deploy_prompt_library))
|
||||
.on_action(cx.listener(Self::deploy_rules_library))
|
||||
.on_action(cx.listener(Self::open_agent_diff))
|
||||
.on_action(cx.listener(Self::go_back))
|
||||
.child(self.render_toolbar(window, cx))
|
||||
@@ -1832,13 +1887,13 @@ impl PromptLibraryInlineAssist {
|
||||
}
|
||||
}
|
||||
|
||||
impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
fn assist(
|
||||
&self,
|
||||
prompt_editor: &Entity<Editor>,
|
||||
_initial_prompt: Option<String>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<PromptLibrary>,
|
||||
cx: &mut Context<RulesLibrary>,
|
||||
) {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
let Some(project) = self
|
||||
@@ -1848,11 +1903,14 @@ impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let prompt_store = None;
|
||||
let thread_store = None;
|
||||
assistant.assist(
|
||||
&prompt_editor,
|
||||
self.workspace.clone(),
|
||||
project,
|
||||
None,
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -1931,8 +1989,8 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
|
||||
// being updated.
|
||||
cx.defer_in(window, move |panel, window, cx| {
|
||||
if panel.has_active_thread() {
|
||||
panel.thread.update(cx, |thread, cx| {
|
||||
thread.context_store().update(cx, |store, cx| {
|
||||
panel.message_editor.update(cx, |message_editor, cx| {
|
||||
message_editor.context_store().update(cx, |store, cx| {
|
||||
let buffer = buffer.read(cx);
|
||||
let selection_ranges = selection_ranges
|
||||
.into_iter()
|
||||
@@ -1949,7 +2007,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (buffer, range) in selection_ranges {
|
||||
store.add_excerpt(range, buffer, cx).detach_and_log_err(cx);
|
||||
store.add_selection(buffer, range, cx);
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
use crate::context::attach_context_to_message;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::context::ContextLoadResult;
|
||||
use crate::inline_prompt_editor::CodegenStatus;
|
||||
use crate::{context::load_context, context_store::ContextStore};
|
||||
use anyhow::Result;
|
||||
use client::telemetry::Telemetry;
|
||||
use collections::HashSet;
|
||||
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
|
||||
use futures::{SinkExt, Stream, StreamExt, channel::mpsc, future::LocalBoxFuture, join};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
|
||||
use futures::{
|
||||
SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
|
||||
};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
|
||||
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
@@ -14,7 +16,9 @@ use language_model::{
|
||||
};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use parking_lot::Mutex;
|
||||
use project::Project;
|
||||
use prompt_store::PromptBuilder;
|
||||
use prompt_store::PromptStore;
|
||||
use rope::Rope;
|
||||
use smol::future::FutureExt;
|
||||
use std::{
|
||||
@@ -39,6 +43,8 @@ pub struct BufferCodegen {
|
||||
range: Range<Anchor>,
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
context_store: Entity<ContextStore>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
pub is_insertion: bool,
|
||||
@@ -50,6 +56,8 @@ impl BufferCodegen {
|
||||
range: Range<Anchor>,
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
context_store: Entity<ContextStore>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -60,6 +68,8 @@ impl BufferCodegen {
|
||||
range.clone(),
|
||||
false,
|
||||
Some(context_store.clone()),
|
||||
project.clone(),
|
||||
prompt_store.clone(),
|
||||
Some(telemetry.clone()),
|
||||
builder.clone(),
|
||||
cx,
|
||||
@@ -75,6 +85,8 @@ impl BufferCodegen {
|
||||
range,
|
||||
initial_transaction_id,
|
||||
context_store,
|
||||
project,
|
||||
prompt_store,
|
||||
telemetry,
|
||||
builder,
|
||||
};
|
||||
@@ -153,6 +165,8 @@ impl BufferCodegen {
|
||||
self.range.clone(),
|
||||
false,
|
||||
Some(self.context_store.clone()),
|
||||
self.project.clone(),
|
||||
self.prompt_store.clone(),
|
||||
Some(self.telemetry.clone()),
|
||||
self.builder.clone(),
|
||||
cx,
|
||||
@@ -229,13 +243,14 @@ pub struct CodegenAlternative {
|
||||
generation: Task<()>,
|
||||
diff: Diff,
|
||||
context_store: Option<Entity<ContextStore>>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
_subscription: gpui::Subscription,
|
||||
builder: Arc<PromptBuilder>,
|
||||
active: bool,
|
||||
edits: Vec<(Range<Anchor>, String)>,
|
||||
line_operations: Vec<LineOperation>,
|
||||
request: Option<LanguageModelRequest>,
|
||||
elapsed_time: Option<f64>,
|
||||
completion: Option<String>,
|
||||
pub message_id: Option<String>,
|
||||
@@ -249,6 +264,8 @@ impl CodegenAlternative {
|
||||
range: Range<Anchor>,
|
||||
active: bool,
|
||||
context_store: Option<Entity<ContextStore>>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -290,6 +307,8 @@ impl CodegenAlternative {
|
||||
generation: Task::ready(()),
|
||||
diff: Diff::default(),
|
||||
context_store,
|
||||
project,
|
||||
prompt_store,
|
||||
telemetry,
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
builder,
|
||||
@@ -297,7 +316,6 @@ impl CodegenAlternative {
|
||||
edits: Vec::new(),
|
||||
line_operations: Vec::new(),
|
||||
range,
|
||||
request: None,
|
||||
elapsed_time: None,
|
||||
completion: None,
|
||||
}
|
||||
@@ -366,16 +384,18 @@ impl CodegenAlternative {
|
||||
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
|
||||
} else {
|
||||
let request = self.build_request(user_prompt, cx)?;
|
||||
self.request = Some(request.clone());
|
||||
|
||||
cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await)
|
||||
cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
|
||||
.boxed_local()
|
||||
};
|
||||
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_request(&self, user_prompt: String, cx: &mut App) -> Result<LanguageModelRequest> {
|
||||
fn build_request(
|
||||
&self,
|
||||
user_prompt: String,
|
||||
cx: &mut App,
|
||||
) -> Result<Task<LanguageModelRequest>> {
|
||||
let buffer = self.buffer.read(cx).snapshot(cx);
|
||||
let language = buffer.language_at(self.range.start);
|
||||
let language_name = if let Some(language) = language.as_ref() {
|
||||
@@ -408,30 +428,44 @@ impl CodegenAlternative {
|
||||
.generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
let context_task = self.context_store.as_ref().map(|context_store| {
|
||||
if let Some(project) = self.project.upgrade() {
|
||||
let context = context_store
|
||||
.read(cx)
|
||||
.context()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
load_context(context, &project, &self.prompt_store, cx)
|
||||
} else {
|
||||
Task::ready(ContextLoadResult::default())
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(context_store) = &self.context_store {
|
||||
attach_context_to_message(
|
||||
&mut request_message,
|
||||
context_store.read(cx).context().iter(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
Ok(cx.spawn(async move |_cx| {
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
|
||||
request_message.content.push(prompt.into());
|
||||
if let Some(context_task) = context_task {
|
||||
context_task
|
||||
.await
|
||||
.loaded_context
|
||||
.add_to_request_message(&mut request_message);
|
||||
}
|
||||
|
||||
Ok(LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
messages: vec![request_message],
|
||||
})
|
||||
request_message.content.push(prompt.into());
|
||||
|
||||
LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
messages: vec![request_message],
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn handle_stream(
|
||||
@@ -508,7 +542,9 @@ impl CodegenAlternative {
|
||||
let mut response_latency = None;
|
||||
let request_start = Instant::now();
|
||||
let diff = async {
|
||||
let chunks = StripInvalidSpans::new(stream?.stream);
|
||||
let chunks = StripInvalidSpans::new(
|
||||
stream?.stream.map_err(|error| error.into()),
|
||||
);
|
||||
futures::pin_mut!(chunks);
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
@@ -1034,6 +1070,7 @@ impl Diff {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use fs::FakeFs;
|
||||
use futures::{
|
||||
Stream,
|
||||
stream::{self},
|
||||
@@ -1076,12 +1113,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1140,12 +1181,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1207,12 +1252,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1274,12 +1323,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1329,12 +1382,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
false,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
|
||||
@@ -1,32 +1,34 @@
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::usize;
|
||||
use std::{ops::Range, path::Path, sync::Arc};
|
||||
|
||||
use gpui::{App, Entity, SharedString};
|
||||
use language::{Buffer, File};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use project::{ProjectPath, Worktree};
|
||||
use rope::Point;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use text::{Anchor, BufferId};
|
||||
use ui::IconName;
|
||||
use util::post_inc;
|
||||
use collections::HashSet;
|
||||
use futures::future;
|
||||
use futures::{FutureExt, future::Shared};
|
||||
use gpui::{App, AppContext as _, Entity, SharedString, Task};
|
||||
use language::Buffer;
|
||||
use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
|
||||
use project::{Project, ProjectEntryId, ProjectPath, Worktree};
|
||||
use prompt_store::{PromptStore, UserPromptId};
|
||||
use ref_cast::RefCast;
|
||||
use rope::{Point, Rope};
|
||||
use text::{Anchor, OffsetRangeExt as _};
|
||||
use ui::{ElementId, IconName};
|
||||
use util::{ResultExt as _, post_inc};
|
||||
|
||||
use crate::thread::Thread;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct ContextId(pub(crate) usize);
|
||||
pub const RULES_ICON: IconName = IconName::Context;
|
||||
|
||||
impl ContextId {
|
||||
pub fn post_inc(&mut self) -> Self {
|
||||
Self(post_inc(&mut self.0))
|
||||
}
|
||||
}
|
||||
pub enum ContextKind {
|
||||
File,
|
||||
Directory,
|
||||
Symbol,
|
||||
Excerpt,
|
||||
Selection,
|
||||
FetchedUrl,
|
||||
Thread,
|
||||
Rules,
|
||||
Image,
|
||||
}
|
||||
|
||||
impl ContextKind {
|
||||
@@ -35,244 +37,770 @@ impl ContextKind {
|
||||
ContextKind::File => IconName::File,
|
||||
ContextKind::Directory => IconName::Folder,
|
||||
ContextKind::Symbol => IconName::Code,
|
||||
ContextKind::Excerpt => IconName::Code,
|
||||
ContextKind::Selection => IconName::Context,
|
||||
ContextKind::FetchedUrl => IconName::Globe,
|
||||
ContextKind::Thread => IconName::MessageBubbles,
|
||||
ContextKind::Rules => RULES_ICON,
|
||||
ContextKind::Image => IconName::Image,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle for context that can be added to a user message.
|
||||
///
|
||||
/// This uses IDs that are stable enough for tracking renames and identifying when context has
|
||||
/// already been added to the thread. To use this in a set, wrap it in `AgentContextKey` to opt in
|
||||
/// to `PartialEq` and `Hash` impls that use the subset of the fields used for this stable identity.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AssistantContext {
|
||||
pub enum AgentContext {
|
||||
File(FileContext),
|
||||
Directory(DirectoryContext),
|
||||
Symbol(SymbolContext),
|
||||
Selection(SelectionContext),
|
||||
FetchedUrl(FetchedUrlContext),
|
||||
Thread(ThreadContext),
|
||||
Excerpt(ExcerptContext),
|
||||
Rules(RulesContext),
|
||||
Image(ImageContext),
|
||||
}
|
||||
|
||||
impl AssistantContext {
|
||||
pub fn id(&self) -> ContextId {
|
||||
impl AgentContext {
|
||||
fn id(&self) -> ContextId {
|
||||
match self {
|
||||
Self::File(file) => file.id,
|
||||
Self::Directory(directory) => directory.id,
|
||||
Self::Symbol(symbol) => symbol.id,
|
||||
Self::FetchedUrl(url) => url.id,
|
||||
Self::Thread(thread) => thread.id,
|
||||
Self::Excerpt(excerpt) => excerpt.id,
|
||||
Self::File(context) => context.context_id,
|
||||
Self::Directory(context) => context.context_id,
|
||||
Self::Symbol(context) => context.context_id,
|
||||
Self::Selection(context) => context.context_id,
|
||||
Self::FetchedUrl(context) => context.context_id,
|
||||
Self::Thread(context) => context.context_id,
|
||||
Self::Rules(context) => context.context_id,
|
||||
Self::Image(context) => context.context_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn element_id(&self, name: SharedString) -> ElementId {
|
||||
ElementId::NamedInteger(name, self.id().0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileContext {
|
||||
pub id: ContextId,
|
||||
pub context_buffer: ContextBuffer,
|
||||
/// ID created at time of context add, for use in ElementId. This is not the stable identity of a
|
||||
/// context, instead that's handled by the `PartialEq` and `Hash` impls of `AgentContextKey`.
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct ContextId(usize);
|
||||
|
||||
impl ContextId {
|
||||
pub fn zero() -> Self {
|
||||
ContextId(0)
|
||||
}
|
||||
|
||||
fn for_lookup() -> Self {
|
||||
ContextId(usize::MAX)
|
||||
}
|
||||
|
||||
pub fn post_inc(&mut self) -> Self {
|
||||
Self(post_inc(&mut self.0))
|
||||
}
|
||||
}
|
||||
|
||||
/// File context provides the entire contents of a file.
|
||||
///
|
||||
/// This holds an `Entity<Buffer>` so that file path renames affect its display and so that it can
|
||||
/// be opened even if the file has been deleted. An alternative might be to use `ProjectEntryId`,
|
||||
/// but then when deleted there is no path info or ability to open.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileContext {
|
||||
pub buffer: Entity<Buffer>,
|
||||
pub context_id: ContextId,
|
||||
}
|
||||
|
||||
impl FileContext {
|
||||
pub fn eq_for_key(&self, other: &Self) -> bool {
|
||||
self.buffer == other.buffer
|
||||
}
|
||||
|
||||
pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
|
||||
self.buffer.hash(state)
|
||||
}
|
||||
|
||||
pub fn project_path(&self, cx: &App) -> Option<ProjectPath> {
|
||||
let file = self.buffer.read(cx).file()?;
|
||||
Some(ProjectPath {
|
||||
worktree_id: file.worktree_id(cx),
|
||||
path: file.path().clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn load(&self, cx: &App) -> Option<Task<(String, Entity<Buffer>)>> {
|
||||
let buffer_ref = self.buffer.read(cx);
|
||||
let Some(file) = buffer_ref.file() else {
|
||||
log::error!("file context missing path");
|
||||
return None;
|
||||
};
|
||||
let full_path = file.full_path(cx);
|
||||
let rope = buffer_ref.as_rope().clone();
|
||||
let buffer = self.buffer.clone();
|
||||
Some(
|
||||
cx.background_spawn(
|
||||
async move { (to_fenced_codeblock(&full_path, rope, None), buffer) },
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Directory contents provides the entire contents of text files in a directory.
|
||||
///
|
||||
/// This has a `ProjectEntryId` so that it follows renames.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DirectoryContext {
|
||||
pub id: ContextId,
|
||||
pub worktree: Entity<Worktree>,
|
||||
pub path: Arc<Path>,
|
||||
/// Buffers of the files within the directory.
|
||||
pub context_buffers: Vec<ContextBuffer>,
|
||||
pub entry_id: ProjectEntryId,
|
||||
pub context_id: ContextId,
|
||||
}
|
||||
|
||||
impl DirectoryContext {
|
||||
pub fn project_path(&self, cx: &App) -> ProjectPath {
|
||||
ProjectPath {
|
||||
worktree_id: self.worktree.read(cx).id(),
|
||||
path: self.path.clone(),
|
||||
pub fn eq_for_key(&self, other: &Self) -> bool {
|
||||
self.entry_id == other.entry_id
|
||||
}
|
||||
|
||||
pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
|
||||
self.entry_id.hash(state)
|
||||
}
|
||||
|
||||
fn load(
|
||||
&self,
|
||||
project: Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Option<Task<Vec<(String, Entity<Buffer>)>>> {
|
||||
let worktree = project.read(cx).worktree_for_entry(self.entry_id, cx)?;
|
||||
let worktree_ref = worktree.read(cx);
|
||||
let entry = worktree_ref.entry_for_id(self.entry_id)?;
|
||||
if entry.is_file() {
|
||||
log::error!("DirectoryContext unexpectedly refers to a file.");
|
||||
return None;
|
||||
}
|
||||
|
||||
let file_paths = collect_files_in_path(worktree_ref, entry.path.as_ref());
|
||||
let texts_future = future::join_all(file_paths.into_iter().map(|path| {
|
||||
load_file_path_text_as_fenced_codeblock(project.clone(), worktree.clone(), path, cx)
|
||||
}));
|
||||
|
||||
Some(cx.background_spawn(async move {
|
||||
texts_future.await.into_iter().flatten().collect::<Vec<_>>()
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SymbolContext {
|
||||
pub id: ContextId,
|
||||
pub context_symbol: ContextSymbol,
|
||||
pub buffer: Entity<Buffer>,
|
||||
pub symbol: SharedString,
|
||||
pub range: Range<Anchor>,
|
||||
/// The range that fully contain the symbol. e.g. for function symbol, this will include not
|
||||
/// only the signature, but also the body. Not used by `PartialEq` or `Hash` for `AgentContextKey`.
|
||||
pub enclosing_range: Range<Anchor>,
|
||||
pub context_id: ContextId,
|
||||
}
|
||||
|
||||
impl SymbolContext {
|
||||
pub fn eq_for_key(&self, other: &Self) -> bool {
|
||||
self.buffer == other.buffer && self.symbol == other.symbol && self.range == other.range
|
||||
}
|
||||
|
||||
pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
|
||||
self.buffer.hash(state);
|
||||
self.symbol.hash(state);
|
||||
self.range.hash(state);
|
||||
}
|
||||
|
||||
fn load(&self, cx: &App) -> Option<Task<(String, Entity<Buffer>)>> {
|
||||
let buffer_ref = self.buffer.read(cx);
|
||||
let Some(file) = buffer_ref.file() else {
|
||||
log::error!("symbol context's file has no path");
|
||||
return None;
|
||||
};
|
||||
let full_path = file.full_path(cx);
|
||||
let rope = buffer_ref
|
||||
.text_for_range(self.enclosing_range.clone())
|
||||
.collect::<Rope>();
|
||||
let line_range = self.enclosing_range.to_point(&buffer_ref.snapshot());
|
||||
let buffer = self.buffer.clone();
|
||||
Some(cx.background_spawn(async move {
|
||||
(
|
||||
to_fenced_codeblock(&full_path, rope, Some(line_range)),
|
||||
buffer,
|
||||
)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SelectionContext {
|
||||
pub buffer: Entity<Buffer>,
|
||||
pub range: Range<Anchor>,
|
||||
pub context_id: ContextId,
|
||||
}
|
||||
|
||||
impl SelectionContext {
|
||||
pub fn eq_for_key(&self, other: &Self) -> bool {
|
||||
self.buffer == other.buffer && self.range == other.range
|
||||
}
|
||||
|
||||
pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
|
||||
self.buffer.hash(state);
|
||||
self.range.hash(state);
|
||||
}
|
||||
|
||||
fn load(&self, cx: &App) -> Option<Task<(String, Entity<Buffer>)>> {
|
||||
let buffer_ref = self.buffer.read(cx);
|
||||
let Some(file) = buffer_ref.file() else {
|
||||
log::error!("selection context's file has no path");
|
||||
return None;
|
||||
};
|
||||
let full_path = file.full_path(cx);
|
||||
let rope = buffer_ref
|
||||
.text_for_range(self.range.clone())
|
||||
.collect::<Rope>();
|
||||
let line_range = self.range.to_point(&buffer_ref.snapshot());
|
||||
let buffer = self.buffer.clone();
|
||||
Some(cx.background_spawn(async move {
|
||||
(
|
||||
to_fenced_codeblock(&full_path, rope, Some(line_range)),
|
||||
buffer,
|
||||
)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FetchedUrlContext {
|
||||
pub id: ContextId,
|
||||
pub url: SharedString,
|
||||
/// Text contents of the fetched url. Unlike other context types, the contents of this gets
|
||||
/// populated when added rather than when sending the message. Not used by `PartialEq` or `Hash`
|
||||
/// for `AgentContextKey`.
|
||||
pub text: SharedString,
|
||||
pub context_id: ContextId,
|
||||
}
|
||||
|
||||
impl FetchedUrlContext {
|
||||
pub fn eq_for_key(&self, other: &Self) -> bool {
|
||||
self.url == other.url
|
||||
}
|
||||
|
||||
pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
|
||||
self.url.hash(state);
|
||||
}
|
||||
|
||||
pub fn lookup_key(url: SharedString) -> AgentContextKey {
|
||||
AgentContextKey(AgentContext::FetchedUrl(FetchedUrlContext {
|
||||
url,
|
||||
text: "".into(),
|
||||
context_id: ContextId::for_lookup(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ThreadContext {
|
||||
pub id: ContextId,
|
||||
// TODO: Entity<Thread> holds onto the thread even if the thread is deleted. Should probably be
|
||||
// a WeakEntity and handle removal from the UI when it has dropped.
|
||||
pub thread: Entity<Thread>,
|
||||
pub text: SharedString,
|
||||
pub context_id: ContextId,
|
||||
}
|
||||
|
||||
impl ThreadContext {
|
||||
pub fn summary(&self, cx: &App) -> SharedString {
|
||||
pub fn eq_for_key(&self, other: &Self) -> bool {
|
||||
self.thread == other.thread
|
||||
}
|
||||
|
||||
pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
|
||||
self.thread.hash(state)
|
||||
}
|
||||
|
||||
pub fn name(&self, cx: &App) -> SharedString {
|
||||
self.thread
|
||||
.read(cx)
|
||||
.summary()
|
||||
.unwrap_or("New thread".into())
|
||||
.unwrap_or_else(|| "New thread".into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ContextBuffer {
|
||||
pub id: BufferId,
|
||||
// TODO: Entity<Buffer> holds onto the thread even if the thread is deleted. Should probably be
|
||||
// a WeakEntity and handle removal from the UI when it has dropped.
|
||||
pub buffer: Entity<Buffer>,
|
||||
pub file: Arc<dyn File>,
|
||||
pub version: clock::Global,
|
||||
pub text: SharedString,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ContextBuffer {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ContextBuffer")
|
||||
.field("id", &self.id)
|
||||
.field("buffer", &self.buffer)
|
||||
.field("version", &self.version)
|
||||
.field("text", &self.text)
|
||||
.finish()
|
||||
pub fn load(&self, cx: &App) -> String {
|
||||
let name = self.name(cx);
|
||||
let contents = self.thread.read(cx).latest_detailed_summary_or_text();
|
||||
let mut text = String::new();
|
||||
text.push_str(&name);
|
||||
text.push('\n');
|
||||
text.push_str(&contents.trim());
|
||||
text.push('\n');
|
||||
text
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContextSymbol {
|
||||
pub id: ContextSymbolId,
|
||||
pub buffer: Entity<Buffer>,
|
||||
pub buffer_version: clock::Global,
|
||||
/// The range that the symbol encloses, e.g. for function symbol, this will
|
||||
/// include not only the signature, but also the body
|
||||
pub enclosing_range: Range<Anchor>,
|
||||
pub text: SharedString,
|
||||
pub struct RulesContext {
|
||||
pub prompt_id: UserPromptId,
|
||||
pub context_id: ContextId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct ContextSymbolId {
|
||||
pub path: ProjectPath,
|
||||
pub name: SharedString,
|
||||
pub range: Range<Anchor>,
|
||||
impl RulesContext {
|
||||
pub fn eq_for_key(&self, other: &Self) -> bool {
|
||||
self.prompt_id == other.prompt_id
|
||||
}
|
||||
|
||||
pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
|
||||
self.prompt_id.hash(state)
|
||||
}
|
||||
|
||||
pub fn lookup_key(prompt_id: UserPromptId) -> AgentContextKey {
|
||||
AgentContextKey(AgentContext::Rules(RulesContext {
|
||||
prompt_id,
|
||||
context_id: ContextId::for_lookup(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn load(
|
||||
&self,
|
||||
prompt_store: &Option<Entity<PromptStore>>,
|
||||
cx: &App,
|
||||
) -> Task<Option<String>> {
|
||||
let Some(prompt_store) = prompt_store.as_ref() else {
|
||||
return Task::ready(None);
|
||||
};
|
||||
let prompt_store = prompt_store.read(cx);
|
||||
let prompt_id = self.prompt_id.into();
|
||||
let Some(metadata) = prompt_store.metadata(prompt_id) else {
|
||||
return Task::ready(None);
|
||||
};
|
||||
let contents_task = prompt_store.load(prompt_id, cx);
|
||||
cx.background_spawn(async move {
|
||||
let contents = contents_task.await.ok()?;
|
||||
let mut text = String::new();
|
||||
if let Some(title) = metadata.title {
|
||||
text.push_str("Rules title: ");
|
||||
text.push_str(&title);
|
||||
text.push('\n');
|
||||
}
|
||||
text.push_str("``````\n");
|
||||
text.push_str(contents.trim());
|
||||
text.push_str("\n``````\n");
|
||||
Some(text)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ExcerptContext {
|
||||
pub id: ContextId,
|
||||
pub range: Range<Anchor>,
|
||||
pub line_range: Range<Point>,
|
||||
pub context_buffer: ContextBuffer,
|
||||
pub struct ImageContext {
|
||||
pub original_image: Arc<gpui::Image>,
|
||||
// TODO: handle this elsewhere and remove `ignore-interior-mutability` opt-out in clippy.toml
|
||||
// needed due to a false positive of `clippy::mutable_key_type`.
|
||||
pub image_task: Shared<Task<Option<LanguageModelImage>>>,
|
||||
pub context_id: ContextId,
|
||||
}
|
||||
|
||||
/// Formats a collection of contexts into a string representation
|
||||
pub fn format_context_as_string<'a>(
|
||||
contexts: impl Iterator<Item = &'a AssistantContext>,
|
||||
cx: &App,
|
||||
) -> Option<String> {
|
||||
let mut file_context = Vec::new();
|
||||
let mut directory_context = Vec::new();
|
||||
let mut symbol_context = Vec::new();
|
||||
let mut excerpt_context = Vec::new();
|
||||
let mut fetch_context = Vec::new();
|
||||
let mut thread_context = Vec::new();
|
||||
pub enum ImageStatus {
|
||||
Loading,
|
||||
Error,
|
||||
Ready,
|
||||
}
|
||||
|
||||
for context in contexts {
|
||||
match context {
|
||||
AssistantContext::File(context) => file_context.push(context),
|
||||
AssistantContext::Directory(context) => directory_context.push(context),
|
||||
AssistantContext::Symbol(context) => symbol_context.push(context),
|
||||
AssistantContext::Excerpt(context) => excerpt_context.push(context),
|
||||
AssistantContext::FetchedUrl(context) => fetch_context.push(context),
|
||||
AssistantContext::Thread(context) => thread_context.push(context),
|
||||
impl ImageContext {
|
||||
pub fn eq_for_key(&self, other: &Self) -> bool {
|
||||
self.original_image.id == other.original_image.id
|
||||
}
|
||||
|
||||
pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
|
||||
self.original_image.id.hash(state);
|
||||
}
|
||||
|
||||
pub fn image(&self) -> Option<LanguageModelImage> {
|
||||
self.image_task.clone().now_or_never().flatten()
|
||||
}
|
||||
|
||||
pub fn status(&self) -> ImageStatus {
|
||||
match self.image_task.clone().now_or_never() {
|
||||
None => ImageStatus::Loading,
|
||||
Some(None) => ImageStatus::Error,
|
||||
Some(Some(_)) => ImageStatus::Ready,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if file_context.is_empty()
|
||||
&& directory_context.is_empty()
|
||||
&& symbol_context.is_empty()
|
||||
&& excerpt_context.is_empty()
|
||||
&& fetch_context.is_empty()
|
||||
&& thread_context.is_empty()
|
||||
{
|
||||
return None;
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ContextLoadResult {
|
||||
pub loaded_context: LoadedContext,
|
||||
pub referenced_buffers: HashSet<Entity<Buffer>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct LoadedContext {
|
||||
pub contexts: Vec<AgentContext>,
|
||||
pub text: String,
|
||||
pub images: Vec<LanguageModelImage>,
|
||||
}
|
||||
|
||||
impl LoadedContext {
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.text.is_empty() && self.images.is_empty()
|
||||
}
|
||||
|
||||
let mut result = String::new();
|
||||
result.push_str("\n<context>\n\
|
||||
The following items were attached by the user. You don't need to use other tools to read them.\n\n");
|
||||
|
||||
if !file_context.is_empty() {
|
||||
result.push_str("<files>\n");
|
||||
for context in file_context {
|
||||
result.push_str(&context.context_buffer.text);
|
||||
pub fn add_to_request_message(&self, request_message: &mut LanguageModelRequestMessage) {
|
||||
if !self.text.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text(self.text.to_string()));
|
||||
}
|
||||
result.push_str("</files>\n");
|
||||
}
|
||||
|
||||
if !directory_context.is_empty() {
|
||||
result.push_str("<directories>\n");
|
||||
for context in directory_context {
|
||||
for context_buffer in &context.context_buffers {
|
||||
result.push_str(&context_buffer.text);
|
||||
if !self.images.is_empty() {
|
||||
// Some providers only support image parts after an initial text part
|
||||
if request_message.content.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text("Images attached by user:".to_string()));
|
||||
}
|
||||
|
||||
for image in &self.images {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Image(image.clone()))
|
||||
}
|
||||
}
|
||||
result.push_str("</directories>\n");
|
||||
}
|
||||
|
||||
if !symbol_context.is_empty() {
|
||||
result.push_str("<symbols>\n");
|
||||
for context in symbol_context {
|
||||
result.push_str(&context.context_symbol.text);
|
||||
result.push('\n');
|
||||
}
|
||||
result.push_str("</symbols>\n");
|
||||
}
|
||||
|
||||
if !excerpt_context.is_empty() {
|
||||
result.push_str("<excerpts>\n");
|
||||
for context in excerpt_context {
|
||||
result.push_str(&context.context_buffer.text);
|
||||
result.push('\n');
|
||||
}
|
||||
result.push_str("</excerpts>\n");
|
||||
}
|
||||
|
||||
if !fetch_context.is_empty() {
|
||||
result.push_str("<fetched_urls>\n");
|
||||
for context in &fetch_context {
|
||||
result.push_str(&context.url);
|
||||
result.push('\n');
|
||||
result.push_str(&context.text);
|
||||
result.push('\n');
|
||||
}
|
||||
result.push_str("</fetched_urls>\n");
|
||||
}
|
||||
|
||||
if !thread_context.is_empty() {
|
||||
result.push_str("<conversation_threads>\n");
|
||||
for context in &thread_context {
|
||||
result.push_str(&context.summary(cx));
|
||||
result.push('\n');
|
||||
result.push_str(&context.text);
|
||||
result.push('\n');
|
||||
}
|
||||
result.push_str("</conversation_threads>\n");
|
||||
}
|
||||
|
||||
result.push_str("</context>\n");
|
||||
Some(result)
|
||||
}
|
||||
|
||||
pub fn attach_context_to_message<'a>(
|
||||
message: &mut LanguageModelRequestMessage,
|
||||
contexts: impl Iterator<Item = &'a AssistantContext>,
|
||||
cx: &App,
|
||||
) {
|
||||
if let Some(context_string) = format_context_as_string(contexts, cx) {
|
||||
message.content.push(context_string.into());
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads and formats a collection of contexts.
|
||||
pub fn load_context(
|
||||
contexts: Vec<AgentContext>,
|
||||
project: &Entity<Project>,
|
||||
prompt_store: &Option<Entity<PromptStore>>,
|
||||
cx: &mut App,
|
||||
) -> Task<ContextLoadResult> {
|
||||
let mut file_tasks = Vec::new();
|
||||
let mut directory_tasks = Vec::new();
|
||||
let mut symbol_tasks = Vec::new();
|
||||
let mut selection_tasks = Vec::new();
|
||||
let mut fetch_context = Vec::new();
|
||||
let mut thread_context = Vec::new();
|
||||
let mut rules_tasks = Vec::new();
|
||||
let mut image_tasks = Vec::new();
|
||||
|
||||
for context in contexts.iter().cloned() {
|
||||
match context {
|
||||
AgentContext::File(context) => file_tasks.extend(context.load(cx)),
|
||||
AgentContext::Directory(context) => {
|
||||
directory_tasks.extend(context.load(project.clone(), cx))
|
||||
}
|
||||
AgentContext::Symbol(context) => symbol_tasks.extend(context.load(cx)),
|
||||
AgentContext::Selection(context) => selection_tasks.extend(context.load(cx)),
|
||||
AgentContext::FetchedUrl(context) => fetch_context.push(context),
|
||||
AgentContext::Thread(context) => thread_context.push(context.load(cx)),
|
||||
AgentContext::Rules(context) => rules_tasks.push(context.load(prompt_store, cx)),
|
||||
AgentContext::Image(context) => image_tasks.push(context.image_task.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let (
|
||||
file_context,
|
||||
directory_context,
|
||||
symbol_context,
|
||||
selection_context,
|
||||
rules_context,
|
||||
images,
|
||||
) = futures::join!(
|
||||
future::join_all(file_tasks),
|
||||
future::join_all(directory_tasks),
|
||||
future::join_all(symbol_tasks),
|
||||
future::join_all(selection_tasks),
|
||||
future::join_all(rules_tasks),
|
||||
future::join_all(image_tasks)
|
||||
);
|
||||
|
||||
let directory_context = directory_context.into_iter().flatten().collect::<Vec<_>>();
|
||||
let rules_context = rules_context.into_iter().flatten().collect::<Vec<_>>();
|
||||
let images = images.into_iter().flatten().collect::<Vec<_>>();
|
||||
|
||||
let mut referenced_buffers = HashSet::default();
|
||||
let mut text = String::new();
|
||||
|
||||
if file_context.is_empty()
|
||||
&& directory_context.is_empty()
|
||||
&& symbol_context.is_empty()
|
||||
&& selection_context.is_empty()
|
||||
&& fetch_context.is_empty()
|
||||
&& thread_context.is_empty()
|
||||
&& rules_context.is_empty()
|
||||
{
|
||||
return ContextLoadResult {
|
||||
loaded_context: LoadedContext {
|
||||
contexts,
|
||||
text,
|
||||
images,
|
||||
},
|
||||
referenced_buffers,
|
||||
};
|
||||
}
|
||||
|
||||
text.push_str(
|
||||
"\n<context>\n\
|
||||
The following items were attached by the user. \
|
||||
You don't need to use other tools to read them.\n\n",
|
||||
);
|
||||
|
||||
if !file_context.is_empty() {
|
||||
text.push_str("<files>");
|
||||
for (file_text, buffer) in file_context {
|
||||
text.push('\n');
|
||||
text.push_str(&file_text);
|
||||
referenced_buffers.insert(buffer);
|
||||
}
|
||||
text.push_str("</files>\n");
|
||||
}
|
||||
|
||||
if !directory_context.is_empty() {
|
||||
text.push_str("<directories>");
|
||||
for (file_text, buffer) in directory_context {
|
||||
text.push('\n');
|
||||
text.push_str(&file_text);
|
||||
referenced_buffers.insert(buffer);
|
||||
}
|
||||
text.push_str("</directories>\n");
|
||||
}
|
||||
|
||||
if !symbol_context.is_empty() {
|
||||
text.push_str("<symbols>");
|
||||
for (symbol_text, buffer) in symbol_context {
|
||||
text.push('\n');
|
||||
text.push_str(&symbol_text);
|
||||
referenced_buffers.insert(buffer);
|
||||
}
|
||||
text.push_str("</symbols>\n");
|
||||
}
|
||||
|
||||
if !selection_context.is_empty() {
|
||||
text.push_str("<selections>");
|
||||
for (selection_text, buffer) in selection_context {
|
||||
text.push('\n');
|
||||
text.push_str(&selection_text);
|
||||
referenced_buffers.insert(buffer);
|
||||
}
|
||||
text.push_str("</selections>\n");
|
||||
}
|
||||
|
||||
if !fetch_context.is_empty() {
|
||||
text.push_str("<fetched_urls>");
|
||||
for context in fetch_context {
|
||||
text.push('\n');
|
||||
text.push_str(&context.url);
|
||||
text.push('\n');
|
||||
text.push_str(&context.text);
|
||||
}
|
||||
text.push_str("</fetched_urls>\n");
|
||||
}
|
||||
|
||||
if !thread_context.is_empty() {
|
||||
text.push_str("<conversation_threads>");
|
||||
for thread_text in thread_context {
|
||||
text.push('\n');
|
||||
text.push_str(&thread_text);
|
||||
}
|
||||
text.push_str("</conversation_threads>\n");
|
||||
}
|
||||
|
||||
if !rules_context.is_empty() {
|
||||
text.push_str(
|
||||
"<user_rules>\n\
|
||||
The user has specified the following rules that should be applied:\n",
|
||||
);
|
||||
for rules_text in rules_context {
|
||||
text.push('\n');
|
||||
text.push_str(&rules_text);
|
||||
}
|
||||
text.push_str("</user_rules>\n");
|
||||
}
|
||||
|
||||
text.push_str("</context>\n");
|
||||
|
||||
ContextLoadResult {
|
||||
loaded_context: LoadedContext {
|
||||
contexts,
|
||||
text,
|
||||
images,
|
||||
},
|
||||
referenced_buffers,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn collect_files_in_path(worktree: &Worktree, path: &Path) -> Vec<Arc<Path>> {
|
||||
let mut files = Vec::new();
|
||||
|
||||
for entry in worktree.child_entries(path) {
|
||||
if entry.is_dir() {
|
||||
files.extend(collect_files_in_path(worktree, &entry.path));
|
||||
} else if entry.is_file() {
|
||||
files.push(entry.path.clone());
|
||||
}
|
||||
}
|
||||
|
||||
files
|
||||
}
|
||||
|
||||
fn load_file_path_text_as_fenced_codeblock(
|
||||
project: Entity<Project>,
|
||||
worktree: Entity<Worktree>,
|
||||
path: Arc<Path>,
|
||||
cx: &mut App,
|
||||
) -> Task<Option<(String, Entity<Buffer>)>> {
|
||||
let worktree_ref = worktree.read(cx);
|
||||
let worktree_id = worktree_ref.id();
|
||||
let full_path = worktree_ref.full_path(&path);
|
||||
|
||||
let open_task = project.update(cx, |project, cx| {
|
||||
project.buffer_store().update(cx, |buffer_store, cx| {
|
||||
let project_path = ProjectPath { worktree_id, path };
|
||||
buffer_store.open_buffer(project_path, cx)
|
||||
})
|
||||
});
|
||||
|
||||
let rope_task = cx.spawn(async move |cx| {
|
||||
let buffer = open_task.await.log_err()?;
|
||||
let rope = buffer
|
||||
.read_with(cx, |buffer, _cx| buffer.as_rope().clone())
|
||||
.log_err()?;
|
||||
Some((rope, buffer))
|
||||
});
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let (rope, buffer) = rope_task.await?;
|
||||
Some((to_fenced_codeblock(&full_path, rope, None), buffer))
|
||||
})
|
||||
}
|
||||
|
||||
fn to_fenced_codeblock(
|
||||
full_path: &Path,
|
||||
content: Rope,
|
||||
line_range: Option<Range<Point>>,
|
||||
) -> String {
|
||||
let line_range_text = line_range.map(|range| {
|
||||
if range.start.row == range.end.row {
|
||||
format!(":{}", range.start.row + 1)
|
||||
} else {
|
||||
format!(":{}-{}", range.start.row + 1, range.end.row + 1)
|
||||
}
|
||||
});
|
||||
|
||||
let path_extension = full_path.extension().and_then(|ext| ext.to_str());
|
||||
let path_string = full_path.to_string_lossy();
|
||||
let capacity = 3
|
||||
+ path_extension.map_or(0, |extension| extension.len() + 1)
|
||||
+ path_string.len()
|
||||
+ line_range_text.as_ref().map_or(0, |text| text.len())
|
||||
+ 1
|
||||
+ content.len()
|
||||
+ 5;
|
||||
let mut buffer = String::with_capacity(capacity);
|
||||
|
||||
buffer.push_str("```");
|
||||
|
||||
if let Some(extension) = path_extension {
|
||||
buffer.push_str(extension);
|
||||
buffer.push(' ');
|
||||
}
|
||||
buffer.push_str(&path_string);
|
||||
|
||||
if let Some(line_range_text) = line_range_text {
|
||||
buffer.push_str(&line_range_text);
|
||||
}
|
||||
|
||||
buffer.push('\n');
|
||||
for chunk in content.chunks() {
|
||||
buffer.push_str(chunk);
|
||||
}
|
||||
|
||||
if !buffer.ends_with('\n') {
|
||||
buffer.push('\n');
|
||||
}
|
||||
|
||||
buffer.push_str("```\n");
|
||||
|
||||
debug_assert!(
|
||||
buffer.len() == capacity - 1 || buffer.len() == capacity,
|
||||
"to_fenced_codeblock calculated capacity of {}, but length was {}",
|
||||
capacity,
|
||||
buffer.len(),
|
||||
);
|
||||
|
||||
buffer
|
||||
}
|
||||
|
||||
/// Wraps `AgentContext` to opt-in to `PartialEq` and `Hash` impls which use a subset of fields
|
||||
/// needed for stable context identity.
|
||||
#[derive(Debug, Clone, RefCast)]
|
||||
#[repr(transparent)]
|
||||
pub struct AgentContextKey(pub AgentContext);
|
||||
|
||||
impl AsRef<AgentContext> for AgentContextKey {
|
||||
fn as_ref(&self) -> &AgentContext {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for AgentContextKey {}
|
||||
|
||||
impl PartialEq for AgentContextKey {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match &self.0 {
|
||||
AgentContext::File(context) => {
|
||||
if let AgentContext::File(other_context) = &other.0 {
|
||||
return context.eq_for_key(other_context);
|
||||
}
|
||||
}
|
||||
AgentContext::Directory(context) => {
|
||||
if let AgentContext::Directory(other_context) = &other.0 {
|
||||
return context.eq_for_key(other_context);
|
||||
}
|
||||
}
|
||||
AgentContext::Symbol(context) => {
|
||||
if let AgentContext::Symbol(other_context) = &other.0 {
|
||||
return context.eq_for_key(other_context);
|
||||
}
|
||||
}
|
||||
AgentContext::Selection(context) => {
|
||||
if let AgentContext::Selection(other_context) = &other.0 {
|
||||
return context.eq_for_key(other_context);
|
||||
}
|
||||
}
|
||||
AgentContext::FetchedUrl(context) => {
|
||||
if let AgentContext::FetchedUrl(other_context) = &other.0 {
|
||||
return context.eq_for_key(other_context);
|
||||
}
|
||||
}
|
||||
AgentContext::Thread(context) => {
|
||||
if let AgentContext::Thread(other_context) = &other.0 {
|
||||
return context.eq_for_key(other_context);
|
||||
}
|
||||
}
|
||||
AgentContext::Rules(context) => {
|
||||
if let AgentContext::Rules(other_context) = &other.0 {
|
||||
return context.eq_for_key(other_context);
|
||||
}
|
||||
}
|
||||
AgentContext::Image(context) => {
|
||||
if let AgentContext::Image(other_context) = &other.0 {
|
||||
return context.eq_for_key(other_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl Hash for AgentContextKey {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
match &self.0 {
|
||||
AgentContext::File(context) => context.hash_for_key(state),
|
||||
AgentContext::Directory(context) => context.hash_for_key(state),
|
||||
AgentContext::Symbol(context) => context.hash_for_key(state),
|
||||
AgentContext::Selection(context) => context.hash_for_key(state),
|
||||
AgentContext::FetchedUrl(context) => context.hash_for_key(state),
|
||||
AgentContext::Thread(context) => context.hash_for_key(state),
|
||||
AgentContext::Rules(context) => context.hash_for_key(state),
|
||||
AgentContext::Image(context) => context.hash_for_key(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod completion_provider;
|
||||
mod fetch_context_picker;
|
||||
mod file_context_picker;
|
||||
mod rules_context_picker;
|
||||
mod symbol_context_picker;
|
||||
mod thread_context_picker;
|
||||
|
||||
@@ -9,37 +10,96 @@ use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
pub use completion_provider::ContextPickerCompletionProvider;
|
||||
use editor::display_map::{Crease, FoldId};
|
||||
use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset};
|
||||
use fetch_context_picker::FetchContextPicker;
|
||||
use file_context_picker::FileContextPicker;
|
||||
use file_context_picker::render_file_context_entry;
|
||||
use gpui::{
|
||||
App, DismissEvent, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task,
|
||||
WeakEntity,
|
||||
};
|
||||
use language::Buffer;
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use project::{Entry, ProjectPath};
|
||||
use prompt_store::{PromptStore, UserPromptId};
|
||||
use rules_context_picker::{RulesContextEntry, RulesContextPicker};
|
||||
use symbol_context_picker::SymbolContextPicker;
|
||||
use thread_context_picker::{ThreadContextEntry, render_thread_context_entry};
|
||||
use thread_context_picker::{ThreadContextEntry, ThreadContextPicker, render_thread_context_entry};
|
||||
use ui::{
|
||||
ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
use workspace::{Workspace, notifications::NotifyResultExt};
|
||||
|
||||
use crate::AssistantPanel;
|
||||
pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider;
|
||||
use crate::context_picker::fetch_context_picker::FetchContextPicker;
|
||||
use crate::context_picker::file_context_picker::FileContextPicker;
|
||||
use crate::context_picker::thread_context_picker::ThreadContextPicker;
|
||||
use crate::context::RULES_ICON;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::thread::ThreadId;
|
||||
use crate::thread_store::ThreadStore;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ContextPickerEntry {
|
||||
Mode(ContextPickerMode),
|
||||
Action(ContextPickerAction),
|
||||
}
|
||||
|
||||
impl ContextPickerEntry {
|
||||
pub fn keyword(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Mode(mode) => mode.keyword(),
|
||||
Self::Action(action) => action.keyword(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn label(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Mode(mode) => mode.label(),
|
||||
Self::Action(action) => action.label(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn icon(&self) -> IconName {
|
||||
match self {
|
||||
Self::Mode(mode) => mode.icon(),
|
||||
Self::Action(action) => action.icon(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ContextPickerMode {
|
||||
File,
|
||||
Symbol,
|
||||
Fetch,
|
||||
Thread,
|
||||
Rules,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ContextPickerAction {
|
||||
AddSelections,
|
||||
}
|
||||
|
||||
impl ContextPickerAction {
|
||||
pub fn keyword(&self) -> &'static str {
|
||||
match self {
|
||||
Self::AddSelections => "selection",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn label(&self) -> &'static str {
|
||||
match self {
|
||||
Self::AddSelections => "Selection",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn icon(&self) -> IconName {
|
||||
match self {
|
||||
Self::AddSelections => IconName::Context,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for ContextPickerMode {
|
||||
@@ -51,18 +111,20 @@ impl TryFrom<&str> for ContextPickerMode {
|
||||
"symbol" => Ok(Self::Symbol),
|
||||
"fetch" => Ok(Self::Fetch),
|
||||
"thread" => Ok(Self::Thread),
|
||||
"rules" => Ok(Self::Rules),
|
||||
_ => Err(format!("Invalid context picker mode: {}", value)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ContextPickerMode {
|
||||
pub fn mention_prefix(&self) -> &'static str {
|
||||
pub fn keyword(&self) -> &'static str {
|
||||
match self {
|
||||
Self::File => "file",
|
||||
Self::Symbol => "symbol",
|
||||
Self::Fetch => "fetch",
|
||||
Self::Thread => "thread",
|
||||
Self::Rules => "rules",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +134,7 @@ impl ContextPickerMode {
|
||||
Self::Symbol => "Symbols",
|
||||
Self::Fetch => "Fetch",
|
||||
Self::Thread => "Threads",
|
||||
Self::Rules => "Rules",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,6 +144,7 @@ impl ContextPickerMode {
|
||||
Self::Symbol => IconName::Code,
|
||||
Self::Fetch => IconName::Globe,
|
||||
Self::Thread => IconName::MessageBubbles,
|
||||
Self::Rules => RULES_ICON,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -92,6 +156,7 @@ enum ContextPickerState {
|
||||
Symbol(Entity<SymbolContextPicker>),
|
||||
Fetch(Entity<FetchContextPicker>),
|
||||
Thread(Entity<ThreadContextPicker>),
|
||||
Rules(Entity<RulesContextPicker>),
|
||||
}
|
||||
|
||||
pub(super) struct ContextPicker {
|
||||
@@ -99,6 +164,7 @@ pub(super) struct ContextPicker {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
@@ -126,6 +192,13 @@ impl ContextPicker {
|
||||
)
|
||||
.collect::<Vec<Subscription>>();
|
||||
|
||||
let prompt_store = thread_store.as_ref().and_then(|thread_store| {
|
||||
thread_store
|
||||
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
|
||||
.ok()
|
||||
.flatten()
|
||||
});
|
||||
|
||||
ContextPicker {
|
||||
mode: ContextPickerState::Default(ContextMenu::build(
|
||||
window,
|
||||
@@ -135,6 +208,7 @@ impl ContextPicker {
|
||||
workspace,
|
||||
context_store,
|
||||
thread_store,
|
||||
prompt_store,
|
||||
_subscriptions: subscriptions,
|
||||
}
|
||||
}
|
||||
@@ -155,7 +229,18 @@ impl ContextPicker {
|
||||
.enumerate()
|
||||
.map(|(ix, entry)| self.recent_menu_item(context_picker.clone(), ix, entry));
|
||||
|
||||
let modes = supported_context_picker_modes(&self.thread_store);
|
||||
let entries = self
|
||||
.workspace
|
||||
.upgrade()
|
||||
.map(|workspace| {
|
||||
available_context_picker_entries(
|
||||
&self.prompt_store,
|
||||
&self.thread_store,
|
||||
&workspace,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
menu.when(has_recent, |menu| {
|
||||
menu.custom_row(|_, _| {
|
||||
@@ -171,15 +256,15 @@ impl ContextPicker {
|
||||
})
|
||||
.extend(recent_entries)
|
||||
.when(has_recent, |menu| menu.separator())
|
||||
.extend(modes.into_iter().map(|mode| {
|
||||
.extend(entries.into_iter().map(|entry| {
|
||||
let context_picker = context_picker.clone();
|
||||
|
||||
ContextMenuEntry::new(mode.label())
|
||||
.icon(mode.icon())
|
||||
ContextMenuEntry::new(entry.label())
|
||||
.icon(entry.icon())
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
context_picker.update(cx, |this, cx| this.select_mode(mode, window, cx))
|
||||
context_picker.update(cx, |this, cx| this.select_entry(entry, window, cx))
|
||||
})
|
||||
}))
|
||||
.keep_open_on_confirm()
|
||||
@@ -198,61 +283,87 @@ impl ContextPicker {
|
||||
self.thread_store.is_some()
|
||||
}
|
||||
|
||||
fn select_mode(
|
||||
fn select_entry(
|
||||
&mut self,
|
||||
mode: ContextPickerMode,
|
||||
entry: ContextPickerEntry,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let context_picker = cx.entity().downgrade();
|
||||
|
||||
match mode {
|
||||
ContextPickerMode::File => {
|
||||
self.mode = ContextPickerState::File(cx.new(|cx| {
|
||||
FileContextPicker::new(
|
||||
context_picker.clone(),
|
||||
self.workspace.clone(),
|
||||
self.context_store.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}));
|
||||
}
|
||||
ContextPickerMode::Symbol => {
|
||||
self.mode = ContextPickerState::Symbol(cx.new(|cx| {
|
||||
SymbolContextPicker::new(
|
||||
context_picker.clone(),
|
||||
self.workspace.clone(),
|
||||
self.context_store.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}));
|
||||
}
|
||||
ContextPickerMode::Fetch => {
|
||||
self.mode = ContextPickerState::Fetch(cx.new(|cx| {
|
||||
FetchContextPicker::new(
|
||||
context_picker.clone(),
|
||||
self.workspace.clone(),
|
||||
self.context_store.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}));
|
||||
}
|
||||
ContextPickerMode::Thread => {
|
||||
if let Some(thread_store) = self.thread_store.as_ref() {
|
||||
self.mode = ContextPickerState::Thread(cx.new(|cx| {
|
||||
ThreadContextPicker::new(
|
||||
thread_store.clone(),
|
||||
match entry {
|
||||
ContextPickerEntry::Mode(mode) => match mode {
|
||||
ContextPickerMode::File => {
|
||||
self.mode = ContextPickerState::File(cx.new(|cx| {
|
||||
FileContextPicker::new(
|
||||
context_picker.clone(),
|
||||
self.workspace.clone(),
|
||||
self.context_store.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}));
|
||||
}
|
||||
}
|
||||
ContextPickerMode::Symbol => {
|
||||
self.mode = ContextPickerState::Symbol(cx.new(|cx| {
|
||||
SymbolContextPicker::new(
|
||||
context_picker.clone(),
|
||||
self.workspace.clone(),
|
||||
self.context_store.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}));
|
||||
}
|
||||
ContextPickerMode::Rules => {
|
||||
if let Some(prompt_store) = self.prompt_store.as_ref() {
|
||||
self.mode = ContextPickerState::Rules(cx.new(|cx| {
|
||||
RulesContextPicker::new(
|
||||
prompt_store.clone(),
|
||||
context_picker.clone(),
|
||||
self.context_store.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}));
|
||||
}
|
||||
}
|
||||
ContextPickerMode::Fetch => {
|
||||
self.mode = ContextPickerState::Fetch(cx.new(|cx| {
|
||||
FetchContextPicker::new(
|
||||
context_picker.clone(),
|
||||
self.workspace.clone(),
|
||||
self.context_store.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}));
|
||||
}
|
||||
ContextPickerMode::Thread => {
|
||||
if let Some(thread_store) = self.thread_store.as_ref() {
|
||||
self.mode = ContextPickerState::Thread(cx.new(|cx| {
|
||||
ThreadContextPicker::new(
|
||||
thread_store.clone(),
|
||||
context_picker.clone(),
|
||||
self.context_store.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}));
|
||||
}
|
||||
}
|
||||
},
|
||||
ContextPickerEntry::Action(action) => match action {
|
||||
ContextPickerAction::AddSelections => {
|
||||
if let Some((context_store, workspace)) =
|
||||
self.context_store.upgrade().zip(self.workspace.upgrade())
|
||||
{
|
||||
add_selections_as_context(&context_store, &workspace, cx);
|
||||
}
|
||||
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
@@ -381,6 +492,7 @@ impl ContextPicker {
|
||||
ContextPickerState::Symbol(entity) => entity.update(cx, |_, cx| cx.notify()),
|
||||
ContextPickerState::Fetch(entity) => entity.update(cx, |_, cx| cx.notify()),
|
||||
ContextPickerState::Thread(entity) => entity.update(cx, |_, cx| cx.notify()),
|
||||
ContextPickerState::Rules(entity) => entity.update(cx, |_, cx| cx.notify()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -395,6 +507,7 @@ impl Focusable for ContextPicker {
|
||||
ContextPickerState::Symbol(symbol_picker) => symbol_picker.focus_handle(cx),
|
||||
ContextPickerState::Fetch(fetch_picker) => fetch_picker.focus_handle(cx),
|
||||
ContextPickerState::Thread(thread_picker) => thread_picker.focus_handle(cx),
|
||||
ContextPickerState::Rules(user_rules_picker) => user_rules_picker.focus_handle(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -410,6 +523,9 @@ impl Render for ContextPicker {
|
||||
ContextPickerState::Symbol(symbol_picker) => parent.child(symbol_picker.clone()),
|
||||
ContextPickerState::Fetch(fetch_picker) => parent.child(fetch_picker.clone()),
|
||||
ContextPickerState::Thread(thread_picker) => parent.child(thread_picker.clone()),
|
||||
ContextPickerState::Rules(user_rules_picker) => {
|
||||
parent.child(user_rules_picker.clone())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -421,18 +537,41 @@ enum RecentEntry {
|
||||
Thread(ThreadContextEntry),
|
||||
}
|
||||
|
||||
fn supported_context_picker_modes(
|
||||
fn available_context_picker_entries(
|
||||
prompt_store: &Option<Entity<PromptStore>>,
|
||||
thread_store: &Option<WeakEntity<ThreadStore>>,
|
||||
) -> Vec<ContextPickerMode> {
|
||||
let mut modes = vec![
|
||||
ContextPickerMode::File,
|
||||
ContextPickerMode::Symbol,
|
||||
ContextPickerMode::Fetch,
|
||||
workspace: &Entity<Workspace>,
|
||||
cx: &mut App,
|
||||
) -> Vec<ContextPickerEntry> {
|
||||
let mut entries = vec![
|
||||
ContextPickerEntry::Mode(ContextPickerMode::File),
|
||||
ContextPickerEntry::Mode(ContextPickerMode::Symbol),
|
||||
];
|
||||
if thread_store.is_some() {
|
||||
modes.push(ContextPickerMode::Thread);
|
||||
|
||||
let has_selection = workspace
|
||||
.read(cx)
|
||||
.active_item(cx)
|
||||
.and_then(|item| item.downcast::<Editor>())
|
||||
.map_or(false, |editor| {
|
||||
editor.update(cx, |editor, cx| editor.has_non_empty_selection(cx))
|
||||
});
|
||||
if has_selection {
|
||||
entries.push(ContextPickerEntry::Action(
|
||||
ContextPickerAction::AddSelections,
|
||||
));
|
||||
}
|
||||
modes
|
||||
|
||||
if thread_store.is_some() {
|
||||
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Thread));
|
||||
}
|
||||
|
||||
if prompt_store.is_some() {
|
||||
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules));
|
||||
}
|
||||
|
||||
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Fetch));
|
||||
|
||||
entries
|
||||
}
|
||||
|
||||
fn recent_context_picker_entries(
|
||||
@@ -462,22 +601,21 @@ fn recent_context_picker_entries(
|
||||
}),
|
||||
);
|
||||
|
||||
let mut current_threads = context_store.read(cx).thread_ids();
|
||||
let current_threads = context_store.read(cx).thread_ids();
|
||||
|
||||
if let Some(active_thread) = workspace
|
||||
let active_thread_id = workspace
|
||||
.panel::<AssistantPanel>(cx)
|
||||
.map(|panel| panel.read(cx).active_thread(cx))
|
||||
{
|
||||
current_threads.insert(active_thread.read(cx).id().clone());
|
||||
}
|
||||
.map(|panel| panel.read(cx).active_thread(cx).read(cx).id());
|
||||
|
||||
if let Some(thread_store) = thread_store.and_then(|thread_store| thread_store.upgrade()) {
|
||||
recent.extend(
|
||||
thread_store
|
||||
.read(cx)
|
||||
.threads()
|
||||
.reverse_chronological_threads()
|
||||
.into_iter()
|
||||
.filter(|thread| !current_threads.contains(&thread.id))
|
||||
.filter(|thread| {
|
||||
Some(&thread.id) != active_thread_id && !current_threads.contains(&thread.id)
|
||||
})
|
||||
.take(2)
|
||||
.map(|thread| {
|
||||
RecentEntry::Thread(ThreadContextEntry {
|
||||
@@ -491,6 +629,52 @@ fn recent_context_picker_entries(
|
||||
recent
|
||||
}
|
||||
|
||||
fn add_selections_as_context(
|
||||
context_store: &Entity<ContextStore>,
|
||||
workspace: &Entity<Workspace>,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let selection_ranges = selection_ranges(workspace, cx);
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
for (buffer, range) in selection_ranges {
|
||||
context_store.add_selection(buffer, range, cx);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn selection_ranges(
|
||||
workspace: &Entity<Workspace>,
|
||||
cx: &mut App,
|
||||
) -> Vec<(Entity<Buffer>, Range<text::Anchor>)> {
|
||||
let Some(editor) = workspace
|
||||
.read(cx)
|
||||
.active_item(cx)
|
||||
.and_then(|item| item.act_as::<Editor>(cx))
|
||||
else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
let selections = editor.selections.all_adjusted(cx);
|
||||
|
||||
let buffer = editor.buffer().clone().read(cx);
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
|
||||
selections
|
||||
.into_iter()
|
||||
.map(|s| snapshot.anchor_after(s.start)..snapshot.anchor_before(s.end))
|
||||
.flat_map(|range| {
|
||||
let (start_buffer, start) = buffer.text_anchor_for_position(range.start, cx)?;
|
||||
let (end_buffer, end) = buffer.text_anchor_for_position(range.end, cx)?;
|
||||
if start_buffer != end_buffer {
|
||||
return None;
|
||||
}
|
||||
Some((start_buffer, start..end))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn insert_fold_for_mention(
|
||||
excerpt_id: ExcerptId,
|
||||
crease_start: text::Anchor,
|
||||
@@ -510,24 +694,11 @@ pub(crate) fn insert_fold_for_mention(
|
||||
let start = start.bias_right(&snapshot);
|
||||
let end = snapshot.anchor_before(start.to_offset(&snapshot) + content_len);
|
||||
|
||||
let placeholder = FoldPlaceholder {
|
||||
render: render_fold_icon_button(
|
||||
crease_icon_path,
|
||||
crease_label,
|
||||
editor_entity.downgrade(),
|
||||
),
|
||||
merge_adjacent: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let render_trailer =
|
||||
move |_row, _unfold, _window: &mut Window, _cx: &mut App| Empty.into_any();
|
||||
|
||||
let crease = Crease::inline(
|
||||
let crease = crease_for_mention(
|
||||
crease_label,
|
||||
crease_icon_path,
|
||||
start..end,
|
||||
placeholder.clone(),
|
||||
fold_toggle("mention"),
|
||||
render_trailer,
|
||||
editor_entity.downgrade(),
|
||||
);
|
||||
|
||||
editor.display_map.update(cx, |display_map, cx| {
|
||||
@@ -536,6 +707,29 @@ pub(crate) fn insert_fold_for_mention(
|
||||
});
|
||||
}
|
||||
|
||||
pub fn crease_for_mention(
|
||||
label: SharedString,
|
||||
icon_path: SharedString,
|
||||
range: Range<Anchor>,
|
||||
editor_entity: WeakEntity<Editor>,
|
||||
) -> Crease<Anchor> {
|
||||
let placeholder = FoldPlaceholder {
|
||||
render: render_fold_icon_button(icon_path, label, editor_entity),
|
||||
merge_adjacent: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let render_trailer = move |_row, _unfold, _window: &mut Window, _cx: &mut App| Empty.into_any();
|
||||
|
||||
let crease = Crease::inline(
|
||||
range,
|
||||
placeholder.clone(),
|
||||
fold_toggle("mention"),
|
||||
render_trailer,
|
||||
);
|
||||
crease
|
||||
}
|
||||
|
||||
fn render_fold_icon_button(
|
||||
icon_path: SharedString,
|
||||
label: SharedString,
|
||||
@@ -624,15 +818,19 @@ fn fold_toggle(
|
||||
pub enum MentionLink {
|
||||
File(ProjectPath, Entry),
|
||||
Symbol(ProjectPath, String),
|
||||
Selection(ProjectPath, Range<usize>),
|
||||
Fetch(String),
|
||||
Thread(ThreadId),
|
||||
Rules(UserPromptId),
|
||||
}
|
||||
|
||||
impl MentionLink {
|
||||
const FILE: &str = "@file";
|
||||
const SYMBOL: &str = "@symbol";
|
||||
const SELECTION: &str = "@selection";
|
||||
const THREAD: &str = "@thread";
|
||||
const FETCH: &str = "@fetch";
|
||||
const RULES: &str = "@rules";
|
||||
|
||||
const SEPARATOR: &str = ":";
|
||||
|
||||
@@ -640,7 +838,9 @@ impl MentionLink {
|
||||
url.starts_with(Self::FILE)
|
||||
|| url.starts_with(Self::SYMBOL)
|
||||
|| url.starts_with(Self::FETCH)
|
||||
|| url.starts_with(Self::SELECTION)
|
||||
|| url.starts_with(Self::THREAD)
|
||||
|| url.starts_with(Self::RULES)
|
||||
}
|
||||
|
||||
pub fn for_file(file_name: &str, full_path: &str) -> String {
|
||||
@@ -657,14 +857,31 @@ impl MentionLink {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn for_fetch(url: &str) -> String {
|
||||
format!("[@{}]({}:{})", url, Self::FETCH, url)
|
||||
pub fn for_selection(file_name: &str, full_path: &str, line_range: Range<usize>) -> String {
|
||||
format!(
|
||||
"[@{} ({}-{})]({}:{}:{}-{})",
|
||||
file_name,
|
||||
line_range.start,
|
||||
line_range.end,
|
||||
Self::SELECTION,
|
||||
full_path,
|
||||
line_range.start,
|
||||
line_range.end
|
||||
)
|
||||
}
|
||||
|
||||
pub fn for_thread(thread: &ThreadContextEntry) -> String {
|
||||
format!("[@{}]({}:{})", thread.summary, Self::THREAD, thread.id)
|
||||
}
|
||||
|
||||
pub fn for_fetch(url: &str) -> String {
|
||||
format!("[@{}]({}:{})", url, Self::FETCH, url)
|
||||
}
|
||||
|
||||
pub fn for_rules(rules: &RulesContextEntry) -> String {
|
||||
format!("[@{}]({}:{})", rules.title, Self::RULES, rules.prompt_id.0)
|
||||
}
|
||||
|
||||
pub fn try_parse(link: &str, workspace: &Entity<Workspace>, cx: &App) -> Option<Self> {
|
||||
fn extract_project_path_from_link(
|
||||
path: &str,
|
||||
@@ -701,11 +918,29 @@ impl MentionLink {
|
||||
let project_path = extract_project_path_from_link(path, workspace, cx)?;
|
||||
Some(MentionLink::Symbol(project_path, symbol.to_string()))
|
||||
}
|
||||
Self::SELECTION => {
|
||||
let (path, line_args) = argument.split_once(Self::SEPARATOR)?;
|
||||
let project_path = extract_project_path_from_link(path, workspace, cx)?;
|
||||
|
||||
let line_range = {
|
||||
let (start, end) = line_args
|
||||
.trim_start_matches('(')
|
||||
.trim_end_matches(')')
|
||||
.split_once('-')?;
|
||||
start.parse::<usize>().ok()?..end.parse::<usize>().ok()?
|
||||
};
|
||||
|
||||
Some(MentionLink::Selection(project_path, line_range))
|
||||
}
|
||||
Self::THREAD => {
|
||||
let thread_id = ThreadId::from(argument);
|
||||
Some(MentionLink::Thread(thread_id))
|
||||
}
|
||||
Self::FETCH => Some(MentionLink::Fetch(argument.to_string())),
|
||||
Self::RULES => {
|
||||
let prompt_id = UserPromptId(Uuid::try_parse(argument).ok()?);
|
||||
Some(MentionLink::Rules(prompt_id))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,59 +1,64 @@
|
||||
use std::cell::RefCell;
|
||||
use std::ops::Range;
|
||||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use anyhow::Result;
|
||||
use editor::{CompletionProvider, Editor, ExcerptId};
|
||||
use editor::{CompletionProvider, Editor, ExcerptId, ToOffset as _};
|
||||
use file_icons::FileIcons;
|
||||
use fuzzy::{StringMatch, StringMatchCandidate};
|
||||
use gpui::{App, Entity, Task, WeakEntity};
|
||||
use http_client::HttpClientWithUrl;
|
||||
use itertools::Itertools;
|
||||
use language::{Buffer, CodeLabel, HighlightId};
|
||||
use lsp::CompletionContext;
|
||||
use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId};
|
||||
use prompt_store::PromptStore;
|
||||
use rope::Point;
|
||||
use text::{Anchor, ToPoint};
|
||||
use text::{Anchor, OffsetRangeExt, ToPoint};
|
||||
use ui::prelude::*;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::context_picker::file_context_picker::search_files;
|
||||
use crate::context_picker::symbol_context_picker::search_symbols;
|
||||
use crate::context::RULES_ICON;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::thread_store::ThreadStore;
|
||||
|
||||
use super::fetch_context_picker::fetch_url_content;
|
||||
use super::file_context_picker::FileMatch;
|
||||
use super::file_context_picker::{FileMatch, search_files};
|
||||
use super::rules_context_picker::{RulesContextEntry, search_rules};
|
||||
use super::symbol_context_picker::SymbolMatch;
|
||||
use super::symbol_context_picker::search_symbols;
|
||||
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
|
||||
use super::{
|
||||
ContextPickerMode, MentionLink, RecentEntry, recent_context_picker_entries,
|
||||
supported_context_picker_modes,
|
||||
ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry,
|
||||
available_context_picker_entries, recent_context_picker_entries, selection_ranges,
|
||||
};
|
||||
|
||||
pub(crate) enum Match {
|
||||
Symbol(SymbolMatch),
|
||||
File(FileMatch),
|
||||
Symbol(SymbolMatch),
|
||||
Thread(ThreadMatch),
|
||||
Fetch(SharedString),
|
||||
Mode(ModeMatch),
|
||||
Rules(RulesContextEntry),
|
||||
Entry(EntryMatch),
|
||||
}
|
||||
|
||||
pub struct ModeMatch {
|
||||
pub struct EntryMatch {
|
||||
mat: Option<StringMatch>,
|
||||
mode: ContextPickerMode,
|
||||
entry: ContextPickerEntry,
|
||||
}
|
||||
|
||||
impl Match {
|
||||
pub fn score(&self) -> f64 {
|
||||
match self {
|
||||
Match::File(file) => file.mat.score,
|
||||
Match::Mode(mode) => mode.mat.as_ref().map(|mat| mat.score).unwrap_or(1.),
|
||||
Match::Entry(mode) => mode.mat.as_ref().map(|mat| mat.score).unwrap_or(1.),
|
||||
Match::Thread(_) => 1.,
|
||||
Match::Symbol(_) => 1.,
|
||||
Match::Fetch(_) => 1.,
|
||||
Match::Rules(_) => 1.,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -63,6 +68,7 @@ fn search(
|
||||
query: String,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
recent_entries: Vec<RecentEntry>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
workspace: Entity<Workspace>,
|
||||
cx: &mut App,
|
||||
@@ -79,6 +85,7 @@ fn search(
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Symbol) => {
|
||||
let search_symbols_task =
|
||||
search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx);
|
||||
@@ -90,6 +97,7 @@ fn search(
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Thread) => {
|
||||
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
|
||||
let search_threads_task =
|
||||
@@ -105,6 +113,7 @@ fn search(
|
||||
Task::ready(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Fetch) => {
|
||||
if !query.is_empty() {
|
||||
Task::ready(vec![Match::Fetch(query.into())])
|
||||
@@ -112,6 +121,23 @@ fn search(
|
||||
Task::ready(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Rules) => {
|
||||
if let Some(prompt_store) = prompt_store.as_ref() {
|
||||
let search_rules_task =
|
||||
search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx);
|
||||
cx.background_spawn(async move {
|
||||
search_rules_task
|
||||
.await
|
||||
.into_iter()
|
||||
.map(Match::Rules)
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
} else {
|
||||
Task::ready(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
None => {
|
||||
if query.is_empty() {
|
||||
let mut matches = recent_entries
|
||||
@@ -142,9 +168,14 @@ fn search(
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
matches.extend(
|
||||
supported_context_picker_modes(&thread_store)
|
||||
available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx)
|
||||
.into_iter()
|
||||
.map(|mode| Match::Mode(ModeMatch { mode, mat: None })),
|
||||
.map(|mode| {
|
||||
Match::Entry(EntryMatch {
|
||||
entry: mode,
|
||||
mat: None,
|
||||
})
|
||||
}),
|
||||
);
|
||||
|
||||
Task::ready(matches)
|
||||
@@ -154,11 +185,12 @@ fn search(
|
||||
let search_files_task =
|
||||
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
|
||||
|
||||
let modes = supported_context_picker_modes(&thread_store);
|
||||
let mode_candidates = modes
|
||||
let entries =
|
||||
available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx);
|
||||
let entry_candidates = entries
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(ix, mode)| StringMatchCandidate::new(ix, mode.mention_prefix()))
|
||||
.map(|(ix, entry)| StringMatchCandidate::new(ix, entry.keyword()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cx.background_spawn(async move {
|
||||
@@ -168,8 +200,8 @@ fn search(
|
||||
.map(Match::File)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mode_matches = fuzzy::match_strings(
|
||||
&mode_candidates,
|
||||
let entry_matches = fuzzy::match_strings(
|
||||
&entry_candidates,
|
||||
&query,
|
||||
false,
|
||||
100,
|
||||
@@ -178,9 +210,9 @@ fn search(
|
||||
)
|
||||
.await;
|
||||
|
||||
matches.extend(mode_matches.into_iter().map(|mat| {
|
||||
Match::Mode(ModeMatch {
|
||||
mode: modes[mat.candidate_id],
|
||||
matches.extend(entry_matches.into_iter().map(|mat| {
|
||||
Match::Entry(EntryMatch {
|
||||
entry: entries[mat.candidate_id],
|
||||
mat: Some(mat),
|
||||
})
|
||||
}));
|
||||
@@ -220,19 +252,139 @@ impl ContextPickerCompletionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
fn completion_for_mode(source_range: Range<Anchor>, mode: ContextPickerMode) -> Completion {
|
||||
Completion {
|
||||
replace_range: source_range.clone(),
|
||||
new_text: format!("@{} ", mode.mention_prefix()),
|
||||
label: CodeLabel::plain(mode.label().to_string(), None),
|
||||
icon_path: Some(mode.icon().path().into()),
|
||||
documentation: None,
|
||||
source: project::CompletionSource::Custom,
|
||||
insert_text_mode: None,
|
||||
// This ensures that when a user accepts this completion, the
|
||||
// completion menu will still be shown after "@category " is
|
||||
// inserted
|
||||
confirm: Some(Arc::new(|_, _, _| true)),
|
||||
fn completion_for_entry(
|
||||
entry: ContextPickerEntry,
|
||||
excerpt_id: ExcerptId,
|
||||
source_range: Range<Anchor>,
|
||||
editor: Entity<Editor>,
|
||||
context_store: Entity<ContextStore>,
|
||||
workspace: &Entity<Workspace>,
|
||||
cx: &mut App,
|
||||
) -> Option<Completion> {
|
||||
match entry {
|
||||
ContextPickerEntry::Mode(mode) => Some(Completion {
|
||||
replace_range: source_range.clone(),
|
||||
new_text: format!("@{} ", mode.keyword()),
|
||||
label: CodeLabel::plain(mode.label().to_string(), None),
|
||||
icon_path: Some(mode.icon().path().into()),
|
||||
documentation: None,
|
||||
source: project::CompletionSource::Custom,
|
||||
insert_text_mode: None,
|
||||
// This ensures that when a user accepts this completion, the
|
||||
// completion menu will still be shown after "@category " is
|
||||
// inserted
|
||||
confirm: Some(Arc::new(|_, _, _| true)),
|
||||
}),
|
||||
ContextPickerEntry::Action(action) => {
|
||||
let (new_text, on_action) = match action {
|
||||
ContextPickerAction::AddSelections => {
|
||||
let selections = selection_ranges(workspace, cx);
|
||||
|
||||
let selection_infos = selections
|
||||
.iter()
|
||||
.map(|(buffer, range)| {
|
||||
let full_path = buffer
|
||||
.read(cx)
|
||||
.file()
|
||||
.map(|file| file.full_path(cx))
|
||||
.unwrap_or_else(|| PathBuf::from("untitled"));
|
||||
let file_name = full_path
|
||||
.file_name()
|
||||
.unwrap_or_default()
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
let line_range = range.to_point(&buffer.read(cx).snapshot());
|
||||
|
||||
let link = MentionLink::for_selection(
|
||||
&file_name,
|
||||
&full_path.to_string_lossy(),
|
||||
line_range.start.row as usize..line_range.end.row as usize,
|
||||
);
|
||||
(file_name, link, line_range)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let new_text = selection_infos.iter().map(|(_, link, _)| link).join(" ");
|
||||
|
||||
let callback = Arc::new({
|
||||
let context_store = context_store.clone();
|
||||
let selections = selections.clone();
|
||||
let selection_infos = selection_infos.clone();
|
||||
move |_, _: &mut Window, cx: &mut App| {
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
for (buffer, range) in &selections {
|
||||
context_store.add_selection(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
let editor = editor.clone();
|
||||
let selection_infos = selection_infos.clone();
|
||||
cx.defer(move |cx| {
|
||||
let mut current_offset = 0;
|
||||
for (file_name, link, line_range) in selection_infos.iter() {
|
||||
let snapshot =
|
||||
editor.read(cx).buffer().read(cx).snapshot(cx);
|
||||
let Some(start) = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, source_range.start)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let offset = start.to_offset(&snapshot) + current_offset;
|
||||
let text_len = link.len();
|
||||
|
||||
let range = snapshot.anchor_after(offset)
|
||||
..snapshot.anchor_after(offset + text_len);
|
||||
|
||||
let crease = super::crease_for_mention(
|
||||
format!(
|
||||
"{} ({}-{})",
|
||||
file_name,
|
||||
line_range.start.row + 1,
|
||||
line_range.end.row + 1
|
||||
)
|
||||
.into(),
|
||||
IconName::Context.path().into(),
|
||||
range,
|
||||
editor.downgrade(),
|
||||
);
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.display_map.update(cx, |display_map, cx| {
|
||||
display_map.fold(vec![crease], cx);
|
||||
});
|
||||
});
|
||||
|
||||
current_offset += text_len + 1;
|
||||
}
|
||||
});
|
||||
|
||||
false
|
||||
}
|
||||
});
|
||||
|
||||
(new_text, callback)
|
||||
}
|
||||
};
|
||||
|
||||
Some(Completion {
|
||||
replace_range: source_range.clone(),
|
||||
new_text,
|
||||
label: CodeLabel::plain(action.label().to_string(), None),
|
||||
icon_path: Some(action.icon().path().into()),
|
||||
documentation: None,
|
||||
source: project::CompletionSource::Custom,
|
||||
insert_text_mode: None,
|
||||
// This ensures that when a user accepts this completion, the
|
||||
// completion menu will still be shown after "@category " is
|
||||
// inserted
|
||||
confirm: Some(on_action),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,6 +439,40 @@ impl ContextPickerCompletionProvider {
|
||||
}
|
||||
}
|
||||
|
||||
fn completion_for_rules(
|
||||
rules: RulesContextEntry,
|
||||
excerpt_id: ExcerptId,
|
||||
source_range: Range<Anchor>,
|
||||
editor: Entity<Editor>,
|
||||
context_store: Entity<ContextStore>,
|
||||
) -> Completion {
|
||||
let new_text = MentionLink::for_rules(&rules);
|
||||
let new_text_len = new_text.len();
|
||||
Completion {
|
||||
replace_range: source_range.clone(),
|
||||
new_text,
|
||||
label: CodeLabel::plain(rules.title.to_string(), None),
|
||||
documentation: None,
|
||||
insert_text_mode: None,
|
||||
source: project::CompletionSource::Custom,
|
||||
icon_path: Some(RULES_ICON.path().into()),
|
||||
confirm: Some(confirm_completion_callback(
|
||||
RULES_ICON.path().into(),
|
||||
rules.title.clone(),
|
||||
excerpt_id,
|
||||
source_range.start,
|
||||
new_text_len,
|
||||
editor.clone(),
|
||||
move |cx| {
|
||||
let user_prompt_id = rules.prompt_id;
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
context_store.add_rules(user_prompt_id, false, cx);
|
||||
});
|
||||
},
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn completion_for_fetch(
|
||||
source_range: Range<Anchor>,
|
||||
url_to_fetch: SharedString,
|
||||
@@ -318,7 +504,7 @@ impl ContextPickerCompletionProvider {
|
||||
let url_to_fetch = url_to_fetch.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
if context_store.update(cx, |context_store, _| {
|
||||
context_store.includes_url(&url_to_fetch).is_some()
|
||||
context_store.includes_url(&url_to_fetch)
|
||||
})? {
|
||||
return Ok(());
|
||||
}
|
||||
@@ -394,7 +580,7 @@ impl ContextPickerCompletionProvider {
|
||||
move |cx| {
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
let task = if is_directory {
|
||||
context_store.add_directory(project_path.clone(), false, cx)
|
||||
Task::ready(context_store.add_directory(&project_path, false, cx))
|
||||
} else {
|
||||
context_store.add_file_from_path(project_path.clone(), false, cx)
|
||||
};
|
||||
@@ -534,11 +720,19 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
cx,
|
||||
);
|
||||
|
||||
let prompt_store = thread_store.as_ref().and_then(|thread_store| {
|
||||
thread_store
|
||||
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
|
||||
.ok()
|
||||
.flatten()
|
||||
});
|
||||
|
||||
let search_task = search(
|
||||
mode,
|
||||
query,
|
||||
Arc::<AtomicBool>::default(),
|
||||
recent_entries,
|
||||
prompt_store,
|
||||
thread_store.clone(),
|
||||
workspace.clone(),
|
||||
cx,
|
||||
@@ -570,6 +764,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
cx,
|
||||
))
|
||||
}
|
||||
|
||||
Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol(
|
||||
symbol,
|
||||
excerpt_id,
|
||||
@@ -579,6 +774,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
workspace.clone(),
|
||||
cx,
|
||||
),
|
||||
|
||||
Match::Thread(ThreadMatch {
|
||||
thread, is_recent, ..
|
||||
}) => {
|
||||
@@ -593,6 +789,15 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
thread_store,
|
||||
))
|
||||
}
|
||||
|
||||
Match::Rules(user_rules) => Some(Self::completion_for_rules(
|
||||
user_rules,
|
||||
excerpt_id,
|
||||
source_range.clone(),
|
||||
editor.clone(),
|
||||
context_store.clone(),
|
||||
)),
|
||||
|
||||
Match::Fetch(url) => Some(Self::completion_for_fetch(
|
||||
source_range.clone(),
|
||||
url,
|
||||
@@ -601,9 +806,16 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
context_store.clone(),
|
||||
http_client.clone(),
|
||||
)),
|
||||
Match::Mode(ModeMatch { mode, .. }) => {
|
||||
Some(Self::completion_for_mode(source_range.clone(), mode))
|
||||
}
|
||||
|
||||
Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry(
|
||||
entry,
|
||||
excerpt_id,
|
||||
source_range.clone(),
|
||||
editor.clone(),
|
||||
context_store.clone(),
|
||||
&workspace,
|
||||
cx,
|
||||
),
|
||||
})
|
||||
.collect()
|
||||
})?))
|
||||
|
||||
@@ -227,7 +227,7 @@ impl PickerDelegate for FetchContextPickerDelegate {
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Option<Self::ListItem> {
|
||||
let added = self.context_store.upgrade().map_or(false, |context_store| {
|
||||
context_store.read(cx).includes_url(&self.url).is_some()
|
||||
context_store.read(cx).includes_url(&self.url)
|
||||
});
|
||||
|
||||
Some(
|
||||
|
||||
@@ -134,9 +134,9 @@ impl PickerDelegate for FileContextPickerDelegate {
|
||||
.context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
if is_directory {
|
||||
context_store.add_directory(project_path, true, cx)
|
||||
Task::ready(context_store.add_directory(&project_path, true, cx))
|
||||
} else {
|
||||
context_store.add_file_from_path(project_path, true, cx)
|
||||
context_store.add_file_from_path(project_path.clone(), true, cx)
|
||||
}
|
||||
})
|
||||
.ok()
|
||||
@@ -325,11 +325,11 @@ pub fn render_file_context_entry(
|
||||
path: path.clone(),
|
||||
};
|
||||
if is_directory {
|
||||
context_store.read(cx).includes_directory(&project_path)
|
||||
} else {
|
||||
context_store
|
||||
.read(cx)
|
||||
.will_include_file_path(&project_path, cx)
|
||||
.path_included_in_directory(&project_path, cx)
|
||||
} else {
|
||||
context_store.read(cx).file_path_included(&project_path, cx)
|
||||
}
|
||||
});
|
||||
|
||||
@@ -357,7 +357,7 @@ pub fn render_file_context_entry(
|
||||
})),
|
||||
)
|
||||
.when_some(added, |el, added| match added {
|
||||
FileInclusion::Direct(_) => el.child(
|
||||
FileInclusion::Direct => el.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.justify_end()
|
||||
@@ -369,9 +369,8 @@ pub fn render_file_context_entry(
|
||||
)
|
||||
.child(Label::new("Added").size(LabelSize::Small)),
|
||||
),
|
||||
FileInclusion::InDirectory(directory_project_path) => {
|
||||
// TODO: Consider using worktree full_path to include worktree name.
|
||||
let directory_path = directory_project_path.path.to_string_lossy().into_owned();
|
||||
FileInclusion::InDirectory { full_path } => {
|
||||
let directory_full_path = full_path.to_string_lossy().into_owned();
|
||||
|
||||
el.child(
|
||||
h_flex()
|
||||
@@ -385,7 +384,7 @@ pub fn render_file_context_entry(
|
||||
)
|
||||
.child(Label::new("Included").size(LabelSize::Small)),
|
||||
)
|
||||
.tooltip(Tooltip::text(format!("in {directory_path}")))
|
||||
.tooltip(Tooltip::text(format!("in {directory_full_path}")))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
224
crates/agent/src/context_picker/rules_context_picker.rs
Normal file
224
crates/agent/src/context_picker/rules_context_picker.rs
Normal file
@@ -0,0 +1,224 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use prompt_store::{PromptId, PromptStore, UserPromptId};
|
||||
use ui::{ListItem, prelude::*};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::context::RULES_ICON;
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::{self, ContextStore};
|
||||
|
||||
pub struct RulesContextPicker {
|
||||
picker: Entity<Picker<RulesContextPickerDelegate>>,
|
||||
}
|
||||
|
||||
impl RulesContextPicker {
|
||||
pub fn new(
|
||||
prompt_store: Entity<PromptStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let delegate = RulesContextPickerDelegate::new(prompt_store, context_picker, context_store);
|
||||
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
|
||||
|
||||
RulesContextPicker { picker }
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for RulesContextPicker {
|
||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||
self.picker.focus_handle(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for RulesContextPicker {
|
||||
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
|
||||
self.picker.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RulesContextEntry {
|
||||
pub prompt_id: UserPromptId,
|
||||
pub title: SharedString,
|
||||
}
|
||||
|
||||
pub struct RulesContextPickerDelegate {
|
||||
prompt_store: Entity<PromptStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
matches: Vec<RulesContextEntry>,
|
||||
selected_index: usize,
|
||||
}
|
||||
|
||||
impl RulesContextPickerDelegate {
|
||||
pub fn new(
|
||||
prompt_store: Entity<PromptStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
) -> Self {
|
||||
RulesContextPickerDelegate {
|
||||
prompt_store,
|
||||
context_picker,
|
||||
context_store,
|
||||
matches: Vec::new(),
|
||||
selected_index: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PickerDelegate for RulesContextPickerDelegate {
|
||||
type ListItem = ListItem;
|
||||
|
||||
fn match_count(&self) -> usize {
|
||||
self.matches.len()
|
||||
}
|
||||
|
||||
fn selected_index(&self) -> usize {
|
||||
self.selected_index
|
||||
}
|
||||
|
||||
fn set_selected_index(
|
||||
&mut self,
|
||||
ix: usize,
|
||||
_window: &mut Window,
|
||||
_cx: &mut Context<Picker<Self>>,
|
||||
) {
|
||||
self.selected_index = ix;
|
||||
}
|
||||
|
||||
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
|
||||
"Search available rules…".into()
|
||||
}
|
||||
|
||||
fn update_matches(
|
||||
&mut self,
|
||||
query: String,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Task<()> {
|
||||
let search_task = search_rules(
|
||||
query,
|
||||
Arc::new(AtomicBool::default()),
|
||||
&self.prompt_store,
|
||||
cx,
|
||||
);
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let matches = search_task.await;
|
||||
this.update(cx, |this, cx| {
|
||||
this.delegate.matches = matches;
|
||||
this.delegate.selected_index = 0;
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
}
|
||||
|
||||
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
let Some(entry) = self.matches.get(self.selected_index) else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
context_store.add_rules(entry.prompt_id, true, cx)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
||||
fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
self.context_picker
|
||||
.update(cx, |_, cx| {
|
||||
cx.emit(DismissEvent);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn render_match(
|
||||
&self,
|
||||
ix: usize,
|
||||
selected: bool,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Option<Self::ListItem> {
|
||||
let thread = &self.matches[ix];
|
||||
|
||||
Some(ListItem::new(ix).inset(true).toggle_state(selected).child(
|
||||
render_thread_context_entry(thread, self.context_store.clone(), cx),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn render_thread_context_entry(
|
||||
user_rules: &RulesContextEntry,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
cx: &mut App,
|
||||
) -> Div {
|
||||
let added = context_store.upgrade().map_or(false, |context_store| {
|
||||
context_store
|
||||
.read(cx)
|
||||
.includes_user_rules(user_rules.prompt_id)
|
||||
});
|
||||
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.w_full()
|
||||
.justify_between()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.max_w_72()
|
||||
.child(
|
||||
Icon::new(RULES_ICON)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(Label::new(user_rules.title.clone()).truncate()),
|
||||
)
|
||||
.when(added, |el| {
|
||||
el.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(
|
||||
Icon::new(IconName::Check)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Success),
|
||||
)
|
||||
.child(Label::new("Added").size(LabelSize::Small)),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn search_rules(
|
||||
query: String,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
prompt_store: &Entity<PromptStore>,
|
||||
cx: &mut App,
|
||||
) -> Task<Vec<RulesContextEntry>> {
|
||||
let search_task = prompt_store.read(cx).search(query, cancellation_flag, cx);
|
||||
cx.background_spawn(async move {
|
||||
search_task
|
||||
.await
|
||||
.into_iter()
|
||||
.flat_map(|metadata| {
|
||||
// Default prompts are filtered out as they are automatically included.
|
||||
if metadata.default {
|
||||
None
|
||||
} else {
|
||||
match metadata.id {
|
||||
PromptId::EditWorkflow => None,
|
||||
PromptId::User { uuid } => Some(RulesContextEntry {
|
||||
prompt_id: uuid,
|
||||
title: metadata.title?,
|
||||
}),
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}
|
||||
@@ -10,7 +10,6 @@ use gpui::{
|
||||
use ordered_float::OrderedFloat;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use project::{DocumentSymbol, Symbol};
|
||||
use text::OffsetRangeExt;
|
||||
use ui::{ListItem, prelude::*};
|
||||
use util::ResultExt as _;
|
||||
use workspace::Workspace;
|
||||
@@ -228,18 +227,16 @@ pub(crate) fn add_symbol(
|
||||
)
|
||||
})?;
|
||||
|
||||
context_store
|
||||
.update(cx, move |context_store, cx| {
|
||||
context_store.add_symbol(
|
||||
buffer,
|
||||
name.into(),
|
||||
range,
|
||||
enclosing_range,
|
||||
remove_if_exists,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await
|
||||
context_store.update(cx, move |context_store, cx| {
|
||||
context_store.add_symbol(
|
||||
buffer,
|
||||
name.into(),
|
||||
range,
|
||||
enclosing_range,
|
||||
remove_if_exists,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -353,38 +350,13 @@ fn compute_symbol_entries(
|
||||
context_store: &ContextStore,
|
||||
cx: &App,
|
||||
) -> Vec<SymbolEntry> {
|
||||
let mut symbol_entries = Vec::with_capacity(symbols.len());
|
||||
for SymbolMatch { symbol, .. } in symbols {
|
||||
let symbols_for_path = context_store.included_symbols_by_path().get(&symbol.path);
|
||||
let is_included = if let Some(symbols_for_path) = symbols_for_path {
|
||||
let mut is_included = false;
|
||||
for included_symbol_id in symbols_for_path {
|
||||
if included_symbol_id.name.as_ref() == symbol.name.as_str() {
|
||||
if let Some(buffer) = context_store.buffer_for_symbol(included_symbol_id) {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let included_symbol_range =
|
||||
included_symbol_id.range.to_point_utf16(&snapshot);
|
||||
|
||||
if included_symbol_range.start == symbol.range.start.0
|
||||
&& included_symbol_range.end == symbol.range.end.0
|
||||
{
|
||||
is_included = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
is_included
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
symbol_entries.push(SymbolEntry {
|
||||
symbols
|
||||
.into_iter()
|
||||
.map(|SymbolMatch { symbol, .. }| SymbolEntry {
|
||||
is_included: context_store.includes_symbol(&symbol, cx),
|
||||
symbol,
|
||||
is_included,
|
||||
})
|
||||
}
|
||||
symbol_entries
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub fn render_symbol_context_entry(id: ElementId, entry: &SymbolEntry) -> Stateful<Div> {
|
||||
|
||||
@@ -103,11 +103,11 @@ impl PickerDelegate for ThreadContextPickerDelegate {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Task<()> {
|
||||
let Some(threads) = self.thread_store.upgrade() else {
|
||||
let Some(thread_store) = self.thread_store.upgrade() else {
|
||||
return Task::ready(());
|
||||
};
|
||||
|
||||
let search_task = search_threads(query, Arc::new(AtomicBool::default()), threads, cx);
|
||||
let search_task = search_threads(query, Arc::new(AtomicBool::default()), thread_store, cx);
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let matches = search_task.await;
|
||||
this.update(cx, |this, cx| {
|
||||
@@ -173,7 +173,7 @@ pub fn render_thread_context_entry(
|
||||
cx: &mut App,
|
||||
) -> Div {
|
||||
let added = context_store.upgrade().map_or(false, |ctx_store| {
|
||||
ctx_store.read(cx).includes_thread(&thread.id).is_some()
|
||||
ctx_store.read(cx).includes_thread(&thread.id)
|
||||
});
|
||||
|
||||
h_flex()
|
||||
@@ -217,15 +217,15 @@ pub(crate) fn search_threads(
|
||||
thread_store: Entity<ThreadStore>,
|
||||
cx: &mut App,
|
||||
) -> Task<Vec<ThreadMatch>> {
|
||||
let threads = thread_store.update(cx, |this, _cx| {
|
||||
this.threads()
|
||||
.into_iter()
|
||||
.map(|thread| ThreadContextEntry {
|
||||
id: thread.id,
|
||||
summary: thread.summary,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
let threads = thread_store
|
||||
.read(cx)
|
||||
.reverse_chronological_threads()
|
||||
.into_iter()
|
||||
.map(|thread| ThreadContextEntry {
|
||||
id: thread.id,
|
||||
summary: thread.summary,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let executor = cx.background_executor().clone();
|
||||
cx.background_spawn(async move {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,9 +12,9 @@ use itertools::Itertools;
|
||||
use language::Buffer;
|
||||
use project::ProjectItem;
|
||||
use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
|
||||
use workspace::{Workspace, notifications::NotifyResultExt};
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::context::{ContextId, ContextKind};
|
||||
use crate::context::{AgentContext, ContextKind};
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::thread::Thread;
|
||||
@@ -32,6 +32,7 @@ pub struct ContextStrip {
|
||||
focus_handle: FocusHandle,
|
||||
suggest_context_kind: SuggestContextKind,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
focused_index: Option<usize>,
|
||||
children_bounds: Option<Vec<Bounds<Pixels>>>,
|
||||
@@ -73,12 +74,31 @@ impl ContextStrip {
|
||||
focus_handle,
|
||||
suggest_context_kind,
|
||||
workspace,
|
||||
thread_store,
|
||||
_subscriptions: subscriptions,
|
||||
focused_index: None,
|
||||
children_bounds: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn added_contexts(&self, cx: &App) -> Vec<AddedContext> {
|
||||
if let Some(workspace) = self.workspace.upgrade() {
|
||||
let project = workspace.read(cx).project().read(cx);
|
||||
let prompt_store = self
|
||||
.thread_store
|
||||
.as_ref()
|
||||
.and_then(|thread_store| thread_store.upgrade())
|
||||
.and_then(|thread_store| thread_store.read(cx).prompt_store().as_ref());
|
||||
self.context_store
|
||||
.read(cx)
|
||||
.context()
|
||||
.flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx))
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn suggested_context(&self, cx: &Context<Self>) -> Option<SuggestedContext> {
|
||||
match self.suggest_context_kind {
|
||||
SuggestContextKind::File => self.suggested_file(cx),
|
||||
@@ -93,22 +113,19 @@ impl ContextStrip {
|
||||
let editor = active_item.to_any().downcast::<Editor>().ok()?.read(cx);
|
||||
let active_buffer_entity = editor.buffer().read(cx).as_singleton()?;
|
||||
let active_buffer = active_buffer_entity.read(cx);
|
||||
|
||||
let project_path = active_buffer.project_path(cx)?;
|
||||
|
||||
if self
|
||||
.context_store
|
||||
.read(cx)
|
||||
.will_include_buffer(active_buffer.remote_id(), &project_path)
|
||||
.file_path_included(&project_path, cx)
|
||||
.is_some()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let file_name = active_buffer.file()?.file_name(cx);
|
||||
|
||||
let icon_path = FileIcons::get_icon(&Path::new(&file_name), cx);
|
||||
|
||||
Some(SuggestedContext::File {
|
||||
name: file_name.to_string_lossy().into_owned().into(),
|
||||
buffer: active_buffer_entity.downgrade(),
|
||||
@@ -135,7 +152,6 @@ impl ContextStrip {
|
||||
.context_store
|
||||
.read(cx)
|
||||
.includes_thread(active_thread.id())
|
||||
.is_some()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
@@ -272,12 +288,12 @@ impl ContextStrip {
|
||||
best.map(|(index, _, _)| index)
|
||||
}
|
||||
|
||||
fn open_context(&mut self, id: ContextId, window: &mut Window, cx: &mut App) {
|
||||
fn open_context(&mut self, context: &AgentContext, window: &mut Window, cx: &mut App) {
|
||||
let Some(workspace) = self.workspace.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
crate::active_thread::open_context(id, self.context_store.clone(), workspace, window, cx);
|
||||
crate::active_thread::open_context(context, workspace, window, cx);
|
||||
}
|
||||
|
||||
fn remove_focused_context(
|
||||
@@ -287,17 +303,17 @@ impl ContextStrip {
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(index) = self.focused_index {
|
||||
let mut is_empty = false;
|
||||
let added_contexts = self.added_contexts(cx);
|
||||
let Some(context) = added_contexts.get(index) else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.context_store.update(cx, |this, cx| {
|
||||
if let Some(item) = this.context().get(index) {
|
||||
this.remove_context(item.id(), cx);
|
||||
}
|
||||
|
||||
is_empty = this.context().is_empty();
|
||||
this.remove_context(&context.context, cx);
|
||||
});
|
||||
|
||||
if is_empty {
|
||||
let is_now_empty = added_contexts.len() == 1;
|
||||
if is_now_empty {
|
||||
cx.emit(ContextStripEvent::BlurredEmpty);
|
||||
} else {
|
||||
self.focused_index = Some(index.saturating_sub(1));
|
||||
@@ -306,49 +322,28 @@ impl ContextStrip {
|
||||
}
|
||||
}
|
||||
|
||||
fn is_suggested_focused<T>(&self, context: &Vec<T>) -> bool {
|
||||
fn is_suggested_focused(&self, added_contexts: &Vec<AddedContext>) -> bool {
|
||||
// We only suggest one item after the actual context
|
||||
self.focused_index == Some(context.len())
|
||||
self.focused_index == Some(added_contexts.len())
|
||||
}
|
||||
|
||||
fn accept_suggested_context(
|
||||
&mut self,
|
||||
_: &AcceptSuggestedContext,
|
||||
window: &mut Window,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(suggested) = self.suggested_context(cx) {
|
||||
let context_store = self.context_store.read(cx);
|
||||
|
||||
if self.is_suggested_focused(context_store.context()) {
|
||||
self.add_suggested_context(&suggested, window, cx);
|
||||
if self.is_suggested_focused(&self.added_contexts(cx)) {
|
||||
self.add_suggested_context(&suggested, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_suggested_context(
|
||||
&mut self,
|
||||
suggested: &SuggestedContext,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let task = self.context_store.update(cx, |context_store, cx| {
|
||||
context_store.accept_suggested_context(&suggested, cx)
|
||||
fn add_suggested_context(&mut self, suggested: &SuggestedContext, cx: &mut Context<Self>) {
|
||||
self.context_store.update(cx, |context_store, cx| {
|
||||
context_store.add_suggested_context(&suggested, cx)
|
||||
});
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
match task.await.notify_async_err(cx) {
|
||||
None => {}
|
||||
Some(()) => {
|
||||
if let Some(this) = this.upgrade() {
|
||||
this.update(cx, |_, cx| cx.notify())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
@@ -361,17 +356,10 @@ impl Focusable for ContextStrip {
|
||||
|
||||
impl Render for ContextStrip {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let context_store = self.context_store.read(cx);
|
||||
let context = context_store.context();
|
||||
let context_picker = self.context_picker.clone();
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
||||
let suggested_context = self.suggested_context(cx);
|
||||
|
||||
let added_contexts = context
|
||||
.iter()
|
||||
.map(|c| AddedContext::new(c, cx))
|
||||
.collect::<Vec<_>>();
|
||||
let added_contexts = self.added_contexts(cx);
|
||||
let dupe_names = added_contexts
|
||||
.iter()
|
||||
.map(|c| c.name.clone())
|
||||
@@ -380,6 +368,14 @@ impl Render for ContextStrip {
|
||||
.filter(|(a, b)| a == b)
|
||||
.map(|(a, _)| a)
|
||||
.collect::<HashSet<SharedString>>();
|
||||
let no_added_context = added_contexts.is_empty();
|
||||
|
||||
let suggested_context = self.suggested_context(cx).map(|suggested_context| {
|
||||
(
|
||||
suggested_context,
|
||||
self.is_suggested_focused(&added_contexts),
|
||||
)
|
||||
});
|
||||
|
||||
h_flex()
|
||||
.flex_wrap()
|
||||
@@ -436,7 +432,7 @@ impl Render for ContextStrip {
|
||||
})
|
||||
.with_handle(self.context_picker_menu_handle.clone()),
|
||||
)
|
||||
.when(context.is_empty() && suggested_context.is_none(), {
|
||||
.when(no_added_context && suggested_context.is_none(), {
|
||||
|parent| {
|
||||
parent.child(
|
||||
h_flex()
|
||||
@@ -466,16 +462,17 @@ impl Render for ContextStrip {
|
||||
.enumerate()
|
||||
.map(|(i, added_context)| {
|
||||
let name = added_context.name.clone();
|
||||
let id = added_context.id;
|
||||
let context = added_context.context.clone();
|
||||
ContextPill::added(
|
||||
added_context,
|
||||
dupe_names.contains(&name),
|
||||
self.focused_index == Some(i),
|
||||
Some({
|
||||
let context = context.clone();
|
||||
let context_store = self.context_store.clone();
|
||||
Rc::new(cx.listener(move |_this, _event, _window, cx| {
|
||||
context_store.update(cx, |this, cx| {
|
||||
this.remove_context(id, cx);
|
||||
this.remove_context(&context, cx);
|
||||
});
|
||||
cx.notify();
|
||||
}))
|
||||
@@ -484,7 +481,7 @@ impl Render for ContextStrip {
|
||||
.on_click({
|
||||
Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| {
|
||||
if event.down.click_count > 1 {
|
||||
this.open_context(id, window, cx);
|
||||
this.open_context(&context, window, cx);
|
||||
} else {
|
||||
this.focused_index = Some(i);
|
||||
}
|
||||
@@ -493,22 +490,22 @@ impl Render for ContextStrip {
|
||||
})
|
||||
}),
|
||||
)
|
||||
.when_some(suggested_context, |el, suggested| {
|
||||
.when_some(suggested_context, |el, (suggested, focused)| {
|
||||
el.child(
|
||||
ContextPill::suggested(
|
||||
suggested.name().clone(),
|
||||
suggested.icon_path(),
|
||||
suggested.kind(),
|
||||
self.is_suggested_focused(&context),
|
||||
focused,
|
||||
)
|
||||
.on_click(Rc::new(cx.listener(
|
||||
move |this, _event, window, cx| {
|
||||
this.add_suggested_context(&suggested, window, cx);
|
||||
move |this, _event, _window, cx| {
|
||||
this.add_suggested_context(&suggested, cx);
|
||||
},
|
||||
))),
|
||||
)
|
||||
})
|
||||
.when(!context.is_empty(), {
|
||||
.when(!no_added_context, {
|
||||
move |parent| {
|
||||
parent.child(
|
||||
IconButton::new("remove-all-context", IconName::Eraser)
|
||||
@@ -534,6 +531,7 @@ impl Render for ContextStrip {
|
||||
)
|
||||
}
|
||||
})
|
||||
.into_any()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,7 +51,10 @@ impl HistoryStore {
|
||||
return history_entries;
|
||||
}
|
||||
|
||||
for thread in self.thread_store.update(cx, |this, _cx| this.threads()) {
|
||||
for thread in self
|
||||
.thread_store
|
||||
.update(cx, |this, _cx| this.reverse_chronological_threads())
|
||||
{
|
||||
history_entries.push(HistoryEntry::Thread(thread));
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ use project::LspAction;
|
||||
use project::Project;
|
||||
use project::{CodeAction, ProjectTransaction};
|
||||
use prompt_store::PromptBuilder;
|
||||
use prompt_store::PromptStore;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||
use terminal_view::{TerminalView, terminal_panel::TerminalPanel};
|
||||
@@ -245,9 +246,13 @@ impl InlineAssistant {
|
||||
.map_or(false, |model| model.provider.is_authenticated(cx))
|
||||
};
|
||||
|
||||
let thread_store = workspace
|
||||
let assistant_panel = workspace
|
||||
.panel::<AssistantPanel>(cx)
|
||||
.map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
|
||||
.map(|assistant_panel| assistant_panel.read(cx));
|
||||
let prompt_store = assistant_panel
|
||||
.and_then(|assistant_panel| assistant_panel.prompt_store().as_ref().cloned());
|
||||
let thread_store =
|
||||
assistant_panel.map(|assistant_panel| assistant_panel.thread_store().downgrade());
|
||||
|
||||
let handle_assist =
|
||||
|window: &mut Window, cx: &mut Context<Workspace>| match inline_assist_target {
|
||||
@@ -257,6 +262,7 @@ impl InlineAssistant {
|
||||
&active_editor,
|
||||
cx.entity().downgrade(),
|
||||
workspace.project().downgrade(),
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
@@ -269,6 +275,7 @@ impl InlineAssistant {
|
||||
&active_terminal,
|
||||
cx.entity().downgrade(),
|
||||
workspace.project().downgrade(),
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
@@ -323,6 +330,7 @@ impl InlineAssistant {
|
||||
editor: &Entity<Editor>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -437,6 +445,8 @@ impl InlineAssistant {
|
||||
range.clone(),
|
||||
None,
|
||||
context_store.clone(),
|
||||
project.clone(),
|
||||
prompt_store.clone(),
|
||||
self.telemetry.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
cx,
|
||||
@@ -525,6 +535,7 @@ impl InlineAssistant {
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
focus: bool,
|
||||
workspace: Entity<Workspace>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -543,7 +554,7 @@ impl InlineAssistant {
|
||||
}
|
||||
|
||||
let project = workspace.read(cx).project().downgrade();
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project.clone(), thread_store.clone()));
|
||||
|
||||
let codegen = cx.new(|cx| {
|
||||
BufferCodegen::new(
|
||||
@@ -551,6 +562,8 @@ impl InlineAssistant {
|
||||
range.clone(),
|
||||
initial_transaction_id,
|
||||
context_store.clone(),
|
||||
project,
|
||||
prompt_store,
|
||||
self.telemetry.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
cx,
|
||||
@@ -1328,7 +1341,7 @@ impl InlineAssistant {
|
||||
editor.highlight_rows::<InlineAssist>(
|
||||
row_range,
|
||||
cx.theme().status().info_background,
|
||||
false,
|
||||
Default::default(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
@@ -1393,7 +1406,7 @@ impl InlineAssistant {
|
||||
editor.highlight_rows::<DeletedLines>(
|
||||
Anchor::min()..Anchor::max(),
|
||||
cx.theme().status().deleted_background,
|
||||
false,
|
||||
Default::default(),
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
@@ -1789,6 +1802,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
|
||||
let editor = self.editor.clone();
|
||||
let workspace = self.workspace.clone();
|
||||
let thread_store = self.thread_store.clone();
|
||||
let prompt_store = PromptStore::global(cx);
|
||||
window.spawn(cx, async move |cx| {
|
||||
let workspace = workspace.upgrade().context("workspace was released")?;
|
||||
let editor = editor.upgrade().context("editor was released")?;
|
||||
@@ -1829,6 +1843,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
|
||||
})?
|
||||
.context("invalid range")?;
|
||||
|
||||
let prompt_store = prompt_store.await.ok();
|
||||
cx.update_global(|assistant: &mut InlineAssistant, window, cx| {
|
||||
let assist_id = assistant.suggest_assist(
|
||||
&editor,
|
||||
@@ -1837,6 +1852,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
|
||||
None,
|
||||
true,
|
||||
workspace,
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
|
||||
@@ -13,7 +13,7 @@ use editor::{
|
||||
Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, GutterDimensions, MultiBuffer,
|
||||
actions::{MoveDown, MoveUp},
|
||||
};
|
||||
use feature_flags::{FeatureFlagAppExt as _, ZedPro};
|
||||
use feature_flags::{FeatureFlagAppExt as _, ZedProFeatureFlag};
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
AnyElement, App, ClickEvent, Context, CursorStyle, Entity, EventEmitter, FocusHandle,
|
||||
@@ -132,7 +132,7 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
|
||||
let error_message = SharedString::from(error.to_string());
|
||||
if error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& cx.has_flag::<ZedPro>()
|
||||
&& cx.has_flag::<ZedProFeatureFlag>()
|
||||
{
|
||||
el.child(
|
||||
v_flex()
|
||||
@@ -931,7 +931,7 @@ impl PromptEditor<BufferCodegen> {
|
||||
.update(cx, |editor, _| editor.set_read_only(false));
|
||||
}
|
||||
CodegenStatus::Error(error) => {
|
||||
if cx.has_flag::<ZedPro>()
|
||||
if cx.has_flag::<ZedProFeatureFlag>()
|
||||
&& error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& !dismissed_rate_limit_notice()
|
||||
{
|
||||
|
||||
@@ -2,26 +2,29 @@ use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::assistant_model_selector::ModelType;
|
||||
use crate::context::{AssistantContext, format_context_as_string};
|
||||
use crate::context::{ContextLoadResult, load_context};
|
||||
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
|
||||
use buffer_diff::BufferDiff;
|
||||
use collections::HashSet;
|
||||
use editor::actions::MoveUp;
|
||||
use editor::actions::{MoveUp, Paste};
|
||||
use editor::{
|
||||
ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorEvent, EditorMode,
|
||||
EditorStyle, MultiBuffer,
|
||||
};
|
||||
use file_icons::FileIcons;
|
||||
use fs::Fs;
|
||||
use futures::future::Shared;
|
||||
use futures::{FutureExt as _, future};
|
||||
use gpui::{
|
||||
Animation, AnimationExt, App, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle,
|
||||
WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
|
||||
Animation, AnimationExt, App, ClipboardEntry, Entity, EventEmitter, Focusable, Subscription,
|
||||
Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
|
||||
};
|
||||
use language::{Buffer, Language};
|
||||
use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage};
|
||||
use language_model_selector::ToggleModelSelector;
|
||||
use multi_buffer;
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use theme::ThemeSettings;
|
||||
@@ -31,7 +34,7 @@ use workspace::Workspace;
|
||||
|
||||
use crate::assistant_model_selector::AssistantModelSelector;
|
||||
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
|
||||
use crate::context_store::{ContextStore, refresh_context_store_text};
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
|
||||
use crate::profile_selector::ProfileSelector;
|
||||
use crate::thread::{Thread, TokenUsageRatio};
|
||||
@@ -49,9 +52,12 @@ pub struct MessageEditor {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
context_strip: Entity<ContextStrip>,
|
||||
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
|
||||
model_selector: Entity<AssistantModelSelector>,
|
||||
last_loaded_context: Option<ContextLoadResult>,
|
||||
context_load_task: Option<Shared<Task<()>>>,
|
||||
profile_selector: Entity<ProfileSelector>,
|
||||
edits_expanded: bool,
|
||||
editor_is_expanded: bool,
|
||||
@@ -68,6 +74,7 @@ impl MessageEditor {
|
||||
fs: Arc<dyn Fs>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
thread: Entity<Thread>,
|
||||
window: &mut Window,
|
||||
@@ -135,13 +142,11 @@ impl MessageEditor {
|
||||
let subscriptions = vec![
|
||||
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
|
||||
cx.subscribe(&editor, |this, _, event, cx| match event {
|
||||
EditorEvent::BufferEdited => {
|
||||
this.message_or_context_changed(true, cx);
|
||||
}
|
||||
EditorEvent::BufferEdited => this.handle_message_changed(cx),
|
||||
_ => {}
|
||||
}),
|
||||
cx.observe(&context_store, |this, _, cx| {
|
||||
this.message_or_context_changed(false, cx);
|
||||
let _ = this.start_context_load(cx);
|
||||
}),
|
||||
];
|
||||
|
||||
@@ -152,8 +157,11 @@ impl MessageEditor {
|
||||
incompatible_tools_state: incompatible_tools.clone(),
|
||||
workspace,
|
||||
context_store,
|
||||
prompt_store,
|
||||
context_strip,
|
||||
context_picker_menu_handle,
|
||||
context_load_task: None,
|
||||
last_loaded_context: None,
|
||||
model_selector: cx.new(|cx| {
|
||||
AssistantModelSelector::new(
|
||||
fs.clone(),
|
||||
@@ -175,6 +183,10 @@ impl MessageEditor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_store(&self) -> &Entity<ContextStore> {
|
||||
&self.context_store
|
||||
}
|
||||
|
||||
fn toggle_chat_mode(&mut self, _: &ChatMode, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.notify();
|
||||
}
|
||||
@@ -195,6 +207,7 @@ impl MessageEditor {
|
||||
editor.set_mode(EditorMode::Full {
|
||||
scale_ui_elements_with_buffer_font_size: false,
|
||||
show_active_line_background: false,
|
||||
sized_by_content: false,
|
||||
})
|
||||
} else {
|
||||
editor.set_mode(EditorMode::AutoHeight {
|
||||
@@ -213,6 +226,7 @@ impl MessageEditor {
|
||||
) {
|
||||
self.context_picker_menu_handle.toggle(window, cx);
|
||||
}
|
||||
|
||||
pub fn remove_all_context(
|
||||
&mut self,
|
||||
_: &RemoveAllContext,
|
||||
@@ -269,49 +283,44 @@ impl MessageEditor {
|
||||
self.last_estimated_token_count.take();
|
||||
cx.emit(MessageEditorEvent::EstimatedTokenCount);
|
||||
|
||||
let refresh_task =
|
||||
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
|
||||
|
||||
let thread = self.thread.clone();
|
||||
let context_store = self.context_store.clone();
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
|
||||
let context_task = self.load_context(cx);
|
||||
let window_handle = window.window_handle();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let checkpoint = checkpoint.await.ok();
|
||||
refresh_task.await;
|
||||
cx.spawn(async move |_this, cx| {
|
||||
let (checkpoint, loaded_context) = future::join(checkpoint, context_task).await;
|
||||
let loaded_context = loaded_context.unwrap_or_default();
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
let context = context_store.read(cx).context().clone();
|
||||
thread.insert_user_message(user_message, context, checkpoint, cx);
|
||||
thread.insert_user_message(user_message, loaded_context, checkpoint.ok(), cx);
|
||||
})
|
||||
.log_err();
|
||||
|
||||
context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
let excerpt_ids = context_store
|
||||
.context()
|
||||
.iter()
|
||||
.filter(|ctx| matches!(ctx, AssistantContext::Excerpt(_)))
|
||||
.map(|ctx| ctx.id())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for id in excerpt_ids {
|
||||
context_store.remove_context(id, cx);
|
||||
}
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.advance_prompt_id();
|
||||
thread.send_to_model(model, Some(window_handle), cx);
|
||||
})
|
||||
.log_err();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn wait_for_summaries(&mut self, cx: &mut Context<Self>) -> Task<()> {
|
||||
let context_store = self.context_store.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
if let Some(wait_for_summaries) = context_store
|
||||
.update(cx, |context_store, cx| context_store.wait_for_summaries(cx))
|
||||
.log_err()
|
||||
.ok()
|
||||
{
|
||||
this.update(cx, |this, cx| {
|
||||
this.waiting_for_summaries_to_send = true;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
.ok();
|
||||
|
||||
wait_for_summaries.await;
|
||||
|
||||
@@ -319,24 +328,15 @@ impl MessageEditor {
|
||||
this.waiting_for_summaries_to_send = false;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
.ok();
|
||||
}
|
||||
|
||||
// Send to model after summaries are done
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.advance_prompt_id();
|
||||
thread.send_to_model(model, cx);
|
||||
})
|
||||
.log_err();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let cancelled = self
|
||||
.thread
|
||||
.update(cx, |thread, cx| thread.cancel_last_completion(cx));
|
||||
let cancelled = self.thread.update(cx, |thread, cx| {
|
||||
thread.cancel_last_completion(Some(window.window_handle()), cx)
|
||||
});
|
||||
|
||||
if cancelled {
|
||||
self.set_editor_is_expanded(false, cx);
|
||||
@@ -370,8 +370,38 @@ impl MessageEditor {
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_review_click(&self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
fn paste(&mut self, _: &Paste, _: &mut Window, cx: &mut Context<Self>) {
|
||||
let images = cx
|
||||
.read_from_clipboard()
|
||||
.map(|item| {
|
||||
item.into_entries()
|
||||
.filter_map(|entry| {
|
||||
if let ClipboardEntry::Image(image) = entry {
|
||||
Some(image)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
if images.is_empty() {
|
||||
return;
|
||||
}
|
||||
cx.stop_propagation();
|
||||
|
||||
self.context_store.update(cx, |store, cx| {
|
||||
for image in images {
|
||||
store.add_image(Arc::new(image), cx);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn handle_review_click(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.edits_expanded = true;
|
||||
AgentDiff::deploy(self.thread.clone(), self.workspace.clone(), window, cx).log_err();
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_file_click(
|
||||
@@ -445,6 +475,7 @@ impl MessageEditor {
|
||||
.on_action(cx.listener(Self::move_up))
|
||||
.on_action(cx.listener(Self::toggle_chat_mode))
|
||||
.on_action(cx.listener(Self::expand_message_editor))
|
||||
.capture_action(cx.listener(Self::paste))
|
||||
.gap_2()
|
||||
.p_2()
|
||||
.bg(editor_bg_color)
|
||||
@@ -975,6 +1006,48 @@ impl MessageEditor {
|
||||
self.update_token_count_task.is_some()
|
||||
}
|
||||
|
||||
fn handle_message_changed(&mut self, cx: &mut Context<Self>) {
|
||||
self.message_or_context_changed(true, cx);
|
||||
}
|
||||
|
||||
fn start_context_load(&mut self, cx: &mut Context<Self>) -> Shared<Task<()>> {
|
||||
let summaries_task = self.wait_for_summaries(cx);
|
||||
let load_task = cx.spawn(async move |this, cx| {
|
||||
// Waits for detailed summaries before `load_context`, as it directly reads these from
|
||||
// the thread. TODO: Would be cleaner to have context loading await on summarization.
|
||||
summaries_task.await;
|
||||
let Ok(load_task) = this.update(cx, |this, cx| {
|
||||
let new_context = this.context_store.read_with(cx, |context_store, cx| {
|
||||
context_store.new_context_for_thread(this.thread.read(cx))
|
||||
});
|
||||
load_context(new_context, &this.project, &this.prompt_store, cx)
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
let result = load_task.await;
|
||||
this.update(cx, |this, cx| {
|
||||
this.last_loaded_context = Some(result);
|
||||
this.context_load_task = None;
|
||||
this.message_or_context_changed(false, cx);
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
// Replace existing load task, if any, causing it to be cancelled.
|
||||
let load_task = load_task.shared();
|
||||
self.context_load_task = Some(load_task.clone());
|
||||
load_task
|
||||
}
|
||||
|
||||
fn load_context(&mut self, cx: &mut Context<Self>) -> Task<Option<ContextLoadResult>> {
|
||||
let context_load_task = self.start_context_load(cx);
|
||||
cx.spawn(async move |this, cx| {
|
||||
context_load_task.await;
|
||||
this.read_with(cx, |this, _cx| this.last_loaded_context.clone())
|
||||
.ok()
|
||||
.flatten()
|
||||
})
|
||||
}
|
||||
|
||||
fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context<Self>) {
|
||||
cx.emit(MessageEditorEvent::Changed);
|
||||
self.update_token_count_task.take();
|
||||
@@ -984,9 +1057,7 @@ impl MessageEditor {
|
||||
return;
|
||||
};
|
||||
|
||||
let context_store = self.context_store.clone();
|
||||
let editor = self.editor.clone();
|
||||
let thread = self.thread.clone();
|
||||
|
||||
self.update_token_count_task = Some(cx.spawn(async move |this, cx| {
|
||||
if debounce {
|
||||
@@ -995,27 +1066,33 @@ impl MessageEditor {
|
||||
.await;
|
||||
}
|
||||
|
||||
let token_count = if let Some(task) = cx.update(|cx| {
|
||||
let context = context_store.read(cx).context().iter();
|
||||
let new_context = thread.read(cx).filter_new_context(context);
|
||||
let context_text =
|
||||
format_context_as_string(new_context, cx).unwrap_or(String::new());
|
||||
let token_count = if let Some(task) = this.update(cx, |this, cx| {
|
||||
let loaded_context = this
|
||||
.last_loaded_context
|
||||
.as_ref()
|
||||
.map(|context_load_result| &context_load_result.loaded_context);
|
||||
let message_text = editor.read(cx).text(cx);
|
||||
|
||||
let content = context_text + &message_text;
|
||||
|
||||
if content.is_empty() {
|
||||
if message_text.is_empty()
|
||||
&& loaded_context.map_or(true, |loaded_context| loaded_context.is_empty())
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
|
||||
if let Some(loaded_context) = loaded_context {
|
||||
loaded_context.add_to_request_message(&mut request_message);
|
||||
}
|
||||
|
||||
let request = language_model::LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: vec![content.into()],
|
||||
cache: false,
|
||||
}],
|
||||
messages: vec![request_message],
|
||||
tools: vec![],
|
||||
stop: vec![],
|
||||
temperature: None,
|
||||
|
||||
@@ -32,7 +32,7 @@ impl TerminalCodegen {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
|
||||
pub fn start(&mut self, prompt_task: Task<LanguageModelRequest>, cx: &mut Context<Self>) {
|
||||
let Some(ConfiguredModel { model, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).inline_assistant_model()
|
||||
else {
|
||||
@@ -45,6 +45,7 @@ impl TerminalCodegen {
|
||||
self.status = CodegenStatus::Pending;
|
||||
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
|
||||
self.generation = cx.spawn(async move |this, cx| {
|
||||
let prompt = prompt_task.await;
|
||||
let model_telemetry_id = model.telemetry_id();
|
||||
let model_provider_id = model.provider_id();
|
||||
let response = model.stream_completion_text(prompt, &cx).await;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::context::attach_context_to_message;
|
||||
use crate::context::load_context;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::inline_prompt_editor::{
|
||||
CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
|
||||
@@ -10,14 +10,14 @@ use client::telemetry::Telemetry;
|
||||
use collections::{HashMap, VecDeque};
|
||||
use editor::{MultiBuffer, actions::SelectAll};
|
||||
use fs::Fs;
|
||||
use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
|
||||
use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
|
||||
use language::Buffer;
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
Role, report_assistant_event,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::PromptBuilder;
|
||||
use prompt_store::{PromptBuilder, PromptStore};
|
||||
use std::sync::Arc;
|
||||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||
use terminal_view::TerminalView;
|
||||
@@ -69,6 +69,7 @@ impl TerminalInlineAssistant {
|
||||
terminal_view: &Entity<TerminalView>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -109,6 +110,7 @@ impl TerminalInlineAssistant {
|
||||
prompt_editor,
|
||||
workspace.clone(),
|
||||
context_store,
|
||||
prompt_store,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
@@ -196,11 +198,11 @@ impl TerminalInlineAssistant {
|
||||
.log_err();
|
||||
|
||||
let codegen = assist.codegen.clone();
|
||||
let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else {
|
||||
let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
|
||||
return;
|
||||
};
|
||||
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
|
||||
}
|
||||
|
||||
fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
|
||||
@@ -217,7 +219,7 @@ impl TerminalInlineAssistant {
|
||||
&self,
|
||||
assist_id: TerminalInlineAssistId,
|
||||
cx: &mut App,
|
||||
) -> Result<LanguageModelRequest> {
|
||||
) -> Result<Task<LanguageModelRequest>> {
|
||||
let assist = self.assists.get(&assist_id).context("invalid assist")?;
|
||||
|
||||
let shell = std::env::var("SHELL").ok();
|
||||
@@ -246,28 +248,40 @@ impl TerminalInlineAssistant {
|
||||
&latest_output,
|
||||
)?;
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
let contexts = assist
|
||||
.context_store
|
||||
.read(cx)
|
||||
.context()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
let context_load_task = assist.workspace.update(cx, |workspace, cx| {
|
||||
let project = workspace.project();
|
||||
load_context(contexts, project, &assist.prompt_store, cx)
|
||||
})?;
|
||||
|
||||
attach_context_to_message(
|
||||
&mut request_message,
|
||||
assist.context_store.read(cx).context().iter(),
|
||||
cx,
|
||||
);
|
||||
Ok(cx.background_spawn(async move {
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
|
||||
request_message.content.push(prompt.into());
|
||||
context_load_task
|
||||
.await
|
||||
.loaded_context
|
||||
.add_to_request_message(&mut request_message);
|
||||
|
||||
Ok(LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: vec![request_message],
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
})
|
||||
request_message.content.push(prompt.into());
|
||||
|
||||
LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: vec![request_message],
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn finish_assist(
|
||||
@@ -380,6 +394,7 @@ struct TerminalInlineAssist {
|
||||
codegen: Entity<TerminalCodegen>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
@@ -390,6 +405,7 @@ impl TerminalInlineAssist {
|
||||
prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
@@ -400,6 +416,7 @@ impl TerminalInlineAssist {
|
||||
codegen: codegen.clone(),
|
||||
workspace: workspace.clone(),
|
||||
context_store,
|
||||
prompt_store,
|
||||
_subscriptions: vec![
|
||||
window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
|
||||
TerminalInlineAssistant::update_global(cx, |this, cx| {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,8 +24,8 @@ use heed::types::SerdeBincode;
|
||||
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
|
||||
use project::{Project, Worktree};
|
||||
use prompt_store::{
|
||||
DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptId, PromptStore,
|
||||
PromptsUpdatedEvent, RulesFileContext, WorktreeContext,
|
||||
ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
|
||||
UserRulesContext, WorktreeContext,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
@@ -62,6 +62,7 @@ pub struct ThreadStore {
|
||||
project: Entity<Project>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
context_server_manager: Entity<ContextServerManager>,
|
||||
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
|
||||
threads: Vec<SerializedThreadMetadata>,
|
||||
@@ -81,12 +82,11 @@ impl ThreadStore {
|
||||
pub fn load(
|
||||
project: Entity<Project>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<Self>>> {
|
||||
let prompt_store = PromptStore::global(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
let prompt_store = prompt_store.await.ok();
|
||||
let (thread_store, ready_rx) = cx.update(|cx| {
|
||||
let mut option_ready_rx = None;
|
||||
let thread_store = cx.new(|cx| {
|
||||
@@ -135,6 +135,7 @@ impl ThreadStore {
|
||||
let (ready_tx, ready_rx) = oneshot::channel();
|
||||
let mut ready_tx = Some(ready_tx);
|
||||
let reload_system_prompt_task = cx.spawn({
|
||||
let prompt_store = prompt_store.clone();
|
||||
async move |thread_store, cx| {
|
||||
loop {
|
||||
let Some(reload_task) = thread_store
|
||||
@@ -158,6 +159,7 @@ impl ThreadStore {
|
||||
project,
|
||||
tools,
|
||||
prompt_builder,
|
||||
prompt_store,
|
||||
context_server_manager,
|
||||
context_server_tool_ids: HashMap::default(),
|
||||
threads: Vec::new(),
|
||||
@@ -245,7 +247,7 @@ impl ThreadStore {
|
||||
let default_user_rules = default_user_rules
|
||||
.into_iter()
|
||||
.flat_map(|(contents, prompt_metadata)| match contents {
|
||||
Ok(contents) => Some(DefaultUserRulesContext {
|
||||
Ok(contents) => Some(UserRulesContext {
|
||||
uuid: match prompt_metadata.id {
|
||||
PromptId::User { uuid } => uuid,
|
||||
PromptId::EditWorkflow => return None,
|
||||
@@ -346,6 +348,10 @@ impl ThreadStore {
|
||||
self.context_server_manager.clone()
|
||||
}
|
||||
|
||||
pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
|
||||
&self.prompt_store
|
||||
}
|
||||
|
||||
pub fn tools(&self) -> Entity<ToolWorkingSet> {
|
||||
self.tools.clone()
|
||||
}
|
||||
@@ -355,16 +361,12 @@ impl ThreadStore {
|
||||
self.threads.len()
|
||||
}
|
||||
|
||||
pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
|
||||
pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
|
||||
let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
|
||||
threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
|
||||
threads
|
||||
}
|
||||
|
||||
pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
|
||||
self.threads().into_iter().take(limit).collect()
|
||||
}
|
||||
|
||||
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
|
||||
cx.new(|cx| {
|
||||
Thread::new(
|
||||
@@ -615,12 +617,17 @@ pub struct SerializedThread {
|
||||
}
|
||||
|
||||
impl SerializedThread {
|
||||
pub const VERSION: &'static str = "0.1.0";
|
||||
pub const VERSION: &'static str = "0.2.0";
|
||||
|
||||
pub fn from_json(json: &[u8]) -> Result<Self> {
|
||||
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
|
||||
match saved_thread_json.get("version") {
|
||||
Some(serde_json::Value::String(version)) => match version.as_str() {
|
||||
SerializedThreadV0_1_0::VERSION => {
|
||||
let saved_thread =
|
||||
serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
|
||||
Ok(saved_thread.upgrade())
|
||||
}
|
||||
SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
|
||||
saved_thread_json,
|
||||
)?),
|
||||
@@ -642,6 +649,38 @@ impl SerializedThread {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct SerializedThreadV0_1_0(
|
||||
// The structure did not change, so we are reusing the latest SerializedThread.
|
||||
// When making the next version, make sure this points to SerializedThreadV0_2_0
|
||||
SerializedThread,
|
||||
);
|
||||
|
||||
impl SerializedThreadV0_1_0 {
|
||||
pub const VERSION: &'static str = "0.1.0";
|
||||
|
||||
pub fn upgrade(self) -> SerializedThread {
|
||||
debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
|
||||
|
||||
let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
|
||||
|
||||
for message in self.0.messages {
|
||||
if message.role == Role::User && !message.tool_results.is_empty() {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
debug_assert!(last_message.role == Role::Assistant);
|
||||
|
||||
last_message.tool_results = message.tool_results;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(message);
|
||||
}
|
||||
|
||||
SerializedThread { messages, ..self.0 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SerializedMessage {
|
||||
pub id: MessageId,
|
||||
|
||||
@@ -7,13 +7,13 @@ use futures::FutureExt as _;
|
||||
use futures::future::Shared;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
|
||||
LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
|
||||
};
|
||||
use ui::IconName;
|
||||
use util::truncate_lines_to_byte_limit;
|
||||
|
||||
use crate::thread::MessageId;
|
||||
use crate::thread::{MessageId, PromptId, ThreadId};
|
||||
use crate::thread_store::SerializedMessage;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -27,15 +27,13 @@ pub struct ToolUse {
|
||||
pub needs_confirmation: bool,
|
||||
}
|
||||
|
||||
pub const USING_TOOL_MARKER: &str = "<using_tool>";
|
||||
|
||||
pub struct ToolUseState {
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
|
||||
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
||||
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
||||
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
|
||||
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
|
||||
tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
|
||||
}
|
||||
|
||||
impl ToolUseState {
|
||||
@@ -43,10 +41,10 @@ impl ToolUseState {
|
||||
Self {
|
||||
tools,
|
||||
tool_uses_by_assistant_message: HashMap::default(),
|
||||
tool_uses_by_user_message: HashMap::default(),
|
||||
tool_results: HashMap::default(),
|
||||
pending_tool_uses_by_id: HashMap::default(),
|
||||
tool_result_cards: HashMap::default(),
|
||||
tool_use_metadata_by_id: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,7 +54,6 @@ impl ToolUseState {
|
||||
pub fn from_serialized_messages(
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
messages: &[SerializedMessage],
|
||||
mut filter_by_tool_name: impl FnMut(&str) -> bool,
|
||||
) -> Self {
|
||||
let mut this = Self::new(tools);
|
||||
let mut tool_names_by_id = HashMap::default();
|
||||
@@ -68,11 +65,12 @@ impl ToolUseState {
|
||||
let tool_uses = message
|
||||
.tool_uses
|
||||
.iter()
|
||||
.filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
|
||||
.map(|tool_use| LanguageModelToolUse {
|
||||
id: tool_use.id.clone(),
|
||||
name: tool_use.name.clone().into(),
|
||||
raw_input: tool_use.input.to_string(),
|
||||
input: tool_use.input.clone(),
|
||||
is_input_complete: true,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
@@ -84,14 +82,6 @@ impl ToolUseState {
|
||||
|
||||
this.tool_uses_by_assistant_message
|
||||
.insert(message.id, tool_uses);
|
||||
}
|
||||
}
|
||||
Role::User => {
|
||||
if !message.tool_results.is_empty() {
|
||||
let tool_uses_by_user_message = this
|
||||
.tool_uses_by_user_message
|
||||
.entry(message.id)
|
||||
.or_default();
|
||||
|
||||
for tool_result in &message.tool_results {
|
||||
let tool_use_id = tool_result.tool_use_id.clone();
|
||||
@@ -100,11 +90,6 @@ impl ToolUseState {
|
||||
continue;
|
||||
};
|
||||
|
||||
if !(filter_by_tool_name)(tool_use.as_ref()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
tool_uses_by_user_message.push(tool_use_id.clone());
|
||||
this.tool_results.insert(
|
||||
tool_use_id.clone(),
|
||||
LanguageModelToolResult {
|
||||
@@ -117,7 +102,7 @@ impl ToolUseState {
|
||||
}
|
||||
}
|
||||
}
|
||||
Role::System => {}
|
||||
Role::System | Role::User => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,6 +159,9 @@ impl ToolUseState {
|
||||
PendingToolUseStatus::Error(ref err) => {
|
||||
ToolUseStatus::Error(err.clone().into())
|
||||
}
|
||||
PendingToolUseStatus::InputStillStreaming => {
|
||||
ToolUseStatus::InputStillStreaming
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ToolUseStatus::Pending
|
||||
@@ -190,7 +178,12 @@ impl ToolUseState {
|
||||
tool_uses.push(ToolUse {
|
||||
id: tool_use.id.clone(),
|
||||
name: tool_use.name.clone().into(),
|
||||
ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
|
||||
ui_text: self.tool_ui_label(
|
||||
&tool_use.name,
|
||||
&tool_use.input,
|
||||
tool_use.is_input_complete,
|
||||
cx,
|
||||
),
|
||||
input: tool_use.input.clone(),
|
||||
status,
|
||||
icon,
|
||||
@@ -205,29 +198,40 @@ impl ToolUseState {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
input: &serde_json::Value,
|
||||
is_input_complete: bool,
|
||||
cx: &App,
|
||||
) -> SharedString {
|
||||
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
|
||||
tool.ui_text(input).into()
|
||||
if is_input_complete {
|
||||
tool.ui_text(input).into()
|
||||
} else {
|
||||
tool.still_streaming_ui_text(input).into()
|
||||
}
|
||||
} else {
|
||||
format!("Unknown tool {tool_name:?}").into()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
|
||||
let empty = Vec::new();
|
||||
pub fn tool_results_for_message(
|
||||
&self,
|
||||
assistant_message_id: MessageId,
|
||||
) -> Vec<&LanguageModelToolResult> {
|
||||
let Some(tool_uses) = self
|
||||
.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)
|
||||
else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
self.tool_uses_by_user_message
|
||||
.get(&message_id)
|
||||
.unwrap_or(&empty)
|
||||
tool_uses
|
||||
.iter()
|
||||
.filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
|
||||
.filter_map(|tool_use| self.tool_results.get(&tool_use.id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_user_message
|
||||
.get(&message_id)
|
||||
pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)
|
||||
.map_or(false, |results| !results.is_empty())
|
||||
}
|
||||
|
||||
@@ -254,20 +258,44 @@ impl ToolUseState {
|
||||
&mut self,
|
||||
assistant_message_id: MessageId,
|
||||
tool_use: LanguageModelToolUse,
|
||||
metadata: ToolUseMetadata,
|
||||
cx: &App,
|
||||
) {
|
||||
self.tool_uses_by_assistant_message
|
||||
) -> Arc<str> {
|
||||
let tool_uses = self
|
||||
.tool_uses_by_assistant_message
|
||||
.entry(assistant_message_id)
|
||||
.or_default()
|
||||
.push(tool_use.clone());
|
||||
.or_default();
|
||||
|
||||
// The tool use is being requested by the Assistant, so we want to
|
||||
// attach the tool results to the next user message.
|
||||
let next_user_message_id = MessageId(assistant_message_id.0 + 1);
|
||||
self.tool_uses_by_user_message
|
||||
.entry(next_user_message_id)
|
||||
.or_default()
|
||||
.push(tool_use.id.clone());
|
||||
let mut existing_tool_use_found = false;
|
||||
|
||||
for existing_tool_use in tool_uses.iter_mut() {
|
||||
if existing_tool_use.id == tool_use.id {
|
||||
*existing_tool_use = tool_use.clone();
|
||||
existing_tool_use_found = true;
|
||||
}
|
||||
}
|
||||
|
||||
if !existing_tool_use_found {
|
||||
tool_uses.push(tool_use.clone());
|
||||
}
|
||||
|
||||
let status = if tool_use.is_input_complete {
|
||||
self.tool_use_metadata_by_id
|
||||
.insert(tool_use.id.clone(), metadata);
|
||||
|
||||
PendingToolUseStatus::Idle
|
||||
} else {
|
||||
PendingToolUseStatus::InputStillStreaming
|
||||
};
|
||||
|
||||
let ui_text: Arc<str> = self
|
||||
.tool_ui_label(
|
||||
&tool_use.name,
|
||||
&tool_use.input,
|
||||
tool_use.is_input_complete,
|
||||
cx,
|
||||
)
|
||||
.into();
|
||||
|
||||
self.pending_tool_uses_by_id.insert(
|
||||
tool_use.id.clone(),
|
||||
@@ -275,13 +303,13 @@ impl ToolUseState {
|
||||
assistant_message_id,
|
||||
id: tool_use.id,
|
||||
name: tool_use.name.clone(),
|
||||
ui_text: self
|
||||
.tool_ui_label(&tool_use.name, &tool_use.input, cx)
|
||||
.into(),
|
||||
ui_text: ui_text.clone(),
|
||||
input: tool_use.input,
|
||||
status: PendingToolUseStatus::Idle,
|
||||
status,
|
||||
},
|
||||
);
|
||||
|
||||
ui_text
|
||||
}
|
||||
|
||||
pub fn run_pending_tool(
|
||||
@@ -327,7 +355,21 @@ impl ToolUseState {
|
||||
output: Result<String>,
|
||||
cx: &App,
|
||||
) -> Option<PendingToolUse> {
|
||||
telemetry::event!("Agent Tool Finished", tool_name, success = output.is_ok());
|
||||
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
|
||||
|
||||
telemetry::event!(
|
||||
"Agent Tool Finished",
|
||||
model = metadata
|
||||
.as_ref()
|
||||
.map(|metadata| metadata.model.telemetry_id()),
|
||||
model_provider = metadata
|
||||
.as_ref()
|
||||
.map(|metadata| metadata.model.provider_id().to_string()),
|
||||
thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
|
||||
prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
|
||||
tool_name,
|
||||
success = output.is_ok()
|
||||
);
|
||||
|
||||
match output {
|
||||
Ok(tool_result) => {
|
||||
@@ -390,28 +432,8 @@ impl ToolUseState {
|
||||
request_message: &mut LanguageModelRequestMessage,
|
||||
) {
|
||||
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
|
||||
let mut found_tool_use = false;
|
||||
|
||||
for tool_use in tool_uses {
|
||||
if self.tool_results.contains_key(&tool_use.id) {
|
||||
if !found_tool_use {
|
||||
// The API fails if a message contains a tool use without any (non-whitespace) text around it
|
||||
match request_message.content.last_mut() {
|
||||
Some(MessageContent::Text(txt)) => {
|
||||
if txt.is_empty() {
|
||||
txt.push_str(USING_TOOL_MARKER);
|
||||
}
|
||||
}
|
||||
None | Some(_) => {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text(USING_TOOL_MARKER.into()));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
found_tool_use = true;
|
||||
|
||||
// Do not send tool uses until they are completed
|
||||
request_message
|
||||
.content
|
||||
@@ -426,31 +448,49 @@ impl ToolUseState {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn attach_tool_results(
|
||||
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_assistant_message
|
||||
.contains_key(&assistant_message_id)
|
||||
}
|
||||
|
||||
pub fn tool_results_message(
|
||||
&self,
|
||||
message_id: MessageId,
|
||||
request_message: &mut LanguageModelRequestMessage,
|
||||
) {
|
||||
if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
|
||||
for tool_use_id in tool_uses {
|
||||
if let Some(tool_result) = self.tool_results.get(tool_use_id) {
|
||||
request_message.content.push(MessageContent::ToolResult(
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
tool_name: tool_result.tool_name.clone(),
|
||||
is_error: tool_result.is_error,
|
||||
content: if tool_result.content.is_empty() {
|
||||
// Surprisingly, the API fails if we return an empty string here.
|
||||
// It thinks we are sending a tool use without a tool result.
|
||||
"<Tool returned an empty string>".into()
|
||||
} else {
|
||||
tool_result.content.clone()
|
||||
},
|
||||
assistant_message_id: MessageId,
|
||||
) -> Option<LanguageModelRequestMessage> {
|
||||
let tool_uses = self
|
||||
.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)?;
|
||||
|
||||
if tool_uses.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
|
||||
for tool_use in tool_uses {
|
||||
if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id.clone(),
|
||||
tool_name: tool_result.tool_name.clone(),
|
||||
is_error: tool_result.is_error,
|
||||
content: if tool_result.content.is_empty() {
|
||||
// Surprisingly, the API fails if we return an empty string here.
|
||||
// It thinks we are sending a tool use without a tool result.
|
||||
"<Tool returned an empty string>".into()
|
||||
} else {
|
||||
tool_result.content.clone()
|
||||
},
|
||||
));
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Some(request_message)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -477,6 +517,7 @@ pub struct Confirmation {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PendingToolUseStatus {
|
||||
InputStillStreaming,
|
||||
Idle,
|
||||
NeedsConfirmation(Arc<Confirmation>),
|
||||
Running { _task: Shared<Task<()>> },
|
||||
@@ -496,3 +537,10 @@ impl PendingToolUseStatus {
|
||||
matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolUseMetadata {
|
||||
pub model: Arc<dyn LanguageModel>,
|
||||
pub thread_id: ThreadId,
|
||||
pub prompt_id: PromptId,
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
mod agent_notification;
|
||||
mod animated_label;
|
||||
mod context_pill;
|
||||
mod usage_banner;
|
||||
|
||||
pub use agent_notification::*;
|
||||
pub use animated_label::*;
|
||||
pub use context_pill::*;
|
||||
pub use usage_banner::*;
|
||||
|
||||
@@ -12,6 +12,7 @@ pub struct AgentNotification {
|
||||
title: SharedString,
|
||||
caption: SharedString,
|
||||
icon: IconName,
|
||||
project_name: Option<SharedString>,
|
||||
}
|
||||
|
||||
impl AgentNotification {
|
||||
@@ -19,11 +20,13 @@ impl AgentNotification {
|
||||
title: impl Into<SharedString>,
|
||||
caption: impl Into<SharedString>,
|
||||
icon: IconName,
|
||||
project_name: Option<impl Into<SharedString>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
title: title.into(),
|
||||
caption: caption.into(),
|
||||
icon,
|
||||
project_name: project_name.map(|name| name.into()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,11 +133,34 @@ impl Render for AgentNotification {
|
||||
.child(gradient_overflow()),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
h_flex()
|
||||
.relative()
|
||||
.gap_1p5()
|
||||
.text_size(px(12.))
|
||||
.text_color(cx.theme().colors().text_muted)
|
||||
.truncate()
|
||||
.when_some(
|
||||
self.project_name.clone(),
|
||||
|description, project_name| {
|
||||
description.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
div()
|
||||
.max_w_16()
|
||||
.truncate()
|
||||
.child(project_name),
|
||||
)
|
||||
.child(
|
||||
div().size(px(3.)).rounded_full().bg(cx
|
||||
.theme()
|
||||
.colors()
|
||||
.text
|
||||
.opacity(0.5)),
|
||||
),
|
||||
)
|
||||
},
|
||||
)
|
||||
.child(self.caption.clone())
|
||||
.child(gradient_overflow()),
|
||||
),
|
||||
|
||||
121
crates/agent/src/ui/animated_label.rs
Normal file
121
crates/agent/src/ui/animated_label.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use gpui::{Animation, AnimationExt, FontWeight, pulsating_between};
|
||||
use std::time::Duration;
|
||||
use ui::prelude::*;
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub struct AnimatedLabel {
|
||||
base: Label,
|
||||
text: SharedString,
|
||||
}
|
||||
|
||||
impl AnimatedLabel {
|
||||
pub fn new(text: impl Into<SharedString>) -> Self {
|
||||
let text = text.into();
|
||||
AnimatedLabel {
|
||||
base: Label::new(text.clone()),
|
||||
text,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelCommon for AnimatedLabel {
|
||||
fn size(mut self, size: LabelSize) -> Self {
|
||||
self.base = self.base.size(size);
|
||||
self
|
||||
}
|
||||
|
||||
fn weight(mut self, weight: FontWeight) -> Self {
|
||||
self.base = self.base.weight(weight);
|
||||
self
|
||||
}
|
||||
|
||||
fn line_height_style(mut self, line_height_style: LineHeightStyle) -> Self {
|
||||
self.base = self.base.line_height_style(line_height_style);
|
||||
self
|
||||
}
|
||||
|
||||
fn color(mut self, color: Color) -> Self {
|
||||
self.base = self.base.color(color);
|
||||
self
|
||||
}
|
||||
|
||||
fn strikethrough(mut self) -> Self {
|
||||
self.base = self.base.strikethrough();
|
||||
self
|
||||
}
|
||||
|
||||
fn italic(mut self) -> Self {
|
||||
self.base = self.base.italic();
|
||||
self
|
||||
}
|
||||
|
||||
fn alpha(mut self, alpha: f32) -> Self {
|
||||
self.base = self.base.alpha(alpha);
|
||||
self
|
||||
}
|
||||
|
||||
fn underline(mut self) -> Self {
|
||||
self.base = self.base.underline();
|
||||
self
|
||||
}
|
||||
|
||||
fn truncate(mut self) -> Self {
|
||||
self.base = self.base.truncate();
|
||||
self
|
||||
}
|
||||
|
||||
fn single_line(mut self) -> Self {
|
||||
self.base = self.base.single_line();
|
||||
self
|
||||
}
|
||||
|
||||
fn buffer_font(mut self, cx: &App) -> Self {
|
||||
self.base = self.base.buffer_font(cx);
|
||||
self
|
||||
}
|
||||
|
||||
fn inline_code(mut self, cx: &App) -> Self {
|
||||
self.base = self.base.inline_code(cx);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for AnimatedLabel {
|
||||
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
|
||||
let text = self.text.clone();
|
||||
|
||||
self.base
|
||||
.color(Color::Muted)
|
||||
.with_animations(
|
||||
"animated-label",
|
||||
vec![
|
||||
Animation::new(Duration::from_secs(1)),
|
||||
Animation::new(Duration::from_secs(1)).repeat(),
|
||||
],
|
||||
move |mut label, animation_ix, delta| {
|
||||
match animation_ix {
|
||||
0 => {
|
||||
let chars_to_show = (delta * text.len() as f32).ceil() as usize;
|
||||
let text = SharedString::from(text[0..chars_to_show].to_string());
|
||||
label.set_text(text);
|
||||
}
|
||||
1 => match delta {
|
||||
d if d < 0.25 => label.set_text(text.clone()),
|
||||
d if d < 0.5 => label.set_text(format!("{}.", text)),
|
||||
d if d < 0.75 => label.set_text(format!("{}..", text)),
|
||||
_ => label.set_text(format!("{}...", text)),
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
label
|
||||
},
|
||||
)
|
||||
.with_animation(
|
||||
"pulsating-label",
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.6, 1.)),
|
||||
|label, delta| label.map_element(|label| label.alpha(delta)),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
use std::{rc::Rc, time::Duration};
|
||||
|
||||
use file_icons::FileIcons;
|
||||
use gpui::ClickEvent;
|
||||
use gpui::{Animation, AnimationExt as _, pulsating_between};
|
||||
use ui::{IconButtonShape, Tooltip, prelude::*};
|
||||
use gpui::{Animation, AnimationExt as _, ClickEvent, Entity, MouseButton, pulsating_between};
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
use text::OffsetRangeExt;
|
||||
use ui::{IconButtonShape, Tooltip, prelude::*, tooltip_container};
|
||||
|
||||
use crate::context::{AssistantContext, ContextId, ContextKind};
|
||||
use crate::context::{AgentContext, ContextKind, ImageStatus};
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub enum ContextPill {
|
||||
@@ -70,9 +72,7 @@ impl ContextPill {
|
||||
|
||||
pub fn id(&self) -> ElementId {
|
||||
match self {
|
||||
Self::Added { context, .. } => {
|
||||
ElementId::NamedInteger("context-pill".into(), context.id.0)
|
||||
}
|
||||
Self::Added { context, .. } => context.context.element_id("context-pill".into()),
|
||||
Self::Suggested { .. } => "suggested-context-pill".into(),
|
||||
}
|
||||
}
|
||||
@@ -120,43 +120,86 @@ impl RenderOnce for ContextPill {
|
||||
on_remove,
|
||||
focused,
|
||||
on_click,
|
||||
} => base_pill
|
||||
.bg(color.element_background)
|
||||
.border_color(if *focused {
|
||||
color.border_focused
|
||||
} else {
|
||||
color.border.opacity(0.5)
|
||||
})
|
||||
.pr(if on_remove.is_some() { px(2.) } else { px(4.) })
|
||||
.child(
|
||||
h_flex()
|
||||
.id("context-data")
|
||||
.gap_1()
|
||||
.child(
|
||||
div().max_w_64().child(
|
||||
Label::new(context.name.clone())
|
||||
.size(LabelSize::Small)
|
||||
.truncate(),
|
||||
),
|
||||
)
|
||||
.when_some(context.parent.as_ref(), |element, parent_name| {
|
||||
if *dupe_name {
|
||||
element.child(
|
||||
Label::new(parent_name.clone())
|
||||
.size(LabelSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
} else {
|
||||
element
|
||||
}
|
||||
})
|
||||
.when_some(context.tooltip.as_ref(), |element, tooltip| {
|
||||
element.tooltip(Tooltip::text(tooltip.clone()))
|
||||
}),
|
||||
)
|
||||
.when_some(on_remove.as_ref(), |element, on_remove| {
|
||||
element.child(
|
||||
IconButton::new(("remove", context.id.0), IconName::Close)
|
||||
} => {
|
||||
let status_is_error = matches!(context.status, ContextStatus::Error { .. });
|
||||
|
||||
base_pill
|
||||
.pr(if on_remove.is_some() { px(2.) } else { px(4.) })
|
||||
.map(|pill| {
|
||||
if status_is_error {
|
||||
pill.bg(cx.theme().status().error_background)
|
||||
.border_color(cx.theme().status().error_border)
|
||||
} else if *focused {
|
||||
pill.bg(color.element_background)
|
||||
.border_color(color.border_focused)
|
||||
} else {
|
||||
pill.bg(color.element_background)
|
||||
.border_color(color.border.opacity(0.5))
|
||||
}
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.id("context-data")
|
||||
.gap_1()
|
||||
.child(
|
||||
div().max_w_64().child(
|
||||
Label::new(context.name.clone())
|
||||
.size(LabelSize::Small)
|
||||
.truncate(),
|
||||
),
|
||||
)
|
||||
.when_some(context.parent.as_ref(), |element, parent_name| {
|
||||
if *dupe_name {
|
||||
element.child(
|
||||
Label::new(parent_name.clone())
|
||||
.size(LabelSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
} else {
|
||||
element
|
||||
}
|
||||
})
|
||||
.when_some(context.tooltip.as_ref(), |element, tooltip| {
|
||||
element.tooltip(Tooltip::text(tooltip.clone()))
|
||||
})
|
||||
.map(|element| match &context.status {
|
||||
ContextStatus::Ready => element
|
||||
.when_some(
|
||||
context.render_preview.as_ref(),
|
||||
|element, render_preview| {
|
||||
element.hoverable_tooltip({
|
||||
let render_preview = render_preview.clone();
|
||||
move |_, cx| {
|
||||
cx.new(|_| ContextPillPreview {
|
||||
render_preview: render_preview.clone(),
|
||||
})
|
||||
.into()
|
||||
}
|
||||
})
|
||||
},
|
||||
)
|
||||
.into_any(),
|
||||
ContextStatus::Loading { message } => element
|
||||
.tooltip(ui::Tooltip::text(message.clone()))
|
||||
.with_animation(
|
||||
"pulsating-ctx-pill",
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.4, 0.8)),
|
||||
|label, delta| label.opacity(delta),
|
||||
)
|
||||
.into_any_element(),
|
||||
ContextStatus::Error { message } => element
|
||||
.tooltip(ui::Tooltip::text(message.clone()))
|
||||
.into_any_element(),
|
||||
}),
|
||||
)
|
||||
.when_some(on_remove.as_ref(), |element, on_remove| {
|
||||
element.child(
|
||||
IconButton::new(
|
||||
context.context.element_id("remove".into()),
|
||||
IconName::Close,
|
||||
)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.tooltip(Tooltip::text("Remove Context"))
|
||||
@@ -164,30 +207,16 @@ impl RenderOnce for ContextPill {
|
||||
let on_remove = on_remove.clone();
|
||||
move |event, window, cx| on_remove(event, window, cx)
|
||||
}),
|
||||
)
|
||||
})
|
||||
.when_some(on_click.as_ref(), |element, on_click| {
|
||||
let on_click = on_click.clone();
|
||||
element
|
||||
.cursor_pointer()
|
||||
.on_click(move |event, window, cx| on_click(event, window, cx))
|
||||
})
|
||||
.map(|element| {
|
||||
if context.summarizing {
|
||||
)
|
||||
})
|
||||
.when_some(on_click.as_ref(), |element, on_click| {
|
||||
let on_click = on_click.clone();
|
||||
element
|
||||
.tooltip(ui::Tooltip::text("Summarizing..."))
|
||||
.with_animation(
|
||||
"pulsating-ctx-pill",
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.4, 0.8)),
|
||||
|label, delta| label.opacity(delta),
|
||||
)
|
||||
.into_any_element()
|
||||
} else {
|
||||
element.into_any()
|
||||
}
|
||||
}),
|
||||
.cursor_pointer()
|
||||
.on_click(move |event, window, cx| on_click(event, window, cx))
|
||||
})
|
||||
.into_any_element()
|
||||
}
|
||||
ContextPill::Suggested {
|
||||
name,
|
||||
icon_path: _,
|
||||
@@ -198,15 +227,15 @@ impl RenderOnce for ContextPill {
|
||||
.cursor_pointer()
|
||||
.pr_1()
|
||||
.border_dashed()
|
||||
.border_color(if *focused {
|
||||
color.border_focused
|
||||
} else {
|
||||
color.border
|
||||
.map(|pill| {
|
||||
if *focused {
|
||||
pill.border_color(color.border_focused)
|
||||
.bg(color.element_background.opacity(0.5))
|
||||
} else {
|
||||
pill.border_color(color.border)
|
||||
}
|
||||
})
|
||||
.hover(|style| style.bg(color.element_hover.opacity(0.5)))
|
||||
.when(*focused, |this| {
|
||||
this.bg(color.element_background.opacity(0.5))
|
||||
})
|
||||
.child(
|
||||
div().max_w_64().child(
|
||||
Label::new(name.clone())
|
||||
@@ -227,21 +256,40 @@ impl RenderOnce for ContextPill {
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ContextStatus {
|
||||
Ready,
|
||||
Loading { message: SharedString },
|
||||
Error { message: SharedString },
|
||||
}
|
||||
|
||||
// TODO: Component commented out due to new dependency on `Project`.
|
||||
//
|
||||
// #[derive(RegisterComponent)]
|
||||
pub struct AddedContext {
|
||||
pub id: ContextId,
|
||||
pub context: AgentContext,
|
||||
pub kind: ContextKind,
|
||||
pub name: SharedString,
|
||||
pub parent: Option<SharedString>,
|
||||
pub tooltip: Option<SharedString>,
|
||||
pub icon_path: Option<SharedString>,
|
||||
pub summarizing: bool,
|
||||
pub status: ContextStatus,
|
||||
pub render_preview: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyElement + 'static>>,
|
||||
}
|
||||
|
||||
impl AddedContext {
|
||||
pub fn new(context: &AssistantContext, cx: &App) -> AddedContext {
|
||||
/// Creates an `AddedContext` by retrieving relevant details of `AgentContext`. This returns a
|
||||
/// `None` if `DirectoryContext` or `RulesContext` no longer exist.
|
||||
///
|
||||
/// TODO: `None` cases are unremovable from `ContextStore` and so are a very minor memory leak.
|
||||
pub fn new(
|
||||
context: AgentContext,
|
||||
prompt_store: Option<&Entity<PromptStore>>,
|
||||
project: &Project,
|
||||
cx: &App,
|
||||
) -> Option<AddedContext> {
|
||||
match context {
|
||||
AssistantContext::File(file_context) => {
|
||||
let full_path = file_context.context_buffer.file.full_path(cx);
|
||||
AgentContext::File(ref file_context) => {
|
||||
let full_path = file_context.buffer.read(cx).file()?.full_path(cx);
|
||||
let full_path_string: SharedString =
|
||||
full_path.to_string_lossy().into_owned().into();
|
||||
let name = full_path
|
||||
@@ -252,22 +300,24 @@ impl AddedContext {
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
AddedContext {
|
||||
id: file_context.id,
|
||||
Some(AddedContext {
|
||||
kind: ContextKind::File,
|
||||
name,
|
||||
parent,
|
||||
tooltip: Some(full_path_string),
|
||||
icon_path: FileIcons::get_icon(&full_path, cx),
|
||||
summarizing: false,
|
||||
}
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
context,
|
||||
})
|
||||
}
|
||||
|
||||
AssistantContext::Directory(directory_context) => {
|
||||
let full_path = directory_context
|
||||
.worktree
|
||||
.read(cx)
|
||||
.full_path(&directory_context.path);
|
||||
AgentContext::Directory(ref directory_context) => {
|
||||
let worktree = project
|
||||
.worktree_for_entry(directory_context.entry_id, cx)?
|
||||
.read(cx);
|
||||
let entry = worktree.entry_for_id(directory_context.entry_id)?;
|
||||
let full_path = worktree.full_path(&entry.path);
|
||||
let full_path_string: SharedString =
|
||||
full_path.to_string_lossy().into_owned().into();
|
||||
let name = full_path
|
||||
@@ -278,40 +328,42 @@ impl AddedContext {
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
AddedContext {
|
||||
id: directory_context.id,
|
||||
Some(AddedContext {
|
||||
kind: ContextKind::Directory,
|
||||
name,
|
||||
parent,
|
||||
tooltip: Some(full_path_string),
|
||||
icon_path: None,
|
||||
summarizing: false,
|
||||
}
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
context,
|
||||
})
|
||||
}
|
||||
|
||||
AssistantContext::Symbol(symbol_context) => AddedContext {
|
||||
id: symbol_context.id,
|
||||
AgentContext::Symbol(ref symbol_context) => Some(AddedContext {
|
||||
kind: ContextKind::Symbol,
|
||||
name: symbol_context.context_symbol.id.name.clone(),
|
||||
name: symbol_context.symbol.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
summarizing: false,
|
||||
},
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
context,
|
||||
}),
|
||||
|
||||
AssistantContext::Excerpt(excerpt_context) => {
|
||||
let full_path = excerpt_context.context_buffer.file.full_path(cx);
|
||||
AgentContext::Selection(ref selection_context) => {
|
||||
let buffer = selection_context.buffer.read(cx);
|
||||
let full_path = buffer.file()?.full_path(cx);
|
||||
let mut full_path_string = full_path.to_string_lossy().into_owned();
|
||||
let mut name = full_path
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().into_owned())
|
||||
.unwrap_or_else(|| full_path_string.clone());
|
||||
|
||||
let line_range_text = format!(
|
||||
" ({}-{})",
|
||||
excerpt_context.line_range.start.row + 1,
|
||||
excerpt_context.line_range.end.row + 1
|
||||
);
|
||||
let line_range = selection_context.range.to_point(&buffer.snapshot());
|
||||
|
||||
let line_range_text =
|
||||
format!(" ({}-{})", line_range.start.row + 1, line_range.end.row + 1);
|
||||
|
||||
full_path_string.push_str(&line_range_text);
|
||||
name.push_str(&line_range_text);
|
||||
@@ -321,39 +373,198 @@ impl AddedContext {
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
|
||||
AddedContext {
|
||||
id: excerpt_context.id,
|
||||
kind: ContextKind::File, // Use File icon for excerpts
|
||||
Some(AddedContext {
|
||||
kind: ContextKind::Selection,
|
||||
name: name.into(),
|
||||
parent,
|
||||
tooltip: Some(full_path_string.into()),
|
||||
tooltip: None,
|
||||
icon_path: FileIcons::get_icon(&full_path, cx),
|
||||
summarizing: false,
|
||||
}
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
/*
|
||||
render_preview: Some(Rc::new({
|
||||
let content = selection_context.text.clone();
|
||||
move |_, cx| {
|
||||
div()
|
||||
.id("context-pill-selection-preview")
|
||||
.overflow_scroll()
|
||||
.max_w_128()
|
||||
.max_h_96()
|
||||
.child(Label::new(content.clone()).buffer_font(cx))
|
||||
.into_any_element()
|
||||
}
|
||||
})),
|
||||
*/
|
||||
context,
|
||||
})
|
||||
}
|
||||
|
||||
AssistantContext::FetchedUrl(fetched_url_context) => AddedContext {
|
||||
id: fetched_url_context.id,
|
||||
AgentContext::FetchedUrl(ref fetched_url_context) => Some(AddedContext {
|
||||
kind: ContextKind::FetchedUrl,
|
||||
name: fetched_url_context.url.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
summarizing: false,
|
||||
},
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
context,
|
||||
}),
|
||||
|
||||
AssistantContext::Thread(thread_context) => AddedContext {
|
||||
id: thread_context.id,
|
||||
AgentContext::Thread(ref thread_context) => Some(AddedContext {
|
||||
kind: ContextKind::Thread,
|
||||
name: thread_context.summary(cx),
|
||||
name: thread_context.name(cx),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
summarizing: thread_context
|
||||
status: if thread_context
|
||||
.thread
|
||||
.read(cx)
|
||||
.is_generating_detailed_summary(),
|
||||
},
|
||||
.is_generating_detailed_summary()
|
||||
{
|
||||
ContextStatus::Loading {
|
||||
message: "Summarizing…".into(),
|
||||
}
|
||||
} else {
|
||||
ContextStatus::Ready
|
||||
},
|
||||
render_preview: None,
|
||||
context,
|
||||
}),
|
||||
|
||||
AgentContext::Rules(ref user_rules_context) => {
|
||||
let name = prompt_store
|
||||
.as_ref()?
|
||||
.read(cx)
|
||||
.metadata(user_rules_context.prompt_id.into())?
|
||||
.title?;
|
||||
Some(AddedContext {
|
||||
kind: ContextKind::Rules,
|
||||
name: name.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
context,
|
||||
})
|
||||
}
|
||||
|
||||
AgentContext::Image(ref image_context) => Some(AddedContext {
|
||||
kind: ContextKind::Image,
|
||||
name: "Image".into(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: match image_context.status() {
|
||||
ImageStatus::Loading => ContextStatus::Loading {
|
||||
message: "Loading…".into(),
|
||||
},
|
||||
ImageStatus::Error => ContextStatus::Error {
|
||||
message: "Failed to load image".into(),
|
||||
},
|
||||
ImageStatus::Ready => ContextStatus::Ready,
|
||||
},
|
||||
render_preview: Some(Rc::new({
|
||||
let image = image_context.original_image.clone();
|
||||
move |_, _| {
|
||||
gpui::img(image.clone())
|
||||
.max_w_96()
|
||||
.max_h_96()
|
||||
.into_any_element()
|
||||
}
|
||||
})),
|
||||
context,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ContextPillPreview {
|
||||
render_preview: Rc<dyn Fn(&mut Window, &mut App) -> AnyElement>,
|
||||
}
|
||||
|
||||
impl Render for ContextPillPreview {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
tooltip_container(window, cx, move |this, window, cx| {
|
||||
this.occlude()
|
||||
.on_mouse_move(|_, _, cx| cx.stop_propagation())
|
||||
.on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation())
|
||||
.child((self.render_preview)(window, cx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Component commented out due to new dependency on `Project`.
|
||||
/*
|
||||
impl Component for AddedContext {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
}
|
||||
|
||||
fn sort_name() -> &'static str {
|
||||
"AddedContext"
|
||||
}
|
||||
|
||||
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
|
||||
let next_context_id = ContextId::zero();
|
||||
let image_ready = (
|
||||
"Ready",
|
||||
AddedContext::new(
|
||||
AgentContext::Image(ImageContext {
|
||||
context_id: next_context_id.post_inc(),
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
|
||||
}),
|
||||
cx,
|
||||
),
|
||||
);
|
||||
|
||||
let image_loading = (
|
||||
"Loading",
|
||||
AddedContext::new(
|
||||
AgentContext::Image(ImageContext {
|
||||
context_id: next_context_id.post_inc(),
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: cx
|
||||
.background_spawn(async move {
|
||||
smol::Timer::after(Duration::from_secs(60 * 5)).await;
|
||||
Some(LanguageModelImage::empty())
|
||||
})
|
||||
.shared(),
|
||||
}),
|
||||
cx,
|
||||
),
|
||||
);
|
||||
|
||||
let image_error = (
|
||||
"Error",
|
||||
AddedContext::new(
|
||||
AgentContext::Image(ImageContext {
|
||||
context_id: next_context_id.post_inc(),
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: Task::ready(None).shared(),
|
||||
}),
|
||||
cx,
|
||||
),
|
||||
);
|
||||
|
||||
Some(
|
||||
v_flex()
|
||||
.gap_6()
|
||||
.children(
|
||||
vec![image_ready, image_loading, image_error]
|
||||
.into_iter()
|
||||
.map(|(text, context)| {
|
||||
single_example(
|
||||
text,
|
||||
ContextPill::added(context, false, false, None).into_any_element(),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.into_any(),
|
||||
)
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
@@ -1,31 +1,30 @@
|
||||
use client::zed_urls;
|
||||
use language_model::RequestUsage;
|
||||
use ui::{Banner, ProgressBar, Severity, prelude::*};
|
||||
use zed_llm_client::{Plan, UsageLimit};
|
||||
|
||||
#[derive(IntoElement, RegisterComponent)]
|
||||
pub struct UsageBanner {
|
||||
plan: Plan,
|
||||
requests: i32,
|
||||
usage: RequestUsage,
|
||||
}
|
||||
|
||||
impl UsageBanner {
|
||||
pub fn new(plan: Plan, requests: i32) -> Self {
|
||||
Self { plan, requests }
|
||||
pub fn new(plan: Plan, usage: RequestUsage) -> Self {
|
||||
Self { plan, usage }
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for UsageBanner {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
let request_limit = self.plan.model_requests_limit();
|
||||
|
||||
let used_percentage = match request_limit {
|
||||
UsageLimit::Limited(limit) => Some((self.requests as f32 / limit as f32) * 100.),
|
||||
let used_percentage = match self.usage.limit {
|
||||
UsageLimit::Limited(limit) => Some((self.usage.amount as f32 / limit as f32) * 100.),
|
||||
UsageLimit::Unlimited => None,
|
||||
};
|
||||
|
||||
let (severity, message) = match request_limit {
|
||||
let (severity, message) = match self.usage.limit {
|
||||
UsageLimit::Limited(limit) => {
|
||||
if self.requests >= limit {
|
||||
if self.usage.amount >= limit {
|
||||
let message = match self.plan {
|
||||
Plan::ZedPro => "Monthly request limit reached",
|
||||
Plan::ZedProTrial => "Trial request limit reached",
|
||||
@@ -33,7 +32,7 @@ impl RenderOnce for UsageBanner {
|
||||
};
|
||||
|
||||
(Severity::Error, message)
|
||||
} else if (self.requests as f32 / limit as f32) >= 0.9 {
|
||||
} else if (self.usage.amount as f32 / limit as f32) >= 0.9 {
|
||||
(Severity::Warning, "Approaching request limit")
|
||||
} else {
|
||||
let message = match self.plan {
|
||||
@@ -81,11 +80,11 @@ impl RenderOnce for UsageBanner {
|
||||
.child(ProgressBar::new("usage", percent, 100., cx))
|
||||
}))
|
||||
.child(
|
||||
Label::new(match request_limit {
|
||||
Label::new(match self.usage.limit {
|
||||
UsageLimit::Limited(limit) => {
|
||||
format!("{} / {limit}", self.requests)
|
||||
format!("{} / {limit}", self.usage.amount)
|
||||
}
|
||||
UsageLimit::Unlimited => format!("{} / ∞", self.requests),
|
||||
UsageLimit::Unlimited => format!("{} / ∞", self.usage.amount),
|
||||
})
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
@@ -104,74 +103,131 @@ impl Component for UsageBanner {
|
||||
}
|
||||
|
||||
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
|
||||
let trial_limit = Plan::ZedProTrial.model_requests_limit();
|
||||
let trial_examples = vec![
|
||||
single_example(
|
||||
"Zed Pro Trial - New User",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(Plan::ZedProTrial, 10))
|
||||
.child(UsageBanner::new(
|
||||
Plan::ZedProTrial,
|
||||
RequestUsage {
|
||||
limit: trial_limit,
|
||||
amount: 10,
|
||||
},
|
||||
))
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Zed Pro Trial - Approaching Limit",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(Plan::ZedProTrial, 135))
|
||||
.child(UsageBanner::new(
|
||||
Plan::ZedProTrial,
|
||||
RequestUsage {
|
||||
limit: trial_limit,
|
||||
amount: 135,
|
||||
},
|
||||
))
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Zed Pro Trial - Request Limit Reached",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(Plan::ZedProTrial, 150))
|
||||
.child(UsageBanner::new(
|
||||
Plan::ZedProTrial,
|
||||
RequestUsage {
|
||||
limit: trial_limit,
|
||||
amount: 150,
|
||||
},
|
||||
))
|
||||
.into_any_element(),
|
||||
),
|
||||
];
|
||||
|
||||
let free_limit = Plan::Free.model_requests_limit();
|
||||
let free_examples = vec![
|
||||
single_example(
|
||||
"Free - Normal Usage",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(Plan::Free, 25))
|
||||
.child(UsageBanner::new(
|
||||
Plan::Free,
|
||||
RequestUsage {
|
||||
limit: free_limit,
|
||||
amount: 25,
|
||||
},
|
||||
))
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Free - Approaching Limit",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(Plan::Free, 45))
|
||||
.child(UsageBanner::new(
|
||||
Plan::Free,
|
||||
RequestUsage {
|
||||
limit: free_limit,
|
||||
amount: 45,
|
||||
},
|
||||
))
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Free - Request Limit Reached",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(Plan::Free, 50))
|
||||
.child(UsageBanner::new(
|
||||
Plan::Free,
|
||||
RequestUsage {
|
||||
limit: free_limit,
|
||||
amount: 50,
|
||||
},
|
||||
))
|
||||
.into_any_element(),
|
||||
),
|
||||
];
|
||||
|
||||
let zed_pro_limit = Plan::ZedPro.model_requests_limit();
|
||||
let zed_pro_examples = vec![
|
||||
single_example(
|
||||
"Zed Pro - Normal Usage",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(Plan::ZedPro, 250))
|
||||
.child(UsageBanner::new(
|
||||
Plan::ZedPro,
|
||||
RequestUsage {
|
||||
limit: zed_pro_limit,
|
||||
amount: 250,
|
||||
},
|
||||
))
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Zed Pro - Approaching Limit",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(Plan::ZedPro, 450))
|
||||
.child(UsageBanner::new(
|
||||
Plan::ZedPro,
|
||||
RequestUsage {
|
||||
limit: zed_pro_limit,
|
||||
amount: 450,
|
||||
},
|
||||
))
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Zed Pro - Request Limit Reached",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(Plan::ZedPro, 500))
|
||||
.child(UsageBanner::new(
|
||||
Plan::ZedPro,
|
||||
RequestUsage {
|
||||
limit: zed_pro_limit,
|
||||
amount: 500,
|
||||
},
|
||||
))
|
||||
.into_any_element(),
|
||||
),
|
||||
];
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
mod supported_countries;
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
@@ -11,8 +9,6 @@ use serde::{Deserialize, Serialize};
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
|
||||
pub use supported_countries::*;
|
||||
|
||||
pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
|
||||
@@ -1,225 +0,0 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
/// Returns whether the given country code is supported by Anthropic.
|
||||
///
|
||||
/// <https://www.anthropic.com/supported-countries>
|
||||
pub fn is_supported_country(country_code: &str) -> bool {
|
||||
SUPPORTED_COUNTRIES.contains(&country_code)
|
||||
}
|
||||
|
||||
/// The list of country codes supported by Anthropic.
|
||||
///
|
||||
/// https://www.anthropic.com/supported-countries
|
||||
static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
|
||||
vec![
|
||||
"AL", // Albania
|
||||
"DZ", // Algeria
|
||||
"AS", // American Samoa (US)
|
||||
"AD", // Andorra
|
||||
"AO", // Angola
|
||||
"AI", // Anguilla (UK)
|
||||
"AG", // Antigua and Barbuda
|
||||
"AR", // Argentina
|
||||
"AM", // Armenia
|
||||
"AU", // Australia
|
||||
"AT", // Austria
|
||||
"AZ", // Azerbaijan
|
||||
"BS", // Bahamas
|
||||
"BH", // Bahrain
|
||||
"BD", // Bangladesh
|
||||
"BB", // Barbados
|
||||
"BE", // Belgium
|
||||
"BZ", // Belize
|
||||
"BJ", // Benin
|
||||
"BM", // Bermuda (UK)
|
||||
"BT", // Bhutan
|
||||
"BO", // Bolivia
|
||||
"BA", // Bosnia and Herzegovina
|
||||
"BW", // Botswana
|
||||
"BR", // Brazil
|
||||
"IO", // British Indian Ocean Territory (UK)
|
||||
"BN", // Brunei
|
||||
"BG", // Bulgaria
|
||||
"BF", // Burkina Faso
|
||||
"BI", // Burundi
|
||||
"CV", // Cabo Verde
|
||||
"KH", // Cambodia
|
||||
"CM", // Cameroon
|
||||
"CA", // Canada
|
||||
"KY", // Cayman Islands (UK)
|
||||
"TD", // Chad
|
||||
"CL", // Chile
|
||||
"CX", // Christmas Island (AU)
|
||||
"CC", // Cocos (Keeling) Islands (AU)
|
||||
"CO", // Colombia
|
||||
"KM", // Comoros
|
||||
"CG", // Congo (Brazzaville)
|
||||
"CK", // Cook Islands (NZ)
|
||||
"CR", // Costa Rica
|
||||
"CI", // Côte d'Ivoire
|
||||
"HR", // Croatia
|
||||
"CY", // Cyprus
|
||||
"CZ", // Czechia (Czech Republic)
|
||||
"DK", // Denmark
|
||||
"DJ", // Djibouti
|
||||
"DM", // Dominica
|
||||
"DO", // Dominican Republic
|
||||
"EC", // Ecuador
|
||||
"EG", // Egypt
|
||||
"SV", // El Salvador
|
||||
"GQ", // Equatorial Guinea
|
||||
"EE", // Estonia
|
||||
"SZ", // Eswatini
|
||||
"FK", // Falkland Islands (UK)
|
||||
"FJ", // Fiji
|
||||
"FI", // Finland
|
||||
"FR", // France
|
||||
"GF", // French Guiana (FR)
|
||||
"PF", // French Polynesia (FR)
|
||||
"TF", // French Southern Territories
|
||||
"GA", // Gabon
|
||||
"GM", // Gambia
|
||||
"GE", // Georgia
|
||||
"DE", // Germany
|
||||
"GH", // Ghana
|
||||
"GI", // Gibraltar (UK)
|
||||
"GR", // Greece
|
||||
"GD", // Grenada
|
||||
"GT", // Guatemala
|
||||
"GU", // Guam (US)
|
||||
"GN", // Guinea
|
||||
"GW", // Guinea-Bissau
|
||||
"GY", // Guyana
|
||||
"HT", // Haiti
|
||||
"HM", // Heard Island and McDonald Islands (AU)
|
||||
"HN", // Honduras
|
||||
"HU", // Hungary
|
||||
"IS", // Iceland
|
||||
"IN", // India
|
||||
"ID", // Indonesia
|
||||
"IQ", // Iraq
|
||||
"IE", // Ireland
|
||||
"IL", // Israel
|
||||
"IT", // Italy
|
||||
"JM", // Jamaica
|
||||
"JP", // Japan
|
||||
"JO", // Jordan
|
||||
"KZ", // Kazakhstan
|
||||
"KE", // Kenya
|
||||
"KI", // Kiribati
|
||||
"KW", // Kuwait
|
||||
"KG", // Kyrgyzstan
|
||||
"LA", // Laos
|
||||
"LV", // Latvia
|
||||
"LB", // Lebanon
|
||||
"LS", // Lesotho
|
||||
"LR", // Liberia
|
||||
"LI", // Liechtenstein
|
||||
"LT", // Lithuania
|
||||
"LU", // Luxembourg
|
||||
"MG", // Madagascar
|
||||
"MW", // Malawi
|
||||
"MY", // Malaysia
|
||||
"MV", // Maldives
|
||||
"MT", // Malta
|
||||
"MH", // Marshall Islands
|
||||
"MR", // Mauritania
|
||||
"MU", // Mauritius
|
||||
"MX", // Mexico
|
||||
"FM", // Micronesia
|
||||
"MD", // Moldova
|
||||
"MC", // Monaco
|
||||
"MN", // Mongolia
|
||||
"MS", // Montserrat (UK)
|
||||
"ME", // Montenegro
|
||||
"MA", // Morocco
|
||||
"MZ", // Mozambique
|
||||
"NA", // Namibia
|
||||
"NR", // Nauru
|
||||
"NP", // Nepal
|
||||
"NL", // Netherlands
|
||||
"NZ", // New Zealand
|
||||
"NE", // Niger
|
||||
"NG", // Nigeria
|
||||
"NF", // Norfolk Island (AU)
|
||||
"MK", // North Macedonia
|
||||
"MI", // Northern Mariana Islands (UK)
|
||||
"NO", // Norway
|
||||
"NU", // Niue (NZ)
|
||||
"OM", // Oman
|
||||
"PK", // Pakistan
|
||||
"PW", // Palau
|
||||
"PS", // Palestine
|
||||
"PA", // Panama
|
||||
"PG", // Papua New Guinea
|
||||
"PY", // Paraguay
|
||||
"PE", // Peru
|
||||
"PH", // Philippines
|
||||
"PN", // Pitcairn (UK)
|
||||
"PL", // Poland
|
||||
"PT", // Portugal
|
||||
"PR", // Puerto Rico (US)
|
||||
"QA", // Qatar
|
||||
"RO", // Romania
|
||||
"RW", // Rwanda
|
||||
"BL", // Saint Barthélemy (FR)
|
||||
"KN", // Saint Kitts and Nevis
|
||||
"LC", // Saint Lucia
|
||||
"MF", // Saint Martin (FR)
|
||||
"PM", // Saint Pierre and Miquelon (FR)
|
||||
"VC", // Saint Vincent and the Grenadines
|
||||
"WS", // Samoa
|
||||
"SM", // San Marino
|
||||
"ST", // São Tomé and Príncipe
|
||||
"SA", // Saudi Arabia
|
||||
"SN", // Senegal
|
||||
"RS", // Serbia
|
||||
"SC", // Seychelles
|
||||
"SH", // Saint Helena, Ascension and Tristan da Cunha (UK)
|
||||
"SL", // Sierra Leone
|
||||
"SG", // Singapore
|
||||
"SK", // Slovakia
|
||||
"SI", // Slovenia
|
||||
"SB", // Solomon Islands
|
||||
"ZA", // South Africa
|
||||
"KR", // South Korea
|
||||
"ES", // Spain
|
||||
"LK", // Sri Lanka
|
||||
"SR", // Suriname
|
||||
"SE", // Sweden
|
||||
"CH", // Switzerland
|
||||
"TW", // Taiwan
|
||||
"TJ", // Tajikistan
|
||||
"TZ", // Tanzania
|
||||
"TH", // Thailand
|
||||
"TL", // Timor-Leste
|
||||
"TG", // Togo
|
||||
"TK", // Tokelau (NZ)
|
||||
"TO", // Tonga
|
||||
"TT", // Trinidad and Tobago
|
||||
"TN", // Tunisia
|
||||
"TR", // Türkiye (Turkey)
|
||||
"TM", // Turkmenistan
|
||||
"TC", // Turks and Caicos Islands (UK)
|
||||
"TV", // Tuvalu
|
||||
"UG", // Uganda
|
||||
"UA", // Ukraine (except Crimea, Donetsk, and Luhansk regions)
|
||||
"AE", // United Arab Emirates
|
||||
"GB", // United Kingdom
|
||||
"UM", // United States Minor Outlying Islands (US)
|
||||
"US", // United States of America
|
||||
"UY", // Uruguay
|
||||
"UZ", // Uzbekistan
|
||||
"VU", // Vanuatu
|
||||
"VA", // Vatican City
|
||||
"VN", // Vietnam
|
||||
"VI", // Virgin Islands (US)
|
||||
"VG", // Virgin Islands (UK)
|
||||
"WF", // Wallis and Futuna (FR)
|
||||
"ZM", // Zambia
|
||||
"ZW", // Zimbabwe
|
||||
]
|
||||
.into_iter()
|
||||
.collect()
|
||||
});
|
||||
@@ -49,7 +49,7 @@ menu.workspace = true
|
||||
multi_buffer.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
prompt_library.workspace = true
|
||||
rules_library.workspace = true
|
||||
prompt_store.workspace = true
|
||||
proto.workspace = true
|
||||
rope.workspace = true
|
||||
|
||||
@@ -101,7 +101,7 @@ pub fn init(
|
||||
SlashCommandSettings::register(cx);
|
||||
|
||||
assistant_context_editor::init(client.clone(), cx);
|
||||
prompt_library::init(cx);
|
||||
rules_library::init(cx);
|
||||
init_language_model_settings(cx);
|
||||
assistant_slash_command::init(cx);
|
||||
assistant_tool::init(cx);
|
||||
|
||||
@@ -23,11 +23,10 @@ use gpui::{
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{
|
||||
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||
ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_library::{PromptLibrary, open_prompt_library};
|
||||
use prompt_store::{PromptBuilder, PromptId};
|
||||
use prompt_store::{PromptBuilder, UserPromptId};
|
||||
use rules_library::{RulesLibrary, open_rules_library};
|
||||
|
||||
use search::{BufferSearchBar, buffer_search::DivRegistrar};
|
||||
use settings::{Settings, update_settings_file};
|
||||
@@ -44,7 +43,7 @@ use workspace::{
|
||||
dock::{DockPosition, Panel, PanelEvent},
|
||||
pane,
|
||||
};
|
||||
use zed_actions::assistant::{InlineAssist, OpenPromptLibrary, ShowConfiguration, ToggleFocus};
|
||||
use zed_actions::assistant::{InlineAssist, OpenRulesLibrary, ShowConfiguration, ToggleFocus};
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
workspace::FollowableViewRegistry::register::<ContextEditor>(cx);
|
||||
@@ -58,11 +57,11 @@ pub fn init(cx: &mut App) {
|
||||
.register_action(AssistantPanel::show_configuration)
|
||||
.register_action(AssistantPanel::create_new_context)
|
||||
.register_action(AssistantPanel::restart_context_servers)
|
||||
.register_action(|workspace, _: &OpenPromptLibrary, window, cx| {
|
||||
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.deploy_prompt_library(&OpenPromptLibrary::default(), window, cx)
|
||||
panel.deploy_rules_library(action, window, cx)
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -273,8 +272,8 @@ impl AssistantPanel {
|
||||
.action("New Chat", Box::new(NewChat))
|
||||
.action("History", Box::new(DeployHistory))
|
||||
.action(
|
||||
"Prompt Library",
|
||||
Box::new(OpenPromptLibrary::default()),
|
||||
"Rules Library",
|
||||
Box::new(OpenRulesLibrary::default()),
|
||||
)
|
||||
.action("Configure", Box::new(ShowConfiguration))
|
||||
.action(zoom_label, Box::new(ToggleZoom))
|
||||
@@ -489,8 +488,8 @@ impl AssistantPanel {
|
||||
|
||||
// If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is
|
||||
// the provider, we want to show a nudge to sign in.
|
||||
let show_zed_ai_notice = client_status.is_signed_out()
|
||||
&& model.map_or(true, |model| model.provider.id().0 == ZED_CLOUD_PROVIDER_ID);
|
||||
let show_zed_ai_notice =
|
||||
client_status.is_signed_out() && model.map_or(true, |model| model.is_provided_by_zed());
|
||||
|
||||
self.show_zed_ai_notice = show_zed_ai_notice;
|
||||
cx.notify();
|
||||
@@ -1044,13 +1043,13 @@ impl AssistantPanel {
|
||||
}
|
||||
}
|
||||
|
||||
fn deploy_prompt_library(
|
||||
fn deploy_rules_library(
|
||||
&mut self,
|
||||
action: &OpenPromptLibrary,
|
||||
action: &OpenRulesLibrary,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
open_prompt_library(
|
||||
open_rules_library(
|
||||
self.languages.clone(),
|
||||
Box::new(PromptLibraryInlineAssist),
|
||||
Arc::new(|| {
|
||||
@@ -1060,7 +1059,9 @@ impl AssistantPanel {
|
||||
None,
|
||||
))
|
||||
}),
|
||||
action.prompt_to_focus.map(|uuid| PromptId::User { uuid }),
|
||||
action
|
||||
.prompt_to_select
|
||||
.map(|uuid| UserPromptId(uuid).into()),
|
||||
cx,
|
||||
)
|
||||
.detach_and_log_err(cx);
|
||||
@@ -1234,7 +1235,7 @@ impl Render for AssistantPanel {
|
||||
this.show_configuration_tab(window, cx)
|
||||
}))
|
||||
.on_action(cx.listener(AssistantPanel::deploy_history))
|
||||
.on_action(cx.listener(AssistantPanel::deploy_prompt_library))
|
||||
.on_action(cx.listener(AssistantPanel::deploy_rules_library))
|
||||
.child(registrar.size_full().child(self.pane.clone()))
|
||||
.into_any_element()
|
||||
}
|
||||
@@ -1349,13 +1350,13 @@ impl Focusable for AssistantPanel {
|
||||
|
||||
struct PromptLibraryInlineAssist;
|
||||
|
||||
impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
fn assist(
|
||||
&self,
|
||||
prompt_editor: &Entity<Editor>,
|
||||
initial_prompt: Option<String>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<PromptLibrary>,
|
||||
cx: &mut Context<RulesLibrary>,
|
||||
) {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
assistant.assist(&prompt_editor, None, None, initial_prompt, window, cx)
|
||||
|
||||
@@ -18,11 +18,11 @@ use editor::{
|
||||
},
|
||||
};
|
||||
use feature_flags::{
|
||||
Assistant2FeatureFlag, FeatureFlagAppExt as _, FeatureFlagViewExt as _, ZedPro,
|
||||
Assistant2FeatureFlag, FeatureFlagAppExt as _, FeatureFlagViewExt as _, ZedProFeatureFlag,
|
||||
};
|
||||
use fs::Fs;
|
||||
use futures::{
|
||||
SinkExt, Stream, StreamExt,
|
||||
SinkExt, Stream, StreamExt, TryStreamExt as _,
|
||||
channel::mpsc,
|
||||
future::{BoxFuture, LocalBoxFuture},
|
||||
join,
|
||||
@@ -1226,7 +1226,7 @@ impl InlineAssistant {
|
||||
editor.highlight_rows::<InlineAssist>(
|
||||
row_range,
|
||||
cx.theme().status().info_background,
|
||||
false,
|
||||
Default::default(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
@@ -1291,7 +1291,7 @@ impl InlineAssistant {
|
||||
editor.highlight_rows::<DeletedLines>(
|
||||
Anchor::min()..Anchor::max(),
|
||||
cx.theme().status().deleted_background,
|
||||
false,
|
||||
Default::default(),
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
@@ -1652,7 +1652,7 @@ impl Render for PromptEditor {
|
||||
|
||||
let error_message = SharedString::from(error.to_string());
|
||||
if error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& cx.has_flag::<ZedPro>()
|
||||
&& cx.has_flag::<ZedProFeatureFlag>()
|
||||
{
|
||||
el.child(
|
||||
v_flex()
|
||||
@@ -1966,7 +1966,7 @@ impl PromptEditor {
|
||||
.update(cx, |editor, _| editor.set_read_only(false));
|
||||
}
|
||||
CodegenStatus::Error(error) => {
|
||||
if cx.has_flag::<ZedPro>()
|
||||
if cx.has_flag::<ZedProFeatureFlag>()
|
||||
&& error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& !dismissed_rate_limit_notice()
|
||||
{
|
||||
@@ -3056,7 +3056,8 @@ impl CodegenAlternative {
|
||||
let mut response_latency = None;
|
||||
let request_start = Instant::now();
|
||||
let diff = async {
|
||||
let chunks = StripInvalidSpans::new(stream?.stream);
|
||||
let chunks =
|
||||
StripInvalidSpans::new(stream?.stream.map_err(|e| e.into()));
|
||||
futures::pin_mut!(chunks);
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
|
||||
@@ -44,4 +44,6 @@ impl Settings for SlashCommandSettings {
|
||||
.chain(sources.server),
|
||||
)
|
||||
}
|
||||
|
||||
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
|
||||
}
|
||||
|
||||
@@ -2089,7 +2089,7 @@ impl ContextEditor {
|
||||
continue;
|
||||
};
|
||||
let image_id = image.id();
|
||||
let image_task = LanguageModelImage::from_image(image, cx).shared();
|
||||
let image_task = LanguageModelImage::from_image(Arc::new(image), cx).shared();
|
||||
|
||||
for image_position in image_positions.iter() {
|
||||
context.insert_content(
|
||||
|
||||
@@ -112,13 +112,27 @@ impl AssistantSettings {
|
||||
}
|
||||
|
||||
/// Assistant panel settings
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Default)]
|
||||
pub struct AssistantSettingsContent {
|
||||
#[serde(flatten)]
|
||||
pub inner: Option<AssistantSettingsContentInner>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
pub enum AssistantSettingsContent {
|
||||
pub enum AssistantSettingsContentInner {
|
||||
Versioned(Box<VersionedAssistantSettingsContent>),
|
||||
Legacy(LegacyAssistantSettingsContent),
|
||||
}
|
||||
|
||||
impl AssistantSettingsContentInner {
|
||||
fn for_v2(content: AssistantSettingsContentV2) -> Self {
|
||||
AssistantSettingsContentInner::Versioned(Box::new(VersionedAssistantSettingsContent::V2(
|
||||
content,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonSchema for AssistantSettingsContent {
|
||||
fn schema_name() -> String {
|
||||
VersionedAssistantSettingsContent::schema_name()
|
||||
@@ -133,26 +147,21 @@ impl JsonSchema for AssistantSettingsContent {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AssistantSettingsContent {
|
||||
fn default() -> Self {
|
||||
Self::Versioned(Box::new(VersionedAssistantSettingsContent::default()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AssistantSettingsContent {
|
||||
pub fn is_version_outdated(&self) -> bool {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match **settings {
|
||||
match &self.inner {
|
||||
Some(AssistantSettingsContentInner::Versioned(settings)) => match **settings {
|
||||
VersionedAssistantSettingsContent::V1(_) => true,
|
||||
VersionedAssistantSettingsContent::V2(_) => false,
|
||||
},
|
||||
AssistantSettingsContent::Legacy(_) => true,
|
||||
Some(AssistantSettingsContentInner::Legacy(_)) => true,
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn upgrade(&self) -> AssistantSettingsContentV2 {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match **settings {
|
||||
match &self.inner {
|
||||
Some(AssistantSettingsContentInner::Versioned(settings)) => match **settings {
|
||||
VersionedAssistantSettingsContent::V1(ref settings) => AssistantSettingsContentV2 {
|
||||
enabled: settings.enabled,
|
||||
button: settings.button,
|
||||
@@ -212,7 +221,7 @@ impl AssistantSettingsContent {
|
||||
},
|
||||
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
|
||||
},
|
||||
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV2 {
|
||||
Some(AssistantSettingsContentInner::Legacy(settings)) => AssistantSettingsContentV2 {
|
||||
enabled: None,
|
||||
button: settings.button,
|
||||
dock: settings.dock,
|
||||
@@ -237,12 +246,13 @@ impl AssistantSettingsContent {
|
||||
always_allow_tool_actions: None,
|
||||
notify_when_agent_waiting: None,
|
||||
},
|
||||
None => AssistantSettingsContentV2::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_dock(&mut self, dock: AssistantDockPosition) {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match **settings {
|
||||
match &mut self.inner {
|
||||
Some(AssistantSettingsContentInner::Versioned(settings)) => match **settings {
|
||||
VersionedAssistantSettingsContent::V1(ref mut settings) => {
|
||||
settings.dock = Some(dock);
|
||||
}
|
||||
@@ -250,9 +260,17 @@ impl AssistantSettingsContent {
|
||||
settings.dock = Some(dock);
|
||||
}
|
||||
},
|
||||
AssistantSettingsContent::Legacy(settings) => {
|
||||
Some(AssistantSettingsContentInner::Legacy(settings)) => {
|
||||
settings.dock = Some(dock);
|
||||
}
|
||||
None => {
|
||||
self.inner = Some(AssistantSettingsContentInner::for_v2(
|
||||
AssistantSettingsContentV2 {
|
||||
dock: Some(dock),
|
||||
..Default::default()
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -260,8 +278,8 @@ impl AssistantSettingsContent {
|
||||
let model = language_model.id().0.to_string();
|
||||
let provider = language_model.provider_id().0.to_string();
|
||||
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match **settings {
|
||||
match &mut self.inner {
|
||||
Some(AssistantSettingsContentInner::Versioned(settings)) => match **settings {
|
||||
VersionedAssistantSettingsContent::V1(ref mut settings) => {
|
||||
match provider.as_ref() {
|
||||
"zed.dev" => {
|
||||
@@ -337,56 +355,80 @@ impl AssistantSettingsContent {
|
||||
settings.default_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
},
|
||||
AssistantSettingsContent::Legacy(settings) => {
|
||||
Some(AssistantSettingsContentInner::Legacy(settings)) => {
|
||||
if let Ok(model) = OpenAiModel::from_id(&language_model.id().0) {
|
||||
settings.default_open_ai_model = Some(model);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
self.inner = Some(AssistantSettingsContentInner::for_v2(
|
||||
AssistantSettingsContentV2 {
|
||||
default_model: Some(LanguageModelSelection { provider, model }),
|
||||
..Default::default()
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
|
||||
if let AssistantSettingsContent::Versioned(boxed) = self {
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.inline_assistant_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
}
|
||||
self.v2_setting(|setting| {
|
||||
setting.inline_assistant_model = Some(LanguageModelSelection { provider, model });
|
||||
Ok(())
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
pub fn set_commit_message_model(&mut self, provider: String, model: String) {
|
||||
if let AssistantSettingsContent::Versioned(boxed) = self {
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.commit_message_model = Some(LanguageModelSelection { provider, model });
|
||||
self.v2_setting(|setting| {
|
||||
setting.commit_message_model = Some(LanguageModelSelection { provider, model });
|
||||
Ok(())
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
pub fn v2_setting(
|
||||
&mut self,
|
||||
f: impl FnOnce(&mut AssistantSettingsContentV2) -> anyhow::Result<()>,
|
||||
) -> anyhow::Result<()> {
|
||||
match self.inner.get_or_insert_with(|| {
|
||||
AssistantSettingsContentInner::for_v2(AssistantSettingsContentV2 {
|
||||
..Default::default()
|
||||
})
|
||||
}) {
|
||||
AssistantSettingsContentInner::Versioned(boxed) => {
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
f(settings)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
|
||||
if let AssistantSettingsContent::Versioned(boxed) = self {
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.thread_summary_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
}
|
||||
self.v2_setting(|setting| {
|
||||
setting.thread_summary_model = Some(LanguageModelSelection { provider, model });
|
||||
Ok(())
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
pub fn set_always_allow_tool_actions(&mut self, allow: bool) {
|
||||
let AssistantSettingsContent::Versioned(boxed) = self else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.always_allow_tool_actions = Some(allow);
|
||||
}
|
||||
self.v2_setting(|setting| {
|
||||
setting.always_allow_tool_actions = Some(allow);
|
||||
Ok(())
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
pub fn set_profile(&mut self, profile_id: AgentProfileId) {
|
||||
let AssistantSettingsContent::Versioned(boxed) = self else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.default_profile = Some(profile_id);
|
||||
}
|
||||
self.v2_setting(|setting| {
|
||||
setting.default_profile = Some(profile_id);
|
||||
Ok(())
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
pub fn create_profile(
|
||||
@@ -394,11 +436,7 @@ impl AssistantSettingsContent {
|
||||
profile_id: AgentProfileId,
|
||||
profile: AgentProfile,
|
||||
) -> Result<()> {
|
||||
let AssistantSettingsContent::Versioned(boxed) = self else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
self.v2_setting(|settings| {
|
||||
let profiles = settings.profiles.get_or_insert_default();
|
||||
if profiles.contains_key(&profile_id) {
|
||||
bail!("profile with ID '{profile_id}' already exists");
|
||||
@@ -424,9 +462,9 @@ impl AssistantSettingsContent {
|
||||
.collect(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -461,7 +499,7 @@ impl Default for VersionedAssistantSettingsContent {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug, Default)]
|
||||
pub struct AssistantSettingsContentV2 {
|
||||
/// Whether the Assistant is enabled.
|
||||
///
|
||||
@@ -708,6 +746,39 @@ impl Settings for AssistantSettings {
|
||||
|
||||
Ok(settings)
|
||||
}
|
||||
|
||||
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
|
||||
if let Some(b) = vscode
|
||||
.read_value("chat.agent.enabled")
|
||||
.and_then(|b| b.as_bool())
|
||||
{
|
||||
match &mut current.inner {
|
||||
Some(AssistantSettingsContentInner::Versioned(versioned)) => {
|
||||
match versioned.as_mut() {
|
||||
VersionedAssistantSettingsContent::V1(setting) => {
|
||||
setting.enabled = Some(b);
|
||||
setting.button = Some(b);
|
||||
}
|
||||
|
||||
VersionedAssistantSettingsContent::V2(setting) => {
|
||||
setting.enabled = Some(b);
|
||||
setting.button = Some(b);
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(AssistantSettingsContentInner::Legacy(setting)) => setting.button = Some(b),
|
||||
None => {
|
||||
current.inner = Some(AssistantSettingsContentInner::for_v2(
|
||||
AssistantSettingsContentV2 {
|
||||
enabled: Some(b),
|
||||
button: Some(b),
|
||||
..Default::default()
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn merge<T>(target: &mut T, value: Option<T>) {
|
||||
@@ -751,28 +822,30 @@ mod tests {
|
||||
settings::SettingsStore::global(cx).update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
|settings, _| {
|
||||
*settings = AssistantSettingsContent::Versioned(Box::new(
|
||||
VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 {
|
||||
default_model: Some(LanguageModelSelection {
|
||||
provider: "test-provider".into(),
|
||||
model: "gpt-99".into(),
|
||||
}),
|
||||
inline_assistant_model: None,
|
||||
commit_message_model: None,
|
||||
thread_summary_model: None,
|
||||
inline_alternatives: None,
|
||||
enabled: None,
|
||||
button: None,
|
||||
dock: None,
|
||||
default_width: None,
|
||||
default_height: None,
|
||||
enable_experimental_live_diffs: None,
|
||||
default_profile: None,
|
||||
profiles: None,
|
||||
always_allow_tool_actions: None,
|
||||
notify_when_agent_waiting: None,
|
||||
}),
|
||||
))
|
||||
*settings = AssistantSettingsContent {
|
||||
inner: Some(AssistantSettingsContentInner::for_v2(
|
||||
AssistantSettingsContentV2 {
|
||||
default_model: Some(LanguageModelSelection {
|
||||
provider: "test-provider".into(),
|
||||
model: "gpt-99".into(),
|
||||
}),
|
||||
inline_assistant_model: None,
|
||||
commit_message_model: None,
|
||||
thread_summary_model: None,
|
||||
inline_alternatives: None,
|
||||
enabled: None,
|
||||
button: None,
|
||||
dock: None,
|
||||
default_width: None,
|
||||
default_height: None,
|
||||
enable_experimental_live_diffs: None,
|
||||
default_profile: None,
|
||||
profiles: None,
|
||||
always_allow_tool_actions: None,
|
||||
notify_when_agent_waiting: None,
|
||||
},
|
||||
)),
|
||||
}
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
@@ -44,9 +44,10 @@ impl SlashCommand for PromptSlashCommand {
|
||||
let store = PromptStore::global(cx);
|
||||
let query = arguments.to_owned().join(" ");
|
||||
cx.spawn(async move |cx| {
|
||||
let cancellation_flag = Arc::new(AtomicBool::default());
|
||||
let prompts: Vec<PromptMetadata> = store
|
||||
.await?
|
||||
.read_with(cx, |store, cx| store.search(query, cx))?
|
||||
.read_with(cx, |store, cx| store.search(query, cancellation_flag, cx))?
|
||||
.await;
|
||||
Ok(prompts
|
||||
.into_iter()
|
||||
|
||||
@@ -28,6 +28,7 @@ serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
text.workspace = true
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -39,10 +39,9 @@ impl ActionLog {
|
||||
self.edited_since_project_diagnostics_check
|
||||
}
|
||||
|
||||
fn track_buffer(
|
||||
fn track_buffer_internal(
|
||||
&mut self,
|
||||
buffer: Entity<Buffer>,
|
||||
created: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) -> &mut TrackedBuffer {
|
||||
let tracked_buffer = self
|
||||
@@ -59,7 +58,11 @@ impl ActionLog {
|
||||
let base_text;
|
||||
let status;
|
||||
let unreviewed_changes;
|
||||
if created {
|
||||
if buffer
|
||||
.read(cx)
|
||||
.file()
|
||||
.map_or(true, |file| !file.disk_state().exists())
|
||||
{
|
||||
base_text = Rope::default();
|
||||
status = TrackedBufferStatus::Created;
|
||||
unreviewed_changes = Patch::new(vec![Edit {
|
||||
@@ -146,7 +149,7 @@ impl ActionLog {
|
||||
// resurrected externally, we want to clear the changes we
|
||||
// were tracking and reset the buffer's state.
|
||||
self.tracked_buffers.remove(&buffer);
|
||||
self.track_buffer(buffer, false, cx);
|
||||
self.track_buffer_internal(buffer, cx);
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
@@ -260,26 +263,15 @@ impl ActionLog {
|
||||
}
|
||||
|
||||
/// Track a buffer as read, so we can notify the model about user edits.
|
||||
pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.track_buffer(buffer, false, cx);
|
||||
}
|
||||
|
||||
/// Track a buffer that was added as context, so we can notify the model about user edits.
|
||||
pub fn buffer_added_as_context(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.track_buffer(buffer, false, cx);
|
||||
}
|
||||
|
||||
/// Track a buffer as read, so we can notify the model about user edits.
|
||||
pub fn will_create_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.track_buffer(buffer.clone(), true, cx);
|
||||
self.buffer_edited(buffer, cx)
|
||||
pub fn track_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.track_buffer_internal(buffer, cx);
|
||||
}
|
||||
|
||||
/// Mark a buffer as edited, so we can refresh it in the context
|
||||
pub fn buffer_edited(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.edited_since_project_diagnostics_check = true;
|
||||
|
||||
let tracked_buffer = self.track_buffer(buffer.clone(), false, cx);
|
||||
let tracked_buffer = self.track_buffer_internal(buffer.clone(), cx);
|
||||
if let TrackedBufferStatus::Deleted = tracked_buffer.status {
|
||||
tracked_buffer.status = TrackedBufferStatus::Modified;
|
||||
}
|
||||
@@ -287,7 +279,7 @@ impl ActionLog {
|
||||
}
|
||||
|
||||
pub fn will_delete_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
let tracked_buffer = self.track_buffer(buffer.clone(), false, cx);
|
||||
let tracked_buffer = self.track_buffer_internal(buffer.clone(), cx);
|
||||
match tracked_buffer.status {
|
||||
TrackedBufferStatus::Created => {
|
||||
self.tracked_buffers.remove(&buffer);
|
||||
@@ -397,7 +389,7 @@ impl ActionLog {
|
||||
|
||||
// Clear all tracked changes for this buffer and start over as if we just read it.
|
||||
self.tracked_buffers.remove(&buffer);
|
||||
self.track_buffer(buffer.clone(), false, cx);
|
||||
self.track_buffer_internal(buffer.clone(), cx);
|
||||
cx.notify();
|
||||
save
|
||||
}
|
||||
@@ -695,12 +687,20 @@ mod tests {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
|
||||
let file_path = project
|
||||
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
|
||||
.unwrap();
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(file_path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.edit([(Point::new(1, 1)..Point::new(1, 2), "E")], None, cx)
|
||||
@@ -765,12 +765,23 @@ mod tests {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
fs.insert_tree(
|
||||
path!("/dir"),
|
||||
json!({"file": "abc\ndef\nghi\njkl\nmno\npqr"}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno\npqr", cx));
|
||||
let file_path = project
|
||||
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
|
||||
.unwrap();
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(file_path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.edit([(Point::new(1, 0)..Point::new(2, 0), "")], None, cx)
|
||||
@@ -839,12 +850,20 @@ mod tests {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
|
||||
let file_path = project
|
||||
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
|
||||
.unwrap();
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(file_path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.edit([(Point::new(1, 2)..Point::new(2, 3), "F\nGHI")], None, cx)
|
||||
@@ -928,25 +947,21 @@ mod tests {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/dir"), json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/dir"), json!({})).await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
|
||||
let file_path = project
|
||||
.read_with(cx, |project, cx| project.find_project_path("dir/file1", cx))
|
||||
.unwrap();
|
||||
|
||||
// Simulate file2 being recreated by a tool.
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(file_path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text("lorem", cx));
|
||||
action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||
});
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||
@@ -1067,8 +1082,9 @@ mod tests {
|
||||
.update(cx, |project, cx| project.open_buffer(file2_path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer2.clone(), cx));
|
||||
buffer2.update(cx, |buffer, cx| buffer.set_text("IPSUM", cx));
|
||||
action_log.update(cx, |log, cx| log.will_create_buffer(buffer2.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_edited(buffer2.clone(), cx));
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer2.clone(), cx))
|
||||
.await
|
||||
@@ -1113,7 +1129,7 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
|
||||
@@ -1248,7 +1264,7 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
|
||||
@@ -1381,8 +1397,9 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text("content", cx));
|
||||
action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||
});
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||
@@ -1438,7 +1455,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
|
||||
for _ in 0..operations {
|
||||
match rng.gen_range(0..100) {
|
||||
@@ -1490,7 +1507,7 @@ mod tests {
|
||||
log::info!("quiescing...");
|
||||
cx.run_until_parked();
|
||||
action_log.update(cx, |log, cx| {
|
||||
let tracked_buffer = log.track_buffer(buffer.clone(), false, cx);
|
||||
let tracked_buffer = log.track_buffer_internal(buffer.clone(), cx);
|
||||
let mut old_text = tracked_buffer.base_text.clone();
|
||||
let new_text = buffer.read(cx).as_rope();
|
||||
for edit in tracked_buffer.unreviewed_changes.edits() {
|
||||
|
||||
@@ -10,14 +10,16 @@ use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use gpui::AnyElement;
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::Context;
|
||||
use gpui::IntoElement;
|
||||
use gpui::Window;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use gpui::{App, Entity, SharedString, Task, WeakEntity};
|
||||
use icons::IconName;
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
use project::Project;
|
||||
use workspace::Workspace;
|
||||
|
||||
pub use crate::action_log::*;
|
||||
pub use crate::tool_registry::*;
|
||||
@@ -30,6 +32,7 @@ pub fn init(cx: &mut App) {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ToolUseStatus {
|
||||
InputStillStreaming,
|
||||
NeedsConfirmation,
|
||||
Pending,
|
||||
Running,
|
||||
@@ -41,6 +44,7 @@ impl ToolUseStatus {
|
||||
pub fn text(&self) -> SharedString {
|
||||
match self {
|
||||
ToolUseStatus::NeedsConfirmation => "".into(),
|
||||
ToolUseStatus::InputStillStreaming => "".into(),
|
||||
ToolUseStatus::Pending => "".into(),
|
||||
ToolUseStatus::Running => "".into(),
|
||||
ToolUseStatus::Finished(out) => out.clone(),
|
||||
@@ -63,6 +67,7 @@ pub trait ToolCard: 'static + Sized {
|
||||
&mut self,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement;
|
||||
}
|
||||
@@ -74,6 +79,7 @@ pub struct AnyToolCard {
|
||||
entity: gpui::AnyEntity,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
cx: &mut App,
|
||||
) -> AnyElement,
|
||||
}
|
||||
@@ -84,11 +90,14 @@ impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
|
||||
entity: gpui::AnyEntity,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
cx: &mut App,
|
||||
) -> AnyElement {
|
||||
let entity = entity.downcast::<T>().unwrap();
|
||||
entity.update(cx, |entity, cx| {
|
||||
entity.render(status, window, cx).into_any_element()
|
||||
entity
|
||||
.render(status, window, workspace, cx)
|
||||
.into_any_element()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -100,8 +109,14 @@ impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
|
||||
}
|
||||
|
||||
impl AnyToolCard {
|
||||
pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement {
|
||||
(self.render)(self.entity.clone(), status, window, cx)
|
||||
pub fn render(
|
||||
&self,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
cx: &mut App,
|
||||
) -> AnyElement {
|
||||
(self.render)(self.entity.clone(), status, window, workspace, cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,6 +163,12 @@ pub trait Tool: 'static + Send + Sync {
|
||||
/// Returns markdown to be displayed in the UI for this tool.
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String;
|
||||
|
||||
/// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
|
||||
/// (so information may be missing).
|
||||
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
|
||||
self.ui_text(input)
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
@@ -155,6 +176,7 @@ pub trait Tool: 'static + Send + Sync {
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult;
|
||||
}
|
||||
|
||||
@@ -10,6 +10,11 @@ pub fn adapt_schema_to_format(
|
||||
json: &mut Value,
|
||||
format: LanguageModelToolSchemaFormat,
|
||||
) -> Result<()> {
|
||||
if let Value::Object(obj) = json {
|
||||
obj.remove("$schema");
|
||||
obj.remove("title");
|
||||
}
|
||||
|
||||
match format {
|
||||
LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
|
||||
@@ -30,10 +35,7 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"];
|
||||
for key in KEYS_TO_REMOVE {
|
||||
obj.remove(key);
|
||||
}
|
||||
obj.remove("format");
|
||||
|
||||
if let Some(default) = obj.get("default") {
|
||||
let is_null = default.is_null();
|
||||
|
||||
@@ -14,12 +14,14 @@ path = "src/assistant_tools.rs"
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
feature_flags.workspace = true
|
||||
editor.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||
html_to_markdown.workspace = true
|
||||
http_client.workspace = true
|
||||
indoc.workspace = true
|
||||
@@ -30,23 +32,32 @@ linkme.workspace = true
|
||||
open.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
rust-embed.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
smallvec.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
web_search.workspace = true
|
||||
workspace.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
worktree.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
client = { workspace = true, features = ["test-support"] }
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
collections = { workspace = true, features = ["test-support"] }
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
gpui_tokio.workspace = true
|
||||
fs = { workspace = true, features = ["test-support"] }
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
language_models.workspace = true
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
rand.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
reqwest_client.workspace = true
|
||||
settings = { workspace = true, features = ["test-support"] }
|
||||
tree-sitter-rust.workspace = true
|
||||
workspace = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -7,33 +7,37 @@ mod create_directory_tool;
|
||||
mod create_file_tool;
|
||||
mod delete_path_tool;
|
||||
mod diagnostics_tool;
|
||||
mod edit_agent;
|
||||
mod edit_file_tool;
|
||||
mod fetch_tool;
|
||||
mod find_path_tool;
|
||||
mod grep_tool;
|
||||
mod list_directory_tool;
|
||||
mod move_path_tool;
|
||||
mod now_tool;
|
||||
mod open_tool;
|
||||
mod path_search_tool;
|
||||
mod read_file_tool;
|
||||
mod rename_tool;
|
||||
mod replace;
|
||||
mod schema;
|
||||
mod symbol_info_tool;
|
||||
mod templates;
|
||||
mod terminal_tool;
|
||||
mod thinking_tool;
|
||||
mod ui;
|
||||
mod web_search_tool;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use assistant_tool::ToolRegistry;
|
||||
use copy_path_tool::CopyPathTool;
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
use gpui::App;
|
||||
use http_client::HttpClientWithUrl;
|
||||
use language_model::LanguageModelRegistry;
|
||||
use move_path_tool::MovePathTool;
|
||||
use web_search_tool::WebSearchTool;
|
||||
|
||||
pub(crate) use templates::*;
|
||||
|
||||
use crate::batch_tool::BatchTool;
|
||||
use crate::code_action_tool::CodeActionTool;
|
||||
use crate::code_symbols_tool::CodeSymbolsTool;
|
||||
@@ -44,17 +48,22 @@ use crate::delete_path_tool::DeletePathTool;
|
||||
use crate::diagnostics_tool::DiagnosticsTool;
|
||||
use crate::edit_file_tool::EditFileTool;
|
||||
use crate::fetch_tool::FetchTool;
|
||||
use crate::find_path_tool::FindPathTool;
|
||||
use crate::grep_tool::GrepTool;
|
||||
use crate::list_directory_tool::ListDirectoryTool;
|
||||
use crate::now_tool::NowTool;
|
||||
use crate::open_tool::OpenTool;
|
||||
use crate::path_search_tool::PathSearchTool;
|
||||
use crate::read_file_tool::ReadFileTool;
|
||||
use crate::rename_tool::RenameTool;
|
||||
use crate::symbol_info_tool::SymbolInfoTool;
|
||||
use crate::terminal_tool::TerminalTool;
|
||||
use crate::thinking_tool::ThinkingTool;
|
||||
|
||||
pub use create_file_tool::CreateFileToolInput;
|
||||
pub use edit_file_tool::EditFileToolInput;
|
||||
pub use find_path_tool::FindPathToolInput;
|
||||
pub use read_file_tool::ReadFileToolInput;
|
||||
|
||||
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
assistant_tool::init(cx);
|
||||
|
||||
@@ -75,41 +84,79 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
registry.register_tool(OpenTool);
|
||||
registry.register_tool(CodeSymbolsTool);
|
||||
registry.register_tool(ContentsTool);
|
||||
registry.register_tool(PathSearchTool);
|
||||
registry.register_tool(FindPathTool);
|
||||
registry.register_tool(ReadFileTool);
|
||||
registry.register_tool(GrepTool);
|
||||
registry.register_tool(RenameTool);
|
||||
registry.register_tool(ThinkingTool);
|
||||
registry.register_tool(FetchTool::new(http_client));
|
||||
|
||||
cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
|
||||
move |is_enabled, cx| {
|
||||
if is_enabled {
|
||||
ToolRegistry::global(cx).register_tool(WebSearchTool);
|
||||
} else {
|
||||
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
|
||||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
move |registry, event, cx| match event {
|
||||
language_model::Event::DefaultModelChanged => {
|
||||
let using_zed_provider = registry
|
||||
.read(cx)
|
||||
.default_model()
|
||||
.map_or(false, |default| default.is_provided_by_zed());
|
||||
if using_zed_provider {
|
||||
ToolRegistry::global(cx).register_tool(WebSearchTool);
|
||||
} else {
|
||||
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
_ => {}
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use http_client::FakeHttpClient;
|
||||
|
||||
use super::*;
|
||||
use client::Client;
|
||||
use clock::FakeSystemClock;
|
||||
use http_client::FakeHttpClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Serialize;
|
||||
|
||||
#[test]
|
||||
fn test_json_schema() {
|
||||
#[derive(Serialize, JsonSchema)]
|
||||
struct GetWeatherTool {
|
||||
location: String,
|
||||
}
|
||||
|
||||
let schema = schema::json_schema_for::<GetWeatherTool>(
|
||||
language_model::LanguageModelToolSchemaFormat::JsonSchema,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
schema,
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
|
||||
crate::init(
|
||||
Arc::new(http_client::HttpClientWithUrl::new(
|
||||
FakeHttpClient::with_200_response(),
|
||||
"https://zed.dev",
|
||||
None,
|
||||
)),
|
||||
settings::init(cx);
|
||||
|
||||
let client = Client::new(
|
||||
Arc::new(FakeSystemClock::new()),
|
||||
FakeHttpClient::with_200_response(),
|
||||
cx,
|
||||
);
|
||||
language_model::init(client.clone(), cx);
|
||||
crate::init(client.http_client(), cx);
|
||||
|
||||
for tool in ToolRegistry::global(cx).tools() {
|
||||
let actual_schema = tool
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult, ToolWorkingSet};
|
||||
use futures::future::join_all;
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -97,7 +97,7 @@ pub struct BatchToolInput {
|
||||
/// }
|
||||
/// },
|
||||
/// {
|
||||
/// "name": "path_search",
|
||||
/// "name": "find_path",
|
||||
/// "input": {
|
||||
/// "glob": "**/*test*.rs"
|
||||
/// }
|
||||
@@ -218,6 +218,7 @@ impl Tool for BatchTool {
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<BatchToolInput>(input) {
|
||||
@@ -258,7 +259,9 @@ impl Tool for BatchTool {
|
||||
let action_log = action_log.clone();
|
||||
let messages = messages.clone();
|
||||
let tool_result = cx
|
||||
.update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
|
||||
.update(|cx| {
|
||||
tool.run(invocation.input, &messages, project, action_log, window, cx)
|
||||
})
|
||||
.map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
|
||||
|
||||
tasks.push(tool_result.output);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language::{self, Anchor, Buffer, ToPointUtf16};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{self, LspAction, Project};
|
||||
@@ -140,6 +140,7 @@ impl Tool for CodeActionTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CodeActionToolInput>(input) {
|
||||
@@ -159,7 +160,7 @@ impl Tool for CodeActionTool {
|
||||
};
|
||||
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.buffer_read(buffer.clone(), cx);
|
||||
action_log.track_buffer(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
let range = {
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use collections::IndexMap;
|
||||
use gpui::{App, AsyncApp, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
|
||||
use language::{OutlineItem, ParseStatus, Point};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{Project, Symbol};
|
||||
@@ -128,6 +128,7 @@ impl Tool for CodeSymbolsTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CodeSymbolsInput>(input) {
|
||||
@@ -174,7 +175,7 @@ pub async fn file_outline(
|
||||
};
|
||||
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.buffer_read(buffer.clone(), cx);
|
||||
action_log.track_buffer(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
// Wait until the buffer has been fully parsed, so that we can read its outline.
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use itertools::Itertools;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -102,6 +102,7 @@ impl Tool for ContentsTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<ContentsToolInput>(input) {
|
||||
@@ -209,7 +210,7 @@ impl Tool for ContentsTool {
|
||||
})?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_read(buffer, cx);
|
||||
log.track_buffer(buffer, cx);
|
||||
})?;
|
||||
|
||||
Ok(result)
|
||||
@@ -221,7 +222,7 @@ impl Tool for ContentsTool {
|
||||
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_read(buffer, cx);
|
||||
log.track_buffer(buffer, cx);
|
||||
})?;
|
||||
|
||||
Ok(result)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
@@ -76,6 +77,7 @@ impl Tool for CopyPathTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CopyPathToolInput>(input) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::{App, Entity, Task};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
@@ -67,6 +68,7 @@ impl Tool for CreateDirectoryTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CreateDirectoryToolInput>(input) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::{App, Entity, Task};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
@@ -23,6 +24,9 @@ pub struct CreateFileToolInput {
|
||||
///
|
||||
/// You can create a new file by providing a path of "directory1/new_file.txt"
|
||||
/// </example>
|
||||
///
|
||||
/// Make sure to include this field before the `contents` field in the input object
|
||||
/// so that we can display it immediately.
|
||||
pub path: String,
|
||||
|
||||
/// The text contents of the file to create.
|
||||
@@ -33,8 +37,18 @@ pub struct CreateFileToolInput {
|
||||
pub contents: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
struct PartialInput {
|
||||
#[serde(default)]
|
||||
path: String,
|
||||
#[serde(default)]
|
||||
contents: String,
|
||||
}
|
||||
|
||||
pub struct CreateFileTool;
|
||||
|
||||
const DEFAULT_UI_TEXT: &str = "Create file";
|
||||
|
||||
impl Tool for CreateFileTool {
|
||||
fn name(&self) -> String {
|
||||
"create_file".into()
|
||||
@@ -62,7 +76,14 @@ impl Tool for CreateFileTool {
|
||||
let path = MarkdownString::inline_code(&input.path);
|
||||
format!("Create file {path}")
|
||||
}
|
||||
Err(_) => "Create file".to_string(),
|
||||
Err(_) => DEFAULT_UI_TEXT.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<PartialInput>(input.clone()).ok() {
|
||||
Some(input) if !input.path.is_empty() => input.path,
|
||||
_ => DEFAULT_UI_TEXT.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +93,7 @@ impl Tool for CreateFileTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CreateFileToolInput>(input) {
|
||||
@@ -95,9 +117,12 @@ impl Tool for CreateFileTool {
|
||||
.await
|
||||
.map_err(|err| anyhow!("Unable to open buffer for {destination_path}: {err}"))?;
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.track_buffer(buffer.clone(), cx)
|
||||
});
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text(contents, cx));
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.will_create_buffer(buffer.clone(), cx)
|
||||
action_log.buffer_edited(buffer.clone(), cx)
|
||||
});
|
||||
})?;
|
||||
|
||||
@@ -111,3 +136,60 @@ impl Tool for CreateFileTool {
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_path() {
|
||||
let tool = CreateFileTool;
|
||||
let input = json!({
|
||||
"path": "src/main.rs",
|
||||
"contents": "fn main() {\n println!(\"Hello, world!\");\n}"
|
||||
});
|
||||
|
||||
assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_without_path() {
|
||||
let tool = CreateFileTool;
|
||||
let input = json!({
|
||||
"path": "",
|
||||
"contents": "fn main() {\n println!(\"Hello, world!\");\n}"
|
||||
});
|
||||
|
||||
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_null() {
|
||||
let tool = CreateFileTool;
|
||||
let input = serde_json::Value::Null;
|
||||
|
||||
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ui_text_with_valid_input() {
|
||||
let tool = CreateFileTool;
|
||||
let input = json!({
|
||||
"path": "src/main.rs",
|
||||
"contents": "fn main() {\n println!(\"Hello, world!\");\n}"
|
||||
});
|
||||
|
||||
assert_eq!(tool.ui_text(&input), "Create file `src/main.rs`");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ui_text_with_invalid_input() {
|
||||
let tool = CreateFileTool;
|
||||
let input = json!({
|
||||
"invalid": "field"
|
||||
});
|
||||
|
||||
assert_eq!(tool.ui_text(&input), DEFAULT_UI_TEXT);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::{SinkExt, StreamExt, channel::mpsc};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{Project, ProjectPath};
|
||||
use schemars::JsonSchema;
|
||||
@@ -62,6 +62,7 @@ impl Tool for DeletePathTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language::{DiagnosticSeverity, OffsetRangeExt};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -82,6 +82,7 @@ impl Tool for DiagnosticsTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
match serde_json::from_value::<DiagnosticsToolInput>(input)
|
||||
|
||||
526
crates/assistant_tools/src/edit_agent.rs
Normal file
526
crates/assistant_tools/src/edit_agent.rs
Normal file
@@ -0,0 +1,526 @@
|
||||
mod edit_parser;
|
||||
|
||||
use crate::{Template, Templates};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::ActionLog;
|
||||
use edit_parser::EditParser;
|
||||
use futures::{Stream, StreamExt, stream};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use language::{Anchor, Bias, Buffer, BufferSnapshot};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use smallvec::SmallVec;
|
||||
use std::{ops::Range, path::PathBuf, sync::Arc};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct EditAgentTemplate {
|
||||
path: Option<PathBuf>,
|
||||
file_content: String,
|
||||
instructions: String,
|
||||
}
|
||||
|
||||
impl Template for EditAgentTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "edit_agent.hbs";
|
||||
}
|
||||
|
||||
pub struct EditAgent {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
}
|
||||
|
||||
impl EditAgent {
|
||||
pub fn new(
|
||||
model: Arc<dyn LanguageModel>,
|
||||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
) -> Self {
|
||||
EditAgent {
|
||||
model,
|
||||
action_log,
|
||||
templates,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn edit(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
instructions: String,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
let edits = self.stream_edits(buffer.clone(), instructions, cx).await?;
|
||||
self.apply_edits(buffer, edits, cx).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn apply_edits(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
edits: impl Stream<Item = Result<(Range<Anchor>, String)>>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
// todo!("group all edits into one transaction")
|
||||
// todo!("add tests for this")
|
||||
|
||||
// Ensure the buffer is tracked by the action log.
|
||||
self.action_log
|
||||
.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
|
||||
|
||||
futures::pin_mut!(edits);
|
||||
while let Some(edit) = edits.next().await {
|
||||
let (range, content) = edit?;
|
||||
// Edit the buffer and report the edit as part of the same effect cycle, otherwise
|
||||
// the edit will be reported as if the user made it.
|
||||
cx.update(|cx| {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(range, content)], None, cx));
|
||||
self.action_log
|
||||
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx))
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stream_edits(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
instructions: String,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<impl use<> + Stream<Item = Result<(Range<Anchor>, String)>>> {
|
||||
println!("{}\n\n", instructions);
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
||||
// todo!("move to background")
|
||||
let prompt = EditAgentTemplate {
|
||||
path,
|
||||
file_content: snapshot.text(),
|
||||
instructions,
|
||||
}
|
||||
.render(&self.templates)?;
|
||||
let request = LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text(prompt)],
|
||||
cache: false,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
let mut parser = EditParser::new();
|
||||
let stream = self.model.stream_completion_text(request, cx).await?.stream;
|
||||
Ok(stream.flat_map(move |chunk| {
|
||||
let mut edits = SmallVec::new();
|
||||
let mut error = None;
|
||||
let snapshot = snapshot.clone();
|
||||
match chunk {
|
||||
Ok(chunk) => edits = parser.push(&chunk),
|
||||
Err(err) => error = Some(Err(anyhow!(err))),
|
||||
}
|
||||
stream::iter(
|
||||
edits
|
||||
.into_iter()
|
||||
.map(move |edit| {
|
||||
dbg!(&edit);
|
||||
let range = Self::resolve_location(&snapshot, &edit.old_text);
|
||||
Ok((range, edit.new_text))
|
||||
})
|
||||
.chain(error),
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
fn resolve_location(buffer: &BufferSnapshot, search_query: &str) -> Range<Anchor> {
|
||||
const INSERTION_COST: u32 = 3;
|
||||
const DELETION_COST: u32 = 10;
|
||||
const WHITESPACE_INSERTION_COST: u32 = 1;
|
||||
const WHITESPACE_DELETION_COST: u32 = 1;
|
||||
|
||||
let buffer_len = buffer.len();
|
||||
let query_len = search_query.len();
|
||||
let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1);
|
||||
let mut leading_deletion_cost = 0_u32;
|
||||
for (row, query_byte) in search_query.bytes().enumerate() {
|
||||
let deletion_cost = if query_byte.is_ascii_whitespace() {
|
||||
WHITESPACE_DELETION_COST
|
||||
} else {
|
||||
DELETION_COST
|
||||
};
|
||||
|
||||
leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost);
|
||||
matrix.set(
|
||||
row + 1,
|
||||
0,
|
||||
SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
|
||||
);
|
||||
|
||||
for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
|
||||
let insertion_cost = if buffer_byte.is_ascii_whitespace() {
|
||||
WHITESPACE_INSERTION_COST
|
||||
} else {
|
||||
INSERTION_COST
|
||||
};
|
||||
|
||||
let up = SearchState::new(
|
||||
matrix.get(row, col + 1).cost.saturating_add(deletion_cost),
|
||||
SearchDirection::Up,
|
||||
);
|
||||
let left = SearchState::new(
|
||||
matrix.get(row + 1, col).cost.saturating_add(insertion_cost),
|
||||
SearchDirection::Left,
|
||||
);
|
||||
let diagonal = SearchState::new(
|
||||
if query_byte == *buffer_byte {
|
||||
matrix.get(row, col).cost
|
||||
} else {
|
||||
matrix
|
||||
.get(row, col)
|
||||
.cost
|
||||
.saturating_add(deletion_cost + insertion_cost)
|
||||
},
|
||||
SearchDirection::Diagonal,
|
||||
);
|
||||
matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
|
||||
}
|
||||
}
|
||||
|
||||
// Traceback to find the best match
|
||||
let mut best_buffer_end = buffer_len;
|
||||
let mut best_cost = u32::MAX;
|
||||
for col in 1..=buffer_len {
|
||||
let cost = matrix.get(query_len, col).cost;
|
||||
if cost < best_cost {
|
||||
best_cost = cost;
|
||||
best_buffer_end = col;
|
||||
}
|
||||
}
|
||||
|
||||
let mut query_ix = query_len;
|
||||
let mut buffer_ix = best_buffer_end;
|
||||
while query_ix > 0 && buffer_ix > 0 {
|
||||
let current = matrix.get(query_ix, buffer_ix);
|
||||
match current.direction {
|
||||
SearchDirection::Diagonal => {
|
||||
query_ix -= 1;
|
||||
buffer_ix -= 1;
|
||||
}
|
||||
SearchDirection::Up => {
|
||||
query_ix -= 1;
|
||||
}
|
||||
SearchDirection::Left => {
|
||||
buffer_ix -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
|
||||
start.column = 0;
|
||||
let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
|
||||
if end.column > 0 {
|
||||
end.column = buffer.line_len(end.row);
|
||||
}
|
||||
|
||||
buffer.anchor_after(start)..buffer.anchor_before(end)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||
enum SearchDirection {
|
||||
Up,
|
||||
Left,
|
||||
Diagonal,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct SearchState {
|
||||
cost: u32,
|
||||
direction: SearchDirection,
|
||||
}
|
||||
|
||||
impl SearchState {
|
||||
fn new(cost: u32, direction: SearchDirection) -> Self {
|
||||
Self { cost, direction }
|
||||
}
|
||||
}
|
||||
|
||||
struct SearchMatrix {
|
||||
cols: usize,
|
||||
data: Vec<SearchState>,
|
||||
}
|
||||
|
||||
impl SearchMatrix {
|
||||
fn new(rows: usize, cols: usize) -> Self {
|
||||
SearchMatrix {
|
||||
cols,
|
||||
data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
|
||||
}
|
||||
}
|
||||
|
||||
fn get(&self, row: usize, col: usize) -> SearchState {
|
||||
self.data[row * self.cols + col]
|
||||
}
|
||||
|
||||
fn set(&mut self, row: usize, col: usize, cost: SearchState) {
|
||||
self.data[row * self.cols + col] = cost;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use client::{Client, UserStore};
|
||||
use collections::HashSet;
|
||||
use fs::FakeFs;
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language_model::LanguageModelRegistry;
|
||||
use project::Project;
|
||||
use rand::prelude::*;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use serde_json::json;
|
||||
use std::{fmt::Write as _, io::Write as _, path::Path, sync::mpsc};
|
||||
use util::path;
|
||||
|
||||
#[test]
|
||||
fn test_delete_run_git_blame() {
|
||||
eval(
|
||||
100,
|
||||
0.9,
|
||||
Eval {
|
||||
input_path: "root/blame.rs".into(),
|
||||
input_content: include_str!("fixtures/delete_run_git_blame/before.rs").into(),
|
||||
instructions: indoc! {r#"
|
||||
Let's delete the `run_git_blame` function while keeping all other code intact:
|
||||
|
||||
// ... existing code ...
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
// ... existing code ...
|
||||
"#}
|
||||
.into(),
|
||||
expected_output: include_str!("fixtures/delete_run_git_blame/after.rs").into(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_handle_command_output() {
|
||||
eval(
|
||||
100,
|
||||
0.9,
|
||||
Eval {
|
||||
input_path: "root/blame.rs".into(),
|
||||
input_content: include_str!("fixtures/extract_handle_command_output/before.rs").into(),
|
||||
instructions: indoc! {r#"
|
||||
Extract `handle_command_output` method from `run_git_blame`.
|
||||
|
||||
// ... existing code ...
|
||||
|
||||
async fn run_git_blame(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
contents: &Rope,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in contents.chunks() {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
let output = child
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
|
||||
|
||||
handle_command_output(output)
|
||||
}
|
||||
|
||||
fn handle_command_output(output: std::process::Output) -> Result<String> {
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let trimmed = stderr.trim();
|
||||
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
|
||||
return Ok(String::new());
|
||||
}
|
||||
return Err(anyhow!("git blame process failed: {}", stderr));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8(output.stdout)?)
|
||||
}
|
||||
|
||||
// ... existing code ...
|
||||
"#}
|
||||
.into(),
|
||||
expected_output: include_str!("fixtures/extract_handle_command_output/after.rs").into()
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Eval {
|
||||
input_path: PathBuf,
|
||||
input_content: String,
|
||||
instructions: String,
|
||||
expected_output: String,
|
||||
}
|
||||
|
||||
fn eval(iterations: usize, expected_pass_ratio: f32, eval: Eval) {
|
||||
let executor = gpui::background_executor();
|
||||
let (tx, rx) = mpsc::channel();
|
||||
for _ in 0..iterations {
|
||||
let eval = eval.clone();
|
||||
let tx = tx.clone();
|
||||
executor
|
||||
.spawn(async move {
|
||||
let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
|
||||
let mut cx = TestAppContext::build(dispatcher, None);
|
||||
let output = cx.executor().block_test(async {
|
||||
let test = agent_test(&mut cx).await;
|
||||
apply_edits(
|
||||
eval.input_path,
|
||||
eval.input_content,
|
||||
eval.instructions,
|
||||
&test,
|
||||
&mut cx,
|
||||
)
|
||||
.await
|
||||
});
|
||||
tx.send(output).unwrap();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
drop(tx);
|
||||
|
||||
let mut evaluated_count = 0;
|
||||
report_progress(evaluated_count, iterations);
|
||||
|
||||
let mut failed_count = 0;
|
||||
let mut failed_message = String::new();
|
||||
let mut failed_outputs = HashSet::default();
|
||||
while let Ok(output) = rx.recv() {
|
||||
if output != eval.expected_output {
|
||||
failed_count += 1;
|
||||
if failed_outputs.insert(output.clone()) {
|
||||
writeln!(
|
||||
failed_message,
|
||||
"=======\n{}\n=======",
|
||||
pretty_assertions::StrComparison::new(&output, &eval.expected_output)
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
evaluated_count += 1;
|
||||
report_progress(evaluated_count, iterations);
|
||||
}
|
||||
|
||||
let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
|
||||
println!("Actual pass ratio: {}\n", actual_pass_ratio);
|
||||
assert!(
|
||||
actual_pass_ratio >= expected_pass_ratio,
|
||||
"Expected pass ratio: {}\nActual pass ratio: {}\nFailures: {}",
|
||||
expected_pass_ratio,
|
||||
actual_pass_ratio,
|
||||
failed_message
|
||||
);
|
||||
}
|
||||
|
||||
fn report_progress(evaluated_count: usize, iterations: usize) {
|
||||
print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
async fn apply_edits(
|
||||
path: impl AsRef<Path>,
|
||||
content: impl Into<Arc<str>>,
|
||||
instructions: impl Into<String>,
|
||||
test: &EditAgentTest,
|
||||
cx: &mut TestAppContext,
|
||||
) -> String {
|
||||
let path = test
|
||||
.project
|
||||
.read_with(cx, |project, cx| project.find_project_path(path, cx))
|
||||
.unwrap();
|
||||
let buffer = test
|
||||
.project
|
||||
.update(cx, |project, cx| project.open_buffer(path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text(content, cx));
|
||||
test.agent
|
||||
.edit(buffer.clone(), instructions.into(), &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
buffer.update(cx, |buffer, _cx| buffer.text())
|
||||
}
|
||||
|
||||
struct EditAgentTest {
|
||||
agent: EditAgent,
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
async fn agent_test(cx: &mut TestAppContext) -> EditAgentTest {
|
||||
cx.executor().allow_parking();
|
||||
cx.update(settings::init);
|
||||
cx.update(Project::init_settings);
|
||||
cx.update(language::init);
|
||||
cx.update(gpui_tokio::init);
|
||||
cx.update(client::init_settings);
|
||||
|
||||
let fs = FakeFs::new(cx.executor().clone());
|
||||
fs.insert_tree("/root", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||
|
||||
let models = LanguageModelRegistry::read_global(cx);
|
||||
let model = models
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == "gemini-2.0-flash")
|
||||
.unwrap();
|
||||
|
||||
let provider = models.provider(&model.provider_id()).unwrap();
|
||||
let authenticated = provider.authenticate(cx);
|
||||
|
||||
cx.spawn(async move |_| {
|
||||
authenticated.await.unwrap();
|
||||
model
|
||||
})
|
||||
})
|
||||
.await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
||||
EditAgentTest {
|
||||
agent: EditAgent::new(model, action_log, Templates::new()),
|
||||
project,
|
||||
}
|
||||
}
|
||||
}
|
||||
246
crates/assistant_tools/src/edit_agent/edit_parser.rs
Normal file
246
crates/assistant_tools/src/edit_agent/edit_parser.rs
Normal file
@@ -0,0 +1,246 @@
|
||||
use smallvec::SmallVec;
|
||||
use std::mem;
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
pub struct Edit {
|
||||
pub old_text: String,
|
||||
pub new_text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EditParser {
|
||||
state: EditParserState,
|
||||
buffer: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
enum EditParserState {
|
||||
Pending,
|
||||
WithinOldText,
|
||||
AfterOldText { old_text: String },
|
||||
WithinNewText { old_text: String },
|
||||
}
|
||||
|
||||
impl EditParser {
|
||||
pub fn new() -> Self {
|
||||
EditParser {
|
||||
state: EditParserState::Pending,
|
||||
buffer: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, chunk: &str) -> SmallVec<[Edit; 1]> {
|
||||
self.buffer.push_str(chunk);
|
||||
|
||||
let mut edits = SmallVec::new();
|
||||
loop {
|
||||
match &mut self.state {
|
||||
EditParserState::Pending => {
|
||||
if let Some(start) = self.buffer.find("<old_text>") {
|
||||
self.buffer.drain(..start + "<old_text>".len());
|
||||
self.state = EditParserState::WithinOldText;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
EditParserState::WithinOldText => {
|
||||
if let Some(end) = self.buffer.find("</old_text>") {
|
||||
let mut start = 0;
|
||||
if self.buffer.starts_with('\n') {
|
||||
start = 1;
|
||||
}
|
||||
let mut old_text = self.buffer[start..end].to_string();
|
||||
if old_text.ends_with('\n') {
|
||||
old_text.pop();
|
||||
}
|
||||
|
||||
self.buffer.drain(..end + "</old_text>".len());
|
||||
self.state = EditParserState::AfterOldText { old_text };
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
EditParserState::AfterOldText { old_text } => {
|
||||
if let Some(start) = self.buffer.find("<new_text>") {
|
||||
self.buffer.drain(..start + "<new_text>".len());
|
||||
self.state = EditParserState::WithinNewText {
|
||||
old_text: mem::take(old_text),
|
||||
};
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
EditParserState::WithinNewText { old_text } => {
|
||||
if let Some(end) = self.buffer.find("</new_text>") {
|
||||
let mut start = 0;
|
||||
if self.buffer.starts_with('\n') {
|
||||
start = 1;
|
||||
}
|
||||
let mut new_text = self.buffer[start..end].to_string();
|
||||
if new_text.ends_with('\n') {
|
||||
new_text.pop();
|
||||
}
|
||||
edits.push(Edit {
|
||||
old_text: mem::take(old_text),
|
||||
new_text,
|
||||
});
|
||||
|
||||
self.buffer.drain(..end + "</new_text>".len());
|
||||
self.state = EditParserState::Pending;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
edits
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use indoc::indoc;
|
||||
use rand::prelude::*;
|
||||
use std::cmp;
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_single_edit(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
parse(
|
||||
"<old_text>original</old_text><new_text>updated</new_text>",
|
||||
&mut rng
|
||||
),
|
||||
vec![Edit {
|
||||
old_text: "original".to_string(),
|
||||
new_text: "updated".to_string(),
|
||||
}]
|
||||
)
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_multiple_edits(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
parse(
|
||||
indoc! {"
|
||||
<old_text>
|
||||
first old
|
||||
</old_text><new_text>first new</new_text>
|
||||
<old_text>second old</old_text><new_text>
|
||||
second new
|
||||
</new_text>
|
||||
"},
|
||||
&mut rng
|
||||
),
|
||||
vec![
|
||||
Edit {
|
||||
old_text: "first old".to_string(),
|
||||
new_text: "first new".to_string(),
|
||||
},
|
||||
Edit {
|
||||
old_text: "second old".to_string(),
|
||||
new_text: "second new".to_string(),
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_edits_with_extra_text(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
parse(
|
||||
indoc! {"
|
||||
ignore this <old_text>
|
||||
content</old_text>extra stuff<new_text>updated content</new_text>trailing data
|
||||
more text <old_text>second item
|
||||
</old_text>middle text<new_text>modified second item</new_text>end
|
||||
<old_text>third case</old_text><new_text>improved third case</new_text> with trailing text
|
||||
"},
|
||||
&mut rng
|
||||
),
|
||||
vec![
|
||||
Edit {
|
||||
old_text: "content".to_string(),
|
||||
new_text: "updated content".to_string(),
|
||||
},
|
||||
Edit {
|
||||
old_text: "second item".to_string(),
|
||||
new_text: "modified second item".to_string(),
|
||||
},
|
||||
Edit {
|
||||
old_text: "third case".to_string(),
|
||||
new_text: "improved third case".to_string(),
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_nested_tags(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
parse(
|
||||
"<old_text>code with <tag>nested</tag> elements</old_text><new_text>new <code>content</code></new_text>",
|
||||
&mut rng
|
||||
),
|
||||
vec![Edit {
|
||||
old_text: "code with <tag>nested</tag> elements".to_string(),
|
||||
new_text: "new <code>content</code>".to_string(),
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_empty_old_and_new_text(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
parse("<old_text></old_text><new_text></new_text>", &mut rng),
|
||||
vec![Edit {
|
||||
old_text: "".to_string(),
|
||||
new_text: "".to_string(),
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_with_special_characters(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
parse(
|
||||
"<old_text>function(x) { return x * 2; }</old_text><new_text>function(x) { return x ** 2; }</new_text>",
|
||||
&mut rng
|
||||
),
|
||||
vec![Edit {
|
||||
old_text: "function(x) { return x * 2; }".to_string(),
|
||||
new_text: "function(x) { return x ** 2; }".to_string(),
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 100)]
|
||||
fn test_multiline_content(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
parse(
|
||||
"<old_text>line1\nline2\nline3</old_text><new_text>line1\nmodified line2\nline3</new_text>",
|
||||
&mut rng
|
||||
),
|
||||
vec![Edit {
|
||||
old_text: "line1\nline2\nline3".to_string(),
|
||||
new_text: "line1\nmodified line2\nline3".to_string(),
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
fn parse(input: &str, rng: &mut StdRng) -> Vec<Edit> {
|
||||
let mut parser = EditParser::new();
|
||||
let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
|
||||
let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
|
||||
chunk_indices.sort();
|
||||
|
||||
let mut edits = Vec::new();
|
||||
let mut last_ix = 0;
|
||||
for chunk_ix in chunk_indices {
|
||||
edits.extend(parser.push(&input[last_ix..chunk_ix]));
|
||||
last_ix = chunk_ix;
|
||||
}
|
||||
edits.extend(parser.push(&input[last_ix..]));
|
||||
edits
|
||||
}
|
||||
}
|
||||
@@ -1,18 +1,39 @@
|
||||
use crate::{replace::replace_with_flexible_indent, schema::json_schema_for};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use crate::{Templates, edit_agent::EditAgent, schema::json_schema_for};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||
use editor::{Editor, EditorMode, MultiBuffer, PathKey};
|
||||
use gpui::{
|
||||
AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EntityId, Task, WeakEntity,
|
||||
};
|
||||
use language::{
|
||||
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
|
||||
};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use ui::IconName;
|
||||
|
||||
use crate::replace::replace_exact;
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use ui::{Disclosure, Tooltip, Window, prelude::*};
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct EditFileToolInput {
|
||||
/// A user-friendly markdown description of the edit. This will be shown in the UI.
|
||||
///
|
||||
/// <example>Fix API endpoint URLs</example>
|
||||
/// <example>Update copyright year in `page_footer`</example>
|
||||
///
|
||||
/// Make sure to include this field before all the others in the input object
|
||||
/// so that we can display it immediately.
|
||||
pub display_description: String,
|
||||
|
||||
/// The full path of the file to modify in the project.
|
||||
///
|
||||
/// WARNING: When specifying which file path need changing, you MUST
|
||||
@@ -34,21 +55,49 @@ pub struct EditFileToolInput {
|
||||
/// </example>
|
||||
pub path: PathBuf,
|
||||
|
||||
/// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
|
||||
/// Edit instructions that will be interpreted by a less intelligent model,
|
||||
/// which will quickly apply the edits. You should make it clear what the
|
||||
/// edits are, while also minimizing the unchanged code you write. The model
|
||||
/// does not have access to this conversation, so you must make sure the
|
||||
/// instructions are self-contained and do not rely on external context.
|
||||
///
|
||||
/// <example>Fix API endpoint URLs</example>
|
||||
/// <example>Update copyright year in `page_footer`</example>
|
||||
pub display_description: String,
|
||||
/// Insert `// ... existing code ...` comments in your output to represent
|
||||
/// unchanged code ABOVE, BELOW, and IN BETWEEN edited lines.
|
||||
///
|
||||
/// Bias towards repeating as few lines of the original file as possible to
|
||||
/// convey the change. However, each edit should contain sufficient context
|
||||
/// of unchanged lines to resolve ambiguity. When you want to delete a piece
|
||||
/// of code, indicate a few lines above and below the code you want to
|
||||
/// delete (surrounded by `// ... existing code ...`).
|
||||
///
|
||||
/// Never forget to include `// ... existing code ...` comments to represent
|
||||
/// unchanged lines, otherwise the small model may not understand the
|
||||
/// context of your edit and will delete important code!
|
||||
///
|
||||
/// <your_output>
|
||||
/// // ... existing code ...
|
||||
/// FIRST_EDIT
|
||||
/// // ... existing code ...
|
||||
/// SECOND_EDIT
|
||||
/// // ... existing code ...
|
||||
/// THIRD_EDIT
|
||||
/// // ... existing code ...
|
||||
/// </your_output>
|
||||
pub edit_instructions: String,
|
||||
}
|
||||
|
||||
/// The text to replace.
|
||||
pub old_string: String,
|
||||
|
||||
/// The text to replace it with.
|
||||
pub new_string: String,
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
struct PartialInput {
|
||||
#[serde(default)]
|
||||
path: String,
|
||||
#[serde(default)]
|
||||
display_description: String,
|
||||
}
|
||||
|
||||
pub struct EditFileTool;
|
||||
|
||||
const DEFAULT_UI_TEXT: &str = "Editing file";
|
||||
|
||||
impl Tool for EditFileTool {
|
||||
fn name(&self) -> String {
|
||||
"edit_file".into()
|
||||
@@ -73,16 +122,33 @@ impl Tool for EditFileTool {
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<EditFileToolInput>(input.clone()) {
|
||||
Ok(input) => input.display_description,
|
||||
Err(_) => "Edit file".to_string(),
|
||||
Err(_) => "Editing file".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
|
||||
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
|
||||
let description = input.display_description.trim();
|
||||
if !description.is_empty() {
|
||||
return description.to_string();
|
||||
}
|
||||
|
||||
let path = input.path.trim();
|
||||
if !path.is_empty() {
|
||||
return path.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
DEFAULT_UI_TEXT.to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<EditFileToolInput>(input) {
|
||||
@@ -90,94 +156,516 @@ impl Tool for EditFileTool {
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
|
||||
cx.spawn(async move |cx: &mut AsyncApp| {
|
||||
let project_path = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.find_project_path(&input.path, cx)
|
||||
.context("Path not found in project")
|
||||
})??;
|
||||
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
|
||||
return Task::ready(Err(anyhow!(
|
||||
"Path {} not found in project",
|
||||
input.path.display()
|
||||
)))
|
||||
.into();
|
||||
};
|
||||
let Some(worktree) = project
|
||||
.read(cx)
|
||||
.worktree_for_id(project_path.worktree_id, cx)
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
|
||||
};
|
||||
let exists = worktree.update(cx, |worktree, cx| {
|
||||
worktree.file_exists(&project_path.path, cx)
|
||||
});
|
||||
|
||||
let card = window.and_then(|window| {
|
||||
window
|
||||
.update(cx, |_, window, cx| {
|
||||
cx.new(|cx| {
|
||||
EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
|
||||
})
|
||||
})
|
||||
.ok()
|
||||
});
|
||||
|
||||
let card_clone = card.clone();
|
||||
// todo!("read model from settings...")
|
||||
let models = LanguageModelRegistry::read_global(cx);
|
||||
let model = models
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == "gemini-2.0-flash")
|
||||
.unwrap();
|
||||
let provider = models.provider(&model.provider_id()).unwrap();
|
||||
let authenticated = provider.authenticate(cx);
|
||||
|
||||
// todo!("reuse templates")
|
||||
let edit_agent = EditAgent::new(model, action_log, Templates::new());
|
||||
let task = cx.spawn(async move |cx: &mut AsyncApp| {
|
||||
authenticated.await?;
|
||||
if !exists.await? {
|
||||
return Err(anyhow!("{} not found", input.path.display()));
|
||||
}
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(project_path.clone(), cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
edit_agent
|
||||
.edit(buffer.clone(), input.edit_instructions.clone(), cx)
|
||||
.await?;
|
||||
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer, cx))?
|
||||
.await?;
|
||||
|
||||
if input.old_string.is_empty() {
|
||||
return Err(anyhow!("`old_string` cannot be empty. Use a different tool if you want to create a file."));
|
||||
}
|
||||
let old_text = cx.background_spawn({
|
||||
let old_snapshot = old_snapshot.clone();
|
||||
async move { old_snapshot.text() }
|
||||
});
|
||||
let new_text = cx.background_spawn({
|
||||
let new_snapshot = new_snapshot.clone();
|
||||
async move { new_snapshot.text() }
|
||||
});
|
||||
let diff = cx.background_spawn(async move {
|
||||
language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
|
||||
});
|
||||
let (old_text, new_text, diff) = futures::join!(old_text, new_text, diff);
|
||||
|
||||
if input.old_string == input.new_string {
|
||||
return Err(anyhow!("The `old_string` and `new_string` are identical, so no changes would be made."));
|
||||
}
|
||||
|
||||
let result = cx
|
||||
.background_spawn(async move {
|
||||
// Try to match exactly
|
||||
let diff = replace_exact(&input.old_string, &input.new_string, &snapshot)
|
||||
.await
|
||||
// If that fails, try being flexible about indentation
|
||||
.or_else(|| replace_with_flexible_indent(&input.old_string, &input.new_string, &snapshot))?;
|
||||
|
||||
if diff.edits.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let old_text = snapshot.text();
|
||||
|
||||
Some((old_text, diff))
|
||||
if let Some(card) = card_clone {
|
||||
card.update(cx, |card, cx| {
|
||||
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
|
||||
})
|
||||
.await;
|
||||
.log_err();
|
||||
}
|
||||
|
||||
let Some((old_text, diff)) = result else {
|
||||
let err = buffer.read_with(cx, |buffer, _cx| {
|
||||
let file_exists = buffer
|
||||
.file()
|
||||
.map_or(false, |file| file.disk_state().exists());
|
||||
Ok(format!(
|
||||
"Edited {}:\n\n```diff\n{}\n```",
|
||||
input.path.display(),
|
||||
diff
|
||||
))
|
||||
});
|
||||
|
||||
if !file_exists {
|
||||
anyhow!("{} does not exist", input.path.display())
|
||||
} else if buffer.is_empty() {
|
||||
anyhow!(
|
||||
"{} is empty, so the provided `old_string` wasn't found.",
|
||||
input.path.display()
|
||||
)
|
||||
} else {
|
||||
anyhow!("Failed to match the provided `old_string`")
|
||||
}
|
||||
})?;
|
||||
|
||||
return Err(err)
|
||||
};
|
||||
|
||||
let snapshot = cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_read(buffer.clone(), cx)
|
||||
});
|
||||
let snapshot = buffer.update(cx, |buffer, cx| {
|
||||
buffer.finalize_last_transaction();
|
||||
buffer.apply_diff(diff, cx);
|
||||
buffer.finalize_last_transaction();
|
||||
buffer.snapshot()
|
||||
});
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_edited(buffer.clone(), cx)
|
||||
});
|
||||
snapshot
|
||||
})?;
|
||||
|
||||
project.update( cx, |project, cx| {
|
||||
project.save_buffer(buffer, cx)
|
||||
})?.await?;
|
||||
|
||||
let diff_str = cx.background_spawn(async move {
|
||||
let new_text = snapshot.text();
|
||||
language::unified_diff(&old_text, &new_text)
|
||||
}).await;
|
||||
|
||||
|
||||
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str))
|
||||
|
||||
}).into()
|
||||
ToolResult {
|
||||
output: task,
|
||||
card: card.map(AnyToolCard::from),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EditFileToolCard {
|
||||
path: PathBuf,
|
||||
editor: Entity<Editor>,
|
||||
multibuffer: Entity<MultiBuffer>,
|
||||
project: Entity<Project>,
|
||||
diff_task: Option<Task<Result<()>>>,
|
||||
preview_expanded: bool,
|
||||
full_height_expanded: bool,
|
||||
editor_unique_id: EntityId,
|
||||
}
|
||||
|
||||
impl EditFileToolCard {
|
||||
fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
|
||||
let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor = Editor::new(
|
||||
EditorMode::Full {
|
||||
scale_ui_elements_with_buffer_font_size: false,
|
||||
show_active_line_background: false,
|
||||
sized_by_content: true,
|
||||
},
|
||||
multibuffer.clone(),
|
||||
Some(project.clone()),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
editor.set_show_scrollbars(false, cx);
|
||||
editor.set_show_gutter(false, cx);
|
||||
editor.disable_inline_diagnostics();
|
||||
editor.disable_scrolling(cx);
|
||||
editor.disable_expand_excerpt_buttons(cx);
|
||||
editor.set_show_breakpoints(false, cx);
|
||||
editor.set_show_code_actions(false, cx);
|
||||
editor.set_show_git_diff_gutter(false, cx);
|
||||
editor.set_expand_all_diff_hunks(cx);
|
||||
editor
|
||||
});
|
||||
Self {
|
||||
editor_unique_id: editor.entity_id(),
|
||||
path,
|
||||
project,
|
||||
editor,
|
||||
multibuffer,
|
||||
diff_task: None,
|
||||
preview_expanded: true,
|
||||
full_height_expanded: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn set_diff(
|
||||
&mut self,
|
||||
path: Arc<Path>,
|
||||
old_text: String,
|
||||
new_text: String,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let language_registry = self.project.read(cx).languages().clone();
|
||||
self.diff_task = Some(cx.spawn(async move |this, cx| {
|
||||
let buffer = build_buffer(new_text, path.clone(), &language_registry, cx).await?;
|
||||
let buffer_diff = build_buffer_diff(old_text, &buffer, &language_registry, cx).await?;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.multibuffer.update(cx, |multibuffer, cx| {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let diff = buffer_diff.read(cx);
|
||||
let diff_hunk_ranges = diff
|
||||
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
|
||||
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
|
||||
.collect::<Vec<_>>();
|
||||
let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
|
||||
PathKey::for_buffer(&buffer, cx),
|
||||
buffer,
|
||||
diff_hunk_ranges,
|
||||
editor::DEFAULT_MULTIBUFFER_CONTEXT,
|
||||
cx,
|
||||
);
|
||||
debug_assert!(is_newly_added);
|
||||
multibuffer.add_diff(buffer_diff, cx);
|
||||
});
|
||||
cx.notify();
|
||||
})
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolCard for EditFileToolCard {
|
||||
fn render(
|
||||
&mut self,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let failed = matches!(status, ToolUseStatus::Error(_));
|
||||
|
||||
let path_label_button = h_flex()
|
||||
.id(("edit-tool-path-label-button", self.editor_unique_id))
|
||||
.w_full()
|
||||
.max_w_full()
|
||||
.px_1()
|
||||
.gap_0p5()
|
||||
.cursor_pointer()
|
||||
.rounded_sm()
|
||||
.opacity(0.8)
|
||||
.hover(|label| {
|
||||
label
|
||||
.opacity(1.)
|
||||
.bg(cx.theme().colors().element_hover.opacity(0.5))
|
||||
})
|
||||
.tooltip(Tooltip::text("Jump to File"))
|
||||
.child(
|
||||
h_flex()
|
||||
.child(
|
||||
Icon::new(IconName::Pencil)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.text_size(rems(0.8125))
|
||||
.child(self.path.display().to_string())
|
||||
.ml_1p5()
|
||||
.mr_0p5(),
|
||||
)
|
||||
.child(
|
||||
Icon::new(IconName::ArrowUpRight)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Ignored),
|
||||
),
|
||||
)
|
||||
.on_click({
|
||||
let path = self.path.clone();
|
||||
let workspace = workspace.clone();
|
||||
move |_, window, cx| {
|
||||
workspace
|
||||
.update(cx, {
|
||||
|workspace, cx| {
|
||||
let Some(project_path) =
|
||||
workspace.project().read(cx).find_project_path(&path, cx)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let open_task =
|
||||
workspace.open_path(project_path, None, true, window, cx);
|
||||
window
|
||||
.spawn(cx, async move |cx| {
|
||||
let item = open_task.await?;
|
||||
if let Some(active_editor) = item.downcast::<Editor>() {
|
||||
active_editor
|
||||
.update_in(cx, |editor, window, cx| {
|
||||
editor.go_to_singleton_buffer_point(
|
||||
language::Point::new(0, 0),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
})
|
||||
.into_any_element();
|
||||
|
||||
let codeblock_header_bg = cx
|
||||
.theme()
|
||||
.colors()
|
||||
.element_background
|
||||
.blend(cx.theme().colors().editor_foreground.opacity(0.025));
|
||||
|
||||
let codeblock_header = h_flex()
|
||||
.flex_none()
|
||||
.p_1()
|
||||
.gap_1()
|
||||
.justify_between()
|
||||
.rounded_t_md()
|
||||
.when(!failed, |header| header.bg(codeblock_header_bg))
|
||||
.child(path_label_button)
|
||||
.map(|container| {
|
||||
if failed {
|
||||
container.child(
|
||||
Icon::new(IconName::Close)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Error),
|
||||
)
|
||||
} else {
|
||||
container.child(
|
||||
Disclosure::new(
|
||||
("edit-file-disclosure", self.editor_unique_id),
|
||||
self.preview_expanded,
|
||||
)
|
||||
.opened_icon(IconName::ChevronUp)
|
||||
.closed_icon(IconName::ChevronDown)
|
||||
.on_click(cx.listener(
|
||||
move |this, _event, _window, _cx| {
|
||||
this.preview_expanded = !this.preview_expanded;
|
||||
},
|
||||
)),
|
||||
)
|
||||
}
|
||||
});
|
||||
|
||||
let editor = self.editor.update(cx, |editor, cx| {
|
||||
editor.render(window, cx).into_any_element()
|
||||
});
|
||||
|
||||
let (full_height_icon, full_height_tooltip_label) = if self.full_height_expanded {
|
||||
(IconName::ChevronUp, "Collapse Code Block")
|
||||
} else {
|
||||
(IconName::ChevronDown, "Expand Code Block")
|
||||
};
|
||||
|
||||
let gradient_overlay = div()
|
||||
.absolute()
|
||||
.bottom_0()
|
||||
.left_0()
|
||||
.w_full()
|
||||
.h_2_5()
|
||||
.rounded_b_lg()
|
||||
.bg(gpui::linear_gradient(
|
||||
0.,
|
||||
gpui::linear_color_stop(cx.theme().colors().editor_background, 0.),
|
||||
gpui::linear_color_stop(cx.theme().colors().editor_background.opacity(0.), 1.),
|
||||
));
|
||||
|
||||
let border_color = cx.theme().colors().border.opacity(0.6);
|
||||
|
||||
v_flex()
|
||||
.mb_2()
|
||||
.border_1()
|
||||
.when(failed, |card| card.border_dashed())
|
||||
.border_color(border_color)
|
||||
.rounded_lg()
|
||||
.overflow_hidden()
|
||||
.child(codeblock_header)
|
||||
.when(!failed && self.preview_expanded, |card| {
|
||||
card.child(
|
||||
v_flex()
|
||||
.relative()
|
||||
.overflow_hidden()
|
||||
.border_t_1()
|
||||
.border_color(border_color)
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.map(|editor_container| {
|
||||
if self.full_height_expanded {
|
||||
editor_container.h_full()
|
||||
} else {
|
||||
editor_container.max_h_64()
|
||||
}
|
||||
})
|
||||
.child(div().pl_1().child(editor))
|
||||
.when(!self.full_height_expanded, |editor_container| {
|
||||
editor_container.child(gradient_overlay)
|
||||
}),
|
||||
)
|
||||
})
|
||||
.when(!failed && self.preview_expanded, |card| {
|
||||
card.child(
|
||||
h_flex()
|
||||
.id(("edit-tool-card-inner-hflex", self.editor_unique_id))
|
||||
.flex_none()
|
||||
.cursor_pointer()
|
||||
.h_5()
|
||||
.justify_center()
|
||||
.rounded_b_md()
|
||||
.border_t_1()
|
||||
.border_color(border_color)
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.hover(|style| style.bg(cx.theme().colors().element_hover.opacity(0.1)))
|
||||
.child(
|
||||
Icon::new(full_height_icon)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.tooltip(Tooltip::text(full_height_tooltip_label))
|
||||
.on_click(cx.listener(move |this, _event, _window, _cx| {
|
||||
this.full_height_expanded = !this.full_height_expanded;
|
||||
})),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn build_buffer(
|
||||
mut text: String,
|
||||
path: Arc<Path>,
|
||||
language_registry: &Arc<language::LanguageRegistry>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<Buffer>> {
|
||||
let line_ending = LineEnding::detect(&text);
|
||||
LineEnding::normalize(&mut text);
|
||||
let text = Rope::from(text);
|
||||
let language = cx
|
||||
.update(|_cx| language_registry.language_for_file_path(&path))?
|
||||
.await
|
||||
.ok();
|
||||
let buffer = cx.new(|cx| {
|
||||
let buffer = TextBuffer::new_normalized(
|
||||
0,
|
||||
cx.entity_id().as_non_zero_u64().into(),
|
||||
line_ending,
|
||||
text,
|
||||
);
|
||||
let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
|
||||
buffer.set_language(language, cx);
|
||||
buffer
|
||||
})?;
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
async fn build_buffer_diff(
|
||||
mut old_text: String,
|
||||
buffer: &Entity<Buffer>,
|
||||
language_registry: &Arc<LanguageRegistry>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<BufferDiff>> {
|
||||
LineEnding::normalize(&mut old_text);
|
||||
|
||||
let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
|
||||
|
||||
let base_buffer = cx
|
||||
.update(|cx| {
|
||||
Buffer::build_snapshot(
|
||||
old_text.clone().into(),
|
||||
buffer.language().cloned(),
|
||||
Some(language_registry.clone()),
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let diff_snapshot = cx
|
||||
.update(|cx| {
|
||||
BufferDiffSnapshot::new_with_base_buffer(
|
||||
buffer.text.clone(),
|
||||
Some(old_text.into()),
|
||||
base_buffer,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await;
|
||||
|
||||
cx.new(|cx| {
|
||||
let mut diff = BufferDiff::new(&buffer.text, cx);
|
||||
diff.set_snapshot(diff_snapshot, &buffer.text, cx);
|
||||
diff
|
||||
})
|
||||
}
|
||||
|
||||
// todo!("add unit tests for failure modes of edit, like file not found, etc.")
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_path() {
|
||||
let tool = EditFileTool;
|
||||
let input = json!({
|
||||
"path": "src/main.rs",
|
||||
"display_description": "",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_description() {
|
||||
let tool = EditFileTool;
|
||||
let input = json!({
|
||||
"path": "",
|
||||
"display_description": "Fix error handling",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_path_and_description() {
|
||||
let tool = EditFileTool;
|
||||
let input = json!({
|
||||
"path": "src/main.rs",
|
||||
"display_description": "Fix error handling",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_no_path_or_description() {
|
||||
let tool = EditFileTool;
|
||||
let input = json!({
|
||||
"path": "",
|
||||
"display_description": "",
|
||||
"old_string": "old code",
|
||||
"new_string": "new code"
|
||||
});
|
||||
|
||||
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn still_streaming_ui_text_with_null() {
|
||||
let tool = EditFileTool;
|
||||
let input = serde_json::Value::Null;
|
||||
|
||||
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,39 +7,4 @@ Before using this tool:
|
||||
2. Verify the directory path is correct (only applicable when creating new files):
|
||||
- Use the `list_directory` tool to verify the parent directory exists and is the correct location
|
||||
|
||||
To make a file edit, provide the following:
|
||||
1. path: The full path to the file you wish to modify in the project. This path must include the root directory in the project.
|
||||
2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation)
|
||||
3. new_string: The edited text, which will replace the old_string in the file.
|
||||
|
||||
The tool will replace ONE occurrence of old_string with new_string in the specified file.
|
||||
|
||||
CRITICAL REQUIREMENTS FOR USING THIS TOOL:
|
||||
|
||||
1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means:
|
||||
- Include AT LEAST 3-5 lines of context BEFORE the change point
|
||||
- Include AT LEAST 3-5 lines of context AFTER the change point
|
||||
- Include all whitespace, indentation, and surrounding code exactly as it appears in the file
|
||||
|
||||
2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances:
|
||||
- Make separate calls to this tool for each instance
|
||||
- Each call must uniquely identify its specific instance using extensive context
|
||||
|
||||
3. VERIFICATION: Before using this tool:
|
||||
- Check how many instances of the target text exist in the file
|
||||
- If multiple instances exist, gather enough context to uniquely identify each one
|
||||
- Plan separate tool calls for each instance
|
||||
|
||||
WARNING: If you do not follow these requirements:
|
||||
- The tool will fail if old_string matches multiple locations
|
||||
- The tool will fail if old_string doesn't match exactly (including whitespace)
|
||||
- You may change the wrong instance if you don't include enough context
|
||||
|
||||
When making edits:
|
||||
- Ensure the edit results in idiomatic, correct code
|
||||
- Do not leave the code in a broken state
|
||||
- Always use fully-qualified project paths (starting with the name of one of the project's root directories)
|
||||
|
||||
If you want to create a new file, use the `create_file` tool instead of this tool. Don't pass an empty `old_string`.
|
||||
|
||||
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.
|
||||
Group coherent edits together and include all of them in a single call to this tool. Add the full context needed for a small model to understand the edits.
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{App, AppContext as _, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
|
||||
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
|
||||
use http_client::{AsyncBody, HttpClientWithUrl};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
@@ -145,6 +145,7 @@ impl Tool for FetchTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<FetchToolInput>(input) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -12,7 +12,7 @@ use util::paths::PathMatcher;
|
||||
use worktree::Snapshot;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct PathSearchToolInput {
|
||||
pub struct FindPathToolInput {
|
||||
/// The glob to match against every path in the project.
|
||||
///
|
||||
/// <example>
|
||||
@@ -34,11 +34,11 @@ pub struct PathSearchToolInput {
|
||||
|
||||
const RESULTS_PER_PAGE: usize = 50;
|
||||
|
||||
pub struct PathSearchTool;
|
||||
pub struct FindPathTool;
|
||||
|
||||
impl Tool for PathSearchTool {
|
||||
impl Tool for FindPathTool {
|
||||
fn name(&self) -> String {
|
||||
"path_search".into()
|
||||
"find_path".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
@@ -46,7 +46,7 @@ impl Tool for PathSearchTool {
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
include_str!("./path_search_tool/description.md").into()
|
||||
include_str!("./find_path_tool/description.md").into()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
@@ -54,11 +54,11 @@ impl Tool for PathSearchTool {
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<PathSearchToolInput>(format)
|
||||
json_schema_for::<FindPathToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<PathSearchToolInput>(input.clone()) {
|
||||
match serde_json::from_value::<FindPathToolInput>(input.clone()) {
|
||||
Ok(input) => format!("Find paths matching “`{}`”", input.glob),
|
||||
Err(_) => "Search paths".to_string(),
|
||||
}
|
||||
@@ -70,9 +70,10 @@ impl Tool for PathSearchTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let (offset, glob) = match serde_json::from_value::<PathSearchToolInput>(input) {
|
||||
let (offset, glob) = match serde_json::from_value::<FindPathToolInput>(input) {
|
||||
Ok(input) => (input.offset, input.glob),
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
@@ -143,7 +144,7 @@ mod test {
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_path_search_tool(cx: &mut TestAppContext) {
|
||||
async fn test_find_path_tool(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
@@ -1,4 +1,4 @@
|
||||
Fast file pattern matching tool that works with any codebase size
|
||||
Fast file path pattern matching tool that works with any codebase size
|
||||
|
||||
- Supports glob patterns like "**/*.js" or "src/**/*.ts"
|
||||
- Returns matching file paths sorted alphabetically
|
||||
@@ -1,268 +0,0 @@
|
||||
use crate::{replace::replace_with_flexible_indent, schema::json_schema_for};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use ui::IconName;
|
||||
|
||||
use crate::replace::replace_exact;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct FindReplaceFileToolInput {
|
||||
/// The path of the file to modify.
|
||||
///
|
||||
/// WARNING: When specifying which file path need changing, you MUST
|
||||
/// start each path with one of the project's root directories.
|
||||
///
|
||||
/// The following examples assume we have two root directories in the project:
|
||||
/// - backend
|
||||
/// - frontend
|
||||
///
|
||||
/// <example>
|
||||
/// `backend/src/main.rs`
|
||||
///
|
||||
/// Notice how the file path starts with root-1. Without that, the path
|
||||
/// would be ambiguous and the call would fail!
|
||||
/// </example>
|
||||
///
|
||||
/// <example>
|
||||
/// `frontend/db.js`
|
||||
/// </example>
|
||||
pub path: PathBuf,
|
||||
|
||||
/// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
|
||||
///
|
||||
/// <example>Fix API endpoint URLs</example>
|
||||
/// <example>Update copyright year in `page_footer`</example>
|
||||
pub display_description: String,
|
||||
|
||||
/// The unique string to find in the file. This string cannot be empty;
|
||||
/// if the string is empty, the tool call will fail. Remember, do not use this tool
|
||||
/// to create new files from scratch, or to overwrite existing files! Use a different
|
||||
/// approach if you want to do that.
|
||||
///
|
||||
/// If this string appears more than once in the file, this tool call will fail,
|
||||
/// so it is absolutely critical that you verify ahead of time that the string
|
||||
/// is unique. You can search within the file to verify this.
|
||||
///
|
||||
/// To make the string more likely to be unique, include a minimum of 3 lines of context
|
||||
/// before the string you actually want to find, as well as a minimum of 3 lines of
|
||||
/// context after the string you want to find. (These lines of context should appear
|
||||
/// in the `replace` string as well.) If 3 lines of context is not enough to obtain
|
||||
/// a string that appears only once in the file, then double the number of context lines
|
||||
/// until the string becomes unique. (Start with 3 lines before and 3 lines after
|
||||
/// though, because too much context is needlessly costly.)
|
||||
///
|
||||
/// Do not alter the context lines of code in any way, and make sure to preserve all
|
||||
/// whitespace and indentation for all lines of code. This string must be exactly as
|
||||
/// it appears in the file, because this tool will do a literal find/replace, and if
|
||||
/// even one character in this string is different in any way from how it appears
|
||||
/// in the file, then the tool call will fail.
|
||||
///
|
||||
/// If you get an error that the `find` string was not found, this means that either
|
||||
/// you made a mistake, or that the file has changed since you last looked at it.
|
||||
/// Either way, when this happens, you should retry doing this tool call until it
|
||||
/// succeeds, up to 3 times. Each time you retry, you should take another look at
|
||||
/// the exact text of the file in question, to make sure that you are searching for
|
||||
/// exactly the right string. Regardless of whether it was because you made a mistake
|
||||
/// or because the file changed since you last looked at it, you should be extra
|
||||
/// careful when retrying in this way. It's a bad experience for the user if
|
||||
/// this `find` string isn't found, so be super careful to get it exactly right!
|
||||
///
|
||||
/// <example>
|
||||
/// If a file contains this code:
|
||||
///
|
||||
/// ```ignore
|
||||
/// fn check_user_permissions(user_id: &str) -> Result<bool> {
|
||||
/// // Check if user exists first
|
||||
/// let user = database.find_user(user_id)?;
|
||||
///
|
||||
/// // This is the part we want to modify
|
||||
/// if user.role == "admin" {
|
||||
/// return Ok(true);
|
||||
/// }
|
||||
///
|
||||
/// // Check other permissions
|
||||
/// check_custom_permissions(user_id)
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Your find string should include at least 3 lines of context before and after the part
|
||||
/// you want to change:
|
||||
///
|
||||
/// ```ignore
|
||||
/// fn check_user_permissions(user_id: &str) -> Result<bool> {
|
||||
/// // Check if user exists first
|
||||
/// let user = database.find_user(user_id)?;
|
||||
///
|
||||
/// // This is the part we want to modify
|
||||
/// if user.role == "admin" {
|
||||
/// return Ok(true);
|
||||
/// }
|
||||
///
|
||||
/// // Check other permissions
|
||||
/// check_custom_permissions(user_id)
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// And your replace string might look like:
|
||||
///
|
||||
/// ```ignore
|
||||
/// fn check_user_permissions(user_id: &str) -> Result<bool> {
|
||||
/// // Check if user exists first
|
||||
/// let user = database.find_user(user_id)?;
|
||||
///
|
||||
/// // This is the part we want to modify
|
||||
/// if user.role == "admin" || user.role == "superuser" {
|
||||
/// return Ok(true);
|
||||
/// }
|
||||
///
|
||||
/// // Check other permissions
|
||||
/// check_custom_permissions(user_id)
|
||||
/// }
|
||||
/// ```
|
||||
/// </example>
|
||||
pub find: String,
|
||||
|
||||
/// The string to replace the one unique occurrence of the find string with.
|
||||
pub replace: String,
|
||||
}
|
||||
|
||||
pub struct FindReplaceFileTool;
|
||||
|
||||
impl Tool for FindReplaceFileTool {
|
||||
fn name(&self) -> String {
|
||||
"find_replace_file".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
include_str!("find_replace_tool/description.md").to_string()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::Pencil
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<FindReplaceFileToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<FindReplaceFileToolInput>(input.clone()) {
|
||||
Ok(input) => input.display_description,
|
||||
Err(_) => "Edit file".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<FindReplaceFileToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
|
||||
cx.spawn(async move |cx: &mut AsyncApp| {
|
||||
let project_path = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.find_project_path(&input.path, cx)
|
||||
.context("Path not found in project")
|
||||
})??;
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||
.await?;
|
||||
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
|
||||
if input.find.is_empty() {
|
||||
return Err(anyhow!("`find` string cannot be empty. Use a different tool if you want to create a file."));
|
||||
}
|
||||
|
||||
if input.find == input.replace {
|
||||
return Err(anyhow!("The `find` and `replace` strings are identical, so no changes would be made."));
|
||||
}
|
||||
|
||||
let result = cx
|
||||
.background_spawn(async move {
|
||||
// Try to match exactly
|
||||
let diff = replace_exact(&input.find, &input.replace, &snapshot)
|
||||
.await
|
||||
// If that fails, try being flexible about indentation
|
||||
.or_else(|| replace_with_flexible_indent(&input.find, &input.replace, &snapshot))?;
|
||||
|
||||
if diff.edits.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let old_text = snapshot.text();
|
||||
|
||||
Some((old_text, diff))
|
||||
})
|
||||
.await;
|
||||
|
||||
let Some((old_text, diff)) = result else {
|
||||
let err = buffer.read_with(cx, |buffer, _cx| {
|
||||
let file_exists = buffer
|
||||
.file()
|
||||
.map_or(false, |file| file.disk_state().exists());
|
||||
|
||||
if !file_exists {
|
||||
anyhow!("{} does not exist", input.path.display())
|
||||
} else if buffer.is_empty() {
|
||||
anyhow!(
|
||||
"{} is empty, so the provided `find` string wasn't found.",
|
||||
input.path.display()
|
||||
)
|
||||
} else {
|
||||
anyhow!("Failed to match the provided `find` string")
|
||||
}
|
||||
})?;
|
||||
|
||||
return Err(err)
|
||||
};
|
||||
|
||||
let snapshot = cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_read(buffer.clone(), cx)
|
||||
});
|
||||
let snapshot = buffer.update(cx, |buffer, cx| {
|
||||
buffer.finalize_last_transaction();
|
||||
buffer.apply_diff(diff, cx);
|
||||
buffer.finalize_last_transaction();
|
||||
buffer.snapshot()
|
||||
});
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_edited(buffer.clone(), cx)
|
||||
});
|
||||
snapshot
|
||||
})?;
|
||||
|
||||
project.update( cx, |project, cx| {
|
||||
project.save_buffer(buffer, cx)
|
||||
})?.await?;
|
||||
|
||||
let diff_str = cx.background_spawn(async move {
|
||||
let new_text = snapshot.text();
|
||||
language::unified_diff(&old_text, &new_text)
|
||||
}).await;
|
||||
|
||||
|
||||
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str))
|
||||
|
||||
}).into()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,328 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
async fn run_git_blame(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
contents: &Rope,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in contents.chunks() {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
let output = child
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let trimmed = stderr.trim();
|
||||
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
|
||||
return Ok(String::new());
|
||||
}
|
||||
return Err(anyhow!("git blame process failed: {}", stderr));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8(output.stdout)?)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,378 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
async fn run_git_blame(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
contents: &Rope,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in contents.chunks() {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
let output = child
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
|
||||
|
||||
handle_command_output(output)
|
||||
}
|
||||
|
||||
fn handle_command_output(output: std::process::Output) -> Result<String> {
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let trimmed = stderr.trim();
|
||||
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
|
||||
return Ok(String::new());
|
||||
}
|
||||
return Err(anyhow!("git blame process failed: {}", stderr));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8(output.stdout)?)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
async fn run_git_blame(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
contents: &Rope,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in contents.chunks() {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
let output = child
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let trimmed = stderr.trim();
|
||||
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
|
||||
return Ok(String::new());
|
||||
}
|
||||
return Err(anyhow!("git blame process failed: {}", stderr));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8(output.stdout)?)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::StreamExt;
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language::OffsetRangeExt;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{
|
||||
@@ -20,6 +20,8 @@ use util::paths::PathMatcher;
|
||||
pub struct GrepToolInput {
|
||||
/// A regex pattern to search for in the entire project. Note that the regex
|
||||
/// will be parsed by the Rust `regex` crate.
|
||||
///
|
||||
/// Do NOT specify a path here! This will only be matched against the code **content**.
|
||||
pub regex: String,
|
||||
|
||||
/// A glob pattern for the paths of files to include in the search.
|
||||
@@ -96,6 +98,7 @@ impl Tool for GrepTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
const CONTEXT_LINES: u32 = 2;
|
||||
@@ -405,7 +408,7 @@ mod tests {
|
||||
) -> String {
|
||||
let tool = Arc::new(GrepTool);
|
||||
let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
|
||||
let task = cx.update(|cx| tool.run(input, &[], project, action_log, cx));
|
||||
let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx));
|
||||
|
||||
match task.output.await {
|
||||
Ok(result) => result,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -76,6 +76,7 @@ impl Tool for ListDirectoryTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {
|
||||
|
||||
@@ -1 +1 @@
|
||||
Lists files and directories in a given path. Prefer the `grep` or `path_search` tools when searching the codebase.
|
||||
Lists files and directories in a given path. Prefer the `grep` or `find_path` tools when searching the codebase.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -89,6 +89,7 @@ impl Tool for MovePathTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<MovePathToolInput>(input) {
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use chrono::{Local, Utc};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -59,6 +59,7 @@ impl Tool for NowTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
_cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input: NowToolInput = match serde_json::from_value(input) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -52,6 +52,7 @@ impl Tool for OpenTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input: OpenToolInput = match serde_json::from_value(input) {
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
|
||||
use indoc::formatdoc;
|
||||
use itertools::Itertools;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
@@ -87,6 +88,7 @@ impl Tool for ReadFileTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
|
||||
@@ -134,7 +136,7 @@ impl Tool for ReadFileTool {
|
||||
})?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_read(buffer, cx);
|
||||
log.track_buffer(buffer, cx);
|
||||
})?;
|
||||
|
||||
Ok(result)
|
||||
@@ -147,7 +149,7 @@ impl Tool for ReadFileTool {
|
||||
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_read(buffer, cx);
|
||||
log.track_buffer(buffer, cx);
|
||||
})?;
|
||||
|
||||
Ok(result)
|
||||
@@ -193,7 +195,7 @@ mod test {
|
||||
"path": "root/nonexistent_file.txt"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log, cx)
|
||||
.run(input, &[], project.clone(), action_log, None, cx)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
@@ -223,7 +225,7 @@ mod test {
|
||||
"path": "root/small_file.txt"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log, cx)
|
||||
.run(input, &[], project.clone(), action_log, None, cx)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
@@ -253,7 +255,7 @@ mod test {
|
||||
"path": "root/large_file.rs"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log.clone(), cx)
|
||||
.run(input, &[], project.clone(), action_log.clone(), None, cx)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
@@ -277,7 +279,7 @@ mod test {
|
||||
"offset": 1
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log, cx)
|
||||
.run(input, &[], project.clone(), action_log, None, cx)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
@@ -323,7 +325,7 @@ mod test {
|
||||
"end_line": 4
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log, cx)
|
||||
.run(input, &[], project.clone(), action_log, None, cx)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language::{self, Buffer, ToPointUtf16};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -87,6 +87,7 @@ impl Tool for RenameTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<RenameToolInput>(input) {
|
||||
@@ -106,7 +107,7 @@ impl Tool for RenameTool {
|
||||
};
|
||||
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.buffer_read(buffer.clone(), cx);
|
||||
action_log.track_buffer(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
let position = {
|
||||
|
||||
@@ -1,872 +0,0 @@
|
||||
use language::{BufferSnapshot, Diff, Point, ToOffset};
|
||||
use project::search::SearchQuery;
|
||||
use std::iter;
|
||||
use util::{ResultExt as _, paths::PathMatcher};
|
||||
|
||||
/// Performs an exact string replacement in a buffer, requiring precise character-for-character matching.
|
||||
/// Uses the search functionality to locate the first occurrence of the exact string.
|
||||
/// Returns None if no exact match is found in the buffer.
|
||||
pub async fn replace_exact(old: &str, new: &str, snapshot: &BufferSnapshot) -> Option<Diff> {
|
||||
let query = SearchQuery::text(
|
||||
old,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
PathMatcher::new(iter::empty::<&str>()).ok()?,
|
||||
PathMatcher::new(iter::empty::<&str>()).ok()?,
|
||||
false,
|
||||
None,
|
||||
)
|
||||
.log_err()?;
|
||||
|
||||
let matches = query.search(&snapshot, None).await;
|
||||
|
||||
if matches.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let edit_range = matches[0].clone();
|
||||
let diff = language::text_diff(&old, &new);
|
||||
|
||||
let edits = diff
|
||||
.into_iter()
|
||||
.map(|(old_range, text)| {
|
||||
let start = edit_range.start + old_range.start;
|
||||
let end = edit_range.start + old_range.end;
|
||||
(start..end, text)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let diff = language::Diff {
|
||||
base_version: snapshot.version().clone(),
|
||||
line_ending: snapshot.line_ending(),
|
||||
edits,
|
||||
};
|
||||
|
||||
Some(diff)
|
||||
}
|
||||
|
||||
/// Performs a replacement that's indentation-aware - matches text content ignoring leading whitespace differences.
|
||||
/// When replacing, preserves the indentation level found in the buffer at each matching line.
|
||||
/// Returns None if no match found or if indentation is offset inconsistently across matched lines.
|
||||
pub fn replace_with_flexible_indent(old: &str, new: &str, buffer: &BufferSnapshot) -> Option<Diff> {
|
||||
let (old_lines, old_min_indent) = lines_with_min_indent(old);
|
||||
let (new_lines, new_min_indent) = lines_with_min_indent(new);
|
||||
let min_indent = old_min_indent.min(new_min_indent);
|
||||
|
||||
let old_lines = drop_lines_prefix(&old_lines, min_indent);
|
||||
let new_lines = drop_lines_prefix(&new_lines, min_indent);
|
||||
|
||||
let max_row = buffer.max_point().row;
|
||||
|
||||
'windows: for start_row in 0..max_row + 1 {
|
||||
let end_row = start_row + old_lines.len().saturating_sub(1) as u32;
|
||||
|
||||
if end_row > max_row {
|
||||
// The buffer ends before fully matching the pattern
|
||||
return None;
|
||||
}
|
||||
|
||||
let start_point = Point::new(start_row, 0);
|
||||
let end_point = Point::new(end_row, buffer.line_len(end_row));
|
||||
let range = start_point.to_offset(buffer)..end_point.to_offset(buffer);
|
||||
|
||||
let window_text = buffer.text_for_range(range.clone());
|
||||
let mut window_lines = window_text.lines();
|
||||
let mut old_lines_iter = old_lines.iter();
|
||||
|
||||
let mut common_mismatch = None;
|
||||
|
||||
#[derive(Eq, PartialEq)]
|
||||
enum Mismatch {
|
||||
OverIndented(String),
|
||||
UnderIndented(String),
|
||||
}
|
||||
|
||||
while let (Some(window_line), Some(old_line)) = (window_lines.next(), old_lines_iter.next())
|
||||
{
|
||||
let line_trimmed = window_line.trim_start();
|
||||
|
||||
if line_trimmed != old_line.trim_start() {
|
||||
continue 'windows;
|
||||
}
|
||||
|
||||
if line_trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let line_mismatch = if window_line.len() > old_line.len() {
|
||||
let prefix = window_line[..window_line.len() - old_line.len()].to_string();
|
||||
Mismatch::UnderIndented(prefix)
|
||||
} else {
|
||||
let prefix = old_line[..old_line.len() - window_line.len()].to_string();
|
||||
Mismatch::OverIndented(prefix)
|
||||
};
|
||||
|
||||
match &common_mismatch {
|
||||
Some(common_mismatch) if common_mismatch != &line_mismatch => {
|
||||
continue 'windows;
|
||||
}
|
||||
Some(_) => (),
|
||||
None => common_mismatch = Some(line_mismatch),
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(common_mismatch) = &common_mismatch {
|
||||
let line_ending = buffer.line_ending();
|
||||
let replacement = new_lines
|
||||
.iter()
|
||||
.map(|new_line| {
|
||||
if new_line.trim().is_empty() {
|
||||
new_line.to_string()
|
||||
} else {
|
||||
match common_mismatch {
|
||||
Mismatch::UnderIndented(prefix) => prefix.to_string() + new_line,
|
||||
Mismatch::OverIndented(prefix) => new_line
|
||||
.strip_prefix(prefix)
|
||||
.unwrap_or(new_line)
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(line_ending.as_str());
|
||||
|
||||
let diff = Diff {
|
||||
base_version: buffer.version().clone(),
|
||||
line_ending,
|
||||
edits: vec![(range, replacement.into())],
|
||||
};
|
||||
|
||||
return Some(diff);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn drop_lines_prefix<'a>(lines: &'a [&str], prefix_len: usize) -> Vec<&'a str> {
|
||||
lines
|
||||
.iter()
|
||||
.map(|line| line.get(prefix_len..).unwrap_or(""))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn lines_with_min_indent(input: &str) -> (Vec<&str>, usize) {
|
||||
let mut lines = Vec::new();
|
||||
let mut min_indent: Option<usize> = None;
|
||||
|
||||
for line in input.lines() {
|
||||
lines.push(line);
|
||||
if !line.trim().is_empty() {
|
||||
let indent = line.len() - line.trim_start().len();
|
||||
min_indent = Some(min_indent.map_or(indent, |m| m.min(indent)));
|
||||
}
|
||||
}
|
||||
|
||||
(lines, min_indent.unwrap_or(0))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod replace_exact_tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use gpui::prelude::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn basic(cx: &mut TestAppContext) {
|
||||
let result = test_replace_exact(cx, "let x = 41;", "let x = 41;", "let x = 42;").await;
|
||||
assert_eq!(result, Some("let x = 42;".to_string()));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn no_match(cx: &mut TestAppContext) {
|
||||
let result = test_replace_exact(cx, "let x = 41;", "let y = 42;", "let y = 43;").await;
|
||||
assert_eq!(result, None);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn multi_line(cx: &mut TestAppContext) {
|
||||
let whole = "fn example() {\n let x = 41;\n println!(\"x = {}\", x);\n}";
|
||||
let old_text = " let x = 41;\n println!(\"x = {}\", x);";
|
||||
let new_text = " let x = 42;\n println!(\"x = {}\", x);";
|
||||
let result = test_replace_exact(cx, whole, old_text, new_text).await;
|
||||
assert_eq!(
|
||||
result,
|
||||
Some("fn example() {\n let x = 42;\n println!(\"x = {}\", x);\n}".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn multiple_occurrences(cx: &mut TestAppContext) {
|
||||
let whole = "let x = 41;\nlet y = 41;\nlet z = 41;";
|
||||
let result = test_replace_exact(cx, whole, "let x = 41;", "let x = 42;").await;
|
||||
assert_eq!(
|
||||
result,
|
||||
Some("let x = 42;\nlet y = 41;\nlet z = 41;".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn empty_buffer(cx: &mut TestAppContext) {
|
||||
let result = test_replace_exact(cx, "", "let x = 41;", "let x = 42;").await;
|
||||
assert_eq!(result, None);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn partial_match(cx: &mut TestAppContext) {
|
||||
let whole = "let x = 41; let y = 42;";
|
||||
let result = test_replace_exact(cx, whole, "let x = 41", "let x = 42").await;
|
||||
assert_eq!(result, Some("let x = 42; let y = 42;".to_string()));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn whitespace_sensitive(cx: &mut TestAppContext) {
|
||||
let result = test_replace_exact(cx, "let x = 41;", " let x = 41;", "let x = 42;").await;
|
||||
assert_eq!(result, None);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn entire_buffer(cx: &mut TestAppContext) {
|
||||
let result = test_replace_exact(cx, "let x = 41;", "let x = 41;", "let x = 42;").await;
|
||||
assert_eq!(result, Some("let x = 42;".to_string()));
|
||||
}
|
||||
|
||||
async fn test_replace_exact(
|
||||
cx: &mut TestAppContext,
|
||||
whole: &str,
|
||||
old: &str,
|
||||
new: &str,
|
||||
) -> Option<String> {
|
||||
let buffer = cx.new(|cx| language::Buffer::local(whole, cx));
|
||||
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
let diff = replace_exact(old, new, &buffer_snapshot).await;
|
||||
diff.map(|diff| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
let _ = buffer.apply_diff(diff, cx);
|
||||
buffer.text()
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod flexible_indent_tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use gpui::prelude::*;
|
||||
use unindent::Unindent;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_underindented_single_line(cx: &mut TestAppContext) {
|
||||
let cur = " let a = 41;".to_string();
|
||||
let old = " let a = 41;".to_string();
|
||||
let new = " let a = 42;".to_string();
|
||||
let exp = " let a = 42;".to_string();
|
||||
|
||||
let result = test_replace_with_flexible_indent(cx, &cur, &old, &new);
|
||||
|
||||
assert_eq!(result, Some(exp.to_string()))
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_overindented_single_line(cx: &mut TestAppContext) {
|
||||
let cur = " let a = 41;".to_string();
|
||||
let old = " let a = 41;".to_string();
|
||||
let new = " let a = 42;".to_string();
|
||||
let exp = " let a = 42;".to_string();
|
||||
|
||||
let result = test_replace_with_flexible_indent(cx, &cur, &old, &new);
|
||||
|
||||
assert_eq!(result, Some(exp.to_string()))
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_underindented_multi_line(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
fn test() {
|
||||
let x = 5;
|
||||
println!("x = {}", x);
|
||||
let y = 10;
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
let x = 5;
|
||||
println!("x = {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
let x = 42;
|
||||
println!("New value: {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let expected = r#"
|
||||
fn test() {
|
||||
let x = 42;
|
||||
println!("New value: {}", x);
|
||||
let y = 10;
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
Some(expected.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_overindented_multi_line(cx: &mut TestAppContext) {
|
||||
let cur = r#"
|
||||
fn foo() {
|
||||
let a = 41;
|
||||
let b = 3.13;
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
// 6 space indent instead of 4
|
||||
let old = " let a = 41;\n let b = 3.13;";
|
||||
let new = " let a = 42;\n let b = 3.14;";
|
||||
|
||||
let expected = r#"
|
||||
fn foo() {
|
||||
let a = 42;
|
||||
let b = 3.14;
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let result = test_replace_with_flexible_indent(cx, &cur, &old, &new);
|
||||
|
||||
assert_eq!(result, Some(expected.to_string()))
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_inconsistent_indentation(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
fn test() {
|
||||
if condition {
|
||||
println!("{}", 43);
|
||||
}
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
if condition {
|
||||
println!("{}", 43);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
if condition {
|
||||
println!("{}", 42);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_with_empty_lines(cx: &mut TestAppContext) {
|
||||
// Test with empty lines
|
||||
let whole = r#"
|
||||
fn test() {
|
||||
let x = 5;
|
||||
|
||||
println!("x = {}", x);
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
let x = 5;
|
||||
|
||||
println!("x = {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
let x = 10;
|
||||
|
||||
println!("New x: {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let expected = r#"
|
||||
fn test() {
|
||||
let x = 10;
|
||||
|
||||
println!("New x: {}", x);
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
Some(expected.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_no_match(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
fn test() {
|
||||
let x = 5;
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
let y = 10;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
let y = 20;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_whole_ends_before_matching_old(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
fn test() {
|
||||
let x = 5;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
let x = 5;
|
||||
println!("x = {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
let x = 10;
|
||||
println!("x = {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
// Should return None because whole doesn't fully contain the old text
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_whole_is_shorter_than_old(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
let x = 5;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
let x = 5;
|
||||
let y = 10;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
let x = 5;
|
||||
let y = 20;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_old_is_empty(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
fn test() {
|
||||
let x = 5;
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = "";
|
||||
let new = r#"
|
||||
let y = 10;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_whole_is_empty(cx: &mut TestAppContext) {
|
||||
let whole = "";
|
||||
let old = r#"
|
||||
let x = 5;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
let x = 10;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lines_with_min_indent() {
|
||||
// Empty string
|
||||
assert_eq!(lines_with_min_indent(""), (vec![], 0));
|
||||
|
||||
// Single line without indentation
|
||||
assert_eq!(lines_with_min_indent("hello"), (vec!["hello"], 0));
|
||||
|
||||
// Multiple lines with no indentation
|
||||
assert_eq!(
|
||||
lines_with_min_indent("line1\nline2\nline3"),
|
||||
(vec!["line1", "line2", "line3"], 0)
|
||||
);
|
||||
|
||||
// Multiple lines with consistent indentation
|
||||
assert_eq!(
|
||||
lines_with_min_indent(" line1\n line2\n line3"),
|
||||
(vec![" line1", " line2", " line3"], 2)
|
||||
);
|
||||
|
||||
// Multiple lines with varying indentation
|
||||
assert_eq!(
|
||||
lines_with_min_indent(" line1\n line2\n line3"),
|
||||
(vec![" line1", " line2", " line3"], 2)
|
||||
);
|
||||
|
||||
// Lines with mixed indentation and empty lines
|
||||
assert_eq!(
|
||||
lines_with_min_indent(" line1\n\n line2"),
|
||||
(vec![" line1", "", " line2"], 2)
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_with_missing_indent_uneven_match(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
fn test() {
|
||||
if true {
|
||||
let x = 5;
|
||||
println!("x = {}", x);
|
||||
}
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
let x = 5;
|
||||
println!("x = {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
let x = 42;
|
||||
println!("x = {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let expected = r#"
|
||||
fn test() {
|
||||
if true {
|
||||
let x = 42;
|
||||
println!("x = {}", x);
|
||||
}
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
Some(expected.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_big_example(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_valid_age() {
|
||||
assert!(is_valid_age(0));
|
||||
assert!(!is_valid_age(151));
|
||||
}
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
#[test]
|
||||
fn test_is_valid_age() {
|
||||
assert!(is_valid_age(0));
|
||||
assert!(!is_valid_age(151));
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
#[test]
|
||||
fn test_is_valid_age() {
|
||||
assert!(is_valid_age(0));
|
||||
assert!(!is_valid_age(151));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_group_people_by_age() {
|
||||
let people = vec![
|
||||
Person::new("Young One", 5, "young@example.com").unwrap(),
|
||||
Person::new("Teen One", 15, "teen@example.com").unwrap(),
|
||||
Person::new("Teen Two", 18, "teen2@example.com").unwrap(),
|
||||
Person::new("Adult One", 25, "adult@example.com").unwrap(),
|
||||
];
|
||||
|
||||
let groups = group_people_by_age(&people);
|
||||
|
||||
assert_eq!(groups.get(&0).unwrap().len(), 1); // One person in 0-9
|
||||
assert_eq!(groups.get(&10).unwrap().len(), 2); // Two people in 10-19
|
||||
assert_eq!(groups.get(&20).unwrap().len(), 1); // One person in 20-29
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
let expected = r#"
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_valid_age() {
|
||||
assert!(is_valid_age(0));
|
||||
assert!(!is_valid_age(151));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_group_people_by_age() {
|
||||
let people = vec![
|
||||
Person::new("Young One", 5, "young@example.com").unwrap(),
|
||||
Person::new("Teen One", 15, "teen@example.com").unwrap(),
|
||||
Person::new("Teen Two", 18, "teen2@example.com").unwrap(),
|
||||
Person::new("Adult One", 25, "adult@example.com").unwrap(),
|
||||
];
|
||||
|
||||
let groups = group_people_by_age(&people);
|
||||
|
||||
assert_eq!(groups.get(&0).unwrap().len(), 1); // One person in 0-9
|
||||
assert_eq!(groups.get(&10).unwrap().len(), 2); // Two people in 10-19
|
||||
assert_eq!(groups.get(&20).unwrap().len(), 1); // One person in 20-29
|
||||
}
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
Some(expected.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_drop_lines_prefix() {
|
||||
// Empty array
|
||||
assert_eq!(drop_lines_prefix(&[], 2), Vec::<&str>::new());
|
||||
|
||||
// Zero prefix length
|
||||
assert_eq!(
|
||||
drop_lines_prefix(&["line1", "line2"], 0),
|
||||
vec!["line1", "line2"]
|
||||
);
|
||||
|
||||
// Normal prefix drop
|
||||
assert_eq!(
|
||||
drop_lines_prefix(&[" line1", " line2"], 2),
|
||||
vec!["line1", "line2"]
|
||||
);
|
||||
|
||||
// Prefix longer than some lines
|
||||
assert_eq!(drop_lines_prefix(&[" line1", "a"], 2), vec!["line1", ""]);
|
||||
|
||||
// Prefix longer than all lines
|
||||
assert_eq!(drop_lines_prefix(&["a", "b"], 5), vec!["", ""]);
|
||||
|
||||
// Mixed length lines
|
||||
assert_eq!(
|
||||
drop_lines_prefix(&[" line1", " line2", " line3"], 2),
|
||||
vec![" line1", "line2", " line3"]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_basic(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| language::Buffer::local("let x = 41;", cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
let diff = replace_exact("let x = 41;", "let x = 42;", &snapshot).await;
|
||||
assert!(diff.is_some());
|
||||
|
||||
let diff = diff.unwrap();
|
||||
assert_eq!(diff.edits.len(), 1);
|
||||
|
||||
let result = buffer.update(cx, |buffer, cx| {
|
||||
let _ = buffer.apply_diff(diff, cx);
|
||||
buffer.text()
|
||||
});
|
||||
|
||||
assert_eq!(result, "let x = 42;");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_no_match(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| language::Buffer::local("let x = 41;", cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
let diff = replace_exact("let y = 42;", "let y = 43;", &snapshot).await;
|
||||
assert!(diff.is_none());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_multi_line(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| {
|
||||
language::Buffer::local(
|
||||
"fn example() {\n let x = 41;\n println!(\"x = {}\", x);\n}",
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
let old_text = " let x = 41;\n println!(\"x = {}\", x);";
|
||||
let new_text = " let x = 42;\n println!(\"x = {}\", x);";
|
||||
let diff = replace_exact(old_text, new_text, &snapshot).await;
|
||||
assert!(diff.is_some());
|
||||
|
||||
let diff = diff.unwrap();
|
||||
let result = buffer.update(cx, |buffer, cx| {
|
||||
let _ = buffer.apply_diff(diff, cx);
|
||||
buffer.text()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
"fn example() {\n let x = 42;\n println!(\"x = {}\", x);\n}"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_multiple_occurrences(cx: &mut TestAppContext) {
|
||||
let buffer =
|
||||
cx.new(|cx| language::Buffer::local("let x = 41;\nlet y = 41;\nlet z = 41;", cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
// Should replace only the first occurrence
|
||||
let diff = replace_exact("let x = 41;", "let x = 42;", &snapshot).await;
|
||||
assert!(diff.is_some());
|
||||
|
||||
let diff = diff.unwrap();
|
||||
let result = buffer.update(cx, |buffer, cx| {
|
||||
let _ = buffer.apply_diff(diff, cx);
|
||||
buffer.text()
|
||||
});
|
||||
|
||||
assert_eq!(result, "let x = 42;\nlet y = 41;\nlet z = 41;");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_empty_buffer(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| language::Buffer::local("", cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
let diff = replace_exact("let x = 41;", "let x = 42;", &snapshot).await;
|
||||
assert!(diff.is_none());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_partial_match(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| language::Buffer::local("let x = 41; let y = 42;", cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
// Verify substring replacement actually works
|
||||
let diff = replace_exact("let x = 41", "let x = 42", &snapshot).await;
|
||||
assert!(diff.is_some());
|
||||
|
||||
let diff = diff.unwrap();
|
||||
let result = buffer.update(cx, |buffer, cx| {
|
||||
let _ = buffer.apply_diff(diff, cx);
|
||||
buffer.text()
|
||||
});
|
||||
|
||||
assert_eq!(result, "let x = 42; let y = 42;");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_whitespace_sensitive(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| language::Buffer::local("let x = 41;", cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
let diff = replace_exact(" let x = 41;", "let x = 42;", &snapshot).await;
|
||||
assert!(diff.is_none());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_entire_buffer(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| language::Buffer::local("let x = 41;", cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
let diff = replace_exact("let x = 41;", "let x = 42;", &snapshot).await;
|
||||
assert!(diff.is_some());
|
||||
|
||||
let diff = diff.unwrap();
|
||||
let result = buffer.update(cx, |buffer, cx| {
|
||||
let _ = buffer.apply_diff(diff, cx);
|
||||
buffer.text()
|
||||
});
|
||||
|
||||
assert_eq!(result, "let x = 42;");
|
||||
}
|
||||
|
||||
fn test_replace_with_flexible_indent(
|
||||
cx: &mut TestAppContext,
|
||||
whole: &str,
|
||||
old: &str,
|
||||
new: &str,
|
||||
) -> Option<String> {
|
||||
// Create a local buffer with the test content
|
||||
let buffer = cx.new(|cx| language::Buffer::local(whole, cx));
|
||||
|
||||
// Get the buffer snapshot
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
// Call replace_flexible and transform the result
|
||||
replace_with_flexible_indent(old, new, &buffer_snapshot).map(|diff| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
let _ = buffer.apply_diff(diff, cx);
|
||||
buffer.text()
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, AsyncApp, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
|
||||
use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -121,6 +121,7 @@ impl Tool for SymbolInfoTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<SymbolInfoToolInput>(input) {
|
||||
@@ -140,7 +141,7 @@ impl Tool for SymbolInfoTool {
|
||||
};
|
||||
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.buffer_read(buffer.clone(), cx);
|
||||
action_log.track_buffer(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
let position = {
|
||||
|
||||
32
crates/assistant_tools/src/templates.rs
Normal file
32
crates/assistant_tools/src/templates.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use anyhow::Result;
|
||||
use handlebars::Handlebars;
|
||||
use rust_embed::RustEmbed;
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(RustEmbed)]
|
||||
#[folder = "src/templates"]
|
||||
#[include = "*.hbs"]
|
||||
struct Assets;
|
||||
|
||||
pub struct Templates(Handlebars<'static>);
|
||||
|
||||
impl Templates {
|
||||
pub fn new() -> Arc<Self> {
|
||||
let mut handlebars = Handlebars::new();
|
||||
handlebars.register_embed_templates::<Assets>().unwrap();
|
||||
handlebars.register_escape_fn(|text| text.into());
|
||||
Arc::new(Self(handlebars))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Template: Sized {
|
||||
const TEMPLATE_NAME: &'static str;
|
||||
|
||||
fn render(&self, templates: &Templates) -> Result<String>
|
||||
where
|
||||
Self: Serialize + Sized,
|
||||
{
|
||||
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
|
||||
}
|
||||
}
|
||||
46
crates/assistant_tools/src/templates/edit_agent.hbs
Normal file
46
crates/assistant_tools/src/templates/edit_agent.hbs
Normal file
@@ -0,0 +1,46 @@
|
||||
You are an expert text editor. Taking the following file as an input:
|
||||
|
||||
```{{path}}
|
||||
{{file_content}}
|
||||
```
|
||||
|
||||
Your response must be a series of edits in the following format:
|
||||
|
||||
<edits>
|
||||
<old_text>
|
||||
OLD TEXT 1 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 1 HERE
|
||||
</new_text>
|
||||
|
||||
<old_text>
|
||||
OLD TEXT 2 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 2 HERE
|
||||
</new_text>
|
||||
|
||||
<old_text>
|
||||
OLD TEXT 3 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 3 HERE
|
||||
</new_text>
|
||||
</edits>
|
||||
|
||||
Rules for editing:
|
||||
|
||||
- `old_text` represents full lines (including indentation) in the input file that will be replaced with `new_text`
|
||||
- It is crucial that `old_text` is unique and unambiguous.
|
||||
- Always include enough context around the lines you want to replace in `old_text` such that it's impossible to mistake them for other lines.
|
||||
- If you want to replace many occurrences of the same text, repeat the same `old_text`/`new_text` pair multiple times and I will apply them sequentially, one occurrence at a time.
|
||||
- Don't explain why you made a change, just report the edits.
|
||||
- Never do MORE than what the user has requested.
|
||||
- Never do LESS than what the user has requested.
|
||||
|
||||
<user_instructions>
|
||||
{{instructions}}
|
||||
</user_instructions>
|
||||
|
||||
<edits>
|
||||
@@ -3,7 +3,7 @@ use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::io::BufReader;
|
||||
use futures::{AsyncBufReadExt, AsyncReadExt, FutureExt};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -78,6 +78,7 @@ impl Tool for TerminalTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input: TerminalToolInput = match serde_json::from_value(input) {
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -50,6 +50,7 @@ impl Tool for ThinkingTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
_cx: &mut App,
|
||||
) -> ToolResult {
|
||||
// This tool just "thinks out loud" and doesn't perform any actions.
|
||||
|
||||
3
crates/assistant_tools/src/ui.rs
Normal file
3
crates/assistant_tools/src/ui.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod tool_call_card_header;
|
||||
|
||||
pub use tool_call_card_header::*;
|
||||
102
crates/assistant_tools/src/ui/tool_call_card_header.rs
Normal file
102
crates/assistant_tools/src/ui/tool_call_card_header.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
use gpui::{Animation, AnimationExt, App, IntoElement, pulsating_between};
|
||||
use std::time::Duration;
|
||||
use ui::{Tooltip, prelude::*};
|
||||
|
||||
/// A reusable header component for tool call cards.
|
||||
#[derive(IntoElement)]
|
||||
pub struct ToolCallCardHeader {
|
||||
icon: IconName,
|
||||
primary_text: SharedString,
|
||||
secondary_text: Option<SharedString>,
|
||||
is_loading: bool,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
impl ToolCallCardHeader {
|
||||
pub fn new(icon: IconName, primary_text: impl Into<SharedString>) -> Self {
|
||||
Self {
|
||||
icon,
|
||||
primary_text: primary_text.into(),
|
||||
secondary_text: None,
|
||||
is_loading: false,
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_secondary_text(mut self, text: impl Into<SharedString>) -> Self {
|
||||
self.secondary_text = Some(text.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn loading(mut self) -> Self {
|
||||
self.is_loading = true;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_error(mut self, error: impl Into<String>) -> Self {
|
||||
self.error = Some(error.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for ToolCallCardHeader {
|
||||
fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
let font_size = rems(0.8125);
|
||||
let secondary_text = self.secondary_text;
|
||||
|
||||
h_flex()
|
||||
.id("tool-label-container")
|
||||
.gap_1p5()
|
||||
.max_w_full()
|
||||
.overflow_x_scroll()
|
||||
.opacity(0.8)
|
||||
.child(
|
||||
h_flex().h(window.line_height()).justify_center().child(
|
||||
Icon::new(self.icon)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.h(window.line_height())
|
||||
.gap_1p5()
|
||||
.text_size(font_size)
|
||||
.map(|this| {
|
||||
if let Some(error) = &self.error {
|
||||
this.child(format!("{} failed", self.primary_text)).child(
|
||||
IconButton::new("error_info", IconName::Warning)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Warning)
|
||||
.tooltip(Tooltip::text(error.clone())),
|
||||
)
|
||||
} else {
|
||||
this.child(self.primary_text.clone())
|
||||
}
|
||||
})
|
||||
.when_some(secondary_text, |this, secondary_text| {
|
||||
this.child(
|
||||
div()
|
||||
.size(px(3.))
|
||||
.rounded_full()
|
||||
.bg(cx.theme().colors().text),
|
||||
)
|
||||
.child(div().text_size(font_size).child(secondary_text.clone()))
|
||||
})
|
||||
.with_animation(
|
||||
"loading-label",
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.6, 1.)),
|
||||
move |this, delta| {
|
||||
if self.is_loading {
|
||||
this.opacity(delta)
|
||||
} else {
|
||||
this
|
||||
}
|
||||
},
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,12 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use crate::schema::json_schema_for;
|
||||
use crate::ui::ToolCallCardHeader;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use futures::{Future, FutureExt, TryFutureExt};
|
||||
use gpui::{
|
||||
Animation, AnimationExt, App, AppContext, Context, Entity, IntoElement, Task, Window,
|
||||
pulsating_between,
|
||||
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
|
||||
};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -14,6 +14,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ui::{IconName, Tooltip, prelude::*};
|
||||
use web_search::WebSearchRegistry;
|
||||
use workspace::Workspace;
|
||||
use zed_llm_client::{WebSearchCitation, WebSearchResponse};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -47,7 +48,7 @@ impl Tool for WebSearchTool {
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
"Web Search".to_string()
|
||||
"Searching the Web".to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
@@ -56,6 +57,7 @@ impl Tool for WebSearchTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
|
||||
@@ -113,63 +115,33 @@ impl ToolCard for WebSearchToolCard {
|
||||
&mut self,
|
||||
_status: &ToolUseStatus,
|
||||
_window: &mut Window,
|
||||
_workspace: WeakEntity<Workspace>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let header = h_flex()
|
||||
.id("tool-label-container")
|
||||
.gap_1p5()
|
||||
.max_w_full()
|
||||
.overflow_x_scroll()
|
||||
.child(
|
||||
Icon::new(IconName::Globe)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(match self.response.as_ref() {
|
||||
Some(Ok(response)) => {
|
||||
let text: SharedString = if response.citations.len() == 1 {
|
||||
"1 result".into()
|
||||
} else {
|
||||
format!("{} results", response.citations.len()).into()
|
||||
};
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(Label::new("Searched the Web").size(LabelSize::Small))
|
||||
.child(
|
||||
div()
|
||||
.size(px(3.))
|
||||
.rounded_full()
|
||||
.bg(cx.theme().colors().text),
|
||||
)
|
||||
.child(Label::new(text).size(LabelSize::Small))
|
||||
.into_any_element()
|
||||
}
|
||||
Some(Err(error)) => div()
|
||||
.id("web-search-error")
|
||||
.child(Label::new("Web Search failed").size(LabelSize::Small))
|
||||
.tooltip(Tooltip::text(error.to_string()))
|
||||
.into_any_element(),
|
||||
|
||||
None => Label::new("Searching the Web…")
|
||||
.size(LabelSize::Small)
|
||||
.with_animation(
|
||||
"web-search-label",
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.6, 1.)),
|
||||
|label, delta| label.alpha(delta),
|
||||
)
|
||||
.into_any_element(),
|
||||
})
|
||||
.into_any();
|
||||
let header = match self.response.as_ref() {
|
||||
Some(Ok(response)) => {
|
||||
let text: SharedString = if response.citations.len() == 1 {
|
||||
"1 result".into()
|
||||
} else {
|
||||
format!("{} results", response.citations.len()).into()
|
||||
};
|
||||
ToolCallCardHeader::new(IconName::Globe, "Searched the Web")
|
||||
.with_secondary_text(text)
|
||||
}
|
||||
Some(Err(error)) => {
|
||||
ToolCallCardHeader::new(IconName::Globe, "Web Search").with_error(error.to_string())
|
||||
}
|
||||
None => ToolCallCardHeader::new(IconName::Globe, "Searching the Web").loading(),
|
||||
};
|
||||
|
||||
let content =
|
||||
self.response.as_ref().and_then(|response| match response {
|
||||
Ok(response) => {
|
||||
Some(
|
||||
v_flex()
|
||||
.overflow_hidden()
|
||||
.ml_1p5()
|
||||
.pl_1p5()
|
||||
.pl(px(5.))
|
||||
.border_l_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.gap_1()
|
||||
@@ -209,7 +181,7 @@ impl ToolCard for WebSearchToolCard {
|
||||
Err(_) => None,
|
||||
});
|
||||
|
||||
v_flex().my_2().gap_1().child(header).children(content)
|
||||
v_flex().mb_3().gap_1().child(header).children(content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,8 +225,13 @@ impl Component for WebSearchTool {
|
||||
div()
|
||||
.size_full()
|
||||
.child(in_progress_search.update(cx, |tool, cx| {
|
||||
tool.render(&ToolUseStatus::Pending, window, cx)
|
||||
.into_any_element()
|
||||
tool.render(
|
||||
&ToolUseStatus::Pending,
|
||||
window,
|
||||
WeakEntity::new_invalid(),
|
||||
cx,
|
||||
)
|
||||
.into_any_element()
|
||||
}))
|
||||
.into_any_element(),
|
||||
),
|
||||
@@ -263,8 +240,13 @@ impl Component for WebSearchTool {
|
||||
div()
|
||||
.size_full()
|
||||
.child(successful_search.update(cx, |tool, cx| {
|
||||
tool.render(&ToolUseStatus::Finished("".into()), window, cx)
|
||||
.into_any_element()
|
||||
tool.render(
|
||||
&ToolUseStatus::Finished("".into()),
|
||||
window,
|
||||
WeakEntity::new_invalid(),
|
||||
cx,
|
||||
)
|
||||
.into_any_element()
|
||||
}))
|
||||
.into_any_element(),
|
||||
),
|
||||
@@ -273,8 +255,13 @@ impl Component for WebSearchTool {
|
||||
div()
|
||||
.size_full()
|
||||
.child(error_search.update(cx, |tool, cx| {
|
||||
tool.render(&ToolUseStatus::Error("".into()), window, cx)
|
||||
.into_any_element()
|
||||
tool.render(
|
||||
&ToolUseStatus::Error("".into()),
|
||||
window,
|
||||
WeakEntity::new_invalid(),
|
||||
cx,
|
||||
)
|
||||
.into_any_element()
|
||||
}))
|
||||
.into_any_element(),
|
||||
),
|
||||
|
||||
@@ -118,6 +118,13 @@ impl Settings for AutoUpdateSetting {
|
||||
|
||||
Ok(Self(auto_update.0))
|
||||
}
|
||||
|
||||
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
|
||||
vscode.enum_setting("update.mode", current, |s| match s {
|
||||
"none" | "manual" => Some(AutoUpdateSettingContent(false)),
|
||||
_ => Some(AutoUpdateSettingContent(true)),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
|
||||
@@ -32,4 +32,6 @@ impl Settings for CallSettings {
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
||||
sources.json_merge()
|
||||
}
|
||||
|
||||
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user