Compare commits

..

2 Commits

Author SHA1 Message Date
Richard Feldman
e700d3dfa5 Fix test 2025-05-28 15:08:20 -04:00
Richard Feldman
9a89ffd1fa Make single-file review disabled by default. 2025-05-28 14:20:28 -04:00
224 changed files with 2755 additions and 7887 deletions

View File

@@ -482,9 +482,7 @@ jobs:
- macos_tests
- windows_clippy
- windows_tests
if: |
github.repository_owner == 'zed-industries' &&
always()
if: always()
steps:
- name: Check all tests passed
run: |
@@ -716,7 +714,6 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
nix-build:
name: Build with Nix
uses: ./.github/workflows/nix.yml
if: github.repository_owner == 'zed-industries' && contains(github.event.pull_request.labels.*.name, 'run-nix')
with:

View File

@@ -56,7 +56,6 @@ jobs:
name: zed
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
pushFilter: "${{ inputs.cachix-filter }}"
cachixArgs: '-v'
- run: nix build .#${{ inputs.flake-output }} -L --accept-flake-config

View File

@@ -168,7 +168,6 @@ jobs:
run: script/upload-nightly linux-targz
bundle-nix:
name: Build and cache Nix package
needs: tests
uses: ./.github/workflows/nix.yml

View File

@@ -2,11 +2,16 @@
{
"label": "Debug Zed (CodeLLDB)",
"adapter": "CodeLLDB",
"build": { "label": "Build Zed", "command": "cargo", "args": ["build"] }
"program": "$ZED_WORKTREE_ROOT/target/debug/zed",
"request": "launch"
},
{
"label": "Debug Zed (GDB)",
"adapter": "GDB",
"build": { "label": "Build Zed", "command": "cargo", "args": ["build"] }
"program": "$ZED_WORKTREE_ROOT/target/debug/zed",
"request": "launch",
"initialize_args": {
"stopAtBeginningOfMainSubprogram": true
}
}
]

14
Cargo.lock generated
View File

@@ -525,7 +525,6 @@ dependencies = [
"fuzzy",
"gpui",
"indexed_docs",
"indoc",
"language",
"language_model",
"languages",
@@ -560,7 +559,6 @@ dependencies = [
"workspace",
"workspace-hack",
"zed_actions",
"zed_llm_client",
]
[[package]]
@@ -685,7 +683,6 @@ dependencies = [
"language_model",
"language_models",
"log",
"lsp",
"markdown",
"open",
"paths",
@@ -2201,7 +2198,6 @@ dependencies = [
"editor",
"gpui",
"itertools 0.14.0",
"settings",
"theme",
"ui",
"workspace",
@@ -5050,7 +5046,6 @@ dependencies = [
"util",
"uuid",
"workspace-hack",
"zed_llm_client",
]
[[package]]
@@ -6153,7 +6148,6 @@ dependencies = [
"workspace",
"workspace-hack",
"zed_actions",
"zed_llm_client",
"zlog",
]
@@ -7071,7 +7065,6 @@ dependencies = [
"image",
"inventory",
"itertools 0.14.0",
"libc",
"log",
"lyon",
"media",
@@ -8937,7 +8930,6 @@ dependencies = [
"async-compression",
"async-tar",
"async-trait",
"chrono",
"collections",
"dap",
"futures 0.3.31",
@@ -8991,7 +8983,6 @@ dependencies = [
"tree-sitter-yaml",
"unindent",
"util",
"which 6.0.3",
"workspace",
"workspace-hack",
]
@@ -15591,7 +15582,6 @@ dependencies = [
"futures 0.3.31",
"gpui",
"hex",
"log",
"parking_lot",
"pretty_assertions",
"proto",
@@ -19885,9 +19875,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
version = "0.8.4"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de7d9523255f4e00ee3d0918e5407bd252d798a4a8e71f6d37f23317a1588203"
checksum = "22a8b9575b215536ed8ad254ba07171e4e13bd029eda3b54cca4b184d2768050"
dependencies = [
"anyhow",
"serde",

View File

@@ -617,7 +617,7 @@ wasmtime = { version = "29", default-features = false, features = [
wasmtime-wasi = "29"
which = "6.0.0"
workspace-hack = "0.1.0"
zed_llm_client = "0.8.4"
zed_llm_client = "0.8.3"
zstd = "0.11"
[workspace.dependencies.async-stripe]

View File

@@ -1,4 +1,5 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M11 13H10.4C9.76346 13 9.15302 12.7893 8.70296 12.4142C8.25284 12.0391 8 11.5304 8 11V5C8 4.46957 8.25284 3.96086 8.70296 3.58579C9.15302 3.21071 9.76346 3 10.4 3H11" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M5 13H5.6C6.23654 13 6.84698 12.7893 7.29704 12.4142C7.74716 12.0391 8 11.5304 8 11V5C8 4.46957 7.74716 3.96086 7.29704 3.58579C6.84698 3.21071 6.23654 3 5.6 3H5" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M17 20H16C14.9391 20 13.9217 19.6629 13.1716 19.0627C12.4214 18.4626 12 17.6487 12 16.8V7.2C12 6.35131 12.4214 5.53737 13.1716 4.93726C13.9217 4.33714 14.9391 4 16 4H17" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M7 20H8C9.06087 20 10.0783 19.5786 10.8284 18.8284C11.5786 18.0783 12 17.0609 12 16V15" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M7 4H8C9.06087 4 10.0783 4.42143 10.8284 5.17157C11.5786 5.92172 12 6.93913 12 8V9" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 617 B

After

Width:  |  Height:  |  Size: 715 B

View File

@@ -1,3 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M4 3L13 8L4 13V3Z" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 214 B

View File

@@ -1,8 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M4 12C2.35977 11.85 1 10.575 1 9" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M1.00875 15.2C1.00875 13.625 0.683456 12.275 4.00001 12.2" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M7 9C7 10.575 5.62857 11.85 4 12" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M4 12.2C6.98117 12.2 7 13.625 7 15.2" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<rect x="2.5" y="9" width="3" height="6" rx="1.5" fill="black"/>
<path d="M9 10L13 8L4 3V7.5" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 813 B

View File

@@ -1,8 +1,3 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M2 5H4" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
<path d="M8 5L14 5" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
<path d="M12 11L14 11" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
<path d="M2 11H8" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
<circle cx="6" cy="5" r="2" fill="black" fill-opacity="0.1" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
<circle cx="10" cy="11" r="2" fill="black" fill-opacity="0.1" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
<svg width="17" height="17" viewBox="0 0 17 17" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M6.36667 3.79167C5.53364 3.79167 4.85833 4.46697 4.85833 5.3C4.85833 6.13303 5.53364 6.80833 6.36667 6.80833C7.1997 6.80833 7.875 6.13303 7.875 5.3C7.875 4.46697 7.1997 3.79167 6.36667 3.79167ZM2.1 5.925H3.67944C3.9626 7.14732 5.05824 8.05833 6.36667 8.05833C7.67509 8.05833 8.77073 7.14732 9.05389 5.925H14.9C15.2452 5.925 15.525 5.64518 15.525 5.3C15.525 4.95482 15.2452 4.675 14.9 4.675H9.05389C8.77073 3.45268 7.67509 2.54167 6.36667 2.54167C5.05824 2.54167 3.9626 3.45268 3.67944 4.675H2.1C1.75482 4.675 1.475 4.95482 1.475 5.3C1.475 5.64518 1.75482 5.925 2.1 5.925ZM13.3206 12.325C13.0374 13.5473 11.9418 14.4583 10.6333 14.4583C9.32491 14.4583 8.22927 13.5473 7.94611 12.325H2.1C1.75482 12.325 1.475 12.0452 1.475 11.7C1.475 11.3548 1.75482 11.075 2.1 11.075H7.94611C8.22927 9.85268 9.32491 8.94167 10.6333 8.94167C11.9418 8.94167 13.0374 9.85268 13.3206 11.075H14.9C15.2452 11.075 15.525 11.3548 15.525 11.7C15.525 12.0452 15.2452 12.325 14.9 12.325H13.3206ZM9.125 11.7C9.125 10.867 9.8003 10.1917 10.6333 10.1917C11.4664 10.1917 12.1417 10.867 12.1417 11.7C12.1417 12.533 11.4664 13.2083 10.6333 13.2083C9.8003 13.2083 9.125 12.533 9.125 11.7Z" fill="black"/>
</svg>

Before

Width:  |  Height:  |  Size: 657 B

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@@ -1,5 +1,5 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8 2L6.72534 5.87534C6.6601 6.07367 6.5492 6.25392 6.40155 6.40155C6.25392 6.5492 6.07367 6.6601 5.87534 6.72534L2 8L5.87534 9.27466C6.07367 9.3399 6.25392 9.4508 6.40155 9.59845C6.5492 9.74608 6.6601 9.92633 6.72534 10.1247L8 14L9.27466 10.1247C9.3399 9.92633 9.4508 9.74608 9.59845 9.59845C9.74608 9.4508 9.92633 9.3399 10.1247 9.27466L14 8L10.1247 6.72534C9.92633 6.6601 9.74608 6.5492 9.59845 6.40155C9.4508 6.25392 9.3399 6.07367 9.27466 5.87534L8 2Z" fill="black" fill-opacity="0.15" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M3.33334 2V4.66666M2 3.33334H4.66666" stroke="black" stroke-opacity="0.75" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M12.6665 11.3333V14M11.3333 12.6666H13.9999" stroke="black" stroke-opacity="0.75" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M7 1.75L5.88467 5.14092C5.82759 5.31446 5.73055 5.47218 5.60136 5.60136C5.47218 5.73055 5.31446 5.82759 5.14092 5.88467L1.75 7L5.14092 8.11533C5.31446 8.17241 5.47218 8.26945 5.60136 8.39864C5.73055 8.52782 5.82759 8.68554 5.88467 8.85908L7 12.25L8.11533 8.85908C8.17241 8.68554 8.26945 8.52782 8.39864 8.39864C8.52782 8.26945 8.68554 8.17241 8.85908 8.11533L12.25 7L8.85908 5.88467C8.68554 5.82759 8.52782 5.73055 8.39864 5.60136C8.26945 5.47218 8.17241 5.31446 8.11533 5.14092L7 1.75Z" fill="black" fill-opacity="0.15" stroke="black" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M2.91667 1.75V4.08333M1.75 2.91667H4.08333" stroke="black" stroke-opacity="0.75" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M11.0833 9.91667V12.25M9.91667 11.0833H12.25" stroke="black" stroke-opacity="0.75" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 998 B

After

Width:  |  Height:  |  Size: 1.0 KiB

View File

@@ -127,7 +127,9 @@
"shift-f10": "editor::OpenContextMenu",
"ctrl-shift-e": "editor::ToggleEditPrediction",
"f9": "editor::ToggleBreakpoint",
"shift-f9": "editor::EditLogBreakpoint"
"shift-f9": "editor::EditLogBreakpoint",
"ctrl-shift-backspace": "editor::GoToPreviousChange",
"ctrl-shift-alt-backspace": "editor::GoToNextChange"
}
},
{
@@ -146,8 +148,6 @@
"ctrl->": "assistant::QuoteSelection",
"ctrl-<": "assistant::InsertIntoEditor",
"ctrl-alt-e": "editor::SelectEnclosingSymbol",
"ctrl-shift-backspace": "editor::GoToPreviousChange",
"ctrl-shift-alt-backspace": "editor::GoToNextChange",
"alt-enter": "editor::OpenSelectionsInMultibuffer"
}
},
@@ -244,14 +244,13 @@
"ctrl-i": "agent::ToggleProfileSelector",
"ctrl-alt-/": "agent::ToggleModelSelector",
"ctrl-shift-a": "agent::ToggleContextPicker",
"ctrl-shift-j": "agent::ToggleNavigationMenu",
"ctrl-shift-o": "agent::ToggleNavigationMenu",
"ctrl-shift-i": "agent::ToggleOptionsMenu",
"shift-alt-escape": "agent::ExpandMessageEditor",
"ctrl-alt-e": "agent::RemoveAllContext",
"ctrl-shift-e": "project_panel::ToggleFocus",
"ctrl-shift-enter": "agent::ContinueThread",
"alt-enter": "agent::ContinueWithBurnMode",
"ctrl-alt-b": "agent::ToggleBurnMode"
"alt-enter": "agent::ContinueWithBurnMode"
}
},
{
@@ -1019,12 +1018,5 @@
"bindings": {
"enter": "menu::Confirm"
}
},
{
"context": "RunModal",
"bindings": {
"ctrl-tab": "pane::ActivateNextItem",
"ctrl-shift-tab": "pane::ActivatePreviousItem"
}
}
]

View File

@@ -279,14 +279,13 @@
"cmd-i": "agent::ToggleProfileSelector",
"cmd-alt-/": "agent::ToggleModelSelector",
"cmd-shift-a": "agent::ToggleContextPicker",
"cmd-shift-j": "agent::ToggleNavigationMenu",
"cmd-shift-o": "agent::ToggleNavigationMenu",
"cmd-shift-i": "agent::ToggleOptionsMenu",
"shift-alt-escape": "agent::ExpandMessageEditor",
"cmd-alt-e": "agent::RemoveAllContext",
"cmd-shift-e": "project_panel::ToggleFocus",
"cmd-shift-enter": "agent::ContinueThread",
"alt-enter": "agent::ContinueWithBurnMode",
"cmd-alt-b": "agent::ToggleBurnMode"
"alt-enter": "agent::ContinueWithBurnMode"
}
},
{
@@ -546,7 +545,9 @@
"cmd-\\": "pane::SplitRight",
"cmd-k v": "markdown::OpenPreviewToTheSide",
"cmd-shift-v": "markdown::OpenPreview",
"ctrl-cmd-c": "editor::DisplayCursorNames"
"ctrl-cmd-c": "editor::DisplayCursorNames",
"cmd-shift-backspace": "editor::GoToPreviousChange",
"cmd-shift-alt-backspace": "editor::GoToNextChange"
}
},
{
@@ -554,9 +555,7 @@
"use_key_equivalents": true,
"bindings": {
"cmd-shift-o": "outline::Toggle",
"ctrl-g": "go_to_line::Toggle",
"cmd-shift-backspace": "editor::GoToPreviousChange",
"cmd-shift-alt-backspace": "editor::GoToNextChange"
"ctrl-g": "go_to_line::Toggle"
}
},
{
@@ -1109,13 +1108,5 @@
"bindings": {
"enter": "menu::Confirm"
}
},
{
"context": "RunModal",
"use_key_equivalents": true,
"bindings": {
"ctrl-tab": "pane::ActivateNextItem",
"ctrl-shift-tab": "pane::ActivatePreviousItem"
}
}
]

View File

@@ -1,85 +0,0 @@
[
// Cursor for MacOS. See: https://docs.cursor.com/kbd
{
"context": "Workspace",
"use_key_equivalents": true,
"bindings": {
"ctrl-i": "agent::ToggleFocus",
"ctrl-shift-i": "agent::ToggleFocus",
"ctrl-l": "agent::ToggleFocus",
"ctrl-shift-l": "agent::ToggleFocus",
"ctrl-alt-b": "agent::ToggleFocus",
"ctrl-shift-j": "agent::OpenConfiguration"
}
},
{
"context": "Editor && mode == full",
"use_key_equivalents": true,
"bindings": {
"ctrl-i": "agent::ToggleFocus",
"ctrl-shift-i": "agent::ToggleFocus",
"ctrl-shift-l": "assistant::QuoteSelection", // In cursor uses "Ask" mode
"ctrl-l": "assistant::QuoteSelection", // In cursor uses "Agent" mode
"ctrl-k": "assistant::InlineAssist",
"ctrl-shift-k": "assistant::InsertIntoEditor"
}
},
{
"context": "InlineAssistEditor",
"use_key_equivalents": true,
"bindings": {
"ctrl-shift-backspace": "editor::Cancel"
// "alt-enter": // Quick Question
// "ctrl-shift-enter": // Full File Context
// "ctrl-shift-k": // Toggle input focus (editor <> inline assist)
}
},
{
"context": "AgentPanel || ContextEditor || (MessageEditor > Editor)",
"use_key_equivalents": true,
"bindings": {
"ctrl-i": "workspace::ToggleRightDock",
"ctrl-shift-i": "workspace::ToggleRightDock",
"ctrl-l": "workspace::ToggleRightDock",
"ctrl-shift-l": "workspace::ToggleRightDock",
"ctrl-alt-b": "workspace::ToggleRightDock",
"ctrl-w": "workspace::ToggleRightDock", // technically should close chat
"ctrl-.": "agent::ToggleProfileSelector",
"ctrl-/": "agent::ToggleModelSelector",
"ctrl-shift-backspace": "editor::Cancel",
"ctrl-r": "agent::NewThread",
"ctrl-shift-v": "editor::Paste",
"ctrl-shift-k": "assistant::InsertIntoEditor"
// "escape": "agent::ToggleFocus"
///// Enable when Zed supports multiple thread tabs
// "ctrl-t": // new thread tab
// "ctrl-[": // next thread tab
// "ctrl-]": // next thread tab
///// Enable if Zed adds support for keyboard navigation of thread elements
// "tab": // cycle to next message
// "shift-tab": // cycle to previous message
}
},
{
"context": "Editor && editor_agent_diff",
"use_key_equivalents": true,
"bindings": {
"ctrl-enter": "agent::KeepAll",
"ctrl-backspace": "agent::RejectAll"
}
},
{
"context": "Editor && mode == full && edit_prediction",
"use_key_equivalents": true,
"bindings": {
"ctrl-right": "editor::AcceptPartialEditPrediction"
}
},
{
"context": "Terminal",
"use_key_equivalents": true,
"bindings": {
"ctrl-k": "assistant::InlineAssist"
}
}
]

View File

@@ -51,11 +51,7 @@
"ctrl-k ctrl-l": "editor::ConvertToLowerCase",
"shift-alt-m": "markdown::OpenPreviewToTheSide",
"ctrl-backspace": "editor::DeleteToPreviousWordStart",
"ctrl-delete": "editor::DeleteToNextWordEnd",
"ctrl-right": "editor::MoveToNextSubwordEnd",
"ctrl-left": "editor::MoveToPreviousSubwordStart",
"ctrl-shift-right": "editor::SelectToNextSubwordEnd",
"ctrl-shift-left": "editor::SelectToPreviousSubwordStart"
"ctrl-delete": "editor::DeleteToNextWordEnd"
}
},
{

View File

@@ -1,85 +0,0 @@
[
// Cursor for MacOS. See: https://docs.cursor.com/kbd
{
"context": "Workspace",
"use_key_equivalents": true,
"bindings": {
"cmd-i": "agent::ToggleFocus",
"cmd-shift-i": "agent::ToggleFocus",
"cmd-l": "agent::ToggleFocus",
"cmd-shift-l": "agent::ToggleFocus",
"cmd-alt-b": "agent::ToggleFocus",
"cmd-shift-j": "agent::OpenConfiguration"
}
},
{
"context": "Editor && mode == full",
"use_key_equivalents": true,
"bindings": {
"cmd-i": "agent::ToggleFocus",
"cmd-shift-i": "agent::ToggleFocus",
"cmd-shift-l": "assistant::QuoteSelection", // In cursor uses "Ask" mode
"cmd-l": "assistant::QuoteSelection", // In cursor uses "Agent" mode
"cmd-k": "assistant::InlineAssist",
"cmd-shift-k": "assistant::InsertIntoEditor"
}
},
{
"context": "InlineAssistEditor",
"use_key_equivalents": true,
"bindings": {
"cmd-shift-backspace": "editor::Cancel"
// "alt-enter": // Quick Question
// "cmd-shift-enter": // Full File Context
// "cmd-shift-k": // Toggle input focus (editor <> inline assist)
}
},
{
"context": "AgentPanel || ContextEditor || (MessageEditor > Editor)",
"use_key_equivalents": true,
"bindings": {
"cmd-i": "workspace::ToggleRightDock",
"cmd-shift-i": "workspace::ToggleRightDock",
"cmd-l": "workspace::ToggleRightDock",
"cmd-shift-l": "workspace::ToggleRightDock",
"cmd-alt-b": "workspace::ToggleRightDock",
"cmd-w": "workspace::ToggleRightDock", // technically should close chat
"cmd-.": "agent::ToggleProfileSelector",
"cmd-/": "agent::ToggleModelSelector",
"cmd-shift-backspace": "editor::Cancel",
"cmd-r": "agent::NewThread",
"cmd-shift-v": "editor::Paste",
"cmd-shift-k": "assistant::InsertIntoEditor"
// "escape": "agent::ToggleFocus"
///// Enable when Zed supports multiple thread tabs
// "cmd-t": // new thread tab
// "cmd-[": // next thread tab
// "cmd-]": // next thread tab
///// Enable if Zed adds support for keyboard navigation of thread elements
// "tab": // cycle to next message
// "shift-tab": // cycle to previous message
}
},
{
"context": "Editor && editor_agent_diff",
"use_key_equivalents": true,
"bindings": {
"cmd-enter": "agent::KeepAll",
"cmd-backspace": "agent::RejectAll"
}
},
{
"context": "Editor && mode == full && edit_prediction",
"use_key_equivalents": true,
"bindings": {
"cmd-right": "editor::AcceptPartialEditPrediction"
}
},
{
"context": "Terminal",
"use_key_equivalents": true,
"bindings": {
"cmd-k": "assistant::InlineAssist"
}
}
]

View File

@@ -53,11 +53,7 @@
"cmd-shift-j": "editor::JoinLines",
"shift-alt-m": "markdown::OpenPreviewToTheSide",
"ctrl-backspace": "editor::DeleteToPreviousWordStart",
"ctrl-delete": "editor::DeleteToNextWordEnd",
"ctrl-right": "editor::MoveToNextSubwordEnd",
"ctrl-left": "editor::MoveToPreviousSubwordStart",
"ctrl-shift-right": "editor::SelectToNextSubwordEnd",
"ctrl-shift-left": "editor::SelectToPreviousSubwordStart"
"ctrl-delete": "editor::DeleteToNextWordEnd"
}
},
{

View File

@@ -714,7 +714,7 @@
"version": "2",
// Whether the agent is enabled.
"enabled": true,
/// What completion mode to start new threads in, if available. Can be 'normal' or 'burn'.
/// What completion mode to start new threads in, if available. Can be 'normal' or 'max'.
"preferred_completion_mode": "normal",
// Whether to show the agent panel button in the status bar.
"button": true,
@@ -765,7 +765,7 @@
// When enabled, the agent will stream edits.
"stream_edits": false,
// When enabled, agent edits will be displayed in single-file editors for review
"single_file_review": true,
"single_file_review": false,
// When enabled, show voting thumbs for feedback on agent edits.
"enable_feedback": true,
"default_profile": "write",

View File

@@ -311,31 +311,6 @@ impl ActivityIndicator {
});
}
if let Some(session) = self
.project
.read(cx)
.dap_store()
.read(cx)
.sessions()
.find(|s| !s.read(cx).is_started())
{
return Some(Content {
icon: Some(
Icon::new(IconName::ArrowCircle)
.size(IconSize::Small)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
)
.into_any_element(),
),
message: format!("Debug: {}", session.read(cx).adapter()),
tooltip_message: Some(session.read(cx).label().to_string()),
on_click: None,
});
}
let current_job = self
.project
.read(cx)

View File

@@ -55,7 +55,6 @@ use util::ResultExt as _;
use util::markdown::MarkdownCodeBlock;
use workspace::{CollaboratorId, Workspace};
use zed_actions::assistant::OpenRulesLibrary;
use zed_llm_client::CompletionIntent;
pub struct ActiveThread {
context_store: Entity<ContextStore>,
@@ -1437,7 +1436,6 @@ impl ActiveThread {
let request = language_model::LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: None,
mode: None,
messages: vec![request_message],
tools: vec![],
@@ -1535,22 +1533,9 @@ impl ActiveThread {
});
}
fn cancel_editing_message(
&mut self,
_: &menu::Cancel,
window: &mut Window,
cx: &mut Context<Self>,
) {
fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
self.editing_message.take();
cx.notify();
if let Some(workspace) = self.workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
panel.focus_handle(cx).focus(window);
}
});
}
}
fn confirm_editing_message(
@@ -1612,12 +1597,7 @@ impl ActiveThread {
this.thread.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(
model.model,
CompletionIntent::UserPrompt,
Some(window.window_handle()),
cx,
);
thread.send_to_model(model.model, Some(window.window_handle()), cx);
});
this._load_edited_message_context_task = None;
cx.notify();
@@ -1838,7 +1818,6 @@ impl ActiveThread {
let colors = cx.theme().colors();
let editor_bg_color = colors.editor_background;
let panel_bg = colors.panel_background;
let open_as_markdown = IconButton::new(("open-as-markdown", ix), IconName::DocumentText)
.icon_size(IconSize::XSmall)
@@ -1859,6 +1838,7 @@ impl ActiveThread {
const RESPONSE_PADDING_X: Pixels = px(19.);
let show_feedback = thread.is_turn_end(ix);
let feedback_container = h_flex()
.group("feedback_container")
.mt_1()
@@ -2155,14 +2135,16 @@ impl ActiveThread {
message_id > *editing_message_id
});
let panel_background = cx.theme().colors().panel_background;
let backdrop = div()
.id(("backdrop", ix))
.size_full()
.id("backdrop")
.stop_mouse_events_except_scroll()
.absolute()
.inset_0()
.bg(panel_bg)
.size_full()
.bg(panel_background)
.opacity(0.8)
.block_mouse_except_scroll()
.on_click(cx.listener(Self::handle_cancel_click));
v_flex()
@@ -3709,8 +3691,7 @@ mod tests {
// Stream response to user message
thread.update(cx, |thread, cx| {
let request =
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx);
let request = thread.to_completion_request(model.clone(), cx);
thread.stream_completion(request, model, cx.active_window(), cx)
});
// Follow the agent

View File

@@ -89,7 +89,6 @@ actions!(
ResetTrialEndUpsell,
ContinueThread,
ContinueWithBurnMode,
ToggleBurnMode,
]
);

View File

