Compare commits

..

35 Commits

Author SHA1 Message Date
Richard Feldman
7682ddbafe Merge remote-tracking branch 'origin/main' into anthropic_base_url 2025-09-09 10:31:34 -04:00
Smit Barmase
734f94b71c agent_ui: Fix crash when typing multibyte character after mention (#37847)
Closes #36333

Release Notes:

- Fixed a crash that occurred when typing an IME character right after a
mention in the Agent Panel.
2025-09-09 19:16:22 +05:30
Danilo Leal
136468a4df keymap editor: Add some adjustments to the UI (#37819)
- Makes the keymap editor search container more consistent with the
project & file search corresponding elements
- Changes the keymap editor menu item in the user menu be called "Keymap
Editor", as opposed to "Key Binding", to match with the tab and action
name

Design note: Still a bit unsure about the extra space on the right for
the keymap editor. This makes it way more consistent with the other
search views, but it also just feels like space that could be used. On
the other hand, though, it's very unlikely anyone will ever use more
than 30% of the search bar width as search queries here are likely
pretty short; definitely much shorter than project search queries.

<img width="600" height="552" alt="Screenshot 2025-09-09 at 1  02@2x"
src="https://github.com/user-attachments/assets/9825a129-2c5a-4852-9837-c586b88e9332"
/>


Release Notes:

- N/A
2025-09-09 09:36:12 -03:00
Lukas Wirth
adf43d691a project: Remove non searchable buffer entries on buffer close (#37841)
Release Notes:

- N/A
2025-09-09 11:21:51 +00:00
localcc
466a2e22d5 Improve font rendering on macOS (#37622)
Part of https://github.com/zed-industries/zed/issues/7992

Release Notes:

- N/A
2025-09-09 13:46:59 +03:00
Finn Evers
365c5ab45f editor: Remove unnecessary clone (#37833)
The style is taken by reference everywhere, so no need to clone it at
the start of every `prepaint`.

Release Notes:

- N/A
2025-09-09 08:04:27 +00:00
Jakub Konka
11d81b95d4 Revert "git: Use self.git_binary_path instead raw git string" (#37828)
Reverts zed-industries/zed#37757
2025-09-09 06:36:25 +00:00
Michael Sloan
4b3b2acf75 Fix hot reload of builtin TreeSitter queries on Linux (#37825)
`fs.watch` is recursive on mac and non-recursive on Linux.

Release Notes:

- N/A
2025-09-09 05:55:28 +00:00
Danilo Leal
849424740f Fix code action menu items font size in toolbar (#37824)
Closes https://github.com/zed-industries/zed/issues/36478

Release Notes:

- N/A
2025-09-09 02:20:23 -03:00
Danilo Leal
3e605c2c4b docs: Fix casing on mentions to some brand names (#37822)
- MacOS → macOS
- VSCode → VS Code
- SublimeText → Sublime Text
- Javascript/Typescript → JavaScript/TypeScript

Release Notes:

- N/A
2025-09-09 01:45:55 -03:00
Danilo Leal
82b11bf77c docs: Include Cursor in the list of supported base keymaps (#37821)
We were missing that in the /key-bindings page. Also took advantage of
the opportunity to add a bunch of small writing tweaks.

Release Notes:

- N/A
2025-09-09 01:39:52 -03:00
Conrad Irwin
3a437fd888 Remove Chat (#37789)
At RustConf we were demo'ing zed, and it continually popped open the
chat panel.

We're usually inured to this because the Chat panel doesn't open unless
a Guest
is in the channel, but it made me sad that we were showing a long stream
of
vacuous comments and unresponded to questions on every demo screen.

We may bring chat back in the future, but we need more thought on the
UX, and
we need to rebuild the backend to not use the existing collab server
that we're
trying to move off of.

Release Notes:

- Removed the chat feature from Zed (Sorry to the 5 of you who use this
on the regular!)
2025-09-08 21:53:17 -06:00
Conrad Irwin
96c429d2c3 Only reject agent actions, don't restore checkpoint on revert (#37801)
Updates #37623

Release Notes:

- Changed the behaviour when editing an old message in a native agent
thread.
Prior to this, it would automatically restore the checkpoint (which
could
lead to a surprising amount of work being discarded). Now it will just
reject
any unaccepted agent edits, and you can use the "restore checkpoint"
button
  for the original behavior.
2025-09-08 20:18:40 -06:00
Marshall Bowers
ea4073e50e cloud_llm_client: Add new Plan variants (#37810)
This PR adds new variants to the `Plan` enum.

Release Notes:

- N/A
2025-09-09 00:18:43 +00:00
Ben Kunkle
8c93112869 settings_ui: Add Basic Implementation of Language Settings (#37803)
Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

---------

Co-authored-by: Anthony <anthony@zed.dev>
2025-09-08 20:14:36 -04:00
Marshall Bowers
1feffad5e8 Remove zed-pro feature flag (#37807)
This PR removes the `zed-pro` feature flag, as it was not being used.

Release Notes:

- N/A
2025-09-08 23:23:16 +00:00
Ivan Danov
ae54a4e1b8 Add commands to select next/previous siblings in the syntax tree (#35053)
Closes #5133 and discussion
https://github.com/zed-industries/zed/discussions/33493

This PR adds two new commands to select next/previous siblings in the
syntax tree. These commands were modelled after the existing ones about
expand/shrink selection. With this PR I've added new key bindings
inspired by `helix` for previous / next / expand / shrink selections.



https://github.com/user-attachments/assets/4ef7fadb-0b82-4897-95c7-1737827bf4ac


Release Notes:

- Add commands to select next/previous siblings in the syntax tree

---------

Co-authored-by: Joseph T. Lyons <JosephTLyons@gmail.com>
2025-09-08 23:11:53 +00:00
Danilo Leal
4a0a7d1d27 Add item for the debugger panel in the app view menu (#37805)
Release Notes:

- Enabled the debugger panel to be opened via the app's "View" menu
option
2025-09-08 19:19:10 -03:00
Danilo Leal
5934d3789b python: Improve Basedpyright banner styles (#37802)
Just tidying this up a bit.

Release Notes:

- N/A
2025-09-08 18:51:13 -03:00
Danilo Leal
acde79dae7 agent: Improve popover trigger styles (#37800)
This PR mostly adds some style treatment to popover button triggers in
the agent panel, either making them better aligned with their trigger or
adjusting the color to better clarify which button is triggering the
currently opened menu.

Moving forward, I think the selected styles at least should probably be
tackled at the component level, whether that's a context menu or a
popover, so we don't have to ever do this manually (and running the risk
of forgetting to do it).

Release Notes:

- N/A
2025-09-08 18:39:49 -03:00
Patsakula Nikita
246c644316 agent_servers: Fix proxy configuration for Gemini (#37790)
Closes #37487 

Proxy settings are now taken from the Zed configuration and passed to
Gemini via the "--proxy" flag.

Release Notes:

- acp: Gemini ACP server now uses proxy settings from Zed configuration.
2025-09-08 20:44:40 +00:00
Marshall Bowers
e4de26e5dc cloud_llm_client: Remove unused code (#37799)
This PR removes some unused code from the `cloud_llm_client`.

This was only used on the server, so we can move it there.

Release Notes:

- N/A
2025-09-08 20:42:42 +00:00
ZhangJun
7091c70a1e open_ai: Trim newline before "data:" prefix and account for the possibility of no space after ":" (#37644)
I'am using an openai compatible model, but got nothing in agent thread
panel, and Zed log has "Model generated an empty summary" line.

I add one log to open_ai.rs:
<img width="2454" height="626" alt="图片"
src="https://github.com/user-attachments/assets/85354c7d-a0cc-4bba-86fd-2a640038a13e"
/>

and got:

<img width="3456" height="278" alt="图片"
src="https://github.com/user-attachments/assets/7746aedd-5d76-44b5-90f2-e129a1507178"
/>

It appear that `let line = line.strip_prefix("data: ")?;` can not handle
correctly.

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-09-08 22:01:55 +02:00
Cole Miller
fa0df6da1c python: Replace pyright with basedpyright (#35362)
Follow-up to #35250. Let's experiment with having this by default on
nightly.

Release Notes:

- Added built-in support for the basedpyright language server for Python
code. basedpyright is now enabled by default, and pyright (previously
the primary Python language server) remains available but is disabled by
default. This supersedes the basedpyright extension, which can be
uninstalled. Advantages of basedpyright over pyright include support for
inlay hints, semantic highlighting, auto-import code actions, and
stricter type checking. To switch back to pyright, add the following
configuration to settings.json:

```json
{
  "languages": {
    "Python": {
      "language_servers": ["pyright", "pylsp", "!basedpyright"]
    }
  }
}
```

---------

Co-authored-by: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com>
Co-authored-by: Lukas Wirth <lukas@zed.dev>
2025-09-08 19:15:17 +00:00
Cole Miller
99102a84fa ACP over SSH (#37725)
This PR adds support for using external agents in SSH projects via ACP,
including automatic installation of Gemini CLI and Claude Code,
authentication with API keys (for Gemini) and CLI login, and custom
agents from user configuration.

Co-authored-by: maan2003 <manmeetmann2003@gmail.com>

Release Notes:

- agent: Gemini CLI, Claude Code, and custom external agents can now be
used in SSH projects.

---------

Co-authored-by: maan2003 <manmeetmann2003@gmail.com>
2025-09-08 14:19:41 -04:00
Cole Miller
5f01f6d75f agent: Make read_file and edit_file tool call titles more specific (#37639)
For read_file and edit_file, show the worktree-relative path if there's
only one visible worktree, and the "full path" otherwise. Also restores
the display of line numbers for read_file calls.

Release Notes:

- N/A
2025-09-08 12:57:22 -04:00
Dave Waggoner
a66cd820b3 Fix line endings in terminal_hyperlinks.rs (#37654)
Fixes Windows line endings in `terminal_hyperlinks.rs`, which was
accidentally originally added with them.

Release Notes:

- N/A
2025-09-08 12:21:47 -04:00
hong jihwan
f07da9d9f2 Correctly parse backslash character on replacement (#37014)
When a keybind contains a backslash character (\\), it is parsed
incorrectly, which results in an invalid keybind configuration.


This patch fixes the issue by ensuring that backslashes are properly
escaped during the parsing process. This allows them to be used as
intended in keybind definitions.

Release Notes:

- Fixed an issue where keybinds containing a backslash character (\\)
failed to be replaced correctly


## Screenshots
<img width="912" height="530" alt="SCR-20250828-borp"
src="https://github.com/user-attachments/assets/561a040f-575b-4222-ac75-17ab4fa71d07"
/>
<img width="912" height="530" alt="SCR-20250828-bosx"
src="https://github.com/user-attachments/assets/b8e0fb99-549e-4fc9-8609-9b9aa2004656"
/>
2025-09-08 12:17:48 -04:00
Iha Shin (신의하)
8d05bb090c Add injections for Isograph function calls in JavaScript and TypeScript (#36320)
Required for https://github.com/isographlabs/isograph/pull/568 to work
properly. Tested with a local build and made sure everything's working
great!

Release Notes:

- JavaScript/TypeScript/JSX: Added support for injecting Isograph language support into `iso`
function calls
2025-09-08 16:04:37 +00:00
Dino
2325f14713 diagnostics: Current file diagnostics view (#34430)
These changes introduce a new command to the Diagnostics panel,
`diagnostics: deploy current file`, which allows the user to view the
diagnostics only for the currently opened file.

Here's a screen recording showing these changes in action 🔽 

[diagnostics: deploy current
file](https://github.com/user-attachments/assets/b0e87eea-3b3a-4888-95f8-9e21aff8ea97)

Closes #4739 

Release Notes:

- Added new `diagnostics: deploy current file` command to view
diagnostics for the currently open file

---------

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-09-08 09:14:24 -06:00
Ben Kunkle
fe2aa3f4cb onboarding: Fix font loading frame delay (#37668)
Closes #ISSUE

Fixed an issue where the first frame of the `Editing` page in onboarding
would have a slight delay before rendering the first time it was
navigated to. This was caused by listing the OS fonts on the main
thread, blocking rendering. This PR fixes the issue by adding a new
method to the font family cache to prefill the cache on a background
thread.

Release Notes:

- N/A *or* Added/Fixed/Improved ...

---------

Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
Co-authored-by: Anthony Eid <hello@anthonyeid.me>
Co-authored-by: Anthony <anthony@zed.dev>
2025-09-08 11:09:54 -04:00
fantacell
10989c702c helix: Add match operator (#34060)
This is an implementation of matching like "m i (", as well as "] (" and
"[ (" in `helix_mode` with a few supported objects and a basis for more.

Release Notes:

- Added helix operators for selecting text objects

---------

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-09-08 08:48:47 -06:00
张小白
3f80ac0127 macos: Fix menu bar flickering (#37707)
Closes #37526

Release Notes:

- Fixed menu bar flickering when using some IMEs on macOS.
2025-09-08 10:44:19 -04:00
Nia Espera
1884d83e6f fixup 2025-09-04 20:40:30 +02:00
Richard Feldman
370fe8ce23 wip 2025-09-04 12:48:45 -04:00
183 changed files with 7845 additions and 10187 deletions

View File

@@ -26,7 +26,7 @@ third-party = [
# build of remote_server should not include scap / its x11 dependency
{ name = "scap", git = "https://github.com/zed-industries/scap", rev = "808aa5c45b41e8f44729d02e38fd00a2fe2722e7" },
# build of remote_server should not need to include on libalsa through rodio
{ name = "rodio", git = "https://github.com/RustAudio/rodio", branch = "better_wav_output"},
{ name = "rodio" },
]
[final-excludes]

154
Cargo.lock generated
View File

@@ -308,22 +308,18 @@ dependencies = [
"libc",
"log",
"nix 0.29.0",
"node_runtime",
"paths",
"project",
"reqwest_client",
"schemars",
"semver",
"serde",
"serde_json",
"settings",
"smol",
"task",
"tempfile",
"thiserror 2.0.12",
"ui",
"util",
"watch",
"which 6.0.3",
"workspace-hack",
]
@@ -486,6 +482,7 @@ dependencies = [
"client",
"cloud_llm_client",
"component",
"feature_flags",
"gpui",
"language_model",
"serde",
@@ -1385,19 +1382,12 @@ name = "audio"
version = "0.1.0"
dependencies = [
"anyhow",
"async-tar",
"collections",
"crossbeam",
"gpui",
"libwebrtc",
"log",
"parking_lot",
"rodio",
"schemars",
"serde",
"settings",
"smol",
"thiserror 2.0.12",
"util",
"workspace-hack",
]
@@ -2618,7 +2608,6 @@ dependencies = [
"audio",
"client",
"collections",
"feature_flags",
"fs",
"futures 0.3.31",
"gpui",
@@ -2894,11 +2883,9 @@ dependencies = [
"language",
"log",
"postage",
"rand 0.9.1",
"release_channel",
"rpc",
"settings",
"sum_tree",
"text",
"time",
"util",
@@ -3386,12 +3373,10 @@ dependencies = [
"collections",
"db",
"editor",
"emojis",
"futures 0.3.31",
"fuzzy",
"gpui",
"http_client",
"language",
"log",
"menu",
"notifications",
@@ -3399,7 +3384,6 @@ dependencies = [
"pretty_assertions",
"project",
"release_channel",
"rich_text",
"rpc",
"schemars",
"serde",
@@ -4152,19 +4136,6 @@ dependencies = [
"itertools 0.10.5",
]
[[package]]
name = "crossbeam"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.15"
@@ -9206,6 +9177,19 @@ dependencies = [
"x_ai",
]
[[package]]
name = "language_onboarding"
version = "0.1.0"
dependencies = [
"db",
"editor",
"gpui",
"project",
"ui",
"workspace",
"workspace-hack",
]
[[package]]
name = "language_selector"
version = "0.1.0"
@@ -9268,7 +9252,6 @@ dependencies = [
"chrono",
"collections",
"dap",
"feature_flags",
"futures 0.3.31",
"gpui",
"http_client",
@@ -9673,7 +9656,6 @@ dependencies = [
"scap",
"serde",
"serde_json",
"serde_urlencoded",
"settings",
"sha2",
"simplelog",
@@ -12627,6 +12609,7 @@ dependencies = [
"remote",
"rpc",
"schemars",
"semver",
"serde",
"serde_json",
"settings",
@@ -12646,6 +12629,7 @@ dependencies = [
"unindent",
"url",
"util",
"watch",
"which 6.0.3",
"workspace-hack",
"worktree",
@@ -13875,15 +13859,15 @@ dependencies = [
[[package]]
name = "rodio"
version = "0.21.1"
source = "git+https://github.com/RustAudio/rodio?branch=better_wav_output#82514bd1f2c6cfd9a1a885019b26a8ffea75bc5c"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e40ecf59e742e03336be6a3d53755e789fd05a059fa22dfa0ed624722319e183"
dependencies = [
"cpal",
"dasp_sample",
"hound",
"num-rational",
"rtrb",
"symphonia",
"thiserror 2.0.12",
"tracing",
]
[[package]]
@@ -13957,12 +13941,6 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rtrb"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad8388ea1a9e0ea807e442e8263a699e7edcb320ecbcd21b4fa8ff859acce3ba"
[[package]]
name = "rules_library"
version = "0.1.0"
@@ -15924,53 +15902,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "815c942ae7ee74737bb00f965fa5b5a2ac2ce7b6c01c0cc169bbeaf7abd5f5a9"
dependencies = [
"lazy_static",
"symphonia-bundle-flac",
"symphonia-bundle-mp3",
"symphonia-codec-aac",
"symphonia-codec-pcm",
"symphonia-codec-vorbis",
"symphonia-core",
"symphonia-format-isomp4",
"symphonia-format-ogg",
"symphonia-format-riff",
"symphonia-metadata",
]
[[package]]
name = "symphonia-bundle-flac"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72e34f34298a7308d4397a6c7fbf5b84c5d491231ce3dd379707ba673ab3bd97"
dependencies = [
"log",
"symphonia-core",
"symphonia-metadata",
"symphonia-utils-xiph",
]
[[package]]
name = "symphonia-bundle-mp3"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c01c2aae70f0f1fb096b6f0ff112a930b1fb3626178fba3ae68b09dce71706d4"
dependencies = [
"lazy_static",
"log",
"symphonia-core",
"symphonia-metadata",
]
[[package]]
name = "symphonia-codec-aac"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdbf25b545ad0d3ee3e891ea643ad115aff4ca92f6aec472086b957a58522f70"
dependencies = [
"lazy_static",
"log",
"symphonia-core",
]
[[package]]
name = "symphonia-codec-pcm"
version = "0.5.4"
@@ -15981,17 +15918,6 @@ dependencies = [
"symphonia-core",
]
[[package]]
name = "symphonia-codec-vorbis"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a98765fb46a0a6732b007f7e2870c2129b6f78d87db7987e6533c8f164a9f30"
dependencies = [
"log",
"symphonia-core",
"symphonia-utils-xiph",
]
[[package]]
name = "symphonia-core"
version = "0.5.4"
@@ -16005,31 +15931,6 @@ dependencies = [
"log",
]
[[package]]
name = "symphonia-format-isomp4"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "abfdf178d697e50ce1e5d9b982ba1b94c47218e03ec35022d9f0e071a16dc844"
dependencies = [
"encoding_rs",
"log",
"symphonia-core",
"symphonia-metadata",
"symphonia-utils-xiph",
]
[[package]]
name = "symphonia-format-ogg"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ada3505789516bcf00fc1157c67729eded428b455c27ca370e41f4d785bfa931"
dependencies = [
"log",
"symphonia-core",
"symphonia-metadata",
"symphonia-utils-xiph",
]
[[package]]
name = "symphonia-format-riff"
version = "0.5.4"
@@ -16054,16 +15955,6 @@ dependencies = [
"symphonia-core",
]
[[package]]
name = "symphonia-utils-xiph"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "484472580fa49991afda5f6550ece662237b00c6f562c7d9638d1b086ed010fe"
dependencies = [
"symphonia-core",
"symphonia-metadata",
]
[[package]]
name = "syn"
version = "1.0.109"
@@ -19954,7 +19845,6 @@ dependencies = [
"core-foundation-sys",
"cranelift-codegen",
"crc32fast",
"crossbeam-channel",
"crossbeam-epoch",
"crossbeam-utils",
"crypto-common",
@@ -19998,7 +19888,6 @@ dependencies = [
"libsqlite3-sys",
"linux-raw-sys 0.4.15",
"linux-raw-sys 0.9.4",
"livekit-runtime",
"log",
"lyon",
"lyon_path",
@@ -20484,7 +20373,6 @@ dependencies = [
"acp_tools",
"activity_indicator",
"agent",
"agent_servers",
"agent_settings",
"agent_ui",
"anyhow",
@@ -20548,11 +20436,13 @@ dependencies = [
"language_extension",
"language_model",
"language_models",
"language_onboarding",
"language_selector",
"language_tools",
"languages",
"libc",
"line_ending_selector",
"livekit_client",
"log",
"markdown",
"markdown_preview",

View File

@@ -94,6 +94,7 @@ members = [
"crates/language_extension",
"crates/language_model",
"crates/language_models",
"crates/language_onboarding",
"crates/language_selector",
"crates/language_tools",
"crates/languages",
@@ -276,7 +277,6 @@ context_server = { path = "crates/context_server" }
copilot = { path = "crates/copilot" }
crashes = { path = "crates/crashes" }
credentials_provider = { path = "crates/credentials_provider" }
crossbeam = "0.8.4"
dap = { path = "crates/dap" }
dap_adapters = { path = "crates/dap_adapters" }
db = { path = "crates/db" }
@@ -321,6 +321,7 @@ language = { path = "crates/language" }
language_extension = { path = "crates/language_extension" }
language_model = { path = "crates/language_model" }
language_models = { path = "crates/language_models" }
language_onboarding = { path = "crates/language_onboarding" }
language_selector = { path = "crates/language_selector" }
language_tools = { path = "crates/language_tools" }
languages = { path = "crates/languages" }
@@ -368,7 +369,7 @@ remote_server = { path = "crates/remote_server" }
repl = { path = "crates/repl" }
reqwest_client = { path = "crates/reqwest_client" }
rich_text = { path = "crates/rich_text" }
rodio = { git = "https://github.com/RustAudio/rodio", branch = "better_wav_output"}
rodio = { version = "0.21.1", default-features = false }
rope = { path = "crates/rope" }
rpc = { path = "crates/rpc" }
rules_library = { path = "crates/rules_library" }

View File

@@ -32,34 +32,6 @@
"(": "vim::SentenceBackward",
")": "vim::SentenceForward",
"|": "vim::GoToColumn",
"] ]": "vim::NextSectionStart",
"] [": "vim::NextSectionEnd",
"[ [": "vim::PreviousSectionStart",
"[ ]": "vim::PreviousSectionEnd",
"] m": "vim::NextMethodStart",
"] shift-m": "vim::NextMethodEnd",
"[ m": "vim::PreviousMethodStart",
"[ shift-m": "vim::PreviousMethodEnd",
"[ *": "vim::PreviousComment",
"[ /": "vim::PreviousComment",
"] *": "vim::NextComment",
"] /": "vim::NextComment",
"[ -": "vim::PreviousLesserIndent",
"[ +": "vim::PreviousGreaterIndent",
"[ =": "vim::PreviousSameIndent",
"] -": "vim::NextLesserIndent",
"] +": "vim::NextGreaterIndent",
"] =": "vim::NextSameIndent",
"] b": "pane::ActivateNextItem",
"[ b": "pane::ActivatePreviousItem",
"] shift-b": "pane::ActivateLastItem",
"[ shift-b": ["pane::ActivateItem", 0],
"] space": "vim::InsertEmptyLineBelow",
"[ space": "vim::InsertEmptyLineAbove",
"[ e": "editor::MoveLineUp",
"] e": "editor::MoveLineDown",
"[ f": "workspace::FollowNextCollaborator",
"] f": "workspace::FollowNextCollaborator",
// Word motions
"w": "vim::NextWordStart",
@@ -83,10 +55,6 @@
"n": "vim::MoveToNextMatch",
"shift-n": "vim::MoveToPreviousMatch",
"%": "vim::Matching",
"] }": ["vim::UnmatchedForward", { "char": "}" }],
"[ {": ["vim::UnmatchedBackward", { "char": "{" }],
"] )": ["vim::UnmatchedForward", { "char": ")" }],
"[ (": ["vim::UnmatchedBackward", { "char": "(" }],
"f": ["vim::PushFindForward", { "before": false, "multiline": false }],
"t": ["vim::PushFindForward", { "before": true, "multiline": false }],
"shift-f": ["vim::PushFindBackward", { "after": false, "multiline": false }],
@@ -219,6 +187,46 @@
".": "vim::Repeat"
}
},
{
"context": "vim_mode == normal || vim_mode == visual || vim_mode == operator",
"bindings": {
"] ]": "vim::NextSectionStart",
"] [": "vim::NextSectionEnd",
"[ [": "vim::PreviousSectionStart",
"[ ]": "vim::PreviousSectionEnd",
"] m": "vim::NextMethodStart",
"] shift-m": "vim::NextMethodEnd",
"[ m": "vim::PreviousMethodStart",
"[ shift-m": "vim::PreviousMethodEnd",
"[ *": "vim::PreviousComment",
"[ /": "vim::PreviousComment",
"] *": "vim::NextComment",
"] /": "vim::NextComment",
"[ -": "vim::PreviousLesserIndent",
"[ +": "vim::PreviousGreaterIndent",
"[ =": "vim::PreviousSameIndent",
"] -": "vim::NextLesserIndent",
"] +": "vim::NextGreaterIndent",
"] =": "vim::NextSameIndent",
"] b": "pane::ActivateNextItem",
"[ b": "pane::ActivatePreviousItem",
"] shift-b": "pane::ActivateLastItem",
"[ shift-b": ["pane::ActivateItem", 0],
"] space": "vim::InsertEmptyLineBelow",
"[ space": "vim::InsertEmptyLineAbove",
"[ e": "editor::MoveLineUp",
"] e": "editor::MoveLineDown",
"[ f": "workspace::FollowNextCollaborator",
"] f": "workspace::FollowNextCollaborator",
"] }": ["vim::UnmatchedForward", { "char": "}" }],
"[ {": ["vim::UnmatchedBackward", { "char": "{" }],
"] )": ["vim::UnmatchedForward", { "char": ")" }],
"[ (": ["vim::UnmatchedBackward", { "char": "(" }],
// tree-sitter related commands
"[ x": "vim::SelectLargerSyntaxNode",
"] x": "vim::SelectSmallerSyntaxNode"
}
},
{
"context": "vim_mode == normal",
"bindings": {
@@ -249,9 +257,6 @@
"g w": "vim::PushRewrap",
"g q": "vim::PushRewrap",
"insert": "vim::InsertBefore",
// tree-sitter related commands
"[ x": "vim::SelectLargerSyntaxNode",
"] x": "vim::SelectSmallerSyntaxNode",
"] d": "editor::GoToDiagnostic",
"[ d": "editor::GoToPreviousDiagnostic",
"] c": "editor::GoToHunk",
@@ -317,10 +322,7 @@
"g w": "vim::Rewrap",
"g ?": "vim::ConvertToRot13",
// "g ?": "vim::ConvertToRot47",
"\"": "vim::PushRegister",
// tree-sitter related commands
"[ x": "editor::SelectLargerSyntaxNode",
"] x": "editor::SelectSmallerSyntaxNode"
"\"": "vim::PushRegister"
}
},
{
@@ -397,6 +399,9 @@
"ctrl-[": "editor::Cancel",
";": "vim::HelixCollapseSelection",
":": "command_palette::Toggle",
"m": "vim::PushHelixMatch",
"]": ["vim::PushHelixNext", { "around": true }],
"[": ["vim::PushHelixPrevious", { "around": true }],
"left": "vim::WrappingLeft",
"right": "vim::WrappingRight",
"h": "vim::WrappingLeft",
@@ -419,13 +424,6 @@
"insert": "vim::InsertBefore",
"alt-.": "vim::RepeatFind",
"alt-s": ["editor::SplitSelectionIntoLines", { "keep_selections": true }],
// tree-sitter related commands
"[ x": "editor::SelectLargerSyntaxNode",
"] x": "editor::SelectSmallerSyntaxNode",
"] d": "editor::GoToDiagnostic",
"[ d": "editor::GoToPreviousDiagnostic",
"] c": "editor::GoToHunk",
"[ c": "editor::GoToPreviousHunk",
// Goto mode
"g n": "pane::ActivateNextItem",
"g p": "pane::ActivatePreviousItem",
@@ -469,9 +467,6 @@
"space c": "editor::ToggleComments",
"space y": "editor::Copy",
"space p": "editor::Paste",
// Match mode
"m m": "vim::Matching",
"m i w": ["workspace::SendKeystrokes", "v i w"],
"shift-u": "editor::Redo",
"ctrl-c": "editor::ToggleComments",
"d": "vim::HelixDelete",
@@ -540,7 +535,7 @@
}
},
{
"context": "vim_operator == a || vim_operator == i || vim_operator == cs",
"context": "vim_operator == a || vim_operator == i || vim_operator == cs || vim_operator == helix_next || vim_operator == helix_previous",
"bindings": {
"w": "vim::Word",
"shift-w": ["vim::Word", { "ignore_punctuation": true }],
@@ -577,6 +572,48 @@
"e": "vim::EntireFile"
}
},
{
"context": "vim_operator == helix_m",
"bindings": {
"m": "vim::Matching"
}
},
{
"context": "vim_operator == helix_next",
"bindings": {
"z": "vim::NextSectionStart",
"shift-z": "vim::NextSectionEnd",
"*": "vim::NextComment",
"/": "vim::NextComment",
"-": "vim::NextLesserIndent",
"+": "vim::NextGreaterIndent",
"=": "vim::NextSameIndent",
"b": "pane::ActivateNextItem",
"shift-b": "pane::ActivateLastItem",
"x": "editor::SelectSmallerSyntaxNode",
"d": "editor::GoToDiagnostic",
"c": "editor::GoToHunk",
"space": "vim::InsertEmptyLineBelow"
}
},
{
"context": "vim_operator == helix_previous",
"bindings": {
"z": "vim::PreviousSectionStart",
"shift-z": "vim::PreviousSectionEnd",
"*": "vim::PreviousComment",
"/": "vim::PreviousComment",
"-": "vim::PreviousLesserIndent",
"+": "vim::PreviousGreaterIndent",
"=": "vim::PreviousSameIndent",
"b": "pane::ActivatePreviousItem",
"shift-b": ["pane::ActivateItem", 0],
"x": "editor::SelectLargerSyntaxNode",
"d": "editor::GoToPreviousDiagnostic",
"c": "editor::GoToPreviousHunk",
"space": "vim::InsertEmptyLineAbove"
}
},
{
"context": "vim_operator == c",
"bindings": {

View File

@@ -740,16 +740,6 @@
// Default width of the collaboration panel.
"default_width": 240
},
"chat_panel": {
// When to show the chat panel button in the status bar.
// Can be 'never', 'always', or 'when_in_call',
// or a boolean (interpreted as 'never'/'always').
"button": "when_in_call",
// Where to dock the chat panel. Can be 'left' or 'right'.
"dock": "right",
// Default width of the chat panel.
"default_width": 240
},
"git_panel": {
// Whether to show the git panel button in the status bar.
"button": true,

View File

@@ -1640,13 +1640,13 @@ impl AcpThread {
cx.foreground_executor().spawn(send_task)
}
/// Rewinds this thread to before the entry at `index`, removing it and all
/// subsequent entries while reverting any changes made from that point.
pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
return Task::ready(Err(anyhow!("not supported")));
};
let Some(message) = self.user_message(&id) else {
/// Restores the git working tree to the state at the given checkpoint (if one exists)
pub fn restore_checkpoint(
&mut self,
id: UserMessageId,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let Some((_, message)) = self.user_message_mut(&id) else {
return Task::ready(Err(anyhow!("message not found")));
};
@@ -1654,15 +1654,30 @@ impl AcpThread {
.checkpoint
.as_ref()
.map(|c| c.git_checkpoint.clone());
let rewind = self.rewind(id.clone(), cx);
let git_store = self.project.read(cx).git_store().clone();
cx.spawn(async move |this, cx| {
cx.spawn(async move |_, cx| {
rewind.await?;
if let Some(checkpoint) = checkpoint {
git_store
.update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
.await?;
}
Ok(())
})
}
/// Rewinds this thread to before the entry at `index`, removing it and all
/// subsequent entries while rejecting any action_log changes made from that point.
/// Unlike `restore_checkpoint`, this method does not restore from git.
pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
return Task::ready(Err(anyhow!("not supported")));
};
cx.spawn(async move |this, cx| {
cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
this.update(cx, |this, cx| {
if let Some((ix, _)) = this.user_message_mut(&id) {
@@ -1670,7 +1685,11 @@ impl AcpThread {
this.entries.truncate(ix);
cx.emit(AcpThreadEvent::EntriesRemoved(range));
}
})
this.action_log()
.update(cx, |action_log, cx| action_log.reject_all_edits(cx))
})?
.await;
Ok(())
})
}
@@ -1727,20 +1746,6 @@ impl AcpThread {
})
}
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
self.entries.iter().find_map(|entry| {
if let AgentThreadEntry::UserMessage(message) = entry {
if message.id.as_ref() == Some(id) {
Some(message)
} else {
None
}
} else {
None
}
})
}
fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
if let AgentThreadEntry::UserMessage(message) = entry {
@@ -2684,7 +2689,7 @@ mod tests {
let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
panic!("unexpected entries {:?}", thread.entries)
};
thread.rewind(message.id.clone().unwrap(), cx)
thread.restore_checkpoint(message.id.clone().unwrap(), cx)
})
.await
.unwrap();
@@ -2758,7 +2763,7 @@ mod tests {
}));
let thread = cx
.update(|cx| connection.new_thread(project, Path::new("/test"), cx))
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
.await
.unwrap();

View File

@@ -35,10 +35,15 @@ impl AgentServer for NativeAgentServer {
fn connect(
&self,
_root_dir: &Path,
_root_dir: Option<&Path>,
delegate: AgentServerDelegate,
cx: &mut App,
) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
) -> Task<
Result<(
Rc<dyn acp_thread::AgentConnection>,
Option<task::SpawnInTerminal>,
)>,
> {
log::debug!(
"NativeAgentServer::connect called for path: {:?}",
_root_dir
@@ -60,7 +65,10 @@ impl AgentServer for NativeAgentServer {
let connection = NativeAgentConnection(agent);
log::debug!("NativeAgentServer connection established successfully");
Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
Ok((
Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>,
None,
))
})
}

View File

@@ -24,7 +24,11 @@ impl AgentTool for EchoTool {
acp::ToolKind::Other
}
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
_input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
"Echo".into()
}
@@ -55,7 +59,11 @@ impl AgentTool for DelayTool {
"delay"
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
if let Ok(input) = input {
format!("Delay {}ms", input.ms).into()
} else {
@@ -100,7 +108,11 @@ impl AgentTool for ToolRequiringPermission {
acp::ToolKind::Other
}
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
_input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
"This tool requires permission".into()
}
@@ -135,7 +147,11 @@ impl AgentTool for InfiniteTool {
acp::ToolKind::Other
}
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
_input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
"Infinite Tool".into()
}
@@ -186,7 +202,11 @@ impl AgentTool for WordListTool {
acp::ToolKind::Other
}
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
_input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
"List of random words".into()
}

View File

@@ -741,7 +741,7 @@ impl Thread {
return;
};
let title = tool.initial_title(tool_use.input.clone());
let title = tool.initial_title(tool_use.input.clone(), cx);
let kind = tool.kind();
stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
@@ -1062,7 +1062,11 @@ impl Thread {
self.action_log.clone(),
));
self.add_tool(DiagnosticsTool::new(self.project.clone()));
self.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
self.add_tool(EditFileTool::new(
self.project.clone(),
cx.weak_entity(),
language_registry,
));
self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));
self.add_tool(FindPathTool::new(self.project.clone()));
self.add_tool(GrepTool::new(self.project.clone()));
@@ -1514,7 +1518,7 @@ impl Thread {
let mut title = SharedString::from(&tool_use.name);
let mut kind = acp::ToolKind::Other;
if let Some(tool) = tool.as_ref() {
title = tool.initial_title(tool_use.input.clone());
title = tool.initial_title(tool_use.input.clone(), cx);
kind = tool.kind();
}
@@ -2148,7 +2152,11 @@ where
fn kind() -> acp::ToolKind;
/// The initial tool title to display. Can be updated during the tool run.
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString;
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
cx: &mut App,
) -> SharedString;
/// Returns the JSON schema that describes the tool's input.
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema {
@@ -2196,7 +2204,7 @@ pub trait AnyAgentTool {
fn name(&self) -> SharedString;
fn description(&self) -> SharedString;
fn kind(&self) -> acp::ToolKind;
fn initial_title(&self, input: serde_json::Value) -> SharedString;
fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString;
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool {
true
@@ -2232,9 +2240,9 @@ where
T::kind()
}
fn initial_title(&self, input: serde_json::Value) -> SharedString {
fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString {
let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input);
self.0.initial_title(parsed_input)
self.0.initial_title(parsed_input, _cx)
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {

View File

@@ -145,7 +145,7 @@ impl AnyAgentTool for ContextServerTool {
ToolKind::Other
}
fn initial_title(&self, _input: serde_json::Value) -> SharedString {
fn initial_title(&self, _input: serde_json::Value, _cx: &mut App) -> SharedString {
format!("Run MCP tool `{}`", self.tool.name).into()
}
@@ -176,7 +176,7 @@ impl AnyAgentTool for ContextServerTool {
return Task::ready(Err(anyhow!("Context server not found")));
};
let tool_name = self.tool.name.clone();
let authorize = event_stream.authorize(self.initial_title(input.clone()), cx);
let authorize = event_stream.authorize(self.initial_title(input.clone(), cx), cx);
cx.spawn(async move |_cx| {
authorize.await?;

View File

@@ -58,7 +58,11 @@ impl AgentTool for CopyPathTool {
ToolKind::Move
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> ui::SharedString {
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> ui::SharedString {
if let Ok(input) = input {
let src = MarkdownInlineCode(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path);

View File

@@ -49,7 +49,11 @@ impl AgentTool for CreateDirectoryTool {
ToolKind::Read
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
if let Ok(input) = input {
format!("Create directory {}", MarkdownInlineCode(&input.path)).into()
} else {

View File

@@ -52,7 +52,11 @@ impl AgentTool for DeletePathTool {
ToolKind::Delete
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
if let Ok(input) = input {
format!("Delete “`{}`”", input.path).into()
} else {

View File

@@ -71,7 +71,11 @@ impl AgentTool for DiagnosticsTool {
acp::ToolKind::Read
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
if let Some(path) = input.ok().and_then(|input| match input.path {
Some(path) if !path.is_empty() => Some(path),
_ => None,

View File

@@ -120,11 +120,17 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
pub struct EditFileTool {
thread: WeakEntity<Thread>,
language_registry: Arc<LanguageRegistry>,
project: Entity<Project>,
}
impl EditFileTool {
pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
pub fn new(
project: Entity<Project>,
thread: WeakEntity<Thread>,
language_registry: Arc<LanguageRegistry>,
) -> Self {
Self {
project,
thread,
language_registry,
}
@@ -195,22 +201,50 @@ impl AgentTool for EditFileTool {
acp::ToolKind::Edit
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
cx: &mut App,
) -> SharedString {
match input {
Ok(input) => input.display_description.into(),
Ok(input) => self
.project
.read(cx)
.find_project_path(&input.path, cx)
.and_then(|project_path| {
self.project
.read(cx)
.short_full_path_for_project_path(&project_path, cx)
})
.unwrap_or(Path::new(&input.path).into())
.to_string_lossy()
.to_string()
.into(),
Err(raw_input) => {
if let Some(input) =
serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
{
let path = input.path.trim();
if !path.is_empty() {
return self
.project
.read(cx)
.find_project_path(&input.path, cx)
.and_then(|project_path| {
self.project
.read(cx)
.short_full_path_for_project_path(&project_path, cx)
})
.unwrap_or(Path::new(&input.path).into())
.to_string_lossy()
.to_string()
.into();
}
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string().into();
}
let path = input.path.trim().to_string();
if !path.is_empty() {
return path.into();
}
}
DEFAULT_UI_TEXT.into()
@@ -545,7 +579,7 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
project,
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
Templates::new(),
@@ -560,11 +594,12 @@ mod tests {
path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit,
};
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
Arc::new(EditFileTool::new(
project,
thread.downgrade(),
language_registry,
))
.run(input, ToolCallEventStream::test().0, cx)
})
.await;
assert_eq!(
@@ -743,7 +778,7 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
project,
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
Templates::new(),
@@ -775,6 +810,7 @@ mod tests {
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
language_registry.clone(),
))
@@ -833,11 +869,12 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
language_registry,
))
.run(input, ToolCallEventStream::test().0, cx)
});
// Stream the unformatted content
@@ -885,7 +922,7 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
project,
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
Templates::new(),
@@ -918,6 +955,7 @@ mod tests {
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
language_registry.clone(),
))
@@ -969,11 +1007,12 @@ mod tests {
path: "root/src/main.rs".into(),
mode: EditFileMode::Overwrite,
};
Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
input,
ToolCallEventStream::test().0,
cx,
)
Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
language_registry,
))
.run(input, ToolCallEventStream::test().0, cx)
});
// Stream the content with trailing whitespace
@@ -1012,7 +1051,7 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
project,
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
Templates::new(),
@@ -1020,7 +1059,11 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
let tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
language_registry,
));
fs.insert_tree("/root", json!({})).await;
// Test 1: Path with .zed component should require confirmation
@@ -1148,7 +1191,7 @@ mod tests {
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
project,
project.clone(),
cx.new(|_cx| ProjectContext::default()),
context_server_registry,
Templates::new(),
@@ -1156,7 +1199,11 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
let tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
language_registry,
));
// Test global config paths - these should require confirmation if they exist and are outside the project
let test_cases = vec![
@@ -1264,7 +1311,11 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
let tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
language_registry,
));
// Test files in different worktrees
let test_cases = vec![
@@ -1344,7 +1395,11 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
let tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
language_registry,
));
// Test edge cases
let test_cases = vec![
@@ -1427,7 +1482,11 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
let tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
language_registry,
));
// Test different EditFileMode values
let modes = vec![
@@ -1507,48 +1566,67 @@ mod tests {
cx,
)
});
let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
let tool = Arc::new(EditFileTool::new(
project,
thread.downgrade(),
language_registry,
));
assert_eq!(
tool.initial_title(Err(json!({
"path": "src/main.rs",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
}))),
"src/main.rs"
);
assert_eq!(
tool.initial_title(Err(json!({
"path": "",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
}))),
"Fix error handling"
);
assert_eq!(
tool.initial_title(Err(json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
}))),
"Fix error handling"
);
assert_eq!(
tool.initial_title(Err(json!({
"path": "",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
}))),
DEFAULT_UI_TEXT
);
assert_eq!(
tool.initial_title(Err(serde_json::Value::Null)),
DEFAULT_UI_TEXT
);
cx.update(|cx| {
// ...
assert_eq!(
tool.initial_title(
Err(json!({
"path": "src/main.rs",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
})),
cx
),
"src/main.rs"
);
assert_eq!(
tool.initial_title(
Err(json!({
"path": "",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
})),
cx
),
"Fix error handling"
);
assert_eq!(
tool.initial_title(
Err(json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
})),
cx
),
"src/main.rs"
);
assert_eq!(
tool.initial_title(
Err(json!({
"path": "",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
})),
cx
),
DEFAULT_UI_TEXT
);
assert_eq!(
tool.initial_title(Err(serde_json::Value::Null), cx),
DEFAULT_UI_TEXT
);
});
}
#[gpui::test]
@@ -1575,7 +1653,11 @@ mod tests {
// Ensure the diff is finalized after the edit completes.
{
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
let tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
languages.clone(),
));
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
let edit = cx.update(|cx| {
tool.run(
@@ -1600,7 +1682,11 @@ mod tests {
// Ensure the diff is finalized if an error occurs while editing.
{
model.forbid_requests();
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
let tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
languages.clone(),
));
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
let edit = cx.update(|cx| {
tool.run(
@@ -1623,7 +1709,11 @@ mod tests {
// Ensure the diff is finalized if the tool call gets dropped.
{
let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
let tool = Arc::new(EditFileTool::new(
project.clone(),
thread.downgrade(),
languages.clone(),
));
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
let edit = cx.update(|cx| {
tool.run(

View File

@@ -126,7 +126,11 @@ impl AgentTool for FetchTool {
acp::ToolKind::Fetch
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
match input {
Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)).into(),
Err(_) => "Fetch URL".into(),

View File

@@ -93,7 +93,11 @@ impl AgentTool for FindPathTool {
acp::ToolKind::Search
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
let mut title = "Find paths".to_string();
if let Ok(input) = input {
title.push_str(&format!(" matching “`{}`”", input.glob));

View File

@@ -75,7 +75,11 @@ impl AgentTool for GrepTool {
acp::ToolKind::Search
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
match input {
Ok(input) => {
let page = input.page();

View File

@@ -59,7 +59,11 @@ impl AgentTool for ListDirectoryTool {
ToolKind::Read
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
if let Ok(input) = input {
let path = MarkdownInlineCode(&input.path);
format!("List the {path} directory's contents").into()

View File

@@ -60,7 +60,11 @@ impl AgentTool for MovePathTool {
ToolKind::Move
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
if let Ok(input) = input {
let src = MarkdownInlineCode(&input.source_path);
let dest = MarkdownInlineCode(&input.destination_path);

View File

@@ -41,7 +41,11 @@ impl AgentTool for NowTool {
acp::ToolKind::Other
}
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
_input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
"Get current time".into()
}

View File

@@ -45,7 +45,11 @@ impl AgentTool for OpenTool {
ToolKind::Execute
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
if let Ok(input) = input {
format!("Open `{}`", MarkdownEscaped(&input.path_or_url)).into()
} else {
@@ -61,7 +65,7 @@ impl AgentTool for OpenTool {
) -> Task<Result<Self::Output>> {
// If path_or_url turns out to be a path in the project, make it absolute.
let abs_path = to_absolute_path(&input.path_or_url, self.project.clone(), cx);
let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx);
let authorize = event_stream.authorize(self.initial_title(Ok(input.clone()), cx), cx);
cx.background_spawn(async move {
authorize.await?;

View File

@@ -10,7 +10,7 @@ use project::{AgentLocation, ImageItem, Project, WorktreeSettings, image_store};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::{path::Path, sync::Arc};
use std::sync::Arc;
use util::markdown::MarkdownCodeBlock;
use crate::{AgentTool, ToolCallEventStream};
@@ -68,13 +68,31 @@ impl AgentTool for ReadFileTool {
acp::ToolKind::Read
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
input
.ok()
.as_ref()
.and_then(|input| Path::new(&input.path).file_name())
.map(|file_name| file_name.to_string_lossy().to_string().into())
.unwrap_or_default()
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
cx: &mut App,
) -> SharedString {
if let Ok(input) = input
&& let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx)
&& let Some(path) = self
.project
.read(cx)
.short_full_path_for_project_path(&project_path, cx)
{
match (input.start_line, input.end_line) {
(Some(start), Some(end)) => {
format!("Read file `{}` (lines {}-{})", path.display(), start, end,)
}
(Some(start), None) => {
format!("Read file `{}` (from line {})", path.display(), start)
}
_ => format!("Read file `{}`", path.display()),
}
.into()
} else {
"Read file".into()
}
}
fn run(
@@ -86,6 +104,12 @@ impl AgentTool for ReadFileTool {
let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!("Path {} not found in project", &input.path)));
};
let Some(abs_path) = self.project.read(cx).absolute_path(&project_path, cx) else {
return Task::ready(Err(anyhow!(
"Failed to convert {} to absolute path",
&input.path
)));
};
// Error out if this path is either excluded or private in global settings
let global_settings = WorktreeSettings::get_global(cx);
@@ -121,6 +145,14 @@ impl AgentTool for ReadFileTool {
let file_path = input.path.clone();
event_stream.update_fields(ToolCallUpdateFields {
locations: Some(vec![acp::ToolCallLocation {
path: abs_path,
line: input.start_line.map(|line| line.saturating_sub(1)),
}]),
..Default::default()
});
if image_store::is_image_file(&self.project, &project_path, cx) {
return cx.spawn(async move |cx| {
let image_entity: Entity<ImageItem> = cx
@@ -229,34 +261,25 @@ impl AgentTool for ReadFileTool {
};
project.update(cx, |project, cx| {
if let Some(abs_path) = project.absolute_path(&project_path, cx) {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: anchor.unwrap_or(text::Anchor::MIN),
}),
cx,
);
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: anchor.unwrap_or(text::Anchor::MIN),
}),
cx,
);
if let Ok(LanguageModelToolResultContent::Text(text)) = &result {
let markdown = MarkdownCodeBlock {
tag: &input.path,
text,
}
.to_string();
event_stream.update_fields(ToolCallUpdateFields {
locations: Some(vec![acp::ToolCallLocation {
path: abs_path,
line: input.start_line.map(|line| line.saturating_sub(1)),
content: Some(vec![acp::ToolCallContent::Content {
content: markdown.into(),
}]),
..Default::default()
});
if let Ok(LanguageModelToolResultContent::Text(text)) = &result {
let markdown = MarkdownCodeBlock {
tag: &input.path,
text,
}
.to_string();
event_stream.update_fields(ToolCallUpdateFields {
content: Some(vec![acp::ToolCallContent::Content {
content: markdown.into(),
}]),
..Default::default()
})
}
})
}
})?;

View File

@@ -60,7 +60,11 @@ impl AgentTool for TerminalTool {
acp::ToolKind::Execute
}
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
if let Ok(input) = input {
let mut lines = input.command.lines();
let first_line = lines.next().unwrap_or_default();
@@ -93,7 +97,7 @@ impl AgentTool for TerminalTool {
Err(err) => return Task::ready(Err(err)),
};
let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx);
let authorize = event_stream.authorize(self.initial_title(Ok(input.clone()), cx), cx);
cx.spawn(async move |cx| {
authorize.await?;

View File

@@ -29,7 +29,11 @@ impl AgentTool for ThinkingTool {
acp::ToolKind::Think
}
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
_input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
"Thinking".into()
}

View File

@@ -48,7 +48,11 @@ impl AgentTool for WebSearchTool {
acp::ToolKind::Fetch
}
fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
fn initial_title(
&self,
_input: Result<Self::Input, serde_json::Value>,
_cx: &mut App,
) -> SharedString {
"Searching the Web".into()
}

View File

@@ -23,7 +23,7 @@ action_log.workspace = true
agent-client-protocol.workspace = true
agent_settings.workspace = true
anyhow.workspace = true
client = { workspace = true, optional = true }
client.workspace = true
collections.workspace = true
env_logger = { workspace = true, optional = true }
fs.workspace = true
@@ -35,22 +35,18 @@ language.workspace = true
language_model.workspace = true
language_models.workspace = true
log.workspace = true
node_runtime.workspace = true
paths.workspace = true
project.workspace = true
reqwest_client = { workspace = true, optional = true }
schemars.workspace = true
semver.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
task.workspace = true
tempfile.workspace = true
thiserror.workspace = true
ui.workspace = true
util.workspace = true
watch.workspace = true
which.workspace = true
workspace-hack.workspace = true
[target.'cfg(unix)'.dependencies]

View File

@@ -1,4 +1,3 @@
use crate::AgentServerCommand;
use acp_thread::AgentConnection;
use acp_tools::AcpConnectionRegistry;
use action_log::ActionLog;
@@ -8,8 +7,10 @@ use collections::HashMap;
use futures::AsyncBufReadExt as _;
use futures::io::BufReader;
use project::Project;
use project::agent_server_store::AgentServerCommand;
use serde::Deserialize;
use std::path::PathBuf;
use std::{any::Any, cell::RefCell};
use std::{path::Path, rc::Rc};
use thiserror::Error;
@@ -29,6 +30,7 @@ pub struct AcpConnection {
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
auth_methods: Vec<acp::AuthMethod>,
agent_capabilities: acp::AgentCapabilities,
root_dir: PathBuf,
_io_task: Task<Result<()>>,
_wait_task: Task<Result<()>>,
_stderr_task: Task<Result<()>>,
@@ -43,9 +45,10 @@ pub async fn connect(
server_name: SharedString,
command: AgentServerCommand,
root_dir: &Path,
is_remote: bool,
cx: &mut AsyncApp,
) -> Result<Rc<dyn AgentConnection>> {
let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await?;
let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, is_remote, cx).await?;
Ok(Rc::new(conn) as _)
}
@@ -56,17 +59,21 @@ impl AcpConnection {
server_name: SharedString,
command: AgentServerCommand,
root_dir: &Path,
is_remote: bool,
cx: &mut AsyncApp,
) -> Result<Self> {
let mut child = util::command::new_smol_command(command.path)
let mut child = util::command::new_smol_command(command.path);
child
.args(command.args.iter().map(|arg| arg.as_str()))
.envs(command.env.iter().flatten())
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true)
.spawn()?;
.kill_on_drop(true);
if !is_remote {
child.current_dir(root_dir);
}
let mut child = child.spawn()?;
let stdout = child.stdout.take().context("Failed to take stdout")?;
let stdin = child.stdin.take().context("Failed to take stdin")?;
@@ -145,6 +152,7 @@ impl AcpConnection {
Ok(Self {
auth_methods: response.auth_methods,
root_dir: root_dir.to_owned(),
connection,
server_name,
sessions,
@@ -158,6 +166,10 @@ impl AcpConnection {
pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
&self.agent_capabilities.prompt_capabilities
}
pub fn root_dir(&self) -> &Path {
&self.root_dir
}
}
impl AgentConnection for AcpConnection {
@@ -171,29 +183,36 @@ impl AgentConnection for AcpConnection {
let sessions = self.sessions.clone();
let cwd = cwd.to_path_buf();
let context_server_store = project.read(cx).context_server_store().read(cx);
let mcp_servers = context_server_store
.configured_server_ids()
.iter()
.filter_map(|id| {
let configuration = context_server_store.configuration_for_server(id)?;
let command = configuration.command();
Some(acp::McpServer {
name: id.0.to_string(),
command: command.path.clone(),
args: command.args.clone(),
env: if let Some(env) = command.env.as_ref() {
env.iter()
.map(|(name, value)| acp::EnvVariable {
name: name.clone(),
value: value.clone(),
})
.collect()
} else {
vec![]
},
let mcp_servers = if project.read(cx).is_local() {
context_server_store
.configured_server_ids()
.iter()
.filter_map(|id| {
let configuration = context_server_store.configuration_for_server(id)?;
let command = configuration.command();
Some(acp::McpServer {
name: id.0.to_string(),
command: command.path.clone(),
args: command.args.clone(),
env: if let Some(env) = command.env.as_ref() {
env.iter()
.map(|(name, value)| acp::EnvVariable {
name: name.clone(),
value: value.clone(),
})
.collect()
} else {
vec![]
},
})
})
})
.collect();
.collect()
} else {
// In SSH projects, the external agent is running on the remote
// machine, and currently we only run MCP servers on the local
// machine. So don't pass any MCP servers to the agent in that case.
Vec::new()
};
cx.spawn(async move |cx| {
let response = conn

View File

@@ -2,47 +2,25 @@ mod acp;
mod claude;
mod custom;
mod gemini;
mod settings;
#[cfg(any(test, feature = "test-support"))]
pub mod e2e_tests;
use anyhow::Context as _;
pub use claude::*;
pub use custom::*;
use fs::Fs;
use fs::RemoveOptions;
use fs::RenameOptions;
use futures::StreamExt as _;
pub use gemini::*;
use gpui::AppContext;
use node_runtime::NodeRuntime;
pub use settings::*;
use project::agent_server_store::AgentServerStore;
use acp_thread::AgentConnection;
use acp_thread::LoadError;
use anyhow::Result;
use anyhow::anyhow;
use collections::HashMap;
use gpui::{App, AsyncApp, Entity, SharedString, Task};
use gpui::{App, Entity, SharedString, Task};
use project::Project;
use schemars::JsonSchema;
use semver::Version;
use serde::{Deserialize, Serialize};
use std::str::FromStr as _;
use std::{
any::Any,
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
};
use util::ResultExt as _;
use std::{any::Any, path::Path, rc::Rc};
pub fn init(cx: &mut App) {
settings::init(cx);
}
pub use acp::AcpConnection;
pub struct AgentServerDelegate {
store: Entity<AgentServerStore>,
project: Entity<Project>,
status_tx: Option<watch::Sender<SharedString>>,
new_version_available: Option<watch::Sender<Option<String>>>,
@@ -50,11 +28,13 @@ pub struct AgentServerDelegate {
impl AgentServerDelegate {
pub fn new(
store: Entity<AgentServerStore>,
project: Entity<Project>,
status_tx: Option<watch::Sender<SharedString>>,
new_version_tx: Option<watch::Sender<Option<String>>>,
) -> Self {
Self {
store,
project,
status_tx,
new_version_available: new_version_tx,
@@ -64,188 +44,6 @@ impl AgentServerDelegate {
pub fn project(&self) -> &Entity<Project> {
&self.project
}
fn get_or_npm_install_builtin_agent(
self,
binary_name: SharedString,
package_name: SharedString,
entrypoint_path: PathBuf,
ignore_system_version: bool,
minimum_version: Option<Version>,
cx: &mut App,
) -> Task<Result<AgentServerCommand>> {
let project = self.project;
let fs = project.read(cx).fs().clone();
let Some(node_runtime) = project.read(cx).node_runtime().cloned() else {
return Task::ready(Err(anyhow!(
"External agents are not yet available in remote projects."
)));
};
let status_tx = self.status_tx;
let new_version_available = self.new_version_available;
cx.spawn(async move |cx| {
if !ignore_system_version {
if let Some(bin) = find_bin_in_path(binary_name.clone(), &project, cx).await {
return Ok(AgentServerCommand {
path: bin,
args: Vec::new(),
env: Default::default(),
});
}
}
cx.spawn(async move |cx| {
let node_path = node_runtime.binary_path().await?;
let dir = paths::data_dir()
.join("external_agents")
.join(binary_name.as_str());
fs.create_dir(&dir).await?;
let mut stream = fs.read_dir(&dir).await?;
let mut versions = Vec::new();
let mut to_delete = Vec::new();
while let Some(entry) = stream.next().await {
let Ok(entry) = entry else { continue };
let Some(file_name) = entry.file_name() else {
continue;
};
if let Some(name) = file_name.to_str()
&& let Some(version) = semver::Version::from_str(name).ok()
&& fs
.is_file(&dir.join(file_name).join(&entrypoint_path))
.await
{
versions.push((version, file_name.to_owned()));
} else {
to_delete.push(file_name.to_owned())
}
}
versions.sort();
let newest_version = if let Some((version, file_name)) = versions.last().cloned()
&& minimum_version.is_none_or(|minimum_version| version >= minimum_version)
{
versions.pop();
Some(file_name)
} else {
None
};
log::debug!("existing version of {package_name}: {newest_version:?}");
to_delete.extend(versions.into_iter().map(|(_, file_name)| file_name));
cx.background_spawn({
let fs = fs.clone();
let dir = dir.clone();
async move {
for file_name in to_delete {
fs.remove_dir(
&dir.join(file_name),
RemoveOptions {
recursive: true,
ignore_if_not_exists: false,
},
)
.await
.ok();
}
}
})
.detach();
let version = if let Some(file_name) = newest_version {
cx.background_spawn({
let file_name = file_name.clone();
let dir = dir.clone();
let fs = fs.clone();
async move {
let latest_version =
node_runtime.npm_package_latest_version(&package_name).await;
if let Ok(latest_version) = latest_version
&& &latest_version != &file_name.to_string_lossy()
{
Self::download_latest_version(
fs,
dir.clone(),
node_runtime,
package_name,
)
.await
.log_err();
if let Some(mut new_version_available) = new_version_available {
new_version_available.send(Some(latest_version)).ok();
}
}
}
})
.detach();
file_name
} else {
if let Some(mut status_tx) = status_tx {
status_tx.send("Installing…".into()).ok();
}
let dir = dir.clone();
cx.background_spawn(Self::download_latest_version(
fs.clone(),
dir.clone(),
node_runtime,
package_name,
))
.await?
.into()
};
let agent_server_path = dir.join(version).join(entrypoint_path);
let agent_server_path_exists = fs.is_file(&agent_server_path).await;
anyhow::ensure!(
agent_server_path_exists,
"Missing entrypoint path {} after installation",
agent_server_path.to_string_lossy()
);
anyhow::Ok(AgentServerCommand {
path: node_path,
args: vec![agent_server_path.to_string_lossy().to_string()],
env: Default::default(),
})
})
.await
.map_err(|e| LoadError::FailedToInstall(e.to_string().into()).into())
})
}
async fn download_latest_version(
fs: Arc<dyn Fs>,
dir: PathBuf,
node_runtime: NodeRuntime,
package_name: SharedString,
) -> Result<String> {
log::debug!("downloading latest version of {package_name}");
let tmp_dir = tempfile::tempdir_in(&dir)?;
node_runtime
.npm_install_packages(tmp_dir.path(), &[(&package_name, "latest")])
.await?;
let version = node_runtime
.npm_package_installed_version(tmp_dir.path(), &package_name)
.await?
.context("expected package to be installed")?;
fs.rename(
&tmp_dir.keep(),
&dir.join(&version),
RenameOptions {
ignore_if_exists: true,
overwrite: false,
},
)
.await?;
anyhow::Ok(version)
}
}
pub trait AgentServer: Send {
@@ -255,10 +53,10 @@ pub trait AgentServer: Send {
fn connect(
&self,
root_dir: &Path,
root_dir: Option<&Path>,
delegate: AgentServerDelegate,
cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>>;
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>>;
fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
}
@@ -268,120 +66,3 @@ impl dyn AgentServer {
self.into_any().downcast().ok()
}
}
impl std::fmt::Debug for AgentServerCommand {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let filtered_env = self.env.as_ref().map(|env| {
env.iter()
.map(|(k, v)| {
(
k,
if util::redact::should_redact(k) {
"[REDACTED]"
} else {
v
},
)
})
.collect::<Vec<_>>()
});
f.debug_struct("AgentServerCommand")
.field("path", &self.path)
.field("args", &self.args)
.field("env", &filtered_env)
.finish()
}
}
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
pub struct AgentServerCommand {
#[serde(rename = "command")]
pub path: PathBuf,
#[serde(default)]
pub args: Vec<String>,
pub env: Option<HashMap<String, String>>,
}
impl AgentServerCommand {
pub async fn resolve(
path_bin_name: &'static str,
extra_args: &[&'static str],
fallback_path: Option<&Path>,
settings: Option<BuiltinAgentServerSettings>,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Option<Self> {
if let Some(settings) = settings
&& let Some(command) = settings.custom_command()
{
Some(command)
} else {
match find_bin_in_path(path_bin_name.into(), project, cx).await {
Some(path) => Some(Self {
path,
args: extra_args.iter().map(|arg| arg.to_string()).collect(),
env: None,
}),
None => fallback_path.and_then(|path| {
if path.exists() {
Some(Self {
path: path.to_path_buf(),
args: extra_args.iter().map(|arg| arg.to_string()).collect(),
env: None,
})
} else {
None
}
}),
}
}
}
}
async fn find_bin_in_path(
bin_name: SharedString,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Option<PathBuf> {
let (env_task, root_dir) = project
.update(cx, |project, cx| {
let worktree = project.visible_worktrees(cx).next();
match worktree {
Some(worktree) => {
let env_task = project.environment().update(cx, |env, cx| {
env.get_worktree_environment(worktree.clone(), cx)
});
let path = worktree.read(cx).abs_path();
(env_task, path)
}
None => {
let path: Arc<Path> = paths::home_dir().as_path().into();
let env_task = project.environment().update(cx, |env, cx| {
env.get_directory_environment(path.clone(), cx)
});
(env_task, path)
}
}
})
.log_err()?;
cx.background_executor()
.spawn(async move {
let which_result = if cfg!(windows) {
which::which(bin_name.as_str())
} else {
let env = env_task.await.unwrap_or_default();
let shell_path = env.get("PATH").cloned();
which::which_in(bin_name.as_str(), shell_path.as_ref(), root_dir.as_ref())
};
if let Err(which::Error::CannotFindBinaryPath) = which_result {
return None;
}
which_result.log_err()
})
.await
}

View File

@@ -1,60 +1,22 @@
use settings::SettingsStore;
use std::path::Path;
use std::rc::Rc;
use std::{any::Any, path::PathBuf};
use anyhow::Result;
use gpui::{App, AppContext as _, SharedString, Task};
use anyhow::{Context as _, Result};
use gpui::{App, SharedString, Task};
use project::agent_server_store::CLAUDE_CODE_NAME;
use crate::{AgentServer, AgentServerDelegate, AllAgentServersSettings};
use crate::{AgentServer, AgentServerDelegate};
use acp_thread::AgentConnection;
#[derive(Clone)]
pub struct ClaudeCode;
pub struct ClaudeCodeLoginCommand {
pub struct AgentServerLoginCommand {
pub path: PathBuf,
pub arguments: Vec<String>,
}
impl ClaudeCode {
const BINARY_NAME: &'static str = "claude-code-acp";
const PACKAGE_NAME: &'static str = "@zed-industries/claude-code-acp";
pub fn login_command(
delegate: AgentServerDelegate,
cx: &mut App,
) -> Task<Result<ClaudeCodeLoginCommand>> {
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).claude.clone()
});
cx.spawn(async move |cx| {
let mut command = if let Some(settings) = settings {
settings.command
} else {
cx.update(|cx| {
delegate.get_or_npm_install_builtin_agent(
Self::BINARY_NAME.into(),
Self::PACKAGE_NAME.into(),
"node_modules/@anthropic-ai/claude-code/cli.js".into(),
true,
Some("0.2.5".parse().unwrap()),
cx,
)
})?
.await?
};
command.args.push("/login".into());
Ok(ClaudeCodeLoginCommand {
path: command.path,
arguments: command.args,
})
})
}
}
impl AgentServer for ClaudeCode {
fn telemetry_id(&self) -> &'static str {
"claude-code"
@@ -70,56 +32,42 @@ impl AgentServer for ClaudeCode {
fn connect(
&self,
root_dir: &Path,
root_dir: Option<&Path>,
delegate: AgentServerDelegate,
cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let root_dir = root_dir.to_path_buf();
let fs = delegate.project().read(cx).fs().clone();
let server_name = self.name();
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).claude.clone()
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
let name = self.name();
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().to_string());
let is_remote = delegate.project.read(cx).is_via_remote_server();
let store = delegate.store.downgrade();
// Get the project environment variables for the root directory
let project_env = delegate.project().update(cx, |project, cx| {
if let Some(path) = project.active_project_directory(cx) {
Some(project.directory_environment(path, cx))
} else {
None
}
});
let project = delegate.project().clone();
cx.spawn(async move |cx| {
let mut project_env = project
.update(cx, |project, cx| {
project.directory_environment(root_dir.as_path().into(), cx)
})?
.await
.unwrap_or_default();
let mut command = if let Some(settings) = settings {
settings.command
} else {
cx.update(|cx| {
delegate.get_or_npm_install_builtin_agent(
Self::BINARY_NAME.into(),
Self::PACKAGE_NAME.into(),
format!("node_modules/{}/dist/index.js", Self::PACKAGE_NAME).into(),
true,
None,
cx,
)
})?
.await?
};
project_env.extend(command.env.take().unwrap_or_default());
command.env = Some(project_env);
command
.env
.get_or_insert_default()
.insert("ANTHROPIC_API_KEY".to_owned(), "".to_owned());
let root_dir_exists = fs.is_dir(&root_dir).await;
anyhow::ensure!(
root_dir_exists,
"Session root {} does not exist or is not a directory",
root_dir.to_string_lossy()
);
crate::acp::connect(server_name, command.clone(), &root_dir, cx).await
let (command, root_dir, login) = store
.update(cx, |store, cx| {
let agent = store
.get_external_agent(&CLAUDE_CODE_NAME.into())
.context("Claude Code is not registered")?;
anyhow::Ok(agent.get_command(
root_dir.as_deref(),
Default::default(),
delegate.status_tx,
delegate.new_version_available,
&mut cx.to_async(),
))
})??
.await?;
let connection =
crate::acp::connect(name, command, root_dir.as_ref(), is_remote, cx).await?;
Ok((connection, login))
})
}

View File

@@ -1,19 +1,19 @@
use crate::{AgentServerCommand, AgentServerDelegate};
use crate::AgentServerDelegate;
use acp_thread::AgentConnection;
use anyhow::Result;
use anyhow::{Context as _, Result};
use gpui::{App, SharedString, Task};
use project::agent_server_store::ExternalAgentServerName;
use std::{path::Path, rc::Rc};
use ui::IconName;
/// A generic agent server implementation for custom user-defined agents
pub struct CustomAgentServer {
name: SharedString,
command: AgentServerCommand,
}
impl CustomAgentServer {
pub fn new(name: SharedString, command: AgentServerCommand) -> Self {
Self { name, command }
pub fn new(name: SharedString) -> Self {
Self { name }
}
}
@@ -32,14 +32,36 @@ impl crate::AgentServer for CustomAgentServer {
fn connect(
&self,
root_dir: &Path,
_delegate: AgentServerDelegate,
root_dir: Option<&Path>,
delegate: AgentServerDelegate,
cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let server_name = self.name();
let command = self.command.clone();
let root_dir = root_dir.to_path_buf();
cx.spawn(async move |cx| crate::acp::connect(server_name, command, &root_dir, cx).await)
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
let name = self.name();
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().to_string());
let is_remote = delegate.project.read(cx).is_via_remote_server();
let store = delegate.store.downgrade();
cx.spawn(async move |cx| {
let (command, root_dir, login) = store
.update(cx, |store, cx| {
let agent = store
.get_external_agent(&ExternalAgentServerName(name.clone()))
.with_context(|| {
format!("Custom agent server `{}` is not registered", name)
})?;
anyhow::Ok(agent.get_command(
root_dir.as_deref(),
Default::default(),
delegate.status_tx,
delegate.new_version_available,
&mut cx.to_async(),
))
})??
.await?;
let connection =
crate::acp::connect(name, command, root_dir.as_ref(), is_remote, cx).await?;
Ok((connection, login))
})
}
fn into_any(self: Rc<Self>) -> Rc<dyn std::any::Any> {

View File

@@ -1,12 +1,12 @@
use crate::{AgentServer, AgentServerDelegate};
#[cfg(test)]
use crate::{AgentServerCommand, CustomAgentServerSettings};
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
use agent_client_protocol as acp;
use futures::{FutureExt, StreamExt, channel::mpsc, select};
use gpui::{AppContext, Entity, TestAppContext};
use indoc::indoc;
use project::{FakeFs, Project};
#[cfg(test)]
use project::agent_server_store::{AgentServerCommand, CustomAgentServerSettings};
use project::{FakeFs, Project, agent_server_store::AllAgentServersSettings};
use std::{
path::{Path, PathBuf},
sync::Arc,
@@ -449,7 +449,6 @@ pub use common_e2e_tests;
// Helpers
pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
#[cfg(test)]
use settings::Settings;
env_logger::try_init().ok();
@@ -468,11 +467,11 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
language_model::init(client.clone(), cx);
language_models::init(user_store, client, cx);
agent_settings::init(cx);
crate::settings::init(cx);
AllAgentServersSettings::register(cx);
#[cfg(test)]
crate::AllAgentServersSettings::override_global(
crate::AllAgentServersSettings {
AllAgentServersSettings::override_global(
AllAgentServersSettings {
claude: Some(CustomAgentServerSettings {
command: AgentServerCommand {
path: "claude-code-acp".into(),
@@ -498,10 +497,11 @@ pub async fn new_test_thread(
current_dir: impl AsRef<Path>,
cx: &mut TestAppContext,
) -> Entity<AcpThread> {
let delegate = AgentServerDelegate::new(project.clone(), None, None);
let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
let connection = cx
.update(|cx| server.connect(current_dir.as_ref(), delegate, cx))
let (connection, _) = cx
.update(|cx| server.connect(Some(current_dir.as_ref()), delegate, cx))
.await
.unwrap();

View File

@@ -1,21 +1,19 @@
use std::rc::Rc;
use std::{any::Any, path::Path};
use crate::acp::AcpConnection;
use crate::{AgentServer, AgentServerDelegate};
use acp_thread::{AgentConnection, LoadError};
use anyhow::Result;
use gpui::{App, AppContext as _, SharedString, Task};
use acp_thread::AgentConnection;
use anyhow::{Context as _, Result};
use client::ProxySettings;
use collections::HashMap;
use gpui::{App, AppContext, SharedString, Task};
use language_models::provider::google::GoogleLanguageModelProvider;
use project::agent_server_store::GEMINI_NAME;
use settings::SettingsStore;
use crate::AllAgentServersSettings;
#[derive(Clone)]
pub struct Gemini;
const ACP_ARG: &str = "--experimental-acp";
impl AgentServer for Gemini {
fn telemetry_id(&self) -> &'static str {
"gemini-cli"
@@ -31,126 +29,49 @@ impl AgentServer for Gemini {
fn connect(
&self,
root_dir: &Path,
root_dir: Option<&Path>,
delegate: AgentServerDelegate,
cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let root_dir = root_dir.to_path_buf();
let fs = delegate.project().read(cx).fs().clone();
let server_name = self.name();
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).gemini.clone()
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
let name = self.name();
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().to_string());
let is_remote = delegate.project.read(cx).is_via_remote_server();
let store = delegate.store.downgrade();
let proxy_url = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<ProxySettings>(None).proxy.clone()
});
let project = delegate.project().clone();
cx.spawn(async move |cx| {
let ignore_system_version = settings
.as_ref()
.and_then(|settings| settings.ignore_system_version)
.unwrap_or(true);
let mut project_env = project
.update(cx, |project, cx| {
project.directory_environment(root_dir.as_path().into(), cx)
})?
.await
.unwrap_or_default();
let mut command = if let Some(settings) = settings
&& let Some(command) = settings.custom_command()
{
command
} else {
cx.update(|cx| {
delegate.get_or_npm_install_builtin_agent(
Self::BINARY_NAME.into(),
Self::PACKAGE_NAME.into(),
format!("node_modules/{}/dist/index.js", Self::PACKAGE_NAME).into(),
ignore_system_version,
Some(Self::MINIMUM_VERSION.parse().unwrap()),
cx,
)
})?
.await?
};
if !command.args.contains(&ACP_ARG.into()) {
command.args.push(ACP_ARG.into());
}
let mut extra_env = HashMap::default();
if let Some(api_key) = cx.update(GoogleLanguageModelProvider::api_key)?.await.ok() {
project_env
.insert("GEMINI_API_KEY".to_owned(), api_key.key);
extra_env.insert("GEMINI_API_KEY".into(), api_key.key);
}
project_env.extend(command.env.take().unwrap_or_default());
command.env = Some(project_env);
let (mut command, root_dir, login) = store
.update(cx, |store, cx| {
let agent = store
.get_external_agent(&GEMINI_NAME.into())
.context("Gemini CLI is not registered")?;
anyhow::Ok(agent.get_command(
root_dir.as_deref(),
extra_env,
delegate.status_tx,
delegate.new_version_available,
&mut cx.to_async(),
))
})??
.await?;
let root_dir_exists = fs.is_dir(&root_dir).await;
anyhow::ensure!(
root_dir_exists,
"Session root {} does not exist or is not a directory",
root_dir.to_string_lossy()
);
let result = crate::acp::connect(server_name, command.clone(), &root_dir, cx).await;
match &result {
Ok(connection) => {
if let Some(connection) = connection.clone().downcast::<AcpConnection>()
&& !connection.prompt_capabilities().image
{
let version_output = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.arg("--version")
.kill_on_drop(true)
.output()
.await;
let current_version =
String::from_utf8(version_output?.stdout)?.trim().to_owned();
log::error!("connected to gemini, but missing prompt_capabilities.image (version is {current_version})");
return Err(LoadError::Unsupported {
current_version: current_version.into(),
command: (command.path.to_string_lossy().to_string() + " " + &command.args.join(" ")).into(),
minimum_version: Self::MINIMUM_VERSION.into(),
}
.into());
}
}
Err(e) => {
let version_fut = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.arg("--version")
.kill_on_drop(true)
.output();
let help_fut = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.arg("--help")
.kill_on_drop(true)
.output();
let (version_output, help_output) =
futures::future::join(version_fut, help_fut).await;
let Some(version_output) = version_output.ok().and_then(|output| String::from_utf8(output.stdout).ok()) else {
return result;
};
let Some((help_stdout, help_stderr)) = help_output.ok().and_then(|output| String::from_utf8(output.stdout).ok().zip(String::from_utf8(output.stderr).ok())) else {
return result;
};
let current_version = version_output.trim().to_string();
let supported = help_stdout.contains(ACP_ARG) || current_version.parse::<semver::Version>().is_ok_and(|version| version >= Self::MINIMUM_VERSION.parse::<semver::Version>().unwrap());
log::error!("failed to create ACP connection to gemini (version is {current_version}, supported: {supported}): {e}");
log::debug!("gemini --help stdout: {help_stdout:?}");
log::debug!("gemini --help stderr: {help_stderr:?}");
if !supported {
return Err(LoadError::Unsupported {
current_version: current_version.into(),
command: (command.path.to_string_lossy().to_string() + " " + &command.args.join(" ")).into(),
minimum_version: Self::MINIMUM_VERSION.into(),
}
.into());
}
}
// Add proxy flag if proxy settings are configured in Zed and not in the args
if let Some(proxy_url_value) = &proxy_url
&& !command.args.iter().any(|arg| arg.contains("--proxy"))
{
command.args.push("--proxy".into());
command.args.push(proxy_url_value.clone());
}
result
let connection =
crate::acp::connect(name, command, root_dir.as_ref(), is_remote, cx).await?;
Ok((connection, login))
})
}
@@ -159,18 +80,11 @@ impl AgentServer for Gemini {
}
}
impl Gemini {
const PACKAGE_NAME: &str = "@google/gemini-cli";
const MINIMUM_VERSION: &str = "0.2.1";
const BINARY_NAME: &str = "gemini";
}
#[cfg(test)]
pub(crate) mod tests {
use project::agent_server_store::AgentServerCommand;
use super::*;
use crate::AgentServerCommand;
use std::path::Path;
crate::common_e2e_tests!(async |_, _, _| Gemini, allow_option_id = "proceed_once");

View File

@@ -1,110 +0,0 @@
use std::path::PathBuf;
use crate::AgentServerCommand;
use anyhow::Result;
use collections::HashMap;
use gpui::{App, SharedString};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsKey, SettingsSources, SettingsUi};
pub fn init(cx: &mut App) {
AllAgentServersSettings::register(cx);
}
#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug, SettingsUi, SettingsKey)]
#[settings_key(key = "agent_servers")]
pub struct AllAgentServersSettings {
pub gemini: Option<BuiltinAgentServerSettings>,
pub claude: Option<CustomAgentServerSettings>,
/// Custom agent servers configured by the user
#[serde(flatten)]
pub custom: HashMap<SharedString, CustomAgentServerSettings>,
}
#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug, PartialEq)]
pub struct BuiltinAgentServerSettings {
/// Absolute path to a binary to be used when launching this agent.
///
/// This can be used to run a specific binary without automatic downloads or searching `$PATH`.
#[serde(rename = "command")]
pub path: Option<PathBuf>,
/// If a binary is specified in `command`, it will be passed these arguments.
pub args: Option<Vec<String>>,
/// If a binary is specified in `command`, it will be passed these environment variables.
pub env: Option<HashMap<String, String>>,
/// Whether to skip searching `$PATH` for an agent server binary when
/// launching this agent.
///
/// This has no effect if a `command` is specified. Otherwise, when this is
/// `false`, Zed will search `$PATH` for an agent server binary and, if one
/// is found, use it for threads with this agent. If no agent binary is
/// found on `$PATH`, Zed will automatically install and use its own binary.
/// When this is `true`, Zed will not search `$PATH`, and will always use
/// its own binary.
///
/// Default: true
pub ignore_system_version: Option<bool>,
}
impl BuiltinAgentServerSettings {
pub(crate) fn custom_command(self) -> Option<AgentServerCommand> {
self.path.map(|path| AgentServerCommand {
path,
args: self.args.unwrap_or_default(),
env: self.env,
})
}
}
impl From<AgentServerCommand> for BuiltinAgentServerSettings {
fn from(value: AgentServerCommand) -> Self {
BuiltinAgentServerSettings {
path: Some(value.path),
args: Some(value.args),
env: value.env,
..Default::default()
}
}
}
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug, PartialEq)]
pub struct CustomAgentServerSettings {
#[serde(flatten)]
pub command: AgentServerCommand,
}
impl settings::Settings for AllAgentServersSettings {
type FileContent = Self;
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
let mut settings = AllAgentServersSettings::default();
for AllAgentServersSettings {
gemini,
claude,
custom,
} in sources.defaults_and_customizations()
{
if gemini.is_some() {
settings.gemini = gemini.clone();
}
if claude.is_some() {
settings.claude = claude.clone();
}
// Merge custom agents
for (name, config) in custom {
// Skip built-in agent names to avoid conflicts
if name != "gemini" && name != "claude" {
settings.custom.insert(name.clone(), config.clone());
}
}
}
Ok(settings)
}
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
}

View File

@@ -1066,13 +1066,21 @@ struct MentionCompletion {
impl MentionCompletion {
fn try_parse(allow_non_file_mentions: bool, line: &str, offset_to_line: usize) -> Option<Self> {
let last_mention_start = line.rfind('@')?;
if last_mention_start >= line.len() {
return Some(Self::default());
// No whitespace immediately after '@'
if line[last_mention_start + 1..]
.chars()
.next()
.is_some_and(|c| c.is_whitespace())
{
return None;
}
// Must be a word boundary before '@'
if last_mention_start > 0
&& line
&& line[..last_mention_start]
.chars()
.nth(last_mention_start - 1)
.last()
.is_some_and(|c| !c.is_whitespace())
{
return None;
@@ -1085,7 +1093,9 @@ impl MentionCompletion {
let mut parts = rest_of_line.split_whitespace();
let mut end = last_mention_start + 1;
if let Some(mode_text) = parts.next() {
// Safe since we check no leading whitespace above
end += mode_text.len();
if let Some(parsed_mode) = ContextPickerMode::try_from(mode_text).ok()
@@ -1278,5 +1288,23 @@ mod tests {
argument: Some("main".to_string()),
})
);
assert_eq!(
MentionCompletion::try_parse(true, "Lorem@symbol", 0),
None,
"Should not parse mention inside word"
);
assert_eq!(
MentionCompletion::try_parse(true, "Lorem @ file", 0),
None,
"Should not parse with a space after @"
);
assert_eq!(
MentionCompletion::try_parse(true, "@ file", 0),
None,
"Should not parse with a space after @ at the start of the line"
);
}
}

View File

@@ -699,10 +699,15 @@ impl MessageEditor {
self.project.read(cx).fs().clone(),
self.history_store.clone(),
));
let delegate = AgentServerDelegate::new(self.project.clone(), None, None);
let connection = server.connect(Path::new(""), delegate, cx);
let delegate = AgentServerDelegate::new(
self.project.read(cx).agent_server_store().clone(),
self.project.clone(),
None,
None,
);
let connection = server.connect(None, delegate, cx);
cx.spawn(async move |_, cx| {
let agent = connection.await?;
let (agent, _) = connection.await?;
let agent = agent.downcast::<agent2::NativeAgentConnection>().unwrap();
let summary = agent
.0

View File

@@ -5,7 +5,8 @@ use agent_client_protocol as acp;
use gpui::{Entity, FocusHandle};
use picker::popover_menu::PickerPopoverMenu;
use ui::{
ButtonLike, Context, IntoElement, PopoverMenuHandle, SharedString, Tooltip, Window, prelude::*,
ButtonLike, Context, IntoElement, PopoverMenuHandle, SharedString, TintColor, Tooltip, Window,
prelude::*,
};
use zed_actions::agent::ToggleModelSelector;
@@ -58,15 +59,22 @@ impl Render for AcpModelSelectorPopover {
let focus_handle = self.focus_handle.clone();
let color = if self.menu_handle.is_deployed() {
Color::Accent
} else {
Color::Muted
};
PickerPopoverMenu::new(
self.selector.clone(),
ButtonLike::new("active-model")
.when_some(model_icon, |this, icon| {
this.child(Icon::new(icon).color(Color::Muted).size(IconSize::XSmall))
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
})
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
.child(
Label::new(model_name)
.color(Color::Muted)
.color(color)
.size(LabelSize::Small)
.ml_0p5(),
)

View File

@@ -6,7 +6,7 @@ use acp_thread::{
use acp_thread::{AgentConnection, Plan};
use action_log::ActionLog;
use agent_client_protocol::{self as acp, PromptCapabilities};
use agent_servers::{AgentServer, AgentServerDelegate, ClaudeCode};
use agent_servers::{AgentServer, AgentServerDelegate};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, NotifyWhenAgentWaiting};
use agent2::{DbThreadMetadata, HistoryEntry, HistoryEntryId, HistoryStore, NativeAgentServer};
use anyhow::{Context as _, Result, anyhow, bail};
@@ -40,7 +40,6 @@ use std::path::Path;
use std::sync::Arc;
use std::time::Instant;
use std::{collections::BTreeMap, rc::Rc, time::Duration};
use task::SpawnInTerminal;
use terminal_view::terminal_panel::TerminalPanel;
use text::Anchor;
use theme::{AgentFontSize, ThemeSettings};
@@ -263,6 +262,7 @@ pub struct AcpThreadView {
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
thread_state: ThreadState,
login: Option<task::SpawnInTerminal>,
history_store: Entity<HistoryStore>,
hovered_recent_history_item: Option<usize>,
entry_view_state: Entity<EntryViewState>,
@@ -392,6 +392,7 @@ impl AcpThreadView {
project: project.clone(),
entry_view_state,
thread_state: Self::initial_state(agent, resume_thread, workspace, project, window, cx),
login: None,
message_editor,
model_selector: None,
profile_selector: None,
@@ -444,9 +445,11 @@ impl AcpThreadView {
window: &mut Window,
cx: &mut Context<Self>,
) -> ThreadState {
if !project.read(cx).is_local() && agent.clone().downcast::<NativeAgentServer>().is_none() {
if project.read(cx).is_via_collab()
&& agent.clone().downcast::<NativeAgentServer>().is_none()
{
return ThreadState::LoadError(LoadError::Other(
"External agents are not yet supported for remote projects.".into(),
"External agents are not yet supported in shared projects.".into(),
));
}
let mut worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
@@ -466,20 +469,23 @@ impl AcpThreadView {
Some(worktree.read(cx).abs_path())
}
})
.next()
.unwrap_or_else(|| paths::home_dir().as_path().into());
.next();
let (status_tx, mut status_rx) = watch::channel("Loading…".into());
let (new_version_available_tx, mut new_version_available_rx) = watch::channel(None);
let delegate = AgentServerDelegate::new(
project.read(cx).agent_server_store().clone(),
project.clone(),
Some(status_tx),
Some(new_version_available_tx),
);
let connect_task = agent.connect(&root_dir, delegate, cx);
let connect_task = agent.connect(root_dir.as_deref(), delegate, cx);
let load_task = cx.spawn_in(window, async move |this, cx| {
let connection = match connect_task.await {
Ok(connection) => connection,
Ok((connection, login)) => {
this.update(cx, |this, _| this.login = login).ok();
connection
}
Err(err) => {
this.update_in(cx, |this, window, cx| {
if err.downcast_ref::<LoadError>().is_some() {
@@ -506,6 +512,14 @@ impl AcpThreadView {
})
.log_err()
} else {
let root_dir = if let Some(acp_agent) = connection
.clone()
.downcast::<agent_servers::AcpConnection>()
{
acp_agent.root_dir().into()
} else {
root_dir.unwrap_or(paths::home_dir().as_path().into())
};
cx.update(|_, cx| {
connection
.clone()
@@ -913,7 +927,7 @@ impl AcpThreadView {
}
}
ViewEvent::MessageEditorEvent(editor, MessageEditorEvent::Send) => {
self.regenerate(event.entry_index, editor, window, cx);
self.regenerate(event.entry_index, editor.clone(), window, cx);
}
ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Cancel) => {
self.cancel_editing(&Default::default(), window, cx);
@@ -1137,7 +1151,7 @@ impl AcpThreadView {
fn regenerate(
&mut self,
entry_ix: usize,
message_editor: &Entity<MessageEditor>,
message_editor: Entity<MessageEditor>,
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -1154,16 +1168,18 @@ impl AcpThreadView {
return;
};
let contents = message_editor.update(cx, |message_editor, cx| message_editor.contents(cx));
let task = cx.spawn(async move |_, cx| {
let contents = contents.await?;
cx.spawn_in(window, async move |this, cx| {
thread
.update(cx, |thread, cx| thread.rewind(user_message_id, cx))?
.await?;
Ok(contents)
});
self.send_impl(task, window, cx);
let contents =
message_editor.update(cx, |message_editor, cx| message_editor.contents(cx))?;
this.update_in(cx, |this, window, cx| {
this.send_impl(contents, window, cx);
})?;
anyhow::Ok(())
})
.detach();
}
fn open_agent_diff(&mut self, _: &OpenAgentDiff, window: &mut Window, cx: &mut Context<Self>) {
@@ -1462,9 +1478,12 @@ impl AcpThreadView {
self.thread_error.take();
configuration_view.take();
pending_auth_method.replace(method.clone());
let authenticate = if method.0.as_ref() == "claude-login" {
let authenticate = if (method.0.as_ref() == "claude-login"
|| method.0.as_ref() == "spawn-gemini-cli")
&& let Some(login) = self.login.clone()
{
if let Some(workspace) = self.workspace.upgrade() {
Self::spawn_claude_login(&workspace, window, cx)
Self::spawn_external_agent_login(login, workspace, false, window, cx)
} else {
Task::ready(Ok(()))
}
@@ -1511,31 +1530,28 @@ impl AcpThreadView {
}));
}
fn spawn_claude_login(
workspace: &Entity<Workspace>,
fn spawn_external_agent_login(
login: task::SpawnInTerminal,
workspace: Entity<Workspace>,
previous_attempt: bool,
window: &mut Window,
cx: &mut App,
) -> Task<Result<()>> {
let Some(terminal_panel) = workspace.read(cx).panel::<TerminalPanel>(cx) else {
return Task::ready(Ok(()));
};
let project_entity = workspace.read(cx).project();
let project = project_entity.read(cx);
let cwd = project.first_project_directory(cx);
let shell = project.terminal_settings(&cwd, cx).shell.clone();
let delegate = AgentServerDelegate::new(project_entity.clone(), None, None);
let command = ClaudeCode::login_command(delegate, cx);
let project = workspace.read(cx).project().clone();
let cwd = project.read(cx).first_project_directory(cx);
let shell = project.read(cx).terminal_settings(&cwd, cx).shell.clone();
window.spawn(cx, async move |cx| {
let login_command = command.await?;
let command = login_command
.path
.to_str()
.with_context(|| format!("invalid login command: {:?}", login_command.path))?;
let command = shlex::try_quote(command)?;
let args = login_command
.arguments
let mut task = login.clone();
task.command = task
.command
.map(|command| anyhow::Ok(shlex::try_quote(&command)?.to_string()))
.transpose()?;
task.args = task
.args
.iter()
.map(|arg| {
Ok(shlex::try_quote(arg)
@@ -1543,26 +1559,16 @@ impl AcpThreadView {
.to_string())
})
.collect::<Result<Vec<_>>>()?;
task.full_label = task.label.clone();
task.id = task::TaskId(format!("external-agent-{}-login", task.label));
task.command_label = task.label.clone();
task.use_new_terminal = true;
task.allow_concurrent_runs = true;
task.hide = task::HideStrategy::Always;
task.shell = shell;
let terminal = terminal_panel.update_in(cx, |terminal_panel, window, cx| {
terminal_panel.spawn_task(
&SpawnInTerminal {
id: task::TaskId("claude-login".into()),
full_label: "claude /login".to_owned(),
label: "claude /login".to_owned(),
command: Some(command.into()),
args,
command_label: "claude /login".to_owned(),
cwd,
use_new_terminal: true,
allow_concurrent_runs: true,
hide: task::HideStrategy::Always,
shell,
..Default::default()
},
window,
cx,
)
terminal_panel.spawn_task(&login, window, cx)
})?;
let terminal = terminal.await?;
@@ -1578,7 +1584,9 @@ impl AcpThreadView {
cx.background_executor().timer(Duration::from_secs(1)).await;
let content =
terminal.update(cx, |terminal, _cx| terminal.get_content())?;
if content.contains("Login successful") {
if content.contains("Login successful")
|| content.contains("Type your message")
{
return anyhow::Ok(());
}
}
@@ -1594,6 +1602,9 @@ impl AcpThreadView {
}
}
_ = exit_status => {
if !previous_attempt && project.read_with(cx, |project, _| project.is_via_remote_server())? && login.label.contains("gemini") {
return cx.update(|window, cx| Self::spawn_external_agent_login(login, workspace, true, window, cx))?.await
}
return Err(anyhow!("exited before logging in"));
}
}
@@ -1626,14 +1637,16 @@ impl AcpThreadView {
cx.notify();
}
fn rewind(&mut self, message_id: &UserMessageId, cx: &mut Context<Self>) {
fn restore_checkpoint(&mut self, message_id: &UserMessageId, cx: &mut Context<Self>) {
let Some(thread) = self.thread() else {
return;
};
thread
.update(cx, |thread, cx| thread.rewind(message_id.clone(), cx))
.update(cx, |thread, cx| {
thread.restore_checkpoint(message_id.clone(), cx)
})
.detach_and_log_err(cx);
cx.notify();
}
fn render_entry(
@@ -1703,8 +1716,9 @@ impl AcpThreadView {
.label_size(LabelSize::XSmall)
.icon_color(Color::Muted)
.color(Color::Muted)
.tooltip(Tooltip::text("Restores all files in the project to the content they had at this point in the conversation."))
.on_click(cx.listener(move |this, _, _window, cx| {
this.rewind(&message_id, cx);
this.restore_checkpoint(&message_id, cx);
}))
)
.child(Divider::horizontal())
@@ -1775,7 +1789,7 @@ impl AcpThreadView {
let editor = editor.clone();
move |this, _, window, cx| {
this.regenerate(
entry_ix, &editor, window, cx,
entry_ix, editor.clone(), window, cx,
);
}
})).into_any_element()
@@ -2024,35 +2038,34 @@ impl AcpThreadView {
window: &Window,
cx: &Context<Self>,
) -> Div {
let has_location = tool_call.locations.len() == 1;
let card_header_id = SharedString::from("inner-tool-call-header");
let tool_icon =
if tool_call.kind == acp::ToolKind::Edit && tool_call.locations.len() == 1 {
FileIcons::get_icon(&tool_call.locations[0].path, cx)
.map(Icon::from_path)
.unwrap_or(Icon::new(IconName::ToolPencil))
} else {
Icon::new(match tool_call.kind {
acp::ToolKind::Read => IconName::ToolSearch,
acp::ToolKind::Edit => IconName::ToolPencil,
acp::ToolKind::Delete => IconName::ToolDeleteFile,
acp::ToolKind::Move => IconName::ArrowRightLeft,
acp::ToolKind::Search => IconName::ToolSearch,
acp::ToolKind::Execute => IconName::ToolTerminal,
acp::ToolKind::Think => IconName::ToolThink,
acp::ToolKind::Fetch => IconName::ToolWeb,
acp::ToolKind::Other => IconName::ToolHammer,
})
}
.size(IconSize::Small)
.color(Color::Muted);
let tool_icon = if tool_call.kind == acp::ToolKind::Edit && has_location {
FileIcons::get_icon(&tool_call.locations[0].path, cx)
.map(Icon::from_path)
.unwrap_or(Icon::new(IconName::ToolPencil))
} else {
Icon::new(match tool_call.kind {
acp::ToolKind::Read => IconName::ToolSearch,
acp::ToolKind::Edit => IconName::ToolPencil,
acp::ToolKind::Delete => IconName::ToolDeleteFile,
acp::ToolKind::Move => IconName::ArrowRightLeft,
acp::ToolKind::Search => IconName::ToolSearch,
acp::ToolKind::Execute => IconName::ToolTerminal,
acp::ToolKind::Think => IconName::ToolThink,
acp::ToolKind::Fetch => IconName::ToolWeb,
acp::ToolKind::Other => IconName::ToolHammer,
})
}
.size(IconSize::Small)
.color(Color::Muted);
let failed_or_canceled = match &tool_call.status {
ToolCallStatus::Rejected | ToolCallStatus::Canceled | ToolCallStatus::Failed => true,
_ => false,
};
let has_location = tool_call.locations.len() == 1;
let needs_confirmation = matches!(
tool_call.status,
ToolCallStatus::WaitingForConfirmation { .. }
@@ -2195,13 +2208,6 @@ impl AcpThreadView {
.overflow_hidden()
.child(tool_icon)
.child(if has_location {
let name = tool_call.locations[0]
.path
.file_name()
.unwrap_or_default()
.display()
.to_string();
h_flex()
.id(("open-tool-call-location", entry_ix))
.w_full()
@@ -2212,7 +2218,13 @@ impl AcpThreadView {
this.text_color(cx.theme().colors().text_muted)
}
})
.child(name)
.child(self.render_markdown(
tool_call.label.clone(),
MarkdownStyle {
prevent_mouse_interaction: true,
..default_markdown_style(false, true, window, cx)
},
))
.tooltip(Tooltip::text("Jump to File"))
.on_click(cx.listener(move |this, _, window, cx| {
this.open_tool_call_location(entry_ix, 0, window, cx);
@@ -3090,26 +3102,38 @@ impl AcpThreadView {
})
.children(connection.auth_methods().iter().enumerate().rev().map(
|(ix, method)| {
Button::new(
SharedString::from(method.id.0.clone()),
method.name.clone(),
)
.when(ix == 0, |el| {
el.style(ButtonStyle::Tinted(ui::TintColor::Warning))
})
.label_size(LabelSize::Small)
.on_click({
let method_id = method.id.clone();
cx.listener(move |this, _, window, cx| {
telemetry::event!(
"Authenticate Agent Started",
agent = this.agent.telemetry_id(),
method = method_id
);
let (method_id, name) = if self
.project
.read(cx)
.is_via_remote_server()
&& method.id.0.as_ref() == "oauth-personal"
&& method.name == "Log in with Google"
{
("spawn-gemini-cli".into(), "Log in with Gemini CLI".into())
} else {
(method.id.0.clone(), method.name.clone())
};
this.authenticate(method_id.clone(), window, cx)
Button::new(SharedString::from(method_id.clone()), name)
.when(ix == 0, |el| {
el.style(ButtonStyle::Tinted(ui::TintColor::Warning))
})
.label_size(LabelSize::Small)
.on_click({
cx.listener(move |this, _, window, cx| {
telemetry::event!(
"Authenticate Agent Started",
agent = this.agent.telemetry_id(),
method = method_id
);
this.authenticate(
acp::AuthMethodId(method_id.clone()),
window,
cx,
)
})
})
})
},
)),
)
@@ -4986,6 +5010,7 @@ impl AcpThreadView {
cloud_llm_client::Plan::ZedProTrial | cloud_llm_client::Plan::ZedFree => {
"Upgrade to Zed Pro for more prompts."
}
cloud_llm_client::Plan::ZedProV2 | cloud_llm_client::Plan::ZedProTrialV2 => "",
};
Callout::new()
@@ -5712,11 +5737,11 @@ pub(crate) mod tests {
fn connect(
&self,
_root_dir: &Path,
_root_dir: Option<&Path>,
_delegate: AgentServerDelegate,
_cx: &mut App,
) -> Task<gpui::Result<Rc<dyn AgentConnection>>> {
Task::ready(Ok(Rc::new(self.connection.clone())))
) -> Task<gpui::Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
Task::ready(Ok((Rc::new(self.connection.clone()), None)))
}
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {

View File

@@ -5,7 +5,6 @@ mod tool_picker;
use std::{ops::Range, sync::Arc};
use agent_servers::{AgentServerCommand, AllAgentServersSettings, CustomAgentServerSettings};
use agent_settings::AgentSettings;
use anyhow::Result;
use assistant_tool::{ToolSource, ToolWorkingSet};
@@ -26,6 +25,10 @@ use language_model::{
};
use notifications::status_toast::{StatusToast, ToastIcon};
use project::{
agent_server_store::{
AgentServerCommand, AgentServerStore, AllAgentServersSettings, CLAUDE_CODE_NAME,
CustomAgentServerSettings, GEMINI_NAME,
},
context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore},
project_settings::{ContextServerSettings, ProjectSettings},
};
@@ -45,11 +48,13 @@ pub(crate) use manage_profiles_modal::ManageProfilesModal;
use crate::{
AddContextServer, ExternalAgent, NewExternalAgentThread,
agent_configuration::add_llm_provider_modal::{AddLlmProviderModal, LlmCompatibleProvider},
placeholder_command,
};
pub struct AgentConfiguration {
fs: Arc<dyn Fs>,
language_registry: Arc<LanguageRegistry>,
agent_server_store: Entity<AgentServerStore>,
workspace: WeakEntity<Workspace>,
focus_handle: FocusHandle,
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
@@ -66,6 +71,7 @@ pub struct AgentConfiguration {
impl AgentConfiguration {
pub fn new(
fs: Arc<dyn Fs>,
agent_server_store: Entity<AgentServerStore>,
context_server_store: Entity<ContextServerStore>,
tools: Entity<ToolWorkingSet>,
language_registry: Arc<LanguageRegistry>,
@@ -104,6 +110,7 @@ impl AgentConfiguration {
workspace,
focus_handle,
configuration_views_by_provider: HashMap::default(),
agent_server_store,
context_server_store,
expanded_context_server_tools: HashMap::default(),
expanded_provider_configurations: HashMap::default(),
@@ -509,8 +516,10 @@ impl AgentConfiguration {
let (plan_name, label_color, bg_color) = match plan {
Plan::ZedFree => ("Free", Color::Default, free_chip_bg),
Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg),
Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg),
Plan::ZedProTrial | Plan::ZedProTrialV2 => {
("Pro Trial", Color::Accent, pro_chip_bg)
}
Plan::ZedPro | Plan::ZedProV2 => ("Pro", Color::Accent, pro_chip_bg),
};
Chip::new(plan_name.to_string())
@@ -991,17 +1000,30 @@ impl AgentConfiguration {
}
fn render_agent_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let settings = AllAgentServersSettings::get_global(cx).clone();
let user_defined_agents = settings
let custom_settings = cx
.global::<SettingsStore>()
.get::<AllAgentServersSettings>(None)
.custom
.iter()
.map(|(name, settings)| {
.clone();
let user_defined_agents = self
.agent_server_store
.read(cx)
.external_agents()
.filter(|name| name.0 != GEMINI_NAME && name.0 != CLAUDE_CODE_NAME)
.cloned()
.collect::<Vec<_>>();
let user_defined_agents = user_defined_agents
.into_iter()
.map(|name| {
self.render_agent_server(
IconName::Ai,
name.clone(),
ExternalAgent::Custom {
name: name.clone(),
command: settings.command.clone(),
name: name.clone().into(),
command: custom_settings
.get(&name.0)
.map(|settings| settings.command.clone())
.unwrap_or(placeholder_command()),
},
cx,
)

View File

@@ -5,9 +5,11 @@ use std::sync::Arc;
use std::time::Duration;
use acp_thread::AcpThread;
use agent_servers::AgentServerCommand;
use agent2::{DbThreadMetadata, HistoryEntry};
use db::kvp::{Dismissable, KEY_VALUE_STORE};
use project::agent_server_store::{
AgentServerCommand, AllAgentServersSettings, CLAUDE_CODE_NAME, GEMINI_NAME,
};
use serde::{Deserialize, Serialize};
use zed_actions::OpenBrowser;
use zed_actions::agent::{OpenClaudeCodeOnboardingModal, ReauthenticateAgent};
@@ -33,7 +35,9 @@ use crate::{
thread_history::{HistoryEntryElement, ThreadHistory},
ui::{AgentOnboardingModal, EndTrialUpsell},
};
use crate::{ExternalAgent, NewExternalAgentThread, NewNativeAgentThreadFromSummary};
use crate::{
ExternalAgent, NewExternalAgentThread, NewNativeAgentThreadFromSummary, placeholder_command,
};
use agent::{
Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio,
context_store::ContextStore,
@@ -62,7 +66,7 @@ use project::{DisableAiSettings, Project, ProjectPath, Worktree};
use prompt_store::{PromptBuilder, PromptStore, UserPromptId};
use rules_library::{RulesLibrary, open_rules_library};
use search::{BufferSearchBar, buffer_search};
use settings::{Settings, update_settings_file};
use settings::{Settings, SettingsStore, update_settings_file};
use theme::ThemeSettings;
use time::UtcOffset;
use ui::utils::WithRemSize;
@@ -1094,7 +1098,7 @@ impl AgentPanel {
let workspace = self.workspace.clone();
let project = self.project.clone();
let fs = self.fs.clone();
let is_not_local = !self.project.read(cx).is_local();
let is_via_collab = self.project.read(cx).is_via_collab();
const LAST_USED_EXTERNAL_AGENT_KEY: &str = "agent_panel__last_used_external_agent";
@@ -1126,7 +1130,7 @@ impl AgentPanel {
agent
}
None => {
if is_not_local {
if is_via_collab {
ExternalAgent::NativeAgent
} else {
cx.background_spawn(async move {
@@ -1503,6 +1507,7 @@ impl AgentPanel {
}
pub(crate) fn open_configuration(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let agent_server_store = self.project.read(cx).agent_server_store().clone();
let context_server_store = self.project.read(cx).context_server_store();
let tools = self.thread_store.read(cx).tools();
let fs = self.fs.clone();
@@ -1511,6 +1516,7 @@ impl AgentPanel {
self.configuration = Some(cx.new(|cx| {
AgentConfiguration::new(
fs,
agent_server_store,
context_server_store,
tools,
self.language_registry.clone(),
@@ -2503,6 +2509,7 @@ impl AgentPanel {
}
fn render_toolbar_new(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let agent_server_store = self.project.read(cx).agent_server_store().clone();
let focus_handle = self.focus_handle(cx);
let active_thread = match &self.active_view {
@@ -2531,12 +2538,14 @@ impl AgentPanel {
}
},
)
.anchor(Corner::TopLeft)
.anchor(Corner::TopRight)
.with_handle(self.new_thread_menu_handle.clone())
.menu({
let workspace = self.workspace.clone();
let is_not_local = workspace
.update(cx, |workspace, cx| !workspace.project().read(cx).is_local())
let is_via_collab = workspace
.update(cx, |workspace, cx| {
workspace.project().read(cx).is_via_collab()
})
.unwrap_or_default();
move |window, cx| {
@@ -2628,7 +2637,7 @@ impl AgentPanel {
ContextMenuEntry::new("New Gemini CLI Thread")
.icon(IconName::AiGemini)
.icon_color(Color::Muted)
.disabled(is_not_local)
.disabled(is_via_collab)
.handler({
let workspace = workspace.clone();
move |window, cx| {
@@ -2655,7 +2664,7 @@ impl AgentPanel {
menu.item(
ContextMenuEntry::new("New Claude Code Thread")
.icon(IconName::AiClaude)
.disabled(is_not_local)
.disabled(is_via_collab)
.icon_color(Color::Muted)
.handler({
let workspace = workspace.clone();
@@ -2680,19 +2689,25 @@ impl AgentPanel {
)
})
.when(cx.has_flag::<GeminiAndNativeFeatureFlag>(), |mut menu| {
// Add custom agents from settings
let settings =
agent_servers::AllAgentServersSettings::get_global(cx);
for (agent_name, agent_settings) in &settings.custom {
let agent_names = agent_server_store
.read(cx)
.external_agents()
.filter(|name| {
name.0 != GEMINI_NAME && name.0 != CLAUDE_CODE_NAME
})
.cloned()
.collect::<Vec<_>>();
let custom_settings = cx.global::<SettingsStore>().get::<AllAgentServersSettings>(None).custom.clone();
for agent_name in agent_names {
menu = menu.item(
ContextMenuEntry::new(format!("New {} Thread", agent_name))
.icon(IconName::Terminal)
.icon_color(Color::Muted)
.disabled(is_not_local)
.disabled(is_via_collab)
.handler({
let workspace = workspace.clone();
let agent_name = agent_name.clone();
let agent_settings = agent_settings.clone();
let custom_settings = custom_settings.clone();
move |window, cx| {
if let Some(workspace) = workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
@@ -2703,10 +2718,9 @@ impl AgentPanel {
panel.new_agent_thread(
AgentType::Custom {
name: agent_name
.clone(),
command: agent_settings
.command
.clone(),
.clone()
.into(),
command: custom_settings.get(&agent_name.0).map(|settings| settings.command.clone()).unwrap_or(placeholder_command())
},
window,
cx,
@@ -3504,6 +3518,7 @@ impl AgentPanel {
let error_message = match plan {
Plan::ZedPro => "Upgrade to usage-based billing for more prompts.",
Plan::ZedProTrial | Plan::ZedFree => "Upgrade to Zed Pro for more prompts.",
Plan::ZedProV2 | Plan::ZedProTrialV2 => "",
};
Callout::new()

View File

@@ -28,7 +28,6 @@ use std::rc::Rc;
use std::sync::Arc;
use agent::{Thread, ThreadId};
use agent_servers::AgentServerCommand;
use agent_settings::{AgentProfileId, AgentSettings, LanguageModelSelection};
use assistant_slash_command::SlashCommandRegistry;
use client::Client;
@@ -41,6 +40,7 @@ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelId, LanguageModelProviderId, LanguageModelRegistry,
};
use project::DisableAiSettings;
use project::agent_server_store::AgentServerCommand;
use prompt_store::PromptBuilder;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -174,6 +174,14 @@ enum ExternalAgent {
},
}
fn placeholder_command() -> AgentServerCommand {
AgentServerCommand {
path: "/placeholder".into(),
args: vec![],
env: None,
}
}
impl ExternalAgent {
fn name(&self) -> &'static str {
match self {
@@ -193,10 +201,9 @@ impl ExternalAgent {
Self::Gemini => Rc::new(agent_servers::Gemini),
Self::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
Self::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs, history)),
Self::Custom { name, command } => Rc::new(agent_servers::CustomAgentServer::new(
name.clone(),
command.clone(),
)),
Self::Custom { name, command: _ } => {
Rc::new(agent_servers::CustomAgentServer::new(name.clone()))
}
}
}
}

View File

@@ -1,29 +1,18 @@
use crate::agent_model_selector::AgentModelSelector;
use crate::buffer_codegen::BufferCodegen;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::message_editor::{ContextCreasesAddon, extract_message_creases, insert_message_creases};
use crate::terminal_codegen::TerminalCodegen;
use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist, ModelUsageContext};
use crate::{RemoveAllContext, ToggleContextPicker};
use agent::{
context_store::ContextStore,
thread_store::{TextThreadStore, ThreadStore},
};
use client::ErrorExt;
use collections::VecDeque;
use db::kvp::Dismissable;
use editor::actions::Paste;
use editor::display_map::EditorMargins;
use editor::{
ContextMenuOptions, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, MultiBuffer,
actions::{MoveDown, MoveUp},
};
use feature_flags::{FeatureFlagAppExt as _, ZedProFeatureFlag};
use fs::Fs;
use gpui::{
AnyElement, App, ClickEvent, Context, CursorStyle, Entity, EventEmitter, FocusHandle,
Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point,
AnyElement, App, Context, CursorStyle, Entity, EventEmitter, FocusHandle, Focusable,
Subscription, TextStyle, WeakEntity, Window,
};
use language_model::{LanguageModel, LanguageModelRegistry};
use parking_lot::Mutex;
@@ -33,12 +22,19 @@ use std::rc::Rc;
use std::sync::Arc;
use theme::ThemeSettings;
use ui::utils::WithRemSize;
use ui::{
CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, PopoverMenuHandle, Tooltip, prelude::*,
};
use ui::{IconButtonShape, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*};
use workspace::Workspace;
use zed_actions::agent::ToggleModelSelector;
use crate::agent_model_selector::AgentModelSelector;
use crate::buffer_codegen::BufferCodegen;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::message_editor::{ContextCreasesAddon, extract_message_creases, insert_message_creases};
use crate::terminal_codegen::TerminalCodegen;
use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist, ModelUsageContext};
use crate::{RemoveAllContext, ToggleContextPicker};
pub struct PromptEditor<T> {
pub editor: Entity<Editor>,
mode: PromptEditorMode,
@@ -144,47 +140,16 @@ impl<T: 'static> Render for PromptEditor<T> {
};
let error_message = SharedString::from(error.to_string());
if error.error_code() == proto::ErrorCode::RateLimitExceeded
&& cx.has_flag::<ZedProFeatureFlag>()
{
el.child(
v_flex()
.child(
IconButton::new(
"rate-limit-error",
IconName::XCircle,
)
.toggle_state(self.show_rate_limit_notice)
.shape(IconButtonShape::Square)
.icon_size(IconSize::Small)
.on_click(
cx.listener(Self::toggle_rate_limit_notice),
),
)
.children(self.show_rate_limit_notice.then(|| {
deferred(
anchored()
.position_mode(
gpui::AnchoredPositionMode::Local,
)
.position(point(px(0.), px(24.)))
.anchor(gpui::Corner::TopLeft)
.child(self.render_rate_limit_notice(cx)),
)
})),
)
} else {
el.child(
div()
.id("error")
.tooltip(Tooltip::text(error_message))
.child(
Icon::new(IconName::XCircle)
.size(IconSize::Small)
.color(Color::Error),
),
)
}
el.child(
div()
.id("error")
.tooltip(Tooltip::text(error_message))
.child(
Icon::new(IconName::XCircle)
.size(IconSize::Small)
.color(Color::Error),
),
)
}),
)
.child(
@@ -310,19 +275,6 @@ impl<T: 'static> PromptEditor<T> {
crate::active_thread::attach_pasted_images_as_context(&self.context_store, cx);
}
fn toggle_rate_limit_notice(
&mut self,
_: &ClickEvent,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.show_rate_limit_notice = !self.show_rate_limit_notice;
if self.show_rate_limit_notice {
window.focus(&self.editor.focus_handle(cx));
}
cx.notify();
}
fn handle_prompt_editor_events(
&mut self,
_: &Entity<Editor>,
@@ -707,61 +659,6 @@ impl<T: 'static> PromptEditor<T> {
.into_any_element()
}
fn render_rate_limit_notice(&self, cx: &mut Context<Self>) -> impl IntoElement {
Popover::new().child(
v_flex()
.occlude()
.p_2()
.child(
Label::new("Out of Tokens")
.size(LabelSize::Small)
.weight(FontWeight::BOLD),
)
.child(Label::new(
"Try Zed Pro for higher limits, a wider range of models, and more.",
))
.child(
h_flex()
.justify_between()
.child(CheckboxWithLabel::new(
"dont-show-again",
Label::new("Don't show again"),
if RateLimitNotice::dismissed() {
ui::ToggleState::Selected
} else {
ui::ToggleState::Unselected
},
|selection, _, cx| {
let is_dismissed = match selection {
ui::ToggleState::Unselected => false,
ui::ToggleState::Indeterminate => return,
ui::ToggleState::Selected => true,
};
RateLimitNotice::set_dismissed(is_dismissed, cx);
},
))
.child(
h_flex()
.gap_2()
.child(
Button::new("dismiss", "Dismiss")
.style(ButtonStyle::Transparent)
.on_click(cx.listener(Self::toggle_rate_limit_notice)),
)
.child(Button::new("more-info", "More Info").on_click(
|_event, window, cx| {
window.dispatch_action(
Box::new(zed_actions::OpenAccountSettings),
cx,
)
},
)),
),
),
)
}
fn render_editor(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
let colors = cx.theme().colors();
@@ -978,15 +875,7 @@ impl PromptEditor<BufferCodegen> {
self.editor
.update(cx, |editor, _| editor.set_read_only(false));
}
CodegenStatus::Error(error) => {
if cx.has_flag::<ZedProFeatureFlag>()
&& error.error_code() == proto::ErrorCode::RateLimitExceeded
&& !RateLimitNotice::dismissed()
{
self.show_rate_limit_notice = true;
cx.notify();
}
CodegenStatus::Error(_error) => {
self.edited_since_done = false;
self.editor
.update(cx, |editor, _| editor.set_read_only(false));
@@ -1189,12 +1078,6 @@ impl PromptEditor<TerminalCodegen> {
}
}
struct RateLimitNotice;
impl Dismissable for RateLimitNotice {
const KEY: &'static str = "dismissed-rate-limit-notice";
}
pub enum CodegenStatus {
Idle,
Pending,

View File

@@ -1,8 +1,6 @@
use std::{cmp::Reverse, sync::Arc};
use cloud_llm_client::Plan;
use collections::{HashSet, IndexMap};
use feature_flags::ZedProFeatureFlag;
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
use language_model::{
@@ -13,8 +11,6 @@ use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use ui::{ListItem, ListItemSpacing, prelude::*};
const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
@@ -531,13 +527,9 @@ impl PickerDelegate for LanguageModelPickerDelegate {
fn render_footer(
&self,
_: &mut Window,
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<gpui::AnyElement> {
use feature_flags::FeatureFlagAppExt;
let plan = Plan::ZedPro;
Some(
h_flex()
.w_full()
@@ -546,28 +538,6 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.p_1()
.gap_4()
.justify_between()
.when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
this.child(match plan {
Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
.icon(IconName::ZedAssistant)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.icon_position(IconPosition::Start)
.on_click(|_, window, cx| {
window
.dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
}),
Plan::ZedFree | Plan::ZedProTrial => Button::new(
"try-pro",
if plan == Plan::ZedProTrial {
"Upgrade to Pro"
} else {
"Try Pro"
},
)
.on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
})
})
.child(
Button::new("configure", "Configure")
.icon(IconName::Settings)

View File

@@ -6,8 +6,8 @@ use gpui::{Action, Entity, FocusHandle, Subscription, prelude::*};
use settings::{Settings as _, SettingsStore, update_settings_file};
use std::sync::Arc;
use ui::{
ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, Tooltip,
prelude::*,
ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, TintColor,
Tooltip, prelude::*,
};
/// Trait for types that can provide and manage agent profiles
@@ -170,7 +170,8 @@ impl Render for ProfileSelector {
.icon(IconName::ChevronDown)
.icon_size(IconSize::XSmall)
.icon_position(IconPosition::End)
.icon_color(Color::Muted);
.icon_color(Color::Muted)
.selected_style(ButtonStyle::Tinted(TintColor::Accent));
PopoverMenu::new("profile-selector")
.trigger_with_tooltip(trigger_button, {
@@ -195,6 +196,10 @@ impl Render for ProfileSelector {
.menu(move |window, cx| {
Some(this.update(cx, |this, cx| this.build_context_menu(window, cx)))
})
.offset(gpui::Point {
x: px(0.0),
y: px(-2.0),
})
.into_any_element()
} else {
Button::new("tools-not-supported-button", "Tools Unsupported")

View File

@@ -2,6 +2,7 @@ use std::sync::Arc;
use ai_onboarding::{AgentPanelOnboardingCard, PlanDefinitions};
use client::zed_urls;
use feature_flags::{BillingV2FeatureFlag, FeatureFlagAppExt as _};
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
use ui::{Divider, Tooltip, prelude::*};
@@ -18,8 +19,6 @@ impl EndTrialUpsell {
impl RenderOnce for EndTrialUpsell {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
let plan_definitions = PlanDefinitions;
let pro_section = v_flex()
.gap_1()
.child(
@@ -33,7 +32,7 @@ impl RenderOnce for EndTrialUpsell {
)
.child(Divider::horizontal()),
)
.child(plan_definitions.pro_plan(false))
.child(PlanDefinitions.pro_plan(cx.has_flag::<BillingV2FeatureFlag>(), false))
.child(
Button::new("cta-button", "Upgrade to Zed Pro")
.full_width()
@@ -64,7 +63,7 @@ impl RenderOnce for EndTrialUpsell {
)
.child(Divider::horizontal()),
)
.child(plan_definitions.free_plan());
.child(PlanDefinitions.free_plan(cx.has_flag::<BillingV2FeatureFlag>()));
AgentPanelOnboardingCard::new()
.child(Headline::new("Your Zed Pro Trial has expired"))

View File

@@ -45,13 +45,13 @@ impl RenderOnce for UsageCallout {
"Upgrade",
zed_urls::account_url(cx),
),
Plan::ZedProTrial => (
Plan::ZedProTrial | Plan::ZedProTrialV2 => (
"Out of trial prompts",
"Upgrade to Zed Pro to continue, or switch to API key.".to_string(),
"Upgrade",
zed_urls::account_url(cx),
),
Plan::ZedPro => (
Plan::ZedPro | Plan::ZedProV2 => (
"Out of included prompts",
"Enable usage-based billing to continue.".to_string(),
"Manage",

View File

@@ -18,6 +18,7 @@ default = []
client.workspace = true
cloud_llm_client.workspace = true
component.workspace = true
feature_flags.workspace = true
gpui.workspace = true
language_model.workspace = true
serde.workspace = true

View File

@@ -18,6 +18,7 @@ pub use young_account_banner::YoungAccountBanner;
use std::sync::Arc;
use client::{Client, UserStore, zed_urls};
use feature_flags::{BillingV2FeatureFlag, FeatureFlagAppExt as _};
use gpui::{AnyElement, Entity, IntoElement, ParentElement};
use ui::{Divider, RegisterComponent, Tooltip, prelude::*};
@@ -84,9 +85,8 @@ impl ZedAiOnboarding {
self
}
fn render_sign_in_disclaimer(&self, _cx: &mut App) -> AnyElement {
fn render_sign_in_disclaimer(&self, cx: &mut App) -> AnyElement {
let signing_in = matches!(self.sign_in_status, SignInStatus::SigningIn);
let plan_definitions = PlanDefinitions;
v_flex()
.gap_1()
@@ -96,7 +96,7 @@ impl ZedAiOnboarding {
.color(Color::Muted)
.mb_2(),
)
.child(plan_definitions.pro_plan(false))
.child(PlanDefinitions.pro_plan(cx.has_flag::<BillingV2FeatureFlag>(), false))
.child(
Button::new("sign_in", "Try Zed Pro for Free")
.disabled(signing_in)
@@ -114,16 +114,13 @@ impl ZedAiOnboarding {
}
fn render_free_plan_state(&self, cx: &mut App) -> AnyElement {
let young_account_banner = YoungAccountBanner;
let plan_definitions = PlanDefinitions;
if self.account_too_young {
v_flex()
.relative()
.max_w_full()
.gap_1()
.child(Headline::new("Welcome to Zed AI"))
.child(young_account_banner)
.child(YoungAccountBanner)
.child(
v_flex()
.mt_2()
@@ -139,7 +136,9 @@ impl ZedAiOnboarding {
)
.child(Divider::horizontal()),
)
.child(plan_definitions.pro_plan(true))
.child(
PlanDefinitions.pro_plan(cx.has_flag::<BillingV2FeatureFlag>(), true),
)
.child(
Button::new("pro", "Get Started")
.full_width()
@@ -182,7 +181,7 @@ impl ZedAiOnboarding {
)
.child(Divider::horizontal()),
)
.child(plan_definitions.free_plan()),
.child(PlanDefinitions.free_plan(cx.has_flag::<BillingV2FeatureFlag>())),
)
.when_some(
self.dismiss_onboarding.as_ref(),
@@ -220,7 +219,9 @@ impl ZedAiOnboarding {
)
.child(Divider::horizontal()),
)
.child(plan_definitions.pro_trial(true))
.child(
PlanDefinitions.pro_trial(cx.has_flag::<BillingV2FeatureFlag>(), true),
)
.child(
Button::new("pro", "Start Free Trial")
.full_width()
@@ -238,9 +239,7 @@ impl ZedAiOnboarding {
}
}
fn render_trial_state(&self, _cx: &mut App) -> AnyElement {
let plan_definitions = PlanDefinitions;
fn render_trial_state(&self, is_v2: bool, _cx: &mut App) -> AnyElement {
v_flex()
.relative()
.gap_1()
@@ -250,7 +249,7 @@ impl ZedAiOnboarding {
.color(Color::Muted)
.mb_2(),
)
.child(plan_definitions.pro_trial(false))
.child(PlanDefinitions.pro_trial(is_v2, false))
.when_some(
self.dismiss_onboarding.as_ref(),
|this, dismiss_callback| {
@@ -274,9 +273,7 @@ impl ZedAiOnboarding {
.into_any_element()
}
fn render_pro_plan_state(&self, _cx: &mut App) -> AnyElement {
let plan_definitions = PlanDefinitions;
fn render_pro_plan_state(&self, is_v2: bool, _cx: &mut App) -> AnyElement {
v_flex()
.gap_1()
.child(Headline::new("Welcome to Zed Pro"))
@@ -285,7 +282,7 @@ impl ZedAiOnboarding {
.color(Color::Muted)
.mb_2(),
)
.child(plan_definitions.pro_plan(false))
.child(PlanDefinitions.pro_plan(is_v2, false))
.when_some(
self.dismiss_onboarding.as_ref(),
|this, dismiss_callback| {
@@ -315,8 +312,10 @@ impl RenderOnce for ZedAiOnboarding {
if matches!(self.sign_in_status, SignInStatus::SignedIn) {
match self.plan {
None | Some(Plan::ZedFree) => self.render_free_plan_state(cx),
Some(Plan::ZedProTrial) => self.render_trial_state(cx),
Some(Plan::ZedPro) => self.render_pro_plan_state(cx),
Some(Plan::ZedProTrial) => self.render_trial_state(false, cx),
Some(Plan::ZedProTrialV2) => self.render_trial_state(true, cx),
Some(Plan::ZedPro) => self.render_pro_plan_state(false, cx),
Some(Plan::ZedProV2) => self.render_pro_plan_state(true, cx),
}
} else {
self.render_sign_in_disclaimer(cx)

View File

@@ -2,6 +2,7 @@ use std::sync::Arc;
use client::{Client, UserStore, zed_urls};
use cloud_llm_client::Plan;
use feature_flags::{BillingV2FeatureFlag, FeatureFlagAppExt};
use gpui::{AnyElement, App, Entity, IntoElement, RenderOnce, Window};
use ui::{CommonAnimationExt, Divider, Vector, VectorName, prelude::*};
@@ -49,9 +50,6 @@ impl AiUpsellCard {
impl RenderOnce for AiUpsellCard {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
let plan_definitions = PlanDefinitions;
let young_account_banner = YoungAccountBanner;
let pro_section = v_flex()
.flex_grow()
.w_full()
@@ -67,7 +65,7 @@ impl RenderOnce for AiUpsellCard {
)
.child(Divider::horizontal()),
)
.child(plan_definitions.pro_plan(false));
.child(PlanDefinitions.pro_plan(cx.has_flag::<BillingV2FeatureFlag>(), false));
let free_section = v_flex()
.flex_grow()
@@ -84,7 +82,7 @@ impl RenderOnce for AiUpsellCard {
)
.child(Divider::horizontal()),
)
.child(plan_definitions.free_plan());
.child(PlanDefinitions.free_plan(cx.has_flag::<BillingV2FeatureFlag>()));
let grid_bg = h_flex()
.absolute()
@@ -173,7 +171,7 @@ impl RenderOnce for AiUpsellCard {
.child(Label::new("Try Zed AI").size(LabelSize::Large))
.map(|this| {
if self.account_too_young {
this.child(young_account_banner).child(
this.child(YoungAccountBanner).child(
v_flex()
.mt_2()
.gap_1()
@@ -188,7 +186,10 @@ impl RenderOnce for AiUpsellCard {
)
.child(Divider::horizontal()),
)
.child(plan_definitions.pro_plan(true))
.child(
PlanDefinitions
.pro_plan(cx.has_flag::<BillingV2FeatureFlag>(), true),
)
.child(
Button::new("pro", "Get Started")
.full_width()
@@ -235,7 +236,7 @@ impl RenderOnce for AiUpsellCard {
)
}
}),
Some(Plan::ZedProTrial) => card
Some(plan @ Plan::ZedProTrial | plan @ Plan::ZedProTrialV2) => card
.child(pro_trial_stamp)
.child(Label::new("You're in the Zed Pro Trial").size(LabelSize::Large))
.child(
@@ -243,8 +244,8 @@ impl RenderOnce for AiUpsellCard {
.color(Color::Muted)
.mb_2(),
)
.child(plan_definitions.pro_trial(false)),
Some(Plan::ZedPro) => card
.child(PlanDefinitions.pro_trial(plan == Plan::ZedProTrialV2, false)),
Some(plan @ Plan::ZedPro | plan @ Plan::ZedProV2) => card
.child(certified_user_stamp)
.child(Label::new("You're in the Zed Pro plan").size(LabelSize::Large))
.child(
@@ -252,7 +253,7 @@ impl RenderOnce for AiUpsellCard {
.color(Color::Muted)
.mb_2(),
)
.child(plan_definitions.pro_plan(false)),
.child(PlanDefinitions.pro_plan(plan == Plan::ZedProV2, false)),
},
// Signed Out State
_ => card

View File

@@ -7,13 +7,13 @@ pub struct PlanDefinitions;
impl PlanDefinitions {
pub const AI_DESCRIPTION: &'static str = "Zed offers a complete agentic experience, with robust editing and reviewing features to collaborate with AI.";
pub fn free_plan(&self) -> impl IntoElement {
pub fn free_plan(&self, _is_v2: bool) -> impl IntoElement {
List::new()
.child(ListBulletItem::new("50 prompts with Claude models"))
.child(ListBulletItem::new("2,000 accepted edit predictions"))
}
pub fn pro_trial(&self, period: bool) -> impl IntoElement {
pub fn pro_trial(&self, _is_v2: bool, period: bool) -> impl IntoElement {
List::new()
.child(ListBulletItem::new("150 prompts with Claude models"))
.child(ListBulletItem::new(
@@ -26,7 +26,7 @@ impl PlanDefinitions {
})
}
pub fn pro_plan(&self, price: bool) -> impl IntoElement {
pub fn pro_plan(&self, _is_v2: bool, price: bool) -> impl IntoElement {
List::new()
.child(ListBulletItem::new("500 prompts with Claude models"))
.child(ListBulletItem::new(

View File

@@ -14,20 +14,11 @@ doctest = false
[dependencies]
anyhow.workspace = true
async-tar.workspace = true
collections.workspace = true
crossbeam.workspace = true
gpui.workspace = true
log.workspace = true
parking_lot.workspace = true
rodio = { workspace = true, features = [ "wav", "playback", "wav_output" ] }
settings.workspace = true
schemars.workspace = true
serde.workspace = true
settings.workspace = true
smol.workspace = true
thiserror.workspace = true
rodio = { workspace = true, features = [ "wav", "playback", "tracing" ] }
util.workspace = true
workspace-hack.workspace = true
[target.'cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))'.dependencies]
libwebrtc = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks" }

View File

@@ -1,56 +1,19 @@
use anyhow::{Context as _, Result};
use anyhow::{Context as _, Result, anyhow};
use collections::HashMap;
use gpui::{App, BackgroundExecutor, BorrowAppContext, Global};
#[cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))]
mod non_windows_and_freebsd_deps {
pub(super) use gpui::AsyncApp;
pub(super) use libwebrtc::native::apm;
pub(super) use log::info;
pub(super) use parking_lot::Mutex;
pub(super) use rodio::cpal::Sample;
pub(super) use rodio::source::{LimitSettings, UniformSourceIterator};
pub(super) use std::sync::Arc;
}
#[cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))]
use non_windows_and_freebsd_deps::*;
use rodio::{
Decoder, OutputStream, OutputStreamBuilder, Source, mixer::Mixer, nz, source::Buffered,
};
use gpui::{App, BorrowAppContext, Global};
use rodio::{Decoder, OutputStream, OutputStreamBuilder, Source, source::Buffered};
use settings::Settings;
use std::{io::Cursor, num::NonZero, path::PathBuf, sync::atomic::Ordering, time::Duration};
use std::io::Cursor;
use util::ResultExt;
mod audio_settings;
mod replays;
mod rodio_ext;
pub use audio_settings::AudioSettings;
pub use rodio_ext::RodioExt;
use crate::audio_settings::LIVE_SETTINGS;
// NOTE: We used to use WebRTC's mixer which only supported
// 16kHz, 32kHz and 48kHz. As 48 is the most common "next step up"
// for audio output devices like speakers/bluetooth, we just hard-code
// this; and downsample when we need to.
//
// Since most noise cancelling requires 16kHz we will move to
// that in the future.
pub const SAMPLE_RATE: NonZero<u32> = nz!(48000);
pub const CHANNEL_COUNT: NonZero<u16> = nz!(2);
pub const BUFFER_SIZE: usize = // echo canceller and livekit want 10ms of audio
(SAMPLE_RATE.get() as usize / 100) * CHANNEL_COUNT.get() as usize;
pub const REPLAY_DURATION: Duration = Duration::from_secs(30);
pub fn init(cx: &mut App) {
AudioSettings::register(cx);
LIVE_SETTINGS.initialize(cx);
}
#[derive(Debug, Copy, Clone, Eq, Hash, PartialEq)]
#[derive(Copy, Clone, Eq, Hash, PartialEq)]
pub enum Sound {
Joined,
Leave,
@@ -75,152 +38,32 @@ impl Sound {
}
}
#[derive(Default)]
pub struct Audio {
output_handle: Option<OutputStream>,
output_mixer: Option<Mixer>,
#[cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))]
pub echo_canceller: Arc<Mutex<apm::AudioProcessingModule>>,
source_cache: HashMap<Sound, Buffered<Decoder<Cursor<Vec<u8>>>>>,
replays: replays::Replays,
}
impl Default for Audio {
fn default() -> Self {
Self {
output_handle: Default::default(),
output_mixer: Default::default(),
#[cfg(not(any(
all(target_os = "windows", target_env = "gnu"),
target_os = "freebsd"
)))]
echo_canceller: Arc::new(Mutex::new(apm::AudioProcessingModule::new(
true, false, false, false,
))),
source_cache: Default::default(),
replays: Default::default(),
}
}
}
impl Global for Audio {}
impl Audio {
fn ensure_output_exists(&mut self) -> Result<&Mixer> {
fn ensure_output_exists(&mut self) -> Option<&OutputStream> {
if self.output_handle.is_none() {
self.output_handle = Some(
OutputStreamBuilder::open_default_stream()
.context("Could not open default output stream")?,
);
if let Some(output_handle) = &self.output_handle {
let (mixer, source) = rodio::mixer::mixer(CHANNEL_COUNT, SAMPLE_RATE);
// or the mixer will end immediately as its empty.
mixer.add(rodio::source::Zero::new(CHANNEL_COUNT, SAMPLE_RATE));
self.output_mixer = Some(mixer);
// The webrtc apm is not yet compiling for windows & freebsd
#[cfg(not(any(
any(all(target_os = "windows", target_env = "gnu")),
target_os = "freebsd"
)))]
let echo_canceller = Arc::clone(&self.echo_canceller);
#[cfg(not(any(
any(all(target_os = "windows", target_env = "gnu")),
target_os = "freebsd"
)))]
let source = source.inspect_buffer::<BUFFER_SIZE, _>(move |buffer| {
let mut buf: [i16; _] = buffer.map(|s| s.to_sample());
echo_canceller
.lock()
.process_reverse_stream(
&mut buf,
SAMPLE_RATE.get() as i32,
CHANNEL_COUNT.get().into(),
)
.expect("Audio input and output threads should not panic");
});
output_handle.mixer().add(source);
}
self.output_handle = OutputStreamBuilder::open_default_stream().log_err();
}
Ok(self
.output_mixer
.as_ref()
.expect("we only get here if opening the outputstream succeeded"))
self.output_handle.as_ref()
}
pub fn save_replays(
&self,
executor: BackgroundExecutor,
) -> gpui::Task<anyhow::Result<(PathBuf, Duration)>> {
self.replays.replays_to_tar(executor)
}
#[cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))]
pub fn open_microphone(voip_parts: VoipParts) -> anyhow::Result<impl Source> {
let stream = rodio::microphone::MicrophoneBuilder::new()
.default_device()?
.default_config()?
.prefer_sample_rates([SAMPLE_RATE, SAMPLE_RATE.saturating_mul(nz!(2))])
.prefer_channel_counts([nz!(1), nz!(2)])
.prefer_buffer_sizes(512..)
.open_stream()?;
info!("Opened microphone: {:?}", stream.config());
let (replay, stream) = UniformSourceIterator::new(stream, CHANNEL_COUNT, SAMPLE_RATE)
.limit(LimitSettings::live_performance())
.process_buffer::<BUFFER_SIZE, _>(move |buffer| {
let mut int_buffer: [i16; _] = buffer.map(|s| s.to_sample());
if voip_parts
.echo_canceller
.lock()
.process_stream(
&mut int_buffer,
SAMPLE_RATE.get() as i32,
CHANNEL_COUNT.get() as i32,
)
.context("livekit audio processor error")
.log_err()
.is_some()
{
for (sample, processed) in buffer.iter_mut().zip(&int_buffer) {
*sample = (*processed).to_sample();
}
}
})
.automatic_gain_control(1.0, 4.0, 0.0, 5.0)
.periodic_access(Duration::from_millis(100), move |agc_source| {
agc_source.set_enabled(LIVE_SETTINGS.control_input_volume.load(Ordering::Relaxed));
})
.replayable(REPLAY_DURATION)?;
voip_parts
.replays
.add_voip_stream("local microphone".to_string(), replay);
Ok(stream)
}
pub fn play_voip_stream(
pub fn play_source(
source: impl rodio::Source + Send + 'static,
speaker_name: String,
is_staff: bool,
cx: &mut App,
) -> anyhow::Result<()> {
let (replay_source, source) = source
.automatic_gain_control(1.0, 4.0, 0.0, 5.0)
.periodic_access(Duration::from_millis(100), move |agc_source| {
agc_source.set_enabled(LIVE_SETTINGS.control_input_volume.load(Ordering::Relaxed));
})
.replayable(REPLAY_DURATION)
.expect("REPLAY_DURATION is longer then 100ms");
cx.update_default_global(|this: &mut Self, _cx| {
let output_mixer = this
let output_handle = this
.ensure_output_exists()
.context("Could not get output mixer")?;
output_mixer.add(source);
if is_staff {
this.replays.add_voip_stream(speaker_name, replay_source);
}
.ok_or_else(|| anyhow!("Could not open audio output"))?;
output_handle.mixer().add(source);
Ok(())
})
}
@@ -228,12 +71,8 @@ impl Audio {
pub fn play_sound(sound: Sound, cx: &mut App) {
cx.update_default_global(|this: &mut Self, cx| {
let source = this.sound_source(sound, cx).log_err()?;
let output_mixer = this
.ensure_output_exists()
.context("Could not get output mixer")
.log_err()?;
output_mixer.add(source);
let output_handle = this.ensure_output_exists()?;
output_handle.mixer().add(source);
Some(())
});
}
@@ -264,23 +103,3 @@ impl Audio {
Ok(source)
}
}
#[cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))]
pub struct VoipParts {
echo_canceller: Arc<Mutex<apm::AudioProcessingModule>>,
replays: replays::Replays,
}
#[cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))]
impl VoipParts {
pub fn new(cx: &AsyncApp) -> anyhow::Result<Self> {
let (apm, replays) = cx.try_read_default_global::<Audio, _>(|audio, _| {
(Arc::clone(&audio.echo_canceller), audio.replays.clone())
})?;
Ok(Self {
echo_canceller: apm,
replays,
})
}
}

View File

@@ -1,29 +1,14 @@
use std::sync::atomic::{AtomicBool, Ordering};
use anyhow::Result;
use gpui::App;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsKey, SettingsSources, SettingsStore, SettingsUi};
use settings::{Settings, SettingsKey, SettingsSources, SettingsUi};
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug, SettingsUi)]
pub struct AudioSettings {
/// Opt into the new audio system.
#[serde(rename = "experimental.rodio_audio", default)]
pub rodio_audio: bool, // default is false
/// Requires 'rodio_audio: true'
///
/// Use the new audio systems automatic gain control for your microphone.
/// This affects how loud you sound to others.
#[serde(rename = "experimental.control_input_volume", default)]
pub control_input_volume: bool,
/// Requires 'rodio_audio: true'
///
/// Use the new audio systems automatic gain control on everyone in the
/// call. This makes call members who are too quite louder and those who are
/// too loud quieter. This only affects how things sound for you.
#[serde(rename = "experimental.control_output_volume", default)]
pub control_output_volume: bool,
}
/// Configuration of audio in Zed.
@@ -31,22 +16,9 @@ pub struct AudioSettings {
#[serde(default)]
#[settings_key(key = "audio")]
pub struct AudioSettingsContent {
/// Opt into the new audio system.
/// Whether to use the experimental audio system
#[serde(rename = "experimental.rodio_audio", default)]
pub rodio_audio: bool, // default is false
/// Requires 'rodio_audio: true'
///
/// Use the new audio systems automatic gain control for your microphone.
/// This affects how loud you sound to others.
#[serde(rename = "experimental.control_input_volume", default)]
pub control_input_volume: bool,
/// Requires 'rodio_audio: true'
///
/// Use the new audio systems automatic gain control on everyone in the
/// call. This makes call members who are too quite louder and those who are
/// too loud quieter. This only affects how things sound for you.
#[serde(rename = "experimental.control_output_volume", default)]
pub control_output_volume: bool,
pub rodio_audio: bool,
}
impl Settings for AudioSettings {
@@ -58,42 +30,3 @@ impl Settings for AudioSettings {
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
}
/// See docs on [LIVE_SETTINGS]
pub(crate) struct LiveSettings {
pub(crate) control_input_volume: AtomicBool,
pub(crate) control_output_volume: AtomicBool,
}
impl LiveSettings {
pub(crate) fn initialize(&self, cx: &mut App) {
cx.observe_global::<SettingsStore>(move |cx| {
LIVE_SETTINGS.control_input_volume.store(
AudioSettings::get_global(cx).control_input_volume,
Ordering::Relaxed,
);
LIVE_SETTINGS.control_output_volume.store(
AudioSettings::get_global(cx).control_output_volume,
Ordering::Relaxed,
);
})
.detach();
let init_settings = AudioSettings::get_global(cx);
LIVE_SETTINGS
.control_input_volume
.store(init_settings.control_input_volume, Ordering::Relaxed);
LIVE_SETTINGS
.control_output_volume
.store(init_settings.control_output_volume, Ordering::Relaxed);
}
}
/// Allows access to settings from the audio thread. Updated by
/// observer of SettingsStore. Needed because audio playback and recording are
/// real time and must each run in a dedicated OS thread, therefore we can not
/// use the background executor.
pub(crate) static LIVE_SETTINGS: LiveSettings = LiveSettings {
control_input_volume: AtomicBool::new(true),
control_output_volume: AtomicBool::new(true),
};

View File

@@ -1,77 +0,0 @@
use anyhow::{Context, anyhow};
use async_tar::{Builder, Header};
use gpui::{BackgroundExecutor, Task};
use collections::HashMap;
use parking_lot::Mutex;
use rodio::Source;
use smol::fs::File;
use std::{io, path::PathBuf, sync::Arc, time::Duration};
use crate::{REPLAY_DURATION, rodio_ext::Replay};
#[derive(Default, Clone)]
pub(crate) struct Replays(Arc<Mutex<HashMap<String, Replay>>>);
impl Replays {
pub(crate) fn add_voip_stream(&self, stream_name: String, source: Replay) {
let mut map = self.0.lock();
map.retain(|_, replay| replay.source_is_active());
map.insert(stream_name, source);
}
pub(crate) fn replays_to_tar(
&self,
executor: BackgroundExecutor,
) -> Task<anyhow::Result<(PathBuf, Duration)>> {
let map = Arc::clone(&self.0);
executor.spawn(async move {
let recordings: Vec<_> = map
.lock()
.iter_mut()
.map(|(name, replay)| {
let queued = REPLAY_DURATION.min(replay.duration_ready());
(name.clone(), replay.take_duration(queued).record())
})
.collect();
let longest = recordings
.iter()
.map(|(_, r)| {
r.total_duration()
.expect("SamplesBuffer always returns a total duration")
})
.max()
.ok_or(anyhow!("There is no audio to capture"))?;
let path = std::env::current_dir()
.context("Could not get current dir")?
.join("replays.tar");
let tar = File::create(&path)
.await
.context("Could not create file for tar")?;
let mut tar = Builder::new(tar);
for (name, recording) in recordings {
let mut writer = io::Cursor::new(Vec::new());
rodio::wav_to_writer(recording, &mut writer).context("failed to encode wav")?;
let wav_data = writer.into_inner();
let path = name.replace(' ', "_") + ".wav";
let mut header = Header::new_gnu();
// rw permissions for everyone
header.set_mode(0o666);
header.set_size(wav_data.len() as u64);
tar.append_data(&mut header, path, wav_data.as_slice())
.await
.context("failed to apped wav to tar")?;
}
tar.into_inner()
.await
.context("Could not finish writing tar")?
.sync_all()
.await
.context("Could not flush tar file to disk")?;
Ok((path, longest))
})
}
}

View File

@@ -1,598 +0,0 @@
use std::{
sync::{
Arc, Mutex,
atomic::{AtomicBool, Ordering},
},
time::Duration,
};
use crossbeam::queue::ArrayQueue;
use rodio::{ChannelCount, Sample, SampleRate, Source};
#[derive(Debug, thiserror::Error)]
#[error("Replay duration is too short must be >= 100ms")]
pub struct ReplayDurationTooShort;
pub trait RodioExt: Source + Sized {
fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
where
F: FnMut(&mut [Sample; N]);
fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
where
F: FnMut(&[Sample; N]);
fn replayable(
self,
duration: Duration,
) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
fn take_samples(self, n: usize) -> TakeSamples<Self>;
}
impl<S: Source> RodioExt for S {
fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
where
F: FnMut(&mut [Sample; N]),
{
ProcessBuffer {
inner: self,
callback,
buffer: [0.0; N],
next: N,
}
}
fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
where
F: FnMut(&[Sample; N]),
{
InspectBuffer {
inner: self,
callback,
buffer: [0.0; N],
free: 0,
}
}
/// Maintains a live replay with a history of at least `duration` seconds.
///
/// Note:
/// History can be 100ms longer if the source drops before or while the
/// replay is being read
///
/// # Errors
/// If duration is smaller then 100ms
fn replayable(
self,
duration: Duration,
) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
if duration < Duration::from_millis(100) {
return Err(ReplayDurationTooShort);
}
let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize;
let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
let samples_to_queue =
(samples_to_queue as usize).next_multiple_of(self.channels().get().into());
let chunk_size =
(samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize);
let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
let is_active = Arc::new(AtomicBool::new(true));
let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
Ok((
Replay {
rx: Arc::clone(&queue),
buffer: Vec::new().into_iter(),
sleep_duration: duration / 2,
sample_rate: self.sample_rate(),
channel_count: self.channels(),
source_is_active: is_active.clone(),
},
Replayable {
tx: queue,
inner: self,
buffer: Vec::with_capacity(chunk_size),
chunk_size,
is_active,
},
))
}
fn take_samples(self, n: usize) -> TakeSamples<S> {
TakeSamples {
inner: self,
left_to_take: n,
}
}
}
pub struct TakeSamples<S> {
inner: S,
left_to_take: usize,
}
impl<S: Source> Iterator for TakeSamples<S> {
type Item = Sample;
fn next(&mut self) -> Option<Self::Item> {
if self.left_to_take == 0 {
None
} else {
self.left_to_take -= 1;
self.inner.next()
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(self.left_to_take))
}
}
impl<S: Source> Source for TakeSamples<S> {
fn current_span_len(&self) -> Option<usize> {
None // does not support spans
}
fn channels(&self) -> ChannelCount {
self.inner.channels()
}
fn sample_rate(&self) -> SampleRate {
self.inner.sample_rate()
}
fn total_duration(&self) -> Option<Duration> {
Some(Duration::from_secs_f64(
self.left_to_take as f64
/ self.sample_rate().get() as f64
/ self.channels().get() as f64,
))
}
}
#[derive(Debug)]
struct ReplayQueue {
inner: ArrayQueue<Vec<Sample>>,
normal_chunk_len: usize,
/// The last chunk in the queue may be smaller then
/// the normal chunk size. This is always equal to the
/// size of the last element in the queue.
/// (so normally chunk_size)
last_chunk: Mutex<Vec<Sample>>,
}
impl ReplayQueue {
fn new(queue_len: usize, chunk_size: usize) -> Self {
Self {
inner: ArrayQueue::new(queue_len),
normal_chunk_len: chunk_size,
last_chunk: Mutex::new(Vec::new()),
}
}
/// Returns the length in samples
fn len(&self) -> usize {
self.inner.len().saturating_sub(1) * self.normal_chunk_len
+ self
.last_chunk
.lock()
.expect("Self::push_last can not poison this lock")
.len()
}
fn pop(&self) -> Option<Vec<Sample>> {
self.inner.pop() // removes element that was inserted first
}
fn push_last(&self, mut samples: Vec<Sample>) {
let mut last_chunk = self
.last_chunk
.lock()
.expect("Self::len can not poison this lock");
std::mem::swap(&mut *last_chunk, &mut samples);
}
fn push_normal(&self, samples: Vec<Sample>) {
let _pushed_out_of_ringbuf = self.inner.force_push(samples);
}
}
pub struct ProcessBuffer<const N: usize, S, F>
where
S: Source + Sized,
F: FnMut(&mut [Sample; N]),
{
inner: S,
callback: F,
/// Buffer used for both input and output.
buffer: [Sample; N],
/// Next already processed sample is at this index
/// in buffer.
///
/// If this is equal to the length of the buffer we have no more samples and
/// we must get new ones and process them
next: usize,
}
impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
where
S: Source + Sized,
F: FnMut(&mut [Sample; N]),
{
type Item = Sample;
fn next(&mut self) -> Option<Self::Item> {
self.next += 1;
if self.next < self.buffer.len() {
let sample = self.buffer[self.next];
return Some(sample);
}
for sample in &mut self.buffer {
*sample = self.inner.next()?
}
(self.callback)(&mut self.buffer);
self.next = 0;
Some(self.buffer[0])
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
where
S: Source + Sized,
F: FnMut(&mut [Sample; N]),
{
fn current_span_len(&self) -> Option<usize> {
None
}
fn channels(&self) -> rodio::ChannelCount {
self.inner.channels()
}
fn sample_rate(&self) -> rodio::SampleRate {
self.inner.sample_rate()
}
fn total_duration(&self) -> Option<std::time::Duration> {
self.inner.total_duration()
}
}
pub struct InspectBuffer<const N: usize, S, F>
where
S: Source + Sized,
F: FnMut(&[Sample; N]),
{
inner: S,
callback: F,
/// Stores already emitted samples, once its full we call the callback.
buffer: [Sample; N],
/// Next free element in buffer. If this is equal to the buffer length
/// we have no more free lements.
free: usize,
}
impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
where
S: Source + Sized,
F: FnMut(&[Sample; N]),
{
type Item = Sample;
fn next(&mut self) -> Option<Self::Item> {
let Some(sample) = self.inner.next() else {
return None;
};
self.buffer[self.free] = sample;
self.free += 1;
if self.free == self.buffer.len() {
(self.callback)(&self.buffer);
self.free = 0
}
Some(sample)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
where
S: Source + Sized,
F: FnMut(&[Sample; N]),
{
fn current_span_len(&self) -> Option<usize> {
None
}
fn channels(&self) -> rodio::ChannelCount {
self.inner.channels()
}
fn sample_rate(&self) -> rodio::SampleRate {
self.inner.sample_rate()
}
fn total_duration(&self) -> Option<std::time::Duration> {
self.inner.total_duration()
}
}
#[derive(Debug)]
pub struct Replayable<S: Source> {
inner: S,
buffer: Vec<Sample>,
chunk_size: usize,
tx: Arc<ReplayQueue>,
is_active: Arc<AtomicBool>,
}
impl<S: Source> Iterator for Replayable<S> {
type Item = Sample;
fn next(&mut self) -> Option<Self::Item> {
if let Some(sample) = self.inner.next() {
self.buffer.push(sample);
// If the buffer is full send it
if self.buffer.len() == self.chunk_size {
self.tx.push_normal(std::mem::take(&mut self.buffer));
}
Some(sample)
} else {
let last_chunk = std::mem::take(&mut self.buffer);
self.tx.push_last(last_chunk);
self.is_active.store(false, Ordering::Relaxed);
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl<S: Source> Source for Replayable<S> {
fn current_span_len(&self) -> Option<usize> {
self.inner.current_span_len()
}
fn channels(&self) -> ChannelCount {
self.inner.channels()
}
fn sample_rate(&self) -> SampleRate {
self.inner.sample_rate()
}
fn total_duration(&self) -> Option<Duration> {
self.inner.total_duration()
}
}
#[derive(Debug)]
pub struct Replay {
rx: Arc<ReplayQueue>,
buffer: std::vec::IntoIter<Sample>,
sleep_duration: Duration,
sample_rate: SampleRate,
channel_count: ChannelCount,
source_is_active: Arc<AtomicBool>,
}
impl Replay {
pub fn source_is_active(&self) -> bool {
// - source could return None and not drop
// - source could be dropped before returning None
self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
}
/// Duration of what is in the buffer and can be returned without blocking.
pub fn duration_ready(&self) -> Duration {
let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
Duration::from_secs_f64(seconds_queued)
}
/// Number of samples in the buffer and can be returned without blocking.
pub fn samples_ready(&self) -> usize {
self.rx.len() + self.buffer.len()
}
}
impl Iterator for Replay {
type Item = Sample;
fn next(&mut self) -> Option<Self::Item> {
if let Some(sample) = self.buffer.next() {
return Some(sample);
}
loop {
if let Some(new_buffer) = self.rx.pop() {
self.buffer = new_buffer.into_iter();
return self.buffer.next();
}
if !self.source_is_active() {
return None;
}
// The queue does not support blocking on a next item. We want this queue as it
// is quite fast and provides a fixed size. We know how many samples are in a
// buffer so if we do not get one now we must be getting one after `sleep_duration`.
std::thread::sleep(self.sleep_duration);
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
((self.rx.len() + self.buffer.len()), None)
}
}
impl Source for Replay {
fn current_span_len(&self) -> Option<usize> {
None // source is not compatible with spans
}
fn channels(&self) -> ChannelCount {
self.channel_count
}
fn sample_rate(&self) -> SampleRate {
self.sample_rate
}
fn total_duration(&self) -> Option<Duration> {
None
}
}
#[cfg(test)]
mod tests {
use rodio::{nz, static_buffer::StaticSamplesBuffer};
use super::*;
const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
fn test_source() -> StaticSamplesBuffer {
StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
}
mod process_buffer {
use super::*;
#[test]
fn callback_gets_all_samples() {
let input = test_source();
let _ = input
.process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
.count();
}
#[test]
fn callback_modifies_yielded() {
let input = test_source();
let yielded: Vec<_> = input
.process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
for sample in buffer {
*sample += 1.0;
}
})
.collect();
assert_eq!(
yielded,
SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
)
}
#[test]
fn source_truncates_to_whole_buffers() {
let input = test_source();
let yielded = input
.process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
.count();
assert_eq!(yielded, 3)
}
}
mod inspect_buffer {
use super::*;
#[test]
fn callback_gets_all_samples() {
let input = test_source();
let _ = input
.inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
.count();
}
#[test]
fn source_does_not_truncate() {
let input = test_source();
let yielded = input
.inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
.count();
assert_eq!(yielded, SAMPLES.len())
}
}
mod instant_replay {
use super::*;
#[test]
fn continues_after_history() {
let input = test_source();
let (mut replay, mut source) = input
.replayable(Duration::from_secs(3))
.expect("longer then 100ms");
source.by_ref().take(3).count();
let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
assert_eq!(&yielded, &SAMPLES[0..3],);
source.count();
let yielded: Vec<Sample> = replay.collect();
assert_eq!(&yielded, &SAMPLES[3..5],);
}
#[test]
fn keeps_only_latest() {
let input = test_source();
let (mut replay, mut source) = input
.replayable(Duration::from_secs(2))
.expect("longer then 100ms");
source.by_ref().take(5).count(); // get all items but do not end the source
let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
assert_eq!(&yielded, &SAMPLES[3..5]);
source.count(); // exhaust source
assert_eq!(replay.next(), None);
}
#[test]
fn keeps_correct_amount_of_seconds() {
let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
let (replay, mut source) = input
.replayable(Duration::from_secs(2))
.expect("longer then 100ms");
// exhaust but do not yet end source
source.by_ref().take(40_000).count();
// take all samples we can without blocking
let ready = replay.samples_ready();
let n_yielded = replay.take_samples(ready).count();
let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
let margin = 16_000 / 10; // 100ms
assert!(n_yielded as u32 >= max - margin);
}
#[test]
fn samples_ready() {
let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
let (mut replay, source) = input
.replayable(Duration::from_secs(2))
.expect("longer then 100ms");
assert_eq!(replay.by_ref().samples_ready(), 0);
source.take(8000).count(); // half a second
let margin = 16_000 / 10; // 100ms
let ready = replay.samples_ready();
assert!(ready >= 8000 - margin);
}
}
}

View File

@@ -1,5 +1,4 @@
use auto_update::AutoUpdater;
use client::proto::UpdateNotification;
use editor::{Editor, MultiBuffer};
use gpui::{App, Context, DismissEvent, Entity, Window, actions, prelude::*};
use http_client::HttpClient;
@@ -138,6 +137,8 @@ pub fn notify_if_app_was_updated(cx: &mut App) {
return;
}
struct UpdateNotification;
let should_show_notification = updater.read(cx).should_show_update_notification(cx);
cx.spawn(async move |cx| {
let should_show_notification = should_show_notification.await?;

View File

@@ -29,7 +29,6 @@ client.workspace = true
collections.workspace = true
fs.workspace = true
futures.workspace = true
feature_flags.workspace = true
gpui = { workspace = true, features = ["screen-capture"] }
language.workspace = true
log.workspace = true

View File

@@ -9,7 +9,6 @@ use client::{
proto::{self, PeerId},
};
use collections::{BTreeMap, HashMap, HashSet};
use feature_flags::FeatureFlagAppExt;
use fs::Fs;
use futures::StreamExt;
use gpui::{
@@ -1323,18 +1322,8 @@ impl Room {
return Task::ready(Err(anyhow!("live-kit was not initialized")));
};
let is_staff = cx.is_staff();
let user_name = self
.user_store
.read(cx)
.current_user()
.and_then(|user| user.name.clone())
.unwrap_or_else(|| "unknown".to_string());
cx.spawn(async move |this, cx| {
let publication = room
.publish_local_microphone_track(user_name, is_staff, cx)
.await;
let publication = room.publish_local_microphone_track(cx).await;
this.update(cx, |this, cx| {
let live_kit = this
.live_kit

View File

@@ -25,11 +25,9 @@ gpui.workspace = true
language.workspace = true
log.workspace = true
postage.workspace = true
rand.workspace = true
release_channel.workspace = true
rpc.workspace = true
settings.workspace = true
sum_tree.workspace = true
text.workspace = true
time.workspace = true
util.workspace = true

View File

@@ -1,5 +1,4 @@
mod channel_buffer;
mod channel_chat;
mod channel_store;
use client::{Client, UserStore};
@@ -7,10 +6,6 @@ use gpui::{App, Entity};
use std::sync::Arc;
pub use channel_buffer::{ACKNOWLEDGE_DEBOUNCE_INTERVAL, ChannelBuffer, ChannelBufferEvent};
pub use channel_chat::{
ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId, MessageParams,
mentions_to_proto,
};
pub use channel_store::{Channel, ChannelEvent, ChannelMembership, ChannelStore};
#[cfg(test)]
@@ -19,5 +14,4 @@ mod channel_store_tests;
pub fn init(client: &Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
channel_store::init(client, user_store, cx);
channel_buffer::init(&client.clone().into());
channel_chat::init(&client.clone().into());
}

View File

@@ -1,861 +0,0 @@
use crate::{Channel, ChannelStore};
use anyhow::{Context as _, Result};
use client::{
ChannelId, Client, Subscription, TypedEnvelope, UserId, proto,
user::{User, UserStore},
};
use collections::HashSet;
use futures::lock::Mutex;
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
use rand::prelude::*;
use rpc::AnyProtoClient;
use std::{
ops::{ControlFlow, Range},
sync::Arc,
};
use sum_tree::{Bias, Dimensions, SumTree};
use time::OffsetDateTime;
use util::{ResultExt as _, TryFutureExt, post_inc};
pub struct ChannelChat {
pub channel_id: ChannelId,
messages: SumTree<ChannelMessage>,
acknowledged_message_ids: HashSet<u64>,
channel_store: Entity<ChannelStore>,
loaded_all_messages: bool,
last_acknowledged_id: Option<u64>,
next_pending_message_id: usize,
first_loaded_message_id: Option<u64>,
user_store: Entity<UserStore>,
rpc: Arc<Client>,
outgoing_messages_lock: Arc<Mutex<()>>,
rng: StdRng,
_subscription: Subscription,
}
#[derive(Debug, PartialEq, Eq)]
pub struct MessageParams {
pub text: String,
pub mentions: Vec<(Range<usize>, UserId)>,
pub reply_to_message_id: Option<u64>,
}
#[derive(Clone, Debug)]
pub struct ChannelMessage {
pub id: ChannelMessageId,
pub body: String,
pub timestamp: OffsetDateTime,
pub sender: Arc<User>,
pub nonce: u128,
pub mentions: Vec<(Range<usize>, UserId)>,
pub reply_to_message_id: Option<u64>,
pub edited_at: Option<OffsetDateTime>,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ChannelMessageId {
Saved(u64),
Pending(usize),
}
impl From<ChannelMessageId> for Option<u64> {
fn from(val: ChannelMessageId) -> Self {
match val {
ChannelMessageId::Saved(id) => Some(id),
ChannelMessageId::Pending(_) => None,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct ChannelMessageSummary {
max_id: ChannelMessageId,
count: usize,
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
struct Count(usize);
#[derive(Clone, Debug, PartialEq)]
pub enum ChannelChatEvent {
MessagesUpdated {
old_range: Range<usize>,
new_count: usize,
},
UpdateMessage {
message_id: ChannelMessageId,
message_ix: usize,
},
NewMessage {
channel_id: ChannelId,
message_id: u64,
},
}
impl EventEmitter<ChannelChatEvent> for ChannelChat {}
pub fn init(client: &AnyProtoClient) {
client.add_entity_message_handler(ChannelChat::handle_message_sent);
client.add_entity_message_handler(ChannelChat::handle_message_removed);
client.add_entity_message_handler(ChannelChat::handle_message_updated);
}
impl ChannelChat {
pub async fn new(
channel: Arc<Channel>,
channel_store: Entity<ChannelStore>,
user_store: Entity<UserStore>,
client: Arc<Client>,
cx: &mut AsyncApp,
) -> Result<Entity<Self>> {
let channel_id = channel.id;
let subscription = client.subscribe_to_entity(channel_id.0).unwrap();
let response = client
.request(proto::JoinChannelChat {
channel_id: channel_id.0,
})
.await?;
let handle = cx.new(|cx| {
cx.on_release(Self::release).detach();
Self {
channel_id: channel.id,
user_store: user_store.clone(),
channel_store,
rpc: client.clone(),
outgoing_messages_lock: Default::default(),
messages: Default::default(),
acknowledged_message_ids: Default::default(),
loaded_all_messages: false,
next_pending_message_id: 0,
last_acknowledged_id: None,
rng: StdRng::from_os_rng(),
first_loaded_message_id: None,
_subscription: subscription.set_entity(&cx.entity(), &cx.to_async()),
}
})?;
Self::handle_loaded_messages(
handle.downgrade(),
user_store,
client,
response.messages,
response.done,
cx,
)
.await?;
Ok(handle)
}
fn release(&mut self, _: &mut App) {
self.rpc
.send(proto::LeaveChannelChat {
channel_id: self.channel_id.0,
})
.log_err();
}
pub fn channel(&self, cx: &App) -> Option<Arc<Channel>> {
self.channel_store
.read(cx)
.channel_for_id(self.channel_id)
.cloned()
}
pub fn client(&self) -> &Arc<Client> {
&self.rpc
}
pub fn send_message(
&mut self,
message: MessageParams,
cx: &mut Context<Self>,
) -> Result<Task<Result<u64>>> {
anyhow::ensure!(
!message.text.trim().is_empty(),
"message body can't be empty"
);
let current_user = self
.user_store
.read(cx)
.current_user()
.context("current_user is not present")?;
let channel_id = self.channel_id;
let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
let nonce = self.rng.random();
self.insert_messages(
SumTree::from_item(
ChannelMessage {
id: pending_id,
body: message.text.clone(),
sender: current_user,
timestamp: OffsetDateTime::now_utc(),
mentions: message.mentions.clone(),
nonce,
reply_to_message_id: message.reply_to_message_id,
edited_at: None,
},
&(),
),
cx,
);
let user_store = self.user_store.clone();
let rpc = self.rpc.clone();
let outgoing_messages_lock = self.outgoing_messages_lock.clone();
// todo - handle messages that fail to send (e.g. >1024 chars)
Ok(cx.spawn(async move |this, cx| {
let outgoing_message_guard = outgoing_messages_lock.lock().await;
let request = rpc.request(proto::SendChannelMessage {
channel_id: channel_id.0,
body: message.text,
nonce: Some(nonce.into()),
mentions: mentions_to_proto(&message.mentions),
reply_to_message_id: message.reply_to_message_id,
});
let response = request.await?;
drop(outgoing_message_guard);
let response = response.message.context("invalid message")?;
let id = response.id;
let message = ChannelMessage::from_proto(response, &user_store, cx).await?;
this.update(cx, |this, cx| {
this.insert_messages(SumTree::from_item(message, &()), cx);
if this.first_loaded_message_id.is_none() {
this.first_loaded_message_id = Some(id);
}
})?;
Ok(id)
}))
}
pub fn remove_message(&mut self, id: u64, cx: &mut Context<Self>) -> Task<Result<()>> {
let response = self.rpc.request(proto::RemoveChannelMessage {
channel_id: self.channel_id.0,
message_id: id,
});
cx.spawn(async move |this, cx| {
response.await?;
this.update(cx, |this, cx| {
this.message_removed(id, cx);
})?;
Ok(())
})
}
pub fn update_message(
&mut self,
id: u64,
message: MessageParams,
cx: &mut Context<Self>,
) -> Result<Task<Result<()>>> {
self.message_update(
ChannelMessageId::Saved(id),
message.text.clone(),
message.mentions.clone(),
Some(OffsetDateTime::now_utc()),
cx,
);
let nonce: u128 = self.rng.random();
let request = self.rpc.request(proto::UpdateChannelMessage {
channel_id: self.channel_id.0,
message_id: id,
body: message.text,
nonce: Some(nonce.into()),
mentions: mentions_to_proto(&message.mentions),
});
Ok(cx.spawn(async move |_, _| {
request.await?;
Ok(())
}))
}
pub fn load_more_messages(&mut self, cx: &mut Context<Self>) -> Option<Task<Option<()>>> {
if self.loaded_all_messages {
return None;
}
let rpc = self.rpc.clone();
let user_store = self.user_store.clone();
let channel_id = self.channel_id;
let before_message_id = self.first_loaded_message_id()?;
Some(cx.spawn(async move |this, cx| {
async move {
let response = rpc
.request(proto::GetChannelMessages {
channel_id: channel_id.0,
before_message_id,
})
.await?;
Self::handle_loaded_messages(
this,
user_store,
rpc,
response.messages,
response.done,
cx,
)
.await?;
anyhow::Ok(())
}
.log_err()
.await
}))
}
pub fn first_loaded_message_id(&mut self) -> Option<u64> {
self.first_loaded_message_id
}
/// Load a message by its id, if it's already stored locally.
pub fn find_loaded_message(&self, id: u64) -> Option<&ChannelMessage> {
self.messages.iter().find(|message| match message.id {
ChannelMessageId::Saved(message_id) => message_id == id,
ChannelMessageId::Pending(_) => false,
})
}
/// Load all of the chat messages since a certain message id.
///
/// For now, we always maintain a suffix of the channel's messages.
pub async fn load_history_since_message(
chat: Entity<Self>,
message_id: u64,
mut cx: AsyncApp,
) -> Option<usize> {
loop {
let step = chat
.update(&mut cx, |chat, cx| {
if let Some(first_id) = chat.first_loaded_message_id()
&& first_id <= message_id
{
let mut cursor = chat
.messages
.cursor::<Dimensions<ChannelMessageId, Count>>(&());
let message_id = ChannelMessageId::Saved(message_id);
cursor.seek(&message_id, Bias::Left);
return ControlFlow::Break(
if cursor
.item()
.is_some_and(|message| message.id == message_id)
{
Some(cursor.start().1.0)
} else {
None
},
);
}
ControlFlow::Continue(chat.load_more_messages(cx))
})
.log_err()?;
match step {
ControlFlow::Break(ix) => return ix,
ControlFlow::Continue(task) => task?.await?,
}
}
}
pub fn acknowledge_last_message(&mut self, cx: &mut Context<Self>) {
if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id
&& self
.last_acknowledged_id
.is_none_or(|acknowledged_id| acknowledged_id < latest_message_id)
{
self.rpc
.send(proto::AckChannelMessage {
channel_id: self.channel_id.0,
message_id: latest_message_id,
})
.ok();
self.last_acknowledged_id = Some(latest_message_id);
self.channel_store.update(cx, |store, cx| {
store.acknowledge_message_id(self.channel_id, latest_message_id, cx);
});
}
}
async fn handle_loaded_messages(
this: WeakEntity<Self>,
user_store: Entity<UserStore>,
rpc: Arc<Client>,
proto_messages: Vec<proto::ChannelMessage>,
loaded_all_messages: bool,
cx: &mut AsyncApp,
) -> Result<()> {
let loaded_messages = messages_from_proto(proto_messages, &user_store, cx).await?;
let first_loaded_message_id = loaded_messages.first().map(|m| m.id);
let loaded_message_ids = this.read_with(cx, |this, _| {
let mut loaded_message_ids: HashSet<u64> = HashSet::default();
for message in loaded_messages.iter() {
if let Some(saved_message_id) = message.id.into() {
loaded_message_ids.insert(saved_message_id);
}
}
for message in this.messages.iter() {
if let Some(saved_message_id) = message.id.into() {
loaded_message_ids.insert(saved_message_id);
}
}
loaded_message_ids
})?;
let missing_ancestors = loaded_messages
.iter()
.filter_map(|message| {
if let Some(ancestor_id) = message.reply_to_message_id
&& !loaded_message_ids.contains(&ancestor_id)
{
return Some(ancestor_id);
}
None
})
.collect::<Vec<_>>();
let loaded_ancestors = if missing_ancestors.is_empty() {
None
} else {
let response = rpc
.request(proto::GetChannelMessagesById {
message_ids: missing_ancestors,
})
.await?;
Some(messages_from_proto(response.messages, &user_store, cx).await?)
};
this.update(cx, |this, cx| {
this.first_loaded_message_id = first_loaded_message_id.and_then(|msg_id| msg_id.into());
this.loaded_all_messages = loaded_all_messages;
this.insert_messages(loaded_messages, cx);
if let Some(loaded_ancestors) = loaded_ancestors {
this.insert_messages(loaded_ancestors, cx);
}
})?;
Ok(())
}
pub fn rejoin(&mut self, cx: &mut Context<Self>) {
let user_store = self.user_store.clone();
let rpc = self.rpc.clone();
let channel_id = self.channel_id;
cx.spawn(async move |this, cx| {
async move {
let response = rpc
.request(proto::JoinChannelChat {
channel_id: channel_id.0,
})
.await?;
Self::handle_loaded_messages(
this.clone(),
user_store.clone(),
rpc.clone(),
response.messages,
response.done,
cx,
)
.await?;
let pending_messages = this.read_with(cx, |this, _| {
this.pending_messages().cloned().collect::<Vec<_>>()
})?;
for pending_message in pending_messages {
let request = rpc.request(proto::SendChannelMessage {
channel_id: channel_id.0,
body: pending_message.body,
mentions: mentions_to_proto(&pending_message.mentions),
nonce: Some(pending_message.nonce.into()),
reply_to_message_id: pending_message.reply_to_message_id,
});
let response = request.await?;
let message = ChannelMessage::from_proto(
response.message.context("invalid message")?,
&user_store,
cx,
)
.await?;
this.update(cx, |this, cx| {
this.insert_messages(SumTree::from_item(message, &()), cx);
})?;
}
anyhow::Ok(())
}
.log_err()
.await
})
.detach();
}
pub fn message_count(&self) -> usize {
self.messages.summary().count
}
pub fn messages(&self) -> &SumTree<ChannelMessage> {
&self.messages
}
pub fn message(&self, ix: usize) -> &ChannelMessage {
let mut cursor = self.messages.cursor::<Count>(&());
cursor.seek(&Count(ix), Bias::Right);
cursor.item().unwrap()
}
pub fn acknowledge_message(&mut self, id: u64) {
if self.acknowledged_message_ids.insert(id) {
self.rpc
.send(proto::AckChannelMessage {
channel_id: self.channel_id.0,
message_id: id,
})
.ok();
}
}
pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
let mut cursor = self.messages.cursor::<Count>(&());
cursor.seek(&Count(range.start), Bias::Right);
cursor.take(range.len())
}
pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
let mut cursor = self.messages.cursor::<ChannelMessageId>(&());
cursor.seek(&ChannelMessageId::Pending(0), Bias::Left);
cursor
}
async fn handle_message_sent(
this: Entity<Self>,
message: TypedEnvelope<proto::ChannelMessageSent>,
mut cx: AsyncApp,
) -> Result<()> {
let user_store = this.read_with(&cx, |this, _| this.user_store.clone())?;
let message = message.payload.message.context("empty message")?;
let message_id = message.id;
let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
this.update(&mut cx, |this, cx| {
this.insert_messages(SumTree::from_item(message, &()), cx);
cx.emit(ChannelChatEvent::NewMessage {
channel_id: this.channel_id,
message_id,
})
})?;
Ok(())
}
async fn handle_message_removed(
this: Entity<Self>,
message: TypedEnvelope<proto::RemoveChannelMessage>,
mut cx: AsyncApp,
) -> Result<()> {
this.update(&mut cx, |this, cx| {
this.message_removed(message.payload.message_id, cx)
})?;
Ok(())
}
async fn handle_message_updated(
this: Entity<Self>,
message: TypedEnvelope<proto::ChannelMessageUpdate>,
mut cx: AsyncApp,
) -> Result<()> {
let user_store = this.read_with(&cx, |this, _| this.user_store.clone())?;
let message = message.payload.message.context("empty message")?;
let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
this.update(&mut cx, |this, cx| {
this.message_update(
message.id,
message.body,
message.mentions,
message.edited_at,
cx,
)
})?;
Ok(())
}
fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut Context<Self>) {
if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
let nonces = messages
.cursor::<()>(&())
.map(|m| m.nonce)
.collect::<HashSet<_>>();
let mut old_cursor = self
.messages
.cursor::<Dimensions<ChannelMessageId, Count>>(&());
let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left);
let start_ix = old_cursor.start().1.0;
let removed_messages = old_cursor.slice(&last_message.id, Bias::Right);
let removed_count = removed_messages.summary().count;
let new_count = messages.summary().count;
let end_ix = start_ix + removed_count;
new_messages.append(messages, &());
let mut ranges = Vec::<Range<usize>>::new();
if new_messages.last().unwrap().is_pending() {
new_messages.append(old_cursor.suffix(), &());
} else {
new_messages.append(
old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left),
&(),
);
while let Some(message) = old_cursor.item() {
let message_ix = old_cursor.start().1.0;
if nonces.contains(&message.nonce) {
if ranges.last().is_some_and(|r| r.end == message_ix) {
ranges.last_mut().unwrap().end += 1;
} else {
ranges.push(message_ix..message_ix + 1);
}
} else {
new_messages.push(message.clone(), &());
}
old_cursor.next();
}
}
drop(old_cursor);
self.messages = new_messages;
for range in ranges.into_iter().rev() {
cx.emit(ChannelChatEvent::MessagesUpdated {
old_range: range,
new_count: 0,
});
}
cx.emit(ChannelChatEvent::MessagesUpdated {
old_range: start_ix..end_ix,
new_count,
});
cx.notify();
}
}
fn message_removed(&mut self, id: u64, cx: &mut Context<Self>) {
let mut cursor = self.messages.cursor::<ChannelMessageId>(&());
let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left);
if let Some(item) = cursor.item()
&& item.id == ChannelMessageId::Saved(id)
{
let deleted_message_ix = messages.summary().count;
cursor.next();
messages.append(cursor.suffix(), &());
drop(cursor);
self.messages = messages;
// If the message that was deleted was the last acknowledged message,
// replace the acknowledged message with an earlier one.
self.channel_store.update(cx, |store, _| {
let summary = self.messages.summary();
if summary.count == 0 {
store.set_acknowledged_message_id(self.channel_id, None);
} else if deleted_message_ix == summary.count
&& let ChannelMessageId::Saved(id) = summary.max_id
{
store.set_acknowledged_message_id(self.channel_id, Some(id));
}
});
cx.emit(ChannelChatEvent::MessagesUpdated {
old_range: deleted_message_ix..deleted_message_ix + 1,
new_count: 0,
});
}
}
fn message_update(
&mut self,
id: ChannelMessageId,
body: String,
mentions: Vec<(Range<usize>, u64)>,
edited_at: Option<OffsetDateTime>,
cx: &mut Context<Self>,
) {
let mut cursor = self.messages.cursor::<ChannelMessageId>(&());
let mut messages = cursor.slice(&id, Bias::Left);
let ix = messages.summary().count;
if let Some(mut message_to_update) = cursor.item().cloned() {
message_to_update.body = body;
message_to_update.mentions = mentions;
message_to_update.edited_at = edited_at;
messages.push(message_to_update, &());
cursor.next();
}
messages.append(cursor.suffix(), &());
drop(cursor);
self.messages = messages;
cx.emit(ChannelChatEvent::UpdateMessage {
message_ix: ix,
message_id: id,
});
cx.notify();
}
}
async fn messages_from_proto(
proto_messages: Vec<proto::ChannelMessage>,
user_store: &Entity<UserStore>,
cx: &mut AsyncApp,
) -> Result<SumTree<ChannelMessage>> {
let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
let mut result = SumTree::default();
result.extend(messages, &());
Ok(result)
}
impl ChannelMessage {
pub async fn from_proto(
message: proto::ChannelMessage,
user_store: &Entity<UserStore>,
cx: &mut AsyncApp,
) -> Result<Self> {
let sender = user_store
.update(cx, |user_store, cx| {
user_store.get_user(message.sender_id, cx)
})?
.await?;
let edited_at = message.edited_at.and_then(|t| -> Option<OffsetDateTime> {
if let Ok(a) = OffsetDateTime::from_unix_timestamp(t as i64) {
return Some(a);
}
None
});
Ok(ChannelMessage {
id: ChannelMessageId::Saved(message.id),
body: message.body,
mentions: message
.mentions
.into_iter()
.filter_map(|mention| {
let range = mention.range?;
Some((range.start as usize..range.end as usize, mention.user_id))
})
.collect(),
timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
sender,
nonce: message.nonce.context("nonce is required")?.into(),
reply_to_message_id: message.reply_to_message_id,
edited_at,
})
}
pub fn is_pending(&self) -> bool {
matches!(self.id, ChannelMessageId::Pending(_))
}
pub async fn from_proto_vec(
proto_messages: Vec<proto::ChannelMessage>,
user_store: &Entity<UserStore>,
cx: &mut AsyncApp,
) -> Result<Vec<Self>> {
let unique_user_ids = proto_messages
.iter()
.map(|m| m.sender_id)
.collect::<HashSet<_>>()
.into_iter()
.collect();
user_store
.update(cx, |user_store, cx| {
user_store.get_users(unique_user_ids, cx)
})?
.await?;
let mut messages = Vec::with_capacity(proto_messages.len());
for message in proto_messages {
messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
}
Ok(messages)
}
}
pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
mentions
.iter()
.map(|(range, user_id)| proto::ChatMention {
range: Some(proto::Range {
start: range.start as u64,
end: range.end as u64,
}),
user_id: *user_id,
})
.collect()
}
impl sum_tree::Item for ChannelMessage {
type Summary = ChannelMessageSummary;
fn summary(&self, _cx: &()) -> Self::Summary {
ChannelMessageSummary {
max_id: self.id,
count: 1,
}
}
}
impl Default for ChannelMessageId {
fn default() -> Self {
Self::Saved(0)
}
}
impl sum_tree::Summary for ChannelMessageSummary {
type Context = ();
fn zero(_cx: &Self::Context) -> Self {
Default::default()
}
fn add_summary(&mut self, summary: &Self, _: &()) {
self.max_id = summary.max_id;
self.count += summary.count;
}
}
impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for ChannelMessageId {
fn zero(_cx: &()) -> Self {
Default::default()
}
fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
debug_assert!(summary.max_id > *self);
*self = summary.max_id;
}
}
impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
fn zero(_cx: &()) -> Self {
Default::default()
}
fn add_summary(&mut self, summary: &'a ChannelMessageSummary, _: &()) {
self.0 += summary.count;
}
}
impl<'a> From<&'a str> for MessageParams {
fn from(value: &'a str) -> Self {
Self {
text: value.into(),
mentions: Vec::new(),
reply_to_message_id: None,
}
}
}

View File

@@ -1,6 +1,6 @@
mod channel_index;
use crate::{ChannelMessage, channel_buffer::ChannelBuffer, channel_chat::ChannelChat};
use crate::channel_buffer::ChannelBuffer;
use anyhow::{Context as _, Result, anyhow};
use channel_index::ChannelIndex;
use client::{ChannelId, Client, ClientSettings, Subscription, User, UserId, UserStore};
@@ -41,7 +41,6 @@ pub struct ChannelStore {
outgoing_invites: HashSet<(ChannelId, UserId)>,
update_channels_tx: mpsc::UnboundedSender<proto::UpdateChannels>,
opened_buffers: HashMap<ChannelId, OpenEntityHandle<ChannelBuffer>>,
opened_chats: HashMap<ChannelId, OpenEntityHandle<ChannelChat>>,
client: Arc<Client>,
did_subscribe: bool,
channels_loaded: (watch::Sender<bool>, watch::Receiver<bool>),
@@ -63,10 +62,8 @@ pub struct Channel {
#[derive(Default, Debug)]
pub struct ChannelState {
latest_chat_message: Option<u64>,
latest_notes_version: NotesVersion,
observed_notes_version: NotesVersion,
observed_chat_message: Option<u64>,
role: Option<ChannelRole>,
}
@@ -196,7 +193,6 @@ impl ChannelStore {
channel_participants: Default::default(),
outgoing_invites: Default::default(),
opened_buffers: Default::default(),
opened_chats: Default::default(),
update_channels_tx,
client,
user_store,
@@ -362,89 +358,12 @@ impl ChannelStore {
)
}
pub fn fetch_channel_messages(
&self,
message_ids: Vec<u64>,
cx: &mut Context<Self>,
) -> Task<Result<Vec<ChannelMessage>>> {
let request = if message_ids.is_empty() {
None
} else {
Some(
self.client
.request(proto::GetChannelMessagesById { message_ids }),
)
};
cx.spawn(async move |this, cx| {
if let Some(request) = request {
let response = request.await?;
let this = this.upgrade().context("channel store dropped")?;
let user_store = this.read_with(cx, |this, _| this.user_store.clone())?;
ChannelMessage::from_proto_vec(response.messages, &user_store, cx).await
} else {
Ok(Vec::new())
}
})
}
pub fn has_channel_buffer_changed(&self, channel_id: ChannelId) -> bool {
self.channel_states
.get(&channel_id)
.is_some_and(|state| state.has_channel_buffer_changed())
}
pub fn has_new_messages(&self, channel_id: ChannelId) -> bool {
self.channel_states
.get(&channel_id)
.is_some_and(|state| state.has_new_messages())
}
pub fn set_acknowledged_message_id(&mut self, channel_id: ChannelId, message_id: Option<u64>) {
if let Some(state) = self.channel_states.get_mut(&channel_id) {
state.latest_chat_message = message_id;
}
}
pub fn last_acknowledge_message_id(&self, channel_id: ChannelId) -> Option<u64> {
self.channel_states.get(&channel_id).and_then(|state| {
if let Some(last_message_id) = state.latest_chat_message
&& state
.last_acknowledged_message_id()
.is_some_and(|id| id < last_message_id)
{
return state.last_acknowledged_message_id();
}
None
})
}
pub fn acknowledge_message_id(
&mut self,
channel_id: ChannelId,
message_id: u64,
cx: &mut Context<Self>,
) {
self.channel_states
.entry(channel_id)
.or_default()
.acknowledge_message_id(message_id);
cx.notify();
}
pub fn update_latest_message_id(
&mut self,
channel_id: ChannelId,
message_id: u64,
cx: &mut Context<Self>,
) {
self.channel_states
.entry(channel_id)
.or_default()
.update_latest_message_id(message_id);
cx.notify();
}
pub fn acknowledge_notes_version(
&mut self,
channel_id: ChannelId,
@@ -473,23 +392,6 @@ impl ChannelStore {
cx.notify()
}
pub fn open_channel_chat(
&mut self,
channel_id: ChannelId,
cx: &mut Context<Self>,
) -> Task<Result<Entity<ChannelChat>>> {
let client = self.client.clone();
let user_store = self.user_store.clone();
let this = cx.entity();
self.open_channel_resource(
channel_id,
"chat",
|this| &mut this.opened_chats,
async move |channel, cx| ChannelChat::new(channel, this, user_store, client, cx).await,
cx,
)
}
/// Asynchronously open a given resource associated with a channel.
///
/// Make sure that the resource is only opened once, even if this method
@@ -931,13 +833,6 @@ impl ChannelStore {
cx,
);
}
for message_id in message.payload.observed_channel_message_id {
this.acknowledge_message_id(
ChannelId(message_id.channel_id),
message_id.message_id,
cx,
);
}
for membership in message.payload.channel_memberships {
if let Some(role) = ChannelRole::from_i32(membership.role) {
this.channel_states
@@ -957,16 +852,6 @@ impl ChannelStore {
self.outgoing_invites.clear();
self.disconnect_channel_buffers_task.take();
for chat in self.opened_chats.values() {
if let OpenEntityHandle::Open(chat) = chat
&& let Some(chat) = chat.upgrade()
{
chat.update(cx, |chat, cx| {
chat.rejoin(cx);
});
}
}
let mut buffer_versions = Vec::new();
for buffer in self.opened_buffers.values() {
if let OpenEntityHandle::Open(buffer) = buffer
@@ -1094,7 +979,6 @@ impl ChannelStore {
self.channel_participants.clear();
self.outgoing_invites.clear();
self.opened_buffers.clear();
self.opened_chats.clear();
self.disconnect_channel_buffers_task = None;
self.channel_states.clear();
}
@@ -1131,7 +1015,6 @@ impl ChannelStore {
let channels_changed = !payload.channels.is_empty()
|| !payload.delete_channels.is_empty()
|| !payload.latest_channel_message_ids.is_empty()
|| !payload.latest_channel_buffer_versions.is_empty();
if channels_changed {
@@ -1181,13 +1064,6 @@ impl ChannelStore {
.update_latest_notes_version(latest_buffer_version.epoch, &version)
}
for latest_channel_message in payload.latest_channel_message_ids {
self.channel_states
.entry(ChannelId(latest_channel_message.channel_id))
.or_default()
.update_latest_message_id(latest_channel_message.message_id);
}
self.channels_loaded.0.try_send(true).log_err();
}
@@ -1251,29 +1127,6 @@ impl ChannelState {
.changed_since(&self.observed_notes_version.version))
}
fn has_new_messages(&self) -> bool {
let latest_message_id = self.latest_chat_message;
let observed_message_id = self.observed_chat_message;
latest_message_id.is_some_and(|latest_message_id| {
latest_message_id > observed_message_id.unwrap_or_default()
})
}
fn last_acknowledged_message_id(&self) -> Option<u64> {
self.observed_chat_message
}
fn acknowledge_message_id(&mut self, message_id: u64) {
let observed = self.observed_chat_message.get_or_insert(message_id);
*observed = (*observed).max(message_id);
}
fn update_latest_message_id(&mut self, message_id: u64) {
self.latest_chat_message =
Some(message_id.max(self.latest_chat_message.unwrap_or_default()));
}
fn acknowledge_notes_version(&mut self, epoch: u64, version: &clock::Global) {
if self.observed_notes_version.epoch == epoch {
self.observed_notes_version.version.join(version);

View File

@@ -1,9 +1,7 @@
use crate::channel_chat::ChannelChatEvent;
use super::*;
use client::{Client, UserStore, test::FakeServer};
use client::{Client, UserStore};
use clock::FakeSystemClock;
use gpui::{App, AppContext as _, Entity, SemanticVersion, TestAppContext};
use gpui::{App, AppContext as _, Entity, SemanticVersion};
use http_client::FakeHttpClient;
use rpc::proto::{self};
use settings::SettingsStore;
@@ -235,201 +233,6 @@ fn test_dangling_channel_paths(cx: &mut App) {
assert_channels(&channel_store, &[(0, "a".to_string())], cx);
}
#[gpui::test]
async fn test_channel_messages(cx: &mut TestAppContext) {
let user_id = 5;
let channel_id = 5;
let channel_store = cx.update(init_test);
let client = channel_store.read_with(cx, |s, _| s.client());
let server = FakeServer::for_client(user_id, &client, cx).await;
// Get the available channels.
server.send(proto::UpdateChannels {
channels: vec![proto::Channel {
id: channel_id,
name: "the-channel".to_string(),
visibility: proto::ChannelVisibility::Members as i32,
parent_path: vec![],
channel_order: 1,
}],
..Default::default()
});
cx.executor().run_until_parked();
cx.update(|cx| {
assert_channels(&channel_store, &[(0, "the-channel".to_string())], cx);
});
// Join a channel and populate its existing messages.
let channel = channel_store.update(cx, |store, cx| {
let channel_id = store.ordered_channels().next().unwrap().1.id;
store.open_channel_chat(channel_id, cx)
});
let join_channel = server.receive::<proto::JoinChannelChat>().await.unwrap();
server.respond(
join_channel.receipt(),
proto::JoinChannelChatResponse {
messages: vec![
proto::ChannelMessage {
id: 10,
body: "a".into(),
timestamp: 1000,
sender_id: 5,
mentions: vec![],
nonce: Some(1.into()),
reply_to_message_id: None,
edited_at: None,
},
proto::ChannelMessage {
id: 11,
body: "b".into(),
timestamp: 1001,
sender_id: 6,
mentions: vec![],
nonce: Some(2.into()),
reply_to_message_id: None,
edited_at: None,
},
],
done: false,
},
);
cx.executor().start_waiting();
// Client requests all users for the received messages
let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
get_users.payload.user_ids.sort();
assert_eq!(get_users.payload.user_ids, vec![6]);
server.respond(
get_users.receipt(),
proto::UsersResponse {
users: vec![proto::User {
id: 6,
github_login: "maxbrunsfeld".into(),
avatar_url: "http://avatar.com/maxbrunsfeld".into(),
name: None,
}],
},
);
let channel = channel.await.unwrap();
channel.update(cx, |channel, _| {
assert_eq!(
channel
.messages_in_range(0..2)
.map(|message| (message.sender.github_login.clone(), message.body.clone()))
.collect::<Vec<_>>(),
&[
("user-5".into(), "a".into()),
("maxbrunsfeld".into(), "b".into())
]
);
});
// Receive a new message.
server.send(proto::ChannelMessageSent {
channel_id,
message: Some(proto::ChannelMessage {
id: 12,
body: "c".into(),
timestamp: 1002,
sender_id: 7,
mentions: vec![],
nonce: Some(3.into()),
reply_to_message_id: None,
edited_at: None,
}),
});
// Client requests user for message since they haven't seen them yet
let get_users = server.receive::<proto::GetUsers>().await.unwrap();
assert_eq!(get_users.payload.user_ids, vec![7]);
server.respond(
get_users.receipt(),
proto::UsersResponse {
users: vec![proto::User {
id: 7,
github_login: "as-cii".into(),
avatar_url: "http://avatar.com/as-cii".into(),
name: None,
}],
},
);
assert_eq!(
channel.next_event(cx).await,
ChannelChatEvent::MessagesUpdated {
old_range: 2..2,
new_count: 1,
}
);
channel.update(cx, |channel, _| {
assert_eq!(
channel
.messages_in_range(2..3)
.map(|message| (message.sender.github_login.clone(), message.body.clone()))
.collect::<Vec<_>>(),
&[("as-cii".into(), "c".into())]
)
});
// Scroll up to view older messages.
channel.update(cx, |channel, cx| {
channel.load_more_messages(cx).unwrap().detach();
});
let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
assert_eq!(get_messages.payload.channel_id, 5);
assert_eq!(get_messages.payload.before_message_id, 10);
server.respond(
get_messages.receipt(),
proto::GetChannelMessagesResponse {
done: true,
messages: vec![
proto::ChannelMessage {
id: 8,
body: "y".into(),
timestamp: 998,
sender_id: 5,
nonce: Some(4.into()),
mentions: vec![],
reply_to_message_id: None,
edited_at: None,
},
proto::ChannelMessage {
id: 9,
body: "z".into(),
timestamp: 999,
sender_id: 6,
nonce: Some(5.into()),
mentions: vec![],
reply_to_message_id: None,
edited_at: None,
},
],
},
);
assert_eq!(
channel.next_event(cx).await,
ChannelChatEvent::MessagesUpdated {
old_range: 0..0,
new_count: 2,
}
);
channel.update(cx, |channel, _| {
assert_eq!(
channel
.messages_in_range(0..2)
.map(|message| (message.sender.github_login.clone(), message.body.clone()))
.collect::<Vec<_>>(),
&[
("user-5".into(), "y".into()),
("maxbrunsfeld".into(), "z".into())
]
);
});
}
fn init_test(cx: &mut App) -> Entity<ChannelStore> {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);

View File

@@ -82,34 +82,10 @@ pub enum Plan {
ZedFree,
#[serde(alias = "ZedPro")]
ZedPro,
ZedProV2,
#[serde(alias = "ZedProTrial")]
ZedProTrial,
}
impl Plan {
pub fn as_str(&self) -> &'static str {
match self {
Plan::ZedFree => "zed_free",
Plan::ZedPro => "zed_pro",
Plan::ZedProTrial => "zed_pro_trial",
}
}
pub fn model_requests_limit(&self) -> UsageLimit {
match self {
Plan::ZedPro => UsageLimit::Limited(500),
Plan::ZedProTrial => UsageLimit::Limited(150),
Plan::ZedFree => UsageLimit::Limited(50),
}
}
pub fn edit_predictions_limit(&self) -> UsageLimit {
match self {
Plan::ZedPro => UsageLimit::Unlimited,
Plan::ZedProTrial => UsageLimit::Unlimited,
Plan::ZedFree => UsageLimit::Limited(2_000),
}
}
ZedProTrialV2,
}
impl FromStr for Plan {
@@ -353,6 +329,12 @@ mod tests {
let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial")).unwrap();
assert_eq!(plan, Plan::ZedProTrial);
let plan = serde_json::from_value::<Plan>(json!("zed_pro_v2")).unwrap();
assert_eq!(plan, Plan::ZedProV2);
let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial_v2")).unwrap();
assert_eq!(plan, Plan::ZedProTrialV2);
}
#[test]

View File

@@ -26,7 +26,6 @@ use semantic_version::SemanticVersion;
use serde::{Deserialize, Serialize};
use std::ops::RangeInclusive;
use std::{
fmt::Write as _,
future::Future,
marker::PhantomData,
ops::{Deref, DerefMut},
@@ -486,9 +485,7 @@ pub struct ChannelsForUser {
pub invited_channels: Vec<Channel>,
pub observed_buffer_versions: Vec<proto::ChannelBufferVersion>,
pub observed_channel_messages: Vec<proto::ChannelMessageId>,
pub latest_buffer_versions: Vec<proto::ChannelBufferVersion>,
pub latest_channel_messages: Vec<proto::ChannelMessageId>,
}
#[derive(Debug)]

View File

@@ -7,7 +7,6 @@ pub mod contacts;
pub mod contributors;
pub mod embeddings;
pub mod extensions;
pub mod messages;
pub mod notifications;
pub mod projects;
pub mod rooms;

View File

@@ -618,25 +618,17 @@ impl Database {
}
drop(rows);
let latest_channel_messages = self.latest_channel_messages(&channel_ids, tx).await?;
let observed_buffer_versions = self
.observed_channel_buffer_changes(&channel_ids_by_buffer_id, user_id, tx)
.await?;
let observed_channel_messages = self
.observed_channel_messages(&channel_ids, user_id, tx)
.await?;
Ok(ChannelsForUser {
channel_memberships,
channels,
invited_channels,
channel_participants,
latest_buffer_versions,
latest_channel_messages,
observed_buffer_versions,
observed_channel_messages,
})
}

View File

@@ -1,725 +0,0 @@
use super::*;
use anyhow::Context as _;
use rpc::Notification;
use sea_orm::{SelectColumns, TryInsertResult};
use time::OffsetDateTime;
use util::ResultExt;
impl Database {
/// Inserts a record representing a user joining the chat for a given channel.
pub async fn join_channel_chat(
&self,
channel_id: ChannelId,
connection_id: ConnectionId,
user_id: UserId,
) -> Result<()> {
self.transaction(|tx| async move {
let channel = self.get_channel_internal(channel_id, &tx).await?;
self.check_user_is_channel_participant(&channel, user_id, &tx)
.await?;
channel_chat_participant::ActiveModel {
id: ActiveValue::NotSet,
channel_id: ActiveValue::Set(channel_id),
user_id: ActiveValue::Set(user_id),
connection_id: ActiveValue::Set(connection_id.id as i32),
connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
}
.insert(&*tx)
.await?;
Ok(())
})
.await
}
/// Removes `channel_chat_participant` records associated with the given connection ID.
pub async fn channel_chat_connection_lost(
&self,
connection_id: ConnectionId,
tx: &DatabaseTransaction,
) -> Result<()> {
channel_chat_participant::Entity::delete_many()
.filter(
Condition::all()
.add(
channel_chat_participant::Column::ConnectionServerId
.eq(connection_id.owner_id),
)
.add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id)),
)
.exec(tx)
.await?;
Ok(())
}
/// Removes `channel_chat_participant` records associated with the given user ID so they
/// will no longer get chat notifications.
pub async fn leave_channel_chat(
&self,
channel_id: ChannelId,
connection_id: ConnectionId,
_user_id: UserId,
) -> Result<()> {
self.transaction(|tx| async move {
channel_chat_participant::Entity::delete_many()
.filter(
Condition::all()
.add(
channel_chat_participant::Column::ConnectionServerId
.eq(connection_id.owner_id),
)
.add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
.add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
)
.exec(&*tx)
.await?;
Ok(())
})
.await
}
/// Retrieves the messages in the specified channel.
///
/// Use `before_message_id` to paginate through the channel's messages.
pub async fn get_channel_messages(
&self,
channel_id: ChannelId,
user_id: UserId,
count: usize,
before_message_id: Option<MessageId>,
) -> Result<Vec<proto::ChannelMessage>> {
self.transaction(|tx| async move {
let channel = self.get_channel_internal(channel_id, &tx).await?;
self.check_user_is_channel_participant(&channel, user_id, &tx)
.await?;
let mut condition =
Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
if let Some(before_message_id) = before_message_id {
condition = condition.add(channel_message::Column::Id.lt(before_message_id));
}
let rows = channel_message::Entity::find()
.filter(condition)
.order_by_desc(channel_message::Column::Id)
.limit(count as u64)
.all(&*tx)
.await?;
self.load_channel_messages(rows, &tx).await
})
.await
}
/// Returns the channel messages with the given IDs.
pub async fn get_channel_messages_by_id(
&self,
user_id: UserId,
message_ids: &[MessageId],
) -> Result<Vec<proto::ChannelMessage>> {
self.transaction(|tx| async move {
let rows = channel_message::Entity::find()
.filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
.order_by_desc(channel_message::Column::Id)
.all(&*tx)
.await?;
let mut channels = HashMap::<ChannelId, channel::Model>::default();
for row in &rows {
channels.insert(
row.channel_id,
self.get_channel_internal(row.channel_id, &tx).await?,
);
}
for (_, channel) in channels {
self.check_user_is_channel_participant(&channel, user_id, &tx)
.await?;
}
let messages = self.load_channel_messages(rows, &tx).await?;
Ok(messages)
})
.await
}
async fn load_channel_messages(
&self,
rows: Vec<channel_message::Model>,
tx: &DatabaseTransaction,
) -> Result<Vec<proto::ChannelMessage>> {
let mut messages = rows
.into_iter()
.map(|row| {
let nonce = row.nonce.as_u64_pair();
proto::ChannelMessage {
id: row.id.to_proto(),
sender_id: row.sender_id.to_proto(),
body: row.body,
timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
mentions: vec![],
nonce: Some(proto::Nonce {
upper_half: nonce.0,
lower_half: nonce.1,
}),
reply_to_message_id: row.reply_to_message_id.map(|id| id.to_proto()),
edited_at: row
.edited_at
.map(|t| t.assume_utc().unix_timestamp() as u64),
}
})
.collect::<Vec<_>>();
messages.reverse();
let mut mentions = channel_message_mention::Entity::find()
.filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
.order_by_asc(channel_message_mention::Column::MessageId)
.order_by_asc(channel_message_mention::Column::StartOffset)
.stream(tx)
.await?;
let mut message_ix = 0;
while let Some(mention) = mentions.next().await {
let mention = mention?;
let message_id = mention.message_id.to_proto();
while let Some(message) = messages.get_mut(message_ix) {
if message.id < message_id {
message_ix += 1;
} else {
if message.id == message_id {
message.mentions.push(proto::ChatMention {
range: Some(proto::Range {
start: mention.start_offset as u64,
end: mention.end_offset as u64,
}),
user_id: mention.user_id.to_proto(),
});
}
break;
}
}
}
Ok(messages)
}
fn format_mentions_to_entities(
&self,
message_id: MessageId,
body: &str,
mentions: &[proto::ChatMention],
) -> Result<Vec<tables::channel_message_mention::ActiveModel>> {
Ok(mentions
.iter()
.filter_map(|mention| {
let range = mention.range.as_ref()?;
if !body.is_char_boundary(range.start as usize)
|| !body.is_char_boundary(range.end as usize)
{
return None;
}
Some(channel_message_mention::ActiveModel {
message_id: ActiveValue::Set(message_id),
start_offset: ActiveValue::Set(range.start as i32),
end_offset: ActiveValue::Set(range.end as i32),
user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
})
})
.collect::<Vec<_>>())
}
/// Creates a new channel message.
pub async fn create_channel_message(
&self,
channel_id: ChannelId,
user_id: UserId,
body: &str,
mentions: &[proto::ChatMention],
timestamp: OffsetDateTime,
nonce: u128,
reply_to_message_id: Option<MessageId>,
) -> Result<CreatedChannelMessage> {
self.transaction(|tx| async move {
let channel = self.get_channel_internal(channel_id, &tx).await?;
self.check_user_is_channel_participant(&channel, user_id, &tx)
.await?;
let mut rows = channel_chat_participant::Entity::find()
.filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
.stream(&*tx)
.await?;
let mut is_participant = false;
let mut participant_connection_ids = HashSet::default();
let mut participant_user_ids = Vec::new();
while let Some(row) = rows.next().await {
let row = row?;
if row.user_id == user_id {
is_participant = true;
}
participant_user_ids.push(row.user_id);
participant_connection_ids.insert(row.connection());
}
drop(rows);
if !is_participant {
Err(anyhow!("not a chat participant"))?;
}
let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
let result = channel_message::Entity::insert(channel_message::ActiveModel {
channel_id: ActiveValue::Set(channel_id),
sender_id: ActiveValue::Set(user_id),
body: ActiveValue::Set(body.to_string()),
sent_at: ActiveValue::Set(timestamp),
nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
id: ActiveValue::NotSet,
reply_to_message_id: ActiveValue::Set(reply_to_message_id),
edited_at: ActiveValue::NotSet,
})
.on_conflict(
OnConflict::columns([
channel_message::Column::SenderId,
channel_message::Column::Nonce,
])
.do_nothing()
.to_owned(),
)
.do_nothing()
.exec(&*tx)
.await?;
let message_id;
let mut notifications = Vec::new();
match result {
TryInsertResult::Inserted(result) => {
message_id = result.last_insert_id;
let mentioned_user_ids =
mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
let mentions = self.format_mentions_to_entities(message_id, body, mentions)?;
if !mentions.is_empty() {
channel_message_mention::Entity::insert_many(mentions)
.exec(&*tx)
.await?;
}
for mentioned_user in mentioned_user_ids {
notifications.extend(
self.create_notification(
UserId::from_proto(mentioned_user),
rpc::Notification::ChannelMessageMention {
message_id: message_id.to_proto(),
sender_id: user_id.to_proto(),
channel_id: channel_id.to_proto(),
},
false,
&tx,
)
.await?,
);
}
self.observe_channel_message_internal(channel_id, user_id, message_id, &tx)
.await?;
}
_ => {
message_id = channel_message::Entity::find()
.filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
.one(&*tx)
.await?
.context("failed to insert message")?
.id;
}
}
Ok(CreatedChannelMessage {
message_id,
participant_connection_ids,
notifications,
})
})
.await
}
pub async fn observe_channel_message(
&self,
channel_id: ChannelId,
user_id: UserId,
message_id: MessageId,
) -> Result<NotificationBatch> {
self.transaction(|tx| async move {
self.observe_channel_message_internal(channel_id, user_id, message_id, &tx)
.await?;
let mut batch = NotificationBatch::default();
batch.extend(
self.mark_notification_as_read(
user_id,
&Notification::ChannelMessageMention {
message_id: message_id.to_proto(),
sender_id: Default::default(),
channel_id: Default::default(),
},
&tx,
)
.await?,
);
Ok(batch)
})
.await
}
async fn observe_channel_message_internal(
&self,
channel_id: ChannelId,
user_id: UserId,
message_id: MessageId,
tx: &DatabaseTransaction,
) -> Result<()> {
observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
user_id: ActiveValue::Set(user_id),
channel_id: ActiveValue::Set(channel_id),
channel_message_id: ActiveValue::Set(message_id),
})
.on_conflict(
OnConflict::columns([
observed_channel_messages::Column::ChannelId,
observed_channel_messages::Column::UserId,
])
.update_column(observed_channel_messages::Column::ChannelMessageId)
.action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
.to_owned(),
)
// TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
.exec_without_returning(tx)
.await?;
Ok(())
}
pub async fn observed_channel_messages(
&self,
channel_ids: &[ChannelId],
user_id: UserId,
tx: &DatabaseTransaction,
) -> Result<Vec<proto::ChannelMessageId>> {
let rows = observed_channel_messages::Entity::find()
.filter(observed_channel_messages::Column::UserId.eq(user_id))
.filter(
observed_channel_messages::Column::ChannelId
.is_in(channel_ids.iter().map(|id| id.0)),
)
.all(tx)
.await?;
Ok(rows
.into_iter()
.map(|message| proto::ChannelMessageId {
channel_id: message.channel_id.to_proto(),
message_id: message.channel_message_id.to_proto(),
})
.collect())
}
pub async fn latest_channel_messages(
&self,
channel_ids: &[ChannelId],
tx: &DatabaseTransaction,
) -> Result<Vec<proto::ChannelMessageId>> {
let mut values = String::new();
for id in channel_ids {
if !values.is_empty() {
values.push_str(", ");
}
write!(&mut values, "({})", id).unwrap();
}
if values.is_empty() {
return Ok(Vec::default());
}
let sql = format!(
r#"
SELECT
*
FROM (
SELECT
*,
row_number() OVER (
PARTITION BY channel_id
ORDER BY id DESC
) as row_number
FROM channel_messages
WHERE
channel_id in ({values})
) AS messages
WHERE
row_number = 1
"#,
);
let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
let mut last_messages = channel_message::Model::find_by_statement(stmt)
.stream(tx)
.await?;
let mut results = Vec::new();
while let Some(result) = last_messages.next().await {
let message = result?;
results.push(proto::ChannelMessageId {
channel_id: message.channel_id.to_proto(),
message_id: message.id.to_proto(),
});
}
Ok(results)
}
fn get_notification_kind_id_by_name(&self, notification_kind: &str) -> Option<i32> {
self.notification_kinds_by_id
.iter()
.find(|(_, kind)| **kind == notification_kind)
.map(|kind| kind.0.0)
}
/// Removes the channel message with the given ID.
pub async fn remove_channel_message(
&self,
channel_id: ChannelId,
message_id: MessageId,
user_id: UserId,
) -> Result<(Vec<ConnectionId>, Vec<NotificationId>)> {
self.transaction(|tx| async move {
let mut rows = channel_chat_participant::Entity::find()
.filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
.stream(&*tx)
.await?;
let mut is_participant = false;
let mut participant_connection_ids = Vec::new();
while let Some(row) = rows.next().await {
let row = row?;
if row.user_id == user_id {
is_participant = true;
}
participant_connection_ids.push(row.connection());
}
drop(rows);
if !is_participant {
Err(anyhow!("not a chat participant"))?;
}
let result = channel_message::Entity::delete_by_id(message_id)
.filter(channel_message::Column::SenderId.eq(user_id))
.exec(&*tx)
.await?;
if result.rows_affected == 0 {
let channel = self.get_channel_internal(channel_id, &tx).await?;
if self
.check_user_is_channel_admin(&channel, user_id, &tx)
.await
.is_ok()
{
let result = channel_message::Entity::delete_by_id(message_id)
.exec(&*tx)
.await?;
if result.rows_affected == 0 {
Err(anyhow!("no such message"))?;
}
} else {
Err(anyhow!("operation could not be completed"))?;
}
}
let notification_kind_id =
self.get_notification_kind_id_by_name("ChannelMessageMention");
let existing_notifications = notification::Entity::find()
.filter(notification::Column::EntityId.eq(message_id))
.filter(notification::Column::Kind.eq(notification_kind_id))
.select_column(notification::Column::Id)
.all(&*tx)
.await?;
let existing_notification_ids = existing_notifications
.into_iter()
.map(|notification| notification.id)
.collect();
// remove all the mention notifications for this message
notification::Entity::delete_many()
.filter(notification::Column::EntityId.eq(message_id))
.filter(notification::Column::Kind.eq(notification_kind_id))
.exec(&*tx)
.await?;
Ok((participant_connection_ids, existing_notification_ids))
})
.await
}
/// Updates the channel message with the given ID, body and timestamp(edited_at).
pub async fn update_channel_message(
&self,
channel_id: ChannelId,
message_id: MessageId,
user_id: UserId,
body: &str,
mentions: &[proto::ChatMention],
edited_at: OffsetDateTime,
) -> Result<UpdatedChannelMessage> {
self.transaction(|tx| async move {
let channel = self.get_channel_internal(channel_id, &tx).await?;
self.check_user_is_channel_participant(&channel, user_id, &tx)
.await?;
let mut rows = channel_chat_participant::Entity::find()
.filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
.stream(&*tx)
.await?;
let mut is_participant = false;
let mut participant_connection_ids = Vec::new();
let mut participant_user_ids = Vec::new();
while let Some(row) = rows.next().await {
let row = row?;
if row.user_id == user_id {
is_participant = true;
}
participant_user_ids.push(row.user_id);
participant_connection_ids.push(row.connection());
}
drop(rows);
if !is_participant {
Err(anyhow!("not a chat participant"))?;
}
let channel_message = channel_message::Entity::find_by_id(message_id)
.filter(channel_message::Column::SenderId.eq(user_id))
.one(&*tx)
.await?;
let Some(channel_message) = channel_message else {
Err(anyhow!("Channel message not found"))?
};
let edited_at = edited_at.to_offset(time::UtcOffset::UTC);
let edited_at = time::PrimitiveDateTime::new(edited_at.date(), edited_at.time());
let updated_message = channel_message::ActiveModel {
body: ActiveValue::Set(body.to_string()),
edited_at: ActiveValue::Set(Some(edited_at)),
reply_to_message_id: ActiveValue::Unchanged(channel_message.reply_to_message_id),
id: ActiveValue::Unchanged(message_id),
channel_id: ActiveValue::Unchanged(channel_id),
sender_id: ActiveValue::Unchanged(user_id),
sent_at: ActiveValue::Unchanged(channel_message.sent_at),
nonce: ActiveValue::Unchanged(channel_message.nonce),
};
let result = channel_message::Entity::update_many()
.set(updated_message)
.filter(channel_message::Column::Id.eq(message_id))
.filter(channel_message::Column::SenderId.eq(user_id))
.exec(&*tx)
.await?;
if result.rows_affected == 0 {
return Err(anyhow!(
"Attempted to edit a message (id: {message_id}) which does not exist anymore."
))?;
}
// we have to fetch the old mentions,
// so we don't send a notification when the message has been edited that you are mentioned in
let old_mentions = channel_message_mention::Entity::find()
.filter(channel_message_mention::Column::MessageId.eq(message_id))
.all(&*tx)
.await?;
// remove all existing mentions
channel_message_mention::Entity::delete_many()
.filter(channel_message_mention::Column::MessageId.eq(message_id))
.exec(&*tx)
.await?;
let new_mentions = self.format_mentions_to_entities(message_id, body, mentions)?;
if !new_mentions.is_empty() {
// insert new mentions
channel_message_mention::Entity::insert_many(new_mentions)
.exec(&*tx)
.await?;
}
let mut update_mention_user_ids = HashSet::default();
let mut new_mention_user_ids =
mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
// Filter out users that were mentioned before
for mention in &old_mentions {
if new_mention_user_ids.contains(&mention.user_id.to_proto()) {
update_mention_user_ids.insert(mention.user_id.to_proto());
}
new_mention_user_ids.remove(&mention.user_id.to_proto());
}
let notification_kind_id =
self.get_notification_kind_id_by_name("ChannelMessageMention");
let existing_notifications = notification::Entity::find()
.filter(notification::Column::EntityId.eq(message_id))
.filter(notification::Column::Kind.eq(notification_kind_id))
.all(&*tx)
.await?;
// determine which notifications should be updated or deleted
let mut deleted_notification_ids = HashSet::default();
let mut updated_mention_notifications = Vec::new();
for notification in existing_notifications {
if update_mention_user_ids.contains(&notification.recipient_id.to_proto()) {
if let Some(notification) =
self::notifications::model_to_proto(self, notification).log_err()
{
updated_mention_notifications.push(notification);
}
} else {
deleted_notification_ids.insert(notification.id);
}
}
let mut notifications = Vec::new();
for mentioned_user in new_mention_user_ids {
notifications.extend(
self.create_notification(
UserId::from_proto(mentioned_user),
rpc::Notification::ChannelMessageMention {
message_id: message_id.to_proto(),
sender_id: user_id.to_proto(),
channel_id: channel_id.to_proto(),
},
false,
&tx,
)
.await?,
);
}
Ok(UpdatedChannelMessage {
message_id,
participant_connection_ids,
notifications,
reply_to_message_id: channel_message.reply_to_message_id,
timestamp: channel_message.sent_at,
deleted_mention_notification_ids: deleted_notification_ids
.into_iter()
.collect::<Vec<_>>(),
updated_mention_notifications,
})
})
.await
}
}

View File

@@ -1193,7 +1193,6 @@ impl Database {
self.transaction(|tx| async move {
self.room_connection_lost(connection, &tx).await?;
self.channel_buffer_connection_lost(connection, &tx).await?;
self.channel_chat_connection_lost(connection, &tx).await?;
Ok(())
})
.await

View File

@@ -7,7 +7,6 @@ mod db_tests;
mod embedding_tests;
mod extension_tests;
mod feature_flag_tests;
mod message_tests;
mod user_tests;
use crate::migrations::run_database_migrations;
@@ -21,7 +20,7 @@ use sqlx::migrate::MigrateDatabase;
use std::{
sync::{
Arc,
atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
atomic::{AtomicI32, Ordering::SeqCst},
},
time::Duration,
};
@@ -224,11 +223,3 @@ async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
.unwrap()
.user_id
}
static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
fn new_test_connection(server: ServerId) -> ConnectionId {
ConnectionId {
id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
owner_id: server.0 as u32,
}
}

View File

@@ -1,7 +1,7 @@
use crate::{
db::{
Channel, ChannelId, ChannelRole, Database, NewUserParams, RoomId, UserId,
tests::{assert_channel_tree_matches, channel_tree, new_test_connection, new_test_user},
tests::{assert_channel_tree_matches, channel_tree, new_test_user},
},
test_both_dbs,
};
@@ -949,41 +949,6 @@ async fn test_user_is_channel_participant(db: &Arc<Database>) {
)
}
test_both_dbs!(
test_guest_access,
test_guest_access_postgres,
test_guest_access_sqlite
);
async fn test_guest_access(db: &Arc<Database>) {
let server = db.create_server("test").await.unwrap();
let admin = new_test_user(db, "admin@example.com").await;
let guest = new_test_user(db, "guest@example.com").await;
let guest_connection = new_test_connection(server);
let zed_channel = db.create_root_channel("zed", admin).await.unwrap();
db.set_channel_visibility(zed_channel, crate::db::ChannelVisibility::Public, admin)
.await
.unwrap();
assert!(
db.join_channel_chat(zed_channel, guest_connection, guest)
.await
.is_err()
);
db.join_channel(zed_channel, guest, guest_connection)
.await
.unwrap();
assert!(
db.join_channel_chat(zed_channel, guest_connection, guest)
.await
.is_ok()
)
}
#[track_caller]
fn assert_channel_tree(actual: Vec<Channel>, expected: &[(ChannelId, &[ChannelId])]) {
let actual = actual

View File

@@ -1,421 +0,0 @@
use super::new_test_user;
use crate::{
db::{ChannelRole, Database, MessageId},
test_both_dbs,
};
use channel::mentions_to_proto;
use std::sync::Arc;
use time::OffsetDateTime;
test_both_dbs!(
test_channel_message_retrieval,
test_channel_message_retrieval_postgres,
test_channel_message_retrieval_sqlite
);
async fn test_channel_message_retrieval(db: &Arc<Database>) {
let user = new_test_user(db, "user@example.com").await;
let channel = db.create_channel("channel", None, user).await.unwrap().0;
let owner_id = db.create_server("test").await.unwrap().0 as u32;
db.join_channel_chat(channel.id, rpc::ConnectionId { owner_id, id: 0 }, user)
.await
.unwrap();
let mut all_messages = Vec::new();
for i in 0..10 {
all_messages.push(
db.create_channel_message(
channel.id,
user,
&i.to_string(),
&[],
OffsetDateTime::now_utc(),
i,
None,
)
.await
.unwrap()
.message_id
.to_proto(),
);
}
let messages = db
.get_channel_messages(channel.id, user, 3, None)
.await
.unwrap()
.into_iter()
.map(|message| message.id)
.collect::<Vec<_>>();
assert_eq!(messages, &all_messages[7..10]);
let messages = db
.get_channel_messages(
channel.id,
user,
4,
Some(MessageId::from_proto(all_messages[6])),
)
.await
.unwrap()
.into_iter()
.map(|message| message.id)
.collect::<Vec<_>>();
assert_eq!(messages, &all_messages[2..6]);
}
test_both_dbs!(
test_channel_message_nonces,
test_channel_message_nonces_postgres,
test_channel_message_nonces_sqlite
);
async fn test_channel_message_nonces(db: &Arc<Database>) {
let user_a = new_test_user(db, "user_a@example.com").await;
let user_b = new_test_user(db, "user_b@example.com").await;
let user_c = new_test_user(db, "user_c@example.com").await;
let channel = db.create_root_channel("channel", user_a).await.unwrap();
db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member)
.await
.unwrap();
db.invite_channel_member(channel, user_c, user_a, ChannelRole::Member)
.await
.unwrap();
db.respond_to_channel_invite(channel, user_b, true)
.await
.unwrap();
db.respond_to_channel_invite(channel, user_c, true)
.await
.unwrap();
let owner_id = db.create_server("test").await.unwrap().0 as u32;
db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user_a)
.await
.unwrap();
db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 1 }, user_b)
.await
.unwrap();
// As user A, create messages that reuse the same nonces. The requests
// succeed, but return the same ids.
let id1 = db
.create_channel_message(
channel,
user_a,
"hi @user_b",
&mentions_to_proto(&[(3..10, user_b.to_proto())]),
OffsetDateTime::now_utc(),
100,
None,
)
.await
.unwrap()
.message_id;
let id2 = db
.create_channel_message(
channel,
user_a,
"hello, fellow users",
&mentions_to_proto(&[]),
OffsetDateTime::now_utc(),
200,
None,
)
.await
.unwrap()
.message_id;
let id3 = db
.create_channel_message(
channel,
user_a,
"bye @user_c (same nonce as first message)",
&mentions_to_proto(&[(4..11, user_c.to_proto())]),
OffsetDateTime::now_utc(),
100,
None,
)
.await
.unwrap()
.message_id;
let id4 = db
.create_channel_message(
channel,
user_a,
"omg (same nonce as second message)",
&mentions_to_proto(&[]),
OffsetDateTime::now_utc(),
200,
None,
)
.await
.unwrap()
.message_id;
// As a different user, reuse one of the same nonces. This request succeeds
// and returns a different id.
let id5 = db
.create_channel_message(
channel,
user_b,
"omg @user_a (same nonce as user_a's first message)",
&mentions_to_proto(&[(4..11, user_a.to_proto())]),
OffsetDateTime::now_utc(),
100,
None,
)
.await
.unwrap()
.message_id;
assert_ne!(id1, id2);
assert_eq!(id1, id3);
assert_eq!(id2, id4);
assert_ne!(id5, id1);
let messages = db
.get_channel_messages(channel, user_a, 5, None)
.await
.unwrap()
.into_iter()
.map(|m| (m.id, m.body, m.mentions))
.collect::<Vec<_>>();
assert_eq!(
messages,
&[
(
id1.to_proto(),
"hi @user_b".into(),
mentions_to_proto(&[(3..10, user_b.to_proto())]),
),
(
id2.to_proto(),
"hello, fellow users".into(),
mentions_to_proto(&[])
),
(
id5.to_proto(),
"omg @user_a (same nonce as user_a's first message)".into(),
mentions_to_proto(&[(4..11, user_a.to_proto())]),
),
]
);
}
test_both_dbs!(
test_unseen_channel_messages,
test_unseen_channel_messages_postgres,
test_unseen_channel_messages_sqlite
);
async fn test_unseen_channel_messages(db: &Arc<Database>) {
let user = new_test_user(db, "user_a@example.com").await;
let observer = new_test_user(db, "user_b@example.com").await;
let channel_1 = db.create_root_channel("channel", user).await.unwrap();
let channel_2 = db.create_root_channel("channel-2", user).await.unwrap();
db.invite_channel_member(channel_1, observer, user, ChannelRole::Member)
.await
.unwrap();
db.invite_channel_member(channel_2, observer, user, ChannelRole::Member)
.await
.unwrap();
db.respond_to_channel_invite(channel_1, observer, true)
.await
.unwrap();
db.respond_to_channel_invite(channel_2, observer, true)
.await
.unwrap();
let owner_id = db.create_server("test").await.unwrap().0 as u32;
let user_connection_id = rpc::ConnectionId { owner_id, id: 0 };
db.join_channel_chat(channel_1, user_connection_id, user)
.await
.unwrap();
let _ = db
.create_channel_message(
channel_1,
user,
"1_1",
&[],
OffsetDateTime::now_utc(),
1,
None,
)
.await
.unwrap();
let _ = db
.create_channel_message(
channel_1,
user,
"1_2",
&[],
OffsetDateTime::now_utc(),
2,
None,
)
.await
.unwrap();
let third_message = db
.create_channel_message(
channel_1,
user,
"1_3",
&[],
OffsetDateTime::now_utc(),
3,
None,
)
.await
.unwrap()
.message_id;
db.join_channel_chat(channel_2, user_connection_id, user)
.await
.unwrap();
let fourth_message = db
.create_channel_message(
channel_2,
user,
"2_1",
&[],
OffsetDateTime::now_utc(),
4,
None,
)
.await
.unwrap()
.message_id;
// Check that observer has new messages
let latest_messages = db
.transaction(|tx| async move {
db.latest_channel_messages(&[channel_1, channel_2], &tx)
.await
})
.await
.unwrap();
assert_eq!(
latest_messages,
[
rpc::proto::ChannelMessageId {
channel_id: channel_1.to_proto(),
message_id: third_message.to_proto(),
},
rpc::proto::ChannelMessageId {
channel_id: channel_2.to_proto(),
message_id: fourth_message.to_proto(),
},
]
);
}
test_both_dbs!(
test_channel_message_mentions,
test_channel_message_mentions_postgres,
test_channel_message_mentions_sqlite
);
async fn test_channel_message_mentions(db: &Arc<Database>) {
let user_a = new_test_user(db, "user_a@example.com").await;
let user_b = new_test_user(db, "user_b@example.com").await;
let user_c = new_test_user(db, "user_c@example.com").await;
let channel = db
.create_channel("channel", None, user_a)
.await
.unwrap()
.0
.id;
db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member)
.await
.unwrap();
db.respond_to_channel_invite(channel, user_b, true)
.await
.unwrap();
let owner_id = db.create_server("test").await.unwrap().0 as u32;
let connection_id = rpc::ConnectionId { owner_id, id: 0 };
db.join_channel_chat(channel, connection_id, user_a)
.await
.unwrap();
db.create_channel_message(
channel,
user_a,
"hi @user_b and @user_c",
&mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]),
OffsetDateTime::now_utc(),
1,
None,
)
.await
.unwrap();
db.create_channel_message(
channel,
user_a,
"bye @user_c",
&mentions_to_proto(&[(4..11, user_c.to_proto())]),
OffsetDateTime::now_utc(),
2,
None,
)
.await
.unwrap();
db.create_channel_message(
channel,
user_a,
"umm",
&mentions_to_proto(&[]),
OffsetDateTime::now_utc(),
3,
None,
)
.await
.unwrap();
db.create_channel_message(
channel,
user_a,
"@user_b, stop.",
&mentions_to_proto(&[(0..7, user_b.to_proto())]),
OffsetDateTime::now_utc(),
4,
None,
)
.await
.unwrap();
let messages = db
.get_channel_messages(channel, user_b, 5, None)
.await
.unwrap()
.into_iter()
.map(|m| (m.body, m.mentions))
.collect::<Vec<_>>();
assert_eq!(
&messages,
&[
(
"hi @user_b and @user_c".into(),
mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]),
),
(
"bye @user_c".into(),
mentions_to_proto(&[(4..11, user_c.to_proto())]),
),
("umm".into(), mentions_to_proto(&[]),),
(
"@user_b, stop.".into(),
mentions_to_proto(&[(0..7, user_b.to_proto())]),
),
]
);
}

View File

@@ -4,10 +4,9 @@ use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::{
AppState, Error, Result, auth,
db::{
self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
},
executor::Executor,
};
@@ -66,7 +65,6 @@ use std::{
},
time::{Duration, Instant},
};
use time::OffsetDateTime;
use tokio::sync::{Semaphore, watch};
use tower::ServiceBuilder;
use tracing::{
@@ -80,8 +78,6 @@ pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
const MESSAGE_COUNT_PER_PAGE: usize = 100;
const MAX_MESSAGE_LEN: usize = 1024;
const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
const MAX_CONCURRENT_CONNECTIONS: usize = 512;
@@ -3597,235 +3593,36 @@ fn send_notifications(
/// Send a message to the channel
async fn send_channel_message(
request: proto::SendChannelMessage,
response: Response<proto::SendChannelMessage>,
session: MessageContext,
_request: proto::SendChannelMessage,
_response: Response<proto::SendChannelMessage>,
_session: MessageContext,
) -> Result<()> {
// Validate the message body.
let body = request.body.trim().to_string();
if body.len() > MAX_MESSAGE_LEN {
return Err(anyhow!("message is too long"))?;
}
if body.is_empty() {
return Err(anyhow!("message can't be blank"))?;
}
// TODO: adjust mentions if body is trimmed
let timestamp = OffsetDateTime::now_utc();
let nonce = request.nonce.context("nonce can't be blank")?;
let channel_id = ChannelId::from_proto(request.channel_id);
let CreatedChannelMessage {
message_id,
participant_connection_ids,
notifications,
} = session
.db()
.await
.create_channel_message(
channel_id,
session.user_id(),
&body,
&request.mentions,
timestamp,
nonce.clone().into(),
request.reply_to_message_id.map(MessageId::from_proto),
)
.await?;
let message = proto::ChannelMessage {
sender_id: session.user_id().to_proto(),
id: message_id.to_proto(),
body,
mentions: request.mentions,
timestamp: timestamp.unix_timestamp() as u64,
nonce: Some(nonce),
reply_to_message_id: request.reply_to_message_id,
edited_at: None,
};
broadcast(
Some(session.connection_id),
participant_connection_ids.clone(),
|connection| {
session.peer.send(
connection,
proto::ChannelMessageSent {
channel_id: channel_id.to_proto(),
message: Some(message.clone()),
},
)
},
);
response.send(proto::SendChannelMessageResponse {
message: Some(message),
})?;
let pool = &*session.connection_pool().await;
let non_participants =
pool.channel_connection_ids(channel_id)
.filter_map(|(connection_id, _)| {
if participant_connection_ids.contains(&connection_id) {
None
} else {
Some(connection_id)
}
});
broadcast(None, non_participants, |peer_id| {
session.peer.send(
peer_id,
proto::UpdateChannels {
latest_channel_message_ids: vec![proto::ChannelMessageId {
channel_id: channel_id.to_proto(),
message_id: message_id.to_proto(),
}],
..Default::default()
},
)
});
send_notifications(pool, &session.peer, notifications);
Ok(())
Err(anyhow!("chat has been removed in the latest version of Zed").into())
}
/// Delete a channel message
async fn remove_channel_message(
request: proto::RemoveChannelMessage,
response: Response<proto::RemoveChannelMessage>,
session: MessageContext,
_request: proto::RemoveChannelMessage,
_response: Response<proto::RemoveChannelMessage>,
_session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id);
let (connection_ids, existing_notification_ids) = session
.db()
.await
.remove_channel_message(channel_id, message_id, session.user_id())
.await?;
broadcast(
Some(session.connection_id),
connection_ids,
move |connection| {
session.peer.send(connection, request.clone())?;
for notification_id in &existing_notification_ids {
session.peer.send(
connection,
proto::DeleteNotification {
notification_id: (*notification_id).to_proto(),
},
)?;
}
Ok(())
},
);
response.send(proto::Ack {})?;
Ok(())
Err(anyhow!("chat has been removed in the latest version of Zed").into())
}
async fn update_channel_message(
request: proto::UpdateChannelMessage,
response: Response<proto::UpdateChannelMessage>,
session: MessageContext,
_request: proto::UpdateChannelMessage,
_response: Response<proto::UpdateChannelMessage>,
_session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id);
let updated_at = OffsetDateTime::now_utc();
let UpdatedChannelMessage {
message_id,
participant_connection_ids,
notifications,
reply_to_message_id,
timestamp,
deleted_mention_notification_ids,
updated_mention_notifications,
} = session
.db()
.await
.update_channel_message(
channel_id,
message_id,
session.user_id(),
request.body.as_str(),
&request.mentions,
updated_at,
)
.await?;
let nonce = request.nonce.clone().context("nonce can't be blank")?;
let message = proto::ChannelMessage {
sender_id: session.user_id().to_proto(),
id: message_id.to_proto(),
body: request.body.clone(),
mentions: request.mentions.clone(),
timestamp: timestamp.assume_utc().unix_timestamp() as u64,
nonce: Some(nonce),
reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
edited_at: Some(updated_at.unix_timestamp() as u64),
};
response.send(proto::Ack {})?;
let pool = &*session.connection_pool().await;
broadcast(
Some(session.connection_id),
participant_connection_ids,
|connection| {
session.peer.send(
connection,
proto::ChannelMessageUpdate {
channel_id: channel_id.to_proto(),
message: Some(message.clone()),
},
)?;
for notification_id in &deleted_mention_notification_ids {
session.peer.send(
connection,
proto::DeleteNotification {
notification_id: (*notification_id).to_proto(),
},
)?;
}
for notification in &updated_mention_notifications {
session.peer.send(
connection,
proto::UpdateNotification {
notification: Some(notification.clone()),
},
)?;
}
Ok(())
},
);
send_notifications(pool, &session.peer, notifications);
Ok(())
Err(anyhow!("chat has been removed in the latest version of Zed").into())
}
/// Mark a channel message as read
async fn acknowledge_channel_message(
request: proto::AckChannelMessage,
session: MessageContext,
_request: proto::AckChannelMessage,
_session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id);
let notifications = session
.db()
.await
.observe_channel_message(channel_id, session.user_id(), message_id)
.await?;
send_notifications(
&*session.connection_pool().await,
&session.peer,
notifications,
);
Ok(())
Err(anyhow!("chat has been removed in the latest version of Zed").into())
}
/// Mark a buffer version as synced
@@ -3878,84 +3675,37 @@ async fn get_supermaven_api_key(
/// Start receiving chat updates for a channel
async fn join_channel_chat(
request: proto::JoinChannelChat,
response: Response<proto::JoinChannelChat>,
session: MessageContext,
_request: proto::JoinChannelChat,
_response: Response<proto::JoinChannelChat>,
_session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let db = session.db().await;
db.join_channel_chat(channel_id, session.connection_id, session.user_id())
.await?;
let messages = db
.get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
.await?;
response.send(proto::JoinChannelChatResponse {
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
messages,
})?;
Ok(())
Err(anyhow!("chat has been removed in the latest version of Zed").into())
}
/// Stop receiving chat updates for a channel
async fn leave_channel_chat(
request: proto::LeaveChannelChat,
session: MessageContext,
_request: proto::LeaveChannelChat,
_session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
session
.db()
.await
.leave_channel_chat(channel_id, session.connection_id, session.user_id())
.await?;
Ok(())
Err(anyhow!("chat has been removed in the latest version of Zed").into())
}
/// Retrieve the chat history for a channel
async fn get_channel_messages(
request: proto::GetChannelMessages,
response: Response<proto::GetChannelMessages>,
session: MessageContext,
_request: proto::GetChannelMessages,
_response: Response<proto::GetChannelMessages>,
_session: MessageContext,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let messages = session
.db()
.await
.get_channel_messages(
channel_id,
session.user_id(),
MESSAGE_COUNT_PER_PAGE,
Some(MessageId::from_proto(request.before_message_id)),
)
.await?;
response.send(proto::GetChannelMessagesResponse {
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
messages,
})?;
Ok(())
Err(anyhow!("chat has been removed in the latest version of Zed").into())
}
/// Retrieve specific chat messages
async fn get_channel_messages_by_id(
request: proto::GetChannelMessagesById,
response: Response<proto::GetChannelMessagesById>,
session: MessageContext,
_request: proto::GetChannelMessagesById,
_response: Response<proto::GetChannelMessagesById>,
_session: MessageContext,
) -> Result<()> {
let message_ids = request
.message_ids
.iter()
.map(|id| MessageId::from_proto(*id))
.collect::<Vec<_>>();
let messages = session
.db()
.await
.get_channel_messages_by_id(session.user_id(), &message_ids)
.await?;
response.send(proto::GetChannelMessagesResponse {
done: messages.len() < MESSAGE_COUNT_PER_PAGE,
messages,
})?;
Ok(())
Err(anyhow!("chat has been removed in the latest version of Zed").into())
}
/// Retrieve the current users notifications
@@ -4095,7 +3845,6 @@ fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserCh
})
.collect(),
observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
observed_channel_message_id: channels.observed_channel_messages.clone(),
}
}
@@ -4107,7 +3856,6 @@ fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
}
update.latest_channel_buffer_versions = channels.latest_buffer_versions;
update.latest_channel_message_ids = channels.latest_channel_messages;
for (channel_id, participants) in channels.channel_participants {
update

View File

@@ -6,7 +6,6 @@ use gpui::{Entity, TestAppContext};
mod channel_buffer_tests;
mod channel_guest_tests;
mod channel_message_tests;
mod channel_tests;
mod editor_tests;
mod following_tests;

View File

@@ -1,725 +0,0 @@
use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
use channel::{ChannelChat, ChannelMessageId, MessageParams};
use collab_ui::chat_panel::ChatPanel;
use gpui::{BackgroundExecutor, Entity, TestAppContext};
use rpc::Notification;
use workspace::dock::Panel;
#[gpui::test]
async fn test_basic_channel_messages(
executor: BackgroundExecutor,
mut cx_a: &mut TestAppContext,
mut cx_b: &mut TestAppContext,
mut cx_c: &mut TestAppContext,
) {
let mut server = TestServer::start(executor.clone()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let client_c = server.create_client(cx_c, "user_c").await;
let channel_id = server
.make_channel(
"the-channel",
None,
(&client_a, cx_a),
&mut [(&client_b, cx_b), (&client_c, cx_c)],
)
.await;
let channel_chat_a = client_a
.channel_store()
.update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
let channel_chat_b = client_b
.channel_store()
.update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
let message_id = channel_chat_a
.update(cx_a, |c, cx| {
c.send_message(
MessageParams {
text: "hi @user_c!".into(),
mentions: vec![(3..10, client_c.id())],
reply_to_message_id: None,
},
cx,
)
.unwrap()
})
.await
.unwrap();
channel_chat_a
.update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap())
.await
.unwrap();
executor.run_until_parked();
channel_chat_b
.update(cx_b, |c, cx| c.send_message("three".into(), cx).unwrap())
.await
.unwrap();
executor.run_until_parked();
let channel_chat_c = client_c
.channel_store()
.update(cx_c, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
for (chat, cx) in [
(&channel_chat_a, &mut cx_a),
(&channel_chat_b, &mut cx_b),
(&channel_chat_c, &mut cx_c),
] {
chat.update(*cx, |c, _| {
assert_eq!(
c.messages()
.iter()
.map(|m| (m.body.as_str(), m.mentions.as_slice()))
.collect::<Vec<_>>(),
vec![
("hi @user_c!", [(3..10, client_c.id())].as_slice()),
("two", &[]),
("three", &[])
],
"results for user {}",
c.client().id(),
);
});
}
client_c.notification_store().update(cx_c, |store, _| {
assert_eq!(store.notification_count(), 2);
assert_eq!(store.unread_notification_count(), 1);
assert_eq!(
store.notification_at(0).unwrap().notification,
Notification::ChannelMessageMention {
message_id,
sender_id: client_a.id(),
channel_id: channel_id.0,
}
);
assert_eq!(
store.notification_at(1).unwrap().notification,
Notification::ChannelInvitation {
channel_id: channel_id.0,
channel_name: "the-channel".to_string(),
inviter_id: client_a.id()
}
);
});
}
#[gpui::test]
async fn test_rejoin_channel_chat(
executor: BackgroundExecutor,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
) {
let mut server = TestServer::start(executor.clone()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let channel_id = server
.make_channel(
"the-channel",
None,
(&client_a, cx_a),
&mut [(&client_b, cx_b)],
)
.await;
let channel_chat_a = client_a
.channel_store()
.update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
let channel_chat_b = client_b
.channel_store()
.update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
channel_chat_a
.update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap())
.await
.unwrap();
channel_chat_b
.update(cx_b, |c, cx| c.send_message("two".into(), cx).unwrap())
.await
.unwrap();
server.forbid_connections();
server.disconnect_client(client_a.peer_id().unwrap());
// While client A is disconnected, clients A and B both send new messages.
channel_chat_a
.update(cx_a, |c, cx| c.send_message("three".into(), cx).unwrap())
.await
.unwrap_err();
channel_chat_a
.update(cx_a, |c, cx| c.send_message("four".into(), cx).unwrap())
.await
.unwrap_err();
channel_chat_b
.update(cx_b, |c, cx| c.send_message("five".into(), cx).unwrap())
.await
.unwrap();
channel_chat_b
.update(cx_b, |c, cx| c.send_message("six".into(), cx).unwrap())
.await
.unwrap();
// Client A reconnects.
server.allow_connections();
executor.advance_clock(RECONNECT_TIMEOUT);
// Client A fetches the messages that were sent while they were disconnected
// and resends their own messages which failed to send.
let expected_messages = &["one", "two", "five", "six", "three", "four"];
assert_messages(&channel_chat_a, expected_messages, cx_a);
assert_messages(&channel_chat_b, expected_messages, cx_b);
}
#[gpui::test]
async fn test_remove_channel_message(
executor: BackgroundExecutor,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
cx_c: &mut TestAppContext,
) {
let mut server = TestServer::start(executor.clone()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let client_c = server.create_client(cx_c, "user_c").await;
let channel_id = server
.make_channel(
"the-channel",
None,
(&client_a, cx_a),
&mut [(&client_b, cx_b), (&client_c, cx_c)],
)
.await;
let channel_chat_a = client_a
.channel_store()
.update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
let channel_chat_b = client_b
.channel_store()
.update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
// Client A sends some messages.
channel_chat_a
.update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap())
.await
.unwrap();
let msg_id_2 = channel_chat_a
.update(cx_a, |c, cx| {
c.send_message(
MessageParams {
text: "two @user_b".to_string(),
mentions: vec![(4..12, client_b.id())],
reply_to_message_id: None,
},
cx,
)
.unwrap()
})
.await
.unwrap();
channel_chat_a
.update(cx_a, |c, cx| c.send_message("three".into(), cx).unwrap())
.await
.unwrap();
// Clients A and B see all of the messages.
executor.run_until_parked();
let expected_messages = &["one", "two @user_b", "three"];
assert_messages(&channel_chat_a, expected_messages, cx_a);
assert_messages(&channel_chat_b, expected_messages, cx_b);
// Ensure that client B received a notification for the mention.
client_b.notification_store().read_with(cx_b, |store, _| {
assert_eq!(store.notification_count(), 2);
let entry = store.notification_at(0).unwrap();
assert_eq!(
entry.notification,
Notification::ChannelMessageMention {
message_id: msg_id_2,
sender_id: client_a.id(),
channel_id: channel_id.0,
}
);
});
// Client A deletes one of their messages.
channel_chat_a
.update(cx_a, |c, cx| {
let ChannelMessageId::Saved(id) = c.message(1).id else {
panic!("message not saved")
};
c.remove_message(id, cx)
})
.await
.unwrap();
// Client B sees that the message is gone.
executor.run_until_parked();
let expected_messages = &["one", "three"];
assert_messages(&channel_chat_a, expected_messages, cx_a);
assert_messages(&channel_chat_b, expected_messages, cx_b);
// Client C joins the channel chat, and does not see the deleted message.
let channel_chat_c = client_c
.channel_store()
.update(cx_c, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
assert_messages(&channel_chat_c, expected_messages, cx_c);
// Ensure we remove the notifications when the message is removed
client_b.notification_store().read_with(cx_b, |store, _| {
// First notification is the channel invitation, second would be the mention
// notification, which should now be removed.
assert_eq!(store.notification_count(), 1);
});
}
#[track_caller]
fn assert_messages(chat: &Entity<ChannelChat>, messages: &[&str], cx: &mut TestAppContext) {
assert_eq!(
chat.read_with(cx, |chat, _| {
chat.messages()
.iter()
.map(|m| m.body.clone())
.collect::<Vec<_>>()
}),
messages
);
}
#[gpui::test]
async fn test_channel_message_changes(
executor: BackgroundExecutor,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
) {
let mut server = TestServer::start(executor.clone()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let channel_id = server
.make_channel(
"the-channel",
None,
(&client_a, cx_a),
&mut [(&client_b, cx_b)],
)
.await;
// Client A sends a message, client B should see that there is a new message.
let channel_chat_a = client_a
.channel_store()
.update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
channel_chat_a
.update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap())
.await
.unwrap();
executor.run_until_parked();
let b_has_messages = cx_b.update(|cx| {
client_b
.channel_store()
.read(cx)
.has_new_messages(channel_id)
});
assert!(b_has_messages);
// Opening the chat should clear the changed flag.
cx_b.update(|cx| {
collab_ui::init(&client_b.app_state, cx);
});
let project_b = client_b.build_empty_local_project(cx_b);
let (workspace_b, cx_b) = client_b.build_workspace(&project_b, cx_b);
let chat_panel_b = workspace_b.update_in(cx_b, ChatPanel::new);
chat_panel_b
.update_in(cx_b, |chat_panel, window, cx| {
chat_panel.set_active(true, window, cx);
chat_panel.select_channel(channel_id, None, cx)
})
.await
.unwrap();
executor.run_until_parked();
let b_has_messages = cx_b.update(|_, cx| {
client_b
.channel_store()
.read(cx)
.has_new_messages(channel_id)
});
assert!(!b_has_messages);
// Sending a message while the chat is open should not change the flag.
channel_chat_a
.update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap())
.await
.unwrap();
executor.run_until_parked();
let b_has_messages = cx_b.update(|_, cx| {
client_b
.channel_store()
.read(cx)
.has_new_messages(channel_id)
});
assert!(!b_has_messages);
// Sending a message while the chat is closed should change the flag.
chat_panel_b.update_in(cx_b, |chat_panel, window, cx| {
chat_panel.set_active(false, window, cx);
});
// Sending a message while the chat is open should not change the flag.
channel_chat_a
.update(cx_a, |c, cx| c.send_message("three".into(), cx).unwrap())
.await
.unwrap();
executor.run_until_parked();
let b_has_messages = cx_b.update(|_, cx| {
client_b
.channel_store()
.read(cx)
.has_new_messages(channel_id)
});
assert!(b_has_messages);
// Closing the chat should re-enable change tracking
cx_b.update(|_, _| drop(chat_panel_b));
channel_chat_a
.update(cx_a, |c, cx| c.send_message("four".into(), cx).unwrap())
.await
.unwrap();
executor.run_until_parked();
let b_has_messages = cx_b.update(|_, cx| {
client_b
.channel_store()
.read(cx)
.has_new_messages(channel_id)
});
assert!(b_has_messages);
}
#[gpui::test]
async fn test_chat_replies(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
let mut server = TestServer::start(cx_a.executor()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let channel_id = server
.make_channel(
"the-channel",
None,
(&client_a, cx_a),
&mut [(&client_b, cx_b)],
)
.await;
// Client A sends a message, client B should see that there is a new message.
let channel_chat_a = client_a
.channel_store()
.update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
let channel_chat_b = client_b
.channel_store()
.update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
let msg_id = channel_chat_a
.update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap())
.await
.unwrap();
cx_a.run_until_parked();
let reply_id = channel_chat_b
.update(cx_b, |c, cx| {
c.send_message(
MessageParams {
text: "reply".into(),
reply_to_message_id: Some(msg_id),
mentions: Vec::new(),
},
cx,
)
.unwrap()
})
.await
.unwrap();
cx_a.run_until_parked();
channel_chat_a.update(cx_a, |channel_chat, _| {
assert_eq!(
channel_chat
.find_loaded_message(reply_id)
.unwrap()
.reply_to_message_id,
Some(msg_id),
)
});
}
#[gpui::test]
async fn test_chat_editing(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
let mut server = TestServer::start(cx_a.executor()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let channel_id = server
.make_channel(
"the-channel",
None,
(&client_a, cx_a),
&mut [(&client_b, cx_b)],
)
.await;
// Client A sends a message, client B should see that there is a new message.
let channel_chat_a = client_a
.channel_store()
.update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
let channel_chat_b = client_b
.channel_store()
.update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
let msg_id = channel_chat_a
.update(cx_a, |c, cx| {
c.send_message(
MessageParams {
text: "Initial message".into(),
reply_to_message_id: None,
mentions: Vec::new(),
},
cx,
)
.unwrap()
})
.await
.unwrap();
cx_a.run_until_parked();
channel_chat_a
.update(cx_a, |c, cx| {
c.update_message(
msg_id,
MessageParams {
text: "Updated body".into(),
reply_to_message_id: None,
mentions: Vec::new(),
},
cx,
)
.unwrap()
})
.await
.unwrap();
cx_a.run_until_parked();
cx_b.run_until_parked();
channel_chat_a.update(cx_a, |channel_chat, _| {
let update_message = channel_chat.find_loaded_message(msg_id).unwrap();
assert_eq!(update_message.body, "Updated body");
assert_eq!(update_message.mentions, Vec::new());
});
channel_chat_b.update(cx_b, |channel_chat, _| {
let update_message = channel_chat.find_loaded_message(msg_id).unwrap();
assert_eq!(update_message.body, "Updated body");
assert_eq!(update_message.mentions, Vec::new());
});
// test mentions are updated correctly
client_b.notification_store().read_with(cx_b, |store, _| {
assert_eq!(store.notification_count(), 1);
let entry = store.notification_at(0).unwrap();
assert!(matches!(
entry.notification,
Notification::ChannelInvitation { .. }
),);
});
channel_chat_a
.update(cx_a, |c, cx| {
c.update_message(
msg_id,
MessageParams {
text: "Updated body including a mention for @user_b".into(),
reply_to_message_id: None,
mentions: vec![(37..45, client_b.id())],
},
cx,
)
.unwrap()
})
.await
.unwrap();
cx_a.run_until_parked();
cx_b.run_until_parked();
channel_chat_a.update(cx_a, |channel_chat, _| {
assert_eq!(
channel_chat.find_loaded_message(msg_id).unwrap().body,
"Updated body including a mention for @user_b",
)
});
channel_chat_b.update(cx_b, |channel_chat, _| {
assert_eq!(
channel_chat.find_loaded_message(msg_id).unwrap().body,
"Updated body including a mention for @user_b",
)
});
client_b.notification_store().read_with(cx_b, |store, _| {
assert_eq!(store.notification_count(), 2);
let entry = store.notification_at(0).unwrap();
assert_eq!(
entry.notification,
Notification::ChannelMessageMention {
message_id: msg_id,
sender_id: client_a.id(),
channel_id: channel_id.0,
}
);
});
// Test update message and keep the mention and check that the body is updated correctly
channel_chat_a
.update(cx_a, |c, cx| {
c.update_message(
msg_id,
MessageParams {
text: "Updated body v2 including a mention for @user_b".into(),
reply_to_message_id: None,
mentions: vec![(37..45, client_b.id())],
},
cx,
)
.unwrap()
})
.await
.unwrap();
cx_a.run_until_parked();
cx_b.run_until_parked();
channel_chat_a.update(cx_a, |channel_chat, _| {
assert_eq!(
channel_chat.find_loaded_message(msg_id).unwrap().body,
"Updated body v2 including a mention for @user_b",
)
});
channel_chat_b.update(cx_b, |channel_chat, _| {
assert_eq!(
channel_chat.find_loaded_message(msg_id).unwrap().body,
"Updated body v2 including a mention for @user_b",
)
});
client_b.notification_store().read_with(cx_b, |store, _| {
let message = store.channel_message_for_id(msg_id);
assert!(message.is_some());
assert_eq!(
message.unwrap().body,
"Updated body v2 including a mention for @user_b"
);
assert_eq!(store.notification_count(), 2);
let entry = store.notification_at(0).unwrap();
assert_eq!(
entry.notification,
Notification::ChannelMessageMention {
message_id: msg_id,
sender_id: client_a.id(),
channel_id: channel_id.0,
}
);
});
// If we remove a mention from a message the corresponding mention notification
// should also be removed.
channel_chat_a
.update(cx_a, |c, cx| {
c.update_message(
msg_id,
MessageParams {
text: "Updated body without a mention".into(),
reply_to_message_id: None,
mentions: vec![],
},
cx,
)
.unwrap()
})
.await
.unwrap();
cx_a.run_until_parked();
cx_b.run_until_parked();
channel_chat_a.update(cx_a, |channel_chat, _| {
assert_eq!(
channel_chat.find_loaded_message(msg_id).unwrap().body,
"Updated body without a mention",
)
});
channel_chat_b.update(cx_b, |channel_chat, _| {
assert_eq!(
channel_chat.find_loaded_message(msg_id).unwrap().body,
"Updated body without a mention",
)
});
client_b.notification_store().read_with(cx_b, |store, _| {
// First notification is the channel invitation, second would be the mention
// notification, which should now be removed.
assert_eq!(store.notification_count(), 1);
});
}

View File

@@ -37,18 +37,15 @@ client.workspace = true
collections.workspace = true
db.workspace = true
editor.workspace = true
emojis.workspace = true
futures.workspace = true
fuzzy.workspace = true
gpui.workspace = true
language.workspace = true
log.workspace = true
menu.workspace = true
notifications.workspace = true
picker.workspace = true
project.workspace = true
release_channel.workspace = true
rich_text.workspace = true
rpc.workspace = true
schemars.workspace = true
serde.workspace = true

File diff suppressed because it is too large Load Diff

View File

@@ -1,548 +0,0 @@
use anyhow::{Context as _, Result};
use channel::{ChannelChat, ChannelStore, MessageParams};
use client::{UserId, UserStore};
use collections::HashSet;
use editor::{AnchorRangeExt, CompletionProvider, Editor, EditorElement, EditorStyle, ExcerptId};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
AsyncApp, AsyncWindowContext, Context, Entity, Focusable, FontStyle, FontWeight,
HighlightStyle, IntoElement, Render, Task, TextStyle, WeakEntity, Window,
};
use language::{
Anchor, Buffer, BufferSnapshot, CodeLabel, LanguageRegistry, ToOffset,
language_settings::SoftWrap,
};
use project::{
Completion, CompletionDisplayOptions, CompletionResponse, CompletionSource, search::SearchQuery,
};
use settings::Settings;
use std::{
ops::Range,
rc::Rc,
sync::{Arc, LazyLock},
time::Duration,
};
use theme::ThemeSettings;
use ui::{TextSize, prelude::*};
use crate::panel_settings::MessageEditorSettings;
const MENTIONS_DEBOUNCE_INTERVAL: Duration = Duration::from_millis(50);
static MENTIONS_SEARCH: LazyLock<SearchQuery> = LazyLock::new(|| {
SearchQuery::regex(
"@[-_\\w]+",
false,
false,
false,
false,
Default::default(),
Default::default(),
false,
None,
)
.unwrap()
});
pub struct MessageEditor {
pub editor: Entity<Editor>,
user_store: Entity<UserStore>,
channel_chat: Option<Entity<ChannelChat>>,
mentions: Vec<UserId>,
mentions_task: Option<Task<()>>,
reply_to_message_id: Option<u64>,
edit_message_id: Option<u64>,
}
struct MessageEditorCompletionProvider(WeakEntity<MessageEditor>);
impl CompletionProvider for MessageEditorCompletionProvider {
fn completions(
&self,
_excerpt_id: ExcerptId,
buffer: &Entity<Buffer>,
buffer_position: language::Anchor,
_: editor::CompletionContext,
_window: &mut Window,
cx: &mut Context<Editor>,
) -> Task<Result<Vec<CompletionResponse>>> {
let Some(handle) = self.0.upgrade() else {
return Task::ready(Ok(Vec::new()));
};
handle.update(cx, |message_editor, cx| {
message_editor.completions(buffer, buffer_position, cx)
})
}
fn is_completion_trigger(
&self,
_buffer: &Entity<Buffer>,
_position: language::Anchor,
text: &str,
_trigger_in_words: bool,
_menu_is_open: bool,
_cx: &mut Context<Editor>,
) -> bool {
text == "@"
}
}
impl MessageEditor {
pub fn new(
language_registry: Arc<LanguageRegistry>,
user_store: Entity<UserStore>,
channel_chat: Option<Entity<ChannelChat>>,
editor: Entity<Editor>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let this = cx.entity().downgrade();
editor.update(cx, |editor, cx| {
editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
editor.set_offset_content(false, cx);
editor.set_use_autoclose(false);
editor.set_show_gutter(false, cx);
editor.set_show_wrap_guides(false, cx);
editor.set_show_indent_guides(false, cx);
editor.set_completion_provider(Some(Rc::new(MessageEditorCompletionProvider(this))));
editor.set_auto_replace_emoji_shortcode(
MessageEditorSettings::get_global(cx)
.auto_replace_emoji_shortcode
.unwrap_or_default(),
);
});
let buffer = editor
.read(cx)
.buffer()
.read(cx)
.as_singleton()
.expect("message editor must be singleton");
cx.subscribe_in(&buffer, window, Self::on_buffer_event)
.detach();
cx.observe_global::<settings::SettingsStore>(|this, cx| {
this.editor.update(cx, |editor, cx| {
editor.set_auto_replace_emoji_shortcode(
MessageEditorSettings::get_global(cx)
.auto_replace_emoji_shortcode
.unwrap_or_default(),
)
})
})
.detach();
let markdown = language_registry.language_for_name("Markdown");
cx.spawn_in(window, async move |_, cx| {
let markdown = markdown.await.context("failed to load Markdown language")?;
buffer.update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx))
})
.detach_and_log_err(cx);
Self {
editor,
user_store,
channel_chat,
mentions: Vec::new(),
mentions_task: None,
reply_to_message_id: None,
edit_message_id: None,
}
}
pub fn reply_to_message_id(&self) -> Option<u64> {
self.reply_to_message_id
}
pub fn set_reply_to_message_id(&mut self, reply_to_message_id: u64) {
self.reply_to_message_id = Some(reply_to_message_id);
}
pub fn clear_reply_to_message_id(&mut self) {
self.reply_to_message_id = None;
}
pub fn edit_message_id(&self) -> Option<u64> {
self.edit_message_id
}
pub fn set_edit_message_id(&mut self, edit_message_id: u64) {
self.edit_message_id = Some(edit_message_id);
}
pub fn clear_edit_message_id(&mut self) {
self.edit_message_id = None;
}
pub fn set_channel_chat(&mut self, chat: Entity<ChannelChat>, cx: &mut Context<Self>) {
let channel_id = chat.read(cx).channel_id;
self.channel_chat = Some(chat);
let channel_name = ChannelStore::global(cx)
.read(cx)
.channel_for_id(channel_id)
.map(|channel| channel.name.clone());
self.editor.update(cx, |editor, cx| {
if let Some(channel_name) = channel_name {
editor.set_placeholder_text(format!("Message #{channel_name}"), cx);
} else {
editor.set_placeholder_text("Message Channel", cx);
}
});
}
pub fn take_message(&mut self, window: &mut Window, cx: &mut Context<Self>) -> MessageParams {
self.editor.update(cx, |editor, cx| {
let highlights = editor.text_highlights::<Self>(cx);
let text = editor.text(cx);
let snapshot = editor.buffer().read(cx).snapshot(cx);
let mentions = if let Some((_, ranges)) = highlights {
ranges
.iter()
.map(|range| range.to_offset(&snapshot))
.zip(self.mentions.iter().copied())
.collect()
} else {
Vec::new()
};
editor.clear(window, cx);
self.mentions.clear();
let reply_to_message_id = std::mem::take(&mut self.reply_to_message_id);
MessageParams {
text,
mentions,
reply_to_message_id,
}
})
}
fn on_buffer_event(
&mut self,
buffer: &Entity<Buffer>,
event: &language::BufferEvent,
window: &mut Window,
cx: &mut Context<Self>,
) {
if let language::BufferEvent::Reparsed | language::BufferEvent::Edited = event {
let buffer = buffer.read(cx).snapshot();
self.mentions_task = Some(cx.spawn_in(window, async move |this, cx| {
cx.background_executor()
.timer(MENTIONS_DEBOUNCE_INTERVAL)
.await;
Self::find_mentions(this, buffer, cx).await;
}));
}
}
fn completions(
&mut self,
buffer: &Entity<Buffer>,
end_anchor: Anchor,
cx: &mut Context<Self>,
) -> Task<Result<Vec<CompletionResponse>>> {
if let Some((start_anchor, query, candidates)) =
self.collect_mention_candidates(buffer, end_anchor, cx)
&& !candidates.is_empty()
{
return cx.spawn(async move |_, cx| {
let completion_response = Self::completions_for_candidates(
cx,
query.as_str(),
&candidates,
start_anchor..end_anchor,
Self::completion_for_mention,
)
.await;
Ok(vec![completion_response])
});
}
if let Some((start_anchor, query, candidates)) =
self.collect_emoji_candidates(buffer, end_anchor, cx)
&& !candidates.is_empty()
{
return cx.spawn(async move |_, cx| {
let completion_response = Self::completions_for_candidates(
cx,
query.as_str(),
candidates,
start_anchor..end_anchor,
Self::completion_for_emoji,
)
.await;
Ok(vec![completion_response])
});
}
Task::ready(Ok(vec![CompletionResponse {
completions: Vec::new(),
display_options: CompletionDisplayOptions::default(),
is_incomplete: false,
}]))
}
async fn completions_for_candidates(
cx: &AsyncApp,
query: &str,
candidates: &[StringMatchCandidate],
range: Range<Anchor>,
completion_fn: impl Fn(&StringMatch) -> (String, CodeLabel),
) -> CompletionResponse {
const LIMIT: usize = 10;
let matches = fuzzy::match_strings(
candidates,
query,
true,
true,
LIMIT,
&Default::default(),
cx.background_executor().clone(),
)
.await;
let completions = matches
.into_iter()
.map(|mat| {
let (new_text, label) = completion_fn(&mat);
Completion {
replace_range: range.clone(),
new_text,
label,
icon_path: None,
confirm: None,
documentation: None,
insert_text_mode: None,
source: CompletionSource::Custom,
}
})
.collect::<Vec<_>>();
CompletionResponse {
is_incomplete: completions.len() >= LIMIT,
display_options: CompletionDisplayOptions::default(),
completions,
}
}
fn completion_for_mention(mat: &StringMatch) -> (String, CodeLabel) {
let label = CodeLabel {
filter_range: 1..mat.string.len() + 1,
text: format!("@{}", mat.string),
runs: Vec::new(),
};
(mat.string.clone(), label)
}
fn completion_for_emoji(mat: &StringMatch) -> (String, CodeLabel) {
let emoji = emojis::get_by_shortcode(&mat.string).unwrap();
let label = CodeLabel {
filter_range: 1..mat.string.len() + 1,
text: format!(":{}: {}", mat.string, emoji),
runs: Vec::new(),
};
(emoji.to_string(), label)
}
fn collect_mention_candidates(
&mut self,
buffer: &Entity<Buffer>,
end_anchor: Anchor,
cx: &mut Context<Self>,
) -> Option<(Anchor, String, Vec<StringMatchCandidate>)> {
let end_offset = end_anchor.to_offset(buffer.read(cx));
let query = buffer.read_with(cx, |buffer, _| {
let mut query = String::new();
for ch in buffer.reversed_chars_at(end_offset).take(100) {
if ch == '@' {
return Some(query.chars().rev().collect::<String>());
}
if ch.is_whitespace() || !ch.is_ascii() {
break;
}
query.push(ch);
}
None
})?;
let start_offset = end_offset - query.len();
let start_anchor = buffer.read(cx).anchor_before(start_offset);
let mut names = HashSet::default();
if let Some(chat) = self.channel_chat.as_ref() {
let chat = chat.read(cx);
for participant in ChannelStore::global(cx)
.read(cx)
.channel_participants(chat.channel_id)
{
names.insert(participant.github_login.clone());
}
for message in chat
.messages_in_range(chat.message_count().saturating_sub(100)..chat.message_count())
{
names.insert(message.sender.github_login.clone());
}
}
let candidates = names
.into_iter()
.map(|user| StringMatchCandidate::new(0, &user))
.collect::<Vec<_>>();
Some((start_anchor, query, candidates))
}
fn collect_emoji_candidates(
&mut self,
buffer: &Entity<Buffer>,
end_anchor: Anchor,
cx: &mut Context<Self>,
) -> Option<(Anchor, String, &'static [StringMatchCandidate])> {
static EMOJI_FUZZY_MATCH_CANDIDATES: LazyLock<Vec<StringMatchCandidate>> =
LazyLock::new(|| {
emojis::iter()
.flat_map(|s| s.shortcodes())
.map(|emoji| StringMatchCandidate::new(0, emoji))
.collect::<Vec<_>>()
});
let end_offset = end_anchor.to_offset(buffer.read(cx));
let query = buffer.read_with(cx, |buffer, _| {
let mut query = String::new();
for ch in buffer.reversed_chars_at(end_offset).take(100) {
if ch == ':' {
let next_char = buffer
.reversed_chars_at(end_offset - query.len() - 1)
.next();
// Ensure we are at the start of the message or that the previous character is a whitespace
if next_char.is_none() || next_char.unwrap().is_whitespace() {
return Some(query.chars().rev().collect::<String>());
}
// If the previous character is not a whitespace, we are in the middle of a word
// and we only want to complete the shortcode if the word is made up of other emojis
let mut containing_word = String::new();
for ch in buffer
.reversed_chars_at(end_offset - query.len() - 1)
.take(100)
{
if ch.is_whitespace() {
break;
}
containing_word.push(ch);
}
let containing_word = containing_word.chars().rev().collect::<String>();
if util::word_consists_of_emojis(containing_word.as_str()) {
return Some(query.chars().rev().collect::<String>());
}
break;
}
if ch.is_whitespace() || !ch.is_ascii() {
break;
}
query.push(ch);
}
None
})?;
let start_offset = end_offset - query.len() - 1;
let start_anchor = buffer.read(cx).anchor_before(start_offset);
Some((start_anchor, query, &EMOJI_FUZZY_MATCH_CANDIDATES))
}
async fn find_mentions(
this: WeakEntity<MessageEditor>,
buffer: BufferSnapshot,
cx: &mut AsyncWindowContext,
) {
let (buffer, ranges) = cx
.background_spawn(async move {
let ranges = MENTIONS_SEARCH.search(&buffer, None).await;
(buffer, ranges)
})
.await;
this.update(cx, |this, cx| {
let mut anchor_ranges = Vec::new();
let mut mentioned_user_ids = Vec::new();
let mut text = String::new();
this.editor.update(cx, |editor, cx| {
let multi_buffer = editor.buffer().read(cx).snapshot(cx);
for range in ranges {
text.clear();
text.extend(buffer.text_for_range(range.clone()));
if let Some(username) = text.strip_prefix('@')
&& let Some(user) = this
.user_store
.read(cx)
.cached_user_by_github_login(username)
{
let start = multi_buffer.anchor_after(range.start);
let end = multi_buffer.anchor_after(range.end);
mentioned_user_ids.push(user.id);
anchor_ranges.push(start..end);
}
}
editor.clear_highlights::<Self>(cx);
editor.highlight_text::<Self>(
anchor_ranges,
HighlightStyle {
font_weight: Some(FontWeight::BOLD),
..Default::default()
},
cx,
)
});
this.mentions = mentioned_user_ids;
this.mentions_task.take();
})
.ok();
}
pub(crate) fn focus_handle(&self, cx: &gpui::App) -> gpui::FocusHandle {
self.editor.read(cx).focus_handle(cx)
}
}
impl Render for MessageEditor {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let text_style = TextStyle {
color: if self.editor.read(cx).read_only(cx) {
cx.theme().colors().text_disabled
} else {
cx.theme().colors().text
},
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features.clone(),
font_fallbacks: settings.ui_font.fallbacks.clone(),
font_size: TextSize::Small.rems(cx).into(),
font_weight: settings.ui_font.weight,
font_style: FontStyle::Normal,
line_height: relative(1.3),
..Default::default()
};
div()
.w_full()
.px_2()
.py_1()
.bg(cx.theme().colors().editor_background)
.rounded_sm()
.child(EditorElement::new(
&self.editor,
EditorStyle {
local_player: cx.theme().players().local(),
text: text_style,
..Default::default()
},
))
}
}

View File

@@ -2,7 +2,7 @@ mod channel_modal;
mod contact_finder;
use self::channel_modal::ChannelModal;
use crate::{CollaborationPanelSettings, channel_view::ChannelView, chat_panel::ChatPanel};
use crate::{CollaborationPanelSettings, channel_view::ChannelView};
use anyhow::Context as _;
use call::ActiveCall;
use channel::{Channel, ChannelEvent, ChannelStore};
@@ -38,7 +38,7 @@ use util::{ResultExt, TryFutureExt, maybe};
use workspace::{
Deafen, LeaveCall, Mute, OpenChannelNotes, ScreenShare, ShareProject, Workspace,
dock::{DockPosition, Panel, PanelEvent},
notifications::{DetachAndPromptErr, NotifyResultExt, NotifyTaskExt},
notifications::{DetachAndPromptErr, NotifyResultExt},
};
actions!(
@@ -261,9 +261,6 @@ enum ListEntry {
ChannelNotes {
channel_id: ChannelId,
},
ChannelChat {
channel_id: ChannelId,
},
ChannelEditor {
depth: usize,
},
@@ -495,7 +492,6 @@ impl CollabPanel {
&& let Some(channel_id) = room.channel_id()
{
self.entries.push(ListEntry::ChannelNotes { channel_id });
self.entries.push(ListEntry::ChannelChat { channel_id });
}
// Populate the active user.
@@ -1089,39 +1085,6 @@ impl CollabPanel {
.tooltip(Tooltip::text("Open Channel Notes"))
}
fn render_channel_chat(
&self,
channel_id: ChannelId,
is_selected: bool,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let channel_store = self.channel_store.read(cx);
let has_messages_notification = channel_store.has_new_messages(channel_id);
ListItem::new("channel-chat")
.toggle_state(is_selected)
.on_click(cx.listener(move |this, _, window, cx| {
this.join_channel_chat(channel_id, window, cx);
}))
.start_slot(
h_flex()
.relative()
.gap_1()
.child(render_tree_branch(false, false, window, cx))
.child(IconButton::new(0, IconName::Chat))
.children(has_messages_notification.then(|| {
div()
.w_1p5()
.absolute()
.right(px(2.))
.top(px(4.))
.child(Indicator::dot().color(Color::Info))
})),
)
.child(Label::new("chat"))
.tooltip(Tooltip::text("Open Chat"))
}
fn has_subchannels(&self, ix: usize) -> bool {
self.entries.get(ix).is_some_and(|entry| {
if let ListEntry::Channel { has_children, .. } = entry {
@@ -1296,13 +1259,6 @@ impl CollabPanel {
this.open_channel_notes(channel_id, window, cx)
}),
)
.entry(
"Open Chat",
None,
window.handler_for(&this, move |this, window, cx| {
this.join_channel_chat(channel_id, window, cx)
}),
)
.entry(
"Copy Channel Link",
None,
@@ -1632,9 +1588,6 @@ impl CollabPanel {
ListEntry::ChannelNotes { channel_id } => {
self.open_channel_notes(*channel_id, window, cx)
}
ListEntry::ChannelChat { channel_id } => {
self.join_channel_chat(*channel_id, window, cx)
}
ListEntry::OutgoingRequest(_) => {}
ListEntry::ChannelEditor { .. } => {}
}
@@ -2258,28 +2211,6 @@ impl CollabPanel {
.detach_and_prompt_err("Failed to join channel", window, cx, |_, _, _| None)
}
fn join_channel_chat(
&mut self,
channel_id: ChannelId,
window: &mut Window,
cx: &mut Context<Self>,
) {
let Some(workspace) = self.workspace.upgrade() else {
return;
};
window.defer(cx, move |window, cx| {
workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.focus_panel::<ChatPanel>(window, cx) {
panel.update(cx, |panel, cx| {
panel
.select_channel(channel_id, None, cx)
.detach_and_notify_err(window, cx);
});
}
});
});
}
fn copy_channel_link(&mut self, channel_id: ChannelId, cx: &mut Context<Self>) {
let channel_store = self.channel_store.read(cx);
let Some(channel) = channel_store.channel_for_id(channel_id) else {
@@ -2398,9 +2329,6 @@ impl CollabPanel {
ListEntry::ChannelNotes { channel_id } => self
.render_channel_notes(*channel_id, is_selected, window, cx)
.into_any_element(),
ListEntry::ChannelChat { channel_id } => self
.render_channel_chat(*channel_id, is_selected, window, cx)
.into_any_element(),
}
}
@@ -2781,7 +2709,6 @@ impl CollabPanel {
let disclosed =
has_children.then(|| self.collapsed_channels.binary_search(&channel.id).is_err());
let has_messages_notification = channel_store.has_new_messages(channel_id);
let has_notes_notification = channel_store.has_channel_buffer_changed(channel_id);
const FACEPILE_LIMIT: usize = 3;
@@ -2909,21 +2836,6 @@ impl CollabPanel {
.rounded_l_sm()
.gap_1()
.px_1()
.child(
IconButton::new("channel_chat", IconName::Chat)
.style(ButtonStyle::Filled)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::Small)
.icon_color(if has_messages_notification {
Color::Default
} else {
Color::Muted
})
.on_click(cx.listener(move |this, _, window, cx| {
this.join_channel_chat(channel_id, window, cx)
}))
.tooltip(Tooltip::text("Open channel chat")),
)
.child(
IconButton::new("channel_notes", IconName::Reader)
.style(ButtonStyle::Filled)
@@ -3183,14 +3095,6 @@ impl PartialEq for ListEntry {
return channel_id == other_id;
}
}
ListEntry::ChannelChat { channel_id } => {
if let ListEntry::ChannelChat {
channel_id: other_id,
} = other
{
return channel_id == other_id;
}
}
ListEntry::ChannelInvite(channel_1) => {
if let ListEntry::ChannelInvite(channel_2) = other {
return channel_1.id == channel_2.id;

View File

@@ -1,5 +1,4 @@
pub mod channel_view;
pub mod chat_panel;
pub mod collab_panel;
pub mod notification_panel;
pub mod notifications;
@@ -13,9 +12,7 @@ use gpui::{
WindowDecorations, WindowKind, WindowOptions, point,
};
use panel_settings::MessageEditorSettings;
pub use panel_settings::{
ChatPanelButton, ChatPanelSettings, CollaborationPanelSettings, NotificationPanelSettings,
};
pub use panel_settings::{CollaborationPanelSettings, NotificationPanelSettings};
use release_channel::ReleaseChannel;
use settings::Settings;
use ui::px;
@@ -23,12 +20,10 @@ use workspace::AppState;
pub fn init(app_state: &Arc<AppState>, cx: &mut App) {
CollaborationPanelSettings::register(cx);
ChatPanelSettings::register(cx);
NotificationPanelSettings::register(cx);
MessageEditorSettings::register(cx);
channel_view::init(cx);
chat_panel::init(cx);
collab_panel::init(cx);
notification_panel::init(cx);
notifications::init(app_state, cx);

View File

@@ -1,4 +1,4 @@
use crate::{NotificationPanelSettings, chat_panel::ChatPanel};
use crate::NotificationPanelSettings;
use anyhow::Result;
use channel::ChannelStore;
use client::{ChannelId, Client, Notification, User, UserStore};
@@ -6,8 +6,8 @@ use collections::HashMap;
use db::kvp::KEY_VALUE_STORE;
use futures::StreamExt;
use gpui::{
AnyElement, App, AsyncWindowContext, ClickEvent, Context, CursorStyle, DismissEvent, Element,
Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, ListAlignment,
AnyElement, App, AsyncWindowContext, ClickEvent, Context, DismissEvent, Element, Entity,
EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, ListAlignment,
ListScrollEvent, ListState, ParentElement, Render, StatefulInteractiveElement, Styled, Task,
WeakEntity, Window, actions, div, img, list, px,
};
@@ -71,7 +71,6 @@ pub struct NotificationPresenter {
pub text: String,
pub icon: &'static str,
pub needs_response: bool,
pub can_navigate: bool,
}
actions!(
@@ -234,7 +233,6 @@ impl NotificationPanel {
actor,
text,
needs_response,
can_navigate,
..
} = self.present_notification(entry, cx)?;
@@ -269,14 +267,6 @@ impl NotificationPanel {
.py_1()
.gap_2()
.hover(|style| style.bg(cx.theme().colors().element_hover))
.when(can_navigate, |el| {
el.cursor(CursorStyle::PointingHand).on_click({
let notification = notification.clone();
cx.listener(move |this, _, window, cx| {
this.did_click_notification(&notification, window, cx)
})
})
})
.children(actor.map(|actor| {
img(actor.avatar_uri.clone())
.flex_none()
@@ -369,7 +359,6 @@ impl NotificationPanel {
text: format!("{} wants to add you as a contact", requester.github_login),
needs_response: user_store.has_incoming_contact_request(requester.id),
actor: Some(requester),
can_navigate: false,
})
}
Notification::ContactRequestAccepted { responder_id } => {
@@ -379,7 +368,6 @@ impl NotificationPanel {
text: format!("{} accepted your contact invite", responder.github_login),
needs_response: false,
actor: Some(responder),
can_navigate: false,
})
}
Notification::ChannelInvitation {
@@ -396,29 +384,6 @@ impl NotificationPanel {
),
needs_response: channel_store.has_channel_invitation(ChannelId(channel_id)),
actor: Some(inviter),
can_navigate: false,
})
}
Notification::ChannelMessageMention {
sender_id,
channel_id,
message_id,
} => {
let sender = user_store.get_cached_user(sender_id)?;
let channel = channel_store.channel_for_id(ChannelId(channel_id))?;
let message = self
.notification_store
.read(cx)
.channel_message_for_id(message_id)?;
Some(NotificationPresenter {
icon: "icons/conversations.svg",
text: format!(
"{} mentioned you in #{}:\n{}",
sender.github_login, channel.name, message.body,
),
needs_response: false,
actor: Some(sender),
can_navigate: true,
})
}
}
@@ -433,9 +398,7 @@ impl NotificationPanel {
) {
let should_mark_as_read = match notification {
Notification::ContactRequestAccepted { .. } => true,
Notification::ContactRequest { .. }
| Notification::ChannelInvitation { .. }
| Notification::ChannelMessageMention { .. } => false,
Notification::ContactRequest { .. } | Notification::ChannelInvitation { .. } => false,
};
if should_mark_as_read {
@@ -457,55 +420,6 @@ impl NotificationPanel {
}
}
fn did_click_notification(
&mut self,
notification: &Notification,
window: &mut Window,
cx: &mut Context<Self>,
) {
if let Notification::ChannelMessageMention {
message_id,
channel_id,
..
} = notification.clone()
&& let Some(workspace) = self.workspace.upgrade()
{
window.defer(cx, move |window, cx| {
workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.focus_panel::<ChatPanel>(window, cx) {
panel.update(cx, |panel, cx| {
panel
.select_channel(ChannelId(channel_id), Some(message_id), cx)
.detach_and_log_err(cx);
});
}
});
});
}
}
fn is_showing_notification(&self, notification: &Notification, cx: &mut Context<Self>) -> bool {
if !self.active {
return false;
}
if let Notification::ChannelMessageMention { channel_id, .. } = &notification
&& let Some(workspace) = self.workspace.upgrade()
{
return if let Some(panel) = workspace.read(cx).panel::<ChatPanel>(cx) {
let panel = panel.read(cx);
panel.is_scrolled_to_bottom()
&& panel
.active_chat()
.is_some_and(|chat| chat.read(cx).channel_id.0 == *channel_id)
} else {
false
};
}
false
}
fn on_notification_event(
&mut self,
_: &Entity<NotificationStore>,
@@ -515,9 +429,7 @@ impl NotificationPanel {
) {
match event {
NotificationEvent::NewNotification { entry } => {
if !self.is_showing_notification(&entry.notification, cx) {
self.unseen_notifications.push(entry.clone());
}
self.unseen_notifications.push(entry.clone());
self.add_toast(entry, window, cx);
}
NotificationEvent::NotificationRemoved { entry }
@@ -541,10 +453,6 @@ impl NotificationPanel {
window: &mut Window,
cx: &mut Context<Self>,
) {
if self.is_showing_notification(&entry.notification, cx) {
return;
}
let Some(NotificationPresenter { actor, text, .. }) = self.present_notification(entry, cx)
else {
return;
@@ -568,7 +476,6 @@ impl NotificationPanel {
workspace.show_notification(id, cx, |cx| {
let workspace = cx.entity().downgrade();
cx.new(|cx| NotificationToast {
notification_id,
actor,
text,
workspace,
@@ -781,7 +688,6 @@ impl Panel for NotificationPanel {
}
pub struct NotificationToast {
notification_id: u64,
actor: Option<Arc<User>>,
text: String,
workspace: WeakEntity<Workspace>,
@@ -799,22 +705,10 @@ impl WorkspaceNotification for NotificationToast {}
impl NotificationToast {
fn focus_notification_panel(&self, window: &mut Window, cx: &mut Context<Self>) {
let workspace = self.workspace.clone();
let notification_id = self.notification_id;
window.defer(cx, move |window, cx| {
workspace
.update(cx, |workspace, cx| {
if let Some(panel) = workspace.focus_panel::<NotificationPanel>(window, cx) {
panel.update(cx, |panel, cx| {
let store = panel.notification_store.read(cx);
if let Some(entry) = store.notification_for_id(notification_id) {
panel.did_click_notification(
&entry.clone().notification,
window,
cx,
);
}
});
}
workspace.focus_panel::<NotificationPanel>(window, cx)
})
.ok();
})

View File

@@ -11,39 +11,6 @@ pub struct CollaborationPanelSettings {
pub default_width: Pixels,
}
#[derive(Clone, Copy, Default, Serialize, Deserialize, JsonSchema, Debug)]
#[serde(rename_all = "snake_case")]
pub enum ChatPanelButton {
Never,
Always,
#[default]
WhenInCall,
}
#[derive(Deserialize, Debug)]
pub struct ChatPanelSettings {
pub button: ChatPanelButton,
pub dock: DockPosition,
pub default_width: Pixels,
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug, SettingsUi, SettingsKey)]
#[settings_key(key = "chat_panel")]
pub struct ChatPanelSettingsContent {
/// When to show the panel button in the status bar.
///
/// Default: only when in a call
pub button: Option<ChatPanelButton>,
/// Where to dock the panel.
///
/// Default: right
pub dock: Option<DockPosition>,
/// Default width of the panel in pixels.
///
/// Default: 240
pub default_width: Option<f32>,
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug, SettingsUi, SettingsKey)]
#[settings_key(key = "collaboration_panel")]
pub struct PanelSettingsContent {
@@ -108,19 +75,6 @@ impl Settings for CollaborationPanelSettings {
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
}
impl Settings for ChatPanelSettings {
type FileContent = ChatPanelSettingsContent;
fn load(
sources: SettingsSources<Self::FileContent>,
_: &mut gpui::App,
) -> anyhow::Result<Self> {
sources.json_merge()
}
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
}
impl Settings for NotificationPanelSettings {
type FileContent = NotificationPanelSettingsContent;

View File

@@ -0,0 +1,982 @@
use crate::{
DIAGNOSTICS_UPDATE_DELAY, IncludeWarnings, ToggleWarnings, context_range_for_entry,
diagnostic_renderer::{DiagnosticBlock, DiagnosticRenderer},
toolbar_controls::DiagnosticsToolbarEditor,
};
use anyhow::Result;
use collections::HashMap;
use editor::{
Editor, EditorEvent, ExcerptRange, MultiBuffer, PathKey,
display_map::{BlockPlacement, BlockProperties, BlockStyle, CustomBlockId},
multibuffer_context_lines,
};
use gpui::{
AnyElement, App, AppContext, Context, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
InteractiveElement, IntoElement, ParentElement, Render, SharedString, Styled, Subscription,
Task, WeakEntity, Window, actions, div,
};
use language::{Buffer, DiagnosticEntry, Point};
use project::{
DiagnosticSummary, Event, Project, ProjectItem, ProjectPath,
project_settings::{DiagnosticSeverity, ProjectSettings},
};
use settings::Settings;
use std::{
any::{Any, TypeId},
cmp::Ordering,
sync::Arc,
};
use text::{Anchor, BufferSnapshot, OffsetRangeExt};
use ui::{Button, ButtonStyle, Icon, IconName, Label, Tooltip, h_flex, prelude::*};
use util::paths::PathExt;
use workspace::{
ItemHandle, ItemNavHistory, ToolbarItemLocation, Workspace,
item::{BreadcrumbText, Item, ItemEvent, TabContentParams},
};
actions!(
diagnostics,
[
/// Opens the project diagnostics view for the currently focused file.
DeployCurrentFile,
]
);
/// The `BufferDiagnosticsEditor` is meant to be used when dealing specifically
/// with diagnostics for a single buffer, as only the excerpts of the buffer
/// where diagnostics are available are displayed.
pub(crate) struct BufferDiagnosticsEditor {
pub project: Entity<Project>,
focus_handle: FocusHandle,
editor: Entity<Editor>,
/// The current diagnostic entries in the `BufferDiagnosticsEditor`. Used to
/// allow quick comparison of updated diagnostics, to confirm if anything
/// has changed.
pub(crate) diagnostics: Vec<DiagnosticEntry<Anchor>>,
/// The blocks used to display the diagnostics' content in the editor, next
/// to the excerpts where the diagnostic originated.
blocks: Vec<CustomBlockId>,
/// Multibuffer to contain all excerpts that contain diagnostics, which are
/// to be rendered in the editor.
multibuffer: Entity<MultiBuffer>,
/// The buffer for which the editor is displaying diagnostics and excerpts
/// for.
buffer: Option<Entity<Buffer>>,
/// The path for which the editor is displaying diagnostics for.
project_path: ProjectPath,
/// Summary of the number of warnings and errors for the path. Used to
/// display the number of warnings and errors in the tab's content.
summary: DiagnosticSummary,
/// Whether to include warnings in the list of diagnostics shown in the
/// editor.
pub(crate) include_warnings: bool,
/// Keeps track of whether there's a background task already running to
/// update the excerpts, in order to avoid firing multiple tasks for this purpose.
pub(crate) update_excerpts_task: Option<Task<Result<()>>>,
/// The project's subscription, responsible for processing events related to
/// diagnostics.
_subscription: Subscription,
}
impl BufferDiagnosticsEditor {
/// Creates new instance of the `BufferDiagnosticsEditor` which can then be
/// displayed by adding it to a pane.
pub fn new(
project_path: ProjectPath,
project_handle: Entity<Project>,
buffer: Option<Entity<Buffer>>,
include_warnings: bool,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
// Subscribe to project events related to diagnostics so the
// `BufferDiagnosticsEditor` can update its state accordingly.
let project_event_subscription = cx.subscribe_in(
&project_handle,
window,
|buffer_diagnostics_editor, _project, event, window, cx| match event {
Event::DiskBasedDiagnosticsStarted { .. } => {
cx.notify();
}
Event::DiskBasedDiagnosticsFinished { .. } => {
buffer_diagnostics_editor.update_all_excerpts(window, cx);
}
Event::DiagnosticsUpdated {
paths,
language_server_id,
} => {
// When diagnostics have been updated, the
// `BufferDiagnosticsEditor` should update its state only if
// one of the paths matches its `project_path`, otherwise
// the event should be ignored.
if paths.contains(&buffer_diagnostics_editor.project_path) {
buffer_diagnostics_editor.update_diagnostic_summary(cx);
if buffer_diagnostics_editor.editor.focus_handle(cx).contains_focused(window, cx) || buffer_diagnostics_editor.focus_handle.contains_focused(window, cx) {
log::debug!("diagnostics updated for server {language_server_id}. recording change");
} else {
log::debug!("diagnostics updated for server {language_server_id}. updating excerpts");
buffer_diagnostics_editor.update_all_excerpts(window, cx);
}
}
}
_ => {}
},
);
let focus_handle = cx.focus_handle();
cx.on_focus_in(
&focus_handle,
window,
|buffer_diagnostics_editor, window, cx| buffer_diagnostics_editor.focus_in(window, cx),
)
.detach();
cx.on_focus_out(
&focus_handle,
window,
|buffer_diagnostics_editor, _event, window, cx| {
buffer_diagnostics_editor.focus_out(window, cx)
},
)
.detach();
let summary = project_handle
.read(cx)
.diagnostic_summary_for_path(&project_path, cx);
let multibuffer = cx.new(|cx| MultiBuffer::new(project_handle.read(cx).capability()));
let max_severity = Self::max_diagnostics_severity(include_warnings);
let editor = cx.new(|cx| {
let mut editor = Editor::for_multibuffer(
multibuffer.clone(),
Some(project_handle.clone()),
window,
cx,
);
editor.set_vertical_scroll_margin(5, cx);
editor.disable_inline_diagnostics();
editor.set_max_diagnostics_severity(max_severity, cx);
editor.set_all_diagnostics_active(cx);
editor
});
// Subscribe to events triggered by the editor in order to correctly
// update the buffer's excerpts.
cx.subscribe_in(
&editor,
window,
|buffer_diagnostics_editor, _editor, event: &EditorEvent, window, cx| {
cx.emit(event.clone());
match event {
// If the user tries to focus on the editor but there's actually
// no excerpts for the buffer, focus back on the
// `BufferDiagnosticsEditor` instance.
EditorEvent::Focused => {
if buffer_diagnostics_editor.multibuffer.read(cx).is_empty() {
window.focus(&buffer_diagnostics_editor.focus_handle);
}
}
EditorEvent::Blurred => {
buffer_diagnostics_editor.update_all_excerpts(window, cx)
}
_ => {}
}
},
)
.detach();
let diagnostics = vec![];
let update_excerpts_task = None;
let mut buffer_diagnostics_editor = Self {
project: project_handle,
focus_handle,
editor,
diagnostics,
blocks: Default::default(),
multibuffer,
buffer,
project_path,
summary,
include_warnings,
update_excerpts_task,
_subscription: project_event_subscription,
};
buffer_diagnostics_editor.update_all_diagnostics(window, cx);
buffer_diagnostics_editor
}
fn deploy(
workspace: &mut Workspace,
_: &DeployCurrentFile,
window: &mut Window,
cx: &mut Context<Workspace>,
) {
// Determine the currently opened path by finding the active editor and
// finding the project path for the buffer.
// If there's no active editor with a project path, avoiding deploying
// the buffer diagnostics view.
if let Some(editor) = workspace.active_item_as::<Editor>(cx)
&& let Some(project_path) = editor.project_path(cx)
{
// Check if there's already a `BufferDiagnosticsEditor` tab for this
// same path, and if so, focus on that one instead of creating a new
// one.
let existing_editor = workspace
.items_of_type::<BufferDiagnosticsEditor>(cx)
.find(|editor| editor.read(cx).project_path == project_path);
if let Some(editor) = existing_editor {
workspace.activate_item(&editor, true, true, window, cx);
} else {
let include_warnings = match cx.try_global::<IncludeWarnings>() {
Some(include_warnings) => include_warnings.0,
None => ProjectSettings::get_global(cx).diagnostics.include_warnings,
};
let item = cx.new(|cx| {
Self::new(
project_path,
workspace.project().clone(),
editor.read(cx).buffer().read(cx).as_singleton(),
include_warnings,
window,
cx,
)
});
workspace.add_item_to_active_pane(Box::new(item), None, true, window, cx);
}
}
}
pub fn register(
workspace: &mut Workspace,
_window: Option<&mut Window>,
_: &mut Context<Workspace>,
) {
workspace.register_action(Self::deploy);
}
fn update_all_diagnostics(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.update_all_excerpts(window, cx);
}
fn update_diagnostic_summary(&mut self, cx: &mut Context<Self>) {
let project = self.project.read(cx);
self.summary = project.diagnostic_summary_for_path(&self.project_path, cx);
}
/// Enqueue an update to the excerpts and diagnostic blocks being shown in
/// the editor.
pub(crate) fn update_all_excerpts(&mut self, window: &mut Window, cx: &mut Context<Self>) {
// If there's already a task updating the excerpts, early return and let
// the other task finish.
if self.update_excerpts_task.is_some() {
return;
}
let buffer = self.buffer.clone();
self.update_excerpts_task = Some(cx.spawn_in(window, async move |editor, cx| {
cx.background_executor()
.timer(DIAGNOSTICS_UPDATE_DELAY)
.await;
if let Some(buffer) = buffer {
editor
.update_in(cx, |editor, window, cx| {
editor.update_excerpts(buffer, window, cx)
})?
.await?;
};
let _ = editor.update(cx, |editor, cx| {
editor.update_excerpts_task = None;
cx.notify();
});
Ok(())
}));
}
/// Updates the excerpts in the `BufferDiagnosticsEditor` for a single
/// buffer.
fn update_excerpts(
&mut self,
buffer: Entity<Buffer>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let was_empty = self.multibuffer.read(cx).is_empty();
let multibuffer_context = multibuffer_context_lines(cx);
let buffer_snapshot = buffer.read(cx).snapshot();
let buffer_snapshot_max = buffer_snapshot.max_point();
let max_severity = Self::max_diagnostics_severity(self.include_warnings)
.into_lsp()
.unwrap_or(lsp::DiagnosticSeverity::WARNING);
cx.spawn_in(window, async move |buffer_diagnostics_editor, mut cx| {
// Fetch the diagnostics for the whole of the buffer
// (`Point::zero()..buffer_snapshot.max_point()`) so we can confirm
// if the diagnostics changed, if it didn't, early return as there's
// nothing to update.
let diagnostics = buffer_snapshot
.diagnostics_in_range::<_, Anchor>(Point::zero()..buffer_snapshot_max, false)
.collect::<Vec<_>>();
let unchanged =
buffer_diagnostics_editor.update(cx, |buffer_diagnostics_editor, _cx| {
if buffer_diagnostics_editor
.diagnostics_are_unchanged(&diagnostics, &buffer_snapshot)
{
return true;
}
buffer_diagnostics_editor.set_diagnostics(&diagnostics);
return false;
})?;
if unchanged {
return Ok(());
}
// Mapping between the Group ID and a vector of DiagnosticEntry.
let mut grouped: HashMap<usize, Vec<_>> = HashMap::default();
for entry in diagnostics {
grouped
.entry(entry.diagnostic.group_id)
.or_default()
.push(DiagnosticEntry {
range: entry.range.to_point(&buffer_snapshot),
diagnostic: entry.diagnostic,
})
}
let mut blocks: Vec<DiagnosticBlock> = Vec::new();
for (_, group) in grouped {
// If the minimum severity of the group is higher than the
// maximum severity, or it doesn't even have severity, skip this
// group.
if group
.iter()
.map(|d| d.diagnostic.severity)
.min()
.is_none_or(|severity| severity > max_severity)
{
continue;
}
let diagnostic_blocks = cx.update(|_window, cx| {
DiagnosticRenderer::diagnostic_blocks_for_group(
group,
buffer_snapshot.remote_id(),
Some(Arc::new(buffer_diagnostics_editor.clone())),
cx,
)
})?;
// For each of the diagnostic blocks to be displayed in the
// editor, figure out its index in the list of blocks.
//
// The following rules are used to determine the order:
// 1. Blocks with a lower start position should come first.
// 2. If two blocks have the same start position, the one with
// the higher end position should come first.
for diagnostic_block in diagnostic_blocks {
let index = blocks.partition_point(|probe| {
match probe
.initial_range
.start
.cmp(&diagnostic_block.initial_range.start)
{
Ordering::Less => true,
Ordering::Greater => false,
Ordering::Equal => {
probe.initial_range.end > diagnostic_block.initial_range.end
}
}
});
blocks.insert(index, diagnostic_block);
}
}
// Build the excerpt ranges for this specific buffer's diagnostics,
// so those excerpts can later be used to update the excerpts shown
// in the editor.
// This is done by iterating over the list of diagnostic blocks and
// determine what range does the diagnostic block span.
let mut excerpt_ranges: Vec<ExcerptRange<Point>> = Vec::new();
for diagnostic_block in blocks.iter() {
let excerpt_range = context_range_for_entry(
diagnostic_block.initial_range.clone(),
multibuffer_context,
buffer_snapshot.clone(),
&mut cx,
)
.await;
let index = excerpt_ranges
.binary_search_by(|probe| {
probe
.context
.start
.cmp(&excerpt_range.start)
.then(probe.context.end.cmp(&excerpt_range.end))
.then(
probe
.primary
.start
.cmp(&diagnostic_block.initial_range.start),
)
.then(probe.primary.end.cmp(&diagnostic_block.initial_range.end))
.then(Ordering::Greater)
})
.unwrap_or_else(|index| index);
excerpt_ranges.insert(
index,
ExcerptRange {
context: excerpt_range,
primary: diagnostic_block.initial_range.clone(),
},
)
}
// Finally, update the editor's content with the new excerpt ranges
// for this editor, as well as the diagnostic blocks.
buffer_diagnostics_editor.update_in(cx, |buffer_diagnostics_editor, window, cx| {
// Remove the list of `CustomBlockId` from the editor's display
// map, ensuring that if any diagnostics have been solved, the
// associated block stops being shown.
let block_ids = buffer_diagnostics_editor.blocks.clone();
buffer_diagnostics_editor.editor.update(cx, |editor, cx| {
editor.display_map.update(cx, |display_map, cx| {
display_map.remove_blocks(block_ids.into_iter().collect(), cx);
})
});
let (anchor_ranges, _) =
buffer_diagnostics_editor
.multibuffer
.update(cx, |multibuffer, cx| {
multibuffer.set_excerpt_ranges_for_path(
PathKey::for_buffer(&buffer, cx),
buffer.clone(),
&buffer_snapshot,
excerpt_ranges,
cx,
)
});
if was_empty {
if let Some(anchor_range) = anchor_ranges.first() {
let range_to_select = anchor_range.start..anchor_range.start;
buffer_diagnostics_editor.editor.update(cx, |editor, cx| {
editor.change_selections(Default::default(), window, cx, |selection| {
selection.select_anchor_ranges([range_to_select])
})
});
// If the `BufferDiagnosticsEditor` is currently
// focused, move focus to its editor.
if buffer_diagnostics_editor.focus_handle.is_focused(window) {
buffer_diagnostics_editor
.editor
.read(cx)
.focus_handle(cx)
.focus(window);
}
}
}
// Cloning the blocks before moving ownership so these can later
// be used to set the block contents for testing purposes.
#[cfg(test)]
let cloned_blocks = blocks.clone();
// Build new diagnostic blocks to be added to the editor's
// display map for the new diagnostics. Update the `blocks`
// property before finishing, to ensure the blocks are removed
// on the next execution.
let editor_blocks =
anchor_ranges
.into_iter()
.zip(blocks.into_iter())
.map(|(anchor, block)| {
let editor = buffer_diagnostics_editor.editor.downgrade();
BlockProperties {
placement: BlockPlacement::Near(anchor.start),
height: Some(1),
style: BlockStyle::Flex,
render: Arc::new(move |block_context| {
block.render_block(editor.clone(), block_context)
}),
priority: 1,
}
});
let block_ids = buffer_diagnostics_editor.editor.update(cx, |editor, cx| {
editor.display_map.update(cx, |display_map, cx| {
display_map.insert_blocks(editor_blocks, cx)
})
});
// In order to be able to verify which diagnostic blocks are
// rendered in the editor, the `set_block_content_for_tests`
// function must be used, so that the
// `editor::test::editor_content_with_blocks` function can then
// be called to fetch these blocks.
#[cfg(test)]
{
for (block_id, block) in block_ids.iter().zip(cloned_blocks.iter()) {
let markdown = block.markdown.clone();
editor::test::set_block_content_for_tests(
&buffer_diagnostics_editor.editor,
*block_id,
cx,
move |cx| {
markdown::MarkdownElement::rendered_text(
markdown.clone(),
cx,
editor::hover_popover::diagnostics_markdown_style,
)
},
);
}
}
buffer_diagnostics_editor.blocks = block_ids;
cx.notify()
})
})
}
fn set_diagnostics(&mut self, diagnostics: &Vec<DiagnosticEntry<Anchor>>) {
self.diagnostics = diagnostics.clone();
}
fn diagnostics_are_unchanged(
&self,
diagnostics: &Vec<DiagnosticEntry<Anchor>>,
snapshot: &BufferSnapshot,
) -> bool {
if self.diagnostics.len() != diagnostics.len() {
return false;
}
self.diagnostics
.iter()
.zip(diagnostics.iter())
.all(|(existing, new)| {
existing.diagnostic.message == new.diagnostic.message
&& existing.diagnostic.severity == new.diagnostic.severity
&& existing.diagnostic.is_primary == new.diagnostic.is_primary
&& existing.range.to_offset(snapshot) == new.range.to_offset(snapshot)
})
}
fn focus_in(&mut self, window: &mut Window, cx: &mut Context<Self>) {
// If the `BufferDiagnosticsEditor` is focused and the multibuffer is
// not empty, focus on the editor instead, which will allow the user to
// start interacting and editing the buffer's contents.
if self.focus_handle.is_focused(window) && !self.multibuffer.read(cx).is_empty() {
self.editor.focus_handle(cx).focus(window)
}
}
fn focus_out(&mut self, window: &mut Window, cx: &mut Context<Self>) {
if !self.focus_handle.is_focused(window) && !self.editor.focus_handle(cx).is_focused(window)
{
self.update_all_excerpts(window, cx);
}
}
pub fn toggle_warnings(
&mut self,
_: &ToggleWarnings,
window: &mut Window,
cx: &mut Context<Self>,
) {
let include_warnings = !self.include_warnings;
let max_severity = Self::max_diagnostics_severity(include_warnings);
self.editor.update(cx, |editor, cx| {
editor.set_max_diagnostics_severity(max_severity, cx);
});
self.include_warnings = include_warnings;
self.diagnostics.clear();
self.update_all_diagnostics(window, cx);
}
fn max_diagnostics_severity(include_warnings: bool) -> DiagnosticSeverity {
match include_warnings {
true => DiagnosticSeverity::Warning,
false => DiagnosticSeverity::Error,
}
}
#[cfg(test)]
pub fn editor(&self) -> &Entity<Editor> {
&self.editor
}
#[cfg(test)]
pub fn summary(&self) -> &DiagnosticSummary {
&self.summary
}
}
impl Focusable for BufferDiagnosticsEditor {
fn focus_handle(&self, _: &App) -> FocusHandle {
self.focus_handle.clone()
}
}
impl EventEmitter<EditorEvent> for BufferDiagnosticsEditor {}
impl Item for BufferDiagnosticsEditor {
type Event = EditorEvent;
fn act_as_type<'a>(
&'a self,
type_id: std::any::TypeId,
self_handle: &'a Entity<Self>,
_: &'a App,
) -> Option<gpui::AnyView> {
if type_id == TypeId::of::<Self>() {
Some(self_handle.to_any())
} else if type_id == TypeId::of::<Editor>() {
Some(self.editor.to_any())
} else {
None
}
}
fn added_to_workspace(
&mut self,
workspace: &mut Workspace,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.editor.update(cx, |editor, cx| {
editor.added_to_workspace(workspace, window, cx)
});
}
fn breadcrumb_location(&self, _: &App) -> ToolbarItemLocation {
ToolbarItemLocation::PrimaryLeft
}
fn breadcrumbs(&self, theme: &theme::Theme, cx: &App) -> Option<Vec<BreadcrumbText>> {
self.editor.breadcrumbs(theme, cx)
}
fn can_save(&self, _cx: &App) -> bool {
true
}
fn clone_on_split(
&self,
_workspace_id: Option<workspace::WorkspaceId>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Option<Entity<Self>>
where
Self: Sized,
{
Some(cx.new(|cx| {
BufferDiagnosticsEditor::new(
self.project_path.clone(),
self.project.clone(),
self.buffer.clone(),
self.include_warnings,
window,
cx,
)
}))
}
fn deactivated(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.editor
.update(cx, |editor, cx| editor.deactivated(window, cx));
}
fn for_each_project_item(&self, cx: &App, f: &mut dyn FnMut(EntityId, &dyn ProjectItem)) {
self.editor.for_each_project_item(cx, f);
}
fn has_conflict(&self, cx: &App) -> bool {
self.multibuffer.read(cx).has_conflict(cx)
}
fn has_deleted_file(&self, cx: &App) -> bool {
self.multibuffer.read(cx).has_deleted_file(cx)
}
fn is_dirty(&self, cx: &App) -> bool {
self.multibuffer.read(cx).is_dirty(cx)
}
fn is_singleton(&self, _cx: &App) -> bool {
false
}
fn navigate(
&mut self,
data: Box<dyn Any>,
window: &mut Window,
cx: &mut Context<Self>,
) -> bool {
self.editor
.update(cx, |editor, cx| editor.navigate(data, window, cx))
}
fn reload(
&mut self,
project: Entity<Project>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
self.editor.reload(project, window, cx)
}
fn save(
&mut self,
options: workspace::item::SaveOptions,
project: Entity<Project>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
self.editor.save(options, project, window, cx)
}
fn save_as(
&mut self,
_project: Entity<Project>,
_path: ProjectPath,
_window: &mut Window,
_cx: &mut Context<Self>,
) -> Task<Result<()>> {
unreachable!()
}
fn set_nav_history(
&mut self,
nav_history: ItemNavHistory,
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.editor.update(cx, |editor, _| {
editor.set_nav_history(Some(nav_history));
})
}
// Builds the content to be displayed in the tab.
fn tab_content(&self, params: TabContentParams, _window: &Window, _cx: &App) -> AnyElement {
let error_count = self.summary.error_count;
let warning_count = self.summary.warning_count;
let label = Label::new(
self.project_path
.path
.file_name()
.map(|f| f.to_sanitized_string())
.unwrap_or_else(|| self.project_path.path.to_sanitized_string()),
);
h_flex()
.gap_1()
.child(label)
.when(error_count == 0 && warning_count == 0, |parent| {
parent.child(
h_flex()
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success)),
)
})
.when(error_count > 0, |parent| {
parent.child(
h_flex()
.gap_1()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(Label::new(error_count.to_string()).color(params.text_color())),
)
})
.when(warning_count > 0, |parent| {
parent.child(
h_flex()
.gap_1()
.child(Icon::new(IconName::Warning).color(Color::Warning))
.child(Label::new(warning_count.to_string()).color(params.text_color())),
)
})
.into_any_element()
}
fn tab_content_text(&self, _detail: usize, _app: &App) -> SharedString {
"Buffer Diagnostics".into()
}
fn tab_tooltip_text(&self, _: &App) -> Option<SharedString> {
Some(
format!(
"Buffer Diagnostics - {}",
self.project_path.path.to_sanitized_string()
)
.into(),
)
}
fn telemetry_event_text(&self) -> Option<&'static str> {
Some("Buffer Diagnostics Opened")
}
fn to_item_events(event: &EditorEvent, f: impl FnMut(ItemEvent)) {
Editor::to_item_events(event, f)
}
}
impl Render for BufferDiagnosticsEditor {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let filename = self.project_path.path.to_sanitized_string();
let error_count = self.summary.error_count;
let warning_count = match self.include_warnings {
true => self.summary.warning_count,
false => 0,
};
let child = if error_count + warning_count == 0 {
let label = match warning_count {
0 => "No problems in",
_ => "No errors in",
};
v_flex()
.key_context("EmptyPane")
.size_full()
.gap_1()
.justify_center()
.items_center()
.text_center()
.bg(cx.theme().colors().editor_background)
.child(
div()
.h_flex()
.child(Label::new(label).color(Color::Muted))
.child(
Button::new("open-file", filename)
.style(ButtonStyle::Transparent)
.tooltip(Tooltip::text("Open File"))
.on_click(cx.listener(|buffer_diagnostics, _, window, cx| {
if let Some(workspace) = window.root::<Workspace>().flatten() {
workspace.update(cx, |workspace, cx| {
workspace
.open_path(
buffer_diagnostics.project_path.clone(),
None,
true,
window,
cx,
)
.detach_and_log_err(cx);
})
}
})),
),
)
.when(self.summary.warning_count > 0, |div| {
let label = match self.summary.warning_count {
1 => "Show 1 warning".into(),
warning_count => format!("Show {} warnings", warning_count),
};
div.child(
Button::new("diagnostics-show-warning-label", label).on_click(cx.listener(
|buffer_diagnostics_editor, _, window, cx| {
buffer_diagnostics_editor.toggle_warnings(
&Default::default(),
window,
cx,
);
cx.notify();
},
)),
)
})
} else {
div().size_full().child(self.editor.clone())
};
div()
.key_context("Diagnostics")
.track_focus(&self.focus_handle(cx))
.size_full()
.child(child)
}
}
impl DiagnosticsToolbarEditor for WeakEntity<BufferDiagnosticsEditor> {
fn include_warnings(&self, cx: &App) -> bool {
self.read_with(cx, |buffer_diagnostics_editor, _cx| {
buffer_diagnostics_editor.include_warnings
})
.unwrap_or(false)
}
fn has_stale_excerpts(&self, _cx: &App) -> bool {
false
}
fn is_updating(&self, cx: &App) -> bool {
self.read_with(cx, |buffer_diagnostics_editor, cx| {
buffer_diagnostics_editor.update_excerpts_task.is_some()
|| buffer_diagnostics_editor
.project
.read(cx)
.language_servers_running_disk_based_diagnostics(cx)
.next()
.is_some()
})
.unwrap_or(false)
}
fn stop_updating(&self, cx: &mut App) {
let _ = self.update(cx, |buffer_diagnostics_editor, cx| {
buffer_diagnostics_editor.update_excerpts_task = None;
cx.notify();
});
}
fn refresh_diagnostics(&self, window: &mut Window, cx: &mut App) {
let _ = self.update(cx, |buffer_diagnostics_editor, cx| {
buffer_diagnostics_editor.update_all_excerpts(window, cx);
});
}
fn toggle_warnings(&self, window: &mut Window, cx: &mut App) {
let _ = self.update(cx, |buffer_diagnostics_editor, cx| {
buffer_diagnostics_editor.toggle_warnings(&Default::default(), window, cx);
});
}
fn get_diagnostics_for_buffer(
&self,
_buffer_id: text::BufferId,
cx: &App,
) -> Vec<language::DiagnosticEntry<text::Anchor>> {
self.read_with(cx, |buffer_diagnostics_editor, _cx| {
buffer_diagnostics_editor.diagnostics.clone()
})
.unwrap_or_default()
}
}

View File

@@ -18,7 +18,7 @@ use ui::{
};
use util::maybe;
use crate::ProjectDiagnosticsEditor;
use crate::toolbar_controls::DiagnosticsToolbarEditor;
pub struct DiagnosticRenderer;
@@ -26,7 +26,7 @@ impl DiagnosticRenderer {
pub fn diagnostic_blocks_for_group(
diagnostic_group: Vec<DiagnosticEntry<Point>>,
buffer_id: BufferId,
diagnostics_editor: Option<WeakEntity<ProjectDiagnosticsEditor>>,
diagnostics_editor: Option<Arc<dyn DiagnosticsToolbarEditor>>,
cx: &mut App,
) -> Vec<DiagnosticBlock> {
let Some(primary_ix) = diagnostic_group
@@ -130,6 +130,7 @@ impl editor::DiagnosticRenderer for DiagnosticRenderer {
cx: &mut App,
) -> Vec<BlockProperties<Anchor>> {
let blocks = Self::diagnostic_blocks_for_group(diagnostic_group, buffer_id, None, cx);
blocks
.into_iter()
.map(|block| {
@@ -182,7 +183,7 @@ pub(crate) struct DiagnosticBlock {
pub(crate) initial_range: Range<Point>,
pub(crate) severity: DiagnosticSeverity,
pub(crate) markdown: Entity<Markdown>,
pub(crate) diagnostics_editor: Option<WeakEntity<ProjectDiagnosticsEditor>>,
pub(crate) diagnostics_editor: Option<Arc<dyn DiagnosticsToolbarEditor>>,
}
impl DiagnosticBlock {
@@ -233,7 +234,7 @@ impl DiagnosticBlock {
pub fn open_link(
editor: &mut Editor,
diagnostics_editor: &Option<WeakEntity<ProjectDiagnosticsEditor>>,
diagnostics_editor: &Option<Arc<dyn DiagnosticsToolbarEditor>>,
link: SharedString,
window: &mut Window,
cx: &mut Context<Editor>,
@@ -254,18 +255,10 @@ impl DiagnosticBlock {
if let Some(diagnostics_editor) = diagnostics_editor {
if let Some(diagnostic) = diagnostics_editor
.read_with(cx, |diagnostics, _| {
diagnostics
.diagnostics
.get(&buffer_id)
.cloned()
.unwrap_or_default()
.into_iter()
.filter(|d| d.diagnostic.group_id == group_id)
.nth(ix)
})
.ok()
.flatten()
.get_diagnostics_for_buffer(buffer_id, cx)
.into_iter()
.filter(|d| d.diagnostic.group_id == group_id)
.nth(ix)
{
let multibuffer = editor.buffer().read(cx);
let Some(snapshot) = multibuffer
@@ -297,9 +290,9 @@ impl DiagnosticBlock {
};
}
fn jump_to<T: ToOffset>(
fn jump_to<I: ToOffset>(
editor: &mut Editor,
range: Range<T>,
range: Range<I>,
window: &mut Window,
cx: &mut Context<Editor>,
) {

View File

@@ -1,12 +1,14 @@
pub mod items;
mod toolbar_controls;
mod buffer_diagnostics;
mod diagnostic_renderer;
#[cfg(test)]
mod diagnostics_tests;
use anyhow::Result;
use buffer_diagnostics::BufferDiagnosticsEditor;
use collections::{BTreeSet, HashMap};
use diagnostic_renderer::DiagnosticBlock;
use editor::{
@@ -36,6 +38,7 @@ use std::{
};
use text::{BufferId, OffsetRangeExt};
use theme::ActiveTheme;
use toolbar_controls::DiagnosticsToolbarEditor;
pub use toolbar_controls::ToolbarControls;
use ui::{Icon, IconName, Label, h_flex, prelude::*};
use util::ResultExt;
@@ -64,6 +67,7 @@ impl Global for IncludeWarnings {}
pub fn init(cx: &mut App) {
editor::set_diagnostic_renderer(diagnostic_renderer::DiagnosticRenderer {}, cx);
cx.observe_new(ProjectDiagnosticsEditor::register).detach();
cx.observe_new(BufferDiagnosticsEditor::register).detach();
}
pub(crate) struct ProjectDiagnosticsEditor {
@@ -85,6 +89,7 @@ pub(crate) struct ProjectDiagnosticsEditor {
impl EventEmitter<EditorEvent> for ProjectDiagnosticsEditor {}
const DIAGNOSTICS_UPDATE_DELAY: Duration = Duration::from_millis(50);
const DIAGNOSTICS_SUMMARY_UPDATE_DELAY: Duration = Duration::from_millis(30);
impl Render for ProjectDiagnosticsEditor {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
@@ -144,7 +149,7 @@ impl Render for ProjectDiagnosticsEditor {
}
impl ProjectDiagnosticsEditor {
fn register(
pub fn register(
workspace: &mut Workspace,
_window: Option<&mut Window>,
_: &mut Context<Workspace>,
@@ -160,7 +165,7 @@ impl ProjectDiagnosticsEditor {
cx: &mut Context<Self>,
) -> Self {
let project_event_subscription =
cx.subscribe_in(&project_handle, window, |this, project, event, window, cx| match event {
cx.subscribe_in(&project_handle, window, |this, _project, event, window, cx| match event {
project::Event::DiskBasedDiagnosticsStarted { .. } => {
cx.notify();
}
@@ -173,13 +178,12 @@ impl ProjectDiagnosticsEditor {
paths,
} => {
this.paths_to_update.extend(paths.clone());
let project = project.clone();
this.diagnostic_summary_update = cx.spawn(async move |this, cx| {
cx.background_executor()
.timer(Duration::from_millis(30))
.timer(DIAGNOSTICS_SUMMARY_UPDATE_DELAY)
.await;
this.update(cx, |this, cx| {
this.summary = project.read(cx).diagnostic_summary(false, cx);
this.update_diagnostic_summary(cx);
})
.log_err();
});
@@ -326,6 +330,7 @@ impl ProjectDiagnosticsEditor {
let is_active = workspace
.active_item(cx)
.is_some_and(|item| item.item_id() == existing.item_id());
workspace.activate_item(&existing, true, !is_active, window, cx);
} else {
let workspace_handle = cx.entity().downgrade();
@@ -383,22 +388,25 @@ impl ProjectDiagnosticsEditor {
/// currently have diagnostics or are currently present in this view.
fn update_all_excerpts(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.project.update(cx, |project, cx| {
let mut paths = project
let mut project_paths = project
.diagnostic_summaries(false, cx)
.map(|(path, _, _)| path)
.map(|(project_path, _, _)| project_path)
.collect::<BTreeSet<_>>();
self.multibuffer.update(cx, |multibuffer, cx| {
for buffer in multibuffer.all_buffers() {
if let Some(file) = buffer.read(cx).file() {
paths.insert(ProjectPath {
project_paths.insert(ProjectPath {
path: file.path().clone(),
worktree_id: file.worktree_id(cx),
});
}
}
});
self.paths_to_update = paths;
self.paths_to_update = project_paths;
});
self.update_stale_excerpts(window, cx);
}
@@ -428,6 +436,7 @@ impl ProjectDiagnosticsEditor {
let was_empty = self.multibuffer.read(cx).is_empty();
let buffer_snapshot = buffer.read(cx).snapshot();
let buffer_id = buffer_snapshot.remote_id();
let max_severity = if self.include_warnings {
lsp::DiagnosticSeverity::WARNING
} else {
@@ -441,6 +450,7 @@ impl ProjectDiagnosticsEditor {
false,
)
.collect::<Vec<_>>();
let unchanged = this.update(cx, |this, _| {
if this.diagnostics.get(&buffer_id).is_some_and(|existing| {
this.diagnostics_are_unchanged(existing, &diagnostics, &buffer_snapshot)
@@ -475,7 +485,7 @@ impl ProjectDiagnosticsEditor {
crate::diagnostic_renderer::DiagnosticRenderer::diagnostic_blocks_for_group(
group,
buffer_snapshot.remote_id(),
Some(this.clone()),
Some(Arc::new(this.clone())),
cx,
)
})?;
@@ -505,6 +515,7 @@ impl ProjectDiagnosticsEditor {
cx,
)
.await;
let i = excerpt_ranges
.binary_search_by(|probe| {
probe
@@ -574,6 +585,7 @@ impl ProjectDiagnosticsEditor {
priority: 1,
}
});
let block_ids = this.editor.update(cx, |editor, cx| {
editor.display_map.update(cx, |display_map, cx| {
display_map.insert_blocks(editor_blocks, cx)
@@ -604,6 +616,10 @@ impl ProjectDiagnosticsEditor {
})
})
}
fn update_diagnostic_summary(&mut self, cx: &mut Context<Self>) {
self.summary = self.project.read(cx).diagnostic_summary(false, cx);
}
}
impl Focusable for ProjectDiagnosticsEditor {
@@ -812,6 +828,68 @@ impl Item for ProjectDiagnosticsEditor {
}
}
impl DiagnosticsToolbarEditor for WeakEntity<ProjectDiagnosticsEditor> {
fn include_warnings(&self, cx: &App) -> bool {
self.read_with(cx, |project_diagnostics_editor, _cx| {
project_diagnostics_editor.include_warnings
})
.unwrap_or(false)
}
fn has_stale_excerpts(&self, cx: &App) -> bool {
self.read_with(cx, |project_diagnostics_editor, _cx| {
!project_diagnostics_editor.paths_to_update.is_empty()
})
.unwrap_or(false)
}
fn is_updating(&self, cx: &App) -> bool {
self.read_with(cx, |project_diagnostics_editor, cx| {
project_diagnostics_editor.update_excerpts_task.is_some()
|| project_diagnostics_editor
.project
.read(cx)
.language_servers_running_disk_based_diagnostics(cx)
.next()
.is_some()
})
.unwrap_or(false)
}
fn stop_updating(&self, cx: &mut App) {
let _ = self.update(cx, |project_diagnostics_editor, cx| {
project_diagnostics_editor.update_excerpts_task = None;
cx.notify();
});
}
fn refresh_diagnostics(&self, window: &mut Window, cx: &mut App) {
let _ = self.update(cx, |project_diagnostics_editor, cx| {
project_diagnostics_editor.update_all_excerpts(window, cx);
});
}
fn toggle_warnings(&self, window: &mut Window, cx: &mut App) {
let _ = self.update(cx, |project_diagnostics_editor, cx| {
project_diagnostics_editor.toggle_warnings(&Default::default(), window, cx);
});
}
fn get_diagnostics_for_buffer(
&self,
buffer_id: text::BufferId,
cx: &App,
) -> Vec<language::DiagnosticEntry<text::Anchor>> {
self.read_with(cx, |project_diagnostics_editor, _cx| {
project_diagnostics_editor
.diagnostics
.get(&buffer_id)
.cloned()
.unwrap_or_default()
})
.unwrap_or_default()
}
}
const DIAGNOSTIC_EXPANSION_ROW_LIMIT: u32 = 32;
async fn context_range_for_entry(

View File

@@ -1567,6 +1567,440 @@ async fn go_to_diagnostic_with_severity(cx: &mut TestAppContext) {
cx.assert_editor_state(indoc! {"error ˇwarning info hint"});
}
#[gpui::test]
async fn test_buffer_diagnostics(cx: &mut TestAppContext) {
init_test(cx);
// We'll be creating two different files, both with diagnostics, so we can
// later verify that, since the `BufferDiagnosticsEditor` only shows
// diagnostics for the provided path, the diagnostics for the other file
// will not be shown, contrary to what happens with
// `ProjectDiagnosticsEditor`.
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/test"),
json!({
"main.rs": "
fn main() {
let x = vec![];
let y = vec![];
a(x);
b(y);
c(y);
d(x);
}
"
.unindent(),
"other.rs": "
fn other() {
let unused = 42;
undefined_function();
}
"
.unindent(),
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
let cx = &mut VisualTestContext::from_window(*window, cx);
let project_path = project::ProjectPath {
worktree_id: project.read_with(cx, |project, cx| {
project.worktrees(cx).next().unwrap().read(cx).id()
}),
path: Arc::from(Path::new("main.rs")),
};
let buffer = project
.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})
.await
.ok();
// Create the diagnostics for `main.rs`.
let language_server_id = LanguageServerId(0);
let uri = lsp::Uri::from_file_path(path!("/test/main.rs")).unwrap();
let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
lsp_store.update(cx, |lsp_store, cx| {
lsp_store.update_diagnostics(language_server_id, lsp::PublishDiagnosticsParams {
uri: uri.clone(),
diagnostics: vec![
lsp::Diagnostic{
range: lsp::Range::new(lsp::Position::new(5, 6), lsp::Position::new(5, 7)),
severity: Some(lsp::DiagnosticSeverity::WARNING),
message: "use of moved value\nvalue used here after move".to_string(),
related_information: Some(vec![
lsp::DiagnosticRelatedInformation {
location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(2, 8), lsp::Position::new(2, 9))),
message: "move occurs because `y` has type `Vec<char>`, which does not implement the `Copy` trait".to_string()
},
lsp::DiagnosticRelatedInformation {
location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(4, 6), lsp::Position::new(4, 7))),
message: "value moved here".to_string()
},
]),
..Default::default()
},
lsp::Diagnostic{
range: lsp::Range::new(lsp::Position::new(6, 6), lsp::Position::new(6, 7)),
severity: Some(lsp::DiagnosticSeverity::ERROR),
message: "use of moved value\nvalue used here after move".to_string(),
related_information: Some(vec![
lsp::DiagnosticRelatedInformation {
location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(1, 8), lsp::Position::new(1, 9))),
message: "move occurs because `x` has type `Vec<char>`, which does not implement the `Copy` trait".to_string()
},
lsp::DiagnosticRelatedInformation {
location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(3, 6), lsp::Position::new(3, 7))),
message: "value moved here".to_string()
},
]),
..Default::default()
}
],
version: None
}, None, DiagnosticSourceKind::Pushed, &[], cx).unwrap();
// Create diagnostics for other.rs to ensure that the file and
// diagnostics are not included in `BufferDiagnosticsEditor` when it is
// deployed for main.rs.
lsp_store.update_diagnostics(language_server_id, lsp::PublishDiagnosticsParams {
uri: lsp::Uri::from_file_path(path!("/test/other.rs")).unwrap(),
diagnostics: vec![
lsp::Diagnostic{
range: lsp::Range::new(lsp::Position::new(1, 8), lsp::Position::new(1, 14)),
severity: Some(lsp::DiagnosticSeverity::WARNING),
message: "unused variable: `unused`".to_string(),
..Default::default()
},
lsp::Diagnostic{
range: lsp::Range::new(lsp::Position::new(2, 4), lsp::Position::new(2, 22)),
severity: Some(lsp::DiagnosticSeverity::ERROR),
message: "cannot find function `undefined_function` in this scope".to_string(),
..Default::default()
}
],
version: None
}, None, DiagnosticSourceKind::Pushed, &[], cx).unwrap();
});
let buffer_diagnostics = window.build_entity(cx, |window, cx| {
BufferDiagnosticsEditor::new(
project_path.clone(),
project.clone(),
buffer,
true,
window,
cx,
)
});
let editor = buffer_diagnostics.update(cx, |buffer_diagnostics, _| {
buffer_diagnostics.editor().clone()
});
// Since the excerpt updates is handled by a background task, we need to
// wait a little bit to ensure that the buffer diagnostic's editor content
// is rendered.
cx.executor()
.advance_clock(DIAGNOSTICS_UPDATE_DELAY + Duration::from_millis(10));
pretty_assertions::assert_eq!(
editor_content_with_blocks(&editor, cx),
indoc::indoc! {
"§ main.rs
§ -----
fn main() {
let x = vec![];
§ move occurs because `x` has type `Vec<char>`, which does not implement
§ the `Copy` trait (back)
let y = vec![];
§ move occurs because `y` has type `Vec<char>`, which does not implement
§ the `Copy` trait
a(x); § value moved here
b(y); § value moved here
c(y);
§ use of moved value
§ value used here after move
d(x);
§ use of moved value
§ value used here after move
§ hint: move occurs because `x` has type `Vec<char>`, which does not
§ implement the `Copy` trait
}"
}
);
}
#[gpui::test]
async fn test_buffer_diagnostics_without_warnings(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/test"),
json!({
"main.rs": "
fn main() {
let x = vec![];
let y = vec![];
a(x);
b(y);
c(y);
d(x);
}
"
.unindent(),
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
let cx = &mut VisualTestContext::from_window(*window, cx);
let project_path = project::ProjectPath {
worktree_id: project.read_with(cx, |project, cx| {
project.worktrees(cx).next().unwrap().read(cx).id()
}),
path: Arc::from(Path::new("main.rs")),
};
let buffer = project
.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})
.await
.ok();
let language_server_id = LanguageServerId(0);
let uri = lsp::Uri::from_file_path(path!("/test/main.rs")).unwrap();
let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
lsp_store.update(cx, |lsp_store, cx| {
lsp_store.update_diagnostics(language_server_id, lsp::PublishDiagnosticsParams {
uri: uri.clone(),
diagnostics: vec![
lsp::Diagnostic{
range: lsp::Range::new(lsp::Position::new(5, 6), lsp::Position::new(5, 7)),
severity: Some(lsp::DiagnosticSeverity::WARNING),
message: "use of moved value\nvalue used here after move".to_string(),
related_information: Some(vec![
lsp::DiagnosticRelatedInformation {
location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(2, 8), lsp::Position::new(2, 9))),
message: "move occurs because `y` has type `Vec<char>`, which does not implement the `Copy` trait".to_string()
},
lsp::DiagnosticRelatedInformation {
location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(4, 6), lsp::Position::new(4, 7))),
message: "value moved here".to_string()
},
]),
..Default::default()
},
lsp::Diagnostic{
range: lsp::Range::new(lsp::Position::new(6, 6), lsp::Position::new(6, 7)),
severity: Some(lsp::DiagnosticSeverity::ERROR),
message: "use of moved value\nvalue used here after move".to_string(),
related_information: Some(vec![
lsp::DiagnosticRelatedInformation {
location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(1, 8), lsp::Position::new(1, 9))),
message: "move occurs because `x` has type `Vec<char>`, which does not implement the `Copy` trait".to_string()
},
lsp::DiagnosticRelatedInformation {
location: lsp::Location::new(uri.clone(), lsp::Range::new(lsp::Position::new(3, 6), lsp::Position::new(3, 7))),
message: "value moved here".to_string()
},
]),
..Default::default()
}
],
version: None
}, None, DiagnosticSourceKind::Pushed, &[], cx).unwrap();
});
let include_warnings = false;
let buffer_diagnostics = window.build_entity(cx, |window, cx| {
BufferDiagnosticsEditor::new(
project_path.clone(),
project.clone(),
buffer,
include_warnings,
window,
cx,
)
});
let editor = buffer_diagnostics.update(cx, |buffer_diagnostics, _cx| {
buffer_diagnostics.editor().clone()
});
// Since the excerpt updates is handled by a background task, we need to
// wait a little bit to ensure that the buffer diagnostic's editor content
// is rendered.
cx.executor()
.advance_clock(DIAGNOSTICS_UPDATE_DELAY + Duration::from_millis(10));
pretty_assertions::assert_eq!(
editor_content_with_blocks(&editor, cx),
indoc::indoc! {
"§ main.rs
§ -----
fn main() {
let x = vec![];
§ move occurs because `x` has type `Vec<char>`, which does not implement
§ the `Copy` trait (back)
let y = vec![];
a(x); § value moved here
b(y);
c(y);
d(x);
§ use of moved value
§ value used here after move
§ hint: move occurs because `x` has type `Vec<char>`, which does not
§ implement the `Copy` trait
}"
}
);
}
#[gpui::test]
async fn test_buffer_diagnostics_multiple_servers(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/test"),
json!({
"main.rs": "
fn main() {
let x = vec![];
let y = vec![];
a(x);
b(y);
c(y);
d(x);
}
"
.unindent(),
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
let cx = &mut VisualTestContext::from_window(*window, cx);
let project_path = project::ProjectPath {
worktree_id: project.read_with(cx, |project, cx| {
project.worktrees(cx).next().unwrap().read(cx).id()
}),
path: Arc::from(Path::new("main.rs")),
};
let buffer = project
.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})
.await
.ok();
// Create the diagnostics for `main.rs`.
// Two warnings are being created, one for each language server, in order to
// assert that both warnings are rendered in the editor.
let language_server_id_a = LanguageServerId(0);
let language_server_id_b = LanguageServerId(1);
let uri = lsp::Uri::from_file_path(path!("/test/main.rs")).unwrap();
let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
lsp_store.update(cx, |lsp_store, cx| {
lsp_store
.update_diagnostics(
language_server_id_a,
lsp::PublishDiagnosticsParams {
uri: uri.clone(),
diagnostics: vec![lsp::Diagnostic {
range: lsp::Range::new(lsp::Position::new(5, 6), lsp::Position::new(5, 7)),
severity: Some(lsp::DiagnosticSeverity::WARNING),
message: "use of moved value\nvalue used here after move".to_string(),
related_information: None,
..Default::default()
}],
version: None,
},
None,
DiagnosticSourceKind::Pushed,
&[],
cx,
)
.unwrap();
lsp_store
.update_diagnostics(
language_server_id_b,
lsp::PublishDiagnosticsParams {
uri: uri.clone(),
diagnostics: vec![lsp::Diagnostic {
range: lsp::Range::new(lsp::Position::new(6, 6), lsp::Position::new(6, 7)),
severity: Some(lsp::DiagnosticSeverity::WARNING),
message: "use of moved value\nvalue used here after move".to_string(),
related_information: None,
..Default::default()
}],
version: None,
},
None,
DiagnosticSourceKind::Pushed,
&[],
cx,
)
.unwrap();
});
let buffer_diagnostics = window.build_entity(cx, |window, cx| {
BufferDiagnosticsEditor::new(
project_path.clone(),
project.clone(),
buffer,
true,
window,
cx,
)
});
let editor = buffer_diagnostics.update(cx, |buffer_diagnostics, _| {
buffer_diagnostics.editor().clone()
});
// Since the excerpt updates is handled by a background task, we need to
// wait a little bit to ensure that the buffer diagnostic's editor content
// is rendered.
cx.executor()
.advance_clock(DIAGNOSTICS_UPDATE_DELAY + Duration::from_millis(10));
pretty_assertions::assert_eq!(
editor_content_with_blocks(&editor, cx),
indoc::indoc! {
"§ main.rs
§ -----
a(x);
b(y);
c(y);
§ use of moved value
§ value used here after move
d(x);
§ use of moved value
§ value used here after move
}"
}
);
buffer_diagnostics.update(cx, |buffer_diagnostics, _cx| {
assert_eq!(
*buffer_diagnostics.summary(),
DiagnosticSummary {
warning_count: 2,
error_count: 0
}
);
})
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
zlog::init_test();

View File

@@ -1,33 +1,56 @@
use crate::{ProjectDiagnosticsEditor, ToggleDiagnosticsRefresh};
use gpui::{Context, Entity, EventEmitter, ParentElement, Render, WeakEntity, Window};
use crate::{BufferDiagnosticsEditor, ProjectDiagnosticsEditor, ToggleDiagnosticsRefresh};
use gpui::{Context, EventEmitter, ParentElement, Render, Window};
use language::DiagnosticEntry;
use text::{Anchor, BufferId};
use ui::prelude::*;
use ui::{IconButton, IconButtonShape, IconName, Tooltip};
use workspace::{ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, item::ItemHandle};
pub struct ToolbarControls {
editor: Option<WeakEntity<ProjectDiagnosticsEditor>>,
editor: Option<Box<dyn DiagnosticsToolbarEditor>>,
}
pub(crate) trait DiagnosticsToolbarEditor: Send + Sync {
/// Informs the toolbar whether warnings are included in the diagnostics.
fn include_warnings(&self, cx: &App) -> bool;
/// Toggles whether warning diagnostics should be displayed by the
/// diagnostics editor.
fn toggle_warnings(&self, window: &mut Window, cx: &mut App);
/// Indicates whether any of the excerpts displayed by the diagnostics
/// editor are stale.
fn has_stale_excerpts(&self, cx: &App) -> bool;
/// Indicates whether the diagnostics editor is currently updating the
/// diagnostics.
fn is_updating(&self, cx: &App) -> bool;
/// Requests that the diagnostics editor stop updating the diagnostics.
fn stop_updating(&self, cx: &mut App);
/// Requests that the diagnostics editor updates the displayed diagnostics
/// with the latest information.
fn refresh_diagnostics(&self, window: &mut Window, cx: &mut App);
/// Returns a list of diagnostics for the provided buffer id.
fn get_diagnostics_for_buffer(
&self,
buffer_id: BufferId,
cx: &App,
) -> Vec<DiagnosticEntry<Anchor>>;
}
impl Render for ToolbarControls {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let mut include_warnings = false;
let mut has_stale_excerpts = false;
let mut include_warnings = false;
let mut is_updating = false;
if let Some(editor) = self.diagnostics() {
let diagnostics = editor.read(cx);
include_warnings = diagnostics.include_warnings;
has_stale_excerpts = !diagnostics.paths_to_update.is_empty();
is_updating = diagnostics.update_excerpts_task.is_some()
|| diagnostics
.project
.read(cx)
.language_servers_running_disk_based_diagnostics(cx)
.next()
.is_some();
match &self.editor {
Some(editor) => {
include_warnings = editor.include_warnings(cx);
has_stale_excerpts = editor.has_stale_excerpts(cx);
is_updating = editor.is_updating(cx);
}
None => {}
}
let tooltip = if include_warnings {
let warning_tooltip = if include_warnings {
"Exclude Warnings"
} else {
"Include Warnings"
@@ -52,11 +75,12 @@ impl Render for ToolbarControls {
&ToggleDiagnosticsRefresh,
))
.on_click(cx.listener(move |toolbar_controls, _, _, cx| {
if let Some(diagnostics) = toolbar_controls.diagnostics() {
diagnostics.update(cx, |diagnostics, cx| {
diagnostics.update_excerpts_task = None;
match toolbar_controls.editor() {
Some(editor) => {
editor.stop_updating(cx);
cx.notify();
});
}
None => {}
}
})),
)
@@ -71,12 +95,11 @@ impl Render for ToolbarControls {
&ToggleDiagnosticsRefresh,
))
.on_click(cx.listener({
move |toolbar_controls, _, window, cx| {
if let Some(diagnostics) = toolbar_controls.diagnostics() {
diagnostics.update(cx, move |diagnostics, cx| {
diagnostics.update_all_excerpts(window, cx);
});
}
move |toolbar_controls, _, window, cx| match toolbar_controls
.editor()
{
Some(editor) => editor.refresh_diagnostics(window, cx),
None => {}
}
})),
)
@@ -86,13 +109,10 @@ impl Render for ToolbarControls {
IconButton::new("toggle-warnings", IconName::Warning)
.icon_color(warning_color)
.shape(IconButtonShape::Square)
.tooltip(Tooltip::text(tooltip))
.on_click(cx.listener(|this, _, window, cx| {
if let Some(editor) = this.diagnostics() {
editor.update(cx, |editor, cx| {
editor.toggle_warnings(&Default::default(), window, cx);
});
}
.tooltip(Tooltip::text(warning_tooltip))
.on_click(cx.listener(|this, _, window, cx| match &this.editor {
Some(editor) => editor.toggle_warnings(window, cx),
None => {}
})),
)
}
@@ -109,7 +129,10 @@ impl ToolbarItemView for ToolbarControls {
) -> ToolbarItemLocation {
if let Some(pane_item) = active_pane_item.as_ref() {
if let Some(editor) = pane_item.downcast::<ProjectDiagnosticsEditor>() {
self.editor = Some(editor.downgrade());
self.editor = Some(Box::new(editor.downgrade()));
ToolbarItemLocation::PrimaryRight
} else if let Some(editor) = pane_item.downcast::<BufferDiagnosticsEditor>() {
self.editor = Some(Box::new(editor.downgrade()));
ToolbarItemLocation::PrimaryRight
} else {
ToolbarItemLocation::Hidden
@@ -131,7 +154,7 @@ impl ToolbarControls {
ToolbarControls { editor: None }
}
fn diagnostics(&self) -> Option<Entity<ProjectDiagnosticsEditor>> {
self.editor.as_ref()?.upgrade()
fn editor(&self) -> Option<&dyn DiagnosticsToolbarEditor> {
self.editor.as_deref()
}
}

View File

@@ -640,6 +640,10 @@ actions!(
SelectEnclosingSymbol,
/// Selects the next larger syntax node.
SelectLargerSyntaxNode,
/// Selects the next syntax node sibling.
SelectNextSyntaxNode,
/// Selects the previous syntax node sibling.
SelectPreviousSyntaxNode,
/// Extends selection left.
SelectLeft,
/// Selects the current line.

View File

@@ -1502,6 +1502,7 @@ impl CodeActionsMenu {
this.child(
h_flex()
.overflow_hidden()
.text_sm()
.child(
// TASK: It would be good to make lsp_action.title a SharedString to avoid allocating here.
action.lsp_action.title().replace("\n", ""),

View File

@@ -15138,6 +15138,104 @@ impl Editor {
});
}
pub fn select_next_syntax_node(
&mut self,
_: &SelectNextSyntaxNode,
window: &mut Window,
cx: &mut Context<Self>,
) {
let old_selections: Box<[_]> = self.selections.all::<usize>(cx).into();
if old_selections.is_empty() {
return;
}
self.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx);
let buffer = self.buffer.read(cx).snapshot(cx);
let mut selected_sibling = false;
let new_selections = old_selections
.iter()
.map(|selection| {
let old_range = selection.start..selection.end;
if let Some(node) = buffer.syntax_next_sibling(old_range) {
let new_range = node.byte_range();
selected_sibling = true;
Selection {
id: selection.id,
start: new_range.start,
end: new_range.end,
goal: SelectionGoal::None,
reversed: selection.reversed,
}
} else {
selection.clone()
}
})
.collect::<Vec<_>>();
if selected_sibling {
self.change_selections(
SelectionEffects::scroll(Autoscroll::fit()),
window,
cx,
|s| {
s.select(new_selections);
},
);
}
}
pub fn select_prev_syntax_node(
&mut self,
_: &SelectPreviousSyntaxNode,
window: &mut Window,
cx: &mut Context<Self>,
) {
let old_selections: Box<[_]> = self.selections.all::<usize>(cx).into();
if old_selections.is_empty() {
return;
}
self.hide_mouse_cursor(HideMouseCursorOrigin::MovementAction, cx);
let buffer = self.buffer.read(cx).snapshot(cx);
let mut selected_sibling = false;
let new_selections = old_selections
.iter()
.map(|selection| {
let old_range = selection.start..selection.end;
if let Some(node) = buffer.syntax_prev_sibling(old_range) {
let new_range = node.byte_range();
selected_sibling = true;
Selection {
id: selection.id,
start: new_range.start,
end: new_range.end,
goal: SelectionGoal::None,
reversed: selection.reversed,
}
} else {
selection.clone()
}
})
.collect::<Vec<_>>();
if selected_sibling {
self.change_selections(
SelectionEffects::scroll(Autoscroll::fit()),
window,
cx,
|s| {
s.select(new_selections);
},
);
}
}
fn refresh_runnables(&mut self, window: &mut Window, cx: &mut Context<Self>) -> Task<()> {
if !EditorSettings::get_global(cx).gutter.runnables {
self.clear_tasks();
@@ -18998,6 +19096,8 @@ impl Editor {
}
}
/// Returns the project path for the editor's buffer, if any buffer is
/// opened in the editor.
pub fn project_path(&self, cx: &App) -> Option<ProjectPath> {
if let Some(buffer) = self.buffer.read(cx).as_singleton() {
buffer.read(cx).project_path(cx)

View File

@@ -25330,6 +25330,101 @@ async fn test_non_utf_8_opens(cx: &mut TestAppContext) {
);
}
#[gpui::test]
async fn test_select_next_prev_syntax_node(cx: &mut TestAppContext) {
init_test(cx, |_| {});
let language = Arc::new(Language::new(
LanguageConfig::default(),
Some(tree_sitter_rust::LANGUAGE.into()),
));
// Test hierarchical sibling navigation
let text = r#"
fn outer() {
if condition {
let a = 1;
}
let b = 2;
}
fn another() {
let c = 3;
}
"#;
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language, cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let (editor, cx) = cx.add_window_view(|window, cx| build_editor(buffer, window, cx));
// Wait for parsing to complete
editor
.condition::<crate::EditorEvent>(cx, |editor, cx| !editor.buffer.read(cx).is_parsing(cx))
.await;
editor.update_in(cx, |editor, window, cx| {
// Start by selecting "let a = 1;" inside the if block
editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
s.select_display_ranges([
DisplayPoint::new(DisplayRow(3), 16)..DisplayPoint::new(DisplayRow(3), 26)
]);
});
let initial_selection = editor.selections.display_ranges(cx);
assert_eq!(initial_selection.len(), 1, "Should have one selection");
// Test select next sibling - should move up levels to find the next sibling
// Since "let a = 1;" has no siblings in the if block, it should move up
// to find "let b = 2;" which is a sibling of the if block
editor.select_next_syntax_node(&SelectNextSyntaxNode, window, cx);
let next_selection = editor.selections.display_ranges(cx);
// Should have a selection and it should be different from the initial
assert_eq!(
next_selection.len(),
1,
"Should have one selection after next"
);
assert_ne!(
next_selection[0], initial_selection[0],
"Next sibling selection should be different"
);
// Test hierarchical navigation by going to the end of the current function
// and trying to navigate to the next function
editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
s.select_display_ranges([
DisplayPoint::new(DisplayRow(5), 12)..DisplayPoint::new(DisplayRow(5), 22)
]);
});
editor.select_next_syntax_node(&SelectNextSyntaxNode, window, cx);
let function_next_selection = editor.selections.display_ranges(cx);
// Should move to the next function
assert_eq!(
function_next_selection.len(),
1,
"Should have one selection after function next"
);
// Test select previous sibling navigation
editor.select_prev_syntax_node(&SelectPreviousSyntaxNode, window, cx);
let prev_selection = editor.selections.display_ranges(cx);
// Should have a selection and it should be different
assert_eq!(
prev_selection.len(),
1,
"Should have one selection after prev"
);
assert_ne!(
prev_selection[0], function_next_selection[0],
"Previous sibling selection should be different from next"
);
});
}
#[track_caller]
fn extract_color_inlays(editor: &Editor, cx: &App) -> Vec<Rgba> {
editor

View File

@@ -365,6 +365,8 @@ impl EditorElement {
register_action(editor, window, Editor::toggle_comments);
register_action(editor, window, Editor::select_larger_syntax_node);
register_action(editor, window, Editor::select_smaller_syntax_node);
register_action(editor, window, Editor::select_next_syntax_node);
register_action(editor, window, Editor::select_prev_syntax_node);
register_action(editor, window, Editor::unwrap_syntax_node);
register_action(editor, window, Editor::select_enclosing_symbol);
register_action(editor, window, Editor::move_to_enclosing_bracket);
@@ -8296,7 +8298,7 @@ impl Element for EditorElement {
let (mut snapshot, is_read_only) = self.editor.update(cx, |editor, cx| {
(editor.snapshot(window, cx), editor.read_only(cx))
});
let style = self.style.clone();
let style = &self.style;
let rem_size = window.rem_size();
let font_id = window.text_system().resolve_font(&style.text.font());
@@ -8771,7 +8773,7 @@ impl Element for EditorElement {
blame.blame_for_rows(&[row_infos], cx).next()
})
.flatten()?;
let mut element = render_inline_blame_entry(blame_entry, &style, cx)?;
let mut element = render_inline_blame_entry(blame_entry, style, cx)?;
let inline_blame_padding = ProjectSettings::get_global(cx)
.git
.inline_blame
@@ -8791,7 +8793,7 @@ impl Element for EditorElement {
let longest_line_width = layout_line(
snapshot.longest_row(),
&snapshot,
&style,
style,
editor_width,
is_row_soft_wrapped,
window,
@@ -8949,7 +8951,7 @@ impl Element for EditorElement {
scroll_pixel_position,
newest_selection_head,
editor_width,
&style,
style,
window,
cx,
)
@@ -8967,7 +8969,7 @@ impl Element for EditorElement {
end_row,
line_height,
em_width,
&style,
style,
window,
cx,
);
@@ -9112,7 +9114,7 @@ impl Element for EditorElement {
&line_layouts,
newest_selection_head,
newest_selection_point,
&style,
style,
window,
cx,
)

View File

@@ -4,7 +4,7 @@
use super::{Bias, DisplayPoint, DisplaySnapshot, SelectionGoal, ToDisplayPoint};
use crate::{DisplayRow, EditorStyle, ToOffset, ToPoint, scroll::ScrollAnchor};
use gpui::{Pixels, WindowTextSystem};
use language::Point;
use language::{CharClassifier, Point};
use multi_buffer::{MultiBufferRow, MultiBufferSnapshot};
use serde::Deserialize;
use workspace::searchable::Direction;
@@ -405,15 +405,18 @@ pub fn previous_subword_start(map: &DisplaySnapshot, point: DisplayPoint) -> Dis
let classifier = map.buffer_snapshot.char_classifier_at(raw_point);
find_preceding_boundary_display_point(map, point, FindRange::MultiLine, |left, right| {
let is_word_start =
classifier.kind(left) != classifier.kind(right) && !right.is_whitespace();
let is_subword_start = classifier.is_word('-') && left == '-' && right != '-'
|| left == '_' && right != '_'
|| left.is_lowercase() && right.is_uppercase();
is_word_start || is_subword_start || left == '\n'
is_subword_start(left, right, &classifier) || left == '\n'
})
}
pub fn is_subword_start(left: char, right: char, classifier: &CharClassifier) -> bool {
let is_word_start = classifier.kind(left) != classifier.kind(right) && !right.is_whitespace();
let is_subword_start = classifier.is_word('-') && left == '-' && right != '-'
|| left == '_' && right != '_'
|| left.is_lowercase() && right.is_uppercase();
is_word_start || is_subword_start
}
/// Returns a position of the next word boundary, where a word character is defined as either
/// uppercase letter, lowercase letter, '_' character or language-specific word character (like '-' in CSS).
pub fn next_word_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint {
@@ -463,15 +466,19 @@ pub fn next_subword_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPo
let classifier = map.buffer_snapshot.char_classifier_at(raw_point);
find_boundary(map, point, FindRange::MultiLine, |left, right| {
let is_word_end =
(classifier.kind(left) != classifier.kind(right)) && !classifier.is_whitespace(left);
let is_subword_end = classifier.is_word('-') && left != '-' && right == '-'
|| left != '_' && right == '_'
|| left.is_lowercase() && right.is_uppercase();
is_word_end || is_subword_end || right == '\n'
is_subword_end(left, right, &classifier) || right == '\n'
})
}
pub fn is_subword_end(left: char, right: char, classifier: &CharClassifier) -> bool {
let is_word_end =
(classifier.kind(left) != classifier.kind(right)) && !classifier.is_whitespace(left);
let is_subword_end = classifier.is_word('-') && left != '-' && right == '-'
|| left != '_' && right == '_'
|| left.is_lowercase() && right.is_uppercase();
is_word_end || is_subword_end
}
/// Returns a position of the start of the current paragraph, where a paragraph
/// is defined as a run of non-blank lines.
pub fn start_of_paragraph(

View File

@@ -66,9 +66,10 @@ impl FeatureFlag for LlmClosedBetaFeatureFlag {
const NAME: &'static str = "llm-closed-beta";
}
pub struct ZedProFeatureFlag {}
impl FeatureFlag for ZedProFeatureFlag {
const NAME: &'static str = "zed-pro";
pub struct BillingV2FeatureFlag {}
impl FeatureFlag for BillingV2FeatureFlag {
const NAME: &'static str = "billing-v2";
}
pub struct NotebookFeatureFlag;

View File

@@ -1205,10 +1205,9 @@ impl GitRepository for RealGitRepository {
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
let working_directory = self.working_directory();
let git_binary_path = self.git_binary_path.clone();
self.executor
.spawn(async move {
let mut cmd = new_smol_command(&git_binary_path);
let mut cmd = new_smol_command("git");
cmd.current_dir(&working_directory?)
.envs(env.iter())
.args(["stash", "push", "--quiet"])
@@ -1230,10 +1229,9 @@ impl GitRepository for RealGitRepository {
fn stash_pop(&self, env: Arc<HashMap<String, String>>) -> BoxFuture<'_, Result<()>> {
let working_directory = self.working_directory();
let git_binary_path = self.git_binary_path.clone();
self.executor
.spawn(async move {
let mut cmd = new_smol_command(&git_binary_path);
let mut cmd = new_smol_command("git");
cmd.current_dir(&working_directory?)
.envs(env.iter())
.args(["stash", "pop"]);
@@ -1258,10 +1256,9 @@ impl GitRepository for RealGitRepository {
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
let working_directory = self.working_directory();
let git_binary_path = self.git_binary_path.clone();
self.executor
.spawn(async move {
let mut cmd = new_smol_command(&git_binary_path);
let mut cmd = new_smol_command("git");
cmd.current_dir(&working_directory?)
.envs(env.iter())
.args(["commit", "--quiet", "-m"])
@@ -1305,7 +1302,7 @@ impl GitRepository for RealGitRepository {
let executor = cx.background_executor().clone();
async move {
let working_directory = working_directory?;
let mut command = new_smol_command(&self.git_binary_path);
let mut command = new_smol_command("git");
command
.envs(env.iter())
.current_dir(&working_directory)
@@ -1336,7 +1333,7 @@ impl GitRepository for RealGitRepository {
let working_directory = self.working_directory();
let executor = cx.background_executor().clone();
async move {
let mut command = new_smol_command(&self.git_binary_path);
let mut command = new_smol_command("git");
command
.envs(env.iter())
.current_dir(&working_directory?)
@@ -1362,7 +1359,7 @@ impl GitRepository for RealGitRepository {
let remote_name = format!("{}", fetch_options);
let executor = cx.background_executor().clone();
async move {
let mut command = new_smol_command(&self.git_binary_path);
let mut command = new_smol_command("git");
command
.envs(env.iter())
.current_dir(&working_directory?)

View File

@@ -218,23 +218,6 @@ impl AsyncApp {
Some(read(app.try_global()?, &app))
}
/// Reads the global state of the specified type, passing it to the given callback.
/// A default value is assigned if a global of this type has not yet been assigned.
///
/// # Errors
/// If the app has ben dropped this returns an error.
pub fn try_read_default_global<G: Global + Default, R>(
&self,
read: impl FnOnce(&G, &App) -> R,
) -> Result<R> {
let app = self.app.upgrade().context("app was released")?;
let mut app = app.borrow_mut();
app.update(|cx| {
cx.default_global::<G>();
});
Ok(read(app.try_global().context("app was released")?, &app))
}
/// A convenience method for [`App::update_global`](BorrowAppContext::update_global)
/// for updating the global state of the specified type.
pub fn update_global<G: Global, R>(

View File

@@ -16,7 +16,7 @@ use core_foundation::{
use core_graphics::{
base::{CGGlyph, kCGImageAlphaPremultipliedLast},
color_space::CGColorSpace,
context::CGContext,
context::{CGContext, CGTextDrawingMode},
display::CGPoint,
};
use core_text::{
@@ -396,6 +396,12 @@ impl MacTextSystemState {
let subpixel_shift = params
.subpixel_variant
.map(|v| v as f32 / SUBPIXEL_VARIANTS as f32);
cx.set_allows_font_smoothing(true);
cx.set_should_smooth_fonts(true);
cx.set_text_drawing_mode(CGTextDrawingMode::CGTextFill);
cx.set_gray_fill_color(0.0, 1.0);
cx.set_allows_antialiasing(true);
cx.set_should_antialias(true);
cx.set_allows_font_subpixel_positioning(true);
cx.set_should_subpixel_position_fonts(true);
cx.set_allows_font_subpixel_quantization(false);

View File

@@ -1576,33 +1576,6 @@ impl Render for KeymapEditor {
.child(
h_flex()
.gap_2()
.child(
right_click_menu("open-keymap-menu")
.menu(|window, cx| {
ContextMenu::build(window, cx, |menu, _, _| {
menu.header("Open Keymap JSON")
.action("User", zed_actions::OpenKeymap.boxed_clone())
.action("Zed Default", zed_actions::OpenDefaultKeymap.boxed_clone())
.action("Vim Default", vim::OpenDefaultKeymap.boxed_clone())
})
})
.anchor(gpui::Corner::TopLeft)
.trigger(|open, _, _|
IconButton::new(
"OpenKeymapJsonButton",
IconName::Json
)
.shape(ui::IconButtonShape::Square)
.when(!open, |this|
this.tooltip(move |window, cx| {
Tooltip::with_meta("Open Keymap JSON", Some(&zed_actions::OpenKeymap),"Right click to view more options", window, cx)
})
)
.on_click(|_, window, cx| {
window.dispatch_action(zed_actions::OpenKeymap.boxed_clone(), cx);
})
)
)
.child(
div()
.key_context({
@@ -1617,73 +1590,139 @@ impl Render for KeymapEditor {
.py_1()
.border_1()
.border_color(theme.colors().border)
.rounded_lg()
.rounded_md()
.child(self.filter_editor.clone()),
)
.child(
IconButton::new(
"KeymapEditorToggleFiltersIcon",
IconName::Keyboard,
)
.shape(ui::IconButtonShape::Square)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Search by Keystroke",
&ToggleKeystrokeSearch,
&focus_handle.clone(),
window,
cx,
h_flex()
.gap_1()
.min_w_64()
.child(
IconButton::new(
"KeymapEditorToggleFiltersIcon",
IconName::Keyboard,
)
}
})
.toggle_state(matches!(
self.search_mode,
SearchMode::KeyStroke { .. }
))
.on_click(|_, window, cx| {
window.dispatch_action(ToggleKeystrokeSearch.boxed_clone(), cx);
}),
)
.child(
IconButton::new("KeymapEditorConflictIcon", IconName::Warning)
.shape(ui::IconButtonShape::Square)
.when(
self.keybinding_conflict_state.any_user_binding_conflicts(),
|this| {
this.indicator(Indicator::dot().color(Color::Warning))
},
)
.tooltip({
let filter_state = self.filter_state;
let focus_handle = focus_handle.clone();
.icon_size(IconSize::Small)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
match filter_state {
FilterState::All => "Show Conflicts",
FilterState::Conflicts => "Hide Conflicts",
},
&ToggleConflictFilter,
&focus_handle.clone(),
window,
move |window, cx| {
Tooltip::for_action_in(
"Search by Keystroke",
&ToggleKeystrokeSearch,
&focus_handle.clone(),
window,
cx,
)
}
})
.toggle_state(matches!(
self.search_mode,
SearchMode::KeyStroke { .. }
))
.on_click(|_, window, cx| {
window.dispatch_action(
ToggleKeystrokeSearch.boxed_clone(),
cx,
);
}),
)
.child(
IconButton::new("KeymapEditorConflictIcon", IconName::Warning)
.icon_size(IconSize::Small)
.when(
self.keybinding_conflict_state
.any_user_binding_conflicts(),
|this| {
this.indicator(
Indicator::dot().color(Color::Warning),
)
},
)
}
})
.selected_icon_color(Color::Warning)
.toggle_state(matches!(
self.filter_state,
FilterState::Conflicts
))
.on_click(|_, window, cx| {
window.dispatch_action(
ToggleConflictFilter.boxed_clone(),
cx,
);
}),
.tooltip({
let filter_state = self.filter_state;
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
match filter_state {
FilterState::All => "Show Conflicts",
FilterState::Conflicts => {
"Hide Conflicts"
}
},
&ToggleConflictFilter,
&focus_handle.clone(),
window,
cx,
)
}
})
.selected_icon_color(Color::Warning)
.toggle_state(matches!(
self.filter_state,
FilterState::Conflicts
))
.on_click(|_, window, cx| {
window.dispatch_action(
ToggleConflictFilter.boxed_clone(),
cx,
);
}),
)
.child(
div()
.ml_1()
.pl_2()
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.child(
right_click_menu("open-keymap-menu")
.menu(|window, cx| {
ContextMenu::build(window, cx, |menu, _, _| {
menu.header("Open Keymap JSON")
.action(
"User",
zed_actions::OpenKeymap.boxed_clone(),
)
.action(
"Zed Default",
zed_actions::OpenDefaultKeymap
.boxed_clone(),
)
.action(
"Vim Default",
vim::OpenDefaultKeymap.boxed_clone(),
)
})
})
.anchor(gpui::Corner::TopLeft)
.trigger(|open, _, _| {
IconButton::new(
"OpenKeymapJsonButton",
IconName::Json,
)
.icon_size(IconSize::Small)
.when(!open, |this| {
this.tooltip(move |window, cx| {
Tooltip::with_meta(
"Open keymap.json",
Some(&zed_actions::OpenKeymap),
"Right click to view more options",
window,
cx,
)
})
})
.on_click(|_, window, cx| {
window.dispatch_action(
zed_actions::OpenKeymap.boxed_clone(),
cx,
);
})
}),
),
)
),
)
.when_some(
@@ -1694,48 +1733,42 @@ impl Render for KeymapEditor {
|this, exact_match| {
this.child(
h_flex()
.map(|this| {
if self
.keybinding_conflict_state
.any_user_binding_conflicts()
{
this.pr(rems_from_px(54.))
} else {
this.pr_7()
}
})
.gap_2()
.child(self.keystroke_editor.clone())
.child(
IconButton::new(
"keystrokes-exact-match",
IconName::CaseSensitive,
)
.tooltip({
let keystroke_focus_handle =
self.keystroke_editor.read(cx).focus_handle(cx);
move |window, cx| {
Tooltip::for_action_in(
"Toggle Exact Match Mode",
&ToggleExactKeystrokeMatching,
&keystroke_focus_handle,
window,
cx,
h_flex()
.min_w_64()
.child(
IconButton::new(
"keystrokes-exact-match",
IconName::CaseSensitive,
)
}
})
.shape(IconButtonShape::Square)
.toggle_state(exact_match)
.on_click(
cx.listener(|_, _, window, cx| {
window.dispatch_action(
ToggleExactKeystrokeMatching.boxed_clone(),
cx,
);
}),
),
),
.tooltip({
let keystroke_focus_handle =
self.keystroke_editor.read(cx).focus_handle(cx);
move |window, cx| {
Tooltip::for_action_in(
"Toggle Exact Match Mode",
&ToggleExactKeystrokeMatching,
&keystroke_focus_handle,
window,
cx,
)
}
})
.shape(IconButtonShape::Square)
.toggle_state(exact_match)
.on_click(
cx.listener(|_, _, window, cx| {
window.dispatch_action(
ToggleExactKeystrokeMatching.boxed_clone(),
cx,
);
}),
),
),
)
)
},
),

View File

@@ -461,7 +461,7 @@ impl Render for KeystrokeInput {
let is_focused = self.outer_focus_handle.contains_focused(window, cx);
let is_recording = self.is_recording(window);
let horizontal_padding = rems_from_px(64.);
let width = rems_from_px(64.);
let recording_bg_color = colors
.editor_background
@@ -528,6 +528,9 @@ impl Render for KeystrokeInput {
h_flex()
.id("keystroke-input")
.track_focus(&self.outer_focus_handle)
.key_context(Self::key_context())
.on_action(cx.listener(Self::start_recording))
.on_action(cx.listener(Self::clear_keystrokes))
.py_2()
.px_3()
.gap_2()
@@ -535,7 +538,7 @@ impl Render for KeystrokeInput {
.w_full()
.flex_1()
.justify_between()
.rounded_sm()
.rounded_md()
.overflow_hidden()
.map(|this| {
if is_recording {
@@ -545,16 +548,16 @@ impl Render for KeystrokeInput {
}
})
.border_1()
.border_color(colors.border_variant)
.when(is_focused, |parent| {
parent.border_color(colors.border_focused)
.map(|this| {
if is_focused {
this.border_color(colors.border_focused)
} else {
this.border_color(colors.border_variant)
}
})
.key_context(Self::key_context())
.on_action(cx.listener(Self::start_recording))
.on_action(cx.listener(Self::clear_keystrokes))
.child(
h_flex()
.w(horizontal_padding)
.w(width)
.gap_0p5()
.justify_start()
.flex_none()
@@ -573,14 +576,13 @@ impl Render for KeystrokeInput {
.id("keystroke-input-inner")
.track_focus(&self.inner_focus_handle)
.on_modifiers_changed(cx.listener(Self::on_modifiers_changed))
.size_full()
.when(!self.search, |this| {
this.focus(|mut style| {
style.border_color = Some(colors.border_focused);
style
})
})
.w_full()
.size_full()
.min_w_0()
.justify_center()
.flex_wrap()
@@ -589,7 +591,7 @@ impl Render for KeystrokeInput {
)
.child(
h_flex()
.w(horizontal_padding)
.w(width)
.gap_0p5()
.justify_end()
.flex_none()
@@ -641,9 +643,7 @@ impl Render for KeystrokeInput {
"Clear Keystrokes",
&ClearKeystrokes,
))
.when(!is_recording || !is_focused, |this| {
this.icon_color(Color::Muted)
})
.when(!is_focused, |this| this.icon_color(Color::Muted))
.on_click(cx.listener(|this, _event, window, cx| {
this.clear_keystrokes(&ClearKeystrokes, window, cx);
})),

View File

@@ -3459,46 +3459,66 @@ impl BufferSnapshot {
}
/// Returns the closest syntax node enclosing the given range.
/// Positions a tree cursor at the leaf node that contains or touches the given range.
/// This is shared logic used by syntax navigation methods.
fn position_cursor_at_range(cursor: &mut tree_sitter::TreeCursor, range: &Range<usize>) {
// Descend to the first leaf that touches the start of the range.
//
// If the range is non-empty and the current node ends exactly at the start,
// move to the next sibling to find a node that extends beyond the start.
//
// If the range is empty and the current node starts after the range position,
// move to the previous sibling to find the node that contains the position.
while cursor.goto_first_child_for_byte(range.start).is_some() {
if !range.is_empty() && cursor.node().end_byte() == range.start {
cursor.goto_next_sibling();
}
if range.is_empty() && cursor.node().start_byte() > range.start {
cursor.goto_previous_sibling();
}
}
}
/// Moves the cursor to find a node that contains the given range.
/// Returns true if such a node is found, false otherwise.
/// This is shared logic used by syntax navigation methods.
fn find_containing_node(
cursor: &mut tree_sitter::TreeCursor,
range: &Range<usize>,
strict: bool,
) -> bool {
loop {
let node_range = cursor.node().byte_range();
if node_range.start <= range.start
&& node_range.end >= range.end
&& (!strict || node_range.len() > range.len())
{
return true;
}
if !cursor.goto_parent() {
return false;
}
}
}
pub fn syntax_ancestor<'a, T: ToOffset>(
&'a self,
range: Range<T>,
) -> Option<tree_sitter::Node<'a>> {
let range = range.start.to_offset(self)..range.end.to_offset(self);
let mut result: Option<tree_sitter::Node<'a>> = None;
'outer: for layer in self
for layer in self
.syntax
.layers_for_range(range.clone(), &self.text, true)
{
let mut cursor = layer.node().walk();
// Descend to the first leaf that touches the start of the range.
//
// If the range is non-empty and the current node ends exactly at the start,
// move to the next sibling to find a node that extends beyond the start.
//
// If the range is empty and the current node starts after the range position,
// move to the previous sibling to find the node that contains the position.
while cursor.goto_first_child_for_byte(range.start).is_some() {
if !range.is_empty() && cursor.node().end_byte() == range.start {
cursor.goto_next_sibling();
}
if range.is_empty() && cursor.node().start_byte() > range.start {
cursor.goto_previous_sibling();
}
}
Self::position_cursor_at_range(&mut cursor, &range);
// Ascend to the smallest ancestor that strictly contains the range.
loop {
let node_range = cursor.node().byte_range();
if node_range.start <= range.start
&& node_range.end >= range.end
&& node_range.len() > range.len()
{
break;
}
if !cursor.goto_parent() {
continue 'outer;
}
if !Self::find_containing_node(&mut cursor, &range, true) {
continue;
}
let left_node = cursor.node();
@@ -3541,6 +3561,112 @@ impl BufferSnapshot {
result
}
/// Find the previous sibling syntax node at the given range.
///
/// This function locates the syntax node that precedes the node containing
/// the given range. It searches hierarchically by:
/// 1. Finding the node that contains the given range
/// 2. Looking for the previous sibling at the same tree level
/// 3. If no sibling is found, moving up to parent levels and searching for siblings
///
/// Returns `None` if there is no previous sibling at any ancestor level.
pub fn syntax_prev_sibling<'a, T: ToOffset>(
&'a self,
range: Range<T>,
) -> Option<tree_sitter::Node<'a>> {
let range = range.start.to_offset(self)..range.end.to_offset(self);
let mut result: Option<tree_sitter::Node<'a>> = None;
for layer in self
.syntax
.layers_for_range(range.clone(), &self.text, true)
{
let mut cursor = layer.node().walk();
Self::position_cursor_at_range(&mut cursor, &range);
// Find the node that contains the range
if !Self::find_containing_node(&mut cursor, &range, false) {
continue;
}
// Look for the previous sibling, moving up ancestor levels if needed
loop {
if cursor.goto_previous_sibling() {
let layer_result = cursor.node();
if let Some(previous_result) = &result {
if previous_result.byte_range().end < layer_result.byte_range().end {
continue;
}
}
result = Some(layer_result);
break;
}
// No sibling found at this level, try moving up to parent
if !cursor.goto_parent() {
break;
}
}
}
result
}
/// Find the next sibling syntax node at the given range.
///
/// This function locates the syntax node that follows the node containing
/// the given range. It searches hierarchically by:
/// 1. Finding the node that contains the given range
/// 2. Looking for the next sibling at the same tree level
/// 3. If no sibling is found, moving up to parent levels and searching for siblings
///
/// Returns `None` if there is no next sibling at any ancestor level.
pub fn syntax_next_sibling<'a, T: ToOffset>(
&'a self,
range: Range<T>,
) -> Option<tree_sitter::Node<'a>> {
let range = range.start.to_offset(self)..range.end.to_offset(self);
let mut result: Option<tree_sitter::Node<'a>> = None;
for layer in self
.syntax
.layers_for_range(range.clone(), &self.text, true)
{
let mut cursor = layer.node().walk();
Self::position_cursor_at_range(&mut cursor, &range);
// Find the node that contains the range
if !Self::find_containing_node(&mut cursor, &range, false) {
continue;
}
// Look for the next sibling, moving up ancestor levels if needed
loop {
if cursor.goto_next_sibling() {
let layer_result = cursor.node();
if let Some(previous_result) = &result {
if previous_result.byte_range().start > layer_result.byte_range().start {
continue;
}
}
result = Some(layer_result);
break;
}
// No sibling found at this level, try moving up to parent
if !cursor.goto_parent() {
break;
}
}
}
result
}
/// Returns the root syntax node within the given row
pub fn syntax_root_ancestor(&self, position: Anchor) -> Option<tree_sitter::Node<'_>> {
let start_offset = position.to_offset(self);

Some files were not shown because too many files have changed in this diff Show More