Compare commits
71 Commits
v0.199.8
...
agent-stil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fbd27148d5 | ||
|
|
d6022dc87c | ||
|
|
0dd480d475 | ||
|
|
34fc2fd9d0 | ||
|
|
00701b5e99 | ||
|
|
cdfb3348ea | ||
|
|
35cd1b9ae1 | ||
|
|
bd402fdc7d | ||
|
|
c7d641ecb8 | ||
|
|
3d662ee282 | ||
|
|
7d4d8b8398 | ||
|
|
6912dc8399 | ||
|
|
952e3713d7 | ||
|
|
913e9adf90 | ||
|
|
50482a6bc2 | ||
|
|
d110459ef8 | ||
|
|
d693f02c63 | ||
|
|
90fa921756 | ||
|
|
11efa32fa7 | ||
|
|
e6dc6faccf | ||
|
|
070f7dbe1a | ||
|
|
106d4cfce9 | ||
|
|
a1080a0411 | ||
|
|
7679db99ac | ||
|
|
9ade399756 | ||
|
|
e8db429d24 | ||
|
|
53b69d29c5 | ||
|
|
e2e147ab0e | ||
|
|
fa2ff3ce1c | ||
|
|
c1d1d1cff6 | ||
|
|
efba2cbfd3 | ||
|
|
2234220618 | ||
|
|
dd840e4b27 | ||
|
|
262365ca24 | ||
|
|
90fa06dd61 | ||
|
|
740686b883 | ||
|
|
a5c25e0366 | ||
|
|
305c653c62 | ||
|
|
e227b5ac30 | ||
|
|
03876d076e | ||
|
|
4dbd24d75f | ||
|
|
c397027ec2 | ||
|
|
f5f837d39a | ||
|
|
5b1b3c51d4 | ||
|
|
b4a441f12f | ||
|
|
f1e69f6311 | ||
|
|
bd1c26cb5b | ||
|
|
1907b16fe6 | ||
|
|
c595a7576d | ||
|
|
8e290b446e | ||
|
|
58392b9c13 | ||
|
|
9358690337 | ||
|
|
3ea90e397b | ||
|
|
a5dd8d0052 | ||
|
|
250c51bb20 | ||
|
|
010441e23b | ||
|
|
f9038f6189 | ||
|
|
a80da784b7 | ||
|
|
fb1f9d1212 | ||
|
|
794098e5c9 | ||
|
|
b08e26df60 | ||
|
|
740597492b | ||
|
|
ebda6b8a94 | ||
|
|
55b4df4d9f | ||
|
|
b8e8fbd8e6 | ||
|
|
33f198fef1 | ||
|
|
3c602fecbf | ||
|
|
334bdd0efc | ||
|
|
69dc870828 | ||
|
|
22fa41e9c0 | ||
|
|
7e790f52c8 |
3
.github/actionlint.yml
vendored
@@ -24,9 +24,6 @@ self-hosted-runner:
|
||||
- namespace-profile-8x16-ubuntu-2204
|
||||
- namespace-profile-16x32-ubuntu-2204
|
||||
- namespace-profile-32x64-ubuntu-2204
|
||||
# Namespace Limited Preview
|
||||
- namespace-profile-8x16-ubuntu-2004-arm-m4
|
||||
- namespace-profile-8x32-ubuntu-2004-arm-m4
|
||||
# Self Hosted Runners
|
||||
- self-mini-macos
|
||||
- self-32vcpu-windows-2022
|
||||
|
||||
16
.github/workflows/ci.yml
vendored
@@ -511,8 +511,8 @@ jobs:
|
||||
runs-on:
|
||||
- self-mini-macos
|
||||
if: |
|
||||
( startsWith(github.ref, 'refs/tags/v')
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling') )
|
||||
startsWith(github.ref, 'refs/tags/v')
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
needs: [macos_tests]
|
||||
env:
|
||||
MACOS_CERTIFICATE: ${{ secrets.MACOS_CERTIFICATE }}
|
||||
@@ -599,8 +599,8 @@ jobs:
|
||||
runs-on:
|
||||
- namespace-profile-16x32-ubuntu-2004 # ubuntu 20.04 for minimal glibc
|
||||
if: |
|
||||
( startsWith(github.ref, 'refs/tags/v')
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling') )
|
||||
startsWith(github.ref, 'refs/tags/v')
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
needs: [linux_tests]
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
@@ -650,7 +650,7 @@ jobs:
|
||||
timeout-minutes: 60
|
||||
name: Linux arm64 release bundle
|
||||
runs-on:
|
||||
- namespace-profile-8x32-ubuntu-2004-arm-m4 # ubuntu 20.04 for minimal glibc
|
||||
- namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc
|
||||
if: |
|
||||
startsWith(github.ref, 'refs/tags/v')
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
@@ -703,8 +703,10 @@ jobs:
|
||||
timeout-minutes: 60
|
||||
runs-on: github-8vcpu-ubuntu-2404
|
||||
if: |
|
||||
false && ( startsWith(github.ref, 'refs/tags/v')
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling') )
|
||||
false && (
|
||||
startsWith(github.ref, 'refs/tags/v')
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
)
|
||||
needs: [linux_tests]
|
||||
name: Build Zed on FreeBSD
|
||||
steps:
|
||||
|
||||
2
.github/workflows/deploy_collab.yml
vendored
@@ -137,12 +137,14 @@ jobs:
|
||||
|
||||
export ZED_SERVICE_NAME=collab
|
||||
export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT
|
||||
export DATABASE_MAX_CONNECTIONS=850
|
||||
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
|
||||
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
|
||||
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"
|
||||
|
||||
export ZED_SERVICE_NAME=api
|
||||
export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_API_LOAD_BALANCER_SIZE_UNIT
|
||||
export DATABASE_MAX_CONNECTIONS=60
|
||||
envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f -
|
||||
kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch
|
||||
echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}"
|
||||
|
||||
2
.github/workflows/release_nightly.yml
vendored
@@ -168,7 +168,7 @@ jobs:
|
||||
name: Create a Linux *.tar.gz bundle for ARM
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
runs-on:
|
||||
- namespace-profile-8x32-ubuntu-2004-arm-m4 # ubuntu 20.04 for minimal glibc
|
||||
- namespace-profile-32x64-ubuntu-2004-arm # ubuntu 20.04 for minimal glibc
|
||||
needs: tests
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
|
||||
23
Cargo.lock
generated
@@ -138,9 +138,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "agent-client-protocol"
|
||||
version = "0.0.20"
|
||||
version = "0.0.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "12dbfec3d27680337ed9d3064eecafe97acf0b0f190148bb4e29d96707c9e403"
|
||||
checksum = "3fad72b7b8ee4331b3a4c8d43c107e982a4725564b4ee658ae5c4e79d2b486e8"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.31",
|
||||
@@ -159,6 +159,7 @@ dependencies = [
|
||||
"agent-client-protocol",
|
||||
"agent_servers",
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"client",
|
||||
"clock",
|
||||
"cloud_llm_client",
|
||||
@@ -171,10 +172,14 @@ dependencies = [
|
||||
"gpui_tokio",
|
||||
"handlebars 4.5.0",
|
||||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"log",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"reqwest_client",
|
||||
"rust-embed",
|
||||
"schemars",
|
||||
@@ -185,6 +190,7 @@ dependencies = [
|
||||
"ui",
|
||||
"util",
|
||||
"uuid",
|
||||
"watch",
|
||||
"workspace-hack",
|
||||
"worktree",
|
||||
]
|
||||
@@ -7502,9 +7508,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "grid"
|
||||
version = "0.17.0"
|
||||
version = "0.18.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "71b01d27060ad58be4663b9e4ac9e2d4806918e8876af8912afbddd1a91d5eaa"
|
||||
checksum = "12101ecc8225ea6d675bc70263074eab6169079621c2186fe0c66590b2df9681"
|
||||
|
||||
[[package]]
|
||||
name = "group"
|
||||
@@ -11184,7 +11190,6 @@ dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.31",
|
||||
"http_client",
|
||||
"log",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -12634,6 +12639,7 @@ dependencies = [
|
||||
"editor",
|
||||
"file_icons",
|
||||
"git",
|
||||
"git_ui",
|
||||
"gpui",
|
||||
"indexmap",
|
||||
"language",
|
||||
@@ -16217,9 +16223,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "taffy"
|
||||
version = "0.8.3"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7aaef0ac998e6527d6d0d5582f7e43953bb17221ac75bb8eb2fcc2db3396db1c"
|
||||
checksum = "a13e5d13f79d558b5d353a98072ca8ca0e99da429467804de959aa8c83c9a004"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"grid",
|
||||
@@ -20461,7 +20467,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.199.8"
|
||||
version = "0.200.0"
|
||||
dependencies = [
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
@@ -20864,7 +20870,6 @@ dependencies = [
|
||||
"menu",
|
||||
"postage",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"release_channel",
|
||||
"reqwest_client",
|
||||
|
||||
@@ -425,7 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
||||
#
|
||||
|
||||
agentic-coding-protocol = "0.0.10"
|
||||
agent-client-protocol = "0.0.20"
|
||||
agent-client-protocol = { version = "0.0.23" }
|
||||
aho-corasick = "1.1"
|
||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||
any_vec = "0.14"
|
||||
|
||||
1
Procfile
@@ -1,3 +1,4 @@
|
||||
collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all
|
||||
cloud: cd ../cloud; cargo make dev
|
||||
livekit: livekit-server --dev
|
||||
blob_store: ./script/run-local-minio
|
||||
|
||||
1
assets/icons/file_icons/puppet.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="17" height="16" fill="none"><path fill="#000" d="M12.5 9.639V6.354H9.683L8.18 5.007V2.5H4.5v3.285h2.817l1.51 1.347v1.736l-1.51 1.347H4.5V13.5h3.681v-2.507l1.51-1.347H12.5v-.007ZM5.727 3.595h1.227V4.69H5.727V3.595Zm1.227 8.803H5.727v-1.095h1.227v1.095Z"/></svg>
|
||||
|
After Width: | Height: | Size: 308 B |
|
Before Width: | Height: | Size: 776 B After Width: | Height: | Size: 776 B |
|
Before Width: | Height: | Size: 6.4 KiB |
|
Before Width: | Height: | Size: 2.5 KiB After Width: | Height: | Size: 5.3 KiB |
1
assets/images/pro_user_stamp.svg
Normal file
|
After Width: | Height: | Size: 5.5 KiB |
@@ -332,7 +332,9 @@
|
||||
"enter": "agent::Chat",
|
||||
"up": "agent::PreviousHistoryMessage",
|
||||
"down": "agent::NextHistoryMessage",
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff"
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"ctrl-shift-y": "agent::KeepAll",
|
||||
"ctrl-shift-n": "agent::RejectAll"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -846,6 +848,7 @@
|
||||
"ctrl-delete": ["project_panel::Delete", { "skip_prompt": false }],
|
||||
"alt-ctrl-r": "project_panel::RevealInFileManager",
|
||||
"ctrl-shift-enter": "project_panel::OpenWithSystem",
|
||||
"alt-d": "project_panel::CompareMarkedFiles",
|
||||
"shift-find": "project_panel::NewSearchInDirectory",
|
||||
"ctrl-alt-shift-f": "project_panel::NewSearchInDirectory",
|
||||
"shift-down": "menu::SelectNext",
|
||||
@@ -1100,6 +1103,13 @@
|
||||
"ctrl-enter": "menu::Confirm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "OnboardingAiConfigurationModal",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"escape": "menu::Cancel"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Diagnostics",
|
||||
"use_key_equivalents": true,
|
||||
@@ -1176,7 +1186,8 @@
|
||||
"ctrl-1": "onboarding::ActivateBasicsPage",
|
||||
"ctrl-2": "onboarding::ActivateEditingPage",
|
||||
"ctrl-3": "onboarding::ActivateAISetupPage",
|
||||
"ctrl-escape": "onboarding::Finish"
|
||||
"ctrl-escape": "onboarding::Finish",
|
||||
"alt-tab": "onboarding::SignIn"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -384,7 +384,9 @@
|
||||
"enter": "agent::Chat",
|
||||
"up": "agent::PreviousHistoryMessage",
|
||||
"down": "agent::NextHistoryMessage",
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff"
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"cmd-shift-y": "agent::KeepAll",
|
||||
"cmd-shift-n": "agent::RejectAll"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -905,6 +907,7 @@
|
||||
"cmd-delete": ["project_panel::Delete", { "skip_prompt": false }],
|
||||
"alt-cmd-r": "project_panel::RevealInFileManager",
|
||||
"ctrl-shift-enter": "project_panel::OpenWithSystem",
|
||||
"alt-d": "project_panel::CompareMarkedFiles",
|
||||
"cmd-alt-backspace": ["project_panel::Delete", { "skip_prompt": false }],
|
||||
"cmd-alt-shift-f": "project_panel::NewSearchInDirectory",
|
||||
"shift-down": "menu::SelectNext",
|
||||
@@ -1202,6 +1205,13 @@
|
||||
"cmd-enter": "menu::Confirm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "OnboardingAiConfigurationModal",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"escape": "menu::Cancel"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Diagnostics",
|
||||
"use_key_equivalents": true,
|
||||
@@ -1278,7 +1288,8 @@
|
||||
"cmd-1": "onboarding::ActivateBasicsPage",
|
||||
"cmd-2": "onboarding::ActivateEditingPage",
|
||||
"cmd-3": "onboarding::ActivateAISetupPage",
|
||||
"cmd-escape": "onboarding::Finish"
|
||||
"cmd-escape": "onboarding::Finish",
|
||||
"alt-tab": "onboarding::SignIn"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -813,6 +813,7 @@
|
||||
"p": "project_panel::Open",
|
||||
"x": "project_panel::RevealInFileManager",
|
||||
"s": "project_panel::OpenWithSystem",
|
||||
"z d": "project_panel::CompareMarkedFiles",
|
||||
"] c": "project_panel::SelectNextGitEntry",
|
||||
"[ c": "project_panel::SelectPrevGitEntry",
|
||||
"] d": "project_panel::SelectNextDiagnostic",
|
||||
|
||||
@@ -1171,6 +1171,9 @@
|
||||
// Sets a delay after which the inline blame information is shown.
|
||||
// Delay is restarted with every cursor movement.
|
||||
"delay_ms": 0,
|
||||
// The amount of padding between the end of the source line and the start
|
||||
// of the inline blame in units of em widths.
|
||||
"padding": 7,
|
||||
// Whether or not to display the git commit summary on the same line.
|
||||
"show_commit_summary": false,
|
||||
// The minimum column number to show the inline blame information at
|
||||
|
||||
@@ -166,6 +166,7 @@ pub struct ToolCall {
|
||||
pub status: ToolCallStatus,
|
||||
pub locations: Vec<acp::ToolCallLocation>,
|
||||
pub raw_input: Option<serde_json::Value>,
|
||||
pub raw_output: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl ToolCall {
|
||||
@@ -194,6 +195,7 @@ impl ToolCall {
|
||||
locations: tool_call.locations,
|
||||
status,
|
||||
raw_input: tool_call.raw_input,
|
||||
raw_output: tool_call.raw_output,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,6 +212,7 @@ impl ToolCall {
|
||||
content,
|
||||
locations,
|
||||
raw_input,
|
||||
raw_output,
|
||||
} = fields;
|
||||
|
||||
if let Some(kind) = kind {
|
||||
@@ -221,7 +224,9 @@ impl ToolCall {
|
||||
}
|
||||
|
||||
if let Some(title) = title {
|
||||
self.label = cx.new(|cx| Markdown::new_text(title.into(), cx));
|
||||
self.label.update(cx, |label, cx| {
|
||||
label.replace(title, cx);
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(content) = content {
|
||||
@@ -238,6 +243,10 @@ impl ToolCall {
|
||||
if let Some(raw_input) = raw_input {
|
||||
self.raw_input = Some(raw_input);
|
||||
}
|
||||
|
||||
if let Some(raw_output) = raw_output {
|
||||
self.raw_output = Some(raw_output);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
|
||||
@@ -411,8 +420,6 @@ impl ToolCallContent {
|
||||
pub struct Diff {
|
||||
pub multibuffer: Entity<MultiBuffer>,
|
||||
pub path: PathBuf,
|
||||
pub new_buffer: Entity<Buffer>,
|
||||
pub old_buffer: Entity<Buffer>,
|
||||
_task: Task<Result<()>>,
|
||||
}
|
||||
|
||||
@@ -433,23 +440,34 @@ impl Diff {
|
||||
let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
|
||||
let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
|
||||
let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
|
||||
let old_buffer_snapshot = old_buffer.read(cx).snapshot();
|
||||
let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
|
||||
let diff_task = buffer_diff.update(cx, |diff, cx| {
|
||||
diff.set_base_text(
|
||||
old_buffer_snapshot,
|
||||
Some(language_registry.clone()),
|
||||
new_buffer_snapshot,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let task = cx.spawn({
|
||||
let multibuffer = multibuffer.clone();
|
||||
let path = path.clone();
|
||||
let new_buffer = new_buffer.clone();
|
||||
async move |cx| {
|
||||
diff_task.await?;
|
||||
let language = language_registry
|
||||
.language_for_file_path(&path)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
|
||||
|
||||
let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
|
||||
buffer.set_language(language, cx);
|
||||
buffer.snapshot()
|
||||
})?;
|
||||
|
||||
buffer_diff
|
||||
.update(cx, |diff, cx| {
|
||||
diff.set_base_text(
|
||||
old_buffer_snapshot,
|
||||
Some(language_registry),
|
||||
new_buffer_snapshot,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
multibuffer
|
||||
.update(cx, |multibuffer, cx| {
|
||||
@@ -468,18 +486,10 @@ impl Diff {
|
||||
editor::DEFAULT_MULTIBUFFER_CONTEXT,
|
||||
cx,
|
||||
);
|
||||
multibuffer.add_diff(buffer_diff.clone(), cx);
|
||||
multibuffer.add_diff(buffer_diff, cx);
|
||||
})
|
||||
.log_err();
|
||||
|
||||
if let Some(language) = language_registry
|
||||
.language_for_file_path(&path)
|
||||
.await
|
||||
.log_err()
|
||||
{
|
||||
new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
});
|
||||
@@ -487,8 +497,6 @@ impl Diff {
|
||||
Self {
|
||||
multibuffer,
|
||||
path,
|
||||
new_buffer,
|
||||
old_buffer,
|
||||
_task: task,
|
||||
}
|
||||
}
|
||||
@@ -557,7 +565,7 @@ pub struct PlanEntry {
|
||||
impl PlanEntry {
|
||||
pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
|
||||
Self {
|
||||
content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
|
||||
content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
|
||||
priority: entry.priority,
|
||||
status: entry.status,
|
||||
}
|
||||
@@ -661,7 +669,7 @@ impl AcpThread {
|
||||
}
|
||||
|
||||
pub fn status(&self) -> ThreadStatus {
|
||||
if self.send_task.is_some() {
|
||||
if self.send_task.as_ref().map_or(false, |t| !t.is_finished()) {
|
||||
if self.waiting_for_tool_confirmation() {
|
||||
ThreadStatus::WaitingForToolConfirmation
|
||||
} else {
|
||||
@@ -896,7 +904,7 @@ impl AcpThread {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn request_tool_call_permission(
|
||||
pub fn request_tool_call_authorization(
|
||||
&mut self,
|
||||
tool_call: acp::ToolCall,
|
||||
options: Vec<acp::PermissionOption>,
|
||||
@@ -971,13 +979,26 @@ impl AcpThread {
|
||||
}
|
||||
|
||||
pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
|
||||
self.plan = Plan {
|
||||
entries: request
|
||||
.entries
|
||||
.into_iter()
|
||||
.map(|entry| PlanEntry::from_acp(entry, cx))
|
||||
.collect(),
|
||||
};
|
||||
let new_entries_len = request.entries.len();
|
||||
let mut new_entries = request.entries.into_iter();
|
||||
|
||||
// Reuse existing markdown to prevent flickering
|
||||
for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
|
||||
let PlanEntry {
|
||||
content,
|
||||
priority,
|
||||
status,
|
||||
} = old;
|
||||
content.update(cx, |old, cx| {
|
||||
old.replace(new.content, cx);
|
||||
});
|
||||
*priority = new.priority;
|
||||
*status = new.status;
|
||||
}
|
||||
for new in new_entries {
|
||||
self.plan.entries.push(PlanEntry::from_acp(new, cx))
|
||||
}
|
||||
self.plan.entries.truncate(new_entries_len);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
@@ -1038,8 +1059,8 @@ impl AcpThread {
|
||||
)
|
||||
})?
|
||||
.await;
|
||||
|
||||
tx.send(result).log_err();
|
||||
this.update(cx, |this, _cx| this.send_task.take())?;
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.await
|
||||
@@ -1525,6 +1546,7 @@ mod tests {
|
||||
content: vec![],
|
||||
locations: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
@@ -1637,6 +1659,7 @@ mod tests {
|
||||
}],
|
||||
locations: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ acp_thread.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent_servers.workspace = true
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
fs.workspace = true
|
||||
@@ -23,10 +24,13 @@ futures.workspace = true
|
||||
gpui.workspace = true
|
||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||
indoc.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
log.workspace = true
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
rust-embed.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
@@ -36,7 +40,7 @@ smol.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
worktree.workspace = true
|
||||
watch.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
@@ -47,8 +51,10 @@ env_logger.workspace = true
|
||||
fs = { workspace = true, "features" = ["test-support"] }
|
||||
gpui = { workspace = true, "features" = ["test-support"] }
|
||||
gpui_tokio.workspace = true
|
||||
language = { workspace = true, "features" = ["test-support"] }
|
||||
language_model = { workspace = true, "features" = ["test-support"] }
|
||||
project = { workspace = true, "features" = ["test-support"] }
|
||||
reqwest_client.workspace = true
|
||||
settings = { workspace = true, "features" = ["test-support"] }
|
||||
worktree = { workspace = true, "features" = ["test-support"] }
|
||||
pretty_assertions.workspace = true
|
||||
|
||||
@@ -1,16 +1,39 @@
|
||||
use crate::{templates::Templates, AgentResponseEvent, Thread};
|
||||
use crate::{FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization};
|
||||
use acp_thread::ModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::StreamExt;
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Subscription, Task, WeakEntity};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use futures::{future, StreamExt};
|
||||
use gpui::{
|
||||
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
|
||||
};
|
||||
use language_model::{LanguageModel, LanguageModelRegistry};
|
||||
use project::Project;
|
||||
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
||||
use prompt_store::{
|
||||
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
|
||||
};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{templates::Templates, AgentResponseEvent, Thread};
|
||||
const RULES_FILE_NAMES: [&'static str; 9] = [
|
||||
".rules",
|
||||
".cursorrules",
|
||||
".windsurfrules",
|
||||
".clinerules",
|
||||
".github/copilot-instructions.md",
|
||||
"CLAUDE.md",
|
||||
"AGENT.md",
|
||||
"AGENTS.md",
|
||||
"GEMINI.md",
|
||||
];
|
||||
|
||||
pub struct RulesLoadingError {
|
||||
pub message: SharedString,
|
||||
}
|
||||
|
||||
/// Holds both the internal Thread and the AcpThread for a session
|
||||
struct Session {
|
||||
@@ -24,17 +47,247 @@ struct Session {
|
||||
pub struct NativeAgent {
|
||||
/// Session ID -> Session mapping
|
||||
sessions: HashMap<acp::SessionId, Session>,
|
||||
/// Shared project context for all threads
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
project_context_needs_refresh: watch::Sender<()>,
|
||||
_maintain_project_context: Task<Result<()>>,
|
||||
/// Shared templates for all threads
|
||||
templates: Arc<Templates>,
|
||||
project: Entity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
impl NativeAgent {
|
||||
pub fn new(templates: Arc<Templates>) -> Self {
|
||||
pub async fn new(
|
||||
project: Entity<Project>,
|
||||
templates: Arc<Templates>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<NativeAgent>> {
|
||||
log::info!("Creating new NativeAgent");
|
||||
Self {
|
||||
sessions: HashMap::new(),
|
||||
templates,
|
||||
|
||||
let project_context = cx
|
||||
.update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
|
||||
.await;
|
||||
|
||||
cx.new(|cx| {
|
||||
let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
|
||||
if let Some(prompt_store) = prompt_store.as_ref() {
|
||||
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
|
||||
}
|
||||
|
||||
let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
|
||||
watch::channel(());
|
||||
Self {
|
||||
sessions: HashMap::new(),
|
||||
project_context: Rc::new(RefCell::new(project_context)),
|
||||
project_context_needs_refresh: project_context_needs_refresh_tx,
|
||||
_maintain_project_context: cx.spawn(async move |this, cx| {
|
||||
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
|
||||
}),
|
||||
templates,
|
||||
project,
|
||||
prompt_store,
|
||||
_subscriptions: subscriptions,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn maintain_project_context(
|
||||
this: WeakEntity<Self>,
|
||||
mut needs_refresh: watch::Receiver<()>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
while needs_refresh.changed().await.is_ok() {
|
||||
let project_context = this
|
||||
.update(cx, |this, cx| {
|
||||
Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
|
||||
})?
|
||||
.await;
|
||||
this.update(cx, |this, _| this.project_context.replace(project_context))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_project_context(
|
||||
project: &Entity<Project>,
|
||||
prompt_store: Option<&Entity<PromptStore>>,
|
||||
cx: &mut App,
|
||||
) -> Task<ProjectContext> {
|
||||
let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
|
||||
let worktree_tasks = worktrees
|
||||
.into_iter()
|
||||
.map(|worktree| {
|
||||
Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
|
||||
prompt_store.read_with(cx, |prompt_store, cx| {
|
||||
let prompts = prompt_store.default_prompt_metadata();
|
||||
let load_tasks = prompts.into_iter().map(|prompt_metadata| {
|
||||
let contents = prompt_store.load(prompt_metadata.id, cx);
|
||||
async move { (contents.await, prompt_metadata) }
|
||||
});
|
||||
cx.background_spawn(future::join_all(load_tasks))
|
||||
})
|
||||
} else {
|
||||
Task::ready(vec![])
|
||||
};
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
let (worktrees, default_user_rules) =
|
||||
future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
|
||||
|
||||
let worktrees = worktrees
|
||||
.into_iter()
|
||||
.map(|(worktree, _rules_error)| {
|
||||
// TODO: show error message
|
||||
// if let Some(rules_error) = rules_error {
|
||||
// this.update(cx, |_, cx| cx.emit(rules_error)).ok();
|
||||
// }
|
||||
worktree
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let default_user_rules = default_user_rules
|
||||
.into_iter()
|
||||
.flat_map(|(contents, prompt_metadata)| match contents {
|
||||
Ok(contents) => Some(UserRulesContext {
|
||||
uuid: match prompt_metadata.id {
|
||||
PromptId::User { uuid } => uuid,
|
||||
PromptId::EditWorkflow => return None,
|
||||
},
|
||||
title: prompt_metadata.title.map(|title| title.to_string()),
|
||||
contents,
|
||||
}),
|
||||
Err(_err) => {
|
||||
// TODO: show error message
|
||||
// this.update(cx, |_, cx| {
|
||||
// cx.emit(RulesLoadingError {
|
||||
// message: format!("{err:?}").into(),
|
||||
// });
|
||||
// })
|
||||
// .ok();
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
ProjectContext::new(worktrees, default_user_rules)
|
||||
})
|
||||
}
|
||||
|
||||
fn load_worktree_info_for_system_prompt(
|
||||
worktree: Entity<Worktree>,
|
||||
project: Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
|
||||
let tree = worktree.read(cx);
|
||||
let root_name = tree.root_name().into();
|
||||
let abs_path = tree.abs_path();
|
||||
|
||||
let mut context = WorktreeContext {
|
||||
root_name,
|
||||
abs_path,
|
||||
rules_file: None,
|
||||
};
|
||||
|
||||
let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
|
||||
let Some(rules_task) = rules_task else {
|
||||
return Task::ready((context, None));
|
||||
};
|
||||
|
||||
cx.spawn(async move |_| {
|
||||
let (rules_file, rules_file_error) = match rules_task.await {
|
||||
Ok(rules_file) => (Some(rules_file), None),
|
||||
Err(err) => (
|
||||
None,
|
||||
Some(RulesLoadingError {
|
||||
message: format!("{err}").into(),
|
||||
}),
|
||||
),
|
||||
};
|
||||
context.rules_file = rules_file;
|
||||
(context, rules_file_error)
|
||||
})
|
||||
}
|
||||
|
||||
fn load_worktree_rules_file(
|
||||
worktree: Entity<Worktree>,
|
||||
project: Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Option<Task<Result<RulesFileContext>>> {
|
||||
let worktree = worktree.read(cx);
|
||||
let worktree_id = worktree.id();
|
||||
let selected_rules_file = RULES_FILE_NAMES
|
||||
.into_iter()
|
||||
.filter_map(|name| {
|
||||
worktree
|
||||
.entry_for_path(name)
|
||||
.filter(|entry| entry.is_file())
|
||||
.map(|entry| entry.path.clone())
|
||||
})
|
||||
.next();
|
||||
|
||||
// Note that Cline supports `.clinerules` being a directory, but that is not currently
|
||||
// supported. This doesn't seem to occur often in GitHub repositories.
|
||||
selected_rules_file.map(|path_in_worktree| {
|
||||
let project_path = ProjectPath {
|
||||
worktree_id,
|
||||
path: path_in_worktree.clone(),
|
||||
};
|
||||
let buffer_task =
|
||||
project.update(cx, |project, cx| project.open_buffer(project_path, cx));
|
||||
let rope_task = cx.spawn(async move |cx| {
|
||||
buffer_task.await?.read_with(cx, |buffer, cx| {
|
||||
let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
|
||||
anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
|
||||
})?
|
||||
});
|
||||
// Build a string from the rope on a background thread.
|
||||
cx.background_spawn(async move {
|
||||
let (project_entry_id, rope) = rope_task.await?;
|
||||
anyhow::Ok(RulesFileContext {
|
||||
path_in_worktree,
|
||||
text: rope.to_string().trim().to_string(),
|
||||
project_entry_id: project_entry_id.to_usize(),
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_project_event(
|
||||
&mut self,
|
||||
_project: Entity<Project>,
|
||||
event: &project::Event,
|
||||
_cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
|
||||
self.project_context_needs_refresh.send(()).ok();
|
||||
}
|
||||
project::Event::WorktreeUpdatedEntries(_, items) => {
|
||||
if items.iter().any(|(path, _, _)| {
|
||||
RULES_FILE_NAMES
|
||||
.iter()
|
||||
.any(|name| path.as_ref() == Path::new(name))
|
||||
}) {
|
||||
self.project_context_needs_refresh.send(()).ok();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_prompts_updated_event(
|
||||
&mut self,
|
||||
_prompt_store: Entity<PromptStore>,
|
||||
_event: &prompt_store::PromptsUpdatedEvent,
|
||||
_cx: &mut Context<Self>,
|
||||
) {
|
||||
self.project_context_needs_refresh.send(()).ok();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,8 +373,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
log::debug!("Starting thread creation in async context");
|
||||
|
||||
// Generate session ID
|
||||
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
|
||||
log::info!("Created session with ID: {}", session_id);
|
||||
|
||||
// Create AcpThread
|
||||
let acp_thread = cx.update(|cx| {
|
||||
cx.new(|cx| {
|
||||
acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
|
||||
})
|
||||
})?;
|
||||
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
|
||||
|
||||
// Create Thread
|
||||
let (session_id, thread) = agent.update(
|
||||
let thread = agent.update(
|
||||
cx,
|
||||
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
|
||||
// Fetch default model from registry settings
|
||||
@@ -146,22 +412,18 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
anyhow!("No default model configured. Please configure a default model in settings.")
|
||||
})?;
|
||||
|
||||
let thread = cx.new(|_| Thread::new(project.clone(), agent.templates.clone(), default_model));
|
||||
let thread = cx.new(|_| {
|
||||
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
|
||||
thread.add_tool(ThinkingTool);
|
||||
thread.add_tool(FindPathTool::new(project.clone()));
|
||||
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
|
||||
thread
|
||||
});
|
||||
|
||||
// Generate session ID
|
||||
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
|
||||
log::info!("Created session with ID: {}", session_id);
|
||||
Ok((session_id, thread))
|
||||
Ok(thread)
|
||||
},
|
||||
)??;
|
||||
|
||||
// Create AcpThread
|
||||
let acp_thread = cx.update(|cx| {
|
||||
cx.new(|cx| {
|
||||
acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx)
|
||||
})
|
||||
})?;
|
||||
|
||||
// Store the session
|
||||
agent.update(cx, |agent, cx| {
|
||||
agent.sessions.insert(
|
||||
@@ -264,6 +526,28 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
|
||||
tool_call,
|
||||
options,
|
||||
response,
|
||||
}) => {
|
||||
let recv = acp_thread.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_authorization(tool_call, options, cx)
|
||||
})?;
|
||||
cx.background_spawn(async move {
|
||||
if let Some(option) = recv
|
||||
.await
|
||||
.context("authorization sender was dropped")
|
||||
.log_err()
|
||||
{
|
||||
response
|
||||
.send(option)
|
||||
.map(|_| anyhow!("authorization receiver was dropped"))
|
||||
.log_err();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
AgentResponseEvent::ToolCall(tool_call) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
@@ -343,3 +627,77 @@ fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
|
||||
|
||||
message
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use fs::FakeFs;
|
||||
use gpui::TestAppContext;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_maintaining_project_context(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/",
|
||||
json!({
|
||||
"a": {}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
agent.read_with(cx, |agent, _| {
|
||||
assert_eq!(agent.project_context.borrow().worktrees, vec![])
|
||||
});
|
||||
|
||||
let worktree = project
|
||||
.update(cx, |project, cx| project.create_worktree("/a", true, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
agent.read_with(cx, |agent, _| {
|
||||
assert_eq!(
|
||||
agent.project_context.borrow().worktrees,
|
||||
vec![WorktreeContext {
|
||||
root_name: "a".into(),
|
||||
abs_path: Path::new("/a").into(),
|
||||
rules_file: None
|
||||
}]
|
||||
)
|
||||
});
|
||||
|
||||
// Creating `/a/.rules` updates the project context.
|
||||
fs.insert_file("/a/.rules", Vec::new()).await;
|
||||
cx.run_until_parked();
|
||||
agent.read_with(cx, |agent, cx| {
|
||||
let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
|
||||
assert_eq!(
|
||||
agent.project_context.borrow().worktrees,
|
||||
vec![WorktreeContext {
|
||||
root_name: "a".into(),
|
||||
abs_path: Path::new("/a").into(),
|
||||
rules_file: Some(RulesFileContext {
|
||||
path_in_worktree: Path::new(".rules").into(),
|
||||
text: "".into(),
|
||||
project_entry_id: rules_entry.id.to_usize()
|
||||
})
|
||||
}]
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
env_logger::try_init().ok();
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
language::init(cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
mod agent;
|
||||
mod native_agent_server;
|
||||
mod prompts;
|
||||
mod templates;
|
||||
mod thread;
|
||||
mod tools;
|
||||
@@ -11,3 +10,4 @@ mod tests;
|
||||
pub use agent::*;
|
||||
pub use native_agent_server::NativeAgentServer;
|
||||
pub use thread::*;
|
||||
pub use tools::*;
|
||||
|
||||
@@ -3,8 +3,9 @@ use std::rc::Rc;
|
||||
|
||||
use agent_servers::AgentServer;
|
||||
use anyhow::Result;
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use gpui::{App, Entity, Task};
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
|
||||
use crate::{templates::Templates, NativeAgent, NativeAgentConnection};
|
||||
|
||||
@@ -32,21 +33,22 @@ impl AgentServer for NativeAgentServer {
|
||||
fn connect(
|
||||
&self,
|
||||
_root_dir: &Path,
|
||||
_project: &Entity<Project>,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
|
||||
log::info!(
|
||||
"NativeAgentServer::connect called for path: {:?}",
|
||||
_root_dir
|
||||
);
|
||||
let project = project.clone();
|
||||
let prompt_store = PromptStore::global(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
log::debug!("Creating templates for native agent");
|
||||
// Create templates (you might want to load these from files or resources)
|
||||
let templates = Templates::new();
|
||||
let prompt_store = prompt_store.await?;
|
||||
|
||||
// Create the native agent
|
||||
log::debug!("Creating native agent entity");
|
||||
let agent = cx.update(|cx| cx.new(|_| NativeAgent::new(templates)))?;
|
||||
let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?;
|
||||
|
||||
// Create the connection wrapper
|
||||
let connection = NativeAgentConnection(agent);
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
use crate::{
|
||||
templates::{BaseTemplate, Template, Templates, WorktreeData},
|
||||
thread::Prompt,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use gpui::{App, Entity};
|
||||
use project::Project;
|
||||
|
||||
pub struct BasePrompt {
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
impl BasePrompt {
|
||||
pub fn new(project: Entity<Project>) -> Self {
|
||||
Self { project }
|
||||
}
|
||||
}
|
||||
|
||||
impl Prompt for BasePrompt {
|
||||
fn render(&self, templates: &Templates, cx: &App) -> Result<String> {
|
||||
BaseTemplate {
|
||||
os: std::env::consts::OS.to_string(),
|
||||
shell: util::get_system_shell(),
|
||||
worktrees: self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| WorktreeData {
|
||||
root_name: worktree.read(cx).root_name().to_string(),
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
.render(templates)
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use gpui::SharedString;
|
||||
use handlebars::Handlebars;
|
||||
use rust_embed::RustEmbed;
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(RustEmbed)]
|
||||
#[folder = "src/templates"]
|
||||
@@ -15,6 +15,8 @@ pub struct Templates(Handlebars<'static>);
|
||||
impl Templates {
|
||||
pub fn new() -> Arc<Self> {
|
||||
let mut handlebars = Handlebars::new();
|
||||
handlebars.set_strict_mode(true);
|
||||
handlebars.register_helper("contains", Box::new(contains));
|
||||
handlebars.register_embed_templates::<Assets>().unwrap();
|
||||
Arc::new(Self(handlebars))
|
||||
}
|
||||
@@ -32,26 +34,54 @@ pub trait Template: Sized {
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct BaseTemplate {
|
||||
pub os: String,
|
||||
pub shell: String,
|
||||
pub worktrees: Vec<WorktreeData>,
|
||||
pub struct SystemPromptTemplate<'a> {
|
||||
#[serde(flatten)]
|
||||
pub project: &'a prompt_store::ProjectContext,
|
||||
pub available_tools: Vec<SharedString>,
|
||||
}
|
||||
|
||||
impl Template for BaseTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "base.hbs";
|
||||
impl Template for SystemPromptTemplate<'_> {
|
||||
const TEMPLATE_NAME: &'static str = "system_prompt.hbs";
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct WorktreeData {
|
||||
pub root_name: String,
|
||||
/// Handlebars helper for checking if an item is in a list
|
||||
fn contains(
|
||||
h: &handlebars::Helper,
|
||||
_: &handlebars::Handlebars,
|
||||
_: &handlebars::Context,
|
||||
_: &mut handlebars::RenderContext,
|
||||
out: &mut dyn handlebars::Output,
|
||||
) -> handlebars::HelperResult {
|
||||
let list = h
|
||||
.param(0)
|
||||
.and_then(|v| v.value().as_array())
|
||||
.ok_or_else(|| {
|
||||
handlebars::RenderError::new("contains: missing or invalid list parameter")
|
||||
})?;
|
||||
let query = h.param(1).map(|v| v.value()).ok_or_else(|| {
|
||||
handlebars::RenderError::new("contains: missing or invalid query parameter")
|
||||
})?;
|
||||
|
||||
if list.contains(&query) {
|
||||
out.write("true")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct GlobTemplate {
|
||||
pub project_roots: String,
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
impl Template for GlobTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "glob.hbs";
|
||||
#[test]
|
||||
fn test_system_prompt_template() {
|
||||
let project = prompt_store::ProjectContext::default();
|
||||
let template = SystemPromptTemplate {
|
||||
project: &project,
|
||||
available_tools: vec!["echo".into()],
|
||||
};
|
||||
let templates = Templates::new();
|
||||
let rendered = template.render(&templates).unwrap();
|
||||
assert!(rendered.contains("## Fixing Diagnostics"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
|
||||
|
||||
## Communication
|
||||
|
||||
1. Be conversational but professional.
|
||||
2. Refer to the USER in the second person and yourself in the first person.
|
||||
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
|
||||
4. NEVER lie or make things up.
|
||||
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
|
||||
|
||||
## Tool Use
|
||||
|
||||
1. Make sure to adhere to the tools schema.
|
||||
2. Provide every required argument.
|
||||
3. DO NOT use tools to access items that are already available in the context section.
|
||||
4. Use only the tools that are currently available.
|
||||
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
|
||||
|
||||
## Searching and Reading
|
||||
|
||||
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
|
||||
|
||||
If appropriate, use tool calls to explore the current project, which contains the following root directories:
|
||||
|
||||
{{#each worktrees}}
|
||||
- `{{root_name}}`
|
||||
{{/each}}
|
||||
|
||||
- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above.
|
||||
- When looking for symbols in the project, prefer the `grep` tool.
|
||||
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
|
||||
- Bias towards not asking the user for help if you can find the answer yourself.
|
||||
|
||||
## Fixing Diagnostics
|
||||
|
||||
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
|
||||
2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem.
|
||||
|
||||
## Debugging
|
||||
|
||||
When debugging, only make code changes if you are certain that you can solve the problem.
|
||||
Otherwise, follow debugging best practices:
|
||||
1. Address the root cause instead of the symptoms.
|
||||
2. Add descriptive logging statements and error messages to track variable and code state.
|
||||
3. Add test functions and statements to isolate the problem.
|
||||
|
||||
## Calling External APIs
|
||||
|
||||
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
|
||||
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
|
||||
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
|
||||
|
||||
## System Information
|
||||
|
||||
Operating System: {{os}}
|
||||
Default Shell: {{shell}}
|
||||
@@ -1,8 +0,0 @@
|
||||
Find paths on disk with glob patterns.
|
||||
|
||||
Assume that all glob patterns are matched in a project directory with the following entries.
|
||||
|
||||
{{project_roots}}
|
||||
|
||||
When searching with patterns that begin with literal path components, e.g. `foo/bar/**/*.rs`, be
|
||||
sure to anchor them with one of the directories listed above.
|
||||
178
crates/agent2/src/templates/system_prompt.hbs
Normal file
@@ -0,0 +1,178 @@
|
||||
You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
|
||||
|
||||
## Communication
|
||||
|
||||
1. Be conversational but professional.
|
||||
2. Refer to the user in the second person and yourself in the first person.
|
||||
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
|
||||
4. NEVER lie or make things up.
|
||||
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
|
||||
|
||||
{{#if (gt (len available_tools) 0)}}
|
||||
## Tool Use
|
||||
|
||||
1. Make sure to adhere to the tools schema.
|
||||
2. Provide every required argument.
|
||||
3. DO NOT use tools to access items that are already available in the context section.
|
||||
4. Use only the tools that are currently available.
|
||||
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
|
||||
6. NEVER run commands that don't terminate on their own such as web servers (like `npm run start`, `npm run dev`, `python -m http.server`, etc) or file watchers.
|
||||
7. Avoid HTML entity escaping - use plain characters instead.
|
||||
|
||||
## Searching and Reading
|
||||
|
||||
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
|
||||
|
||||
If appropriate, use tool calls to explore the current project, which contains the following root directories:
|
||||
|
||||
{{#each worktrees}}
|
||||
- `{{abs_path}}`
|
||||
{{/each}}
|
||||
|
||||
- Bias towards not asking the user for help if you can find the answer yourself.
|
||||
- When providing paths to tools, the path should always start with the name of a project root directory listed above.
|
||||
- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path!
|
||||
{{# if (contains available_tools 'grep') }}
|
||||
- When looking for symbols in the project, prefer the `grep` tool.
|
||||
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
|
||||
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
|
||||
{{/if}}
|
||||
{{else}}
|
||||
You are being tasked with providing a response, but you have no ability to use tools or to read or write any aspect of the user's system (other than any context the user might have provided to you).
|
||||
|
||||
As such, if you need the user to perform any actions for you, you must request them explicitly. Bias towards giving a response to the best of your ability, and then making requests for the user to take action (e.g. to give you more context) only optionally.
|
||||
|
||||
The one exception to this is if the user references something you don't know about - for example, the name of a source code file, function, type, or other piece of code that you have no awareness of. In this case, you MUST NOT MAKE SOMETHING UP, or assume you know what that thing is or how it works. Instead, you must ask the user for clarification rather than giving a response.
|
||||
{{/if}}
|
||||
|
||||
## Code Block Formatting
|
||||
|
||||
Whenever you mention a code block, you MUST use ONLY use the following format:
|
||||
```path/to/Something.blah#L123-456
|
||||
(code goes here)
|
||||
```
|
||||
The `#L123-456` means the line number range 123 through 456, and the path/to/Something.blah
|
||||
is a path in the project. (If there is no valid path in the project, then you can use
|
||||
/dev/null/path.extension for its path.) This is the ONLY valid way to format code blocks, because the Markdown parser
|
||||
does not understand the more common ```language syntax, or bare ``` blocks. It only
|
||||
understands this path-based syntax, and if the path is missing, then it will error and you will have to do it over again.
|
||||
Just to be really clear about this, if you ever find yourself writing three backticks followed by a language name, STOP!
|
||||
You have made a mistake. You can only ever put paths after triple backticks!
|
||||
<example>
|
||||
Based on all the information I've gathered, here's a summary of how this system works:
|
||||
1. The README file is loaded into the system.
|
||||
2. The system finds the first two headers, including everything in between. In this case, that would be:
|
||||
```path/to/README.md#L8-12
|
||||
# First Header
|
||||
This is the info under the first header.
|
||||
## Sub-header
|
||||
```
|
||||
3. Then the system finds the last header in the README:
|
||||
```path/to/README.md#L27-29
|
||||
## Last Header
|
||||
This is the last header in the README.
|
||||
```
|
||||
4. Finally, it passes this information on to the next process.
|
||||
</example>
|
||||
<example>
|
||||
In Markdown, hash marks signify headings. For example:
|
||||
```/dev/null/example.md#L1-3
|
||||
# Level 1 heading
|
||||
## Level 2 heading
|
||||
### Level 3 heading
|
||||
```
|
||||
</example>
|
||||
Here are examples of ways you must never render code blocks:
|
||||
<bad_example_do_not_do_this>
|
||||
In Markdown, hash marks signify headings. For example:
|
||||
```
|
||||
# Level 1 heading
|
||||
## Level 2 heading
|
||||
### Level 3 heading
|
||||
```
|
||||
</bad_example_do_not_do_this>
|
||||
This example is unacceptable because it does not include the path.
|
||||
<bad_example_do_not_do_this>
|
||||
In Markdown, hash marks signify headings. For example:
|
||||
```markdown
|
||||
# Level 1 heading
|
||||
## Level 2 heading
|
||||
### Level 3 heading
|
||||
```
|
||||
</bad_example_do_not_do_this>
|
||||
This example is unacceptable because it has the language instead of the path.
|
||||
<bad_example_do_not_do_this>
|
||||
In Markdown, hash marks signify headings. For example:
|
||||
# Level 1 heading
|
||||
## Level 2 heading
|
||||
### Level 3 heading
|
||||
</bad_example_do_not_do_this>
|
||||
This example is unacceptable because it uses indentation to mark the code block
|
||||
instead of backticks with a path.
|
||||
<bad_example_do_not_do_this>
|
||||
In Markdown, hash marks signify headings. For example:
|
||||
```markdown
|
||||
/dev/null/example.md#L1-3
|
||||
# Level 1 heading
|
||||
## Level 2 heading
|
||||
### Level 3 heading
|
||||
```
|
||||
</bad_example_do_not_do_this>
|
||||
This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks.
|
||||
|
||||
{{#if (gt (len available_tools) 0)}}
|
||||
## Fixing Diagnostics
|
||||
|
||||
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
|
||||
2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem.
|
||||
|
||||
## Debugging
|
||||
|
||||
When debugging, only make code changes if you are certain that you can solve the problem.
|
||||
Otherwise, follow debugging best practices:
|
||||
1. Address the root cause instead of the symptoms.
|
||||
2. Add descriptive logging statements and error messages to track variable and code state.
|
||||
3. Add test functions and statements to isolate the problem.
|
||||
|
||||
{{/if}}
|
||||
## Calling External APIs
|
||||
|
||||
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
|
||||
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file(s). If no such file exists or if the package is not present, use the latest version that is in your training data.
|
||||
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
|
||||
|
||||
## System Information
|
||||
|
||||
Operating System: {{os}}
|
||||
Default Shell: {{shell}}
|
||||
|
||||
{{#if (or has_rules has_user_rules)}}
|
||||
## User's Custom Instructions
|
||||
|
||||
The following additional instructions are provided by the user, and should be followed to the best of your ability{{#if (gt (len available_tools) 0)}} without interfering with the tool use guidelines{{/if}}.
|
||||
|
||||
{{#if has_rules}}
|
||||
There are project rules that apply to these root directories:
|
||||
{{#each worktrees}}
|
||||
{{#if rules_file}}
|
||||
`{{root_name}}/{{rules_file.path_in_worktree}}`:
|
||||
``````
|
||||
{{{rules_file.text}}}
|
||||
``````
|
||||
{{/if}}
|
||||
{{/each}}
|
||||
{{/if}}
|
||||
|
||||
{{#if has_user_rules}}
|
||||
The user has specified the following rules that should be applied:
|
||||
{{#each user_rules}}
|
||||
|
||||
{{#if title}}
|
||||
Rules title: {{title}}
|
||||
{{/if}}
|
||||
``````
|
||||
{{contents}}}
|
||||
``````
|
||||
{{/each}}
|
||||
{{/if}}
|
||||
{{/if}}
|
||||
@@ -1,30 +1,34 @@
|
||||
use super::*;
|
||||
use crate::templates::Templates;
|
||||
use acp_thread::AgentConnection;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
use assistant_tool::ActionLog;
|
||||
use client::{Client, UserStore};
|
||||
use fs::FakeFs;
|
||||
use futures::channel::mpsc::UnboundedReceiver;
|
||||
use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language_model::{
|
||||
fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
|
||||
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, MessageContent,
|
||||
StopReason,
|
||||
LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
|
||||
LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
use util::path;
|
||||
|
||||
mod test_tools;
|
||||
use test_tools::*;
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_echo(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
@@ -44,7 +48,7 @@ async fn test_echo(cx: &mut TestAppContext) {
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_thinking(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
|
||||
|
||||
@@ -77,7 +81,46 @@ async fn test_thinking(cx: &mut TestAppContext) {
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
async fn test_system_prompt(cx: &mut TestAppContext) {
|
||||
let ThreadTest {
|
||||
model,
|
||||
thread,
|
||||
project_context,
|
||||
..
|
||||
} = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
project_context.borrow_mut().shell = "test-shell".into();
|
||||
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
||||
thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
|
||||
cx.run_until_parked();
|
||||
let mut pending_completions = fake_model.pending_completions();
|
||||
assert_eq!(
|
||||
pending_completions.len(),
|
||||
1,
|
||||
"unexpected pending completions: {:?}",
|
||||
pending_completions
|
||||
);
|
||||
|
||||
let pending_completion = pending_completions.pop().unwrap();
|
||||
assert_eq!(pending_completion.messages[0].role, Role::System);
|
||||
|
||||
let system_message = &pending_completion.messages[0];
|
||||
let system_prompt = system_message.content[0].to_str().unwrap();
|
||||
assert!(
|
||||
system_prompt.contains("test-shell"),
|
||||
"unexpected system message: {:?}",
|
||||
system_message
|
||||
);
|
||||
assert!(
|
||||
system_prompt.contains("## Fixing Diagnostics"),
|
||||
"unexpected system message: {:?}",
|
||||
system_message
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
@@ -127,7 +170,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
@@ -175,7 +218,161 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(ToolRequiringPermission);
|
||||
thread.send(model.clone(), "abc", cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "tool_id_1".into(),
|
||||
name: ToolRequiringPermission.name().into(),
|
||||
raw_input: "{}".into(),
|
||||
input: json!({}),
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "tool_id_2".into(),
|
||||
name: ToolRequiringPermission.name().into(),
|
||||
raw_input: "{}".into(),
|
||||
input: json!({}),
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
|
||||
let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
|
||||
|
||||
// Approve the first
|
||||
tool_call_auth_1
|
||||
.response
|
||||
.send(tool_call_auth_1.options[1].id.clone())
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
// Reject the second
|
||||
tool_call_auth_2
|
||||
.response
|
||||
.send(tool_call_auth_1.options[2].id.clone())
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
let message = completion.messages.last().unwrap();
|
||||
assert_eq!(
|
||||
message.content,
|
||||
vec![
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
|
||||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: false,
|
||||
content: "Allowed".into(),
|
||||
output: None
|
||||
}),
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
|
||||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: true,
|
||||
content: "Permission to run tool denied by user".into(),
|
||||
output: None
|
||||
})
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_tool_hallucination(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "tool_id_1".into(),
|
||||
name: "nonexistent_tool".into(),
|
||||
raw_input: "{}".into(),
|
||||
input: json!({}),
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
|
||||
let tool_call = expect_tool_call(&mut events).await;
|
||||
assert_eq!(tool_call.title, "nonexistent_tool");
|
||||
assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
|
||||
let update = expect_tool_call_update(&mut events).await;
|
||||
assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
|
||||
}
|
||||
|
||||
async fn expect_tool_call(
|
||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
) -> acp::ToolCall {
|
||||
let event = events
|
||||
.next()
|
||||
.await
|
||||
.expect("no tool call authorization event received")
|
||||
.unwrap();
|
||||
match event {
|
||||
AgentResponseEvent::ToolCall(tool_call) => return tool_call,
|
||||
event => {
|
||||
panic!("Unexpected event {event:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn expect_tool_call_update(
|
||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
) -> acp::ToolCallUpdate {
|
||||
let event = events
|
||||
.next()
|
||||
.await
|
||||
.expect("no tool call authorization event received")
|
||||
.unwrap();
|
||||
match event {
|
||||
AgentResponseEvent::ToolCallUpdate(tool_call_update) => return tool_call_update,
|
||||
event => {
|
||||
panic!("Unexpected event {event:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn next_tool_call_authorization(
|
||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
) -> ToolCallAuthorization {
|
||||
loop {
|
||||
let event = events
|
||||
.next()
|
||||
.await
|
||||
.expect("no tool call authorization event received")
|
||||
.unwrap();
|
||||
if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
|
||||
let permission_kinds = tool_call_authorization
|
||||
.options
|
||||
.iter()
|
||||
.map(|o| o.kind)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
permission_kinds,
|
||||
vec![
|
||||
acp::PermissionOptionKind::AllowAlways,
|
||||
acp::PermissionOptionKind::AllowOnce,
|
||||
acp::PermissionOptionKind::RejectOnce,
|
||||
]
|
||||
);
|
||||
return tool_call_authorization;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
@@ -214,7 +411,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "temporarily disabled until it can be run on CI"]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_cancellation(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
|
||||
|
||||
@@ -281,12 +478,10 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_refusal(cx: &mut TestAppContext) {
|
||||
let fake_model = Arc::new(FakeLanguageModel::default());
|
||||
let ThreadTest { thread, .. } = setup(cx, TestModel::Fake(fake_model.clone())).await;
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let events = thread.update(cx, |thread, cx| {
|
||||
thread.send(fake_model.clone(), "Hello", cx)
|
||||
});
|
||||
let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
@@ -343,8 +538,16 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
});
|
||||
cx.executor().forbid_parking();
|
||||
|
||||
// Create a project for new_thread
|
||||
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
|
||||
fake_fs.insert_tree(path!("/test"), json!({})).await;
|
||||
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
|
||||
let cwd = Path::new("/test");
|
||||
|
||||
// Create agent and connection
|
||||
let agent = cx.new(|_| NativeAgent::new(templates.clone()));
|
||||
let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
let connection = NativeAgentConnection(agent.clone());
|
||||
|
||||
// Test model_selector returns Some
|
||||
@@ -366,12 +569,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
assert!(!listed_models.is_empty(), "should have at least one model");
|
||||
assert_eq!(listed_models[0].id().0, "fake");
|
||||
|
||||
// Create a project for new_thread
|
||||
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
|
||||
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
|
||||
|
||||
// Create a thread using new_thread
|
||||
let cwd = Path::new("/test");
|
||||
let connection_rc = Rc::new(connection.clone());
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
@@ -441,6 +639,77 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||
let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
|
||||
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx));
|
||||
cx.run_until_parked();
|
||||
|
||||
let input = json!({ "content": "Thinking hard!" });
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: "1".into(),
|
||||
name: ThinkingTool.name().into(),
|
||||
raw_input: input.to_string(),
|
||||
input,
|
||||
is_input_complete: true,
|
||||
},
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
let tool_call = expect_tool_call(&mut events).await;
|
||||
assert_eq!(
|
||||
tool_call,
|
||||
acp::ToolCall {
|
||||
id: acp::ToolCallId("1".into()),
|
||||
title: "Thinking".into(),
|
||||
kind: acp::ToolKind::Think,
|
||||
status: acp::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
locations: vec![],
|
||||
raw_input: Some(json!({ "content": "Thinking hard!" })),
|
||||
raw_output: None,
|
||||
}
|
||||
);
|
||||
let update = expect_tool_call_update(&mut events).await;
|
||||
assert_eq!(
|
||||
update,
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId("1".into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress,),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
);
|
||||
let update = expect_tool_call_update(&mut events).await;
|
||||
assert_eq!(
|
||||
update,
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId("1".into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
content: Some(vec!["Thinking hard!".into()]),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
);
|
||||
let update = expect_tool_call_update(&mut events).await;
|
||||
assert_eq!(
|
||||
update,
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId("1".into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::Completed),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
/// Filters out the stop events for asserting against in tests
|
||||
fn stop_events(
|
||||
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
@@ -457,12 +726,13 @@ fn stop_events(
|
||||
struct ThreadTest {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
thread: Entity<Thread>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
}
|
||||
|
||||
enum TestModel {
|
||||
Sonnet4,
|
||||
Sonnet4Thinking,
|
||||
Fake(Arc<FakeLanguageModel>),
|
||||
Fake,
|
||||
}
|
||||
|
||||
impl TestModel {
|
||||
@@ -470,7 +740,7 @@ impl TestModel {
|
||||
match self {
|
||||
TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
|
||||
TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
|
||||
TestModel::Fake(fake_model) => fake_model.id(),
|
||||
TestModel::Fake => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -499,8 +769,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
if let TestModel::Fake(model) = model {
|
||||
Task::ready(model as Arc<_>)
|
||||
if let TestModel::Fake = model {
|
||||
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
|
||||
} else {
|
||||
let model_id = model.id();
|
||||
let models = LanguageModelRegistry::read_global(cx);
|
||||
@@ -520,9 +790,22 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||
})
|
||||
.await;
|
||||
|
||||
let thread = cx.new(|_| Thread::new(project, templates, model.clone()));
|
||||
|
||||
ThreadTest { model, thread }
|
||||
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let thread = cx.new(|_| {
|
||||
Thread::new(
|
||||
project,
|
||||
project_context.clone(),
|
||||
action_log,
|
||||
templates,
|
||||
model.clone(),
|
||||
)
|
||||
});
|
||||
ThreadTest {
|
||||
model,
|
||||
thread,
|
||||
project_context,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -19,7 +19,20 @@ impl AgentTool for EchoTool {
|
||||
"echo".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
acp::ToolKind::Other
|
||||
}
|
||||
|
||||
fn initial_title(&self, _: Self::Input) -> SharedString {
|
||||
"Echo".into()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
_event_stream: ToolCallEventStream,
|
||||
_cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
Task::ready(Ok(input.text))
|
||||
}
|
||||
}
|
||||
@@ -40,7 +53,20 @@ impl AgentTool for DelayTool {
|
||||
"delay".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>
|
||||
fn initial_title(&self, input: Self::Input) -> SharedString {
|
||||
format!("Delay {}ms", input.ms).into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
acp::ToolKind::Other
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
@@ -51,6 +77,43 @@ impl AgentTool for DelayTool {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct ToolRequiringPermissionInput {}
|
||||
|
||||
pub struct ToolRequiringPermission;
|
||||
|
||||
impl AgentTool for ToolRequiringPermission {
|
||||
type Input = ToolRequiringPermissionInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"tool_requiring_permission".into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
acp::ToolKind::Other
|
||||
}
|
||||
|
||||
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||
"This tool requires permission".into()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let auth_check = self.authorize(input, event_stream);
|
||||
cx.foreground_executor().spawn(async move {
|
||||
auth_check.await?;
|
||||
Ok("Allowed".to_string())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct InfiniteToolInput {}
|
||||
|
||||
@@ -63,7 +126,20 @@ impl AgentTool for InfiniteTool {
|
||||
"infinite".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, _input: Self::Input, cx: &mut App) -> Task<Result<String>> {
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
acp::ToolKind::Other
|
||||
}
|
||||
|
||||
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||
"This is the tool that never ends... it just goes on and on my friends!".into()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_input: Self::Input,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
future::pending::<()>().await;
|
||||
unreachable!()
|
||||
@@ -100,7 +176,20 @@ impl AgentTool for WordListTool {
|
||||
"word_list".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, _input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
|
||||
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||
"List of random words".into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
acp::ToolKind::Other
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_input: Self::Input,
|
||||
_event_stream: ToolCallEventStream,
|
||||
_cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
Task::ready(Ok("ok".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +1,27 @@
|
||||
use crate::{prompts::BasePrompt, templates::Templates};
|
||||
use crate::templates::{SystemPromptTemplate, Template, Templates};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Result};
|
||||
use anyhow::{anyhow, Context as _, Result};
|
||||
use assistant_tool::{adapt_schema_to_format, ActionLog};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||
use collections::HashMap;
|
||||
use futures::{channel::mpsc, stream::FuturesUnordered};
|
||||
use gpui::{App, Context, Entity, ImageFormat, SharedString, Task};
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
stream::FuturesUnordered,
|
||||
};
|
||||
use gpui::{App, Context, Entity, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason,
|
||||
};
|
||||
use log;
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
use schemars::{JsonSchema, Schema};
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{collections::BTreeMap, fmt::Write, sync::Arc};
|
||||
use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
|
||||
use util::{markdown::MarkdownCodeBlock, ResultExt};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -97,11 +102,15 @@ pub enum AgentResponseEvent {
|
||||
Thinking(String),
|
||||
ToolCall(acp::ToolCall),
|
||||
ToolCallUpdate(acp::ToolCallUpdate),
|
||||
ToolCallAuthorization(ToolCallAuthorization),
|
||||
Stop(acp::StopReason),
|
||||
}
|
||||
|
||||
pub trait Prompt {
|
||||
fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
|
||||
#[derive(Debug)]
|
||||
pub struct ToolCallAuthorization {
|
||||
pub tool_call: acp::ToolCall,
|
||||
pub options: Vec<acp::PermissionOption>,
|
||||
pub response: oneshot::Sender<acp::PermissionOptionId>,
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
@@ -112,28 +121,31 @@ pub struct Thread {
|
||||
/// we run tools, report their results.
|
||||
running_turn: Option<Task<()>>,
|
||||
pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
|
||||
system_prompts: Vec<Arc<dyn Prompt>>,
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
templates: Arc<Templates>,
|
||||
pub selected_model: Arc<dyn LanguageModel>,
|
||||
// action_log: Entity<ActionLog>,
|
||||
action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(
|
||||
project: Entity<Project>,
|
||||
_project: Entity<Project>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
default_model: Arc<dyn LanguageModel>,
|
||||
) -> Self {
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
completion_mode: CompletionMode::Normal,
|
||||
system_prompts: vec![Arc::new(BasePrompt::new(project))],
|
||||
running_turn: None,
|
||||
pending_tool_uses: HashMap::default(),
|
||||
tools: BTreeMap::default(),
|
||||
project_context,
|
||||
templates,
|
||||
selected_model: default_model,
|
||||
action_log,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,6 +200,7 @@ impl Thread {
|
||||
cx.notify();
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
let event_stream = AgentResponseEventStream(events_tx);
|
||||
|
||||
let user_message_ix = self.messages.len();
|
||||
self.messages.push(AgentMessage {
|
||||
@@ -222,12 +235,7 @@ impl Thread {
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
|
||||
if let Some(reason) = to_acp_stop_reason(reason) {
|
||||
events_tx
|
||||
.unbounded_send(Ok(AgentResponseEvent::Stop(reason)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
event_stream.send_stop(reason);
|
||||
if reason == StopReason::Refusal {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.messages.truncate(user_message_ix);
|
||||
@@ -240,14 +248,16 @@ impl Thread {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
tool_uses.extend(thread.handle_streamed_completion_event(
|
||||
event, &events_tx, cx,
|
||||
event,
|
||||
&event_stream,
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("Error in completion stream: {:?}", error);
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
event_stream.send_error(error);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -266,11 +276,17 @@ impl Thread {
|
||||
while let Some(tool_result) = tool_uses.next().await {
|
||||
log::info!("Tool finished {:?}", tool_result);
|
||||
|
||||
events_tx
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
|
||||
to_acp_tool_call_update(&tool_result),
|
||||
)))
|
||||
.ok();
|
||||
event_stream.send_tool_call_update(
|
||||
&tool_result.tool_use_id,
|
||||
acp::ToolCallUpdateFields {
|
||||
status: Some(if tool_result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
thread
|
||||
.update(cx, |thread, _cx| {
|
||||
thread.pending_tool_uses.remove(&tool_result.tool_use_id);
|
||||
@@ -291,7 +307,7 @@ impl Thread {
|
||||
|
||||
if let Err(error) = turn_result {
|
||||
log::error!("Turn execution failed: {:?}", error);
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
event_stream.send_error(error);
|
||||
} else {
|
||||
log::info!("Turn execution completed successfully");
|
||||
}
|
||||
@@ -299,24 +315,24 @@ impl Thread {
|
||||
events_rx
|
||||
}
|
||||
|
||||
pub fn build_system_message(&self, cx: &App) -> Option<AgentMessage> {
|
||||
pub fn action_log(&self) -> &Entity<ActionLog> {
|
||||
&self.action_log
|
||||
}
|
||||
|
||||
pub fn build_system_message(&self) -> AgentMessage {
|
||||
log::debug!("Building system message");
|
||||
let mut system_message = AgentMessage {
|
||||
role: Role::System,
|
||||
content: Vec::new(),
|
||||
};
|
||||
|
||||
for prompt in &self.system_prompts {
|
||||
if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
|
||||
system_message
|
||||
.content
|
||||
.push(MessageContent::Text(rendered_prompt));
|
||||
}
|
||||
let prompt = SystemPromptTemplate {
|
||||
project: &self.project_context.borrow(),
|
||||
available_tools: self.tools.keys().cloned().collect(),
|
||||
}
|
||||
.render(&self.templates)
|
||||
.context("failed to build system prompt")
|
||||
.expect("Invalid template");
|
||||
log::debug!("System message built");
|
||||
AgentMessage {
|
||||
role: Role::System,
|
||||
content: vec![prompt.into()],
|
||||
}
|
||||
|
||||
let result = (!system_message.content.is_empty()).then_some(system_message);
|
||||
log::debug!("System message built: {}", result.is_some());
|
||||
result
|
||||
}
|
||||
|
||||
/// A helper method that's called on every streamed completion event.
|
||||
@@ -325,7 +341,7 @@ impl Thread {
|
||||
fn handle_streamed_completion_event(
|
||||
&mut self,
|
||||
event: LanguageModelCompletionEvent,
|
||||
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
event_stream: &AgentResponseEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
log::trace!("Handling streamed completion event: {:?}", event);
|
||||
@@ -338,13 +354,13 @@ impl Thread {
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
Text(new_text) => self.handle_text_event(new_text, events_tx, cx),
|
||||
Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
|
||||
Thinking { text, signature } => {
|
||||
self.handle_thinking_event(text, signature, events_tx, cx)
|
||||
self.handle_thinking_event(text, signature, event_stream, cx)
|
||||
}
|
||||
RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
|
||||
ToolUse(tool_use) => {
|
||||
return self.handle_tool_use_event(tool_use, events_tx, cx);
|
||||
return self.handle_tool_use_event(tool_use, event_stream, cx);
|
||||
}
|
||||
ToolUseJsonParseError {
|
||||
id,
|
||||
@@ -369,12 +385,10 @@ impl Thread {
|
||||
fn handle_text_event(
|
||||
&mut self,
|
||||
new_text: String,
|
||||
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
events_stream: &AgentResponseEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
events_tx
|
||||
.unbounded_send(Ok(AgentResponseEvent::Text(new_text.clone())))
|
||||
.ok();
|
||||
events_stream.send_text(&new_text);
|
||||
|
||||
let last_message = self.last_assistant_message();
|
||||
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
|
||||
@@ -390,12 +404,10 @@ impl Thread {
|
||||
&mut self,
|
||||
new_text: String,
|
||||
new_signature: Option<String>,
|
||||
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
event_stream: &AgentResponseEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
events_tx
|
||||
.unbounded_send(Ok(AgentResponseEvent::Thinking(new_text.clone())))
|
||||
.ok();
|
||||
event_stream.send_thinking(&new_text);
|
||||
|
||||
let last_message = self.last_assistant_message();
|
||||
if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
|
||||
@@ -423,11 +435,13 @@ impl Thread {
|
||||
fn handle_tool_use_event(
|
||||
&mut self,
|
||||
tool_use: LanguageModelToolUse,
|
||||
events_tx: &mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
event_stream: &AgentResponseEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
cx.notify();
|
||||
|
||||
let tool = self.tools.get(tool_use.name.as_ref()).cloned();
|
||||
|
||||
self.pending_tool_uses
|
||||
.insert(tool_use.id.clone(), tool_use.clone());
|
||||
let last_message = self.last_assistant_message();
|
||||
@@ -445,82 +459,74 @@ impl Thread {
|
||||
true
|
||||
}
|
||||
});
|
||||
|
||||
if push_new_tool_use {
|
||||
events_tx
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCall(acp::ToolCall {
|
||||
id: acp::ToolCallId(tool_use.id.to_string().into()),
|
||||
title: tool_use.name.to_string(),
|
||||
kind: acp::ToolKind::Other,
|
||||
status: acp::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
locations: vec![],
|
||||
raw_input: Some(tool_use.input.clone()),
|
||||
})))
|
||||
.ok();
|
||||
event_stream.send_tool_call(tool.as_ref(), &tool_use);
|
||||
last_message
|
||||
.content
|
||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||
} else {
|
||||
events_tx
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(tool_use.id.to_string().into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
raw_input: Some(tool_use.input.clone()),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
)))
|
||||
.ok();
|
||||
event_stream.send_tool_call_update(
|
||||
&tool_use.id,
|
||||
acp::ToolCallUpdateFields {
|
||||
raw_input: Some(tool_use.input.clone()),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
if !tool_use.is_input_complete {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
|
||||
events_tx
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(tool_use.id.to_string().into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
)))
|
||||
.ok();
|
||||
|
||||
let pending_tool_result = tool.clone().run(tool_use.input, cx);
|
||||
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
match pending_tool_result.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
||||
output: None,
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
|
||||
output: None,
|
||||
},
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
let Some(tool) = tool else {
|
||||
let content = format!("No tool named {} exists", tool_use.name);
|
||||
Some(Task::ready(LanguageModelToolResult {
|
||||
return Some(Task::ready(LanguageModelToolResult {
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(content)),
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
output: None,
|
||||
}))
|
||||
}
|
||||
}));
|
||||
};
|
||||
|
||||
let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx);
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
match tool_result.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
||||
output: None,
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
|
||||
output: None,
|
||||
},
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn run_tool(
|
||||
&self,
|
||||
tool: Arc<dyn AnyAgentTool>,
|
||||
tool_use: LanguageModelToolUse,
|
||||
event_stream: AgentResponseEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<String>> {
|
||||
cx.spawn(async move |_this, cx| {
|
||||
let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream);
|
||||
tool_event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
});
|
||||
cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))?
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_tool_use_json_parse_error_event(
|
||||
@@ -575,7 +581,7 @@ impl Thread {
|
||||
log::debug!("Completion intent: {:?}", completion_intent);
|
||||
log::debug!("Completion mode: {:?}", self.completion_mode);
|
||||
|
||||
let messages = self.build_request_messages(cx);
|
||||
let messages = self.build_request_messages();
|
||||
log::info!("Request will include {} messages", messages.len());
|
||||
|
||||
let tools: Vec<LanguageModelRequestTool> = self
|
||||
@@ -588,7 +594,7 @@ impl Thread {
|
||||
name: tool_name,
|
||||
description: tool.description(cx).to_string(),
|
||||
input_schema: tool
|
||||
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
|
||||
.input_schema(self.selected_model.tool_input_format())
|
||||
.log_err()?,
|
||||
})
|
||||
})
|
||||
@@ -613,14 +619,13 @@ impl Thread {
|
||||
request
|
||||
}
|
||||
|
||||
fn build_request_messages(&self, cx: &App) -> Vec<LanguageModelRequestMessage> {
|
||||
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
log::trace!(
|
||||
"Building request messages from {} thread messages",
|
||||
self.messages.len()
|
||||
);
|
||||
|
||||
let messages = self
|
||||
.build_system_message(cx)
|
||||
let messages = Some(self.build_system_message())
|
||||
.iter()
|
||||
.chain(self.messages.iter())
|
||||
.map(|message| {
|
||||
@@ -656,9 +661,10 @@ pub trait AgentTool
|
||||
where
|
||||
Self: 'static + Sized,
|
||||
{
|
||||
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
||||
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema;
|
||||
|
||||
fn name(&self) -> SharedString;
|
||||
|
||||
fn description(&self, _cx: &mut App) -> SharedString {
|
||||
let schema = schemars::schema_for!(Self::Input);
|
||||
SharedString::new(
|
||||
@@ -669,13 +675,33 @@ where
|
||||
)
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind;
|
||||
|
||||
/// The initial tool title to display. Can be updated during the tool run.
|
||||
fn initial_title(&self, input: Self::Input) -> SharedString;
|
||||
|
||||
/// Returns the JSON schema that describes the tool's input.
|
||||
fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema {
|
||||
fn input_schema(&self) -> Schema {
|
||||
schemars::schema_for!(Self::Input)
|
||||
}
|
||||
|
||||
/// Allows the tool to authorize a given tool call with the user if necessary
|
||||
fn authorize(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
) -> impl use<Self> + Future<Output = Result<()>> {
|
||||
let json_input = serde_json::json!(&input);
|
||||
event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input)
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>;
|
||||
|
||||
fn erase(self) -> Arc<dyn AnyAgentTool> {
|
||||
Arc::new(Erased(Arc::new(self)))
|
||||
@@ -687,8 +713,15 @@ pub struct Erased<T>(T);
|
||||
pub trait AnyAgentTool {
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, cx: &mut App) -> SharedString;
|
||||
fn kind(&self) -> acp::ToolKind;
|
||||
fn initial_title(&self, input: serde_json::Value) -> Result<SharedString>;
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>>;
|
||||
}
|
||||
|
||||
impl<T> AnyAgentTool for Erased<Arc<T>>
|
||||
@@ -703,52 +736,220 @@ where
|
||||
self.0.description(cx)
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
Ok(serde_json::to_value(self.0.input_schema(format))?)
|
||||
fn kind(&self) -> agent_client_protocol::ToolKind {
|
||||
self.0.kind()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
|
||||
fn initial_title(&self, input: serde_json::Value) -> Result<SharedString> {
|
||||
let parsed_input = serde_json::from_value(input)?;
|
||||
Ok(self.0.initial_title(parsed_input))
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
let mut json = serde_json::to_value(self.0.input_schema())?;
|
||||
adapt_schema_to_format(&mut json, format)?;
|
||||
Ok(json)
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
||||
match parsed_input {
|
||||
Ok(input) => self.0.clone().run(input, cx),
|
||||
Ok(input) => self.0.clone().run(input, event_stream, cx),
|
||||
Err(error) => Task::ready(Err(anyhow!(error))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn to_acp_stop_reason(reason: StopReason) -> Option<acp::StopReason> {
|
||||
match reason {
|
||||
StopReason::EndTurn => Some(acp::StopReason::EndTurn),
|
||||
StopReason::MaxTokens => Some(acp::StopReason::MaxTokens),
|
||||
StopReason::Refusal => Some(acp::StopReason::Refusal),
|
||||
StopReason::ToolUse => None,
|
||||
#[derive(Clone)]
|
||||
struct AgentResponseEventStream(
|
||||
mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
);
|
||||
|
||||
impl AgentResponseEventStream {
|
||||
fn send_text(&self, text: &str) {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn send_thinking(&self, text: &str) {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn authorize_tool_call(
|
||||
&self,
|
||||
id: &LanguageModelToolUseId,
|
||||
title: String,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
) -> impl use<> + Future<Output = Result<()>> {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
|
||||
ToolCallAuthorization {
|
||||
tool_call: Self::initial_tool_call(id, title, kind, input),
|
||||
options: vec![
|
||||
acp::PermissionOption {
|
||||
id: acp::PermissionOptionId("always_allow".into()),
|
||||
name: "Always Allow".into(),
|
||||
kind: acp::PermissionOptionKind::AllowAlways,
|
||||
},
|
||||
acp::PermissionOption {
|
||||
id: acp::PermissionOptionId("allow".into()),
|
||||
name: "Allow".into(),
|
||||
kind: acp::PermissionOptionKind::AllowOnce,
|
||||
},
|
||||
acp::PermissionOption {
|
||||
id: acp::PermissionOptionId("deny".into()),
|
||||
name: "Deny".into(),
|
||||
kind: acp::PermissionOptionKind::RejectOnce,
|
||||
},
|
||||
],
|
||||
response: response_tx,
|
||||
},
|
||||
)))
|
||||
.ok();
|
||||
async move {
|
||||
match response_rx.await?.0.as_ref() {
|
||||
"allow" | "always_allow" => Ok(()),
|
||||
_ => Err(anyhow!("Permission to run tool denied by user")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn send_tool_call(
|
||||
&self,
|
||||
tool: Option<&Arc<dyn AnyAgentTool>>,
|
||||
tool_use: &LanguageModelToolUse,
|
||||
) {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
|
||||
&tool_use.id,
|
||||
tool.and_then(|t| t.initial_title(tool_use.input.clone()).ok())
|
||||
.map(|i| i.into())
|
||||
.unwrap_or_else(|| tool_use.name.to_string()),
|
||||
tool.map(|t| t.kind()).unwrap_or(acp::ToolKind::Other),
|
||||
tool_use.input.clone(),
|
||||
))))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn initial_tool_call(
|
||||
id: &LanguageModelToolUseId,
|
||||
title: String,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
) -> acp::ToolCall {
|
||||
acp::ToolCall {
|
||||
id: acp::ToolCallId(id.to_string().into()),
|
||||
title,
|
||||
kind,
|
||||
status: acp::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
locations: vec![],
|
||||
raw_input: Some(input),
|
||||
raw_output: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn send_tool_call_update(
|
||||
&self,
|
||||
tool_use_id: &LanguageModelToolUseId,
|
||||
fields: acp::ToolCallUpdateFields,
|
||||
) {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(tool_use_id.to_string().into()),
|
||||
fields,
|
||||
},
|
||||
)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn send_stop(&self, reason: StopReason) {
|
||||
match reason {
|
||||
StopReason::EndTurn => {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
|
||||
.ok();
|
||||
}
|
||||
StopReason::MaxTokens => {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
|
||||
.ok();
|
||||
}
|
||||
StopReason::Refusal => {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
|
||||
.ok();
|
||||
}
|
||||
StopReason::ToolUse => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn send_error(&self, error: LanguageModelCompletionError) {
|
||||
self.0.unbounded_send(Err(error)).ok();
|
||||
}
|
||||
}
|
||||
|
||||
fn to_acp_tool_call_update(tool_result: &LanguageModelToolResult) -> acp::ToolCallUpdate {
|
||||
let status = if tool_result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
};
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => text.to_string().into(),
|
||||
LanguageModelToolResultContent::Image(LanguageModelImage { source, .. }) => {
|
||||
acp::ToolCallContent::Content {
|
||||
content: acp::ContentBlock::Image(acp::ImageContent {
|
||||
annotations: None,
|
||||
data: source.to_string(),
|
||||
mime_type: ImageFormat::Png.mime_type().to_string(),
|
||||
}),
|
||||
}
|
||||
#[derive(Clone)]
|
||||
pub struct ToolCallEventStream {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
stream: AgentResponseEventStream,
|
||||
}
|
||||
|
||||
impl ToolCallEventStream {
|
||||
fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self {
|
||||
Self {
|
||||
tool_use_id,
|
||||
stream,
|
||||
}
|
||||
};
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(tool_result.tool_use_id.to_string().into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(status),
|
||||
content: Some(vec![content]),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
|
||||
pub fn send_update(&self, fields: acp::ToolCallUpdateFields) {
|
||||
self.stream.send_tool_call_update(&self.tool_use_id, fields);
|
||||
}
|
||||
|
||||
pub fn authorize(
|
||||
&self,
|
||||
title: String,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
) -> impl use<> + Future<Output = Result<()>> {
|
||||
self.stream
|
||||
.authorize_tool_call(&self.tool_use_id, title, kind, input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub struct TestToolCallEventStream {
|
||||
stream: ToolCallEventStream,
|
||||
_events_rx: mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl TestToolCallEventStream {
|
||||
pub fn new() -> Self {
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
|
||||
let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx));
|
||||
|
||||
Self {
|
||||
stream,
|
||||
_events_rx: events_rx,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stream(&self) -> ToolCallEventStream {
|
||||
self.stream.clone()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1 +1,7 @@
|
||||
mod glob;
|
||||
mod find_path_tool;
|
||||
mod read_file_tool;
|
||||
mod thinking_tool;
|
||||
|
||||
pub use find_path_tool::*;
|
||||
pub use read_file_tool::*;
|
||||
pub use thinking_tool::*;
|
||||
|
||||
231
crates/agent2/src/tools/find_path_tool.rs
Normal file
@@ -0,0 +1,231 @@
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{App, AppContext, Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Write;
|
||||
use std::{cmp, path::PathBuf, sync::Arc};
|
||||
use util::paths::PathMatcher;
|
||||
|
||||
use crate::{AgentTool, ToolCallEventStream};
|
||||
|
||||
/// Fast file path pattern matching tool that works with any codebase size
|
||||
///
|
||||
/// - Supports glob patterns like "**/*.js" or "src/**/*.ts"
|
||||
/// - Returns matching file paths sorted alphabetically
|
||||
/// - Prefer the `grep` tool to this tool when searching for symbols unless you have specific information about paths.
|
||||
/// - Use this tool when you need to find files by name patterns
|
||||
/// - Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages.
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct FindPathToolInput {
|
||||
/// The glob to match against every path in the project.
|
||||
///
|
||||
/// <example>
|
||||
/// If the project has the following root directories:
|
||||
///
|
||||
/// - directory1/a/something.txt
|
||||
/// - directory2/a/things.txt
|
||||
/// - directory3/a/other.txt
|
||||
///
|
||||
/// You can get back the first two paths by providing a glob of "*thing*.txt"
|
||||
/// </example>
|
||||
pub glob: String,
|
||||
|
||||
/// Optional starting position for paginated results (0-based).
|
||||
/// When not provided, starts from the beginning.
|
||||
#[serde(default)]
|
||||
pub offset: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FindPathToolOutput {
|
||||
paths: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
const RESULTS_PER_PAGE: usize = 50;
|
||||
|
||||
pub struct FindPathTool {
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
impl FindPathTool {
|
||||
pub fn new(project: Entity<Project>) -> Self {
|
||||
Self { project }
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentTool for FindPathTool {
|
||||
type Input = FindPathToolInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"find_path".into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
acp::ToolKind::Search
|
||||
}
|
||||
|
||||
fn initial_title(&self, input: Self::Input) -> SharedString {
|
||||
format!("Find paths matching “`{}`”", input.glob).into()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
let search_paths_task = search_paths(&input.glob, self.project.clone(), cx);
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let matches = search_paths_task.await?;
|
||||
let paginated_matches: &[PathBuf] = &matches[cmp::min(input.offset, matches.len())
|
||||
..cmp::min(input.offset + RESULTS_PER_PAGE, matches.len())];
|
||||
|
||||
event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
title: Some(if paginated_matches.len() == 0 {
|
||||
"No matches".into()
|
||||
} else if paginated_matches.len() == 1 {
|
||||
"1 match".into()
|
||||
} else {
|
||||
format!("{} matches", paginated_matches.len())
|
||||
}),
|
||||
content: Some(
|
||||
paginated_matches
|
||||
.iter()
|
||||
.map(|path| acp::ToolCallContent::Content {
|
||||
content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
|
||||
uri: format!("file://{}", path.display()),
|
||||
name: path.to_string_lossy().into(),
|
||||
annotations: None,
|
||||
description: None,
|
||||
mime_type: None,
|
||||
size: None,
|
||||
title: None,
|
||||
}),
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
raw_output: Some(serde_json::json!({
|
||||
"paths": &matches,
|
||||
})),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
if matches.is_empty() {
|
||||
Ok("No matches found".into())
|
||||
} else {
|
||||
let mut message = format!("Found {} total matches.", matches.len());
|
||||
if matches.len() > RESULTS_PER_PAGE {
|
||||
write!(
|
||||
&mut message,
|
||||
"\nShowing results {}-{} (provide 'offset' parameter for more results):",
|
||||
input.offset + 1,
|
||||
input.offset + paginated_matches.len()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) {
|
||||
write!(&mut message, "\n{}", mat.display()).unwrap();
|
||||
}
|
||||
|
||||
Ok(message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
|
||||
let path_matcher = match PathMatcher::new([
|
||||
// Sometimes models try to search for "". In this case, return all paths in the project.
|
||||
if glob.is_empty() { "*" } else { glob },
|
||||
]) {
|
||||
Ok(matcher) => matcher,
|
||||
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
|
||||
};
|
||||
let snapshots: Vec<_> = project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).snapshot())
|
||||
.collect();
|
||||
|
||||
cx.background_spawn(async move {
|
||||
Ok(snapshots
|
||||
.iter()
|
||||
.flat_map(|snapshot| {
|
||||
let root_name = PathBuf::from(snapshot.root_name());
|
||||
snapshot
|
||||
.entries(false, 0)
|
||||
.map(move |entry| root_name.join(&entry.path))
|
||||
.filter(|path| path_matcher.is_match(&path))
|
||||
})
|
||||
.collect())
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use project::{FakeFs, Project};
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_find_path_tool(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/root",
|
||||
serde_json::json!({
|
||||
"apple": {
|
||||
"banana": {
|
||||
"carrot": "1",
|
||||
},
|
||||
"bandana": {
|
||||
"carbonara": "2",
|
||||
},
|
||||
"endive": "3"
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
|
||||
let matches = cx
|
||||
.update(|cx| search_paths("root/**/car*", project.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
matches,
|
||||
&[
|
||||
PathBuf::from("root/apple/banana/carrot"),
|
||||
PathBuf::from("root/apple/bandana/carbonara")
|
||||
]
|
||||
);
|
||||
|
||||
let matches = cx
|
||||
.update(|cx| search_paths("**/car*", project.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
matches,
|
||||
&[
|
||||
PathBuf::from("root/apple/banana/carrot"),
|
||||
PathBuf::from("root/apple/bandana/carbonara")
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,76 +0,0 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{App, AppContext, Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::paths::PathMatcher;
|
||||
use worktree::Snapshot as WorktreeSnapshot;
|
||||
|
||||
use crate::{
|
||||
templates::{GlobTemplate, Template, Templates},
|
||||
thread::AgentTool,
|
||||
};
|
||||
|
||||
// Description is dynamic, see `fn description` below
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
struct GlobInput {
|
||||
/// A POSIX glob pattern
|
||||
glob: SharedString,
|
||||
}
|
||||
|
||||
struct GlobTool {
|
||||
project: Entity<Project>,
|
||||
templates: Arc<Templates>,
|
||||
}
|
||||
|
||||
impl AgentTool for GlobTool {
|
||||
type Input = GlobInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"glob".into()
|
||||
}
|
||||
|
||||
fn description(&self, cx: &mut App) -> SharedString {
|
||||
let project_roots = self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).root_name().into())
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
GlobTemplate { project_roots }
|
||||
.render(&self.templates)
|
||||
.expect("template failed to render")
|
||||
.into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>> {
|
||||
let path_matcher = match PathMatcher::new([&input.glob]) {
|
||||
Ok(matcher) => matcher,
|
||||
Err(error) => return Task::ready(Err(anyhow!(error))),
|
||||
};
|
||||
|
||||
let snapshots: Vec<WorktreeSnapshot> = self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).snapshot())
|
||||
.collect();
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let paths = snapshots.iter().flat_map(|snapshot| {
|
||||
let root_name = PathBuf::from(snapshot.root_name());
|
||||
snapshot
|
||||
.entries(false, 0)
|
||||
.map(move |entry| root_name.join(&entry.path))
|
||||
.filter(|path| path_matcher.is_match(&path))
|
||||
});
|
||||
let output = paths
|
||||
.map(|path| format!("{}\n", path.display()))
|
||||
.collect::<String>();
|
||||
Ok(output)
|
||||
})
|
||||
}
|
||||
}
|
||||
970
crates/agent2/src/tools/read_file_tool.rs
Normal file
@@ -0,0 +1,970 @@
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_tool::{outline, ActionLog};
|
||||
use gpui::{Entity, Task};
|
||||
use indoc::formatdoc;
|
||||
use language::{Anchor, Point};
|
||||
use project::{AgentLocation, Project, WorktreeSettings};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
use std::sync::Arc;
|
||||
use ui::{App, SharedString};
|
||||
|
||||
use crate::{AgentTool, ToolCallEventStream};
|
||||
|
||||
/// Reads the content of the given file in the project.
|
||||
///
|
||||
/// - Never attempt to read a path that hasn't been previously mentioned.
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ReadFileToolInput {
|
||||
/// The relative path of the file to read.
|
||||
///
|
||||
/// This path should never be absolute, and the first component
|
||||
/// of the path should always be a root directory in a project.
|
||||
///
|
||||
/// <example>
|
||||
/// If the project has the following root directories:
|
||||
///
|
||||
/// - /a/b/directory1
|
||||
/// - /c/d/directory2
|
||||
///
|
||||
/// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`.
|
||||
/// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`.
|
||||
/// </example>
|
||||
pub path: String,
|
||||
|
||||
/// Optional line number to start reading on (1-based index)
|
||||
#[serde(default)]
|
||||
pub start_line: Option<u32>,
|
||||
|
||||
/// Optional line number to end reading on (1-based index, inclusive)
|
||||
#[serde(default)]
|
||||
pub end_line: Option<u32>,
|
||||
}
|
||||
|
||||
pub struct ReadFileTool {
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
||||
impl ReadFileTool {
|
||||
pub fn new(project: Entity<Project>, action_log: Entity<ActionLog>) -> Self {
|
||||
Self {
|
||||
project,
|
||||
action_log,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentTool for ReadFileTool {
|
||||
type Input = ReadFileToolInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"read_file".into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
acp::ToolKind::Read
|
||||
}
|
||||
|
||||
fn initial_title(&self, input: Self::Input) -> SharedString {
|
||||
let path = &input.path;
|
||||
match (input.start_line, input.end_line) {
|
||||
(Some(start), Some(end)) => {
|
||||
format!(
|
||||
"[Read file `{}` (lines {}-{})](@selection:{}:({}-{}))",
|
||||
path, start, end, path, start, end
|
||||
)
|
||||
}
|
||||
(Some(start), None) => {
|
||||
format!(
|
||||
"[Read file `{}` (from line {})](@selection:{}:({}-{}))",
|
||||
path, start, path, start, start
|
||||
)
|
||||
}
|
||||
_ => format!("[Read file `{}`](@file:{})", path, path),
|
||||
}
|
||||
.into()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
|
||||
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path)));
|
||||
};
|
||||
|
||||
// Error out if this path is either excluded or private in global settings
|
||||
let global_settings = WorktreeSettings::get_global(cx);
|
||||
if global_settings.is_path_excluded(&project_path.path) {
|
||||
return Task::ready(Err(anyhow!(
|
||||
"Cannot read file because its path matches the global `file_scan_exclusions` setting: {}",
|
||||
&input.path
|
||||
)));
|
||||
}
|
||||
|
||||
if global_settings.is_path_private(&project_path.path) {
|
||||
return Task::ready(Err(anyhow!(
|
||||
"Cannot read file because its path matches the global `private_files` setting: {}",
|
||||
&input.path
|
||||
)));
|
||||
}
|
||||
|
||||
// Error out if this path is either excluded or private in worktree settings
|
||||
let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx);
|
||||
if worktree_settings.is_path_excluded(&project_path.path) {
|
||||
return Task::ready(Err(anyhow!(
|
||||
"Cannot read file because its path matches the worktree `file_scan_exclusions` setting: {}",
|
||||
&input.path
|
||||
)));
|
||||
}
|
||||
|
||||
if worktree_settings.is_path_private(&project_path.path) {
|
||||
return Task::ready(Err(anyhow!(
|
||||
"Cannot read file because its path matches the worktree `private_files` setting: {}",
|
||||
&input.path
|
||||
)));
|
||||
}
|
||||
|
||||
let file_path = input.path.clone();
|
||||
|
||||
event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
locations: Some(vec![acp::ToolCallLocation {
|
||||
path: project_path.path.to_path_buf(),
|
||||
line: input.start_line,
|
||||
// TODO (tracked): use full range
|
||||
}]),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// TODO (tracked): images
|
||||
// if image_store::is_image_file(&self.project, &project_path, cx) {
|
||||
// let model = &self.thread.read(cx).selected_model;
|
||||
|
||||
// if !model.supports_images() {
|
||||
// return Task::ready(Err(anyhow!(
|
||||
// "Attempted to read an image, but Zed doesn't currently support sending images to {}.",
|
||||
// model.name().0
|
||||
// )))
|
||||
// .into();
|
||||
// }
|
||||
|
||||
// return cx.spawn(async move |cx| -> Result<ToolResultOutput> {
|
||||
// let image_entity: Entity<ImageItem> = cx
|
||||
// .update(|cx| {
|
||||
// self.project.update(cx, |project, cx| {
|
||||
// project.open_image(project_path.clone(), cx)
|
||||
// })
|
||||
// })?
|
||||
// .await?;
|
||||
|
||||
// let image =
|
||||
// image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
|
||||
|
||||
// let language_model_image = cx
|
||||
// .update(|cx| LanguageModelImage::from_image(image, cx))?
|
||||
// .await
|
||||
// .context("processing image")?;
|
||||
|
||||
// Ok(ToolResultOutput {
|
||||
// content: ToolResultContent::Image(language_model_image),
|
||||
// output: None,
|
||||
// })
|
||||
// });
|
||||
// }
|
||||
//
|
||||
|
||||
let project = self.project.clone();
|
||||
let action_log = self.action_log.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let buffer = cx
|
||||
.update(|cx| {
|
||||
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
|
||||
})?
|
||||
.await?;
|
||||
if buffer.read_with(cx, |buffer, _| {
|
||||
buffer
|
||||
.file()
|
||||
.as_ref()
|
||||
.map_or(true, |file| !file.disk_state().exists())
|
||||
})? {
|
||||
anyhow::bail!("{file_path} not found");
|
||||
}
|
||||
|
||||
project.update(cx, |project, cx| {
|
||||
project.set_agent_location(
|
||||
Some(AgentLocation {
|
||||
buffer: buffer.downgrade(),
|
||||
position: Anchor::MIN,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
})?;
|
||||
|
||||
// Check if specific line ranges are provided
|
||||
if input.start_line.is_some() || input.end_line.is_some() {
|
||||
let mut anchor = None;
|
||||
let result = buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
// .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
|
||||
let start = input.start_line.unwrap_or(1).max(1);
|
||||
let start_row = start - 1;
|
||||
if start_row <= buffer.max_point().row {
|
||||
let column = buffer.line_indent_for_row(start_row).raw_len();
|
||||
anchor = Some(buffer.anchor_before(Point::new(start_row, column)));
|
||||
}
|
||||
|
||||
let lines = text.split('\n').skip(start_row as usize);
|
||||
if let Some(end) = input.end_line {
|
||||
let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
|
||||
itertools::intersperse(lines.take(count as usize), "\n").collect::<String>()
|
||||
} else {
|
||||
itertools::intersperse(lines, "\n").collect::<String>()
|
||||
}
|
||||
})?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_read(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
if let Some(anchor) = anchor {
|
||||
project.update(cx, |project, cx| {
|
||||
project.set_agent_location(
|
||||
Some(AgentLocation {
|
||||
buffer: buffer.downgrade(),
|
||||
position: anchor,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
} else {
|
||||
// No line ranges specified, so check file size to see if it's too big.
|
||||
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
|
||||
|
||||
if file_size <= outline::AUTO_OUTLINE_SIZE {
|
||||
// File is small enough, so return its contents.
|
||||
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_read(buffer, cx);
|
||||
})?;
|
||||
|
||||
Ok(result)
|
||||
} else {
|
||||
// File is too big, so return the outline
|
||||
// and a suggestion to read again with line numbers.
|
||||
let outline =
|
||||
outline::file_outline(project, file_path, action_log, None, cx).await?;
|
||||
Ok(formatdoc! {"
|
||||
This file was too big to read all at once.
|
||||
|
||||
Here is an outline of its symbols:
|
||||
|
||||
{outline}
|
||||
|
||||
Using the line numbers in this outline, you can call this tool again
|
||||
while specifying the start_line and end_line fields to see the
|
||||
implementations of symbols in the outline.
|
||||
|
||||
Alternatively, you can fall back to the `grep` tool (if available)
|
||||
to search the file for specific content."
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::TestToolCallEventStream;
|
||||
|
||||
use super::*;
|
||||
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
|
||||
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_read_nonexistent_file(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/root"), json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/nonexistent_file.txt".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(
|
||||
result.unwrap_err().to_string(),
|
||||
"root/nonexistent_file.txt not found"
|
||||
);
|
||||
}
|
||||
#[gpui::test]
|
||||
async fn test_read_small_file(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"small_file.txt": "This is a small file content"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/small_file.txt".into(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "This is a small file content");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_read_large_file(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n a: u32,\n b: usize,\n}}", i)).collect::<Vec<_>>().join("\n")
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
|
||||
language_registry.add(Arc::new(rust_lang()));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
let content = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/large_file.rs".into(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
content.lines().skip(4).take(6).collect::<Vec<_>>(),
|
||||
vec![
|
||||
"struct Test0 [L1-4]",
|
||||
" a [L2]",
|
||||
" b [L3]",
|
||||
"struct Test1 [L5-8]",
|
||||
" a [L6]",
|
||||
" b [L7]",
|
||||
]
|
||||
);
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/large_file.rs".into(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
let content = result.unwrap();
|
||||
let expected_content = (0..1000)
|
||||
.flat_map(|i| {
|
||||
vec![
|
||||
format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4),
|
||||
format!(" a [L{}]", i * 4 + 2),
|
||||
format!(" b [L{}]", i * 4 + 3),
|
||||
]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
pretty_assertions::assert_eq!(
|
||||
content
|
||||
.lines()
|
||||
.skip(4)
|
||||
.take(expected_content.len())
|
||||
.collect::<Vec<_>>(),
|
||||
expected_content
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_read_file_with_line_range(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/multiline.txt".to_string(),
|
||||
start_line: Some(2),
|
||||
end_line: Some(4),
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
|
||||
// start_line of 0 should be treated as 1
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/multiline.txt".to_string(),
|
||||
start_line: Some(0),
|
||||
end_line: Some(2),
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 1\nLine 2");
|
||||
|
||||
// end_line of 0 should result in at least 1 line
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/multiline.txt".to_string(),
|
||||
start_line: Some(1),
|
||||
end_line: Some(0),
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 1");
|
||||
|
||||
// when start_line > end_line, should still return at least 1 line
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "root/multiline.txt".to_string(),
|
||||
start_line: Some(3),
|
||||
end_line: Some(2),
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 3");
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_outline_query(
|
||||
r#"
|
||||
(line_comment) @annotation
|
||||
|
||||
(struct_item
|
||||
"struct" @context
|
||||
name: (_) @name) @item
|
||||
(enum_item
|
||||
"enum" @context
|
||||
name: (_) @name) @item
|
||||
(enum_variant
|
||||
name: (_) @name) @item
|
||||
(field_declaration
|
||||
name: (_) @name) @item
|
||||
(impl_item
|
||||
"impl" @context
|
||||
trait: (_)? @name
|
||||
"for"? @context
|
||||
type: (_) @name
|
||||
body: (_ "{" (_)* "}")) @item
|
||||
(function_item
|
||||
"fn" @context
|
||||
name: (_) @name) @item
|
||||
(mod_item
|
||||
"mod" @context
|
||||
name: (_) @name) @item
|
||||
"#,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_read_file_security(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
|
||||
fs.insert_tree(
|
||||
path!("/"),
|
||||
json!({
|
||||
"project_root": {
|
||||
"allowed_file.txt": "This file is in the project",
|
||||
".mysecrets": "SECRET_KEY=abc123",
|
||||
".secretdir": {
|
||||
"config": "special configuration"
|
||||
},
|
||||
".mymetadata": "custom metadata",
|
||||
"subdir": {
|
||||
"normal_file.txt": "Normal file content",
|
||||
"special.privatekey": "private key content",
|
||||
"data.mysensitive": "sensitive data"
|
||||
}
|
||||
},
|
||||
"outside_project": {
|
||||
"sensitive_file.txt": "This file is outside the project"
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
cx.update(|cx| {
|
||||
use gpui::UpdateGlobal;
|
||||
use project::WorktreeSettings;
|
||||
use settings::SettingsStore;
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<WorktreeSettings>(cx, |settings| {
|
||||
settings.file_scan_exclusions = Some(vec![
|
||||
"**/.secretdir".to_string(),
|
||||
"**/.mymetadata".to_string(),
|
||||
]);
|
||||
settings.private_files = Some(vec![
|
||||
"**/.mysecrets".to_string(),
|
||||
"**/*.privatekey".to_string(),
|
||||
"**/*.mysensitive".to_string(),
|
||||
]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project, action_log));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
|
||||
// Reading a file outside the project worktree should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "/outside_project/sensitive_file.txt".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"read_file_tool should error when attempting to read an absolute path outside a worktree"
|
||||
);
|
||||
|
||||
// Reading a file within the project should succeed
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/allowed_file.txt".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"read_file_tool should be able to read files inside worktrees"
|
||||
);
|
||||
|
||||
// Reading files that match file_scan_exclusions should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/.secretdir/config".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"read_file_tool should error when attempting to read files in .secretdir (file_scan_exclusions)"
|
||||
);
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/.mymetadata".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"read_file_tool should error when attempting to read .mymetadata files (file_scan_exclusions)"
|
||||
);
|
||||
|
||||
// Reading private files should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/.mysecrets".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"read_file_tool should error when attempting to read .mysecrets (private_files)"
|
||||
);
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/subdir/special.privatekey".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"read_file_tool should error when attempting to read .privatekey files (private_files)"
|
||||
);
|
||||
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/subdir/data.mysensitive".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"read_file_tool should error when attempting to read .mysensitive files (private_files)"
|
||||
);
|
||||
|
||||
// Reading a normal file should still work, even with private_files configured
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/subdir/normal_file.txt".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert!(result.is_ok(), "Should be able to read normal files");
|
||||
assert_eq!(result.unwrap(), "Normal file content");
|
||||
|
||||
// Path traversal attempts with .. should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "project_root/../outside_project/sensitive_file.txt".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"read_file_tool should error when attempting to read a relative path that resolves to outside a worktree"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_read_file_with_multiple_worktree_settings(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
|
||||
// Create first worktree with its own private_files setting
|
||||
fs.insert_tree(
|
||||
path!("/worktree1"),
|
||||
json!({
|
||||
"src": {
|
||||
"main.rs": "fn main() { println!(\"Hello from worktree1\"); }",
|
||||
"secret.rs": "const API_KEY: &str = \"secret_key_1\";",
|
||||
"config.toml": "[database]\nurl = \"postgres://localhost/db1\""
|
||||
},
|
||||
"tests": {
|
||||
"test.rs": "mod tests { fn test_it() {} }",
|
||||
"fixture.sql": "CREATE TABLE users (id INT, name VARCHAR(255));"
|
||||
},
|
||||
".zed": {
|
||||
"settings.json": r#"{
|
||||
"file_scan_exclusions": ["**/fixture.*"],
|
||||
"private_files": ["**/secret.rs", "**/config.toml"]
|
||||
}"#
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Create second worktree with different private_files setting
|
||||
fs.insert_tree(
|
||||
path!("/worktree2"),
|
||||
json!({
|
||||
"lib": {
|
||||
"public.js": "export function greet() { return 'Hello from worktree2'; }",
|
||||
"private.js": "const SECRET_TOKEN = \"private_token_2\";",
|
||||
"data.json": "{\"api_key\": \"json_secret_key\"}"
|
||||
},
|
||||
"docs": {
|
||||
"README.md": "# Public Documentation",
|
||||
"internal.md": "# Internal Secrets and Configuration"
|
||||
},
|
||||
".zed": {
|
||||
"settings.json": r#"{
|
||||
"file_scan_exclusions": ["**/internal.*"],
|
||||
"private_files": ["**/private.js", "**/data.json"]
|
||||
}"#
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Set global settings
|
||||
cx.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
store.update_user_settings::<WorktreeSettings>(cx, |settings| {
|
||||
settings.file_scan_exclusions =
|
||||
Some(vec!["**/.git".to_string(), "**/node_modules".to_string()]);
|
||||
settings.private_files = Some(vec!["**/.env".to_string()]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
let project = Project::test(
|
||||
fs.clone(),
|
||||
[path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone()));
|
||||
let event_stream = TestToolCallEventStream::new();
|
||||
|
||||
// Test reading allowed files in worktree1
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "worktree1/src/main.rs".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "fn main() { println!(\"Hello from worktree1\"); }");
|
||||
|
||||
// Test reading private file in worktree1 should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "worktree1/src/secret.rs".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(
|
||||
result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("worktree `private_files` setting"),
|
||||
"Error should mention worktree private_files setting"
|
||||
);
|
||||
|
||||
// Test reading excluded file in worktree1 should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "worktree1/tests/fixture.sql".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(
|
||||
result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("worktree `file_scan_exclusions` setting"),
|
||||
"Error should mention worktree file_scan_exclusions setting"
|
||||
);
|
||||
|
||||
// Test reading allowed files in worktree2
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "worktree2/lib/public.js".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
"export function greet() { return 'Hello from worktree2'; }"
|
||||
);
|
||||
|
||||
// Test reading private file in worktree2 should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "worktree2/lib/private.js".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(
|
||||
result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("worktree `private_files` setting"),
|
||||
"Error should mention worktree private_files setting"
|
||||
);
|
||||
|
||||
// Test reading excluded file in worktree2 should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "worktree2/docs/internal.md".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(
|
||||
result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("worktree `file_scan_exclusions` setting"),
|
||||
"Error should mention worktree file_scan_exclusions setting"
|
||||
);
|
||||
|
||||
// Test that files allowed in one worktree but not in another are handled correctly
|
||||
// (e.g., config.toml is private in worktree1 but doesn't exist in worktree2)
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = ReadFileToolInput {
|
||||
path: "worktree1/src/config.toml".to_string(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
};
|
||||
tool.clone().run(input, event_stream.stream(), cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(
|
||||
result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("worktree `private_files` setting"),
|
||||
"Config.toml should be blocked by worktree1's private_files setting"
|
||||
);
|
||||
}
|
||||
}
|
||||
48
crates/agent2/src/tools/thinking_tool.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use gpui::{App, SharedString, Task};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{AgentTool, ToolCallEventStream};
|
||||
|
||||
/// A tool for thinking through problems, brainstorming ideas, or planning without executing any actions.
|
||||
/// Use this tool when you need to work through complex problems, develop strategies, or outline approaches before taking action.
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ThinkingToolInput {
|
||||
/// Content to think about. This should be a description of what to think about or
|
||||
/// a problem to solve.
|
||||
content: String,
|
||||
}
|
||||
|
||||
pub struct ThinkingTool;
|
||||
|
||||
impl AgentTool for ThinkingTool {
|
||||
type Input = ThinkingToolInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"thinking".into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
acp::ToolKind::Think
|
||||
}
|
||||
|
||||
fn initial_title(&self, _input: Self::Input) -> SharedString {
|
||||
"Thinking".into()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
_cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
event_stream.send_update(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![input.content.into()]),
|
||||
..Default::default()
|
||||
});
|
||||
Task::ready(Ok("Finished thinking.".to_string()))
|
||||
}
|
||||
}
|
||||
@@ -135,7 +135,7 @@ impl acp_old::Client for OldAcpClientDelegate {
|
||||
let response = cx
|
||||
.update(|cx| {
|
||||
self.thread.borrow().update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(tool_call, acp_options, cx)
|
||||
thread.request_tool_call_authorization(tool_call, acp_options, cx)
|
||||
})
|
||||
})?
|
||||
.context("Failed to update thread")?
|
||||
@@ -280,6 +280,7 @@ fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams)
|
||||
.map(into_new_tool_call_location)
|
||||
.collect(),
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -210,7 +210,7 @@ impl acp::Client for ClientDelegate {
|
||||
.context("Failed to get session")?
|
||||
.thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
|
||||
thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
|
||||
})?;
|
||||
|
||||
let result = rx.await;
|
||||
|
||||
@@ -24,7 +24,7 @@ use futures::{
|
||||
};
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::ResultExt;
|
||||
use util::{ResultExt, debug_panic};
|
||||
|
||||
use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
|
||||
use crate::claude::tools::ClaudeTool;
|
||||
@@ -153,16 +153,17 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
})
|
||||
.detach();
|
||||
|
||||
let end_turn_tx = Rc::new(RefCell::new(None));
|
||||
let turn_state = Rc::new(RefCell::new(TurnState::None));
|
||||
|
||||
let handler_task = cx.spawn({
|
||||
let end_turn_tx = end_turn_tx.clone();
|
||||
let turn_state = turn_state.clone();
|
||||
let mut thread_rx = thread_rx.clone();
|
||||
async move |cx| {
|
||||
while let Some(message) = incoming_message_rx.next().await {
|
||||
ClaudeAgentSession::handle_message(
|
||||
thread_rx.clone(),
|
||||
message,
|
||||
end_turn_tx.clone(),
|
||||
turn_state.clone(),
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
@@ -188,7 +189,7 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
|
||||
let session = ClaudeAgentSession {
|
||||
outgoing_tx,
|
||||
end_turn_tx,
|
||||
turn_state,
|
||||
_handler_task: handler_task,
|
||||
_mcp_server: Some(permission_mcp_server),
|
||||
};
|
||||
@@ -220,8 +221,8 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
)));
|
||||
};
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
session.end_turn_tx.borrow_mut().replace(tx);
|
||||
let (end_tx, end_rx) = oneshot::channel();
|
||||
session.turn_state.replace(TurnState::InProgress { end_tx });
|
||||
|
||||
let mut content = String::new();
|
||||
for chunk in params.prompt {
|
||||
@@ -255,7 +256,7 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
return Task::ready(Err(anyhow!(err)));
|
||||
}
|
||||
|
||||
cx.foreground_executor().spawn(async move { rx.await? })
|
||||
cx.foreground_executor().spawn(async move { end_rx.await? })
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
|
||||
@@ -265,18 +266,27 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
return;
|
||||
};
|
||||
|
||||
let request_id = new_request_id();
|
||||
|
||||
let turn_state = session.turn_state.take();
|
||||
let TurnState::InProgress { end_tx } = turn_state else {
|
||||
// Already cancelled or idle, put it back
|
||||
session.turn_state.replace(turn_state);
|
||||
return;
|
||||
};
|
||||
|
||||
session.turn_state.replace(TurnState::CancelRequested {
|
||||
end_tx,
|
||||
request_id: request_id.clone(),
|
||||
});
|
||||
|
||||
session
|
||||
.outgoing_tx
|
||||
.unbounded_send(SdkMessage::new_interrupt_message())
|
||||
.unbounded_send(SdkMessage::ControlRequest {
|
||||
request_id,
|
||||
request: ControlRequest::Interrupt,
|
||||
})
|
||||
.log_err();
|
||||
|
||||
if let Some(end_turn_tx) = session.end_turn_tx.borrow_mut().take() {
|
||||
end_turn_tx
|
||||
.send(Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::Cancelled,
|
||||
}))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -338,26 +348,139 @@ fn spawn_claude(
|
||||
|
||||
struct ClaudeAgentSession {
|
||||
outgoing_tx: UnboundedSender<SdkMessage>,
|
||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
|
||||
turn_state: Rc<RefCell<TurnState>>,
|
||||
_mcp_server: Option<ClaudeZedMcpServer>,
|
||||
_handler_task: Task<()>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
enum TurnState {
|
||||
#[default]
|
||||
None,
|
||||
InProgress {
|
||||
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
|
||||
},
|
||||
CancelRequested {
|
||||
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
|
||||
request_id: String,
|
||||
},
|
||||
CancelConfirmed {
|
||||
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl TurnState {
|
||||
fn is_cancelled(&self) -> bool {
|
||||
matches!(self, TurnState::CancelConfirmed { .. })
|
||||
}
|
||||
|
||||
fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
|
||||
match self {
|
||||
TurnState::None => None,
|
||||
TurnState::InProgress { end_tx, .. } => Some(end_tx),
|
||||
TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
|
||||
TurnState::CancelConfirmed { end_tx } => Some(end_tx),
|
||||
}
|
||||
}
|
||||
|
||||
fn confirm_cancellation(self, id: &str) -> Self {
|
||||
match self {
|
||||
TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
|
||||
TurnState::CancelConfirmed { end_tx }
|
||||
}
|
||||
_ => self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClaudeAgentSession {
|
||||
async fn handle_message(
|
||||
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
message: SdkMessage,
|
||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
|
||||
turn_state: Rc<RefCell<TurnState>>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
match message {
|
||||
// we should only be sending these out, they don't need to be in the thread
|
||||
SdkMessage::ControlRequest { .. } => {}
|
||||
SdkMessage::Assistant {
|
||||
SdkMessage::User {
|
||||
message,
|
||||
session_id: _,
|
||||
} => {
|
||||
let Some(thread) = thread_rx
|
||||
.recv()
|
||||
.await
|
||||
.log_err()
|
||||
.and_then(|entity| entity.upgrade())
|
||||
else {
|
||||
log::error!("Received an SDK message but thread is gone");
|
||||
return;
|
||||
};
|
||||
|
||||
for chunk in message.content.chunks() {
|
||||
match chunk {
|
||||
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
|
||||
if !turn_state.borrow().is_cancelled() {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_user_content_block(text.into(), cx)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
ContentChunk::ToolResult {
|
||||
content,
|
||||
tool_use_id,
|
||||
} => {
|
||||
let content = content.to_string();
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.update_tool_call(
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(tool_use_id.into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: if turn_state.borrow().is_cancelled() {
|
||||
// Do not set to completed if turn was cancelled
|
||||
None
|
||||
} else {
|
||||
Some(acp::ToolCallStatus::Completed)
|
||||
},
|
||||
content: (!content.is_empty())
|
||||
.then(|| vec![content.into()]),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
ContentChunk::Thinking { .. }
|
||||
| ContentChunk::RedactedThinking
|
||||
| ContentChunk::ToolUse { .. } => {
|
||||
debug_panic!(
|
||||
"Should not get {:?} with role: assistant. should we handle this?",
|
||||
chunk
|
||||
);
|
||||
}
|
||||
|
||||
ContentChunk::Image
|
||||
| ContentChunk::Document
|
||||
| ContentChunk::WebSearchToolResult => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
format!("Unsupported content: {:?}", chunk).into(),
|
||||
false,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
| SdkMessage::User {
|
||||
SdkMessage::Assistant {
|
||||
message,
|
||||
session_id: _,
|
||||
} => {
|
||||
@@ -380,6 +503,24 @@ impl ClaudeAgentSession {
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
ContentChunk::Thinking { thinking } => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(thinking.into(), true, cx)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
ContentChunk::RedactedThinking => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
"[REDACTED]".into(),
|
||||
true,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
ContentChunk::ToolUse { id, name, input } => {
|
||||
let claude_tool = ClaudeTool::infer(&name, input);
|
||||
|
||||
@@ -405,33 +546,12 @@ impl ClaudeAgentSession {
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
ContentChunk::ToolResult {
|
||||
content,
|
||||
tool_use_id,
|
||||
} => {
|
||||
let content = content.to_string();
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.update_tool_call(
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(tool_use_id.into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::Completed),
|
||||
content: (!content.is_empty())
|
||||
.then(|| vec![content.into()]),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err();
|
||||
ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
|
||||
debug_panic!(
|
||||
"Should not get tool results with role: assistant. should we handle this?"
|
||||
);
|
||||
}
|
||||
ContentChunk::Image
|
||||
| ContentChunk::Document
|
||||
| ContentChunk::Thinking
|
||||
| ContentChunk::RedactedThinking
|
||||
| ContentChunk::WebSearchToolResult => {
|
||||
ContentChunk::Image | ContentChunk::Document => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
@@ -451,27 +571,41 @@ impl ClaudeAgentSession {
|
||||
result,
|
||||
..
|
||||
} => {
|
||||
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
|
||||
if is_error || subtype == ResultErrorType::ErrorDuringExecution {
|
||||
end_turn_tx
|
||||
.send(Err(anyhow!(
|
||||
"Error: {}",
|
||||
result.unwrap_or_else(|| subtype.to_string())
|
||||
)))
|
||||
.ok();
|
||||
} else {
|
||||
let stop_reason = match subtype {
|
||||
ResultErrorType::Success => acp::StopReason::EndTurn,
|
||||
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
|
||||
ResultErrorType::ErrorDuringExecution => unreachable!(),
|
||||
};
|
||||
end_turn_tx
|
||||
.send(Ok(acp::PromptResponse { stop_reason }))
|
||||
.ok();
|
||||
}
|
||||
let turn_state = turn_state.take();
|
||||
let was_cancelled = turn_state.is_cancelled();
|
||||
let Some(end_turn_tx) = turn_state.end_tx() else {
|
||||
debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
|
||||
return;
|
||||
};
|
||||
|
||||
if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution)
|
||||
{
|
||||
end_turn_tx
|
||||
.send(Err(anyhow!(
|
||||
"Error: {}",
|
||||
result.unwrap_or_else(|| subtype.to_string())
|
||||
)))
|
||||
.ok();
|
||||
} else {
|
||||
let stop_reason = match subtype {
|
||||
ResultErrorType::Success => acp::StopReason::EndTurn,
|
||||
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
|
||||
ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
|
||||
};
|
||||
end_turn_tx
|
||||
.send(Ok(acp::PromptResponse { stop_reason }))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
SdkMessage::System { .. } | SdkMessage::ControlResponse { .. } => {}
|
||||
SdkMessage::ControlResponse { response } => {
|
||||
if matches!(response.subtype, ResultErrorType::Success) {
|
||||
let new_state = turn_state.take().confirm_cancellation(&response.request_id);
|
||||
turn_state.replace(new_state);
|
||||
} else {
|
||||
log::error!("Control response error: {:?}", response);
|
||||
}
|
||||
}
|
||||
SdkMessage::System { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -580,11 +714,13 @@ enum ContentChunk {
|
||||
content: Content,
|
||||
tool_use_id: String,
|
||||
},
|
||||
Thinking {
|
||||
thinking: String,
|
||||
},
|
||||
RedactedThinking,
|
||||
// TODO
|
||||
Image,
|
||||
Document,
|
||||
Thinking,
|
||||
RedactedThinking,
|
||||
WebSearchToolResult,
|
||||
#[serde(untagged)]
|
||||
UntaggedText(String),
|
||||
@@ -594,12 +730,12 @@ impl Display for ContentChunk {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ContentChunk::Text { text } => write!(f, "{}", text),
|
||||
ContentChunk::Thinking { thinking } => write!(f, "Thinking: {}", thinking),
|
||||
ContentChunk::RedactedThinking => write!(f, "Thinking: [REDACTED]"),
|
||||
ContentChunk::UntaggedText(text) => write!(f, "{}", text),
|
||||
ContentChunk::ToolResult { content, .. } => write!(f, "{}", content),
|
||||
ContentChunk::Image
|
||||
| ContentChunk::Document
|
||||
| ContentChunk::Thinking
|
||||
| ContentChunk::RedactedThinking
|
||||
| ContentChunk::ToolUse { .. }
|
||||
| ContentChunk::WebSearchToolResult => {
|
||||
write!(f, "\n{:?}\n", &self)
|
||||
@@ -710,22 +846,15 @@ impl Display for ResultErrorType {
|
||||
}
|
||||
}
|
||||
|
||||
impl SdkMessage {
|
||||
fn new_interrupt_message() -> Self {
|
||||
use rand::Rng;
|
||||
// In the Claude Code TS SDK they just generate a random 12 character string,
|
||||
// `Math.random().toString(36).substring(2, 15)`
|
||||
let request_id = rand::thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(12)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
Self::ControlRequest {
|
||||
request_id,
|
||||
request: ControlRequest::Interrupt,
|
||||
}
|
||||
}
|
||||
fn new_request_id() -> String {
|
||||
use rand::Rng;
|
||||
// In the Claude Code TS SDK they just generate a random 12 character string,
|
||||
// `Math.random().toString(36).substring(2, 15)`
|
||||
rand::thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(12)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -746,6 +875,8 @@ enum PermissionMode {
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
use crate::e2e_tests;
|
||||
use gpui::TestAppContext;
|
||||
use serde_json::json;
|
||||
|
||||
crate::common_e2e_tests!(ClaudeCode, allow_option_id = "allow");
|
||||
@@ -758,6 +889,68 @@ pub(crate) mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[cfg_attr(not(feature = "e2e"), ignore)]
|
||||
async fn test_todo_plan(cx: &mut TestAppContext) {
|
||||
let fs = e2e_tests::init_test(cx).await;
|
||||
let project = Project::test(fs, [], cx).await;
|
||||
let thread =
|
||||
e2e_tests::new_test_thread(ClaudeCode, project.clone(), "/private/tmp", cx).await;
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
"Create a todo plan for initializing a new React app. I'll follow it myself, do not execute on it.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut entries_len = 0;
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
entries_len = thread.plan().entries.len();
|
||||
assert!(thread.plan().entries.len() > 0, "Empty plan");
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
"Mark the first entry status as in progress without acting on it.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(
|
||||
thread.plan().entries[0].status,
|
||||
acp::PlanEntryStatus::InProgress
|
||||
));
|
||||
assert_eq!(thread.plan().entries.len(), entries_len);
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
"Now mark the first entry as completed without acting on it.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(
|
||||
thread.plan().entries[0].status,
|
||||
acp::PlanEntryStatus::Completed
|
||||
));
|
||||
assert_eq!(thread.plan().entries.len(), entries_len);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_content_untagged_text() {
|
||||
let json = json!("Hello, world!");
|
||||
|
||||
@@ -153,7 +153,7 @@ impl McpServerTool for PermissionTool {
|
||||
|
||||
let chosen_option = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(
|
||||
thread.request_tool_call_authorization(
|
||||
claude_tool.as_acp(tool_call_id),
|
||||
vec![
|
||||
acp::PermissionOption {
|
||||
|
||||
@@ -143,25 +143,6 @@ impl ClaudeTool {
|
||||
Self::Grep(Some(params)) => vec![format!("`{params}`").into()],
|
||||
Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()],
|
||||
Self::WebSearch(Some(params)) => vec![params.to_string().into()],
|
||||
Self::TodoWrite(Some(params)) => vec![
|
||||
params
|
||||
.todos
|
||||
.iter()
|
||||
.map(|todo| {
|
||||
format!(
|
||||
"- {} {}: {}",
|
||||
match todo.status {
|
||||
TodoStatus::Completed => "✅",
|
||||
TodoStatus::InProgress => "🚧",
|
||||
TodoStatus::Pending => "⬜",
|
||||
},
|
||||
todo.priority,
|
||||
todo.content
|
||||
)
|
||||
})
|
||||
.join("\n")
|
||||
.into(),
|
||||
],
|
||||
Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()],
|
||||
Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff {
|
||||
diff: acp::Diff {
|
||||
@@ -193,6 +174,10 @@ impl ClaudeTool {
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
Self::TodoWrite(Some(_)) => {
|
||||
// These are mapped to plan updates later
|
||||
vec![]
|
||||
}
|
||||
Self::Task(None)
|
||||
| Self::NotebookRead(None)
|
||||
| Self::NotebookEdit(None)
|
||||
@@ -312,6 +297,7 @@ impl ClaudeTool {
|
||||
content: self.content(),
|
||||
locations: self.locations(),
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -488,10 +474,11 @@ impl std::fmt::Display for GrepToolParams {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, JsonSchema, strum::Display, Debug)]
|
||||
#[derive(Default, Deserialize, Serialize, JsonSchema, strum::Display, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TodoPriority {
|
||||
High,
|
||||
#[default]
|
||||
Medium,
|
||||
Low,
|
||||
}
|
||||
@@ -526,14 +513,13 @@ impl Into<acp::PlanEntryStatus> for TodoStatus {
|
||||
|
||||
#[derive(Deserialize, Serialize, JsonSchema, Debug)]
|
||||
pub struct Todo {
|
||||
/// Unique identifier
|
||||
pub id: String,
|
||||
/// Task description
|
||||
pub content: String,
|
||||
/// Priority level of the todo
|
||||
pub priority: TodoPriority,
|
||||
/// Current status of the todo
|
||||
pub status: TodoStatus,
|
||||
/// Priority level of the todo
|
||||
#[serde(default)]
|
||||
pub priority: TodoPriority,
|
||||
}
|
||||
|
||||
impl Into<acp::PlanEntry> for Todo {
|
||||
|
||||
@@ -246,7 +246,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
|
||||
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
|
||||
let full_turn = thread.update(cx, |thread, cx| {
|
||||
let _ = thread.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
|
||||
cx,
|
||||
@@ -285,9 +285,8 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
|
||||
id.clone()
|
||||
});
|
||||
|
||||
let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
|
||||
full_turn.await.unwrap();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
thread.update(cx, |thread, cx| thread.cancel(cx)).await;
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
status: ToolCallStatus::Canceled,
|
||||
..
|
||||
|
||||
@@ -45,6 +45,11 @@ impl<T> MessageHistory<T> {
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn items(&self) -> &[T] {
|
||||
&self.items
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
@@ -21,20 +21,22 @@ use editor::{
|
||||
use file_icons::FileIcons;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId,
|
||||
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, PlatformDisplay, SharedString,
|
||||
StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation,
|
||||
UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, linear_gradient,
|
||||
list, percentage, point, prelude::*, pulsating_between,
|
||||
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay,
|
||||
SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement,
|
||||
Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop,
|
||||
linear_gradient, list, percentage, point, prelude::*, pulsating_between,
|
||||
};
|
||||
use language::language_settings::SoftWrap;
|
||||
use language::{Buffer, Language};
|
||||
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
|
||||
use parking_lot::Mutex;
|
||||
use project::Project;
|
||||
use settings::Settings as _;
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use text::{Anchor, BufferSnapshot};
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Disclosure, Divider, DividerColor, KeyBinding, Tooltip, prelude::*};
|
||||
use ui::{
|
||||
Disclosure, Divider, DividerColor, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt;
|
||||
use workspace::{CollaboratorId, Workspace};
|
||||
use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage};
|
||||
@@ -69,6 +71,7 @@ pub struct AcpThreadView {
|
||||
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
|
||||
last_error: Option<Entity<Markdown>>,
|
||||
list_state: ListState,
|
||||
scrollbar_state: ScrollbarState,
|
||||
auth_task: Option<Task<()>>,
|
||||
expanded_tool_calls: HashSet<acp::ToolCallId>,
|
||||
expanded_thinking_blocks: HashSet<(usize, usize)>,
|
||||
@@ -77,6 +80,7 @@ pub struct AcpThreadView {
|
||||
editor_expanded: bool,
|
||||
message_history: Rc<RefCell<MessageHistory<Vec<acp::ContentBlock>>>>,
|
||||
_cancel_task: Option<Task<()>>,
|
||||
_subscriptions: [Subscription; 1],
|
||||
}
|
||||
|
||||
enum ThreadState {
|
||||
@@ -175,6 +179,8 @@ impl AcpThreadView {
|
||||
|
||||
let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0));
|
||||
|
||||
let subscription = cx.observe_global_in::<SettingsStore>(window, Self::settings_changed);
|
||||
|
||||
Self {
|
||||
agent: agent.clone(),
|
||||
workspace: workspace.clone(),
|
||||
@@ -187,7 +193,8 @@ impl AcpThreadView {
|
||||
notifications: Vec::new(),
|
||||
notification_subscriptions: HashMap::default(),
|
||||
diff_editors: Default::default(),
|
||||
list_state: list_state,
|
||||
list_state: list_state.clone(),
|
||||
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
|
||||
last_error: None,
|
||||
auth_task: None,
|
||||
expanded_tool_calls: HashSet::default(),
|
||||
@@ -196,6 +203,7 @@ impl AcpThreadView {
|
||||
plan_expanded: false,
|
||||
editor_expanded: false,
|
||||
message_history,
|
||||
_subscriptions: [subscription],
|
||||
_cancel_task: None,
|
||||
}
|
||||
}
|
||||
@@ -373,6 +381,11 @@ impl AcpThreadView {
|
||||
editor.display_map.update(cx, |map, cx| {
|
||||
let snapshot = map.snapshot(cx);
|
||||
for (crease_id, crease) in snapshot.crease_snapshot.creases() {
|
||||
// Skip creases that have been edited out of the message buffer.
|
||||
if !crease.range().start.is_valid(&snapshot.buffer_snapshot) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(project_path) =
|
||||
self.mention_set.lock().path_for_crease_id(crease_id)
|
||||
{
|
||||
@@ -700,15 +713,7 @@ impl AcpThreadView {
|
||||
editor.set_show_code_actions(false, cx);
|
||||
editor.set_show_git_diff_gutter(false, cx);
|
||||
editor.set_expand_all_diff_hunks(cx);
|
||||
editor.set_text_style_refinement(TextStyleRefinement {
|
||||
font_size: Some(
|
||||
TextSize::Small
|
||||
.rems(cx)
|
||||
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
|
||||
.into(),
|
||||
),
|
||||
..Default::default()
|
||||
});
|
||||
editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
|
||||
editor
|
||||
});
|
||||
let entity_id = multibuffer.entity_id();
|
||||
@@ -854,6 +859,7 @@ impl AcpThreadView {
|
||||
.into_any()
|
||||
}
|
||||
AgentThreadEntry::ToolCall(tool_call) => div()
|
||||
.w_full()
|
||||
.py_1p5()
|
||||
.px_5()
|
||||
.child(self.render_tool_call(index, tool_call, window, cx))
|
||||
@@ -866,6 +872,7 @@ impl AcpThreadView {
|
||||
let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating);
|
||||
if index == total_entries - 1 && !is_generating {
|
||||
v_flex()
|
||||
.w_full()
|
||||
.child(primary)
|
||||
.child(self.render_thread_controls(cx))
|
||||
.into_any_element()
|
||||
@@ -898,6 +905,7 @@ impl AcpThreadView {
|
||||
cx: &Context<Self>,
|
||||
) -> AnyElement {
|
||||
let header_id = SharedString::from(format!("thinking-block-header-{}", entry_ix));
|
||||
let card_header_id = SharedString::from("inner-card-header");
|
||||
let key = (entry_ix, chunk_ix);
|
||||
let is_open = self.expanded_thinking_blocks.contains(&key);
|
||||
|
||||
@@ -905,41 +913,53 @@ impl AcpThreadView {
|
||||
.child(
|
||||
h_flex()
|
||||
.id(header_id)
|
||||
.group("disclosure-header")
|
||||
.group(&card_header_id)
|
||||
.relative()
|
||||
.w_full()
|
||||
.justify_between()
|
||||
.gap_1p5()
|
||||
.opacity(0.8)
|
||||
.hover(|style| style.opacity(1.))
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
Icon::new(IconName::ToolBulb)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.size_4()
|
||||
.justify_center()
|
||||
.child(
|
||||
div()
|
||||
.text_size(self.tool_name_font_size())
|
||||
.child("Thinking"),
|
||||
.group_hover(&card_header_id, |s| s.invisible().w_0())
|
||||
.child(
|
||||
Icon::new(IconName::ToolThink)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.absolute()
|
||||
.inset_0()
|
||||
.invisible()
|
||||
.justify_center()
|
||||
.group_hover(&card_header_id, |s| s.visible())
|
||||
.child(
|
||||
Disclosure::new(("expand", entry_ix), is_open)
|
||||
.opened_icon(IconName::ChevronUp)
|
||||
.closed_icon(IconName::ChevronRight)
|
||||
.on_click(cx.listener({
|
||||
move |this, _event, _window, cx| {
|
||||
if is_open {
|
||||
this.expanded_thinking_blocks.remove(&key);
|
||||
} else {
|
||||
this.expanded_thinking_blocks.insert(key);
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
})),
|
||||
),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
div().visible_on_hover("disclosure-header").child(
|
||||
Disclosure::new("thinking-disclosure", is_open)
|
||||
.opened_icon(IconName::ChevronUp)
|
||||
.closed_icon(IconName::ChevronDown)
|
||||
.on_click(cx.listener({
|
||||
move |this, _event, _window, cx| {
|
||||
if is_open {
|
||||
this.expanded_thinking_blocks.remove(&key);
|
||||
} else {
|
||||
this.expanded_thinking_blocks.insert(key);
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
})),
|
||||
),
|
||||
div()
|
||||
.text_size(self.tool_name_font_size())
|
||||
.child("Thinking"),
|
||||
)
|
||||
.on_click(cx.listener({
|
||||
move |this, _event, _window, cx| {
|
||||
@@ -970,6 +990,67 @@ impl AcpThreadView {
|
||||
.into_any_element()
|
||||
}
|
||||
|
||||
fn render_tool_call_icon(
|
||||
&self,
|
||||
group_name: SharedString,
|
||||
entry_ix: usize,
|
||||
is_collapsible: bool,
|
||||
is_open: bool,
|
||||
tool_call: &ToolCall,
|
||||
cx: &Context<Self>,
|
||||
) -> Div {
|
||||
let tool_icon = Icon::new(match tool_call.kind {
|
||||
acp::ToolKind::Read => IconName::ToolRead,
|
||||
acp::ToolKind::Edit => IconName::ToolPencil,
|
||||
acp::ToolKind::Delete => IconName::ToolDeleteFile,
|
||||
acp::ToolKind::Move => IconName::ArrowRightLeft,
|
||||
acp::ToolKind::Search => IconName::ToolSearch,
|
||||
acp::ToolKind::Execute => IconName::ToolTerminal,
|
||||
acp::ToolKind::Think => IconName::ToolThink,
|
||||
acp::ToolKind::Fetch => IconName::ToolWeb,
|
||||
acp::ToolKind::Other => IconName::ToolHammer,
|
||||
})
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted);
|
||||
|
||||
if is_collapsible {
|
||||
h_flex()
|
||||
.size_4()
|
||||
.justify_center()
|
||||
.child(
|
||||
div()
|
||||
.group_hover(&group_name, |s| s.invisible().w_0())
|
||||
.child(tool_icon),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.absolute()
|
||||
.inset_0()
|
||||
.invisible()
|
||||
.justify_center()
|
||||
.group_hover(&group_name, |s| s.visible())
|
||||
.child(
|
||||
Disclosure::new(("expand", entry_ix), is_open)
|
||||
.opened_icon(IconName::ChevronUp)
|
||||
.closed_icon(IconName::ChevronRight)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call.id.clone();
|
||||
move |this: &mut Self, _, _, cx: &mut Context<Self>| {
|
||||
if is_open {
|
||||
this.expanded_tool_calls.remove(&id);
|
||||
} else {
|
||||
this.expanded_tool_calls.insert(id.clone());
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
})),
|
||||
),
|
||||
)
|
||||
} else {
|
||||
div().child(tool_icon)
|
||||
}
|
||||
}
|
||||
|
||||
fn render_tool_call(
|
||||
&self,
|
||||
entry_ix: usize,
|
||||
@@ -977,7 +1058,8 @@ impl AcpThreadView {
|
||||
window: &Window,
|
||||
cx: &Context<Self>,
|
||||
) -> Div {
|
||||
let header_id = SharedString::from(format!("tool-call-header-{}", entry_ix));
|
||||
let header_id = SharedString::from(format!("outer-tool-call-header-{}", entry_ix));
|
||||
let card_header_id = SharedString::from("inner-tool-call-header");
|
||||
|
||||
let status_icon = match &tool_call.status {
|
||||
ToolCallStatus::Allowed {
|
||||
@@ -1026,6 +1108,21 @@ impl AcpThreadView {
|
||||
let is_collapsible = !tool_call.content.is_empty() && !needs_confirmation;
|
||||
let is_open = !is_collapsible || self.expanded_tool_calls.contains(&tool_call.id);
|
||||
|
||||
let gradient_color = cx.theme().colors().panel_background;
|
||||
let gradient_overlay = {
|
||||
div()
|
||||
.absolute()
|
||||
.top_0()
|
||||
.right_0()
|
||||
.w_12()
|
||||
.h_full()
|
||||
.bg(linear_gradient(
|
||||
90.,
|
||||
linear_color_stop(gradient_color, 1.),
|
||||
linear_color_stop(gradient_color.opacity(0.2), 0.),
|
||||
))
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.when(needs_confirmation, |this| {
|
||||
this.rounded_lg()
|
||||
@@ -1042,43 +1139,38 @@ impl AcpThreadView {
|
||||
.justify_between()
|
||||
.map(|this| {
|
||||
if needs_confirmation {
|
||||
this.px_2()
|
||||
this.pl_2()
|
||||
.pr_1()
|
||||
.py_1()
|
||||
.rounded_t_md()
|
||||
.bg(self.tool_card_header_bg(cx))
|
||||
.border_b_1()
|
||||
.border_color(self.tool_card_border_color(cx))
|
||||
.bg(self.tool_card_header_bg(cx))
|
||||
} else {
|
||||
this.opacity(0.8).hover(|style| style.opacity(1.))
|
||||
}
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.id("tool-call-header")
|
||||
.overflow_x_scroll()
|
||||
.group(&card_header_id)
|
||||
.relative()
|
||||
.w_full()
|
||||
.map(|this| {
|
||||
if needs_confirmation {
|
||||
this.text_xs()
|
||||
if tool_call.locations.len() == 1 {
|
||||
this.gap_0()
|
||||
} else {
|
||||
this.text_size(self.tool_name_font_size())
|
||||
this.gap_1p5()
|
||||
}
|
||||
})
|
||||
.gap_1p5()
|
||||
.child(
|
||||
Icon::new(match tool_call.kind {
|
||||
acp::ToolKind::Read => IconName::ToolRead,
|
||||
acp::ToolKind::Edit => IconName::ToolPencil,
|
||||
acp::ToolKind::Delete => IconName::ToolDeleteFile,
|
||||
acp::ToolKind::Move => IconName::ArrowRightLeft,
|
||||
acp::ToolKind::Search => IconName::ToolSearch,
|
||||
acp::ToolKind::Execute => IconName::ToolTerminal,
|
||||
acp::ToolKind::Think => IconName::ToolBulb,
|
||||
acp::ToolKind::Fetch => IconName::ToolWeb,
|
||||
acp::ToolKind::Other => IconName::ToolHammer,
|
||||
})
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.text_size(self.tool_name_font_size())
|
||||
.child(self.render_tool_call_icon(
|
||||
card_header_id,
|
||||
entry_ix,
|
||||
is_collapsible,
|
||||
is_open,
|
||||
tool_call,
|
||||
cx,
|
||||
))
|
||||
.child(if tool_call.locations.len() == 1 {
|
||||
let name = tool_call.locations[0]
|
||||
.path
|
||||
@@ -1089,13 +1181,11 @@ impl AcpThreadView {
|
||||
|
||||
h_flex()
|
||||
.id(("open-tool-call-location", entry_ix))
|
||||
.child(name)
|
||||
.w_full()
|
||||
.max_w_full()
|
||||
.pr_1()
|
||||
.gap_0p5()
|
||||
.cursor_pointer()
|
||||
.px_1p5()
|
||||
.rounded_sm()
|
||||
.overflow_x_scroll()
|
||||
.opacity(0.8)
|
||||
.hover(|label| {
|
||||
label.opacity(1.).bg(cx
|
||||
@@ -1104,53 +1194,49 @@ impl AcpThreadView {
|
||||
.element_hover
|
||||
.opacity(0.5))
|
||||
})
|
||||
.child(name)
|
||||
.tooltip(Tooltip::text("Jump to File"))
|
||||
.on_click(cx.listener(move |this, _, window, cx| {
|
||||
this.open_tool_call_location(entry_ix, 0, window, cx);
|
||||
}))
|
||||
.into_any_element()
|
||||
} else {
|
||||
self.render_markdown(
|
||||
tool_call.label.clone(),
|
||||
default_markdown_style(needs_confirmation, window, cx),
|
||||
)
|
||||
.into_any()
|
||||
h_flex()
|
||||
.id("non-card-label-container")
|
||||
.w_full()
|
||||
.relative()
|
||||
.overflow_hidden()
|
||||
.child(
|
||||
h_flex()
|
||||
.id("non-card-label")
|
||||
.pr_8()
|
||||
.w_full()
|
||||
.overflow_x_scroll()
|
||||
.child(self.render_markdown(
|
||||
tool_call.label.clone(),
|
||||
default_markdown_style(
|
||||
needs_confirmation,
|
||||
window,
|
||||
cx,
|
||||
),
|
||||
)),
|
||||
)
|
||||
.child(gradient_overlay)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call.id.clone();
|
||||
move |this: &mut Self, _, _, cx: &mut Context<Self>| {
|
||||
if is_open {
|
||||
this.expanded_tool_calls.remove(&id);
|
||||
} else {
|
||||
this.expanded_tool_calls.insert(id.clone());
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
}))
|
||||
.into_any()
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_0p5()
|
||||
.when(is_collapsible, |this| {
|
||||
this.child(
|
||||
Disclosure::new(("expand", entry_ix), is_open)
|
||||
.opened_icon(IconName::ChevronUp)
|
||||
.closed_icon(IconName::ChevronDown)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call.id.clone();
|
||||
move |this: &mut Self, _, _, cx: &mut Context<Self>| {
|
||||
if is_open {
|
||||
this.expanded_tool_calls.remove(&id);
|
||||
} else {
|
||||
this.expanded_tool_calls.insert(id.clone());
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
})),
|
||||
)
|
||||
})
|
||||
.children(status_icon),
|
||||
)
|
||||
.on_click(cx.listener({
|
||||
let id = tool_call.id.clone();
|
||||
move |this: &mut Self, _, _, cx: &mut Context<Self>| {
|
||||
if is_open {
|
||||
this.expanded_tool_calls.remove(&id);
|
||||
} else {
|
||||
this.expanded_tool_calls.insert(id.clone());
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
})),
|
||||
.children(status_icon),
|
||||
)
|
||||
.when(is_open, |this| {
|
||||
this.child(
|
||||
@@ -1244,8 +1330,7 @@ impl AcpThreadView {
|
||||
cx: &Context<Self>,
|
||||
) -> Div {
|
||||
h_flex()
|
||||
.py_1p5()
|
||||
.px_1p5()
|
||||
.p_1p5()
|
||||
.gap_1()
|
||||
.justify_end()
|
||||
.when(!empty_content, |this| {
|
||||
@@ -1271,6 +1356,7 @@ impl AcpThreadView {
|
||||
})
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let tool_call_id = tool_call_id.clone();
|
||||
let option_id = option.id.clone();
|
||||
@@ -1520,7 +1606,7 @@ impl AcpThreadView {
|
||||
})
|
||||
})
|
||||
.when(!changed_buffers.is_empty(), |this| {
|
||||
this.child(Divider::horizontal())
|
||||
this.child(Divider::horizontal().color(DividerColor::Border))
|
||||
.child(self.render_edits_summary(
|
||||
action_log,
|
||||
&changed_buffers,
|
||||
@@ -1550,6 +1636,7 @@ impl AcpThreadView {
|
||||
{
|
||||
h_flex()
|
||||
.w_full()
|
||||
.cursor_default()
|
||||
.gap_1()
|
||||
.text_xs()
|
||||
.text_color(cx.theme().colors().text_muted)
|
||||
@@ -1579,7 +1666,7 @@ impl AcpThreadView {
|
||||
let status_label = if stats.pending == 0 {
|
||||
"All Done".to_string()
|
||||
} else if stats.completed == 0 {
|
||||
format!("{}", plan.entries.len())
|
||||
format!("{} Tasks", plan.entries.len())
|
||||
} else {
|
||||
format!("{}/{}", stats.completed, plan.entries.len())
|
||||
};
|
||||
@@ -1693,7 +1780,6 @@ impl AcpThreadView {
|
||||
.child(
|
||||
h_flex()
|
||||
.id("edits-container")
|
||||
.cursor_pointer()
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.child(Disclosure::new("edits-disclosure", expanded))
|
||||
@@ -2468,6 +2554,7 @@ impl AcpThreadView {
|
||||
}));
|
||||
|
||||
h_flex()
|
||||
.w_full()
|
||||
.mr_1()
|
||||
.pb_2()
|
||||
.px(RESPONSE_PADDING_X)
|
||||
@@ -2478,6 +2565,48 @@ impl AcpThreadView {
|
||||
.child(open_as_markdown)
|
||||
.child(scroll_to_top)
|
||||
}
|
||||
|
||||
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {
|
||||
div()
|
||||
.id("acp-thread-scrollbar")
|
||||
.occlude()
|
||||
.on_mouse_move(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
cx.stop_propagation()
|
||||
}))
|
||||
.on_hover(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|_, _, _, cx| {
|
||||
cx.stop_propagation();
|
||||
}),
|
||||
)
|
||||
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.h_full()
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_1()
|
||||
.bottom_0()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx)))
|
||||
}
|
||||
|
||||
fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
for diff_editor in self.diff_editors.values() {
|
||||
diff_editor.update(cx, |diff_editor, cx| {
|
||||
diff_editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
|
||||
cx.notify();
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for AcpThreadView {
|
||||
@@ -2552,6 +2681,7 @@ impl Render for AcpThreadView {
|
||||
.flex_grow()
|
||||
.into_any(),
|
||||
)
|
||||
.child(self.render_vertical_scrollbar(cx))
|
||||
.children(match thread_clone.read(cx).status() {
|
||||
ThreadStatus::Idle | ThreadStatus::WaitingForToolConfirmation => {
|
||||
None
|
||||
@@ -2754,6 +2884,18 @@ fn plan_label_markdown_style(
|
||||
}
|
||||
}
|
||||
|
||||
fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
|
||||
TextStyleRefinement {
|
||||
font_size: Some(
|
||||
TextSize::Small
|
||||
.rems(cx)
|
||||
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
|
||||
.into(),
|
||||
),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use agent_client_protocol::SessionId;
|
||||
@@ -2761,8 +2903,12 @@ mod tests {
|
||||
use fs::FakeFs;
|
||||
use futures::future::try_join_all;
|
||||
use gpui::{SemanticVersion, TestAppContext, VisualTestContext};
|
||||
use lsp::{CompletionContext, CompletionTriggerKind};
|
||||
use project::CompletionIntent;
|
||||
use rand::Rng;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -2842,6 +2988,7 @@ mod tests {
|
||||
content: vec!["hi".into()],
|
||||
locations: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
};
|
||||
let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)])
|
||||
.with_permission_requests(HashMap::from_iter([(
|
||||
@@ -2874,6 +3021,109 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_crease_removal(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/project", json!({"file": ""})).await;
|
||||
let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
|
||||
let agent = StubAgentServer::default();
|
||||
let (workspace, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
let thread_view = cx.update(|window, cx| {
|
||||
cx.new(|cx| {
|
||||
AcpThreadView::new(
|
||||
Rc::new(agent),
|
||||
workspace.downgrade(),
|
||||
project,
|
||||
Rc::new(RefCell::new(MessageHistory::default())),
|
||||
1,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
|
||||
let excerpt_id = message_editor.update(cx, |editor, cx| {
|
||||
editor
|
||||
.buffer()
|
||||
.read(cx)
|
||||
.excerpt_ids()
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap()
|
||||
});
|
||||
let completions = message_editor.update_in(cx, |editor, window, cx| {
|
||||
editor.set_text("Hello @", window, cx);
|
||||
let buffer = editor.buffer().read(cx).as_singleton().unwrap();
|
||||
let completion_provider = editor.completion_provider().unwrap();
|
||||
completion_provider.completions(
|
||||
excerpt_id,
|
||||
&buffer,
|
||||
Anchor::MAX,
|
||||
CompletionContext {
|
||||
trigger_kind: CompletionTriggerKind::TRIGGER_CHARACTER,
|
||||
trigger_character: Some("@".into()),
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let [_, completion]: [_; 2] = completions
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.flat_map(|response| response.completions)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
message_editor.update_in(cx, |editor, window, cx| {
|
||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||
let start = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, completion.replace_range.start)
|
||||
.unwrap();
|
||||
let end = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, completion.replace_range.end)
|
||||
.unwrap();
|
||||
editor.edit([(start..end, completion.new_text)], cx);
|
||||
(completion.confirm.unwrap())(CompletionIntent::Complete, window, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
// Backspace over the inserted crease (and the following space).
|
||||
message_editor.update_in(cx, |editor, window, cx| {
|
||||
editor.backspace(&Default::default(), window, cx);
|
||||
editor.backspace(&Default::default(), window, cx);
|
||||
});
|
||||
|
||||
thread_view.update_in(cx, |thread_view, window, cx| {
|
||||
thread_view.chat(&Chat, window, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let content = thread_view.update_in(cx, |thread_view, _window, _cx| {
|
||||
thread_view
|
||||
.message_history
|
||||
.borrow()
|
||||
.items()
|
||||
.iter()
|
||||
.flatten()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
// We don't send a resource link for the deleted crease.
|
||||
pretty_assertions::assert_matches!(content.as_slice(), [acp::ContentBlock::Text { .. }]);
|
||||
}
|
||||
|
||||
async fn setup_thread_view(
|
||||
agent: impl AgentServer + 'static,
|
||||
cx: &mut TestAppContext,
|
||||
@@ -3027,7 +3277,7 @@ mod tests {
|
||||
let task = cx.spawn(async move |cx| {
|
||||
if let Some((tool_call, options)) = permission_request {
|
||||
let permission = thread.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(
|
||||
thread.request_tool_call_authorization(
|
||||
tool_call.clone(),
|
||||
options.clone(),
|
||||
cx,
|
||||
|
||||
@@ -69,8 +69,6 @@ pub struct ActiveThread {
|
||||
messages: Vec<MessageId>,
|
||||
list_state: ListState,
|
||||
scrollbar_state: ScrollbarState,
|
||||
show_scrollbar: bool,
|
||||
hide_scrollbar_task: Option<Task<()>>,
|
||||
rendered_messages_by_id: HashMap<MessageId, RenderedMessage>,
|
||||
rendered_tool_uses: HashMap<LanguageModelToolUseId, RenderedToolUse>,
|
||||
editing_message: Option<(MessageId, EditingMessageState)>,
|
||||
@@ -805,9 +803,7 @@ impl ActiveThread {
|
||||
expanded_thinking_segments: HashMap::default(),
|
||||
expanded_code_blocks: HashMap::default(),
|
||||
list_state: list_state.clone(),
|
||||
scrollbar_state: ScrollbarState::new(list_state),
|
||||
show_scrollbar: false,
|
||||
hide_scrollbar_task: None,
|
||||
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
|
||||
editing_message: None,
|
||||
last_error: None,
|
||||
copied_code_block_ids: HashSet::default(),
|
||||
@@ -2628,7 +2624,7 @@ impl ActiveThread {
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
Icon::new(IconName::ToolBulb)
|
||||
Icon::new(IconName::ToolThink)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
@@ -3502,60 +3498,37 @@ impl ActiveThread {
|
||||
}
|
||||
}
|
||||
|
||||
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
|
||||
if !self.show_scrollbar && !self.scrollbar_state.is_dragging() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(
|
||||
div()
|
||||
.occlude()
|
||||
.id("active-thread-scrollbar")
|
||||
.on_mouse_move(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
cx.stop_propagation()
|
||||
}))
|
||||
.on_hover(|_, _, cx| {
|
||||
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {
|
||||
div()
|
||||
.occlude()
|
||||
.id("active-thread-scrollbar")
|
||||
.on_mouse_move(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
cx.stop_propagation()
|
||||
}))
|
||||
.on_hover(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|_, _, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|_, _, _, cx| {
|
||||
cx.stop_propagation();
|
||||
}),
|
||||
)
|
||||
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.h_full()
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_1()
|
||||
.bottom_0()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
|
||||
)
|
||||
}
|
||||
|
||||
fn hide_scrollbar_later(&mut self, cx: &mut Context<Self>) {
|
||||
const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1);
|
||||
self.hide_scrollbar_task = Some(cx.spawn(async move |thread, cx| {
|
||||
cx.background_executor()
|
||||
.timer(SCROLLBAR_SHOW_INTERVAL)
|
||||
.await;
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
if !thread.scrollbar_state.is_dragging() {
|
||||
thread.show_scrollbar = false;
|
||||
cx.notify();
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
}))
|
||||
}),
|
||||
)
|
||||
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.h_full()
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_1()
|
||||
.bottom_0()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx)))
|
||||
}
|
||||
|
||||
pub fn is_codeblock_expanded(&self, message_id: MessageId, ix: usize) -> bool {
|
||||
@@ -3596,26 +3569,8 @@ impl Render for ActiveThread {
|
||||
.size_full()
|
||||
.relative()
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.on_mouse_move(cx.listener(|this, _, _, cx| {
|
||||
this.show_scrollbar = true;
|
||||
this.hide_scrollbar_later(cx);
|
||||
cx.notify();
|
||||
}))
|
||||
.on_scroll_wheel(cx.listener(|this, _, _, cx| {
|
||||
this.show_scrollbar = true;
|
||||
this.hide_scrollbar_later(cx);
|
||||
cx.notify();
|
||||
}))
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|this, _, _, cx| {
|
||||
this.hide_scrollbar_later(cx);
|
||||
}),
|
||||
)
|
||||
.child(list(self.list_state.clone(), cx.processor(Self::render_message)).flex_grow())
|
||||
.when_some(self.render_vertical_scrollbar(cx), |this, scrollbar| {
|
||||
this.child(scrollbar)
|
||||
})
|
||||
.child(self.render_vertical_scrollbar(cx))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2343,6 +2343,16 @@ impl AgentPanel {
|
||||
)
|
||||
}
|
||||
|
||||
fn render_backdrop(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
div()
|
||||
.size_full()
|
||||
.absolute()
|
||||
.inset_0()
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.opacity(0.8)
|
||||
.block_mouse_except_scroll()
|
||||
}
|
||||
|
||||
fn render_trial_end_upsell(
|
||||
&self,
|
||||
_window: &mut Window,
|
||||
@@ -2352,15 +2362,24 @@ impl AgentPanel {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(EndTrialUpsell::new(Arc::new({
|
||||
let this = cx.entity();
|
||||
move |_, cx| {
|
||||
this.update(cx, |_this, cx| {
|
||||
TrialEndUpsell::set_dismissed(true, cx);
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
})))
|
||||
Some(
|
||||
v_flex()
|
||||
.absolute()
|
||||
.inset_0()
|
||||
.size_full()
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.opacity(0.85)
|
||||
.block_mouse_except_scroll()
|
||||
.child(EndTrialUpsell::new(Arc::new({
|
||||
let this = cx.entity();
|
||||
move |_, cx| {
|
||||
this.update(cx, |_this, cx| {
|
||||
TrialEndUpsell::set_dismissed(true, cx);
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
}))),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_empty_state_section_header(
|
||||
@@ -3210,9 +3229,10 @@ impl Render for AgentPanel {
|
||||
// - Scrolling in all views works as expected
|
||||
// - Files can be dropped into the panel
|
||||
let content = v_flex()
|
||||
.key_context(self.key_context())
|
||||
.justify_between()
|
||||
.relative()
|
||||
.size_full()
|
||||
.justify_between()
|
||||
.key_context(self.key_context())
|
||||
.on_action(cx.listener(Self::cancel))
|
||||
.on_action(cx.listener(|this, action: &NewThread, window, cx| {
|
||||
this.new_thread(action, window, cx);
|
||||
@@ -3255,14 +3275,12 @@ impl Render for AgentPanel {
|
||||
.on_action(cx.listener(Self::toggle_burn_mode))
|
||||
.child(self.render_toolbar(window, cx))
|
||||
.children(self.render_onboarding(window, cx))
|
||||
.children(self.render_trial_end_upsell(window, cx))
|
||||
.map(|parent| match &self.active_view {
|
||||
ActiveView::Thread {
|
||||
thread,
|
||||
message_editor,
|
||||
..
|
||||
} => parent
|
||||
.relative()
|
||||
.child(
|
||||
if thread.read(cx).is_empty() && !self.should_render_onboarding(cx) {
|
||||
self.render_thread_empty_state(window, cx)
|
||||
@@ -3299,21 +3317,10 @@ impl Render for AgentPanel {
|
||||
})
|
||||
.child(h_flex().relative().child(message_editor.clone()).when(
|
||||
!LanguageModelRegistry::read_global(cx).has_authenticated_provider(cx),
|
||||
|this| {
|
||||
this.child(
|
||||
div()
|
||||
.size_full()
|
||||
.absolute()
|
||||
.inset_0()
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.opacity(0.8)
|
||||
.block_mouse_except_scroll(),
|
||||
)
|
||||
},
|
||||
|this| this.child(self.render_backdrop(cx)),
|
||||
))
|
||||
.child(self.render_drag_target(cx)),
|
||||
ActiveView::ExternalAgentThread { thread_view, .. } => parent
|
||||
.relative()
|
||||
.child(thread_view.clone())
|
||||
.child(self.render_drag_target(cx)),
|
||||
ActiveView::History => parent.child(self.history.clone()),
|
||||
@@ -3352,7 +3359,8 @@ impl Render for AgentPanel {
|
||||
))
|
||||
}
|
||||
ActiveView::Configuration => parent.children(self.configuration.clone()),
|
||||
});
|
||||
})
|
||||
.children(self.render_trial_end_upsell(window, cx));
|
||||
|
||||
match self.active_view.which_font_size_used() {
|
||||
WhichFontSize::AgentFont => {
|
||||
|
||||
@@ -1,26 +1,33 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use client::{Client, zed_urls};
|
||||
use client::{Client, UserStore, zed_urls};
|
||||
use cloud_llm_client::Plan;
|
||||
use gpui::{
|
||||
Animation, AnimationExt, AnyElement, App, IntoElement, RenderOnce, Transformation, Window,
|
||||
percentage,
|
||||
Animation, AnimationExt, AnyElement, App, Entity, IntoElement, RenderOnce, Transformation,
|
||||
Window, percentage,
|
||||
};
|
||||
use ui::{Divider, Vector, VectorName, prelude::*};
|
||||
|
||||
use crate::{SignInStatus, plan_definitions::PlanDefinitions};
|
||||
use crate::{SignInStatus, YoungAccountBanner, plan_definitions::PlanDefinitions};
|
||||
|
||||
#[derive(IntoElement, RegisterComponent)]
|
||||
pub struct AiUpsellCard {
|
||||
pub sign_in_status: SignInStatus,
|
||||
pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
|
||||
pub account_too_young: bool,
|
||||
pub user_plan: Option<Plan>,
|
||||
pub tab_index: Option<isize>,
|
||||
}
|
||||
|
||||
impl AiUpsellCard {
|
||||
pub fn new(client: Arc<Client>, user_plan: Option<Plan>) -> Self {
|
||||
pub fn new(
|
||||
client: Arc<Client>,
|
||||
user_store: &Entity<UserStore>,
|
||||
user_plan: Option<Plan>,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let status = *client.status().borrow();
|
||||
let store = user_store.read(cx);
|
||||
|
||||
Self {
|
||||
user_plan,
|
||||
@@ -32,6 +39,7 @@ impl AiUpsellCard {
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}),
|
||||
account_too_young: store.account_too_young(),
|
||||
tab_index: None,
|
||||
}
|
||||
}
|
||||
@@ -40,6 +48,7 @@ impl AiUpsellCard {
|
||||
impl RenderOnce for AiUpsellCard {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
let plan_definitions = PlanDefinitions;
|
||||
let young_account_banner = YoungAccountBanner;
|
||||
|
||||
let pro_section = v_flex()
|
||||
.flex_grow()
|
||||
@@ -128,7 +137,7 @@ impl RenderOnce for AiUpsellCard {
|
||||
.size(rems_from_px(72.))
|
||||
.child(
|
||||
Vector::new(
|
||||
VectorName::CertifiedUserStamp,
|
||||
VectorName::ProUserStamp,
|
||||
rems_from_px(72.),
|
||||
rems_from_px(72.),
|
||||
)
|
||||
@@ -158,36 +167,70 @@ impl RenderOnce for AiUpsellCard {
|
||||
SignInStatus::SignedIn => match self.user_plan {
|
||||
None | Some(Plan::ZedFree) => card
|
||||
.child(Label::new("Try Zed AI").size(LabelSize::Large))
|
||||
.child(
|
||||
div()
|
||||
.max_w_3_4()
|
||||
.mb_2()
|
||||
.child(Label::new(description).color(Color::Muted)),
|
||||
)
|
||||
.child(plans_section)
|
||||
.child(
|
||||
footer_container
|
||||
.child(
|
||||
Button::new("start_trial", "Start 14-day Free Pro Trial")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.when_some(self.tab_index, |this, tab_index| {
|
||||
this.tab_index(tab_index)
|
||||
})
|
||||
.on_click(move |_, _window, cx| {
|
||||
telemetry::event!(
|
||||
"Start Trial Clicked",
|
||||
state = "post-sign-in"
|
||||
);
|
||||
cx.open_url(&zed_urls::start_trial_url(cx))
|
||||
}),
|
||||
.map(|this| {
|
||||
if self.account_too_young {
|
||||
this.child(young_account_banner).child(
|
||||
v_flex()
|
||||
.mt_2()
|
||||
.gap_1()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Label::new("Pro")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Accent)
|
||||
.buffer_font(cx),
|
||||
)
|
||||
.child(Divider::horizontal()),
|
||||
)
|
||||
.child(plan_definitions.pro_plan(true))
|
||||
.child(
|
||||
Button::new("pro", "Get Started")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.on_click(move |_, _window, cx| {
|
||||
telemetry::event!(
|
||||
"Upgrade To Pro Clicked",
|
||||
state = "young-account"
|
||||
);
|
||||
cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))
|
||||
}),
|
||||
),
|
||||
)
|
||||
} else {
|
||||
this.child(
|
||||
div()
|
||||
.max_w_3_4()
|
||||
.mb_2()
|
||||
.child(Label::new(description).color(Color::Muted)),
|
||||
)
|
||||
.child(plans_section)
|
||||
.child(
|
||||
Label::new("No credit card required")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
),
|
||||
footer_container
|
||||
.child(
|
||||
Button::new("start_trial", "Start 14-day Free Pro Trial")
|
||||
.full_width()
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
|
||||
.when_some(self.tab_index, |this, tab_index| {
|
||||
this.tab_index(tab_index)
|
||||
})
|
||||
.on_click(move |_, _window, cx| {
|
||||
telemetry::event!(
|
||||
"Start Trial Clicked",
|
||||
state = "post-sign-in"
|
||||
);
|
||||
cx.open_url(&zed_urls::start_trial_url(cx))
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
Label::new("No credit card required")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
}
|
||||
}),
|
||||
Some(Plan::ZedProTrial) => card
|
||||
.child(pro_trial_stamp)
|
||||
.child(Label::new("You're in the Zed Pro Trial").size(LabelSize::Large))
|
||||
@@ -255,48 +298,68 @@ impl Component for AiUpsellCard {
|
||||
Some(
|
||||
v_flex()
|
||||
.gap_4()
|
||||
.children(vec![example_group(vec![
|
||||
single_example(
|
||||
"Signed Out State",
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedOut,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
user_plan: None,
|
||||
tab_index: Some(0),
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Free Plan",
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedIn,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
user_plan: Some(Plan::ZedFree),
|
||||
tab_index: Some(1),
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Pro Trial",
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedIn,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
user_plan: Some(Plan::ZedProTrial),
|
||||
tab_index: Some(1),
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Pro Plan",
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedIn,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
user_plan: Some(Plan::ZedPro),
|
||||
tab_index: Some(1),
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
])])
|
||||
.items_center()
|
||||
.max_w_4_5()
|
||||
.child(single_example(
|
||||
"Signed Out State",
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedOut,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
account_too_young: false,
|
||||
user_plan: None,
|
||||
tab_index: Some(0),
|
||||
}
|
||||
.into_any_element(),
|
||||
))
|
||||
.child(example_group_with_title(
|
||||
"Signed In States",
|
||||
vec![
|
||||
single_example(
|
||||
"Free Plan",
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedIn,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
account_too_young: false,
|
||||
user_plan: Some(Plan::ZedFree),
|
||||
tab_index: Some(1),
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Free Plan but Young Account",
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedIn,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
account_too_young: true,
|
||||
user_plan: Some(Plan::ZedFree),
|
||||
tab_index: Some(1),
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Pro Trial",
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedIn,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
account_too_young: false,
|
||||
user_plan: Some(Plan::ZedProTrial),
|
||||
tab_index: Some(1),
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Pro Plan",
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedIn,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
account_too_young: false,
|
||||
user_plan: Some(Plan::ZedPro),
|
||||
tab_index: Some(1),
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
],
|
||||
))
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -36,13 +36,12 @@ use crate::delete_path_tool::DeletePathTool;
|
||||
use crate::diagnostics_tool::DiagnosticsTool;
|
||||
use crate::edit_file_tool::EditFileTool;
|
||||
use crate::fetch_tool::FetchTool;
|
||||
use crate::find_path_tool::FindPathTool;
|
||||
use crate::list_directory_tool::ListDirectoryTool;
|
||||
use crate::now_tool::NowTool;
|
||||
use crate::thinking_tool::ThinkingTool;
|
||||
|
||||
pub use edit_file_tool::{EditFileMode, EditFileToolInput};
|
||||
pub use find_path_tool::FindPathToolInput;
|
||||
pub use find_path_tool::*;
|
||||
pub use grep_tool::{GrepTool, GrepToolInput};
|
||||
pub use open_tool::OpenTool;
|
||||
pub use project_notifications_tool::ProjectNotificationsTool;
|
||||
|
||||
@@ -37,7 +37,7 @@ impl Tool for ThinkingTool {
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::ToolBulb
|
||||
IconName::ToolThink
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
|
||||
@@ -134,10 +134,15 @@ impl Settings for AutoUpdateSetting {
|
||||
type FileContent = Option<AutoUpdateSettingContent>;
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
||||
let auto_update = [sources.server, sources.release_channel, sources.user]
|
||||
.into_iter()
|
||||
.find_map(|value| value.copied().flatten())
|
||||
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?);
|
||||
let auto_update = [
|
||||
sources.server,
|
||||
sources.release_channel,
|
||||
sources.operating_system,
|
||||
sources.user,
|
||||
]
|
||||
.into_iter()
|
||||
.find_map(|value| value.copied().flatten())
|
||||
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?);
|
||||
|
||||
Ok(Self(auto_update.0))
|
||||
}
|
||||
|
||||
@@ -226,35 +226,17 @@ impl UserStore {
|
||||
match status {
|
||||
Status::Authenticated | Status::Connected { .. } => {
|
||||
if let Some(user_id) = client.user_id() {
|
||||
let response = client
|
||||
.cloud_client()
|
||||
.get_authenticated_user()
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
let current_user_and_response = if let Some(response) = response {
|
||||
let user = Arc::new(User {
|
||||
id: user_id,
|
||||
github_login: response.user.github_login.clone().into(),
|
||||
avatar_uri: response.user.avatar_url.clone().into(),
|
||||
name: response.user.name.clone(),
|
||||
});
|
||||
|
||||
Some((user, response))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
current_user_tx
|
||||
.send(
|
||||
current_user_and_response
|
||||
.as_ref()
|
||||
.map(|(user, _)| user.clone()),
|
||||
)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
let response = client.cloud_client().get_authenticated_user().await;
|
||||
let mut current_user = None;
|
||||
cx.update(|cx| {
|
||||
if let Some((user, response)) = current_user_and_response {
|
||||
if let Some(response) = response.log_err() {
|
||||
let user = Arc::new(User {
|
||||
id: user_id,
|
||||
github_login: response.user.github_login.clone().into(),
|
||||
avatar_uri: response.user.avatar_url.clone().into(),
|
||||
name: response.user.name.clone(),
|
||||
});
|
||||
current_user = Some(user.clone());
|
||||
this.update(cx, |this, cx| {
|
||||
this.by_github_login
|
||||
.insert(user.github_login.clone(), user_id);
|
||||
@@ -265,6 +247,7 @@ impl UserStore {
|
||||
anyhow::Ok(())
|
||||
}
|
||||
})??;
|
||||
current_user_tx.send(current_user).await.ok();
|
||||
|
||||
this.update(cx, |_, cx| cx.notify())?;
|
||||
}
|
||||
|
||||
@@ -2,5 +2,6 @@ ZED_ENVIRONMENT=production
|
||||
RUST_LOG=info
|
||||
INVITE_LINK_PREFIX=https://zed.dev/invites/
|
||||
AUTO_JOIN_CHANNEL_ID=283
|
||||
DATABASE_MAX_CONNECTIONS=250
|
||||
# Set DATABASE_MAX_CONNECTIONS max connections in the `deploy_collab.yml`:
|
||||
# https://github.com/zed-industries/zed/blob/main/.github/workflows/deploy_collab.yml
|
||||
LLM_DATABASE_MAX_CONNECTIONS=25
|
||||
|
||||
@@ -1,477 +1,10 @@
|
||||
use anyhow::{Context as _, bail};
|
||||
use chrono::{DateTime, Utc};
|
||||
use sea_orm::ActiveValue;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus};
|
||||
use util::ResultExt;
|
||||
use std::sync::Arc;
|
||||
use stripe::SubscriptionStatus;
|
||||
|
||||
use crate::AppState;
|
||||
use crate::db::billing_subscription::{
|
||||
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
|
||||
};
|
||||
use crate::db::{
|
||||
CreateBillingCustomerParams, CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
|
||||
UpdateBillingCustomerParams, UpdateBillingSubscriptionParams, billing_customer,
|
||||
};
|
||||
use crate::rpc::{ResultExt as _, Server};
|
||||
use crate::stripe_client::{
|
||||
StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
|
||||
StripeSubscriptionId,
|
||||
};
|
||||
|
||||
/// The amount of time we wait in between each poll of Stripe events.
|
||||
///
|
||||
/// This value should strike a balance between:
|
||||
/// 1. Being short enough that we update quickly when something in Stripe changes
|
||||
/// 2. Being long enough that we don't eat into our rate limits.
|
||||
///
|
||||
/// As a point of reference, the Sequin folks say they have this at **500ms**:
|
||||
///
|
||||
/// > We poll the Stripe /events endpoint every 500ms per account
|
||||
/// >
|
||||
/// > — https://blog.sequinstream.com/events-not-webhooks/
|
||||
const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
/// The maximum number of events to return per page.
|
||||
///
|
||||
/// We set this to 100 (the max) so we have to make fewer requests to Stripe.
|
||||
///
|
||||
/// > Limit can range between 1 and 100, and the default is 10.
|
||||
const EVENTS_LIMIT_PER_PAGE: u64 = 100;
|
||||
|
||||
/// The number of pages consisting entirely of already-processed events that we
|
||||
/// will see before we stop retrieving events.
|
||||
///
|
||||
/// This is used to prevent over-fetching the Stripe events API for events we've
|
||||
/// already seen and processed.
|
||||
const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
|
||||
|
||||
/// Polls the Stripe events API periodically to reconcile the records in our
|
||||
/// database with the data in Stripe.
|
||||
pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
|
||||
let Some(real_stripe_client) = app.real_stripe_client.clone() else {
|
||||
log::warn!("failed to retrieve Stripe client");
|
||||
return;
|
||||
};
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::warn!("failed to retrieve Stripe client");
|
||||
return;
|
||||
};
|
||||
|
||||
let executor = app.executor.clone();
|
||||
executor.spawn_detached({
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
executor.sleep(POLL_EVENTS_INTERVAL).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn poll_stripe_events(
|
||||
app: &Arc<AppState>,
|
||||
rpc_server: &Arc<Server>,
|
||||
stripe_client: &Arc<dyn StripeClient>,
|
||||
real_stripe_client: &stripe::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
let feature_flags = app.db.list_feature_flags().await?;
|
||||
let sync_events_using_cloud = feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag.flag == "cloud-stripe-events-polling" && flag.enabled_for_all);
|
||||
if sync_events_using_cloud {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
fn event_type_to_string(event_type: EventType) -> String {
|
||||
// Calling `to_string` on `stripe::EventType` members gives us a quoted string,
|
||||
// so we need to unquote it.
|
||||
event_type.to_string().trim_matches('"').to_string()
|
||||
}
|
||||
|
||||
let event_types = [
|
||||
EventType::CustomerCreated,
|
||||
EventType::CustomerUpdated,
|
||||
EventType::CustomerSubscriptionCreated,
|
||||
EventType::CustomerSubscriptionUpdated,
|
||||
EventType::CustomerSubscriptionPaused,
|
||||
EventType::CustomerSubscriptionResumed,
|
||||
EventType::CustomerSubscriptionDeleted,
|
||||
]
|
||||
.into_iter()
|
||||
.map(event_type_to_string)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut pages_of_already_processed_events = 0;
|
||||
let mut unprocessed_events = Vec::new();
|
||||
|
||||
log::info!(
|
||||
"Stripe events: starting retrieval for {}",
|
||||
event_types.join(", ")
|
||||
);
|
||||
let mut params = ListEvents::new();
|
||||
params.types = Some(event_types.clone());
|
||||
params.limit = Some(EVENTS_LIMIT_PER_PAGE);
|
||||
|
||||
let mut event_pages = stripe::Event::list(&real_stripe_client, ¶ms)
|
||||
.await?
|
||||
.paginate(params);
|
||||
|
||||
loop {
|
||||
let processed_event_ids = {
|
||||
let event_ids = event_pages
|
||||
.page
|
||||
.data
|
||||
.iter()
|
||||
.map(|event| event.id.as_str())
|
||||
.collect::<Vec<_>>();
|
||||
app.db
|
||||
.get_processed_stripe_events_by_event_ids(&event_ids)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|event| event.stripe_event_id)
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let mut processed_events_in_page = 0;
|
||||
let events_in_page = event_pages.page.data.len();
|
||||
for event in &event_pages.page.data {
|
||||
if processed_event_ids.contains(&event.id.to_string()) {
|
||||
processed_events_in_page += 1;
|
||||
log::debug!("Stripe events: already processed '{}', skipping", event.id);
|
||||
} else {
|
||||
unprocessed_events.push(event.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if processed_events_in_page == events_in_page {
|
||||
pages_of_already_processed_events += 1;
|
||||
}
|
||||
|
||||
if event_pages.page.has_more {
|
||||
if pages_of_already_processed_events >= NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP
|
||||
{
|
||||
log::info!(
|
||||
"Stripe events: stopping, saw {pages_of_already_processed_events} pages of already-processed events"
|
||||
);
|
||||
break;
|
||||
} else {
|
||||
log::info!("Stripe events: retrieving next page");
|
||||
event_pages = event_pages.next(&real_stripe_client).await?;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Stripe events: unprocessed {}", unprocessed_events.len());
|
||||
|
||||
// Sort all of the unprocessed events in ascending order, so we can handle them in the order they occurred.
|
||||
unprocessed_events.sort_by(|a, b| a.created.cmp(&b.created).then_with(|| a.id.cmp(&b.id)));
|
||||
|
||||
for event in unprocessed_events {
|
||||
let event_id = event.id.clone();
|
||||
let processed_event_params = CreateProcessedStripeEventParams {
|
||||
stripe_event_id: event.id.to_string(),
|
||||
stripe_event_type: event_type_to_string(event.type_),
|
||||
stripe_event_created_timestamp: event.created,
|
||||
};
|
||||
|
||||
// If the event has happened too far in the past, we don't want to
|
||||
// process it and risk overwriting other more-recent updates.
|
||||
//
|
||||
// 1 day was chosen arbitrarily. This could be made longer or shorter.
|
||||
let one_day = Duration::from_secs(24 * 60 * 60);
|
||||
let a_day_ago = Utc::now() - one_day;
|
||||
if a_day_ago.timestamp() > event.created {
|
||||
log::info!(
|
||||
"Stripe events: event '{}' is more than {one_day:?} old, marking as processed",
|
||||
event_id
|
||||
);
|
||||
app.db
|
||||
.create_processed_stripe_event(&processed_event_params)
|
||||
.await?;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
let process_result = match event.type_ {
|
||||
EventType::CustomerCreated | EventType::CustomerUpdated => {
|
||||
handle_customer_event(app, real_stripe_client, event).await
|
||||
}
|
||||
EventType::CustomerSubscriptionCreated
|
||||
| EventType::CustomerSubscriptionUpdated
|
||||
| EventType::CustomerSubscriptionPaused
|
||||
| EventType::CustomerSubscriptionResumed
|
||||
| EventType::CustomerSubscriptionDeleted => {
|
||||
handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
|
||||
}
|
||||
_ => Ok(()),
|
||||
};
|
||||
|
||||
if let Some(()) = process_result
|
||||
.with_context(|| format!("failed to process event {event_id} successfully"))
|
||||
.log_err()
|
||||
{
|
||||
app.db
|
||||
.create_processed_stripe_event(&processed_event_params)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_customer_event(
|
||||
app: &Arc<AppState>,
|
||||
_stripe_client: &stripe::Client,
|
||||
event: stripe::Event,
|
||||
) -> anyhow::Result<()> {
|
||||
let EventObject::Customer(customer) = event.data.object else {
|
||||
bail!("unexpected event payload for {}", event.id);
|
||||
};
|
||||
|
||||
log::info!("handling Stripe {} event: {}", event.type_, event.id);
|
||||
|
||||
let Some(email) = customer.email else {
|
||||
log::info!("Stripe customer has no email: skipping");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let Some(user) = app.db.get_user_by_email(&email).await? else {
|
||||
log::info!("no user found for email: skipping");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if let Some(existing_customer) = app
|
||||
.db
|
||||
.get_billing_customer_by_stripe_customer_id(&customer.id)
|
||||
.await?
|
||||
{
|
||||
app.db
|
||||
.update_billing_customer(
|
||||
existing_customer.id,
|
||||
&UpdateBillingCustomerParams {
|
||||
// For now we just leave the information as-is, as it is not
|
||||
// likely to change.
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
app.db
|
||||
.create_billing_customer(&CreateBillingCustomerParams {
|
||||
user_id: user.id,
|
||||
stripe_customer_id: customer.id.to_string(),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn sync_subscription(
|
||||
app: &Arc<AppState>,
|
||||
stripe_client: &Arc<dyn StripeClient>,
|
||||
subscription: StripeSubscription,
|
||||
) -> anyhow::Result<billing_customer::Model> {
|
||||
let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
|
||||
stripe_billing
|
||||
.determine_subscription_kind(&subscription)
|
||||
.await
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let billing_customer =
|
||||
find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
|
||||
.await?
|
||||
.context("billing customer not found")?;
|
||||
|
||||
if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
|
||||
if subscription.status == SubscriptionStatus::Trialing {
|
||||
let current_period_start =
|
||||
DateTime::from_timestamp(subscription.current_period_start, 0)
|
||||
.context("No trial subscription period start")?;
|
||||
|
||||
app.db
|
||||
.update_billing_customer(
|
||||
billing_customer.id,
|
||||
&UpdateBillingCustomerParams {
|
||||
trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
|
||||
&& subscription
|
||||
.cancellation_details
|
||||
.as_ref()
|
||||
.and_then(|details| details.reason)
|
||||
.map_or(false, |reason| {
|
||||
reason == StripeCancellationDetailsReason::PaymentFailed
|
||||
});
|
||||
|
||||
if was_canceled_due_to_payment_failure {
|
||||
app.db
|
||||
.update_billing_customer(
|
||||
billing_customer.id,
|
||||
&UpdateBillingCustomerParams {
|
||||
has_overdue_invoices: ActiveValue::set(true),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if let Some(existing_subscription) = app
|
||||
.db
|
||||
.get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
|
||||
.await?
|
||||
{
|
||||
app.db
|
||||
.update_billing_subscription(
|
||||
existing_subscription.id,
|
||||
&UpdateBillingSubscriptionParams {
|
||||
billing_customer_id: ActiveValue::set(billing_customer.id),
|
||||
kind: ActiveValue::set(subscription_kind),
|
||||
stripe_subscription_id: ActiveValue::set(subscription.id.to_string()),
|
||||
stripe_subscription_status: ActiveValue::set(subscription.status.into()),
|
||||
stripe_cancel_at: ActiveValue::set(
|
||||
subscription
|
||||
.cancel_at
|
||||
.and_then(|cancel_at| DateTime::from_timestamp(cancel_at, 0))
|
||||
.map(|time| time.naive_utc()),
|
||||
),
|
||||
stripe_cancellation_reason: ActiveValue::set(
|
||||
subscription
|
||||
.cancellation_details
|
||||
.and_then(|details| details.reason)
|
||||
.map(|reason| reason.into()),
|
||||
),
|
||||
stripe_current_period_start: ActiveValue::set(Some(
|
||||
subscription.current_period_start,
|
||||
)),
|
||||
stripe_current_period_end: ActiveValue::set(Some(
|
||||
subscription.current_period_end,
|
||||
)),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
if let Some(existing_subscription) = app
|
||||
.db
|
||||
.get_active_billing_subscription(billing_customer.user_id)
|
||||
.await?
|
||||
{
|
||||
if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
|
||||
&& subscription_kind == Some(SubscriptionKind::ZedProTrial)
|
||||
{
|
||||
let stripe_subscription_id = StripeSubscriptionId(
|
||||
existing_subscription.stripe_subscription_id.clone().into(),
|
||||
);
|
||||
|
||||
stripe_client
|
||||
.cancel_subscription(&stripe_subscription_id)
|
||||
.await?;
|
||||
} else {
|
||||
// If the user already has an active billing subscription, ignore the
|
||||
// event and return an `Ok` to signal that it was processed
|
||||
// successfully.
|
||||
//
|
||||
// There is the possibility that this could cause us to not create a
|
||||
// subscription in the following scenario:
|
||||
//
|
||||
// 1. User has an active subscription A
|
||||
// 2. User cancels subscription A
|
||||
// 3. User creates a new subscription B
|
||||
// 4. We process the new subscription B before the cancellation of subscription A
|
||||
// 5. User ends up with no subscriptions
|
||||
//
|
||||
// In theory this situation shouldn't arise as we try to process the events in the order they occur.
|
||||
|
||||
log::info!(
|
||||
"user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
|
||||
user_id = billing_customer.user_id,
|
||||
subscription_id = subscription.id
|
||||
);
|
||||
return Ok(billing_customer);
|
||||
}
|
||||
}
|
||||
|
||||
app.db
|
||||
.create_billing_subscription(&CreateBillingSubscriptionParams {
|
||||
billing_customer_id: billing_customer.id,
|
||||
kind: subscription_kind,
|
||||
stripe_subscription_id: subscription.id.to_string(),
|
||||
stripe_subscription_status: subscription.status.into(),
|
||||
stripe_cancellation_reason: subscription
|
||||
.cancellation_details
|
||||
.and_then(|details| details.reason)
|
||||
.map(|reason| reason.into()),
|
||||
stripe_current_period_start: Some(subscription.current_period_start),
|
||||
stripe_current_period_end: Some(subscription.current_period_end),
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
if let Some(stripe_billing) = app.stripe_billing.as_ref() {
|
||||
if subscription.status == SubscriptionStatus::Canceled
|
||||
|| subscription.status == SubscriptionStatus::Paused
|
||||
{
|
||||
let already_has_active_billing_subscription = app
|
||||
.db
|
||||
.has_active_billing_subscription(billing_customer.user_id)
|
||||
.await?;
|
||||
if !already_has_active_billing_subscription {
|
||||
let stripe_customer_id =
|
||||
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
|
||||
|
||||
stripe_billing
|
||||
.subscribe_to_zed_free(stripe_customer_id)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(billing_customer)
|
||||
}
|
||||
|
||||
async fn handle_customer_subscription_event(
|
||||
app: &Arc<AppState>,
|
||||
rpc_server: &Arc<Server>,
|
||||
stripe_client: &Arc<dyn StripeClient>,
|
||||
event: stripe::Event,
|
||||
) -> anyhow::Result<()> {
|
||||
let EventObject::Subscription(subscription) = event.data.object else {
|
||||
bail!("unexpected event payload for {}", event.id);
|
||||
};
|
||||
|
||||
log::info!("handling Stripe {} event: {}", event.type_, event.id);
|
||||
|
||||
let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
|
||||
|
||||
// When the user's subscription changes, push down any changes to their plan.
|
||||
rpc_server
|
||||
.update_plan_for_user_legacy(billing_customer.user_id)
|
||||
.await
|
||||
.trace_err();
|
||||
|
||||
// When the user's subscription changes, we want to refresh their LLM tokens
|
||||
// to either grant/revoke access.
|
||||
rpc_server
|
||||
.refresh_llm_tokens_for_user(billing_customer.user_id)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
use crate::db::billing_subscription::StripeSubscriptionStatus;
|
||||
use crate::db::{CreateBillingCustomerParams, billing_customer};
|
||||
use crate::stripe_client::{StripeClient, StripeCustomerId};
|
||||
|
||||
impl From<SubscriptionStatus> for StripeSubscriptionStatus {
|
||||
fn from(value: SubscriptionStatus) -> Self {
|
||||
@@ -488,16 +21,6 @@ impl From<SubscriptionStatus> for StripeSubscriptionStatus {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CancellationDetailsReason> for StripeCancellationReason {
|
||||
fn from(value: CancellationDetailsReason) -> Self {
|
||||
match value {
|
||||
CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
|
||||
CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
|
||||
CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Finds or creates a billing customer using the provided customer.
|
||||
pub async fn find_or_create_billing_customer(
|
||||
app: &Arc<AppState>,
|
||||
|
||||
@@ -7,6 +7,7 @@ use axum::{
|
||||
routing::get,
|
||||
};
|
||||
|
||||
use collab::ServiceMode;
|
||||
use collab::api::CloudflareIpCountryHeader;
|
||||
use collab::llm::db::LlmDatabase;
|
||||
use collab::migrations::run_database_migrations;
|
||||
@@ -15,7 +16,6 @@ use collab::{
|
||||
AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
|
||||
executor::Executor, rpc::ResultExt,
|
||||
};
|
||||
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically};
|
||||
use db::Database;
|
||||
use std::{
|
||||
env::args,
|
||||
@@ -119,8 +119,6 @@ async fn main() -> Result<()> {
|
||||
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
|
||||
rpc_server.start().await?;
|
||||
|
||||
poll_stripe_events_periodically(state.clone(), rpc_server.clone());
|
||||
|
||||
app = app
|
||||
.merge(collab::api::routes(rpc_server.clone()))
|
||||
.merge(collab::rpc::routes(rpc_server.clone()));
|
||||
|
||||
@@ -41,9 +41,11 @@ use chrono::Utc;
|
||||
use collections::{HashMap, HashSet};
|
||||
pub use connection_pool::{ConnectionPool, ZedVersion};
|
||||
use core::fmt::{self, Debug, Formatter};
|
||||
use futures::TryFutureExt as _;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use rpc::proto::{MultiLspQuery, split_repository_update};
|
||||
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
|
||||
use tracing::Span;
|
||||
|
||||
use futures::{
|
||||
FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
|
||||
@@ -94,8 +96,13 @@ const MAX_CONCURRENT_CONNECTIONS: usize = 512;
|
||||
|
||||
static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
const TOTAL_DURATION_MS: &str = "total_duration_ms";
|
||||
const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
|
||||
const QUEUE_DURATION_MS: &str = "queue_duration_ms";
|
||||
const HOST_WAITING_MS: &str = "host_waiting_ms";
|
||||
|
||||
type MessageHandler =
|
||||
Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
|
||||
Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
|
||||
|
||||
pub struct ConnectionGuard;
|
||||
|
||||
@@ -163,6 +170,42 @@ impl Principal {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MessageContext {
|
||||
session: Session,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Deref for MessageContext {
|
||||
type Target = Session;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.session
|
||||
}
|
||||
}
|
||||
|
||||
impl MessageContext {
|
||||
pub fn forward_request<T: RequestMessage>(
|
||||
&self,
|
||||
receiver_id: ConnectionId,
|
||||
request: T,
|
||||
) -> impl Future<Output = anyhow::Result<T::Response>> {
|
||||
let request_start_time = Instant::now();
|
||||
let span = self.span.clone();
|
||||
tracing::info!("start forwarding request");
|
||||
self.peer
|
||||
.forward_request(self.connection_id, receiver_id, request)
|
||||
.inspect(move |_| {
|
||||
span.record(
|
||||
HOST_WAITING_MS,
|
||||
request_start_time.elapsed().as_micros() as f64 / 1000.0,
|
||||
);
|
||||
})
|
||||
.inspect_err(|_| tracing::error!("error forwarding request"))
|
||||
.inspect_ok(|_| tracing::info!("finished forwarding request"))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Session {
|
||||
principal: Principal,
|
||||
@@ -646,40 +689,37 @@ impl Server {
|
||||
|
||||
fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
|
||||
F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
|
||||
Fut: 'static + Send + Future<Output = Result<()>>,
|
||||
M: EnvelopedMessage,
|
||||
{
|
||||
let prev_handler = self.handlers.insert(
|
||||
TypeId::of::<M>(),
|
||||
Box::new(move |envelope, session| {
|
||||
Box::new(move |envelope, session, span| {
|
||||
let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
|
||||
let received_at = envelope.received_at;
|
||||
tracing::info!("message received");
|
||||
let start_time = Instant::now();
|
||||
let future = (handler)(*envelope, session);
|
||||
let future = (handler)(
|
||||
*envelope,
|
||||
MessageContext {
|
||||
session,
|
||||
span: span.clone(),
|
||||
},
|
||||
);
|
||||
async move {
|
||||
let result = future.await;
|
||||
let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
|
||||
let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
|
||||
let queue_duration_ms = total_duration_ms - processing_duration_ms;
|
||||
|
||||
span.record(TOTAL_DURATION_MS, total_duration_ms);
|
||||
span.record(PROCESSING_DURATION_MS, processing_duration_ms);
|
||||
span.record(QUEUE_DURATION_MS, queue_duration_ms);
|
||||
match result {
|
||||
Err(error) => {
|
||||
tracing::error!(
|
||||
?error,
|
||||
total_duration_ms,
|
||||
processing_duration_ms,
|
||||
queue_duration_ms,
|
||||
"error handling message"
|
||||
)
|
||||
tracing::error!(?error, "error handling message")
|
||||
}
|
||||
Ok(()) => tracing::info!(
|
||||
total_duration_ms,
|
||||
processing_duration_ms,
|
||||
queue_duration_ms,
|
||||
"finished handling message"
|
||||
),
|
||||
Ok(()) => tracing::info!("finished handling message"),
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
@@ -693,7 +733,7 @@ impl Server {
|
||||
|
||||
fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(M, Session) -> Fut,
|
||||
F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
|
||||
Fut: 'static + Send + Future<Output = Result<()>>,
|
||||
M: EnvelopedMessage,
|
||||
{
|
||||
@@ -703,7 +743,7 @@ impl Server {
|
||||
|
||||
fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
|
||||
where
|
||||
F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
|
||||
F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
|
||||
Fut: Send + Future<Output = Result<()>>,
|
||||
M: RequestMessage,
|
||||
{
|
||||
@@ -746,6 +786,7 @@ impl Server {
|
||||
address: String,
|
||||
principal: Principal,
|
||||
zed_version: ZedVersion,
|
||||
release_channel: Option<String>,
|
||||
user_agent: Option<String>,
|
||||
geoip_country_code: Option<String>,
|
||||
system_id: Option<String>,
|
||||
@@ -760,12 +801,16 @@ impl Server {
|
||||
login=field::Empty,
|
||||
impersonator=field::Empty,
|
||||
user_agent=field::Empty,
|
||||
geoip_country_code=field::Empty
|
||||
geoip_country_code=field::Empty,
|
||||
release_channel=field::Empty,
|
||||
);
|
||||
principal.update_span(&span);
|
||||
if let Some(user_agent) = user_agent {
|
||||
span.record("user_agent", user_agent);
|
||||
}
|
||||
if let Some(release_channel) = release_channel {
|
||||
span.record("release_channel", release_channel);
|
||||
}
|
||||
|
||||
if let Some(country_code) = geoip_country_code.as_ref() {
|
||||
span.record("geoip_country_code", country_code);
|
||||
@@ -884,12 +929,16 @@ impl Server {
|
||||
login=field::Empty,
|
||||
impersonator=field::Empty,
|
||||
multi_lsp_query_request=field::Empty,
|
||||
{ TOTAL_DURATION_MS }=field::Empty,
|
||||
{ PROCESSING_DURATION_MS }=field::Empty,
|
||||
{ QUEUE_DURATION_MS }=field::Empty,
|
||||
{ HOST_WAITING_MS }=field::Empty
|
||||
);
|
||||
principal.update_span(&span);
|
||||
let span_enter = span.enter();
|
||||
if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
|
||||
let is_background = message.is_background();
|
||||
let handle_message = (handler)(message, session.clone());
|
||||
let handle_message = (handler)(message, session.clone(), span.clone());
|
||||
drop(span_enter);
|
||||
|
||||
let handle_message = async move {
|
||||
@@ -1181,6 +1230,35 @@ impl Header for AppVersionHeader {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ReleaseChannelHeader(String);
|
||||
|
||||
impl Header for ReleaseChannelHeader {
|
||||
fn name() -> &'static HeaderName {
|
||||
static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
|
||||
ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
|
||||
}
|
||||
|
||||
fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
|
||||
where
|
||||
Self: Sized,
|
||||
I: Iterator<Item = &'i axum::http::HeaderValue>,
|
||||
{
|
||||
Ok(Self(
|
||||
values
|
||||
.next()
|
||||
.ok_or_else(axum::headers::Error::invalid)?
|
||||
.to_str()
|
||||
.map_err(|_| axum::headers::Error::invalid())?
|
||||
.to_owned(),
|
||||
))
|
||||
}
|
||||
|
||||
fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
|
||||
values.extend([self.0.parse().unwrap()]);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn routes(server: Arc<Server>) -> Router<(), Body> {
|
||||
Router::new()
|
||||
.route("/rpc", get(handle_websocket_request))
|
||||
@@ -1196,6 +1274,7 @@ pub fn routes(server: Arc<Server>) -> Router<(), Body> {
|
||||
pub async fn handle_websocket_request(
|
||||
TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
|
||||
app_version_header: Option<TypedHeader<AppVersionHeader>>,
|
||||
release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
|
||||
ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
|
||||
Extension(server): Extension<Arc<Server>>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
@@ -1220,6 +1299,8 @@ pub async fn handle_websocket_request(
|
||||
.into_response();
|
||||
};
|
||||
|
||||
let release_channel = release_channel_header.map(|header| header.0.0);
|
||||
|
||||
if !version.can_collaborate() {
|
||||
return (
|
||||
StatusCode::UPGRADE_REQUIRED,
|
||||
@@ -1255,6 +1336,7 @@ pub async fn handle_websocket_request(
|
||||
socket_address,
|
||||
principal,
|
||||
version,
|
||||
release_channel,
|
||||
user_agent.map(|header| header.to_string()),
|
||||
country_code_header.map(|header| header.to_string()),
|
||||
system_id_header.map(|header| header.to_string()),
|
||||
@@ -1348,7 +1430,11 @@ async fn connection_lost(
|
||||
}
|
||||
|
||||
/// Acknowledges a ping from a client, used to keep the connection alive.
|
||||
async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
|
||||
async fn ping(
|
||||
_: proto::Ping,
|
||||
response: Response<proto::Ping>,
|
||||
_session: MessageContext,
|
||||
) -> Result<()> {
|
||||
response.send(proto::Ack {})?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1357,7 +1443,7 @@ async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session
|
||||
async fn create_room(
|
||||
_request: proto::CreateRoom,
|
||||
response: Response<proto::CreateRoom>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let livekit_room = nanoid::nanoid!(30);
|
||||
|
||||
@@ -1397,7 +1483,7 @@ async fn create_room(
|
||||
async fn join_room(
|
||||
request: proto::JoinRoom,
|
||||
response: Response<proto::JoinRoom>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let room_id = RoomId::from_proto(request.id);
|
||||
|
||||
@@ -1464,7 +1550,7 @@ async fn join_room(
|
||||
async fn rejoin_room(
|
||||
request: proto::RejoinRoom,
|
||||
response: Response<proto::RejoinRoom>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let room;
|
||||
let channel;
|
||||
@@ -1592,15 +1678,15 @@ fn notify_rejoined_projects(
|
||||
}
|
||||
|
||||
// Stream this worktree's diagnostics.
|
||||
for summary in worktree.diagnostic_summaries {
|
||||
session.peer.send(
|
||||
session.connection_id,
|
||||
proto::UpdateDiagnosticSummary {
|
||||
project_id: project.id.to_proto(),
|
||||
worktree_id: worktree.id,
|
||||
summary: Some(summary),
|
||||
},
|
||||
)?;
|
||||
let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
|
||||
if let Some(summary) = worktree_diagnostics.next() {
|
||||
let message = proto::UpdateDiagnosticSummary {
|
||||
project_id: project.id.to_proto(),
|
||||
worktree_id: worktree.id,
|
||||
summary: Some(summary),
|
||||
more_summaries: worktree_diagnostics.collect(),
|
||||
};
|
||||
session.peer.send(session.connection_id, message)?;
|
||||
}
|
||||
|
||||
for settings_file in worktree.settings_files {
|
||||
@@ -1641,7 +1727,7 @@ fn notify_rejoined_projects(
|
||||
async fn leave_room(
|
||||
_: proto::LeaveRoom,
|
||||
response: Response<proto::LeaveRoom>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
leave_room_for_session(&session, session.connection_id).await?;
|
||||
response.send(proto::Ack {})?;
|
||||
@@ -1652,7 +1738,7 @@ async fn leave_room(
|
||||
async fn set_room_participant_role(
|
||||
request: proto::SetRoomParticipantRole,
|
||||
response: Response<proto::SetRoomParticipantRole>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let user_id = UserId::from_proto(request.user_id);
|
||||
let role = ChannelRole::from(request.role());
|
||||
@@ -1700,7 +1786,7 @@ async fn set_room_participant_role(
|
||||
async fn call(
|
||||
request: proto::Call,
|
||||
response: Response<proto::Call>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let room_id = RoomId::from_proto(request.room_id);
|
||||
let calling_user_id = session.user_id();
|
||||
@@ -1769,7 +1855,7 @@ async fn call(
|
||||
async fn cancel_call(
|
||||
request: proto::CancelCall,
|
||||
response: Response<proto::CancelCall>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let called_user_id = UserId::from_proto(request.called_user_id);
|
||||
let room_id = RoomId::from_proto(request.room_id);
|
||||
@@ -1804,7 +1890,7 @@ async fn cancel_call(
|
||||
}
|
||||
|
||||
/// Decline an incoming call.
|
||||
async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
|
||||
async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
|
||||
let room_id = RoomId::from_proto(message.room_id);
|
||||
{
|
||||
let room = session
|
||||
@@ -1839,7 +1925,7 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<(
|
||||
async fn update_participant_location(
|
||||
request: proto::UpdateParticipantLocation,
|
||||
response: Response<proto::UpdateParticipantLocation>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let room_id = RoomId::from_proto(request.room_id);
|
||||
let location = request.location.context("invalid location")?;
|
||||
@@ -1858,7 +1944,7 @@ async fn update_participant_location(
|
||||
async fn share_project(
|
||||
request: proto::ShareProject,
|
||||
response: Response<proto::ShareProject>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let (project_id, room) = &*session
|
||||
.db()
|
||||
@@ -1879,7 +1965,7 @@ async fn share_project(
|
||||
}
|
||||
|
||||
/// Unshare a project from the room.
|
||||
async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
|
||||
async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
|
||||
let project_id = ProjectId::from_proto(message.project_id);
|
||||
unshare_project_internal(project_id, session.connection_id, &session).await
|
||||
}
|
||||
@@ -1926,7 +2012,7 @@ async fn unshare_project_internal(
|
||||
async fn join_project(
|
||||
request: proto::JoinProject,
|
||||
response: Response<proto::JoinProject>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let project_id = ProjectId::from_proto(request.project_id);
|
||||
|
||||
@@ -2022,15 +2108,15 @@ async fn join_project(
|
||||
}
|
||||
|
||||
// Stream this worktree's diagnostics.
|
||||
for summary in worktree.diagnostic_summaries {
|
||||
session.peer.send(
|
||||
session.connection_id,
|
||||
proto::UpdateDiagnosticSummary {
|
||||
project_id: project_id.to_proto(),
|
||||
worktree_id: worktree.id,
|
||||
summary: Some(summary),
|
||||
},
|
||||
)?;
|
||||
let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
|
||||
if let Some(summary) = worktree_diagnostics.next() {
|
||||
let message = proto::UpdateDiagnosticSummary {
|
||||
project_id: project.id.to_proto(),
|
||||
worktree_id: worktree.id,
|
||||
summary: Some(summary),
|
||||
more_summaries: worktree_diagnostics.collect(),
|
||||
};
|
||||
session.peer.send(session.connection_id, message)?;
|
||||
}
|
||||
|
||||
for settings_file in worktree.settings_files {
|
||||
@@ -2073,7 +2159,7 @@ async fn join_project(
|
||||
}
|
||||
|
||||
/// Leave someone elses shared project.
|
||||
async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
|
||||
async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
|
||||
let sender_id = session.connection_id;
|
||||
let project_id = ProjectId::from_proto(request.project_id);
|
||||
let db = session.db().await;
|
||||
@@ -2096,7 +2182,7 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result
|
||||
async fn update_project(
|
||||
request: proto::UpdateProject,
|
||||
response: Response<proto::UpdateProject>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let project_id = ProjectId::from_proto(request.project_id);
|
||||
let (room, guest_connection_ids) = &*session
|
||||
@@ -2125,7 +2211,7 @@ async fn update_project(
|
||||
async fn update_worktree(
|
||||
request: proto::UpdateWorktree,
|
||||
response: Response<proto::UpdateWorktree>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let guest_connection_ids = session
|
||||
.db()
|
||||
@@ -2149,7 +2235,7 @@ async fn update_worktree(
|
||||
async fn update_repository(
|
||||
request: proto::UpdateRepository,
|
||||
response: Response<proto::UpdateRepository>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let guest_connection_ids = session
|
||||
.db()
|
||||
@@ -2173,7 +2259,7 @@ async fn update_repository(
|
||||
async fn remove_repository(
|
||||
request: proto::RemoveRepository,
|
||||
response: Response<proto::RemoveRepository>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let guest_connection_ids = session
|
||||
.db()
|
||||
@@ -2197,7 +2283,7 @@ async fn remove_repository(
|
||||
/// Updates other participants with changes to the diagnostics
|
||||
async fn update_diagnostic_summary(
|
||||
message: proto::UpdateDiagnosticSummary,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let guest_connection_ids = session
|
||||
.db()
|
||||
@@ -2221,7 +2307,7 @@ async fn update_diagnostic_summary(
|
||||
/// Updates other participants with changes to the worktree settings
|
||||
async fn update_worktree_settings(
|
||||
message: proto::UpdateWorktreeSettings,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let guest_connection_ids = session
|
||||
.db()
|
||||
@@ -2245,7 +2331,7 @@ async fn update_worktree_settings(
|
||||
/// Notify other participants that a language server has started.
|
||||
async fn start_language_server(
|
||||
request: proto::StartLanguageServer,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let guest_connection_ids = session
|
||||
.db()
|
||||
@@ -2268,7 +2354,7 @@ async fn start_language_server(
|
||||
/// Notify other participants that a language server has changed.
|
||||
async fn update_language_server(
|
||||
request: proto::UpdateLanguageServer,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let project_id = ProjectId::from_proto(request.project_id);
|
||||
let db = session.db().await;
|
||||
@@ -2301,7 +2387,7 @@ async fn update_language_server(
|
||||
async fn forward_read_only_project_request<T>(
|
||||
request: T,
|
||||
response: Response<T>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: EntityMessage + RequestMessage,
|
||||
@@ -2312,10 +2398,7 @@ where
|
||||
.await
|
||||
.host_for_read_only_project_request(project_id, session.connection_id)
|
||||
.await?;
|
||||
let payload = session
|
||||
.peer
|
||||
.forward_request(session.connection_id, host_connection_id, request)
|
||||
.await?;
|
||||
let payload = session.forward_request(host_connection_id, request).await?;
|
||||
response.send(payload)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -2325,7 +2408,7 @@ where
|
||||
async fn forward_mutating_project_request<T>(
|
||||
request: T,
|
||||
response: Response<T>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()>
|
||||
where
|
||||
T: EntityMessage + RequestMessage,
|
||||
@@ -2337,10 +2420,7 @@ where
|
||||
.await
|
||||
.host_for_mutating_project_request(project_id, session.connection_id)
|
||||
.await?;
|
||||
let payload = session
|
||||
.peer
|
||||
.forward_request(session.connection_id, host_connection_id, request)
|
||||
.await?;
|
||||
let payload = session.forward_request(host_connection_id, request).await?;
|
||||
response.send(payload)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -2348,7 +2428,7 @@ where
|
||||
async fn multi_lsp_query(
|
||||
request: MultiLspQuery,
|
||||
response: Response<MultiLspQuery>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
tracing::Span::current().record("multi_lsp_query_request", request.request_str());
|
||||
tracing::info!("multi_lsp_query message received");
|
||||
@@ -2358,7 +2438,7 @@ async fn multi_lsp_query(
|
||||
/// Notify other participants that a new buffer has been created
|
||||
async fn create_buffer_for_peer(
|
||||
request: proto::CreateBufferForPeer,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
session
|
||||
.db()
|
||||
@@ -2380,7 +2460,7 @@ async fn create_buffer_for_peer(
|
||||
async fn update_buffer(
|
||||
request: proto::UpdateBuffer,
|
||||
response: Response<proto::UpdateBuffer>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let project_id = ProjectId::from_proto(request.project_id);
|
||||
let mut capability = Capability::ReadOnly;
|
||||
@@ -2415,17 +2495,14 @@ async fn update_buffer(
|
||||
};
|
||||
|
||||
if host != session.connection_id {
|
||||
session
|
||||
.peer
|
||||
.forward_request(session.connection_id, host, request.clone())
|
||||
.await?;
|
||||
session.forward_request(host, request.clone()).await?;
|
||||
}
|
||||
|
||||
response.send(proto::Ack {})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
|
||||
async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
|
||||
let project_id = ProjectId::from_proto(message.project_id);
|
||||
|
||||
let operation = message.operation.as_ref().context("invalid operation")?;
|
||||
@@ -2470,7 +2547,7 @@ async fn update_context(message: proto::UpdateContext, session: Session) -> Resu
|
||||
/// Notify other participants that a project has been updated.
|
||||
async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
|
||||
request: T,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let project_id = ProjectId::from_proto(request.remote_entity_id());
|
||||
let project_connection_ids = session
|
||||
@@ -2495,7 +2572,7 @@ async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProj
|
||||
async fn follow(
|
||||
request: proto::Follow,
|
||||
response: Response<proto::Follow>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let room_id = RoomId::from_proto(request.room_id);
|
||||
let project_id = request.project_id.map(ProjectId::from_proto);
|
||||
@@ -2508,10 +2585,7 @@ async fn follow(
|
||||
.check_room_participants(room_id, leader_id, session.connection_id)
|
||||
.await?;
|
||||
|
||||
let response_payload = session
|
||||
.peer
|
||||
.forward_request(session.connection_id, leader_id, request)
|
||||
.await?;
|
||||
let response_payload = session.forward_request(leader_id, request).await?;
|
||||
response.send(response_payload)?;
|
||||
|
||||
if let Some(project_id) = project_id {
|
||||
@@ -2527,7 +2601,7 @@ async fn follow(
|
||||
}
|
||||
|
||||
/// Stop following another user in a call.
|
||||
async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
|
||||
async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
|
||||
let room_id = RoomId::from_proto(request.room_id);
|
||||
let project_id = request.project_id.map(ProjectId::from_proto);
|
||||
let leader_id = request.leader_id.context("invalid leader id")?.into();
|
||||
@@ -2556,7 +2630,7 @@ async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
|
||||
}
|
||||
|
||||
/// Notify everyone following you of your current location.
|
||||
async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
|
||||
async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
|
||||
let room_id = RoomId::from_proto(request.room_id);
|
||||
let database = session.db.lock().await;
|
||||
|
||||
@@ -2591,7 +2665,7 @@ async fn update_followers(request: proto::UpdateFollowers, session: Session) ->
|
||||
async fn get_users(
|
||||
request: proto::GetUsers,
|
||||
response: Response<proto::GetUsers>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let user_ids = request
|
||||
.user_ids
|
||||
@@ -2619,7 +2693,7 @@ async fn get_users(
|
||||
async fn fuzzy_search_users(
|
||||
request: proto::FuzzySearchUsers,
|
||||
response: Response<proto::FuzzySearchUsers>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let query = request.query;
|
||||
let users = match query.len() {
|
||||
@@ -2651,7 +2725,7 @@ async fn fuzzy_search_users(
|
||||
async fn request_contact(
|
||||
request: proto::RequestContact,
|
||||
response: Response<proto::RequestContact>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let requester_id = session.user_id();
|
||||
let responder_id = UserId::from_proto(request.responder_id);
|
||||
@@ -2698,7 +2772,7 @@ async fn request_contact(
|
||||
async fn respond_to_contact_request(
|
||||
request: proto::RespondToContactRequest,
|
||||
response: Response<proto::RespondToContactRequest>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let responder_id = session.user_id();
|
||||
let requester_id = UserId::from_proto(request.requester_id);
|
||||
@@ -2756,7 +2830,7 @@ async fn respond_to_contact_request(
|
||||
async fn remove_contact(
|
||||
request: proto::RemoveContact,
|
||||
response: Response<proto::RemoveContact>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let requester_id = session.user_id();
|
||||
let responder_id = UserId::from_proto(request.user_id);
|
||||
@@ -3015,7 +3089,10 @@ async fn update_user_plan(session: &Session) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
|
||||
async fn subscribe_to_channels(
|
||||
_: proto::SubscribeToChannels,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
subscribe_user_to_channels(session.user_id(), &session).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -3041,7 +3118,7 @@ async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Resul
|
||||
async fn create_channel(
|
||||
request: proto::CreateChannel,
|
||||
response: Response<proto::CreateChannel>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
@@ -3096,7 +3173,7 @@ async fn create_channel(
|
||||
async fn delete_channel(
|
||||
request: proto::DeleteChannel,
|
||||
response: Response<proto::DeleteChannel>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
@@ -3124,7 +3201,7 @@ async fn delete_channel(
|
||||
async fn invite_channel_member(
|
||||
request: proto::InviteChannelMember,
|
||||
response: Response<proto::InviteChannelMember>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
@@ -3161,7 +3238,7 @@ async fn invite_channel_member(
|
||||
async fn remove_channel_member(
|
||||
request: proto::RemoveChannelMember,
|
||||
response: Response<proto::RemoveChannelMember>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
@@ -3205,7 +3282,7 @@ async fn remove_channel_member(
|
||||
async fn set_channel_visibility(
|
||||
request: proto::SetChannelVisibility,
|
||||
response: Response<proto::SetChannelVisibility>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
@@ -3250,7 +3327,7 @@ async fn set_channel_visibility(
|
||||
async fn set_channel_member_role(
|
||||
request: proto::SetChannelMemberRole,
|
||||
response: Response<proto::SetChannelMemberRole>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
@@ -3298,7 +3375,7 @@ async fn set_channel_member_role(
|
||||
async fn rename_channel(
|
||||
request: proto::RenameChannel,
|
||||
response: Response<proto::RenameChannel>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
@@ -3330,7 +3407,7 @@ async fn rename_channel(
|
||||
async fn move_channel(
|
||||
request: proto::MoveChannel,
|
||||
response: Response<proto::MoveChannel>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
let to = ChannelId::from_proto(request.to);
|
||||
@@ -3372,7 +3449,7 @@ async fn move_channel(
|
||||
async fn reorder_channel(
|
||||
request: proto::ReorderChannel,
|
||||
response: Response<proto::ReorderChannel>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
let direction = request.direction();
|
||||
@@ -3418,7 +3495,7 @@ async fn reorder_channel(
|
||||
async fn get_channel_members(
|
||||
request: proto::GetChannelMembers,
|
||||
response: Response<proto::GetChannelMembers>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
@@ -3438,7 +3515,7 @@ async fn get_channel_members(
|
||||
async fn respond_to_channel_invite(
|
||||
request: proto::RespondToChannelInvite,
|
||||
response: Response<proto::RespondToChannelInvite>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
@@ -3479,7 +3556,7 @@ async fn respond_to_channel_invite(
|
||||
async fn join_channel(
|
||||
request: proto::JoinChannel,
|
||||
response: Response<proto::JoinChannel>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
join_channel_internal(channel_id, Box::new(response), session).await
|
||||
@@ -3502,7 +3579,7 @@ impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
|
||||
async fn join_channel_internal(
|
||||
channel_id: ChannelId,
|
||||
response: Box<impl JoinChannelInternalResponse>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let joined_room = {
|
||||
let mut db = session.db().await;
|
||||
@@ -3597,7 +3674,7 @@ async fn join_channel_internal(
|
||||
async fn join_channel_buffer(
|
||||
request: proto::JoinChannelBuffer,
|
||||
response: Response<proto::JoinChannelBuffer>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
@@ -3628,7 +3705,7 @@ async fn join_channel_buffer(
|
||||
/// Edit the channel notes
|
||||
async fn update_channel_buffer(
|
||||
request: proto::UpdateChannelBuffer,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
@@ -3680,7 +3757,7 @@ async fn update_channel_buffer(
|
||||
async fn rejoin_channel_buffers(
|
||||
request: proto::RejoinChannelBuffers,
|
||||
response: Response<proto::RejoinChannelBuffers>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let buffers = db
|
||||
@@ -3715,7 +3792,7 @@ async fn rejoin_channel_buffers(
|
||||
async fn leave_channel_buffer(
|
||||
request: proto::LeaveChannelBuffer,
|
||||
response: Response<proto::LeaveChannelBuffer>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
@@ -3777,7 +3854,7 @@ fn send_notifications(
|
||||
async fn send_channel_message(
|
||||
request: proto::SendChannelMessage,
|
||||
response: Response<proto::SendChannelMessage>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
// Validate the message body.
|
||||
let body = request.body.trim().to_string();
|
||||
@@ -3870,7 +3947,7 @@ async fn send_channel_message(
|
||||
async fn remove_channel_message(
|
||||
request: proto::RemoveChannelMessage,
|
||||
response: Response<proto::RemoveChannelMessage>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
let message_id = MessageId::from_proto(request.message_id);
|
||||
@@ -3905,7 +3982,7 @@ async fn remove_channel_message(
|
||||
async fn update_channel_message(
|
||||
request: proto::UpdateChannelMessage,
|
||||
response: Response<proto::UpdateChannelMessage>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
let message_id = MessageId::from_proto(request.message_id);
|
||||
@@ -3989,7 +4066,7 @@ async fn update_channel_message(
|
||||
/// Mark a channel message as read
|
||||
async fn acknowledge_channel_message(
|
||||
request: proto::AckChannelMessage,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
let message_id = MessageId::from_proto(request.message_id);
|
||||
@@ -4009,7 +4086,7 @@ async fn acknowledge_channel_message(
|
||||
/// Mark a buffer version as synced
|
||||
async fn acknowledge_buffer_version(
|
||||
request: proto::AckBufferOperation,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let buffer_id = BufferId::from_proto(request.buffer_id);
|
||||
session
|
||||
@@ -4029,7 +4106,7 @@ async fn acknowledge_buffer_version(
|
||||
async fn get_supermaven_api_key(
|
||||
_request: proto::GetSupermavenApiKey,
|
||||
response: Response<proto::GetSupermavenApiKey>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let user_id: String = session.user_id().to_string();
|
||||
if !session.is_staff() {
|
||||
@@ -4058,7 +4135,7 @@ async fn get_supermaven_api_key(
|
||||
async fn join_channel_chat(
|
||||
request: proto::JoinChannelChat,
|
||||
response: Response<proto::JoinChannelChat>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
|
||||
@@ -4076,7 +4153,10 @@ async fn join_channel_chat(
|
||||
}
|
||||
|
||||
/// Stop receiving chat updates for a channel
|
||||
async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
|
||||
async fn leave_channel_chat(
|
||||
request: proto::LeaveChannelChat,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
session
|
||||
.db()
|
||||
@@ -4090,7 +4170,7 @@ async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session)
|
||||
async fn get_channel_messages(
|
||||
request: proto::GetChannelMessages,
|
||||
response: Response<proto::GetChannelMessages>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
let messages = session
|
||||
@@ -4114,7 +4194,7 @@ async fn get_channel_messages(
|
||||
async fn get_channel_messages_by_id(
|
||||
request: proto::GetChannelMessagesById,
|
||||
response: Response<proto::GetChannelMessagesById>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let message_ids = request
|
||||
.message_ids
|
||||
@@ -4137,7 +4217,7 @@ async fn get_channel_messages_by_id(
|
||||
async fn get_notifications(
|
||||
request: proto::GetNotifications,
|
||||
response: Response<proto::GetNotifications>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let notifications = session
|
||||
.db()
|
||||
@@ -4159,7 +4239,7 @@ async fn get_notifications(
|
||||
async fn mark_notification_as_read(
|
||||
request: proto::MarkNotificationRead,
|
||||
response: Response<proto::MarkNotificationRead>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let database = &session.db().await;
|
||||
let notifications = database
|
||||
@@ -4181,7 +4261,7 @@ async fn mark_notification_as_read(
|
||||
async fn get_private_user_info(
|
||||
_request: proto::GetPrivateUserInfo,
|
||||
response: Response<proto::GetPrivateUserInfo>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
@@ -4205,7 +4285,7 @@ async fn get_private_user_info(
|
||||
async fn accept_terms_of_service(
|
||||
_request: proto::AcceptTermsOfService,
|
||||
response: Response<proto::AcceptTermsOfService>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
@@ -4229,7 +4309,7 @@ async fn accept_terms_of_service(
|
||||
async fn get_llm_api_token(
|
||||
_request: proto::GetLlmToken,
|
||||
response: Response<proto::GetLlmToken>,
|
||||
session: Session,
|
||||
session: MessageContext,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
|
||||
@@ -1,21 +1,15 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use chrono::Utc;
|
||||
use collections::HashMap;
|
||||
use stripe::SubscriptionStatus;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::Result;
|
||||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::stripe_client::{
|
||||
RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateMeterEventParams,
|
||||
StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
|
||||
StripeCustomerId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
|
||||
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
|
||||
UpdateSubscriptionParams,
|
||||
RealStripeClient, StripeAutomaticTax, StripeClient, StripeCreateSubscriptionItems,
|
||||
StripeCreateSubscriptionParams, StripeCustomerId, StripePrice, StripePriceId,
|
||||
StripeSubscription,
|
||||
};
|
||||
|
||||
pub struct StripeBilling {
|
||||
@@ -94,30 +88,6 @@ impl StripeBilling {
|
||||
.ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
|
||||
}
|
||||
|
||||
pub async fn determine_subscription_kind(
|
||||
&self,
|
||||
subscription: &StripeSubscription,
|
||||
) -> Option<SubscriptionKind> {
|
||||
let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
|
||||
let zed_free_price_id = self.zed_free_price_id().await.ok()?;
|
||||
|
||||
subscription.items.iter().find_map(|item| {
|
||||
let price = item.price.as_ref()?;
|
||||
|
||||
if price.id == zed_pro_price_id {
|
||||
Some(if subscription.status == SubscriptionStatus::Trialing {
|
||||
SubscriptionKind::ZedProTrial
|
||||
} else {
|
||||
SubscriptionKind::ZedPro
|
||||
})
|
||||
} else if price.id == zed_free_price_id {
|
||||
Some(SubscriptionKind::ZedFree)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
|
||||
/// not already exist.
|
||||
///
|
||||
@@ -150,65 +120,6 @@ impl StripeBilling {
|
||||
Ok(customer_id)
|
||||
}
|
||||
|
||||
pub async fn subscribe_to_price(
|
||||
&self,
|
||||
subscription_id: &StripeSubscriptionId,
|
||||
price: &StripePrice,
|
||||
) -> Result<()> {
|
||||
let subscription = self.client.get_subscription(subscription_id).await?;
|
||||
|
||||
if subscription_contains_price(&subscription, &price.id) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
|
||||
|
||||
let price_per_unit = price.unit_amount.unwrap_or_default();
|
||||
let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
|
||||
|
||||
self.client
|
||||
.update_subscription(
|
||||
subscription_id,
|
||||
UpdateSubscriptionParams {
|
||||
items: Some(vec![UpdateSubscriptionItems {
|
||||
price: Some(price.id.clone()),
|
||||
}]),
|
||||
trial_settings: Some(StripeSubscriptionTrialSettings {
|
||||
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
|
||||
missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
|
||||
},
|
||||
}),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn bill_model_request_usage(
|
||||
&self,
|
||||
customer_id: &StripeCustomerId,
|
||||
event_name: &str,
|
||||
requests: i32,
|
||||
) -> Result<()> {
|
||||
let timestamp = Utc::now().timestamp();
|
||||
let idempotency_key = Uuid::new_v4();
|
||||
|
||||
self.client
|
||||
.create_meter_event(StripeCreateMeterEventParams {
|
||||
identifier: &format!("model_requests/{}", idempotency_key),
|
||||
event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: requests as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn subscribe_to_zed_free(
|
||||
&self,
|
||||
customer_id: StripeCustomerId,
|
||||
@@ -243,14 +154,3 @@ impl StripeBilling {
|
||||
Ok(subscription)
|
||||
}
|
||||
}
|
||||
|
||||
fn subscription_contains_price(
|
||||
subscription: &StripeSubscription,
|
||||
price_id: &StripePriceId,
|
||||
) -> bool {
|
||||
subscription.items.iter().any(|item| {
|
||||
item.price
|
||||
.as_ref()
|
||||
.map_or(false, |price| price.id == *price_id)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3101,9 +3101,7 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA
|
||||
// Turn inline-blame-off by default so no state is transferred without us explicitly doing so
|
||||
let inline_blame_off_settings = Some(InlineBlameSettings {
|
||||
enabled: false,
|
||||
delay_ms: None,
|
||||
min_column: None,
|
||||
show_commit_summary: false,
|
||||
..Default::default()
|
||||
});
|
||||
cx_a.update(|cx| {
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use chrono::{Duration, Utc};
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use crate::stripe_billing::StripeBilling;
|
||||
use crate::stripe_client::{
|
||||
FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
|
||||
StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
|
||||
StripeSubscriptionItemId, UpdateSubscriptionItems,
|
||||
};
|
||||
use crate::stripe_client::{FakeStripeClient, StripePrice, StripePriceId, StripePriceRecurring};
|
||||
|
||||
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
|
||||
let stripe_client = Arc::new(FakeStripeClient::new());
|
||||
@@ -21,24 +16,6 @@ fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
|
||||
async fn test_initialize() {
|
||||
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||
|
||||
// Add test meters
|
||||
let meter1 = StripeMeter {
|
||||
id: StripeMeterId("meter_1".into()),
|
||||
event_name: "event_1".to_string(),
|
||||
};
|
||||
let meter2 = StripeMeter {
|
||||
id: StripeMeterId("meter_2".into()),
|
||||
event_name: "event_2".to_string(),
|
||||
};
|
||||
stripe_client
|
||||
.meters
|
||||
.lock()
|
||||
.insert(meter1.id.clone(), meter1);
|
||||
stripe_client
|
||||
.meters
|
||||
.lock()
|
||||
.insert(meter2.id.clone(), meter2);
|
||||
|
||||
// Add test prices
|
||||
let price1 = StripePrice {
|
||||
id: StripePriceId("price_1".into()),
|
||||
@@ -144,217 +121,3 @@ async fn test_find_or_create_customer_by_email() {
|
||||
assert_eq!(customer.email.as_deref(), Some(email));
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_subscribe_to_price() {
|
||||
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||
|
||||
let price = StripePrice {
|
||||
id: StripePriceId("price_test".into()),
|
||||
unit_amount: Some(2000),
|
||||
lookup_key: Some("test-price".to_string()),
|
||||
recurring: None,
|
||||
};
|
||||
stripe_client
|
||||
.prices
|
||||
.lock()
|
||||
.insert(price.id.clone(), price.clone());
|
||||
|
||||
let now = Utc::now();
|
||||
let subscription = StripeSubscription {
|
||||
id: StripeSubscriptionId("sub_test".into()),
|
||||
customer: StripeCustomerId("cus_test".into()),
|
||||
status: stripe::SubscriptionStatus::Active,
|
||||
current_period_start: now.timestamp(),
|
||||
current_period_end: (now + Duration::days(30)).timestamp(),
|
||||
items: vec![],
|
||||
cancel_at: None,
|
||||
cancellation_details: None,
|
||||
};
|
||||
stripe_client
|
||||
.subscriptions
|
||||
.lock()
|
||||
.insert(subscription.id.clone(), subscription.clone());
|
||||
|
||||
stripe_billing
|
||||
.subscribe_to_price(&subscription.id, &price)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let update_subscription_calls = stripe_client
|
||||
.update_subscription_calls
|
||||
.lock()
|
||||
.iter()
|
||||
.map(|(id, params)| (id.clone(), params.clone()))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(update_subscription_calls.len(), 1);
|
||||
assert_eq!(update_subscription_calls[0].0, subscription.id);
|
||||
assert_eq!(
|
||||
update_subscription_calls[0].1.items,
|
||||
Some(vec![UpdateSubscriptionItems {
|
||||
price: Some(price.id.clone())
|
||||
}])
|
||||
);
|
||||
|
||||
// Subscribing to a price that is already on the subscription is a no-op.
|
||||
{
|
||||
let now = Utc::now();
|
||||
let subscription = StripeSubscription {
|
||||
id: StripeSubscriptionId("sub_test".into()),
|
||||
customer: StripeCustomerId("cus_test".into()),
|
||||
status: stripe::SubscriptionStatus::Active,
|
||||
current_period_start: now.timestamp(),
|
||||
current_period_end: (now + Duration::days(30)).timestamp(),
|
||||
items: vec![StripeSubscriptionItem {
|
||||
id: StripeSubscriptionItemId("si_test".into()),
|
||||
price: Some(price.clone()),
|
||||
}],
|
||||
cancel_at: None,
|
||||
cancellation_details: None,
|
||||
};
|
||||
stripe_client
|
||||
.subscriptions
|
||||
.lock()
|
||||
.insert(subscription.id.clone(), subscription.clone());
|
||||
|
||||
stripe_billing
|
||||
.subscribe_to_price(&subscription.id, &price)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_subscribe_to_zed_free() {
|
||||
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||
|
||||
let zed_pro_price = StripePrice {
|
||||
id: StripePriceId("price_1".into()),
|
||||
unit_amount: Some(0),
|
||||
lookup_key: Some("zed-pro".to_string()),
|
||||
recurring: None,
|
||||
};
|
||||
stripe_client
|
||||
.prices
|
||||
.lock()
|
||||
.insert(zed_pro_price.id.clone(), zed_pro_price.clone());
|
||||
let zed_free_price = StripePrice {
|
||||
id: StripePriceId("price_2".into()),
|
||||
unit_amount: Some(0),
|
||||
lookup_key: Some("zed-free".to_string()),
|
||||
recurring: None,
|
||||
};
|
||||
stripe_client
|
||||
.prices
|
||||
.lock()
|
||||
.insert(zed_free_price.id.clone(), zed_free_price.clone());
|
||||
|
||||
stripe_billing.initialize().await.unwrap();
|
||||
|
||||
// Customer is subscribed to Zed Free when not already subscribed to a plan.
|
||||
{
|
||||
let customer_id = StripeCustomerId("cus_no_plan".into());
|
||||
|
||||
let subscription = stripe_billing
|
||||
.subscribe_to_zed_free(customer_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
|
||||
}
|
||||
|
||||
// Customer is not subscribed to Zed Free when they already have an active subscription.
|
||||
{
|
||||
let customer_id = StripeCustomerId("cus_active_subscription".into());
|
||||
|
||||
let now = Utc::now();
|
||||
let existing_subscription = StripeSubscription {
|
||||
id: StripeSubscriptionId("sub_existing_active".into()),
|
||||
customer: customer_id.clone(),
|
||||
status: stripe::SubscriptionStatus::Active,
|
||||
current_period_start: now.timestamp(),
|
||||
current_period_end: (now + Duration::days(30)).timestamp(),
|
||||
items: vec![StripeSubscriptionItem {
|
||||
id: StripeSubscriptionItemId("si_test".into()),
|
||||
price: Some(zed_pro_price.clone()),
|
||||
}],
|
||||
cancel_at: None,
|
||||
cancellation_details: None,
|
||||
};
|
||||
stripe_client.subscriptions.lock().insert(
|
||||
existing_subscription.id.clone(),
|
||||
existing_subscription.clone(),
|
||||
);
|
||||
|
||||
let subscription = stripe_billing
|
||||
.subscribe_to_zed_free(customer_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(subscription, existing_subscription);
|
||||
}
|
||||
|
||||
// Customer is not subscribed to Zed Free when they already have a trial subscription.
|
||||
{
|
||||
let customer_id = StripeCustomerId("cus_trial_subscription".into());
|
||||
|
||||
let now = Utc::now();
|
||||
let existing_subscription = StripeSubscription {
|
||||
id: StripeSubscriptionId("sub_existing_trial".into()),
|
||||
customer: customer_id.clone(),
|
||||
status: stripe::SubscriptionStatus::Trialing,
|
||||
current_period_start: now.timestamp(),
|
||||
current_period_end: (now + Duration::days(14)).timestamp(),
|
||||
items: vec![StripeSubscriptionItem {
|
||||
id: StripeSubscriptionItemId("si_test".into()),
|
||||
price: Some(zed_pro_price.clone()),
|
||||
}],
|
||||
cancel_at: None,
|
||||
cancellation_details: None,
|
||||
};
|
||||
stripe_client.subscriptions.lock().insert(
|
||||
existing_subscription.id.clone(),
|
||||
existing_subscription.clone(),
|
||||
);
|
||||
|
||||
let subscription = stripe_billing
|
||||
.subscribe_to_zed_free(customer_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(subscription, existing_subscription);
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_bill_model_request_usage() {
|
||||
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||
|
||||
let customer_id = StripeCustomerId("cus_test".into());
|
||||
|
||||
stripe_billing
|
||||
.bill_model_request_usage(&customer_id, "some_model/requests", 73)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let create_meter_event_calls = stripe_client
|
||||
.create_meter_event_calls
|
||||
.lock()
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(create_meter_event_calls.len(), 1);
|
||||
assert!(
|
||||
create_meter_event_calls[0]
|
||||
.identifier
|
||||
.starts_with("model_requests/")
|
||||
);
|
||||
assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
|
||||
assert_eq!(
|
||||
create_meter_event_calls[0].event_name.as_ref(),
|
||||
"some_model/requests"
|
||||
);
|
||||
assert_eq!(create_meter_event_calls[0].value, 73);
|
||||
}
|
||||
|
||||
@@ -297,6 +297,7 @@ impl TestServer {
|
||||
client_name,
|
||||
Principal::User(user),
|
||||
ZedVersion(SemanticVersion::new(1, 0, 0)),
|
||||
Some("test".to_string()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
|
||||
@@ -136,7 +136,10 @@ impl Focusable for CommandPalette {
|
||||
|
||||
impl Render for CommandPalette {
|
||||
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
|
||||
v_flex().w(rems(34.)).child(self.picker.clone())
|
||||
v_flex()
|
||||
.key_context("CommandPalette")
|
||||
.w(rems(34.))
|
||||
.child(self.picker.clone())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ use language::{
|
||||
point_from_lsp, point_to_lsp,
|
||||
};
|
||||
use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
|
||||
use node_runtime::{NodeRuntime, VersionCheck};
|
||||
use node_runtime::NodeRuntime;
|
||||
use parking_lot::Mutex;
|
||||
use project::DisableAiSettings;
|
||||
use request::StatusNotification;
|
||||
@@ -1169,8 +1169,9 @@ async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::
|
||||
const SERVER_PATH: &str =
|
||||
"node_modules/@github/copilot-language-server/dist/language-server.js";
|
||||
|
||||
// pinning it: https://github.com/zed-industries/zed/issues/36093
|
||||
const PINNED_VERSION: &str = "1.354";
|
||||
let latest_version = node_runtime
|
||||
.npm_package_latest_version(PACKAGE_NAME)
|
||||
.await?;
|
||||
let server_path = paths::copilot_dir().join(SERVER_PATH);
|
||||
|
||||
fs.create_dir(paths::copilot_dir()).await?;
|
||||
@@ -1180,13 +1181,12 @@ async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::
|
||||
PACKAGE_NAME,
|
||||
&server_path,
|
||||
paths::copilot_dir(),
|
||||
&PINNED_VERSION,
|
||||
VersionCheck::VersionMismatch,
|
||||
&latest_version,
|
||||
)
|
||||
.await;
|
||||
if should_install {
|
||||
node_runtime
|
||||
.npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &PINNED_VERSION)])
|
||||
.npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)])
|
||||
.await?;
|
||||
}
|
||||
|
||||
|
||||
@@ -58,11 +58,19 @@ impl EditPredictionProvider for CopilotCompletionProvider {
|
||||
}
|
||||
|
||||
fn show_completions_in_menu() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn show_tab_accept_marker() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_jump_to_edit() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn is_refreshing(&self) -> bool {
|
||||
self.pending_refresh.is_some()
|
||||
self.pending_refresh.is_some() && self.completions.is_empty()
|
||||
}
|
||||
|
||||
fn is_enabled(
|
||||
@@ -343,8 +351,8 @@ mod tests {
|
||||
executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT);
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
assert!(editor.context_menu_visible());
|
||||
assert!(!editor.has_active_edit_prediction());
|
||||
// Since we have both, the copilot suggestion is not shown inline
|
||||
assert!(editor.has_active_edit_prediction());
|
||||
// Since we have both, the copilot suggestion is existing but does not show up as ghost text
|
||||
assert_eq!(editor.text(cx), "one.\ntwo\nthree\n");
|
||||
assert_eq!(editor.display_text(cx), "one.\ntwo\nthree\n");
|
||||
|
||||
@@ -934,8 +942,9 @@ mod tests {
|
||||
executor.advance_clock(COPILOT_DEBOUNCE_TIMEOUT);
|
||||
cx.update_editor(|editor, _, cx| {
|
||||
assert!(editor.context_menu_visible());
|
||||
assert!(!editor.has_active_edit_prediction(),);
|
||||
assert!(editor.has_active_edit_prediction());
|
||||
assert_eq!(editor.text(cx), "one\ntwo.\nthree\n");
|
||||
assert_eq!(editor.display_text(cx), "one\ntwo.\nthree\n");
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1077,8 +1086,6 @@ mod tests {
|
||||
vec![complete_from_marker.clone(), replace_range_marker.clone()],
|
||||
);
|
||||
|
||||
let complete_from_position =
|
||||
cx.to_lsp(marked_ranges.remove(&complete_from_marker).unwrap()[0].start);
|
||||
let replace_range =
|
||||
cx.to_lsp_range(marked_ranges.remove(&replace_range_marker).unwrap()[0].clone());
|
||||
|
||||
@@ -1087,10 +1094,6 @@ mod tests {
|
||||
let completions = completions.clone();
|
||||
async move {
|
||||
assert_eq!(params.text_document_position.text_document.uri, url.clone());
|
||||
assert_eq!(
|
||||
params.text_document_position.position,
|
||||
complete_from_position
|
||||
);
|
||||
Ok(Some(lsp::CompletionResponse::Array(
|
||||
completions
|
||||
.iter()
|
||||
|
||||
@@ -74,6 +74,12 @@ impl Borrow<str> for DebugAdapterName {
|
||||
}
|
||||
}
|
||||
|
||||
impl Borrow<SharedString> for DebugAdapterName {
|
||||
fn borrow(&self) -> &SharedString {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DebugAdapterName {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
std::fmt::Display::fmt(&self.0, f)
|
||||
|
||||
@@ -87,7 +87,7 @@ impl DapRegistry {
|
||||
self.0.read().adapters.get(name).cloned()
|
||||
}
|
||||
|
||||
pub fn enumerate_adapters(&self) -> Vec<DebugAdapterName> {
|
||||
pub fn enumerate_adapters<B: FromIterator<DebugAdapterName>>(&self) -> B {
|
||||
self.0.read().adapters.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,9 +152,6 @@ impl PythonDebugAdapter {
|
||||
maybe!(async move {
|
||||
let response = latest_release.filter(|response| response.status().is_success())?;
|
||||
|
||||
let download_dir = debug_adapters_dir().join(Self::ADAPTER_NAME);
|
||||
std::fs::create_dir_all(&download_dir).ok()?;
|
||||
|
||||
let mut output = String::new();
|
||||
response
|
||||
.into_body()
|
||||
|
||||
@@ -300,7 +300,7 @@ impl DebugPanel {
|
||||
});
|
||||
|
||||
session.update(cx, |session, _| match &mut session.mode {
|
||||
SessionState::Building(state_task) => {
|
||||
SessionState::Booting(state_task) => {
|
||||
*state_task = Some(boot_task);
|
||||
}
|
||||
SessionState::Running(_) => {
|
||||
|
||||
@@ -299,59 +299,76 @@ pub fn init(cx: &mut App) {
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let session = active_session
|
||||
.read(cx)
|
||||
.running_state
|
||||
.read(cx)
|
||||
.session()
|
||||
.read(cx);
|
||||
|
||||
if session.is_terminated() {
|
||||
return;
|
||||
}
|
||||
|
||||
let editor = cx.entity().downgrade();
|
||||
window.on_action(TypeId::of::<editor::actions::RunToCursor>(), {
|
||||
let editor = editor.clone();
|
||||
let active_session = active_session.clone();
|
||||
move |_, phase, _, cx| {
|
||||
if phase != DispatchPhase::Bubble {
|
||||
return;
|
||||
}
|
||||
maybe!({
|
||||
let (buffer, position, _) = editor
|
||||
.update(cx, |editor, cx| {
|
||||
let cursor_point: language::Point =
|
||||
editor.selections.newest(cx).head();
|
||||
|
||||
editor
|
||||
.buffer()
|
||||
.read(cx)
|
||||
.point_to_buffer_point(cursor_point, cx)
|
||||
})
|
||||
.ok()??;
|
||||
window.on_action_when(
|
||||
session.any_stopped_thread(),
|
||||
TypeId::of::<editor::actions::RunToCursor>(),
|
||||
{
|
||||
let editor = editor.clone();
|
||||
let active_session = active_session.clone();
|
||||
move |_, phase, _, cx| {
|
||||
if phase != DispatchPhase::Bubble {
|
||||
return;
|
||||
}
|
||||
maybe!({
|
||||
let (buffer, position, _) = editor
|
||||
.update(cx, |editor, cx| {
|
||||
let cursor_point: language::Point =
|
||||
editor.selections.newest(cx).head();
|
||||
|
||||
let path =
|
||||
editor
|
||||
.buffer()
|
||||
.read(cx)
|
||||
.point_to_buffer_point(cursor_point, cx)
|
||||
})
|
||||
.ok()??;
|
||||
|
||||
let path =
|
||||
debugger::breakpoint_store::BreakpointStore::abs_path_from_buffer(
|
||||
&buffer, cx,
|
||||
)?;
|
||||
|
||||
let source_breakpoint = SourceBreakpoint {
|
||||
row: position.row,
|
||||
path,
|
||||
message: None,
|
||||
condition: None,
|
||||
hit_condition: None,
|
||||
state: debugger::breakpoint_store::BreakpointState::Enabled,
|
||||
};
|
||||
let source_breakpoint = SourceBreakpoint {
|
||||
row: position.row,
|
||||
path,
|
||||
message: None,
|
||||
condition: None,
|
||||
hit_condition: None,
|
||||
state: debugger::breakpoint_store::BreakpointState::Enabled,
|
||||
};
|
||||
|
||||
active_session.update(cx, |session, cx| {
|
||||
session.running_state().update(cx, |state, cx| {
|
||||
if let Some(thread_id) = state.selected_thread_id() {
|
||||
state.session().update(cx, |session, cx| {
|
||||
session.run_to_position(
|
||||
source_breakpoint,
|
||||
thread_id,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
}
|
||||
active_session.update(cx, |session, cx| {
|
||||
session.running_state().update(cx, |state, cx| {
|
||||
if let Some(thread_id) = state.selected_thread_id() {
|
||||
state.session().update(cx, |session, cx| {
|
||||
session.run_to_position(
|
||||
source_breakpoint,
|
||||
thread_id,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Some(())
|
||||
});
|
||||
}
|
||||
});
|
||||
Some(())
|
||||
});
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
window.on_action(
|
||||
TypeId::of::<editor::actions::EvaluateSelectedText>(),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use anyhow::{Context as _, bail};
|
||||
use collections::{FxHashMap, HashMap};
|
||||
use collections::{FxHashMap, HashMap, HashSet};
|
||||
use language::LanguageRegistry;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
@@ -450,7 +450,7 @@ impl NewProcessModal {
|
||||
.and_then(|buffer| buffer.read(cx).language())
|
||||
.cloned();
|
||||
|
||||
let mut available_adapters = workspace
|
||||
let mut available_adapters: Vec<_> = workspace
|
||||
.update(cx, |_, cx| DapRegistry::global(cx).enumerate_adapters())
|
||||
.unwrap_or_default();
|
||||
if let Some(language) = active_buffer_language {
|
||||
@@ -1054,6 +1054,9 @@ impl DebugDelegate {
|
||||
})
|
||||
})
|
||||
});
|
||||
|
||||
let valid_adapters: HashSet<_> = cx.global::<DapRegistry>().enumerate_adapters();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (recent, scenarios) = if let Some(task) = task {
|
||||
task.await
|
||||
@@ -1094,6 +1097,7 @@ impl DebugDelegate {
|
||||
} => !(hide_vscode && dir.ends_with(".vscode")),
|
||||
_ => true,
|
||||
})
|
||||
.filter(|(_, scenario)| valid_adapters.contains(&scenario.adapter))
|
||||
.map(|(kind, scenario)| {
|
||||
let (language, scenario) =
|
||||
Self::get_scenario_kind(&languages, &dap_registry, scenario);
|
||||
|
||||
@@ -1651,7 +1651,7 @@ impl RunningState {
|
||||
|
||||
let is_building = self.session.update(cx, |session, cx| {
|
||||
session.shutdown(cx).detach();
|
||||
matches!(session.mode, session::SessionState::Building(_))
|
||||
matches!(session.mode, session::SessionState::Booting(_))
|
||||
});
|
||||
|
||||
if is_building {
|
||||
|
||||
@@ -29,7 +29,6 @@ use ui::{
|
||||
Scrollbar, ScrollbarState, SharedString, StatefulInteractiveElement, Styled, Toggleable,
|
||||
Tooltip, Window, div, h_flex, px, v_flex,
|
||||
};
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
use zed_actions::{ToggleEnableBreakpoint, UnsetBreakpoint};
|
||||
|
||||
@@ -56,8 +55,6 @@ pub(crate) struct BreakpointList {
|
||||
scrollbar_state: ScrollbarState,
|
||||
breakpoints: Vec<BreakpointEntry>,
|
||||
session: Option<Entity<Session>>,
|
||||
hide_scrollbar_task: Option<Task<()>>,
|
||||
show_scrollbar: bool,
|
||||
focus_handle: FocusHandle,
|
||||
scroll_handle: UniformListScrollHandle,
|
||||
selected_ix: Option<usize>,
|
||||
@@ -103,8 +100,6 @@ impl BreakpointList {
|
||||
worktree_store,
|
||||
scrollbar_state,
|
||||
breakpoints: Default::default(),
|
||||
hide_scrollbar_task: None,
|
||||
show_scrollbar: false,
|
||||
workspace,
|
||||
session,
|
||||
focus_handle,
|
||||
@@ -565,21 +560,6 @@ impl BreakpointList {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn hide_scrollbar(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1);
|
||||
self.hide_scrollbar_task = Some(cx.spawn_in(window, async move |panel, cx| {
|
||||
cx.background_executor()
|
||||
.timer(SCROLLBAR_SHOW_INTERVAL)
|
||||
.await;
|
||||
panel
|
||||
.update(cx, |panel, cx| {
|
||||
panel.show_scrollbar = false;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}))
|
||||
}
|
||||
|
||||
fn render_list(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let selected_ix = self.selected_ix;
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
@@ -614,43 +594,39 @@ impl BreakpointList {
|
||||
.flex_grow()
|
||||
}
|
||||
|
||||
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
|
||||
if !(self.show_scrollbar || self.scrollbar_state.is_dragging()) {
|
||||
return None;
|
||||
}
|
||||
Some(
|
||||
div()
|
||||
.occlude()
|
||||
.id("breakpoint-list-vertical-scrollbar")
|
||||
.on_mouse_move(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
cx.stop_propagation()
|
||||
}))
|
||||
.on_hover(|_, _, cx| {
|
||||
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {
|
||||
div()
|
||||
.occlude()
|
||||
.id("breakpoint-list-vertical-scrollbar")
|
||||
.on_mouse_move(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
cx.stop_propagation()
|
||||
}))
|
||||
.on_hover(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|_, _, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|_, _, _, cx| {
|
||||
cx.stop_propagation();
|
||||
}),
|
||||
)
|
||||
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.h_full()
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_1()
|
||||
.bottom_0()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.h_full()
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_1()
|
||||
.bottom_0()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.children(Scrollbar::vertical(self.scrollbar_state.clone()).map(|s| s.auto_hide(cx)))
|
||||
}
|
||||
|
||||
pub(crate) fn render_control_strip(&self) -> AnyElement {
|
||||
let selection_kind = self.selection_kind();
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
@@ -819,15 +795,6 @@ impl Render for BreakpointList {
|
||||
.id("breakpoint-list")
|
||||
.key_context("BreakpointList")
|
||||
.track_focus(&self.focus_handle)
|
||||
.on_hover(cx.listener(|this, hovered, window, cx| {
|
||||
if *hovered {
|
||||
this.show_scrollbar = true;
|
||||
this.hide_scrollbar_task.take();
|
||||
cx.notify();
|
||||
} else if !this.focus_handle.contains_focused(window, cx) {
|
||||
this.hide_scrollbar(window, cx);
|
||||
}
|
||||
}))
|
||||
.on_action(cx.listener(Self::select_next))
|
||||
.on_action(cx.listener(Self::select_previous))
|
||||
.on_action(cx.listener(Self::select_first))
|
||||
@@ -844,7 +811,7 @@ impl Render for BreakpointList {
|
||||
v_flex()
|
||||
.size_full()
|
||||
.child(self.render_list(cx))
|
||||
.children(self.render_vertical_scrollbar(cx)),
|
||||
.child(self.render_vertical_scrollbar(cx)),
|
||||
)
|
||||
.when_some(self.strip_mode, |this, _| {
|
||||
this.child(Divider::horizontal()).child(
|
||||
|
||||
@@ -23,7 +23,6 @@ use ui::{
|
||||
ParentElement, Pixels, PopoverMenuHandle, Render, Scrollbar, ScrollbarState, SharedString,
|
||||
StatefulInteractiveElement, Styled, TextSize, Tooltip, Window, div, h_flex, px, v_flex,
|
||||
};
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::{ToggleDataBreakpoint, session::running::stack_frame_list::StackFrameList};
|
||||
@@ -34,9 +33,7 @@ pub(crate) struct MemoryView {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
scroll_handle: UniformListScrollHandle,
|
||||
scroll_state: ScrollbarState,
|
||||
show_scrollbar: bool,
|
||||
stack_frame_list: WeakEntity<StackFrameList>,
|
||||
hide_scrollbar_task: Option<Task<()>>,
|
||||
focus_handle: FocusHandle,
|
||||
view_state: ViewState,
|
||||
query_editor: Entity<Editor>,
|
||||
@@ -150,8 +147,6 @@ impl MemoryView {
|
||||
scroll_state,
|
||||
scroll_handle,
|
||||
stack_frame_list,
|
||||
show_scrollbar: false,
|
||||
hide_scrollbar_task: None,
|
||||
focus_handle: cx.focus_handle(),
|
||||
view_state,
|
||||
query_editor,
|
||||
@@ -168,61 +163,42 @@ impl MemoryView {
|
||||
.detach();
|
||||
this
|
||||
}
|
||||
fn hide_scrollbar(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
const SCROLLBAR_SHOW_INTERVAL: Duration = Duration::from_secs(1);
|
||||
self.hide_scrollbar_task = Some(cx.spawn_in(window, async move |panel, cx| {
|
||||
cx.background_executor()
|
||||
.timer(SCROLLBAR_SHOW_INTERVAL)
|
||||
.await;
|
||||
panel
|
||||
.update(cx, |panel, cx| {
|
||||
panel.show_scrollbar = false;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}))
|
||||
}
|
||||
|
||||
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Option<Stateful<Div>> {
|
||||
if !(self.show_scrollbar || self.scroll_state.is_dragging()) {
|
||||
return None;
|
||||
}
|
||||
Some(
|
||||
div()
|
||||
.occlude()
|
||||
.id("memory-view-vertical-scrollbar")
|
||||
.on_drag_move(cx.listener(|this, evt, _, cx| {
|
||||
let did_handle = this.handle_scroll_drag(evt);
|
||||
cx.notify();
|
||||
if did_handle {
|
||||
cx.stop_propagation()
|
||||
}
|
||||
}))
|
||||
.on_drag(ScrollbarDragging, |_, _, _, cx| cx.new(|_| Empty))
|
||||
.on_hover(|_, _, cx| {
|
||||
fn render_vertical_scrollbar(&self, cx: &mut Context<Self>) -> Stateful<Div> {
|
||||
div()
|
||||
.occlude()
|
||||
.id("memory-view-vertical-scrollbar")
|
||||
.on_drag_move(cx.listener(|this, evt, _, cx| {
|
||||
let did_handle = this.handle_scroll_drag(evt);
|
||||
cx.notify();
|
||||
if did_handle {
|
||||
cx.stop_propagation()
|
||||
}
|
||||
}))
|
||||
.on_drag(ScrollbarDragging, |_, _, _, cx| cx.new(|_| Empty))
|
||||
.on_hover(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|_, _, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_mouse_up(
|
||||
MouseButton::Left,
|
||||
cx.listener(|_, _, _, cx| {
|
||||
cx.stop_propagation();
|
||||
}),
|
||||
)
|
||||
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.h_full()
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_1()
|
||||
.bottom_0()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.children(Scrollbar::vertical(self.scroll_state.clone())),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.on_scroll_wheel(cx.listener(|_, _, _, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.h_full()
|
||||
.absolute()
|
||||
.right_1()
|
||||
.top_1()
|
||||
.bottom_0()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.children(Scrollbar::vertical(self.scroll_state.clone()).map(|s| s.auto_hide(cx)))
|
||||
}
|
||||
|
||||
fn render_memory(&self, cx: &mut Context<Self>) -> UniformList {
|
||||
@@ -920,15 +896,6 @@ impl Render for MemoryView {
|
||||
.on_action(cx.listener(Self::page_up))
|
||||
.size_full()
|
||||
.track_focus(&self.focus_handle)
|
||||
.on_hover(cx.listener(|this, hovered, window, cx| {
|
||||
if *hovered {
|
||||
this.show_scrollbar = true;
|
||||
this.hide_scrollbar_task.take();
|
||||
cx.notify();
|
||||
} else if !this.focus_handle.contains_focused(window, cx) {
|
||||
this.hide_scrollbar(window, cx);
|
||||
}
|
||||
}))
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
@@ -978,7 +945,7 @@ impl Render for MemoryView {
|
||||
)
|
||||
.with_priority(1)
|
||||
}))
|
||||
.children(self.render_vertical_scrollbar(cx)),
|
||||
.child(self.render_vertical_scrollbar(cx)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -298,7 +298,7 @@ async fn test_dap_adapter_config_conversion_and_validation(cx: &mut TestAppConte
|
||||
|
||||
let adapter_names = cx.update(|cx| {
|
||||
let registry = DapRegistry::global(cx);
|
||||
registry.enumerate_adapters()
|
||||
registry.enumerate_adapters::<Vec<_>>()
|
||||
});
|
||||
|
||||
let zed_config = ZedDebugConfig {
|
||||
|
||||
@@ -177,9 +177,9 @@ impl ProjectDiagnosticsEditor {
|
||||
}
|
||||
project::Event::DiagnosticsUpdated {
|
||||
language_server_id,
|
||||
path,
|
||||
paths,
|
||||
} => {
|
||||
this.paths_to_update.insert(path.clone());
|
||||
this.paths_to_update.extend(paths.clone());
|
||||
let project = project.clone();
|
||||
this.diagnostic_summary_update = cx.spawn(async move |this, cx| {
|
||||
cx.background_executor()
|
||||
@@ -193,9 +193,9 @@ impl ProjectDiagnosticsEditor {
|
||||
cx.emit(EditorEvent::TitleChanged);
|
||||
|
||||
if this.editor.focus_handle(cx).contains_focused(window, cx) || this.focus_handle.contains_focused(window, cx) {
|
||||
log::debug!("diagnostics updated for server {language_server_id}, path {path:?}. recording change");
|
||||
log::debug!("diagnostics updated for server {language_server_id}, paths {paths:?}. recording change");
|
||||
} else {
|
||||
log::debug!("diagnostics updated for server {language_server_id}, path {path:?}. updating excerpts");
|
||||
log::debug!("diagnostics updated for server {language_server_id}, paths {paths:?}. updating excerpts");
|
||||
this.update_stale_excerpts(window, cx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -876,7 +876,7 @@ async fn test_random_diagnostics_with_inlays(cx: &mut TestAppContext, mut rng: S
|
||||
vec![Inlay::edit_prediction(
|
||||
post_inc(&mut next_inlay_id),
|
||||
snapshot.buffer_snapshot.anchor_before(position),
|
||||
format!("Test inlay {next_inlay_id}"),
|
||||
Rope::from_iter(["Test inlay ", "next_inlay_id"]),
|
||||
)],
|
||||
cx,
|
||||
);
|
||||
|
||||
@@ -61,6 +61,10 @@ pub trait EditPredictionProvider: 'static + Sized {
|
||||
fn show_tab_accept_marker() -> bool {
|
||||
false
|
||||
}
|
||||
fn supports_jump_to_edit() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
|
||||
DataCollectionState::Unsupported
|
||||
}
|
||||
@@ -116,6 +120,7 @@ pub trait EditPredictionProviderHandle {
|
||||
) -> bool;
|
||||
fn show_completions_in_menu(&self) -> bool;
|
||||
fn show_tab_accept_marker(&self) -> bool;
|
||||
fn supports_jump_to_edit(&self) -> bool;
|
||||
fn data_collection_state(&self, cx: &App) -> DataCollectionState;
|
||||
fn usage(&self, cx: &App) -> Option<EditPredictionUsage>;
|
||||
fn toggle_data_collection(&self, cx: &mut App);
|
||||
@@ -166,6 +171,10 @@ where
|
||||
T::show_tab_accept_marker()
|
||||
}
|
||||
|
||||
fn supports_jump_to_edit(&self) -> bool {
|
||||
T::supports_jump_to_edit()
|
||||
}
|
||||
|
||||
fn data_collection_state(&self, cx: &App) -> DataCollectionState {
|
||||
self.read(cx).data_collection_state(cx)
|
||||
}
|
||||
|
||||
@@ -491,7 +491,12 @@ impl EditPredictionButton {
|
||||
let subtle_mode = matches!(current_mode, EditPredictionsMode::Subtle);
|
||||
let eager_mode = matches!(current_mode, EditPredictionsMode::Eager);
|
||||
|
||||
if matches!(provider, EditPredictionProvider::Zed) {
|
||||
if matches!(
|
||||
provider,
|
||||
EditPredictionProvider::Zed
|
||||
| EditPredictionProvider::Copilot
|
||||
| EditPredictionProvider::Supermaven
|
||||
) {
|
||||
menu = menu
|
||||
.separator()
|
||||
.header("Display Modes")
|
||||
|
||||
@@ -745,5 +745,6 @@ actions!(
|
||||
UniqueLinesCaseInsensitive,
|
||||
/// Removes duplicate lines (case-sensitive).
|
||||
UniqueLinesCaseSensitive,
|
||||
UnwrapSyntaxNode
|
||||
]
|
||||
);
|
||||
|
||||
@@ -48,16 +48,16 @@ pub struct Inlay {
|
||||
impl Inlay {
|
||||
pub fn hint(id: usize, position: Anchor, hint: &project::InlayHint) -> Self {
|
||||
let mut text = hint.text();
|
||||
if hint.padding_right && !text.ends_with(' ') {
|
||||
text.push(' ');
|
||||
if hint.padding_right && text.chars_at(text.len().saturating_sub(1)).next() != Some(' ') {
|
||||
text.push(" ");
|
||||
}
|
||||
if hint.padding_left && !text.starts_with(' ') {
|
||||
text.insert(0, ' ');
|
||||
if hint.padding_left && text.chars_at(0).next() != Some(' ') {
|
||||
text.push_front(" ");
|
||||
}
|
||||
Self {
|
||||
id: InlayId::Hint(id),
|
||||
position,
|
||||
text: text.into(),
|
||||
text,
|
||||
color: None,
|
||||
}
|
||||
}
|
||||
@@ -737,13 +737,13 @@ impl InlayMap {
|
||||
Inlay::mock_hint(
|
||||
post_inc(next_inlay_id),
|
||||
snapshot.buffer.anchor_at(position, bias),
|
||||
text.clone(),
|
||||
&text,
|
||||
)
|
||||
} else {
|
||||
Inlay::edit_prediction(
|
||||
post_inc(next_inlay_id),
|
||||
snapshot.buffer.anchor_at(position, bias),
|
||||
text.clone(),
|
||||
&text,
|
||||
)
|
||||
};
|
||||
let inlay_id = next_inlay.id;
|
||||
@@ -1694,7 +1694,7 @@ mod tests {
|
||||
(offset, inlay.clone())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let mut expected_text = Rope::from(buffer_snapshot.text());
|
||||
let mut expected_text = Rope::from(&buffer_snapshot.text());
|
||||
for (offset, inlay) in inlays.iter().rev() {
|
||||
expected_text.replace(*offset..*offset, &inlay.text.to_string());
|
||||
}
|
||||
|
||||
@@ -228,6 +228,49 @@ async fn test_edit_prediction_invalidation_range(cx: &mut gpui::TestAppContext)
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui::TestAppContext) {
|
||||
init_test(cx, |_| {});
|
||||
|
||||
let mut cx = EditorTestContext::new(cx).await;
|
||||
let provider = cx.new(|_| FakeNonZedEditPredictionProvider::default());
|
||||
assign_editor_completion_provider_non_zed(provider.clone(), &mut cx);
|
||||
|
||||
// Cursor is 2+ lines above the proposed edit
|
||||
cx.set_state(indoc! {"
|
||||
line 0
|
||||
line ˇ1
|
||||
line 2
|
||||
line 3
|
||||
line
|
||||
"});
|
||||
|
||||
propose_edits_non_zed(
|
||||
&provider,
|
||||
vec![(Point::new(4, 3)..Point::new(4, 3), " 4")],
|
||||
&mut cx,
|
||||
);
|
||||
|
||||
cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx));
|
||||
|
||||
// For non-Zed providers, there should be no move completion (jump functionality disabled)
|
||||
cx.editor(|editor, _, _| {
|
||||
if let Some(completion_state) = &editor.active_edit_prediction {
|
||||
// Should be an Edit prediction, not a Move prediction
|
||||
match &completion_state.completion {
|
||||
EditPrediction::Edit { .. } => {
|
||||
// This is expected for non-Zed providers
|
||||
}
|
||||
EditPrediction::Move { .. } => {
|
||||
panic!(
|
||||
"Non-Zed providers should not show Move predictions (jump functionality)"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn assert_editor_active_edit_completion(
|
||||
cx: &mut EditorTestContext,
|
||||
assert: impl FnOnce(MultiBufferSnapshot, &Vec<(Range<Anchor>, String)>),
|
||||
@@ -301,6 +344,37 @@ fn assign_editor_completion_provider(
|
||||
})
|
||||
}
|
||||
|
||||
fn propose_edits_non_zed<T: ToOffset>(
|
||||
provider: &Entity<FakeNonZedEditPredictionProvider>,
|
||||
edits: Vec<(Range<T>, &str)>,
|
||||
cx: &mut EditorTestContext,
|
||||
) {
|
||||
let snapshot = cx.buffer_snapshot();
|
||||
let edits = edits.into_iter().map(|(range, text)| {
|
||||
let range = snapshot.anchor_after(range.start)..snapshot.anchor_before(range.end);
|
||||
(range, text.into())
|
||||
});
|
||||
|
||||
cx.update(|_, cx| {
|
||||
provider.update(cx, |provider, _| {
|
||||
provider.set_edit_prediction(Some(edit_prediction::EditPrediction {
|
||||
id: None,
|
||||
edits: edits.collect(),
|
||||
edit_preview: None,
|
||||
}))
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn assign_editor_completion_provider_non_zed(
|
||||
provider: Entity<FakeNonZedEditPredictionProvider>,
|
||||
cx: &mut EditorTestContext,
|
||||
) {
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct FakeEditPredictionProvider {
|
||||
pub completion: Option<edit_prediction::EditPrediction>,
|
||||
@@ -325,6 +399,84 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
|
||||
false
|
||||
}
|
||||
|
||||
fn supports_jump_to_edit() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn is_enabled(
|
||||
&self,
|
||||
_buffer: &gpui::Entity<language::Buffer>,
|
||||
_cursor_position: language::Anchor,
|
||||
_cx: &gpui::App,
|
||||
) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn is_refreshing(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
_project: Option<Entity<Project>>,
|
||||
_buffer: gpui::Entity<language::Buffer>,
|
||||
_cursor_position: language::Anchor,
|
||||
_debounce: bool,
|
||||
_cx: &mut gpui::Context<Self>,
|
||||
) {
|
||||
}
|
||||
|
||||
fn cycle(
|
||||
&mut self,
|
||||
_buffer: gpui::Entity<language::Buffer>,
|
||||
_cursor_position: language::Anchor,
|
||||
_direction: edit_prediction::Direction,
|
||||
_cx: &mut gpui::Context<Self>,
|
||||
) {
|
||||
}
|
||||
|
||||
fn accept(&mut self, _cx: &mut gpui::Context<Self>) {}
|
||||
|
||||
fn discard(&mut self, _cx: &mut gpui::Context<Self>) {}
|
||||
|
||||
fn suggest<'a>(
|
||||
&mut self,
|
||||
_buffer: &gpui::Entity<language::Buffer>,
|
||||
_cursor_position: language::Anchor,
|
||||
_cx: &mut gpui::Context<Self>,
|
||||
) -> Option<edit_prediction::EditPrediction> {
|
||||
self.completion.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct FakeNonZedEditPredictionProvider {
|
||||
pub completion: Option<edit_prediction::EditPrediction>,
|
||||
}
|
||||
|
||||
impl FakeNonZedEditPredictionProvider {
|
||||
pub fn set_edit_prediction(&mut self, completion: Option<edit_prediction::EditPrediction>) {
|
||||
self.completion = completion;
|
||||
}
|
||||
}
|
||||
|
||||
impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
|
||||
fn name() -> &'static str {
|
||||
"fake-non-zed-provider"
|
||||
}
|
||||
|
||||
fn display_name() -> &'static str {
|
||||
"Fake Non-Zed Provider"
|
||||
}
|
||||
|
||||
fn show_completions_in_menu() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn supports_jump_to_edit() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn is_enabled(
|
||||
&self,
|
||||
_buffer: &gpui::Entity<language::Buffer>,
|
||||
|
||||
@@ -2705,6 +2705,11 @@ impl Editor {
|
||||
self.completion_provider = provider;
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn completion_provider(&self) -> Option<Rc<dyn CompletionProvider>> {
|
||||
self.completion_provider.clone()
|
||||
}
|
||||
|
||||
pub fn semantics_provider(&self) -> Option<Rc<dyn SemanticsProvider>> {
|
||||
self.semantics_provider.clone()
|
||||
}
|
||||
@@ -7760,8 +7765,14 @@ impl Editor {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let is_move =
|
||||
move_invalidation_row_range.is_some() || self.edit_predictions_hidden_for_vim_mode;
|
||||
let supports_jump = self
|
||||
.edit_prediction_provider
|
||||
.as_ref()
|
||||
.map(|provider| provider.provider.supports_jump_to_edit())
|
||||
.unwrap_or(true);
|
||||
|
||||
let is_move = supports_jump
|
||||
&& (move_invalidation_row_range.is_some() || self.edit_predictions_hidden_for_vim_mode);
|
||||
let completion = if is_move {
|
||||
invalidation_row_range =
|
||||
move_invalidation_row_range.unwrap_or(edit_start_row..edit_end_row);
|
||||
@@ -8799,8 +8810,12 @@ impl Editor {
|
||||
return None;
|
||||
}
|
||||
|
||||
let highlighted_edits =
|
||||
crate::edit_prediction_edit_text(&snapshot, edits, edit_preview.as_ref()?, false, cx);
|
||||
let highlighted_edits = if let Some(edit_preview) = edit_preview.as_ref() {
|
||||
crate::edit_prediction_edit_text(&snapshot, edits, edit_preview, false, cx)
|
||||
} else {
|
||||
// Fallback for providers without edit_preview
|
||||
crate::edit_prediction_fallback_text(edits, cx)
|
||||
};
|
||||
|
||||
let styled_text = highlighted_edits.to_styled_text(&style.text);
|
||||
let line_count = highlighted_edits.text.lines().count();
|
||||
@@ -9068,6 +9083,18 @@ impl Editor {
|
||||
let editor_bg_color = cx.theme().colors().editor_background;
|
||||
editor_bg_color.blend(accent_color.opacity(0.6))
|
||||
}
|
||||
fn get_prediction_provider_icon_name(
|
||||
provider: &Option<RegisteredEditPredictionProvider>,
|
||||
) -> IconName {
|
||||
match provider {
|
||||
Some(provider) => match provider.provider.name() {
|
||||
"copilot" => IconName::Copilot,
|
||||
"supermaven" => IconName::Supermaven,
|
||||
_ => IconName::ZedPredict,
|
||||
},
|
||||
None => IconName::ZedPredict,
|
||||
}
|
||||
}
|
||||
|
||||
fn render_edit_prediction_cursor_popover(
|
||||
&self,
|
||||
@@ -9080,6 +9107,7 @@ impl Editor {
|
||||
cx: &mut Context<Editor>,
|
||||
) -> Option<AnyElement> {
|
||||
let provider = self.edit_prediction_provider.as_ref()?;
|
||||
let provider_icon = Self::get_prediction_provider_icon_name(&self.edit_prediction_provider);
|
||||
|
||||
if provider.provider.needs_terms_acceptance(cx) {
|
||||
return Some(
|
||||
@@ -9106,7 +9134,7 @@ impl Editor {
|
||||
h_flex()
|
||||
.flex_1()
|
||||
.gap_2()
|
||||
.child(Icon::new(IconName::ZedPredict))
|
||||
.child(Icon::new(provider_icon))
|
||||
.child(Label::new("Accept Terms of Service"))
|
||||
.child(div().w_full())
|
||||
.child(
|
||||
@@ -9122,12 +9150,8 @@ impl Editor {
|
||||
|
||||
let is_refreshing = provider.provider.is_refreshing(cx);
|
||||
|
||||
fn pending_completion_container() -> Div {
|
||||
h_flex()
|
||||
.h_full()
|
||||
.flex_1()
|
||||
.gap_2()
|
||||
.child(Icon::new(IconName::ZedPredict))
|
||||
fn pending_completion_container(icon: IconName) -> Div {
|
||||
h_flex().h_full().flex_1().gap_2().child(Icon::new(icon))
|
||||
}
|
||||
|
||||
let completion = match &self.active_edit_prediction {
|
||||
@@ -9157,7 +9181,7 @@ impl Editor {
|
||||
Icon::new(IconName::ZedPredictUp)
|
||||
}
|
||||
}
|
||||
EditPrediction::Edit { .. } => Icon::new(IconName::ZedPredict),
|
||||
EditPrediction::Edit { .. } => Icon::new(provider_icon),
|
||||
}))
|
||||
.child(
|
||||
h_flex()
|
||||
@@ -9224,15 +9248,15 @@ impl Editor {
|
||||
cx,
|
||||
)?,
|
||||
|
||||
None => {
|
||||
pending_completion_container().child(Label::new("...").size(LabelSize::Small))
|
||||
}
|
||||
None => pending_completion_container(provider_icon)
|
||||
.child(Label::new("...").size(LabelSize::Small)),
|
||||
},
|
||||
|
||||
None => pending_completion_container().child(Label::new("No Prediction")),
|
||||
None => pending_completion_container(provider_icon)
|
||||
.child(Label::new("...").size(LabelSize::Small)),
|
||||
};
|
||||
|
||||
let completion = if is_refreshing {
|
||||
let completion = if is_refreshing || self.active_edit_prediction.is_none() {
|
||||
completion
|
||||
.with_animation(
|
||||
"loading-completion",
|
||||
@@ -9332,23 +9356,35 @@ impl Editor {
|
||||
.child(Icon::new(arrow).color(Color::Muted).size(IconSize::Small))
|
||||
}
|
||||
|
||||
let supports_jump = self
|
||||
.edit_prediction_provider
|
||||
.as_ref()
|
||||
.map(|provider| provider.provider.supports_jump_to_edit())
|
||||
.unwrap_or(true);
|
||||
|
||||
match &completion.completion {
|
||||
EditPrediction::Move {
|
||||
target, snapshot, ..
|
||||
} => Some(
|
||||
h_flex()
|
||||
.px_2()
|
||||
.gap_2()
|
||||
.flex_1()
|
||||
.child(
|
||||
if target.text_anchor.to_point(&snapshot).row > cursor_point.row {
|
||||
Icon::new(IconName::ZedPredictDown)
|
||||
} else {
|
||||
Icon::new(IconName::ZedPredictUp)
|
||||
},
|
||||
)
|
||||
.child(Label::new("Jump to Edit")),
|
||||
),
|
||||
} => {
|
||||
if !supports_jump {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(
|
||||
h_flex()
|
||||
.px_2()
|
||||
.gap_2()
|
||||
.flex_1()
|
||||
.child(
|
||||
if target.text_anchor.to_point(&snapshot).row > cursor_point.row {
|
||||
Icon::new(IconName::ZedPredictDown)
|
||||
} else {
|
||||
Icon::new(IconName::ZedPredictUp)
|
||||
},
|
||||
)
|
||||
.child(Label::new("Jump to Edit")),
|
||||
)
|
||||
}
|
||||
|
||||
EditPrediction::Edit {
|
||||
edits,
|
||||
@@ -9358,14 +9394,13 @@ impl Editor {
|
||||
} => {
|
||||
let first_edit_row = edits.first()?.0.start.text_anchor.to_point(&snapshot).row;
|
||||
|
||||
let (highlighted_edits, has_more_lines) = crate::edit_prediction_edit_text(
|
||||
&snapshot,
|
||||
&edits,
|
||||
edit_preview.as_ref()?,
|
||||
true,
|
||||
cx,
|
||||
)
|
||||
.first_line_preview();
|
||||
let (highlighted_edits, has_more_lines) =
|
||||
if let Some(edit_preview) = edit_preview.as_ref() {
|
||||
crate::edit_prediction_edit_text(&snapshot, &edits, edit_preview, true, cx)
|
||||
.first_line_preview()
|
||||
} else {
|
||||
crate::edit_prediction_fallback_text(&edits, cx).first_line_preview()
|
||||
};
|
||||
|
||||
let styled_text = gpui::StyledText::new(highlighted_edits.text)
|
||||
.with_default_highlights(&style.text, highlighted_edits.highlights);
|
||||
@@ -9376,11 +9411,13 @@ impl Editor {
|
||||
.child(styled_text)
|
||||
.when(has_more_lines, |parent| parent.child("…"));
|
||||
|
||||
let left = if first_edit_row != cursor_point.row {
|
||||
let left = if supports_jump && first_edit_row != cursor_point.row {
|
||||
render_relative_row_jump("", cursor_point.row, first_edit_row)
|
||||
.into_any_element()
|
||||
} else {
|
||||
Icon::new(IconName::ZedPredict).into_any_element()
|
||||
let icon_name =
|
||||
Editor::get_prediction_provider_icon_name(&self.edit_prediction_provider);
|
||||
Icon::new(icon_name).into_any_element()
|
||||
};
|
||||
|
||||
Some(
|
||||
@@ -14711,6 +14748,81 @@ impl Editor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unwrap_syntax_node(
|
||||
&mut self,
|
||||
_: &UnwrapSyntaxNode,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx);
|
||||
|
||||
let buffer = self.buffer.read(cx).snapshot(cx);
|
||||
let old_selections: Box<[_]> = self.selections.all::<usize>(cx).into();
|
||||
|
||||
let edits = old_selections
|
||||
.iter()
|
||||
// only consider the first selection for now
|
||||
.take(1)
|
||||
.map(|selection| {
|
||||
// Only requires two branches once if-let-chains stabilize (#53667)
|
||||
let selection_range = if !selection.is_empty() {
|
||||
selection.range()
|
||||
} else if let Some((_, ancestor_range)) =
|
||||
buffer.syntax_ancestor(selection.start..selection.end)
|
||||
{
|
||||
match ancestor_range {
|
||||
MultiOrSingleBufferOffsetRange::Single(range) => range,
|
||||
MultiOrSingleBufferOffsetRange::Multi(range) => range,
|
||||
}
|
||||
} else {
|
||||
selection.range()
|
||||
};
|
||||
|
||||
let mut new_range = selection_range.clone();
|
||||
while let Some((_, ancestor_range)) = buffer.syntax_ancestor(new_range.clone()) {
|
||||
new_range = match ancestor_range {
|
||||
MultiOrSingleBufferOffsetRange::Single(range) => range,
|
||||
MultiOrSingleBufferOffsetRange::Multi(range) => range,
|
||||
};
|
||||
if new_range.start < selection_range.start
|
||||
|| new_range.end > selection_range.end
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(selection, selection_range, new_range)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
self.transact(window, cx, |editor, window, cx| {
|
||||
for (_, child, parent) in &edits {
|
||||
let text = buffer.text_for_range(child.clone()).collect::<String>();
|
||||
editor.replace_text_in_range(Some(parent.clone()), &text, window, cx);
|
||||
}
|
||||
|
||||
editor.change_selections(
|
||||
SelectionEffects::scroll(Autoscroll::fit()),
|
||||
window,
|
||||
cx,
|
||||
|s| {
|
||||
s.select(
|
||||
edits
|
||||
.iter()
|
||||
.map(|(s, old, new)| Selection {
|
||||
id: s.id,
|
||||
start: new.start,
|
||||
end: new.start + old.len(),
|
||||
goal: SelectionGoal::None,
|
||||
reversed: s.reversed,
|
||||
})
|
||||
.collect(),
|
||||
);
|
||||
},
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
fn refresh_runnables(&mut self, window: &mut Window, cx: &mut Context<Self>) -> Task<()> {
|
||||
if !EditorSettings::get_global(cx).gutter.runnables {
|
||||
self.clear_tasks();
|
||||
@@ -23195,6 +23307,33 @@ fn edit_prediction_edit_text(
|
||||
edit_preview.highlight_edits(current_snapshot, &edits, include_deletions, cx)
|
||||
}
|
||||
|
||||
fn edit_prediction_fallback_text(edits: &[(Range<Anchor>, String)], cx: &App) -> HighlightedText {
|
||||
// Fallback for providers that don't provide edit_preview (like Copilot/Supermaven)
|
||||
// Just show the raw edit text with basic styling
|
||||
let mut text = String::new();
|
||||
let mut highlights = Vec::new();
|
||||
|
||||
let insertion_highlight_style = HighlightStyle {
|
||||
color: Some(cx.theme().colors().text),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
for (_, edit_text) in edits {
|
||||
let start_offset = text.len();
|
||||
text.push_str(edit_text);
|
||||
let end_offset = text.len();
|
||||
|
||||
if start_offset < end_offset {
|
||||
highlights.push((start_offset..end_offset, insertion_highlight_style));
|
||||
}
|
||||
}
|
||||
|
||||
HighlightedText {
|
||||
text: text.into(),
|
||||
highlights,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn diagnostic_style(severity: lsp::DiagnosticSeverity, colors: &StatusColors) -> Hsla {
|
||||
match severity {
|
||||
lsp::DiagnosticSeverity::ERROR => colors.error,
|
||||
|
||||
@@ -7969,6 +7969,38 @@ async fn test_select_larger_smaller_syntax_node_for_string(cx: &mut TestAppConte
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_unwrap_syntax_node(cx: &mut gpui::TestAppContext) {
|
||||
init_test(cx, |_| {});
|
||||
|
||||
let mut cx = EditorTestContext::new(cx).await;
|
||||
|
||||
let language = Arc::new(Language::new(
|
||||
LanguageConfig::default(),
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
));
|
||||
|
||||
cx.update_buffer(|buffer, cx| {
|
||||
buffer.set_language(Some(language), cx);
|
||||
});
|
||||
|
||||
cx.set_state(
|
||||
&r#"
|
||||
use mod1::mod2::{«mod3ˇ», mod4};
|
||||
"#
|
||||
.unindent(),
|
||||
);
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.unwrap_syntax_node(&UnwrapSyntaxNode, window, cx);
|
||||
});
|
||||
cx.assert_editor_state(
|
||||
&r#"
|
||||
use mod1::mod2::«mod3ˇ»;
|
||||
"#
|
||||
.unindent(),
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fold_function_bodies(cx: &mut TestAppContext) {
|
||||
init_test(cx, |_| {});
|
||||
|
||||
@@ -86,8 +86,6 @@ use util::post_inc;
|
||||
use util::{RangeExt, ResultExt, debug_panic};
|
||||
use workspace::{CollaboratorId, Workspace, item::Item, notifications::NotifyTaskExt};
|
||||
|
||||
const INLINE_BLAME_PADDING_EM_WIDTHS: f32 = 7.;
|
||||
|
||||
/// Determines what kinds of highlights should be applied to a lines background.
|
||||
#[derive(Clone, Copy, Default)]
|
||||
struct LineHighlightSpec {
|
||||
@@ -357,6 +355,7 @@ impl EditorElement {
|
||||
register_action(editor, window, Editor::toggle_comments);
|
||||
register_action(editor, window, Editor::select_larger_syntax_node);
|
||||
register_action(editor, window, Editor::select_smaller_syntax_node);
|
||||
register_action(editor, window, Editor::unwrap_syntax_node);
|
||||
register_action(editor, window, Editor::select_enclosing_symbol);
|
||||
register_action(editor, window, Editor::move_to_enclosing_bracket);
|
||||
register_action(editor, window, Editor::undo_selection);
|
||||
@@ -2427,10 +2426,13 @@ impl EditorElement {
|
||||
let editor = self.editor.read(cx);
|
||||
let blame = editor.blame.clone()?;
|
||||
let padding = {
|
||||
const INLINE_BLAME_PADDING_EM_WIDTHS: f32 = 6.;
|
||||
const INLINE_ACCEPT_SUGGESTION_EM_WIDTHS: f32 = 14.;
|
||||
|
||||
let mut padding = INLINE_BLAME_PADDING_EM_WIDTHS;
|
||||
let mut padding = ProjectSettings::get_global(cx)
|
||||
.git
|
||||
.inline_blame
|
||||
.unwrap_or_default()
|
||||
.padding as f32;
|
||||
|
||||
if let Some(edit_prediction) = editor.active_edit_prediction.as_ref() {
|
||||
match &edit_prediction.completion {
|
||||
@@ -2468,7 +2470,7 @@ impl EditorElement {
|
||||
let min_column_in_pixels = ProjectSettings::get_global(cx)
|
||||
.git
|
||||
.inline_blame
|
||||
.and_then(|settings| settings.min_column)
|
||||
.map(|settings| settings.min_column)
|
||||
.map(|col| self.column_pixels(col as usize, window))
|
||||
.unwrap_or(px(0.));
|
||||
let min_start = content_origin.x - scroll_pixel_position.x + min_column_in_pixels;
|
||||
@@ -3681,6 +3683,7 @@ impl EditorElement {
|
||||
.id("path header block")
|
||||
.size_full()
|
||||
.justify_between()
|
||||
.overflow_hidden()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
@@ -8363,7 +8366,13 @@ impl Element for EditorElement {
|
||||
})
|
||||
.flatten()?;
|
||||
let mut element = render_inline_blame_entry(blame_entry, &style, cx)?;
|
||||
let inline_blame_padding = INLINE_BLAME_PADDING_EM_WIDTHS * em_advance;
|
||||
let inline_blame_padding = ProjectSettings::get_global(cx)
|
||||
.git
|
||||
.inline_blame
|
||||
.unwrap_or_default()
|
||||
.padding
|
||||
as f32
|
||||
* em_advance;
|
||||
Some(
|
||||
element
|
||||
.layout_as_root(AvailableSpace::min_size(), window, cx)
|
||||
|
||||
@@ -3546,7 +3546,7 @@ pub mod tests {
|
||||
let excerpt_hints = excerpt_hints.read();
|
||||
for id in &excerpt_hints.ordered_hints {
|
||||
let hint = &excerpt_hints.hints_by_id[id];
|
||||
let mut label = hint.text();
|
||||
let mut label = hint.text().to_string();
|
||||
if hint.padding_left {
|
||||
label.insert(0, ' ');
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use crate::{
|
||||
Copy, CopyAndTrim, CopyPermalinkToLine, Cut, DisplayPoint, DisplaySnapshot, Editor,
|
||||
EvaluateSelectedText, FindAllReferences, GoToDeclaration, GoToDefinition, GoToImplementation,
|
||||
GoToTypeDefinition, Paste, Rename, RevealInFileManager, SelectMode, SelectionEffects,
|
||||
SelectionExt, ToDisplayPoint, ToggleCodeActions,
|
||||
GoToTypeDefinition, Paste, Rename, RevealInFileManager, RunToCursor, SelectMode,
|
||||
SelectionEffects, SelectionExt, ToDisplayPoint, ToggleCodeActions,
|
||||
actions::{Format, FormatSelections},
|
||||
selections_collection::SelectionsCollection,
|
||||
};
|
||||
@@ -200,15 +200,21 @@ pub fn deploy_context_menu(
|
||||
});
|
||||
|
||||
let evaluate_selection = window.is_action_available(&EvaluateSelectedText, cx);
|
||||
let run_to_cursor = window.is_action_available(&RunToCursor, cx);
|
||||
|
||||
ui::ContextMenu::build(window, cx, |menu, _window, _cx| {
|
||||
let builder = menu
|
||||
.on_blur_subscription(Subscription::new(|| {}))
|
||||
.when(evaluate_selection && has_selections, |builder| {
|
||||
builder
|
||||
.action("Evaluate Selection", Box::new(EvaluateSelectedText))
|
||||
.separator()
|
||||
.when(run_to_cursor, |builder| {
|
||||
builder.action("Run to Cursor", Box::new(RunToCursor))
|
||||
})
|
||||
.when(evaluate_selection && has_selections, |builder| {
|
||||
builder.action("Evaluate Selection", Box::new(EvaluateSelectedText))
|
||||
})
|
||||
.when(
|
||||
run_to_cursor || (evaluate_selection && has_selections),
|
||||
|builder| builder.separator(),
|
||||
)
|
||||
.action("Go to Definition", Box::new(GoToDefinition))
|
||||
.action("Go to Declaration", Box::new(GoToDeclaration))
|
||||
.action("Go to Type Definition", Box::new(GoToTypeDefinition))
|
||||
|
||||
@@ -191,7 +191,7 @@ impl Editor {
|
||||
|
||||
if let Some(language) = language {
|
||||
for signature in &mut signature_help.signatures {
|
||||
let text = Rope::from(signature.label.to_string());
|
||||
let text = Rope::from(signature.label.as_ref());
|
||||
let highlights = language
|
||||
.highlight_text(&text, 0..signature.label.len())
|
||||
.into_iter()
|
||||
|
||||
@@ -846,14 +846,12 @@ impl GitRepository for RealGitRepository {
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.spawn()?;
|
||||
child
|
||||
.stdin
|
||||
.take()
|
||||
.unwrap()
|
||||
.write_all(content.as_bytes())
|
||||
.await?;
|
||||
let mut stdin = child.stdin.take().unwrap();
|
||||
stdin.write_all(content.as_bytes()).await?;
|
||||
stdin.flush().await?;
|
||||
drop(stdin);
|
||||
let output = child.output().await?.stdout;
|
||||
let sha = String::from_utf8(output)?;
|
||||
let sha = str::from_utf8(&output)?.trim();
|
||||
|
||||
log::debug!("indexing SHA: {sha}, path {path:?}");
|
||||
|
||||
@@ -871,6 +869,7 @@ impl GitRepository for RealGitRepository {
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
} else {
|
||||
log::debug!("removing path {path:?} from the index");
|
||||
let output = new_smol_command(&git_binary_path)
|
||||
.current_dir(&working_directory)
|
||||
.envs(env.iter())
|
||||
@@ -921,6 +920,7 @@ impl GitRepository for RealGitRepository {
|
||||
for rev in &revs {
|
||||
write!(&mut stdin, "{rev}\n")?;
|
||||
}
|
||||
stdin.flush()?;
|
||||
drop(stdin);
|
||||
|
||||
let output = process.wait_with_output()?;
|
||||
|
||||
@@ -1,12 +1,22 @@
|
||||
use std::str::FromStr;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use regex::Regex;
|
||||
use url::Url;
|
||||
|
||||
use git::{
|
||||
BuildCommitPermalinkParams, BuildPermalinkParams, GitHostingProvider, ParsedGitRemote,
|
||||
RemoteUrl,
|
||||
PullRequest, RemoteUrl,
|
||||
};
|
||||
|
||||
fn pull_request_regex() -> &'static Regex {
|
||||
static PULL_REQUEST_REGEX: LazyLock<Regex> = LazyLock::new(|| {
|
||||
// This matches Bitbucket PR reference pattern: (pull request #xxx)
|
||||
Regex::new(r"\(pull request #(\d+)\)").unwrap()
|
||||
});
|
||||
&PULL_REQUEST_REGEX
|
||||
}
|
||||
|
||||
pub struct Bitbucket {
|
||||
name: String,
|
||||
base_url: Url,
|
||||
@@ -96,6 +106,22 @@ impl GitHostingProvider for Bitbucket {
|
||||
);
|
||||
permalink
|
||||
}
|
||||
|
||||
fn extract_pull_request(&self, remote: &ParsedGitRemote, message: &str) -> Option<PullRequest> {
|
||||
// Check first line of commit message for PR references
|
||||
let first_line = message.lines().next()?;
|
||||
|
||||
// Try to match against our PR patterns
|
||||
let capture = pull_request_regex().captures(first_line)?;
|
||||
let number = capture.get(1)?.as_str().parse::<u32>().ok()?;
|
||||
|
||||
// Construct the PR URL in Bitbucket format
|
||||
let mut url = self.base_url();
|
||||
let path = format!("/{}/{}/pull-requests/{}", remote.owner, remote.repo, number);
|
||||
url.set_path(&path);
|
||||
|
||||
Some(PullRequest { number, url })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -203,4 +229,34 @@ mod tests {
|
||||
"https://bitbucket.org/zed-industries/zed/src/f00b4r/main.rs#lines-24:48";
|
||||
assert_eq!(permalink.to_string(), expected_url.to_string())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bitbucket_pull_requests() {
|
||||
use indoc::indoc;
|
||||
|
||||
let remote = ParsedGitRemote {
|
||||
owner: "zed-industries".into(),
|
||||
repo: "zed".into(),
|
||||
};
|
||||
|
||||
let bitbucket = Bitbucket::public_instance();
|
||||
|
||||
// Test message without PR reference
|
||||
let message = "This does not contain a pull request";
|
||||
assert!(bitbucket.extract_pull_request(&remote, message).is_none());
|
||||
|
||||
// Pull request number at end of first line
|
||||
let message = indoc! {r#"
|
||||
Merged in feature-branch (pull request #123)
|
||||
|
||||
Some detailed description of the changes.
|
||||
"#};
|
||||
|
||||
let pr = bitbucket.extract_pull_request(&remote, message).unwrap();
|
||||
assert_eq!(pr.number, 123);
|
||||
assert_eq!(
|
||||
pr.url.as_str(),
|
||||
"https://bitbucket.org/zed-industries/zed/pull-requests/123"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,6 +180,7 @@ impl Focusable for BranchList {
|
||||
impl Render for BranchList {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.key_context("GitBranchSelector")
|
||||
.w(self.width)
|
||||
.on_modifiers_changed(cx.listener(Self::handle_modifiers_changed))
|
||||
.child(self.picker.clone())
|
||||
|
||||
@@ -109,7 +109,10 @@ impl Focusable for RepositorySelector {
|
||||
|
||||
impl Render for RepositorySelector {
|
||||
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
|
||||
div().w(self.width).child(self.picker.clone())
|
||||
div()
|
||||
.key_context("GitRepositorySelector")
|
||||
.w(self.width)
|
||||
.child(self.picker.clone())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -308,10 +308,14 @@ impl Settings for LineIndicatorFormat {
|
||||
type FileContent = Option<LineIndicatorFormatContent>;
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> anyhow::Result<Self> {
|
||||
let format = [sources.release_channel, sources.user]
|
||||
.into_iter()
|
||||
.find_map(|value| value.copied().flatten())
|
||||
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?);
|
||||
let format = [
|
||||
sources.release_channel,
|
||||
sources.operating_system,
|
||||
sources.user,
|
||||
]
|
||||
.into_iter()
|
||||
.find_map(|value| value.copied().flatten())
|
||||
.unwrap_or(sources.default.ok_or_else(Self::missing_default)?);
|
||||
|
||||
Ok(format.0)
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ smallvec.workspace = true
|
||||
smol.workspace = true
|
||||
strum.workspace = true
|
||||
sum_tree.workspace = true
|
||||
taffy = "=0.8.3"
|
||||
taffy = "=0.9.0"
|
||||
thiserror.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
|
||||
@@ -79,6 +79,14 @@ impl<T> Task<T> {
|
||||
Task(TaskState::Spawned(task)) => task.detach(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether the task has run to completion.
|
||||
pub fn is_finished(&self) -> bool {
|
||||
match self {
|
||||
Task(TaskState::Ready(_)) => true,
|
||||
Task(TaskState::Spawned(task)) => task.is_finished(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E, T> Task<Result<T, E>>
|
||||
|
||||
@@ -150,7 +150,7 @@ pub struct MouseClickEvent {
|
||||
}
|
||||
|
||||
/// A click event that was generated by a keyboard button being pressed and released.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct KeyboardClickEvent {
|
||||
/// The keyboard button that was pressed to trigger the click.
|
||||
pub button: KeyboardButton,
|
||||
@@ -168,6 +168,12 @@ pub enum ClickEvent {
|
||||
Keyboard(KeyboardClickEvent),
|
||||
}
|
||||
|
||||
impl Default for ClickEvent {
|
||||
fn default() -> Self {
|
||||
ClickEvent::Keyboard(KeyboardClickEvent::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl ClickEvent {
|
||||
/// Returns the modifiers that were held during the click event
|
||||
///
|
||||
@@ -256,9 +262,10 @@ impl ClickEvent {
|
||||
}
|
||||
|
||||
/// An enum representing the keyboard button that was pressed for a click event.
|
||||
#[derive(Hash, PartialEq, Eq, Copy, Clone, Debug)]
|
||||
#[derive(Hash, PartialEq, Eq, Copy, Clone, Debug, Default)]
|
||||
pub enum KeyboardButton {
|
||||
/// Enter key was clicked
|
||||
#[default]
|
||||
Enter,
|
||||
/// Space key was clicked
|
||||
Space,
|
||||
|
||||
@@ -4248,6 +4248,25 @@ impl Window {
|
||||
.on_action(action_type, Rc::new(listener));
|
||||
}
|
||||
|
||||
/// Register an action listener on the window for the next frame if the condition is true.
|
||||
/// The type of action is determined by the first parameter of the given listener.
|
||||
/// When the next frame is rendered the listener will be cleared.
|
||||
///
|
||||
/// This is a fairly low-level method, so prefer using action handlers on elements unless you have
|
||||
/// a specific need to register a global listener.
|
||||
pub fn on_action_when(
|
||||
&mut self,
|
||||
condition: bool,
|
||||
action_type: TypeId,
|
||||
listener: impl Fn(&dyn Any, DispatchPhase, &mut Window, &mut App) + 'static,
|
||||
) {
|
||||
if condition {
|
||||
self.next_frame
|
||||
.dispatch_tree
|
||||
.on_action(action_type, Rc::new(listener));
|
||||
}
|
||||
}
|
||||
|
||||
/// Read information about the GPU backing this window.
|
||||
/// Currently returns None on Mac and Windows.
|
||||
pub fn gpu_specs(&self) -> Option<GpuSpecs> {
|
||||
|
||||
@@ -261,7 +261,6 @@ pub enum IconName {
|
||||
TodoComplete,
|
||||
TodoPending,
|
||||
TodoProgress,
|
||||
ToolBulb,
|
||||
ToolCopy,
|
||||
ToolDeleteFile,
|
||||
ToolDiagnostics,
|
||||
@@ -273,6 +272,7 @@ pub enum IconName {
|
||||
ToolRegex,
|
||||
ToolSearch,
|
||||
ToolTerminal,
|
||||
ToolThink,
|
||||
ToolWeb,
|
||||
Trash,
|
||||
Triangle,
|
||||
|
||||
@@ -941,8 +941,6 @@ impl LanguageModel for CloudLanguageModel {
|
||||
request,
|
||||
model.id(),
|
||||
model.supports_parallel_tool_calls(),
|
||||
model.supports_prompt_cache_key(),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
|
||||
@@ -14,7 +14,7 @@ use language_model::{
|
||||
RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use menu;
|
||||
use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent, stream_completion};
|
||||
use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
@@ -45,7 +45,6 @@ pub struct AvailableModel {
|
||||
pub max_tokens: u64,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
pub reasoning_effort: Option<ReasoningEffort>,
|
||||
}
|
||||
|
||||
pub struct OpenAiLanguageModelProvider {
|
||||
@@ -214,7 +213,6 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
||||
max_tokens: model.max_tokens,
|
||||
max_output_tokens: model.max_output_tokens,
|
||||
max_completion_tokens: model.max_completion_tokens,
|
||||
reasoning_effort: model.reasoning_effort.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -303,25 +301,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
use open_ai::Model;
|
||||
match &self.model {
|
||||
Model::FourOmni
|
||||
| Model::FourOmniMini
|
||||
| Model::FourPointOne
|
||||
| Model::FourPointOneMini
|
||||
| Model::FourPointOneNano
|
||||
| Model::Five
|
||||
| Model::FiveMini
|
||||
| Model::FiveNano
|
||||
| Model::O1
|
||||
| Model::O3
|
||||
| Model::O4Mini => true,
|
||||
Model::ThreePointFiveTurbo
|
||||
| Model::Four
|
||||
| Model::FourTurbo
|
||||
| Model::O3Mini
|
||||
| Model::Custom { .. } => false,
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
@@ -370,9 +350,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||
request,
|
||||
self.model.id(),
|
||||
self.model.supports_parallel_tool_calls(),
|
||||
self.model.supports_prompt_cache_key(),
|
||||
self.max_output_tokens(),
|
||||
self.model.reasoning_effort(),
|
||||
);
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
@@ -387,9 +365,7 @@ pub fn into_open_ai(
|
||||
request: LanguageModelRequest,
|
||||
model_id: &str,
|
||||
supports_parallel_tool_calls: bool,
|
||||
supports_prompt_cache_key: bool,
|
||||
max_output_tokens: Option<u64>,
|
||||
reasoning_effort: Option<ReasoningEffort>,
|
||||
) -> open_ai::Request {
|
||||
let stream = !model_id.starts_with("o1-");
|
||||
|
||||
@@ -479,11 +455,6 @@ pub fn into_open_ai(
|
||||
} else {
|
||||
None
|
||||
},
|
||||
prompt_cache_key: if supports_prompt_cache_key {
|
||||
request.thread_id
|
||||
} else {
|
||||
None
|
||||
},
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
@@ -500,7 +471,6 @@ pub fn into_open_ai(
|
||||
LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
|
||||
LanguageModelToolChoice::None => open_ai::ToolChoice::None,
|
||||
}),
|
||||
reasoning_effort,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -355,16 +355,7 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let supports_parallel_tool_call = true;
|
||||
let supports_prompt_cache_key = false;
|
||||
let request = into_open_ai(
|
||||
request,
|
||||
&self.model.name,
|
||||
supports_parallel_tool_call,
|
||||
supports_prompt_cache_key,
|
||||
self.max_output_tokens(),
|
||||
None,
|
||||
);
|
||||
let request = into_open_ai(request, &self.model.name, true, self.max_output_tokens());
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
let mapper = OpenAiEventMapper::new();
|
||||
|
||||
@@ -355,9 +355,7 @@ impl LanguageModel for VercelLanguageModel {
|
||||
request,
|
||||
self.model.id(),
|
||||
self.model.supports_parallel_tool_calls(),
|
||||
self.model.supports_prompt_cache_key(),
|
||||
self.max_output_tokens(),
|
||||
None,
|
||||
);
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
|
||||
@@ -359,9 +359,7 @@ impl LanguageModel for XAiLanguageModel {
|
||||
request,
|
||||
self.model.id(),
|
||||
self.model.supports_parallel_tool_calls(),
|
||||
self.model.supports_prompt_cache_key(),
|
||||
self.max_output_tokens(),
|
||||
None,
|
||||
);
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
|
||||
@@ -86,7 +86,10 @@ impl LanguageSelector {
|
||||
|
||||
impl Render for LanguageSelector {
|
||||
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
|
||||
v_flex().w(rems(34.)).child(self.picker.clone())
|
||||
v_flex()
|
||||
.key_context("LanguageSelector")
|
||||
.w(rems(34.))
|
||||
.child(self.picker.clone())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -75,9 +75,6 @@ impl super::LspAdapter for CLspAdapter {
|
||||
&*version.downcast::<GitHubLspBinaryVersion>().unwrap();
|
||||
let version_dir = container_dir.join(format!("clangd_{name}"));
|
||||
let binary_path = version_dir.join("bin/clangd");
|
||||
let expected_digest = digest
|
||||
.as_ref()
|
||||
.and_then(|digest| digest.strip_prefix("sha256:"));
|
||||
|
||||
let binary = LanguageServerBinary {
|
||||
path: binary_path.clone(),
|
||||
@@ -102,9 +99,7 @@ impl super::LspAdapter for CLspAdapter {
|
||||
log::warn!("Unable to run {binary_path:?} asset, redownloading: {err}",)
|
||||
})
|
||||
};
|
||||
if let (Some(actual_digest), Some(expected_digest)) =
|
||||
(&metadata.digest, expected_digest)
|
||||
{
|
||||
if let (Some(actual_digest), Some(expected_digest)) = (&metadata.digest, digest) {
|
||||
if actual_digest == expected_digest {
|
||||
if validity_check().await.is_ok() {
|
||||
return Ok(binary);
|
||||
|
||||