@@ -699,7 +699,7 @@ fn render_diff_hunk_controls(
.rounded_b_md()
.bg(cx.theme().colors().editor_background)
.gap_1()
.block_mouse_except_scroll()
.stop_mouse_events_except_scroll()
.shadow_md()
.children(vec![
Button::new(("reject", row as u64), "Reject")
@@ -1919,6 +1919,13 @@ mod tests {
EditorSettings::register(cx);
language_model::init_settings(cx);
workspace::register_project_item::<Editor>(cx);
// Explicitly set single_file_review to true since it's now false by default
SettingsStore::update_global(cx, |store, _cx| {
let mut agent_settings = store.get::<AgentSettings>(None).clone();
agent_settings.single_file_review = true;
store.override_global(agent_settings);
});
});
let fs = FakeFs::new(cx.executor());

View File

@@ -1,11 +1,10 @@
use agent_settings::AgentSettings;
use fs::Fs;
use gpui::{Entity, FocusHandle, SharedString};
use picker::popover_menu::PickerPopoverMenu;
use crate::Thread;
use assistant_context_editor::language_model_selector::{
LanguageModelSelector, ToggleModelSelector, language_model_selector,
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
};
use language_model::{ConfiguredModel, LanguageModelRegistry};
use settings::update_settings_file;
@@ -36,7 +35,7 @@ impl AgentModelSelector {
Self {
selector: cx.new(move |cx| {
let fs = fs.clone();
language_model_selector(
LanguageModelSelector::new(
{
let model_type = model_type.clone();
move |cx| match &model_type {
@@ -101,14 +100,15 @@ impl AgentModelSelector {
}
impl Render for AgentModelSelector {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let focus_handle = self.focus_handle.clone();
let model = self.selector.read(cx).delegate.active_model(cx);
let model = self.selector.read(cx).active_model(cx);
let model_name = model
.map(|model| model.model.name().0)
.unwrap_or_else(|| SharedString::from("No model selected"));
PickerPopoverMenu::new(
LanguageModelSelectorPopoverMenu::new(
self.selector.clone(),
Button::new("active-model", model_name)
.label_size(LabelSize::Small)
@@ -127,9 +127,7 @@ impl Render for AgentModelSelector {
)
},
gpui::Corner::BottomRight,
cx,
)
.with_handle(self.menu_handle.clone())
.render(window, cx)
}
}

View File

@@ -52,7 +52,7 @@ use workspace::{
use zed_actions::agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding};
use zed_actions::assistant::{OpenRulesLibrary, ToggleFocus};
use zed_actions::{DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize};
use zed_llm_client::{CompletionIntent, UsageLimit};
use zed_llm_client::UsageLimit;
use crate::active_thread::{self, ActiveThread, ActiveThreadEvent};
use crate::agent_configuration::{AgentConfiguration, AssistantConfigurationEvent};
@@ -67,8 +67,8 @@ use crate::{
AddContextServer, AgentDiffPane, ContextStore, ContinueThread, ContinueWithBurnMode,
DeleteRecentlyOpenThread, ExpandMessageEditor, Follow, InlineAssistant, NewTextThread,
NewThread, OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ResetTrialEndUpsell,
ResetTrialUpsell, TextThreadStore, ThreadEvent, ToggleBurnMode, ToggleContextPicker,
ToggleNavigationMenu, ToggleOptionsMenu,
ResetTrialUpsell, TextThreadStore, ThreadEvent, ToggleContextPicker, ToggleNavigationMenu,
ToggleOptionsMenu,
};
const AGENT_PANEL_KEY: &str = "agent_panel";
@@ -174,7 +174,7 @@ enum ActiveView {
thread: WeakEntity<Thread>,
_subscriptions: Vec<gpui::Subscription>,
},
TextThread {
PromptEditor {
context_editor: Entity<ContextEditor>,
title_editor: Entity<Editor>,
buffer_search_bar: Entity<BufferSearchBar>,
@@ -194,7 +194,7 @@ impl ActiveView {
pub fn which_font_size_used(&self) -> WhichFontSize {
match self {
ActiveView::Thread { .. } | ActiveView::History => WhichFontSize::AgentFont,
ActiveView::TextThread { .. } => WhichFontSize::BufferFont,
ActiveView::PromptEditor { .. } => WhichFontSize::BufferFont,
ActiveView::Configuration => WhichFontSize::None,
}
}
@@ -333,7 +333,7 @@ impl ActiveView {
buffer_search_bar.set_active_pane_item(Some(&context_editor), window, cx)
});
Self::TextThread {
Self::PromptEditor {
context_editor,
title_editor: editor,
buffer_search_bar,
@@ -1084,23 +1084,9 @@ impl AgentPanel {
pub fn go_back(&mut self, _: &workspace::GoBack, window: &mut Window, cx: &mut Context<Self>) {
match self.active_view {
ActiveView::Configuration | ActiveView::History => {
if let Some(previous_view) = self.previous_view.take() {
self.active_view = previous_view;
match &self.active_view {
ActiveView::Thread { .. } => {
self.message_editor.focus_handle(cx).focus(window);
}
ActiveView::TextThread { context_editor, .. } => {
context_editor.focus_handle(cx).focus(window);
}
_ => {}
}
} else {
self.active_view =
ActiveView::thread(self.thread.read(cx).thread().clone(), window, cx);
self.message_editor.focus_handle(cx).focus(window);
}
self.active_view =
ActiveView::thread(self.thread.read(cx).thread().clone(), window, cx);
self.message_editor.focus_handle(cx).focus(window);
cx.notify();
}
_ => {}
@@ -1310,12 +1296,7 @@ impl AgentPanel {
active_thread.thread().update(cx, |thread, cx| {
thread.insert_invisible_continue_message(cx);
thread.advance_prompt_id();
thread.send_to_model(
model,
CompletionIntent::UserPrompt,
Some(window.window_handle()),
cx,
);
thread.send_to_model(model, Some(window.window_handle()), cx);
});
});
} else {
@@ -1323,27 +1304,9 @@ impl AgentPanel {
}
}
fn toggle_burn_mode(
&mut self,
_: &ToggleBurnMode,
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, _cx| {
let current_mode = thread.completion_mode();
thread.set_completion_mode(match current_mode {
CompletionMode::Burn => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Burn,
});
});
});
}
pub(crate) fn active_context_editor(&self) -> Option<Entity<ContextEditor>> {
match &self.active_view {
ActiveView::TextThread { context_editor, .. } => Some(context_editor.clone()),
ActiveView::PromptEditor { context_editor, .. } => Some(context_editor.clone()),
_ => None,
}
}
@@ -1366,12 +1329,6 @@ impl AgentPanel {
let current_is_history = matches!(self.active_view, ActiveView::History);
let new_is_history = matches!(new_view, ActiveView::History);
let current_is_config = matches!(self.active_view, ActiveView::Configuration);
let new_is_config = matches!(new_view, ActiveView::Configuration);
let current_is_special = current_is_history || current_is_config;
let new_is_special = new_is_history || new_is_config;
match &self.active_view {
ActiveView::Thread { thread, .. } => {
if let Some(thread) = thread.upgrade() {
@@ -1383,7 +1340,7 @@ impl AgentPanel {
}
}
}
ActiveView::TextThread { context_editor, .. } => {
ActiveView::PromptEditor { context_editor, .. } => {
let context = context_editor.read(cx).context();
// When switching away from an unsaved text thread, delete its entry.
if context.read(cx).path().is_none() {
@@ -1403,7 +1360,7 @@ impl AgentPanel {
store.push_recently_opened_entry(RecentEntry::Thread(id, thread), cx);
}
}),
ActiveView::TextThread { context_editor, .. } => {
ActiveView::PromptEditor { context_editor, .. } => {
self.history_store.update(cx, |store, cx| {
let context = context_editor.read(cx).context().clone();
store.push_recently_opened_entry(RecentEntry::Context(context), cx)
@@ -1412,12 +1369,12 @@ impl AgentPanel {
_ => {}
}
if current_is_special && !new_is_special {
if current_is_history && !new_is_history {
self.active_view = new_view;
} else if !current_is_special && new_is_special {
} else if !current_is_history && new_is_history {
self.previous_view = Some(std::mem::replace(&mut self.active_view, new_view));
} else {
if !new_is_special {
if !new_is_history {
self.previous_view = None;
}
self.active_view = new_view;
@@ -1432,7 +1389,7 @@ impl Focusable for AgentPanel {
match &self.active_view {
ActiveView::Thread { .. } => self.message_editor.focus_handle(cx),
ActiveView::History => self.history.focus_handle(cx),
ActiveView::TextThread { context_editor, .. } => context_editor.focus_handle(cx),
ActiveView::PromptEditor { context_editor, .. } => context_editor.focus_handle(cx),
ActiveView::Configuration => {
if let Some(configuration) = self.configuration.as_ref() {
configuration.focus_handle(cx)
@@ -1584,7 +1541,7 @@ impl AgentPanel {
.into_any_element(),
}
}
ActiveView::TextThread {
ActiveView::PromptEditor {
title_editor,
context_editor,
..
@@ -1676,7 +1633,7 @@ impl AgentPanel {
let show_token_count = match &self.active_view {
ActiveView::Thread { .. } => !is_empty || !editor_empty,
ActiveView::TextThread { .. } => true,
ActiveView::PromptEditor { .. } => true,
_ => false,
};
@@ -1992,7 +1949,7 @@ impl AgentPanel {
Some(token_count)
}
ActiveView::TextThread { context_editor, .. } => {
ActiveView::PromptEditor { context_editor, .. } => {
let element = render_remaining_tokens(context_editor, cx)?;
Some(element.into_any_element())
@@ -2706,7 +2663,7 @@ impl AgentPanel {
.on_click(cx.listener(|this, _, window, cx| {
this.thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, _cx| {
thread.set_completion_mode(CompletionMode::Burn);
thread.set_completion_mode(CompletionMode::Max);
});
});
this.continue_conversation(window, cx);
@@ -2910,7 +2867,7 @@ impl AgentPanel {
) -> Div {
let mut registrar = buffer_search::DivRegistrar::new(
|this, _, _cx| match &this.active_view {
ActiveView::TextThread {
ActiveView::PromptEditor {
buffer_search_bar, ..
} => Some(buffer_search_bar.clone()),
_ => None,
@@ -3028,7 +2985,7 @@ impl AgentPanel {
.detach();
});
}
ActiveView::TextThread { context_editor, .. } => {
ActiveView::PromptEditor { context_editor, .. } => {
context_editor.update(cx, |context_editor, cx| {
ContextEditor::insert_dragged_files(
context_editor,
@@ -3055,7 +3012,7 @@ impl AgentPanel {
fn key_context(&self) -> KeyContext {
let mut key_context = KeyContext::new_with_defaults();
key_context.add("AgentPanel");
if matches!(self.active_view, ActiveView::TextThread { .. }) {
if matches!(self.active_view, ActiveView::PromptEditor { .. }) {
key_context.add("prompt_editor");
}
key_context
@@ -3103,12 +3060,11 @@ impl Render for AgentPanel {
.on_action(cx.listener(|this, _: &ContinueWithBurnMode, window, cx| {
this.thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, _cx| {
thread.set_completion_mode(CompletionMode::Burn);
thread.set_completion_mode(CompletionMode::Max);
});
});
this.continue_conversation(window, cx);
}))
.on_action(cx.listener(Self::toggle_burn_mode))
.child(self.render_toolbar(window, cx))
.children(self.render_upsell(window, cx))
.children(self.render_trial_end_upsell(window, cx))
@@ -3121,7 +3077,7 @@ impl Render for AgentPanel {
.children(self.render_last_error(cx))
.child(self.render_drag_target(cx)),
ActiveView::History => parent.child(self.history.clone()),
ActiveView::TextThread {
ActiveView::PromptEditor {
context_editor,
buffer_search_bar,
..

View File

@@ -34,7 +34,6 @@ use std::{
};
use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use zed_llm_client::CompletionIntent;
pub struct BufferCodegen {
alternatives: Vec<Entity<CodegenAlternative>>,
@@ -465,7 +464,6 @@ impl CodegenAlternative {
LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: Some(CompletionIntent::InlineAssist),
mode: None,
tools: Vec::new(),
tool_choice: None,

View File

@@ -1445,7 +1445,7 @@ impl InlineAssistant {
style: BlockStyle::Flex,
render: Arc::new(move |cx| {
div()
.block_mouse_except_scroll()
.block_mouse_down()
.bg(cx.theme().status().deleted_background)
.size_full()
.h(height as f32 * cx.window.line_height())

View File

@@ -100,7 +100,7 @@ impl<T: 'static> Render for PromptEditor<T> {
v_flex()
.key_context("PromptEditor")
.bg(cx.theme().colors().editor_background)
.block_mouse_except_scroll()
.block_mouse_down()
.gap_0p5()
.border_y_1()
.border_color(cx.theme().status().info_border)

View File

@@ -42,7 +42,6 @@ use theme::ThemeSettings;
use ui::{Disclosure, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*};
use util::{ResultExt as _, maybe};
use workspace::{CollaboratorId, Workspace};
use zed_llm_client::CompletionIntent;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
use crate::context_store::ContextStore;
@@ -52,7 +51,7 @@ use crate::thread::{MessageCrease, Thread, TokenUsageRatio};
use crate::thread_store::{TextThreadStore, ThreadStore};
use crate::{
ActiveThread, AgentDiffPane, Chat, ChatWithFollow, ExpandMessageEditor, Follow, NewThread,
OpenAgentDiff, RemoveAllContext, ToggleBurnMode, ToggleContextPicker, ToggleProfileSelector,
OpenAgentDiff, RemoveAllContext, ToggleContextPicker, ToggleProfileSelector,
register_agent_preview,
};
@@ -376,12 +375,7 @@ impl MessageEditor {
thread
.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(
model,
CompletionIntent::UserPrompt,
Some(window_handle),
cx,
);
thread.send_to_model(model, Some(window_handle), cx);
})
.log_err();
})
@@ -477,22 +471,6 @@ impl MessageEditor {
}
}
pub fn toggle_burn_mode(
&mut self,
_: &ToggleBurnMode,
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.thread.update(cx, |thread, _cx| {
let active_completion_mode = thread.completion_mode();
thread.set_completion_mode(match active_completion_mode {
CompletionMode::Burn => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Burn,
});
});
}
fn render_max_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let thread = self.thread.read(cx);
let model = thread.configured_model();
@@ -501,8 +479,8 @@ impl MessageEditor {
}
let active_completion_mode = thread.completion_mode();
let burn_mode_enabled = active_completion_mode == CompletionMode::Burn;
let icon = if burn_mode_enabled {
let max_mode_enabled = active_completion_mode == CompletionMode::Max;
let icon = if max_mode_enabled {
IconName::ZedBurnModeOn
} else {
IconName::ZedBurnMode
@@ -512,13 +490,18 @@ impl MessageEditor {
IconButton::new("burn-mode", icon)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.toggle_state(burn_mode_enabled)
.toggle_state(max_mode_enabled)
.selected_icon_color(Color::Error)
.on_click(cx.listener(|this, _event, window, cx| {
this.toggle_burn_mode(&ToggleBurnMode, window, cx);
.on_click(cx.listener(move |this, _event, _window, cx| {
this.thread.update(cx, |thread, _cx| {
thread.set_completion_mode(match active_completion_mode {
CompletionMode::Max => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Max,
});
});
}))
.tooltip(move |_window, cx| {
cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled))
cx.new(|_| MaxModeTooltip::new().selected(max_mode_enabled))
.into()
})
.into_any_element(),
@@ -613,7 +596,6 @@ impl MessageEditor {
.on_action(cx.listener(Self::remove_all_context))
.on_action(cx.listener(Self::move_up))
.on_action(cx.listener(Self::expand_message_editor))
.on_action(cx.listener(Self::toggle_burn_mode))
.capture_action(cx.listener(Self::paste))
.gap_2()
.p_2()
@@ -1286,7 +1268,6 @@ impl MessageEditor {
let request = language_model::LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: None,
mode: None,
messages: vec![request_message],
tools: vec![],

View File

@@ -1 +0,0 @@
These files changed since last read:

View File

@@ -1,6 +0,0 @@
Generate a detailed summary of this conversation. Include:
1. A brief overview of what was discussed
2. Key facts or information discovered
3. Outcomes or conclusions reached
4. Any action items or next steps if any
Format it in Markdown with headings and bullet points.

View File

@@ -1,4 +0,0 @@
Generate a concise 3-7 word title for this conversation, omitting punctuation.
Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`.
If the conversation is about a specific subject, include it in the title.
Be descriptive. DO NOT speak in the first person.

View File

@@ -25,7 +25,6 @@ use terminal_view::TerminalView;
use ui::prelude::*;
use util::ResultExt;
use workspace::{Toast, Workspace, notifications::NotificationId};
use zed_llm_client::CompletionIntent;
pub fn init(
fs: Arc<dyn Fs>,
@@ -292,7 +291,6 @@ impl TerminalInlineAssistant {
thread_id: None,
prompt_id: None,
mode: None,
intent: Some(CompletionIntent::TerminalInlineAssist),
messages: vec![request_message],
tools: Vec::new(),
tool_choice: None,

View File

@@ -38,7 +38,7 @@ use thiserror::Error;
use ui::Window;
use util::{ResultExt as _, post_inc};
use uuid::Uuid;
use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
use zed_llm_client::CompletionRequestStatus;
use crate::ThreadStore;
use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
@@ -1184,7 +1184,6 @@ impl Thread {
pub fn send_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
@@ -1194,7 +1193,7 @@ impl Thread {
self.remaining_turns -= 1;
let request = self.to_completion_request(model.clone(), intent, cx);
let request = self.to_completion_request(model.clone(), cx);
self.stream_completion(request, model, window, cx);
}
@@ -1214,13 +1213,11 @@ impl Thread {
pub fn to_completion_request(
&self,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
cx: &mut Context<Self>,
) -> LanguageModelRequest {
let mut request = LanguageModelRequest {
thread_id: Some(self.id.to_string()),
prompt_id: Some(self.last_prompt_id.to_string()),
intent: Some(intent),
mode: None,
messages: vec![],
tools: Vec::new(),
@@ -1374,14 +1371,12 @@ impl Thread {
fn to_summarize_request(
&self,
model: &Arc<dyn LanguageModel>,
intent: CompletionIntent,
added_user_message: String,
cx: &App,
) -> LanguageModelRequest {
let mut request = LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: Some(intent),
mode: None,
messages: vec![],
tools: Vec::new(),
@@ -1428,7 +1423,7 @@ impl Thread {
messages: &mut Vec<LanguageModelRequestMessage>,
cx: &App,
) {
const STALE_FILES_HEADER: &str = include_str!("./prompts/stale_files_prompt_header.txt");
const STALE_FILES_HEADER: &str = "These files changed since last read:";
let mut stale_message = String::new();
@@ -1440,7 +1435,7 @@ impl Thread {
};
if stale_message.is_empty() {
write!(&mut stale_message, "{}\n", STALE_FILES_HEADER.trim()).ok();
write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
}
writeln!(&mut stale_message, "- {}", file.path().display()).ok();
@@ -1854,14 +1849,12 @@ impl Thread {
return;
}
let added_user_message = include_str!("./prompts/summarize_thread_prompt.txt");
let added_user_message = "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
If the conversation is about a specific subject, include it in the title. \
Be descriptive. DO NOT speak in the first person.";
let request = self.to_summarize_request(
&model.model,
CompletionIntent::ThreadSummarization,
added_user_message.into(),
cx,
);
let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
self.summary = ThreadSummary::Generating;
@@ -1955,14 +1948,14 @@ impl Thread {
return;
}
let added_user_message = include_str!("./prompts/summarize_thread_detailed_prompt.txt");
let added_user_message = "Generate a detailed summary of this conversation. Include:\n\
1. A brief overview of what was discussed\n\
2. Key facts or information discovered\n\
3. Outcomes or conclusions reached\n\
4. Any action items or next steps if any\n\
Format it in Markdown with headings and bullet points.";
let request = self.to_summarize_request(
&model,
CompletionIntent::ThreadContextSummarization,
added_user_message.into(),
cx,
);
let request = self.to_summarize_request(&model, added_user_message.into(), cx);
*self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating {
message_id: last_message_id,
@@ -2054,8 +2047,7 @@ impl Thread {
model: Arc<dyn LanguageModel>,
) -> Vec<PendingToolUse> {
self.auto_capture_telemetry(cx);
let request =
Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx));
let request = Arc::new(self.to_completion_request(model.clone(), cx));
let pending_tool_uses = self
.tool_use
.pending_tool_uses()
@@ -2251,7 +2243,7 @@ impl Thread {
if self.all_tools_finished() {
if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
if !canceled {
self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx);
self.send_to_model(model.clone(), window, cx);
}
self.auto_capture_telemetry(cx);
}
@@ -2942,7 +2934,7 @@ fn main() {{
// Check message in request
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
thread.to_completion_request(model.clone(), cx)
});
assert_eq!(request.messages.len(), 2);
@@ -3037,7 +3029,7 @@ fn main() {{
// Check entire request to make sure all contexts are properly included
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
thread.to_completion_request(model.clone(), cx)
});
// The request should contain all 3 messages
@@ -3144,7 +3136,7 @@ fn main() {{
// Check message in request
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
thread.to_completion_request(model.clone(), cx)
});
assert_eq!(request.messages.len(), 2);
@@ -3170,7 +3162,7 @@ fn main() {{
// Check that both messages appear in the request
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
thread.to_completion_request(model.clone(), cx)
});
assert_eq!(request.messages.len(), 3);
@@ -3215,7 +3207,7 @@ fn main() {{
// Create a request and check that it doesn't have a stale buffer warning yet
let initial_request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
thread.to_completion_request(model.clone(), cx)
});
// Make sure we don't have a stale file warning yet
@@ -3251,7 +3243,7 @@ fn main() {{
// Create a new request and check for the stale buffer warning
let new_request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
thread.to_completion_request(model.clone(), cx)
});
// We should have a stale file warning as the last message
@@ -3301,7 +3293,7 @@ fn main() {{
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
thread.to_completion_request(model.clone(), cx)
});
assert_eq!(request.temperature, Some(0.66));
@@ -3321,7 +3313,7 @@ fn main() {{
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
thread.to_completion_request(model.clone(), cx)
});
assert_eq!(request.temperature, Some(0.66));
@@ -3341,7 +3333,7 @@ fn main() {{
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
thread.to_completion_request(model.clone(), cx)
});
assert_eq!(request.temperature, Some(0.66));
@@ -3361,7 +3353,7 @@ fn main() {{
});
let request = thread.update(cx, |thread, cx| {
thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx)
thread.to_completion_request(model.clone(), cx)
});
assert_eq!(request.temperature, None);
}
@@ -3393,12 +3385,7 @@ fn main() {{
// Send a message
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
thread.send_to_model(
model.clone(),
CompletionIntent::ThreadSummarization,
None,
cx,
);
thread.send_to_model(model.clone(), None, cx);
});
let fake_model = model.as_fake();
@@ -3493,7 +3480,7 @@ fn main() {{
vec![],
cx,
);
thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx);
thread.send_to_model(model.clone(), None, cx);
});
let fake_model = model.as_fake();
@@ -3531,12 +3518,7 @@ fn main() {{
) {
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
thread.send_to_model(
model.clone(),
CompletionIntent::ThreadSummarization,
None,
cx,
);
thread.send_to_model(model.clone(), None, cx);
});
let fake_model = model.as_fake();

View File

@@ -1,6 +1,5 @@
use crate::ToggleBurnMode;
use gpui::{Context, FontWeight, IntoElement, Render, Window};
use ui::{KeyBinding, prelude::*, tooltip_container};
use gpui::{Context, IntoElement, Render, Window};
use ui::{prelude::*, tooltip_container};
pub struct MaxModeTooltip {
selected: bool,
@@ -19,48 +18,39 @@ impl MaxModeTooltip {
impl Render for MaxModeTooltip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let (icon, color) = if self.selected {
(IconName::ZedBurnModeOn, Color::Error)
let icon = if self.selected {
IconName::ZedBurnModeOn
} else {
(IconName::ZedBurnMode, Color::Default)
IconName::ZedBurnMode
};
let turned_on = h_flex()
.h_4()
.px_1()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().text_accent.opacity(0.1))
.rounded_sm()
.child(
Label::new("ON")
.size(LabelSize::XSmall)
.weight(FontWeight::SEMIBOLD)
.color(Color::Accent),
);
let title = h_flex()
.gap_1p5()
.child(Icon::new(icon).size(IconSize::Small).color(color))
.child(Label::new("Burn Mode"))
.when(self.selected, |title| title.child(turned_on));
let keybinding = KeyBinding::for_action(&ToggleBurnMode, window, cx)
.map(|kb| kb.size(rems_from_px(12.)));
.gap_1()
.child(Icon::new(icon).size(IconSize::Small))
.child(Label::new("Burn Mode"));
tooltip_container(window, cx, |this, _, _| {
this
.child(
h_flex()
.justify_between()
.child(title)
.children(keybinding)
)
this.gap_0p5()
.map(|header| if self.selected {
header.child(
h_flex()
.justify_between()
.child(title)
.child(
h_flex()
.gap_0p5()
.child(Icon::new(IconName::Check).size(IconSize::XSmall).color(Color::Accent))
.child(Label::new("Turned On").size(LabelSize::XSmall).color(Color::Accent))
)
)
} else {
header.child(title)
})
.child(
div()
.max_w_64()
.max_w_72()
.child(
Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning.")
Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning, offering an unfettered agentic experience.")
.size(LabelSize::Small)
.color(Color::Muted)
)

View File

@@ -663,7 +663,7 @@ pub struct AgentSettingsContentV2 {
stream_edits: Option<bool>,
/// Whether to display agent edits in single-file editors in addition to the review multibuffer pane.
///
/// Default: true
/// Default: false
single_file_review: Option<bool>,
/// Additional parameters for language model requests. When making a request
/// to a model, parameters will be taken from the last entry in this list
@@ -689,15 +689,14 @@ pub struct AgentSettingsContentV2 {
pub enum CompletionMode {
#[default]
Normal,
#[serde(alias = "max")]
Burn,
Max,
}
impl From<CompletionMode> for zed_llm_client::CompletionMode {
fn from(value: CompletionMode) -> Self {
match value {
CompletionMode::Normal => zed_llm_client::CompletionMode::Normal,
CompletionMode::Burn => zed_llm_client::CompletionMode::Max,
CompletionMode::Max => zed_llm_client::CompletionMode::Max,
}
}
}

View File

@@ -57,10 +57,8 @@ uuid.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
indoc.workspace = true
language_model = { workspace = true, features = ["test-support"] }
languages = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true

View File

@@ -45,7 +45,6 @@ use text::{BufferSnapshot, ToPoint};
use ui::IconName;
use util::{ResultExt, TryFutureExt, post_inc};
use uuid::Uuid;
use zed_llm_client::CompletionIntent;
#[derive(Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ContextId(String);
@@ -2273,7 +2272,6 @@ impl AssistantContext {
let mut completion_request = LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: Some(CompletionIntent::UserPrompt),
mode: None,
messages: Vec::new(),
tools: Vec::new(),

View File

@@ -1,6 +1,6 @@
use crate::{
language_model_selector::{
LanguageModelSelector, ToggleModelSelector, language_model_selector,
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
},
max_mode_tooltip::MaxModeTooltip,
};
@@ -43,7 +43,7 @@ use language_model::{
Role,
};
use multi_buffer::MultiBufferRow;
use picker::{Picker, popover_menu::PickerPopoverMenu};
use picker::Picker;
use project::{Project, Worktree};
use project::{ProjectPath, lsp_store::LocalLspAdapterDelegate};
use rope::Point;
@@ -283,7 +283,7 @@ impl ContextEditor {
slash_menu_handle: Default::default(),
dragged_file_worktrees: Vec::new(),
language_model_selector: cx.new(|cx| {
language_model_selector(
LanguageModelSelector::new(
|cx| LanguageModelRegistry::read_global(cx).default_model(),
move |model, cx| {
update_settings_file::<AgentSettings>(
@@ -1646,35 +1646,34 @@ impl ContextEditor {
let context = self.context.read(cx);
let mut text = String::new();
// If selection is empty, we want to copy the entire line
if selection.range().is_empty() {
let snapshot = context.buffer().read(cx).snapshot();
let point = snapshot.offset_to_point(selection.range().start);
selection.start = snapshot.point_to_offset(Point::new(point.row, 0));
selection.end = snapshot
.point_to_offset(cmp::min(Point::new(point.row + 1, 0), snapshot.max_point()));
for chunk in context.buffer().read(cx).text_for_range(selection.range()) {
text.push_str(chunk);
}
} else {
for message in context.messages(cx) {
if message.offset_range.start >= selection.range().end {
break;
} else if message.offset_range.end >= selection.range().start {
let range = cmp::max(message.offset_range.start, selection.range().start)
..cmp::min(message.offset_range.end, selection.range().end);
if !range.is_empty() {
for chunk in context.buffer().read(cx).text_for_range(range) {
text.push_str(chunk);
}
if message.offset_range.end < selection.range().end {
text.push('\n');
}
for message in context.messages(cx) {
if message.offset_range.start >= selection.range().end {
break;
} else if message.offset_range.end >= selection.range().start {
let range = cmp::max(message.offset_range.start, selection.range().start)
..cmp::min(message.offset_range.end, selection.range().end);
if range.is_empty() {
let snapshot = context.buffer().read(cx).snapshot();
let point = snapshot.offset_to_point(range.start);
selection.start = snapshot.point_to_offset(Point::new(point.row, 0));
selection.end = snapshot.point_to_offset(cmp::min(
Point::new(point.row + 1, 0),
snapshot.max_point(),
));
for chunk in context.buffer().read(cx).text_for_range(selection.range()) {
text.push_str(chunk);
}
} else {
for chunk in context.buffer().read(cx).text_for_range(range) {
text.push_str(chunk);
}
if message.offset_range.end < selection.range().end {
text.push('\n');
}
}
}
}
(text, CopyMetadata { creases }, vec![selection])
}
@@ -2072,8 +2071,8 @@ impl ContextEditor {
}
let active_completion_mode = context.completion_mode();
let burn_mode_enabled = active_completion_mode == CompletionMode::Burn;
let icon = if burn_mode_enabled {
let max_mode_enabled = active_completion_mode == CompletionMode::Max;
let icon = if max_mode_enabled {
IconName::ZedBurnModeOn
} else {
IconName::ZedBurnMode
@@ -2083,29 +2082,25 @@ impl ContextEditor {
IconButton::new("burn-mode", icon)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.toggle_state(burn_mode_enabled)
.toggle_state(max_mode_enabled)
.selected_icon_color(Color::Error)
.on_click(cx.listener(move |this, _event, _window, cx| {
this.context().update(cx, |context, _cx| {
context.set_completion_mode(match active_completion_mode {
CompletionMode::Burn => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Burn,
CompletionMode::Max => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Max,
});
});
}))
.tooltip(move |_window, cx| {
cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled))
cx.new(|_| MaxModeTooltip::new().selected(max_mode_enabled))
.into()
})
.into_any_element(),
)
}
fn render_language_model_selector(
&self,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
fn render_language_model_selector(&self, cx: &mut Context<Self>) -> impl IntoElement {
let active_model = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.model);
@@ -2115,7 +2110,7 @@ impl ContextEditor {
None => SharedString::from("No model selected"),
};
PickerPopoverMenu::new(
LanguageModelSelectorPopoverMenu::new(
self.language_model_selector.clone(),
ButtonLike::new("active-model")
.style(ButtonStyle::Subtle)
@@ -2143,10 +2138,8 @@ impl ContextEditor {
)
},
gpui::Corner::BottomLeft,
cx,
)
.with_handle(self.language_model_selector_menu_handle.clone())
.render(window, cx)
}
fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
@@ -2622,7 +2615,7 @@ impl Render for ContextEditor {
.child(
h_flex()
.gap_1()
.child(self.render_language_model_selector(window, cx))
.child(self.render_language_model_selector(cx))
.child(self.render_send_button(window, cx)),
),
)
@@ -3265,92 +3258,74 @@ mod tests {
use super::*;
use fs::FakeFs;
use gpui::{App, TestAppContext, VisualTestContext};
use indoc::indoc;
use language::{Buffer, LanguageRegistry};
use pretty_assertions::assert_eq;
use prompt_store::PromptBuilder;
use text::OffsetRangeExt;
use unindent::Unindent;
use util::path;
#[gpui::test]
async fn test_copy_paste_whole_message(cx: &mut TestAppContext) {
let (context, context_editor, mut cx) = setup_context_editor_text(vec![
(Role::User, "What is the Zed editor?"),
(
Role::Assistant,
"Zed is a modern, high-performance code editor designed from the ground up for speed and collaboration.",
),
(Role::User, ""),
],cx).await;
// Select & Copy whole user message
assert_copy_paste_context_editor(
&context_editor,
message_range(&context, 0, &mut cx),
indoc! {"
What is the Zed editor?
Zed is a modern, high-performance code editor designed from the ground up for speed and collaboration.
What is the Zed editor?
"},
&mut cx,
);
// Select & Copy whole assistant message
assert_copy_paste_context_editor(
&context_editor,
message_range(&context, 1, &mut cx),
indoc! {"
What is the Zed editor?
Zed is a modern, high-performance code editor designed from the ground up for speed and collaboration.
What is the Zed editor?
Zed is a modern, high-performance code editor designed from the ground up for speed and collaboration.
"},
&mut cx,
);
}
#[gpui::test]
async fn test_copy_paste_no_selection(cx: &mut TestAppContext) {
let (context, context_editor, mut cx) = setup_context_editor_text(
vec![
(Role::User, "user1"),
(Role::Assistant, "assistant1"),
(Role::Assistant, "assistant2"),
(Role::User, ""),
],
cx,
)
.await;
cx.update(init_test);
// Copy and paste first assistant message
let message_2_range = message_range(&context, 1, &mut cx);
assert_copy_paste_context_editor(
&context_editor,
message_2_range.start..message_2_range.start,
indoc! {"
user1
assistant1
assistant2
assistant1
"},
&mut cx,
);
let fs = FakeFs::new(cx.executor());
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| {
AssistantContext::local(
registry,
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
cx,
)
});
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
let workspace = window.root(cx).unwrap();
let cx = &mut VisualTestContext::from_window(*window, cx);
// Copy and cut second assistant message
let message_3_range = message_range(&context, 2, &mut cx);
assert_copy_paste_context_editor(
&context_editor,
message_3_range.start..message_3_range.start,
indoc! {"
user1
assistant1
assistant2
assistant1
assistant2
"},
&mut cx,
);
let context_editor = window
.update(cx, |_, window, cx| {
cx.new(|cx| {
ContextEditor::for_context(
context,
fs,
workspace.downgrade(),
project,
None,
window,
cx,
)
})
})
.unwrap();
context_editor.update_in(cx, |context_editor, window, cx| {
context_editor.editor.update(cx, |editor, cx| {
editor.set_text("abc\ndef\nghi", window, cx);
editor.move_to_beginning(&Default::default(), window, cx);
})
});
context_editor.update_in(cx, |context_editor, window, cx| {
context_editor.editor.update(cx, |editor, cx| {
editor.copy(&Default::default(), window, cx);
editor.paste(&Default::default(), window, cx);
assert_eq!(editor.text(cx), "abc\nabc\ndef\nghi");
})
});
context_editor.update_in(cx, |context_editor, window, cx| {
context_editor.editor.update(cx, |editor, cx| {
editor.cut(&Default::default(), window, cx);
assert_eq!(editor.text(cx), "abc\ndef\nghi");
editor.paste(&Default::default(), window, cx);
assert_eq!(editor.text(cx), "abc\nabc\ndef\nghi");
})
});
}
#[gpui::test]
@@ -3427,129 +3402,6 @@ mod tests {
}
}
async fn setup_context_editor_text(
messages: Vec<(Role, &str)>,
cx: &mut TestAppContext,
) -> (
Entity<AssistantContext>,
Entity<ContextEditor>,
VisualTestContext,
) {
cx.update(init_test);
let fs = FakeFs::new(cx.executor());
let context = create_context_with_messages(messages, cx);
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
let workspace = window.root(cx).unwrap();
let mut cx = VisualTestContext::from_window(*window, cx);
let context_editor = window
.update(&mut cx, |_, window, cx| {
cx.new(|cx| {
let editor = ContextEditor::for_context(
context.clone(),
fs,
workspace.downgrade(),
project,
None,
window,
cx,
);
editor
})
})
.unwrap();
(context, context_editor, cx)
}
fn message_range(
context: &Entity<AssistantContext>,
message_ix: usize,
cx: &mut TestAppContext,
) -> Range<usize> {
context.update(cx, |context, cx| {
context
.messages(cx)
.nth(message_ix)
.unwrap()
.anchor_range
.to_offset(&context.buffer().read(cx).snapshot())
})
}
fn assert_copy_paste_context_editor<T: editor::ToOffset>(
context_editor: &Entity<ContextEditor>,
range: Range<T>,
expected_text: &str,
cx: &mut VisualTestContext,
) {
context_editor.update_in(cx, |context_editor, window, cx| {
context_editor.editor.update(cx, |editor, cx| {
editor.change_selections(None, window, cx, |s| s.select_ranges([range]));
});
context_editor.copy(&Default::default(), window, cx);
context_editor.editor.update(cx, |editor, cx| {
editor.move_to_end(&Default::default(), window, cx);
});
context_editor.paste(&Default::default(), window, cx);
context_editor.editor.update(cx, |editor, cx| {
assert_eq!(editor.text(cx), expected_text);
});
});
}
fn create_context_with_messages(
mut messages: Vec<(Role, &str)>,
cx: &mut TestAppContext,
) -> Entity<AssistantContext> {
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
cx.new(|cx| {
let mut context = AssistantContext::local(
registry,
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
cx,
);
let mut message_1 = context.messages(cx).next().unwrap();
let (role, text) = messages.remove(0);
loop {
if role == message_1.role {
context.buffer().update(cx, |buffer, cx| {
buffer.edit([(message_1.offset_range, text)], None, cx);
});
break;
}
let mut ids = HashSet::default();
ids.insert(message_1.id);
context.cycle_message_roles(ids, cx);
message_1 = context.messages(cx).next().unwrap();
}
let mut last_message_id = message_1.id;
for (role, text) in messages {
context.insert_message_after(last_message_id, role, MessageStatus::Done, cx);
let message = context.messages(cx).last().unwrap();
last_message_id = message.id;
context.buffer().update(cx, |buffer, cx| {
buffer.edit([(message.offset_range, text)], None, cx);
})
}
context
})
}
fn init_test(cx: &mut App) {
let settings_store = SettingsStore::test(cx);
prompt_store::init(cx);

View File

@@ -4,7 +4,8 @@ use collections::{HashSet, IndexMap};
use feature_flags::ZedProFeatureFlag;
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{
Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task,
Action, AnyElement, AnyView, App, BackgroundExecutor, Corner, DismissEvent, Entity,
EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
action_with_deprecated_aliases,
};
use language_model::{
@@ -14,7 +15,7 @@ use language_model::{
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use proto::Plan;
use ui::{ListItem, ListItemSpacing, prelude::*};
use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
action_with_deprecated_aliases!(
agent,
@@ -30,128 +31,77 @@ const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
pub fn language_model_selector(
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
window: &mut Window,
cx: &mut Context<LanguageModelSelector>,
) -> LanguageModelSelector {
let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
Picker::list(delegate, window, cx)
.show_scrollbar(true)
.width(rems(20.))
.max_height(Some(rems(20.).into()))
}
fn all_models(cx: &App) -> GroupedModels {
let providers = LanguageModelRegistry::global(cx).read(cx).providers();
let recommended = providers
.iter()
.flat_map(|provider| {
provider
.recommended_models(cx)
.into_iter()
.map(|model| ModelInfo {
model,
icon: provider.icon(),
})
})
.collect();
let other = providers
.iter()
.flat_map(|provider| {
provider
.provided_models(cx)
.into_iter()
.map(|model| ModelInfo {
model,
icon: provider.icon(),
})
})
.collect();
GroupedModels::new(other, recommended)
}
#[derive(Clone)]
struct ModelInfo {
model: Arc<dyn LanguageModel>,
icon: IconName,
}
pub struct LanguageModelPickerDelegate {
on_model_changed: OnModelChanged,
get_active_model: GetActiveModel,
all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
pub struct LanguageModelSelector {
picker: Entity<Picker<LanguageModelPickerDelegate>>,
_authenticate_all_providers_task: Task<()>,
_subscriptions: Vec<Subscription>,
}
impl LanguageModelPickerDelegate {
fn new(
impl LanguageModelSelector {
pub fn new(
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
window: &mut Window,
cx: &mut Context<Picker<Self>>,
cx: &mut Context<Self>,
) -> Self {
let on_model_changed = Arc::new(on_model_changed);
let models = all_models(cx);
let entries = models.entries();
Self {
let all_models = Self::all_models(cx);
let entries = all_models.entries();
let delegate = LanguageModelPickerDelegate {
language_model_selector: cx.entity().downgrade(),
on_model_changed: on_model_changed.clone(),
all_models: Arc::new(models),
all_models: Arc::new(all_models),
selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
filtered_entries: entries,
get_active_model: Arc::new(get_active_model),
};
let picker = cx.new(|cx| {
Picker::list(delegate, window, cx)
.show_scrollbar(true)
.width(rems(20.))
.max_height(Some(rems(20.).into()))
});
let subscription = cx.subscribe(&picker, |_, _, _, cx| cx.emit(DismissEvent));
LanguageModelSelector {
picker,
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
_subscriptions: vec![cx.subscribe_in(
&LanguageModelRegistry::global(cx),
window,
|picker, _, event, window, cx| {
match event {
language_model::Event::ProviderStateChanged
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
let query = picker.query(cx);
picker.delegate.all_models = Arc::new(all_models(cx));
// Update matches will automatically drop the previous task
// if we get a provider event again
picker.update_matches(query, window, cx)
}
_ => {}
}
},
)],
_subscriptions: vec![
cx.subscribe_in(
&LanguageModelRegistry::global(cx),
window,
Self::handle_language_model_registry_event,
),
subscription,
],
}
}
fn get_active_model_index(
entries: &[LanguageModelPickerEntry],
active_model: Option<ConfiguredModel>,
) -> usize {
entries
.iter()
.position(|entry| {
if let LanguageModelPickerEntry::Model(model) = entry {
active_model
.as_ref()
.map(|active_model| {
active_model.model.id() == model.model.id()
&& active_model.provider.id() == model.model.provider_id()
})
.unwrap_or_default()
} else {
false
}
})
.unwrap_or(0)
fn handle_language_model_registry_event(
&mut self,
_registry: &Entity<LanguageModelRegistry>,
event: &language_model::Event,
window: &mut Window,
cx: &mut Context<Self>,
) {
match event {
language_model::Event::ProviderStateChanged
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
self.picker.update(cx, |this, cx| {
let query = this.query(cx);
this.delegate.all_models = Arc::new(Self::all_models(cx));
// Update matches will automatically drop the previous task
// if we get a provider event again
this.update_matches(query, window, cx)
});
}
_ => {}
}
}
/// Authenticates all providers in the [`LanguageModelRegistry`].
@@ -204,9 +154,169 @@ impl LanguageModelPickerDelegate {
})
}
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
(self.get_active_model)(cx)
fn all_models(cx: &App) -> GroupedModels {
let mut recommended = Vec::new();
let mut recommended_set = HashSet::default();
for provider in LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.iter()
{
let models = provider.recommended_models(cx);
recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
recommended.extend(
provider
.recommended_models(cx)
.into_iter()
.map(move |model| ModelInfo {
model: model.clone(),
icon: provider.icon(),
}),
);
}
let other_models = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.iter()
.map(|provider| {
(
provider.id(),
provider
.provided_models(cx)
.into_iter()
.filter_map(|model| {
let not_included =
!recommended_set.contains(&(model.provider_id(), model.id()));
not_included.then(|| ModelInfo {
model: model.clone(),
icon: provider.icon(),
})
})
.collect::<Vec<_>>(),
)
})
.collect::<IndexMap<_, _>>();
GroupedModels {
recommended,
other: other_models,
}
}
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
(self.picker.read(cx).delegate.get_active_model)(cx)
}
fn get_active_model_index(
entries: &[LanguageModelPickerEntry],
active_model: Option<ConfiguredModel>,
) -> usize {
entries
.iter()
.position(|entry| {
if let LanguageModelPickerEntry::Model(model) = entry {
active_model
.as_ref()
.map(|active_model| {
active_model.model.id() == model.model.id()
&& active_model.provider.id() == model.model.provider_id()
})
.unwrap_or_default()
} else {
false
}
})
.unwrap_or(0)
}
}
impl EventEmitter<DismissEvent> for LanguageModelSelector {}
impl Focusable for LanguageModelSelector {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.picker.focus_handle(cx)
}
}
impl Render for LanguageModelSelector {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
self.picker.clone()
}
}
#[derive(IntoElement)]
pub struct LanguageModelSelectorPopoverMenu<T, TT>
where
T: PopoverTrigger + ButtonCommon,
TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
{
language_model_selector: Entity<LanguageModelSelector>,
trigger: T,
tooltip: TT,
handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
anchor: Corner,
}
impl<T, TT> LanguageModelSelectorPopoverMenu<T, TT>
where
T: PopoverTrigger + ButtonCommon,
TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
{
pub fn new(
language_model_selector: Entity<LanguageModelSelector>,
trigger: T,
tooltip: TT,
anchor: Corner,
) -> Self {
Self {
language_model_selector,
trigger,
tooltip,
handle: None,
anchor,
}
}
pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
self.handle = Some(handle);
self
}
}
impl<T, TT> RenderOnce for LanguageModelSelectorPopoverMenu<T, TT>
where
T: PopoverTrigger + ButtonCommon,
TT: Fn(&mut Window, &mut App) -> AnyView + 'static,
{
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
let language_model_selector = self.language_model_selector.clone();
PopoverMenu::new("model-switcher")
.menu(move |_window, _cx| Some(language_model_selector.clone()))
.trigger_with_tooltip(self.trigger, self.tooltip)
.anchor(self.anchor)
.when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
.offset(gpui::Point {
x: px(0.0),
y: px(-2.0),
})
}
}
#[derive(Clone)]
struct ModelInfo {
model: Arc<dyn LanguageModel>,
icon: IconName,
}
pub struct LanguageModelPickerDelegate {
language_model_selector: WeakEntity<LanguageModelSelector>,
on_model_changed: OnModelChanged,
get_active_model: GetActiveModel,
all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
}
struct GroupedModels {
@@ -216,14 +326,11 @@ struct GroupedModels {
impl GroupedModels {
pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
let recommended_ids = recommended
.iter()
.map(|info| (info.model.provider_id(), info.model.id()))
.collect::<HashSet<_>>();
let recommended_ids: HashSet<_> = recommended.iter().map(|info| info.model.id()).collect();
let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
for model in other {
if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
if recommended_ids.contains(&model.model.id()) {
continue;
}
@@ -470,7 +577,9 @@ impl PickerDelegate for LanguageModelPickerDelegate {
}
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
cx.emit(DismissEvent);
self.language_model_selector
.update(cx, |_this, cx| cx.emit(DismissEvent))
.ok();
}
fn render_match(
@@ -808,26 +917,4 @@ mod tests {
// Recommended models should not appear in "other"
assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
}
#[gpui::test]
fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
let recommended_models = create_models(vec![("zed", "claude")]);
let all_models = create_models(vec![
("zed", "claude"), // Should be filtered out from "other"
("zed", "gemini"),
("copilot", "claude"), // Should not be filtered out from "other"
]);
let grouped_models = GroupedModels::new(all_models, recommended_models);
let actual_other_models = grouped_models
.other
.values()
.flatten()
.cloned()
.collect::<Vec<_>>();
// Recommended models should not appear in "other"
assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
}
}

View File

@@ -1,4 +1,4 @@
use gpui::{Context, FontWeight, IntoElement, Render, Window};
use gpui::{Context, IntoElement, Render, Window};
use ui::{prelude::*, tooltip_container};
pub struct MaxModeTooltip {
@@ -18,40 +18,39 @@ impl MaxModeTooltip {
impl Render for MaxModeTooltip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let (icon, color) = if self.selected {
(IconName::ZedBurnModeOn, Color::Error)
let icon = if self.selected {
IconName::ZedBurnModeOn
} else {
(IconName::ZedBurnMode, Color::Default)
IconName::ZedBurnMode
};
let turned_on = h_flex()
.h_4()
.px_1()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().text_accent.opacity(0.1))
.rounded_sm()
.child(
Label::new("ON")
.size(LabelSize::XSmall)
.weight(FontWeight::SEMIBOLD)
.color(Color::Accent),
);
let title = h_flex()
.gap_1p5()
.child(Icon::new(icon).size(IconSize::Small).color(color))
.child(Label::new("Burn Mode"))
.when(self.selected, |title| title.child(turned_on));
.gap_1()
.child(Icon::new(icon).size(IconSize::Small))
.child(Label::new("Burn Mode"));
tooltip_container(window, cx, |this, _, _| {
this
.child(title)
this.gap_0p5()
.map(|header| if self.selected {
header.child(
h_flex()
.justify_between()
.child(title)
.child(
h_flex()
.gap_0p5()
.child(Icon::new(IconName::Check).size(IconSize::XSmall).color(Color::Accent))
.child(Label::new("Turned On").size(LabelSize::XSmall).color(Color::Accent))
)
)
} else {
header.child(title)
})
.child(
div()
.max_w_64()
.max_w_72()
.child(
Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning.")
Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning, offering an unfettered agentic experience.")
.size(LabelSize::Small)
.color(Color::Muted)
)

View File

@@ -415,38 +415,14 @@ impl ActionLog {
self.project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
} else {
// For a file created by AI with no pre-existing content,
// only delete the file if we're certain it contains only AI content
// with no edits from the user.
let initial_version = tracked_buffer.version.clone();
let current_version = buffer.read(cx).version();
let current_content = buffer.read(cx).text();
let tracked_content = tracked_buffer.snapshot.text();
let is_ai_only_content =
initial_version == current_version && current_content == tracked_content;
if is_ai_only_content {
buffer
.read(cx)
.entry_id(cx)
.and_then(|entry_id| {
self.project.update(cx, |project, cx| {
project.delete_entry(entry_id, false, cx)
})
})
.unwrap_or(Task::ready(Ok(())))
} else {
// Not sure how to disentangle edits made by the user
// from edits made by the AI at this point.
// For now, preserve both to avoid data loss.
//
// TODO: Better solution (disable "Reject" after user makes some
// edit or find a way to differentiate between AI and user edits)
Task::ready(Ok(()))
}
buffer
.read(cx)
.entry_id(cx)
.and_then(|entry_id| {
self.project
.update(cx, |project, cx| project.delete_entry(entry_id, false, cx))
})
.unwrap_or(Task::ready(Ok(())))
};
self.tracked_buffers.remove(&buffer);
@@ -1600,6 +1576,7 @@ mod tests {
project.find_project_path("dir/new_file", cx)
})
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
@@ -1642,72 +1619,6 @@ mod tests {
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
}
#[gpui::test]
async fn test_reject_created_file_with_user_edits(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| {
project.find_project_path("dir/new_file", cx)
})
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
// AI creates file with initial content
cx.update(|cx| {
action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| buffer.set_text("ai content", cx));
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
});
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.await
.unwrap();
cx.run_until_parked();
// User makes additional edits
cx.update(|cx| {
buffer.update(cx, |buffer, cx| {
buffer.edit([(10..10, "\nuser added this line")], None, cx);
});
});
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.await
.unwrap();
assert!(fs.is_file(path!("/dir/new_file").as_ref()).await);
// Reject all
action_log
.update(cx, |log, cx| {
log.reject_edits_in_ranges(
buffer.clone(),
vec![Point::new(0, 0)..Point::new(100, 0)],
cx,
)
})
.await
.unwrap();
cx.run_until_parked();
// File should still contain all the content
assert!(fs.is_file(path!("/dir/new_file").as_ref()).await);
let content = buffer.read_with(cx, |buffer, _| buffer.text());
assert_eq!(content, "ai content\nuser added this line");
}
#[gpui::test(iterations = 100)]
async fn test_random_diffs(mut rng: StdRng, cx: &mut TestAppContext) {
init_test(cx);

View File

@@ -36,7 +36,6 @@ itertools.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
lsp.workspace = true
markdown.workspace = true
open.workspace = true
paths.workspace = true
@@ -65,7 +64,6 @@ workspace.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
lsp = { workspace = true, features = ["test-support"] }
client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }

View File

@@ -28,7 +28,6 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::
use streaming_diff::{CharOperation, StreamingDiff};
use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
use util::debug_panic;
use zed_llm_client::CompletionIntent;
#[derive(Serialize)]
struct CreateFilePromptTemplate {
@@ -107,9 +106,7 @@ impl EditAgent {
edit_description,
}
.render(&this.templates)?;
let new_chunks = this
.request(conversation, CompletionIntent::CreateFile, prompt, cx)
.await?;
let new_chunks = this.request(conversation, prompt, cx).await?;
let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
while let Some(event) = inner_events.next().await {
@@ -216,9 +213,7 @@ impl EditAgent {
edit_description,
}
.render(&this.templates)?;
let edit_chunks = this
.request(conversation, CompletionIntent::EditFile, prompt, cx)
.await?;
let edit_chunks = this.request(conversation, prompt, cx).await?;
this.apply_edit_chunks(buffer, edit_chunks, events_tx, cx)
.await
});
@@ -594,7 +589,6 @@ impl EditAgent {
async fn request(
&self,
mut conversation: LanguageModelRequest,
intent: CompletionIntent,
prompt: String,
cx: &mut AsyncApp,
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
@@ -652,7 +646,6 @@ impl EditAgent {
let request = LanguageModelRequest {
thread_id: conversation.thread_id,
prompt_id: conversation.prompt_id,
intent: Some(intent),
mode: conversation.mode,
messages: conversation.messages,
tool_choice,

View File

@@ -4,7 +4,7 @@ use std::cell::LazyCell;
use util::debug_panic;
const START_MARKER: LazyCell<Regex> = LazyCell::new(|| Regex::new(r"\n?```\S*\n").unwrap());
const END_MARKER: LazyCell<Regex> = LazyCell::new(|| Regex::new(r"(^|\n)```\s*$").unwrap());
const END_MARKER: LazyCell<Regex> = LazyCell::new(|| Regex::new(r"\n```\s*$").unwrap());
#[derive(Debug)]
pub enum CreateFileParserEvent {
@@ -184,22 +184,6 @@ mod tests {
);
}
#[gpui::test(iterations = 10)]
fn test_empty_file(mut rng: StdRng) {
let mut parser = CreateFileParser::new();
assert_eq!(
parse_random_chunks(
indoc! {"
```
```
"},
&mut parser,
&mut rng
),
"".to_string()
);
}
fn parse_random_chunks(input: &str, parser: &mut CreateFileParser, rng: &mut StdRng) -> String {
let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);

View File

@@ -18,21 +18,16 @@ use gpui::{
use indoc::formatdoc;
use language::{
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Point, Rope,
TextBuffer,
language_settings::{self, FormatOnSave, SoftWrap},
TextBuffer, language_settings::SoftWrap,
};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use project::{
Project, ProjectPath,
lsp_store::{FormatTrigger, LspFormatTarget},
};
use project::{Project, ProjectPath};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::{
cmp::Reverse,
collections::HashSet,
ops::Range,
path::{Path, PathBuf},
sync::Arc,
@@ -194,10 +189,8 @@ impl Tool for EditFileTool {
});
let card_clone = card.clone();
let action_log_clone = action_log.clone();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
let edit_agent =
EditAgent::new(model, project.clone(), action_log_clone, Templates::new());
let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
let buffer = project
.update(cx, |project, cx| {
@@ -251,53 +244,19 @@ impl Tool for EditFileTool {
}
let agent_output = output.await?;
// If format_on_save is enabled, format the buffer
let format_on_save_enabled = buffer
.read_with(cx, |buffer, cx| {
let settings = language_settings::language_settings(
buffer.language().map(|l| l.name()),
buffer.file(),
cx,
);
!matches!(settings.format_on_save, FormatOnSave::Off)
})
.unwrap_or(false);
if format_on_save_enabled {
let format_task = project.update(cx, |project, cx| {
project.format(
HashSet::from_iter([buffer.clone()]),
LspFormatTarget::Buffers,
false, // Don't push to history since the tool did it.
FormatTrigger::Save,
cx,
)
})?;
format_task.await.log_err();
}
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
// Notify the action log that we've edited the buffer (*after* formatting has completed).
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx);
})?;
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let (new_text, diff) = cx
.background_spawn({
let new_snapshot = new_snapshot.clone();
let old_text = old_text.clone();
async move {
let new_text = new_snapshot.text();
let diff = language::unified_diff(&old_text, &new_text);
(new_text, diff)
}
})
.await;
let new_text = cx.background_spawn({
let new_snapshot = new_snapshot.clone();
async move { new_snapshot.text() }
});
let diff = cx.background_spawn(async move {
language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
});
let (new_text, diff) = futures::join!(new_text, diff);
let output = EditFileToolOutput {
original_path: project_path.path.to_path_buf(),
@@ -1140,8 +1099,8 @@ async fn build_buffer_diff(
mod tests {
use super::*;
use client::TelemetrySettings;
use fs::{FakeFs, Fs};
use gpui::{TestAppContext, UpdateGlobal};
use fs::FakeFs;
use gpui::TestAppContext;
use language_model::fake_provider::FakeLanguageModel;
use serde_json::json;
use settings::SettingsStore;
@@ -1351,340 +1310,4 @@ mod tests {
Project::init_settings(cx);
});
}
#[gpui::test]
async fn test_format_on_save(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({"src": {}})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
// Set up a Rust language with LSP formatting support
let rust_language = Arc::new(language::Language::new(
language::LanguageConfig {
name: "Rust".into(),
matcher: language::LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
None,
));
// Register the language and fake LSP
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
language_registry.add(rust_language);
let mut fake_language_servers = language_registry.register_fake_lsp(
"Rust",
language::FakeLspAdapter {
capabilities: lsp::ServerCapabilities {
document_formatting_provider: Some(lsp::OneOf::Left(true)),
..Default::default()
},
..Default::default()
},
);
// Create the file
fs.save(
path!("/root/src/main.rs").as_ref(),
&"initial content".into(),
language::LineEnding::Unix,
)
.await
.unwrap();
// Open the buffer to trigger LSP initialization
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/root/src/main.rs"), cx)
})
.await
.unwrap();
// Register the buffer with language servers
let _handle = project.update(cx, |project, cx| {
project.register_buffer_with_language_servers(&buffer, cx)
});
const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
const FORMATTED_CONTENT: &str =
"This file was formatted by the fake formatter in the test.\n";
// Get the fake language server and set up formatting handler
let fake_language_server = fake_language_servers.next().await.unwrap();
fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
|_, _| async move {
Ok(Some(vec![lsp::TextEdit {
range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
new_text: FORMATTED_CONTENT.to_string(),
}]))
}
});
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
// First, test with format_on_save enabled
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
cx,
|settings| {
settings.defaults.format_on_save = Some(FormatOnSave::On);
settings.defaults.formatter =
Some(language::language_settings::SelectedFormatter::Auto);
},
);
});
});
// Have the model stream unformatted content
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Create main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
});
// Stream the unformatted content
cx.executor().run_until_parked();
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
model.end_last_completion_stream();
edit_task.await
};
assert!(edit_result.is_ok());
// Wait for any async operations (e.g. formatting) to complete
cx.executor().run_until_parked();
// Read the file to verify it was formatted automatically
let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
assert_eq!(
// Ignore carriage returns on Windows
new_content.replace("\r\n", "\n"),
FORMATTED_CONTENT,
"Code should be formatted when format_on_save is enabled"
);
let stale_buffer_count = action_log.read_with(cx, |log, cx| log.stale_buffers(cx).count());
assert_eq!(
stale_buffer_count, 0,
"BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
This causes the agent to think the file was modified externally when it was just formatted.",
stale_buffer_count
);
// Next, test with format_on_save disabled
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
cx,
|settings| {
settings.defaults.format_on_save = Some(FormatOnSave::Off);
},
);
});
});
// Stream unformatted edits again
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Update main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
});
// Stream the unformatted content
cx.executor().run_until_parked();
model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
model.end_last_completion_stream();
edit_task.await
};
assert!(edit_result.is_ok());
// Wait for any async operations (e.g. formatting) to complete
cx.executor().run_until_parked();
// Verify the file was not formatted
let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
assert_eq!(
// Ignore carriage returns on Windows
new_content.replace("\r\n", "\n"),
UNFORMATTED_CONTENT,
"Code should not be formatted when format_on_save is disabled"
);
}
#[gpui::test]
async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/root", json!({"src": {}})).await;
// Create a simple file with trailing whitespace
fs.save(
path!("/root/src/main.rs").as_ref(),
&"initial content".into(),
language::LineEnding::Unix,
)
.await
.unwrap();
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
// First, test with remove_trailing_whitespace_on_save enabled
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
cx,
|settings| {
settings.defaults.remove_trailing_whitespace_on_save = Some(true);
},
);
});
});
const CONTENT_WITH_TRAILING_WHITESPACE: &str =
"fn main() { \n println!(\"Hello!\"); \n}\n";
// Have the model stream content that contains trailing whitespace
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Create main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
});
// Stream the content with trailing whitespace
cx.executor().run_until_parked();
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
model.end_last_completion_stream();
edit_task.await
};
assert!(edit_result.is_ok());
// Wait for any async operations (e.g. formatting) to complete
cx.executor().run_until_parked();
// Read the file to verify trailing whitespace was removed automatically
assert_eq!(
// Ignore carriage returns on Windows
fs.load(path!("/root/src/main.rs").as_ref())
.await
.unwrap()
.replace("\r\n", "\n"),
"fn main() {\n println!(\"Hello!\");\n}\n",
"Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
);
// Next, test with remove_trailing_whitespace_on_save disabled
cx.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<language::language_settings::AllLanguageSettings>(
cx,
|settings| {
settings.defaults.remove_trailing_whitespace_on_save = Some(false);
},
);
});
});
// Stream edits again with trailing whitespace
let edit_result = {
let edit_task = cx.update(|cx| {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Update main function".into(),
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
})
.unwrap();
Arc::new(EditFileTool)
.run(
input,
Arc::default(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
.output
});
// Stream the content with trailing whitespace
cx.executor().run_until_parked();
model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
model.end_last_completion_stream();
edit_task.await
};
assert!(edit_result.is_ok());
// Wait for any async operations (e.g. formatting) to complete
cx.executor().run_until_parked();
// Verify the file still has trailing whitespace
// Read the file again - it should still have trailing whitespace
let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
assert_eq!(
// Ignore carriage returns on Windows
final_content.replace("\r\n", "\n"),
CONTENT_WITH_TRAILING_WHITESPACE,
"Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
);
}
}

View File

@@ -1,6 +1,6 @@
use std::cell::RefCell;
use std::rc::Rc;
use std::sync::Arc;
use std::{borrow::Cow, cell::RefCell};
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow, bail};
@@ -39,11 +39,10 @@ impl FetchTool {
}
async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
let url = if !url.starts_with("https://") && !url.starts_with("http://") {
Cow::Owned(format!("https://{url}"))
} else {
Cow::Borrowed(url)
};
let mut url = url.to_owned();
if !url.starts_with("https://") && !url.starts_with("http://") {
url = format!("https://{url}");
}
let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
@@ -157,7 +156,8 @@ impl Tool for FetchTool {
let text = cx.background_spawn({
let http_client = self.http_client.clone();
async move { Self::build_message(http_client, &input.url).await }
let url = input.url.clone();
async move { Self::build_message(http_client, &url).await }
});
cx.foreground_executor()

View File

@@ -119,16 +119,14 @@ impl Tool for FindPathTool {
)
.unwrap();
}
for mat in matches.iter().skip(offset).take(RESULTS_PER_PAGE) {
write!(&mut message, "\n{}", mat.display()).unwrap();
}
let output = FindPathToolOutput {
glob,
paths: matches,
paths: matches.clone(),
};
for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) {
write!(&mut message, "\n{}", mat.display()).unwrap();
}
Ok(ToolResultOutput {
content: ToolResultContent::Text(message),
output: Some(serde_json::to_value(output)?),
@@ -237,6 +235,8 @@ impl ToolCard for FindPathToolCard {
format!("{} matches", self.paths.len()).into()
};
let glob_label = self.glob.to_string();
let content = if !self.paths.is_empty() && self.expanded {
Some(
v_flex()
@@ -310,7 +310,7 @@ impl ToolCard for FindPathToolCard {
.gap_1()
.child(
ToolCallCardHeader::new(IconName::SearchCode, matches_label)
.with_code_path(&self.glob)
.with_code_path(glob_label)
.disclosure_slot(
Disclosure::new("path-search-disclosure", self.expanded)
.opened_icon(IconName::ChevronUp)

View File

@@ -182,8 +182,9 @@ impl Tool for TerminalTool {
let mut child = pair.slave.spawn_command(cmd)?;
let mut reader = pair.master.try_clone_reader()?;
drop(pair);
let mut content = String::new();
reader.read_to_string(&mut content)?;
let mut content = Vec::new();
reader.read_to_end(&mut content)?;
let mut content = String::from_utf8(content)?;
// Massage the pty output a bit to try to match what the terminal codepath gives us
LineEnding::normalize(&mut content);
content = content

View File

@@ -166,7 +166,7 @@ impl ToolCard for WebSearchToolCard {
.gap_1()
.children(response.results.iter().enumerate().map(|(index, result)| {
let title = result.title.clone();
let url = SharedString::from(result.url.clone());
let url = result.url.clone();
Button::new(("result", index), title)
.label_size(LabelSize::Small)

View File

@@ -91,7 +91,7 @@ fn view_release_notes_locally(
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let tab_content = Some(SharedString::from(body.title.to_string()));
let tab_content = SharedString::from(body.title.to_string());
let editor = cx.new(|cx| {
Editor::for_multibuffer(buffer, Some(project), window, cx)
});

View File

@@ -16,7 +16,6 @@ doctest = false
editor.workspace = true
gpui.workspace = true
itertools.workspace = true
settings.workspace = true
theme.workspace = true
ui.workspace = true
workspace.workspace = true

View File

@@ -1,15 +1,14 @@
use editor::Editor;
use gpui::{
Context, Element, EventEmitter, Focusable, FontWeight, IntoElement, ParentElement, Render,
StyledText, Subscription, Window,
Context, Element, EventEmitter, Focusable, IntoElement, ParentElement, Render, StyledText,
Subscription, Window,
};
use itertools::Itertools;
use settings::Settings;
use std::cmp;
use theme::ActiveTheme;
use ui::{ButtonLike, ButtonStyle, Label, Tooltip, prelude::*};
use workspace::{
TabBarSettings, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView,
ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView,
item::{BreadcrumbText, ItemEvent, ItemHandle},
};
@@ -72,23 +71,16 @@ impl Render for Breadcrumbs {
);
}
let highlighted_segments = segments.into_iter().enumerate().map(|(index, segment)| {
let highlighted_segments = segments.into_iter().map(|segment| {
let mut text_style = window.text_style();
if let Some(ref font) = segment.font {
text_style.font_family = font.family.clone();
text_style.font_features = font.features.clone();
if let Some(font) = segment.font {
text_style.font_family = font.family;
text_style.font_features = font.features;
text_style.font_style = font.style;
text_style.font_weight = font.weight;
}
text_style.color = Color::Muted.color(cx);
if index == 0 && !TabBarSettings::get_global(cx).show && active_item.is_dirty(cx) {
if let Some(styled_element) = apply_dirty_filename_style(&segment, &text_style, cx)
{
return styled_element;
}
}
StyledText::new(segment.text.replace('\n', ""))
.with_default_highlights(&text_style, segment.highlights.unwrap_or_default())
.into_any()
@@ -192,46 +184,3 @@ impl ToolbarItemView for Breadcrumbs {
self.pane_focused = pane_focused;
}
}
fn apply_dirty_filename_style(
segment: &BreadcrumbText,
text_style: &gpui::TextStyle,
cx: &mut Context<Breadcrumbs>,
) -> Option<gpui::AnyElement> {
let text = segment.text.replace('\n', "");
let filename_position = std::path::Path::new(&segment.text)
.file_name()
.and_then(|f| {
let filename_str = f.to_string_lossy();
segment.text.rfind(filename_str.as_ref())
})?;
let bold_weight = FontWeight::BOLD;
let default_color = Color::Default.color(cx);
if filename_position == 0 {
let mut filename_style = text_style.clone();
filename_style.font_weight = bold_weight;
filename_style.color = default_color;
return Some(
StyledText::new(text)
.with_default_highlights(&filename_style, [])
.into_any(),
);
}
let highlight_style = gpui::HighlightStyle {
font_weight: Some(bold_weight),
color: Some(default_color),
..Default::default()
};
let highlight = vec![(filename_position..text.len(), highlight_style)];
Some(
StyledText::new(text)
.with_default_highlights(&text_style, highlight)
.into_any(),
)
}

View File

@@ -20,7 +20,6 @@ test-support = ["sqlite"]
[dependencies]
anyhow.workspace = true
async-stripe.workspace = true
async-trait.workspace = true
async-tungstenite.workspace = true
aws-config = { version = "1.1.5" }
aws-sdk-s3 = { version = "1.15.0" }

View File

@@ -17,8 +17,8 @@ use stripe::{
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
};
use util::{ResultExt, maybe};
@@ -29,10 +29,6 @@ use crate::db::billing_subscription::{
use crate::llm::db::subscription_usage_meter::CompletionMode;
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
use crate::rpc::{ResultExt as _, Server};
use crate::stripe_client::{
StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
StripeSubscriptionId,
};
use crate::{AppState, Error, Result};
use crate::{db::UserId, llm::db::LlmDatabase};
use crate::{
@@ -58,6 +54,10 @@ pub fn router() -> Router {
"/billing/subscriptions/manage",
post(manage_billing_subscription),
)
.route(
"/billing/subscriptions/migrate",
post(migrate_to_new_billing),
)
.route(
"/billing/subscriptions/sync",
post(sync_billing_subscription),
@@ -282,6 +282,7 @@ async fn list_billing_subscriptions(
enum ProductCode {
ZedPro,
ZedProTrial,
ZedFree,
}
#[derive(Debug, Deserialize)]
@@ -337,7 +338,8 @@ async fn create_billing_subscription(
}
let customer_id = if let Some(existing_customer) = &existing_billing_customer {
StripeCustomerId(existing_customer.stripe_customer_id.clone().into())
CustomerId::from_str(&existing_customer.stripe_customer_id)
.context("failed to parse customer ID")?
} else {
stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref())
@@ -352,7 +354,7 @@ async fn create_billing_subscription(
let checkout_session_url = match body.product {
ProductCode::ZedPro => {
stripe_billing
.checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
.checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
.await?
}
ProductCode::ZedProTrial => {
@@ -369,13 +371,18 @@ async fn create_billing_subscription(
stripe_billing
.checkout_with_zed_pro_trial(
&customer_id,
customer_id,
&user.github_login,
feature_flags,
&success_url,
)
.await?
}
ProductCode::ZedFree => {
stripe_billing
.checkout_with_zed_free(customer_id, &user.github_login, &success_url)
.await?
}
};
Ok(Json(CreateBillingSubscriptionResponse {
@@ -425,7 +432,7 @@ async fn manage_billing_subscription(
.await?
.context("user not found")?;
let Some(stripe_client) = app.real_stripe_client.clone() else {
let Some(stripe_client) = app.stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
@@ -491,10 +498,8 @@ async fn manage_billing_subscription(
let flow = match body.intent {
ManageSubscriptionIntent::ManageSubscription => None,
ManageSubscriptionIntent::UpgradeToPro => {
let zed_pro_price_id: stripe::PriceId =
stripe_billing.zed_pro_price_id().await?.try_into()?;
let zed_free_price_id: stripe::PriceId =
stripe_billing.zed_free_price_id().await?.try_into()?;
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
let stripe_subscription =
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
@@ -628,6 +633,86 @@ async fn manage_billing_subscription(
}))
}
#[derive(Debug, Deserialize)]
struct MigrateToNewBillingBody {
github_user_id: i32,
}
#[derive(Debug, Serialize)]
struct MigrateToNewBillingResponse {
/// The ID of the subscription that was canceled.
canceled_subscription_id: Option<String>,
}
async fn migrate_to_new_billing(
Extension(app): Extension<Arc<AppState>>,
extract::Json(body): extract::Json<MigrateToNewBillingBody>,
) -> Result<Json<MigrateToNewBillingResponse>> {
let Some(stripe_client) = app.stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let user = app
.db
.get_user_by_github_user_id(body.github_user_id)
.await?
.context("user not found")?;
let old_billing_subscriptions_by_user = app
.db
.get_active_billing_subscriptions(HashSet::from_iter([user.id]))
.await?;
let canceled_subscription_id = if let Some((_billing_customer, billing_subscription)) =
old_billing_subscriptions_by_user.get(&user.id)
{
let stripe_subscription_id = billing_subscription
.stripe_subscription_id
.parse::<stripe::SubscriptionId>()
.context("failed to parse Stripe subscription ID from database")?;
Subscription::cancel(
&stripe_client,
&stripe_subscription_id,
stripe::CancelSubscription {
invoice_now: Some(true),
..Default::default()
},
)
.await?;
Some(stripe_subscription_id)
} else {
None
};
let all_feature_flags = app.db.list_feature_flags().await?;
let user_feature_flags = app.db.get_user_flags(user.id).await?;
for feature_flag in ["new-billing", "assistant2"] {
let already_in_feature_flag = user_feature_flags.iter().any(|flag| flag == feature_flag);
if already_in_feature_flag {
continue;
}
let feature_flag = all_feature_flags
.iter()
.find(|flag| flag.flag == feature_flag)
.context("failed to find feature flag: {feature_flag:?}")?;
app.db.add_user_flag(user.id, feature_flag.id).await?;
}
Ok(Json(MigrateToNewBillingResponse {
canceled_subscription_id: canceled_subscription_id
.map(|subscription_id| subscription_id.to_string()),
}))
}
#[derive(Debug, Deserialize)]
struct SyncBillingSubscriptionBody {
github_user_id: i32,
@@ -661,13 +746,23 @@ async fn sync_billing_subscription(
.get_billing_customer_by_user_id(user.id)
.await?
.context("billing customer not found")?;
let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
let stripe_customer_id = billing_customer
.stripe_customer_id
.parse::<stripe::CustomerId>()
.context("failed to parse Stripe customer ID from database")?;
let subscriptions = stripe_client
.list_subscriptions_for_customer(&stripe_customer_id)
.await?;
let subscriptions = Subscription::list(
&stripe_client,
&stripe::ListSubscriptions {
customer: Some(stripe_customer_id),
// Sync all non-canceled subscriptions.
status: None,
..Default::default()
},
)
.await?;
for subscription in subscriptions {
for subscription in subscriptions.data {
let subscription_id = subscription.id.clone();
sync_subscription(&app, &stripe_client, subscription)
@@ -715,10 +810,6 @@ const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
/// Polls the Stripe events API periodically to reconcile the records in our
/// database with the data in Stripe.
pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
let Some(real_stripe_client) = app.real_stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
};
let Some(stripe_client) = app.stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
@@ -729,7 +820,7 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Serve
let executor = executor.clone();
async move {
loop {
poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
poll_stripe_events(&app, &rpc_server, &stripe_client)
.await
.log_err();
@@ -742,8 +833,7 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Serve
async fn poll_stripe_events(
app: &Arc<AppState>,
rpc_server: &Arc<Server>,
stripe_client: &Arc<dyn StripeClient>,
real_stripe_client: &stripe::Client,
stripe_client: &stripe::Client,
) -> anyhow::Result<()> {
fn event_type_to_string(event_type: EventType) -> String {
// Calling `to_string` on `stripe::EventType` members gives us a quoted string,
@@ -775,7 +865,7 @@ async fn poll_stripe_events(
params.types = Some(event_types.clone());
params.limit = Some(EVENTS_LIMIT_PER_PAGE);
let mut event_pages = stripe::Event::list(&real_stripe_client, &params)
let mut event_pages = stripe::Event::list(&stripe_client, &params)
.await?
.paginate(params);
@@ -819,7 +909,7 @@ async fn poll_stripe_events(
break;
} else {
log::info!("Stripe events: retrieving next page");
event_pages = event_pages.next(&real_stripe_client).await?;
event_pages = event_pages.next(&stripe_client).await?;
}
} else {
break;
@@ -859,7 +949,7 @@ async fn poll_stripe_events(
let process_result = match event.type_ {
EventType::CustomerCreated | EventType::CustomerUpdated => {
handle_customer_event(app, real_stripe_client, event).await
handle_customer_event(app, stripe_client, event).await
}
EventType::CustomerSubscriptionCreated
| EventType::CustomerSubscriptionUpdated
@@ -934,8 +1024,8 @@ async fn handle_customer_event(
async fn sync_subscription(
app: &Arc<AppState>,
stripe_client: &Arc<dyn StripeClient>,
subscription: StripeSubscription,
stripe_client: &stripe::Client,
subscription: stripe::Subscription,
) -> anyhow::Result<billing_customer::Model> {
let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
stripe_billing
@@ -946,7 +1036,7 @@ async fn sync_subscription(
};
let billing_customer =
find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
find_or_create_billing_customer(app, stripe_client, subscription.customer)
.await?
.context("billing customer not found")?;
@@ -974,7 +1064,7 @@ async fn sync_subscription(
.as_ref()
.and_then(|details| details.reason)
.map_or(false, |reason| {
reason == StripeCancellationDetailsReason::PaymentFailed
reason == CancellationDetailsReason::PaymentFailed
});
if was_canceled_due_to_payment_failure {
@@ -991,7 +1081,7 @@ async fn sync_subscription(
if let Some(existing_subscription) = app
.db
.get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
.get_billing_subscription_by_stripe_subscription_id(&subscription.id)
.await?
{
app.db
@@ -1032,13 +1122,20 @@ async fn sync_subscription(
if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
&& subscription_kind == Some(SubscriptionKind::ZedProTrial)
{
let stripe_subscription_id = StripeSubscriptionId(
existing_subscription.stripe_subscription_id.clone().into(),
);
let stripe_subscription_id = existing_subscription
.stripe_subscription_id
.parse::<stripe::SubscriptionId>()
.context("failed to parse Stripe subscription ID from database")?;
stripe_client
.cancel_subscription(&stripe_subscription_id)
.await?;
Subscription::cancel(
&stripe_client,
&stripe_subscription_id,
stripe::CancelSubscription {
invoice_now: None,
..Default::default()
},
)
.await?;
} else {
// If the user already has an active billing subscription, ignore the
// event and return an `Ok` to signal that it was processed
@@ -1089,8 +1186,10 @@ async fn sync_subscription(
.has_active_billing_subscription(billing_customer.user_id)
.await?;
if !already_has_active_billing_subscription {
let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
let stripe_customer_id = billing_customer
.stripe_customer_id
.parse::<stripe::CustomerId>()
.context("failed to parse Stripe customer ID from database")?;
stripe_billing
.subscribe_to_zed_free(stripe_customer_id)
@@ -1105,7 +1204,7 @@ async fn sync_subscription(
async fn handle_customer_subscription_event(
app: &Arc<AppState>,
rpc_server: &Arc<Server>,
stripe_client: &Arc<dyn StripeClient>,
stripe_client: &stripe::Client,
event: stripe::Event,
) -> anyhow::Result<()> {
let EventObject::Subscription(subscription) = event.data.object else {
@@ -1114,7 +1213,7 @@ async fn handle_customer_subscription_event(
log::info!("handling Stripe {} event: {}", event.type_, event.id);
let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
let billing_customer = sync_subscription(app, stripe_client, subscription).await?;
// When the user's subscription changes, push down any changes to their plan.
rpc_server
@@ -1310,20 +1409,30 @@ impl From<CancellationDetailsReason> for StripeCancellationReason {
/// Finds or creates a billing customer using the provided customer.
pub async fn find_or_create_billing_customer(
app: &Arc<AppState>,
stripe_client: &dyn StripeClient,
customer_id: &StripeCustomerId,
stripe_client: &stripe::Client,
customer_or_id: Expandable<Customer>,
) -> anyhow::Result<Option<billing_customer::Model>> {
let customer_id = match &customer_or_id {
Expandable::Id(id) => id,
Expandable::Object(customer) => customer.id.as_ref(),
};
// If we already have a billing customer record associated with the Stripe customer,
// there's nothing more we need to do.
if let Some(billing_customer) = app
.db
.get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
.get_billing_customer_by_stripe_customer_id(customer_id)
.await?
{
return Ok(Some(billing_customer));
}
let customer = stripe_client.get_customer(customer_id).await?;
// If all we have is a customer ID, resolve it to a full customer record by
// hitting the Stripe API.
let customer = match customer_or_id {
Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
Expandable::Object(customer) => *customer,
};
let Some(email) = customer.email else {
return Ok(None);
@@ -1433,10 +1542,14 @@ async fn sync_model_request_usage_with_stripe(
);
};
let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
let stripe_subscription_id =
StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
let stripe_customer_id = billing_customer
.stripe_customer_id
.parse::<stripe::CustomerId>()
.context("failed to parse Stripe customer ID from database")?;
let stripe_subscription_id = billing_subscription
.stripe_subscription_id
.parse::<stripe::SubscriptionId>()
.context("failed to parse Stripe subscription ID from database")?;
let model = llm_db.model_by_id(usage_meter.model_id)?;

View File

@@ -1,5 +1,4 @@
use crate::db::{BillingCustomerId, BillingSubscriptionId};
use crate::stripe_client;
use chrono::{Datelike as _, NaiveDate, Utc};
use sea_orm::entity::prelude::*;
use serde::Serialize;
@@ -160,17 +159,3 @@ pub enum StripeCancellationReason {
#[sea_orm(string_value = "payment_failed")]
PaymentFailed,
}
impl From<stripe_client::StripeCancellationDetailsReason> for StripeCancellationReason {
fn from(value: stripe_client::StripeCancellationDetailsReason) -> Self {
match value {
stripe_client::StripeCancellationDetailsReason::CancellationRequested => {
Self::CancellationRequested
}
stripe_client::StripeCancellationDetailsReason::PaymentDisputed => {
Self::PaymentDisputed
}
stripe_client::StripeCancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
}
}
}

View File

@@ -9,7 +9,6 @@ pub mod migrations;
pub mod rpc;
pub mod seed;
pub mod stripe_billing;
pub mod stripe_client;
pub mod user_backfiller;
#[cfg(test)]
@@ -30,7 +29,6 @@ use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{RealStripeClient, StripeClient};
pub type Result<T, E = Error> = std::result::Result<T, E>;
@@ -271,10 +269,7 @@ pub struct AppState {
pub llm_db: Option<Arc<LlmDatabase>>,
pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
/// This is a real instance of the Stripe client; we're working to replace references to this with the
/// [`StripeClient`] trait.
pub real_stripe_client: Option<Arc<stripe::Client>>,
pub stripe_client: Option<Arc<dyn StripeClient>>,
pub stripe_client: Option<Arc<stripe::Client>>,
pub stripe_billing: Option<Arc<StripeBilling>>,
pub executor: Executor,
pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
@@ -327,9 +322,7 @@ impl AppState {
stripe_billing: stripe_client
.clone()
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
real_stripe_client: stripe_client.clone(),
stripe_client: stripe_client
.map(|stripe_client| Arc::new(RealStripeClient::new(stripe_client)) as _),
stripe_client,
executor,
kinesis_client: if config.kinesis_access_key.is_some() {
build_kinesis_client(&config).await.log_err()

View File

@@ -5,7 +5,6 @@ use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::db::LlmDatabase;
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
use crate::stripe_client::StripeCustomerId;
use crate::{
AppState, Error, Result, auth,
db::{
@@ -4034,26 +4033,31 @@ async fn get_llm_api_token(
.as_ref()
.context("failed to retrieve Stripe billing object")?;
let billing_customer = if let Some(billing_customer) =
db.get_billing_customer_by_user_id(user.id).await?
{
billing_customer
} else {
let customer_id = stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref())
.await?;
let billing_customer =
if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? {
billing_customer
} else {
let customer_id = stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref())
.await?;
find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
find_or_create_billing_customer(
&session.app_state,
&stripe_client,
stripe::Expandable::Id(customer_id),
)
.await?
.context("billing customer not found")?
};
};
let billing_subscription =
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
billing_subscription
} else {
let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
let stripe_customer_id = billing_customer
.stripe_customer_id
.parse::<stripe::CustomerId>()
.context("failed to parse Stripe customer ID from database")?;
let stripe_subscription = stripe_billing
.subscribe_to_zed_free(stripe_customer_id)

View File

@@ -1,49 +1,30 @@
use std::sync::Arc;
use anyhow::{Context as _, anyhow};
use chrono::Utc;
use collections::HashMap;
use stripe::SubscriptionStatus;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_client::{
RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
StripeSubscriptionId, StripeSubscriptionTrialSettings,
StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
UpdateSubscriptionParams,
};
use anyhow::{Context as _, anyhow};
use chrono::Utc;
use collections::HashMap;
use serde::{Deserialize, Serialize};
use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
use tokio::sync::RwLock;
use uuid::Uuid;
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
client: Arc<dyn StripeClient>,
client: Arc<stripe::Client>,
}
#[derive(Default)]
struct StripeBillingState {
meters_by_event_name: HashMap<String, StripeMeter>,
price_ids_by_meter_id: HashMap<String, StripePriceId>,
prices_by_lookup_key: HashMap<String, StripePrice>,
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
prices_by_lookup_key: HashMap<String, stripe::Price>,
}
impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self {
client: Arc::new(RealStripeClient::new(client.clone())),
state: RwLock::default(),
}
}
#[cfg(test)]
pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
Self {
client,
state: RwLock::default(),
@@ -55,16 +36,24 @@ impl StripeBilling {
let mut state = self.state.write().await;
let (meters, prices) =
futures::try_join!(self.client.list_meters(), self.client.list_prices())?;
let (meters, prices) = futures::try_join!(
StripeMeter::list(&self.client),
stripe::Price::list(
&self.client,
&stripe::ListPrices {
limit: Some(100),
..Default::default()
}
)
)?;
for meter in meters {
for meter in meters.data {
state
.meters_by_event_name
.insert(meter.event_name.clone(), meter);
}
for price in prices {
for price in prices.data {
if let Some(lookup_key) = price.lookup_key.clone() {
state.prices_by_lookup_key.insert(lookup_key, price.clone());
}
@@ -81,15 +70,15 @@ impl StripeBilling {
Ok(())
}
pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
self.find_price_id_by_lookup_key("zed-pro").await
}
pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
pub async fn zed_free_price_id(&self) -> Result<PriceId> {
self.find_price_id_by_lookup_key("zed-free").await
}
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
self.state
.read()
.await
@@ -99,7 +88,7 @@ impl StripeBilling {
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
}
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
self.state
.read()
.await
@@ -111,12 +100,12 @@ impl StripeBilling {
pub async fn determine_subscription_kind(
&self,
subscription: &StripeSubscription,
subscription: &stripe::Subscription,
) -> Option<SubscriptionKind> {
let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
let zed_free_price_id = self.zed_free_price_id().await.ok()?;
subscription.items.iter().find_map(|item| {
subscription.items.data.iter().find_map(|item| {
let price = item.price.as_ref()?;
if price.id == zed_pro_price_id {
@@ -140,11 +129,18 @@ impl StripeBilling {
pub async fn find_or_create_customer_by_email(
&self,
email_address: Option<&str>,
) -> Result<StripeCustomerId> {
) -> Result<CustomerId> {
let existing_customer = if let Some(email) = email_address {
let customers = self.client.list_customers_by_email(email).await?;
let customers = Customer::list(
&self.client,
&stripe::ListCustomers {
email: Some(email),
..Default::default()
},
)
.await?;
customers.first().cloned()
customers.data.first().cloned()
} else {
None
};
@@ -152,12 +148,14 @@ impl StripeBilling {
let customer_id = if let Some(existing_customer) = existing_customer {
existing_customer.id
} else {
let customer = self
.client
.create_customer(crate::stripe_client::CreateCustomerParams {
let customer = Customer::create(
&self.client,
CreateCustomer {
email: email_address,
})
.await?;
..Default::default()
},
)
.await?;
customer.id
};
@@ -167,10 +165,11 @@ impl StripeBilling {
pub async fn subscribe_to_price(
&self,
subscription_id: &StripeSubscriptionId,
price: &StripePrice,
subscription_id: &stripe::SubscriptionId,
price: &stripe::Price,
) -> Result<()> {
let subscription = self.client.get_subscription(subscription_id).await?;
let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
if subscription_contains_price(&subscription, &price.id) {
return Ok(());
@@ -181,36 +180,39 @@ impl StripeBilling {
let price_per_unit = price.unit_amount.unwrap_or_default();
let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
self.client
.update_subscription(
subscription_id,
UpdateSubscriptionParams {
items: Some(vec![UpdateSubscriptionItems {
price: Some(price.id.clone()),
}]),
trial_settings: Some(StripeSubscriptionTrialSettings {
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
},
}),
},
)
.await?;
stripe::Subscription::update(
&self.client,
subscription_id,
stripe::UpdateSubscription {
items: Some(vec![stripe::UpdateSubscriptionItems {
price: Some(price.id.to_string()),
..Default::default()
}]),
trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
},
}),
..Default::default()
},
)
.await?;
Ok(())
}
pub async fn bill_model_request_usage(
&self,
customer_id: &StripeCustomerId,
customer_id: &stripe::CustomerId,
event_name: &str,
requests: i32,
) -> Result<()> {
let timestamp = Utc::now().timestamp();
let idempotency_key = Uuid::new_v4();
self.client
.create_meter_event(StripeCreateMeterEventParams {
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("model_requests/{}", idempotency_key),
event_name,
payload: StripeCreateMeterEventPayload {
@@ -218,37 +220,39 @@ impl StripeBilling {
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
})
.await?;
},
)
.await?;
Ok(())
}
pub async fn checkout_with_zed_pro(
&self,
customer_id: &StripeCustomerId,
customer_id: stripe::CustomerId,
github_login: &str,
success_url: &str,
) -> Result<String> {
let zed_pro_price_id = self.zed_pro_price_id().await?;
let mut params = StripeCreateCheckoutSessionParams::default();
params.mode = Some(StripeCheckoutSessionMode::Subscription);
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
price: Some(zed_pro_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
params.success_url = Some(success_url);
let session = self.client.create_checkout_session(params).await?;
let session = stripe::CheckoutSession::create(&self.client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
pub async fn checkout_with_zed_pro_trial(
&self,
customer_id: &StripeCustomerId,
customer_id: stripe::CustomerId,
github_login: &str,
feature_flags: Vec<String>,
success_url: &str,
@@ -269,75 +273,172 @@ impl StripeBilling {
);
}
let mut params = StripeCreateCheckoutSessionParams::default();
params.subscription_data = Some(StripeCreateCheckoutSessionSubscriptionData {
let mut params = stripe::CreateCheckoutSession::new();
params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
trial_period_days: Some(trial_period_days),
trial_settings: Some(StripeSubscriptionTrialSettings {
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
missing_payment_method:
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
},
trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
}
}),
metadata: if !subscription_metadata.is_empty() {
Some(subscription_metadata)
} else {
None
},
..Default::default()
});
params.mode = Some(StripeCheckoutSessionMode::Subscription);
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.payment_method_collection =
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired);
Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
price: Some(zed_pro_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
params.success_url = Some(success_url);
let session = self.client.create_checkout_session(params).await?;
let session = stripe::CheckoutSession::create(&self.client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
pub async fn subscribe_to_zed_free(
&self,
customer_id: StripeCustomerId,
) -> Result<StripeSubscription> {
customer_id: stripe::CustomerId,
) -> Result<stripe::Subscription> {
let zed_free_price_id = self.zed_free_price_id().await?;
let existing_subscriptions = self
.client
.list_subscriptions_for_customer(&customer_id)
.await?;
let existing_subscriptions = stripe::Subscription::list(
&self.client,
&stripe::ListSubscriptions {
customer: Some(customer_id.clone()),
status: None,
..Default::default()
},
)
.await?;
let existing_active_subscription =
existing_subscriptions.into_iter().find(|subscription| {
subscription.status == SubscriptionStatus::Active
|| subscription.status == SubscriptionStatus::Trialing
});
existing_subscriptions
.data
.into_iter()
.find(|subscription| {
subscription.status == SubscriptionStatus::Active
|| subscription.status == SubscriptionStatus::Trialing
});
if let Some(subscription) = existing_active_subscription {
return Ok(subscription);
}
let params = StripeCreateSubscriptionParams {
customer: customer_id,
items: vec![StripeCreateSubscriptionItems {
price: Some(zed_free_price_id),
quantity: Some(1),
}],
};
let mut params = stripe::CreateSubscription::new(customer_id);
params.items = Some(vec![stripe::CreateSubscriptionItems {
price: Some(zed_free_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
let subscription = self.client.create_subscription(params).await?;
let subscription = stripe::Subscription::create(&self.client, params).await?;
Ok(subscription)
}
pub async fn checkout_with_zed_free(
&self,
customer_id: stripe::CustomerId,
github_login: &str,
success_url: &str,
) -> Result<String> {
let zed_free_price_id = self.zed_free_price_id().await?;
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.payment_method_collection =
Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
price: Some(zed_free_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
}
#[derive(Clone, Deserialize)]
struct StripeMeter {
id: String,
event_name: String,
}
impl StripeMeter {
pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
#[derive(Serialize)]
struct Params {
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<u64>,
}
client.get_query("/billing/meters", Params { limit: Some(100) })
}
}
#[derive(Deserialize)]
struct StripeMeterEvent {
identifier: String,
}
impl StripeMeterEvent {
pub async fn create(
client: &stripe::Client,
params: StripeCreateMeterEventParams<'_>,
) -> Result<Self, stripe::StripeError> {
let identifier = params.identifier;
match client.post_form("/billing/meter_events", params).await {
Ok(event) => Ok(event),
Err(stripe::StripeError::Stripe(error)) => {
if error.http_status == 400
&& error
.message
.as_ref()
.map_or(false, |message| message.contains(identifier))
{
Ok(Self {
identifier: identifier.to_string(),
})
} else {
Err(stripe::StripeError::Stripe(error))
}
}
Err(error) => Err(error),
}
}
}
#[derive(Serialize)]
struct StripeCreateMeterEventParams<'a> {
identifier: &'a str,
event_name: &'a str,
payload: StripeCreateMeterEventPayload<'a>,
timestamp: Option<i64>,
}
#[derive(Serialize)]
struct StripeCreateMeterEventPayload<'a> {
value: u64,
stripe_customer_id: &'a stripe::CustomerId,
}
fn subscription_contains_price(
subscription: &StripeSubscription,
price_id: &StripePriceId,
subscription: &stripe::Subscription,
price_id: &stripe::PriceId,
) -> bool {
subscription.items.iter().any(|item| {
subscription.items.data.iter().any(|item| {
item.price
.as_ref()
.map_or(false, |price| price.id == *price_id)

View File

@@ -1,229 +0,0 @@
#[cfg(test)]
mod fake_stripe_client;
mod real_stripe_client;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
#[cfg(test)]
pub use fake_stripe_client::*;
pub use real_stripe_client::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Serialize)]
pub struct StripeCustomerId(pub Arc<str>);
#[derive(Debug, Clone)]
pub struct StripeCustomer {
pub id: StripeCustomerId,
pub email: Option<String>,
}
#[derive(Debug)]
pub struct CreateCustomerParams<'a> {
pub email: Option<&'a str>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeSubscriptionId(pub Arc<str>);
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscription {
pub id: StripeSubscriptionId,
pub customer: StripeCustomerId,
// TODO: Create our own version of this enum.
pub status: stripe::SubscriptionStatus,
pub current_period_end: i64,
pub current_period_start: i64,
pub items: Vec<StripeSubscriptionItem>,
pub cancel_at: Option<i64>,
pub cancellation_details: Option<StripeCancellationDetails>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeSubscriptionItemId(pub Arc<str>);
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionItem {
pub id: StripeSubscriptionItemId,
pub price: Option<StripePrice>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct StripeCancellationDetails {
pub reason: Option<StripeCancellationDetailsReason>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCancellationDetailsReason {
CancellationRequested,
PaymentDisputed,
PaymentFailed,
}
#[derive(Debug)]
pub struct StripeCreateSubscriptionParams {
pub customer: StripeCustomerId,
pub items: Vec<StripeCreateSubscriptionItems>,
}
#[derive(Debug)]
pub struct StripeCreateSubscriptionItems {
pub price: Option<StripePriceId>,
pub quantity: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct UpdateSubscriptionParams {
pub items: Option<Vec<UpdateSubscriptionItems>>,
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct UpdateSubscriptionItems {
pub price: Option<StripePriceId>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionTrialSettings {
pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionTrialSettingsEndBehavior {
pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
Cancel,
CreateInvoice,
Pause,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripePriceId(pub Arc<str>);
#[derive(Debug, PartialEq, Clone)]
pub struct StripePrice {
pub id: StripePriceId,
pub unit_amount: Option<i64>,
pub lookup_key: Option<String>,
pub recurring: Option<StripePriceRecurring>,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripePriceRecurring {
pub meter: Option<String>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)]
pub struct StripeMeterId(pub Arc<str>);
#[derive(Debug, Clone, Deserialize)]
pub struct StripeMeter {
pub id: StripeMeterId,
pub event_name: String,
}
#[derive(Debug, Serialize)]
pub struct StripeCreateMeterEventParams<'a> {
pub identifier: &'a str,
pub event_name: &'a str,
pub payload: StripeCreateMeterEventPayload<'a>,
pub timestamp: Option<i64>,
}
#[derive(Debug, Serialize)]
pub struct StripeCreateMeterEventPayload<'a> {
pub value: u64,
pub stripe_customer_id: &'a StripeCustomerId,
}
#[derive(Debug, Default)]
pub struct StripeCreateCheckoutSessionParams<'a> {
pub customer: Option<&'a StripeCustomerId>,
pub client_reference_id: Option<&'a str>,
pub mode: Option<StripeCheckoutSessionMode>,
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
pub success_url: Option<&'a str>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCheckoutSessionMode {
Payment,
Setup,
Subscription,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeCreateCheckoutSessionLineItems {
pub price: Option<String>,
pub quantity: Option<u64>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCheckoutSessionPaymentMethodCollection {
Always,
IfRequired,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeCreateCheckoutSessionSubscriptionData {
pub metadata: Option<HashMap<String, String>>,
pub trial_period_days: Option<u32>,
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
}
#[derive(Debug)]
pub struct StripeCheckoutSession {
pub url: Option<String>,
}
#[async_trait]
pub trait StripeClient: Send + Sync {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer>;
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>>;
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription>;
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription>;
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
params: UpdateSubscriptionParams,
) -> Result<()>;
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()>;
async fn list_prices(&self) -> Result<Vec<StripePrice>>;
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>;
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession>;
}

View File

@@ -1,224 +0,0 @@
use std::sync::Arc;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use chrono::{Duration, Utc};
use collections::HashMap;
use parking_lot::Mutex;
use uuid::Uuid;
use crate::stripe_client::{
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId,
StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
StripeSubscriptionItemId, UpdateSubscriptionParams,
};
#[derive(Debug, Clone)]
pub struct StripeCreateMeterEventCall {
pub identifier: Arc<str>,
pub event_name: Arc<str>,
pub value: u64,
pub stripe_customer_id: StripeCustomerId,
pub timestamp: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct StripeCreateCheckoutSessionCall {
pub customer: Option<StripeCustomerId>,
pub client_reference_id: Option<String>,
pub mode: Option<StripeCheckoutSessionMode>,
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
pub success_url: Option<String>,
}
pub struct FakeStripeClient {
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
pub update_subscription_calls:
Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
}
impl FakeStripeClient {
pub fn new() -> Self {
Self {
customers: Arc::new(Mutex::new(HashMap::default())),
subscriptions: Arc::new(Mutex::new(HashMap::default())),
update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
prices: Arc::new(Mutex::new(HashMap::default())),
meters: Arc::new(Mutex::new(HashMap::default())),
create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
}
}
}
#[async_trait]
impl StripeClient for FakeStripeClient {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
Ok(self
.customers
.lock()
.values()
.filter(|customer| customer.email.as_deref() == Some(email))
.cloned()
.collect())
}
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
self.customers
.lock()
.get(customer_id)
.cloned()
.ok_or_else(|| anyhow!("no customer found for {customer_id:?}"))
}
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
let customer = StripeCustomer {
id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
email: params.email.map(|email| email.to_string()),
};
self.customers
.lock()
.insert(customer.id.clone(), customer.clone());
Ok(customer)
}
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>> {
let subscriptions = self
.subscriptions
.lock()
.values()
.filter(|subscription| subscription.customer == *customer_id)
.cloned()
.collect();
Ok(subscriptions)
}
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription> {
self.subscriptions
.lock()
.get(subscription_id)
.cloned()
.ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
}
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription> {
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
customer: params.customer,
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: params
.items
.into_iter()
.map(|item| StripeSubscriptionItem {
id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
price: item
.price
.and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
})
.collect(),
cancel_at: None,
cancellation_details: None,
};
self.subscriptions
.lock()
.insert(subscription.id.clone(), subscription.clone());
Ok(subscription)
}
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
params: UpdateSubscriptionParams,
) -> Result<()> {
let subscription = self.get_subscription(subscription_id).await?;
self.update_subscription_calls
.lock()
.push((subscription.id, params));
Ok(())
}
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
// TODO: Implement fake subscription cancellation.
let _ = subscription_id;
Ok(())
}
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
let prices = self.prices.lock().values().cloned().collect();
Ok(prices)
}
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
let meters = self.meters.lock().values().cloned().collect();
Ok(meters)
}
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
self.create_meter_event_calls
.lock()
.push(StripeCreateMeterEventCall {
identifier: params.identifier.into(),
event_name: params.event_name.into(),
value: params.payload.value,
stripe_customer_id: params.payload.stripe_customer_id.clone(),
timestamp: params.timestamp,
});
Ok(())
}
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession> {
self.create_checkout_session_calls
.lock()
.push(StripeCreateCheckoutSessionCall {
customer: params.customer.cloned(),
client_reference_id: params.client_reference_id.map(|id| id.to_string()),
mode: params.mode,
line_items: params.line_items,
payment_method_collection: params.payment_method_collection,
subscription_data: params.subscription_data,
success_url: params.success_url.map(|url| url.to_string()),
});
Ok(StripeCheckoutSession {
url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
})
}
}

View File

@@ -1,500 +0,0 @@
use std::str::FromStr as _;
use std::sync::Arc;
use anyhow::{Context as _, Result, anyhow};
use async_trait::async_trait;
use serde::Serialize;
use stripe::{
CancellationDetails, CancellationDetailsReason, CheckoutSession, CheckoutSessionMode,
CheckoutSessionPaymentMethodCollection, CreateCheckoutSession, CreateCheckoutSessionLineItems,
CreateCheckoutSessionSubscriptionData, CreateCheckoutSessionSubscriptionDataTrialSettings,
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior,
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod,
CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription,
SubscriptionId, SubscriptionItem, SubscriptionItemId, UpdateSubscriptionItems,
UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
};
use crate::stripe_client::{
CreateCustomerParams, StripeCancellationDetails, StripeCancellationDetailsReason,
StripeCheckoutSession, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings,
StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams,
};
pub struct RealStripeClient {
client: Arc<stripe::Client>,
}
impl RealStripeClient {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self { client }
}
}
#[async_trait]
impl StripeClient for RealStripeClient {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
let response = Customer::list(
&self.client,
&ListCustomers {
email: Some(email),
..Default::default()
},
)
.await?;
Ok(response
.data
.into_iter()
.map(StripeCustomer::from)
.collect())
}
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
let customer_id = customer_id.try_into()?;
let customer = Customer::retrieve(&self.client, &customer_id, &[]).await?;
Ok(StripeCustomer::from(customer))
}
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
let customer = Customer::create(
&self.client,
CreateCustomer {
email: params.email,
..Default::default()
},
)
.await?;
Ok(StripeCustomer::from(customer))
}
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>> {
let customer_id = customer_id.try_into()?;
let subscriptions = stripe::Subscription::list(
&self.client,
&stripe::ListSubscriptions {
customer: Some(customer_id),
status: None,
..Default::default()
},
)
.await?;
Ok(subscriptions
.data
.into_iter()
.map(StripeSubscription::from)
.collect())
}
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription> {
let subscription_id = subscription_id.try_into()?;
let subscription = Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
Ok(StripeSubscription::from(subscription))
}
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription> {
let customer_id = params.customer.try_into()?;
let mut create_subscription = stripe::CreateSubscription::new(customer_id);
create_subscription.items = Some(
params
.items
.into_iter()
.map(|item| stripe::CreateSubscriptionItems {
price: item.price.map(|price| price.to_string()),
quantity: item.quantity,
..Default::default()
})
.collect(),
);
let subscription = Subscription::create(&self.client, create_subscription).await?;
Ok(StripeSubscription::from(subscription))
}
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
params: UpdateSubscriptionParams,
) -> Result<()> {
let subscription_id = subscription_id.try_into()?;
stripe::Subscription::update(
&self.client,
&subscription_id,
stripe::UpdateSubscription {
items: params.items.map(|items| {
items
.into_iter()
.map(|item| UpdateSubscriptionItems {
price: item.price.map(|price| price.to_string()),
..Default::default()
})
.collect()
}),
trial_settings: params.trial_settings.map(Into::into),
..Default::default()
},
)
.await?;
Ok(())
}
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
let subscription_id = subscription_id.try_into()?;
Subscription::cancel(
&self.client,
&subscription_id,
stripe::CancelSubscription {
invoice_now: None,
..Default::default()
},
)
.await?;
Ok(())
}
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
let response = stripe::Price::list(
&self.client,
&stripe::ListPrices {
limit: Some(100),
..Default::default()
},
)
.await?;
Ok(response.data.into_iter().map(StripePrice::from).collect())
}
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
#[derive(Serialize)]
struct Params {
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<u64>,
}
let response = self
.client
.get_query::<stripe::List<StripeMeter>, _>(
"/billing/meters",
Params { limit: Some(100) },
)
.await?;
Ok(response.data)
}
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
let identifier = params.identifier;
match self.client.post_form("/billing/meter_events", params).await {
Ok(event) => Ok(event),
Err(stripe::StripeError::Stripe(error)) => {
if error.http_status == 400
&& error
.message
.as_ref()
.map_or(false, |message| message.contains(identifier))
{
Ok(())
} else {
Err(anyhow!(stripe::StripeError::Stripe(error)))
}
}
Err(error) => Err(anyhow!(error)),
}
}
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession> {
let params = params.try_into()?;
let session = CheckoutSession::create(&self.client, params).await?;
Ok(session.into())
}
}
impl From<CustomerId> for StripeCustomerId {
fn from(value: CustomerId) -> Self {
Self(value.as_str().into())
}
}
impl TryFrom<StripeCustomerId> for CustomerId {
type Error = anyhow::Error;
fn try_from(value: StripeCustomerId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
}
}
impl TryFrom<&StripeCustomerId> for CustomerId {
type Error = anyhow::Error;
fn try_from(value: &StripeCustomerId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
}
}
impl From<Customer> for StripeCustomer {
fn from(value: Customer) -> Self {
StripeCustomer {
id: value.id.into(),
email: value.email,
}
}
}
impl From<SubscriptionId> for StripeSubscriptionId {
fn from(value: SubscriptionId) -> Self {
Self(value.as_str().into())
}
}
impl TryFrom<&StripeSubscriptionId> for SubscriptionId {
type Error = anyhow::Error;
fn try_from(value: &StripeSubscriptionId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe subscription ID")
}
}
impl From<Subscription> for StripeSubscription {
fn from(value: Subscription) -> Self {
Self {
id: value.id.into(),
customer: value.customer.id().into(),
status: value.status,
current_period_start: value.current_period_start,
current_period_end: value.current_period_end,
items: value.items.data.into_iter().map(Into::into).collect(),
cancel_at: value.cancel_at,
cancellation_details: value.cancellation_details.map(Into::into),
}
}
}
impl From<CancellationDetails> for StripeCancellationDetails {
fn from(value: CancellationDetails) -> Self {
Self {
reason: value.reason.map(Into::into),
}
}
}
impl From<CancellationDetailsReason> for StripeCancellationDetailsReason {
fn from(value: CancellationDetailsReason) -> Self {
match value {
CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
}
}
}
impl From<SubscriptionItemId> for StripeSubscriptionItemId {
fn from(value: SubscriptionItemId) -> Self {
Self(value.as_str().into())
}
}
impl From<SubscriptionItem> for StripeSubscriptionItem {
fn from(value: SubscriptionItem) -> Self {
Self {
id: value.id.into(),
price: value.price.map(Into::into),
}
}
}
impl From<StripeSubscriptionTrialSettings> for UpdateSubscriptionTrialSettings {
fn from(value: StripeSubscriptionTrialSettings) -> Self {
Self {
end_behavior: value.end_behavior.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehavior>
for UpdateSubscriptionTrialSettingsEndBehavior
{
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
Self {
missing_payment_method: value.missing_payment_method.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod
{
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
match value {
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
Self::CreateInvoice
}
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
}
}
}
impl From<PriceId> for StripePriceId {
fn from(value: PriceId) -> Self {
Self(value.as_str().into())
}
}
impl TryFrom<StripePriceId> for PriceId {
type Error = anyhow::Error;
fn try_from(value: StripePriceId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID")
}
}
impl From<Price> for StripePrice {
fn from(value: Price) -> Self {
Self {
id: value.id.into(),
unit_amount: value.unit_amount,
lookup_key: value.lookup_key,
recurring: value.recurring.map(StripePriceRecurring::from),
}
}
}
impl From<Recurring> for StripePriceRecurring {
fn from(value: Recurring) -> Self {
Self { meter: value.meter }
}
}
impl<'a> TryFrom<StripeCreateCheckoutSessionParams<'a>> for CreateCheckoutSession<'a> {
type Error = anyhow::Error;
fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result<Self, Self::Error> {
Ok(Self {
customer: value
.customer
.map(|customer_id| customer_id.try_into())
.transpose()?,
client_reference_id: value.client_reference_id,
mode: value.mode.map(Into::into),
line_items: value
.line_items
.map(|line_items| line_items.into_iter().map(Into::into).collect()),
payment_method_collection: value.payment_method_collection.map(Into::into),
subscription_data: value.subscription_data.map(Into::into),
success_url: value.success_url,
..Default::default()
})
}
}
impl From<StripeCheckoutSessionMode> for CheckoutSessionMode {
fn from(value: StripeCheckoutSessionMode) -> Self {
match value {
StripeCheckoutSessionMode::Payment => Self::Payment,
StripeCheckoutSessionMode::Setup => Self::Setup,
StripeCheckoutSessionMode::Subscription => Self::Subscription,
}
}
}
impl From<StripeCreateCheckoutSessionLineItems> for CreateCheckoutSessionLineItems {
fn from(value: StripeCreateCheckoutSessionLineItems) -> Self {
Self {
price: value.price,
quantity: value.quantity,
..Default::default()
}
}
}
impl From<StripeCheckoutSessionPaymentMethodCollection> for CheckoutSessionPaymentMethodCollection {
fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self {
match value {
StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always,
StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired,
}
}
}
impl From<StripeCreateCheckoutSessionSubscriptionData> for CreateCheckoutSessionSubscriptionData {
fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self {
Self {
trial_period_days: value.trial_period_days,
trial_settings: value.trial_settings.map(Into::into),
metadata: value.metadata,
..Default::default()
}
}
}
impl From<StripeSubscriptionTrialSettings> for CreateCheckoutSessionSubscriptionDataTrialSettings {
fn from(value: StripeSubscriptionTrialSettings) -> Self {
Self {
end_behavior: value.end_behavior.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehavior>
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior
{
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
Self {
missing_payment_method: value.missing_payment_method.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod
{
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
match value {
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
Self::CreateInvoice
}
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
}
}
}
impl From<CheckoutSession> for StripeCheckoutSession {
fn from(value: CheckoutSession) -> Self {
Self { url: value.url }
}
}

View File

@@ -18,7 +18,6 @@ mod random_channel_buffer_tests;
mod random_project_collaboration_tests;
mod randomized_test_helpers;
mod remote_editing_collaboration_tests;
mod stripe_billing_tests;
mod test_server;
use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};

View File

@@ -1010,6 +1010,7 @@ async fn test_peers_following_each_other(cx_a: &mut TestAppContext, cx_b: &mut T
workspace_b.update_in(cx_b, |workspace, window, cx| {
workspace.active_pane().update(cx, |pane, cx| {
pane.close_inactive_items(&Default::default(), window, cx)
.unwrap()
.detach();
});
});

View File

@@ -1,565 +0,0 @@
use std::sync::Arc;
use chrono::{Duration, Utc};
use pretty_assertions::assert_eq;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{
FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData,
StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
};
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
let stripe_client = Arc::new(FakeStripeClient::new());
let stripe_billing = StripeBilling::test(stripe_client.clone());
(stripe_billing, stripe_client)
}
#[gpui::test]
async fn test_initialize() {
let (stripe_billing, stripe_client) = make_stripe_billing();
// Add test meters
let meter1 = StripeMeter {
id: StripeMeterId("meter_1".into()),
event_name: "event_1".to_string(),
};
let meter2 = StripeMeter {
id: StripeMeterId("meter_2".into()),
event_name: "event_2".to_string(),
};
stripe_client
.meters
.lock()
.insert(meter1.id.clone(), meter1);
stripe_client
.meters
.lock()
.insert(meter2.id.clone(), meter2);
// Add test prices
let price1 = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(1_000),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
let price2 = StripePrice {
id: StripePriceId("price_2".into()),
unit_amount: Some(0),
lookup_key: Some("zed-free".to_string()),
recurring: None,
};
let price3 = StripePrice {
id: StripePriceId("price_3".into()),
unit_amount: Some(500),
lookup_key: None,
recurring: Some(StripePriceRecurring {
meter: Some("meter_1".to_string()),
}),
};
stripe_client
.prices
.lock()
.insert(price1.id.clone(), price1);
stripe_client
.prices
.lock()
.insert(price2.id.clone(), price2);
stripe_client
.prices
.lock()
.insert(price3.id.clone(), price3);
// Initialize the billing system
stripe_billing.initialize().await.unwrap();
// Verify that prices can be found by lookup key
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
assert_eq!(zed_pro_price_id.to_string(), "price_1");
let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
assert_eq!(zed_free_price_id.to_string(), "price_2");
// Verify that a price can be found by lookup key
let zed_pro_price = stripe_billing
.find_price_by_lookup_key("zed-pro")
.await
.unwrap();
assert_eq!(zed_pro_price.id.to_string(), "price_1");
assert_eq!(zed_pro_price.unit_amount, Some(1_000));
// Verify that finding a non-existent lookup key returns an error
let result = stripe_billing
.find_price_by_lookup_key("non-existent")
.await;
assert!(result.is_err());
}
#[gpui::test]
async fn test_find_or_create_customer_by_email() {
let (stripe_billing, stripe_client) = make_stripe_billing();
// Create a customer with an email that doesn't yet correspond to a customer.
{
let email = "user@example.com";
let customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
let customer = stripe_client
.customers
.lock()
.get(&customer_id)
.unwrap()
.clone();
assert_eq!(customer.email.as_deref(), Some(email));
}
// Create a customer with an email that corresponds to an existing customer.
{
let email = "user2@example.com";
let existing_customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
let customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
assert_eq!(customer_id, existing_customer_id);
let customer = stripe_client
.customers
.lock()
.get(&customer_id)
.unwrap()
.clone();
assert_eq!(customer.email.as_deref(), Some(email));
}
}
#[gpui::test]
async fn test_subscribe_to_price() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let price = StripePrice {
id: StripePriceId("price_test".into()),
unit_amount: Some(2000),
lookup_key: Some("test-price".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(price.id.clone(), price.clone());
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId("sub_test".into()),
customer: StripeCustomerId("cus_test".into()),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![],
cancel_at: None,
cancellation_details: None,
};
stripe_client
.subscriptions
.lock()
.insert(subscription.id.clone(), subscription.clone());
stripe_billing
.subscribe_to_price(&subscription.id, &price)
.await
.unwrap();
let update_subscription_calls = stripe_client
.update_subscription_calls
.lock()
.iter()
.map(|(id, params)| (id.clone(), params.clone()))
.collect::<Vec<_>>();
assert_eq!(update_subscription_calls.len(), 1);
assert_eq!(update_subscription_calls[0].0, subscription.id);
assert_eq!(
update_subscription_calls[0].1.items,
Some(vec![UpdateSubscriptionItems {
price: Some(price.id.clone())
}])
);
// Subscribing to a price that is already on the subscription is a no-op.
{
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId("sub_test".into()),
customer: StripeCustomerId("cus_test".into()),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(price.clone()),
}],
cancel_at: None,
cancellation_details: None,
};
stripe_client
.subscriptions
.lock()
.insert(subscription.id.clone(), subscription.clone());
stripe_billing
.subscribe_to_price(&subscription.id, &price)
.await
.unwrap();
assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
}
}
#[gpui::test]
async fn test_subscribe_to_zed_free() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let zed_pro_price = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(0),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(zed_pro_price.id.clone(), zed_pro_price.clone());
let zed_free_price = StripePrice {
id: StripePriceId("price_2".into()),
unit_amount: Some(0),
lookup_key: Some("zed-free".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(zed_free_price.id.clone(), zed_free_price.clone());
stripe_billing.initialize().await.unwrap();
// Customer is subscribed to Zed Free when not already subscribed to a plan.
{
let customer_id = StripeCustomerId("cus_no_plan".into());
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
}
// Customer is not subscribed to Zed Free when they already have an active subscription.
{
let customer_id = StripeCustomerId("cus_active_subscription".into());
let now = Utc::now();
let existing_subscription = StripeSubscription {
id: StripeSubscriptionId("sub_existing_active".into()),
customer: customer_id.clone(),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(zed_pro_price.clone()),
}],
cancel_at: None,
cancellation_details: None,
};
stripe_client.subscriptions.lock().insert(
existing_subscription.id.clone(),
existing_subscription.clone(),
);
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription, existing_subscription);
}
// Customer is not subscribed to Zed Free when they already have a trial subscription.
{
let customer_id = StripeCustomerId("cus_trial_subscription".into());
let now = Utc::now();
let existing_subscription = StripeSubscription {
id: StripeSubscriptionId("sub_existing_trial".into()),
customer: customer_id.clone(),
status: stripe::SubscriptionStatus::Trialing,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(14)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(zed_pro_price.clone()),
}],
cancel_at: None,
cancellation_details: None,
};
stripe_client.subscriptions.lock().insert(
existing_subscription.id.clone(),
existing_subscription.clone(),
);
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription, existing_subscription);
}
}
#[gpui::test]
async fn test_bill_model_request_usage() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let customer_id = StripeCustomerId("cus_test".into());
stripe_billing
.bill_model_request_usage(&customer_id, "some_model/requests", 73)
.await
.unwrap();
let create_meter_event_calls = stripe_client
.create_meter_event_calls
.lock()
.iter()
.cloned()
.collect::<Vec<_>>();
assert_eq!(create_meter_event_calls.len(), 1);
assert!(
create_meter_event_calls[0]
.identifier
.starts_with("model_requests/")
);
assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
assert_eq!(
create_meter_event_calls[0].event_name.as_ref(),
"some_model/requests"
);
assert_eq!(create_meter_event_calls[0].value, 73);
}
#[gpui::test]
async fn test_checkout_with_zed_pro() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let customer_id = StripeCustomerId("cus_test".into());
let github_login = "zeduser1";
let success_url = "https://example.com/success";
// It returns an error when the Zed Pro price doesn't exist.
{
let result = stripe_billing
.checkout_with_zed_pro(&customer_id, github_login, success_url)
.await;
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
r#"no price ID found for "zed-pro""#
);
}
// Successful checkout.
{
let price = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(2000),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(price.id.clone(), price.clone());
stripe_billing.initialize().await.unwrap();
let checkout_url = stripe_billing
.checkout_with_zed_pro(&customer_id, github_login, success_url)
.await
.unwrap();
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
let create_checkout_session_calls = stripe_client
.create_checkout_session_calls
.lock()
.drain(..)
.collect::<Vec<_>>();
assert_eq!(create_checkout_session_calls.len(), 1);
let call = create_checkout_session_calls.into_iter().next().unwrap();
assert_eq!(call.customer, Some(customer_id));
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
assert_eq!(
call.line_items,
Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(price.id.to_string()),
quantity: Some(1)
}])
);
assert_eq!(call.payment_method_collection, None);
assert_eq!(call.subscription_data, None);
assert_eq!(call.success_url.as_deref(), Some(success_url));
}
}
#[gpui::test]
async fn test_checkout_with_zed_pro_trial() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let customer_id = StripeCustomerId("cus_test".into());
let github_login = "zeduser1";
let success_url = "https://example.com/success";
// It returns an error when the Zed Pro price doesn't exist.
{
let result = stripe_billing
.checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
.await;
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
r#"no price ID found for "zed-pro""#
);
}
let price = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(2000),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(price.id.clone(), price.clone());
stripe_billing.initialize().await.unwrap();
// Successful checkout.
{
let checkout_url = stripe_billing
.checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
.await
.unwrap();
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
let create_checkout_session_calls = stripe_client
.create_checkout_session_calls
.lock()
.drain(..)
.collect::<Vec<_>>();
assert_eq!(create_checkout_session_calls.len(), 1);
let call = create_checkout_session_calls.into_iter().next().unwrap();
assert_eq!(call.customer.as_ref(), Some(&customer_id));
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
assert_eq!(
call.line_items,
Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(price.id.to_string()),
quantity: Some(1)
}])
);
assert_eq!(
call.payment_method_collection,
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
);
assert_eq!(
call.subscription_data,
Some(StripeCreateCheckoutSessionSubscriptionData {
trial_period_days: Some(14),
trial_settings: Some(StripeSubscriptionTrialSettings {
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
missing_payment_method:
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
},
}),
metadata: None,
})
);
assert_eq!(call.success_url.as_deref(), Some(success_url));
}
// Successful checkout with extended trial.
{
let checkout_url = stripe_billing
.checkout_with_zed_pro_trial(
&customer_id,
github_login,
vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
success_url,
)
.await
.unwrap();
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
let create_checkout_session_calls = stripe_client
.create_checkout_session_calls
.lock()
.drain(..)
.collect::<Vec<_>>();
assert_eq!(create_checkout_session_calls.len(), 1);
let call = create_checkout_session_calls.into_iter().next().unwrap();
assert_eq!(call.customer, Some(customer_id));
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
assert_eq!(
call.line_items,
Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(price.id.to_string()),
quantity: Some(1)
}])
);
assert_eq!(
call.payment_method_collection,
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
);
assert_eq!(
call.subscription_data,
Some(StripeCreateCheckoutSessionSubscriptionData {
trial_period_days: Some(60),
trial_settings: Some(StripeSubscriptionTrialSettings {
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
missing_payment_method:
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
},
}),
metadata: Some(std::collections::HashMap::from_iter([(
"promo_feature_flag".into(),
AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
)])),
})
);
assert_eq!(call.success_url.as_deref(), Some(success_url));
}
}

View File

@@ -1,4 +1,3 @@
use crate::stripe_client::FakeStripeClient;
use crate::{
AppState, Config,
db::{NewUserParams, UserId, tests::TestDb},
@@ -523,8 +522,7 @@ impl TestServer {
llm_db: None,
livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
blob_store_client: None,
real_stripe_client: None,
stripe_client: Some(Arc::new(FakeStripeClient::new())),
stripe_client: None,
stripe_billing: None,
executor,
kinesis_client: None,

View File

@@ -298,7 +298,6 @@ pub async fn download_adapter_from_github(
response.status().to_string()
);
delegate.output_to_console("Download complete".to_owned());
match file_type {
DownloadedFileType::GzipTar => {
let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
@@ -370,19 +369,21 @@ pub trait DebugAdapter: 'static + Send + Sync {
None
}
/// Extracts the kind (attach/launch) of debug configuration from the given JSON config.
/// This method should only return error when the kind cannot be determined for a given configuration;
/// in particular, it *should not* validate whether the request as a whole is valid, because that's best left to the debug adapter itself to decide.
fn request_kind(
fn validate_config(
&self,
config: &serde_json::Value,
) -> Result<StartDebuggingRequestArgumentsRequest> {
match config.get("request") {
Some(val) if val == "launch" => Ok(StartDebuggingRequestArgumentsRequest::Launch),
Some(val) if val == "attach" => Ok(StartDebuggingRequestArgumentsRequest::Attach),
_ => Err(anyhow!(
"missing or invalid `request` field in config. Expected 'launch' or 'attach'"
)),
let map = config.as_object().context("Config isn't an object")?;
let request_variant = map
.get("request")
.and_then(|val| val.as_str())
.context("request argument is not found or invalid")?;
match request_variant {
"launch" => Ok(StartDebuggingRequestArgumentsRequest::Launch),
"attach" => Ok(StartDebuggingRequestArgumentsRequest::Attach),
_ => Err(anyhow!("request must be either 'launch' or 'attach'")),
}
}
@@ -412,7 +413,7 @@ impl DebugAdapter for FakeAdapter {
serde_json::Value::Null
}
fn request_kind(
fn validate_config(
&self,
config: &serde_json::Value,
) -> Result<StartDebuggingRequestArgumentsRequest> {
@@ -457,7 +458,7 @@ impl DebugAdapter for FakeAdapter {
envs: HashMap::default(),
cwd: None,
request_args: StartDebuggingRequestArguments {
request: self.request_kind(&task_definition.config)?,
request: self.validate_config(&task_definition.config)?,
configuration: task_definition.config.clone(),
},
})

View File

@@ -52,7 +52,7 @@ pub fn send_telemetry(scenario: &DebugScenario, location: TelemetrySpawnLocation
return;
};
let kind = adapter
.request_kind(&scenario.config)
.validate_config(&scenario.config)
.ok()
.map(serde_json::to_value)
.and_then(Result::ok);

View File

@@ -4,7 +4,7 @@ use dap_types::{
messages::{Message, Response},
};
use futures::{AsyncRead, AsyncReadExt as _, AsyncWrite, FutureExt as _, channel::oneshot, select};
use gpui::{AppContext as _, AsyncApp, Task};
use gpui::AsyncApp;
use settings::Settings as _;
use smallvec::SmallVec;
use smol::{
@@ -22,7 +22,7 @@ use std::{
time::Duration,
};
use task::TcpArgumentsTemplate;
use util::{ConnectionResult, ResultExt as _};
use util::{ResultExt as _, TryFutureExt};
use crate::{adapters::DebugAdapterBinary, debugger_settings::DebuggerSettings};
@@ -126,7 +126,7 @@ pub(crate) struct TransportDelegate {
pending_requests: Requests,
transport: Transport,
server_tx: Arc<Mutex<Option<Sender<Message>>>>,
_tasks: Vec<Task<()>>,
_tasks: Vec<gpui::Task<Option<()>>>,
}
impl TransportDelegate {
@@ -141,7 +141,7 @@ impl TransportDelegate {
log_handlers: Default::default(),
current_requests: Default::default(),
pending_requests: Default::default(),
_tasks: Vec::new(),
_tasks: Default::default(),
};
let messages = this.start_handlers(transport_pipes, cx).await?;
Ok((messages, this))
@@ -166,76 +166,45 @@ impl TransportDelegate {
None
};
let adapter_log_handler = log_handler.clone();
cx.update(|cx| {
if let Some(stdout) = params.stdout.take() {
self._tasks.push(cx.background_spawn(async move {
match Self::handle_adapter_log(stdout, adapter_log_handler).await {
ConnectionResult::Timeout => {
log::error!("Timed out when handling debugger log");
}
ConnectionResult::ConnectionReset => {
log::info!("Debugger logs connection closed");
}
ConnectionResult::Result(Ok(())) => {}
ConnectionResult::Result(Err(e)) => {
log::error!("Error handling debugger log: {e}");
}
}
}));
self._tasks.push(
cx.background_executor()
.spawn(Self::handle_adapter_log(stdout, log_handler.clone()).log_err()),
);
}
let pending_requests = self.pending_requests.clone();
let output_log_handler = log_handler.clone();
self._tasks.push(cx.background_spawn(async move {
match Self::handle_output(
params.output,
client_tx,
pending_requests,
output_log_handler,
)
.await
{
Ok(()) => {}
Err(e) => log::error!("Error handling debugger output: {e}"),
}
}));
self._tasks.push(
cx.background_executor().spawn(
Self::handle_output(
params.output,
client_tx,
self.pending_requests.clone(),
log_handler.clone(),
)
.log_err(),
),
);
if let Some(stderr) = params.stderr.take() {
let log_handlers = self.log_handlers.clone();
self._tasks.push(cx.background_spawn(async move {
match Self::handle_error(stderr, log_handlers).await {
ConnectionResult::Timeout => {
log::error!("Timed out reading debugger error stream")
}
ConnectionResult::ConnectionReset => {
log::info!("Debugger closed its error stream")
}
ConnectionResult::Result(Ok(())) => {}
ConnectionResult::Result(Err(e)) => {
log::error!("Error handling debugger error: {e}")
}
}
}));
self._tasks.push(
cx.background_executor()
.spawn(Self::handle_error(stderr, self.log_handlers.clone()).log_err()),
);
}
let current_requests = self.current_requests.clone();
let pending_requests = self.pending_requests.clone();
let log_handler = log_handler.clone();
self._tasks.push(cx.background_spawn(async move {
match Self::handle_input(
params.input,
client_rx,
current_requests,
pending_requests,
log_handler,
)
.await
{
Ok(()) => {}
Err(e) => log::error!("Error handling debugger input: {e}"),
}
}));
self._tasks.push(
cx.background_executor().spawn(
Self::handle_input(
params.input,
client_rx,
self.current_requests.clone(),
self.pending_requests.clone(),
log_handler.clone(),
)
.log_err(),
),
);
})?;
{
@@ -266,7 +235,7 @@ impl TransportDelegate {
async fn handle_adapter_log<Stdout>(
stdout: Stdout,
log_handlers: Option<LogHandlers>,
) -> ConnectionResult<()>
) -> Result<()>
where
Stdout: AsyncRead + Unpin + Send + 'static,
{
@@ -276,14 +245,13 @@ impl TransportDelegate {
let result = loop {
line.truncate(0);
match reader
.read_line(&mut line)
.await
.context("reading adapter log line")
{
Ok(0) => break ConnectionResult::ConnectionReset,
Ok(_) => {}
Err(e) => break ConnectionResult::Result(Err(e)),
let bytes_read = match reader.read_line(&mut line).await {
Ok(bytes_read) => bytes_read,
Err(e) => break Err(e.into()),
};
if bytes_read == 0 {
anyhow::bail!("Debugger log stream closed");
}
if let Some(log_handlers) = log_handlers.as_ref() {
@@ -369,35 +337,35 @@ impl TransportDelegate {
let mut reader = BufReader::new(server_stdout);
let result = loop {
match Self::receive_server_message(&mut reader, &mut recv_buffer, log_handlers.as_ref())
.await
{
ConnectionResult::Timeout => anyhow::bail!("Timed out when connecting to debugger"),
ConnectionResult::ConnectionReset => {
log::info!("Debugger closed the connection");
return Ok(());
}
ConnectionResult::Result(Ok(Message::Response(res))) => {
let message =
Self::receive_server_message(&mut reader, &mut recv_buffer, log_handlers.as_ref())
.await;
match message {
Ok(Message::Response(res)) => {
if let Some(tx) = pending_requests.lock().await.remove(&res.request_seq) {
if let Err(e) = tx.send(Self::process_response(res)) {
log::trace!("Did not send response `{:?}` for a cancelled", e);
}
} else {
client_tx.send(Message::Response(res)).await?;
}
};
}
ConnectionResult::Result(Ok(message)) => client_tx.send(message).await?,
ConnectionResult::Result(Err(e)) => break Err(e),
Ok(message) => {
client_tx.send(message).await?;
}
Err(e) => break Err(e),
}
};
drop(client_tx);
log::debug!("Handle adapter output dropped");
result
}
async fn handle_error<Stderr>(stderr: Stderr, log_handlers: LogHandlers) -> ConnectionResult<()>
async fn handle_error<Stderr>(stderr: Stderr, log_handlers: LogHandlers) -> Result<()>
where
Stderr: AsyncRead + Unpin + Send + 'static,
{
@@ -407,12 +375,8 @@ impl TransportDelegate {
let mut reader = BufReader::new(stderr);
let result = loop {
match reader
.read_line(&mut buffer)
.await
.context("reading error log line")
{
Ok(0) => break ConnectionResult::ConnectionReset,
match reader.read_line(&mut buffer).await {
Ok(0) => anyhow::bail!("debugger error stream closed"),
Ok(_) => {
for (kind, log_handler) in log_handlers.lock().iter_mut() {
if matches!(kind, LogKind::Adapter) {
@@ -422,7 +386,7 @@ impl TransportDelegate {
buffer.truncate(0);
}
Err(error) => break ConnectionResult::Result(Err(error)),
Err(error) => break Err(error.into()),
}
};
@@ -456,7 +420,7 @@ impl TransportDelegate {
reader: &mut BufReader<Stdout>,
buffer: &mut String,
log_handlers: Option<&LogHandlers>,
) -> ConnectionResult<Message>
) -> Result<Message>
where
Stdout: AsyncRead + Unpin + Send + 'static,
{
@@ -464,58 +428,48 @@ impl TransportDelegate {
loop {
buffer.truncate(0);
match reader
if reader
.read_line(buffer)
.await
.with_context(|| "reading a message from server")
.with_context(|| "reading a message from server")?
== 0
{
Ok(0) => return ConnectionResult::ConnectionReset,
Ok(_) => {}
Err(e) => return ConnectionResult::Result(Err(e)),
anyhow::bail!("debugger reader stream closed");
};
if buffer == "\r\n" {
break;
}
if let Some(("Content-Length", value)) = buffer.trim().split_once(": ") {
match value.parse().context("invalid content length") {
Ok(length) => content_length = Some(length),
Err(e) => return ConnectionResult::Result(Err(e)),
let parts = buffer.trim().split_once(": ");
match parts {
Some(("Content-Length", value)) => {
content_length = Some(value.parse().context("invalid content length")?);
}
_ => {}
}
}
let content_length = match content_length.context("missing content length") {
Ok(length) => length,
Err(e) => return ConnectionResult::Result(Err(e)),
};
let content_length = content_length.context("missing content length")?;
let mut content = vec![0; content_length];
if let Err(e) = reader
reader
.read_exact(&mut content)
.await
.with_context(|| "reading after a loop")
{
return ConnectionResult::Result(Err(e));
}
.with_context(|| "reading after a loop")?;
let message_str = match std::str::from_utf8(&content).context("invalid utf8 from server") {
Ok(str) => str,
Err(e) => return ConnectionResult::Result(Err(e)),
};
let message = std::str::from_utf8(&content).context("invalid utf8 from server")?;
if let Some(log_handlers) = log_handlers {
for (kind, log_handler) in log_handlers.lock().iter_mut() {
if matches!(kind, LogKind::Rpc) {
log_handler(IoKind::StdOut, message_str);
log_handler(IoKind::StdOut, &message);
}
}
}
ConnectionResult::Result(
serde_json::from_str::<Message>(message_str).context("deserializing server message"),
)
Ok(serde_json::from_str::<Message>(message)?)
}
pub async fn shutdown(&self) -> Result<()> {
@@ -704,13 +658,9 @@ impl StdioTransport {
.stderr(Stdio::piped())
.kill_on_drop(true);
let mut process = command.spawn().with_context(|| {
format!(
"failed to spawn command `{} {}`.",
binary.command,
binary.arguments.join(" ")
)
})?;
let mut process = command
.spawn()
.with_context(|| "failed to spawn command.")?;
let stdin = process.stdin.take().context("Failed to open stdin")?;
let stdout = process.stdout.take().context("Failed to open stdout")?;
@@ -823,55 +773,73 @@ impl FakeTransport {
let response_handlers = this.response_handlers.clone();
let stdout_writer = Arc::new(Mutex::new(stdout_writer));
cx.background_spawn(async move {
let mut reader = BufReader::new(stdin_reader);
let mut buffer = String::new();
cx.background_executor()
.spawn(async move {
let mut reader = BufReader::new(stdin_reader);
let mut buffer = String::new();
loop {
match TransportDelegate::receive_server_message(&mut reader, &mut buffer, None)
.await
{
ConnectionResult::Timeout => {
anyhow::bail!("Timed out when connecting to debugger");
}
ConnectionResult::ConnectionReset => {
log::info!("Debugger closed the connection");
break Ok(());
}
ConnectionResult::Result(Err(e)) => break Err(e),
ConnectionResult::Result(Ok(message)) => {
match message {
Message::Request(request) => {
// redirect reverse requests to stdout writer/reader
if request.command == RunInTerminal::COMMAND
|| request.command == StartDebugging::COMMAND
{
let message =
serde_json::to_string(&Message::Request(request)).unwrap();
loop {
let message =
TransportDelegate::receive_server_message(&mut reader, &mut buffer, None)
.await;
let mut writer = stdout_writer.lock().await;
writer
.write_all(
TransportDelegate::build_rpc_message(message)
.as_bytes(),
)
.await
.unwrap();
writer.flush().await.unwrap();
} else {
let response = if let Some(handle) =
request_handlers.lock().get_mut(request.command.as_str())
match message {
Err(error) => {
break anyhow::anyhow!(error);
}
Ok(message) => {
match message {
Message::Request(request) => {
// redirect reverse requests to stdout writer/reader
if request.command == RunInTerminal::COMMAND
|| request.command == StartDebugging::COMMAND
{
handle(request.seq, request.arguments.unwrap_or(json!({})))
} else {
panic!("No request handler for {}", request.command);
};
let message =
serde_json::to_string(&Message::Response(response))
let message =
serde_json::to_string(&Message::Request(request))
.unwrap();
let mut writer = stdout_writer.lock().await;
writer
.write_all(
TransportDelegate::build_rpc_message(message)
.as_bytes(),
)
.await
.unwrap();
writer.flush().await.unwrap();
} else {
let response = if let Some(handle) = request_handlers
.lock()
.get_mut(request.command.as_str())
{
handle(
request.seq,
request.arguments.unwrap_or(json!({})),
)
} else {
panic!("No request handler for {}", request.command);
};
let message =
serde_json::to_string(&Message::Response(response))
.unwrap();
let mut writer = stdout_writer.lock().await;
writer
.write_all(
TransportDelegate::build_rpc_message(message)
.as_bytes(),
)
.await
.unwrap();
writer.flush().await.unwrap();
}
}
Message::Event(event) => {
let message =
serde_json::to_string(&Message::Event(event)).unwrap();
let mut writer = stdout_writer.lock().await;
writer
.write_all(
TransportDelegate::build_rpc_message(message)
@@ -881,35 +849,21 @@ impl FakeTransport {
.unwrap();
writer.flush().await.unwrap();
}
}
Message::Event(event) => {
let message =
serde_json::to_string(&Message::Event(event)).unwrap();
let mut writer = stdout_writer.lock().await;
writer
.write_all(
TransportDelegate::build_rpc_message(message).as_bytes(),
)
.await
.unwrap();
writer.flush().await.unwrap();
}
Message::Response(response) => {
if let Some(handle) =
response_handlers.lock().get(response.command.as_str())
{
handle(response);
} else {
log::error!("No response handler for {}", response.command);
Message::Response(response) => {
if let Some(handle) =
response_handlers.lock().get(response.command.as_str())
{
handle(response);
} else {
log::error!("No response handler for {}", response.command);
}
}
}
}
}
}
}
})
.detach();
})
.detach();
Ok((
TransportPipe::new(Box::new(stdin_writer), Box::new(stdout_reader), None, None),

View File

@@ -1,8 +1,11 @@
use std::{collections::HashMap, path::PathBuf, sync::OnceLock};
use anyhow::{Context as _, Result};
use anyhow::{Context as _, Result, anyhow};
use async_trait::async_trait;
use dap::adapters::{DebugTaskDefinition, latest_github_release};
use dap::{
StartDebuggingRequestArgumentsRequest,
adapters::{DebugTaskDefinition, latest_github_release},
};
use futures::StreamExt;
use gpui::AsyncApp;
use serde_json::Value;
@@ -34,7 +37,7 @@ impl CodeLldbDebugAdapter {
Value::String(String::from(task_definition.label.as_ref())),
);
let request = self.request_kind(&configuration)?;
let request = self.validate_config(&configuration)?;
Ok(dap::StartDebuggingRequestArguments {
request,
@@ -86,6 +89,48 @@ impl DebugAdapter for CodeLldbDebugAdapter {
DebugAdapterName(Self::ADAPTER_NAME.into())
}
fn validate_config(
&self,
config: &serde_json::Value,
) -> Result<StartDebuggingRequestArgumentsRequest> {
let map = config
.as_object()
.ok_or_else(|| anyhow!("Config isn't an object"))?;
let request_variant = map
.get("request")
.and_then(|r| r.as_str())
.ok_or_else(|| anyhow!("request field is required and must be a string"))?;
match request_variant {
"launch" => {
// For launch, verify that one of the required configs exists
if !(map.contains_key("program")
|| map.contains_key("targetCreateCommands")
|| map.contains_key("cargo"))
{
return Err(anyhow!(
"launch request requires either 'program', 'targetCreateCommands', or 'cargo' field"
));
}
Ok(StartDebuggingRequestArgumentsRequest::Launch)
}
"attach" => {
// For attach, verify that either pid or program exists
if !(map.contains_key("pid") || map.contains_key("program")) {
return Err(anyhow!(
"attach request requires either 'pid' or 'program' field"
));
}
Ok(StartDebuggingRequestArgumentsRequest::Attach)
}
_ => Err(anyhow!(
"request must be either 'launch' or 'attach', got '{}'",
request_variant
)),
}
}
fn config_from_zed_format(&self, zed_scenario: ZedDebugConfig) -> Result<DebugScenario> {
let mut configuration = json!({
"request": match zed_scenario.request {

View File

@@ -37,7 +37,7 @@ pub fn init(cx: &mut App) {
registry.add_adapter(Arc::from(PhpDebugAdapter::default()));
registry.add_adapter(Arc::from(JsDebugAdapter::default()));
registry.add_adapter(Arc::from(RubyDebugAdapter));
registry.add_adapter(Arc::from(GoDebugAdapter::default()));
registry.add_adapter(Arc::from(GoDebugAdapter));
registry.add_adapter(Arc::from(GdbDebugAdapter));
#[cfg(any(test, feature = "test-support"))]

View File

@@ -178,7 +178,7 @@ impl DebugAdapter for GdbDebugAdapter {
let gdb_path = user_setting_path.unwrap_or(gdb_path?);
let request_args = StartDebuggingRequestArguments {
request: self.request_kind(&config.config)?,
request: self.validate_config(&config.config)?,
configuration: config.config.clone(),
};

View File

@@ -1,87 +1,22 @@
use anyhow::{Context as _, bail};
use anyhow::{Context as _, anyhow, bail};
use dap::{
StartDebuggingRequestArguments,
adapters::{
DebugTaskDefinition, DownloadedFileType, download_adapter_from_github,
latest_github_release,
},
StartDebuggingRequestArguments, StartDebuggingRequestArgumentsRequest,
adapters::DebugTaskDefinition,
};
use gpui::{AsyncApp, SharedString};
use language::LanguageName;
use std::{collections::HashMap, env::consts, ffi::OsStr, path::PathBuf, sync::OnceLock};
use std::{collections::HashMap, ffi::OsStr, path::PathBuf};
use util;
use crate::*;
#[derive(Default, Debug)]
pub(crate) struct GoDebugAdapter {
shim_path: OnceLock<PathBuf>,
}
pub(crate) struct GoDebugAdapter;
impl GoDebugAdapter {
const ADAPTER_NAME: &'static str = "Delve";
async fn fetch_latest_adapter_version(
delegate: &Arc<dyn DapDelegate>,
) -> Result<AdapterVersion> {
let release = latest_github_release(
&"zed-industries/delve-shim-dap",
true,
false,
delegate.http_client(),
)
.await?;
let os = match consts::OS {
"macos" => "apple-darwin",
"linux" => "unknown-linux-gnu",
"windows" => "pc-windows-msvc",
other => bail!("Running on unsupported os: {other}"),
};
let suffix = if consts::OS == "windows" {
".zip"
} else {
".tar.gz"
};
let asset_name = format!("delve-shim-dap-{}-{os}{suffix}", consts::ARCH);
let asset = release
.assets
.iter()
.find(|asset| asset.name == asset_name)
.with_context(|| format!("no asset found matching `{asset_name:?}`"))?;
Ok(AdapterVersion {
tag_name: release.tag_name,
url: asset.browser_download_url.clone(),
})
}
async fn install_shim(&self, delegate: &Arc<dyn DapDelegate>) -> anyhow::Result<PathBuf> {
if let Some(path) = self.shim_path.get().cloned() {
return Ok(path);
}
let asset = Self::fetch_latest_adapter_version(delegate).await?;
let ty = if consts::OS == "windows" {
DownloadedFileType::Zip
} else {
DownloadedFileType::GzipTar
};
download_adapter_from_github(
"delve-shim-dap".into(),
asset.clone(),
ty,
delegate.as_ref(),
)
.await?;
let path = paths::debug_adapters_dir()
.join("delve-shim-dap")
.join(format!("delve-shim-dap_{}", asset.tag_name))
.join(format!("delve-shim-dap{}", std::env::consts::EXE_SUFFIX));
self.shim_path.set(path.clone()).ok();
Ok(path)
}
const DEFAULT_TIMEOUT_MS: u64 = 60000;
}
#[async_trait(?Send)]
@@ -350,6 +285,24 @@ impl DebugAdapter for GoDebugAdapter {
})
}
fn validate_config(
&self,
config: &serde_json::Value,
) -> Result<StartDebuggingRequestArgumentsRequest> {
let map = config.as_object().context("Config isn't an object")?;
let request_variant = map
.get("request")
.and_then(|val| val.as_str())
.context("request argument is not found or invalid")?;
match request_variant {
"launch" => Ok(StartDebuggingRequestArgumentsRequest::Launch),
"attach" => Ok(StartDebuggingRequestArgumentsRequest::Attach),
_ => Err(anyhow!("request must be either 'launch' or 'attach'")),
}
}
fn config_from_zed_format(&self, zed_scenario: ZedDebugConfig) -> Result<DebugScenario> {
let mut args = match &zed_scenario.request {
dap::DebugRequest::Attach(attach_config) => {
@@ -396,15 +349,13 @@ impl DebugAdapter for GoDebugAdapter {
&self,
delegate: &Arc<dyn DapDelegate>,
task_definition: &DebugTaskDefinition,
user_installed_path: Option<PathBuf>,
_user_installed_path: Option<PathBuf>,
_cx: &mut AsyncApp,
) -> Result<DebugAdapterBinary> {
let adapter_path = paths::debug_adapters_dir().join(&Self::ADAPTER_NAME);
let dlv_path = adapter_path.join("dlv");
let delve_path = if let Some(path) = user_installed_path {
path.to_string_lossy().to_string()
} else if let Some(path) = delegate.which(OsStr::new("dlv")).await {
let delve_path = if let Some(path) = delegate.which(OsStr::new("dlv")).await {
path.to_string_lossy().to_string()
} else if delegate.fs().is_file(&dlv_path).await {
dlv_path.to_string_lossy().to_string()
@@ -433,10 +384,16 @@ impl DebugAdapter for GoDebugAdapter {
adapter_path.join("dlv").to_string_lossy().to_string()
};
let minidelve_path = self.install_shim(delegate).await?;
let tcp_connection = task_definition.tcp_connection.clone().unwrap_or_default();
let (host, port, _) = crate::configure_tcp_connection(tcp_connection).await?;
let mut tcp_connection = task_definition.tcp_connection.clone().unwrap_or_default();
if tcp_connection.timeout.is_none()
|| tcp_connection.timeout.unwrap_or(0) < Self::DEFAULT_TIMEOUT_MS
{
tcp_connection.timeout = Some(Self::DEFAULT_TIMEOUT_MS);
}
let (host, port, timeout) = crate::configure_tcp_connection(tcp_connection).await?;
let cwd = task_definition
.config
@@ -447,7 +404,6 @@ impl DebugAdapter for GoDebugAdapter {
let arguments = if cfg!(windows) {
vec![
delve_path,
"dap".into(),
"--listen".into(),
format!("{}:{}", host, port),
@@ -455,7 +411,6 @@ impl DebugAdapter for GoDebugAdapter {
]
} else {
vec![
delve_path,
"dap".into(),
"--listen".into(),
format!("{}:{}", host, port),
@@ -463,14 +418,18 @@ impl DebugAdapter for GoDebugAdapter {
};
Ok(DebugAdapterBinary {
command: minidelve_path.to_string_lossy().into_owned(),
command: delve_path,
arguments,
cwd: Some(cwd),
envs: HashMap::default(),
connection: None,
connection: Some(adapters::TcpArguments {
host,
port,
timeout,
}),
request_args: StartDebuggingRequestArguments {
configuration: task_definition.config.clone(),
request: self.request_kind(&task_definition.config)?,
request: self.validate_config(&task_definition.config)?,
},
})
}

View File

@@ -1,6 +1,9 @@
use adapters::latest_github_release;
use anyhow::Context as _;
use dap::{StartDebuggingRequestArguments, adapters::DebugTaskDefinition};
use anyhow::{Context as _, anyhow};
use dap::{
StartDebuggingRequestArguments, StartDebuggingRequestArgumentsRequest,
adapters::DebugTaskDefinition,
};
use gpui::AsyncApp;
use std::{collections::HashMap, path::PathBuf, sync::OnceLock};
use task::DebugRequest;
@@ -23,7 +26,7 @@ impl JsDebugAdapter {
delegate: &Arc<dyn DapDelegate>,
) -> Result<AdapterVersion> {
let release = latest_github_release(
&format!("microsoft/{}", Self::ADAPTER_NPM_NAME),
&format!("{}/{}", "microsoft", Self::ADAPTER_NPM_NAME),
true,
false,
delegate.http_client(),
@@ -92,7 +95,7 @@ impl JsDebugAdapter {
}),
request_args: StartDebuggingRequestArguments {
configuration: task_definition.config.clone(),
request: self.request_kind(&task_definition.config)?,
request: self.validate_config(&task_definition.config)?,
},
})
}
@@ -104,6 +107,29 @@ impl DebugAdapter for JsDebugAdapter {
DebugAdapterName(Self::ADAPTER_NAME.into())
}
fn validate_config(
&self,
config: &serde_json::Value,
) -> Result<dap::StartDebuggingRequestArgumentsRequest> {
match config.get("request") {
Some(val) if val == "launch" => {
if config.get("program").is_none() && config.get("url").is_none() {
return Err(anyhow!(
"either program or url is required for launch request"
));
}
Ok(StartDebuggingRequestArgumentsRequest::Launch)
}
Some(val) if val == "attach" => {
if !config.get("processId").is_some_and(|val| val.is_u64()) {
return Err(anyhow!("processId must be a number"));
}
Ok(StartDebuggingRequestArgumentsRequest::Attach)
}
_ => Err(anyhow!("missing or invalid request field in config")),
}
}
fn config_from_zed_format(&self, zed_scenario: ZedDebugConfig) -> Result<DebugScenario> {
let mut args = json!({
"type": "pwa-node",
@@ -423,8 +449,6 @@ impl DebugAdapter for JsDebugAdapter {
delegate.as_ref(),
)
.await?;
} else {
delegate.output_to_console(format!("{} debug adapter is up to date", self.name()));
}
}

View File

@@ -94,7 +94,7 @@ impl PhpDebugAdapter {
envs: HashMap::default(),
request_args: StartDebuggingRequestArguments {
configuration: task_definition.config.clone(),
request: <Self as DebugAdapter>::request_kind(self, &task_definition.config)?,
request: <Self as DebugAdapter>::validate_config(self, &task_definition.config)?,
},
})
}
@@ -282,7 +282,10 @@ impl DebugAdapter for PhpDebugAdapter {
Some(SharedString::new_static("PHP").into())
}
fn request_kind(&self, _: &serde_json::Value) -> Result<StartDebuggingRequestArgumentsRequest> {
fn validate_config(
&self,
_: &serde_json::Value,
) -> Result<StartDebuggingRequestArgumentsRequest> {
Ok(StartDebuggingRequestArgumentsRequest::Launch)
}

View File

@@ -1,6 +1,9 @@
use crate::*;
use anyhow::Context as _;
use dap::{DebugRequest, StartDebuggingRequestArguments, adapters::DebugTaskDefinition};
use anyhow::{Context as _, anyhow};
use dap::{
DebugRequest, StartDebuggingRequestArguments, StartDebuggingRequestArgumentsRequest,
adapters::DebugTaskDefinition,
};
use gpui::{AsyncApp, SharedString};
use json_dotpath::DotPaths;
use language::{LanguageName, Toolchain};
@@ -83,7 +86,7 @@ impl PythonDebugAdapter {
&self,
task_definition: &DebugTaskDefinition,
) -> Result<StartDebuggingRequestArguments> {
let request = self.request_kind(&task_definition.config)?;
let request = self.validate_config(&task_definition.config)?;
let mut configuration = task_definition.config.clone();
if let Ok(console) = configuration.dot_get_mut("console") {
@@ -251,6 +254,24 @@ impl DebugAdapter for PythonDebugAdapter {
})
}
fn validate_config(
&self,
config: &serde_json::Value,
) -> Result<StartDebuggingRequestArgumentsRequest> {
let map = config.as_object().context("Config isn't an object")?;
let request_variant = map
.get("request")
.and_then(|val| val.as_str())
.context("request is not valid")?;
match request_variant {
"launch" => Ok(StartDebuggingRequestArgumentsRequest::Launch),
"attach" => Ok(StartDebuggingRequestArgumentsRequest::Attach),
_ => Err(anyhow!("request must be either 'launch' or 'attach'")),
}
}
async fn dap_schema(&self) -> serde_json::Value {
json!({
"properties": {
@@ -639,7 +660,7 @@ impl DebugAdapter for PythonDebugAdapter {
}
}
self.get_installed_binary(delegate, &config, None, toolchain, false)
self.get_installed_binary(delegate, &config, None, None, false)
.await
}
}

View File

@@ -265,7 +265,7 @@ impl DebugAdapter for RubyDebugAdapter {
cwd: None,
envs: std::collections::HashMap::default(),
request_args: StartDebuggingRequestArguments {
request: self.request_kind(&definition.config)?,
request: self.validate_config(&definition.config)?,
configuration: definition.config.clone(),
},
})

View File

@@ -5,7 +5,7 @@ use crate::{
ClearAllBreakpoints, Continue, Detach, FocusBreakpointList, FocusConsole, FocusFrames,
FocusLoadedSources, FocusModules, FocusTerminal, FocusVariables, Pause, Restart,
ShowStackTrace, StepBack, StepInto, StepOut, StepOver, Stop, ToggleIgnoreBreakpoints,
ToggleSessionPicker, ToggleThreadPicker, persistence, spawn_task_or_modal,
ToggleSessionPicker, ToggleThreadPicker, persistence,
};
use anyhow::{Context as _, Result, anyhow};
use command_palette_hooks::CommandPaletteFilter;
@@ -65,7 +65,6 @@ pub struct DebugPanel {
workspace: WeakEntity<Workspace>,
focus_handle: FocusHandle,
context_menu: Option<(Entity<ContextMenu>, Point<Pixels>, Subscription)>,
debug_scenario_scheduled_last: bool,
pub(crate) thread_picker_menu_handle: PopoverMenuHandle<ContextMenu>,
pub(crate) session_picker_menu_handle: PopoverMenuHandle<ContextMenu>,
fs: Arc<dyn Fs>,
@@ -104,7 +103,6 @@ impl DebugPanel {
thread_picker_menu_handle,
session_picker_menu_handle,
_subscriptions: [focus_subscription],
debug_scenario_scheduled_last: true,
}
})
}
@@ -266,7 +264,6 @@ impl DebugPanel {
cx,
)
});
self.debug_scenario_scheduled_last = true;
if let Some(inventory) = self
.project
.read(cx)
@@ -435,10 +432,7 @@ impl DebugPanel {
};
let dap_store_handle = self.project.read(cx).dap_store().clone();
let mut label = parent_session.read(cx).label().clone();
if !label.ends_with("(child)") {
label = format!("{label} (child)").into();
}
let label = parent_session.read(cx).label().clone();
let adapter = parent_session.read(cx).adapter().clone();
let mut binary = parent_session.read(cx).binary().clone();
binary.request_args = request.clone();
@@ -1384,30 +1378,4 @@ impl workspace::DebuggerProvider for DebuggerProvider {
})
})
}
fn spawn_task_or_modal(
&self,
workspace: &mut Workspace,
action: &tasks_ui::Spawn,
window: &mut Window,
cx: &mut Context<Workspace>,
) {
spawn_task_or_modal(workspace, action, window, cx);
}
fn debug_scenario_scheduled(&self, cx: &mut App) {
self.0.update(cx, |this, _| {
this.debug_scenario_scheduled_last = true;
});
}
fn task_scheduled(&self, cx: &mut App) {
self.0.update(cx, |this, _| {
this.debug_scenario_scheduled_last = false;
})
}
fn debug_scenario_scheduled_last(&self, cx: &App) -> bool {
self.0.read(cx).debug_scenario_scheduled_last
}
}

View File

@@ -3,12 +3,11 @@ use debugger_panel::{DebugPanel, ToggleFocus};
use editor::Editor;
use feature_flags::{DebuggerFeatureFlag, FeatureFlagViewExt};
use gpui::{App, EntityInputHandler, actions};
use new_session_modal::{NewSessionModal, NewSessionMode};
use new_session_modal::NewSessionModal;
use project::debugger::{self, breakpoint_store::SourceBreakpoint};
use session::DebugSession;
use settings::Settings;
use stack_trace_view::StackTraceView;
use tasks_ui::{Spawn, TaskOverrides};
use util::maybe;
use workspace::{ItemHandle, ShutdownDebugAdapters, Workspace};
@@ -63,7 +62,6 @@ pub fn init(cx: &mut App) {
cx.when_flag_enabled::<DebuggerFeatureFlag>(window, |workspace, _, _| {
workspace
.register_action(spawn_task_or_modal)
.register_action(|workspace, _: &ToggleFocus, window, cx| {
workspace.toggle_panel_focus::<DebugPanel>(window, cx);
})
@@ -210,7 +208,7 @@ pub fn init(cx: &mut App) {
},
)
.register_action(|workspace: &mut Workspace, _: &Start, window, cx| {
NewSessionModal::show(workspace, window, NewSessionMode::Launch, None, cx);
NewSessionModal::show(workspace, window, cx);
})
.register_action(
|workspace: &mut Workspace, _: &RerunLastSession, window, cx| {
@@ -311,48 +309,3 @@ pub fn init(cx: &mut App) {
})
.detach();
}
fn spawn_task_or_modal(
workspace: &mut Workspace,
action: &Spawn,
window: &mut ui::Window,
cx: &mut ui::Context<Workspace>,
) {
match action {
Spawn::ByName {
task_name,
reveal_target,
} => {
let overrides = reveal_target.map(|reveal_target| TaskOverrides {
reveal_target: Some(reveal_target),
});
let name = task_name.clone();
tasks_ui::spawn_tasks_filtered(
move |(_, task)| task.label.eq(&name),
overrides,
window,
cx,
)
.detach_and_log_err(cx)
}
Spawn::ByTag {
task_tag,
reveal_target,
} => {
let overrides = reveal_target.map(|reveal_target| TaskOverrides {
reveal_target: Some(reveal_target),
});
let tag = task_tag.clone();
tasks_ui::spawn_tasks_filtered(
move |(_, task)| task.tags.contains(&tag),
overrides,
window,
cx,
)
.detach_and_log_err(cx)
}
Spawn::ViaModal { reveal_target } => {
NewSessionModal::show(workspace, window, NewSessionMode::Task, *reveal_target, cx);
}
}
}

View File

@@ -1,6 +1,4 @@
use std::time::Duration;
use gpui::{Animation, AnimationExt as _, Entity, Transformation, percentage};
use gpui::Entity;
use project::debugger::session::{ThreadId, ThreadStatus};
use ui::{ContextMenu, DropdownMenu, DropdownStyle, Indicator, prelude::*};
@@ -25,40 +23,31 @@ impl DebugPanel {
let sessions = self.sessions().clone();
let weak = cx.weak_entity();
let running_state = running_state.read(cx);
let label = if let Some(active_session) = active_session.clone() {
let label = if let Some(active_session) = active_session {
active_session.read(cx).session(cx).read(cx).label()
} else {
SharedString::new_static("Unknown Session")
};
let is_terminated = running_state.session().read(cx).is_terminated();
let is_started = active_session
.is_some_and(|session| session.read(cx).session(cx).read(cx).is_started());
let session_state_indicator = if is_terminated {
Indicator::dot().color(Color::Error).into_any_element()
} else if !is_started {
Icon::new(IconName::ArrowCircle)
.size(IconSize::Small)
.color(Color::Muted)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
)
.into_any_element()
} else {
match running_state.thread_status(cx).unwrap_or_default() {
ThreadStatus::Stopped => {
Indicator::dot().color(Color::Conflict).into_any_element()
let session_state_indicator = {
if is_terminated {
Some(Indicator::dot().color(Color::Error))
} else {
match running_state.thread_status(cx).unwrap_or_default() {
project::debugger::session::ThreadStatus::Stopped => {
Some(Indicator::dot().color(Color::Conflict))
}
_ => Some(Indicator::dot().color(Color::Success)),
}
_ => Indicator::dot().color(Color::Success).into_any_element(),
}
};
let trigger = h_flex()
.gap_2()
.child(session_state_indicator)
.when_some(session_state_indicator, |this, indicator| {
this.child(indicator)
})
.justify_between()
.child(
DebugPanel::dropdown_label(label)

View File

@@ -8,7 +8,6 @@ use std::{
time::Duration,
usize,
};
use tasks_ui::{TaskOverrides, TasksModal};
use dap::{
DapRegistry, DebugRequest, TelemetrySpawnLocation, adapters::DebugAdapterName, send_telemetry,
@@ -17,19 +16,19 @@ use editor::{Anchor, Editor, EditorElement, EditorStyle, scroll::Autoscroll};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
Animation, AnimationExt as _, App, AppContext, DismissEvent, Entity, EventEmitter, FocusHandle,
Focusable, KeyContext, Render, Subscription, TextStyle, Transformation, WeakEntity, percentage,
Focusable, Render, Subscription, TextStyle, Transformation, WeakEntity, percentage,
};
use picker::{Picker, PickerDelegate, highlighted_match_with_paths::HighlightedMatch};
use project::{ProjectPath, TaskContexts, TaskSourceKind, task_store::TaskStore};
use settings::Settings;
use task::{DebugScenario, LaunchRequest, RevealTarget, ZedDebugConfig};
use task::{DebugScenario, LaunchRequest, ZedDebugConfig};
use theme::ThemeSettings;
use ui::{
ActiveTheme, Button, ButtonCommon, ButtonSize, CheckboxWithLabel, Clickable, Color, Context,
ContextMenu, Disableable, DropdownMenu, FluentBuilder, Icon, IconButton, IconName, IconSize,
IconWithIndicator, Indicator, InteractiveElement, IntoElement, Label, LabelCommon as _,
ListItem, ListItemSpacing, ParentElement, RenderOnce, SharedString, Styled, StyledExt,
ToggleButton, ToggleState, Toggleable, Window, div, h_flex, relative, rems, v_flex,
InteractiveElement, IntoElement, Label, LabelCommon as _, ListItem, ListItemSpacing,
ParentElement, RenderOnce, SharedString, Styled, StyledExt, ToggleButton, ToggleState,
Toggleable, Window, div, h_flex, relative, rems, v_flex,
};
use util::ResultExt;
use workspace::{ModalView, Workspace, pane};
@@ -48,11 +47,10 @@ pub(super) struct NewSessionModal {
mode: NewSessionMode,
launch_picker: Entity<Picker<DebugScenarioDelegate>>,
attach_mode: Entity<AttachMode>,
configure_mode: Entity<ConfigureMode>,
task_mode: TaskMode,
custom_mode: Entity<CustomMode>,
debugger: Option<DebugAdapterName>,
save_scenario_state: Option<SaveScenarioState>,
_subscriptions: [Subscription; 3],
_subscriptions: [Subscription; 2],
}
fn suggested_label(request: &DebugRequest, debugger: &str) -> SharedString {
@@ -77,8 +75,6 @@ impl NewSessionModal {
pub(super) fn show(
workspace: &mut Workspace,
window: &mut Window,
mode: NewSessionMode,
reveal_target: Option<RevealTarget>,
cx: &mut Context<Workspace>,
) {
let Some(debug_panel) = workspace.panel::<DebugPanel>(cx) else {
@@ -88,50 +84,20 @@ impl NewSessionModal {
let languages = workspace.app_state().languages.clone();
cx.spawn_in(window, async move |workspace, cx| {
let task_contexts = workspace
.update_in(cx, |workspace, window, cx| {
tasks_ui::task_contexts(workspace, window, cx)
})?
.await;
let task_contexts = Arc::new(task_contexts);
workspace.update_in(cx, |workspace, window, cx| {
let workspace_handle = workspace.weak_handle();
workspace.toggle_modal(window, cx, |window, cx| {
let attach_mode = AttachMode::new(None, workspace_handle.clone(), window, cx);
let launch_picker = cx.new(|cx| {
let mut delegate =
DebugScenarioDelegate::new(debug_panel.downgrade(), task_store.clone());
delegate.task_contexts_loaded(task_contexts.clone(), languages, window, cx);
Picker::uniform_list(delegate, window, cx).modal(false)
Picker::uniform_list(
DebugScenarioDelegate::new(debug_panel.downgrade(), task_store),
window,
cx,
)
.modal(false)
});
let configure_mode = ConfigureMode::new(None, window, cx);
if let Some(active_cwd) = task_contexts
.active_context()
.and_then(|context| context.cwd.clone())
{
configure_mode.update(cx, |configure_mode, cx| {
configure_mode.load(active_cwd, window, cx);
});
}
let task_overrides = Some(TaskOverrides { reveal_target });
let task_mode = TaskMode {
task_modal: cx.new(|cx| {
TasksModal::new(
task_store.clone(),
task_contexts,
task_overrides,
false,
workspace_handle.clone(),
window,
cx,
)
}),
};
let _subscriptions = [
cx.subscribe(&launch_picker, |_, _, _, cx| {
cx.emit(DismissEvent);
@@ -142,18 +108,52 @@ impl NewSessionModal {
cx.emit(DismissEvent);
},
),
cx.subscribe(&task_mode.task_modal, |_, _, _: &DismissEvent, cx| {
cx.emit(DismissEvent)
}),
];
let custom_mode = CustomMode::new(None, window, cx);
cx.spawn_in(window, {
let workspace_handle = workspace_handle.clone();
async move |this, cx| {
let task_contexts = workspace_handle
.update_in(cx, |workspace, window, cx| {
tasks_ui::task_contexts(workspace, window, cx)
})?
.await;
this.update_in(cx, |this, window, cx| {
if let Some(active_cwd) = task_contexts
.active_context()
.and_then(|context| context.cwd.clone())
{
this.custom_mode.update(cx, |custom, cx| {
custom.load(active_cwd, window, cx);
});
this.debugger = None;
}
this.launch_picker.update(cx, |picker, cx| {
picker.delegate.task_contexts_loaded(
task_contexts,
languages,
window,
cx,
);
picker.refresh(window, cx);
cx.notify();
});
})
}
})
.detach();
Self {
launch_picker,
attach_mode,
configure_mode,
task_mode,
custom_mode,
debugger: None,
mode,
mode: NewSessionMode::Launch,
debug_panel: debug_panel.downgrade(),
workspace: workspace_handle,
save_scenario_state: None,
@@ -170,17 +170,10 @@ impl NewSessionModal {
fn render_mode(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl ui::IntoElement {
let dap_menu = self.adapter_drop_down_menu(window, cx);
match self.mode {
NewSessionMode::Task => self
.task_mode
.task_modal
.read(cx)
.picker
.clone()
.into_any_element(),
NewSessionMode::Attach => self.attach_mode.update(cx, |this, cx| {
this.clone().render(window, cx).into_any_element()
}),
NewSessionMode::Configure => self.configure_mode.update(cx, |this, cx| {
NewSessionMode::Custom => self.custom_mode.update(cx, |this, cx| {
this.clone().render(dap_menu, window, cx).into_any_element()
}),
NewSessionMode::Launch => v_flex()
@@ -192,17 +185,16 @@ impl NewSessionModal {
fn mode_focus_handle(&self, cx: &App) -> FocusHandle {
match self.mode {
NewSessionMode::Task => self.task_mode.task_modal.focus_handle(cx),
NewSessionMode::Attach => self.attach_mode.read(cx).attach_picker.focus_handle(cx),
NewSessionMode::Configure => self.configure_mode.read(cx).program.focus_handle(cx),
NewSessionMode::Custom => self.custom_mode.read(cx).program.focus_handle(cx),
NewSessionMode::Launch => self.launch_picker.focus_handle(cx),
}
}
fn debug_scenario(&self, debugger: &str, cx: &App) -> Option<DebugScenario> {
let request = match self.mode {
NewSessionMode::Configure => Some(DebugRequest::Launch(
self.configure_mode.read(cx).debug_request(cx),
NewSessionMode::Custom => Some(DebugRequest::Launch(
self.custom_mode.read(cx).debug_request(cx),
)),
NewSessionMode::Attach => Some(DebugRequest::Attach(
self.attach_mode.read(cx).debug_request(),
@@ -211,8 +203,8 @@ impl NewSessionModal {
}?;
let label = suggested_label(&request, debugger);
let stop_on_entry = if let NewSessionMode::Configure = &self.mode {
Some(self.configure_mode.read(cx).stop_on_entry.selected())
let stop_on_entry = if let NewSessionMode::Custom = &self.mode {
Some(self.custom_mode.read(cx).stop_on_entry.selected())
} else {
None
};
@@ -535,8 +527,7 @@ static SELECT_DEBUGGER_LABEL: SharedString = SharedString::new_static("Select De
#[derive(Clone)]
pub(crate) enum NewSessionMode {
Task,
Configure,
Custom,
Attach,
Launch,
}
@@ -544,10 +535,9 @@ pub(crate) enum NewSessionMode {
impl std::fmt::Display for NewSessionMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mode = match self {
NewSessionMode::Task => "Run",
NewSessionMode::Launch => "Debug",
NewSessionMode::Attach => "Attach",
NewSessionMode::Configure => "Configure Debugger",
NewSessionMode::Launch => "Launch".to_owned(),
NewSessionMode::Attach => "Attach".to_owned(),
NewSessionMode::Custom => "Custom".to_owned(),
};
write!(f, "{}", mode)
@@ -607,39 +597,36 @@ impl Render for NewSessionModal {
v_flex()
.size_full()
.w(rems(34.))
.key_context({
let mut key_context = KeyContext::new_with_defaults();
key_context.add("Pane");
key_context.add("RunModal");
key_context
})
.key_context("Pane")
.elevation_3(cx)
.bg(cx.theme().colors().elevated_surface_background)
.on_action(cx.listener(|_, _: &menu::Cancel, _, cx| {
cx.emit(DismissEvent);
}))
.on_action(cx.listener(|this, _: &pane::ActivateNextItem, window, cx| {
this.mode = match this.mode {
NewSessionMode::Task => NewSessionMode::Launch,
NewSessionMode::Launch => NewSessionMode::Attach,
NewSessionMode::Attach => NewSessionMode::Configure,
NewSessionMode::Configure => NewSessionMode::Task,
};
this.mode_focus_handle(cx).focus(window);
}))
.on_action(
cx.listener(|this, _: &pane::ActivatePreviousItem, window, cx| {
this.mode = match this.mode {
NewSessionMode::Task => NewSessionMode::Configure,
NewSessionMode::Launch => NewSessionMode::Task,
NewSessionMode::Attach => NewSessionMode::Launch,
NewSessionMode::Configure => NewSessionMode::Attach,
NewSessionMode::Launch => NewSessionMode::Attach,
_ => {
return;
}
};
this.mode_focus_handle(cx).focus(window);
}),
)
.on_action(cx.listener(|this, _: &pane::ActivateNextItem, window, cx| {
this.mode = match this.mode {
NewSessionMode::Attach => NewSessionMode::Launch,
NewSessionMode::Launch => NewSessionMode::Attach,
_ => {
return;
}
};
this.mode_focus_handle(cx).focus(window);
}))
.child(
h_flex()
.w_full()
@@ -650,73 +637,37 @@ impl Render for NewSessionModal {
.justify_start()
.w_full()
.child(
ToggleButton::new(
"debugger-session-ui-tasks-button",
NewSessionMode::Task.to_string(),
)
.size(ButtonSize::Default)
.toggle_state(matches!(self.mode, NewSessionMode::Task))
.style(ui::ButtonStyle::Subtle)
.on_click(cx.listener(|this, _, window, cx| {
this.mode = NewSessionMode::Task;
this.mode_focus_handle(cx).focus(window);
cx.notify();
}))
.first(),
ToggleButton::new("debugger-session-ui-picker-button", "Launch")
.size(ButtonSize::Default)
.style(ui::ButtonStyle::Subtle)
.toggle_state(matches!(self.mode, NewSessionMode::Launch))
.on_click(cx.listener(|this, _, window, cx| {
this.mode = NewSessionMode::Launch;
this.mode_focus_handle(cx).focus(window);
cx.notify();
}))
.first(),
)
.child(
ToggleButton::new(
"debugger-session-ui-launch-button",
NewSessionMode::Launch.to_string(),
)
.size(ButtonSize::Default)
.style(ui::ButtonStyle::Subtle)
.toggle_state(matches!(self.mode, NewSessionMode::Launch))
.on_click(cx.listener(|this, _, window, cx| {
this.mode = NewSessionMode::Launch;
this.mode_focus_handle(cx).focus(window);
cx.notify();
}))
.middle(),
)
.child(
ToggleButton::new(
"debugger-session-ui-attach-button",
NewSessionMode::Attach.to_string(),
)
.size(ButtonSize::Default)
.toggle_state(matches!(self.mode, NewSessionMode::Attach))
.style(ui::ButtonStyle::Subtle)
.on_click(cx.listener(|this, _, window, cx| {
this.mode = NewSessionMode::Attach;
ToggleButton::new("debugger-session-ui-attach-button", "Attach")
.size(ButtonSize::Default)
.toggle_state(matches!(self.mode, NewSessionMode::Attach))
.style(ui::ButtonStyle::Subtle)
.on_click(cx.listener(|this, _, window, cx| {
this.mode = NewSessionMode::Attach;
if let Some(debugger) = this.debugger.as_ref() {
Self::update_attach_picker(
&this.attach_mode,
&debugger,
window,
cx,
);
}
this.mode_focus_handle(cx).focus(window);
cx.notify();
}))
.middle(),
)
.child(
ToggleButton::new(
"debugger-session-ui-custom-button",
NewSessionMode::Configure.to_string(),
)
.size(ButtonSize::Default)
.toggle_state(matches!(self.mode, NewSessionMode::Configure))
.style(ui::ButtonStyle::Subtle)
.on_click(cx.listener(|this, _, window, cx| {
this.mode = NewSessionMode::Configure;
this.mode_focus_handle(cx).focus(window);
cx.notify();
}))
.last(),
if let Some(debugger) = this.debugger.as_ref() {
Self::update_attach_picker(
&this.attach_mode,
&debugger,
window,
cx,
);
}
this.mode_focus_handle(cx).focus(window);
cx.notify();
}))
.last(),
),
)
.justify_between()
@@ -724,83 +675,83 @@ impl Render for NewSessionModal {
.border_b_1(),
)
.child(v_flex().child(self.render_mode(window, cx)))
.map(|el| {
let container = h_flex()
.child(
h_flex()
.justify_between()
.gap_2()
.p_2()
.border_color(cx.theme().colors().border_variant)
.border_t_1()
.w_full();
match self.mode {
NewSessionMode::Configure => el.child(
container
.w_full()
.child(match self.mode {
NewSessionMode::Attach => {
div().child(self.adapter_drop_down_menu(window, cx))
}
NewSessionMode::Launch => div().child(
Button::new("new-session-modal-custom", "Custom").on_click({
let this = cx.weak_entity();
move |_, window, cx| {
this.update(cx, |this, cx| {
this.mode = NewSessionMode::Custom;
this.mode_focus_handle(cx).focus(window);
})
.ok();
}
}),
),
NewSessionMode::Custom => h_flex()
.child(
h_flex()
.child(
Button::new(
"new-session-modal-back",
"Save to .zed/debug.json...",
)
.on_click(cx.listener(|this, _, window, cx| {
this.save_debug_scenario(window, cx);
}))
.disabled(
self.debugger.is_none()
|| self
.configure_mode
.read(cx)
.program
.read(cx)
.is_empty(cx)
|| self.save_scenario_state.is_some(),
),
)
.child(self.render_save_state(cx)),
)
.child(
Button::new("debugger-spawn", "Start")
Button::new("new-session-modal-back", "Save to .zed/debug.json...")
.on_click(cx.listener(|this, _, window, cx| {
this.start_new_session(window, cx)
this.save_debug_scenario(window, cx);
}))
.disabled(
self.debugger.is_none()
|| self
.configure_mode
.custom_mode
.read(cx)
.program
.read(cx)
.is_empty(cx),
.is_empty(cx)
|| self.save_scenario_state.is_some(),
),
),
)
.child(self.render_save_state(cx)),
})
.child(
Button::new("debugger-spawn", "Start")
.on_click(cx.listener(|this, _, window, cx| match &this.mode {
NewSessionMode::Launch => {
this.launch_picker.update(cx, |picker, cx| {
picker.delegate.confirm(true, window, cx)
})
}
_ => this.start_new_session(window, cx),
}))
.disabled(match self.mode {
NewSessionMode::Launch => {
!self.launch_picker.read(cx).delegate.matches.is_empty()
}
NewSessionMode::Attach => {
self.debugger.is_none()
|| self
.attach_mode
.read(cx)
.attach_picker
.read(cx)
.picker
.read(cx)
.delegate
.match_count()
== 0
}
NewSessionMode::Custom => {
self.debugger.is_none()
|| self.custom_mode.read(cx).program.read(cx).is_empty(cx)
}
}),
),
NewSessionMode::Attach => el.child(
container
.child(div().child(self.adapter_drop_down_menu(window, cx)))
.child(
Button::new("debugger-spawn", "Start")
.on_click(cx.listener(|this, _, window, cx| {
this.start_new_session(window, cx)
}))
.disabled(
self.debugger.is_none()
|| self
.attach_mode
.read(cx)
.attach_picker
.read(cx)
.picker
.read(cx)
.delegate
.match_count()
== 0,
),
),
),
NewSessionMode::Launch => el,
NewSessionMode::Task => el,
}
})
)
}
}
@@ -823,13 +774,13 @@ impl RenderOnce for AttachMode {
}
#[derive(Clone)]
pub(super) struct ConfigureMode {
pub(super) struct CustomMode {
program: Entity<Editor>,
cwd: Entity<Editor>,
stop_on_entry: ToggleState,
}
impl ConfigureMode {
impl CustomMode {
pub(super) fn new(
past_launch_config: Option<LaunchRequest>,
window: &mut Window,
@@ -989,11 +940,6 @@ impl AttachMode {
}
}
#[derive(Clone)]
pub(super) struct TaskMode {
pub(super) task_modal: Entity<TasksModal>,
}
pub(super) struct DebugScenarioDelegate {
task_store: Entity<TaskStore>,
candidates: Vec<(Option<TaskSourceKind>, DebugScenario)>,
@@ -1049,12 +995,12 @@ impl DebugScenarioDelegate {
pub fn task_contexts_loaded(
&mut self,
task_contexts: Arc<TaskContexts>,
task_contexts: TaskContexts,
languages: Arc<LanguageRegistry>,
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) {
self.task_contexts = Some(task_contexts);
self.task_contexts = Some(Arc::new(task_contexts));
let (recent, scenarios) = self
.task_store
@@ -1222,32 +1168,21 @@ impl PickerDelegate for DebugScenarioDelegate {
let task_kind = &self.candidates[hit.candidate_id].0;
let icon = match task_kind {
Some(TaskSourceKind::Lsp(..)) => Some(Icon::new(IconName::BoltFilled)),
Some(TaskSourceKind::UserInput) => Some(Icon::new(IconName::Terminal)),
Some(TaskSourceKind::AbsPath { .. }) => Some(Icon::new(IconName::Settings)),
Some(TaskSourceKind::Worktree { .. }) => Some(Icon::new(IconName::FileTree)),
Some(TaskSourceKind::Lsp {
language_name: name,
..
})
| Some(TaskSourceKind::Language { name }) => file_icons::FileIcons::get(cx)
Some(TaskSourceKind::Language { name }) => file_icons::FileIcons::get(cx)
.get_icon_for_type(&name.to_lowercase(), cx)
.map(Icon::from_path),
None => Some(Icon::new(IconName::HistoryRerun)),
}
.map(|icon| icon.color(Color::Muted).size(IconSize::Small));
let indicator = if matches!(task_kind, Some(TaskSourceKind::Lsp { .. })) {
Some(Indicator::icon(
Icon::new(IconName::BoltFilled).color(Color::Muted),
))
} else {
None
};
let icon = icon.map(|icon| IconWithIndicator::new(icon, indicator));
.map(|icon| icon.color(Color::Muted).size(ui::IconSize::Small));
Some(
ListItem::new(SharedString::from(format!("debug-scenario-selection-{ix}")))
.inset(true)
.start_slot::<IconWithIndicator>(icon)
.start_slot::<Icon>(icon)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
.child(highlighted_location.render(window, cx)),
@@ -1271,7 +1206,7 @@ pub(crate) fn resolve_path(path: &mut String) {
#[cfg(test)]
impl NewSessionModal {
pub(crate) fn set_configure(
pub(crate) fn set_custom(
&mut self,
program: impl AsRef<str>,
cwd: impl AsRef<str>,
@@ -1279,21 +1214,21 @@ impl NewSessionModal {
window: &mut Window,
cx: &mut Context<Self>,
) {
self.mode = NewSessionMode::Configure;
self.mode = NewSessionMode::Custom;
self.debugger = Some(dap::adapters::DebugAdapterName("fake-adapter".into()));
self.configure_mode.update(cx, |configure, cx| {
configure.program.update(cx, |editor, cx| {
self.custom_mode.update(cx, |custom, cx| {
custom.program.update(cx, |editor, cx| {
editor.clear(window, cx);
editor.set_text(program.as_ref(), window, cx);
});
configure.cwd.update(cx, |editor, cx| {
custom.cwd.update(cx, |editor, cx| {
editor.clear(window, cx);
editor.set_text(cwd.as_ref(), window, cx);
});
configure.stop_on_entry = match stop_on_entry {
custom.stop_on_entry = match stop_on_entry {
true => ToggleState::Selected,
_ => ToggleState::Unselected,
}
@@ -1304,3 +1239,28 @@ impl NewSessionModal {
self.save_debug_scenario(window, cx);
}
}
#[cfg(test)]
mod tests {
use paths::home_dir;
#[test]
fn test_normalize_paths() {
let sep = std::path::MAIN_SEPARATOR;
let home = home_dir().to_string_lossy().to_string();
let resolve_path = |path: &str| -> String {
let mut path = path.to_string();
super::resolve_path(&mut path);
path
};
assert_eq!(resolve_path("bin"), format!("bin"));
assert_eq!(resolve_path(&format!("{sep}foo")), format!("{sep}foo"));
assert_eq!(resolve_path(""), format!(""));
assert_eq!(
resolve_path(&format!("~{sep}blah")),
format!("{home}{sep}blah")
);
assert_eq!(resolve_path("~"), home);
}
}

View File

@@ -61,28 +61,6 @@ impl DebuggerPaneItem {
DebuggerPaneItem::Terminal => SharedString::new_static("Terminal"),
}
}
pub(crate) fn tab_tooltip(self) -> SharedString {
let tooltip = match self {
DebuggerPaneItem::Console => {
"Displays program output and allows manual input of debugger commands."
}
DebuggerPaneItem::Variables => {
"Shows current values of local and global variables in the current stack frame."
}
DebuggerPaneItem::BreakpointList => "Lists all active breakpoints set in the code.",
DebuggerPaneItem::Frames => {
"Displays the call stack, letting you navigate between function calls."
}
DebuggerPaneItem::Modules => "Shows all modules or libraries loaded by the program.",
DebuggerPaneItem::LoadedSources => {
"Lists all source files currently loaded and used by the debugger."
}
DebuggerPaneItem::Terminal => {
"Provides an interactive terminal session within the debugging environment."
}
};
SharedString::new_static(tooltip)
}
}
impl From<DebuggerPaneItem> for SharedString {

View File

@@ -173,10 +173,6 @@ impl Item for SubView {
self.kind.to_shared_string()
}
fn tab_tooltip_text(&self, _: &App) -> Option<SharedString> {
Some(self.kind.tab_tooltip())
}
fn tab_content(
&self,
params: workspace::item::TabContentParams,
@@ -403,9 +399,6 @@ pub(crate) fn new_debugger_pane(
.p_1()
.rounded_md()
.cursor_pointer()
.when_some(item.tab_tooltip_text(cx), |this, tooltip| {
this.tooltip(Tooltip::text(tooltip))
})
.map(|this| {
let theme = cx.theme();
if selected {
@@ -812,7 +805,7 @@ impl RunningState {
let request_type = dap_registry
.adapter(&adapter)
.ok_or_else(|| anyhow!("{}: is not a valid adapter name", &adapter))
.and_then(|adapter| adapter.request_kind(&config));
.and_then(|adapter| adapter.validate_config(&config));
let config_is_valid = request_type.is_ok();
@@ -881,6 +874,7 @@ impl RunningState {
args,
..task.resolved.clone()
};
let terminal = project
.update_in(cx, |project, window, cx| {
project.create_terminal(
@@ -925,12 +919,6 @@ impl RunningState {
};
if config_is_valid {
// Ok(DebugTaskDefinition {
// label,
// adapter: DebugAdapterName(adapter),
// config,
// tcp_connection,
// })
} else if let Some((task, locator_name)) = build_output {
let locator_name =
locator_name.context("Could not find a valid locator for a build task")?;
@@ -949,15 +937,12 @@ impl RunningState {
let scenario = dap_registry
.adapter(&adapter)
.ok_or_else(|| anyhow!("{}: is not a valid adapter name", &adapter))
.context(format!("{}: is not a valid adapter name", &adapter))
.map(|adapter| adapter.config_from_zed_format(zed_config))??;
config = scenario.config;
Self::substitute_variables_in_config(&mut config, &task_context);
} else {
let Err(e) = request_type else {
unreachable!();
};
anyhow::bail!("Zed cannot determine how to run this debug scenario. `build` field was not provided and Debug Adapter won't accept provided configuration because: {e}");
anyhow::bail!("No request or build provided");
};
Ok(DebugTaskDefinition {

View File

@@ -110,7 +110,7 @@ impl Console {
}
fn is_running(&self, cx: &Context<Self>) -> bool {
self.session.read(cx).is_running()
self.session.read(cx).is_local()
}
fn handle_stack_frame_list_events(
@@ -176,18 +176,16 @@ impl Console {
}
fn render_console(&self, cx: &Context<Self>) -> impl IntoElement {
EditorElement::new(&self.console, Self::editor_style(&self.console, cx))
EditorElement::new(&self.console, self.editor_style(cx))
}
fn editor_style(editor: &Entity<Editor>, cx: &Context<Self>) -> EditorStyle {
let is_read_only = editor.read(cx).read_only(cx);
fn editor_style(&self, cx: &Context<Self>) -> EditorStyle {
let settings = ThemeSettings::get_global(cx);
let theme = cx.theme();
let text_style = TextStyle {
color: if is_read_only {
theme.colors().text_muted
color: if self.console.read(cx).read_only(cx) {
cx.theme().colors().text_disabled
} else {
theme.colors().text
cx.theme().colors().text
},
font_family: settings.buffer_font.family.clone(),
font_features: settings.buffer_font.features.clone(),
@@ -197,15 +195,15 @@ impl Console {
..Default::default()
};
EditorStyle {
background: theme.colors().editor_background,
local_player: theme.players().local(),
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
..Default::default()
}
}
fn render_query_bar(&self, cx: &Context<Self>) -> impl IntoElement {
EditorElement::new(&self.query_bar, Self::editor_style(&self.query_bar, cx))
EditorElement::new(&self.query_bar, self.editor_style(cx))
}
fn update_output(&mut self, window: &mut Window, cx: &mut Context<Self>) {

View File

@@ -250,6 +250,9 @@ impl StackFrameList {
let Some(abs_path) = Self::abs_path_from_stack_frame(&stack_frame) else {
return Task::ready(Err(anyhow!("Project path not found")));
};
if abs_path.starts_with("<node_internals>") {
return Task::ready(Ok(()));
}
let row = stack_frame.line.saturating_sub(1) as u32;
cx.emit(StackFrameListEvent::SelectedStackFrameChanged(
stack_frame_id,
@@ -342,7 +345,6 @@ impl StackFrameList {
s.path
.as_deref()
.map(|path| Arc::<Path>::from(Path::new(path)))
.filter(|path| path.is_absolute())
})
}

View File

@@ -7,7 +7,6 @@ use std::sync::atomic::{AtomicBool, Ordering};
use task::{DebugRequest, DebugScenario, LaunchRequest, TaskContext, VariableName, ZedDebugConfig};
use util::path;
use crate::new_session_modal::NewSessionMode;
use crate::tests::{init_test, init_test_workspace};
#[gpui::test]
@@ -171,13 +170,7 @@ async fn test_save_debug_scenario_to_file(executor: BackgroundExecutor, cx: &mut
workspace
.update(cx, |workspace, window, cx| {
crate::new_session_modal::NewSessionModal::show(
workspace,
window,
NewSessionMode::Launch,
None,
cx,
);
crate::new_session_modal::NewSessionModal::show(workspace, window, cx);
})
.unwrap();
@@ -191,7 +184,7 @@ async fn test_save_debug_scenario_to_file(executor: BackgroundExecutor, cx: &mut
.expect("Modal should be active");
modal.update_in(cx, |modal, window, cx| {
modal.set_configure("/project/main", "/project", false, window, cx);
modal.set_custom("/project/main", "/project", false, window, cx);
modal.save_scenario(window, cx);
});
@@ -220,7 +213,7 @@ async fn test_save_debug_scenario_to_file(executor: BackgroundExecutor, cx: &mut
pretty_assertions::assert_eq!(expected_content, actual_lines);
modal.update_in(cx, |modal, window, cx| {
modal.set_configure("/project/other", "/project", true, window, cx);
modal.set_custom("/project/other", "/project", true, window, cx);
modal.save_scenario(window, cx);
});
@@ -322,7 +315,7 @@ async fn test_dap_adapter_config_conversion_and_validation(cx: &mut TestAppConte
);
let request_type = adapter
.request_kind(&debug_scenario.config)
.validate_config(&debug_scenario.config)
.unwrap_or_else(|_| {
panic!(
"Adapter {} should validate the config successfully",

View File

@@ -936,8 +936,6 @@ pub struct Editor {
select_next_state: Option<SelectNextState>,
select_prev_state: Option<SelectNextState>,
selection_history: SelectionHistory,
defer_selection_effects: bool,
deferred_selection_effects_state: Option<DeferredSelectionEffectsState>,
autoclose_regions: Vec<AutocloseRegion>,
snippet_stack: InvalidationStack<SnippetState>,
select_syntax_node_history: SelectSyntaxNodeHistory,
@@ -1197,14 +1195,6 @@ impl Default for SelectionHistoryMode {
}
}
struct DeferredSelectionEffectsState {
changed: bool,
show_completions: bool,
autoscroll: Option<Autoscroll>,
old_cursor_position: Anchor,
history_entry: SelectionHistoryEntry,
}
#[derive(Default)]
struct SelectionHistory {
#[allow(clippy::type_complexity)]
@@ -1801,8 +1791,6 @@ impl Editor {
select_next_state: None,
select_prev_state: None,
selection_history: SelectionHistory::default(),
defer_selection_effects: false,
deferred_selection_effects_state: None,
autoclose_regions: Vec::new(),
snippet_stack: InvalidationStack::default(),
select_syntax_node_history: SelectSyntaxNodeHistory::default(),
@@ -2966,9 +2954,6 @@ impl Editor {
Subscription::join(other_subscription, this_subscription)
}
/// Changes selections using the provided mutation function. Changes to `self.selections` occur
/// immediately, but when run within `transact` or `with_selection_effects_deferred` other
/// effects of selection change occur at the end of the transaction.
pub fn change_selections<R>(
&mut self,
autoscroll: Option<Autoscroll>,
@@ -2976,105 +2961,39 @@ impl Editor {
cx: &mut Context<Self>,
change: impl FnOnce(&mut MutableSelectionsCollection<'_>) -> R,
) -> R {
self.change_selections_inner(true, autoscroll, window, cx, change)
}
pub(crate) fn change_selections_without_showing_completions<R>(
&mut self,
autoscroll: Option<Autoscroll>,
window: &mut Window,
cx: &mut Context<Self>,
change: impl FnOnce(&mut MutableSelectionsCollection<'_>) -> R,
) -> R {
self.change_selections_inner(false, autoscroll, window, cx, change)
self.change_selections_inner(autoscroll, true, window, cx, change)
}
fn change_selections_inner<R>(
&mut self,
show_completions: bool,
autoscroll: Option<Autoscroll>,
request_completions: bool,
window: &mut Window,
cx: &mut Context<Self>,
change: impl FnOnce(&mut MutableSelectionsCollection<'_>) -> R,
) -> R {
if let Some(state) = &mut self.deferred_selection_effects_state {
state.autoscroll = autoscroll.or(state.autoscroll);
state.show_completions = show_completions;
let (changed, result) = self.selections.change_with(cx, change);
state.changed |= changed;
return result;
}
let mut state = DeferredSelectionEffectsState {
changed: false,
show_completions,
autoscroll,
old_cursor_position: self.selections.newest_anchor().head(),
history_entry: SelectionHistoryEntry {
selections: self.selections.disjoint_anchors(),
select_next_state: self.select_next_state.clone(),
select_prev_state: self.select_prev_state.clone(),
add_selections_state: self.add_selections_state.clone(),
},
};
let old_cursor_position = self.selections.newest_anchor().head();
self.push_to_selection_history();
let (changed, result) = self.selections.change_with(cx, change);
state.changed = state.changed || changed;
if self.defer_selection_effects {
self.deferred_selection_effects_state = Some(state);
} else {
self.apply_selection_effects(state, window, cx);
}
result
}
/// Defers the effects of selection change, so that the effects of multiple calls to
/// `change_selections` are applied at the end. This way these intermediate states aren't added
/// to selection history and the state of popovers based on selection position aren't
/// erroneously updated.
pub fn with_selection_effects_deferred<R>(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
update: impl FnOnce(&mut Self, &mut Window, &mut Context<Self>) -> R,
) -> R {
let already_deferred = self.defer_selection_effects;
self.defer_selection_effects = true;
let result = update(self, window, cx);
if !already_deferred {
self.defer_selection_effects = false;
if let Some(state) = self.deferred_selection_effects_state.take() {
self.apply_selection_effects(state, window, cx);
}
}
result
}
fn apply_selection_effects(
&mut self,
state: DeferredSelectionEffectsState,
window: &mut Window,
cx: &mut Context<Self>,
) {
if state.changed {
self.selection_history.push(state.history_entry);
if let Some(autoscroll) = state.autoscroll {
if changed {
if let Some(autoscroll) = autoscroll {
self.request_autoscroll(autoscroll, cx);
}
self.selections_did_change(true, &old_cursor_position, request_completions, window, cx);
let old_cursor_position = &state.old_cursor_position;
self.selections_did_change(
true,
if self.should_open_signature_help_automatically(
&old_cursor_position,
state.show_completions,
window,
self.signature_help_state.backspace_pressed(),
cx,
);
if self.should_open_signature_help_automatically(&old_cursor_position, cx) {
) {
self.show_signature_help(&ShowSignatureHelp, window, cx);
}
self.signature_help_state.set_backspace_pressed(false);
}
result
}
pub fn edit<I, S, T>(&mut self, edits: I, cx: &mut Context<Self>)
@@ -3958,12 +3877,9 @@ impl Editor {
}
let had_active_inline_completion = this.has_active_inline_completion();
this.change_selections_without_showing_completions(
Some(Autoscroll::fit()),
window,
cx,
|s| s.select(new_selections),
);
this.change_selections_inner(Some(Autoscroll::fit()), false, window, cx, |s| {
s.select(new_selections)
});
if !bracket_inserted {
if let Some(on_type_format_task) =
@@ -5368,6 +5284,7 @@ impl Editor {
mat.candidate_id
};
let buffer_handle = completions_menu.buffer;
let completion = completions_menu
.completions
.borrow()
@@ -5375,23 +5292,34 @@ impl Editor {
.clone();
cx.stop_propagation();
let buffer_handle = completions_menu.buffer;
let CompletionEdit {
new_text,
snippet,
replace_range,
} = process_completion_for_edit(
&completion,
intent,
&buffer_handle,
&completions_menu.initial_position.text_anchor,
cx,
);
let buffer = buffer_handle.read(cx);
let snapshot = self.buffer.read(cx).snapshot(cx);
let newest_anchor = self.selections.newest_anchor();
let snippet;
let new_text;
if completion.is_snippet() {
let mut snippet_source = completion.new_text.clone();
if let Some(scope) = snapshot.language_scope_at(newest_anchor.head()) {
if scope.prefers_label_for_snippet_in_completion() {
if let Some(label) = completion.label() {
if matches!(
completion.kind(),
Some(CompletionItemKind::FUNCTION) | Some(CompletionItemKind::METHOD)
) {
snippet_source = label;
}
}
}
}
snippet = Some(Snippet::parse(&snippet_source).log_err()?);
new_text = snippet.as_ref().unwrap().text.clone();
} else {
snippet = None;
new_text = completion.new_text.clone();
};
let replace_range = choose_completion_range(&completion, intent, &buffer_handle, cx);
let buffer = buffer_handle.read(cx);
let replace_range_multibuffer = {
let excerpt = snapshot.excerpt_containing(newest_anchor.range()).unwrap();
let multibuffer_anchor = snapshot
@@ -9105,6 +9033,7 @@ impl Editor {
}
}
this.signature_help_state.set_backspace_pressed(true);
this.change_selections(Some(Autoscroll::fit()), window, cx, |s| {
s.select(selections)
});
@@ -12826,6 +12755,7 @@ impl Editor {
) -> Result<()> {
self.hide_mouse_cursor(&HideMouseCursorOrigin::MovementAction);
self.push_to_selection_history();
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
self.select_next_match_internal(&display_map, false, None, window, cx)?;
@@ -12878,6 +12808,7 @@ impl Editor {
cx: &mut Context<Self>,
) -> Result<()> {
self.hide_mouse_cursor(&HideMouseCursorOrigin::MovementAction);
self.push_to_selection_history();
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
self.select_next_match_internal(
&display_map,
@@ -12896,6 +12827,7 @@ impl Editor {
cx: &mut Context<Self>,
) -> Result<()> {
self.hide_mouse_cursor(&HideMouseCursorOrigin::MovementAction);
self.push_to_selection_history();
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
let buffer = &display_map.buffer_snapshot;
let mut selections = self.selections.all::<usize>(cx);
@@ -15089,7 +15021,7 @@ impl Editor {
text_style = text_style.highlight(highlight_style);
}
div()
.block_mouse_except_scroll()
.block_mouse_down()
.pl(cx.anchor_x)
.child(EditorElement::new(
&rename_editor,
@@ -15691,14 +15623,15 @@ impl Editor {
None
};
self.inline_diagnostics_update = cx.spawn_in(window, async move |editor, cx| {
let editor = editor.upgrade().unwrap();
if let Some(debounce) = debounce {
cx.background_executor().timer(debounce).await;
}
let Some(snapshot) = editor.upgrade().and_then(|editor| {
editor
.update(cx, |editor, cx| editor.buffer().read(cx).snapshot(cx))
.ok()
}) else {
let Some(snapshot) = editor
.update(cx, |editor, cx| editor.buffer().read(cx).snapshot(cx))
.ok()
else {
return;
};
@@ -15764,17 +15697,24 @@ impl Editor {
self.selections_did_change(false, &old_cursor_position, true, window, cx);
}
fn push_to_selection_history(&mut self) {
self.selection_history.push(SelectionHistoryEntry {
selections: self.selections.disjoint_anchors(),
select_next_state: self.select_next_state.clone(),
select_prev_state: self.select_prev_state.clone(),
add_selections_state: self.add_selections_state.clone(),
});
}
pub fn transact(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
update: impl FnOnce(&mut Self, &mut Window, &mut Context<Self>),
) -> Option<TransactionId> {
self.with_selection_effects_deferred(window, cx, |this, window, cx| {
this.start_transaction_at(Instant::now(), window, cx);
update(this, window, cx);
this.end_transaction_at(Instant::now(), cx)
})
self.start_transaction_at(Instant::now(), window, cx);
update(self, window, cx);
self.end_transaction_at(Instant::now(), cx)
}
pub fn start_transaction_at(
@@ -18695,20 +18635,16 @@ impl Editor {
}
let minimap_settings = EditorSettings::get_global(cx).minimap;
if self.minimap_visibility != MinimapVisibility::Disabled {
if self.minimap_visibility.settings_visibility()
!= minimap_settings.minimap_enabled()
{
self.set_minimap_visibility(
MinimapVisibility::for_mode(self.mode(), cx),
window,
cx,
);
} else if let Some(minimap_entity) = self.minimap.as_ref() {
minimap_entity.update(cx, |minimap_editor, cx| {
minimap_editor.update_minimap_configuration(minimap_settings, cx)
})
}
if self.minimap_visibility.settings_visibility() != minimap_settings.minimap_enabled() {
self.set_minimap_visibility(
MinimapVisibility::for_mode(self.mode(), cx),
window,
cx,
);
} else if let Some(minimap_entity) = self.minimap.as_ref() {
minimap_entity.update(cx, |minimap_editor, cx| {
minimap_editor.update_minimap_configuration(minimap_settings, cx)
})
}
}
@@ -19502,152 +19438,79 @@ fn vim_enabled(cx: &App) -> bool {
== Some(&serde_json::Value::Bool(true))
}
fn process_completion_for_edit(
// Consider user intent and default settings
fn choose_completion_range(
completion: &Completion,
intent: CompletionIntent,
buffer: &Entity<Buffer>,
cursor_position: &text::Anchor,
cx: &mut Context<Editor>,
) -> CompletionEdit {
) -> Range<usize> {
fn should_replace(
completion: &Completion,
insert_range: &Range<text::Anchor>,
intent: CompletionIntent,
completion_mode_setting: LspInsertMode,
buffer: &Buffer,
) -> bool {
// specific actions take precedence over settings
match intent {
CompletionIntent::CompleteWithInsert => return false,
CompletionIntent::CompleteWithReplace => return true,
CompletionIntent::Complete | CompletionIntent::Compose => {}
}
match completion_mode_setting {
LspInsertMode::Insert => false,
LspInsertMode::Replace => true,
LspInsertMode::ReplaceSubsequence => {
let mut text_to_replace = buffer.chars_for_range(
buffer.anchor_before(completion.replace_range.start)
..buffer.anchor_after(completion.replace_range.end),
);
let mut completion_text = completion.new_text.chars();
// is `text_to_replace` a subsequence of `completion_text`
text_to_replace
.all(|needle_ch| completion_text.any(|haystack_ch| haystack_ch == needle_ch))
}
LspInsertMode::ReplaceSuffix => {
let range_after_cursor = insert_range.end..completion.replace_range.end;
let text_after_cursor = buffer
.text_for_range(
buffer.anchor_before(range_after_cursor.start)
..buffer.anchor_after(range_after_cursor.end),
)
.collect::<String>();
completion.new_text.ends_with(&text_after_cursor)
}
}
}
let buffer = buffer.read(cx);
let buffer_snapshot = buffer.snapshot();
let (snippet, new_text) = if completion.is_snippet() {
let mut snippet_source = completion.new_text.clone();
if let Some(scope) = buffer_snapshot.language_scope_at(cursor_position) {
if scope.prefers_label_for_snippet_in_completion() {
if let Some(label) = completion.label() {
if matches!(
completion.kind(),
Some(CompletionItemKind::FUNCTION) | Some(CompletionItemKind::METHOD)
) {
snippet_source = label;
}
}
}
}
match Snippet::parse(&snippet_source).log_err() {
Some(parsed_snippet) => (Some(parsed_snippet.clone()), parsed_snippet.text),
None => (None, completion.new_text.clone()),
}
} else {
(None, completion.new_text.clone())
};
let mut range_to_replace = {
let replace_range = &completion.replace_range;
if let CompletionSource::Lsp {
insert_range: Some(insert_range),
..
} = &completion.source
{
debug_assert_eq!(
insert_range.start, replace_range.start,
"insert_range and replace_range should start at the same position"
);
debug_assert!(
insert_range
.start
.cmp(&cursor_position, &buffer_snapshot)
.is_le(),
"insert_range should start before or at cursor position"
);
debug_assert!(
replace_range
.start
.cmp(&cursor_position, &buffer_snapshot)
.is_le(),
"replace_range should start before or at cursor position"
);
debug_assert!(
insert_range
.end
.cmp(&cursor_position, &buffer_snapshot)
.is_le(),
"insert_range should end before or at cursor position"
);
let should_replace = match intent {
CompletionIntent::CompleteWithInsert => false,
CompletionIntent::CompleteWithReplace => true,
CompletionIntent::Complete | CompletionIntent::Compose => {
let insert_mode =
language_settings(buffer.language().map(|l| l.name()), buffer.file(), cx)
.completions
.lsp_insert_mode;
match insert_mode {
LspInsertMode::Insert => false,
LspInsertMode::Replace => true,
LspInsertMode::ReplaceSubsequence => {
let mut text_to_replace = buffer.chars_for_range(
buffer.anchor_before(replace_range.start)
..buffer.anchor_after(replace_range.end),
);
let mut current_needle = text_to_replace.next();
for haystack_ch in completion.label.text.chars() {
if let Some(needle_ch) = current_needle {
if haystack_ch.eq_ignore_ascii_case(&needle_ch) {
current_needle = text_to_replace.next();
}
}
}
current_needle.is_none()
}
LspInsertMode::ReplaceSuffix => {
if replace_range
.end
.cmp(&cursor_position, &buffer_snapshot)
.is_gt()
{
let range_after_cursor = *cursor_position..replace_range.end;
let text_after_cursor = buffer
.text_for_range(
buffer.anchor_before(range_after_cursor.start)
..buffer.anchor_after(range_after_cursor.end),
)
.collect::<String>()
.to_ascii_lowercase();
completion
.label
.text
.to_ascii_lowercase()
.ends_with(&text_after_cursor)
} else {
true
}
}
}
}
};
if should_replace {
replace_range.clone()
} else {
insert_range.clone()
}
} else {
replace_range.clone()
}
};
if range_to_replace
.end
.cmp(&cursor_position, &buffer_snapshot)
.is_lt()
if let CompletionSource::Lsp {
insert_range: Some(insert_range),
..
} = &completion.source
{
range_to_replace.end = *cursor_position;
let completion_mode_setting =
language_settings(buffer.language().map(|l| l.name()), buffer.file(), cx)
.completions
.lsp_insert_mode;
if !should_replace(
completion,
&insert_range,
intent,
completion_mode_setting,
buffer,
) {
return insert_range.to_offset(buffer);
}
}
CompletionEdit {
new_text,
replace_range: range_to_replace.to_offset(&buffer),
snippet,
}
}
struct CompletionEdit {
new_text: String,
replace_range: Range<usize>,
snippet: Option<Snippet>,
completion.replace_range.to_offset(buffer)
}
fn insert_extra_newline_brackets(
@@ -22040,7 +21903,7 @@ fn render_diff_hunk_controls(
.rounded_b_lg()
.bg(cx.theme().colors().editor_background)
.gap_1()
.block_mouse_except_scroll()
.stop_mouse_events_except_scroll()
.shadow_md()
.child(if status.has_secondary_hunk() {
Button::new(("stage", row as u64), "Stage")

View File

@@ -10479,7 +10479,6 @@ async fn test_completion_mode(cx: &mut TestAppContext) {
run_description: &'static str,
initial_state: String,
buffer_marked_text: String,
completion_label: &'static str,
completion_text: &'static str,
expected_with_insert_mode: String,
expected_with_replace_mode: String,
@@ -10492,7 +10491,6 @@ async fn test_completion_mode(cx: &mut TestAppContext) {
run_description: "Start of word matches completion text",
initial_state: "before ediˇ after".into(),
buffer_marked_text: "before <edi|> after".into(),
completion_label: "editor",
completion_text: "editor",
expected_with_insert_mode: "before editorˇ after".into(),
expected_with_replace_mode: "before editorˇ after".into(),
@@ -10503,7 +10501,6 @@ async fn test_completion_mode(cx: &mut TestAppContext) {
run_description: "Accept same text at the middle of the word",
initial_state: "before ediˇtor after".into(),
buffer_marked_text: "before <edi|tor> after".into(),
completion_label: "editor",
completion_text: "editor",
expected_with_insert_mode: "before editorˇtor after".into(),
expected_with_replace_mode: "before editorˇ after".into(),
@@ -10514,7 +10511,6 @@ async fn test_completion_mode(cx: &mut TestAppContext) {
run_description: "End of word matches completion text -- cursor at end",
initial_state: "before torˇ after".into(),
buffer_marked_text: "before <tor|> after".into(),
completion_label: "editor",
completion_text: "editor",
expected_with_insert_mode: "before editorˇ after".into(),
expected_with_replace_mode: "before editorˇ after".into(),
@@ -10525,7 +10521,6 @@ async fn test_completion_mode(cx: &mut TestAppContext) {
run_description: "End of word matches completion text -- cursor at start",
initial_state: "before ˇtor after".into(),
buffer_marked_text: "before <|tor> after".into(),
completion_label: "editor",
completion_text: "editor",
expected_with_insert_mode: "before editorˇtor after".into(),
expected_with_replace_mode: "before editorˇ after".into(),
@@ -10536,7 +10531,6 @@ async fn test_completion_mode(cx: &mut TestAppContext) {
run_description: "Prepend text containing whitespace",
initial_state: "pˇfield: bool".into(),
buffer_marked_text: "<p|field>: bool".into(),
completion_label: "pub ",
completion_text: "pub ",
expected_with_insert_mode: "pub ˇfield: bool".into(),
expected_with_replace_mode: "pub ˇ: bool".into(),
@@ -10547,7 +10541,6 @@ async fn test_completion_mode(cx: &mut TestAppContext) {
run_description: "Add element to start of list",
initial_state: "[element_ˇelement_2]".into(),
buffer_marked_text: "[<element_|element_2>]".into(),
completion_label: "element_1",
completion_text: "element_1",
expected_with_insert_mode: "[element_1ˇelement_2]".into(),
expected_with_replace_mode: "[element_1ˇ]".into(),
@@ -10558,7 +10551,6 @@ async fn test_completion_mode(cx: &mut TestAppContext) {
run_description: "Add element to start of list -- first and second elements are equal",
initial_state: "[elˇelement]".into(),
buffer_marked_text: "[<el|element>]".into(),
completion_label: "element",
completion_text: "element",
expected_with_insert_mode: "[elementˇelement]".into(),
expected_with_replace_mode: "[elementˇ]".into(),
@@ -10569,7 +10561,6 @@ async fn test_completion_mode(cx: &mut TestAppContext) {
run_description: "Ends with matching suffix",
initial_state: "SubˇError".into(),
buffer_marked_text: "<Sub|Error>".into(),
completion_label: "SubscriptionError",
completion_text: "SubscriptionError",
expected_with_insert_mode: "SubscriptionErrorˇError".into(),
expected_with_replace_mode: "SubscriptionErrorˇ".into(),
@@ -10580,7 +10571,6 @@ async fn test_completion_mode(cx: &mut TestAppContext) {
run_description: "Suffix is a subsequence -- contiguous",
initial_state: "SubˇErr".into(),
buffer_marked_text: "<Sub|Err>".into(),
completion_label: "SubscriptionError",
completion_text: "SubscriptionError",
expected_with_insert_mode: "SubscriptionErrorˇErr".into(),
expected_with_replace_mode: "SubscriptionErrorˇ".into(),
@@ -10591,7 +10581,6 @@ async fn test_completion_mode(cx: &mut TestAppContext) {
run_description: "Suffix is a subsequence -- non-contiguous -- replace intended",
initial_state: "Suˇscrirr".into(),
buffer_marked_text: "<Su|scrirr>".into(),
completion_label: "SubscriptionError",
completion_text: "SubscriptionError",
expected_with_insert_mode: "SubscriptionErrorˇscrirr".into(),
expected_with_replace_mode: "SubscriptionErrorˇ".into(),
@@ -10602,46 +10591,12 @@ async fn test_completion_mode(cx: &mut TestAppContext) {
run_description: "Suffix is a subsequence -- non-contiguous -- replace unintended",
initial_state: "foo(indˇix)".into(),
buffer_marked_text: "foo(<ind|ix>)".into(),
completion_label: "node_index",
completion_text: "node_index",
expected_with_insert_mode: "foo(node_indexˇix)".into(),
expected_with_replace_mode: "foo(node_indexˇ)".into(),
expected_with_replace_subsequence_mode: "foo(node_indexˇix)".into(),
expected_with_replace_suffix_mode: "foo(node_indexˇix)".into(),
},
Run {
run_description: "Replace range ends before cursor - should extend to cursor",
initial_state: "before editˇo after".into(),
buffer_marked_text: "before <{ed}>it|o after".into(),
completion_label: "editor",
completion_text: "editor",
expected_with_insert_mode: "before editorˇo after".into(),
expected_with_replace_mode: "before editorˇo after".into(),
expected_with_replace_subsequence_mode: "before editorˇo after".into(),
expected_with_replace_suffix_mode: "before editorˇo after".into(),
},
Run {
run_description: "Uses label for suffix matching",
initial_state: "before ediˇtor after".into(),
buffer_marked_text: "before <edi|tor> after".into(),
completion_label: "editor",
completion_text: "editor()",
expected_with_insert_mode: "before editor()ˇtor after".into(),
expected_with_replace_mode: "before editor()ˇ after".into(),
expected_with_replace_subsequence_mode: "before editor()ˇ after".into(),
expected_with_replace_suffix_mode: "before editor()ˇ after".into(),
},
Run {
run_description: "Case insensitive subsequence and suffix matching",
initial_state: "before EDiˇtoR after".into(),
buffer_marked_text: "before <EDi|toR> after".into(),
completion_label: "editor",
completion_text: "editor",
expected_with_insert_mode: "before editorˇtoR after".into(),
expected_with_replace_mode: "before editorˇ after".into(),
expected_with_replace_subsequence_mode: "before editorˇ after".into(),
expected_with_replace_suffix_mode: "before editorˇ after".into(),
},
];
for run in runs {
@@ -10682,7 +10637,7 @@ async fn test_completion_mode(cx: &mut TestAppContext) {
handle_completion_request_with_insert_and_replace(
&mut cx,
&run.buffer_marked_text,
vec![(run.completion_label, run.completion_text)],
vec![run.completion_text],
counter.clone(),
)
.await;
@@ -10742,7 +10697,7 @@ async fn test_completion_with_mode_specified_by_action(cx: &mut TestAppContext)
handle_completion_request_with_insert_and_replace(
&mut cx,
&buffer_marked_text,
vec![(completion_text, completion_text)],
vec![completion_text],
counter.clone(),
)
.await;
@@ -10776,7 +10731,7 @@ async fn test_completion_with_mode_specified_by_action(cx: &mut TestAppContext)
handle_completion_request_with_insert_and_replace(
&mut cx,
&buffer_marked_text,
vec![(completion_text, completion_text)],
vec![completion_text],
counter.clone(),
)
.await;
@@ -10863,7 +10818,7 @@ async fn test_completion_replacing_surrounding_text_with_multicursors(cx: &mut T
handle_completion_request_with_insert_and_replace(
&mut cx,
completion_marked_buffer,
vec![(completion_text, completion_text)],
vec![completion_text],
Arc::new(AtomicUsize::new(0)),
)
.await;
@@ -10917,7 +10872,7 @@ async fn test_completion_replacing_surrounding_text_with_multicursors(cx: &mut T
handle_completion_request_with_insert_and_replace(
&mut cx,
completion_marked_buffer,
vec![(completion_text, completion_text)],
vec![completion_text],
Arc::new(AtomicUsize::new(0)),
)
.await;
@@ -10966,7 +10921,7 @@ async fn test_completion_replacing_surrounding_text_with_multicursors(cx: &mut T
handle_completion_request_with_insert_and_replace(
&mut cx,
completion_marked_buffer,
vec![(completion_text, completion_text)],
vec![completion_text],
Arc::new(AtomicUsize::new(0)),
)
.await;
@@ -16813,9 +16768,9 @@ fn indent_guide(buffer_id: BufferId, start_row: u32, end_row: u32, depth: u32) -
async fn test_indent_guide_single_line(cx: &mut TestAppContext) {
let (buffer_id, mut cx) = setup_indent_guides_editor(
&"
fn main() {
let a = 1;
}"
fn main() {
let a = 1;
}"
.unindent(),
cx,
)
@@ -16828,10 +16783,10 @@ async fn test_indent_guide_single_line(cx: &mut TestAppContext) {
async fn test_indent_guide_simple_block(cx: &mut TestAppContext) {
let (buffer_id, mut cx) = setup_indent_guides_editor(
&"
fn main() {
let a = 1;
let b = 2;
}"
fn main() {
let a = 1;
let b = 2;
}"
.unindent(),
cx,
)
@@ -16844,14 +16799,14 @@ async fn test_indent_guide_simple_block(cx: &mut TestAppContext) {
async fn test_indent_guide_nested(cx: &mut TestAppContext) {
let (buffer_id, mut cx) = setup_indent_guides_editor(
&"
fn main() {
let a = 1;
if a == 3 {
let b = 2;
} else {
let c = 3;
}
}"
fn main() {
let a = 1;
if a == 3 {
let b = 2;
} else {
let c = 3;
}
}"
.unindent(),
cx,
)
@@ -16873,11 +16828,11 @@ async fn test_indent_guide_nested(cx: &mut TestAppContext) {
async fn test_indent_guide_tab(cx: &mut TestAppContext) {
let (buffer_id, mut cx) = setup_indent_guides_editor(
&"
fn main() {
let a = 1;
let b = 2;
let c = 3;
}"
fn main() {
let a = 1;
let b = 2;
let c = 3;
}"
.unindent(),
cx,
)
@@ -17007,72 +16962,6 @@ async fn test_indent_guide_ends_off_screen(cx: &mut TestAppContext) {
);
}
#[gpui::test]
async fn test_indent_guide_with_folds(cx: &mut TestAppContext) {
let (buffer_id, mut cx) = setup_indent_guides_editor(
&"
fn main() {
if a {
b(
c,
d,
)
} else {
e(
f
)
}
}"
.unindent(),
cx,
)
.await;
assert_indent_guides(
0..11,
vec![
indent_guide(buffer_id, 1, 10, 0),
indent_guide(buffer_id, 2, 5, 1),
indent_guide(buffer_id, 7, 9, 1),
indent_guide(buffer_id, 3, 4, 2),
indent_guide(buffer_id, 8, 8, 2),
],
None,
&mut cx,
);
cx.update_editor(|editor, window, cx| {
editor.fold_at(MultiBufferRow(2), window, cx);
assert_eq!(
editor.display_text(cx),
"
fn main() {
if a {
b(⋯
)
} else {
e(
f
)
}
}"
.unindent()
);
});
assert_indent_guides(
0..11,
vec![
indent_guide(buffer_id, 1, 10, 0),
indent_guide(buffer_id, 2, 5, 1),
indent_guide(buffer_id, 7, 9, 1),
indent_guide(buffer_id, 8, 8, 2),
],
None,
&mut cx,
);
}
#[gpui::test]
async fn test_indent_guide_without_brackets(cx: &mut TestAppContext) {
let (buffer_id, mut cx) = setup_indent_guides_editor(
@@ -20061,6 +19950,7 @@ println!("5");
pane_1
.update_in(cx, |pane, window, cx| {
pane.close_inactive_items(&CloseInactiveItems::default(), window, cx)
.unwrap()
})
.await
.unwrap();
@@ -20097,6 +19987,7 @@ println!("5");
pane_2
.update_in(cx, |pane, window, cx| {
pane.close_inactive_items(&CloseInactiveItems::default(), window, cx)
.unwrap()
})
.await
.unwrap();
@@ -20272,6 +20163,7 @@ println!("5");
});
pane.update_in(cx, |pane, window, cx| {
pane.close_all_items(&CloseAllItems::default(), window, cx)
.unwrap()
})
.await
.unwrap();
@@ -20625,6 +20517,7 @@ async fn test_invisible_worktree_servers(cx: &mut TestAppContext) {
pane.update_in(cx, |pane, window, cx| {
pane.close_active_item(&CloseActiveItem::default(), window, cx)
})
.unwrap()
.await
.unwrap();
pane.update_in(cx, |pane, window, cx| {
@@ -21109,27 +21002,19 @@ pub fn handle_completion_request(
/// Similar to `handle_completion_request`, but a [`CompletionTextEdit::InsertAndReplace`] will be
/// given instead, which also contains an `insert` range.
///
/// This function uses markers to define ranges:
/// - `|` marks the cursor position
/// - `<>` marks the replace range
/// - `[]` marks the insert range (optional, defaults to `replace_range.start..cursor_pos`which is what Rust-Analyzer provides)
/// This function uses the cursor position to mimic what Rust-Analyzer provides as the `insert` range,
/// that is, `replace_range.start..cursor_pos`.
pub fn handle_completion_request_with_insert_and_replace(
cx: &mut EditorLspTestContext,
marked_string: &str,
completions: Vec<(&'static str, &'static str)>, // (label, new_text)
completions: Vec<&'static str>,
counter: Arc<AtomicUsize>,
) -> impl Future<Output = ()> {
let complete_from_marker: TextRangeMarker = '|'.into();
let replace_range_marker: TextRangeMarker = ('<', '>').into();
let insert_range_marker: TextRangeMarker = ('{', '}').into();
let (_, mut marked_ranges) = marked_text_ranges_by(
marked_string,
vec![
complete_from_marker.clone(),
replace_range_marker.clone(),
insert_range_marker.clone(),
],
vec![complete_from_marker.clone(), replace_range_marker.clone()],
);
let complete_from_position =
@@ -21137,14 +21022,6 @@ pub fn handle_completion_request_with_insert_and_replace(
let replace_range =
cx.to_lsp_range(marked_ranges.remove(&replace_range_marker).unwrap()[0].clone());
let insert_range = match marked_ranges.remove(&insert_range_marker) {
Some(ranges) if !ranges.is_empty() => cx.to_lsp_range(ranges[0].clone()),
_ => lsp::Range {
start: replace_range.start,
end: complete_from_position,
},
};
let mut request =
cx.set_request_handler::<lsp::request::Completion, _, _>(move |url, params, _| {
let completions = completions.clone();
@@ -21158,13 +21035,16 @@ pub fn handle_completion_request_with_insert_and_replace(
Ok(Some(lsp::CompletionResponse::Array(
completions
.iter()
.map(|(label, new_text)| lsp::CompletionItem {
label: label.to_string(),
.map(|completion_text| lsp::CompletionItem {
label: completion_text.to_string(),
text_edit: Some(lsp::CompletionTextEdit::InsertAndReplace(
lsp::InsertReplaceEdit {
insert: insert_range,
insert: lsp::Range {
start: replace_range.start,
end: complete_from_position,
},
replace: replace_range,
new_text: new_text.to_string(),
new_text: completion_text.to_string(),
},
)),
..Default::default()

View File

@@ -42,13 +42,13 @@ use git::{
use gpui::{
Action, Along, AnyElement, App, AppContext, AvailableSpace, Axis as ScrollbarAxis, BorderStyle,
Bounds, ClickEvent, ContentMask, Context, Corner, Corners, CursorStyle, DispatchPhase, Edges,
Element, ElementInputHandler, Entity, Focusable as _, FontId, GlobalElementId, Hitbox,
HitboxBehavior, Hsla, InteractiveElement, IntoElement, IsZero, Keystroke, Length,
ModifiersChangedEvent, MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, PaintQuad,
ParentElement, Pixels, ScrollDelta, ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString,
Size, StatefulInteractiveElement, Style, Styled, TextRun, TextStyleRefinement, WeakEntity,
Window, anchored, deferred, div, fill, linear_color_stop, linear_gradient, outline, point, px,
quad, relative, size, solid_background, transparent_black,
Element, ElementInputHandler, Entity, Focusable as _, FontId, GlobalElementId, Hitbox, Hsla,
InteractiveElement, IntoElement, IsZero, Keystroke, Length, ModifiersChangedEvent, MouseButton,
MouseDownEvent, MouseMoveEvent, MouseUpEvent, PaintQuad, ParentElement, Pixels, ScrollDelta,
ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString, Size, StatefulInteractiveElement,
Style, Styled, TextRun, TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill,
linear_color_stop, linear_gradient, outline, point, px, quad, relative, size, solid_background,
transparent_black,
};
use itertools::Itertools;
use language::language_settings::{
@@ -1512,17 +1512,6 @@ impl EditorElement {
ShowScrollbar::Never => return None,
};
// The horizontal scrollbar is usually slightly offset to align nicely with
// indent guides. However, this offset is not needed if indent guides are
// disabled for the current editor.
let content_offset = self
.editor
.read(cx)
.show_indent_guides
.is_none_or(|should_show| should_show)
.then_some(content_offset)
.unwrap_or_default();
Some(EditorScrollbars::from_scrollbar_axes(
ScrollbarAxes {
horizontal: scrollbar_settings.axes.horizontal
@@ -1620,7 +1609,7 @@ impl EditorElement {
);
let layout = ScrollbarLayout::for_minimap(
window.insert_hitbox(minimap_bounds, HitboxBehavior::Normal),
window.insert_hitbox(minimap_bounds, false),
visible_editor_lines,
total_editor_lines,
minimap_line_height,
@@ -1791,7 +1780,7 @@ impl EditorElement {
if matches!(hunk, DisplayDiffHunk::Unfolded { .. }) {
let hunk_bounds =
Self::diff_hunk_bounds(snapshot, line_height, gutter_hitbox.bounds, hunk);
*hitbox = Some(window.insert_hitbox(hunk_bounds, HitboxBehavior::BlockMouse));
*hitbox = Some(window.insert_hitbox(hunk_bounds, true));
}
}
}
@@ -2883,7 +2872,7 @@ impl EditorElement {
let hitbox = line_origin.map(|line_origin| {
window.insert_hitbox(
Bounds::new(line_origin, size(shaped_line.width, line_height)),
HitboxBehavior::Normal,
false,
)
});
#[cfg(test)]
@@ -6371,7 +6360,7 @@ impl EditorElement {
}
};
if phase == DispatchPhase::Bubble && hitbox.should_handle_scroll(window) {
if phase == DispatchPhase::Bubble && hitbox.is_hovered(window) {
delta = delta.coalesce(event.delta);
editor.update(cx, |editor, cx| {
let position_map: &PositionMap = &position_map;
@@ -7651,17 +7640,15 @@ impl Element for EditorElement {
.map(|(guide, active)| (self.column_pixels(*guide, window, cx), *active))
.collect::<SmallVec<[_; 2]>>();
let hitbox = window.insert_hitbox(bounds, HitboxBehavior::Normal);
let gutter_hitbox = window.insert_hitbox(
gutter_bounds(bounds, gutter_dimensions),
HitboxBehavior::Normal,
);
let hitbox = window.insert_hitbox(bounds, false);
let gutter_hitbox =
window.insert_hitbox(gutter_bounds(bounds, gutter_dimensions), false);
let text_hitbox = window.insert_hitbox(
Bounds {
origin: gutter_hitbox.top_right(),
size: size(text_width, bounds.size.height),
},
HitboxBehavior::Normal,
false,
);
let content_origin = text_hitbox.origin + content_offset;
@@ -8882,7 +8869,7 @@ impl EditorScrollbars {
})
.map(|(viewport_size, scroll_range)| {
ScrollbarLayout::new(
window.insert_hitbox(scrollbar_bounds_for(axis), HitboxBehavior::Normal),
window.insert_hitbox(scrollbar_bounds_for(axis), false),
viewport_size,
scroll_range,
glyph_grid_cell.along(axis),

View File

@@ -1,9 +1,9 @@
use std::{cmp::Ordering, ops::Range, time::Duration};
use std::{ops::Range, time::Duration};
use collections::HashSet;
use gpui::{App, AppContext as _, Context, Task, Window};
use language::language_settings::language_settings;
use multi_buffer::{IndentGuide, MultiBufferRow, ToPoint};
use multi_buffer::{IndentGuide, MultiBufferRow};
use text::{LineIndent, Point};
use util::ResultExt;
@@ -154,28 +154,12 @@ pub fn indent_guides_in_range(
snapshot: &DisplaySnapshot,
cx: &App,
) -> Vec<IndentGuide> {
let start_offset = snapshot
let start_anchor = snapshot
.buffer_snapshot
.point_to_offset(Point::new(visible_buffer_range.start.0, 0));
let end_offset = snapshot
.anchor_before(Point::new(visible_buffer_range.start.0, 0));
let end_anchor = snapshot
.buffer_snapshot
.point_to_offset(Point::new(visible_buffer_range.end.0, 0));
let start_anchor = snapshot.buffer_snapshot.anchor_before(start_offset);
let end_anchor = snapshot.buffer_snapshot.anchor_after(end_offset);
let mut fold_ranges = Vec::<Range<Point>>::new();
let mut folds = snapshot.folds_in_range(start_offset..end_offset).peekable();
while let Some(fold) = folds.next() {
let start = fold.range.start.to_point(&snapshot.buffer_snapshot);
let end = fold.range.end.to_point(&snapshot.buffer_snapshot);
if let Some(last_range) = fold_ranges.last_mut() {
if last_range.end >= start {
last_range.end = last_range.end.max(end);
continue;
}
}
fold_ranges.push(start..end);
}
.anchor_after(Point::new(visible_buffer_range.end.0, 0));
snapshot
.buffer_snapshot
@@ -185,19 +169,15 @@ pub fn indent_guides_in_range(
return false;
}
let has_containing_fold = fold_ranges
.binary_search_by(|fold_range| {
if fold_range.start >= Point::new(indent_guide.start_row.0, 0) {
Ordering::Greater
} else if fold_range.end < Point::new(indent_guide.end_row.0, 0) {
Ordering::Less
} else {
Ordering::Equal
}
})
.is_ok();
!has_containing_fold
let start = MultiBufferRow(indent_guide.start_row.0.saturating_sub(1));
// Filter out indent guides that are inside a fold
// All indent guides that are starting "offscreen" have a start value of the first visible row minus one
// Therefore checking if a line is folded at first visible row minus one causes the other indent guides that are not related to the fold to disappear as well
let is_folded = snapshot.is_line_folded(start);
let line_indent = snapshot.line_indent_for_buffer_row(start);
let contained_in_fold =
line_indent.len(indent_guide.tab_size) <= indent_guide.indent_level();
!(is_folded && contained_in_fold)
})
.collect()
}

View File

@@ -600,7 +600,7 @@ pub(crate) fn handle_from(
})
.collect::<Vec<_>>();
this.update_in(cx, |this, window, cx| {
this.change_selections_without_showing_completions(None, window, cx, |s| {
this.change_selections_inner(None, false, window, cx, |s| {
s.select(base_selections);
});
})

View File

@@ -22,7 +22,6 @@ use smol::stream::StreamExt;
use task::ResolvedTask;
use task::TaskContext;
use text::BufferId;
use ui::SharedString;
use util::ResultExt as _;
pub(crate) fn find_specific_language_server_in_selection<F>(
@@ -134,22 +133,13 @@ pub fn lsp_tasks(
cx.spawn(async move |cx| {
cx.spawn(async move |cx| {
let mut lsp_tasks = HashMap::default();
let mut lsp_tasks = Vec::new();
while let Some(server_to_query) = lsp_task_sources.next().await {
if let Some((server_id, buffers)) = server_to_query {
let source_kind = TaskSourceKind::Lsp(server_id);
let id_base = source_kind.to_id_base();
let mut new_lsp_tasks = Vec::new();
for buffer in buffers {
let source_kind = match buffer.update(cx, |buffer, _| {
buffer.language().map(|language| language.name())
}) {
Ok(Some(language_name)) => TaskSourceKind::Lsp {
server: server_id,
language_name: SharedString::from(language_name),
},
Ok(None) => continue,
Err(_) => return Vec::new(),
};
let id_base = source_kind.to_id_base();
let lsp_buffer_context = lsp_task_context(&project, &buffer, cx)
.await
.unwrap_or_default();
@@ -178,14 +168,11 @@ pub fn lsp_tasks(
);
}
}
lsp_tasks
.entry(source_kind)
.or_insert_with(Vec::new)
.append(&mut new_lsp_tasks);
}
lsp_tasks.push((source_kind, new_lsp_tasks));
}
}
lsp_tasks.into_iter().collect()
lsp_tasks
})
.race({
// `lsp::LSP_REQUEST_TIMEOUT` is larger than we want for the modal to open fast

View File

@@ -74,6 +74,8 @@ impl Editor {
pub(super) fn should_open_signature_help_automatically(
&mut self,
old_cursor_position: &Anchor,
backspace_pressed: bool,
cx: &mut Context<Self>,
) -> bool {
if !(self.signature_help_state.is_shown() || self.auto_signature_help_enabled(cx)) {
@@ -82,7 +84,9 @@ impl Editor {
let newest_selection = self.selections.newest::<usize>(cx);
let head = newest_selection.head();
if !newest_selection.is_empty() && head != newest_selection.tail() {
// There are two cases where the head and tail of a selection are different: selecting multiple ranges and using backspace.
// If we dont exclude the backspace case, signature_help will blink every time backspace is pressed, so we need to prevent this.
if !newest_selection.is_empty() && !backspace_pressed && head != newest_selection.tail() {
self.signature_help_state
.hide(SignatureHelpHiddenBy::Selection);
return false;
@@ -228,6 +232,7 @@ pub struct SignatureHelpState {
task: Option<Task<()>>,
popover: Option<SignatureHelpPopover>,
hidden_by: Option<SignatureHelpHiddenBy>,
backspace_pressed: bool,
}
impl SignatureHelpState {
@@ -249,6 +254,14 @@ impl SignatureHelpState {
self.popover.as_mut()
}
pub fn backspace_pressed(&self) -> bool {
self.backspace_pressed
}
pub fn set_backspace_pressed(&mut self, backspace_pressed: bool) {
self.backspace_pressed = backspace_pressed;
}
pub fn set_popover(&mut self, popover: SignatureHelpPopover) {
self.popover = Some(popover);
self.hidden_by = None;

View File

@@ -67,4 +67,3 @@ unindent.workspace = true
util.workspace = true
uuid.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true

View File

@@ -19,7 +19,6 @@ use collections::HashMap;
use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
use gpui::{App, AppContext, AsyncApp, Entity};
use language_model::{LanguageModel, Role, StopReason};
use zed_llm_client::CompletionIntent;
pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
@@ -50,7 +49,6 @@ pub struct ExampleMetadata {
pub max_assertions: Option<usize>,
pub profile_id: AgentProfileId,
pub existing_thread_json: Option<String>,
pub max_turns: Option<u32>,
}
#[derive(Clone, Debug)]
@@ -309,7 +307,7 @@ impl ExampleContext {
let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
thread.set_remaining_turns(iterations);
thread.send_to_model(model, CompletionIntent::UserPrompt, None, cx);
thread.send_to_model(model, None, cx);
thread.messages().len()
})?;

View File

@@ -22,7 +22,6 @@ impl Example for AddArgToTraitMethod {
max_assertions: None,
profile_id: AgentProfileId::default(),
existing_thread_json: None,
max_turns: None,
}
}

View File

@@ -23,7 +23,6 @@ impl Example for CodeBlockCitations {
max_assertions: None,
profile_id: AgentProfileId::default(),
existing_thread_json: None,
max_turns: None,
}
}

View File

@@ -17,7 +17,6 @@ impl Example for CommentTranslation {
max_assertions: Some(1),
profile_id: AgentProfileId::default(),
existing_thread_json: None,
max_turns: None,
}
}

View File

@@ -19,7 +19,6 @@ impl Example for FileSearchExample {
max_assertions: Some(3),
profile_id: AgentProfileId::default(),
existing_thread_json: None,
max_turns: None,
}
}

View File

@@ -82,7 +82,6 @@ impl DeclarativeExample {
max_assertions: None,
profile_id,
existing_thread_json,
max_turns: base.max_turns,
};
Ok(DeclarativeExample {
@@ -125,8 +124,6 @@ pub struct ExampleToml {
pub thread_assertions: BTreeMap<String, String>,
#[serde(default)]
pub existing_thread_path: Option<String>,
#[serde(default)]
pub max_turns: Option<u32>,
}
#[async_trait(?Send)]
@@ -137,8 +134,7 @@ impl Example for DeclarativeExample {
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
cx.push_user_message(&self.prompt);
let max_turns = self.metadata.max_turns.unwrap_or(1000);
let _ = cx.run_turns(max_turns).await;
let _ = cx.run_to_end().await;
Ok(())
}

View File

@@ -31,7 +31,6 @@ impl Example for FileOverwriteExample {
max_assertions: Some(1),
profile_id: AgentProfileId::default(),
existing_thread_json: Some(thread_json.to_string()),
max_turns: None,
}
}

Some files were not shown because too many files have changed in this diff Show More