diff --git a/.github/ISSUE_TEMPLATE/99_other.yml b/.github/ISSUE_TEMPLATE/99_other.yml new file mode 100644 index 0000000000..9383a576b1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/99_other.yml @@ -0,0 +1,19 @@ +name: Other [Staff Only] +description: Zed Staff Only +body: + - type: textarea + attributes: + label: Summary + value: | + + SUMMARY_SENTENCE_HERE + + ### Description + + IF YOU DO NOT WORK FOR ZED INDUSTRIES DO NOT CREATE ISSUES WITH THIS TEMPLATE. + THEY WILL BE AUTO-CLOSED AND MAY RESULT IN YOU BEING BANNED FROM THE ZED ISSUE TRACKER. + + FEATURE REQUESTS / SUPPORT REQUESTS SHOULD BE OPENED AS DISCUSSIONS: + https://github.com/zed-industries/zed/discussions/new/choose + validations: + required: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6375d74f15..4ddb5173e7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -594,7 +594,7 @@ jobs: timeout-minutes: 60 name: Linux x86_x64 release bundle runs-on: - - buildjet-16vcpu-ubuntu-2004 + - buildjet-16vcpu-ubuntu-2004 # ubuntu 20.04 for minimal glibc if: | startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') @@ -622,26 +622,23 @@ jobs: - name: Create Linux .tar.gz bundle run: script/bundle-linux - - name: Upload Linux bundle to workflow run if main branch or specific label + - name: Upload Artifact to Workflow - zed (run-bundling) uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4 - if: | - github.ref == 'refs/heads/main' - || contains(github.event.pull_request.labels.*.name, 'run-bundling') + if: contains(github.event.pull_request.labels.*.name, 'run-bundling') with: name: zed-${{ github.event.pull_request.head.sha || github.sha }}-x86_64-unknown-linux-gnu.tar.gz path: target/release/zed-*.tar.gz - - name: Upload Linux remote server to workflow run if main branch or specific label + - name: Upload Artifact to Workflow - zed-remote-server (run-bundling) uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4 - if: | - github.ref == 'refs/heads/main' - || contains(github.event.pull_request.labels.*.name, 'run-bundling') + if: contains(github.event.pull_request.labels.*.name, 'run-bundling') with: name: zed-remote-server-${{ github.event.pull_request.head.sha || github.sha }}-x86_64-unknown-linux-gnu.gz path: target/zed-remote-server-linux-x86_64.gz - - name: Upload app bundle to release + - name: Upload Artifacts to release uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1 + if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) }} with: draft: true prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }} @@ -680,29 +677,26 @@ jobs: # This exports RELEASE_CHANNEL into env (GITHUB_ENV) script/determine-release-channel - - name: Create and upload Linux .tar.gz bundle + - name: Create and upload Linux .tar.gz bundles run: script/bundle-linux - - name: Upload Linux bundle to workflow run if main branch or specific label + - name: Upload Artifact to Workflow - zed (run-bundling) uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4 - if: | - github.ref == 'refs/heads/main' - || contains(github.event.pull_request.labels.*.name, 'run-bundling') + if: contains(github.event.pull_request.labels.*.name, 'run-bundling') with: name: zed-${{ github.event.pull_request.head.sha || github.sha }}-aarch64-unknown-linux-gnu.tar.gz path: target/release/zed-*.tar.gz - - name: Upload Linux remote server to workflow run if main branch or specific label + - name: Upload Artifact to Workflow - zed-remote-server (run-bundling) uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4 - if: | - github.ref == 'refs/heads/main' - || contains(github.event.pull_request.labels.*.name, 'run-bundling') + if: contains(github.event.pull_request.labels.*.name, 'run-bundling') with: name: zed-remote-server-${{ github.event.pull_request.head.sha || github.sha }}-aarch64-unknown-linux-gnu.gz path: target/zed-remote-server-linux-aarch64.gz - - name: Upload app bundle to release + - name: Upload Artifacts to release uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1 + if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) }} with: draft: true prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }} diff --git a/.github/workflows/deploy_collab.yml b/.github/workflows/deploy_collab.yml index d921a08bf1..eb5875afcc 100644 --- a/.github/workflows/deploy_collab.yml +++ b/.github/workflows/deploy_collab.yml @@ -117,12 +117,10 @@ jobs: export ZED_KUBE_NAMESPACE=production export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=10 export ZED_API_LOAD_BALANCER_SIZE_UNIT=2 - export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=2 elif [[ $GITHUB_REF_NAME = "collab-staging" ]]; then export ZED_KUBE_NAMESPACE=staging export ZED_COLLAB_LOAD_BALANCER_SIZE_UNIT=1 export ZED_API_LOAD_BALANCER_SIZE_UNIT=1 - export ZED_LLM_LOAD_BALANCER_SIZE_UNIT=1 else echo "cowardly refusing to deploy from an unknown branch" exit 1 @@ -147,9 +145,3 @@ jobs: envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f - kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}" - - export ZED_SERVICE_NAME=llm - export ZED_LOAD_BALANCER_SIZE_UNIT=$ZED_LLM_LOAD_BALANCER_SIZE_UNIT - envsubst < crates/collab/k8s/collab.template.yml | kubectl apply -f - - kubectl -n "$ZED_KUBE_NAMESPACE" rollout status deployment/$ZED_SERVICE_NAME --watch - echo "deployed ${ZED_SERVICE_NAME} to ${ZED_KUBE_NAMESPACE}" diff --git a/Cargo.lock b/Cargo.lock index 6ba576975a..87c50bd31f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -326,7 +326,6 @@ dependencies = [ "serde_json", "strum", "thiserror 2.0.12", - "util", "workspace-hack", ] @@ -1183,6 +1182,18 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "auto_update_helper" +version = "0.1.0" +dependencies = [ + "anyhow", + "log", + "simplelog", + "windows 0.61.1", + "winresource", + "workspace-hack", +] + [[package]] name = "auto_update_ui" version = "0.1.0" @@ -2932,7 +2943,6 @@ dependencies = [ name = "collab" version = "0.44.0" dependencies = [ - "anthropic", "anyhow", "assistant", "assistant_context_editor", @@ -3176,14 +3186,18 @@ dependencies = [ name = "component_preview" version = "0.1.0" dependencies = [ + "anyhow", "client", "collections", "component", + "db", "gpui", "languages", "notifications", "project", + "serde", "ui", + "ui_input", "workspace", "workspace-hack", ] @@ -3988,7 +4002,6 @@ dependencies = [ "node_runtime", "parking_lot", "paths", - "regex", "schemars", "serde", "serde_json", @@ -4020,7 +4033,6 @@ dependencies = [ "gpui", "language", "paths", - "regex", "serde", "serde_json", "task", @@ -4164,6 +4176,7 @@ dependencies = [ "collections", "command_palette_hooks", "dap", + "db", "editor", "env_logger 0.11.8", "feature_flags", @@ -4863,25 +4876,37 @@ dependencies = [ "assistant_settings", "assistant_tool", "assistant_tools", + "async-watch", + "chrono", + "clap", "client", + "collections", "context_server", "dap", "env_logger 0.11.8", + "extension", "fs", "futures 0.3.31", "gpui", "gpui_tokio", + "handlebars 4.5.0", "language", + "language_extension", "language_model", "language_models", + "languages", "node_runtime", + "paths", "project", "prompt_store", "release_channel", "reqwest_client", "serde", "settings", + "shellexpand 2.1.2", "toml 0.8.20", + "unindent", + "util", "workspace-hack", ] @@ -4976,10 +5001,10 @@ dependencies = [ "async-tar", "async-trait", "collections", - "convert_case 0.8.0", "fs", "futures 0.3.31", "gpui", + "heck 0.5.0", "http_client", "language", "log", @@ -7654,6 +7679,7 @@ dependencies = [ name = "language_model_selector" version = "0.1.0" dependencies = [ + "collections", "feature_flags", "gpui", "language_model", @@ -7704,6 +7730,7 @@ dependencies = [ "smol", "strum", "theme", + "thiserror 2.0.12", "tiktoken-rs", "tokio", "ui", @@ -17628,6 +17655,7 @@ dependencies = [ "ui", "util", "uuid", + "windows 0.61.1", "workspace-hack", "zed_actions", ] @@ -17790,6 +17818,8 @@ dependencies = [ "wasmtime-cranelift", "wasmtime-environ", "winapi", + "windows-core 0.61.0", + "windows-numerics", "windows-sys 0.48.0", "windows-sys 0.52.0", "windows-sys 0.59.0", @@ -18134,7 +18164,7 @@ dependencies = [ [[package]] name = "zed" -version = "0.182.0" +version = "0.183.0" dependencies = [ "activity_indicator", "agent", @@ -18230,7 +18260,6 @@ dependencies = [ "settings", "settings_ui", "shellexpand 2.1.2", - "simplelog", "smol", "snippet_provider", "snippets_ui", @@ -18581,7 +18610,9 @@ name = "zlog" version = "0.1.0" dependencies = [ "anyhow", + "chrono", "log", + "tempfile", "workspace-hack", ] diff --git a/Cargo.toml b/Cargo.toml index 7ba482e8c7..844ff7e7c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ members = [ "crates/assistant_tools", "crates/audio", "crates/auto_update", + "crates/auto_update_helper", "crates/auto_update_ui", "crates/aws_http_client", "crates/bedrock", @@ -224,6 +225,7 @@ assistant_tool = { path = "crates/assistant_tool" } assistant_tools = { path = "crates/assistant_tools" } audio = { path = "crates/audio" } auto_update = { path = "crates/auto_update" } +auto_update_helper = { path = "crates/auto_update_helper" } auto_update_ui = { path = "crates/auto_update_ui" } aws_http_client = { path = "crates/aws_http_client" } bedrock = { path = "crates/bedrock" } @@ -403,8 +405,12 @@ async-tungstenite = "0.29.1" async-watch = "0.3.1" async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] } aws-config = { version = "1.6.1", features = ["behavior-version-latest"] } -aws-credential-types = { version = "1.2.2", features = ["hardcoded-credentials"] } -aws-sdk-bedrockruntime = { version = "1.80.0", features = ["behavior-version-latest"] } +aws-credential-types = { version = "1.2.2", features = [ + "hardcoded-credentials", +] } +aws-sdk-bedrockruntime = { version = "1.80.0", features = [ + "behavior-version-latest", +] } aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] } aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] } base64 = "0.22" @@ -443,6 +449,7 @@ futures-lite = "1.13" git2 = { version = "0.20.1", default-features = false } globset = "0.4" handlebars = "4.3" +heck = "0.5" heed = { version = "0.21.0", features = ["read-txn-no-tls"] } hex = "0.4.3" html5ever = "0.27.0" @@ -619,12 +626,10 @@ features = [ [workspace.dependencies.windows] version = "0.61" features = [ - "Foundation_Collections", "Foundation_Numerics", "Storage_Search", "Storage_Streams", "System_Threading", - "UI_StartScreen", "UI_ViewManagement", "Wdk_System_SystemServices", "Win32_Globalization", @@ -651,6 +656,7 @@ features = [ "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", + "Win32_System_Variant", "Win32_System_WinRT", "Win32_UI_Controls", "Win32_UI_HiDpi", @@ -658,6 +664,7 @@ features = [ "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_Shell_Common", + "Win32_UI_Shell_PropertiesSystem", "Win32_UI_WindowsAndMessaging", ] @@ -781,4 +788,12 @@ let_underscore_future = "allow" too_many_arguments = "allow" [workspace.metadata.cargo-machete] -ignored = ["bindgen", "cbindgen", "prost_build", "serde", "component", "linkme", "workspace-hack"] +ignored = [ + "bindgen", + "cbindgen", + "prost_build", + "serde", + "component", + "linkme", + "workspace-hack", +] diff --git a/assets/icons/ai_anthropic_hosted.svg b/assets/icons/ai_anthropic_hosted.svg deleted file mode 100644 index b088520490..0000000000 --- a/assets/icons/ai_anthropic_hosted.svg +++ /dev/null @@ -1,12 +0,0 @@ - - - - - - - - - - - - diff --git a/assets/icons/layout.svg b/assets/icons/layout.svg new file mode 100644 index 0000000000..79464013b1 --- /dev/null +++ b/assets/icons/layout.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 2fd742c5ae..9b94ed32a1 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -150,7 +150,9 @@ "context": "AgentDiff", "bindings": { "ctrl-y": "agent::Keep", - "ctrl-n": "agent::Reject" + "ctrl-n": "agent::Reject", + "ctrl-shift-y": "agent::KeepAll", + "ctrl-shift-n": "agent::RejectAll" } }, { @@ -352,11 +354,11 @@ "alt-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink Selection "ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection "ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word - "ctrl-d": ["editor::SelectNext", { "replace_newest": false }], - "ctrl-shift-down": ["editor::SelectNext", { "replace_newest": false }], // Add selection to Next Find Match - "ctrl-shift-up": ["editor::SelectPrevious", { "replace_newest": false }], - "ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }], - "ctrl-k ctrl-shift-d": ["editor::SelectPrevious", { "replace_newest": true }], + "ctrl-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand + "ctrl-shift-down": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch + "ctrl-shift-up": ["editor::SelectPrevious", { "replace_newest": false }], // editor.action.addSelectionToPreviousFindMatch + "ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }], // editor.action.moveSelectionToNextFindMatch / find_under_expand_skip + "ctrl-k ctrl-shift-d": ["editor::SelectPrevious", { "replace_newest": true }], // editor.action.moveSelectionToPreviousFindMatch "ctrl-k ctrl-i": "editor::Hover", "ctrl-/": ["editor::ToggleComments", { "advance_downwards": false }], "ctrl-u": "editor::UndoSelection", @@ -780,6 +782,7 @@ "shift-tab": "git_panel::FocusEditor", "escape": "git_panel::ToggleFocus", "ctrl-enter": "git::Commit", + "ctrl-shift-enter": "git::Amend", "alt-enter": "menu::SecondaryConfirm", "delete": ["git::RestoreFile", { "skip_prompt": false }], "backspace": ["git::RestoreFile", { "skip_prompt": false }], @@ -788,12 +791,20 @@ "ctrl-delete": ["git::RestoreFile", { "skip_prompt": false }] } }, + { + "context": "GitPanel && CommitEditor", + "use_key_equivalents": true, + "bindings": { + "escape": "git::Cancel" + } + }, { "context": "GitCommit > Editor", "bindings": { "escape": "menu::Cancel", "enter": "editor::Newline", "ctrl-enter": "git::Commit", + "ctrl-shift-enter": "git::Amend", "alt-l": "git::GenerateCommitMessage" } }, @@ -815,6 +826,7 @@ "context": "GitDiff > Editor", "bindings": { "ctrl-enter": "git::Commit", + "ctrl-shift-enter": "git::Amend", "ctrl-space": "git::StageAll", "ctrl-shift-space": "git::UnstageAll" } @@ -833,6 +845,7 @@ "shift-tab": "git_panel::FocusChanges", "enter": "editor::Newline", "ctrl-enter": "git::Commit", + "ctrl-shift-enter": "git::Amend", "alt-up": "git_panel::FocusChanges", "alt-l": "git::GenerateCommitMessage" } diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 034ac3b8a2..c5cf9e019b 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -242,7 +242,9 @@ "use_key_equivalents": true, "bindings": { "cmd-y": "agent::Keep", - "cmd-n": "agent::Reject" + "cmd-n": "agent::Reject", + "cmd-shift-y": "agent::KeepAll", + "cmd-shift-n": "agent::RejectAll" } }, { @@ -489,12 +491,15 @@ "alt-shift-down": "editor::DuplicateLineDown", "ctrl-shift-right": "editor::SelectLargerSyntaxNode", // Expand Selection "ctrl-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink Selection - "cmd-d": ["editor::SelectNext", { "replace_newest": false }], // Add selection to Next Find Match + "cmd-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand "cmd-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection "cmd-f2": "editor::SelectAllMatches", // Select all occurrences of current word - "ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": false }], - "cmd-k cmd-d": ["editor::SelectNext", { "replace_newest": true }], - "cmd-k ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": true }], + "cmd-k cmd-d": ["editor::SelectNext", { "replace_newest": true }], // editor.action.moveSelectionToNextFindMatch / find_under_expand_skip + // macOS binds `ctrl-cmd-d` to Show Dictionary which breaks these two binds + // To use `ctrl-cmd-d` or `ctrl-k ctrl-cmd-d` in Zed you must execute this command and then restart: + // defaults write com.apple.symbolichotkeys AppleSymbolicHotKeys -dict-add 70 'enabled' + "ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": false }], // editor.action.addSelectionToPreviousFindMatch + "cmd-k ctrl-cmd-d": ["editor::SelectPrevious", { "replace_newest": true }], // editor.action.moveSelectionToPreviousFindMatch "cmd-k cmd-i": "editor::Hover", "cmd-/": ["editor::ToggleComments", { "advance_downwards": false }], "cmd-u": "editor::UndoSelection", @@ -850,17 +855,26 @@ "shift-tab": "git_panel::FocusEditor", "escape": "git_panel::ToggleFocus", "cmd-enter": "git::Commit", + "cmd-shift-enter": "git::Amend", "backspace": ["git::RestoreFile", { "skip_prompt": false }], "delete": ["git::RestoreFile", { "skip_prompt": false }], "cmd-backspace": ["git::RestoreFile", { "skip_prompt": true }], "cmd-delete": ["git::RestoreFile", { "skip_prompt": true }] } }, + { + "context": "GitPanel && CommitEditor", + "use_key_equivalents": true, + "bindings": { + "escape": "git::Cancel" + } + }, { "context": "GitDiff > Editor", "use_key_equivalents": true, "bindings": { "cmd-enter": "git::Commit", + "cmd-shift-enter": "git::Amend", "cmd-ctrl-y": "git::StageAll", "cmd-ctrl-shift-y": "git::UnstageAll" } @@ -871,6 +885,7 @@ "bindings": { "enter": "editor::Newline", "cmd-enter": "git::Commit", + "cmd-shift-enter": "git::Amend", "tab": "git_panel::FocusChanges", "shift-tab": "git_panel::FocusChanges", "alt-up": "git_panel::FocusChanges", @@ -900,6 +915,7 @@ "enter": "editor::Newline", "escape": "menu::Cancel", "cmd-enter": "git::Commit", + "cmd-shift-enter": "git::Amend", "alt-tab": "git::GenerateCommitMessage" } }, diff --git a/assets/keymaps/linux/sublime_text.json b/assets/keymaps/linux/sublime_text.json index 258b8a5629..c3f56350b9 100644 --- a/assets/keymaps/linux/sublime_text.json +++ b/assets/keymaps/linux/sublime_text.json @@ -37,6 +37,8 @@ "ctrl-shift-a": "editor::SelectLargerSyntaxNode", "ctrl-shift-d": "editor::DuplicateSelection", "alt-f3": "editor::SelectAllMatches", // find_all_under + // "ctrl-f3": "", // find_under (cancels any selections) + // "cmd-alt-shift-g": "" // find_under_prev (cancels any selections) "f9": "editor::SortLinesCaseSensitive", "ctrl-f9": "editor::SortLinesCaseInsensitive", "f12": "editor::GoToDefinition", diff --git a/assets/keymaps/macos/sublime_text.json b/assets/keymaps/macos/sublime_text.json index d3929af9e9..6251ae0ccd 100644 --- a/assets/keymaps/macos/sublime_text.json +++ b/assets/keymaps/macos/sublime_text.json @@ -38,6 +38,8 @@ "cmd-shift-a": "editor::SelectLargerSyntaxNode", "cmd-shift-d": "editor::DuplicateSelection", "ctrl-cmd-g": "editor::SelectAllMatches", // find_all_under + // "cmd-alt-g": "", // find_under (cancels any selections) + // "cmd-alt-shift-g": "" // find_under_prev (cancels any selections) "f5": "editor::SortLinesCaseSensitive", "ctrl-f5": "editor::SortLinesCaseInsensitive", "shift-f12": "editor::FindAllReferences", diff --git a/assets/keymaps/vim.json b/assets/keymaps/vim.json index 40c96ed5d8..452731ebc2 100644 --- a/assets/keymaps/vim.json +++ b/assets/keymaps/vim.json @@ -539,6 +539,7 @@ "bindings": { "d": "vim::CurrentLine", "s": "vim::PushDeleteSurrounds", + "v": "vim::PushForcedMotion", // "d v" "o": "editor::ToggleSelectedDiffHunks", // "d o" "shift-o": "git::ToggleStaged", "p": "git::Restore", // "d p" @@ -587,6 +588,7 @@ "context": "vim_operator == y", "bindings": { "y": "vim::CurrentLine", + "v": "vim::PushForcedMotion", "s": ["vim::PushAddSurrounds", {}] } }, diff --git a/assets/settings/default.json b/assets/settings/default.json index 9276cf8117..98ee37f213 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -80,6 +80,8 @@ // Values are clamped to the [0.0, 1.0] range. "inactive_opacity": 1.0 }, + // Layout mode of the bottom dock. Defaults to "contained" + "bottom_dock_layout": "contained", // The direction that you want to split panes horizontally. Defaults to "up" "pane_split_direction_horizontal": "up", // The direction that you want to split panes horizontally. Defaults to "left" @@ -642,6 +644,7 @@ // We don't know which of the context server tools are safe for the "Ask" profile, so we don't enable them by default. // "enable_all_context_servers": true, "tools": { + "contents": true, "diagnostics": true, "fetch": true, "list_directory": false, @@ -661,6 +664,7 @@ "batch_tool": true, "code_actions": true, "code_symbols": true, + "contents": true, "copy_path": false, "create_file": true, "delete_path": false, diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 830646e38d..0ef73946d8 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -13,18 +13,18 @@ use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting}; use assistant_tool::ToolUseStatus; use collections::{HashMap, HashSet}; use editor::scroll::Autoscroll; -use editor::{Editor, MultiBuffer}; +use editor::{Editor, EditorElement, EditorStyle, MultiBuffer}; use gpui::{ AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, ClipboardItem, DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Hsla, ListAlignment, ListState, MouseButton, PlatformDisplay, ScrollHandle, Stateful, StyleRefinement, Subscription, Task, - TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle, + TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, WindowHandle, linear_color_stop, linear_gradient, list, percentage, pulsating_between, }; use language::{Buffer, LanguageRegistry}; use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role, StopReason}; use markdown::parser::{CodeBlockKind, CodeBlockMetadata}; -use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown}; +use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown}; use project::ProjectItem as _; use rope::Point; use settings::{Settings as _, update_settings_file}; @@ -34,7 +34,9 @@ use std::sync::Arc; use std::time::Duration; use text::ToPoint; use theme::ThemeSettings; -use ui::{Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, Tooltip, prelude::*}; +use ui::{ + Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*, +}; use util::ResultExt as _; use workspace::{OpenOptions, Workspace}; @@ -66,8 +68,6 @@ pub struct ActiveThread { open_feedback_editors: HashMap>, } -const MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK: usize = 5; - struct RenderedMessage { language_registry: Arc, segments: Vec, @@ -176,11 +176,37 @@ fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { }); MarkdownStyle { - base_text_style: text_style, + base_text_style: text_style.clone(), syntax: cx.theme().syntax().clone(), selection_background_color: cx.theme().players().local().selection, code_block_overflow_x_scroll: true, table_overflow_x_scroll: true, + heading_level_styles: Some(HeadingLevelStyles { + h1: Some(TextStyleRefinement { + font_size: Some(rems(1.15).into()), + ..Default::default() + }), + h2: Some(TextStyleRefinement { + font_size: Some(rems(1.1).into()), + ..Default::default() + }), + h3: Some(TextStyleRefinement { + font_size: Some(rems(1.05).into()), + ..Default::default() + }), + h4: Some(TextStyleRefinement { + font_size: Some(rems(1.).into()), + ..Default::default() + }), + h5: Some(TextStyleRefinement { + font_size: Some(rems(0.95).into()), + ..Default::default() + }), + h6: Some(TextStyleRefinement { + font_size: Some(rems(0.875).into()), + ..Default::default() + }), + }), code_block: StyleRefinement { padding: EdgesRefinement { top: Some(DefiniteLength::Absolute(AbsoluteLength::Pixels(Pixels(8.)))), @@ -292,6 +318,8 @@ fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle { } } +const MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK: usize = 10; + fn render_markdown_code_block( message_id: MessageId, ix: usize, @@ -578,7 +606,7 @@ fn render_markdown_code_block( if is_expanded { this.h_full() } else { - this.max_h_40() + this.max_h_80() } }, ) @@ -1497,12 +1525,36 @@ impl ActiveThread { .when(!message_is_empty, |parent| { parent.child( if let Some(edit_message_editor) = edit_message_editor.clone() { + let settings = ThemeSettings::get_global(cx); + let font_size = TextSize::Small.rems(cx); + let line_height = font_size.to_pixels(window.rem_size()) * 1.5; + + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.buffer_font.family.clone(), + font_fallbacks: settings.buffer_font.fallbacks.clone(), + font_features: settings.buffer_font.features.clone(), + font_size: font_size.into(), + line_height: line_height.into(), + ..Default::default() + }; + div() .key_context("EditMessageEditor") .on_action(cx.listener(Self::cancel_editing_message)) .on_action(cx.listener(Self::confirm_editing_message)) .min_h_6() - .child(edit_message_editor) + .pt_1() + .child(EditorElement::new( + &edit_message_editor, + EditorStyle { + background: colors.editor_background, + local_player: cx.theme().players().local(), + text: text_style, + syntax: cx.theme().syntax().clone(), + ..Default::default() + }, + )) .into_any() } else { div() @@ -1667,11 +1719,9 @@ impl ActiveThread { ), Role::Assistant => v_flex() .id(("message-container", ix)) - .ml_2() + .ml_2p5() .pl_2() .pr_4() - .border_l_1() - .border_color(cx.theme().colors().border_variant) .children(message_content) .when(has_tool_uses, |parent| { parent.children( diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index c3bc120ead..8fdcbbcb58 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -1,4 +1,4 @@ -use crate::{Keep, Reject, Thread, ThreadEvent}; +use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent}; use anyhow::Result; use buffer_diff::DiffHunkStatus; use collections::HashSet; @@ -843,7 +843,7 @@ impl ToolbarItemView for AgentDiffToolbar { } impl Render for AgentDiffToolbar { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let agent_diff = match self.agent_diff(cx) { Some(ad) => ad, None => return div(), @@ -855,6 +855,8 @@ impl Render for AgentDiffToolbar { return div(); } + let focus_handle = agent_diff.focus_handle(cx); + h_group_xl() .my_neg_1() .items_center() @@ -864,15 +866,25 @@ impl Render for AgentDiffToolbar { .child( h_group_sm() .child( - Button::new("reject-all", "Reject All").on_click(cx.listener( - |this, _, window, cx| { - this.dispatch_action(&crate::RejectAll, window, cx) - }, - )), + Button::new("reject-all", "Reject All") + .key_binding({ + KeyBinding::for_action_in(&RejectAll, &focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(12.))) + }) + .on_click(cx.listener(|this, _, window, cx| { + this.dispatch_action(&RejectAll, window, cx) + })), ) - .child(Button::new("keep-all", "Keep All").on_click(cx.listener( - |this, _, window, cx| this.dispatch_action(&crate::KeepAll, window, cx), - ))), + .child( + Button::new("keep-all", "Keep All") + .key_binding({ + KeyBinding::for_action_in(&KeepAll, &focus_handle, window, cx) + .map(|kb| kb.size(rems_from_px(12.))) + }) + .on_click(cx.listener(|this, _, window, cx| { + this.dispatch_action(&KeepAll, window, cx) + })), + ), ) } } @@ -882,6 +894,7 @@ mod tests { use super::*; use crate::{ThreadStore, thread_store}; use assistant_settings::AssistantSettings; + use assistant_tool::ToolWorkingSet; use context_server::ContextServerSettings; use editor::EditorSettings; use gpui::TestAppContext; @@ -925,7 +938,7 @@ mod tests { .update(|cx| { ThreadStore::load( project.clone(), - Arc::default(), + cx.new(|_| ToolWorkingSet::default()), Arc::new(PromptBuilder::new(None).unwrap()), cx, ) diff --git a/crates/agent/src/assistant_configuration.rs b/crates/agent/src/assistant_configuration.rs index 8972d786e2..7616b1f8b0 100644 --- a/crates/agent/src/assistant_configuration.rs +++ b/crates/agent/src/assistant_configuration.rs @@ -12,7 +12,9 @@ use fs::Fs; use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription}; use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry}; use settings::{Settings, update_settings_file}; -use ui::{Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, prelude::*}; +use ui::{ + Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, Tooltip, prelude::*, +}; use util::ResultExt as _; use zed_actions::ExtensionCategoryFilter; @@ -27,7 +29,7 @@ pub struct AssistantConfiguration { configuration_views_by_provider: HashMap, context_server_manager: Entity, expanded_context_server_tools: HashMap, bool>, - tools: Arc, + tools: Entity, _registry_subscription: Subscription, } @@ -35,7 +37,7 @@ impl AssistantConfiguration { pub fn new( fs: Arc, context_server_manager: Entity, - tools: Arc, + tools: Entity, window: &mut Window, cx: &mut Context, ) -> Self { @@ -224,7 +226,7 @@ impl AssistantConfiguration { fn render_context_servers_section(&mut self, cx: &mut Context) -> impl IntoElement { let context_servers = self.context_server_manager.read(cx).all_servers().clone(); - let tools_by_source = self.tools.tools_by_source(cx); + let tools_by_source = self.tools.read(cx).tools_by_source(cx); let empty = Vec::new(); const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly."; @@ -236,7 +238,10 @@ impl AssistantConfiguration { .child( v_flex() .gap_0p5() - .child(Headline::new("Context Servers (MCP)").size(HeadlineSize::Small)) + .child( + Headline::new("Model Context Protocol (MCP) Servers") + .size(HeadlineSize::Small), + ) .child(Label::new(SUBHEADING).color(Color::Muted)), ) .children(context_servers.into_iter().map(|context_server| { @@ -262,10 +267,9 @@ impl AssistantConfiguration { .bg(cx.theme().colors().editor_background) .child( h_flex() + .p_1() .justify_between() - .px_2() - .py_1() - .when(are_tools_expanded, |element| { + .when(are_tools_expanded && tool_count > 1, |element| { element .border_b_1() .border_color(cx.theme().colors().border) @@ -275,6 +279,7 @@ impl AssistantConfiguration { .gap_2() .child( Disclosure::new("tool-list-disclosure", are_tools_expanded) + .disabled(tool_count == 0) .on_click(cx.listener({ let context_server_id = context_server.id(); move |this, _event, _window, _cx| { @@ -295,10 +300,11 @@ impl AssistantConfiguration { .child(Label::new(context_server.id())) .child( Label::new(format!("{tool_count} tools")) - .color(Color::Muted), + .color(Color::Muted) + .size(LabelSize::Small), ), ) - .child(h_flex().child( + .child( Switch::new("context-server-switch", is_running.into()).on_click({ let context_server_manager = self.context_server_manager.clone(); @@ -334,7 +340,7 @@ impl AssistantConfiguration { } } }), - )), + ), ) .map(|parent| { if !are_tools_expanded { @@ -344,14 +350,29 @@ impl AssistantConfiguration { parent.child(v_flex().children(tools.into_iter().enumerate().map( |(ix, tool)| { h_flex() - .px_2() + .id("tool-item") + .pl_2() + .pr_1() .py_1() + .gap_2() + .justify_between() .when(ix < tool_count - 1, |element| { element .border_b_1() - .border_color(cx.theme().colors().border) + .border_color(cx.theme().colors().border_variant) }) - .child(Label::new(tool.name())) + .child( + Label::new(tool.name()) + .buffer_font(cx) + .size(LabelSize::Small), + ) + .child( + IconButton::new(("tool-description", ix), IconName::Info) + .shape(ui::IconButtonShape::Square) + .icon_size(IconSize::Small) + .icon_color(Color::Ignored) + .tooltip(Tooltip::text(tool.description())), + ) }, ))) }) @@ -362,7 +383,7 @@ impl AssistantConfiguration { .gap_2() .child( h_flex().w_full().child( - Button::new("add-context-server", "Add Context Server") + Button::new("add-context-server", "Add MCPs Directly") .style(ButtonStyle::Filled) .layer(ElevationIndex::ModalSurface) .full_width() @@ -378,7 +399,7 @@ impl AssistantConfiguration { h_flex().w_full().child( Button::new( "install-context-server-extensions", - "Install Context Server Extensions", + "Install MCP Extensions", ) .style(ButtonStyle::Filled) .layer(ElevationIndex::ModalSurface) diff --git a/crates/agent/src/assistant_configuration/manage_profiles_modal.rs b/crates/agent/src/assistant_configuration/manage_profiles_modal.rs index a4c72cbac9..6f5172a8d4 100644 --- a/crates/agent/src/assistant_configuration/manage_profiles_modal.rs +++ b/crates/agent/src/assistant_configuration/manage_profiles_modal.rs @@ -84,7 +84,7 @@ pub struct NewProfileMode { pub struct ManageProfilesModal { fs: Arc, - tools: Arc, + tools: Entity, thread_store: WeakEntity, focus_handle: FocusHandle, mode: Mode, @@ -117,7 +117,7 @@ impl ManageProfilesModal { pub fn new( fs: Arc, - tools: Arc, + tools: Entity, thread_store: WeakEntity, window: &mut Window, cx: &mut Context, diff --git a/crates/agent/src/assistant_configuration/tool_picker.rs b/crates/agent/src/assistant_configuration/tool_picker.rs index eabd9e172b..2b105e87a2 100644 --- a/crates/agent/src/assistant_configuration/tool_picker.rs +++ b/crates/agent/src/assistant_configuration/tool_picker.rs @@ -60,7 +60,7 @@ pub struct ToolPickerDelegate { impl ToolPickerDelegate { pub fn new( fs: Arc, - tool_set: Arc, + tool_set: Entity, thread_store: WeakEntity, profile_id: AgentProfileId, profile: AgentProfile, @@ -68,7 +68,7 @@ impl ToolPickerDelegate { ) -> Self { let mut tool_entries = Vec::new(); - for (source, tools) in tool_set.tools_by_source(cx) { + for (source, tools) in tool_set.read(cx).tools_by_source(cx) { tool_entries.extend(tools.into_iter().map(|tool| ToolEntry { name: tool.name().into(), source: source.clone(), @@ -192,7 +192,7 @@ impl PickerDelegate for ToolPickerDelegate { if active_profile_id == &self.profile_id { self.thread_store .update(cx, |this, cx| { - this.load_profile(&self.profile, cx); + this.load_profile(self.profile.clone(), cx); }) .log_err(); } diff --git a/crates/agent/src/assistant_model_selector.rs b/crates/agent/src/assistant_model_selector.rs index 11726b2574..091071af29 100644 --- a/crates/agent/src/assistant_model_selector.rs +++ b/crates/agent/src/assistant_model_selector.rs @@ -80,17 +80,16 @@ impl AssistantModelSelector { impl Render for AssistantModelSelector { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let model_registry = LanguageModelRegistry::read_global(cx); + let focus_handle = self.focus_handle.clone(); + let model_registry = LanguageModelRegistry::read_global(cx); let model = match self.model_type { ModelType::Default => model_registry.default_model(), ModelType::InlineAssistant => model_registry.inline_assistant_model(), }; - - let focus_handle = self.focus_handle.clone(); - let model_name = match model { - Some(model) => model.model.name().0, - _ => SharedString::from("No model selected"), + let (model_name, model_icon) = match model { + Some(model) => (model.model.name().0, Some(model.provider.icon())), + _ => (SharedString::from("No model selected"), None), }; LanguageModelSelectorPopoverMenu::new( @@ -100,10 +99,16 @@ impl Render for AssistantModelSelector { .child( h_flex() .gap_0p5() + .children( + model_icon.map(|icon| { + Icon::new(icon).color(Color::Muted).size(IconSize::Small) + }), + ) .child( Label::new(model_name) .size(LabelSize::Small) - .color(Color::Muted), + .color(Color::Muted) + .ml_1(), ) .child( Icon::new(IconName::ChevronDown) diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index 75f4db9ff3..fa953d93a8 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -44,8 +44,8 @@ use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio}; use crate::thread_history::{PastContext, PastThread, ThreadHistory}; use crate::thread_store::ThreadStore; use crate::{ - AgentDiff, InlineAssistant, NewTextThread, NewThread, OpenActiveThreadAsMarkdown, - OpenAgentDiff, OpenHistory, ThreadEvent, ToggleContextPicker, + AgentDiff, ExpandMessageEditor, InlineAssistant, NewTextThread, NewThread, + OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ThreadEvent, ToggleContextPicker, }; pub fn init(cx: &mut App) { @@ -90,6 +90,16 @@ pub fn init(cx: &mut App) { let thread = panel.read(cx).thread.read(cx).thread().clone(); AgentDiff::deploy_in_workspace(thread, workspace, window, cx); } + }) + .register_action(|workspace, _: &ExpandMessageEditor, window, cx| { + if let Some(panel) = workspace.panel::(cx) { + workspace.focus_panel::(window, cx); + panel.update(cx, |panel, cx| { + panel.message_editor.update(cx, |editor, cx| { + editor.expand_message_editor(&ExpandMessageEditor, window, cx); + }); + }); + } }); }, ) @@ -193,7 +203,7 @@ impl AssistantPanel { cx: AsyncWindowContext, ) -> Task>> { cx.spawn(async move |cx| { - let tools = Arc::new(ToolWorkingSet::default()); + let tools = cx.new(|_| ToolWorkingSet::default())?; let thread_store = workspace .update(cx, |workspace, cx| { let project = workspace.project().clone(); @@ -559,6 +569,7 @@ impl AssistantPanel { ActiveView::Configuration | ActiveView::History => { self.active_view = ActiveView::thread(self.thread.read(cx).thread().clone(), window, cx); + self.message_editor.focus_handle(cx).focus(window); cx.notify(); } _ => {} @@ -1088,20 +1099,30 @@ impl AssistantPanel { window, cx, |menu, _window, _cx| { - menu.action( + menu + .when(!is_empty, |menu| { + menu.action( + "Start New From Summary", + Box::new(NewThread { + from_thread_id: Some(thread_id.clone()), + }), + ).separator() + }) + .action( "New Text Thread", NewTextThread.boxed_clone(), ) - .when(!is_empty, |menu| { - menu.action( - "Continue in New Thread", - Box::new(NewThread { - from_thread_id: Some(thread_id.clone()), - }), - ) - }) - .separator() .action("Settings", OpenConfiguration.boxed_clone()) + .separator() + .action( + "Install MCPs", + zed_actions::Extensions { + category_filter: Some( + zed_actions::ExtensionCategoryFilter::ContextServers, + ), + } + .boxed_clone(), + ) }, )) }), diff --git a/crates/agent/src/context_picker.rs b/crates/agent/src/context_picker.rs index bcbee38b73..9e578c4fc0 100644 --- a/crates/agent/src/context_picker.rs +++ b/crates/agent/src/context_picker.rs @@ -34,12 +34,6 @@ use crate::context_store::ContextStore; use crate::thread::ThreadId; use crate::thread_store::ThreadStore; -#[derive(Debug, Clone, Copy)] -pub enum ConfirmBehavior { - KeepOpen, - Close, -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum ContextPickerMode { File, @@ -105,7 +99,6 @@ pub(super) struct ContextPicker { workspace: WeakEntity, context_store: WeakEntity, thread_store: Option>, - confirm_behavior: ConfirmBehavior, _subscriptions: Vec, } @@ -114,7 +107,6 @@ impl ContextPicker { workspace: WeakEntity, thread_store: Option>, context_store: WeakEntity, - confirm_behavior: ConfirmBehavior, window: &mut Window, cx: &mut Context, ) -> Self { @@ -143,7 +135,6 @@ impl ContextPicker { workspace, context_store, thread_store, - confirm_behavior, _subscriptions: subscriptions, } } @@ -166,37 +157,32 @@ impl ContextPicker { let modes = supported_context_picker_modes(&self.thread_store); - let menu = menu - .when(has_recent, |menu| { - menu.custom_row(|_, _| { - div() - .mb_1() - .child( - Label::new("Recent") - .color(Color::Muted) - .size(LabelSize::Small), - ) - .into_any_element() - }) + menu.when(has_recent, |menu| { + menu.custom_row(|_, _| { + div() + .mb_1() + .child( + Label::new("Recent") + .color(Color::Muted) + .size(LabelSize::Small), + ) + .into_any_element() }) - .extend(recent_entries) - .when(has_recent, |menu| menu.separator()) - .extend(modes.into_iter().map(|mode| { - let context_picker = context_picker.clone(); + }) + .extend(recent_entries) + .when(has_recent, |menu| menu.separator()) + .extend(modes.into_iter().map(|mode| { + let context_picker = context_picker.clone(); - ContextMenuEntry::new(mode.label()) - .icon(mode.icon()) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .handler(move |window, cx| { - context_picker.update(cx, |this, cx| this.select_mode(mode, window, cx)) - }) - })); - - match self.confirm_behavior { - ConfirmBehavior::KeepOpen => menu.keep_open_on_confirm(), - ConfirmBehavior::Close => menu, - } + ContextMenuEntry::new(mode.label()) + .icon(mode.icon()) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .handler(move |window, cx| { + context_picker.update(cx, |this, cx| this.select_mode(mode, window, cx)) + }) + })) + .keep_open_on_confirm() }); cx.subscribe(&menu, move |_, _, _: &DismissEvent, cx| { @@ -227,7 +213,6 @@ impl ContextPicker { context_picker.clone(), self.workspace.clone(), self.context_store.clone(), - self.confirm_behavior, window, cx, ) @@ -239,7 +224,6 @@ impl ContextPicker { context_picker.clone(), self.workspace.clone(), self.context_store.clone(), - self.confirm_behavior, window, cx, ) @@ -251,7 +235,6 @@ impl ContextPicker { context_picker.clone(), self.workspace.clone(), self.context_store.clone(), - self.confirm_behavior, window, cx, ) @@ -264,7 +247,6 @@ impl ContextPicker { thread_store.clone(), context_picker.clone(), self.context_store.clone(), - self.confirm_behavior, window, cx, ) diff --git a/crates/agent/src/context_picker/fetch_context_picker.rs b/crates/agent/src/context_picker/fetch_context_picker.rs index c4a9dd1211..5c7795237b 100644 --- a/crates/agent/src/context_picker/fetch_context_picker.rs +++ b/crates/agent/src/context_picker/fetch_context_picker.rs @@ -11,7 +11,7 @@ use picker::{Picker, PickerDelegate}; use ui::{Context, ListItem, Window, prelude::*}; use workspace::Workspace; -use crate::context_picker::{ConfirmBehavior, ContextPicker}; +use crate::context_picker::ContextPicker; use crate::context_store::ContextStore; pub struct FetchContextPicker { @@ -23,16 +23,10 @@ impl FetchContextPicker { context_picker: WeakEntity, workspace: WeakEntity, context_store: WeakEntity, - confirm_behavior: ConfirmBehavior, window: &mut Window, cx: &mut Context, ) -> Self { - let delegate = FetchContextPickerDelegate::new( - context_picker, - workspace, - context_store, - confirm_behavior, - ); + let delegate = FetchContextPickerDelegate::new(context_picker, workspace, context_store); let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); Self { picker } @@ -62,7 +56,6 @@ pub struct FetchContextPickerDelegate { context_picker: WeakEntity, workspace: WeakEntity, context_store: WeakEntity, - confirm_behavior: ConfirmBehavior, url: String, } @@ -71,13 +64,11 @@ impl FetchContextPickerDelegate { context_picker: WeakEntity, workspace: WeakEntity, context_store: WeakEntity, - confirm_behavior: ConfirmBehavior, ) -> Self { FetchContextPickerDelegate { context_picker, workspace, context_store, - confirm_behavior, url: String::new(), } } @@ -204,25 +195,15 @@ impl PickerDelegate for FetchContextPickerDelegate { let http_client = workspace.read(cx).client().http_client().clone(); let url = self.url.clone(); - let confirm_behavior = self.confirm_behavior; cx.spawn_in(window, async move |this, cx| { let text = cx .background_spawn(fetch_url_content(http_client, url.clone())) .await?; - this.update_in(cx, |this, window, cx| { - this.delegate - .context_store - .update(cx, |context_store, cx| { - context_store.add_fetched_url(url, text, cx) - })?; - - match confirm_behavior { - ConfirmBehavior::KeepOpen => {} - ConfirmBehavior::Close => this.delegate.dismissed(window, cx), - } - - anyhow::Ok(()) + this.update(cx, |this, cx| { + this.delegate.context_store.update(cx, |context_store, cx| { + context_store.add_fetched_url(url, text, cx) + }) })??; anyhow::Ok(()) diff --git a/crates/agent/src/context_picker/file_context_picker.rs b/crates/agent/src/context_picker/file_context_picker.rs index 965f4a530e..5981b471c2 100644 --- a/crates/agent/src/context_picker/file_context_picker.rs +++ b/crates/agent/src/context_picker/file_context_picker.rs @@ -11,9 +11,9 @@ use picker::{Picker, PickerDelegate}; use project::{PathMatchCandidateSet, ProjectPath, WorktreeId}; use ui::{ListItem, Tooltip, prelude::*}; use util::ResultExt as _; -use workspace::{Workspace, notifications::NotifyResultExt}; +use workspace::Workspace; -use crate::context_picker::{ConfirmBehavior, ContextPicker}; +use crate::context_picker::ContextPicker; use crate::context_store::{ContextStore, FileInclusion}; pub struct FileContextPicker { @@ -25,16 +25,10 @@ impl FileContextPicker { context_picker: WeakEntity, workspace: WeakEntity, context_store: WeakEntity, - confirm_behavior: ConfirmBehavior, window: &mut Window, cx: &mut Context, ) -> Self { - let delegate = FileContextPickerDelegate::new( - context_picker, - workspace, - context_store, - confirm_behavior, - ); + let delegate = FileContextPickerDelegate::new(context_picker, workspace, context_store); let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); Self { picker } @@ -57,7 +51,6 @@ pub struct FileContextPickerDelegate { context_picker: WeakEntity, workspace: WeakEntity, context_store: WeakEntity, - confirm_behavior: ConfirmBehavior, matches: Vec, selected_index: usize, } @@ -67,13 +60,11 @@ impl FileContextPickerDelegate { context_picker: WeakEntity, workspace: WeakEntity, context_store: WeakEntity, - confirm_behavior: ConfirmBehavior, ) -> Self { Self { context_picker, workspace, context_store, - confirm_behavior, matches: Vec::new(), selected_index: 0, } @@ -127,7 +118,7 @@ impl PickerDelegate for FileContextPickerDelegate { }) } - fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { + fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context>) { let Some(FileMatch { mat, .. }) = self.matches.get(self.selected_index) else { return; }; @@ -153,17 +144,7 @@ impl PickerDelegate for FileContextPickerDelegate { return; }; - let confirm_behavior = self.confirm_behavior; - cx.spawn_in(window, async move |this, cx| { - match task.await.notify_async_err(cx) { - None => anyhow::Ok(()), - Some(()) => this.update_in(cx, |this, window, cx| match confirm_behavior { - ConfirmBehavior::KeepOpen => {} - ConfirmBehavior::Close => this.delegate.dismissed(window, cx), - }), - } - }) - .detach_and_log_err(cx); + task.detach_and_log_err(cx); } fn dismissed(&mut self, _: &mut Window, cx: &mut Context>) { diff --git a/crates/agent/src/context_picker/symbol_context_picker.rs b/crates/agent/src/context_picker/symbol_context_picker.rs index 608accc098..b76d4a8093 100644 --- a/crates/agent/src/context_picker/symbol_context_picker.rs +++ b/crates/agent/src/context_picker/symbol_context_picker.rs @@ -15,7 +15,7 @@ use ui::{ListItem, prelude::*}; use util::ResultExt as _; use workspace::Workspace; -use crate::context_picker::{ConfirmBehavior, ContextPicker}; +use crate::context_picker::ContextPicker; use crate::context_store::ContextStore; pub struct SymbolContextPicker { @@ -27,16 +27,10 @@ impl SymbolContextPicker { context_picker: WeakEntity, workspace: WeakEntity, context_store: WeakEntity, - confirm_behavior: ConfirmBehavior, window: &mut Window, cx: &mut Context, ) -> Self { - let delegate = SymbolContextPickerDelegate::new( - context_picker, - workspace, - context_store, - confirm_behavior, - ); + let delegate = SymbolContextPickerDelegate::new(context_picker, workspace, context_store); let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); Self { picker } @@ -59,7 +53,6 @@ pub struct SymbolContextPickerDelegate { context_picker: WeakEntity, workspace: WeakEntity, context_store: WeakEntity, - confirm_behavior: ConfirmBehavior, matches: Vec, selected_index: usize, } @@ -69,13 +62,11 @@ impl SymbolContextPickerDelegate { context_picker: WeakEntity, workspace: WeakEntity, context_store: WeakEntity, - confirm_behavior: ConfirmBehavior, ) -> Self { Self { context_picker, workspace, context_store, - confirm_behavior, matches: Vec::new(), selected_index: 0, } @@ -135,7 +126,7 @@ impl PickerDelegate for SymbolContextPickerDelegate { }) } - fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { + fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context>) { let Some(mat) = self.matches.get(self.selected_index) else { return; }; @@ -143,7 +134,6 @@ impl PickerDelegate for SymbolContextPickerDelegate { return; }; - let confirm_behavior = self.confirm_behavior; let add_symbol_task = add_symbol( mat.symbol.clone(), true, @@ -153,16 +143,12 @@ impl PickerDelegate for SymbolContextPickerDelegate { ); let selected_index = self.selected_index; - cx.spawn_in(window, async move |this, cx| { + cx.spawn(async move |this, cx| { let included = add_symbol_task.await?; - this.update_in(cx, |this, window, cx| { + this.update(cx, |this, _| { if let Some(mat) = this.delegate.matches.get_mut(selected_index) { mat.is_included = included; } - match confirm_behavior { - ConfirmBehavior::KeepOpen => {} - ConfirmBehavior::Close => this.delegate.dismissed(window, cx), - } }) }) .detach_and_log_err(cx); diff --git a/crates/agent/src/context_picker/thread_context_picker.rs b/crates/agent/src/context_picker/thread_context_picker.rs index 98f62b3073..941926a898 100644 --- a/crates/agent/src/context_picker/thread_context_picker.rs +++ b/crates/agent/src/context_picker/thread_context_picker.rs @@ -6,7 +6,7 @@ use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity}; use picker::{Picker, PickerDelegate}; use ui::{ListItem, prelude::*}; -use crate::context_picker::{ConfirmBehavior, ContextPicker}; +use crate::context_picker::ContextPicker; use crate::context_store::{self, ContextStore}; use crate::thread::ThreadId; use crate::thread_store::ThreadStore; @@ -20,16 +20,11 @@ impl ThreadContextPicker { thread_store: WeakEntity, context_picker: WeakEntity, context_store: WeakEntity, - confirm_behavior: ConfirmBehavior, window: &mut Window, cx: &mut Context, ) -> Self { - let delegate = ThreadContextPickerDelegate::new( - thread_store, - context_picker, - context_store, - confirm_behavior, - ); + let delegate = + ThreadContextPickerDelegate::new(thread_store, context_picker, context_store); let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); ThreadContextPicker { picker } @@ -58,7 +53,6 @@ pub struct ThreadContextPickerDelegate { thread_store: WeakEntity, context_picker: WeakEntity, context_store: WeakEntity, - confirm_behavior: ConfirmBehavior, matches: Vec, selected_index: usize, } @@ -68,13 +62,11 @@ impl ThreadContextPickerDelegate { thread_store: WeakEntity, context_picker: WeakEntity, context_store: WeakEntity, - confirm_behavior: ConfirmBehavior, ) -> Self { ThreadContextPickerDelegate { thread_store, context_picker, context_store, - confirm_behavior, matches: Vec::new(), selected_index: 0, } @@ -127,7 +119,7 @@ impl PickerDelegate for ThreadContextPickerDelegate { }) } - fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { + fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context>) { let Some(entry) = self.matches.get(self.selected_index) else { return; }; @@ -138,20 +130,15 @@ impl PickerDelegate for ThreadContextPickerDelegate { let open_thread_task = thread_store.update(cx, |this, cx| this.open_thread(&entry.id, cx)); - cx.spawn_in(window, async move |this, cx| { + cx.spawn(async move |this, cx| { let thread = open_thread_task.await?; - this.update_in(cx, |this, window, cx| { + this.update(cx, |this, cx| { this.delegate .context_store .update(cx, |context_store, cx| { context_store.add_thread(thread, true, cx) }) .ok(); - - match this.delegate.confirm_behavior { - ConfirmBehavior::KeepOpen => {} - ConfirmBehavior::Close => this.delegate.dismissed(window, cx), - } }) }) .detach_and_log_err(cx); diff --git a/crates/agent/src/context_strip.rs b/crates/agent/src/context_strip.rs index afc61f46ce..6245f88998 100644 --- a/crates/agent/src/context_strip.rs +++ b/crates/agent/src/context_strip.rs @@ -15,7 +15,7 @@ use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*}; use workspace::{Workspace, notifications::NotifyResultExt}; use crate::context::{ContextId, ContextKind}; -use crate::context_picker::{ConfirmBehavior, ContextPicker}; +use crate::context_picker::ContextPicker; use crate::context_store::ContextStore; use crate::thread::Thread; use crate::thread_store::ThreadStore; @@ -52,7 +52,6 @@ impl ContextStrip { workspace.clone(), thread_store.clone(), context_store.downgrade(), - ConfirmBehavior::KeepOpen, window, cx, ) diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 81adea8945..c61e99ad04 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -1,6 +1,8 @@ +use std::collections::BTreeMap; use std::sync::Arc; use crate::assistant_model_selector::ModelType; +use buffer_diff::BufferDiff; use collections::HashSet; use editor::actions::MoveUp; use editor::{ @@ -10,8 +12,8 @@ use editor::{ use file_icons::FileIcons; use fs::Fs; use gpui::{ - Animation, AnimationExt, App, DismissEvent, Entity, Focusable, Subscription, TextStyle, - WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between, + Animation, AnimationExt, App, Entity, Focusable, Subscription, TextStyle, WeakEntity, + linear_color_stop, linear_gradient, point, pulsating_between, }; use language::{Buffer, Language}; use language_model::{ConfiguredModel, LanguageModelRegistry}; @@ -21,12 +23,12 @@ use project::Project; use settings::Settings; use std::time::Duration; use theme::ThemeSettings; -use ui::{Disclosure, KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*}; +use ui::{Disclosure, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*}; use util::ResultExt as _; use workspace::Workspace; use crate::assistant_model_selector::AssistantModelSelector; -use crate::context_picker::{ConfirmBehavior, ContextPicker, ContextPickerCompletionProvider}; +use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider}; use crate::context_store::{ContextStore, refresh_context_store_text}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::profile_selector::ProfileSelector; @@ -46,8 +48,6 @@ pub struct MessageEditor { context_store: Entity, context_strip: Entity, context_picker_menu_handle: PopoverMenuHandle, - inline_context_picker: Entity, - inline_context_picker_menu_handle: PopoverMenuHandle, model_selector: Entity, profile_selector: Entity, edits_expanded: bool, @@ -56,7 +56,7 @@ pub struct MessageEditor { _subscriptions: Vec, } -const MAX_EDITOR_LINES: usize = 10; +const MAX_EDITOR_LINES: usize = 8; impl MessageEditor { pub fn new( @@ -69,7 +69,6 @@ impl MessageEditor { cx: &mut Context, ) -> Self { let context_picker_menu_handle = PopoverMenuHandle::default(); - let inline_context_picker_menu_handle = PopoverMenuHandle::default(); let model_selector_menu_handle = PopoverMenuHandle::default(); let language = Language::new( @@ -94,6 +93,7 @@ impl MessageEditor { ); editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx); editor.set_show_indent_guides(false, cx); + editor.set_soft_wrap(); editor.set_context_menu_options(ContextMenuOptions { min_entries_visible: 12, max_entries_visible: 12, @@ -112,17 +112,6 @@ impl MessageEditor { )))); }); - let inline_context_picker = cx.new(|cx| { - ContextPicker::new( - workspace.clone(), - Some(thread_store.clone()), - context_store.downgrade(), - ConfirmBehavior::Close, - window, - cx, - ) - }); - let context_strip = cx.new(|cx| { ContextStrip::new( context_store.clone(), @@ -135,14 +124,8 @@ impl MessageEditor { ) }); - let subscriptions = vec![ - cx.subscribe_in( - &inline_context_picker, - window, - Self::handle_inline_context_picker_event, - ), - cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event), - ]; + let subscriptions = + vec![cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event)]; Self { editor: editor.clone(), @@ -152,8 +135,6 @@ impl MessageEditor { context_store, context_strip, context_picker_menu_handle, - inline_context_picker, - inline_context_picker_menu_handle, model_selector: cx.new(|cx| { AssistantModelSelector::new( fs.clone(), @@ -177,7 +158,7 @@ impl MessageEditor { cx.notify(); } - fn expand_message_editor( + pub fn expand_message_editor( &mut self, _: &ExpandMessageEditor, _window: &mut Window, @@ -316,17 +297,6 @@ impl MessageEditor { .detach(); } - fn handle_inline_context_picker_event( - &mut self, - _inline_context_picker: &Entity, - _event: &DismissEvent, - window: &mut Window, - cx: &mut Context, - ) { - let editor_focus_handle = self.editor.focus_handle(cx); - window.focus(&editor_focus_handle); - } - fn handle_context_strip_event( &mut self, _context_strip: &Entity, @@ -346,9 +316,7 @@ impl MessageEditor { } fn move_up(&mut self, _: &MoveUp, window: &mut Window, cx: &mut Context) { - if self.context_picker_menu_handle.is_deployed() - || self.inline_context_picker_menu_handle.is_deployed() - { + if self.context_picker_menu_handle.is_deployed() { cx.propagate(); } else { self.context_strip.focus_handle(cx).focus(window); @@ -371,6 +339,503 @@ impl MessageEditor { diff.update(cx, |diff, cx| diff.move_to_path(path_key, window, cx)); } } + + fn render_editor( + &self, + font_size: Rems, + line_height: Pixels, + window: &mut Window, + cx: &mut Context, + ) -> Div { + let thread = self.thread.read(cx); + + let editor_bg_color = cx.theme().colors().editor_background; + let is_generating = thread.is_generating(); + let focus_handle = self.editor.focus_handle(cx); + + let is_model_selected = self.is_model_selected(cx); + let is_editor_empty = self.is_editor_empty(cx); + + let is_editor_expanded = self.editor_is_expanded; + let expand_icon = if is_editor_expanded { + IconName::Minimize + } else { + IconName::Maximize + }; + + v_flex() + .key_context("MessageEditor") + .on_action(cx.listener(Self::chat)) + .on_action(cx.listener(|this, _: &ToggleProfileSelector, window, cx| { + this.profile_selector + .read(cx) + .menu_handle() + .toggle(window, cx); + })) + .on_action(cx.listener(|this, _: &ToggleModelSelector, window, cx| { + this.model_selector + .update(cx, |model_selector, cx| model_selector.toggle(window, cx)); + })) + .on_action(cx.listener(Self::toggle_context_picker)) + .on_action(cx.listener(Self::remove_all_context)) + .on_action(cx.listener(Self::move_up)) + .on_action(cx.listener(Self::toggle_chat_mode)) + .on_action(cx.listener(Self::expand_message_editor)) + .gap_2() + .p_2() + .bg(editor_bg_color) + .border_t_1() + .border_color(cx.theme().colors().border) + .child( + h_flex() + .items_start() + .justify_between() + .child(self.context_strip.clone()) + .child( + IconButton::new("toggle-height", expand_icon) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + let expand_label = if is_editor_expanded { + "Minimize Message Editor".to_string() + } else { + "Expand Message Editor".to_string() + }; + + Tooltip::for_action_in( + expand_label, + &ExpandMessageEditor, + &focus_handle, + window, + cx, + ) + } + }) + .on_click(cx.listener(|_, _, window, cx| { + window.dispatch_action(Box::new(ExpandMessageEditor), cx); + })), + ), + ) + .child( + v_flex() + .size_full() + .gap_4() + .when(is_editor_expanded, |this| { + this.h(vh(0.8, window)).justify_between() + }) + .child( + div() + .min_h_16() + .when(is_editor_expanded, |this| this.h_full()) + .child({ + let settings = ThemeSettings::get_global(cx); + + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.buffer_font.family.clone(), + font_fallbacks: settings.buffer_font.fallbacks.clone(), + font_features: settings.buffer_font.features.clone(), + font_size: font_size.into(), + line_height: line_height.into(), + ..Default::default() + }; + + EditorElement::new( + &self.editor, + EditorStyle { + background: editor_bg_color, + local_player: cx.theme().players().local(), + text: text_style, + syntax: cx.theme().syntax().clone(), + ..Default::default() + }, + ) + .into_any() + }), + ) + .child( + h_flex() + .flex_none() + .justify_between() + .child(h_flex().gap_2().child(self.profile_selector.clone())) + .child(h_flex().gap_1().child(self.model_selector.clone()).map({ + let focus_handle = focus_handle.clone(); + move |parent| { + if is_generating { + parent.child( + IconButton::new( + "stop-generation", + IconName::StopFilled, + ) + .icon_color(Color::Error) + .style(ButtonStyle::Tinted(ui::TintColor::Error)) + .tooltip(move |window, cx| { + Tooltip::for_action( + "Stop Generation", + &editor::actions::Cancel, + window, + cx, + ) + }) + .on_click({ + let focus_handle = focus_handle.clone(); + move |_event, window, cx| { + focus_handle.dispatch_action( + &editor::actions::Cancel, + window, + cx, + ); + } + }) + .with_animation( + "pulsating-label", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.4, 1.0)), + |icon_button, delta| icon_button.alpha(delta), + ), + ) + } else { + parent.child( + IconButton::new("send-message", IconName::Send) + .icon_color(Color::Accent) + .style(ButtonStyle::Filled) + .disabled( + is_editor_empty + || !is_model_selected + || self.waiting_for_summaries_to_send, + ) + .on_click({ + let focus_handle = focus_handle.clone(); + move |_event, window, cx| { + focus_handle + .dispatch_action(&Chat, window, cx); + } + }) + .when( + !is_editor_empty && is_model_selected, + |button| { + button.tooltip(move |window, cx| { + Tooltip::for_action( + "Send", &Chat, window, cx, + ) + }) + }, + ) + .when(is_editor_empty, |button| { + button.tooltip(Tooltip::text( + "Type a message to submit", + )) + }) + .when(!is_model_selected, |button| { + button.tooltip(Tooltip::text( + "Select a model to continue", + )) + }), + ) + } + } + })), + ), + ) + } + + fn render_changed_buffers( + &self, + changed_buffers: &BTreeMap, Entity>, + window: &mut Window, + cx: &mut Context, + ) -> Div { + let focus_handle = self.editor.focus_handle(cx); + + let editor_bg_color = cx.theme().colors().editor_background; + let border_color = cx.theme().colors().border; + let active_color = cx.theme().colors().element_selected; + let bg_edit_files_disclosure = editor_bg_color.blend(active_color.opacity(0.3)); + let is_edit_changes_expanded = self.edits_expanded; + + v_flex() + .mx_2() + .bg(bg_edit_files_disclosure) + .border_1() + .border_b_0() + .border_color(border_color) + .rounded_t_md() + .shadow(smallvec::smallvec![gpui::BoxShadow { + color: gpui::black().opacity(0.15), + offset: point(px(1.), px(-1.)), + blur_radius: px(3.), + spread_radius: px(0.), + }]) + .child( + h_flex() + .id("edits-container") + .cursor_pointer() + .p_1p5() + .justify_between() + .when(is_edit_changes_expanded, |this| { + this.border_b_1().border_color(border_color) + }) + .on_click( + cx.listener(|this, _, window, cx| this.handle_review_click(window, cx)), + ) + .child( + h_flex() + .gap_1() + .child( + Disclosure::new("edits-disclosure", is_edit_changes_expanded) + .on_click(cx.listener(|this, _ev, _window, cx| { + this.edits_expanded = !this.edits_expanded; + cx.notify(); + })), + ) + .child( + Label::new("Edits") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child(Label::new("•").size(LabelSize::XSmall).color(Color::Muted)) + .child( + Label::new(format!( + "{} {}", + changed_buffers.len(), + if changed_buffers.len() == 1 { + "file" + } else { + "files" + } + )) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + .child( + Button::new("review", "Review Changes") + .label_size(LabelSize::Small) + .key_binding( + KeyBinding::for_action_in( + &OpenAgentDiff, + &focus_handle, + window, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _, window, cx| { + this.handle_review_click(window, cx) + })), + ), + ) + .when(is_edit_changes_expanded, |parent| { + parent.child( + v_flex().children(changed_buffers.into_iter().enumerate().flat_map( + |(index, (buffer, _diff))| { + let file = buffer.read(cx).file()?; + let path = file.path(); + + let parent_label = path.parent().and_then(|parent| { + let parent_str = parent.to_string_lossy(); + + if parent_str.is_empty() { + None + } else { + Some( + Label::new(format!( + "/{}{}", + parent_str, + std::path::MAIN_SEPARATOR_STR + )) + .color(Color::Muted) + .size(LabelSize::XSmall) + .buffer_font(cx), + ) + } + }); + + let name_label = path.file_name().map(|name| { + Label::new(name.to_string_lossy().to_string()) + .size(LabelSize::XSmall) + .buffer_font(cx) + }); + + let file_icon = FileIcons::get_icon(&path, cx) + .map(Icon::from_path) + .map(|icon| icon.color(Color::Muted).size(IconSize::Small)) + .unwrap_or_else(|| { + Icon::new(IconName::File) + .color(Color::Muted) + .size(IconSize::Small) + }); + + let hover_color = cx + .theme() + .colors() + .element_background + .blend(cx.theme().colors().editor_foreground.opacity(0.025)); + + let overlay_gradient = linear_gradient( + 90., + linear_color_stop(editor_bg_color, 1.), + linear_color_stop(editor_bg_color.opacity(0.2), 0.), + ); + + let overlay_gradient_hover = linear_gradient( + 90., + linear_color_stop(hover_color, 1.), + linear_color_stop(hover_color.opacity(0.2), 0.), + ); + + let element = h_flex() + .group("edited-code") + .id(("file-container", index)) + .cursor_pointer() + .relative() + .py_1() + .pl_2() + .pr_1() + .gap_2() + .justify_between() + .bg(cx.theme().colors().editor_background) + .hover(|style| style.bg(hover_color)) + .when(index + 1 < changed_buffers.len(), |parent| { + parent.border_color(border_color).border_b_1() + }) + .child( + h_flex() + .id("file-name") + .pr_8() + .gap_1p5() + .max_w_full() + .overflow_x_scroll() + .child(file_icon) + .child( + h_flex() + .gap_0p5() + .children(name_label) + .children(parent_label), + ) // TODO: show lines changed + .child(Label::new("+").color(Color::Created)) + .child(Label::new("-").color(Color::Deleted)), + ) + .child( + div().visible_on_hover("edited-code").child( + Button::new("review", "Review") + .label_size(LabelSize::Small) + .on_click({ + let buffer = buffer.clone(); + cx.listener(move |this, _, window, cx| { + this.handle_file_click( + buffer.clone(), + window, + cx, + ); + }) + }), + ), + ) + .child( + div() + .id("gradient-overlay") + .absolute() + .h_5_6() + .w_12() + .bottom_0() + .right(px(52.)) + .bg(overlay_gradient) + .group_hover("edited-code", |style| { + style.bg(overlay_gradient_hover) + }), + ) + .on_click({ + let buffer = buffer.clone(); + cx.listener(move |this, _, window, cx| { + this.handle_file_click(buffer.clone(), window, cx); + }) + }); + + Some(element) + }, + )), + ) + }) + } + + fn render_token_limit_callout( + &self, + line_height: Pixels, + token_usage_ratio: TokenUsageRatio, + cx: &mut Context, + ) -> Div { + let heading = if token_usage_ratio == TokenUsageRatio::Exceeded { + "Thread reached the token limit" + } else { + "Thread reaching the token limit soon" + }; + + h_flex() + .p_2() + .gap_2() + .flex_wrap() + .justify_between() + .bg( + if token_usage_ratio == TokenUsageRatio::Exceeded { + cx.theme().status().error_background.opacity(0.1) + } else { + cx.theme().status().warning_background.opacity(0.1) + }) + .border_t_1() + .border_color(cx.theme().colors().border) + .child( + h_flex() + .gap_2() + .items_start() + .child( + h_flex() + .h(line_height) + .justify_center() + .child( + if token_usage_ratio == TokenUsageRatio::Exceeded { + Icon::new(IconName::X) + .color(Color::Error) + .size(IconSize::XSmall) + } else { + Icon::new(IconName::Warning) + .color(Color::Warning) + .size(IconSize::XSmall) + } + ), + ) + .child( + v_flex() + .mr_auto() + .child(Label::new(heading).size(LabelSize::Small)) + .child( + Label::new( + "Start a new thread from a summary to continue the conversation.", + ) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ), + ) + .child( + Button::new("new-thread", "Start New Thread") + .on_click(cx.listener(|this, _, window, cx| { + let from_thread_id = Some(this.thread.read(cx).id().clone()); + + window.dispatch_action(Box::new(NewThread { + from_thread_id + }), cx); + })) + .icon(IconName::Plus) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .label_size(LabelSize::Small), + ) + } } impl Focusable for MessageEditor { @@ -381,35 +846,14 @@ impl Focusable for MessageEditor { impl Render for MessageEditor { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let font_size = TextSize::Small.rems(cx); - let line_height = font_size.to_pixels(window.rem_size()) * 1.5; - - let focus_handle = self.editor.focus_handle(cx); - let focus_handle_clone = focus_handle.clone(); - let inline_context_picker = self.inline_context_picker.clone(); - - let is_editor_expanded = self.editor_is_expanded; - let expand_icon = if is_editor_expanded { - IconName::Minimize - } else { - IconName::Maximize - }; - let thread = self.thread.read(cx); - let is_generating = thread.is_generating(); let total_token_usage = thread.total_token_usage(cx); - let is_model_selected = self.is_model_selected(cx); - let is_editor_empty = self.is_editor_empty(cx); - let is_edit_changes_expanded = self.edits_expanded; let action_log = self.thread.read(cx).action_log(); let changed_buffers = action_log.read(cx).changed_buffers(cx); - let changed_buffers_count = changed_buffers.len(); - let editor_bg_color = cx.theme().colors().editor_background; - let border_color = cx.theme().colors().border; - let active_color = cx.theme().colors().element_selected; - let bg_edit_files_disclosure = editor_bg_color.blend(active_color.opacity(0.3)); + let font_size = TextSize::Small.rems(cx); + let line_height = font_size.to_pixels(window.rem_size()) * 1.5; v_flex() .size_full() @@ -420,7 +864,7 @@ impl Render for MessageEditor { .flex_none() .px_2() .py_2() - .bg(editor_bg_color) + .bg(cx.theme().colors().editor_background) .border_1() .border_color(cx.theme().colors().border_variant) .rounded_lg() @@ -448,477 +892,19 @@ impl Render for MessageEditor { ), ) }) - .when(changed_buffers_count > 0, |parent| { - parent.child( - v_flex() - .mx_2() - .bg(bg_edit_files_disclosure) - .border_1() - .border_b_0() - .border_color(border_color) - .rounded_t_md() - .shadow(smallvec::smallvec![gpui::BoxShadow { - color: gpui::black().opacity(0.15), - offset: point(px(1.), px(-1.)), - blur_radius: px(3.), - spread_radius: px(0.), - }]) - .child( - h_flex() - .id("edits-container") - .cursor_pointer() - .p_1p5() - .justify_between() - .when(is_edit_changes_expanded, |this| { - this.border_b_1().border_color(border_color) - }) - .on_click(cx.listener(|this, _, window, cx| { - this.handle_review_click(window, cx) - })) - .child( - h_flex() - .gap_1() - .child( - Disclosure::new( - "edits-disclosure", - is_edit_changes_expanded, - ) - .on_click( - cx.listener(|this, _ev, _window, cx| { - this.edits_expanded = !this.edits_expanded; - cx.notify(); - }), - ), - ) - .child( - Label::new("Edits") - .size(LabelSize::Small) - .color(Color::Muted), - ) - .child( - Label::new("•") - .size(LabelSize::XSmall) - .color(Color::Muted), - ) - .child( - Label::new(format!( - "{} {}", - changed_buffers_count, - if changed_buffers_count == 1 { - "file" - } else { - "files" - } - )) - .size(LabelSize::Small) - .color(Color::Muted), - ), - ) - .child( - Button::new("review", "Review Changes") - .label_size(LabelSize::Small) - .key_binding( - KeyBinding::for_action_in( - &OpenAgentDiff, - &focus_handle, - window, - cx, - ) - .map(|kb| kb.size(rems_from_px(12.))), - ) - .on_click(cx.listener(|this, _, window, cx| { - this.handle_review_click(window, cx) - })), - ), - ) - .when(is_edit_changes_expanded, |parent| { - parent.child( - v_flex().children( - changed_buffers.into_iter().enumerate().flat_map( - |(index, (buffer, _diff))| { - let file = buffer.read(cx).file()?; - let path = file.path(); - - let parent_label = path.parent().and_then(|parent| { - let parent_str = parent.to_string_lossy(); - - if parent_str.is_empty() { - None - } else { - Some( - Label::new(format!( - "/{}{}", - parent_str, - std::path::MAIN_SEPARATOR_STR - )) - .color(Color::Muted) - .size(LabelSize::XSmall) - .buffer_font(cx), - ) - } - }); - - let name_label = path.file_name().map(|name| { - Label::new(name.to_string_lossy().to_string()) - .size(LabelSize::XSmall) - .buffer_font(cx) - }); - - let file_icon = FileIcons::get_icon(&path, cx) - .map(Icon::from_path) - .map(|icon| { - icon.color(Color::Muted).size(IconSize::Small) - }) - .unwrap_or_else(|| { - Icon::new(IconName::File) - .color(Color::Muted) - .size(IconSize::Small) - }); - - let hover_color = cx.theme() - .colors() - .element_background - .blend(cx.theme().colors().editor_foreground.opacity(0.025)); - - let overlay_gradient = linear_gradient( - 90., - linear_color_stop( - editor_bg_color, - 1., - ), - linear_color_stop( - editor_bg_color - .opacity(0.2), - 0., - ), - ); - - let overlay_gradient_hover = linear_gradient( - 90., - linear_color_stop( - hover_color, - 1., - ), - linear_color_stop( - hover_color - .opacity(0.2), - 0., - ), - ); - - let element = h_flex() - .group("edited-code") - .id(("file-container", index)) - .cursor_pointer() - .relative() - .py_1() - .pl_2() - .pr_1() - .gap_2() - .justify_between() - .bg(cx.theme().colors().editor_background) - .hover(|style| style.bg(hover_color)) - .when(index + 1 < changed_buffers_count, |parent| { - parent.border_color(border_color).border_b_1() - }) - .child( - h_flex() - .id("file-name") - .pr_8() - .gap_1p5() - .max_w_full() - .overflow_x_scroll() - .child(file_icon) - .child( - h_flex() - .gap_0p5() - .children(name_label) - .children(parent_label) - ) // TODO: show lines changed - .child( - Label::new("+") - .color(Color::Created), - ) - .child( - Label::new("-") - .color(Color::Deleted), - ), - ) - .child( - div().visible_on_hover("edited-code").child( - Button::new("review", "Review") - .label_size(LabelSize::Small) - .on_click({ - let buffer = buffer.clone(); - cx.listener(move |this, _, window, cx| { - this.handle_file_click(buffer.clone(), window, cx); - }) - }) - ) - ) - .child( - div() - .id("gradient-overlay") - .absolute() - .h_5_6() - .w_12() - .bottom_0() - .right(px(52.)) - .bg(overlay_gradient) - .group_hover("edited-code", |style| style.bg(overlay_gradient_hover)) - , - ) - .on_click({ - let buffer = buffer.clone(); - cx.listener(move |this, _, window, cx| { - this.handle_file_click(buffer.clone(), window, cx); - }) - }); - - Some(element) - }, - ), - ), - ) - }), - ) + .when(changed_buffers.len() > 0, |parent| { + parent.child(self.render_changed_buffers(&changed_buffers, window, cx)) }) - .child( - v_flex() - .key_context("MessageEditor") - .on_action(cx.listener(Self::chat)) - .on_action(cx.listener(|this, _: &ToggleProfileSelector, window, cx| { - this.profile_selector - .read(cx) - .menu_handle() - .toggle(window, cx); - })) - .on_action(cx.listener(|this, _: &ToggleModelSelector, window, cx| { - this.model_selector - .update(cx, |model_selector, cx| model_selector.toggle(window, cx)); - })) - .on_action(cx.listener(Self::toggle_context_picker)) - .on_action(cx.listener(Self::remove_all_context)) - .on_action(cx.listener(Self::move_up)) - .on_action(cx.listener(Self::toggle_chat_mode)) - .on_action(cx.listener(Self::expand_message_editor)) - .gap_2() - .p_2() - .bg(editor_bg_color) - .border_t_1() - .border_color(cx.theme().colors().border) - .child( - h_flex() - .items_start() - .justify_between() - .child(self.context_strip.clone()) - .child( - IconButton::new("toggle-height", expand_icon) - .icon_size(IconSize::XSmall) - .icon_color(Color::Muted) - .tooltip(move |window, cx| { - let focus_handle = focus_handle.clone(); - let expand_label = if is_editor_expanded { - "Minimize Message Editor".to_string() - } else { - "Expand Message Editor".to_string() - }; - - Tooltip::for_action_in( - expand_label, - &ExpandMessageEditor, - &focus_handle, - window, - cx, - ) - }) - .on_click(cx.listener(|_, _, window, cx| { - window.dispatch_action(Box::new(ExpandMessageEditor), cx); - })) - ) - ) - .child( - v_flex() - .size_full() - .gap_4() - .when(is_editor_expanded, |this| this.h(vh(0.8, window)).justify_between()) - .child(div().when(is_editor_expanded, |this| this.h_full()).child({ - let settings = ThemeSettings::get_global(cx); - - let text_style = TextStyle { - color: cx.theme().colors().text, - font_family: settings.buffer_font.family.clone(), - font_fallbacks: settings.buffer_font.fallbacks.clone(), - font_features: settings.buffer_font.features.clone(), - font_size: font_size.into(), - line_height: line_height.into(), - ..Default::default() - }; - - EditorElement::new( - &self.editor, - EditorStyle { - background: editor_bg_color, - local_player: cx.theme().players().local(), - text: text_style, - syntax: cx.theme().syntax().clone(), - ..Default::default() - }, - ).into_any() - })) - .child( - PopoverMenu::new("inline-context-picker") - .menu(move |window, cx| { - inline_context_picker.update(cx, |this, cx| { - this.init(window, cx); - }); - Some(inline_context_picker.clone()) - }) - .attach(gpui::Corner::TopLeft) - .anchor(gpui::Corner::BottomLeft) - .offset(gpui::Point { - x: px(0.0), - y: (-ThemeSettings::get_global(cx).ui_font_size(cx) * 2) - - px(4.0), - }) - .with_handle(self.inline_context_picker_menu_handle.clone()), - ) - .child( - h_flex() - .flex_none() - .justify_between() - .child(h_flex().gap_2().child(self.profile_selector.clone())) - .child( - h_flex().gap_1() - .child(self.model_selector.clone()) - .map(move |parent| { - if is_generating { - parent.child( - IconButton::new("stop-generation", IconName::StopFilled) - .icon_color(Color::Error) - .style(ButtonStyle::Tinted(ui::TintColor::Error)) - .tooltip(move |window, cx| { - Tooltip::for_action( - "Stop Generation", - &editor::actions::Cancel, - window, - cx, - ) - }) - .on_click({ - let focus_handle = focus_handle_clone.clone(); - move |_event, window, cx| { - focus_handle.dispatch_action( - &editor::actions::Cancel, - window, - cx, - ); - } - }) - .with_animation( - "pulsating-label", - Animation::new(Duration::from_secs(2)) - .repeat() - .with_easing(pulsating_between(0.4, 1.0)), - |icon_button, delta| icon_button.alpha(delta), - ), - ) - } else { - parent.child( - IconButton::new("send-message", IconName::Send) - .icon_color(Color::Accent) - .style(ButtonStyle::Filled) - .disabled( - is_editor_empty - || !is_model_selected - || self.waiting_for_summaries_to_send - ) - .on_click({ - let focus_handle = focus_handle_clone.clone(); - move |_event, window, cx| { - focus_handle.dispatch_action(&Chat, window, cx); - } - }) - .when(!is_editor_empty && is_model_selected, |button| { - button.tooltip(move |window, cx| { - Tooltip::for_action( - "Send", - &Chat, - window, - cx, - ) - }) - }) - .when(is_editor_empty, |button| { - button.tooltip(Tooltip::text( - "Type a message to submit", - )) - }) - .when(!is_model_selected, |button| { - button.tooltip(Tooltip::text( - "Select a model to continue", - )) - }) - ) - } - }) - ), - ), - ) + .child(self.render_editor(font_size, line_height, window, cx)) + .when( + total_token_usage.ratio != TokenUsageRatio::Normal, + |parent| { + parent.child(self.render_token_limit_callout( + line_height, + total_token_usage.ratio, + cx, + )) + }, ) - .when(total_token_usage.ratio != TokenUsageRatio::Normal, |parent| { - parent.child( - h_flex() - .p_2() - .gap_2() - .flex_wrap() - .justify_between() - .bg(cx.theme().status().warning_background.opacity(0.1)) - .border_t_1() - .border_color(cx.theme().colors().border) - .child( - h_flex() - .gap_2() - .items_start() - .child( - h_flex() - .h(line_height) - .justify_center() - .child( - Icon::new(IconName::Warning) - .color(Color::Warning) - .size(IconSize::XSmall), - ), - ) - .child( - v_flex() - .mr_auto() - .child(Label::new("Thread reaching the token limit soon").size(LabelSize::Small)) - .child( - Label::new( - "Start a new thread from a summary to continue the conversation.", - ) - .size(LabelSize::Small) - .color(Color::Muted), - ), - ), - ) - .child( - Button::new("new-thread", "Start New Thread") - .on_click(cx.listener(|this, _, window, cx| { - let from_thread_id = Some(this.thread.read(cx).id().clone()); - - window.dispatch_action(Box::new(NewThread { - from_thread_id - }), cx); - })) - .icon(IconName::Plus) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .style(ButtonStyle::Tinted(ui::TintColor::Accent)) - .label_size(LabelSize::Small), - ), - ) - }) } } diff --git a/crates/agent/src/profile_selector.rs b/crates/agent/src/profile_selector.rs index dfcafba5fc..c033bf9c58 100644 --- a/crates/agent/src/profile_selector.rs +++ b/crates/agent/src/profile_selector.rs @@ -86,7 +86,7 @@ impl ProfileSelector { thread_store .update(cx, |this, cx| { - this.load_profile_by_id(&profile_id, cx); + this.load_profile_by_id(profile_id.clone(), cx); }) .log_err(); } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 749cf25780..80cade75ae 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -15,10 +15,11 @@ use futures::{FutureExt, StreamExt as _}; use git::repository::DiffType; use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; use language_model::{ - ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, - LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, - LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, - PaymentRequiredError, Role, StopReason, TokenUsage, + ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId, + LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, + LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, + Role, StopReason, TokenUsage, }; use project::Project; use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState}; @@ -228,7 +229,7 @@ pub struct TotalTokenUsage { pub ratio: TokenUsageRatio, } -#[derive(Default, PartialEq, Eq)] +#[derive(Debug, Default, PartialEq, Eq)] pub enum TokenUsageRatio { #[default] Normal, @@ -253,22 +254,31 @@ pub struct Thread { pending_completions: Vec, project: Entity, prompt_builder: Arc, - tools: Arc, + tools: Entity, tool_use: ToolUseState, action_log: Entity, last_restore_checkpoint: Option, pending_checkpoint: Option, initial_project_snapshot: Shared>>>, cumulative_token_usage: TokenUsage, + exceeded_window_error: Option, feedback: Option, message_feedback: HashMap, last_auto_capture_at: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExceededWindowError { + /// Model used when last message exceeded context window + model_id: LanguageModelId, + /// Token count including last message + token_count: usize, +} + impl Thread { pub fn new( project: Entity, - tools: Arc, + tools: Entity, prompt_builder: Arc, system_prompt: SharedProjectContext, cx: &mut Context, @@ -301,6 +311,7 @@ impl Thread { .shared() }, cumulative_token_usage: TokenUsage::default(), + exceeded_window_error: None, feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, @@ -311,7 +322,7 @@ impl Thread { id: ThreadId, serialized: SerializedThread, project: Entity, - tools: Arc, + tools: Entity, prompt_builder: Arc, project_context: SharedProjectContext, cx: &mut Context, @@ -367,6 +378,7 @@ impl Thread { action_log: cx.new(|_| ActionLog::new(project)), initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(), cumulative_token_usage: serialized.cumulative_token_usage, + exceeded_window_error: None, feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, @@ -446,7 +458,7 @@ impl Thread { !self.pending_completions.is_empty() || !self.all_tools_finished() } - pub fn tools(&self) -> &Arc { + pub fn tools(&self) -> &Entity { &self.tools } @@ -819,8 +831,9 @@ impl Thread { }) .collect(), initial_project_snapshot, - cumulative_token_usage: this.cumulative_token_usage.clone(), + cumulative_token_usage: this.cumulative_token_usage, detailed_summary_state: this.detailed_summary_state.clone(), + exceeded_window_error: this.exceeded_window_error.clone(), }) }) } @@ -835,13 +848,21 @@ impl Thread { if model.supports_tools() { request.tools = { let mut tools = Vec::new(); - tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| { - LanguageModelRequestTool { - name: tool.name(), - description: tool.description(), - input_schema: tool.input_schema(model.tool_input_format()), - } - })); + tools.extend( + self.tools() + .read(cx) + .enabled_tools(cx) + .into_iter() + .filter_map(|tool| { + // Skip tools that cannot be supported + let input_schema = tool.input_schema(model.tool_input_format()).ok()?; + Some(LanguageModelRequestTool { + name: tool.name(), + description: tool.description(), + input_schema, + }) + }), + ); tools }; @@ -1000,7 +1021,7 @@ impl Thread { let task = cx.spawn(async move |thread, cx| { let stream = model.stream_completion(request, &cx); let initial_token_usage = - thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone()); + thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage); let stream_completion = async { let mut events = stream.await?; let mut stop_reason = StopReason::EndTurn; @@ -1022,9 +1043,9 @@ impl Thread { stop_reason = reason; } LanguageModelCompletionEvent::UsageUpdate(token_usage) => { - thread.cumulative_token_usage = - thread.cumulative_token_usage.clone() + token_usage.clone() - - current_token_usage.clone(); + thread.cumulative_token_usage = thread.cumulative_token_usage + + token_usage + - current_token_usage; current_token_usage = token_usage; } LanguageModelCompletionEvent::Text(chunk) => { @@ -1133,6 +1154,20 @@ impl Thread { cx.emit(ThreadEvent::ShowError( ThreadError::MaxMonthlySpendReached, )); + } else if let Some(known_error) = + error.downcast_ref::() + { + match known_error { + LanguageModelKnownError::ContextWindowLimitExceeded { + tokens, + } => { + thread.exceeded_window_error = Some(ExceededWindowError { + model_id: model.id(), + token_count: *tokens, + }); + cx.notify(); + } + } } else { let error_message = error .chain() @@ -1153,7 +1188,7 @@ impl Thread { thread.auto_capture_telemetry(cx); if let Ok(initial_usage) = initial_token_usage { - let usage = thread.cumulative_token_usage.clone() - initial_usage; + let usage = thread.cumulative_token_usage - initial_usage; telemetry::event!( "Assistant Thread Completion", @@ -1324,7 +1359,7 @@ impl Thread { .collect::>(); for tool_use in pending_tool_uses.iter() { - if let Some(tool) = self.tools.tool(&tool_use.name, cx) { + if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) { if tool.needs_confirmation(&tool_use.input, cx) && !AssistantSettings::get_global(cx).always_allow_tool_actions { @@ -1376,7 +1411,7 @@ impl Thread { ) -> Task<()> { let tool_name: Arc = tool.name().into(); - let tool_result = if self.tools.is_disabled(&tool.source(), &tool_name) { + let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) { ToolResult { output: Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))), card: None, @@ -1500,6 +1535,7 @@ impl Thread { let enabled_tool_names: Vec = self .tools() + .read(cx) .enabled_tools(cx) .iter() .map(|tool| tool.name().to_string()) @@ -1797,10 +1833,6 @@ impl Thread { &self.project } - pub fn cumulative_token_usage(&self) -> TokenUsage { - self.cumulative_token_usage.clone() - } - pub fn auto_capture_telemetry(&mut self, cx: &mut Context) { if !cx.has_flag::() { return; @@ -1845,6 +1877,10 @@ impl Thread { .detach(); } + pub fn cumulative_token_usage(&self) -> TokenUsage { + self.cumulative_token_usage + } + pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage { let model_registry = LanguageModelRegistry::read_global(cx); let Some(model) = model_registry.default_model() else { @@ -1853,6 +1889,16 @@ impl Thread { let max = model.model.max_token_count(); + if let Some(exceeded_error) = &self.exceeded_window_error { + if model.model.id() == exceeded_error.model_id { + return TotalTokenUsage { + total: exceeded_error.token_count, + max, + ratio: TokenUsageRatio::Exceeded, + }; + } + } + #[cfg(debug_assertions)] let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD") .unwrap_or("0.8".to_string()) @@ -2310,7 +2356,7 @@ fn main() {{ .update(|_, cx| { ThreadStore::load( project.clone(), - Arc::default(), + cx.new(|_| ToolWorkingSet::default()), Arc::new(PromptBuilder::new(None).unwrap()), cx, ) diff --git a/crates/agent/src/thread_history.rs b/crates/agent/src/thread_history.rs index bb0d8ca3fd..ecf5e958a7 100644 --- a/crates/agent/src/thread_history.rs +++ b/crates/agent/src/thread_history.rs @@ -4,11 +4,14 @@ use assistant_context_editor::SavedContextMetadata; use editor::{Editor, EditorEvent}; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{ - App, Entity, FocusHandle, Focusable, ScrollStrategy, Task, UniformListScrollHandle, WeakEntity, - Window, uniform_list, + App, Entity, FocusHandle, Focusable, ScrollStrategy, Stateful, Task, UniformListScrollHandle, + WeakEntity, Window, uniform_list, }; use time::{OffsetDateTime, UtcOffset}; -use ui::{HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Tooltip, prelude::*}; +use ui::{ + HighlightedLabel, IconButtonShape, ListItem, ListItemSpacing, Scrollbar, ScrollbarState, + Tooltip, prelude::*, +}; use util::ResultExt; use crate::history_store::{HistoryEntry, HistoryStore}; @@ -26,6 +29,8 @@ pub struct ThreadHistory { matches: Vec, _subscriptions: Vec, _search_task: Option>, + scrollbar_visibility: bool, + scrollbar_state: ScrollbarState, } impl ThreadHistory { @@ -58,10 +63,13 @@ impl ThreadHistory { this.update_all_entries(cx); }); + let scroll_handle = UniformListScrollHandle::default(); + let scrollbar_state = ScrollbarState::new(scroll_handle.clone()); + Self { assistant_panel, history_store, - scroll_handle: UniformListScrollHandle::default(), + scroll_handle, selected_index: 0, search_query: SharedString::new_static(""), all_entries: entries, @@ -69,6 +77,8 @@ impl ThreadHistory { search_editor, _subscriptions: vec![search_editor_subscription, history_store_subscription], _search_task: None, + scrollbar_visibility: true, + scrollbar_state, } } @@ -220,6 +230,43 @@ impl ThreadHistory { cx.notify(); } + fn render_scrollbar(&self, cx: &mut Context) -> Option> { + if !(self.scrollbar_visibility || self.scrollbar_state.is_dragging()) { + return None; + } + + Some( + div() + .occlude() + .id("thread-history-scroll") + .h_full() + .bg(cx.theme().colors().panel_background.opacity(0.8)) + .border_l_1() + .border_color(cx.theme().colors().border_variant) + .absolute() + .right_1() + .top_0() + .bottom_0() + .w_4() + .pl_1() + .cursor_default() + .on_mouse_move(cx.listener(|_, _, _window, cx| { + cx.notify(); + cx.stop_propagation() + })) + .on_hover(|_, _window, cx| { + cx.stop_propagation(); + }) + .on_any_mouse_down(|_, _window, cx| { + cx.stop_propagation(); + }) + .on_scroll_wheel(cx.listener(|_, _, _window, cx| { + cx.notify(); + })) + .children(Scrollbar::vertical(self.scrollbar_state.clone())), + ) + } + fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { if let Some(entry) = self.get_match(self.selected_index) { let task_result = match entry { @@ -305,7 +352,11 @@ impl Render for ThreadHistory { ) }) .child({ - let view = v_flex().overflow_hidden().flex_grow(); + let view = v_flex() + .id("list-container") + .relative() + .overflow_hidden() + .flex_grow(); if self.all_entries.is_empty() { view.justify_center() @@ -322,59 +373,70 @@ impl Render for ThreadHistory { ), ) } else { - view.p_1().child( - uniform_list( - cx.entity().clone(), - "thread-history", - self.matched_count(), - move |history, range, _window, _cx| { - let range_start = range.start; - let assistant_panel = history.assistant_panel.clone(); + view.pr_5() + .child( + uniform_list( + cx.entity().clone(), + "thread-history", + self.matched_count(), + move |history, range, _window, _cx| { + let range_start = range.start; + let assistant_panel = history.assistant_panel.clone(); - let render_item = |index: usize, - entry: &HistoryEntry, - highlight_positions: Vec| - -> Div { - h_flex().w_full().pb_1().child(match entry { - HistoryEntry::Thread(thread) => PastThread::new( - thread.clone(), - assistant_panel.clone(), - selected_index == index + range_start, - highlight_positions, - ) - .into_any_element(), - HistoryEntry::Context(context) => PastContext::new( - context.clone(), - assistant_panel.clone(), - selected_index == index + range_start, - highlight_positions, - ) - .into_any_element(), - }) - }; - - if history.has_search_query() { - history.matches[range] - .iter() - .enumerate() - .filter_map(|(index, m)| { - history.all_entries.get(m.candidate_id).map(|entry| { - render_item(index, entry, m.positions.clone()) - }) + let render_item = |index: usize, + entry: &HistoryEntry, + highlight_positions: Vec| + -> Div { + h_flex().w_full().pb_1().child(match entry { + HistoryEntry::Thread(thread) => PastThread::new( + thread.clone(), + assistant_panel.clone(), + selected_index == index + range_start, + highlight_positions, + ) + .into_any_element(), + HistoryEntry::Context(context) => PastContext::new( + context.clone(), + assistant_panel.clone(), + selected_index == index + range_start, + highlight_positions, + ) + .into_any_element(), }) - .collect() - } else { - history.all_entries[range] - .iter() - .enumerate() - .map(|(index, entry)| render_item(index, entry, vec![])) - .collect() - } - }, + }; + + if history.has_search_query() { + history.matches[range] + .iter() + .enumerate() + .filter_map(|(index, m)| { + history.all_entries.get(m.candidate_id).map( + |entry| { + render_item( + index, + entry, + m.positions.clone(), + ) + }, + ) + }) + .collect() + } else { + history.all_entries[range] + .iter() + .enumerate() + .map(|(index, entry)| render_item(index, entry, vec![])) + .collect() + } + }, + ) + .p_1() + .track_scroll(self.scroll_handle.clone()) + .flex_grow(), ) - .track_scroll(self.scroll_handle.clone()) - .flex_grow(), - ) + .when_some(self.render_scrollbar(cx), |div, scrollbar| { + div.child(scrollbar) + }) } }) } @@ -440,6 +502,7 @@ impl RenderOnce for PastThread { IconButton::new("delete", IconName::TrashAlt) .shape(IconButtonShape::Square) .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) .tooltip(move |window, cx| { Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx) }) @@ -531,6 +594,7 @@ impl RenderOnce for PastContext { IconButton::new("delete", IconName::TrashAlt) .shape(IconButtonShape::Square) .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) .tooltip(move |window, cx| { Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx) }) diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index c8f8d239a2..6fb0f6c7a2 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -27,7 +27,9 @@ use serde::{Deserialize, Serialize}; use settings::{Settings as _, SettingsStore}; use util::ResultExt as _; -use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId}; +use crate::thread::{ + DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId, +}; const RULES_FILE_NAMES: [&'static str; 6] = [ ".rules", @@ -54,7 +56,7 @@ impl SharedProjectContext { pub struct ThreadStore { project: Entity, - tools: Arc, + tools: Entity, prompt_builder: Arc, context_server_manager: Entity, context_server_tool_ids: HashMap, Vec>, @@ -72,7 +74,7 @@ impl EventEmitter for ThreadStore {} impl ThreadStore { pub fn load( project: Entity, - tools: Arc, + tools: Entity, prompt_builder: Arc, cx: &mut App, ) -> Task> { @@ -86,7 +88,7 @@ impl ThreadStore { fn new( project: Entity, - tools: Arc, + tools: Entity, prompt_builder: Arc, cx: &mut Context, ) -> Self { @@ -246,7 +248,7 @@ impl ThreadStore { self.context_server_manager.clone() } - pub fn tools(&self) -> Arc { + pub fn tools(&self) -> Entity { self.tools.clone() } @@ -353,52 +355,60 @@ impl ThreadStore { }) } - fn load_default_profile(&self, cx: &Context) { + fn load_default_profile(&self, cx: &mut Context) { let assistant_settings = AssistantSettings::get_global(cx); - self.load_profile_by_id(&assistant_settings.default_profile, cx); + self.load_profile_by_id(assistant_settings.default_profile.clone(), cx); } - pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context) { + pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context) { let assistant_settings = AssistantSettings::get_global(cx); - if let Some(profile) = assistant_settings.profiles.get(profile_id) { - self.load_profile(profile, cx); + if let Some(profile) = assistant_settings.profiles.get(&profile_id) { + self.load_profile(profile.clone(), cx); } } - pub fn load_profile(&self, profile: &AgentProfile, cx: &Context) { - self.tools.disable_all_tools(); - self.tools.enable( - ToolSource::Native, - &profile - .tools - .iter() - .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) - .collect::>(), - ); + pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context) { + self.tools.update(cx, |tools, cx| { + tools.disable_all_tools(cx); + tools.enable( + ToolSource::Native, + &profile + .tools + .iter() + .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) + .collect::>(), + cx, + ); + }); if profile.enable_all_context_servers { for context_server in self.context_server_manager.read(cx).all_servers() { - self.tools.enable_source( - ToolSource::ContextServer { - id: context_server.id().into(), - }, - cx, - ); + self.tools.update(cx, |tools, cx| { + tools.enable_source( + ToolSource::ContextServer { + id: context_server.id().into(), + }, + cx, + ); + }); } } else { for (context_server_id, preset) in &profile.context_servers { - self.tools.enable( - ToolSource::ContextServer { - id: context_server_id.clone().into(), - }, - &preset - .tools - .iter() - .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) - .collect::>(), - ) + self.tools.update(cx, |tools, cx| { + tools.enable( + ToolSource::ContextServer { + id: context_server_id.clone().into(), + }, + &preset + .tools + .iter() + .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) + .collect::>(), + cx, + ) + }) } } } @@ -432,29 +442,36 @@ impl ThreadStore { if protocol.capable(context_server::protocol::ServerCapability::Tools) { if let Some(tools) = protocol.list_tools().await.log_err() { - let tool_ids = tools - .tools - .into_iter() - .map(|tool| { - log::info!( - "registering context server tool: {:?}", - tool.name - ); - tool_working_set.insert(Arc::new( - ContextServerTool::new( - context_server_manager.clone(), - server.id(), - tool, - ), - )) + let tool_ids = tool_working_set + .update(cx, |tool_working_set, _| { + tools + .tools + .into_iter() + .map(|tool| { + log::info!( + "registering context server tool: {:?}", + tool.name + ); + tool_working_set.insert(Arc::new( + ContextServerTool::new( + context_server_manager.clone(), + server.id(), + tool, + ), + )) + }) + .collect::>() }) - .collect::>(); + .log_err(); - this.update(cx, |this, cx| { - this.context_server_tool_ids.insert(server_id, tool_ids); - this.load_default_profile(cx); - }) - .log_err(); + if let Some(tool_ids) = tool_ids { + this.update(cx, |this, cx| { + this.context_server_tool_ids + .insert(server_id, tool_ids); + this.load_default_profile(cx); + }) + .log_err(); + } } } } @@ -464,7 +481,9 @@ impl ThreadStore { } context_server::manager::Event::ServerStopped { server_id } => { if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { - tool_working_set.remove(&tool_ids); + tool_working_set.update(cx, |tool_working_set, _| { + tool_working_set.remove(&tool_ids); + }); self.load_default_profile(cx); } } @@ -491,6 +510,8 @@ pub struct SerializedThread { pub cumulative_token_usage: TokenUsage, #[serde(default)] pub detailed_summary_state: DetailedSummaryState, + #[serde(default)] + pub exceeded_window_error: Option, } impl SerializedThread { @@ -577,6 +598,7 @@ impl LegacySerializedThread { initial_project_snapshot: self.initial_project_snapshot, cumulative_token_usage: TokenUsage::default(), detailed_summary_state: DetailedSummaryState::default(), + exceeded_window_error: None, } } } diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs index 7abecbd4a0..32876a100c 100644 --- a/crates/agent/src/tool_use.rs +++ b/crates/agent/src/tool_use.rs @@ -5,7 +5,7 @@ use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet}; use collections::HashMap; use futures::FutureExt as _; use futures::future::Shared; -use gpui::{App, SharedString, Task}; +use gpui::{App, Entity, SharedString, Task}; use language_model::{ LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, @@ -30,7 +30,7 @@ pub struct ToolUse { pub const USING_TOOL_MARKER: &str = ""; pub struct ToolUseState { - tools: Arc, + tools: Entity, tool_uses_by_assistant_message: HashMap>, tool_uses_by_user_message: HashMap>, tool_results: HashMap, @@ -39,7 +39,7 @@ pub struct ToolUseState { } impl ToolUseState { - pub fn new(tools: Arc) -> Self { + pub fn new(tools: Entity) -> Self { Self { tools, tool_uses_by_assistant_message: HashMap::default(), @@ -54,7 +54,7 @@ impl ToolUseState { /// /// Accepts a function to filter the tools that should be used to populate the state. pub fn from_serialized_messages( - tools: Arc, + tools: Entity, messages: &[SerializedMessage], mut filter_by_tool_name: impl FnMut(&str) -> bool, ) -> Self { @@ -180,12 +180,12 @@ impl ToolUseState { } })(); - let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx) - { - (tool.icon(), tool.needs_confirmation(&tool_use.input, cx)) - } else { - (IconName::Cog, false) - }; + let (icon, needs_confirmation) = + if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) { + (tool.icon(), tool.needs_confirmation(&tool_use.input, cx)) + } else { + (IconName::Cog, false) + }; tool_uses.push(ToolUse { id: tool_use.id.clone(), @@ -207,7 +207,7 @@ impl ToolUseState { input: &serde_json::Value, cx: &App, ) -> SharedString { - if let Some(tool) = self.tools.tool(tool_name, cx) { + if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) { tool.ui_text(input).into() } else { format!("Unknown tool {tool_name:?}").into() diff --git a/crates/anthropic/Cargo.toml b/crates/anthropic/Cargo.toml index 1735579729..8e82c7cdd6 100644 --- a/crates/anthropic/Cargo.toml +++ b/crates/anthropic/Cargo.toml @@ -25,5 +25,4 @@ serde.workspace = true serde_json.workspace = true strum.workspace = true thiserror.workspace = true -util.workspace = true workspace-hack.workspace = true diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index e0c215bc3a..266d3c7642 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -10,7 +10,6 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use serde::{Deserialize, Serialize}; use strum::{EnumIter, EnumString}; use thiserror::Error; -use util::ResultExt as _; pub use supported_countries::*; @@ -363,11 +362,25 @@ pub struct RateLimitInfo { impl RateLimitInfo { fn from_headers(headers: &HeaderMap) -> Self { + // Check if any rate limit headers exist + let has_rate_limit_headers = headers + .keys() + .any(|k| k.as_str().starts_with("anthropic-ratelimit-")); + + if !has_rate_limit_headers { + return Self { + requests: None, + tokens: None, + input_tokens: None, + output_tokens: None, + }; + } + Self { - requests: RateLimit::from_headers("requests", headers).log_err(), - tokens: RateLimit::from_headers("tokens", headers).log_err(), - input_tokens: RateLimit::from_headers("input-tokens", headers).log_err(), - output_tokens: RateLimit::from_headers("output-tokens", headers).log_err(), + requests: RateLimit::from_headers("requests", headers).ok(), + tokens: RateLimit::from_headers("tokens", headers).ok(), + input_tokens: RateLimit::from_headers("input-tokens", headers).ok(), + output_tokens: RateLimit::from_headers("output-tokens", headers).ok(), } } } @@ -724,4 +737,54 @@ impl ApiError { pub fn is_rate_limit_error(&self) -> bool { matches!(self.error_type.as_str(), "rate_limit_error") } + + pub fn match_window_exceeded(&self) -> Option { + let Some(ApiErrorCode::InvalidRequestError) = self.code() else { + return None; + }; + + parse_prompt_too_long(&self.message) + } +} + +pub fn parse_prompt_too_long(message: &str) -> Option { + message + .strip_prefix("prompt is too long: ")? + .split_once(" tokens")? + .0 + .parse::() + .ok() +} + +#[test] +fn test_match_window_exceeded() { + let error = ApiError { + error_type: "invalid_request_error".to_string(), + message: "prompt is too long: 220000 tokens > 200000".to_string(), + }; + assert_eq!(error.match_window_exceeded(), Some(220_000)); + + let error = ApiError { + error_type: "invalid_request_error".to_string(), + message: "prompt is too long: 1234953 tokens".to_string(), + }; + assert_eq!(error.match_window_exceeded(), Some(1234953)); + + let error = ApiError { + error_type: "invalid_request_error".to_string(), + message: "not a prompt length error".to_string(), + }; + assert_eq!(error.match_window_exceeded(), None); + + let error = ApiError { + error_type: "rate_limit_error".to_string(), + message: "prompt is too long: 12345 tokens".to_string(), + }; + assert_eq!(error.match_window_exceeded(), None); + + let error = ApiError { + error_type: "invalid_request_error".to_string(), + message: "prompt is too long: invalid tokens".to_string(), + }; + assert_eq!(error.match_window_exceeded(), None); } diff --git a/crates/assistant_tool/src/action_log.rs b/crates/assistant_tool/src/action_log.rs index 1c0911c189..fa305a512e 100644 --- a/crates/assistant_tool/src/action_log.rs +++ b/crates/assistant_tool/src/action_log.rs @@ -4,7 +4,7 @@ use collections::BTreeMap; use futures::{StreamExt, channel::mpsc}; use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity}; use language::{Anchor, Buffer, BufferEvent, DiskState, Point}; -use project::{Project, ProjectItem}; +use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle}; use std::{cmp, ops::Range, sync::Arc}; use text::{Edit, Patch, Rope}; use util::RangeExt; @@ -49,6 +49,10 @@ impl ActionLog { .tracked_buffers .entry(buffer.clone()) .or_insert_with(|| { + let open_lsp_handle = self.project.update(cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + }); + let text_snapshot = buffer.read(cx).text_snapshot(); let diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx)); let (diff_update_tx, diff_update_rx) = mpsc::unbounded(); @@ -76,6 +80,7 @@ impl ActionLog { version: buffer.read(cx).version(), diff, diff_update: diff_update_tx, + _open_lsp_handle: open_lsp_handle, _maintain_diff: cx.spawn({ let buffer = buffer.clone(); async move |this, cx| { @@ -615,6 +620,7 @@ struct TrackedBuffer { diff: Entity, snapshot: text::BufferSnapshot, diff_update: mpsc::UnboundedSender<(ChangeAuthor, text::BufferSnapshot)>, + _open_lsp_handle: OpenLspBufferHandle, _maintain_diff: Task<()>, _subscription: Subscription, } diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index c91ac3b8a5..cb7f0ff518 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -1,5 +1,6 @@ mod action_log; mod tool_registry; +mod tool_schema; mod tool_working_set; use std::fmt; @@ -20,6 +21,7 @@ use project::Project; pub use crate::action_log::*; pub use crate::tool_registry::*; +pub use crate::tool_schema::*; pub use crate::tool_working_set::*; pub fn init(cx: &mut App) { @@ -139,8 +141,8 @@ pub trait Tool: 'static + Send + Sync { fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool; /// Returns the JSON schema that describes the tool's input. - fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> serde_json::Value { - serde_json::Value::Object(serde_json::Map::default()) + fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result { + Ok(serde_json::Value::Object(serde_json::Map::default())) } /// Returns markdown to be displayed in the UI for this tool. diff --git a/crates/assistant_tool/src/tool_schema.rs b/crates/assistant_tool/src/tool_schema.rs new file mode 100644 index 0000000000..225c1c22ef --- /dev/null +++ b/crates/assistant_tool/src/tool_schema.rs @@ -0,0 +1,236 @@ +use anyhow::Result; +use serde_json::Value; + +use crate::LanguageModelToolSchemaFormat; + +/// Tries to adapt a JSON schema representation to be compatible with the specified format. +/// +/// If the json cannot be made compatible with the specified format, an error is returned. +pub fn adapt_schema_to_format( + json: &mut Value, + format: LanguageModelToolSchemaFormat, +) -> Result<()> { + match format { + LanguageModelToolSchemaFormat::JsonSchema => Ok(()), + LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json), + } +} + +/// Tries to adapt the json schema so that it is compatible with https://ai.google.dev/api/caching#Schema +fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> { + if let Value::Object(obj) = json { + const UNSUPPORTED_KEYS: [&str; 4] = ["if", "then", "else", "$ref"]; + + for key in UNSUPPORTED_KEYS { + if obj.contains_key(key) { + return Err(anyhow::anyhow!( + "Schema cannot be made compatible because it contains \"{}\" ", + key + )); + } + } + + const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"]; + for key in KEYS_TO_REMOVE { + obj.remove(key); + } + + if let Some(default) = obj.get("default") { + let is_null = default.is_null(); + // Default is not supported, so we need to remove it + obj.remove("default"); + if is_null { + obj.insert("nullable".to_string(), Value::Bool(true)); + } + } + + // If a type is not specified for an input parameter, add a default type + if obj.contains_key("description") + && !obj.contains_key("type") + && !(obj.contains_key("anyOf") + || obj.contains_key("oneOf") + || obj.contains_key("allOf")) + { + obj.insert("type".to_string(), Value::String("string".to_string())); + } + + // Handle oneOf -> anyOf conversion + if let Some(subschemas) = obj.get_mut("oneOf") { + if subschemas.is_array() { + let subschemas_clone = subschemas.clone(); + obj.remove("oneOf"); + obj.insert("anyOf".to_string(), subschemas_clone); + } + } + + // Recursively process all nested objects and arrays + for (_, value) in obj.iter_mut() { + if let Value::Object(_) | Value::Array(_) = value { + adapt_to_json_schema_subset(value)?; + } + } + } else if let Value::Array(arr) = json { + for item in arr.iter_mut() { + adapt_to_json_schema_subset(item)?; + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_transform_default_null_to_nullable() { + let mut json = json!({ + "description": "A test field", + "type": "string", + "default": null + }); + + adapt_to_json_schema_subset(&mut json).unwrap(); + + assert_eq!( + json, + json!({ + "description": "A test field", + "type": "string", + "nullable": true + }) + ); + } + + #[test] + fn test_transform_adds_type_when_missing() { + let mut json = json!({ + "description": "A test field without type" + }); + + adapt_to_json_schema_subset(&mut json).unwrap(); + + assert_eq!( + json, + json!({ + "description": "A test field without type", + "type": "string" + }) + ); + } + + #[test] + fn test_transform_removes_format() { + let mut json = json!({ + "description": "A test field", + "type": "integer", + "format": "uint32" + }); + + adapt_to_json_schema_subset(&mut json).unwrap(); + + assert_eq!( + json, + json!({ + "description": "A test field", + "type": "integer" + }) + ); + } + + #[test] + fn test_transform_one_of_to_any_of() { + let mut json = json!({ + "description": "A test field", + "oneOf": [ + { "type": "string" }, + { "type": "integer" } + ] + }); + + adapt_to_json_schema_subset(&mut json).unwrap(); + + assert_eq!( + json, + json!({ + "description": "A test field", + "anyOf": [ + { "type": "string" }, + { "type": "integer" } + ] + }) + ); + } + + #[test] + fn test_transform_nested_objects() { + let mut json = json!({ + "type": "object", + "properties": { + "nested": { + "oneOf": [ + { "type": "string" }, + { "type": "null" } + ], + "format": "email" + } + } + }); + + adapt_to_json_schema_subset(&mut json).unwrap(); + + assert_eq!( + json, + json!({ + "type": "object", + "properties": { + "nested": { + "anyOf": [ + { "type": "string" }, + { "type": "null" } + ] + } + } + }) + ); + } + + #[test] + fn test_transform_fails_if_unsupported_keys_exist() { + let mut json = json!({ + "type": "object", + "properties": { + "$ref": "#/definitions/User", + } + }); + + assert!(adapt_to_json_schema_subset(&mut json).is_err()); + + let mut json = json!({ + "type": "object", + "properties": { + "if": "...", + } + }); + + assert!(adapt_to_json_schema_subset(&mut json).is_err()); + + let mut json = json!({ + "type": "object", + "properties": { + "then": "...", + } + }); + + assert!(adapt_to_json_schema_subset(&mut json).is_err()); + + let mut json = json!({ + "type": "object", + "properties": { + "else": "...", + } + }); + + assert!(adapt_to_json_schema_subset(&mut json).is_err()); + } +} diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index 97060cfdad..c7e20d3517 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -1,8 +1,7 @@ use std::sync::Arc; use collections::{HashMap, HashSet, IndexMap}; -use gpui::App; -use parking_lot::Mutex; +use gpui::{App, Context, EventEmitter}; use crate::{Tool, ToolRegistry, ToolSource}; @@ -12,11 +11,6 @@ pub struct ToolId(usize); /// A working set of tools for use in one instance of the Assistant Panel. #[derive(Default)] pub struct ToolWorkingSet { - state: Mutex, -} - -#[derive(Default)] -struct WorkingSetState { context_server_tools_by_id: HashMap>, context_server_tools_by_name: HashMap>, enabled_sources: HashSet, @@ -24,99 +18,27 @@ struct WorkingSetState { next_tool_id: ToolId, } +pub enum ToolWorkingSetEvent { + EnabledToolsChanged, +} + +impl EventEmitter for ToolWorkingSet {} + impl ToolWorkingSet { pub fn tool(&self, name: &str, cx: &App) -> Option> { - self.state - .lock() - .context_server_tools_by_name + self.context_server_tools_by_name .get(name) .cloned() .or_else(|| ToolRegistry::global(cx).tool(name)) } pub fn tools(&self, cx: &App) -> Vec> { - self.state.lock().tools(cx) - } - - pub fn tools_by_source(&self, cx: &App) -> IndexMap>> { - self.state.lock().tools_by_source(cx) - } - - pub fn enabled_tools(&self, cx: &App) -> Vec> { - self.state.lock().enabled_tools(cx) - } - - pub fn disable_all_tools(&self) { - let mut state = self.state.lock(); - state.disable_all_tools(); - } - - pub fn enable_source(&self, source: ToolSource, cx: &App) { - let mut state = self.state.lock(); - state.enable_source(source, cx); - } - - pub fn disable_source(&self, source: &ToolSource) { - let mut state = self.state.lock(); - state.disable_source(source); - } - - pub fn insert(&self, tool: Arc) -> ToolId { - let mut state = self.state.lock(); - let tool_id = state.next_tool_id; - state.next_tool_id.0 += 1; - state - .context_server_tools_by_id - .insert(tool_id, tool.clone()); - state.tools_changed(); - tool_id - } - - pub fn is_enabled(&self, source: &ToolSource, name: &Arc) -> bool { - self.state.lock().is_enabled(source, name) - } - - pub fn is_disabled(&self, source: &ToolSource, name: &Arc) -> bool { - self.state.lock().is_disabled(source, name) - } - - pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc]) { - let mut state = self.state.lock(); - state.enable(source, tools_to_enable); - } - - pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc]) { - let mut state = self.state.lock(); - state.disable(source, tools_to_disable); - } - - pub fn remove(&self, tool_ids_to_remove: &[ToolId]) { - let mut state = self.state.lock(); - state - .context_server_tools_by_id - .retain(|id, _| !tool_ids_to_remove.contains(id)); - state.tools_changed(); - } -} - -impl WorkingSetState { - fn tools_changed(&mut self) { - self.context_server_tools_by_name.clear(); - self.context_server_tools_by_name.extend( - self.context_server_tools_by_id - .values() - .map(|tool| (tool.name(), tool.clone())), - ); - } - - fn tools(&self, cx: &App) -> Vec> { let mut tools = ToolRegistry::global(cx).tools(); tools.extend(self.context_server_tools_by_id.values().cloned()); - tools } - fn tools_by_source(&self, cx: &App) -> IndexMap>> { + pub fn tools_by_source(&self, cx: &App) -> IndexMap>> { let mut tools_by_source = IndexMap::default(); for tool in self.tools(cx) { @@ -135,7 +57,7 @@ impl WorkingSetState { tools_by_source } - fn enabled_tools(&self, cx: &App) -> Vec> { + pub fn enabled_tools(&self, cx: &App) -> Vec> { let all_tools = self.tools(cx); all_tools @@ -144,31 +66,12 @@ impl WorkingSetState { .collect() } - fn is_enabled(&self, source: &ToolSource, name: &Arc) -> bool { - self.enabled_tools_by_source - .get(source) - .map_or(false, |enabled_tools| enabled_tools.contains(name)) + pub fn disable_all_tools(&mut self, cx: &mut Context) { + self.enabled_tools_by_source.clear(); + cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); } - fn is_disabled(&self, source: &ToolSource, name: &Arc) -> bool { - !self.is_enabled(source, name) - } - - fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc]) { - self.enabled_tools_by_source - .entry(source) - .or_default() - .extend(tools_to_enable.into_iter().cloned()); - } - - fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc]) { - self.enabled_tools_by_source - .entry(source) - .or_default() - .retain(|name| !tools_to_disable.contains(name)); - } - - fn enable_source(&mut self, source: ToolSource, cx: &App) { + pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context) { self.enabled_sources.insert(source.clone()); let tools_by_source = self.tools_by_source(cx); @@ -181,14 +84,72 @@ impl WorkingSetState { .collect::>(), ); } + cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); } - fn disable_source(&mut self, source: &ToolSource) { + pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context) { self.enabled_sources.remove(source); self.enabled_tools_by_source.remove(source); + cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); } - fn disable_all_tools(&mut self) { - self.enabled_tools_by_source.clear(); + pub fn insert(&mut self, tool: Arc) -> ToolId { + let tool_id = self.next_tool_id; + self.next_tool_id.0 += 1; + self.context_server_tools_by_id + .insert(tool_id, tool.clone()); + self.tools_changed(); + tool_id + } + + pub fn is_enabled(&self, source: &ToolSource, name: &Arc) -> bool { + self.enabled_tools_by_source + .get(source) + .map_or(false, |enabled_tools| enabled_tools.contains(name)) + } + + pub fn is_disabled(&self, source: &ToolSource, name: &Arc) -> bool { + !self.is_enabled(source, name) + } + + pub fn enable( + &mut self, + source: ToolSource, + tools_to_enable: &[Arc], + cx: &mut Context, + ) { + self.enabled_tools_by_source + .entry(source) + .or_default() + .extend(tools_to_enable.into_iter().cloned()); + cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); + } + + pub fn disable( + &mut self, + source: ToolSource, + tools_to_disable: &[Arc], + cx: &mut Context, + ) { + self.enabled_tools_by_source + .entry(source) + .or_default() + .retain(|name| !tools_to_disable.contains(name)); + cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); + } + + pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) { + self.context_server_tools_by_id + .retain(|id, _| !tool_ids_to_remove.contains(id)); + self.tools_changed(); + } + + fn tools_changed(&mut self) { + self.context_server_tools_by_name.clear(); + self.context_server_tools_by_name.extend( + self.context_server_tools_by_id + .values() + .map(|tool| (tool.name(), tool.clone())), + ); } } diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index d7d901fd86..adf273798a 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -1,6 +1,7 @@ mod batch_tool; mod code_action_tool; mod code_symbols_tool; +mod contents_tool; mod copy_path_tool; mod create_directory_tool; mod create_file_tool; @@ -35,6 +36,7 @@ use web_search_tool::WebSearchTool; use crate::batch_tool::BatchTool; use crate::code_action_tool::CodeActionTool; use crate::code_symbols_tool::CodeSymbolsTool; +use crate::contents_tool::ContentsTool; use crate::create_directory_tool::CreateDirectoryTool; use crate::create_file_tool::CreateFileTool; use crate::delete_path_tool::DeletePathTool; @@ -59,6 +61,7 @@ pub fn init(http_client: Arc, cx: &mut App) { registry.register_tool(BatchTool); registry.register_tool(CodeActionTool); registry.register_tool(CodeSymbolsTool); + registry.register_tool(ContentsTool); registry.register_tool(CopyPathTool); registry.register_tool(CreateDirectoryTool); registry.register_tool(CreateFileTool); @@ -79,3 +82,42 @@ pub fn init(http_client: Arc, cx: &mut App) { registry.register_tool(ThinkingTool); registry.register_tool(WebSearchTool); } + +#[cfg(test)] +mod tests { + use http_client::FakeHttpClient; + + use super::*; + + #[gpui::test] + fn test_builtin_tool_schema_compatibility(cx: &mut App) { + crate::init( + Arc::new(http_client::HttpClientWithUrl::new( + FakeHttpClient::with_200_response(), + "https://zed.dev", + None, + )), + cx, + ); + + for tool in ToolRegistry::global(cx).tools() { + let actual_schema = tool + .input_schema(language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset) + .unwrap(); + let mut expected_schema = actual_schema.clone(); + assistant_tool::adapt_schema_to_format( + &mut expected_schema, + language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset, + ) + .unwrap(); + + let error_message = format!( + "Tool schema for `{}` is not compatible with `language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset` (Gemini Models).\n\ + Are you using `schema::json_schema_for(format)` to generate the schema?", + tool.name(), + ); + + assert_eq!(actual_schema, expected_schema, "{}", error_message) + } + } +} diff --git a/crates/assistant_tools/src/batch_tool.rs b/crates/assistant_tools/src/batch_tool.rs index a195cee4d6..87e70e1e62 100644 --- a/crates/assistant_tools/src/batch_tool.rs +++ b/crates/assistant_tools/src/batch_tool.rs @@ -172,7 +172,7 @@ impl Tool for BatchTool { IconName::Cog } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/code_action_tool.rs b/crates/assistant_tools/src/code_action_tool.rs index 119cce7669..4bf39f1deb 100644 --- a/crates/assistant_tools/src/code_action_tool.rs +++ b/crates/assistant_tools/src/code_action_tool.rs @@ -2,7 +2,7 @@ use anyhow::{Context as _, anyhow}; use assistant_tool::{ActionLog, Tool, ToolResult}; use gpui::{App, Entity, Task}; use language::{self, Anchor, Buffer, ToPointUtf16}; -use language_model::LanguageModelRequestMessage; +use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use project::{self, LspAction, Project}; use regex::Regex; use schemars::JsonSchema; @@ -10,6 +10,8 @@ use serde::{Deserialize, Serialize}; use std::{ops::Range, sync::Arc}; use ui::IconName; +use crate::schema::json_schema_for; + #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct CodeActionToolInput { /// The relative path to the file containing the text range. @@ -95,12 +97,8 @@ impl Tool for CodeActionTool { IconName::Wand } - fn input_schema( - &self, - _format: language_model::LanguageModelToolSchemaFormat, - ) -> serde_json::Value { - let schema = schemars::schema_for!(CodeActionToolInput); - serde_json::to_value(&schema).unwrap() + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + json_schema_for::(format) } fn ui_text(&self, input: &serde_json::Value) -> String { diff --git a/crates/assistant_tools/src/code_symbols_tool.rs b/crates/assistant_tools/src/code_symbols_tool.rs index eabc18c486..78dea96d5b 100644 --- a/crates/assistant_tools/src/code_symbols_tool.rs +++ b/crates/assistant_tools/src/code_symbols_tool.rs @@ -91,7 +91,7 @@ impl Tool for CodeSymbolsTool { IconName::Code } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/contents_tool.rs b/crates/assistant_tools/src/contents_tool.rs new file mode 100644 index 0000000000..be7c4927cb --- /dev/null +++ b/crates/assistant_tools/src/contents_tool.rs @@ -0,0 +1,239 @@ +use std::sync::Arc; + +use crate::{code_symbols_tool::file_outline, schema::json_schema_for}; +use anyhow::{Result, anyhow}; +use assistant_tool::{ActionLog, Tool}; +use gpui::{App, Entity, Task}; +use itertools::Itertools; +use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::{fmt::Write, path::Path}; +use ui::IconName; +use util::markdown::MarkdownString; + +/// If the model requests to read a file whose size exceeds this, then +/// the tool will return the file's symbol outline instead of its contents, +/// and suggest trying again using line ranges from the outline. +const MAX_FILE_SIZE_TO_READ: usize = 16384; + +/// If the model requests to list the entries in a directory with more +/// entries than this, then the tool will return a subset of the entries +/// and suggest trying again. +const MAX_DIR_ENTRIES: usize = 1024; + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct ContentsToolInput { + /// The relative path of the file or directory to access. + /// + /// This path should never be absolute, and the first component + /// of the path should always be a root directory in a project. + /// + /// + /// If the project has the following root directories: + /// + /// - directory1 + /// - directory2 + /// + /// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`. + /// If you want to list contents in the directory `directory2/subfolder`, you should use the path `directory2/subfolder`. + /// + pub path: String, + + /// Optional position (1-based index) to start reading on, if you want to read a subset of the contents. + /// When reading a file, this refers to a line number in the file (e.g. 1 is the first line). + /// When reading a directory, this refers to the number of the directory entry (e.g. 1 is the first entry). + /// + /// Defaults to 1. + pub start: Option, + + /// Optional position (1-based index) to end reading on, if you want to read a subset of the contents. + /// When reading a file, this refers to a line number in the file (e.g. 1 is the first line). + /// When reading a directory, this refers to the number of the directory entry (e.g. 1 is the first entry). + /// + /// Defaults to reading until the end of the file or directory. + pub end: Option, +} + +pub struct ContentsTool; + +impl Tool for ContentsTool { + fn name(&self) -> String { + "contents".into() + } + + fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + false + } + + fn description(&self) -> String { + include_str!("./contents_tool/description.md").into() + } + + fn icon(&self) -> IconName { + IconName::FileSearch + } + + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + json_schema_for::(format) + } + + fn ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()) { + Ok(input) => { + let path = MarkdownString::inline_code(&input.path); + + match (input.start, input.end) { + (Some(start), None) => format!("Read {path} (from line {start})"), + (Some(start), Some(end)) => { + format!("Read {path} (lines {start}-{end})") + } + _ => format!("Read {path}"), + } + } + Err(_) => "Read file or directory".to_string(), + } + } + + fn run( + self: Arc, + input: serde_json::Value, + _messages: &[LanguageModelRequestMessage], + project: Entity, + action_log: Entity, + cx: &mut App, + ) -> Task> { + let input = match serde_json::from_value::(input) { + Ok(input) => input, + Err(err) => return Task::ready(Err(anyhow!(err))), + }; + + // Sometimes models will return these even though we tell it to give a path and not a glob. + // When this happens, just list the root worktree directories. + if matches!(input.path.as_str(), "." | "" | "./" | "*") { + let output = project + .read(cx) + .worktrees(cx) + .filter_map(|worktree| { + worktree.read(cx).root_entry().and_then(|entry| { + if entry.is_dir() { + entry.path.to_str() + } else { + None + } + }) + }) + .collect::>() + .join("\n"); + + return Task::ready(Ok(output)); + } + + let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { + return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))); + }; + + let Some(worktree) = project + .read(cx) + .worktree_for_id(project_path.worktree_id, cx) + else { + return Task::ready(Err(anyhow!("Worktree not found"))); + }; + let worktree = worktree.read(cx); + + let Some(entry) = worktree.entry_for_path(&project_path.path) else { + return Task::ready(Err(anyhow!("Path not found: {}", input.path))); + }; + + // If it's a directory, list its contents + if entry.is_dir() { + let mut output = String::new(); + let start_index = input + .start + .map(|line| (line as usize).saturating_sub(1)) + .unwrap_or(0); + let end_index = input + .end + .map(|line| (line as usize).saturating_sub(1)) + .unwrap_or(MAX_DIR_ENTRIES); + let mut skipped = 0; + + for (index, entry) in worktree.child_entries(&project_path.path).enumerate() { + if index >= start_index && index <= end_index { + writeln!( + output, + "{}", + Path::new(worktree.root_name()).join(&entry.path).display(), + ) + .unwrap(); + } else { + skipped += 1; + } + } + + if output.is_empty() { + output.push_str(&input.path); + output.push_str(" is empty."); + } + + if skipped > 0 { + write!( + output, + "\n\nNote: Skipped {skipped} entries. Adjust start and end to see other entries.", + ).ok(); + } + + Task::ready(Ok(output)) + } else { + // It's a file, so read its contents + let file_path = input.path.clone(); + cx.spawn(async move |cx| { + let buffer = cx + .update(|cx| { + project.update(cx, |project, cx| project.open_buffer(project_path, cx)) + })? + .await?; + + if input.start.is_some() || input.end.is_some() { + let result = buffer.read_with(cx, |buffer, _cx| { + let text = buffer.text(); + let start = input.start.unwrap_or(1); + let lines = text.split('\n').skip(start as usize - 1); + if let Some(end) = input.end { + let count = end.saturating_sub(start).max(1); // Ensure at least 1 line + Itertools::intersperse(lines.take(count as usize), "\n").collect() + } else { + Itertools::intersperse(lines, "\n").collect() + } + })?; + + action_log.update(cx, |log, cx| { + log.buffer_read(buffer, cx); + })?; + + Ok(result) + } else { + // No line ranges specified, so check file size to see if it's too big. + let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?; + + if file_size <= MAX_FILE_SIZE_TO_READ { + let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?; + + action_log.update(cx, |log, cx| { + log.buffer_read(buffer, cx); + })?; + + Ok(result) + } else { + // File is too big, so return its outline and a suggestion to + // read again with a line number range specified. + let outline = file_outline(project, file_path, action_log, None, 0, cx).await?; + + Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start and end fields to see the implementations of symbols in the outline.")) + } + } + }) + } + } +} diff --git a/crates/assistant_tools/src/contents_tool/description.md b/crates/assistant_tools/src/contents_tool/description.md new file mode 100644 index 0000000000..b532f7c534 --- /dev/null +++ b/crates/assistant_tools/src/contents_tool/description.md @@ -0,0 +1,9 @@ +Reads the contents of a path on the filesystem. + +If the path is a directory, this lists all files and directories within that path. +If the path is a file, this returns the file's contents. + +When reading a file, if the file is too big and no line range is specified, an outline of the file's code symbols is listed instead, which can be used to request specific line ranges in a subsequent call. + +Similarly, if a directory has too many entries to show at once, a subset of entries will be shown, +and subsequent requests can use starting and ending line numbers to get other subsets. diff --git a/crates/assistant_tools/src/copy_path_tool.rs b/crates/assistant_tools/src/copy_path_tool.rs index 3d7e407605..d2cc3f006b 100644 --- a/crates/assistant_tools/src/copy_path_tool.rs +++ b/crates/assistant_tools/src/copy_path_tool.rs @@ -55,7 +55,7 @@ impl Tool for CopyPathTool { IconName::Clipboard } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/create_directory_tool.rs b/crates/assistant_tools/src/create_directory_tool.rs index bf8cd1027e..aa5538d6e9 100644 --- a/crates/assistant_tools/src/create_directory_tool.rs +++ b/crates/assistant_tools/src/create_directory_tool.rs @@ -45,7 +45,7 @@ impl Tool for CreateDirectoryTool { IconName::Folder } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/create_file_tool.rs b/crates/assistant_tools/src/create_file_tool.rs index 2fb3b0c895..24ef5e186e 100644 --- a/crates/assistant_tools/src/create_file_tool.rs +++ b/crates/assistant_tools/src/create_file_tool.rs @@ -52,7 +52,7 @@ impl Tool for CreateFileTool { IconName::FileCreate } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/delete_path_tool.rs b/crates/assistant_tools/src/delete_path_tool.rs index 5f96db189f..c63bcc9507 100644 --- a/crates/assistant_tools/src/delete_path_tool.rs +++ b/crates/assistant_tools/src/delete_path_tool.rs @@ -45,7 +45,7 @@ impl Tool for DeletePathTool { IconName::FileDelete } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/diagnostics_tool.rs b/crates/assistant_tools/src/diagnostics_tool.rs index 234fa17e16..956bee397a 100644 --- a/crates/assistant_tools/src/diagnostics_tool.rs +++ b/crates/assistant_tools/src/diagnostics_tool.rs @@ -58,7 +58,7 @@ impl Tool for DiagnosticsTool { IconName::XCircle } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/fetch_tool.rs b/crates/assistant_tools/src/fetch_tool.rs index 306eb5041d..d2d596e17e 100644 --- a/crates/assistant_tools/src/fetch_tool.rs +++ b/crates/assistant_tools/src/fetch_tool.rs @@ -128,7 +128,7 @@ impl Tool for FetchTool { IconName::Globe } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/find_replace_file_tool.rs b/crates/assistant_tools/src/find_replace_file_tool.rs index 11d1109fed..b58ea142e5 100644 --- a/crates/assistant_tools/src/find_replace_file_tool.rs +++ b/crates/assistant_tools/src/find_replace_file_tool.rs @@ -151,7 +151,7 @@ impl Tool for FindReplaceFileTool { IconName::Pencil } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/list_directory_tool.rs b/crates/assistant_tools/src/list_directory_tool.rs index f2295fc3ba..f241fc47da 100644 --- a/crates/assistant_tools/src/list_directory_tool.rs +++ b/crates/assistant_tools/src/list_directory_tool.rs @@ -56,7 +56,7 @@ impl Tool for ListDirectoryTool { IconName::Folder } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/move_path_tool.rs b/crates/assistant_tools/src/move_path_tool.rs index 8e3a70bc3d..f457be6590 100644 --- a/crates/assistant_tools/src/move_path_tool.rs +++ b/crates/assistant_tools/src/move_path_tool.rs @@ -54,7 +54,7 @@ impl Tool for MovePathTool { IconName::ArrowRightLeft } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/now_tool.rs b/crates/assistant_tools/src/now_tool.rs index eebe040f3e..af1f174a30 100644 --- a/crates/assistant_tools/src/now_tool.rs +++ b/crates/assistant_tools/src/now_tool.rs @@ -45,7 +45,7 @@ impl Tool for NowTool { IconName::Info } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/open_tool.rs b/crates/assistant_tools/src/open_tool.rs index 2429039087..cac095ef74 100644 --- a/crates/assistant_tools/src/open_tool.rs +++ b/crates/assistant_tools/src/open_tool.rs @@ -35,7 +35,7 @@ impl Tool for OpenTool { IconName::ArrowUpRight } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/path_search_tool.rs b/crates/assistant_tools/src/path_search_tool.rs index 52aa4ef510..672adbfa90 100644 --- a/crates/assistant_tools/src/path_search_tool.rs +++ b/crates/assistant_tools/src/path_search_tool.rs @@ -53,7 +53,7 @@ impl Tool for PathSearchTool { IconName::SearchCode } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index dd92c02486..8b941e7ccf 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -63,7 +63,7 @@ impl Tool for ReadFileTool { IconName::FileSearch } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/regex_search_tool.rs b/crates/assistant_tools/src/regex_search_tool.rs index a9ef728f47..7005cbdf4a 100644 --- a/crates/assistant_tools/src/regex_search_tool.rs +++ b/crates/assistant_tools/src/regex_search_tool.rs @@ -60,7 +60,7 @@ impl Tool for RegexSearchTool { IconName::Regex } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/rename_tool.rs b/crates/assistant_tools/src/rename_tool.rs index e53e49d68c..7ad1a90a9f 100644 --- a/crates/assistant_tools/src/rename_tool.rs +++ b/crates/assistant_tools/src/rename_tool.rs @@ -2,13 +2,15 @@ use anyhow::{Context as _, anyhow}; use assistant_tool::{ActionLog, Tool, ToolResult}; use gpui::{App, Entity, Task}; use language::{self, Buffer, ToPointUtf16}; -use language_model::LanguageModelRequestMessage; +use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::sync::Arc; use ui::IconName; +use crate::schema::json_schema_for; + #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct RenameToolInput { /// The relative path to the file containing the symbol to rename. @@ -66,12 +68,8 @@ impl Tool for RenameTool { IconName::Pencil } - fn input_schema( - &self, - _format: language_model::LanguageModelToolSchemaFormat, - ) -> serde_json::Value { - let schema = schemars::schema_for!(RenameToolInput); - serde_json::to_value(&schema).unwrap() + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + json_schema_for::(format) } fn ui_text(&self, input: &serde_json::Value) -> String { diff --git a/crates/assistant_tools/src/schema.rs b/crates/assistant_tools/src/schema.rs index 10ae594ecd..4a71d47d2c 100644 --- a/crates/assistant_tools/src/schema.rs +++ b/crates/assistant_tools/src/schema.rs @@ -5,23 +5,20 @@ use schemars::{ schema::{RootSchema, Schema, SchemaObject}, }; -pub fn json_schema_for(format: LanguageModelToolSchemaFormat) -> serde_json::Value { +pub fn json_schema_for( + format: LanguageModelToolSchemaFormat, +) -> Result { let schema = root_schema_for::(format); - schema_to_json(&schema, format).expect("Failed to convert tool calling schema to JSON") + schema_to_json(&schema, format) } -pub fn schema_to_json( +fn schema_to_json( schema: &RootSchema, format: LanguageModelToolSchemaFormat, ) -> Result { let mut value = serde_json::to_value(schema)?; - match format { - LanguageModelToolSchemaFormat::JsonSchema => Ok(value), - LanguageModelToolSchemaFormat::JsonSchemaSubset => { - transform_fields_to_json_schema_subset(&mut value); - Ok(value) - } - } + assistant_tool::adapt_schema_to_format(&mut value, format)?; + Ok(value) } fn root_schema_for(format: LanguageModelToolSchemaFormat) -> RootSchema { @@ -79,42 +76,3 @@ impl schemars::visit::Visitor for TransformToJsonSchemaSubsetVisitor { schemars::visit::visit_schema_object(self, schema) } } - -fn transform_fields_to_json_schema_subset(json: &mut serde_json::Value) { - if let serde_json::Value::Object(obj) = json { - if let Some(default) = obj.get("default") { - let is_null = default.is_null(); - //Default is not supported, so we need to remove it. - obj.remove("default"); - if is_null { - obj.insert("nullable".to_string(), serde_json::Value::Bool(true)); - } - } - - // If a type is not specified for an input parameter we need to add it. - if obj.contains_key("description") - && !obj.contains_key("type") - && !(obj.contains_key("anyOf") - || obj.contains_key("oneOf") - || obj.contains_key("allOf")) - { - obj.insert( - "type".to_string(), - serde_json::Value::String("string".to_string()), - ); - } - - //Format field is only partially supported (e.g. not uint compatibility) - obj.remove("format"); - - for (_, value) in obj.iter_mut() { - if let serde_json::Value::Object(_) | serde_json::Value::Array(_) = value { - transform_fields_to_json_schema_subset(value); - } - } - } else if let serde_json::Value::Array(arr) = json { - for item in arr.iter_mut() { - transform_fields_to_json_schema_subset(item); - } - } -} diff --git a/crates/assistant_tools/src/symbol_info_tool.rs b/crates/assistant_tools/src/symbol_info_tool.rs index 68d9b6afc9..d92c328467 100644 --- a/crates/assistant_tools/src/symbol_info_tool.rs +++ b/crates/assistant_tools/src/symbol_info_tool.rs @@ -84,7 +84,7 @@ impl Tool for SymbolInfoTool { IconName::Code } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/terminal_tool.rs b/crates/assistant_tools/src/terminal_tool.rs index 6295c6f1b0..4f7343492e 100644 --- a/crates/assistant_tools/src/terminal_tool.rs +++ b/crates/assistant_tools/src/terminal_tool.rs @@ -44,7 +44,7 @@ impl Tool for TerminalTool { IconName::Terminal } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/assistant_tools/src/thinking_tool.rs b/crates/assistant_tools/src/thinking_tool.rs index 63486854b4..179f92791a 100644 --- a/crates/assistant_tools/src/thinking_tool.rs +++ b/crates/assistant_tools/src/thinking_tool.rs @@ -36,7 +36,7 @@ impl Tool for ThinkingTool { IconName::LightBulb } - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { json_schema_for::(format) } diff --git a/crates/auto_update/Cargo.toml b/crates/auto_update/Cargo.toml index 84b4e5d739..1a772710c9 100644 --- a/crates/auto_update/Cargo.toml +++ b/crates/auto_update/Cargo.toml @@ -27,6 +27,8 @@ serde_json.workspace = true settings.workspace = true smol.workspace = true tempfile.workspace = true -which.workspace = true workspace.workspace = true workspace-hack.workspace = true + +[target.'cfg(not(target_os = "windows"))'.dependencies] +which.workspace = true diff --git a/crates/auto_update/src/auto_update.rs b/crates/auto_update/src/auto_update.rs index 77d2037288..390400c048 100644 --- a/crates/auto_update/src/auto_update.rs +++ b/crates/auto_update/src/auto_update.rs @@ -23,7 +23,6 @@ use std::{ sync::Arc, time::Duration, }; -use which::which; use workspace::Workspace; const SHOULD_SHOW_UPDATE_NOTIFICATION_KEY: &str = "auto-updater-should-show-updated-notification"; @@ -63,7 +62,7 @@ pub struct AutoUpdater { pending_poll: Option>>, } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] pub struct JsonRelease { pub version: String, pub url: String, @@ -237,6 +236,46 @@ pub fn view_release_notes(_: &ViewReleaseNotes, cx: &mut App) -> Option<()> { None } +#[cfg(not(target_os = "windows"))] +struct InstallerDir(tempfile::TempDir); + +#[cfg(not(target_os = "windows"))] +impl InstallerDir { + async fn new() -> Result { + Ok(Self( + tempfile::Builder::new() + .prefix("zed-auto-update") + .tempdir()?, + )) + } + + fn path(&self) -> &Path { + self.0.path() + } +} + +#[cfg(target_os = "windows")] +struct InstallerDir(PathBuf); + +#[cfg(target_os = "windows")] +impl InstallerDir { + async fn new() -> Result { + let installer_dir = std::env::current_exe()? + .parent() + .context("No parent dir for Zed.exe")? + .join("updates"); + if smol::fs::metadata(&installer_dir).await.is_ok() { + smol::fs::remove_dir_all(&installer_dir).await?; + } + smol::fs::create_dir(&installer_dir).await?; + Ok(Self(installer_dir)) + } + + fn path(&self) -> &Path { + self.0.as_path() + } +} + impl AutoUpdater { pub fn get(cx: &mut App) -> Option> { cx.default_global::().0.clone() @@ -469,22 +508,21 @@ impl AutoUpdater { cx.notify(); })?; - let temp_dir = tempfile::Builder::new() - .prefix("zed-auto-update") - .tempdir()?; - + let installer_dir = InstallerDir::new().await?; let filename = match OS { "macos" => Ok("Zed.dmg"), "linux" => Ok("zed.tar.gz"), + "windows" => Ok("ZedUpdateInstaller.exe"), _ => Err(anyhow!("not supported: {:?}", OS)), }?; + #[cfg(not(target_os = "windows"))] anyhow::ensure!( - which("rsync").is_ok(), + which::which("rsync").is_ok(), "Aborting. Could not find rsync which is required for auto-updates." ); - let downloaded_asset = temp_dir.path().join(filename); + let downloaded_asset = installer_dir.path().join(filename); download_release(&downloaded_asset, release, client, &cx).await?; this.update(&mut cx, |this, cx| { @@ -493,8 +531,9 @@ impl AutoUpdater { })?; let binary_path = match OS { - "macos" => install_release_macos(&temp_dir, downloaded_asset, &cx).await, - "linux" => install_release_linux(&temp_dir, downloaded_asset, &cx).await, + "macos" => install_release_macos(&installer_dir, downloaded_asset, &cx).await, + "linux" => install_release_linux(&installer_dir, downloaded_asset, &cx).await, + "windows" => install_release_windows(downloaded_asset).await, _ => Err(anyhow!("not supported: {:?}", OS)), }?; @@ -629,7 +668,7 @@ async fn download_release( } async fn install_release_linux( - temp_dir: &tempfile::TempDir, + temp_dir: &InstallerDir, downloaded_tar_gz: PathBuf, cx: &AsyncApp, ) -> Result { @@ -696,7 +735,7 @@ async fn install_release_linux( } async fn install_release_macos( - temp_dir: &tempfile::TempDir, + temp_dir: &InstallerDir, downloaded_dmg: PathBuf, cx: &AsyncApp, ) -> Result { @@ -743,3 +782,41 @@ async fn install_release_macos( Ok(running_app_path) } + +async fn install_release_windows(downloaded_installer: PathBuf) -> Result { + let output = Command::new(downloaded_installer) + .arg("/verysilent") + .arg("/update=true") + .arg("!desktopicon") + .arg("!quicklaunchicon") + .output() + .await?; + anyhow::ensure!( + output.status.success(), + "failed to start installer: {:?}", + String::from_utf8_lossy(&output.stderr) + ); + Ok(std::env::current_exe()?) +} + +pub fn check_pending_installation() -> bool { + let Some(installer_path) = std::env::current_exe() + .ok() + .and_then(|p| p.parent().map(|p| p.join("updates"))) + else { + return false; + }; + + // The installer will create a flag file after it finishes updating + let flag_file = installer_path.join("versions.txt"); + if flag_file.exists() { + if let Some(helper) = installer_path + .parent() + .map(|p| p.join("tools\\auto_update_helper.exe")) + { + let _ = std::process::Command::new(helper).spawn(); + return true; + } + } + false +} diff --git a/crates/auto_update_helper/Cargo.toml b/crates/auto_update_helper/Cargo.toml new file mode 100644 index 0000000000..6581de48d2 --- /dev/null +++ b/crates/auto_update_helper/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "auto_update_helper" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[[bin]] +name = "auto_update_helper" +path = "src/auto_update_helper.rs" +doctest = false + +[dependencies] +anyhow.workspace = true +log.workspace = true +simplelog.workspace = true +workspace-hack.workspace = true + +[target.'cfg(target_os = "windows")'.dependencies] +windows.workspace = true + +[target.'cfg(target_os = "windows")'.build-dependencies] +winresource = "0.1" + +[package.metadata.docs.rs] +targets = ["x86_64-pc-windows-msvc"] diff --git a/crates/auto_update_helper/LICENSE-GPL b/crates/auto_update_helper/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/auto_update_helper/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/auto_update_helper/app-icon.ico b/crates/auto_update_helper/app-icon.ico new file mode 100644 index 0000000000..321e90fcfa Binary files /dev/null and b/crates/auto_update_helper/app-icon.ico differ diff --git a/crates/auto_update_helper/build.rs b/crates/auto_update_helper/build.rs new file mode 100644 index 0000000000..2910632c7f --- /dev/null +++ b/crates/auto_update_helper/build.rs @@ -0,0 +1,15 @@ +fn main() { + #[cfg(target_os = "windows")] + { + println!("cargo:rerun-if-changed=manifest.xml"); + + let mut res = winresource::WindowsResource::new(); + res.set_manifest_file("manifest.xml"); + res.set_icon("app-icon.ico"); + + if let Err(e) = res.compile() { + eprintln!("{}", e); + std::process::exit(1); + } + } +} diff --git a/crates/auto_update_helper/manifest.xml b/crates/auto_update_helper/manifest.xml new file mode 100644 index 0000000000..5a69b43486 --- /dev/null +++ b/crates/auto_update_helper/manifest.xml @@ -0,0 +1,16 @@ + + + + true + PerMonitorV2 + + + + + + + + diff --git a/crates/auto_update_helper/src/auto_update_helper.rs b/crates/auto_update_helper/src/auto_update_helper.rs new file mode 100644 index 0000000000..b8e4ba26d1 --- /dev/null +++ b/crates/auto_update_helper/src/auto_update_helper.rs @@ -0,0 +1,94 @@ +#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] + +#[cfg(target_os = "windows")] +mod dialog; +#[cfg(target_os = "windows")] +mod updater; + +#[cfg(target_os = "windows")] +fn main() { + if let Err(e) = windows_impl::run() { + log::error!("Error: Zed update failed, {:?}", e); + windows_impl::show_error(format!("Error: {:?}", e)); + } +} + +#[cfg(not(target_os = "windows"))] +fn main() {} + +#[cfg(target_os = "windows")] +mod windows_impl { + use std::path::Path; + + use super::dialog::create_dialog_window; + use super::updater::perform_update; + use anyhow::{Context, Result}; + use windows::{ + Win32::{ + Foundation::{HWND, LPARAM, WPARAM}, + UI::WindowsAndMessaging::{ + DispatchMessageW, GetMessageW, MB_ICONERROR, MB_SYSTEMMODAL, MSG, MessageBoxW, + PostMessageW, WM_USER, + }, + }, + core::HSTRING, + }; + + pub(crate) const WM_JOB_UPDATED: u32 = WM_USER + 1; + pub(crate) const WM_TERMINATE: u32 = WM_USER + 2; + + pub(crate) fn run() -> Result<()> { + let helper_dir = std::env::current_exe()? + .parent() + .context("No parent directory")? + .to_path_buf(); + init_log(&helper_dir)?; + let app_dir = helper_dir + .parent() + .context("No parent directory")? + .to_path_buf(); + + log::info!("======= Starting Zed update ======="); + let (tx, rx) = std::sync::mpsc::channel(); + let hwnd = create_dialog_window(rx)?.0 as isize; + std::thread::spawn(move || { + let result = perform_update(app_dir.as_path(), Some(hwnd)); + tx.send(result).ok(); + unsafe { PostMessageW(Some(HWND(hwnd as _)), WM_TERMINATE, WPARAM(0), LPARAM(0)) }.ok(); + }); + unsafe { + let mut message = MSG::default(); + while GetMessageW(&mut message, None, 0, 0).as_bool() { + DispatchMessageW(&message); + } + } + Ok(()) + } + + fn init_log(helper_dir: &Path) -> Result<()> { + simplelog::WriteLogger::init( + simplelog::LevelFilter::Info, + simplelog::Config::default(), + std::fs::File::options() + .append(true) + .create(true) + .open(helper_dir.join("auto_update_helper.log"))?, + )?; + Ok(()) + } + + pub(crate) fn show_error(mut content: String) { + if content.len() > 600 { + content.truncate(600); + content.push_str("...\n"); + } + let _ = unsafe { + MessageBoxW( + None, + &HSTRING::from(content), + windows::core::w!("Error: Zed update failed."), + MB_ICONERROR | MB_SYSTEMMODAL, + ) + }; + } +} diff --git a/crates/auto_update_helper/src/dialog.rs b/crates/auto_update_helper/src/dialog.rs new file mode 100644 index 0000000000..010ebb4875 --- /dev/null +++ b/crates/auto_update_helper/src/dialog.rs @@ -0,0 +1,236 @@ +use std::{cell::RefCell, sync::mpsc::Receiver}; + +use anyhow::{Context as _, Result}; +use windows::{ + Win32::{ + Foundation::{HWND, LPARAM, LRESULT, RECT, WPARAM}, + Graphics::Gdi::{ + BeginPaint, CLEARTYPE_QUALITY, CLIP_DEFAULT_PRECIS, CreateFontW, DEFAULT_CHARSET, + DeleteObject, EndPaint, FW_NORMAL, LOGFONTW, OUT_TT_ONLY_PRECIS, PAINTSTRUCT, + ReleaseDC, SelectObject, TextOutW, + }, + System::LibraryLoader::GetModuleHandleW, + UI::{ + Controls::{PBM_SETRANGE, PBM_SETSTEP, PBM_STEPIT, PROGRESS_CLASS}, + WindowsAndMessaging::{ + CREATESTRUCTW, CS_HREDRAW, CS_VREDRAW, CreateWindowExW, DefWindowProcW, + GWLP_USERDATA, GetDesktopWindow, GetWindowLongPtrW, GetWindowRect, HICON, + IMAGE_ICON, LR_DEFAULTSIZE, LR_SHARED, LoadImageW, PostQuitMessage, RegisterClassW, + SPI_GETICONTITLELOGFONT, SYSTEM_PARAMETERS_INFO_UPDATE_FLAGS, SendMessageW, + SetWindowLongPtrW, SystemParametersInfoW, WINDOW_EX_STYLE, WM_CLOSE, WM_CREATE, + WM_DESTROY, WM_NCCREATE, WM_PAINT, WNDCLASSW, WS_CAPTION, WS_CHILD, WS_EX_TOPMOST, + WS_POPUP, WS_VISIBLE, + }, + }, + }, + core::HSTRING, +}; + +use crate::{ + updater::JOBS, + windows_impl::{WM_JOB_UPDATED, WM_TERMINATE, show_error}, +}; + +#[repr(C)] +#[derive(Debug)] +struct DialogInfo { + rx: Receiver>, + progress_bar: isize, +} + +pub(crate) fn create_dialog_window(receiver: Receiver>) -> Result { + unsafe { + let class_name = windows::core::w!("Zed-Auto-Updater-Dialog-Class"); + let module = GetModuleHandleW(None).context("unable to get module handle")?; + let handle = LoadImageW( + Some(module.into()), + windows::core::PCWSTR(1 as _), + IMAGE_ICON, + 0, + 0, + LR_DEFAULTSIZE | LR_SHARED, + ) + .context("unable to load icon file")?; + let wc = WNDCLASSW { + lpfnWndProc: Some(wnd_proc), + lpszClassName: class_name, + style: CS_HREDRAW | CS_VREDRAW, + hIcon: HICON(handle.0), + ..Default::default() + }; + RegisterClassW(&wc); + let mut rect = RECT::default(); + GetWindowRect(GetDesktopWindow(), &mut rect) + .context("unable to get desktop window rect")?; + let width = 400; + let height = 150; + let info = Box::new(RefCell::new(DialogInfo { + rx: receiver, + progress_bar: 0, + })); + + let hwnd = CreateWindowExW( + WS_EX_TOPMOST, + class_name, + windows::core::w!("Zed Editor"), + WS_VISIBLE | WS_POPUP | WS_CAPTION, + rect.right / 2 - width / 2, + rect.bottom / 2 - height / 2, + width, + height, + None, + None, + None, + Some(Box::into_raw(info) as _), + ) + .context("unable to create dialog window")?; + Ok(hwnd) + } +} + +macro_rules! return_if_failed { + ($e:expr) => { + match $e { + Ok(v) => v, + Err(e) => { + return LRESULT(e.code().0 as _); + } + } + }; +} + +macro_rules! make_lparam { + ($l:expr, $h:expr) => { + LPARAM(($l as u32 | ($h as u32) << 16) as isize) + }; +} + +unsafe extern "system" fn wnd_proc( + hwnd: HWND, + msg: u32, + wparam: WPARAM, + lparam: LPARAM, +) -> LRESULT { + match msg { + WM_NCCREATE => unsafe { + let create_struct = lparam.0 as *const CREATESTRUCTW; + let info = (*create_struct).lpCreateParams as *mut RefCell; + let info = Box::from_raw(info); + SetWindowLongPtrW(hwnd, GWLP_USERDATA, Box::into_raw(info) as _); + DefWindowProcW(hwnd, msg, wparam, lparam) + }, + WM_CREATE => unsafe { + // Create progress bar + let mut rect = RECT::default(); + return_if_failed!(GetWindowRect(hwnd, &mut rect)); + let progress_bar = return_if_failed!(CreateWindowExW( + WINDOW_EX_STYLE(0), + PROGRESS_CLASS, + None, + WS_CHILD | WS_VISIBLE, + 20, + 50, + 340, + 35, + Some(hwnd), + None, + None, + None, + )); + SendMessageW( + progress_bar, + PBM_SETRANGE, + None, + Some(make_lparam!(0, JOBS.len() * 10)), + ); + SendMessageW(progress_bar, PBM_SETSTEP, Some(WPARAM(10)), None); + with_dialog_data(hwnd, |data| { + data.borrow_mut().progress_bar = progress_bar.0 as isize + }); + LRESULT(0) + }, + WM_PAINT => unsafe { + let mut ps = PAINTSTRUCT::default(); + let hdc = BeginPaint(hwnd, &mut ps); + + let font_name = get_system_ui_font_name(); + let font = CreateFontW( + 24, + 0, + 0, + 0, + FW_NORMAL.0 as _, + 0, + 0, + 0, + DEFAULT_CHARSET, + OUT_TT_ONLY_PRECIS, + CLIP_DEFAULT_PRECIS, + CLEARTYPE_QUALITY, + 0, + &HSTRING::from(font_name), + ); + let temp = SelectObject(hdc, font.into()); + let string = HSTRING::from("Zed Editor is updating..."); + return_if_failed!(TextOutW(hdc, 20, 15, &string).ok()); + return_if_failed!(DeleteObject(temp).ok()); + + return_if_failed!(EndPaint(hwnd, &ps).ok()); + ReleaseDC(Some(hwnd), hdc); + + LRESULT(0) + }, + WM_JOB_UPDATED => with_dialog_data(hwnd, |data| { + let progress_bar = data.borrow().progress_bar; + unsafe { SendMessageW(HWND(progress_bar as _), PBM_STEPIT, None, None) } + }), + WM_TERMINATE => { + with_dialog_data(hwnd, |data| { + if let Ok(result) = data.borrow_mut().rx.recv() { + if let Err(e) = result { + log::error!("Failed to update Zed: {:?}", e); + show_error(format!("Error: {:?}", e)); + } + } + }); + unsafe { PostQuitMessage(0) }; + LRESULT(0) + } + WM_CLOSE => LRESULT(0), // Prevent user occasionally closing the window + WM_DESTROY => { + unsafe { PostQuitMessage(0) }; + LRESULT(0) + } + _ => unsafe { DefWindowProcW(hwnd, msg, wparam, lparam) }, + } +} + +fn with_dialog_data(hwnd: HWND, f: F) -> T +where + F: FnOnce(&RefCell) -> T, +{ + let raw = unsafe { GetWindowLongPtrW(hwnd, GWLP_USERDATA) as *mut RefCell }; + let data = unsafe { Box::from_raw(raw) }; + let result = f(data.as_ref()); + unsafe { SetWindowLongPtrW(hwnd, GWLP_USERDATA, Box::into_raw(data) as _) }; + result +} + +fn get_system_ui_font_name() -> String { + unsafe { + let mut info: LOGFONTW = std::mem::zeroed(); + if SystemParametersInfoW( + SPI_GETICONTITLELOGFONT, + std::mem::size_of::() as u32, + Some(&mut info as *mut _ as _), + SYSTEM_PARAMETERS_INFO_UPDATE_FLAGS(0), + ) + .is_ok() + { + let font_name = String::from_utf16_lossy(&info.lfFaceName); + font_name.trim_matches(char::from(0)).to_owned() + } else { + "MS Shell Dlg".to_owned() + } + } +} diff --git a/crates/auto_update_helper/src/updater.rs b/crates/auto_update_helper/src/updater.rs new file mode 100644 index 0000000000..1c3fc10655 --- /dev/null +++ b/crates/auto_update_helper/src/updater.rs @@ -0,0 +1,171 @@ +use std::{ + os::windows::process::CommandExt, + path::Path, + time::{Duration, Instant}, +}; + +use anyhow::{Context, Result}; +use windows::Win32::{ + Foundation::{HWND, LPARAM, WPARAM}, + System::Threading::CREATE_NEW_PROCESS_GROUP, + UI::WindowsAndMessaging::PostMessageW, +}; + +use crate::windows_impl::WM_JOB_UPDATED; + +type Job = fn(&Path) -> Result<()>; + +#[cfg(not(test))] +pub(crate) const JOBS: [Job; 6] = [ + // Delete old files + |app_dir| { + let zed_executable = app_dir.join("Zed.exe"); + log::info!("Removing old file: {}", zed_executable.display()); + std::fs::remove_file(&zed_executable).context(format!( + "Failed to remove old file {}", + zed_executable.display() + )) + }, + |app_dir| { + let zed_cli = app_dir.join("bin\\zed.exe"); + log::info!("Removing old file: {}", zed_cli.display()); + std::fs::remove_file(&zed_cli) + .context(format!("Failed to remove old file {}", zed_cli.display())) + }, + // Copy new files + |app_dir| { + let zed_executable_source = app_dir.join("install\\Zed.exe"); + let zed_executable_dest = app_dir.join("Zed.exe"); + log::info!( + "Copying new file {} to {}", + zed_executable_source.display(), + zed_executable_dest.display() + ); + std::fs::copy(&zed_executable_source, &zed_executable_dest) + .map(|_| ()) + .context(format!( + "Failed to copy new file {} to {}", + zed_executable_source.display(), + zed_executable_dest.display() + )) + }, + |app_dir| { + let zed_cli_source = app_dir.join("install\\bin\\zed.exe"); + let zed_cli_dest = app_dir.join("bin\\zed.exe"); + log::info!( + "Copying new file {} to {}", + zed_cli_source.display(), + zed_cli_dest.display() + ); + std::fs::copy(&zed_cli_source, &zed_cli_dest) + .map(|_| ()) + .context(format!( + "Failed to copy new file {} to {}", + zed_cli_source.display(), + zed_cli_dest.display() + )) + }, + // Clean up installer folder and updates folder + |app_dir| { + let updates_folder = app_dir.join("updates"); + log::info!("Cleaning up: {}", updates_folder.display()); + std::fs::remove_dir_all(&updates_folder).context(format!( + "Failed to remove updates folder {}", + updates_folder.display() + )) + }, + |app_dir| { + let installer_folder = app_dir.join("install"); + log::info!("Cleaning up: {}", installer_folder.display()); + std::fs::remove_dir_all(&installer_folder).context(format!( + "Failed to remove installer folder {}", + installer_folder.display() + )) + }, +]; + +#[cfg(test)] +pub(crate) const JOBS: [Job; 2] = [ + |_| { + std::thread::sleep(Duration::from_millis(1000)); + if let Ok(config) = std::env::var("ZED_AUTO_UPDATE") { + match config.as_str() { + "err" => Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Simulated error", + )) + .context("Anyhow!"), + _ => panic!("Unknown ZED_AUTO_UPDATE value: {}", config), + } + } else { + Ok(()) + } + }, + |_| { + std::thread::sleep(Duration::from_millis(1000)); + if let Ok(config) = std::env::var("ZED_AUTO_UPDATE") { + match config.as_str() { + "err" => Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Simulated error", + )) + .context("Anyhow!"), + _ => panic!("Unknown ZED_AUTO_UPDATE value: {}", config), + } + } else { + Ok(()) + } + }, +]; + +pub(crate) fn perform_update(app_dir: &Path, hwnd: Option) -> Result<()> { + let hwnd = hwnd.map(|ptr| HWND(ptr as _)); + + for job in JOBS.iter() { + let start = Instant::now(); + loop { + if start.elapsed().as_secs() > 2 { + return Err(anyhow::anyhow!("Timed out")); + } + match (*job)(app_dir) { + Ok(_) => { + unsafe { PostMessageW(hwnd, WM_JOB_UPDATED, WPARAM(0), LPARAM(0))? }; + break; + } + Err(err) => { + // Check if it's a "not found" error + let io_err = err.downcast_ref::().unwrap(); + if io_err.kind() == std::io::ErrorKind::NotFound { + log::warn!("File or folder not found."); + unsafe { PostMessageW(hwnd, WM_JOB_UPDATED, WPARAM(0), LPARAM(0))? }; + break; + } + + log::error!("Operation failed: {}", err); + std::thread::sleep(Duration::from_millis(50)); + } + } + } + } + let _ = std::process::Command::new(app_dir.join("Zed.exe")) + .creation_flags(CREATE_NEW_PROCESS_GROUP.0) + .spawn(); + log::info!("Update completed successfully"); + Ok(()) +} + +#[cfg(test)] +mod test { + use super::perform_update; + + #[test] + fn test_perform_update() { + let app_dir = std::path::Path::new("C:/"); + assert!(perform_update(app_dir, None).is_ok()); + + // Simulate a timeout + unsafe { std::env::set_var("ZED_AUTO_UPDATE", "err") }; + let ret = perform_update(app_dir, None); + assert!(ret.is_err_and(|e| e.to_string().as_str() == "Timed out")); + } +} diff --git a/crates/cli/build.rs b/crates/cli/build.rs index f07d12546a..d41647c696 100644 --- a/crates/cli/build.rs +++ b/crates/cli/build.rs @@ -7,8 +7,6 @@ fn main() { if cfg!(target_os = "macos") { println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7"); - // Weakly link ScreenCaptureKit to ensure can be used on macOS 10.15+. - println!("cargo:rustc-link-arg=-Wl,-weak_framework,ScreenCaptureKit"); } // Populate git sha environment variable if git is available diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index f60131e0de..c4aa90e2c2 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -18,7 +18,6 @@ sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"] test-support = ["sqlite"] [dependencies] -anthropic.workspace = true anyhow.workspace = true async-stripe.workspace = true async-tungstenite.workspace = true diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 334d015d4b..2e682d2878 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -253,7 +253,6 @@ impl Config { pub enum ServiceMode { Api, Collab, - Llm, All, } @@ -265,10 +264,6 @@ impl ServiceMode { pub fn is_api(&self) -> bool { matches!(self, Self::Api | Self::All) } - - pub fn is_llm(&self) -> bool { - matches!(self, Self::Llm | Self::All) - } } pub struct AppState { diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 8c6fd772df..13d503e7d4 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -1,448 +1,10 @@ -mod authorization; pub mod db; mod token; -use crate::api::CloudflareIpCountryHeader; -use crate::api::events::SnowflakeRow; -use crate::build_kinesis_client; -use crate::rpc::MIN_ACCOUNT_AGE_FOR_LLM_USE; -use crate::{Cents, Config, Error, Result, db::UserId, executor::Executor}; -use anyhow::{Context as _, anyhow}; -use authorization::authorize_access_to_language_model; -use axum::routing::get; -use axum::{ - Extension, Json, Router, TypedHeader, - body::Body, - http::{self, HeaderName, HeaderValue, Request, StatusCode}, - middleware::{self, Next}, - response::{IntoResponse, Response}, - routing::post, -}; -use chrono::{DateTime, Duration, Utc}; -use collections::HashMap; -use db::TokenUsage; -use db::{ActiveUserCount, LlmDatabase, usage_measure::UsageMeasure}; -use futures::{Stream, StreamExt as _}; -use reqwest_client::ReqwestClient; -use rpc::{ - EXPIRED_LLM_TOKEN_HEADER_NAME, LanguageModelProvider, PerformCompletionParams, proto::Plan, -}; -use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME}; -use serde_json::json; -use std::{ - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; -use strum::IntoEnumIterator; -use tokio::sync::RwLock; -use util::ResultExt; +use crate::Cents; pub use token::*; -const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30); - -pub struct LlmState { - pub config: Config, - pub executor: Executor, - pub db: Arc, - pub http_client: ReqwestClient, - pub kinesis_client: Option, - active_user_count_by_model: - RwLock, ActiveUserCount)>>, -} - -impl LlmState { - pub async fn new(config: Config, executor: Executor) -> Result> { - let database_url = config - .llm_database_url - .as_ref() - .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?; - let max_connections = config - .llm_database_max_connections - .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?; - - let mut db_options = db::ConnectOptions::new(database_url); - db_options.max_connections(max_connections); - let mut db = LlmDatabase::new(db_options, executor.clone()).await?; - db.initialize().await?; - - let db = Arc::new(db); - - let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION")); - let http_client = - ReqwestClient::user_agent(&user_agent).context("failed to construct http client")?; - - let this = Self { - executor, - db, - http_client, - kinesis_client: if config.kinesis_access_key.is_some() { - build_kinesis_client(&config).await.log_err() - } else { - None - }, - active_user_count_by_model: RwLock::new(HashMap::default()), - config, - }; - - Ok(Arc::new(this)) - } - - pub async fn get_active_user_count( - &self, - provider: LanguageModelProvider, - model: &str, - ) -> Result { - let now = Utc::now(); - - { - let active_user_count_by_model = self.active_user_count_by_model.read().await; - if let Some((last_updated, count)) = - active_user_count_by_model.get(&(provider, model.to_string())) - { - if now - *last_updated < ACTIVE_USER_COUNT_CACHE_DURATION { - return Ok(*count); - } - } - } - - let mut cache = self.active_user_count_by_model.write().await; - let new_count = self.db.get_active_user_count(provider, model, now).await?; - cache.insert((provider, model.to_string()), (now, new_count)); - Ok(new_count) - } -} - -pub fn routes() -> Router<(), Body> { - Router::new() - .route("/models", get(list_models)) - .route("/completion", post(perform_completion)) - .layer(middleware::from_fn(validate_api_token)) -} - -async fn validate_api_token(mut req: Request, next: Next) -> impl IntoResponse { - let token = req - .headers() - .get(http::header::AUTHORIZATION) - .and_then(|header| header.to_str().ok()) - .ok_or_else(|| { - Error::http( - StatusCode::BAD_REQUEST, - "missing authorization header".to_string(), - ) - })? - .strip_prefix("Bearer ") - .ok_or_else(|| { - Error::http( - StatusCode::BAD_REQUEST, - "invalid authorization header".to_string(), - ) - })?; - - let state = req.extensions().get::>().unwrap(); - match LlmTokenClaims::validate(token, &state.config) { - Ok(claims) => { - if state.db.is_access_token_revoked(&claims.jti).await? { - return Err(Error::http( - StatusCode::UNAUTHORIZED, - "unauthorized".to_string(), - )); - } - - tracing::Span::current() - .record("user_id", claims.user_id) - .record("login", claims.github_user_login.clone()) - .record("authn.jti", &claims.jti) - .record("is_staff", claims.is_staff); - - req.extensions_mut().insert(claims); - Ok::<_, Error>(next.run(req).await.into_response()) - } - Err(ValidateLlmTokenError::Expired) => Err(Error::Http( - StatusCode::UNAUTHORIZED, - "unauthorized".to_string(), - [( - HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME), - HeaderValue::from_static("true"), - )] - .into_iter() - .collect(), - )), - Err(_err) => Err(Error::http( - StatusCode::UNAUTHORIZED, - "unauthorized".to_string(), - )), - } -} - -async fn list_models( - Extension(state): Extension>, - Extension(claims): Extension, - country_code_header: Option>, -) -> Result> { - let country_code = country_code_header.map(|header| header.to_string()); - - let mut accessible_models = Vec::new(); - - for (provider, model) in state.db.all_models() { - let authorize_result = authorize_access_to_language_model( - &state.config, - &claims, - country_code.as_deref(), - provider, - &model.name, - ); - - if authorize_result.is_ok() { - accessible_models.push(rpc::LanguageModel { - provider, - name: model.name, - }); - } - } - - Ok(Json(ListModelsResponse { - models: accessible_models, - })) -} - -async fn perform_completion( - Extension(state): Extension>, - Extension(claims): Extension, - country_code_header: Option>, - Json(params): Json, -) -> Result { - let model = normalize_model_name( - state.db.model_names_for_provider(params.provider), - params.model, - ); - - let bypass_account_age_check = claims.has_llm_subscription || claims.bypass_account_age_check; - if !bypass_account_age_check { - if Utc::now().naive_utc() - claims.account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE { - Err(anyhow!("account too young"))? - } - } - - authorize_access_to_language_model( - &state.config, - &claims, - country_code_header - .map(|header| header.to_string()) - .as_deref(), - params.provider, - &model, - )?; - - check_usage_limit(&state, params.provider, &model, &claims).await?; - - let stream = match params.provider { - LanguageModelProvider::Anthropic => { - let api_key = if claims.is_staff { - state - .config - .anthropic_staff_api_key - .as_ref() - .context("no Anthropic AI staff API key configured on the server")? - } else { - state - .config - .anthropic_api_key - .as_ref() - .context("no Anthropic AI API key configured on the server")? - }; - - let mut request: anthropic::Request = - serde_json::from_str(params.provider_request.get())?; - - // Override the model on the request with the latest version of the model that is - // known to the server. - // - // Right now, we use the version that's defined in `model.id()`, but we will likely - // want to change this code once a new version of an Anthropic model is released, - // so that users can use the new version, without having to update Zed. - request.model = match model.as_str() { - "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(), - "claude-3-7-sonnet" => anthropic::Model::Claude3_7Sonnet.id().to_string(), - "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(), - "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(), - "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(), - _ => request.model, - }; - - let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info( - &state.http_client, - anthropic::ANTHROPIC_API_URL, - api_key, - request, - ) - .await - .map_err(|err| match err { - anthropic::AnthropicError::ApiError(ref api_error) => match api_error.code() { - Some(anthropic::ApiErrorCode::RateLimitError) => { - tracing::info!( - target: "upstream rate limit exceeded", - user_id = claims.user_id, - login = claims.github_user_login, - authn.jti = claims.jti, - is_staff = claims.is_staff, - provider = params.provider.to_string(), - model = model - ); - - Error::http( - StatusCode::TOO_MANY_REQUESTS, - "Upstream Anthropic rate limit exceeded.".to_string(), - ) - } - Some(anthropic::ApiErrorCode::InvalidRequestError) => { - Error::http(StatusCode::BAD_REQUEST, api_error.message.clone()) - } - Some(anthropic::ApiErrorCode::OverloadedError) => { - Error::http(StatusCode::SERVICE_UNAVAILABLE, api_error.message.clone()) - } - Some(_) => { - Error::http(StatusCode::INTERNAL_SERVER_ERROR, api_error.message.clone()) - } - None => Error::Internal(anyhow!(err)), - }, - anthropic::AnthropicError::Other(err) => Error::Internal(err), - })?; - - if let Some(rate_limit_info) = rate_limit_info { - tracing::info!( - target: "upstream rate limit", - is_staff = claims.is_staff, - provider = params.provider.to_string(), - model = model, - tokens_remaining = rate_limit_info.tokens.as_ref().map(|limits| limits.remaining), - input_tokens_remaining = rate_limit_info.input_tokens.as_ref().map(|limits| limits.remaining), - output_tokens_remaining = rate_limit_info.output_tokens.as_ref().map(|limits| limits.remaining), - requests_remaining = rate_limit_info.requests.as_ref().map(|limits| limits.remaining), - requests_reset = ?rate_limit_info.requests.as_ref().map(|limits| limits.reset), - tokens_reset = ?rate_limit_info.tokens.as_ref().map(|limits| limits.reset), - input_tokens_reset = ?rate_limit_info.input_tokens.as_ref().map(|limits| limits.reset), - output_tokens_reset = ?rate_limit_info.output_tokens.as_ref().map(|limits| limits.reset), - ); - } - - chunks - .map(move |event| { - let chunk = event?; - let ( - input_tokens, - output_tokens, - cache_creation_input_tokens, - cache_read_input_tokens, - ) = match &chunk { - anthropic::Event::MessageStart { - message: anthropic::Response { usage, .. }, - } - | anthropic::Event::MessageDelta { usage, .. } => ( - usage.input_tokens.unwrap_or(0) as usize, - usage.output_tokens.unwrap_or(0) as usize, - usage.cache_creation_input_tokens.unwrap_or(0) as usize, - usage.cache_read_input_tokens.unwrap_or(0) as usize, - ), - _ => (0, 0, 0, 0), - }; - - anyhow::Ok(CompletionChunk { - bytes: serde_json::to_vec(&chunk).unwrap(), - input_tokens, - output_tokens, - cache_creation_input_tokens, - cache_read_input_tokens, - }) - }) - .boxed() - } - LanguageModelProvider::OpenAi => { - let api_key = state - .config - .openai_api_key - .as_ref() - .context("no OpenAI API key configured on the server")?; - let chunks = open_ai::stream_completion( - &state.http_client, - open_ai::OPEN_AI_API_URL, - api_key, - serde_json::from_str(params.provider_request.get())?, - ) - .await?; - - chunks - .map(|event| { - event.map(|chunk| { - let input_tokens = - chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize; - let output_tokens = - chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize; - CompletionChunk { - bytes: serde_json::to_vec(&chunk).unwrap(), - input_tokens, - output_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - } - }) - }) - .boxed() - } - LanguageModelProvider::Google => { - let api_key = state - .config - .google_ai_api_key - .as_ref() - .context("no Google AI API key configured on the server")?; - let chunks = google_ai::stream_generate_content( - &state.http_client, - google_ai::API_URL, - api_key, - serde_json::from_str(params.provider_request.get())?, - ) - .await?; - - chunks - .map(|event| { - event.map(|chunk| { - // TODO - implement token counting for Google AI - CompletionChunk { - bytes: serde_json::to_vec(&chunk).unwrap(), - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - } - }) - }) - .boxed() - } - }; - - Ok(Response::new(Body::wrap_stream(TokenCountingStream { - state, - claims, - provider: params.provider, - model, - tokens: TokenUsage::default(), - inner_stream: stream, - }))) -} - -fn normalize_model_name(known_models: Vec, name: String) -> String { - if let Some(known_model_name) = known_models - .iter() - .filter(|known_model_name| name.starts_with(known_model_name.as_str())) - .max_by_key(|known_model_name| known_model_name.len()) - { - known_model_name.to_string() - } else { - name - } -} - /// The maximum monthly spending an individual user can reach on the free tier /// before they have to pay. pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10); @@ -452,330 +14,3 @@ pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10); /// /// Used to prevent surprise bills. pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10); - -async fn check_usage_limit( - state: &Arc, - provider: LanguageModelProvider, - model_name: &str, - claims: &LlmTokenClaims, -) -> Result<()> { - if claims.is_staff { - return Ok(()); - } - - let user_id = UserId::from_proto(claims.user_id); - let model = state.db.model(provider, model_name)?; - let free_tier = claims.free_tier_monthly_spending_limit(); - - let spending_this_month = state - .db - .get_user_spending_for_month(user_id, Utc::now()) - .await?; - if spending_this_month >= free_tier { - if !claims.has_llm_subscription { - return Err(Error::http( - StatusCode::PAYMENT_REQUIRED, - "Maximum spending limit reached for this month.".to_string(), - )); - } - - let monthly_spend = spending_this_month.saturating_sub(free_tier); - if monthly_spend >= Cents(claims.max_monthly_spend_in_cents) { - return Err(Error::Http( - StatusCode::FORBIDDEN, - "Maximum spending limit reached for this month.".to_string(), - [( - HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME), - HeaderValue::from_static("true"), - )] - .into_iter() - .collect(), - )); - } - } - - let active_users = state.get_active_user_count(provider, model_name).await?; - - let users_in_recent_minutes = active_users.users_in_recent_minutes.max(1); - let users_in_recent_days = active_users.users_in_recent_days.max(1); - - let per_user_max_requests_per_minute = - model.max_requests_per_minute as usize / users_in_recent_minutes; - let per_user_max_tokens_per_minute = - model.max_tokens_per_minute as usize / users_in_recent_minutes; - let per_user_max_input_tokens_per_minute = - model.max_input_tokens_per_minute as usize / users_in_recent_minutes; - let per_user_max_output_tokens_per_minute = - model.max_output_tokens_per_minute as usize / users_in_recent_minutes; - let per_user_max_tokens_per_day = model.max_tokens_per_day as usize / users_in_recent_days; - - let usage = state - .db - .get_usage(user_id, provider, model_name, Utc::now()) - .await?; - - let checks = match (provider, model_name) { - (LanguageModelProvider::Anthropic, "claude-3-7-sonnet") => vec![ - ( - usage.requests_this_minute, - per_user_max_requests_per_minute, - UsageMeasure::RequestsPerMinute, - ), - ( - usage.input_tokens_this_minute, - per_user_max_tokens_per_minute, - UsageMeasure::InputTokensPerMinute, - ), - ( - usage.output_tokens_this_minute, - per_user_max_tokens_per_minute, - UsageMeasure::OutputTokensPerMinute, - ), - ( - usage.tokens_this_day, - per_user_max_tokens_per_day, - UsageMeasure::TokensPerDay, - ), - ], - _ => vec![ - ( - usage.requests_this_minute, - per_user_max_requests_per_minute, - UsageMeasure::RequestsPerMinute, - ), - ( - usage.tokens_this_minute, - per_user_max_tokens_per_minute, - UsageMeasure::TokensPerMinute, - ), - ( - usage.tokens_this_day, - per_user_max_tokens_per_day, - UsageMeasure::TokensPerDay, - ), - ], - }; - - for (used, limit, usage_measure) in checks { - if used > limit { - let resource = match usage_measure { - UsageMeasure::RequestsPerMinute => "requests_per_minute", - UsageMeasure::TokensPerMinute => "tokens_per_minute", - UsageMeasure::InputTokensPerMinute => "input_tokens_per_minute", - UsageMeasure::OutputTokensPerMinute => "output_tokens_per_minute", - UsageMeasure::TokensPerDay => "tokens_per_day", - }; - - tracing::info!( - target: "user rate limit", - user_id = claims.user_id, - login = claims.github_user_login, - authn.jti = claims.jti, - is_staff = claims.is_staff, - provider = provider.to_string(), - model = model.name, - usage_measure = resource, - requests_this_minute = usage.requests_this_minute, - tokens_this_minute = usage.tokens_this_minute, - input_tokens_this_minute = usage.input_tokens_this_minute, - output_tokens_this_minute = usage.output_tokens_this_minute, - tokens_this_day = usage.tokens_this_day, - users_in_recent_minutes = users_in_recent_minutes, - users_in_recent_days = users_in_recent_days, - max_requests_per_minute = per_user_max_requests_per_minute, - max_tokens_per_minute = per_user_max_tokens_per_minute, - max_input_tokens_per_minute = per_user_max_input_tokens_per_minute, - max_output_tokens_per_minute = per_user_max_output_tokens_per_minute, - max_tokens_per_day = per_user_max_tokens_per_day, - ); - - SnowflakeRow::new( - "Language Model Rate Limited", - Some(claims.metrics_id), - claims.is_staff, - claims.system_id.clone(), - json!({ - "usage": usage, - "users_in_recent_minutes": users_in_recent_minutes, - "users_in_recent_days": users_in_recent_days, - "max_requests_per_minute": per_user_max_requests_per_minute, - "max_tokens_per_minute": per_user_max_tokens_per_minute, - "max_input_tokens_per_minute": per_user_max_input_tokens_per_minute, - "max_output_tokens_per_minute": per_user_max_output_tokens_per_minute, - "max_tokens_per_day": per_user_max_tokens_per_day, - "plan": match claims.plan { - Plan::Free => "free".to_string(), - Plan::ZedPro => "zed_pro".to_string(), - }, - "model": model.name.clone(), - "provider": provider.to_string(), - "usage_measure": resource.to_string(), - }), - ) - .write(&state.kinesis_client, &state.config.kinesis_stream) - .await - .log_err(); - - return Err(Error::http( - StatusCode::TOO_MANY_REQUESTS, - format!("Rate limit exceeded. Maximum {} reached.", resource), - )); - } - } - - Ok(()) -} - -struct CompletionChunk { - bytes: Vec, - input_tokens: usize, - output_tokens: usize, - cache_creation_input_tokens: usize, - cache_read_input_tokens: usize, -} - -struct TokenCountingStream { - state: Arc, - claims: LlmTokenClaims, - provider: LanguageModelProvider, - model: String, - tokens: TokenUsage, - inner_stream: S, -} - -impl Stream for TokenCountingStream -where - S: Stream> + Unpin, -{ - type Item = Result, anyhow::Error>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match Pin::new(&mut self.inner_stream).poll_next(cx) { - Poll::Ready(Some(Ok(mut chunk))) => { - chunk.bytes.push(b'\n'); - self.tokens.input += chunk.input_tokens; - self.tokens.output += chunk.output_tokens; - self.tokens.input_cache_creation += chunk.cache_creation_input_tokens; - self.tokens.input_cache_read += chunk.cache_read_input_tokens; - Poll::Ready(Some(Ok(chunk.bytes))) - } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } -} - -impl Drop for TokenCountingStream { - fn drop(&mut self) { - let state = self.state.clone(); - let claims = self.claims.clone(); - let provider = self.provider; - let model = std::mem::take(&mut self.model); - let tokens = self.tokens; - self.state.executor.spawn_detached(async move { - let usage = state - .db - .record_usage( - UserId::from_proto(claims.user_id), - claims.is_staff, - provider, - &model, - tokens, - claims.has_llm_subscription, - Cents(claims.max_monthly_spend_in_cents), - claims.free_tier_monthly_spending_limit(), - Utc::now(), - ) - .await - .log_err(); - - if let Some(usage) = usage { - tracing::info!( - target: "user usage", - user_id = claims.user_id, - login = claims.github_user_login, - authn.jti = claims.jti, - is_staff = claims.is_staff, - provider = provider.to_string(), - model = model, - requests_this_minute = usage.requests_this_minute, - tokens_this_minute = usage.tokens_this_minute, - input_tokens_this_minute = usage.input_tokens_this_minute, - output_tokens_this_minute = usage.output_tokens_this_minute, - ); - - let properties = json!({ - "has_llm_subscription": claims.has_llm_subscription, - "max_monthly_spend_in_cents": claims.max_monthly_spend_in_cents, - "plan": match claims.plan { - Plan::Free => "free".to_string(), - Plan::ZedPro => "zed_pro".to_string(), - }, - "model": model, - "provider": provider, - "usage": usage, - "tokens": tokens - }); - SnowflakeRow::new( - "Language Model Used", - Some(claims.metrics_id), - claims.is_staff, - claims.system_id.clone(), - properties, - ) - .write(&state.kinesis_client, &state.config.kinesis_stream) - .await - .log_err(); - } - }) - } -} - -pub fn log_usage_periodically(state: Arc) { - state.executor.clone().spawn_detached(async move { - loop { - state - .executor - .sleep(std::time::Duration::from_secs(30)) - .await; - - for provider in LanguageModelProvider::iter() { - for model in state.db.model_names_for_provider(provider) { - if let Some(active_user_count) = state - .get_active_user_count(provider, &model) - .await - .log_err() - { - tracing::info!( - target: "active user counts", - provider = provider.to_string(), - model = model, - users_in_recent_minutes = active_user_count.users_in_recent_minutes, - users_in_recent_days = active_user_count.users_in_recent_days, - ); - } - } - } - - if let Some(usages) = state - .db - .get_application_wide_usages_by_model(Utc::now()) - .await - .log_err() - { - for usage in usages { - tracing::info!( - target: "computed usage", - provider = usage.provider.to_string(), - model = usage.model, - requests_this_minute = usage.requests_this_minute, - tokens_this_minute = usage.tokens_this_minute, - input_tokens_this_minute = usage.input_tokens_this_minute, - output_tokens_this_minute = usage.output_tokens_this_minute, - ); - } - } - } - }) -} diff --git a/crates/collab/src/llm/authorization.rs b/crates/collab/src/llm/authorization.rs deleted file mode 100644 index 1ce7d7afdc..0000000000 --- a/crates/collab/src/llm/authorization.rs +++ /dev/null @@ -1,330 +0,0 @@ -use reqwest::StatusCode; -use rpc::LanguageModelProvider; - -use crate::llm::LlmTokenClaims; -use crate::{Config, Error, Result}; - -pub fn authorize_access_to_language_model( - config: &Config, - claims: &LlmTokenClaims, - country_code: Option<&str>, - provider: LanguageModelProvider, - model: &str, -) -> Result<()> { - authorize_access_for_country(config, country_code, provider)?; - authorize_access_to_model(config, claims, provider, model)?; - Ok(()) -} - -fn authorize_access_to_model( - config: &Config, - claims: &LlmTokenClaims, - provider: LanguageModelProvider, - model: &str, -) -> Result<()> { - if claims.is_staff { - return Ok(()); - } - - if provider == LanguageModelProvider::Anthropic { - if model == "claude-3-5-sonnet" || model == "claude-3-7-sonnet" { - return Ok(()); - } - - if claims.has_llm_closed_beta_feature_flag - && Some(model) == config.llm_closed_beta_model_name.as_deref() - { - return Ok(()); - } - } - - Err(Error::http( - StatusCode::FORBIDDEN, - format!("access to model {model:?} is not included in your plan"), - )) -} - -fn authorize_access_for_country( - config: &Config, - country_code: Option<&str>, - provider: LanguageModelProvider, -) -> Result<()> { - // In development we won't have the `CF-IPCountry` header, so we can't check - // the country code. - // - // This shouldn't be necessary, as anyone running in development will need to provide - // their own API credentials in order to use an LLM provider. - if config.is_development() { - return Ok(()); - } - - // https://developers.cloudflare.com/fundamentals/reference/http-request-headers/#cf-ipcountry - let country_code = match country_code { - // `XX` - Used for clients without country code data. - None | Some("XX") => Err(Error::http( - StatusCode::BAD_REQUEST, - "no country code".to_string(), - ))?, - // `T1` - Used for clients using the Tor network. - Some("T1") => Err(Error::http( - StatusCode::FORBIDDEN, - format!("access to {provider:?} models is not available over Tor"), - ))?, - Some(country_code) => country_code, - }; - - let is_country_supported_by_provider = match provider { - LanguageModelProvider::Anthropic => anthropic::is_supported_country(country_code), - LanguageModelProvider::OpenAi => open_ai::is_supported_country(country_code), - LanguageModelProvider::Google => google_ai::is_supported_country(country_code), - }; - if !is_country_supported_by_provider { - Err(Error::http( - StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS, - format!( - "access to {provider:?} models is not available in your region ({country_code})" - ), - ))? - } - - Ok(()) -} - -#[cfg(test)] -mod tests { - use axum::response::IntoResponse; - use pretty_assertions::assert_eq; - use rpc::proto::Plan; - - use super::*; - - #[gpui::test] - async fn test_authorize_access_to_language_model_with_supported_country( - _cx: &mut gpui::TestAppContext, - ) { - let config = Config::test(); - - let claims = LlmTokenClaims { - user_id: 99, - plan: Plan::ZedPro, - is_staff: true, - ..Default::default() - }; - - let cases = vec![ - (LanguageModelProvider::Anthropic, "US"), // United States - (LanguageModelProvider::Anthropic, "GB"), // United Kingdom - (LanguageModelProvider::OpenAi, "US"), // United States - (LanguageModelProvider::OpenAi, "GB"), // United Kingdom - (LanguageModelProvider::Google, "US"), // United States - (LanguageModelProvider::Google, "GB"), // United Kingdom - ]; - - for (provider, country_code) in cases { - authorize_access_to_language_model( - &config, - &claims, - Some(country_code), - provider, - "the-model", - ) - .unwrap_or_else(|_| { - panic!("expected authorization to return Ok for {provider:?}: {country_code}") - }) - } - } - - #[gpui::test] - async fn test_authorize_access_to_language_model_with_unsupported_country( - _cx: &mut gpui::TestAppContext, - ) { - let config = Config::test(); - - let claims = LlmTokenClaims { - user_id: 99, - plan: Plan::ZedPro, - ..Default::default() - }; - - let cases = vec![ - (LanguageModelProvider::Anthropic, "AF"), // Afghanistan - (LanguageModelProvider::Anthropic, "BY"), // Belarus - (LanguageModelProvider::Anthropic, "CF"), // Central African Republic - (LanguageModelProvider::Anthropic, "CN"), // China - (LanguageModelProvider::Anthropic, "CU"), // Cuba - (LanguageModelProvider::Anthropic, "ER"), // Eritrea - (LanguageModelProvider::Anthropic, "ET"), // Ethiopia - (LanguageModelProvider::Anthropic, "IR"), // Iran - (LanguageModelProvider::Anthropic, "KP"), // North Korea - (LanguageModelProvider::Anthropic, "XK"), // Kosovo - (LanguageModelProvider::Anthropic, "LY"), // Libya - (LanguageModelProvider::Anthropic, "MM"), // Myanmar - (LanguageModelProvider::Anthropic, "RU"), // Russia - (LanguageModelProvider::Anthropic, "SO"), // Somalia - (LanguageModelProvider::Anthropic, "SS"), // South Sudan - (LanguageModelProvider::Anthropic, "SD"), // Sudan - (LanguageModelProvider::Anthropic, "SY"), // Syria - (LanguageModelProvider::Anthropic, "VE"), // Venezuela - (LanguageModelProvider::Anthropic, "YE"), // Yemen - (LanguageModelProvider::OpenAi, "KP"), // North Korea - (LanguageModelProvider::Google, "KP"), // North Korea - ]; - - for (provider, country_code) in cases { - let error_response = authorize_access_to_language_model( - &config, - &claims, - Some(country_code), - provider, - "the-model", - ) - .expect_err(&format!( - "expected authorization to return an error for {provider:?}: {country_code}" - )) - .into_response(); - - assert_eq!( - error_response.status(), - StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS - ); - let response_body = hyper::body::to_bytes(error_response.into_body()) - .await - .unwrap() - .to_vec(); - assert_eq!( - String::from_utf8(response_body).unwrap(), - format!( - "access to {provider:?} models is not available in your region ({country_code})" - ) - ); - } - } - - #[gpui::test] - async fn test_authorize_access_to_language_model_with_tor(_cx: &mut gpui::TestAppContext) { - let config = Config::test(); - - let claims = LlmTokenClaims { - user_id: 99, - plan: Plan::ZedPro, - ..Default::default() - }; - - let cases = vec![ - (LanguageModelProvider::Anthropic, "T1"), // Tor - (LanguageModelProvider::OpenAi, "T1"), // Tor - (LanguageModelProvider::Google, "T1"), // Tor - ]; - - for (provider, country_code) in cases { - let error_response = authorize_access_to_language_model( - &config, - &claims, - Some(country_code), - provider, - "the-model", - ) - .expect_err(&format!( - "expected authorization to return an error for {provider:?}: {country_code}" - )) - .into_response(); - - assert_eq!(error_response.status(), StatusCode::FORBIDDEN); - let response_body = hyper::body::to_bytes(error_response.into_body()) - .await - .unwrap() - .to_vec(); - assert_eq!( - String::from_utf8(response_body).unwrap(), - format!("access to {provider:?} models is not available over Tor") - ); - } - } - - #[gpui::test] - async fn test_authorize_access_to_language_model_based_on_plan() { - let config = Config::test(); - - let test_cases = vec![ - // Pro plan should have access to claude-3.5-sonnet - ( - Plan::ZedPro, - LanguageModelProvider::Anthropic, - "claude-3-5-sonnet", - true, - ), - // Free plan should have access to claude-3.5-sonnet - ( - Plan::Free, - LanguageModelProvider::Anthropic, - "claude-3-5-sonnet", - true, - ), - // Pro plan should NOT have access to other Anthropic models - ( - Plan::ZedPro, - LanguageModelProvider::Anthropic, - "claude-3-opus", - false, - ), - ]; - - for (plan, provider, model, expected_access) in test_cases { - let claims = LlmTokenClaims { - plan, - ..Default::default() - }; - - let result = - authorize_access_to_language_model(&config, &claims, Some("US"), provider, model); - - if expected_access { - assert!( - result.is_ok(), - "Expected access to be granted for plan {:?}, provider {:?}, model {}", - plan, - provider, - model - ); - } else { - let error = result.expect_err(&format!( - "Expected access to be denied for plan {:?}, provider {:?}, model {}", - plan, provider, model - )); - let response = error.into_response(); - assert_eq!(response.status(), StatusCode::FORBIDDEN); - } - } - } - - #[gpui::test] - async fn test_authorize_access_to_language_model_for_staff() { - let config = Config::test(); - - let claims = LlmTokenClaims { - is_staff: true, - ..Default::default() - }; - - // Staff should have access to all models - let test_cases = vec![ - (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"), - (LanguageModelProvider::Anthropic, "claude-2"), - (LanguageModelProvider::Anthropic, "claude-123-agi"), - (LanguageModelProvider::OpenAi, "gpt-4"), - (LanguageModelProvider::Google, "gemini-pro"), - ]; - - for (provider, model) in test_cases { - let result = - authorize_access_to_language_model(&config, &claims, Some("US"), provider, model); - - assert!( - result.is_ok(), - "Expected staff to have access to provider {:?}, model {}", - provider, - model - ); - } - } -} diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index 6a46184171..f56e9e61e3 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -20,7 +20,6 @@ use std::future::Future; use std::sync::Arc; use anyhow::anyhow; -pub use queries::usages::{ActiveUserCount, TokenUsage}; pub use sea_orm::ConnectOptions; use sea_orm::prelude::*; use sea_orm::{ diff --git a/crates/collab/src/llm/db/queries.rs b/crates/collab/src/llm/db/queries.rs index 79a17999b7..4a4a10fb51 100644 --- a/crates/collab/src/llm/db/queries.rs +++ b/crates/collab/src/llm/db/queries.rs @@ -2,5 +2,4 @@ use super::*; pub mod billing_events; pub mod providers; -pub mod revoked_access_tokens; pub mod usages; diff --git a/crates/collab/src/llm/db/queries/revoked_access_tokens.rs b/crates/collab/src/llm/db/queries/revoked_access_tokens.rs deleted file mode 100644 index 31d70192a0..0000000000 --- a/crates/collab/src/llm/db/queries/revoked_access_tokens.rs +++ /dev/null @@ -1,15 +0,0 @@ -use super::*; - -impl LlmDatabase { - /// Returns whether the access token with the given `jti` has been revoked. - pub async fn is_access_token_revoked(&self, jti: &str) -> Result { - self.transaction(|tx| async move { - Ok(revoked_access_token::Entity::find() - .filter(revoked_access_token::Column::Jti.eq(jti)) - .one(&*tx) - .await? - .is_some()) - }) - .await - } -} diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index 3dee5a41f6..6313e7572c 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -1,56 +1,12 @@ use crate::db::UserId; use crate::llm::Cents; -use chrono::{Datelike, Duration}; +use chrono::Datelike; use futures::StreamExt as _; -use rpc::LanguageModelProvider; -use sea_orm::QuerySelect; -use std::{iter, str::FromStr}; +use std::str::FromStr; use strum::IntoEnumIterator as _; use super::*; -#[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)] -pub struct TokenUsage { - pub input: usize, - pub input_cache_creation: usize, - pub input_cache_read: usize, - pub output: usize, -} - -impl TokenUsage { - pub fn total(&self) -> usize { - self.input + self.input_cache_creation + self.input_cache_read + self.output - } -} - -#[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)] -pub struct Usage { - pub requests_this_minute: usize, - pub tokens_this_minute: usize, - pub input_tokens_this_minute: usize, - pub output_tokens_this_minute: usize, - pub tokens_this_day: usize, - pub tokens_this_month: TokenUsage, - pub spending_this_month: Cents, - pub lifetime_spending: Cents, -} - -#[derive(Debug, PartialEq, Clone)] -pub struct ApplicationWideUsage { - pub provider: LanguageModelProvider, - pub model: String, - pub requests_this_minute: usize, - pub tokens_this_minute: usize, - pub input_tokens_this_minute: usize, - pub output_tokens_this_minute: usize, -} - -#[derive(Clone, Copy, Debug, Default)] -pub struct ActiveUserCount { - pub users_in_recent_minutes: usize, - pub users_in_recent_days: usize, -} - impl LlmDatabase { pub async fn initialize_usage_measures(&mut self) -> Result<()> { let all_measures = self @@ -90,100 +46,6 @@ impl LlmDatabase { Ok(()) } - pub async fn get_application_wide_usages_by_model( - &self, - now: DateTimeUtc, - ) -> Result> { - self.transaction(|tx| async move { - let past_minute = now - Duration::minutes(1); - let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute]; - let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute]; - let input_tokens_per_minute = - self.usage_measure_ids[&UsageMeasure::InputTokensPerMinute]; - let output_tokens_per_minute = - self.usage_measure_ids[&UsageMeasure::OutputTokensPerMinute]; - - let mut results = Vec::new(); - for ((provider, model_name), model) in self.models.iter() { - let mut usages = usage::Entity::find() - .filter( - usage::Column::Timestamp - .gte(past_minute.naive_utc()) - .and(usage::Column::IsStaff.eq(false)) - .and(usage::Column::ModelId.eq(model.id)) - .and( - usage::Column::MeasureId - .eq(requests_per_minute) - .or(usage::Column::MeasureId.eq(tokens_per_minute)), - ), - ) - .stream(&*tx) - .await?; - - let mut requests_this_minute = 0; - let mut tokens_this_minute = 0; - let mut input_tokens_this_minute = 0; - let mut output_tokens_this_minute = 0; - while let Some(usage) = usages.next().await { - let usage = usage?; - if usage.measure_id == requests_per_minute { - requests_this_minute += Self::get_live_buckets( - &usage, - now.naive_utc(), - UsageMeasure::RequestsPerMinute, - ) - .0 - .iter() - .copied() - .sum::() as usize; - } else if usage.measure_id == tokens_per_minute { - tokens_this_minute += Self::get_live_buckets( - &usage, - now.naive_utc(), - UsageMeasure::TokensPerMinute, - ) - .0 - .iter() - .copied() - .sum::() as usize; - } else if usage.measure_id == input_tokens_per_minute { - input_tokens_this_minute += Self::get_live_buckets( - &usage, - now.naive_utc(), - UsageMeasure::InputTokensPerMinute, - ) - .0 - .iter() - .copied() - .sum::() as usize; - } else if usage.measure_id == output_tokens_per_minute { - output_tokens_this_minute += Self::get_live_buckets( - &usage, - now.naive_utc(), - UsageMeasure::OutputTokensPerMinute, - ) - .0 - .iter() - .copied() - .sum::() as usize; - } - } - - results.push(ApplicationWideUsage { - provider: *provider, - model: model_name.clone(), - requests_this_minute, - tokens_this_minute, - input_tokens_this_minute, - output_tokens_this_minute, - }) - } - - Ok(results) - }) - .await - } - pub async fn get_user_spending_for_month( &self, user_id: UserId, @@ -223,499 +85,6 @@ impl LlmDatabase { }) .await } - - pub async fn get_usage( - &self, - user_id: UserId, - provider: LanguageModelProvider, - model_name: &str, - now: DateTimeUtc, - ) -> Result { - self.transaction(|tx| async move { - let model = self - .models - .get(&(provider, model_name.to_string())) - .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?; - - let usages = usage::Entity::find() - .filter( - usage::Column::UserId - .eq(user_id) - .and(usage::Column::ModelId.eq(model.id)), - ) - .all(&*tx) - .await?; - - let month = now.date_naive().month() as i32; - let year = now.date_naive().year(); - let monthly_usage = monthly_usage::Entity::find() - .filter( - monthly_usage::Column::UserId - .eq(user_id) - .and(monthly_usage::Column::ModelId.eq(model.id)) - .and(monthly_usage::Column::Month.eq(month)) - .and(monthly_usage::Column::Year.eq(year)), - ) - .one(&*tx) - .await?; - let lifetime_usage = lifetime_usage::Entity::find() - .filter( - lifetime_usage::Column::UserId - .eq(user_id) - .and(lifetime_usage::Column::ModelId.eq(model.id)), - ) - .one(&*tx) - .await?; - - let requests_this_minute = - self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?; - let tokens_this_minute = - self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerMinute)?; - let input_tokens_this_minute = - self.get_usage_for_measure(&usages, now, UsageMeasure::InputTokensPerMinute)?; - let output_tokens_this_minute = - self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMinute)?; - let tokens_this_day = - self.get_usage_for_measure(&usages, now, UsageMeasure::TokensPerDay)?; - let spending_this_month = if let Some(monthly_usage) = &monthly_usage { - calculate_spending( - model, - monthly_usage.input_tokens as usize, - monthly_usage.cache_creation_input_tokens as usize, - monthly_usage.cache_read_input_tokens as usize, - monthly_usage.output_tokens as usize, - ) - } else { - Cents::ZERO - }; - let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage { - calculate_spending( - model, - lifetime_usage.input_tokens as usize, - lifetime_usage.cache_creation_input_tokens as usize, - lifetime_usage.cache_read_input_tokens as usize, - lifetime_usage.output_tokens as usize, - ) - } else { - Cents::ZERO - }; - - Ok(Usage { - requests_this_minute, - tokens_this_minute, - input_tokens_this_minute, - output_tokens_this_minute, - tokens_this_day, - tokens_this_month: TokenUsage { - input: monthly_usage - .as_ref() - .map_or(0, |usage| usage.input_tokens as usize), - input_cache_creation: monthly_usage - .as_ref() - .map_or(0, |usage| usage.cache_creation_input_tokens as usize), - input_cache_read: monthly_usage - .as_ref() - .map_or(0, |usage| usage.cache_read_input_tokens as usize), - output: monthly_usage - .as_ref() - .map_or(0, |usage| usage.output_tokens as usize), - }, - spending_this_month, - lifetime_spending, - }) - }) - .await - } - - pub async fn record_usage( - &self, - user_id: UserId, - is_staff: bool, - provider: LanguageModelProvider, - model_name: &str, - tokens: TokenUsage, - has_llm_subscription: bool, - max_monthly_spend: Cents, - free_tier_monthly_spending_limit: Cents, - now: DateTimeUtc, - ) -> Result { - self.transaction(|tx| async move { - let model = self.model(provider, model_name)?; - - let usages = usage::Entity::find() - .filter( - usage::Column::UserId - .eq(user_id) - .and(usage::Column::ModelId.eq(model.id)), - ) - .all(&*tx) - .await?; - - let requests_this_minute = self - .update_usage_for_measure( - user_id, - is_staff, - model.id, - &usages, - UsageMeasure::RequestsPerMinute, - now, - 1, - &tx, - ) - .await?; - let tokens_this_minute = self - .update_usage_for_measure( - user_id, - is_staff, - model.id, - &usages, - UsageMeasure::TokensPerMinute, - now, - tokens.total(), - &tx, - ) - .await?; - let input_tokens_this_minute = self - .update_usage_for_measure( - user_id, - is_staff, - model.id, - &usages, - UsageMeasure::InputTokensPerMinute, - now, - // Cache read input tokens are not counted for the purposes of rate limits (but they are still billed). - tokens.input + tokens.input_cache_creation, - &tx, - ) - .await?; - let output_tokens_this_minute = self - .update_usage_for_measure( - user_id, - is_staff, - model.id, - &usages, - UsageMeasure::OutputTokensPerMinute, - now, - tokens.output, - &tx, - ) - .await?; - let tokens_this_day = self - .update_usage_for_measure( - user_id, - is_staff, - model.id, - &usages, - UsageMeasure::TokensPerDay, - now, - tokens.total(), - &tx, - ) - .await?; - - let month = now.date_naive().month() as i32; - let year = now.date_naive().year(); - - // Update monthly usage - let monthly_usage = monthly_usage::Entity::find() - .filter( - monthly_usage::Column::UserId - .eq(user_id) - .and(monthly_usage::Column::ModelId.eq(model.id)) - .and(monthly_usage::Column::Month.eq(month)) - .and(monthly_usage::Column::Year.eq(year)), - ) - .one(&*tx) - .await?; - - let monthly_usage = match monthly_usage { - Some(usage) => { - monthly_usage::Entity::update(monthly_usage::ActiveModel { - id: ActiveValue::unchanged(usage.id), - input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64), - cache_creation_input_tokens: ActiveValue::set( - usage.cache_creation_input_tokens + tokens.input_cache_creation as i64, - ), - cache_read_input_tokens: ActiveValue::set( - usage.cache_read_input_tokens + tokens.input_cache_read as i64, - ), - output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64), - ..Default::default() - }) - .exec(&*tx) - .await? - } - None => { - monthly_usage::ActiveModel { - user_id: ActiveValue::set(user_id), - model_id: ActiveValue::set(model.id), - month: ActiveValue::set(month), - year: ActiveValue::set(year), - input_tokens: ActiveValue::set(tokens.input as i64), - cache_creation_input_tokens: ActiveValue::set( - tokens.input_cache_creation as i64, - ), - cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64), - output_tokens: ActiveValue::set(tokens.output as i64), - ..Default::default() - } - .insert(&*tx) - .await? - } - }; - - let spending_this_month = calculate_spending( - model, - monthly_usage.input_tokens as usize, - monthly_usage.cache_creation_input_tokens as usize, - monthly_usage.cache_read_input_tokens as usize, - monthly_usage.output_tokens as usize, - ); - - if !is_staff - && spending_this_month > free_tier_monthly_spending_limit - && has_llm_subscription - && (spending_this_month - free_tier_monthly_spending_limit) <= max_monthly_spend - { - billing_event::ActiveModel { - id: ActiveValue::not_set(), - idempotency_key: ActiveValue::not_set(), - user_id: ActiveValue::set(user_id), - model_id: ActiveValue::set(model.id), - input_tokens: ActiveValue::set(tokens.input as i64), - input_cache_creation_tokens: ActiveValue::set( - tokens.input_cache_creation as i64, - ), - input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64), - output_tokens: ActiveValue::set(tokens.output as i64), - } - .insert(&*tx) - .await?; - } - - // Update lifetime usage - let lifetime_usage = lifetime_usage::Entity::find() - .filter( - lifetime_usage::Column::UserId - .eq(user_id) - .and(lifetime_usage::Column::ModelId.eq(model.id)), - ) - .one(&*tx) - .await?; - - let lifetime_usage = match lifetime_usage { - Some(usage) => { - lifetime_usage::Entity::update(lifetime_usage::ActiveModel { - id: ActiveValue::unchanged(usage.id), - input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64), - cache_creation_input_tokens: ActiveValue::set( - usage.cache_creation_input_tokens + tokens.input_cache_creation as i64, - ), - cache_read_input_tokens: ActiveValue::set( - usage.cache_read_input_tokens + tokens.input_cache_read as i64, - ), - output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64), - ..Default::default() - }) - .exec(&*tx) - .await? - } - None => { - lifetime_usage::ActiveModel { - user_id: ActiveValue::set(user_id), - model_id: ActiveValue::set(model.id), - input_tokens: ActiveValue::set(tokens.input as i64), - cache_creation_input_tokens: ActiveValue::set( - tokens.input_cache_creation as i64, - ), - cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64), - output_tokens: ActiveValue::set(tokens.output as i64), - ..Default::default() - } - .insert(&*tx) - .await? - } - }; - - let lifetime_spending = calculate_spending( - model, - lifetime_usage.input_tokens as usize, - lifetime_usage.cache_creation_input_tokens as usize, - lifetime_usage.cache_read_input_tokens as usize, - lifetime_usage.output_tokens as usize, - ); - - Ok(Usage { - requests_this_minute, - tokens_this_minute, - input_tokens_this_minute, - output_tokens_this_minute, - tokens_this_day, - tokens_this_month: TokenUsage { - input: monthly_usage.input_tokens as usize, - input_cache_creation: monthly_usage.cache_creation_input_tokens as usize, - input_cache_read: monthly_usage.cache_read_input_tokens as usize, - output: monthly_usage.output_tokens as usize, - }, - spending_this_month, - lifetime_spending, - }) - }) - .await - } - - /// Returns the active user count for the specified model. - pub async fn get_active_user_count( - &self, - provider: LanguageModelProvider, - model_name: &str, - now: DateTimeUtc, - ) -> Result { - self.transaction(|tx| async move { - let minute_since = now - Duration::minutes(5); - let day_since = now - Duration::days(5); - - let model = self - .models - .get(&(provider, model_name.to_string())) - .ok_or_else(|| anyhow!("unknown model {provider}:{model_name}"))?; - - let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute]; - - let users_in_recent_minutes = usage::Entity::find() - .filter( - usage::Column::ModelId - .eq(model.id) - .and(usage::Column::MeasureId.eq(tokens_per_minute)) - .and(usage::Column::Timestamp.gte(minute_since.naive_utc())) - .and(usage::Column::IsStaff.eq(false)), - ) - .select_only() - .column(usage::Column::UserId) - .group_by(usage::Column::UserId) - .count(&*tx) - .await? as usize; - - let users_in_recent_days = usage::Entity::find() - .filter( - usage::Column::ModelId - .eq(model.id) - .and(usage::Column::MeasureId.eq(tokens_per_minute)) - .and(usage::Column::Timestamp.gte(day_since.naive_utc())) - .and(usage::Column::IsStaff.eq(false)), - ) - .select_only() - .column(usage::Column::UserId) - .group_by(usage::Column::UserId) - .count(&*tx) - .await? as usize; - - Ok(ActiveUserCount { - users_in_recent_minutes, - users_in_recent_days, - }) - }) - .await - } - - async fn update_usage_for_measure( - &self, - user_id: UserId, - is_staff: bool, - model_id: ModelId, - usages: &[usage::Model], - usage_measure: UsageMeasure, - now: DateTimeUtc, - usage_to_add: usize, - tx: &DatabaseTransaction, - ) -> Result { - let now = now.naive_utc(); - let measure_id = *self - .usage_measure_ids - .get(&usage_measure) - .ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?; - - let mut id = None; - let mut timestamp = now; - let mut buckets = vec![0_i64]; - - if let Some(old_usage) = usages.iter().find(|usage| usage.measure_id == measure_id) { - id = Some(old_usage.id); - let (live_buckets, buckets_since) = - Self::get_live_buckets(old_usage, now, usage_measure); - if !live_buckets.is_empty() { - buckets.clear(); - buckets.extend_from_slice(live_buckets); - buckets.extend(iter::repeat(0).take(buckets_since)); - timestamp = - old_usage.timestamp + (usage_measure.bucket_duration() * buckets_since as i32); - } - } - - *buckets.last_mut().unwrap() += usage_to_add as i64; - let total_usage = buckets.iter().sum::() as usize; - - let mut model = usage::ActiveModel { - user_id: ActiveValue::set(user_id), - is_staff: ActiveValue::set(is_staff), - model_id: ActiveValue::set(model_id), - measure_id: ActiveValue::set(measure_id), - timestamp: ActiveValue::set(timestamp), - buckets: ActiveValue::set(buckets), - ..Default::default() - }; - - if let Some(id) = id { - model.id = ActiveValue::unchanged(id); - model.update(tx).await?; - } else { - usage::Entity::insert(model) - .exec_without_returning(tx) - .await?; - } - - Ok(total_usage) - } - - fn get_usage_for_measure( - &self, - usages: &[usage::Model], - now: DateTimeUtc, - usage_measure: UsageMeasure, - ) -> Result { - let now = now.naive_utc(); - let measure_id = *self - .usage_measure_ids - .get(&usage_measure) - .ok_or_else(|| anyhow!("usage measure {usage_measure} not found"))?; - let Some(usage) = usages.iter().find(|usage| usage.measure_id == measure_id) else { - return Ok(0); - }; - - let (live_buckets, _) = Self::get_live_buckets(usage, now, usage_measure); - Ok(live_buckets.iter().sum::() as _) - } - - fn get_live_buckets( - usage: &usage::Model, - now: chrono::NaiveDateTime, - measure: UsageMeasure, - ) -> (&[i64], usize) { - let seconds_since_usage = (now - usage.timestamp).num_seconds().max(0); - let buckets_since_usage = - seconds_since_usage as f32 / measure.bucket_duration().num_seconds() as f32; - let buckets_since_usage = buckets_since_usage.ceil() as usize; - let mut live_buckets = &[] as &[i64]; - if buckets_since_usage < measure.bucket_count() { - let expired_bucket_count = - (usage.buckets.len() + buckets_since_usage).saturating_sub(measure.bucket_count()); - live_buckets = &usage.buckets[expired_bucket_count..]; - while live_buckets.first() == Some(&0) { - live_buckets = &live_buckets[1..]; - } - } - (live_buckets, buckets_since_usage) - } } fn calculate_spending( @@ -741,32 +110,3 @@ fn calculate_spending( + output_token_cost; Cents::new(spending as u32) } - -const MINUTE_BUCKET_COUNT: usize = 12; -const DAY_BUCKET_COUNT: usize = 48; - -impl UsageMeasure { - fn bucket_count(&self) -> usize { - match self { - UsageMeasure::RequestsPerMinute => MINUTE_BUCKET_COUNT, - UsageMeasure::TokensPerMinute - | UsageMeasure::InputTokensPerMinute - | UsageMeasure::OutputTokensPerMinute => MINUTE_BUCKET_COUNT, - UsageMeasure::TokensPerDay => DAY_BUCKET_COUNT, - } - } - - fn total_duration(&self) -> Duration { - match self { - UsageMeasure::RequestsPerMinute => Duration::minutes(1), - UsageMeasure::TokensPerMinute - | UsageMeasure::InputTokensPerMinute - | UsageMeasure::OutputTokensPerMinute => Duration::minutes(1), - UsageMeasure::TokensPerDay => Duration::hours(24), - } - } - - fn bucket_duration(&self) -> Duration { - self.total_duration() / self.bucket_count() as i32 - } -} diff --git a/crates/collab/src/llm/db/tables.rs b/crates/collab/src/llm/db/tables.rs index 407c5c8fd0..5f2d357a87 100644 --- a/crates/collab/src/llm/db/tables.rs +++ b/crates/collab/src/llm/db/tables.rs @@ -1,8 +1,6 @@ pub mod billing_event; -pub mod lifetime_usage; pub mod model; pub mod monthly_usage; pub mod provider; -pub mod revoked_access_token; pub mod usage; pub mod usage_measure; diff --git a/crates/collab/src/llm/db/tables/lifetime_usage.rs b/crates/collab/src/llm/db/tables/lifetime_usage.rs deleted file mode 100644 index fc8354699b..0000000000 --- a/crates/collab/src/llm/db/tables/lifetime_usage.rs +++ /dev/null @@ -1,20 +0,0 @@ -use crate::{db::UserId, llm::db::ModelId}; -use sea_orm::entity::prelude::*; - -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "lifetime_usages")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: i32, - pub user_id: UserId, - pub model_id: ModelId, - pub input_tokens: i64, - pub cache_creation_input_tokens: i64, - pub cache_read_input_tokens: i64, - pub output_tokens: i64, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/revoked_access_token.rs b/crates/collab/src/llm/db/tables/revoked_access_token.rs deleted file mode 100644 index 364963be88..0000000000 --- a/crates/collab/src/llm/db/tables/revoked_access_token.rs +++ /dev/null @@ -1,19 +0,0 @@ -use chrono::NaiveDateTime; -use sea_orm::entity::prelude::*; - -use crate::llm::db::RevokedAccessTokenId; - -/// A revoked access token. -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "revoked_access_tokens")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: RevokedAccessTokenId, - pub jti: String, - pub revoked_at: NaiveDateTime, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tests.rs b/crates/collab/src/llm/db/tests.rs index 59f92958c7..43a1b8b0d4 100644 --- a/crates/collab/src/llm/db/tests.rs +++ b/crates/collab/src/llm/db/tests.rs @@ -1,6 +1,4 @@ -mod billing_tests; mod provider_tests; -mod usage_tests; use gpui::BackgroundExecutor; use parking_lot::Mutex; diff --git a/crates/collab/src/llm/db/tests/billing_tests.rs b/crates/collab/src/llm/db/tests/billing_tests.rs deleted file mode 100644 index 3a95610bc2..0000000000 --- a/crates/collab/src/llm/db/tests/billing_tests.rs +++ /dev/null @@ -1,152 +0,0 @@ -use crate::{ - Cents, - db::UserId, - llm::{ - FREE_TIER_MONTHLY_SPENDING_LIMIT, - db::{LlmDatabase, TokenUsage, queries::providers::ModelParams}, - }, - test_llm_db, -}; -use chrono::{DateTime, Utc}; -use pretty_assertions::assert_eq; -use rpc::LanguageModelProvider; - -test_llm_db!( - test_billing_limit_exceeded, - test_billing_limit_exceeded_postgres -); - -async fn test_billing_limit_exceeded(db: &mut LlmDatabase) { - let provider = LanguageModelProvider::Anthropic; - let model = "fake-claude-limerick"; - const PRICE_PER_MILLION_INPUT_TOKENS: i32 = 5; - const PRICE_PER_MILLION_OUTPUT_TOKENS: i32 = 5; - - // Initialize the database and insert the model - db.initialize().await.unwrap(); - db.insert_models(&[ModelParams { - provider, - name: model.to_string(), - max_requests_per_minute: 5, - max_tokens_per_minute: 10_000, - max_tokens_per_day: 50_000, - price_per_million_input_tokens: PRICE_PER_MILLION_INPUT_TOKENS, - price_per_million_output_tokens: PRICE_PER_MILLION_OUTPUT_TOKENS, - }]) - .await - .unwrap(); - - // Set a fixed datetime for consistent testing - let now = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z") - .unwrap() - .with_timezone(&Utc); - - let user_id = UserId::from_proto(123); - - let max_monthly_spend = Cents::from_dollars(11); - - // Record usage that brings us close to the limit but doesn't exceed it - // Let's say we use $10.50 worth of tokens - let tokens_to_use = 210_000_000; // This will cost $10.50 at $0.05 per 1 million tokens - let usage = TokenUsage { - input: tokens_to_use, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }; - - // Verify that before we record any usage, there are 0 billing events - let billing_events = db.get_billing_events().await.unwrap(); - assert_eq!(billing_events.len(), 0); - - db.record_usage( - user_id, - false, - provider, - model, - usage, - true, - max_monthly_spend, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - // Verify the recorded usage and spending - let recorded_usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - // Verify that we exceeded the free tier usage - assert_eq!(recorded_usage.spending_this_month, Cents::new(1050)); - assert!(recorded_usage.spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT); - - // Verify that there is one `billing_event` record - let billing_events = db.get_billing_events().await.unwrap(); - assert_eq!(billing_events.len(), 1); - - let (billing_event, _model) = &billing_events[0]; - assert_eq!(billing_event.user_id, user_id); - assert_eq!(billing_event.input_tokens, tokens_to_use as i64); - assert_eq!(billing_event.input_cache_creation_tokens, 0); - assert_eq!(billing_event.input_cache_read_tokens, 0); - assert_eq!(billing_event.output_tokens, 0); - - // Record usage that puts us at $20.50 - let usage_2 = TokenUsage { - input: 200_000_000, // This will cost $10 more, pushing us from $10.50 to $20.50, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }; - db.record_usage( - user_id, - false, - provider, - model, - usage_2, - true, - max_monthly_spend, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - // Verify the updated usage and spending - let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!(updated_usage.spending_this_month, Cents::new(2050)); - - // Verify that there are now two billing events - let billing_events = db.get_billing_events().await.unwrap(); - assert_eq!(billing_events.len(), 2); - - let tokens_to_exceed = 20_000_000; // This will cost $1.00 more, pushing us from $20.50 to $21.50, which is over the $11 monthly maximum limit - let usage_exceeding = TokenUsage { - input: tokens_to_exceed, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }; - - // This should still create a billing event as it's the first request that exceeds the limit - db.record_usage( - user_id, - false, - provider, - model, - usage_exceeding, - true, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - max_monthly_spend, - now, - ) - .await - .unwrap(); - // Verify the updated usage and spending - let updated_usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!(updated_usage.spending_this_month, Cents::new(2150)); - - // Verify that we never exceed the user max spending for the user - // and avoid charging them. - let billing_events = db.get_billing_events().await.unwrap(); - assert_eq!(billing_events.len(), 2); -} diff --git a/crates/collab/src/llm/db/tests/usage_tests.rs b/crates/collab/src/llm/db/tests/usage_tests.rs deleted file mode 100644 index 0a4ef7f4cf..0000000000 --- a/crates/collab/src/llm/db/tests/usage_tests.rs +++ /dev/null @@ -1,306 +0,0 @@ -use crate::llm::FREE_TIER_MONTHLY_SPENDING_LIMIT; -use crate::{ - Cents, - db::UserId, - llm::db::{ - LlmDatabase, TokenUsage, - queries::{providers::ModelParams, usages::Usage}, - }, - test_llm_db, -}; -use chrono::{DateTime, Duration, Utc}; -use pretty_assertions::assert_eq; -use rpc::LanguageModelProvider; - -test_llm_db!(test_tracking_usage, test_tracking_usage_postgres); - -async fn test_tracking_usage(db: &mut LlmDatabase) { - let provider = LanguageModelProvider::Anthropic; - let model = "claude-3-5-sonnet"; - - db.initialize().await.unwrap(); - db.insert_models(&[ModelParams { - provider, - name: model.to_string(), - max_requests_per_minute: 5, - max_tokens_per_minute: 10_000, - max_tokens_per_day: 50_000, - price_per_million_input_tokens: 50, - price_per_million_output_tokens: 50, - }]) - .await - .unwrap(); - - // We're using a fixed datetime to prevent flakiness based on the clock. - let t0 = DateTime::parse_from_rfc3339("2024-08-08T22:46:33Z") - .unwrap() - .with_timezone(&Utc); - let user_id = UserId::from_proto(123); - - let now = t0; - db.record_usage( - user_id, - false, - provider, - model, - TokenUsage { - input: 1000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - false, - Cents::ZERO, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - let now = t0 + Duration::seconds(10); - db.record_usage( - user_id, - false, - provider, - model, - TokenUsage { - input: 2000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - false, - Cents::ZERO, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 2, - tokens_this_minute: 3000, - input_tokens_this_minute: 3000, - output_tokens_this_minute: 0, - tokens_this_day: 3000, - tokens_this_month: TokenUsage { - input: 3000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - spending_this_month: Cents::ZERO, - lifetime_spending: Cents::ZERO, - } - ); - - let now = t0 + Duration::seconds(60); - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 1, - tokens_this_minute: 2000, - input_tokens_this_minute: 2000, - output_tokens_this_minute: 0, - tokens_this_day: 3000, - tokens_this_month: TokenUsage { - input: 3000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - spending_this_month: Cents::ZERO, - lifetime_spending: Cents::ZERO, - } - ); - - let now = t0 + Duration::seconds(60); - db.record_usage( - user_id, - false, - provider, - model, - TokenUsage { - input: 3000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - false, - Cents::ZERO, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 2, - tokens_this_minute: 5000, - input_tokens_this_minute: 5000, - output_tokens_this_minute: 0, - tokens_this_day: 6000, - tokens_this_month: TokenUsage { - input: 6000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - spending_this_month: Cents::ZERO, - lifetime_spending: Cents::ZERO, - } - ); - - let t1 = t0 + Duration::hours(24); - let now = t1; - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 0, - tokens_this_minute: 0, - input_tokens_this_minute: 0, - output_tokens_this_minute: 0, - tokens_this_day: 5000, - tokens_this_month: TokenUsage { - input: 6000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - spending_this_month: Cents::ZERO, - lifetime_spending: Cents::ZERO, - } - ); - - db.record_usage( - user_id, - false, - provider, - model, - TokenUsage { - input: 4000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - false, - Cents::ZERO, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 1, - tokens_this_minute: 4000, - input_tokens_this_minute: 4000, - output_tokens_this_minute: 0, - tokens_this_day: 9000, - tokens_this_month: TokenUsage { - input: 10000, - input_cache_creation: 0, - input_cache_read: 0, - output: 0, - }, - spending_this_month: Cents::ZERO, - lifetime_spending: Cents::ZERO, - } - ); - - // We're using a fixed datetime to prevent flakiness based on the clock. - let now = DateTime::parse_from_rfc3339("2024-10-08T22:15:58Z") - .unwrap() - .with_timezone(&Utc); - - // Test cache creation input tokens - db.record_usage( - user_id, - false, - provider, - model, - TokenUsage { - input: 1000, - input_cache_creation: 500, - input_cache_read: 0, - output: 0, - }, - false, - Cents::ZERO, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 1, - tokens_this_minute: 1500, - input_tokens_this_minute: 1500, - output_tokens_this_minute: 0, - tokens_this_day: 1500, - tokens_this_month: TokenUsage { - input: 1000, - input_cache_creation: 500, - input_cache_read: 0, - output: 0, - }, - spending_this_month: Cents::ZERO, - lifetime_spending: Cents::ZERO, - } - ); - - // Test cache read input tokens - db.record_usage( - user_id, - false, - provider, - model, - TokenUsage { - input: 1000, - input_cache_creation: 0, - input_cache_read: 300, - output: 0, - }, - false, - Cents::ZERO, - FREE_TIER_MONTHLY_SPENDING_LIMIT, - now, - ) - .await - .unwrap(); - - let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); - assert_eq!( - usage, - Usage { - requests_this_minute: 2, - tokens_this_minute: 2800, - input_tokens_this_minute: 2500, - output_tokens_this_minute: 0, - tokens_this_day: 2800, - tokens_this_month: TokenUsage { - input: 2000, - input_cache_creation: 500, - input_cache_read: 300, - output: 0, - }, - spending_this_month: Cents::ZERO, - lifetime_spending: Cents::ZERO, - } - ); -} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 30dab40cce..8f850ee847 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -9,14 +9,14 @@ use axum::{ use collab::api::CloudflareIpCountryHeader; use collab::api::billing::sync_llm_usage_with_stripe_periodically; -use collab::llm::{db::LlmDatabase, log_usage_periodically}; +use collab::llm::db::LlmDatabase; use collab::migrations::run_database_migrations; use collab::user_backfiller::spawn_user_backfiller; use collab::{ AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, rpc::ResultExt, }; -use collab::{ServiceMode, api::billing::poll_stripe_events_periodically, llm::LlmState}; +use collab::{ServiceMode, api::billing::poll_stripe_events_periodically}; use db::Database; use std::{ env::args, @@ -74,11 +74,10 @@ async fn main() -> Result<()> { let mode = match args.next().as_deref() { Some("collab") => ServiceMode::Collab, Some("api") => ServiceMode::Api, - Some("llm") => ServiceMode::Llm, Some("all") => ServiceMode::All, _ => { return Err(anyhow!( - "usage: collab >" + "usage: collab >" ))?; } }; @@ -97,20 +96,9 @@ async fn main() -> Result<()> { let mut on_shutdown = None; - if mode.is_llm() { - setup_llm_database(&config).await?; - - let state = LlmState::new(config.clone(), Executor::Production).await?; - - log_usage_periodically(state.clone()); - - app = app - .merge(collab::llm::routes()) - .layer(Extension(state.clone())); - } - if mode.is_collab() || mode.is_api() { setup_app_database(&config).await?; + setup_llm_database(&config).await?; let state = AppState::new(config, Executor::Production).await?; @@ -336,18 +324,11 @@ async fn handle_root(Extension(mode): Extension) -> String { format!("zed:{mode} v{VERSION} ({})", REVISION.unwrap_or("unknown")) } -async fn handle_liveness_probe( - app_state: Option>>, - llm_state: Option>>, -) -> Result { +async fn handle_liveness_probe(app_state: Option>>) -> Result { if let Some(state) = app_state { state.db.get_all_users(0, 1).await?; } - if let Some(llm_state) = llm_state { - llm_state.db.list_providers().await?; - } - Ok("ok".to_string()) } diff --git a/crates/collab/src/tests/editor_tests.rs b/crates/collab/src/tests/editor_tests.rs index 719b8643f2..8a039da882 100644 --- a/crates/collab/src/tests/editor_tests.rs +++ b/crates/collab/src/tests/editor_tests.rs @@ -694,7 +694,15 @@ async fn test_collaborating_with_code_actions( // Confirming the code action will trigger a resolve request. let confirm_action = editor_b .update_in(cx_b, |editor, window, cx| { - Editor::confirm_code_action(editor, &ConfirmCodeAction { item_ix: Some(0) }, window, cx) + Editor::confirm_code_action( + editor, + &ConfirmCodeAction { + item_ix: Some(0), + from_mouse_context_menu: false, + }, + window, + cx, + ) }) .unwrap(); fake_language_server.set_request_handler::( diff --git a/crates/component/src/component.rs b/crates/component/src/component.rs index 31ed169743..db847d5538 100644 --- a/crates/component/src/component.rs +++ b/crates/component/src/component.rs @@ -191,6 +191,14 @@ pub fn components() -> AllComponents { all_components } +// #[derive(Debug, Clone, PartialEq, Eq, Hash)] +// pub enum ComponentStatus { +// WorkInProgress, +// EngineeringReady, +// Live, +// Deprecated, +// } + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum ComponentScope { Collaboration, @@ -241,24 +249,30 @@ pub struct ComponentExample { impl RenderOnce for ComponentExample { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { div() + .pt_2() .w_full() .flex() .flex_col() .gap_3() .child( div() - .child(self.variant_name.clone()) - .text_size(rems(1.25)) - .text_color(cx.theme().colors().text), + .flex() + .flex_col() + .child( + div() + .child(self.variant_name.clone()) + .text_size(rems(1.0)) + .text_color(cx.theme().colors().text), + ) + .when_some(self.description, |this, description| { + this.child( + div() + .text_size(rems(0.875)) + .text_color(cx.theme().colors().text_muted) + .child(description.clone()), + ) + }), ) - .when_some(self.description, |this, description| { - this.child( - div() - .text_size(rems(0.9375)) - .text_color(cx.theme().colors().text_muted) - .child(description.clone()), - ) - }) .child( div() .flex() @@ -268,11 +282,11 @@ impl RenderOnce for ComponentExample { .justify_center() .p_8() .border_1() - .border_color(cx.theme().colors().border) + .border_color(cx.theme().colors().border.opacity(0.5)) .bg(pattern_slash( cx.theme().colors().surface_background.opacity(0.5), - 24.0, - 24.0, + 12.0, + 12.0, )) .shadow_sm() .child(self.element), diff --git a/crates/component_preview/Cargo.toml b/crates/component_preview/Cargo.toml index d8b4df34dd..e01e7d2208 100644 --- a/crates/component_preview/Cargo.toml +++ b/crates/component_preview/Cargo.toml @@ -16,12 +16,16 @@ default = [] [dependencies] client.workspace = true +collections.workspace = true component.workspace = true gpui.workspace = true languages.workspace = true +notifications.workspace = true project.workspace = true ui.workspace = true -workspace.workspace = true -notifications.workspace = true -collections.workspace = true +ui_input.workspace = true workspace-hack.workspace = true +workspace.workspace = true +db.workspace = true +anyhow.workspace = true +serde.workspace = true diff --git a/crates/component_preview/src/component_preview.rs b/crates/component_preview/src/component_preview.rs index 5109b8692d..276271828e 100644 --- a/crates/component_preview/src/component_preview.rs +++ b/crates/component_preview/src/component_preview.rs @@ -2,6 +2,8 @@ //! //! A view for exploring Zed components. +mod persistence; + use std::iter::Iterator; use std::sync::Arc; @@ -9,24 +11,27 @@ use client::UserStore; use component::{ComponentId, ComponentMetadata, components}; use gpui::{ App, Entity, EventEmitter, FocusHandle, Focusable, Task, WeakEntity, Window, list, prelude::*, - uniform_list, }; use collections::HashMap; -use gpui::{ListState, ScrollHandle, UniformListScrollHandle}; +use gpui::{ListState, ScrollHandle, ScrollStrategy, UniformListScrollHandle}; use languages::LanguageRegistry; use notifications::status_toast::{StatusToast, ToastIcon}; +use persistence::COMPONENT_PREVIEW_DB; use project::Project; -use ui::{Divider, ListItem, ListSubHeader, prelude::*}; +use ui::{Divider, HighlightedLabel, ListItem, ListSubHeader, prelude::*}; +use ui_input::SingleLineInput; use workspace::{AppState, ItemId, SerializableItem}; use workspace::{Item, Workspace, WorkspaceId, item::ItemEvent}; pub fn init(app_state: Arc, cx: &mut App) { + workspace::register_serializable_item::(cx); + let app_state = app_state.clone(); - cx.observe_new(move |workspace: &mut Workspace, _, cx| { + cx.observe_new(move |workspace: &mut Workspace, _window, cx| { let app_state = app_state.clone(); let weak_workspace = cx.entity().downgrade(); @@ -44,6 +49,7 @@ pub fn init(app_state: Arc, cx: &mut App) { user_store, None, None, + window, cx, ) }); @@ -64,13 +70,13 @@ pub fn init(app_state: Arc, cx: &mut App) { enum PreviewEntry { AllComponents, Separator, - Component(ComponentMetadata), + Component(ComponentMetadata, Option>), SectionHeader(SharedString), } impl From for PreviewEntry { fn from(component: ComponentMetadata) -> Self { - PreviewEntry::Component(component) + PreviewEntry::Component(component, None) } } @@ -88,6 +94,7 @@ enum PreviewPage { } struct ComponentPreview { + workspace_id: Option, focus_handle: FocusHandle, _view_scroll_handle: ScrollHandle, nav_scroll_handle: UniformListScrollHandle, @@ -99,6 +106,8 @@ struct ComponentPreview { language_registry: Arc, workspace: WeakEntity, user_store: Entity, + filter_editor: Entity, + filter_text: String, } impl ComponentPreview { @@ -108,11 +117,14 @@ impl ComponentPreview { user_store: Entity, selected_index: impl Into>, active_page: Option, + window: &mut Window, cx: &mut Context, ) -> Self { let sorted_components = components().all_sorted(); let selected_index = selected_index.into().unwrap_or(0); let active_page = active_page.unwrap_or(PreviewPage::AllComponents); + let filter_editor = + cx.new(|cx| SingleLineInput::new(window, cx, "Find components or usages…")); let component_list = ListState::new( sorted_components.len(), @@ -132,6 +144,7 @@ impl ComponentPreview { ); let mut component_preview = Self { + workspace_id: None, focus_handle: cx.focus_handle(), _view_scroll_handle: ScrollHandle::new(), nav_scroll_handle: UniformListScrollHandle::new(), @@ -143,6 +156,8 @@ impl ComponentPreview { components: sorted_components, component_list, cursor_index: selected_index, + filter_editor, + filter_text: String::new(), }; if component_preview.cursor_index > 0 { @@ -154,6 +169,13 @@ impl ComponentPreview { component_preview } + pub fn active_page_id(&self, _cx: &App) -> ActivePageId { + match &self.active_page { + PreviewPage::AllComponents => ActivePageId::default(), + PreviewPage::Component(component_id) => ActivePageId(component_id.0.to_string()), + } + } + fn scroll_to_preview(&mut self, ix: usize, cx: &mut Context) { self.component_list.scroll_to_reveal_item(ix); self.cursor_index = ix; @@ -162,6 +184,7 @@ impl ComponentPreview { fn set_active_page(&mut self, page: PreviewPage, cx: &mut Context) { self.active_page = page; + cx.emit(ItemEvent::UpdateTab); cx.notify(); } @@ -169,20 +192,94 @@ impl ComponentPreview { self.components[ix].clone() } + fn filtered_components(&self) -> Vec { + if self.filter_text.is_empty() { + return self.components.clone(); + } + + let filter = self.filter_text.to_lowercase(); + self.components + .iter() + .filter(|component| { + let component_name = component.name().to_lowercase(); + let scope_name = component.scope().to_string().to_lowercase(); + let description = component + .description() + .map(|d| d.to_lowercase()) + .unwrap_or_default(); + + component_name.contains(&filter) + || scope_name.contains(&filter) + || description.contains(&filter) + }) + .cloned() + .collect() + } + fn scope_ordered_entries(&self) -> Vec { use std::collections::HashMap; - let mut scope_groups: HashMap> = HashMap::default(); + let mut scope_groups: HashMap< + ComponentScope, + Vec<(ComponentMetadata, Option>)>, + > = HashMap::default(); + let lowercase_filter = self.filter_text.to_lowercase(); for component in &self.components { - scope_groups - .entry(component.scope()) - .or_insert_with(Vec::new) - .push(component.clone()); + if self.filter_text.is_empty() { + scope_groups + .entry(component.scope()) + .or_insert_with(Vec::new) + .push((component.clone(), None)); + continue; + } + + // let full_component_name = component.name(); + let scopeless_name = component.scopeless_name(); + let scope_name = component.scope().to_string(); + let description = component.description().unwrap_or_default(); + + let lowercase_scopeless = scopeless_name.to_lowercase(); + let lowercase_scope = scope_name.to_lowercase(); + let lowercase_desc = description.to_lowercase(); + + if lowercase_scopeless.contains(&lowercase_filter) { + if let Some(index) = lowercase_scopeless.find(&lowercase_filter) { + let end = index + lowercase_filter.len(); + + if end <= scopeless_name.len() { + let mut positions = Vec::new(); + for i in index..end { + if scopeless_name.is_char_boundary(i) { + positions.push(i); + } + } + + if !positions.is_empty() { + scope_groups + .entry(component.scope()) + .or_insert_with(Vec::new) + .push((component.clone(), Some(positions))); + continue; + } + } + } + } + + if lowercase_scopeless.contains(&lowercase_filter) + || lowercase_scope.contains(&lowercase_filter) + || lowercase_desc.contains(&lowercase_filter) + { + scope_groups + .entry(component.scope()) + .or_insert_with(Vec::new) + .push((component.clone(), None)); + } } + // Sort the components in each group for components in scope_groups.values_mut() { - components.sort_by_key(|c| c.name().to_lowercase()); + components.sort_by_key(|(c, _)| c.sort_name()); } let mut entries = Vec::new(); @@ -204,10 +301,10 @@ impl ComponentPreview { if !components.is_empty() { entries.push(PreviewEntry::SectionHeader(scope.to_string().into())); let mut sorted_components = components; - sorted_components.sort_by_key(|component| component.sort_name()); + sorted_components.sort_by_key(|(component, _)| component.sort_name()); - for component in sorted_components { - entries.push(PreviewEntry::Component(component)); + for (component, positions) in sorted_components { + entries.push(PreviewEntry::Component(component, positions)); } } } @@ -219,10 +316,10 @@ impl ComponentPreview { entries.push(PreviewEntry::Separator); entries.push(PreviewEntry::SectionHeader("Uncategorized".into())); let mut sorted_components = components.clone(); - sorted_components.sort_by_key(|c| c.sort_name()); + sorted_components.sort_by_key(|(c, _)| c.sort_name()); - for component in sorted_components { - entries.push(PreviewEntry::Component(component.clone())); + for (component, positions) in sorted_components { + entries.push(PreviewEntry::Component(component, positions)); } } } @@ -237,14 +334,33 @@ impl ComponentPreview { cx: &Context, ) -> impl IntoElement + use<> { match entry { - PreviewEntry::Component(component_metadata) => { + PreviewEntry::Component(component_metadata, highlight_positions) => { let id = component_metadata.id(); let selected = self.active_page == PreviewPage::Component(id.clone()); + let name = component_metadata.scopeless_name(); + ListItem::new(ix) - .child( - Label::new(component_metadata.scopeless_name().clone()) - .color(Color::Default), - ) + .child(if let Some(_positions) = highlight_positions { + let name_lower = name.to_lowercase(); + let filter_lower = self.filter_text.to_lowercase(); + let valid_positions = if let Some(start) = name_lower.find(&filter_lower) { + let end = start + filter_lower.len(); + (start..end).collect() + } else { + Vec::new() + }; + if valid_positions.is_empty() { + Label::new(name.clone()) + .color(Color::Default) + .into_any_element() + } else { + HighlightedLabel::new(name.clone(), valid_positions).into_any_element() + } + } else { + Label::new(name.clone()) + .color(Color::Default) + .into_any_element() + }) .selectable(true) .toggle_state(selected) .inset(true) @@ -282,20 +398,70 @@ impl ComponentPreview { } fn update_component_list(&mut self, cx: &mut Context) { - let new_len = self.scope_ordered_entries().len(); let entries = self.scope_ordered_entries(); + let new_len = entries.len(); let weak_entity = cx.entity().downgrade(); + if new_len > 0 { + self.nav_scroll_handle + .scroll_to_item(0, ScrollStrategy::Top); + } + + let filtered_components = self.filtered_components(); + + if !self.filter_text.is_empty() && !matches!(self.active_page, PreviewPage::AllComponents) { + if let PreviewPage::Component(ref component_id) = self.active_page { + let component_still_visible = filtered_components + .iter() + .any(|component| component.id() == *component_id); + + if !component_still_visible { + if !filtered_components.is_empty() { + let first_component = &filtered_components[0]; + self.set_active_page(PreviewPage::Component(first_component.id()), cx); + } else { + self.set_active_page(PreviewPage::AllComponents, cx); + } + } + } + } + + self.component_list = ListState::new( + filtered_components.len(), + gpui::ListAlignment::Top, + px(1500.0), + { + let components = filtered_components.clone(); + let this = cx.entity().downgrade(); + move |ix, window: &mut Window, cx: &mut App| { + if ix >= components.len() { + return div().w_full().h_0().into_any_element(); + } + + this.update(cx, |this, cx| { + let component = &components[ix]; + this.render_preview(component, window, cx) + .into_any_element() + }) + .unwrap() + } + }, + ); + let new_list = ListState::new( new_len, gpui::ListAlignment::Top, px(1500.0), move |ix, window, cx| { + if ix >= entries.len() { + return div().w_full().h_0().into_any_element(); + } + let entry = &entries[ix]; weak_entity .update(cx, |this, cx| match entry { - PreviewEntry::Component(component) => this + PreviewEntry::Component(component, _) => this .render_preview(component, window, cx) .into_any_element(), PreviewEntry::SectionHeader(shared_string) => this @@ -309,6 +475,7 @@ impl ComponentPreview { ); self.component_list = new_list; + cx.emit(ItemEvent::UpdateTab); } fn render_scope_header( @@ -377,16 +544,27 @@ impl ComponentPreview { .into_any_element() } - fn render_all_components(&self) -> impl IntoElement { + fn render_all_components(&self, cx: &Context) -> impl IntoElement { v_flex() .id("component-list") .px_8() .pt_4() .size_full() .child( - list(self.component_list.clone()) - .flex_grow() - .with_sizing_behavior(gpui::ListSizingBehavior::Auto), + if self.filtered_components().is_empty() && !self.filter_text.is_empty() { + div() + .size_full() + .items_center() + .justify_center() + .text_color(cx.theme().colors().text_muted) + .child(format!("No components matching '{}'.", self.filter_text)) + .into_any_element() + } else { + list(self.component_list.clone()) + .flex_grow() + .with_sizing_behavior(gpui::ListSizingBehavior::Auto) + .into_any_element() + }, ) } @@ -432,6 +610,19 @@ impl ComponentPreview { impl Render for ComponentPreview { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + // TODO: move this into the struct + let current_filter = self.filter_editor.update(cx, |input, cx| { + if input.is_empty(cx) { + String::new() + } else { + input.editor().read(cx).text(cx).to_string() + } + }); + + if current_filter != self.filter_text { + self.filter_text = current_filter; + self.update_component_list(cx); + } let sidebar_entries = self.scope_ordered_entries(); let active_page = self.active_page.clone(); @@ -449,14 +640,22 @@ impl Render for ComponentPreview { .border_color(cx.theme().colors().border) .h_full() .child( - uniform_list( + gpui::uniform_list( cx.entity().clone(), "component-nav", sidebar_entries.len(), move |this, range, _window, cx| { range - .map(|ix| { - this.render_sidebar_entry(ix, &sidebar_entries[ix], cx) + .filter_map(|ix| { + if ix < sidebar_entries.len() { + Some(this.render_sidebar_entry( + ix, + &sidebar_entries[ix], + cx, + )) + } else { + None + } }) .collect() }, @@ -481,12 +680,29 @@ impl Render for ComponentPreview { ), ), ) - .child(match active_page { - PreviewPage::AllComponents => self.render_all_components().into_any_element(), - PreviewPage::Component(id) => self - .render_component_page(&id, window, cx) - .into_any_element(), - }) + .child( + v_flex() + .id("content-area") + .flex_1() + .size_full() + .overflow_hidden() + .child( + div() + .p_2() + .w_full() + .border_b_1() + .border_color(cx.theme().colors().border) + .child(self.filter_editor.clone()), + ) + .child(match active_page { + PreviewPage::AllComponents => { + self.render_all_components(cx).into_any_element() + } + PreviewPage::Component(id) => self + .render_component_page(&id, window, cx) + .into_any_element(), + }), + ) } } @@ -498,6 +714,21 @@ impl Focusable for ComponentPreview { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ActivePageId(pub String); + +impl Default for ActivePageId { + fn default() -> Self { + ActivePageId("AllComponents".to_string()) + } +} + +impl From for ActivePageId { + fn from(id: ComponentId) -> Self { + ActivePageId(id.0.to_string()) + } +} + impl Item for ComponentPreview { type Event = ItemEvent; @@ -516,7 +747,7 @@ impl Item for ComponentPreview { fn clone_on_split( &self, _workspace_id: Option, - _window: &mut Window, + window: &mut Window, cx: &mut Context, ) -> Option> where @@ -535,6 +766,7 @@ impl Item for ComponentPreview { user_store, selected_index, Some(active_page), + window, cx, ) })) @@ -543,6 +775,15 @@ impl Item for ComponentPreview { fn to_item_events(event: &Self::Event, mut f: impl FnMut(workspace::item::ItemEvent)) { f(*event) } + + fn added_to_workspace( + &mut self, + workspace: &mut Workspace, + _window: &mut Window, + _cx: &mut Context, + ) { + self.workspace_id = workspace.database_id(); + } } impl SerializableItem for ComponentPreview { @@ -553,26 +794,53 @@ impl SerializableItem for ComponentPreview { fn deserialize( project: Entity, workspace: WeakEntity, - _workspace_id: WorkspaceId, - _item_id: ItemId, + workspace_id: WorkspaceId, + item_id: ItemId, window: &mut Window, cx: &mut App, ) -> Task>> { + let deserialized_active_page = + match COMPONENT_PREVIEW_DB.get_active_page(item_id, workspace_id) { + Ok(page) => { + if let Some(page) = page { + ActivePageId(page) + } else { + ActivePageId::default() + } + } + Err(_) => ActivePageId::default(), + }; + let user_store = project.read(cx).user_store().clone(); let language_registry = project.read(cx).languages().clone(); + let preview_page = if deserialized_active_page.0 == ActivePageId::default().0 { + Some(PreviewPage::default()) + } else { + let component_str = deserialized_active_page.0; + let component_registry = components(); + let all_components = component_registry.all(); + let found_component = all_components.iter().find(|c| c.id().0 == component_str); + + if let Some(component) = found_component { + Some(PreviewPage::Component(component.id().clone())) + } else { + Some(PreviewPage::default()) + } + }; window.spawn(cx, async move |cx| { let user_store = user_store.clone(); let language_registry = language_registry.clone(); let weak_workspace = workspace.clone(); - cx.update(|_, cx| { + cx.update(move |window, cx| { Ok(cx.new(|cx| { ComponentPreview::new( weak_workspace, language_registry, user_store, None, - None, + preview_page, + window, cx, ) })) @@ -581,34 +849,41 @@ impl SerializableItem for ComponentPreview { } fn cleanup( - _workspace_id: WorkspaceId, - _alive_items: Vec, + workspace_id: WorkspaceId, + alive_items: Vec, _window: &mut Window, - _cx: &mut App, + cx: &mut App, ) -> Task> { - Task::ready(Ok(())) - // window.spawn(cx, |_| { - // ... - // }) + cx.background_spawn(async move { + COMPONENT_PREVIEW_DB + .delete_unloaded_items(workspace_id, alive_items) + .await + }) } fn serialize( &mut self, _workspace: &mut Workspace, - _item_id: ItemId, + item_id: ItemId, _closing: bool, _window: &mut Window, - _cx: &mut Context, + cx: &mut Context, ) -> Option>> { - // TODO: Serialize the active index so we can re-open to the same place - None + let active_page = self.active_page_id(cx); + let workspace_id = self.workspace_id?; + Some(cx.background_spawn(async move { + COMPONENT_PREVIEW_DB + .save_active_page(item_id, workspace_id, active_page.0) + .await + })) } - fn should_serialize(&self, _event: &Self::Event) -> bool { - false + fn should_serialize(&self, event: &Self::Event) -> bool { + matches!(event, ItemEvent::UpdateTab) } } +// TODO: use language registry to allow rendering markdown #[derive(IntoElement)] pub struct ComponentPreviewPage { // languages: Arc, diff --git a/crates/component_preview/src/persistence.rs b/crates/component_preview/src/persistence.rs new file mode 100644 index 0000000000..a3fb0c698b --- /dev/null +++ b/crates/component_preview/src/persistence.rs @@ -0,0 +1,73 @@ +use anyhow::Result; +use db::{define_connection, query, sqlez::statement::Statement, sqlez_macros::sql}; +use workspace::{ItemId, WorkspaceDb, WorkspaceId}; + +define_connection! { + pub static ref COMPONENT_PREVIEW_DB: ComponentPreviewDb = + &[sql!( + CREATE TABLE component_previews ( + workspace_id INTEGER, + item_id INTEGER UNIQUE, + active_page_id TEXT, + PRIMARY KEY(workspace_id, item_id), + FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) + ON DELETE CASCADE + ) STRICT; + )]; +} + +impl ComponentPreviewDb { + pub async fn save_active_page( + &self, + item_id: ItemId, + workspace_id: WorkspaceId, + active_page_id: String, + ) -> Result<()> { + let query = "INSERT INTO component_previews(item_id, workspace_id, active_page_id) + VALUES (?1, ?2, ?3) + ON CONFLICT DO UPDATE SET + active_page_id = ?3"; + self.write(move |conn| { + let mut statement = Statement::prepare(conn, query)?; + let mut next_index = statement.bind(&item_id, 1)?; + next_index = statement.bind(&workspace_id, next_index)?; + statement.bind(&active_page_id, next_index)?; + statement.exec() + }) + .await + } + + query! { + pub fn get_active_page(item_id: ItemId, workspace_id: WorkspaceId) -> Result> { + SELECT active_page_id + FROM component_previews + WHERE item_id = ? AND workspace_id = ? + } + } + + pub async fn delete_unloaded_items( + &self, + workspace: WorkspaceId, + alive_items: Vec, + ) -> Result<()> { + let placeholders = alive_items + .iter() + .map(|_| "?") + .collect::>() + .join(", "); + + let query = format!( + "DELETE FROM component_previews WHERE workspace_id = ? AND item_id NOT IN ({placeholders})" + ); + + self.write(move |conn| { + let mut statement = Statement::prepare(conn, query)?; + let mut next_index = statement.bind(&workspace, 1)?; + for id in alive_items { + next_index = statement.bind(&id, next_index)?; + } + statement.exec() + }) + .await + } +} diff --git a/crates/context_server/src/context_server_tool.rs b/crates/context_server/src/context_server_tool.rs index cb04ee6804..2a82819be3 100644 --- a/crates/context_server/src/context_server_tool.rs +++ b/crates/context_server/src/context_server_tool.rs @@ -53,16 +53,18 @@ impl Tool for ContextServerTool { true } - fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> serde_json::Value { - match &self.tool.input_schema { + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + let mut schema = self.tool.input_schema.clone(); + assistant_tool::adapt_schema_to_format(&mut schema, format)?; + Ok(match schema { serde_json::Value::Null => { serde_json::json!({ "type": "object", "properties": [] }) } serde_json::Value::Object(map) if map.is_empty() => { serde_json::json!({ "type": "object", "properties": [] }) } - _ => self.tool.input_schema.clone(), - } + _ => schema, + }) } fn ui_text(&self, _input: &serde_json::Value) -> String { diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index e2c1499838..bc9308f7fe 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -33,6 +33,8 @@ pub enum Model { Gpt4o, #[serde(alias = "gpt-4", rename = "gpt-4")] Gpt4, + #[serde(alias = "gpt-4.1", rename = "gpt-4.1")] + Gpt4_1, #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")] Gpt3_5Turbo, #[serde(alias = "o1", rename = "o1")] @@ -50,6 +52,8 @@ pub enum Model { Claude3_7SonnetThinking, #[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")] Gemini20Flash, + #[serde(alias = "gemini-2.5-pro", rename = "gemini-2.5-pro")] + Gemini25Pro, } impl Model { @@ -57,11 +61,12 @@ impl Model { match self { Self::Gpt4o | Self::Gpt4 + | Self::Gpt4_1 | Self::Gpt3_5Turbo | Self::Claude3_5Sonnet | Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => true, - Self::O3Mini | Self::O1 | Self::Gemini20Flash => false, + Self::O3Mini | Self::O1 | Self::Gemini20Flash | Self::Gemini25Pro => false, } } @@ -69,6 +74,7 @@ impl Model { match id { "gpt-4o" => Ok(Self::Gpt4o), "gpt-4" => Ok(Self::Gpt4), + "gpt-4.1" => Ok(Self::Gpt4_1), "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo), "o1" => Ok(Self::O1), "o3-mini" => Ok(Self::O3Mini), @@ -76,6 +82,7 @@ impl Model { "claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet), "claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking), "gemini-2.0-flash-001" => Ok(Self::Gemini20Flash), + "gemini-2.5-pro" => Ok(Self::Gemini25Pro), _ => Err(anyhow!("Invalid model id: {}", id)), } } @@ -84,6 +91,7 @@ impl Model { match self { Self::Gpt3_5Turbo => "gpt-3.5-turbo", Self::Gpt4 => "gpt-4", + Self::Gpt4_1 => "gpt-4.1", Self::Gpt4o => "gpt-4o", Self::O3Mini => "o3-mini", Self::O1 => "o1", @@ -91,6 +99,7 @@ impl Model { Self::Claude3_7Sonnet => "claude-3-7-sonnet", Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought", Self::Gemini20Flash => "gemini-2.0-flash-001", + Self::Gemini25Pro => "gemini-2.5-pro", } } @@ -98,6 +107,7 @@ impl Model { match self { Self::Gpt3_5Turbo => "GPT-3.5", Self::Gpt4 => "GPT-4", + Self::Gpt4_1 => "GPT-4.1", Self::Gpt4o => "GPT-4o", Self::O3Mini => "o3-mini", Self::O1 => "o1", @@ -105,6 +115,7 @@ impl Model { Self::Claude3_7Sonnet => "Claude 3.7 Sonnet", Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking", Self::Gemini20Flash => "Gemini 2.0 Flash", + Self::Gemini25Pro => "Gemini 2.5 Pro", } } @@ -112,13 +123,15 @@ impl Model { match self { Self::Gpt4o => 64_000, Self::Gpt4 => 32_768, + Self::Gpt4_1 => 1_047_576, Self::Gpt3_5Turbo => 12_288, Self::O3Mini => 64_000, Self::O1 => 20_000, Self::Claude3_5Sonnet => 200_000, Self::Claude3_7Sonnet => 90_000, Self::Claude3_7SonnetThinking => 90_000, - Model::Gemini20Flash => 128_000, + Self::Gemini20Flash => 128_000, + Self::Gemini25Pro => 128_000, } } } diff --git a/crates/dap/Cargo.toml b/crates/dap/Cargo.toml index 5a44d6b946..0fdd19c93e 100644 --- a/crates/dap/Cargo.toml +++ b/crates/dap/Cargo.toml @@ -39,7 +39,6 @@ log.workspace = true node_runtime.workspace = true parking_lot.workspace = true paths.workspace = true -regex.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/dap/src/adapters.rs b/crates/dap/src/adapters.rs index 175fdd8c2d..5f1083d0d6 100644 --- a/crates/dap/src/adapters.rs +++ b/crates/dap/src/adapters.rs @@ -20,7 +20,7 @@ use std::{ net::Ipv4Addr, ops::Deref, path::PathBuf, - sync::{Arc, LazyLock}, + sync::Arc, }; use task::{DebugAdapterConfig, DebugTaskDefinition}; use util::ResultExt; @@ -291,14 +291,7 @@ pub trait DebugAdapter: 'static + Send + Sync { /// Should return base configuration to make the debug adapter work fn request_args(&self, config: &DebugTaskDefinition) -> Value; - - fn attach_processes_filter(&self) -> regex::Regex { - EMPTY_REGEX.clone() - } } - -static EMPTY_REGEX: LazyLock = - LazyLock::new(|| regex::Regex::new("").expect("Regex compilation to succeed")); #[cfg(any(test, feature = "test-support"))] pub struct FakeAdapter {} @@ -375,10 +368,4 @@ impl DebugAdapter for FakeAdapter { }, }) } - - fn attach_processes_filter(&self) -> regex::Regex { - static REGEX: LazyLock = - LazyLock::new(|| regex::Regex::new("^fake-binary").unwrap()); - REGEX.clone() - } } diff --git a/crates/dap/src/registry.rs b/crates/dap/src/registry.rs index 2a3f0869fb..b6c8efea40 100644 --- a/crates/dap/src/registry.rs +++ b/crates/dap/src/registry.rs @@ -8,7 +8,7 @@ struct DapRegistryState { adapters: BTreeMap>, } -#[derive(Default)] +#[derive(Clone, Default)] /// Stores available debug adapters. pub struct DapRegistry(Arc>); diff --git a/crates/dap_adapters/Cargo.toml b/crates/dap_adapters/Cargo.toml index 40ca634a26..0a11724aa2 100644 --- a/crates/dap_adapters/Cargo.toml +++ b/crates/dap_adapters/Cargo.toml @@ -27,7 +27,6 @@ dap.workspace = true gpui.workspace = true language.workspace = true paths.workspace = true -regex.workspace = true serde.workspace = true serde_json.workspace = true task.workspace = true diff --git a/crates/dap_adapters/src/dap_adapters.rs b/crates/dap_adapters/src/dap_adapters.rs index 320b5336fc..f6c6f7844c 100644 --- a/crates/dap_adapters/src/dap_adapters.rs +++ b/crates/dap_adapters/src/dap_adapters.rs @@ -31,7 +31,7 @@ pub fn init(registry: Arc) { registry.add_adapter(Arc::from(CodeLldbDebugAdapter::default())); registry.add_adapter(Arc::from(PythonDebugAdapter)); registry.add_adapter(Arc::from(PhpDebugAdapter)); - registry.add_adapter(Arc::from(JsDebugAdapter::default())); + registry.add_adapter(Arc::from(JsDebugAdapter)); registry.add_adapter(Arc::from(LldbDebugAdapter)); registry.add_adapter(Arc::from(GoDebugAdapter)); registry.add_adapter(Arc::from(GdbDebugAdapter)); diff --git a/crates/dap_adapters/src/javascript.rs b/crates/dap_adapters/src/javascript.rs index 5022f0ac76..11dee971b1 100644 --- a/crates/dap_adapters/src/javascript.rs +++ b/crates/dap_adapters/src/javascript.rs @@ -1,24 +1,13 @@ use adapters::latest_github_release; use gpui::AsyncApp; -use regex::Regex; use std::path::PathBuf; use task::{DebugRequestType, DebugTaskDefinition}; use crate::*; #[derive(Debug)] -pub(crate) struct JsDebugAdapter { - attach_processes: Regex, -} +pub(crate) struct JsDebugAdapter; -impl Default for JsDebugAdapter { - fn default() -> Self { - Self { - attach_processes: Regex::new(r"(?i)^(?:node|bun|iojs)(?:$|\b)") - .expect("Regex compilation to succeed"), - } - } -} impl JsDebugAdapter { const ADAPTER_NAME: &'static str = "JavaScript"; const ADAPTER_NPM_NAME: &'static str = "vscode-js-debug"; @@ -149,8 +138,4 @@ impl DebugAdapter for JsDebugAdapter { } args } - - fn attach_processes_filter(&self) -> Regex { - self.attach_processes.clone() - } } diff --git a/crates/db/src/kvp.rs b/crates/db/src/kvp.rs index c9d994d34d..d501368c85 100644 --- a/crates/db/src/kvp.rs +++ b/crates/db/src/kvp.rs @@ -1,6 +1,7 @@ use sqlez_macros::sql; use crate::{define_connection, query}; +pub static DEBUGGER_PANEL_PREFIX: &str = "debugger_panel_"; define_connection!(pub static ref KEY_VALUE_STORE: KeyValueStore<()> = &[sql!( diff --git a/crates/debugger_ui/Cargo.toml b/crates/debugger_ui/Cargo.toml index 239c47d07d..ac180f308f 100644 --- a/crates/debugger_ui/Cargo.toml +++ b/crates/debugger_ui/Cargo.toml @@ -28,6 +28,7 @@ client.workspace = true collections.workspace = true command_palette_hooks.workspace = true dap.workspace = true +db.workspace = true editor.workspace = true feature_flags.workspace = true futures.workspace = true diff --git a/crates/debugger_ui/src/attach_modal.rs b/crates/debugger_ui/src/attach_modal.rs index 0ba8efc5e3..870d09aa3f 100644 --- a/crates/debugger_ui/src/attach_modal.rs +++ b/crates/debugger_ui/src/attach_modal.rs @@ -4,7 +4,6 @@ use gpui::Subscription; use gpui::{DismissEvent, Entity, EventEmitter, Focusable, Render}; use picker::{Picker, PickerDelegate}; -use std::cell::LazyCell; use std::sync::Arc; use sysinfo::System; use ui::{Context, Tooltip, prelude::*}; @@ -24,7 +23,7 @@ pub(crate) struct AttachModalDelegate { matches: Vec, placeholder_text: Arc, project: Entity, - debug_config: task::DebugTaskDefinition, + pub(crate) debug_config: task::DebugTaskDefinition, candidates: Arc<[Candidate]>, } @@ -58,7 +57,7 @@ impl AttachModal { window: &mut Window, cx: &mut Context, ) -> Self { - let mut processes: Vec<_> = System::new_all() + let mut processes: Box<[_]> = System::new_all() .processes() .values() .map(|process| { @@ -75,30 +74,18 @@ impl AttachModal { }) .collect(); processes.sort_by_key(|k| k.name.clone()); + let processes = processes.into_iter().collect(); Self::with_processes(project, debug_config, processes, modal, window, cx) } pub(super) fn with_processes( project: Entity, debug_config: task::DebugTaskDefinition, - processes: Vec, + processes: Arc<[Candidate]>, modal: bool, window: &mut Window, cx: &mut Context, ) -> Self { - let adapter = project - .read(cx) - .debug_adapters() - .adapter(&debug_config.adapter); - let filter = LazyCell::new(|| adapter.map(|adapter| adapter.attach_processes_filter())); - let processes = processes - .into_iter() - .filter(|process| { - filter - .as_ref() - .map_or(false, |filter| filter.is_match(&process.name)) - }) - .collect(); let picker = cx.new(|cx| { Picker::uniform_list( AttachModalDelegate::new(project, debug_config, processes), @@ -117,9 +104,10 @@ impl AttachModal { } impl Render for AttachModal { - fn render(&mut self, _window: &mut Window, _: &mut Context) -> impl ui::IntoElement { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl ui::IntoElement { v_flex() .key_context("AttachModal") + .track_focus(&self.focus_handle(cx)) .w(rems(34.)) .child(self.picker.clone()) } diff --git a/crates/debugger_ui/src/debugger_panel.rs b/crates/debugger_ui/src/debugger_panel.rs index d9327711f9..9cec74d27e 100644 --- a/crates/debugger_ui/src/debugger_panel.rs +++ b/crates/debugger_ui/src/debugger_panel.rs @@ -1,6 +1,6 @@ use crate::{ ClearAllBreakpoints, Continue, CreateDebuggingSession, Disconnect, Pause, Restart, StepBack, - StepInto, StepOut, StepOver, Stop, ToggleIgnoreBreakpoints, + StepInto, StepOut, StepOver, Stop, ToggleIgnoreBreakpoints, persistence, }; use crate::{new_session_modal::NewSessionModal, session::DebugSession}; use anyhow::{Result, anyhow}; @@ -293,35 +293,49 @@ impl DebugPanel { ); }; - let Some(project) = self.project.upgrade() else { - return log::error!("Debug Panel out lived it's weak reference to Project"); - }; + let adapter_name = session.read(cx).adapter_name(); - if self - .sessions - .iter() - .any(|item| item.read(cx).session_id(cx) == *session_id) - { - // We already have an item for this session. - return; - } - let session_item = DebugSession::running( - project, - self.workspace.clone(), - session, - cx.weak_entity(), - window, - cx, - ); + let session_id = *session_id; + cx.spawn_in(window, async move |this, cx| { + let serialized_layout = + persistence::get_serialized_pane_layout(adapter_name).await; - if let Some(running) = session_item.read(cx).mode().as_running().cloned() { - // We might want to make this an event subscription and only notify when a new thread is selected - // This is used to filter the command menu correctly - cx.observe(&running, |_, _, cx| cx.notify()).detach(); - } + this.update_in(cx, |this, window, cx| { + let Some(project) = this.project.upgrade() else { + return log::error!( + "Debug Panel out lived it's weak reference to Project" + ); + }; - self.sessions.push(session_item.clone()); - self.activate_session(session_item, window, cx); + if this + .sessions + .iter() + .any(|item| item.read(cx).session_id(cx) == session_id) + { + // We already have an item for this session. + return; + } + let session_item = DebugSession::running( + project, + this.workspace.clone(), + session, + cx.weak_entity(), + serialized_layout, + window, + cx, + ); + + if let Some(running) = session_item.read(cx).mode().as_running().cloned() { + // We might want to make this an event subscription and only notify when a new thread is selected + // This is used to filter the command menu correctly + cx.observe(&running, |_, _, cx| cx.notify()).detach(); + } + + this.sessions.push(session_item.clone()); + this.activate_session(session_item, window, cx); + }) + }) + .detach(); } dap_store::DapStoreEvent::RunInTerminal { title, @@ -415,32 +429,58 @@ impl DebugPanel { }) } - fn close_session(&mut self, entity_id: EntityId, cx: &mut Context) { + fn close_session(&mut self, entity_id: EntityId, window: &mut Window, cx: &mut Context) { let Some(session) = self .sessions .iter() .find(|other| entity_id == other.entity_id()) + .cloned() else { return; }; - session.update(cx, |session, cx| session.shutdown(cx)); + let session_id = session.update(cx, |this, cx| this.session_id(cx)); + let should_prompt = self + .project + .update(cx, |this, cx| { + let session = this.dap_store().read(cx).session_by_id(session_id); + session.map(|session| !session.read(cx).is_terminated()) + }) + .ok() + .flatten() + .unwrap_or_default(); - self.sessions.retain(|other| entity_id != other.entity_id()); - - if let Some(active_session_id) = self - .active_session - .as_ref() - .map(|session| session.entity_id()) - { - if active_session_id == entity_id { - self.active_session = self.sessions.first().cloned(); + cx.spawn_in(window, async move |this, cx| { + if should_prompt { + let response = cx.prompt( + gpui::PromptLevel::Warning, + "This Debug Session is still running. Are you sure you want to terminate it?", + None, + &["Yes", "No"], + ); + if response.await == Ok(1) { + return; + } } - } + session.update(cx, |session, cx| session.shutdown(cx)).ok(); + this.update(cx, |this, cx| { + this.sessions.retain(|other| entity_id != other.entity_id()); - cx.notify(); + if let Some(active_session_id) = this + .active_session + .as_ref() + .map(|session| session.entity_id()) + { + if active_session_id == entity_id { + this.active_session = this.sessions.first().cloned(); + } + } + cx.notify() + }) + .ok(); + }) + .detach(); } - fn sessions_drop_down_menu( &self, active_session: &Entity, @@ -487,8 +527,11 @@ impl DebugPanel { let weak = weak.clone(); move |_, window, cx| { weak.update(cx, |panel, cx| { - panel - .close_session(weak_session_id, cx); + panel.close_session( + weak_session_id, + window, + cx, + ); }) .ok(); context_menu diff --git a/crates/debugger_ui/src/debugger_ui.rs b/crates/debugger_ui/src/debugger_ui.rs index 7a89b275ea..55eea2315e 100644 --- a/crates/debugger_ui/src/debugger_ui.rs +++ b/crates/debugger_ui/src/debugger_ui.rs @@ -13,6 +13,7 @@ use workspace::{ShutdownDebugAdapters, Workspace}; pub mod attach_modal; pub mod debugger_panel; mod new_session_modal; +mod persistence; pub(crate) mod session; #[cfg(test)] diff --git a/crates/debugger_ui/src/new_session_modal.rs b/crates/debugger_ui/src/new_session_modal.rs index b97af68fd6..c6f10bae25 100644 --- a/crates/debugger_ui/src/new_session_modal.rs +++ b/crates/debugger_ui/src/new_session_modal.rs @@ -11,6 +11,7 @@ use gpui::{ App, AppContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, TextStyle, WeakEntity, }; +use project::Project; use settings::Settings; use task::{DebugTaskDefinition, LaunchConfig}; use theme::ThemeSettings; @@ -59,7 +60,7 @@ impl NewSessionModal { debug_panel: WeakEntity, workspace: WeakEntity, window: &mut Window, - cx: &mut App, + cx: &mut Context, ) -> Self { let debugger = past_debug_definition .as_ref() @@ -171,25 +172,13 @@ impl NewSessionModal { attach.update(cx, |this, cx| { if selected_debugger != this.debug_definition.adapter { this.debug_definition.adapter = selected_debugger.into(); - if let Some(project) = this - .workspace - .read_with(cx, |workspace, _| workspace.project().clone()) - .ok() - { - this.attach_picker = Some(cx.new(|cx| { - let modal = AttachModal::new( - project, - this.debug_definition.clone(), - false, - window, - cx, - ); - window.focus(&modal.focus_handle(cx)); - - modal - })); - } + this.attach_picker.update(cx, |this, cx| { + this.picker.update(cx, |this, cx| { + this.delegate.debug_config.adapter = selected_debugger.into(); + this.focus(window, cx); + }) + }); } cx.notify(); @@ -256,7 +245,6 @@ impl NewSessionModal { ContextMenu::build(window, cx, move |mut menu, _, cx| { let setter_for_name = |task: DebugTaskDefinition| { let weak = weak.clone(); - let workspace = workspace.clone(); move |window: &mut Window, cx: &mut App| { weak.update(cx, |this, cx| { this.last_selected_profile_name = Some(SharedString::from(&task.label)); @@ -271,12 +259,19 @@ impl NewSessionModal { ); } DebugRequestType::Attach(_) => { + let Ok(project) = this + .workspace + .read_with(cx, |this, _| this.project().clone()) + else { + return; + }; this.mode = NewSessionMode::attach( this.debugger.clone(), - workspace.clone(), + project, window, cx, ); + this.mode.focus_handle(cx).focus(window); if let Some((debugger, attach)) = this.debugger.as_ref().zip(this.mode.as_attach()) { @@ -365,18 +360,16 @@ impl LaunchMode { #[derive(Clone)] struct AttachMode { - workspace: WeakEntity, debug_definition: DebugTaskDefinition, - attach_picker: Option>, - focus_handle: FocusHandle, + attach_picker: Entity, } impl AttachMode { fn new( debugger: Option, - workspace: WeakEntity, + project: Entity, window: &mut Window, - cx: &mut App, + cx: &mut Context, ) -> Entity { let debug_definition = DebugTaskDefinition { label: "Attach New Session Setup".into(), @@ -387,27 +380,15 @@ impl AttachMode { initialize_args: None, stop_on_entry: Some(false), }; + let attach_picker = cx.new(|cx| { + let modal = AttachModal::new(project, debug_definition.clone(), false, window, cx); + window.focus(&modal.focus_handle(cx)); - let attach_picker = if let Some(project) = debugger.and( - workspace - .read_with(cx, |workspace, _| workspace.project().clone()) - .ok(), - ) { - Some(cx.new(|cx| { - let modal = AttachModal::new(project, debug_definition.clone(), false, window, cx); - window.focus(&modal.focus_handle(cx)); - - modal - })) - } else { - None - }; - - cx.new(|cx| Self { - workspace, + modal + }); + cx.new(|_| Self { debug_definition, attach_picker, - focus_handle: cx.focus_handle(), }) } fn debug_task(&self) -> task::AttachConfig { @@ -444,7 +425,7 @@ impl Focusable for NewSessionMode { fn focus_handle(&self, cx: &App) -> FocusHandle { match &self { NewSessionMode::Launch(entity) => entity.read(cx).program.focus_handle(cx), - NewSessionMode::Attach(entity) => entity.read(cx).focus_handle.clone(), + NewSessionMode::Attach(entity) => entity.read(cx).attach_picker.focus_handle(cx), } } } @@ -476,8 +457,11 @@ impl RenderOnce for LaunchMode { } impl RenderOnce for AttachMode { - fn render(self, _: &mut Window, _: &mut App) -> impl IntoElement { - v_flex().w_full().children(self.attach_picker.clone()) + fn render(self, _: &mut Window, cx: &mut App) -> impl IntoElement { + v_flex() + .w_full() + .track_focus(&self.attach_picker.focus_handle(cx)) + .child(self.attach_picker.clone()) } } @@ -497,13 +481,17 @@ impl RenderOnce for NewSessionMode { impl NewSessionMode { fn attach( debugger: Option, - workspace: WeakEntity, + project: Entity, window: &mut Window, - cx: &mut App, + cx: &mut Context, ) -> Self { - Self::Attach(AttachMode::new(debugger, workspace, window, cx)) + Self::Attach(AttachMode::new(debugger, project, window, cx)) } - fn launch(past_launch_config: Option, window: &mut Window, cx: &mut App) -> Self { + fn launch( + past_launch_config: Option, + window: &mut Window, + cx: &mut Context, + ) -> Self { Self::Launch(LaunchMode::new(past_launch_config, window, cx)) } } @@ -592,18 +580,25 @@ impl Render for NewSessionModal { .toggle_state(matches!(self.mode, NewSessionMode::Attach(_))) .style(ui::ButtonStyle::Subtle) .on_click(cx.listener(|this, _, window, cx| { + let Ok(project) = this + .workspace + .read_with(cx, |this, _| this.project().clone()) + else { + return; + }; this.mode = NewSessionMode::attach( this.debugger.clone(), - this.workspace.clone(), + project, window, cx, ); + this.mode.focus_handle(cx).focus(window); if let Some((debugger, attach)) = this.debugger.as_ref().zip(this.mode.as_attach()) { Self::update_attach_picker(&attach, &debugger, window, cx); } - this.mode.focus_handle(cx).focus(window); + cx.notify(); })) .last(), diff --git a/crates/debugger_ui/src/persistence.rs b/crates/debugger_ui/src/persistence.rs new file mode 100644 index 0000000000..0472675268 --- /dev/null +++ b/crates/debugger_ui/src/persistence.rs @@ -0,0 +1,259 @@ +use collections::HashMap; +use db::kvp::KEY_VALUE_STORE; +use gpui::{Axis, Context, Entity, EntityId, Focusable, Subscription, WeakEntity, Window}; +use project::Project; +use serde::{Deserialize, Serialize}; +use ui::{App, SharedString}; +use util::ResultExt; +use workspace::{Member, Pane, PaneAxis, Workspace}; + +use crate::session::running::{ + self, RunningState, SubView, breakpoint_list::BreakpointList, console::Console, + module_list::ModuleList, stack_frame_list::StackFrameList, variable_list::VariableList, +}; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub(crate) enum DebuggerPaneItem { + Console, + Variables, + BreakpointList, + Frames, + Modules, +} + +impl DebuggerPaneItem { + pub(crate) fn to_shared_string(self) -> SharedString { + match self { + DebuggerPaneItem::Console => SharedString::new_static("Console"), + DebuggerPaneItem::Variables => SharedString::new_static("Variables"), + DebuggerPaneItem::BreakpointList => SharedString::new_static("Breakpoints"), + DebuggerPaneItem::Frames => SharedString::new_static("Frames"), + DebuggerPaneItem::Modules => SharedString::new_static("Modules"), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct SerializedAxis(pub Axis); + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) enum SerializedPaneLayout { + Pane(SerializedPane), + Group { + axis: SerializedAxis, + flexes: Option>, + children: Vec, + }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct SerializedPane { + pub children: Vec, + pub active_item: Option, +} + +pub(crate) async fn serialize_pane_layout( + adapter_name: SharedString, + pane_group: SerializedPaneLayout, +) -> anyhow::Result<()> { + if let Ok(serialized_pane_group) = serde_json::to_string(&pane_group) { + KEY_VALUE_STORE + .write_kvp( + format!("{}-{adapter_name}", db::kvp::DEBUGGER_PANEL_PREFIX), + serialized_pane_group, + ) + .await + } else { + Err(anyhow::anyhow!( + "Failed to serialize pane group with serde_json as a string" + )) + } +} + +pub(crate) fn build_serialized_pane_layout( + pane_group: &Member, + cx: &mut App, +) -> SerializedPaneLayout { + match pane_group { + Member::Axis(PaneAxis { + axis, + members, + flexes, + bounding_boxes: _, + }) => SerializedPaneLayout::Group { + axis: SerializedAxis(*axis), + children: members + .iter() + .map(|member| build_serialized_pane_layout(member, cx)) + .collect::>(), + flexes: Some(flexes.lock().clone()), + }, + Member::Pane(pane_handle) => SerializedPaneLayout::Pane(serialize_pane(pane_handle, cx)), + } +} + +fn serialize_pane(pane: &Entity, cx: &mut App) -> SerializedPane { + let pane = pane.read(cx); + let children = pane + .items() + .filter_map(|item| { + item.act_as::(cx) + .map(|view| view.read(cx).view_kind()) + }) + .collect::>(); + + let active_item = pane + .active_item() + .and_then(|item| item.act_as::(cx)) + .map(|view| view.read(cx).view_kind()); + + SerializedPane { + children, + active_item, + } +} + +pub(crate) async fn get_serialized_pane_layout( + adapter_name: impl AsRef, +) -> Option { + let key = format!( + "{}-{}", + db::kvp::DEBUGGER_PANEL_PREFIX, + adapter_name.as_ref() + ); + + KEY_VALUE_STORE + .read_kvp(&key) + .log_err() + .flatten() + .and_then(|value| serde_json::from_str::(&value).ok()) +} + +pub(crate) fn deserialize_pane_layout( + serialized: SerializedPaneLayout, + workspace: &WeakEntity, + project: &Entity, + stack_frame_list: &Entity, + variable_list: &Entity, + module_list: &Entity, + console: &Entity, + breakpoint_list: &Entity, + subscriptions: &mut HashMap, + window: &mut Window, + cx: &mut Context, +) -> Option { + match serialized { + SerializedPaneLayout::Group { + axis, + flexes, + children, + } => { + let mut members = Vec::new(); + for child in children { + if let Some(new_member) = deserialize_pane_layout( + child, + workspace, + project, + stack_frame_list, + variable_list, + module_list, + console, + breakpoint_list, + subscriptions, + window, + cx, + ) { + members.push(new_member); + } + } + + if members.is_empty() { + return None; + } + + if members.len() == 1 { + return Some(members.remove(0)); + } + + Some(Member::Axis(PaneAxis::load( + axis.0, + members, + flexes.clone(), + ))) + } + SerializedPaneLayout::Pane(serialized_pane) => { + let pane = running::new_debugger_pane(workspace.clone(), project.clone(), window, cx); + subscriptions.insert( + pane.entity_id(), + cx.subscribe_in(&pane, window, RunningState::handle_pane_event), + ); + + let sub_views: Vec<_> = serialized_pane + .children + .iter() + .map(|child| match child { + DebuggerPaneItem::Frames => Box::new(SubView::new( + pane.focus_handle(cx), + stack_frame_list.clone().into(), + DebuggerPaneItem::Frames, + None, + cx, + )), + DebuggerPaneItem::Variables => Box::new(SubView::new( + variable_list.focus_handle(cx), + variable_list.clone().into(), + DebuggerPaneItem::Variables, + None, + cx, + )), + DebuggerPaneItem::BreakpointList => Box::new(SubView::new( + breakpoint_list.focus_handle(cx), + breakpoint_list.clone().into(), + DebuggerPaneItem::BreakpointList, + None, + cx, + )), + DebuggerPaneItem::Modules => Box::new(SubView::new( + pane.focus_handle(cx), + module_list.clone().into(), + DebuggerPaneItem::Modules, + None, + cx, + )), + + DebuggerPaneItem::Console => Box::new(SubView::new( + pane.focus_handle(cx), + console.clone().into(), + DebuggerPaneItem::Console, + Some(Box::new({ + let console = console.clone().downgrade(); + move |cx| { + console + .read_with(cx, |console, cx| console.show_indicator(cx)) + .unwrap_or_default() + } + })), + cx, + )), + }) + .collect(); + + pane.update(cx, |pane, cx| { + let mut active_idx = 0; + for (idx, sub_view) in sub_views.into_iter().enumerate() { + if serialized_pane + .active_item + .is_some_and(|active| active == sub_view.read(cx).view_kind()) + { + active_idx = idx; + } + pane.add_item(sub_view, false, false, None, window, cx); + } + + pane.activate_item(active_idx, false, false, window, cx); + }); + + Some(Member::Pane(pane.clone())) + } + } +} diff --git a/crates/debugger_ui/src/session.rs b/crates/debugger_ui/src/session.rs index c69a2259b2..93fbdc1111 100644 --- a/crates/debugger_ui/src/session.rs +++ b/crates/debugger_ui/src/session.rs @@ -16,6 +16,7 @@ use workspace::{ }; use crate::debugger_panel::DebugPanel; +use crate::persistence::SerializedPaneLayout; pub(crate) enum DebugSessionState { Running(Entity), @@ -52,6 +53,7 @@ impl DebugSession { workspace: WeakEntity, session: Entity, _debug_panel: WeakEntity, + serialized_pane_layout: Option, window: &mut Window, cx: &mut App, ) -> Entity { @@ -60,6 +62,7 @@ impl DebugSession { session.clone(), project.clone(), workspace.clone(), + serialized_pane_layout, window, cx, ) diff --git a/crates/debugger_ui/src/session/running.rs b/crates/debugger_ui/src/session/running.rs index b836a05db9..d3d3d5637e 100644 --- a/crates/debugger_ui/src/session/running.rs +++ b/crates/debugger_ui/src/session/running.rs @@ -1,11 +1,13 @@ -mod breakpoint_list; -mod console; -mod loaded_source_list; -mod module_list; +pub(crate) mod breakpoint_list; +pub(crate) mod console; +pub(crate) mod loaded_source_list; +pub(crate) mod module_list; pub mod stack_frame_list; pub mod variable_list; -use std::{any::Any, ops::ControlFlow, sync::Arc}; +use std::{any::Any, ops::ControlFlow, sync::Arc, time::Duration}; + +use crate::persistence::{self, DebuggerPaneItem, SerializedPaneLayout}; use super::DebugPanelItemEvent; use breakpoint_list::BreakpointList; @@ -14,7 +16,7 @@ use console::Console; use dap::{Capabilities, Thread, client::SessionId, debugger_settings::DebuggerSettings}; use gpui::{ Action as _, AnyView, AppContext, Entity, EntityId, EventEmitter, FocusHandle, Focusable, - NoAction, Subscription, WeakEntity, + NoAction, Subscription, Task, WeakEntity, }; use loaded_source_list::LoadedSourceList; use module_list::ModuleList; @@ -33,8 +35,8 @@ use ui::{ use util::ResultExt; use variable_list::VariableList; use workspace::{ - ActivePaneDecorator, DraggedTab, Item, Pane, PaneGroup, Workspace, item::TabContentParams, - move_item, pane::Event, + ActivePaneDecorator, DraggedTab, Item, Member, Pane, PaneGroup, Workspace, + item::TabContentParams, move_item, pane::Event, }; pub struct RunningState { @@ -51,6 +53,7 @@ pub struct RunningState { _console: Entity, panes: PaneGroup, pane_close_subscriptions: HashMap, + _schedule_serialize: Option>, } impl Render for RunningState { @@ -84,28 +87,32 @@ impl Render for RunningState { } } -struct SubView { +pub(crate) struct SubView { inner: AnyView, pane_focus_handle: FocusHandle, - tab_name: SharedString, + kind: DebuggerPaneItem, show_indicator: Box bool>, } impl SubView { - fn new( + pub(crate) fn new( pane_focus_handle: FocusHandle, view: AnyView, - tab_name: SharedString, + kind: DebuggerPaneItem, show_indicator: Option bool>>, cx: &mut App, ) -> Entity { cx.new(|_| Self { - tab_name, + kind, inner: view, pane_focus_handle, show_indicator: show_indicator.unwrap_or(Box::new(|_| false)), }) } + + pub(crate) fn view_kind(&self) -> DebuggerPaneItem { + self.kind + } } impl Focusable for SubView { fn focus_handle(&self, _: &App) -> FocusHandle { @@ -116,13 +123,19 @@ impl EventEmitter<()> for SubView {} impl Item for SubView { type Event = (); + /// This is used to serialize debugger pane layouts + /// A SharedString gets converted to a enum and back during serialization/deserialization. + fn tab_content_text(&self, _window: &Window, _cx: &App) -> Option { + Some(self.kind.to_shared_string()) + } + fn tab_content( &self, params: workspace::item::TabContentParams, _: &Window, cx: &App, ) -> AnyElement { - let label = Label::new(self.tab_name.clone()) + let label = Label::new(self.kind.to_shared_string()) .size(ui::LabelSize::Small) .color(params.text_color()) .line_height_style(ui::LineHeightStyle::UiLabel); @@ -146,7 +159,7 @@ impl Render for SubView { } } -fn new_debugger_pane( +pub(crate) fn new_debugger_pane( workspace: WeakEntity, project: Entity, window: &mut Window, @@ -185,7 +198,7 @@ fn new_debugger_pane( new_debugger_pane(workspace.clone(), project.clone(), window, cx); let _previous_subscription = running.pane_close_subscriptions.insert( new_pane.entity_id(), - cx.subscribe(&new_pane, RunningState::handle_pane_event), + cx.subscribe_in(&new_pane, window, RunningState::handle_pane_event), ); debug_assert!(_previous_subscription.is_none()); running @@ -275,6 +288,8 @@ fn new_debugger_pane( let active_pane_item = pane.active_item(); h_flex() .w_full() + .px_2() + .gap_1() .h(Tab::container_height(cx)) .drag_over::(|bar, _, _, cx| { bar.bg(cx.theme().colors().drop_target_background) @@ -352,6 +367,7 @@ impl RunningState { session: Entity, project: Entity, workspace: WeakEntity, + serialized_pane_layout: Option, window: &mut Window, cx: &mut Context, ) -> Self { @@ -380,6 +396,8 @@ impl RunningState { ) }); + let breakpoints = BreakpointList::new(session.clone(), workspace.clone(), &project, cx); + let _subscriptions = vec![ cx.observe(&module_list, |_, _, cx| cx.notify()), cx.subscribe_in(&session, window, |this, _, event, window, cx| { @@ -405,112 +423,40 @@ impl RunningState { }), ]; - let leftmost_pane = new_debugger_pane(workspace.clone(), project.clone(), window, cx); - leftmost_pane.update(cx, |this, cx| { - this.add_item( - Box::new(SubView::new( - this.focus_handle(cx), - stack_frame_list.clone().into(), - SharedString::new_static("Frames"), - None, - cx, - )), - true, - false, - None, + let mut pane_close_subscriptions = HashMap::default(); + let panes = if let Some(root) = serialized_pane_layout.and_then(|serialized_layout| { + persistence::deserialize_pane_layout( + serialized_layout, + &workspace, + &project, + &stack_frame_list, + &variable_list, + &module_list, + &console, + &breakpoints, + &mut pane_close_subscriptions, + window, + cx, + ) + }) { + workspace::PaneGroup::with_root(root) + } else { + pane_close_subscriptions.clear(); + let root = Self::default_pane_layout( + project, + &workspace, + &stack_frame_list, + &variable_list, + &module_list, + &console, + breakpoints, + &mut pane_close_subscriptions, window, cx, ); - let breakpoints = BreakpointList::new(session.clone(), workspace.clone(), &project, cx); - this.add_item( - Box::new(SubView::new( - breakpoints.focus_handle(cx), - breakpoints.into(), - SharedString::new_static("Breakpoints"), - None, - cx, - )), - true, - false, - None, - window, - cx, - ); - this.activate_item(0, false, false, window, cx); - }); - let center_pane = new_debugger_pane(workspace.clone(), project.clone(), window, cx); - center_pane.update(cx, |this, cx| { - this.add_item( - Box::new(SubView::new( - variable_list.focus_handle(cx), - variable_list.clone().into(), - SharedString::new_static("Variables"), - None, - cx, - )), - true, - false, - None, - window, - cx, - ); - this.add_item( - Box::new(SubView::new( - this.focus_handle(cx), - module_list.clone().into(), - SharedString::new_static("Modules"), - None, - cx, - )), - false, - false, - None, - window, - cx, - ); - this.activate_item(0, false, false, window, cx); - }); - let rightmost_pane = new_debugger_pane(workspace.clone(), project.clone(), window, cx); - rightmost_pane.update(cx, |this, cx| { - let weak_console = console.downgrade(); - this.add_item( - Box::new(SubView::new( - this.focus_handle(cx), - console.clone().into(), - SharedString::new_static("Console"), - Some(Box::new(move |cx| { - weak_console - .read_with(cx, |console, cx| console.show_indicator(cx)) - .unwrap_or_default() - })), - cx, - )), - true, - false, - None, - window, - cx, - ); - }); - let pane_close_subscriptions = HashMap::from_iter( - [&leftmost_pane, ¢er_pane, &rightmost_pane] - .into_iter() - .map(|entity| { - ( - entity.entity_id(), - cx.subscribe(entity, Self::handle_pane_event), - ) - }), - ); - let group_root = workspace::PaneAxis::new( - gpui::Axis::Horizontal, - [leftmost_pane, center_pane, rightmost_pane] - .into_iter() - .map(workspace::Member::Pane) - .collect(), - ); - let panes = PaneGroup::with_root(workspace::Member::Axis(group_root)); + workspace::PaneGroup::with_root(root) + }; Self { session, @@ -526,21 +472,57 @@ impl RunningState { _module_list: module_list, _console: console, pane_close_subscriptions, + _schedule_serialize: None, } } - fn handle_pane_event( + fn serialize_layout(&mut self, window: &mut Window, cx: &mut Context) { + if self._schedule_serialize.is_none() { + self._schedule_serialize = Some(cx.spawn_in(window, async move |this, cx| { + cx.background_executor() + .timer(Duration::from_millis(100)) + .await; + + let Some((adapter_name, pane_group)) = this + .update(cx, |this, cx| { + let adapter_name = this.session.read(cx).adapter_name(); + ( + adapter_name, + persistence::build_serialized_pane_layout(&this.panes.root, cx), + ) + }) + .ok() + else { + return; + }; + + persistence::serialize_pane_layout(adapter_name, pane_group) + .await + .log_err(); + + this.update(cx, |this, _| { + this._schedule_serialize.take(); + }) + .ok(); + })); + } + } + + pub(crate) fn handle_pane_event( this: &mut RunningState, - source_pane: Entity, + source_pane: &Entity, event: &Event, + window: &mut Window, cx: &mut Context, ) { + this.serialize_layout(window, cx); if let Event::Remove { .. } = event { let _did_find_pane = this.panes.remove(&source_pane).is_ok(); debug_assert!(_did_find_pane); cx.notify(); } } + pub(crate) fn go_to_selected_stack_frame(&self, window: &Window, cx: &mut Context) { if self.thread_id.is_some() { self.stack_frame_list @@ -584,7 +566,7 @@ impl RunningState { .find_map(|pane| { pane.read(cx) .items_of_type::() - .position(|view| view.read(cx).tab_name == *"Modules") + .position(|view| view.read(cx).view_kind().to_shared_string() == *"Modules") .map(|view| (view, pane)) }) .unwrap(); @@ -800,6 +782,127 @@ impl RunningState { }), ) } + + fn default_pane_layout( + project: Entity, + workspace: &WeakEntity, + stack_frame_list: &Entity, + variable_list: &Entity, + module_list: &Entity, + console: &Entity, + breakpoints: Entity, + subscriptions: &mut HashMap, + window: &mut Window, + cx: &mut Context<'_, RunningState>, + ) -> Member { + let leftmost_pane = new_debugger_pane(workspace.clone(), project.clone(), window, cx); + leftmost_pane.update(cx, |this, cx| { + this.add_item( + Box::new(SubView::new( + this.focus_handle(cx), + stack_frame_list.clone().into(), + DebuggerPaneItem::Frames, + None, + cx, + )), + true, + false, + None, + window, + cx, + ); + this.add_item( + Box::new(SubView::new( + breakpoints.focus_handle(cx), + breakpoints.into(), + DebuggerPaneItem::BreakpointList, + None, + cx, + )), + true, + false, + None, + window, + cx, + ); + this.activate_item(0, false, false, window, cx); + }); + let center_pane = new_debugger_pane(workspace.clone(), project.clone(), window, cx); + center_pane.update(cx, |this, cx| { + this.add_item( + Box::new(SubView::new( + variable_list.focus_handle(cx), + variable_list.clone().into(), + DebuggerPaneItem::Variables, + None, + cx, + )), + true, + false, + None, + window, + cx, + ); + this.add_item( + Box::new(SubView::new( + this.focus_handle(cx), + module_list.clone().into(), + DebuggerPaneItem::Modules, + None, + cx, + )), + false, + false, + None, + window, + cx, + ); + this.activate_item(0, false, false, window, cx); + }); + let rightmost_pane = new_debugger_pane(workspace.clone(), project.clone(), window, cx); + rightmost_pane.update(cx, |this, cx| { + let weak_console = console.downgrade(); + this.add_item( + Box::new(SubView::new( + this.focus_handle(cx), + console.clone().into(), + DebuggerPaneItem::Console, + Some(Box::new(move |cx| { + weak_console + .read_with(cx, |console, cx| console.show_indicator(cx)) + .unwrap_or_default() + })), + cx, + )), + true, + false, + None, + window, + cx, + ); + }); + + subscriptions.extend( + [&leftmost_pane, ¢er_pane, &rightmost_pane] + .into_iter() + .map(|entity| { + ( + entity.entity_id(), + cx.subscribe_in(entity, window, Self::handle_pane_event), + ) + }), + ); + + let group_root = workspace::PaneAxis::new( + gpui::Axis::Horizontal, + [leftmost_pane, center_pane, rightmost_pane] + .into_iter() + .map(workspace::Member::Pane) + .collect(), + ); + + Member::Axis(group_root) + } } impl EventEmitter for RunningState {} diff --git a/crates/debugger_ui/src/session/running/breakpoint_list.rs b/crates/debugger_ui/src/session/running/breakpoint_list.rs index ff7ead4123..d9972284fe 100644 --- a/crates/debugger_ui/src/session/running/breakpoint_list.rs +++ b/crates/debugger_ui/src/session/running/breakpoint_list.rs @@ -27,7 +27,7 @@ use ui::{ use util::{ResultExt, maybe}; use workspace::Workspace; -pub(super) struct BreakpointList { +pub(crate) struct BreakpointList { workspace: WeakEntity, breakpoint_store: Entity, worktree_store: Entity, diff --git a/crates/debugger_ui/src/session/running/stack_frame_list.rs b/crates/debugger_ui/src/session/running/stack_frame_list.rs index 12dbef1722..bbb6e1f684 100644 --- a/crates/debugger_ui/src/session/running/stack_frame_list.rs +++ b/crates/debugger_ui/src/session/running/stack_frame_list.rs @@ -321,11 +321,15 @@ impl StackFrameList { let source = stack_frame.source.clone(); let is_selected_frame = Some(stack_frame.id) == self.selected_stack_frame_id; - let formatted_path = format!( - "{}:{}", - source.clone().and_then(|s| s.name).unwrap_or_default(), - stack_frame.line, - ); + let path = source.clone().and_then(|s| s.path.or(s.name)); + let formatted_path = path.map(|path| format!("{}:{}", path, stack_frame.line,)); + let formatted_path = formatted_path.map(|path| { + Label::new(path) + .size(LabelSize::XSmall) + .line_height_style(LineHeightStyle::UiLabel) + .truncate() + .color(Color::Muted) + }); let supports_frame_restart = self .session @@ -334,32 +338,19 @@ impl StackFrameList { .supports_restart_frame .unwrap_or_default(); - let origin = stack_frame - .source - .to_owned() - .and_then(|source| source.origin); - + let should_deemphasize = matches!( + stack_frame.presentation_hint, + Some( + dap::StackFramePresentationHint::Subtle + | dap::StackFramePresentationHint::Deemphasize + ) + ); h_flex() .rounded_md() .justify_between() .w_full() .group("") .id(("stack-frame", stack_frame.id)) - .tooltip({ - let formatted_path = formatted_path.clone(); - move |_window, app| { - app.new(|_| { - let mut tooltip = Tooltip::new(formatted_path.clone()); - - if let Some(origin) = &origin { - tooltip = tooltip.meta(origin); - } - - tooltip - }) - .into() - } - }) .p_1() .when(is_selected_frame, |this| { this.bg(cx.theme().colors().element_hover) @@ -374,21 +365,14 @@ impl StackFrameList { .hover(|style| style.bg(cx.theme().colors().element_hover).cursor_pointer()) .child( v_flex() + .gap_0p5() .child( - h_flex() - .gap_0p5() - .text_ui_sm(cx) + Label::new(stack_frame.name.clone()) + .size(LabelSize::Small) .truncate() - .child(stack_frame.name.clone()) - .child(formatted_path), + .when(should_deemphasize, |this| this.color(Color::Muted)), ) - .child( - h_flex() - .text_ui_xs(cx) - .truncate() - .text_color(cx.theme().colors().text_muted) - .when_some(source.and_then(|s| s.path), |this, path| this.child(path)), - ), + .children(formatted_path), ) .when( supports_frame_restart && stack_frame.can_restart.unwrap_or(true), diff --git a/crates/debugger_ui/src/tests.rs b/crates/debugger_ui/src/tests.rs index 0db43bb379..6fbe84a9e5 100644 --- a/crates/debugger_ui/src/tests.rs +++ b/crates/debugger_ui/src/tests.rs @@ -68,6 +68,7 @@ pub async fn init_test_workspace( workspace_handle } +#[track_caller] pub fn active_debug_session_panel( workspace: WindowHandle, cx: &mut TestAppContext, diff --git a/crates/debugger_ui/src/tests/attach_modal.rs b/crates/debugger_ui/src/tests/attach_modal.rs index 868191f22d..0c7465ca26 100644 --- a/crates/debugger_ui/src/tests/attach_modal.rs +++ b/crates/debugger_ui/src/tests/attach_modal.rs @@ -100,7 +100,7 @@ async fn test_show_attach_modal_and_select_process( }, Candidate { pid: 3, - name: "non-fake-binary-1".into(), + name: "real-binary-1".into(), command: vec![], }, Candidate { @@ -108,7 +108,9 @@ async fn test_show_attach_modal_and_select_process( name: "fake-binary-2".into(), command: vec![], }, - ], + ] + .into_iter() + .collect(), true, window, cx, @@ -121,17 +123,30 @@ async fn test_show_attach_modal_and_select_process( cx.run_until_parked(); + // assert we got the expected processes + workspace + .update(cx, |_, window, cx| { + let names = + attach_modal.update(cx, |modal, cx| attach_modal::_process_names(&modal, cx)); + // Initially all processes are visible. + assert_eq!(3, names.len()); + attach_modal.update(cx, |this, cx| { + this.picker.update(cx, |this, cx| { + this.set_query("fakb", window, cx); + }) + }) + }) + .unwrap(); + cx.run_until_parked(); // assert we got the expected processes workspace .update(cx, |_, _, cx| { let names = attach_modal.update(cx, |modal, cx| attach_modal::_process_names(&modal, cx)); - - // we filtered out all processes that are not starting with `fake-binary` + // Initially all processes are visible. assert_eq!(2, names.len()); }) .unwrap(); - // select the only existing process cx.dispatch_action(Confirm); diff --git a/crates/debugger_ui/src/tests/debugger_panel.rs b/crates/debugger_ui/src/tests/debugger_panel.rs index 0dd83f7143..a19d852a85 100644 --- a/crates/debugger_ui/src/tests/debugger_panel.rs +++ b/crates/debugger_ui/src/tests/debugger_panel.rs @@ -81,6 +81,8 @@ async fn test_basic_show_debug_panel(executor: BackgroundExecutor, cx: &mut Test }) .await; + cx.run_until_parked(); + // assert we have a debug panel item before the session has stopped workspace .update(cx, |workspace, _window, cx| { @@ -229,6 +231,8 @@ async fn test_we_can_only_have_one_panel_per_debug_session( }) .await; + cx.run_until_parked(); + // assert we have a debug panel item before the session has stopped workspace .update(cx, |workspace, _window, cx| { @@ -1052,6 +1056,8 @@ async fn test_debug_panel_item_thread_status_reset_on_failure( })) .await; + cx.run_until_parked(); + let running_state = active_debug_session_panel(workspace, cx).update_in(cx, |item, _, _| { item.mode() .as_running() diff --git a/crates/debugger_ui/src/tests/variable_list.rs b/crates/debugger_ui/src/tests/variable_list.rs index 8c352731fb..c4001cc5e2 100644 --- a/crates/debugger_ui/src/tests/variable_list.rs +++ b/crates/debugger_ui/src/tests/variable_list.rs @@ -1538,6 +1538,8 @@ async fn test_variable_list_only_sends_requests_when_rendering( }) .await; + cx.run_until_parked(); + let running_state = active_debug_session_panel(workspace, cx).update_in(cx, |item, _, _| { let state = item .mode() diff --git a/crates/editor/src/actions.rs b/crates/editor/src/actions.rs index ecc0823eb1..5ee1492b5b 100644 --- a/crates/editor/src/actions.rs +++ b/crates/editor/src/actions.rs @@ -99,6 +99,9 @@ pub struct ComposeCompletion { pub struct ConfirmCodeAction { #[serde(default)] pub item_ix: Option, + #[serde(default)] + #[serde(skip)] + pub from_mouse_context_menu: bool, } #[derive(PartialEq, Clone, Deserialize, Default, JsonSchema)] diff --git a/crates/editor/src/code_context_menus.rs b/crates/editor/src/code_context_menus.rs index caf555bc30..de8416a57f 100644 --- a/crates/editor/src/code_context_menus.rs +++ b/crates/editor/src/code_context_menus.rs @@ -774,7 +774,7 @@ pub struct AvailableCodeAction { pub provider: Rc, } -#[derive(Clone)] +#[derive(Clone, Default)] pub struct CodeActionContents { pub tasks: Option>, pub actions: Option>, @@ -790,7 +790,7 @@ impl CodeActionContents { } } - fn is_empty(&self) -> bool { + pub fn is_empty(&self) -> bool { match (&self.tasks, &self.actions) { (Some(tasks), Some(actions)) => actions.is_empty() && tasks.templates.is_empty(), (Some(tasks), None) => tasks.templates.is_empty(), @@ -799,7 +799,7 @@ impl CodeActionContents { } } - fn iter(&self) -> impl Iterator + '_ { + pub fn iter(&self) -> impl Iterator + '_ { self.tasks .iter() .flat_map(|tasks| { @@ -867,14 +867,14 @@ pub enum CodeActionsItem { } impl CodeActionsItem { - fn as_task(&self) -> Option<&ResolvedTask> { + pub fn as_task(&self) -> Option<&ResolvedTask> { let Self::Task(_, task) = self else { return None; }; Some(task) } - fn as_code_action(&self) -> Option<&CodeAction> { + pub fn as_code_action(&self) -> Option<&CodeAction> { let Self::CodeAction { action, .. } = self else { return None; }; @@ -1014,6 +1014,7 @@ impl CodeActionsMenu { if let Some(task) = editor.confirm_code_action( &ConfirmCodeAction { item_ix: Some(item_ix), + from_mouse_context_menu: false, }, window, cx, @@ -1039,6 +1040,7 @@ impl CodeActionsMenu { if let Some(task) = editor.confirm_code_action( &ConfirmCodeAction { item_ix: Some(item_ix), + from_mouse_context_menu: false, }, window, cx, diff --git a/crates/editor/src/display_map.rs b/crates/editor/src/display_map.rs index 896ee0be81..d94d9cb51d 100644 --- a/crates/editor/src/display_map.rs +++ b/crates/editor/src/display_map.rs @@ -1742,7 +1742,6 @@ pub mod tests { } } - #[cfg(target_os = "macos")] #[gpui::test(retries = 5)] async fn test_soft_wraps(cx: &mut gpui::TestAppContext) { cx.background_executor @@ -1760,7 +1759,7 @@ pub mod tests { editor.update(cx, |editor, _cx| editor.text_layout_details(window)); let font_size = px(12.0); - let wrap_width = Some(px(64.)); + let wrap_width = Some(px(96.)); let text = "one two three four five\nsix seven eight"; let buffer = MultiBuffer::build_simple(text, cx); @@ -2411,8 +2410,6 @@ pub mod tests { } } - // todo(linux) fails due to pixel differences in text rendering - #[cfg(target_os = "macos")] #[gpui::test] async fn test_chunks_with_soft_wrapping(cx: &mut gpui::TestAppContext) { cx.background_executor diff --git a/crates/editor/src/display_map/block_map.rs b/crates/editor/src/display_map/block_map.rs index 3e8aaaaefb..f57ae8b96b 100644 --- a/crates/editor/src/display_map/block_map.rs +++ b/crates/editor/src/display_map/block_map.rs @@ -2277,7 +2277,6 @@ mod tests { } } - #[cfg(target_os = "macos")] #[gpui::test] fn test_blocks_on_wrapped_lines(cx: &mut gpui::TestAppContext) { cx.update(init_test); @@ -2292,7 +2291,7 @@ mod tests { let (_, fold_snapshot) = FoldMap::new(inlay_snapshot); let (_, tab_snapshot) = TabMap::new(fold_snapshot, 4.try_into().unwrap()); let (_, wraps_snapshot) = cx.update(|cx| { - WrapMap::new(tab_snapshot, font("Helvetica"), px(14.0), Some(px(60.)), cx) + WrapMap::new(tab_snapshot, font("Helvetica"), px(14.0), Some(px(90.)), cx) }); let mut block_map = BlockMap::new(wraps_snapshot.clone(), 1, 1); diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 8bbf9ec774..6e779ebed4 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -1693,6 +1693,7 @@ impl Editor { self.mouse_context_menu = Some(MouseContextMenu::new( crate::mouse_context_menu::MenuPosition::PinnedToScreen(position), context_menu, + None, window, cx, )); @@ -3168,6 +3169,7 @@ impl Editor { let mut is_bracket_pair_start = false; let mut is_bracket_pair_end = false; if !text.is_empty() { + let mut bracket_pair_matching_end = None; // `text` can be empty when a user is using IME (e.g. Chinese Wubi Simplified) // and they are removing the character that triggered IME popup. for (pair, enabled) in scope.brackets() { @@ -3192,12 +3194,17 @@ impl Editor { break; } } - if pair.end.as_str() == text.as_ref() { - bracket_pair = Some(pair.clone()); - is_bracket_pair_end = true; - break; + if pair.end.as_str() == text.as_ref() && bracket_pair_matching_end.is_none() + { + // take first bracket pair matching end, but don't break in case a later bracket + // pair matches start + bracket_pair_matching_end = Some(pair.clone()); } } + if bracket_pair.is_none() && bracket_pair_matching_end.is_some() { + bracket_pair = Some(bracket_pair_matching_end.unwrap()); + is_bracket_pair_end = true; + } } if let Some(bracket_pair) = bracket_pair { @@ -4673,61 +4680,69 @@ impl Editor { snippet = None; new_text = completion.new_text.clone(); }; - let selections = self.selections.all::(cx); let replace_range = choose_completion_range(&completion, intent, &buffer_handle, cx); let buffer = buffer_handle.read(cx); - let old_text = buffer - .text_for_range(replace_range.clone()) - .collect::(); - - let newest_selection = self.selections.newest_anchor(); - if newest_selection.start.buffer_id != Some(buffer_handle.read(cx).remote_id()) { + let snapshot = self.buffer.read(cx).snapshot(cx); + let replace_range_multibuffer = { + let excerpt = snapshot + .excerpt_containing(self.selections.newest_anchor().range()) + .unwrap(); + let multibuffer_anchor = snapshot + .anchor_in_excerpt(excerpt.id(), buffer.anchor_before(replace_range.start)) + .unwrap() + ..snapshot + .anchor_in_excerpt(excerpt.id(), buffer.anchor_before(replace_range.end)) + .unwrap(); + multibuffer_anchor.start.to_offset(&snapshot) + ..multibuffer_anchor.end.to_offset(&snapshot) + }; + let newest_anchor = self.selections.newest_anchor(); + if newest_anchor.head().buffer_id != Some(buffer.remote_id()) { return None; } - let lookbehind = newest_selection + let old_text = buffer + .text_for_range(replace_range.clone()) + .collect::(); + let lookbehind = newest_anchor .start .text_anchor .to_offset(buffer) .saturating_sub(replace_range.start); let lookahead = replace_range .end - .saturating_sub(newest_selection.end.text_anchor.to_offset(buffer)); - let mut common_prefix_len = 0; - for (a, b) in old_text.chars().zip(new_text.chars()) { - if a == b { - common_prefix_len += a.len_utf8(); - } else { - break; - } - } + .saturating_sub(newest_anchor.end.text_anchor.to_offset(buffer)); + let prefix = &old_text[..old_text.len() - lookahead]; + let suffix = &old_text[lookbehind..]; - let snapshot = self.buffer.read(cx).snapshot(cx); - let mut range_to_replace: Option> = None; - let mut ranges = Vec::new(); + let selections = self.selections.all::(cx); + let mut edits = Vec::new(); let mut linked_edits = HashMap::<_, Vec<_>>::default(); + for selection in &selections { - if snapshot.contains_str_at(selection.start.saturating_sub(lookbehind), &old_text) { - let start = selection.start.saturating_sub(lookbehind); - let end = selection.end + lookahead; - if selection.id == newest_selection.id { - range_to_replace = Some(start + common_prefix_len..end); - } - ranges.push(start + common_prefix_len..end); + let edit = if selection.id == newest_anchor.id { + (replace_range_multibuffer.clone(), new_text.as_str()) } else { - common_prefix_len = 0; - ranges.clear(); - ranges.extend(selections.iter().map(|s| { - if s.id == newest_selection.id { - range_to_replace = Some(replace_range.clone()); - replace_range.clone() - } else { - s.start..s.end + let mut range = selection.range(); + let mut text = new_text.as_str(); + + // if prefix is present, don't duplicate it + if snapshot.contains_str_at(range.start.saturating_sub(lookbehind), prefix) { + text = &new_text[lookbehind..]; + + // if suffix is also present, mimic the newest cursor and replace it + if selection.id != newest_anchor.id + && snapshot.contains_str_at(range.end, suffix) + { + range.end += lookahead; } - })); - break; - } + } + (range, text) + }; + + edits.push(edit); + if !self.linked_edit_ranges.is_empty() { let start_anchor = snapshot.anchor_before(selection.head()); let end_anchor = snapshot.anchor_after(selection.tail()); @@ -4735,45 +4750,30 @@ impl Editor { .linked_editing_ranges_for(start_anchor.text_anchor..end_anchor.text_anchor, cx) { for (buffer, edits) in ranges { - linked_edits.entry(buffer.clone()).or_default().extend( - edits - .into_iter() - .map(|range| (range, new_text[common_prefix_len..].to_owned())), - ); + linked_edits + .entry(buffer.clone()) + .or_default() + .extend(edits.into_iter().map(|range| (range, new_text.to_owned()))); } } } } - let text = &new_text[common_prefix_len..]; - let utf16_range_to_replace = range_to_replace.map(|range| { - let newest_selection = self.selections.newest::(cx).range(); - let selection_start_utf16 = newest_selection.start.0 as isize; - - range.start.to_offset_utf16(&snapshot).0 as isize - selection_start_utf16 - ..range.end.to_offset_utf16(&snapshot).0 as isize - selection_start_utf16 - }); cx.emit(EditorEvent::InputHandled { - utf16_range_to_replace, - text: text.into(), + utf16_range_to_replace: None, + text: new_text.clone().into(), }); self.transact(window, cx, |this, window, cx| { if let Some(mut snippet) = snippet { - snippet.text = text.to_string(); - for tabstop in snippet - .tabstops - .iter_mut() - .flat_map(|tabstop| tabstop.ranges.iter_mut()) - { - tabstop.start -= common_prefix_len as isize; - tabstop.end -= common_prefix_len as isize; - } - + snippet.text = new_text.to_string(); + let ranges = edits + .iter() + .map(|(range, _)| range.clone()) + .collect::>(); this.insert_snippet(&ranges, snippet, window, cx).log_err(); } else { this.buffer.update(cx, |buffer, cx| { - let edits = ranges.iter().map(|range| (range.clone(), text)); let auto_indent = if completion.insert_text_mode == Some(InsertTextMode::AS_IS) { None @@ -4833,6 +4833,89 @@ impl Editor { })) } + fn prepare_code_actions_task( + &mut self, + action: &ToggleCodeActions, + window: &mut Window, + cx: &mut Context, + ) -> Task, CodeActionContents)>> { + let snapshot = self.snapshot(window, cx); + let multibuffer_point = action + .deployed_from_indicator + .map(|row| DisplayPoint::new(row, 0).to_point(&snapshot)) + .unwrap_or_else(|| self.selections.newest::(cx).head()); + + let Some((buffer, buffer_row)) = snapshot + .buffer_snapshot + .buffer_line_for_row(MultiBufferRow(multibuffer_point.row)) + .and_then(|(buffer_snapshot, range)| { + self.buffer + .read(cx) + .buffer(buffer_snapshot.remote_id()) + .map(|buffer| (buffer, range.start.row)) + }) + else { + return Task::ready(None); + }; + + let (_, code_actions) = self + .available_code_actions + .clone() + .and_then(|(location, code_actions)| { + let snapshot = location.buffer.read(cx).snapshot(); + let point_range = location.range.to_point(&snapshot); + let point_range = point_range.start.row..=point_range.end.row; + if point_range.contains(&buffer_row) { + Some((location, code_actions)) + } else { + None + } + }) + .unzip(); + + let buffer_id = buffer.read(cx).remote_id(); + let tasks = self + .tasks + .get(&(buffer_id, buffer_row)) + .map(|t| Arc::new(t.to_owned())); + + if tasks.is_none() && code_actions.is_none() { + return Task::ready(None); + } + + self.completion_tasks.clear(); + self.discard_inline_completion(false, cx); + + let task_context = tasks + .as_ref() + .zip(self.project.clone()) + .map(|(tasks, project)| { + Self::build_tasks_context(&project, &buffer, buffer_row, tasks, cx) + }); + + cx.spawn_in(window, async move |_, _| { + let task_context = match task_context { + Some(task_context) => task_context.await, + None => None, + }; + let resolved_tasks = tasks.zip(task_context).map(|(tasks, task_context)| { + Rc::new(ResolvedTasks { + templates: tasks.resolve(&task_context).collect(), + position: snapshot + .buffer_snapshot + .anchor_before(Point::new(multibuffer_point.row, tasks.column)), + }) + }); + Some(( + buffer, + CodeActionContents { + actions: code_actions, + tasks: resolved_tasks, + }, + )) + }) + } + pub fn toggle_code_actions( &mut self, action: &ToggleCodeActions, @@ -4853,113 +4936,58 @@ impl Editor { } } drop(context_menu); - let snapshot = self.snapshot(window, cx); + let deployed_from_indicator = action.deployed_from_indicator; let mut task = self.code_actions_task.take(); let action = action.clone(); + cx.spawn_in(window, async move |editor, cx| { while let Some(prev_task) = task { prev_task.await.log_err(); task = editor.update(cx, |this, _| this.code_actions_task.take())?; } - let spawned_test_task = editor.update_in(cx, |editor, window, cx| { - if editor.focus_handle.is_focused(window) { - let multibuffer_point = action - .deployed_from_indicator - .map(|row| DisplayPoint::new(row, 0).to_point(&snapshot)) - .unwrap_or_else(|| editor.selections.newest::(cx).head()); - let (buffer, buffer_row) = snapshot - .buffer_snapshot - .buffer_line_for_row(MultiBufferRow(multibuffer_point.row)) - .and_then(|(buffer_snapshot, range)| { - editor - .buffer - .read(cx) - .buffer(buffer_snapshot.remote_id()) - .map(|buffer| (buffer, range.start.row)) - })?; - let (_, code_actions) = editor - .available_code_actions - .clone() - .and_then(|(location, code_actions)| { - let snapshot = location.buffer.read(cx).snapshot(); - let point_range = location.range.to_point(&snapshot); - let point_range = point_range.start.row..=point_range.end.row; - if point_range.contains(&buffer_row) { - Some((location, code_actions)) - } else { - None - } - }) - .unzip(); - let buffer_id = buffer.read(cx).remote_id(); - let tasks = editor - .tasks - .get(&(buffer_id, buffer_row)) - .map(|t| Arc::new(t.to_owned())); - if tasks.is_none() && code_actions.is_none() { - return None; - } - - editor.completion_tasks.clear(); - editor.discard_inline_completion(false, cx); - let task_context = - tasks - .as_ref() - .zip(editor.project.clone()) - .map(|(tasks, project)| { - Self::build_tasks_context(&project, &buffer, buffer_row, tasks, cx) - }); - - let debugger_flag = cx.has_flag::(); - - Some(cx.spawn_in(window, async move |editor, cx| { - let task_context = match task_context { - Some(task_context) => task_context.await, - None => None, - }; - let resolved_tasks = - tasks.zip(task_context).map(|(tasks, task_context)| { - Rc::new(ResolvedTasks { - templates: tasks.resolve(&task_context).collect(), - position: snapshot.buffer_snapshot.anchor_before(Point::new( - multibuffer_point.row, - tasks.column, - )), - }) - }); - let spawn_straight_away = resolved_tasks.as_ref().map_or(false, |tasks| { - tasks - .templates - .iter() - .filter(|task| { - if matches!(task.1.task_type(), task::TaskType::Debug(_)) { - debugger_flag - } else { - true - } - }) - .count() - == 1 - }) && code_actions - .as_ref() - .map_or(true, |actions| actions.is_empty()); + let context_menu_task = editor.update_in(cx, |editor, window, cx| { + if !editor.focus_handle.is_focused(window) { + return Some(Task::ready(Ok(()))); + } + let debugger_flag = cx.has_flag::(); + let code_actions_task = editor.prepare_code_actions_task(&action, window, cx); + Some(cx.spawn_in(window, async move |editor, cx| { + if let Some((buffer, code_action_contents)) = code_actions_task.await { + let spawn_straight_away = + code_action_contents.tasks.as_ref().map_or(false, |tasks| { + tasks + .templates + .iter() + .filter(|task| { + if matches!(task.1.task_type(), task::TaskType::Debug(_)) { + debugger_flag + } else { + true + } + }) + .count() + == 1 + }) && code_action_contents + .actions + .as_ref() + .map_or(true, |actions| actions.is_empty()); if let Ok(task) = editor.update_in(cx, |editor, window, cx| { *editor.context_menu.borrow_mut() = Some(CodeContextMenu::CodeActions(CodeActionsMenu { buffer, - actions: CodeActionContents { - tasks: resolved_tasks, - actions: code_actions, - }, + actions: code_action_contents, selected_item: Default::default(), scroll_handle: UniformListScrollHandle::default(), deployed_from_indicator, })); if spawn_straight_away { if let Some(task) = editor.confirm_code_action( - &ConfirmCodeAction { item_ix: Some(0) }, + &ConfirmCodeAction { + item_ix: Some(0), + from_mouse_context_menu: false, + }, window, cx, ) { @@ -4974,12 +5002,12 @@ impl Editor { } else { Ok(()) } - })) - } else { - Some(Task::ready(Ok(()))) - } + } else { + Ok(()) + } + })) })?; - if let Some(task) = spawned_test_task { + if let Some(task) = context_menu_task { task.await?; } @@ -4996,17 +5024,27 @@ impl Editor { ) -> Option>> { self.hide_mouse_cursor(&HideMouseCursorOrigin::TypingAction); - let actions_menu = - if let CodeContextMenu::CodeActions(menu) = self.hide_context_menu(window, cx)? { - menu + let (action, buffer) = if action.from_mouse_context_menu { + if let Some(menu) = self.mouse_context_menu.take() { + let code_action = menu.code_action?; + let index = action.item_ix?; + let action = code_action.actions.get(index)?; + (action, code_action.buffer) } else { return None; - }; + } + } else { + if let CodeContextMenu::CodeActions(menu) = self.hide_context_menu(window, cx)? { + let action_ix = action.item_ix.unwrap_or(menu.selected_item); + let action = menu.actions.get(action_ix)?; + let buffer = menu.buffer; + (action, buffer) + } else { + return None; + } + }; - let action_ix = action.item_ix.unwrap_or(actions_menu.selected_item); - let action = actions_menu.actions.get(action_ix)?; let title = action.label(); - let buffer = actions_menu.buffer; let workspace = self.workspace()?; match action { @@ -8803,6 +8841,7 @@ impl Editor { self, source, clicked_point, + None, context_menu, window, cx, @@ -18853,143 +18892,158 @@ fn snippet_completions( buffer_position: text::Anchor, cx: &mut App, ) -> Task>> { - let language = buffer.read(cx).language_at(buffer_position); - let language_name = language.as_ref().map(|language| language.lsp_id()); + let languages = buffer.read(cx).languages_at(buffer_position); let snippet_store = project.snippets().read(cx); - let snippets = snippet_store.snippets_for(language_name, cx); - if snippets.is_empty() { + let scopes: Vec<_> = languages + .iter() + .filter_map(|language| { + let language_name = language.lsp_id(); + let snippets = snippet_store.snippets_for(Some(language_name), cx); + + if snippets.is_empty() { + None + } else { + Some((language.default_scope(), snippets)) + } + }) + .collect(); + + if scopes.is_empty() { return Task::ready(Ok(vec![])); } + let snapshot = buffer.read(cx).text_snapshot(); let chars: String = snapshot .reversed_chars_for_range(text::Anchor::MIN..buffer_position) .collect(); - - let scope = language.map(|language| language.default_scope()); let executor = cx.background_executor().clone(); cx.background_spawn(async move { - let classifier = CharClassifier::new(scope).for_completion(true); - let mut last_word = chars - .chars() - .take_while(|c| classifier.is_word(*c)) - .collect::(); - last_word = last_word.chars().rev().collect(); + let mut all_results: Vec = Vec::new(); + for (scope, snippets) in scopes.into_iter() { + let classifier = CharClassifier::new(Some(scope)).for_completion(true); + let mut last_word = chars + .chars() + .take_while(|c| classifier.is_word(*c)) + .collect::(); + last_word = last_word.chars().rev().collect(); - if last_word.is_empty() { - return Ok(vec![]); - } + if last_word.is_empty() { + return Ok(vec![]); + } - let as_offset = text::ToOffset::to_offset(&buffer_position, &snapshot); - let to_lsp = |point: &text::Anchor| { - let end = text::ToPointUtf16::to_point_utf16(point, &snapshot); - point_to_lsp(end) - }; - let lsp_end = to_lsp(&buffer_position); + let as_offset = text::ToOffset::to_offset(&buffer_position, &snapshot); + let to_lsp = |point: &text::Anchor| { + let end = text::ToPointUtf16::to_point_utf16(point, &snapshot); + point_to_lsp(end) + }; + let lsp_end = to_lsp(&buffer_position); - let candidates = snippets - .iter() - .enumerate() - .flat_map(|(ix, snippet)| { - snippet - .prefix - .iter() - .map(move |prefix| StringMatchCandidate::new(ix, &prefix)) - }) - .collect::>(); - - let mut matches = fuzzy::match_strings( - &candidates, - &last_word, - last_word.chars().any(|c| c.is_uppercase()), - 100, - &Default::default(), - executor, - ) - .await; - - // Remove all candidates where the query's start does not match the start of any word in the candidate - if let Some(query_start) = last_word.chars().next() { - matches.retain(|string_match| { - split_words(&string_match.string).any(|word| { - // Check that the first codepoint of the word as lowercase matches the first - // codepoint of the query as lowercase - word.chars() - .flat_map(|codepoint| codepoint.to_lowercase()) - .zip(query_start.to_lowercase()) - .all(|(word_cp, query_cp)| word_cp == query_cp) + let candidates = snippets + .iter() + .enumerate() + .flat_map(|(ix, snippet)| { + snippet + .prefix + .iter() + .map(move |prefix| StringMatchCandidate::new(ix, &prefix)) }) - }); - } + .collect::>(); - let matched_strings = matches - .into_iter() - .map(|m| m.string) - .collect::>(); + let mut matches = fuzzy::match_strings( + &candidates, + &last_word, + last_word.chars().any(|c| c.is_uppercase()), + 100, + &Default::default(), + executor.clone(), + ) + .await; - let result: Vec = snippets - .into_iter() - .filter_map(|snippet| { - let matching_prefix = snippet - .prefix - .iter() - .find(|prefix| matched_strings.contains(*prefix))?; - let start = as_offset - last_word.len(); - let start = snapshot.anchor_before(start); - let range = start..buffer_position; - let lsp_start = to_lsp(&start); - let lsp_range = lsp::Range { - start: lsp_start, - end: lsp_end, - }; - Some(Completion { - replace_range: range, - new_text: snippet.body.clone(), - source: CompletionSource::Lsp { - insert_range: None, - server_id: LanguageServerId(usize::MAX), - resolved: true, - lsp_completion: Box::new(lsp::CompletionItem { - label: snippet.prefix.first().unwrap().clone(), - kind: Some(CompletionItemKind::SNIPPET), - label_details: snippet.description.as_ref().map(|description| { - lsp::CompletionItemLabelDetails { - detail: Some(description.clone()), - description: None, - } + // Remove all candidates where the query's start does not match the start of any word in the candidate + if let Some(query_start) = last_word.chars().next() { + matches.retain(|string_match| { + split_words(&string_match.string).any(|word| { + // Check that the first codepoint of the word as lowercase matches the first + // codepoint of the query as lowercase + word.chars() + .flat_map(|codepoint| codepoint.to_lowercase()) + .zip(query_start.to_lowercase()) + .all(|(word_cp, query_cp)| word_cp == query_cp) + }) + }); + } + + let matched_strings = matches + .into_iter() + .map(|m| m.string) + .collect::>(); + + let mut result: Vec = snippets + .iter() + .filter_map(|snippet| { + let matching_prefix = snippet + .prefix + .iter() + .find(|prefix| matched_strings.contains(*prefix))?; + let start = as_offset - last_word.len(); + let start = snapshot.anchor_before(start); + let range = start..buffer_position; + let lsp_start = to_lsp(&start); + let lsp_range = lsp::Range { + start: lsp_start, + end: lsp_end, + }; + Some(Completion { + replace_range: range, + new_text: snippet.body.clone(), + source: CompletionSource::Lsp { + insert_range: None, + server_id: LanguageServerId(usize::MAX), + resolved: true, + lsp_completion: Box::new(lsp::CompletionItem { + label: snippet.prefix.first().unwrap().clone(), + kind: Some(CompletionItemKind::SNIPPET), + label_details: snippet.description.as_ref().map(|description| { + lsp::CompletionItemLabelDetails { + detail: Some(description.clone()), + description: None, + } + }), + insert_text_format: Some(InsertTextFormat::SNIPPET), + text_edit: Some(lsp::CompletionTextEdit::InsertAndReplace( + lsp::InsertReplaceEdit { + new_text: snippet.body.clone(), + insert: lsp_range, + replace: lsp_range, + }, + )), + filter_text: Some(snippet.body.clone()), + sort_text: Some(char::MAX.to_string()), + ..lsp::CompletionItem::default() }), - insert_text_format: Some(InsertTextFormat::SNIPPET), - text_edit: Some(lsp::CompletionTextEdit::InsertAndReplace( - lsp::InsertReplaceEdit { - new_text: snippet.body.clone(), - insert: lsp_range, - replace: lsp_range, - }, - )), - filter_text: Some(snippet.body.clone()), - sort_text: Some(char::MAX.to_string()), - ..lsp::CompletionItem::default() + lsp_defaults: None, + }, + label: CodeLabel { + text: matching_prefix.clone(), + runs: Vec::new(), + filter_range: 0..matching_prefix.len(), + }, + icon_path: None, + documentation: snippet.description.clone().map(|description| { + CompletionDocumentation::SingleLine(description.into()) }), - lsp_defaults: None, - }, - label: CodeLabel { - text: matching_prefix.clone(), - runs: Vec::new(), - filter_range: 0..matching_prefix.len(), - }, - icon_path: None, - documentation: snippet - .description - .clone() - .map(|description| CompletionDocumentation::SingleLine(description.into())), - insert_text_mode: None, - confirm: None, + insert_text_mode: None, + confirm: None, + }) }) - }) - .collect(); + .collect(); - Ok(result) + all_results.append(&mut result); + } + + Ok(all_results) }) } diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index b7354bdb9e..51647b0226 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -1316,8 +1316,6 @@ fn test_move_cursor(cx: &mut TestAppContext) { }); } -// TODO: Re-enable this test -#[cfg(target_os = "macos")] #[gpui::test] fn test_move_cursor_multibyte(cx: &mut TestAppContext) { init_test(cx, |_| {}); @@ -9679,9 +9677,9 @@ async fn test_completion_mode(cx: &mut TestAppContext) { buffer_marked_text: "before after".into(), completion_text: "editor", expected_with_insert_mode: "before editorˇtor after".into(), - expected_with_replace_mode: "before ediˇtor after".into(), - expected_with_replace_subsequence_mode: "before ediˇtor after".into(), - expected_with_replace_suffix_mode: "before ediˇtor after".into(), + expected_with_replace_mode: "before editorˇ after".into(), + expected_with_replace_subsequence_mode: "before editorˇ after".into(), + expected_with_replace_suffix_mode: "before editorˇ after".into(), }, Run { run_description: "End of word matches completion text -- cursor at end", @@ -9729,9 +9727,9 @@ async fn test_completion_mode(cx: &mut TestAppContext) { buffer_marked_text: "[]".into(), completion_text: "element", expected_with_insert_mode: "[elementˇelement]".into(), - expected_with_replace_mode: "[elˇement]".into(), + expected_with_replace_mode: "[elementˇ]".into(), expected_with_replace_subsequence_mode: "[elementˇelement]".into(), - expected_with_replace_suffix_mode: "[elˇement]".into(), + expected_with_replace_suffix_mode: "[elementˇ]".into(), }, Run { run_description: "Ends with matching suffix", @@ -9925,6 +9923,270 @@ async fn test_completion_with_mode_specified_by_action(cx: &mut TestAppContext) apply_additional_edits.await.unwrap(); } +#[gpui::test] +async fn test_completion_replacing_suffix_in_multicursors(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + let mut cx = EditorLspTestContext::new_rust( + lsp::ServerCapabilities { + completion_provider: Some(lsp::CompletionOptions { + resolve_provider: Some(true), + ..Default::default() + }), + ..Default::default() + }, + cx, + ) + .await; + + let initial_state = indoc! {" + 1. buf.to_offˇsuffix + 2. buf.to_offˇsuf + 3. buf.to_offˇfix + 4. buf.to_offˇ + 5. into_offˇensive + 6. ˇsuffix + 7. let ˇ // + 8. aaˇzz + 9. buf.to_off«zzzzzˇ»suffix + 10. buf.«ˇzzzzz»suffix + 11. to_off«ˇzzzzz» + + buf.to_offˇsuffix // newest cursor + "}; + let completion_marked_buffer = indoc! {" + 1. buf.to_offsuffix + 2. buf.to_offsuf + 3. buf.to_offfix + 4. buf.to_off + 5. into_offensive + 6. suffix + 7. let // + 8. aazz + 9. buf.to_offzzzzzsuffix + 10. buf.zzzzzsuffix + 11. to_offzzzzz + + buf. // newest cursor + "}; + let completion_text = "to_offset"; + let expected = indoc! {" + 1. buf.to_offsetˇ + 2. buf.to_offsetˇsuf + 3. buf.to_offsetˇfix + 4. buf.to_offsetˇ + 5. into_offsetˇensive + 6. to_offsetˇsuffix + 7. let to_offsetˇ // + 8. aato_offsetˇzz + 9. buf.to_offsetˇ + 10. buf.to_offsetˇsuffix + 11. to_offsetˇ + + buf.to_offsetˇ // newest cursor + "}; + + cx.set_state(initial_state); + cx.update_editor(|editor, window, cx| { + editor.show_completions(&ShowCompletions { trigger: None }, window, cx); + }); + + let counter = Arc::new(AtomicUsize::new(0)); + handle_completion_request_with_insert_and_replace( + &mut cx, + completion_marked_buffer, + vec![completion_text], + counter.clone(), + ) + .await; + cx.condition(|editor, _| editor.context_menu_visible()) + .await; + assert_eq!(counter.load(atomic::Ordering::Acquire), 1); + + let apply_additional_edits = cx.update_editor(|editor, window, cx| { + editor + .confirm_completion_replace(&ConfirmCompletionReplace, window, cx) + .unwrap() + }); + cx.assert_editor_state(expected); + handle_resolve_completion_request(&mut cx, None).await; + apply_additional_edits.await.unwrap(); +} + +// This used to crash +#[gpui::test] +async fn test_completion_in_multibuffer_with_replace_range(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + + let buffer_text = indoc! {" + fn main() { + 10.satu; + + // + // separate cursors so they open in different excerpts (manually reproducible) + // + + 10.satu20; + } + "}; + let multibuffer_text_with_selections = indoc! {" + fn main() { + 10.satuˇ; + + // + + // + + 10.satuˇ20; + } + "}; + let expected_multibuffer = indoc! {" + fn main() { + 10.saturating_sub()ˇ; + + // + + // + + 10.saturating_sub()ˇ; + } + "}; + + let first_excerpt_end = buffer_text.find("//").unwrap() + 3; + let second_excerpt_end = buffer_text.rfind("//").unwrap() - 4; + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/a"), + json!({ + "main.rs": buffer_text, + }), + ) + .await; + + let project = Project::test(fs, [path!("/a").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + language_registry.add(rust_lang()); + let mut fake_servers = language_registry.register_fake_lsp( + "Rust", + FakeLspAdapter { + capabilities: lsp::ServerCapabilities { + completion_provider: Some(lsp::CompletionOptions { + resolve_provider: None, + ..lsp::CompletionOptions::default() + }), + ..lsp::ServerCapabilities::default() + }, + ..FakeLspAdapter::default() + }, + ); + let workspace = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = &mut VisualTestContext::from_window(*workspace, cx); + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/a/main.rs"), cx) + }) + .await + .unwrap(); + + let multi_buffer = cx.new(|cx| { + let mut multi_buffer = MultiBuffer::new(Capability::ReadWrite); + multi_buffer.push_excerpts( + buffer.clone(), + [ExcerptRange::new(0..first_excerpt_end)], + cx, + ); + multi_buffer.push_excerpts( + buffer.clone(), + [ExcerptRange::new(second_excerpt_end..buffer_text.len())], + cx, + ); + multi_buffer + }); + + let editor = workspace + .update(cx, |_, window, cx| { + cx.new(|cx| { + Editor::new( + EditorMode::Full { + scale_ui_elements_with_buffer_font_size: false, + show_active_line_background: false, + }, + multi_buffer.clone(), + Some(project.clone()), + window, + cx, + ) + }) + }) + .unwrap(); + + let pane = workspace + .update(cx, |workspace, _, _| workspace.active_pane().clone()) + .unwrap(); + pane.update_in(cx, |pane, window, cx| { + pane.add_item(Box::new(editor.clone()), true, true, None, window, cx); + }); + + let fake_server = fake_servers.next().await.unwrap(); + + editor.update_in(cx, |editor, window, cx| { + editor.change_selections(None, window, cx, |s| { + s.select_ranges([ + Point::new(1, 11)..Point::new(1, 11), + Point::new(7, 11)..Point::new(7, 11), + ]) + }); + + assert_text_with_selections(editor, multibuffer_text_with_selections, cx); + }); + + editor.update_in(cx, |editor, window, cx| { + editor.show_completions(&ShowCompletions { trigger: None }, window, cx); + }); + + fake_server + .set_request_handler::(move |_, _| async move { + let completion_item = lsp::CompletionItem { + label: "saturating_sub()".into(), + text_edit: Some(lsp::CompletionTextEdit::InsertAndReplace( + lsp::InsertReplaceEdit { + new_text: "saturating_sub()".to_owned(), + insert: lsp::Range::new( + lsp::Position::new(7, 7), + lsp::Position::new(7, 11), + ), + replace: lsp::Range::new( + lsp::Position::new(7, 7), + lsp::Position::new(7, 13), + ), + }, + )), + ..lsp::CompletionItem::default() + }; + + Ok(Some(lsp::CompletionResponse::Array(vec![completion_item]))) + }) + .next() + .await + .unwrap(); + + cx.condition(&editor, |editor, _| editor.context_menu_visible()) + .await; + + editor + .update_in(cx, |editor, window, cx| { + editor + .confirm_completion_replace(&ConfirmCompletionReplace, window, cx) + .unwrap() + }) + .await + .unwrap(); + + editor.update(cx, |editor, cx| { + assert_text_with_selections(editor, expected_multibuffer, cx); + }) +} + #[gpui::test] async fn test_completion(cx: &mut TestAppContext) { init_test(cx, |_| {}); diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index dffa4da3e1..a05e39e2f0 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -6638,7 +6638,7 @@ impl Element for EditorElement { } }; - if editor.set_wrap_width(wrap_width, cx) { + if editor.set_wrap_width(wrap_width.map(|w| w.ceil()), cx) { editor.snapshot(window, cx) } else { snapshot diff --git a/crates/editor/src/hover_popover.rs b/crates/editor/src/hover_popover.rs index f876fab52c..fd53c4b0ad 100644 --- a/crates/editor/src/hover_popover.rs +++ b/crates/editor/src/hover_popover.rs @@ -8,8 +8,8 @@ use crate::{ use gpui::{ AnyElement, AsyncWindowContext, Context, Entity, Focusable as _, FontWeight, Hsla, InteractiveElement, IntoElement, MouseButton, ParentElement, Pixels, ScrollHandle, Size, - Stateful, StatefulInteractiveElement, StyleRefinement, Styled, Task, TextStyleRefinement, - Window, div, px, + Stateful, StatefulInteractiveElement, StyleRefinement, Styled, Subscription, Task, + TextStyleRefinement, Window, div, px, }; use itertools::Itertools; use language::{DiagnosticEntry, Language, LanguageRegistry}; @@ -64,26 +64,31 @@ pub fn show_keyboard_hover( window: &mut Window, cx: &mut Context, ) -> bool { - let info_popovers = editor.hover_state.info_popovers.clone(); - for p in info_popovers { - let keyboard_grace = p.keyboard_grace.borrow(); - if *keyboard_grace { - if let Some(anchor) = p.anchor { - show_hover(editor, anchor, false, window, cx); - return true; - } + if let Some(anchor) = editor.hover_state.info_popovers.iter().find_map(|p| { + if *p.keyboard_grace.borrow() { + p.anchor + } else { + None } + }) { + show_hover(editor, anchor, false, window, cx); + return true; } - let diagnostic_popover = editor.hover_state.diagnostic_popover.clone(); - if let Some(d) = diagnostic_popover { - let keyboard_grace = d.keyboard_grace.borrow(); - if *keyboard_grace { - if let Some(anchor) = d.anchor { - show_hover(editor, anchor, false, window, cx); - return true; + if let Some(anchor) = editor + .hover_state + .diagnostic_popover + .as_ref() + .and_then(|d| { + if *d.keyboard_grace.borrow() { + d.anchor + } else { + None } - } + }) + { + show_hover(editor, anchor, false, window, cx); + return true; } false @@ -164,6 +169,18 @@ pub fn hover_at_inlay( let parsed_content = parse_blocks(&blocks, &language_registry, None, cx).await; let scroll_handle = ScrollHandle::new(); + + let subscription = this + .update(cx, |_, cx| { + if let Some(parsed_content) = &parsed_content { + Some(cx.observe(parsed_content, |_, _, cx| cx.notify())) + } else { + None + } + }) + .ok() + .flatten(); + let hover_popover = InfoPopover { symbol_range: RangeInEditor::Inlay(inlay_hover.range.clone()), parsed_content, @@ -171,6 +188,7 @@ pub fn hover_at_inlay( scroll_handle, keyboard_grace: Rc::new(RefCell::new(false)), anchor: None, + _subscription: subscription, }; this.update(cx, |this, cx| { @@ -307,40 +325,44 @@ fn show_hover( .anchor_after(local_diagnostic.range.end), }; - let mut border_color: Option = None; - let mut background_color: Option = None; + let (background_color, border_color) = cx.update(|_, cx| { + let status_colors = cx.theme().status(); + match local_diagnostic.diagnostic.severity { + DiagnosticSeverity::ERROR => { + (status_colors.error_background, status_colors.error_border) + } + DiagnosticSeverity::WARNING => ( + status_colors.warning_background, + status_colors.warning_border, + ), + DiagnosticSeverity::INFORMATION => { + (status_colors.info_background, status_colors.info_border) + } + DiagnosticSeverity::HINT => { + (status_colors.hint_background, status_colors.hint_border) + } + _ => ( + status_colors.ignored_background, + status_colors.ignored_border, + ), + } + })?; let parsed_content = cx - .new_window_entity(|_window, cx| { - let status_colors = cx.theme().status(); - - match local_diagnostic.diagnostic.severity { - DiagnosticSeverity::ERROR => { - background_color = Some(status_colors.error_background); - border_color = Some(status_colors.error_border); - } - DiagnosticSeverity::WARNING => { - background_color = Some(status_colors.warning_background); - border_color = Some(status_colors.warning_border); - } - DiagnosticSeverity::INFORMATION => { - background_color = Some(status_colors.info_background); - border_color = Some(status_colors.info_border); - } - DiagnosticSeverity::HINT => { - background_color = Some(status_colors.hint_background); - border_color = Some(status_colors.hint_border); - } - _ => { - background_color = Some(status_colors.ignored_background); - border_color = Some(status_colors.ignored_border); - } - }; - - Markdown::new_text(SharedString::new(text), cx) - }) + .new(|cx| Markdown::new_text(SharedString::new(text), cx)) .ok(); + let subscription = this + .update(cx, |_, cx| { + if let Some(parsed_content) = &parsed_content { + Some(cx.observe(parsed_content, |_, _, cx| cx.notify())) + } else { + None + } + }) + .ok() + .flatten(); + Some(DiagnosticPopover { local_diagnostic, parsed_content, @@ -348,6 +370,7 @@ fn show_hover( background_color, keyboard_grace: Rc::new(RefCell::new(ignore_timeout)), anchor: Some(anchor), + _subscription: subscription, }) } else { None @@ -400,6 +423,16 @@ fn show_hover( }]; let parsed_content = parse_blocks(&blocks, &language_registry, None, cx).await; let scroll_handle = ScrollHandle::new(); + let subscription = this + .update(cx, |_, cx| { + if let Some(parsed_content) = &parsed_content { + Some(cx.observe(parsed_content, |_, _, cx| cx.notify())) + } else { + None + } + }) + .ok() + .flatten(); info_popovers.push(InfoPopover { symbol_range: RangeInEditor::Text(range), parsed_content, @@ -407,6 +440,7 @@ fn show_hover( scroll_handle, keyboard_grace: Rc::new(RefCell::new(ignore_timeout)), anchor: Some(anchor), + _subscription: subscription, }) } @@ -440,6 +474,16 @@ fn show_hover( let parsed_content = parse_blocks(&blocks, &language_registry, language, cx).await; let scroll_handle = ScrollHandle::new(); hover_highlights.push(range.clone()); + let subscription = this + .update(cx, |_, cx| { + if let Some(parsed_content) = &parsed_content { + Some(cx.observe(parsed_content, |_, _, cx| cx.notify())) + } else { + None + } + }) + .ok() + .flatten(); info_popovers.push(InfoPopover { symbol_range: RangeInEditor::Text(range), parsed_content, @@ -447,6 +491,7 @@ fn show_hover( scroll_handle, keyboard_grace: Rc::new(RefCell::new(ignore_timeout)), anchor: Some(anchor), + _subscription: subscription, }); } @@ -660,7 +705,7 @@ pub fn open_markdown_url(link: SharedString, window: &mut Window, cx: &mut App) cx.open_url(&link); } -#[derive(Default, Debug)] +#[derive(Default)] pub struct HoverState { pub info_popovers: Vec, pub diagnostic_popover: Option, @@ -742,7 +787,6 @@ impl HoverState { } } -#[derive(Debug, Clone)] pub(crate) struct InfoPopover { pub(crate) symbol_range: RangeInEditor, pub(crate) parsed_content: Option>, @@ -750,6 +794,7 @@ pub(crate) struct InfoPopover { pub(crate) scrollbar_state: ScrollbarState, pub(crate) keyboard_grace: Rc>, pub(crate) anchor: Option, + _subscription: Option, } impl InfoPopover { @@ -760,7 +805,7 @@ impl InfoPopover { cx: &mut Context, ) -> AnyElement { let keyboard_grace = Rc::clone(&self.keyboard_grace); - let mut d = div() + div() .id("info_popover") .elevation_2(cx) // Prevent a mouse down/move on the popover from being propagated to the editor, @@ -770,11 +815,9 @@ impl InfoPopover { let mut keyboard_grace = keyboard_grace.borrow_mut(); *keyboard_grace = false; cx.stop_propagation(); - }); - - if let Some(markdown) = &self.parsed_content { - d = d - .child( + }) + .when_some(self.parsed_content.clone(), |this, markdown| { + this.child( div() .id("info-md-container") .overflow_y_scroll() @@ -783,19 +826,16 @@ impl InfoPopover { .p_2() .track_scroll(&self.scroll_handle) .child( - MarkdownElement::new( - markdown.clone(), - hover_markdown_style(window, cx), - ) - .code_block_renderer(markdown::CodeBlockRenderer::Default { - copy_button: false, - }) - .on_url_click(open_markdown_url), + MarkdownElement::new(markdown, hover_markdown_style(window, cx)) + .code_block_renderer(markdown::CodeBlockRenderer::Default { + copy_button: false, + }) + .on_url_click(open_markdown_url), ), ) - .child(self.render_vertical_scrollbar(cx)); - } - d.into_any_element() + .child(self.render_vertical_scrollbar(cx)) + }) + .into_any_element() } pub fn scroll(&self, amount: &ScrollAmount, window: &mut Window, cx: &mut Context) { @@ -842,14 +882,14 @@ impl InfoPopover { } } -#[derive(Debug, Clone)] pub struct DiagnosticPopover { pub(crate) local_diagnostic: DiagnosticEntry, parsed_content: Option>, - border_color: Option, - background_color: Option, + border_color: Hsla, + background_color: Hsla, pub keyboard_grace: Rc>, pub anchor: Option, + _subscription: Option, } impl DiagnosticPopover { @@ -860,53 +900,7 @@ impl DiagnosticPopover { cx: &mut Context, ) -> AnyElement { let keyboard_grace = Rc::clone(&self.keyboard_grace); - let mut markdown_div = div().py_1().px_2(); - if let Some(markdown) = &self.parsed_content { - let settings = ThemeSettings::get_global(cx); - let mut base_text_style = window.text_style(); - base_text_style.refine(&TextStyleRefinement { - font_family: Some(settings.ui_font.family.clone()), - font_fallbacks: settings.ui_font.fallbacks.clone(), - font_size: Some(settings.ui_font_size(cx).into()), - color: Some(cx.theme().colors().editor_foreground), - background_color: Some(gpui::transparent_black()), - ..Default::default() - }); - let markdown_style = MarkdownStyle { - base_text_style, - selection_background_color: { cx.theme().players().local().selection }, - link: TextStyleRefinement { - underline: Some(gpui::UnderlineStyle { - thickness: px(1.), - color: Some(cx.theme().colors().editor_foreground), - wavy: false, - }), - ..Default::default() - }, - ..Default::default() - }; - - markdown_div = markdown_div.child( - MarkdownElement::new(markdown.clone(), markdown_style) - .code_block_renderer(markdown::CodeBlockRenderer::Default { - copy_button: false, - }) - .on_url_click(open_markdown_url), - ); - } - - if let Some(background_color) = &self.background_color { - markdown_div = markdown_div.bg(*background_color); - } - - if let Some(border_color) = &self.border_color { - markdown_div = markdown_div - .border_1() - .border_color(*border_color) - .rounded_lg(); - } - - let diagnostic_div = div() + div() .id("diagnostic") .block() .max_h(max_size.height) @@ -928,9 +922,51 @@ impl DiagnosticPopover { *keyboard_grace = false; cx.stop_propagation(); }) - .child(markdown_div); - - diagnostic_div.into_any_element() + .when_some(self.parsed_content.clone(), |this, markdown| { + this.child( + div() + .py_1() + .px_2() + .child( + MarkdownElement::new(markdown, { + let settings = ThemeSettings::get_global(cx); + let mut base_text_style = window.text_style(); + base_text_style.refine(&TextStyleRefinement { + font_family: Some(settings.ui_font.family.clone()), + font_fallbacks: settings.ui_font.fallbacks.clone(), + font_size: Some(settings.ui_font_size(cx).into()), + color: Some(cx.theme().colors().editor_foreground), + background_color: Some(gpui::transparent_black()), + ..Default::default() + }); + MarkdownStyle { + base_text_style, + selection_background_color: { + cx.theme().players().local().selection + }, + link: TextStyleRefinement { + underline: Some(gpui::UnderlineStyle { + thickness: px(1.), + color: Some(cx.theme().colors().editor_foreground), + wavy: false, + }), + ..Default::default() + }, + ..Default::default() + } + }) + .code_block_renderer(markdown::CodeBlockRenderer::Default { + copy_button: false, + }) + .on_url_click(open_markdown_url), + ) + .bg(self.background_color) + .border_1() + .border_color(self.border_color) + .rounded_lg(), + ) + }) + .into_any_element() } } @@ -1070,7 +1106,7 @@ mod tests { editor.hover_state.info_popovers.len(), 1, "Expected exactly one hover but got: {:?}", - editor.hover_state.info_popovers + editor.hover_state.info_popovers.len() ); let rendered_text = editor .hover_state @@ -1110,7 +1146,7 @@ mod tests { editor.hover_state.info_popovers.len(), 1, "Expected exactly one hover but got: {:?}", - editor.hover_state.info_popovers + editor.hover_state.info_popovers.len() ); let rendered_text = editor .hover_state @@ -1205,7 +1241,7 @@ mod tests { editor.hover_state.info_popovers.len(), 1, "Expected exactly one hover but got: {:?}", - editor.hover_state.info_popovers + editor.hover_state.info_popovers.len() ); let rendered_text = editor .hover_state @@ -1270,7 +1306,7 @@ mod tests { editor.hover_state.info_popovers.len(), 0, "Expected no hovers but got but got: {:?}", - editor.hover_state.info_popovers + editor.hover_state.info_popovers.len() ); }); @@ -1294,7 +1330,7 @@ mod tests { editor.hover_state.info_popovers.len(), 1, "Expected exactly one hover but got: {:?}", - editor.hover_state.info_popovers + editor.hover_state.info_popovers.len() ); let rendered_text = editor @@ -1352,7 +1388,7 @@ mod tests { editor.hover_state.info_popovers.len(), 1, "Expected exactly one hover but got: {:?}", - editor.hover_state.info_popovers + editor.hover_state.info_popovers.len() ); let rendered_text = editor .hover_state @@ -1418,7 +1454,7 @@ mod tests { editor.hover_state.info_popovers.len(), 1, "Expected exactly one hover but got: {:?}", - editor.hover_state.info_popovers + editor.hover_state.info_popovers.len() ); let rendered_text = editor .hover_state @@ -1795,7 +1831,7 @@ mod tests { assert!( hover_state.diagnostic_popover.is_none() && hover_state.info_popovers.len() == 1 ); - let popover = hover_state.info_popovers.first().cloned().unwrap(); + let popover = hover_state.info_popovers.first().unwrap(); let buffer_snapshot = editor.buffer().update(cx, |buffer, cx| buffer.snapshot(cx)); assert_eq!( popover.symbol_range, @@ -1850,7 +1886,7 @@ mod tests { assert!( hover_state.diagnostic_popover.is_none() && hover_state.info_popovers.len() == 1 ); - let popover = hover_state.info_popovers.first().cloned().unwrap(); + let popover = hover_state.info_popovers.first().unwrap(); let buffer_snapshot = editor.buffer().update(cx, |buffer, cx| buffer.snapshot(cx)); assert_eq!( popover.symbol_range, diff --git a/crates/editor/src/mouse_context_menu.rs b/crates/editor/src/mouse_context_menu.rs index 9450ea4562..bcad4ef3c0 100644 --- a/crates/editor/src/mouse_context_menu.rs +++ b/crates/editor/src/mouse_context_menu.rs @@ -1,15 +1,22 @@ use crate::{ - Copy, CopyAndTrim, CopyPermalinkToLine, Cut, DebuggerEvaluateSelectedText, DisplayPoint, - DisplaySnapshot, Editor, FindAllReferences, GoToDeclaration, GoToDefinition, + ConfirmCodeAction, Copy, CopyAndTrim, CopyPermalinkToLine, Cut, DebuggerEvaluateSelectedText, + DisplayPoint, DisplaySnapshot, Editor, FindAllReferences, GoToDeclaration, GoToDefinition, GoToImplementation, GoToTypeDefinition, Paste, Rename, RevealInFileManager, SelectMode, ToDisplayPoint, ToggleCodeActions, actions::{Format, FormatSelections}, + code_context_menus::CodeActionContents, selections_collection::SelectionsCollection, }; +use feature_flags::{Debugger, FeatureFlagAppExt as _}; use gpui::prelude::FluentBuilder; -use gpui::{Context, DismissEvent, Entity, Focusable as _, Pixels, Point, Subscription, Window}; +use gpui::{ + Context, DismissEvent, Entity, FocusHandle, Focusable as _, Pixels, Point, Subscription, Task, + Window, +}; use std::ops::Range; use text::PointUtf16; +use ui::ContextMenu; +use util::ResultExt; use workspace::OpenInTerminal; #[derive(Debug)] @@ -25,12 +32,23 @@ pub enum MenuPosition { }, } +pub struct MouseCodeAction { + pub actions: CodeActionContents, + pub buffer: Entity, +} + pub struct MouseContextMenu { pub(crate) position: MenuPosition, pub(crate) context_menu: Entity, + pub(crate) code_action: Option, _subscription: Subscription, } +enum CodeActionLoadState { + Loading, + Loaded(CodeActionContents), +} + impl std::fmt::Debug for MouseContextMenu { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("MouseContextMenu") @@ -45,6 +63,7 @@ impl MouseContextMenu { editor: &mut Editor, source: multi_buffer::Anchor, position: Point, + code_action: Option, context_menu: Entity, window: &mut Window, cx: &mut Context, @@ -63,6 +82,7 @@ impl MouseContextMenu { return Some(MouseContextMenu::new( menu_position, context_menu, + code_action, window, cx, )); @@ -71,6 +91,7 @@ impl MouseContextMenu { pub(crate) fn new( position: MenuPosition, context_menu: Entity, + code_action: Option, window: &mut Window, cx: &mut Context, ) -> Self { @@ -91,6 +112,7 @@ impl MouseContextMenu { Self { position, context_menu, + code_action, _subscription, } } @@ -129,13 +151,13 @@ pub fn deploy_context_menu( let display_map = editor.selections.display_map(cx); let source_anchor = display_map.display_point_to_anchor(point, text::Bias::Right); - let context_menu = if let Some(custom) = editor.custom_context_menu.take() { + if let Some(custom) = editor.custom_context_menu.take() { let menu = custom(editor, point, window, cx); editor.custom_context_menu = Some(custom); let Some(menu) = menu else { return; }; - menu + set_context_menu(editor, menu, source_anchor, position, None, window, cx); } else { // Don't show the context menu if there isn't a project associated with this editor let Some(project) = editor.project.clone() else { @@ -174,74 +196,223 @@ pub fn deploy_context_menu( !filter.is_hidden(&DebuggerEvaluateSelectedText) }); - ui::ContextMenu::build(window, cx, |menu, _window, _cx| { - let builder = menu - .on_blur_subscription(Subscription::new(|| {})) - .when(evaluate_selection && has_selections, |builder| { - builder - .action("Evaluate Selection", Box::new(DebuggerEvaluateSelectedText)) - .separator() - }) - .action("Go to Definition", Box::new(GoToDefinition)) - .action("Go to Declaration", Box::new(GoToDeclaration)) - .action("Go to Type Definition", Box::new(GoToTypeDefinition)) - .action("Go to Implementation", Box::new(GoToImplementation)) - .action("Find All References", Box::new(FindAllReferences)) - .separator() - .action("Rename Symbol", Box::new(Rename)) - .action("Format Buffer", Box::new(Format)) - .when(has_selections, |cx| { - cx.action("Format Selections", Box::new(FormatSelections)) - }) - .action( - "Code Actions", - Box::new(ToggleCodeActions { - deployed_from_indicator: None, - }), - ) - .separator() - .action("Cut", Box::new(Cut)) - .action("Copy", Box::new(Copy)) - .action("Copy and trim", Box::new(CopyAndTrim)) - .action("Paste", Box::new(Paste)) - .separator() - .map(|builder| { - let reveal_in_finder_label = if cfg!(target_os = "macos") { - "Reveal in Finder" - } else { - "Reveal in File Manager" - }; - const OPEN_IN_TERMINAL_LABEL: &str = "Open in Terminal"; - if has_reveal_target { - builder - .action(reveal_in_finder_label, Box::new(RevealInFileManager)) - .action(OPEN_IN_TERMINAL_LABEL, Box::new(OpenInTerminal)) - } else { - builder - .disabled_action(reveal_in_finder_label, Box::new(RevealInFileManager)) - .disabled_action(OPEN_IN_TERMINAL_LABEL, Box::new(OpenInTerminal)) - } - }) - .map(|builder| { - const COPY_PERMALINK_LABEL: &str = "Copy Permalink"; - if has_git_repo { - builder.action(COPY_PERMALINK_LABEL, Box::new(CopyPermalinkToLine)) - } else { - builder.disabled_action(COPY_PERMALINK_LABEL, Box::new(CopyPermalinkToLine)) - } - }); - match focus { - Some(focus) => builder.context(focus), - None => builder, - } - }) - }; + let menu = build_context_menu( + focus, + has_selections, + has_reveal_target, + has_git_repo, + evaluate_selection, + Some(CodeActionLoadState::Loading), + window, + cx, + ); + set_context_menu(editor, menu, source_anchor, position, None, window, cx); + + let mut actions_task = editor.code_actions_task.take(); + cx.spawn_in(window, async move |editor, cx| { + while let Some(prev_task) = actions_task { + prev_task.await.log_err(); + actions_task = editor.update(cx, |this, _| this.code_actions_task.take())?; + } + let action = ToggleCodeActions { + deployed_from_indicator: Some(point.row()), + }; + let context_menu_task = editor.update_in(cx, |editor, window, cx| { + let code_actions_task = editor.prepare_code_actions_task(&action, window, cx); + Some(cx.spawn_in(window, async move |editor, cx| { + let code_action_result = code_actions_task.await; + if let Ok(editor_task) = editor.update_in(cx, |editor, window, cx| { + let Some(mouse_context_menu) = editor.mouse_context_menu.take() else { + return Task::ready(Ok::<_, anyhow::Error>(())); + }; + if mouse_context_menu + .context_menu + .focus_handle(cx) + .contains_focused(window, cx) + { + window.focus(&editor.focus_handle(cx)); + } + drop(mouse_context_menu); + let (state, code_action) = + if let Some((buffer, actions)) = code_action_result { + ( + CodeActionLoadState::Loaded(actions.clone()), + Some(MouseCodeAction { actions, buffer }), + ) + } else { + ( + CodeActionLoadState::Loaded(CodeActionContents::default()), + None, + ) + }; + let menu = build_context_menu( + window.focused(cx), + has_selections, + has_reveal_target, + has_git_repo, + evaluate_selection, + Some(state), + window, + cx, + ); + set_context_menu( + editor, + menu, + source_anchor, + position, + code_action, + window, + cx, + ); + Task::ready(Ok(())) + }) { + editor_task.await + } else { + Ok(()) + } + })) + })?; + if let Some(task) = context_menu_task { + task.await?; + } + Ok::<_, anyhow::Error>(()) + }) + .detach_and_log_err(cx); + }; +} + +fn build_context_menu( + focus: Option, + has_selections: bool, + has_reveal_target: bool, + has_git_repo: bool, + evaluate_selection: bool, + code_action_load_state: Option, + window: &mut Window, + cx: &mut Context, +) -> Entity { + ui::ContextMenu::build(window, cx, |menu, _window, cx| { + let menu = menu + .on_blur_subscription(Subscription::new(|| {})) + .when_some(code_action_load_state, |menu, state| { + match state { + CodeActionLoadState::Loading => menu.disabled_action( + "Loading code actions...", + Box::new(ConfirmCodeAction { + item_ix: None, + from_mouse_context_menu: true, + }), + ), + CodeActionLoadState::Loaded(actions) => { + if actions.is_empty() { + menu.disabled_action( + "No code actions available", + Box::new(ConfirmCodeAction { + item_ix: None, + from_mouse_context_menu: true, + }), + ) + } else { + actions + .iter() + .filter(|action| { + if action + .as_task() + .map(|task| { + matches!(task.task_type(), task::TaskType::Debug(_)) + }) + .unwrap_or(false) + { + cx.has_flag::() + } else { + true + } + }) + .enumerate() + .fold(menu, |menu, (ix, action)| { + menu.action( + action.label(), + Box::new(ConfirmCodeAction { + item_ix: Some(ix), + from_mouse_context_menu: true, + }), + ) + }) + } + } + } + .separator() + }) + .when(evaluate_selection && has_selections, |builder| { + builder + .action("Evaluate Selection", Box::new(DebuggerEvaluateSelectedText)) + .separator() + }) + .action("Go to Definition", Box::new(GoToDefinition)) + .action("Go to Declaration", Box::new(GoToDeclaration)) + .action("Go to Type Definition", Box::new(GoToTypeDefinition)) + .action("Go to Implementation", Box::new(GoToImplementation)) + .action("Find All References", Box::new(FindAllReferences)) + .separator() + .action("Rename Symbol", Box::new(Rename)) + .action("Format Buffer", Box::new(Format)) + .when(has_selections, |cx| { + cx.action("Format Selections", Box::new(FormatSelections)) + }) + .separator() + .action("Cut", Box::new(Cut)) + .action("Copy", Box::new(Copy)) + .action("Copy and trim", Box::new(CopyAndTrim)) + .action("Paste", Box::new(Paste)) + .separator() + .map(|builder| { + let reveal_in_finder_label = if cfg!(target_os = "macos") { + "Reveal in Finder" + } else { + "Reveal in File Manager" + }; + const OPEN_IN_TERMINAL_LABEL: &str = "Open in Terminal"; + if has_reveal_target { + builder + .action(reveal_in_finder_label, Box::new(RevealInFileManager)) + .action(OPEN_IN_TERMINAL_LABEL, Box::new(OpenInTerminal)) + } else { + builder + .disabled_action(reveal_in_finder_label, Box::new(RevealInFileManager)) + .disabled_action(OPEN_IN_TERMINAL_LABEL, Box::new(OpenInTerminal)) + } + }) + .map(|builder| { + const COPY_PERMALINK_LABEL: &str = "Copy Permalink"; + if has_git_repo { + builder.action(COPY_PERMALINK_LABEL, Box::new(CopyPermalinkToLine)) + } else { + builder.disabled_action(COPY_PERMALINK_LABEL, Box::new(CopyPermalinkToLine)) + } + }); + match focus { + Some(focus) => menu.context(focus), + None => menu, + } + }) +} + +fn set_context_menu( + editor: &mut Editor, + context_menu: Entity, + source_anchor: multi_buffer::Anchor, + position: Option>, + code_action: Option, + window: &mut Window, + cx: &mut Context, +) { editor.mouse_context_menu = match position { Some(position) => MouseContextMenu::pinned_to_editor( editor, source_anchor, position, + code_action, context_menu, window, cx, @@ -255,6 +426,7 @@ pub fn deploy_context_menu( Some(MouseContextMenu::new( menu_position, context_menu, + code_action, window, cx, )) diff --git a/crates/editor/src/test.rs b/crates/editor/src/test.rs index e207b9988e..b197c56bbc 100644 --- a/crates/editor/src/test.rs +++ b/crates/editor/src/test.rs @@ -1,8 +1,7 @@ pub mod editor_lsp_test_context; pub mod editor_test_context; -use std::sync::LazyLock; - +pub use crate::rust_analyzer_ext::expand_macro_recursively; use crate::{ DisplayPoint, Editor, EditorMode, FoldPlaceholder, MultiBuffer, display_map::{DisplayMap, DisplaySnapshot, ToDisplayPoint}, @@ -11,11 +10,11 @@ use gpui::{ AppContext as _, Context, Entity, Font, FontFeatures, FontStyle, FontWeight, Pixels, Window, font, }; +use pretty_assertions::assert_eq; use project::Project; +use std::sync::LazyLock; use util::test::{marked_text_offsets, marked_text_ranges}; -pub use crate::rust_analyzer_ext::expand_macro_recursively; - #[cfg(test)] #[ctor::ctor] fn init_logger() { @@ -96,8 +95,12 @@ pub fn assert_text_with_selections( cx: &mut Context, ) { let (unmarked_text, text_ranges) = marked_text_ranges(marked_text, true); - assert_eq!(editor.text(cx), unmarked_text); - assert_eq!(editor.selections.ranges(cx), text_ranges); + assert_eq!(editor.text(cx), unmarked_text, "text doesn't match"); + assert_eq!( + editor.selections.ranges(cx), + text_ranges, + "selections don't match", + ); } // RA thinks this is dead code even though it is used in a whole lot of tests diff --git a/crates/eval/.gitignore b/crates/eval/.gitignore new file mode 100644 index 0000000000..89fb02c122 --- /dev/null +++ b/crates/eval/.gitignore @@ -0,0 +1,3 @@ +repos/ +worktrees/ +runs/ diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml index 0249c24dcf..42597393a1 100644 --- a/crates/eval/Cargo.toml +++ b/crates/eval/Cargo.toml @@ -7,28 +7,40 @@ edition.workspace = true [dependencies] agent.workspace = true anyhow.workspace = true +async-watch.workspace = true +assistant_settings.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true -assistant_settings.workspace = true +chrono.workspace = true +clap.workspace = true client.workspace = true +collections.workspace = true context_server.workspace = true dap.workspace = true env_logger.workspace = true +extension.workspace = true fs.workspace = true futures.workspace = true gpui.workspace = true gpui_tokio.workspace = true +handlebars.workspace = true language.workspace = true +language_extension.workspace = true language_model.workspace = true language_models.workspace = true +languages.workspace = true node_runtime.workspace = true +paths.workspace = true project.workspace = true prompt_store.workspace = true release_channel.workspace = true reqwest_client.workspace = true serde.workspace = true settings.workspace = true +shellexpand.workspace = true toml.workspace = true +unindent.workspace = true +util.workspace = true workspace-hack.workspace = true [[bin]] diff --git a/crates/eval/examples/auth_session_management/base.toml b/crates/eval/examples/auth_session_management/base.toml new file mode 100644 index 0000000000..f34b9a0e44 --- /dev/null +++ b/crates/eval/examples/auth_session_management/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/workos/authkit-js.git" +revision = "949345d85782a93e8f1738ec31823948ffc26301" +language_extension = "ts" diff --git a/crates/eval/examples/auth_session_management/criteria.md b/crates/eval/examples/auth_session_management/criteria.md new file mode 100644 index 0000000000..cfb483450c --- /dev/null +++ b/crates/eval/examples/auth_session_management/criteria.md @@ -0,0 +1,10 @@ +1. Add a new test case in `create-client.test.ts` for when the `returnTo` option is provided during sign-out. It verifies that the sign-out URL includes the correct `return_to` query parameter with the provided URL. The test sets up a mock client, calls signOut with a returnTo value, and asserts that the resulting URL contains the expected session_id and return_to parameters while maintaining the correct API endpoint structure. +2. Modifies the `signOut` method in `create-client.ts` to accept an optional options parameter containing a returnTo string. Instead of directly passing the sessionId to getLogoutUrl, it now passes an object containing both the sessionId and the returnTo value from the options. The method maintains its existing behavior of checking for an access token and clearing session data when a URL is available. +3. Updates the HTTP client tests in `http-client.test.ts` to reflect the new getLogoutUrl signature. It adds a test case for the basic logout URL and a new describe block for when returnTo is provided, verifying that the URL includes the properly encoded return_to parameter. The test ensures the URL construction handles both cases correctly. +4. Modifies the `getLogoutUrl` method in `http-client.ts` to accept an object parameter with sessionId and returnTo properties instead of just a sessionId string. It maintains the base URL construction but now conditionally adds the return_to query parameter only when a returnTo value is provided, while always including the session_id parameter. The method handles URL construction and parameter encoding internally. +5. Updates the session initialization logic in `create-client.ts` to check for either a `workos-has-session` cookie or a refresh token (retrieved via `getRefreshToken`). This allows the client to refresh sessions even if no `code` is present in the URL, especially in development environments. +6. Adds corresponding test coverage in `create-client.test.ts`: + - When no code is in the URL but the `workos-has-session` cookie exists, the session should be refreshed. + - When devMode is enabled and a refresh token is present in localStorage, the session should be refreshed. + - When devMode is enabled but no refresh token exists, the client should be created without making any network requests. + - When neither a code, cookie, nor refresh token is present, the client should initialize without refreshing. diff --git a/crates/eval/examples/auth_session_management/prompt.md b/crates/eval/examples/auth_session_management/prompt.md new file mode 100644 index 0000000000..19081fa060 --- /dev/null +++ b/crates/eval/examples/auth_session_management/prompt.md @@ -0,0 +1,3 @@ +I need to improve our logout feature. When users sign out, they should be able to specify a return URL to redirect to afterward. Right now, signing out just takes them to a default page, but we want to support custom redirects (like back to the homepage or a login screen). The URL should be safely included in the logout request. Make sure existing logouts still work normally when no redirect is specified. + +Also, note that we updated how the client initializes its session. It should now check for either a `workos-has-session` cookie or a valid refresh token (even in devMode). This ensures that sessions are refreshed appropriately even without a code in the URL. Be sure this logic is covered by the minimum tests. diff --git a/crates/eval/examples/checkpoint_stability/base.toml b/crates/eval/examples/checkpoint_stability/base.toml new file mode 100644 index 0000000000..bdd1e912d0 --- /dev/null +++ b/crates/eval/examples/checkpoint_stability/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/cline/cline.git" +revision = "a26494e5cc453f9c7e148d35895fda3f74d03284" +language_extension = "ts" diff --git a/crates/eval/examples/checkpoint_stability/criteria.md b/crates/eval/examples/checkpoint_stability/criteria.md new file mode 100644 index 0000000000..b104713a24 --- /dev/null +++ b/crates/eval/examples/checkpoint_stability/criteria.md @@ -0,0 +1,5 @@ +1. A new changeset file is created to document a patch that improves diff editing animations and enhances prompts for large file edits. An indicator showing the number of diff edits is also added next to each file path. +2. In `diff.ts`, the error message thrown when a `SEARCH` block doesn’t match content has been updated to clarify that the mismatch could be due to out-of-order blocks. +3. In `responses.ts`, the assistant response for diff mismatches now recommends limiting to 1–3 `SEARCH/REPLACE` blocks at a time for large files. It also simplifies fallback instructions for using the `write_to_file` tool. +4. The `DiffViewProvider.ts` file has been updated to replace line-by-line animations with chunk-based updates for better performance. For large diffs, a smooth scrolling animation is introduced to maintain visual context. Small diffs still scroll directly. +5. In `CodeAccordian.tsx`, a new visual indicator displays the number of `REPLACE` blocks in the code diff using a diff icon and count, providing quick insight into the volume of changes. diff --git a/crates/eval/examples/checkpoint_stability/prompt.md b/crates/eval/examples/checkpoint_stability/prompt.md new file mode 100644 index 0000000000..4c97e52ca7 --- /dev/null +++ b/crates/eval/examples/checkpoint_stability/prompt.md @@ -0,0 +1,7 @@ +We're trying to improve both performance and usability when working with large diffs in the editor. A few areas need attention: + +First, the current diff animation applies updates line-by-line, which can feel slow and visually jarring for large edits. Could you revise the logic so that we update the editor in larger chunks instead? For smaller diffs, direct scrolling to the edited line is fine, but for larger changes, it would be great to implement a smooth scrolling animation that steps through the affected region before settling at the final line. + +Second, the current error message when a SEARCH block doesn't match is a bit too vague. Let's make it clearer that the issue could be due to out-of-order or imprecise SEARCH/REPLACE blocks, especially when working with multiple blocks. It might also help to add a suggestion that users try only 1–3 changes at a time for large files before retrying. + +Finally, in the file accordion UI, it would be useful to show how many edits a file contains. Could you parse the diff content and display a count of REPLACE blocks next to the file path, maybe with a small icon for clarity? diff --git a/crates/eval/examples/dd_iaptic_mcp_server_integration/base.toml b/crates/eval/examples/dd_iaptic_mcp_server_integration/base.toml new file mode 100644 index 0000000000..dcf989bca8 --- /dev/null +++ b/crates/eval/examples/dd_iaptic_mcp_server_integration/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/punkpeye/awesome-mcp-servers.git" +revision = "5480a9849b01ae8a5c1433d75ad0415975609571" +language_extension = "md" diff --git a/crates/eval/examples/dd_iaptic_mcp_server_integration/criteria.md b/crates/eval/examples/dd_iaptic_mcp_server_integration/criteria.md new file mode 100644 index 0000000000..fa74ab9d9f --- /dev/null +++ b/crates/eval/examples/dd_iaptic_mcp_server_integration/criteria.md @@ -0,0 +1,5 @@ +1. The diff shows changes to `README.md`, specifically adding a new entry to the "Tools and integrations" list. The new entry is for `@iaptic/mcp-server-iaptic`, which provides access to customer purchase and revenue data. +2. The added line includes: + - The GitHub repository URL + - Three emojis: 🎖️ (possibly representing awards or achievements), 📇 (profiles or contacts), and ☁️ (cloud) + - A description of the tool's functionality: "Connect with [iaptic](https://www.iaptic.com) to ask about your Customer Purchases, Transaction data and App Revenue statistics" diff --git a/crates/eval/examples/dd_iaptic_mcp_server_integration/prompt.md b/crates/eval/examples/dd_iaptic_mcp_server_integration/prompt.md new file mode 100644 index 0000000000..cc88ae4c7f --- /dev/null +++ b/crates/eval/examples/dd_iaptic_mcp_server_integration/prompt.md @@ -0,0 +1,3 @@ +Please add a new tool entry to the README.md file's integration list: "@iaptic/mcp-server-iaptic" with GitHub link, described as "Connect with [iaptic](https://www.iaptic.com) to ask about your Customer Purchases, Transaction data and App Revenue statistics", tagged with the following emojis: 🎖️ 📇 ☁️. Place it appropriately in the existing tools section, following the current alphabetical or category-based order. + +Edit the README file with the above, new resource diff --git a/crates/eval/examples/debian_image_builder/base.toml b/crates/eval/examples/debian_image_builder/base.toml new file mode 100644 index 0000000000..0d0fa6e4a1 --- /dev/null +++ b/crates/eval/examples/debian_image_builder/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/avkcode/container-tools.git" +revision = "34137bb453b4d2dd28b08bd80e26bc3105a50ada" +language_extension = "sh" diff --git a/crates/eval/examples/debian_image_builder/criteria.md b/crates/eval/examples/debian_image_builder/criteria.md new file mode 100644 index 0000000000..5c1dd4ccfc --- /dev/null +++ b/crates/eval/examples/debian_image_builder/criteria.md @@ -0,0 +1,4 @@ +1. Changes to the Makefile where the parameter "--keyrign" was corrected to "--keyring" in multiple build targets including debian11, debian11-java, debian11-java-slim, debian11-graal, debian11-graal-slim, debian11-corretto, debian11-java-slim-maven, debian11-java-slim-gradle, debian11-graal-slim-maven, and debian11-graal-slim-gradle. This appears to be a typo fix across all Java-related build configurations in the Makefile. +2. Introduces significant enhancements to the debian/mkimage.sh script, including adding a usage function with detailed documentation, improving error handling for command-line arguments, and fixing the "--keyrign" parameter to "--keyring" to match the Makefile changes. It also adds better validation for required arguments and more descriptive error messages when values are missing. The script now includes comprehensive documentation about its purpose and usage examples. +3. Shows extensive improvements to the script's functionality and robustness, including adding tracing capabilities, better error handling, and more informative logging. It introduces new helper functions like usage(), die(), warn(), and info() for better user feedback. The script now properly checks for required commands (debootstrap, unzip, trivy) and provides installation instructions if they're missing. It also includes better system checks (Linux OS verification, root privileges check, SELinux status) and implements a more reliable way to handle GPG keys by setting up the correct directory structure and permissions before key import. +4. Continues the script improvements with better package management, repository configuration, and container setup. It adds proper apt repository configuration in the target system, implements package installation with retries, and includes Docker-specific optimizations. The script now provides clearer output about installed packages and their sizes. It also includes better cleanup procedures and more informative completion messages with clear instructions on how to load and run the resulting Docker image. The output now includes example commands and proper formatting for better readability. diff --git a/crates/eval/examples/debian_image_builder/prompt.md b/crates/eval/examples/debian_image_builder/prompt.md new file mode 100644 index 0000000000..4e3651c3d1 --- /dev/null +++ b/crates/eval/examples/debian_image_builder/prompt.md @@ -0,0 +1 @@ +I need to make several improvements to our Debian image-building scripts. First, fix the typo in the `Makefile` where `--keyrign` is incorrectly used instead of `--keyring` across all build targets, including the standard Debian image and Java variants like `debian11-java`, `debian11-graal`, and `debian11-corretto`. Second, enhance the `debian/mkimage.sh` script to include proper error handling, usage documentation, and command-line argument validation. The script should check for required tools like `debootstrap`, `unzip`, and `trivy`, and provide installation instructions if they're missing. Improve the GPG key setup by ensuring the `/root/.gnupg` directory is properly configured before importing keys. Add structured logging with timestamps, warnings, and informational messages. Implement better package installation with retries and proper cleanup. Finally, include clear instructions at the end on how to load and run the generated Docker image, with example commands for verification. The script should be robust, well-documented, and fail early with meaningful error messages if system requirements aren't met. diff --git a/crates/eval/examples/docs_restructure/base.toml b/crates/eval/examples/docs_restructure/base.toml new file mode 100644 index 0000000000..c0917ebe5b --- /dev/null +++ b/crates/eval/examples/docs_restructure/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/YuhangSong/Arena-Baselines.git" +revision = "801ed8566110ddc4a6ada0cc70171c636d78dbb8" +language_extension = "py" diff --git a/crates/eval/examples/docs_restructure/criteria.md b/crates/eval/examples/docs_restructure/criteria.md new file mode 100644 index 0000000000..2a30e3657f --- /dev/null +++ b/crates/eval/examples/docs_restructure/criteria.md @@ -0,0 +1,12 @@ +1. README.md Features Section Reorganization +The features section has been reorganized into two subsections ("Baselines" and "Games") with markdown tables added. The previous bullet points were replaced with more structured content including supported/benchmarked status indicators. A new "Visualization" section was added with TensorBoard and port forwarding instructions. +2. Content Relocation and File Restructuring +The Tennis game documentation and action space details were moved from README.md to a new games.md file. The README was cleaned up by removing commented-out content and consolidating documentation sections. YAML config files (Benchmark-2T1P-Discrete.yaml and Test-Pong.yaml) were modified to replace `selfplay_recent_prob` with `playing_policy_load_recent_prob` and adjust population size options. +3. train.py Refactoring +Significant changes to train.py including: +- Renamed `selfplay_recent_prob` parameter to `playing_policy_load_recent_prob` +- Simplified the nested grid search structure by removing unnecessary loops +- Improved policy loading logic with better checkpoint path handling +- Enhanced error handling and logging for policy saving/reloading +- Removed redundant code and improved code organization +- Added more descriptive console output during policy operations diff --git a/crates/eval/examples/docs_restructure/prompt.md b/crates/eval/examples/docs_restructure/prompt.md new file mode 100644 index 0000000000..08c5c793e8 --- /dev/null +++ b/crates/eval/examples/docs_restructure/prompt.md @@ -0,0 +1,13 @@ +I need to refactor the multi-agent configuration system in our Arena-Baselines repository. The current policy_assignment parameter (self_play, independent) is too coarse. I want to replace it with a more flexible set of parameters to better support advanced training schemes like population-based training (PBT) and sophisticated self-play with historical opponents. + +Specifically, I will introduce four new configuration parameters: + +iterations_per_reload: Controls the frequency (in training iterations) at which policies are saved and potentially reloaded. +num_learning_policies: Explicitly defines how many agents use policies that are actively being trained (can be an integer or 'all'). +selfplay_recent_prob: For non-learning agents (players), this determines the probability of loading the latest version of a learning policy versus loading a uniformly random historical version during reloads. +size_population: Specifies the number of distinct policy versions maintained for each learning agent, enabling PBT-style experiments. +To implement this, I will significantly modify train.py. This includes updating the argument parser, changing how experiment configurations are expanded (especially with grid_search), and implementing a new callback function (on_train_result). This callback will handle the periodic saving (using pickle) of learning policies to structured directories and the reloading of all policies (learning and playing) according to the new parameters (iterations_per_reload, selfplay_recent_prob, size_population). Playing policies will use deterministic actions. + +I'll also reorganize the codebase by renaming arena/rllib_env.py to arena/arena.py and creating a new arena/utils.py file to house utility functions (like configuration helpers, ID generators, DeterministicCategorical) and constants. + +Finally, I will update the example configuration files (Benchmark-2T1P-Discrete.yaml, Test-Pong.yaml) to remove policy_assignment and demonstrate the usage of the new parameters, including within grid_search. diff --git a/crates/eval/examples/expand_laravel_php_support/base.toml b/crates/eval/examples/expand_laravel_php_support/base.toml new file mode 100644 index 0000000000..175c1fca4d --- /dev/null +++ b/crates/eval/examples/expand_laravel_php_support/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/calebporzio/sushi.git" +revision = "01dd34fe3374f5fb7ce63756c0419385e31cd532" +language_extension = "php" diff --git a/crates/eval/examples/expand_laravel_php_support/criteria.md b/crates/eval/examples/expand_laravel_php_support/criteria.md new file mode 100644 index 0000000000..e9ccd7a7d9 --- /dev/null +++ b/crates/eval/examples/expand_laravel_php_support/criteria.md @@ -0,0 +1,3 @@ +1. The GitHub workflow file has been significantly updated to expand testing coverage and improve the CI process. The changes introduce a new `fail-fast: false` setting to allow all matrix combinations to complete even if some fail. The testing matrix now includes PHP 8.4 and Laravel 12.* alongside the existing versions. The configuration includes specific testbench version mappings for Laravel 12.* and removes the DBAL requirement for Laravel 11.* tests. Numerous new test combinations have been added across all Laravel versions to include PHP 8.4 testing. The dependency installation process has been restructured into separate steps - one specifically for DBAL when needed, and another for general dependencies using updated composer commands with precise version constraints. +2. The composer.json file has been updated to support the newly added Laravel 12.* version in both the main requirements and development dependencies. The testbench package now explicitly includes versions 5.* and 10.* in its supported range. For testing tools, PHPUnit 11.* has been added to the list of supported versions while maintaining backward compatibility with older versions. These changes ensure the package can be used with the latest Laravel ecosystem components while preserving compatibility with existing installations. +st file modifications primarily focus on adapting to changes in Laravel 11+ where column type handling was updated. The changes introduce version-aware assertions that check whether to expect 'string' or 'varchar' as column types based on the Laravel version being tested. A new import for the version comparison function was added to support these conditional checks. Additional safeguards were implemented, including a check for the HandlesAnnotations trait before running database migration tests, making the test suite more robust when running in different environments. The column type assertions in multiple test methods were updated to use these version-aware checks to maintain compatibility across Laravel versions. diff --git a/crates/eval/examples/expand_laravel_php_support/prompt.md b/crates/eval/examples/expand_laravel_php_support/prompt.md new file mode 100644 index 0000000000..e193cdb3c6 --- /dev/null +++ b/crates/eval/examples/expand_laravel_php_support/prompt.md @@ -0,0 +1,11 @@ + +I'd like to update our Laravel package's CI workflow and dependencies to ensure compatibility with the upcoming Laravel 12 release and PHP 8.4. Currently, our package supports Laravel versions 5.8 through 11 and PHP versions 7.1 through 8.3, and we'll need to extend this support while maintaining backward compatibility. + +**Key Changes Needed:** +First, we'll need to update composer.json to explicitly support Laravel 12. The CI test matrix should also be expanded to include PHP 8.4 testing across all supported Laravel versions. The workflow configuration will require adjustments to properly handle these new version combinations. + +There are some test compatibility issues we'll need to address - particularly around how we check string column types in Laravel 11+ (where 'string' was changed to 'varchar'), and we should add conditional skipping for tests that depend on traits that might not be available in all test environments. + +While making these changes, we could also implement some workflow improvements: enabling the fail-fast: false option to get complete test results even with individual failures, modernizing our dependency installation approach using the newer composer update syntax, and making the DBAL dependency installation conditional since it's not needed for all test cases. + +Would you be able to help review these changes or suggest any additional considerations we should keep in mind for this compatibility update? I want to make sure we maintain stability while expanding our support coverage. diff --git a/crates/eval/examples/find_and_replace_diff_card/base.toml b/crates/eval/examples/find_and_replace_diff_card/base.toml index 2b14a64530..c88298997d 100644 --- a/crates/eval/examples/find_and_replace_diff_card/base.toml +++ b/crates/eval/examples/find_and_replace_diff_card/base.toml @@ -1,2 +1,3 @@ -path = "../zed_worktree" +url = "https://github.com/zed-industries/zed.git" revision = "38fcadf9481d018543c65f36ac3bafeba190179b" +language_extension = "rs" diff --git a/crates/eval/examples/find_and_replace_diff_card/criteria.md b/crates/eval/examples/find_and_replace_diff_card/criteria.md new file mode 100644 index 0000000000..393056f134 --- /dev/null +++ b/crates/eval/examples/find_and_replace_diff_card/criteria.md @@ -0,0 +1,2 @@ +1. The changes must replace the previous output returned by `FindReplaceFileTool` with the new `ToolResult` struct. The struct should contain an `output` field that is the same as the string we were returning before, and a new `card` field that contains a view for the card +2. The card should be a view that displays a diff. Each line in the diff should be colored according to whether it was added, removed or unchanged. diff --git a/crates/eval/examples/find_and_replace_diff_card/rubric.md b/crates/eval/examples/find_and_replace_diff_card/rubric.md deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/crates/eval/examples/finnish_translation/base.toml b/crates/eval/examples/finnish_translation/base.toml new file mode 100644 index 0000000000..a54cbb4626 --- /dev/null +++ b/crates/eval/examples/finnish_translation/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/sdras/array-explorer.git" +revision = "8ff1a72f7ba24d44946bf591c3586b0dcccc2381" +language_extension = "js" diff --git a/crates/eval/examples/finnish_translation/criteria.md b/crates/eval/examples/finnish_translation/criteria.md new file mode 100644 index 0000000000..356aee78e0 --- /dev/null +++ b/crates/eval/examples/finnish_translation/criteria.md @@ -0,0 +1,12 @@ +1. **EditorConfig Change** +Added a new setting `quote_type = single` to the `.editorconfig` file. This specifies that single quotes should be used for quoting in the codebase. +2. **New Finnish Locale Files** +Added two new Finnish language files: + - `src/locale/fi/index.js`: Contains Finnish translations for UI strings and method descriptions + - `store/fi/index.js`: Contains Finnish translations for all array method documentation (298 lines) + - `store/fi/meta.json`: Metadata about the Finnish translation (language code "fi", full name "Finnish", created by "sjarva") +3. **Store Integration Updates** +Modified `store/index.js` to: + - Import the new Finnish locale files (`import fi from './fi/index'` and `import translationsFi from '../src/locale/fi/index'`) + - Add Finnish to the Vuex store state (`fi`) + - Register Finnish translations with Vue I18n (`Vue.i18n.add('fi', translationsFi)`) diff --git a/crates/eval/examples/finnish_translation/prompt.md b/crates/eval/examples/finnish_translation/prompt.md new file mode 100644 index 0000000000..d4782f41ea --- /dev/null +++ b/crates/eval/examples/finnish_translation/prompt.md @@ -0,0 +1,5 @@ +I’m working on adding Finnish (fi) language support to our array method reference application, which helps users determine the right JavaScript array methods based on their needs. To achieve this, I’ll need to: + +First, create the Finnish locale file containing translations for method selection options, method types (such as add, remove, find, and iterate), and primary action choices. Next, I’ll add Finnish translations to the store, covering all array methods (like splice, push, and unshift), including detailed descriptions of their behaviors, parameters, return values, and example code with outputs. + +Additionally, I’ll generate a Finnish meta file with language metadata (language code, full name, and contributor info). Finally, I’ll update the main store index to integrate Finnish alongside existing languages like English, Spanish, and German. diff --git a/crates/eval/examples/language_model_file_support/base.toml b/crates/eval/examples/language_model_file_support/base.toml new file mode 100644 index 0000000000..5fb211f3af --- /dev/null +++ b/crates/eval/examples/language_model_file_support/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/vercel/ai.git" +revision = "1766edec300deb05c84bb7fefc034af4c2bc1165" +language_extension = "ts" diff --git a/crates/eval/examples/language_model_file_support/criteria.md b/crates/eval/examples/language_model_file_support/criteria.md new file mode 100644 index 0000000000..0f2f6ba492 --- /dev/null +++ b/crates/eval/examples/language_model_file_support/criteria.md @@ -0,0 +1,3 @@ +1. Introduces a new changeset file that documents a patch for the '@ai-sdk/provider' package. The changeset indicates a chore task where 'LanguageModelV2File' is being extracted, suggesting a refactoring effort to modularize the codebase by separating file-related types into their own module. +2. Modifications to the language model v2 index file where a new export statement for 'language-model-v2-file' has been added. This change reflects the extraction mentioned in the changeset and makes the new file type available to other parts of the application. Additionally, there are significant changes to the language model v2 implementation file where the inline file type definition has been replaced with the newly extracted 'LanguageModelV2File' type, both in the main model interface and in the stream part union type, demonstrating the consolidation of file-related types into a single, reusable definition. +3. Present the newly created 'language-model-v2-file.ts' file which defines the 'LanguageModelV2File' type with comprehensive documentation. The type includes two properties: 'mediaType' which specifies the IANA media type of the file with a reference to the official media types registry, and 'data' which can be either a base64 encoded string or binary data, with clear documentation about maintaining the original format from the API without unnecessary conversion. This new file represents the extracted type that is now being used throughout the codebase. diff --git a/crates/eval/examples/language_model_file_support/prompt.md b/crates/eval/examples/language_model_file_support/prompt.md new file mode 100644 index 0000000000..71c5b9fba4 --- /dev/null +++ b/crates/eval/examples/language_model_file_support/prompt.md @@ -0,0 +1 @@ +We need to improve how our language model handles file attachments by making the file type definitions more modular and reusable. Currently, file-related properties are defined inline within the model’s response and stream types, which makes maintenance harder and duplicates documentation. The goal is to extract these definitions into a dedicated type that can be shared consistently across both static responses and streaming payloads. The new type should include clear documentation about media types (referencing IANA standards) and support both base64 and binary data formats without unnecessary conversions. This change should maintain backward compatibility while centralizing the file structure definition for better type safety and readability. Focus on clean separation of concerns, and ensure the extracted type is properly exported and imported where needed. diff --git a/crates/eval/examples/license_management/base.toml b/crates/eval/examples/license_management/base.toml new file mode 100644 index 0000000000..cb63ccc048 --- /dev/null +++ b/crates/eval/examples/license_management/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/SAP-samples/abap-cheat-sheets.git" +revision = "262c0472eeb03e05ff8235767356a328d97850e6" +require_lsp = false diff --git a/crates/eval/examples/license_management/criteria.md b/crates/eval/examples/license_management/criteria.md new file mode 100644 index 0000000000..ad270f4ccf --- /dev/null +++ b/crates/eval/examples/license_management/criteria.md @@ -0,0 +1,3 @@ +1. The file `.reuse/dep5` has been deleted. This file previously contained copyright and licensing information in Debian's copyright format, including details about API usage with SAP products, copyright notice (2022 SAP SE or affiliates), and Apache-2.0 license information. +2. A new file `REUSE.toml` has been created with similar copyright and licensing information but in a different format. It includes the package name, supplier information, download location, and the same detailed disclaimer about API usage with SAP products that was in the deleted file. +3. The new `REUSE.toml` file also contains annotations specifying that the copyright text and Apache-2.0 license apply to all files (`path = "**"`) with aggregate precedence, effectively maintaining the same licensing terms but in a different configuration format. diff --git a/crates/eval/examples/license_management/prompt.md b/crates/eval/examples/license_management/prompt.md new file mode 100644 index 0000000000..df6901fc16 --- /dev/null +++ b/crates/eval/examples/license_management/prompt.md @@ -0,0 +1,17 @@ +I need to switch our license stuff from the old .reuse/dep5 file to the new REUSE.toml format. basically same info, just different format. here's what's in the old file: + +project name: abap-cheat-sheets +contact: daniel reger's email +repo link +that long SAP API disclaimer +copyright: SAP + contributors, 2022 +license: Apache-2.0 +need to: + +delete the old .reuse/dep5 file +make a new REUSE.toml with: +same project info (name, contact, repo) +same exact API disclaimer text +SPDX-style copyright & license fields +apply to all files (** glob) with aggregate precedence +not changing any actual license terms, just updating the format. can you give me the exact REUSE.toml file we need? diff --git a/crates/eval/examples/metal_i64_support/base.toml b/crates/eval/examples/metal_i64_support/base.toml new file mode 100644 index 0000000000..01b0703231 --- /dev/null +++ b/crates/eval/examples/metal_i64_support/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/huggingface/candle.git" +revision = "3164a19a5dc18f5e0f7a063ae85a0cfd289e98f1" +language_extension = "rs" diff --git a/crates/eval/examples/metal_i64_support/criteria.md b/crates/eval/examples/metal_i64_support/criteria.md new file mode 100644 index 0000000000..35741151c9 --- /dev/null +++ b/crates/eval/examples/metal_i64_support/criteria.md @@ -0,0 +1,4 @@ +1. The changes improve the configurability of the `TextGeneration` struct and its initialization by refactoring generation parameters (`temperature`, `top_p`) to use non-optional types with default values, simplifying their use throughout the codebase. +2. The argument parser is updated to enhance usability: `verbose_prompt` is renamed to a more general `verbose` flag, several arguments are given default values (e.g., `temperature`, `top_p`, `sample_len`), and optional arguments like `cache_path` and `weight_path` are now properly handled with conditional logic and fallbacks. +3. The code loading the model configuration is updated to support deserializing from a JSON config file using Serde, and the `Config` struct is extended with a new `rope_ratio` field with a default value via a helper function, improving flexibility for different model setups. +4. Import statements and general code layout are cleaned up for clarity and consistency, including reorganizing imports and removing unnecessary unwraps or panics, while maintaining the same core functionality of the text generation pipeline. diff --git a/crates/eval/examples/metal_i64_support/prompt.md b/crates/eval/examples/metal_i64_support/prompt.md new file mode 100644 index 0000000000..bdc365b1cd --- /dev/null +++ b/crates/eval/examples/metal_i64_support/prompt.md @@ -0,0 +1 @@ +I'd like to improve the configurability and usability of the text generation script for the CodeGeeX4-9B model. Please refactor the argument parsing to set more user-friendly defaults where possible, especially for generation parameters like temperature and top-p, and change fields like verbose_prompt to a more general verbose flag. Simplify the handling of optional paths like cache or weight paths, making them truly optional with fallbacks. I also want the model config to support deserialization from a JSON file instead of relying on hardcoded defaults, including support for a rope_ratio parameter with a sensible default. Lastly, please clean up the code for consistency—such as import ordering—and ensure everything aligns with these improvements without changing the overall functionality. diff --git a/crates/eval/examples/nan_diff_handling/base.toml b/crates/eval/examples/nan_diff_handling/base.toml new file mode 100644 index 0000000000..7a046a28f4 --- /dev/null +++ b/crates/eval/examples/nan_diff_handling/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/AsyncBanana/microdiff" +revision = "ce2055948483d01fb1e96def4ab98d6339d3b2f9" +language_extension = "js" diff --git a/crates/eval/examples/nan_diff_handling/criteria.md b/crates/eval/examples/nan_diff_handling/criteria.md new file mode 100644 index 0000000000..abb19b15bc --- /dev/null +++ b/crates/eval/examples/nan_diff_handling/criteria.md @@ -0,0 +1,6 @@ +1. **NaN Comparison Logic Update**: +The diff modifies the comparison function to explicitly handle NaN values as equivalent. Previously, the function relied on string conversion for NaN comparison, but now it first checks if both values are NaN using Number.isNaN() before proceeding with other comparison logic. This change ensures consistent behavior when comparing NaN values in objects. +2. **New NaN Test Suite - Object Operations**: +A comprehensive test suite is added to verify NaN handling in object operations. The tests cover: creating new objects with NaN values, changing NaN values to other numbers, verifying no changes when NaN values remain the same, and removing properties with NaN values. Each test case validates the diff output structure and type of operation. +3. **New NaN Test Suite - Array Operations**: +The test suite extends to array operations with similar test cases as objects but adapted for array contexts. It tests: adding NaN to arrays, replacing NaN with other numbers, maintaining arrays with unchanged NaN values, and removing NaN elements from arrays. The tests ensure consistent behavior between object and array operations involving NaN values. diff --git a/crates/eval/examples/nan_diff_handling/prompt.md b/crates/eval/examples/nan_diff_handling/prompt.md new file mode 100644 index 0000000000..79e362c69a --- /dev/null +++ b/crates/eval/examples/nan_diff_handling/prompt.md @@ -0,0 +1 @@ +The goal of this update is to fix NaN value handling in our JavaScript object diffing functionality. Currently, the diff function fails to properly recognize that two NaN values should be treated as equal due to JavaScript's native behavior where `NaN !== NaN`. This causes incorrect change detection when comparing objects or arrays containing NaN values. The solution involves modifying the diff function to explicitly check for NaN values using `Number.isNaN()` during comparisons of object keys and values, ensuring NaN values are treated as equivalent. The implementation requires adding specific NaN equivalence checks while maintaining existing comparison logic. Additionally, comprehensive unit tests are being added to verify correct handling across various scenarios: creating objects/arrays with NaN values, changing NaN values to other values, ensuring no false positives when NaN values remain unchanged, and properly tracking removal of NaN values from both objects and arrays. This change will bring the diff behavior in line with mathematical expectations for NaN comparisons while maintaining all other existing functionality. diff --git a/crates/eval/examples/optimizer_schema_refactor/base.toml b/crates/eval/examples/optimizer_schema_refactor/base.toml new file mode 100644 index 0000000000..b29a97f5e8 --- /dev/null +++ b/crates/eval/examples/optimizer_schema_refactor/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/redis/redis-vl-python.git" +revision = "494e5e2f8cf800b90c7383385095c2e503404bc5" +language_extension = "py" diff --git a/crates/eval/examples/optimizer_schema_refactor/criteria.md b/crates/eval/examples/optimizer_schema_refactor/criteria.md new file mode 100644 index 0000000000..97e7c2b8ad --- /dev/null +++ b/crates/eval/examples/optimizer_schema_refactor/criteria.md @@ -0,0 +1,3 @@ +1. The changes involve renaming the `TestData` class to `LabeledData` across multiple files. This includes updating the import statements in `__init__.py`, `cache.py`, `router.py`, `schema.py`, and `utils.py` to reflect this new class name. The `__all__` list in `__init__.py` is also updated to export `LabeledData` instead of `TestData`. This appears to be a conceptual renaming to better reflect the purpose of the data structure. +2. The modifications update all function signatures and type hints that previously used `TestData` to now use `LabeledData`. This affects several functions in `cache.py` including `_generate_run_cache`, `_eval_cache`, and `_grid_search_opt_cache`, as well as functions in `router.py` like `_generate_run_router` and `_eval_router`. The utility functions in `utils.py` are also updated to work with `LabeledData` instead of `TestData`. +3. The changes introduce a new `search_step` parameter in the router optimization logic within `router.py`, with a default value of 0.10. This parameter is passed through to the `_router_random_search` function and is used in the optimization process. The test file `test_threshold_optimizer.py` is updated to explicitly set this parameter to 0.5 when calling the optimize method, demonstrating how it can be configured for different search granularities during threshold optimization. diff --git a/crates/eval/examples/optimizer_schema_refactor/prompt.md b/crates/eval/examples/optimizer_schema_refactor/prompt.md new file mode 100644 index 0000000000..4a4635d1e9 --- /dev/null +++ b/crates/eval/examples/optimizer_schema_refactor/prompt.md @@ -0,0 +1 @@ +I need to refactor our codebase to improve the clarity and consistency of our data model, particularly around how we handle labeled evaluation data for our threshold optimization system. Currently, the naming and structure might imply that this data is only used for testing, when in reality it represents labeled examples that power both training and evaluation. The changes should better reflect that these are curated data points with known outcomes, not just test cases. Focus on updating the core data model and ensuring all dependent components—like the cache optimizer, router, and evaluation utilities—properly reference this updated concept. The implementation should maintain all existing functionality while making the naming more semantically accurate. Where relevant, consider adding parameters to fine-tune optimization behavior, like allowing control over the granularity of threshold searches. diff --git a/crates/eval/examples/rate_limit_endpoints/base.toml b/crates/eval/examples/rate_limit_endpoints/base.toml new file mode 100644 index 0000000000..0a3437f288 --- /dev/null +++ b/crates/eval/examples/rate_limit_endpoints/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/matryer/goblueprints.git" +revision = "68041a598865cc3f4fa2acd4119081a2ea0826bf" +language_extension = "go" diff --git a/crates/eval/examples/rate_limit_endpoints/criteria.md b/crates/eval/examples/rate_limit_endpoints/criteria.md new file mode 100644 index 0000000000..feebae4439 --- /dev/null +++ b/crates/eval/examples/rate_limit_endpoints/criteria.md @@ -0,0 +1,12 @@ +1. The main.go changes introduce rate-limited endpoints by creating them via `MakeEndpoints` and passing them to both HTTP and gRPC servers instead of directly using the service. This includes: + - Adding endpoint creation before server startup + - Modifying HTTP server to use endpoints + - Modifying gRPC server to use endpoints +2. The server_grpc.go changes update the gRPC server implementation to use the provided endpoints instead of creating them internally. This affects both hash and validate endpoints which are now taken from the Endpoints struct rather than being created via makeHashEndpoint/makeValidateEndpoint. +3. The server_http.go changes mirror the gRPC server changes, modifying the HTTP server to use endpoints from the Endpoints struct rather than creating them internally for both hash and validate routes. +4. The service.go changes include: + - Renaming makeHashEndpoint to MakeHashEndpoint and making it public + - Renaming makeValidateEndpoint to MakeValidateEndpoint and making it public + - Adding new MakeEndpoints function that creates rate-limited endpoints using a token bucket (5 requests per second) + - Adding new dependencies for rate limiting (kitrl and ratelimit packages) + - The Endpoints struct remains the same but is now populated with rate-limited versions of the endpoints diff --git a/crates/eval/examples/rate_limit_endpoints/prompt.md b/crates/eval/examples/rate_limit_endpoints/prompt.md new file mode 100644 index 0000000000..91416ed7ad --- /dev/null +++ b/crates/eval/examples/rate_limit_endpoints/prompt.md @@ -0,0 +1,18 @@ +Here’s a more abstract, goal-oriented version of your request without diving into implementation specifics: + +--- + +### **Request: Add Rate Limiting to Vault Service** + +We need to introduce rate limiting to our vault service to protect it from excessive traffic and ensure fair usage. The service currently handles password hashing and validation through both HTTP and gRPC, and we want to enforce a controlled request rate across all endpoints. + +#### **Key Requirements:** +- Apply a global rate limit (e.g., 5 requests per second) to prevent abuse. +- Ensure the rate limiting works consistently across both HTTP and gRPC interfaces. +- Refactor the service to cleanly support rate limiting without breaking existing functionality. +- Maintain flexibility so that limits can be adjusted if needed. + +#### **Implementation Approach (High-Level):** +- Use a token bucket or similar algorithm for smooth rate limiting. +- Integrate with our existing middleware/request pipeline. +- Keep the changes minimal but scalable for future adjustments. diff --git a/crates/eval/examples/request_to_axios_migration/base.toml b/crates/eval/examples/request_to_axios_migration/base.toml new file mode 100644 index 0000000000..85fdcd1569 --- /dev/null +++ b/crates/eval/examples/request_to_axios_migration/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/localtunnel/localtunnel.git" +revision = "4c136a265c2005bcb81bf47709c8ca9b634f2fc1" +language_extension = "js" diff --git a/crates/eval/examples/request_to_axios_migration/criteria.md b/crates/eval/examples/request_to_axios_migration/criteria.md new file mode 100644 index 0000000000..a7c09ce84c --- /dev/null +++ b/crates/eval/examples/request_to_axios_migration/criteria.md @@ -0,0 +1,3 @@ +1. The first change replaces the `request` module import with `axios` in Tunnel.js. This is accompanied by modifications to the request parameters where `path` and `json` fields are removed and replaced with `responseType: 'json'`. The request URI construction is also slightly modified to separate the base URI from the parameters. +2. The second chunk shows significant changes to the request handling logic in Tunnel.js. The callback-based `request` implementation is replaced with a promise-based `axios.get` approach. The error handling is restructured to use `.catch()` instead of checking for errors in the callback. The success case now extracts data from `res.data` instead of directly from the response body, and the status code check looks at `res.status` instead of `res.statusCode`. +3. The third chunk shows changes to package.json where the `request` dependency is removed and replaced with `axios` at version 0.17.1. The dependencies are also reordered, with `debug` and `openurl` moved up and `yargs` moved to the end of the list, though their versions remain unchanged. The devDependencies section remains untouched. diff --git a/crates/eval/examples/request_to_axios_migration/prompt.md b/crates/eval/examples/request_to_axios_migration/prompt.md new file mode 100644 index 0000000000..c5408efee6 --- /dev/null +++ b/crates/eval/examples/request_to_axios_migration/prompt.md @@ -0,0 +1 @@ +I need help modernizing the HTTP client in my Node.js tunneling service. The current implementation uses the older `request` library, which is now deprecated, and I'd like to switch to a more modern, promise-based alternative like `axios`. The changes should maintain all existing functionality—including error handling, retry logic, and response parsing—but improve readability and maintainability by using async/await or proper promise chaining where possible. The request parameters and response handling should be updated to match the new library's conventions while preserving the same behavior for downstream consumers. Additionally, ensure the package.json dependencies are updated accordingly, removing deprecated packages and cleaning up the dependency list. The core tunneling logic should remain unchanged; this is purely about updating the HTTP client layer to be more future-proof. diff --git a/crates/eval/examples/runtime_script_refactor/base.toml b/crates/eval/examples/runtime_script_refactor/base.toml new file mode 100644 index 0000000000..f354196301 --- /dev/null +++ b/crates/eval/examples/runtime_script_refactor/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/thalissonvs/pydoll.git" +revision = "9ea9e91c716b60a7cc8f11ecd865093d460f31aa" +language_extension = "py" diff --git a/crates/eval/examples/runtime_script_refactor/criteria.md b/crates/eval/examples/runtime_script_refactor/criteria.md new file mode 100644 index 0000000000..bc7d77ef58 --- /dev/null +++ b/crates/eval/examples/runtime_script_refactor/criteria.md @@ -0,0 +1,6 @@ +1. **Added RuntimeCommands import and WebElement to page.py** +The changes add an import for `RuntimeCommands` and `WebElement` to `page.py`. The `execute_js_script` method is renamed to `execute_script` and enhanced to support execution in the context of a WebElement. The method now uses `RuntimeCommands` for script evaluation. +2. **Refactored Runtime-related commands from DomCommands to new RuntimeCommands class** +The changes move all Runtime-related command templates and methods from `DomCommands` in `dom.py` to a new `runtime.py` file. This includes `EVALUATE_TEMPLATE`, `CALL_FUNCTION_ON_TEMPLATE`, `GET_PROPERTIES`, and their associated methods. The DomCommands class now uses RuntimeCommands for JavaScript evaluation. +3. **Added Scripts constants and enhanced WebElement functionality** +The changes add a new `Scripts` class to `constants.py` containing JavaScript snippets for common operations. The `element.py` file is significantly enhanced with new methods for script execution, visibility checking, and improved click handling. New exceptions are added to `exceptions.py` for better error handling. diff --git a/crates/eval/examples/runtime_script_refactor/prompt.md b/crates/eval/examples/runtime_script_refactor/prompt.md new file mode 100644 index 0000000000..1c1bfeb6ee --- /dev/null +++ b/crates/eval/examples/runtime_script_refactor/prompt.md @@ -0,0 +1,7 @@ +I'm looking to improve our Python web automation library (pydoll) to make it more robust and maintainable, particularly around JavaScript execution and element interactions. Currently, we need to better organize our Runtime-related commands and enhance how scripts are executed in the browser context. + +The main focus areas include creating a dedicated RuntimeCommands class to centralize all JavaScript-related operations, moving these functions out of DomCommands for cleaner separation of concerns. This new class would handle script evaluation, function calling, and property lookups. We should also enhance the existing page.execute_js_script method—renaming it to execute_script for clarity—and expand its functionality to support execution within specific WebElement contexts, including passing elements as arguments. + +For element interactions, we need more reliable mechanisms, particularly around clicking elements. The improvements would include visibility checks, verifying elements aren't obscured, and implementing proper error handling with descriptive exceptions when interactions fail. The current click implementation should be moved to realistic_click, while the new click method would incorporate these safety checks. Additionally, we should consolidate commonly used JavaScript snippets into a centralized Scripts class for better maintainability. + +The overall goal is to strengthen the library's reliability for automation tasks while making the codebase more organized and easier to maintain. These changes will provide better error handling, clearer structure, and more intuitive APIs for working with page elements and JavaScript execution. Would you be able to help break this down into actionable steps or suggest any improvements to this approach? diff --git a/crates/eval/examples/standardized_docker_dependency_checks/base.toml b/crates/eval/examples/standardized_docker_dependency_checks/base.toml new file mode 100644 index 0000000000..ccd751d5b5 --- /dev/null +++ b/crates/eval/examples/standardized_docker_dependency_checks/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/basecamp/kamal.git" +revision = "0174b872bfc34b66852cffb58514ae079f21d299" +language_extension = "rb" diff --git a/crates/eval/examples/standardized_docker_dependency_checks/criteria.md b/crates/eval/examples/standardized_docker_dependency_checks/criteria.md new file mode 100644 index 0000000000..c8526dab5c --- /dev/null +++ b/crates/eval/examples/standardized_docker_dependency_checks/criteria.md @@ -0,0 +1,7 @@ +1. The changes introduce a new `DependencyError` class in `kamal/cli.rb` alongside other error classes like `BootError` and `HookError`. This new error class will be used to handle dependency-related failures. +2. In `kamal/cli/base.rb`, a new method `ensure_docker_installed` is added which checks for Docker and buildx plugin installation locally. It raises the new `DependencyError` with appropriate messages if either Docker or buildx plugin are not found, replacing similar functionality that was previously scattered elsewhere. +3. The `kamal/cli/build.rb` file is modified to use the new `ensure_docker_installed` method instead of the removed `verify_local_dependencies` method. The error handling is now consistent, using `DependencyError` instead of `BuildError` for dependency-related failures. +4. The `kamal/cli/registry.rb` file now includes a call to `ensure_docker_installed` at the start of the login method, ensuring Docker is available before attempting registry operations. +5. The `kamal/commands/base.rb` file adds a new public method `ensure_docker_installed` that combines checks for both Docker and buildx plugin installation, moving this functionality from the Builder class. +6. The `kamal/commands/builder.rb` file is simplified by removing the `ensure_local_dependencies_installed` method and related private methods, as this functionality has been moved to the base commands class. +7. Test files are updated to reflect these changes, with `build_test.rb` now expecting `DependencyError` instead of `BuildError` for dependency failures, and `registry_test.rb` adding a new test case for Docker dependency checking during login. diff --git a/crates/eval/examples/standardized_docker_dependency_checks/prompt.md b/crates/eval/examples/standardized_docker_dependency_checks/prompt.md new file mode 100644 index 0000000000..b2b13cf579 --- /dev/null +++ b/crates/eval/examples/standardized_docker_dependency_checks/prompt.md @@ -0,0 +1 @@ +I need to improve how our codebase handles Docker dependency checks and error reporting. Right now, the logic for verifying Docker and buildx installations is scattered across different classes, and the error messages aren't consistent. I'd like a more unified approach where we centralize these checks in a single place, making it easier to maintain and reuse. Additionally, we should introduce a dedicated error type for dependency-related failures instead of repurposing existing errors like BuildError. The changes should ensure that any command requiring Docker (like builds or registry logins) properly validates dependencies first, with clear error messages if something is missing. The solution should be clean, follow existing patterns in the codebase, and include any necessary test updates to reflect the new behavior. diff --git a/crates/eval/examples/table_metrics_sorting/base.toml b/crates/eval/examples/table_metrics_sorting/base.toml new file mode 100644 index 0000000000..ef915651e1 --- /dev/null +++ b/crates/eval/examples/table_metrics_sorting/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/duyet/clickhouse-monitoring.git" +revision = "b8ab1a957115f41c916e7061b432ae00b1bbe7db" +language_extension = "ts" diff --git a/crates/eval/examples/table_metrics_sorting/criteria.md b/crates/eval/examples/table_metrics_sorting/criteria.md new file mode 100644 index 0000000000..8a595402eb --- /dev/null +++ b/crates/eval/examples/table_metrics_sorting/criteria.md @@ -0,0 +1,5 @@ +1. The SQL query in tables-overview.ts has been enhanced to include additional metrics for part sizes, both average and maximum. New fields have been added for compressed and uncompressed average part sizes with their readable formats and percentage calculations. Similarly, maximum part size metrics have been added with the same set of calculations. These additions provide more granular visibility into table partition characteristics while maintaining the existing percentage calculations relative to the maximum values across all tables. +2. The column ordering and formatting in tables-overview.ts has been updated to accommodate the new part size metrics. The new readable_avg_part_size and readable_max_part_size columns have been added to the columns array and configured with BackgroundBar formatting. The engine column has been moved to the end of the list for better grouping of related metrics. The sortingFns configuration has been added to specify custom sorting behavior for various compressed and uncompressed size columns. +3. The column definitions system has been enhanced to support custom sorting functions. A new sorting-fns.ts file has been created containing a sort_column_using_actual_value function that enables sorting based on underlying numeric values rather than formatted strings. The getColumnDefs function now checks for both custom and built-in sorting functions in the config and applies them appropriately to column definitions. +4. The data table component has been updated to include custom sorting functions in its configuration. The getCustomSortingFns function is now passed to the table's sortingFns option, making these functions available for all columns. The ValueOf utility type has been added to generic.ts to support proper typing of the sorting functions. +5. The query config type has been extended to include a new optional sortingFns property. This property allows specifying custom sorting functions for specific columns in the table configuration. The type imports have been reorganized, and CustomSortingFnNames is now properly imported and used in the QueryConfig interface. diff --git a/crates/eval/examples/table_metrics_sorting/prompt.md b/crates/eval/examples/table_metrics_sorting/prompt.md new file mode 100644 index 0000000000..903c4ad001 --- /dev/null +++ b/crates/eval/examples/table_metrics_sorting/prompt.md @@ -0,0 +1 @@ +I need to enhance our data table functionality to support more advanced sorting capabilities, particularly for columns that display formatted values (like readable sizes or percentages) but should sort based on their underlying raw numeric values. The table should also include additional metrics for average and maximum part sizes (both compressed and uncompressed) to give better insights into table storage characteristics. These new metrics should follow the same pattern as existing columns, with formatted readable versions, percentage calculations relative to the dataset maximum, and proper sorting behavior. The sorting system should be flexible enough to support both custom sorting logic (like comparing raw numbers behind formatted strings) and built-in sorting methods, with a clean way to configure which columns use which sorting approach. The implementation should maintain consistency with our existing column formatting system and integrate smoothly with the React Table setup we already have in place. diff --git a/crates/eval/examples/tax_id_validation/base.toml b/crates/eval/examples/tax_id_validation/base.toml new file mode 100644 index 0000000000..a3cff0bbbf --- /dev/null +++ b/crates/eval/examples/tax_id_validation/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/go-playground/validator.git" +revision = "4676b8e43bb907ef07f3bcc4ae2a218b05d60397" +language_extension = "go" diff --git a/crates/eval/examples/tax_id_validation/criteria.md b/crates/eval/examples/tax_id_validation/criteria.md new file mode 100644 index 0000000000..3b26ca1812 --- /dev/null +++ b/crates/eval/examples/tax_id_validation/criteria.md @@ -0,0 +1,3 @@ +1. Documentation updates in README.md, where a new validation type for Employer Identification Numbers (EIN) was added to the supported validators table. This addition was carefully positioned between the existing "e164" phone number format and "email" validators to maintain alphabetical ordering. The entry follows the established table format with pipe-separated columns and includes a clear description indicating its purpose for validating U.S. Employer Identification Numbers. Notably, this change was made without modifying any of the existing documentation entries, preserving all current validator descriptions while expanding the supported validation types. +2. Core implementation of the EIN validation across multiple files. In baked_in.go, this involved adding an "ein" entry to the validator map that points to a newly created isEIN function, following the same pattern as other validator registrations. The isEIN() function itself implements the validation logic, checking for both length requirements (exactly 10 characters) and pattern matching using a new regular expression. The regexes.go file was updated with a new einRegexString constant defining the EIN pattern (##-#######) and corresponding regex variable initialization, utilizing the existing lazyRegexCompile helper function for consistency. Documentation was added in doc.go following the established format for validator descriptions, complete with a simple usage example. Throughout these changes, careful attention was paid to maintain consistent error handling patterns and code organization while removing unnecessary newlines in several functions to improve readability. +3. Testing improvements and code quality enhancements, primarily in validator_test.go. A comprehensive TestEINStringValidation test case was added, covering various valid and invalid EIN formats, including tests for length requirements and hyphen positioning. This new test follows the same structure and assertion patterns as existing validation tests. Numerous code quality improvements were made throughout the test file, including grouping interface declarations, fixing comment formatting, removing unnecessary newlines in struct declarations, correcting indentation in test cases, and adding missing newlines between tests. These changes significantly improved code readability while maintaining all existing test logic and ensuring backward compatibility. The improvements demonstrate careful attention to maintaining consistent patterns throughout the test suite while adding thorough test coverage for the new EIN validation functionality. diff --git a/crates/eval/examples/tax_id_validation/prompt.md b/crates/eval/examples/tax_id_validation/prompt.md new file mode 100644 index 0000000000..e1a8ed4d5f --- /dev/null +++ b/crates/eval/examples/tax_id_validation/prompt.md @@ -0,0 +1,10 @@ + +Add validation support for Employer Identification Numbers (EIN) to the Go validator library + +I need to implement a new validator function for US Employer Identification Numbers (EIN) in this Go validation library. The EIN validator should: + +1. Create a new tag called "ein" that validates if a string is a valid US Employer Identification Number +2. Follow the pattern of ##-#######, where # is a digit (regex pattern would be ^(\d{2}-\d{7})$) +3. Ensure the field contains exactly 10 characters (including the hyphen) +4. Document the new validator in the README.md and doc.go files +5. Add proper unit tests to verify validation works correctly for valid and invalid EINs diff --git a/crates/eval/examples/test_infrastructure/base.toml b/crates/eval/examples/test_infrastructure/base.toml new file mode 100644 index 0000000000..2a4fe2d3dd --- /dev/null +++ b/crates/eval/examples/test_infrastructure/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/dagster-io/dagster.git" +revision = "c9ed914a76baa6fb761a97f3236f96cd7d5361e6" +language_extension = "py" diff --git a/crates/eval/examples/test_infrastructure/criteria.md b/crates/eval/examples/test_infrastructure/criteria.md new file mode 100644 index 0000000000..0cdfe6b394 --- /dev/null +++ b/crates/eval/examples/test_infrastructure/criteria.md @@ -0,0 +1,3 @@ +1. Introduce a new docker-compose.yml file in the integration tests directory for the monitoring daemon test suite. This file defines two services: a PostgreSQL database with test credentials exposed on port 5432, and a localstack S3 service exposed on port 4566. These services provide the necessary infrastructure for running the monitoring tests. +2. Shows significant modifications to the test_monitoring.py file, including new imports (boto3, Path, and docker_compose_cm), removal of the dagster_aws tests import, and the addition of new fixtures. The new fixtures handle docker-compose setup, provide hostnames for services, configure AWS environment variables with test credentials, and initialize an S3 bucket for testing purposes. The changes reflect a shift from using external AWS credentials to using localstack for S3 testing. +3. Reveals structural changes to the test file, where the aws_env fixture has been moved from the bottom of the file to be grouped with other fixtures. The original implementation that relied on get_aws_creds() has been replaced with a new implementation that uses localstack with hardcoded test credentials, and the test_docker_monitoring_run_out_of_attempts function remains at the end of the file but now uses the new aws_env fixture implementation. diff --git a/crates/eval/examples/test_infrastructure/prompt.md b/crates/eval/examples/test_infrastructure/prompt.md new file mode 100644 index 0000000000..7428d6b362 --- /dev/null +++ b/crates/eval/examples/test_infrastructure/prompt.md @@ -0,0 +1 @@ +Refactor the monitoring daemon integration tests to use local Docker-managed dependencies instead of direct AWS dependencies. First, create a docker-compose.yml file with two services: a PostgreSQL container with test credentials exposed on port 5432, and a LocalStack S3 container exposed on port 4566. Next, modify the test file to remove reliance on external AWS credentials and replace them with fixtures that configure a LocalStack S3 mock. The fixtures should include session-scoped setup for hostnames, PostgreSQL connections, and AWS environment variables with hardcoded test credentials (e.g., fake access keys). Ensure the S3 fixture initializes a test bucket. Move the AWS environment fixture to align with other fixtures and update the test logic to use the new LocalStack endpoint URL, handling both local and Buildkite environments. Keep the core test cases (like monitoring run attempts) intact but adapt them to use the new Docker-based dependencies. diff --git a/crates/eval/examples/tool_response_handling/base.toml b/crates/eval/examples/tool_response_handling/base.toml new file mode 100644 index 0000000000..cd499cefb3 --- /dev/null +++ b/crates/eval/examples/tool_response_handling/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/block/goose.git" +revision = "d7308457fe3f1b9c7253de45b2f81ddc4f005fe5" +language_extension = "rs" diff --git a/crates/eval/examples/tool_response_handling/criteria.md b/crates/eval/examples/tool_response_handling/criteria.md new file mode 100644 index 0000000000..9aaaa83b43 --- /dev/null +++ b/crates/eval/examples/tool_response_handling/criteria.md @@ -0,0 +1,3 @@ +1. All Goose packages (`goose`, `goose-bench`, `goose-cli`, `goose-mcp`, `goose-server`) were updated from version `1.0.17` to `1.0.18` in `Cargo.lock`. These updates ensure compatibility and consistency across related packages. +2. The `goose-app` version in `ui/desktop/package-lock.json` was also updated to `1.0.18`, maintaining alignment with the backend and shared libraries. +3. In `App.tsx`, the `useConfig` hook was destructured to directly use `addExtension` instead of the older `addExtensionToConfig` function. All occurrences of the old function name were updated, including inside effects and async calls, to use the new unified method. This change simplifies extension handling logic while preserving current behavior. diff --git a/crates/eval/examples/tool_response_handling/prompt.md b/crates/eval/examples/tool_response_handling/prompt.md new file mode 100644 index 0000000000..3358ad6eec --- /dev/null +++ b/crates/eval/examples/tool_response_handling/prompt.md @@ -0,0 +1 @@ +Upgrade all Goose-related packages and apps from version 1.0.17 to 1.0.18 throughout the codebase. This includes updating version references in Cargo.lock, package-lock.json, and source files where applicable. In addition, streamline the addExtension logic in App.tsx by removing the outdated addExtensionToConfig references and replacing them with the new unified addExtension function. Ensure that all function dependencies and hooks reflect this updated usage. The goal is to improve maintainability and consistency across the codebase without introducing any functional changes. diff --git a/crates/eval/examples/toolbar_endpoints/base.toml b/crates/eval/examples/toolbar_endpoints/base.toml new file mode 100644 index 0000000000..016078cbdf --- /dev/null +++ b/crates/eval/examples/toolbar_endpoints/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/django-cms/django-cms.git" +revision = "0b775f27300c4347be18a5bb7b1b172d6a943ccf" +language_extension = "py" diff --git a/crates/eval/examples/toolbar_endpoints/criteria.md b/crates/eval/examples/toolbar_endpoints/criteria.md new file mode 100644 index 0000000000..cc2aba9271 --- /dev/null +++ b/crates/eval/examples/toolbar_endpoints/criteria.md @@ -0,0 +1,3 @@ +1. The changes add two new URL patterns ('cms_placeholder_add_plugin' and 'cms_placeholder_edit_plugin') to the list of endpoints in the toolbar middleware configuration. These endpoints will now be recognized by the toolbar system. +2. The changes add test cases for the new toolbar endpoints in the test file. The first test case verifies that the toolbar is properly attached to requests for the 'cms_placeholder_add_plugin' admin endpoint. The test creates a mock request and checks that the toolbar attribute is present after middleware processing. +3. The changes include a second test case that verifies toolbar functionality for the 'cms_placeholder_edit_plugin' admin endpoint. Similar to the first test, it creates a mock request with plugin ID (1) and checks for the presence of the toolbar attribute after middleware processing. This maintains consistency with the existing test for 'cms_placeholder_clear_placeholder'. diff --git a/crates/eval/examples/toolbar_endpoints/prompt.md b/crates/eval/examples/toolbar_endpoints/prompt.md new file mode 100644 index 0000000000..1291e80a78 --- /dev/null +++ b/crates/eval/examples/toolbar_endpoints/prompt.md @@ -0,0 +1,3 @@ +I'm working on improving the Django CMS toolbar middleware to better support plugin management functionality. Currently, the toolbar is only enabled for specific views defined in the `TOOLBAR_URL_PREFIXES` within toolbar.py, but I've noticed we're missing support for two critical plugin-related operations: adding and editing plugins through the `cms_placeholder_add_plugin` and `cms_placeholder_edit_plugin` views. These views should have access to the toolbar object just like our other administrative actions, as they're fundamental to the content editing experience. + +To implement this enhancement, we'll need to make two key changes. First, we should add both 'cms_placeholder_add_plugin' and 'cms_placeholder_edit_plugin' to the allowed URL prefixes list in cms/middleware/toolbar.py. Second, we should expand our test coverage in cms/tests/test_toolbar.py to verify that the toolbar object is properly attached to requests hitting these endpoints, maintaining consistency with how we test other toolbar-enabled views. This change will ensure a more complete and reliable toolbar experience throughout the entire plugin management workflow. diff --git a/crates/eval/examples/war_and_uri_corrections/base.toml b/crates/eval/examples/war_and_uri_corrections/base.toml new file mode 100644 index 0000000000..bcdfeb9614 --- /dev/null +++ b/crates/eval/examples/war_and_uri_corrections/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/jetty/jetty.project.git" +revision = "dc685b6f84e94ad2eb6a3930769e6eab0cab3fa6" +language_extension = "java" diff --git a/crates/eval/examples/war_and_uri_corrections/criteria.md b/crates/eval/examples/war_and_uri_corrections/criteria.md new file mode 100644 index 0000000000..1e263fd59f --- /dev/null +++ b/crates/eval/examples/war_and_uri_corrections/criteria.md @@ -0,0 +1,7 @@ +1. The changes add an import for `URIUtil` and modify the URL creation in `OSGiApp.java` to use `URIUtil.correctURI()` for proper URI handling. The modification ensures correct URI formatting before converting to URL. +2. The changes add an import for `URIUtil` and modify the URI creation in `Util.java` to use `URIUtil.correctURI()` when handling file paths. This ensures proper URI formatting for paths starting with "file:/". +3. The changes in both `WebInfConfiguration.java` files (EE10 and EE9 versions) refactor the war file handling logic. The modifications: + - Add explanatory comments about looking for sibling directories + - Change how the war path is obtained (using webApp.getPath() instead of creating new resources) + - Restructure the conditional logic for better clarity + - Maintain the same functionality but with improved safety checks and documentation diff --git a/crates/eval/examples/war_and_uri_corrections/prompt.md b/crates/eval/examples/war_and_uri_corrections/prompt.md new file mode 100644 index 0000000000..3c0ac029df --- /dev/null +++ b/crates/eval/examples/war_and_uri_corrections/prompt.md @@ -0,0 +1,7 @@ +I’m working on improvements to a Jetty OSGi application’s file path handling and deployment logic. The changes focus on two main areas: URI normalization and WAR file extraction. + +First, the URI handling logic needs updates to ensure consistent formatting, particularly when dealing with file paths. Currently, there are cases where paths aren’t properly normalized, especially when converting between file URIs and URLs. This affects both core OSGi resource resolution and utility methods that process path strings. The goal is to apply systematic corrections so that paths are reliably formatted across different scenarios. + +Second, the WAR file extraction process requires refinement to make it more robust. The current implementation checks for pre-extracted sibling directories, but the logic could be strengthened by using the resolved webApp path directly rather than reconstructing it from strings. Additionally, the code would benefit from clearer documentation and added safeguards to handle edge cases gracefully. These changes will apply to both the EE9 and EE10 WebApp configurations, ensuring consistent behavior across versions. + +The overarching aim is to reduce deployment failures and improve maintainability while keeping the changes backward-compatible. diff --git a/crates/eval/examples/window_title_support/base.toml b/crates/eval/examples/window_title_support/base.toml new file mode 100644 index 0000000000..3b0e37d2c2 --- /dev/null +++ b/crates/eval/examples/window_title_support/base.toml @@ -0,0 +1,3 @@ +url = "https://github.com/charmbracelet/bubbletea.git" +revision = "bc1c475eb0263aba13ef430f191677e153dc0320" +language_extension = "go" diff --git a/crates/eval/examples/window_title_support/criteria.md b/crates/eval/examples/window_title_support/criteria.md new file mode 100644 index 0000000000..d64b009b5a --- /dev/null +++ b/crates/eval/examples/window_title_support/criteria.md @@ -0,0 +1,4 @@ +1. Adds a new `setWindowTitle` method to the `standardRenderer` struct that sets the terminal window title using the OSC 0 escape sequence. It includes thread safety with mutex locking and uses fmt.Fprintf to send the escape sequence with the provided title. +2. Modifies the `handleMessages` method in `standardRenderer` to handle a new `setWindowTitleMsg` message type by calling the new `setWindowTitle` method. This completes the rendering-side implementation for window title updates. +3. Updates the event loop in the Program struct to properly handle `setWindowTitleMsg` messages by passing them through to the renderer without additional processing, similar to other renderer-specific messages. +4. Adds documentation to the commands tutorial README explaining how to set window titles in Bubble Tea applications. It shows examples of using `tea.SetWindowTitle()` in both Init and Update methods, and explains its usefulness for reflecting application state in the window title. diff --git a/crates/eval/examples/window_title_support/prompt.md b/crates/eval/examples/window_title_support/prompt.md new file mode 100644 index 0000000000..f47418f8a3 --- /dev/null +++ b/crates/eval/examples/window_title_support/prompt.md @@ -0,0 +1,11 @@ +I’d like to add the ability to set terminal window titles in our Bubble Tea framework. This would let applications dynamically update the title bar (e.g., to show status or app names). + +Requirements: + +Expose a user-friendly way to set titles (e.g., a SetWindowTitle command). +Ensure it works cross-platform with standard terminal escape codes. +Include a minimal example and docs showing usage. +Constraints: + +Follow existing patterns for commands/messages. +Thread-safe rendering. diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index 88cca63852..eb2c07e0dc 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -1,32 +1,80 @@ mod example; use assistant_settings::AssistantSettings; -use client::{Client, UserStore}; +use client::{Client, ProxySettings, UserStore}; pub(crate) use example::*; use ::fs::RealFs; -use anyhow::anyhow; -use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task}; +use anyhow::{Result, anyhow}; +use clap::Parser; +use extension::ExtensionHostProxy; +use futures::future; +use gpui::http_client::{Uri, read_proxy_from_env}; +use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task}; +use gpui_tokio::Tokio; use language::LanguageRegistry; use language_model::{ AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, }; -use node_runtime::NodeRuntime; +use node_runtime::{NodeBinaryOptions, NodeRuntime}; use project::Project; +use project::project_settings::ProjectSettings; use prompt_store::PromptBuilder; +use release_channel::AppVersion; use reqwest_client::ReqwestClient; use settings::{Settings, SettingsStore}; +use std::collections::HashSet; +use std::path::{Path, PathBuf}; use std::sync::Arc; +use util::ResultExt as _; + +pub const RUNS_DIR: &str = "./crates/eval/runs"; + +#[derive(Parser, Debug)] +#[command(name = "eval", disable_version_flag = true)] +struct Args { + /// Runs all examples that contain these substrings. If unspecified, all examples are run. + #[arg(value_name = "EXAMPLE_SUBSTRING")] + examples: Vec, + /// Model to use (default: "claude-3-7-sonnet-latest") + #[arg(long, default_value = "claude-3-7-sonnet-latest")] + model: String, + /// Languages to run (comma-separated, e.g. "js,ts,py"). If unspecified, only Rust examples are run. + #[arg(long, value_delimiter = ',')] + languages: Option>, +} fn main() { env_logger::init(); + + let args = Args::parse(); + let all_available_examples = list_all_examples().unwrap(); + let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]); + + let example_paths = all_available_examples + .iter() + .filter_map(|example_path| { + let name = example_path.file_name()?.to_string_lossy(); + if args.examples.is_empty() + || args + .examples + .iter() + .any(|name_substring| name.contains(name_substring)) + { + Some(example_path.clone()) + } else { + None + } + }) + .collect::>(); + let http_client = Arc::new(ReqwestClient::new()); let app = Application::headless().with_http_client(http_client.clone()); app.run(move |cx| { let app_state = init(cx); - let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap(); + let model = find_model("claude-3-7-sonnet-latest", cx).unwrap(); LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.set_default_model(Some(model.clone()), cx); @@ -39,17 +87,155 @@ fn main() { cx.spawn(async move |cx| { authenticate.await.unwrap(); - let example = - Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?; - example.setup()?; - cx.update(|cx| example.run(model, app_state, cx))?.await?; + std::fs::create_dir_all(REPOS_DIR)?; + std::fs::create_dir_all(WORKTREES_DIR)?; - anyhow::Ok(()) + let run_dir = Path::new(RUNS_DIR).join(format!( + "{}", + chrono::Local::now().format("%Y-%m-%d_%H-%M-%S") + )); + std::fs::create_dir_all(&run_dir)?; + + let mut examples = Vec::new(); + for example_path in example_paths { + let example = Example::load_from_directory(&example_path, &run_dir)?; + + if !example + .base + .language_extension + .as_ref() + .map_or(false, |lang| languages.contains(lang)) + { + println!("Skipping {}", example.name); + continue; + } + + println!("{}> Logging to {:?}", example.name, example.log_file_path); + + examples.push(example); + } + let mut repo_urls = HashSet::new(); + + let mut clone_tasks = Vec::new(); + + for example in examples.iter() { + let repo_url = example.base.url.clone(); + if repo_urls.insert(repo_url.clone()) { + let repo_path = repo_path_for_url(&repo_url); + + if !repo_path.join(".git").is_dir() { + println!("Cloning: {}", repo_url); + + let git_task = cx.spawn(async move |_cx| { + std::fs::create_dir_all(&repo_path)?; + run_git(&repo_path, &["init"]).await?; + run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await + }); + + clone_tasks.push(git_task); + } else { + println!("Already cloned: {}", repo_url); + + let actual_origin = + run_git(&repo_path, &["remote", "get-url", "origin"]).await?; + if actual_origin != repo_url { + return Err(anyhow!( + "remote origin {} does not match expected origin {}", + actual_origin, + repo_url, + )); + } + } + } + } + + future::join_all(clone_tasks).await; + + for example in examples.iter() { + example.setup().await?; + } + + let tasks = examples + .into_iter() + .map(|example| { + let app_state = app_state.clone(); + let model = model.clone(); + cx.spawn(async move |cx| { + (run_example(&example, model, app_state, cx).await, example) + }) + }) + .collect::>(); + + let results: Vec<(Result, Example)> = future::join_all(tasks).await; + + println!("\n\n"); + println!("========================================"); + println!(" EVAL RESULTS "); + println!("========================================"); + println!(""); + + let mut judge_scores = Vec::new(); + + for (result, example) in results { + println!("📜 {:<30}: {:?}", example.name, example.log_file_path); + match result { + Err(err) => { + println!("💥 {:<30}: {:?}", example.name, err); + } + Ok(judge_output) => { + const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"]; + + println!( + "{} {:<30}: {}", + SCORES[judge_output.score.min(5) as usize], + example.name, + judge_output.score, + ); + judge_scores.push(judge_output.score); + } + } + } + + let score_count = judge_scores.len(); + let average_score = judge_scores + .into_iter() + .map(|score| score as f32) + .sum::() + / (score_count as f32); + println!("\nAverage score: {average_score}"); + + cx.update(|cx| cx.quit()) }) .detach_and_log_err(cx); }); } +async fn run_example( + example: &Example, + model: Arc, + app_state: Arc, + cx: &mut AsyncApp, +) -> Result { + cx.update(|cx| example.run(model.clone(), app_state, cx))? + .await?; + let diff = example.repository_diff().await?; + example.judge(model, diff, cx).await +} + +fn list_all_examples() -> Result> { + let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap(); + let entries = std::fs::read_dir(path).unwrap(); + let mut result_paths = Vec::new(); + for entry in entries { + let entry = entry?; + let path = entry.path(); + if path.is_dir() { + result_paths.push(path); + } + } + Ok(result_paths) +} + /// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields. pub struct AgentAppState { pub languages: Arc, @@ -72,6 +258,27 @@ pub fn init(cx: &mut App) -> Arc { .unwrap(); cx.set_global(settings_store); client::init_settings(cx); + + // Set User-Agent so we can download language servers from GitHub + let user_agent = format!( + "Zed/{} ({}; {})", + AppVersion::global(cx), + std::env::consts::OS, + std::env::consts::ARCH + ); + let proxy_str = ProxySettings::get_global(cx).proxy.to_owned(); + let proxy_url = proxy_str + .as_ref() + .and_then(|input| input.parse::().ok()) + .or_else(read_proxy_from_env); + let http = { + let _guard = Tokio::handle(cx).enter(); + + ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent) + .expect("could not start HTTP client") + }; + cx.set_http_client(Arc::new(http)); + Project::init_settings(cx); let client = Client::production(cx); @@ -83,13 +290,47 @@ pub fn init(cx: &mut App) -> Arc { cx.background_executor().clone(), )); - let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone())); + let mut languages = LanguageRegistry::new(cx.background_executor().clone()); + languages.set_language_server_download_dir(paths::languages_dir().clone()); + let languages = Arc::new(languages); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + extension::init(cx); + + let (tx, rx) = async_watch::channel(None); + cx.observe_global::(move |cx| { + let settings = &ProjectSettings::get_global(cx).node; + let options = NodeBinaryOptions { + allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(), + allow_binary_download: true, + use_paths: settings.path.as_ref().map(|node_path| { + let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref()); + let npm_path = settings + .npm_path + .as_ref() + .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref())); + ( + node_path.clone(), + npm_path.unwrap_or_else(|| { + let base_path = PathBuf::new(); + node_path.parent().unwrap_or(&base_path).join("npm") + }), + ) + }), + }; + tx.send(Some(options)).log_err(); + }) + .detach(); + let node_runtime = NodeRuntime::new(client.http_client().clone(), rx); + + let extension_host_proxy = ExtensionHostProxy::global(cx); + language::init(cx); + language_extension::init(extension_host_proxy.clone(), languages.clone()); language_model::init(client.clone(), cx); language_models::init(user_store.clone(), client.clone(), fs.clone(), cx); + languages::init(languages.clone(), node_runtime.clone(), cx); assistant_tools::init(client.http_client().clone(), cx); context_server::init(cx); let stdout_is_a_pty = false; @@ -109,7 +350,7 @@ pub fn init(cx: &mut App) -> Arc { client, user_store, fs, - node_runtime: NodeRuntime::unavailable(), + node_runtime, prompt_builder, }) } diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 5d1e4b72a6..de8f1d241f 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -1,83 +1,170 @@ use agent::{RequestKind, ThreadEvent, ThreadStore}; -use anyhow::{Result, anyhow}; +use anyhow::{Context as _, Result, anyhow}; use assistant_tool::ToolWorkingSet; +use client::proto::LspWorkProgress; +use collections::HashMap; use dap::DapRegistry; -use futures::channel::oneshot; -use gpui::{App, Task}; -use language_model::{LanguageModel, StopReason}; -use project::Project; -use serde::Deserialize; -use std::process::Command; -use std::sync::Arc; +use futures::channel::mpsc; +use futures::{FutureExt, StreamExt as _, select_biased}; +use gpui::{App, AppContext as _, AsyncApp, Entity, Task}; +use handlebars::Handlebars; +use language::{DiagnosticSeverity, OffsetRangeExt}; +use language_model::{ + LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, + StopReason, TokenUsage, +}; +use project::{LspStore, Project, ProjectPath}; +use serde::{Deserialize, Serialize}; +use std::fmt::Write as _; +use std::fs::File; +use std::io::Write as _; +use std::sync::{Arc, Mutex}; +use std::time::Duration; use std::{ fs, path::{Path, PathBuf}, }; +use unindent::Unindent as _; +use util::ResultExt as _; +use util::command::new_smol_command; +use util::serde::default_true; use crate::AgentAppState; -#[derive(Debug, Deserialize)] +pub const EXAMPLES_DIR: &str = "./crates/eval/examples"; +pub const REPOS_DIR: &str = "./crates/eval/repos"; +pub const WORKTREES_DIR: &str = "./crates/eval/worktrees"; + +const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2); + +#[derive(Clone, Debug, Deserialize)] pub struct ExampleBase { - pub path: PathBuf, + pub url: String, pub revision: String, + pub language_extension: Option, + pub insert_id: Option, + #[serde(default = "default_true")] + pub require_lsp: bool, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Example { + pub name: String, + /// Content of `base.toml` pub base: ExampleBase, - - /// Content of the prompt.md file + /// Content of `prompt.md` pub prompt: String, + /// Content of `criteria.md` + pub criteria: String, + /// Markdown log file to append to + pub log_file: Arc>, + /// Path to markdown log file + pub log_file_path: PathBuf, +} - /// Content of the rubric.md file - pub _rubric: String, +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct RunOutput { + pub repository_diff: String, + pub diagnostics: String, + pub response_count: usize, + pub token_usage: TokenUsage, + pub tool_use_counts: HashMap, u32>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JudgeInput { + pub repository_diff: String, + pub criteria: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JudgeOutput { + pub analysis: String, + pub score: u32, } impl Example { - /// Load an example from a directory containing base.toml, prompt.md, and rubric.md - pub fn load_from_directory>(dir_path: P) -> Result { - let base_path = dir_path.as_ref().join("base.toml"); - let prompt_path = dir_path.as_ref().join("prompt.md"); - let rubric_path = dir_path.as_ref().join("rubric.md"); + /// Load an example from a directory containing base.toml, prompt.md, and criteria.md + pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result { + let name = dir_path.file_name().unwrap().to_string_lossy().to_string(); + let base_path = dir_path.join("base.toml"); + let prompt_path = dir_path.join("prompt.md"); + let criteria_path = dir_path.join("criteria.md"); - let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?; - base.path = base.path.canonicalize()?; + let log_file_path = run_dir.join(format!( + "{}.md", + dir_path.file_name().unwrap().to_str().unwrap() + )); + let log_file = Arc::new(Mutex::new(File::create(&log_file_path).unwrap())); + println!("{}> Logging to {:?}", name, log_file_path); Ok(Example { - base, - prompt: fs::read_to_string(prompt_path)?, - _rubric: fs::read_to_string(rubric_path)?, + name, + base: toml::from_str(&fs::read_to_string(&base_path)?)?, + prompt: fs::read_to_string(prompt_path.clone())?, + criteria: fs::read_to_string(criteria_path.clone())?, + log_file, + log_file_path, }) } - /// Set up the example by checking out the specified Git revision - pub fn setup(&self) -> Result<()> { - // Check if the directory exists - let path = Path::new(&self.base.path); - anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path); + pub fn worktree_path(&self) -> PathBuf { + Path::new(WORKTREES_DIR) + .canonicalize() + .context(format!("No such directory {WORKTREES_DIR}")) + .unwrap() + .join(&self.name) + } - // Change to the project directory and checkout the specified revision - let output = Command::new("git") - .current_dir(&self.base.path) - .arg("checkout") - .arg(&self.base.revision) - .output()?; - anyhow::ensure!( - output.status.success(), - "Failed to checkout revision {}: {}", - self.base.revision, - String::from_utf8_lossy(&output.stderr), - ); + /// Set up the example by checking out the specified Git revision + pub async fn setup(&self) -> Result<()> { + let repo_path = repo_path_for_url(&self.base.url); + + println!("{}> Fetching", self.name); + + run_git( + &repo_path, + &["fetch", "--depth", "1", "origin", &self.base.revision], + ) + .await?; + + let worktree_path = self.worktree_path(); + + if worktree_path.is_dir() { + println!("{}> Resetting existing worktree", self.name); + + // TODO: consider including "-x" to remove ignored files. The downside of this is that + // it will also remove build artifacts, and so prevent incremental reuse there. + run_git(&worktree_path, &["clean", "--force", "-d"]).await?; + run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?; + run_git(&worktree_path, &["checkout", &self.base.revision]).await?; + } else { + println!("{}> Creating worktree", self.name); + + let worktree_path_string = worktree_path.to_string_lossy().to_string(); + + run_git( + &repo_path, + &[ + "worktree", + "add", + "-f", + &worktree_path_string, + &self.base.revision, + ], + ) + .await?; + } Ok(()) } pub fn run( - self, + &self, model: Arc, app_state: Arc, cx: &mut App, - ) -> Task> { + ) -> Task> { let project = Project::local( app_state.client.clone(), app_state.node_runtime.clone(), @@ -89,91 +176,504 @@ impl Example { cx, ); + let worktree_path = self.worktree_path(); let worktree = project.update(cx, |project, cx| { - project.create_worktree(self.base.path, true, cx) + project.create_worktree(&worktree_path, true, cx) }); - let tools = Arc::new(ToolWorkingSet::default()); + let tools = cx.new(|_| ToolWorkingSet::default()); let thread_store = ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx); + let this = self.clone(); - println!("USER:"); - println!("{}", self.prompt); - println!("ASSISTANT:"); cx.spawn(async move |cx| { - worktree.await?; + let worktree = worktree.await?; + + // Wait for worktree scan to finish before choosing a file to open. + worktree + .update(cx, |worktree, _cx| { + worktree.as_local().unwrap().scan_complete() + })? + .await; + + let lsp_open_handle_and_store = if this.base.require_lsp { + let language_extension = this.base.language_extension.as_deref().context( + "language_extension field is required in base.toml when `require_lsp == true`", + )?; + + // Open a file that matches the language to cause LSP to start. + let language_file = worktree.read_with(cx, |worktree, _cx| { + worktree + .files(false, 0) + .find_map(|e| { + if e.path.clone().extension().and_then(|ext| ext.to_str()) + == Some(language_extension) + { + Some(ProjectPath { + worktree_id: worktree.id(), + path: e.path.clone(), + }) + } else { + None + } + }) + .context("Failed to find a file for example language") + })??; + + let open_language_file_buffer_task = project.update(cx, |project, cx| { + project.open_buffer(language_file.clone(), cx) + })?; + + let language_file_buffer = open_language_file_buffer_task.await?; + + let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| { + ( + project.register_buffer_with_language_servers(&language_file_buffer, cx), + project.lsp_store().clone(), + ) + })?; + + // TODO: remove this once the diagnostics tool waits for new diagnostics + cx.background_executor().timer(Duration::new(5, 0)).await; + wait_for_lang_server(&lsp_store, this.name.clone(), cx).await?; + + lsp_store.update(cx, |lsp_store, cx| { + lsp_open_handle.update(cx, |buffer, cx| { + buffer.update(cx, |buffer, cx| { + let has_language_server = lsp_store + .language_servers_for_local_buffer(buffer, cx) + .next() + .is_some(); + if has_language_server { + Ok(()) + } else { + Err(anyhow!( + "`{:?}` was opened to cause the language server to start, \ + but no language servers are registered for its buffer. \ + Set `require_lsp = false` in `base.toml` to skip this.", + language_file + )) + } + }) + }) + })??; + + Some((lsp_open_handle, lsp_store)) + } else { + None + }; + + if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() { + return Err(anyhow!("Setup only mode")); + } + let thread_store = thread_store.await; let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?; - let (tx, rx) = oneshot::channel(); - let mut tx = Some(tx); + { + let mut log_file = this.log_file.lock().unwrap(); + writeln!(&mut log_file, "👤 USER:").log_err(); + writeln!(&mut log_file, "{}", this.prompt).log_err(); + writeln!(&mut log_file, "🤖 ASSISTANT:").log_err(); + log_file.flush().log_err(); + } - let _subscription = - cx.subscribe( - &thread, - move |thread, event: &ThreadEvent, cx| match event { - ThreadEvent::Stopped(reason) => match reason { - Ok(StopReason::EndTurn) => { - if let Some(tx) = tx.take() { - tx.send(Ok(())).ok(); + let tool_use_counts: Arc, u32>>> = + Mutex::new(HashMap::default()).into(); + + let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded(); + + let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| { + thread_event_tx.unbounded_send(event.clone()).log_err(); + }); + + let event_handler_task = cx.spawn({ + let log_file = this.log_file.clone(); + let name = this.name.clone(); + let tool_use_counts = tool_use_counts.clone(); + let thread = thread.downgrade(); + async move |cx| { + loop { + let event = select_biased! { + event = thread_event_rx.next() => event, + _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => { + return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT)); + } + }; + let Some(event) = event else { + return Err(anyhow!("ThreadEvent channel ended early")); + }; + + let mut log_file = log_file.lock().unwrap(); + + match event { + ThreadEvent::Stopped(reason) => match reason { + Ok(StopReason::EndTurn) => { + return Ok(()); + } + Ok(StopReason::MaxTokens) => { + return Err(anyhow!("Exceeded maximum tokens")); + } + Ok(StopReason::ToolUse) => {} + Err(error) => { + return Err(anyhow!(error.clone())); + } + }, + ThreadEvent::ShowError(thread_error) => { + break Err(anyhow!(thread_error.clone())); + } + ThreadEvent::StreamedAssistantText(_, chunk) => { + write!(&mut log_file, "{}", chunk).log_err(); + } + ThreadEvent::StreamedAssistantThinking(_, chunk) => { + write!(&mut log_file, "{}", chunk).log_err(); + } + ThreadEvent::UsePendingTools { tool_uses } => { + writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err(); + for tool_use in tool_uses { + writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input) + .log_err(); } } - Ok(StopReason::MaxTokens) => { - if let Some(tx) = tx.take() { - tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok(); + ThreadEvent::ToolFinished { + tool_use_id, + pending_tool_use, + .. + } => { + if let Some(tool_use) = pending_tool_use { + let message = format!("TOOL FINISHED: {}", tool_use.name); + println!("{name}> {message}"); + writeln!(&mut log_file, "\n{}", message).log_err(); } + thread.update(cx, |thread, _cx| { + if let Some(tool_result) = thread.tool_result(&tool_use_id) { + writeln!(&mut log_file, "\n{}\n", tool_result.content).log_err(); + let mut tool_use_counts = tool_use_counts.lock().unwrap(); + *tool_use_counts + .entry(tool_result.tool_name.clone()) + .or_insert(0) += 1; + } + })?; } - Ok(StopReason::ToolUse) => {} - Err(error) => { - if let Some(tx) = tx.take() { - tx.send(Err(anyhow!(error.clone()))).ok(); - } - } - }, - ThreadEvent::ShowError(thread_error) => { - if let Some(tx) = tx.take() { - tx.send(Err(anyhow!(thread_error.clone()))).ok(); - } + _ => {} } - ThreadEvent::StreamedAssistantText(_, chunk) => { - print!("{}", chunk); - } - ThreadEvent::StreamedAssistantThinking(_, chunk) => { - print!("{}", chunk); - } - ThreadEvent::UsePendingTools { tool_uses } => { - println!("\n\nUSING TOOLS:"); - for tool_use in tool_uses { - println!("{}: {}", tool_use.name, tool_use.input); - } - } - ThreadEvent::ToolFinished { - tool_use_id, - pending_tool_use, - .. - } => { - if let Some(tool_use) = pending_tool_use { - println!("\nTOOL FINISHED: {}", tool_use.name); - } - if let Some(tool_result) = thread.read(cx).output_for_tool(tool_use_id) - { - println!("\n{}\n", tool_result); - } - } - _ => {} - }, - )?; + log_file.flush().log_err(); + } + } + }); thread.update(cx, |thread, cx| { let context = vec![]; - thread.insert_user_message(self.prompt.clone(), context, None, cx); + thread.insert_user_message(this.prompt.clone(), context, None, cx); thread.send_to_model(model, RequestKind::Chat, cx); })?; - rx.await??; + event_handler_task.await?; - Ok(()) + if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() { + wait_for_lang_server(lsp_store, this.name.clone(), cx).await?; + } + + let repository_diff = this.repository_diff().await?; + let diagnostics = cx + .update(move |cx| { + cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await) + })? + .await?; + + drop(subscription); + drop(lsp_open_handle_and_store); + + thread.update(cx, |thread, _cx| { + let response_count = thread + .messages() + .filter(|message| message.role == language_model::Role::Assistant) + .count(); + RunOutput { + repository_diff, + diagnostics, + response_count, + token_usage: thread.cumulative_token_usage(), + tool_use_counts: tool_use_counts.lock().unwrap().clone(), + } + }) }) } + + pub async fn judge( + &self, + model: Arc, + repository_diff: String, + cx: &AsyncApp, + ) -> Result { + let judge_prompt = include_str!("judge_prompt.hbs"); + let judge_prompt_name = "judge_prompt"; + let mut handlebars = Handlebars::new(); + handlebars.register_template_string(judge_prompt_name, judge_prompt)?; + let prompt = handlebars.render( + judge_prompt_name, + &JudgeInput { + repository_diff, + criteria: self.criteria.clone(), + }, + )?; + + let request = LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text(prompt)], + cache: false, + }], + temperature: None, + tools: Vec::new(), + stop: Vec::new(), + }; + + let response = send_language_model_request(model, request, cx).await?; + + let mut log_file = self.log_file.lock().unwrap(); + + writeln!(&mut log_file, "\n\n").log_err(); + writeln!(&mut log_file, "========================================").log_err(); + writeln!(&mut log_file, " JUDGE OUTPUT ").log_err(); + writeln!(&mut log_file, "========================================").log_err(); + writeln!(&mut log_file, "\n{}", &response).log_err(); + + parse_judge_output(&response) + } + + pub async fn repository_diff(&self) -> Result { + let worktree_path = self.worktree_path(); + run_git(&worktree_path, &["add", "-N"]).await?; + run_git(&worktree_path, &["diff"]).await + } +} + +fn wait_for_lang_server( + lsp_store: &Entity, + name: String, + cx: &mut AsyncApp, +) -> Task> { + if cx + .update(|cx| !has_pending_lang_server_work(lsp_store, cx)) + .unwrap() + || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok() + { + return Task::ready(anyhow::Ok(())); + } + + println!("{}> ⏵ Waiting for language server", name); + + let (mut tx, mut rx) = mpsc::channel(1); + + let subscription = + cx.subscribe(&lsp_store, { + let name = name.clone(); + move |lsp_store, event, cx| { + match event { + project::LspStoreEvent::LanguageServerUpdate { + message: + client::proto::update_language_server::Variant::WorkProgress( + LspWorkProgress { + message: Some(message), + .. + }, + ), + .. + } => println!("{name}> ⟲ {message}"), + _ => {} + } + + if !has_pending_lang_server_work(&lsp_store, cx) { + tx.try_send(()).ok(); + } + } + }); + + cx.spawn(async move |cx| { + let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0)); + let result = futures::select! { + _ = rx.next() => { + println!("{}> ⚑ Language server idle", name); + anyhow::Ok(()) + }, + _ = timeout.fuse() => { + Err(anyhow!("LSP wait timed out after 5 minutes")) + } + }; + drop(subscription); + result + }) +} + +fn has_pending_lang_server_work(lsp_store: &Entity, cx: &App) -> bool { + lsp_store + .read(cx) + .language_server_statuses() + .any(|(_, status)| !status.pending_work.is_empty()) +} + +async fn query_lsp_diagnostics(project: Entity, cx: &mut AsyncApp) -> Result { + let paths_with_diagnostics = project.update(cx, |project, cx| { + project + .diagnostic_summaries(true, cx) + .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0) + .map(|(project_path, _, _)| project_path) + .collect::>() + })?; + + let mut output = String::new(); + for project_path in paths_with_diagnostics { + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx))? + .await?; + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + for (_, group) in snapshot.diagnostic_groups(None) { + let entry = &group.entries[group.primary_ix]; + let range = entry.range.to_point(&snapshot); + let severity = match entry.diagnostic.severity { + DiagnosticSeverity::ERROR => "error", + DiagnosticSeverity::WARNING => "warning", + _ => continue, + }; + + writeln!( + output, + "{} at line {}: {}", + severity, + range.start.row + 1, + entry.diagnostic.message + )?; + } + } + anyhow::Ok(output) +} + +fn parse_judge_output(response: &str) -> Result { + let analysis = get_tag("analysis", response)?.to_string(); + let score = get_tag("score", response)? + .parse() + .context("error parsing score")?; + + Ok(JudgeOutput { analysis, score }) +} + +fn get_tag(name: &'static str, response: &str) -> Result { + let start_tag = format!("<{}>", name); + let end_tag = format!("", name); + + let start_ix = response + .find(&start_tag) + .context(format!("{} start tag not found", name))?; + let content_start_ix = start_ix + start_tag.len(); + + let end_ix = content_start_ix + + response[content_start_ix..] + .find(&end_tag) + .context(format!("{} end tag not found", name))?; + + let content = response[content_start_ix..end_ix].trim().unindent(); + + anyhow::Ok(content) +} + +pub fn repo_path_for_url(repo_url: &str) -> PathBuf { + let repo_name = repo_url + .trim_start_matches("https://") + .replace(|c: char| !c.is_alphanumeric(), "-"); + Path::new(REPOS_DIR) + .canonicalize() + .context(format!("No such directory {REPOS_DIR}")) + .unwrap() + .join(repo_name) +} + +pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result { + let output = new_smol_command("git") + .current_dir(repo_path) + .args(args) + .output() + .await?; + + if output.status.success() { + Ok(String::from_utf8(output.stdout)?.trim().to_string()) + } else { + Err(anyhow!( + "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}", + args.join(" "), + repo_path.display(), + output.status, + String::from_utf8_lossy(&output.stderr), + String::from_utf8_lossy(&output.stdout), + )) + } +} + +pub async fn send_language_model_request( + model: Arc, + request: LanguageModelRequest, + cx: &AsyncApp, +) -> anyhow::Result { + match model.stream_completion_text(request, &cx).await { + Ok(mut stream) => { + let mut full_response = String::new(); + while let Some(chunk_result) = stream.stream.next().await { + match chunk_result { + Ok(chunk_str) => { + print!("{}", &chunk_str); + full_response.push_str(&chunk_str); + } + Err(err) => { + return Err(anyhow!( + "Error receiving response from language model: {err}" + )); + } + } + } + Ok(full_response) + } + Err(err) => Err(anyhow!( + "Failed to get response from language model. Error was: {err}" + )), + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_parse_judge_output() { + let response = r#" + The model did a good job but there were still compilations errors. + 3 + "# + .unindent(); + + let output = parse_judge_output(&response).unwrap(); + assert_eq!( + output.analysis, + "The model did a good job but there were still compilations errors." + ); + assert_eq!(output.score, 3); + + let response = r#" + Text around ignored + + + Failed to compile: + - Error 1 + - Error 2 + + + 1 + "# + .unindent(); + + let output = parse_judge_output(&response).unwrap(); + assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2"); + assert_eq!(output.score, 1); + } } diff --git a/crates/eval/src/judge_prompt.hbs b/crates/eval/src/judge_prompt.hbs new file mode 100644 index 0000000000..862cc0985c --- /dev/null +++ b/crates/eval/src/judge_prompt.hbs @@ -0,0 +1,26 @@ +You are an expert software developer tasked with evaluating the following changes to a codebase: + + +{{repository_diff}} + + +Use the following criteria to score the above changes. + + +{{criteria}} + + +Based on these criteria, give the test output a score between 0 and 5. +The output score should ONLY INCLUDE whole numbers. DO NOT return decimals or floats. + +- 5 means: changes meet all criteria +- 0 means: changes don't meet any criteria + +Be suspicious of the changes because they were generated by an LLM. +Sometimes the LLM decides to change random code, so if the changes are not mentioned in the criteria, penalize the score. +Analyze the diff hunk by hunk and describe how each change meets or fails to meet the criteria. + +``` +{YOUR ANALYSIS HERE} +{YOUR SCORE HERE} +``` diff --git a/crates/extension/Cargo.toml b/crates/extension/Cargo.toml index 5031e1cb85..cf89f41dda 100644 --- a/crates/extension/Cargo.toml +++ b/crates/extension/Cargo.toml @@ -17,10 +17,10 @@ async-compression.workspace = true async-tar.workspace = true async-trait.workspace = true collections.workspace = true -convert_case.workspace = true fs.workspace = true futures.workspace = true gpui.workspace = true +heck.workspace = true http_client.workspace = true language.workspace = true log.workspace = true diff --git a/crates/extension/src/extension_builder.rs b/crates/extension/src/extension_builder.rs index 162f926dda..c6636f03d2 100644 --- a/crates/extension/src/extension_builder.rs +++ b/crates/extension/src/extension_builder.rs @@ -4,9 +4,9 @@ use crate::{ use anyhow::{Context as _, Result, anyhow, bail}; use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; -use convert_case::{Case, Casing as _}; use futures::AsyncReadExt; use futures::io::BufReader; +use heck::ToSnakeCase; use http_client::{self, AsyncBody, HttpClient}; use serde::Deserialize; use std::{ @@ -106,7 +106,7 @@ impl ExtensionBuilder { } for (grammar_name, grammar_metadata) in &extension_manifest.grammars { - let snake_cased_grammar_name = grammar_name.to_case(Case::Snake); + let snake_cased_grammar_name = grammar_name.to_snake_case(); if grammar_name.as_ref() != snake_cased_grammar_name.as_str() { bail!( "grammar name '{grammar_name}' must be written in snake_case: {snake_cased_grammar_name}" diff --git a/crates/file_finder/src/file_finder_tests.rs b/crates/file_finder/src/file_finder_tests.rs index d5d3582858..d2a5f1402d 100644 --- a/crates/file_finder/src/file_finder_tests.rs +++ b/crates/file_finder/src/file_finder_tests.rs @@ -2133,18 +2133,28 @@ async fn test_repeat_toggle_action(cx: &mut gpui::TestAppContext) { cx.dispatch_action(ToggleFileFinder::default()); let picker = active_file_picker(&workspace, cx); + + picker.update_in(cx, |picker, window, cx| { + picker.update_matches(".txt".to_string(), window, cx) + }); + + cx.run_until_parked(); + picker.update(cx, |picker, _| { + assert_eq!(picker.delegate.matches.len(), 6); assert_eq!(picker.delegate.selected_index, 0); - assert_eq!(picker.logical_scroll_top_index(), 0); }); // When toggling repeatedly, the picker scrolls to reveal the selected item. cx.dispatch_action(ToggleFileFinder::default()); cx.dispatch_action(ToggleFileFinder::default()); cx.dispatch_action(ToggleFileFinder::default()); + + cx.run_until_parked(); + picker.update(cx, |picker, _| { + assert_eq!(picker.delegate.matches.len(), 6); assert_eq!(picker.delegate.selected_index, 3); - assert_eq!(picker.logical_scroll_top_index(), 3); }); } diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 584abd4cf7..1a2d28e241 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -5,8 +5,8 @@ use futures::future::{self, BoxFuture}; use git::{ blame::Blame, repository::{ - AskPassDelegate, Branch, CommitDetails, GitRepository, GitRepositoryCheckpoint, - PushOptions, Remote, RepoPath, ResetMode, + AskPassDelegate, Branch, CommitDetails, CommitOptions, GitRepository, + GitRepositoryCheckpoint, PushOptions, Remote, RepoPath, ResetMode, }, status::{FileStatus, GitStatus, StatusCode, TrackedStatus, UnmergedStatus}, }; @@ -21,11 +21,12 @@ pub struct FakeGitRepository { pub(crate) fs: Arc, pub(crate) executor: BackgroundExecutor, pub(crate) dot_git_path: PathBuf, + pub(crate) repository_dir_path: PathBuf, + pub(crate) common_dir_path: PathBuf, } #[derive(Debug, Clone)] pub struct FakeGitRepositoryState { - pub path: PathBuf, pub event_emitter: smol::channel::Sender, pub unmerged_paths: HashMap, pub head_contents: HashMap, @@ -37,9 +38,8 @@ pub struct FakeGitRepositoryState { } impl FakeGitRepositoryState { - pub fn new(path: PathBuf, event_emitter: smol::channel::Sender) -> Self { + pub fn new(event_emitter: smol::channel::Sender) -> Self { FakeGitRepositoryState { - path, event_emitter, head_contents: Default::default(), index_contents: Default::default(), @@ -53,15 +53,6 @@ impl FakeGitRepositoryState { } impl FakeGitRepository { - fn with_state(&self, f: F) -> T - where - F: FnOnce(&mut FakeGitRepositoryState) -> T, - { - self.fs - .with_git_state(&self.dot_git_path, false, f) - .unwrap() - } - fn with_state_async(&self, write: bool, f: F) -> BoxFuture<'static, Result> where F: 'static + Send + FnOnce(&mut FakeGitRepositoryState) -> Result, @@ -172,11 +163,11 @@ impl GitRepository for FakeGitRepository { } fn path(&self) -> PathBuf { - self.with_state(|state| state.path.clone()) + self.repository_dir_path.clone() } fn main_repository_path(&self) -> PathBuf { - self.path() + self.common_dir_path.clone() } fn merge_message(&self) -> BoxFuture> { @@ -207,8 +198,9 @@ impl GitRepository for FakeGitRepository { .files() .iter() .filter_map(|path| { + // TODO better simulate git status output in the case of submodules and worktrees let repo_path = path.strip_prefix(workdir_path).ok()?; - let mut is_ignored = false; + let mut is_ignored = repo_path.starts_with(".git"); for ignore in &ignores { match ignore.matched_path_or_any_parents(path, false) { ignore::Match::None => {} @@ -373,6 +365,7 @@ impl GitRepository for FakeGitRepository { &self, _message: gpui::SharedString, _name_and_email: Option<(gpui::SharedString, gpui::SharedString)>, + _options: CommitOptions, _env: Arc>, ) -> BoxFuture> { unimplemented!() diff --git a/crates/fs/src/fs.rs b/crates/fs/src/fs.rs index 272a05e9b8..bc60e8a2fd 100644 --- a/crates/fs/src/fs.rs +++ b/crates/fs/src/fs.rs @@ -851,7 +851,7 @@ impl Watcher for RealWatcher { pub struct FakeFs { this: std::sync::Weak, // Use an unfair lock to ensure tests are deterministic. - state: Mutex, + state: Arc>, executor: gpui::BackgroundExecutor, } @@ -878,6 +878,8 @@ enum FakeFsEntry { mtime: MTime, len: u64, content: Vec, + // The path to the repository state directory, if this is a gitfile. + git_dir_path: Option, }, Dir { inode: u64, @@ -1036,7 +1038,7 @@ impl FakeFs { let this = Arc::new_cyclic(|this| Self { this: this.clone(), executor: executor.clone(), - state: Mutex::new(FakeFsState { + state: Arc::new(Mutex::new(FakeFsState { root: Arc::new(Mutex::new(FakeFsEntry::Dir { inode: 0, mtime: MTime(UNIX_EPOCH), @@ -1054,7 +1056,7 @@ impl FakeFs { metadata_call_count: 0, moves: Default::default(), home_dir: None, - }), + })), }); executor.spawn({ @@ -1097,6 +1099,7 @@ impl FakeFs { mtime: new_mtime, content: Vec::new(), len: 0, + git_dir_path: None, }))); } btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() { @@ -1154,6 +1157,7 @@ impl FakeFs { mtime: new_mtime, len: new_len, content: new_content, + git_dir_path: None, }))); } btree_map::Entry::Occupied(mut e) => { @@ -1278,9 +1282,14 @@ impl FakeFs { .boxed() } - pub fn with_git_state(&self, dot_git: &Path, emit_git_event: bool, f: F) -> Result + pub fn with_git_state_and_paths( + &self, + dot_git: &Path, + emit_git_event: bool, + f: F, + ) -> Result where - F: FnOnce(&mut FakeGitRepositoryState) -> T, + F: FnOnce(&mut FakeGitRepositoryState, &Path, &Path) -> T, { let mut state = self.state.lock(); let entry = state.read_path(dot_git).context("open .git")?; @@ -1288,25 +1297,75 @@ impl FakeFs { if let FakeFsEntry::Dir { git_repo_state, .. } = &mut *entry { let repo_state = git_repo_state.get_or_insert_with(|| { + log::debug!("insert git state for {dot_git:?}"); Arc::new(Mutex::new(FakeGitRepositoryState::new( - dot_git.to_path_buf(), state.git_event_tx.clone(), ))) }); let mut repo_state = repo_state.lock(); - let result = f(&mut repo_state); + let result = f(&mut repo_state, dot_git, dot_git); if emit_git_event { state.emit_event([(dot_git, None)]); } + Ok(result) + } else if let FakeFsEntry::File { + content, + git_dir_path, + .. + } = &mut *entry + { + let path = match git_dir_path { + Some(path) => path, + None => { + let path = std::str::from_utf8(content) + .ok() + .and_then(|content| content.strip_prefix("gitdir:")) + .ok_or_else(|| anyhow!("not a valid gitfile"))? + .trim(); + git_dir_path.insert(normalize_path(&dot_git.parent().unwrap().join(path))) + } + } + .clone(); + drop(entry); + let Some((git_dir_entry, canonical_path)) = state.try_read_path(&path, true) else { + anyhow::bail!("pointed-to git dir {path:?} not found") + }; + let FakeFsEntry::Dir { git_repo_state, .. } = &mut *git_dir_entry.lock() else { + anyhow::bail!("gitfile points to a non-directory") + }; + let common_dir = canonical_path + .ancestors() + .find(|ancestor| ancestor.ends_with(".git")) + .ok_or_else(|| anyhow!("repository dir not contained in any .git"))?; + let repo_state = git_repo_state.get_or_insert_with(|| { + Arc::new(Mutex::new(FakeGitRepositoryState::new( + state.git_event_tx.clone(), + ))) + }); + let mut repo_state = repo_state.lock(); + + let result = f(&mut repo_state, &canonical_path, common_dir); + + if emit_git_event { + state.emit_event([(canonical_path, None)]); + } + Ok(result) } else { - Err(anyhow!("not a directory")) + Err(anyhow!("not a valid git repository")) } } + pub fn with_git_state(&self, dot_git: &Path, emit_git_event: bool, f: F) -> Result + where + F: FnOnce(&mut FakeGitRepositoryState) -> T, + { + self.with_git_state_and_paths(dot_git, emit_git_event, |state, _, _| f(state)) + } + pub fn set_branch_name(&self, dot_git: &Path, branch: Option>) { self.with_git_state(dot_git, true, |state| { let branch = branch.map(Into::into); @@ -1663,11 +1722,25 @@ impl FakeFsEntry { } #[cfg(any(test, feature = "test-support"))] -struct FakeWatcher {} +struct FakeWatcher { + tx: smol::channel::Sender>, + original_path: PathBuf, + fs_state: Arc>, + prefixes: Mutex>, +} #[cfg(any(test, feature = "test-support"))] impl Watcher for FakeWatcher { - fn add(&self, _: &Path) -> Result<()> { + fn add(&self, path: &Path) -> Result<()> { + if path.starts_with(&self.original_path) { + return Ok(()); + } + self.fs_state + .try_lock() + .unwrap() + .event_txs + .push((path.to_owned(), self.tx.clone())); + self.prefixes.lock().push(path.to_owned()); Ok(()) } @@ -1745,6 +1818,7 @@ impl Fs for FakeFs { mtime, len: 0, content: Vec::new(), + git_dir_path: None, })); let mut kind = Some(PathEventKind::Created); state.write_path(path, |entry| { @@ -1901,6 +1975,7 @@ impl Fs for FakeFs { mtime, len: content.len() as u64, content, + git_dir_path: None, }))) .clone(), )), @@ -2137,42 +2212,54 @@ impl Fs for FakeFs { self.simulate_random_delay().await; let (tx, rx) = smol::channel::unbounded(); let path = path.to_path_buf(); - self.state.lock().event_txs.push((path.clone(), tx)); + self.state.lock().event_txs.push((path.clone(), tx.clone())); let executor = self.executor.clone(); + let watcher = Arc::new(FakeWatcher { + tx, + original_path: path.to_owned(), + fs_state: self.state.clone(), + prefixes: Mutex::new(vec![path.to_owned()]), + }); ( - Box::pin(futures::StreamExt::filter(rx, move |events| { - let result = events - .iter() - .any(|evt_path| evt_path.path.starts_with(&path)); - let executor = executor.clone(); - async move { - executor.simulate_random_delay().await; - result + Box::pin(futures::StreamExt::filter(rx, { + let watcher = watcher.clone(); + move |events| { + let result = events.iter().any(|evt_path| { + let result = watcher + .prefixes + .lock() + .iter() + .any(|prefix| evt_path.path.starts_with(prefix)); + result + }); + let executor = executor.clone(); + async move { + executor.simulate_random_delay().await; + result + } } })), - Arc::new(FakeWatcher {}), + watcher, ) } fn open_repo(&self, abs_dot_git: &Path) -> Option> { - let state = self.state.lock(); - let entry = state.read_path(abs_dot_git).unwrap(); - let mut entry = entry.lock(); - if let FakeFsEntry::Dir { git_repo_state, .. } = &mut *entry { - git_repo_state.get_or_insert_with(|| { - Arc::new(Mutex::new(FakeGitRepositoryState::new( - abs_dot_git.to_path_buf(), - state.git_event_tx.clone(), - ))) - }); - Some(Arc::new(fake_git_repo::FakeGitRepository { - fs: self.this.upgrade().unwrap(), - executor: self.executor.clone(), - dot_git_path: abs_dot_git.to_path_buf(), - })) - } else { - None - } + use util::ResultExt as _; + + self.with_git_state_and_paths( + abs_dot_git, + false, + |_, repository_dir_path, common_dir_path| { + Arc::new(fake_git_repo::FakeGitRepository { + fs: self.this.upgrade().unwrap(), + executor: self.executor.clone(), + dot_git_path: abs_dot_git.to_path_buf(), + repository_dir_path: repository_dir_path.to_owned(), + common_dir_path: common_dir_path.to_owned(), + }) as _ + }, + ) + .log_err() } fn git_init( diff --git a/crates/git/src/git.rs b/crates/git/src/git.rs index 615d807c38..668d5f9ac7 100644 --- a/crates/git/src/git.rs +++ b/crates/git/src/git.rs @@ -50,6 +50,8 @@ actions!( Pull, Fetch, Commit, + Amend, + Cancel, ExpandCommitEditor, GenerateCommitMessage, Init, diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index da68a532e3..28f0d1c910 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -74,6 +74,11 @@ impl Upstream { } } +#[derive(Clone, Copy, Default)] +pub struct CommitOptions { + pub amend: bool, +} + #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] pub enum UpstreamTracking { /// Remote ref not present in local repository. @@ -229,12 +234,6 @@ pub trait GitRepository: Send + Sync { /// worktree's gitdir within the main repository (typically `.git/worktrees/`). fn path(&self) -> PathBuf; - /// Returns the absolute path to the ".git" dir for the main repository, typically a `.git` - /// folder. For worktrees, this will be the path to the repository the worktree was created - /// from. Otherwise, this is the same value as `path()`. - /// - /// Git documentation calls this the "commondir", and for git CLI is overridden by - /// `GIT_COMMON_DIR`. fn main_repository_path(&self) -> PathBuf; /// Updates the index to match the worktree at the given paths. @@ -258,6 +257,7 @@ pub trait GitRepository: Send + Sync { &self, message: SharedString, name_and_email: Option<(SharedString, SharedString)>, + options: CommitOptions, env: Arc>, ) -> BoxFuture>; @@ -374,8 +374,8 @@ impl RealGitRepository { #[derive(Clone, Debug)] pub struct GitRepositoryCheckpoint { - ref_name: String, - commit_sha: Oid, + pub ref_name: String, + pub commit_sha: Oid, } impl GitRepository for RealGitRepository { @@ -963,6 +963,7 @@ impl GitRepository for RealGitRepository { &self, message: SharedString, name_and_email: Option<(SharedString, SharedString)>, + options: CommitOptions, env: Arc>, ) -> BoxFuture> { let working_directory = self.working_directory(); @@ -975,6 +976,10 @@ impl GitRepository for RealGitRepository { .arg(&message.to_string()) .arg("--cleanup=strip"); + if options.amend { + cmd.arg("--amend"); + } + if let Some((name, email)) = name_and_email { cmd.arg("--author").arg(&format!("{name} <{email}>")); } @@ -1771,6 +1776,7 @@ mod tests { repo.commit( "Initial commit".into(), None, + CommitOptions::default(), Arc::new(checkpoint_author_envs()), ) .await @@ -1799,6 +1805,7 @@ mod tests { repo.commit( "Commit after checkpoint".into(), None, + CommitOptions::default(), Arc::new(checkpoint_author_envs()), ) .await diff --git a/crates/git_ui/src/commit_modal.rs b/crates/git_ui/src/commit_modal.rs index 16b8525f75..dd897eb46a 100644 --- a/crates/git_ui/src/commit_modal.rs +++ b/crates/git_ui/src/commit_modal.rs @@ -1,8 +1,11 @@ use crate::branch_picker::{self, BranchList}; use crate::git_panel::{GitPanel, commit_message_editor}; -use git::{Commit, GenerateCommitMessage}; +use git::repository::CommitOptions; +use git::{Amend, Commit, GenerateCommitMessage}; use panel::{panel_button, panel_editor_style, panel_filled_button}; -use ui::{KeybindingHint, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*}; +use ui::{ + ContextMenu, KeybindingHint, PopoverMenu, PopoverMenuHandle, SplitButton, Tooltip, prelude::*, +}; use editor::{Editor, EditorElement}; use gpui::*; @@ -58,6 +61,7 @@ pub struct CommitModal { restore_dock: RestoreDock, properties: ModalContainerProperties, branch_list_handle: PopoverMenuHandle, + commit_menu_handle: PopoverMenuHandle, } impl Focusable for CommitModal { @@ -95,19 +99,47 @@ struct RestoreDock { active_index: Option, } +pub enum ForceMode { + Amend, + Commit, +} + impl CommitModal { pub fn register(workspace: &mut Workspace) { workspace.register_action(|workspace, _: &Commit, window, cx| { - CommitModal::toggle(workspace, window, cx); + CommitModal::toggle(workspace, Some(ForceMode::Commit), window, cx); + }); + workspace.register_action(|workspace, _: &Amend, window, cx| { + CommitModal::toggle(workspace, Some(ForceMode::Amend), window, cx); }); } - pub fn toggle(workspace: &mut Workspace, window: &mut Window, cx: &mut Context) { + pub fn toggle( + workspace: &mut Workspace, + force_mode: Option, + window: &mut Window, + cx: &mut Context, + ) { let Some(git_panel) = workspace.panel::(cx) else { return; }; git_panel.update(cx, |git_panel, cx| { + if let Some(force_mode) = force_mode { + match force_mode { + ForceMode::Amend => { + if !git_panel.amend_pending() { + git_panel.set_amend_pending(true, cx); + git_panel.load_last_commit_message_if_empty(cx); + } + } + ForceMode::Commit => { + if git_panel.amend_pending() { + git_panel.set_amend_pending(false, cx); + } + } + } + } git_panel.set_modal_open(true, cx); }); @@ -164,7 +196,9 @@ impl CommitModal { let focus_handle = commit_editor.focus_handle(cx); cx.on_focus_out(&focus_handle, window, |this, _, window, cx| { - if !this.branch_list_handle.is_focused(window, cx) { + if !this.branch_list_handle.is_focused(window, cx) + && !this.commit_menu_handle.is_focused(window, cx) + { cx.emit(DismissEvent); } }) @@ -178,6 +212,7 @@ impl CommitModal { restore_dock, properties, branch_list_handle: PopoverMenuHandle::default(), + commit_menu_handle: PopoverMenuHandle::default(), } } @@ -214,23 +249,68 @@ impl CommitModal { ) } + fn render_git_commit_menu( + &self, + id: impl Into, + keybinding_target: Option, + ) -> impl IntoElement { + PopoverMenu::new(id.into()) + .trigger( + ui::ButtonLike::new_rounded_right("commit-split-button-right") + .layer(ui::ElevationIndex::ModalSurface) + .size(ui::ButtonSize::None) + .child( + div() + .px_1() + .child(Icon::new(IconName::ChevronDownSmall).size(IconSize::XSmall)), + ), + ) + .menu(move |window, cx| { + Some(ContextMenu::build(window, cx, |context_menu, _, _| { + context_menu + .when_some(keybinding_target.clone(), |el, keybinding_target| { + el.context(keybinding_target.clone()) + }) + .action("Amend...", Amend.boxed_clone()) + })) + }) + .with_handle(self.commit_menu_handle.clone()) + .anchor(Corner::TopRight) + } + pub fn render_footer(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { - let (can_commit, tooltip, commit_label, co_authors, generate_commit_message, active_repo) = - self.git_panel.update(cx, |git_panel, cx| { - let (can_commit, tooltip) = git_panel.configure_commit_button(cx); - let title = git_panel.commit_button_title(); - let co_authors = git_panel.render_co_authors(cx); - let generate_commit_message = git_panel.render_generate_commit_message_button(cx); - let active_repo = git_panel.active_repository.clone(); - ( - can_commit, - tooltip, - title, - co_authors, - generate_commit_message, - active_repo, - ) - }); + let ( + can_commit, + tooltip, + commit_label, + co_authors, + generate_commit_message, + active_repo, + is_amend_pending, + has_previous_commit, + ) = self.git_panel.update(cx, |git_panel, cx| { + let (can_commit, tooltip) = git_panel.configure_commit_button(cx); + let title = git_panel.commit_button_title(); + let co_authors = git_panel.render_co_authors(cx); + let generate_commit_message = git_panel.render_generate_commit_message_button(cx); + let active_repo = git_panel.active_repository.clone(); + let is_amend_pending = git_panel.amend_pending(); + let has_previous_commit = active_repo + .as_ref() + .and_then(|repo| repo.read(cx).branch.as_ref()) + .and_then(|branch| branch.most_recent_commit.as_ref()) + .is_some(); + ( + can_commit, + tooltip, + title, + co_authors, + generate_commit_message, + active_repo, + is_amend_pending, + has_previous_commit, + ) + }); let branch = active_repo .as_ref() @@ -277,21 +357,6 @@ impl CommitModal { None }; - let commit_button = panel_filled_button(commit_label) - .tooltip({ - let panel_editor_focus_handle = focus_handle.clone(); - move |window, cx| { - Tooltip::for_action_in(tooltip, &Commit, &panel_editor_focus_handle, window, cx) - } - }) - .disabled(!can_commit) - .on_click(cx.listener(move |this, _: &ClickEvent, window, cx| { - telemetry::event!("Git Committed", source = "Git Modal"); - this.git_panel - .update(cx, |git_panel, cx| git_panel.commit_changes(window, cx)); - cx.emit(DismissEvent); - })); - h_flex() .group("commit_editor_footer") .flex_none() @@ -324,21 +389,143 @@ impl CommitModal { .px_1() .gap_4() .children(close_kb_hint) - .child(commit_button), + .when(is_amend_pending, |this| { + let focus_handle = focus_handle.clone(); + this.child( + panel_filled_button(commit_label) + .tooltip(move |window, cx| { + if can_commit { + Tooltip::for_action_in( + tooltip, + &Amend, + &focus_handle, + window, + cx, + ) + } else { + Tooltip::simple(tooltip, cx) + } + }) + .disabled(!can_commit) + .on_click(move |_, window, cx| { + window.dispatch_action(Box::new(git::Commit), cx); + }), + ) + }) + .when(!is_amend_pending, |this| { + this.when(has_previous_commit, |this| { + this.child(SplitButton::new( + ui::ButtonLike::new_rounded_left(ElementId::Name( + format!("split-button-left-{}", commit_label).into(), + )) + .layer(ui::ElevationIndex::ModalSurface) + .size(ui::ButtonSize::Compact) + .child( + div() + .child(Label::new(commit_label).size(LabelSize::Small)) + .mr_0p5(), + ) + .on_click(move |_, window, cx| { + window.dispatch_action(Box::new(git::Commit), cx); + }) + .disabled(!can_commit) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |window, cx| { + if can_commit { + Tooltip::with_meta_in( + tooltip, + Some(&git::Commit), + "git commit", + &focus_handle.clone(), + window, + cx, + ) + } else { + Tooltip::simple(tooltip, cx) + } + } + }), + self.render_git_commit_menu( + ElementId::Name( + format!("split-button-right-{}", commit_label).into(), + ), + Some(focus_handle.clone()), + ) + .into_any_element(), + )) + }) + .when(!has_previous_commit, |this| { + this.child( + panel_filled_button(commit_label) + .tooltip(move |window, cx| { + if can_commit { + Tooltip::with_meta_in( + tooltip, + Some(&git::Commit), + "git commit", + &focus_handle, + window, + cx, + ) + } else { + Tooltip::simple(tooltip, cx) + } + }) + .disabled(!can_commit) + .on_click(move |_, window, cx| { + window.dispatch_action(Box::new(git::Commit), cx); + }), + ) + }) + }), ) } fn dismiss(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { - cx.emit(DismissEvent); + if self.git_panel.read(cx).amend_pending() { + self.git_panel + .update(cx, |git_panel, cx| git_panel.set_amend_pending(false, cx)); + } else { + cx.emit(DismissEvent); + } } fn commit(&mut self, _: &git::Commit, window: &mut Window, cx: &mut Context) { + if self.git_panel.read(cx).amend_pending() { + return; + } telemetry::event!("Git Committed", source = "Git Modal"); - self.git_panel - .update(cx, |git_panel, cx| git_panel.commit_changes(window, cx)); + self.git_panel.update(cx, |git_panel, cx| { + git_panel.commit_changes(CommitOptions { amend: false }, window, cx) + }); cx.emit(DismissEvent); } + fn amend(&mut self, _: &git::Amend, window: &mut Window, cx: &mut Context) { + if self + .commit_editor + .focus_handle(cx) + .contains_focused(window, cx) + { + if !self.git_panel.read(cx).amend_pending() { + self.git_panel.update(cx, |git_panel, cx| { + git_panel.set_amend_pending(true, cx); + git_panel.load_last_commit_message_if_empty(cx); + }); + } else { + telemetry::event!("Git Amended", source = "Git Panel"); + self.git_panel.update(cx, |git_panel, cx| { + git_panel.set_amend_pending(false, cx); + git_panel.commit_changes(CommitOptions { amend: true }, window, cx); + }); + cx.emit(DismissEvent); + } + } else { + cx.propagate(); + } + } + fn toggle_branch_selector(&mut self, window: &mut Window, cx: &mut Context) { if self.branch_list_handle.is_focused(window, cx) { self.focus_handle(cx).focus(window) @@ -361,6 +548,7 @@ impl Render for CommitModal { .key_context("GitCommit") .on_action(cx.listener(Self::dismiss)) .on_action(cx.listener(Self::commit)) + .on_action(cx.listener(Self::amend)) .on_action(cx.listener(|this, _: &GenerateCommitMessage, _, cx| { this.git_panel.update(cx, |panel, cx| { panel.generate_commit_message(cx); diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index 3b0acc4161..e8d27cd442 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -21,11 +21,11 @@ use editor::{ use futures::StreamExt as _; use git::blame::ParsedCommitMessage; use git::repository::{ - Branch, CommitDetails, CommitSummary, DiffType, PushOptions, Remote, RemoteCommandOutput, - ResetMode, Upstream, UpstreamTracking, UpstreamTrackingStatus, + Branch, CommitDetails, CommitOptions, CommitSummary, DiffType, PushOptions, Remote, + RemoteCommandOutput, ResetMode, Upstream, UpstreamTracking, UpstreamTrackingStatus, }; use git::status::StageStatus; -use git::{Commit, ToggleStaged, repository::RepoPath, status::FileStatus}; +use git::{Amend, ToggleStaged, repository::RepoPath, status::FileStatus}; use git::{ExpandCommitEditor, RestoreTrackedFiles, StageAll, TrashUntrackedFiles, UnstageAll}; use gpui::{ Action, Animation, AnimationExt as _, Axis, ClickEvent, Corner, DismissEvent, Entity, @@ -59,8 +59,8 @@ use std::{collections::HashSet, sync::Arc, time::Duration, usize}; use strum::{IntoEnumIterator, VariantNames}; use time::OffsetDateTime; use ui::{ - Checkbox, ContextMenu, ElevationIndex, PopoverMenu, Scrollbar, ScrollbarState, Tooltip, - prelude::*, + Checkbox, ContextMenu, ElevationIndex, PopoverMenu, Scrollbar, ScrollbarState, SplitButton, + Tooltip, prelude::*, }; use util::{ResultExt, TryFutureExt, maybe}; use workspace::AppState; @@ -167,7 +167,7 @@ pub fn register(workspace: &mut Workspace) { workspace.toggle_panel_focus::(window, cx); }); workspace.register_action(|workspace, _: &ExpandCommitEditor, window, cx| { - CommitModal::toggle(workspace, window, cx) + CommitModal::toggle(workspace, None, window, cx) }); } @@ -340,6 +340,7 @@ pub struct GitPanel { new_staged_count: usize, pending: Vec, pending_commit: Option>, + amend_pending: bool, pending_serialization: Task>, pub(crate) project: Entity, scroll_handle: UniformListScrollHandle, @@ -492,6 +493,7 @@ impl GitPanel { new_staged_count: 0, pending: Vec::new(), pending_commit: None, + amend_pending: false, pending_serialization: Task::ready(None), single_staged_entry: None, single_tracked_entry: None, @@ -1417,18 +1419,81 @@ impl GitPanel { } fn commit(&mut self, _: &git::Commit, window: &mut Window, cx: &mut Context) { + if self.amend_pending { + return; + } if self .commit_editor .focus_handle(cx) .contains_focused(window, cx) { telemetry::event!("Git Committed", source = "Git Panel"); - self.commit_changes(window, cx) + self.commit_changes(CommitOptions { amend: false }, window, cx) } else { cx.propagate(); } } + pub fn load_last_commit_message_if_empty(&mut self, cx: &mut Context) { + if !self.commit_editor.read(cx).is_empty(cx) { + return; + } + let Some(active_repository) = self.active_repository.as_ref() else { + return; + }; + let Some(branch) = active_repository.read(cx).branch.as_ref() else { + return; + }; + let Some(recent_sha) = branch + .most_recent_commit + .as_ref() + .map(|commit| commit.sha.to_string()) + else { + return; + }; + let detail_task = self.load_commit_details(recent_sha, cx); + cx.spawn(async move |this, cx| { + if let Ok(message) = detail_task.await.map(|detail| detail.message) { + this.update(cx, |this, cx| { + this.commit_message_buffer(cx).update(cx, |buffer, cx| { + let start = buffer.anchor_before(0); + let end = buffer.anchor_after(buffer.len()); + buffer.edit([(start..end, message)], None, cx); + }); + }) + .log_err(); + } + }) + .detach(); + } + + fn amend(&mut self, _: &git::Amend, window: &mut Window, cx: &mut Context) { + if self + .commit_editor + .focus_handle(cx) + .contains_focused(window, cx) + { + if !self.amend_pending { + self.amend_pending = true; + cx.notify(); + self.load_last_commit_message_if_empty(cx); + } else { + telemetry::event!("Git Amended", source = "Git Panel"); + self.amend_pending = false; + self.commit_changes(CommitOptions { amend: true }, window, cx); + } + } else { + cx.propagate(); + } + } + + fn cancel(&mut self, _: &git::Cancel, _: &mut Window, cx: &mut Context) { + if self.amend_pending { + self.amend_pending = false; + cx.notify(); + } + } + fn custom_or_suggested_commit_message(&self, cx: &mut Context) -> Option { let message = self.commit_editor.read(cx).text(cx); @@ -1440,7 +1505,12 @@ impl GitPanel { .filter(|message| !message.trim().is_empty()) } - pub(crate) fn commit_changes(&mut self, window: &mut Window, cx: &mut Context) { + pub(crate) fn commit_changes( + &mut self, + options: CommitOptions, + window: &mut Window, + cx: &mut Context, + ) { let Some(active_repository) = self.active_repository.clone() else { return; }; @@ -1474,8 +1544,9 @@ impl GitPanel { let task = if self.has_staged_changes() { // Repository serializes all git operations, so we can just send a commit immediately - let commit_task = - active_repository.update(cx, |repo, cx| repo.commit(message.into(), None, cx)); + let commit_task = active_repository.update(cx, |repo, cx| { + repo.commit(message.into(), None, options, cx) + }); cx.background_spawn(async move { commit_task.await? }) } else { let changed_files = self @@ -1495,8 +1566,9 @@ impl GitPanel { active_repository.update(cx, |repo, cx| repo.stage_entries(changed_files, cx)); cx.spawn(async move |_, cx| { stage_task.await?; - let commit_task = active_repository - .update(cx, |repo, cx| repo.commit(message.into(), None, cx))?; + let commit_task = active_repository.update(cx, |repo, cx| { + repo.commit(message.into(), None, options, cx) + })?; commit_task.await? }) }; @@ -2722,6 +2794,34 @@ impl GitPanel { } } + fn render_git_commit_menu( + &self, + id: impl Into, + keybinding_target: Option, + ) -> impl IntoElement { + PopoverMenu::new(id.into()) + .trigger( + ui::ButtonLike::new_rounded_right("commit-split-button-right") + .layer(ui::ElevationIndex::ModalSurface) + .size(ui::ButtonSize::None) + .child( + div() + .px_1() + .child(Icon::new(IconName::ChevronDownSmall).size(IconSize::XSmall)), + ), + ) + .menu(move |window, cx| { + Some(ContextMenu::build(window, cx, |context_menu, _, _| { + context_menu + .when_some(keybinding_target.clone(), |el, keybinding_target| { + el.context(keybinding_target.clone()) + }) + .action("Amend...", Amend.boxed_clone()) + })) + }) + .anchor(Corner::TopRight) + } + pub fn configure_commit_button(&self, cx: &mut Context) -> (bool, &'static str) { if self.has_unstaged_conflicts() { (false, "You must resolve conflicts before committing") @@ -2739,10 +2839,18 @@ impl GitPanel { } pub fn commit_button_title(&self) -> &'static str { - if self.has_staged_changes() { - "Commit" + if self.amend_pending { + if self.has_staged_changes() { + "Amend" + } else { + "Amend Tracked" + } } else { - "Commit Tracked" + if self.has_staged_changes() { + "Commit" + } else { + "Commit Tracked" + } } } @@ -2756,7 +2864,7 @@ impl GitPanel { window.defer(cx, move |window, cx| { workspace .update(cx, |workspace, cx| { - CommitModal::toggle(workspace, window, cx) + CommitModal::toggle(workspace, None, window, cx) }) .ok(); }) @@ -2885,6 +2993,10 @@ impl GitPanel { let editor_is_long = self.commit_editor.update(cx, |editor, cx| { editor.max_point(cx).row().0 >= MAX_PANEL_EDITOR_LINES as u32 }); + let has_previous_commit = branch + .as_ref() + .and_then(|branch| branch.most_recent_commit.as_ref()) + .is_some(); let footer = v_flex() .child(PanelRepoFooter::new(display_name, branch, Some(git_panel))) @@ -2920,32 +3032,140 @@ impl GitPanel { .unwrap_or_else(|| div().into_any_element()), ) .child( - h_flex().gap_0p5().children(enable_coauthors).child( - panel_filled_button(title) - .tooltip(move |window, cx| { - if can_commit { - Tooltip::for_action_in( - tooltip, - &Commit, - &commit_tooltip_focus_handle, - window, - cx, + h_flex() + .gap_0p5() + .children(enable_coauthors) + .when(self.amend_pending, { + |this| { + this.h_flex() + .gap_1() + .child( + panel_filled_button("Cancel") + .tooltip({ + let handle = + commit_tooltip_focus_handle.clone(); + move |window, cx| { + Tooltip::for_action_in( + "Cancel amend", + &git::Cancel, + &handle, + window, + cx, + ) + } + }) + .on_click(move |_, window, cx| { + window.dispatch_action( + Box::new(git::Cancel), + cx, + ); + }), ) - } else { - Tooltip::simple(tooltip, cx) - } + .child( + panel_filled_button(title) + .tooltip({ + let handle = + commit_tooltip_focus_handle.clone(); + move |window, cx| { + if can_commit { + Tooltip::for_action_in( + tooltip, &Amend, &handle, + window, cx, + ) + } else { + Tooltip::simple(tooltip, cx) + } + } + }) + .disabled(!can_commit || self.modal_open) + .on_click(move |_, window, cx| { + window.dispatch_action( + Box::new(git::Amend), + cx, + ); + }), + ) + } + }) + .when(!self.amend_pending, |this| { + this.when(has_previous_commit, |this| { + this.child(SplitButton::new( + ui::ButtonLike::new_rounded_left(ElementId::Name( + format!("split-button-left-{}", title).into(), + )) + .layer(ui::ElevationIndex::ModalSurface) + .size(ui::ButtonSize::Compact) + .child( + div() + .child( + Label::new(title) + .size(LabelSize::Small), + ) + .mr_0p5(), + ) + .on_click(move |_, window, cx| { + window + .dispatch_action(Box::new(git::Commit), cx); + }) + .disabled(!can_commit || self.modal_open) + .tooltip({ + let handle = + commit_tooltip_focus_handle.clone(); + move |window, cx| { + if can_commit { + Tooltip::with_meta_in( + tooltip, + Some(&git::Commit), + "git commit", + &handle.clone(), + window, + cx, + ) + } else { + Tooltip::simple(tooltip, cx) + } + } + }), + self.render_git_commit_menu( + ElementId::Name( + format!("split-button-right-{}", title) + .into(), + ), + Some(commit_tooltip_focus_handle.clone()), + ) + .into_any_element(), + )) }) - .disabled(!can_commit || self.modal_open) - .on_click({ - cx.listener(move |this, _: &ClickEvent, window, cx| { - telemetry::event!( - "Git Committed", - source = "Git Panel" - ); - this.commit_changes(window, cx) - }) - }), - ), + .when( + !has_previous_commit, + |this| { + this.child( + panel_filled_button(title) + .tooltip(move |window, cx| { + if can_commit { + Tooltip::with_meta_in( + tooltip, + Some(&git::Commit), + "git commit", + &commit_tooltip_focus_handle, + window, + cx, + ) + } else { + Tooltip::simple(tooltip, cx) + } + }) + .disabled(!can_commit || self.modal_open) + .on_click(move |_, window, cx| { + window.dispatch_action( + Box::new(git::Commit), + cx, + ); + }), + ) + }, + ) + }), ), ) .child( @@ -2994,6 +3214,17 @@ impl GitPanel { Some(footer) } + fn render_pending_amend(&self, cx: &mut Context) -> impl IntoElement { + div() + .py_2() + .px(px(8.)) + .border_color(cx.theme().colors().border) + .child( + Label::new("Your changes will modify your most recent commit. If you want to make these changes as a new commit, you can cancel the amend operation.") + .size(LabelSize::Small), + ) + } + fn render_previous_commit(&self, cx: &mut Context) -> Option { let active_repository = self.active_repository.as_ref()?; let branch = active_repository.read(cx).branch.as_ref()?; @@ -3448,7 +3679,7 @@ impl GitPanel { .into_any_element() } - fn load_commit_details( + pub fn load_commit_details( &self, sha: String, cx: &mut Context, @@ -3766,6 +3997,15 @@ impl GitPanel { fn has_write_access(&self, cx: &App) -> bool { !self.project.read(cx).is_read_only(cx) } + + pub fn amend_pending(&self) -> bool { + self.amend_pending + } + + pub fn set_amend_pending(&mut self, value: bool, cx: &mut Context) { + self.amend_pending = value; + cx.notify(); + } } fn current_language_model(cx: &Context<'_, GitPanel>) -> Option> { @@ -3806,6 +4046,8 @@ impl Render for GitPanel { .when(has_write_access && !project.is_read_only(cx), |this| { this.on_action(cx.listener(Self::toggle_staged_for_selected)) .on_action(cx.listener(GitPanel::commit)) + .on_action(cx.listener(GitPanel::amend)) + .on_action(cx.listener(GitPanel::cancel)) .on_action(cx.listener(Self::stage_all)) .on_action(cx.listener(Self::unstage_all)) .on_action(cx.listener(Self::stage_selected)) @@ -3852,7 +4094,12 @@ impl Render for GitPanel { } }) .children(self.render_footer(window, cx)) - .children(self.render_previous_commit(cx)) + .when(self.amend_pending, |this| { + this.child(self.render_pending_amend(cx)) + }) + .when(!self.amend_pending, |this| { + this.children(self.render_previous_commit(cx)) + }) .into_any_element(), ) .children(self.context_menu.as_ref().map(|(menu, position, _)| { diff --git a/crates/git_ui/src/git_ui.rs b/crates/git_ui/src/git_ui.rs index 5edceb90fe..ac0a1ef859 100644 --- a/crates/git_ui/src/git_ui.rs +++ b/crates/git_ui/src/git_ui.rs @@ -368,6 +368,7 @@ mod remote_button { }) .anchor(Corner::TopRight) } + #[allow(clippy::too_many_arguments)] fn split_button( id: SharedString, diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index cd9fb181d8..4ed22717fc 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -1,7 +1,7 @@ mod supported_countries; use anyhow::{Result, anyhow, bail}; -use futures::{AsyncBufReadExt, AsyncReadExt, Stream, StreamExt, io::BufReader, stream::BoxStream}; +use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use serde::{Deserialize, Serialize}; @@ -455,24 +455,3 @@ impl std::fmt::Display for Model { write!(f, "{}", self.id()) } } - -pub fn extract_text_from_events( - events: impl Stream>, -) -> impl Stream> { - events.filter_map(|event| async move { - match event { - Ok(event) => event.candidates.and_then(|candidates| { - candidates.into_iter().next().and_then(|candidate| { - candidate.content.parts.into_iter().next().and_then(|part| { - if let Part::TextPart(TextPart { text }) = part { - Some(Ok(text)) - } else { - None - } - }) - }) - }), - Err(error) => Some(Err(error)), - } - }) -} diff --git a/crates/gpui/build.rs b/crates/gpui/build.rs index 9c2b0bafa9..e30a7648a8 100644 --- a/crates/gpui/build.rs +++ b/crates/gpui/build.rs @@ -77,8 +77,8 @@ mod macos { fn generate_dispatch_bindings() { println!("cargo:rustc-link-lib=framework=System"); - println!("cargo:rustc-link-lib=framework=ScreenCaptureKit"); - println!("cargo:rerun-if-changed=src/platform/mac/dispatch.h"); + // weak link to support Catalina + println!("cargo:rustc-link-arg=-Wl,-weak_framework,ScreenCaptureKit"); let bindings = bindgen::Builder::default() .header("src/platform/mac/dispatch.h") diff --git a/crates/gpui/src/app.rs b/crates/gpui/src/app.rs index c7c2818b7e..525f9d6ac0 100644 --- a/crates/gpui/src/app.rs +++ b/crates/gpui/src/app.rs @@ -25,6 +25,7 @@ use collections::{FxHashMap, FxHashSet, HashMap, VecDeque}; pub use context::*; pub use entity_map::*; use http_client::HttpClient; +use smallvec::SmallVec; #[cfg(any(test, feature = "test-support"))] pub use test_context::*; use util::ResultExt; @@ -1430,7 +1431,7 @@ impl App { /// Sets the right click menu for the app icon in the dock pub fn set_dock_menu(&self, menus: Vec) { - self.platform.set_dock_menu(menus, &self.keymap.borrow()); + self.platform.set_dock_menu(menus, &self.keymap.borrow()) } /// Performs the action associated with the given dock menu item, only used on Windows for now. @@ -1446,6 +1447,16 @@ impl App { self.platform.add_recent_document(path); } + /// Updates the jump list with the updated list of recent paths for the application, only used on Windows for now. + /// Note that this also sets the dock menu on Windows. + pub fn update_jump_list( + &self, + menus: Vec, + entries: Vec>, + ) -> Vec> { + self.platform.update_jump_list(menus, entries) + } + /// Dispatch an action to the currently active window or global action handler /// See [`crate::Action`] for more information on how actions work pub fn dispatch_action(&mut self, action: &dyn Action) { diff --git a/crates/gpui/src/platform.rs b/crates/gpui/src/platform.rs index 4a6ebc92f8..969e51e44d 100644 --- a/crates/gpui/src/platform.rs +++ b/crates/gpui/src/platform.rs @@ -37,9 +37,10 @@ use crate::{ DEFAULT_WINDOW_SIZE, DevicePixels, DispatchEventResult, Font, FontId, FontMetrics, FontRun, ForegroundExecutor, GlyphId, GpuSpecs, ImageSource, Keymap, LineLayout, Pixels, PlatformInput, Point, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams, ScaledPixels, Scene, - SharedString, Size, SvgRenderer, SvgSize, Task, TaskLabel, Window, point, + ShapedGlyph, ShapedRun, SharedString, Size, SvgRenderer, SvgSize, Task, TaskLabel, Window, + point, px, size, }; -use anyhow::{Result, anyhow}; +use anyhow::Result; use async_task::Runnable; use futures::channel::oneshot; use image::codecs::gif::GifDecoder; @@ -203,6 +204,13 @@ pub(crate) trait Platform: 'static { fn set_dock_menu(&self, menu: Vec, keymap: &Keymap); fn perform_dock_menu_action(&self, _action: usize) {} fn add_recent_document(&self, _path: &Path) {} + fn update_jump_list( + &self, + _menus: Vec, + _entries: Vec>, + ) -> Vec> { + Vec::new() + } fn on_app_menu_action(&self, callback: Box); fn on_will_open_app_menu(&self, callback: Box); fn on_validate_app_menu_command(&self, callback: Box bool>); @@ -525,40 +533,105 @@ impl PlatformTextSystem for NoopTextSystem { Vec::new() } - fn font_id(&self, descriptor: &Font) -> Result { - Err(anyhow!("No font found for {:?}", descriptor)) + fn font_id(&self, _descriptor: &Font) -> Result { + return Ok(FontId(1)); } fn font_metrics(&self, _font_id: FontId) -> FontMetrics { - unimplemented!() + FontMetrics { + units_per_em: 1000, + ascent: 1025.0, + descent: -275.0, + line_gap: 0.0, + underline_position: -95.0, + underline_thickness: 60.0, + cap_height: 698.0, + x_height: 516.0, + bounding_box: Bounds { + origin: Point { + x: -260.0, + y: -245.0, + }, + size: Size { + width: 1501.0, + height: 1364.0, + }, + }, + } } - fn typographic_bounds(&self, font_id: FontId, _glyph_id: GlyphId) -> Result> { - Err(anyhow!("No font found for {:?}", font_id)) + fn typographic_bounds(&self, _font_id: FontId, _glyph_id: GlyphId) -> Result> { + Ok(Bounds { + origin: Point { x: 54.0, y: 0.0 }, + size: size(392.0, 528.0), + }) } - fn advance(&self, font_id: FontId, _glyph_id: GlyphId) -> Result> { - Err(anyhow!("No font found for {:?}", font_id)) + fn advance(&self, _font_id: FontId, glyph_id: GlyphId) -> Result> { + Ok(size(600.0 * glyph_id.0 as f32, 0.0)) } - fn glyph_for_char(&self, _font_id: FontId, _ch: char) -> Option { - None + fn glyph_for_char(&self, _font_id: FontId, ch: char) -> Option { + Some(GlyphId(ch.len_utf16() as u32)) } - fn glyph_raster_bounds(&self, params: &RenderGlyphParams) -> Result> { - Err(anyhow!("No font found for {:?}", params)) + fn glyph_raster_bounds(&self, _params: &RenderGlyphParams) -> Result> { + Ok(Default::default()) } fn rasterize_glyph( &self, - params: &RenderGlyphParams, - _raster_bounds: Bounds, + _params: &RenderGlyphParams, + raster_bounds: Bounds, ) -> Result<(Size, Vec)> { - Err(anyhow!("No font found for {:?}", params)) + Ok((raster_bounds.size, Vec::new())) } - fn layout_line(&self, _text: &str, _font_size: Pixels, _runs: &[FontRun]) -> LineLayout { - unimplemented!() + fn layout_line(&self, text: &str, font_size: Pixels, _runs: &[FontRun]) -> LineLayout { + let mut position = px(0.); + let metrics = self.font_metrics(FontId(0)); + let em_width = font_size + * self + .advance(FontId(0), self.glyph_for_char(FontId(0), 'm').unwrap()) + .unwrap() + .width + / metrics.units_per_em as f32; + let mut glyphs = SmallVec::default(); + for (ix, c) in text.char_indices() { + if let Some(glyph) = self.glyph_for_char(FontId(0), c) { + glyphs.push(ShapedGlyph { + id: glyph, + position: point(position, px(0.)), + index: ix, + is_emoji: glyph.0 == 2, + }); + if glyph.0 == 2 { + position += em_width * 2.0; + } else { + position += em_width; + } + } else { + position += em_width + } + } + let mut runs = Vec::default(); + if glyphs.len() > 0 { + runs.push(ShapedRun { + font_id: FontId(0), + glyphs, + }); + } else { + position = px(0.); + } + + LineLayout { + font_size, + width: position, + ascent: font_size * (metrics.ascent / metrics.units_per_em as f32), + descent: font_size * (metrics.descent / metrics.units_per_em as f32), + runs, + len: text.len(), + } } } diff --git a/crates/gpui/src/platform/linux/platform.rs b/crates/gpui/src/platform/linux/platform.rs index d02eea6dac..445192f07a 100644 --- a/crates/gpui/src/platform/linux/platform.rs +++ b/crates/gpui/src/platform/linux/platform.rs @@ -440,7 +440,9 @@ impl Platform for P { self.with_common(|common| Some(common.menus.clone())) } - fn set_dock_menu(&self, _menu: Vec, _keymap: &Keymap) {} + fn set_dock_menu(&self, _menu: Vec, _keymap: &Keymap) { + // todo(linux) + } fn path_for_auxiliary_executable(&self, _name: &str) -> Result { Err(anyhow::Error::msg( diff --git a/crates/gpui/src/platform/mac/platform.rs b/crates/gpui/src/platform/mac/platform.rs index 0bda71369e..759e5462d0 100644 --- a/crates/gpui/src/platform/mac/platform.rs +++ b/crates/gpui/src/platform/mac/platform.rs @@ -2,7 +2,7 @@ use super::{ BoolExt, attributed_string::{NSAttributedString, NSMutableAttributedString}, events::key_to_native, - renderer, screen_capture, + is_macos_version_at_least, renderer, screen_capture, }; use crate::{ Action, AnyWindowHandle, BackgroundExecutor, ClipboardEntry, ClipboardItem, ClipboardString, @@ -22,8 +22,8 @@ use cocoa::{ }, base::{BOOL, NO, YES, id, nil, selector}, foundation::{ - NSArray, NSAutoreleasePool, NSBundle, NSData, NSInteger, NSProcessInfo, NSRange, NSString, - NSUInteger, NSURL, + NSArray, NSAutoreleasePool, NSBundle, NSData, NSInteger, NSOperatingSystemVersion, + NSProcessInfo, NSRange, NSString, NSUInteger, NSURL, }, }; use core_foundation::{ @@ -553,7 +553,8 @@ impl Platform for MacPlatform { } fn is_screen_capture_supported(&self) -> bool { - true + let min_version = NSOperatingSystemVersion::new(12, 3, 0); + is_macos_version_at_least(min_version) } fn screen_capture_sources( diff --git a/crates/gpui/src/platform/mac/screen_capture.rs b/crates/gpui/src/platform/mac/screen_capture.rs index 8e9fc3d3f9..ac2503bb20 100644 --- a/crates/gpui/src/platform/mac/screen_capture.rs +++ b/crates/gpui/src/platform/mac/screen_capture.rs @@ -37,9 +37,6 @@ pub struct MacScreenCaptureStream { sc_stream_output: id, } -#[link(name = "ScreenCaptureKit", kind = "framework")] -unsafe extern "C" {} - static mut DELEGATE_CLASS: *const Class = ptr::null(); static mut OUTPUT_CLASS: *const Class = ptr::null(); const FRAME_CALLBACK_IVAR: &str = "frame_callback"; diff --git a/crates/gpui/src/platform/mac/window.rs b/crates/gpui/src/platform/mac/window.rs index 532856d890..26a62aeadf 100644 --- a/crates/gpui/src/platform/mac/window.rs +++ b/crates/gpui/src/platform/mac/window.rs @@ -1568,7 +1568,7 @@ extern "C" fn window_will_exit_fullscreen(this: &Object, _: Sel, _: id) { } } -fn is_macos_version_at_least(version: NSOperatingSystemVersion) -> bool { +pub(crate) fn is_macos_version_at_least(version: NSOperatingSystemVersion) -> bool { unsafe { NSProcessInfo::processInfo(nil).isOperatingSystemAtLeastVersion(version) } } diff --git a/crates/gpui/src/platform/test/platform.rs b/crates/gpui/src/platform/test/platform.rs index 90e3cf2fa6..3902a98a65 100644 --- a/crates/gpui/src/platform/test/platform.rs +++ b/crates/gpui/src/platform/test/platform.rs @@ -1,8 +1,8 @@ use crate::{ AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DevicePixels, - ForegroundExecutor, Keymap, Platform, PlatformDisplay, PlatformTextSystem, ScreenCaptureFrame, - ScreenCaptureSource, ScreenCaptureStream, Size, Task, TestDisplay, TestWindow, - WindowAppearance, WindowParams, size, + ForegroundExecutor, Keymap, NoopTextSystem, Platform, PlatformDisplay, PlatformTextSystem, + ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, Size, Task, TestDisplay, + TestWindow, WindowAppearance, WindowParams, size, }; use anyhow::Result; use collections::VecDeque; @@ -91,17 +91,7 @@ impl TestPlatform { ) }; - #[cfg(target_os = "macos")] - let text_system = Arc::new(crate::platform::mac::MacTextSystem::new()); - - #[cfg(any(target_os = "linux", target_os = "freebsd"))] - let text_system = Arc::new(crate::platform::linux::CosmicTextSystem::new()); - - #[cfg(target_os = "windows")] - let text_system = Arc::new( - crate::platform::windows::DirectWriteTextSystem::new(&bitmap_factory) - .expect("Unable to initialize direct write."), - ); + let text_system = Arc::new(NoopTextSystem); Rc::new_cyclic(|weak| TestPlatform { background_executor: executor, diff --git a/crates/gpui/src/platform/windows.rs b/crates/gpui/src/platform/windows.rs index 51d09f0013..b3a89e9635 100644 --- a/crates/gpui/src/platform/windows.rs +++ b/crates/gpui/src/platform/windows.rs @@ -1,4 +1,5 @@ mod clipboard; +mod destination_list; mod direct_write; mod dispatcher; mod display; @@ -10,6 +11,7 @@ mod window; mod wrapper; pub(crate) use clipboard::*; +pub(crate) use destination_list::*; pub(crate) use direct_write::*; pub(crate) use dispatcher::*; pub(crate) use display::*; diff --git a/crates/gpui/src/platform/windows/destination_list.rs b/crates/gpui/src/platform/windows/destination_list.rs new file mode 100644 index 0000000000..09b47a3ea4 --- /dev/null +++ b/crates/gpui/src/platform/windows/destination_list.rs @@ -0,0 +1,211 @@ +use std::path::PathBuf; + +use itertools::Itertools; +use smallvec::SmallVec; +use windows::{ + Win32::{ + Foundation::PROPERTYKEY, + Globalization::u_strlen, + System::Com::{CLSCTX_INPROC_SERVER, CoCreateInstance, StructuredStorage::PROPVARIANT}, + UI::{ + Controls::INFOTIPSIZE, + Shell::{ + Common::{IObjectArray, IObjectCollection}, + DestinationList, EnumerableObjectCollection, ICustomDestinationList, IShellLinkW, + PropertiesSystem::IPropertyStore, + ShellLink, + }, + }, + }, + core::{GUID, HSTRING, Interface}, +}; + +use crate::{Action, MenuItem}; + +pub(crate) struct JumpList { + pub(crate) dock_menus: Vec, + pub(crate) recent_workspaces: Vec>, +} + +impl JumpList { + pub(crate) fn new() -> Self { + Self { + dock_menus: Vec::new(), + recent_workspaces: Vec::new(), + } + } +} + +pub(crate) struct DockMenuItem { + pub(crate) name: String, + pub(crate) description: String, + pub(crate) action: Box, +} + +impl DockMenuItem { + pub(crate) fn new(item: MenuItem) -> anyhow::Result { + match item { + MenuItem::Action { name, action, .. } => Ok(Self { + name: name.clone().into(), + description: if name == "New Window" { + "Opens a new window".to_string() + } else { + name.into() + }, + action, + }), + _ => Err(anyhow::anyhow!( + "Only `MenuItem::Action` is supported for dock menu on Windows." + )), + } + } +} + +// This code is based on the example from Microsoft: +// https://github.com/microsoft/Windows-classic-samples/blob/main/Samples/Win7Samples/winui/shell/appshellintegration/RecipePropertyHandler/RecipePropertyHandler.cpp +pub(crate) fn update_jump_list( + jump_list: &JumpList, +) -> anyhow::Result>> { + let (list, removed) = create_destination_list()?; + add_recent_folders(&list, &jump_list.recent_workspaces, removed.as_ref())?; + add_dock_menu(&list, &jump_list.dock_menus)?; + unsafe { list.CommitList() }?; + Ok(removed) +} + +// Copied from: +// https://github.com/microsoft/windows-rs/blob/0fc3c2e5a13d4316d242bdeb0a52af611eba8bd4/crates/libs/windows/src/Windows/Win32/Storage/EnhancedStorage/mod.rs#L1881 +const PKEY_TITLE: PROPERTYKEY = PROPERTYKEY { + fmtid: GUID::from_u128(0xf29f85e0_4ff9_1068_ab91_08002b27b3d9), + pid: 2, +}; + +fn create_destination_list() -> anyhow::Result<(ICustomDestinationList, Vec>)> +{ + let list: ICustomDestinationList = + unsafe { CoCreateInstance(&DestinationList, None, CLSCTX_INPROC_SERVER) }?; + + let mut slots = 0; + let user_removed: IObjectArray = unsafe { list.BeginList(&mut slots) }?; + + let count = unsafe { user_removed.GetCount() }?; + if count == 0 { + return Ok((list, Vec::new())); + } + + let mut removed = Vec::with_capacity(count as usize); + for i in 0..count { + let shell_link: IShellLinkW = unsafe { user_removed.GetAt(i)? }; + let description = { + // INFOTIPSIZE is the maximum size of the buffer + // see https://learn.microsoft.com/en-us/windows/win32/api/shobjidl_core/nf-shobjidl_core-ishelllinkw-getdescription + let mut buffer = [0u16; INFOTIPSIZE as usize]; + unsafe { shell_link.GetDescription(&mut buffer)? }; + let len = unsafe { u_strlen(buffer.as_ptr()) }; + String::from_utf16_lossy(&buffer[..len as usize]) + }; + let args = description.split('\n').map(PathBuf::from).collect(); + + removed.push(args); + } + + Ok((list, removed)) +} + +fn add_dock_menu(list: &ICustomDestinationList, dock_menus: &[DockMenuItem]) -> anyhow::Result<()> { + unsafe { + let tasks: IObjectCollection = + CoCreateInstance(&EnumerableObjectCollection, None, CLSCTX_INPROC_SERVER)?; + for (idx, dock_menu) in dock_menus.iter().enumerate() { + let argument = HSTRING::from(format!("--dock-action {}", idx)); + let description = HSTRING::from(dock_menu.description.as_str()); + let display = dock_menu.name.as_str(); + let task = create_shell_link(argument, description, None, display)?; + tasks.AddObject(&task)?; + } + list.AddUserTasks(&tasks)?; + Ok(()) + } +} + +fn add_recent_folders( + list: &ICustomDestinationList, + entries: &[SmallVec<[PathBuf; 2]>], + removed: &Vec>, +) -> anyhow::Result<()> { + unsafe { + let tasks: IObjectCollection = + CoCreateInstance(&EnumerableObjectCollection, None, CLSCTX_INPROC_SERVER)?; + + for folder_path in entries + .iter() + .filter(|path| !is_item_in_array(path, removed)) + { + let argument = HSTRING::from( + folder_path + .iter() + .map(|path| format!("\"{}\"", path.display())) + .join(" "), + ); + + let description = HSTRING::from( + folder_path + .iter() + .map(|path| path.to_string_lossy()) + .collect::>() + .join("\n"), + ); + // simulate folder icon + // https://github.com/microsoft/vscode/blob/7a5dc239516a8953105da34f84bae152421a8886/src/vs/platform/workspaces/electron-main/workspacesHistoryMainService.ts#L380 + let icon = HSTRING::from("explorer.exe"); + + let display = folder_path + .iter() + .map(|p| { + p.file_name() + .map(|name| name.to_string_lossy().to_string()) + .unwrap_or_else(|| p.to_string_lossy().to_string()) + }) + .join(", "); + + tasks.AddObject(&create_shell_link( + argument, + description, + Some(icon), + &display, + )?)?; + } + + list.AppendCategory(&HSTRING::from("Recent Folders"), &tasks)?; + Ok(()) + } +} + +#[inline] +fn is_item_in_array(item: &SmallVec<[PathBuf; 2]>, removed: &Vec>) -> bool { + removed.iter().any(|removed_item| removed_item == item) +} + +fn create_shell_link( + argument: HSTRING, + description: HSTRING, + icon: Option, + display: &str, +) -> anyhow::Result { + unsafe { + let link: IShellLinkW = CoCreateInstance(&ShellLink, None, CLSCTX_INPROC_SERVER)?; + let exe_path = HSTRING::from(std::env::current_exe()?.as_os_str()); + link.SetPath(&exe_path)?; + link.SetArguments(&argument)?; + link.SetDescription(&description)?; + if let Some(icon) = icon { + link.SetIconLocation(&icon, 0)?; + } + let store: IPropertyStore = link.cast()?; + let title = PROPVARIANT::from(display); + store.SetValue(&PKEY_TITLE, &title)?; + store.Commit()?; + + Ok(link) + } +} diff --git a/crates/gpui/src/platform/windows/platform.rs b/crates/gpui/src/platform/windows/platform.rs index 116b2253d1..7889c89a9e 100644 --- a/crates/gpui/src/platform/windows/platform.rs +++ b/crates/gpui/src/platform/windows/platform.rs @@ -14,10 +14,7 @@ use itertools::Itertools; use parking_lot::RwLock; use smallvec::SmallVec; use windows::{ - UI::{ - StartScreen::{JumpList, JumpListItem}, - ViewManagement::UISettings, - }, + UI::ViewManagement::UISettings, Win32::{ Foundation::*, Graphics::{ @@ -52,7 +49,7 @@ pub(crate) struct WindowsPlatform { pub(crate) struct WindowsPlatformState { callbacks: PlatformCallbacks, menus: Vec, - dock_menu_actions: Vec>, + jump_list: JumpList, // NOTE: standard cursor handles don't need to close. pub(crate) current_cursor: Option, } @@ -70,12 +67,12 @@ struct PlatformCallbacks { impl WindowsPlatformState { fn new() -> Self { let callbacks = PlatformCallbacks::default(); - let dock_menu_actions = Vec::new(); + let jump_list = JumpList::new(); let current_cursor = load_cursor(CursorStyle::Arrow); Self { callbacks, - dock_menu_actions, + jump_list, current_cursor, menus: Vec::new(), } @@ -189,9 +186,10 @@ impl WindowsPlatform { let mut lock = self.state.borrow_mut(); if let Some(mut callback) = lock.callbacks.app_menu_action.take() { let Some(action) = lock - .dock_menu_actions + .jump_list + .dock_menus .get(action_idx) - .map(|action| action.boxed_clone()) + .map(|dock_menu| dock_menu.action.boxed_clone()) else { lock.callbacks.app_menu_action = Some(callback); log::error!("Dock menu for index {action_idx} not found"); @@ -254,33 +252,35 @@ impl WindowsPlatform { false } - fn configure_jump_list(&self, menus: Vec) -> Result<()> { - let jump_list = JumpList::LoadCurrentAsync()?.get()?; - let items = jump_list.Items()?; - items.Clear()?; + fn set_dock_menus(&self, menus: Vec) { let mut actions = Vec::new(); - for item in menus.into_iter() { - let item = match item { - MenuItem::Separator => JumpListItem::CreateSeparator()?, - MenuItem::Submenu(_) => { - log::error!("Set `MenuItemSubmenu` for dock menu on Windows is not supported."); - continue; - } - MenuItem::Action { name, action, .. } => { - let idx = actions.len(); - actions.push(action.boxed_clone()); - let item_args = format!("--dock-action {}", idx); - JumpListItem::CreateWithArguments( - &HSTRING::from(item_args), - &HSTRING::from(name.as_ref()), - )? - } - }; - items.Append(&item)?; - } - jump_list.SaveAsync()?.get()?; - self.state.borrow_mut().dock_menu_actions = actions; - Ok(()) + menus.into_iter().for_each(|menu| { + if let Some(dock_menu) = DockMenuItem::new(menu).log_err() { + actions.push(dock_menu); + } + }); + let mut lock = self.state.borrow_mut(); + lock.jump_list.dock_menus = actions; + update_jump_list(&lock.jump_list).log_err(); + } + + fn update_jump_list( + &self, + menus: Vec, + entries: Vec>, + ) -> Vec> { + let mut actions = Vec::new(); + menus.into_iter().for_each(|menu| { + if let Some(dock_menu) = DockMenuItem::new(menu).log_err() { + actions.push(dock_menu); + } + }); + let mut lock = self.state.borrow_mut(); + lock.jump_list.dock_menus = actions; + lock.jump_list.recent_workspaces = entries; + update_jump_list(&lock.jump_list) + .log_err() + .unwrap_or_default() } } @@ -535,7 +535,7 @@ impl Platform for WindowsPlatform { } fn set_dock_menu(&self, menus: Vec, _keymap: &Keymap) { - self.configure_jump_list(menus).log_err(); + self.set_dock_menus(menus); } fn on_app_menu_action(&self, callback: Box) { @@ -673,6 +673,14 @@ impl Platform for WindowsPlatform { .log_err(); } } + + fn update_jump_list( + &self, + menus: Vec, + entries: Vec>, + ) -> Vec> { + self.update_jump_list(menus, entries) + } } impl Drop for WindowsPlatform { diff --git a/crates/gpui/src/text_system/line_wrapper.rs b/crates/gpui/src/text_system/line_wrapper.rs index f0dfc927e5..d7bc4c1f24 100644 --- a/crates/gpui/src/text_system/line_wrapper.rs +++ b/crates/gpui/src/text_system/line_wrapper.rs @@ -330,13 +330,6 @@ mod tests { fn build_wrapper() -> LineWrapper { let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0)); let cx = TestAppContext::new(dispatcher, None); - cx.text_system() - .add_fonts(vec![ - std::fs::read("../../assets/fonts/plex-mono/ZedPlexMono-Regular.ttf") - .unwrap() - .into(), - ]) - .unwrap(); let id = cx.text_system().font_id(&font("Zed Plex Mono")).unwrap(); LineWrapper::new(id, px(16.), cx.text_system().platform_text_system.clone()) } @@ -734,16 +727,16 @@ mod tests { lines[0].layout.wrap_boundaries(), &[ WrapBoundary { - run_ix: 1, - glyph_ix: 3 + run_ix: 0, + glyph_ix: 7 }, WrapBoundary { - run_ix: 2, - glyph_ix: 3 + run_ix: 0, + glyph_ix: 12 }, WrapBoundary { - run_ix: 4, - glyph_ix: 2 + run_ix: 0, + glyph_ix: 18 } ], ); diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index aa8dcaf587..d7f4a820da 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -11,7 +11,6 @@ pub enum IconName { Ai, AiAnthropic, AiBedrock, - AiAnthropicHosted, AiDeepSeek, AiEdit, AiGoogle, @@ -61,6 +60,7 @@ pub enum IconName { CircleOff, Clipboard, Close, + Cloud, Code, Cog, Command, @@ -74,22 +74,22 @@ pub enum IconName { CountdownTimer, CursorIBeam, Dash, + DatabaseZap, + Debug, DebugBreakpoint, + DebugContinue, DebugDisabledBreakpoint, DebugDisabledLogBreakpoint, + DebugDisconnect, DebugIgnoreBreakpoints, + DebugLogBreakpoint, DebugPause, - DebugContinue, - DebugStepOver, + DebugRestart, + DebugStepBack, DebugStepInto, DebugStepOut, - DebugStepBack, - DebugRestart, - Debug, + DebugStepOver, DebugStop, - DebugDisconnect, - DebugLogBreakpoint, - DatabaseZap, Delete, Diff, Disconnected, @@ -99,18 +99,18 @@ pub enum IconName { Envelope, Eraser, Escape, - ExpandVertical, Exit, - ExternalLink, - ExpandUp, ExpandDown, + ExpandUp, + ExpandVertical, + ExternalLink, Eye, File, FileCode, FileCreate, FileDelete, - FileDoc, FileDiff, + FileDoc, FileGeneric, FileGit, FileLock, @@ -133,16 +133,17 @@ pub enum IconName { GenericMaximize, GenericMinimize, GenericRestore, - Github, - Globe, GitBranch, GitBranchSmall, + Github, + Globe, Hash, HistoryRerun, Indicator, Info, InlayHint, Keyboard, + Layout, Library, LightBulb, LineHeight, @@ -155,7 +156,6 @@ pub enum IconName { Maximize, Menu, MessageBubbles, - Cloud, Mic, MicMute, Microscope, @@ -227,8 +227,8 @@ pub enum IconName { Tab, Terminal, TextSnippet, - ThumbsUp, ThumbsDown, + ThumbsUp, Trash, TrashAlt, Triangle, @@ -247,10 +247,10 @@ pub enum IconName { ZedAssistant, ZedAssistantFilled, ZedPredict, - ZedPredictUp, - ZedPredictDown, ZedPredictDisabled, + ZedPredictDown, ZedPredictError, + ZedPredictUp, ZedXCopilot, } diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index 45535971f7..0d6f797bc3 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -1373,6 +1373,25 @@ impl Buffer { .or_else(|| self.language.clone()) } + /// Returns each [`Language`] for the active syntax layers at the given location. + pub fn languages_at(&self, position: D) -> Vec> { + let offset = position.to_offset(self); + let mut languages: Vec> = self + .syntax_map + .lock() + .layers_for_range(offset..offset, &self.text, false) + .map(|info| info.language.clone()) + .collect(); + + if languages.is_empty() { + if let Some(buffer_language) = self.language() { + languages.push(buffer_language.clone()); + } + } + + languages + } + /// An integer version number that accounts for all updates besides /// the buffer's text itself (which is versioned via a version vector). pub fn non_text_state_update_count(&self) -> usize { diff --git a/crates/language/src/language_registry.rs b/crates/language/src/language_registry.rs index a575e08022..7ba3f3b0ae 100644 --- a/crates/language/src/language_registry.rs +++ b/crates/language/src/language_registry.rs @@ -8,7 +8,7 @@ use crate::{ with_parser, }; use anyhow::{Context as _, Result, anyhow}; -use collections::{HashMap, HashSet, hash_map}; +use collections::{FxHashMap, HashMap, HashSet, hash_map}; use futures::{ Future, @@ -16,13 +16,17 @@ use futures::{ }; use globset::GlobSet; use gpui::{App, BackgroundExecutor, SharedString}; +use itertools::FoldWhile::{Continue, Done}; +use itertools::Itertools; use lsp::LanguageServerId; use parking_lot::{Mutex, RwLock}; use postage::watch; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use smallvec::SmallVec; use std::{ borrow::{Borrow, Cow}, + cell::LazyCell, ffi::OsStr, ops::Not, path::{Path, PathBuf}, @@ -163,6 +167,20 @@ impl AvailableLanguage { } } +#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord)] +enum LanguageMatchPrecedence { + #[default] + Undetermined, + PathOrContent, + UserConfigured, +} + +impl LanguageMatchPrecedence { + fn best_possible_match(&self) -> bool { + *self == LanguageMatchPrecedence::UserConfigured + } +} + enum AvailableGrammar { Native(tree_sitter::Language), Loaded(#[allow(unused)] PathBuf, tree_sitter::Language), @@ -600,12 +618,10 @@ impl LanguageRegistry { name: &str, ) -> impl Future>> + use<> { let name = UniCase::new(name); - let rx = self.get_or_load_language(|language_name, _| { - if UniCase::new(&language_name.0) == name { - 1 - } else { - 0 - } + let rx = self.get_or_load_language(|language_name, _, current_best_match| { + (current_best_match < LanguageMatchPrecedence::PathOrContent + && UniCase::new(&language_name.0) == name) + .then_some(LanguageMatchPrecedence::PathOrContent) }); async move { rx.await? } } @@ -615,17 +631,14 @@ impl LanguageRegistry { string: &str, ) -> impl Future>> { let string = UniCase::new(string); - let rx = self.get_or_load_language(|name, config| { - if UniCase::new(&name.0) == string - || config - .path_suffixes - .iter() - .any(|suffix| UniCase::new(suffix) == string) - { - 1 - } else { - 0 - } + let rx = self.get_or_load_language(|name, config, current_best_match| { + (current_best_match < LanguageMatchPrecedence::PathOrContent + && (UniCase::new(&name.0) == string + || config + .path_suffixes + .iter() + .any(|suffix| UniCase::new(suffix) == string))) + .then_some(LanguageMatchPrecedence::PathOrContent) }); async move { rx.await? } } @@ -674,7 +687,7 @@ impl LanguageRegistry { self: &Arc, path: &Path, content: Option<&Rope>, - user_file_types: Option<&HashMap, GlobSet>>, + user_file_types: Option<&FxHashMap, GlobSet>>, ) -> Option { let filename = path.file_name().and_then(|name| name.to_str()); // `Path.extension()` returns None for files with a leading '.' @@ -682,57 +695,94 @@ impl LanguageRegistry { // as we want `.zshrc` to result in extension being `Some("zshrc")` let extension = filename.and_then(|filename| filename.split('.').next_back()); let path_suffixes = [extension, filename, path.to_str()]; - let empty = GlobSet::empty(); + let path_suffixes_candidates = path_suffixes + .iter() + .filter_map(|suffix| suffix.map(globset::Candidate::new)) + .collect::>(); + let content = LazyCell::new(|| { + content.map(|content| { + let end = content.clip_point(Point::new(0, 256), Bias::Left); + let end = content.point_to_offset(end); + content.chunks_in_range(0..end).collect::() + }) + }); + self.find_matching_language(move |language_name, config, current_best_match| { + let path_matches_default_suffix = || { + config + .path_suffixes + .iter() + .any(|suffix| path_suffixes.contains(&Some(suffix.as_str()))) + }; + let path_matches_custom_suffix = || { + user_file_types + .and_then(|types| types.get(language_name.as_ref())) + .map_or(false, |custom_suffixes| { + path_suffixes_candidates + .iter() + .any(|suffix| custom_suffixes.is_match_candidate(suffix)) + }) + }; + let content_matches = || { + config.first_line_pattern.as_ref().map_or(false, |pattern| { + content + .as_ref() + .is_some_and(|content| pattern.is_match(content)) + }) + }; - self.find_matching_language(move |language_name, config| { - let path_matches_default_suffix = config - .path_suffixes - .iter() - .any(|suffix| path_suffixes.contains(&Some(suffix.as_str()))); - let custom_suffixes = user_file_types - .and_then(|types| types.get(language_name.as_ref())) - .unwrap_or(&empty); - let path_matches_custom_suffix = path_suffixes - .iter() - .map(|suffix| suffix.unwrap_or("")) - .any(|suffix| custom_suffixes.is_match(suffix)); - let content_matches = content.zip(config.first_line_pattern.as_ref()).map_or( - false, - |(content, pattern)| { - let end = content.clip_point(Point::new(0, 256), Bias::Left); - let end = content.point_to_offset(end); - let text = content.chunks_in_range(0..end).collect::(); - pattern.is_match(&text) - }, - ); - if path_matches_custom_suffix { - 2 - } else if path_matches_default_suffix || content_matches { - 1 - } else { - 0 + // Only return a match for the given file if we have a better match than + // the current one. + match current_best_match { + LanguageMatchPrecedence::PathOrContent | LanguageMatchPrecedence::Undetermined + if path_matches_custom_suffix() => + { + Some(LanguageMatchPrecedence::UserConfigured) + } + LanguageMatchPrecedence::Undetermined + if path_matches_default_suffix() || content_matches() => + { + Some(LanguageMatchPrecedence::PathOrContent) + } + _ => None, } }) } fn find_matching_language( self: &Arc, - callback: impl Fn(&LanguageName, &LanguageMatcher) -> usize, + callback: impl Fn( + &LanguageName, + &LanguageMatcher, + LanguageMatchPrecedence, + ) -> Option, ) -> Option { let state = self.state.read(); let available_language = state .available_languages .iter() - .filter_map(|language| { - let score = callback(&language.name, &language.matcher); - if score > 0 { - Some((language.clone(), score)) - } else { - None + .rev() + .fold_while(None, |best_language_match, language| { + let current_match_type = best_language_match + .as_ref() + .map_or(LanguageMatchPrecedence::default(), |(_, score)| *score); + let language_score = + callback(&language.name, &language.matcher, current_match_type); + debug_assert!( + language_score.is_none_or(|new_score| new_score > current_match_type), + "Matching callback should only return a better match than the current one" + ); + + match language_score { + Some(new_score) if new_score.best_possible_match() => { + Done(Some((language.clone(), new_score))) + } + Some(new_score) if current_match_type < new_score => { + Continue(Some((language.clone(), new_score))) + } + _ => Continue(best_language_match), } }) - .max_by_key(|e| e.1) - .clone() + .into_inner() .map(|(available_language, _)| available_language); drop(state); available_language @@ -827,7 +877,11 @@ impl LanguageRegistry { fn get_or_load_language( self: &Arc, - callback: impl Fn(&LanguageName, &LanguageMatcher) -> usize, + callback: impl Fn( + &LanguageName, + &LanguageMatcher, + LanguageMatchPrecedence, + ) -> Option, ) -> oneshot::Receiver>> { let Some(language) = self.find_matching_language(callback) else { let (tx, rx) = oneshot::channel(); diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index ca2c33419f..56ffbbef2f 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -2,7 +2,7 @@ use crate::{File, Language, LanguageName, LanguageServerName}; use anyhow::Result; -use collections::{HashMap, HashSet}; +use collections::{FxHashMap, HashMap, HashSet}; use core::slice; use ec4rs::{ Properties as EditorconfigProperties, @@ -63,7 +63,7 @@ pub struct AllLanguageSettings { pub edit_predictions: EditPredictionSettings, pub defaults: LanguageSettings, languages: HashMap, - pub(crate) file_types: HashMap, GlobSet>, + pub(crate) file_types: FxHashMap, GlobSet>, } /// The settings for a particular language. @@ -1217,7 +1217,7 @@ impl settings::Settings for AllLanguageSettings { .map(|settings| settings.enabled_in_assistant) .unwrap_or(true); - let mut file_types: HashMap, GlobSet> = HashMap::default(); + let mut file_types: FxHashMap, GlobSet> = FxHashMap::default(); for (language, suffixes) in &default_value.file_types { let mut builder = GlobSetBuilder::new(); diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index aa060f7b30..a0e38c629e 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -83,7 +83,7 @@ pub enum StopReason { ToolUse, } -#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Default)] +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] pub struct TokenUsage { #[serde(default, skip_serializing_if = "is_default")] pub input_tokens: u32, @@ -174,10 +174,6 @@ impl Default for LanguageModelTextStream { pub trait LanguageModel: Send + Sync { fn id(&self) -> LanguageModelId; fn name(&self) -> LanguageModelName; - /// If None, falls back to [LanguageModelProvider::icon] - fn icon(&self) -> Option { - None - } fn provider_id(&self) -> LanguageModelProviderId; fn provider_name(&self) -> LanguageModelProviderName; fn telemetry_id(&self) -> String; @@ -282,6 +278,12 @@ pub trait LanguageModel: Send + Sync { } } +#[derive(Debug, Error)] +pub enum LanguageModelKnownError { + #[error("Context window limit exceeded ({tokens})")] + ContextWindowLimitExceeded { tokens: usize }, +} + pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { fn name() -> String; fn description() -> String; @@ -304,6 +306,9 @@ pub trait LanguageModelProvider: 'static { } fn default_model(&self, cx: &App) -> Option>; fn provided_models(&self, cx: &App) -> Vec>; + fn recommended_models(&self, _cx: &App) -> Vec> { + Vec::new() + } fn load_model(&self, _model: Arc, _cx: &App) {} fn is_authenticated(&self, cx: &App) -> bool; fn authenticate(&self, cx: &mut App) -> Task>; @@ -348,7 +353,7 @@ pub trait LanguageModelProviderState: 'static { } } -#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] +#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)] pub struct LanguageModelId(pub SharedString); #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index e5c66670d8..3c12cb1bd5 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -6,7 +6,6 @@ use client::Client; use gpui::{ App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _, }; -use icons::IconName; use proto::{Plan, TypedEnvelope}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -53,13 +52,6 @@ impl CloudModel { } } - pub fn icon(&self) -> Option { - match self { - Self::Anthropic(_) => Some(IconName::AiAnthropicHosted), - _ => None, - } - } - pub fn max_token_count(&self) -> usize { match self { Self::Anthropic(model) => model.max_token_count(), @@ -91,6 +83,9 @@ impl CloudModel { | open_ai::Model::FourTurbo | open_ai::Model::FourOmni | open_ai::Model::FourOmniMini + | open_ai::Model::FourPointOne + | open_ai::Model::FourPointOneMini + | open_ai::Model::FourPointOneNano | open_ai::Model::O1Mini | open_ai::Model::O1Preview | open_ai::Model::O1 diff --git a/crates/language_model_selector/Cargo.toml b/crates/language_model_selector/Cargo.toml index 1257ae564c..39bc8a59f9 100644 --- a/crates/language_model_selector/Cargo.toml +++ b/crates/language_model_selector/Cargo.toml @@ -12,6 +12,7 @@ workspace = true path = "src/language_model_selector.rs" [dependencies] +collections.workspace = true feature_flags.workspace = true gpui.workspace = true language_model.workspace = true diff --git a/crates/language_model_selector/src/language_model_selector.rs b/crates/language_model_selector/src/language_model_selector.rs index 90747a01f3..c7b5d9cd48 100644 --- a/crates/language_model_selector/src/language_model_selector.rs +++ b/crates/language_model_selector/src/language_model_selector.rs @@ -1,12 +1,13 @@ use std::sync::Arc; +use collections::{HashSet, IndexMap}; use feature_flags::{Assistant2FeatureFlag, ZedPro}; use gpui::{ Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases, }; use language_model::{ - AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry, + AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, }; use picker::{Picker, PickerDelegate}; use proto::Plan; @@ -24,9 +25,6 @@ type OnModelChanged = Arc, &App) + 'static>; pub struct LanguageModelSelector { picker: Entity>, - /// The task used to update the picker's matches when there is a change to - /// the language model registry. - update_matches_task: Option>, _authenticate_all_providers_task: Task<()>, _subscriptions: Vec, } @@ -40,16 +38,18 @@ impl LanguageModelSelector { let on_model_changed = Arc::new(on_model_changed); let all_models = Self::all_models(cx); + let entries = all_models.entries(); + let delegate = LanguageModelPickerDelegate { language_model_selector: cx.entity().downgrade(), on_model_changed: on_model_changed.clone(), - all_models: all_models.clone(), - filtered_models: all_models, - selected_index: Self::get_active_model_index(cx), + all_models: Arc::new(all_models), + selected_index: Self::get_active_model_index(&entries, cx), + filtered_entries: entries, }; let picker = cx.new(|cx| { - Picker::uniform_list(delegate, window, cx) + Picker::list(delegate, window, cx) .show_scrollbar(true) .width(rems(20.)) .max_height(Some(rems(20.).into())) @@ -59,7 +59,6 @@ impl LanguageModelSelector { LanguageModelSelector { picker, - update_matches_task: None, _authenticate_all_providers_task: Self::authenticate_all_providers(cx), _subscriptions: vec![ cx.subscribe_in( @@ -83,12 +82,13 @@ impl LanguageModelSelector { language_model::Event::ProviderStateChanged | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { - let task = self.picker.update(cx, |this, cx| { + self.picker.update(cx, |this, cx| { let query = this.query(cx); - this.delegate.all_models = Self::all_models(cx); - this.delegate.update_matches(query, window, cx) + this.delegate.all_models = Arc::new(Self::all_models(cx)); + // Update matches will automatically drop the previous task + // if we get a provider event again + this.update_matches(query, window, cx) }); - self.update_matches_task = Some(task); } _ => {} } @@ -144,34 +144,72 @@ impl LanguageModelSelector { }) } - fn all_models(cx: &App) -> Vec { - LanguageModelRegistry::global(cx) + fn all_models(cx: &App) -> GroupedModels { + let mut recommended = Vec::new(); + let mut recommended_set = HashSet::default(); + for provider in LanguageModelRegistry::global(cx) .read(cx) .providers() .iter() - .flat_map(|provider| { - let icon = provider.icon(); - - provider.provided_models(cx).into_iter().map(move |model| { - let model = model.clone(); - let icon = model.icon().unwrap_or(icon); - - ModelInfo { + { + let models = provider.recommended_models(cx); + recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id()))); + recommended.extend( + provider + .recommended_models(cx) + .into_iter() + .map(move |model| ModelInfo { model: model.clone(), - icon, - availability: model.availability(), - } - }) + icon: provider.icon(), + }), + ); + } + + let other_models = LanguageModelRegistry::global(cx) + .read(cx) + .providers() + .iter() + .map(|provider| { + ( + provider.id(), + provider + .provided_models(cx) + .into_iter() + .filter_map(|model| { + let not_included = + !recommended_set.contains(&(model.provider_id(), model.id())); + not_included.then(|| ModelInfo { + model: model.clone(), + icon: provider.icon(), + }) + }) + .collect::>(), + ) }) - .collect::>() + .collect::>(); + + GroupedModels { + recommended, + other: other_models, + } } - fn get_active_model_index(cx: &App) -> usize { + fn get_active_model_index(entries: &[LanguageModelPickerEntry], cx: &App) -> usize { let active_model = LanguageModelRegistry::read_global(cx).default_model(); - Self::all_models(cx) + entries .iter() - .position(|model_info| { - Some(model_info.model.id()) == active_model.as_ref().map(|model| model.model.id()) + .position(|entry| { + if let LanguageModelPickerEntry::Model(model) = entry { + active_model + .as_ref() + .map(|active_model| { + active_model.model.id() == model.model.id() + && active_model.model.provider_id() == model.model.provider_id() + }) + .unwrap_or_default() + } else { + false + } }) .unwrap_or(0) } @@ -254,22 +292,61 @@ where struct ModelInfo { model: Arc, icon: IconName, - availability: LanguageModelAvailability, } pub struct LanguageModelPickerDelegate { language_model_selector: WeakEntity, on_model_changed: OnModelChanged, - all_models: Vec, - filtered_models: Vec, + all_models: Arc, + filtered_entries: Vec, selected_index: usize, } +struct GroupedModels { + recommended: Vec, + other: IndexMap>, +} + +impl GroupedModels { + fn entries(&self) -> Vec { + let mut entries = Vec::new(); + + if !self.recommended.is_empty() { + entries.push(LanguageModelPickerEntry::Separator("Recommended".into())); + entries.extend( + self.recommended + .iter() + .map(|info| LanguageModelPickerEntry::Model(info.clone())), + ); + } + + for models in self.other.values() { + if models.is_empty() { + continue; + } + entries.push(LanguageModelPickerEntry::Separator( + models[0].model.provider_name().0, + )); + entries.extend( + models + .iter() + .map(|info| LanguageModelPickerEntry::Model(info.clone())), + ); + } + entries + } +} + +enum LanguageModelPickerEntry { + Model(ModelInfo), + Separator(SharedString), +} + impl PickerDelegate for LanguageModelPickerDelegate { - type ListItem = ListItem; + type ListItem = AnyElement; fn match_count(&self) -> usize { - self.filtered_models.len() + self.filtered_entries.len() } fn selected_index(&self) -> usize { @@ -277,12 +354,24 @@ impl PickerDelegate for LanguageModelPickerDelegate { } fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context>) { - self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1)); + self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1)); cx.notify(); } + fn can_select( + &mut self, + ix: usize, + _window: &mut Window, + _cx: &mut Context>, + ) -> bool { + match self.filtered_entries.get(ix) { + Some(LanguageModelPickerEntry::Model(_)) => true, + Some(LanguageModelPickerEntry::Separator(_)) | None => false, + } + } + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { - "Select a model...".into() + "Select a model…".into() } fn update_matches( @@ -307,44 +396,51 @@ impl PickerDelegate for LanguageModelPickerDelegate { cx.spawn_in(window, async move |this, cx| { let filtered_models = cx .background_spawn(async move { - let displayed_models = if configured_providers.is_empty() { - all_models - } else { - all_models - .into_iter() - .filter(|model_info| { - configured_providers.contains(&model_info.model.provider_id()) - }) - .collect::>() + let matches = |info: &ModelInfo| { + info.model + .name() + .0 + .to_lowercase() + .contains(&query.to_lowercase()) }; - if query.is_empty() { - displayed_models - } else { - displayed_models - .into_iter() - .filter(|model_info| { - model_info - .model - .name() - .0 - .to_lowercase() - .contains(&query.to_lowercase()) - }) - .collect() + let recommended_models = all_models + .recommended + .iter() + .filter(|r| { + configured_providers.contains(&r.model.provider_id()) && matches(r) + }) + .cloned() + .collect(); + let mut other_models = IndexMap::default(); + for (provider_id, models) in &all_models.other { + if configured_providers.contains(&provider_id) { + other_models.insert( + provider_id.clone(), + models + .iter() + .filter(|m| matches(m)) + .cloned() + .collect::>(), + ); + } + } + GroupedModels { + recommended: recommended_models, + other: other_models, } }) .await; this.update_in(cx, |this, window, cx| { - this.delegate.filtered_models = filtered_models; + this.delegate.filtered_entries = filtered_models.entries(); // Preserve selection focus - let new_index = if current_index >= this.delegate.filtered_models.len() { + let new_index = if current_index >= this.delegate.filtered_entries.len() { 0 } else { current_index }; - this.delegate.set_selected_index(new_index, window, cx); + this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx); cx.notify(); }) .ok(); @@ -352,7 +448,9 @@ impl PickerDelegate for LanguageModelPickerDelegate { } fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { - if let Some(model_info) = self.filtered_models.get(self.selected_index) { + if let Some(LanguageModelPickerEntry::Model(model_info)) = + self.filtered_entries.get(self.selected_index) + { let model = model_info.model.clone(); (self.on_model_changed)(model.clone(), cx); @@ -369,29 +467,6 @@ impl PickerDelegate for LanguageModelPickerDelegate { .ok(); } - fn render_header(&self, _: &mut Window, cx: &mut Context>) -> Option { - let configured_models_count = LanguageModelRegistry::global(cx) - .read(cx) - .providers() - .iter() - .filter(|provider| provider.is_authenticated(cx)) - .count(); - - if configured_models_count > 0 { - Some( - Label::new("Configured Models") - .size(LabelSize::Small) - .color(Color::Muted) - .mt_1() - .mb_0p5() - .ml_2() - .into_any_element(), - ) - } else { - None - } - } - fn render_match( &self, ix: usize, @@ -399,77 +474,68 @@ impl PickerDelegate for LanguageModelPickerDelegate { _: &mut Window, cx: &mut Context>, ) -> Option { - use feature_flags::FeatureFlagAppExt; - let show_badges = cx.has_flag::(); + match self.filtered_entries.get(ix)? { + LanguageModelPickerEntry::Separator(title) => Some( + div() + .px_2() + .pb_1() + .when(ix > 1, |this| { + this.mt_1() + .pt_2() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + }) + .child( + Label::new(title) + .size(LabelSize::XSmall) + .color(Color::Muted), + ) + .into_any_element(), + ), + LanguageModelPickerEntry::Model(model_info) => { + let active_model = LanguageModelRegistry::read_global(cx).default_model(); - let model_info = self.filtered_models.get(ix)?; - let provider_name: String = model_info.model.provider_name().0.clone().into(); + let active_provider_id = active_model.as_ref().map(|m| m.provider.id()); + let active_model_id = active_model.map(|m| m.model.id()); - let active_model = LanguageModelRegistry::read_global(cx).default_model(); + let is_selected = Some(model_info.model.provider_id()) == active_provider_id + && Some(model_info.model.id()) == active_model_id; - let active_provider_id = active_model.as_ref().map(|m| m.provider.id()); - let active_model_id = active_model.map(|m| m.model.id()); + let model_icon_color = if is_selected { + Color::Accent + } else { + Color::Muted + }; - let is_selected = Some(model_info.model.provider_id()) == active_provider_id - && Some(model_info.model.id()) == active_model_id; - - let model_icon_color = if is_selected { - Color::Accent - } else { - Color::Muted - }; - - Some( - ListItem::new(ix) - .inset(true) - .spacing(ListItemSpacing::Sparse) - .toggle_state(selected) - .start_slot( - Icon::new(model_info.icon) - .color(model_icon_color) - .size(IconSize::Small), - ) - .child( - h_flex() - .w_full() - .items_center() - .gap_1p5() - .pl_0p5() - .w(px(240.)) - .child( - div() - .max_w_40() - .child(Label::new(model_info.model.name().0.clone()).truncate()), + Some( + ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot( + Icon::new(model_info.icon) + .color(model_icon_color) + .size(IconSize::Small), ) .child( h_flex() - .gap_0p5() - .child( - Label::new(provider_name) - .size(LabelSize::XSmall) - .color(Color::Muted), - ) - .children(match model_info.availability { - LanguageModelAvailability::Public => None, - LanguageModelAvailability::RequiresPlan(Plan::Free) => None, - LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => { - show_badges.then(|| { - Label::new("Pro") - .size(LabelSize::XSmall) - .color(Color::Muted) - }) - } - }), - ), + .w_full() + .pl_0p5() + .gap_1p5() + .w(px(240.)) + .child(Label::new(model_info.model.name().0.clone()).truncate()), + ) + .end_slot(div().pr_3().when(is_selected, |this| { + this.child( + Icon::new(IconName::Check) + .color(Color::Accent) + .size(IconSize::Small), + ) + })) + .into_any_element(), ) - .end_slot(div().pr_3().when(is_selected, |this| { - this.child( - Icon::new(IconName::Check) - .color(Color::Accent) - .size(IconSize::Small), - ) - })), - ) + } + } } fn render_footer( diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index c1bea29691..6f2e11f493 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -47,6 +47,7 @@ settings.workspace = true smol.workspace = true strum.workspace = true theme.workspace = true +thiserror.workspace = true tiktoken-rs.workspace = true tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } ui.workspace = true diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index bce985a872..7746d214b4 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -13,8 +13,9 @@ use gpui::{ use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, - LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role, + LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, MessageContent, + RateLimiter, Role, }; use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; use schemars::JsonSchema; @@ -192,6 +193,16 @@ impl AnthropicLanguageModelProvider { Self { http_client, state } } + + fn create_language_model(&self, model: anthropic::Model) -> Arc { + Arc::new(AnthropicModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) as Arc + } } impl LanguageModelProviderState for AnthropicLanguageModelProvider { @@ -226,6 +237,16 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { })) } + fn recommended_models(&self, _cx: &App) -> Vec> { + [ + anthropic::Model::Claude3_7Sonnet, + anthropic::Model::Claude3_7SonnetThinking, + ] + .into_iter() + .map(|model| self.create_language_model(model)) + .collect() + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); @@ -266,15 +287,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { models .into_values() - .map(|model| { - Arc::new(AnthropicModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model)) .collect() } @@ -442,7 +455,12 @@ impl LanguageModel for AnthropicModel { ); let request = self.stream_completion(request, cx); let future = self.request_limiter.stream(async move { - let response = request.await.map_err(|err| anyhow!(err))?; + let response = request + .await + .map_err(|err| match err.downcast::() { + Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err), + Err(err) => anyhow!(err), + })?; Ok(map_to_language_model_completion_events(response)) }); async move { Ok(future.await?.boxed()) }.boxed() @@ -734,7 +752,7 @@ pub fn map_to_language_model_completion_events( _ => {} }, Err(err) => { - return Some((vec![Err(anyhow!(err))], state)); + return Some((vec![Err(anthropic_err_to_anyhow(err))], state)); } } } @@ -745,6 +763,16 @@ pub fn map_to_language_model_completion_events( .flat_map(futures::stream::iter) } +pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error { + if let AnthropicError::ApiError(api_err) = &err { + if let Some(tokens) = api_err.match_window_exceeded() { + return anyhow!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens }); + } + } + + anyhow!(err) +} + /// Updates usage data by preferring counts from `new`. fn update_usage(usage: &mut Usage, new: &Usage) { if let Some(input_tokens) = new.input_tokens { diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 9377bf315f..38d8c79d35 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,4 +1,4 @@ -use anthropic::{AnthropicError, AnthropicModelMode}; +use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long}; use anyhow::{Result, anyhow}; use client::{ Client, EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, @@ -14,7 +14,7 @@ use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Ta use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use language_model::{ AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, - LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat, RateLimiter, ZED_CLOUD_PROVIDER_ID, }; @@ -33,6 +33,7 @@ use std::{ time::Duration, }; use strum::IntoEnumIterator; +use thiserror::Error; use ui::{TintColor, prelude::*}; use crate::AllLanguageModelSettings; @@ -225,6 +226,20 @@ impl CloudLanguageModelProvider { _maintain_client_status: maintain_client_status, } } + + fn create_language_model( + &self, + model: CloudModel, + llm_api_token: LlmApiToken, + ) -> Arc { + Arc::new(CloudLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + llm_api_token: llm_api_token.clone(), + client: self.client.clone(), + request_limiter: RateLimiter::new(4), + }) as Arc + } } impl LanguageModelProviderState for CloudLanguageModelProvider { @@ -260,6 +275,17 @@ impl LanguageModelProvider for CloudLanguageModelProvider { })) } + fn recommended_models(&self, cx: &App) -> Vec> { + let llm_api_token = self.state.read(cx).llm_api_token.clone(); + [ + CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet), + CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking), + ] + .into_iter() + .map(|model| self.create_language_model(model, llm_api_token.clone())) + .collect() + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); @@ -345,15 +371,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { let llm_api_token = self.state.read(cx).llm_api_token.clone(); models .into_values() - .map(|model| { - Arc::new(CloudLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - llm_api_token: llm_api_token.clone(), - client: self.client.clone(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model, llm_api_token.clone())) .collect() } @@ -558,14 +576,19 @@ impl CloudLanguageModel { } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; - return Err(anyhow!( - "cloud language model completion failed with status {status}: {body}", - )); + return Err(anyhow!(ApiError { status, body })); } } } } +#[derive(Debug, Error)] +#[error("cloud language model completion failed with status {status}: {body}")] +struct ApiError { + status: StatusCode, + body: String, +} + impl LanguageModel for CloudLanguageModel { fn id(&self) -> LanguageModelId { self.id.clone() @@ -575,10 +598,6 @@ impl LanguageModel for CloudLanguageModel { LanguageModelName::from(self.model.display_name().to_string()) } - fn icon(&self) -> Option { - self.model.icon() - } - fn provider_id(&self) -> LanguageModelProviderId { LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into()) } @@ -683,7 +702,23 @@ impl LanguageModel for CloudLanguageModel { )?)?, }, ) - .await?; + .await + .map_err(|err| match err.downcast::() { + Ok(api_err) => { + if api_err.status == StatusCode::BAD_REQUEST { + if let Some(tokens) = parse_prompt_too_long(&api_err.body) { + return anyhow!( + LanguageModelKnownError::ContextWindowLimitExceeded { + tokens + } + ); + } + } + anyhow!(api_err) + } + Err(err) => anyhow!(err), + })?; + Ok( crate::provider::anthropic::map_to_language_model_completion_events( Box::pin(response_lines(response).map_err(AnthropicError::Other)), diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 827ca3f190..cde252e04a 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -210,17 +210,21 @@ impl LanguageModel for CopilotChatLanguageModel { CopilotChatModel::Claude3_5Sonnet => count_anthropic_tokens(request, cx), CopilotChatModel::Claude3_7Sonnet => count_anthropic_tokens(request, cx), CopilotChatModel::Claude3_7SonnetThinking => count_anthropic_tokens(request, cx), - CopilotChatModel::Gemini20Flash => count_google_tokens(request, cx), + CopilotChatModel::Gemini20Flash | CopilotChatModel::Gemini25Pro => { + count_google_tokens(request, cx) + } _ => { let model = match self.model { CopilotChatModel::Gpt4o => open_ai::Model::FourOmni, CopilotChatModel::Gpt4 => open_ai::Model::Four, + CopilotChatModel::Gpt4_1 => open_ai::Model::FourPointOne, CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo, CopilotChatModel::O1 | CopilotChatModel::O3Mini => open_ai::Model::Four, CopilotChatModel::Claude3_5Sonnet | CopilotChatModel::Claude3_7Sonnet | CopilotChatModel::Claude3_7SonnetThinking - | CopilotChatModel::Gemini20Flash => { + | CopilotChatModel::Gemini20Flash + | CopilotChatModel::Gemini25Pro => { unreachable!() } }; diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 4da5f255d2..36a01a30c2 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -417,17 +417,19 @@ pub fn into_google( top_k: None, }), safety_settings: None, - tools: Some(vec![google_ai::Tool { - function_declarations: request - .tools - .into_iter() - .map(|tool| FunctionDeclaration { - name: tool.name, - description: tool.description, - parameters: tool.input_schema, - }) - .collect(), - }]), + tools: (request.tools.len() > 0).then(|| { + vec![google_ai::Tool { + function_declarations: request + .tools + .into_iter() + .map(|tool| FunctionDeclaration { + name: tool.name, + description: tool.description, + parameters: tool.input_schema, + }) + .collect(), + }] + }), tool_config: None, } } diff --git a/crates/languages/src/python/config.toml b/crates/languages/src/python/config.toml index 836059bf96..6749f39060 100644 --- a/crates/languages/src/python/config.toml +++ b/crates/languages/src/python/config.toml @@ -5,6 +5,18 @@ first_line_pattern = '^#!.*\bpython[0-9.]*\b' line_comments = ["# "] autoclose_before = ";:.,=}])>" brackets = [ + { start = "f\"", end = "\"", close = true, newline = false, not_in = ["string", "comment"] }, + { start = "f'", end = "'", close = true, newline = false, not_in = ["string", "comment"] }, + { start = "b\"", end = "\"", close = true, newline = false, not_in = ["string", "comment"] }, + { start = "b'", end = "'", close = true, newline = false, not_in = ["string", "comment"] }, + { start = "u\"", end = "\"", close = true, newline = false, not_in = ["string", "comment"] }, + { start = "u'", end = "'", close = true, newline = false, not_in = ["string", "comment"] }, + { start = "r\"", end = "\"", close = true, newline = false, not_in = ["string", "comment"] }, + { start = "r'", end = "'", close = true, newline = false, not_in = ["string", "comment"] }, + { start = "rb\"", end = "\"", close = true, newline = false, not_in = ["string", "comment"] }, + { start = "rb'", end = "'", close = true, newline = false, not_in = ["string", "comment"] }, + { start = "t\"", end = "\"", close = true, newline = false, not_in = ["string", "comment"] }, + { start = "t'", end = "'", close = true, newline = false, not_in = ["string", "comment"] }, { start = "\"\"\"", end = "\"\"\"", close = true, newline = false, not_in = ["string"] }, { start = "'''", end = "'''", close = true, newline = false, not_in = ["string"] }, { start = "{", end = "}", close = true, newline = true }, diff --git a/crates/languages/src/python/runnables.scm b/crates/languages/src/python/runnables.scm index 3b32556707..84a024de94 100644 --- a/crates/languages/src/python/runnables.scm +++ b/crates/languages/src/python/runnables.scm @@ -83,6 +83,26 @@ ) ) +; decorated pytest class methods +( + (module + (class_definition + name: (identifier) @_pytest_class_name + (#match? @_pytest_class_name "^Test") + body: (block + (decorated_definition + (decorator)+ @_decorator + definition: (function_definition + name: (identifier) @run @_pytest_method_name + (#match? @_pytest_method_name "^test_") + ) + ) + ) @_python-pytest-method + (#set! tag python-pytest-method) + ) + ) +) + ; module main method ( (module diff --git a/crates/markdown/src/markdown.rs b/crates/markdown/src/markdown.rs index 85dcefc40a..16f4f621e0 100644 --- a/crates/markdown/src/markdown.rs +++ b/crates/markdown/src/markdown.rs @@ -32,6 +32,17 @@ use crate::parser::CodeBlockKind; /// If the callback returns `None`, the default link style will be used. type LinkStyleCallback = Rc Option>; +/// Defines custom style refinements for each heading level (H1-H6) +#[derive(Clone, Default)] +pub struct HeadingLevelStyles { + pub h1: Option, + pub h2: Option, + pub h3: Option, + pub h4: Option, + pub h5: Option, + pub h6: Option, +} + #[derive(Clone)] pub struct MarkdownStyle { pub base_text_style: TextStyle, @@ -46,6 +57,7 @@ pub struct MarkdownStyle { pub syntax: Arc, pub selection_background_color: Hsla, pub heading: StyleRefinement, + pub heading_level_styles: Option, pub table_overflow_x_scroll: bool, } @@ -64,6 +76,7 @@ impl Default for MarkdownStyle { syntax: Arc::new(SyntaxTheme::default()), selection_background_color: Default::default(), heading: Default::default(), + heading_level_styles: None, table_overflow_x_scroll: false, } } @@ -628,17 +641,19 @@ impl Element for MarkdownElement { } MarkdownTag::Heading { level, .. } => { let mut heading = div().mb_2(); - heading = match level { - pulldown_cmark::HeadingLevel::H1 => heading.text_3xl(), - pulldown_cmark::HeadingLevel::H2 => heading.text_2xl(), - pulldown_cmark::HeadingLevel::H3 => heading.text_xl(), - pulldown_cmark::HeadingLevel::H4 => heading.text_lg(), - _ => heading, - }; - heading.style().refine(&self.style.heading); - builder.push_text_style( - self.style.heading.text_style().clone().unwrap_or_default(), + + heading = apply_heading_style( + heading, + *level, + self.style.heading_level_styles.as_ref(), ); + + heading.style().refine(&self.style.heading); + + let text_style = + self.style.heading.text_style().clone().unwrap_or_default(); + + builder.push_text_style(text_style); builder.push_div(heading, range, markdown_end); } MarkdownTag::BlockQuote => { @@ -1043,6 +1058,38 @@ impl Element for MarkdownElement { } } +fn apply_heading_style( + mut heading: Div, + level: pulldown_cmark::HeadingLevel, + custom_styles: Option<&HeadingLevelStyles>, +) -> Div { + heading = match level { + pulldown_cmark::HeadingLevel::H1 => heading.text_3xl(), + pulldown_cmark::HeadingLevel::H2 => heading.text_2xl(), + pulldown_cmark::HeadingLevel::H3 => heading.text_xl(), + pulldown_cmark::HeadingLevel::H4 => heading.text_lg(), + pulldown_cmark::HeadingLevel::H5 => heading.text_base(), + pulldown_cmark::HeadingLevel::H6 => heading.text_sm(), + }; + + if let Some(styles) = custom_styles { + let style_opt = match level { + pulldown_cmark::HeadingLevel::H1 => &styles.h1, + pulldown_cmark::HeadingLevel::H2 => &styles.h2, + pulldown_cmark::HeadingLevel::H3 => &styles.h3, + pulldown_cmark::HeadingLevel::H4 => &styles.h4, + pulldown_cmark::HeadingLevel::H5 => &styles.h5, + pulldown_cmark::HeadingLevel::H6 => &styles.h6, + }; + + if let Some(style) = style_opt { + heading.style().text = Some(style.clone()); + } + } + + heading +} + fn render_copy_code_block_button( id: usize, code: String, diff --git a/crates/multi_buffer/src/multi_buffer.rs b/crates/multi_buffer/src/multi_buffer.rs index 18c17b3b02..994815910c 100644 --- a/crates/multi_buffer/src/multi_buffer.rs +++ b/crates/multi_buffer/src/multi_buffer.rs @@ -7694,7 +7694,12 @@ impl ToOffset for Point { impl ToOffset for usize { #[track_caller] fn to_offset<'a>(&self, snapshot: &MultiBufferSnapshot) -> usize { - assert!(*self <= snapshot.len(), "offset is out of range"); + assert!( + *self <= snapshot.len(), + "offset {} is greater than the snapshot.len() {}", + *self, + snapshot.len(), + ); *self } } diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index b9aa2ce7f0..0aee8f4345 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -71,6 +71,12 @@ pub enum Model { FourOmni, #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")] FourOmniMini, + #[serde(rename = "gpt-4.1", alias = "gpt-4.1")] + FourPointOne, + #[serde(rename = "gpt-4.1-mini", alias = "gpt-4.1-mini")] + FourPointOneMini, + #[serde(rename = "gpt-4.1-nano", alias = "gpt-4.1-nano")] + FourPointOneNano, #[serde(rename = "o1", alias = "o1")] O1, #[serde(rename = "o1-preview", alias = "o1-preview")] @@ -99,6 +105,9 @@ impl Model { "gpt-4-turbo-preview" => Ok(Self::FourTurbo), "gpt-4o" => Ok(Self::FourOmni), "gpt-4o-mini" => Ok(Self::FourOmniMini), + "gpt-4.1" => Ok(Self::FourPointOne), + "gpt-4.1-mini" => Ok(Self::FourPointOneMini), + "gpt-4.1-nano" => Ok(Self::FourPointOneNano), "o1" => Ok(Self::O1), "o1-preview" => Ok(Self::O1Preview), "o1-mini" => Ok(Self::O1Mini), @@ -114,6 +123,9 @@ impl Model { Self::FourTurbo => "gpt-4-turbo", Self::FourOmni => "gpt-4o", Self::FourOmniMini => "gpt-4o-mini", + Self::FourPointOne => "gpt-4.1", + Self::FourPointOneMini => "gpt-4.1-mini", + Self::FourPointOneNano => "gpt-4.1-nano", Self::O1 => "o1", Self::O1Preview => "o1-preview", Self::O1Mini => "o1-mini", @@ -129,6 +141,9 @@ impl Model { Self::FourTurbo => "gpt-4-turbo", Self::FourOmni => "gpt-4o", Self::FourOmniMini => "gpt-4o-mini", + Self::FourPointOne => "gpt-4.1", + Self::FourPointOneMini => "gpt-4.1-mini", + Self::FourPointOneNano => "gpt-4.1-nano", Self::O1 => "o1", Self::O1Preview => "o1-preview", Self::O1Mini => "o1-mini", @@ -146,6 +161,9 @@ impl Model { Self::FourTurbo => 128_000, Self::FourOmni => 128_000, Self::FourOmniMini => 128_000, + Self::FourPointOne => 1_047_576, + Self::FourPointOneMini => 1_047_576, + Self::FourPointOneNano => 1_047_576, Self::O1 => 200_000, Self::O1Preview => 128_000, Self::O1Mini => 128_000, @@ -173,6 +191,9 @@ impl Model { | Self::FourTurbo | Self::FourOmni | Self::FourOmniMini + | Self::FourPointOne + | Self::FourPointOneMini + | Self::FourPointOneNano | Self::O1 | Self::O1Preview | Self::O1Mini => true, diff --git a/crates/paths/src/paths.rs b/crates/paths/src/paths.rs index 622e4a67f3..9f7644ee9b 100644 --- a/crates/paths/src/paths.rs +++ b/crates/paths/src/paths.rs @@ -375,6 +375,11 @@ pub fn local_settings_folder_relative_path() -> &'static Path { Path::new(".zed") } +/// Returns the relative path to a `.vscode` folder within a project. +pub fn local_vscode_folder_relative_path() -> &'static Path { + Path::new(".vscode") +} + /// Returns the relative path to a `settings.json` file within a project. pub fn local_settings_file_relative_path() -> &'static Path { Path::new(".zed/settings.json") diff --git a/crates/picker/src/picker.rs b/crates/picker/src/picker.rs index 2caa9ff756..54b50453ce 100644 --- a/crates/picker/src/picker.rs +++ b/crates/picker/src/picker.rs @@ -3,8 +3,8 @@ use editor::{Editor, scroll::Autoscroll}; use gpui::{ AnyElement, App, ClickEvent, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Length, ListSizingBehavior, ListState, MouseButton, MouseUpEvent, Render, - ScrollHandle, ScrollStrategy, Stateful, Task, UniformListScrollHandle, Window, actions, div, - impl_actions, list, prelude::*, uniform_list, + ScrollStrategy, Stateful, Task, UniformListScrollHandle, Window, actions, div, impl_actions, + list, prelude::*, uniform_list, }; use head::Head; use schemars::JsonSchema; @@ -24,6 +24,11 @@ enum ElementContainer { UniformList(UniformListScrollHandle), } +pub enum Direction { + Up, + Down, +} + actions!(picker, [ConfirmCompletion]); /// ConfirmInput is an alternative editor action which - instead of selecting active picker entry - treats pickers editor input literally, @@ -86,6 +91,15 @@ pub trait PickerDelegate: Sized + 'static { window: &mut Window, cx: &mut Context>, ); + fn can_select( + &mut self, + _ix: usize, + _window: &mut Window, + _cx: &mut Context>, + ) -> bool { + true + } + // Allows binding some optional effect to when the selection changes. fn selected_index_changed( &self, @@ -271,10 +285,7 @@ impl Picker { ElementContainer::UniformList(scroll_handle) => { ScrollbarState::new(scroll_handle.clone()) } - ElementContainer::List(_) => { - // todo smit: implement for list - ScrollbarState::new(ScrollHandle::new()) - } + ElementContainer::List(state) => ScrollbarState::new(state.clone()), }; let focus_handle = cx.focus_handle(); let mut this = Self { @@ -359,16 +370,58 @@ impl Picker { } /// Handles the selecting an index, and passing the change to the delegate. - /// If `scroll_to_index` is true, the new selected index will be scrolled into view. + /// If `fallback_direction` is set to `None`, the index will not be selected + /// if the element at that index cannot be selected. + /// If `fallback_direction` is set to + /// `Some(..)`, the next selectable element will be selected in the + /// specified direction (Down or Up), cycling through all elements until + /// finding one that can be selected or returning if there are no selectable elements. + /// If `scroll_to_index` is true, the new selected index will be scrolled into + /// view. /// /// If some effect is bound to `selected_index_changed`, it will be executed. pub fn set_selected_index( &mut self, - ix: usize, + mut ix: usize, + fallback_direction: Option, scroll_to_index: bool, window: &mut Window, cx: &mut Context, ) { + let match_count = self.delegate.match_count(); + if match_count == 0 { + return; + } + + if let Some(bias) = fallback_direction { + let mut curr_ix = ix; + while !self.delegate.can_select(curr_ix, window, cx) { + curr_ix = match bias { + Direction::Down => { + if curr_ix == match_count - 1 { + 0 + } else { + curr_ix + 1 + } + } + Direction::Up => { + if curr_ix == 0 { + match_count - 1 + } else { + curr_ix - 1 + } + } + }; + // There is no item that can be selected + if ix == curr_ix { + return; + } + } + ix = curr_ix; + } else if !self.delegate.can_select(ix, window, cx) { + return; + } + let previous_index = self.delegate.selected_index(); self.delegate.set_selected_index(ix, window, cx); let current_index = self.delegate.selected_index(); @@ -393,7 +446,7 @@ impl Picker { if count > 0 { let index = self.delegate.selected_index(); let ix = if index == count - 1 { 0 } else { index + 1 }; - self.set_selected_index(ix, true, window, cx); + self.set_selected_index(ix, Some(Direction::Down), true, window, cx); cx.notify(); } } @@ -408,7 +461,7 @@ impl Picker { if count > 0 { let index = self.delegate.selected_index(); let ix = if index == 0 { count - 1 } else { index - 1 }; - self.set_selected_index(ix, true, window, cx); + self.set_selected_index(ix, Some(Direction::Up), true, window, cx); cx.notify(); } } @@ -416,7 +469,7 @@ impl Picker { fn select_first(&mut self, _: &menu::SelectFirst, window: &mut Window, cx: &mut Context) { let count = self.delegate.match_count(); if count > 0 { - self.set_selected_index(0, true, window, cx); + self.set_selected_index(0, Some(Direction::Down), true, window, cx); cx.notify(); } } @@ -424,7 +477,7 @@ impl Picker { fn select_last(&mut self, _: &menu::SelectLast, window: &mut Window, cx: &mut Context) { let count = self.delegate.match_count(); if count > 0 { - self.set_selected_index(count - 1, true, window, cx); + self.set_selected_index(count - 1, Some(Direction::Up), true, window, cx); cx.notify(); } } @@ -433,7 +486,7 @@ impl Picker { let count = self.delegate.match_count(); let index = self.delegate.selected_index(); let new_index = if index + 1 == count { 0 } else { index + 1 }; - self.set_selected_index(new_index, true, window, cx); + self.set_selected_index(new_index, Some(Direction::Down), true, window, cx); cx.notify(); } @@ -506,14 +559,14 @@ impl Picker { ) { cx.stop_propagation(); window.prevent_default(); - self.set_selected_index(ix, false, window, cx); + self.set_selected_index(ix, None, false, window, cx); self.do_confirm(secondary, window, cx) } fn do_confirm(&mut self, secondary: bool, window: &mut Window, cx: &mut Context) { if let Some(update_query) = self.delegate.confirm_update_query(window, cx) { self.set_query(update_query, window, cx); - self.delegate.set_selected_index(0, window, cx); + self.set_selected_index(0, Some(Direction::Down), false, window, cx); } else { self.delegate.confirm(secondary, window, cx) } diff --git a/crates/project/src/buffer_store.rs b/crates/project/src/buffer_store.rs index c31f1bfbe5..815ba19ea9 100644 --- a/crates/project/src/buffer_store.rs +++ b/crates/project/src/buffer_store.rs @@ -36,6 +36,7 @@ pub struct BufferStore { loading_buffers: HashMap, Arc>>>>, worktree_store: Entity, opened_buffers: HashMap, + path_to_buffer_id: HashMap, downstream_client: Option<(AnyProtoClient, u64)>, shared_buffers: HashMap>, } @@ -62,7 +63,6 @@ struct RemoteBufferStore { } struct LocalBufferStore { - local_buffer_ids_by_path: HashMap, local_buffer_ids_by_entry_id: HashMap, worktree_store: Entity, _subscription: Subscription, @@ -368,8 +368,9 @@ impl LocalBufferStore { let line_ending = buffer.line_ending(); let version = buffer.version(); let buffer_id = buffer.remote_id(); - if buffer - .file() + let file = buffer.file().cloned(); + if file + .as_ref() .is_some_and(|file| file.disk_state() == DiskState::New) { has_changed_file = true; @@ -462,13 +463,11 @@ impl LocalBufferStore { path: path.clone(), }; - let buffer_id = { - let local = this.as_local_mut()?; - match local.local_buffer_ids_by_entry_id.get(&entry_id) { - Some(&buffer_id) => buffer_id, - None => local.local_buffer_ids_by_path.get(&project_path).copied()?, - } - }; + let buffer_id = this + .as_local_mut() + .and_then(|local| local.local_buffer_ids_by_entry_id.get(&entry_id)) + .copied() + .or_else(|| this.path_to_buffer_id.get(&project_path).copied())?; let buffer = if let Some(buffer) = this.get(buffer_id) { Some(buffer) @@ -480,14 +479,13 @@ impl LocalBufferStore { let buffer = if let Some(buffer) = buffer { buffer } else { + this.path_to_buffer_id.remove(&project_path); let this = this.as_local_mut()?; - this.local_buffer_ids_by_path.remove(&project_path); this.local_buffer_ids_by_entry_id.remove(&entry_id); return None; }; let events = buffer.update(cx, |buffer, cx| { - let local = this.as_local_mut()?; let file = buffer.file()?; let old_file = File::from_dyn(Some(file))?; if old_file.worktree != *worktree { @@ -528,11 +526,11 @@ impl LocalBufferStore { let mut events = Vec::new(); if new_file.path != old_file.path { - local.local_buffer_ids_by_path.remove(&ProjectPath { + this.path_to_buffer_id.remove(&ProjectPath { path: old_file.path.clone(), worktree_id: old_file.worktree_id(cx), }); - local.local_buffer_ids_by_path.insert( + this.path_to_buffer_id.insert( ProjectPath { worktree_id: new_file.worktree_id(cx), path: new_file.path.clone(), @@ -544,7 +542,7 @@ impl LocalBufferStore { old_file: buffer.file().cloned(), }); } - + let local = this.as_local_mut()?; if new_file.entry_id != old_file.entry_id { if let Some(entry_id) = old_file.entry_id { local.local_buffer_ids_by_entry_id.remove(&entry_id); @@ -577,32 +575,6 @@ impl LocalBufferStore { None } - fn buffer_changed_file(&mut self, buffer: Entity, cx: &mut App) -> Option<()> { - let file = File::from_dyn(buffer.read(cx).file())?; - - let remote_id = buffer.read(cx).remote_id(); - if let Some(entry_id) = file.entry_id { - match self.local_buffer_ids_by_entry_id.get(&entry_id) { - Some(_) => { - return None; - } - None => { - self.local_buffer_ids_by_entry_id - .insert(entry_id, remote_id); - } - } - }; - self.local_buffer_ids_by_path.insert( - ProjectPath { - worktree_id: file.worktree_id(cx), - path: file.path.clone(), - }, - remote_id, - ); - - Some(()) - } - fn save_buffer( &self, buffer: Entity, @@ -677,15 +649,14 @@ impl LocalBufferStore { this.add_buffer(buffer.clone(), cx)?; let buffer_id = buffer.read(cx).remote_id(); if let Some(file) = File::from_dyn(buffer.read(cx).file()) { - let this = this.as_local_mut().unwrap(); - this.local_buffer_ids_by_path.insert( + this.path_to_buffer_id.insert( ProjectPath { worktree_id: file.worktree_id(cx), path: file.path.clone(), }, buffer_id, ); - + let this = this.as_local_mut().unwrap(); if let Some(entry_id) = file.entry_id { this.local_buffer_ids_by_entry_id .insert(entry_id, buffer_id); @@ -748,7 +719,6 @@ impl BufferStore { pub fn local(worktree_store: Entity, cx: &mut Context) -> Self { Self { state: BufferStoreState::Local(LocalBufferStore { - local_buffer_ids_by_path: Default::default(), local_buffer_ids_by_entry_id: Default::default(), worktree_store: worktree_store.clone(), _subscription: cx.subscribe(&worktree_store, |this, _, event, cx| { @@ -760,6 +730,7 @@ impl BufferStore { }), downstream_client: None, opened_buffers: Default::default(), + path_to_buffer_id: Default::default(), shared_buffers: Default::default(), loading_buffers: Default::default(), worktree_store, @@ -783,19 +754,13 @@ impl BufferStore { }), downstream_client: None, opened_buffers: Default::default(), + path_to_buffer_id: Default::default(), loading_buffers: Default::default(), shared_buffers: Default::default(), worktree_store, } } - fn as_local(&self) -> Option<&LocalBufferStore> { - match &self.state { - BufferStoreState::Local(state) => Some(state), - _ => None, - } - } - fn as_local_mut(&mut self) -> Option<&mut LocalBufferStore> { match &mut self.state { BufferStoreState::Local(state) => Some(state), @@ -915,6 +880,10 @@ impl BufferStore { fn add_buffer(&mut self, buffer_entity: Entity, cx: &mut Context) -> Result<()> { let buffer = buffer_entity.read(cx); let remote_id = buffer.remote_id(); + let path = File::from_dyn(buffer.file()).map(|file| ProjectPath { + path: file.path.clone(), + worktree_id: file.worktree_id(cx), + }); let is_remote = buffer.replica_id() != 0; let open_buffer = OpenBuffer::Complete { buffer: buffer_entity.downgrade(), @@ -931,10 +900,11 @@ impl BufferStore { }) .detach() }); - + let _expect_path_to_exist; match self.opened_buffers.entry(remote_id) { hash_map::Entry::Vacant(entry) => { entry.insert(open_buffer); + _expect_path_to_exist = false; } hash_map::Entry::Occupied(mut entry) => { if let OpenBuffer::Operations(operations) = entry.get_mut() { @@ -948,9 +918,14 @@ impl BufferStore { } } entry.insert(open_buffer); + _expect_path_to_exist = true; } } + if let Some(path) = path { + self.path_to_buffer_id.insert(path, remote_id); + } + cx.subscribe(&buffer_entity, Self::on_buffer_event).detach(); cx.emit(BufferStoreEvent::BufferAdded(buffer_entity)); Ok(()) @@ -972,18 +947,13 @@ impl BufferStore { } pub fn buffer_id_for_project_path(&self, project_path: &ProjectPath) -> Option<&BufferId> { - self.as_local() - .and_then(|state| state.local_buffer_ids_by_path.get(project_path)) + self.path_to_buffer_id.get(project_path) } - pub fn get_by_path(&self, path: &ProjectPath, cx: &App) -> Option> { - self.buffers().find_map(|buffer| { - let file = File::from_dyn(buffer.read(cx).file())?; - if file.worktree_id(cx) == path.worktree_id && file.path == path.path { - Some(buffer) - } else { - None - } + pub fn get_by_path(&self, path: &ProjectPath, _cx: &App) -> Option> { + self.path_to_buffer_id.get(path).and_then(|buffer_id| { + let buffer = self.get(*buffer_id); + buffer }) } @@ -1055,6 +1025,35 @@ impl BufferStore { .retain(|_, buffer| !matches!(buffer, OpenBuffer::Operations(_))); } + fn buffer_changed_file(&mut self, buffer: Entity, cx: &mut App) -> Option<()> { + let file = File::from_dyn(buffer.read(cx).file())?; + + let remote_id = buffer.read(cx).remote_id(); + if let Some(entry_id) = file.entry_id { + if let Some(local) = self.as_local_mut() { + match local.local_buffer_ids_by_entry_id.get(&entry_id) { + Some(_) => { + return None; + } + None => { + local + .local_buffer_ids_by_entry_id + .insert(entry_id, remote_id); + } + } + } + self.path_to_buffer_id.insert( + ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path.clone(), + }, + remote_id, + ); + }; + + Some(()) + } + pub fn find_search_candidates( &mut self, query: &SearchQuery, @@ -1118,9 +1117,7 @@ impl BufferStore { ) { match event { BufferEvent::FileHandleChanged => { - if let Some(local) = self.as_local_mut() { - local.buffer_changed_file(buffer, cx); - } + self.buffer_changed_file(buffer, cx); } BufferEvent::Reloaded => { let Some((downstream_client, project_id)) = self.downstream_client.as_ref() else { @@ -1316,6 +1313,7 @@ impl BufferStore { let old_file = buffer.update(cx, |buffer, cx| { let old_file = buffer.file().cloned(); let new_path = file.path.clone(); + buffer.file_updated(Arc::new(file), cx); if old_file .as_ref() @@ -1606,18 +1604,17 @@ impl BufferStore { self.add_buffer(buffer.clone(), cx).log_err(); let buffer_id = buffer.read(cx).remote_id(); - let this = self - .as_local_mut() - .expect("local-only method called in a non-local context"); if let Some(file) = File::from_dyn(buffer.read(cx).file()) { - this.local_buffer_ids_by_path.insert( + self.path_to_buffer_id.insert( ProjectPath { worktree_id: file.worktree_id(cx), path: file.path.clone(), }, buffer_id, ); - + let this = self + .as_local_mut() + .expect("local-only method called in a non-local context"); if let Some(entry_id) = file.entry_id { this.local_buffer_ids_by_entry_id .insert(entry_id, buffer_id); diff --git a/crates/project/src/debugger/session.rs b/crates/project/src/debugger/session.rs index ff861771ef..401494ddcc 100644 --- a/crates/project/src/debugger/session.rs +++ b/crates/project/src/debugger/session.rs @@ -30,7 +30,8 @@ use dap::{ use futures::channel::oneshot; use futures::{FutureExt, future::Shared}; use gpui::{ - App, AppContext, AsyncApp, BackgroundExecutor, Context, Entity, EventEmitter, Task, WeakEntity, + App, AppContext, AsyncApp, BackgroundExecutor, Context, Entity, EventEmitter, SharedString, + Task, WeakEntity, }; use rpc::AnyProtoClient; use serde_json::{Value, json}; @@ -125,6 +126,7 @@ type UpstreamProjectId = u64; struct RemoteConnection { _client: AnyProtoClient, _upstream_project_id: UpstreamProjectId, + _adapter_name: SharedString, } impl RemoteConnection { @@ -996,6 +998,7 @@ impl Session { ) -> Self { Self { mode: Mode::Remote(RemoteConnection { + _adapter_name: SharedString::new(""), // todo(debugger) we need to pipe in the right values to deserialize the debugger pane layout _client: client, _upstream_project_id: upstream_project_id, }), @@ -1044,6 +1047,13 @@ impl Session { &self.capabilities } + pub fn adapter_name(&self) -> SharedString { + match &self.mode { + Mode::Local(local_mode) => local_mode.adapter.name().into(), + Mode::Remote(remote_mode) => remote_mode._adapter_name.clone(), + } + } + pub fn configuration(&self) -> Option { if let Mode::Local(local_mode) = &self.mode { Some(local_mode.config.clone()) diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index a2386ad34c..024b347d19 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -21,7 +21,7 @@ use git::{ blame::Blame, parse_git_remote_url, repository::{ - Branch, CommitDetails, CommitDiff, CommitFile, DiffType, GitRepository, + Branch, CommitDetails, CommitDiff, CommitFile, CommitOptions, DiffType, GitRepository, GitRepositoryCheckpoint, PushOptions, Remote, RemoteCommandOutput, RepoPath, ResetMode, UpstreamTrackingStatus, }, @@ -61,7 +61,8 @@ use sum_tree::{Edit, SumTree, TreeSet}; use text::{Bias, BufferId}; use util::{ResultExt, debug_panic, post_inc}; use worktree::{ - File, PathKey, PathProgress, PathSummary, PathTarget, UpdatedGitRepositoriesSet, Worktree, + File, PathKey, PathProgress, PathSummary, PathTarget, UpdatedGitRepositoriesSet, + UpdatedGitRepository, Worktree, }; pub struct GitStore { @@ -1144,18 +1145,23 @@ impl GitStore { } else { removed_ids.push(*id); } - } else if let Some((work_directory_abs_path, dot_git_abs_path)) = update - .new_work_directory_abs_path - .clone() - .zip(update.dot_git_abs_path.clone()) + } else if let UpdatedGitRepository { + new_work_directory_abs_path: Some(work_directory_abs_path), + dot_git_abs_path: Some(dot_git_abs_path), + repository_dir_abs_path: Some(repository_dir_abs_path), + common_dir_abs_path: Some(common_dir_abs_path), + .. + } = update { let id = RepositoryId(next_repository_id.fetch_add(1, atomic::Ordering::Release)); let git_store = cx.weak_entity(); let repo = cx.new(|cx| { let mut repo = Repository::local( id, - work_directory_abs_path, - dot_git_abs_path, + work_directory_abs_path.clone(), + dot_git_abs_path.clone(), + repository_dir_abs_path.clone(), + common_dir_abs_path.clone(), project_environment.downgrade(), fs.clone(), git_store, @@ -1650,10 +1656,18 @@ impl GitStore { let message = SharedString::from(envelope.payload.message); let name = envelope.payload.name.map(SharedString::from); let email = envelope.payload.email.map(SharedString::from); + let options = envelope.payload.options.unwrap_or_default(); repository_handle .update(&mut cx, |repository_handle, cx| { - repository_handle.commit(message, name.zip(email), cx) + repository_handle.commit( + message, + name.zip(email), + CommitOptions { + amend: options.amend, + }, + cx, + ) })? .await??; Ok(proto::Ack {}) @@ -2542,6 +2556,8 @@ impl Repository { id: RepositoryId, work_directory_abs_path: Arc, dot_git_abs_path: Arc, + repository_dir_abs_path: Arc, + common_dir_abs_path: Arc, project_environment: WeakEntity, fs: Arc, git_store: WeakEntity, @@ -2559,6 +2575,8 @@ impl Repository { job_sender: Repository::spawn_local_git_worker( work_directory_abs_path, dot_git_abs_path, + repository_dir_abs_path, + common_dir_abs_path, project_environment, fs, cx, @@ -3238,6 +3256,7 @@ impl Repository { &mut self, message: SharedString, name_and_email: Option<(SharedString, SharedString)>, + options: CommitOptions, _cx: &mut App, ) -> oneshot::Receiver> { let id = self.id; @@ -3248,7 +3267,11 @@ impl Repository { backend, environment, .. - } => backend.commit(message, name_and_email, environment).await, + } => { + backend + .commit(message, name_and_email, options, environment) + .await + } RepositoryState::Remote { project_id, client } => { let (name, email) = name_and_email.unzip(); client @@ -3258,6 +3281,9 @@ impl Repository { message: String::from(message), name: name.map(String::from), email: email.map(String::from), + options: Some(proto::commit::CommitOptions { + amend: options.amend, + }), }) .await .context("sending commit request")?; @@ -3796,12 +3822,6 @@ impl Repository { updates_tx: Option>, cx: &mut Context, ) { - self.paths_changed( - vec![git::repository::WORK_DIRECTORY_REPO_PATH.clone()], - updates_tx.clone(), - cx, - ); - let this = cx.weak_entity(); let _ = self.send_keyed_job( Some(GitJobKey::ReloadGitState), @@ -3842,6 +3862,8 @@ impl Repository { fn spawn_local_git_worker( work_directory_abs_path: Arc, dot_git_abs_path: Arc, + _repository_dir_abs_path: Arc, + _common_dir_abs_path: Arc, project_environment: WeakEntity, fs: Arc, cx: &mut Context, @@ -4054,7 +4076,7 @@ impl Repository { for (repo_path, status) in &*statuses.entries { changed_paths.remove(repo_path); if cursor.seek_forward(&PathTarget::Path(repo_path), Bias::Left, &()) { - if &cursor.item().unwrap().status == status { + if cursor.item().is_some_and(|entry| entry.status == *status) { continue; } } @@ -4098,6 +4120,10 @@ impl Repository { pub fn current_job(&self) -> Option { self.active_jobs.values().next().cloned() } + + pub fn barrier(&mut self) -> oneshot::Receiver<()> { + self.send_job(None, |_, _| async {}) + } } fn get_permalink_in_rust_registry_src( diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 7ce0d92379..d27ffc070c 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -48,6 +48,8 @@ use debugger::{ session::Session, }; pub use environment::ProjectEnvironment; +#[cfg(test)] +use futures::future::join_all; use futures::{ StreamExt, channel::mpsc::{self, UnboundedReceiver}, @@ -4808,6 +4810,30 @@ impl Project { &self.git_store } + #[cfg(test)] + fn git_scans_complete(&self, cx: &Context) -> Task<()> { + cx.spawn(async move |this, cx| { + let scans_complete = this + .read_with(cx, |this, cx| { + this.worktrees(cx) + .filter_map(|worktree| Some(worktree.read(cx).as_local()?.scan_complete())) + .collect::>() + }) + .unwrap(); + join_all(scans_complete).await; + let barriers = this + .update(cx, |this, cx| { + let repos = this.repositories(cx).values().cloned().collect::>(); + repos + .into_iter() + .map(|repo| repo.update(cx, |repo, _| repo.barrier())) + .collect::>() + }) + .unwrap(); + join_all(barriers).await; + }) + } + pub fn active_repository(&self, cx: &App) -> Option> { self.git_store.read(cx).active_repository() } diff --git a/crates/project/src/project_tests.rs b/crates/project/src/project_tests.rs index 93537e4e46..d981696a08 100644 --- a/crates/project/src/project_tests.rs +++ b/crates/project/src/project_tests.rs @@ -5425,6 +5425,87 @@ async fn test_search_in_gitignored_dirs(cx: &mut gpui::TestAppContext) { ); } +#[gpui::test] +async fn test_search_with_unicode(cx: &mut gpui::TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/dir"), + json!({ + "one.rs": "// ПРИВЕТ? привет!", + "two.rs": "// ПРИВЕТ.", + "three.rs": "// привет", + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + + let unicode_case_sensitive_query = SearchQuery::text( + "привет", + false, + true, + false, + Default::default(), + Default::default(), + None, + ); + assert_matches!(unicode_case_sensitive_query, Ok(SearchQuery::Text { .. })); + assert_eq!( + search(&project, unicode_case_sensitive_query.unwrap(), cx) + .await + .unwrap(), + HashMap::from_iter([ + (separator!("dir/one.rs").to_string(), vec![17..29]), + (separator!("dir/three.rs").to_string(), vec![3..15]), + ]) + ); + + let unicode_case_insensitive_query = SearchQuery::text( + "привет", + false, + false, + false, + Default::default(), + Default::default(), + None, + ); + assert_matches!( + unicode_case_insensitive_query, + Ok(SearchQuery::Regex { .. }) + ); + assert_eq!( + search(&project, unicode_case_insensitive_query.unwrap(), cx) + .await + .unwrap(), + HashMap::from_iter([ + (separator!("dir/one.rs").to_string(), vec![3..15, 17..29]), + (separator!("dir/two.rs").to_string(), vec![3..15]), + (separator!("dir/three.rs").to_string(), vec![3..15]), + ]) + ); + + assert_eq!( + search( + &project, + SearchQuery::text( + "привет.", + false, + false, + false, + Default::default(), + Default::default(), + None, + ) + .unwrap(), + cx + ) + .await + .unwrap(), + HashMap::from_iter([(separator!("dir/two.rs").to_string(), vec![3..16]),]) + ); +} + #[gpui::test] async fn test_create_entry(cx: &mut gpui::TestAppContext) { init_test(cx); @@ -7192,7 +7273,7 @@ async fn test_repository_and_path_for_project_path( let tree_id = tree.read_with(cx, |tree, _| tree.id()); tree.read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete()) .await; - tree.flush_fs_events(cx).await; + cx.run_until_parked(); project.read_with(cx, |project, cx| { let git_store = project.git_store().read(cx); @@ -7233,7 +7314,7 @@ async fn test_repository_and_path_for_project_path( fs.remove_dir(path!("/root/dir1/.git").as_ref(), RemoveOptions::default()) .await .unwrap(); - tree.flush_fs_events(cx).await; + cx.run_until_parked(); project.read_with(cx, |project, cx| { let git_store = project.git_store().read(cx); @@ -7493,49 +7574,51 @@ async fn test_git_status_postprocessing(cx: &mut gpui::TestAppContext) { } #[gpui::test] -async fn test_repository_subfolder_git_status(cx: &mut gpui::TestAppContext) { +async fn test_repository_subfolder_git_status( + executor: gpui::BackgroundExecutor, + cx: &mut gpui::TestAppContext, +) { init_test(cx); - cx.executor().allow_parking(); - let root = TempTree::new(json!({ - "my-repo": { - // .git folder will go here - "a.txt": "a", - "sub-folder-1": { - "sub-folder-2": { - "c.txt": "cc", - "d": { - "e.txt": "eee" - } - }, - } - }, - })); + let fs = FakeFs::new(executor); + fs.insert_tree( + path!("/root"), + json!({ + "my-repo": { + ".git": {}, + "a.txt": "a", + "sub-folder-1": { + "sub-folder-2": { + "c.txt": "cc", + "d": { + "e.txt": "eee" + } + }, + } + }, + }), + ) + .await; const C_TXT: &str = "sub-folder-1/sub-folder-2/c.txt"; const E_TXT: &str = "sub-folder-1/sub-folder-2/d/e.txt"; - // Set up git repository before creating the worktree. - let git_repo_work_dir = root.path().join("my-repo"); - let repo = git_init(git_repo_work_dir.as_path()); - git_add(C_TXT, &repo); - git_commit("Initial commit", &repo); - - // Open the worktree in subfolder - let project_root = Path::new("my-repo/sub-folder-1/sub-folder-2"); + fs.set_status_for_repo( + path!("/root/my-repo/.git").as_ref(), + &[(E_TXT.as_ref(), FileStatus::Untracked)], + ); let project = Project::test( - Arc::new(RealFs::new(None, cx.executor())), - [root.path().join(project_root).as_path()], + fs.clone(), + [path!("/root/my-repo/sub-folder-1/sub-folder-2").as_ref()], cx, ) .await; - let tree = project.read_with(cx, |project, cx| project.worktrees(cx).next().unwrap()); - tree.flush_fs_events(cx).await; - cx.read(|cx| tree.read(cx).as_local().unwrap().scan_complete()) + project + .update(cx, |project, cx| project.git_scans_complete(cx)) .await; - cx.executor().run_until_parked(); + cx.run_until_parked(); let repository = project.read_with(cx, |project, cx| { project.repositories(cx).values().next().unwrap().clone() @@ -7544,8 +7627,8 @@ async fn test_repository_subfolder_git_status(cx: &mut gpui::TestAppContext) { // Ensure that the git status is loaded correctly repository.read_with(cx, |repository, _cx| { assert_eq!( - repository.work_directory_abs_path.canonicalize().unwrap(), - root.path().join("my-repo").canonicalize().unwrap() + repository.work_directory_abs_path, + Path::new(path!("/root/my-repo")).into() ); assert_eq!(repository.status_for_path(&C_TXT.into()), None); @@ -7555,13 +7638,11 @@ async fn test_repository_subfolder_git_status(cx: &mut gpui::TestAppContext) { ); }); - // Now we simulate FS events, but ONLY in the .git folder that's outside - // of out project root. - // Meaning: we don't produce any FS events for files inside the project. - git_add(E_TXT, &repo); - git_commit("Second commit", &repo); - tree.flush_fs_events_in_root_git_repository(cx).await; - cx.executor().run_until_parked(); + fs.set_status_for_repo(path!("/root/my-repo/.git").as_ref(), &[]); + project + .update(cx, |project, cx| project.git_scans_complete(cx)) + .await; + cx.run_until_parked(); repository.read_with(cx, |repository, _cx| { assert_eq!(repository.status_for_path(&C_TXT.into()), None); @@ -8182,6 +8263,104 @@ async fn test_rescan_with_gitignore(cx: &mut gpui::TestAppContext) { }); } +#[gpui::test] +async fn test_git_worktrees_and_submodules(cx: &mut gpui::TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + ".git": { + "worktrees": { + "some-worktree": {} + }, + }, + "src": { + "a.txt": "A", + }, + "some-worktree": { + ".git": "gitdir: ../.git/worktrees/some-worktree", + "src": { + "b.txt": "B", + } + } + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let scan_complete = project.update(cx, |project, cx| { + project + .worktrees(cx) + .next() + .unwrap() + .read(cx) + .as_local() + .unwrap() + .scan_complete() + }); + scan_complete.await; + + let mut repositories = project.update(cx, |project, cx| { + project + .repositories(cx) + .values() + .map(|repo| repo.read(cx).work_directory_abs_path.clone()) + .collect::>() + }); + repositories.sort(); + pretty_assertions::assert_eq!( + repositories, + [ + Path::new(path!("/project")).into(), + Path::new(path!("/project/some-worktree")).into(), + ] + ); + + fs.with_git_state( + path!("/project/some-worktree/.git").as_ref(), + true, + |state| { + state + .head_contents + .insert("src/b.txt".into(), "b".to_owned()); + state + .index_contents + .insert("src/b.txt".into(), "b".to_owned()); + }, + ) + .unwrap(); + cx.run_until_parked(); + + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/project/some-worktree/src/b.txt"), cx) + }) + .await + .unwrap(); + let (worktree_repo, barrier) = project.update(cx, |project, cx| { + let (repo, _) = project + .git_store() + .read(cx) + .repository_and_path_for_buffer_id(buffer.read(cx).remote_id(), cx) + .unwrap(); + pretty_assertions::assert_eq!( + repo.read(cx).work_directory_abs_path, + Path::new(path!("/project/some-worktree")).into(), + ); + let barrier = repo.update(cx, |repo, _| repo.barrier()); + (repo.clone(), barrier) + }); + barrier.await.unwrap(); + worktree_repo.update(cx, |repo, _| { + pretty_assertions::assert_eq!( + repo.status_for_path(&"src/b.txt".into()).unwrap().status, + StatusCode::Modified.worktree(), + ); + }); +} + #[gpui::test] async fn test_repository_deduplication(cx: &mut gpui::TestAppContext) { init_test(cx); diff --git a/crates/project/src/search.rs b/crates/project/src/search.rs index 06745c82f4..d23bb9a9b8 100644 --- a/crates/project/src/search.rs +++ b/crates/project/src/search.rs @@ -93,6 +93,21 @@ impl SearchQuery { buffers: Option>>, ) -> Result { let query = query.to_string(); + if !case_sensitive && !query.is_ascii() { + // AhoCorasickBuilder doesn't support case-insensitive search with unicode characters + // Fallback to regex search as recommended by + // https://docs.rs/aho-corasick/1.1/aho_corasick/struct.AhoCorasickBuilder.html#method.ascii_case_insensitive + return Self::regex( + regex::escape(&query), + whole_word, + case_sensitive, + include_ignored, + false, + files_to_include, + files_to_exclude, + buffers, + ); + } let search = AhoCorasickBuilder::new() .ascii_case_insensitive(!case_sensitive) .build([&query])?; diff --git a/crates/project_panel/src/project_panel_tests.rs b/crates/project_panel/src/project_panel_tests.rs index e35e5d25c5..990b446dcb 100644 --- a/crates/project_panel/src/project_panel_tests.rs +++ b/crates/project_panel/src/project_panel_tests.rs @@ -2070,6 +2070,20 @@ async fn test_select_git_entry(cx: &mut gpui::TestAppContext) { cx, ) .await; + + let (scan1_complete, scan2_complete) = project.update(cx, |project, cx| { + let mut worktrees = project.worktrees(cx); + let worktree1 = worktrees.next().unwrap(); + let worktree2 = worktrees.next().unwrap(); + ( + worktree1.read(cx).as_local().unwrap().scan_complete(), + worktree2.read(cx).as_local().unwrap().scan_complete(), + ) + }); + scan1_complete.await; + scan2_complete.await; + cx.run_until_parked(); + let workspace = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); let cx = &mut VisualTestContext::from_window(*workspace, cx); let panel = workspace.update(cx, ProjectPanel::new).unwrap(); diff --git a/crates/prompt_library/src/prompt_library.rs b/crates/prompt_library/src/prompt_library.rs index c2c1f3da60..7fff6d1258 100644 --- a/crates/prompt_library/src/prompt_library.rs +++ b/crates/prompt_library/src/prompt_library.rs @@ -657,7 +657,7 @@ impl PromptLibrary { .iter() .position(|mat| mat.id == prompt_id) { - picker.set_selected_index(ix, true, window, cx); + picker.set_selected_index(ix, None, true, window, cx); } } } else { diff --git a/crates/proto/proto/git.proto b/crates/proto/proto/git.proto index 7774a5293b..0d94bcb469 100644 --- a/crates/proto/proto/git.proto +++ b/crates/proto/proto/git.proto @@ -292,6 +292,11 @@ message Commit { optional string name = 4; optional string email = 5; string message = 6; + optional CommitOptions options = 7; + + message CommitOptions { + bool amend = 1; + } } message OpenCommitMessageBuffer { diff --git a/crates/recent_projects/src/recent_projects.rs b/crates/recent_projects/src/recent_projects.rs index 77feba9f2c..3d65bcac02 100644 --- a/crates/recent_projects/src/recent_projects.rs +++ b/crates/recent_projects/src/recent_projects.rs @@ -24,8 +24,8 @@ use std::{ use ui::{KeyBinding, ListItem, ListItemSpacing, Tooltip, prelude::*, tooltip_container}; use util::{ResultExt, paths::PathExt}; use workspace::{ - CloseIntent, ModalView, OpenOptions, SerializedWorkspaceLocation, WORKSPACE_DB, Workspace, - WorkspaceId, + CloseIntent, HistoryManager, ModalView, OpenOptions, SerializedWorkspaceLocation, WORKSPACE_DB, + Workspace, WorkspaceId, }; use zed_actions::{OpenRecent, OpenRemote}; @@ -553,7 +553,13 @@ impl RecentProjectsDelegate { .delegate .set_selected_index(ix.saturating_sub(1), window, cx); picker.delegate.reset_selected_match_index = false; - picker.update_matches(picker.query(cx), window, cx) + picker.update_matches(picker.query(cx), window, cx); + // After deleting a project, we want to update the history manager to reflect the change. + // But we do not emit a update event when user opens a project, because it's handled in `workspace::load_workspace`. + if let Some(history_manager) = HistoryManager::global(cx) { + history_manager + .update(cx, |this, cx| this.delete_history(workspace_id, cx)); + } }) }) .detach(); diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index ed929cb1f3..225c3613ce 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -36,7 +36,7 @@ use ui::{ IconWithIndicator, Indicator, PopoverMenu, Tooltip, h_flex, prelude::*, }; use util::ResultExt; -use workspace::{Workspace, notifications::NotifyResultExt}; +use workspace::{BottomDockLayout, Workspace, notifications::NotifyResultExt}; use zed_actions::{OpenBrowser, OpenRecent, OpenRemote}; pub use onboarding_banner::restore_banner; @@ -210,6 +210,7 @@ impl Render for TitleBar { .pr_1() .on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation()) .children(self.render_call_controls(window, cx)) + .child(self.render_bottom_dock_layout_menu(cx)) .map(|el| { let status = self.client.status(); let status = &*status.borrow(); @@ -622,6 +623,101 @@ impl TitleBar { } } + pub fn render_bottom_dock_layout_menu(&self, cx: &mut Context) -> impl IntoElement { + let workspace = self.workspace.upgrade().unwrap(); + let current_layout = workspace.update(cx, |workspace, _cx| workspace.bottom_dock_layout()); + + PopoverMenu::new("layout-menu") + .trigger( + IconButton::new("toggle_layout", IconName::Layout) + .icon_size(IconSize::Small) + .tooltip(Tooltip::text("Toggle Layout Menu")), + ) + .anchor(gpui::Corner::TopRight) + .menu(move |window, cx| { + ContextMenu::build(window, cx, { + let workspace = workspace.clone(); + move |menu, _, _| { + menu.label("Bottom Dock") + .separator() + .toggleable_entry( + "Contained", + current_layout == BottomDockLayout::Contained, + ui::IconPosition::End, + None, + { + let workspace = workspace.clone(); + move |window, cx| { + workspace.update(cx, |workspace, cx| { + workspace.set_bottom_dock_layout( + BottomDockLayout::Contained, + window, + cx, + ); + }); + } + }, + ) + .toggleable_entry( + "Full", + current_layout == BottomDockLayout::Full, + ui::IconPosition::End, + None, + { + let workspace = workspace.clone(); + move |window, cx| { + workspace.update(cx, |workspace, cx| { + workspace.set_bottom_dock_layout( + BottomDockLayout::Full, + window, + cx, + ); + }); + } + }, + ) + .toggleable_entry( + "Left Aligned", + current_layout == BottomDockLayout::LeftAligned, + ui::IconPosition::End, + None, + { + let workspace = workspace.clone(); + move |window, cx| { + workspace.update(cx, |workspace, cx| { + workspace.set_bottom_dock_layout( + BottomDockLayout::LeftAligned, + window, + cx, + ); + }); + } + }, + ) + .toggleable_entry( + "Right Aligned", + current_layout == BottomDockLayout::RightAligned, + ui::IconPosition::End, + None, + { + let workspace = workspace.clone(); + move |window, cx| { + workspace.update(cx, |workspace, cx| { + workspace.set_bottom_dock_layout( + BottomDockLayout::RightAligned, + window, + cx, + ); + }); + } + }, + ) + } + }) + .into() + }) + } + pub fn render_sign_in_button(&mut self, _: &mut Context) -> Button { let client = self.client.clone(); Button::new("sign_in", "Sign in") diff --git a/crates/ui/src/components/avatar.rs b/crates/ui/src/components/avatar.rs index 668bdd5285..3ab31acc1b 100644 --- a/crates/ui/src/components/avatar.rs +++ b/crates/ui/src/components/avatar.rs @@ -236,53 +236,30 @@ impl Component for Avatar { v_flex() .gap_6() .children(vec![ + example_group(vec![ + single_example("Default", Avatar::new(example_avatar).into_any_element()), + single_example( + "Grayscale", + Avatar::new(example_avatar) + .grayscale(true) + .into_any_element(), + ), + single_example( + "Border", + Avatar::new(example_avatar) + .border_color(cx.theme().colors().border) + .into_any_element(), + ).description("Can be used to create visual space by setting the border color to match the background, which creates the appearance of a gap around the avatar."), + ]), example_group_with_title( - "Sizes", - vec![ - single_example( - "Default", - Avatar::new(example_avatar).into_any_element(), - ), - single_example( - "Small", - Avatar::new(example_avatar).size(px(24.)).into_any_element(), - ), - single_example( - "Large", - Avatar::new(example_avatar).size(px(48.)).into_any_element(), - ), - ], - ), - example_group_with_title( - "Styles", - vec![ - single_example( - "Default", - Avatar::new(example_avatar).into_any_element(), - ), - single_example( - "Grayscale", - Avatar::new(example_avatar) - .grayscale(true) - .into_any_element(), - ), - single_example( - "With Border", - Avatar::new(example_avatar) - .border_color(cx.theme().colors().border) - .into_any_element(), - ), - ], - ), - example_group_with_title( - "Audio Status", + "Indicator Styles", vec![ single_example( "Muted", Avatar::new(example_avatar) .indicator(AvatarAudioStatusIndicator::new(AudioStatus::Muted)) .into_any_element(), - ), + ).description("Indicates the collaborator's mic is muted."), single_example( "Deafened", Avatar::new(example_avatar) @@ -290,28 +267,23 @@ impl Component for Avatar { AudioStatus::Deafened, )) .into_any_element(), - ), - ], - ), - example_group_with_title( - "Availability", - vec![ + ).description("Indicates that both the collaborator's mic and audio are muted."), single_example( - "Free", + "Availability: Free", Avatar::new(example_avatar) .indicator(AvatarAvailabilityIndicator::new( CollaboratorAvailability::Free, )) .into_any_element(), - ), + ).description("Indicates that the person is free, usually meaning they are not in a call."), single_example( - "Busy", + "Availability: Busy", Avatar::new(example_avatar) .indicator(AvatarAvailabilityIndicator::new( CollaboratorAvailability::Busy, )) .into_any_element(), - ), + ).description("Indicates that the person is busy, usually meaning they are in a channel or direct call."), ], ), ]) diff --git a/crates/ui/src/components/button/split_button.rs b/crates/ui/src/components/button/split_button.rs index 6ceeb88377..3d50340755 100644 --- a/crates/ui/src/components/button/split_button.rs +++ b/crates/ui/src/components/button/split_button.rs @@ -20,6 +20,12 @@ pub struct SplitButton { pub right: AnyElement, } +impl SplitButton { + pub fn new(left: ButtonLike, right: AnyElement) -> Self { + Self { left, right } + } +} + impl RenderOnce for SplitButton { fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { h_flex() diff --git a/crates/ui/src/components/disclosure.rs b/crates/ui/src/components/disclosure.rs index 6460d059a1..a1fab02e54 100644 --- a/crates/ui/src/components/disclosure.rs +++ b/crates/ui/src/components/disclosure.rs @@ -9,6 +9,7 @@ pub struct Disclosure { id: ElementId, is_open: bool, selected: bool, + disabled: bool, on_toggle: Option>, cursor_style: CursorStyle, opened_icon: IconName, @@ -21,6 +22,7 @@ impl Disclosure { id: id.into(), is_open, selected: false, + disabled: false, on_toggle: None, cursor_style: CursorStyle::PointingHand, opened_icon: IconName::ChevronDown, @@ -45,6 +47,11 @@ impl Disclosure { self.closed_icon = icon; self } + + pub fn disabled(mut self, disabled: bool) -> Self { + self.disabled = disabled; + self + } } impl Toggleable for Disclosure { @@ -78,6 +85,7 @@ impl RenderOnce for Disclosure { .shape(IconButtonShape::Square) .icon_color(Color::Muted) .icon_size(IconSize::Small) + .disabled(self.disabled) .toggle_state(self.selected) .when_some(self.on_toggle, move |this, on_toggle| { this.on_click(move |event, window, cx| on_toggle(event, window, cx)) @@ -120,13 +128,7 @@ impl Component for Disclosure { "Toggleable", v_flex() .gap_2() - .child( - Disclosure::new("interactive", false) - // .on_toggle(Some(Arc::new(|_, _, cx| { - // cx.refresh(); - // }))) - .into_any_element(), - ) + .child(Disclosure::new("interactive", false).into_any_element()) .child(Label::new("Click to toggle")) .into_any_element(), )], diff --git a/crates/ui/src/components/keybinding.rs b/crates/ui/src/components/keybinding.rs index db9bb3008f..1b3746515d 100644 --- a/crates/ui/src/components/keybinding.rs +++ b/crates/ui/src/components/keybinding.rs @@ -451,7 +451,7 @@ fn keystroke_text(keystroke: &Keystroke, platform_style: PlatformStyle, vim_mode impl Component for KeyBinding { fn scope() -> ComponentScope { - ComponentScope::Input + ComponentScope::Typography } fn name() -> &'static str { diff --git a/crates/vim/src/change_list.rs b/crates/vim/src/change_list.rs index baa72dbc1f..0d0a5898b2 100644 --- a/crates/vim/src/change_list.rs +++ b/crates/vim/src/change_list.rs @@ -24,6 +24,7 @@ impl Vim { cx: &mut Context, ) { let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); if self.change_list.is_empty() { return; } diff --git a/crates/vim/src/command.rs b/crates/vim/src/command.rs index 9245701bf3..264edca0b1 100644 --- a/crates/vim/src/command.rs +++ b/crates/vim/src/command.rs @@ -234,6 +234,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { return; }; let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); let n = if count > 1 { format!(".,.+{}", count.saturating_sub(1)) } else { @@ -962,7 +963,15 @@ pub fn command_interceptor(mut input: &str, cx: &App) -> Vec, + forced_motion: bool, window: &mut Window, cx: &mut Context, ) { @@ -1335,7 +1345,13 @@ impl Vim { let start = editor.selections.newest_display(cx); let text_layout_details = editor.text_layout_details(window); let (mut range, _) = motion - .range(&snapshot, start.clone(), times, &text_layout_details) + .range( + &snapshot, + start.clone(), + times, + &text_layout_details, + forced_motion, + ) .unwrap_or((start.range(), MotionKind::Exclusive)); if range.start != start.start { editor.change_selections(None, window, cx, |s| { diff --git a/crates/vim/src/indent.rs b/crates/vim/src/indent.rs index 3f0ed5251f..ac708a7e89 100644 --- a/crates/vim/src/indent.rs +++ b/crates/vim/src/indent.rs @@ -18,6 +18,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, _: &Indent, window, cx| { vim.record_current_action(cx); let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); vim.store_visual_marks(window, cx); vim.update_editor(window, cx, |vim, editor, window, cx| { editor.transact(window, cx, |editor, window, cx| { @@ -36,6 +37,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, _: &Outdent, window, cx| { vim.record_current_action(cx); let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); vim.store_visual_marks(window, cx); vim.update_editor(window, cx, |vim, editor, window, cx| { editor.transact(window, cx, |editor, window, cx| { @@ -54,6 +56,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, _: &AutoIndent, window, cx| { vim.record_current_action(cx); let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); vim.store_visual_marks(window, cx); vim.update_editor(window, cx, |vim, editor, window, cx| { editor.transact(window, cx, |editor, window, cx| { @@ -75,6 +78,7 @@ impl Vim { &mut self, motion: Motion, times: Option, + forced_motion: bool, dir: IndentDirection, window: &mut Window, cx: &mut Context, @@ -88,7 +92,13 @@ impl Vim { s.move_with(|map, selection| { let anchor = map.display_point_to_anchor(selection.head(), Bias::Right); selection_starts.insert(selection.id, anchor); - motion.expand_selection(map, selection, times, &text_layout_details); + motion.expand_selection( + map, + selection, + times, + &text_layout_details, + forced_motion, + ); }); }); match dir { diff --git a/crates/vim/src/insert.rs b/crates/vim/src/insert.rs index 550d5b57fd..561ceec0a8 100644 --- a/crates/vim/src/insert.rs +++ b/crates/vim/src/insert.rs @@ -23,6 +23,7 @@ impl Vim { return; } let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); self.stop_recording_immediately(action.boxed_clone(), cx); if count <= 1 || Vim::globals(cx).dot_replaying { self.create_mark("^".into(), window, cx); diff --git a/crates/vim/src/motion.rs b/crates/vim/src/motion.rs index bbe1ef01a6..fcf1e07749 100644 --- a/crates/vim/src/motion.rs +++ b/crates/vim/src/motion.rs @@ -650,6 +650,7 @@ impl Vim { } let count = Vim::take_count(cx); + let forced_motion = Vim::take_forced_motion(cx); let active_operator = self.active_operator(); let mut waiting_operator: Option = None; match self.mode { @@ -659,7 +660,14 @@ impl Vim { target: Some(SurroundsType::Motion(motion)), }); } else { - self.normal_motion(motion.clone(), active_operator.clone(), count, window, cx) + self.normal_motion( + motion.clone(), + active_operator.clone(), + count, + forced_motion, + window, + cx, + ) } } Mode::Visual | Mode::VisualLine | Mode::VisualBlock => { @@ -1183,7 +1191,6 @@ impl Motion { SelectionGoal::None, ), }; - (new_point != point || infallible).then_some((new_point, goal)) } @@ -1194,6 +1201,7 @@ impl Motion { selection: Selection, times: Option, text_layout_details: &TextLayoutDetails, + forced_motion: bool, ) -> Option<(Range, MotionKind)> { if let Motion::ZedSearchResult { prior_selections, @@ -1221,18 +1229,29 @@ impl Motion { return None; } } - - let (new_head, goal) = self.move_point( + let maybe_new_point = self.move_point( map, selection.head(), selection.goal, times, text_layout_details, - )?; + ); + + let (new_head, goal) = match (maybe_new_point, forced_motion) { + (Some((p, g)), _) => Some((p, g)), + (None, false) => None, + (None, true) => Some((selection.head(), selection.goal)), + }?; + let mut selection = selection.clone(); selection.set_head(new_head, goal); - let mut kind = self.default_kind(); + let mut kind = match (self.default_kind(), forced_motion) { + (MotionKind::Linewise, true) => MotionKind::Exclusive, + (MotionKind::Exclusive, true) => MotionKind::Inclusive, + (MotionKind::Inclusive, true) => MotionKind::Exclusive, + (kind, false) => kind, + }; if let Motion::NextWordStart { ignore_punctuation: _, @@ -1259,6 +1278,12 @@ impl Motion { } else if kind == MotionKind::Exclusive && !self.skip_exclusive_special_case() { let start_point = selection.start.to_point(map); let mut end_point = selection.end.to_point(map); + let mut next_point = selection.end; + *next_point.column_mut() += 1; + next_point = map.clip_point(next_point, Bias::Right); + if next_point.to_point(map) == end_point && forced_motion { + selection.end = movement::saturating_left(map, selection.end); + } if end_point.row > start_point.row { let first_non_blank_of_start_row = map @@ -1304,8 +1329,15 @@ impl Motion { selection: &mut Selection, times: Option, text_layout_details: &TextLayoutDetails, + forced_motion: bool, ) -> Option { - let (range, kind) = self.range(map, selection.clone(), times, text_layout_details)?; + let (range, kind) = self.range( + map, + selection.clone(), + times, + text_layout_details, + forced_motion, + )?; selection.start = range.start; selection.end = range.end; Some(kind) @@ -3816,6 +3848,7 @@ mod test { Mode::Normal, ); } + #[gpui::test] async fn test_delete_key_can_remove_last_character(cx: &mut gpui::TestAppContext) { let mut cx = NeovimBackedTestContext::new(cx).await; @@ -3823,4 +3856,147 @@ mod test { cx.simulate_shared_keystrokes("delete").await; cx.shared_state().await.assert_eq("aˇb"); } + + #[gpui::test] + async fn test_forced_motion_delete_to_start_of_line(cx: &mut gpui::TestAppContext) { + let mut cx = NeovimBackedTestContext::new(cx).await; + + cx.set_shared_state(indoc! {" + ˇthe quick brown fox + jumped over the lazy dog"}) + .await; + cx.simulate_shared_keystrokes("d v 0").await; + cx.shared_state().await.assert_eq(indoc! {" + ˇhe quick brown fox + jumped over the lazy dog"}); + assert_eq!(cx.cx.forced_motion(), false); + + cx.set_shared_state(indoc! {" + the quick bˇrown fox + jumped over the lazy dog"}) + .await; + cx.simulate_shared_keystrokes("d v 0").await; + cx.shared_state().await.assert_eq(indoc! {" + ˇown fox + jumped over the lazy dog"}); + assert_eq!(cx.cx.forced_motion(), false); + + cx.set_shared_state(indoc! {" + the quick brown foˇx + jumped over the lazy dog"}) + .await; + cx.simulate_shared_keystrokes("d v 0").await; + cx.shared_state().await.assert_eq(indoc! {" + ˇ + jumped over the lazy dog"}); + assert_eq!(cx.cx.forced_motion(), false); + } + + #[gpui::test] + async fn test_forced_motion_delete_to_end_of_line(cx: &mut gpui::TestAppContext) { + let mut cx = NeovimBackedTestContext::new(cx).await; + + cx.set_shared_state(indoc! {" + the quick brown foˇx + jumped over the lazy dog"}) + .await; + cx.simulate_shared_keystrokes("d v $").await; + cx.shared_state().await.assert_eq(indoc! {" + the quick brown foˇx + jumped over the lazy dog"}); + assert_eq!(cx.cx.forced_motion(), false); + + cx.set_shared_state(indoc! {" + ˇthe quick brown fox + jumped over the lazy dog"}) + .await; + cx.simulate_shared_keystrokes("d v $").await; + cx.shared_state().await.assert_eq(indoc! {" + ˇx + jumped over the lazy dog"}); + assert_eq!(cx.cx.forced_motion(), false); + } + + #[gpui::test] + async fn test_forced_motion_yank(cx: &mut gpui::TestAppContext) { + let mut cx = NeovimBackedTestContext::new(cx).await; + + cx.set_shared_state(indoc! {" + ˇthe quick brown fox + jumped over the lazy dog"}) + .await; + cx.simulate_shared_keystrokes("y v j p").await; + cx.shared_state().await.assert_eq(indoc! {" + the quick brown fox + ˇthe quick brown fox + jumped over the lazy dog"}); + assert_eq!(cx.cx.forced_motion(), false); + + cx.set_shared_state(indoc! {" + the quick bˇrown fox + jumped over the lazy dog"}) + .await; + cx.simulate_shared_keystrokes("y v j p").await; + cx.shared_state().await.assert_eq(indoc! {" + the quick brˇrown fox + jumped overown fox + jumped over the lazy dog"}); + assert_eq!(cx.cx.forced_motion(), false); + + cx.set_shared_state(indoc! {" + the quick brown foˇx + jumped over the lazy dog"}) + .await; + cx.simulate_shared_keystrokes("y v j p").await; + cx.shared_state().await.assert_eq(indoc! {" + the quick brown foxˇx + jumped over the la + jumped over the lazy dog"}); + assert_eq!(cx.cx.forced_motion(), false); + + cx.set_shared_state(indoc! {" + the quick brown fox + jˇumped over the lazy dog"}) + .await; + cx.simulate_shared_keystrokes("y v k p").await; + cx.shared_state().await.assert_eq(indoc! {" + thˇhe quick brown fox + je quick brown fox + jumped over the lazy dog"}); + assert_eq!(cx.cx.forced_motion(), false); + } + + #[gpui::test] + async fn test_inclusive_to_exclusive_delete(cx: &mut gpui::TestAppContext) { + let mut cx = NeovimBackedTestContext::new(cx).await; + + cx.set_shared_state(indoc! {" + ˇthe quick brown fox + jumped over the lazy dog"}) + .await; + cx.simulate_shared_keystrokes("d v e").await; + cx.shared_state().await.assert_eq(indoc! {" + ˇe quick brown fox + jumped over the lazy dog"}); + assert_eq!(cx.cx.forced_motion(), false); + + cx.set_shared_state(indoc! {" + the quick bˇrown fox + jumped over the lazy dog"}) + .await; + cx.simulate_shared_keystrokes("d v e").await; + cx.shared_state().await.assert_eq(indoc! {" + the quick bˇn fox + jumped over the lazy dog"}); + assert_eq!(cx.cx.forced_motion(), false); + + cx.set_shared_state(indoc! {" + the quick brown foˇx + jumped over the lazy dog"}) + .await; + cx.simulate_shared_keystrokes("d v e").await; + cx.shared_state().await.assert_eq(indoc! {" + the quick brown foˇd over the lazy dog"}); + assert_eq!(cx.cx.forced_motion(), false); + } } diff --git a/crates/vim/src/normal.rs b/crates/vim/src/normal.rs index 43657ffd73..7781891050 100644 --- a/crates/vim/src/normal.rs +++ b/crates/vim/src/normal.rs @@ -86,12 +86,14 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, _: &DeleteLeft, window, cx| { vim.record_current_action(cx); let times = Vim::take_count(cx); - vim.delete_motion(Motion::Left, times, window, cx); + let forced_motion = Vim::take_forced_motion(cx); + vim.delete_motion(Motion::Left, times, forced_motion, window, cx); }); Vim::action(editor, cx, |vim, _: &DeleteRight, window, cx| { vim.record_current_action(cx); let times = Vim::take_count(cx); - vim.delete_motion(Motion::Right, times, window, cx); + let forced_motion = Vim::take_forced_motion(cx); + vim.delete_motion(Motion::Right, times, forced_motion, window, cx); }); Vim::action(editor, cx, |vim, _: &HelixDelete, window, cx| { @@ -111,11 +113,13 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, _: &ChangeToEndOfLine, window, cx| { vim.start_recording(cx); let times = Vim::take_count(cx); + let forced_motion = Vim::take_forced_motion(cx); vim.change_motion( Motion::EndOfLine { display_lines: false, }, times, + forced_motion, window, cx, ); @@ -123,11 +127,13 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, _: &DeleteToEndOfLine, window, cx| { vim.record_current_action(cx); let times = Vim::take_count(cx); + let forced_motion = Vim::take_forced_motion(cx); vim.delete_motion( Motion::EndOfLine { display_lines: false, }, times, + forced_motion, window, cx, ); @@ -142,6 +148,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, _: &Undo, window, cx| { let times = Vim::take_count(cx); + Vim::take_forced_motion(cx); vim.update_editor(window, cx, |_, editor, window, cx| { for _ in 0..times.unwrap_or(1) { editor.undo(&editor::actions::Undo, window, cx); @@ -150,6 +157,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { }); Vim::action(editor, cx, |vim, _: &Redo, window, cx| { let times = Vim::take_count(cx); + Vim::take_forced_motion(cx); vim.update_editor(window, cx, |_, editor, window, cx| { for _ in 0..times.unwrap_or(1) { editor.redo(&editor::actions::Redo, window, cx); @@ -170,48 +178,93 @@ impl Vim { motion: Motion, operator: Option, times: Option, + forced_motion: bool, window: &mut Window, cx: &mut Context, ) { match operator { None => self.move_cursor(motion, times, window, cx), - Some(Operator::Change) => self.change_motion(motion, times, window, cx), - Some(Operator::Delete) => self.delete_motion(motion, times, window, cx), - Some(Operator::Yank) => self.yank_motion(motion, times, window, cx), + Some(Operator::Change) => self.change_motion(motion, times, forced_motion, window, cx), + Some(Operator::Delete) => self.delete_motion(motion, times, forced_motion, window, cx), + Some(Operator::Yank) => self.yank_motion(motion, times, forced_motion, window, cx), Some(Operator::AddSurrounds { target: None }) => {} - Some(Operator::Indent) => { - self.indent_motion(motion, times, IndentDirection::In, window, cx) - } - Some(Operator::Rewrap) => self.rewrap_motion(motion, times, window, cx), - Some(Operator::Outdent) => { - self.indent_motion(motion, times, IndentDirection::Out, window, cx) - } - Some(Operator::AutoIndent) => { - self.indent_motion(motion, times, IndentDirection::Auto, window, cx) - } - Some(Operator::ShellCommand) => self.shell_command_motion(motion, times, window, cx), - Some(Operator::Lowercase) => { - self.convert_motion(motion, times, ConvertTarget::LowerCase, window, cx) - } - Some(Operator::Uppercase) => { - self.convert_motion(motion, times, ConvertTarget::UpperCase, window, cx) - } - Some(Operator::OppositeCase) => { - self.convert_motion(motion, times, ConvertTarget::OppositeCase, window, cx) - } - Some(Operator::Rot13) => { - self.convert_motion(motion, times, ConvertTarget::Rot13, window, cx) - } - Some(Operator::Rot47) => { - self.convert_motion(motion, times, ConvertTarget::Rot47, window, cx) + Some(Operator::Indent) => self.indent_motion( + motion, + times, + forced_motion, + IndentDirection::In, + window, + cx, + ), + Some(Operator::Rewrap) => self.rewrap_motion(motion, times, forced_motion, window, cx), + Some(Operator::Outdent) => self.indent_motion( + motion, + times, + forced_motion, + IndentDirection::Out, + window, + cx, + ), + Some(Operator::AutoIndent) => self.indent_motion( + motion, + times, + forced_motion, + IndentDirection::Auto, + window, + cx, + ), + Some(Operator::ShellCommand) => { + self.shell_command_motion(motion, times, forced_motion, window, cx) } + Some(Operator::Lowercase) => self.convert_motion( + motion, + times, + forced_motion, + ConvertTarget::LowerCase, + window, + cx, + ), + Some(Operator::Uppercase) => self.convert_motion( + motion, + times, + forced_motion, + ConvertTarget::UpperCase, + window, + cx, + ), + Some(Operator::OppositeCase) => self.convert_motion( + motion, + times, + forced_motion, + ConvertTarget::OppositeCase, + window, + cx, + ), + Some(Operator::Rot13) => self.convert_motion( + motion, + times, + forced_motion, + ConvertTarget::Rot13, + window, + cx, + ), + Some(Operator::Rot47) => self.convert_motion( + motion, + times, + forced_motion, + ConvertTarget::Rot47, + window, + cx, + ), Some(Operator::ToggleComments) => { - self.toggle_comments_motion(motion, times, window, cx) + self.toggle_comments_motion(motion, times, forced_motion, window, cx) } Some(Operator::ReplaceWithRegister) => { - self.replace_with_register_motion(motion, times, window, cx) + self.replace_with_register_motion(motion, times, forced_motion, window, cx) + } + Some(Operator::Exchange) => { + self.exchange_motion(motion, times, forced_motion, window, cx) } - Some(Operator::Exchange) => self.exchange_motion(motion, times, window, cx), Some(operator) => { // Can't do anything for text objects, Ignoring error!("Unexpected normal mode motion operator: {:?}", operator) @@ -492,6 +545,7 @@ impl Vim { ) { self.record_current_action(cx); let mut times = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); if self.mode.is_visual() { times = 1; } else if times > 1 { @@ -513,11 +567,19 @@ impl Vim { fn yank_line(&mut self, _: &YankLine, window: &mut Window, cx: &mut Context) { let count = Vim::take_count(cx); - self.yank_motion(motion::Motion::CurrentLine, count, window, cx) + let forced_motion = Vim::take_forced_motion(cx); + self.yank_motion( + motion::Motion::CurrentLine, + count, + forced_motion, + window, + cx, + ) } fn show_location(&mut self, _: &ShowLocation, window: &mut Window, cx: &mut Context) { let count = Vim::take_count(cx); + Vim::take_forced_motion(cx); self.update_editor(window, cx, |vim, editor, _window, cx| { let selection = editor.selections.newest_anchor(); if let Some((_, buffer, _)) = editor.active_excerpt(cx) { @@ -577,6 +639,7 @@ impl Vim { cx: &mut Context, ) { let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); self.stop_recording(cx); self.update_editor(window, cx, |_, editor, window, cx| { editor.transact(window, cx, |editor, window, cx| { diff --git a/crates/vim/src/normal/change.rs b/crates/vim/src/normal/change.rs index 199ac8b0c7..7e27cda949 100644 --- a/crates/vim/src/normal/change.rs +++ b/crates/vim/src/normal/change.rs @@ -18,6 +18,7 @@ impl Vim { &mut self, motion: Motion, times: Option, + forced_motion: bool, window: &mut Window, cx: &mut Context, ) { @@ -59,6 +60,7 @@ impl Vim { selection, times, &text_layout_details, + forced_motion, ); if let Motion::CurrentLine = motion { let mut start_offset = @@ -181,7 +183,7 @@ fn expand_changed_word_selection( } else { Motion::NextWordStart { ignore_punctuation } }; - motion.expand_selection(map, selection, times, text_layout_details) + motion.expand_selection(map, selection, times, text_layout_details, false) } } diff --git a/crates/vim/src/normal/convert.rs b/crates/vim/src/normal/convert.rs index af0154d3c2..31aac771c2 100644 --- a/crates/vim/src/normal/convert.rs +++ b/crates/vim/src/normal/convert.rs @@ -25,6 +25,7 @@ impl Vim { &mut self, motion: Motion, times: Option, + forced_motion: bool, mode: ConvertTarget, window: &mut Window, cx: &mut Context, @@ -39,7 +40,13 @@ impl Vim { s.move_with(|map, selection| { let anchor = map.display_point_to_anchor(selection.head(), Bias::Left); selection_starts.insert(selection.id, anchor); - motion.expand_selection(map, selection, times, &text_layout_details); + motion.expand_selection( + map, + selection, + times, + &text_layout_details, + forced_motion, + ); }); }); match mode { @@ -185,6 +192,7 @@ impl Vim { self.record_current_action(cx); self.store_visual_marks(window, cx); let count = Vim::take_count(cx).unwrap_or(1) as u32; + Vim::take_forced_motion(cx); self.update_editor(window, cx, |vim, editor, window, cx| { let mut ranges = Vec::new(); diff --git a/crates/vim/src/normal/delete.rs b/crates/vim/src/normal/delete.rs index afd6bc402c..583e775fc6 100644 --- a/crates/vim/src/normal/delete.rs +++ b/crates/vim/src/normal/delete.rs @@ -18,6 +18,7 @@ impl Vim { &mut self, motion: Motion, times: Option, + forced_motion: bool, window: &mut Window, cx: &mut Context, ) { @@ -33,9 +34,13 @@ impl Vim { s.move_with(|map, selection| { let original_head = selection.head(); original_columns.insert(selection.id, original_head.column()); - let kind = - motion.expand_selection(map, selection, times, &text_layout_details); - + let kind = motion.expand_selection( + map, + selection, + times, + &text_layout_details, + forced_motion, + ); ranges_to_copy .push(selection.start.to_point(map)..selection.end.to_point(map)); diff --git a/crates/vim/src/normal/increment.rs b/crates/vim/src/normal/increment.rs index 194a5c8803..e092249e32 100644 --- a/crates/vim/src/normal/increment.rs +++ b/crates/vim/src/normal/increment.rs @@ -29,12 +29,14 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, action: &Increment, window, cx| { vim.record_current_action(cx); let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); let step = if action.step { count as i32 } else { 0 }; vim.increment(count as i64, step, window, cx) }); Vim::action(editor, cx, |vim, action: &Decrement, window, cx| { vim.record_current_action(cx); let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); let step = if action.step { -1 * (count as i32) } else { 0 }; vim.increment(-(count as i64), step, window, cx) }); diff --git a/crates/vim/src/normal/paste.rs b/crates/vim/src/normal/paste.rs index 2aaa2a4b7c..3d0a3e44c8 100644 --- a/crates/vim/src/normal/paste.rs +++ b/crates/vim/src/normal/paste.rs @@ -28,6 +28,7 @@ impl Vim { self.record_current_action(cx); self.store_visual_marks(window, cx); let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); self.update_editor(window, cx, |vim, editor, window, cx| { let text_layout_details = editor.text_layout_details(window); @@ -247,6 +248,7 @@ impl Vim { &mut self, motion: Motion, times: Option, + forced_motion: bool, window: &mut Window, cx: &mut Context, ) { @@ -258,7 +260,13 @@ impl Vim { editor.set_clip_at_line_ends(false, cx); editor.change_selections(None, window, cx, |s| { s.move_with(|map, selection| { - motion.expand_selection(map, selection, times, &text_layout_details); + motion.expand_selection( + map, + selection, + times, + &text_layout_details, + forced_motion, + ); }); }); diff --git a/crates/vim/src/normal/repeat.rs b/crates/vim/src/normal/repeat.rs index d396d0ae4d..49f07954ff 100644 --- a/crates/vim/src/normal/repeat.rs +++ b/crates/vim/src/normal/repeat.rs @@ -170,6 +170,7 @@ impl Vim { cx: &mut Context, ) { let mut count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); self.clear_operator(window, cx); let globals = Vim::globals(cx); @@ -201,6 +202,7 @@ impl Vim { cx: &mut Context, ) { let count = Vim::take_count(cx); + Vim::take_forced_motion(cx); let Some((mut actions, selection, mode)) = Vim::update_globals(cx, |globals, _| { let actions = globals.recorded_actions.clone(); diff --git a/crates/vim/src/normal/scroll.rs b/crates/vim/src/normal/scroll.rs index 0b87a3b345..dfca3aa280 100644 --- a/crates/vim/src/normal/scroll.rs +++ b/crates/vim/src/normal/scroll.rs @@ -55,6 +55,7 @@ impl Vim { by: fn(c: Option) -> ScrollAmount, ) { let amount = by(Vim::take_count(cx).map(|c| c as f32)); + Vim::take_forced_motion(cx); self.update_editor(window, cx, |_, editor, window, cx| { scroll_editor(editor, move_cursor, &amount, window, cx) }); diff --git a/crates/vim/src/normal/search.rs b/crates/vim/src/normal/search.rs index 98972097ae..da8f65c1cf 100644 --- a/crates/vim/src/normal/search.rs +++ b/crates/vim/src/normal/search.rs @@ -138,6 +138,7 @@ impl Vim { Direction::Next }; let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); let prior_selections = self.editor_selections(window, cx); pane.update(cx, |pane, cx| { if let Some(search_bar) = pane.toolbar().read(cx).item_of_type::() { @@ -261,6 +262,7 @@ impl Vim { return; }; let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); let prior_selections = self.editor_selections(window, cx); let success = pane.update(cx, |pane, cx| { @@ -303,6 +305,7 @@ impl Vim { return; }; let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); let prior_selections = self.editor_selections(window, cx); let cursor_word = self.editor_cursor_word(window, cx); let vim = cx.entity().clone(); diff --git a/crates/vim/src/normal/substitute.rs b/crates/vim/src/normal/substitute.rs index 78c9ec5b3f..1199356995 100644 --- a/crates/vim/src/normal/substitute.rs +++ b/crates/vim/src/normal/substitute.rs @@ -13,6 +13,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, _: &Substitute, window, cx| { vim.start_recording(cx); let count = Vim::take_count(cx); + Vim::take_forced_motion(cx); vim.substitute(count, vim.mode == Mode::VisualLine, window, cx); }); @@ -22,6 +23,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { vim.switch_mode(Mode::VisualLine, false, window, cx) } let count = Vim::take_count(cx); + Vim::take_forced_motion(cx); vim.substitute(count, true, window, cx) }); } @@ -47,6 +49,7 @@ impl Vim { selection, count, &text_layout_details, + false, ); } if line_mode { @@ -60,6 +63,7 @@ impl Vim { selection, None, &text_layout_details, + false, ); if let Some((point, _)) = (Motion::FirstNonWhitespace { display_lines: false, diff --git a/crates/vim/src/normal/toggle_comments.rs b/crates/vim/src/normal/toggle_comments.rs index 363215ffe2..1df381acbe 100644 --- a/crates/vim/src/normal/toggle_comments.rs +++ b/crates/vim/src/normal/toggle_comments.rs @@ -9,6 +9,7 @@ impl Vim { &mut self, motion: Motion, times: Option, + forced_motion: bool, window: &mut Window, cx: &mut Context, ) { @@ -21,7 +22,13 @@ impl Vim { s.move_with(|map, selection| { let anchor = map.display_point_to_anchor(selection.head(), Bias::Right); selection_starts.insert(selection.id, anchor); - motion.expand_selection(map, selection, times, &text_layout_details); + motion.expand_selection( + map, + selection, + times, + &text_layout_details, + forced_motion, + ); }); }); editor.toggle_comments(&Default::default(), window, cx); diff --git a/crates/vim/src/normal/yank.rs b/crates/vim/src/normal/yank.rs index 0ec19f654b..6f83b954b2 100644 --- a/crates/vim/src/normal/yank.rs +++ b/crates/vim/src/normal/yank.rs @@ -21,6 +21,7 @@ impl Vim { &mut self, motion: Motion, times: Option, + forced_motion: bool, window: &mut Window, cx: &mut Context, ) { @@ -33,8 +34,19 @@ impl Vim { editor.change_selections(None, window, cx, |s| { s.move_with(|map, selection| { let original_position = (selection.head(), selection.goal); - original_positions.insert(selection.id, original_position); - kind = motion.expand_selection(map, selection, times, &text_layout_details); + kind = motion.expand_selection( + map, + selection, + times, + &text_layout_details, + forced_motion, + ); + if kind == Some(MotionKind::Exclusive) { + original_positions + .insert(selection.id, (selection.start, selection.goal)); + } else { + original_positions.insert(selection.id, original_position); + } }) }); let Some(kind) = kind else { return }; diff --git a/crates/vim/src/replace.rs b/crates/vim/src/replace.rs index 26437550a1..f975aefa33 100644 --- a/crates/vim/src/replace.rs +++ b/crates/vim/src/replace.rs @@ -27,6 +27,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { return; } let count = Vim::take_count(cx); + Vim::take_forced_motion(cx); vim.undo_replace(count, window, cx) }); } @@ -179,6 +180,7 @@ impl Vim { &mut self, motion: Motion, times: Option, + forced_motion: bool, window: &mut Window, cx: &mut Context, ) { @@ -188,7 +190,13 @@ impl Vim { let text_layout_details = editor.text_layout_details(window); let mut selection = editor.selections.newest_display(cx); let snapshot = editor.snapshot(window, cx); - motion.expand_selection(&snapshot, &mut selection, times, &text_layout_details); + motion.expand_selection( + &snapshot, + &mut selection, + times, + &text_layout_details, + forced_motion, + ); let start = snapshot .buffer_snapshot .anchor_before(selection.start.to_point(&snapshot)); diff --git a/crates/vim/src/rewrap.rs b/crates/vim/src/rewrap.rs index f7f234c742..b5d69ef0ae 100644 --- a/crates/vim/src/rewrap.rs +++ b/crates/vim/src/rewrap.rs @@ -10,6 +10,7 @@ pub(crate) fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, _: &Rewrap, window, cx| { vim.record_current_action(cx); Vim::take_count(cx); + Vim::take_forced_motion(cx); vim.store_visual_marks(window, cx); vim.update_editor(window, cx, |vim, editor, window, cx| { editor.transact(window, cx, |editor, window, cx| { @@ -43,6 +44,7 @@ impl Vim { &mut self, motion: Motion, times: Option, + forced_motion: bool, window: &mut Window, cx: &mut Context, ) { @@ -55,7 +57,13 @@ impl Vim { s.move_with(|map, selection| { let anchor = map.display_point_to_anchor(selection.head(), Bias::Right); selection_starts.insert(selection.id, anchor); - motion.expand_selection(map, selection, times, &text_layout_details); + motion.expand_selection( + map, + selection, + times, + &text_layout_details, + forced_motion, + ); }); }); editor.rewrap_impl( diff --git a/crates/vim/src/state.rs b/crates/vim/src/state.rs index 6e7f753def..6b1a87aec7 100644 --- a/crates/vim/src/state.rs +++ b/crates/vim/src/state.rs @@ -202,7 +202,7 @@ pub struct VimGlobals { pub pre_count: Option, /// post_count is the number after an operator is specified (2 in 3d2d) pub post_count: Option, - + pub forced_motion: bool, pub stop_recording_after_next_action: bool, pub ignore_current_insertion: bool, pub recorded_count: Option, diff --git a/crates/vim/src/surrounds.rs b/crates/vim/src/surrounds.rs index 3c450292e1..6697742e4d 100644 --- a/crates/vim/src/surrounds.rs +++ b/crates/vim/src/surrounds.rs @@ -27,6 +27,7 @@ impl Vim { ) { self.stop_recording(cx); let count = Vim::take_count(cx); + let forced_motion = Vim::take_forced_motion(cx); let mode = self.mode; self.update_editor(window, cx, |_, editor, window, cx| { let text_layout_details = editor.text_layout_details(window); @@ -55,7 +56,13 @@ impl Vim { } SurroundsType::Motion(motion) => { motion - .range(&display_map, selection.clone(), count, &text_layout_details) + .range( + &display_map, + selection.clone(), + count, + &text_layout_details, + forced_motion, + ) .map(|(mut range, _)| { // The Motion::CurrentLine operation will contain the newline of the current line and leading/trailing whitespace if let Motion::CurrentLine = motion { diff --git a/crates/vim/src/test/neovim_backed_test_context.rs b/crates/vim/src/test/neovim_backed_test_context.rs index e2189da86b..053e1e587e 100644 --- a/crates/vim/src/test/neovim_backed_test_context.rs +++ b/crates/vim/src/test/neovim_backed_test_context.rs @@ -13,7 +13,7 @@ use super::{VimTestContext, neovim_connection::NeovimConnection}; use crate::state::{Mode, VimGlobals}; pub struct NeovimBackedTestContext { - cx: VimTestContext, + pub(crate) cx: VimTestContext, pub(crate) neovim: NeovimConnection, last_set_state: Option, diff --git a/crates/vim/src/test/vim_test_context.rs b/crates/vim/src/test/vim_test_context.rs index 32e8de3af5..188ae1c248 100644 --- a/crates/vim/src/test/vim_test_context.rs +++ b/crates/vim/src/test/vim_test_context.rs @@ -142,6 +142,10 @@ impl VimTestContext { self.update_editor(|editor, _, cx| editor.addon::().unwrap().entity.read(cx).mode) } + pub fn forced_motion(&mut self) -> bool { + self.update_editor(|_, _, cx| cx.global::().forced_motion) + } + pub fn active_operator(&mut self) -> Option { self.update_editor(|editor, _, cx| { editor diff --git a/crates/vim/src/vim.rs b/crates/vim/src/vim.rs index c4efb2b513..a1ecab13c3 100644 --- a/crates/vim/src/vim.rs +++ b/crates/vim/src/vim.rs @@ -145,6 +145,7 @@ actions!( PushDeleteSurrounds, PushMark, ToggleMarksView, + PushForcedMotion, PushIndent, PushOutdent, PushAutoIndent, @@ -233,6 +234,7 @@ pub fn init(cx: &mut App) { workspace.register_action(|workspace, _: &ResizePaneRight, window, cx| { let count = Vim::take_count(cx).unwrap_or(1) as f32; + Vim::take_forced_motion(cx); let theme = ThemeSettings::get_global(cx); let Ok(font_id) = window.text_system().font_id(&theme.buffer_font) else { return; @@ -248,6 +250,7 @@ pub fn init(cx: &mut App) { workspace.register_action(|workspace, _: &ResizePaneLeft, window, cx| { let count = Vim::take_count(cx).unwrap_or(1) as f32; + Vim::take_forced_motion(cx); let theme = ThemeSettings::get_global(cx); let Ok(font_id) = window.text_system().font_id(&theme.buffer_font) else { return; @@ -263,6 +266,7 @@ pub fn init(cx: &mut App) { workspace.register_action(|workspace, _: &ResizePaneUp, window, cx| { let count = Vim::take_count(cx).unwrap_or(1) as f32; + Vim::take_forced_motion(cx); let theme = ThemeSettings::get_global(cx); let height = theme.buffer_font_size(cx) * theme.buffer_line_height.value(); workspace.resize_pane(Axis::Vertical, height * count, window, cx); @@ -270,6 +274,7 @@ pub fn init(cx: &mut App) { workspace.register_action(|workspace, _: &ResizePaneDown, window, cx| { let count = Vim::take_count(cx).unwrap_or(1) as f32; + Vim::take_forced_motion(cx); let theme = ThemeSettings::get_global(cx); let height = theme.buffer_font_size(cx) * theme.buffer_line_height.value(); workspace.resize_pane(Axis::Vertical, -height * count, window, cx); @@ -472,7 +477,9 @@ impl Vim { vim.switch_mode(Mode::HelixNormal, false, window, cx) }, ); - + Vim::action(editor, cx, |_, _: &PushForcedMotion, _, cx| { + Vim::globals(cx).forced_motion = true; + }); Vim::action(editor, cx, |vim, action: &PushObject, window, cx| { vim.push_operator( Operator::Object { @@ -907,6 +914,7 @@ impl Vim { self.current_tx.take(); self.current_anchor.take(); } + Vim::take_forced_motion(cx); if mode != Mode::Insert && mode != Mode::Replace { Vim::take_count(cx); } @@ -1011,6 +1019,13 @@ impl Vim { count } + pub fn take_forced_motion(cx: &mut App) -> bool { + let global_state = cx.global_mut::(); + let forced_motion = global_state.forced_motion; + global_state.forced_motion = false; + forced_motion + } + pub fn cursor_shape(&self, cx: &mut App) -> CursorShape { match self.mode { Mode::Normal => { @@ -1372,6 +1387,7 @@ impl Vim { fn clear_operator(&mut self, window: &mut Window, cx: &mut Context) { Vim::take_count(cx); + Vim::take_forced_motion(cx); self.selected_register.take(); self.operator_stack.clear(); self.sync_vim_settings(window, cx); diff --git a/crates/vim/src/visual.rs b/crates/vim/src/visual.rs index a96d49a43c..6827c2c055 100644 --- a/crates/vim/src/visual.rs +++ b/crates/vim/src/visual.rs @@ -85,6 +85,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { Vim::action(editor, cx, |vim, _: &SelectLargerSyntaxNode, window, cx| { let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); for _ in 0..count { vim.update_editor(window, cx, |_, editor, window, cx| { editor.select_larger_syntax_node(&Default::default(), window, cx); @@ -97,6 +98,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { cx, |vim, _: &SelectSmallerSyntaxNode, window, cx| { let count = Vim::take_count(cx).unwrap_or(1); + Vim::take_forced_motion(cx); for _ in 0..count { vim.update_editor(window, cx, |_, editor, window, cx| { editor.select_smaller_syntax_node(&Default::default(), window, cx); @@ -682,6 +684,7 @@ impl Vim { } pub fn select_next(&mut self, _: &SelectNext, window: &mut Window, cx: &mut Context) { + Vim::take_forced_motion(cx); let count = Vim::take_count(cx).unwrap_or_else(|| if self.mode.is_visual() { 1 } else { 2 }); self.update_editor(window, cx, |_, editor, window, cx| { @@ -704,6 +707,7 @@ impl Vim { window: &mut Window, cx: &mut Context, ) { + Vim::take_forced_motion(cx); let count = Vim::take_count(cx).unwrap_or_else(|| if self.mode.is_visual() { 1 } else { 2 }); self.update_editor(window, cx, |_, editor, window, cx| { @@ -725,6 +729,7 @@ impl Vim { window: &mut Window, cx: &mut Context, ) { + Vim::take_forced_motion(cx); let count = Vim::take_count(cx).unwrap_or(1); let Some(pane) = self.pane(window, cx) else { return; diff --git a/crates/vim/test_data/test_forced_motion_delete_to_end_of_line.json b/crates/vim/test_data/test_forced_motion_delete_to_end_of_line.json new file mode 100644 index 0000000000..4df916befb --- /dev/null +++ b/crates/vim/test_data/test_forced_motion_delete_to_end_of_line.json @@ -0,0 +1,10 @@ +{"Put":{"state":"the quick brown foˇx\njumped over the lazy dog"}} +{"Key":"d"} +{"Key":"v"} +{"Key":"$"} +{"Get":{"state":"the quick brown foˇx\njumped over the lazy dog","mode":"Normal"}} +{"Put":{"state":"ˇthe quick brown fox\njumped over the lazy dog"}} +{"Key":"d"} +{"Key":"v"} +{"Key":"$"} +{"Get":{"state":"ˇx\njumped over the lazy dog","mode":"Normal"}} diff --git a/crates/vim/test_data/test_forced_motion_delete_to_start_of_line.json b/crates/vim/test_data/test_forced_motion_delete_to_start_of_line.json new file mode 100644 index 0000000000..8aae77c8de --- /dev/null +++ b/crates/vim/test_data/test_forced_motion_delete_to_start_of_line.json @@ -0,0 +1,15 @@ +{"Put":{"state":"ˇthe quick brown fox\njumped over the lazy dog"}} +{"Key":"d"} +{"Key":"v"} +{"Key":"0"} +{"Get":{"state":"ˇhe quick brown fox\njumped over the lazy dog","mode":"Normal"}} +{"Put":{"state":"the quick bˇrown fox\njumped over the lazy dog"}} +{"Key":"d"} +{"Key":"v"} +{"Key":"0"} +{"Get":{"state":"ˇown fox\njumped over the lazy dog","mode":"Normal"}} +{"Put":{"state":"the quick brown foˇx\njumped over the lazy dog"}} +{"Key":"d"} +{"Key":"v"} +{"Key":"0"} +{"Get":{"state":"ˇ\njumped over the lazy dog","mode":"Normal"}} diff --git a/crates/vim/test_data/test_forced_motion_yank.json b/crates/vim/test_data/test_forced_motion_yank.json new file mode 100644 index 0000000000..208c22d689 --- /dev/null +++ b/crates/vim/test_data/test_forced_motion_yank.json @@ -0,0 +1,24 @@ +{"Put":{"state":"ˇthe quick brown fox\njumped over the lazy dog"}} +{"Key":"y"} +{"Key":"v"} +{"Key":"j"} +{"Key":"p"} +{"Get":{"state":"the quick brown fox\nˇthe quick brown fox\njumped over the lazy dog","mode":"Normal"}} +{"Put":{"state":"the quick bˇrown fox\njumped over the lazy dog"}} +{"Key":"y"} +{"Key":"v"} +{"Key":"j"} +{"Key":"p"} +{"Get":{"state":"the quick brˇrown fox\njumped overown fox\njumped over the lazy dog","mode":"Normal"}} +{"Put":{"state":"the quick brown foˇx\njumped over the lazy dog"}} +{"Key":"y"} +{"Key":"v"} +{"Key":"j"} +{"Key":"p"} +{"Get":{"state":"the quick brown foxˇx\njumped over the la\njumped over the lazy dog","mode":"Normal"}} +{"Put":{"state":"the quick brown fox\njˇumped over the lazy dog"}} +{"Key":"y"} +{"Key":"v"} +{"Key":"k"} +{"Key":"p"} +{"Get":{"state":"thˇhe quick brown fox\nje quick brown fox\njumped over the lazy dog","mode":"Normal"}} diff --git a/crates/vim/test_data/test_inclusive_to_exclusive_delete.json b/crates/vim/test_data/test_inclusive_to_exclusive_delete.json new file mode 100644 index 0000000000..3d25b9fc67 --- /dev/null +++ b/crates/vim/test_data/test_inclusive_to_exclusive_delete.json @@ -0,0 +1,15 @@ +{"Put":{"state":"ˇthe quick brown fox\njumped over the lazy dog"}} +{"Key":"d"} +{"Key":"v"} +{"Key":"e"} +{"Get":{"state":"ˇe quick brown fox\njumped over the lazy dog","mode":"Normal"}} +{"Put":{"state":"the quick bˇrown fox\njumped over the lazy dog"}} +{"Key":"d"} +{"Key":"v"} +{"Key":"e"} +{"Get":{"state":"the quick bˇn fox\njumped over the lazy dog","mode":"Normal"}} +{"Put":{"state":"the quick brown foˇx\njumped over the lazy dog"}} +{"Key":"d"} +{"Key":"v"} +{"Key":"e"} +{"Get":{"state":"the quick brown foˇd over the lazy dog","mode":"Normal"}} diff --git a/crates/workspace/Cargo.toml b/crates/workspace/Cargo.toml index aa257a5fc9..63a57fe14a 100644 --- a/crates/workspace/Cargo.toml +++ b/crates/workspace/Cargo.toml @@ -67,6 +67,9 @@ uuid.workspace = true zed_actions.workspace = true workspace-hack.workspace = true +[target.'cfg(target_os = "windows")'.dependencies] +windows.workspace = true + [dev-dependencies] call = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] } @@ -78,5 +81,5 @@ gpui = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] } session = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } -http_client = { workspace = true, features = ["test-support"] } +http_client = { workspace = true, features = ["test-support"] } tempfile.workspace = true diff --git a/crates/workspace/src/history_manager.rs b/crates/workspace/src/history_manager.rs new file mode 100644 index 0000000000..e63b1823ea --- /dev/null +++ b/crates/workspace/src/history_manager.rs @@ -0,0 +1,129 @@ +use std::path::PathBuf; + +use gpui::{AppContext, Entity, Global, MenuItem}; +use smallvec::SmallVec; +use ui::App; +use util::{ResultExt, paths::PathExt}; + +use crate::{NewWindow, SerializedWorkspaceLocation, WORKSPACE_DB, WorkspaceId}; + +pub fn init(cx: &mut App) { + let manager = cx.new(|_| HistoryManager::new()); + HistoryManager::set_global(manager.clone(), cx); + HistoryManager::init(manager, cx); +} + +pub struct HistoryManager { + /// The history of workspaces that have been opened in the past, in reverse order. + /// The most recent workspace is at the end of the vector. + history: Vec, +} + +#[derive(Debug)] +pub struct HistoryManagerEntry { + pub id: WorkspaceId, + pub path: SmallVec<[PathBuf; 2]>, +} + +struct GlobalHistoryManager(Entity); + +impl Global for GlobalHistoryManager {} + +impl HistoryManager { + fn new() -> Self { + Self { + history: Vec::new(), + } + } + + fn init(this: Entity, cx: &App) { + cx.spawn(async move |cx| { + let recent_folders = WORKSPACE_DB + .recent_workspaces_on_disk() + .await + .unwrap_or_default() + .into_iter() + .rev() + .map(|(id, location)| HistoryManagerEntry::new(id, &location)) + .collect::>(); + this.update(cx, |this, cx| { + this.history = recent_folders; + this.update_jump_list(cx); + }) + }) + .detach(); + } + + pub fn global(cx: &App) -> Option> { + cx.try_global::() + .map(|model| model.0.clone()) + } + + fn set_global(history_manager: Entity, cx: &mut App) { + cx.set_global(GlobalHistoryManager(history_manager)); + } + + pub fn update_history(&mut self, id: WorkspaceId, entry: HistoryManagerEntry, cx: &App) { + if let Some(pos) = self.history.iter().position(|e| e.id == id) { + self.history.remove(pos); + } + self.history.push(entry); + self.update_jump_list(cx); + } + + pub fn delete_history(&mut self, id: WorkspaceId, cx: &App) { + let Some(pos) = self.history.iter().position(|e| e.id == id) else { + return; + }; + self.history.remove(pos); + self.update_jump_list(cx); + } + + fn update_jump_list(&mut self, cx: &App) { + let menus = vec![MenuItem::action("New Window", NewWindow)]; + let entries = self + .history + .iter() + .rev() + .map(|entry| entry.path.clone()) + .collect::>(); + let user_removed = cx.update_jump_list(menus, entries); + self.remove_user_removed_workspaces(user_removed, cx); + } + + pub fn remove_user_removed_workspaces( + &mut self, + user_removed: Vec>, + cx: &App, + ) { + if user_removed.is_empty() { + return; + } + let mut deleted_ids = Vec::new(); + for idx in (0..self.history.len()).rev() { + if let Some(entry) = self.history.get(idx) { + if user_removed.contains(&entry.path) { + deleted_ids.push(entry.id); + self.history.remove(idx); + } + } + } + cx.spawn(async move |_| { + for id in deleted_ids.iter() { + WORKSPACE_DB.delete_workspace_by_id(*id).await.log_err(); + } + }) + .detach(); + } +} + +impl HistoryManagerEntry { + pub fn new(id: WorkspaceId, location: &SerializedWorkspaceLocation) -> Self { + let path = location + .sorted_paths() + .iter() + .map(|path| path.compact()) + .collect::>(); + Self { id, path } + } +} diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index 286a744569..06a84773ce 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -745,7 +745,7 @@ impl WorkspaceDb { conn.exec_bound(sql!( DELETE FROM pane_groups WHERE workspace_id = ?1; DELETE FROM panes WHERE workspace_id = ?1;))?(workspace.id) - .context("Clearing old panes")?; + .context("Clearing old panes")?; conn.exec_bound(sql!(DELETE FROM breakpoints WHERE workspace_id = ?1))?(workspace.id).context("Clearing old breakpoints")?; diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 5a9dce7c01..6c06489c44 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -1,4 +1,5 @@ pub mod dock; +pub mod history_manager; pub mod item; mod modal_layer; pub mod notifications; @@ -43,6 +44,7 @@ use gpui::{ WindowHandle, WindowId, WindowOptions, action_as, actions, canvas, impl_action_as, impl_actions, point, relative, size, transparent_black, }; +pub use history_manager::*; pub use item::{ FollowableItem, FollowableItemHandle, Item, ItemHandle, ItemSettings, PreviewTabsSettings, ProjectItem, SerializableItem, SerializableItemHandle, WeakItemHandle, @@ -102,7 +104,7 @@ use ui::prelude::*; use util::{ResultExt, TryFutureExt, paths::SanitizedPath, serde::default_true}; use uuid::Uuid; pub use workspace_settings::{ - AutosaveSetting, RestoreOnStartupBehavior, TabBarSettings, WorkspaceSettings, + AutosaveSetting, BottomDockLayout, RestoreOnStartupBehavior, TabBarSettings, WorkspaceSettings, }; use crate::notifications::NotificationId; @@ -387,6 +389,7 @@ pub fn init(app_state: Arc, cx: &mut App) { component::init(); theme_preview::init(cx); toast_layer::init(cx); + history_manager::init(cx); cx.on_action(Workspace::close_global); cx.on_action(reload); @@ -819,6 +822,7 @@ pub struct Workspace { center: PaneGroup, left_dock: Entity, bottom_dock: Entity, + bottom_dock_layout: BottomDockLayout, right_dock: Entity, panes: Vec>, panes_by_item: HashMap>, @@ -901,6 +905,9 @@ impl Workspace { project::Event::WorktreeRemoved(_) | project::Event::WorktreeAdded(_) => { this.update_window_title(window, cx); this.serialize_workspace(window, cx); + // This event could be triggered by `AddFolderToProject` or `RemoveFromProject`. + // So we need to update the history. + this.update_history(cx); } project::Event::DisconnectedFromHost => { @@ -1044,6 +1051,7 @@ impl Workspace { let modal_layer = cx.new(|_| ModalLayer::new()); let toast_layer = cx.new(|_| ToastLayer::new()); + let bottom_dock_layout = WorkspaceSettings::get_global(cx).bottom_dock_layout; let left_dock = Dock::new(DockPosition::Left, modal_layer.clone(), window, cx); let bottom_dock = Dock::new(DockPosition::Bottom, modal_layer.clone(), window, cx); let right_dock = Dock::new(DockPosition::Right, modal_layer.clone(), window, cx); @@ -1141,6 +1149,7 @@ impl Workspace { notifications: Default::default(), left_dock, bottom_dock, + bottom_dock_layout, right_dock, project: project.clone(), follower_states: Default::default(), @@ -1331,7 +1340,10 @@ impl Workspace { .unwrap_or_default(); window - .update(cx, |_, window, _| window.activate_window()) + .update(cx, |workspace, window, cx| { + window.activate_window(); + workspace.update_history(cx); + }) .log_err(); Ok((window, opened_items)) }) @@ -1349,6 +1361,26 @@ impl Workspace { &self.bottom_dock } + pub fn bottom_dock_layout(&self) -> BottomDockLayout { + self.bottom_dock_layout + } + + pub fn set_bottom_dock_layout( + &mut self, + layout: BottomDockLayout, + window: &mut Window, + cx: &mut Context, + ) { + let fs = self.project().read(cx).fs(); + settings::update_settings_file::(fs.clone(), cx, move |content, _cx| { + content.bottom_dock_layout = Some(layout); + }); + + self.bottom_dock_layout = layout; + cx.notify(); + self.serialize_workspace(window, cx); + } + pub fn right_dock(&self) -> &Entity { &self.right_dock } @@ -4684,19 +4716,7 @@ impl Workspace { } } - let location = if let Some(ssh_project) = &self.serialized_ssh_project { - Some(SerializedWorkspaceLocation::Ssh(ssh_project.clone())) - } else if let Some(local_paths) = self.local_paths(cx) { - if !local_paths.is_empty() { - Some(SerializedWorkspaceLocation::from_local_paths(local_paths)) - } else { - None - } - } else { - None - }; - - if let Some(location) = location { + if let Some(location) = self.serialize_workspace_location(cx) { let breakpoints = self.project.update(cx, |project, cx| { project.breakpoint_store().read(cx).all_breakpoints(cx) }); @@ -4716,13 +4736,42 @@ impl Workspace { breakpoints, window_id: Some(window.window_handle().window_id().as_u64()), }; + return window.spawn(cx, async move |_| { - persistence::DB.save_workspace(serialized_workspace).await + persistence::DB.save_workspace(serialized_workspace).await; }); } Task::ready(()) } + fn serialize_workspace_location(&self, cx: &App) -> Option { + if let Some(ssh_project) = &self.serialized_ssh_project { + Some(SerializedWorkspaceLocation::Ssh(ssh_project.clone())) + } else if let Some(local_paths) = self.local_paths(cx) { + if !local_paths.is_empty() { + Some(SerializedWorkspaceLocation::from_local_paths(local_paths)) + } else { + None + } + } else { + None + } + } + + fn update_history(&self, cx: &mut App) { + let Some(id) = self.database_id() else { + return; + }; + let Some(location) = self.serialize_workspace_location(cx) else { + return; + }; + if let Some(manager) = HistoryManager::global(cx) { + manager.update(cx, |this, cx| { + this.update_history(id, HistoryManagerEntry::new(id, &location), cx); + }); + } + } + async fn serialize_items( this: &WeakEntity, items_rx: UnboundedReceiver>, @@ -5535,64 +5584,248 @@ impl Render for Workspace { }, )) }) - .child( - div() - .flex() - .flex_row() - .h_full() - // Left Dock - .children(self.render_dock( - DockPosition::Left, - &self.left_dock, - window, - cx, - )) - // Panes - .child( - div() - .flex() - .flex_col() - .flex_1() - .overflow_hidden() - .child( - h_flex() - .flex_1() - .when_some(paddings.0, |this, p| { - this.child(p.border_r_1()) - }) - .child(self.center.render( - self.zoomed.as_ref(), - &PaneRenderContext { - follower_states: - &self.follower_states, - active_call: self.active_call(), - active_pane: &self.active_pane, - app_state: &self.app_state, - project: &self.project, - workspace: &self.weak_self, - }, - window, - cx, - )) - .when_some(paddings.1, |this, p| { - this.child(p.border_l_1()) - }), - ) - .children(self.render_dock( - DockPosition::Bottom, - &self.bottom_dock, - window, - cx, - )), - ) - // Right Dock - .children(self.render_dock( - DockPosition::Right, - &self.right_dock, - window, - cx, - )), - ) + .child({ + match self.bottom_dock_layout { + BottomDockLayout::Full => div() + .flex() + .flex_col() + .h_full() + .child( + div() + .flex() + .flex_row() + .flex_1() + .overflow_hidden() + .children(self.render_dock( + DockPosition::Left, + &self.left_dock, + window, + cx, + )) + .child( + div() + .flex() + .flex_col() + .flex_1() + .overflow_hidden() + .child( + h_flex() + .flex_1() + .when_some( + paddings.0, + |this, p| { + this.child( + p.border_r_1(), + ) + }, + ) + .child(self.center.render( + self.zoomed.as_ref(), + &PaneRenderContext { + follower_states: + &self.follower_states, + active_call: self.active_call(), + active_pane: &self.active_pane, + app_state: &self.app_state, + project: &self.project, + workspace: &self.weak_self, + }, + window, + cx, + )) + .when_some( + paddings.1, + |this, p| { + this.child( + p.border_l_1(), + ) + }, + ), + ), + ) + .children(self.render_dock( + DockPosition::Right, + &self.right_dock, + window, + cx, + )), + ) + .child(div().w_full().children(self.render_dock( + DockPosition::Bottom, + &self.bottom_dock, + window, + cx + ))), + + BottomDockLayout::LeftAligned => div() + .flex() + .flex_row() + .h_full() + .child( + div() + .flex() + .flex_col() + .flex_1() + .h_full() + .child( + div() + .flex() + .flex_row() + .flex_1() + .children(self.render_dock(DockPosition::Left, &self.left_dock, window, cx)) + .child( + div() + .flex() + .flex_col() + .flex_1() + .overflow_hidden() + .child( + h_flex() + .flex_1() + .when_some(paddings.0, |this, p| this.child(p.border_r_1())) + .child(self.center.render( + self.zoomed.as_ref(), + &PaneRenderContext { + follower_states: + &self.follower_states, + active_call: self.active_call(), + active_pane: &self.active_pane, + app_state: &self.app_state, + project: &self.project, + workspace: &self.weak_self, + }, + window, + cx, + )) + .when_some(paddings.1, |this, p| this.child(p.border_l_1())), + ) + ) + ) + .child( + div() + .w_full() + .children(self.render_dock(DockPosition::Bottom, &self.bottom_dock, window, cx)) + ), + ) + .children(self.render_dock( + DockPosition::Right, + &self.right_dock, + window, + cx, + )), + + BottomDockLayout::RightAligned => div() + .flex() + .flex_row() + .h_full() + .children(self.render_dock( + DockPosition::Left, + &self.left_dock, + window, + cx, + )) + .child( + div() + .flex() + .flex_col() + .flex_1() + .h_full() + .child( + div() + .flex() + .flex_row() + .flex_1() + .child( + div() + .flex() + .flex_col() + .flex_1() + .overflow_hidden() + .child( + h_flex() + .flex_1() + .when_some(paddings.0, |this, p| this.child(p.border_r_1())) + .child(self.center.render( + self.zoomed.as_ref(), + &PaneRenderContext { + follower_states: + &self.follower_states, + active_call: self.active_call(), + active_pane: &self.active_pane, + app_state: &self.app_state, + project: &self.project, + workspace: &self.weak_self, + }, + window, + cx, + )) + .when_some(paddings.1, |this, p| this.child(p.border_l_1())), + ) + ) + .children(self.render_dock(DockPosition::Right, &self.right_dock, window, cx)) + ) + .child( + div() + .w_full() + .children(self.render_dock(DockPosition::Bottom, &self.bottom_dock, window, cx)) + ), + ), + + BottomDockLayout::Contained => div() + .flex() + .flex_row() + .h_full() + .children(self.render_dock( + DockPosition::Left, + &self.left_dock, + window, + cx, + )) + .child( + div() + .flex() + .flex_col() + .flex_1() + .overflow_hidden() + .child( + h_flex() + .flex_1() + .when_some(paddings.0, |this, p| { + this.child(p.border_r_1()) + }) + .child(self.center.render( + self.zoomed.as_ref(), + &PaneRenderContext { + follower_states: + &self.follower_states, + active_call: self.active_call(), + active_pane: &self.active_pane, + app_state: &self.app_state, + project: &self.project, + workspace: &self.weak_self, + }, + window, + cx, + )) + .when_some(paddings.1, |this, p| { + this.child(p.border_l_1()) + }), + ) + .children(self.render_dock( + DockPosition::Bottom, + &self.bottom_dock, + window, + cx, + )), + ) + .children(self.render_dock( + DockPosition::Right, + &self.right_dock, + window, + cx, + )), + } + }) .children(self.zoomed.as_ref().and_then(|view| { let zoomed_view = view.upgrade()?; let div = div() @@ -6407,6 +6640,7 @@ async fn open_ssh_project_inner( let mut workspace = Workspace::new(Some(workspace_id), project, app_state.clone(), window, cx); workspace.set_serialized_ssh_project(serialized_ssh_project); + workspace.update_history(cx); workspace }); })?; diff --git a/crates/workspace/src/workspace_settings.rs b/crates/workspace/src/workspace_settings.rs index 2dda042288..a61a987b1c 100644 --- a/crates/workspace/src/workspace_settings.rs +++ b/crates/workspace/src/workspace_settings.rs @@ -10,6 +10,7 @@ use settings::{Settings, SettingsSources}; #[derive(Deserialize)] pub struct WorkspaceSettings { pub active_pane_modifiers: ActivePanelModifiers, + pub bottom_dock_layout: BottomDockLayout, pub pane_split_direction_horizontal: PaneSplitDirectionHorizontal, pub pane_split_direction_vertical: PaneSplitDirectionVertical, pub centered_layout: CenteredLayoutSettings, @@ -71,6 +72,20 @@ pub struct ActivePanelModifiers { pub inactive_opacity: Option, } +#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum BottomDockLayout { + /// Contained between the left and right docks + #[default] + Contained, + /// Takes up the full width of the window + Full, + /// Extends under the left dock while snapping to the right dock + LeftAligned, + /// Extends under the right dock while snapping to the left dock + RightAligned, +} + #[derive(Copy, Clone, Default, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] pub enum CloseWindowWhenNoItems { @@ -109,6 +124,10 @@ pub enum RestoreOnStartupBehavior { pub struct WorkspaceSettingsContent { /// Active pane styling settings. pub active_pane_modifiers: Option, + /// Layout mode for the bottom dock + /// + /// Default: contained + pub bottom_dock_layout: Option, /// Direction to split horizontally. /// /// Default: "up" diff --git a/crates/worktree/src/worktree.rs b/crates/worktree/src/worktree.rs index 920b89770b..fa4bf202d8 100644 --- a/crates/worktree/src/worktree.rs +++ b/crates/worktree/src/worktree.rs @@ -29,7 +29,7 @@ use ignore::IgnoreStack; use language::DiskState; use parking_lot::Mutex; -use paths::local_settings_folder_relative_path; +use paths::{local_settings_folder_relative_path, local_vscode_folder_relative_path}; use postage::{ barrier, prelude::{Sink as _, Stream as _}, @@ -384,12 +384,23 @@ struct LocalRepositoryEntry { work_directory: WorkDirectory, work_directory_abs_path: Arc, git_dir_scan_id: usize, - original_dot_git_abs_path: Arc, - /// Absolute path to the actual .git folder. - /// Note: if .git is a file, this points to the folder indicated by the .git file - dot_git_dir_abs_path: Arc, - /// Absolute path to the .git file, if we're in a git worktree. - dot_git_worktree_abs_path: Option>, + /// Absolute path to the original .git entry that caused us to create this repository. + /// + /// This is normally a directory, but may be a "gitfile" that points to a directory elsewhere + /// (whose path we then store in `repository_dir_abs_path`). + dot_git_abs_path: Arc, + /// Absolute path to the "commondir" for this repository. + /// + /// This is always a directory. For a normal repository, this is the same as dot_git_abs_path, + /// but in the case of a submodule or a worktree it is the path to the "parent" .git directory + /// from which the submodule/worktree was derived. + common_dir_abs_path: Arc, + /// Absolute path to the directory holding the repository's state. + /// + /// For a normal repository, this is a directory and coincides with `dot_git_abs_path` and + /// `common_dir_abs_path`. For a submodule or worktree, this is some subdirectory of the + /// commondir like `/project/.git/modules/foo`. + repository_dir_abs_path: Arc, } impl sum_tree::Item for LocalRepositoryEntry { @@ -1351,7 +1362,11 @@ impl LocalWorktree { new_work_directory_abs_path: Some( new_repo.work_directory_abs_path.clone(), ), - dot_git_abs_path: Some(new_repo.original_dot_git_abs_path.clone()), + dot_git_abs_path: Some(new_repo.dot_git_abs_path.clone()), + repository_dir_abs_path: Some( + new_repo.repository_dir_abs_path.clone(), + ), + common_dir_abs_path: Some(new_repo.common_dir_abs_path.clone()), }); new_repos.next(); } @@ -1368,9 +1383,11 @@ impl LocalWorktree { new_work_directory_abs_path: Some( new_repo.work_directory_abs_path.clone(), ), - dot_git_abs_path: Some( - new_repo.original_dot_git_abs_path.clone(), + dot_git_abs_path: Some(new_repo.dot_git_abs_path.clone()), + repository_dir_abs_path: Some( + new_repo.repository_dir_abs_path.clone(), ), + common_dir_abs_path: Some(new_repo.common_dir_abs_path.clone()), }); } new_repos.next(); @@ -1384,6 +1401,8 @@ impl LocalWorktree { ), new_work_directory_abs_path: None, dot_git_abs_path: None, + repository_dir_abs_path: None, + common_dir_abs_path: None, }); old_repos.next(); } @@ -1394,7 +1413,9 @@ impl LocalWorktree { work_directory_id: entry_id, old_work_directory_abs_path: None, new_work_directory_abs_path: Some(repo.work_directory_abs_path.clone()), - dot_git_abs_path: Some(repo.original_dot_git_abs_path.clone()), + dot_git_abs_path: Some(repo.dot_git_abs_path.clone()), + repository_dir_abs_path: Some(repo.repository_dir_abs_path.clone()), + common_dir_abs_path: Some(repo.common_dir_abs_path.clone()), }); new_repos.next(); } @@ -1403,7 +1424,9 @@ impl LocalWorktree { work_directory_id: entry_id, old_work_directory_abs_path: Some(repo.work_directory_abs_path.clone()), new_work_directory_abs_path: None, - dot_git_abs_path: Some(repo.original_dot_git_abs_path.clone()), + dot_git_abs_path: Some(repo.dot_git_abs_path.clone()), + repository_dir_abs_path: Some(repo.repository_dir_abs_path.clone()), + common_dir_abs_path: Some(repo.common_dir_abs_path.clone()), }); old_repos.next(); } @@ -2842,6 +2865,7 @@ impl BackgroundScannerState { (!entry.is_external && (!entry.is_ignored || entry.is_always_included)) || entry.path.file_name() == Some(*DOT_GIT) || entry.path.file_name() == Some(local_settings_folder_relative_path().as_os_str()) + || entry.path.file_name() == Some(local_vscode_folder_relative_path().as_os_str()) || self.scanned_dirs.contains(&entry.id) // If we've ever scanned it, keep scanning || self .paths_to_scan @@ -3042,9 +3066,6 @@ impl BackgroundScannerState { ); return; }; - log::debug!( - "building git repository, `.git` path in the worktree: {dot_git_path:?}" - ); parent_dir.into() } @@ -3075,7 +3096,6 @@ impl BackgroundScannerState { fs: &dyn Fs, watcher: &dyn Watcher, ) -> Option { - log::trace!("insert git repository for {dot_git_path:?}"); let work_dir_entry = self.snapshot.entry_for_path(work_directory.path_key().0)?; let work_directory_abs_path = self .snapshot @@ -3092,46 +3112,51 @@ impl BackgroundScannerState { return None; } - let dot_git_abs_path = self.snapshot.abs_path.as_path().join(&dot_git_path); + let dot_git_abs_path: Arc = self + .snapshot + .abs_path + .as_path() + .join(&dot_git_path) + .as_path() + .into(); - // TODO add these watchers without building a whole repository by parsing .git-with-indirection - let t0 = Instant::now(); - let repository = fs.open_repo(&dot_git_abs_path)?; - log::trace!("opened git repo for {dot_git_abs_path:?}"); - - let repository_path = repository.path(); - watcher.add(&repository_path).log_err()?; - - let actual_dot_git_dir_abs_path = repository.main_repository_path(); - let dot_git_worktree_abs_path = if actual_dot_git_dir_abs_path == dot_git_abs_path { - None - } else { - // The two paths could be different because we opened a git worktree. - // When that happens: - // - // * `dot_git_abs_path` is a file that points to the worktree-subdirectory in the actual - // .git directory. - // - // * `repository_path` is the worktree-subdirectory. - // - // * `actual_dot_git_dir_abs_path` is the path to the actual .git directory. In git - // documentation this is called the "commondir". - watcher.add(&dot_git_abs_path).log_err()?; - Some(Arc::from(dot_git_abs_path.as_path())) + let mut common_dir_abs_path = dot_git_abs_path.clone(); + let mut repository_dir_abs_path = dot_git_abs_path.clone(); + // Parse .git if it's a "gitfile" pointing to a repository directory elsewhere. + if let Some(dot_git_contents) = smol::block_on(fs.load(&dot_git_abs_path)).ok() { + if let Some(path) = dot_git_contents.strip_prefix("gitdir:") { + let path = path.trim(); + let path = dot_git_abs_path + .parent() + .unwrap_or(Path::new("")) + .join(path); + if let Some(path) = smol::block_on(fs.canonicalize(&path)).log_err() { + repository_dir_abs_path = Path::new(&path).into(); + common_dir_abs_path = repository_dir_abs_path.clone(); + if let Some(ancestor_dot_git) = path + .ancestors() + .skip(1) + .find(|ancestor| smol::block_on(is_git_dir(ancestor, fs))) + { + common_dir_abs_path = ancestor_dot_git.into(); + } + } + } else { + log::error!("failed to parse contents of .git file: {dot_git_contents:?}"); + } }; - - log::trace!("constructed libgit2 repo in {:?}", t0.elapsed()); + watcher.add(&common_dir_abs_path).log_err(); let work_directory_id = work_dir_entry.id; let local_repository = LocalRepositoryEntry { work_directory_id, work_directory, - git_dir_scan_id: 0, - original_dot_git_abs_path: dot_git_abs_path.as_path().into(), - dot_git_dir_abs_path: actual_dot_git_dir_abs_path.into(), work_directory_abs_path: work_directory_abs_path.as_path().into(), - dot_git_worktree_abs_path, + git_dir_scan_id: 0, + dot_git_abs_path, + common_dir_abs_path, + repository_dir_abs_path, }; self.snapshot @@ -3454,6 +3479,8 @@ pub struct UpdatedGitRepository { /// For a normal git repository checkout, the absolute path to the .git directory. /// For a worktree, the absolute path to the worktree's subdirectory inside the .git directory. pub dot_git_abs_path: Option>, + pub repository_dir_abs_path: Option>, + pub common_dir_abs_path: Option>, } pub type UpdatedEntriesSet = Arc<[(Arc, ProjectEntryId, PathChange)]>; @@ -4010,8 +4037,8 @@ impl BackgroundScanner { if abs_path.0.file_name() == Some(*GITIGNORE) { for (_, repo) in snapshot.git_repositories.iter().filter(|(_, repo)| repo.directory_contains(&relative_path)) { - if !dot_git_abs_paths.iter().any(|dot_git_abs_path| dot_git_abs_path == repo.dot_git_dir_abs_path.as_ref()) { - dot_git_abs_paths.push(repo.dot_git_dir_abs_path.to_path_buf()); + if !dot_git_abs_paths.iter().any(|dot_git_abs_path| dot_git_abs_path == repo.common_dir_abs_path.as_ref()) { + dot_git_abs_paths.push(repo.common_dir_abs_path.to_path_buf()); } } } @@ -4070,7 +4097,6 @@ impl BackgroundScanner { } } self.send_status_update(false, SmallVec::new()); - // send_status_update_inner(phase, state, status_update_tx, false, SmallVec::new()); } async fn forcibly_load_paths(&self, paths: &[Arc]) -> bool { @@ -4709,8 +4735,8 @@ impl BackgroundScanner { .git_repositories .iter() .find_map(|(_, repo)| { - if repo.dot_git_dir_abs_path.as_ref() == &dot_git_dir - || repo.dot_git_worktree_abs_path.as_deref() == Some(&dot_git_dir) + if repo.common_dir_abs_path.as_ref() == &dot_git_dir + || repo.repository_dir_abs_path.as_ref() == &dot_git_dir { Some(repo.clone()) } else { @@ -4752,7 +4778,7 @@ impl BackgroundScanner { if exists_in_snapshot || matches!( - smol::block_on(self.fs.metadata(&entry.dot_git_dir_abs_path)), + smol::block_on(self.fs.metadata(&entry.common_dir_abs_path)), Ok(Some(_)) ) { @@ -5081,7 +5107,7 @@ impl WorktreeModelHandle for Entity { .unwrap(); ( tree.fs.clone(), - local_repo_entry.dot_git_dir_abs_path.clone(), + local_repo_entry.common_dir_abs_path.clone(), local_repo_entry.git_dir_scan_id, ) }); diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 66f8f2838b..ae253abdfc 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -2,7 +2,7 @@ description = "The fast, collaborative code editor." edition.workspace = true name = "zed" -version = "0.182.0" +version = "0.183.0" publish.workspace = true license = "GPL-3.0-or-later" authors = ["Zed Team "] @@ -108,7 +108,6 @@ session.workspace = true settings.workspace = true settings_ui.workspace = true shellexpand.workspace = true -simplelog.workspace = true smol.workspace = true snippet_provider.workspace = true snippets_ui.workspace = true diff --git a/crates/zed/src/logger.rs b/crates/zed/src/logger.rs deleted file mode 100644 index 5a889424d0..0000000000 --- a/crates/zed/src/logger.rs +++ /dev/null @@ -1,122 +0,0 @@ -use chrono::Offset; -use env_logger::Builder; -use log::LevelFilter; -use simplelog::ConfigBuilder; -use std::fs::{self, File, OpenOptions}; -use std::io::{self, Write}; -use time::UtcOffset; - -pub fn init_logger() { - let level = LevelFilter::Info; - - // Prevent log file from becoming too large. - const KIB: u64 = 1024; - const MIB: u64 = 1024 * KIB; - const MAX_LOG_BYTES: u64 = MIB; - if std::fs::metadata(paths::log_file()).map_or(false, |metadata| metadata.len() > MAX_LOG_BYTES) - { - let _ = std::fs::rename(paths::log_file(), paths::old_log_file()); - } - - match LogWriter::new(MAX_LOG_BYTES) { - Ok(writer) => { - let mut config_builder = ConfigBuilder::new(); - - config_builder.set_time_format_rfc3339(); - let local_offset = chrono::Local::now().offset().fix().local_minus_utc(); - if let Ok(offset) = UtcOffset::from_whole_seconds(local_offset) { - config_builder.set_time_offset(offset); - } - - #[cfg(any(target_os = "linux", target_os = "freebsd"))] - { - config_builder.add_filter_ignore_str("zbus"); - config_builder.add_filter_ignore_str("blade_graphics::hal::resource"); - config_builder.add_filter_ignore_str("naga::back::spv::writer"); - } - - let config = config_builder.build(); - simplelog::WriteLogger::init(level, config, writer) - .expect("could not initialize logger"); - } - Err(err) => { - init_stdout_logger(); - log::error!( - "could not open log file, defaulting to stdout logging: {}", - err - ); - } - } -} - -pub fn init_stdout_logger() { - Builder::new() - .parse_default_env() - .format(|buf, record| { - use env_logger::fmt::style::{AnsiColor, Style}; - - let subtle = Style::new().fg_color(Some(AnsiColor::BrightBlack.into())); - write!(buf, "{subtle}[{subtle:#}")?; - write!( - buf, - "{} ", - chrono::Local::now().format("%Y-%m-%dT%H:%M:%S%:z") - )?; - let level_style = buf.default_level_style(record.level()); - write!(buf, "{level_style}{:<5}{level_style:#}", record.level())?; - if let Some(path) = record.module_path() { - write!(buf, " {path}")?; - } - write!(buf, "{subtle}]{subtle:#}")?; - writeln!(buf, " {}", record.args()) - }) - .init(); -} - -struct LogWriter { - file: File, - max_size: u64, - current_size: u64, -} - -impl LogWriter { - fn new(max_size: u64) -> io::Result { - let file = OpenOptions::new() - .create(true) - .append(true) - .open(paths::log_file())?; - let current_size = file.metadata()?.len(); - - Ok(LogWriter { - file, - max_size, - current_size, - }) - } - - fn replace(&mut self) -> io::Result<()> { - self.file.sync_all()?; - fs::rename(paths::log_file(), paths::old_log_file())?; - self.file = OpenOptions::new() - .create(true) - .append(true) - .open(paths::log_file())?; - self.current_size = 0; - Ok(()) - } -} - -impl Write for LogWriter { - fn write(&mut self, buf: &[u8]) -> io::Result { - if self.current_size + buf.len() as u64 > self.max_size { - self.replace()?; - } - let bytes = self.file.write(buf)?; - self.current_size += bytes as u64; - Ok(bytes) - } - - fn flush(&mut self) -> io::Result<()> { - self.file.flush() - } -} diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index ba4d20b949..967f4aac14 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -1,7 +1,6 @@ // Disable command line from opening on release mode #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] -mod logger; mod reliability; mod zed; @@ -28,7 +27,6 @@ use prompt_store::PromptBuilder; use reqwest_client::ReqwestClient; use assets::Assets; -use logger::{init_logger, init_stdout_logger}; use node_runtime::{NodeBinaryOptions, NodeRuntime}; use parking_lot::Mutex; use project::project_settings::ProjectSettings; @@ -170,6 +168,16 @@ fn fail_to_open_window(e: anyhow::Error, _cx: &mut App) { } fn main() { + // Check if there is a pending installer + // If there is, run the installer and exit + // And we don't want to run the installer if we are not the first instance + #[cfg(target_os = "windows")] + let is_first_instance = crate::zed::windows_only_instance::is_first_instance(); + #[cfg(target_os = "windows")] + if is_first_instance && auto_update::check_pending_installation() { + return; + } + let args = Args::parse(); // Set custom data directory. @@ -195,11 +203,15 @@ fn main() { return; } - zlog::init_from_env(); + zlog::init(); if stdout_is_a_pty() { - init_stdout_logger(); + zlog::init_output_stdout(); } else { - init_logger(); + let result = zlog::init_output_file(paths::log_file(), Some(paths::old_log_file())); + if let Err(err) = result { + eprintln!("Could not open log file: {}... Defaulting to stdout", err); + zlog::init_output_stdout(); + }; } log::info!("========== starting zed =========="); @@ -234,27 +246,30 @@ fn main() { let (open_listener, mut open_rx) = OpenListener::new(); - let failed_single_instance_check = if *db::ZED_STATELESS - || *release_channel::RELEASE_CHANNEL == ReleaseChannel::Dev - { - false - } else { - #[cfg(any(target_os = "linux", target_os = "freebsd"))] - { - crate::zed::listen_for_cli_connections(open_listener.clone()).is_err() - } + let failed_single_instance_check = + if *db::ZED_STATELESS || *release_channel::RELEASE_CHANNEL == ReleaseChannel::Dev { + false + } else { + #[cfg(any(target_os = "linux", target_os = "freebsd"))] + { + crate::zed::listen_for_cli_connections(open_listener.clone()).is_err() + } - #[cfg(target_os = "windows")] - { - !crate::zed::windows_only_instance::check_single_instance(open_listener.clone(), &args) - } + #[cfg(target_os = "windows")] + { + !crate::zed::windows_only_instance::handle_single_instance( + open_listener.clone(), + &args, + is_first_instance, + ) + } - #[cfg(target_os = "macos")] - { - use zed::mac_only_instance::*; - ensure_only_instance() != IsOnlyInstance::Yes - } - }; + #[cfg(target_os = "macos")] + { + use zed::mac_only_instance::*; + ensure_only_instance() != IsOnlyInstance::Yes + } + }; if failed_single_instance_check { println!("zed is already running"); return; diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index d3d048e422..691de1edca 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -26,9 +26,9 @@ use git_ui::git_panel::GitPanel; use git_ui::project_diff::ProjectDiffToolbar; use gpui::{ Action, App, AppContext as _, AsyncWindowContext, Context, DismissEvent, Element, Entity, - Focusable, KeyBinding, MenuItem, ParentElement, PathPromptOptions, PromptLevel, ReadGlobal, - SharedString, Styled, Task, TitlebarOptions, UpdateGlobal, Window, WindowKind, WindowOptions, - actions, point, px, + Focusable, KeyBinding, ParentElement, PathPromptOptions, PromptLevel, ReadGlobal, SharedString, + Styled, Task, TitlebarOptions, UpdateGlobal, Window, WindowKind, WindowOptions, actions, point, + px, }; use image_viewer::ImageInfo; use migrate::{MigrationBanner, MigrationEvent, MigrationNotification, MigrationType}; @@ -1386,7 +1386,12 @@ fn reload_keymaps(cx: &mut App, user_key_bindings: Vec) { load_default_keymap(cx); cx.bind_keys(user_key_bindings); cx.set_menus(app_menus()); - cx.set_dock_menu(vec![MenuItem::action("New Window", workspace::NewWindow)]); + // On Windows, this is set in the `update_jump_list` method of the `HistoryManager`. + #[cfg(not(target_os = "windows"))] + cx.set_dock_menu(vec![gpui::MenuItem::action( + "New Window", + workspace::NewWindow, + )]); } pub fn load_default_keymap(cx: &mut App) { diff --git a/crates/zed/src/zed/windows_only_instance.rs b/crates/zed/src/zed/windows_only_instance.rs index 92295b5006..972cad38fe 100644 --- a/crates/zed/src/zed/windows_only_instance.rs +++ b/crates/zed/src/zed/windows_only_instance.rs @@ -25,7 +25,7 @@ use windows::{ use crate::{Args, OpenListener}; -pub fn check_single_instance(opener: OpenListener, args: &Args) -> bool { +pub fn is_first_instance() -> bool { unsafe { CreateMutexW( None, @@ -34,9 +34,11 @@ pub fn check_single_instance(opener: OpenListener, args: &Args) -> bool { ) .expect("Unable to create instance mutex.") }; - let first_instance = unsafe { GetLastError() } != ERROR_ALREADY_EXISTS; + unsafe { GetLastError() != ERROR_ALREADY_EXISTS } +} - if first_instance { +pub fn handle_single_instance(opener: OpenListener, args: &Args, is_first_instance: bool) -> bool { + if is_first_instance { // We are the first instance, listen for messages sent from other instances std::thread::spawn(move || with_pipe(|url| opener.open_urls(vec![url]))); } else if !args.foreground { @@ -44,7 +46,7 @@ pub fn check_single_instance(opener: OpenListener, args: &Args) -> bool { send_args_to_instance(args).log_err(); } - first_instance + is_first_instance } fn with_pipe(f: impl Fn(String)) { diff --git a/crates/zlog/Cargo.toml b/crates/zlog/Cargo.toml index b64b72633d..d0632d14f2 100644 --- a/crates/zlog/Cargo.toml +++ b/crates/zlog/Cargo.toml @@ -15,6 +15,10 @@ path = "src/zlog.rs" default = [] [dependencies] +chrono.workspace = true log.workspace = true workspace-hack.workspace = true anyhow.workspace = true + +[dev-dependencies] +tempfile.workspace = true diff --git a/crates/zlog/src/filter.rs b/crates/zlog/src/filter.rs new file mode 100644 index 0000000000..255f17a6d4 --- /dev/null +++ b/crates/zlog/src/filter.rs @@ -0,0 +1,568 @@ +use std::{ + collections::{HashMap, VecDeque}, + hash::{DefaultHasher, Hasher}, + sync::{ + OnceLock, RwLock, + atomic::{AtomicU8, Ordering}, + }, + usize, +}; + +use crate::{SCOPE_DEPTH_MAX, SCOPE_STRING_SEP_STR, Scope, ScopeAlloc, env_config}; + +use log; + +static ENV_FILTER: OnceLock = OnceLock::new(); +static SCOPE_MAP: RwLock> = RwLock::new(None); +struct GlobalScopeMap { + map: ScopeMap, + hash: u64, +} + +const LEVEL_ENABLED_MAX_DEFAULT: log::LevelFilter = log::LevelFilter::Info; +/// The maximum log level of verbosity that is enabled by default. +/// All messages more verbose than this level will be discarded +/// by default unless specially configured. +/// +/// This is used instead of the `log::max_level` as we need to tell the `log` +/// crate that the max level is everything, so that we can dynamically enable +/// logs that are more verbose than this level without the `log` crate throwing +/// them away before we see them +static mut LEVEL_ENABLED_MAX_STATIC: log::LevelFilter = LEVEL_ENABLED_MAX_DEFAULT; + +/// A cache of the true maximum log level that _could_ be printed. This is based +/// on the maximally verbose level that is configured by the user, and is used +/// to filter out logs more verbose than any configured level. +/// +/// E.g. if `LEVEL_ENABLED_MAX_STATIC `is 'info' but a user has configured some +/// scope to print at a `debug` level, then this will be `debug`, and all +/// `trace` logs will be discarded. +/// Therefore, it should always be `>= LEVEL_ENABLED_MAX_STATIC` +// PERF: this doesn't need to be an atomic, we don't actually care about race conditions here +static LEVEL_ENABLED_MAX_CONFIG: AtomicU8 = AtomicU8::new(LEVEL_ENABLED_MAX_DEFAULT as u8); + +pub fn init_env_filter(filter: env_config::EnvFilter) { + if let Some(level_max) = filter.level_global { + unsafe { LEVEL_ENABLED_MAX_STATIC = level_max } + } + if ENV_FILTER.set(filter).is_err() { + panic!("Environment filter cannot be initialized twice"); + } +} + +pub fn is_possibly_enabled_level(level: log::Level) -> bool { + return LEVEL_ENABLED_MAX_CONFIG.load(Ordering::Relaxed) <= level as u8; +} + +pub fn is_scope_enabled(scope: &Scope, level: log::Level) -> bool { + if level <= unsafe { LEVEL_ENABLED_MAX_STATIC } { + // [FAST PATH] + // if the message is at or below the minimum printed log level + // (where error < warn < info etc) then always enable + return true; + } + if !is_possibly_enabled_level(level) { + // [FAST PATH PT. 2] + // if the message is above the maximum enabled log level + // (where error < warn < info etc) then disable without checking + // scope map + return false; + } + let global_scope_map = SCOPE_MAP.read().unwrap_or_else(|err| { + SCOPE_MAP.clear_poison(); + return err.into_inner(); + }); + + let Some(GlobalScopeMap { map, .. }) = global_scope_map.as_ref() else { + // on failure, return false because it's not <= LEVEL_ENABLED_MAX_STATIC + return false; + }; + + if map.is_empty() { + // if no scopes are enabled, return false because it's not <= LEVEL_ENABLED_MAX_STATIC + return false; + } + let enabled_status = map.is_enabled(&scope, level); + return match enabled_status { + // if it isn't configured, then it it's disabled because it's not <= LEVEL_ENABLED_MAX_STATIC + EnabledStatus::NotConfigured => false, + EnabledStatus::Enabled => true, + EnabledStatus::Disabled => false, + }; +} + +fn hash_scope_map_settings(map: &HashMap) -> u64 { + let mut hasher = DefaultHasher::new(); + let mut items = map.iter().collect::>(); + items.sort(); + for (key, value) in items { + Hasher::write(&mut hasher, key.as_bytes()); + Hasher::write(&mut hasher, value.as_bytes()); + } + return hasher.finish(); +} + +pub(crate) fn refresh() { + refresh_from_settings(&HashMap::default()); +} + +pub fn refresh_from_settings(settings: &HashMap) { + let hash_old = { + SCOPE_MAP + .read() + .unwrap_or_else(|err| { + SCOPE_MAP.clear_poison(); + err.into_inner() + }) + .as_ref() + .map(|scope_map| scope_map.hash) + }; + let hash_new = hash_scope_map_settings(settings); + if hash_old == Some(hash_new) { + return; + } + let env_config = ENV_FILTER.get(); + let map_new = ScopeMap::new_from_settings_and_env(settings, env_config); + let mut level_enabled_max = unsafe { LEVEL_ENABLED_MAX_STATIC }; + for entry in &map_new.entries { + if let Some(level) = entry.enabled { + level_enabled_max = level_enabled_max.max(level.to_level_filter()); + } + } + LEVEL_ENABLED_MAX_CONFIG.store(level_enabled_max as u8, Ordering::Release); + + let mut global_map = SCOPE_MAP.write().unwrap_or_else(|err| { + SCOPE_MAP.clear_poison(); + err.into_inner() + }); + global_map.replace(GlobalScopeMap { + map: map_new, + hash: hash_new, + }); +} + +fn level_from_level_str(level_str: &String) -> Option { + let level = match level_str.to_ascii_lowercase().as_str() { + "" => log::Level::Trace, + "trace" => log::Level::Trace, + "debug" => log::Level::Debug, + "info" => log::Level::Info, + "warn" => log::Level::Warn, + "error" => log::Level::Error, + "off" | "disable" | "no" | "none" | "disabled" => { + crate::warn!( + "Invalid log level \"{level_str}\", set to error to disable non-error logging. Defaulting to error" + ); + log::Level::Error + } + _ => { + crate::warn!("Invalid log level \"{level_str}\", ignoring"); + return None; + } + }; + return Some(level); +} + +fn scope_alloc_from_scope_str(scope_str: &String) -> Option { + let mut scope_buf = [""; SCOPE_DEPTH_MAX]; + let mut index = 0; + let mut scope_iter = scope_str.split(SCOPE_STRING_SEP_STR); + while index < SCOPE_DEPTH_MAX { + let Some(scope) = scope_iter.next() else { + break; + }; + if scope == "" { + continue; + } + scope_buf[index] = scope; + index += 1; + } + if index == 0 { + return None; + } + if let Some(_) = scope_iter.next() { + crate::warn!( + "Invalid scope key, too many nested scopes: '{scope_str}'. Max depth is {SCOPE_DEPTH_MAX}", + ); + return None; + } + let scope = scope_buf.map(|s| s.to_string()); + return Some(scope); +} + +pub struct ScopeMap { + entries: Vec, + root_count: usize, +} + +pub struct ScopeMapEntry { + scope: String, + enabled: Option, + descendants: std::ops::Range, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EnabledStatus { + Enabled, + Disabled, + NotConfigured, +} + +impl ScopeMap { + pub fn new_from_settings_and_env( + items_input_map: &HashMap, + env_config: Option<&env_config::EnvFilter>, + ) -> Self { + let mut items = Vec::with_capacity( + items_input_map.len() + env_config.map_or(0, |c| c.directive_names.len()), + ); + if let Some(env_filter) = env_config { + // TODO: parse on load instead of every reload + items.extend( + env_filter + .directive_names + .iter() + .zip(env_filter.directive_levels.iter()) + .filter_map(|(scope, level_filter)| { + if items_input_map.get(scope).is_some() { + return None; + } + let scope = scope_alloc_from_scope_str(scope)?; + // TODO: use level filters instead of scopes in scope map + let level = level_filter.to_level()?; + + Some((scope, level)) + }), + ); + } + items.extend( + items_input_map + .into_iter() + .filter_map(|(scope_str, level_str)| { + let scope = scope_alloc_from_scope_str(&scope_str)?; + let level = level_from_level_str(&level_str)?; + return Some((scope, level)); + }), + ); + + items.sort_by(|a, b| a.0.cmp(&b.0)); + + let mut this = Self { + entries: Vec::with_capacity(items.len() * SCOPE_DEPTH_MAX), + root_count: 0, + }; + + let items_count = items.len(); + + struct ProcessQueueEntry { + parent_index: usize, + depth: usize, + items_range: std::ops::Range, + } + let mut process_queue = VecDeque::new(); + process_queue.push_back(ProcessQueueEntry { + parent_index: usize::MAX, + depth: 0, + items_range: 0..items_count, + }); + + let empty_range = 0..0; + + while let Some(process_entry) = process_queue.pop_front() { + let ProcessQueueEntry { + items_range, + depth, + parent_index, + } = process_entry; + let mut cursor = items_range.start; + let res_entries_start = this.entries.len(); + while cursor < items_range.end { + let sub_items_start = cursor; + cursor += 1; + let scope_name = &items[sub_items_start].0[depth]; + while cursor < items_range.end && &items[cursor].0[depth] == scope_name { + cursor += 1; + } + let sub_items_end = cursor; + if scope_name == "" { + assert_eq!(sub_items_start + 1, sub_items_end); + assert_ne!(depth, 0); + assert_ne!(parent_index, usize::MAX); + assert!(this.entries[parent_index].enabled.is_none()); + this.entries[parent_index].enabled = Some(items[sub_items_start].1); + continue; + } + let is_valid_scope = scope_name != ""; + let is_last = depth + 1 == SCOPE_DEPTH_MAX || !is_valid_scope; + let mut enabled = None; + if is_last { + assert_eq!( + sub_items_start + 1, + sub_items_end, + "Expected one item: got: {:?}", + &items[items_range.clone()] + ); + enabled = Some(items[sub_items_start].1); + } else { + let entry_index = this.entries.len(); + process_queue.push_back(ProcessQueueEntry { + items_range: sub_items_start..sub_items_end, + parent_index: entry_index, + depth: depth + 1, + }); + } + this.entries.push(ScopeMapEntry { + scope: scope_name.to_owned(), + enabled, + descendants: empty_range.clone(), + }); + } + let res_entries_end = this.entries.len(); + if parent_index != usize::MAX { + this.entries[parent_index].descendants = res_entries_start..res_entries_end; + } else { + this.root_count = res_entries_end; + } + } + + return this; + } + + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } + + pub fn is_enabled(&self, scope: &[S; SCOPE_DEPTH_MAX], level: log::Level) -> EnabledStatus + where + S: AsRef, + { + let mut enabled = None; + let mut cur_range = &self.entries[0..self.root_count]; + let mut depth = 0; + + 'search: while !cur_range.is_empty() + && depth < SCOPE_DEPTH_MAX + && scope[depth].as_ref() != "" + { + for entry in cur_range { + if entry.scope == scope[depth].as_ref() { + // note: + enabled = entry.enabled.or(enabled); + cur_range = &self.entries[entry.descendants.clone()]; + depth += 1; + continue 'search; + } + } + break 'search; + } + + return enabled.map_or(EnabledStatus::NotConfigured, |level_enabled| { + if level <= level_enabled { + EnabledStatus::Enabled + } else { + EnabledStatus::Disabled + } + }); + } +} + +#[cfg(test)] +mod tests { + use crate::private::scope_new; + + use super::*; + + fn scope_map_from_keys(kv: &[(&str, &str)]) -> ScopeMap { + let hash_map: HashMap = kv + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + ScopeMap::new_from_settings_and_env(&hash_map, None) + } + + #[test] + fn test_initialization() { + let map = scope_map_from_keys(&[("a.b.c.d", "trace")]); + assert_eq!(map.root_count, 1); + assert_eq!(map.entries.len(), 4); + + let map = scope_map_from_keys(&[]); + assert_eq!(map.root_count, 0); + assert_eq!(map.entries.len(), 0); + + let map = scope_map_from_keys(&[("", "trace")]); + assert_eq!(map.root_count, 0); + assert_eq!(map.entries.len(), 0); + + let map = scope_map_from_keys(&[("foo..bar", "trace")]); + assert_eq!(map.root_count, 1); + assert_eq!(map.entries.len(), 2); + + let map = scope_map_from_keys(&[ + ("a.b.c.d", "trace"), + ("e.f.g.h", "debug"), + ("i.j.k.l", "info"), + ("m.n.o.p", "warn"), + ("q.r.s.t", "error"), + ]); + assert_eq!(map.root_count, 5); + assert_eq!(map.entries.len(), 20); + assert_eq!(map.entries[0].scope, "a"); + assert_eq!(map.entries[1].scope, "e"); + assert_eq!(map.entries[2].scope, "i"); + assert_eq!(map.entries[3].scope, "m"); + assert_eq!(map.entries[4].scope, "q"); + } + + fn scope_from_scope_str(scope_str: &'static str) -> Scope { + let mut scope_buf = [""; SCOPE_DEPTH_MAX]; + let mut index = 0; + let mut scope_iter = scope_str.split(SCOPE_STRING_SEP_STR); + while index < SCOPE_DEPTH_MAX { + let Some(scope) = scope_iter.next() else { + break; + }; + if scope == "" { + continue; + } + scope_buf[index] = scope; + index += 1; + } + assert_ne!(index, 0); + assert!(scope_iter.next().is_none()); + return scope_buf; + } + + #[test] + fn test_is_enabled() { + let map = scope_map_from_keys(&[ + ("a.b.c.d", "trace"), + ("e.f.g.h", "debug"), + ("i.j.k.l", "info"), + ("m.n.o.p", "warn"), + ("q.r.s.t", "error"), + ]); + use log::Level; + assert_eq!( + map.is_enabled(&scope_from_scope_str("a.b.c.d"), Level::Trace), + EnabledStatus::Enabled + ); + assert_eq!( + map.is_enabled(&scope_from_scope_str("a.b.c.d"), Level::Debug), + EnabledStatus::Enabled + ); + + assert_eq!( + map.is_enabled(&scope_from_scope_str("e.f.g.h"), Level::Debug), + EnabledStatus::Enabled + ); + assert_eq!( + map.is_enabled(&scope_from_scope_str("e.f.g.h"), Level::Info), + EnabledStatus::Enabled + ); + assert_eq!( + map.is_enabled(&scope_from_scope_str("e.f.g.h"), Level::Trace), + EnabledStatus::Disabled + ); + + assert_eq!( + map.is_enabled(&scope_from_scope_str("i.j.k.l"), Level::Info), + EnabledStatus::Enabled + ); + assert_eq!( + map.is_enabled(&scope_from_scope_str("i.j.k.l"), Level::Warn), + EnabledStatus::Enabled + ); + assert_eq!( + map.is_enabled(&scope_from_scope_str("i.j.k.l"), Level::Debug), + EnabledStatus::Disabled + ); + + assert_eq!( + map.is_enabled(&scope_from_scope_str("m.n.o.p"), Level::Warn), + EnabledStatus::Enabled + ); + assert_eq!( + map.is_enabled(&scope_from_scope_str("m.n.o.p"), Level::Error), + EnabledStatus::Enabled + ); + assert_eq!( + map.is_enabled(&scope_from_scope_str("m.n.o.p"), Level::Info), + EnabledStatus::Disabled + ); + + assert_eq!( + map.is_enabled(&scope_from_scope_str("q.r.s.t"), Level::Error), + EnabledStatus::Enabled + ); + assert_eq!( + map.is_enabled(&scope_from_scope_str("q.r.s.t"), Level::Warn), + EnabledStatus::Disabled + ); + } + + fn scope_map_from_keys_and_env(kv: &[(&str, &str)], env: &env_config::EnvFilter) -> ScopeMap { + let hash_map: HashMap = kv + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + ScopeMap::new_from_settings_and_env(&hash_map, Some(env)) + } + + #[test] + fn test_initialization_with_env() { + let env_filter = env_config::parse("a.b=debug,u=error").unwrap(); + let map = scope_map_from_keys_and_env(&[], &env_filter); + assert_eq!(map.root_count, 2); + assert_eq!(map.entries.len(), 3); + assert_eq!( + map.is_enabled(&scope_new(&["a"]), log::Level::Debug), + EnabledStatus::NotConfigured + ); + assert_eq!( + map.is_enabled(&scope_new(&["a", "b"]), log::Level::Debug), + EnabledStatus::Enabled + ); + assert_eq!( + map.is_enabled(&scope_new(&["a", "b", "c"]), log::Level::Trace), + EnabledStatus::Disabled + ); + + let env_filter = env_config::parse("a.b=debug,e.f.g.h=trace,u=error").unwrap(); + let map = scope_map_from_keys_and_env( + &[ + ("a.b.c.d", "trace"), + ("e.f.g.h", "debug"), + ("i.j.k.l", "info"), + ("m.n.o.p", "warn"), + ("q.r.s.t", "error"), + ], + &env_filter, + ); + assert_eq!(map.root_count, 6); + assert_eq!(map.entries.len(), 21); + assert_eq!(map.entries[0].scope, "a"); + assert_eq!(map.entries[1].scope, "e"); + assert_eq!(map.entries[2].scope, "i"); + assert_eq!(map.entries[3].scope, "m"); + assert_eq!(map.entries[4].scope, "q"); + assert_eq!(map.entries[5].scope, "u"); + assert_eq!( + map.is_enabled(&scope_new(&["a", "b", "c", "d"]), log::Level::Trace), + EnabledStatus::Enabled + ); + assert_eq!( + map.is_enabled(&scope_new(&["a", "b", "c"]), log::Level::Trace), + EnabledStatus::Disabled + ); + assert_eq!( + map.is_enabled(&scope_new(&["u", "v"]), log::Level::Warn), + EnabledStatus::Disabled + ); + // settings override env + assert_eq!( + map.is_enabled(&scope_new(&["e", "f", "g", "h"]), log::Level::Trace), + EnabledStatus::Disabled, + ); + } +} diff --git a/crates/zlog/src/sink.rs b/crates/zlog/src/sink.rs new file mode 100644 index 0000000000..6a2a041b2a --- /dev/null +++ b/crates/zlog/src/sink.rs @@ -0,0 +1,257 @@ +use std::{ + fs, + io::{self, Write}, + path::PathBuf, + sync::{ + Mutex, OnceLock, + atomic::{AtomicU64, Ordering}, + }, +}; + +use crate::{SCOPE_STRING_SEP_CHAR, Scope}; + +// ANSI color escape codes for log levels +const ANSI_RESET: &str = "\x1b[0m"; +const ANSI_BOLD: &str = "\x1b[1m"; +const ANSI_RED: &str = "\x1b[31m"; +const ANSI_YELLOW: &str = "\x1b[33m"; +const ANSI_GREEN: &str = "\x1b[32m"; +const ANSI_BLUE: &str = "\x1b[34m"; +const ANSI_MAGENTA: &str = "\x1b[35m"; + +/// Whether stdout output is enabled. +static mut ENABLED_SINKS_STDOUT: bool = false; + +/// Is Some(file) if file output is enabled. +static ENABLED_SINKS_FILE: Mutex> = Mutex::new(None); +static SINK_FILE_PATH: OnceLock<&'static PathBuf> = OnceLock::new(); +static SINK_FILE_PATH_ROTATE: OnceLock<&'static PathBuf> = OnceLock::new(); +/// Atomic counter for the size of the log file in bytes. +// TODO: make non-atomic if writing single threaded +static SINK_FILE_SIZE_BYTES: AtomicU64 = AtomicU64::new(0); +/// Maximum size of the log file before it will be rotated, in bytes. +const SINK_FILE_SIZE_BYTES_MAX: u64 = 1024 * 1024; // 1 MB + +pub fn init_output_stdout() { + unsafe { + ENABLED_SINKS_STDOUT = true; + } +} + +pub fn init_output_file( + path: &'static PathBuf, + path_rotate: Option<&'static PathBuf>, +) -> io::Result<()> { + let mut file = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(path)?; + + SINK_FILE_PATH + .set(path) + .expect("Init file output should only be called once"); + if let Some(path_rotate) = path_rotate { + SINK_FILE_PATH_ROTATE + .set(path_rotate) + .expect("Init file output should only be called once"); + } + + let mut enabled_sinks_file = ENABLED_SINKS_FILE + .try_lock() + .expect("Log file lock is available during init"); + + let size_bytes = file.metadata().map_or(0, |metadata| metadata.len()); + if size_bytes >= SINK_FILE_SIZE_BYTES_MAX { + rotate_log_file(&mut file, Some(path), path_rotate, &SINK_FILE_SIZE_BYTES); + } else { + SINK_FILE_SIZE_BYTES.store(size_bytes, Ordering::Relaxed); + } + + *enabled_sinks_file = Some(file); + + Ok(()) +} + +const LEVEL_OUTPUT_STRINGS: [&str; 6] = [ + " ", // nop: ERROR = 1 + "ERROR", // + "WARN ", // + "INFO ", // + "DEBUG", // + "TRACE", // +]; + +// Colors for different log levels +static LEVEL_ANSI_COLORS: [&str; 6] = [ + "", // nop + ANSI_RED, // Error: Red + ANSI_YELLOW, // Warn: Yellow + ANSI_GREEN, // Info: Green + ANSI_BLUE, // Debug: Blue + ANSI_MAGENTA, // Trace: Magenta +]; + +pub fn submit(record: Record) { + if unsafe { ENABLED_SINKS_STDOUT } { + let mut stdout = std::io::stdout().lock(); + _ = writeln!( + &mut stdout, + "{} {}{}{}{} {}[{}]{} {}", + chrono::Local::now().format("%Y-%m-%dT%H:%M:%S%:z"), + ANSI_BOLD, + LEVEL_ANSI_COLORS[record.level as usize], + LEVEL_OUTPUT_STRINGS[record.level as usize], + ANSI_RESET, + ANSI_BOLD, + ScopeFmt(record.scope), + ANSI_RESET, + record.message + ); + } + let mut file = ENABLED_SINKS_FILE.lock().unwrap_or_else(|handle| { + ENABLED_SINKS_FILE.clear_poison(); + handle.into_inner() + }); + if let Some(file) = file.as_mut() { + struct SizedWriter<'a> { + file: &'a mut std::fs::File, + written: u64, + } + impl io::Write for SizedWriter<'_> { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.file.write(buf)?; + self.written += buf.len() as u64; + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + self.file.flush() + } + } + let file_size_bytes = { + let mut writer = SizedWriter { file, written: 0 }; + _ = writeln!( + &mut writer, + "{} {} [{}] {}", + chrono::Local::now().format("%Y-%m-%dT%H:%M:%S%:z"), + LEVEL_OUTPUT_STRINGS[record.level as usize], + ScopeFmt(record.scope), + record.message + ); + SINK_FILE_SIZE_BYTES.fetch_add(writer.written, Ordering::Relaxed) + writer.written + }; + if file_size_bytes > SINK_FILE_SIZE_BYTES_MAX { + rotate_log_file( + file, + SINK_FILE_PATH.get(), + SINK_FILE_PATH_ROTATE.get(), + &SINK_FILE_SIZE_BYTES, + ); + } + } +} + +pub fn flush() { + _ = std::io::stdout().lock().flush(); +} + +struct ScopeFmt(Scope); + +impl std::fmt::Display for ScopeFmt { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use std::fmt::Write; + f.write_str(self.0[0])?; + for scope in &self.0[1..] { + if !scope.is_empty() { + f.write_char(SCOPE_STRING_SEP_CHAR)?; + } + f.write_str(scope)?; + } + Ok(()) + } +} + +pub struct Record<'a> { + pub scope: Scope, + pub level: log::Level, + pub message: &'a std::fmt::Arguments<'a>, +} + +fn rotate_log_file( + file: &mut fs::File, + path: Option, + path_rotate: Option, + atomic_size: &AtomicU64, +) where + PathRef: AsRef, +{ + if let Err(err) = file.flush() { + eprintln!( + "Failed to flush log file before rotating, some logs may be lost: {}", + err + ); + } + let rotation_error = match (path, path_rotate) { + (Some(_), None) => Some(anyhow::anyhow!("No rotation log file path configured")), + (None, _) => Some(anyhow::anyhow!("No log file path configured")), + (Some(path), Some(path_rotate)) => fs::copy(path, path_rotate) + .err() + .map(|err| anyhow::anyhow!(err)), + }; + if let Some(err) = rotation_error { + eprintln!( + "Log file rotation failed. Truncating log file anyways: {}", + err, + ); + } + _ = file.set_len(0); + + // SAFETY: It is safe to set size to 0 even if set_len fails as + // according to the documentation, it only fails if: + // - the file is not writeable: should never happen, + // - the size would cause an overflow (implementation specific): 0 should never cause an overflow + atomic_size.store(0, Ordering::Relaxed); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rotate_log_file() { + let temp_dir = tempfile::tempdir().unwrap(); + let log_file_path = temp_dir.path().join("log.txt"); + let rotation_log_file_path = temp_dir.path().join("log_rotated.txt"); + + let mut file = fs::File::create(&log_file_path).unwrap(); + let contents = String::from("Hello, world!"); + file.write_all(contents.as_bytes()).unwrap(); + + let size = AtomicU64::new(contents.len() as u64); + + rotate_log_file( + &mut file, + Some(&log_file_path), + Some(&rotation_log_file_path), + &size, + ); + + assert!(log_file_path.exists()); + assert_eq!(log_file_path.metadata().unwrap().len(), 0); + assert!(rotation_log_file_path.exists()); + assert_eq!( + std::fs::read_to_string(&rotation_log_file_path).unwrap(), + contents, + ); + assert_eq!(size.load(Ordering::Relaxed), 0); + } + + #[test] + fn test_log_level_names() { + assert_eq!(LEVEL_OUTPUT_STRINGS[log::Level::Error as usize], "ERROR"); + assert_eq!(LEVEL_OUTPUT_STRINGS[log::Level::Warn as usize], "WARN "); + assert_eq!(LEVEL_OUTPUT_STRINGS[log::Level::Info as usize], "INFO "); + assert_eq!(LEVEL_OUTPUT_STRINGS[log::Level::Debug as usize], "DEBUG"); + assert_eq!(LEVEL_OUTPUT_STRINGS[log::Level::Trace as usize], "TRACE"); + } +} diff --git a/crates/zlog/src/zlog.rs b/crates/zlog/src/zlog.rs index b953409223..9191335e41 100644 --- a/crates/zlog/src/zlog.rs +++ b/crates/zlog/src/zlog.rs @@ -2,18 +2,27 @@ pub use log as log_impl; mod env_config; +pub mod filter; +pub mod sink; + +pub use sink::{init_output_file, init_output_stdout}; pub const SCOPE_DEPTH_MAX: usize = 4; -pub fn init_from_env() { +pub fn init() { + process_env(); + log::set_logger(&ZLOG).expect("Logger should not be initialized twice"); + log::set_max_level(log::LevelFilter::max()); +} + +pub fn process_env() { let Ok(env_config) = std::env::var("ZED_LOG").or_else(|_| std::env::var("RUST_LOG")) else { return; }; match env_config::parse(&env_config) { Ok(filter) => { - scope_map::init_env_filter(filter); - scope_map::refresh(); - // TODO: set max level once removing `env_logger` and `simple_log` crates + filter::init_env_filter(filter); + filter::refresh(); } Err(err) => { eprintln!("Failed to parse log filter: {}", err); @@ -21,25 +30,43 @@ pub fn init_from_env() { } } -/// because we are currently just wrapping the `log` crate in `zlog`, -/// we need to work around the fact that the `log` crate only provides a -/// single global level filter. In order to have more precise control until -/// we no longer wrap `log`, we bump up the priority of log level so that it -/// will be logged, even if the actual level is lower -/// This is fine for now, as we use a `info` level filter by default in releases, -/// which hopefully won't result in confusion like `warn` or `error` levels might. -pub fn min_printed_log_level(level: log_impl::Level) -> log_impl::Level { - // this logic is defined based on the logic used in the `log` crate, - // which checks that a logs level is <= both of these values, - // so we take the minimum of the two values to ensure that check passes - let level_min_static = log_impl::STATIC_MAX_LEVEL; - let level_min_dynamic = log_impl::max_level(); - if level <= level_min_static && level <= level_min_dynamic { - return level; +static ZLOG: Zlog = Zlog {}; + +pub struct Zlog {} + +impl log::Log for Zlog { + fn enabled(&self, metadata: &log::Metadata) -> bool { + filter::is_possibly_enabled_level(metadata.level()) + } + + fn log(&self, record: &log::Record) { + if !self.enabled(record.metadata()) { + return; + } + let scope = match record.module_path_static() { + Some(module_path) => { + // TODO: better module name -> scope translation + let crate_name = private::extract_crate_name_from_module_path(module_path); + private::scope_new(&[crate_name]) + } + // TODO: when do we hit this + None => private::scope_new(&["*unknown*"]), + }; + let level = record.metadata().level(); + if !filter::is_scope_enabled(&scope, level) { + return; + } + sink::submit(sink::Record { + scope, + level, + message: record.args(), + }); + } + + fn flush(&self) { + // todo: necessary? + sink::flush(); } - return log_impl::LevelFilter::min(level_min_static, level_min_dynamic) - .to_level() - .unwrap_or(level); } #[macro_export] @@ -47,9 +74,13 @@ macro_rules! log { ($logger:expr, $level:expr, $($arg:tt)+) => { let level = $level; let logger = $logger; - let (enabled, level) = $crate::scope_map::is_scope_enabled(&logger.scope, level); + let enabled = $crate::filter::is_scope_enabled(&logger.scope, level); if enabled { - $crate::log_impl::log!(level, "[{}]: {}", &logger.fmt_scope(), format!($($arg)+)); + $crate::sink::submit($crate::sink::Record { + scope: logger.scope, + level, + message: &format_args!($($arg)+), + }); } } } @@ -205,27 +236,14 @@ pub mod private { pub type Scope = [&'static str; SCOPE_DEPTH_MAX]; pub type ScopeAlloc = [String; SCOPE_DEPTH_MAX]; -const SCOPE_STRING_SEP: &'static str = "."; +const SCOPE_STRING_SEP_STR: &'static str = "."; +const SCOPE_STRING_SEP_CHAR: char = '.'; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct Logger { pub scope: Scope, } -impl Logger { - pub fn fmt_scope(&self) -> String { - let mut last = 0; - for s in self.scope { - if s.is_empty() { - break; - } - last += 1; - } - - return self.scope[0..last].join(SCOPE_STRING_SEP); - } -} - pub struct Timer { pub logger: Logger, pub start_time: std::time::Instant, @@ -288,543 +306,6 @@ impl Timer { } } -pub mod scope_map { - use std::{ - collections::{HashMap, VecDeque}, - hash::{DefaultHasher, Hasher}, - sync::{ - OnceLock, RwLock, - atomic::{AtomicU64, Ordering}, - }, - usize, - }; - - use super::*; - static ENV_FILTER: OnceLock = OnceLock::new(); - static SCOPE_MAP: RwLock> = RwLock::new(None); - static SCOPE_MAP_HASH: AtomicU64 = AtomicU64::new(0); - - pub fn init_env_filter(filter: env_config::EnvFilter) { - if ENV_FILTER.set(filter).is_err() { - panic!("Environment filter cannot be initialized twice"); - } - } - - pub fn is_scope_enabled(scope: &Scope, level: log_impl::Level) -> (bool, log_impl::Level) { - let level_min = min_printed_log_level(level); - if level <= level_min { - // [FAST PATH] - // if the message is at or below the minimum printed log level - // (where error < warn < info etc) then always enable - return (true, level); - } - - let Ok(map) = SCOPE_MAP.read() else { - // on failure, default to enabled detection done by `log` crate - return (true, level); - }; - - let Some(map) = map.as_ref() else { - // on failure, default to enabled detection done by `log` crate - return (true, level); - }; - - if map.is_empty() { - // if no scopes are enabled, default to enabled detection done by `log` crate - return (true, level); - } - let enabled_status = map.is_enabled(&scope, level); - match enabled_status { - EnabledStatus::NotConfigured => { - // if this scope isn't configured, default to enabled detection done by `log` crate - return (true, level); - } - EnabledStatus::Enabled => { - // if this scope is enabled, enable logging - // note: bumping level to min level that will be printed - // to work around log crate limitations - return (true, level_min); - } - EnabledStatus::Disabled => { - // if the configured level is lower than the requested level, disable logging - // note: err = 0, warn = 1, etc. - return (false, level); - } - } - } - - fn hash_scope_map_settings(map: &HashMap) -> u64 { - let mut hasher = DefaultHasher::new(); - let mut items = map.iter().collect::>(); - items.sort(); - for (key, value) in items { - Hasher::write(&mut hasher, key.as_bytes()); - Hasher::write(&mut hasher, value.as_bytes()); - } - return hasher.finish(); - } - - pub(crate) fn refresh() { - refresh_from_settings(&HashMap::default()); - } - - pub fn refresh_from_settings(settings: &HashMap) { - let hash_old = SCOPE_MAP_HASH.load(Ordering::Acquire); - let hash_new = hash_scope_map_settings(settings); - if hash_old == hash_new && hash_old != 0 { - return; - } - let env_config = ENV_FILTER.get(); - let map_new = ScopeMap::new_from_settings_and_env(settings, env_config); - - if let Ok(_) = SCOPE_MAP_HASH.compare_exchange( - hash_old, - hash_new, - Ordering::Release, - Ordering::Relaxed, - ) { - let mut map = SCOPE_MAP.write().unwrap_or_else(|err| { - SCOPE_MAP.clear_poison(); - err.into_inner() - }); - *map = Some(map_new); - } - } - - fn level_from_level_str(level_str: &String) -> Option { - let level = match level_str.to_ascii_lowercase().as_str() { - "" => log_impl::Level::Trace, - "trace" => log_impl::Level::Trace, - "debug" => log_impl::Level::Debug, - "info" => log_impl::Level::Info, - "warn" => log_impl::Level::Warn, - "error" => log_impl::Level::Error, - "off" | "disable" | "no" | "none" | "disabled" => { - crate::warn!( - "Invalid log level \"{level_str}\", set to error to disable non-error logging. Defaulting to error" - ); - log_impl::Level::Error - } - _ => { - crate::warn!("Invalid log level \"{level_str}\", ignoring"); - return None; - } - }; - return Some(level); - } - - fn scope_alloc_from_scope_str(scope_str: &String) -> Option { - let mut scope_buf = [""; SCOPE_DEPTH_MAX]; - let mut index = 0; - let mut scope_iter = scope_str.split(SCOPE_STRING_SEP); - while index < SCOPE_DEPTH_MAX { - let Some(scope) = scope_iter.next() else { - break; - }; - if scope == "" { - continue; - } - scope_buf[index] = scope; - index += 1; - } - if index == 0 { - return None; - } - if let Some(_) = scope_iter.next() { - crate::warn!( - "Invalid scope key, too many nested scopes: '{scope_str}'. Max depth is {SCOPE_DEPTH_MAX}", - ); - return None; - } - let scope = scope_buf.map(|s| s.to_string()); - return Some(scope); - } - - pub struct ScopeMap { - entries: Vec, - root_count: usize, - } - - pub struct ScopeMapEntry { - scope: String, - enabled: Option, - descendants: std::ops::Range, - } - - #[derive(Debug, Clone, Copy, PartialEq, Eq)] - pub enum EnabledStatus { - Enabled, - Disabled, - NotConfigured, - } - - impl ScopeMap { - pub fn new_from_settings_and_env( - items_input_map: &HashMap, - env_config: Option<&env_config::EnvFilter>, - ) -> Self { - let mut items = Vec::with_capacity( - items_input_map.len() + env_config.map_or(0, |c| c.directive_names.len()), - ); - if let Some(env_filter) = env_config { - // TODO: parse on load instead of every reload - items.extend( - env_filter - .directive_names - .iter() - .zip(env_filter.directive_levels.iter()) - .filter_map(|(scope, level_filter)| { - if items_input_map.get(scope).is_some() { - return None; - } - let scope = scope_alloc_from_scope_str(scope)?; - // TODO: use level filters instead of scopes in scope map - let level = level_filter.to_level()?; - - Some((scope, level)) - }), - ); - } - items.extend( - items_input_map - .into_iter() - .filter_map(|(scope_str, level_str)| { - let scope = scope_alloc_from_scope_str(&scope_str)?; - let level = level_from_level_str(&level_str)?; - return Some((scope, level)); - }), - ); - - items.sort_by(|a, b| a.0.cmp(&b.0)); - - let mut this = Self { - entries: Vec::with_capacity(items.len() * SCOPE_DEPTH_MAX), - root_count: 0, - }; - - let items_count = items.len(); - - struct ProcessQueueEntry { - parent_index: usize, - depth: usize, - items_range: std::ops::Range, - } - let mut process_queue = VecDeque::new(); - process_queue.push_back(ProcessQueueEntry { - parent_index: usize::MAX, - depth: 0, - items_range: 0..items_count, - }); - - let empty_range = 0..0; - - while let Some(process_entry) = process_queue.pop_front() { - let ProcessQueueEntry { - items_range, - depth, - parent_index, - } = process_entry; - let mut cursor = items_range.start; - let res_entries_start = this.entries.len(); - while cursor < items_range.end { - let sub_items_start = cursor; - cursor += 1; - let scope_name = &items[sub_items_start].0[depth]; - while cursor < items_range.end && &items[cursor].0[depth] == scope_name { - cursor += 1; - } - let sub_items_end = cursor; - if scope_name == "" { - assert_eq!(sub_items_start + 1, sub_items_end); - assert_ne!(depth, 0); - assert_ne!(parent_index, usize::MAX); - assert!(this.entries[parent_index].enabled.is_none()); - this.entries[parent_index].enabled = Some(items[sub_items_start].1); - continue; - } - let is_valid_scope = scope_name != ""; - let is_last = depth + 1 == SCOPE_DEPTH_MAX || !is_valid_scope; - let mut enabled = None; - if is_last { - assert_eq!( - sub_items_start + 1, - sub_items_end, - "Expected one item: got: {:?}", - &items[items_range.clone()] - ); - enabled = Some(items[sub_items_start].1); - } else { - let entry_index = this.entries.len(); - process_queue.push_back(ProcessQueueEntry { - items_range: sub_items_start..sub_items_end, - parent_index: entry_index, - depth: depth + 1, - }); - } - this.entries.push(ScopeMapEntry { - scope: scope_name.to_owned(), - enabled, - descendants: empty_range.clone(), - }); - } - let res_entries_end = this.entries.len(); - if parent_index != usize::MAX { - this.entries[parent_index].descendants = res_entries_start..res_entries_end; - } else { - this.root_count = res_entries_end; - } - } - - return this; - } - - pub fn is_empty(&self) -> bool { - self.entries.is_empty() - } - - pub fn is_enabled( - &self, - scope: &[S; SCOPE_DEPTH_MAX], - level: log_impl::Level, - ) -> EnabledStatus - where - S: AsRef, - { - let mut enabled = None; - let mut cur_range = &self.entries[0..self.root_count]; - let mut depth = 0; - - 'search: while !cur_range.is_empty() - && depth < SCOPE_DEPTH_MAX - && scope[depth].as_ref() != "" - { - for entry in cur_range { - if entry.scope == scope[depth].as_ref() { - // note: - enabled = entry.enabled.or(enabled); - cur_range = &self.entries[entry.descendants.clone()]; - depth += 1; - continue 'search; - } - } - break 'search; - } - - return enabled.map_or(EnabledStatus::NotConfigured, |level_enabled| { - if level <= level_enabled { - EnabledStatus::Enabled - } else { - EnabledStatus::Disabled - } - }); - } - } - - #[cfg(test)] - mod tests { - use crate::private::scope_new; - - use super::*; - - fn scope_map_from_keys(kv: &[(&str, &str)]) -> ScopeMap { - let hash_map: HashMap = kv - .iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect(); - ScopeMap::new_from_settings_and_env(&hash_map, None) - } - - #[test] - fn test_initialization() { - let map = scope_map_from_keys(&[("a.b.c.d", "trace")]); - assert_eq!(map.root_count, 1); - assert_eq!(map.entries.len(), 4); - - let map = scope_map_from_keys(&[]); - assert_eq!(map.root_count, 0); - assert_eq!(map.entries.len(), 0); - - let map = scope_map_from_keys(&[("", "trace")]); - assert_eq!(map.root_count, 0); - assert_eq!(map.entries.len(), 0); - - let map = scope_map_from_keys(&[("foo..bar", "trace")]); - assert_eq!(map.root_count, 1); - assert_eq!(map.entries.len(), 2); - - let map = scope_map_from_keys(&[ - ("a.b.c.d", "trace"), - ("e.f.g.h", "debug"), - ("i.j.k.l", "info"), - ("m.n.o.p", "warn"), - ("q.r.s.t", "error"), - ]); - assert_eq!(map.root_count, 5); - assert_eq!(map.entries.len(), 20); - assert_eq!(map.entries[0].scope, "a"); - assert_eq!(map.entries[1].scope, "e"); - assert_eq!(map.entries[2].scope, "i"); - assert_eq!(map.entries[3].scope, "m"); - assert_eq!(map.entries[4].scope, "q"); - } - - fn scope_from_scope_str(scope_str: &'static str) -> Scope { - let mut scope_buf = [""; SCOPE_DEPTH_MAX]; - let mut index = 0; - let mut scope_iter = scope_str.split(SCOPE_STRING_SEP); - while index < SCOPE_DEPTH_MAX { - let Some(scope) = scope_iter.next() else { - break; - }; - if scope == "" { - continue; - } - scope_buf[index] = scope; - index += 1; - } - assert_ne!(index, 0); - assert!(scope_iter.next().is_none()); - return scope_buf; - } - - #[test] - fn test_is_enabled() { - let map = scope_map_from_keys(&[ - ("a.b.c.d", "trace"), - ("e.f.g.h", "debug"), - ("i.j.k.l", "info"), - ("m.n.o.p", "warn"), - ("q.r.s.t", "error"), - ]); - use log_impl::Level; - assert_eq!( - map.is_enabled(&scope_from_scope_str("a.b.c.d"), Level::Trace), - EnabledStatus::Enabled - ); - assert_eq!( - map.is_enabled(&scope_from_scope_str("a.b.c.d"), Level::Debug), - EnabledStatus::Enabled - ); - - assert_eq!( - map.is_enabled(&scope_from_scope_str("e.f.g.h"), Level::Debug), - EnabledStatus::Enabled - ); - assert_eq!( - map.is_enabled(&scope_from_scope_str("e.f.g.h"), Level::Info), - EnabledStatus::Enabled - ); - assert_eq!( - map.is_enabled(&scope_from_scope_str("e.f.g.h"), Level::Trace), - EnabledStatus::Disabled - ); - - assert_eq!( - map.is_enabled(&scope_from_scope_str("i.j.k.l"), Level::Info), - EnabledStatus::Enabled - ); - assert_eq!( - map.is_enabled(&scope_from_scope_str("i.j.k.l"), Level::Warn), - EnabledStatus::Enabled - ); - assert_eq!( - map.is_enabled(&scope_from_scope_str("i.j.k.l"), Level::Debug), - EnabledStatus::Disabled - ); - - assert_eq!( - map.is_enabled(&scope_from_scope_str("m.n.o.p"), Level::Warn), - EnabledStatus::Enabled - ); - assert_eq!( - map.is_enabled(&scope_from_scope_str("m.n.o.p"), Level::Error), - EnabledStatus::Enabled - ); - assert_eq!( - map.is_enabled(&scope_from_scope_str("m.n.o.p"), Level::Info), - EnabledStatus::Disabled - ); - - assert_eq!( - map.is_enabled(&scope_from_scope_str("q.r.s.t"), Level::Error), - EnabledStatus::Enabled - ); - assert_eq!( - map.is_enabled(&scope_from_scope_str("q.r.s.t"), Level::Warn), - EnabledStatus::Disabled - ); - } - - fn scope_map_from_keys_and_env( - kv: &[(&str, &str)], - env: &env_config::EnvFilter, - ) -> ScopeMap { - let hash_map: HashMap = kv - .iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect(); - ScopeMap::new_from_settings_and_env(&hash_map, Some(env)) - } - - #[test] - fn test_initialization_with_env() { - let env_filter = env_config::parse("a.b=debug,u=error").unwrap(); - let map = scope_map_from_keys_and_env(&[], &env_filter); - assert_eq!(map.root_count, 2); - assert_eq!(map.entries.len(), 3); - assert_eq!( - map.is_enabled(&scope_new(&["a"]), log_impl::Level::Debug), - EnabledStatus::NotConfigured - ); - assert_eq!( - map.is_enabled(&scope_new(&["a", "b"]), log_impl::Level::Debug), - EnabledStatus::Enabled - ); - assert_eq!( - map.is_enabled(&scope_new(&["a", "b", "c"]), log_impl::Level::Trace), - EnabledStatus::Disabled - ); - - let env_filter = env_config::parse("a.b=debug,e.f.g.h=trace,u=error").unwrap(); - let map = scope_map_from_keys_and_env( - &[ - ("a.b.c.d", "trace"), - ("e.f.g.h", "debug"), - ("i.j.k.l", "info"), - ("m.n.o.p", "warn"), - ("q.r.s.t", "error"), - ], - &env_filter, - ); - assert_eq!(map.root_count, 6); - assert_eq!(map.entries.len(), 21); - assert_eq!(map.entries[0].scope, "a"); - assert_eq!(map.entries[1].scope, "e"); - assert_eq!(map.entries[2].scope, "i"); - assert_eq!(map.entries[3].scope, "m"); - assert_eq!(map.entries[4].scope, "q"); - assert_eq!(map.entries[5].scope, "u"); - assert_eq!( - map.is_enabled(&scope_new(&["a", "b", "c", "d"]), log_impl::Level::Trace), - EnabledStatus::Enabled - ); - assert_eq!( - map.is_enabled(&scope_new(&["a", "b", "c"]), log_impl::Level::Trace), - EnabledStatus::Disabled - ); - assert_eq!( - map.is_enabled(&scope_new(&["u", "v"]), log_impl::Level::Warn), - EnabledStatus::Disabled - ); - // settings override env - assert_eq!( - map.is_enabled(&scope_new(&["e", "f", "g", "h"]), log_impl::Level::Trace), - EnabledStatus::Disabled, - ); - } - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/zlog_settings/src/zlog_settings.rs b/crates/zlog_settings/src/zlog_settings.rs index 36e01dab24..fde28ba918 100644 --- a/crates/zlog_settings/src/zlog_settings.rs +++ b/crates/zlog_settings/src/zlog_settings.rs @@ -10,7 +10,7 @@ pub fn init(cx: &mut App) { cx.observe_global::(|cx| { let zlog_settings = ZlogSettings::get_global(cx); - zlog::scope_map::refresh_from_settings(&zlog_settings.scopes); + zlog::filter::refresh_from_settings(&zlog_settings.scopes); }) .detach(); } diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index ab1c6f5d09..9064370fe0 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -127,6 +127,7 @@ - [Vue](./languages/vue.md) - [XML](./languages/xml.md) - [YAML](./languages/yaml.md) +- [Yara](./languages/yara.md) - [Yarn](./languages/yarn.md) - [Zig](./languages/zig.md) diff --git a/docs/src/languages.md b/docs/src/languages.md index 4e7f05bb57..faee42176c 100644 --- a/docs/src/languages.md +++ b/docs/src/languages.md @@ -71,6 +71,7 @@ Some work out-of-the box and others rely on 3rd party extensions. - [Vue](./languages/vue.md) - [XML](./languages/xml.md) - [YAML](./languages/yaml.md) \* +- [Yara](./languages/yarn.md) - [Yarn](./languages/yarn.md) - [Zig](./languages/zig.md) diff --git a/docs/src/languages/haskell.md b/docs/src/languages/haskell.md index 8075ec2904..fec9142a5f 100644 --- a/docs/src/languages/haskell.md +++ b/docs/src/languages/haskell.md @@ -33,4 +33,19 @@ If you need to configure haskell-language-server (hls) you can add configuration } ``` -See: official [configuring haskell-language-server](https://haskell-language-server.readthedocs.io/en/latest/configuration.html) docs for more. +See the official [configuring haskell-language-server](https://haskell-language-server.readthedocs.io/en/latest/configuration.html) docs for more options. + +If you would like to use a specific hls binary, or perhaps use [static-ls](https://github.com/josephsumabat/static-ls) as a drop-in replacement instead, you can specify the binary path and arguments: + +```json +{ + "lsp": { + "hls": { + "binary": { + "path": "static-ls", + "arguments": ["--experimentalFeatures"] + } + } + } +} +``` diff --git a/docs/src/languages/lua.md b/docs/src/languages/lua.md index b10aa068b4..4ad143ce41 100644 --- a/docs/src/languages/lua.md +++ b/docs/src/languages/lua.md @@ -9,29 +9,107 @@ Lua support is available through the [Lua extension](https://github.com/zed-exte To configure LuaLS you can create a `.luarc.json` file in the root of your workspace. -See [LuaLS Settings Documentation](https://luals.github.io/wiki/settings/) for all available configuration options. +```json +{ + "$schema": "https://raw.githubusercontent.com/LuaLS/vscode-lua/master/setting/schema.json", + "runtime.version": "Lua 5.4", + "format.enable": true, + "workspace.library": ["../somedir/library"] +} +``` + +See [LuaLS Settings Documentation](https://luals.github.io/wiki/settings/) for all available configuration options, or when editing this file in Zed available settings options will autocomplete, (e.g `runtime.version` will show `"Lua 5.1"`, `"Lua 5.2"`, `"Lua 5.3"`, `"Lua 5.4"` and `"LuaJIT"` as allowed values). Note when importing settings options from VSCode, remove the `Lua.` prefix. (e.g. `runtime.version` instead of `Lua.runtime.version`). + +### LuaCATS Definitions + +LuaLS can provide enhanced LSP autocompletion suggestions and type validation with the help of LuaCATS (Lua Comment and Type System) definitions. These definitions are available for many common Lua libraries, and local paths containing them can be specified via `workspace.library` in `luarc.json`. You can do this via relative paths if you checkout your definitions into the same partent directory of your project (`../playdate-luacats`, `../love2d`, etc). Alternatively you can create submodule(s) inside your project for each LuaCATS definition repo. + +### LÖVE (Love2D) {#love2d} + +To use [LÖVE (Love2D)](https://love2d.org/) in Zed, checkout [LuaCATS/love2d](https://github.com/LuaCATS/love2d) into a folder called `love2d-luacats` into the parent folder of your project: + +```sh +cd .. && git clone https://github.com/LuaCATS/love2d love2d-luacats +``` + +Then in your `.luarc.json`: + +``` +{ + "$schema": "https://raw.githubusercontent.com/LuaLS/vscode-lua/master/setting/schema.json", + "runtime.version": "Lua 5.4", + "workspace.library": ["../love2d-luacats"], + "runtime.special": { + "love.filesystem.load": "loadfile" + } +} +``` + +### PlaydateSDK + +To use [Playdate Lua SDK](https://play.date/dev/) in Zed, checkout [playdate-luacats](https://github.com/notpeter/playdate-luacats) into the parent folder of your project: + +```sh +cd .. && git clone https://github.com/notpeter/playdate-luacats +``` + +Then in your `.luarc.json`: ```json { "$schema": "https://raw.githubusercontent.com/LuaLS/vscode-lua/master/setting/schema.json", "runtime.version": "Lua 5.4", - "diagnostics.severity": { - "duplicate-set-field": "Hint" - }, - "format.enable": true, + "runtime.nonstandardSymbol": [ + "+=", + "-=", + "*=", + "/=", + "//=", + "%=", + "<<=", + ">>=", + "&=", + "|=", + "^=" + ], + "diagnostics.severity": { "duplicate-set-field": "Hint" }, + "diagnostics.globals": ["import"], + "workspace.library": ["../playdate-luacats"], "format.defaultConfig": { "indent_style": "space", "indent_size": "4" }, - "workspace.library": ["../somedir/library"] + "format.enable": true, + "runtime.builtin": { "io": "disable", "os": "disable", "package": "disable" } } ``` +### Inlay Hints + +To enable [Inlay Hints](../configuring-languages#inlay-hints) for LuaLS in Zed + +1. Add the following to your Zed settings.json: + +```json + "languages": { + "Lua": { + "inlay_hints": { + "enabled": true, + "show_type_hints": true, + "show_parameter_hints": true, + "show_other_hints": true + } + } + } +``` + +2. Add `"hint.enable": true` to your `.luarc.json`. + ## Formatting ### LuaLS -To enable auto-formatting with your LuaLS, make sure you have `"format.enable": true,` in your .luarc.json add the following to your Zed `settings.json`: +To enable auto-formatting with your LuaLS (provided by [CppCXY/EmmyLuaCodeStyle](https://github.com/CppCXY/EmmyLuaCodeStyle)) make sure you have `"format.enable": true,` in your .luarc.json add the following to your Zed `settings.json`: ```json { @@ -44,9 +122,11 @@ To enable auto-formatting with your LuaLS, make sure you have `"format.enable": } ``` +You can customize various EmmyLuaCodeStyle style options via `.editorconfig`, see [lua.template.editorconfig](https://github.com/CppCXY/EmmyLuaCodeStyle/blob/master/lua.template.editorconfig) for all available options. + ### StyLua -Alternative you can use [StyLua](https://github.com/JohnnyMorganz/StyLua): +Alternatively to use [StyLua](https://github.com/JohnnyMorganz/StyLua) for auto-formatting: 1. Install [StyLua](https://github.com/JohnnyMorganz/StyLua): `brew install stylua` or `cargo install stylua --features lua52,lua53,lua54,luau,luajit` (feel free to remove any Lua versions you don't need). 2. Add the following to your `settings.json`: diff --git a/docs/src/languages/yara.md b/docs/src/languages/yara.md new file mode 100644 index 0000000000..f95ab2f778 --- /dev/null +++ b/docs/src/languages/yara.md @@ -0,0 +1,6 @@ +# Yara + +`Yara` language support in Zed is provided by the [Yara](https://github.com/egibs/yara.zed) extension. Please report issues to [https://github.com/egibs/yara.zed/issues](https://github.com/egibs/yara.zed). + +- Tree-sitter: [egibs/tree-sitter-yara](https://github.com/egibs/tree-sitter-yara) +- Language Server: [avast/yls](https://github.com/avast/yls) diff --git a/docs/src/tasks.md b/docs/src/tasks.md index 7eb6fd7f8d..557bedb118 100644 --- a/docs/src/tasks.md +++ b/docs/src/tasks.md @@ -65,7 +65,7 @@ Keep `"use_new_terminal": false` and set `"allow_concurrent_runs": true` to allo Tasks can be defined: - in the global `tasks.json` file; such tasks are available in all Zed projects you work on. This file is usually located in `~/.config/zed/tasks.json`. You can edit them by using the `zed: open tasks` action. -- in the worktree-specific (local) `.zed/tasks.json` file; such tasks are available only when working on a project with that worktree included. You can edit worktree-specific tasks by using the `zed: open local tasks` action. +- in the worktree-specific (local) `.zed/tasks.json` file; such tasks are available only when working on a project with that worktree included. You can edit worktree-specific tasks by using the `zed: open project tasks` action. - on the fly with [oneshot tasks](#oneshot-tasks). These tasks are project-specific and do not persist across sessions. - by language extension. diff --git a/flake.lock b/flake.lock index 45fbb2d49f..c09eb90d2d 100644 --- a/flake.lock +++ b/flake.lock @@ -61,11 +61,11 @@ ] }, "locked": { - "lastModified": 1743215516, - "narHash": "sha256-52qbrkG65U1hyrQWltgHTgH4nm0SJL+9TWv2UDCEPNI=", + "lastModified": 1743906877, + "narHash": "sha256-Thah1oU8Vy0gs9bh5QhNcQh1iuQiowMnZPbrkURonZA=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "524463199fdee49338006b049bc376b965a2cfed", + "rev": "9d00c6b69408dd40d067603012938d9fbe95cfcd", "type": "github" }, "original": { diff --git a/script/generate-licenses b/script/generate-licenses index 901e1040a2..9fcb2bd513 100755 --- a/script/generate-licenses +++ b/script/generate-licenses @@ -18,7 +18,7 @@ echo -n "" >"$OUTPUT_FILE" echo -e "\n# ###### CODE LICENSES ######\n" } >>"$OUTPUT_FILE" -if ! cargo about --version | grep "cargo-about $CARGO_ABOUT_VERSION" 2>&1 >/dev/null; then +if ! cargo about --version | grep "cargo-about $CARGO_ABOUT_VERSION" &>/dev/null; then echo "Installing cargo-about@^$CARGO_ABOUT_VERSION..." cargo install "cargo-about@^$CARGO_ABOUT_VERSION" else diff --git a/tooling/workspace-hack/Cargo.toml b/tooling/workspace-hack/Cargo.toml index fac5fec310..c7a2529172 100644 --- a/tooling/workspace-hack/Cargo.toml +++ b/tooling/workspace-hack/Cargo.toml @@ -512,6 +512,8 @@ tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } winapi = { version = "0.3", default-features = false, features = ["cfg", "consoleapi", "errhandlingapi", "evntrace", "fileapi", "handleapi", "in6addr", "inaddr", "knownfolders", "minwinbase", "ntsecapi", "objbase", "processenv", "processthreadsapi", "shlobj", "std", "sysinfoapi", "winbase", "windef", "winerror", "winioctl"] } +windows-core = { version = "0.61" } +windows-numerics = { version = "0.2" } windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } windows-sys-b21d60becc0929df = { package = "windows-sys", version = "0.52", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_IO", "Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming"] } windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_UI_Shell"] } @@ -533,6 +535,8 @@ tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } winapi = { version = "0.3", default-features = false, features = ["cfg", "consoleapi", "errhandlingapi", "evntrace", "fileapi", "handleapi", "in6addr", "inaddr", "knownfolders", "minwinbase", "ntsecapi", "objbase", "processenv", "processthreadsapi", "shlobj", "std", "sysinfoapi", "winbase", "windef", "winerror", "winioctl"] } +windows-core = { version = "0.61" } +windows-numerics = { version = "0.2" } windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } windows-sys-b21d60becc0929df = { package = "windows-sys", version = "0.52", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_IO", "Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming"] } windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_UI_Shell"] } diff --git a/typos.toml b/typos.toml index f4ec8a1220..1c90bf5926 100644 --- a/typos.toml +++ b/typos.toml @@ -41,6 +41,10 @@ extend-exclude = [ "docs/theme/css/", # Spellcheck triggers on `|Fixe[sd]|` regex part. "script/danger/dangerfile.ts", + # Eval examples for prompts and criteria + "crates/eval/examples/checkpoint_stability/criteria.md", + "crates/eval/examples/tax_id_validation/prompt.md", + "crates/eval/examples/tax_id_validation/criteria.md" ] [default]