Compare commits
157 Commits
gpui_extra
...
channel-da
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a1f13645a | ||
|
|
bf296ebbd7 | ||
|
|
114961fc69 | ||
|
|
80f5e66efc | ||
|
|
b7172d5e0d | ||
|
|
1ab2007fcd | ||
|
|
441848d195 | ||
|
|
273fa9dd22 | ||
|
|
e0602da8df | ||
|
|
65b795c213 | ||
|
|
fe10ecebb6 | ||
|
|
7cc05c99c2 | ||
|
|
e29ce489c8 | ||
|
|
4c92172cca | ||
|
|
ba1c350dad | ||
|
|
5d782b6cf0 | ||
|
|
88dae22e3e | ||
|
|
f069cd0485 | ||
|
|
e1d4d911b4 | ||
|
|
a0701777d5 | ||
|
|
f4a9d3f269 | ||
|
|
87472a9de6 | ||
|
|
5f897f45a8 | ||
|
|
74ccb3df63 | ||
|
|
e9747d0fea | ||
|
|
ddc8a126da | ||
|
|
6ad2ec4825 | ||
|
|
4e818fed4a | ||
|
|
e7b7ac9d8c | ||
|
|
56d9a578bd | ||
|
|
5b0f4ac9e8 | ||
|
|
4d2933a4d7 | ||
|
|
560d6b1644 | ||
|
|
a6ce382368 | ||
|
|
cf5d1d91a4 | ||
|
|
98999b1e9a | ||
|
|
eda7e00645 | ||
|
|
8e2e00e003 | ||
|
|
47d7aa0b91 | ||
|
|
65e17e212d | ||
|
|
48bb2a3321 | ||
|
|
1b1d7f22cc | ||
|
|
1969a12a0b | ||
|
|
3b784668c0 | ||
|
|
a45c8c380f | ||
|
|
757a285852 | ||
|
|
93b889a93b | ||
|
|
3ad1befb11 | ||
|
|
425a3969c8 | ||
|
|
39e13b6675 | ||
|
|
d03a89ca19 | ||
|
|
58f58a629b | ||
|
|
ed2aed4f93 | ||
|
|
b75e69d31b | ||
|
|
e779adfe46 | ||
|
|
66c3879306 | ||
|
|
f22d53eef9 | ||
|
|
20f98e4d17 | ||
|
|
bbeb82f884 | ||
|
|
265d02a583 | ||
|
|
17237f748c | ||
|
|
f4237ace40 | ||
|
|
5b5c232cd1 | ||
|
|
15609b4803 | ||
|
|
29e35531af | ||
|
|
a2e91e45d9 | ||
|
|
246b699bfd | ||
|
|
8d672f5d4c | ||
|
|
ce62173534 | ||
|
|
de0f53b39f | ||
|
|
c802680084 | ||
|
|
9272e9354a | ||
|
|
653d4976cd | ||
|
|
ec5ff20b4c | ||
|
|
49af2874bb | ||
|
|
c2c04616b4 | ||
|
|
27143e2fb4 | ||
|
|
95b72a73ad | ||
|
|
3c70b127bd | ||
|
|
4855063151 | ||
|
|
e2479a7172 | ||
|
|
6b1dc63fc0 | ||
|
|
7b5a41dda2 | ||
|
|
d4cff68475 | ||
|
|
42976b6014 | ||
|
|
56db21d54b | ||
|
|
55dd0b176c | ||
|
|
3a7b551e33 | ||
|
|
6827ddf97d | ||
|
|
e6babce556 | ||
|
|
d7e4cb4ab1 | ||
|
|
d370c72fbf | ||
|
|
8dbc0fe033 | ||
|
|
da16167db1 | ||
|
|
af12977d17 | ||
|
|
aa7b65bbaf | ||
|
|
0e41c6c5b3 | ||
|
|
6d7949654b | ||
|
|
54235f4fb1 | ||
|
|
e86964eb5d | ||
|
|
524533cfb2 | ||
|
|
c4db914f0a | ||
|
|
2bf417fa45 | ||
|
|
d868ec920f | ||
|
|
7bcc59c8a5 | ||
|
|
1e60454643 | ||
|
|
03f0365d4d | ||
|
|
afa59abbcd | ||
|
|
00aae5abee | ||
|
|
eecd4e39cc | ||
|
|
50cfb067e7 | ||
|
|
220533ff1a | ||
|
|
2503d54d19 | ||
|
|
3001a46f69 | ||
|
|
fe2300fdaa | ||
|
|
7b5974e8e9 | ||
|
|
c763e728d1 | ||
|
|
35440be98e | ||
|
|
ddc6214216 | ||
|
|
5731ef51cd | ||
|
|
e682db7101 | ||
|
|
5bc5831032 | ||
|
|
292af55ebc | ||
|
|
fff385a585 | ||
|
|
9e12df43d0 | ||
|
|
ff3865a4ad | ||
|
|
529adb95a1 | ||
|
|
7d4d6c871b | ||
|
|
5abad58b0d | ||
|
|
76ce52df4e | ||
|
|
9781047156 | ||
|
|
76caea80f7 | ||
|
|
7e5735c8f1 | ||
|
|
e377ada1a9 | ||
|
|
d3650594c3 | ||
|
|
e3a0252b04 | ||
|
|
a7e6a65deb | ||
|
|
4f8b95cf0d | ||
|
|
0e6c91818f | ||
|
|
2d411303bb | ||
|
|
15628af04b | ||
|
|
35b7787e02 | ||
|
|
ded6decb29 | ||
|
|
fc457d45f5 | ||
|
|
a394aaa524 | ||
|
|
68408f3838 | ||
|
|
affb73d651 | ||
|
|
814896de3f | ||
|
|
a35b3f39c5 | ||
|
|
007d1b09ac | ||
|
|
c842e87079 | ||
|
|
a979e32127 | ||
|
|
4f0fa21c04 | ||
|
|
e54f16f372 | ||
|
|
8839b07a25 | ||
|
|
40ce099780 | ||
|
|
7a67ec5743 |
77
Cargo.lock
generated
77
Cargo.lock
generated
@@ -1453,9 +1453,10 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "collab"
|
||||
version = "0.19.0"
|
||||
version = "0.20.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"async-tungstenite",
|
||||
"audio",
|
||||
"axum",
|
||||
@@ -3539,7 +3540,7 @@ dependencies = [
|
||||
"gif",
|
||||
"jpeg-decoder",
|
||||
"num-iter",
|
||||
"num-rational",
|
||||
"num-rational 0.3.2",
|
||||
"num-traits",
|
||||
"png",
|
||||
"scoped_threadpool",
|
||||
@@ -4177,8 +4178,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "lsp-types"
|
||||
version = "0.94.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c66bfd44a06ae10647fe3f8214762e9369fd4248df1350924b4ef9e770a85ea1"
|
||||
source = "git+https://github.com/zed-industries/lsp-types?branch=updated-completion-list-item-defaults#90a040a1d195687bd19e1df47463320a44e93d7a"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"serde",
|
||||
@@ -4583,6 +4583,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"async-compression",
|
||||
"async-tar",
|
||||
"async-trait",
|
||||
"futures 0.3.28",
|
||||
"gpui",
|
||||
"log",
|
||||
@@ -4632,6 +4633,31 @@ dependencies = [
|
||||
"winapi 0.3.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
|
||||
dependencies = [
|
||||
"num-bigint 0.2.6",
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-iter",
|
||||
"num-rational 0.2.4",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint"
|
||||
version = "0.4.4"
|
||||
@@ -4660,6 +4686,16 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-derive"
|
||||
version = "0.3.3"
|
||||
@@ -4692,6 +4728,18 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-bigint 0.2.6",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.3.2"
|
||||
@@ -5008,6 +5056,17 @@ dependencies = [
|
||||
"windows-targets 0.48.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parse_duration"
|
||||
version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7037e5e93e0172a5a96874380bf73bc6ecef022e26fa25f2be26864d6b3ba95d"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
"num",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "password-hash"
|
||||
version = "0.2.3"
|
||||
@@ -6663,6 +6722,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"collections",
|
||||
"ctor",
|
||||
"editor",
|
||||
"env_logger 0.9.3",
|
||||
@@ -6675,6 +6735,7 @@ dependencies = [
|
||||
"log",
|
||||
"matrixmultiply",
|
||||
"parking_lot 0.11.2",
|
||||
"parse_duration",
|
||||
"picker",
|
||||
"postage",
|
||||
"pretty_assertions",
|
||||
@@ -7006,7 +7067,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8eb4ea60fb301dc81dfc113df680571045d375ab7345d171c5dc7d7e13107a80"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"num-bigint",
|
||||
"num-bigint 0.4.4",
|
||||
"num-traits",
|
||||
"thiserror",
|
||||
]
|
||||
@@ -7238,7 +7299,7 @@ dependencies = [
|
||||
"log",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"num-bigint",
|
||||
"num-bigint 0.4.4",
|
||||
"once_cell",
|
||||
"paste",
|
||||
"percent-encoding",
|
||||
@@ -8768,12 +8829,14 @@ dependencies = [
|
||||
"collections",
|
||||
"command_palette",
|
||||
"editor",
|
||||
"futures 0.3.28",
|
||||
"gpui",
|
||||
"indoc",
|
||||
"itertools",
|
||||
"language",
|
||||
"language_selector",
|
||||
"log",
|
||||
"lsp",
|
||||
"nvim-rs",
|
||||
"parking_lot 0.11.2",
|
||||
"project",
|
||||
@@ -9702,7 +9765,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.103.0"
|
||||
version = "0.104.0"
|
||||
dependencies = [
|
||||
"activity_indicator",
|
||||
"ai",
|
||||
|
||||
37
README.md
37
README.md
@@ -8,7 +8,31 @@ Welcome to Zed, a lightning-fast, collaborative code editor that makes your drea
|
||||
|
||||
### Dependencies
|
||||
|
||||
* Install [Postgres.app](https://postgresapp.com) and start it.
|
||||
* Install Xcode from https://apps.apple.com/us/app/xcode/id497799835?mt=12, and accept the license:
|
||||
```
|
||||
sudo xcodebuild -license
|
||||
```
|
||||
|
||||
* Install homebrew, rust and node
|
||||
```
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
brew install rust
|
||||
brew install node
|
||||
```
|
||||
|
||||
* Ensure rust executables are in your $PATH
|
||||
```
|
||||
echo $HOME/.cargo/bin | sudo tee /etc/paths.d/10-rust
|
||||
```
|
||||
|
||||
* Install postgres and configure the database
|
||||
```
|
||||
brew install postgresql@15
|
||||
brew services start postgresql@15
|
||||
psql -c "CREATE ROLE postgres SUPERUSER LOGIN" postgres
|
||||
psql -U postgres -c "CREATE DATABASE zed"
|
||||
```
|
||||
|
||||
* Install the `LiveKit` server and the `foreman` process supervisor:
|
||||
|
||||
```
|
||||
@@ -41,6 +65,17 @@ Welcome to Zed, a lightning-fast, collaborative code editor that makes your drea
|
||||
GITHUB_TOKEN=<$token> script/bootstrap
|
||||
```
|
||||
|
||||
* Now try running zed with collaboration disabled:
|
||||
```
|
||||
cargo run
|
||||
```
|
||||
|
||||
### Common errors
|
||||
|
||||
* `xcrun: error: unable to find utility "metal", not a developer tool or in PATH`
|
||||
* You need to install Xcode and then run: `xcode-select --switch /Applications/Xcode.app/Contents/Developer`
|
||||
* (see https://github.com/gfx-rs/gfx/issues/2309)
|
||||
|
||||
### Testing against locally-running servers
|
||||
|
||||
Start the web and collab servers:
|
||||
|
||||
@@ -515,6 +515,17 @@
|
||||
"enter": "editor::ConfirmCodeAction"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && (showing_code_actions || showing_completions)",
|
||||
"bindings": {
|
||||
"up": "editor::ContextMenuPrev",
|
||||
"ctrl-p": "editor::ContextMenuPrev",
|
||||
"down": "editor::ContextMenuNext",
|
||||
"ctrl-n": "editor::ContextMenuNext",
|
||||
"pageup": "editor::ContextMenuFirst",
|
||||
"pagedown": "editor::ContextMenuLast"
|
||||
}
|
||||
},
|
||||
// Custom bindings
|
||||
{
|
||||
"bindings": {
|
||||
|
||||
@@ -198,6 +198,18 @@
|
||||
"z c": "editor::Fold",
|
||||
"z o": "editor::UnfoldLines",
|
||||
"z f": "editor::FoldSelectedRanges",
|
||||
"shift-z shift-q": [
|
||||
"pane::CloseActiveItem",
|
||||
{
|
||||
"saveBehavior": "dontSave"
|
||||
}
|
||||
],
|
||||
"shift-z shift-z": [
|
||||
"pane::CloseActiveItem",
|
||||
{
|
||||
"saveBehavior": "promptOnConflict"
|
||||
}
|
||||
],
|
||||
// Count support
|
||||
"1": [
|
||||
"vim::Number",
|
||||
@@ -316,6 +328,7 @@
|
||||
{
|
||||
"context": "Editor && vim_mode == normal && (vim_operator == none || vim_operator == n) && !VimWaiting",
|
||||
"bindings": {
|
||||
".": "vim::Repeat",
|
||||
"c": [
|
||||
"vim::PushOperator",
|
||||
"Change"
|
||||
@@ -326,15 +339,12 @@
|
||||
"Delete"
|
||||
],
|
||||
"shift-d": "vim::DeleteToEndOfLine",
|
||||
"shift-j": "editor::JoinLines",
|
||||
"shift-j": "vim::JoinLines",
|
||||
"y": [
|
||||
"vim::PushOperator",
|
||||
"Yank"
|
||||
],
|
||||
"i": [
|
||||
"vim::SwitchMode",
|
||||
"Insert"
|
||||
],
|
||||
"i": "vim::InsertBefore",
|
||||
"shift-i": "vim::InsertFirstNonWhitespace",
|
||||
"a": "vim::InsertAfter",
|
||||
"shift-a": "vim::InsertEndOfLine",
|
||||
@@ -371,6 +381,7 @@
|
||||
"Replace"
|
||||
],
|
||||
"s": "vim::Substitute",
|
||||
"shift-s": "vim::SubstituteLine",
|
||||
"> >": "editor::Indent",
|
||||
"< <": "editor::Outdent",
|
||||
"ctrl-pagedown": "pane::ActivateNextItem",
|
||||
@@ -446,13 +457,13 @@
|
||||
}
|
||||
],
|
||||
"s": "vim::Substitute",
|
||||
"shift-s": "vim::SubstituteLine",
|
||||
"shift-r": "vim::SubstituteLine",
|
||||
"c": "vim::Substitute",
|
||||
"~": "vim::ChangeCase",
|
||||
"shift-i": [
|
||||
"vim::SwitchMode",
|
||||
"Insert"
|
||||
],
|
||||
"shift-i": "vim::InsertBefore",
|
||||
"shift-a": "vim::InsertAfter",
|
||||
"shift-j": "vim::JoinLines",
|
||||
"r": [
|
||||
"vim::PushOperator",
|
||||
"Replace"
|
||||
|
||||
@@ -406,36 +406,30 @@ impl AssistantPanel {
|
||||
_: &editor::Cancel,
|
||||
cx: &mut ViewContext<Workspace>,
|
||||
) {
|
||||
let panel = if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
panel
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
let editor = if let Some(editor) = workspace
|
||||
.active_item(cx)
|
||||
.and_then(|item| item.downcast::<Editor>())
|
||||
{
|
||||
editor
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
|
||||
let handled = panel.update(cx, |panel, cx| {
|
||||
if let Some(assist_id) = panel
|
||||
.pending_inline_assist_ids_by_editor
|
||||
.get(&editor.downgrade())
|
||||
.and_then(|assist_ids| assist_ids.last().copied())
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
if let Some(editor) = workspace
|
||||
.active_item(cx)
|
||||
.and_then(|item| item.downcast::<Editor>())
|
||||
{
|
||||
panel.close_inline_assist(assist_id, true, cx);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
let handled = panel.update(cx, |panel, cx| {
|
||||
if let Some(assist_id) = panel
|
||||
.pending_inline_assist_ids_by_editor
|
||||
.get(&editor.downgrade())
|
||||
.and_then(|assist_ids| assist_ids.last().copied())
|
||||
{
|
||||
panel.close_inline_assist(assist_id, true, cx);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
if handled {
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if !handled {
|
||||
cx.propagate_action();
|
||||
}
|
||||
|
||||
cx.propagate_action();
|
||||
}
|
||||
|
||||
fn close_inline_assist(&mut self, assist_id: usize, undo: bool, cx: &mut ViewContext<Self>) {
|
||||
@@ -513,10 +507,13 @@ impl AssistantPanel {
|
||||
return;
|
||||
};
|
||||
|
||||
self.inline_prompt_history
|
||||
.retain(|prompt| prompt != user_prompt);
|
||||
self.inline_prompt_history.push_back(user_prompt.into());
|
||||
if self.inline_prompt_history.len() > Self::INLINE_PROMPT_HISTORY_MAX_LEN {
|
||||
self.inline_prompt_history.pop_front();
|
||||
}
|
||||
|
||||
let range = pending_assist.range.clone();
|
||||
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
|
||||
let selected_text = snapshot
|
||||
|
||||
@@ -273,7 +273,13 @@ impl ActiveCall {
|
||||
.borrow_mut()
|
||||
.take()
|
||||
.ok_or_else(|| anyhow!("no incoming call"))?;
|
||||
Self::report_call_event_for_room("decline incoming", call.room_id, None, &self.client, cx);
|
||||
Self::report_call_event_for_room(
|
||||
"decline incoming",
|
||||
Some(call.room_id),
|
||||
None,
|
||||
&self.client,
|
||||
cx,
|
||||
);
|
||||
self.client.send(proto::DeclineCall {
|
||||
room_id: call.room_id,
|
||||
})?;
|
||||
@@ -404,21 +410,19 @@ impl ActiveCall {
|
||||
}
|
||||
|
||||
fn report_call_event(&self, operation: &'static str, cx: &AppContext) {
|
||||
if let Some(room) = self.room() {
|
||||
let room = room.read(cx);
|
||||
Self::report_call_event_for_room(
|
||||
operation,
|
||||
room.id(),
|
||||
room.channel_id(),
|
||||
&self.client,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
let (room_id, channel_id) = match self.room() {
|
||||
Some(room) => {
|
||||
let room = room.read(cx);
|
||||
(Some(room.id()), room.channel_id())
|
||||
}
|
||||
None => (None, None),
|
||||
};
|
||||
Self::report_call_event_for_room(operation, room_id, channel_id, &self.client, cx)
|
||||
}
|
||||
|
||||
pub fn report_call_event_for_room(
|
||||
operation: &'static str,
|
||||
room_id: u64,
|
||||
room_id: Option<u64>,
|
||||
channel_id: Option<u64>,
|
||||
client: &Arc<Client>,
|
||||
cx: &AppContext,
|
||||
|
||||
@@ -10,6 +10,7 @@ pub(crate) fn init(client: &Arc<Client>) {
|
||||
client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer);
|
||||
client.add_model_message_handler(ChannelBuffer::handle_add_channel_buffer_collaborator);
|
||||
client.add_model_message_handler(ChannelBuffer::handle_remove_channel_buffer_collaborator);
|
||||
client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer_collaborator);
|
||||
}
|
||||
|
||||
pub struct ChannelBuffer {
|
||||
@@ -17,6 +18,7 @@ pub struct ChannelBuffer {
|
||||
connected: bool,
|
||||
collaborators: Vec<proto::Collaborator>,
|
||||
buffer: ModelHandle<language::Buffer>,
|
||||
buffer_epoch: u64,
|
||||
client: Arc<Client>,
|
||||
subscription: Option<client::Subscription>,
|
||||
}
|
||||
@@ -73,6 +75,7 @@ impl ChannelBuffer {
|
||||
|
||||
Self {
|
||||
buffer,
|
||||
buffer_epoch: response.epoch,
|
||||
client,
|
||||
connected: true,
|
||||
collaborators,
|
||||
@@ -82,6 +85,26 @@ impl ChannelBuffer {
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn replace_collaborators(
|
||||
&mut self,
|
||||
collaborators: Vec<proto::Collaborator>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
for old_collaborator in &self.collaborators {
|
||||
if collaborators
|
||||
.iter()
|
||||
.any(|c| c.replica_id == old_collaborator.replica_id)
|
||||
{
|
||||
self.buffer.update(cx, |buffer, cx| {
|
||||
buffer.remove_peer(old_collaborator.replica_id as u16, cx)
|
||||
});
|
||||
}
|
||||
}
|
||||
self.collaborators = collaborators;
|
||||
cx.emit(Event::CollaboratorsChanged);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
async fn handle_update_channel_buffer(
|
||||
this: ModelHandle<Self>,
|
||||
update_channel_buffer: TypedEnvelope<proto::UpdateChannelBuffer>,
|
||||
@@ -149,6 +172,26 @@ impl ChannelBuffer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_update_channel_buffer_collaborator(
|
||||
this: ModelHandle<Self>,
|
||||
message: TypedEnvelope<proto::UpdateChannelBufferCollaborator>,
|
||||
_: Arc<Client>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<()> {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
for collaborator in &mut this.collaborators {
|
||||
if collaborator.peer_id == message.payload.old_peer_id {
|
||||
collaborator.peer_id = message.payload.new_peer_id;
|
||||
break;
|
||||
}
|
||||
}
|
||||
cx.emit(Event::CollaboratorsChanged);
|
||||
cx.notify();
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn on_buffer_update(
|
||||
&mut self,
|
||||
_: ModelHandle<language::Buffer>,
|
||||
@@ -166,6 +209,10 @@ impl ChannelBuffer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn epoch(&self) -> u64 {
|
||||
self.buffer_epoch
|
||||
}
|
||||
|
||||
pub fn buffer(&self) -> ModelHandle<language::Buffer> {
|
||||
self.buffer.clone()
|
||||
}
|
||||
@@ -179,6 +226,7 @@ impl ChannelBuffer {
|
||||
}
|
||||
|
||||
pub(crate) fn disconnect(&mut self, cx: &mut ModelContext<Self>) {
|
||||
log::info!("channel buffer {} disconnected", self.channel.id);
|
||||
if self.connected {
|
||||
self.connected = false;
|
||||
self.subscription.take();
|
||||
|
||||
@@ -1,18 +1,24 @@
|
||||
mod channel_index;
|
||||
|
||||
use crate::channel_buffer::ChannelBuffer;
|
||||
use anyhow::{anyhow, Result};
|
||||
use client::{Client, Status, Subscription, User, UserId, UserStore};
|
||||
use client::{Client, Subscription, User, UserId, UserStore};
|
||||
use collections::{hash_map, HashMap, HashSet};
|
||||
use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt};
|
||||
use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
|
||||
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
|
||||
use rpc::{proto, TypedEnvelope};
|
||||
use std::sync::Arc;
|
||||
use std::{mem, sync::Arc, time::Duration};
|
||||
use util::ResultExt;
|
||||
|
||||
use self::channel_index::ChannelIndex;
|
||||
pub use self::channel_index::ChannelPath;
|
||||
|
||||
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
pub type ChannelId = u64;
|
||||
|
||||
pub struct ChannelStore {
|
||||
channels_by_id: HashMap<ChannelId, Arc<Channel>>,
|
||||
channel_paths: Vec<Vec<ChannelId>>,
|
||||
channel_index: ChannelIndex,
|
||||
channel_invitations: Vec<Arc<Channel>>,
|
||||
channel_participants: HashMap<ChannelId, Vec<Arc<User>>>,
|
||||
channels_with_admin_privileges: HashSet<ChannelId>,
|
||||
@@ -22,7 +28,8 @@ pub struct ChannelStore {
|
||||
client: Arc<Client>,
|
||||
user_store: ModelHandle<UserStore>,
|
||||
_rpc_subscription: Subscription,
|
||||
_watch_connection_status: Task<()>,
|
||||
_watch_connection_status: Task<Option<()>>,
|
||||
disconnect_channel_buffers_task: Option<Task<()>>,
|
||||
_update_channels: Task<()>,
|
||||
}
|
||||
|
||||
@@ -67,30 +74,25 @@ impl ChannelStore {
|
||||
let rpc_subscription =
|
||||
client.add_message_handler(cx.handle(), Self::handle_update_channels);
|
||||
|
||||
let (update_channels_tx, mut update_channels_rx) = mpsc::unbounded();
|
||||
let mut connection_status = client.status();
|
||||
let (update_channels_tx, mut update_channels_rx) = mpsc::unbounded();
|
||||
let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
|
||||
while let Some(status) = connection_status.next().await {
|
||||
if !status.is_connected() {
|
||||
if let Some(this) = this.upgrade(&cx) {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if matches!(status, Status::ConnectionLost | Status::SignedOut) {
|
||||
this.handle_disconnect(cx);
|
||||
} else {
|
||||
this.disconnect_buffers(cx);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
let this = this.upgrade(&cx)?;
|
||||
if status.is_connected() {
|
||||
this.update(&mut cx, |this, cx| this.handle_connect(cx))
|
||||
.await
|
||||
.log_err()?;
|
||||
} else {
|
||||
this.update(&mut cx, |this, cx| this.handle_disconnect(cx));
|
||||
}
|
||||
}
|
||||
Some(())
|
||||
});
|
||||
|
||||
Self {
|
||||
channels_by_id: HashMap::default(),
|
||||
channel_invitations: Vec::default(),
|
||||
channel_paths: Vec::default(),
|
||||
channel_index: ChannelIndex::default(),
|
||||
channel_participants: Default::default(),
|
||||
channels_with_admin_privileges: Default::default(),
|
||||
outgoing_invites: Default::default(),
|
||||
@@ -100,6 +102,7 @@ impl ChannelStore {
|
||||
user_store,
|
||||
_rpc_subscription: rpc_subscription,
|
||||
_watch_connection_status: watch_connection_status,
|
||||
disconnect_channel_buffers_task: None,
|
||||
_update_channels: cx.spawn_weak(|this, mut cx| async move {
|
||||
while let Some(update_channels) = update_channels_rx.next().await {
|
||||
if let Some(this) = this.upgrade(&cx) {
|
||||
@@ -116,7 +119,7 @@ impl ChannelStore {
|
||||
}
|
||||
|
||||
pub fn has_children(&self, channel_id: ChannelId) -> bool {
|
||||
self.channel_paths.iter().any(|path| {
|
||||
self.channel_index.iter().any(|path| {
|
||||
if let Some(ix) = path.iter().position(|id| *id == channel_id) {
|
||||
path.len() > ix + 1
|
||||
} else {
|
||||
@@ -126,22 +129,23 @@ impl ChannelStore {
|
||||
}
|
||||
|
||||
pub fn channel_count(&self) -> usize {
|
||||
self.channel_paths.len()
|
||||
self.channel_index.len()
|
||||
}
|
||||
|
||||
pub fn channels(&self) -> impl '_ + Iterator<Item = (usize, &Arc<Channel>)> {
|
||||
self.channel_paths.iter().map(move |path| {
|
||||
self.channel_index.iter().map(move |path| {
|
||||
let id = path.last().unwrap();
|
||||
let channel = self.channel_for_id(*id).unwrap();
|
||||
(path.len() - 1, channel)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn channel_at_index(&self, ix: usize) -> Option<(usize, &Arc<Channel>)> {
|
||||
let path = self.channel_paths.get(ix)?;
|
||||
pub fn channel_at_index(&self, ix: usize) -> Option<(&Arc<Channel>, &ChannelPath)> {
|
||||
let path = self.channel_index.get(ix)?;
|
||||
let id = path.last().unwrap();
|
||||
let channel = self.channel_for_id(*id).unwrap();
|
||||
Some((path.len() - 1, channel))
|
||||
|
||||
Some((channel, path))
|
||||
}
|
||||
|
||||
pub fn channel_invitations(&self) -> &[Arc<Channel>] {
|
||||
@@ -149,7 +153,16 @@ impl ChannelStore {
|
||||
}
|
||||
|
||||
pub fn channel_for_id(&self, channel_id: ChannelId) -> Option<&Arc<Channel>> {
|
||||
self.channels_by_id.get(&channel_id)
|
||||
self.channel_index.by_id().get(&channel_id)
|
||||
}
|
||||
|
||||
pub fn has_open_channel_buffer(&self, channel_id: ChannelId, cx: &AppContext) -> bool {
|
||||
if let Some(buffer) = self.opened_buffers.get(&channel_id) {
|
||||
if let OpenedChannelBuffer::Open(buffer) = buffer {
|
||||
return buffer.upgrade(cx).is_some();
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub fn open_channel_buffer(
|
||||
@@ -221,7 +234,7 @@ impl ChannelStore {
|
||||
}
|
||||
|
||||
pub fn is_user_admin(&self, channel_id: ChannelId) -> bool {
|
||||
self.channel_paths.iter().any(|path| {
|
||||
self.channel_index.iter().any(|path| {
|
||||
if let Some(ix) = path.iter().position(|id| *id == channel_id) {
|
||||
path[..=ix]
|
||||
.iter()
|
||||
@@ -276,6 +289,59 @@ impl ChannelStore {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn link_channel(
|
||||
&mut self,
|
||||
channel_id: ChannelId,
|
||||
to: ChannelId,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let client = self.client.clone();
|
||||
cx.spawn(|_, _| async move {
|
||||
let _ = client
|
||||
.request(proto::LinkChannel { channel_id, to })
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn unlink_channel(
|
||||
&mut self,
|
||||
channel_id: ChannelId,
|
||||
from: Option<ChannelId>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let client = self.client.clone();
|
||||
cx.spawn(|_, _| async move {
|
||||
let _ = client
|
||||
.request(proto::UnlinkChannel { channel_id, from })
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn move_channel(
|
||||
&mut self,
|
||||
channel_id: ChannelId,
|
||||
from: Option<ChannelId>,
|
||||
to: ChannelId,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let client = self.client.clone();
|
||||
cx.spawn(|_, _| async move {
|
||||
let _ = client
|
||||
.request(proto::MoveChannel {
|
||||
channel_id,
|
||||
from,
|
||||
to,
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn invite_member(
|
||||
&mut self,
|
||||
channel_id: ChannelId,
|
||||
@@ -455,7 +521,7 @@ impl ChannelStore {
|
||||
pub fn remove_channel(&self, channel_id: ChannelId) -> impl Future<Output = Result<()>> {
|
||||
let client = self.client.clone();
|
||||
async move {
|
||||
client.request(proto::RemoveChannel { channel_id }).await?;
|
||||
client.request(proto::DeleteChannel { channel_id }).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -482,25 +548,130 @@ impl ChannelStore {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_disconnect(&mut self, cx: &mut ModelContext<'_, ChannelStore>) {
|
||||
self.disconnect_buffers(cx);
|
||||
self.channels_by_id.clear();
|
||||
self.channel_invitations.clear();
|
||||
self.channel_participants.clear();
|
||||
self.channels_with_admin_privileges.clear();
|
||||
self.channel_paths.clear();
|
||||
self.outgoing_invites.clear();
|
||||
cx.notify();
|
||||
}
|
||||
fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||
self.disconnect_channel_buffers_task.take();
|
||||
|
||||
fn disconnect_buffers(&mut self, cx: &mut ModelContext<ChannelStore>) {
|
||||
for (_, buffer) in self.opened_buffers.drain() {
|
||||
let mut buffer_versions = Vec::new();
|
||||
for buffer in self.opened_buffers.values() {
|
||||
if let OpenedChannelBuffer::Open(buffer) = buffer {
|
||||
if let Some(buffer) = buffer.upgrade(cx) {
|
||||
buffer.update(cx, |buffer, cx| buffer.disconnect(cx));
|
||||
let channel_buffer = buffer.read(cx);
|
||||
let buffer = channel_buffer.buffer().read(cx);
|
||||
buffer_versions.push(proto::ChannelBufferVersion {
|
||||
channel_id: channel_buffer.channel().id,
|
||||
epoch: channel_buffer.epoch(),
|
||||
version: language::proto::serialize_version(&buffer.version()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if buffer_versions.is_empty() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let response = self.client.request(proto::RejoinChannelBuffers {
|
||||
buffers: buffer_versions,
|
||||
});
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let mut response = response.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.opened_buffers.retain(|_, buffer| match buffer {
|
||||
OpenedChannelBuffer::Open(channel_buffer) => {
|
||||
let Some(channel_buffer) = channel_buffer.upgrade(cx) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
channel_buffer.update(cx, |channel_buffer, cx| {
|
||||
let channel_id = channel_buffer.channel().id;
|
||||
if let Some(remote_buffer) = response
|
||||
.buffers
|
||||
.iter_mut()
|
||||
.find(|buffer| buffer.channel_id == channel_id)
|
||||
{
|
||||
let channel_id = channel_buffer.channel().id;
|
||||
let remote_version =
|
||||
language::proto::deserialize_version(&remote_buffer.version);
|
||||
|
||||
channel_buffer.replace_collaborators(
|
||||
mem::take(&mut remote_buffer.collaborators),
|
||||
cx,
|
||||
);
|
||||
|
||||
let operations = channel_buffer
|
||||
.buffer()
|
||||
.update(cx, |buffer, cx| {
|
||||
let outgoing_operations =
|
||||
buffer.serialize_ops(Some(remote_version), cx);
|
||||
let incoming_operations =
|
||||
mem::take(&mut remote_buffer.operations)
|
||||
.into_iter()
|
||||
.map(language::proto::deserialize_operation)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
buffer.apply_ops(incoming_operations, cx)?;
|
||||
anyhow::Ok(outgoing_operations)
|
||||
})
|
||||
.log_err();
|
||||
|
||||
if let Some(operations) = operations {
|
||||
let client = this.client.clone();
|
||||
cx.background()
|
||||
.spawn(async move {
|
||||
let operations = operations.await;
|
||||
for chunk in
|
||||
language::proto::split_operations(operations)
|
||||
{
|
||||
client
|
||||
.send(proto::UpdateChannelBuffer {
|
||||
channel_id,
|
||||
operations: chunk,
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
channel_buffer.disconnect(cx);
|
||||
false
|
||||
})
|
||||
}
|
||||
OpenedChannelBuffer::Loading(_) => true,
|
||||
});
|
||||
});
|
||||
anyhow::Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
|
||||
self.channel_index.clear();
|
||||
self.channel_invitations.clear();
|
||||
self.channel_participants.clear();
|
||||
self.channels_with_admin_privileges.clear();
|
||||
self.channel_index.clear();
|
||||
self.outgoing_invites.clear();
|
||||
cx.notify();
|
||||
|
||||
self.disconnect_channel_buffers_task.get_or_insert_with(|| {
|
||||
cx.spawn_weak(|this, mut cx| async move {
|
||||
cx.background().timer(RECONNECT_TIMEOUT).await;
|
||||
if let Some(this) = this.upgrade(&cx) {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
for (_, buffer) in this.opened_buffers.drain() {
|
||||
if let OpenedChannelBuffer::Open(buffer) = buffer {
|
||||
if let Some(buffer) = buffer.upgrade(cx) {
|
||||
buffer.update(cx, |buffer, cx| buffer.disconnect(cx));
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn update_channels(
|
||||
@@ -528,17 +699,16 @@ impl ChannelStore {
|
||||
}
|
||||
}
|
||||
|
||||
let channels_changed = !payload.channels.is_empty() || !payload.remove_channels.is_empty();
|
||||
let channels_changed = !payload.channels.is_empty() || !payload.delete_channels.is_empty();
|
||||
if channels_changed {
|
||||
if !payload.remove_channels.is_empty() {
|
||||
self.channels_by_id
|
||||
.retain(|channel_id, _| !payload.remove_channels.contains(channel_id));
|
||||
if !payload.delete_channels.is_empty() {
|
||||
self.channel_index.delete_channels(&payload.delete_channels);
|
||||
self.channel_participants
|
||||
.retain(|channel_id, _| !payload.remove_channels.contains(channel_id));
|
||||
.retain(|channel_id, _| !payload.delete_channels.contains(channel_id));
|
||||
self.channels_with_admin_privileges
|
||||
.retain(|channel_id| !payload.remove_channels.contains(channel_id));
|
||||
.retain(|channel_id| !payload.delete_channels.contains(channel_id));
|
||||
|
||||
for channel_id in &payload.remove_channels {
|
||||
for channel_id in &payload.delete_channels {
|
||||
let channel_id = *channel_id;
|
||||
if let Some(OpenedChannelBuffer::Open(buffer)) =
|
||||
self.opened_buffers.remove(&channel_id)
|
||||
@@ -550,44 +720,15 @@ impl ChannelStore {
|
||||
}
|
||||
}
|
||||
|
||||
for channel_proto in payload.channels {
|
||||
if let Some(existing_channel) = self.channels_by_id.get_mut(&channel_proto.id) {
|
||||
Arc::make_mut(existing_channel).name = channel_proto.name;
|
||||
} else {
|
||||
let channel = Arc::new(Channel {
|
||||
id: channel_proto.id,
|
||||
name: channel_proto.name,
|
||||
});
|
||||
self.channels_by_id.insert(channel.id, channel.clone());
|
||||
|
||||
if let Some(parent_id) = channel_proto.parent_id {
|
||||
let mut ix = 0;
|
||||
while ix < self.channel_paths.len() {
|
||||
let path = &self.channel_paths[ix];
|
||||
if path.ends_with(&[parent_id]) {
|
||||
let mut new_path = path.clone();
|
||||
new_path.push(channel.id);
|
||||
self.channel_paths.insert(ix + 1, new_path);
|
||||
ix += 1;
|
||||
}
|
||||
ix += 1;
|
||||
}
|
||||
} else {
|
||||
self.channel_paths.push(vec![channel.id]);
|
||||
}
|
||||
}
|
||||
let mut channel_index = self.channel_index.start_upsert();
|
||||
for channel in payload.channels {
|
||||
channel_index.upsert(channel)
|
||||
}
|
||||
}
|
||||
|
||||
self.channel_paths.sort_by(|a, b| {
|
||||
let a = Self::channel_path_sorting_key(a, &self.channels_by_id);
|
||||
let b = Self::channel_path_sorting_key(b, &self.channels_by_id);
|
||||
a.cmp(b)
|
||||
});
|
||||
self.channel_paths.dedup();
|
||||
self.channel_paths.retain(|path| {
|
||||
path.iter()
|
||||
.all(|channel_id| self.channels_by_id.contains_key(channel_id))
|
||||
});
|
||||
for edge in payload.delete_channel_edge {
|
||||
self.channel_index
|
||||
.delete_edge(edge.parent_id, edge.channel_id);
|
||||
}
|
||||
|
||||
for permission in payload.channel_permissions {
|
||||
@@ -645,12 +786,4 @@ impl ChannelStore {
|
||||
anyhow::Ok(())
|
||||
}))
|
||||
}
|
||||
|
||||
fn channel_path_sorting_key<'a>(
|
||||
path: &'a [ChannelId],
|
||||
channels_by_id: &'a HashMap<ChannelId, Arc<Channel>>,
|
||||
) -> impl 'a + Iterator<Item = Option<&'a str>> {
|
||||
path.iter()
|
||||
.map(|id| Some(channels_by_id.get(id)?.name.as_str()))
|
||||
}
|
||||
}
|
||||
|
||||
161
crates/channel/src/channel_store/channel_index.rs
Normal file
161
crates/channel/src/channel_store/channel_index.rs
Normal file
@@ -0,0 +1,161 @@
|
||||
use std::{sync::Arc, ops::Deref};
|
||||
|
||||
use collections::HashMap;
|
||||
use rpc::proto;
|
||||
use serde_derive::{Serialize, Deserialize};
|
||||
|
||||
use crate::{ChannelId, Channel};
|
||||
|
||||
pub type ChannelsById = HashMap<ChannelId, Arc<Channel>>;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
|
||||
pub struct ChannelPath(Arc<[ChannelId]>);
|
||||
|
||||
impl Deref for ChannelPath {
|
||||
type Target = [ChannelId];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl ChannelPath {
|
||||
pub fn parent_id(&self) -> Option<ChannelId> {
|
||||
self.0.len().checked_sub(2).map(|i| {
|
||||
self.0[i]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ChannelPath {
|
||||
fn default() -> Self {
|
||||
ChannelPath(Arc::from([]))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct ChannelIndex {
|
||||
paths: Vec<ChannelPath>,
|
||||
channels_by_id: ChannelsById,
|
||||
}
|
||||
|
||||
|
||||
impl ChannelIndex {
|
||||
pub fn by_id(&self) -> &ChannelsById {
|
||||
&self.channels_by_id
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.paths.clear();
|
||||
self.channels_by_id.clear();
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.paths.len()
|
||||
}
|
||||
|
||||
pub fn get(&self, idx: usize) -> Option<&ChannelPath> {
|
||||
self.paths.get(idx)
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = &ChannelPath> {
|
||||
self.paths.iter()
|
||||
}
|
||||
|
||||
/// Remove the given edge from this index. This will not remove the channel
|
||||
/// and may result in dangling channels.
|
||||
pub fn delete_edge(&mut self, parent_id: ChannelId, channel_id: ChannelId) {
|
||||
self.paths.retain(|path| {
|
||||
!path
|
||||
.windows(2)
|
||||
.any(|window| window == [parent_id, channel_id])
|
||||
});
|
||||
}
|
||||
|
||||
/// Delete the given channels from this index.
|
||||
pub fn delete_channels(&mut self, channels: &[ChannelId]) {
|
||||
self.channels_by_id.retain(|channel_id, _| !channels.contains(channel_id));
|
||||
self.paths.retain(|channel_path| !channel_path.iter().any(|channel_id| {channels.contains(channel_id)}))
|
||||
}
|
||||
|
||||
/// Upsert one or more channels into this index.
|
||||
pub fn start_upsert(& mut self) -> ChannelPathsUpsertGuard {
|
||||
ChannelPathsUpsertGuard {
|
||||
paths: &mut self.paths,
|
||||
channels_by_id: &mut self.channels_by_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A guard for ensuring that the paths index maintains its sort and uniqueness
|
||||
/// invariants after a series of insertions
|
||||
pub struct ChannelPathsUpsertGuard<'a> {
|
||||
paths: &'a mut Vec<ChannelPath>,
|
||||
channels_by_id: &'a mut ChannelsById,
|
||||
}
|
||||
|
||||
impl<'a> ChannelPathsUpsertGuard<'a> {
|
||||
pub fn upsert(&mut self, channel_proto: proto::Channel) {
|
||||
if let Some(existing_channel) = self.channels_by_id.get_mut(&channel_proto.id) {
|
||||
Arc::make_mut(existing_channel).name = channel_proto.name;
|
||||
|
||||
if let Some(parent_id) = channel_proto.parent_id {
|
||||
self.insert_edge(parent_id, channel_proto.id)
|
||||
}
|
||||
} else {
|
||||
let channel = Arc::new(Channel {
|
||||
id: channel_proto.id,
|
||||
name: channel_proto.name,
|
||||
});
|
||||
self.channels_by_id.insert(channel.id, channel.clone());
|
||||
|
||||
if let Some(parent_id) = channel_proto.parent_id {
|
||||
self.insert_edge(parent_id, channel.id);
|
||||
} else {
|
||||
self.insert_root(channel.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_edge(&mut self, parent_id: ChannelId, channel_id: ChannelId) {
|
||||
let mut ix = 0;
|
||||
while ix < self.paths.len() {
|
||||
let path = &self.paths[ix];
|
||||
if path.ends_with(&[parent_id]) {
|
||||
let mut new_path = path.to_vec();
|
||||
new_path.push(channel_id);
|
||||
self.paths.insert(ix + 1, ChannelPath(new_path.into()));
|
||||
ix += 1;
|
||||
}
|
||||
ix += 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_root(&mut self, channel_id: ChannelId) {
|
||||
self.paths.push(ChannelPath(Arc::from([channel_id])));
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for ChannelPathsUpsertGuard<'a> {
|
||||
fn drop(&mut self) {
|
||||
self.paths.sort_by(|a, b| {
|
||||
let a = channel_path_sorting_key(a, &self.channels_by_id);
|
||||
let b = channel_path_sorting_key(b, &self.channels_by_id);
|
||||
a.cmp(b)
|
||||
});
|
||||
self.paths.dedup();
|
||||
self.paths.retain(|path| {
|
||||
path.iter()
|
||||
.all(|channel_id| self.channels_by_id.contains_key(channel_id))
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn channel_path_sorting_key<'a>(
|
||||
path: &'a [ChannelId],
|
||||
channels_by_id: &'a ChannelsById,
|
||||
) -> impl 'a + Iterator<Item = Option<&'a str>> {
|
||||
path.iter()
|
||||
.map(|id| Some(channels_by_id.get(id)?.name.as_str()))
|
||||
}
|
||||
@@ -127,7 +127,7 @@ fn test_dangling_channel_paths(cx: &mut AppContext) {
|
||||
update_channels(
|
||||
&channel_store,
|
||||
proto::UpdateChannels {
|
||||
remove_channels: vec![1, 2],
|
||||
delete_channels: vec![1, 2],
|
||||
..Default::default()
|
||||
},
|
||||
cx,
|
||||
|
||||
@@ -1011,9 +1011,9 @@ impl Client {
|
||||
credentials: &Credentials,
|
||||
cx: &AsyncAppContext,
|
||||
) -> Task<Result<Connection, EstablishConnectionError>> {
|
||||
let is_preview = cx.read(|cx| {
|
||||
let use_preview_server = cx.read(|cx| {
|
||||
if cx.has_global::<ReleaseChannel>() {
|
||||
*cx.global::<ReleaseChannel>() == ReleaseChannel::Preview
|
||||
*cx.global::<ReleaseChannel>() != ReleaseChannel::Stable
|
||||
} else {
|
||||
false
|
||||
}
|
||||
@@ -1028,7 +1028,7 @@ impl Client {
|
||||
|
||||
let http = self.http.clone();
|
||||
cx.background().spawn(async move {
|
||||
let mut rpc_url = Self::get_rpc_url(http, is_preview).await?;
|
||||
let mut rpc_url = Self::get_rpc_url(http, use_preview_server).await?;
|
||||
let rpc_host = rpc_url
|
||||
.host_str()
|
||||
.zip(rpc_url.port_or_known_default())
|
||||
|
||||
@@ -73,7 +73,7 @@ pub enum ClickhouseEvent {
|
||||
},
|
||||
Call {
|
||||
operation: &'static str,
|
||||
room_id: u64,
|
||||
room_id: Option<u64>,
|
||||
channel_id: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2,70 +2,17 @@ use smallvec::SmallVec;
|
||||
use std::{
|
||||
cmp::{self, Ordering},
|
||||
fmt, iter,
|
||||
ops::{Add, AddAssign},
|
||||
};
|
||||
|
||||
pub type ReplicaId = u16;
|
||||
pub type Seq = u32;
|
||||
|
||||
#[derive(Clone, Copy, Default, Eq, Hash, PartialEq, Ord, PartialOrd)]
|
||||
pub struct Local {
|
||||
pub replica_id: ReplicaId,
|
||||
pub value: Seq,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Default, Eq, Hash, PartialEq)]
|
||||
pub struct Lamport {
|
||||
pub replica_id: ReplicaId,
|
||||
pub value: Seq,
|
||||
}
|
||||
|
||||
impl Local {
|
||||
pub const MIN: Self = Self {
|
||||
replica_id: ReplicaId::MIN,
|
||||
value: Seq::MIN,
|
||||
};
|
||||
pub const MAX: Self = Self {
|
||||
replica_id: ReplicaId::MAX,
|
||||
value: Seq::MAX,
|
||||
};
|
||||
|
||||
pub fn new(replica_id: ReplicaId) -> Self {
|
||||
Self {
|
||||
replica_id,
|
||||
value: 1,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tick(&mut self) -> Self {
|
||||
let timestamp = *self;
|
||||
self.value += 1;
|
||||
timestamp
|
||||
}
|
||||
|
||||
pub fn observe(&mut self, timestamp: Self) {
|
||||
if timestamp.replica_id == self.replica_id {
|
||||
self.value = cmp::max(self.value, timestamp.value + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Add<&'a Self> for Local {
|
||||
type Output = Local;
|
||||
|
||||
fn add(self, other: &'a Self) -> Self::Output {
|
||||
*cmp::max(&self, other)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> AddAssign<&'a Local> for Local {
|
||||
fn add_assign(&mut self, other: &Self) {
|
||||
if *self < *other {
|
||||
*self = *other;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A vector clock
|
||||
#[derive(Clone, Default, Hash, Eq, PartialEq)]
|
||||
pub struct Global(SmallVec<[u32; 8]>);
|
||||
@@ -79,7 +26,7 @@ impl Global {
|
||||
self.0.get(replica_id as usize).copied().unwrap_or(0) as Seq
|
||||
}
|
||||
|
||||
pub fn observe(&mut self, timestamp: Local) {
|
||||
pub fn observe(&mut self, timestamp: Lamport) {
|
||||
if timestamp.value > 0 {
|
||||
let new_len = timestamp.replica_id as usize + 1;
|
||||
if new_len > self.0.len() {
|
||||
@@ -126,7 +73,7 @@ impl Global {
|
||||
self.0.resize(new_len, 0);
|
||||
}
|
||||
|
||||
pub fn observed(&self, timestamp: Local) -> bool {
|
||||
pub fn observed(&self, timestamp: Lamport) -> bool {
|
||||
self.get(timestamp.replica_id) >= timestamp.value
|
||||
}
|
||||
|
||||
@@ -178,16 +125,16 @@ impl Global {
|
||||
false
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = Local> + '_ {
|
||||
self.0.iter().enumerate().map(|(replica_id, seq)| Local {
|
||||
pub fn iter(&self) -> impl Iterator<Item = Lamport> + '_ {
|
||||
self.0.iter().enumerate().map(|(replica_id, seq)| Lamport {
|
||||
replica_id: replica_id as ReplicaId,
|
||||
value: *seq,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<Local> for Global {
|
||||
fn from_iter<T: IntoIterator<Item = Local>>(locals: T) -> Self {
|
||||
impl FromIterator<Lamport> for Global {
|
||||
fn from_iter<T: IntoIterator<Item = Lamport>>(locals: T) -> Self {
|
||||
let mut result = Self::new();
|
||||
for local in locals {
|
||||
result.observe(local);
|
||||
@@ -212,6 +159,16 @@ impl PartialOrd for Lamport {
|
||||
}
|
||||
|
||||
impl Lamport {
|
||||
pub const MIN: Self = Self {
|
||||
replica_id: ReplicaId::MIN,
|
||||
value: Seq::MIN,
|
||||
};
|
||||
|
||||
pub const MAX: Self = Self {
|
||||
replica_id: ReplicaId::MAX,
|
||||
value: Seq::MAX,
|
||||
};
|
||||
|
||||
pub fn new(replica_id: ReplicaId) -> Self {
|
||||
Self {
|
||||
value: 1,
|
||||
@@ -230,12 +187,6 @@ impl Lamport {
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Local {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Local {{{}: {}}}", self.replica_id, self.value)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Lamport {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Lamport {{{}: {}}}", self.replica_id, self.value)
|
||||
|
||||
@@ -3,7 +3,7 @@ authors = ["Nathan Sobo <nathan@zed.dev>"]
|
||||
default-run = "collab"
|
||||
edition = "2021"
|
||||
name = "collab"
|
||||
version = "0.19.0"
|
||||
version = "0.20.0"
|
||||
publish = false
|
||||
|
||||
[[bin]]
|
||||
@@ -72,7 +72,6 @@ fs = { path = "../fs", features = ["test-support"] }
|
||||
git = { path = "../git", features = ["test-support"] }
|
||||
live_kit_client = { path = "../live_kit_client", features = ["test-support"] }
|
||||
lsp = { path = "../lsp", features = ["test-support"] }
|
||||
pretty_assertions.workspace = true
|
||||
project = { path = "../project", features = ["test-support"] }
|
||||
rpc = { path = "../rpc", features = ["test-support"] }
|
||||
settings = { path = "../settings", features = ["test-support"] }
|
||||
@@ -80,6 +79,8 @@ theme = { path = "../theme" }
|
||||
workspace = { path = "../workspace", features = ["test-support"] }
|
||||
collab_ui = { path = "../collab_ui", features = ["test-support"] }
|
||||
|
||||
async-trait.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
ctor.workspace = true
|
||||
env_logger.workspace = true
|
||||
indoc.workspace = true
|
||||
|
||||
@@ -435,6 +435,12 @@ pub struct ChannelsForUser {
|
||||
pub channels_with_admin_privileges: HashSet<ChannelId>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RejoinedChannelBuffer {
|
||||
pub buffer: proto::RejoinedChannelBuffer,
|
||||
pub old_connection_id: ConnectionId,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct JoinRoom {
|
||||
pub room: proto::Room,
|
||||
@@ -498,6 +504,11 @@ pub struct RefreshedRoom {
|
||||
pub canceled_calls_to_user_ids: Vec<UserId>,
|
||||
}
|
||||
|
||||
pub struct RefreshedChannelBuffer {
|
||||
pub connection_ids: Vec<ConnectionId>,
|
||||
pub removed_collaborators: Vec<proto::RemoveChannelBufferCollaborator>,
|
||||
}
|
||||
|
||||
pub struct Project {
|
||||
pub collaborators: Vec<ProjectCollaborator>,
|
||||
pub worktrees: BTreeMap<u64, Worktree>,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::*;
|
||||
use prost::Message;
|
||||
use text::{EditOperation, InsertionTimestamp, UndoOperation};
|
||||
use text::{EditOperation, UndoOperation};
|
||||
|
||||
impl Database {
|
||||
pub async fn join_channel_buffer(
|
||||
@@ -10,8 +10,6 @@ impl Database {
|
||||
connection: ConnectionId,
|
||||
) -> Result<proto::JoinChannelBufferResponse> {
|
||||
self.transaction(|tx| async move {
|
||||
let tx = tx;
|
||||
|
||||
self.check_user_is_channel_member(channel_id, user_id, &tx)
|
||||
.await?;
|
||||
|
||||
@@ -70,7 +68,6 @@ impl Database {
|
||||
.await?;
|
||||
collaborators.push(collaborator);
|
||||
|
||||
// Assemble the buffer state
|
||||
let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
|
||||
|
||||
Ok(proto::JoinChannelBufferResponse {
|
||||
@@ -78,6 +75,7 @@ impl Database {
|
||||
replica_id: replica_id.to_proto() as u32,
|
||||
base_text,
|
||||
operations,
|
||||
epoch: buffer.epoch as u64,
|
||||
collaborators: collaborators
|
||||
.into_iter()
|
||||
.map(|collaborator| proto::Collaborator {
|
||||
@@ -91,6 +89,154 @@ impl Database {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn rejoin_channel_buffers(
|
||||
&self,
|
||||
buffers: &[proto::ChannelBufferVersion],
|
||||
user_id: UserId,
|
||||
connection_id: ConnectionId,
|
||||
) -> Result<Vec<RejoinedChannelBuffer>> {
|
||||
self.transaction(|tx| async move {
|
||||
let mut results = Vec::new();
|
||||
for client_buffer in buffers {
|
||||
let channel_id = ChannelId::from_proto(client_buffer.channel_id);
|
||||
if self
|
||||
.check_user_is_channel_member(channel_id, user_id, &*tx)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
log::info!("user is not a member of channel");
|
||||
continue;
|
||||
}
|
||||
|
||||
let buffer = self.get_channel_buffer(channel_id, &*tx).await?;
|
||||
let mut collaborators = channel_buffer_collaborator::Entity::find()
|
||||
.filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
// If the buffer epoch hasn't changed since the client lost
|
||||
// connection, then the client's buffer can be syncronized with
|
||||
// the server's buffer.
|
||||
if buffer.epoch as u64 != client_buffer.epoch {
|
||||
log::info!("can't rejoin buffer, epoch has changed");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find the collaborator record for this user's previous lost
|
||||
// connection. Update it with the new connection id.
|
||||
let server_id = ServerId(connection_id.owner_id as i32);
|
||||
let Some(self_collaborator) = collaborators.iter_mut().find(|c| {
|
||||
c.user_id == user_id
|
||||
&& (c.connection_lost || c.connection_server_id != server_id)
|
||||
}) else {
|
||||
log::info!("can't rejoin buffer, no previous collaborator found");
|
||||
continue;
|
||||
};
|
||||
let old_connection_id = self_collaborator.connection();
|
||||
*self_collaborator = channel_buffer_collaborator::ActiveModel {
|
||||
id: ActiveValue::Unchanged(self_collaborator.id),
|
||||
connection_id: ActiveValue::Set(connection_id.id as i32),
|
||||
connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
|
||||
connection_lost: ActiveValue::Set(false),
|
||||
..Default::default()
|
||||
}
|
||||
.update(&*tx)
|
||||
.await?;
|
||||
|
||||
let client_version = version_from_wire(&client_buffer.version);
|
||||
let serialization_version = self
|
||||
.get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
|
||||
.await?;
|
||||
|
||||
let mut rows = buffer_operation::Entity::find()
|
||||
.filter(
|
||||
buffer_operation::Column::BufferId
|
||||
.eq(buffer.id)
|
||||
.and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
|
||||
)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
|
||||
// Find the server's version vector and any operations
|
||||
// that the client has not seen.
|
||||
let mut server_version = clock::Global::new();
|
||||
let mut operations = Vec::new();
|
||||
while let Some(row) = rows.next().await {
|
||||
let row = row?;
|
||||
let timestamp = clock::Lamport {
|
||||
replica_id: row.replica_id as u16,
|
||||
value: row.lamport_timestamp as u32,
|
||||
};
|
||||
server_version.observe(timestamp);
|
||||
if !client_version.observed(timestamp) {
|
||||
operations.push(proto::Operation {
|
||||
variant: Some(operation_from_storage(row, serialization_version)?),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
results.push(RejoinedChannelBuffer {
|
||||
old_connection_id,
|
||||
buffer: proto::RejoinedChannelBuffer {
|
||||
channel_id: client_buffer.channel_id,
|
||||
version: version_to_wire(&server_version),
|
||||
operations,
|
||||
collaborators: collaborators
|
||||
.into_iter()
|
||||
.map(|collaborator| proto::Collaborator {
|
||||
peer_id: Some(collaborator.connection().into()),
|
||||
user_id: collaborator.user_id.to_proto(),
|
||||
replica_id: collaborator.replica_id.0 as u32,
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn clear_stale_channel_buffer_collaborators(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
server_id: ServerId,
|
||||
) -> Result<RefreshedChannelBuffer> {
|
||||
self.transaction(|tx| async move {
|
||||
let collaborators = channel_buffer_collaborator::Entity::find()
|
||||
.filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let mut connection_ids = Vec::new();
|
||||
let mut removed_collaborators = Vec::new();
|
||||
let mut collaborator_ids_to_remove = Vec::new();
|
||||
for collaborator in &collaborators {
|
||||
if !collaborator.connection_lost && collaborator.connection_server_id == server_id {
|
||||
connection_ids.push(collaborator.connection());
|
||||
} else {
|
||||
removed_collaborators.push(proto::RemoveChannelBufferCollaborator {
|
||||
channel_id: channel_id.to_proto(),
|
||||
peer_id: Some(collaborator.connection().into()),
|
||||
});
|
||||
collaborator_ids_to_remove.push(collaborator.id);
|
||||
}
|
||||
}
|
||||
|
||||
channel_buffer_collaborator::Entity::delete_many()
|
||||
.filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove))
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(RefreshedChannelBuffer {
|
||||
connection_ids,
|
||||
removed_collaborators,
|
||||
})
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn leave_channel_buffer(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
@@ -103,6 +249,39 @@ impl Database {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn leave_channel_buffers(
|
||||
&self,
|
||||
connection: ConnectionId,
|
||||
) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
|
||||
self.transaction(|tx| async move {
|
||||
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
|
||||
enum QueryChannelIds {
|
||||
ChannelId,
|
||||
}
|
||||
|
||||
let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
|
||||
.select_only()
|
||||
.column(channel_buffer_collaborator::Column::ChannelId)
|
||||
.filter(Condition::all().add(
|
||||
channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
|
||||
))
|
||||
.into_values::<_, QueryChannelIds>()
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let mut result = Vec::new();
|
||||
for channel_id in channel_ids {
|
||||
let collaborators = self
|
||||
.leave_channel_buffer_internal(channel_id, connection, &*tx)
|
||||
.await?;
|
||||
result.push((channel_id, collaborators));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn leave_channel_buffer_internal(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
@@ -143,46 +322,12 @@ impl Database {
|
||||
drop(rows);
|
||||
|
||||
if connections.is_empty() {
|
||||
self.snapshot_buffer(channel_id, &tx).await?;
|
||||
self.snapshot_channel_buffer(channel_id, &tx).await?;
|
||||
}
|
||||
|
||||
Ok(connections)
|
||||
}
|
||||
|
||||
pub async fn leave_channel_buffers(
|
||||
&self,
|
||||
connection: ConnectionId,
|
||||
) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
|
||||
self.transaction(|tx| async move {
|
||||
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
|
||||
enum QueryChannelIds {
|
||||
ChannelId,
|
||||
}
|
||||
|
||||
let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
|
||||
.select_only()
|
||||
.column(channel_buffer_collaborator::Column::ChannelId)
|
||||
.filter(Condition::all().add(
|
||||
channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
|
||||
))
|
||||
.into_values::<_, QueryChannelIds>()
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let mut result = Vec::new();
|
||||
for channel_id in channel_ids {
|
||||
let collaborators = self
|
||||
.leave_channel_buffer_internal(channel_id, connection, &*tx)
|
||||
.await?;
|
||||
result.push((channel_id, collaborators));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
pub async fn get_channel_buffer_collaborators(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
@@ -225,20 +370,9 @@ impl Database {
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("no such buffer"))?;
|
||||
|
||||
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
|
||||
enum QueryVersion {
|
||||
OperationSerializationVersion,
|
||||
}
|
||||
|
||||
let serialization_version: i32 = buffer
|
||||
.find_related(buffer_snapshot::Entity)
|
||||
.select_only()
|
||||
.column(buffer_snapshot::Column::OperationSerializationVersion)
|
||||
.filter(buffer_snapshot::Column::Epoch.eq(buffer.epoch))
|
||||
.into_values::<_, QueryVersion>()
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("missing buffer snapshot"))?;
|
||||
let serialization_version = self
|
||||
.get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
|
||||
.await?;
|
||||
|
||||
let operations = operations
|
||||
.iter()
|
||||
@@ -246,6 +380,16 @@ impl Database {
|
||||
.collect::<Vec<_>>();
|
||||
if !operations.is_empty() {
|
||||
buffer_operation::Entity::insert_many(operations)
|
||||
.on_conflict(
|
||||
OnConflict::columns([
|
||||
buffer_operation::Column::BufferId,
|
||||
buffer_operation::Column::Epoch,
|
||||
buffer_operation::Column::LamportTimestamp,
|
||||
buffer_operation::Column::ReplicaId,
|
||||
])
|
||||
.do_nothing()
|
||||
.to_owned(),
|
||||
)
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
}
|
||||
@@ -271,6 +415,38 @@ impl Database {
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_buffer_operation_serialization_version(
|
||||
&self,
|
||||
buffer_id: BufferId,
|
||||
epoch: i32,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<i32> {
|
||||
Ok(buffer_snapshot::Entity::find()
|
||||
.filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
|
||||
.filter(buffer_snapshot::Column::Epoch.eq(epoch))
|
||||
.select_only()
|
||||
.column(buffer_snapshot::Column::OperationSerializationVersion)
|
||||
.into_values::<_, QueryOperationSerializationVersion>()
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("missing buffer snapshot"))?)
|
||||
}
|
||||
|
||||
async fn get_channel_buffer(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<buffer::Model> {
|
||||
Ok(channel::Model {
|
||||
id: channel_id,
|
||||
..Default::default()
|
||||
}
|
||||
.find_related(buffer::Entity)
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("no such buffer"))?)
|
||||
}
|
||||
|
||||
async fn get_buffer_state(
|
||||
&self,
|
||||
buffer: &buffer::Model,
|
||||
@@ -304,27 +480,20 @@ impl Database {
|
||||
.await?;
|
||||
let mut operations = Vec::new();
|
||||
while let Some(row) = rows.next().await {
|
||||
let row = row?;
|
||||
|
||||
let operation = operation_from_storage(row, version)?;
|
||||
operations.push(proto::Operation {
|
||||
variant: Some(operation),
|
||||
variant: Some(operation_from_storage(row?, version)?),
|
||||
})
|
||||
}
|
||||
|
||||
Ok((base_text, operations))
|
||||
}
|
||||
|
||||
async fn snapshot_buffer(&self, channel_id: ChannelId, tx: &DatabaseTransaction) -> Result<()> {
|
||||
let buffer = channel::Model {
|
||||
id: channel_id,
|
||||
..Default::default()
|
||||
}
|
||||
.find_related(buffer::Entity)
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("no such buffer"))?;
|
||||
|
||||
async fn snapshot_channel_buffer(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<()> {
|
||||
let buffer = self.get_channel_buffer(channel_id, tx).await?;
|
||||
let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
|
||||
if operations.is_empty() {
|
||||
return Ok(());
|
||||
@@ -370,7 +539,6 @@ fn operation_to_storage(
|
||||
operation.replica_id,
|
||||
operation.lamport_timestamp,
|
||||
storage::Operation {
|
||||
local_timestamp: operation.local_timestamp,
|
||||
version: version_to_storage(&operation.version),
|
||||
is_undo: false,
|
||||
edit_ranges: operation
|
||||
@@ -389,7 +557,6 @@ fn operation_to_storage(
|
||||
operation.replica_id,
|
||||
operation.lamport_timestamp,
|
||||
storage::Operation {
|
||||
local_timestamp: operation.local_timestamp,
|
||||
version: version_to_storage(&operation.version),
|
||||
is_undo: true,
|
||||
edit_ranges: Vec::new(),
|
||||
@@ -399,7 +566,7 @@ fn operation_to_storage(
|
||||
.iter()
|
||||
.map(|entry| storage::UndoCount {
|
||||
replica_id: entry.replica_id,
|
||||
local_timestamp: entry.local_timestamp,
|
||||
lamport_timestamp: entry.lamport_timestamp,
|
||||
count: entry.count,
|
||||
})
|
||||
.collect(),
|
||||
@@ -427,7 +594,6 @@ fn operation_from_storage(
|
||||
Ok(if operation.is_undo {
|
||||
proto::operation::Variant::Undo(proto::operation::Undo {
|
||||
replica_id: row.replica_id as u32,
|
||||
local_timestamp: operation.local_timestamp as u32,
|
||||
lamport_timestamp: row.lamport_timestamp as u32,
|
||||
version,
|
||||
counts: operation
|
||||
@@ -435,7 +601,7 @@ fn operation_from_storage(
|
||||
.iter()
|
||||
.map(|entry| proto::UndoCount {
|
||||
replica_id: entry.replica_id,
|
||||
local_timestamp: entry.local_timestamp,
|
||||
lamport_timestamp: entry.lamport_timestamp,
|
||||
count: entry.count,
|
||||
})
|
||||
.collect(),
|
||||
@@ -443,7 +609,6 @@ fn operation_from_storage(
|
||||
} else {
|
||||
proto::operation::Variant::Edit(proto::operation::Edit {
|
||||
replica_id: row.replica_id as u32,
|
||||
local_timestamp: operation.local_timestamp as u32,
|
||||
lamport_timestamp: row.lamport_timestamp as u32,
|
||||
version,
|
||||
ranges: operation
|
||||
@@ -483,10 +648,9 @@ fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::
|
||||
pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
|
||||
match operation.variant? {
|
||||
proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
|
||||
timestamp: InsertionTimestamp {
|
||||
timestamp: clock::Lamport {
|
||||
replica_id: edit.replica_id as text::ReplicaId,
|
||||
local: edit.local_timestamp,
|
||||
lamport: edit.lamport_timestamp,
|
||||
value: edit.lamport_timestamp,
|
||||
},
|
||||
version: version_from_wire(&edit.version),
|
||||
ranges: edit
|
||||
@@ -498,32 +662,26 @@ pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operatio
|
||||
.collect(),
|
||||
new_text: edit.new_text.into_iter().map(Arc::from).collect(),
|
||||
})),
|
||||
proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo {
|
||||
lamport_timestamp: clock::Lamport {
|
||||
proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
|
||||
timestamp: clock::Lamport {
|
||||
replica_id: undo.replica_id as text::ReplicaId,
|
||||
value: undo.lamport_timestamp,
|
||||
},
|
||||
undo: UndoOperation {
|
||||
id: clock::Local {
|
||||
replica_id: undo.replica_id as text::ReplicaId,
|
||||
value: undo.local_timestamp,
|
||||
},
|
||||
version: version_from_wire(&undo.version),
|
||||
counts: undo
|
||||
.counts
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
(
|
||||
clock::Local {
|
||||
replica_id: c.replica_id as text::ReplicaId,
|
||||
value: c.local_timestamp,
|
||||
},
|
||||
c.count,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
}),
|
||||
version: version_from_wire(&undo.version),
|
||||
counts: undo
|
||||
.counts
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
(
|
||||
clock::Lamport {
|
||||
replica_id: c.replica_id as text::ReplicaId,
|
||||
value: c.lamport_timestamp,
|
||||
},
|
||||
c.count,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
})),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -531,7 +689,7 @@ pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operatio
|
||||
fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
|
||||
let mut version = clock::Global::new();
|
||||
for entry in message {
|
||||
version.observe(clock::Local {
|
||||
version.observe(clock::Lamport {
|
||||
replica_id: entry.replica_id as text::ReplicaId,
|
||||
value: entry.timestamp,
|
||||
});
|
||||
@@ -539,6 +697,22 @@ fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
|
||||
version
|
||||
}
|
||||
|
||||
fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
|
||||
let mut message = Vec::new();
|
||||
for entry in version.iter() {
|
||||
message.push(proto::VectorClockEntry {
|
||||
replica_id: entry.replica_id as u32,
|
||||
timestamp: entry.value,
|
||||
});
|
||||
}
|
||||
message
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
|
||||
enum QueryOperationSerializationVersion {
|
||||
OperationSerializationVersion,
|
||||
}
|
||||
|
||||
mod storage {
|
||||
#![allow(non_snake_case)]
|
||||
use prost::Message;
|
||||
@@ -546,8 +720,6 @@ mod storage {
|
||||
|
||||
#[derive(Message)]
|
||||
pub struct Operation {
|
||||
#[prost(uint32, tag = "1")]
|
||||
pub local_timestamp: u32,
|
||||
#[prost(message, repeated, tag = "2")]
|
||||
pub version: Vec<VectorClockEntry>,
|
||||
#[prost(bool, tag = "3")]
|
||||
@@ -581,7 +753,7 @@ mod storage {
|
||||
#[prost(uint32, tag = "1")]
|
||||
pub replica_id: u32,
|
||||
#[prost(uint32, tag = "2")]
|
||||
pub local_timestamp: u32,
|
||||
pub lamport_timestamp: u32,
|
||||
#[prost(uint32, tag = "3")]
|
||||
pub count: u32,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,22 @@
|
||||
use super::*;
|
||||
|
||||
type ChannelDescendants = HashMap<ChannelId, HashSet<ChannelId>>;
|
||||
|
||||
impl Database {
|
||||
#[cfg(test)]
|
||||
pub async fn all_channels(&self) -> Result<Vec<(ChannelId, String)>> {
|
||||
self.transaction(move |tx| async move {
|
||||
let mut channels = Vec::new();
|
||||
let mut rows = channel::Entity::find().stream(&*tx).await?;
|
||||
while let Some(row) = rows.next().await {
|
||||
let row = row?;
|
||||
channels.push((row.id, row.name));
|
||||
}
|
||||
Ok(channels)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn create_root_channel(
|
||||
&self,
|
||||
name: &str,
|
||||
@@ -86,7 +102,7 @@ impl Database {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn remove_channel(
|
||||
pub async fn delete_channel(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
user_id: UserId,
|
||||
@@ -135,6 +151,19 @@ impl Database {
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
|
||||
// Delete any other paths that incldue this channel
|
||||
let sql = r#"
|
||||
DELETE FROM channel_paths
|
||||
WHERE
|
||||
id_path LIKE '%' || $1 || '%'
|
||||
"#;
|
||||
let channel_paths_stmt = Statement::from_sql_and_values(
|
||||
self.pool.get_database_backend(),
|
||||
sql,
|
||||
[channel_id.to_proto().into()],
|
||||
);
|
||||
tx.execute(channel_paths_stmt).await?;
|
||||
|
||||
Ok((channels_to_remove.into_keys().collect(), members_to_notify))
|
||||
})
|
||||
.await
|
||||
@@ -305,6 +334,43 @@ impl Database {
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_all_channels(
|
||||
&self,
|
||||
parents_by_child_id: ChannelDescendants,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<Vec<Channel>> {
|
||||
let mut channels = Vec::with_capacity(parents_by_child_id.len());
|
||||
{
|
||||
let mut rows = channel::Entity::find()
|
||||
.filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
while let Some(row) = rows.next().await {
|
||||
let row = row?;
|
||||
|
||||
// As these rows are pulled from the map's keys, this unwrap is safe.
|
||||
let parents = parents_by_child_id.get(&row.id).unwrap();
|
||||
if parents.len() > 0 {
|
||||
for parent in parents {
|
||||
channels.push(Channel {
|
||||
id: row.id,
|
||||
name: row.name.clone(),
|
||||
parent_id: Some(*parent),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
channels.push(Channel {
|
||||
id: row.id,
|
||||
name: row.name,
|
||||
parent_id: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(channels)
|
||||
}
|
||||
|
||||
pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
|
||||
self.transaction(|tx| async move {
|
||||
let tx = tx;
|
||||
@@ -327,21 +393,7 @@ impl Database {
|
||||
.filter_map(|membership| membership.admin.then_some(membership.channel_id))
|
||||
.collect();
|
||||
|
||||
let mut channels = Vec::with_capacity(parents_by_child_id.len());
|
||||
{
|
||||
let mut rows = channel::Entity::find()
|
||||
.filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
while let Some(row) = rows.next().await {
|
||||
let row = row?;
|
||||
channels.push(Channel {
|
||||
id: row.id,
|
||||
name: row.name,
|
||||
parent_id: parents_by_child_id.get(&row.id).copied().flatten(),
|
||||
});
|
||||
}
|
||||
}
|
||||
let channels = self.get_all_channels(parents_by_child_id, &tx).await?;
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
|
||||
enum QueryUserIdsAndChannelIds {
|
||||
@@ -545,6 +597,7 @@ impl Database {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the channel ancestors, deepest first
|
||||
pub async fn get_channel_ancestors(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
@@ -552,6 +605,7 @@ impl Database {
|
||||
) -> Result<Vec<ChannelId>> {
|
||||
let paths = channel_path::Entity::find()
|
||||
.filter(channel_path::Column::ChannelId.eq(channel_id))
|
||||
.order_by(channel_path::Column::IdPath, sea_query::Order::Desc)
|
||||
.all(tx)
|
||||
.await?;
|
||||
let mut channel_ids = Vec::new();
|
||||
@@ -568,11 +622,25 @@ impl Database {
|
||||
Ok(channel_ids)
|
||||
}
|
||||
|
||||
/// Returns the channel descendants,
|
||||
/// Structured as a map from child ids to their parent ids
|
||||
/// For example, the descendants of 'a' in this DAG:
|
||||
///
|
||||
/// /- b -\
|
||||
/// a -- c -- d
|
||||
///
|
||||
/// would be:
|
||||
/// {
|
||||
/// a: [],
|
||||
/// b: [a],
|
||||
/// c: [a],
|
||||
/// d: [a, c],
|
||||
/// }
|
||||
async fn get_channel_descendants(
|
||||
&self,
|
||||
channel_ids: impl IntoIterator<Item = ChannelId>,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<HashMap<ChannelId, Option<ChannelId>>> {
|
||||
) -> Result<ChannelDescendants> {
|
||||
let mut values = String::new();
|
||||
for id in channel_ids {
|
||||
if !values.is_empty() {
|
||||
@@ -599,7 +667,7 @@ impl Database {
|
||||
|
||||
let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
|
||||
|
||||
let mut parents_by_child_id = HashMap::default();
|
||||
let mut parents_by_child_id: ChannelDescendants = HashMap::default();
|
||||
let mut paths = channel_path::Entity::find()
|
||||
.from_raw_sql(stmt)
|
||||
.stream(tx)
|
||||
@@ -618,7 +686,10 @@ impl Database {
|
||||
parent_id = Some(id);
|
||||
}
|
||||
}
|
||||
parents_by_child_id.insert(path.channel_id, parent_id);
|
||||
let entry = parents_by_child_id.entry(path.channel_id).or_default();
|
||||
if let Some(parent_id) = parent_id {
|
||||
entry.insert(parent_id);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(parents_by_child_id)
|
||||
@@ -689,6 +760,191 @@ impl Database {
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
// Insert an edge from the given channel to the given other channel.
|
||||
pub async fn link_channel(
|
||||
&self,
|
||||
user: UserId,
|
||||
channel: ChannelId,
|
||||
to: ChannelId,
|
||||
) -> Result<Vec<Channel>> {
|
||||
self.transaction(|tx| async move {
|
||||
// Note that even with these maxed permissions, this linking operation
|
||||
// is still insecure because you can't remove someone's permissions to a
|
||||
// channel if they've linked the channel to one where they're an admin.
|
||||
self.check_user_is_channel_admin(channel, user, &*tx)
|
||||
.await?;
|
||||
|
||||
self.link_channel_internal(user, channel, to, &*tx).await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn link_channel_internal(
|
||||
&self,
|
||||
user: UserId,
|
||||
channel: ChannelId,
|
||||
to: ChannelId,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<Vec<Channel>> {
|
||||
self.check_user_is_channel_admin(to, user, &*tx).await?;
|
||||
|
||||
let to_ancestors = self.get_channel_ancestors(to, &*tx).await?;
|
||||
let mut from_descendants = self.get_channel_descendants([channel], &*tx).await?;
|
||||
for ancestor in to_ancestors {
|
||||
if from_descendants.contains_key(&ancestor) {
|
||||
return Err(anyhow!("Cannot create a channel cycle").into());
|
||||
}
|
||||
}
|
||||
let sql = r#"
|
||||
INSERT INTO channel_paths
|
||||
(id_path, channel_id)
|
||||
SELECT
|
||||
id_path || $1 || '/', $2
|
||||
FROM
|
||||
channel_paths
|
||||
WHERE
|
||||
channel_id = $3
|
||||
ON CONFLICT (id_path) DO NOTHING;
|
||||
"#;
|
||||
let channel_paths_stmt = Statement::from_sql_and_values(
|
||||
self.pool.get_database_backend(),
|
||||
sql,
|
||||
[
|
||||
channel.to_proto().into(),
|
||||
channel.to_proto().into(),
|
||||
to.to_proto().into(),
|
||||
],
|
||||
);
|
||||
tx.execute(channel_paths_stmt).await?;
|
||||
for (from_id, to_ids) in from_descendants.iter().filter(|(id, _)| id != &&channel) {
|
||||
for to_id in to_ids {
|
||||
let channel_paths_stmt = Statement::from_sql_and_values(
|
||||
self.pool.get_database_backend(),
|
||||
sql,
|
||||
[
|
||||
from_id.to_proto().into(),
|
||||
from_id.to_proto().into(),
|
||||
to_id.to_proto().into(),
|
||||
],
|
||||
);
|
||||
tx.execute(channel_paths_stmt).await?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(channel) = from_descendants.get_mut(&channel) {
|
||||
// Remove the other parents
|
||||
channel.clear();
|
||||
channel.insert(to);
|
||||
}
|
||||
|
||||
let channels = self.get_all_channels(from_descendants, &*tx).await?;
|
||||
|
||||
Ok(channels)
|
||||
}
|
||||
|
||||
/// Unlink a channel from a given parent. This will add in a root edge if
|
||||
/// the channel has no other parents after this operation.
|
||||
pub async fn unlink_channel(
|
||||
&self,
|
||||
user: UserId,
|
||||
channel: ChannelId,
|
||||
from: Option<ChannelId>,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
// Note that even with these maxed permissions, this linking operation
|
||||
// is still insecure because you can't remove someone's permissions to a
|
||||
// channel if they've linked the channel to one where they're an admin.
|
||||
self.check_user_is_channel_admin(channel, user, &*tx)
|
||||
.await?;
|
||||
|
||||
self.unlink_channel_internal(user, channel, from, &*tx)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn unlink_channel_internal(
|
||||
&self,
|
||||
user: UserId,
|
||||
channel: ChannelId,
|
||||
from: Option<ChannelId>,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<()> {
|
||||
if let Some(from) = from {
|
||||
self.check_user_is_channel_admin(from, user, &*tx).await?;
|
||||
|
||||
let sql = r#"
|
||||
DELETE FROM channel_paths
|
||||
WHERE
|
||||
id_path LIKE '%' || $1 || '/' || $2 || '%'
|
||||
"#;
|
||||
let channel_paths_stmt = Statement::from_sql_and_values(
|
||||
self.pool.get_database_backend(),
|
||||
sql,
|
||||
[from.to_proto().into(), channel.to_proto().into()],
|
||||
);
|
||||
tx.execute(channel_paths_stmt).await?;
|
||||
} else {
|
||||
let sql = r#"
|
||||
DELETE FROM channel_paths
|
||||
WHERE
|
||||
id_path = '/' || $1 || '/'
|
||||
"#;
|
||||
let channel_paths_stmt = Statement::from_sql_and_values(
|
||||
self.pool.get_database_backend(),
|
||||
sql,
|
||||
[channel.to_proto().into()],
|
||||
);
|
||||
tx.execute(channel_paths_stmt).await?;
|
||||
}
|
||||
|
||||
// Make sure that there is always at least one path to the channel
|
||||
let sql = r#"
|
||||
INSERT INTO channel_paths
|
||||
(id_path, channel_id)
|
||||
SELECT
|
||||
'/' || $1 || '/', $2
|
||||
WHERE NOT EXISTS
|
||||
(SELECT *
|
||||
FROM channel_paths
|
||||
WHERE channel_id = $2)
|
||||
"#;
|
||||
|
||||
let channel_paths_stmt = Statement::from_sql_and_values(
|
||||
self.pool.get_database_backend(),
|
||||
sql,
|
||||
[channel.to_proto().into(), channel.to_proto().into()],
|
||||
);
|
||||
tx.execute(channel_paths_stmt).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Move a channel from one parent to another, returns the
|
||||
/// Channels that were moved for notifying clients
|
||||
pub async fn move_channel(
|
||||
&self,
|
||||
user: UserId,
|
||||
channel: ChannelId,
|
||||
from: Option<ChannelId>,
|
||||
to: ChannelId,
|
||||
) -> Result<Vec<Channel>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.check_user_is_channel_admin(channel, user, &*tx)
|
||||
.await?;
|
||||
|
||||
let moved_channels = self.link_channel_internal(user, channel, to, &*tx).await?;
|
||||
|
||||
self.unlink_channel_internal(user, channel, from, &*tx)
|
||||
.await?;
|
||||
|
||||
Ok(moved_channels)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::*;
|
||||
|
||||
impl Database {
|
||||
pub async fn refresh_room(
|
||||
pub async fn clear_stale_room_participants(
|
||||
&self,
|
||||
room_id: RoomId,
|
||||
new_server_id: ServerId,
|
||||
|
||||
@@ -14,31 +14,49 @@ impl Database {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn stale_room_ids(
|
||||
pub async fn stale_server_resource_ids(
|
||||
&self,
|
||||
environment: &str,
|
||||
new_server_id: ServerId,
|
||||
) -> Result<Vec<RoomId>> {
|
||||
) -> Result<(Vec<RoomId>, Vec<ChannelId>)> {
|
||||
self.transaction(|tx| async move {
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
|
||||
enum QueryAs {
|
||||
enum QueryRoomIds {
|
||||
RoomId,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
|
||||
enum QueryChannelIds {
|
||||
ChannelId,
|
||||
}
|
||||
|
||||
let stale_server_epochs = self
|
||||
.stale_server_ids(environment, new_server_id, &tx)
|
||||
.await?;
|
||||
Ok(room_participant::Entity::find()
|
||||
let room_ids = room_participant::Entity::find()
|
||||
.select_only()
|
||||
.column(room_participant::Column::RoomId)
|
||||
.distinct()
|
||||
.filter(
|
||||
room_participant::Column::AnsweringConnectionServerId
|
||||
.is_in(stale_server_epochs),
|
||||
.is_in(stale_server_epochs.iter().copied()),
|
||||
)
|
||||
.into_values::<_, QueryAs>()
|
||||
.into_values::<_, QueryRoomIds>()
|
||||
.all(&*tx)
|
||||
.await?)
|
||||
.await?;
|
||||
let channel_ids = channel_buffer_collaborator::Entity::find()
|
||||
.select_only()
|
||||
.column(channel_buffer_collaborator::Column::ChannelId)
|
||||
.distinct()
|
||||
.filter(
|
||||
channel_buffer_collaborator::Column::ConnectionServerId
|
||||
.is_in(stale_server_epochs.iter().copied()),
|
||||
)
|
||||
.into_values::<_, QueryChannelIds>()
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok((room_ids, channel_ids))
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -241,7 +241,6 @@ impl Database {
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
pub async fn create_user_flag(&self, flag: &str) -> Result<FlagId> {
|
||||
self.transaction(|tx| async move {
|
||||
let flag = feature_flag::Entity::insert(feature_flag::ActiveModel {
|
||||
@@ -257,7 +256,6 @@ impl Database {
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
pub async fn add_user_flag(&self, user: UserId, flag: FlagId) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
user_feature::Entity::insert(user_feature::ActiveModel {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
mod buffer_tests;
|
||||
mod channel_tests;
|
||||
mod db_tests;
|
||||
mod feature_flag_tests;
|
||||
|
||||
|
||||
844
crates/collab/src/db/tests/channel_tests.rs
Normal file
844
crates/collab/src/db/tests/channel_tests.rs
Normal file
@@ -0,0 +1,844 @@
|
||||
use rpc::{proto, ConnectionId};
|
||||
|
||||
use crate::{
|
||||
db::{Channel, ChannelId, Database, NewUserParams},
|
||||
test_both_dbs,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite);
|
||||
|
||||
async fn test_channels(db: &Arc<Database>) {
|
||||
let a_id = db
|
||||
.create_user(
|
||||
"user1@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user1".into(),
|
||||
github_user_id: 5,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let b_id = db
|
||||
.create_user(
|
||||
"user2@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user2".into(),
|
||||
github_user_id: 6,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
|
||||
|
||||
// Make sure that people cannot read channels they haven't been invited to
|
||||
assert!(db.get_channel(zed_id, b_id).await.unwrap().is_none());
|
||||
|
||||
db.invite_channel_member(zed_id, b_id, a_id, false)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
db.respond_to_channel_invite(zed_id, b_id, true)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let crdb_id = db
|
||||
.create_channel("crdb", Some(zed_id), "2", a_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let livestreaming_id = db
|
||||
.create_channel("livestreaming", Some(zed_id), "3", a_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let replace_id = db
|
||||
.create_channel("replace", Some(zed_id), "4", a_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut members = db.get_channel_members(replace_id).await.unwrap();
|
||||
members.sort();
|
||||
assert_eq!(members, &[a_id, b_id]);
|
||||
|
||||
let rust_id = db.create_root_channel("rust", "5", a_id).await.unwrap();
|
||||
let cargo_id = db
|
||||
.create_channel("cargo", Some(rust_id), "6", a_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let cargo_ra_id = db
|
||||
.create_channel("cargo-ra", Some(cargo_id), "7", a_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_eq!(
|
||||
result.channels,
|
||||
vec![
|
||||
Channel {
|
||||
id: zed_id,
|
||||
name: "zed".to_string(),
|
||||
parent_id: None,
|
||||
},
|
||||
Channel {
|
||||
id: crdb_id,
|
||||
name: "crdb".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
Channel {
|
||||
id: livestreaming_id,
|
||||
name: "livestreaming".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
Channel {
|
||||
id: replace_id,
|
||||
name: "replace".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
Channel {
|
||||
id: rust_id,
|
||||
name: "rust".to_string(),
|
||||
parent_id: None,
|
||||
},
|
||||
Channel {
|
||||
id: cargo_id,
|
||||
name: "cargo".to_string(),
|
||||
parent_id: Some(rust_id),
|
||||
},
|
||||
Channel {
|
||||
id: cargo_ra_id,
|
||||
name: "cargo-ra".to_string(),
|
||||
parent_id: Some(cargo_id),
|
||||
}
|
||||
]
|
||||
);
|
||||
|
||||
let result = db.get_channels_for_user(b_id).await.unwrap();
|
||||
assert_eq!(
|
||||
result.channels,
|
||||
vec![
|
||||
Channel {
|
||||
id: zed_id,
|
||||
name: "zed".to_string(),
|
||||
parent_id: None,
|
||||
},
|
||||
Channel {
|
||||
id: crdb_id,
|
||||
name: "crdb".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
Channel {
|
||||
id: livestreaming_id,
|
||||
name: "livestreaming".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
Channel {
|
||||
id: replace_id,
|
||||
name: "replace".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
]
|
||||
);
|
||||
|
||||
// Update member permissions
|
||||
let set_subchannel_admin = db.set_channel_member_admin(crdb_id, a_id, b_id, true).await;
|
||||
assert!(set_subchannel_admin.is_err());
|
||||
let set_channel_admin = db.set_channel_member_admin(zed_id, a_id, b_id, true).await;
|
||||
assert!(set_channel_admin.is_ok());
|
||||
|
||||
let result = db.get_channels_for_user(b_id).await.unwrap();
|
||||
assert_eq!(
|
||||
result.channels,
|
||||
vec![
|
||||
Channel {
|
||||
id: zed_id,
|
||||
name: "zed".to_string(),
|
||||
parent_id: None,
|
||||
},
|
||||
Channel {
|
||||
id: crdb_id,
|
||||
name: "crdb".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
Channel {
|
||||
id: livestreaming_id,
|
||||
name: "livestreaming".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
Channel {
|
||||
id: replace_id,
|
||||
name: "replace".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
]
|
||||
);
|
||||
|
||||
// Remove a single channel
|
||||
db.delete_channel(crdb_id, a_id).await.unwrap();
|
||||
assert!(db.get_channel(crdb_id, a_id).await.unwrap().is_none());
|
||||
|
||||
// Remove a channel tree
|
||||
let (mut channel_ids, user_ids) = db.delete_channel(rust_id, a_id).await.unwrap();
|
||||
channel_ids.sort();
|
||||
assert_eq!(channel_ids, &[rust_id, cargo_id, cargo_ra_id]);
|
||||
assert_eq!(user_ids, &[a_id]);
|
||||
|
||||
assert!(db.get_channel(rust_id, a_id).await.unwrap().is_none());
|
||||
assert!(db.get_channel(cargo_id, a_id).await.unwrap().is_none());
|
||||
assert!(db.get_channel(cargo_ra_id, a_id).await.unwrap().is_none());
|
||||
}
|
||||
|
||||
test_both_dbs!(
|
||||
test_joining_channels,
|
||||
test_joining_channels_postgres,
|
||||
test_joining_channels_sqlite
|
||||
);
|
||||
|
||||
async fn test_joining_channels(db: &Arc<Database>) {
|
||||
let owner_id = db.create_server("test").await.unwrap().0 as u32;
|
||||
|
||||
let user_1 = db
|
||||
.create_user(
|
||||
"user1@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user1".into(),
|
||||
github_user_id: 5,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
let user_2 = db
|
||||
.create_user(
|
||||
"user2@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user2".into(),
|
||||
github_user_id: 6,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let channel_1 = db
|
||||
.create_root_channel("channel_1", "1", user_1)
|
||||
.await
|
||||
.unwrap();
|
||||
let room_1 = db.room_id_for_channel(channel_1).await.unwrap();
|
||||
|
||||
// can join a room with membership to its channel
|
||||
let joined_room = db
|
||||
.join_room(room_1, user_1, ConnectionId { owner_id, id: 1 })
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(joined_room.room.participants.len(), 1);
|
||||
|
||||
drop(joined_room);
|
||||
// cannot join a room without membership to its channel
|
||||
assert!(db
|
||||
.join_room(room_1, user_2, ConnectionId { owner_id, id: 1 })
|
||||
.await
|
||||
.is_err());
|
||||
}
|
||||
|
||||
test_both_dbs!(
|
||||
test_channel_invites,
|
||||
test_channel_invites_postgres,
|
||||
test_channel_invites_sqlite
|
||||
);
|
||||
|
||||
async fn test_channel_invites(db: &Arc<Database>) {
|
||||
db.create_server("test").await.unwrap();
|
||||
|
||||
let user_1 = db
|
||||
.create_user(
|
||||
"user1@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user1".into(),
|
||||
github_user_id: 5,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
let user_2 = db
|
||||
.create_user(
|
||||
"user2@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user2".into(),
|
||||
github_user_id: 6,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let user_3 = db
|
||||
.create_user(
|
||||
"user3@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user3".into(),
|
||||
github_user_id: 7,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let channel_1_1 = db
|
||||
.create_root_channel("channel_1", "1", user_1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let channel_1_2 = db
|
||||
.create_root_channel("channel_2", "2", user_1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
db.invite_channel_member(channel_1_1, user_2, user_1, false)
|
||||
.await
|
||||
.unwrap();
|
||||
db.invite_channel_member(channel_1_2, user_2, user_1, false)
|
||||
.await
|
||||
.unwrap();
|
||||
db.invite_channel_member(channel_1_1, user_3, user_1, true)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let user_2_invites = db
|
||||
.get_channel_invites_for_user(user_2) // -> [channel_1_1, channel_1_2]
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|channel| channel.id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(user_2_invites, &[channel_1_1, channel_1_2]);
|
||||
|
||||
let user_3_invites = db
|
||||
.get_channel_invites_for_user(user_3) // -> [channel_1_1]
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|channel| channel.id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(user_3_invites, &[channel_1_1]);
|
||||
|
||||
let members = db
|
||||
.get_channel_member_details(channel_1_1, user_1)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
members,
|
||||
&[
|
||||
proto::ChannelMember {
|
||||
user_id: user_1.to_proto(),
|
||||
kind: proto::channel_member::Kind::Member.into(),
|
||||
admin: true,
|
||||
},
|
||||
proto::ChannelMember {
|
||||
user_id: user_2.to_proto(),
|
||||
kind: proto::channel_member::Kind::Invitee.into(),
|
||||
admin: false,
|
||||
},
|
||||
proto::ChannelMember {
|
||||
user_id: user_3.to_proto(),
|
||||
kind: proto::channel_member::Kind::Invitee.into(),
|
||||
admin: true,
|
||||
},
|
||||
]
|
||||
);
|
||||
|
||||
db.respond_to_channel_invite(channel_1_1, user_2, true)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let channel_1_3 = db
|
||||
.create_channel("channel_3", Some(channel_1_1), "1", user_1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let members = db
|
||||
.get_channel_member_details(channel_1_3, user_1)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
members,
|
||||
&[
|
||||
proto::ChannelMember {
|
||||
user_id: user_1.to_proto(),
|
||||
kind: proto::channel_member::Kind::Member.into(),
|
||||
admin: true,
|
||||
},
|
||||
proto::ChannelMember {
|
||||
user_id: user_2.to_proto(),
|
||||
kind: proto::channel_member::Kind::AncestorMember.into(),
|
||||
admin: false,
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
test_both_dbs!(
|
||||
test_channel_renames,
|
||||
test_channel_renames_postgres,
|
||||
test_channel_renames_sqlite
|
||||
);
|
||||
|
||||
async fn test_channel_renames(db: &Arc<Database>) {
|
||||
db.create_server("test").await.unwrap();
|
||||
|
||||
let user_1 = db
|
||||
.create_user(
|
||||
"user1@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user1".into(),
|
||||
github_user_id: 5,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let user_2 = db
|
||||
.create_user(
|
||||
"user2@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user2".into(),
|
||||
github_user_id: 6,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let zed_id = db.create_root_channel("zed", "1", user_1).await.unwrap();
|
||||
|
||||
db.rename_channel(zed_id, user_1, "#zed-archive")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let zed_archive_id = zed_id;
|
||||
|
||||
let (channel, _) = db
|
||||
.get_channel(zed_archive_id, user_1)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(channel.name, "zed-archive");
|
||||
|
||||
let non_permissioned_rename = db
|
||||
.rename_channel(zed_archive_id, user_2, "hacked-lol")
|
||||
.await;
|
||||
assert!(non_permissioned_rename.is_err());
|
||||
|
||||
let bad_name_rename = db.rename_channel(zed_id, user_1, "#").await;
|
||||
assert!(bad_name_rename.is_err())
|
||||
}
|
||||
|
||||
test_both_dbs!(
|
||||
test_channels_moving,
|
||||
test_channels_moving_postgres,
|
||||
test_channels_moving_sqlite
|
||||
);
|
||||
|
||||
async fn test_channels_moving(db: &Arc<Database>) {
|
||||
let a_id = db
|
||||
.create_user(
|
||||
"user1@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user1".into(),
|
||||
github_user_id: 5,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
|
||||
|
||||
let crdb_id = db
|
||||
.create_channel("crdb", Some(zed_id), "2", a_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let gpui2_id = db
|
||||
.create_channel("gpui2", Some(zed_id), "3", a_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let livestreaming_id = db
|
||||
.create_channel("livestreaming", Some(crdb_id), "4", a_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let livestreaming_dag_id = db
|
||||
.create_channel("livestreaming_dag", Some(livestreaming_id), "5", a_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// ========================================================================
|
||||
// sanity check
|
||||
// Initial DAG:
|
||||
// /- gpui2
|
||||
// zed -- crdb - livestreaming - livestreaming_dag
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_dag(
|
||||
result.channels,
|
||||
&[
|
||||
(zed_id, None),
|
||||
(crdb_id, Some(zed_id)),
|
||||
(gpui2_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(crdb_id)),
|
||||
(livestreaming_dag_id, Some(livestreaming_id)),
|
||||
],
|
||||
);
|
||||
|
||||
// Attempt to make a cycle
|
||||
assert!(db
|
||||
.link_channel(a_id, zed_id, livestreaming_id)
|
||||
.await
|
||||
.is_err());
|
||||
|
||||
// ========================================================================
|
||||
// Make a link
|
||||
db.link_channel(a_id, livestreaming_id, zed_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// DAG is now:
|
||||
// /- gpui2
|
||||
// zed -- crdb - livestreaming - livestreaming_dag
|
||||
// \---------/
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_dag(result.channels, &[
|
||||
(zed_id, None),
|
||||
(crdb_id, Some(zed_id)),
|
||||
(gpui2_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(crdb_id)),
|
||||
(livestreaming_dag_id, Some(livestreaming_id)),
|
||||
]);
|
||||
|
||||
// ========================================================================
|
||||
// Create a new channel below a channel with multiple parents
|
||||
let livestreaming_dag_sub_id = db
|
||||
.create_channel(
|
||||
"livestreaming_dag_sub",
|
||||
Some(livestreaming_dag_id),
|
||||
"6",
|
||||
a_id,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// DAG is now:
|
||||
// /- gpui2
|
||||
// zed -- crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub_id
|
||||
// \---------/
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_dag(result.channels, &[
|
||||
(zed_id, None),
|
||||
(crdb_id, Some(zed_id)),
|
||||
(gpui2_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(crdb_id)),
|
||||
(livestreaming_dag_id, Some(livestreaming_id)),
|
||||
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
|
||||
]);
|
||||
|
||||
// ========================================================================
|
||||
// Test a complex DAG by making another link
|
||||
let returned_channels = db
|
||||
.link_channel(a_id, livestreaming_dag_sub_id, livestreaming_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// DAG is now:
|
||||
// /- gpui2 /---------------------\
|
||||
// zed - crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub_id
|
||||
// \--------/
|
||||
|
||||
// make sure we're getting just the new link
|
||||
// Not using the assert_dag helper because we want to make sure we're returning the full data
|
||||
pretty_assertions::assert_eq!(
|
||||
returned_channels,
|
||||
vec![Channel {
|
||||
id: livestreaming_dag_sub_id,
|
||||
name: "livestreaming_dag_sub".to_string(),
|
||||
parent_id: Some(livestreaming_id),
|
||||
}]
|
||||
);
|
||||
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_dag(result.channels, &[
|
||||
(zed_id, None),
|
||||
(crdb_id, Some(zed_id)),
|
||||
(gpui2_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(crdb_id)),
|
||||
(livestreaming_dag_id, Some(livestreaming_id)),
|
||||
(livestreaming_dag_sub_id, Some(livestreaming_id)),
|
||||
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
|
||||
]);
|
||||
|
||||
// ========================================================================
|
||||
// Test a complex DAG by making another link
|
||||
let returned_channels = db
|
||||
.link_channel(a_id, livestreaming_id, gpui2_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// DAG is now:
|
||||
// /- gpui2 -\ /---------------------\
|
||||
// zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub_id
|
||||
// \---------/
|
||||
|
||||
// Make sure that we're correctly getting the full sub-dag
|
||||
pretty_assertions::assert_eq!(
|
||||
returned_channels,
|
||||
vec![
|
||||
Channel {
|
||||
id: livestreaming_id,
|
||||
name: "livestreaming".to_string(),
|
||||
parent_id: Some(gpui2_id),
|
||||
},
|
||||
Channel {
|
||||
id: livestreaming_dag_id,
|
||||
name: "livestreaming_dag".to_string(),
|
||||
parent_id: Some(livestreaming_id),
|
||||
},
|
||||
Channel {
|
||||
id: livestreaming_dag_sub_id,
|
||||
name: "livestreaming_dag_sub".to_string(),
|
||||
parent_id: Some(livestreaming_id),
|
||||
},
|
||||
Channel {
|
||||
id: livestreaming_dag_sub_id,
|
||||
name: "livestreaming_dag_sub".to_string(),
|
||||
parent_id: Some(livestreaming_dag_id),
|
||||
}
|
||||
]
|
||||
);
|
||||
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_dag(result.channels, &[
|
||||
(zed_id, None),
|
||||
(crdb_id, Some(zed_id)),
|
||||
(gpui2_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(gpui2_id)),
|
||||
(livestreaming_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(crdb_id)),
|
||||
(livestreaming_dag_id, Some(livestreaming_id)),
|
||||
(livestreaming_dag_sub_id, Some(livestreaming_id)),
|
||||
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
|
||||
]);
|
||||
|
||||
// ========================================================================
|
||||
// Test unlinking in a complex DAG by removing the inner link
|
||||
db
|
||||
.unlink_channel(
|
||||
a_id,
|
||||
livestreaming_dag_sub_id,
|
||||
Some(livestreaming_id),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// DAG is now:
|
||||
// /- gpui2 -\
|
||||
// zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub
|
||||
// \---------/
|
||||
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_dag(result.channels, &[
|
||||
(zed_id, None),
|
||||
(crdb_id, Some(zed_id)),
|
||||
(gpui2_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(gpui2_id)),
|
||||
(livestreaming_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(crdb_id)),
|
||||
(livestreaming_dag_id, Some(livestreaming_id)),
|
||||
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
|
||||
]);
|
||||
|
||||
// ========================================================================
|
||||
// Test unlinking in a complex DAG by removing the inner link
|
||||
db.unlink_channel(a_id, livestreaming_id, Some(gpui2_id))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// DAG is now:
|
||||
// /- gpui2
|
||||
// zed - crdb -- livestreaming - livestreaming_dag - livestreaming_dag_sub
|
||||
// \---------/
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_dag(result.channels, &[
|
||||
(zed_id, None),
|
||||
(crdb_id, Some(zed_id)),
|
||||
(gpui2_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(crdb_id)),
|
||||
(livestreaming_dag_id, Some(livestreaming_id)),
|
||||
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
|
||||
]);
|
||||
|
||||
// ========================================================================
|
||||
// Test moving DAG nodes by moving livestreaming to be below gpui2
|
||||
db.move_channel(a_id, livestreaming_id, Some(crdb_id), gpui2_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// DAG is now:
|
||||
// /- gpui2 -- livestreaming - livestreaming_dag - livestreaming_dag_sub
|
||||
// zed - crdb /
|
||||
// \---------/
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_dag(result.channels, &[
|
||||
(zed_id, None),
|
||||
(crdb_id, Some(zed_id)),
|
||||
(gpui2_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(gpui2_id)),
|
||||
(livestreaming_dag_id, Some(livestreaming_id)),
|
||||
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
|
||||
]);
|
||||
|
||||
// ========================================================================
|
||||
// Deleting a channel should not delete children that still have other parents
|
||||
db.delete_channel(gpui2_id, a_id).await.unwrap();
|
||||
|
||||
// DAG is now:
|
||||
// zed - crdb
|
||||
// \- livestreaming - livestreaming_dag - livestreaming_dag_sub
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_dag(result.channels, &[
|
||||
(zed_id, None),
|
||||
(crdb_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(zed_id)),
|
||||
(livestreaming_dag_id, Some(livestreaming_id)),
|
||||
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
|
||||
]);
|
||||
|
||||
// ========================================================================
|
||||
// Unlinking a channel from it's parent should automatically promote it to a root channel
|
||||
db.unlink_channel(a_id, crdb_id, Some(zed_id))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// DAG is now:
|
||||
// crdb
|
||||
// zed
|
||||
// \- livestreaming - livestreaming_dag - livestreaming_dag_sub
|
||||
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_dag(result.channels, &[
|
||||
(zed_id, None),
|
||||
(crdb_id, None),
|
||||
(livestreaming_id, Some(zed_id)),
|
||||
(livestreaming_dag_id, Some(livestreaming_id)),
|
||||
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
|
||||
]);
|
||||
|
||||
// ========================================================================
|
||||
// Unlinking a root channel should not have any effect
|
||||
db.unlink_channel(a_id, crdb_id, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// DAG is now:
|
||||
// crdb
|
||||
// zed
|
||||
// \- livestreaming - livestreaming_dag - livestreaming_dag_sub
|
||||
//
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_dag(result.channels, &[
|
||||
(zed_id, None),
|
||||
(crdb_id, None),
|
||||
(livestreaming_id, Some(zed_id)),
|
||||
(livestreaming_dag_id, Some(livestreaming_id)),
|
||||
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
|
||||
]);
|
||||
|
||||
// ========================================================================
|
||||
// You should be able to move a root channel into a non-root channel
|
||||
db.move_channel(a_id, crdb_id, None, zed_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// DAG is now:
|
||||
// zed - crdb
|
||||
// \- livestreaming - livestreaming_dag - livestreaming_dag_sub
|
||||
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_dag(result.channels, &[
|
||||
(zed_id, None),
|
||||
(crdb_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(zed_id)),
|
||||
(livestreaming_dag_id, Some(livestreaming_id)),
|
||||
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
|
||||
]);
|
||||
|
||||
|
||||
// ========================================================================
|
||||
// Moving a non-root channel without a parent id should be the equivalent of a link operation
|
||||
db.move_channel(a_id, livestreaming_id, None, crdb_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// DAG is now:
|
||||
// zed - crdb - livestreaming - livestreaming_dag - livestreaming_dag_sub
|
||||
// \--------/
|
||||
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_dag(result.channels, &[
|
||||
(zed_id, None),
|
||||
(crdb_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(zed_id)),
|
||||
(livestreaming_id, Some(crdb_id)),
|
||||
(livestreaming_dag_id, Some(livestreaming_id)),
|
||||
(livestreaming_dag_sub_id, Some(livestreaming_dag_id)),
|
||||
]);
|
||||
|
||||
// ========================================================================
|
||||
// Deleting a parent of a DAG should delete the whole DAG:
|
||||
db.delete_channel(zed_id, a_id).await.unwrap();
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert!(
|
||||
result.channels.is_empty()
|
||||
)
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn assert_dag(actual: Vec<Channel>, expected: &[(ChannelId, Option<ChannelId>)]) {
|
||||
let actual = actual
|
||||
.iter()
|
||||
.map(|channel| (channel.id, channel.parent_id))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
pretty_assertions::assert_eq!(actual, expected)
|
||||
}
|
||||
@@ -877,458 +877,6 @@ async fn test_invite_codes() {
|
||||
assert!(db.has_contact(user5, user1).await.unwrap());
|
||||
}
|
||||
|
||||
test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite);
|
||||
|
||||
async fn test_channels(db: &Arc<Database>) {
|
||||
let a_id = db
|
||||
.create_user(
|
||||
"user1@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user1".into(),
|
||||
github_user_id: 5,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let b_id = db
|
||||
.create_user(
|
||||
"user2@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user2".into(),
|
||||
github_user_id: 6,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
|
||||
|
||||
// Make sure that people cannot read channels they haven't been invited to
|
||||
assert!(db.get_channel(zed_id, b_id).await.unwrap().is_none());
|
||||
|
||||
db.invite_channel_member(zed_id, b_id, a_id, false)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
db.respond_to_channel_invite(zed_id, b_id, true)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let crdb_id = db
|
||||
.create_channel("crdb", Some(zed_id), "2", a_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let livestreaming_id = db
|
||||
.create_channel("livestreaming", Some(zed_id), "3", a_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let replace_id = db
|
||||
.create_channel("replace", Some(zed_id), "4", a_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut members = db.get_channel_members(replace_id).await.unwrap();
|
||||
members.sort();
|
||||
assert_eq!(members, &[a_id, b_id]);
|
||||
|
||||
let rust_id = db.create_root_channel("rust", "5", a_id).await.unwrap();
|
||||
let cargo_id = db
|
||||
.create_channel("cargo", Some(rust_id), "6", a_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let cargo_ra_id = db
|
||||
.create_channel("cargo-ra", Some(cargo_id), "7", a_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = db.get_channels_for_user(a_id).await.unwrap();
|
||||
assert_eq!(
|
||||
result.channels,
|
||||
vec![
|
||||
Channel {
|
||||
id: zed_id,
|
||||
name: "zed".to_string(),
|
||||
parent_id: None,
|
||||
},
|
||||
Channel {
|
||||
id: crdb_id,
|
||||
name: "crdb".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
Channel {
|
||||
id: livestreaming_id,
|
||||
name: "livestreaming".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
Channel {
|
||||
id: replace_id,
|
||||
name: "replace".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
Channel {
|
||||
id: rust_id,
|
||||
name: "rust".to_string(),
|
||||
parent_id: None,
|
||||
},
|
||||
Channel {
|
||||
id: cargo_id,
|
||||
name: "cargo".to_string(),
|
||||
parent_id: Some(rust_id),
|
||||
},
|
||||
Channel {
|
||||
id: cargo_ra_id,
|
||||
name: "cargo-ra".to_string(),
|
||||
parent_id: Some(cargo_id),
|
||||
}
|
||||
]
|
||||
);
|
||||
|
||||
let result = db.get_channels_for_user(b_id).await.unwrap();
|
||||
assert_eq!(
|
||||
result.channels,
|
||||
vec![
|
||||
Channel {
|
||||
id: zed_id,
|
||||
name: "zed".to_string(),
|
||||
parent_id: None,
|
||||
},
|
||||
Channel {
|
||||
id: crdb_id,
|
||||
name: "crdb".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
Channel {
|
||||
id: livestreaming_id,
|
||||
name: "livestreaming".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
Channel {
|
||||
id: replace_id,
|
||||
name: "replace".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
]
|
||||
);
|
||||
|
||||
// Update member permissions
|
||||
let set_subchannel_admin = db.set_channel_member_admin(crdb_id, a_id, b_id, true).await;
|
||||
assert!(set_subchannel_admin.is_err());
|
||||
let set_channel_admin = db.set_channel_member_admin(zed_id, a_id, b_id, true).await;
|
||||
assert!(set_channel_admin.is_ok());
|
||||
|
||||
let result = db.get_channels_for_user(b_id).await.unwrap();
|
||||
assert_eq!(
|
||||
result.channels,
|
||||
vec![
|
||||
Channel {
|
||||
id: zed_id,
|
||||
name: "zed".to_string(),
|
||||
parent_id: None,
|
||||
},
|
||||
Channel {
|
||||
id: crdb_id,
|
||||
name: "crdb".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
Channel {
|
||||
id: livestreaming_id,
|
||||
name: "livestreaming".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
Channel {
|
||||
id: replace_id,
|
||||
name: "replace".to_string(),
|
||||
parent_id: Some(zed_id),
|
||||
},
|
||||
]
|
||||
);
|
||||
|
||||
// Remove a single channel
|
||||
db.remove_channel(crdb_id, a_id).await.unwrap();
|
||||
assert!(db.get_channel(crdb_id, a_id).await.unwrap().is_none());
|
||||
|
||||
// Remove a channel tree
|
||||
let (mut channel_ids, user_ids) = db.remove_channel(rust_id, a_id).await.unwrap();
|
||||
channel_ids.sort();
|
||||
assert_eq!(channel_ids, &[rust_id, cargo_id, cargo_ra_id]);
|
||||
assert_eq!(user_ids, &[a_id]);
|
||||
|
||||
assert!(db.get_channel(rust_id, a_id).await.unwrap().is_none());
|
||||
assert!(db.get_channel(cargo_id, a_id).await.unwrap().is_none());
|
||||
assert!(db.get_channel(cargo_ra_id, a_id).await.unwrap().is_none());
|
||||
}
|
||||
|
||||
test_both_dbs!(
|
||||
test_joining_channels,
|
||||
test_joining_channels_postgres,
|
||||
test_joining_channels_sqlite
|
||||
);
|
||||
|
||||
async fn test_joining_channels(db: &Arc<Database>) {
|
||||
let owner_id = db.create_server("test").await.unwrap().0 as u32;
|
||||
|
||||
let user_1 = db
|
||||
.create_user(
|
||||
"user1@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user1".into(),
|
||||
github_user_id: 5,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
let user_2 = db
|
||||
.create_user(
|
||||
"user2@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user2".into(),
|
||||
github_user_id: 6,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let channel_1 = db
|
||||
.create_root_channel("channel_1", "1", user_1)
|
||||
.await
|
||||
.unwrap();
|
||||
let room_1 = db.room_id_for_channel(channel_1).await.unwrap();
|
||||
|
||||
// can join a room with membership to its channel
|
||||
let joined_room = db
|
||||
.join_room(room_1, user_1, ConnectionId { owner_id, id: 1 })
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(joined_room.room.participants.len(), 1);
|
||||
|
||||
drop(joined_room);
|
||||
// cannot join a room without membership to its channel
|
||||
assert!(db
|
||||
.join_room(room_1, user_2, ConnectionId { owner_id, id: 1 })
|
||||
.await
|
||||
.is_err());
|
||||
}
|
||||
|
||||
test_both_dbs!(
|
||||
test_channel_invites,
|
||||
test_channel_invites_postgres,
|
||||
test_channel_invites_sqlite
|
||||
);
|
||||
|
||||
async fn test_channel_invites(db: &Arc<Database>) {
|
||||
db.create_server("test").await.unwrap();
|
||||
|
||||
let user_1 = db
|
||||
.create_user(
|
||||
"user1@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user1".into(),
|
||||
github_user_id: 5,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
let user_2 = db
|
||||
.create_user(
|
||||
"user2@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user2".into(),
|
||||
github_user_id: 6,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let user_3 = db
|
||||
.create_user(
|
||||
"user3@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user3".into(),
|
||||
github_user_id: 7,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let channel_1_1 = db
|
||||
.create_root_channel("channel_1", "1", user_1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let channel_1_2 = db
|
||||
.create_root_channel("channel_2", "2", user_1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
db.invite_channel_member(channel_1_1, user_2, user_1, false)
|
||||
.await
|
||||
.unwrap();
|
||||
db.invite_channel_member(channel_1_2, user_2, user_1, false)
|
||||
.await
|
||||
.unwrap();
|
||||
db.invite_channel_member(channel_1_1, user_3, user_1, true)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let user_2_invites = db
|
||||
.get_channel_invites_for_user(user_2) // -> [channel_1_1, channel_1_2]
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|channel| channel.id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(user_2_invites, &[channel_1_1, channel_1_2]);
|
||||
|
||||
let user_3_invites = db
|
||||
.get_channel_invites_for_user(user_3) // -> [channel_1_1]
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|channel| channel.id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(user_3_invites, &[channel_1_1]);
|
||||
|
||||
let members = db
|
||||
.get_channel_member_details(channel_1_1, user_1)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
members,
|
||||
&[
|
||||
proto::ChannelMember {
|
||||
user_id: user_1.to_proto(),
|
||||
kind: proto::channel_member::Kind::Member.into(),
|
||||
admin: true,
|
||||
},
|
||||
proto::ChannelMember {
|
||||
user_id: user_2.to_proto(),
|
||||
kind: proto::channel_member::Kind::Invitee.into(),
|
||||
admin: false,
|
||||
},
|
||||
proto::ChannelMember {
|
||||
user_id: user_3.to_proto(),
|
||||
kind: proto::channel_member::Kind::Invitee.into(),
|
||||
admin: true,
|
||||
},
|
||||
]
|
||||
);
|
||||
|
||||
db.respond_to_channel_invite(channel_1_1, user_2, true)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let channel_1_3 = db
|
||||
.create_channel("channel_3", Some(channel_1_1), "1", user_1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let members = db
|
||||
.get_channel_member_details(channel_1_3, user_1)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
members,
|
||||
&[
|
||||
proto::ChannelMember {
|
||||
user_id: user_1.to_proto(),
|
||||
kind: proto::channel_member::Kind::Member.into(),
|
||||
admin: true,
|
||||
},
|
||||
proto::ChannelMember {
|
||||
user_id: user_2.to_proto(),
|
||||
kind: proto::channel_member::Kind::AncestorMember.into(),
|
||||
admin: false,
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
test_both_dbs!(
|
||||
test_channel_renames,
|
||||
test_channel_renames_postgres,
|
||||
test_channel_renames_sqlite
|
||||
);
|
||||
|
||||
async fn test_channel_renames(db: &Arc<Database>) {
|
||||
db.create_server("test").await.unwrap();
|
||||
|
||||
let user_1 = db
|
||||
.create_user(
|
||||
"user1@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user1".into(),
|
||||
github_user_id: 5,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let user_2 = db
|
||||
.create_user(
|
||||
"user2@example.com",
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user2".into(),
|
||||
github_user_id: 6,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let zed_id = db.create_root_channel("zed", "1", user_1).await.unwrap();
|
||||
|
||||
db.rename_channel(zed_id, user_1, "#zed-archive")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let zed_archive_id = zed_id;
|
||||
|
||||
let (channel, _) = db
|
||||
.get_channel(zed_archive_id, user_1)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(channel.name, "zed-archive");
|
||||
|
||||
let non_permissioned_rename = db
|
||||
.rename_channel(zed_archive_id, user_2, "hacked-lol")
|
||||
.await;
|
||||
assert!(non_permissioned_rename.is_err());
|
||||
|
||||
let bad_name_rename = db.rename_channel(zed_id, user_1, "#").await;
|
||||
assert!(bad_name_rename.is_err())
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_multiple_signup_overwrite() {
|
||||
let test_db = TestDb::postgres(build_background_executor());
|
||||
|
||||
@@ -2,7 +2,10 @@ mod connection_pool;
|
||||
|
||||
use crate::{
|
||||
auth,
|
||||
db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
|
||||
db::{
|
||||
self, Channel, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User,
|
||||
UserId,
|
||||
},
|
||||
executor::Executor,
|
||||
AppState, Result,
|
||||
};
|
||||
@@ -243,7 +246,7 @@ impl Server {
|
||||
.add_request_handler(remove_contact)
|
||||
.add_request_handler(respond_to_contact_request)
|
||||
.add_request_handler(create_channel)
|
||||
.add_request_handler(remove_channel)
|
||||
.add_request_handler(delete_channel)
|
||||
.add_request_handler(invite_channel_member)
|
||||
.add_request_handler(remove_channel_member)
|
||||
.add_request_handler(set_channel_member_admin)
|
||||
@@ -251,9 +254,13 @@ impl Server {
|
||||
.add_request_handler(join_channel_buffer)
|
||||
.add_request_handler(leave_channel_buffer)
|
||||
.add_message_handler(update_channel_buffer)
|
||||
.add_request_handler(rejoin_channel_buffers)
|
||||
.add_request_handler(get_channel_members)
|
||||
.add_request_handler(respond_to_channel_invite)
|
||||
.add_request_handler(join_channel)
|
||||
.add_request_handler(link_channel)
|
||||
.add_request_handler(unlink_channel)
|
||||
.add_request_handler(move_channel)
|
||||
.add_request_handler(follow)
|
||||
.add_message_handler(unfollow)
|
||||
.add_message_handler(update_followers)
|
||||
@@ -277,13 +284,33 @@ impl Server {
|
||||
tracing::info!("waiting for cleanup timeout");
|
||||
timeout.await;
|
||||
tracing::info!("cleanup timeout expired, retrieving stale rooms");
|
||||
if let Some(room_ids) = app_state
|
||||
if let Some((room_ids, channel_ids)) = app_state
|
||||
.db
|
||||
.stale_room_ids(&app_state.config.zed_environment, server_id)
|
||||
.stale_server_resource_ids(&app_state.config.zed_environment, server_id)
|
||||
.await
|
||||
.trace_err()
|
||||
{
|
||||
tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
|
||||
tracing::info!(
|
||||
stale_channel_buffer_count = channel_ids.len(),
|
||||
"retrieved stale channel buffers"
|
||||
);
|
||||
|
||||
for channel_id in channel_ids {
|
||||
if let Some(refreshed_channel_buffer) = app_state
|
||||
.db
|
||||
.clear_stale_channel_buffer_collaborators(channel_id, server_id)
|
||||
.await
|
||||
.trace_err()
|
||||
{
|
||||
for connection_id in refreshed_channel_buffer.connection_ids {
|
||||
for message in &refreshed_channel_buffer.removed_collaborators {
|
||||
peer.send(connection_id, message.clone()).trace_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for room_id in room_ids {
|
||||
let mut contacts_to_update = HashSet::default();
|
||||
let mut canceled_calls_to_user_ids = Vec::new();
|
||||
@@ -292,7 +319,7 @@ impl Server {
|
||||
|
||||
if let Some(mut refreshed_room) = app_state
|
||||
.db
|
||||
.refresh_room(room_id, server_id)
|
||||
.clear_stale_room_participants(room_id, server_id)
|
||||
.await
|
||||
.trace_err()
|
||||
{
|
||||
@@ -854,13 +881,13 @@ async fn connection_lost(
|
||||
.await
|
||||
.trace_err();
|
||||
|
||||
leave_channel_buffers_for_session(&session)
|
||||
.await
|
||||
.trace_err();
|
||||
|
||||
futures::select_biased! {
|
||||
_ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
|
||||
log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
|
||||
leave_room_for_session(&session).await.trace_err();
|
||||
leave_channel_buffers_for_session(&session)
|
||||
.await
|
||||
.trace_err();
|
||||
|
||||
if !session
|
||||
.connection_pool()
|
||||
@@ -2206,23 +2233,23 @@ async fn create_channel(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_channel(
|
||||
request: proto::RemoveChannel,
|
||||
response: Response<proto::RemoveChannel>,
|
||||
async fn delete_channel(
|
||||
request: proto::DeleteChannel,
|
||||
response: Response<proto::DeleteChannel>,
|
||||
session: Session,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
let channel_id = request.channel_id;
|
||||
let (removed_channels, member_ids) = db
|
||||
.remove_channel(ChannelId::from_proto(channel_id), session.user_id)
|
||||
.delete_channel(ChannelId::from_proto(channel_id), session.user_id)
|
||||
.await?;
|
||||
response.send(proto::Ack {})?;
|
||||
|
||||
// Notify members of removed channels
|
||||
let mut update = proto::UpdateChannels::default();
|
||||
update
|
||||
.remove_channels
|
||||
.delete_channels
|
||||
.extend(removed_channels.into_iter().map(|id| id.to_proto()));
|
||||
|
||||
let connection_pool = session.connection_pool().await;
|
||||
@@ -2282,7 +2309,7 @@ async fn remove_channel_member(
|
||||
.await?;
|
||||
|
||||
let mut update = proto::UpdateChannels::default();
|
||||
update.remove_channels.push(channel_id.to_proto());
|
||||
update.delete_channels.push(channel_id.to_proto());
|
||||
|
||||
for connection_id in session
|
||||
.connection_pool()
|
||||
@@ -2366,6 +2393,126 @@ async fn rename_channel(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn link_channel(
|
||||
request: proto::LinkChannel,
|
||||
response: Response<proto::LinkChannel>,
|
||||
session: Session,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
let to = ChannelId::from_proto(request.to);
|
||||
let channels_to_send = db.link_channel(session.user_id, channel_id, to).await?;
|
||||
|
||||
let members = db.get_channel_members(to).await?;
|
||||
let connection_pool = session.connection_pool().await;
|
||||
let update = proto::UpdateChannels {
|
||||
channels: channels_to_send
|
||||
.into_iter()
|
||||
.map(|channel| proto::Channel {
|
||||
id: channel.id.to_proto(),
|
||||
name: channel.name,
|
||||
parent_id: channel.parent_id.map(ChannelId::to_proto),
|
||||
})
|
||||
.collect(),
|
||||
..Default::default()
|
||||
};
|
||||
for member_id in members {
|
||||
for connection_id in connection_pool.user_connection_ids(member_id) {
|
||||
session.peer.send(connection_id, update.clone())?;
|
||||
}
|
||||
}
|
||||
|
||||
response.send(Ack {})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unlink_channel(
|
||||
request: proto::UnlinkChannel,
|
||||
response: Response<proto::UnlinkChannel>,
|
||||
session: Session,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
let from = request.from.map(ChannelId::from_proto);
|
||||
db.unlink_channel(session.user_id, channel_id, from).await?;
|
||||
|
||||
if let Some(from_parent) = from {
|
||||
let members = db.get_channel_members(from_parent).await?;
|
||||
let update = proto::UpdateChannels {
|
||||
delete_channel_edge: vec![proto::ChannelEdge {
|
||||
channel_id: channel_id.to_proto(),
|
||||
parent_id: from_parent.to_proto(),
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
let connection_pool = session.connection_pool().await;
|
||||
for member_id in members {
|
||||
for connection_id in connection_pool.user_connection_ids(member_id) {
|
||||
session.peer.send(connection_id, update.clone())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response.send(Ack {})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn move_channel(
|
||||
request: proto::MoveChannel,
|
||||
response: Response<proto::MoveChannel>,
|
||||
session: Session,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
let from_parent = request.from.map(ChannelId::from_proto);
|
||||
let to = ChannelId::from_proto(request.to);
|
||||
let channels_to_send: Vec<Channel> = db
|
||||
.move_channel(session.user_id, channel_id, from_parent, to)
|
||||
.await?;
|
||||
|
||||
if let Some(from_parent) = from_parent {
|
||||
let members = db.get_channel_members(from_parent).await?;
|
||||
let update = proto::UpdateChannels {
|
||||
delete_channel_edge: vec![proto::ChannelEdge {
|
||||
channel_id: channel_id.to_proto(),
|
||||
parent_id: from_parent.to_proto(),
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
let connection_pool = session.connection_pool().await;
|
||||
for member_id in members {
|
||||
for connection_id in connection_pool.user_connection_ids(member_id) {
|
||||
session.peer.send(connection_id, update.clone())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let members = db.get_channel_members(to).await?;
|
||||
let connection_pool = session.connection_pool().await;
|
||||
let update = proto::UpdateChannels {
|
||||
channels: channels_to_send
|
||||
.into_iter()
|
||||
.map(|channel| proto::Channel {
|
||||
id: channel.id.to_proto(),
|
||||
name: channel.name,
|
||||
parent_id: channel.parent_id.map(ChannelId::to_proto),
|
||||
})
|
||||
.collect(),
|
||||
..Default::default()
|
||||
};
|
||||
for member_id in members {
|
||||
for connection_id in connection_pool.user_connection_ids(member_id) {
|
||||
session.peer.send(connection_id, update.clone())?;
|
||||
}
|
||||
}
|
||||
|
||||
response.send(Ack {})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_channel_members(
|
||||
request: proto::GetChannelMembers,
|
||||
response: Response<proto::GetChannelMembers>,
|
||||
@@ -2547,6 +2694,41 @@ async fn update_channel_buffer(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn rejoin_channel_buffers(
|
||||
request: proto::RejoinChannelBuffers,
|
||||
response: Response<proto::RejoinChannelBuffers>,
|
||||
session: Session,
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let buffers = db
|
||||
.rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
|
||||
.await?;
|
||||
|
||||
for buffer in &buffers {
|
||||
let collaborators_to_notify = buffer
|
||||
.buffer
|
||||
.collaborators
|
||||
.iter()
|
||||
.filter_map(|c| Some(c.peer_id?.into()));
|
||||
channel_buffer_updated(
|
||||
session.connection_id,
|
||||
collaborators_to_notify,
|
||||
&proto::UpdateChannelBufferCollaborator {
|
||||
channel_id: buffer.buffer.channel_id,
|
||||
old_peer_id: Some(buffer.old_connection_id.into()),
|
||||
new_peer_id: Some(session.connection_id.into()),
|
||||
},
|
||||
&session.peer,
|
||||
);
|
||||
}
|
||||
|
||||
response.send(proto::RejoinChannelBuffersResponse {
|
||||
buffers: buffers.into_iter().map(|b| b.buffer).collect(),
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn leave_channel_buffer(
|
||||
request: proto::LeaveChannelBuffer,
|
||||
response: Response<proto::LeaveChannelBuffer>,
|
||||
|
||||
@@ -1,555 +1,18 @@
|
||||
use crate::{
|
||||
db::{tests::TestDb, NewUserParams, UserId},
|
||||
executor::Executor,
|
||||
rpc::{Server, CLEANUP_TIMEOUT},
|
||||
AppState,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use call::{ActiveCall, Room};
|
||||
use channel::ChannelStore;
|
||||
use client::{
|
||||
self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
|
||||
};
|
||||
use collections::{HashMap, HashSet};
|
||||
use fs::FakeFs;
|
||||
use futures::{channel::oneshot, StreamExt as _};
|
||||
use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle};
|
||||
use language::LanguageRegistry;
|
||||
use parking_lot::Mutex;
|
||||
use project::{Project, WorktreeId};
|
||||
use settings::SettingsStore;
|
||||
use std::{
|
||||
cell::{Ref, RefCell, RefMut},
|
||||
env,
|
||||
ops::{Deref, DerefMut},
|
||||
path::Path,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use util::http::FakeHttpClient;
|
||||
use workspace::Workspace;
|
||||
use call::Room;
|
||||
use gpui::{ModelHandle, TestAppContext};
|
||||
|
||||
mod channel_buffer_tests;
|
||||
mod channel_tests;
|
||||
mod integration_tests;
|
||||
mod randomized_integration_tests;
|
||||
mod random_channel_buffer_tests;
|
||||
mod random_project_collaboration_tests;
|
||||
mod randomized_test_helpers;
|
||||
mod test_server;
|
||||
|
||||
struct TestServer {
|
||||
app_state: Arc<AppState>,
|
||||
server: Arc<Server>,
|
||||
connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
|
||||
forbid_connections: Arc<AtomicBool>,
|
||||
_test_db: TestDb,
|
||||
test_live_kit_server: Arc<live_kit_client::TestServer>,
|
||||
}
|
||||
|
||||
impl TestServer {
|
||||
async fn start(deterministic: &Arc<Deterministic>) -> Self {
|
||||
static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
let use_postgres = env::var("USE_POSTGRES").ok();
|
||||
let use_postgres = use_postgres.as_deref();
|
||||
let test_db = if use_postgres == Some("true") || use_postgres == Some("1") {
|
||||
TestDb::postgres(deterministic.build_background())
|
||||
} else {
|
||||
TestDb::sqlite(deterministic.build_background())
|
||||
};
|
||||
let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst);
|
||||
let live_kit_server = live_kit_client::TestServer::create(
|
||||
format!("http://livekit.{}.test", live_kit_server_id),
|
||||
format!("devkey-{}", live_kit_server_id),
|
||||
format!("secret-{}", live_kit_server_id),
|
||||
deterministic.build_background(),
|
||||
)
|
||||
.unwrap();
|
||||
let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
|
||||
let epoch = app_state
|
||||
.db
|
||||
.create_server(&app_state.config.zed_environment)
|
||||
.await
|
||||
.unwrap();
|
||||
let server = Server::new(
|
||||
epoch,
|
||||
app_state.clone(),
|
||||
Executor::Deterministic(deterministic.build_background()),
|
||||
);
|
||||
server.start().await.unwrap();
|
||||
// Advance clock to ensure the server's cleanup task is finished.
|
||||
deterministic.advance_clock(CLEANUP_TIMEOUT);
|
||||
Self {
|
||||
app_state,
|
||||
server,
|
||||
connection_killers: Default::default(),
|
||||
forbid_connections: Default::default(),
|
||||
_test_db: test_db,
|
||||
test_live_kit_server: live_kit_server,
|
||||
}
|
||||
}
|
||||
|
||||
async fn reset(&self) {
|
||||
self.app_state.db.reset();
|
||||
let epoch = self
|
||||
.app_state
|
||||
.db
|
||||
.create_server(&self.app_state.config.zed_environment)
|
||||
.await
|
||||
.unwrap();
|
||||
self.server.reset(epoch);
|
||||
}
|
||||
|
||||
async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
|
||||
cx.update(|cx| {
|
||||
if cx.has_global::<SettingsStore>() {
|
||||
panic!("Same cx used to create two test clients")
|
||||
}
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
});
|
||||
|
||||
let http = FakeHttpClient::with_404_response();
|
||||
let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
|
||||
{
|
||||
user.id
|
||||
} else {
|
||||
self.app_state
|
||||
.db
|
||||
.create_user(
|
||||
&format!("{name}@example.com"),
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: name.into(),
|
||||
github_user_id: 0,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("creating user failed")
|
||||
.user_id
|
||||
};
|
||||
let client_name = name.to_string();
|
||||
let mut client = cx.read(|cx| Client::new(http.clone(), cx));
|
||||
let server = self.server.clone();
|
||||
let db = self.app_state.db.clone();
|
||||
let connection_killers = self.connection_killers.clone();
|
||||
let forbid_connections = self.forbid_connections.clone();
|
||||
|
||||
Arc::get_mut(&mut client)
|
||||
.unwrap()
|
||||
.set_id(user_id.0 as usize)
|
||||
.override_authenticate(move |cx| {
|
||||
cx.spawn(|_| async move {
|
||||
let access_token = "the-token".to_string();
|
||||
Ok(Credentials {
|
||||
user_id: user_id.0 as u64,
|
||||
access_token,
|
||||
})
|
||||
})
|
||||
})
|
||||
.override_establish_connection(move |credentials, cx| {
|
||||
assert_eq!(credentials.user_id, user_id.0 as u64);
|
||||
assert_eq!(credentials.access_token, "the-token");
|
||||
|
||||
let server = server.clone();
|
||||
let db = db.clone();
|
||||
let connection_killers = connection_killers.clone();
|
||||
let forbid_connections = forbid_connections.clone();
|
||||
let client_name = client_name.clone();
|
||||
cx.spawn(move |cx| async move {
|
||||
if forbid_connections.load(SeqCst) {
|
||||
Err(EstablishConnectionError::other(anyhow!(
|
||||
"server is forbidding connections"
|
||||
)))
|
||||
} else {
|
||||
let (client_conn, server_conn, killed) =
|
||||
Connection::in_memory(cx.background());
|
||||
let (connection_id_tx, connection_id_rx) = oneshot::channel();
|
||||
let user = db
|
||||
.get_user_by_id(user_id)
|
||||
.await
|
||||
.expect("retrieving user failed")
|
||||
.unwrap();
|
||||
cx.background()
|
||||
.spawn(server.handle_connection(
|
||||
server_conn,
|
||||
client_name,
|
||||
user,
|
||||
Some(connection_id_tx),
|
||||
Executor::Deterministic(cx.background()),
|
||||
))
|
||||
.detach();
|
||||
let connection_id = connection_id_rx.await.unwrap();
|
||||
connection_killers
|
||||
.lock()
|
||||
.insert(connection_id.into(), killed);
|
||||
Ok(client_conn)
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.background());
|
||||
let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
|
||||
let channel_store =
|
||||
cx.add_model(|cx| ChannelStore::new(client.clone(), user_store.clone(), cx));
|
||||
let app_state = Arc::new(workspace::AppState {
|
||||
client: client.clone(),
|
||||
user_store: user_store.clone(),
|
||||
channel_store: channel_store.clone(),
|
||||
languages: Arc::new(LanguageRegistry::test()),
|
||||
fs: fs.clone(),
|
||||
build_window_options: |_, _, _| Default::default(),
|
||||
initialize_workspace: |_, _, _, _| Task::ready(Ok(())),
|
||||
background_actions: || &[],
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
theme::init((), cx);
|
||||
Project::init(&client, cx);
|
||||
client::init(&client, cx);
|
||||
language::init(cx);
|
||||
editor::init_settings(cx);
|
||||
workspace::init(app_state.clone(), cx);
|
||||
audio::init((), cx);
|
||||
call::init(client.clone(), user_store.clone(), cx);
|
||||
channel::init(&client);
|
||||
});
|
||||
|
||||
client
|
||||
.authenticate_and_connect(false, &cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let client = TestClient {
|
||||
app_state,
|
||||
username: name.to_string(),
|
||||
state: Default::default(),
|
||||
};
|
||||
client.wait_for_current_user(cx).await;
|
||||
client
|
||||
}
|
||||
|
||||
fn disconnect_client(&self, peer_id: PeerId) {
|
||||
self.connection_killers
|
||||
.lock()
|
||||
.remove(&peer_id)
|
||||
.unwrap()
|
||||
.store(true, SeqCst);
|
||||
}
|
||||
|
||||
fn forbid_connections(&self) {
|
||||
self.forbid_connections.store(true, SeqCst);
|
||||
}
|
||||
|
||||
fn allow_connections(&self) {
|
||||
self.forbid_connections.store(false, SeqCst);
|
||||
}
|
||||
|
||||
async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
|
||||
for ix in 1..clients.len() {
|
||||
let (left, right) = clients.split_at_mut(ix);
|
||||
let (client_a, cx_a) = left.last_mut().unwrap();
|
||||
for (client_b, cx_b) in right {
|
||||
client_a
|
||||
.app_state
|
||||
.user_store
|
||||
.update(*cx_a, |store, cx| {
|
||||
store.request_contact(client_b.user_id().unwrap(), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
cx_a.foreground().run_until_parked();
|
||||
client_b
|
||||
.app_state
|
||||
.user_store
|
||||
.update(*cx_b, |store, cx| {
|
||||
store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn make_channel(
|
||||
&self,
|
||||
channel: &str,
|
||||
admin: (&TestClient, &mut TestAppContext),
|
||||
members: &mut [(&TestClient, &mut TestAppContext)],
|
||||
) -> u64 {
|
||||
let (admin_client, admin_cx) = admin;
|
||||
let channel_id = admin_client
|
||||
.app_state
|
||||
.channel_store
|
||||
.update(admin_cx, |channel_store, cx| {
|
||||
channel_store.create_channel(channel, None, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
for (member_client, member_cx) in members {
|
||||
admin_client
|
||||
.app_state
|
||||
.channel_store
|
||||
.update(admin_cx, |channel_store, cx| {
|
||||
channel_store.invite_member(
|
||||
channel_id,
|
||||
member_client.user_id().unwrap(),
|
||||
false,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
admin_cx.foreground().run_until_parked();
|
||||
|
||||
member_client
|
||||
.app_state
|
||||
.channel_store
|
||||
.update(*member_cx, |channels, _| {
|
||||
channels.respond_to_channel_invite(channel_id, true)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
channel_id
|
||||
}
|
||||
|
||||
async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
|
||||
self.make_contacts(clients).await;
|
||||
|
||||
let (left, right) = clients.split_at_mut(1);
|
||||
let (_client_a, cx_a) = &mut left[0];
|
||||
let active_call_a = cx_a.read(ActiveCall::global);
|
||||
|
||||
for (client_b, cx_b) in right {
|
||||
let user_id_b = client_b.current_user_id(*cx_b).to_proto();
|
||||
active_call_a
|
||||
.update(*cx_a, |call, cx| call.invite(user_id_b, None, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx_b.foreground().run_until_parked();
|
||||
let active_call_b = cx_b.read(ActiveCall::global);
|
||||
active_call_b
|
||||
.update(*cx_b, |call, cx| call.accept_incoming(cx))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
async fn build_app_state(
|
||||
test_db: &TestDb,
|
||||
fake_server: &live_kit_client::TestServer,
|
||||
) -> Arc<AppState> {
|
||||
Arc::new(AppState {
|
||||
db: test_db.db().clone(),
|
||||
live_kit_client: Some(Arc::new(fake_server.create_api_client())),
|
||||
config: Default::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for TestServer {
|
||||
type Target = Server;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.server
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TestServer {
|
||||
fn drop(&mut self) {
|
||||
self.server.teardown();
|
||||
self.test_live_kit_server.teardown().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
struct TestClient {
|
||||
username: String,
|
||||
state: RefCell<TestClientState>,
|
||||
app_state: Arc<workspace::AppState>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct TestClientState {
|
||||
local_projects: Vec<ModelHandle<Project>>,
|
||||
remote_projects: Vec<ModelHandle<Project>>,
|
||||
buffers: HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>,
|
||||
}
|
||||
|
||||
impl Deref for TestClient {
|
||||
type Target = Arc<Client>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.app_state.client
|
||||
}
|
||||
}
|
||||
|
||||
struct ContactsSummary {
|
||||
pub current: Vec<String>,
|
||||
pub outgoing_requests: Vec<String>,
|
||||
pub incoming_requests: Vec<String>,
|
||||
}
|
||||
|
||||
impl TestClient {
|
||||
pub fn fs(&self) -> &FakeFs {
|
||||
self.app_state.fs.as_fake()
|
||||
}
|
||||
|
||||
pub fn channel_store(&self) -> &ModelHandle<ChannelStore> {
|
||||
&self.app_state.channel_store
|
||||
}
|
||||
|
||||
pub fn user_store(&self) -> &ModelHandle<UserStore> {
|
||||
&self.app_state.user_store
|
||||
}
|
||||
|
||||
pub fn language_registry(&self) -> &Arc<LanguageRegistry> {
|
||||
&self.app_state.languages
|
||||
}
|
||||
|
||||
pub fn client(&self) -> &Arc<Client> {
|
||||
&self.app_state.client
|
||||
}
|
||||
|
||||
pub fn current_user_id(&self, cx: &TestAppContext) -> UserId {
|
||||
UserId::from_proto(
|
||||
self.app_state
|
||||
.user_store
|
||||
.read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
|
||||
)
|
||||
}
|
||||
|
||||
async fn wait_for_current_user(&self, cx: &TestAppContext) {
|
||||
let mut authed_user = self
|
||||
.app_state
|
||||
.user_store
|
||||
.read_with(cx, |user_store, _| user_store.watch_current_user());
|
||||
while authed_user.next().await.unwrap().is_none() {}
|
||||
}
|
||||
|
||||
async fn clear_contacts(&self, cx: &mut TestAppContext) {
|
||||
self.app_state
|
||||
.user_store
|
||||
.update(cx, |store, _| store.clear_contacts())
|
||||
.await;
|
||||
}
|
||||
|
||||
fn local_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
|
||||
Ref::map(self.state.borrow(), |state| &state.local_projects)
|
||||
}
|
||||
|
||||
fn remote_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
|
||||
Ref::map(self.state.borrow(), |state| &state.remote_projects)
|
||||
}
|
||||
|
||||
fn local_projects_mut<'a>(&'a self) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
|
||||
RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects)
|
||||
}
|
||||
|
||||
fn remote_projects_mut<'a>(&'a self) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
|
||||
RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects)
|
||||
}
|
||||
|
||||
fn buffers_for_project<'a>(
|
||||
&'a self,
|
||||
project: &ModelHandle<Project>,
|
||||
) -> impl DerefMut<Target = HashSet<ModelHandle<language::Buffer>>> + 'a {
|
||||
RefMut::map(self.state.borrow_mut(), |state| {
|
||||
state.buffers.entry(project.clone()).or_default()
|
||||
})
|
||||
}
|
||||
|
||||
fn buffers<'a>(
|
||||
&'a self,
|
||||
) -> impl DerefMut<Target = HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>> + 'a
|
||||
{
|
||||
RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers)
|
||||
}
|
||||
|
||||
fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary {
|
||||
self.app_state
|
||||
.user_store
|
||||
.read_with(cx, |store, _| ContactsSummary {
|
||||
current: store
|
||||
.contacts()
|
||||
.iter()
|
||||
.map(|contact| contact.user.github_login.clone())
|
||||
.collect(),
|
||||
outgoing_requests: store
|
||||
.outgoing_contact_requests()
|
||||
.iter()
|
||||
.map(|user| user.github_login.clone())
|
||||
.collect(),
|
||||
incoming_requests: store
|
||||
.incoming_contact_requests()
|
||||
.iter()
|
||||
.map(|user| user.github_login.clone())
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn build_local_project(
|
||||
&self,
|
||||
root_path: impl AsRef<Path>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> (ModelHandle<Project>, WorktreeId) {
|
||||
let project = cx.update(|cx| {
|
||||
Project::local(
|
||||
self.client().clone(),
|
||||
self.app_state.user_store.clone(),
|
||||
self.app_state.languages.clone(),
|
||||
self.app_state.fs.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let (worktree, _) = project
|
||||
.update(cx, |p, cx| {
|
||||
p.find_or_create_local_worktree(root_path, true, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
worktree
|
||||
.read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete())
|
||||
.await;
|
||||
(project, worktree.read_with(cx, |tree, _| tree.id()))
|
||||
}
|
||||
|
||||
async fn build_remote_project(
|
||||
&self,
|
||||
host_project_id: u64,
|
||||
guest_cx: &mut TestAppContext,
|
||||
) -> ModelHandle<Project> {
|
||||
let active_call = guest_cx.read(ActiveCall::global);
|
||||
let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone());
|
||||
room.update(guest_cx, |room, cx| {
|
||||
room.join_project(
|
||||
host_project_id,
|
||||
self.app_state.languages.clone(),
|
||||
self.app_state.fs.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn build_workspace(
|
||||
&self,
|
||||
project: &ModelHandle<Project>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> WindowHandle<Workspace> {
|
||||
cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TestClient {
|
||||
fn drop(&mut self) {
|
||||
self.app_state.client.teardown();
|
||||
}
|
||||
}
|
||||
pub use randomized_test_helpers::{
|
||||
run_randomized_test, save_randomized_test_plan, RandomizedTest, TestError, UserTestPlan,
|
||||
};
|
||||
pub use test_server::{TestClient, TestServer};
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
struct RoomParticipants {
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
|
||||
use crate::{
|
||||
rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
|
||||
tests::TestServer,
|
||||
};
|
||||
use call::ActiveCall;
|
||||
use channel::Channel;
|
||||
use client::UserId;
|
||||
@@ -21,20 +24,19 @@ async fn test_core_channel_buffers(
|
||||
let client_a = server.create_client(cx_a, "user_a").await;
|
||||
let client_b = server.create_client(cx_b, "user_b").await;
|
||||
|
||||
let zed_id = server
|
||||
let channel_id = server
|
||||
.make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)])
|
||||
.await;
|
||||
|
||||
// Client A joins the channel buffer
|
||||
let channel_buffer_a = client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx))
|
||||
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Client A edits the buffer
|
||||
let buffer_a = channel_buffer_a.read_with(cx_a, |buffer, _| buffer.buffer());
|
||||
|
||||
buffer_a.update(cx_a, |buffer, cx| {
|
||||
buffer.edit([(0..0, "hello world")], None, cx)
|
||||
});
|
||||
@@ -45,17 +47,15 @@ async fn test_core_channel_buffers(
|
||||
buffer.edit([(0..5, "goodbye")], None, cx)
|
||||
});
|
||||
buffer_a.update(cx_a, |buffer, cx| buffer.undo(cx));
|
||||
deterministic.run_until_parked();
|
||||
|
||||
assert_eq!(buffer_text(&buffer_a, cx_a), "hello, cruel world");
|
||||
deterministic.run_until_parked();
|
||||
|
||||
// Client B joins the channel buffer
|
||||
let channel_buffer_b = client_b
|
||||
.channel_store()
|
||||
.update(cx_b, |channel, cx| channel.open_channel_buffer(zed_id, cx))
|
||||
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
channel_buffer_b.read_with(cx_b, |buffer, _| {
|
||||
assert_collaborators(
|
||||
buffer.collaborators(),
|
||||
@@ -91,9 +91,7 @@ async fn test_core_channel_buffers(
|
||||
// Client A rejoins the channel buffer
|
||||
let _channel_buffer_a = client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |channels, cx| {
|
||||
channels.open_channel_buffer(zed_id, cx)
|
||||
})
|
||||
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
deterministic.run_until_parked();
|
||||
@@ -136,7 +134,7 @@ async fn test_channel_buffer_replica_ids(
|
||||
|
||||
let channel_id = server
|
||||
.make_channel(
|
||||
"zed",
|
||||
"the-channel",
|
||||
(&client_a, cx_a),
|
||||
&mut [(&client_b, cx_b), (&client_c, cx_c)],
|
||||
)
|
||||
@@ -160,23 +158,17 @@ async fn test_channel_buffer_replica_ids(
|
||||
// C first so that the replica IDs in the project and the channel buffer are different
|
||||
let channel_buffer_c = client_c
|
||||
.channel_store()
|
||||
.update(cx_c, |channel, cx| {
|
||||
channel.open_channel_buffer(channel_id, cx)
|
||||
})
|
||||
.update(cx_c, |store, cx| store.open_channel_buffer(channel_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
let channel_buffer_b = client_b
|
||||
.channel_store()
|
||||
.update(cx_b, |channel, cx| {
|
||||
channel.open_channel_buffer(channel_id, cx)
|
||||
})
|
||||
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
let channel_buffer_a = client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |channel, cx| {
|
||||
channel.open_channel_buffer(channel_id, cx)
|
||||
})
|
||||
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -286,28 +278,30 @@ async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mu
|
||||
let mut server = TestServer::start(&deterministic).await;
|
||||
let client_a = server.create_client(cx_a, "user_a").await;
|
||||
|
||||
let zed_id = server.make_channel("zed", (&client_a, cx_a), &mut []).await;
|
||||
let channel_id = server
|
||||
.make_channel("the-channel", (&client_a, cx_a), &mut [])
|
||||
.await;
|
||||
|
||||
let channel_buffer_1 = client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
|
||||
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx));
|
||||
let channel_buffer_2 = client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
|
||||
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx));
|
||||
let channel_buffer_3 = client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx));
|
||||
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx));
|
||||
|
||||
// All concurrent tasks for opening a channel buffer return the same model handle.
|
||||
let (channel_buffer_1, channel_buffer_2, channel_buffer_3) =
|
||||
let (channel_buffer, channel_buffer_2, channel_buffer_3) =
|
||||
future::try_join3(channel_buffer_1, channel_buffer_2, channel_buffer_3)
|
||||
.await
|
||||
.unwrap();
|
||||
let model_id = channel_buffer_1.id();
|
||||
assert_eq!(channel_buffer_1, channel_buffer_2);
|
||||
assert_eq!(channel_buffer_1, channel_buffer_3);
|
||||
let channel_buffer_model_id = channel_buffer.id();
|
||||
assert_eq!(channel_buffer, channel_buffer_2);
|
||||
assert_eq!(channel_buffer, channel_buffer_3);
|
||||
|
||||
channel_buffer_1.update(cx_a, |buffer, cx| {
|
||||
channel_buffer.update(cx_a, |buffer, cx| {
|
||||
buffer.buffer().update(cx, |buffer, cx| {
|
||||
buffer.edit([(0..0, "hello")], None, cx);
|
||||
})
|
||||
@@ -315,7 +309,7 @@ async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mu
|
||||
deterministic.run_until_parked();
|
||||
|
||||
cx_a.update(|_| {
|
||||
drop(channel_buffer_1);
|
||||
drop(channel_buffer);
|
||||
drop(channel_buffer_2);
|
||||
drop(channel_buffer_3);
|
||||
});
|
||||
@@ -324,10 +318,10 @@ async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mu
|
||||
// The channel buffer can be reopened after dropping it.
|
||||
let channel_buffer = client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |channel, cx| channel.open_channel_buffer(zed_id, cx))
|
||||
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_ne!(channel_buffer.id(), model_id);
|
||||
assert_ne!(channel_buffer.id(), channel_buffer_model_id);
|
||||
channel_buffer.update(cx_a, |buffer, cx| {
|
||||
buffer.buffer().update(cx, |buffer, _| {
|
||||
assert_eq!(buffer.text(), "hello");
|
||||
@@ -347,22 +341,17 @@ async fn test_channel_buffer_disconnect(
|
||||
let client_b = server.create_client(cx_b, "user_b").await;
|
||||
|
||||
let channel_id = server
|
||||
.make_channel("zed", (&client_a, cx_a), &mut [(&client_b, cx_b)])
|
||||
.make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)])
|
||||
.await;
|
||||
|
||||
let channel_buffer_a = client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |channel, cx| {
|
||||
channel.open_channel_buffer(channel_id, cx)
|
||||
})
|
||||
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let channel_buffer_b = client_b
|
||||
.channel_store()
|
||||
.update(cx_b, |channel, cx| {
|
||||
channel.open_channel_buffer(channel_id, cx)
|
||||
})
|
||||
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -375,7 +364,7 @@ async fn test_channel_buffer_disconnect(
|
||||
buffer.channel().as_ref(),
|
||||
&Channel {
|
||||
id: channel_id,
|
||||
name: "zed".to_string()
|
||||
name: "the-channel".to_string()
|
||||
}
|
||||
);
|
||||
assert!(!buffer.is_connected());
|
||||
@@ -403,13 +392,180 @@ async fn test_channel_buffer_disconnect(
|
||||
buffer.channel().as_ref(),
|
||||
&Channel {
|
||||
id: channel_id,
|
||||
name: "zed".to_string()
|
||||
name: "the-channel".to_string()
|
||||
}
|
||||
);
|
||||
assert!(!buffer.is_connected());
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_rejoin_channel_buffer(
|
||||
deterministic: Arc<Deterministic>,
|
||||
cx_a: &mut TestAppContext,
|
||||
cx_b: &mut TestAppContext,
|
||||
) {
|
||||
deterministic.forbid_parking();
|
||||
let mut server = TestServer::start(&deterministic).await;
|
||||
let client_a = server.create_client(cx_a, "user_a").await;
|
||||
let client_b = server.create_client(cx_b, "user_b").await;
|
||||
|
||||
let channel_id = server
|
||||
.make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)])
|
||||
.await;
|
||||
|
||||
let channel_buffer_a = client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
let channel_buffer_b = client_b
|
||||
.channel_store()
|
||||
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
channel_buffer_a.update(cx_a, |buffer, cx| {
|
||||
buffer.buffer().update(cx, |buffer, cx| {
|
||||
buffer.edit([(0..0, "1")], None, cx);
|
||||
})
|
||||
});
|
||||
deterministic.run_until_parked();
|
||||
|
||||
// Client A disconnects.
|
||||
server.forbid_connections();
|
||||
server.disconnect_client(client_a.peer_id().unwrap());
|
||||
|
||||
// Both clients make an edit.
|
||||
channel_buffer_a.update(cx_a, |buffer, cx| {
|
||||
buffer.buffer().update(cx, |buffer, cx| {
|
||||
buffer.edit([(1..1, "2")], None, cx);
|
||||
})
|
||||
});
|
||||
channel_buffer_b.update(cx_b, |buffer, cx| {
|
||||
buffer.buffer().update(cx, |buffer, cx| {
|
||||
buffer.edit([(0..0, "0")], None, cx);
|
||||
})
|
||||
});
|
||||
|
||||
// Both clients see their own edit.
|
||||
deterministic.run_until_parked();
|
||||
channel_buffer_a.read_with(cx_a, |buffer, cx| {
|
||||
assert_eq!(buffer.buffer().read(cx).text(), "12");
|
||||
});
|
||||
channel_buffer_b.read_with(cx_b, |buffer, cx| {
|
||||
assert_eq!(buffer.buffer().read(cx).text(), "01");
|
||||
});
|
||||
|
||||
// Client A reconnects. Both clients see each other's edits, and see
|
||||
// the same collaborators.
|
||||
server.allow_connections();
|
||||
deterministic.advance_clock(RECEIVE_TIMEOUT);
|
||||
channel_buffer_a.read_with(cx_a, |buffer, cx| {
|
||||
assert_eq!(buffer.buffer().read(cx).text(), "012");
|
||||
});
|
||||
channel_buffer_b.read_with(cx_b, |buffer, cx| {
|
||||
assert_eq!(buffer.buffer().read(cx).text(), "012");
|
||||
});
|
||||
|
||||
channel_buffer_a.read_with(cx_a, |buffer_a, _| {
|
||||
channel_buffer_b.read_with(cx_b, |buffer_b, _| {
|
||||
assert_eq!(buffer_a.collaborators(), buffer_b.collaborators());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_channel_buffers_and_server_restarts(
|
||||
deterministic: Arc<Deterministic>,
|
||||
cx_a: &mut TestAppContext,
|
||||
cx_b: &mut TestAppContext,
|
||||
cx_c: &mut TestAppContext,
|
||||
) {
|
||||
deterministic.forbid_parking();
|
||||
let mut server = TestServer::start(&deterministic).await;
|
||||
let client_a = server.create_client(cx_a, "user_a").await;
|
||||
let client_b = server.create_client(cx_b, "user_b").await;
|
||||
let client_c = server.create_client(cx_c, "user_c").await;
|
||||
|
||||
let channel_id = server
|
||||
.make_channel(
|
||||
"the-channel",
|
||||
(&client_a, cx_a),
|
||||
&mut [(&client_b, cx_b), (&client_c, cx_c)],
|
||||
)
|
||||
.await;
|
||||
|
||||
let channel_buffer_a = client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
let channel_buffer_b = client_b
|
||||
.channel_store()
|
||||
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
let _channel_buffer_c = client_c
|
||||
.channel_store()
|
||||
.update(cx_c, |store, cx| store.open_channel_buffer(channel_id, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
channel_buffer_a.update(cx_a, |buffer, cx| {
|
||||
buffer.buffer().update(cx, |buffer, cx| {
|
||||
buffer.edit([(0..0, "1")], None, cx);
|
||||
})
|
||||
});
|
||||
deterministic.run_until_parked();
|
||||
|
||||
// Client C can't reconnect.
|
||||
client_c.override_establish_connection(|_, cx| cx.spawn(|_| future::pending()));
|
||||
|
||||
// Server stops.
|
||||
server.reset().await;
|
||||
deterministic.advance_clock(RECEIVE_TIMEOUT);
|
||||
|
||||
// While the server is down, both clients make an edit.
|
||||
channel_buffer_a.update(cx_a, |buffer, cx| {
|
||||
buffer.buffer().update(cx, |buffer, cx| {
|
||||
buffer.edit([(1..1, "2")], None, cx);
|
||||
})
|
||||
});
|
||||
channel_buffer_b.update(cx_b, |buffer, cx| {
|
||||
buffer.buffer().update(cx, |buffer, cx| {
|
||||
buffer.edit([(0..0, "0")], None, cx);
|
||||
})
|
||||
});
|
||||
|
||||
// Server restarts.
|
||||
server.start().await.unwrap();
|
||||
deterministic.advance_clock(CLEANUP_TIMEOUT);
|
||||
|
||||
// Clients reconnects. Clients A and B see each other's edits, and see
|
||||
// that client C has disconnected.
|
||||
channel_buffer_a.read_with(cx_a, |buffer, cx| {
|
||||
assert_eq!(buffer.buffer().read(cx).text(), "012");
|
||||
});
|
||||
channel_buffer_b.read_with(cx_b, |buffer, cx| {
|
||||
assert_eq!(buffer.buffer().read(cx).text(), "012");
|
||||
});
|
||||
|
||||
channel_buffer_a.read_with(cx_a, |buffer_a, _| {
|
||||
channel_buffer_b.read_with(cx_b, |buffer_b, _| {
|
||||
assert_eq!(
|
||||
buffer_a
|
||||
.collaborators()
|
||||
.iter()
|
||||
.map(|c| c.user_id)
|
||||
.collect::<Vec<_>>(),
|
||||
vec![client_a.user_id().unwrap(), client_b.user_id().unwrap()]
|
||||
);
|
||||
assert_eq!(buffer_a.collaborators(), buffer_b.collaborators());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn assert_collaborators(collaborators: &[proto::Collaborator], ids: &[Option<UserId>]) {
|
||||
assert_eq!(
|
||||
|
||||
@@ -874,6 +874,143 @@ async fn test_lost_channel_creation(
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_channel_moving(deterministic: Arc<Deterministic>, cx_a: &mut TestAppContext) {
|
||||
deterministic.forbid_parking();
|
||||
let mut server = TestServer::start(&deterministic).await;
|
||||
let client_a = server.create_client(cx_a, "user_a").await;
|
||||
|
||||
let channel_a_id = client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |channel_store, cx| {
|
||||
channel_store.create_channel("channel-a", None, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let channel_b_id = client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |channel_store, cx| {
|
||||
channel_store.create_channel("channel-b", Some(channel_a_id), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let channel_c_id = client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |channel_store, cx| {
|
||||
channel_store.create_channel("channel-c", Some(channel_b_id), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Current shape:
|
||||
// a - b - c
|
||||
deterministic.run_until_parked();
|
||||
assert_channels(
|
||||
client_a.channel_store(),
|
||||
cx_a,
|
||||
&[
|
||||
ExpectedChannel {
|
||||
id: channel_a_id,
|
||||
name: "channel-a".to_string(),
|
||||
depth: 0,
|
||||
user_is_admin: true,
|
||||
},
|
||||
ExpectedChannel {
|
||||
id: channel_b_id,
|
||||
name: "channel-b".to_string(),
|
||||
depth: 1,
|
||||
user_is_admin: true,
|
||||
},
|
||||
ExpectedChannel {
|
||||
id: channel_c_id,
|
||||
name: "channel-c".to_string(),
|
||||
depth: 2,
|
||||
user_is_admin: true,
|
||||
},
|
||||
],
|
||||
);
|
||||
|
||||
client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |channel_store, cx| {
|
||||
channel_store.move_channel(channel_c_id, Some(channel_b_id), channel_a_id, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Current shape:
|
||||
// /- c
|
||||
// a -- b
|
||||
deterministic.run_until_parked();
|
||||
assert_channels(
|
||||
client_a.channel_store(),
|
||||
cx_a,
|
||||
&[
|
||||
ExpectedChannel {
|
||||
id: channel_a_id,
|
||||
name: "channel-a".to_string(),
|
||||
depth: 0,
|
||||
user_is_admin: true,
|
||||
},
|
||||
ExpectedChannel {
|
||||
id: channel_b_id,
|
||||
name: "channel-b".to_string(),
|
||||
depth: 1,
|
||||
user_is_admin: true,
|
||||
},
|
||||
ExpectedChannel {
|
||||
id: channel_c_id,
|
||||
name: "channel-c".to_string(),
|
||||
depth: 1,
|
||||
user_is_admin: true,
|
||||
},
|
||||
],
|
||||
);
|
||||
|
||||
client_a
|
||||
.channel_store()
|
||||
.update(cx_a, |channel_store, cx| {
|
||||
channel_store.link_channel(channel_c_id, channel_b_id, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Current shape:
|
||||
// /------\
|
||||
// a -- b -- c
|
||||
deterministic.run_until_parked();
|
||||
assert_channels(
|
||||
client_a.channel_store(),
|
||||
cx_a,
|
||||
&[
|
||||
ExpectedChannel {
|
||||
id: channel_a_id,
|
||||
name: "channel-a".to_string(),
|
||||
depth: 0,
|
||||
user_is_admin: true,
|
||||
},
|
||||
ExpectedChannel {
|
||||
id: channel_b_id,
|
||||
name: "channel-b".to_string(),
|
||||
depth: 1,
|
||||
user_is_admin: true,
|
||||
},
|
||||
ExpectedChannel {
|
||||
id: channel_c_id,
|
||||
name: "channel-c".to_string(),
|
||||
depth: 2,
|
||||
user_is_admin: true,
|
||||
},
|
||||
ExpectedChannel {
|
||||
id: channel_c_id,
|
||||
name: "channel-c".to_string(),
|
||||
depth: 1,
|
||||
user_is_admin: true,
|
||||
},
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
struct ExpectedChannel {
|
||||
depth: usize,
|
||||
@@ -920,5 +1057,5 @@ fn assert_channels(
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
assert_eq!(actual, expected_channels);
|
||||
pretty_assertions::assert_eq!(actual, expected_channels);
|
||||
}
|
||||
|
||||
288
crates/collab/src/tests/random_channel_buffer_tests.rs
Normal file
288
crates/collab/src/tests/random_channel_buffer_tests.rs
Normal file
@@ -0,0 +1,288 @@
|
||||
use super::{run_randomized_test, RandomizedTest, TestClient, TestError, TestServer, UserTestPlan};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use gpui::{executor::Deterministic, TestAppContext};
|
||||
use rand::prelude::*;
|
||||
use serde_derive::{Deserialize, Serialize};
|
||||
use std::{ops::Range, rc::Rc, sync::Arc};
|
||||
use text::Bias;
|
||||
|
||||
#[gpui::test(
|
||||
iterations = 100,
|
||||
on_failure = "crate::tests::save_randomized_test_plan"
|
||||
)]
|
||||
async fn test_random_channel_buffers(
|
||||
cx: &mut TestAppContext,
|
||||
deterministic: Arc<Deterministic>,
|
||||
rng: StdRng,
|
||||
) {
|
||||
run_randomized_test::<RandomChannelBufferTest>(cx, deterministic, rng).await;
|
||||
}
|
||||
|
||||
struct RandomChannelBufferTest;
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
enum ChannelBufferOperation {
|
||||
JoinChannelNotes {
|
||||
channel_name: String,
|
||||
},
|
||||
LeaveChannelNotes {
|
||||
channel_name: String,
|
||||
},
|
||||
EditChannelNotes {
|
||||
channel_name: String,
|
||||
edits: Vec<(Range<usize>, Arc<str>)>,
|
||||
},
|
||||
Noop,
|
||||
}
|
||||
|
||||
const CHANNEL_COUNT: usize = 3;
|
||||
|
||||
#[async_trait(?Send)]
|
||||
impl RandomizedTest for RandomChannelBufferTest {
|
||||
type Operation = ChannelBufferOperation;
|
||||
|
||||
async fn initialize(server: &mut TestServer, users: &[UserTestPlan]) {
|
||||
let db = &server.app_state.db;
|
||||
for ix in 0..CHANNEL_COUNT {
|
||||
let id = db
|
||||
.create_channel(
|
||||
&format!("channel-{ix}"),
|
||||
None,
|
||||
&format!("livekit-room-{ix}"),
|
||||
users[0].user_id,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
for user in &users[1..] {
|
||||
db.invite_channel_member(id, user.user_id, users[0].user_id, false)
|
||||
.await
|
||||
.unwrap();
|
||||
db.respond_to_channel_invite(id, user.user_id, true)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_operation(
|
||||
client: &TestClient,
|
||||
rng: &mut StdRng,
|
||||
_: &mut UserTestPlan,
|
||||
cx: &TestAppContext,
|
||||
) -> ChannelBufferOperation {
|
||||
let channel_store = client.channel_store().clone();
|
||||
let channel_buffers = client.channel_buffers();
|
||||
|
||||
// When signed out, we can't do anything unless a channel buffer is
|
||||
// already open.
|
||||
if channel_buffers.is_empty()
|
||||
&& channel_store.read_with(cx, |store, _| store.channel_count() == 0)
|
||||
{
|
||||
return ChannelBufferOperation::Noop;
|
||||
}
|
||||
|
||||
loop {
|
||||
match rng.gen_range(0..100_u32) {
|
||||
0..=29 => {
|
||||
let channel_name = client.channel_store().read_with(cx, |store, cx| {
|
||||
store.channels().find_map(|(_, channel)| {
|
||||
if store.has_open_channel_buffer(channel.id, cx) {
|
||||
None
|
||||
} else {
|
||||
Some(channel.name.clone())
|
||||
}
|
||||
})
|
||||
});
|
||||
if let Some(channel_name) = channel_name {
|
||||
break ChannelBufferOperation::JoinChannelNotes { channel_name };
|
||||
}
|
||||
}
|
||||
|
||||
30..=40 => {
|
||||
if let Some(buffer) = channel_buffers.iter().choose(rng) {
|
||||
let channel_name = buffer.read_with(cx, |b, _| b.channel().name.clone());
|
||||
break ChannelBufferOperation::LeaveChannelNotes { channel_name };
|
||||
}
|
||||
}
|
||||
|
||||
_ => {
|
||||
if let Some(buffer) = channel_buffers.iter().choose(rng) {
|
||||
break buffer.read_with(cx, |b, _| {
|
||||
let channel_name = b.channel().name.clone();
|
||||
let edits = b
|
||||
.buffer()
|
||||
.read_with(cx, |buffer, _| buffer.get_random_edits(rng, 3));
|
||||
ChannelBufferOperation::EditChannelNotes {
|
||||
channel_name,
|
||||
edits,
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn apply_operation(
|
||||
client: &TestClient,
|
||||
operation: ChannelBufferOperation,
|
||||
cx: &mut TestAppContext,
|
||||
) -> Result<(), TestError> {
|
||||
match operation {
|
||||
ChannelBufferOperation::JoinChannelNotes { channel_name } => {
|
||||
let buffer = client.channel_store().update(cx, |store, cx| {
|
||||
let channel_id = store
|
||||
.channels()
|
||||
.find(|(_, c)| c.name == channel_name)
|
||||
.unwrap()
|
||||
.1
|
||||
.id;
|
||||
if store.has_open_channel_buffer(channel_id, cx) {
|
||||
Err(TestError::Inapplicable)
|
||||
} else {
|
||||
Ok(store.open_channel_buffer(channel_id, cx))
|
||||
}
|
||||
})?;
|
||||
|
||||
log::info!(
|
||||
"{}: opening notes for channel {channel_name}",
|
||||
client.username
|
||||
);
|
||||
client.channel_buffers().insert(buffer.await?);
|
||||
}
|
||||
|
||||
ChannelBufferOperation::LeaveChannelNotes { channel_name } => {
|
||||
let buffer = cx.update(|cx| {
|
||||
let mut left_buffer = Err(TestError::Inapplicable);
|
||||
client.channel_buffers().retain(|buffer| {
|
||||
if buffer.read(cx).channel().name == channel_name {
|
||||
left_buffer = Ok(buffer.clone());
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
left_buffer
|
||||
})?;
|
||||
|
||||
log::info!(
|
||||
"{}: closing notes for channel {channel_name}",
|
||||
client.username
|
||||
);
|
||||
cx.update(|_| drop(buffer));
|
||||
}
|
||||
|
||||
ChannelBufferOperation::EditChannelNotes {
|
||||
channel_name,
|
||||
edits,
|
||||
} => {
|
||||
let channel_buffer = cx
|
||||
.read(|cx| {
|
||||
client
|
||||
.channel_buffers()
|
||||
.iter()
|
||||
.find(|buffer| buffer.read(cx).channel().name == channel_name)
|
||||
.cloned()
|
||||
})
|
||||
.ok_or_else(|| TestError::Inapplicable)?;
|
||||
|
||||
log::info!(
|
||||
"{}: editing notes for channel {channel_name} with {:?}",
|
||||
client.username,
|
||||
edits
|
||||
);
|
||||
|
||||
channel_buffer.update(cx, |buffer, cx| {
|
||||
let buffer = buffer.buffer();
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot();
|
||||
buffer.edit(
|
||||
edits.into_iter().map(|(range, text)| {
|
||||
let start = snapshot.clip_offset(range.start, Bias::Left);
|
||||
let end = snapshot.clip_offset(range.end, Bias::Right);
|
||||
(start..end, text)
|
||||
}),
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
ChannelBufferOperation::Noop => Err(TestError::Inapplicable)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn on_client_added(client: &Rc<TestClient>, cx: &mut TestAppContext) {
|
||||
let channel_store = client.channel_store();
|
||||
while channel_store.read_with(cx, |store, _| store.channel_count() == 0) {
|
||||
channel_store.next_notification(cx).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_quiesce(server: &mut TestServer, clients: &mut [(Rc<TestClient>, TestAppContext)]) {
|
||||
let channels = server.app_state.db.all_channels().await.unwrap();
|
||||
|
||||
for (client, client_cx) in clients.iter_mut() {
|
||||
client_cx.update(|cx| {
|
||||
client
|
||||
.channel_buffers()
|
||||
.retain(|b| b.read(cx).is_connected());
|
||||
});
|
||||
}
|
||||
|
||||
for (channel_id, channel_name) in channels {
|
||||
let mut prev_text: Option<(u64, String)> = None;
|
||||
|
||||
let mut collaborator_user_ids = server
|
||||
.app_state
|
||||
.db
|
||||
.get_channel_buffer_collaborators(channel_id)
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|id| id.to_proto())
|
||||
.collect::<Vec<_>>();
|
||||
collaborator_user_ids.sort();
|
||||
|
||||
for (client, client_cx) in clients.iter() {
|
||||
let user_id = client.user_id().unwrap();
|
||||
client_cx.read(|cx| {
|
||||
if let Some(channel_buffer) = client
|
||||
.channel_buffers()
|
||||
.iter()
|
||||
.find(|b| b.read(cx).channel().id == channel_id.to_proto())
|
||||
{
|
||||
let channel_buffer = channel_buffer.read(cx);
|
||||
|
||||
// Assert that channel buffer's text matches other clients' copies.
|
||||
let text = channel_buffer.buffer().read(cx).text();
|
||||
if let Some((prev_user_id, prev_text)) = &prev_text {
|
||||
assert_eq!(
|
||||
&text,
|
||||
prev_text,
|
||||
"client {user_id} has different text than client {prev_user_id} for channel {channel_name}",
|
||||
);
|
||||
} else {
|
||||
prev_text = Some((user_id, text.clone()));
|
||||
}
|
||||
|
||||
// Assert that all clients and the server agree about who is present in the
|
||||
// channel buffer.
|
||||
let collaborators = channel_buffer.collaborators();
|
||||
let mut user_ids =
|
||||
collaborators.iter().map(|c| c.user_id).collect::<Vec<_>>();
|
||||
user_ids.sort();
|
||||
assert_eq!(
|
||||
user_ids,
|
||||
collaborator_user_ids,
|
||||
"client {user_id} has different user ids for channel {channel_name} than the server",
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
1585
crates/collab/src/tests/random_project_collaboration_tests.rs
Normal file
1585
crates/collab/src/tests/random_project_collaboration_tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
689
crates/collab/src/tests/randomized_test_helpers.rs
Normal file
689
crates/collab/src/tests/randomized_test_helpers.rs
Normal file
@@ -0,0 +1,689 @@
|
||||
use crate::{
|
||||
db::{self, NewUserParams, UserId},
|
||||
rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
|
||||
tests::{TestClient, TestServer},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use gpui::{executor::Deterministic, Task, TestAppContext};
|
||||
use parking_lot::Mutex;
|
||||
use rand::prelude::*;
|
||||
use rpc::RECEIVE_TIMEOUT;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use settings::SettingsStore;
|
||||
use std::{
|
||||
env,
|
||||
path::PathBuf,
|
||||
rc::Rc,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref PLAN_LOAD_PATH: Option<PathBuf> = path_env_var("LOAD_PLAN");
|
||||
static ref PLAN_SAVE_PATH: Option<PathBuf> = path_env_var("SAVE_PLAN");
|
||||
static ref MAX_PEERS: usize = env::var("MAX_PEERS")
|
||||
.map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
|
||||
.unwrap_or(3);
|
||||
static ref MAX_OPERATIONS: usize = env::var("OPERATIONS")
|
||||
.map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
|
||||
.unwrap_or(10);
|
||||
|
||||
}
|
||||
|
||||
static LOADED_PLAN_JSON: Mutex<Option<Vec<u8>>> = Mutex::new(None);
|
||||
static LAST_PLAN: Mutex<Option<Box<dyn Send + FnOnce() -> Vec<u8>>>> = Mutex::new(None);
|
||||
|
||||
struct TestPlan<T: RandomizedTest> {
|
||||
rng: StdRng,
|
||||
replay: bool,
|
||||
stored_operations: Vec<(StoredOperation<T::Operation>, Arc<AtomicBool>)>,
|
||||
max_operations: usize,
|
||||
operation_ix: usize,
|
||||
users: Vec<UserTestPlan>,
|
||||
next_batch_id: usize,
|
||||
allow_server_restarts: bool,
|
||||
allow_client_reconnection: bool,
|
||||
allow_client_disconnection: bool,
|
||||
}
|
||||
|
||||
pub struct UserTestPlan {
|
||||
pub user_id: UserId,
|
||||
pub username: String,
|
||||
pub allow_client_reconnection: bool,
|
||||
pub allow_client_disconnection: bool,
|
||||
next_root_id: usize,
|
||||
operation_ix: usize,
|
||||
online: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum StoredOperation<T> {
|
||||
Server(ServerOperation),
|
||||
Client {
|
||||
user_id: UserId,
|
||||
batch_id: usize,
|
||||
operation: T,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
enum ServerOperation {
|
||||
AddConnection {
|
||||
user_id: UserId,
|
||||
},
|
||||
RemoveConnection {
|
||||
user_id: UserId,
|
||||
},
|
||||
BounceConnection {
|
||||
user_id: UserId,
|
||||
},
|
||||
RestartServer,
|
||||
MutateClients {
|
||||
batch_id: usize,
|
||||
#[serde(skip_serializing)]
|
||||
#[serde(skip_deserializing)]
|
||||
user_ids: Vec<UserId>,
|
||||
quiesce: bool,
|
||||
},
|
||||
}
|
||||
|
||||
pub enum TestError {
|
||||
Inapplicable,
|
||||
Other(anyhow::Error),
|
||||
}
|
||||
|
||||
#[async_trait(?Send)]
|
||||
pub trait RandomizedTest: 'static + Sized {
|
||||
type Operation: Send + Clone + Serialize + DeserializeOwned;
|
||||
|
||||
fn generate_operation(
|
||||
client: &TestClient,
|
||||
rng: &mut StdRng,
|
||||
plan: &mut UserTestPlan,
|
||||
cx: &TestAppContext,
|
||||
) -> Self::Operation;
|
||||
|
||||
async fn apply_operation(
|
||||
client: &TestClient,
|
||||
operation: Self::Operation,
|
||||
cx: &mut TestAppContext,
|
||||
) -> Result<(), TestError>;
|
||||
|
||||
async fn initialize(server: &mut TestServer, users: &[UserTestPlan]);
|
||||
|
||||
async fn on_client_added(client: &Rc<TestClient>, cx: &mut TestAppContext);
|
||||
|
||||
async fn on_quiesce(server: &mut TestServer, client: &mut [(Rc<TestClient>, TestAppContext)]);
|
||||
}
|
||||
|
||||
pub async fn run_randomized_test<T: RandomizedTest>(
|
||||
cx: &mut TestAppContext,
|
||||
deterministic: Arc<Deterministic>,
|
||||
rng: StdRng,
|
||||
) {
|
||||
deterministic.forbid_parking();
|
||||
let mut server = TestServer::start(&deterministic).await;
|
||||
let plan = TestPlan::<T>::new(&mut server, rng).await;
|
||||
|
||||
LAST_PLAN.lock().replace({
|
||||
let plan = plan.clone();
|
||||
Box::new(move || plan.lock().serialize())
|
||||
});
|
||||
|
||||
let mut clients = Vec::new();
|
||||
let mut client_tasks = Vec::new();
|
||||
let mut operation_channels = Vec::new();
|
||||
loop {
|
||||
let Some((next_operation, applied)) = plan.lock().next_server_operation(&clients) else {
|
||||
break;
|
||||
};
|
||||
applied.store(true, SeqCst);
|
||||
let did_apply = TestPlan::apply_server_operation(
|
||||
plan.clone(),
|
||||
deterministic.clone(),
|
||||
&mut server,
|
||||
&mut clients,
|
||||
&mut client_tasks,
|
||||
&mut operation_channels,
|
||||
next_operation,
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
if !did_apply {
|
||||
applied.store(false, SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
drop(operation_channels);
|
||||
deterministic.start_waiting();
|
||||
futures::future::join_all(client_tasks).await;
|
||||
deterministic.finish_waiting();
|
||||
|
||||
deterministic.run_until_parked();
|
||||
T::on_quiesce(&mut server, &mut clients).await;
|
||||
|
||||
for (client, mut cx) in clients {
|
||||
cx.update(|cx| {
|
||||
let store = cx.remove_global::<SettingsStore>();
|
||||
cx.clear_globals();
|
||||
cx.set_global(store);
|
||||
drop(client);
|
||||
});
|
||||
}
|
||||
deterministic.run_until_parked();
|
||||
|
||||
if let Some(path) = &*PLAN_SAVE_PATH {
|
||||
eprintln!("saved test plan to path {:?}", path);
|
||||
std::fs::write(path, plan.lock().serialize()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save_randomized_test_plan() {
|
||||
if let Some(serialize_plan) = LAST_PLAN.lock().take() {
|
||||
if let Some(path) = &*PLAN_SAVE_PATH {
|
||||
eprintln!("saved test plan to path {:?}", path);
|
||||
std::fs::write(path, serialize_plan()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RandomizedTest> TestPlan<T> {
|
||||
pub async fn new(server: &mut TestServer, mut rng: StdRng) -> Arc<Mutex<Self>> {
|
||||
let allow_server_restarts = rng.gen_bool(0.7);
|
||||
let allow_client_reconnection = rng.gen_bool(0.7);
|
||||
let allow_client_disconnection = rng.gen_bool(0.1);
|
||||
|
||||
let mut users = Vec::new();
|
||||
for ix in 0..*MAX_PEERS {
|
||||
let username = format!("user-{}", ix + 1);
|
||||
let user_id = server
|
||||
.app_state
|
||||
.db
|
||||
.create_user(
|
||||
&format!("{username}@example.com"),
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: username.clone(),
|
||||
github_user_id: (ix + 1) as i32,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
users.push(UserTestPlan {
|
||||
user_id,
|
||||
username,
|
||||
online: false,
|
||||
next_root_id: 0,
|
||||
operation_ix: 0,
|
||||
allow_client_disconnection,
|
||||
allow_client_reconnection,
|
||||
});
|
||||
}
|
||||
|
||||
T::initialize(server, &users).await;
|
||||
|
||||
let plan = Arc::new(Mutex::new(Self {
|
||||
replay: false,
|
||||
allow_server_restarts,
|
||||
allow_client_reconnection,
|
||||
allow_client_disconnection,
|
||||
stored_operations: Vec::new(),
|
||||
operation_ix: 0,
|
||||
next_batch_id: 0,
|
||||
max_operations: *MAX_OPERATIONS,
|
||||
users,
|
||||
rng,
|
||||
}));
|
||||
|
||||
if let Some(path) = &*PLAN_LOAD_PATH {
|
||||
let json = LOADED_PLAN_JSON
|
||||
.lock()
|
||||
.get_or_insert_with(|| {
|
||||
eprintln!("loaded test plan from path {:?}", path);
|
||||
std::fs::read(path).unwrap()
|
||||
})
|
||||
.clone();
|
||||
plan.lock().deserialize(json);
|
||||
}
|
||||
|
||||
plan
|
||||
}
|
||||
|
||||
fn deserialize(&mut self, json: Vec<u8>) {
|
||||
let stored_operations: Vec<StoredOperation<T::Operation>> =
|
||||
serde_json::from_slice(&json).unwrap();
|
||||
self.replay = true;
|
||||
self.stored_operations = stored_operations
|
||||
.iter()
|
||||
.cloned()
|
||||
.enumerate()
|
||||
.map(|(i, mut operation)| {
|
||||
let did_apply = Arc::new(AtomicBool::new(false));
|
||||
if let StoredOperation::Server(ServerOperation::MutateClients {
|
||||
batch_id: current_batch_id,
|
||||
user_ids,
|
||||
..
|
||||
}) = &mut operation
|
||||
{
|
||||
assert!(user_ids.is_empty());
|
||||
user_ids.extend(stored_operations[i + 1..].iter().filter_map(|operation| {
|
||||
if let StoredOperation::Client {
|
||||
user_id, batch_id, ..
|
||||
} = operation
|
||||
{
|
||||
if batch_id == current_batch_id {
|
||||
return Some(user_id);
|
||||
}
|
||||
}
|
||||
None
|
||||
}));
|
||||
user_ids.sort_unstable();
|
||||
}
|
||||
(operation, did_apply)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn serialize(&mut self) -> Vec<u8> {
|
||||
// Format each operation as one line
|
||||
let mut json = Vec::new();
|
||||
json.push(b'[');
|
||||
for (operation, applied) in &self.stored_operations {
|
||||
if !applied.load(SeqCst) {
|
||||
continue;
|
||||
}
|
||||
if json.len() > 1 {
|
||||
json.push(b',');
|
||||
}
|
||||
json.extend_from_slice(b"\n ");
|
||||
serde_json::to_writer(&mut json, operation).unwrap();
|
||||
}
|
||||
json.extend_from_slice(b"\n]\n");
|
||||
json
|
||||
}
|
||||
|
||||
fn next_server_operation(
|
||||
&mut self,
|
||||
clients: &[(Rc<TestClient>, TestAppContext)],
|
||||
) -> Option<(ServerOperation, Arc<AtomicBool>)> {
|
||||
if self.replay {
|
||||
while let Some(stored_operation) = self.stored_operations.get(self.operation_ix) {
|
||||
self.operation_ix += 1;
|
||||
if let (StoredOperation::Server(operation), applied) = stored_operation {
|
||||
return Some((operation.clone(), applied.clone()));
|
||||
}
|
||||
}
|
||||
None
|
||||
} else {
|
||||
let operation = self.generate_server_operation(clients)?;
|
||||
let applied = Arc::new(AtomicBool::new(false));
|
||||
self.stored_operations
|
||||
.push((StoredOperation::Server(operation.clone()), applied.clone()));
|
||||
Some((operation, applied))
|
||||
}
|
||||
}
|
||||
|
||||
fn next_client_operation(
|
||||
&mut self,
|
||||
client: &TestClient,
|
||||
current_batch_id: usize,
|
||||
cx: &TestAppContext,
|
||||
) -> Option<(T::Operation, Arc<AtomicBool>)> {
|
||||
let current_user_id = client.current_user_id(cx);
|
||||
let user_ix = self
|
||||
.users
|
||||
.iter()
|
||||
.position(|user| user.user_id == current_user_id)
|
||||
.unwrap();
|
||||
let user_plan = &mut self.users[user_ix];
|
||||
|
||||
if self.replay {
|
||||
while let Some(stored_operation) = self.stored_operations.get(user_plan.operation_ix) {
|
||||
user_plan.operation_ix += 1;
|
||||
if let (
|
||||
StoredOperation::Client {
|
||||
user_id, operation, ..
|
||||
},
|
||||
applied,
|
||||
) = stored_operation
|
||||
{
|
||||
if user_id == ¤t_user_id {
|
||||
return Some((operation.clone(), applied.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
} else {
|
||||
if self.operation_ix == self.max_operations {
|
||||
return None;
|
||||
}
|
||||
self.operation_ix += 1;
|
||||
let operation = T::generate_operation(
|
||||
client,
|
||||
&mut self.rng,
|
||||
self.users
|
||||
.iter_mut()
|
||||
.find(|user| user.user_id == current_user_id)
|
||||
.unwrap(),
|
||||
cx,
|
||||
);
|
||||
let applied = Arc::new(AtomicBool::new(false));
|
||||
self.stored_operations.push((
|
||||
StoredOperation::Client {
|
||||
user_id: current_user_id,
|
||||
batch_id: current_batch_id,
|
||||
operation: operation.clone(),
|
||||
},
|
||||
applied.clone(),
|
||||
));
|
||||
Some((operation, applied))
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_server_operation(
|
||||
&mut self,
|
||||
clients: &[(Rc<TestClient>, TestAppContext)],
|
||||
) -> Option<ServerOperation> {
|
||||
if self.operation_ix == self.max_operations {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(loop {
|
||||
break match self.rng.gen_range(0..100) {
|
||||
0..=29 if clients.len() < self.users.len() => {
|
||||
let user = self
|
||||
.users
|
||||
.iter()
|
||||
.filter(|u| !u.online)
|
||||
.choose(&mut self.rng)
|
||||
.unwrap();
|
||||
self.operation_ix += 1;
|
||||
ServerOperation::AddConnection {
|
||||
user_id: user.user_id,
|
||||
}
|
||||
}
|
||||
30..=34 if clients.len() > 1 && self.allow_client_disconnection => {
|
||||
let (client, cx) = &clients[self.rng.gen_range(0..clients.len())];
|
||||
let user_id = client.current_user_id(cx);
|
||||
self.operation_ix += 1;
|
||||
ServerOperation::RemoveConnection { user_id }
|
||||
}
|
||||
35..=39 if clients.len() > 1 && self.allow_client_reconnection => {
|
||||
let (client, cx) = &clients[self.rng.gen_range(0..clients.len())];
|
||||
let user_id = client.current_user_id(cx);
|
||||
self.operation_ix += 1;
|
||||
ServerOperation::BounceConnection { user_id }
|
||||
}
|
||||
40..=44 if self.allow_server_restarts && clients.len() > 1 => {
|
||||
self.operation_ix += 1;
|
||||
ServerOperation::RestartServer
|
||||
}
|
||||
_ if !clients.is_empty() => {
|
||||
let count = self
|
||||
.rng
|
||||
.gen_range(1..10)
|
||||
.min(self.max_operations - self.operation_ix);
|
||||
let batch_id = util::post_inc(&mut self.next_batch_id);
|
||||
let mut user_ids = (0..count)
|
||||
.map(|_| {
|
||||
let ix = self.rng.gen_range(0..clients.len());
|
||||
let (client, cx) = &clients[ix];
|
||||
client.current_user_id(cx)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
user_ids.sort_unstable();
|
||||
ServerOperation::MutateClients {
|
||||
user_ids,
|
||||
batch_id,
|
||||
quiesce: self.rng.gen_bool(0.7),
|
||||
}
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
})
|
||||
}
|
||||
|
||||
async fn apply_server_operation(
|
||||
plan: Arc<Mutex<Self>>,
|
||||
deterministic: Arc<Deterministic>,
|
||||
server: &mut TestServer,
|
||||
clients: &mut Vec<(Rc<TestClient>, TestAppContext)>,
|
||||
client_tasks: &mut Vec<Task<()>>,
|
||||
operation_channels: &mut Vec<futures::channel::mpsc::UnboundedSender<usize>>,
|
||||
operation: ServerOperation,
|
||||
cx: &mut TestAppContext,
|
||||
) -> bool {
|
||||
match operation {
|
||||
ServerOperation::AddConnection { user_id } => {
|
||||
let username;
|
||||
{
|
||||
let mut plan = plan.lock();
|
||||
let user = plan.user(user_id);
|
||||
if user.online {
|
||||
return false;
|
||||
}
|
||||
user.online = true;
|
||||
username = user.username.clone();
|
||||
};
|
||||
log::info!("adding new connection for {}", username);
|
||||
let next_entity_id = (user_id.0 * 10_000) as usize;
|
||||
let mut client_cx = TestAppContext::new(
|
||||
cx.foreground_platform(),
|
||||
cx.platform(),
|
||||
deterministic.build_foreground(user_id.0 as usize),
|
||||
deterministic.build_background(),
|
||||
cx.font_cache(),
|
||||
cx.leak_detector(),
|
||||
next_entity_id,
|
||||
cx.function_name.clone(),
|
||||
);
|
||||
|
||||
let (operation_tx, operation_rx) = futures::channel::mpsc::unbounded();
|
||||
let client = Rc::new(server.create_client(&mut client_cx, &username).await);
|
||||
operation_channels.push(operation_tx);
|
||||
clients.push((client.clone(), client_cx.clone()));
|
||||
client_tasks.push(client_cx.foreground().spawn(Self::simulate_client(
|
||||
plan.clone(),
|
||||
client,
|
||||
operation_rx,
|
||||
client_cx,
|
||||
)));
|
||||
|
||||
log::info!("added connection for {}", username);
|
||||
}
|
||||
|
||||
ServerOperation::RemoveConnection {
|
||||
user_id: removed_user_id,
|
||||
} => {
|
||||
log::info!("simulating full disconnection of user {}", removed_user_id);
|
||||
let client_ix = clients
|
||||
.iter()
|
||||
.position(|(client, cx)| client.current_user_id(cx) == removed_user_id);
|
||||
let Some(client_ix) = client_ix else {
|
||||
return false;
|
||||
};
|
||||
let user_connection_ids = server
|
||||
.connection_pool
|
||||
.lock()
|
||||
.user_connection_ids(removed_user_id)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(user_connection_ids.len(), 1);
|
||||
let removed_peer_id = user_connection_ids[0].into();
|
||||
let (client, mut client_cx) = clients.remove(client_ix);
|
||||
let client_task = client_tasks.remove(client_ix);
|
||||
operation_channels.remove(client_ix);
|
||||
server.forbid_connections();
|
||||
server.disconnect_client(removed_peer_id);
|
||||
deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
|
||||
deterministic.start_waiting();
|
||||
log::info!("waiting for user {} to exit...", removed_user_id);
|
||||
client_task.await;
|
||||
deterministic.finish_waiting();
|
||||
server.allow_connections();
|
||||
|
||||
for project in client.remote_projects().iter() {
|
||||
project.read_with(&client_cx, |project, _| {
|
||||
assert!(
|
||||
project.is_read_only(),
|
||||
"project {:?} should be read only",
|
||||
project.remote_id()
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
for (client, cx) in clients {
|
||||
let contacts = server
|
||||
.app_state
|
||||
.db
|
||||
.get_contacts(client.current_user_id(cx))
|
||||
.await
|
||||
.unwrap();
|
||||
let pool = server.connection_pool.lock();
|
||||
for contact in contacts {
|
||||
if let db::Contact::Accepted { user_id, busy, .. } = contact {
|
||||
if user_id == removed_user_id {
|
||||
assert!(!pool.is_user_online(user_id));
|
||||
assert!(!busy);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("{} removed", client.username);
|
||||
plan.lock().user(removed_user_id).online = false;
|
||||
client_cx.update(|cx| {
|
||||
cx.clear_globals();
|
||||
drop(client);
|
||||
});
|
||||
}
|
||||
|
||||
ServerOperation::BounceConnection { user_id } => {
|
||||
log::info!("simulating temporary disconnection of user {}", user_id);
|
||||
let user_connection_ids = server
|
||||
.connection_pool
|
||||
.lock()
|
||||
.user_connection_ids(user_id)
|
||||
.collect::<Vec<_>>();
|
||||
if user_connection_ids.is_empty() {
|
||||
return false;
|
||||
}
|
||||
assert_eq!(user_connection_ids.len(), 1);
|
||||
let peer_id = user_connection_ids[0].into();
|
||||
server.disconnect_client(peer_id);
|
||||
deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
|
||||
}
|
||||
|
||||
ServerOperation::RestartServer => {
|
||||
log::info!("simulating server restart");
|
||||
server.reset().await;
|
||||
deterministic.advance_clock(RECEIVE_TIMEOUT);
|
||||
server.start().await.unwrap();
|
||||
deterministic.advance_clock(CLEANUP_TIMEOUT);
|
||||
let environment = &server.app_state.config.zed_environment;
|
||||
let (stale_room_ids, _) = server
|
||||
.app_state
|
||||
.db
|
||||
.stale_server_resource_ids(environment, server.id())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(stale_room_ids, vec![]);
|
||||
}
|
||||
|
||||
ServerOperation::MutateClients {
|
||||
user_ids,
|
||||
batch_id,
|
||||
quiesce,
|
||||
} => {
|
||||
let mut applied = false;
|
||||
for user_id in user_ids {
|
||||
let client_ix = clients
|
||||
.iter()
|
||||
.position(|(client, cx)| client.current_user_id(cx) == user_id);
|
||||
let Some(client_ix) = client_ix else { continue };
|
||||
applied = true;
|
||||
if let Err(err) = operation_channels[client_ix].unbounded_send(batch_id) {
|
||||
log::error!("error signaling user {user_id}: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
if quiesce && applied {
|
||||
deterministic.run_until_parked();
|
||||
T::on_quiesce(server, clients).await;
|
||||
}
|
||||
|
||||
return applied;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
async fn simulate_client(
|
||||
plan: Arc<Mutex<Self>>,
|
||||
client: Rc<TestClient>,
|
||||
mut operation_rx: futures::channel::mpsc::UnboundedReceiver<usize>,
|
||||
mut cx: TestAppContext,
|
||||
) {
|
||||
T::on_client_added(&client, &mut cx).await;
|
||||
|
||||
while let Some(batch_id) = operation_rx.next().await {
|
||||
let Some((operation, applied)) =
|
||||
plan.lock().next_client_operation(&client, batch_id, &cx)
|
||||
else {
|
||||
break;
|
||||
};
|
||||
applied.store(true, SeqCst);
|
||||
match T::apply_operation(&client, operation, &mut cx).await {
|
||||
Ok(()) => {}
|
||||
Err(TestError::Inapplicable) => {
|
||||
applied.store(false, SeqCst);
|
||||
log::info!("skipped operation");
|
||||
}
|
||||
Err(TestError::Other(error)) => {
|
||||
log::error!("{} error: {}", client.username, error);
|
||||
}
|
||||
}
|
||||
cx.background().simulate_random_delay().await;
|
||||
}
|
||||
log::info!("{}: done", client.username);
|
||||
}
|
||||
|
||||
fn user(&mut self, user_id: UserId) -> &mut UserTestPlan {
|
||||
self.users
|
||||
.iter_mut()
|
||||
.find(|user| user.user_id == user_id)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl UserTestPlan {
|
||||
pub fn next_root_dir_name(&mut self) -> String {
|
||||
let user_id = self.user_id;
|
||||
let root_id = util::post_inc(&mut self.next_root_id);
|
||||
format!("dir-{user_id}-{root_id}")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for TestError {
|
||||
fn from(value: anyhow::Error) -> Self {
|
||||
Self::Other(value)
|
||||
}
|
||||
}
|
||||
|
||||
fn path_env_var(name: &str) -> Option<PathBuf> {
|
||||
let value = env::var(name).ok()?;
|
||||
let mut path = PathBuf::from(value);
|
||||
if path.is_relative() {
|
||||
let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
abs_path.pop();
|
||||
abs_path.pop();
|
||||
abs_path.push(path);
|
||||
path = abs_path
|
||||
}
|
||||
Some(path)
|
||||
}
|
||||
558
crates/collab/src/tests/test_server.rs
Normal file
558
crates/collab/src/tests/test_server.rs
Normal file
@@ -0,0 +1,558 @@
|
||||
use crate::{
|
||||
db::{tests::TestDb, NewUserParams, UserId},
|
||||
executor::Executor,
|
||||
rpc::{Server, CLEANUP_TIMEOUT},
|
||||
AppState,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use call::ActiveCall;
|
||||
use channel::{channel_buffer::ChannelBuffer, ChannelStore};
|
||||
use client::{
|
||||
self, proto::PeerId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
|
||||
};
|
||||
use collections::{HashMap, HashSet};
|
||||
use fs::FakeFs;
|
||||
use futures::{channel::oneshot, StreamExt as _};
|
||||
use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle};
|
||||
use language::LanguageRegistry;
|
||||
use parking_lot::Mutex;
|
||||
use project::{Project, WorktreeId};
|
||||
use settings::SettingsStore;
|
||||
use std::{
|
||||
cell::{Ref, RefCell, RefMut},
|
||||
env,
|
||||
ops::{Deref, DerefMut},
|
||||
path::Path,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use util::http::FakeHttpClient;
|
||||
use workspace::Workspace;
|
||||
|
||||
pub struct TestServer {
|
||||
pub app_state: Arc<AppState>,
|
||||
pub test_live_kit_server: Arc<live_kit_client::TestServer>,
|
||||
server: Arc<Server>,
|
||||
connection_killers: Arc<Mutex<HashMap<PeerId, Arc<AtomicBool>>>>,
|
||||
forbid_connections: Arc<AtomicBool>,
|
||||
_test_db: TestDb,
|
||||
}
|
||||
|
||||
pub struct TestClient {
|
||||
pub username: String,
|
||||
pub app_state: Arc<workspace::AppState>,
|
||||
state: RefCell<TestClientState>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct TestClientState {
|
||||
local_projects: Vec<ModelHandle<Project>>,
|
||||
remote_projects: Vec<ModelHandle<Project>>,
|
||||
buffers: HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>,
|
||||
channel_buffers: HashSet<ModelHandle<ChannelBuffer>>,
|
||||
}
|
||||
|
||||
pub struct ContactsSummary {
|
||||
pub current: Vec<String>,
|
||||
pub outgoing_requests: Vec<String>,
|
||||
pub incoming_requests: Vec<String>,
|
||||
}
|
||||
|
||||
impl TestServer {
|
||||
pub async fn start(deterministic: &Arc<Deterministic>) -> Self {
|
||||
static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
let use_postgres = env::var("USE_POSTGRES").ok();
|
||||
let use_postgres = use_postgres.as_deref();
|
||||
let test_db = if use_postgres == Some("true") || use_postgres == Some("1") {
|
||||
TestDb::postgres(deterministic.build_background())
|
||||
} else {
|
||||
TestDb::sqlite(deterministic.build_background())
|
||||
};
|
||||
let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst);
|
||||
let live_kit_server = live_kit_client::TestServer::create(
|
||||
format!("http://livekit.{}.test", live_kit_server_id),
|
||||
format!("devkey-{}", live_kit_server_id),
|
||||
format!("secret-{}", live_kit_server_id),
|
||||
deterministic.build_background(),
|
||||
)
|
||||
.unwrap();
|
||||
let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
|
||||
let epoch = app_state
|
||||
.db
|
||||
.create_server(&app_state.config.zed_environment)
|
||||
.await
|
||||
.unwrap();
|
||||
let server = Server::new(
|
||||
epoch,
|
||||
app_state.clone(),
|
||||
Executor::Deterministic(deterministic.build_background()),
|
||||
);
|
||||
server.start().await.unwrap();
|
||||
// Advance clock to ensure the server's cleanup task is finished.
|
||||
deterministic.advance_clock(CLEANUP_TIMEOUT);
|
||||
Self {
|
||||
app_state,
|
||||
server,
|
||||
connection_killers: Default::default(),
|
||||
forbid_connections: Default::default(),
|
||||
_test_db: test_db,
|
||||
test_live_kit_server: live_kit_server,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn reset(&self) {
|
||||
self.app_state.db.reset();
|
||||
let epoch = self
|
||||
.app_state
|
||||
.db
|
||||
.create_server(&self.app_state.config.zed_environment)
|
||||
.await
|
||||
.unwrap();
|
||||
self.server.reset(epoch);
|
||||
}
|
||||
|
||||
pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
|
||||
cx.update(|cx| {
|
||||
if cx.has_global::<SettingsStore>() {
|
||||
panic!("Same cx used to create two test clients")
|
||||
}
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
});
|
||||
|
||||
let http = FakeHttpClient::with_404_response();
|
||||
let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
|
||||
{
|
||||
user.id
|
||||
} else {
|
||||
self.app_state
|
||||
.db
|
||||
.create_user(
|
||||
&format!("{name}@example.com"),
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: name.into(),
|
||||
github_user_id: 0,
|
||||
invite_count: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("creating user failed")
|
||||
.user_id
|
||||
};
|
||||
let client_name = name.to_string();
|
||||
let mut client = cx.read(|cx| Client::new(http.clone(), cx));
|
||||
let server = self.server.clone();
|
||||
let db = self.app_state.db.clone();
|
||||
let connection_killers = self.connection_killers.clone();
|
||||
let forbid_connections = self.forbid_connections.clone();
|
||||
|
||||
Arc::get_mut(&mut client)
|
||||
.unwrap()
|
||||
.set_id(user_id.0 as usize)
|
||||
.override_authenticate(move |cx| {
|
||||
cx.spawn(|_| async move {
|
||||
let access_token = "the-token".to_string();
|
||||
Ok(Credentials {
|
||||
user_id: user_id.0 as u64,
|
||||
access_token,
|
||||
})
|
||||
})
|
||||
})
|
||||
.override_establish_connection(move |credentials, cx| {
|
||||
assert_eq!(credentials.user_id, user_id.0 as u64);
|
||||
assert_eq!(credentials.access_token, "the-token");
|
||||
|
||||
let server = server.clone();
|
||||
let db = db.clone();
|
||||
let connection_killers = connection_killers.clone();
|
||||
let forbid_connections = forbid_connections.clone();
|
||||
let client_name = client_name.clone();
|
||||
cx.spawn(move |cx| async move {
|
||||
if forbid_connections.load(SeqCst) {
|
||||
Err(EstablishConnectionError::other(anyhow!(
|
||||
"server is forbidding connections"
|
||||
)))
|
||||
} else {
|
||||
let (client_conn, server_conn, killed) =
|
||||
Connection::in_memory(cx.background());
|
||||
let (connection_id_tx, connection_id_rx) = oneshot::channel();
|
||||
let user = db
|
||||
.get_user_by_id(user_id)
|
||||
.await
|
||||
.expect("retrieving user failed")
|
||||
.unwrap();
|
||||
cx.background()
|
||||
.spawn(server.handle_connection(
|
||||
server_conn,
|
||||
client_name,
|
||||
user,
|
||||
Some(connection_id_tx),
|
||||
Executor::Deterministic(cx.background()),
|
||||
))
|
||||
.detach();
|
||||
let connection_id = connection_id_rx.await.unwrap();
|
||||
connection_killers
|
||||
.lock()
|
||||
.insert(connection_id.into(), killed);
|
||||
Ok(client_conn)
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.background());
|
||||
let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
|
||||
let channel_store =
|
||||
cx.add_model(|cx| ChannelStore::new(client.clone(), user_store.clone(), cx));
|
||||
let app_state = Arc::new(workspace::AppState {
|
||||
client: client.clone(),
|
||||
user_store: user_store.clone(),
|
||||
channel_store: channel_store.clone(),
|
||||
languages: Arc::new(LanguageRegistry::test()),
|
||||
fs: fs.clone(),
|
||||
build_window_options: |_, _, _| Default::default(),
|
||||
initialize_workspace: |_, _, _, _| Task::ready(Ok(())),
|
||||
background_actions: || &[],
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
theme::init((), cx);
|
||||
Project::init(&client, cx);
|
||||
client::init(&client, cx);
|
||||
language::init(cx);
|
||||
editor::init_settings(cx);
|
||||
workspace::init(app_state.clone(), cx);
|
||||
audio::init((), cx);
|
||||
call::init(client.clone(), user_store.clone(), cx);
|
||||
channel::init(&client);
|
||||
});
|
||||
|
||||
client
|
||||
.authenticate_and_connect(false, &cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let client = TestClient {
|
||||
app_state,
|
||||
username: name.to_string(),
|
||||
state: Default::default(),
|
||||
};
|
||||
client.wait_for_current_user(cx).await;
|
||||
client
|
||||
}
|
||||
|
||||
pub fn disconnect_client(&self, peer_id: PeerId) {
|
||||
self.connection_killers
|
||||
.lock()
|
||||
.remove(&peer_id)
|
||||
.unwrap()
|
||||
.store(true, SeqCst);
|
||||
}
|
||||
|
||||
pub fn forbid_connections(&self) {
|
||||
self.forbid_connections.store(true, SeqCst);
|
||||
}
|
||||
|
||||
pub fn allow_connections(&self) {
|
||||
self.forbid_connections.store(false, SeqCst);
|
||||
}
|
||||
|
||||
pub async fn make_contacts(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
|
||||
for ix in 1..clients.len() {
|
||||
let (left, right) = clients.split_at_mut(ix);
|
||||
let (client_a, cx_a) = left.last_mut().unwrap();
|
||||
for (client_b, cx_b) in right {
|
||||
client_a
|
||||
.app_state
|
||||
.user_store
|
||||
.update(*cx_a, |store, cx| {
|
||||
store.request_contact(client_b.user_id().unwrap(), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
cx_a.foreground().run_until_parked();
|
||||
client_b
|
||||
.app_state
|
||||
.user_store
|
||||
.update(*cx_b, |store, cx| {
|
||||
store.respond_to_contact_request(client_a.user_id().unwrap(), true, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn make_channel(
|
||||
&self,
|
||||
channel: &str,
|
||||
admin: (&TestClient, &mut TestAppContext),
|
||||
members: &mut [(&TestClient, &mut TestAppContext)],
|
||||
) -> u64 {
|
||||
let (admin_client, admin_cx) = admin;
|
||||
let channel_id = admin_client
|
||||
.app_state
|
||||
.channel_store
|
||||
.update(admin_cx, |channel_store, cx| {
|
||||
channel_store.create_channel(channel, None, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
for (member_client, member_cx) in members {
|
||||
admin_client
|
||||
.app_state
|
||||
.channel_store
|
||||
.update(admin_cx, |channel_store, cx| {
|
||||
channel_store.invite_member(
|
||||
channel_id,
|
||||
member_client.user_id().unwrap(),
|
||||
false,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
admin_cx.foreground().run_until_parked();
|
||||
|
||||
member_client
|
||||
.app_state
|
||||
.channel_store
|
||||
.update(*member_cx, |channels, _| {
|
||||
channels.respond_to_channel_invite(channel_id, true)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
channel_id
|
||||
}
|
||||
|
||||
pub async fn create_room(&self, clients: &mut [(&TestClient, &mut TestAppContext)]) {
|
||||
self.make_contacts(clients).await;
|
||||
|
||||
let (left, right) = clients.split_at_mut(1);
|
||||
let (_client_a, cx_a) = &mut left[0];
|
||||
let active_call_a = cx_a.read(ActiveCall::global);
|
||||
|
||||
for (client_b, cx_b) in right {
|
||||
let user_id_b = client_b.current_user_id(*cx_b).to_proto();
|
||||
active_call_a
|
||||
.update(*cx_a, |call, cx| call.invite(user_id_b, None, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx_b.foreground().run_until_parked();
|
||||
let active_call_b = cx_b.read(ActiveCall::global);
|
||||
active_call_b
|
||||
.update(*cx_b, |call, cx| call.accept_incoming(cx))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn build_app_state(
|
||||
test_db: &TestDb,
|
||||
fake_server: &live_kit_client::TestServer,
|
||||
) -> Arc<AppState> {
|
||||
Arc::new(AppState {
|
||||
db: test_db.db().clone(),
|
||||
live_kit_client: Some(Arc::new(fake_server.create_api_client())),
|
||||
config: Default::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for TestServer {
|
||||
type Target = Server;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.server
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TestServer {
|
||||
fn drop(&mut self) {
|
||||
self.server.teardown();
|
||||
self.test_live_kit_server.teardown().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for TestClient {
|
||||
type Target = Arc<Client>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.app_state.client
|
||||
}
|
||||
}
|
||||
|
||||
impl TestClient {
|
||||
pub fn fs(&self) -> &FakeFs {
|
||||
self.app_state.fs.as_fake()
|
||||
}
|
||||
|
||||
pub fn channel_store(&self) -> &ModelHandle<ChannelStore> {
|
||||
&self.app_state.channel_store
|
||||
}
|
||||
|
||||
pub fn user_store(&self) -> &ModelHandle<UserStore> {
|
||||
&self.app_state.user_store
|
||||
}
|
||||
|
||||
pub fn language_registry(&self) -> &Arc<LanguageRegistry> {
|
||||
&self.app_state.languages
|
||||
}
|
||||
|
||||
pub fn client(&self) -> &Arc<Client> {
|
||||
&self.app_state.client
|
||||
}
|
||||
|
||||
pub fn current_user_id(&self, cx: &TestAppContext) -> UserId {
|
||||
UserId::from_proto(
|
||||
self.app_state
|
||||
.user_store
|
||||
.read_with(cx, |user_store, _| user_store.current_user().unwrap().id),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn wait_for_current_user(&self, cx: &TestAppContext) {
|
||||
let mut authed_user = self
|
||||
.app_state
|
||||
.user_store
|
||||
.read_with(cx, |user_store, _| user_store.watch_current_user());
|
||||
while authed_user.next().await.unwrap().is_none() {}
|
||||
}
|
||||
|
||||
pub async fn clear_contacts(&self, cx: &mut TestAppContext) {
|
||||
self.app_state
|
||||
.user_store
|
||||
.update(cx, |store, _| store.clear_contacts())
|
||||
.await;
|
||||
}
|
||||
|
||||
pub fn local_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
|
||||
Ref::map(self.state.borrow(), |state| &state.local_projects)
|
||||
}
|
||||
|
||||
pub fn remote_projects<'a>(&'a self) -> impl Deref<Target = Vec<ModelHandle<Project>>> + 'a {
|
||||
Ref::map(self.state.borrow(), |state| &state.remote_projects)
|
||||
}
|
||||
|
||||
pub fn local_projects_mut<'a>(
|
||||
&'a self,
|
||||
) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
|
||||
RefMut::map(self.state.borrow_mut(), |state| &mut state.local_projects)
|
||||
}
|
||||
|
||||
pub fn remote_projects_mut<'a>(
|
||||
&'a self,
|
||||
) -> impl DerefMut<Target = Vec<ModelHandle<Project>>> + 'a {
|
||||
RefMut::map(self.state.borrow_mut(), |state| &mut state.remote_projects)
|
||||
}
|
||||
|
||||
pub fn buffers_for_project<'a>(
|
||||
&'a self,
|
||||
project: &ModelHandle<Project>,
|
||||
) -> impl DerefMut<Target = HashSet<ModelHandle<language::Buffer>>> + 'a {
|
||||
RefMut::map(self.state.borrow_mut(), |state| {
|
||||
state.buffers.entry(project.clone()).or_default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn buffers<'a>(
|
||||
&'a self,
|
||||
) -> impl DerefMut<Target = HashMap<ModelHandle<Project>, HashSet<ModelHandle<language::Buffer>>>> + 'a
|
||||
{
|
||||
RefMut::map(self.state.borrow_mut(), |state| &mut state.buffers)
|
||||
}
|
||||
|
||||
pub fn channel_buffers<'a>(
|
||||
&'a self,
|
||||
) -> impl DerefMut<Target = HashSet<ModelHandle<ChannelBuffer>>> + 'a {
|
||||
RefMut::map(self.state.borrow_mut(), |state| &mut state.channel_buffers)
|
||||
}
|
||||
|
||||
pub fn summarize_contacts(&self, cx: &TestAppContext) -> ContactsSummary {
|
||||
self.app_state
|
||||
.user_store
|
||||
.read_with(cx, |store, _| ContactsSummary {
|
||||
current: store
|
||||
.contacts()
|
||||
.iter()
|
||||
.map(|contact| contact.user.github_login.clone())
|
||||
.collect(),
|
||||
outgoing_requests: store
|
||||
.outgoing_contact_requests()
|
||||
.iter()
|
||||
.map(|user| user.github_login.clone())
|
||||
.collect(),
|
||||
incoming_requests: store
|
||||
.incoming_contact_requests()
|
||||
.iter()
|
||||
.map(|user| user.github_login.clone())
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn build_local_project(
|
||||
&self,
|
||||
root_path: impl AsRef<Path>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> (ModelHandle<Project>, WorktreeId) {
|
||||
let project = cx.update(|cx| {
|
||||
Project::local(
|
||||
self.client().clone(),
|
||||
self.app_state.user_store.clone(),
|
||||
self.app_state.languages.clone(),
|
||||
self.app_state.fs.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let (worktree, _) = project
|
||||
.update(cx, |p, cx| {
|
||||
p.find_or_create_local_worktree(root_path, true, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
worktree
|
||||
.read_with(cx, |tree, _| tree.as_local().unwrap().scan_complete())
|
||||
.await;
|
||||
(project, worktree.read_with(cx, |tree, _| tree.id()))
|
||||
}
|
||||
|
||||
pub async fn build_remote_project(
|
||||
&self,
|
||||
host_project_id: u64,
|
||||
guest_cx: &mut TestAppContext,
|
||||
) -> ModelHandle<Project> {
|
||||
let active_call = guest_cx.read(ActiveCall::global);
|
||||
let room = active_call.read_with(guest_cx, |call, _| call.room().unwrap().clone());
|
||||
room.update(guest_cx, |room, cx| {
|
||||
room.join_project(
|
||||
host_project_id,
|
||||
self.app_state.languages.clone(),
|
||||
self.app_state.fs.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn build_workspace(
|
||||
&self,
|
||||
project: &ModelHandle<Project>,
|
||||
cx: &mut TestAppContext,
|
||||
) -> WindowHandle<Workspace> {
|
||||
cx.add_window(|cx| Workspace::new(0, project.clone(), self.app_state.clone(), cx))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TestClient {
|
||||
fn drop(&mut self) {
|
||||
self.app_state.client.teardown();
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,7 @@ use gpui::{
|
||||
ViewContext, ViewHandle,
|
||||
};
|
||||
use project::Project;
|
||||
use std::any::Any;
|
||||
use std::any::{Any, TypeId};
|
||||
use workspace::{
|
||||
item::{FollowableItem, Item, ItemHandle},
|
||||
register_followable_item,
|
||||
@@ -189,6 +189,21 @@ impl View for ChannelView {
|
||||
}
|
||||
|
||||
impl Item for ChannelView {
|
||||
fn act_as_type<'a>(
|
||||
&'a self,
|
||||
type_id: TypeId,
|
||||
self_handle: &'a ViewHandle<Self>,
|
||||
_: &'a AppContext,
|
||||
) -> Option<&'a AnyViewHandle> {
|
||||
if type_id == TypeId::of::<Self>() {
|
||||
Some(self_handle)
|
||||
} else if type_id == TypeId::of::<Editor>() {
|
||||
Some(&self.editor)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn tab_content<V: 'static>(
|
||||
&self,
|
||||
_: Option<usize>,
|
||||
|
||||
@@ -4,7 +4,7 @@ mod panel_settings;
|
||||
|
||||
use anyhow::Result;
|
||||
use call::ActiveCall;
|
||||
use channel::{Channel, ChannelEvent, ChannelId, ChannelStore};
|
||||
use channel::{Channel, ChannelEvent, ChannelId, ChannelPath, ChannelStore};
|
||||
use client::{proto::PeerId, Client, Contact, User, UserStore};
|
||||
use context_menu::{ContextMenu, ContextMenuItem};
|
||||
use db::kvp::KEY_VALUE_STORE;
|
||||
@@ -35,7 +35,7 @@ use panel_settings::{CollaborationPanelDockPosition, CollaborationPanelSettings}
|
||||
use project::{Fs, Project};
|
||||
use serde_derive::{Deserialize, Serialize};
|
||||
use settings::SettingsStore;
|
||||
use std::{borrow::Cow, mem, sync::Arc};
|
||||
use std::{borrow::Cow, hash::Hash, mem, sync::Arc};
|
||||
use theme::{components::ComponentExt, IconButton};
|
||||
use util::{iife, ResultExt, TryFutureExt};
|
||||
use workspace::{
|
||||
@@ -54,37 +54,59 @@ use self::contact_finder::ContactFinder;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
struct RemoveChannel {
|
||||
channel_id: u64,
|
||||
channel_id: ChannelId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
struct ToggleCollapse {
|
||||
channel_id: u64,
|
||||
location: ChannelLocation<'static>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
struct NewChannel {
|
||||
channel_id: u64,
|
||||
location: ChannelLocation<'static>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
struct InviteMembers {
|
||||
channel_id: u64,
|
||||
channel_id: ChannelId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
struct ManageMembers {
|
||||
channel_id: u64,
|
||||
channel_id: ChannelId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
struct RenameChannel {
|
||||
channel_id: u64,
|
||||
location: ChannelLocation<'static>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
struct OpenChannelBuffer {
|
||||
channel_id: u64,
|
||||
channel_id: ChannelId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
struct LinkChannel {
|
||||
channel_id: ChannelId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
struct MoveChannel {
|
||||
channel_id: ChannelId,
|
||||
parent_id: Option<ChannelId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
struct PutChannel {
|
||||
to: ChannelId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
struct UnlinkChannel {
|
||||
channel_id: ChannelId,
|
||||
parent_id: Option<ChannelId>,
|
||||
}
|
||||
|
||||
actions!(
|
||||
@@ -107,12 +129,40 @@ impl_actions!(
|
||||
ManageMembers,
|
||||
RenameChannel,
|
||||
ToggleCollapse,
|
||||
OpenChannelBuffer
|
||||
OpenChannelBuffer,
|
||||
LinkChannel,
|
||||
MoveChannel,
|
||||
PutChannel,
|
||||
UnlinkChannel
|
||||
]
|
||||
);
|
||||
|
||||
const COLLABORATION_PANEL_KEY: &'static str = "CollaborationPanel";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
|
||||
pub struct ChannelLocation<'a> {
|
||||
channel: ChannelId,
|
||||
path: Cow<'a, ChannelPath>,
|
||||
}
|
||||
|
||||
impl From<(ChannelId, ChannelPath)> for ChannelLocation<'static> {
|
||||
fn from(value: (ChannelId, ChannelPath)) -> Self {
|
||||
ChannelLocation {
|
||||
channel: value.0,
|
||||
path: Cow::Owned(value.1),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<(ChannelId, &'a ChannelPath)> for ChannelLocation<'a> {
|
||||
fn from(value: (ChannelId, &'a ChannelPath)) -> Self {
|
||||
ChannelLocation {
|
||||
channel: value.0,
|
||||
path: Cow::Borrowed(value.1),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(_client: Arc<Client>, cx: &mut AppContext) {
|
||||
settings::register::<panel_settings::CollaborationPanelSettings>(cx);
|
||||
contact_finder::init(cx);
|
||||
@@ -135,16 +185,65 @@ pub fn init(_client: Arc<Client>, cx: &mut AppContext) {
|
||||
cx.add_action(CollabPanel::collapse_selected_channel);
|
||||
cx.add_action(CollabPanel::expand_selected_channel);
|
||||
cx.add_action(CollabPanel::open_channel_buffer);
|
||||
|
||||
cx.add_action(
|
||||
|panel: &mut CollabPanel, action: &LinkChannel, _: &mut ViewContext<CollabPanel>| {
|
||||
panel.link_or_move = Some(ChannelCopy::Link(action.channel_id));
|
||||
},
|
||||
);
|
||||
|
||||
cx.add_action(
|
||||
|panel: &mut CollabPanel, action: &MoveChannel, _: &mut ViewContext<CollabPanel>| {
|
||||
panel.link_or_move = Some(ChannelCopy::Move {
|
||||
channel_id: action.channel_id,
|
||||
parent_id: action.parent_id,
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
cx.add_action(
|
||||
|panel: &mut CollabPanel, action: &PutChannel, cx: &mut ViewContext<CollabPanel>| {
|
||||
if let Some(copy) = panel.link_or_move.take() {
|
||||
match copy {
|
||||
ChannelCopy::Move {
|
||||
channel_id,
|
||||
parent_id,
|
||||
} => panel.channel_store.update(cx, |channel_store, cx| {
|
||||
channel_store
|
||||
.move_channel(channel_id, parent_id, action.to, cx)
|
||||
.detach_and_log_err(cx)
|
||||
}),
|
||||
ChannelCopy::Link(channel) => {
|
||||
panel.channel_store.update(cx, |channel_store, cx| {
|
||||
channel_store
|
||||
.link_channel(channel, action.to, cx)
|
||||
.detach_and_log_err(cx)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
cx.add_action(
|
||||
|panel: &mut CollabPanel, action: &UnlinkChannel, cx: &mut ViewContext<CollabPanel>| {
|
||||
panel.channel_store.update(cx, |channel_store, cx| {
|
||||
channel_store
|
||||
.unlink_channel(action.channel_id, action.parent_id, cx)
|
||||
.detach_and_log_err(cx)
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ChannelEditingState {
|
||||
Create {
|
||||
parent_id: Option<u64>,
|
||||
location: Option<ChannelLocation<'static>>,
|
||||
pending_name: Option<String>,
|
||||
},
|
||||
Rename {
|
||||
channel_id: u64,
|
||||
location: ChannelLocation<'static>,
|
||||
pending_name: Option<String>,
|
||||
},
|
||||
}
|
||||
@@ -158,10 +257,36 @@ impl ChannelEditingState {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum ChannelCopy {
|
||||
Move {
|
||||
channel_id: u64,
|
||||
parent_id: Option<u64>,
|
||||
},
|
||||
Link(u64),
|
||||
}
|
||||
|
||||
impl ChannelCopy {
|
||||
fn channel_id(&self) -> u64 {
|
||||
match self {
|
||||
ChannelCopy::Move { channel_id, .. } => *channel_id,
|
||||
ChannelCopy::Link(channel_id) => *channel_id,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_move(&self) -> bool {
|
||||
match self {
|
||||
ChannelCopy::Move { .. } => true,
|
||||
ChannelCopy::Link(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CollabPanel {
|
||||
width: Option<f32>,
|
||||
fs: Arc<dyn Fs>,
|
||||
has_focus: bool,
|
||||
link_or_move: Option<ChannelCopy>,
|
||||
pending_serialization: Task<Option<()>>,
|
||||
context_menu: ViewHandle<ContextMenu>,
|
||||
filter_editor: ViewHandle<Editor>,
|
||||
@@ -177,7 +302,7 @@ pub struct CollabPanel {
|
||||
list_state: ListState<Self>,
|
||||
subscriptions: Vec<Subscription>,
|
||||
collapsed_sections: Vec<Section>,
|
||||
collapsed_channels: Vec<ChannelId>,
|
||||
collapsed_channels: Vec<ChannelLocation<'static>>,
|
||||
workspace: WeakViewHandle<Workspace>,
|
||||
context_menu_on_selected: bool,
|
||||
}
|
||||
@@ -185,7 +310,7 @@ pub struct CollabPanel {
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SerializedCollabPanel {
|
||||
width: Option<f32>,
|
||||
collapsed_channels: Option<Vec<ChannelId>>,
|
||||
collapsed_channels: Option<Vec<ChannelLocation<'static>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -229,6 +354,7 @@ enum ListEntry {
|
||||
Channel {
|
||||
channel: Arc<Channel>,
|
||||
depth: usize,
|
||||
path: ChannelPath,
|
||||
},
|
||||
ChannelNotes {
|
||||
channel_id: ChannelId,
|
||||
@@ -353,10 +479,15 @@ impl CollabPanel {
|
||||
cx,
|
||||
)
|
||||
}
|
||||
ListEntry::Channel { channel, depth } => {
|
||||
ListEntry::Channel {
|
||||
channel,
|
||||
depth,
|
||||
path,
|
||||
} => {
|
||||
let channel_row = this.render_channel(
|
||||
&*channel,
|
||||
*depth,
|
||||
path.to_owned(),
|
||||
&theme.collab_panel,
|
||||
is_selected,
|
||||
cx,
|
||||
@@ -425,6 +556,7 @@ impl CollabPanel {
|
||||
let mut this = Self {
|
||||
width: None,
|
||||
has_focus: false,
|
||||
link_or_move: None,
|
||||
fs: workspace.app_state().fs.clone(),
|
||||
pending_serialization: Task::ready(None),
|
||||
context_menu: cx.add_view(|cx| ContextMenu::new(view_id, cx)),
|
||||
@@ -512,7 +644,13 @@ impl CollabPanel {
|
||||
.log_err()
|
||||
.flatten()
|
||||
{
|
||||
Some(serde_json::from_str::<SerializedCollabPanel>(&panel)?)
|
||||
match serde_json::from_str::<SerializedCollabPanel>(&panel) {
|
||||
Ok(panel) => Some(panel),
|
||||
Err(err) => {
|
||||
log::error!("Failed to deserialize collaboration panel: {}", err);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -702,28 +840,24 @@ impl CollabPanel {
|
||||
executor.clone(),
|
||||
));
|
||||
if let Some(state) = &self.channel_editing_state {
|
||||
if matches!(
|
||||
state,
|
||||
ChannelEditingState::Create {
|
||||
parent_id: None,
|
||||
..
|
||||
}
|
||||
) {
|
||||
if matches!(state, ChannelEditingState::Create { location: None, .. }) {
|
||||
self.entries.push(ListEntry::ChannelEditor { depth: 0 });
|
||||
}
|
||||
}
|
||||
let mut collapse_depth = None;
|
||||
for mat in matches {
|
||||
let (depth, channel) =
|
||||
channel_store.channel_at_index(mat.candidate_id).unwrap();
|
||||
let (channel, path) = channel_store.channel_at_index(mat.candidate_id).unwrap();
|
||||
let depth = path.len() - 1;
|
||||
|
||||
if collapse_depth.is_none() && self.is_channel_collapsed(channel.id) {
|
||||
let location: ChannelLocation<'_> = (channel.id, path).into();
|
||||
|
||||
if collapse_depth.is_none() && self.is_channel_collapsed(&location) {
|
||||
collapse_depth = Some(depth);
|
||||
} else if let Some(collapsed_depth) = collapse_depth {
|
||||
if depth > collapsed_depth {
|
||||
continue;
|
||||
}
|
||||
if self.is_channel_collapsed(channel.id) {
|
||||
if self.is_channel_collapsed(&location) {
|
||||
collapse_depth = Some(depth);
|
||||
} else {
|
||||
collapse_depth = None;
|
||||
@@ -731,18 +865,21 @@ impl CollabPanel {
|
||||
}
|
||||
|
||||
match &self.channel_editing_state {
|
||||
Some(ChannelEditingState::Create { parent_id, .. })
|
||||
if *parent_id == Some(channel.id) =>
|
||||
{
|
||||
Some(ChannelEditingState::Create {
|
||||
location: parent_id,
|
||||
..
|
||||
}) if *parent_id == Some(location) => {
|
||||
self.entries.push(ListEntry::Channel {
|
||||
channel: channel.clone(),
|
||||
depth,
|
||||
path: path.clone(),
|
||||
});
|
||||
self.entries
|
||||
.push(ListEntry::ChannelEditor { depth: depth + 1 });
|
||||
}
|
||||
Some(ChannelEditingState::Rename { channel_id, .. })
|
||||
if *channel_id == channel.id =>
|
||||
Some(ChannelEditingState::Rename { location, .. })
|
||||
if location.channel == channel.id
|
||||
&& location.path == Cow::Borrowed(path) =>
|
||||
{
|
||||
self.entries.push(ListEntry::ChannelEditor { depth });
|
||||
}
|
||||
@@ -750,6 +887,7 @@ impl CollabPanel {
|
||||
self.entries.push(ListEntry::Channel {
|
||||
channel: channel.clone(),
|
||||
depth,
|
||||
path: path.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1546,14 +1684,21 @@ impl CollabPanel {
|
||||
&self,
|
||||
channel: &Channel,
|
||||
depth: usize,
|
||||
path: ChannelPath,
|
||||
theme: &theme::CollabPanel,
|
||||
is_selected: bool,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> AnyElement<Self> {
|
||||
let channel_id = channel.id;
|
||||
let has_children = self.channel_store.read(cx).has_children(channel_id);
|
||||
let disclosed =
|
||||
has_children.then(|| !self.collapsed_channels.binary_search(&channel_id).is_ok());
|
||||
|
||||
let disclosed = {
|
||||
let location = ChannelLocation {
|
||||
channel: channel_id,
|
||||
path: Cow::Borrowed(&path),
|
||||
};
|
||||
has_children.then(|| !self.collapsed_channels.binary_search(&location).is_ok())
|
||||
};
|
||||
|
||||
let is_active = iife!({
|
||||
let call_channel = ActiveCall::global(cx)
|
||||
@@ -1567,7 +1712,7 @@ impl CollabPanel {
|
||||
|
||||
const FACEPILE_LIMIT: usize = 3;
|
||||
|
||||
MouseEventHandler::new::<Channel, _>(channel.id as usize, cx, |state, cx| {
|
||||
MouseEventHandler::new::<Channel, _>(id(&path) as usize, cx, |state, cx| {
|
||||
Flex::<Self>::row()
|
||||
.with_child(
|
||||
Svg::new("icons/hash.svg")
|
||||
@@ -1618,8 +1763,13 @@ impl CollabPanel {
|
||||
})
|
||||
.align_children_center()
|
||||
.styleable_component()
|
||||
.disclosable(disclosed, Box::new(ToggleCollapse { channel_id }))
|
||||
.with_id(channel_id as usize)
|
||||
.disclosable(
|
||||
disclosed,
|
||||
Box::new(ToggleCollapse {
|
||||
location: (channel_id, path.clone()).into(),
|
||||
}),
|
||||
)
|
||||
.with_id(id(&path) as usize)
|
||||
.with_style(theme.disclosure.clone())
|
||||
.element()
|
||||
.constrained()
|
||||
@@ -1635,7 +1785,11 @@ impl CollabPanel {
|
||||
this.join_channel(channel_id, cx);
|
||||
})
|
||||
.on_click(MouseButton::Right, move |e, this, cx| {
|
||||
this.deploy_channel_context_menu(Some(e.position), channel_id, cx);
|
||||
this.deploy_channel_context_menu(
|
||||
Some(e.position),
|
||||
&(channel_id, path.clone()).into(),
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.with_cursor_style(CursorStyle::PointingHand)
|
||||
.into_any()
|
||||
@@ -1882,11 +2036,20 @@ impl CollabPanel {
|
||||
fn deploy_channel_context_menu(
|
||||
&mut self,
|
||||
position: Option<Vector2F>,
|
||||
channel_id: u64,
|
||||
location: &ChannelLocation<'static>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) {
|
||||
self.context_menu_on_selected = position.is_none();
|
||||
|
||||
let operation_details = self.link_or_move.as_ref().and_then(|link_or_move| {
|
||||
let channel_name = self
|
||||
.channel_store
|
||||
.read(cx)
|
||||
.channel_for_id(link_or_move.channel_id())
|
||||
.map(|channel| channel.name.clone())?;
|
||||
Some((channel_name, link_or_move.is_move()))
|
||||
});
|
||||
|
||||
self.context_menu.update(cx, |context_menu, cx| {
|
||||
context_menu.set_position_mode(if self.context_menu_on_selected {
|
||||
OverlayPositionMode::Local
|
||||
@@ -1894,27 +2057,112 @@ impl CollabPanel {
|
||||
OverlayPositionMode::Window
|
||||
});
|
||||
|
||||
let expand_action_name = if self.is_channel_collapsed(channel_id) {
|
||||
let mut items = Vec::new();
|
||||
|
||||
if let Some((channel_name, is_move)) = operation_details {
|
||||
items.push(ContextMenuItem::action(
|
||||
format!(
|
||||
"{} '#{}' here",
|
||||
if is_move { "Move" } else { "Link" },
|
||||
channel_name
|
||||
),
|
||||
PutChannel {
|
||||
to: location.channel,
|
||||
},
|
||||
));
|
||||
items.push(ContextMenuItem::Separator)
|
||||
}
|
||||
|
||||
let expand_action_name = if self.is_channel_collapsed(&location) {
|
||||
"Expand Subchannels"
|
||||
} else {
|
||||
"Collapse Subchannels"
|
||||
};
|
||||
|
||||
let mut items = vec![
|
||||
ContextMenuItem::action(expand_action_name, ToggleCollapse { channel_id }),
|
||||
ContextMenuItem::action("Open Notes", OpenChannelBuffer { channel_id }),
|
||||
];
|
||||
items.extend([
|
||||
ContextMenuItem::action(
|
||||
expand_action_name,
|
||||
ToggleCollapse {
|
||||
location: location.clone(),
|
||||
},
|
||||
),
|
||||
ContextMenuItem::action(
|
||||
"Open Notes",
|
||||
OpenChannelBuffer {
|
||||
channel_id: location.channel,
|
||||
},
|
||||
),
|
||||
]);
|
||||
|
||||
if self.channel_store.read(cx).is_user_admin(location.channel) {
|
||||
let parent_id = location.path.parent_id();
|
||||
|
||||
if self.channel_store.read(cx).is_user_admin(channel_id) {
|
||||
items.extend([
|
||||
ContextMenuItem::Separator,
|
||||
ContextMenuItem::action("New Subchannel", NewChannel { channel_id }),
|
||||
ContextMenuItem::action("Rename", RenameChannel { channel_id }),
|
||||
ContextMenuItem::action(
|
||||
"New Subchannel",
|
||||
NewChannel {
|
||||
location: location.clone(),
|
||||
},
|
||||
),
|
||||
ContextMenuItem::action(
|
||||
"Rename",
|
||||
RenameChannel {
|
||||
location: location.clone(),
|
||||
},
|
||||
),
|
||||
ContextMenuItem::Separator,
|
||||
ContextMenuItem::action("Invite Members", InviteMembers { channel_id }),
|
||||
ContextMenuItem::action("Manage Members", ManageMembers { channel_id }),
|
||||
]);
|
||||
|
||||
items.push(ContextMenuItem::action(
|
||||
if parent_id.is_some() {
|
||||
"Unlink from parent"
|
||||
} else {
|
||||
"Unlink from root"
|
||||
},
|
||||
UnlinkChannel {
|
||||
channel_id: location.channel,
|
||||
parent_id,
|
||||
},
|
||||
));
|
||||
|
||||
items.extend([
|
||||
ContextMenuItem::action(
|
||||
"Link this channel",
|
||||
LinkChannel {
|
||||
channel_id: location.channel,
|
||||
},
|
||||
),
|
||||
ContextMenuItem::action(
|
||||
"Move this channel",
|
||||
MoveChannel {
|
||||
channel_id: location.channel,
|
||||
parent_id,
|
||||
},
|
||||
),
|
||||
]);
|
||||
|
||||
items.extend([
|
||||
ContextMenuItem::Separator,
|
||||
ContextMenuItem::action("Delete", RemoveChannel { channel_id }),
|
||||
ContextMenuItem::action(
|
||||
"Invite Members",
|
||||
InviteMembers {
|
||||
channel_id: location.channel,
|
||||
},
|
||||
),
|
||||
ContextMenuItem::action(
|
||||
"Manage Members",
|
||||
ManageMembers {
|
||||
channel_id: location.channel,
|
||||
},
|
||||
),
|
||||
ContextMenuItem::Separator,
|
||||
ContextMenuItem::action(
|
||||
"Delete",
|
||||
RemoveChannel {
|
||||
channel_id: location.channel,
|
||||
},
|
||||
),
|
||||
]);
|
||||
}
|
||||
|
||||
@@ -2040,7 +2288,7 @@ impl CollabPanel {
|
||||
if let Some(editing_state) = &mut self.channel_editing_state {
|
||||
match editing_state {
|
||||
ChannelEditingState::Create {
|
||||
parent_id,
|
||||
location,
|
||||
pending_name,
|
||||
..
|
||||
} => {
|
||||
@@ -2053,13 +2301,17 @@ impl CollabPanel {
|
||||
|
||||
self.channel_store
|
||||
.update(cx, |channel_store, cx| {
|
||||
channel_store.create_channel(&channel_name, *parent_id, cx)
|
||||
channel_store.create_channel(
|
||||
&channel_name,
|
||||
location.as_ref().map(|location| location.channel),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.detach();
|
||||
cx.notify();
|
||||
}
|
||||
ChannelEditingState::Rename {
|
||||
channel_id,
|
||||
location,
|
||||
pending_name,
|
||||
} => {
|
||||
if pending_name.is_some() {
|
||||
@@ -2070,7 +2322,7 @@ impl CollabPanel {
|
||||
|
||||
self.channel_store
|
||||
.update(cx, |channel_store, cx| {
|
||||
channel_store.rename(*channel_id, &channel_name, cx)
|
||||
channel_store.rename(location.channel, &channel_name, cx)
|
||||
})
|
||||
.detach();
|
||||
cx.notify();
|
||||
@@ -2097,38 +2349,58 @@ impl CollabPanel {
|
||||
_: &CollapseSelectedChannel,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) {
|
||||
let Some(channel_id) = self.selected_channel().map(|channel| channel.id) else {
|
||||
let Some((channel_id, path)) = self
|
||||
.selected_channel()
|
||||
.map(|(channel, parent)| (channel.id, parent))
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
if self.is_channel_collapsed(channel_id) {
|
||||
let path = path.to_owned();
|
||||
|
||||
if self.is_channel_collapsed(&(channel_id, path.clone()).into()) {
|
||||
return;
|
||||
}
|
||||
|
||||
self.toggle_channel_collapsed(&ToggleCollapse { channel_id }, cx)
|
||||
self.toggle_channel_collapsed(
|
||||
&ToggleCollapse {
|
||||
location: (channel_id, path).into(),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn expand_selected_channel(&mut self, _: &ExpandSelectedChannel, cx: &mut ViewContext<Self>) {
|
||||
let Some(channel_id) = self.selected_channel().map(|channel| channel.id) else {
|
||||
let Some((channel_id, path)) = self
|
||||
.selected_channel()
|
||||
.map(|(channel, parent)| (channel.id, parent))
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
if !self.is_channel_collapsed(channel_id) {
|
||||
let path = path.to_owned();
|
||||
|
||||
if !self.is_channel_collapsed(&(channel_id, path.clone()).into()) {
|
||||
return;
|
||||
}
|
||||
|
||||
self.toggle_channel_collapsed(&ToggleCollapse { channel_id }, cx)
|
||||
self.toggle_channel_collapsed(
|
||||
&ToggleCollapse {
|
||||
location: (channel_id, path).into(),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn toggle_channel_collapsed(&mut self, action: &ToggleCollapse, cx: &mut ViewContext<Self>) {
|
||||
let channel_id = action.channel_id;
|
||||
let location = action.location.clone();
|
||||
|
||||
match self.collapsed_channels.binary_search(&channel_id) {
|
||||
match self.collapsed_channels.binary_search(&location) {
|
||||
Ok(ix) => {
|
||||
self.collapsed_channels.remove(ix);
|
||||
}
|
||||
Err(ix) => {
|
||||
self.collapsed_channels.insert(ix, channel_id);
|
||||
self.collapsed_channels.insert(ix, location);
|
||||
}
|
||||
};
|
||||
self.serialize(cx);
|
||||
@@ -2137,8 +2409,8 @@ impl CollabPanel {
|
||||
cx.focus_self();
|
||||
}
|
||||
|
||||
fn is_channel_collapsed(&self, channel: ChannelId) -> bool {
|
||||
self.collapsed_channels.binary_search(&channel).is_ok()
|
||||
fn is_channel_collapsed(&self, location: &ChannelLocation) -> bool {
|
||||
self.collapsed_channels.binary_search(location).is_ok()
|
||||
}
|
||||
|
||||
fn leave_call(cx: &mut ViewContext<Self>) {
|
||||
@@ -2163,7 +2435,7 @@ impl CollabPanel {
|
||||
|
||||
fn new_root_channel(&mut self, cx: &mut ViewContext<Self>) {
|
||||
self.channel_editing_state = Some(ChannelEditingState::Create {
|
||||
parent_id: None,
|
||||
location: None,
|
||||
pending_name: None,
|
||||
});
|
||||
self.update_entries(false, cx);
|
||||
@@ -2181,9 +2453,9 @@ impl CollabPanel {
|
||||
|
||||
fn new_subchannel(&mut self, action: &NewChannel, cx: &mut ViewContext<Self>) {
|
||||
self.collapsed_channels
|
||||
.retain(|&channel| channel != action.channel_id);
|
||||
.retain(|channel| *channel != action.location);
|
||||
self.channel_editing_state = Some(ChannelEditingState::Create {
|
||||
parent_id: Some(action.channel_id),
|
||||
location: Some(action.location.to_owned()),
|
||||
pending_name: None,
|
||||
});
|
||||
self.update_entries(false, cx);
|
||||
@@ -2201,16 +2473,16 @@ impl CollabPanel {
|
||||
}
|
||||
|
||||
fn remove(&mut self, _: &Remove, cx: &mut ViewContext<Self>) {
|
||||
if let Some(channel) = self.selected_channel() {
|
||||
if let Some((channel, _)) = self.selected_channel() {
|
||||
self.remove_channel(channel.id, cx)
|
||||
}
|
||||
}
|
||||
|
||||
fn rename_selected_channel(&mut self, _: &menu::SecondaryConfirm, cx: &mut ViewContext<Self>) {
|
||||
if let Some(channel) = self.selected_channel() {
|
||||
if let Some((channel, parent)) = self.selected_channel() {
|
||||
self.rename_channel(
|
||||
&RenameChannel {
|
||||
channel_id: channel.id,
|
||||
location: (channel.id, parent.to_owned()).into(),
|
||||
},
|
||||
cx,
|
||||
);
|
||||
@@ -2219,12 +2491,15 @@ impl CollabPanel {
|
||||
|
||||
fn rename_channel(&mut self, action: &RenameChannel, cx: &mut ViewContext<Self>) {
|
||||
let channel_store = self.channel_store.read(cx);
|
||||
if !channel_store.is_user_admin(action.channel_id) {
|
||||
if !channel_store.is_user_admin(action.location.channel) {
|
||||
return;
|
||||
}
|
||||
if let Some(channel) = channel_store.channel_for_id(action.channel_id).cloned() {
|
||||
if let Some(channel) = channel_store
|
||||
.channel_for_id(action.location.channel)
|
||||
.cloned()
|
||||
{
|
||||
self.channel_editing_state = Some(ChannelEditingState::Rename {
|
||||
channel_id: action.channel_id,
|
||||
location: action.location.to_owned(),
|
||||
pending_name: None,
|
||||
});
|
||||
self.channel_name_editor.update(cx, |editor, cx| {
|
||||
@@ -2240,7 +2515,8 @@ impl CollabPanel {
|
||||
fn open_channel_buffer(&mut self, action: &OpenChannelBuffer, cx: &mut ViewContext<Self>) {
|
||||
if let Some(workspace) = self.workspace.upgrade(cx) {
|
||||
let pane = workspace.read(cx).active_pane().clone();
|
||||
let channel_view = ChannelView::open(action.channel_id, pane.clone(), workspace, cx);
|
||||
let channel_id = action.channel_id;
|
||||
let channel_view = ChannelView::open(channel_id, pane.clone(), workspace, cx);
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
let channel_view = channel_view.await?;
|
||||
pane.update(&mut cx, |pane, cx| {
|
||||
@@ -2249,22 +2525,38 @@ impl CollabPanel {
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
let room_id = ActiveCall::global(cx)
|
||||
.read(cx)
|
||||
.room()
|
||||
.map(|room| room.read(cx).id());
|
||||
|
||||
ActiveCall::report_call_event_for_room(
|
||||
"open channel notes",
|
||||
room_id,
|
||||
Some(channel_id),
|
||||
&self.client,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn show_inline_context_menu(&mut self, _: &menu::ShowContextMenu, cx: &mut ViewContext<Self>) {
|
||||
let Some(channel) = self.selected_channel() else {
|
||||
let Some((channel, path)) = self.selected_channel() else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.deploy_channel_context_menu(None, channel.id, cx);
|
||||
self.deploy_channel_context_menu(None, &(channel.id, path.to_owned()).into(), cx);
|
||||
}
|
||||
|
||||
fn selected_channel(&self) -> Option<&Arc<Channel>> {
|
||||
fn selected_channel(&self) -> Option<(&Arc<Channel>, &ChannelPath)> {
|
||||
self.selection
|
||||
.and_then(|ix| self.entries.get(ix))
|
||||
.and_then(|entry| match entry {
|
||||
ListEntry::Channel { channel, .. } => Some(channel),
|
||||
ListEntry::Channel {
|
||||
channel,
|
||||
path: parent,
|
||||
..
|
||||
} => Some((channel, parent)),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
@@ -2644,13 +2936,17 @@ impl PartialEq for ListEntry {
|
||||
ListEntry::Channel {
|
||||
channel: channel_1,
|
||||
depth: depth_1,
|
||||
path: parent_1,
|
||||
} => {
|
||||
if let ListEntry::Channel {
|
||||
channel: channel_2,
|
||||
depth: depth_2,
|
||||
path: parent_2,
|
||||
} = other
|
||||
{
|
||||
return channel_1.id == channel_2.id && depth_1 == depth_2;
|
||||
return channel_1.id == channel_2.id
|
||||
&& depth_1 == depth_2
|
||||
&& parent_1 == parent_2;
|
||||
}
|
||||
}
|
||||
ListEntry::ChannelNotes { channel_id } => {
|
||||
@@ -2713,3 +3009,26 @@ fn render_icon_button(style: &IconButton, svg_path: &'static str) -> impl Elemen
|
||||
.contained()
|
||||
.with_style(style.container)
|
||||
}
|
||||
|
||||
/// Hash a channel path to a u64, for use as a mouse id
|
||||
/// Based on the Fowler–Noll–Vo hash:
|
||||
/// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
|
||||
fn id(path: &[ChannelId]) -> u64 {
|
||||
// I probably should have done this, but I didn't
|
||||
// let hasher = DefaultHasher::new();
|
||||
// let path = path.hash(&mut hasher);
|
||||
// let x = hasher.finish();
|
||||
|
||||
const OFFSET: u64 = 14695981039346656037;
|
||||
const PRIME: u64 = 1099511628211;
|
||||
|
||||
let mut hash = OFFSET;
|
||||
for id in path.iter() {
|
||||
for id in id.to_ne_bytes() {
|
||||
hash = hash ^ (id as u64);
|
||||
hash = (hash as u128 * PRIME as u128) as u64;
|
||||
}
|
||||
}
|
||||
|
||||
hash
|
||||
}
|
||||
|
||||
@@ -771,7 +771,7 @@ impl CollabTitlebarItem {
|
||||
})
|
||||
.with_tooltip::<ToggleUserMenu>(
|
||||
0,
|
||||
"Toggle user menu".to_owned(),
|
||||
"Toggle User Menu".to_owned(),
|
||||
Some(Box::new(ToggleUserMenu)),
|
||||
tooltip,
|
||||
cx,
|
||||
|
||||
@@ -49,7 +49,7 @@ pub fn toggle_screen_sharing(_: &ToggleScreenSharing, cx: &mut AppContext) {
|
||||
if room.is_screen_sharing() {
|
||||
ActiveCall::report_call_event_for_room(
|
||||
"disable screen share",
|
||||
room.id(),
|
||||
Some(room.id()),
|
||||
room.channel_id(),
|
||||
&client,
|
||||
cx,
|
||||
@@ -58,7 +58,7 @@ pub fn toggle_screen_sharing(_: &ToggleScreenSharing, cx: &mut AppContext) {
|
||||
} else {
|
||||
ActiveCall::report_call_event_for_room(
|
||||
"enable screen share",
|
||||
room.id(),
|
||||
Some(room.id()),
|
||||
room.channel_id(),
|
||||
&client,
|
||||
cx,
|
||||
@@ -78,7 +78,7 @@ pub fn toggle_mute(_: &ToggleMute, cx: &mut AppContext) {
|
||||
if room.is_muted(cx) {
|
||||
ActiveCall::report_call_event_for_room(
|
||||
"enable microphone",
|
||||
room.id(),
|
||||
Some(room.id()),
|
||||
room.channel_id(),
|
||||
&client,
|
||||
cx,
|
||||
@@ -86,7 +86,7 @@ pub fn toggle_mute(_: &ToggleMute, cx: &mut AppContext) {
|
||||
} else {
|
||||
ActiveCall::report_call_event_for_room(
|
||||
"disable microphone",
|
||||
room.id(),
|
||||
Some(room.id()),
|
||||
room.channel_id(),
|
||||
&client,
|
||||
cx,
|
||||
|
||||
@@ -41,7 +41,7 @@ actions!(
|
||||
[Suggest, NextSuggestion, PreviousSuggestion, Reinstall]
|
||||
);
|
||||
|
||||
pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<NodeRuntime>, cx: &mut AppContext) {
|
||||
pub fn init(http: Arc<dyn HttpClient>, node_runtime: Arc<dyn NodeRuntime>, cx: &mut AppContext) {
|
||||
let copilot = cx.add_model({
|
||||
let node_runtime = node_runtime.clone();
|
||||
move |cx| Copilot::start(http, node_runtime, cx)
|
||||
@@ -265,7 +265,7 @@ pub struct Completion {
|
||||
|
||||
pub struct Copilot {
|
||||
http: Arc<dyn HttpClient>,
|
||||
node_runtime: Arc<NodeRuntime>,
|
||||
node_runtime: Arc<dyn NodeRuntime>,
|
||||
server: CopilotServer,
|
||||
buffers: HashSet<WeakModelHandle<Buffer>>,
|
||||
}
|
||||
@@ -299,7 +299,7 @@ impl Copilot {
|
||||
|
||||
fn start(
|
||||
http: Arc<dyn HttpClient>,
|
||||
node_runtime: Arc<NodeRuntime>,
|
||||
node_runtime: Arc<dyn NodeRuntime>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
let mut this = Self {
|
||||
@@ -335,12 +335,15 @@ impl Copilot {
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle<Self>, lsp::FakeLanguageServer) {
|
||||
use node_runtime::FakeNodeRuntime;
|
||||
|
||||
let (server, fake_server) =
|
||||
LanguageServer::fake("copilot".into(), Default::default(), cx.to_async());
|
||||
let http = util::http::FakeHttpClient::create(|_| async { unreachable!() });
|
||||
let node_runtime = FakeNodeRuntime::new();
|
||||
let this = cx.add_model(|_| Self {
|
||||
http: http.clone(),
|
||||
node_runtime: NodeRuntime::instance(http),
|
||||
node_runtime,
|
||||
server: CopilotServer::Running(RunningCopilotServer {
|
||||
lsp: Arc::new(server),
|
||||
sign_in_status: SignInStatus::Authorized,
|
||||
@@ -353,7 +356,7 @@ impl Copilot {
|
||||
|
||||
fn start_language_server(
|
||||
http: Arc<dyn HttpClient>,
|
||||
node_runtime: Arc<NodeRuntime>,
|
||||
node_runtime: Arc<dyn NodeRuntime>,
|
||||
this: ModelHandle<Self>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> impl Future<Output = ()> {
|
||||
|
||||
@@ -555,67 +555,6 @@ impl DisplaySnapshot {
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns an iterator of the start positions of the occurrences of `target` in the `self` after `from`
|
||||
/// Stops if `condition` returns false for any of the character position pairs observed.
|
||||
pub fn find_while<'a>(
|
||||
&'a self,
|
||||
from: DisplayPoint,
|
||||
target: &str,
|
||||
condition: impl FnMut(char, DisplayPoint) -> bool + 'a,
|
||||
) -> impl Iterator<Item = DisplayPoint> + 'a {
|
||||
Self::find_internal(self.chars_at(from), target.chars().collect(), condition)
|
||||
}
|
||||
|
||||
/// Returns an iterator of the end positions of the occurrences of `target` in the `self` before `from`
|
||||
/// Stops if `condition` returns false for any of the character position pairs observed.
|
||||
pub fn reverse_find_while<'a>(
|
||||
&'a self,
|
||||
from: DisplayPoint,
|
||||
target: &str,
|
||||
condition: impl FnMut(char, DisplayPoint) -> bool + 'a,
|
||||
) -> impl Iterator<Item = DisplayPoint> + 'a {
|
||||
Self::find_internal(
|
||||
self.reverse_chars_at(from),
|
||||
target.chars().rev().collect(),
|
||||
condition,
|
||||
)
|
||||
}
|
||||
|
||||
fn find_internal<'a>(
|
||||
iterator: impl Iterator<Item = (char, DisplayPoint)> + 'a,
|
||||
target: Vec<char>,
|
||||
mut condition: impl FnMut(char, DisplayPoint) -> bool + 'a,
|
||||
) -> impl Iterator<Item = DisplayPoint> + 'a {
|
||||
// List of partial matches with the index of the last seen character in target and the starting point of the match
|
||||
let mut partial_matches: Vec<(usize, DisplayPoint)> = Vec::new();
|
||||
iterator
|
||||
.take_while(move |(ch, point)| condition(*ch, *point))
|
||||
.filter_map(move |(ch, point)| {
|
||||
if Some(&ch) == target.get(0) {
|
||||
partial_matches.push((0, point));
|
||||
}
|
||||
|
||||
let mut found = None;
|
||||
// Keep partial matches that have the correct next character
|
||||
partial_matches.retain_mut(|(match_position, match_start)| {
|
||||
if target.get(*match_position) == Some(&ch) {
|
||||
*match_position += 1;
|
||||
if *match_position == target.len() {
|
||||
found = Some(match_start.clone());
|
||||
// This match is completed. No need to keep tracking it
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
|
||||
found
|
||||
})
|
||||
}
|
||||
|
||||
pub fn column_to_chars(&self, display_row: u32, target: u32) -> u32 {
|
||||
let mut count = 0;
|
||||
let mut column = 0;
|
||||
@@ -933,7 +872,7 @@ pub mod tests {
|
||||
use smol::stream::StreamExt;
|
||||
use std::{env, sync::Arc};
|
||||
use theme::SyntaxTheme;
|
||||
use util::test::{marked_text_offsets, marked_text_ranges, sample_text};
|
||||
use util::test::{marked_text_ranges, sample_text};
|
||||
use Bias::*;
|
||||
|
||||
#[gpui::test(iterations = 100)]
|
||||
@@ -1744,32 +1683,6 @@ pub mod tests {
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_internal() {
|
||||
assert("This is a ˇtest of find internal", "test");
|
||||
assert("Some text ˇaˇaˇaa with repeated characters", "aa");
|
||||
|
||||
fn assert(marked_text: &str, target: &str) {
|
||||
let (text, expected_offsets) = marked_text_offsets(marked_text);
|
||||
|
||||
let chars = text
|
||||
.chars()
|
||||
.enumerate()
|
||||
.map(|(index, ch)| (ch, DisplayPoint::new(0, index as u32)));
|
||||
let target = target.chars();
|
||||
|
||||
assert_eq!(
|
||||
expected_offsets
|
||||
.into_iter()
|
||||
.map(|offset| offset as u32)
|
||||
.collect::<Vec<_>>(),
|
||||
DisplaySnapshot::find_internal(chars, target.collect(), |_, _| true)
|
||||
.map(|point| point.column())
|
||||
.collect::<Vec<_>>()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn syntax_chunks<'a>(
|
||||
rows: Range<u32>,
|
||||
map: &ModelHandle<DisplayMap>,
|
||||
|
||||
@@ -44,7 +44,7 @@ use gpui::{
|
||||
elements::*,
|
||||
executor,
|
||||
fonts::{self, HighlightStyle, TextStyle},
|
||||
geometry::vector::Vector2F,
|
||||
geometry::vector::{vec2f, Vector2F},
|
||||
impl_actions,
|
||||
keymap_matcher::KeymapContext,
|
||||
platform::{CursorStyle, MouseButton},
|
||||
@@ -312,6 +312,10 @@ actions!(
|
||||
CopyPath,
|
||||
CopyRelativePath,
|
||||
CopyHighlightJson,
|
||||
ContextMenuFirst,
|
||||
ContextMenuPrev,
|
||||
ContextMenuNext,
|
||||
ContextMenuLast,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -468,6 +472,10 @@ pub fn init(cx: &mut AppContext) {
|
||||
cx.add_action(Editor::next_copilot_suggestion);
|
||||
cx.add_action(Editor::previous_copilot_suggestion);
|
||||
cx.add_action(Editor::copilot_suggest);
|
||||
cx.add_action(Editor::context_menu_first);
|
||||
cx.add_action(Editor::context_menu_prev);
|
||||
cx.add_action(Editor::context_menu_next);
|
||||
cx.add_action(Editor::context_menu_last);
|
||||
|
||||
hover_popover::init(cx);
|
||||
scroll::actions::init(cx);
|
||||
@@ -564,7 +572,7 @@ pub struct Editor {
|
||||
project: Option<ModelHandle<Project>>,
|
||||
focused: bool,
|
||||
blink_manager: ModelHandle<BlinkManager>,
|
||||
show_local_selections: bool,
|
||||
pub show_local_selections: bool,
|
||||
mode: EditorMode,
|
||||
replica_id_mapping: Option<HashMap<ReplicaId, ReplicaId>>,
|
||||
show_gutter: bool,
|
||||
@@ -820,6 +828,7 @@ struct CompletionsMenu {
|
||||
id: CompletionId,
|
||||
initial_position: Anchor,
|
||||
buffer: ModelHandle<Buffer>,
|
||||
project: Option<ModelHandle<Project>>,
|
||||
completions: Arc<[Completion]>,
|
||||
match_candidates: Vec<StringMatchCandidate>,
|
||||
matches: Arc<[StringMatch]>,
|
||||
@@ -863,6 +872,48 @@ impl CompletionsMenu {
|
||||
fn render(&self, style: EditorStyle, cx: &mut ViewContext<Editor>) -> AnyElement<Editor> {
|
||||
enum CompletionTag {}
|
||||
|
||||
let language_servers = self.project.as_ref().map(|project| {
|
||||
project
|
||||
.read(cx)
|
||||
.language_servers_for_buffer(self.buffer.read(cx), cx)
|
||||
.filter(|(_, server)| server.capabilities().completion_provider.is_some())
|
||||
.map(|(adapter, server)| (server.server_id(), adapter.short_name))
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
let needs_server_name = language_servers
|
||||
.as_ref()
|
||||
.map_or(false, |servers| servers.len() > 1);
|
||||
|
||||
let get_server_name =
|
||||
move |lookup_server_id: lsp::LanguageServerId| -> Option<&'static str> {
|
||||
language_servers
|
||||
.iter()
|
||||
.flatten()
|
||||
.find_map(|(server_id, server_name)| {
|
||||
if *server_id == lookup_server_id {
|
||||
Some(*server_name)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
let widest_completion_ix = self
|
||||
.matches
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by_key(|(_, mat)| {
|
||||
let completion = &self.completions[mat.candidate_id];
|
||||
let mut len = completion.label.text.chars().count();
|
||||
|
||||
if let Some(server_name) = get_server_name(completion.server_id) {
|
||||
len += server_name.chars().count();
|
||||
}
|
||||
|
||||
len
|
||||
})
|
||||
.map(|(ix, _)| ix);
|
||||
|
||||
let completions = self.completions.clone();
|
||||
let matches = self.matches.clone();
|
||||
let selected_item = self.selected_item;
|
||||
@@ -889,19 +940,83 @@ impl CompletionsMenu {
|
||||
style.autocomplete.item
|
||||
};
|
||||
|
||||
Text::new(completion.label.text.clone(), style.text.clone())
|
||||
.with_soft_wrap(false)
|
||||
.with_highlights(combine_syntax_and_fuzzy_match_highlights(
|
||||
&completion.label.text,
|
||||
style.text.color.into(),
|
||||
styled_runs_for_code_label(
|
||||
&completion.label,
|
||||
&style.syntax,
|
||||
),
|
||||
&mat.positions,
|
||||
))
|
||||
.contained()
|
||||
.with_style(item_style)
|
||||
let completion_label =
|
||||
Text::new(completion.label.text.clone(), style.text.clone())
|
||||
.with_soft_wrap(false)
|
||||
.with_highlights(
|
||||
combine_syntax_and_fuzzy_match_highlights(
|
||||
&completion.label.text,
|
||||
style.text.color.into(),
|
||||
styled_runs_for_code_label(
|
||||
&completion.label,
|
||||
&style.syntax,
|
||||
),
|
||||
&mat.positions,
|
||||
),
|
||||
);
|
||||
|
||||
if let Some(server_name) = get_server_name(completion.server_id) {
|
||||
Flex::row()
|
||||
.with_child(completion_label)
|
||||
.with_children((|| {
|
||||
if !needs_server_name {
|
||||
return None;
|
||||
}
|
||||
|
||||
let text_style = TextStyle {
|
||||
color: style.autocomplete.server_name_color,
|
||||
font_size: style.text.font_size
|
||||
* style.autocomplete.server_name_size_percent,
|
||||
..style.text.clone()
|
||||
};
|
||||
|
||||
let label = Text::new(server_name, text_style)
|
||||
.aligned()
|
||||
.constrained()
|
||||
.dynamically(move |constraint, _, _| {
|
||||
gpui::SizeConstraint {
|
||||
min: constraint.min,
|
||||
max: vec2f(
|
||||
constraint.max.x(),
|
||||
constraint.min.y(),
|
||||
),
|
||||
}
|
||||
});
|
||||
|
||||
if Some(item_ix) == widest_completion_ix {
|
||||
Some(
|
||||
label
|
||||
.contained()
|
||||
.with_style(
|
||||
style
|
||||
.autocomplete
|
||||
.server_name_container,
|
||||
)
|
||||
.into_any(),
|
||||
)
|
||||
} else {
|
||||
Some(label.flex_float().into_any())
|
||||
}
|
||||
})())
|
||||
.into_any()
|
||||
} else {
|
||||
completion_label.into_any()
|
||||
}
|
||||
.contained()
|
||||
.with_style(item_style)
|
||||
.constrained()
|
||||
.dynamically(
|
||||
move |constraint, _, _| {
|
||||
if Some(item_ix) == widest_completion_ix {
|
||||
constraint
|
||||
} else {
|
||||
gpui::SizeConstraint {
|
||||
min: constraint.min,
|
||||
max: constraint.min,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
.with_cursor_style(CursorStyle::PointingHand)
|
||||
@@ -918,19 +1033,7 @@ impl CompletionsMenu {
|
||||
}
|
||||
},
|
||||
)
|
||||
.with_width_from_item(
|
||||
self.matches
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by_key(|(_, mat)| {
|
||||
self.completions[mat.candidate_id]
|
||||
.label
|
||||
.text
|
||||
.chars()
|
||||
.count()
|
||||
})
|
||||
.map(|(ix, _)| ix),
|
||||
)
|
||||
.with_width_from_item(widest_completion_ix)
|
||||
.contained()
|
||||
.with_style(container_style)
|
||||
.into_any()
|
||||
@@ -1559,7 +1662,7 @@ impl Editor {
|
||||
.excerpt_containing(self.selections.newest_anchor().head(), cx)
|
||||
}
|
||||
|
||||
fn style(&self, cx: &AppContext) -> EditorStyle {
|
||||
pub fn style(&self, cx: &AppContext) -> EditorStyle {
|
||||
build_style(
|
||||
settings::get::<ThemeSettings>(cx),
|
||||
self.get_field_editor_theme.as_deref(),
|
||||
@@ -2166,10 +2269,6 @@ impl Editor {
|
||||
if self.read_only {
|
||||
return;
|
||||
}
|
||||
if !self.input_enabled {
|
||||
cx.emit(Event::InputIgnored { text });
|
||||
return;
|
||||
}
|
||||
|
||||
let selections = self.selections.all_adjusted(cx);
|
||||
let mut brace_inserted = false;
|
||||
@@ -2983,6 +3082,7 @@ impl Editor {
|
||||
});
|
||||
|
||||
let id = post_inc(&mut self.next_completion_id);
|
||||
let project = self.project.clone();
|
||||
let task = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let menu = if let Some(completions) = completions.await.log_err() {
|
||||
@@ -3001,6 +3101,7 @@ impl Editor {
|
||||
})
|
||||
.collect(),
|
||||
buffer,
|
||||
project,
|
||||
completions: completions.into(),
|
||||
matches: Vec::new().into(),
|
||||
selected_item: 0,
|
||||
@@ -3102,17 +3203,30 @@ impl Editor {
|
||||
.count();
|
||||
|
||||
let snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
let mut range_to_replace: Option<Range<isize>> = None;
|
||||
let mut ranges = Vec::new();
|
||||
for selection in &selections {
|
||||
if snapshot.contains_str_at(selection.start.saturating_sub(lookbehind), &old_text) {
|
||||
let start = selection.start.saturating_sub(lookbehind);
|
||||
let end = selection.end + lookahead;
|
||||
if selection.id == newest_selection.id {
|
||||
range_to_replace = Some(
|
||||
((start + common_prefix_len) as isize - selection.start as isize)
|
||||
..(end as isize - selection.start as isize),
|
||||
);
|
||||
}
|
||||
ranges.push(start + common_prefix_len..end);
|
||||
} else {
|
||||
common_prefix_len = 0;
|
||||
ranges.clear();
|
||||
ranges.extend(selections.iter().map(|s| {
|
||||
if s.id == newest_selection.id {
|
||||
range_to_replace = Some(
|
||||
old_range.start.to_offset_utf16(&snapshot).0 as isize
|
||||
- selection.start as isize
|
||||
..old_range.end.to_offset_utf16(&snapshot).0 as isize
|
||||
- selection.start as isize,
|
||||
);
|
||||
old_range.clone()
|
||||
} else {
|
||||
s.start..s.end
|
||||
@@ -3123,6 +3237,11 @@ impl Editor {
|
||||
}
|
||||
let text = &text[common_prefix_len..];
|
||||
|
||||
cx.emit(Event::InputHandled {
|
||||
utf16_range_to_replace: range_to_replace,
|
||||
text: text.into(),
|
||||
});
|
||||
|
||||
self.transact(cx, |this, cx| {
|
||||
if let Some(mut snippet) = snippet {
|
||||
snippet.text = text.to_string();
|
||||
@@ -3580,6 +3699,10 @@ impl Editor {
|
||||
|
||||
self.report_copilot_event(Some(completion.uuid.clone()), true, cx)
|
||||
}
|
||||
cx.emit(Event::InputHandled {
|
||||
utf16_range_to_replace: None,
|
||||
text: suggestion.text.to_string().into(),
|
||||
});
|
||||
self.insert_with_autoindent_mode(&suggestion.text.to_string(), None, cx);
|
||||
cx.notify();
|
||||
true
|
||||
@@ -5069,12 +5192,6 @@ impl Editor {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(context_menu) = self.context_menu.as_mut() {
|
||||
if context_menu.select_prev(cx) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if matches!(self.mode, EditorMode::SingleLine) {
|
||||
cx.propagate_action();
|
||||
return;
|
||||
@@ -5097,15 +5214,6 @@ impl Editor {
|
||||
return;
|
||||
}
|
||||
|
||||
if self
|
||||
.context_menu
|
||||
.as_mut()
|
||||
.map(|menu| menu.select_first(cx))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if matches!(self.mode, EditorMode::SingleLine) {
|
||||
cx.propagate_action();
|
||||
return;
|
||||
@@ -5145,12 +5253,6 @@ impl Editor {
|
||||
pub fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
|
||||
self.take_rename(true, cx);
|
||||
|
||||
if let Some(context_menu) = self.context_menu.as_mut() {
|
||||
if context_menu.select_next(cx) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if self.mode == EditorMode::SingleLine {
|
||||
cx.propagate_action();
|
||||
return;
|
||||
@@ -5218,6 +5320,30 @@ impl Editor {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn context_menu_first(&mut self, _: &ContextMenuFirst, cx: &mut ViewContext<Self>) {
|
||||
if let Some(context_menu) = self.context_menu.as_mut() {
|
||||
context_menu.select_first(cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_menu_prev(&mut self, _: &ContextMenuPrev, cx: &mut ViewContext<Self>) {
|
||||
if let Some(context_menu) = self.context_menu.as_mut() {
|
||||
context_menu.select_prev(cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_menu_next(&mut self, _: &ContextMenuNext, cx: &mut ViewContext<Self>) {
|
||||
if let Some(context_menu) = self.context_menu.as_mut() {
|
||||
context_menu.select_next(cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_menu_last(&mut self, _: &ContextMenuLast, cx: &mut ViewContext<Self>) {
|
||||
if let Some(context_menu) = self.context_menu.as_mut() {
|
||||
context_menu.select_last(cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn move_to_previous_word_start(
|
||||
&mut self,
|
||||
_: &MoveToPreviousWordStart,
|
||||
@@ -8328,6 +8454,41 @@ impl Editor {
|
||||
pub fn inlay_hint_cache(&self) -> &InlayHintCache {
|
||||
&self.inlay_hint_cache
|
||||
}
|
||||
|
||||
pub fn replay_insert_event(
|
||||
&mut self,
|
||||
text: &str,
|
||||
relative_utf16_range: Option<Range<isize>>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) {
|
||||
if !self.input_enabled {
|
||||
cx.emit(Event::InputIgnored { text: text.into() });
|
||||
return;
|
||||
}
|
||||
if let Some(relative_utf16_range) = relative_utf16_range {
|
||||
let selections = self.selections.all::<OffsetUtf16>(cx);
|
||||
self.change_selections(None, cx, |s| {
|
||||
let new_ranges = selections.into_iter().map(|range| {
|
||||
let start = OffsetUtf16(
|
||||
range
|
||||
.head()
|
||||
.0
|
||||
.saturating_add_signed(relative_utf16_range.start),
|
||||
);
|
||||
let end = OffsetUtf16(
|
||||
range
|
||||
.head()
|
||||
.0
|
||||
.saturating_add_signed(relative_utf16_range.end),
|
||||
);
|
||||
start..end
|
||||
});
|
||||
s.select_ranges(new_ranges);
|
||||
});
|
||||
}
|
||||
|
||||
self.handle_input(text, cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn document_to_inlay_range(
|
||||
@@ -8416,6 +8577,10 @@ pub enum Event {
|
||||
InputIgnored {
|
||||
text: Arc<str>,
|
||||
},
|
||||
InputHandled {
|
||||
utf16_range_to_replace: Option<Range<isize>>,
|
||||
text: Arc<str>,
|
||||
},
|
||||
ExcerptsAdded {
|
||||
buffer: ModelHandle<Buffer>,
|
||||
predecessor: ExcerptId,
|
||||
@@ -8569,17 +8734,20 @@ impl View for Editor {
|
||||
if self.pending_rename.is_some() {
|
||||
keymap.add_identifier("renaming");
|
||||
}
|
||||
match self.context_menu.as_ref() {
|
||||
Some(ContextMenu::Completions(_)) => {
|
||||
keymap.add_identifier("menu");
|
||||
keymap.add_identifier("showing_completions")
|
||||
if self.context_menu_visible() {
|
||||
match self.context_menu.as_ref() {
|
||||
Some(ContextMenu::Completions(_)) => {
|
||||
keymap.add_identifier("menu");
|
||||
keymap.add_identifier("showing_completions")
|
||||
}
|
||||
Some(ContextMenu::CodeActions(_)) => {
|
||||
keymap.add_identifier("menu");
|
||||
keymap.add_identifier("showing_code_actions")
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
Some(ContextMenu::CodeActions(_)) => {
|
||||
keymap.add_identifier("menu");
|
||||
keymap.add_identifier("showing_code_actions")
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
for layer in self.keymap_context_layers.values() {
|
||||
keymap.extend(layer);
|
||||
}
|
||||
@@ -8633,29 +8801,51 @@ impl View for Editor {
|
||||
text: &str,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) {
|
||||
self.transact(cx, |this, cx| {
|
||||
if this.input_enabled {
|
||||
let new_selected_ranges = if let Some(range_utf16) = range_utf16 {
|
||||
let range_utf16 = OffsetUtf16(range_utf16.start)..OffsetUtf16(range_utf16.end);
|
||||
Some(this.selection_replacement_ranges(range_utf16, cx))
|
||||
} else {
|
||||
this.marked_text_ranges(cx)
|
||||
};
|
||||
if !self.input_enabled {
|
||||
cx.emit(Event::InputIgnored { text: text.into() });
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(new_selected_ranges) = new_selected_ranges {
|
||||
this.change_selections(None, cx, |selections| {
|
||||
selections.select_ranges(new_selected_ranges)
|
||||
});
|
||||
}
|
||||
self.transact(cx, |this, cx| {
|
||||
let new_selected_ranges = if let Some(range_utf16) = range_utf16 {
|
||||
let range_utf16 = OffsetUtf16(range_utf16.start)..OffsetUtf16(range_utf16.end);
|
||||
Some(this.selection_replacement_ranges(range_utf16, cx))
|
||||
} else {
|
||||
this.marked_text_ranges(cx)
|
||||
};
|
||||
|
||||
let range_to_replace = new_selected_ranges.as_ref().and_then(|ranges_to_replace| {
|
||||
let newest_selection_id = this.selections.newest_anchor().id;
|
||||
this.selections
|
||||
.all::<OffsetUtf16>(cx)
|
||||
.iter()
|
||||
.zip(ranges_to_replace.iter())
|
||||
.find_map(|(selection, range)| {
|
||||
if selection.id == newest_selection_id {
|
||||
Some(
|
||||
(range.start.0 as isize - selection.head().0 as isize)
|
||||
..(range.end.0 as isize - selection.head().0 as isize),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
cx.emit(Event::InputHandled {
|
||||
utf16_range_to_replace: range_to_replace,
|
||||
text: text.into(),
|
||||
});
|
||||
|
||||
if let Some(new_selected_ranges) = new_selected_ranges {
|
||||
this.change_selections(None, cx, |selections| {
|
||||
selections.select_ranges(new_selected_ranges)
|
||||
});
|
||||
}
|
||||
|
||||
this.handle_input(text, cx);
|
||||
});
|
||||
|
||||
if !self.input_enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(transaction) = self.ime_transaction {
|
||||
self.buffer.update(cx, |buffer, cx| {
|
||||
buffer.group_until_transaction(transaction, cx);
|
||||
@@ -8673,6 +8863,7 @@ impl View for Editor {
|
||||
cx: &mut ViewContext<Self>,
|
||||
) {
|
||||
if !self.input_enabled {
|
||||
cx.emit(Event::InputIgnored { text: text.into() });
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -8697,6 +8888,29 @@ impl View for Editor {
|
||||
None
|
||||
};
|
||||
|
||||
let range_to_replace = ranges_to_replace.as_ref().and_then(|ranges_to_replace| {
|
||||
let newest_selection_id = this.selections.newest_anchor().id;
|
||||
this.selections
|
||||
.all::<OffsetUtf16>(cx)
|
||||
.iter()
|
||||
.zip(ranges_to_replace.iter())
|
||||
.find_map(|(selection, range)| {
|
||||
if selection.id == newest_selection_id {
|
||||
Some(
|
||||
(range.start.0 as isize - selection.head().0 as isize)
|
||||
..(range.end.0 as isize - selection.head().0 as isize),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
cx.emit(Event::InputHandled {
|
||||
utf16_range_to_replace: range_to_replace,
|
||||
text: text.into(),
|
||||
});
|
||||
|
||||
if let Some(ranges) = ranges_to_replace {
|
||||
this.change_selections(None, cx, |s| s.select_ranges(ranges));
|
||||
}
|
||||
@@ -9186,6 +9400,7 @@ pub fn split_words<'a>(text: &'a str) -> impl std::iter::Iterator<Item = &'a str
|
||||
None
|
||||
})
|
||||
.flat_map(|word| word.split_inclusive('_'))
|
||||
.flat_map(|word| word.split_inclusive('-'))
|
||||
}
|
||||
|
||||
trait RangeToAnchorExt {
|
||||
|
||||
@@ -19,7 +19,8 @@ use gpui::{
|
||||
use indoc::indoc;
|
||||
use language::{
|
||||
language_settings::{AllLanguageSettings, AllLanguageSettingsContent, LanguageSettingsContent},
|
||||
BracketPairConfig, FakeLspAdapter, LanguageConfig, LanguageRegistry, Point,
|
||||
BracketPairConfig, FakeLspAdapter, LanguageConfig, LanguageConfigOverride, LanguageRegistry,
|
||||
Override, Point,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use project::project_settings::{LspSettings, ProjectSettings};
|
||||
@@ -5339,7 +5340,7 @@ async fn test_completion(cx: &mut gpui::TestAppContext) {
|
||||
cx.condition(|editor, _| editor.context_menu_visible())
|
||||
.await;
|
||||
let apply_additional_edits = cx.update_editor(|editor, cx| {
|
||||
editor.move_down(&MoveDown, cx);
|
||||
editor.context_menu_next(&Default::default(), cx);
|
||||
editor
|
||||
.confirm_completion(&ConfirmCompletion::default(), cx)
|
||||
.unwrap()
|
||||
@@ -7688,6 +7689,105 @@ async fn test_completions_with_additional_edits(cx: &mut gpui::TestAppContext) {
|
||||
cx.assert_editor_state(indoc! {"fn main() { let a = Some(2)ˇ; }"});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_completions_in_languages_with_extra_word_characters(cx: &mut gpui::TestAppContext) {
|
||||
init_test(cx, |_| {});
|
||||
|
||||
let mut cx = EditorLspTestContext::new(
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
path_suffixes: vec!["jsx".into()],
|
||||
overrides: [(
|
||||
"element".into(),
|
||||
LanguageConfigOverride {
|
||||
word_characters: Override::Set(['-'].into_iter().collect()),
|
||||
..Default::default()
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_typescript::language_tsx()),
|
||||
)
|
||||
.with_override_query("(jsx_self_closing_element) @element")
|
||||
.unwrap(),
|
||||
lsp::ServerCapabilities {
|
||||
completion_provider: Some(lsp::CompletionOptions {
|
||||
trigger_characters: Some(vec![":".to_string()]),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
cx.lsp
|
||||
.handle_request::<lsp::request::Completion, _, _>(move |_, _| async move {
|
||||
Ok(Some(lsp::CompletionResponse::Array(vec![
|
||||
lsp::CompletionItem {
|
||||
label: "bg-blue".into(),
|
||||
..Default::default()
|
||||
},
|
||||
lsp::CompletionItem {
|
||||
label: "bg-red".into(),
|
||||
..Default::default()
|
||||
},
|
||||
lsp::CompletionItem {
|
||||
label: "bg-yellow".into(),
|
||||
..Default::default()
|
||||
},
|
||||
])))
|
||||
});
|
||||
|
||||
cx.set_state(r#"<p class="bgˇ" />"#);
|
||||
|
||||
// Trigger completion when typing a dash, because the dash is an extra
|
||||
// word character in the 'element' scope, which contains the cursor.
|
||||
cx.simulate_keystroke("-");
|
||||
cx.foreground().run_until_parked();
|
||||
cx.update_editor(|editor, _| {
|
||||
if let Some(ContextMenu::Completions(menu)) = &editor.context_menu {
|
||||
assert_eq!(
|
||||
menu.matches.iter().map(|m| &m.string).collect::<Vec<_>>(),
|
||||
&["bg-red", "bg-blue", "bg-yellow"]
|
||||
);
|
||||
} else {
|
||||
panic!("expected completion menu to be open");
|
||||
}
|
||||
});
|
||||
|
||||
cx.simulate_keystroke("l");
|
||||
cx.foreground().run_until_parked();
|
||||
cx.update_editor(|editor, _| {
|
||||
if let Some(ContextMenu::Completions(menu)) = &editor.context_menu {
|
||||
assert_eq!(
|
||||
menu.matches.iter().map(|m| &m.string).collect::<Vec<_>>(),
|
||||
&["bg-blue", "bg-yellow"]
|
||||
);
|
||||
} else {
|
||||
panic!("expected completion menu to be open");
|
||||
}
|
||||
});
|
||||
|
||||
// When filtering completions, consider the character after the '-' to
|
||||
// be the start of a subword.
|
||||
cx.set_state(r#"<p class="yelˇ" />"#);
|
||||
cx.simulate_keystroke("l");
|
||||
cx.foreground().run_until_parked();
|
||||
cx.update_editor(|editor, _| {
|
||||
if let Some(ContextMenu::Completions(menu)) = &editor.context_menu {
|
||||
assert_eq!(
|
||||
menu.matches.iter().map(|m| &m.string).collect::<Vec<_>>(),
|
||||
&["bg-yellow"]
|
||||
);
|
||||
} else {
|
||||
panic!("expected completion menu to be open");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn empty_range(row: usize, column: usize) -> Range<DisplayPoint> {
|
||||
let point = DisplayPoint::new(row as u32, column as u32);
|
||||
point..point
|
||||
@@ -7707,7 +7807,7 @@ fn assert_selection_ranges(marked_text: &str, view: &mut Editor, cx: &mut ViewCo
|
||||
/// Handle completion request passing a marked string specifying where the completion
|
||||
/// should be triggered from using '|' character, what range should be replaced, and what completions
|
||||
/// should be returned using '<' and '>' to delimit the range
|
||||
fn handle_completion_request<'a>(
|
||||
pub fn handle_completion_request<'a>(
|
||||
cx: &mut EditorLspTestContext<'a>,
|
||||
marked_string: &str,
|
||||
completions: Vec<&'static str>,
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
use super::{Bias, DisplayPoint, DisplaySnapshot, SelectionGoal, ToDisplayPoint};
|
||||
use crate::{char_kind, CharKind, ToPoint};
|
||||
use crate::{char_kind, CharKind, ToOffset, ToPoint};
|
||||
use language::Point;
|
||||
use std::ops::Range;
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum FindRange {
|
||||
SingleLine,
|
||||
MultiLine,
|
||||
}
|
||||
|
||||
pub fn left(map: &DisplaySnapshot, mut point: DisplayPoint) -> DisplayPoint {
|
||||
if point.column() > 0 {
|
||||
*point.column_mut() -= 1;
|
||||
@@ -177,20 +183,21 @@ pub fn line_end(
|
||||
|
||||
pub fn previous_word_start(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint {
|
||||
let raw_point = point.to_point(map);
|
||||
let language = map.buffer_snapshot.language_at(raw_point);
|
||||
let scope = map.buffer_snapshot.language_scope_at(raw_point);
|
||||
|
||||
find_preceding_boundary(map, point, |left, right| {
|
||||
(char_kind(language, left) != char_kind(language, right) && !right.is_whitespace())
|
||||
find_preceding_boundary(map, point, FindRange::MultiLine, |left, right| {
|
||||
(char_kind(&scope, left) != char_kind(&scope, right) && !right.is_whitespace())
|
||||
|| left == '\n'
|
||||
})
|
||||
}
|
||||
|
||||
pub fn previous_subword_start(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint {
|
||||
let raw_point = point.to_point(map);
|
||||
let language = map.buffer_snapshot.language_at(raw_point);
|
||||
find_preceding_boundary(map, point, |left, right| {
|
||||
let scope = map.buffer_snapshot.language_scope_at(raw_point);
|
||||
|
||||
find_preceding_boundary(map, point, FindRange::MultiLine, |left, right| {
|
||||
let is_word_start =
|
||||
char_kind(language, left) != char_kind(language, right) && !right.is_whitespace();
|
||||
char_kind(&scope, left) != char_kind(&scope, right) && !right.is_whitespace();
|
||||
let is_subword_start =
|
||||
left == '_' && right != '_' || left.is_lowercase() && right.is_uppercase();
|
||||
is_word_start || is_subword_start || left == '\n'
|
||||
@@ -199,19 +206,21 @@ pub fn previous_subword_start(map: &DisplaySnapshot, point: DisplayPoint) -> Dis
|
||||
|
||||
pub fn next_word_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint {
|
||||
let raw_point = point.to_point(map);
|
||||
let language = map.buffer_snapshot.language_at(raw_point);
|
||||
find_boundary(map, point, |left, right| {
|
||||
(char_kind(language, left) != char_kind(language, right) && !left.is_whitespace())
|
||||
let scope = map.buffer_snapshot.language_scope_at(raw_point);
|
||||
|
||||
find_boundary(map, point, FindRange::MultiLine, |left, right| {
|
||||
(char_kind(&scope, left) != char_kind(&scope, right) && !left.is_whitespace())
|
||||
|| right == '\n'
|
||||
})
|
||||
}
|
||||
|
||||
pub fn next_subword_end(map: &DisplaySnapshot, point: DisplayPoint) -> DisplayPoint {
|
||||
let raw_point = point.to_point(map);
|
||||
let language = map.buffer_snapshot.language_at(raw_point);
|
||||
find_boundary(map, point, |left, right| {
|
||||
let scope = map.buffer_snapshot.language_scope_at(raw_point);
|
||||
|
||||
find_boundary(map, point, FindRange::MultiLine, |left, right| {
|
||||
let is_word_end =
|
||||
(char_kind(language, left) != char_kind(language, right)) && !left.is_whitespace();
|
||||
(char_kind(&scope, left) != char_kind(&scope, right)) && !left.is_whitespace();
|
||||
let is_subword_end =
|
||||
left != '_' && right == '_' || left.is_lowercase() && right.is_uppercase();
|
||||
is_word_end || is_subword_end || right == '\n'
|
||||
@@ -272,79 +281,34 @@ pub fn end_of_paragraph(
|
||||
map.max_point()
|
||||
}
|
||||
|
||||
/// Scans for a boundary preceding the given start point `from` until a boundary is found, indicated by the
|
||||
/// given predicate returning true. The predicate is called with the character to the left and right
|
||||
/// of the candidate boundary location, and will be called with `\n` characters indicating the start
|
||||
/// or end of a line.
|
||||
/// Scans for a boundary preceding the given start point `from` until a boundary is found,
|
||||
/// indicated by the given predicate returning true.
|
||||
/// The predicate is called with the character to the left and right of the candidate boundary location.
|
||||
/// If FindRange::SingleLine is specified and no boundary is found before the start of the current line, the start of the current line will be returned.
|
||||
pub fn find_preceding_boundary(
|
||||
map: &DisplaySnapshot,
|
||||
from: DisplayPoint,
|
||||
find_range: FindRange,
|
||||
mut is_boundary: impl FnMut(char, char) -> bool,
|
||||
) -> DisplayPoint {
|
||||
let mut start_column = 0;
|
||||
let mut soft_wrap_row = from.row() + 1;
|
||||
let mut prev_ch = None;
|
||||
let mut offset = from.to_point(map).to_offset(&map.buffer_snapshot);
|
||||
|
||||
let mut prev = None;
|
||||
for (ch, point) in map.reverse_chars_at(from) {
|
||||
// Recompute soft_wrap_indent if the row has changed
|
||||
if point.row() != soft_wrap_row {
|
||||
soft_wrap_row = point.row();
|
||||
|
||||
if point.row() == 0 {
|
||||
start_column = 0;
|
||||
} else if let Some(indent) = map.soft_wrap_indent(point.row() - 1) {
|
||||
start_column = indent;
|
||||
}
|
||||
}
|
||||
|
||||
// If the current point is in the soft_wrap, skip comparing it
|
||||
if point.column() < start_column {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((prev_ch, prev_point)) = prev {
|
||||
if is_boundary(ch, prev_ch) {
|
||||
return map.clip_point(prev_point, Bias::Left);
|
||||
}
|
||||
}
|
||||
|
||||
prev = Some((ch, point));
|
||||
}
|
||||
map.clip_point(DisplayPoint::zero(), Bias::Left)
|
||||
}
|
||||
|
||||
/// Scans for a boundary preceding the given start point `from` until a boundary is found, indicated by the
|
||||
/// given predicate returning true. The predicate is called with the character to the left and right
|
||||
/// of the candidate boundary location, and will be called with `\n` characters indicating the start
|
||||
/// or end of a line. If no boundary is found, the start of the line is returned.
|
||||
pub fn find_preceding_boundary_in_line(
|
||||
map: &DisplaySnapshot,
|
||||
from: DisplayPoint,
|
||||
mut is_boundary: impl FnMut(char, char) -> bool,
|
||||
) -> DisplayPoint {
|
||||
let mut start_column = 0;
|
||||
if from.row() > 0 {
|
||||
if let Some(indent) = map.soft_wrap_indent(from.row() - 1) {
|
||||
start_column = indent;
|
||||
}
|
||||
}
|
||||
|
||||
let mut prev = None;
|
||||
for (ch, point) in map.reverse_chars_at(from) {
|
||||
if let Some((prev_ch, prev_point)) = prev {
|
||||
if is_boundary(ch, prev_ch) {
|
||||
return map.clip_point(prev_point, Bias::Left);
|
||||
}
|
||||
}
|
||||
|
||||
if ch == '\n' || point.column() < start_column {
|
||||
for ch in map.buffer_snapshot.reversed_chars_at(offset) {
|
||||
if find_range == FindRange::SingleLine && ch == '\n' {
|
||||
break;
|
||||
}
|
||||
if let Some(prev_ch) = prev_ch {
|
||||
if is_boundary(ch, prev_ch) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
prev = Some((ch, point));
|
||||
offset -= ch.len_utf8();
|
||||
prev_ch = Some(ch);
|
||||
}
|
||||
|
||||
map.clip_point(prev.map(|(_, point)| point).unwrap_or(from), Bias::Left)
|
||||
map.clip_point(offset.to_display_point(map), Bias::Left)
|
||||
}
|
||||
|
||||
/// Scans for a boundary following the given start point until a boundary is found, indicated by the
|
||||
@@ -354,59 +318,38 @@ pub fn find_preceding_boundary_in_line(
|
||||
pub fn find_boundary(
|
||||
map: &DisplaySnapshot,
|
||||
from: DisplayPoint,
|
||||
find_range: FindRange,
|
||||
mut is_boundary: impl FnMut(char, char) -> bool,
|
||||
) -> DisplayPoint {
|
||||
let mut offset = from.to_offset(&map, Bias::Right);
|
||||
let mut prev_ch = None;
|
||||
for (ch, point) in map.chars_at(from) {
|
||||
if let Some(prev_ch) = prev_ch {
|
||||
if is_boundary(prev_ch, ch) {
|
||||
return map.clip_point(point, Bias::Right);
|
||||
}
|
||||
}
|
||||
|
||||
prev_ch = Some(ch);
|
||||
}
|
||||
map.clip_point(map.max_point(), Bias::Right)
|
||||
}
|
||||
|
||||
/// Scans for a boundary following the given start point until a boundary is found, indicated by the
|
||||
/// given predicate returning true. The predicate is called with the character to the left and right
|
||||
/// of the candidate boundary location, and will be called with `\n` characters indicating the start
|
||||
/// or end of a line. If no boundary is found, the end of the line is returned
|
||||
pub fn find_boundary_in_line(
|
||||
map: &DisplaySnapshot,
|
||||
from: DisplayPoint,
|
||||
mut is_boundary: impl FnMut(char, char) -> bool,
|
||||
) -> DisplayPoint {
|
||||
let mut prev = None;
|
||||
for (ch, point) in map.chars_at(from) {
|
||||
if let Some((prev_ch, _)) = prev {
|
||||
if is_boundary(prev_ch, ch) {
|
||||
return map.clip_point(point, Bias::Right);
|
||||
}
|
||||
}
|
||||
|
||||
prev = Some((ch, point));
|
||||
|
||||
if ch == '\n' {
|
||||
for ch in map.buffer_snapshot.chars_at(offset) {
|
||||
if find_range == FindRange::SingleLine && ch == '\n' {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if let Some(prev_ch) = prev_ch {
|
||||
if is_boundary(prev_ch, ch) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Return the last position checked so that we give a point right before the newline or eof.
|
||||
map.clip_point(prev.map(|(_, point)| point).unwrap_or(from), Bias::Right)
|
||||
offset += ch.len_utf8();
|
||||
prev_ch = Some(ch);
|
||||
}
|
||||
map.clip_point(offset.to_display_point(map), Bias::Right)
|
||||
}
|
||||
|
||||
pub fn is_inside_word(map: &DisplaySnapshot, point: DisplayPoint) -> bool {
|
||||
let raw_point = point.to_point(map);
|
||||
let language = map.buffer_snapshot.language_at(raw_point);
|
||||
let scope = map.buffer_snapshot.language_scope_at(raw_point);
|
||||
let ix = map.clip_point(point, Bias::Left).to_offset(map, Bias::Left);
|
||||
let text = &map.buffer_snapshot;
|
||||
let next_char_kind = text.chars_at(ix).next().map(|c| char_kind(language, c));
|
||||
let next_char_kind = text.chars_at(ix).next().map(|c| char_kind(&scope, c));
|
||||
let prev_char_kind = text
|
||||
.reversed_chars_at(ix)
|
||||
.next()
|
||||
.map(|c| char_kind(language, c));
|
||||
.map(|c| char_kind(&scope, c));
|
||||
prev_char_kind.zip(next_char_kind) == Some((CharKind::Word, CharKind::Word))
|
||||
}
|
||||
|
||||
@@ -533,7 +476,12 @@ mod tests {
|
||||
) {
|
||||
let (snapshot, display_points) = marked_display_snapshot(marked_text, cx);
|
||||
assert_eq!(
|
||||
find_preceding_boundary(&snapshot, display_points[1], is_boundary),
|
||||
find_preceding_boundary(
|
||||
&snapshot,
|
||||
display_points[1],
|
||||
FindRange::MultiLine,
|
||||
is_boundary
|
||||
),
|
||||
display_points[0]
|
||||
);
|
||||
}
|
||||
@@ -612,21 +560,15 @@ mod tests {
|
||||
find_preceding_boundary(
|
||||
&snapshot,
|
||||
buffer_snapshot.len().to_display_point(&snapshot),
|
||||
|left, _| left == 'a',
|
||||
FindRange::MultiLine,
|
||||
|left, _| left == 'e',
|
||||
),
|
||||
0.to_display_point(&snapshot),
|
||||
snapshot
|
||||
.buffer_snapshot
|
||||
.offset_to_point(5)
|
||||
.to_display_point(&snapshot),
|
||||
"Should not stop at inlays when looking for boundaries"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
find_preceding_boundary_in_line(
|
||||
&snapshot,
|
||||
buffer_snapshot.len().to_display_point(&snapshot),
|
||||
|left, _| left == 'a',
|
||||
),
|
||||
0.to_display_point(&snapshot),
|
||||
"Should not stop at inlays when looking for boundaries in line"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
@@ -699,7 +641,12 @@ mod tests {
|
||||
) {
|
||||
let (snapshot, display_points) = marked_display_snapshot(marked_text, cx);
|
||||
assert_eq!(
|
||||
find_boundary(&snapshot, display_points[0], is_boundary),
|
||||
find_boundary(
|
||||
&snapshot,
|
||||
display_points[0],
|
||||
FindRange::MultiLine,
|
||||
is_boundary
|
||||
),
|
||||
display_points[1]
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1417,13 +1417,13 @@ impl MultiBuffer {
|
||||
return false;
|
||||
}
|
||||
|
||||
let language = self.language_at(position.clone(), cx);
|
||||
|
||||
if char_kind(language.as_ref(), char) == CharKind::Word {
|
||||
let snapshot = self.snapshot(cx);
|
||||
let position = position.to_offset(&snapshot);
|
||||
let scope = snapshot.language_scope_at(position);
|
||||
if char_kind(&scope, char) == CharKind::Word {
|
||||
return true;
|
||||
}
|
||||
|
||||
let snapshot = self.snapshot(cx);
|
||||
let anchor = snapshot.anchor_before(position);
|
||||
anchor
|
||||
.buffer_id
|
||||
@@ -1925,8 +1925,8 @@ impl MultiBufferSnapshot {
|
||||
let mut next_chars = self.chars_at(start).peekable();
|
||||
let mut prev_chars = self.reversed_chars_at(start).peekable();
|
||||
|
||||
let language = self.language_at(start);
|
||||
let kind = |c| char_kind(language, c);
|
||||
let scope = self.language_scope_at(start);
|
||||
let kind = |c| char_kind(&scope, c);
|
||||
let word_kind = cmp::max(
|
||||
prev_chars.peek().copied().map(kind),
|
||||
next_chars.peek().copied().map(kind),
|
||||
|
||||
@@ -378,10 +378,6 @@ impl Editor {
|
||||
return;
|
||||
}
|
||||
|
||||
if amount.move_context_menu_selection(self, cx) {
|
||||
return;
|
||||
}
|
||||
|
||||
let cur_position = self.scroll_position(cx);
|
||||
let new_pos = cur_position + vec2f(0., amount.lines(self));
|
||||
self.set_scroll_position(new_pos, cx);
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
use gpui::ViewContext;
|
||||
use serde::Deserialize;
|
||||
use util::iife;
|
||||
|
||||
use crate::Editor;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Clone, PartialEq, Deserialize)]
|
||||
pub enum ScrollAmount {
|
||||
@@ -13,25 +10,6 @@ pub enum ScrollAmount {
|
||||
}
|
||||
|
||||
impl ScrollAmount {
|
||||
pub fn move_context_menu_selection(
|
||||
&self,
|
||||
editor: &mut Editor,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) -> bool {
|
||||
iife!({
|
||||
let context_menu = editor.context_menu.as_mut()?;
|
||||
|
||||
match self {
|
||||
Self::Line(c) if *c > 0. => context_menu.select_next(cx),
|
||||
Self::Line(_) => context_menu.select_prev(cx),
|
||||
Self::Page(c) if *c > 0. => context_menu.select_last(cx),
|
||||
Self::Page(_) => context_menu.select_first(cx),
|
||||
}
|
||||
.then_some(())
|
||||
})
|
||||
.is_some()
|
||||
}
|
||||
|
||||
pub fn lines(&self, editor: &mut Editor) -> f32 {
|
||||
match self {
|
||||
Self::Line(count) => *count,
|
||||
@@ -39,7 +17,7 @@ impl ScrollAmount {
|
||||
.visible_line_count()
|
||||
// subtract one to leave an anchor line
|
||||
// round towards zero (so page-up and page-down are symmetric)
|
||||
.map(|l| ((l - 1.) * count).trunc())
|
||||
.map(|l| (l * count).trunc() - count.signum())
|
||||
.unwrap_or(0.),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ impl<'a> EditorLspTestContext<'a> {
|
||||
language
|
||||
.path_suffixes()
|
||||
.first()
|
||||
.unwrap_or(&"txt".to_string())
|
||||
.expect("language must have a path suffix for EditorLspTestContext")
|
||||
);
|
||||
|
||||
let mut fake_servers = language
|
||||
|
||||
@@ -42,14 +42,14 @@ impl View for FeedbackInfoText {
|
||||
)
|
||||
.with_child(
|
||||
MouseEventHandler::new::<OpenZedCommunityRepo, _>(0, cx, |state, _| {
|
||||
let contained_text = if state.hovered() {
|
||||
let style = if state.hovered() {
|
||||
&theme.feedback.link_text_hover
|
||||
} else {
|
||||
&theme.feedback.link_text_default
|
||||
};
|
||||
|
||||
Label::new("community repo", contained_text.text.clone())
|
||||
Label::new("community repo", style.text.clone())
|
||||
.contained()
|
||||
.with_style(style.container)
|
||||
.aligned()
|
||||
.left()
|
||||
.clipped()
|
||||
@@ -64,6 +64,8 @@ impl View for FeedbackInfoText {
|
||||
.with_soft_wrap(false)
|
||||
.aligned(),
|
||||
)
|
||||
.contained()
|
||||
.with_style(theme.feedback.info_text_default.container)
|
||||
.aligned()
|
||||
.left()
|
||||
.clipped()
|
||||
|
||||
@@ -1528,8 +1528,13 @@ mod tests {
|
||||
let active_pane = cx.read(|cx| workspace.read(cx).active_pane().clone());
|
||||
active_pane
|
||||
.update(cx, |pane, cx| {
|
||||
pane.close_active_item(&workspace::CloseActiveItem, cx)
|
||||
.unwrap()
|
||||
pane.close_active_item(
|
||||
&workspace::CloseActiveItem {
|
||||
save_behavior: None,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
.unwrap()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -3513,14 +3513,12 @@ impl<'a, 'b, 'c, V> LayoutContext<'a, 'b, 'c, V> {
|
||||
handler_depth = Some(contexts.len())
|
||||
}
|
||||
|
||||
let action_contexts = if let Some(depth) = handler_depth {
|
||||
&contexts[depth..]
|
||||
} else {
|
||||
&contexts
|
||||
};
|
||||
|
||||
self.keystroke_matcher
|
||||
.keystrokes_for_action(action, action_contexts)
|
||||
let handler_depth = handler_depth.unwrap_or(0);
|
||||
(0..=handler_depth).find_map(|depth| {
|
||||
let contexts = &contexts[depth..];
|
||||
self.keystroke_matcher
|
||||
.keystrokes_for_action(action, contexts)
|
||||
})
|
||||
}
|
||||
|
||||
fn notify_if_view_ancestors_change(&mut self, view_id: usize) {
|
||||
@@ -6499,7 +6497,7 @@ mod tests {
|
||||
|
||||
#[crate::test(self)]
|
||||
fn test_keystrokes_for_action(cx: &mut TestAppContext) {
|
||||
actions!(test, [Action1, Action2, GlobalAction]);
|
||||
actions!(test, [Action1, Action2, Action3, GlobalAction]);
|
||||
|
||||
struct View1 {
|
||||
child: ViewHandle<View2>,
|
||||
@@ -6542,12 +6540,14 @@ mod tests {
|
||||
|
||||
cx.update(|cx| {
|
||||
cx.add_action(|_: &mut View1, _: &Action1, _cx| {});
|
||||
cx.add_action(|_: &mut View1, _: &Action3, _cx| {});
|
||||
cx.add_action(|_: &mut View2, _: &Action2, _cx| {});
|
||||
cx.add_global_action(|_: &GlobalAction, _| {});
|
||||
cx.add_bindings(vec![
|
||||
Binding::new("a", Action1, Some("View1")),
|
||||
Binding::new("b", Action2, Some("View1 > View2")),
|
||||
Binding::new("c", GlobalAction, Some("View3")), // View 3 does not exist
|
||||
Binding::new("c", Action3, Some("View2")),
|
||||
Binding::new("d", GlobalAction, Some("View3")), // View 3 does not exist
|
||||
]);
|
||||
});
|
||||
|
||||
@@ -6577,6 +6577,14 @@ mod tests {
|
||||
.as_slice(),
|
||||
&[Keystroke::parse("b").unwrap()]
|
||||
);
|
||||
assert_eq!(layout_cx.keystrokes_for_action(view_1.id(), &Action3), None);
|
||||
assert_eq!(
|
||||
layout_cx
|
||||
.keystrokes_for_action(view_2.id(), &Action3)
|
||||
.unwrap()
|
||||
.as_slice(),
|
||||
&[Keystroke::parse("c").unwrap()]
|
||||
);
|
||||
|
||||
// The 'a' keystroke propagates up the view tree from view_2
|
||||
// to view_1. The action, Action1, is handled by view_1.
|
||||
@@ -6604,7 +6612,8 @@ mod tests {
|
||||
&available_actions(window.into(), view_1.id(), cx),
|
||||
&[
|
||||
("test::Action1", vec![Keystroke::parse("a").unwrap()]),
|
||||
("test::GlobalAction", vec![])
|
||||
("test::Action3", vec![]),
|
||||
("test::GlobalAction", vec![]),
|
||||
],
|
||||
);
|
||||
|
||||
@@ -6614,6 +6623,7 @@ mod tests {
|
||||
&[
|
||||
("test::Action1", vec![Keystroke::parse("a").unwrap()]),
|
||||
("test::Action2", vec![Keystroke::parse("b").unwrap()]),
|
||||
("test::Action3", vec![Keystroke::parse("c").unwrap()]),
|
||||
("test::GlobalAction", vec![]),
|
||||
],
|
||||
);
|
||||
|
||||
@@ -1110,7 +1110,7 @@ impl<'a> WindowContext<'a> {
|
||||
self.window.is_fullscreen
|
||||
}
|
||||
|
||||
pub(crate) fn dispatch_action(&mut self, view_id: Option<usize>, action: &dyn Action) -> bool {
|
||||
pub fn dispatch_action(&mut self, view_id: Option<usize>, action: &dyn Action) -> bool {
|
||||
if let Some(view_id) = view_id {
|
||||
self.halt_action_dispatch = false;
|
||||
self.visit_dispatch_path(view_id, |view_id, capture_phase, cx| {
|
||||
|
||||
@@ -106,6 +106,7 @@ pub struct Deterministic {
|
||||
parker: parking_lot::Mutex<parking::Parker>,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub enum Timer {
|
||||
Production(smol::Timer),
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
|
||||
@@ -37,8 +37,14 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
|
||||
Some("seed") => starting_seed = parse_int(&meta.lit)?,
|
||||
Some("on_failure") => {
|
||||
if let Lit::Str(name) = meta.lit {
|
||||
let ident = Ident::new(&name.value(), name.span());
|
||||
on_failure_fn_name = quote!(Some(#ident));
|
||||
let mut path = syn::Path {
|
||||
leading_colon: None,
|
||||
segments: Default::default(),
|
||||
};
|
||||
for part in name.value().split("::") {
|
||||
path.segments.push(Ident::new(part, name.span()).into());
|
||||
}
|
||||
on_failure_fn_name = quote!(Some(#path));
|
||||
} else {
|
||||
return Err(TokenStream::from(
|
||||
syn::Error::new(
|
||||
|
||||
@@ -148,6 +148,7 @@ pub struct Completion {
|
||||
pub old_range: Range<Anchor>,
|
||||
pub new_text: String,
|
||||
pub label: CodeLabel,
|
||||
pub server_id: LanguageServerId,
|
||||
pub lsp_completion: lsp::CompletionItem,
|
||||
}
|
||||
|
||||
@@ -438,7 +439,7 @@ impl Buffer {
|
||||
operations.extend(
|
||||
text_operations
|
||||
.iter()
|
||||
.filter(|(_, op)| !since.observed(op.local_timestamp()))
|
||||
.filter(|(_, op)| !since.observed(op.timestamp()))
|
||||
.map(|(_, op)| proto::serialize_operation(&Operation::Buffer(op.clone()))),
|
||||
);
|
||||
operations.sort_unstable_by_key(proto::lamport_timestamp_for_operation);
|
||||
@@ -1303,7 +1304,7 @@ impl Buffer {
|
||||
|
||||
pub fn wait_for_edits(
|
||||
&mut self,
|
||||
edit_ids: impl IntoIterator<Item = clock::Local>,
|
||||
edit_ids: impl IntoIterator<Item = clock::Lamport>,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
self.text.wait_for_edits(edit_ids)
|
||||
}
|
||||
@@ -1361,7 +1362,7 @@ impl Buffer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_text<T>(&mut self, text: T, cx: &mut ModelContext<Self>) -> Option<clock::Local>
|
||||
pub fn set_text<T>(&mut self, text: T, cx: &mut ModelContext<Self>) -> Option<clock::Lamport>
|
||||
where
|
||||
T: Into<Arc<str>>,
|
||||
{
|
||||
@@ -1374,7 +1375,7 @@ impl Buffer {
|
||||
edits_iter: I,
|
||||
autoindent_mode: Option<AutoindentMode>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Option<clock::Local>
|
||||
) -> Option<clock::Lamport>
|
||||
where
|
||||
I: IntoIterator<Item = (Range<S>, T)>,
|
||||
S: ToOffset,
|
||||
@@ -1411,7 +1412,7 @@ impl Buffer {
|
||||
.and_then(|mode| self.language.as_ref().map(|_| (self.snapshot(), mode)));
|
||||
|
||||
let edit_operation = self.text.edit(edits.iter().cloned());
|
||||
let edit_id = edit_operation.local_timestamp();
|
||||
let edit_id = edit_operation.timestamp();
|
||||
|
||||
if let Some((before_edit, mode)) = autoindent_request {
|
||||
let mut delta = 0isize;
|
||||
@@ -2216,8 +2217,8 @@ impl BufferSnapshot {
|
||||
let mut next_chars = self.chars_at(start).peekable();
|
||||
let mut prev_chars = self.reversed_chars_at(start).peekable();
|
||||
|
||||
let language = self.language_at(start);
|
||||
let kind = |c| char_kind(language, c);
|
||||
let scope = self.language_scope_at(start);
|
||||
let kind = |c| char_kind(&scope, c);
|
||||
let word_kind = cmp::max(
|
||||
prev_chars.peek().copied().map(kind),
|
||||
next_chars.peek().copied().map(kind),
|
||||
@@ -3031,17 +3032,21 @@ pub fn contiguous_ranges(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn char_kind(language: Option<&Arc<Language>>, c: char) -> CharKind {
|
||||
pub fn char_kind(scope: &Option<LanguageScope>, c: char) -> CharKind {
|
||||
if c.is_whitespace() {
|
||||
return CharKind::Whitespace;
|
||||
} else if c.is_alphanumeric() || c == '_' {
|
||||
return CharKind::Word;
|
||||
}
|
||||
if let Some(language) = language {
|
||||
if language.config.word_characters.contains(&c) {
|
||||
return CharKind::Word;
|
||||
|
||||
if let Some(scope) = scope {
|
||||
if let Some(characters) = scope.word_characters() {
|
||||
if characters.contains(&c) {
|
||||
return CharKind::Word;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CharKind::Punctuation
|
||||
}
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ use theme::{SyntaxTheme, Theme};
|
||||
use tree_sitter::{self, Query};
|
||||
use unicase::UniCase;
|
||||
use util::{http::HttpClient, paths::PathExt};
|
||||
use util::{merge_json_value_into, post_inc, ResultExt, TryFutureExt as _, UnwrapFuture};
|
||||
use util::{post_inc, ResultExt, TryFutureExt as _, UnwrapFuture};
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
use futures::channel::mpsc;
|
||||
@@ -91,6 +91,7 @@ pub struct LanguageServerName(pub Arc<str>);
|
||||
/// once at startup, and caches the results.
|
||||
pub struct CachedLspAdapter {
|
||||
pub name: LanguageServerName,
|
||||
pub short_name: &'static str,
|
||||
pub initialization_options: Option<Value>,
|
||||
pub disk_based_diagnostic_sources: Vec<String>,
|
||||
pub disk_based_diagnostics_progress_token: Option<String>,
|
||||
@@ -101,6 +102,7 @@ pub struct CachedLspAdapter {
|
||||
impl CachedLspAdapter {
|
||||
pub async fn new(adapter: Arc<dyn LspAdapter>) -> Arc<Self> {
|
||||
let name = adapter.name().await;
|
||||
let short_name = adapter.short_name();
|
||||
let initialization_options = adapter.initialization_options().await;
|
||||
let disk_based_diagnostic_sources = adapter.disk_based_diagnostic_sources().await;
|
||||
let disk_based_diagnostics_progress_token =
|
||||
@@ -109,6 +111,7 @@ impl CachedLspAdapter {
|
||||
|
||||
Arc::new(CachedLspAdapter {
|
||||
name,
|
||||
short_name,
|
||||
initialization_options,
|
||||
disk_based_diagnostic_sources,
|
||||
disk_based_diagnostics_progress_token,
|
||||
@@ -176,10 +179,7 @@ impl CachedLspAdapter {
|
||||
self.adapter.code_action_kinds()
|
||||
}
|
||||
|
||||
pub fn workspace_configuration(
|
||||
&self,
|
||||
cx: &mut AppContext,
|
||||
) -> Option<BoxFuture<'static, Value>> {
|
||||
pub fn workspace_configuration(&self, cx: &mut AppContext) -> BoxFuture<'static, Value> {
|
||||
self.adapter.workspace_configuration(cx)
|
||||
}
|
||||
|
||||
@@ -220,6 +220,8 @@ pub trait LspAdapterDelegate: Send + Sync {
|
||||
pub trait LspAdapter: 'static + Send + Sync {
|
||||
async fn name(&self) -> LanguageServerName;
|
||||
|
||||
fn short_name(&self) -> &'static str;
|
||||
|
||||
async fn fetch_latest_server_version(
|
||||
&self,
|
||||
delegate: &dyn LspAdapterDelegate,
|
||||
@@ -288,8 +290,8 @@ pub trait LspAdapter: 'static + Send + Sync {
|
||||
None
|
||||
}
|
||||
|
||||
fn workspace_configuration(&self, _: &mut AppContext) -> Option<BoxFuture<'static, Value>> {
|
||||
None
|
||||
fn workspace_configuration(&self, _: &mut AppContext) -> BoxFuture<'static, Value> {
|
||||
futures::future::ready(serde_json::json!({})).boxed()
|
||||
}
|
||||
|
||||
fn code_action_kinds(&self) -> Option<Vec<CodeActionKind>> {
|
||||
@@ -344,6 +346,8 @@ pub struct LanguageConfig {
|
||||
#[serde(default)]
|
||||
pub block_comment: Option<(Arc<str>, Arc<str>)>,
|
||||
#[serde(default)]
|
||||
pub scope_opt_in_language_servers: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub overrides: HashMap<String, LanguageConfigOverride>,
|
||||
#[serde(default)]
|
||||
pub word_characters: HashSet<char>,
|
||||
@@ -374,6 +378,10 @@ pub struct LanguageConfigOverride {
|
||||
pub block_comment: Override<(Arc<str>, Arc<str>)>,
|
||||
#[serde(skip_deserializing)]
|
||||
pub disabled_bracket_ixs: Vec<u16>,
|
||||
#[serde(default)]
|
||||
pub word_characters: Override<HashSet<char>>,
|
||||
#[serde(default)]
|
||||
pub opt_into_language_servers: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Debug)]
|
||||
@@ -412,6 +420,7 @@ impl Default for LanguageConfig {
|
||||
autoclose_before: Default::default(),
|
||||
line_comment: Default::default(),
|
||||
block_comment: Default::default(),
|
||||
scope_opt_in_language_servers: Default::default(),
|
||||
overrides: Default::default(),
|
||||
collapsed_placeholder: Default::default(),
|
||||
word_characters: Default::default(),
|
||||
@@ -686,41 +695,6 @@ impl LanguageRegistry {
|
||||
result
|
||||
}
|
||||
|
||||
pub fn workspace_configuration(&self, cx: &mut AppContext) -> Task<serde_json::Value> {
|
||||
let lsp_adapters = {
|
||||
let state = self.state.read();
|
||||
state
|
||||
.available_languages
|
||||
.iter()
|
||||
.filter(|l| !l.loaded)
|
||||
.flat_map(|l| l.lsp_adapters.clone())
|
||||
.chain(
|
||||
state
|
||||
.languages
|
||||
.iter()
|
||||
.flat_map(|language| &language.adapters)
|
||||
.map(|adapter| adapter.adapter.clone()),
|
||||
)
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let mut language_configs = Vec::new();
|
||||
for adapter in &lsp_adapters {
|
||||
if let Some(language_config) = adapter.workspace_configuration(cx) {
|
||||
language_configs.push(language_config);
|
||||
}
|
||||
}
|
||||
|
||||
cx.background().spawn(async move {
|
||||
let mut config = serde_json::json!({});
|
||||
let language_configs = futures::future::join_all(language_configs).await;
|
||||
for language_config in language_configs {
|
||||
merge_json_value_into(language_config, &mut config);
|
||||
}
|
||||
config
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add(&self, language: Arc<Language>) {
|
||||
self.state.write().add(language);
|
||||
}
|
||||
@@ -1384,13 +1358,23 @@ impl Language {
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn with_override_query(mut self, source: &str) -> Result<Self> {
|
||||
pub fn with_override_query(mut self, source: &str) -> anyhow::Result<Self> {
|
||||
let query = Query::new(self.grammar_mut().ts_language, source)?;
|
||||
|
||||
let mut override_configs_by_id = HashMap::default();
|
||||
for (ix, name) in query.capture_names().iter().enumerate() {
|
||||
if !name.starts_with('_') {
|
||||
let value = self.config.overrides.remove(name).unwrap_or_default();
|
||||
for server_name in &value.opt_into_language_servers {
|
||||
if !self
|
||||
.config
|
||||
.scope_opt_in_language_servers
|
||||
.contains(server_name)
|
||||
{
|
||||
util::debug_panic!("Server {server_name:?} has been opted-in by scope {name:?} but has not been marked as an opt-in server");
|
||||
}
|
||||
}
|
||||
|
||||
override_configs_by_id.insert(ix as u32, (name.clone(), value));
|
||||
}
|
||||
}
|
||||
@@ -1596,6 +1580,13 @@ impl LanguageScope {
|
||||
.map(|e| (&e.0, &e.1))
|
||||
}
|
||||
|
||||
pub fn word_characters(&self) -> Option<&HashSet<char>> {
|
||||
Override::as_option(
|
||||
self.config_override().map(|o| &o.word_characters),
|
||||
Some(&self.language.config.word_characters),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn brackets(&self) -> impl Iterator<Item = (&BracketPair, bool)> {
|
||||
let mut disabled_ids = self
|
||||
.config_override()
|
||||
@@ -1622,6 +1613,20 @@ impl LanguageScope {
|
||||
c.is_whitespace() || self.language.config.autoclose_before.contains(c)
|
||||
}
|
||||
|
||||
pub fn language_allowed(&self, name: &LanguageServerName) -> bool {
|
||||
let config = &self.language.config;
|
||||
let opt_in_servers = &config.scope_opt_in_language_servers;
|
||||
if opt_in_servers.iter().any(|o| *o == *name.0) {
|
||||
if let Some(over) = self.config_override() {
|
||||
over.opt_into_language_servers.iter().any(|o| *o == *name.0)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn config_override(&self) -> Option<&LanguageConfigOverride> {
|
||||
let id = self.override_id?;
|
||||
let grammar = self.language.grammar.as_ref()?;
|
||||
@@ -1726,6 +1731,10 @@ impl LspAdapter for Arc<FakeLspAdapter> {
|
||||
LanguageServerName(self.name.into())
|
||||
}
|
||||
|
||||
fn short_name(&self) -> &'static str {
|
||||
"FakeLspAdapter"
|
||||
}
|
||||
|
||||
async fn fetch_latest_server_version(
|
||||
&self,
|
||||
_: &dyn LspAdapterDelegate,
|
||||
|
||||
@@ -41,24 +41,22 @@ pub fn serialize_operation(operation: &crate::Operation) -> proto::Operation {
|
||||
proto::operation::Variant::Edit(serialize_edit_operation(edit))
|
||||
}
|
||||
|
||||
crate::Operation::Buffer(text::Operation::Undo {
|
||||
undo,
|
||||
lamport_timestamp,
|
||||
}) => proto::operation::Variant::Undo(proto::operation::Undo {
|
||||
replica_id: undo.id.replica_id as u32,
|
||||
local_timestamp: undo.id.value,
|
||||
lamport_timestamp: lamport_timestamp.value,
|
||||
version: serialize_version(&undo.version),
|
||||
counts: undo
|
||||
.counts
|
||||
.iter()
|
||||
.map(|(edit_id, count)| proto::UndoCount {
|
||||
replica_id: edit_id.replica_id as u32,
|
||||
local_timestamp: edit_id.value,
|
||||
count: *count,
|
||||
})
|
||||
.collect(),
|
||||
}),
|
||||
crate::Operation::Buffer(text::Operation::Undo(undo)) => {
|
||||
proto::operation::Variant::Undo(proto::operation::Undo {
|
||||
replica_id: undo.timestamp.replica_id as u32,
|
||||
lamport_timestamp: undo.timestamp.value,
|
||||
version: serialize_version(&undo.version),
|
||||
counts: undo
|
||||
.counts
|
||||
.iter()
|
||||
.map(|(edit_id, count)| proto::UndoCount {
|
||||
replica_id: edit_id.replica_id as u32,
|
||||
lamport_timestamp: edit_id.value,
|
||||
count: *count,
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
crate::Operation::UpdateSelections {
|
||||
selections,
|
||||
@@ -101,8 +99,7 @@ pub fn serialize_operation(operation: &crate::Operation) -> proto::Operation {
|
||||
pub fn serialize_edit_operation(operation: &EditOperation) -> proto::operation::Edit {
|
||||
proto::operation::Edit {
|
||||
replica_id: operation.timestamp.replica_id as u32,
|
||||
local_timestamp: operation.timestamp.local,
|
||||
lamport_timestamp: operation.timestamp.lamport,
|
||||
lamport_timestamp: operation.timestamp.value,
|
||||
version: serialize_version(&operation.version),
|
||||
ranges: operation.ranges.iter().map(serialize_range).collect(),
|
||||
new_text: operation
|
||||
@@ -114,7 +111,7 @@ pub fn serialize_edit_operation(operation: &EditOperation) -> proto::operation::
|
||||
}
|
||||
|
||||
pub fn serialize_undo_map_entry(
|
||||
(edit_id, counts): (&clock::Local, &[(clock::Local, u32)]),
|
||||
(edit_id, counts): (&clock::Lamport, &[(clock::Lamport, u32)]),
|
||||
) -> proto::UndoMapEntry {
|
||||
proto::UndoMapEntry {
|
||||
replica_id: edit_id.replica_id as u32,
|
||||
@@ -123,13 +120,38 @@ pub fn serialize_undo_map_entry(
|
||||
.iter()
|
||||
.map(|(undo_id, count)| proto::UndoCount {
|
||||
replica_id: undo_id.replica_id as u32,
|
||||
local_timestamp: undo_id.value,
|
||||
lamport_timestamp: undo_id.value,
|
||||
count: *count,
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn split_operations(
|
||||
mut operations: Vec<proto::Operation>,
|
||||
) -> impl Iterator<Item = Vec<proto::Operation>> {
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
const CHUNK_SIZE: usize = 5;
|
||||
|
||||
#[cfg(not(any(test, feature = "test-support")))]
|
||||
const CHUNK_SIZE: usize = 100;
|
||||
|
||||
let mut done = false;
|
||||
std::iter::from_fn(move || {
|
||||
if done {
|
||||
return None;
|
||||
}
|
||||
|
||||
let operations = operations
|
||||
.drain(..std::cmp::min(CHUNK_SIZE, operations.len()))
|
||||
.collect::<Vec<_>>();
|
||||
if operations.is_empty() {
|
||||
done = true;
|
||||
}
|
||||
Some(operations)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn serialize_selections(selections: &Arc<[Selection<Anchor>]>) -> Vec<proto::Selection> {
|
||||
selections.iter().map(serialize_selection).collect()
|
||||
}
|
||||
@@ -197,7 +219,7 @@ pub fn serialize_diagnostics<'a>(
|
||||
pub fn serialize_anchor(anchor: &Anchor) -> proto::Anchor {
|
||||
proto::Anchor {
|
||||
replica_id: anchor.timestamp.replica_id as u32,
|
||||
local_timestamp: anchor.timestamp.value,
|
||||
timestamp: anchor.timestamp.value,
|
||||
offset: anchor.offset as u64,
|
||||
bias: match anchor.bias {
|
||||
Bias::Left => proto::Bias::Left as i32,
|
||||
@@ -218,32 +240,26 @@ pub fn deserialize_operation(message: proto::Operation) -> Result<crate::Operati
|
||||
crate::Operation::Buffer(text::Operation::Edit(deserialize_edit_operation(edit)))
|
||||
}
|
||||
proto::operation::Variant::Undo(undo) => {
|
||||
crate::Operation::Buffer(text::Operation::Undo {
|
||||
lamport_timestamp: clock::Lamport {
|
||||
crate::Operation::Buffer(text::Operation::Undo(UndoOperation {
|
||||
timestamp: clock::Lamport {
|
||||
replica_id: undo.replica_id as ReplicaId,
|
||||
value: undo.lamport_timestamp,
|
||||
},
|
||||
undo: UndoOperation {
|
||||
id: clock::Local {
|
||||
replica_id: undo.replica_id as ReplicaId,
|
||||
value: undo.local_timestamp,
|
||||
},
|
||||
version: deserialize_version(&undo.version),
|
||||
counts: undo
|
||||
.counts
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
(
|
||||
clock::Local {
|
||||
replica_id: c.replica_id as ReplicaId,
|
||||
value: c.local_timestamp,
|
||||
},
|
||||
c.count,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
})
|
||||
version: deserialize_version(&undo.version),
|
||||
counts: undo
|
||||
.counts
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
(
|
||||
clock::Lamport {
|
||||
replica_id: c.replica_id as ReplicaId,
|
||||
value: c.lamport_timestamp,
|
||||
},
|
||||
c.count,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
}))
|
||||
}
|
||||
proto::operation::Variant::UpdateSelections(message) => {
|
||||
let selections = message
|
||||
@@ -298,10 +314,9 @@ pub fn deserialize_operation(message: proto::Operation) -> Result<crate::Operati
|
||||
|
||||
pub fn deserialize_edit_operation(edit: proto::operation::Edit) -> EditOperation {
|
||||
EditOperation {
|
||||
timestamp: InsertionTimestamp {
|
||||
timestamp: clock::Lamport {
|
||||
replica_id: edit.replica_id as ReplicaId,
|
||||
local: edit.local_timestamp,
|
||||
lamport: edit.lamport_timestamp,
|
||||
value: edit.lamport_timestamp,
|
||||
},
|
||||
version: deserialize_version(&edit.version),
|
||||
ranges: edit.ranges.into_iter().map(deserialize_range).collect(),
|
||||
@@ -311,9 +326,9 @@ pub fn deserialize_edit_operation(edit: proto::operation::Edit) -> EditOperation
|
||||
|
||||
pub fn deserialize_undo_map_entry(
|
||||
entry: proto::UndoMapEntry,
|
||||
) -> (clock::Local, Vec<(clock::Local, u32)>) {
|
||||
) -> (clock::Lamport, Vec<(clock::Lamport, u32)>) {
|
||||
(
|
||||
clock::Local {
|
||||
clock::Lamport {
|
||||
replica_id: entry.replica_id as u16,
|
||||
value: entry.local_timestamp,
|
||||
},
|
||||
@@ -322,9 +337,9 @@ pub fn deserialize_undo_map_entry(
|
||||
.into_iter()
|
||||
.map(|undo_count| {
|
||||
(
|
||||
clock::Local {
|
||||
clock::Lamport {
|
||||
replica_id: undo_count.replica_id as u16,
|
||||
value: undo_count.local_timestamp,
|
||||
value: undo_count.lamport_timestamp,
|
||||
},
|
||||
undo_count.count,
|
||||
)
|
||||
@@ -384,9 +399,9 @@ pub fn deserialize_diagnostics(
|
||||
|
||||
pub fn deserialize_anchor(anchor: proto::Anchor) -> Option<Anchor> {
|
||||
Some(Anchor {
|
||||
timestamp: clock::Local {
|
||||
timestamp: clock::Lamport {
|
||||
replica_id: anchor.replica_id as ReplicaId,
|
||||
value: anchor.local_timestamp,
|
||||
value: anchor.timestamp,
|
||||
},
|
||||
offset: anchor.offset as usize,
|
||||
bias: match proto::Bias::from_i32(anchor.bias)? {
|
||||
@@ -434,6 +449,7 @@ pub fn serialize_completion(completion: &Completion) -> proto::Completion {
|
||||
old_start: Some(serialize_anchor(&completion.old_range.start)),
|
||||
old_end: Some(serialize_anchor(&completion.old_range.end)),
|
||||
new_text: completion.new_text.clone(),
|
||||
server_id: completion.server_id.0 as u64,
|
||||
lsp_completion: serde_json::to_vec(&completion.lsp_completion).unwrap(),
|
||||
}
|
||||
}
|
||||
@@ -466,6 +482,7 @@ pub async fn deserialize_completion(
|
||||
lsp_completion.filter_text.as_deref(),
|
||||
)
|
||||
}),
|
||||
server_id: LanguageServerId(completion.server_id as usize),
|
||||
lsp_completion,
|
||||
})
|
||||
}
|
||||
@@ -498,12 +515,12 @@ pub fn deserialize_code_action(action: proto::CodeAction) -> Result<CodeAction>
|
||||
|
||||
pub fn serialize_transaction(transaction: &Transaction) -> proto::Transaction {
|
||||
proto::Transaction {
|
||||
id: Some(serialize_local_timestamp(transaction.id)),
|
||||
id: Some(serialize_timestamp(transaction.id)),
|
||||
edit_ids: transaction
|
||||
.edit_ids
|
||||
.iter()
|
||||
.copied()
|
||||
.map(serialize_local_timestamp)
|
||||
.map(serialize_timestamp)
|
||||
.collect(),
|
||||
start: serialize_version(&transaction.start),
|
||||
}
|
||||
@@ -511,7 +528,7 @@ pub fn serialize_transaction(transaction: &Transaction) -> proto::Transaction {
|
||||
|
||||
pub fn deserialize_transaction(transaction: proto::Transaction) -> Result<Transaction> {
|
||||
Ok(Transaction {
|
||||
id: deserialize_local_timestamp(
|
||||
id: deserialize_timestamp(
|
||||
transaction
|
||||
.id
|
||||
.ok_or_else(|| anyhow!("missing transaction id"))?,
|
||||
@@ -519,21 +536,21 @@ pub fn deserialize_transaction(transaction: proto::Transaction) -> Result<Transa
|
||||
edit_ids: transaction
|
||||
.edit_ids
|
||||
.into_iter()
|
||||
.map(deserialize_local_timestamp)
|
||||
.map(deserialize_timestamp)
|
||||
.collect(),
|
||||
start: deserialize_version(&transaction.start),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn serialize_local_timestamp(timestamp: clock::Local) -> proto::LocalTimestamp {
|
||||
proto::LocalTimestamp {
|
||||
pub fn serialize_timestamp(timestamp: clock::Lamport) -> proto::LamportTimestamp {
|
||||
proto::LamportTimestamp {
|
||||
replica_id: timestamp.replica_id as u32,
|
||||
value: timestamp.value,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize_local_timestamp(timestamp: proto::LocalTimestamp) -> clock::Local {
|
||||
clock::Local {
|
||||
pub fn deserialize_timestamp(timestamp: proto::LamportTimestamp) -> clock::Lamport {
|
||||
clock::Lamport {
|
||||
replica_id: timestamp.replica_id as ReplicaId,
|
||||
value: timestamp.value,
|
||||
}
|
||||
@@ -553,7 +570,7 @@ pub fn deserialize_range(range: proto::Range) -> Range<FullOffset> {
|
||||
pub fn deserialize_version(message: &[proto::VectorClockEntry]) -> clock::Global {
|
||||
let mut version = clock::Global::new();
|
||||
for entry in message {
|
||||
version.observe(clock::Local {
|
||||
version.observe(clock::Lamport {
|
||||
replica_id: entry.replica_id as ReplicaId,
|
||||
value: entry.timestamp,
|
||||
});
|
||||
|
||||
@@ -52,6 +52,7 @@ impl View for ActiveBufferLanguage {
|
||||
} else {
|
||||
"Unknown".to_string()
|
||||
};
|
||||
let theme = theme::current(cx).clone();
|
||||
|
||||
MouseEventHandler::new::<Self, _>(0, cx, |state, cx| {
|
||||
let theme = &theme::current(cx).workspace.status_bar;
|
||||
@@ -68,6 +69,7 @@ impl View for ActiveBufferLanguage {
|
||||
});
|
||||
}
|
||||
})
|
||||
.with_tooltip::<Self>(0, "Select Language", None, theme.tooltip.clone(), cx)
|
||||
.into_any()
|
||||
} else {
|
||||
Empty::new().into_any()
|
||||
|
||||
@@ -570,10 +570,12 @@ impl View for LspLogToolbarItemView {
|
||||
let Some(log_view) = self.log_view.as_ref() else {
|
||||
return Empty::new().into_any();
|
||||
};
|
||||
let log_view = log_view.read(cx);
|
||||
let menu_rows = log_view.menu_items(cx).unwrap_or_default();
|
||||
let (menu_rows, current_server_id) = log_view.update(cx, |log_view, cx| {
|
||||
let menu_rows = log_view.menu_items(cx).unwrap_or_default();
|
||||
let current_server_id = log_view.current_server_id;
|
||||
(menu_rows, current_server_id)
|
||||
});
|
||||
|
||||
let current_server_id = log_view.current_server_id;
|
||||
let current_server = current_server_id.and_then(|current_server_id| {
|
||||
if let Ok(ix) = menu_rows.binary_search_by_key(¤t_server_id, |e| e.server_id) {
|
||||
Some(menu_rows[ix].clone())
|
||||
@@ -581,10 +583,10 @@ impl View for LspLogToolbarItemView {
|
||||
None
|
||||
}
|
||||
});
|
||||
let server_selected = current_server.is_some();
|
||||
|
||||
enum Menu {}
|
||||
|
||||
Stack::new()
|
||||
let lsp_menu = Stack::new()
|
||||
.with_child(Self::render_language_server_menu_header(
|
||||
current_server,
|
||||
&theme,
|
||||
@@ -631,8 +633,47 @@ impl View for LspLogToolbarItemView {
|
||||
})
|
||||
.aligned()
|
||||
.left()
|
||||
.clipped()
|
||||
.into_any()
|
||||
.clipped();
|
||||
|
||||
enum LspCleanupButton {}
|
||||
let log_cleanup_button =
|
||||
MouseEventHandler::new::<LspCleanupButton, _>(1, cx, |state, cx| {
|
||||
let theme = theme::current(cx).clone();
|
||||
let style = theme
|
||||
.workspace
|
||||
.toolbar
|
||||
.toggleable_text_tool
|
||||
.in_state(server_selected)
|
||||
.style_for(state);
|
||||
Label::new("Clear", style.text.clone())
|
||||
.aligned()
|
||||
.contained()
|
||||
.with_style(style.container)
|
||||
.constrained()
|
||||
.with_height(theme.toolbar_dropdown_menu.row_height / 6.0 * 5.0)
|
||||
})
|
||||
.on_click(MouseButton::Left, move |_, this, cx| {
|
||||
if let Some(log_view) = this.log_view.as_ref() {
|
||||
log_view.update(cx, |log_view, cx| {
|
||||
log_view.editor.update(cx, |editor, cx| {
|
||||
editor.set_read_only(false);
|
||||
editor.clear(cx);
|
||||
editor.set_read_only(true);
|
||||
});
|
||||
})
|
||||
}
|
||||
})
|
||||
.with_cursor_style(CursorStyle::PointingHand)
|
||||
.aligned()
|
||||
.right();
|
||||
|
||||
Flex::row()
|
||||
.with_child(lsp_menu)
|
||||
.with_child(log_cleanup_button)
|
||||
.contained()
|
||||
.aligned()
|
||||
.left()
|
||||
.into_any_named("lsp log controls")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -63,6 +63,7 @@ fn build_bridge(swift_target: &SwiftTarget) {
|
||||
let swift_target_folder = swift_target_folder();
|
||||
if !Command::new("swift")
|
||||
.arg("build")
|
||||
.arg("--disable-automatic-resolution")
|
||||
.args(["--configuration", &env::var("PROFILE").unwrap()])
|
||||
.args(["--triple", &swift_target.target.triple])
|
||||
.args(["--build-path".into(), swift_target_folder])
|
||||
|
||||
@@ -20,7 +20,7 @@ anyhow.workspace = true
|
||||
async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553", optional = true }
|
||||
futures.workspace = true
|
||||
log.workspace = true
|
||||
lsp-types = "0.94"
|
||||
lsp-types = { git = "https://github.com/zed-industries/lsp-types", branch = "updated-completion-list-item-defaults" }
|
||||
parking_lot.workspace = true
|
||||
postage.workspace = true
|
||||
serde.workspace = true
|
||||
|
||||
@@ -4,7 +4,7 @@ pub use lsp_types::*;
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use collections::HashMap;
|
||||
use futures::{channel::oneshot, io::BufWriter, AsyncRead, AsyncWrite};
|
||||
use futures::{channel::oneshot, io::BufWriter, AsyncRead, AsyncWrite, FutureExt};
|
||||
use gpui::{executor, AsyncAppContext, Task};
|
||||
use parking_lot::Mutex;
|
||||
use postage::{barrier, prelude::Stream};
|
||||
@@ -26,12 +26,14 @@ use std::{
|
||||
atomic::{AtomicUsize, Ordering::SeqCst},
|
||||
Arc, Weak,
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use std::{path::Path, process::Stdio};
|
||||
use util::{ResultExt, TryFutureExt};
|
||||
|
||||
const JSON_RPC_VERSION: &str = "2.0";
|
||||
const CONTENT_LEN_HEADER: &str = "Content-Length: ";
|
||||
const LSP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60 * 2);
|
||||
|
||||
type NotificationHandler = Box<dyn Send + FnMut(Option<usize>, &str, AsyncAppContext)>;
|
||||
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
|
||||
@@ -303,7 +305,7 @@ impl LanguageServer {
|
||||
stdout.read_exact(&mut buffer).await?;
|
||||
|
||||
if let Ok(message) = str::from_utf8(&buffer) {
|
||||
log::trace!("incoming message:{}", message);
|
||||
log::trace!("incoming message: {}", message);
|
||||
for handler in io_handlers.lock().values_mut() {
|
||||
handler(IoKind::StdOut, message);
|
||||
}
|
||||
@@ -468,6 +470,14 @@ impl LanguageServer {
|
||||
}),
|
||||
..Default::default()
|
||||
}),
|
||||
completion_list: Some(CompletionListCapability {
|
||||
item_defaults: Some(vec![
|
||||
"commitCharacters".to_owned(),
|
||||
"editRange".to_owned(),
|
||||
"insertTextMode".to_owned(),
|
||||
"data".to_owned(),
|
||||
]),
|
||||
}),
|
||||
..Default::default()
|
||||
}),
|
||||
rename: Some(RenameClientCapabilities {
|
||||
@@ -740,7 +750,7 @@ impl LanguageServer {
|
||||
outbound_tx: &channel::Sender<String>,
|
||||
executor: &Arc<executor::Background>,
|
||||
params: T::Params,
|
||||
) -> impl 'static + Future<Output = Result<T::Result>>
|
||||
) -> impl 'static + Future<Output = anyhow::Result<T::Result>>
|
||||
where
|
||||
T::Result: 'static + Send,
|
||||
{
|
||||
@@ -781,10 +791,25 @@ impl LanguageServer {
|
||||
.try_send(message)
|
||||
.context("failed to write to language server's stdin");
|
||||
|
||||
let mut timeout = executor.timer(LSP_REQUEST_TIMEOUT).fuse();
|
||||
let started = Instant::now();
|
||||
async move {
|
||||
handle_response?;
|
||||
send?;
|
||||
rx.await?
|
||||
|
||||
let method = T::METHOD;
|
||||
futures::select! {
|
||||
response = rx.fuse() => {
|
||||
let elapsed = started.elapsed();
|
||||
log::trace!("Took {elapsed:?} to recieve response to {method:?} id {id}");
|
||||
response?
|
||||
}
|
||||
|
||||
_ = timeout => {
|
||||
log::error!("Cancelled LSP request task for {method:?} id {id} which took over {LSP_REQUEST_TIMEOUT:?}");
|
||||
anyhow::bail!("LSP request timeout");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ util = { path = "../util" }
|
||||
async-compression = { version = "0.3", features = ["gzip", "futures-bufread"] }
|
||||
async-tar = "0.4.2"
|
||||
futures.workspace = true
|
||||
async-trait.workspace = true
|
||||
anyhow.workspace = true
|
||||
parking_lot.workspace = true
|
||||
serde.workspace = true
|
||||
|
||||
@@ -7,14 +7,12 @@ use std::process::{Output, Stdio};
|
||||
use std::{
|
||||
env::consts,
|
||||
path::{Path, PathBuf},
|
||||
sync::{Arc, OnceLock},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::http::HttpClient;
|
||||
|
||||
const VERSION: &str = "v18.15.0";
|
||||
|
||||
static RUNTIME_INSTANCE: OnceLock<Arc<NodeRuntime>> = OnceLock::new();
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct NpmInfo {
|
||||
@@ -28,23 +26,88 @@ pub struct NpmInfoDistTags {
|
||||
latest: Option<String>,
|
||||
}
|
||||
|
||||
pub struct NodeRuntime {
|
||||
#[async_trait::async_trait]
|
||||
pub trait NodeRuntime: Send + Sync {
|
||||
async fn binary_path(&self) -> Result<PathBuf>;
|
||||
|
||||
async fn run_npm_subcommand(
|
||||
&self,
|
||||
directory: Option<&Path>,
|
||||
subcommand: &str,
|
||||
args: &[&str],
|
||||
) -> Result<Output>;
|
||||
|
||||
async fn npm_package_latest_version(&self, name: &str) -> Result<String>;
|
||||
|
||||
async fn npm_install_packages(&self, directory: &Path, packages: &[(&str, &str)])
|
||||
-> Result<()>;
|
||||
}
|
||||
|
||||
pub struct RealNodeRuntime {
|
||||
http: Arc<dyn HttpClient>,
|
||||
}
|
||||
|
||||
impl NodeRuntime {
|
||||
pub fn instance(http: Arc<dyn HttpClient>) -> Arc<NodeRuntime> {
|
||||
RUNTIME_INSTANCE
|
||||
.get_or_init(|| Arc::new(NodeRuntime { http }))
|
||||
.clone()
|
||||
impl RealNodeRuntime {
|
||||
pub fn new(http: Arc<dyn HttpClient>) -> Arc<dyn NodeRuntime> {
|
||||
Arc::new(RealNodeRuntime { http })
|
||||
}
|
||||
|
||||
pub async fn binary_path(&self) -> Result<PathBuf> {
|
||||
async fn install_if_needed(&self) -> Result<PathBuf> {
|
||||
log::info!("Node runtime install_if_needed");
|
||||
|
||||
let arch = match consts::ARCH {
|
||||
"x86_64" => "x64",
|
||||
"aarch64" => "arm64",
|
||||
other => bail!("Running on unsupported platform: {other}"),
|
||||
};
|
||||
|
||||
let folder_name = format!("node-{VERSION}-darwin-{arch}");
|
||||
let node_containing_dir = util::paths::SUPPORT_DIR.join("node");
|
||||
let node_dir = node_containing_dir.join(folder_name);
|
||||
let node_binary = node_dir.join("bin/node");
|
||||
let npm_file = node_dir.join("bin/npm");
|
||||
|
||||
let result = Command::new(&node_binary)
|
||||
.arg(npm_file)
|
||||
.arg("--version")
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.status()
|
||||
.await;
|
||||
let valid = matches!(result, Ok(status) if status.success());
|
||||
|
||||
if !valid {
|
||||
_ = fs::remove_dir_all(&node_containing_dir).await;
|
||||
fs::create_dir(&node_containing_dir)
|
||||
.await
|
||||
.context("error creating node containing dir")?;
|
||||
|
||||
let file_name = format!("node-{VERSION}-darwin-{arch}.tar.gz");
|
||||
let url = format!("https://nodejs.org/dist/{VERSION}/{file_name}");
|
||||
let mut response = self
|
||||
.http
|
||||
.get(&url, Default::default(), true)
|
||||
.await
|
||||
.context("error downloading Node binary tarball")?;
|
||||
|
||||
let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
|
||||
let archive = Archive::new(decompressed_bytes);
|
||||
archive.unpack(&node_containing_dir).await?;
|
||||
}
|
||||
|
||||
anyhow::Ok(node_dir)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl NodeRuntime for RealNodeRuntime {
|
||||
async fn binary_path(&self) -> Result<PathBuf> {
|
||||
let installation_path = self.install_if_needed().await?;
|
||||
Ok(installation_path.join("bin/node"))
|
||||
}
|
||||
|
||||
pub async fn run_npm_subcommand(
|
||||
async fn run_npm_subcommand(
|
||||
&self,
|
||||
directory: Option<&Path>,
|
||||
subcommand: &str,
|
||||
@@ -106,7 +169,7 @@ impl NodeRuntime {
|
||||
output.map_err(|e| anyhow!("{e}"))
|
||||
}
|
||||
|
||||
pub async fn npm_package_latest_version(&self, name: &str) -> Result<String> {
|
||||
async fn npm_package_latest_version(&self, name: &str) -> Result<String> {
|
||||
let output = self
|
||||
.run_npm_subcommand(
|
||||
None,
|
||||
@@ -131,10 +194,10 @@ impl NodeRuntime {
|
||||
.ok_or_else(|| anyhow!("no version found for npm package {}", name))
|
||||
}
|
||||
|
||||
pub async fn npm_install_packages(
|
||||
async fn npm_install_packages(
|
||||
&self,
|
||||
directory: &Path,
|
||||
packages: impl IntoIterator<Item = (&str, &str)>,
|
||||
packages: &[(&str, &str)],
|
||||
) -> Result<()> {
|
||||
let packages: Vec<_> = packages
|
||||
.into_iter()
|
||||
@@ -155,51 +218,31 @@ impl NodeRuntime {
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn install_if_needed(&self) -> Result<PathBuf> {
|
||||
log::info!("Node runtime install_if_needed");
|
||||
pub struct FakeNodeRuntime;
|
||||
|
||||
let arch = match consts::ARCH {
|
||||
"x86_64" => "x64",
|
||||
"aarch64" => "arm64",
|
||||
other => bail!("Running on unsupported platform: {other}"),
|
||||
};
|
||||
|
||||
let folder_name = format!("node-{VERSION}-darwin-{arch}");
|
||||
let node_containing_dir = util::paths::SUPPORT_DIR.join("node");
|
||||
let node_dir = node_containing_dir.join(folder_name);
|
||||
let node_binary = node_dir.join("bin/node");
|
||||
let npm_file = node_dir.join("bin/npm");
|
||||
|
||||
let result = Command::new(&node_binary)
|
||||
.arg(npm_file)
|
||||
.arg("--version")
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.status()
|
||||
.await;
|
||||
let valid = matches!(result, Ok(status) if status.success());
|
||||
|
||||
if !valid {
|
||||
_ = fs::remove_dir_all(&node_containing_dir).await;
|
||||
fs::create_dir(&node_containing_dir)
|
||||
.await
|
||||
.context("error creating node containing dir")?;
|
||||
|
||||
let file_name = format!("node-{VERSION}-darwin-{arch}.tar.gz");
|
||||
let url = format!("https://nodejs.org/dist/{VERSION}/{file_name}");
|
||||
let mut response = self
|
||||
.http
|
||||
.get(&url, Default::default(), true)
|
||||
.await
|
||||
.context("error downloading Node binary tarball")?;
|
||||
|
||||
let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut()));
|
||||
let archive = Archive::new(decompressed_bytes);
|
||||
archive.unpack(&node_containing_dir).await?;
|
||||
}
|
||||
|
||||
anyhow::Ok(node_dir)
|
||||
impl FakeNodeRuntime {
|
||||
pub fn new() -> Arc<dyn NodeRuntime> {
|
||||
Arc::new(FakeNodeRuntime)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl NodeRuntime for FakeNodeRuntime {
|
||||
async fn binary_path(&self) -> Result<PathBuf> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
async fn run_npm_subcommand(&self, _: Option<&Path>, _: &str, _: &[&str]) -> Result<Output> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
async fn npm_package_latest_version(&self, _: &str) -> Result<String> {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
async fn npm_install_packages(&self, _: &Path, _: &[(&str, &str)]) -> Result<()> {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,10 @@ use language::{
|
||||
CodeAction, Completion, OffsetRangeExt, PointUtf16, ToOffset, ToPointUtf16, Transaction,
|
||||
Unclipped,
|
||||
};
|
||||
use lsp::{DocumentHighlightKind, LanguageServer, LanguageServerId, OneOf, ServerCapabilities};
|
||||
use lsp::{
|
||||
CompletionListItemDefaultsEditRange, DocumentHighlightKind, LanguageServer, LanguageServerId,
|
||||
OneOf, ServerCapabilities,
|
||||
};
|
||||
use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
|
||||
use text::LineEnding;
|
||||
|
||||
@@ -1340,13 +1343,19 @@ impl LspCommand for GetCompletions {
|
||||
completions: Option<lsp::CompletionResponse>,
|
||||
_: ModelHandle<Project>,
|
||||
buffer: ModelHandle<Buffer>,
|
||||
_: LanguageServerId,
|
||||
server_id: LanguageServerId,
|
||||
cx: AsyncAppContext,
|
||||
) -> Result<Vec<Completion>> {
|
||||
let mut response_list = None;
|
||||
let completions = if let Some(completions) = completions {
|
||||
match completions {
|
||||
lsp::CompletionResponse::Array(completions) => completions,
|
||||
lsp::CompletionResponse::List(list) => list.items,
|
||||
|
||||
lsp::CompletionResponse::List(mut list) => {
|
||||
let items = std::mem::take(&mut list.items);
|
||||
response_list = Some(list);
|
||||
items
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Default::default()
|
||||
@@ -1356,6 +1365,7 @@ impl LspCommand for GetCompletions {
|
||||
let language = buffer.language().cloned();
|
||||
let snapshot = buffer.snapshot();
|
||||
let clipped_position = buffer.clip_point_utf16(Unclipped(self.position), Bias::Left);
|
||||
|
||||
let mut range_for_token = None;
|
||||
completions
|
||||
.into_iter()
|
||||
@@ -1376,6 +1386,7 @@ impl LspCommand for GetCompletions {
|
||||
edit.new_text.clone(),
|
||||
)
|
||||
}
|
||||
|
||||
// If the language server does not provide a range, then infer
|
||||
// the range based on the syntax tree.
|
||||
None => {
|
||||
@@ -1383,27 +1394,51 @@ impl LspCommand for GetCompletions {
|
||||
log::info!("completion out of expected range");
|
||||
return None;
|
||||
}
|
||||
let Range { start, end } = range_for_token
|
||||
.get_or_insert_with(|| {
|
||||
let offset = self.position.to_offset(&snapshot);
|
||||
let (range, kind) = snapshot.surrounding_word(offset);
|
||||
if kind == Some(CharKind::Word) {
|
||||
range
|
||||
} else {
|
||||
offset..offset
|
||||
}
|
||||
})
|
||||
.clone();
|
||||
|
||||
let default_edit_range = response_list
|
||||
.as_ref()
|
||||
.and_then(|list| list.item_defaults.as_ref())
|
||||
.and_then(|defaults| defaults.edit_range.as_ref())
|
||||
.and_then(|range| match range {
|
||||
CompletionListItemDefaultsEditRange::Range(r) => Some(r),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
let range = if let Some(range) = default_edit_range {
|
||||
let range = range_from_lsp(range.clone());
|
||||
let start = snapshot.clip_point_utf16(range.start, Bias::Left);
|
||||
let end = snapshot.clip_point_utf16(range.end, Bias::Left);
|
||||
if start != range.start.0 || end != range.end.0 {
|
||||
log::info!("completion out of expected range");
|
||||
return None;
|
||||
}
|
||||
|
||||
snapshot.anchor_before(start)..snapshot.anchor_after(end)
|
||||
} else {
|
||||
range_for_token
|
||||
.get_or_insert_with(|| {
|
||||
let offset = self.position.to_offset(&snapshot);
|
||||
let (range, kind) = snapshot.surrounding_word(offset);
|
||||
let range = if kind == Some(CharKind::Word) {
|
||||
range
|
||||
} else {
|
||||
offset..offset
|
||||
};
|
||||
|
||||
snapshot.anchor_before(range.start)
|
||||
..snapshot.anchor_after(range.end)
|
||||
})
|
||||
.clone()
|
||||
};
|
||||
|
||||
let text = lsp_completion
|
||||
.insert_text
|
||||
.as_ref()
|
||||
.unwrap_or(&lsp_completion.label)
|
||||
.clone();
|
||||
(
|
||||
snapshot.anchor_before(start)..snapshot.anchor_after(end),
|
||||
text,
|
||||
)
|
||||
(range, text)
|
||||
}
|
||||
|
||||
Some(lsp::CompletionTextEdit::InsertAndReplace(_)) => {
|
||||
log::info!("unsupported insert/replace completion");
|
||||
return None;
|
||||
@@ -1427,6 +1462,7 @@ impl LspCommand for GetCompletions {
|
||||
lsp_completion.filter_text.as_deref(),
|
||||
)
|
||||
}),
|
||||
server_id,
|
||||
lsp_completion,
|
||||
}
|
||||
})
|
||||
|
||||
@@ -35,7 +35,7 @@ use language::{
|
||||
point_to_lsp,
|
||||
proto::{
|
||||
deserialize_anchor, deserialize_fingerprint, deserialize_line_ending, deserialize_version,
|
||||
serialize_anchor, serialize_version,
|
||||
serialize_anchor, serialize_version, split_operations,
|
||||
},
|
||||
range_from_lsp, range_to_lsp, Bias, Buffer, BufferSnapshot, CachedLspAdapter, CodeAction,
|
||||
CodeLabel, Completion, Diagnostic, DiagnosticEntry, DiagnosticSet, Diff, Event as BufferEvent,
|
||||
@@ -156,6 +156,11 @@ struct DelayedDebounced {
|
||||
cancel_channel: Option<oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
enum LanguageServerToQuery {
|
||||
Primary,
|
||||
Other(LanguageServerId),
|
||||
}
|
||||
|
||||
impl DelayedDebounced {
|
||||
fn new() -> DelayedDebounced {
|
||||
DelayedDebounced {
|
||||
@@ -634,7 +639,7 @@ impl Project {
|
||||
cx.observe_global::<SettingsStore, _>(Self::on_settings_changed)
|
||||
],
|
||||
_maintain_buffer_languages: Self::maintain_buffer_languages(languages.clone(), cx),
|
||||
_maintain_workspace_config: Self::maintain_workspace_config(languages.clone(), cx),
|
||||
_maintain_workspace_config: Self::maintain_workspace_config(cx),
|
||||
active_entry: None,
|
||||
languages,
|
||||
client,
|
||||
@@ -704,7 +709,7 @@ impl Project {
|
||||
collaborators: Default::default(),
|
||||
join_project_response_message_id: response.message_id,
|
||||
_maintain_buffer_languages: Self::maintain_buffer_languages(languages.clone(), cx),
|
||||
_maintain_workspace_config: Self::maintain_workspace_config(languages.clone(), cx),
|
||||
_maintain_workspace_config: Self::maintain_workspace_config(cx),
|
||||
languages,
|
||||
user_store: user_store.clone(),
|
||||
fs,
|
||||
@@ -2472,35 +2477,42 @@ impl Project {
|
||||
})
|
||||
}
|
||||
|
||||
fn maintain_workspace_config(
|
||||
languages: Arc<LanguageRegistry>,
|
||||
cx: &mut ModelContext<Project>,
|
||||
) -> Task<()> {
|
||||
fn maintain_workspace_config(cx: &mut ModelContext<Project>) -> Task<()> {
|
||||
let (mut settings_changed_tx, mut settings_changed_rx) = watch::channel();
|
||||
let _ = postage::stream::Stream::try_recv(&mut settings_changed_rx);
|
||||
|
||||
let settings_observation = cx.observe_global::<SettingsStore, _>(move |_, _| {
|
||||
*settings_changed_tx.borrow_mut() = ();
|
||||
});
|
||||
|
||||
cx.spawn_weak(|this, mut cx| async move {
|
||||
while let Some(_) = settings_changed_rx.next().await {
|
||||
let workspace_config = cx.update(|cx| languages.workspace_configuration(cx)).await;
|
||||
if let Some(this) = this.upgrade(&cx) {
|
||||
this.read_with(&cx, |this, _| {
|
||||
for server_state in this.language_servers.values() {
|
||||
if let LanguageServerState::Running { server, .. } = server_state {
|
||||
server
|
||||
.notify::<lsp::notification::DidChangeConfiguration>(
|
||||
lsp::DidChangeConfigurationParams {
|
||||
settings: workspace_config.clone(),
|
||||
},
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
})
|
||||
} else {
|
||||
let Some(this) = this.upgrade(&cx) else {
|
||||
break;
|
||||
};
|
||||
|
||||
let servers: Vec<_> = this.read_with(&cx, |this, _| {
|
||||
this.language_servers
|
||||
.values()
|
||||
.filter_map(|state| match state {
|
||||
LanguageServerState::Starting(_) => None,
|
||||
LanguageServerState::Running {
|
||||
adapter, server, ..
|
||||
} => Some((adapter.clone(), server.clone())),
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
for (adapter, server) in servers {
|
||||
let workspace_config =
|
||||
cx.update(|cx| adapter.workspace_configuration(cx)).await;
|
||||
server
|
||||
.notify::<lsp::notification::DidChangeConfiguration>(
|
||||
lsp::DidChangeConfigurationParams {
|
||||
settings: workspace_config.clone(),
|
||||
},
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2615,7 +2627,6 @@ impl Project {
|
||||
let state = LanguageServerState::Starting({
|
||||
let adapter = adapter.clone();
|
||||
let server_name = adapter.name.0.clone();
|
||||
let languages = self.languages.clone();
|
||||
let language = language.clone();
|
||||
let key = key.clone();
|
||||
|
||||
@@ -2625,7 +2636,6 @@ impl Project {
|
||||
initialization_options,
|
||||
pending_server,
|
||||
adapter.clone(),
|
||||
languages,
|
||||
language.clone(),
|
||||
server_id,
|
||||
key,
|
||||
@@ -2729,7 +2739,6 @@ impl Project {
|
||||
initialization_options: Option<serde_json::Value>,
|
||||
pending_server: PendingLanguageServer,
|
||||
adapter: Arc<CachedLspAdapter>,
|
||||
languages: Arc<LanguageRegistry>,
|
||||
language: Arc<Language>,
|
||||
server_id: LanguageServerId,
|
||||
key: (WorktreeId, LanguageServerName),
|
||||
@@ -2740,7 +2749,6 @@ impl Project {
|
||||
initialization_options,
|
||||
pending_server,
|
||||
adapter.clone(),
|
||||
languages,
|
||||
server_id,
|
||||
cx,
|
||||
);
|
||||
@@ -2773,16 +2781,13 @@ impl Project {
|
||||
initialization_options: Option<serde_json::Value>,
|
||||
pending_server: PendingLanguageServer,
|
||||
adapter: Arc<CachedLspAdapter>,
|
||||
languages: Arc<LanguageRegistry>,
|
||||
server_id: LanguageServerId,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> Result<Option<Arc<LanguageServer>>> {
|
||||
let workspace_config = cx.update(|cx| languages.workspace_configuration(cx)).await;
|
||||
let workspace_config = cx.update(|cx| adapter.workspace_configuration(cx)).await;
|
||||
let language_server = match pending_server.task.await? {
|
||||
Some(server) => server.initialize(initialization_options).await?,
|
||||
None => {
|
||||
return Ok(None);
|
||||
}
|
||||
Some(server) => server,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
language_server
|
||||
@@ -2821,12 +2826,12 @@ impl Project {
|
||||
|
||||
language_server
|
||||
.on_request::<lsp::request::WorkspaceConfiguration, _, _>({
|
||||
let languages = languages.clone();
|
||||
let adapter = adapter.clone();
|
||||
move |params, mut cx| {
|
||||
let languages = languages.clone();
|
||||
let adapter = adapter.clone();
|
||||
async move {
|
||||
let workspace_config =
|
||||
cx.update(|cx| languages.workspace_configuration(cx)).await;
|
||||
cx.update(|cx| adapter.workspace_configuration(cx)).await;
|
||||
Ok(params
|
||||
.items
|
||||
.into_iter()
|
||||
@@ -2932,6 +2937,8 @@ impl Project {
|
||||
})
|
||||
.detach();
|
||||
|
||||
let language_server = language_server.initialize(initialization_options).await?;
|
||||
|
||||
language_server
|
||||
.notify::<lsp::notification::DidChangeConfiguration>(
|
||||
lsp::DidChangeConfigurationParams {
|
||||
@@ -3892,7 +3899,7 @@ impl Project {
|
||||
let file = File::from_dyn(buffer.file())?;
|
||||
let buffer_abs_path = file.as_local().map(|f| f.abs_path(cx));
|
||||
let server = self
|
||||
.primary_language_servers_for_buffer(buffer, cx)
|
||||
.primary_language_server_for_buffer(buffer, cx)
|
||||
.map(|s| s.1.clone());
|
||||
Some((buffer_handle, buffer_abs_path, server))
|
||||
})
|
||||
@@ -4197,7 +4204,12 @@ impl Project {
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<LocationLink>>> {
|
||||
let position = position.to_point_utf16(buffer.read(cx));
|
||||
self.request_lsp(buffer.clone(), GetDefinition { position }, cx)
|
||||
self.request_lsp(
|
||||
buffer.clone(),
|
||||
LanguageServerToQuery::Primary,
|
||||
GetDefinition { position },
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn type_definition<T: ToPointUtf16>(
|
||||
@@ -4207,7 +4219,12 @@ impl Project {
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<LocationLink>>> {
|
||||
let position = position.to_point_utf16(buffer.read(cx));
|
||||
self.request_lsp(buffer.clone(), GetTypeDefinition { position }, cx)
|
||||
self.request_lsp(
|
||||
buffer.clone(),
|
||||
LanguageServerToQuery::Primary,
|
||||
GetTypeDefinition { position },
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn references<T: ToPointUtf16>(
|
||||
@@ -4217,7 +4234,12 @@ impl Project {
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<Location>>> {
|
||||
let position = position.to_point_utf16(buffer.read(cx));
|
||||
self.request_lsp(buffer.clone(), GetReferences { position }, cx)
|
||||
self.request_lsp(
|
||||
buffer.clone(),
|
||||
LanguageServerToQuery::Primary,
|
||||
GetReferences { position },
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn document_highlights<T: ToPointUtf16>(
|
||||
@@ -4227,7 +4249,12 @@ impl Project {
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<DocumentHighlight>>> {
|
||||
let position = position.to_point_utf16(buffer.read(cx));
|
||||
self.request_lsp(buffer.clone(), GetDocumentHighlights { position }, cx)
|
||||
self.request_lsp(
|
||||
buffer.clone(),
|
||||
LanguageServerToQuery::Primary,
|
||||
GetDocumentHighlights { position },
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn symbols(&self, query: &str, cx: &mut ModelContext<Self>) -> Task<Result<Vec<Symbol>>> {
|
||||
@@ -4455,17 +4482,66 @@ impl Project {
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Option<Hover>>> {
|
||||
let position = position.to_point_utf16(buffer.read(cx));
|
||||
self.request_lsp(buffer.clone(), GetHover { position }, cx)
|
||||
self.request_lsp(
|
||||
buffer.clone(),
|
||||
LanguageServerToQuery::Primary,
|
||||
GetHover { position },
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn completions<T: ToPointUtf16>(
|
||||
pub fn completions<T: ToOffset + ToPointUtf16>(
|
||||
&self,
|
||||
buffer: &ModelHandle<Buffer>,
|
||||
position: T,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<Completion>>> {
|
||||
let position = position.to_point_utf16(buffer.read(cx));
|
||||
self.request_lsp(buffer.clone(), GetCompletions { position }, cx)
|
||||
if self.is_local() {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let offset = position.to_offset(&snapshot);
|
||||
let scope = snapshot.language_scope_at(offset);
|
||||
|
||||
let server_ids: Vec<_> = self
|
||||
.language_servers_for_buffer(buffer.read(cx), cx)
|
||||
.filter(|(_, server)| server.capabilities().completion_provider.is_some())
|
||||
.filter(|(adapter, _)| {
|
||||
scope
|
||||
.as_ref()
|
||||
.map(|scope| scope.language_allowed(&adapter.name))
|
||||
.unwrap_or(true)
|
||||
})
|
||||
.map(|(_, server)| server.server_id())
|
||||
.collect();
|
||||
|
||||
let buffer = buffer.clone();
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let mut tasks = Vec::with_capacity(server_ids.len());
|
||||
this.update(&mut cx, |this, cx| {
|
||||
for server_id in server_ids {
|
||||
tasks.push(this.request_lsp(
|
||||
buffer.clone(),
|
||||
LanguageServerToQuery::Other(server_id),
|
||||
GetCompletions { position },
|
||||
cx,
|
||||
));
|
||||
}
|
||||
});
|
||||
|
||||
let mut completions = Vec::new();
|
||||
for task in tasks {
|
||||
if let Ok(new_completions) = task.await {
|
||||
completions.extend_from_slice(&new_completions);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(completions)
|
||||
})
|
||||
} else if let Some(project_id) = self.remote_id() {
|
||||
self.send_lsp_proto_request(buffer.clone(), project_id, GetCompletions { position }, cx)
|
||||
} else {
|
||||
Task::ready(Ok(Default::default()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply_additional_edits_for_completion(
|
||||
@@ -4479,7 +4555,8 @@ impl Project {
|
||||
let buffer_id = buffer.remote_id();
|
||||
|
||||
if self.is_local() {
|
||||
let lang_server = match self.primary_language_servers_for_buffer(buffer, cx) {
|
||||
let server_id = completion.server_id;
|
||||
let lang_server = match self.language_server_for_buffer(buffer, server_id, cx) {
|
||||
Some((_, server)) => server.clone(),
|
||||
_ => return Task::ready(Ok(Default::default())),
|
||||
};
|
||||
@@ -4586,7 +4663,12 @@ impl Project {
|
||||
) -> Task<Result<Vec<CodeAction>>> {
|
||||
let buffer = buffer_handle.read(cx);
|
||||
let range = buffer.anchor_before(range.start)..buffer.anchor_before(range.end);
|
||||
self.request_lsp(buffer_handle.clone(), GetCodeActions { range }, cx)
|
||||
self.request_lsp(
|
||||
buffer_handle.clone(),
|
||||
LanguageServerToQuery::Primary,
|
||||
GetCodeActions { range },
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn apply_code_action(
|
||||
@@ -4942,7 +5024,12 @@ impl Project {
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Option<Range<Anchor>>>> {
|
||||
let position = position.to_point_utf16(buffer.read(cx));
|
||||
self.request_lsp(buffer, PrepareRename { position }, cx)
|
||||
self.request_lsp(
|
||||
buffer,
|
||||
LanguageServerToQuery::Primary,
|
||||
PrepareRename { position },
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn perform_rename<T: ToPointUtf16>(
|
||||
@@ -4956,6 +5043,7 @@ impl Project {
|
||||
let position = position.to_point_utf16(buffer.read(cx));
|
||||
self.request_lsp(
|
||||
buffer,
|
||||
LanguageServerToQuery::Primary,
|
||||
PerformRename {
|
||||
position,
|
||||
new_name,
|
||||
@@ -4983,6 +5071,7 @@ impl Project {
|
||||
});
|
||||
self.request_lsp(
|
||||
buffer.clone(),
|
||||
LanguageServerToQuery::Primary,
|
||||
OnTypeFormatting {
|
||||
position,
|
||||
trigger,
|
||||
@@ -5008,7 +5097,12 @@ impl Project {
|
||||
let lsp_request = InlayHints { range };
|
||||
|
||||
if self.is_local() {
|
||||
let lsp_request_task = self.request_lsp(buffer_handle.clone(), lsp_request, cx);
|
||||
let lsp_request_task = self.request_lsp(
|
||||
buffer_handle.clone(),
|
||||
LanguageServerToQuery::Primary,
|
||||
lsp_request,
|
||||
cx,
|
||||
);
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
buffer_handle
|
||||
.update(&mut cx, |buffer, _| {
|
||||
@@ -5441,10 +5535,10 @@ impl Project {
|
||||
.await;
|
||||
}
|
||||
|
||||
// TODO: Wire this up to allow selecting a server?
|
||||
fn request_lsp<R: LspCommand>(
|
||||
&self,
|
||||
buffer_handle: ModelHandle<Buffer>,
|
||||
server: LanguageServerToQuery,
|
||||
request: R,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<R::Response>>
|
||||
@@ -5453,11 +5547,19 @@ impl Project {
|
||||
{
|
||||
let buffer = buffer_handle.read(cx);
|
||||
if self.is_local() {
|
||||
let language_server = match server {
|
||||
LanguageServerToQuery::Primary => {
|
||||
match self.primary_language_server_for_buffer(buffer, cx) {
|
||||
Some((_, server)) => Some(Arc::clone(server)),
|
||||
None => return Task::ready(Ok(Default::default())),
|
||||
}
|
||||
}
|
||||
LanguageServerToQuery::Other(id) => self
|
||||
.language_server_for_buffer(buffer, id, cx)
|
||||
.map(|(_, server)| Arc::clone(server)),
|
||||
};
|
||||
let file = File::from_dyn(buffer.file()).and_then(File::as_local);
|
||||
if let Some((file, language_server)) = file.zip(
|
||||
self.primary_language_servers_for_buffer(buffer, cx)
|
||||
.map(|(_, server)| server.clone()),
|
||||
) {
|
||||
if let (Some(file), Some(language_server)) = (file, language_server) {
|
||||
let lsp_params = request.to_lsp(&file.abs_path(cx), buffer, &language_server, cx);
|
||||
return cx.spawn(|this, cx| async move {
|
||||
if !request.check_capabilities(language_server.capabilities()) {
|
||||
@@ -5490,31 +5592,40 @@ impl Project {
|
||||
});
|
||||
}
|
||||
} else if let Some(project_id) = self.remote_id() {
|
||||
let rpc = self.client.clone();
|
||||
let message = request.to_proto(project_id, buffer);
|
||||
return cx.spawn_weak(|this, cx| async move {
|
||||
// Ensure the project is still alive by the time the task
|
||||
// is scheduled.
|
||||
this.upgrade(&cx)
|
||||
.ok_or_else(|| anyhow!("project dropped"))?;
|
||||
|
||||
let response = rpc.request(message).await?;
|
||||
|
||||
let this = this
|
||||
.upgrade(&cx)
|
||||
.ok_or_else(|| anyhow!("project dropped"))?;
|
||||
if this.read_with(&cx, |this, _| this.is_read_only()) {
|
||||
Err(anyhow!("disconnected before completing request"))
|
||||
} else {
|
||||
request
|
||||
.response_from_proto(response, this, buffer_handle, cx)
|
||||
.await
|
||||
}
|
||||
});
|
||||
return self.send_lsp_proto_request(buffer_handle, project_id, request, cx);
|
||||
}
|
||||
|
||||
Task::ready(Ok(Default::default()))
|
||||
}
|
||||
|
||||
fn send_lsp_proto_request<R: LspCommand>(
|
||||
&self,
|
||||
buffer: ModelHandle<Buffer>,
|
||||
project_id: u64,
|
||||
request: R,
|
||||
cx: &mut ModelContext<'_, Project>,
|
||||
) -> Task<anyhow::Result<<R as LspCommand>::Response>> {
|
||||
let rpc = self.client.clone();
|
||||
let message = request.to_proto(project_id, buffer.read(cx));
|
||||
cx.spawn_weak(|this, cx| async move {
|
||||
// Ensure the project is still alive by the time the task
|
||||
// is scheduled.
|
||||
this.upgrade(&cx)
|
||||
.ok_or_else(|| anyhow!("project dropped"))?;
|
||||
let response = rpc.request(message).await?;
|
||||
let this = this
|
||||
.upgrade(&cx)
|
||||
.ok_or_else(|| anyhow!("project dropped"))?;
|
||||
if this.read_with(&cx, |this, _| this.is_read_only()) {
|
||||
Err(anyhow!("disconnected before completing request"))
|
||||
} else {
|
||||
request
|
||||
.response_from_proto(response, this, buffer, cx)
|
||||
.await
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn sort_candidates_and_open_buffers(
|
||||
mut matching_paths_rx: Receiver<SearchMatchCandidate>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
@@ -7150,7 +7261,7 @@ impl Project {
|
||||
let buffer_version = buffer_handle.read_with(&cx, |buffer, _| buffer.version());
|
||||
let response = this
|
||||
.update(&mut cx, |this, cx| {
|
||||
this.request_lsp(buffer_handle, request, cx)
|
||||
this.request_lsp(buffer_handle, LanguageServerToQuery::Primary, request, cx)
|
||||
})
|
||||
.await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
@@ -7867,7 +7978,7 @@ impl Project {
|
||||
})
|
||||
}
|
||||
|
||||
fn primary_language_servers_for_buffer(
|
||||
fn primary_language_server_for_buffer(
|
||||
&self,
|
||||
buffer: &Buffer,
|
||||
cx: &AppContext,
|
||||
@@ -8089,31 +8200,6 @@ impl LspAdapterDelegate for ProjectLspAdapterDelegate {
|
||||
}
|
||||
}
|
||||
|
||||
fn split_operations(
|
||||
mut operations: Vec<proto::Operation>,
|
||||
) -> impl Iterator<Item = Vec<proto::Operation>> {
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
const CHUNK_SIZE: usize = 5;
|
||||
|
||||
#[cfg(not(any(test, feature = "test-support")))]
|
||||
const CHUNK_SIZE: usize = 100;
|
||||
|
||||
let mut done = false;
|
||||
std::iter::from_fn(move || {
|
||||
if done {
|
||||
return None;
|
||||
}
|
||||
|
||||
let operations = operations
|
||||
.drain(..cmp::min(CHUNK_SIZE, operations.len()))
|
||||
.collect::<Vec<_>>();
|
||||
if operations.is_empty() {
|
||||
done = true;
|
||||
}
|
||||
Some(operations)
|
||||
})
|
||||
}
|
||||
|
||||
fn serialize_symbol(symbol: &Symbol) -> proto::Symbol {
|
||||
proto::Symbol {
|
||||
language_server_name: symbol.language_server_name.0.to_string(),
|
||||
|
||||
@@ -2272,7 +2272,18 @@ async fn test_completions_without_edit_ranges(cx: &mut gpui::TestAppContext) {
|
||||
},
|
||||
Some(tree_sitter_typescript::language_typescript()),
|
||||
);
|
||||
let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await;
|
||||
let mut fake_language_servers = language
|
||||
.set_fake_lsp_adapter(Arc::new(FakeLspAdapter {
|
||||
capabilities: lsp::ServerCapabilities {
|
||||
completion_provider: Some(lsp::CompletionOptions {
|
||||
trigger_characters: Some(vec![":".to_string()]),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
}))
|
||||
.await;
|
||||
|
||||
let fs = FakeFs::new(cx.background());
|
||||
fs.insert_tree(
|
||||
@@ -2358,7 +2369,18 @@ async fn test_completions_with_carriage_returns(cx: &mut gpui::TestAppContext) {
|
||||
},
|
||||
Some(tree_sitter_typescript::language_typescript()),
|
||||
);
|
||||
let mut fake_language_servers = language.set_fake_lsp_adapter(Default::default()).await;
|
||||
let mut fake_language_servers = language
|
||||
.set_fake_lsp_adapter(Arc::new(FakeLspAdapter {
|
||||
capabilities: lsp::ServerCapabilities {
|
||||
completion_provider: Some(lsp::CompletionOptions {
|
||||
trigger_characters: Some(vec![":".to_string()]),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
}))
|
||||
.await;
|
||||
|
||||
let fs = FakeFs::new(cx.background());
|
||||
fs.insert_tree(
|
||||
|
||||
@@ -225,15 +225,14 @@ impl SearchQuery {
|
||||
if self.as_str().is_empty() {
|
||||
return Default::default();
|
||||
}
|
||||
let language = buffer.language_at(0);
|
||||
|
||||
let range_offset = subrange.as_ref().map(|r| r.start).unwrap_or(0);
|
||||
let rope = if let Some(range) = subrange {
|
||||
buffer.as_rope().slice(range)
|
||||
} else {
|
||||
buffer.as_rope().clone()
|
||||
};
|
||||
|
||||
let kind = |c| char_kind(language, c);
|
||||
|
||||
let mut matches = Vec::new();
|
||||
match self {
|
||||
Self::Text {
|
||||
@@ -249,6 +248,9 @@ impl SearchQuery {
|
||||
|
||||
let mat = mat.unwrap();
|
||||
if *whole_word {
|
||||
let scope = buffer.language_scope_at(range_offset + mat.start());
|
||||
let kind = |c| char_kind(&scope, c);
|
||||
|
||||
let prev_kind = rope.reversed_chars_at(mat.start()).next().map(kind);
|
||||
let start_kind = kind(rope.chars_at(mat.start()).next().unwrap());
|
||||
let end_kind = kind(rope.reversed_chars_at(mat.end()).next().unwrap());
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
syntax = "proto3";
|
||||
package zed.messages;
|
||||
|
||||
// Looking for a number? Search "// Current max"
|
||||
|
||||
message PeerId {
|
||||
uint32 owner_id = 1;
|
||||
uint32 id = 2;
|
||||
@@ -139,7 +141,7 @@ message Envelope {
|
||||
RespondToChannelInvite respond_to_channel_invite = 123;
|
||||
UpdateChannels update_channels = 124;
|
||||
JoinChannel join_channel = 125;
|
||||
RemoveChannel remove_channel = 126;
|
||||
DeleteChannel delete_channel = 126;
|
||||
GetChannelMembers get_channel_members = 127;
|
||||
GetChannelMembersResponse get_channel_members_response = 128;
|
||||
SetChannelMemberAdmin set_channel_member_admin = 129;
|
||||
@@ -151,6 +153,12 @@ message Envelope {
|
||||
LeaveChannelBuffer leave_channel_buffer = 134;
|
||||
AddChannelBufferCollaborator add_channel_buffer_collaborator = 135;
|
||||
RemoveChannelBufferCollaborator remove_channel_buffer_collaborator = 136;
|
||||
UpdateChannelBufferCollaborator update_channel_buffer_collaborator = 139;
|
||||
RejoinChannelBuffers rejoin_channel_buffers = 140;
|
||||
RejoinChannelBuffersResponse rejoin_channel_buffers_response = 141;
|
||||
LinkChannel link_channel = 142;
|
||||
UnlinkChannel unlink_channel = 143;
|
||||
MoveChannel move_channel = 144; // Current max
|
||||
}
|
||||
}
|
||||
|
||||
@@ -430,6 +438,12 @@ message RemoveChannelBufferCollaborator {
|
||||
PeerId peer_id = 2;
|
||||
}
|
||||
|
||||
message UpdateChannelBufferCollaborator {
|
||||
uint64 channel_id = 1;
|
||||
PeerId old_peer_id = 2;
|
||||
PeerId new_peer_id = 3;
|
||||
}
|
||||
|
||||
message GetDefinition {
|
||||
uint64 project_id = 1;
|
||||
uint64 buffer_id = 2;
|
||||
@@ -616,6 +630,12 @@ message BufferVersion {
|
||||
repeated VectorClockEntry version = 2;
|
||||
}
|
||||
|
||||
message ChannelBufferVersion {
|
||||
uint64 channel_id = 1;
|
||||
repeated VectorClockEntry version = 2;
|
||||
uint64 epoch = 3;
|
||||
}
|
||||
|
||||
enum FormatTrigger {
|
||||
Save = 0;
|
||||
Manual = 1;
|
||||
@@ -657,7 +677,8 @@ message Completion {
|
||||
Anchor old_start = 1;
|
||||
Anchor old_end = 2;
|
||||
string new_text = 3;
|
||||
bytes lsp_completion = 4;
|
||||
uint64 server_id = 4;
|
||||
bytes lsp_completion = 5;
|
||||
}
|
||||
|
||||
message GetCodeActions {
|
||||
@@ -860,12 +881,12 @@ message ProjectTransaction {
|
||||
}
|
||||
|
||||
message Transaction {
|
||||
LocalTimestamp id = 1;
|
||||
repeated LocalTimestamp edit_ids = 2;
|
||||
LamportTimestamp id = 1;
|
||||
repeated LamportTimestamp edit_ids = 2;
|
||||
repeated VectorClockEntry start = 3;
|
||||
}
|
||||
|
||||
message LocalTimestamp {
|
||||
message LamportTimestamp {
|
||||
uint32 replica_id = 1;
|
||||
uint32 value = 2;
|
||||
}
|
||||
@@ -927,11 +948,17 @@ message LspDiskBasedDiagnosticsUpdated {}
|
||||
|
||||
message UpdateChannels {
|
||||
repeated Channel channels = 1;
|
||||
repeated uint64 remove_channels = 2;
|
||||
repeated Channel channel_invitations = 3;
|
||||
repeated uint64 remove_channel_invitations = 4;
|
||||
repeated ChannelParticipants channel_participants = 5;
|
||||
repeated ChannelPermission channel_permissions = 6;
|
||||
repeated ChannelEdge delete_channel_edge = 2;
|
||||
repeated uint64 delete_channels = 3;
|
||||
repeated Channel channel_invitations = 4;
|
||||
repeated uint64 remove_channel_invitations = 5;
|
||||
repeated ChannelParticipants channel_participants = 6;
|
||||
repeated ChannelPermission channel_permissions = 7;
|
||||
}
|
||||
|
||||
message ChannelEdge {
|
||||
uint64 channel_id = 1;
|
||||
uint64 parent_id = 2;
|
||||
}
|
||||
|
||||
message ChannelPermission {
|
||||
@@ -948,7 +975,7 @@ message JoinChannel {
|
||||
uint64 channel_id = 1;
|
||||
}
|
||||
|
||||
message RemoveChannel {
|
||||
message DeleteChannel {
|
||||
uint64 channel_id = 1;
|
||||
}
|
||||
|
||||
@@ -1003,16 +1030,48 @@ message RenameChannel {
|
||||
string name = 2;
|
||||
}
|
||||
|
||||
message LinkChannel {
|
||||
uint64 channel_id = 1;
|
||||
uint64 to = 2;
|
||||
}
|
||||
|
||||
message UnlinkChannel {
|
||||
uint64 channel_id = 1;
|
||||
optional uint64 from = 2;
|
||||
}
|
||||
|
||||
message MoveChannel {
|
||||
uint64 channel_id = 1;
|
||||
optional uint64 from = 2;
|
||||
uint64 to = 3;
|
||||
}
|
||||
|
||||
message JoinChannelBuffer {
|
||||
uint64 channel_id = 1;
|
||||
}
|
||||
|
||||
message RejoinChannelBuffers {
|
||||
repeated ChannelBufferVersion buffers = 1;
|
||||
}
|
||||
|
||||
message RejoinChannelBuffersResponse {
|
||||
repeated RejoinedChannelBuffer buffers = 1;
|
||||
}
|
||||
|
||||
message JoinChannelBufferResponse {
|
||||
uint64 buffer_id = 1;
|
||||
uint32 replica_id = 2;
|
||||
string base_text = 3;
|
||||
repeated Operation operations = 4;
|
||||
repeated Collaborator collaborators = 5;
|
||||
uint64 epoch = 6;
|
||||
}
|
||||
|
||||
message RejoinedChannelBuffer {
|
||||
uint64 channel_id = 1;
|
||||
repeated VectorClockEntry version = 2;
|
||||
repeated Operation operations = 3;
|
||||
repeated Collaborator collaborators = 4;
|
||||
}
|
||||
|
||||
message LeaveChannelBuffer {
|
||||
@@ -1279,7 +1338,7 @@ message Excerpt {
|
||||
|
||||
message Anchor {
|
||||
uint32 replica_id = 1;
|
||||
uint32 local_timestamp = 2;
|
||||
uint32 timestamp = 2;
|
||||
uint64 offset = 3;
|
||||
Bias bias = 4;
|
||||
optional uint64 buffer_id = 5;
|
||||
@@ -1323,19 +1382,17 @@ message Operation {
|
||||
|
||||
message Edit {
|
||||
uint32 replica_id = 1;
|
||||
uint32 local_timestamp = 2;
|
||||
uint32 lamport_timestamp = 3;
|
||||
repeated VectorClockEntry version = 4;
|
||||
repeated Range ranges = 5;
|
||||
repeated string new_text = 6;
|
||||
uint32 lamport_timestamp = 2;
|
||||
repeated VectorClockEntry version = 3;
|
||||
repeated Range ranges = 4;
|
||||
repeated string new_text = 5;
|
||||
}
|
||||
|
||||
message Undo {
|
||||
uint32 replica_id = 1;
|
||||
uint32 local_timestamp = 2;
|
||||
uint32 lamport_timestamp = 3;
|
||||
repeated VectorClockEntry version = 4;
|
||||
repeated UndoCount counts = 5;
|
||||
uint32 lamport_timestamp = 2;
|
||||
repeated VectorClockEntry version = 3;
|
||||
repeated UndoCount counts = 4;
|
||||
}
|
||||
|
||||
message UpdateSelections {
|
||||
@@ -1361,7 +1418,7 @@ message UndoMapEntry {
|
||||
|
||||
message UndoCount {
|
||||
uint32 replica_id = 1;
|
||||
uint32 local_timestamp = 2;
|
||||
uint32 lamport_timestamp = 2;
|
||||
uint32 count = 3;
|
||||
}
|
||||
|
||||
|
||||
@@ -229,13 +229,18 @@ messages!(
|
||||
(StartLanguageServer, Foreground),
|
||||
(SynchronizeBuffers, Foreground),
|
||||
(SynchronizeBuffersResponse, Foreground),
|
||||
(RejoinChannelBuffers, Foreground),
|
||||
(RejoinChannelBuffersResponse, Foreground),
|
||||
(Test, Foreground),
|
||||
(Unfollow, Foreground),
|
||||
(UnshareProject, Foreground),
|
||||
(UpdateBuffer, Foreground),
|
||||
(UpdateBufferFile, Foreground),
|
||||
(UpdateContacts, Foreground),
|
||||
(RemoveChannel, Foreground),
|
||||
(DeleteChannel, Foreground),
|
||||
(MoveChannel, Foreground),
|
||||
(LinkChannel, Foreground),
|
||||
(UnlinkChannel, Foreground),
|
||||
(UpdateChannels, Foreground),
|
||||
(UpdateDiagnosticSummary, Foreground),
|
||||
(UpdateFollowers, Foreground),
|
||||
@@ -257,6 +262,7 @@ messages!(
|
||||
(UpdateChannelBuffer, Foreground),
|
||||
(RemoveChannelBufferCollaborator, Foreground),
|
||||
(AddChannelBufferCollaborator, Foreground),
|
||||
(UpdateChannelBufferCollaborator, Foreground),
|
||||
);
|
||||
|
||||
request_messages!(
|
||||
@@ -312,13 +318,17 @@ request_messages!(
|
||||
(SetChannelMemberAdmin, Ack),
|
||||
(GetChannelMembers, GetChannelMembersResponse),
|
||||
(JoinChannel, JoinRoomResponse),
|
||||
(RemoveChannel, Ack),
|
||||
(DeleteChannel, Ack),
|
||||
(RenameProjectEntry, ProjectEntryResponse),
|
||||
(RenameChannel, ChannelResponse),
|
||||
(LinkChannel, Ack),
|
||||
(UnlinkChannel, Ack),
|
||||
(MoveChannel, Ack),
|
||||
(SaveBuffer, BufferSaved),
|
||||
(SearchProject, SearchProjectResponse),
|
||||
(ShareProject, ShareProjectResponse),
|
||||
(SynchronizeBuffers, SynchronizeBuffersResponse),
|
||||
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
|
||||
(Test, Test),
|
||||
(UpdateBuffer, Ack),
|
||||
(UpdateParticipantLocation, Ack),
|
||||
@@ -386,7 +396,8 @@ entity_messages!(
|
||||
channel_id,
|
||||
UpdateChannelBuffer,
|
||||
RemoveChannelBufferCollaborator,
|
||||
AddChannelBufferCollaborator
|
||||
AddChannelBufferCollaborator,
|
||||
UpdateChannelBufferCollaborator
|
||||
);
|
||||
|
||||
const KIB: usize = 1024;
|
||||
|
||||
@@ -6,4 +6,4 @@ pub use conn::Connection;
|
||||
pub use peer::*;
|
||||
mod macros;
|
||||
|
||||
pub const PROTOCOL_VERSION: u32 = 61;
|
||||
pub const PROTOCOL_VERSION: u32 = 62;
|
||||
|
||||
@@ -12,22 +12,19 @@ use editor::{
|
||||
SelectAll, MAX_TAB_TITLE_LEN,
|
||||
};
|
||||
use futures::StreamExt;
|
||||
|
||||
use gpui::platform::PromptLevel;
|
||||
|
||||
use gpui::{
|
||||
actions, elements::*, platform::MouseButton, Action, AnyElement, AnyViewHandle, AppContext,
|
||||
Entity, ModelContext, ModelHandle, Subscription, Task, View, ViewContext, ViewHandle,
|
||||
WeakModelHandle, WeakViewHandle,
|
||||
actions,
|
||||
elements::*,
|
||||
platform::{MouseButton, PromptLevel},
|
||||
Action, AnyElement, AnyViewHandle, AppContext, Entity, ModelContext, ModelHandle, Subscription,
|
||||
Task, View, ViewContext, ViewHandle, WeakModelHandle, WeakViewHandle,
|
||||
};
|
||||
|
||||
use menu::Confirm;
|
||||
use postage::stream::Stream;
|
||||
use project::{
|
||||
search::{PathMatcher, SearchInputs, SearchQuery},
|
||||
Entry, Project,
|
||||
};
|
||||
use semantic_index::SemanticIndex;
|
||||
use semantic_index::{SemanticIndex, SemanticIndexStatus};
|
||||
use smallvec::SmallVec;
|
||||
use std::{
|
||||
any::{Any, TypeId},
|
||||
@@ -118,7 +115,7 @@ pub struct ProjectSearchView {
|
||||
model: ModelHandle<ProjectSearch>,
|
||||
query_editor: ViewHandle<Editor>,
|
||||
results_editor: ViewHandle<Editor>,
|
||||
semantic_state: Option<SemanticSearchState>,
|
||||
semantic_state: Option<SemanticState>,
|
||||
semantic_permissioned: Option<bool>,
|
||||
search_options: SearchOptions,
|
||||
panels_with_errors: HashSet<InputPanel>,
|
||||
@@ -131,10 +128,9 @@ pub struct ProjectSearchView {
|
||||
current_mode: SearchMode,
|
||||
}
|
||||
|
||||
struct SemanticSearchState {
|
||||
file_count: usize,
|
||||
outstanding_file_count: usize,
|
||||
_progress_task: Task<()>,
|
||||
struct SemanticState {
|
||||
index_status: SemanticIndexStatus,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
pub struct ProjectSearchBar {
|
||||
@@ -233,7 +229,7 @@ impl ProjectSearch {
|
||||
self.search_id += 1;
|
||||
self.match_ranges.clear();
|
||||
self.search_history.add(inputs.as_str().to_string());
|
||||
self.no_results = Some(true);
|
||||
self.no_results = None;
|
||||
self.pending_search = Some(cx.spawn(|this, mut cx| async move {
|
||||
let results = search?.await.log_err()?;
|
||||
let matches = results
|
||||
@@ -241,9 +237,10 @@ impl ProjectSearch {
|
||||
.map(|result| (result.buffer, vec![result.range.start..result.range.start]));
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.no_results = Some(true);
|
||||
this.excerpts.update(cx, |excerpts, cx| {
|
||||
excerpts.clear(cx);
|
||||
})
|
||||
});
|
||||
});
|
||||
for (buffer, ranges) in matches {
|
||||
let mut match_ranges = this.update(&mut cx, |this, cx| {
|
||||
@@ -318,19 +315,20 @@ impl View for ProjectSearchView {
|
||||
}
|
||||
};
|
||||
|
||||
let semantic_status = if let Some(semantic) = &self.semantic_state {
|
||||
if semantic.outstanding_file_count > 0 {
|
||||
format!(
|
||||
"Indexing: {} of {}...",
|
||||
semantic.file_count - semantic.outstanding_file_count,
|
||||
semantic.file_count
|
||||
)
|
||||
} else {
|
||||
"Indexing complete".to_string()
|
||||
let semantic_status = self.semantic_state.as_ref().and_then(|semantic| {
|
||||
let status = semantic.index_status;
|
||||
match status {
|
||||
SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()),
|
||||
SemanticIndexStatus::Indexing { remaining_files } => {
|
||||
if remaining_files == 0 {
|
||||
Some(format!("Indexing..."))
|
||||
} else {
|
||||
Some(format!("Remaining files to index: {}", remaining_files))
|
||||
}
|
||||
}
|
||||
SemanticIndexStatus::NotIndexed => None,
|
||||
}
|
||||
} else {
|
||||
"Indexing: ...".to_string()
|
||||
};
|
||||
});
|
||||
|
||||
let minor_text = if let Some(no_results) = model.no_results {
|
||||
if model.pending_search.is_none() && no_results {
|
||||
@@ -340,12 +338,16 @@ impl View for ProjectSearchView {
|
||||
}
|
||||
} else {
|
||||
match current_mode {
|
||||
SearchMode::Semantic => vec![
|
||||
"".to_owned(),
|
||||
semantic_status,
|
||||
"Simply explain the code you are looking to find.".to_owned(),
|
||||
"ex. 'prompt user for permissions to index their project'".to_owned(),
|
||||
],
|
||||
SearchMode::Semantic => {
|
||||
let mut minor_text = Vec::new();
|
||||
minor_text.push("".into());
|
||||
minor_text.extend(semantic_status);
|
||||
minor_text.push("Simply explain the code you are looking to find.".into());
|
||||
minor_text.push(
|
||||
"ex. 'prompt user for permissions to index their project'".into(),
|
||||
);
|
||||
minor_text
|
||||
}
|
||||
_ => vec![
|
||||
"".to_owned(),
|
||||
"Include/exclude specific paths with the filter option.".to_owned(),
|
||||
@@ -641,40 +643,29 @@ impl ProjectSearchView {
|
||||
|
||||
let project = self.model.read(cx).project.clone();
|
||||
|
||||
let index_task = semantic_index.update(cx, |semantic_index, cx| {
|
||||
semantic_index.index_project(project, cx)
|
||||
semantic_index.update(cx, |semantic_index, cx| {
|
||||
semantic_index
|
||||
.index_project(project.clone(), cx)
|
||||
.detach_and_log_err(cx);
|
||||
});
|
||||
|
||||
cx.spawn(|search_view, mut cx| async move {
|
||||
let (files_to_index, mut files_remaining_rx) = index_task.await?;
|
||||
self.semantic_state = Some(SemanticState {
|
||||
index_status: semantic_index.read(cx).status(&project),
|
||||
_subscription: cx.observe(&semantic_index, Self::semantic_index_changed),
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
||||
search_view.update(&mut cx, |search_view, cx| {
|
||||
cx.notify();
|
||||
search_view.semantic_state = Some(SemanticSearchState {
|
||||
file_count: files_to_index,
|
||||
outstanding_file_count: files_to_index,
|
||||
_progress_task: cx.spawn(|search_view, mut cx| async move {
|
||||
while let Some(count) = files_remaining_rx.recv().await {
|
||||
search_view
|
||||
.update(&mut cx, |search_view, cx| {
|
||||
if let Some(semantic_search_state) =
|
||||
&mut search_view.semantic_state
|
||||
{
|
||||
semantic_search_state.outstanding_file_count = count;
|
||||
cx.notify();
|
||||
if count == 0 {
|
||||
return;
|
||||
}
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}),
|
||||
});
|
||||
})?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
fn semantic_index_changed(
|
||||
&mut self,
|
||||
semantic_index: ModelHandle<SemanticIndex>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) {
|
||||
let project = self.model.read(cx).project.clone();
|
||||
if let Some(semantic_state) = self.semantic_state.as_mut() {
|
||||
semantic_state.index_status = semantic_index.read(cx).status(&project);
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -873,7 +864,7 @@ impl ProjectSearchView {
|
||||
SemanticIndex::global(cx)
|
||||
.map(|semantic| {
|
||||
let project = self.model.read(cx).project.clone();
|
||||
semantic.update(cx, |this, cx| this.project_previously_indexed(project, cx))
|
||||
semantic.update(cx, |this, cx| this.project_previously_indexed(&project, cx))
|
||||
})
|
||||
.unwrap_or(Task::ready(Ok(false)))
|
||||
}
|
||||
@@ -958,11 +949,7 @@ impl ProjectSearchView {
|
||||
let mode = self.current_mode;
|
||||
match mode {
|
||||
SearchMode::Semantic => {
|
||||
if let Some(semantic) = &mut self.semantic_state {
|
||||
if semantic.outstanding_file_count > 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
if self.semantic_state.is_some() {
|
||||
if let Some(query) = self.build_search_query(cx) {
|
||||
self.model
|
||||
.update(cx, |model, cx| model.semantic_search(query.as_inner(), cx));
|
||||
|
||||
@@ -9,6 +9,7 @@ path = "src/semantic_index.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
collections = { path = "../collections" }
|
||||
gpui = { path = "../gpui" }
|
||||
language = { path = "../language" }
|
||||
project = { path = "../project" }
|
||||
@@ -39,8 +40,10 @@ rand.workspace = true
|
||||
schemars.workspace = true
|
||||
globset.workspace = true
|
||||
sha1 = "0.10.5"
|
||||
parse_duration = "2.1.1"
|
||||
|
||||
[dev-dependencies]
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
language = { path = "../language", features = ["test-support"] }
|
||||
project = { path = "../project", features = ["test-support"] }
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
|
||||
use crate::{
|
||||
embedding::Embedding,
|
||||
parsing::{Span, SpanDigest},
|
||||
SEMANTIC_INDEX_VERSION,
|
||||
};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use collections::HashMap;
|
||||
use futures::channel::oneshot;
|
||||
use gpui::executor;
|
||||
use project::{search::PathMatcher, Fs};
|
||||
use rpc::proto::Timestamp;
|
||||
use rusqlite::{
|
||||
params,
|
||||
types::{FromSql, FromSqlResult, ValueRef},
|
||||
};
|
||||
use rusqlite::params;
|
||||
use rusqlite::types::Value;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::HashMap,
|
||||
future::Future,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
time::SystemTime,
|
||||
};
|
||||
use util::TryFutureExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FileRecord {
|
||||
@@ -23,286 +29,366 @@ pub struct FileRecord {
|
||||
pub mtime: Timestamp,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Embedding(pub Vec<f32>);
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Sha1(pub Vec<u8>);
|
||||
|
||||
impl FromSql for Embedding {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
if embedding.is_err() {
|
||||
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
|
||||
}
|
||||
return Ok(Embedding(embedding.unwrap()));
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for Sha1 {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
if sha1.is_err() {
|
||||
return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
|
||||
}
|
||||
return Ok(Sha1(sha1.unwrap()));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct VectorDatabase {
|
||||
db: rusqlite::Connection,
|
||||
path: Arc<Path>,
|
||||
transactions:
|
||||
smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
|
||||
}
|
||||
|
||||
impl VectorDatabase {
|
||||
pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
|
||||
pub async fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
path: Arc<Path>,
|
||||
executor: Arc<executor::Background>,
|
||||
) -> Result<Self> {
|
||||
if let Some(db_directory) = path.parent() {
|
||||
fs.create_dir(db_directory).await?;
|
||||
}
|
||||
|
||||
let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
|
||||
Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
|
||||
>();
|
||||
executor
|
||||
.spawn({
|
||||
let path = path.clone();
|
||||
async move {
|
||||
let mut connection = rusqlite::Connection::open(&path)?;
|
||||
|
||||
connection.pragma_update(None, "journal_mode", "wal")?;
|
||||
connection.pragma_update(None, "synchronous", "normal")?;
|
||||
connection.pragma_update(None, "cache_size", 1000000)?;
|
||||
connection.pragma_update(None, "temp_store", "MEMORY")?;
|
||||
|
||||
while let Ok(transaction) = transactions_rx.recv().await {
|
||||
transaction(&mut connection);
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.log_err()
|
||||
})
|
||||
.detach();
|
||||
let this = Self {
|
||||
db: rusqlite::Connection::open(path.as_path())?,
|
||||
transactions: transactions_tx,
|
||||
path,
|
||||
};
|
||||
this.initialize_database()?;
|
||||
this.initialize_database().await?;
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
fn get_existing_version(&self) -> Result<i64> {
|
||||
let mut version_query = self
|
||||
.db
|
||||
.prepare("SELECT version from semantic_index_config")?;
|
||||
version_query
|
||||
.query_row([], |row| Ok(row.get::<_, i64>(0)?))
|
||||
.map_err(|err| anyhow!("version query failed: {err}"))
|
||||
pub fn path(&self) -> &Arc<Path> {
|
||||
&self.path
|
||||
}
|
||||
|
||||
fn initialize_database(&self) -> Result<()> {
|
||||
rusqlite::vtab::array::load_module(&self.db)?;
|
||||
|
||||
// Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
|
||||
if self
|
||||
.get_existing_version()
|
||||
.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
|
||||
{
|
||||
log::trace!("vector database schema up to date");
|
||||
return Ok(());
|
||||
fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
|
||||
where
|
||||
F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
|
||||
T: 'static + Send,
|
||||
{
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let transactions = self.transactions.clone();
|
||||
async move {
|
||||
if transactions
|
||||
.send(Box::new(|connection| {
|
||||
let result = connection
|
||||
.transaction()
|
||||
.map_err(|err| anyhow!(err))
|
||||
.and_then(|transaction| {
|
||||
let result = f(&transaction)?;
|
||||
transaction.commit()?;
|
||||
Ok(result)
|
||||
});
|
||||
let _ = tx.send(result);
|
||||
}))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err(anyhow!("connection was dropped"))?;
|
||||
}
|
||||
rx.await?
|
||||
}
|
||||
|
||||
log::trace!("vector database schema out of date. updating...");
|
||||
self.db
|
||||
.execute("DROP TABLE IF EXISTS documents", [])
|
||||
.context("failed to drop 'documents' table")?;
|
||||
self.db
|
||||
.execute("DROP TABLE IF EXISTS files", [])
|
||||
.context("failed to drop 'files' table")?;
|
||||
self.db
|
||||
.execute("DROP TABLE IF EXISTS worktrees", [])
|
||||
.context("failed to drop 'worktrees' table")?;
|
||||
self.db
|
||||
.execute("DROP TABLE IF EXISTS semantic_index_config", [])
|
||||
.context("failed to drop 'semantic_index_config' table")?;
|
||||
|
||||
// Initialize Vector Databasing Tables
|
||||
self.db.execute(
|
||||
"CREATE TABLE semantic_index_config (
|
||||
version INTEGER NOT NULL
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
self.db.execute(
|
||||
"INSERT INTO semantic_index_config (version) VALUES (?1)",
|
||||
params![SEMANTIC_INDEX_VERSION],
|
||||
)?;
|
||||
|
||||
self.db.execute(
|
||||
"CREATE TABLE worktrees (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
absolute_path VARCHAR NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
|
||||
",
|
||||
[],
|
||||
)?;
|
||||
|
||||
self.db.execute(
|
||||
"CREATE TABLE files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
worktree_id INTEGER NOT NULL,
|
||||
relative_path VARCHAR NOT NULL,
|
||||
mtime_seconds INTEGER NOT NULL,
|
||||
mtime_nanos INTEGER NOT NULL,
|
||||
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
self.db.execute(
|
||||
"CREATE TABLE documents (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_id INTEGER NOT NULL,
|
||||
start_byte INTEGER NOT NULL,
|
||||
end_byte INTEGER NOT NULL,
|
||||
name VARCHAR NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
sha1 BLOB NOT NULL,
|
||||
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
log::trace!("vector database initialized with updated schema.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
|
||||
self.db.execute(
|
||||
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
|
||||
params![worktree_id, delete_path.to_str()],
|
||||
)?;
|
||||
Ok(())
|
||||
fn initialize_database(&self) -> impl Future<Output = Result<()>> {
|
||||
self.transact(|db| {
|
||||
rusqlite::vtab::array::load_module(&db)?;
|
||||
|
||||
// Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
|
||||
let version_query = db.prepare("SELECT version from semantic_index_config");
|
||||
let version = version_query
|
||||
.and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
|
||||
if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
|
||||
log::trace!("vector database schema up to date");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
log::trace!("vector database schema out of date. updating...");
|
||||
// We renamed the `documents` table to `spans`, so we want to drop
|
||||
// `documents` without recreating it if it exists.
|
||||
db.execute("DROP TABLE IF EXISTS documents", [])
|
||||
.context("failed to drop 'documents' table")?;
|
||||
db.execute("DROP TABLE IF EXISTS spans", [])
|
||||
.context("failed to drop 'spans' table")?;
|
||||
db.execute("DROP TABLE IF EXISTS files", [])
|
||||
.context("failed to drop 'files' table")?;
|
||||
db.execute("DROP TABLE IF EXISTS worktrees", [])
|
||||
.context("failed to drop 'worktrees' table")?;
|
||||
db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
|
||||
.context("failed to drop 'semantic_index_config' table")?;
|
||||
|
||||
// Initialize Vector Databasing Tables
|
||||
db.execute(
|
||||
"CREATE TABLE semantic_index_config (
|
||||
version INTEGER NOT NULL
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"INSERT INTO semantic_index_config (version) VALUES (?1)",
|
||||
params![SEMANTIC_INDEX_VERSION],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"CREATE TABLE worktrees (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
absolute_path VARCHAR NOT NULL
|
||||
);
|
||||
CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
|
||||
",
|
||||
[],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"CREATE TABLE files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
worktree_id INTEGER NOT NULL,
|
||||
relative_path VARCHAR NOT NULL,
|
||||
mtime_seconds INTEGER NOT NULL,
|
||||
mtime_nanos INTEGER NOT NULL,
|
||||
FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
db.execute(
|
||||
"CREATE TABLE spans (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_id INTEGER NOT NULL,
|
||||
start_byte INTEGER NOT NULL,
|
||||
end_byte INTEGER NOT NULL,
|
||||
name VARCHAR NOT NULL,
|
||||
embedding BLOB NOT NULL,
|
||||
digest BLOB NOT NULL,
|
||||
FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
log::trace!("vector database initialized with updated schema.");
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn delete_file(
|
||||
&self,
|
||||
worktree_id: i64,
|
||||
delete_path: Arc<Path>,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
self.transact(move |db| {
|
||||
db.execute(
|
||||
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
|
||||
params![worktree_id, delete_path.to_str()],
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn insert_file(
|
||||
&self,
|
||||
worktree_id: i64,
|
||||
path: PathBuf,
|
||||
path: Arc<Path>,
|
||||
mtime: SystemTime,
|
||||
documents: Vec<Document>,
|
||||
) -> Result<()> {
|
||||
// Return the existing ID, if both the file and mtime match
|
||||
let mtime = Timestamp::from(mtime);
|
||||
let mut existing_id_query = self.db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?;
|
||||
let existing_id = existing_id_query
|
||||
.query_row(
|
||||
spans: Vec<Span>,
|
||||
) -> impl Future<Output = Result<()>> {
|
||||
self.transact(move |db| {
|
||||
// Return the existing ID, if both the file and mtime match
|
||||
let mtime = Timestamp::from(mtime);
|
||||
|
||||
db.execute(
|
||||
"
|
||||
REPLACE INTO files
|
||||
(worktree_id, relative_path, mtime_seconds, mtime_nanos)
|
||||
VALUES (?1, ?2, ?3, ?4)
|
||||
",
|
||||
params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
|
||||
|row| Ok(row.get::<_, i64>(0)?),
|
||||
)
|
||||
.map_err(|err| anyhow!(err));
|
||||
let file_id = if existing_id.is_ok() {
|
||||
// If already exists, just return the existing id
|
||||
existing_id.unwrap()
|
||||
} else {
|
||||
// Delete Existing Row
|
||||
self.db.execute(
|
||||
"DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;",
|
||||
params![worktree_id, path.to_str()],
|
||||
)?;
|
||||
self.db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?;
|
||||
self.db.last_insert_rowid()
|
||||
};
|
||||
|
||||
// Currently inserting at approximately 3400 documents a second
|
||||
// I imagine we can speed this up with a bulk insert of some kind.
|
||||
for document in documents {
|
||||
let embedding_blob = bincode::serialize(&document.embedding)?;
|
||||
let sha_blob = bincode::serialize(&document.sha1)?;
|
||||
let file_id = db.last_insert_rowid();
|
||||
|
||||
self.db.execute(
|
||||
"INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||
params![
|
||||
let mut query = db.prepare(
|
||||
"
|
||||
INSERT INTO spans
|
||||
(file_id, start_byte, end_byte, name, embedding, digest)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||||
",
|
||||
)?;
|
||||
|
||||
for span in spans {
|
||||
query.execute(params![
|
||||
file_id,
|
||||
document.range.start.to_string(),
|
||||
document.range.end.to_string(),
|
||||
document.name,
|
||||
embedding_blob,
|
||||
sha_blob
|
||||
],
|
||||
)?;
|
||||
}
|
||||
span.range.start.to_string(),
|
||||
span.range.end.to_string(),
|
||||
span.name,
|
||||
span.embedding,
|
||||
span.digest
|
||||
])?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result<bool> {
|
||||
let mut worktree_query = self
|
||||
.db
|
||||
.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
|
||||
let worktree_id = worktree_query
|
||||
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
|
||||
Ok(row.get::<_, i64>(0)?)
|
||||
})
|
||||
.map_err(|err| anyhow!(err));
|
||||
pub fn worktree_previously_indexed(
|
||||
&self,
|
||||
worktree_root_path: &Path,
|
||||
) -> impl Future<Output = Result<bool>> {
|
||||
let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
|
||||
self.transact(move |db| {
|
||||
let mut worktree_query =
|
||||
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
|
||||
let worktree_id = worktree_query
|
||||
.query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
|
||||
|
||||
if worktree_id.is_ok() {
|
||||
return Ok(true);
|
||||
} else {
|
||||
return Ok(false);
|
||||
}
|
||||
if worktree_id.is_ok() {
|
||||
return Ok(true);
|
||||
} else {
|
||||
return Ok(false);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
|
||||
// Check that the absolute path doesnt exist
|
||||
let mut worktree_query = self
|
||||
.db
|
||||
.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
|
||||
|
||||
let worktree_id = worktree_query
|
||||
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
|
||||
Ok(row.get::<_, i64>(0)?)
|
||||
})
|
||||
.map_err(|err| anyhow!(err));
|
||||
|
||||
if worktree_id.is_ok() {
|
||||
return worktree_id;
|
||||
}
|
||||
|
||||
// If worktree_id is Err, insert new worktree
|
||||
self.db.execute(
|
||||
"
|
||||
INSERT into worktrees (absolute_path) VALUES (?1)
|
||||
pub fn embeddings_for_files(
|
||||
&self,
|
||||
worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
|
||||
) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
|
||||
self.transact(move |db| {
|
||||
let mut query = db.prepare(
|
||||
"
|
||||
SELECT digest, embedding
|
||||
FROM spans
|
||||
LEFT JOIN files ON files.id = spans.file_id
|
||||
WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
|
||||
",
|
||||
params![worktree_root_path.to_string_lossy()],
|
||||
)?;
|
||||
Ok(self.db.last_insert_rowid())
|
||||
)?;
|
||||
let mut embeddings_by_digest = HashMap::default();
|
||||
for (worktree_id, file_paths) in worktree_id_file_paths {
|
||||
let file_paths = Rc::new(
|
||||
file_paths
|
||||
.into_iter()
|
||||
.map(|p| Value::Text(p.to_string_lossy().into_owned()))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let rows = query.query_map(params![worktree_id, file_paths], |row| {
|
||||
Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?;
|
||||
|
||||
for row in rows {
|
||||
if let Ok(row) = row {
|
||||
embeddings_by_digest.insert(row.0, row.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embeddings_by_digest)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
|
||||
let mut statement = self.db.prepare(
|
||||
"
|
||||
SELECT relative_path, mtime_seconds, mtime_nanos
|
||||
FROM files
|
||||
WHERE worktree_id = ?1
|
||||
ORDER BY relative_path",
|
||||
)?;
|
||||
let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
|
||||
for row in statement.query_map(params![worktree_id], |row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?.into(),
|
||||
Timestamp {
|
||||
seconds: row.get(1)?,
|
||||
nanos: row.get(2)?,
|
||||
}
|
||||
.into(),
|
||||
))
|
||||
})? {
|
||||
let row = row?;
|
||||
result.insert(row.0, row.1);
|
||||
}
|
||||
Ok(result)
|
||||
pub fn find_or_create_worktree(
|
||||
&self,
|
||||
worktree_root_path: Arc<Path>,
|
||||
) -> impl Future<Output = Result<i64>> {
|
||||
self.transact(move |db| {
|
||||
let mut worktree_query =
|
||||
db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
|
||||
let worktree_id = worktree_query
|
||||
.query_row(params![worktree_root_path.to_string_lossy()], |row| {
|
||||
Ok(row.get::<_, i64>(0)?)
|
||||
});
|
||||
|
||||
if worktree_id.is_ok() {
|
||||
return Ok(worktree_id?);
|
||||
}
|
||||
|
||||
// If worktree_id is Err, insert new worktree
|
||||
db.execute(
|
||||
"INSERT into worktrees (absolute_path) VALUES (?1)",
|
||||
params![worktree_root_path.to_string_lossy()],
|
||||
)?;
|
||||
Ok(db.last_insert_rowid())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_file_mtimes(
|
||||
&self,
|
||||
worktree_id: i64,
|
||||
) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
|
||||
self.transact(move |db| {
|
||||
let mut statement = db.prepare(
|
||||
"
|
||||
SELECT relative_path, mtime_seconds, mtime_nanos
|
||||
FROM files
|
||||
WHERE worktree_id = ?1
|
||||
ORDER BY relative_path",
|
||||
)?;
|
||||
let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
|
||||
for row in statement.query_map(params![worktree_id], |row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?.into(),
|
||||
Timestamp {
|
||||
seconds: row.get(1)?,
|
||||
nanos: row.get(2)?,
|
||||
}
|
||||
.into(),
|
||||
))
|
||||
})? {
|
||||
let row = row?;
|
||||
result.insert(row.0, row.1);
|
||||
}
|
||||
Ok(result)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn top_k_search(
|
||||
&self,
|
||||
query_embedding: &Vec<f32>,
|
||||
query_embedding: &Embedding,
|
||||
limit: usize,
|
||||
file_ids: &[i64],
|
||||
) -> Result<Vec<(i64, f32)>> {
|
||||
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
||||
self.for_each_document(file_ids, |id, embedding| {
|
||||
let similarity = dot(&embedding, &query_embedding);
|
||||
let ix = match results
|
||||
.binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
|
||||
{
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
})?;
|
||||
) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
|
||||
let query_embedding = query_embedding.clone();
|
||||
let file_ids = file_ids.to_vec();
|
||||
self.transact(move |db| {
|
||||
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
||||
Self::for_each_span(db, &file_ids, |id, embedding| {
|
||||
let similarity = embedding.similarity(&query_embedding);
|
||||
let ix = match results.binary_search_by(|(_, s)| {
|
||||
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
})?;
|
||||
|
||||
Ok(results)
|
||||
anyhow::Ok(results)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn retrieve_included_file_ids(
|
||||
@@ -310,42 +396,51 @@ impl VectorDatabase {
|
||||
worktree_ids: &[i64],
|
||||
includes: &[PathMatcher],
|
||||
excludes: &[PathMatcher],
|
||||
) -> Result<Vec<i64>> {
|
||||
let mut file_query = self.db.prepare(
|
||||
"
|
||||
SELECT
|
||||
id, relative_path
|
||||
FROM
|
||||
files
|
||||
WHERE
|
||||
worktree_id IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
) -> impl Future<Output = Result<Vec<i64>>> {
|
||||
let worktree_ids = worktree_ids.to_vec();
|
||||
let includes = includes.to_vec();
|
||||
let excludes = excludes.to_vec();
|
||||
self.transact(move |db| {
|
||||
let mut file_query = db.prepare(
|
||||
"
|
||||
SELECT
|
||||
id, relative_path
|
||||
FROM
|
||||
files
|
||||
WHERE
|
||||
worktree_id IN rarray(?)
|
||||
",
|
||||
)?;
|
||||
|
||||
let mut file_ids = Vec::<i64>::new();
|
||||
let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
|
||||
let mut file_ids = Vec::<i64>::new();
|
||||
let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
|
||||
|
||||
while let Some(row) = rows.next()? {
|
||||
let file_id = row.get(0)?;
|
||||
let relative_path = row.get_ref(1)?.as_str()?;
|
||||
let included =
|
||||
includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
|
||||
let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
|
||||
if included && !excluded {
|
||||
file_ids.push(file_id);
|
||||
while let Some(row) = rows.next()? {
|
||||
let file_id = row.get(0)?;
|
||||
let relative_path = row.get_ref(1)?.as_str()?;
|
||||
let included =
|
||||
includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
|
||||
let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
|
||||
if included && !excluded {
|
||||
file_ids.push(file_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(file_ids)
|
||||
anyhow::Ok(file_ids)
|
||||
})
|
||||
}
|
||||
|
||||
fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
|
||||
let mut query_statement = self.db.prepare(
|
||||
fn for_each_span(
|
||||
db: &rusqlite::Connection,
|
||||
file_ids: &[i64],
|
||||
mut f: impl FnMut(i64, Embedding),
|
||||
) -> Result<()> {
|
||||
let mut query_statement = db.prepare(
|
||||
"
|
||||
SELECT
|
||||
id, embedding
|
||||
FROM
|
||||
documents
|
||||
spans
|
||||
WHERE
|
||||
file_id IN rarray(?)
|
||||
",
|
||||
@@ -356,51 +451,57 @@ impl VectorDatabase {
|
||||
Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
|
||||
})?
|
||||
.filter_map(|row| row.ok())
|
||||
.for_each(|(id, embedding)| f(id, embedding.0));
|
||||
.for_each(|(id, embedding)| f(id, embedding));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
|
||||
let mut statement = self.db.prepare(
|
||||
"
|
||||
SELECT
|
||||
documents.id,
|
||||
files.worktree_id,
|
||||
files.relative_path,
|
||||
documents.start_byte,
|
||||
documents.end_byte
|
||||
FROM
|
||||
documents, files
|
||||
WHERE
|
||||
documents.file_id = files.id AND
|
||||
documents.id in rarray(?)
|
||||
",
|
||||
)?;
|
||||
pub fn spans_for_ids(
|
||||
&self,
|
||||
ids: &[i64],
|
||||
) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
|
||||
let ids = ids.to_vec();
|
||||
self.transact(move |db| {
|
||||
let mut statement = db.prepare(
|
||||
"
|
||||
SELECT
|
||||
spans.id,
|
||||
files.worktree_id,
|
||||
files.relative_path,
|
||||
spans.start_byte,
|
||||
spans.end_byte
|
||||
FROM
|
||||
spans, files
|
||||
WHERE
|
||||
spans.file_id = files.id AND
|
||||
spans.id in rarray(?)
|
||||
",
|
||||
)?;
|
||||
|
||||
let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
|
||||
Ok((
|
||||
row.get::<_, i64>(0)?,
|
||||
row.get::<_, i64>(1)?,
|
||||
row.get::<_, String>(2)?.into(),
|
||||
row.get(3)?..row.get(4)?,
|
||||
))
|
||||
})?;
|
||||
let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
|
||||
Ok((
|
||||
row.get::<_, i64>(0)?,
|
||||
row.get::<_, i64>(1)?,
|
||||
row.get::<_, String>(2)?.into(),
|
||||
row.get(3)?..row.get(4)?,
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
|
||||
for row in result_iter {
|
||||
let (id, worktree_id, path, range) = row?;
|
||||
values_by_id.insert(id, (worktree_id, path, range));
|
||||
}
|
||||
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
|
||||
for row in result_iter {
|
||||
let (id, worktree_id, path, range) = row?;
|
||||
values_by_id.insert(id, (worktree_id, path, range));
|
||||
}
|
||||
|
||||
let mut results = Vec::with_capacity(ids.len());
|
||||
for id in ids {
|
||||
let value = values_by_id
|
||||
.remove(id)
|
||||
.ok_or(anyhow!("missing document id {}", id))?;
|
||||
results.push(value);
|
||||
}
|
||||
let mut results = Vec::with_capacity(ids.len());
|
||||
for id in &ids {
|
||||
let value = values_by_id
|
||||
.remove(id)
|
||||
.ok_or(anyhow!("missing span id {}", id))?;
|
||||
results.push(value);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
Ok(results)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -412,29 +513,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
|
||||
let len = vec_a.len();
|
||||
assert_eq!(len, vec_b.len());
|
||||
|
||||
let mut result = 0.0;
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(
|
||||
1,
|
||||
len,
|
||||
1,
|
||||
1.0,
|
||||
vec_a.as_ptr(),
|
||||
len as isize,
|
||||
1,
|
||||
vec_b.as_ptr(),
|
||||
1,
|
||||
len as isize,
|
||||
0.0,
|
||||
&mut result as *mut f32,
|
||||
1,
|
||||
1,
|
||||
);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
@@ -7,6 +7,9 @@ use isahc::http::StatusCode;
|
||||
use isahc::prelude::Configurable;
|
||||
use isahc::{AsyncBody, Response};
|
||||
use lazy_static::lazy_static;
|
||||
use parse_duration::parse;
|
||||
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||
use rusqlite::ToSql;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
@@ -19,6 +22,62 @@ lazy_static! {
|
||||
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Embedding(Vec<f32>);
|
||||
|
||||
impl From<Vec<f32>> for Embedding {
|
||||
fn from(value: Vec<f32>) -> Self {
|
||||
Embedding(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn similarity(&self, other: &Self) -> f32 {
|
||||
let len = self.0.len();
|
||||
assert_eq!(len, other.0.len());
|
||||
|
||||
let mut result = 0.0;
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(
|
||||
1,
|
||||
len,
|
||||
1,
|
||||
1.0,
|
||||
self.0.as_ptr(),
|
||||
len as isize,
|
||||
1,
|
||||
other.0.as_ptr(),
|
||||
1,
|
||||
len as isize,
|
||||
0.0,
|
||||
&mut result as *mut f32,
|
||||
1,
|
||||
1,
|
||||
);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl FromSql for Embedding {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
|
||||
if embedding.is_err() {
|
||||
return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
|
||||
}
|
||||
Ok(Embedding(embedding.unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for Embedding {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
let bytes = bincode::serialize(&self.0)
|
||||
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
|
||||
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAIEmbeddings {
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
@@ -52,42 +111,53 @@ struct OpenAIEmbeddingUsage {
|
||||
|
||||
#[async_trait]
|
||||
pub trait EmbeddingProvider: Sync + Send {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
||||
fn max_tokens_per_batch(&self) -> usize;
|
||||
fn truncate(&self, span: &str) -> (String, usize);
|
||||
}
|
||||
|
||||
pub struct DummyEmbeddings {}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for DummyEmbeddings {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
// 1024 is the OpenAI Embeddings size for ada models.
|
||||
// the model we will likely be starting with.
|
||||
let dummy_vec = vec![0.32 as f32; 1536];
|
||||
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
|
||||
return Ok(vec![dummy_vec; spans.len()]);
|
||||
}
|
||||
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
OPENAI_INPUT_LIMIT
|
||||
}
|
||||
|
||||
fn truncate(&self, span: &str) -> (String, usize) {
|
||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||
let token_count = tokens.len();
|
||||
let output = if token_count > OPENAI_INPUT_LIMIT {
|
||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||
let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
|
||||
new_input.ok().unwrap_or_else(|| span.to_string())
|
||||
} else {
|
||||
span.to_string()
|
||||
};
|
||||
|
||||
(output, tokens.len())
|
||||
}
|
||||
}
|
||||
|
||||
const OPENAI_INPUT_LIMIT: usize = 8190;
|
||||
|
||||
impl OpenAIEmbeddings {
|
||||
fn truncate(span: String) -> String {
|
||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
|
||||
if tokens.len() > OPENAI_INPUT_LIMIT {
|
||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||
let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
|
||||
if result.is_ok() {
|
||||
let transformed = result.unwrap();
|
||||
return transformed;
|
||||
}
|
||||
}
|
||||
|
||||
span
|
||||
}
|
||||
|
||||
async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
|
||||
async fn send_request(
|
||||
&self,
|
||||
api_key: &str,
|
||||
spans: Vec<&str>,
|
||||
request_timeout: u64,
|
||||
) -> Result<Response<AsyncBody>> {
|
||||
let request = Request::post("https://api.openai.com/v1/embeddings")
|
||||
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||
.timeout(Duration::from_secs(4))
|
||||
.timeout(Duration::from_secs(request_timeout))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(
|
||||
@@ -105,7 +175,26 @@ impl OpenAIEmbeddings {
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for OpenAIEmbeddings {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
50000
|
||||
}
|
||||
|
||||
fn truncate(&self, span: &str) -> (String, usize) {
|
||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
|
||||
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
|
||||
tokens.truncate(OPENAI_INPUT_LIMIT);
|
||||
OPENAI_BPE_TOKENIZER
|
||||
.decode(tokens.clone())
|
||||
.ok()
|
||||
.unwrap_or_else(|| span.to_string())
|
||||
} else {
|
||||
span.to_string()
|
||||
};
|
||||
|
||||
(output, tokens.len())
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||
const MAX_RETRIES: usize = 4;
|
||||
|
||||
@@ -114,45 +203,21 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
||||
.ok_or_else(|| anyhow!("no api key"))?;
|
||||
|
||||
let mut request_number = 0;
|
||||
let mut truncated = false;
|
||||
let mut request_timeout: u64 = 15;
|
||||
let mut response: Response<AsyncBody>;
|
||||
let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
|
||||
while request_number < MAX_RETRIES {
|
||||
response = self
|
||||
.send_request(api_key, spans.iter().map(|x| &**x).collect())
|
||||
.send_request(
|
||||
api_key,
|
||||
spans.iter().map(|x| &**x).collect(),
|
||||
request_timeout,
|
||||
)
|
||||
.await?;
|
||||
request_number += 1;
|
||||
|
||||
if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK {
|
||||
return Err(anyhow!(
|
||||
"openai max retries, error: {:?}",
|
||||
&response.status()
|
||||
));
|
||||
}
|
||||
|
||||
match response.status() {
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
||||
log::trace!(
|
||||
"open ai rate limiting, delaying request by {:?} seconds",
|
||||
delay.as_secs()
|
||||
);
|
||||
self.executor.timer(delay).await;
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
// Only truncate if it hasnt been truncated before
|
||||
if !truncated {
|
||||
for span in spans.iter_mut() {
|
||||
*span = Self::truncate(span.clone());
|
||||
}
|
||||
truncated = true;
|
||||
} else {
|
||||
// If failing once already truncated, log the error and break the loop
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
log::trace!("open ai bad request: {:?} {:?}", &response.status(), body);
|
||||
break;
|
||||
}
|
||||
StatusCode::REQUEST_TIMEOUT => {
|
||||
request_timeout += 5;
|
||||
}
|
||||
StatusCode::OK => {
|
||||
let mut body = String::new();
|
||||
@@ -163,18 +228,96 @@ impl EmbeddingProvider for OpenAIEmbeddings {
|
||||
"openai embedding completed. tokens: {:?}",
|
||||
response.usage.total_tokens
|
||||
);
|
||||
|
||||
return Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding| embedding.embedding)
|
||||
.map(|embedding| Embedding::from(embedding.embedding))
|
||||
.collect());
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
let delay_duration = {
|
||||
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
||||
if let Some(time_to_reset) =
|
||||
response.headers().get("x-ratelimit-reset-tokens")
|
||||
{
|
||||
if let Ok(time_str) = time_to_reset.to_str() {
|
||||
parse(time_str).unwrap_or(delay)
|
||||
} else {
|
||||
delay
|
||||
}
|
||||
} else {
|
||||
delay
|
||||
}
|
||||
};
|
||||
|
||||
log::trace!(
|
||||
"openai rate limiting: waiting {:?} until lifted",
|
||||
&delay_duration
|
||||
);
|
||||
|
||||
self.executor.timer(delay_duration).await;
|
||||
}
|
||||
_ => {
|
||||
return Err(anyhow!("openai embedding failed {}", response.status()));
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(
|
||||
"open ai bad request: {:?} {:?}",
|
||||
&response.status(),
|
||||
body
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow!("openai embedding failed"))
|
||||
Err(anyhow!("openai max retries"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::prelude::*;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_similarity(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
Embedding::from(vec![1., 0., 0., 0., 0.])
|
||||
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
|
||||
0.
|
||||
);
|
||||
assert_eq!(
|
||||
Embedding::from(vec![2., 0., 0., 0., 0.])
|
||||
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
|
||||
6.
|
||||
);
|
||||
|
||||
for _ in 0..100 {
|
||||
let size = 1536;
|
||||
let mut a = vec![0.; size];
|
||||
let mut b = vec![0.; size];
|
||||
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
|
||||
*a = rng.gen();
|
||||
*b = rng.gen();
|
||||
}
|
||||
let a = Embedding::from(a);
|
||||
let b = Embedding::from(b);
|
||||
|
||||
assert_eq!(
|
||||
round_to_decimals(a.similarity(&b), 1),
|
||||
round_to_decimals(reference_dot(&a.0, &b.0), 1)
|
||||
);
|
||||
}
|
||||
|
||||
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
|
||||
let factor = (10.0 as f32).powi(decimal_places);
|
||||
(n * factor).round() / factor
|
||||
}
|
||||
|
||||
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
165
crates/semantic_index/src/embedding_queue.rs
Normal file
165
crates/semantic_index/src/embedding_queue.rs
Normal file
@@ -0,0 +1,165 @@
|
||||
use crate::{embedding::EmbeddingProvider, parsing::Span, JobHandle};
|
||||
use gpui::executor::Background;
|
||||
use parking_lot::Mutex;
|
||||
use smol::channel;
|
||||
use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FileToEmbed {
|
||||
pub worktree_id: i64,
|
||||
pub path: Arc<Path>,
|
||||
pub mtime: SystemTime,
|
||||
pub spans: Vec<Span>,
|
||||
pub job_handle: JobHandle,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for FileToEmbed {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("FileToEmbed")
|
||||
.field("worktree_id", &self.worktree_id)
|
||||
.field("path", &self.path)
|
||||
.field("mtime", &self.mtime)
|
||||
.field("spans", &self.spans)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for FileToEmbed {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.worktree_id == other.worktree_id
|
||||
&& self.path == other.path
|
||||
&& self.mtime == other.mtime
|
||||
&& self.spans == other.spans
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EmbeddingQueue {
|
||||
embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
pending_batch: Vec<FileFragmentToEmbed>,
|
||||
executor: Arc<Background>,
|
||||
pending_batch_token_count: usize,
|
||||
finished_files_tx: channel::Sender<FileToEmbed>,
|
||||
finished_files_rx: channel::Receiver<FileToEmbed>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FileFragmentToEmbed {
|
||||
file: Arc<Mutex<FileToEmbed>>,
|
||||
span_range: Range<usize>,
|
||||
}
|
||||
|
||||
impl EmbeddingQueue {
|
||||
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
|
||||
let (finished_files_tx, finished_files_rx) = channel::unbounded();
|
||||
Self {
|
||||
embedding_provider,
|
||||
executor,
|
||||
pending_batch: Vec::new(),
|
||||
pending_batch_token_count: 0,
|
||||
finished_files_tx,
|
||||
finished_files_rx,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, file: FileToEmbed) {
|
||||
if file.spans.is_empty() {
|
||||
self.finished_files_tx.try_send(file).unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
let file = Arc::new(Mutex::new(file));
|
||||
|
||||
self.pending_batch.push(FileFragmentToEmbed {
|
||||
file: file.clone(),
|
||||
span_range: 0..0,
|
||||
});
|
||||
|
||||
let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
|
||||
for (ix, span) in file.lock().spans.iter().enumerate() {
|
||||
let span_token_count = if span.embedding.is_none() {
|
||||
span.token_count
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let next_token_count = self.pending_batch_token_count + span_token_count;
|
||||
if next_token_count > self.embedding_provider.max_tokens_per_batch() {
|
||||
let range_end = fragment_range.end;
|
||||
self.flush();
|
||||
self.pending_batch.push(FileFragmentToEmbed {
|
||||
file: file.clone(),
|
||||
span_range: range_end..range_end,
|
||||
});
|
||||
fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
|
||||
}
|
||||
|
||||
fragment_range.end = ix + 1;
|
||||
self.pending_batch_token_count += span_token_count;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flush(&mut self) {
|
||||
let batch = mem::take(&mut self.pending_batch);
|
||||
self.pending_batch_token_count = 0;
|
||||
if batch.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let finished_files_tx = self.finished_files_tx.clone();
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
|
||||
self.executor
|
||||
.spawn(async move {
|
||||
let mut spans = Vec::new();
|
||||
for fragment in &batch {
|
||||
let file = fragment.file.lock();
|
||||
spans.extend(
|
||||
file.spans[fragment.span_range.clone()]
|
||||
.iter()
|
||||
.filter(|d| d.embedding.is_none())
|
||||
.map(|d| d.content.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
// If spans is 0, just send the fragment to the finished files if its the last one.
|
||||
if spans.is_empty() {
|
||||
for fragment in batch.clone() {
|
||||
if let Some(file) = Arc::into_inner(fragment.file) {
|
||||
finished_files_tx.try_send(file.into_inner()).unwrap();
|
||||
}
|
||||
}
|
||||
return;
|
||||
};
|
||||
|
||||
match embedding_provider.embed_batch(spans).await {
|
||||
Ok(embeddings) => {
|
||||
let mut embeddings = embeddings.into_iter();
|
||||
for fragment in batch {
|
||||
for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
|
||||
.iter_mut()
|
||||
.filter(|d| d.embedding.is_none())
|
||||
{
|
||||
if let Some(embedding) = embeddings.next() {
|
||||
span.embedding = Some(embedding);
|
||||
} else {
|
||||
log::error!("number of embeddings != number of documents");
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(file) = Arc::into_inner(fragment.file) {
|
||||
finished_files_tx.try_send(file.into_inner()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("{:?}", error);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
|
||||
self.finished_files_rx.clone()
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,10 @@
|
||||
use anyhow::{anyhow, Ok, Result};
|
||||
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||
use anyhow::{anyhow, Result};
|
||||
use language::{Grammar, Language};
|
||||
use rusqlite::{
|
||||
types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
|
||||
ToSql,
|
||||
};
|
||||
use sha1::{Digest, Sha1};
|
||||
use std::{
|
||||
cmp::{self, Reverse},
|
||||
@@ -10,13 +15,44 @@ use std::{
|
||||
};
|
||||
use tree_sitter::{Parser, QueryCursor};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
|
||||
pub struct SpanDigest([u8; 20]);
|
||||
|
||||
impl FromSql for SpanDigest {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let blob = value.as_blob()?;
|
||||
let bytes =
|
||||
blob.try_into()
|
||||
.map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
|
||||
expected_size: 20,
|
||||
blob_size: blob.len(),
|
||||
})?;
|
||||
return Ok(SpanDigest(bytes));
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for SpanDigest {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
self.0.to_sql()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&'_ str> for SpanDigest {
|
||||
fn from(value: &'_ str) -> Self {
|
||||
let mut sha1 = Sha1::new();
|
||||
sha1.update(value);
|
||||
Self(sha1.finalize().into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Document {
|
||||
pub struct Span {
|
||||
pub name: String,
|
||||
pub range: Range<usize>,
|
||||
pub content: String,
|
||||
pub embedding: Vec<f32>,
|
||||
pub sha1: [u8; 20],
|
||||
pub embedding: Option<Embedding>,
|
||||
pub digest: SpanDigest,
|
||||
pub token_count: usize,
|
||||
}
|
||||
|
||||
const CODE_CONTEXT_TEMPLATE: &str =
|
||||
@@ -30,6 +66,7 @@ pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
|
||||
pub struct CodeContextRetriever {
|
||||
pub parser: Parser,
|
||||
pub cursor: QueryCursor,
|
||||
pub embedding_provider: Arc<dyn EmbeddingProvider>,
|
||||
}
|
||||
|
||||
// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
|
||||
@@ -47,10 +84,11 @@ pub struct CodeContextMatch {
|
||||
}
|
||||
|
||||
impl CodeContextRetriever {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
|
||||
Self {
|
||||
parser: Parser::new(),
|
||||
cursor: QueryCursor::new(),
|
||||
embedding_provider,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,38 +97,36 @@ impl CodeContextRetriever {
|
||||
relative_path: &Path,
|
||||
language_name: Arc<str>,
|
||||
content: &str,
|
||||
) -> Result<Vec<Document>> {
|
||||
) -> Result<Vec<Span>> {
|
||||
let document_span = ENTIRE_FILE_TEMPLATE
|
||||
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("<item>", &content);
|
||||
|
||||
let mut sha1 = Sha1::new();
|
||||
sha1.update(&document_span);
|
||||
|
||||
Ok(vec![Document {
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||
Ok(vec![Span {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
embedding: Vec::new(),
|
||||
embedding: Default::default(),
|
||||
name: language_name.to_string(),
|
||||
sha1: sha1.finalize().into(),
|
||||
digest,
|
||||
token_count,
|
||||
}])
|
||||
}
|
||||
|
||||
fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Document>> {
|
||||
fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Span>> {
|
||||
let document_span = MARKDOWN_CONTEXT_TEMPLATE
|
||||
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
||||
.replace("<item>", &content);
|
||||
|
||||
let mut sha1 = Sha1::new();
|
||||
sha1.update(&document_span);
|
||||
|
||||
Ok(vec![Document {
|
||||
let digest = SpanDigest::from(document_span.as_str());
|
||||
let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
|
||||
Ok(vec![Span {
|
||||
range: 0..content.len(),
|
||||
content: document_span,
|
||||
embedding: Vec::new(),
|
||||
embedding: None,
|
||||
name: "Markdown".to_string(),
|
||||
sha1: sha1.finalize().into(),
|
||||
digest,
|
||||
token_count,
|
||||
}])
|
||||
}
|
||||
|
||||
@@ -155,26 +191,32 @@ impl CodeContextRetriever {
|
||||
relative_path: &Path,
|
||||
content: &str,
|
||||
language: Arc<Language>,
|
||||
) -> Result<Vec<Document>> {
|
||||
) -> Result<Vec<Span>> {
|
||||
let language_name = language.name();
|
||||
|
||||
if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
|
||||
return self.parse_entire_file(relative_path, language_name, &content);
|
||||
} else if &language_name.to_string() == &"Markdown".to_string() {
|
||||
} else if language_name.as_ref() == "Markdown" {
|
||||
return self.parse_markdown_file(relative_path, &content);
|
||||
}
|
||||
|
||||
let mut documents = self.parse_file(content, language)?;
|
||||
for document in &mut documents {
|
||||
document.content = CODE_CONTEXT_TEMPLATE
|
||||
let mut spans = self.parse_file(content, language)?;
|
||||
for span in &mut spans {
|
||||
let document_content = CODE_CONTEXT_TEMPLATE
|
||||
.replace("<path>", relative_path.to_string_lossy().as_ref())
|
||||
.replace("<language>", language_name.as_ref())
|
||||
.replace("item", &document.content);
|
||||
.replace("item", &span.content);
|
||||
|
||||
let (document_content, token_count) =
|
||||
self.embedding_provider.truncate(&document_content);
|
||||
|
||||
span.content = document_content;
|
||||
span.token_count = token_count;
|
||||
}
|
||||
Ok(documents)
|
||||
Ok(spans)
|
||||
}
|
||||
|
||||
pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
|
||||
pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
|
||||
let grammar = language
|
||||
.grammar()
|
||||
.ok_or_else(|| anyhow!("no grammar for language"))?;
|
||||
@@ -185,7 +227,7 @@ impl CodeContextRetriever {
|
||||
let language_scope = language.default_scope();
|
||||
let placeholder = language_scope.collapsed_placeholder();
|
||||
|
||||
let mut documents = Vec::new();
|
||||
let mut spans = Vec::new();
|
||||
let mut collapsed_ranges_within = Vec::new();
|
||||
let mut parsed_name_ranges = HashSet::new();
|
||||
for (i, context_match) in matches.iter().enumerate() {
|
||||
@@ -225,22 +267,22 @@ impl CodeContextRetriever {
|
||||
|
||||
collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
|
||||
|
||||
let mut document_content = String::new();
|
||||
let mut span_content = String::new();
|
||||
for context_range in &context_match.context_ranges {
|
||||
add_content_from_range(
|
||||
&mut document_content,
|
||||
&mut span_content,
|
||||
content,
|
||||
context_range.clone(),
|
||||
context_match.start_col,
|
||||
);
|
||||
document_content.push_str("\n");
|
||||
span_content.push_str("\n");
|
||||
}
|
||||
|
||||
let mut offset = item_range.start;
|
||||
for collapsed_range in &collapsed_ranges_within {
|
||||
if collapsed_range.start > offset {
|
||||
add_content_from_range(
|
||||
&mut document_content,
|
||||
&mut span_content,
|
||||
content,
|
||||
offset..collapsed_range.start,
|
||||
context_match.start_col,
|
||||
@@ -249,33 +291,32 @@ impl CodeContextRetriever {
|
||||
}
|
||||
|
||||
if collapsed_range.end > offset {
|
||||
document_content.push_str(placeholder);
|
||||
span_content.push_str(placeholder);
|
||||
offset = collapsed_range.end;
|
||||
}
|
||||
}
|
||||
|
||||
if offset < item_range.end {
|
||||
add_content_from_range(
|
||||
&mut document_content,
|
||||
&mut span_content,
|
||||
content,
|
||||
offset..item_range.end,
|
||||
context_match.start_col,
|
||||
);
|
||||
}
|
||||
|
||||
let mut sha1 = Sha1::new();
|
||||
sha1.update(&document_content);
|
||||
|
||||
documents.push(Document {
|
||||
let sha1 = SpanDigest::from(span_content.as_str());
|
||||
spans.push(Span {
|
||||
name,
|
||||
content: document_content,
|
||||
content: span_content,
|
||||
range: item_range.clone(),
|
||||
embedding: vec![],
|
||||
sha1: sha1.finalize().into(),
|
||||
embedding: None,
|
||||
digest: sha1,
|
||||
token_count: 0,
|
||||
})
|
||||
}
|
||||
|
||||
return Ok(documents);
|
||||
return Ok(spans);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,14 +1,15 @@
|
||||
use crate::{
|
||||
db::dot,
|
||||
embedding::EmbeddingProvider,
|
||||
parsing::{subtract_ranges, CodeContextRetriever, Document},
|
||||
embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
|
||||
embedding_queue::EmbeddingQueue,
|
||||
parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
|
||||
semantic_index_settings::SemanticIndexSettings,
|
||||
SearchResult, SemanticIndex,
|
||||
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use gpui::{Task, TestAppContext};
|
||||
use gpui::{executor::Deterministic, Task, TestAppContext};
|
||||
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
||||
use parking_lot::Mutex;
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
|
||||
use rand::{rngs::StdRng, Rng};
|
||||
@@ -20,8 +21,10 @@ use std::{
|
||||
atomic::{self, AtomicUsize},
|
||||
Arc,
|
||||
},
|
||||
time::SystemTime,
|
||||
};
|
||||
use unindent::Unindent;
|
||||
use util::RandomCharIter;
|
||||
|
||||
#[ctor::ctor]
|
||||
fn init_logger() {
|
||||
@@ -31,12 +34,8 @@ fn init_logger() {
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
settings::register::<SemanticIndexSettings>(cx);
|
||||
settings::register::<ProjectSettings>(cx);
|
||||
});
|
||||
async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.background());
|
||||
fs.insert_tree(
|
||||
@@ -56,6 +55,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||
fn bbb() {
|
||||
println!(\"bbbbbbbbbbbbb!\");
|
||||
}
|
||||
struct pqpqpqp {}
|
||||
".unindent(),
|
||||
"file3.toml": "
|
||||
ZZZZZZZZZZZZZZZZZZ = 5
|
||||
@@ -75,7 +75,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||
let db_path = db_dir.path().join("db.sqlite");
|
||||
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
let store = SemanticIndex::new(
|
||||
let semantic_index = SemanticIndex::new(
|
||||
fs.clone(),
|
||||
db_path,
|
||||
embedding_provider.clone(),
|
||||
@@ -87,34 +87,24 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||
|
||||
let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
|
||||
|
||||
let _ = store
|
||||
.update(cx, |store, cx| {
|
||||
store.initialize_project(project.clone(), cx)
|
||||
})
|
||||
.await;
|
||||
|
||||
let (file_count, outstanding_file_count) = store
|
||||
.update(cx, |store, cx| store.index_project(project.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(file_count, 3);
|
||||
cx.foreground().run_until_parked();
|
||||
assert_eq!(*outstanding_file_count.borrow(), 0);
|
||||
|
||||
let search_results = store
|
||||
.update(cx, |store, cx| {
|
||||
store.search_project(
|
||||
project.clone(),
|
||||
"aaaaaabbbbzz".to_string(),
|
||||
5,
|
||||
vec![],
|
||||
vec![],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let search_results = semantic_index.update(cx, |store, cx| {
|
||||
store.search_project(
|
||||
project.clone(),
|
||||
"aaaaaabbbbzz".to_string(),
|
||||
5,
|
||||
vec![],
|
||||
vec![],
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let pending_file_count =
|
||||
semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
|
||||
deterministic.run_until_parked();
|
||||
assert_eq!(*pending_file_count.borrow(), 3);
|
||||
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
|
||||
assert_eq!(*pending_file_count.borrow(), 0);
|
||||
|
||||
let search_results = search_results.await.unwrap();
|
||||
assert_search_results(
|
||||
&search_results,
|
||||
&[
|
||||
@@ -122,6 +112,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||
(Path::new("src/file2.rs").into(), 0),
|
||||
(Path::new("src/file3.toml").into(), 0),
|
||||
(Path::new("src/file1.rs").into(), 45),
|
||||
(Path::new("src/file2.rs").into(), 45),
|
||||
],
|
||||
cx,
|
||||
);
|
||||
@@ -129,7 +120,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||
// Test Include Files Functonality
|
||||
let include_files = vec![PathMatcher::new("*.rs").unwrap()];
|
||||
let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
|
||||
let rust_only_search_results = store
|
||||
let rust_only_search_results = semantic_index
|
||||
.update(cx, |store, cx| {
|
||||
store.search_project(
|
||||
project.clone(),
|
||||
@@ -149,11 +140,12 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||
(Path::new("src/file1.rs").into(), 0),
|
||||
(Path::new("src/file2.rs").into(), 0),
|
||||
(Path::new("src/file1.rs").into(), 45),
|
||||
(Path::new("src/file2.rs").into(), 45),
|
||||
],
|
||||
cx,
|
||||
);
|
||||
|
||||
let no_rust_search_results = store
|
||||
let no_rust_search_results = semantic_index
|
||||
.update(cx, |store, cx| {
|
||||
store.search_project(
|
||||
project.clone(),
|
||||
@@ -186,24 +178,85 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.foreground().run_until_parked();
|
||||
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
|
||||
|
||||
let prev_embedding_count = embedding_provider.embedding_count();
|
||||
let (file_count, outstanding_file_count) = store
|
||||
.update(cx, |store, cx| store.index_project(project.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(file_count, 1);
|
||||
|
||||
cx.foreground().run_until_parked();
|
||||
assert_eq!(*outstanding_file_count.borrow(), 0);
|
||||
let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
|
||||
deterministic.run_until_parked();
|
||||
assert_eq!(*pending_file_count.borrow(), 1);
|
||||
deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
|
||||
assert_eq!(*pending_file_count.borrow(), 0);
|
||||
index.await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
embedding_provider.embedding_count() - prev_embedding_count,
|
||||
2
|
||||
1
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||
let (outstanding_job_count, _) = postage::watch::channel_with(0);
|
||||
let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
|
||||
|
||||
let files = (1..=3)
|
||||
.map(|file_ix| FileToEmbed {
|
||||
worktree_id: 5,
|
||||
path: Path::new(&format!("path-{file_ix}")).into(),
|
||||
mtime: SystemTime::now(),
|
||||
spans: (0..rng.gen_range(4..22))
|
||||
.map(|document_ix| {
|
||||
let content_len = rng.gen_range(10..100);
|
||||
let content = RandomCharIter::new(&mut rng)
|
||||
.with_simple_text()
|
||||
.take(content_len)
|
||||
.collect::<String>();
|
||||
let digest = SpanDigest::from(content.as_str());
|
||||
Span {
|
||||
range: 0..10,
|
||||
embedding: None,
|
||||
name: format!("document {document_ix}"),
|
||||
content,
|
||||
digest,
|
||||
token_count: rng.gen_range(10..30),
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
job_handle: JobHandle::new(&outstanding_job_count),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
|
||||
|
||||
let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
|
||||
for file in &files {
|
||||
queue.push(file.clone());
|
||||
}
|
||||
queue.flush();
|
||||
|
||||
cx.foreground().run_until_parked();
|
||||
let finished_files = queue.finished_files();
|
||||
let mut embedded_files: Vec<_> = files
|
||||
.iter()
|
||||
.map(|_| finished_files.try_recv().expect("no finished file"))
|
||||
.collect();
|
||||
|
||||
let expected_files: Vec<_> = files
|
||||
.iter()
|
||||
.map(|file| {
|
||||
let mut file = file.clone();
|
||||
for doc in &mut file.spans {
|
||||
doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
|
||||
}
|
||||
file
|
||||
})
|
||||
.collect();
|
||||
|
||||
embedded_files.sort_by_key(|f| f.path.clone());
|
||||
|
||||
assert_eq!(embedded_files, expected_files);
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn assert_search_results(
|
||||
actual: &[SearchResult],
|
||||
@@ -227,7 +280,8 @@ fn assert_search_results(
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_rust() {
|
||||
let language = rust_lang();
|
||||
let mut retriever = CodeContextRetriever::new();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = "
|
||||
/// A doc comment
|
||||
@@ -314,7 +368,8 @@ async fn test_code_context_retrieval_rust() {
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_json() {
|
||||
let language = json_lang();
|
||||
let mut retriever = CodeContextRetriever::new();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
{
|
||||
@@ -382,7 +437,7 @@ async fn test_code_context_retrieval_json() {
|
||||
}
|
||||
|
||||
fn assert_documents_eq(
|
||||
documents: &[Document],
|
||||
documents: &[Span],
|
||||
expected_contents_and_start_offsets: &[(String, usize)],
|
||||
) {
|
||||
assert_eq!(
|
||||
@@ -397,7 +452,8 @@ fn assert_documents_eq(
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_javascript() {
|
||||
let language = js_lang();
|
||||
let mut retriever = CodeContextRetriever::new();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = "
|
||||
/* globals importScripts, backend */
|
||||
@@ -495,7 +551,8 @@ async fn test_code_context_retrieval_javascript() {
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_lua() {
|
||||
let language = lua_lang();
|
||||
let mut retriever = CodeContextRetriever::new();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
-- Creates a new class
|
||||
@@ -568,7 +625,8 @@ async fn test_code_context_retrieval_lua() {
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_elixir() {
|
||||
let language = elixir_lang();
|
||||
let mut retriever = CodeContextRetriever::new();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
defmodule File.Stream do
|
||||
@@ -684,7 +742,8 @@ async fn test_code_context_retrieval_elixir() {
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_cpp() {
|
||||
let language = cpp_lang();
|
||||
let mut retriever = CodeContextRetriever::new();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = "
|
||||
/**
|
||||
@@ -836,7 +895,8 @@ async fn test_code_context_retrieval_cpp() {
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_ruby() {
|
||||
let language = ruby_lang();
|
||||
let mut retriever = CodeContextRetriever::new();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
# This concern is inspired by "sudo mode" on GitHub. It
|
||||
@@ -1026,7 +1086,8 @@ async fn test_code_context_retrieval_ruby() {
|
||||
#[gpui::test]
|
||||
async fn test_code_context_retrieval_php() {
|
||||
let language = php_lang();
|
||||
let mut retriever = CodeContextRetriever::new();
|
||||
let embedding_provider = Arc::new(DummyEmbeddings {});
|
||||
let mut retriever = CodeContextRetriever::new(embedding_provider);
|
||||
|
||||
let text = r#"
|
||||
<?php
|
||||
@@ -1173,36 +1234,6 @@ async fn test_code_context_retrieval_php() {
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_dot_product(mut rng: StdRng) {
|
||||
assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
|
||||
assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
|
||||
|
||||
for _ in 0..100 {
|
||||
let size = 1536;
|
||||
let mut a = vec![0.; size];
|
||||
let mut b = vec![0.; size];
|
||||
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
|
||||
*a = rng.gen();
|
||||
*b = rng.gen();
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
round_to_decimals(dot(&a, &b), 1),
|
||||
round_to_decimals(reference_dot(&a, &b), 1)
|
||||
);
|
||||
}
|
||||
|
||||
fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
|
||||
let factor = (10.0 as f32).powi(decimal_places);
|
||||
(n * factor).round() / factor
|
||||
}
|
||||
|
||||
fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct FakeEmbeddingProvider {
|
||||
embedding_count: AtomicUsize,
|
||||
@@ -1212,35 +1243,42 @@ impl FakeEmbeddingProvider {
|
||||
fn embedding_count(&self) -> usize {
|
||||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
fn embed_sync(&self, span: &str) -> Embedding {
|
||||
let mut result = vec![1.0; 26];
|
||||
for letter in span.chars() {
|
||||
let letter = letter.to_ascii_lowercase();
|
||||
if letter as u32 >= 'a' as u32 {
|
||||
let ix = (letter as u32) - ('a' as u32);
|
||||
if ix < 26 {
|
||||
result[ix as usize] += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
for x in &mut result {
|
||||
*x /= norm;
|
||||
}
|
||||
|
||||
result.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
|
||||
fn truncate(&self, span: &str) -> (String, usize) {
|
||||
(span.to_string(), 1)
|
||||
}
|
||||
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
200
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
Ok(spans
|
||||
.iter()
|
||||
.map(|span| {
|
||||
let mut result = vec![1.0; 26];
|
||||
for letter in span.chars() {
|
||||
let letter = letter.to_ascii_lowercase();
|
||||
if letter as u32 >= 'a' as u32 {
|
||||
let ix = (letter as u32) - ('a' as u32);
|
||||
if ix < 26 {
|
||||
result[ix as usize] += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
for x in &mut result {
|
||||
*x /= norm;
|
||||
}
|
||||
|
||||
result
|
||||
})
|
||||
.collect())
|
||||
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1684,3 +1722,11 @@ fn test_subtract_ranges() {
|
||||
|
||||
assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
settings::register::<SemanticIndexSettings>(cx);
|
||||
settings::register::<ProjectSettings>(cx);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -2,13 +2,13 @@ Design notes:
|
||||
|
||||
This crate is split into two conceptual halves:
|
||||
- The terminal.rs file and the src/mappings/ folder, these contain the code for interacting with Alacritty and maintaining the pty event loop. Some behavior in this file is constrained by terminal protocols and standards. The Zed init function is also placed here.
|
||||
- Everything else. These other files integrate the `Terminal` struct created in terminal.rs into the rest of GPUI. The main entry point for GPUI is the terminal_view.rs file and the modal.rs file.
|
||||
- Everything else. These other files integrate the `Terminal` struct created in terminal.rs into the rest of GPUI. The main entry point for GPUI is the terminal_view.rs file and the modal.rs file.
|
||||
|
||||
ttys are created externally, and so can fail in unexpected ways. However, GPUI currently does not have an API for models than can fail to instantiate. `TerminalBuilder` solves this by using Rust's type system to split tty instantiation into a 2 step process: first attempt to create the file handles with `TerminalBuilder::new()`, check the result, then call `TerminalBuilder::subscribe(cx)` from within a model context.
|
||||
|
||||
The TerminalView struct abstracts over failed and successful terminals, passing focus through to the associated view and allowing clients to build a terminal without worrying about errors.
|
||||
|
||||
#Input
|
||||
#Input
|
||||
|
||||
There are currently many distinct paths for getting keystrokes to the terminal:
|
||||
|
||||
@@ -18,6 +18,6 @@ There are currently many distinct paths for getting keystrokes to the terminal:
|
||||
|
||||
3. IME text. When the special character mappings fail, we pass the keystroke back to GPUI to hand it to the IME system. This comes back to us in the `View::replace_text_in_range()` method, and we then send that to the terminal directly, bypassing `try_keystroke()`.
|
||||
|
||||
4. Pasted text has a separate pathway.
|
||||
4. Pasted text has a separate pathway.
|
||||
|
||||
Generally, there's a distinction between 'keystrokes that need to be mapped' and 'strings which need to be written'. I've attempted to unify these under the '.try_keystroke()' API and the `.input()` API (which try_keystroke uses) so we have consistent input handling across the terminal
|
||||
Generally, there's a distinction between 'keystrokes that need to be mapped' and 'strings which need to be written'. I've attempted to unify these under the '.try_keystroke()' API and the `.input()` API (which try_keystroke uses) so we have consistent input handling across the terminal
|
||||
|
||||
@@ -283,7 +283,12 @@ impl TerminalView {
|
||||
pub fn deploy_context_menu(&mut self, position: Vector2F, cx: &mut ViewContext<Self>) {
|
||||
let menu_entries = vec![
|
||||
ContextMenuItem::action("Clear", Clear),
|
||||
ContextMenuItem::action("Close", pane::CloseActiveItem),
|
||||
ContextMenuItem::action(
|
||||
"Close",
|
||||
pane::CloseActiveItem {
|
||||
save_behavior: None,
|
||||
},
|
||||
),
|
||||
];
|
||||
|
||||
self.context_menu.update(cx, |menu, cx| {
|
||||
|
||||
@@ -31,6 +31,7 @@ regex.workspace = true
|
||||
[dev-dependencies]
|
||||
collections = { path = "../collections", features = ["test-support"] }
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
util = { path = "../util", features = ["test-support"] }
|
||||
ctor.workspace = true
|
||||
env_logger.workspace = true
|
||||
rand.workspace = true
|
||||
|
||||
@@ -8,7 +8,7 @@ use sum_tree::Bias;
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash, Default)]
|
||||
pub struct Anchor {
|
||||
pub timestamp: clock::Local,
|
||||
pub timestamp: clock::Lamport,
|
||||
pub offset: usize,
|
||||
pub bias: Bias,
|
||||
pub buffer_id: Option<u64>,
|
||||
@@ -16,14 +16,14 @@ pub struct Anchor {
|
||||
|
||||
impl Anchor {
|
||||
pub const MIN: Self = Self {
|
||||
timestamp: clock::Local::MIN,
|
||||
timestamp: clock::Lamport::MIN,
|
||||
offset: usize::MIN,
|
||||
bias: Bias::Left,
|
||||
buffer_id: None,
|
||||
};
|
||||
|
||||
pub const MAX: Self = Self {
|
||||
timestamp: clock::Local::MAX,
|
||||
timestamp: clock::Lamport::MAX,
|
||||
offset: usize::MAX,
|
||||
bias: Bias::Right,
|
||||
buffer_id: None,
|
||||
|
||||
@@ -46,18 +46,16 @@ lazy_static! {
|
||||
static ref LINE_SEPARATORS_REGEX: Regex = Regex::new("\r\n|\r|\u{2028}|\u{2029}").unwrap();
|
||||
}
|
||||
|
||||
pub type TransactionId = clock::Local;
|
||||
pub type TransactionId = clock::Lamport;
|
||||
|
||||
pub struct Buffer {
|
||||
snapshot: BufferSnapshot,
|
||||
history: History,
|
||||
deferred_ops: OperationQueue<Operation>,
|
||||
deferred_replicas: HashSet<ReplicaId>,
|
||||
replica_id: ReplicaId,
|
||||
local_clock: clock::Local,
|
||||
pub lamport_clock: clock::Lamport,
|
||||
subscriptions: Topic,
|
||||
edit_id_resolvers: HashMap<clock::Local, Vec<oneshot::Sender<()>>>,
|
||||
edit_id_resolvers: HashMap<clock::Lamport, Vec<oneshot::Sender<()>>>,
|
||||
wait_for_version_txs: Vec<(clock::Global, oneshot::Sender<()>)>,
|
||||
}
|
||||
|
||||
@@ -85,7 +83,7 @@ pub struct HistoryEntry {
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Transaction {
|
||||
pub id: TransactionId,
|
||||
pub edit_ids: Vec<clock::Local>,
|
||||
pub edit_ids: Vec<clock::Lamport>,
|
||||
pub start: clock::Global,
|
||||
}
|
||||
|
||||
@@ -97,8 +95,8 @@ impl HistoryEntry {
|
||||
|
||||
struct History {
|
||||
base_text: Rope,
|
||||
operations: TreeMap<clock::Local, Operation>,
|
||||
insertion_slices: HashMap<clock::Local, Vec<InsertionSlice>>,
|
||||
operations: TreeMap<clock::Lamport, Operation>,
|
||||
insertion_slices: HashMap<clock::Lamport, Vec<InsertionSlice>>,
|
||||
undo_stack: Vec<HistoryEntry>,
|
||||
redo_stack: Vec<HistoryEntry>,
|
||||
transaction_depth: usize,
|
||||
@@ -107,7 +105,7 @@ struct History {
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct InsertionSlice {
|
||||
insertion_id: clock::Local,
|
||||
insertion_id: clock::Lamport,
|
||||
range: Range<usize>,
|
||||
}
|
||||
|
||||
@@ -129,18 +127,18 @@ impl History {
|
||||
}
|
||||
|
||||
fn push(&mut self, op: Operation) {
|
||||
self.operations.insert(op.local_timestamp(), op);
|
||||
self.operations.insert(op.timestamp(), op);
|
||||
}
|
||||
|
||||
fn start_transaction(
|
||||
&mut self,
|
||||
start: clock::Global,
|
||||
now: Instant,
|
||||
local_clock: &mut clock::Local,
|
||||
clock: &mut clock::Lamport,
|
||||
) -> Option<TransactionId> {
|
||||
self.transaction_depth += 1;
|
||||
if self.transaction_depth == 1 {
|
||||
let id = local_clock.tick();
|
||||
let id = clock.tick();
|
||||
self.undo_stack.push(HistoryEntry {
|
||||
transaction: Transaction {
|
||||
id,
|
||||
@@ -251,7 +249,7 @@ impl History {
|
||||
self.redo_stack.clear();
|
||||
}
|
||||
|
||||
fn push_undo(&mut self, op_id: clock::Local) {
|
||||
fn push_undo(&mut self, op_id: clock::Lamport) {
|
||||
assert_ne!(self.transaction_depth, 0);
|
||||
if let Some(Operation::Edit(_)) = self.operations.get(&op_id) {
|
||||
let last_transaction = self.undo_stack.last_mut().unwrap();
|
||||
@@ -412,37 +410,14 @@ impl<D1, D2> Edit<(D1, D2)> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
|
||||
pub struct InsertionTimestamp {
|
||||
pub replica_id: ReplicaId,
|
||||
pub local: clock::Seq,
|
||||
pub lamport: clock::Seq,
|
||||
}
|
||||
|
||||
impl InsertionTimestamp {
|
||||
pub fn local(&self) -> clock::Local {
|
||||
clock::Local {
|
||||
replica_id: self.replica_id,
|
||||
value: self.local,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lamport(&self) -> clock::Lamport {
|
||||
clock::Lamport {
|
||||
replica_id: self.replica_id,
|
||||
value: self.lamport,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq, Clone, Debug)]
|
||||
pub struct Fragment {
|
||||
pub id: Locator,
|
||||
pub insertion_timestamp: InsertionTimestamp,
|
||||
pub timestamp: clock::Lamport,
|
||||
pub insertion_offset: usize,
|
||||
pub len: usize,
|
||||
pub visible: bool,
|
||||
pub deletions: HashSet<clock::Local>,
|
||||
pub deletions: HashSet<clock::Lamport>,
|
||||
pub max_undos: clock::Global,
|
||||
}
|
||||
|
||||
@@ -470,29 +445,26 @@ impl<'a> sum_tree::Dimension<'a, FragmentSummary> for FragmentTextSummary {
|
||||
|
||||
#[derive(Eq, PartialEq, Clone, Debug)]
|
||||
struct InsertionFragment {
|
||||
timestamp: clock::Local,
|
||||
timestamp: clock::Lamport,
|
||||
split_offset: usize,
|
||||
fragment_id: Locator,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct InsertionFragmentKey {
|
||||
timestamp: clock::Local,
|
||||
timestamp: clock::Lamport,
|
||||
split_offset: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub enum Operation {
|
||||
Edit(EditOperation),
|
||||
Undo {
|
||||
undo: UndoOperation,
|
||||
lamport_timestamp: clock::Lamport,
|
||||
},
|
||||
Undo(UndoOperation),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct EditOperation {
|
||||
pub timestamp: InsertionTimestamp,
|
||||
pub timestamp: clock::Lamport,
|
||||
pub version: clock::Global,
|
||||
pub ranges: Vec<Range<FullOffset>>,
|
||||
pub new_text: Vec<Arc<str>>,
|
||||
@@ -500,9 +472,9 @@ pub struct EditOperation {
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct UndoOperation {
|
||||
pub id: clock::Local,
|
||||
pub counts: HashMap<clock::Local, u32>,
|
||||
pub timestamp: clock::Lamport,
|
||||
pub version: clock::Global,
|
||||
pub counts: HashMap<clock::Lamport, u32>,
|
||||
}
|
||||
|
||||
impl Buffer {
|
||||
@@ -514,24 +486,21 @@ impl Buffer {
|
||||
let mut fragments = SumTree::new();
|
||||
let mut insertions = SumTree::new();
|
||||
|
||||
let mut local_clock = clock::Local::new(replica_id);
|
||||
let mut lamport_clock = clock::Lamport::new(replica_id);
|
||||
let mut version = clock::Global::new();
|
||||
|
||||
let visible_text = history.base_text.clone();
|
||||
if !visible_text.is_empty() {
|
||||
let insertion_timestamp = InsertionTimestamp {
|
||||
let insertion_timestamp = clock::Lamport {
|
||||
replica_id: 0,
|
||||
local: 1,
|
||||
lamport: 1,
|
||||
value: 1,
|
||||
};
|
||||
local_clock.observe(insertion_timestamp.local());
|
||||
lamport_clock.observe(insertion_timestamp.lamport());
|
||||
version.observe(insertion_timestamp.local());
|
||||
lamport_clock.observe(insertion_timestamp);
|
||||
version.observe(insertion_timestamp);
|
||||
let fragment_id = Locator::between(&Locator::min(), &Locator::max());
|
||||
let fragment = Fragment {
|
||||
id: fragment_id,
|
||||
insertion_timestamp,
|
||||
timestamp: insertion_timestamp,
|
||||
insertion_offset: 0,
|
||||
len: visible_text.len(),
|
||||
visible: true,
|
||||
@@ -557,8 +526,6 @@ impl Buffer {
|
||||
history,
|
||||
deferred_ops: OperationQueue::new(),
|
||||
deferred_replicas: HashSet::default(),
|
||||
replica_id,
|
||||
local_clock,
|
||||
lamport_clock,
|
||||
subscriptions: Default::default(),
|
||||
edit_id_resolvers: Default::default(),
|
||||
@@ -575,7 +542,7 @@ impl Buffer {
|
||||
}
|
||||
|
||||
pub fn replica_id(&self) -> ReplicaId {
|
||||
self.local_clock.replica_id
|
||||
self.lamport_clock.replica_id
|
||||
}
|
||||
|
||||
pub fn remote_id(&self) -> u64 {
|
||||
@@ -602,16 +569,12 @@ impl Buffer {
|
||||
.map(|(range, new_text)| (range, new_text.into()));
|
||||
|
||||
self.start_transaction();
|
||||
let timestamp = InsertionTimestamp {
|
||||
replica_id: self.replica_id,
|
||||
local: self.local_clock.tick().value,
|
||||
lamport: self.lamport_clock.tick().value,
|
||||
};
|
||||
let timestamp = self.lamport_clock.tick();
|
||||
let operation = Operation::Edit(self.apply_local_edit(edits, timestamp));
|
||||
|
||||
self.history.push(operation.clone());
|
||||
self.history.push_undo(operation.local_timestamp());
|
||||
self.snapshot.version.observe(operation.local_timestamp());
|
||||
self.history.push_undo(operation.timestamp());
|
||||
self.snapshot.version.observe(operation.timestamp());
|
||||
self.end_transaction();
|
||||
operation
|
||||
}
|
||||
@@ -619,7 +582,7 @@ impl Buffer {
|
||||
fn apply_local_edit<S: ToOffset, T: Into<Arc<str>>>(
|
||||
&mut self,
|
||||
edits: impl ExactSizeIterator<Item = (Range<S>, T)>,
|
||||
timestamp: InsertionTimestamp,
|
||||
timestamp: clock::Lamport,
|
||||
) -> EditOperation {
|
||||
let mut edits_patch = Patch::default();
|
||||
let mut edit_op = EditOperation {
|
||||
@@ -696,7 +659,7 @@ impl Buffer {
|
||||
.item()
|
||||
.map_or(&Locator::max(), |old_fragment| &old_fragment.id),
|
||||
),
|
||||
insertion_timestamp: timestamp,
|
||||
timestamp,
|
||||
insertion_offset,
|
||||
len: new_text.len(),
|
||||
deletions: Default::default(),
|
||||
@@ -726,7 +689,7 @@ impl Buffer {
|
||||
intersection.insertion_offset += fragment_start - old_fragments.start().visible;
|
||||
intersection.id =
|
||||
Locator::between(&new_fragments.summary().max_id, &intersection.id);
|
||||
intersection.deletions.insert(timestamp.local());
|
||||
intersection.deletions.insert(timestamp);
|
||||
intersection.visible = false;
|
||||
}
|
||||
if intersection.len > 0 {
|
||||
@@ -781,7 +744,7 @@ impl Buffer {
|
||||
self.subscriptions.publish_mut(&edits_patch);
|
||||
self.history
|
||||
.insertion_slices
|
||||
.insert(timestamp.local(), insertion_slices);
|
||||
.insert(timestamp, insertion_slices);
|
||||
edit_op
|
||||
}
|
||||
|
||||
@@ -808,28 +771,23 @@ impl Buffer {
|
||||
fn apply_op(&mut self, op: Operation) -> Result<()> {
|
||||
match op {
|
||||
Operation::Edit(edit) => {
|
||||
if !self.version.observed(edit.timestamp.local()) {
|
||||
if !self.version.observed(edit.timestamp) {
|
||||
self.apply_remote_edit(
|
||||
&edit.version,
|
||||
&edit.ranges,
|
||||
&edit.new_text,
|
||||
edit.timestamp,
|
||||
);
|
||||
self.snapshot.version.observe(edit.timestamp.local());
|
||||
self.local_clock.observe(edit.timestamp.local());
|
||||
self.lamport_clock.observe(edit.timestamp.lamport());
|
||||
self.resolve_edit(edit.timestamp.local());
|
||||
self.snapshot.version.observe(edit.timestamp);
|
||||
self.lamport_clock.observe(edit.timestamp);
|
||||
self.resolve_edit(edit.timestamp);
|
||||
}
|
||||
}
|
||||
Operation::Undo {
|
||||
undo,
|
||||
lamport_timestamp,
|
||||
} => {
|
||||
if !self.version.observed(undo.id) {
|
||||
Operation::Undo(undo) => {
|
||||
if !self.version.observed(undo.timestamp) {
|
||||
self.apply_undo(&undo)?;
|
||||
self.snapshot.version.observe(undo.id);
|
||||
self.local_clock.observe(undo.id);
|
||||
self.lamport_clock.observe(lamport_timestamp);
|
||||
self.snapshot.version.observe(undo.timestamp);
|
||||
self.lamport_clock.observe(undo.timestamp);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -849,7 +807,7 @@ impl Buffer {
|
||||
version: &clock::Global,
|
||||
ranges: &[Range<FullOffset>],
|
||||
new_text: &[Arc<str>],
|
||||
timestamp: InsertionTimestamp,
|
||||
timestamp: clock::Lamport,
|
||||
) {
|
||||
if ranges.is_empty() {
|
||||
return;
|
||||
@@ -916,9 +874,7 @@ impl Buffer {
|
||||
// Skip over insertions that are concurrent to this edit, but have a lower lamport
|
||||
// timestamp.
|
||||
while let Some(fragment) = old_fragments.item() {
|
||||
if fragment_start == range.start
|
||||
&& fragment.insertion_timestamp.lamport() > timestamp.lamport()
|
||||
{
|
||||
if fragment_start == range.start && fragment.timestamp > timestamp {
|
||||
new_ropes.push_fragment(fragment, fragment.visible);
|
||||
new_fragments.push(fragment.clone(), &None);
|
||||
old_fragments.next(&cx);
|
||||
@@ -955,7 +911,7 @@ impl Buffer {
|
||||
.item()
|
||||
.map_or(&Locator::max(), |old_fragment| &old_fragment.id),
|
||||
),
|
||||
insertion_timestamp: timestamp,
|
||||
timestamp,
|
||||
insertion_offset,
|
||||
len: new_text.len(),
|
||||
deletions: Default::default(),
|
||||
@@ -986,7 +942,7 @@ impl Buffer {
|
||||
fragment_start - old_fragments.start().0.full_offset();
|
||||
intersection.id =
|
||||
Locator::between(&new_fragments.summary().max_id, &intersection.id);
|
||||
intersection.deletions.insert(timestamp.local());
|
||||
intersection.deletions.insert(timestamp);
|
||||
intersection.visible = false;
|
||||
insertion_slices.push(intersection.insertion_slice());
|
||||
}
|
||||
@@ -1038,13 +994,13 @@ impl Buffer {
|
||||
self.snapshot.insertions.edit(new_insertions, &());
|
||||
self.history
|
||||
.insertion_slices
|
||||
.insert(timestamp.local(), insertion_slices);
|
||||
.insert(timestamp, insertion_slices);
|
||||
self.subscriptions.publish_mut(&edits_patch)
|
||||
}
|
||||
|
||||
fn fragment_ids_for_edits<'a>(
|
||||
&'a self,
|
||||
edit_ids: impl Iterator<Item = &'a clock::Local>,
|
||||
edit_ids: impl Iterator<Item = &'a clock::Lamport>,
|
||||
) -> Vec<&'a Locator> {
|
||||
// Get all of the insertion slices changed by the given edits.
|
||||
let mut insertion_slices = Vec::new();
|
||||
@@ -1105,7 +1061,7 @@ impl Buffer {
|
||||
let fragment_was_visible = fragment.visible;
|
||||
|
||||
fragment.visible = fragment.is_visible(&self.undo_map);
|
||||
fragment.max_undos.observe(undo.id);
|
||||
fragment.max_undos.observe(undo.timestamp);
|
||||
|
||||
let old_start = old_fragments.start().1;
|
||||
let new_start = new_fragments.summary().text.visible;
|
||||
@@ -1159,10 +1115,10 @@ impl Buffer {
|
||||
if self.deferred_replicas.contains(&op.replica_id()) {
|
||||
false
|
||||
} else {
|
||||
match op {
|
||||
Operation::Edit(edit) => self.version.observed_all(&edit.version),
|
||||
Operation::Undo { undo, .. } => self.version.observed_all(&undo.version),
|
||||
}
|
||||
self.version.observed_all(match op {
|
||||
Operation::Edit(edit) => &edit.version,
|
||||
Operation::Undo(undo) => &undo.version,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1180,7 +1136,7 @@ impl Buffer {
|
||||
|
||||
pub fn start_transaction_at(&mut self, now: Instant) -> Option<TransactionId> {
|
||||
self.history
|
||||
.start_transaction(self.version.clone(), now, &mut self.local_clock)
|
||||
.start_transaction(self.version.clone(), now, &mut self.lamport_clock)
|
||||
}
|
||||
|
||||
pub fn end_transaction(&mut self) -> Option<(TransactionId, clock::Global)> {
|
||||
@@ -1209,7 +1165,7 @@ impl Buffer {
|
||||
&self.history.base_text
|
||||
}
|
||||
|
||||
pub fn operations(&self) -> &TreeMap<clock::Local, Operation> {
|
||||
pub fn operations(&self) -> &TreeMap<clock::Lamport, Operation> {
|
||||
&self.history.operations
|
||||
}
|
||||
|
||||
@@ -1289,16 +1245,13 @@ impl Buffer {
|
||||
}
|
||||
|
||||
let undo = UndoOperation {
|
||||
id: self.local_clock.tick(),
|
||||
timestamp: self.lamport_clock.tick(),
|
||||
version: self.version(),
|
||||
counts,
|
||||
};
|
||||
self.apply_undo(&undo)?;
|
||||
let operation = Operation::Undo {
|
||||
undo,
|
||||
lamport_timestamp: self.lamport_clock.tick(),
|
||||
};
|
||||
self.snapshot.version.observe(operation.local_timestamp());
|
||||
self.snapshot.version.observe(undo.timestamp);
|
||||
let operation = Operation::Undo(undo);
|
||||
self.history.push(operation.clone());
|
||||
Ok(operation)
|
||||
}
|
||||
@@ -1363,7 +1316,7 @@ impl Buffer {
|
||||
|
||||
pub fn wait_for_edits(
|
||||
&mut self,
|
||||
edit_ids: impl IntoIterator<Item = clock::Local>,
|
||||
edit_ids: impl IntoIterator<Item = clock::Lamport>,
|
||||
) -> impl 'static + Future<Output = Result<()>> {
|
||||
let mut futures = Vec::new();
|
||||
for edit_id in edit_ids {
|
||||
@@ -1435,7 +1388,7 @@ impl Buffer {
|
||||
self.wait_for_version_txs.clear();
|
||||
}
|
||||
|
||||
fn resolve_edit(&mut self, edit_id: clock::Local) {
|
||||
fn resolve_edit(&mut self, edit_id: clock::Lamport) {
|
||||
for mut tx in self
|
||||
.edit_id_resolvers
|
||||
.remove(&edit_id)
|
||||
@@ -1513,7 +1466,7 @@ impl Buffer {
|
||||
.insertions
|
||||
.get(
|
||||
&InsertionFragmentKey {
|
||||
timestamp: fragment.insertion_timestamp.local(),
|
||||
timestamp: fragment.timestamp,
|
||||
split_offset: fragment.insertion_offset,
|
||||
},
|
||||
&(),
|
||||
@@ -1996,7 +1949,7 @@ impl BufferSnapshot {
|
||||
let fragment = fragment_cursor.item().unwrap();
|
||||
let overshoot = offset - *fragment_cursor.start();
|
||||
Anchor {
|
||||
timestamp: fragment.insertion_timestamp.local(),
|
||||
timestamp: fragment.timestamp,
|
||||
offset: fragment.insertion_offset + overshoot,
|
||||
bias,
|
||||
buffer_id: Some(self.remote_id),
|
||||
@@ -2188,15 +2141,14 @@ impl<'a, D: TextDimension + Ord, F: FnMut(&FragmentSummary) -> bool> Iterator fo
|
||||
break;
|
||||
}
|
||||
|
||||
let timestamp = fragment.insertion_timestamp.local();
|
||||
let start_anchor = Anchor {
|
||||
timestamp,
|
||||
timestamp: fragment.timestamp,
|
||||
offset: fragment.insertion_offset,
|
||||
bias: Bias::Right,
|
||||
buffer_id: Some(self.buffer_id),
|
||||
};
|
||||
let end_anchor = Anchor {
|
||||
timestamp,
|
||||
timestamp: fragment.timestamp,
|
||||
offset: fragment.insertion_offset + fragment.len,
|
||||
bias: Bias::Left,
|
||||
buffer_id: Some(self.buffer_id),
|
||||
@@ -2269,19 +2221,17 @@ impl<'a, D: TextDimension + Ord, F: FnMut(&FragmentSummary) -> bool> Iterator fo
|
||||
impl Fragment {
|
||||
fn insertion_slice(&self) -> InsertionSlice {
|
||||
InsertionSlice {
|
||||
insertion_id: self.insertion_timestamp.local(),
|
||||
insertion_id: self.timestamp,
|
||||
range: self.insertion_offset..self.insertion_offset + self.len,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_visible(&self, undos: &UndoMap) -> bool {
|
||||
!undos.is_undone(self.insertion_timestamp.local())
|
||||
&& self.deletions.iter().all(|d| undos.is_undone(*d))
|
||||
!undos.is_undone(self.timestamp) && self.deletions.iter().all(|d| undos.is_undone(*d))
|
||||
}
|
||||
|
||||
fn was_visible(&self, version: &clock::Global, undos: &UndoMap) -> bool {
|
||||
(version.observed(self.insertion_timestamp.local())
|
||||
&& !undos.was_undone(self.insertion_timestamp.local(), version))
|
||||
(version.observed(self.timestamp) && !undos.was_undone(self.timestamp, version))
|
||||
&& self
|
||||
.deletions
|
||||
.iter()
|
||||
@@ -2294,14 +2244,14 @@ impl sum_tree::Item for Fragment {
|
||||
|
||||
fn summary(&self) -> Self::Summary {
|
||||
let mut max_version = clock::Global::new();
|
||||
max_version.observe(self.insertion_timestamp.local());
|
||||
max_version.observe(self.timestamp);
|
||||
for deletion in &self.deletions {
|
||||
max_version.observe(*deletion);
|
||||
}
|
||||
max_version.join(&self.max_undos);
|
||||
|
||||
let mut min_insertion_version = clock::Global::new();
|
||||
min_insertion_version.observe(self.insertion_timestamp.local());
|
||||
min_insertion_version.observe(self.timestamp);
|
||||
let max_insertion_version = min_insertion_version.clone();
|
||||
if self.visible {
|
||||
FragmentSummary {
|
||||
@@ -2378,7 +2328,7 @@ impl sum_tree::KeyedItem for InsertionFragment {
|
||||
impl InsertionFragment {
|
||||
fn new(fragment: &Fragment) -> Self {
|
||||
Self {
|
||||
timestamp: fragment.insertion_timestamp.local(),
|
||||
timestamp: fragment.timestamp,
|
||||
split_offset: fragment.insertion_offset,
|
||||
fragment_id: fragment.id.clone(),
|
||||
}
|
||||
@@ -2501,10 +2451,10 @@ impl Operation {
|
||||
operation_queue::Operation::lamport_timestamp(self).replica_id
|
||||
}
|
||||
|
||||
pub fn local_timestamp(&self) -> clock::Local {
|
||||
pub fn timestamp(&self) -> clock::Lamport {
|
||||
match self {
|
||||
Operation::Edit(edit) => edit.timestamp.local(),
|
||||
Operation::Undo { undo, .. } => undo.id,
|
||||
Operation::Edit(edit) => edit.timestamp,
|
||||
Operation::Undo(undo) => undo.timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2523,10 +2473,8 @@ impl Operation {
|
||||
impl operation_queue::Operation for Operation {
|
||||
fn lamport_timestamp(&self) -> clock::Lamport {
|
||||
match self {
|
||||
Operation::Edit(edit) => edit.timestamp.lamport(),
|
||||
Operation::Undo {
|
||||
lamport_timestamp, ..
|
||||
} => *lamport_timestamp,
|
||||
Operation::Edit(edit) => edit.timestamp,
|
||||
Operation::Undo(undo) => undo.timestamp,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,8 +26,8 @@ impl sum_tree::KeyedItem for UndoMapEntry {
|
||||
|
||||
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct UndoMapKey {
|
||||
edit_id: clock::Local,
|
||||
undo_id: clock::Local,
|
||||
edit_id: clock::Lamport,
|
||||
undo_id: clock::Lamport,
|
||||
}
|
||||
|
||||
impl sum_tree::Summary for UndoMapKey {
|
||||
@@ -50,7 +50,7 @@ impl UndoMap {
|
||||
sum_tree::Edit::Insert(UndoMapEntry {
|
||||
key: UndoMapKey {
|
||||
edit_id: *edit_id,
|
||||
undo_id: undo.id,
|
||||
undo_id: undo.timestamp,
|
||||
},
|
||||
undo_count: *count,
|
||||
})
|
||||
@@ -59,11 +59,11 @@ impl UndoMap {
|
||||
self.0.edit(edits, &());
|
||||
}
|
||||
|
||||
pub fn is_undone(&self, edit_id: clock::Local) -> bool {
|
||||
pub fn is_undone(&self, edit_id: clock::Lamport) -> bool {
|
||||
self.undo_count(edit_id) % 2 == 1
|
||||
}
|
||||
|
||||
pub fn was_undone(&self, edit_id: clock::Local, version: &clock::Global) -> bool {
|
||||
pub fn was_undone(&self, edit_id: clock::Lamport, version: &clock::Global) -> bool {
|
||||
let mut cursor = self.0.cursor::<UndoMapKey>();
|
||||
cursor.seek(
|
||||
&UndoMapKey {
|
||||
@@ -88,7 +88,7 @@ impl UndoMap {
|
||||
undo_count % 2 == 1
|
||||
}
|
||||
|
||||
pub fn undo_count(&self, edit_id: clock::Local) -> u32 {
|
||||
pub fn undo_count(&self, edit_id: clock::Lamport) -> u32 {
|
||||
let mut cursor = self.0.cursor::<UndoMapKey>();
|
||||
cursor.seek(
|
||||
&UndoMapKey {
|
||||
|
||||
@@ -408,6 +408,7 @@ pub struct Toolbar {
|
||||
pub height: f32,
|
||||
pub item_spacing: f32,
|
||||
pub toggleable_tool: Toggleable<Interactive<IconButton>>,
|
||||
pub toggleable_text_tool: Toggleable<Interactive<ContainedText>>,
|
||||
pub breadcrumb_height: f32,
|
||||
pub breadcrumbs: Interactive<ContainedText>,
|
||||
}
|
||||
@@ -834,6 +835,9 @@ pub struct AutocompleteStyle {
|
||||
pub selected_item: ContainerStyle,
|
||||
pub hovered_item: ContainerStyle,
|
||||
pub match_highlight: HighlightStyle,
|
||||
pub server_name_container: ContainerStyle,
|
||||
pub server_name_color: Color,
|
||||
pub server_name_size_percent: f32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Default, Deserialize, JsonSchema)]
|
||||
|
||||
@@ -260,11 +260,22 @@ pub fn defer<F: FnOnce()>(f: F) -> impl Drop {
|
||||
Defer(Some(f))
|
||||
}
|
||||
|
||||
pub struct RandomCharIter<T: Rng>(T);
|
||||
pub struct RandomCharIter<T: Rng> {
|
||||
rng: T,
|
||||
simple_text: bool,
|
||||
}
|
||||
|
||||
impl<T: Rng> RandomCharIter<T> {
|
||||
pub fn new(rng: T) -> Self {
|
||||
Self(rng)
|
||||
Self {
|
||||
rng,
|
||||
simple_text: std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_simple_text(mut self) -> Self {
|
||||
self.simple_text = true;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,25 +283,27 @@ impl<T: Rng> Iterator for RandomCharIter<T> {
|
||||
type Item = char;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()) {
|
||||
return if self.0.gen_range(0..100) < 5 {
|
||||
if self.simple_text {
|
||||
return if self.rng.gen_range(0..100) < 5 {
|
||||
Some('\n')
|
||||
} else {
|
||||
Some(self.0.gen_range(b'a'..b'z' + 1).into())
|
||||
Some(self.rng.gen_range(b'a'..b'z' + 1).into())
|
||||
};
|
||||
}
|
||||
|
||||
match self.0.gen_range(0..100) {
|
||||
match self.rng.gen_range(0..100) {
|
||||
// whitespace
|
||||
0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.0).copied(),
|
||||
0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.rng).copied(),
|
||||
// two-byte greek letters
|
||||
20..=32 => char::from_u32(self.0.gen_range(('α' as u32)..('ω' as u32 + 1))),
|
||||
20..=32 => char::from_u32(self.rng.gen_range(('α' as u32)..('ω' as u32 + 1))),
|
||||
// // three-byte characters
|
||||
33..=45 => ['✋', '✅', '❌', '❎', '⭐'].choose(&mut self.0).copied(),
|
||||
33..=45 => ['✋', '✅', '❌', '❎', '⭐']
|
||||
.choose(&mut self.rng)
|
||||
.copied(),
|
||||
// // four-byte characters
|
||||
46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.0).copied(),
|
||||
46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.rng).copied(),
|
||||
// ascii letters
|
||||
_ => Some(self.0.gen_range(b'a'..b'z' + 1).into()),
|
||||
_ => Some(self.rng.gen_range(b'a'..b'z' + 1).into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ language_selector = { path = "../language_selector"}
|
||||
[dev-dependencies]
|
||||
indoc.workspace = true
|
||||
parking_lot.workspace = true
|
||||
futures.workspace = true
|
||||
|
||||
editor = { path = "../editor", features = ["test-support"] }
|
||||
gpui = { path = "../gpui", features = ["test-support"] }
|
||||
@@ -47,3 +48,4 @@ util = { path = "../util", features = ["test-support"] }
|
||||
settings = { path = "../settings" }
|
||||
workspace = { path = "../workspace", features = ["test-support"] }
|
||||
theme = { path = "../theme", features = ["test-support"] }
|
||||
lsp = { path = "../lsp", features = ["test-support"] }
|
||||
|
||||
@@ -34,6 +34,7 @@ fn focused(EditorFocused(editor): &EditorFocused, cx: &mut AppContext) {
|
||||
fn blurred(EditorBlurred(editor): &EditorBlurred, cx: &mut AppContext) {
|
||||
editor.window().update(cx, |cx| {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.workspace_state.recording = false;
|
||||
if let Some(previous_editor) = vim.active_editor.clone() {
|
||||
if previous_editor == editor.clone() {
|
||||
vim.active_editor = None;
|
||||
|
||||
@@ -11,8 +11,9 @@ pub fn init(cx: &mut AppContext) {
|
||||
}
|
||||
|
||||
fn normal_before(_: &mut Workspace, _: &NormalBefore, cx: &mut ViewContext<Workspace>) {
|
||||
Vim::update(cx, |state, cx| {
|
||||
state.update_active_editor(cx, |editor, cx| {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.stop_recording();
|
||||
vim.update_active_editor(cx, |editor, cx| {
|
||||
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
|
||||
s.move_cursors_with(|map, mut cursor, _| {
|
||||
*cursor.column_mut() = cursor.column().saturating_sub(1);
|
||||
@@ -20,7 +21,7 @@ fn normal_before(_: &mut Workspace, _: &NormalBefore, cx: &mut ViewContext<Works
|
||||
});
|
||||
});
|
||||
});
|
||||
state.switch_mode(Mode::Normal, false, cx);
|
||||
vim.switch_mode(Mode::Normal, false, cx);
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use std::{cmp, sync::Arc};
|
||||
use std::cmp;
|
||||
|
||||
use editor::{
|
||||
char_kind,
|
||||
display_map::{DisplaySnapshot, FoldPoint, ToDisplayPoint},
|
||||
movement, Bias, CharKind, DisplayPoint, ToOffset,
|
||||
movement::{self, find_boundary, find_preceding_boundary, FindRange},
|
||||
Bias, CharKind, DisplayPoint, ToOffset,
|
||||
};
|
||||
use gpui::{actions, impl_actions, AppContext, WindowContext};
|
||||
use language::{Point, Selection, SelectionGoal};
|
||||
@@ -36,8 +37,8 @@ pub enum Motion {
|
||||
StartOfDocument,
|
||||
EndOfDocument,
|
||||
Matching,
|
||||
FindForward { before: bool, text: Arc<str> },
|
||||
FindBackward { after: bool, text: Arc<str> },
|
||||
FindForward { before: bool, char: char },
|
||||
FindBackward { after: bool, char: char },
|
||||
NextLineStart,
|
||||
}
|
||||
|
||||
@@ -64,9 +65,9 @@ struct PreviousWordStart {
|
||||
|
||||
#[derive(Clone, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct Up {
|
||||
pub(crate) struct Up {
|
||||
#[serde(default)]
|
||||
display_lines: bool,
|
||||
pub(crate) display_lines: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, PartialEq)]
|
||||
@@ -92,9 +93,9 @@ struct EndOfLine {
|
||||
|
||||
#[derive(Clone, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct StartOfLine {
|
||||
pub struct StartOfLine {
|
||||
#[serde(default)]
|
||||
display_lines: bool,
|
||||
pub(crate) display_lines: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, PartialEq)]
|
||||
@@ -232,25 +233,25 @@ pub(crate) fn motion(motion: Motion, cx: &mut WindowContext) {
|
||||
|
||||
fn repeat_motion(backwards: bool, cx: &mut WindowContext) {
|
||||
let find = match Vim::read(cx).workspace_state.last_find.clone() {
|
||||
Some(Motion::FindForward { before, text }) => {
|
||||
Some(Motion::FindForward { before, char }) => {
|
||||
if backwards {
|
||||
Motion::FindBackward {
|
||||
after: before,
|
||||
text,
|
||||
char,
|
||||
}
|
||||
} else {
|
||||
Motion::FindForward { before, text }
|
||||
Motion::FindForward { before, char }
|
||||
}
|
||||
}
|
||||
|
||||
Some(Motion::FindBackward { after, text }) => {
|
||||
Some(Motion::FindBackward { after, char }) => {
|
||||
if backwards {
|
||||
Motion::FindForward {
|
||||
before: after,
|
||||
text,
|
||||
char,
|
||||
}
|
||||
} else {
|
||||
Motion::FindBackward { after, text }
|
||||
Motion::FindBackward { after, char }
|
||||
}
|
||||
}
|
||||
_ => return,
|
||||
@@ -402,12 +403,12 @@ impl Motion {
|
||||
SelectionGoal::None,
|
||||
),
|
||||
Matching => (matching(map, point), SelectionGoal::None),
|
||||
FindForward { before, text } => (
|
||||
find_forward(map, point, *before, text.clone(), times),
|
||||
FindForward { before, char } => (
|
||||
find_forward(map, point, *before, *char, times),
|
||||
SelectionGoal::None,
|
||||
),
|
||||
FindBackward { after, text } => (
|
||||
find_backward(map, point, *after, text.clone(), times),
|
||||
FindBackward { after, char } => (
|
||||
find_backward(map, point, *after, *char, times),
|
||||
SelectionGoal::None,
|
||||
),
|
||||
NextLineStart => (next_line_start(map, point, times), SelectionGoal::None),
|
||||
@@ -589,12 +590,12 @@ pub(crate) fn next_word_start(
|
||||
ignore_punctuation: bool,
|
||||
times: usize,
|
||||
) -> DisplayPoint {
|
||||
let language = map.buffer_snapshot.language_at(point.to_point(map));
|
||||
let scope = map.buffer_snapshot.language_scope_at(point.to_point(map));
|
||||
for _ in 0..times {
|
||||
let mut crossed_newline = false;
|
||||
point = movement::find_boundary(map, point, |left, right| {
|
||||
let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation);
|
||||
let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation);
|
||||
point = movement::find_boundary(map, point, FindRange::MultiLine, |left, right| {
|
||||
let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation);
|
||||
let right_kind = char_kind(&scope, right).coerce_punctuation(ignore_punctuation);
|
||||
let at_newline = right == '\n';
|
||||
|
||||
let found = (left_kind != right_kind && right_kind != CharKind::Whitespace)
|
||||
@@ -614,12 +615,17 @@ fn next_word_end(
|
||||
ignore_punctuation: bool,
|
||||
times: usize,
|
||||
) -> DisplayPoint {
|
||||
let language = map.buffer_snapshot.language_at(point.to_point(map));
|
||||
let scope = map.buffer_snapshot.language_scope_at(point.to_point(map));
|
||||
for _ in 0..times {
|
||||
*point.column_mut() += 1;
|
||||
point = movement::find_boundary(map, point, |left, right| {
|
||||
let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation);
|
||||
let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation);
|
||||
if point.column() < map.line_len(point.row()) {
|
||||
*point.column_mut() += 1;
|
||||
} else if point.row() < map.max_buffer_row() {
|
||||
*point.row_mut() += 1;
|
||||
*point.column_mut() = 0;
|
||||
}
|
||||
point = movement::find_boundary(map, point, FindRange::MultiLine, |left, right| {
|
||||
let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation);
|
||||
let right_kind = char_kind(&scope, right).coerce_punctuation(ignore_punctuation);
|
||||
|
||||
left_kind != right_kind && left_kind != CharKind::Whitespace
|
||||
});
|
||||
@@ -645,16 +651,17 @@ fn previous_word_start(
|
||||
ignore_punctuation: bool,
|
||||
times: usize,
|
||||
) -> DisplayPoint {
|
||||
let language = map.buffer_snapshot.language_at(point.to_point(map));
|
||||
let scope = map.buffer_snapshot.language_scope_at(point.to_point(map));
|
||||
for _ in 0..times {
|
||||
// This works even though find_preceding_boundary is called for every character in the line containing
|
||||
// cursor because the newline is checked only once.
|
||||
point = movement::find_preceding_boundary(map, point, |left, right| {
|
||||
let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation);
|
||||
let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation);
|
||||
point =
|
||||
movement::find_preceding_boundary(map, point, FindRange::MultiLine, |left, right| {
|
||||
let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation);
|
||||
let right_kind = char_kind(&scope, right).coerce_punctuation(ignore_punctuation);
|
||||
|
||||
(left_kind != right_kind && !right.is_whitespace()) || left == '\n'
|
||||
});
|
||||
(left_kind != right_kind && !right.is_whitespace()) || left == '\n'
|
||||
});
|
||||
}
|
||||
point
|
||||
}
|
||||
@@ -665,7 +672,7 @@ fn first_non_whitespace(
|
||||
from: DisplayPoint,
|
||||
) -> DisplayPoint {
|
||||
let mut last_point = start_of_line(map, display_lines, from);
|
||||
let language = map.buffer_snapshot.language_at(from.to_point(map));
|
||||
let scope = map.buffer_snapshot.language_scope_at(from.to_point(map));
|
||||
for (ch, point) in map.chars_at(last_point) {
|
||||
if ch == '\n' {
|
||||
return from;
|
||||
@@ -673,7 +680,7 @@ fn first_non_whitespace(
|
||||
|
||||
last_point = point;
|
||||
|
||||
if char_kind(language, ch) != CharKind::Whitespace {
|
||||
if char_kind(&scope, ch) != CharKind::Whitespace {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -786,44 +793,55 @@ fn find_forward(
|
||||
map: &DisplaySnapshot,
|
||||
from: DisplayPoint,
|
||||
before: bool,
|
||||
target: Arc<str>,
|
||||
target: char,
|
||||
times: usize,
|
||||
) -> DisplayPoint {
|
||||
map.find_while(from, target.as_ref(), |ch, _| ch != '\n')
|
||||
.skip_while(|found_at| found_at == &from)
|
||||
.nth(times - 1)
|
||||
.map(|mut found| {
|
||||
if before {
|
||||
*found.column_mut() -= 1;
|
||||
found = map.clip_point(found, Bias::Right);
|
||||
found
|
||||
} else {
|
||||
found
|
||||
}
|
||||
})
|
||||
.unwrap_or(from)
|
||||
let mut to = from;
|
||||
let mut found = false;
|
||||
|
||||
for _ in 0..times {
|
||||
found = false;
|
||||
to = find_boundary(map, to, FindRange::SingleLine, |_, right| {
|
||||
found = right == target;
|
||||
found
|
||||
});
|
||||
}
|
||||
|
||||
if found {
|
||||
if before && to.column() > 0 {
|
||||
*to.column_mut() -= 1;
|
||||
map.clip_point(to, Bias::Left)
|
||||
} else {
|
||||
to
|
||||
}
|
||||
} else {
|
||||
from
|
||||
}
|
||||
}
|
||||
|
||||
fn find_backward(
|
||||
map: &DisplaySnapshot,
|
||||
from: DisplayPoint,
|
||||
after: bool,
|
||||
target: Arc<str>,
|
||||
target: char,
|
||||
times: usize,
|
||||
) -> DisplayPoint {
|
||||
map.reverse_find_while(from, target.as_ref(), |ch, _| ch != '\n')
|
||||
.skip_while(|found_at| found_at == &from)
|
||||
.nth(times - 1)
|
||||
.map(|mut found| {
|
||||
if after {
|
||||
*found.column_mut() += 1;
|
||||
found = map.clip_point(found, Bias::Left);
|
||||
found
|
||||
} else {
|
||||
found
|
||||
}
|
||||
})
|
||||
.unwrap_or(from)
|
||||
let mut to = from;
|
||||
|
||||
for _ in 0..times {
|
||||
to = find_preceding_boundary(map, to, FindRange::SingleLine, |_, right| right == target);
|
||||
}
|
||||
|
||||
if map.buffer_snapshot.chars_at(to.to_point(map)).next() == Some(target) {
|
||||
if after {
|
||||
*to.column_mut() += 1;
|
||||
map.clip_point(to, Bias::Right)
|
||||
} else {
|
||||
to
|
||||
}
|
||||
} else {
|
||||
from
|
||||
}
|
||||
}
|
||||
|
||||
fn next_line_start(map: &DisplaySnapshot, point: DisplayPoint, times: usize) -> DisplayPoint {
|
||||
|
||||
@@ -2,6 +2,7 @@ mod case;
|
||||
mod change;
|
||||
mod delete;
|
||||
mod paste;
|
||||
mod repeat;
|
||||
mod scroll;
|
||||
mod search;
|
||||
pub mod substitute;
|
||||
@@ -27,7 +28,6 @@ use self::{
|
||||
case::change_case,
|
||||
change::{change_motion, change_object},
|
||||
delete::{delete_motion, delete_object},
|
||||
substitute::substitute,
|
||||
yank::{yank_motion, yank_object},
|
||||
};
|
||||
|
||||
@@ -35,6 +35,7 @@ actions!(
|
||||
vim,
|
||||
[
|
||||
InsertAfter,
|
||||
InsertBefore,
|
||||
InsertFirstNonWhitespace,
|
||||
InsertEndOfLine,
|
||||
InsertLineAbove,
|
||||
@@ -44,39 +45,43 @@ actions!(
|
||||
ChangeToEndOfLine,
|
||||
DeleteToEndOfLine,
|
||||
Yank,
|
||||
Substitute,
|
||||
ChangeCase,
|
||||
JoinLines,
|
||||
]
|
||||
);
|
||||
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
paste::init(cx);
|
||||
repeat::init(cx);
|
||||
scroll::init(cx);
|
||||
search::init(cx);
|
||||
substitute::init(cx);
|
||||
|
||||
cx.add_action(insert_after);
|
||||
cx.add_action(insert_before);
|
||||
cx.add_action(insert_first_non_whitespace);
|
||||
cx.add_action(insert_end_of_line);
|
||||
cx.add_action(insert_line_above);
|
||||
cx.add_action(insert_line_below);
|
||||
cx.add_action(change_case);
|
||||
search::init(cx);
|
||||
cx.add_action(|_: &mut Workspace, _: &Substitute, cx| {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
let times = vim.pop_number_operator(cx);
|
||||
substitute(vim, times, cx);
|
||||
})
|
||||
});
|
||||
|
||||
cx.add_action(|_: &mut Workspace, _: &DeleteLeft, cx| {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.record_current_action(cx);
|
||||
let times = vim.pop_number_operator(cx);
|
||||
delete_motion(vim, Motion::Left, times, cx);
|
||||
})
|
||||
});
|
||||
cx.add_action(|_: &mut Workspace, _: &DeleteRight, cx| {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.record_current_action(cx);
|
||||
let times = vim.pop_number_operator(cx);
|
||||
delete_motion(vim, Motion::Right, times, cx);
|
||||
})
|
||||
});
|
||||
cx.add_action(|_: &mut Workspace, _: &ChangeToEndOfLine, cx| {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.start_recording(cx);
|
||||
let times = vim.pop_number_operator(cx);
|
||||
change_motion(
|
||||
vim,
|
||||
@@ -90,6 +95,7 @@ pub fn init(cx: &mut AppContext) {
|
||||
});
|
||||
cx.add_action(|_: &mut Workspace, _: &DeleteToEndOfLine, cx| {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.record_current_action(cx);
|
||||
let times = vim.pop_number_operator(cx);
|
||||
delete_motion(
|
||||
vim,
|
||||
@@ -101,8 +107,26 @@ pub fn init(cx: &mut AppContext) {
|
||||
);
|
||||
})
|
||||
});
|
||||
scroll::init(cx);
|
||||
paste::init(cx);
|
||||
cx.add_action(|_: &mut Workspace, _: &JoinLines, cx| {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.record_current_action(cx);
|
||||
let mut times = vim.pop_number_operator(cx).unwrap_or(1);
|
||||
if vim.state().mode.is_visual() {
|
||||
times = 1;
|
||||
} else if times > 1 {
|
||||
// 2J joins two lines together (same as J or 1J)
|
||||
times -= 1;
|
||||
}
|
||||
|
||||
vim.update_active_editor(cx, |editor, cx| {
|
||||
editor.transact(cx, |editor, cx| {
|
||||
for _ in 0..times {
|
||||
editor.join_lines(&Default::default(), cx)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn normal_motion(
|
||||
@@ -158,6 +182,7 @@ fn move_cursor(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &mut Win
|
||||
|
||||
fn insert_after(_: &mut Workspace, _: &InsertAfter, cx: &mut ViewContext<Workspace>) {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.start_recording(cx);
|
||||
vim.switch_mode(Mode::Insert, false, cx);
|
||||
vim.update_active_editor(cx, |editor, cx| {
|
||||
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
|
||||
@@ -169,12 +194,20 @@ fn insert_after(_: &mut Workspace, _: &InsertAfter, cx: &mut ViewContext<Workspa
|
||||
});
|
||||
}
|
||||
|
||||
fn insert_before(_: &mut Workspace, _: &InsertBefore, cx: &mut ViewContext<Workspace>) {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.start_recording(cx);
|
||||
vim.switch_mode(Mode::Insert, false, cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn insert_first_non_whitespace(
|
||||
_: &mut Workspace,
|
||||
_: &InsertFirstNonWhitespace,
|
||||
cx: &mut ViewContext<Workspace>,
|
||||
) {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.start_recording(cx);
|
||||
vim.switch_mode(Mode::Insert, false, cx);
|
||||
vim.update_active_editor(cx, |editor, cx| {
|
||||
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
|
||||
@@ -191,6 +224,7 @@ fn insert_first_non_whitespace(
|
||||
|
||||
fn insert_end_of_line(_: &mut Workspace, _: &InsertEndOfLine, cx: &mut ViewContext<Workspace>) {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.start_recording(cx);
|
||||
vim.switch_mode(Mode::Insert, false, cx);
|
||||
vim.update_active_editor(cx, |editor, cx| {
|
||||
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
|
||||
@@ -204,6 +238,7 @@ fn insert_end_of_line(_: &mut Workspace, _: &InsertEndOfLine, cx: &mut ViewConte
|
||||
|
||||
fn insert_line_above(_: &mut Workspace, _: &InsertLineAbove, cx: &mut ViewContext<Workspace>) {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.start_recording(cx);
|
||||
vim.switch_mode(Mode::Insert, false, cx);
|
||||
vim.update_active_editor(cx, |editor, cx| {
|
||||
editor.transact(cx, |editor, cx| {
|
||||
@@ -236,6 +271,7 @@ fn insert_line_above(_: &mut Workspace, _: &InsertLineAbove, cx: &mut ViewContex
|
||||
|
||||
fn insert_line_below(_: &mut Workspace, _: &InsertLineBelow, cx: &mut ViewContext<Workspace>) {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.start_recording(cx);
|
||||
vim.switch_mode(Mode::Insert, false, cx);
|
||||
vim.update_active_editor(cx, |editor, cx| {
|
||||
editor.transact(cx, |editor, cx| {
|
||||
@@ -267,6 +303,7 @@ fn insert_line_below(_: &mut Workspace, _: &InsertLineBelow, cx: &mut ViewContex
|
||||
|
||||
pub(crate) fn normal_replace(text: Arc<str>, cx: &mut WindowContext) {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.stop_recording();
|
||||
vim.update_active_editor(cx, |editor, cx| {
|
||||
editor.transact(cx, |editor, cx| {
|
||||
editor.set_clip_at_line_ends(false, cx);
|
||||
@@ -445,7 +482,7 @@ mod test {
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_e(cx: &mut gpui::TestAppContext) {
|
||||
async fn test_end_of_word(cx: &mut gpui::TestAppContext) {
|
||||
let mut cx = NeovimBackedTestContext::new(cx).await.binding(["e"]);
|
||||
cx.assert_all(indoc! {"
|
||||
Thˇe quicˇkˇ-browˇn
|
||||
@@ -787,6 +824,7 @@ mod test {
|
||||
#[gpui::test]
|
||||
async fn test_f_and_t(cx: &mut gpui::TestAppContext) {
|
||||
let mut cx = NeovimBackedTestContext::new(cx).await;
|
||||
|
||||
for count in 1..=3 {
|
||||
let test_case = indoc! {"
|
||||
ˇaaaˇbˇ ˇbˇ ˇbˇbˇ aˇaaˇbaaa
|
||||
|
||||
@@ -7,6 +7,7 @@ use crate::{normal::ChangeCase, state::Mode, Vim};
|
||||
|
||||
pub fn change_case(_: &mut Workspace, _: &ChangeCase, cx: &mut ViewContext<Workspace>) {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.record_current_action(cx);
|
||||
let count = vim.pop_number_operator(cx).unwrap_or(1) as u32;
|
||||
vim.update_active_editor(cx, |editor, cx| {
|
||||
let mut ranges = Vec::new();
|
||||
@@ -21,10 +22,16 @@ pub fn change_case(_: &mut Workspace, _: &ChangeCase, cx: &mut ViewContext<Works
|
||||
ranges.push(start..end);
|
||||
cursor_positions.push(start..start);
|
||||
}
|
||||
Mode::Visual | Mode::VisualBlock => {
|
||||
Mode::Visual => {
|
||||
ranges.push(selection.start..selection.end);
|
||||
cursor_positions.push(selection.start..selection.start);
|
||||
}
|
||||
Mode::VisualBlock => {
|
||||
ranges.push(selection.start..selection.end);
|
||||
if cursor_positions.len() == 0 {
|
||||
cursor_positions.push(selection.start..selection.start);
|
||||
}
|
||||
}
|
||||
Mode::Insert | Mode::Normal => {
|
||||
let start = selection.start;
|
||||
let mut end = start;
|
||||
@@ -96,6 +103,11 @@ mod test {
|
||||
cx.simulate_shared_keystrokes(["shift-v", "~"]).await;
|
||||
cx.assert_shared_state("ˇABc\n").await;
|
||||
|
||||
// works in visual block mode
|
||||
cx.set_shared_state("ˇaa\nbb\ncc").await;
|
||||
cx.simulate_shared_keystrokes(["ctrl-v", "j", "~"]).await;
|
||||
cx.assert_shared_state("ˇAa\nBb\ncc").await;
|
||||
|
||||
// works with multiple cursors (zed only)
|
||||
cx.set_state("aˇßcdˇe\n", Mode::Normal);
|
||||
cx.simulate_keystroke("~");
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
use crate::{motion::Motion, object::Object, state::Mode, utils::copy_selections_content, Vim};
|
||||
use editor::{
|
||||
char_kind, display_map::DisplaySnapshot, movement, scroll::autoscroll::Autoscroll, CharKind,
|
||||
DisplayPoint,
|
||||
char_kind,
|
||||
display_map::DisplaySnapshot,
|
||||
movement::{self, FindRange},
|
||||
scroll::autoscroll::Autoscroll,
|
||||
CharKind, DisplayPoint,
|
||||
};
|
||||
use gpui::WindowContext;
|
||||
use language::Selection;
|
||||
@@ -86,22 +89,24 @@ fn expand_changed_word_selection(
|
||||
ignore_punctuation: bool,
|
||||
) -> bool {
|
||||
if times.is_none() || times.unwrap() == 1 {
|
||||
let language = map
|
||||
let scope = map
|
||||
.buffer_snapshot
|
||||
.language_at(selection.start.to_point(map));
|
||||
.language_scope_at(selection.start.to_point(map));
|
||||
let in_word = map
|
||||
.chars_at(selection.head())
|
||||
.next()
|
||||
.map(|(c, _)| char_kind(language, c) != CharKind::Whitespace)
|
||||
.map(|(c, _)| char_kind(&scope, c) != CharKind::Whitespace)
|
||||
.unwrap_or_default();
|
||||
|
||||
if in_word {
|
||||
selection.end = movement::find_boundary(map, selection.end, |left, right| {
|
||||
let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation);
|
||||
let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation);
|
||||
selection.end =
|
||||
movement::find_boundary(map, selection.end, FindRange::MultiLine, |left, right| {
|
||||
let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation);
|
||||
let right_kind =
|
||||
char_kind(&scope, right).coerce_punctuation(ignore_punctuation);
|
||||
|
||||
left_kind != right_kind && left_kind != CharKind::Whitespace
|
||||
});
|
||||
left_kind != right_kind && left_kind != CharKind::Whitespace
|
||||
});
|
||||
true
|
||||
} else {
|
||||
Motion::NextWordStart { ignore_punctuation }
|
||||
|
||||
@@ -4,6 +4,7 @@ use editor::{display_map::ToDisplayPoint, scroll::autoscroll::Autoscroll, Bias};
|
||||
use gpui::WindowContext;
|
||||
|
||||
pub fn delete_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &mut WindowContext) {
|
||||
vim.stop_recording();
|
||||
vim.update_active_editor(cx, |editor, cx| {
|
||||
editor.transact(cx, |editor, cx| {
|
||||
editor.set_clip_at_line_ends(false, cx);
|
||||
@@ -37,6 +38,7 @@ pub fn delete_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &m
|
||||
}
|
||||
|
||||
pub fn delete_object(vim: &mut Vim, object: Object, around: bool, cx: &mut WindowContext) {
|
||||
vim.stop_recording();
|
||||
vim.update_active_editor(cx, |editor, cx| {
|
||||
editor.transact(cx, |editor, cx| {
|
||||
editor.set_clip_at_line_ends(false, cx);
|
||||
|
||||
@@ -28,6 +28,7 @@ pub(crate) fn init(cx: &mut AppContext) {
|
||||
|
||||
fn paste(_: &mut Workspace, action: &Paste, cx: &mut ViewContext<Workspace>) {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.record_current_action(cx);
|
||||
vim.update_active_editor(cx, |editor, cx| {
|
||||
editor.transact(cx, |editor, cx| {
|
||||
editor.set_clip_at_line_ends(false, cx);
|
||||
|
||||
427
crates/vim/src/normal/repeat.rs
Normal file
427
crates/vim/src/normal/repeat.rs
Normal file
@@ -0,0 +1,427 @@
|
||||
use crate::{
|
||||
motion::Motion,
|
||||
state::{Mode, RecordedSelection, ReplayableAction},
|
||||
visual::visual_motion,
|
||||
Vim,
|
||||
};
|
||||
use gpui::{actions, Action, AppContext};
|
||||
use workspace::Workspace;
|
||||
|
||||
actions!(vim, [Repeat, EndRepeat,]);
|
||||
|
||||
fn should_replay(action: &Box<dyn Action>) -> bool {
|
||||
// skip so that we don't leave the character palette open
|
||||
if editor::ShowCharacterPalette.id() == action.id() {
|
||||
return false;
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub(crate) fn init(cx: &mut AppContext) {
|
||||
cx.add_action(|_: &mut Workspace, _: &EndRepeat, cx| {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.workspace_state.replaying = false;
|
||||
vim.update_active_editor(cx, |editor, _| {
|
||||
editor.show_local_selections = true;
|
||||
});
|
||||
vim.switch_mode(Mode::Normal, false, cx)
|
||||
});
|
||||
});
|
||||
|
||||
cx.add_action(|_: &mut Workspace, _: &Repeat, cx| {
|
||||
let Some((actions, editor, selection)) = Vim::update(cx, |vim, cx| {
|
||||
let actions = vim.workspace_state.recorded_actions.clone();
|
||||
let Some(editor) = vim.active_editor.clone() else {
|
||||
return None;
|
||||
};
|
||||
let count = vim.pop_number_operator(cx);
|
||||
|
||||
vim.workspace_state.replaying = true;
|
||||
|
||||
let selection = vim.workspace_state.recorded_selection.clone();
|
||||
match selection {
|
||||
RecordedSelection::SingleLine { .. } | RecordedSelection::Visual { .. } => {
|
||||
vim.workspace_state.recorded_count = None;
|
||||
vim.switch_mode(Mode::Visual, false, cx)
|
||||
}
|
||||
RecordedSelection::VisualLine { .. } => {
|
||||
vim.workspace_state.recorded_count = None;
|
||||
vim.switch_mode(Mode::VisualLine, false, cx)
|
||||
}
|
||||
RecordedSelection::VisualBlock { .. } => {
|
||||
vim.workspace_state.recorded_count = None;
|
||||
vim.switch_mode(Mode::VisualBlock, false, cx)
|
||||
}
|
||||
RecordedSelection::None => {
|
||||
if let Some(count) = count {
|
||||
vim.workspace_state.recorded_count = Some(count);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(editor) = editor.upgrade(cx) {
|
||||
editor.update(cx, |editor, _| {
|
||||
editor.show_local_selections = false;
|
||||
})
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some((actions, editor, selection))
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
|
||||
match selection {
|
||||
RecordedSelection::SingleLine { cols } => {
|
||||
if cols > 1 {
|
||||
visual_motion(Motion::Right, Some(cols as usize - 1), cx)
|
||||
}
|
||||
}
|
||||
RecordedSelection::Visual { rows, cols } => {
|
||||
visual_motion(
|
||||
Motion::Down {
|
||||
display_lines: false,
|
||||
},
|
||||
Some(rows as usize),
|
||||
cx,
|
||||
);
|
||||
visual_motion(
|
||||
Motion::StartOfLine {
|
||||
display_lines: false,
|
||||
},
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
if cols > 1 {
|
||||
visual_motion(Motion::Right, Some(cols as usize - 1), cx)
|
||||
}
|
||||
}
|
||||
RecordedSelection::VisualBlock { rows, cols } => {
|
||||
visual_motion(
|
||||
Motion::Down {
|
||||
display_lines: false,
|
||||
},
|
||||
Some(rows as usize),
|
||||
cx,
|
||||
);
|
||||
if cols > 1 {
|
||||
visual_motion(Motion::Right, Some(cols as usize - 1), cx);
|
||||
}
|
||||
}
|
||||
RecordedSelection::VisualLine { rows } => {
|
||||
visual_motion(
|
||||
Motion::Down {
|
||||
display_lines: false,
|
||||
},
|
||||
Some(rows as usize),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
RecordedSelection::None => {}
|
||||
}
|
||||
|
||||
let window = cx.window();
|
||||
cx.app_context()
|
||||
.spawn(move |mut cx| async move {
|
||||
for action in actions {
|
||||
match action {
|
||||
ReplayableAction::Action(action) => {
|
||||
if should_replay(&action) {
|
||||
window
|
||||
.dispatch_action(editor.id(), action.as_ref(), &mut cx)
|
||||
.ok_or_else(|| anyhow::anyhow!("window was closed"))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
ReplayableAction::Insertion {
|
||||
text,
|
||||
utf16_range_to_replace,
|
||||
} => editor.update(&mut cx, |editor, cx| {
|
||||
editor.replay_insert_event(&text, utf16_range_to_replace.clone(), cx)
|
||||
}),
|
||||
}?
|
||||
}
|
||||
window
|
||||
.dispatch_action(editor.id(), &EndRepeat, &mut cx)
|
||||
.ok_or_else(|| anyhow::anyhow!("window was closed"))
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::sync::Arc;
|
||||
|
||||
use editor::test::editor_lsp_test_context::EditorLspTestContext;
|
||||
use futures::StreamExt;
|
||||
use indoc::indoc;
|
||||
|
||||
use gpui::{executor::Deterministic, View};
|
||||
|
||||
use crate::{
|
||||
state::Mode,
|
||||
test::{NeovimBackedTestContext, VimTestContext},
|
||||
};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_dot_repeat(deterministic: Arc<Deterministic>, cx: &mut gpui::TestAppContext) {
|
||||
let mut cx = NeovimBackedTestContext::new(cx).await;
|
||||
|
||||
// "o"
|
||||
cx.set_shared_state("ˇhello").await;
|
||||
cx.simulate_shared_keystrokes(["o", "w", "o", "r", "l", "d", "escape"])
|
||||
.await;
|
||||
cx.assert_shared_state("hello\nworlˇd").await;
|
||||
cx.simulate_shared_keystrokes(["."]).await;
|
||||
deterministic.run_until_parked();
|
||||
cx.assert_shared_state("hello\nworld\nworlˇd").await;
|
||||
|
||||
// "d"
|
||||
cx.simulate_shared_keystrokes(["^", "d", "f", "o"]).await;
|
||||
cx.simulate_shared_keystrokes(["g", "g", "."]).await;
|
||||
deterministic.run_until_parked();
|
||||
cx.assert_shared_state("ˇ\nworld\nrld").await;
|
||||
|
||||
// "p" (note that it pastes the current clipboard)
|
||||
cx.simulate_shared_keystrokes(["j", "y", "y", "p"]).await;
|
||||
cx.simulate_shared_keystrokes(["shift-g", "y", "y", "."])
|
||||
.await;
|
||||
deterministic.run_until_parked();
|
||||
cx.assert_shared_state("\nworld\nworld\nrld\nˇrld").await;
|
||||
|
||||
// "~" (note that counts apply to the action taken, not . itself)
|
||||
cx.set_shared_state("ˇthe quick brown fox").await;
|
||||
cx.simulate_shared_keystrokes(["2", "~", "."]).await;
|
||||
deterministic.run_until_parked();
|
||||
cx.set_shared_state("THE ˇquick brown fox").await;
|
||||
cx.simulate_shared_keystrokes(["3", "."]).await;
|
||||
deterministic.run_until_parked();
|
||||
cx.set_shared_state("THE QUIˇck brown fox").await;
|
||||
deterministic.run_until_parked();
|
||||
cx.simulate_shared_keystrokes(["."]).await;
|
||||
deterministic.run_until_parked();
|
||||
cx.set_shared_state("THE QUICK ˇbrown fox").await;
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_repeat_ime(deterministic: Arc<Deterministic>, cx: &mut gpui::TestAppContext) {
|
||||
let mut cx = VimTestContext::new(cx, true).await;
|
||||
|
||||
cx.set_state("hˇllo", Mode::Normal);
|
||||
cx.simulate_keystrokes(["i"]);
|
||||
|
||||
// simulate brazilian input for ä.
|
||||
cx.update_editor(|editor, cx| {
|
||||
editor.replace_and_mark_text_in_range(None, "\"", Some(1..1), cx);
|
||||
editor.replace_text_in_range(None, "ä", cx);
|
||||
});
|
||||
cx.simulate_keystrokes(["escape"]);
|
||||
cx.assert_state("hˇällo", Mode::Normal);
|
||||
cx.simulate_keystrokes(["."]);
|
||||
deterministic.run_until_parked();
|
||||
cx.assert_state("hˇäällo", Mode::Normal);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_repeat_completion(
|
||||
deterministic: Arc<Deterministic>,
|
||||
cx: &mut gpui::TestAppContext,
|
||||
) {
|
||||
let cx = EditorLspTestContext::new_rust(
|
||||
lsp::ServerCapabilities {
|
||||
completion_provider: Some(lsp::CompletionOptions {
|
||||
trigger_characters: Some(vec![".".to_string(), ":".to_string()]),
|
||||
resolve_provider: Some(true),
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
let mut cx = VimTestContext::new_with_lsp(cx, true);
|
||||
|
||||
cx.set_state(
|
||||
indoc! {"
|
||||
onˇe
|
||||
two
|
||||
three
|
||||
"},
|
||||
Mode::Normal,
|
||||
);
|
||||
|
||||
let mut request =
|
||||
cx.handle_request::<lsp::request::Completion, _, _>(move |_, params, _| async move {
|
||||
let position = params.text_document_position.position;
|
||||
Ok(Some(lsp::CompletionResponse::Array(vec![
|
||||
lsp::CompletionItem {
|
||||
label: "first".to_string(),
|
||||
text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
|
||||
range: lsp::Range::new(position.clone(), position.clone()),
|
||||
new_text: "first".to_string(),
|
||||
})),
|
||||
..Default::default()
|
||||
},
|
||||
lsp::CompletionItem {
|
||||
label: "second".to_string(),
|
||||
text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
|
||||
range: lsp::Range::new(position.clone(), position.clone()),
|
||||
new_text: "second".to_string(),
|
||||
})),
|
||||
..Default::default()
|
||||
},
|
||||
])))
|
||||
});
|
||||
cx.simulate_keystrokes(["a", "."]);
|
||||
request.next().await;
|
||||
cx.condition(|editor, _| editor.context_menu_visible())
|
||||
.await;
|
||||
cx.simulate_keystrokes(["down", "enter", "!", "escape"]);
|
||||
|
||||
cx.assert_state(
|
||||
indoc! {"
|
||||
one.secondˇ!
|
||||
two
|
||||
three
|
||||
"},
|
||||
Mode::Normal,
|
||||
);
|
||||
cx.simulate_keystrokes(["j", "."]);
|
||||
deterministic.run_until_parked();
|
||||
cx.assert_state(
|
||||
indoc! {"
|
||||
one.second!
|
||||
two.secondˇ!
|
||||
three
|
||||
"},
|
||||
Mode::Normal,
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_repeat_visual(deterministic: Arc<Deterministic>, cx: &mut gpui::TestAppContext) {
|
||||
let mut cx = NeovimBackedTestContext::new(cx).await;
|
||||
|
||||
// single-line (3 columns)
|
||||
cx.set_shared_state(indoc! {
|
||||
"ˇthe quick brown
|
||||
fox jumps over
|
||||
the lazy dog"
|
||||
})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["v", "i", "w", "s", "o", "escape"])
|
||||
.await;
|
||||
cx.assert_shared_state(indoc! {
|
||||
"ˇo quick brown
|
||||
fox jumps over
|
||||
the lazy dog"
|
||||
})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["j", "w", "."]).await;
|
||||
deterministic.run_until_parked();
|
||||
cx.assert_shared_state(indoc! {
|
||||
"o quick brown
|
||||
fox ˇops over
|
||||
the lazy dog"
|
||||
})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["f", "r", "."]).await;
|
||||
deterministic.run_until_parked();
|
||||
cx.assert_shared_state(indoc! {
|
||||
"o quick brown
|
||||
fox ops oveˇothe lazy dog"
|
||||
})
|
||||
.await;
|
||||
|
||||
// visual
|
||||
cx.set_shared_state(indoc! {
|
||||
"the ˇquick brown
|
||||
fox jumps over
|
||||
fox jumps over
|
||||
fox jumps over
|
||||
the lazy dog"
|
||||
})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["v", "j", "x"]).await;
|
||||
cx.assert_shared_state(indoc! {
|
||||
"the ˇumps over
|
||||
fox jumps over
|
||||
fox jumps over
|
||||
the lazy dog"
|
||||
})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["."]).await;
|
||||
deterministic.run_until_parked();
|
||||
cx.assert_shared_state(indoc! {
|
||||
"the ˇumps over
|
||||
fox jumps over
|
||||
the lazy dog"
|
||||
})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["w", "."]).await;
|
||||
deterministic.run_until_parked();
|
||||
cx.assert_shared_state(indoc! {
|
||||
"the umps ˇumps over
|
||||
the lazy dog"
|
||||
})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["j", "."]).await;
|
||||
deterministic.run_until_parked();
|
||||
cx.assert_shared_state(indoc! {
|
||||
"the umps umps over
|
||||
the ˇog"
|
||||
})
|
||||
.await;
|
||||
|
||||
// block mode (3 rows)
|
||||
cx.set_shared_state(indoc! {
|
||||
"ˇthe quick brown
|
||||
fox jumps over
|
||||
the lazy dog"
|
||||
})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["ctrl-v", "j", "j", "shift-i", "o", "escape"])
|
||||
.await;
|
||||
cx.assert_shared_state(indoc! {
|
||||
"ˇothe quick brown
|
||||
ofox jumps over
|
||||
othe lazy dog"
|
||||
})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["j", "4", "l", "."]).await;
|
||||
deterministic.run_until_parked();
|
||||
cx.assert_shared_state(indoc! {
|
||||
"othe quick brown
|
||||
ofoxˇo jumps over
|
||||
otheo lazy dog"
|
||||
})
|
||||
.await;
|
||||
|
||||
// line mode
|
||||
cx.set_shared_state(indoc! {
|
||||
"ˇthe quick brown
|
||||
fox jumps over
|
||||
the lazy dog"
|
||||
})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["shift-v", "shift-r", "o", "escape"])
|
||||
.await;
|
||||
cx.assert_shared_state(indoc! {
|
||||
"ˇo
|
||||
fox jumps over
|
||||
the lazy dog"
|
||||
})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["j", "."]).await;
|
||||
deterministic.run_until_parked();
|
||||
cx.assert_shared_state(indoc! {
|
||||
"o
|
||||
ˇo
|
||||
the lazy dog"
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
@@ -67,7 +67,8 @@ fn scroll_editor(editor: &mut Editor, amount: &ScrollAmount, cx: &mut ViewContex
|
||||
let top_anchor = editor.scroll_manager.anchor().anchor;
|
||||
|
||||
editor.change_selections(None, cx, |s| {
|
||||
s.move_heads_with(|map, head, goal| {
|
||||
s.move_with(|map, selection| {
|
||||
let head = selection.head();
|
||||
let top = top_anchor.to_display_point(map);
|
||||
let min_row = top.row() + VERTICAL_SCROLL_MARGIN as u32;
|
||||
let max_row = top.row() + visible_rows - VERTICAL_SCROLL_MARGIN as u32 - 1;
|
||||
@@ -79,7 +80,11 @@ fn scroll_editor(editor: &mut Editor, amount: &ScrollAmount, cx: &mut ViewContex
|
||||
} else {
|
||||
head
|
||||
};
|
||||
(new_head, goal)
|
||||
if selection.is_empty() {
|
||||
selection.collapse_to(new_head, selection.goal)
|
||||
} else {
|
||||
selection.set_head(new_head, selection.goal)
|
||||
};
|
||||
})
|
||||
});
|
||||
}
|
||||
@@ -90,12 +95,35 @@ mod test {
|
||||
use crate::{state::Mode, test::VimTestContext};
|
||||
use gpui::geometry::vector::vec2f;
|
||||
use indoc::indoc;
|
||||
use language::Point;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_scroll(cx: &mut gpui::TestAppContext) {
|
||||
let mut cx = VimTestContext::new(cx, true).await;
|
||||
|
||||
cx.set_state(indoc! {"ˇa\nb\nc\nd\ne\n"}, Mode::Normal);
|
||||
let window = cx.window;
|
||||
let line_height =
|
||||
cx.editor(|editor, cx| editor.style(cx).text.line_height(cx.font_cache()));
|
||||
window.simulate_resize(vec2f(1000., 8.0 * line_height - 1.0), &mut cx);
|
||||
|
||||
cx.set_state(
|
||||
indoc!(
|
||||
"ˇone
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
six
|
||||
seven
|
||||
eight
|
||||
nine
|
||||
ten
|
||||
eleven
|
||||
twelve
|
||||
"
|
||||
),
|
||||
Mode::Normal,
|
||||
);
|
||||
|
||||
cx.update_editor(|editor, cx| {
|
||||
assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 0.))
|
||||
@@ -112,5 +140,33 @@ mod test {
|
||||
cx.update_editor(|editor, cx| {
|
||||
assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 2.))
|
||||
});
|
||||
|
||||
// does not select in normal mode
|
||||
cx.simulate_keystrokes(["g", "g"]);
|
||||
cx.update_editor(|editor, cx| {
|
||||
assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 0.))
|
||||
});
|
||||
cx.simulate_keystrokes(["ctrl-d"]);
|
||||
cx.update_editor(|editor, cx| {
|
||||
assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 2.0));
|
||||
assert_eq!(
|
||||
editor.selections.newest(cx).range(),
|
||||
Point::new(5, 0)..Point::new(5, 0)
|
||||
)
|
||||
});
|
||||
|
||||
// does select in visual mode
|
||||
cx.simulate_keystrokes(["g", "g"]);
|
||||
cx.update_editor(|editor, cx| {
|
||||
assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 0.))
|
||||
});
|
||||
cx.simulate_keystrokes(["v", "ctrl-d"]);
|
||||
cx.update_editor(|editor, cx| {
|
||||
assert_eq!(editor.snapshot(cx).scroll_position(), vec2f(0., 2.0));
|
||||
assert_eq!(
|
||||
editor.selections.newest(cx).range(),
|
||||
Point::new(0, 0)..Point::new(5, 1)
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,34 @@
|
||||
use gpui::WindowContext;
|
||||
use editor::movement;
|
||||
use gpui::{actions, AppContext, WindowContext};
|
||||
use language::Point;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::{motion::Motion, utils::copy_selections_content, Mode, Vim};
|
||||
|
||||
pub fn substitute(vim: &mut Vim, count: Option<usize>, cx: &mut WindowContext) {
|
||||
let line_mode = vim.state().mode == Mode::VisualLine;
|
||||
actions!(vim, [Substitute, SubstituteLine]);
|
||||
|
||||
pub(crate) fn init(cx: &mut AppContext) {
|
||||
cx.add_action(|_: &mut Workspace, _: &Substitute, cx| {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.start_recording(cx);
|
||||
let count = vim.pop_number_operator(cx);
|
||||
substitute(vim, count, vim.state().mode == Mode::VisualLine, cx);
|
||||
})
|
||||
});
|
||||
|
||||
cx.add_action(|_: &mut Workspace, _: &SubstituteLine, cx| {
|
||||
Vim::update(cx, |vim, cx| {
|
||||
vim.start_recording(cx);
|
||||
if matches!(vim.state().mode, Mode::VisualBlock | Mode::Visual) {
|
||||
vim.switch_mode(Mode::VisualLine, false, cx)
|
||||
}
|
||||
let count = vim.pop_number_operator(cx);
|
||||
substitute(vim, count, true, cx)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn substitute(vim: &mut Vim, count: Option<usize>, line_mode: bool, cx: &mut WindowContext) {
|
||||
vim.update_active_editor(cx, |editor, cx| {
|
||||
editor.set_clip_at_line_ends(false, cx);
|
||||
editor.transact(cx, |editor, cx| {
|
||||
@@ -14,6 +38,11 @@ pub fn substitute(vim: &mut Vim, count: Option<usize>, cx: &mut WindowContext) {
|
||||
Motion::Right.expand_selection(map, selection, count, true);
|
||||
}
|
||||
if line_mode {
|
||||
// in Visual mode when the selection contains the newline at the end
|
||||
// of the line, we should exclude it.
|
||||
if !selection.is_empty() && selection.end.column() == 0 {
|
||||
selection.end = movement::left(map, selection.end);
|
||||
}
|
||||
Motion::CurrentLine.expand_selection(map, selection, None, false);
|
||||
if let Some((point, _)) = (Motion::FirstNonWhitespace {
|
||||
display_lines: false,
|
||||
@@ -166,4 +195,68 @@ mod test {
|
||||
the laˇzy dog"})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_substitute_line(cx: &mut gpui::TestAppContext) {
|
||||
let mut cx = NeovimBackedTestContext::new(cx).await;
|
||||
|
||||
let initial_state = indoc! {"
|
||||
The quick brown
|
||||
fox juˇmps over
|
||||
the lazy dog
|
||||
"};
|
||||
|
||||
// normal mode
|
||||
cx.set_shared_state(initial_state).await;
|
||||
cx.simulate_shared_keystrokes(["shift-s", "o"]).await;
|
||||
cx.assert_shared_state(indoc! {"
|
||||
The quick brown
|
||||
oˇ
|
||||
the lazy dog
|
||||
"})
|
||||
.await;
|
||||
|
||||
// visual mode
|
||||
cx.set_shared_state(initial_state).await;
|
||||
cx.simulate_shared_keystrokes(["v", "k", "shift-s", "o"])
|
||||
.await;
|
||||
cx.assert_shared_state(indoc! {"
|
||||
oˇ
|
||||
the lazy dog
|
||||
"})
|
||||
.await;
|
||||
|
||||
// visual block mode
|
||||
cx.set_shared_state(initial_state).await;
|
||||
cx.simulate_shared_keystrokes(["ctrl-v", "j", "shift-s", "o"])
|
||||
.await;
|
||||
cx.assert_shared_state(indoc! {"
|
||||
The quick brown
|
||||
oˇ
|
||||
"})
|
||||
.await;
|
||||
|
||||
// visual mode including newline
|
||||
cx.set_shared_state(initial_state).await;
|
||||
cx.simulate_shared_keystrokes(["v", "$", "shift-s", "o"])
|
||||
.await;
|
||||
cx.assert_shared_state(indoc! {"
|
||||
The quick brown
|
||||
oˇ
|
||||
the lazy dog
|
||||
"})
|
||||
.await;
|
||||
|
||||
// indentation
|
||||
cx.set_neovim_option("shiftwidth=4").await;
|
||||
cx.set_shared_state(initial_state).await;
|
||||
cx.simulate_shared_keystrokes([">", ">", "shift-s", "o"])
|
||||
.await;
|
||||
cx.assert_shared_state(indoc! {"
|
||||
The quick brown
|
||||
oˇ
|
||||
the lazy dog
|
||||
"})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use editor::{char_kind, display_map::DisplaySnapshot, movement, Bias, CharKind, DisplayPoint};
|
||||
use editor::{
|
||||
char_kind,
|
||||
display_map::DisplaySnapshot,
|
||||
movement::{self, FindRange},
|
||||
Bias, CharKind, DisplayPoint,
|
||||
};
|
||||
use gpui::{actions, impl_actions, AppContext, WindowContext};
|
||||
use language::Selection;
|
||||
use serde::Deserialize;
|
||||
@@ -177,18 +182,22 @@ fn in_word(
|
||||
ignore_punctuation: bool,
|
||||
) -> Option<Range<DisplayPoint>> {
|
||||
// Use motion::right so that we consider the character under the cursor when looking for the start
|
||||
let language = map.buffer_snapshot.language_at(relative_to.to_point(map));
|
||||
let start = movement::find_preceding_boundary_in_line(
|
||||
let scope = map
|
||||
.buffer_snapshot
|
||||
.language_scope_at(relative_to.to_point(map));
|
||||
let start = movement::find_preceding_boundary(
|
||||
map,
|
||||
right(map, relative_to, 1),
|
||||
movement::FindRange::SingleLine,
|
||||
|left, right| {
|
||||
char_kind(language, left).coerce_punctuation(ignore_punctuation)
|
||||
!= char_kind(language, right).coerce_punctuation(ignore_punctuation)
|
||||
char_kind(&scope, left).coerce_punctuation(ignore_punctuation)
|
||||
!= char_kind(&scope, right).coerce_punctuation(ignore_punctuation)
|
||||
},
|
||||
);
|
||||
let end = movement::find_boundary_in_line(map, relative_to, |left, right| {
|
||||
char_kind(language, left).coerce_punctuation(ignore_punctuation)
|
||||
!= char_kind(language, right).coerce_punctuation(ignore_punctuation)
|
||||
|
||||
let end = movement::find_boundary(map, relative_to, FindRange::SingleLine, |left, right| {
|
||||
char_kind(&scope, left).coerce_punctuation(ignore_punctuation)
|
||||
!= char_kind(&scope, right).coerce_punctuation(ignore_punctuation)
|
||||
});
|
||||
|
||||
Some(start..end)
|
||||
@@ -211,11 +220,13 @@ fn around_word(
|
||||
relative_to: DisplayPoint,
|
||||
ignore_punctuation: bool,
|
||||
) -> Option<Range<DisplayPoint>> {
|
||||
let language = map.buffer_snapshot.language_at(relative_to.to_point(map));
|
||||
let scope = map
|
||||
.buffer_snapshot
|
||||
.language_scope_at(relative_to.to_point(map));
|
||||
let in_word = map
|
||||
.chars_at(relative_to)
|
||||
.next()
|
||||
.map(|(c, _)| char_kind(language, c) != CharKind::Whitespace)
|
||||
.map(|(c, _)| char_kind(&scope, c) != CharKind::Whitespace)
|
||||
.unwrap_or(false);
|
||||
|
||||
if in_word {
|
||||
@@ -239,21 +250,24 @@ fn around_next_word(
|
||||
relative_to: DisplayPoint,
|
||||
ignore_punctuation: bool,
|
||||
) -> Option<Range<DisplayPoint>> {
|
||||
let language = map.buffer_snapshot.language_at(relative_to.to_point(map));
|
||||
let scope = map
|
||||
.buffer_snapshot
|
||||
.language_scope_at(relative_to.to_point(map));
|
||||
// Get the start of the word
|
||||
let start = movement::find_preceding_boundary_in_line(
|
||||
let start = movement::find_preceding_boundary(
|
||||
map,
|
||||
right(map, relative_to, 1),
|
||||
FindRange::SingleLine,
|
||||
|left, right| {
|
||||
char_kind(language, left).coerce_punctuation(ignore_punctuation)
|
||||
!= char_kind(language, right).coerce_punctuation(ignore_punctuation)
|
||||
char_kind(&scope, left).coerce_punctuation(ignore_punctuation)
|
||||
!= char_kind(&scope, right).coerce_punctuation(ignore_punctuation)
|
||||
},
|
||||
);
|
||||
|
||||
let mut word_found = false;
|
||||
let end = movement::find_boundary(map, relative_to, |left, right| {
|
||||
let left_kind = char_kind(language, left).coerce_punctuation(ignore_punctuation);
|
||||
let right_kind = char_kind(language, right).coerce_punctuation(ignore_punctuation);
|
||||
let end = movement::find_boundary(map, relative_to, FindRange::MultiLine, |left, right| {
|
||||
let left_kind = char_kind(&scope, left).coerce_punctuation(ignore_punctuation);
|
||||
let right_kind = char_kind(&scope, right).coerce_punctuation(ignore_punctuation);
|
||||
|
||||
let found = (word_found && left_kind != right_kind) || right == '\n' && left == '\n';
|
||||
|
||||
@@ -566,11 +580,18 @@ mod test {
|
||||
async fn test_visual_word_object(cx: &mut gpui::TestAppContext) {
|
||||
let mut cx = NeovimBackedTestContext::new(cx).await;
|
||||
|
||||
cx.set_shared_state("The quick ˇbrown\nfox").await;
|
||||
/*
|
||||
cx.set_shared_state("The quick ˇbrown\nfox").await;
|
||||
cx.simulate_shared_keystrokes(["v"]).await;
|
||||
cx.assert_shared_state("The quick «bˇ»rown\nfox").await;
|
||||
cx.simulate_shared_keystrokes(["i", "w"]).await;
|
||||
cx.assert_shared_state("The quick «brownˇ»\nfox").await;
|
||||
*/
|
||||
cx.set_shared_state("The quick brown\nˇ\nfox").await;
|
||||
cx.simulate_shared_keystrokes(["v"]).await;
|
||||
cx.assert_shared_state("The quick «bˇ»rown\nfox").await;
|
||||
cx.assert_shared_state("The quick brown\n«\nˇ»fox").await;
|
||||
cx.simulate_shared_keystrokes(["i", "w"]).await;
|
||||
cx.assert_shared_state("The quick «brownˇ»\nfox").await;
|
||||
cx.assert_shared_state("The quick brown\n«\nˇ»fox").await;
|
||||
|
||||
cx.assert_binding_matches_all(["v", "i", "w"], WORD_LOCATIONS)
|
||||
.await;
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use gpui::keymap_matcher::KeymapContext;
|
||||
use std::{ops::Range, sync::Arc};
|
||||
|
||||
use gpui::{keymap_matcher::KeymapContext, Action};
|
||||
use language::CursorShape;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use workspace::searchable::Direction;
|
||||
@@ -48,10 +50,61 @@ pub struct EditorState {
|
||||
pub operator_stack: Vec<Operator>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug)]
|
||||
pub enum RecordedSelection {
|
||||
#[default]
|
||||
None,
|
||||
Visual {
|
||||
rows: u32,
|
||||
cols: u32,
|
||||
},
|
||||
SingleLine {
|
||||
cols: u32,
|
||||
},
|
||||
VisualBlock {
|
||||
rows: u32,
|
||||
cols: u32,
|
||||
},
|
||||
VisualLine {
|
||||
rows: u32,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct WorkspaceState {
|
||||
pub search: SearchState,
|
||||
pub last_find: Option<Motion>,
|
||||
|
||||
pub recording: bool,
|
||||
pub stop_recording_after_next_action: bool,
|
||||
pub replaying: bool,
|
||||
pub recorded_count: Option<usize>,
|
||||
pub recorded_actions: Vec<ReplayableAction>,
|
||||
pub recorded_selection: RecordedSelection,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ReplayableAction {
|
||||
Action(Box<dyn Action>),
|
||||
Insertion {
|
||||
text: Arc<str>,
|
||||
utf16_range_to_replace: Option<Range<isize>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Clone for ReplayableAction {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
Self::Action(action) => Self::Action(action.boxed_clone()),
|
||||
Self::Insertion {
|
||||
text,
|
||||
utf16_range_to_replace,
|
||||
} => Self::Insertion {
|
||||
text: text.clone(),
|
||||
utf16_range_to_replace: utf16_range_to_replace.clone(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
||||
@@ -286,6 +286,55 @@ async fn test_word_characters(cx: &mut gpui::TestAppContext) {
|
||||
)
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_join_lines(cx: &mut gpui::TestAppContext) {
|
||||
let mut cx = NeovimBackedTestContext::new(cx).await;
|
||||
|
||||
cx.set_shared_state(indoc! {"
|
||||
ˇone
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
six
|
||||
"})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["shift-j"]).await;
|
||||
cx.assert_shared_state(indoc! {"
|
||||
oneˇ two
|
||||
three
|
||||
four
|
||||
five
|
||||
six
|
||||
"})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["3", "shift-j"]).await;
|
||||
cx.assert_shared_state(indoc! {"
|
||||
one two threeˇ four
|
||||
five
|
||||
six
|
||||
"})
|
||||
.await;
|
||||
|
||||
cx.set_shared_state(indoc! {"
|
||||
ˇone
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
six
|
||||
"})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["j", "v", "3", "j", "shift-j"])
|
||||
.await;
|
||||
cx.assert_shared_state(indoc! {"
|
||||
one
|
||||
two three fourˇ five
|
||||
six
|
||||
"})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_wrapped_lines(cx: &mut gpui::TestAppContext) {
|
||||
let mut cx = NeovimBackedTestContext::new(cx).await;
|
||||
@@ -431,6 +480,31 @@ async fn test_wrapped_lines(cx: &mut gpui::TestAppContext) {
|
||||
twelve char
|
||||
"})
|
||||
.await;
|
||||
|
||||
// line wraps as:
|
||||
// fourteen ch
|
||||
// ar
|
||||
// fourteen ch
|
||||
// ar
|
||||
cx.set_shared_state(indoc! { "
|
||||
fourteen chaˇr
|
||||
fourteen char
|
||||
"})
|
||||
.await;
|
||||
|
||||
cx.simulate_shared_keystrokes(["d", "i", "w"]).await;
|
||||
cx.assert_shared_state(indoc! {"
|
||||
fourteenˇ•
|
||||
fourteen char
|
||||
"})
|
||||
.await;
|
||||
cx.simulate_shared_keystrokes(["j", "shift-f", "e", "f", "r"])
|
||||
.await;
|
||||
cx.assert_shared_state(indoc! {"
|
||||
fourteen•
|
||||
fourteen chaˇr
|
||||
"})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
||||
@@ -153,6 +153,7 @@ impl<'a> NeovimBackedTestContext<'a> {
|
||||
}
|
||||
|
||||
pub async fn assert_shared_state(&mut self, marked_text: &str) {
|
||||
let marked_text = marked_text.replace("•", " ");
|
||||
let neovim = self.neovim_state().await;
|
||||
let editor = self.editor_state();
|
||||
if neovim == marked_text && neovim == editor {
|
||||
@@ -184,9 +185,9 @@ impl<'a> NeovimBackedTestContext<'a> {
|
||||
message,
|
||||
initial_state,
|
||||
self.recent_keystrokes.join(" "),
|
||||
marked_text,
|
||||
neovim,
|
||||
editor
|
||||
marked_text.replace(" \n", "•\n"),
|
||||
neovim.replace(" \n", "•\n"),
|
||||
editor.replace(" \n", "•\n")
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user