Compare commits

...

157 Commits

Author SHA1 Message Date
Mikayla
3a1f13645a Improve database and RPC API for moving and linking channels, improve test legibility 2023-09-09 18:20:14 -07:00
Mikayla
bf296ebbd7 Add move, link, and unlink operations 2023-09-09 13:24:04 -07:00
Mikayla
114961fc69 remove extraneous depth field 2023-09-09 12:10:18 -07:00
Mikayla
80f5e66efc Render the DAG 2023-09-09 12:06:17 -07:00
Mikayla
b7172d5e0d Finish integration tests for channel moving
Refactor channel store to combine the channels_by_id and channel_paths into a 'ChannelIndex'
2023-09-09 12:02:40 -07:00
Mikayla
1ab2007fcd WIP: Add channel DAG related RPC messages, change update message 2023-09-09 12:01:15 -07:00
Mikayla
441848d195 Improve channel deletion to be DAG aware 2023-09-09 12:00:36 -07:00
Mikayla
273fa9dd22 Add removing of previous channel channel, allowing for channel moving operations 2023-09-09 12:00:36 -07:00
Mikayla
e0602da8df Expand DAG tests to include more complex tree operations and removal behavior 2023-09-09 12:00:36 -07:00
Mikayla
65b795c213 Add channel linking operation 2023-09-09 12:00:34 -07:00
Mikayla
fe10ecebb6 Add channel moving test 2023-09-09 12:00:16 -07:00
Conrad Irwin
7cc05c99c2 Update getting started
Just ran through this again.
2023-09-08 23:46:12 -06:00
Conrad Irwin
e29ce489c8 vim: Add ZZ and ZQ (#2950)
The major change here is a refactoring to allow controling the save
behaviour when closing items, which is pre-work needed for vim command
palette.

For zed-industries/community#1868

Release Notes:

- vim: Add `ZZ` and `ZQ` to close the current item.
([#1868](https://github.com/zed-industries/community/issues/1868))
2023-09-08 16:58:04 -06:00
Conrad Irwin
4c92172cca Partially roll back refactoring 2023-09-08 16:49:50 -06:00
Conrad Irwin
ba1c350dad vim: Add ZZ and ZQ
The major change here is a refactoring to allow controling the save
behaviour when closing items, which is pre-work needed for vim command
palette.

For zed-industries/community#1868
2023-09-08 16:25:20 -06:00
Conrad Irwin
5d782b6cf0 vim . to replay (#2936)
Release Notes:

- vim: Add `.` to replay
([#946](https://github.com/zed-industries/community/issues/946))
- vim: Fix `J` in visual mode, and with counts.
2023-09-08 11:52:35 -06:00
Conrad Irwin
88dae22e3e Don't replay ShowCharacterPalette 2023-09-08 11:35:00 -06:00
Conrad Irwin
f069cd0485 Fix f,t on soft-wrapped lines (#2940)
Release Notes:

- vim: fix `f` and `t` on softwrapped lines
2023-09-08 11:34:12 -06:00
Joseph T. Lyons
e1d4d911b4 Add tooltip to language selector (#2949)
Release Notes:

- N/A
2023-09-08 12:48:37 -04:00
Joseph T. Lyons
a0701777d5 Make tooltip title case to match other tooltips 2023-09-08 12:44:49 -04:00
Joseph T. Lyons
f4a9d3f269 Add tooltip to language selector 2023-09-08 12:41:32 -04:00
Julia
87472a9de6 Fix Python's cached binary retrieval being borked (#2948)
We fixed this while brainstorming a better approach to handle server
binaries and if we already have a fix for this one then we might as well
have this not be broken while the new mechanism is being built

Release Notes:

- Fixed Python language server not launching without a network
connection.
2023-09-08 12:21:18 -04:00
Conrad Irwin
5f897f45a8 Fix f,t on soft-wrapped lines
Also remove the (dangerously confusing) display_map.find_while
2023-09-08 10:16:46 -06:00
Julia
74ccb3df63 Fix Python's cached binary retrieval being borked
Co-Authored-By: Max Brunsfeld <max@zed.dev>
2023-09-08 12:09:31 -04:00
Antonio Scandurra
e9747d0fea Find keystrokes defined on a child but handled by a parent (#2947)
This fixes a bug that was preventing keystrokes from being shown on
tooltips for the "Buffer Search" and "Inline Assist" buttons in the
toolbar.

This pull request makes the behavior of `keystrokes_for_action` more
consistent with the behavior of `available_actions`. It seems reasonable
that, if a child view defines a keystroke for an action and that action
is handled on a parent, we should show the child's keystroke.

Release Notes:

- Fixed a bug that was preventing certain keystrokes from being shown in
tooltips.
2023-09-08 14:11:30 +02:00
Antonio Scandurra
ddc8a126da Find keystrokes defined on a child but handled by a parent
This fixes a bug that was preventing keystrokes from being shown on tooltips
for the "Buffer Search" and "Inline Assist" buttons in the toolbar.

This commit makes the behavior of `keystrokes_for_action` more consistent with
the behavior of `available_actions`. It seems reasonable that, if a child view
defines a keystroke for an action and that action is handled on a parent, we
should show the child's keystroke.
2023-09-08 12:50:59 +02:00
Antonio Scandurra
6ad2ec4825 Make channel notes act as an editor to enable inline assistant (#2946)
I think it should be fine to make channel notes act as editors, so I'll
go ahead and merge this but cc'ing @mikayla-maki and @maxbrunsfeld, in
case I'm overlooking something.

Release Notes:

- Added the inline assistant to channel notes.
2023-09-08 11:51:14 +02:00
Antonio Scandurra
4e818fed4a Make channel notes act as an editor to enable inline assistant 2023-09-08 11:20:49 +02:00
Max Brunsfeld
e7b7ac9d8c Make toolbar horizontal padding more consistent (#2944)
* increase horizontal padding of toolbar itself, remove padding that was
added to individual toolbar items like feedback button.
* make feedback info text and breadcrumbs have the same additional
padding as quick action buttons.

Release Notes:

- Fixed some inconsistencies in the layout of toolbars.
2023-09-07 13:12:26 -07:00
Max Brunsfeld
56d9a578bd Make toolbar horizontal padding more consistent
* increase horizontal padding of toolbar itself, remove padding
  that was added to individual toolbar items like feedback button.
* make feedback info text and breadcrumbs have the same additional padding as
  quick action buttons.
2023-09-07 12:53:46 -07:00
Julia
5b0f4ac9e8 Stop LiveKitBridge Package.resolved from constantly updating (#2943)
Stop that damned LiveKitBridge Package.resolved from continually
changing and act more like a lock file

Release Notes:

- N/A
2023-09-07 14:58:55 -04:00
Julia
4d2933a4d7 Include JS template literal as string type for overrides (#2939)
Allows us to trigger Tailwind completions within template literals in
JSX elements

Release Notes:
- Fixed Tailwind autocomplete not appearing in template literals.
2023-09-07 14:58:41 -04:00
Kyle Caverly
560d6b1644 update semantic search to show no results if search query is blank (#2942)
Update semantic search to show no search results if search query is
blank
2023-09-07 14:58:00 -04:00
Julia
a6ce382368 Stop LiveKitBridge Package.resolved from constantly updating 2023-09-07 14:50:36 -04:00
KCaverly
cf5d1d91a4 update semantic search to go to no results if search query is blank 2023-09-07 14:43:41 -04:00
Antonio Scandurra
98999b1e9a Start indexing right away when project was already indexed before (#2941)
Release notes:
- Improved semantic search indexing to start in the background if the
project was already indexed before.
2023-09-07 19:47:26 +02:00
Antonio Scandurra
eda7e00645 Implement SemanticIndex::status and use it in project search
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
2023-09-07 19:39:30 +02:00
Conrad Irwin
8e2e00e003 add vim-specific J (with repeatability) 2023-09-07 11:08:07 -06:00
Antonio Scandurra
47d7aa0b91 Allow searching before indexing is complete
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
2023-09-07 19:04:45 +02:00
Antonio Scandurra
65e17e212d Eagerly index project on workspace creation if it was indexed before
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
2023-09-07 18:51:55 +02:00
Conrad Irwin
48bb2a3321 TEMP 2023-09-07 10:51:18 -06:00
Conrad Irwin
1b1d7f22cc Add visual area repeating 2023-09-07 10:45:38 -06:00
Julia
1969a12a0b Include JS template literal as string type for overrides 2023-09-07 10:55:04 -04:00
Antonio Scandurra
3b784668c0 Rework how we track projects and worktrees in semantic index (#2938)
This pull request introduces several improvements to the semantic search
experience. We're still missing collaboration and searching modified
buffers, which we'll tackle after we take a detour into reducing the
number of tokens used to generate embeddings.

Release Notes:

- Fixed a bug that could prevent semantic search from working when
deploying right after opening a project.
- Fixed a panic that could sometimes occur when using semantic search
while simultaneously changing a file.
- Fixed a bug that prevented semantic search from including new
worktrees when adding them to a project.
2023-09-07 15:30:19 +02:00
Antonio Scandurra
a45c8c380f 💄 2023-09-07 15:25:23 +02:00
Antonio Scandurra
757a285852 Keep dropping the documents table if it exists
This is because we renamed `documents` to `spans`.

Co-Authored-By: Kyle Caverly <kyle@zed.dev>
2023-09-07 15:15:16 +02:00
Antonio Scandurra
93b889a93b Merge remote-tracking branch 'origin/main' into semantic-search-watch-worktrees 2023-09-07 15:07:46 +02:00
Antonio Scandurra
3ad1befb11 Remove unneeded logging
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
2023-09-07 15:07:21 +02:00
Joseph T. Lyons
425a3969c8 Allow call events to be logged without a room id (#2937)
Prior to this PR, we assumed that all call events needed a room_id, but
we now have call-based actions that don't need a room_id - for instance,
you can right click a channel and view the notes while not in a call. In
this case, there is no room_id. We want to be able to track these
events, which requires removing the restriction that requires a room_id.

Release Notes:

- N/A
2023-09-06 23:08:36 -04:00
Joseph T. Lyons
39e13b6675 Allow call events to be logged without a room id 2023-09-06 22:53:05 -04:00
Max Brunsfeld
d03a89ca19 Rejoin channel notes after brief connection loss (#2930)
* [x] Re-send operations that weren't sent while disconnected
* [x] Apply other clients' operations that were missed while
disconnected
* [x] Update collaborators that joined / left while disconnected
* [x] Inform current collaborators that your peer id has changed
* [x] Refresh channel buffer collaborators on server restart
* [x] randomized test
2023-09-06 15:11:21 -07:00
Max Brunsfeld
58f58a629b Tolerate channel buffer operations being re-sent 2023-09-06 14:58:25 -07:00
Max Brunsfeld
ed2aed4f93 Update test name in randomized-test-minimize script 2023-09-06 14:29:11 -07:00
Max Brunsfeld
b75e69d31b Check that channel notes text converges in randomized test 2023-09-06 14:25:07 -07:00
Max Brunsfeld
e779adfe46 Add basic randomized integration test for channel notes 2023-09-06 14:09:36 -07:00
Max Brunsfeld
66c3879306 Extract randomized test infrastructure for use in other tests 2023-09-06 14:08:43 -07:00
Conrad Irwin
f22d53eef9 Make test more deterministic
Otherwise these pass only when --features=neovim is set
2023-09-06 14:14:49 -06:00
Conrad Irwin
20f98e4d17 vim . to replay
Co-Authored-By: maxbrunsfeld@gmail.com
2023-09-06 13:49:55 -06:00
Kyle Caverly
bbeb82f884 Token count fix (#2935)
Fix token count for OpenAIEmbeddings

Release Notes (Preview Only)

- update token count calculation for truncated OpenAIEmbeddings
- increased request timeout for OpenAI
2023-09-06 15:15:02 -04:00
KCaverly
265d02a583 update request timeout for open ai embeddings 2023-09-06 15:09:46 -04:00
KCaverly
17237f748c update token_count for OpenAIEmbeddings to accomodate for truncation 2023-09-06 15:09:15 -04:00
Joseph T. Lyons
f4237ace40 collab 0.20.0 2023-09-06 13:33:39 -04:00
Joseph T. Lyons
5b5c232cd1 Revert "Temporarily comment out cargo check commands"
This reverts commit 29e35531af.
2023-09-06 12:54:53 -04:00
Joseph T. Lyons
15609b4803 v0.104.x dev 2023-09-06 12:53:50 -04:00
Joseph T. Lyons
29e35531af Temporarily comment out cargo check commands 2023-09-06 12:53:50 -04:00
Nathan Sobo
a2e91e45d9 Use preview server when not on stable (#2909)
This PR updates our client code to connect to preview whenever we're not
on stable. This will make it more likely that we'll be able to
collaborate on a dev build, but obviously won't work if there's a
protocol change on main that hasn't made its way to preview yet.
2023-09-06 10:09:08 -06:00
Julia
246b699bfd Remove NodeRuntime static & add fake implementation for tests (#2934)
Release Notes:

- N/A
2023-09-06 11:27:28 -04:00
Julia
8d672f5d4c Remove NodeRuntime static & add fake implementation for tests 2023-09-06 11:18:55 -04:00
Antonio Scandurra
ce62173534 Rename Document to Span 2023-09-06 17:03:08 +02:00
Antonio Scandurra
de0f53b39f Ensure SemanticIndex::search waits for indexing to complete 2023-09-06 11:40:59 +02:00
Antonio Scandurra
c802680084 Clip ranges returned by SemanticIndex::search
The files may have changed since the last time they were parsed, so the
ranges returned by `SemanticIndex::search` may be out of bounds.
2023-09-06 09:41:51 +02:00
Joseph T. Lyons
9272e9354a Add operation for opening channel notes in channel-based calls (#2933)
Release Notes:

- N/A
2023-09-05 17:24:19 -04:00
Joseph T. Lyons
653d4976cd Add operation for opening channel notes in channel based calls 2023-09-05 17:13:09 -04:00
Max Brunsfeld
ec5ff20b4c Implement clearing stale channel buffer participants on server restart
Co-authored-by: Mikayla <mikayla@zed.dev>
2023-09-05 11:34:24 -07:00
Kyle Caverly
49af2874bb Eager background indexing (#2928)
This PR ships a series of optimizations for the semantic search engine.
Mostly focused on removing invalid states, optimizing requests to
OpenAI, and reducing token usage.

Release Notes (Preview-Only):

- Added eager incremental indexing in the background on a debounce.
- Added a local embeddings cache for reducing redundant calls to OpenAI.
- Moved to an Embeddings Queue model which ensures optimal batch sizes
at the token level, and atomic file & document writes.
- Adjusted OpenAI Embedding API requests to use provided backoff delays
during Rate Limiting.
- Removed flush races between parsing files step and embedding queue
steps.
- Moved truncation to parsing step reducing the probability that OpenAI
encounters bad data.
2023-09-05 13:15:54 -04:00
Conrad Irwin
c2c04616b4 vim S (#2929)
Release Notes:
- vim: Add `S` to substitute line ([#1897](https://github.com/zed-industries/community/issues/1897)).
2023-09-05 09:39:08 -06:00
Conrad Irwin
27143e2fb4 Split ContextMenu actions (#2931)
This should have no user-visible impact.

For vim `.` to repeat it's important that actions are replayable.
Currently editor::MoveDown *sometimes* moves the cursor down, and
*sometimes* selects the next completion.

For replay we need to be able to separate the two.
2023-09-05 09:38:08 -06:00
Antonio Scandurra
95b72a73ad Re-index project when a worktree is registered
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
2023-09-05 17:17:58 +02:00
Antonio Scandurra
3c70b127bd Simplify SemanticIndex::index_project
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
2023-09-05 16:54:48 +02:00
Nate Butler
4855063151 Fix cropped search filters (#2932)
Because of the way we set up tools that add rows inside the toolbar it
is complicated to tighten up the spacing inside the toolbar.

This PR just reverts the changes I made previously. We'll need to
properly add rows below the toolbar instead of rendering search inside
of it to have non-equal height tools be able to descend from it.

Release Notes:

- Preview – Fixed an issue where search filters were partially cut off
in the UI.
2023-09-05 10:49:38 -04:00
Nate Butler
e2479a7172 Fix cropped search filters 2023-09-05 10:24:49 -04:00
Antonio Scandurra
6b1dc63fc0 Retrieve embeddings based on pending files
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
2023-09-05 16:16:12 +02:00
Antonio Scandurra
7b5a41dda2 Move retrieval of embeddings from the db into reindex_changed_files
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
2023-09-05 16:09:24 +02:00
Antonio Scandurra
d4cff68475 🎨 2023-09-05 15:52:36 +02:00
Kirill Bulatov
42976b6014 Add LSP logs clear button (#2913)
LSP logs tend to accumulate and hinder performance (e.g. search is
slower over 20 MB of json files).
Add a way to clear them.

Release Notes:

- N/A
2023-09-05 12:11:35 +03:00
Conrad Irwin
56db21d54b Split ContextMenu actions
This should have no user-visible impact.

For vim `.` to repeat it's important that actions are replayable.
Currently editor::MoveDown *sometimes* moves the cursor down, and
*sometimes* selects the next completion.

For replay we need to be able to separate the two.
2023-09-02 21:04:19 -06:00
Conrad Irwin
55dd0b176c Use consistent naming 2023-09-02 19:52:18 -06:00
Conrad Irwin
3a7b551e33 Fix tests with no neovim 2023-09-02 19:43:05 -06:00
Max Brunsfeld
6827ddf97d Start work on refreshing channel buffer collaborators on server restart 2023-09-01 17:51:00 -07:00
Max Brunsfeld
e6babce556 Broadcast new peer ids for rejoined channel collaborators 2023-09-01 17:23:55 -07:00
Max Brunsfeld
d7e4cb4ab1 executor: timers must be used 2023-09-01 16:52:41 -07:00
Max Brunsfeld
d370c72fbf Start work on rejoining channel buffers 2023-09-01 16:52:12 -07:00
KCaverly
8dbc0fe033 update pragma settings for improved database performance 2023-09-01 17:07:20 -04:00
Conrad Irwin
da16167db1 Fix find_{,preceding}boundary to work on buffer text (#2912)
Fixes movement::find_boundary to work on the buffer, not on display
points.

The user-visible impact is that the "until end of word" commands now
correctly go to the end of a soft-wrapped word (instead of to the first
character of the wrapped line).

It also fixes a bug where the callback passed to these methods was
called with the content of inlay hints.

[[PR Description]]

Release Notes:

- fix finding end of word on soft-wrapped lines
2023-09-01 12:14:16 -07:00
Conrad Irwin
af12977d17 vim: Add S to substitute line
For zed-industries/community#1897
2023-09-01 13:13:59 -06:00
Conrad Irwin
aa7b65bbaf Merge branch 'main' into vim-softwrap-word 2023-09-01 12:23:56 -06:00
Conrad Irwin
0e41c6c5b3 Fix accidental visual selection on scroll (#2927)
Release Notes:

- vim: Fix bug where scrolling vertically would sometimes enter visual
mode
2023-09-01 10:58:10 -07:00
Conrad Irwin
6d7949654b Fix accidental visual selection on scroll
As part of this fix partial page distance calculations to more closely
match vim.
2023-09-01 11:14:27 -06:00
KCaverly
54235f4fb1 updated embeddings background delay to 5 minutes
Co-authored-by: Max <max@zed.dev>
2023-09-01 13:04:09 -04:00
KCaverly
e86964eb5d optimize insert file in vector database
Co-authored-by: Max <max@zed.dev>
2023-09-01 13:01:37 -04:00
KCaverly
524533cfb2 flush embeddings queue when no files are parsed for 250 milliseconds
Co-authored-by: Antonio <antonio@zed.dev>
2023-09-01 11:24:08 -04:00
KCaverly
c4db914f0a move embeddings queue to use single hashmap for all changed paths
Co-authored-by: Antonio <me@as-cii.com>
2023-09-01 08:59:25 -04:00
Antonio Scandurra
2bf417fa45 Avoid duplicate entries in inline assistant's prompt history (#2926)
Release Notes:

- Improved the inline assistant's prompt history to avoid including the
same entry multiple times. (preview-only)
2023-09-01 09:20:14 +02:00
Antonio Scandurra
d868ec920f Avoid duplicate entries in inline assistant's prompt history 2023-09-01 09:15:29 +02:00
Max Brunsfeld
7bcc59c8a5 Remove the concept of a local clock; use lamport clocks for all per-replica versioning (#2924)
### Background

Currently, our CRDT uses three different types of timestamps:

| clock type | representation | purpose |
|-----|----------------|----------|
| `Local` | replica id + u32 | uniquely identifies operations |
| `Lamport` | replica id + u32 | provides a consistent total ordering
for all operations |
| `Global` | N local clocks | fully defines the partial ordering between
all concurrent operations |

All text operations include *each* type of timestamp. And every
`Fragment` in a buffer's fragment tree contains both a local and a
lamport timestamp.

### Change

An operation can be uniquely identified by its lamport timestamp, so we
don't really need a concept of a local timestamp. In this PR, I've
removed the concept of a local timestamp. Version vectors
(`clock::Global`) now store vectors of *lamport* timestamps.

Eliminating local timestamps reduces the memory footprint of a buffer by
four bytes per fragment, reduces the size of our `UpdateBuffer` RPC
messages, and reduces the amount of data we need to store in our
database for channel buffers. It also makes our CRDT a bit easier to
understand, IMO, because there is now only one scalar value that we
increment per replica.

It's possible I'm missing something here though. @as-cii, @nathansobo
it'd be good to get your 👀
2023-08-31 16:47:08 -07:00
Max Brunsfeld
1e60454643 Renumber protobuf fields, bump protocol version 2023-08-31 16:31:26 -07:00
Max Brunsfeld
03f0365d4d Remove local timestamps from CRDT operations
Use lamport timestamps for everything.
2023-08-31 16:23:06 -07:00
KCaverly
afa59abbcd WIP: work towards wiring up a embeddings_for_digest hashmap that is stored for all indexed files 2023-08-31 16:42:39 -04:00
Max Brunsfeld
00aae5abee Assistant: propagate cancel action if there is no pending inline assist (#2923)
Release Notes:

- Fixed a bug where modals could not be dismissed with `escape` when
certain views were active in the workspace (preview only).
2023-08-31 11:17:09 -07:00
Max Brunsfeld
eecd4e39cc Propagate Cancel action if there is no pending inline assist 2023-08-31 11:09:36 -07:00
KCaverly
50cfb067e7 fill embeddings with database values and skip during embeddings queue 2023-08-31 13:19:17 -04:00
Antonio Scandurra
220533ff1a WIP 2023-08-31 18:00:57 +02:00
Antonio Scandurra
2503d54d19 Rename Sha1 to DocumentDigest
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
2023-08-31 18:00:36 +02:00
Antonio Scandurra
3001a46f69 Reify Embedding/Sha1 structs that can be (de)serialized from SQL
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
2023-08-31 17:55:43 +02:00
Kirill Bulatov
fe2300fdaa Style the clear button better, add border to button constructor options 2023-08-31 18:31:21 +03:00
Kirill Bulatov
7b5974e8e9 Add LSP logs clear button 2023-08-31 18:31:21 +03:00
Antonio Scandurra
c763e728d1 Write to and read from the database in a transactional way
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
2023-08-31 16:59:54 +02:00
Antonio Scandurra
35440be98e Abstract away how database transactions are executed
Co-Authored-By: Kyle Caverly <kyle@zed.dev>
2023-08-31 16:54:11 +02:00
Kirill Bulatov
ddc6214216 Tailwind autocomplete (#2920)
Release Notes:
- Added basic Tailwind CSS autocomplete support
([#746](https://github.com/zed-industries/community/issues/746)).
2023-08-31 16:55:46 +03:00
Kirill Bulatov
5731ef51cd Fix plugin LSP adapter intefrace 2023-08-31 15:32:24 +03:00
Kirill Bulatov
e682db7101 Route completion requests through remote protocol, if needed 2023-08-31 15:22:13 +03:00
Kirill Bulatov
5bc5831032 Fix wrong assertion in the test 2023-08-31 14:31:43 +03:00
Kirill Bulatov
292af55ebc Ensure all client LSP queries are forwarded via collab 2023-08-31 14:29:37 +03:00
Kirill Bulatov
fff385a585 Fix project tests 2023-08-31 13:01:53 +03:00
Kirill Bulatov
9e12df43d0 Post-rebase fixes 2023-08-31 11:53:46 +03:00
Julia
ff3865a4ad Merge branch 'main' into multi-server-completions-tailwind 2023-08-30 22:58:37 -04:00
Julia
529adb95a1 Scope Tailwind in JS/TS to within string
In some situations outside JSX elements Tailwind will never
respond to a completion request, holding up the tsserver completions.

Only submit the request to Tailwind when we wouldn't get tsserver
completions anyway and don't submit to Tailwind when we know we won't
get Tailwind completions

Co-Authored-By: Kirill Bulatov <kirill@zed.dev>
2023-08-30 21:14:39 -04:00
KCaverly
7d4d6c871b fix bug for truncation ensuring no valid inputs are sent to openai 2023-08-30 17:42:16 -04:00
KCaverly
5abad58b0d moved semantic index to use embeddings queue to batch and managed for atomic database writes
Co-authored-by: Max <max@zed.dev>
2023-08-30 16:58:45 -04:00
KCaverly
76ce52df4e move queuing to embedding_queue functionality and update embedding provider to include trait items for max tokens per batch"
Co-authored-by: Max <max@zed.dev>
2023-08-30 16:01:28 -04:00
KCaverly
9781047156 move truncation to parsing step leveraging the EmbeddingProvider trait 2023-08-30 12:13:26 -04:00
KCaverly
76caea80f7 add should_truncate to embedding providers 2023-08-30 11:58:45 -04:00
Kirill Bulatov
7e5735c8f1 Reap overly long LSP requests with a 2m timeout
Co-authored-by: Julia Risley <julia@zed.dev>
2023-08-30 18:41:41 +03:00
KCaverly
e377ada1a9 added token count to documents during parsing 2023-08-30 11:05:46 -04:00
Conrad Irwin
d3650594c3 Fix find_{,preceding}boundary to work on buffer text
Before this change the bounday could mistakenly have happened on a soft
line wrap.

Also fixes interaction with inlays better.
2023-08-29 18:03:29 -07:00
Julia
e3a0252b04 Make multi-server completion requests not serial 2023-08-29 20:42:13 -04:00
KCaverly
a7e6a65deb reindex files in the background after they have not been edited for 10 minutes
Co-authored-by: Max <max@zed.dev>
2023-08-29 17:14:44 -04:00
KCaverly
4f8b95cf0d add proper handling for open ai rate limit delays 2023-08-29 15:44:51 -04:00
Julia
0e6c91818f Woooooops, don't notify the language server until initialized 2023-08-29 15:37:51 -04:00
Nathan Sobo
2d411303bb Use preview server when not on stable 2023-08-29 10:07:22 -06:00
Julia
15628af04b Style language server name in completion menu
Omit in buffers with one or zero running language servers with the
capability to provide completions

Co-Authored-By: Antonio Scandurra <antonio@zed.dev>
2023-08-29 11:21:02 -04:00
Julia
35b7787e02 Add Tailwind server to TSX 2023-08-28 15:19:16 -04:00
Julia
ded6decb29 Initial unstyled language server short name in completions
Co-Authored-By: Kirill Bulatov <kirill@zed.dev>
2023-08-28 11:27:45 -04:00
Julia
fc457d45f5 Add word_characters to language overrides & use for more things
Use word_characters to feed completion trigger characters as well and
also recognize kebab as a potential sub-word splitter. This is fine for
non-kebab-case languages because we'd only ever attempt to split a word
with a kebab in it in language scopes which are kebab-cased

Co-Authored-By: Max Brunsfeld <max@zed.dev>
2023-08-25 18:46:30 -04:00
Julia
a394aaa524 Add Tailwind server to JS/TS 2023-08-23 00:11:15 -04:00
Julia
68408f3838 Add VSCode CSS language server & add Tailwind to .css files 2023-08-22 23:50:40 -04:00
Julia
affb73d651 Only generate workspace/configuration for relevant adapter 2023-08-22 23:36:04 -04:00
Kirill Bulatov
814896de3f Reenable html, remove emmet due to the lack of the code 2023-08-22 12:51:14 +03:00
Kirill Bulatov
a35b3f39c5 Expand word characters for html and css 2023-08-22 12:41:59 +03:00
Piotr Osiewicz
007d1b09ac Z 2819 (#2872)
This PR adds new config option to language config called
`word_boundaries` that controls which characters should be recognised as
word boundary for a given language. This will improve our UX for
languages such as PHP and Tailwind.

Release Notes:

- Improved completions for PHP
[#1820](https://github.com/zed-industries/community/issues/1820)

---------

Co-authored-by: Julia Risley <julia@zed.dev>
2023-08-22 12:23:30 +03:00
Julia
c842e87079 Use updated lsp-types fork branch 2023-08-18 11:57:19 -04:00
Julia
a979e32127 Utilize LSP completion itemDefaults a bit
Tailwind likes to throw a lot of completion data at us, this gets it to
send less. Previously it would respond to a completion with 2.5 MB JSON
blob, now it is more like 0.8 MB.

Relies on a local copy of lsp-types with the `itemDefaults` field added.
I don't have write perms to push to our fork of the crate atm, sorry :)
2023-08-17 21:57:39 -04:00
Kirill Bulatov
4f0fa21c04 Provide more data to tailwind langserver
Tailwind needs user languages and language-to-language-id mappings to
start providing completions for those languages.
And also it has emmet completions disabled by default, enable them.
2023-08-17 16:14:55 +03:00
Julia
e54f16f372 Register initial request handlers before launching server 2023-08-16 21:25:17 -04:00
Julia
8839b07a25 Add broken Tailwind language server 2023-08-16 11:53:05 -04:00
Julia
40ce099780 Use originating language server to resolve additional completion edits 2023-08-15 16:34:15 -04:00
Julia
7a67ec5743 Add support for querying multiple language servers for completions 2023-08-15 12:48:30 -04:00
148 changed files with 11240 additions and 6252 deletions

77
Cargo.lock generated
View File

@@ -1453,9 +1453,10 @@ dependencies = [
[[package]]
name = "collab"
version = "0.19.0"
version = "0.20.0"
dependencies = [
"anyhow",
"async-trait",
"async-tungstenite",
"audio",
"axum",
@@ -3539,7 +3540,7 @@ dependencies = [
"gif",
"jpeg-decoder",
"num-iter",
"num-rational",
"num-rational 0.3.2",
"num-traits",
"png",
"scoped_threadpool",
@@ -4177,8 +4178,7 @@ dependencies = [
[[package]]
name = "lsp-types"
version = "0.94.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c66bfd44a06ae10647fe3f8214762e9369fd4248df1350924b4ef9e770a85ea1"
source = "git+https://github.com/zed-industries/lsp-types?branch=updated-completion-list-item-defaults#90a040a1d195687bd19e1df47463320a44e93d7a"
dependencies = [
"bitflags 1.3.2",
"serde",
@@ -4583,6 +4583,7 @@ dependencies = [
"anyhow",
"async-compression",
"async-tar",
"async-trait",
"futures 0.3.28",
"gpui",
"log",
@@ -4632,6 +4633,31 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "num"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
dependencies = [
"num-bigint 0.2.6",
"num-complex",
"num-integer",
"num-iter",
"num-rational 0.2.4",
"num-traits",
]
[[package]]
name = "num-bigint"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-bigint"
version = "0.4.4"
@@ -4660,6 +4686,16 @@ dependencies = [
"zeroize",
]
[[package]]
name = "num-complex"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-derive"
version = "0.3.3"
@@ -4692,6 +4728,18 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef"
dependencies = [
"autocfg",
"num-bigint 0.2.6",
"num-integer",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.3.2"
@@ -5008,6 +5056,17 @@ dependencies = [
"windows-targets 0.48.5",
]
[[package]]
name = "parse_duration"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7037e5e93e0172a5a96874380bf73bc6ecef022e26fa25f2be26864d6b3ba95d"
dependencies = [
"lazy_static",
"num",
"regex",
]
[[package]]
name = "password-hash"
version = "0.2.3"
@@ -6663,6 +6722,7 @@ dependencies = [
"anyhow",
"async-trait",
"bincode",
"collections",
"ctor",
"editor",
"env_logger 0.9.3",
@@ -6675,6 +6735,7 @@ dependencies = [
"log",
"matrixmultiply",
"parking_lot 0.11.2",
"parse_duration",
"picker",
"postage",
"pretty_assertions",
@@ -7006,7 +7067,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eb4ea60fb301dc81dfc113df680571045d375ab7345d171c5dc7d7e13107a80"
dependencies = [
"chrono",
"num-bigint",
"num-bigint 0.4.4",
"num-traits",
"thiserror",
]
@@ -7238,7 +7299,7 @@ dependencies = [
"log",
"md-5",
"memchr",
"num-bigint",
"num-bigint 0.4.4",
"once_cell",
"paste",
"percent-encoding",
@@ -8768,12 +8829,14 @@ dependencies = [
"collections",
"command_palette",
"editor",
"futures 0.3.28",
"gpui",
"indoc",
"itertools",
"language",
"language_selector",
"log",
"lsp",
"nvim-rs",
"parking_lot 0.11.2",
"project",
@@ -9702,7 +9765,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.103.0"
version = "0.104.0"
dependencies = [
"activity_indicator",
"ai",

View File

@@ -8,7 +8,31 @@ Welcome to Zed, a lightning-fast, collaborative code editor that makes your drea
### Dependencies
* Install [Postgres.app](https://postgresapp.com) and start it.
* Install Xcode from https://apps.apple.com/us/app/xcode/id497799835?mt=12, and accept the license:
```
sudo xcodebuild -license
```
* Install homebrew, rust and node
```
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
brew install rust
brew install node
```
* Ensure rust executables are in your $PATH
```
echo $HOME/.cargo/bin | sudo tee /etc/paths.d/10-rust
```
* Install postgres and configure the database
```
brew install postgresql@15
brew services start postgresql@15
psql -c "CREATE ROLE postgres SUPERUSER LOGIN" postgres
psql -U postgres -c "CREATE DATABASE zed"
```
* Install the `LiveKit` server and the `foreman` process supervisor:
```
@@ -41,6 +65,17 @@ Welcome to Zed, a lightning-fast, collaborative code editor that makes your drea
GITHUB_TOKEN=<$token> script/bootstrap
```
* Now try running zed with collaboration disabled:
```
cargo run
```
### Common errors
* `xcrun: error: unable to find utility "metal", not a developer tool or in PATH`
* You need to install Xcode and then run: `xcode-select --switch /Applications/Xcode.app/Contents/Developer`
* (see https://github.com/gfx-rs/gfx/issues/2309)
### Testing against locally-running servers
Start the web and collab servers:

View File

@@ -515,6 +515,17 @@
"enter": "editor::ConfirmCodeAction"
}
},
{
"context": "Editor && (showing_code_actions || showing_completions)",
"bindings": {
"up": "editor::ContextMenuPrev",
"ctrl-p": "editor::ContextMenuPrev",
"down": "editor::ContextMenuNext",
"ctrl-n": "editor::ContextMenuNext",
"pageup": "editor::ContextMenuFirst",
"pagedown": "editor::ContextMenuLast"
}
},
// Custom bindings
{
"bindings": {

View File

@@ -198,6 +198,18 @@
"z c": "editor::Fold",
"z o": "editor::UnfoldLines",
"z f": "editor::FoldSelectedRanges",
"shift-z shift-q": [
"pane::CloseActiveItem",
{
"saveBehavior": "dontSave"
}
],
"shift-z shift-z": [
"pane::CloseActiveItem",
{
"saveBehavior": "promptOnConflict"
}
],
// Count support
"1": [
"vim::Number",
@@ -316,6 +328,7 @@
{
"context": "Editor && vim_mode == normal && (vim_operator == none || vim_operator == n) && !VimWaiting",
"bindings": {
".": "vim::Repeat",
"c": [
"vim::PushOperator",
"Change"
@@ -326,15 +339,12 @@
"Delete"
],
"shift-d": "vim::DeleteToEndOfLine",
"shift-j": "editor::JoinLines",
"shift-j": "vim::JoinLines",
"y": [
"vim::PushOperator",
"Yank"
],
"i": [
"vim::SwitchMode",
"Insert"
],
"i": "vim::InsertBefore",
"shift-i": "vim::InsertFirstNonWhitespace",
"a": "vim::InsertAfter",
"shift-a": "vim::InsertEndOfLine",
@@ -371,6 +381,7 @@
"Replace"
],
"s": "vim::Substitute",
"shift-s": "vim::SubstituteLine",
"> >": "editor::Indent",
"< <": "editor::Outdent",
"ctrl-pagedown": "pane::ActivateNextItem",
@@ -446,13 +457,13 @@
}
],
"s": "vim::Substitute",
"shift-s": "vim::SubstituteLine",
"shift-r": "vim::SubstituteLine",
"c": "vim::Substitute",
"~": "vim::ChangeCase",
"shift-i": [
"vim::SwitchMode",
"Insert"
],
"shift-i": "vim::InsertBefore",
"shift-a": "vim::InsertAfter",
"shift-j": "vim::JoinLines",
"r": [
"vim::PushOperator",
"Replace"

View File

@@ -406,36 +406,30 @@ impl AssistantPanel {
_: &editor::Cancel,
cx: &mut ViewContext<Workspace>,
) {
let panel = if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
panel
} else {
return;
};
let editor = if let Some(editor) = workspace
.active_item(cx)
.and_then(|item| item.downcast::<Editor>())
{
editor
} else {
return;
};
let handled = panel.update(cx, |panel, cx| {
if let Some(assist_id) = panel
.pending_inline_assist_ids_by_editor
.get(&editor.downgrade())
.and_then(|assist_ids| assist_ids.last().copied())
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
if let Some(editor) = workspace
.active_item(cx)
.and_then(|item| item.downcast::<Editor>())
{
panel.close_inline_assist(assist_id, true, cx);
true
} else {
false
let handled = panel.update(cx, |panel, cx| {
if let Some(assist_id) = panel
.pending_inline_assist_ids_by_editor
.get(&editor.downgrade())
.and_then(|assist_ids| assist_ids.last().copied())
{
panel.close_inline_assist(assist_id, true, cx);
true
} else {
false
}
});
if handled {
return;
}
}
});
if !handled {
cx.propagate_action();
}
cx.propagate_action();
}
fn close_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext<Self>) {
@@ -513,10 +507,13 @@ impl AssistantPanel {
return;
};
self.inline_prompt_history
.retain(|prompt| prompt != user_prompt);
self.inline_prompt_history.push_back(user_prompt.into());
if self.inline_prompt_history.len() > Self::INLINE_PROMPT_HISTORY_MAX_LEN {
self.inline_prompt_history.pop_front();
}
let range = pending_assist.range.clone();
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
let selected_text = snapshot

View File

@@ -273,7 +273,13 @@ impl ActiveCall {
.borrow_mut()
.take()
.ok_or_else(|| anyhow!("no incoming call"))?;
Self::report_call_event_for_room("decline incoming", call.room_id, None, &self.client, cx);
Self::report_call_event_for_room(
"decline incoming",
Some(call.room_id),
None,
&self.client,
cx,
);
self.client.send(proto::DeclineCall {
room_id: call.room_id,
})?;
@@ -404,21 +410,19 @@ impl ActiveCall {
}
fn report_call_event(&self, operation: &'static str, cx: &AppContext) {
if let Some(room) = self.room() {
let room = room.read(cx);
Self::report_call_event_for_room(
operation,
room.id(),
room.channel_id(),
&self.client,
cx,
)
}
let (room_id, channel_id) = match self.room() {
Some(room) => {
let room = room.read(cx);
(Some(room.id()), room.channel_id())
}
None => (None, None),
};
Self::report_call_event_for_room(operation, room_id, channel_id, &self.client, cx)
}
pub fn report_call_event_for_room(
operation: &'static str,
room_id: u64,
room_id: Option<u64>,
channel_id: Option<u64>,
client: &Arc<Client>,
cx: &AppContext,

View File

@@ -10,6 +10,7 @@ pub(crate) fn init(client: &Arc<Client>) {
client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer);
client.add_model_message_handler(ChannelBuffer::handle_add_channel_buffer_collaborator);
client.add_model_message_handler(ChannelBuffer::handle_remove_channel_buffer_collaborator);
client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer_collaborator);
}
pub struct ChannelBuffer {
@@ -17,6 +18,7 @@ pub struct ChannelBuffer {
connected: bool,
collaborators: Vec<proto::Collaborator>,
buffer: ModelHandle<language::Buffer>,
buffer_epoch: u64,
client: Arc<Client>,
subscription: Option<client::Subscription>,
}
@@ -73,6 +75,7 @@ impl ChannelBuffer {
Self {
buffer,
buffer_epoch: response.epoch,
client,
connected: true,
collaborators,
@@ -82,6 +85,26 @@ impl ChannelBuffer {
}))
}
pub(crate) fn replace_collaborators(
&mut self,
collaborators: Vec<proto::Collaborator>,
cx: &mut ModelContext<Self>,
) {
for old_collaborator in &self.collaborators {
if collaborators
.iter()
.any(|c| c.replica_id == old_collaborator.replica_id)
{
self.buffer.update(cx, |buffer, cx| {
buffer.remove_peer(old_collaborator.replica_id as u16, cx)
});
}
}
self.collaborators = collaborators;
cx.emit(Event::CollaboratorsChanged);
cx.notify();
}
async fn handle_update_channel_buffer(
this: ModelHandle<Self>,
update_channel_buffer: TypedEnvelope<proto::UpdateChannelBuffer>,
@@ -149,6 +172,26 @@ impl ChannelBuffer {
Ok(())
}
async fn handle_update_channel_buffer_collaborator(
this: ModelHandle<Self>,
message: TypedEnvelope<proto::UpdateChannelBufferCollaborator>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, cx| {
for collaborator in &mut this.collaborators {
if collaborator.peer_id == message.payload.old_peer_id {
collaborator.peer_id = message.payload.new_peer_id;
break;
}
}
cx.emit(Event::CollaboratorsChanged);
cx.notify();
});
Ok(())
}
fn on_buffer_update(
&mut self,
_: ModelHandle<language::Buffer>,
@@ -166,6 +209,10 @@ impl ChannelBuffer {
}
}
pub fn epoch(&self) -> u64 {
self.buffer_epoch
}
pub fn buffer(&self) -> ModelHandle<language::Buffer> {
self.buffer.clone()
}
@@ -179,6 +226,7 @@ impl ChannelBuffer {
}
pub(crate) fn disconnect(&mut self, cx: &mut ModelContext<Self>) {
log::info!("channel buffer {} disconnected", self.channel.id);
if self.connected {
self.connected = false;
self.subscription.take();

View File

@@ -1,18 +1,24 @@
mod channel_index;
use crate::channel_buffer::ChannelBuffer;
use anyhow::{anyhow, Result};
use client::{Client, Status, Subscription, User, UserId, UserStore};
use client::{Client, Subscription, User, UserId, UserStore};
use collections::{hash_map, HashMap, HashSet};
use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt};
use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use rpc::{proto, TypedEnvelope};
use std::sync::Arc;
use std::{mem, sync::Arc, time::Duration};
use util::ResultExt;
use self::channel_index::ChannelIndex;
pub use self::channel_index::ChannelPath;
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
pub type ChannelId = u64;
pub struct ChannelStore {
channels_by_id: HashMap<ChannelId, Arc<Channel>>,
channel_paths: Vec<Vec<ChannelId>>,
channel_index: ChannelIndex,
channel_invitations: Vec<Arc<Channel>>,
channel_participants: HashMap<ChannelId, Vec<Arc<User>>>,
channels_with_admin_privileges: HashSet<ChannelId>,
@@ -22,7 +28,8 @@ pub struct ChannelStore {
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
_rpc_subscription: Subscription,
_watch_connection_status: Task<()>,
_watch_connection_status: Task<Option<()>>,
disconnect_channel_buffers_task: Option<Task<()>>,
_update_channels: Task<()>,
}
@@ -67,30 +74,25 @@ impl ChannelStore {
let rpc_subscription =
client.add_message_handler(cx.handle(), Self::handle_update_channels);
let (update_channels_tx, mut update_channels_rx) = mpsc::unbounded();
let mut connection_status = client.status();
let (update_channels_tx, mut update_channels_rx) = mpsc::unbounded();
let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
while let Some(status) = connection_status.next().await {
if !status.is_connected() {
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
if matches!(status, Status::ConnectionLost | Status::SignedOut) {
this.handle_disconnect(cx);
} else {
this.disconnect_buffers(cx);
}
});
} else {
break;
}
let this = this.upgrade(&cx)?;
if status.is_connected() {
this.update(&mut cx, |this, cx| this.handle_connect(cx))
.await
.log_err()?;
} else {
this.update(&mut cx, |this, cx| this.handle_disconnect(cx));
}
}
Some(())
});
Self {
channels_by_id: HashMap::default(),
channel_invitations: Vec::default(),
channel_paths: Vec::default(),
channel_index: ChannelIndex::default(),
channel_participants: Default::default(),
channels_with_admin_privileges: Default::default(),
outgoing_invites: Default::default(),
@@ -100,6 +102,7 @@ impl ChannelStore {
user_store,
_rpc_subscription: rpc_subscription,
_watch_connection_status: watch_connection_status,
disconnect_channel_buffers_task: None,
_update_channels: cx.spawn_weak(|this, mut cx| async move {
while let Some(update_channels) = update_channels_rx.next().await {
if let Some(this) = this.upgrade(&cx) {
@@ -116,7 +119,7 @@ impl ChannelStore {
}
pub fn has_children(&self, channel_id: ChannelId) -> bool {
self.channel_paths.iter().any(|path| {
self.channel_index.iter().any(|path| {
if let Some(ix) = path.iter().position(|id| *id == channel_id) {
path.len() > ix + 1
} else {
@@ -126,22 +129,23 @@ impl ChannelStore {
}
pub fn channel_count(&self) -> usize {
self.channel_paths.len()
self.channel_index.len()
}
pub fn channels(&self) -> impl '_ + Iterator<Item = (usize, &Arc<Channel>)> {
self.channel_paths.iter().map(move |path| {
self.channel_index.iter().map(move |path| {
let id = path.last().unwrap();
let channel = self.channel_for_id(*id).unwrap();
(path.len() - 1, channel)
})
}
pub fn channel_at_index(&self, ix: usize) -> Option<(usize, &Arc<Channel>)> {
let path = self.channel_paths.get(ix)?;
pub fn channel_at_index(&self, ix: usize) -> Option<(&Arc<Channel>, &ChannelPath)> {
let path = self.channel_index.get(ix)?;
let id = path.last().unwrap();
let channel = self.channel_for_id(*id).unwrap();
Some((path.len() - 1, channel))
Some((channel, path))
}
pub fn channel_invitations(&self) -> &[Arc<Channel>] {
@@ -149,7 +153,16 @@ impl ChannelStore {
}
pub fn channel_for_id(&self, channel_id: ChannelId) -> Option<&Arc<Channel>> {
self.channels_by_id.get(&channel_id)
self.channel_index.by_id().get(&channel_id)
}
pub fn has_open_channel_buffer(&self, channel_id: ChannelId, cx: &AppContext) -> bool {
if let Some(buffer) = self.opened_buffers.get(&channel_id) {
if let OpenedChannelBuffer::Open(buffer) = buffer {
return buffer.upgrade(cx).is_some();
}
}
false
}
pub fn open_channel_buffer(
@@ -221,7 +234,7 @@ impl ChannelStore {
}
pub fn is_user_admin(&self, channel_id: ChannelId) -> bool {
self.channel_paths.iter().any(|path| {
self.channel_index.iter().any(|path| {
if let Some(ix) = path.iter().position(|id| *id == channel_id) {
path[..=ix]
.iter()
@@ -276,6 +289,59 @@ impl ChannelStore {
})
}
pub fn link_channel(
&mut self,
channel_id: ChannelId,
to: ChannelId,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let client = self.client.clone();
cx.spawn(|_, _| async move {
let _ = client
.request(proto::LinkChannel { channel_id, to })
.await?;
Ok(())
})
}
pub fn unlink_channel(
&mut self,
channel_id: ChannelId,
from: Option<ChannelId>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let client = self.client.clone();
cx.spawn(|_, _| async move {
let _ = client
.request(proto::UnlinkChannel { channel_id, from })
.await?;
Ok(())
})
}
pub fn move_channel(
&mut self,
channel_id: ChannelId,
from: Option<ChannelId>,
to: ChannelId,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
let client = self.client.clone();
cx.spawn(|_, _| async move {
let _ = client
.request(proto::MoveChannel {
channel_id,
from,
to,
})
.await?;
Ok(())
})
}
pub fn invite_member(
&mut self,
channel_id: ChannelId,
@@ -455,7 +521,7 @@ impl ChannelStore {
pub fn remove_channel(&self, channel_id: ChannelId) -> impl Future<Output = Result<()>> {
let client = self.client.clone();
async move {
client.request(proto::RemoveChannel { channel_id }).await?;
client.request(proto::DeleteChannel { channel_id }).await?;
Ok(())
}
}
@@ -482,25 +548,130 @@ impl ChannelStore {
Ok(())
}
fn handle_disconnect(&mut self, cx: &mut ModelContext<'_, ChannelStore>) {
self.disconnect_buffers(cx);
self.channels_by_id.clear();
self.channel_invitations.clear();
self.channel_participants.clear();
self.channels_with_admin_privileges.clear();
self.channel_paths.clear();
self.outgoing_invites.clear();
cx.notify();
}
fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
self.disconnect_channel_buffers_task.take();
fn disconnect_buffers(&mut self, cx: &mut ModelContext<ChannelStore>) {
for (_, buffer) in self.opened_buffers.drain() {
let mut buffer_versions = Vec::new();
for buffer in self.opened_buffers.values() {
if let OpenedChannelBuffer::Open(buffer) = buffer {
if let Some(buffer) = buffer.upgrade(cx) {
buffer.update(cx, |buffer, cx| buffer.disconnect(cx));
let channel_buffer = buffer.read(cx);
let buffer = channel_buffer.buffer().read(cx);
buffer_versions.push(proto::ChannelBufferVersion {
channel_id: channel_buffer.channel().id,
epoch: channel_buffer.epoch(),
version: language::proto::serialize_version(&buffer.version()),
});
}
}
}
if buffer_versions.is_empty() {
return Task::ready(Ok(()));
}
let response = self.client.request(proto::RejoinChannelBuffers {
buffers: buffer_versions,
});
cx.spawn(|this, mut cx| async move {
let mut response = response.await?;
this.update(&mut cx, |this, cx| {
this.opened_buffers.retain(|_, buffer| match buffer {
OpenedChannelBuffer::Open(channel_buffer) => {
let Some(channel_buffer) = channel_buffer.upgrade(cx) else {
return false;
};
channel_buffer.update(cx, |channel_buffer, cx| {
let channel_id = channel_buffer.channel().id;
if let Some(remote_buffer) = response
.buffers
.iter_mut()
.find(|buffer| buffer.channel_id == channel_id)
{
let channel_id = channel_buffer.channel().id;
let remote_version =
language::proto::deserialize_version(&remote_buffer.version);
channel_buffer.replace_collaborators(
mem::take(&mut remote_buffer.collaborators),
cx,
);
let operations = channel_buffer
.buffer()
.update(cx, |buffer, cx| {
let outgoing_operations =
buffer.serialize_ops(Some(remote_version), cx);
let incoming_operations =
mem::take(&mut remote_buffer.operations)
.into_iter()
.map(language::proto::deserialize_operation)
.collect::<Result<Vec<_>>>()?;
buffer.apply_ops(incoming_operations, cx)?;
anyhow::Ok(outgoing_operations)
})
.log_err();
if let Some(operations) = operations {
let client = this.client.clone();
cx.background()
.spawn(async move {
let operations = operations.await;
for chunk in
language::proto::split_operations(operations)
{
client
.send(proto::UpdateChannelBuffer {
channel_id,
operations: chunk,
})
.ok();
}
})
.detach();
return true;
}
}
channel_buffer.disconnect(cx);
false
})
}
OpenedChannelBuffer::Loading(_) => true,
});
});
anyhow::Ok(())
})
}
fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
self.channel_index.clear();
self.channel_invitations.clear();
self.channel_participants.clear();
self.channels_with_admin_privileges.clear();
self.channel_index.clear();
self.outgoing_invites.clear();
cx.notify();
self.disconnect_channel_buffers_task.get_or_insert_with(|| {
cx.spawn_weak(|this, mut cx| async move {
cx.background().timer(RECONNECT_TIMEOUT).await;
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
for (_, buffer) in this.opened_buffers.drain() {
if let OpenedChannelBuffer::Open(buffer) = buffer {
if let Some(buffer) = buffer.upgrade(cx) {
buffer.update(cx, |buffer, cx| buffer.disconnect(cx));
}
}
}
});
}
})
});
}
pub(crate) fn update_channels(
@@ -528,17 +699,16 @@ impl ChannelStore {
}
}
let channels_changed = !payload.channels.is_empty() || !payload.remove_channels.is_empty();
let channels_changed = !payload.channels.is_empty() || !payload.delete_channels.is_empty();
if channels_changed {
if !payload.remove_channels.is_empty() {
self.channels_by_id
.retain(|channel_id, _| !payload.remove_channels.contains(channel_id));
if !payload.delete_channels.is_empty() {
self.channel_index.delete_channels(&payload.delete_channels);
self.channel_participants
.retain(|channel_id, _| !payload.remove_channels.contains(channel_id));
.retain(|channel_id, _| !payload.delete_channels.contains(channel_id));
self.channels_with_admin_privileges
.retain(|channel_id| !payload.remove_channels.contains(channel_id));
.retain(|channel_id| !payload.delete_channels.contains(channel_id));
for channel_id in &payload.remove_channels {
for channel_id in &payload.delete_channels {
let channel_id = *channel_id;
if let Some(OpenedChannelBuffer::Open(buffer)) =
self.opened_buffers.remove(&channel_id)
@@ -550,44 +720,15 @@ impl ChannelStore {
}
}
for channel_proto in payload.channels {
if let Some(existing_channel) = self.channels_by_id.get_mut(&channel_proto.id) {
Arc::make_mut(existing_channel).name = channel_proto.name;
} else {
let channel = Arc::new(Channel {
id: channel_proto.id,
name: channel_proto.name,
});
self.channels_by_id.insert(channel.id, channel.clone());
if let Some(parent_id) = channel_proto.parent_id {
let mut ix = 0;
while ix < self.channel_paths.len() {
let path = &self.channel_paths[ix];
if path.ends_with(&[parent_id]) {
let mut new_path = path.clone();
new_path.push(channel.id);
self.channel_paths.insert(ix + 1, new_path);
ix += 1;
}
ix += 1;
}
} else {
self.channel_paths.push(vec![channel.id]);
}
}
let mut channel_index = self.channel_index.start_upsert();
for channel in payload.channels {
channel_index.upsert(channel)
}
}
self.channel_paths.sort_by(|a, b| {
let a = Self::channel_path_sorting_key(a, &self.channels_by_id);
let b = Self::channel_path_sorting_key(b, &self.channels_by_id);
a.cmp(b)
});
self.channel_paths.dedup();
self.channel_paths.retain(|path| {
path.iter()
.all(|channel_id| self.channels_by_id.contains_key(channel_id))
});
for edge in payload.delete_channel_edge {
self.channel_index
.delete_edge(edge.parent_id, edge.channel_id);
}
for permission in payload.channel_permissions {
@@ -645,12 +786,4 @@ impl ChannelStore {
anyhow::Ok(())
}))
}
fn channel_path_sorting_key<'a>(
path: &'a [ChannelId],
channels_by_id: &'a HashMap<ChannelId, Arc<Channel>>,
) -> impl 'a + Iterator<Item = Option<&'a str>> {
path.iter()
.map(|id| Some(channels_by_id.get(id)?.name.as_str()))
}
}

View File

@@ -0,0 +1,161 @@
use std::{sync::Arc, ops::Deref};
use collections::HashMap;
use rpc::proto;
use serde_derive::{Serialize, Deserialize};
use crate::{ChannelId, Channel};
pub type ChannelsById = HashMap<ChannelId, Arc<Channel>>;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub struct ChannelPath(Arc<[ChannelId]>);
impl Deref for ChannelPath {
type Target = [ChannelId];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl ChannelPath {
pub fn parent_id(&self) -> Option<ChannelId> {
self.0.len().checked_sub(2).map(|i| {
self.0[i]
})
}
}
impl Default for ChannelPath {
fn default() -> Self {
ChannelPath(Arc::from([]))
}
}
#[derive(Default, Debug)]
pub struct ChannelIndex {
paths: Vec<ChannelPath>,
channels_by_id: ChannelsById,
}
impl ChannelIndex {
pub fn by_id(&self) -> &ChannelsById {
&self.channels_by_id
}
pub fn clear(&mut self) {
self.paths.clear();
self.channels_by_id.clear();
}
pub fn len(&self) -> usize {
self.paths.len()
}
pub fn get(&self, idx: usize) -> Option<&ChannelPath> {
self.paths.get(idx)
}
pub fn iter(&self) -> impl Iterator<Item = &ChannelPath> {
self.paths.iter()
}
/// Remove the given edge from this index. This will not remove the channel
/// and may result in dangling channels.
pub fn delete_edge(&mut self, parent_id: ChannelId, channel_id: ChannelId) {
self.paths.retain(|path| {
!path
.windows(2)
.any(|window| window == [parent_id, channel_id])
});
}
/// Delete the given channels from this index.
pub fn delete_channels(&mut self, channels: &[ChannelId]) {
self.channels_by_id.retain(|channel_id, _| !channels.contains(channel_id));
self.paths.retain(|channel_path| !channel_path.iter().any(|channel_id| {channels.contains(channel_id)}))
}
/// Upsert one or more channels into this index.
pub fn start_upsert(& mut self) -> ChannelPathsUpsertGuard {
ChannelPathsUpsertGuard {
paths: &mut self.paths,
channels_by_id: &mut self.channels_by_id,
}
}
}
/// A guard for ensuring that the paths index maintains its sort and uniqueness
/// invariants after a series of insertions
pub struct ChannelPathsUpsertGuard<'a> {
paths: &'a mut Vec<ChannelPath>,
channels_by_id: &'a mut ChannelsById,
}
impl<'a> ChannelPathsUpsertGuard<'a> {
pub fn upsert(&mut self, channel_proto: proto::Channel) {
if let Some(existing_channel) = self.channels_by_id.get_mut(&channel_proto.id) {
Arc::make_mut(existing_channel).name = channel_proto.name;
if let Some(parent_id) = channel_proto.parent_id {
self.insert_edge(parent_id, channel_proto.id)
}
} else {
let channel = Arc::new(Channel {
id: channel_proto.id,
name: channel_proto.name,
});
self.channels_by_id.insert(channel.id, channel.clone());
if let Some(parent_id) = channel_proto.parent_id {
self.insert_edge(parent_id, channel.id);
} else {
self.insert_root(channel.id);
}
}
}
fn insert_edge(&mut self, parent_id: ChannelId, channel_id: ChannelId) {
let mut ix = 0;
while ix < self.paths.len() {
let path = &self.paths[ix];
if path.ends_with(&[parent_id]) {
let mut new_path = path.to_vec();
new_path.push(channel_id);
self.paths.insert(ix + 1, ChannelPath(new_path.into()));
ix += 1;
}
ix += 1;
}
}
fn insert_root(&mut self, channel_id: ChannelId) {
self.paths.push(ChannelPath(Arc::from([channel_id])));
}
}
impl<'a> Drop for ChannelPathsUpsertGuard<'a> {
fn drop(&mut self) {
self.paths.sort_by(|a, b| {
let a = channel_path_sorting_key(a, &self.channels_by_id);
let b = channel_path_sorting_key(b, &self.channels_by_id);
a.cmp(b)
});
self.paths.dedup();
self.paths.retain(|path| {
path.iter()
.all(|channel_id| self.channels_by_id.contains_key(channel_id))
});
}
}
fn channel_path_sorting_key<'a>(
path: &'a [ChannelId],
channels_by_id: &'a ChannelsById,
) -> impl 'a + Iterator<Item = Option<&'a str>> {
path.iter()
.map(|id| Some(channels_by_id.get(id)?.name.as_str()))
}

View File

@@ -127,7 +127,7 @@ fn test_dangling_channel_paths(cx: &mut AppContext) {
update_channels(
&channel_store,
proto::UpdateChannels {
remove_channels: vec![1, 2],
delete_channels: vec![1, 2],
..Default::default()
},
cx,

View File

@@ -1011,9 +1011,9 @@ impl Client {
credentials: &Credentials,
cx: &AsyncAppContext,
) -> Task<Result<Connection, EstablishConnectionError>> {
let is_preview = cx.read(|cx| {
let use_preview_server = cx.read(|cx| {
if cx.has_global::<ReleaseChannel>() {
*cx.global::<ReleaseChannel>() == ReleaseChannel::Preview
*cx.global::<ReleaseChannel>() != ReleaseChannel::Stable
} else {
false
}
@@ -1028,7 +1028,7 @@ impl Client {
let http = self.http.clone();
cx.background().spawn(async move {
let mut rpc_url = Self::get_rpc_url(http, is_preview).await?;
let mut rpc_url = Self::get_rpc_url(http, use_preview_server).await?;
let rpc_host = rpc_url
.host_str()
.zip(rpc_url.port_or_known_default())

View File

@@ -73,7 +73,7 @@ pub enum ClickhouseEvent {
},
Call {
operation: &'static str,
room_id: u64,
room_id: Option<u64>,
channel_id: Option<u64>,
},
}

View File

@@ -2,70 +2,17 @@ use smallvec::SmallVec;
use std::{
cmp::{self, Ordering},
fmt, iter,
ops::{Add, AddAssign},
};
pub type ReplicaId = u16;
pub type Seq = u32;
#[derive(Clone, Copy, Default, Eq, Hash, PartialEq, Ord, PartialOrd)]
pub struct Local {
pub replica_id: ReplicaId,
pub value: Seq,
}
#[derive(Clone, Copy, Default, Eq, Hash, PartialEq)]
pub struct Lamport {
pub replica_id: ReplicaId,
pub value: Seq,
}
impl Local {
pub const MIN: Self = Self {
replica_id: ReplicaId::MIN,
value: Seq::MIN,
};
pub const MAX: Self = Self {
replica_id: ReplicaId::MAX,
value: Seq::MAX,
};
pub fn new(replica_id: ReplicaId) -> Self {
Self {
replica_id,
value: 1,
}
}
pub fn tick(&mut self) -> Self {
let timestamp = *self;
self.value += 1;
timestamp
}
pub fn observe(&mut self, timestamp: Self) {
if timestamp.replica_id == self.replica_id {
self.value = cmp::max(self.value, timestamp.value + 1);
}
}
}
impl<'a> Add<&'a Self> for Local {
type Output = Local;
fn add(self, other: &'a Self) -> Self::Output {
*cmp::max(&self, other)
}
}
impl<'a> AddAssign<&'a Local> for Local {
fn add_assign(&mut self, other: &Self) {
if *self < *other {
*self = *other;
}
}
}
/// A vector clock
#[derive(Clone, Default, Hash, Eq, PartialEq)]
pub struct Global(SmallVec<[u32; 8]>);
@@ -79,7 +26,7 @@ impl Global {
self.0.get(replica_id as usize).copied().unwrap_or(0) as Seq
}
pub fn observe(&mut self, timestamp: Local) {
pub fn observe(&mut self, timestamp: Lamport) {
if timestamp.value > 0 {
let new_len = timestamp.replica_id as usize + 1;
if new_len > self.0.len() {
@@ -126,7 +73,7 @@ impl Global {
self.0.resize(new_len, 0);
}
pub fn observed(&self, timestamp: Local) -> bool {
pub fn observed(&self, timestamp: Lamport) -> bool {
self.get(timestamp.replica_id) >= timestamp.value
}
@@ -178,16 +125,16 @@ impl Global {
false
}
pub fn iter(&self) -> impl Iterator<Item = Local> + '_ {
self.0.iter().enumerate().map(|(replica_id, seq)| Local {
pub fn iter(&self) -> impl Iterator<Item = Lamport> + '_ {
self.0.iter().enumerate().map(|(replica_id, seq)| Lamport {
replica_id: replica_id as ReplicaId,
value: *seq,
})
}
}
impl FromIterator<Local> for Global {
fn from_iter<T: IntoIterator<Item = Local>>(locals: T) -> Self {
impl FromIterator<Lamport> for Global {
fn from_iter<T: IntoIterator<Item = Lamport>>(locals: T) -> Self {
let mut result = Self::new();
for local in locals {
result.observe(local);
@@ -212,6 +159,16 @@ impl PartialOrd for Lamport {
}
impl Lamport {
pub const MIN: Self = Self {
replica_id: ReplicaId::MIN,
value: Seq::MIN,
};
pub const MAX: Self = Self {
replica_id: ReplicaId::MAX,
value: Seq::MAX,
};
pub fn new(replica_id: ReplicaId) -> Self {
Self {
value: 1,
@@ -230,12 +187,6 @@ impl Lamport {
}
}
impl fmt::Debug for Local {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Local {{{}: {}}}", self.replica_id, self.value)
}
}
impl fmt::Debug for Lamport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Lamport {{{}: {}}}", self.replica_id, self.value)

View File

@@ -3,7 +3,7 @@ authors = ["Nathan Sobo <nathan@zed.dev>"]
default-run = "collab"
edition = "2021"
name = "collab"
version = "0.19.0"
version = "0.20.0"
publish = false
[[bin]]
@@ -72,7 +72,6 @@ fs = { path = "../fs", features = ["test-support"] }
git = { path = "../git", features = ["test-support"] }
live_kit_client = { path = "../live_kit_client", features = ["test-support"] }
lsp = { path = "../lsp", features = ["test-support"] }
pretty_assertions.workspace = true
project = { path = "../project", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
settings = { path = "../settings", features = ["test-support"] }
@@ -80,6 +79,8 @@ theme = { path = "../theme" }
workspace = { path = "../workspace", features = ["test-support"] }
collab_ui = { path = "../collab_ui", features = ["test-support"] }
async-trait.workspace = true
pretty_assertions.workspace = true
ctor.workspace = true
env_logger.workspace = true
indoc.workspace = true

View File

@@ -435,6 +435,12 @@ pub struct ChannelsForUser {
pub channels_with_admin_privileges: HashSet<ChannelId>,
}
#[derive(Debug)]
pub struct RejoinedChannelBuffer {
pub buffer: proto::RejoinedChannelBuffer,
pub old_connection_id: ConnectionId,
}
#[derive(Clone)]
pub struct JoinRoom {
pub room: proto::Room,
@@ -498,6 +504,11 @@ pub struct RefreshedRoom {
pub canceled_calls_to_user_ids: Vec<UserId>,
}
pub struct RefreshedChannelBuffer {
pub connection_ids: Vec<ConnectionId>,
pub removed_collaborators: Vec<proto::RemoveChannelBufferCollaborator>,
}
pub struct Project {
pub collaborators: Vec<ProjectCollaborator>,
pub worktrees: BTreeMap<u64, Worktree>,

View File

@@ -1,6 +1,6 @@
use super::*;
use prost::Message;
use text::{EditOperation, InsertionTimestamp, UndoOperation};
use text::{EditOperation, UndoOperation};
impl Database {
pub async fn join_channel_buffer(
@@ -10,8 +10,6 @@ impl Database {
connection: ConnectionId,
) -> Result<proto::JoinChannelBufferResponse> {
self.transaction(|tx| async move {
let tx = tx;
self.check_user_is_channel_member(channel_id, user_id, &tx)
.await?;
@@ -70,7 +68,6 @@ impl Database {
.await?;
collaborators.push(collaborator);
// Assemble the buffer state
let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
Ok(proto::JoinChannelBufferResponse {
@@ -78,6 +75,7 @@ impl Database {
replica_id: replica_id.to_proto() as u32,
base_text,
operations,
epoch: buffer.epoch as u64,
collaborators: collaborators
.into_iter()
.map(|collaborator| proto::Collaborator {
@@ -91,6 +89,154 @@ impl Database {
.await
}
pub async fn rejoin_channel_buffers(
&self,
buffers: &[proto::ChannelBufferVersion],
user_id: UserId,
connection_id: ConnectionId,
) -> Result<Vec<RejoinedChannelBuffer>> {
self.transaction(|tx| async move {
let mut results = Vec::new();
for client_buffer in buffers {
let channel_id = ChannelId::from_proto(client_buffer.channel_id);
if self
.check_user_is_channel_member(channel_id, user_id, &*tx)
.await
.is_err()
{
log::info!("user is not a member of channel");
continue;
}
let buffer = self.get_channel_buffer(channel_id, &*tx).await?;
let mut collaborators = channel_buffer_collaborator::Entity::find()
.filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
.all(&*tx)
.await?;
// If the buffer epoch hasn't changed since the client lost
// connection, then the client's buffer can be syncronized with
// the server's buffer.
if buffer.epoch as u64 != client_buffer.epoch {
log::info!("can't rejoin buffer, epoch has changed");
continue;
}
// Find the collaborator record for this user's previous lost
// connection. Update it with the new connection id.
let server_id = ServerId(connection_id.owner_id as i32);
let Some(self_collaborator) = collaborators.iter_mut().find(|c| {
c.user_id == user_id
&& (c.connection_lost || c.connection_server_id != server_id)
}) else {
log::info!("can't rejoin buffer, no previous collaborator found");
continue;
};
let old_connection_id = self_collaborator.connection();
*self_collaborator = channel_buffer_collaborator::ActiveModel {
id: ActiveValue::Unchanged(self_collaborator.id),
connection_id: ActiveValue::Set(connection_id.id as i32),
connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
connection_lost: ActiveValue::Set(false),
..Default::default()
}
.update(&*tx)
.await?;
let client_version = version_from_wire(&client_buffer.version);
let serialization_version = self
.get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
.await?;
let mut rows = buffer_operation::Entity::find()
.filter(
buffer_operation::Column::BufferId
.eq(buffer.id)
.and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
)
.stream(&*tx)
.await?;
// Find the server's version vector and any operations
// that the client has not seen.
let mut server_version = clock::Global::new();
let mut operations = Vec::new();
while let Some(row) = rows.next().await {
let row = row?;
let timestamp = clock::Lamport {
replica_id: row.replica_id as u16,
value: row.lamport_timestamp as u32,
};
server_version.observe(timestamp);
if !client_version.observed(timestamp) {
operations.push(proto::Operation {
variant: Some(operation_from_storage(row, serialization_version)?),
})
}
}
results.push(RejoinedChannelBuffer {
old_connection_id,
buffer: proto::RejoinedChannelBuffer {
channel_id: client_buffer.channel_id,
version: version_to_wire(&server_version),
operations,
collaborators: collaborators
.into_iter()
.map(|collaborator| proto::Collaborator {
peer_id: Some(collaborator.connection().into()),
user_id: collaborator.user_id.to_proto(),
replica_id: collaborator.replica_id.0 as u32,
})
.collect(),
},
});
}
Ok(results)
})
.await
}
pub async fn clear_stale_channel_buffer_collaborators(
&self,
channel_id: ChannelId,
server_id: ServerId,
) -> Result<RefreshedChannelBuffer> {
self.transaction(|tx| async move {
let collaborators = channel_buffer_collaborator::Entity::find()
.filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
.all(&*tx)
.await?;
let mut connection_ids = Vec::new();
let mut removed_collaborators = Vec::new();
let mut collaborator_ids_to_remove = Vec::new();
for collaborator in &collaborators {
if !collaborator.connection_lost && collaborator.connection_server_id == server_id {
connection_ids.push(collaborator.connection());
} else {
removed_collaborators.push(proto::RemoveChannelBufferCollaborator {
channel_id: channel_id.to_proto(),
peer_id: Some(collaborator.connection().into()),
});
collaborator_ids_to_remove.push(collaborator.id);
}
}
channel_buffer_collaborator::Entity::delete_many()
.filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove))
.exec(&*tx)
.await?;
Ok(RefreshedChannelBuffer {
connection_ids,
removed_collaborators,
})
})
.await
}
pub async fn leave_channel_buffer(
&self,
channel_id: ChannelId,
@@ -103,6 +249,39 @@ impl Database {
.await
}
pub async fn leave_channel_buffers(
&self,
connection: ConnectionId,
) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
self.transaction(|tx| async move {
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryChannelIds {
ChannelId,
}
let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
.select_only()
.column(channel_buffer_collaborator::Column::ChannelId)
.filter(Condition::all().add(
channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
))
.into_values::<_, QueryChannelIds>()
.all(&*tx)
.await?;
let mut result = Vec::new();
for channel_id in channel_ids {
let collaborators = self
.leave_channel_buffer_internal(channel_id, connection, &*tx)
.await?;
result.push((channel_id, collaborators));
}
Ok(result)
})
.await
}
pub async fn leave_channel_buffer_internal(
&self,
channel_id: ChannelId,
@@ -143,46 +322,12 @@ impl Database {
drop(rows);
if connections.is_empty() {
self.snapshot_buffer(channel_id, &tx).await?;
self.snapshot_channel_buffer(channel_id, &tx).await?;
}
Ok(connections)
}
pub async fn leave_channel_buffers(
&self,
connection: ConnectionId,
) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
self.transaction(|tx| async move {
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryChannelIds {
ChannelId,
}
let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
.select_only()
.column(channel_buffer_collaborator::Column::ChannelId)
.filter(Condition::all().add(
channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
))
.into_values::<_, QueryChannelIds>()
.all(&*tx)
.await?;
let mut result = Vec::new();
for channel_id in channel_ids {
let collaborators = self
.leave_channel_buffer_internal(channel_id, connection, &*tx)
.await?;
result.push((channel_id, collaborators));
}
Ok(result)
})
.await
}
#[cfg(debug_assertions)]
pub async fn get_channel_buffer_collaborators(
&self,
channel_id: ChannelId,
@@ -225,20 +370,9 @@ impl Database {
.await?
.ok_or_else(|| anyhow!("no such buffer"))?;
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryVersion {
OperationSerializationVersion,
}
let serialization_version: i32 = buffer
.find_related(buffer_snapshot::Entity)
.select_only()
.column(buffer_snapshot::Column::OperationSerializationVersion)
.filter(buffer_snapshot::Column::Epoch.eq(buffer.epoch))
.into_values::<_, QueryVersion>()
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("missing buffer snapshot"))?;
let serialization_version = self
.get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
.await?;
let operations = operations
.iter()
@@ -246,6 +380,16 @@ impl Database {
.collect::<Vec<_>>();
if !operations.is_empty() {
buffer_operation::Entity::insert_many(operations)
.on_conflict(
OnConflict::columns([
buffer_operation::Column::BufferId,
buffer_operation::Column::Epoch,
buffer_operation::Column::LamportTimestamp,
buffer_operation::Column::ReplicaId,
])
.do_nothing()
.to_owned(),
)
.exec(&*tx)
.await?;
}
@@ -271,6 +415,38 @@ impl Database {
.await
}
async fn get_buffer_operation_serialization_version(
&self,
buffer_id: BufferId,
epoch: i32,
tx: &DatabaseTransaction,
) -> Result<i32> {
Ok(buffer_snapshot::Entity::find()
.filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
.filter(buffer_snapshot::Column::Epoch.eq(epoch))
.select_only()
.column(buffer_snapshot::Column::OperationSerializationVersion)
.into_values::<_, QueryOperationSerializationVersion>()
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("missing buffer snapshot"))?)
}
async fn get_channel_buffer(
&self,
channel_id: ChannelId,
tx: &DatabaseTransaction,
) -> Result<buffer::Model> {
Ok(channel::Model {
id: channel_id,
..Default::default()
}
.find_related(buffer::Entity)
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such buffer"))?)
}
async fn get_buffer_state(
&self,
buffer: &buffer::Model,
@@ -304,27 +480,20 @@ impl Database {
.await?;
let mut operations = Vec::new();
while let Some(row) = rows.next().await {
let row = row?;
let operation = operation_from_storage(row, version)?;
operations.push(proto::Operation {
variant: Some(operation),
variant: Some(operation_from_storage(row?, version)?),
})
}
Ok((base_text, operations))
}
async fn snapshot_buffer(&self, channel_id: ChannelId, tx: &DatabaseTransaction) -> Result<()> {
let buffer = channel::Model {
id: channel_id,
..Default::default()
}
.find_related(buffer::Entity)
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such buffer"))?;
async fn snapshot_channel_buffer(
&self,
channel_id: ChannelId,
tx: &DatabaseTransaction,
) -> Result<()> {
let buffer = self.get_channel_buffer(channel_id, tx).await?;
let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
if operations.is_empty() {
return Ok(());
@@ -370,7 +539,6 @@ fn operation_to_storage(
operation.replica_id,
operation.lamport_timestamp,
storage::Operation {
local_timestamp: operation.local_timestamp,
version: version_to_storage(&operation.version),
is_undo: false,
edit_ranges: operation
@@ -389,7 +557,6 @@ fn operation_to_storage(
operation.replica_id,
operation.lamport_timestamp,
storage::Operation {
local_timestamp: operation.local_timestamp,
version: version_to_storage(&operation.version),
is_undo: true,
edit_ranges: Vec::new(),
@@ -399,7 +566,7 @@ fn operation_to_storage(
.iter()
.map(|entry| storage::UndoCount {
replica_id: entry.replica_id,
local_timestamp: entry.local_timestamp,
lamport_timestamp: entry.lamport_timestamp,
count: entry.count,
})
.collect(),
@@ -427,7 +594,6 @@ fn operation_from_storage(
Ok(if operation.is_undo {
proto::operation::Variant::Undo(proto::operation::Undo {
replica_id: row.replica_id as u32,
local_timestamp: operation.local_timestamp as u32,
lamport_timestamp: row.lamport_timestamp as u32,
version,
counts: operation
@@ -435,7 +601,7 @@ fn operation_from_storage(
.iter()
.map(|entry| proto::UndoCount {
replica_id: entry.replica_id,
local_timestamp: entry.local_timestamp,
lamport_timestamp: entry.lamport_timestamp,
count: entry.count,
})
.collect(),
@@ -443,7 +609,6 @@ fn operation_from_storage(
} else {
proto::operation::Variant::Edit(proto::operation::Edit {
replica_id: row.replica_id as u32,
local_timestamp: operation.local_timestamp as u32,
lamport_timestamp: row.lamport_timestamp as u32,
version,
ranges: operation
@@ -483,10 +648,9 @@ fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::
pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
match operation.variant? {
proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
timestamp: InsertionTimestamp {
timestamp: clock::Lamport {
replica_id: edit.replica_id as text::ReplicaId,
local: edit.local_timestamp,
lamport: edit.lamport_timestamp,
value: edit.lamport_timestamp,
},
version: version_from_wire(&edit.version),
ranges: edit
@@ -498,32 +662,26 @@ pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operatio
.collect(),
new_text: edit.new_text.into_iter().map(Arc::from).collect(),
})),
proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo {
lamport_timestamp: clock::Lamport {
proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
timestamp: clock::Lamport {
replica_id: undo.replica_id as text::ReplicaId,
value: undo.lamport_timestamp,
},
undo: UndoOperation {
id: clock::Local {
replica_id: undo.replica_id as text::ReplicaId,
value: undo.local_timestamp,
},
version: version_from_wire(&undo.version),
counts: undo
.counts
.into_iter()
.map(|c| {
(
clock::Local {
replica_id: c.replica_id as text::ReplicaId,
value: c.local_timestamp,
},
c.count,
)
})
.collect(),
},
}),
version: version_from_wire(&undo.version),
counts: undo
.counts
.into_iter()
.map(|c| {
(
clock::Lamport {
replica_id: c.replica_id as text::ReplicaId,
value: c.lamport_timestamp,
},
c.count,
)
})
.collect(),
})),
_ => None,
}
}
@@ -531,7 +689,7 @@ pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operatio
fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
let mut version = clock::Global::new();
for entry in message {
version.observe(clock::Local {
version.observe(clock::Lamport {
replica_id: entry.replica_id as text::ReplicaId,
value: entry.timestamp,
});
@@ -539,6 +697,22 @@ fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
version
}
fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
let mut message = Vec::new();
for entry in version.iter() {
message.push(proto::VectorClockEntry {
replica_id: entry.replica_id as u32,
timestamp: entry.value,
});
}
message
}
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryOperationSerializationVersion {
OperationSerializationVersion,
}
mod storage {
#![allow(non_snake_case)]
use prost::Message;
@@ -546,8 +720,6 @@ mod storage {
#[derive(Message)]
pub struct Operation {
#[prost(uint32, tag = "1")]
pub local_timestamp: u32,
#[prost(message, repeated, tag = "2")]
pub version: Vec<VectorClockEntry>,
#[prost(bool, tag = "3")]
@@ -581,7 +753,7 @@ mod storage {
#[prost(uint32, tag = "1")]
pub replica_id: u32,
#[prost(uint32, tag = "2")]
pub local_timestamp: u32,
pub lamport_timestamp: u32,
#[prost(uint32, tag = "3")]
pub count: u32,
}

View File

@@ -1,6 +1,22 @@
use super::*;
type ChannelDescendants = HashMap<ChannelId, HashSet<ChannelId>>;
impl Database {
#[cfg(test)]
pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
self.transaction(move |tx| async move {
let mut channels = Vec::new();
let mut rows = channel::Entity::find().stream(&*tx).await?;
while let Some(row) = rows.next().await {
let row = row?;
channels.push((row.id, row.name));
}
Ok(channels)
})
.await
}
pub async fn create_root_channel(
&self,
name: &str,
@@ -86,7 +102,7 @@ impl Database {
.await
}
pub async fn remove_channel(
pub async fn delete_channel(
&self,
channel_id: ChannelId,
user_id: UserId,
@@ -135,6 +151,19 @@ impl Database {
.exec(&*tx)
.await?;
// Delete any other paths that incldue this channel
let sql = r#"
DELETE FROM channel_paths
WHERE
id_path LIKE '%' || $1 || '%'
"#;
let channel_paths_stmt = Statement::from_sql_and_values(
self.pool.get_database_backend(),
sql,
[channel_id.to_proto().into()],
);
tx.execute(channel_paths_stmt).await?;
Ok((channels_to_remove.into_keys().collect(), members_to_notify))
})
.await
@@ -305,6 +334,43 @@ impl Database {
.await
}
async fn get_all_channels(
&self,
parents_by_child_id: ChannelDescendants,
tx: &DatabaseTransaction,
) -> Result<Vec<Channel>> {
let mut channels = Vec::with_capacity(parents_by_child_id.len());
{
let mut rows = channel::Entity::find()
.filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
.stream(&*tx)
.await?;
while let Some(row) = rows.next().await {
let row = row?;
// As these rows are pulled from the map's keys, this unwrap is safe.
let parents = parents_by_child_id.get(&row.id).unwrap();
if parents.len() > 0 {
for parent in parents {
channels.push(Channel {
id: row.id,
name: row.name.clone(),
parent_id: Some(*parent),
});
}
} else {
channels.push(Channel {
id: row.id,
name: row.name,
parent_id: None,
});
}
}
}
Ok(channels)
}
pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
self.transaction(|tx| async move {
let tx = tx;
@@ -327,21 +393,7 @@ impl Database {
.filter_map(|membership| membership.admin.then_some(membership.channel_id))
.collect();
let mut channels = Vec::with_capacity(parents_by_child_id.len());
{
let mut rows = channel::Entity::find()
.filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
.stream(&*tx)
.await?;
while let Some(row) = rows.next().await {
let row = row?;
channels.push(Channel {
id: row.id,
name: row.name,
parent_id: parents_by_child_id.get(&row.id).copied().flatten(),
});
}
}
let channels = self.get_all_channels(parents_by_child_id, &tx).await?;
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryUserIdsAndChannelIds {
@@ -545,6 +597,7 @@ impl Database {
Ok(())
}
/// Returns the channel ancestors, deepest first
pub async fn get_channel_ancestors(
&self,
channel_id: ChannelId,
@@ -552,6 +605,7 @@ impl Database {
) -> Result<Vec<ChannelId>> {
let paths = channel_path::Entity::find()
.filter(channel_path::Column::ChannelId.eq(channel_id))
.order_by(channel_path::Column::IdPath, sea_query::Order::Desc)
.all(tx)
.await?;
let mut channel_ids = Vec::new();
@@ -568,11 +622,25 @@ impl Database {
Ok(channel_ids)
}
/// Returns the channel descendants,
/// Structured as a map from child ids to their parent ids
/// For example, the descendants of 'a' in this DAG:
///
/// /- b -\
/// a -- c -- d
///
/// would be:
/// {
/// a: [],
/// b: [a],
/// c: [a],
/// d: [a, c],
/// }
async fn get_channel_descendants(
&self,
channel_ids: impl IntoIterator<Item = ChannelId>,
tx: &DatabaseTransaction,
) -> Result<HashMap<ChannelId, Option<ChannelId>>> {
) -> Result<ChannelDescendants> {
let mut values = String::new();
for id in channel_ids {
if !values.is_empty() {
@@ -599,7 +667,7 @@ impl Database {
let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
let mut parents_by_child_id = HashMap::default();
let mut parents_by_child_id: ChannelDescendants = HashMap::default();
let mut paths = channel_path::Entity::find()
.from_raw_sql(stmt)
.stream(tx)
@@ -618,7 +686,10 @@ impl Database {
parent_id = Some(id);
}
}
parents_by_child_id.insert(path.channel_id, parent_id);
let entry = parents_by_child_id.entry(path.channel_id).or_default();
if let Some(parent_id) = parent_id {
entry.insert(parent_id);
}
}
Ok(parents_by_child_id)
@@ -689,6 +760,191 @@ impl Database {
})
.await
}
// Insert an edge from the given channel to the given other channel.
pub async fn link_channel(
&self,
user: UserId,
channel: ChannelId,
to: ChannelId,
) -> Result<Vec<Channel>> {
self.transaction(|tx| async move {
// Note that even with these maxed permissions, this linking operation
// is still insecure because you can't remove someone's permissions to a
// channel if they've linked the channel to one where they're an admin.
self.check_user_is_channel_admin(channel, user, &*tx)
.await?;
self.link_channel_internal(user, channel, to, &*tx).await
})
.await
}
pub async fn link_channel_internal(
&self,
user: UserId,
channel: ChannelId,
to: ChannelId,
tx: &DatabaseTransaction,
) -> Result<Vec<Channel>> {
self.check_user_is_channel_admin(to, user, &*tx).await?;
let to_ancestors = self.get_channel_ancestors(to, &*tx).await?;
let mut from_descendants = self.get_channel_descendants([channel], &*tx).await?;
for ancestor in to_ancestors {
if from_descendants.contains_key(&ancestor) {
return Err(anyhow!("Cannot create a channel cycle").into());
}
}
let sql = r#"
INSERT INTO channel_paths
(id_path, channel_id)
SELECT
id_path || $1 || '/', $2
FROM
channel_paths
WHERE
channel_id = $3
ON CONFLICT (id_path) DO NOTHING;
"#;
let channel_paths_stmt = Statement::from_sql_and_values(
self.pool.get_database_backend(),
sql,
[
channel.to_proto().into(),
channel.to_proto().into(),
to.to_proto().into(),
],
);
tx.execute(channel_paths_stmt).await?;
for (from_id, to_ids) in from_descendants.iter().filter(|(id, _)| id != &&channel) {
for to_id in to_ids {
let channel_paths_stmt = Statement::from_sql_and_values(
self.pool.get_database_backend(),
sql,
[
from_id.to_proto().into(),
from_id.to_proto().into(),
to_id.to_proto().into(),
],
);
tx.execute(channel_paths_stmt).await?;
}
}
if let Some(channel) = from_descendants.get_mut(&channel) {
// Remove the other parents
channel.clear();
channel.insert(to);
}
let channels = self.get_all_channels(from_descendants, &*tx).await?;
Ok(channels)
}
/// Unlink a channel from a given parent. This will add in a root edge if
/// the channel has no other parents after this operation.
pub async fn unlink_channel(
&self,
user: UserId,
channel: ChannelId,
from: Option<ChannelId>,
) -> Result<()> {
self.transaction(|tx| async move {
// Note that even with these maxed permissions, this linking operation
// is still insecure because you can't remove someone's permissions to a
// channel if they've linked the channel to one where they're an admin.
self.check_user_is_channel_admin(channel, user, &*tx)
.await?;
self.unlink_channel_internal(user, channel, from, &*tx)
.await?;
Ok(())
})
.await
}
pub async fn unlink_channel_internal(
&self,
user: UserId,
channel: ChannelId,
from: Option<ChannelId>,
tx: &DatabaseTransaction,
) -> Result<()> {
if let Some(from) = from {
self.check_user_is_channel_admin(from, user, &*tx).await?;
let sql = r#"
DELETE FROM channel_paths
WHERE
id_path LIKE '%' || $1 || '/' || $2 || '%'
"#;
let channel_paths_stmt = Statement::from_sql_and_values(
self.pool.get_database_backend(),
sql,
[from.to_proto().into(), channel.to_proto().into()],
);
tx.execute(channel_paths_stmt).await?;
} else {
let sql = r#"
DELETE FROM channel_paths
WHERE
id_path = '/' || $1 || '/'
"#;
let channel_paths_stmt = Statement::from_sql_and_values(
self.pool.get_database_backend(),
sql,
[channel.to_proto().into()],
);
tx.execute(channel_paths_stmt).await?;
}
// Make sure that there is always at least one path to the channel
let sql = r#"
INSERT INTO channel_paths
(id_path, channel_id)
SELECT
'/' || $1 || '/', $2
WHERE NOT EXISTS
(SELECT *
FROM channel_paths
WHERE channel_id = $2)
"#;
let channel_paths_stmt = Statement::from_sql_and_values(
self.pool.get_database_backend(),
sql,
[channel.to_proto().into(), channel.to_proto().into()],
);
tx.execute(channel_paths_stmt).await?;
Ok(())
}
/// Move a channel from one parent to another, returns the
/// Channels that were moved for notifying clients
pub async fn move_channel(
&self,
user: UserId,
channel: ChannelId,
from: Option<ChannelId>,
to: ChannelId,
) -> Result<Vec<Channel>> {
self.transaction(|tx| async move {
self.check_user_is_channel_admin(channel, user, &*tx)
.await?;
let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
self.unlink_channel_internal(user, channel, from, &*tx)
.await?;
Ok(moved_channels)
})
.await
}
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]

View File

@@ -1,7 +1,7 @@
use super::*;
impl Database {
pub async fn refresh_room(
pub async fn clear_stale_room_participants(
&self,
room_id: RoomId,
new_server_id: ServerId,

View File

@@ -14,31 +14,49 @@ impl Database {
.await
}
pub async fn stale_room_ids(
pub async fn stale_server_resource_ids(
&self,
environment: &str,
new_server_id: ServerId,
) -> Result<Vec<RoomId>> {
) -> Result<(Vec<RoomId>, Vec<ChannelId>)> {
self.transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryAs {
enum QueryRoomIds {
RoomId,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryChannelIds {
ChannelId,
}
let stale_server_epochs = self
.stale_server_ids(environment, new_server_id, &tx)
.await?;
Ok(room_participant::Entity::find()
let room_ids = room_participant::Entity::find()
.select_only()
.column(room_participant::Column::RoomId)
.distinct()
.filter(
room_participant::Column::AnsweringConnectionServerId
.is_in(stale_server_epochs),
.is_in(stale_server_epochs.iter().copied()),
)
.into_values::<_, QueryAs>()
.into_values::<_, QueryRoomIds>()
.all(&*tx)
.await?)
.await?;
let channel_ids = channel_buffer_collaborator::Entity::find()
.select_only()
.column(channel_buffer_collaborator::Column::ChannelId)
.distinct()
.filter(
channel_buffer_collaborator::Column::ConnectionServerId
.is_in(stale_server_epochs.iter().copied()),
)
.into_values::<_, QueryChannelIds>()
.all(&*tx)
.await?;
Ok((room_ids, channel_ids))
})
.await
}

View File

@@ -241,7 +241,6 @@ impl Database {
result
}
#[cfg(debug_assertions)]
pub async fn create_user_flag(&self, flag: &str) -> Result<FlagId> {
self.transaction(|tx| async move {
let flag = feature_flag::Entity::insert(feature_flag::ActiveModel {
@@ -257,7 +256,6 @@ impl Database {
.await
}
#[cfg(debug_assertions)]
pub async fn add_user_flag(&self, user: UserId, flag: FlagId) -> Result<()> {
self.transaction(|tx| async move {
user_feature::Entity::insert(user_feature::ActiveModel {

View File

@@ -1,4 +1,5 @@
mod buffer_tests;
mod channel_tests;
mod db_tests;
mod feature_flag_tests;

View File

@@ -0,0 +1,844 @@
use rpc::{proto, ConnectionId};
use crate::{
db::{Channel, ChannelId, Database, NewUserParams},
test_both_dbs,
};
use std::sync::Arc;
test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite);
async fn test_channels(db: &Arc<Database>) {
let a_id = db
.create_user(
"user1@example.com",
false,
NewUserParams {
github_login: "user1".into(),
github_user_id: 5,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let b_id = db
.create_user(
"user2@example.com",
false,
NewUserParams {
github_login: "user2".into(),
github_user_id: 6,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
// Make sure that people cannot read channels they haven't been invited to
assert!(db.get_channel(zed_id, b_id).await.unwrap().is_none());
db.invite_channel_member(zed_id, b_id, a_id, false)
.await
.unwrap();
db.respond_to_channel_invite(zed_id, b_id, true)
.await
.unwrap();
let crdb_id = db
.create_channel("crdb", Some(zed_id), "2", a_id)
.await
.unwrap();
let livestreaming_id = db
.create_channel("livestreaming", Some(zed_id), "3", a_id)
.await
.unwrap();
let replace_id = db
.create_channel("replace", Some(zed_id), "4", a_id)
.await
.unwrap();
let mut members = db.get_channel_members(replace_id).await.unwrap();
members.sort();
assert_eq!(members, &[a_id, b_id]);
let rust_id = db.create_root_channel("rust", "5", a_id).await.unwrap();
let cargo_id = db
.create_channel("cargo", Some(rust_id), "6", a_id)
.await
.unwrap();
let cargo_ra_id = db
.create_channel("cargo-ra", Some(cargo_id), "7", a_id)
.await
.unwrap();
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_eq!(
result.channels,
vec![
Channel {
id: zed_id,
name: "zed".to_string(),
parent_id: None,
},
Channel {
id: crdb_id,
name: "crdb".to_string(),
parent_id: Some(zed_id),
},
Channel {
id: livestreaming_id,
name: "livestreaming".to_string(),
parent_id: Some(zed_id),
},
Channel {
id: replace_id,
name: "replace".to_string(),
parent_id: Some(zed_id),
},
Channel {
id: rust_id,
name: "rust".to_string(),
parent_id: None,
},
Channel {
id: cargo_id,
name: "cargo".to_string(),
parent_id: Some(rust_id),
},
Channel {
id: cargo_ra_id,
name: "cargo-ra".to_string(),
parent_id: Some(cargo_id),
}
]
);
let result = db.get_channels_for_user(b_id).await.unwrap();
assert_eq!(
result.channels,
vec![
Channel {
id: zed_id,
name: "zed".to_string(),
parent_id: None,
},
Channel {
id: crdb_id,
name: "crdb".to_string(),
parent_id: Some(zed_id),
},
Channel {
id: livestreaming_id,
name: "livestreaming".to_string(),
parent_id: Some(zed_id),
},
Channel {
id: replace_id,
name: "replace".to_string(),
parent_id: Some(zed_id),
},
]
);
// Update member permissions
let set_subchannel_admin = db.set_channel_member_admin(crdb_id, a_id, b_id, true).await;
assert!(set_subchannel_admin.is_err());
let set_channel_admin = db.set_channel_member_admin(zed_id, a_id, b_id, true).await;
assert!(set_channel_admin.is_ok());
let result = db.get_channels_for_user(b_id).await.unwrap();
assert_eq!(
result.channels,
vec![
Channel {
id: zed_id,
name: "zed".to_string(),
parent_id: None,
},
Channel {
id: crdb_id,
name: "crdb".to_string(),
parent_id: Some(zed_id),
},
Channel {
id: livestreaming_id,
name: "livestreaming".to_string(),
parent_id: Some(zed_id),
},
Channel {
id: replace_id,
name: "replace".to_string(),
parent_id: Some(zed_id),
},
]
);
// Remove a single channel
db.delete_channel(crdb_id, a_id).await.unwrap();
assert!(db.get_channel(crdb_id, a_id).await.unwrap().is_none());
// Remove a channel tree
let (mut channel_ids, user_ids) = db.delete_channel(rust_id, a_id).await.unwrap();
channel_ids.sort();
assert_eq!(channel_ids, &[rust_id, cargo_id, cargo_ra_id]);
assert_eq!(user_ids, &[a_id]);
assert!(db.get_channel(rust_id, a_id).await.unwrap().is_none());
assert!(db.get_channel(cargo_id, a_id).await.unwrap().is_none());
assert!(db.get_channel(cargo_ra_id, a_id).await.unwrap().is_none());
}
test_both_dbs!(
test_joining_channels,
test_joining_channels_postgres,
test_joining_channels_sqlite
);
async fn test_joining_channels(db: &Arc<Database>) {
let owner_id = db.create_server("test").await.unwrap().0 as u32;
let user_1 = db
.create_user(
"user1@example.com",
false,
NewUserParams {
github_login: "user1".into(),
github_user_id: 5,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let user_2 = db
.create_user(
"user2@example.com",
false,
NewUserParams {
github_login: "user2".into(),
github_user_id: 6,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let channel_1 = db
.create_root_channel("channel_1", "1", user_1)
.await
.unwrap();
let room_1 = db.room_id_for_channel(channel_1).await.unwrap();
// can join a room with membership to its channel
let joined_room = db
.join_room(room_1, user_1, ConnectionId { owner_id, id: 1 })
.await
.unwrap();
assert_eq!(joined_room.room.participants.len(), 1);
drop(joined_room);
// cannot join a room without membership to its channel
assert!(db
.join_room(room_1, user_2, ConnectionId { owner_id, id: 1 })
.await
.is_err());
}
test_both_dbs!(
test_channel_invites,
test_channel_invites_postgres,
test_channel_invites_sqlite
);
async fn test_channel_invites(db: &Arc<Database>) {
db.create_server("test").await.unwrap();
let user_1 = db
.create_user(
"user1@example.com",
false,
NewUserParams {
github_login: "user1".into(),
github_user_id: 5,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let user_2 = db
.create_user(
"user2@example.com",
false,
NewUserParams {
github_login: "user2".into(),
github_user_id: 6,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let user_3 = db
.create_user(
"user3@example.com",
false,
NewUserParams {
github_login: "user3".into(),
github_user_id: 7,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let channel_1_1 = db
.create_root_channel("channel_1", "1", user_1)
.await
.unwrap();
let channel_1_2 = db
.create_root_channel("channel_2", "2", user_1)
.await
.unwrap();
db.invite_channel_member(channel_1_1, user_2, user_1, false)
.await
.unwrap();
db.invite_channel_member(channel_1_2, user_2, user_1, false)
.await
.unwrap();
db.invite_channel_member(channel_1_1, user_3, user_1, true)
.await
.unwrap();
let user_2_invites = db
.get_channel_invites_for_user(user_2) // -> [channel_1_1, channel_1_2]
.await
.unwrap()
.into_iter()
.map(|channel| channel.id)
.collect::<Vec<_>>();
assert_eq!(user_2_invites, &[channel_1_1, channel_1_2]);
let user_3_invites = db
.get_channel_invites_for_user(user_3) // -> [channel_1_1]
.await
.unwrap()
.into_iter()
.map(|channel| channel.id)
.collect::<Vec<_>>();
assert_eq!(user_3_invites, &[channel_1_1]);
let members = db
.get_channel_member_details(channel_1_1, user_1)
.await
.unwrap();
assert_eq!(
members,
&[
proto::ChannelMember {
user_id: user_1.to_proto(),
kind: proto::channel_member::Kind::Member.into(),
admin: true,
},
proto::ChannelMember {
user_id: user_2.to_proto(),
kind: proto::channel_member::Kind::Invitee.into(),
admin: false,
},
proto::ChannelMember {
user_id: user_3.to_proto(),
kind: proto::channel_member::Kind::Invitee.into(),
admin: true,
},
]
);
db.respond_to_channel_invite(channel_1_1, user_2, true)
.await
.unwrap();
let channel_1_3 = db
.create_channel("channel_3", Some(channel_1_1), "1", user_1)
.await
.unwrap();
let members = db
.get_channel_member_details(channel_1_3, user_1)
.await
.unwrap();
assert_eq!(
members,
&[
proto::ChannelMember {
user_id: user_1.to_proto(),
kind: proto::channel_member::Kind::Member.into(),
admin: true,
},
proto::ChannelMember {
user_id: user_2.to_proto(),
kind: proto::channel_member::Kind::AncestorMember.into(),
admin: false,
},
]
);
}
test_both_dbs!(
test_channel_renames,
test_channel_renames_postgres,
test_channel_renames_sqlite
);
async fn test_channel_renames(db: &Arc<Database>) {
db.create_server("test").await.unwrap();
let user_1 = db
.create_user(
"user1@example.com",
false,
NewUserParams {
github_login: "user1".into(),
github_user_id: 5,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let user_2 = db
.create_user(
"user2@example.com",
false,
NewUserParams {
github_login: "user2".into(),
github_user_id: 6,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let zed_id = db.create_root_channel("zed", "1", user_1).await.unwrap();
db.rename_channel(zed_id, user_1, "#zed-archive")
.await
.unwrap();
let zed_archive_id = zed_id;
let (channel, _) = db
.get_channel(zed_archive_id, user_1)
.await
.unwrap()
.unwrap();
assert_eq!(channel.name, "zed-archive");
let non_permissioned_rename = db
.rename_channel(zed_archive_id, user_2, "hacked-lol")
.await;
assert!(non_permissioned_rename.is_err());
let bad_name_rename = db.rename_channel(zed_id, user_1, "#").await;
assert!(bad_name_rename.is_err())
}
test_both_dbs!(
test_channels_moving,
test_channels_moving_postgres,
test_channels_moving_sqlite
);
async fn test_channels_moving(db: &Arc<Database>) {
let a_id = db
.create_user(
"user1@example.com",
false,
NewUserParams {
github_login: "user1".into(),
github_user_id: 5,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
let crdb_id = db
.create_channel("crdb", Some(zed_id), "2", a_id)
.await
.unwrap();
let gpui2_id = db
.create_channel("gpui2", Some(zed_id), "3", a_id)
.await
.unwrap();
let livestreaming_id = db
.create_channel("livestreaming", Some(crdb_id), "4", a_id)
.await
.unwrap();
let livestreaming_dag_id = db
.create_channel("livestreaming_dag", Some(livestreaming_id), "5", a_id)
.await
.unwrap();
// ========================================================================
// sanity check
// Initial DAG:
// /- gpui2
// zed -- crdb - livestreaming - livestreaming_dag
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_dag(
result.channels,
&[
(zed_id, None),
(crdb_id, Some(zed_id)),
(gpui2_id, Some(zed_id)),
(livestreaming_id, Some(crdb_id)),
(livestreaming_dag_id, Some(livestreaming_id)),
],
);
// Attempt to make a cycle
assert!(db
.link_channel(a_id, zed_id, livestreaming_id)
.await
.is_err());
// ========================================================================
// Make a link
db.link_channel(a_id, livestreaming_id, zed_id)
.await
.unwrap();
// DAG is now:
// /- gpui2
// zed -- crdb - livestreaming - livestreaming_dag
// \---------/
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_dag(result.channels, &[
(zed_id, None),
(crdb_id, Some(zed_id)),
(gpui2_id, Some(zed_id)),
(livestreaming_id, Some(zed_id)),
(livestreaming_id, Some(crdb_id)),
(livestreaming_dag_id, Some(livestreaming_id)),
]);
// ========================================================================
// Create a new channel below a channel with multiple parents
let livestreaming_dag_sub_id = db
.create_channel(
"livestreaming_dag_sub",
Some(livestreaming_dag_id),
"6",
a_id,
)
.await
.unwrap();
// DAG is now:
// /- gpui2
// zed -- crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub_id
// \---------/
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_dag(result.channels, &[
(zed_id, None),
(crdb_id, Some(zed_id)),
(gpui2_id, Some(zed_id)),
(livestreaming_id, Some(zed_id)),
(livestreaming_id, Some(crdb_id)),
(livestreaming_dag_id, Some(livestreaming_id)),
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
]);
// ========================================================================
// Test a complex DAG by making another link
let returned_channels = db
.link_channel(a_id, livestreaming_dag_sub_id, livestreaming_id)
.await
.unwrap();
// DAG is now:
// /- gpui2 /---------------------\
// zed - crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub_id
// \--------/
// make sure we're getting just the new link
// Not using the assert_dag helper because we want to make sure we're returning the full data
pretty_assertions::assert_eq!(
returned_channels,
vec![Channel {
id: livestreaming_dag_sub_id,
name: "livestreaming_dag_sub".to_string(),
parent_id: Some(livestreaming_id),
}]
);
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_dag(result.channels, &[
(zed_id, None),
(crdb_id, Some(zed_id)),
(gpui2_id, Some(zed_id)),
(livestreaming_id, Some(zed_id)),
(livestreaming_id, Some(crdb_id)),
(livestreaming_dag_id, Some(livestreaming_id)),
(livestreaming_dag_sub_id, Some(livestreaming_id)),
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
]);
// ========================================================================
// Test a complex DAG by making another link
let returned_channels = db
.link_channel(a_id, livestreaming_id, gpui2_id)
.await
.unwrap();
// DAG is now:
// /- gpui2 -\ /---------------------\
// zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub_id
// \---------/
// Make sure that we're correctly getting the full sub-dag
pretty_assertions::assert_eq!(
returned_channels,
vec![
Channel {
id: livestreaming_id,
name: "livestreaming".to_string(),
parent_id: Some(gpui2_id),
},
Channel {
id: livestreaming_dag_id,
name: "livestreaming_dag".to_string(),
parent_id: Some(livestreaming_id),
},
Channel {
id: livestreaming_dag_sub_id,
name: "livestreaming_dag_sub".to_string(),
parent_id: Some(livestreaming_id),
},
Channel {
id: livestreaming_dag_sub_id,
name: "livestreaming_dag_sub".to_string(),
parent_id: Some(livestreaming_dag_id),
}
]
);
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_dag(result.channels, &[
(zed_id, None),
(crdb_id, Some(zed_id)),
(gpui2_id, Some(zed_id)),
(livestreaming_id, Some(gpui2_id)),
(livestreaming_id, Some(zed_id)),
(livestreaming_id, Some(crdb_id)),
(livestreaming_dag_id, Some(livestreaming_id)),
(livestreaming_dag_sub_id, Some(livestreaming_id)),
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
]);
// ========================================================================
// Test unlinking in a complex DAG by removing the inner link
db
.unlink_channel(
a_id,
livestreaming_dag_sub_id,
Some(livestreaming_id),
)
.await
.unwrap();
// DAG is now:
// /- gpui2 -\
// zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub
// \---------/
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_dag(result.channels, &[
(zed_id, None),
(crdb_id, Some(zed_id)),
(gpui2_id, Some(zed_id)),
(livestreaming_id, Some(gpui2_id)),
(livestreaming_id, Some(zed_id)),
(livestreaming_id, Some(crdb_id)),
(livestreaming_dag_id, Some(livestreaming_id)),
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
]);
// ========================================================================
// Test unlinking in a complex DAG by removing the inner link
db.unlink_channel(a_id, livestreaming_id, Some(gpui2_id))
.await
.unwrap();
// DAG is now:
// /- gpui2
// zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub
// \---------/
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_dag(result.channels, &[
(zed_id, None),
(crdb_id, Some(zed_id)),
(gpui2_id, Some(zed_id)),
(livestreaming_id, Some(zed_id)),
(livestreaming_id, Some(crdb_id)),
(livestreaming_dag_id, Some(livestreaming_id)),
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
]);
// ========================================================================
// Test moving DAG nodes by moving livestreaming to be below gpui2
db.move_channel(a_id, livestreaming_id, Some(crdb_id), gpui2_id)
.await
.unwrap();
// DAG is now:
// /- gpui2 -- livestreaming - livestreaming_dag - livestreaming_dag_sub
// zed - crdb /
// \---------/
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_dag(result.channels, &[
(zed_id, None),
(crdb_id, Some(zed_id)),
(gpui2_id, Some(zed_id)),
(livestreaming_id, Some(zed_id)),
(livestreaming_id, Some(gpui2_id)),
(livestreaming_dag_id, Some(livestreaming_id)),
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
]);
// ========================================================================
// Deleting a channel should not delete children that still have other parents
db.delete_channel(gpui2_id, a_id).await.unwrap();
// DAG is now:
// zed - crdb
// \- livestreaming - livestreaming_dag - livestreaming_dag_sub
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_dag(result.channels, &[
(zed_id, None),
(crdb_id, Some(zed_id)),
(livestreaming_id, Some(zed_id)),
(livestreaming_dag_id, Some(livestreaming_id)),
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
]);
// ========================================================================
// Unlinking a channel from it's parent should automatically promote it to a root channel
db.unlink_channel(a_id, crdb_id, Some(zed_id))
.await
.unwrap();
// DAG is now:
// crdb
// zed
// \- livestreaming - livestreaming_dag - livestreaming_dag_sub
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_dag(result.channels, &[
(zed_id, None),
(crdb_id, None),
(livestreaming_id, Some(zed_id)),
(livestreaming_dag_id, Some(livestreaming_id)),
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
]);
// ========================================================================
// Unlinking a root channel should not have any effect
db.unlink_channel(a_id, crdb_id, None)
.await
.unwrap();
// DAG is now:
// crdb
// zed
// \- livestreaming - livestreaming_dag - livestreaming_dag_sub
//
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_dag(result.channels, &[
(zed_id, None),
(crdb_id, None),
(livestreaming_id, Some(zed_id)),
(livestreaming_dag_id, Some(livestreaming_id)),
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
]);
// ========================================================================
// You should be able to move a root channel into a non-root channel
db.move_channel(a_id, crdb_id, None, zed_id)
.await
.unwrap();
// DAG is now:
// zed - crdb
// \- livestreaming - livestreaming_dag - livestreaming_dag_sub
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_dag(result.channels, &[
(zed_id, None),
(crdb_id, Some(zed_id)),
(livestreaming_id, Some(zed_id)),
(livestreaming_dag_id, Some(livestreaming_id)),
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
]);
// ========================================================================
// Moving a non-root channel without a parent id should be the equivalent of a link operation
db.move_channel(a_id, livestreaming_id, None, crdb_id)
.await
.unwrap();
// DAG is now:
// zed - crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub
// \--------/
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_dag(result.channels, &[
(zed_id, None),
(crdb_id, Some(zed_id)),
(livestreaming_id, Some(zed_id)),
(livestreaming_id, Some(crdb_id)),
(livestreaming_dag_id, Some(livestreaming_id)),
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
]);
// ========================================================================
// Deleting a parent of a DAG should delete the whole DAG:
db.delete_channel(zed_id, a_id).await.unwrap();
let result = db.get_channels_for_user(a_id).await.unwrap();
assert!(
result.channels.is_empty()
)
}
#[track_caller]
fn assert_dag(actual: Vec<Channel>, expected: &[(ChannelId, Option<ChannelId>)]) {
let actual = actual
.iter()
.map(|channel| (channel.id, channel.parent_id))
.collect::<Vec<_>>();
pretty_assertions::assert_eq!(actual, expected)
}

View File

@@ -877,458 +877,6 @@ async fn test_invite_codes() {
assert!(db.has_contact(user5, user1).await.unwrap());
}
test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite);
async fn test_channels(db: &Arc<Database>) {
let a_id = db
.create_user(
"user1@example.com",
false,
NewUserParams {
github_login: "user1".into(),
github_user_id: 5,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let b_id = db
.create_user(
"user2@example.com",
false,
NewUserParams {
github_login: "user2".into(),
github_user_id: 6,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
// Make sure that people cannot read channels they haven't been invited to
assert!(db.get_channel(zed_id, b_id).await.unwrap().is_none());
db.invite_channel_member(zed_id, b_id, a_id, false)
.await
.unwrap();
db.respond_to_channel_invite(zed_id, b_id, true)
.await
.unwrap();
let crdb_id = db
.create_channel("crdb", Some(zed_id), "2", a_id)
.await
.unwrap();
let livestreaming_id = db
.create_channel("livestreaming", Some(zed_id), "3", a_id)
.await
.unwrap();
let replace_id = db
.create_channel("replace", Some(zed_id), "4", a_id)
.await
.unwrap();
let mut members = db.get_channel_members(replace_id).await.unwrap();
members.sort();
assert_eq!(members, &[a_id, b_id]);
let rust_id = db.create_root_channel("rust", "5", a_id).await.unwrap();
let cargo_id = db
.create_channel("cargo", Some(rust_id), "6", a_id)
.await
.unwrap();
let cargo_ra_id = db
.create_channel("cargo-ra", Some(cargo_id), "7", a_id)
.await
.unwrap();
let result = db.get_channels_for_user(a_id).await.unwrap();
assert_eq!(
result.channels,
vec![
Channel {
id: zed_id,
name: "zed".to_string(),
parent_id: None,
},
Channel {
id: crdb_id,
name: "crdb".to_string(),
parent_id: Some(zed_id),
},
Channel {
id: livestreaming_id,
name: "livestreaming".to_string(),
parent_id: Some(zed_id),
},
Channel {
id: replace_id,
name: "replace".to_string(),
parent_id: Some(zed_id),
},
Channel {
id: rust_id,
name: "rust".to_string(),
parent_id: None,
},
Channel {
id: cargo_id,
name: "cargo".to_string(),
parent_id: Some(rust_id),
},
Channel {
id: cargo_ra_id,
name: "cargo-ra".to_string(),
parent_id: Some(cargo_id),
}
]
);
let result = db.get_channels_for_user(b_id).await.unwrap();
assert_eq!(
result.channels,
vec![
Channel {
id: zed_id,
name: "zed".to_string(),
parent_id: None,
},
Channel {
id: crdb_id,
name: "crdb".to_string(),
parent_id: Some(zed_id),
},
Channel {
id: livestreaming_id,
name: "livestreaming".to_string(),
parent_id: Some(zed_id),
},
Channel {
id: replace_id,
name: "replace".to_string(),
parent_id: Some(zed_id),
},
]
);
// Update member permissions
let set_subchannel_admin = db.set_channel_member_admin(crdb_id, a_id, b_id, true).await;
assert!(set_subchannel_admin.is_err());
let set_channel_admin = db.set_channel_member_admin(zed_id, a_id, b_id, true).await;
assert!(set_channel_admin.is_ok());
let result = db.get_channels_for_user(b_id).await.unwrap();
assert_eq!(
result.channels,
vec![
Channel {
id: zed_id,
name: "zed".to_string(),
parent_id: None,
},
Channel {
id: crdb_id,
name: "crdb".to_string(),
parent_id: Some(zed_id),
},
Channel {
id: livestreaming_id,
name: "livestreaming".to_string(),
parent_id: Some(zed_id),
},
Channel {
id: replace_id,
name: "replace".to_string(),
parent_id: Some(zed_id),
},
]
);
// Remove a single channel
db.remove_channel(crdb_id, a_id).await.unwrap();
assert!(db.get_channel(crdb_id, a_id).await.unwrap().is_none());
// Remove a channel tree
let (mut channel_ids, user_ids) = db.remove_channel(rust_id, a_id).await.unwrap();
channel_ids.sort();
assert_eq!(channel_ids, &[rust_id, cargo_id, cargo_ra_id]);
assert_eq!(user_ids, &[a_id]);
assert!(db.get_channel(rust_id, a_id).await.unwrap().is_none());
assert!(db.get_channel(cargo_id, a_id).await.unwrap().is_none());
assert!(db.get_channel(cargo_ra_id, a_id).await.unwrap().is_none());
}
test_both_dbs!(
test_joining_channels,
test_joining_channels_postgres,
test_joining_channels_sqlite
);
async fn test_joining_channels(db: &Arc<Database>) {
let owner_id = db.create_server("test").await.unwrap().0 as u32;
let user_1 = db
.create_user(
"user1@example.com",
false,
NewUserParams {
github_login: "user1".into(),
github_user_id: 5,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let user_2 = db
.create_user(
"user2@example.com",
false,
NewUserParams {
github_login: "user2".into(),
github_user_id: 6,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let channel_1 = db
.create_root_channel("channel_1", "1", user_1)
.await
.unwrap();
let room_1 = db.room_id_for_channel(channel_1).await.unwrap();
// can join a room with membership to its channel
let joined_room = db
.join_room(room_1, user_1, ConnectionId { owner_id, id: 1 })
.await
.unwrap();
assert_eq!(joined_room.room.participants.len(), 1);
drop(joined_room);
// cannot join a room without membership to its channel
assert!(db
.join_room(room_1, user_2, ConnectionId { owner_id, id: 1 })
.await
.is_err());
}
test_both_dbs!(
test_channel_invites,
test_channel_invites_postgres,
test_channel_invites_sqlite
);
async fn test_channel_invites(db: &Arc<Database>) {
db.create_server("test").await.unwrap();
let user_1 = db
.create_user(
"user1@example.com",
false,
NewUserParams {
github_login: "user1".into(),
github_user_id: 5,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let user_2 = db
.create_user(
"user2@example.com",
false,
NewUserParams {
github_login: "user2".into(),
github_user_id: 6,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let user_3 = db
.create_user(
"user3@example.com",
false,
NewUserParams {
github_login: "user3".into(),
github_user_id: 7,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let channel_1_1 = db
.create_root_channel("channel_1", "1", user_1)
.await
.unwrap();
let channel_1_2 = db
.create_root_channel("channel_2", "2", user_1)
.await
.unwrap();
db.invite_channel_member(channel_1_1, user_2, user_1, false)
.await
.unwrap();
db.invite_channel_member(channel_1_2, user_2, user_1, false)
.await
.unwrap();
db.invite_channel_member(channel_1_1, user_3, user_1, true)
.await
.unwrap();
let user_2_invites = db
.get_channel_invites_for_user(user_2) // -> [channel_1_1, channel_1_2]
.await
.unwrap()
.into_iter()
.map(|channel| channel.id)
.collect::<Vec<_>>();
assert_eq!(user_2_invites, &[channel_1_1, channel_1_2]);
let user_3_invites = db
.get_channel_invites_for_user(user_3) // -> [channel_1_1]
.await
.unwrap()
.into_iter()
.map(|channel| channel.id)
.collect::<Vec<_>>();
assert_eq!(user_3_invites, &[channel_1_1]);
let members = db
.get_channel_member_details(channel_1_1, user_1)
.await
.unwrap();
assert_eq!(
members,
&[
proto::ChannelMember {
user_id: user_1.to_proto(),
kind: proto::channel_member::Kind::Member.into(),
admin: true,
},
proto::ChannelMember {
user_id: user_2.to_proto(),
kind: proto::channel_member::Kind::Invitee.into(),
admin: false,
},
proto::ChannelMember {
user_id: user_3.to_proto(),
kind: proto::channel_member::Kind::Invitee.into(),
admin: true,
},
]
);
db.respond_to_channel_invite(channel_1_1, user_2, true)
.await
.unwrap();
let channel_1_3 = db
.create_channel("channel_3", Some(channel_1_1), "1", user_1)
.await
.unwrap();
let members = db
.get_channel_member_details(channel_1_3, user_1)
.await
.unwrap();
assert_eq!(
members,
&[
proto::ChannelMember {
user_id: user_1.to_proto(),
kind: proto::channel_member::Kind::Member.into(),
admin: true,
},
proto::ChannelMember {
user_id: user_2.to_proto(),
kind: proto::channel_member::Kind::AncestorMember.into(),
admin: false,
},
]
);
}
test_both_dbs!(
test_channel_renames,
test_channel_renames_postgres,
test_channel_renames_sqlite
);
async fn test_channel_renames(db: &Arc<Database>) {
db.create_server("test").await.unwrap();
let user_1 = db
.create_user(
"user1@example.com",
false,
NewUserParams {
github_login: "user1".into(),
github_user_id: 5,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let user_2 = db
.create_user(
"user2@example.com",
false,
NewUserParams {
github_login: "user2".into(),
github_user_id: 6,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let zed_id = db.create_root_channel("zed", "1", user_1).await.unwrap();
db.rename_channel(zed_id, user_1, "#zed-archive")
.await
.unwrap();
let zed_archive_id = zed_id;
let (channel, _) = db
.get_channel(zed_archive_id, user_1)
.await
.unwrap()
.unwrap();
assert_eq!(channel.name, "zed-archive");
let non_permissioned_rename = db
.rename_channel(zed_archive_id, user_2, "hacked-lol")
.await;
assert!(non_permissioned_rename.is_err());
let bad_name_rename = db.rename_channel(zed_id, user_1, "#").await;
assert!(bad_name_rename.is_err())
}
#[gpui::test]
async fn test_multiple_signup_overwrite() {
let test_db = TestDb::postgres(build_background_executor());

View File

@@ -2,7 +2,10 @@ mod connection_pool;
use crate::{
auth,
db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
db::{
self, Channel, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User,
UserId,
},
executor::Executor,
AppState, Result,
};
@@ -243,7 +246,7 @@ impl Server {
.add_request_handler(remove_contact)
.add_request_handler(respond_to_contact_request)
.add_request_handler(create_channel)
.add_request_handler(remove_channel)
.add_request_handler(delete_channel)
.add_request_handler(invite_channel_member)
.add_request_handler(remove_channel_member)
.add_request_handler(set_channel_member_admin)
@@ -251,9 +254,13 @@ impl Server {
.add_request_handler(join_channel_buffer)
.add_request_handler(leave_channel_buffer)
.add_message_handler(update_channel_buffer)
.add_request_handler(rejoin_channel_buffers)
.add_request_handler(get_channel_members)
.add_request_handler(respond_to_channel_invite)
.add_request_handler(join_channel)
.add_request_handler(link_channel)
.add_request_handler(unlink_channel)
.add_request_handler(move_channel)
.add_request_handler(follow)
.add_message_handler(unfollow)
.add_message_handler(update_followers)
@@ -277,13 +284,33 @@ impl Server {
tracing::info!("waiting for cleanup timeout");
timeout.await;
tracing::info!("cleanup timeout expired, retrieving stale rooms");
if let Some(room_ids) = app_state
if let Some((room_ids, channel_ids)) = app_state
.db
.stale_room_ids(&app_state.config.zed_environment, server_id)
.stale_server_resource_ids(&app_state.config.zed_environment, server_id)
.await
.trace_err()
{
tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
tracing::info!(
stale_channel_buffer_count = channel_ids.len(),
"retrieved stale channel buffers"
);
for channel_id in channel_ids {
if let Some(refreshed_channel_buffer) = app_state
.db
.clear_stale_channel_buffer_collaborators(channel_id, server_id)
.await
.trace_err()
{
for connection_id in refreshed_channel_buffer.connection_ids {
for message in &refreshed_channel_buffer.removed_collaborators {
peer.send(connection_id, message.clone()).trace_err();
}
}
}
}
for room_id in room_ids {
let mut contacts_to_update = HashSet::default();
let mut canceled_calls_to_user_ids = Vec::new();
@@ -292,7 +319,7 @@ impl Server {
if let Some(mut refreshed_room) = app_state
.db
.refresh_room(room_id, server_id)
.clear_stale_room_participants(room_id, server_id)
.await
.trace_err()
{
@@ -854,13 +881,13 @@ async fn connection_lost(
.await
.trace_err();
leave_channel_buffers_for_session(&session)
.await
.trace_err();
futures::select_biased! {
_ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
leave_room_for_session(&session).await.trace_err();
leave_channel_buffers_for_session(&session)
.await
.trace_err();
if !session
.connection_pool()
@@ -2206,23 +2233,23 @@ async fn create_channel(
Ok(())
}
async fn remove_channel(
request: proto::RemoveChannel,
response: Response<proto::RemoveChannel>,
async fn delete_channel(
request: proto::DeleteChannel,
response: Response<proto::DeleteChannel>,
session: Session,
) -> Result<()> {
let db = session.db().await;
let channel_id = request.channel_id;
let (removed_channels, member_ids) = db
.remove_channel(ChannelId::from_proto(channel_id), session.user_id)
.delete_channel(ChannelId::from_proto(channel_id), session.user_id)
.await?;
response.send(proto::Ack {})?;
// Notify members of removed channels
let mut update = proto::UpdateChannels::default();
update
.remove_channels
.delete_channels
.extend(removed_channels.into_iter().map(|id| id.to_proto()));
let connection_pool = session.connection_pool().await;
@@ -2282,7 +2309,7 @@ async fn remove_channel_member(
.await?;
let mut update = proto::UpdateChannels::default();
update.remove_channels.push(channel_id.to_proto());
update.delete_channels.push(channel_id.to_proto());
for connection_id in session
.connection_pool()
@@ -2366,6 +2393,126 @@ async fn rename_channel(
Ok(())
}
async fn link_channel(
request: proto::LinkChannel,
response: Response<proto::LinkChannel>,
session: Session,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
let to = ChannelId::from_proto(request.to);
let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
let members = db.get_channel_members(to).await?;
let connection_pool = session.connection_pool().await;
let update = proto::UpdateChannels {
channels: channels_to_send
.into_iter()
.map(|channel| proto::Channel {
id: channel.id.to_proto(),
name: channel.name,
parent_id: channel.parent_id.map(ChannelId::to_proto),
})
.collect(),
..Default::default()
};
for member_id in members {
for connection_id in connection_pool.user_connection_ids(member_id) {
session.peer.send(connection_id, update.clone())?;
}
}
response.send(Ack {})?;
Ok(())
}
async fn unlink_channel(
request: proto::UnlinkChannel,
response: Response<proto::UnlinkChannel>,
session: Session,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
let from = request.from.map(ChannelId::from_proto);
db.unlink_channel(session.user_id, channel_id, from).await?;
if let Some(from_parent) = from {
let members = db.get_channel_members(from_parent).await?;
let update = proto::UpdateChannels {
delete_channel_edge: vec![proto::ChannelEdge {
channel_id: channel_id.to_proto(),
parent_id: from_parent.to_proto(),
}],
..Default::default()
};
let connection_pool = session.connection_pool().await;
for member_id in members {
for connection_id in connection_pool.user_connection_ids(member_id) {
session.peer.send(connection_id, update.clone())?;
}
}
}
response.send(Ack {})?;
Ok(())
}
async fn move_channel(
request: proto::MoveChannel,
response: Response<proto::MoveChannel>,
session: Session,
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
let from_parent = request.from.map(ChannelId::from_proto);
let to = ChannelId::from_proto(request.to);
let channels_to_send: Vec<Channel> = db
.move_channel(session.user_id, channel_id, from_parent, to)
.await?;
if let Some(from_parent) = from_parent {
let members = db.get_channel_members(from_parent).await?;
let update = proto::UpdateChannels {
delete_channel_edge: vec![proto::ChannelEdge {
channel_id: channel_id.to_proto(),
parent_id: from_parent.to_proto(),
}],
..Default::default()
};
let connection_pool = session.connection_pool().await;
for member_id in members {
for connection_id in connection_pool.user_connection_ids(member_id) {
session.peer.send(connection_id, update.clone())?;
}
}
}
let members = db.get_channel_members(to).await?;
let connection_pool = session.connection_pool().await;
let update = proto::UpdateChannels {
channels: channels_to_send
.into_iter()
.map(|channel| proto::Channel {
id: channel.id.to_proto(),
name: channel.name,
parent_id: channel.parent_id.map(ChannelId::to_proto),
})
.collect(),
..Default::default()
};
for member_id in members {
for connection_id in connection_pool.user_connection_ids(member_id) {
session.peer.send(connection_id, update.clone())?;
}
}
response.send(Ack {})?;
Ok(())
}
async fn get_channel_members(
request: proto::GetChannelMembers,
response: Response<proto::GetChannelMembers>,
@@ -2547,6 +2694,41 @@ async fn update_channel_buffer(
Ok(())
}
async fn rejoin_channel_buffers(
request: proto::RejoinChannelBuffers,
response: Response<proto::RejoinChannelBuffers>,
session: Session,
) -> Result<()> {
let db = session.db().await;
let buffers = db
.rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
.await?;
for buffer in &buffers {
let collaborators_to_notify = buffer
.buffer
.collaborators
.iter()
.filter_map(|c| Some(c.peer_id?.into()));
channel_buffer_updated(
session.connection_id,
collaborators_to_notify,
&proto::UpdateChannelBufferCollaborator {
channel_id: buffer.buffer.channel_id,
old_peer_id: Some(buffer.old_connection_id.into()),
new_peer_id: Some(session.connection_id.into()),
},
&session.peer,
);
}
response.send(proto::RejoinChannelBuffersResponse {
buffers: buffers.into_iter().map(|b| b.buffer).collect(),
})?;
Ok(())
}
async fn leave_channel_buffer(
request: proto::LeaveChannelBuffer,
response: Response<proto::LeaveChannelBuffer>,

View File

@@ -1,555 +1,18 @@
use crate::{
db::{tests::TestDb, NewUserParams, UserId},
executor::Executor,
rpc::{Server, CLEANUP_TIMEOUT},
AppState,
};
use anyhow::anyhow;
use call::{ActiveCall, Room};
use channel::ChannelStore;
use client::{
self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
};
use collections::{HashMap, HashSet};
use fs::FakeFs;
use futures::{channel::oneshot, StreamExt as _};
use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle};
use language::LanguageRegistry;
use parking_lot::Mutex;
use project::{Project, WorktreeId};
use settings::SettingsStore;
use std::{
cell::{Ref, RefCell, RefMut},
env,
ops::{Deref, DerefMut},
path::Path,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
Arc,
},
};
use util::http::FakeHttpClient;
use workspace::Workspace;
use call::Room;
use gpui::{ModelHandle, TestAppContext};
mod channel_buffer_tests;
mod channel_tests;
mod integration_tests;
mod randomized_integration_tests;
mod random_channel_buffer_tests;
mod random_project_collaboration_tests;
mod randomized_test_helpers;
mod test_server;
struct TestServer {
app_state: Arc<AppState>,
server: Arc<Server>,
connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
forbid_connections: Arc<AtomicBool>,
_test_db: TestDb,
test_live_kit_server: Arc<live_kit_client::TestServer>,
}
impl TestServer {
async fn start(deterministic: &Arc<Deterministic>) -> Self {
static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0);
let use_postgres = env::var("USE_POSTGRES").ok();
let use_postgres = use_postgres.as_deref();
let test_db = if use_postgres == Some("true") || use_postgres == Some("1") {
TestDb::postgres(deterministic.build_background())
} else {
TestDb::sqlite(deterministic.build_background())
};
let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst);
let live_kit_server = live_kit_client::TestServer::create(
format!("http://livekit.{}.test", live_kit_server_id),
format!("devkey-{}", live_kit_server_id),
format!("secret-{}", live_kit_server_id),
deterministic.build_background(),
)
.unwrap();
let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
let epoch = app_state
.db
.create_server(&app_state.config.zed_environment)
.await
.unwrap();
let server = Server::new(
epoch,
app_state.clone(),
Executor::Deterministic(deterministic.build_background()),
);
server.start().await.unwrap();
// Advance clock to ensure the server's cleanup task is finished.
deterministic.advance_clock(CLEANUP_TIMEOUT);
Self {
app_state,
server,
connection_killers: Default::default(),
forbid_connections: Default::default(),
_test_db: test_db,
test_live_kit_server: live_kit_server,
}
}
async fn reset(&self) {
self.app_state.db.reset();
let epoch = self
.app_state
.db
.create_server(&self.app_state.config.zed_environment)
.await
.unwrap();
self.server.reset(epoch);
}
async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
cx.update(|cx| {
if cx.has_global::<SettingsStore>() {
panic!("Same cx used to create two test clients")
}
cx.set_global(SettingsStore::test(cx));
});
let http = FakeHttpClient::with_404_response();
let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
{
user.id
} else {
self.app_state
.db
.create_user(
&format!("{name}@example.com"),
false,
NewUserParams {
github_login: name.into(),
github_user_id: 0,
invite_count: 0,
},
)
.await
.expect("creating user failed")
.user_id
};
let client_name = name.to_string();
let mut client = cx.read(|cx| Client::new(http.clone(), cx));
let server = self.server.clone();
let db = self.app_state.db.clone();
let connection_killers = self.connection_killers.clone();
let forbid_connections = self.forbid_connections.clone();
Arc::get_mut(&mut client)
.unwrap()
.set_id(user_id.0 as usize)
.override_authenticate(move |cx| {
cx.spawn(|_| async move {
let access_token = "the-token".to_string();
Ok(Credentials {
user_id: user_id.0 as u64,
access_token,
})
})
})
.override_establish_connection(move |credentials, cx| {
assert_eq!(credentials.user_id, user_id.0 as u64);
assert_eq!(credentials.access_token, "the-token");
let server = server.clone();
let db = db.clone();
let connection_killers = connection_killers.clone();
let forbid_connections = forbid_connections.clone();
let client_name = client_name.clone();
cx.spawn(move |cx| async move {
if forbid_connections.load(SeqCst) {
Err(EstablishConnectionError::other(anyhow!(
"server is forbidding connections"
)))
} else {
let (client_conn, server_conn, killed) =
Connection::in_memory(cx.background());
let (connection_id_tx, connection_id_rx) = oneshot::channel();
let user = db
.get_user_by_id(user_id)
.await
.expect("retrieving user failed")
.unwrap();
cx.background()
.spawn(server.handle_connection(
server_conn,
client_name,
user,
Some(connection_id_tx),
Executor::Deterministic(cx.background()),
))
.detach();
let connection_id = connection_id_rx.await.unwrap();
connection_killers
.lock()
.insert(connection_id.into(), killed);
Ok(client_conn)
}
})
});
let fs = FakeFs::new(cx.background());
let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
let channel_store =
cx.add_model(|cx| ChannelStore::new(client.clone(), user_store.clone(), cx));
let app_state = Arc::new(workspace::AppState {
client: client.clone(),
user_store: user_store.clone(),
channel_store: channel_store.clone(),
languages: Arc::new(LanguageRegistry::test()),
fs: fs.clone(),
build_window_options: |_, _, _| Default::default(),
initialize_workspace: |_, _, _, _| Task::ready(Ok(())),
background_actions: || &[],
});
cx.update(|cx| {
theme::init((), cx);
Project::init(&client, cx);
client::init(&client, cx);
language::init(cx);
editor::init_settings(cx);
workspace::init(app_state.clone(), cx);
audio::init((), cx);
call::init(client.clone(), user_store.clone(), cx);
channel::init(&client);
});
client
.authenticate_and_connect(false, &cx.to_async())
.await
.unwrap();
let client = TestClient {
app_state,
username: name.to_string(),
state: Default::default(),
};
client.wait_for_current_user(cx).await;
client
}
fn disconnect_client(&self, peer_id: PeerId) {
self.connection_killers
.lock()
.remove(&peer_id)
.unwrap()
.store(true, SeqCst);
}
fn forbid_connections(&self) {
self.forbid_connections.store(true, SeqCst);
}
fn allow_connections(&self) {
self.forbid_connections.store(false, SeqCst);
}
async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
for ix in 1..clients.len() {
let (left, right) = clients.split_at_mut(ix);
let (client_a, cx_a) = left.last_mut().unwrap();
for (client_b, cx_b) in right {
client_a
.app_state
.user_store
.update(*cx_a, |store, cx| {
store.request_contact(client_b.user_id().unwrap(), cx)
})
.await
.unwrap();
cx_a.foreground().run_until_parked();
client_b
.app_state
.user_store
.update(*cx_b, |store, cx| {
store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx)
})
.await
.unwrap();
}
}
}
async fn make_channel(
&self,
channel: &str,
admin: (&TestClient, &mut TestAppContext),
members: &mut [(&TestClient, &mut TestAppContext)],
) -> u64 {
let (admin_client, admin_cx) = admin;
let channel_id = admin_client
.app_state
.channel_store
.update(admin_cx, |channel_store, cx| {
channel_store.create_channel(channel, None, cx)
})
.await
.unwrap();
for (member_client, member_cx) in members {
admin_client
.app_state
.channel_store
.update(admin_cx, |channel_store, cx| {
channel_store.invite_member(
channel_id,
member_client.user_id().unwrap(),
false,
cx,
)
})
.await
.unwrap();
admin_cx.foreground().run_until_parked();
member_client
.app_state
.channel_store
.update(*member_cx, |channels, _| {
channels.respond_to_channel_invite(channel_id, true)
})
.await
.unwrap();
}
channel_id
}
async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
self.make_contacts(clients).await;
let (left, right) = clients.split_at_mut(1);
let (_client_a, cx_a) = &mut left[0];
let active_call_a = cx_a.read(ActiveCall::global);
for (client_b, cx_b) in right {
let user_id_b = client_b.current_user_id(*cx_b).to_proto();
active_call_a
.update(*cx_a, |call, cx| call.invite(user_id_b, None, cx))
.await
.unwrap();
cx_b.foreground().run_until_parked();
let active_call_b = cx_b.read(ActiveCall::global);
active_call_b
.update(*cx_b, |call, cx| call.accept_incoming(cx))
.await
.unwrap();
}
}
async fn build_app_state(
test_db: &TestDb,
fake_server: &live_kit_client::TestServer,
) -> Arc<AppState> {
Arc::new(AppState {
db: test_db.db().clone(),
live_kit_client: Some(Arc::new(fake_server.create_api_client())),
config: Default::default(),
})
}
}
impl Deref for TestServer {
type Target = Server;
fn deref(&self) -> &Self::Target {
&self.server
}
}
impl Drop for TestServer {
fn drop(&mut self) {
self.server.teardown();
self.test_live_kit_server.teardown().unwrap();
}
}
struct TestClient {
username: String,
state: RefCell<TestClientState>,
app_state: Arc<workspace::AppState>,
}
#[derive(Default)]
struct TestClientState {
local_projects: Vec<ModelHandle<Project>>,
remote_projects: Vec<ModelHandle<Project>>,
buffers: HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>,
}
impl Deref for TestClient {
type Target = Arc<Client>;
fn deref(&self) -> &Self::Target {
&self.app_state.client
}
}
struct ContactsSummary {
pub current: Vec<String>,
pub outgoing_requests: Vec<String>,
pub incoming_requests: Vec<String>,
}
impl TestClient {
pub fn fs(&self) -> &FakeFs {
self.app_state.fs.as_fake()
}
pub fn channel_store(&self) -> &ModelHandle<ChannelStore> {
&self.app_state.channel_store
}
pub fn user_store(&self) -> &ModelHandle<UserStore> {
&self.app_state.user_store
}
pub fn language_registry(&self) -> &Arc<LanguageRegistry> {
&self.app_state.languages
}
pub fn client(&self) -> &Arc<Client> {
&self.app_state.client
}
pub fn current_user_id(&self, cx: &TestAppContext) -> UserId {
UserId::from_proto(
self.app_state
.user_store
.read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
)
}
async fn wait_for_current_user(&self, cx: &TestAppContext) {
let mut authed_user = self
.app_state
.user_store
.read_with(cx, |user_store, _| user_store.watch_current_user());
while authed_user.next().await.unwrap().is_none() {}
}
async fn clear_contacts(&self, cx: &mut TestAppContext) {
self.app_state
.user_store
.update(cx, |store, _| store.clear_contacts())
.await;
}
fn local_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
Ref::map(self.state.borrow(), |state| &state.local_projects)
}
fn remote_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
Ref::map(self.state.borrow(), |state| &state.remote_projects)
}
fn local_projects_mut<'a>(&'a self) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects)
}
fn remote_projects_mut<'a>(&'a self) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects)
}
fn buffers_for_project<'a>(
&'a self,
project: &ModelHandle<Project>,
) -> impl DerefMut<Target = HashSet<ModelHandle<language::Buffer>>> + 'a {
RefMut::map(self.state.borrow_mut(), |state| {
state.buffers.entry(project.clone()).or_default()
})
}
fn buffers<'a>(
&'a self,
) -> impl DerefMut<Target = HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>> + 'a
{
RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers)
}
fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary {
self.app_state
.user_store
.read_with(cx, |store, _| ContactsSummary {
current: store
.contacts()
.iter()
.map(|contact| contact.user.github_login.clone())
.collect(),
outgoing_requests: store
.outgoing_contact_requests()
.iter()
.map(|user| user.github_login.clone())
.collect(),
incoming_requests: store
.incoming_contact_requests()
.iter()
.map(|user| user.github_login.clone())
.collect(),
})
}
async fn build_local_project(
&self,
root_path: impl AsRef<Path>,
cx: &mut TestAppContext,
) -> (ModelHandle<Project>, WorktreeId) {
let project = cx.update(|cx| {
Project::local(
self.client().clone(),
self.app_state.user_store.clone(),
self.app_state.languages.clone(),
self.app_state.fs.clone(),
cx,
)
});
let (worktree, _) = project
.update(cx, |p, cx| {
p.find_or_create_local_worktree(root_path, true, cx)
})
.await
.unwrap();
worktree
.read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete())
.await;
(project, worktree.read_with(cx, |tree, _| tree.id()))
}
async fn build_remote_project(
&self,
host_project_id: u64,
guest_cx: &mut TestAppContext,
) -> ModelHandle<Project> {
let active_call = guest_cx.read(ActiveCall::global);
let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone());
room.update(guest_cx, |room, cx| {
room.join_project(
host_project_id,
self.app_state.languages.clone(),
self.app_state.fs.clone(),
cx,
)
})
.await
.unwrap()
}
fn build_workspace(
&self,
project: &ModelHandle<Project>,
cx: &mut TestAppContext,
) -> WindowHandle<Workspace> {
cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx))
}
}
impl Drop for TestClient {
fn drop(&mut self) {
self.app_state.client.teardown();
}
}
pub use randomized_test_helpers::{
run_randomized_test, save_randomized_test_plan, RandomizedTest, TestError, UserTestPlan,
};
pub use test_server::{TestClient, TestServer};
#[derive(Debug, Eq, PartialEq)]
struct RoomParticipants {

View File

@@ -1,4 +1,7 @@
use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
use crate::{
rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
tests::TestServer,
};
use call::ActiveCall;
use channel::Channel;
use client::UserId;
@@ -21,20 +24,19 @@ async fn test_core_channel_buffers(
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let zed_id = server
let channel_id = server
.make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)])
.await;
// Client A joins the channel buffer
let channel_buffer_a = client_a
.channel_store()
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx))
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
// Client A edits the buffer
let buffer_a = channel_buffer_a.read_with(cx_a, |buffer, _| buffer.buffer());
buffer_a.update(cx_a, |buffer, cx| {
buffer.edit([(0..0, "hello world")], None, cx)
});
@@ -45,17 +47,15 @@ async fn test_core_channel_buffers(
buffer.edit([(0..5, "goodbye")], None, cx)
});
buffer_a.update(cx_a, |buffer, cx| buffer.undo(cx));
deterministic.run_until_parked();
assert_eq!(buffer_text(&buffer_a, cx_a), "hello, cruel world");
deterministic.run_until_parked();
// Client B joins the channel buffer
let channel_buffer_b = client_b
.channel_store()
.update(cx_b, |channel, cx| channel.open_channel_buffer(zed_id, cx))
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
channel_buffer_b.read_with(cx_b, |buffer, _| {
assert_collaborators(
buffer.collaborators(),
@@ -91,9 +91,7 @@ async fn test_core_channel_buffers(
// Client A rejoins the channel buffer
let _channel_buffer_a = client_a
.channel_store()
.update(cx_a, |channels, cx| {
channels.open_channel_buffer(zed_id, cx)
})
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
deterministic.run_until_parked();
@@ -136,7 +134,7 @@ async fn test_channel_buffer_replica_ids(
let channel_id = server
.make_channel(
"zed",
"the-channel",
(&client_a, cx_a),
&mut [(&client_b, cx_b), (&client_c, cx_c)],
)
@@ -160,23 +158,17 @@ async fn test_channel_buffer_replica_ids(
// C first so that the replica IDs in the project and the channel buffer are different
let channel_buffer_c = client_c
.channel_store()
.update(cx_c, |channel, cx| {
channel.open_channel_buffer(channel_id, cx)
})
.update(cx_c, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
let channel_buffer_b = client_b
.channel_store()
.update(cx_b, |channel, cx| {
channel.open_channel_buffer(channel_id, cx)
})
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
let channel_buffer_a = client_a
.channel_store()
.update(cx_a, |channel, cx| {
channel.open_channel_buffer(channel_id, cx)
})
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
@@ -286,28 +278,30 @@ async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mu
let mut server = TestServer::start(&deterministic).await;
let client_a = server.create_client(cx_a, "user_a").await;
let zed_id = server.make_channel("zed", (&client_a, cx_a), &mut []).await;
let channel_id = server
.make_channel("the-channel", (&client_a, cx_a), &mut [])
.await;
let channel_buffer_1 = client_a
.channel_store()
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx));
let channel_buffer_2 = client_a
.channel_store()
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx));
let channel_buffer_3 = client_a
.channel_store()
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx));
// All concurrent tasks for opening a channel buffer return the same model handle.
let (channel_buffer_1, channel_buffer_2, channel_buffer_3) =
let (channel_buffer, channel_buffer_2, channel_buffer_3) =
future::try_join3(channel_buffer_1, channel_buffer_2, channel_buffer_3)
.await
.unwrap();
let model_id = channel_buffer_1.id();
assert_eq!(channel_buffer_1, channel_buffer_2);
assert_eq!(channel_buffer_1, channel_buffer_3);
let channel_buffer_model_id = channel_buffer.id();
assert_eq!(channel_buffer, channel_buffer_2);
assert_eq!(channel_buffer, channel_buffer_3);
channel_buffer_1.update(cx_a, |buffer, cx| {
channel_buffer.update(cx_a, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "hello")], None, cx);
})
@@ -315,7 +309,7 @@ async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mu
deterministic.run_until_parked();
cx_a.update(|_| {
drop(channel_buffer_1);
drop(channel_buffer);
drop(channel_buffer_2);
drop(channel_buffer_3);
});
@@ -324,10 +318,10 @@ async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mu
// The channel buffer can be reopened after dropping it.
let channel_buffer = client_a
.channel_store()
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx))
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
assert_ne!(channel_buffer.id(), model_id);
assert_ne!(channel_buffer.id(), channel_buffer_model_id);
channel_buffer.update(cx_a, |buffer, cx| {
buffer.buffer().update(cx, |buffer, _| {
assert_eq!(buffer.text(), "hello");
@@ -347,22 +341,17 @@ async fn test_channel_buffer_disconnect(
let client_b = server.create_client(cx_b, "user_b").await;
let channel_id = server
.make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)])
.make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)])
.await;
let channel_buffer_a = client_a
.channel_store()
.update(cx_a, |channel, cx| {
channel.open_channel_buffer(channel_id, cx)
})
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
let channel_buffer_b = client_b
.channel_store()
.update(cx_b, |channel, cx| {
channel.open_channel_buffer(channel_id, cx)
})
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
@@ -375,7 +364,7 @@ async fn test_channel_buffer_disconnect(
buffer.channel().as_ref(),
&Channel {
id: channel_id,
name: "zed".to_string()
name: "the-channel".to_string()
}
);
assert!(!buffer.is_connected());
@@ -403,13 +392,180 @@ async fn test_channel_buffer_disconnect(
buffer.channel().as_ref(),
&Channel {
id: channel_id,
name: "zed".to_string()
name: "the-channel".to_string()
}
);
assert!(!buffer.is_connected());
});
}
#[gpui::test]
async fn test_rejoin_channel_buffer(
deterministic: Arc<Deterministic>,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
) {
deterministic.forbid_parking();
let mut server = TestServer::start(&deterministic).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let channel_id = server
.make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)])
.await;
let channel_buffer_a = client_a
.channel_store()
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
let channel_buffer_b = client_b
.channel_store()
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
channel_buffer_a.update(cx_a, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "1")], None, cx);
})
});
deterministic.run_until_parked();
// Client A disconnects.
server.forbid_connections();
server.disconnect_client(client_a.peer_id().unwrap());
// Both clients make an edit.
channel_buffer_a.update(cx_a, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.edit([(1..1, "2")], None, cx);
})
});
channel_buffer_b.update(cx_b, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "0")], None, cx);
})
});
// Both clients see their own edit.
deterministic.run_until_parked();
channel_buffer_a.read_with(cx_a, |buffer, cx| {
assert_eq!(buffer.buffer().read(cx).text(), "12");
});
channel_buffer_b.read_with(cx_b, |buffer, cx| {
assert_eq!(buffer.buffer().read(cx).text(), "01");
});
// Client A reconnects. Both clients see each other's edits, and see
// the same collaborators.
server.allow_connections();
deterministic.advance_clock(RECEIVE_TIMEOUT);
channel_buffer_a.read_with(cx_a, |buffer, cx| {
assert_eq!(buffer.buffer().read(cx).text(), "012");
});
channel_buffer_b.read_with(cx_b, |buffer, cx| {
assert_eq!(buffer.buffer().read(cx).text(), "012");
});
channel_buffer_a.read_with(cx_a, |buffer_a, _| {
channel_buffer_b.read_with(cx_b, |buffer_b, _| {
assert_eq!(buffer_a.collaborators(), buffer_b.collaborators());
});
});
}
#[gpui::test]
async fn test_channel_buffers_and_server_restarts(
deterministic: Arc<Deterministic>,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
cx_c: &mut TestAppContext,
) {
deterministic.forbid_parking();
let mut server = TestServer::start(&deterministic).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let client_c = server.create_client(cx_c, "user_c").await;
let channel_id = server
.make_channel(
"the-channel",
(&client_a, cx_a),
&mut [(&client_b, cx_b), (&client_c, cx_c)],
)
.await;
let channel_buffer_a = client_a
.channel_store()
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
let channel_buffer_b = client_b
.channel_store()
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
let _channel_buffer_c = client_c
.channel_store()
.update(cx_c, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
channel_buffer_a.update(cx_a, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "1")], None, cx);
})
});
deterministic.run_until_parked();
// Client C can't reconnect.
client_c.override_establish_connection(|_, cx| cx.spawn(|_| future::pending()));
// Server stops.
server.reset().await;
deterministic.advance_clock(RECEIVE_TIMEOUT);
// While the server is down, both clients make an edit.
channel_buffer_a.update(cx_a, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.edit([(1..1, "2")], None, cx);
})
});
channel_buffer_b.update(cx_b, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "0")], None, cx);
})
});
// Server restarts.
server.start().await.unwrap();
deterministic.advance_clock(CLEANUP_TIMEOUT);
// Clients reconnects. Clients A and B see each other's edits, and see
// that client C has disconnected.
channel_buffer_a.read_with(cx_a, |buffer, cx| {
assert_eq!(buffer.buffer().read(cx).text(), "012");
});
channel_buffer_b.read_with(cx_b, |buffer, cx| {
assert_eq!(buffer.buffer().read(cx).text(), "012");
});
channel_buffer_a.read_with(cx_a, |buffer_a, _| {
channel_buffer_b.read_with(cx_b, |buffer_b, _| {
assert_eq!(
buffer_a
.collaborators()
.iter()
.map(|c| c.user_id)
.collect::<Vec<_>>(),
vec![client_a.user_id().unwrap(), client_b.user_id().unwrap()]
);
assert_eq!(buffer_a.collaborators(), buffer_b.collaborators());
});
});
}
#[track_caller]
fn assert_collaborators(collaborators: &[proto::Collaborator], ids: &[Option<UserId>]) {
assert_eq!(

View File

@@ -874,6 +874,143 @@ async fn test_lost_channel_creation(
);
}
#[gpui::test]
async fn test_channel_moving(deterministic: Arc<Deterministic>, cx_a: &mut TestAppContext) {
deterministic.forbid_parking();
let mut server = TestServer::start(&deterministic).await;
let client_a = server.create_client(cx_a, "user_a").await;
let channel_a_id = client_a
.channel_store()
.update(cx_a, |channel_store, cx| {
channel_store.create_channel("channel-a", None, cx)
})
.await
.unwrap();
let channel_b_id = client_a
.channel_store()
.update(cx_a, |channel_store, cx| {
channel_store.create_channel("channel-b", Some(channel_a_id), cx)
})
.await
.unwrap();
let channel_c_id = client_a
.channel_store()
.update(cx_a, |channel_store, cx| {
channel_store.create_channel("channel-c", Some(channel_b_id), cx)
})
.await
.unwrap();
// Current shape:
// a - b - c
deterministic.run_until_parked();
assert_channels(
client_a.channel_store(),
cx_a,
&[
ExpectedChannel {
id: channel_a_id,
name: "channel-a".to_string(),
depth: 0,
user_is_admin: true,
},
ExpectedChannel {
id: channel_b_id,
name: "channel-b".to_string(),
depth: 1,
user_is_admin: true,
},
ExpectedChannel {
id: channel_c_id,
name: "channel-c".to_string(),
depth: 2,
user_is_admin: true,
},
],
);
client_a
.channel_store()
.update(cx_a, |channel_store, cx| {
channel_store.move_channel(channel_c_id, Some(channel_b_id), channel_a_id, cx)
})
.await
.unwrap();
// Current shape:
// /- c
// a -- b
deterministic.run_until_parked();
assert_channels(
client_a.channel_store(),
cx_a,
&[
ExpectedChannel {
id: channel_a_id,
name: "channel-a".to_string(),
depth: 0,
user_is_admin: true,
},
ExpectedChannel {
id: channel_b_id,
name: "channel-b".to_string(),
depth: 1,
user_is_admin: true,
},
ExpectedChannel {
id: channel_c_id,
name: "channel-c".to_string(),
depth: 1,
user_is_admin: true,
},
],
);
client_a
.channel_store()
.update(cx_a, |channel_store, cx| {
channel_store.link_channel(channel_c_id, channel_b_id, cx)
})
.await
.unwrap();
// Current shape:
// /------\
// a -- b -- c
deterministic.run_until_parked();
assert_channels(
client_a.channel_store(),
cx_a,
&[
ExpectedChannel {
id: channel_a_id,
name: "channel-a".to_string(),
depth: 0,
user_is_admin: true,
},
ExpectedChannel {
id: channel_b_id,
name: "channel-b".to_string(),
depth: 1,
user_is_admin: true,
},
ExpectedChannel {
id: channel_c_id,
name: "channel-c".to_string(),
depth: 2,
user_is_admin: true,
},
ExpectedChannel {
id: channel_c_id,
name: "channel-c".to_string(),
depth: 1,
user_is_admin: true,
},
],
);
}
#[derive(Debug, PartialEq)]
struct ExpectedChannel {
depth: usize,
@@ -920,5 +1057,5 @@ fn assert_channels(
})
.collect::<Vec<_>>()
});
assert_eq!(actual, expected_channels);
pretty_assertions::assert_eq!(actual, expected_channels);
}

View File

@@ -0,0 +1,288 @@
use super::{run_randomized_test, RandomizedTest, TestClient, TestError, TestServer, UserTestPlan};
use anyhow::Result;
use async_trait::async_trait;
use gpui::{executor::Deterministic, TestAppContext};
use rand::prelude::*;
use serde_derive::{Deserialize, Serialize};
use std::{ops::Range, rc::Rc, sync::Arc};
use text::Bias;
#[gpui::test(
iterations = 100,
on_failure = "crate::tests::save_randomized_test_plan"
)]
async fn test_random_channel_buffers(
cx: &mut TestAppContext,
deterministic: Arc<Deterministic>,
rng: StdRng,
) {
run_randomized_test::<RandomChannelBufferTest>(cx, deterministic, rng).await;
}
struct RandomChannelBufferTest;
#[derive(Clone, Serialize, Deserialize)]
enum ChannelBufferOperation {
JoinChannelNotes {
channel_name: String,
},
LeaveChannelNotes {
channel_name: String,
},
EditChannelNotes {
channel_name: String,
edits: Vec<(Range<usize>, Arc<str>)>,
},
Noop,
}
const CHANNEL_COUNT: usize = 3;
#[async_trait(?Send)]
impl RandomizedTest for RandomChannelBufferTest {
type Operation = ChannelBufferOperation;
async fn initialize(server: &mut TestServer, users: &[UserTestPlan]) {
let db = &server.app_state.db;
for ix in 0..CHANNEL_COUNT {
let id = db
.create_channel(
&format!("channel-{ix}"),
None,
&format!("livekit-room-{ix}"),
users[0].user_id,
)
.await
.unwrap();
for user in &users[1..] {
db.invite_channel_member(id, user.user_id, users[0].user_id, false)
.await
.unwrap();
db.respond_to_channel_invite(id, user.user_id, true)
.await
.unwrap();
}
}
}
fn generate_operation(
client: &TestClient,
rng: &mut StdRng,
_: &mut UserTestPlan,
cx: &TestAppContext,
) -> ChannelBufferOperation {
let channel_store = client.channel_store().clone();
let channel_buffers = client.channel_buffers();
// When signed out, we can't do anything unless a channel buffer is
// already open.
if channel_buffers.is_empty()
&& channel_store.read_with(cx, |store, _| store.channel_count() == 0)
{
return ChannelBufferOperation::Noop;
}
loop {
match rng.gen_range(0..100_u32) {
0..=29 => {
let channel_name = client.channel_store().read_with(cx, |store, cx| {
store.channels().find_map(|(_, channel)| {
if store.has_open_channel_buffer(channel.id, cx) {
None
} else {
Some(channel.name.clone())
}
})
});
if let Some(channel_name) = channel_name {
break ChannelBufferOperation::JoinChannelNotes { channel_name };
}
}
30..=40 => {
if let Some(buffer) = channel_buffers.iter().choose(rng) {
let channel_name = buffer.read_with(cx, |b, _| b.channel().name.clone());
break ChannelBufferOperation::LeaveChannelNotes { channel_name };
}
}
_ => {
if let Some(buffer) = channel_buffers.iter().choose(rng) {
break buffer.read_with(cx, |b, _| {
let channel_name = b.channel().name.clone();
let edits = b
.buffer()
.read_with(cx, |buffer, _| buffer.get_random_edits(rng, 3));
ChannelBufferOperation::EditChannelNotes {
channel_name,
edits,
}
});
}
}
}
}
}
async fn apply_operation(
client: &TestClient,
operation: ChannelBufferOperation,
cx: &mut TestAppContext,
) -> Result<(), TestError> {
match operation {
ChannelBufferOperation::JoinChannelNotes { channel_name } => {
let buffer = client.channel_store().update(cx, |store, cx| {
let channel_id = store
.channels()
.find(|(_, c)| c.name == channel_name)
.unwrap()
.1
.id;
if store.has_open_channel_buffer(channel_id, cx) {
Err(TestError::Inapplicable)
} else {
Ok(store.open_channel_buffer(channel_id, cx))
}
})?;
log::info!(
"{}: opening notes for channel {channel_name}",
client.username
);
client.channel_buffers().insert(buffer.await?);
}
ChannelBufferOperation::LeaveChannelNotes { channel_name } => {
let buffer = cx.update(|cx| {
let mut left_buffer = Err(TestError::Inapplicable);
client.channel_buffers().retain(|buffer| {
if buffer.read(cx).channel().name == channel_name {
left_buffer = Ok(buffer.clone());
false
} else {
true
}
});
left_buffer
})?;
log::info!(
"{}: closing notes for channel {channel_name}",
client.username
);
cx.update(|_| drop(buffer));
}
ChannelBufferOperation::EditChannelNotes {
channel_name,
edits,
} => {
let channel_buffer = cx
.read(|cx| {
client
.channel_buffers()
.iter()
.find(|buffer| buffer.read(cx).channel().name == channel_name)
.cloned()
})
.ok_or_else(|| TestError::Inapplicable)?;
log::info!(
"{}: editing notes for channel {channel_name} with {:?}",
client.username,
edits
);
channel_buffer.update(cx, |buffer, cx| {
let buffer = buffer.buffer();
buffer.update(cx, |buffer, cx| {
let snapshot = buffer.snapshot();
buffer.edit(
edits.into_iter().map(|(range, text)| {
let start = snapshot.clip_offset(range.start, Bias::Left);
let end = snapshot.clip_offset(range.end, Bias::Right);
(start..end, text)
}),
None,
cx,
);
});
});
}
ChannelBufferOperation::Noop => Err(TestError::Inapplicable)?,
}
Ok(())
}
async fn on_client_added(client: &Rc<TestClient>, cx: &mut TestAppContext) {
let channel_store = client.channel_store();
while channel_store.read_with(cx, |store, _| store.channel_count() == 0) {
channel_store.next_notification(cx).await;
}
}
async fn on_quiesce(server: &mut TestServer, clients: &mut [(Rc<TestClient>, TestAppContext)]) {
let channels = server.app_state.db.all_channels().await.unwrap();
for (client, client_cx) in clients.iter_mut() {
client_cx.update(|cx| {
client
.channel_buffers()
.retain(|b| b.read(cx).is_connected());
});
}
for (channel_id, channel_name) in channels {
let mut prev_text: Option<(u64, String)> = None;
let mut collaborator_user_ids = server
.app_state
.db
.get_channel_buffer_collaborators(channel_id)
.await
.unwrap()
.into_iter()
.map(|id| id.to_proto())
.collect::<Vec<_>>();
collaborator_user_ids.sort();
for (client, client_cx) in clients.iter() {
let user_id = client.user_id().unwrap();
client_cx.read(|cx| {
if let Some(channel_buffer) = client
.channel_buffers()
.iter()
.find(|b| b.read(cx).channel().id == channel_id.to_proto())
{
let channel_buffer = channel_buffer.read(cx);
// Assert that channel buffer's text matches other clients' copies.
let text = channel_buffer.buffer().read(cx).text();
if let Some((prev_user_id, prev_text)) = &prev_text {
assert_eq!(
&text,
prev_text,
"client {user_id} has different text than client {prev_user_id} for channel {channel_name}",
);
} else {
prev_text = Some((user_id, text.clone()));
}
// Assert that all clients and the server agree about who is present in the
// channel buffer.
let collaborators = channel_buffer.collaborators();
let mut user_ids =
collaborators.iter().map(|c| c.user_id).collect::<Vec<_>>();
user_ids.sort();
assert_eq!(
user_ids,
collaborator_user_ids,
"client {user_id} has different user ids for channel {channel_name} than the server",
);
}
});
}
}
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,689 @@
use crate::{
db::{self, NewUserParams, UserId},
rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
tests::{TestClient, TestServer},
};
use async_trait::async_trait;
use futures::StreamExt;
use gpui::{executor::Deterministic, Task, TestAppContext};
use parking_lot::Mutex;
use rand::prelude::*;
use rpc::RECEIVE_TIMEOUT;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use settings::SettingsStore;
use std::{
env,
path::PathBuf,
rc::Rc,
sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
},
};
lazy_static::lazy_static! {
static ref PLAN_LOAD_PATH: Option<PathBuf> = path_env_var("LOAD_PLAN");
static ref PLAN_SAVE_PATH: Option<PathBuf> = path_env_var("SAVE_PLAN");
static ref MAX_PEERS: usize = env::var("MAX_PEERS")
.map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
.unwrap_or(3);
static ref MAX_OPERATIONS: usize = env::var("OPERATIONS")
.map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
.unwrap_or(10);
}
static LOADED_PLAN_JSON: Mutex<Option<Vec<u8>>> = Mutex::new(None);
static LAST_PLAN: Mutex<Option<Box<dyn Send + FnOnce() -> Vec<u8>>>> = Mutex::new(None);
struct TestPlan<T: RandomizedTest> {
rng: StdRng,
replay: bool,
stored_operations: Vec<(StoredOperation<T::Operation>, Arc<AtomicBool>)>,
max_operations: usize,
operation_ix: usize,
users: Vec<UserTestPlan>,
next_batch_id: usize,
allow_server_restarts: bool,
allow_client_reconnection: bool,
allow_client_disconnection: bool,
}
pub struct UserTestPlan {
pub user_id: UserId,
pub username: String,
pub allow_client_reconnection: bool,
pub allow_client_disconnection: bool,
next_root_id: usize,
operation_ix: usize,
online: bool,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(untagged)]
enum StoredOperation<T> {
Server(ServerOperation),
Client {
user_id: UserId,
batch_id: usize,
operation: T,
},
}
#[derive(Clone, Debug, Serialize, Deserialize)]
enum ServerOperation {
AddConnection {
user_id: UserId,
},
RemoveConnection {
user_id: UserId,
},
BounceConnection {
user_id: UserId,
},
RestartServer,
MutateClients {
batch_id: usize,
#[serde(skip_serializing)]
#[serde(skip_deserializing)]
user_ids: Vec<UserId>,
quiesce: bool,
},
}
pub enum TestError {
Inapplicable,
Other(anyhow::Error),
}
#[async_trait(?Send)]
pub trait RandomizedTest: 'static + Sized {
type Operation: Send + Clone + Serialize + DeserializeOwned;
fn generate_operation(
client: &TestClient,
rng: &mut StdRng,
plan: &mut UserTestPlan,
cx: &TestAppContext,
) -> Self::Operation;
async fn apply_operation(
client: &TestClient,
operation: Self::Operation,
cx: &mut TestAppContext,
) -> Result<(), TestError>;
async fn initialize(server: &mut TestServer, users: &[UserTestPlan]);
async fn on_client_added(client: &Rc<TestClient>, cx: &mut TestAppContext);
async fn on_quiesce(server: &mut TestServer, client: &mut [(Rc<TestClient>, TestAppContext)]);
}
pub async fn run_randomized_test<T: RandomizedTest>(
cx: &mut TestAppContext,
deterministic: Arc<Deterministic>,
rng: StdRng,
) {
deterministic.forbid_parking();
let mut server = TestServer::start(&deterministic).await;
let plan = TestPlan::<T>::new(&mut server, rng).await;
LAST_PLAN.lock().replace({
let plan = plan.clone();
Box::new(move || plan.lock().serialize())
});
let mut clients = Vec::new();
let mut client_tasks = Vec::new();
let mut operation_channels = Vec::new();
loop {
let Some((next_operation, applied)) = plan.lock().next_server_operation(&clients) else {
break;
};
applied.store(true, SeqCst);
let did_apply = TestPlan::apply_server_operation(
plan.clone(),
deterministic.clone(),
&mut server,
&mut clients,
&mut client_tasks,
&mut operation_channels,
next_operation,
cx,
)
.await;
if !did_apply {
applied.store(false, SeqCst);
}
}
drop(operation_channels);
deterministic.start_waiting();
futures::future::join_all(client_tasks).await;
deterministic.finish_waiting();
deterministic.run_until_parked();
T::on_quiesce(&mut server, &mut clients).await;
for (client, mut cx) in clients {
cx.update(|cx| {
let store = cx.remove_global::<SettingsStore>();
cx.clear_globals();
cx.set_global(store);
drop(client);
});
}
deterministic.run_until_parked();
if let Some(path) = &*PLAN_SAVE_PATH {
eprintln!("saved test plan to path {:?}", path);
std::fs::write(path, plan.lock().serialize()).unwrap();
}
}
pub fn save_randomized_test_plan() {
if let Some(serialize_plan) = LAST_PLAN.lock().take() {
if let Some(path) = &*PLAN_SAVE_PATH {
eprintln!("saved test plan to path {:?}", path);
std::fs::write(path, serialize_plan()).unwrap();
}
}
}
impl<T: RandomizedTest> TestPlan<T> {
pub async fn new(server: &mut TestServer, mut rng: StdRng) -> Arc<Mutex<Self>> {
let allow_server_restarts = rng.gen_bool(0.7);
let allow_client_reconnection = rng.gen_bool(0.7);
let allow_client_disconnection = rng.gen_bool(0.1);
let mut users = Vec::new();
for ix in 0..*MAX_PEERS {
let username = format!("user-{}", ix + 1);
let user_id = server
.app_state
.db
.create_user(
&format!("{username}@example.com"),
false,
NewUserParams {
github_login: username.clone(),
github_user_id: (ix + 1) as i32,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
users.push(UserTestPlan {
user_id,
username,
online: false,
next_root_id: 0,
operation_ix: 0,
allow_client_disconnection,
allow_client_reconnection,
});
}
T::initialize(server, &users).await;
let plan = Arc::new(Mutex::new(Self {
replay: false,
allow_server_restarts,
allow_client_reconnection,
allow_client_disconnection,
stored_operations: Vec::new(),
operation_ix: 0,
next_batch_id: 0,
max_operations: *MAX_OPERATIONS,
users,
rng,
}));
if let Some(path) = &*PLAN_LOAD_PATH {
let json = LOADED_PLAN_JSON
.lock()
.get_or_insert_with(|| {
eprintln!("loaded test plan from path {:?}", path);
std::fs::read(path).unwrap()
})
.clone();
plan.lock().deserialize(json);
}
plan
}
fn deserialize(&mut self, json: Vec<u8>) {
let stored_operations: Vec<StoredOperation<T::Operation>> =
serde_json::from_slice(&json).unwrap();
self.replay = true;
self.stored_operations = stored_operations
.iter()
.cloned()
.enumerate()
.map(|(i, mut operation)| {
let did_apply = Arc::new(AtomicBool::new(false));
if let StoredOperation::Server(ServerOperation::MutateClients {
batch_id: current_batch_id,
user_ids,
..
}) = &mut operation
{
assert!(user_ids.is_empty());
user_ids.extend(stored_operations[i + 1..].iter().filter_map(|operation| {
if let StoredOperation::Client {
user_id, batch_id, ..
} = operation
{
if batch_id == current_batch_id {
return Some(user_id);
}
}
None
}));
user_ids.sort_unstable();
}
(operation, did_apply)
})
.collect()
}
fn serialize(&mut self) -> Vec<u8> {
// Format each operation as one line
let mut json = Vec::new();
json.push(b'[');
for (operation, applied) in &self.stored_operations {
if !applied.load(SeqCst) {
continue;
}
if json.len() > 1 {
json.push(b',');
}
json.extend_from_slice(b"\n ");
serde_json::to_writer(&mut json, operation).unwrap();
}
json.extend_from_slice(b"\n]\n");
json
}
fn next_server_operation(
&mut self,
clients: &[(Rc<TestClient>, TestAppContext)],
) -> Option<(ServerOperation, Arc<AtomicBool>)> {
if self.replay {
while let Some(stored_operation) = self.stored_operations.get(self.operation_ix) {
self.operation_ix += 1;
if let (StoredOperation::Server(operation), applied) = stored_operation {
return Some((operation.clone(), applied.clone()));
}
}
None
} else {
let operation = self.generate_server_operation(clients)?;
let applied = Arc::new(AtomicBool::new(false));
self.stored_operations
.push((StoredOperation::Server(operation.clone()), applied.clone()));
Some((operation, applied))
}
}
fn next_client_operation(
&mut self,
client: &TestClient,
current_batch_id: usize,
cx: &TestAppContext,
) -> Option<(T::Operation, Arc<AtomicBool>)> {
let current_user_id = client.current_user_id(cx);
let user_ix = self
.users
.iter()
.position(|user| user.user_id == current_user_id)
.unwrap();
let user_plan = &mut self.users[user_ix];
if self.replay {
while let Some(stored_operation) = self.stored_operations.get(user_plan.operation_ix) {
user_plan.operation_ix += 1;
if let (
StoredOperation::Client {
user_id, operation, ..
},
applied,
) = stored_operation
{
if user_id == &current_user_id {
return Some((operation.clone(), applied.clone()));
}
}
}
None
} else {
if self.operation_ix == self.max_operations {
return None;
}
self.operation_ix += 1;
let operation = T::generate_operation(
client,
&mut self.rng,
self.users
.iter_mut()
.find(|user| user.user_id == current_user_id)
.unwrap(),
cx,
);
let applied = Arc::new(AtomicBool::new(false));
self.stored_operations.push((
StoredOperation::Client {
user_id: current_user_id,
batch_id: current_batch_id,
operation: operation.clone(),
},
applied.clone(),
));
Some((operation, applied))
}
}
fn generate_server_operation(
&mut self,
clients: &[(Rc<TestClient>, TestAppContext)],
) -> Option<ServerOperation> {
if self.operation_ix == self.max_operations {
return None;
}
Some(loop {
break match self.rng.gen_range(0..100) {
0..=29 if clients.len() < self.users.len() => {
let user = self
.users
.iter()
.filter(|u| !u.online)
.choose(&mut self.rng)
.unwrap();
self.operation_ix += 1;
ServerOperation::AddConnection {
user_id: user.user_id,
}
}
30..=34 if clients.len() > 1 && self.allow_client_disconnection => {
let (client, cx) = &clients[self.rng.gen_range(0..clients.len())];
let user_id = client.current_user_id(cx);
self.operation_ix += 1;
ServerOperation::RemoveConnection { user_id }
}
35..=39 if clients.len() > 1 && self.allow_client_reconnection => {
let (client, cx) = &clients[self.rng.gen_range(0..clients.len())];
let user_id = client.current_user_id(cx);
self.operation_ix += 1;
ServerOperation::BounceConnection { user_id }
}
40..=44 if self.allow_server_restarts && clients.len() > 1 => {
self.operation_ix += 1;
ServerOperation::RestartServer
}
_ if !clients.is_empty() => {
let count = self
.rng
.gen_range(1..10)
.min(self.max_operations - self.operation_ix);
let batch_id = util::post_inc(&mut self.next_batch_id);
let mut user_ids = (0..count)
.map(|_| {
let ix = self.rng.gen_range(0..clients.len());
let (client, cx) = &clients[ix];
client.current_user_id(cx)
})
.collect::<Vec<_>>();
user_ids.sort_unstable();
ServerOperation::MutateClients {
user_ids,
batch_id,
quiesce: self.rng.gen_bool(0.7),
}
}
_ => continue,
};
})
}
async fn apply_server_operation(
plan: Arc<Mutex<Self>>,
deterministic: Arc<Deterministic>,
server: &mut TestServer,
clients: &mut Vec<(Rc<TestClient>, TestAppContext)>,
client_tasks: &mut Vec<Task<()>>,
operation_channels: &mut Vec<futures::channel::mpsc::UnboundedSender<usize>>,
operation: ServerOperation,
cx: &mut TestAppContext,
) -> bool {
match operation {
ServerOperation::AddConnection { user_id } => {
let username;
{
let mut plan = plan.lock();
let user = plan.user(user_id);
if user.online {
return false;
}
user.online = true;
username = user.username.clone();
};
log::info!("adding new connection for {}", username);
let next_entity_id = (user_id.0 * 10_000) as usize;
let mut client_cx = TestAppContext::new(
cx.foreground_platform(),
cx.platform(),
deterministic.build_foreground(user_id.0 as usize),
deterministic.build_background(),
cx.font_cache(),
cx.leak_detector(),
next_entity_id,
cx.function_name.clone(),
);
let (operation_tx, operation_rx) = futures::channel::mpsc::unbounded();
let client = Rc::new(server.create_client(&mut client_cx, &username).await);
operation_channels.push(operation_tx);
clients.push((client.clone(), client_cx.clone()));
client_tasks.push(client_cx.foreground().spawn(Self::simulate_client(
plan.clone(),
client,
operation_rx,
client_cx,
)));
log::info!("added connection for {}", username);
}
ServerOperation::RemoveConnection {
user_id: removed_user_id,
} => {
log::info!("simulating full disconnection of user {}", removed_user_id);
let client_ix = clients
.iter()
.position(|(client, cx)| client.current_user_id(cx) == removed_user_id);
let Some(client_ix) = client_ix else {
return false;
};
let user_connection_ids = server
.connection_pool
.lock()
.user_connection_ids(removed_user_id)
.collect::<Vec<_>>();
assert_eq!(user_connection_ids.len(), 1);
let removed_peer_id = user_connection_ids[0].into();
let (client, mut client_cx) = clients.remove(client_ix);
let client_task = client_tasks.remove(client_ix);
operation_channels.remove(client_ix);
server.forbid_connections();
server.disconnect_client(removed_peer_id);
deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
deterministic.start_waiting();
log::info!("waiting for user {} to exit...", removed_user_id);
client_task.await;
deterministic.finish_waiting();
server.allow_connections();
for project in client.remote_projects().iter() {
project.read_with(&client_cx, |project, _| {
assert!(
project.is_read_only(),
"project {:?} should be read only",
project.remote_id()
)
});
}
for (client, cx) in clients {
let contacts = server
.app_state
.db
.get_contacts(client.current_user_id(cx))
.await
.unwrap();
let pool = server.connection_pool.lock();
for contact in contacts {
if let db::Contact::Accepted { user_id, busy, .. } = contact {
if user_id == removed_user_id {
assert!(!pool.is_user_online(user_id));
assert!(!busy);
}
}
}
}
log::info!("{} removed", client.username);
plan.lock().user(removed_user_id).online = false;
client_cx.update(|cx| {
cx.clear_globals();
drop(client);
});
}
ServerOperation::BounceConnection { user_id } => {
log::info!("simulating temporary disconnection of user {}", user_id);
let user_connection_ids = server
.connection_pool
.lock()
.user_connection_ids(user_id)
.collect::<Vec<_>>();
if user_connection_ids.is_empty() {
return false;
}
assert_eq!(user_connection_ids.len(), 1);
let peer_id = user_connection_ids[0].into();
server.disconnect_client(peer_id);
deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
}
ServerOperation::RestartServer => {
log::info!("simulating server restart");
server.reset().await;
deterministic.advance_clock(RECEIVE_TIMEOUT);
server.start().await.unwrap();
deterministic.advance_clock(CLEANUP_TIMEOUT);
let environment = &server.app_state.config.zed_environment;
let (stale_room_ids, _) = server
.app_state
.db
.stale_server_resource_ids(environment, server.id())
.await
.unwrap();
assert_eq!(stale_room_ids, vec![]);
}
ServerOperation::MutateClients {
user_ids,
batch_id,
quiesce,
} => {
let mut applied = false;
for user_id in user_ids {
let client_ix = clients
.iter()
.position(|(client, cx)| client.current_user_id(cx) == user_id);
let Some(client_ix) = client_ix else { continue };
applied = true;
if let Err(err) = operation_channels[client_ix].unbounded_send(batch_id) {
log::error!("error signaling user {user_id}: {err}");
}
}
if quiesce && applied {
deterministic.run_until_parked();
T::on_quiesce(server, clients).await;
}
return applied;
}
}
true
}
async fn simulate_client(
plan: Arc<Mutex<Self>>,
client: Rc<TestClient>,
mut operation_rx: futures::channel::mpsc::UnboundedReceiver<usize>,
mut cx: TestAppContext,
) {
T::on_client_added(&client, &mut cx).await;
while let Some(batch_id) = operation_rx.next().await {
let Some((operation, applied)) =
plan.lock().next_client_operation(&client, batch_id, &cx)
else {
break;
};
applied.store(true, SeqCst);
match T::apply_operation(&client, operation, &mut cx).await {
Ok(()) => {}
Err(TestError::Inapplicable) => {
applied.store(false, SeqCst);
log::info!("skipped operation");
}
Err(TestError::Other(error)) => {
log::error!("{} error: {}", client.username, error);
}
}
cx.background().simulate_random_delay().await;
}
log::info!("{}: done", client.username);
}
fn user(&mut self, user_id: UserId) -> &mut UserTestPlan {
self.users
.iter_mut()
.find(|user| user.user_id == user_id)
.unwrap()
}
}
impl UserTestPlan {
pub fn next_root_dir_name(&mut self) -> String {
let user_id = self.user_id;
let root_id = util::post_inc(&mut self.next_root_id);
format!("dir-{user_id}-{root_id}")
}
}
impl From<anyhow::Error> for TestError {
fn from(value: anyhow::Error) -> Self {
Self::Other(value)
}
}
fn path_env_var(name: &str) -> Option<PathBuf> {
let value = env::var(name).ok()?;
let mut path = PathBuf::from(value);
if path.is_relative() {
let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
abs_path.pop();
abs_path.pop();
abs_path.push(path);
path = abs_path
}
Some(path)
}

View File

@@ -0,0 +1,558 @@
use crate::{
db::{tests::TestDb, NewUserParams, UserId},
executor::Executor,
rpc::{Server, CLEANUP_TIMEOUT},
AppState,
};
use anyhow::anyhow;
use call::ActiveCall;
use channel::{channel_buffer::ChannelBuffer, ChannelStore};
use client::{
self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
};
use collections::{HashMap, HashSet};
use fs::FakeFs;
use futures::{channel::oneshot, StreamExt as _};
use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle};
use language::LanguageRegistry;
use parking_lot::Mutex;
use project::{Project, WorktreeId};
use settings::SettingsStore;
use std::{
cell::{Ref, RefCell, RefMut},
env,
ops::{Deref, DerefMut},
path::Path,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
Arc,
},
};
use util::http::FakeHttpClient;
use workspace::Workspace;
pub struct TestServer {
pub app_state: Arc<AppState>,
pub test_live_kit_server: Arc<live_kit_client::TestServer>,
server: Arc<Server>,
connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
forbid_connections: Arc<AtomicBool>,
_test_db: TestDb,
}
pub struct TestClient {
pub username: String,
pub app_state: Arc<workspace::AppState>,
state: RefCell<TestClientState>,
}
#[derive(Default)]
struct TestClientState {
local_projects: Vec<ModelHandle<Project>>,
remote_projects: Vec<ModelHandle<Project>>,
buffers: HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>,
channel_buffers: HashSet<ModelHandle<ChannelBuffer>>,
}
pub struct ContactsSummary {
pub current: Vec<String>,
pub outgoing_requests: Vec<String>,
pub incoming_requests: Vec<String>,
}
impl TestServer {
pub async fn start(deterministic: &Arc<Deterministic>) -> Self {
static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0);
let use_postgres = env::var("USE_POSTGRES").ok();
let use_postgres = use_postgres.as_deref();
let test_db = if use_postgres == Some("true") || use_postgres == Some("1") {
TestDb::postgres(deterministic.build_background())
} else {
TestDb::sqlite(deterministic.build_background())
};
let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst);
let live_kit_server = live_kit_client::TestServer::create(
format!("http://livekit.{}.test", live_kit_server_id),
format!("devkey-{}", live_kit_server_id),
format!("secret-{}", live_kit_server_id),
deterministic.build_background(),
)
.unwrap();
let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
let epoch = app_state
.db
.create_server(&app_state.config.zed_environment)
.await
.unwrap();
let server = Server::new(
epoch,
app_state.clone(),
Executor::Deterministic(deterministic.build_background()),
);
server.start().await.unwrap();
// Advance clock to ensure the server's cleanup task is finished.
deterministic.advance_clock(CLEANUP_TIMEOUT);
Self {
app_state,
server,
connection_killers: Default::default(),
forbid_connections: Default::default(),
_test_db: test_db,
test_live_kit_server: live_kit_server,
}
}
pub async fn reset(&self) {
self.app_state.db.reset();
let epoch = self
.app_state
.db
.create_server(&self.app_state.config.zed_environment)
.await
.unwrap();
self.server.reset(epoch);
}
pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
cx.update(|cx| {
if cx.has_global::<SettingsStore>() {
panic!("Same cx used to create two test clients")
}
cx.set_global(SettingsStore::test(cx));
});
let http = FakeHttpClient::with_404_response();
let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
{
user.id
} else {
self.app_state
.db
.create_user(
&format!("{name}@example.com"),
false,
NewUserParams {
github_login: name.into(),
github_user_id: 0,
invite_count: 0,
},
)
.await
.expect("creating user failed")
.user_id
};
let client_name = name.to_string();
let mut client = cx.read(|cx| Client::new(http.clone(), cx));
let server = self.server.clone();
let db = self.app_state.db.clone();
let connection_killers = self.connection_killers.clone();
let forbid_connections = self.forbid_connections.clone();
Arc::get_mut(&mut client)
.unwrap()
.set_id(user_id.0 as usize)
.override_authenticate(move |cx| {
cx.spawn(|_| async move {
let access_token = "the-token".to_string();
Ok(Credentials {
user_id: user_id.0 as u64,
access_token,
})
})
})
.override_establish_connection(move |credentials, cx| {
assert_eq!(credentials.user_id, user_id.0 as u64);
assert_eq!(credentials.access_token, "the-token");
let server = server.clone();
let db = db.clone();
let connection_killers = connection_killers.clone();
let forbid_connections = forbid_connections.clone();
let client_name = client_name.clone();
cx.spawn(move |cx| async move {
if forbid_connections.load(SeqCst) {
Err(EstablishConnectionError::other(anyhow!(
"server is forbidding connections"
)))
} else {
let (client_conn, server_conn, killed) =
Connection::in_memory(cx.background());
let (connection_id_tx, connection_id_rx) = oneshot::channel();
let user = db
.get_user_by_id(user_id)
.await
.expect("retrieving user failed")
.unwrap();
cx.background()
.spawn(server.handle_connection(
server_conn,
client_name,
user,
Some(connection_id_tx),
Executor::Deterministic(cx.background()),
))
.detach();
let connection_id = connection_id_rx.await.unwrap();
connection_killers
.lock()
.insert(connection_id.into(), killed);
Ok(client_conn)
}
})
});
let fs = FakeFs::new(cx.background());
let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
let channel_store =
cx.add_model(|cx| ChannelStore::new(client.clone(), user_store.clone(), cx));
let app_state = Arc::new(workspace::AppState {
client: client.clone(),
user_store: user_store.clone(),
channel_store: channel_store.clone(),
languages: Arc::new(LanguageRegistry::test()),
fs: fs.clone(),
build_window_options: |_, _, _| Default::default(),
initialize_workspace: |_, _, _, _| Task::ready(Ok(())),
background_actions: || &[],
});
cx.update(|cx| {
theme::init((), cx);
Project::init(&client, cx);
client::init(&client, cx);
language::init(cx);
editor::init_settings(cx);
workspace::init(app_state.clone(), cx);
audio::init((), cx);
call::init(client.clone(), user_store.clone(), cx);
channel::init(&client);
});
client
.authenticate_and_connect(false, &cx.to_async())
.await
.unwrap();
let client = TestClient {
app_state,
username: name.to_string(),
state: Default::default(),
};
client.wait_for_current_user(cx).await;
client
}
pub fn disconnect_client(&self, peer_id: PeerId) {
self.connection_killers
.lock()
.remove(&peer_id)
.unwrap()
.store(true, SeqCst);
}
pub fn forbid_connections(&self) {
self.forbid_connections.store(true, SeqCst);
}
pub fn allow_connections(&self) {
self.forbid_connections.store(false, SeqCst);
}
pub async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
for ix in 1..clients.len() {
let (left, right) = clients.split_at_mut(ix);
let (client_a, cx_a) = left.last_mut().unwrap();
for (client_b, cx_b) in right {
client_a
.app_state
.user_store
.update(*cx_a, |store, cx| {
store.request_contact(client_b.user_id().unwrap(), cx)
})
.await
.unwrap();
cx_a.foreground().run_until_parked();
client_b
.app_state
.user_store
.update(*cx_b, |store, cx| {
store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx)
})
.await
.unwrap();
}
}
}
pub async fn make_channel(
&self,
channel: &str,
admin: (&TestClient, &mut TestAppContext),
members: &mut [(&TestClient, &mut TestAppContext)],
) -> u64 {
let (admin_client, admin_cx) = admin;
let channel_id = admin_client
.app_state
.channel_store
.update(admin_cx, |channel_store, cx| {
channel_store.create_channel(channel, None, cx)
})
.await
.unwrap();
for (member_client, member_cx) in members {
admin_client
.app_state
.channel_store
.update(admin_cx, |channel_store, cx| {
channel_store.invite_member(
channel_id,
member_client.user_id().unwrap(),
false,
cx,
)
})
.await
.unwrap();
admin_cx.foreground().run_until_parked();
member_client
.app_state
.channel_store
.update(*member_cx, |channels, _| {
channels.respond_to_channel_invite(channel_id, true)
})
.await
.unwrap();
}
channel_id
}
pub async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
self.make_contacts(clients).await;
let (left, right) = clients.split_at_mut(1);
let (_client_a, cx_a) = &mut left[0];
let active_call_a = cx_a.read(ActiveCall::global);
for (client_b, cx_b) in right {
let user_id_b = client_b.current_user_id(*cx_b).to_proto();
active_call_a
.update(*cx_a, |call, cx| call.invite(user_id_b, None, cx))
.await
.unwrap();
cx_b.foreground().run_until_parked();
let active_call_b = cx_b.read(ActiveCall::global);
active_call_b
.update(*cx_b, |call, cx| call.accept_incoming(cx))
.await
.unwrap();
}
}
pub async fn build_app_state(
test_db: &TestDb,
fake_server: &live_kit_client::TestServer,
) -> Arc<AppState> {
Arc::new(AppState {
db: test_db.db().clone(),
live_kit_client: Some(Arc::new(fake_server.create_api_client())),
config: Default::default(),
})
}
}
impl Deref for TestServer {
type Target = Server;
fn deref(&self) -> &Self::Target {
&self.server
}
}
impl Drop for TestServer {
fn drop(&mut self) {
self.server.teardown();
self.test_live_kit_server.teardown().unwrap();
}
}
impl Deref for TestClient {
type Target = Arc<Client>;
fn deref(&self) -> &Self::Target {
&self.app_state.client
}
}
impl TestClient {
pub fn fs(&self) -> &FakeFs {
self.app_state.fs.as_fake()
}
pub fn channel_store(&self) -> &ModelHandle<ChannelStore> {
&self.app_state.channel_store
}
pub fn user_store(&self) -> &ModelHandle<UserStore> {
&self.app_state.user_store
}
pub fn language_registry(&self) -> &Arc<LanguageRegistry> {
&self.app_state.languages
}
pub fn client(&self) -> &Arc<Client> {
&self.app_state.client
}
pub fn current_user_id(&self, cx: &TestAppContext) -> UserId {
UserId::from_proto(
self.app_state
.user_store
.read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
)
}
pub async fn wait_for_current_user(&self, cx: &TestAppContext) {
let mut authed_user = self
.app_state
.user_store
.read_with(cx, |user_store, _| user_store.watch_current_user());
while authed_user.next().await.unwrap().is_none() {}
}
pub async fn clear_contacts(&self, cx: &mut TestAppContext) {
self.app_state
.user_store
.update(cx, |store, _| store.clear_contacts())
.await;
}
pub fn local_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
Ref::map(self.state.borrow(), |state| &state.local_projects)
}
pub fn remote_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
Ref::map(self.state.borrow(), |state| &state.remote_projects)
}
pub fn local_projects_mut<'a>(
&'a self,
) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects)
}
pub fn remote_projects_mut<'a>(
&'a self,
) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects)
}
pub fn buffers_for_project<'a>(
&'a self,
project: &ModelHandle<Project>,
) -> impl DerefMut<Target = HashSet<ModelHandle<language::Buffer>>> + 'a {
RefMut::map(self.state.borrow_mut(), |state| {
state.buffers.entry(project.clone()).or_default()
})
}
pub fn buffers<'a>(
&'a self,
) -> impl DerefMut<Target = HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>> + 'a
{
RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers)
}
pub fn channel_buffers<'a>(
&'a self,
) -> impl DerefMut<Target = HashSet<ModelHandle<ChannelBuffer>>> + 'a {
RefMut::map(self.state.borrow_mut(), |state| &mut state.channel_buffers)
}
pub fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary {
self.app_state
.user_store
.read_with(cx, |store, _| ContactsSummary {
current: store
.contacts()
.iter()
.map(|contact| contact.user.github_login.clone())
.collect(),
outgoing_requests: store
.outgoing_contact_requests()
.iter()
.map(|user| user.github_login.clone())
.collect(),
incoming_requests: store
.incoming_contact_requests()
.iter()
.map(|user| user.github_login.clone())
.collect(),
})
}
pub async fn build_local_project(
&self,
root_path: impl AsRef<Path>,
cx: &mut TestAppContext,
) -> (ModelHandle<Project>, WorktreeId) {
let project = cx.update(|cx| {
Project::local(
self.client().clone(),
self.app_state.user_store.clone(),
self.app_state.languages.clone(),
self.app_state.fs.clone(),
cx,
)
});
let (worktree, _) = project
.update(cx, |p, cx| {
p.find_or_create_local_worktree(root_path, true, cx)
})
.await
.unwrap();
worktree
.read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete())
.await;
(project, worktree.read_with(cx, |tree, _| tree.id()))
}
pub async fn build_remote_project(
&self,
host_project_id: u64,
guest_cx: &mut TestAppContext,
) -> ModelHandle<Project> {
let active_call = guest_cx.read(ActiveCall::global);
let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone());
room.update(guest_cx, |room, cx| {
room.join_project(
host_project_id,
self.app_state.languages.clone(),
self.app_state.fs.clone(),
cx,
)
})
.await
.unwrap()
}
pub fn build_workspace(
&self,
project: &ModelHandle<Project>,
cx: &mut TestAppContext,
) -> WindowHandle<Workspace> {
cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx))
}
}
impl Drop for TestClient {
fn drop(&mut self) {
self.app_state.client.teardown();
}
}

View File

@@ -15,7 +15,7 @@ use gpui::{
ViewContext, ViewHandle,
};
use project::Project;
use std::any::Any;
use std::any::{Any, TypeId};
use workspace::{
item::{FollowableItem, Item, ItemHandle},
register_followable_item,
@@ -189,6 +189,21 @@ impl View for ChannelView {
}
impl Item for ChannelView {
fn act_as_type<'a>(
&'a self,
type_id: TypeId,
self_handle: &'a ViewHandle<Self>,
_: &'a AppContext,
) -> Option<&'a AnyViewHandle> {
if type_id == TypeId::of::<Self>() {
Some(self_handle)
} else if type_id == TypeId::of::<Editor>() {
Some(&self.editor)
} else {
None
}
}
fn tab_content<V: 'static>(
&self,
_: Option<usize>,

View File

@@ -4,7 +4,7 @@ mod panel_settings;
use anyhow::Result;
use call::ActiveCall;
use channel::{Channel, ChannelEvent, ChannelId, ChannelStore};
use channel::{Channel, ChannelEvent, ChannelId, ChannelPath, ChannelStore};
use client::{proto::PeerId, Client, Contact, User, UserStore};
use context_menu::{ContextMenu, ContextMenuItem};
use db::kvp::KEY_VALUE_STORE;
@@ -35,7 +35,7 @@ use panel_settings::{CollaborationPanelDockPosition, CollaborationPanelSettings}
use project::{Fs, Project};
use serde_derive::{Deserialize, Serialize};
use settings::SettingsStore;
use std::{borrow::Cow, mem, sync::Arc};
use std::{borrow::Cow, hash::Hash, mem, sync::Arc};
use theme::{components::ComponentExt, IconButton};
use util::{iife, ResultExt, TryFutureExt};
use workspace::{
@@ -54,37 +54,59 @@ use self::contact_finder::ContactFinder;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
struct RemoveChannel {
channel_id: u64,
channel_id: ChannelId,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
struct ToggleCollapse {
channel_id: u64,
location: ChannelLocation<'static>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
struct NewChannel {
channel_id: u64,
location: ChannelLocation<'static>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
struct InviteMembers {
channel_id: u64,
channel_id: ChannelId,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
struct ManageMembers {
channel_id: u64,
channel_id: ChannelId,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
struct RenameChannel {
channel_id: u64,
location: ChannelLocation<'static>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
struct OpenChannelBuffer {
channel_id: u64,
channel_id: ChannelId,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
struct LinkChannel {
channel_id: ChannelId,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
struct MoveChannel {
channel_id: ChannelId,
parent_id: Option<ChannelId>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
struct PutChannel {
to: ChannelId,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
struct UnlinkChannel {
channel_id: ChannelId,
parent_id: Option<ChannelId>,
}
actions!(
@@ -107,12 +129,40 @@ impl_actions!(
ManageMembers,
RenameChannel,
ToggleCollapse,
OpenChannelBuffer
OpenChannelBuffer,
LinkChannel,
MoveChannel,
PutChannel,
UnlinkChannel
]
);
const COLLABORATION_PANEL_KEY: &'static str = "CollaborationPanel";
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct ChannelLocation<'a> {
channel: ChannelId,
path: Cow<'a, ChannelPath>,
}
impl From<(ChannelId, ChannelPath)> for ChannelLocation<'static> {
fn from(value: (ChannelId, ChannelPath)) -> Self {
ChannelLocation {
channel: value.0,
path: Cow::Owned(value.1),
}
}
}
impl<'a> From<(ChannelId, &'a ChannelPath)> for ChannelLocation<'a> {
fn from(value: (ChannelId, &'a ChannelPath)) -> Self {
ChannelLocation {
channel: value.0,
path: Cow::Borrowed(value.1),
}
}
}
pub fn init(_client: Arc<Client>, cx: &mut AppContext) {
settings::register::<panel_settings::CollaborationPanelSettings>(cx);
contact_finder::init(cx);
@@ -135,16 +185,65 @@ pub fn init(_client: Arc<Client>, cx: &mut AppContext) {
cx.add_action(CollabPanel::collapse_selected_channel);
cx.add_action(CollabPanel::expand_selected_channel);
cx.add_action(CollabPanel::open_channel_buffer);
cx.add_action(
|panel: &mut CollabPanel, action: &LinkChannel, _: &mut ViewContext<CollabPanel>| {
panel.link_or_move = Some(ChannelCopy::Link(action.channel_id));
},
);
cx.add_action(
|panel: &mut CollabPanel, action: &MoveChannel, _: &mut ViewContext<CollabPanel>| {
panel.link_or_move = Some(ChannelCopy::Move {
channel_id: action.channel_id,
parent_id: action.parent_id,
});
},
);
cx.add_action(
|panel: &mut CollabPanel, action: &PutChannel, cx: &mut ViewContext<CollabPanel>| {
if let Some(copy) = panel.link_or_move.take() {
match copy {
ChannelCopy::Move {
channel_id,
parent_id,
} => panel.channel_store.update(cx, |channel_store, cx| {
channel_store
.move_channel(channel_id, parent_id, action.to, cx)
.detach_and_log_err(cx)
}),
ChannelCopy::Link(channel) => {
panel.channel_store.update(cx, |channel_store, cx| {
channel_store
.link_channel(channel, action.to, cx)
.detach_and_log_err(cx)
})
}
}
}
},
);
cx.add_action(
|panel: &mut CollabPanel, action: &UnlinkChannel, cx: &mut ViewContext<CollabPanel>| {
panel.channel_store.update(cx, |channel_store, cx| {
channel_store
.unlink_channel(action.channel_id, action.parent_id, cx)
.detach_and_log_err(cx)
})
},
);
}
#[derive(Debug)]
pub enum ChannelEditingState {
Create {
parent_id: Option<u64>,
location: Option<ChannelLocation<'static>>,
pending_name: Option<String>,
},
Rename {
channel_id: u64,
location: ChannelLocation<'static>,
pending_name: Option<String>,
},
}
@@ -158,10 +257,36 @@ impl ChannelEditingState {
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ChannelCopy {
Move {
channel_id: u64,
parent_id: Option<u64>,
},
Link(u64),
}
impl ChannelCopy {
fn channel_id(&self) -> u64 {
match self {
ChannelCopy::Move { channel_id, .. } => *channel_id,
ChannelCopy::Link(channel_id) => *channel_id,
}
}
fn is_move(&self) -> bool {
match self {
ChannelCopy::Move { .. } => true,
ChannelCopy::Link(_) => false,
}
}
}
pub struct CollabPanel {
width: Option<f32>,
fs: Arc<dyn Fs>,
has_focus: bool,
link_or_move: Option<ChannelCopy>,
pending_serialization: Task<Option<()>>,
context_menu: ViewHandle<ContextMenu>,
filter_editor: ViewHandle<Editor>,
@@ -177,7 +302,7 @@ pub struct CollabPanel {
list_state: ListState<Self>,
subscriptions: Vec<Subscription>,
collapsed_sections: Vec<Section>,
collapsed_channels: Vec<ChannelId>,
collapsed_channels: Vec<ChannelLocation<'static>>,
workspace: WeakViewHandle<Workspace>,
context_menu_on_selected: bool,
}
@@ -185,7 +310,7 @@ pub struct CollabPanel {
#[derive(Serialize, Deserialize)]
struct SerializedCollabPanel {
width: Option<f32>,
collapsed_channels: Option<Vec<ChannelId>>,
collapsed_channels: Option<Vec<ChannelLocation<'static>>>,
}
#[derive(Debug)]
@@ -229,6 +354,7 @@ enum ListEntry {
Channel {
channel: Arc<Channel>,
depth: usize,
path: ChannelPath,
},
ChannelNotes {
channel_id: ChannelId,
@@ -353,10 +479,15 @@ impl CollabPanel {
cx,
)
}
ListEntry::Channel { channel, depth } => {
ListEntry::Channel {
channel,
depth,
path,
} => {
let channel_row = this.render_channel(
&*channel,
*depth,
path.to_owned(),
&theme.collab_panel,
is_selected,
cx,
@@ -425,6 +556,7 @@ impl CollabPanel {
let mut this = Self {
width: None,
has_focus: false,
link_or_move: None,
fs: workspace.app_state().fs.clone(),
pending_serialization: Task::ready(None),
context_menu: cx.add_view(|cx| ContextMenu::new(view_id, cx)),
@@ -512,7 +644,13 @@ impl CollabPanel {
.log_err()
.flatten()
{
Some(serde_json::from_str::<SerializedCollabPanel>(&panel)?)
match serde_json::from_str::<SerializedCollabPanel>(&panel) {
Ok(panel) => Some(panel),
Err(err) => {
log::error!("Failed to deserialize collaboration panel: {}", err);
None
}
}
} else {
None
};
@@ -702,28 +840,24 @@ impl CollabPanel {
executor.clone(),
));
if let Some(state) = &self.channel_editing_state {
if matches!(
state,
ChannelEditingState::Create {
parent_id: None,
..
}
) {
if matches!(state, ChannelEditingState::Create { location: None, .. }) {
self.entries.push(ListEntry::ChannelEditor { depth: 0 });
}
}
let mut collapse_depth = None;
for mat in matches {
let (depth, channel) =
channel_store.channel_at_index(mat.candidate_id).unwrap();
let (channel, path) = channel_store.channel_at_index(mat.candidate_id).unwrap();
let depth = path.len() - 1;
if collapse_depth.is_none() && self.is_channel_collapsed(channel.id) {
let location: ChannelLocation<'_> = (channel.id, path).into();
if collapse_depth.is_none() && self.is_channel_collapsed(&location) {
collapse_depth = Some(depth);
} else if let Some(collapsed_depth) = collapse_depth {
if depth > collapsed_depth {
continue;
}
if self.is_channel_collapsed(channel.id) {
if self.is_channel_collapsed(&location) {
collapse_depth = Some(depth);
} else {
collapse_depth = None;
@@ -731,18 +865,21 @@ impl CollabPanel {
}
match &self.channel_editing_state {
Some(ChannelEditingState::Create { parent_id, .. })
if *parent_id == Some(channel.id) =>
{
Some(ChannelEditingState::Create {
location: parent_id,
..
}) if *parent_id == Some(location) => {
self.entries.push(ListEntry::Channel {
channel: channel.clone(),
depth,
path: path.clone(),
});
self.entries
.push(ListEntry::ChannelEditor { depth: depth + 1 });
}
Some(ChannelEditingState::Rename { channel_id, .. })
if *channel_id == channel.id =>
Some(ChannelEditingState::Rename { location, .. })
if location.channel == channel.id
&& location.path == Cow::Borrowed(path) =>
{
self.entries.push(ListEntry::ChannelEditor { depth });
}
@@ -750,6 +887,7 @@ impl CollabPanel {
self.entries.push(ListEntry::Channel {
channel: channel.clone(),
depth,
path: path.clone(),
});
}
}
@@ -1546,14 +1684,21 @@ impl CollabPanel {
&self,
channel: &Channel,
depth: usize,
path: ChannelPath,
theme: &theme::CollabPanel,
is_selected: bool,
cx: &mut ViewContext<Self>,
) -> AnyElement<Self> {
let channel_id = channel.id;
let has_children = self.channel_store.read(cx).has_children(channel_id);
let disclosed =
has_children.then(|| !self.collapsed_channels.binary_search(&channel_id).is_ok());
let disclosed = {
let location = ChannelLocation {
channel: channel_id,
path: Cow::Borrowed(&path),
};
has_children.then(|| !self.collapsed_channels.binary_search(&location).is_ok())
};
let is_active = iife!({
let call_channel = ActiveCall::global(cx)
@@ -1567,7 +1712,7 @@ impl CollabPanel {
const FACEPILE_LIMIT: usize = 3;
MouseEventHandler::new::<Channel, _>(channel.id as usize, cx, |state, cx| {
MouseEventHandler::new::<Channel, _>(id(&path) as usize, cx, |state, cx| {
Flex::<Self>::row()
.with_child(
Svg::new("icons/hash.svg")
@@ -1618,8 +1763,13 @@ impl CollabPanel {
})
.align_children_center()
.styleable_component()
.disclosable(disclosed, Box::new(ToggleCollapse { channel_id }))
.with_id(channel_id as usize)
.disclosable(
disclosed,
Box::new(ToggleCollapse {
location: (channel_id, path.clone()).into(),
}),
)
.with_id(id(&path) as usize)
.with_style(theme.disclosure.clone())
.element()
.constrained()
@@ -1635,7 +1785,11 @@ impl CollabPanel {
this.join_channel(channel_id, cx);
})
.on_click(MouseButton::Right, move |e, this, cx| {
this.deploy_channel_context_menu(Some(e.position), channel_id, cx);
this.deploy_channel_context_menu(
Some(e.position),
&(channel_id, path.clone()).into(),
cx,
);
})
.with_cursor_style(CursorStyle::PointingHand)
.into_any()
@@ -1882,11 +2036,20 @@ impl CollabPanel {
fn deploy_channel_context_menu(
&mut self,
position: Option<Vector2F>,
channel_id: u64,
location: &ChannelLocation<'static>,
cx: &mut ViewContext<Self>,
) {
self.context_menu_on_selected = position.is_none();
let operation_details = self.link_or_move.as_ref().and_then(|link_or_move| {
let channel_name = self
.channel_store
.read(cx)
.channel_for_id(link_or_move.channel_id())
.map(|channel| channel.name.clone())?;
Some((channel_name, link_or_move.is_move()))
});
self.context_menu.update(cx, |context_menu, cx| {
context_menu.set_position_mode(if self.context_menu_on_selected {
OverlayPositionMode::Local
@@ -1894,27 +2057,112 @@ impl CollabPanel {
OverlayPositionMode::Window
});
let expand_action_name = if self.is_channel_collapsed(channel_id) {
let mut items = Vec::new();
if let Some((channel_name, is_move)) = operation_details {
items.push(ContextMenuItem::action(
format!(
"{} '#{}' here",
if is_move { "Move" } else { "Link" },
channel_name
),
PutChannel {
to: location.channel,
},
));
items.push(ContextMenuItem::Separator)
}
let expand_action_name = if self.is_channel_collapsed(&location) {
"Expand Subchannels"
} else {
"Collapse Subchannels"
};
let mut items = vec![
ContextMenuItem::action(expand_action_name, ToggleCollapse { channel_id }),
ContextMenuItem::action("Open Notes", OpenChannelBuffer { channel_id }),
];
items.extend([
ContextMenuItem::action(
expand_action_name,
ToggleCollapse {
location: location.clone(),
},
),
ContextMenuItem::action(
"Open Notes",
OpenChannelBuffer {
channel_id: location.channel,
},
),
]);
if self.channel_store.read(cx).is_user_admin(location.channel) {
let parent_id = location.path.parent_id();
if self.channel_store.read(cx).is_user_admin(channel_id) {
items.extend([
ContextMenuItem::Separator,
ContextMenuItem::action("New Subchannel", NewChannel { channel_id }),
ContextMenuItem::action("Rename", RenameChannel { channel_id }),
ContextMenuItem::action(
"New Subchannel",
NewChannel {
location: location.clone(),
},
),
ContextMenuItem::action(
"Rename",
RenameChannel {
location: location.clone(),
},
),
ContextMenuItem::Separator,
ContextMenuItem::action("Invite Members", InviteMembers { channel_id }),
ContextMenuItem::action("Manage Members", ManageMembers { channel_id }),
]);
items.push(ContextMenuItem::action(
if parent_id.is_some() {
"Unlink from parent"
} else {
"Unlink from root"
},
UnlinkChannel {
channel_id: location.channel,
parent_id,
},
));
items.extend([
ContextMenuItem::action(
"Link this channel",
LinkChannel {
channel_id: location.channel,
},
),
ContextMenuItem::action(
"Move this channel",
MoveChannel {
channel_id: location.channel,
parent_id,
},
),
]);
items.extend([
ContextMenuItem::Separator,
ContextMenuItem::action("Delete", RemoveChannel { channel_id }),
ContextMenuItem::action(
"Invite Members",
InviteMembers {
channel_id: location.channel,
},
),
ContextMenuItem::action(
"Manage Members",
ManageMembers {
channel_id: location.channel,
},
),
ContextMenuItem::Separator,
ContextMenuItem::action(
"Delete",
RemoveChannel {
channel_id: location.channel,
},
),
]);
}
@@ -2040,7 +2288,7 @@ impl CollabPanel {
if let Some(editing_state) = &mut self.channel_editing_state {
match editing_state {
ChannelEditingState::Create {
parent_id,
location,
pending_name,
..
} => {
@@ -2053,13 +2301,17 @@ impl CollabPanel {
self.channel_store
.update(cx, |channel_store, cx| {
channel_store.create_channel(&channel_name, *parent_id, cx)
channel_store.create_channel(
&channel_name,
location.as_ref().map(|location| location.channel),
cx,
)
})
.detach();
cx.notify();
}
ChannelEditingState::Rename {
channel_id,
location,
pending_name,
} => {
if pending_name.is_some() {
@@ -2070,7 +2322,7 @@ impl CollabPanel {
self.channel_store
.update(cx, |channel_store, cx| {
channel_store.rename(*channel_id, &channel_name, cx)
channel_store.rename(location.channel, &channel_name, cx)
})
.detach();
cx.notify();
@@ -2097,38 +2349,58 @@ impl CollabPanel {
_: &CollapseSelectedChannel,
cx: &mut ViewContext<Self>,
) {
let Some(channel_id) = self.selected_channel().map(|channel| channel.id) else {
let Some((channel_id, path)) = self
.selected_channel()
.map(|(channel, parent)| (channel.id, parent))
else {
return;
};
if self.is_channel_collapsed(channel_id) {
let path = path.to_owned();
if self.is_channel_collapsed(&(channel_id, path.clone()).into()) {
return;
}
self.toggle_channel_collapsed(&ToggleCollapse { channel_id }, cx)
self.toggle_channel_collapsed(
&ToggleCollapse {
location: (channel_id, path).into(),
},
cx,
)
}
fn expand_selected_channel(&mut self, _: &ExpandSelectedChannel, cx: &mut ViewContext<Self>) {
let Some(channel_id) = self.selected_channel().map(|channel| channel.id) else {
let Some((channel_id, path)) = self
.selected_channel()
.map(|(channel, parent)| (channel.id, parent))
else {
return;
};
if !self.is_channel_collapsed(channel_id) {
let path = path.to_owned();
if !self.is_channel_collapsed(&(channel_id, path.clone()).into()) {
return;
}
self.toggle_channel_collapsed(&ToggleCollapse { channel_id }, cx)
self.toggle_channel_collapsed(
&ToggleCollapse {
location: (channel_id, path).into(),
},
cx,
)
}
fn toggle_channel_collapsed(&mut self, action: &ToggleCollapse, cx: &mut ViewContext<Self>) {
let channel_id = action.channel_id;
let location = action.location.clone();
match self.collapsed_channels.binary_search(&channel_id) {
match self.collapsed_channels.binary_search(&location) {
Ok(ix) => {
self.collapsed_channels.remove(ix);
}
Err(ix) => {
self.collapsed_channels.insert(ix, channel_id);
self.collapsed_channels.insert(ix, location);
}
};
self.serialize(cx);
@@ -2137,8 +2409,8 @@ impl CollabPanel {
cx.focus_self();
}
fn is_channel_collapsed(&self, channel: ChannelId) -> bool {
self.collapsed_channels.binary_search(&channel).is_ok()
fn is_channel_collapsed(&self, location: &ChannelLocation) -> bool {
self.collapsed_channels.binary_search(location).is_ok()
}
fn leave_call(cx: &mut ViewContext<Self>) {
@@ -2163,7 +2435,7 @@ impl CollabPanel {
fn new_root_channel(&mut self, cx: &mut ViewContext<Self>) {
self.channel_editing_state = Some(ChannelEditingState::Create {
parent_id: None,
location: None,
pending_name: None,
});
self.update_entries(false, cx);
@@ -2181,9 +2453,9 @@ impl CollabPanel {
fn new_subchannel(&mut self, action: &NewChannel, cx: &mut ViewContext<Self>) {
self.collapsed_channels
.retain(|&channel| channel != action.channel_id);
.retain(|channel| *channel != action.location);
self.channel_editing_state = Some(ChannelEditingState::Create {
parent_id: Some(action.channel_id),
location: Some(action.location.to_owned()),
pending_name: None,
});
self.update_entries(false, cx);
@@ -2201,16 +2473,16 @@ impl CollabPanel {
}
fn remove(&mut self, _: &Remove, cx: &mut ViewContext<Self>) {
if let Some(channel) = self.selected_channel() {
if let Some((channel, _)) = self.selected_channel() {
self.remove_channel(channel.id, cx)
}
}
fn rename_selected_channel(&mut self, _: &menu::SecondaryConfirm, cx: &mut ViewContext<Self>) {
if let Some(channel) = self.selected_channel() {
if let Some((channel, parent)) = self.selected_channel() {
self.rename_channel(
&RenameChannel {
channel_id: channel.id,
location: (channel.id, parent.to_owned()).into(),
},
cx,
);
@@ -2219,12 +2491,15 @@ impl CollabPanel {
fn rename_channel(&mut self, action: &RenameChannel, cx: &mut ViewContext<Self>) {
let channel_store = self.channel_store.read(cx);
if !channel_store.is_user_admin(action.channel_id) {
if !channel_store.is_user_admin(action.location.channel) {
return;
}
if let Some(channel) = channel_store.channel_for_id(action.channel_id).cloned() {
if let Some(channel) = channel_store
.channel_for_id(action.location.channel)
.cloned()
{
self.channel_editing_state = Some(ChannelEditingState::Rename {
channel_id: action.channel_id,
location: action.location.to_owned(),
pending_name: None,
});
self.channel_name_editor.update(cx, |editor, cx| {
@@ -2240,7 +2515,8 @@ impl CollabPanel {
fn open_channel_buffer(&mut self, action: &OpenChannelBuffer, cx: &mut ViewContext<Self>) {
if let Some(workspace) = self.workspace.upgrade(cx) {
let pane = workspace.read(cx).active_pane().clone();
let channel_view = ChannelView::open(action.channel_id, pane.clone(), workspace, cx);
let channel_id = action.channel_id;
let channel_view = ChannelView::open(channel_id, pane.clone(), workspace, cx);
cx.spawn(|_, mut cx| async move {
let channel_view = channel_view.await?;
pane.update(&mut cx, |pane, cx| {
@@ -2249,22 +2525,38 @@ impl CollabPanel {
anyhow::Ok(())
})
.detach();
let room_id = ActiveCall::global(cx)
.read(cx)
.room()
.map(|room| room.read(cx).id());
ActiveCall::report_call_event_for_room(
"open channel notes",
room_id,
Some(channel_id),
&self.client,
cx,
);
}
}
fn show_inline_context_menu(&mut self, _: &menu::ShowContextMenu, cx: &mut ViewContext<Self>) {
let Some(channel) = self.selected_channel() else {
let Some((channel, path)) = self.selected_channel() else {
return;
};
self.deploy_channel_context_menu(None, channel.id, cx);
self.deploy_channel_context_menu(None, &(channel.id, path.to_owned()).into(), cx);
}
fn selected_channel(&self) -> Option<&Arc<Channel>> {
fn selected_channel(&self) -> Option<(&Arc<Channel>, &ChannelPath)> {
self.selection
.and_then(|ix| self.entries.get(ix))
.and_then(|entry| match entry {
ListEntry::Channel { channel, .. } => Some(channel),
ListEntry::Channel {
channel,
path: parent,
..
} => Some((channel, parent)),
_ => None,
})
}
@@ -2644,13 +2936,17 @@ impl PartialEq for ListEntry {
ListEntry::Channel {
channel: channel_1,
depth: depth_1,
path: parent_1,
} => {
if let ListEntry::Channel {
channel: channel_2,
depth: depth_2,
path: parent_2,
} = other
{
return channel_1.id == channel_2.id && depth_1 == depth_2;
return channel_1.id == channel_2.id
&& depth_1 == depth_2
&& parent_1 == parent_2;
}
}
ListEntry::ChannelNotes { channel_id } => {
@@ -2713,3 +3009,26 @@ fn render_icon_button(style: &IconButton, svg_path: &'static str) -> impl Elemen
.contained()
.with_style(style.container)
}
/// Hash a channel path to a u64, for use as a mouse id
/// Based on the FowlerNollVo hash:
/// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
fn id(path: &[ChannelId]) -> u64 {
// I probably should have done this, but I didn't
// let hasher = DefaultHasher::new();
// let path = path.hash(&mut hasher);
// let x = hasher.finish();
const OFFSET: u64 = 14695981039346656037;
const PRIME: u64 = 1099511628211;
let mut hash = OFFSET;
for id in path.iter() {
for id in id.to_ne_bytes() {
hash = hash ^ (id as u64);
hash = (hash as u128 * PRIME as u128) as u64;
}
}
hash
}

View File

@@ -771,7 +771,7 @@ impl CollabTitlebarItem {
})
.with_tooltip::<ToggleUserMenu>(
0,
"Toggle user menu".to_owned(),
"Toggle User Menu".to_owned(),
Some(Box::new(ToggleUserMenu)),
tooltip,
cx,

View File

@@ -49,7 +49,7 @@ pub fn toggle_screen_sharing(_: &ToggleScreenSharing, cx: &mut AppContext) {
if room.is_screen_sharing() {
ActiveCall::report_call_event_for_room(
"disable screen share",
room.id(),
Some(room.id()),
room.channel_id(),
&client,
cx,
@@ -58,7 +58,7 @@ pub fn toggle_screen_sharing(_: &ToggleScreenSharing, cx: &mut AppContext) {
} else {
ActiveCall::report_call_event_for_room(
"enable screen share",
room.id(),
Some(room.id()),
room.channel_id(),
&client,
cx,
@@ -78,7 +78,7 @@ pub fn toggle_mute(_: &ToggleMute, cx: &mut AppContext) {
if room.is_muted(cx) {
ActiveCall::report_call_event_for_room(
"enable microphone",
room.id(),
Some(room.id()),
room.channel_id(),
&client,
cx,
@@ -86,7 +86,7 @@ pub fn toggle_mute(_: &ToggleMute, cx: &mut AppContext) {
} else {
ActiveCall::report_call_event_for_room(
"disable microphone",
room.id(),
Some(room.id()),
room.channel_id(),
&client,
cx,

View File

@@ -41,7 +41,7 @@ actions!(
[Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
);
pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<dyn NodeRuntime>, cx: &mut AppContext) {
let copilot = cx.add_model({
let node_runtime = node_runtime.clone();
move |cx| Copilot::start(http, node_runtime, cx)
@@ -265,7 +265,7 @@ pub struct Completion {
pub struct Copilot {
http: Arc<dyn HttpClient>,
node_runtime: Arc<NodeRuntime>,
node_runtime: Arc<dyn NodeRuntime>,
server: CopilotServer,
buffers: HashSet<WeakModelHandle<Buffer>>,
}
@@ -299,7 +299,7 @@ impl Copilot {
fn start(
http: Arc<dyn HttpClient>,
node_runtime: Arc<NodeRuntime>,
node_runtime: Arc<dyn NodeRuntime>,
cx: &mut ModelContext<Self>,
) -> Self {
let mut this = Self {
@@ -335,12 +335,15 @@ impl Copilot {
#[cfg(any(test, feature = "test-support"))]
pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
use node_runtime::FakeNodeRuntime;
let (server, fake_server) =
LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
let node_runtime = FakeNodeRuntime::new();
let this = cx.add_model(|_| Self {
http: http.clone(),
node_runtime: NodeRuntime::instance(http),
node_runtime,
server: CopilotServer::Running(RunningCopilotServer {
lsp: Arc::new(server),
sign_in_status: SignInStatus::Authorized,
@@ -353,7 +356,7 @@ impl Copilot {
fn start_language_server(
http: Arc<dyn HttpClient>,
node_runtime: Arc<NodeRuntime>,
node_runtime: Arc<dyn NodeRuntime>,
this: ModelHandle<Self>,
mut cx: AsyncAppContext,
) -> impl Future<Output = ()> {

View File

@@ -555,67 +555,6 @@ impl DisplaySnapshot {
})
}
/// Returns an iterator of the start positions of the occurrences of `target` in the `self` after `from`
/// Stops if `condition` returns false for any of the character position pairs observed.
pub fn find_while<'a>(
&'a self,
from: DisplayPoint,
target: &str,
condition: impl FnMut(char, DisplayPoint) -> bool + 'a,
) -> impl Iterator<Item = DisplayPoint> + 'a {
Self::find_internal(self.chars_at(from), target.chars().collect(), condition)
}
/// Returns an iterator of the end positions of the occurrences of `target` in the `self` before `from`
/// Stops if `condition` returns false for any of the character position pairs observed.
pub fn reverse_find_while<'a>(
&'a self,
from: DisplayPoint,
target: &str,
condition: impl FnMut(char, DisplayPoint) -> bool + 'a,
) -> impl Iterator<Item = DisplayPoint> + 'a {
Self::find_internal(
self.reverse_chars_at(from),
target.chars().rev().collect(),
condition,
)
}
fn find_internal<'a>(
iterator: impl Iterator<Item = (char, DisplayPoint)> + 'a,
target: Vec<char>,
mut condition: impl FnMut(char, DisplayPoint) -> bool + 'a,
) -> impl Iterator<Item = DisplayPoint> + 'a {
// List of partial matches with the index of the last seen character in target and the starting point of the match
let mut partial_matches: Vec<(usize, DisplayPoint)> = Vec::new();
iterator
.take_while(move |(ch, point)| condition(*ch, *point))
.filter_map(move |(ch, point)| {
if Some(&ch) == target.get(0) {
partial_matches.push((0, point));
}
let mut found = None;
// Keep partial matches that have the correct next character
partial_matches.retain_mut(|(match_position, match_start)| {
if target.get(*match_position) == Some(&ch) {
*match_position += 1;
if *match_position == target.len() {
found = Some(match_start.clone());
// This match is completed. No need to keep tracking it
false
} else {
true
}
} else {
false
}
});
found
})
}
pub fn column_to_chars(&self, display_row: u32, target: u32) -> u32 {
let mut count = 0;
let mut column = 0;
@@ -933,7 +872,7 @@ pub mod tests {
use smol::stream::StreamExt;
use std::{env, sync::Arc};
use theme::SyntaxTheme;
use util::test::{marked_text_offsets, marked_text_ranges, sample_text};
use util::test::{marked_text_ranges, sample_text};
use Bias::*;
#[gpui::test(iterations = 100)]
@@ -1744,32 +1683,6 @@ pub mod tests {
)
}
#[test]
fn test_find_internal() {
assert("This is a ˇtest of find internal", "test");
assert("Some text ˇaˇaˇaa with repeated characters", "aa");
fn assert(marked_text: &str, target: &str) {
let (text, expected_offsets) = marked_text_offsets(marked_text);
let chars = text
.chars()
.enumerate()
.map(|(index, ch)| (ch, DisplayPoint::new(0, index as u32)));
let target = target.chars();
assert_eq!(
expected_offsets
.into_iter()
.map(|offset| offset as u32)
.collect::<Vec<_>>(),
DisplaySnapshot::find_internal(chars, target.collect(), |_, _| true)
.map(|point| point.column())
.collect::<Vec<_>>()
)
}
}
fn syntax_chunks<'a>(
rows: Range<u32>,
map: &ModelHandle<DisplayMap>,

View File

@@ -44,7 +44,7 @@ use gpui::{
elements::*,
executor,
fonts::{self, HighlightStyle, TextStyle},
geometry::vector::Vector2F,
geometry::vector::{vec2f, Vector2F},
impl_actions,
keymap_matcher::KeymapContext,
platform::{CursorStyle, MouseButton},
@@ -312,6 +312,10 @@ actions!(
CopyPath,
CopyRelativePath,
CopyHighlightJson,
ContextMenuFirst,
ContextMenuPrev,
ContextMenuNext,
ContextMenuLast,
]
);
@@ -468,6 +472,10 @@ pub fn init(cx: &mut AppContext) {
cx.add_action(Editor::next_copilot_suggestion);
cx.add_action(Editor::previous_copilot_suggestion);
cx.add_action(Editor::copilot_suggest);
cx.add_action(Editor::context_menu_first);
cx.add_action(Editor::context_menu_prev);
cx.add_action(Editor::context_menu_next);
cx.add_action(Editor::context_menu_last);
hover_popover::init(cx);
scroll::actions::init(cx);
@@ -564,7 +572,7 @@ pub struct Editor {
project: Option<ModelHandle<Project>>,
focused: bool,
blink_manager: ModelHandle<BlinkManager>,
show_local_selections: bool,
pub show_local_selections: bool,
mode: EditorMode,
replica_id_mapping: Option<HashMap<ReplicaId, ReplicaId>>,
show_gutter: bool,
@@ -820,6 +828,7 @@ struct CompletionsMenu {
id: CompletionId,
initial_position: Anchor,
buffer: ModelHandle<Buffer>,
project: Option<ModelHandle<Project>>,
completions: Arc<[Completion]>,
match_candidates: Vec<StringMatchCandidate>,
matches: Arc<[StringMatch]>,
@@ -863,6 +872,48 @@ impl CompletionsMenu {
fn render(&self, style: EditorStyle, cx: &mut ViewContext<Editor>) -> AnyElement<Editor> {
enum CompletionTag {}
let language_servers = self.project.as_ref().map(|project| {
project
.read(cx)
.language_servers_for_buffer(self.buffer.read(cx), cx)
.filter(|(_, server)| server.capabilities().completion_provider.is_some())
.map(|(adapter, server)| (server.server_id(), adapter.short_name))
.collect::<Vec<_>>()
});
let needs_server_name = language_servers
.as_ref()
.map_or(false, |servers| servers.len() > 1);
let get_server_name =
move |lookup_server_id: lsp::LanguageServerId| -> Option<&'static str> {
language_servers
.iter()
.flatten()
.find_map(|(server_id, server_name)| {
if *server_id == lookup_server_id {
Some(*server_name)
} else {
None
}
})
};
let widest_completion_ix = self
.matches
.iter()
.enumerate()
.max_by_key(|(_, mat)| {
let completion = &self.completions[mat.candidate_id];
let mut len = completion.label.text.chars().count();
if let Some(server_name) = get_server_name(completion.server_id) {
len += server_name.chars().count();
}
len
})
.map(|(ix, _)| ix);
let completions = self.completions.clone();
let matches = self.matches.clone();
let selected_item = self.selected_item;
@@ -889,19 +940,83 @@ impl CompletionsMenu {
style.autocomplete.item
};
Text::new(completion.label.text.clone(), style.text.clone())
.with_soft_wrap(false)
.with_highlights(combine_syntax_and_fuzzy_match_highlights(
&completion.label.text,
style.text.color.into(),
styled_runs_for_code_label(
&completion.label,
&style.syntax,
),
&mat.positions,
))
.contained()
.with_style(item_style)
let completion_label =
Text::new(completion.label.text.clone(), style.text.clone())
.with_soft_wrap(false)
.with_highlights(
combine_syntax_and_fuzzy_match_highlights(
&completion.label.text,
style.text.color.into(),
styled_runs_for_code_label(
&completion.label,
&style.syntax,
),
&mat.positions,
),
);
if let Some(server_name) = get_server_name(completion.server_id) {
Flex::row()
.with_child(completion_label)
.with_children((|| {
if !needs_server_name {
return None;
}
let text_style = TextStyle {
color: style.autocomplete.server_name_color,
font_size: style.text.font_size
* style.autocomplete.server_name_size_percent,
..style.text.clone()
};
let label = Text::new(server_name, text_style)
.aligned()
.constrained()
.dynamically(move |constraint, _, _| {
gpui::SizeConstraint {
min: constraint.min,
max: vec2f(
constraint.max.x(),
constraint.min.y(),
),
}
});
if Some(item_ix) == widest_completion_ix {
Some(
label
.contained()
.with_style(
style
.autocomplete
.server_name_container,
)
.into_any(),
)
} else {
Some(label.flex_float().into_any())
}
})())
.into_any()
} else {
completion_label.into_any()
}
.contained()
.with_style(item_style)
.constrained()
.dynamically(
move |constraint, _, _| {
if Some(item_ix) == widest_completion_ix {
constraint
} else {
gpui::SizeConstraint {
min: constraint.min,
max: constraint.min,
}
}
},
)
},
)
.with_cursor_style(CursorStyle::PointingHand)
@@ -918,19 +1033,7 @@ impl CompletionsMenu {
}
},
)
.with_width_from_item(
self.matches
.iter()
.enumerate()
.max_by_key(|(_, mat)| {
self.completions[mat.candidate_id]
.label
.text
.chars()
.count()
})
.map(|(ix, _)| ix),
)
.with_width_from_item(widest_completion_ix)
.contained()
.with_style(container_style)
.into_any()
@@ -1559,7 +1662,7 @@ impl Editor {
.excerpt_containing(self.selections.newest_anchor().head(), cx)
}
fn style(&self, cx: &AppContext) -> EditorStyle {
pub fn style(&self, cx: &AppContext) -> EditorStyle {
build_style(
settings::get::<ThemeSettings>(cx),
self.get_field_editor_theme.as_deref(),
@@ -2166,10 +2269,6 @@ impl Editor {
if self.read_only {
return;
}
if !self.input_enabled {
cx.emit(Event::InputIgnored { text });
return;
}
let selections = self.selections.all_adjusted(cx);
let mut brace_inserted = false;
@@ -2983,6 +3082,7 @@ impl Editor {
});
let id = post_inc(&mut self.next_completion_id);
let project = self.project.clone();
let task = cx.spawn(|this, mut cx| {
async move {
let menu = if let Some(completions) = completions.await.log_err() {
@@ -3001,6 +3101,7 @@ impl Editor {
})
.collect(),
buffer,
project,
completions: completions.into(),
matches: Vec::new().into(),
selected_item: 0,
@@ -3102,17 +3203,30 @@ impl Editor {
.count();
let snapshot = self.buffer.read(cx).snapshot(cx);
let mut range_to_replace: Option<Range<isize>> = None;
let mut ranges = Vec::new();
for selection in &selections {
if snapshot.contains_str_at(selection.start.saturating_sub(lookbehind), &old_text) {
let start = selection.start.saturating_sub(lookbehind);
let end = selection.end + lookahead;
if selection.id == newest_selection.id {
range_to_replace = Some(
((start + common_prefix_len) as isize - selection.start as isize)
..(end as isize - selection.start as isize),
);
}
ranges.push(start + common_prefix_len..end);
} else {
common_prefix_len = 0;
ranges.clear();
ranges.extend(selections.iter().map(|s| {
if s.id == newest_selection.id {
range_to_replace = Some(
old_range.start.to_offset_utf16(&snapshot).0 as isize
- selection.start as isize
..old_range.end.to_offset_utf16(&snapshot).0 as isize
- selection.start as isize,
);
old_range.clone()
} else {
s.start..s.end
@@ -3123,6 +3237,11 @@ impl Editor {
}
let text = &text[common_prefix_len..];
cx.emit(Event::InputHandled {
utf16_range_to_replace: range_to_replace,
text: text.into(),
});
self.transact(cx, |this, cx| {
if let Some(mut snippet) = snippet {
snippet.text = text.to_string();
@@ -3580,6 +3699,10 @@ impl Editor {
self.report_copilot_event(Some(completion.uuid.clone()), true, cx)
}
cx.emit(Event::InputHandled {
utf16_range_to_replace: None,
text: suggestion.text.to_string().into(),
});
self.insert_with_autoindent_mode(&suggestion.text.to_string(), None, cx);
cx.notify();
true
@@ -5069,12 +5192,6 @@ impl Editor {
return;
}
if let Some(context_menu) = self.context_menu.as_mut() {
if context_menu.select_prev(cx) {
return;
}
}
if matches!(self.mode, EditorMode::SingleLine) {
cx.propagate_action();
return;
@@ -5097,15 +5214,6 @@ impl Editor {
return;
}
if self
.context_menu
.as_mut()
.map(|menu| menu.select_first(cx))
.unwrap_or(false)
{
return;
}
if matches!(self.mode, EditorMode::SingleLine) {
cx.propagate_action();
return;
@@ -5145,12 +5253,6 @@ impl Editor {
pub fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
self.take_rename(true, cx);
if let Some(context_menu) = self.context_menu.as_mut() {
if context_menu.select_next(cx) {
return;
}
}
if self.mode == EditorMode::SingleLine {
cx.propagate_action();
return;
@@ -5218,6 +5320,30 @@ impl Editor {
});
}
pub fn context_menu_first(&mut self, _: &ContextMenuFirst, cx: &mut ViewContext<Self>) {
if let Some(context_menu) = self.context_menu.as_mut() {
context_menu.select_first(cx);
}
}
pub fn context_menu_prev(&mut self, _: &ContextMenuPrev, cx: &mut ViewContext<Self>) {
if let Some(context_menu) = self.context_menu.as_mut() {
context_menu.select_prev(cx);
}
}
pub fn context_menu_next(&mut self, _: &ContextMenuNext, cx: &mut ViewContext<Self>) {
if let Some(context_menu) = self.context_menu.as_mut() {
context_menu.select_next(cx);
}
}
pub fn context_menu_last(&mut self, _: &ContextMenuLast, cx: &mut ViewContext<Self>) {
if let Some(context_menu) = self.context_menu.as_mut() {
context_menu.select_last(cx);
}
}
pub fn move_to_previous_word_start(
&mut self,
_: &MoveToPreviousWordStart,
@@ -8328,6 +8454,41 @@ impl Editor {
pub fn inlay_hint_cache(&self) -> &InlayHintCache {
&self.inlay_hint_cache
}
pub fn replay_insert_event(
&mut self,
text: &str,
relative_utf16_range: Option<Range<isize>>,
cx: &mut ViewContext<Self>,
) {
if !self.input_enabled {
cx.emit(Event::InputIgnored { text: text.into() });
return;
}
if let Some(relative_utf16_range) = relative_utf16_range {
let selections = self.selections.all::<OffsetUtf16>(cx);
self.change_selections(None, cx, |s| {
let new_ranges = selections.into_iter().map(|range| {
let start = OffsetUtf16(
range
.head()
.0
.saturating_add_signed(relative_utf16_range.start),
);
let end = OffsetUtf16(
range
.head()
.0
.saturating_add_signed(relative_utf16_range.end),
);
start..end
});
s.select_ranges(new_ranges);
});
}
self.handle_input(text, cx);
}
}
fn document_to_inlay_range(
@@ -8416,6 +8577,10 @@ pub enum Event {
InputIgnored {
text: Arc<str>,
},
InputHandled {
utf16_range_to_replace: Option<Range<isize>>,
text: Arc<str>,
},
ExcerptsAdded {
buffer: ModelHandle<Buffer>,
predecessor: ExcerptId,
@@ -8569,17 +8734,20 @@ impl View for Editor {
if self.pending_rename.is_some() {
keymap.add_identifier("renaming");
}
match self.context_menu.as_ref() {
Some(ContextMenu::Completions(_)) => {
keymap.add_identifier("menu");
keymap.add_identifier("showing_completions")
if self.context_menu_visible() {
match self.context_menu.as_ref() {
Some(ContextMenu::Completions(_)) => {
keymap.add_identifier("menu");
keymap.add_identifier("showing_completions")
}
Some(ContextMenu::CodeActions(_)) => {
keymap.add_identifier("menu");
keymap.add_identifier("showing_code_actions")
}
None => {}
}
Some(ContextMenu::CodeActions(_)) => {
keymap.add_identifier("menu");
keymap.add_identifier("showing_code_actions")
}
None => {}
}
for layer in self.keymap_context_layers.values() {
keymap.extend(layer);
}
@@ -8633,29 +8801,51 @@ impl View for Editor {
text: &str,
cx: &mut ViewContext<Self>,
) {
self.transact(cx, |this, cx| {
if this.input_enabled {
let new_selected_ranges = if let Some(range_utf16) = range_utf16 {
let range_utf16 = OffsetUtf16(range_utf16.start)..OffsetUtf16(range_utf16.end);
Some(this.selection_replacement_ranges(range_utf16, cx))
} else {
this.marked_text_ranges(cx)
};
if !self.input_enabled {
cx.emit(Event::InputIgnored { text: text.into() });
return;
}
if let Some(new_selected_ranges) = new_selected_ranges {
this.change_selections(None, cx, |selections| {
selections.select_ranges(new_selected_ranges)
});
}
self.transact(cx, |this, cx| {
let new_selected_ranges = if let Some(range_utf16) = range_utf16 {
let range_utf16 = OffsetUtf16(range_utf16.start)..OffsetUtf16(range_utf16.end);
Some(this.selection_replacement_ranges(range_utf16, cx))
} else {
this.marked_text_ranges(cx)
};
let range_to_replace = new_selected_ranges.as_ref().and_then(|ranges_to_replace| {
let newest_selection_id = this.selections.newest_anchor().id;
this.selections
.all::<OffsetUtf16>(cx)
.iter()
.zip(ranges_to_replace.iter())
.find_map(|(selection, range)| {
if selection.id == newest_selection_id {
Some(
(range.start.0 as isize - selection.head().0 as isize)
..(range.end.0 as isize - selection.head().0 as isize),
)
} else {
None
}
})
});
cx.emit(Event::InputHandled {
utf16_range_to_replace: range_to_replace,
text: text.into(),
});
if let Some(new_selected_ranges) = new_selected_ranges {
this.change_selections(None, cx, |selections| {
selections.select_ranges(new_selected_ranges)
});
}
this.handle_input(text, cx);
});
if !self.input_enabled {
return;
}
if let Some(transaction) = self.ime_transaction {
self.buffer.update(cx, |buffer, cx| {
buffer.group_until_transaction(transaction, cx);
@@ -8673,6 +8863,7 @@ impl View for Editor {
cx: &mut ViewContext<Self>,
) {
if !self.input_enabled {
cx.emit(Event::InputIgnored { text: text.into() });
return;
}
@@ -8697,6 +8888,29 @@ impl View for Editor {
None
};
let range_to_replace = ranges_to_replace.as_ref().and_then(|ranges_to_replace| {
let newest_selection_id = this.selections.newest_anchor().id;
this.selections
.all::<OffsetUtf16>(cx)
.iter()
.zip(ranges_to_replace.iter())
.find_map(|(selection, range)| {
if selection.id == newest_selection_id {
Some(
(range.start.0 as isize - selection.head().0 as isize)
..(range.end.0 as isize - selection.head().0 as isize),
)
} else {
None
}
})
});
cx.emit(Event::InputHandled {
utf16_range_to_replace: range_to_replace,
text: text.into(),
});
if let Some(ranges) = ranges_to_replace {
this.change_selections(None, cx, |s| s.select_ranges(ranges));
}
@@ -9186,6 +9400,7 @@ pub fn split_words<'a>(text: &'a str) -> impl std::iter::Iterator<Item = &'a str
None
})
.flat_map(|word| word.split_inclusive('_'))
.flat_map(|word| word.split_inclusive('-'))
}
trait RangeToAnchorExt {

View File

@@ -19,7 +19,8 @@ use gpui::{
use indoc::indoc;
use language::{
language_settings::{AllLanguageSettings, AllLanguageSettingsContent, LanguageSettingsContent},
BracketPairConfig, FakeLspAdapter, LanguageConfig, LanguageRegistry, Point,
BracketPairConfig, FakeLspAdapter, LanguageConfig, LanguageConfigOverride, LanguageRegistry,
Override, Point,
};
use parking_lot::Mutex;
use project::project_settings::{LspSettings, ProjectSettings};
@@ -5339,7 +5340,7 @@ async fn test_completion(cx: &mut gpui::TestAppContext) {
cx.condition(|editor, _| editor.context_menu_visible())
.await;
let apply_additional_edits = cx.update_editor(|editor, cx| {
editor.move_down(&MoveDown, cx);
editor.context_menu_next(&Default::default(), cx);
editor
.confirm_completion(&ConfirmCompletion::default(), cx)
.unwrap()
@@ -7688,6 +7689,105 @@ async fn test_completions_with_additional_edits(cx: &mut gpui::TestAppContext) {
cx.assert_editor_state(indoc! {"fn main() { let a = Some(2)ˇ; }"});
}
#[gpui::test]
async fn test_completions_in_languages_with_extra_word_characters(cx: &mut gpui::TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorLspTestContext::new(
Language::new(
LanguageConfig {
path_suffixes: vec!["jsx".into()],
overrides: [(
"element".into(),
LanguageConfigOverride {
word_characters: Override::Set(['-'].into_iter().collect()),
..Default::default()
},
)]
.into_iter()
.collect(),
..Default::default()
},
Some(tree_sitter_typescript::language_tsx()),
)
.with_override_query("(jsx_self_closing_element) @element")
.unwrap(),
lsp::ServerCapabilities {
completion_provider: Some(lsp::CompletionOptions {
trigger_characters: Some(vec![":".to_string()]),
..Default::default()
}),
..Default::default()
},
cx,
)
.await;
cx.lsp
.handle_request::<lsp::request::Completion, _, _>(move |_, _| async move {
Ok(Some(lsp::CompletionResponse::Array(vec![
lsp::CompletionItem {
label: "bg-blue".into(),
..Default::default()
},
lsp::CompletionItem {
label: "bg-red".into(),
..Default::default()
},
lsp::CompletionItem {
label: "bg-yellow".into(),
..Default::default()
},
])))
});
cx.set_state(r#"<p class="bgˇ" />"#);
// Trigger completion when typing a dash, because the dash is an extra
// word character in the 'element' scope, which contains the cursor.
cx.simulate_keystroke("-");
cx.foreground().run_until_parked();
cx.update_editor(|editor, _| {
if let Some(ContextMenu::Completions(menu)) = &editor.context_menu {
assert_eq!(
menu.matches.iter().map(|m| &m.string).collect::<Vec<_>>(),
&["bg-red", "bg-blue", "bg-yellow"]
);
} else {
panic!("expected completion menu to be open");
}
});
cx.simulate_keystroke("l");
cx.foreground().run_until_parked();
cx.update_editor(|editor, _| {
if let Some(ContextMenu::Completions(menu)) = &editor.context_menu {
assert_eq!(
menu.matches.iter().map(|m| &m.string).collect::<Vec<_>>(),
&["bg-blue", "bg-yellow"]
);
} else {
panic!("expected completion menu to be open");
}
});
// When filtering completions, consider the character after the '-' to
// be the start of a subword.
cx.set_state(r#"<p class="yelˇ" />"#);
cx.simulate_keystroke("l");
cx.foreground().run_until_parked();
cx.update_editor(|editor, _| {
if let Some(ContextMenu::Completions(menu)) = &editor.context_menu {
assert_eq!(
menu.matches.iter().map(|m| &m.string).collect::<Vec<_>>(),
&["bg-yellow"]
);
} else {
panic!("expected completion menu to be open");
}
});
}
fn empty_range(row: usize, column: usize) -> Range<DisplayPoint> {
let point = DisplayPoint::new(row as u32, column as u32);
point..point
@@ -7707,7 +7807,7 @@ fn assert_selection_ranges(marked_text: &str, view: &mut Editor, cx: &mut ViewCo
/// Handle completion request passing a marked string specifying where the completion
/// should be triggered from using '|' character, what range should be replaced, and what completions
/// should be returned using '<' and '>' to delimit the range
fn handle_completion_request<'a>(
pub fn handle_completion_request<'a>(
cx: &mut EditorLspTestContext<'a>,
marked_string: &str,
completions: Vec<&'static str>,

View File

@@ -1,8 +1,14 @@
use super::{Bias, DisplayPoint, DisplaySnapshot, SelectionGoal, ToDisplayPoint};
use crate::{char_kind, CharKind, ToPoint};
use crate::{char_kind, CharKind, ToOffset, ToPoint};
use language::Point;
use std::ops::Range;
#[derive(Debug, PartialEq)]
pub enum FindRange {
SingleLine,
MultiLine,
}
pub fn left(map: &DisplaySnapshot, mut point: DisplayPoint) -> DisplayPoint {
if point.column() > 0 {
*point.column_mut() -= 1;
@@ -177,20 +183,21 @@ pub fn line_end(
pub fn previous_word_start(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint {
let raw_point = point.to_point(map);
let language = map.buffer_snapshot.language_at(raw_point);
let scope = map.buffer_snapshot.language_scope_at(raw_point);
find_preceding_boundary(map, point, |left, right| {
(char_kind(language, left) != char_kind(language, right) && !right.is_whitespace())
find_preceding_boundary(map, point, FindRange::MultiLine, |left, right| {
(char_kind(&scope, left) != char_kind(&scope, right) && !right.is_whitespace())
|| left == '\n'
})
}
pub fn previous_subword_start(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint {
let raw_point = point.to_point(map);
let language = map.buffer_snapshot.language_at(raw_point);
find_preceding_boundary(map, point, |left, right| {
let scope = map.buffer_snapshot.language_scope_at(raw_point);
find_preceding_boundary(map, point, FindRange::MultiLine, |left, right| {
let is_word_start =
char_kind(language, left) != char_kind(language, right) && !right.is_whitespace();
char_kind(&scope, left) != char_kind(&scope, right) && !right.is_whitespace();
let is_subword_start =
left == '_' && right != '_' || left.is_lowercase() && right.is_uppercase();
is_word_start || is_subword_start || left == '\n'
@@ -199,19 +206,21 @@ pub fn previous_subword_start(map: &DisplaySnapshot, point: DisplayPoint) -> Dis
pub fn next_word_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint {
let raw_point = point.to_point(map);
let language = map.buffer_snapshot.language_at(raw_point);
find_boundary(map, point, |left, right| {
(char_kind(language, left) != char_kind(language, right) && !left.is_whitespace())
let scope = map.buffer_snapshot.language_scope_at(raw_point);
find_boundary(map, point, FindRange::MultiLine, |left, right| {
(char_kind(&scope, left) != char_kind(&scope, right) && !left.is_whitespace())
|| right == '\n'
})
}
pub fn next_subword_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint {
let raw_point = point.to_point(map);
let language = map.buffer_snapshot.language_at(raw_point);
find_boundary(map, point, |left, right| {
let scope = map.buffer_snapshot.language_scope_at(raw_point);
find_boundary(map, point, FindRange::MultiLine, |left, right| {
let is_word_end =
(char_kind(language, left) != char_kind(language, right)) && !left.is_whitespace();
(char_kind(&scope, left) != char_kind(&scope, right)) && !left.is_whitespace();
let is_subword_end =
left != '_' && right == '_' || left.is_lowercase() && right.is_uppercase();
is_word_end || is_subword_end || right == '\n'
@@ -272,79 +281,34 @@ pub fn end_of_paragraph(
map.max_point()
}
/// Scans for a boundary preceding the given start point `from` until a boundary is found, indicated by the
/// given predicate returning true. The predicate is called with the character to the left and right
/// of the candidate boundary location, and will be called with `\n` characters indicating the start
/// or end of a line.
/// Scans for a boundary preceding the given start point `from` until a boundary is found,
/// indicated by the given predicate returning true.
/// The predicate is called with the character to the left and right of the candidate boundary location.
/// If FindRange::SingleLine is specified and no boundary is found before the start of the current line, the start of the current line will be returned.
pub fn find_preceding_boundary(
map: &DisplaySnapshot,
from: DisplayPoint,
find_range: FindRange,
mut is_boundary: impl FnMut(char, char) -> bool,
) -> DisplayPoint {
let mut start_column = 0;
let mut soft_wrap_row = from.row() + 1;
let mut prev_ch = None;
let mut offset = from.to_point(map).to_offset(&map.buffer_snapshot);
let mut prev = None;
for (ch, point) in map.reverse_chars_at(from) {
// Recompute soft_wrap_indent if the row has changed
if point.row() != soft_wrap_row {
soft_wrap_row = point.row();
if point.row() == 0 {
start_column = 0;
} else if let Some(indent) = map.soft_wrap_indent(point.row() - 1) {
start_column = indent;
}
}
// If the current point is in the soft_wrap, skip comparing it
if point.column() < start_column {
continue;
}
if let Some((prev_ch, prev_point)) = prev {
if is_boundary(ch, prev_ch) {
return map.clip_point(prev_point, Bias::Left);
}
}
prev = Some((ch, point));
}
map.clip_point(DisplayPoint::zero(), Bias::Left)
}
/// Scans for a boundary preceding the given start point `from` until a boundary is found, indicated by the
/// given predicate returning true. The predicate is called with the character to the left and right
/// of the candidate boundary location, and will be called with `\n` characters indicating the start
/// or end of a line. If no boundary is found, the start of the line is returned.
pub fn find_preceding_boundary_in_line(
map: &DisplaySnapshot,
from: DisplayPoint,
mut is_boundary: impl FnMut(char, char) -> bool,
) -> DisplayPoint {
let mut start_column = 0;
if from.row() > 0 {
if let Some(indent) = map.soft_wrap_indent(from.row() - 1) {
start_column = indent;
}
}
let mut prev = None;
for (ch, point) in map.reverse_chars_at(from) {
if let Some((prev_ch, prev_point)) = prev {
if is_boundary(ch, prev_ch) {
return map.clip_point(prev_point, Bias::Left);
}
}
if ch == '\n' || point.column() < start_column {
for ch in map.buffer_snapshot.reversed_chars_at(offset) {
if find_range == FindRange::SingleLine && ch == '\n' {
break;
}
if let Some(prev_ch) = prev_ch {
if is_boundary(ch, prev_ch) {
break;
}
}
prev = Some((ch, point));
offset -= ch.len_utf8();
prev_ch = Some(ch);
}
map.clip_point(prev.map(|(_, point)| point).unwrap_or(from), Bias::Left)
map.clip_point(offset.to_display_point(map), Bias::Left)
}
/// Scans for a boundary following the given start point until a boundary is found, indicated by the
@@ -354,59 +318,38 @@ pub fn find_preceding_boundary_in_line(
pub fn find_boundary(
map: &DisplaySnapshot,
from: DisplayPoint,
find_range: FindRange,
mut is_boundary: impl FnMut(char, char) -> bool,
) -> DisplayPoint {
let mut offset = from.to_offset(&map, Bias::Right);
let mut prev_ch = None;
for (ch, point) in map.chars_at(from) {
if let Some(prev_ch) = prev_ch {
if is_boundary(prev_ch, ch) {
return map.clip_point(point, Bias::Right);
}
}
prev_ch = Some(ch);
}
map.clip_point(map.max_point(), Bias::Right)
}
/// Scans for a boundary following the given start point until a boundary is found, indicated by the
/// given predicate returning true. The predicate is called with the character to the left and right
/// of the candidate boundary location, and will be called with `\n` characters indicating the start
/// or end of a line. If no boundary is found, the end of the line is returned
pub fn find_boundary_in_line(
map: &DisplaySnapshot,
from: DisplayPoint,
mut is_boundary: impl FnMut(char, char) -> bool,
) -> DisplayPoint {
let mut prev = None;
for (ch, point) in map.chars_at(from) {
if let Some((prev_ch, _)) = prev {
if is_boundary(prev_ch, ch) {
return map.clip_point(point, Bias::Right);
}
}
prev = Some((ch, point));
if ch == '\n' {
for ch in map.buffer_snapshot.chars_at(offset) {
if find_range == FindRange::SingleLine && ch == '\n' {
break;
}
}
if let Some(prev_ch) = prev_ch {
if is_boundary(prev_ch, ch) {
break;
}
}
// Return the last position checked so that we give a point right before the newline or eof.
map.clip_point(prev.map(|(_, point)| point).unwrap_or(from), Bias::Right)
offset += ch.len_utf8();
prev_ch = Some(ch);
}
map.clip_point(offset.to_display_point(map), Bias::Right)
}
pub fn is_inside_word(map: &DisplaySnapshot, point: DisplayPoint) -> bool {
let raw_point = point.to_point(map);
let language = map.buffer_snapshot.language_at(raw_point);
let scope = map.buffer_snapshot.language_scope_at(raw_point);
let ix = map.clip_point(point, Bias::Left).to_offset(map, Bias::Left);
let text = &map.buffer_snapshot;
let next_char_kind = text.chars_at(ix).next().map(|c| char_kind(language, c));
let next_char_kind = text.chars_at(ix).next().map(|c| char_kind(&scope, c));
let prev_char_kind = text
.reversed_chars_at(ix)
.next()
.map(|c| char_kind(language, c));
.map(|c| char_kind(&scope, c));
prev_char_kind.zip(next_char_kind) == Some((CharKind::Word, CharKind::Word))
}
@@ -533,7 +476,12 @@ mod tests {
) {
let (snapshot, display_points) = marked_display_snapshot(marked_text, cx);
assert_eq!(
find_preceding_boundary(&snapshot, display_points[1], is_boundary),
find_preceding_boundary(
&snapshot,
display_points[1],
FindRange::MultiLine,
is_boundary
),
display_points[0]
);
}
@@ -612,21 +560,15 @@ mod tests {
find_preceding_boundary(
&snapshot,
buffer_snapshot.len().to_display_point(&snapshot),
|left, _| left == 'a',
FindRange::MultiLine,
|left, _| left == 'e',
),
0.to_display_point(&snapshot),
snapshot
.buffer_snapshot
.offset_to_point(5)
.to_display_point(&snapshot),
"Should not stop at inlays when looking for boundaries"
);
assert_eq!(
find_preceding_boundary_in_line(
&snapshot,
buffer_snapshot.len().to_display_point(&snapshot),
|left, _| left == 'a',
),
0.to_display_point(&snapshot),
"Should not stop at inlays when looking for boundaries in line"
);
}
#[gpui::test]
@@ -699,7 +641,12 @@ mod tests {
) {
let (snapshot, display_points) = marked_display_snapshot(marked_text, cx);
assert_eq!(
find_boundary(&snapshot, display_points[0], is_boundary),
find_boundary(
&snapshot,
display_points[0],
FindRange::MultiLine,
is_boundary
),
display_points[1]
);
}

View File

@@ -1417,13 +1417,13 @@ impl MultiBuffer {
return false;
}
let language = self.language_at(position.clone(), cx);
if char_kind(language.as_ref(), char) == CharKind::Word {
let snapshot = self.snapshot(cx);
let position = position.to_offset(&snapshot);
let scope = snapshot.language_scope_at(position);
if char_kind(&scope, char) == CharKind::Word {
return true;
}
let snapshot = self.snapshot(cx);
let anchor = snapshot.anchor_before(position);
anchor
.buffer_id
@@ -1925,8 +1925,8 @@ impl MultiBufferSnapshot {
let mut next_chars = self.chars_at(start).peekable();
let mut prev_chars = self.reversed_chars_at(start).peekable();
let language = self.language_at(start);
let kind = |c| char_kind(language, c);
let scope = self.language_scope_at(start);
let kind = |c| char_kind(&scope, c);
let word_kind = cmp::max(
prev_chars.peek().copied().map(kind),
next_chars.peek().copied().map(kind),

View File

@@ -378,10 +378,6 @@ impl Editor {
return;
}
if amount.move_context_menu_selection(self, cx) {
return;
}
let cur_position = self.scroll_position(cx);
let new_pos = cur_position + vec2f(0., amount.lines(self));
self.set_scroll_position(new_pos, cx);

View File

@@ -1,8 +1,5 @@
use gpui::ViewContext;
use serde::Deserialize;
use util::iife;
use crate::Editor;
use serde::Deserialize;
#[derive(Clone, PartialEq, Deserialize)]
pub enum ScrollAmount {
@@ -13,25 +10,6 @@ pub enum ScrollAmount {
}
impl ScrollAmount {
pub fn move_context_menu_selection(
&self,
editor: &mut Editor,
cx: &mut ViewContext<Editor>,
) -> bool {
iife!({
let context_menu = editor.context_menu.as_mut()?;
match self {
Self::Line(c) if *c > 0. => context_menu.select_next(cx),
Self::Line(_) => context_menu.select_prev(cx),
Self::Page(c) if *c > 0. => context_menu.select_last(cx),
Self::Page(_) => context_menu.select_first(cx),
}
.then_some(())
})
.is_some()
}
pub fn lines(&self, editor: &mut Editor) -> f32 {
match self {
Self::Line(count) => *count,
@@ -39,7 +17,7 @@ impl ScrollAmount {
.visible_line_count()
// subtract one to leave an anchor line
// round towards zero (so page-up and page-down are symmetric)
.map(|l| ((l - 1.) * count).trunc())
.map(|l| (l * count).trunc() - count.signum())
.unwrap_or(0.),
}
}

View File

@@ -51,7 +51,7 @@ impl<'a> EditorLspTestContext<'a> {
language
.path_suffixes()
.first()
.unwrap_or(&"txt".to_string())
.expect("language must have a path suffix for EditorLspTestContext")
);
let mut fake_servers = language

View File

@@ -42,14 +42,14 @@ impl View for FeedbackInfoText {
)
.with_child(
MouseEventHandler::new::<OpenZedCommunityRepo, _>(0, cx, |state, _| {
let contained_text = if state.hovered() {
let style = if state.hovered() {
&theme.feedback.link_text_hover
} else {
&theme.feedback.link_text_default
};
Label::new("community repo", contained_text.text.clone())
Label::new("community repo", style.text.clone())
.contained()
.with_style(style.container)
.aligned()
.left()
.clipped()
@@ -64,6 +64,8 @@ impl View for FeedbackInfoText {
.with_soft_wrap(false)
.aligned(),
)
.contained()
.with_style(theme.feedback.info_text_default.container)
.aligned()
.left()
.clipped()

View File

@@ -1528,8 +1528,13 @@ mod tests {
let active_pane = cx.read(|cx| workspace.read(cx).active_pane().clone());
active_pane
.update(cx, |pane, cx| {
pane.close_active_item(&workspace::CloseActiveItem, cx)
.unwrap()
pane.close_active_item(
&workspace::CloseActiveItem {
save_behavior: None,
},
cx,
)
.unwrap()
})
.await
.unwrap();

View File

@@ -3513,14 +3513,12 @@ impl<'a, 'b, 'c, V> LayoutContext<'a, 'b, 'c, V> {
handler_depth = Some(contexts.len())
}
let action_contexts = if let Some(depth) = handler_depth {
&contexts[depth..]
} else {
&contexts
};
self.keystroke_matcher
.keystrokes_for_action(action, action_contexts)
let handler_depth = handler_depth.unwrap_or(0);
(0..=handler_depth).find_map(|depth| {
let contexts = &contexts[depth..];
self.keystroke_matcher
.keystrokes_for_action(action, contexts)
})
}
fn notify_if_view_ancestors_change(&mut self, view_id: usize) {
@@ -6499,7 +6497,7 @@ mod tests {
#[crate::test(self)]
fn test_keystrokes_for_action(cx: &mut TestAppContext) {
actions!(test, [Action1, Action2, GlobalAction]);
actions!(test, [Action1, Action2, Action3, GlobalAction]);
struct View1 {
child: ViewHandle<View2>,
@@ -6542,12 +6540,14 @@ mod tests {
cx.update(|cx| {
cx.add_action(|_: &mut View1, _: &Action1, _cx| {});
cx.add_action(|_: &mut View1, _: &Action3, _cx| {});
cx.add_action(|_: &mut View2, _: &Action2, _cx| {});
cx.add_global_action(|_: &GlobalAction, _| {});
cx.add_bindings(vec![
Binding::new("a", Action1, Some("View1")),
Binding::new("b", Action2, Some("View1 > View2")),
Binding::new("c", GlobalAction, Some("View3")), // View 3 does not exist
Binding::new("c", Action3, Some("View2")),
Binding::new("d", GlobalAction, Some("View3")), // View 3 does not exist
]);
});
@@ -6577,6 +6577,14 @@ mod tests {
.as_slice(),
&[Keystroke::parse("b").unwrap()]
);
assert_eq!(layout_cx.keystrokes_for_action(view_1.id(), &Action3), None);
assert_eq!(
layout_cx
.keystrokes_for_action(view_2.id(), &Action3)
.unwrap()
.as_slice(),
&[Keystroke::parse("c").unwrap()]
);
// The 'a' keystroke propagates up the view tree from view_2
// to view_1. The action, Action1, is handled by view_1.
@@ -6604,7 +6612,8 @@ mod tests {
&available_actions(window.into(), view_1.id(), cx),
&[
("test::Action1", vec![Keystroke::parse("a").unwrap()]),
("test::GlobalAction", vec![])
("test::Action3", vec![]),
("test::GlobalAction", vec![]),
],
);
@@ -6614,6 +6623,7 @@ mod tests {
&[
("test::Action1", vec![Keystroke::parse("a").unwrap()]),
("test::Action2", vec![Keystroke::parse("b").unwrap()]),
("test::Action3", vec![Keystroke::parse("c").unwrap()]),
("test::GlobalAction", vec![]),
],
);

View File

@@ -1110,7 +1110,7 @@ impl<'a> WindowContext<'a> {
self.window.is_fullscreen
}
pub(crate) fn dispatch_action(&mut self, view_id: Option<usize>, action: &dyn Action) -> bool {
pub fn dispatch_action(&mut self, view_id: Option<usize>, action: &dyn Action) -> bool {
if let Some(view_id) = view_id {
self.halt_action_dispatch = false;
self.visit_dispatch_path(view_id, |view_id, capture_phase, cx| {

View File

@@ -106,6 +106,7 @@ pub struct Deterministic {
parker: parking_lot::Mutex<parking::Parker>,
}
#[must_use]
pub enum Timer {
Production(smol::Timer),
#[cfg(any(test, feature = "test-support"))]

View File

@@ -37,8 +37,14 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
Some("seed") => starting_seed = parse_int(&meta.lit)?,
Some("on_failure") => {
if let Lit::Str(name) = meta.lit {
let ident = Ident::new(&name.value(), name.span());
on_failure_fn_name = quote!(Some(#ident));
let mut path = syn::Path {
leading_colon: None,
segments: Default::default(),
};
for part in name.value().split("::") {
path.segments.push(Ident::new(part, name.span()).into());
}
on_failure_fn_name = quote!(Some(#path));
} else {
return Err(TokenStream::from(
syn::Error::new(

View File

@@ -148,6 +148,7 @@ pub struct Completion {
pub old_range: Range<Anchor>,
pub new_text: String,
pub label: CodeLabel,
pub server_id: LanguageServerId,
pub lsp_completion: lsp::CompletionItem,
}
@@ -438,7 +439,7 @@ impl Buffer {
operations.extend(
text_operations
.iter()
.filter(|(_, op)| !since.observed(op.local_timestamp()))
.filter(|(_, op)| !since.observed(op.timestamp()))
.map(|(_, op)| proto::serialize_operation(&Operation::Buffer(op.clone()))),
);
operations.sort_unstable_by_key(proto::lamport_timestamp_for_operation);
@@ -1303,7 +1304,7 @@ impl Buffer {
pub fn wait_for_edits(
&mut self,
edit_ids: impl IntoIterator<Item = clock::Local>,
edit_ids: impl IntoIterator<Item = clock::Lamport>,
) -> impl Future<Output = Result<()>> {
self.text.wait_for_edits(edit_ids)
}
@@ -1361,7 +1362,7 @@ impl Buffer {
}
}
pub fn set_text<T>(&mut self, text: T, cx: &mut ModelContext<Self>) -> Option<clock::Local>
pub fn set_text<T>(&mut self, text: T, cx: &mut ModelContext<Self>) -> Option<clock::Lamport>
where
T: Into<Arc<str>>,
{
@@ -1374,7 +1375,7 @@ impl Buffer {
edits_iter: I,
autoindent_mode: Option<AutoindentMode>,
cx: &mut ModelContext<Self>,
) -> Option<clock::Local>
) -> Option<clock::Lamport>
where
I: IntoIterator<Item = (Range<S>, T)>,
S: ToOffset,
@@ -1411,7 +1412,7 @@ impl Buffer {
.and_then(|mode| self.language.as_ref().map(|_| (self.snapshot(), mode)));
let edit_operation = self.text.edit(edits.iter().cloned());
let edit_id = edit_operation.local_timestamp();
let edit_id = edit_operation.timestamp();
if let Some((before_edit, mode)) = autoindent_request {
let mut delta = 0isize;
@@ -2216,8 +2217,8 @@ impl BufferSnapshot {
let mut next_chars = self.chars_at(start).peekable();
let mut prev_chars = self.reversed_chars_at(start).peekable();
let language = self.language_at(start);
let kind = |c| char_kind(language, c);
let scope = self.language_scope_at(start);
let kind = |c| char_kind(&scope, c);
let word_kind = cmp::max(
prev_chars.peek().copied().map(kind),
next_chars.peek().copied().map(kind),
@@ -3031,17 +3032,21 @@ pub fn contiguous_ranges(
})
}
pub fn char_kind(language: Option<&Arc<Language>>, c: char) -> CharKind {
pub fn char_kind(scope: &Option<LanguageScope>, c: char) -> CharKind {
if c.is_whitespace() {
return CharKind::Whitespace;
} else if c.is_alphanumeric() || c == '_' {
return CharKind::Word;
}
if let Some(language) = language {
if language.config.word_characters.contains(&c) {
return CharKind::Word;
if let Some(scope) = scope {
if let Some(characters) = scope.word_characters() {
if characters.contains(&c) {
return CharKind::Word;
}
}
}
CharKind::Punctuation
}

View File

@@ -46,7 +46,7 @@ use theme::{SyntaxTheme, Theme};
use tree_sitter::{self, Query};
use unicase::UniCase;
use util::{http::HttpClient, paths::PathExt};
use util::{merge_json_value_into, post_inc, ResultExt, TryFutureExt as _, UnwrapFuture};
use util::{post_inc, ResultExt, TryFutureExt as _, UnwrapFuture};
#[cfg(any(test, feature = "test-support"))]
use futures::channel::mpsc;
@@ -91,6 +91,7 @@ pub struct LanguageServerName(pub Arc<str>);
/// once at startup, and caches the results.
pub struct CachedLspAdapter {
pub name: LanguageServerName,
pub short_name: &'static str,
pub initialization_options: Option<Value>,
pub disk_based_diagnostic_sources: Vec<String>,
pub disk_based_diagnostics_progress_token: Option<String>,
@@ -101,6 +102,7 @@ pub struct CachedLspAdapter {
impl CachedLspAdapter {
pub async fn new(adapter: Arc<dyn LspAdapter>) -> Arc<Self> {
let name = adapter.name().await;
let short_name = adapter.short_name();
let initialization_options = adapter.initialization_options().await;
let disk_based_diagnostic_sources = adapter.disk_based_diagnostic_sources().await;
let disk_based_diagnostics_progress_token =
@@ -109,6 +111,7 @@ impl CachedLspAdapter {
Arc::new(CachedLspAdapter {
name,
short_name,
initialization_options,
disk_based_diagnostic_sources,
disk_based_diagnostics_progress_token,
@@ -176,10 +179,7 @@ impl CachedLspAdapter {
self.adapter.code_action_kinds()
}
pub fn workspace_configuration(
&self,
cx: &mut AppContext,
) -> Option<BoxFuture<'static, Value>> {
pub fn workspace_configuration(&self, cx: &mut AppContext) -> BoxFuture<'static, Value> {
self.adapter.workspace_configuration(cx)
}
@@ -220,6 +220,8 @@ pub trait LspAdapterDelegate: Send + Sync {
pub trait LspAdapter: 'static + Send + Sync {
async fn name(&self) -> LanguageServerName;
fn short_name(&self) -> &'static str;
async fn fetch_latest_server_version(
&self,
delegate: &dyn LspAdapterDelegate,
@@ -288,8 +290,8 @@ pub trait LspAdapter: 'static + Send + Sync {
None
}
fn workspace_configuration(&self, _: &mut AppContext) -> Option<BoxFuture<'static, Value>> {
None
fn workspace_configuration(&self, _: &mut AppContext) -> BoxFuture<'static, Value> {
futures::future::ready(serde_json::json!({})).boxed()
}
fn code_action_kinds(&self) -> Option<Vec<CodeActionKind>> {
@@ -344,6 +346,8 @@ pub struct LanguageConfig {
#[serde(default)]
pub block_comment: Option<(Arc<str>, Arc<str>)>,
#[serde(default)]
pub scope_opt_in_language_servers: Vec<String>,
#[serde(default)]
pub overrides: HashMap<String, LanguageConfigOverride>,
#[serde(default)]
pub word_characters: HashSet<char>,
@@ -374,6 +378,10 @@ pub struct LanguageConfigOverride {
pub block_comment: Override<(Arc<str>, Arc<str>)>,
#[serde(skip_deserializing)]
pub disabled_bracket_ixs: Vec<u16>,
#[serde(default)]
pub word_characters: Override<HashSet<char>>,
#[serde(default)]
pub opt_into_language_servers: Vec<String>,
}
#[derive(Clone, Deserialize, Debug)]
@@ -412,6 +420,7 @@ impl Default for LanguageConfig {
autoclose_before: Default::default(),
line_comment: Default::default(),
block_comment: Default::default(),
scope_opt_in_language_servers: Default::default(),
overrides: Default::default(),
collapsed_placeholder: Default::default(),
word_characters: Default::default(),
@@ -686,41 +695,6 @@ impl LanguageRegistry {
result
}
pub fn workspace_configuration(&self, cx: &mut AppContext) -> Task<serde_json::Value> {
let lsp_adapters = {
let state = self.state.read();
state
.available_languages
.iter()
.filter(|l| !l.loaded)
.flat_map(|l| l.lsp_adapters.clone())
.chain(
state
.languages
.iter()
.flat_map(|language| &language.adapters)
.map(|adapter| adapter.adapter.clone()),
)
.collect::<Vec<_>>()
};
let mut language_configs = Vec::new();
for adapter in &lsp_adapters {
if let Some(language_config) = adapter.workspace_configuration(cx) {
language_configs.push(language_config);
}
}
cx.background().spawn(async move {
let mut config = serde_json::json!({});
let language_configs = futures::future::join_all(language_configs).await;
for language_config in language_configs {
merge_json_value_into(language_config, &mut config);
}
config
})
}
pub fn add(&self, language: Arc<Language>) {
self.state.write().add(language);
}
@@ -1384,13 +1358,23 @@ impl Language {
Ok(self)
}
pub fn with_override_query(mut self, source: &str) -> Result<Self> {
pub fn with_override_query(mut self, source: &str) -> anyhow::Result<Self> {
let query = Query::new(self.grammar_mut().ts_language, source)?;
let mut override_configs_by_id = HashMap::default();
for (ix, name) in query.capture_names().iter().enumerate() {
if !name.starts_with('_') {
let value = self.config.overrides.remove(name).unwrap_or_default();
for server_name in &value.opt_into_language_servers {
if !self
.config
.scope_opt_in_language_servers
.contains(server_name)
{
util::debug_panic!("Server {server_name:?} has been opted-in by scope {name:?} but has not been marked as an opt-in server");
}
}
override_configs_by_id.insert(ix as u32, (name.clone(), value));
}
}
@@ -1596,6 +1580,13 @@ impl LanguageScope {
.map(|e| (&e.0, &e.1))
}
pub fn word_characters(&self) -> Option<&HashSet<char>> {
Override::as_option(
self.config_override().map(|o| &o.word_characters),
Some(&self.language.config.word_characters),
)
}
pub fn brackets(&self) -> impl Iterator<Item = (&BracketPair, bool)> {
let mut disabled_ids = self
.config_override()
@@ -1622,6 +1613,20 @@ impl LanguageScope {
c.is_whitespace() || self.language.config.autoclose_before.contains(c)
}
pub fn language_allowed(&self, name: &LanguageServerName) -> bool {
let config = &self.language.config;
let opt_in_servers = &config.scope_opt_in_language_servers;
if opt_in_servers.iter().any(|o| *o == *name.0) {
if let Some(over) = self.config_override() {
over.opt_into_language_servers.iter().any(|o| *o == *name.0)
} else {
false
}
} else {
true
}
}
fn config_override(&self) -> Option<&LanguageConfigOverride> {
let id = self.override_id?;
let grammar = self.language.grammar.as_ref()?;
@@ -1726,6 +1731,10 @@ impl LspAdapter for Arc<FakeLspAdapter> {
LanguageServerName(self.name.into())
}
fn short_name(&self) -> &'static str {
"FakeLspAdapter"
}
async fn fetch_latest_server_version(
&self,
_: &dyn LspAdapterDelegate,

View File

@@ -41,24 +41,22 @@ pub fn serialize_operation(operation: &crate::Operation) -> proto::Operation {
proto::operation::Variant::Edit(serialize_edit_operation(edit))
}
crate::Operation::Buffer(text::Operation::Undo {
undo,
lamport_timestamp,
}) => proto::operation::Variant::Undo(proto::operation::Undo {
replica_id: undo.id.replica_id as u32,
local_timestamp: undo.id.value,
lamport_timestamp: lamport_timestamp.value,
version: serialize_version(&undo.version),
counts: undo
.counts
.iter()
.map(|(edit_id, count)| proto::UndoCount {
replica_id: edit_id.replica_id as u32,
local_timestamp: edit_id.value,
count: *count,
})
.collect(),
}),
crate::Operation::Buffer(text::Operation::Undo(undo)) => {
proto::operation::Variant::Undo(proto::operation::Undo {
replica_id: undo.timestamp.replica_id as u32,
lamport_timestamp: undo.timestamp.value,
version: serialize_version(&undo.version),
counts: undo
.counts
.iter()
.map(|(edit_id, count)| proto::UndoCount {
replica_id: edit_id.replica_id as u32,
lamport_timestamp: edit_id.value,
count: *count,
})
.collect(),
})
}
crate::Operation::UpdateSelections {
selections,
@@ -101,8 +99,7 @@ pub fn serialize_operation(operation: &crate::Operation) -> proto::Operation {
pub fn serialize_edit_operation(operation: &EditOperation) -> proto::operation::Edit {
proto::operation::Edit {
replica_id: operation.timestamp.replica_id as u32,
local_timestamp: operation.timestamp.local,
lamport_timestamp: operation.timestamp.lamport,
lamport_timestamp: operation.timestamp.value,
version: serialize_version(&operation.version),
ranges: operation.ranges.iter().map(serialize_range).collect(),
new_text: operation
@@ -114,7 +111,7 @@ pub fn serialize_edit_operation(operation: &EditOperation) -> proto::operation::
}
pub fn serialize_undo_map_entry(
(edit_id, counts): (&clock::Local, &[(clock::Local, u32)]),
(edit_id, counts): (&clock::Lamport, &[(clock::Lamport, u32)]),
) -> proto::UndoMapEntry {
proto::UndoMapEntry {
replica_id: edit_id.replica_id as u32,
@@ -123,13 +120,38 @@ pub fn serialize_undo_map_entry(
.iter()
.map(|(undo_id, count)| proto::UndoCount {
replica_id: undo_id.replica_id as u32,
local_timestamp: undo_id.value,
lamport_timestamp: undo_id.value,
count: *count,
})
.collect(),
}
}
pub fn split_operations(
mut operations: Vec<proto::Operation>,
) -> impl Iterator<Item = Vec<proto::Operation>> {
#[cfg(any(test, feature = "test-support"))]
const CHUNK_SIZE: usize = 5;
#[cfg(not(any(test, feature = "test-support")))]
const CHUNK_SIZE: usize = 100;
let mut done = false;
std::iter::from_fn(move || {
if done {
return None;
}
let operations = operations
.drain(..std::cmp::min(CHUNK_SIZE, operations.len()))
.collect::<Vec<_>>();
if operations.is_empty() {
done = true;
}
Some(operations)
})
}
pub fn serialize_selections(selections: &Arc<[Selection<Anchor>]>) -> Vec<proto::Selection> {
selections.iter().map(serialize_selection).collect()
}
@@ -197,7 +219,7 @@ pub fn serialize_diagnostics<'a>(
pub fn serialize_anchor(anchor: &Anchor) -> proto::Anchor {
proto::Anchor {
replica_id: anchor.timestamp.replica_id as u32,
local_timestamp: anchor.timestamp.value,
timestamp: anchor.timestamp.value,
offset: anchor.offset as u64,
bias: match anchor.bias {
Bias::Left => proto::Bias::Left as i32,
@@ -218,32 +240,26 @@ pub fn deserialize_operation(message: proto::Operation) -> Result<crate::Operati
crate::Operation::Buffer(text::Operation::Edit(deserialize_edit_operation(edit)))
}
proto::operation::Variant::Undo(undo) => {
crate::Operation::Buffer(text::Operation::Undo {
lamport_timestamp: clock::Lamport {
crate::Operation::Buffer(text::Operation::Undo(UndoOperation {
timestamp: clock::Lamport {
replica_id: undo.replica_id as ReplicaId,
value: undo.lamport_timestamp,
},
undo: UndoOperation {
id: clock::Local {
replica_id: undo.replica_id as ReplicaId,
value: undo.local_timestamp,
},
version: deserialize_version(&undo.version),
counts: undo
.counts
.into_iter()
.map(|c| {
(
clock::Local {
replica_id: c.replica_id as ReplicaId,
value: c.local_timestamp,
},
c.count,
)
})
.collect(),
},
})
version: deserialize_version(&undo.version),
counts: undo
.counts
.into_iter()
.map(|c| {
(
clock::Lamport {
replica_id: c.replica_id as ReplicaId,
value: c.lamport_timestamp,
},
c.count,
)
})
.collect(),
}))
}
proto::operation::Variant::UpdateSelections(message) => {
let selections = message
@@ -298,10 +314,9 @@ pub fn deserialize_operation(message: proto::Operation) -> Result<crate::Operati
pub fn deserialize_edit_operation(edit: proto::operation::Edit) -> EditOperation {
EditOperation {
timestamp: InsertionTimestamp {
timestamp: clock::Lamport {
replica_id: edit.replica_id as ReplicaId,
local: edit.local_timestamp,
lamport: edit.lamport_timestamp,
value: edit.lamport_timestamp,
},
version: deserialize_version(&edit.version),
ranges: edit.ranges.into_iter().map(deserialize_range).collect(),
@@ -311,9 +326,9 @@ pub fn deserialize_edit_operation(edit: proto::operation::Edit) -> EditOperation
pub fn deserialize_undo_map_entry(
entry: proto::UndoMapEntry,
) -> (clock::Local, Vec<(clock::Local, u32)>) {
) -> (clock::Lamport, Vec<(clock::Lamport, u32)>) {
(
clock::Local {
clock::Lamport {
replica_id: entry.replica_id as u16,
value: entry.local_timestamp,
},
@@ -322,9 +337,9 @@ pub fn deserialize_undo_map_entry(
.into_iter()
.map(|undo_count| {
(
clock::Local {
clock::Lamport {
replica_id: undo_count.replica_id as u16,
value: undo_count.local_timestamp,
value: undo_count.lamport_timestamp,
},
undo_count.count,
)
@@ -384,9 +399,9 @@ pub fn deserialize_diagnostics(
pub fn deserialize_anchor(anchor: proto::Anchor) -> Option<Anchor> {
Some(Anchor {
timestamp: clock::Local {
timestamp: clock::Lamport {
replica_id: anchor.replica_id as ReplicaId,
value: anchor.local_timestamp,
value: anchor.timestamp,
},
offset: anchor.offset as usize,
bias: match proto::Bias::from_i32(anchor.bias)? {
@@ -434,6 +449,7 @@ pub fn serialize_completion(completion: &Completion) -> proto::Completion {
old_start: Some(serialize_anchor(&completion.old_range.start)),
old_end: Some(serialize_anchor(&completion.old_range.end)),
new_text: completion.new_text.clone(),
server_id: completion.server_id.0 as u64,
lsp_completion: serde_json::to_vec(&completion.lsp_completion).unwrap(),
}
}
@@ -466,6 +482,7 @@ pub async fn deserialize_completion(
lsp_completion.filter_text.as_deref(),
)
}),
server_id: LanguageServerId(completion.server_id as usize),
lsp_completion,
})
}
@@ -498,12 +515,12 @@ pub fn deserialize_code_action(action: proto::CodeAction) -> Result<CodeAction>
pub fn serialize_transaction(transaction: &Transaction) -> proto::Transaction {
proto::Transaction {
id: Some(serialize_local_timestamp(transaction.id)),
id: Some(serialize_timestamp(transaction.id)),
edit_ids: transaction
.edit_ids
.iter()
.copied()
.map(serialize_local_timestamp)
.map(serialize_timestamp)
.collect(),
start: serialize_version(&transaction.start),
}
@@ -511,7 +528,7 @@ pub fn serialize_transaction(transaction: &Transaction) -> proto::Transaction {
pub fn deserialize_transaction(transaction: proto::Transaction) -> Result<Transaction> {
Ok(Transaction {
id: deserialize_local_timestamp(
id: deserialize_timestamp(
transaction
.id
.ok_or_else(|| anyhow!("missing transaction id"))?,
@@ -519,21 +536,21 @@ pub fn deserialize_transaction(transaction: proto::Transaction) -> Result<Transa
edit_ids: transaction
.edit_ids
.into_iter()
.map(deserialize_local_timestamp)
.map(deserialize_timestamp)
.collect(),
start: deserialize_version(&transaction.start),
})
}
pub fn serialize_local_timestamp(timestamp: clock::Local) -> proto::LocalTimestamp {
proto::LocalTimestamp {
pub fn serialize_timestamp(timestamp: clock::Lamport) -> proto::LamportTimestamp {
proto::LamportTimestamp {
replica_id: timestamp.replica_id as u32,
value: timestamp.value,
}
}
pub fn deserialize_local_timestamp(timestamp: proto::LocalTimestamp) -> clock::Local {
clock::Local {
pub fn deserialize_timestamp(timestamp: proto::LamportTimestamp) -> clock::Lamport {
clock::Lamport {
replica_id: timestamp.replica_id as ReplicaId,
value: timestamp.value,
}
@@ -553,7 +570,7 @@ pub fn deserialize_range(range: proto::Range) -> Range<FullOffset> {
pub fn deserialize_version(message: &[proto::VectorClockEntry]) -> clock::Global {
let mut version = clock::Global::new();
for entry in message {
version.observe(clock::Local {
version.observe(clock::Lamport {
replica_id: entry.replica_id as ReplicaId,
value: entry.timestamp,
});

View File

@@ -52,6 +52,7 @@ impl View for ActiveBufferLanguage {
} else {
"Unknown".to_string()
};
let theme = theme::current(cx).clone();
MouseEventHandler::new::<Self, _>(0, cx, |state, cx| {
let theme = &theme::current(cx).workspace.status_bar;
@@ -68,6 +69,7 @@ impl View for ActiveBufferLanguage {
});
}
})
.with_tooltip::<Self>(0, "Select Language", None, theme.tooltip.clone(), cx)
.into_any()
} else {
Empty::new().into_any()

View File

@@ -570,10 +570,12 @@ impl View for LspLogToolbarItemView {
let Some(log_view) = self.log_view.as_ref() else {
return Empty::new().into_any();
};
let log_view = log_view.read(cx);
let menu_rows = log_view.menu_items(cx).unwrap_or_default();
let (menu_rows, current_server_id) = log_view.update(cx, |log_view, cx| {
let menu_rows = log_view.menu_items(cx).unwrap_or_default();
let current_server_id = log_view.current_server_id;
(menu_rows, current_server_id)
});
let current_server_id = log_view.current_server_id;
let current_server = current_server_id.and_then(|current_server_id| {
if let Ok(ix) = menu_rows.binary_search_by_key(&current_server_id, |e| e.server_id) {
Some(menu_rows[ix].clone())
@@ -581,10 +583,10 @@ impl View for LspLogToolbarItemView {
None
}
});
let server_selected = current_server.is_some();
enum Menu {}
Stack::new()
let lsp_menu = Stack::new()
.with_child(Self::render_language_server_menu_header(
current_server,
&theme,
@@ -631,8 +633,47 @@ impl View for LspLogToolbarItemView {
})
.aligned()
.left()
.clipped()
.into_any()
.clipped();
enum LspCleanupButton {}
let log_cleanup_button =
MouseEventHandler::new::<LspCleanupButton, _>(1, cx, |state, cx| {
let theme = theme::current(cx).clone();
let style = theme
.workspace
.toolbar
.toggleable_text_tool
.in_state(server_selected)
.style_for(state);
Label::new("Clear", style.text.clone())
.aligned()
.contained()
.with_style(style.container)
.constrained()
.with_height(theme.toolbar_dropdown_menu.row_height / 6.0 * 5.0)
})
.on_click(MouseButton::Left, move |_, this, cx| {
if let Some(log_view) = this.log_view.as_ref() {
log_view.update(cx, |log_view, cx| {
log_view.editor.update(cx, |editor, cx| {
editor.set_read_only(false);
editor.clear(cx);
editor.set_read_only(true);
});
})
}
})
.with_cursor_style(CursorStyle::PointingHand)
.aligned()
.right();
Flex::row()
.with_child(lsp_menu)
.with_child(log_cleanup_button)
.contained()
.aligned()
.left()
.into_any_named("lsp log controls")
}
}

View File

@@ -63,6 +63,7 @@ fn build_bridge(swift_target: &SwiftTarget) {
let swift_target_folder = swift_target_folder();
if !Command::new("swift")
.arg("build")
.arg("--disable-automatic-resolution")
.args(["--configuration", &env::var("PROFILE").unwrap()])
.args(["--triple", &swift_target.target.triple])
.args(["--build-path".into(), swift_target_folder])

View File

@@ -20,7 +20,7 @@ anyhow.workspace = true
async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553", optional = true }
futures.workspace = true
log.workspace = true
lsp-types = "0.94"
lsp-types = { git = "https://github.com/zed-industries/lsp-types", branch = "updated-completion-list-item-defaults" }
parking_lot.workspace = true
postage.workspace = true
serde.workspace = true

View File

@@ -4,7 +4,7 @@ pub use lsp_types::*;
use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use futures::{channel::oneshot, io::BufWriter, AsyncRead, AsyncWrite};
use futures::{channel::oneshot, io::BufWriter, AsyncRead, AsyncWrite, FutureExt};
use gpui::{executor, AsyncAppContext, Task};
use parking_lot::Mutex;
use postage::{barrier, prelude::Stream};
@@ -26,12 +26,14 @@ use std::{
atomic::{AtomicUsize, Ordering::SeqCst},
Arc, Weak,
},
time::{Duration, Instant},
};
use std::{path::Path, process::Stdio};
use util::{ResultExt, TryFutureExt};
const JSON_RPC_VERSION: &str = "2.0";
const CONTENT_LEN_HEADER: &str = "Content-Length: ";
const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2);
type NotificationHandler = Box<dyn Send + FnMut(Option<usize>, &str, AsyncAppContext)>;
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
@@ -303,7 +305,7 @@ impl LanguageServer {
stdout.read_exact(&mut buffer).await?;
if let Ok(message) = str::from_utf8(&buffer) {
log::trace!("incoming message:{}", message);
log::trace!("incoming message: {}", message);
for handler in io_handlers.lock().values_mut() {
handler(IoKind::StdOut, message);
}
@@ -468,6 +470,14 @@ impl LanguageServer {
}),
..Default::default()
}),
completion_list: Some(CompletionListCapability {
item_defaults: Some(vec![
"commitCharacters".to_owned(),
"editRange".to_owned(),
"insertTextMode".to_owned(),
"data".to_owned(),
]),
}),
..Default::default()
}),
rename: Some(RenameClientCapabilities {
@@ -740,7 +750,7 @@ impl LanguageServer {
outbound_tx: &channel::Sender<String>,
executor: &Arc<executor::Background>,
params: T::Params,
) -> impl 'static + Future<Output = Result<T::Result>>
) -> impl 'static + Future<Output = anyhow::Result<T::Result>>
where
T::Result: 'static + Send,
{
@@ -781,10 +791,25 @@ impl LanguageServer {
.try_send(message)
.context("failed to write to language server's stdin");
let mut timeout = executor.timer(LSP_REQUEST_TIMEOUT).fuse();
let started = Instant::now();
async move {
handle_response?;
send?;
rx.await?
let method = T::METHOD;
futures::select! {
response = rx.fuse() => {
let elapsed = started.elapsed();
log::trace!("Took {elapsed:?} to recieve response to {method:?} id {id}");
response?
}
_ = timeout => {
log::error!("Cancelled LSP request task for {method:?} id {id} which took over {LSP_REQUEST_TIMEOUT:?}");
anyhow::bail!("LSP request timeout");
}
}
}
}

View File

@@ -14,6 +14,7 @@ util = { path = "../util" }
async-compression = { version = "0.3", features = ["gzip", "futures-bufread"] }
async-tar = "0.4.2"
futures.workspace = true
async-trait.workspace = true
anyhow.workspace = true
parking_lot.workspace = true
serde.workspace = true

View File

@@ -7,14 +7,12 @@ use std::process::{Output, Stdio};
use std::{
env::consts,
path::{Path, PathBuf},
sync::{Arc, OnceLock},
sync::Arc,
};
use util::http::HttpClient;
const VERSION: &str = "v18.15.0";
static RUNTIME_INSTANCE: OnceLock<Arc<NodeRuntime>> = OnceLock::new();
#[derive(Debug, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct NpmInfo {
@@ -28,23 +26,88 @@ pub struct NpmInfoDistTags {
latest: Option<String>,
}
pub struct NodeRuntime {
#[async_trait::async_trait]
pub trait NodeRuntime: Send + Sync {
async fn binary_path(&self) -> Result<PathBuf>;
async fn run_npm_subcommand(
&self,
directory: Option<&Path>,
subcommand: &str,
args: &[&str],
) -> Result<Output>;
async fn npm_package_latest_version(&self, name: &str) -> Result<String>;
async fn npm_install_packages(&self, directory: &Path, packages: &[(&str, &str)])
-> Result<()>;
}
pub struct RealNodeRuntime {
http: Arc<dyn HttpClient>,
}
impl NodeRuntime {
pub fn instance(http: Arc<dyn HttpClient>) -> Arc<NodeRuntime> {
RUNTIME_INSTANCE
.get_or_init(|| Arc::new(NodeRuntime { http }))
.clone()
impl RealNodeRuntime {
pub fn new(http: Arc<dyn HttpClient>) -> Arc<dyn NodeRuntime> {
Arc::new(RealNodeRuntime { http })
}
pub async fn binary_path(&self) -> Result<PathBuf> {
async fn install_if_needed(&self) -> Result<PathBuf> {
log::info!("Node runtime install_if_needed");
let arch = match consts::ARCH {
"x86_64" => "x64",
"aarch64" => "arm64",
other => bail!("Running on unsupported platform: {other}"),
};
let folder_name = format!("node-{VERSION}-darwin-{arch}");
let node_containing_dir = util::paths::SUPPORT_DIR.join("node");
let node_dir = node_containing_dir.join(folder_name);
let node_binary = node_dir.join("bin/node");
let npm_file = node_dir.join("bin/npm");
let result = Command::new(&node_binary)
.arg(npm_file)
.arg("--version")
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.status()
.await;
let valid = matches!(result, Ok(status) if status.success());
if !valid {
_ = fs::remove_dir_all(&node_containing_dir).await;
fs::create_dir(&node_containing_dir)
.await
.context("error creating node containing dir")?;
let file_name = format!("node-{VERSION}-darwin-{arch}.tar.gz");
let url = format!("https://nodejs.org/dist/{VERSION}/{file_name}");
let mut response = self
.http
.get(&url, Default::default(), true)
.await
.context("error downloading Node binary tarball")?;
let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
let archive = Archive::new(decompressed_bytes);
archive.unpack(&node_containing_dir).await?;
}
anyhow::Ok(node_dir)
}
}
#[async_trait::async_trait]
impl NodeRuntime for RealNodeRuntime {
async fn binary_path(&self) -> Result<PathBuf> {
let installation_path = self.install_if_needed().await?;
Ok(installation_path.join("bin/node"))
}
pub async fn run_npm_subcommand(
async fn run_npm_subcommand(
&self,
directory: Option<&Path>,
subcommand: &str,
@@ -106,7 +169,7 @@ impl NodeRuntime {
output.map_err(|e| anyhow!("{e}"))
}
pub async fn npm_package_latest_version(&self, name: &str) -> Result<String> {
async fn npm_package_latest_version(&self, name: &str) -> Result<String> {
let output = self
.run_npm_subcommand(
None,
@@ -131,10 +194,10 @@ impl NodeRuntime {
.ok_or_else(|| anyhow!("no version found for npm package {}", name))
}
pub async fn npm_install_packages(
async fn npm_install_packages(
&self,
directory: &Path,
packages: impl IntoIterator<Item = (&str, &str)>,
packages: &[(&str, &str)],
) -> Result<()> {
let packages: Vec<_> = packages
.into_iter()
@@ -155,51 +218,31 @@ impl NodeRuntime {
.await?;
Ok(())
}
}
async fn install_if_needed(&self) -> Result<PathBuf> {
log::info!("Node runtime install_if_needed");
pub struct FakeNodeRuntime;
let arch = match consts::ARCH {
"x86_64" => "x64",
"aarch64" => "arm64",
other => bail!("Running on unsupported platform: {other}"),
};
let folder_name = format!("node-{VERSION}-darwin-{arch}");
let node_containing_dir = util::paths::SUPPORT_DIR.join("node");
let node_dir = node_containing_dir.join(folder_name);
let node_binary = node_dir.join("bin/node");
let npm_file = node_dir.join("bin/npm");
let result = Command::new(&node_binary)
.arg(npm_file)
.arg("--version")
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.status()
.await;
let valid = matches!(result, Ok(status) if status.success());
if !valid {
_ = fs::remove_dir_all(&node_containing_dir).await;
fs::create_dir(&node_containing_dir)
.await
.context("error creating node containing dir")?;
let file_name = format!("node-{VERSION}-darwin-{arch}.tar.gz");
let url = format!("https://nodejs.org/dist/{VERSION}/{file_name}");
let mut response = self
.http
.get(&url, Default::default(), true)
.await
.context("error downloading Node binary tarball")?;
let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
let archive = Archive::new(decompressed_bytes);
archive.unpack(&node_containing_dir).await?;
}
anyhow::Ok(node_dir)
impl FakeNodeRuntime {
pub fn new() -> Arc<dyn NodeRuntime> {
Arc::new(FakeNodeRuntime)
}
}
#[async_trait::async_trait]
impl NodeRuntime for FakeNodeRuntime {
async fn binary_path(&self) -> Result<PathBuf> {
unreachable!()
}
async fn run_npm_subcommand(&self, _: Option<&Path>, _: &str, _: &[&str]) -> Result<Output> {
unreachable!()
}
async fn npm_package_latest_version(&self, _: &str) -> Result<String> {
unreachable!()
}
async fn npm_install_packages(&self, _: &Path, _: &[(&str, &str)]) -> Result<()> {
unreachable!()
}
}

View File

@@ -16,7 +16,10 @@ use language::{
CodeAction, Completion, OffsetRangeExt, PointUtf16, ToOffset, ToPointUtf16, Transaction,
Unclipped,
};
use lsp::{DocumentHighlightKind, LanguageServer, LanguageServerId, OneOf, ServerCapabilities};
use lsp::{
CompletionListItemDefaultsEditRange, DocumentHighlightKind, LanguageServer, LanguageServerId,
OneOf, ServerCapabilities,
};
use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
use text::LineEnding;
@@ -1340,13 +1343,19 @@ impl LspCommand for GetCompletions {
completions: Option<lsp::CompletionResponse>,
_: ModelHandle<Project>,
buffer: ModelHandle<Buffer>,
_: LanguageServerId,
server_id: LanguageServerId,
cx: AsyncAppContext,
) -> Result<Vec<Completion>> {
let mut response_list = None;
let completions = if let Some(completions) = completions {
match completions {
lsp::CompletionResponse::Array(completions) => completions,
lsp::CompletionResponse::List(list) => list.items,
lsp::CompletionResponse::List(mut list) => {
let items = std::mem::take(&mut list.items);
response_list = Some(list);
items
}
}
} else {
Default::default()
@@ -1356,6 +1365,7 @@ impl LspCommand for GetCompletions {
let language = buffer.language().cloned();
let snapshot = buffer.snapshot();
let clipped_position = buffer.clip_point_utf16(Unclipped(self.position), Bias::Left);
let mut range_for_token = None;
completions
.into_iter()
@@ -1376,6 +1386,7 @@ impl LspCommand for GetCompletions {
edit.new_text.clone(),
)
}
// If the language server does not provide a range, then infer
// the range based on the syntax tree.
None => {
@@ -1383,27 +1394,51 @@ impl LspCommand for GetCompletions {
log::info!("completion out of expected range");
return None;
}
let Range { start, end } = range_for_token
.get_or_insert_with(|| {
let offset = self.position.to_offset(&snapshot);
let (range, kind) = snapshot.surrounding_word(offset);
if kind == Some(CharKind::Word) {
range
} else {
offset..offset
}
})
.clone();
let default_edit_range = response_list
.as_ref()
.and_then(|list| list.item_defaults.as_ref())
.and_then(|defaults| defaults.edit_range.as_ref())
.and_then(|range| match range {
CompletionListItemDefaultsEditRange::Range(r) => Some(r),
_ => None,
});
let range = if let Some(range) = default_edit_range {
let range = range_from_lsp(range.clone());
let start = snapshot.clip_point_utf16(range.start, Bias::Left);
let end = snapshot.clip_point_utf16(range.end, Bias::Left);
if start != range.start.0 || end != range.end.0 {
log::info!("completion out of expected range");
return None;
}
snapshot.anchor_before(start)..snapshot.anchor_after(end)
} else {
range_for_token
.get_or_insert_with(|| {
let offset = self.position.to_offset(&snapshot);
let (range, kind) = snapshot.surrounding_word(offset);
let range = if kind == Some(CharKind::Word) {
range
} else {
offset..offset
};
snapshot.anchor_before(range.start)
..snapshot.anchor_after(range.end)
})
.clone()
};
let text = lsp_completion
.insert_text
.as_ref()
.unwrap_or(&lsp_completion.label)
.clone();
(
snapshot.anchor_before(start)..snapshot.anchor_after(end),
text,
)
(range, text)
}
Some(lsp::CompletionTextEdit::InsertAndReplace(_)) => {
log::info!("unsupported insert/replace completion");
return None;
@@ -1427,6 +1462,7 @@ impl LspCommand for GetCompletions {
lsp_completion.filter_text.as_deref(),
)
}),
server_id,
lsp_completion,
}
})

View File

@@ -35,7 +35,7 @@ use language::{
point_to_lsp,
proto::{
deserialize_anchor, deserialize_fingerprint, deserialize_line_ending, deserialize_version,
serialize_anchor, serialize_version,
serialize_anchor, serialize_version, split_operations,
},
range_from_lsp, range_to_lsp, Bias, Buffer, BufferSnapshot, CachedLspAdapter, CodeAction,
CodeLabel, Completion, Diagnostic, DiagnosticEntry, DiagnosticSet, Diff, Event as BufferEvent,
@@ -156,6 +156,11 @@ struct DelayedDebounced {
cancel_channel: Option<oneshot::Sender<()>>,
}
enum LanguageServerToQuery {
Primary,
Other(LanguageServerId),
}
impl DelayedDebounced {
fn new() -> DelayedDebounced {
DelayedDebounced {
@@ -634,7 +639,7 @@ impl Project {
cx.observe_global::<SettingsStore, _>(Self::on_settings_changed)
],
_maintain_buffer_languages: Self::maintain_buffer_languages(languages.clone(), cx),
_maintain_workspace_config: Self::maintain_workspace_config(languages.clone(), cx),
_maintain_workspace_config: Self::maintain_workspace_config(cx),
active_entry: None,
languages,
client,
@@ -704,7 +709,7 @@ impl Project {
collaborators: Default::default(),
join_project_response_message_id: response.message_id,
_maintain_buffer_languages: Self::maintain_buffer_languages(languages.clone(), cx),
_maintain_workspace_config: Self::maintain_workspace_config(languages.clone(), cx),
_maintain_workspace_config: Self::maintain_workspace_config(cx),
languages,
user_store: user_store.clone(),
fs,
@@ -2472,35 +2477,42 @@ impl Project {
})
}
fn maintain_workspace_config(
languages: Arc<LanguageRegistry>,
cx: &mut ModelContext<Project>,
) -> Task<()> {
fn maintain_workspace_config(cx: &mut ModelContext<Project>) -> Task<()> {
let (mut settings_changed_tx, mut settings_changed_rx) = watch::channel();
let _ = postage::stream::Stream::try_recv(&mut settings_changed_rx);
let settings_observation = cx.observe_global::<SettingsStore, _>(move |_, _| {
*settings_changed_tx.borrow_mut() = ();
});
cx.spawn_weak(|this, mut cx| async move {
while let Some(_) = settings_changed_rx.next().await {
let workspace_config = cx.update(|cx| languages.workspace_configuration(cx)).await;
if let Some(this) = this.upgrade(&cx) {
this.read_with(&cx, |this, _| {
for server_state in this.language_servers.values() {
if let LanguageServerState::Running { server, .. } = server_state {
server
.notify::<lsp::notification::DidChangeConfiguration>(
lsp::DidChangeConfigurationParams {
settings: workspace_config.clone(),
},
)
.ok();
}
}
})
} else {
let Some(this) = this.upgrade(&cx) else {
break;
};
let servers: Vec<_> = this.read_with(&cx, |this, _| {
this.language_servers
.values()
.filter_map(|state| match state {
LanguageServerState::Starting(_) => None,
LanguageServerState::Running {
adapter, server, ..
} => Some((adapter.clone(), server.clone())),
})
.collect()
});
for (adapter, server) in servers {
let workspace_config =
cx.update(|cx| adapter.workspace_configuration(cx)).await;
server
.notify::<lsp::notification::DidChangeConfiguration>(
lsp::DidChangeConfigurationParams {
settings: workspace_config.clone(),
},
)
.ok();
}
}
@@ -2615,7 +2627,6 @@ impl Project {
let state = LanguageServerState::Starting({
let adapter = adapter.clone();
let server_name = adapter.name.0.clone();
let languages = self.languages.clone();
let language = language.clone();
let key = key.clone();
@@ -2625,7 +2636,6 @@ impl Project {
initialization_options,
pending_server,
adapter.clone(),
languages,
language.clone(),
server_id,
key,
@@ -2729,7 +2739,6 @@ impl Project {
initialization_options: Option<serde_json::Value>,
pending_server: PendingLanguageServer,
adapter: Arc<CachedLspAdapter>,
languages: Arc<LanguageRegistry>,
language: Arc<Language>,
server_id: LanguageServerId,
key: (WorktreeId, LanguageServerName),
@@ -2740,7 +2749,6 @@ impl Project {
initialization_options,
pending_server,
adapter.clone(),
languages,
server_id,
cx,
);
@@ -2773,16 +2781,13 @@ impl Project {
initialization_options: Option<serde_json::Value>,
pending_server: PendingLanguageServer,
adapter: Arc<CachedLspAdapter>,
languages: Arc<LanguageRegistry>,
server_id: LanguageServerId,
cx: &mut AsyncAppContext,
) -> Result<Option<Arc<LanguageServer>>> {
let workspace_config = cx.update(|cx| languages.workspace_configuration(cx)).await;
let workspace_config = cx.update(|cx| adapter.workspace_configuration(cx)).await;
let language_server = match pending_server.task.await? {
Some(server) => server.initialize(initialization_options).await?,
None => {
return Ok(None);
}
Some(server) => server,
None => return Ok(None),
};
language_server
@@ -2821,12 +2826,12 @@ impl Project {
language_server
.on_request::<lsp::request::WorkspaceConfiguration, _, _>({
let languages = languages.clone();
let adapter = adapter.clone();
move |params, mut cx| {
let languages = languages.clone();
let adapter = adapter.clone();
async move {
let workspace_config =
cx.update(|cx| languages.workspace_configuration(cx)).await;
cx.update(|cx| adapter.workspace_configuration(cx)).await;
Ok(params
.items
.into_iter()
@@ -2932,6 +2937,8 @@ impl Project {
})
.detach();
let language_server = language_server.initialize(initialization_options).await?;
language_server
.notify::<lsp::notification::DidChangeConfiguration>(
lsp::DidChangeConfigurationParams {
@@ -3892,7 +3899,7 @@ impl Project {
let file = File::from_dyn(buffer.file())?;
let buffer_abs_path = file.as_local().map(|f| f.abs_path(cx));
let server = self
.primary_language_servers_for_buffer(buffer, cx)
.primary_language_server_for_buffer(buffer, cx)
.map(|s| s.1.clone());
Some((buffer_handle, buffer_abs_path, server))
})
@@ -4197,7 +4204,12 @@ impl Project {
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<LocationLink>>> {
let position = position.to_point_utf16(buffer.read(cx));
self.request_lsp(buffer.clone(), GetDefinition { position }, cx)
self.request_lsp(
buffer.clone(),
LanguageServerToQuery::Primary,
GetDefinition { position },
cx,
)
}
pub fn type_definition<T: ToPointUtf16>(
@@ -4207,7 +4219,12 @@ impl Project {
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<LocationLink>>> {
let position = position.to_point_utf16(buffer.read(cx));
self.request_lsp(buffer.clone(), GetTypeDefinition { position }, cx)
self.request_lsp(
buffer.clone(),
LanguageServerToQuery::Primary,
GetTypeDefinition { position },
cx,
)
}
pub fn references<T: ToPointUtf16>(
@@ -4217,7 +4234,12 @@ impl Project {
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Location>>> {
let position = position.to_point_utf16(buffer.read(cx));
self.request_lsp(buffer.clone(), GetReferences { position }, cx)
self.request_lsp(
buffer.clone(),
LanguageServerToQuery::Primary,
GetReferences { position },
cx,
)
}
pub fn document_highlights<T: ToPointUtf16>(
@@ -4227,7 +4249,12 @@ impl Project {
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<DocumentHighlight>>> {
let position = position.to_point_utf16(buffer.read(cx));
self.request_lsp(buffer.clone(), GetDocumentHighlights { position }, cx)
self.request_lsp(
buffer.clone(),
LanguageServerToQuery::Primary,
GetDocumentHighlights { position },
cx,
)
}
pub fn symbols(&self, query: &str, cx: &mut ModelContext<Self>) -> Task<Result<Vec<Symbol>>> {
@@ -4455,17 +4482,66 @@ impl Project {
cx: &mut ModelContext<Self>,
) -> Task<Result<Option<Hover>>> {
let position = position.to_point_utf16(buffer.read(cx));
self.request_lsp(buffer.clone(), GetHover { position }, cx)
self.request_lsp(
buffer.clone(),
LanguageServerToQuery::Primary,
GetHover { position },
cx,
)
}
pub fn completions<T: ToPointUtf16>(
pub fn completions<T: ToOffset + ToPointUtf16>(
&self,
buffer: &ModelHandle<Buffer>,
position: T,
cx: &mut ModelContext<Self>,
) -> Task<Result<Vec<Completion>>> {
let position = position.to_point_utf16(buffer.read(cx));
self.request_lsp(buffer.clone(), GetCompletions { position }, cx)
if self.is_local() {
let snapshot = buffer.read(cx).snapshot();
let offset = position.to_offset(&snapshot);
let scope = snapshot.language_scope_at(offset);
let server_ids: Vec<_> = self
.language_servers_for_buffer(buffer.read(cx), cx)
.filter(|(_, server)| server.capabilities().completion_provider.is_some())
.filter(|(adapter, _)| {
scope
.as_ref()
.map(|scope| scope.language_allowed(&adapter.name))
.unwrap_or(true)
})
.map(|(_, server)| server.server_id())
.collect();
let buffer = buffer.clone();
cx.spawn(|this, mut cx| async move {
let mut tasks = Vec::with_capacity(server_ids.len());
this.update(&mut cx, |this, cx| {
for server_id in server_ids {
tasks.push(this.request_lsp(
buffer.clone(),
LanguageServerToQuery::Other(server_id),
GetCompletions { position },
cx,
));
}
});
let mut completions = Vec::new();
for task in tasks {
if let Ok(new_completions) = task.await {
completions.extend_from_slice(&new_completions);
}
}
Ok(completions)
})
} else if let Some(project_id) = self.remote_id() {
self.send_lsp_proto_request(buffer.clone(), project_id, GetCompletions { position }, cx)
} else {
Task::ready(Ok(Default::default()))
}
}
pub fn apply_additional_edits_for_completion(
@@ -4479,7 +4555,8 @@ impl Project {
let buffer_id = buffer.remote_id();
if self.is_local() {
let lang_server = match self.primary_language_servers_for_buffer(buffer, cx) {
let server_id = completion.server_id;
let lang_server = match self.language_server_for_buffer(buffer, server_id, cx) {
Some((_, server)) => server.clone(),
_ => return Task::ready(Ok(Default::default())),
};
@@ -4586,7 +4663,12 @@ impl Project {
) -> Task<Result<Vec<CodeAction>>> {
let buffer = buffer_handle.read(cx);
let range = buffer.anchor_before(range.start)..buffer.anchor_before(range.end);
self.request_lsp(buffer_handle.clone(), GetCodeActions { range }, cx)
self.request_lsp(
buffer_handle.clone(),
LanguageServerToQuery::Primary,
GetCodeActions { range },
cx,
)
}
pub fn apply_code_action(
@@ -4942,7 +5024,12 @@ impl Project {
cx: &mut ModelContext<Self>,
) -> Task<Result<Option<Range<Anchor>>>> {
let position = position.to_point_utf16(buffer.read(cx));
self.request_lsp(buffer, PrepareRename { position }, cx)
self.request_lsp(
buffer,
LanguageServerToQuery::Primary,
PrepareRename { position },
cx,
)
}
pub fn perform_rename<T: ToPointUtf16>(
@@ -4956,6 +5043,7 @@ impl Project {
let position = position.to_point_utf16(buffer.read(cx));
self.request_lsp(
buffer,
LanguageServerToQuery::Primary,
PerformRename {
position,
new_name,
@@ -4983,6 +5071,7 @@ impl Project {
});
self.request_lsp(
buffer.clone(),
LanguageServerToQuery::Primary,
OnTypeFormatting {
position,
trigger,
@@ -5008,7 +5097,12 @@ impl Project {
let lsp_request = InlayHints { range };
if self.is_local() {
let lsp_request_task = self.request_lsp(buffer_handle.clone(), lsp_request, cx);
let lsp_request_task = self.request_lsp(
buffer_handle.clone(),
LanguageServerToQuery::Primary,
lsp_request,
cx,
);
cx.spawn(|_, mut cx| async move {
buffer_handle
.update(&mut cx, |buffer, _| {
@@ -5441,10 +5535,10 @@ impl Project {
.await;
}
// TODO: Wire this up to allow selecting a server?
fn request_lsp<R: LspCommand>(
&self,
buffer_handle: ModelHandle<Buffer>,
server: LanguageServerToQuery,
request: R,
cx: &mut ModelContext<Self>,
) -> Task<Result<R::Response>>
@@ -5453,11 +5547,19 @@ impl Project {
{
let buffer = buffer_handle.read(cx);
if self.is_local() {
let language_server = match server {
LanguageServerToQuery::Primary => {
match self.primary_language_server_for_buffer(buffer, cx) {
Some((_, server)) => Some(Arc::clone(server)),
None => return Task::ready(Ok(Default::default())),
}
}
LanguageServerToQuery::Other(id) => self
.language_server_for_buffer(buffer, id, cx)
.map(|(_, server)| Arc::clone(server)),
};
let file = File::from_dyn(buffer.file()).and_then(File::as_local);
if let Some((file, language_server)) = file.zip(
self.primary_language_servers_for_buffer(buffer, cx)
.map(|(_, server)| server.clone()),
) {
if let (Some(file), Some(language_server)) = (file, language_server) {
let lsp_params = request.to_lsp(&file.abs_path(cx), buffer, &language_server, cx);
return cx.spawn(|this, cx| async move {
if !request.check_capabilities(language_server.capabilities()) {
@@ -5490,31 +5592,40 @@ impl Project {
});
}
} else if let Some(project_id) = self.remote_id() {
let rpc = self.client.clone();
let message = request.to_proto(project_id, buffer);
return cx.spawn_weak(|this, cx| async move {
// Ensure the project is still alive by the time the task
// is scheduled.
this.upgrade(&cx)
.ok_or_else(|| anyhow!("project dropped"))?;
let response = rpc.request(message).await?;
let this = this
.upgrade(&cx)
.ok_or_else(|| anyhow!("project dropped"))?;
if this.read_with(&cx, |this, _| this.is_read_only()) {
Err(anyhow!("disconnected before completing request"))
} else {
request
.response_from_proto(response, this, buffer_handle, cx)
.await
}
});
return self.send_lsp_proto_request(buffer_handle, project_id, request, cx);
}
Task::ready(Ok(Default::default()))
}
fn send_lsp_proto_request<R: LspCommand>(
&self,
buffer: ModelHandle<Buffer>,
project_id: u64,
request: R,
cx: &mut ModelContext<'_, Project>,
) -> Task<anyhow::Result<<R as LspCommand>::Response>> {
let rpc = self.client.clone();
let message = request.to_proto(project_id, buffer.read(cx));
cx.spawn_weak(|this, cx| async move {
// Ensure the project is still alive by the time the task
// is scheduled.
this.upgrade(&cx)
.ok_or_else(|| anyhow!("project dropped"))?;
let response = rpc.request(message).await?;
let this = this
.upgrade(&cx)
.ok_or_else(|| anyhow!("project dropped"))?;
if this.read_with(&cx, |this, _| this.is_read_only()) {
Err(anyhow!("disconnected before completing request"))
} else {
request
.response_from_proto(response, this, buffer, cx)
.await
}
})
}
fn sort_candidates_and_open_buffers(
mut matching_paths_rx: Receiver<SearchMatchCandidate>,
cx: &mut ModelContext<Self>,
@@ -7150,7 +7261,7 @@ impl Project {
let buffer_version = buffer_handle.read_with(&cx, |buffer, _| buffer.version());
let response = this
.update(&mut cx, |this, cx| {
this.request_lsp(buffer_handle, request, cx)
this.request_lsp(buffer_handle, LanguageServerToQuery::Primary, request, cx)
})
.await?;
this.update(&mut cx, |this, cx| {
@@ -7867,7 +7978,7 @@ impl Project {
})
}
fn primary_language_servers_for_buffer(
fn primary_language_server_for_buffer(
&self,
buffer: &Buffer,
cx: &AppContext,
@@ -8089,31 +8200,6 @@ impl LspAdapterDelegate for ProjectLspAdapterDelegate {
}
}
fn split_operations(
mut operations: Vec<proto::Operation>,
) -> impl Iterator<Item = Vec<proto::Operation>> {
#[cfg(any(test, feature = "test-support"))]
const CHUNK_SIZE: usize = 5;
#[cfg(not(any(test, feature = "test-support")))]
const CHUNK_SIZE: usize = 100;
let mut done = false;
std::iter::from_fn(move || {
if done {
return None;
}
let operations = operations
.drain(..cmp::min(CHUNK_SIZE, operations.len()))
.collect::<Vec<_>>();
if operations.is_empty() {
done = true;
}
Some(operations)
})
}
fn serialize_symbol(symbol: &Symbol) -> proto::Symbol {
proto::Symbol {
language_server_name: symbol.language_server_name.0.to_string(),

View File

@@ -2272,7 +2272,18 @@ async fn test_completions_without_edit_ranges(cx: &mut gpui::TestAppContext) {
},
Some(tree_sitter_typescript::language_typescript()),
);
let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await;
let mut fake_language_servers = language
.set_fake_lsp_adapter(Arc::new(FakeLspAdapter {
capabilities: lsp::ServerCapabilities {
completion_provider: Some(lsp::CompletionOptions {
trigger_characters: Some(vec![":".to_string()]),
..Default::default()
}),
..Default::default()
},
..Default::default()
}))
.await;
let fs = FakeFs::new(cx.background());
fs.insert_tree(
@@ -2358,7 +2369,18 @@ async fn test_completions_with_carriage_returns(cx: &mut gpui::TestAppContext) {
},
Some(tree_sitter_typescript::language_typescript()),
);
let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await;
let mut fake_language_servers = language
.set_fake_lsp_adapter(Arc::new(FakeLspAdapter {
capabilities: lsp::ServerCapabilities {
completion_provider: Some(lsp::CompletionOptions {
trigger_characters: Some(vec![":".to_string()]),
..Default::default()
}),
..Default::default()
},
..Default::default()
}))
.await;
let fs = FakeFs::new(cx.background());
fs.insert_tree(

View File

@@ -225,15 +225,14 @@ impl SearchQuery {
if self.as_str().is_empty() {
return Default::default();
}
let language = buffer.language_at(0);
let range_offset = subrange.as_ref().map(|r| r.start).unwrap_or(0);
let rope = if let Some(range) = subrange {
buffer.as_rope().slice(range)
} else {
buffer.as_rope().clone()
};
let kind = |c| char_kind(language, c);
let mut matches = Vec::new();
match self {
Self::Text {
@@ -249,6 +248,9 @@ impl SearchQuery {
let mat = mat.unwrap();
if *whole_word {
let scope = buffer.language_scope_at(range_offset + mat.start());
let kind = |c| char_kind(&scope, c);
let prev_kind = rope.reversed_chars_at(mat.start()).next().map(kind);
let start_kind = kind(rope.chars_at(mat.start()).next().unwrap());
let end_kind = kind(rope.reversed_chars_at(mat.end()).next().unwrap());

View File

@@ -1,6 +1,8 @@
syntax = "proto3";
package zed.messages;
// Looking for a number? Search "// Current max"
message PeerId {
uint32 owner_id = 1;
uint32 id = 2;
@@ -139,7 +141,7 @@ message Envelope {
RespondToChannelInvite respond_to_channel_invite = 123;
UpdateChannels update_channels = 124;
JoinChannel join_channel = 125;
RemoveChannel remove_channel = 126;
DeleteChannel delete_channel = 126;
GetChannelMembers get_channel_members = 127;
GetChannelMembersResponse get_channel_members_response = 128;
SetChannelMemberAdmin set_channel_member_admin = 129;
@@ -151,6 +153,12 @@ message Envelope {
LeaveChannelBuffer leave_channel_buffer = 134;
AddChannelBufferCollaborator add_channel_buffer_collaborator = 135;
RemoveChannelBufferCollaborator remove_channel_buffer_collaborator = 136;
UpdateChannelBufferCollaborator update_channel_buffer_collaborator = 139;
RejoinChannelBuffers rejoin_channel_buffers = 140;
RejoinChannelBuffersResponse rejoin_channel_buffers_response = 141;
LinkChannel link_channel = 142;
UnlinkChannel unlink_channel = 143;
MoveChannel move_channel = 144; // Current max
}
}
@@ -430,6 +438,12 @@ message RemoveChannelBufferCollaborator {
PeerId peer_id = 2;
}
message UpdateChannelBufferCollaborator {
uint64 channel_id = 1;
PeerId old_peer_id = 2;
PeerId new_peer_id = 3;
}
message GetDefinition {
uint64 project_id = 1;
uint64 buffer_id = 2;
@@ -616,6 +630,12 @@ message BufferVersion {
repeated VectorClockEntry version = 2;
}
message ChannelBufferVersion {
uint64 channel_id = 1;
repeated VectorClockEntry version = 2;
uint64 epoch = 3;
}
enum FormatTrigger {
Save = 0;
Manual = 1;
@@ -657,7 +677,8 @@ message Completion {
Anchor old_start = 1;
Anchor old_end = 2;
string new_text = 3;
bytes lsp_completion = 4;
uint64 server_id = 4;
bytes lsp_completion = 5;
}
message GetCodeActions {
@@ -860,12 +881,12 @@ message ProjectTransaction {
}
message Transaction {
LocalTimestamp id = 1;
repeated LocalTimestamp edit_ids = 2;
LamportTimestamp id = 1;
repeated LamportTimestamp edit_ids = 2;
repeated VectorClockEntry start = 3;
}
message LocalTimestamp {
message LamportTimestamp {
uint32 replica_id = 1;
uint32 value = 2;
}
@@ -927,11 +948,17 @@ message LspDiskBasedDiagnosticsUpdated {}
message UpdateChannels {
repeated Channel channels = 1;
repeated uint64 remove_channels = 2;
repeated Channel channel_invitations = 3;
repeated uint64 remove_channel_invitations = 4;
repeated ChannelParticipants channel_participants = 5;
repeated ChannelPermission channel_permissions = 6;
repeated ChannelEdge delete_channel_edge = 2;
repeated uint64 delete_channels = 3;
repeated Channel channel_invitations = 4;
repeated uint64 remove_channel_invitations = 5;
repeated ChannelParticipants channel_participants = 6;
repeated ChannelPermission channel_permissions = 7;
}
message ChannelEdge {
uint64 channel_id = 1;
uint64 parent_id = 2;
}
message ChannelPermission {
@@ -948,7 +975,7 @@ message JoinChannel {
uint64 channel_id = 1;
}
message RemoveChannel {
message DeleteChannel {
uint64 channel_id = 1;
}
@@ -1003,16 +1030,48 @@ message RenameChannel {
string name = 2;
}
message LinkChannel {
uint64 channel_id = 1;
uint64 to = 2;
}
message UnlinkChannel {
uint64 channel_id = 1;
optional uint64 from = 2;
}
message MoveChannel {
uint64 channel_id = 1;
optional uint64 from = 2;
uint64 to = 3;
}
message JoinChannelBuffer {
uint64 channel_id = 1;
}
message RejoinChannelBuffers {
repeated ChannelBufferVersion buffers = 1;
}
message RejoinChannelBuffersResponse {
repeated RejoinedChannelBuffer buffers = 1;
}
message JoinChannelBufferResponse {
uint64 buffer_id = 1;
uint32 replica_id = 2;
string base_text = 3;
repeated Operation operations = 4;
repeated Collaborator collaborators = 5;
uint64 epoch = 6;
}
message RejoinedChannelBuffer {
uint64 channel_id = 1;
repeated VectorClockEntry version = 2;
repeated Operation operations = 3;
repeated Collaborator collaborators = 4;
}
message LeaveChannelBuffer {
@@ -1279,7 +1338,7 @@ message Excerpt {
message Anchor {
uint32 replica_id = 1;
uint32 local_timestamp = 2;
uint32 timestamp = 2;
uint64 offset = 3;
Bias bias = 4;
optional uint64 buffer_id = 5;
@@ -1323,19 +1382,17 @@ message Operation {
message Edit {
uint32 replica_id = 1;
uint32 local_timestamp = 2;
uint32 lamport_timestamp = 3;
repeated VectorClockEntry version = 4;
repeated Range ranges = 5;
repeated string new_text = 6;
uint32 lamport_timestamp = 2;
repeated VectorClockEntry version = 3;
repeated Range ranges = 4;
repeated string new_text = 5;
}
message Undo {
uint32 replica_id = 1;
uint32 local_timestamp = 2;
uint32 lamport_timestamp = 3;
repeated VectorClockEntry version = 4;
repeated UndoCount counts = 5;
uint32 lamport_timestamp = 2;
repeated VectorClockEntry version = 3;
repeated UndoCount counts = 4;
}
message UpdateSelections {
@@ -1361,7 +1418,7 @@ message UndoMapEntry {
message UndoCount {
uint32 replica_id = 1;
uint32 local_timestamp = 2;
uint32 lamport_timestamp = 2;
uint32 count = 3;
}

View File

@@ -229,13 +229,18 @@ messages!(
(StartLanguageServer, Foreground),
(SynchronizeBuffers, Foreground),
(SynchronizeBuffersResponse, Foreground),
(RejoinChannelBuffers, Foreground),
(RejoinChannelBuffersResponse, Foreground),
(Test, Foreground),
(Unfollow, Foreground),
(UnshareProject, Foreground),
(UpdateBuffer, Foreground),
(UpdateBufferFile, Foreground),
(UpdateContacts, Foreground),
(RemoveChannel, Foreground),
(DeleteChannel, Foreground),
(MoveChannel, Foreground),
(LinkChannel, Foreground),
(UnlinkChannel, Foreground),
(UpdateChannels, Foreground),
(UpdateDiagnosticSummary, Foreground),
(UpdateFollowers, Foreground),
@@ -257,6 +262,7 @@ messages!(
(UpdateChannelBuffer, Foreground),
(RemoveChannelBufferCollaborator, Foreground),
(AddChannelBufferCollaborator, Foreground),
(UpdateChannelBufferCollaborator, Foreground),
);
request_messages!(
@@ -312,13 +318,17 @@ request_messages!(
(SetChannelMemberAdmin, Ack),
(GetChannelMembers, GetChannelMembersResponse),
(JoinChannel, JoinRoomResponse),
(RemoveChannel, Ack),
(DeleteChannel, Ack),
(RenameProjectEntry, ProjectEntryResponse),
(RenameChannel, ChannelResponse),
(LinkChannel, Ack),
(UnlinkChannel, Ack),
(MoveChannel, Ack),
(SaveBuffer, BufferSaved),
(SearchProject, SearchProjectResponse),
(ShareProject, ShareProjectResponse),
(SynchronizeBuffers, SynchronizeBuffersResponse),
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
(Test, Test),
(UpdateBuffer, Ack),
(UpdateParticipantLocation, Ack),
@@ -386,7 +396,8 @@ entity_messages!(
channel_id,
UpdateChannelBuffer,
RemoveChannelBufferCollaborator,
AddChannelBufferCollaborator
AddChannelBufferCollaborator,
UpdateChannelBufferCollaborator
);
const KIB: usize = 1024;

View File

@@ -6,4 +6,4 @@ pub use conn::Connection;
pub use peer::*;
mod macros;
pub const PROTOCOL_VERSION: u32 = 61;
pub const PROTOCOL_VERSION: u32 = 62;

View File

@@ -12,22 +12,19 @@ use editor::{
SelectAll, MAX_TAB_TITLE_LEN,
};
use futures::StreamExt;
use gpui::platform::PromptLevel;
use gpui::{
actions, elements::*, platform::MouseButton, Action, AnyElement, AnyViewHandle, AppContext,
Entity, ModelContext, ModelHandle, Subscription, Task, View, ViewContext, ViewHandle,
WeakModelHandle, WeakViewHandle,
actions,
elements::*,
platform::{MouseButton, PromptLevel},
Action, AnyElement, AnyViewHandle, AppContext, Entity, ModelContext, ModelHandle, Subscription,
Task, View, ViewContext, ViewHandle, WeakModelHandle, WeakViewHandle,
};
use menu::Confirm;
use postage::stream::Stream;
use project::{
search::{PathMatcher, SearchInputs, SearchQuery},
Entry, Project,
};
use semantic_index::SemanticIndex;
use semantic_index::{SemanticIndex, SemanticIndexStatus};
use smallvec::SmallVec;
use std::{
any::{Any, TypeId},
@@ -118,7 +115,7 @@ pub struct ProjectSearchView {
model: ModelHandle<ProjectSearch>,
query_editor: ViewHandle<Editor>,
results_editor: ViewHandle<Editor>,
semantic_state: Option<SemanticSearchState>,
semantic_state: Option<SemanticState>,
semantic_permissioned: Option<bool>,
search_options: SearchOptions,
panels_with_errors: HashSet<InputPanel>,
@@ -131,10 +128,9 @@ pub struct ProjectSearchView {
current_mode: SearchMode,
}
struct SemanticSearchState {
file_count: usize,
outstanding_file_count: usize,
_progress_task: Task<()>,
struct SemanticState {
index_status: SemanticIndexStatus,
_subscription: Subscription,
}
pub struct ProjectSearchBar {
@@ -233,7 +229,7 @@ impl ProjectSearch {
self.search_id += 1;
self.match_ranges.clear();
self.search_history.add(inputs.as_str().to_string());
self.no_results = Some(true);
self.no_results = None;
self.pending_search = Some(cx.spawn(|this, mut cx| async move {
let results = search?.await.log_err()?;
let matches = results
@@ -241,9 +237,10 @@ impl ProjectSearch {
.map(|result| (result.buffer, vec![result.range.start..result.range.start]));
this.update(&mut cx, |this, cx| {
this.no_results = Some(true);
this.excerpts.update(cx, |excerpts, cx| {
excerpts.clear(cx);
})
});
});
for (buffer, ranges) in matches {
let mut match_ranges = this.update(&mut cx, |this, cx| {
@@ -318,19 +315,20 @@ impl View for ProjectSearchView {
}
};
let semantic_status = if let Some(semantic) = &self.semantic_state {
if semantic.outstanding_file_count > 0 {
format!(
"Indexing: {} of {}...",
semantic.file_count - semantic.outstanding_file_count,
semantic.file_count
)
} else {
"Indexing complete".to_string()
let semantic_status = self.semantic_state.as_ref().and_then(|semantic| {
let status = semantic.index_status;
match status {
SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()),
SemanticIndexStatus::Indexing { remaining_files } => {
if remaining_files == 0 {
Some(format!("Indexing..."))
} else {
Some(format!("Remaining files to index: {}", remaining_files))
}
}
SemanticIndexStatus::NotIndexed => None,
}
} else {
"Indexing: ...".to_string()
};
});
let minor_text = if let Some(no_results) = model.no_results {
if model.pending_search.is_none() && no_results {
@@ -340,12 +338,16 @@ impl View for ProjectSearchView {
}
} else {
match current_mode {
SearchMode::Semantic => vec![
"".to_owned(),
semantic_status,
"Simply explain the code you are looking to find.".to_owned(),
"ex. 'prompt user for permissions to index their project'".to_owned(),
],
SearchMode::Semantic => {
let mut minor_text = Vec::new();
minor_text.push("".into());
minor_text.extend(semantic_status);
minor_text.push("Simply explain the code you are looking to find.".into());
minor_text.push(
"ex. 'prompt user for permissions to index their project'".into(),
);
minor_text
}
_ => vec![
"".to_owned(),
"Include/exclude specific paths with the filter option.".to_owned(),
@@ -641,40 +643,29 @@ impl ProjectSearchView {
let project = self.model.read(cx).project.clone();
let index_task = semantic_index.update(cx, |semantic_index, cx| {
semantic_index.index_project(project, cx)
semantic_index.update(cx, |semantic_index, cx| {
semantic_index
.index_project(project.clone(), cx)
.detach_and_log_err(cx);
});
cx.spawn(|search_view, mut cx| async move {
let (files_to_index, mut files_remaining_rx) = index_task.await?;
self.semantic_state = Some(SemanticState {
index_status: semantic_index.read(cx).status(&project),
_subscription: cx.observe(&semantic_index, Self::semantic_index_changed),
});
cx.notify();
}
}
search_view.update(&mut cx, |search_view, cx| {
cx.notify();
search_view.semantic_state = Some(SemanticSearchState {
file_count: files_to_index,
outstanding_file_count: files_to_index,
_progress_task: cx.spawn(|search_view, mut cx| async move {
while let Some(count) = files_remaining_rx.recv().await {
search_view
.update(&mut cx, |search_view, cx| {
if let Some(semantic_search_state) =
&mut search_view.semantic_state
{
semantic_search_state.outstanding_file_count = count;
cx.notify();
if count == 0 {
return;
}
}
})
.ok();
}
}),
});
})?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
fn semantic_index_changed(
&mut self,
semantic_index: ModelHandle<SemanticIndex>,
cx: &mut ViewContext<Self>,
) {
let project = self.model.read(cx).project.clone();
if let Some(semantic_state) = self.semantic_state.as_mut() {
semantic_state.index_status = semantic_index.read(cx).status(&project);
cx.notify();
}
}
@@ -873,7 +864,7 @@ impl ProjectSearchView {
SemanticIndex::global(cx)
.map(|semantic| {
let project = self.model.read(cx).project.clone();
semantic.update(cx, |this, cx| this.project_previously_indexed(project, cx))
semantic.update(cx, |this, cx| this.project_previously_indexed(&project, cx))
})
.unwrap_or(Task::ready(Ok(false)))
}
@@ -958,11 +949,7 @@ impl ProjectSearchView {
let mode = self.current_mode;
match mode {
SearchMode::Semantic => {
if let Some(semantic) = &mut self.semantic_state {
if semantic.outstanding_file_count > 0 {
return;
}
if self.semantic_state.is_some() {
if let Some(query) = self.build_search_query(cx) {
self.model
.update(cx, |model, cx| model.semantic_search(query.as_inner(), cx));

View File

@@ -9,6 +9,7 @@ path = "src/semantic_index.rs"
doctest = false
[dependencies]
collections = { path = "../collections" }
gpui = { path = "../gpui" }
language = { path = "../language" }
project = { path = "../project" }
@@ -39,8 +40,10 @@ rand.workspace = true
schemars.workspace = true
globset.workspace = true
sha1 = "0.10.5"
parse_duration = "2.1.1"
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }

View File

@@ -1,20 +1,26 @@
use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
use crate::{
embedding::Embedding,
parsing::{Span, SpanDigest},
SEMANTIC_INDEX_VERSION,
};
use anyhow::{anyhow, Context, Result};
use collections::HashMap;
use futures::channel::oneshot;
use gpui::executor;
use project::{search::PathMatcher, Fs};
use rpc::proto::Timestamp;
use rusqlite::{
params,
types::{FromSql, FromSqlResult, ValueRef},
};
use rusqlite::params;
use rusqlite::types::Value;
use std::{
cmp::Ordering,
collections::HashMap,
future::Future,
ops::Range,
path::{Path, PathBuf},
rc::Rc,
sync::Arc,
time::SystemTime,
};
use util::TryFutureExt;
#[derive(Debug)]
pub struct FileRecord {
@@ -23,286 +29,366 @@ pub struct FileRecord {
pub mtime: Timestamp,
}
#[derive(Debug)]
struct Embedding(pub Vec<f32>);
#[derive(Debug)]
struct Sha1(pub Vec<u8>);
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
return Ok(Embedding(embedding.unwrap()));
}
}
impl FromSql for Sha1 {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if sha1.is_err() {
return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
}
return Ok(Sha1(sha1.unwrap()));
}
}
#[derive(Clone)]
pub struct VectorDatabase {
db: rusqlite::Connection,
path: Arc<Path>,
transactions:
smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
}
impl VectorDatabase {
pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
pub async fn new(
fs: Arc<dyn Fs>,
path: Arc<Path>,
executor: Arc<executor::Background>,
) -> Result<Self> {
if let Some(db_directory) = path.parent() {
fs.create_dir(db_directory).await?;
}
let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
>();
executor
.spawn({
let path = path.clone();
async move {
let mut connection = rusqlite::Connection::open(&path)?;
connection.pragma_update(None, "journal_mode", "wal")?;
connection.pragma_update(None, "synchronous", "normal")?;
connection.pragma_update(None, "cache_size", 1000000)?;
connection.pragma_update(None, "temp_store", "MEMORY")?;
while let Ok(transaction) = transactions_rx.recv().await {
transaction(&mut connection);
}
anyhow::Ok(())
}
.log_err()
})
.detach();
let this = Self {
db: rusqlite::Connection::open(path.as_path())?,
transactions: transactions_tx,
path,
};
this.initialize_database()?;
this.initialize_database().await?;
Ok(this)
}
fn get_existing_version(&self) -> Result<i64> {
let mut version_query = self
.db
.prepare("SELECT version from semantic_index_config")?;
version_query
.query_row([], |row| Ok(row.get::<_, i64>(0)?))
.map_err(|err| anyhow!("version query failed: {err}"))
pub fn path(&self) -> &Arc<Path> {
&self.path
}
fn initialize_database(&self) -> Result<()> {
rusqlite::vtab::array::load_module(&self.db)?;
// Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
if self
.get_existing_version()
.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
{
log::trace!("vector database schema up to date");
return Ok(());
fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
where
F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
T: 'static + Send,
{
let (tx, rx) = oneshot::channel();
let transactions = self.transactions.clone();
async move {
if transactions
.send(Box::new(|connection| {
let result = connection
.transaction()
.map_err(|err| anyhow!(err))
.and_then(|transaction| {
let result = f(&transaction)?;
transaction.commit()?;
Ok(result)
});
let _ = tx.send(result);
}))
.await
.is_err()
{
return Err(anyhow!("connection was dropped"))?;
}
rx.await?
}
log::trace!("vector database schema out of date. updating...");
self.db
.execute("DROP TABLE IF EXISTS documents", [])
.context("failed to drop 'documents' table")?;
self.db
.execute("DROP TABLE IF EXISTS files", [])
.context("failed to drop 'files' table")?;
self.db
.execute("DROP TABLE IF EXISTS worktrees", [])
.context("failed to drop 'worktrees' table")?;
self.db
.execute("DROP TABLE IF EXISTS semantic_index_config", [])
.context("failed to drop 'semantic_index_config' table")?;
// Initialize Vector Databasing Tables
self.db.execute(
"CREATE TABLE semantic_index_config (
version INTEGER NOT NULL
)",
[],
)?;
self.db.execute(
"INSERT INTO semantic_index_config (version) VALUES (?1)",
params![SEMANTIC_INDEX_VERSION],
)?;
self.db.execute(
"CREATE TABLE worktrees (
id INTEGER PRIMARY KEY AUTOINCREMENT,
absolute_path VARCHAR NOT NULL
);
CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
",
[],
)?;
self.db.execute(
"CREATE TABLE files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
worktree_id INTEGER NOT NULL,
relative_path VARCHAR NOT NULL,
mtime_seconds INTEGER NOT NULL,
mtime_nanos INTEGER NOT NULL,
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
)",
[],
)?;
self.db.execute(
"CREATE TABLE documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL,
start_byte INTEGER NOT NULL,
end_byte INTEGER NOT NULL,
name VARCHAR NOT NULL,
embedding BLOB NOT NULL,
sha1 BLOB NOT NULL,
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
)",
[],
)?;
log::trace!("vector database initialized with updated schema.");
Ok(())
}
pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
self.db.execute(
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
params![worktree_id, delete_path.to_str()],
)?;
Ok(())
fn initialize_database(&self) -> impl Future<Output = Result<()>> {
self.transact(|db| {
rusqlite::vtab::array::load_module(&db)?;
// Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
let version_query = db.prepare("SELECT version from semantic_index_config");
let version = version_query
.and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
log::trace!("vector database schema up to date");
return Ok(());
}
log::trace!("vector database schema out of date. updating...");
// We renamed the `documents` table to `spans`, so we want to drop
// `documents` without recreating it if it exists.
db.execute("DROP TABLE IF EXISTS documents", [])
.context("failed to drop 'documents' table")?;
db.execute("DROP TABLE IF EXISTS spans", [])
.context("failed to drop 'spans' table")?;
db.execute("DROP TABLE IF EXISTS files", [])
.context("failed to drop 'files' table")?;
db.execute("DROP TABLE IF EXISTS worktrees", [])
.context("failed to drop 'worktrees' table")?;
db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
.context("failed to drop 'semantic_index_config' table")?;
// Initialize Vector Databasing Tables
db.execute(
"CREATE TABLE semantic_index_config (
version INTEGER NOT NULL
)",
[],
)?;
db.execute(
"INSERT INTO semantic_index_config (version) VALUES (?1)",
params![SEMANTIC_INDEX_VERSION],
)?;
db.execute(
"CREATE TABLE worktrees (
id INTEGER PRIMARY KEY AUTOINCREMENT,
absolute_path VARCHAR NOT NULL
);
CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
",
[],
)?;
db.execute(
"CREATE TABLE files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
worktree_id INTEGER NOT NULL,
relative_path VARCHAR NOT NULL,
mtime_seconds INTEGER NOT NULL,
mtime_nanos INTEGER NOT NULL,
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
)",
[],
)?;
db.execute(
"CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
[],
)?;
db.execute(
"CREATE TABLE spans (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_id INTEGER NOT NULL,
start_byte INTEGER NOT NULL,
end_byte INTEGER NOT NULL,
name VARCHAR NOT NULL,
embedding BLOB NOT NULL,
digest BLOB NOT NULL,
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
)",
[],
)?;
log::trace!("vector database initialized with updated schema.");
Ok(())
})
}
pub fn delete_file(
&self,
worktree_id: i64,
delete_path: Arc<Path>,
) -> impl Future<Output = Result<()>> {
self.transact(move |db| {
db.execute(
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
params![worktree_id, delete_path.to_str()],
)?;
Ok(())
})
}
pub fn insert_file(
&self,
worktree_id: i64,
path: PathBuf,
path: Arc<Path>,
mtime: SystemTime,
documents: Vec<Document>,
) -> Result<()> {
// Return the existing ID, if both the file and mtime match
let mtime = Timestamp::from(mtime);
let mut existing_id_query = self.db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?;
let existing_id = existing_id_query
.query_row(
spans: Vec<Span>,
) -> impl Future<Output = Result<()>> {
self.transact(move |db| {
// Return the existing ID, if both the file and mtime match
let mtime = Timestamp::from(mtime);
db.execute(
"
REPLACE INTO files
(worktree_id, relative_path, mtime_seconds, mtime_nanos)
VALUES (?1, ?2, ?3, ?4)
",
params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
|row| Ok(row.get::<_, i64>(0)?),
)
.map_err(|err| anyhow!(err));
let file_id = if existing_id.is_ok() {
// If already exists, just return the existing id
existing_id.unwrap()
} else {
// Delete Existing Row
self.db.execute(
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;",
params![worktree_id, path.to_str()],
)?;
self.db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?;
self.db.last_insert_rowid()
};
// Currently inserting at approximately 3400 documents a second
// I imagine we can speed this up with a bulk insert of some kind.
for document in documents {
let embedding_blob = bincode::serialize(&document.embedding)?;
let sha_blob = bincode::serialize(&document.sha1)?;
let file_id = db.last_insert_rowid();
self.db.execute(
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
let mut query = db.prepare(
"
INSERT INTO spans
(file_id, start_byte, end_byte, name, embedding, digest)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
",
)?;
for span in spans {
query.execute(params![
file_id,
document.range.start.to_string(),
document.range.end.to_string(),
document.name,
embedding_blob,
sha_blob
],
)?;
}
span.range.start.to_string(),
span.range.end.to_string(),
span.name,
span.embedding,
span.digest
])?;
}
Ok(())
Ok(())
})
}
pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result<bool> {
let mut worktree_query = self
.db
.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
Ok(row.get::<_, i64>(0)?)
})
.map_err(|err| anyhow!(err));
pub fn worktree_previously_indexed(
&self,
worktree_root_path: &Path,
) -> impl Future<Output = Result<bool>> {
let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
self.transact(move |db| {
let mut worktree_query =
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query
.query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
if worktree_id.is_ok() {
return Ok(true);
} else {
return Ok(false);
}
if worktree_id.is_ok() {
return Ok(true);
} else {
return Ok(false);
}
})
}
pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
// Check that the absolute path doesnt exist
let mut worktree_query = self
.db
.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
Ok(row.get::<_, i64>(0)?)
})
.map_err(|err| anyhow!(err));
if worktree_id.is_ok() {
return worktree_id;
}
// If worktree_id is Err, insert new worktree
self.db.execute(
"
INSERT into worktrees (absolute_path) VALUES (?1)
pub fn embeddings_for_files(
&self,
worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
self.transact(move |db| {
let mut query = db.prepare(
"
SELECT digest, embedding
FROM spans
LEFT JOIN files ON files.id = spans.file_id
WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
",
params![worktree_root_path.to_string_lossy()],
)?;
Ok(self.db.last_insert_rowid())
)?;
let mut embeddings_by_digest = HashMap::default();
for (worktree_id, file_paths) in worktree_id_file_paths {
let file_paths = Rc::new(
file_paths
.into_iter()
.map(|p| Value::Text(p.to_string_lossy().into_owned()))
.collect::<Vec<_>>(),
);
let rows = query.query_map(params![worktree_id, file_paths], |row| {
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
})?;
for row in rows {
if let Ok(row) = row {
embeddings_by_digest.insert(row.0, row.1);
}
}
}
Ok(embeddings_by_digest)
})
}
pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
let mut statement = self.db.prepare(
"
SELECT relative_path, mtime_seconds, mtime_nanos
FROM files
WHERE worktree_id = ?1
ORDER BY relative_path",
)?;
let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
for row in statement.query_map(params![worktree_id], |row| {
Ok((
row.get::<_, String>(0)?.into(),
Timestamp {
seconds: row.get(1)?,
nanos: row.get(2)?,
}
.into(),
))
})? {
let row = row?;
result.insert(row.0, row.1);
}
Ok(result)
pub fn find_or_create_worktree(
&self,
worktree_root_path: Arc<Path>,
) -> impl Future<Output = Result<i64>> {
self.transact(move |db| {
let mut worktree_query =
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
let worktree_id = worktree_query
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
Ok(row.get::<_, i64>(0)?)
});
if worktree_id.is_ok() {
return Ok(worktree_id?);
}
// If worktree_id is Err, insert new worktree
db.execute(
"INSERT into worktrees (absolute_path) VALUES (?1)",
params![worktree_root_path.to_string_lossy()],
)?;
Ok(db.last_insert_rowid())
})
}
pub fn get_file_mtimes(
&self,
worktree_id: i64,
) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
self.transact(move |db| {
let mut statement = db.prepare(
"
SELECT relative_path, mtime_seconds, mtime_nanos
FROM files
WHERE worktree_id = ?1
ORDER BY relative_path",
)?;
let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
for row in statement.query_map(params![worktree_id], |row| {
Ok((
row.get::<_, String>(0)?.into(),
Timestamp {
seconds: row.get(1)?,
nanos: row.get(2)?,
}
.into(),
))
})? {
let row = row?;
result.insert(row.0, row.1);
}
Ok(result)
})
}
pub fn top_k_search(
&self,
query_embedding: &Vec<f32>,
query_embedding: &Embedding,
limit: usize,
file_ids: &[i64],
) -> Result<Vec<(i64, f32)>> {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
self.for_each_document(file_ids, |id, embedding| {
let similarity = dot(&embedding, &query_embedding);
let ix = match results
.binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
{
Ok(ix) => ix,
Err(ix) => ix,
};
results.insert(ix, (id, similarity));
results.truncate(limit);
})?;
) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
let query_embedding = query_embedding.clone();
let file_ids = file_ids.to_vec();
self.transact(move |db| {
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
Self::for_each_span(db, &file_ids, |id, embedding| {
let similarity = embedding.similarity(&query_embedding);
let ix = match results.binary_search_by(|(_, s)| {
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
}) {
Ok(ix) => ix,
Err(ix) => ix,
};
results.insert(ix, (id, similarity));
results.truncate(limit);
})?;
Ok(results)
anyhow::Ok(results)
})
}
pub fn retrieve_included_file_ids(
@@ -310,42 +396,51 @@ impl VectorDatabase {
worktree_ids: &[i64],
includes: &[PathMatcher],
excludes: &[PathMatcher],
) -> Result<Vec<i64>> {
let mut file_query = self.db.prepare(
"
SELECT
id, relative_path
FROM
files
WHERE
worktree_id IN rarray(?)
",
)?;
) -> impl Future<Output = Result<Vec<i64>>> {
let worktree_ids = worktree_ids.to_vec();
let includes = includes.to_vec();
let excludes = excludes.to_vec();
self.transact(move |db| {
let mut file_query = db.prepare(
"
SELECT
id, relative_path
FROM
files
WHERE
worktree_id IN rarray(?)
",
)?;
let mut file_ids = Vec::<i64>::new();
let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
let mut file_ids = Vec::<i64>::new();
let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
while let Some(row) = rows.next()? {
let file_id = row.get(0)?;
let relative_path = row.get_ref(1)?.as_str()?;
let included =
includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
if included && !excluded {
file_ids.push(file_id);
while let Some(row) = rows.next()? {
let file_id = row.get(0)?;
let relative_path = row.get_ref(1)?.as_str()?;
let included =
includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
if included && !excluded {
file_ids.push(file_id);
}
}
}
Ok(file_ids)
anyhow::Ok(file_ids)
})
}
fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
let mut query_statement = self.db.prepare(
fn for_each_span(
db: &rusqlite::Connection,
file_ids: &[i64],
mut f: impl FnMut(i64, Embedding),
) -> Result<()> {
let mut query_statement = db.prepare(
"
SELECT
id, embedding
FROM
documents
spans
WHERE
file_id IN rarray(?)
",
@@ -356,51 +451,57 @@ impl VectorDatabase {
Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
})?
.filter_map(|row| row.ok())
.for_each(|(id, embedding)| f(id, embedding.0));
.for_each(|(id, embedding)| f(id, embedding));
Ok(())
}
pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
let mut statement = self.db.prepare(
"
SELECT
documents.id,
files.worktree_id,
files.relative_path,
documents.start_byte,
documents.end_byte
FROM
documents, files
WHERE
documents.file_id = files.id AND
documents.id in rarray(?)
",
)?;
pub fn spans_for_ids(
&self,
ids: &[i64],
) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
let ids = ids.to_vec();
self.transact(move |db| {
let mut statement = db.prepare(
"
SELECT
spans.id,
files.worktree_id,
files.relative_path,
spans.start_byte,
spans.end_byte
FROM
spans, files
WHERE
spans.file_id = files.id AND
spans.id in rarray(?)
",
)?;
let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
Ok((
row.get::<_, i64>(0)?,
row.get::<_, i64>(1)?,
row.get::<_, String>(2)?.into(),
row.get(3)?..row.get(4)?,
))
})?;
let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
Ok((
row.get::<_, i64>(0)?,
row.get::<_, i64>(1)?,
row.get::<_, String>(2)?.into(),
row.get(3)?..row.get(4)?,
))
})?;
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
for row in result_iter {
let (id, worktree_id, path, range) = row?;
values_by_id.insert(id, (worktree_id, path, range));
}
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
for row in result_iter {
let (id, worktree_id, path, range) = row?;
values_by_id.insert(id, (worktree_id, path, range));
}
let mut results = Vec::with_capacity(ids.len());
for id in ids {
let value = values_by_id
.remove(id)
.ok_or(anyhow!("missing document id {}", id))?;
results.push(value);
}
let mut results = Vec::with_capacity(ids.len());
for id in &ids {
let value = values_by_id
.remove(id)
.ok_or(anyhow!("missing span id {}", id))?;
results.push(value);
}
Ok(results)
Ok(results)
})
}
}
@@ -412,29 +513,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
.collect::<Vec<_>>(),
)
}
pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
let len = vec_a.len();
assert_eq!(len, vec_b.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
vec_a.as_ptr(),
len as isize,
1,
vec_b.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
result
}

View File

@@ -7,6 +7,9 @@ use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use parse_duration::parse;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
use serde::{Deserialize, Serialize};
use std::env;
use std::sync::Arc;
@@ -19,6 +22,62 @@ lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
#[derive(Debug, PartialEq, Clone)]
pub struct Embedding(Vec<f32>);
impl From<Vec<f32>> for Embedding {
fn from(value: Vec<f32>) -> Self {
Embedding(value)
}
}
impl Embedding {
pub fn similarity(&self, other: &Self) -> f32 {
let len = self.0.len();
assert_eq!(len, other.0.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
self.0.as_ptr(),
len as isize,
1,
other.0.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
result
}
}
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
if embedding.is_err() {
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
}
Ok(Embedding(embedding.unwrap()))
}
}
impl ToSql for Embedding {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
let bytes = bincode::serialize(&self.0)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
}
}
#[derive(Clone)]
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
@@ -52,42 +111,53 @@ struct OpenAIEmbeddingUsage {
#[async_trait]
pub trait EmbeddingProvider: Sync + Send {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn truncate(&self, span: &str) -> (String, usize);
}
pub struct DummyEmbeddings {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
let dummy_vec = vec![0.32 as f32; 1536];
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
return Ok(vec![dummy_vec; spans.len()]);
}
fn max_tokens_per_batch(&self) -> usize {
OPENAI_INPUT_LIMIT
}
fn truncate(&self, span: &str) -> (String, usize) {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
let token_count = tokens.len();
let output = if token_count > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
new_input.ok().unwrap_or_else(|| span.to_string())
} else {
span.to_string()
};
(output, tokens.len())
}
}
const OPENAI_INPUT_LIMIT: usize = 8190;
impl OpenAIEmbeddings {
fn truncate(span: String) -> String {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
if tokens.len() > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
if result.is_ok() {
let transformed = result.unwrap();
return transformed;
}
}
span
}
async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
async fn send_request(
&self,
api_key: &str,
spans: Vec<&str>,
request_timeout: u64,
) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(4))
.timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
@@ -105,7 +175,26 @@ impl OpenAIEmbeddings {
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
fn max_tokens_per_batch(&self) -> usize {
50000
}
fn truncate(&self, span: &str) -> (String, usize) {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
tokens.truncate(OPENAI_INPUT_LIMIT);
OPENAI_BPE_TOKENIZER
.decode(tokens.clone())
.ok()
.unwrap_or_else(|| span.to_string())
} else {
span.to_string()
};
(output, tokens.len())
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
@@ -114,45 +203,21 @@ impl EmbeddingProvider for OpenAIEmbeddings {
.ok_or_else(|| anyhow!("no api key"))?;
let mut request_number = 0;
let mut truncated = false;
let mut request_timeout: u64 = 15;
let mut response: Response<AsyncBody>;
let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
while request_number < MAX_RETRIES {
response = self
.send_request(api_key, spans.iter().map(|x| &**x).collect())
.send_request(
api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
.await?;
request_number += 1;
if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK {
return Err(anyhow!(
"openai max retries, error: {:?}",
&response.status()
));
}
match response.status() {
StatusCode::TOO_MANY_REQUESTS => {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
log::trace!(
"open ai rate limiting, delaying request by {:?} seconds",
delay.as_secs()
);
self.executor.timer(delay).await;
}
StatusCode::BAD_REQUEST => {
// Only truncate if it hasnt been truncated before
if !truncated {
for span in spans.iter_mut() {
*span = Self::truncate(span.clone());
}
truncated = true;
} else {
// If failing once already truncated, log the error and break the loop
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
log::trace!("open ai bad request: {:?} {:?}", &response.status(), body);
break;
}
StatusCode::REQUEST_TIMEOUT => {
request_timeout += 5;
}
StatusCode::OK => {
let mut body = String::new();
@@ -163,18 +228,96 @@ impl EmbeddingProvider for OpenAIEmbeddings {
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
return Ok(response
.data
.into_iter()
.map(|embedding| embedding.embedding)
.map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let delay_duration = {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
if let Some(time_to_reset) =
response.headers().get("x-ratelimit-reset-tokens")
{
if let Ok(time_str) = time_to_reset.to_str() {
parse(time_str).unwrap_or(delay)
} else {
delay
}
} else {
delay
}
};
log::trace!(
"openai rate limiting: waiting {:?} until lifted",
&delay_duration
);
self.executor.timer(delay_duration).await;
}
_ => {
return Err(anyhow!("openai embedding failed {}", response.status()));
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"open ai bad request: {:?} {:?}",
&response.status(),
body
));
}
}
}
Err(anyhow!("openai embedding failed"))
Err(anyhow!("openai max retries"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::prelude::*;
#[gpui::test]
fn test_similarity(mut rng: StdRng) {
assert_eq!(
Embedding::from(vec![1., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
0.
);
assert_eq!(
Embedding::from(vec![2., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
6.
);
for _ in 0..100 {
let size = 1536;
let mut a = vec![0.; size];
let mut b = vec![0.; size];
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
*a = rng.gen();
*b = rng.gen();
}
let a = Embedding::from(a);
let b = Embedding::from(b);
assert_eq!(
round_to_decimals(a.similarity(&b), 1),
round_to_decimals(reference_dot(&a.0, &b.0), 1)
);
}
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
}
}
}

View File

@@ -0,0 +1,165 @@
use crate::{embedding::EmbeddingProvider, parsing::Span, JobHandle};
use gpui::executor::Background;
use parking_lot::Mutex;
use smol::channel;
use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
#[derive(Clone)]
pub struct FileToEmbed {
pub worktree_id: i64,
pub path: Arc<Path>,
pub mtime: SystemTime,
pub spans: Vec<Span>,
pub job_handle: JobHandle,
}
impl std::fmt::Debug for FileToEmbed {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FileToEmbed")
.field("worktree_id", &self.worktree_id)
.field("path", &self.path)
.field("mtime", &self.mtime)
.field("spans", &self.spans)
.finish_non_exhaustive()
}
}
impl PartialEq for FileToEmbed {
fn eq(&self, other: &Self) -> bool {
self.worktree_id == other.worktree_id
&& self.path == other.path
&& self.mtime == other.mtime
&& self.spans == other.spans
}
}
pub struct EmbeddingQueue {
embedding_provider: Arc<dyn EmbeddingProvider>,
pending_batch: Vec<FileFragmentToEmbed>,
executor: Arc<Background>,
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
}
#[derive(Clone)]
pub struct FileFragmentToEmbed {
file: Arc<Mutex<FileToEmbed>>,
span_range: Range<usize>,
}
impl EmbeddingQueue {
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
embedding_provider,
executor,
pending_batch: Vec::new(),
pending_batch_token_count: 0,
finished_files_tx,
finished_files_rx,
}
}
pub fn push(&mut self, file: FileToEmbed) {
if file.spans.is_empty() {
self.finished_files_tx.try_send(file).unwrap();
return;
}
let file = Arc::new(Mutex::new(file));
self.pending_batch.push(FileFragmentToEmbed {
file: file.clone(),
span_range: 0..0,
});
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
for (ix, span) in file.lock().spans.iter().enumerate() {
let span_token_count = if span.embedding.is_none() {
span.token_count
} else {
0
};
let next_token_count = self.pending_batch_token_count + span_token_count;
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
let range_end = fragment_range.end;
self.flush();
self.pending_batch.push(FileFragmentToEmbed {
file: file.clone(),
span_range: range_end..range_end,
});
fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
}
fragment_range.end = ix + 1;
self.pending_batch_token_count += span_token_count;
}
}
pub fn flush(&mut self) {
let batch = mem::take(&mut self.pending_batch);
self.pending_batch_token_count = 0;
if batch.is_empty() {
return;
}
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
self.executor
.spawn(async move {
let mut spans = Vec::new();
for fragment in &batch {
let file = fragment.file.lock();
spans.extend(
file.spans[fragment.span_range.clone()]
.iter()
.filter(|d| d.embedding.is_none())
.map(|d| d.content.clone()),
);
}
// If spans is 0, just send the fragment to the finished files if its the last one.
if spans.is_empty() {
for fragment in batch.clone() {
if let Some(file) = Arc::into_inner(fragment.file) {
finished_files_tx.try_send(file.into_inner()).unwrap();
}
}
return;
};
match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
.iter_mut()
.filter(|d| d.embedding.is_none())
{
if let Some(embedding) = embeddings.next() {
span.embedding = Some(embedding);
} else {
log::error!("number of embeddings != number of documents");
}
}
if let Some(file) = Arc::into_inner(fragment.file) {
finished_files_tx.try_send(file.into_inner()).unwrap();
}
}
}
Err(error) => {
log::error!("{:?}", error);
}
}
})
.detach();
}
pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
self.finished_files_rx.clone()
}
}

View File

@@ -1,5 +1,10 @@
use anyhow::{anyhow, Ok, Result};
use crate::embedding::{Embedding, EmbeddingProvider};
use anyhow::{anyhow, Result};
use language::{Grammar, Language};
use rusqlite::{
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
ToSql,
};
use sha1::{Digest, Sha1};
use std::{
cmp::{self, Reverse},
@@ -10,13 +15,44 @@ use std::{
};
use tree_sitter::{Parser, QueryCursor};
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub struct SpanDigest([u8; 20]);
impl FromSql for SpanDigest {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let blob = value.as_blob()?;
let bytes =
blob.try_into()
.map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
expected_size: 20,
blob_size: blob.len(),
})?;
return Ok(SpanDigest(bytes));
}
}
impl ToSql for SpanDigest {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
self.0.to_sql()
}
}
impl From<&'_ str> for SpanDigest {
fn from(value: &'_ str) -> Self {
let mut sha1 = Sha1::new();
sha1.update(value);
Self(sha1.finalize().into())
}
}
#[derive(Debug, PartialEq, Clone)]
pub struct Document {
pub struct Span {
pub name: String,
pub range: Range<usize>,
pub content: String,
pub embedding: Vec<f32>,
pub sha1: [u8; 20],
pub embedding: Option<Embedding>,
pub digest: SpanDigest,
pub token_count: usize,
}
const CODE_CONTEXT_TEMPLATE: &str =
@@ -30,6 +66,7 @@ pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
pub struct CodeContextRetriever {
pub parser: Parser,
pub cursor: QueryCursor,
pub embedding_provider: Arc<dyn EmbeddingProvider>,
}
// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
@@ -47,10 +84,11 @@ pub struct CodeContextMatch {
}
impl CodeContextRetriever {
pub fn new() -> Self {
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
Self {
parser: Parser::new(),
cursor: QueryCursor::new(),
embedding_provider,
}
}
@@ -59,38 +97,36 @@ impl CodeContextRetriever {
relative_path: &Path,
language_name: Arc<str>,
content: &str,
) -> Result<Vec<Document>> {
) -> Result<Vec<Span>> {
let document_span = ENTIRE_FILE_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<language>", language_name.as_ref())
.replace("<item>", &content);
let mut sha1 = Sha1::new();
sha1.update(&document_span);
Ok(vec![Document {
let digest = SpanDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
embedding: Vec::new(),
embedding: Default::default(),
name: language_name.to_string(),
sha1: sha1.finalize().into(),
digest,
token_count,
}])
}
fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Document>> {
fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Span>> {
let document_span = MARKDOWN_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<item>", &content);
let mut sha1 = Sha1::new();
sha1.update(&document_span);
Ok(vec![Document {
let digest = SpanDigest::from(document_span.as_str());
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
embedding: Vec::new(),
embedding: None,
name: "Markdown".to_string(),
sha1: sha1.finalize().into(),
digest,
token_count,
}])
}
@@ -155,26 +191,32 @@ impl CodeContextRetriever {
relative_path: &Path,
content: &str,
language: Arc<Language>,
) -> Result<Vec<Document>> {
) -> Result<Vec<Span>> {
let language_name = language.name();
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
return self.parse_entire_file(relative_path, language_name, &content);
} else if &language_name.to_string() == &"Markdown".to_string() {
} else if language_name.as_ref() == "Markdown" {
return self.parse_markdown_file(relative_path, &content);
}
let mut documents = self.parse_file(content, language)?;
for document in &mut documents {
document.content = CODE_CONTEXT_TEMPLATE
let mut spans = self.parse_file(content, language)?;
for span in &mut spans {
let document_content = CODE_CONTEXT_TEMPLATE
.replace("<path>", relative_path.to_string_lossy().as_ref())
.replace("<language>", language_name.as_ref())
.replace("item", &document.content);
.replace("item", &span.content);
let (document_content, token_count) =
self.embedding_provider.truncate(&document_content);
span.content = document_content;
span.token_count = token_count;
}
Ok(documents)
Ok(spans)
}
pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
let grammar = language
.grammar()
.ok_or_else(|| anyhow!("no grammar for language"))?;
@@ -185,7 +227,7 @@ impl CodeContextRetriever {
let language_scope = language.default_scope();
let placeholder = language_scope.collapsed_placeholder();
let mut documents = Vec::new();
let mut spans = Vec::new();
let mut collapsed_ranges_within = Vec::new();
let mut parsed_name_ranges = HashSet::new();
for (i, context_match) in matches.iter().enumerate() {
@@ -225,22 +267,22 @@ impl CodeContextRetriever {
collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
let mut document_content = String::new();
let mut span_content = String::new();
for context_range in &context_match.context_ranges {
add_content_from_range(
&mut document_content,
&mut span_content,
content,
context_range.clone(),
context_match.start_col,
);
document_content.push_str("\n");
span_content.push_str("\n");
}
let mut offset = item_range.start;
for collapsed_range in &collapsed_ranges_within {
if collapsed_range.start > offset {
add_content_from_range(
&mut document_content,
&mut span_content,
content,
offset..collapsed_range.start,
context_match.start_col,
@@ -249,33 +291,32 @@ impl CodeContextRetriever {
}
if collapsed_range.end > offset {
document_content.push_str(placeholder);
span_content.push_str(placeholder);
offset = collapsed_range.end;
}
}
if offset < item_range.end {
add_content_from_range(
&mut document_content,
&mut span_content,
content,
offset..item_range.end,
context_match.start_col,
);
}
let mut sha1 = Sha1::new();
sha1.update(&document_content);
documents.push(Document {
let sha1 = SpanDigest::from(span_content.as_str());
spans.push(Span {
name,
content: document_content,
content: span_content,
range: item_range.clone(),
embedding: vec![],
sha1: sha1.finalize().into(),
embedding: None,
digest: sha1,
token_count: 0,
})
}
return Ok(documents);
return Ok(spans);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,15 @@
use crate::{
db::dot,
embedding::EmbeddingProvider,
parsing::{subtract_ranges, CodeContextRetriever, Document},
embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
embedding_queue::EmbeddingQueue,
parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
semantic_index_settings::SemanticIndexSettings,
SearchResult, SemanticIndex,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
};
use anyhow::Result;
use async_trait::async_trait;
use gpui::{Task, TestAppContext};
use gpui::{executor::Deterministic, Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
use rand::{rngs::StdRng, Rng};
@@ -20,8 +21,10 @@ use std::{
atomic::{self, AtomicUsize},
Arc,
},
time::SystemTime,
};
use unindent::Unindent;
use util::RandomCharIter;
#[ctor::ctor]
fn init_logger() {
@@ -31,12 +34,8 @@ fn init_logger() {
}
#[gpui::test]
async fn test_semantic_index(cx: &mut TestAppContext) {
cx.update(|cx| {
cx.set_global(SettingsStore::test(cx));
settings::register::<SemanticIndexSettings>(cx);
settings::register::<ProjectSettings>(cx);
});
async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.background());
fs.insert_tree(
@@ -56,6 +55,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
fn bbb() {
println!(\"bbbbbbbbbbbbb!\");
}
struct pqpqpqp {}
".unindent(),
"file3.toml": "
ZZZZZZZZZZZZZZZZZZ = 5
@@ -75,7 +75,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
let db_path = db_dir.path().join("db.sqlite");
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let store = SemanticIndex::new(
let semantic_index = SemanticIndex::new(
fs.clone(),
db_path,
embedding_provider.clone(),
@@ -87,34 +87,24 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
let _ = store
.update(cx, |store, cx| {
store.initialize_project(project.clone(), cx)
})
.await;
let (file_count, outstanding_file_count) = store
.update(cx, |store, cx| store.index_project(project.clone(), cx))
.await
.unwrap();
assert_eq!(file_count, 3);
cx.foreground().run_until_parked();
assert_eq!(*outstanding_file_count.borrow(), 0);
let search_results = store
.update(cx, |store, cx| {
store.search_project(
project.clone(),
"aaaaaabbbbzz".to_string(),
5,
vec![],
vec![],
cx,
)
})
.await
.unwrap();
let search_results = semantic_index.update(cx, |store, cx| {
store.search_project(
project.clone(),
"aaaaaabbbbzz".to_string(),
5,
vec![],
vec![],
cx,
)
});
let pending_file_count =
semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
deterministic.run_until_parked();
assert_eq!(*pending_file_count.borrow(), 3);
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
assert_eq!(*pending_file_count.borrow(), 0);
let search_results = search_results.await.unwrap();
assert_search_results(
&search_results,
&[
@@ -122,6 +112,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
(Path::new("src/file2.rs").into(), 0),
(Path::new("src/file3.toml").into(), 0),
(Path::new("src/file1.rs").into(), 45),
(Path::new("src/file2.rs").into(), 45),
],
cx,
);
@@ -129,7 +120,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
// Test Include Files Functonality
let include_files = vec![PathMatcher::new("*.rs").unwrap()];
let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
let rust_only_search_results = store
let rust_only_search_results = semantic_index
.update(cx, |store, cx| {
store.search_project(
project.clone(),
@@ -149,11 +140,12 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
(Path::new("src/file1.rs").into(), 0),
(Path::new("src/file2.rs").into(), 0),
(Path::new("src/file1.rs").into(), 45),
(Path::new("src/file2.rs").into(), 45),
],
cx,
);
let no_rust_search_results = store
let no_rust_search_results = semantic_index
.update(cx, |store, cx| {
store.search_project(
project.clone(),
@@ -186,24 +178,85 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
.await
.unwrap();
cx.foreground().run_until_parked();
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
let prev_embedding_count = embedding_provider.embedding_count();
let (file_count, outstanding_file_count) = store
.update(cx, |store, cx| store.index_project(project.clone(), cx))
.await
.unwrap();
assert_eq!(file_count, 1);
cx.foreground().run_until_parked();
assert_eq!(*outstanding_file_count.borrow(), 0);
let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
deterministic.run_until_parked();
assert_eq!(*pending_file_count.borrow(), 1);
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
assert_eq!(*pending_file_count.borrow(), 0);
index.await.unwrap();
assert_eq!(
embedding_provider.embedding_count() - prev_embedding_count,
2
1
);
}
#[gpui::test(iterations = 10)]
async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
let (outstanding_job_count, _) = postage::watch::channel_with(0);
let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
let files = (1..=3)
.map(|file_ix| FileToEmbed {
worktree_id: 5,
path: Path::new(&format!("path-{file_ix}")).into(),
mtime: SystemTime::now(),
spans: (0..rng.gen_range(4..22))
.map(|document_ix| {
let content_len = rng.gen_range(10..100);
let content = RandomCharIter::new(&mut rng)
.with_simple_text()
.take(content_len)
.collect::<String>();
let digest = SpanDigest::from(content.as_str());
Span {
range: 0..10,
embedding: None,
name: format!("document {document_ix}"),
content,
digest,
token_count: rng.gen_range(10..30),
}
})
.collect(),
job_handle: JobHandle::new(&outstanding_job_count),
})
.collect::<Vec<_>>();
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
for file in &files {
queue.push(file.clone());
}
queue.flush();
cx.foreground().run_until_parked();
let finished_files = queue.finished_files();
let mut embedded_files: Vec<_> = files
.iter()
.map(|_| finished_files.try_recv().expect("no finished file"))
.collect();
let expected_files: Vec<_> = files
.iter()
.map(|file| {
let mut file = file.clone();
for doc in &mut file.spans {
doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
}
file
})
.collect();
embedded_files.sort_by_key(|f| f.path.clone());
assert_eq!(embedded_files, expected_files);
}
#[track_caller]
fn assert_search_results(
actual: &[SearchResult],
@@ -227,7 +280,8 @@ fn assert_search_results(
#[gpui::test]
async fn test_code_context_retrieval_rust() {
let language = rust_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
/// A doc comment
@@ -314,7 +368,8 @@ async fn test_code_context_retrieval_rust() {
#[gpui::test]
async fn test_code_context_retrieval_json() {
let language = json_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
{
@@ -382,7 +437,7 @@ async fn test_code_context_retrieval_json() {
}
fn assert_documents_eq(
documents: &[Document],
documents: &[Span],
expected_contents_and_start_offsets: &[(String, usize)],
) {
assert_eq!(
@@ -397,7 +452,8 @@ fn assert_documents_eq(
#[gpui::test]
async fn test_code_context_retrieval_javascript() {
let language = js_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
/* globals importScripts, backend */
@@ -495,7 +551,8 @@ async fn test_code_context_retrieval_javascript() {
#[gpui::test]
async fn test_code_context_retrieval_lua() {
let language = lua_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
-- Creates a new class
@@ -568,7 +625,8 @@ async fn test_code_context_retrieval_lua() {
#[gpui::test]
async fn test_code_context_retrieval_elixir() {
let language = elixir_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
defmodule File.Stream do
@@ -684,7 +742,8 @@ async fn test_code_context_retrieval_elixir() {
#[gpui::test]
async fn test_code_context_retrieval_cpp() {
let language = cpp_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
/**
@@ -836,7 +895,8 @@ async fn test_code_context_retrieval_cpp() {
#[gpui::test]
async fn test_code_context_retrieval_ruby() {
let language = ruby_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
# This concern is inspired by "sudo mode" on GitHub. It
@@ -1026,7 +1086,8 @@ async fn test_code_context_retrieval_ruby() {
#[gpui::test]
async fn test_code_context_retrieval_php() {
let language = php_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
<?php
@@ -1173,36 +1234,6 @@ async fn test_code_context_retrieval_php() {
);
}
#[gpui::test]
fn test_dot_product(mut rng: StdRng) {
assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
for _ in 0..100 {
let size = 1536;
let mut a = vec![0.; size];
let mut b = vec![0.; size];
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
*a = rng.gen();
*b = rng.gen();
}
assert_eq!(
round_to_decimals(dot(&a, &b), 1),
round_to_decimals(reference_dot(&a, &b), 1)
);
}
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
let factor = (10.0 as f32).powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
}
}
#[derive(Default)]
struct FakeEmbeddingProvider {
embedding_count: AtomicUsize,
@@ -1212,35 +1243,42 @@ impl FakeEmbeddingProvider {
fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result.into()
}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
fn truncate(&self, span: &str) -> (String, usize) {
(span.to_string(), 1)
}
fn max_tokens_per_batch(&self) -> usize {
200
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
Ok(spans
.iter()
.map(|span| {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result
})
.collect())
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
@@ -1684,3 +1722,11 @@ fn test_subtract_ranges() {
assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
cx.set_global(SettingsStore::test(cx));
settings::register::<SemanticIndexSettings>(cx);
settings::register::<ProjectSettings>(cx);
});
}

View File

@@ -2,13 +2,13 @@ Design notes:
This crate is split into two conceptual halves:
- The terminal.rs file and the src/mappings/ folder, these contain the code for interacting with Alacritty and maintaining the pty event loop. Some behavior in this file is constrained by terminal protocols and standards. The Zed init function is also placed here.
- Everything else. These other files integrate the `Terminal` struct created in terminal.rs into the rest of GPUI. The main entry point for GPUI is the terminal_view.rs file and the modal.rs file.
- Everything else. These other files integrate the `Terminal` struct created in terminal.rs into the rest of GPUI. The main entry point for GPUI is the terminal_view.rs file and the modal.rs file.
ttys are created externally, and so can fail in unexpected ways. However, GPUI currently does not have an API for models than can fail to instantiate. `TerminalBuilder` solves this by using Rust's type system to split tty instantiation into a 2 step process: first attempt to create the file handles with `TerminalBuilder::new()`, check the result, then call `TerminalBuilder::subscribe(cx)` from within a model context.
The TerminalView struct abstracts over failed and successful terminals, passing focus through to the associated view and allowing clients to build a terminal without worrying about errors.
#Input
#Input
There are currently many distinct paths for getting keystrokes to the terminal:
@@ -18,6 +18,6 @@ There are currently many distinct paths for getting keystrokes to the terminal:
3. IME text. When the special character mappings fail, we pass the keystroke back to GPUI to hand it to the IME system. This comes back to us in the `View::replace_text_in_range()` method, and we then send that to the terminal directly, bypassing `try_keystroke()`.
4. Pasted text has a separate pathway.
4. Pasted text has a separate pathway.
Generally, there's a distinction between 'keystrokes that need to be mapped' and 'strings which need to be written'. I've attempted to unify these under the '.try_keystroke()' API and the `.input()` API (which try_keystroke uses) so we have consistent input handling across the terminal
Generally, there's a distinction between 'keystrokes that need to be mapped' and 'strings which need to be written'. I've attempted to unify these under the '.try_keystroke()' API and the `.input()` API (which try_keystroke uses) so we have consistent input handling across the terminal

View File

@@ -283,7 +283,12 @@ impl TerminalView {
pub fn deploy_context_menu(&mut self, position: Vector2F, cx: &mut ViewContext<Self>) {
let menu_entries = vec![
ContextMenuItem::action("Clear", Clear),
ContextMenuItem::action("Close", pane::CloseActiveItem),
ContextMenuItem::action(
"Close",
pane::CloseActiveItem {
save_behavior: None,
},
),
];
self.context_menu.update(cx, |menu, cx| {

View File

@@ -31,6 +31,7 @@ regex.workspace = true
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }
ctor.workspace = true
env_logger.workspace = true
rand.workspace = true

View File

@@ -8,7 +8,7 @@ use sum_tree::Bias;
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash, Default)]
pub struct Anchor {
pub timestamp: clock::Local,
pub timestamp: clock::Lamport,
pub offset: usize,
pub bias: Bias,
pub buffer_id: Option<u64>,
@@ -16,14 +16,14 @@ pub struct Anchor {
impl Anchor {
pub const MIN: Self = Self {
timestamp: clock::Local::MIN,
timestamp: clock::Lamport::MIN,
offset: usize::MIN,
bias: Bias::Left,
buffer_id: None,
};
pub const MAX: Self = Self {
timestamp: clock::Local::MAX,
timestamp: clock::Lamport::MAX,
offset: usize::MAX,
bias: Bias::Right,
buffer_id: None,

View File

@@ -46,18 +46,16 @@ lazy_static! {
static ref LINE_SEPARATORS_REGEX: Regex = Regex::new("\r\n|\r|\u{2028}|\u{2029}").unwrap();
}
pub type TransactionId = clock::Local;
pub type TransactionId = clock::Lamport;
pub struct Buffer {
snapshot: BufferSnapshot,
history: History,
deferred_ops: OperationQueue<Operation>,
deferred_replicas: HashSet<ReplicaId>,
replica_id: ReplicaId,
local_clock: clock::Local,
pub lamport_clock: clock::Lamport,
subscriptions: Topic,
edit_id_resolvers: HashMap<clock::Local, Vec<oneshot::Sender<()>>>,
edit_id_resolvers: HashMap<clock::Lamport, Vec<oneshot::Sender<()>>>,
wait_for_version_txs: Vec<(clock::Global, oneshot::Sender<()>)>,
}
@@ -85,7 +83,7 @@ pub struct HistoryEntry {
#[derive(Clone, Debug)]
pub struct Transaction {
pub id: TransactionId,
pub edit_ids: Vec<clock::Local>,
pub edit_ids: Vec<clock::Lamport>,
pub start: clock::Global,
}
@@ -97,8 +95,8 @@ impl HistoryEntry {
struct History {
base_text: Rope,
operations: TreeMap<clock::Local, Operation>,
insertion_slices: HashMap<clock::Local, Vec<InsertionSlice>>,
operations: TreeMap<clock::Lamport, Operation>,
insertion_slices: HashMap<clock::Lamport, Vec<InsertionSlice>>,
undo_stack: Vec<HistoryEntry>,
redo_stack: Vec<HistoryEntry>,
transaction_depth: usize,
@@ -107,7 +105,7 @@ struct History {
#[derive(Clone, Debug)]
struct InsertionSlice {
insertion_id: clock::Local,
insertion_id: clock::Lamport,
range: Range<usize>,
}
@@ -129,18 +127,18 @@ impl History {
}
fn push(&mut self, op: Operation) {
self.operations.insert(op.local_timestamp(), op);
self.operations.insert(op.timestamp(), op);
}
fn start_transaction(
&mut self,
start: clock::Global,
now: Instant,
local_clock: &mut clock::Local,
clock: &mut clock::Lamport,
) -> Option<TransactionId> {
self.transaction_depth += 1;
if self.transaction_depth == 1 {
let id = local_clock.tick();
let id = clock.tick();
self.undo_stack.push(HistoryEntry {
transaction: Transaction {
id,
@@ -251,7 +249,7 @@ impl History {
self.redo_stack.clear();
}
fn push_undo(&mut self, op_id: clock::Local) {
fn push_undo(&mut self, op_id: clock::Lamport) {
assert_ne!(self.transaction_depth, 0);
if let Some(Operation::Edit(_)) = self.operations.get(&op_id) {
let last_transaction = self.undo_stack.last_mut().unwrap();
@@ -412,37 +410,14 @@ impl<D1, D2> Edit<(D1, D2)> {
}
}
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
pub struct InsertionTimestamp {
pub replica_id: ReplicaId,
pub local: clock::Seq,
pub lamport: clock::Seq,
}
impl InsertionTimestamp {
pub fn local(&self) -> clock::Local {
clock::Local {
replica_id: self.replica_id,
value: self.local,
}
}
pub fn lamport(&self) -> clock::Lamport {
clock::Lamport {
replica_id: self.replica_id,
value: self.lamport,
}
}
}
#[derive(Eq, PartialEq, Clone, Debug)]
pub struct Fragment {
pub id: Locator,
pub insertion_timestamp: InsertionTimestamp,
pub timestamp: clock::Lamport,
pub insertion_offset: usize,
pub len: usize,
pub visible: bool,
pub deletions: HashSet<clock::Local>,
pub deletions: HashSet<clock::Lamport>,
pub max_undos: clock::Global,
}
@@ -470,29 +445,26 @@ impl<'a> sum_tree::Dimension<'a, FragmentSummary> for FragmentTextSummary {
#[derive(Eq, PartialEq, Clone, Debug)]
struct InsertionFragment {
timestamp: clock::Local,
timestamp: clock::Lamport,
split_offset: usize,
fragment_id: Locator,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
struct InsertionFragmentKey {
timestamp: clock::Local,
timestamp: clock::Lamport,
split_offset: usize,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Operation {
Edit(EditOperation),
Undo {
undo: UndoOperation,
lamport_timestamp: clock::Lamport,
},
Undo(UndoOperation),
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct EditOperation {
pub timestamp: InsertionTimestamp,
pub timestamp: clock::Lamport,
pub version: clock::Global,
pub ranges: Vec<Range<FullOffset>>,
pub new_text: Vec<Arc<str>>,
@@ -500,9 +472,9 @@ pub struct EditOperation {
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct UndoOperation {
pub id: clock::Local,
pub counts: HashMap<clock::Local, u32>,
pub timestamp: clock::Lamport,
pub version: clock::Global,
pub counts: HashMap<clock::Lamport, u32>,
}
impl Buffer {
@@ -514,24 +486,21 @@ impl Buffer {
let mut fragments = SumTree::new();
let mut insertions = SumTree::new();
let mut local_clock = clock::Local::new(replica_id);
let mut lamport_clock = clock::Lamport::new(replica_id);
let mut version = clock::Global::new();
let visible_text = history.base_text.clone();
if !visible_text.is_empty() {
let insertion_timestamp = InsertionTimestamp {
let insertion_timestamp = clock::Lamport {
replica_id: 0,
local: 1,
lamport: 1,
value: 1,
};
local_clock.observe(insertion_timestamp.local());
lamport_clock.observe(insertion_timestamp.lamport());
version.observe(insertion_timestamp.local());
lamport_clock.observe(insertion_timestamp);
version.observe(insertion_timestamp);
let fragment_id = Locator::between(&Locator::min(), &Locator::max());
let fragment = Fragment {
id: fragment_id,
insertion_timestamp,
timestamp: insertion_timestamp,
insertion_offset: 0,
len: visible_text.len(),
visible: true,
@@ -557,8 +526,6 @@ impl Buffer {
history,
deferred_ops: OperationQueue::new(),
deferred_replicas: HashSet::default(),
replica_id,
local_clock,
lamport_clock,
subscriptions: Default::default(),
edit_id_resolvers: Default::default(),
@@ -575,7 +542,7 @@ impl Buffer {
}
pub fn replica_id(&self) -> ReplicaId {
self.local_clock.replica_id
self.lamport_clock.replica_id
}
pub fn remote_id(&self) -> u64 {
@@ -602,16 +569,12 @@ impl Buffer {
.map(|(range, new_text)| (range, new_text.into()));
self.start_transaction();
let timestamp = InsertionTimestamp {
replica_id: self.replica_id,
local: self.local_clock.tick().value,
lamport: self.lamport_clock.tick().value,
};
let timestamp = self.lamport_clock.tick();
let operation = Operation::Edit(self.apply_local_edit(edits, timestamp));
self.history.push(operation.clone());
self.history.push_undo(operation.local_timestamp());
self.snapshot.version.observe(operation.local_timestamp());
self.history.push_undo(operation.timestamp());
self.snapshot.version.observe(operation.timestamp());
self.end_transaction();
operation
}
@@ -619,7 +582,7 @@ impl Buffer {
fn apply_local_edit<S: ToOffset, T: Into<Arc<str>>>(
&mut self,
edits: impl ExactSizeIterator<Item = (Range<S>, T)>,
timestamp: InsertionTimestamp,
timestamp: clock::Lamport,
) -> EditOperation {
let mut edits_patch = Patch::default();
let mut edit_op = EditOperation {
@@ -696,7 +659,7 @@ impl Buffer {
.item()
.map_or(&Locator::max(), |old_fragment| &old_fragment.id),
),
insertion_timestamp: timestamp,
timestamp,
insertion_offset,
len: new_text.len(),
deletions: Default::default(),
@@ -726,7 +689,7 @@ impl Buffer {
intersection.insertion_offset += fragment_start - old_fragments.start().visible;
intersection.id =
Locator::between(&new_fragments.summary().max_id, &intersection.id);
intersection.deletions.insert(timestamp.local());
intersection.deletions.insert(timestamp);
intersection.visible = false;
}
if intersection.len > 0 {
@@ -781,7 +744,7 @@ impl Buffer {
self.subscriptions.publish_mut(&edits_patch);
self.history
.insertion_slices
.insert(timestamp.local(), insertion_slices);
.insert(timestamp, insertion_slices);
edit_op
}
@@ -808,28 +771,23 @@ impl Buffer {
fn apply_op(&mut self, op: Operation) -> Result<()> {
match op {
Operation::Edit(edit) => {
if !self.version.observed(edit.timestamp.local()) {
if !self.version.observed(edit.timestamp) {
self.apply_remote_edit(
&edit.version,
&edit.ranges,
&edit.new_text,
edit.timestamp,
);
self.snapshot.version.observe(edit.timestamp.local());
self.local_clock.observe(edit.timestamp.local());
self.lamport_clock.observe(edit.timestamp.lamport());
self.resolve_edit(edit.timestamp.local());
self.snapshot.version.observe(edit.timestamp);
self.lamport_clock.observe(edit.timestamp);
self.resolve_edit(edit.timestamp);
}
}
Operation::Undo {
undo,
lamport_timestamp,
} => {
if !self.version.observed(undo.id) {
Operation::Undo(undo) => {
if !self.version.observed(undo.timestamp) {
self.apply_undo(&undo)?;
self.snapshot.version.observe(undo.id);
self.local_clock.observe(undo.id);
self.lamport_clock.observe(lamport_timestamp);
self.snapshot.version.observe(undo.timestamp);
self.lamport_clock.observe(undo.timestamp);
}
}
}
@@ -849,7 +807,7 @@ impl Buffer {
version: &clock::Global,
ranges: &[Range<FullOffset>],
new_text: &[Arc<str>],
timestamp: InsertionTimestamp,
timestamp: clock::Lamport,
) {
if ranges.is_empty() {
return;
@@ -916,9 +874,7 @@ impl Buffer {
// Skip over insertions that are concurrent to this edit, but have a lower lamport
// timestamp.
while let Some(fragment) = old_fragments.item() {
if fragment_start == range.start
&& fragment.insertion_timestamp.lamport() > timestamp.lamport()
{
if fragment_start == range.start && fragment.timestamp > timestamp {
new_ropes.push_fragment(fragment, fragment.visible);
new_fragments.push(fragment.clone(), &None);
old_fragments.next(&cx);
@@ -955,7 +911,7 @@ impl Buffer {
.item()
.map_or(&Locator::max(), |old_fragment| &old_fragment.id),
),
insertion_timestamp: timestamp,
timestamp,
insertion_offset,
len: new_text.len(),
deletions: Default::default(),
@@ -986,7 +942,7 @@ impl Buffer {
fragment_start - old_fragments.start().0.full_offset();
intersection.id =
Locator::between(&new_fragments.summary().max_id, &intersection.id);
intersection.deletions.insert(timestamp.local());
intersection.deletions.insert(timestamp);
intersection.visible = false;
insertion_slices.push(intersection.insertion_slice());
}
@@ -1038,13 +994,13 @@ impl Buffer {
self.snapshot.insertions.edit(new_insertions, &());
self.history
.insertion_slices
.insert(timestamp.local(), insertion_slices);
.insert(timestamp, insertion_slices);
self.subscriptions.publish_mut(&edits_patch)
}
fn fragment_ids_for_edits<'a>(
&'a self,
edit_ids: impl Iterator<Item = &'a clock::Local>,
edit_ids: impl Iterator<Item = &'a clock::Lamport>,
) -> Vec<&'a Locator> {
// Get all of the insertion slices changed by the given edits.
let mut insertion_slices = Vec::new();
@@ -1105,7 +1061,7 @@ impl Buffer {
let fragment_was_visible = fragment.visible;
fragment.visible = fragment.is_visible(&self.undo_map);
fragment.max_undos.observe(undo.id);
fragment.max_undos.observe(undo.timestamp);
let old_start = old_fragments.start().1;
let new_start = new_fragments.summary().text.visible;
@@ -1159,10 +1115,10 @@ impl Buffer {
if self.deferred_replicas.contains(&op.replica_id()) {
false
} else {
match op {
Operation::Edit(edit) => self.version.observed_all(&edit.version),
Operation::Undo { undo, .. } => self.version.observed_all(&undo.version),
}
self.version.observed_all(match op {
Operation::Edit(edit) => &edit.version,
Operation::Undo(undo) => &undo.version,
})
}
}
@@ -1180,7 +1136,7 @@ impl Buffer {
pub fn start_transaction_at(&mut self, now: Instant) -> Option<TransactionId> {
self.history
.start_transaction(self.version.clone(), now, &mut self.local_clock)
.start_transaction(self.version.clone(), now, &mut self.lamport_clock)
}
pub fn end_transaction(&mut self) -> Option<(TransactionId, clock::Global)> {
@@ -1209,7 +1165,7 @@ impl Buffer {
&self.history.base_text
}
pub fn operations(&self) -> &TreeMap<clock::Local, Operation> {
pub fn operations(&self) -> &TreeMap<clock::Lamport, Operation> {
&self.history.operations
}
@@ -1289,16 +1245,13 @@ impl Buffer {
}
let undo = UndoOperation {
id: self.local_clock.tick(),
timestamp: self.lamport_clock.tick(),
version: self.version(),
counts,
};
self.apply_undo(&undo)?;
let operation = Operation::Undo {
undo,
lamport_timestamp: self.lamport_clock.tick(),
};
self.snapshot.version.observe(operation.local_timestamp());
self.snapshot.version.observe(undo.timestamp);
let operation = Operation::Undo(undo);
self.history.push(operation.clone());
Ok(operation)
}
@@ -1363,7 +1316,7 @@ impl Buffer {
pub fn wait_for_edits(
&mut self,
edit_ids: impl IntoIterator<Item = clock::Local>,
edit_ids: impl IntoIterator<Item = clock::Lamport>,
) -> impl 'static + Future<Output = Result<()>> {
let mut futures = Vec::new();
for edit_id in edit_ids {
@@ -1435,7 +1388,7 @@ impl Buffer {
self.wait_for_version_txs.clear();
}
fn resolve_edit(&mut self, edit_id: clock::Local) {
fn resolve_edit(&mut self, edit_id: clock::Lamport) {
for mut tx in self
.edit_id_resolvers
.remove(&edit_id)
@@ -1513,7 +1466,7 @@ impl Buffer {
.insertions
.get(
&InsertionFragmentKey {
timestamp: fragment.insertion_timestamp.local(),
timestamp: fragment.timestamp,
split_offset: fragment.insertion_offset,
},
&(),
@@ -1996,7 +1949,7 @@ impl BufferSnapshot {
let fragment = fragment_cursor.item().unwrap();
let overshoot = offset - *fragment_cursor.start();
Anchor {
timestamp: fragment.insertion_timestamp.local(),
timestamp: fragment.timestamp,
offset: fragment.insertion_offset + overshoot,
bias,
buffer_id: Some(self.remote_id),
@@ -2188,15 +2141,14 @@ impl<'a, D: TextDimension + Ord, F: FnMut(&FragmentSummary) -> bool> Iterator fo
break;
}
let timestamp = fragment.insertion_timestamp.local();
let start_anchor = Anchor {
timestamp,
timestamp: fragment.timestamp,
offset: fragment.insertion_offset,
bias: Bias::Right,
buffer_id: Some(self.buffer_id),
};
let end_anchor = Anchor {
timestamp,
timestamp: fragment.timestamp,
offset: fragment.insertion_offset + fragment.len,
bias: Bias::Left,
buffer_id: Some(self.buffer_id),
@@ -2269,19 +2221,17 @@ impl<'a, D: TextDimension + Ord, F: FnMut(&FragmentSummary) -> bool> Iterator fo
impl Fragment {
fn insertion_slice(&self) -> InsertionSlice {
InsertionSlice {
insertion_id: self.insertion_timestamp.local(),
insertion_id: self.timestamp,
range: self.insertion_offset..self.insertion_offset + self.len,
}
}
fn is_visible(&self, undos: &UndoMap) -> bool {
!undos.is_undone(self.insertion_timestamp.local())
&& self.deletions.iter().all(|d| undos.is_undone(*d))
!undos.is_undone(self.timestamp) && self.deletions.iter().all(|d| undos.is_undone(*d))
}
fn was_visible(&self, version: &clock::Global, undos: &UndoMap) -> bool {
(version.observed(self.insertion_timestamp.local())
&& !undos.was_undone(self.insertion_timestamp.local(), version))
(version.observed(self.timestamp) && !undos.was_undone(self.timestamp, version))
&& self
.deletions
.iter()
@@ -2294,14 +2244,14 @@ impl sum_tree::Item for Fragment {
fn summary(&self) -> Self::Summary {
let mut max_version = clock::Global::new();
max_version.observe(self.insertion_timestamp.local());
max_version.observe(self.timestamp);
for deletion in &self.deletions {
max_version.observe(*deletion);
}
max_version.join(&self.max_undos);
let mut min_insertion_version = clock::Global::new();
min_insertion_version.observe(self.insertion_timestamp.local());
min_insertion_version.observe(self.timestamp);
let max_insertion_version = min_insertion_version.clone();
if self.visible {
FragmentSummary {
@@ -2378,7 +2328,7 @@ impl sum_tree::KeyedItem for InsertionFragment {
impl InsertionFragment {
fn new(fragment: &Fragment) -> Self {
Self {
timestamp: fragment.insertion_timestamp.local(),
timestamp: fragment.timestamp,
split_offset: fragment.insertion_offset,
fragment_id: fragment.id.clone(),
}
@@ -2501,10 +2451,10 @@ impl Operation {
operation_queue::Operation::lamport_timestamp(self).replica_id
}
pub fn local_timestamp(&self) -> clock::Local {
pub fn timestamp(&self) -> clock::Lamport {
match self {
Operation::Edit(edit) => edit.timestamp.local(),
Operation::Undo { undo, .. } => undo.id,
Operation::Edit(edit) => edit.timestamp,
Operation::Undo(undo) => undo.timestamp,
}
}
@@ -2523,10 +2473,8 @@ impl Operation {
impl operation_queue::Operation for Operation {
fn lamport_timestamp(&self) -> clock::Lamport {
match self {
Operation::Edit(edit) => edit.timestamp.lamport(),
Operation::Undo {
lamport_timestamp, ..
} => *lamport_timestamp,
Operation::Edit(edit) => edit.timestamp,
Operation::Undo(undo) => undo.timestamp,
}
}
}

View File

@@ -26,8 +26,8 @@ impl sum_tree::KeyedItem for UndoMapEntry {
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
struct UndoMapKey {
edit_id: clock::Local,
undo_id: clock::Local,
edit_id: clock::Lamport,
undo_id: clock::Lamport,
}
impl sum_tree::Summary for UndoMapKey {
@@ -50,7 +50,7 @@ impl UndoMap {
sum_tree::Edit::Insert(UndoMapEntry {
key: UndoMapKey {
edit_id: *edit_id,
undo_id: undo.id,
undo_id: undo.timestamp,
},
undo_count: *count,
})
@@ -59,11 +59,11 @@ impl UndoMap {
self.0.edit(edits, &());
}
pub fn is_undone(&self, edit_id: clock::Local) -> bool {
pub fn is_undone(&self, edit_id: clock::Lamport) -> bool {
self.undo_count(edit_id) % 2 == 1
}
pub fn was_undone(&self, edit_id: clock::Local, version: &clock::Global) -> bool {
pub fn was_undone(&self, edit_id: clock::Lamport, version: &clock::Global) -> bool {
let mut cursor = self.0.cursor::<UndoMapKey>();
cursor.seek(
&UndoMapKey {
@@ -88,7 +88,7 @@ impl UndoMap {
undo_count % 2 == 1
}
pub fn undo_count(&self, edit_id: clock::Local) -> u32 {
pub fn undo_count(&self, edit_id: clock::Lamport) -> u32 {
let mut cursor = self.0.cursor::<UndoMapKey>();
cursor.seek(
&UndoMapKey {

View File

@@ -408,6 +408,7 @@ pub struct Toolbar {
pub height: f32,
pub item_spacing: f32,
pub toggleable_tool: Toggleable<Interactive<IconButton>>,
pub toggleable_text_tool: Toggleable<Interactive<ContainedText>>,
pub breadcrumb_height: f32,
pub breadcrumbs: Interactive<ContainedText>,
}
@@ -834,6 +835,9 @@ pub struct AutocompleteStyle {
pub selected_item: ContainerStyle,
pub hovered_item: ContainerStyle,
pub match_highlight: HighlightStyle,
pub server_name_container: ContainerStyle,
pub server_name_color: Color,
pub server_name_size_percent: f32,
}
#[derive(Clone, Copy, Default, Deserialize, JsonSchema)]

View File

@@ -260,11 +260,22 @@ pub fn defer<F: FnOnce()>(f: F) -> impl Drop {
Defer(Some(f))
}
pub struct RandomCharIter<T: Rng>(T);
pub struct RandomCharIter<T: Rng> {
rng: T,
simple_text: bool,
}
impl<T: Rng> RandomCharIter<T> {
pub fn new(rng: T) -> Self {
Self(rng)
Self {
rng,
simple_text: std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()),
}
}
pub fn with_simple_text(mut self) -> Self {
self.simple_text = true;
self
}
}
@@ -272,25 +283,27 @@ impl<T: Rng> Iterator for RandomCharIter<T> {
type Item = char;
fn next(&mut self) -> Option<Self::Item> {
if std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()) {
return if self.0.gen_range(0..100) < 5 {
if self.simple_text {
return if self.rng.gen_range(0..100) < 5 {
Some('\n')
} else {
Some(self.0.gen_range(b'a'..b'z' + 1).into())
Some(self.rng.gen_range(b'a'..b'z' + 1).into())
};
}
match self.0.gen_range(0..100) {
match self.rng.gen_range(0..100) {
// whitespace
0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.0).copied(),
0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.rng).copied(),
// two-byte greek letters
20..=32 => char::from_u32(self.0.gen_range(('α' as u32)..('ω' as u32 + 1))),
20..=32 => char::from_u32(self.rng.gen_range(('α' as u32)..('ω' as u32 + 1))),
// // three-byte characters
33..=45 => ['✋', '✅', '❌', '❎', '⭐'].choose(&mut self.0).copied(),
33..=45 => ['✋', '✅', '❌', '❎', '⭐']
.choose(&mut self.rng)
.copied(),
// // four-byte characters
46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.0).copied(),
46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.rng).copied(),
// ascii letters
_ => Some(self.0.gen_range(b'a'..b'z' + 1).into()),
_ => Some(self.rng.gen_range(b'a'..b'z' + 1).into()),
}
}
}

View File

@@ -38,6 +38,7 @@ language_selector = { path = "../language_selector"}
[dev-dependencies]
indoc.workspace = true
parking_lot.workspace = true
futures.workspace = true
editor = { path = "../editor", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
@@ -47,3 +48,4 @@ util = { path = "../util", features = ["test-support"] }
settings = { path = "../settings" }
workspace = { path = "../workspace", features = ["test-support"] }
theme = { path = "../theme", features = ["test-support"] }
lsp = { path = "../lsp", features = ["test-support"] }

View File

@@ -34,6 +34,7 @@ fn focused(EditorFocused(editor): &EditorFocused, cx: &mut AppContext) {
fn blurred(EditorBlurred(editor): &EditorBlurred, cx: &mut AppContext) {
editor.window().update(cx, |cx| {
Vim::update(cx, |vim, cx| {
vim.workspace_state.recording = false;
if let Some(previous_editor) = vim.active_editor.clone() {
if previous_editor == editor.clone() {
vim.active_editor = None;

View File

@@ -11,8 +11,9 @@ pub fn init(cx: &mut AppContext) {
}
fn normal_before(_: &mut Workspace, _: &NormalBefore, cx: &mut ViewContext<Workspace>) {
Vim::update(cx, |state, cx| {
state.update_active_editor(cx, |editor, cx| {
Vim::update(cx, |vim, cx| {
vim.stop_recording();
vim.update_active_editor(cx, |editor, cx| {
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
s.move_cursors_with(|map, mut cursor, _| {
*cursor.column_mut() = cursor.column().saturating_sub(1);
@@ -20,7 +21,7 @@ fn normal_before(_: &mut Workspace, _: &NormalBefore, cx: &mut ViewContext<Works
});
});
});
state.switch_mode(Mode::Normal, false, cx);
vim.switch_mode(Mode::Normal, false, cx);
})
}

View File

@@ -1,9 +1,10 @@
use std::{cmp, sync::Arc};
use std::cmp;
use editor::{
char_kind,
display_map::{DisplaySnapshot, FoldPoint, ToDisplayPoint},
movement, Bias, CharKind, DisplayPoint, ToOffset,
movement::{self, find_boundary, find_preceding_boundary, FindRange},
Bias, CharKind, DisplayPoint, ToOffset,
};
use gpui::{actions, impl_actions, AppContext, WindowContext};
use language::{Point, Selection, SelectionGoal};
@@ -36,8 +37,8 @@ pub enum Motion {
StartOfDocument,
EndOfDocument,
Matching,
FindForward { before: bool, text: Arc<str> },
FindBackward { after: bool, text: Arc<str> },
FindForward { before: bool, char: char },
FindBackward { after: bool, char: char },
NextLineStart,
}
@@ -64,9 +65,9 @@ struct PreviousWordStart {
#[derive(Clone, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
struct Up {
pub(crate) struct Up {
#[serde(default)]
display_lines: bool,
pub(crate) display_lines: bool,
}
#[derive(Clone, Deserialize, PartialEq)]
@@ -92,9 +93,9 @@ struct EndOfLine {
#[derive(Clone, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
struct StartOfLine {
pub struct StartOfLine {
#[serde(default)]
display_lines: bool,
pub(crate) display_lines: bool,
}
#[derive(Clone, Deserialize, PartialEq)]
@@ -232,25 +233,25 @@ pub(crate) fn motion(motion: Motion, cx: &mut WindowContext) {
fn repeat_motion(backwards: bool, cx: &mut WindowContext) {
let find = match Vim::read(cx).workspace_state.last_find.clone() {
Some(Motion::FindForward { before, text }) => {
Some(Motion::FindForward { before, char }) => {
if backwards {
Motion::FindBackward {
after: before,
text,
char,
}
} else {
Motion::FindForward { before, text }
Motion::FindForward { before, char }
}
}
Some(Motion::FindBackward { after, text }) => {
Some(Motion::FindBackward { after, char }) => {
if backwards {
Motion::FindForward {
before: after,
text,
char,
}
} else {
Motion::FindBackward { after, text }
Motion::FindBackward { after, char }
}
}
_ => return,
@@ -402,12 +403,12 @@ impl Motion {
SelectionGoal::None,
),
Matching => (matching(map, point), SelectionGoal::None),
FindForward { before, text } => (
find_forward(map, point, *before, text.clone(), times),
FindForward { before, char } => (
find_forward(map, point, *before, *char, times),
SelectionGoal::None,
),
FindBackward { after, text } => (
find_backward(map, point, *after, text.clone(), times),
FindBackward { after, char } => (
find_backward(map, point, *after, *char, times),
SelectionGoal::None,
),
NextLineStart => (next_line_start(map, point, times), SelectionGoal::None),
@@ -589,12 +590,12 @@ pub(crate) fn next_word_start(
ignore_punctuation: bool,
times: usize,
) -> DisplayPoint {
let language = map.buffer_snapshot.language_at(point.to_point(map));
let scope = map.buffer_snapshot.language_scope_at(point.to_point(map));
for _ in 0..times {
let mut crossed_newline = false;
point = movement::find_boundary(map, point, |left, right| {
let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation);
let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation);
point = movement::find_boundary(map, point, FindRange::MultiLine, |left, right| {
let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation);
let right_kind = char_kind(&scope, right).coerce_punctuation(ignore_punctuation);
let at_newline = right == '\n';
let found = (left_kind != right_kind && right_kind != CharKind::Whitespace)
@@ -614,12 +615,17 @@ fn next_word_end(
ignore_punctuation: bool,
times: usize,
) -> DisplayPoint {
let language = map.buffer_snapshot.language_at(point.to_point(map));
let scope = map.buffer_snapshot.language_scope_at(point.to_point(map));
for _ in 0..times {
*point.column_mut() += 1;
point = movement::find_boundary(map, point, |left, right| {
let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation);
let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation);
if point.column() < map.line_len(point.row()) {
*point.column_mut() += 1;
} else if point.row() < map.max_buffer_row() {
*point.row_mut() += 1;
*point.column_mut() = 0;
}
point = movement::find_boundary(map, point, FindRange::MultiLine, |left, right| {
let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation);
let right_kind = char_kind(&scope, right).coerce_punctuation(ignore_punctuation);
left_kind != right_kind && left_kind != CharKind::Whitespace
});
@@ -645,16 +651,17 @@ fn previous_word_start(
ignore_punctuation: bool,
times: usize,
) -> DisplayPoint {
let language = map.buffer_snapshot.language_at(point.to_point(map));
let scope = map.buffer_snapshot.language_scope_at(point.to_point(map));
for _ in 0..times {
// This works even though find_preceding_boundary is called for every character in the line containing
// cursor because the newline is checked only once.
point = movement::find_preceding_boundary(map, point, |left, right| {
let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation);
let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation);
point =
movement::find_preceding_boundary(map, point, FindRange::MultiLine, |left, right| {
let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation);
let right_kind = char_kind(&scope, right).coerce_punctuation(ignore_punctuation);
(left_kind != right_kind && !right.is_whitespace()) || left == '\n'
});
(left_kind != right_kind && !right.is_whitespace()) || left == '\n'
});
}
point
}
@@ -665,7 +672,7 @@ fn first_non_whitespace(
from: DisplayPoint,
) -> DisplayPoint {
let mut last_point = start_of_line(map, display_lines, from);
let language = map.buffer_snapshot.language_at(from.to_point(map));
let scope = map.buffer_snapshot.language_scope_at(from.to_point(map));
for (ch, point) in map.chars_at(last_point) {
if ch == '\n' {
return from;
@@ -673,7 +680,7 @@ fn first_non_whitespace(
last_point = point;
if char_kind(language, ch) != CharKind::Whitespace {
if char_kind(&scope, ch) != CharKind::Whitespace {
break;
}
}
@@ -786,44 +793,55 @@ fn find_forward(
map: &DisplaySnapshot,
from: DisplayPoint,
before: bool,
target: Arc<str>,
target: char,
times: usize,
) -> DisplayPoint {
map.find_while(from, target.as_ref(), |ch, _| ch != '\n')
.skip_while(|found_at| found_at == &from)
.nth(times - 1)
.map(|mut found| {
if before {
*found.column_mut() -= 1;
found = map.clip_point(found, Bias::Right);
found
} else {
found
}
})
.unwrap_or(from)
let mut to = from;
let mut found = false;
for _ in 0..times {
found = false;
to = find_boundary(map, to, FindRange::SingleLine, |_, right| {
found = right == target;
found
});
}
if found {
if before && to.column() > 0 {
*to.column_mut() -= 1;
map.clip_point(to, Bias::Left)
} else {
to
}
} else {
from
}
}
fn find_backward(
map: &DisplaySnapshot,
from: DisplayPoint,
after: bool,
target: Arc<str>,
target: char,
times: usize,
) -> DisplayPoint {
map.reverse_find_while(from, target.as_ref(), |ch, _| ch != '\n')
.skip_while(|found_at| found_at == &from)
.nth(times - 1)
.map(|mut found| {
if after {
*found.column_mut() += 1;
found = map.clip_point(found, Bias::Left);
found
} else {
found
}
})
.unwrap_or(from)
let mut to = from;
for _ in 0..times {
to = find_preceding_boundary(map, to, FindRange::SingleLine, |_, right| right == target);
}
if map.buffer_snapshot.chars_at(to.to_point(map)).next() == Some(target) {
if after {
*to.column_mut() += 1;
map.clip_point(to, Bias::Right)
} else {
to
}
} else {
from
}
}
fn next_line_start(map: &DisplaySnapshot, point: DisplayPoint, times: usize) -> DisplayPoint {

View File

@@ -2,6 +2,7 @@ mod case;
mod change;
mod delete;
mod paste;
mod repeat;
mod scroll;
mod search;
pub mod substitute;
@@ -27,7 +28,6 @@ use self::{
case::change_case,
change::{change_motion, change_object},
delete::{delete_motion, delete_object},
substitute::substitute,
yank::{yank_motion, yank_object},
};
@@ -35,6 +35,7 @@ actions!(
vim,
[
InsertAfter,
InsertBefore,
InsertFirstNonWhitespace,
InsertEndOfLine,
InsertLineAbove,
@@ -44,39 +45,43 @@ actions!(
ChangeToEndOfLine,
DeleteToEndOfLine,
Yank,
Substitute,
ChangeCase,
JoinLines,
]
);
pub fn init(cx: &mut AppContext) {
paste::init(cx);
repeat::init(cx);
scroll::init(cx);
search::init(cx);
substitute::init(cx);
cx.add_action(insert_after);
cx.add_action(insert_before);
cx.add_action(insert_first_non_whitespace);
cx.add_action(insert_end_of_line);
cx.add_action(insert_line_above);
cx.add_action(insert_line_below);
cx.add_action(change_case);
search::init(cx);
cx.add_action(|_: &mut Workspace, _: &Substitute, cx| {
Vim::update(cx, |vim, cx| {
let times = vim.pop_number_operator(cx);
substitute(vim, times, cx);
})
});
cx.add_action(|_: &mut Workspace, _: &DeleteLeft, cx| {
Vim::update(cx, |vim, cx| {
vim.record_current_action(cx);
let times = vim.pop_number_operator(cx);
delete_motion(vim, Motion::Left, times, cx);
})
});
cx.add_action(|_: &mut Workspace, _: &DeleteRight, cx| {
Vim::update(cx, |vim, cx| {
vim.record_current_action(cx);
let times = vim.pop_number_operator(cx);
delete_motion(vim, Motion::Right, times, cx);
})
});
cx.add_action(|_: &mut Workspace, _: &ChangeToEndOfLine, cx| {
Vim::update(cx, |vim, cx| {
vim.start_recording(cx);
let times = vim.pop_number_operator(cx);
change_motion(
vim,
@@ -90,6 +95,7 @@ pub fn init(cx: &mut AppContext) {
});
cx.add_action(|_: &mut Workspace, _: &DeleteToEndOfLine, cx| {
Vim::update(cx, |vim, cx| {
vim.record_current_action(cx);
let times = vim.pop_number_operator(cx);
delete_motion(
vim,
@@ -101,8 +107,26 @@ pub fn init(cx: &mut AppContext) {
);
})
});
scroll::init(cx);
paste::init(cx);
cx.add_action(|_: &mut Workspace, _: &JoinLines, cx| {
Vim::update(cx, |vim, cx| {
vim.record_current_action(cx);
let mut times = vim.pop_number_operator(cx).unwrap_or(1);
if vim.state().mode.is_visual() {
times = 1;
} else if times > 1 {
// 2J joins two lines together (same as J or 1J)
times -= 1;
}
vim.update_active_editor(cx, |editor, cx| {
editor.transact(cx, |editor, cx| {
for _ in 0..times {
editor.join_lines(&Default::default(), cx)
}
})
})
})
})
}
pub fn normal_motion(
@@ -158,6 +182,7 @@ fn move_cursor(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &mut Win
fn insert_after(_: &mut Workspace, _: &InsertAfter, cx: &mut ViewContext<Workspace>) {
Vim::update(cx, |vim, cx| {
vim.start_recording(cx);
vim.switch_mode(Mode::Insert, false, cx);
vim.update_active_editor(cx, |editor, cx| {
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
@@ -169,12 +194,20 @@ fn insert_after(_: &mut Workspace, _: &InsertAfter, cx: &mut ViewContext<Workspa
});
}
fn insert_before(_: &mut Workspace, _: &InsertBefore, cx: &mut ViewContext<Workspace>) {
Vim::update(cx, |vim, cx| {
vim.start_recording(cx);
vim.switch_mode(Mode::Insert, false, cx);
});
}
fn insert_first_non_whitespace(
_: &mut Workspace,
_: &InsertFirstNonWhitespace,
cx: &mut ViewContext<Workspace>,
) {
Vim::update(cx, |vim, cx| {
vim.start_recording(cx);
vim.switch_mode(Mode::Insert, false, cx);
vim.update_active_editor(cx, |editor, cx| {
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
@@ -191,6 +224,7 @@ fn insert_first_non_whitespace(
fn insert_end_of_line(_: &mut Workspace, _: &InsertEndOfLine, cx: &mut ViewContext<Workspace>) {
Vim::update(cx, |vim, cx| {
vim.start_recording(cx);
vim.switch_mode(Mode::Insert, false, cx);
vim.update_active_editor(cx, |editor, cx| {
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
@@ -204,6 +238,7 @@ fn insert_end_of_line(_: &mut Workspace, _: &InsertEndOfLine, cx: &mut ViewConte
fn insert_line_above(_: &mut Workspace, _: &InsertLineAbove, cx: &mut ViewContext<Workspace>) {
Vim::update(cx, |vim, cx| {
vim.start_recording(cx);
vim.switch_mode(Mode::Insert, false, cx);
vim.update_active_editor(cx, |editor, cx| {
editor.transact(cx, |editor, cx| {
@@ -236,6 +271,7 @@ fn insert_line_above(_: &mut Workspace, _: &InsertLineAbove, cx: &mut ViewContex
fn insert_line_below(_: &mut Workspace, _: &InsertLineBelow, cx: &mut ViewContext<Workspace>) {
Vim::update(cx, |vim, cx| {
vim.start_recording(cx);
vim.switch_mode(Mode::Insert, false, cx);
vim.update_active_editor(cx, |editor, cx| {
editor.transact(cx, |editor, cx| {
@@ -267,6 +303,7 @@ fn insert_line_below(_: &mut Workspace, _: &InsertLineBelow, cx: &mut ViewContex
pub(crate) fn normal_replace(text: Arc<str>, cx: &mut WindowContext) {
Vim::update(cx, |vim, cx| {
vim.stop_recording();
vim.update_active_editor(cx, |editor, cx| {
editor.transact(cx, |editor, cx| {
editor.set_clip_at_line_ends(false, cx);
@@ -445,7 +482,7 @@ mod test {
}
#[gpui::test]
async fn test_e(cx: &mut gpui::TestAppContext) {
async fn test_end_of_word(cx: &mut gpui::TestAppContext) {
let mut cx = NeovimBackedTestContext::new(cx).await.binding(["e"]);
cx.assert_all(indoc! {"
Thˇe quicˇkˇ-browˇn
@@ -787,6 +824,7 @@ mod test {
#[gpui::test]
async fn test_f_and_t(cx: &mut gpui::TestAppContext) {
let mut cx = NeovimBackedTestContext::new(cx).await;
for count in 1..=3 {
let test_case = indoc! {"
ˇaaaˇbˇ ˇbˇ ˇbˇbˇ aˇaaˇbaaa

View File

@@ -7,6 +7,7 @@ use crate::{normal::ChangeCase, state::Mode, Vim};
pub fn change_case(_: &mut Workspace, _: &ChangeCase, cx: &mut ViewContext<Workspace>) {
Vim::update(cx, |vim, cx| {
vim.record_current_action(cx);
let count = vim.pop_number_operator(cx).unwrap_or(1) as u32;
vim.update_active_editor(cx, |editor, cx| {
let mut ranges = Vec::new();
@@ -21,10 +22,16 @@ pub fn change_case(_: &mut Workspace, _: &ChangeCase, cx: &mut ViewContext<Works
ranges.push(start..end);
cursor_positions.push(start..start);
}
Mode::Visual | Mode::VisualBlock => {
Mode::Visual => {
ranges.push(selection.start..selection.end);
cursor_positions.push(selection.start..selection.start);
}
Mode::VisualBlock => {
ranges.push(selection.start..selection.end);
if cursor_positions.len() == 0 {
cursor_positions.push(selection.start..selection.start);
}
}
Mode::Insert | Mode::Normal => {
let start = selection.start;
let mut end = start;
@@ -96,6 +103,11 @@ mod test {
cx.simulate_shared_keystrokes(["shift-v", "~"]).await;
cx.assert_shared_state("ˇABc\n").await;
// works in visual block mode
cx.set_shared_state("ˇaa\nbb\ncc").await;
cx.simulate_shared_keystrokes(["ctrl-v", "j", "~"]).await;
cx.assert_shared_state("ˇAa\nBb\ncc").await;
// works with multiple cursors (zed only)
cx.set_state("aˇßcdˇe\n", Mode::Normal);
cx.simulate_keystroke("~");

View File

@@ -1,7 +1,10 @@
use crate::{motion::Motion, object::Object, state::Mode, utils::copy_selections_content, Vim};
use editor::{
char_kind, display_map::DisplaySnapshot, movement, scroll::autoscroll::Autoscroll, CharKind,
DisplayPoint,
char_kind,
display_map::DisplaySnapshot,
movement::{self, FindRange},
scroll::autoscroll::Autoscroll,
CharKind, DisplayPoint,
};
use gpui::WindowContext;
use language::Selection;
@@ -86,22 +89,24 @@ fn expand_changed_word_selection(
ignore_punctuation: bool,
) -> bool {
if times.is_none() || times.unwrap() == 1 {
let language = map
let scope = map
.buffer_snapshot
.language_at(selection.start.to_point(map));
.language_scope_at(selection.start.to_point(map));
let in_word = map
.chars_at(selection.head())
.next()
.map(|(c, _)| char_kind(language, c) != CharKind::Whitespace)
.map(|(c, _)| char_kind(&scope, c) != CharKind::Whitespace)
.unwrap_or_default();
if in_word {
selection.end = movement::find_boundary(map, selection.end, |left, right| {
let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation);
let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation);
selection.end =
movement::find_boundary(map, selection.end, FindRange::MultiLine, |left, right| {
let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation);
let right_kind =
char_kind(&scope, right).coerce_punctuation(ignore_punctuation);
left_kind != right_kind && left_kind != CharKind::Whitespace
});
left_kind != right_kind && left_kind != CharKind::Whitespace
});
true
} else {
Motion::NextWordStart { ignore_punctuation }

View File

@@ -4,6 +4,7 @@ use editor::{display_map::ToDisplayPoint, scroll::autoscroll::Autoscroll, Bias};
use gpui::WindowContext;
pub fn delete_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &mut WindowContext) {
vim.stop_recording();
vim.update_active_editor(cx, |editor, cx| {
editor.transact(cx, |editor, cx| {
editor.set_clip_at_line_ends(false, cx);
@@ -37,6 +38,7 @@ pub fn delete_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &m
}
pub fn delete_object(vim: &mut Vim, object: Object, around: bool, cx: &mut WindowContext) {
vim.stop_recording();
vim.update_active_editor(cx, |editor, cx| {
editor.transact(cx, |editor, cx| {
editor.set_clip_at_line_ends(false, cx);

View File

@@ -28,6 +28,7 @@ pub(crate) fn init(cx: &mut AppContext) {
fn paste(_: &mut Workspace, action: &Paste, cx: &mut ViewContext<Workspace>) {
Vim::update(cx, |vim, cx| {
vim.record_current_action(cx);
vim.update_active_editor(cx, |editor, cx| {
editor.transact(cx, |editor, cx| {
editor.set_clip_at_line_ends(false, cx);

View File

@@ -0,0 +1,427 @@
use crate::{
motion::Motion,
state::{Mode, RecordedSelection, ReplayableAction},
visual::visual_motion,
Vim,
};
use gpui::{actions, Action, AppContext};
use workspace::Workspace;
actions!(vim, [Repeat, EndRepeat,]);
fn should_replay(action: &Box<dyn Action>) -> bool {
// skip so that we don't leave the character palette open
if editor::ShowCharacterPalette.id() == action.id() {
return false;
}
true
}
pub(crate) fn init(cx: &mut AppContext) {
cx.add_action(|_: &mut Workspace, _: &EndRepeat, cx| {
Vim::update(cx, |vim, cx| {
vim.workspace_state.replaying = false;
vim.update_active_editor(cx, |editor, _| {
editor.show_local_selections = true;
});
vim.switch_mode(Mode::Normal, false, cx)
});
});
cx.add_action(|_: &mut Workspace, _: &Repeat, cx| {
let Some((actions, editor, selection)) = Vim::update(cx, |vim, cx| {
let actions = vim.workspace_state.recorded_actions.clone();
let Some(editor) = vim.active_editor.clone() else {
return None;
};
let count = vim.pop_number_operator(cx);
vim.workspace_state.replaying = true;
let selection = vim.workspace_state.recorded_selection.clone();
match selection {
RecordedSelection::SingleLine { .. } | RecordedSelection::Visual { .. } => {
vim.workspace_state.recorded_count = None;
vim.switch_mode(Mode::Visual, false, cx)
}
RecordedSelection::VisualLine { .. } => {
vim.workspace_state.recorded_count = None;
vim.switch_mode(Mode::VisualLine, false, cx)
}
RecordedSelection::VisualBlock { .. } => {
vim.workspace_state.recorded_count = None;
vim.switch_mode(Mode::VisualBlock, false, cx)
}
RecordedSelection::None => {
if let Some(count) = count {
vim.workspace_state.recorded_count = Some(count);
}
}
}
if let Some(editor) = editor.upgrade(cx) {
editor.update(cx, |editor, _| {
editor.show_local_selections = false;
})
} else {
return None;
}
Some((actions, editor, selection))
}) else {
return;
};
match selection {
RecordedSelection::SingleLine { cols } => {
if cols > 1 {
visual_motion(Motion::Right, Some(cols as usize - 1), cx)
}
}
RecordedSelection::Visual { rows, cols } => {
visual_motion(
Motion::Down {
display_lines: false,
},
Some(rows as usize),
cx,
);
visual_motion(
Motion::StartOfLine {
display_lines: false,
},
None,
cx,
);
if cols > 1 {
visual_motion(Motion::Right, Some(cols as usize - 1), cx)
}
}
RecordedSelection::VisualBlock { rows, cols } => {
visual_motion(
Motion::Down {
display_lines: false,
},
Some(rows as usize),
cx,
);
if cols > 1 {
visual_motion(Motion::Right, Some(cols as usize - 1), cx);
}
}
RecordedSelection::VisualLine { rows } => {
visual_motion(
Motion::Down {
display_lines: false,
},
Some(rows as usize),
cx,
);
}
RecordedSelection::None => {}
}
let window = cx.window();
cx.app_context()
.spawn(move |mut cx| async move {
for action in actions {
match action {
ReplayableAction::Action(action) => {
if should_replay(&action) {
window
.dispatch_action(editor.id(), action.as_ref(), &mut cx)
.ok_or_else(|| anyhow::anyhow!("window was closed"))
} else {
Ok(())
}
}
ReplayableAction::Insertion {
text,
utf16_range_to_replace,
} => editor.update(&mut cx, |editor, cx| {
editor.replay_insert_event(&text, utf16_range_to_replace.clone(), cx)
}),
}?
}
window
.dispatch_action(editor.id(), &EndRepeat, &mut cx)
.ok_or_else(|| anyhow::anyhow!("window was closed"))
})
.detach_and_log_err(cx);
});
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use editor::test::editor_lsp_test_context::EditorLspTestContext;
use futures::StreamExt;
use indoc::indoc;
use gpui::{executor::Deterministic, View};
use crate::{
state::Mode,
test::{NeovimBackedTestContext, VimTestContext},
};
#[gpui::test]
async fn test_dot_repeat(deterministic: Arc<Deterministic>, cx: &mut gpui::TestAppContext) {
let mut cx = NeovimBackedTestContext::new(cx).await;
// "o"
cx.set_shared_state("ˇhello").await;
cx.simulate_shared_keystrokes(["o", "w", "o", "r", "l", "d", "escape"])
.await;
cx.assert_shared_state("hello\nworlˇd").await;
cx.simulate_shared_keystrokes(["."]).await;
deterministic.run_until_parked();
cx.assert_shared_state("hello\nworld\nworlˇd").await;
// "d"
cx.simulate_shared_keystrokes(["^", "d", "f", "o"]).await;
cx.simulate_shared_keystrokes(["g", "g", "."]).await;
deterministic.run_until_parked();
cx.assert_shared_state("ˇ\nworld\nrld").await;
// "p" (note that it pastes the current clipboard)
cx.simulate_shared_keystrokes(["j", "y", "y", "p"]).await;
cx.simulate_shared_keystrokes(["shift-g", "y", "y", "."])
.await;
deterministic.run_until_parked();
cx.assert_shared_state("\nworld\nworld\nrld\nˇrld").await;
// "~" (note that counts apply to the action taken, not . itself)
cx.set_shared_state("ˇthe quick brown fox").await;
cx.simulate_shared_keystrokes(["2", "~", "."]).await;
deterministic.run_until_parked();
cx.set_shared_state("THE ˇquick brown fox").await;
cx.simulate_shared_keystrokes(["3", "."]).await;
deterministic.run_until_parked();
cx.set_shared_state("THE QUIˇck brown fox").await;
deterministic.run_until_parked();
cx.simulate_shared_keystrokes(["."]).await;
deterministic.run_until_parked();
cx.set_shared_state("THE QUICK ˇbrown fox").await;
}
#[gpui::test]
async fn test_repeat_ime(deterministic: Arc<Deterministic>, cx: &mut gpui::TestAppContext) {
let mut cx = VimTestContext::new(cx, true).await;
cx.set_state("hˇllo", Mode::Normal);
cx.simulate_keystrokes(["i"]);
// simulate brazilian input for ä.
cx.update_editor(|editor, cx| {
editor.replace_and_mark_text_in_range(None, "\"", Some(1..1), cx);
editor.replace_text_in_range(None, "ä", cx);
});
cx.simulate_keystrokes(["escape"]);
cx.assert_state("hˇällo", Mode::Normal);
cx.simulate_keystrokes(["."]);
deterministic.run_until_parked();
cx.assert_state("hˇäällo", Mode::Normal);
}
#[gpui::test]
async fn test_repeat_completion(
deterministic: Arc<Deterministic>,
cx: &mut gpui::TestAppContext,
) {
let cx = EditorLspTestContext::new_rust(
lsp::ServerCapabilities {
completion_provider: Some(lsp::CompletionOptions {
trigger_characters: Some(vec![".".to_string(), ":".to_string()]),
resolve_provider: Some(true),
..Default::default()
}),
..Default::default()
},
cx,
)
.await;
let mut cx = VimTestContext::new_with_lsp(cx, true);
cx.set_state(
indoc! {"
onˇe
two
three
"},
Mode::Normal,
);
let mut request =
cx.handle_request::<lsp::request::Completion, _, _>(move |_, params, _| async move {
let position = params.text_document_position.position;
Ok(Some(lsp::CompletionResponse::Array(vec![
lsp::CompletionItem {
label: "first".to_string(),
text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
range: lsp::Range::new(position.clone(), position.clone()),
new_text: "first".to_string(),
})),
..Default::default()
},
lsp::CompletionItem {
label: "second".to_string(),
text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
range: lsp::Range::new(position.clone(), position.clone()),
new_text: "second".to_string(),
})),
..Default::default()
},
])))
});
cx.simulate_keystrokes(["a", "."]);
request.next().await;
cx.condition(|editor, _| editor.context_menu_visible())
.await;
cx.simulate_keystrokes(["down", "enter", "!", "escape"]);
cx.assert_state(
indoc! {"
one.secondˇ!
two
three
"},
Mode::Normal,
);
cx.simulate_keystrokes(["j", "."]);
deterministic.run_until_parked();
cx.assert_state(
indoc! {"
one.second!
two.secondˇ!
three
"},
Mode::Normal,
);
}
#[gpui::test]
async fn test_repeat_visual(deterministic: Arc<Deterministic>, cx: &mut gpui::TestAppContext) {
let mut cx = NeovimBackedTestContext::new(cx).await;
// single-line (3 columns)
cx.set_shared_state(indoc! {
"ˇthe quick brown
fox jumps over
the lazy dog"
})
.await;
cx.simulate_shared_keystrokes(["v", "i", "w", "s", "o", "escape"])
.await;
cx.assert_shared_state(indoc! {
"ˇo quick brown
fox jumps over
the lazy dog"
})
.await;
cx.simulate_shared_keystrokes(["j", "w", "."]).await;
deterministic.run_until_parked();
cx.assert_shared_state(indoc! {
"o quick brown
fox ˇops over
the lazy dog"
})
.await;
cx.simulate_shared_keystrokes(["f", "r", "."]).await;
deterministic.run_until_parked();
cx.assert_shared_state(indoc! {
"o quick brown
fox ops oveˇothe lazy dog"
})
.await;
// visual
cx.set_shared_state(indoc! {
"the ˇquick brown
fox jumps over
fox jumps over
fox jumps over
the lazy dog"
})
.await;
cx.simulate_shared_keystrokes(["v", "j", "x"]).await;
cx.assert_shared_state(indoc! {
"the ˇumps over
fox jumps over
fox jumps over
the lazy dog"
})
.await;
cx.simulate_shared_keystrokes(["."]).await;
deterministic.run_until_parked();
cx.assert_shared_state(indoc! {
"the ˇumps over
fox jumps over
the lazy dog"
})
.await;
cx.simulate_shared_keystrokes(["w", "."]).await;
deterministic.run_until_parked();
cx.assert_shared_state(indoc! {
"the umps ˇumps over
the lazy dog"
})
.await;
cx.simulate_shared_keystrokes(["j", "."]).await;
deterministic.run_until_parked();
cx.assert_shared_state(indoc! {
"the umps umps over
the ˇog"
})
.await;
// block mode (3 rows)
cx.set_shared_state(indoc! {
"ˇthe quick brown
fox jumps over
the lazy dog"
})
.await;
cx.simulate_shared_keystrokes(["ctrl-v", "j", "j", "shift-i", "o", "escape"])
.await;
cx.assert_shared_state(indoc! {
"ˇothe quick brown
ofox jumps over
othe lazy dog"
})
.await;
cx.simulate_shared_keystrokes(["j", "4", "l", "."]).await;
deterministic.run_until_parked();
cx.assert_shared_state(indoc! {
"othe quick brown
ofoxˇo jumps over
otheo lazy dog"
})
.await;
// line mode
cx.set_shared_state(indoc! {
"ˇthe quick brown
fox jumps over
the lazy dog"
})
.await;
cx.simulate_shared_keystrokes(["shift-v", "shift-r", "o", "escape"])
.await;
cx.assert_shared_state(indoc! {
"ˇo
fox jumps over
the lazy dog"
})
.await;
cx.simulate_shared_keystrokes(["j", "."]).await;
deterministic.run_until_parked();
cx.assert_shared_state(indoc! {
"o
ˇo
the lazy dog"
})
.await;
}
}

View File

@@ -67,7 +67,8 @@ fn scroll_editor(editor: &mut Editor, amount: &ScrollAmount, cx: &mut ViewContex
let top_anchor = editor.scroll_manager.anchor().anchor;
editor.change_selections(None, cx, |s| {
s.move_heads_with(|map, head, goal| {
s.move_with(|map, selection| {
let head = selection.head();
let top = top_anchor.to_display_point(map);
let min_row = top.row() + VERTICAL_SCROLL_MARGIN as u32;
let max_row = top.row() + visible_rows - VERTICAL_SCROLL_MARGIN as u32 - 1;
@@ -79,7 +80,11 @@ fn scroll_editor(editor: &mut Editor, amount: &ScrollAmount, cx: &mut ViewContex
} else {
head
};
(new_head, goal)
if selection.is_empty() {
selection.collapse_to(new_head, selection.goal)
} else {
selection.set_head(new_head, selection.goal)
};
})
});
}
@@ -90,12 +95,35 @@ mod test {
use crate::{state::Mode, test::VimTestContext};
use gpui::geometry::vector::vec2f;
use indoc::indoc;
use language::Point;
#[gpui::test]
async fn test_scroll(cx: &mut gpui::TestAppContext) {
let mut cx = VimTestContext::new(cx, true).await;
cx.set_state(indoc! {"ˇa\nb\nc\nd\ne\n"}, Mode::Normal);
let window = cx.window;
let line_height =
cx.editor(|editor, cx| editor.style(cx).text.line_height(cx.font_cache()));
window.simulate_resize(vec2f(1000., 8.0 * line_height - 1.0), &mut cx);
cx.set_state(
indoc!(
"ˇone
two
three
four
five
six
seven
eight
nine
ten
eleven
twelve
"
),
Mode::Normal,
);
cx.update_editor(|editor, cx| {
assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 0.))
@@ -112,5 +140,33 @@ mod test {
cx.update_editor(|editor, cx| {
assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 2.))
});
// does not select in normal mode
cx.simulate_keystrokes(["g", "g"]);
cx.update_editor(|editor, cx| {
assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 0.))
});
cx.simulate_keystrokes(["ctrl-d"]);
cx.update_editor(|editor, cx| {
assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 2.0));
assert_eq!(
editor.selections.newest(cx).range(),
Point::new(5, 0)..Point::new(5, 0)
)
});
// does select in visual mode
cx.simulate_keystrokes(["g", "g"]);
cx.update_editor(|editor, cx| {
assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 0.))
});
cx.simulate_keystrokes(["v", "ctrl-d"]);
cx.update_editor(|editor, cx| {
assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 2.0));
assert_eq!(
editor.selections.newest(cx).range(),
Point::new(0, 0)..Point::new(5, 1)
)
});
}
}

View File

@@ -1,10 +1,34 @@
use gpui::WindowContext;
use editor::movement;
use gpui::{actions, AppContext, WindowContext};
use language::Point;
use workspace::Workspace;
use crate::{motion::Motion, utils::copy_selections_content, Mode, Vim};
pub fn substitute(vim: &mut Vim, count: Option<usize>, cx: &mut WindowContext) {
let line_mode = vim.state().mode == Mode::VisualLine;
actions!(vim, [Substitute, SubstituteLine]);
pub(crate) fn init(cx: &mut AppContext) {
cx.add_action(|_: &mut Workspace, _: &Substitute, cx| {
Vim::update(cx, |vim, cx| {
vim.start_recording(cx);
let count = vim.pop_number_operator(cx);
substitute(vim, count, vim.state().mode == Mode::VisualLine, cx);
})
});
cx.add_action(|_: &mut Workspace, _: &SubstituteLine, cx| {
Vim::update(cx, |vim, cx| {
vim.start_recording(cx);
if matches!(vim.state().mode, Mode::VisualBlock | Mode::Visual) {
vim.switch_mode(Mode::VisualLine, false, cx)
}
let count = vim.pop_number_operator(cx);
substitute(vim, count, true, cx)
})
});
}
pub fn substitute(vim: &mut Vim, count: Option<usize>, line_mode: bool, cx: &mut WindowContext) {
vim.update_active_editor(cx, |editor, cx| {
editor.set_clip_at_line_ends(false, cx);
editor.transact(cx, |editor, cx| {
@@ -14,6 +38,11 @@ pub fn substitute(vim: &mut Vim, count: Option<usize>, cx: &mut WindowContext) {
Motion::Right.expand_selection(map, selection, count, true);
}
if line_mode {
// in Visual mode when the selection contains the newline at the end
// of the line, we should exclude it.
if !selection.is_empty() && selection.end.column() == 0 {
selection.end = movement::left(map, selection.end);
}
Motion::CurrentLine.expand_selection(map, selection, None, false);
if let Some((point, _)) = (Motion::FirstNonWhitespace {
display_lines: false,
@@ -166,4 +195,68 @@ mod test {
the laˇzy dog"})
.await;
}
#[gpui::test]
async fn test_substitute_line(cx: &mut gpui::TestAppContext) {
let mut cx = NeovimBackedTestContext::new(cx).await;
let initial_state = indoc! {"
The quick brown
fox juˇmps over
the lazy dog
"};
// normal mode
cx.set_shared_state(initial_state).await;
cx.simulate_shared_keystrokes(["shift-s", "o"]).await;
cx.assert_shared_state(indoc! {"
The quick brown
the lazy dog
"})
.await;
// visual mode
cx.set_shared_state(initial_state).await;
cx.simulate_shared_keystrokes(["v", "k", "shift-s", "o"])
.await;
cx.assert_shared_state(indoc! {"
the lazy dog
"})
.await;
// visual block mode
cx.set_shared_state(initial_state).await;
cx.simulate_shared_keystrokes(["ctrl-v", "j", "shift-s", "o"])
.await;
cx.assert_shared_state(indoc! {"
The quick brown
"})
.await;
// visual mode including newline
cx.set_shared_state(initial_state).await;
cx.simulate_shared_keystrokes(["v", "$", "shift-s", "o"])
.await;
cx.assert_shared_state(indoc! {"
The quick brown
the lazy dog
"})
.await;
// indentation
cx.set_neovim_option("shiftwidth=4").await;
cx.set_shared_state(initial_state).await;
cx.simulate_shared_keystrokes([">", ">", "shift-s", "o"])
.await;
cx.assert_shared_state(indoc! {"
The quick brown
the lazy dog
"})
.await;
}
}

View File

@@ -1,6 +1,11 @@
use std::ops::Range;
use editor::{char_kind, display_map::DisplaySnapshot, movement, Bias, CharKind, DisplayPoint};
use editor::{
char_kind,
display_map::DisplaySnapshot,
movement::{self, FindRange},
Bias, CharKind, DisplayPoint,
};
use gpui::{actions, impl_actions, AppContext, WindowContext};
use language::Selection;
use serde::Deserialize;
@@ -177,18 +182,22 @@ fn in_word(
ignore_punctuation: bool,
) -> Option<Range<DisplayPoint>> {
// Use motion::right so that we consider the character under the cursor when looking for the start
let language = map.buffer_snapshot.language_at(relative_to.to_point(map));
let start = movement::find_preceding_boundary_in_line(
let scope = map
.buffer_snapshot
.language_scope_at(relative_to.to_point(map));
let start = movement::find_preceding_boundary(
map,
right(map, relative_to, 1),
movement::FindRange::SingleLine,
|left, right| {
char_kind(language, left).coerce_punctuation(ignore_punctuation)
!= char_kind(language, right).coerce_punctuation(ignore_punctuation)
char_kind(&scope, left).coerce_punctuation(ignore_punctuation)
!= char_kind(&scope, right).coerce_punctuation(ignore_punctuation)
},
);
let end = movement::find_boundary_in_line(map, relative_to, |left, right| {
char_kind(language, left).coerce_punctuation(ignore_punctuation)
!= char_kind(language, right).coerce_punctuation(ignore_punctuation)
let end = movement::find_boundary(map, relative_to, FindRange::SingleLine, |left, right| {
char_kind(&scope, left).coerce_punctuation(ignore_punctuation)
!= char_kind(&scope, right).coerce_punctuation(ignore_punctuation)
});
Some(start..end)
@@ -211,11 +220,13 @@ fn around_word(
relative_to: DisplayPoint,
ignore_punctuation: bool,
) -> Option<Range<DisplayPoint>> {
let language = map.buffer_snapshot.language_at(relative_to.to_point(map));
let scope = map
.buffer_snapshot
.language_scope_at(relative_to.to_point(map));
let in_word = map
.chars_at(relative_to)
.next()
.map(|(c, _)| char_kind(language, c) != CharKind::Whitespace)
.map(|(c, _)| char_kind(&scope, c) != CharKind::Whitespace)
.unwrap_or(false);
if in_word {
@@ -239,21 +250,24 @@ fn around_next_word(
relative_to: DisplayPoint,
ignore_punctuation: bool,
) -> Option<Range<DisplayPoint>> {
let language = map.buffer_snapshot.language_at(relative_to.to_point(map));
let scope = map
.buffer_snapshot
.language_scope_at(relative_to.to_point(map));
// Get the start of the word
let start = movement::find_preceding_boundary_in_line(
let start = movement::find_preceding_boundary(
map,
right(map, relative_to, 1),
FindRange::SingleLine,
|left, right| {
char_kind(language, left).coerce_punctuation(ignore_punctuation)
!= char_kind(language, right).coerce_punctuation(ignore_punctuation)
char_kind(&scope, left).coerce_punctuation(ignore_punctuation)
!= char_kind(&scope, right).coerce_punctuation(ignore_punctuation)
},
);
let mut word_found = false;
let end = movement::find_boundary(map, relative_to, |left, right| {
let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation);
let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation);
let end = movement::find_boundary(map, relative_to, FindRange::MultiLine, |left, right| {
let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation);
let right_kind = char_kind(&scope, right).coerce_punctuation(ignore_punctuation);
let found = (word_found && left_kind != right_kind) || right == '\n' && left == '\n';
@@ -566,11 +580,18 @@ mod test {
async fn test_visual_word_object(cx: &mut gpui::TestAppContext) {
let mut cx = NeovimBackedTestContext::new(cx).await;
cx.set_shared_state("The quick ˇbrown\nfox").await;
/*
cx.set_shared_state("The quick ˇbrown\nfox").await;
cx.simulate_shared_keystrokes(["v"]).await;
cx.assert_shared_state("The quick «bˇ»rown\nfox").await;
cx.simulate_shared_keystrokes(["i", "w"]).await;
cx.assert_shared_state("The quick «brownˇ»\nfox").await;
*/
cx.set_shared_state("The quick brown\nˇ\nfox").await;
cx.simulate_shared_keystrokes(["v"]).await;
cx.assert_shared_state("The quick «bˇ»rown\nfox").await;
cx.assert_shared_state("The quick brown\n«\nˇ»fox").await;
cx.simulate_shared_keystrokes(["i", "w"]).await;
cx.assert_shared_state("The quick «brownˇ»\nfox").await;
cx.assert_shared_state("The quick brown\n«\nˇ»fox").await;
cx.assert_binding_matches_all(["v", "i", "w"], WORD_LOCATIONS)
.await;

View File

@@ -1,4 +1,6 @@
use gpui::keymap_matcher::KeymapContext;
use std::{ops::Range, sync::Arc};
use gpui::{keymap_matcher::KeymapContext, Action};
use language::CursorShape;
use serde::{Deserialize, Serialize};
use workspace::searchable::Direction;
@@ -48,10 +50,61 @@ pub struct EditorState {
pub operator_stack: Vec<Operator>,
}
#[derive(Default, Clone, Debug)]
pub enum RecordedSelection {
#[default]
None,
Visual {
rows: u32,
cols: u32,
},
SingleLine {
cols: u32,
},
VisualBlock {
rows: u32,
cols: u32,
},
VisualLine {
rows: u32,
},
}
#[derive(Default, Clone)]
pub struct WorkspaceState {
pub search: SearchState,
pub last_find: Option<Motion>,
pub recording: bool,
pub stop_recording_after_next_action: bool,
pub replaying: bool,
pub recorded_count: Option<usize>,
pub recorded_actions: Vec<ReplayableAction>,
pub recorded_selection: RecordedSelection,
}
#[derive(Debug)]
pub enum ReplayableAction {
Action(Box<dyn Action>),
Insertion {
text: Arc<str>,
utf16_range_to_replace: Option<Range<isize>>,
},
}
impl Clone for ReplayableAction {
fn clone(&self) -> Self {
match self {
Self::Action(action) => Self::Action(action.boxed_clone()),
Self::Insertion {
text,
utf16_range_to_replace,
} => Self::Insertion {
text: text.clone(),
utf16_range_to_replace: utf16_range_to_replace.clone(),
},
}
}
}
#[derive(Clone)]

View File

@@ -286,6 +286,55 @@ async fn test_word_characters(cx: &mut gpui::TestAppContext) {
)
}
#[gpui::test]
async fn test_join_lines(cx: &mut gpui::TestAppContext) {
let mut cx = NeovimBackedTestContext::new(cx).await;
cx.set_shared_state(indoc! {"
ˇone
two
three
four
five
six
"})
.await;
cx.simulate_shared_keystrokes(["shift-j"]).await;
cx.assert_shared_state(indoc! {"
oneˇ two
three
four
five
six
"})
.await;
cx.simulate_shared_keystrokes(["3", "shift-j"]).await;
cx.assert_shared_state(indoc! {"
one two threeˇ four
five
six
"})
.await;
cx.set_shared_state(indoc! {"
ˇone
two
three
four
five
six
"})
.await;
cx.simulate_shared_keystrokes(["j", "v", "3", "j", "shift-j"])
.await;
cx.assert_shared_state(indoc! {"
one
two three fourˇ five
six
"})
.await;
}
#[gpui::test]
async fn test_wrapped_lines(cx: &mut gpui::TestAppContext) {
let mut cx = NeovimBackedTestContext::new(cx).await;
@@ -431,6 +480,31 @@ async fn test_wrapped_lines(cx: &mut gpui::TestAppContext) {
twelve char
"})
.await;
// line wraps as:
// fourteen ch
// ar
// fourteen ch
// ar
cx.set_shared_state(indoc! { "
fourteen chaˇr
fourteen char
"})
.await;
cx.simulate_shared_keystrokes(["d", "i", "w"]).await;
cx.assert_shared_state(indoc! {"
fourteenˇ•
fourteen char
"})
.await;
cx.simulate_shared_keystrokes(["j", "shift-f", "e", "f", "r"])
.await;
cx.assert_shared_state(indoc! {"
fourteen•
fourteen chaˇr
"})
.await;
}
#[gpui::test]

View File

@@ -153,6 +153,7 @@ impl<'a> NeovimBackedTestContext<'a> {
}
pub async fn assert_shared_state(&mut self, marked_text: &str) {
let marked_text = marked_text.replace("", " ");
let neovim = self.neovim_state().await;
let editor = self.editor_state();
if neovim == marked_text && neovim == editor {
@@ -184,9 +185,9 @@ impl<'a> NeovimBackedTestContext<'a> {
message,
initial_state,
self.recent_keystrokes.join(" "),
marked_text,
neovim,
editor
marked_text.replace(" \n", "\n"),
neovim.replace(" \n", "\n"),
editor.replace(" \n", "\n")
)
}

Some files were not shown because too many files have changed in this diff Show More