Compare commits

..

3 Commits

Author SHA1 Message Date
KCaverly
a29b70552f working inline with semantic index, this should never merge 2023-09-29 14:02:41 -04:00
KCaverly
3cc499ee7c Merge branch 'main' of github.com:zed-industries/zed into semantic_codegen 2023-09-27 10:39:45 -04:00
KCaverly
e21a92284f bring in useful pieces from better context work 2023-09-27 10:17:32 -04:00
273 changed files with 7206 additions and 16761 deletions

View File

@@ -2,4 +2,11 @@
Release Notes:
- N/A
or
- (Added|Fixed|Improved) ... ([#<public_issue_number_if_exists>](https://github.com/zed-industries/community/issues/<public_issue_number_if_exists>)).
If the release notes are only intended for a specific release channel only, add `(<release_channel>-only)` to the end of the release note line.
These will be removed by the person making the release.

View File

@@ -6,8 +6,8 @@ jobs:
discord_release:
runs-on: ubuntu-latest
steps:
- name: Get release URL
id: get-release-url
- name: Get appropriate URL
id: get-appropriate-url
run: |
if [ "${{ github.event.release.prerelease }}" == "true" ]; then
URL="https://zed.dev/releases/preview/latest"
@@ -15,19 +15,14 @@ jobs:
URL="https://zed.dev/releases/stable/latest"
fi
echo "::set-output name=URL::$URL"
- name: Get content
uses: 2428392/gh-truncate-string-action@v1.2.0
id: get-content
with:
stringToTruncate: |
📣 Zed ${{ github.event.release.tag_name }} was just released!
Restart your Zed or head to ${{ steps.get-release-url.outputs.URL }} to grab it.
${{ github.event.release.body }}
maxLength: 2000
- name: Discord Webhook Action
uses: tsickert/discord-webhook@v5.3.0
with:
webhook-url: ${{ secrets.DISCORD_WEBHOOK_URL }}
content: ${{ steps.get-content.outputs.string }}
content: |
📣 Zed ${{ github.event.release.tag_name }} was just released!
Restart your Zed or head to ${{ steps.get-appropriate-url.outputs.URL }} to grab it.
${{ github.event.release.body }}

707
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -64,7 +64,6 @@ members = [
"crates/sqlez",
"crates/sqlez_macros",
"crates/feature_flags",
"crates/rich_text",
"crates/storybook",
"crates/sum_tree",
"crates/terminal",
@@ -106,14 +105,12 @@ rand = { version = "0.8.5" }
refineable = { path = "./crates/refineable" }
regex = { version = "1.5" }
rust-embed = { version = "8.0", features = ["include-exclude"] }
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
schemars = { version = "0.8" }
serde = { version = "1.0", features = ["derive", "rc"] }
serde_derive = { version = "1.0", features = ["deserialize_in_place"] }
serde_json = { version = "1.0", features = ["preserve_order", "raw_value"] }
smallvec = { version = "1.6", features = ["union"] }
smol = { version = "1.2" }
sysinfo = "0.29.10"
tempdir = { version = "0.3.7" }
thiserror = { version = "1.0.29" }
time = { version = "0.3", features = ["serde", "serde-well-known"] }
@@ -122,7 +119,6 @@ tree-sitter = "0.20"
unindent = { version = "0.1.7" }
pretty_assertions = "1.3.0"
git2 = { version = "0.15", default-features = false}
uuid = { version = "1.1.2", features = ["v4"] }
tree-sitter-bash = { git = "https://github.com/tree-sitter/tree-sitter-bash", rev = "1b0321ee85701d5036c334a6f04761cdc672e64c" }
tree-sitter-c = "0.20.1"

View File

@@ -1,6 +1,6 @@
# syntax = docker/dockerfile:1.2
FROM rust:1.73-bullseye as builder
FROM rust:1.72-bullseye as builder
WORKDIR app
COPY . .

View File

@@ -1,4 +1,4 @@
web: cd ../zed.dev && PORT=3000 npm run dev
collab: cd crates/collab && RUST_LOG=${RUST_LOG:-collab=info} cargo run serve
web: cd ../zed.dev && PORT=3000 npx vercel dev
collab: cd crates/collab && cargo run serve
livekit: livekit-server --dev
postgrest: postgrest crates/collab/admin_api.conf

View File

@@ -83,7 +83,9 @@ foreman start
If you want to run Zed pointed at the local servers, you can run:
```
script/zed-local
script/zed-with-local-servers
# or...
script/zed-with-local-servers --release
```
### Dump element JSON

View File

@@ -250,7 +250,6 @@
"bindings": {
"escape": "project_search::ToggleFocus",
"alt-tab": "search::CycleMode",
"cmd-shift-h": "search::ToggleReplace",
"alt-cmd-g": "search::ActivateRegexMode",
"alt-cmd-s": "search::ActivateSemanticMode",
"alt-cmd-x": "search::ActivateTextMode"
@@ -275,7 +274,6 @@
"bindings": {
"escape": "project_search::ToggleFocus",
"alt-tab": "search::CycleMode",
"cmd-shift-h": "search::ToggleReplace",
"alt-cmd-g": "search::ActivateRegexMode",
"alt-cmd-s": "search::ActivateSemanticMode",
"alt-cmd-x": "search::ActivateTextMode"
@@ -287,7 +285,6 @@
"cmd-f": "project_search::ToggleFocus",
"cmd-g": "search::SelectNextMatch",
"cmd-shift-g": "search::SelectPrevMatch",
"cmd-shift-h": "search::ToggleReplace",
"alt-enter": "search::SelectAllMatches",
"alt-cmd-c": "search::ToggleCaseSensitive",
"alt-cmd-w": "search::ToggleWholeWord",

View File

@@ -95,7 +95,6 @@
}
],
"ctrl-o": "pane::GoBack",
"ctrl-i": "pane::GoForward",
"ctrl-]": "editor::GoToDefinition",
"escape": [
"vim::SwitchMode",
@@ -146,7 +145,6 @@
"g shift-s": "project_symbols::Toggle",
"g .": "editor::ToggleCodeActions", // zed specific
"g shift-a": "editor::FindAllReferences", // zed specific
"g space": "editor::OpenExcerpts", // zed specific
"g *": [
"vim::MoveToNext",
{
@@ -408,7 +406,6 @@
"vim::PushOperator",
"Yank"
],
"shift-y": "vim::YankLine",
"i": "vim::InsertBefore",
"shift-i": "vim::InsertFirstNonWhitespace",
"a": "vim::InsertAfter",
@@ -418,8 +415,6 @@
"o": "vim::InsertLineBelow",
"shift-o": "vim::InsertLineAbove",
"~": "vim::ChangeCase",
"ctrl-a": "vim::Increment",
"ctrl-x": "vim::Decrement",
"p": "vim::Paste",
"shift-p": [
"vim::Paste",
@@ -531,20 +526,6 @@
"shift-r": "vim::SubstituteLine",
"c": "vim::Substitute",
"~": "vim::ChangeCase",
"ctrl-a": "vim::Increment",
"ctrl-x": "vim::Decrement",
"g ctrl-a": [
"vim::Increment",
{
"step": true
}
],
"g ctrl-x": [
"vim::Decrement",
{
"step": true
}
],
"shift-i": "vim::InsertBefore",
"shift-a": "vim::InsertAfter",
"shift-j": "vim::JoinLines",
@@ -585,16 +566,11 @@
}
},
{
"context": "Editor && vim_mode == insert",
"context": "Editor && vim_mode == insert && !menu",
"bindings": {
"escape": "vim::NormalBefore",
"ctrl-c": "vim::NormalBefore",
"ctrl-[": "vim::NormalBefore",
"ctrl-x ctrl-o": "editor::ShowCompletions",
"ctrl-x ctrl-a": "assistant::InlineAssist", // zed specific
"ctrl-x ctrl-c": "copilot::Suggest", // zed specific
"ctrl-x ctrl-l": "editor::ToggleCodeActions", // zed specific
"ctrl-x ctrl-z": "editor::Cancel"
"ctrl-[": "vim::NormalBefore"
}
},
{

View File

@@ -76,7 +76,7 @@
// Settings related to calls in Zed
"calls": {
// Join calls with the microphone muted by default
"mute_on_join": false
"mute_on_join": true
},
// Scrollbar related settings
"scrollbar": {
@@ -227,11 +227,6 @@
},
// Automatically update Zed
"auto_update": true,
// Diagnostics configuration.
"diagnostics": {
// Whether to show warnings or not by default.
"include_warnings": true
},
// Git gutter behavior configuration.
"git": {
// Control whether the git gutter is shown. May take 2 values:
@@ -361,7 +356,7 @@
".venv",
"venv"
],
// Can also be 'csh', 'fish', and `nushell`
// Can also be 'csh' and 'fish'
"activate_script": "default"
}
}
@@ -375,28 +370,28 @@
},
// Difference settings for semantic_index
"semantic_index": {
"enabled": true
"enabled": false
},
// Settings specific to our elixir integration
"elixir": {
// Change the LSP zed uses for elixir.
// Set Zed to use the experimental Next LS LSP server.
// Note that changing this setting requires a restart of Zed
// to take effect.
//
// May take 3 values:
// 1. Use the standard ElixirLS, this is the default
// "lsp": "elixir_ls"
// 2. Use the experimental NextLs
// "lsp": "next_ls",
// 3. Use a language server installed locally on your machine:
// "lsp": {
// 1. Use the standard elixir-ls LSP server
// "next": "off"
// 2. Use a bundled version of the next Next LS LSP server
// "next": "on",
// 3. Use a local build of the next Next LS LSP server:
// "next": {
// "local": {
// "path": "~/next-ls/bin/start",
// "arguments": ["--stdio"]
// }
// },
//
"lsp": "elixir_ls"
"next": "off"
},
// Different settings for specific languages.
"languages": {

View File

@@ -11,6 +11,7 @@ doctest = false
[dependencies]
gpui = { path = "../gpui" }
util = { path = "../util" }
project = { path = "../project" }
async-trait.workspace = true
anyhow.workspace = true
futures.workspace = true
@@ -27,8 +28,9 @@ log.workspace = true
parse_duration = "2.1.1"
tiktoken-rs = "0.5.0"
matrixmultiply = "0.3.7"
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
bincode = "1.3.3"
erased-serde = "0.3"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }

View File

@@ -1,2 +1,50 @@
pub mod completion;
pub mod embedding;
pub mod function_calling;
pub mod skills;
use core::fmt;
use std::fmt::Display;
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}

View File

@@ -1,3 +1,4 @@
use crate::{OpenAIUsage, RequestMessage, Role};
use anyhow::{anyhow, Result};
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
@@ -6,48 +7,10 @@ use futures::{
use gpui::executor::Background;
use isahc::{http::StatusCode, Request, RequestExt};
use serde::{Deserialize, Serialize};
use std::{
fmt::{self, Display},
io,
sync::Arc,
};
use std::{io, sync::Arc};
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Default, Serialize)]
pub struct OpenAIRequest {
pub model: String,
@@ -61,13 +24,6 @@ pub struct ResponseMessage {
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAIUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,

View File

@@ -0,0 +1,176 @@
use crate::{OpenAIUsage, RequestMessage, Role};
use anyhow::anyhow;
use erased_serde::serialize_trait_object;
use futures::AsyncReadExt;
use isahc::{http::StatusCode, Request, RequestExt};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
pub trait OpenAIFunction: erased_serde::Serialize {
fn name(&self) -> String;
fn description(&self) -> String;
fn system_prompt(&self) -> String;
fn parameters(&self) -> serde_json::Value;
fn complete(&self, arguments: serde_json::Value) -> anyhow::Result<String>;
}
serialize_trait_object!(OpenAIFunction);
#[derive(Serialize)]
struct OpenAIFunctionCallingRequest {
model: String,
messages: Vec<RequestMessage>,
functions: Vec<Box<dyn OpenAIFunction>>,
temperature: f32,
}
#[derive(Debug, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
impl FunctionCall {
fn arguments(&self) -> anyhow::Result<serde_json::Value> {
serde_json::from_str(&self.arguments).map_err(|err| anyhow!(err))
}
}
#[derive(Debug, Deserialize)]
pub struct FunctionCallingMessage {
pub role: Option<Role>,
pub content: Option<String>,
pub function_call: FunctionCall,
}
#[derive(Debug, Deserialize)]
pub struct FunctionCallingChoice {
pub message: FunctionCallingMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct OpenAIFunctionCallingResponse {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<FunctionCallingChoice>,
pub usage: OpenAIUsage,
}
impl OpenAIFunctionCallingResponse {
fn get_details(&self) -> anyhow::Result<FunctionCallDetails> {
if let Some(choice) = self.choices.first() {
let name = choice.message.function_call.name.clone();
let arguments = choice.message.function_call.arguments()?;
Ok(FunctionCallDetails {
name,
arguments,
message: choice.message.content.clone(),
})
} else {
Err(anyhow!("no function call details available"))
}
}
}
#[derive(Debug)]
pub struct FunctionCallDetails {
pub name: String, // name of function to call
pub message: Option<String>, // message if provided
pub arguments: serde_json::Value, // json object respresenting provided arguments
}
#[derive(Clone)]
pub struct OpenAIFunctionCallingProvider {
api_key: String,
}
impl OpenAIFunctionCallingProvider {
pub fn new(api_key: String) -> Self {
Self { api_key }
}
fn generate_system_message(
&self,
messages: &Vec<RequestMessage>,
functions: &Vec<Box<dyn OpenAIFunction>>,
) -> RequestMessage {
let mut system_message = messages
.iter()
.filter_map(|message| {
if message.role == Role::System {
Some(message.content.clone())
} else {
None
}
})
.collect::<Vec<String>>()
.join("\n");
let function_string = functions
.iter()
.map(|function| format!("'{}'", function.name()))
.collect::<Vec<String>>()
.join(",");
writeln!(
system_message,
"You have access to the following functions: {function_string} you MUST return a function calling response using at one of the below functions."
)
.unwrap();
for function in functions {
writeln!(system_message, "\n{}", function.system_prompt()).unwrap();
}
RequestMessage {
role: Role::System,
content: system_message,
}
}
pub async fn complete(
&self,
model: String,
mut messages: Vec<RequestMessage>,
functions: Vec<Box<dyn OpenAIFunction>>,
) -> anyhow::Result<FunctionCallDetails> {
// TODO: Rename all this.
let mut system_message = vec![self.generate_system_message(&messages, &functions)];
messages.retain(|message| message.role != Role::System);
system_message.extend(messages);
// Lower temperature values, result in less randomness,
// this is helping keep the function calling consistent
let request = OpenAIFunctionCallingRequest {
model,
messages: system_message,
functions,
temperature: 0.0,
};
let json_data = serde_json::to_string(&request)?;
println!("\nREQUEST: {:?}\n", &json_data);
let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.body(json_data)?
.send_async()
.await?;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
println!("\nRESPONSE: {:?}\n", &body);
match response.status() {
StatusCode::OK => {
let response_data: OpenAIFunctionCallingResponse = serde_json::from_str(&body)?;
response_data.get_details()
}
_ => Err(anyhow!("open ai function calling failed: {:?}", body)),
}
}
}

50
crates/ai/src/skills.rs Normal file
View File

@@ -0,0 +1,50 @@
use crate::function_calling::OpenAIFunction;
use gpui::{AppContext, ModelHandle};
use project::Project;
use serde::{Serialize, Serializer};
use serde_json::json;
pub struct RewritePrompt;
impl Serialize for RewritePrompt {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
json!({"name": self.name(),
"description": self.description(),
"parameters": self.parameters()})
.serialize(serializer)
}
}
impl RewritePrompt {
pub fn load() -> Self {
Self {}
}
}
impl OpenAIFunction for RewritePrompt {
fn name(&self) -> String {
"rewrite_prompt".to_string()
}
fn description(&self) -> String {
"Rewrite prompt given prompt from user".to_string()
}
fn system_prompt(&self) -> String {
"'rewrite_prompt':
If all information is available in the above prompt, and you need no further information.
Rewrite the entire prompt to clarify what should be generated, do not actually complete the users request.
Assume this rewritten message will be passed to another completion agent, to fulfill the users request.".to_string()
}
fn parameters(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"prompt": {}
}
})
}
fn complete(&self, arguments: serde_json::Value) -> anyhow::Result<String> {
Ok(arguments.get("prompt").unwrap().to_string())
}
}

View File

@@ -21,8 +21,10 @@ search = { path = "../search" }
settings = { path = "../settings" }
theme = { path = "../theme" }
util = { path = "../util" }
uuid = { version = "1.1.2", features = ["v4"] }
workspace = { path = "../workspace" }
uuid.workspace = true
semantic_index = { path = "../semantic_index" }
project = { path = "../project" }
anyhow.workspace = true
chrono = { version = "0.4", features = ["serde"] }

View File

@@ -4,7 +4,7 @@ mod codegen;
mod prompts;
mod streaming_diff;
use ai::completion::Role;
use ai::Role;
use anyhow::Result;
pub use assistant_panel::AssistantPanel;
use assistant_settings::OpenAIModel;

View File

@@ -1,13 +1,16 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
codegen::{self, Codegen, CodegenKind},
prompts::generate_content_prompt,
prompts::{generate_codegen_planning_prompt, generate_content_prompt},
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
SavedMessage,
};
use ai::completion::{
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
use ai::{
completion::{stream_completion, OpenAICompletionProvider, OpenAIRequest, OPENAI_API_URL},
function_calling::OpenAIFunctionCallingProvider,
skills::RewritePrompt,
};
use ai::{function_calling::OpenAIFunction, RequestMessage};
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
@@ -17,7 +20,7 @@ use editor::{
BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint,
},
scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset, ToPoint,
Anchor, Editor, MoveDown, MoveUp, MultiBufferSnapshot, ToOffset,
};
use fs::Fs;
use futures::StreamExt;
@@ -35,7 +38,9 @@ use gpui::{
WindowContext,
};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
use project::Project;
use search::BufferSearchBar;
use semantic_index::{skills::RepositoryContextRetriever, SemanticIndex};
use settings::SettingsStore;
use std::{
cell::{Cell, RefCell},
@@ -145,6 +150,8 @@ pub struct AssistantPanel {
include_conversation_in_next_inline_assist: bool,
inline_prompt_history: VecDeque<String>,
_watch_saved_conversations: Task<Result<()>>,
semantic_index: ModelHandle<SemanticIndex>,
project: ModelHandle<Project>,
}
impl AssistantPanel {
@@ -154,6 +161,7 @@ impl AssistantPanel {
workspace: WeakViewHandle<Workspace>,
cx: AsyncAppContext,
) -> Task<Result<ViewHandle<Self>>> {
let index = cx.read(|cx| SemanticIndex::global(cx).unwrap());
cx.spawn(|mut cx| async move {
let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
let saved_conversations = SavedConversationMetadata::list(fs.clone())
@@ -191,6 +199,9 @@ impl AssistantPanel {
toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
toolbar
});
let project = workspace.project().clone();
let mut this = Self {
workspace: workspace_handle,
active_editor_index: Default::default(),
@@ -215,6 +226,8 @@ impl AssistantPanel {
include_conversation_in_next_inline_assist: false,
inline_prompt_history: Default::default(),
_watch_saved_conversations,
semantic_index: index,
project,
};
let mut old_dock_position = this.position(cx);
@@ -274,42 +287,35 @@ impl AssistantPanel {
return;
};
let selection = editor.read(cx).selections.newest_anchor().clone();
if selection.start.excerpt_id() != selection.end.excerpt_id() {
return;
}
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
// Extend the selection to the start and the end of the line.
let mut point_selection = selection.map(|selection| selection.to_point(&snapshot));
if point_selection.end > point_selection.start {
point_selection.start.column = 0;
// If the selection ends at the start of the line, we don't want to include it.
if point_selection.end.column == 0 {
point_selection.end.row -= 1;
}
point_selection.end.column = snapshot.line_len(point_selection.end.row);
}
let codegen_kind = if point_selection.start == point_selection.end {
let provider = Arc::new(OpenAICompletionProvider::new(
api_key.clone(),
cx.background().clone(),
));
let fc_provider = OpenAIFunctionCallingProvider::new(api_key);
let selection = editor.read(cx).selections.newest_anchor().clone();
let codegen_kind = if editor.read(cx).selections.newest::<usize>(cx).is_empty() {
CodegenKind::Generate {
position: snapshot.anchor_after(point_selection.start),
position: selection.start,
}
} else {
CodegenKind::Transform {
range: snapshot.anchor_before(point_selection.start)
..snapshot.anchor_after(point_selection.end),
range: selection.start..selection.end,
}
};
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let provider = Arc::new(OpenAICompletionProvider::new(
api_key,
cx.background().clone(),
));
let project = self.project.clone();
let codegen = cx.add_model(|cx| {
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
Codegen::new(
editor.read(cx).buffer().clone(),
codegen_kind,
provider,
fc_provider,
cx,
project.clone(),
)
});
let measurements = Rc::new(Cell::new(BlockMeasurements::default()));
@@ -333,7 +339,7 @@ impl AssistantPanel {
editor.insert_blocks(
[BlockProperties {
style: BlockStyle::Flex,
position: snapshot.anchor_before(point_selection.head()),
position: selection.head().bias_left(&snapshot),
height: 2,
render: Arc::new({
let inline_assistant = inline_assistant.clone();
@@ -560,26 +566,25 @@ impl AssistantPanel {
self.inline_prompt_history.pop_front();
}
let codegen = pending_assist.codegen.clone();
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
let range = codegen.read(cx).range();
let start = snapshot.point_to_buffer_offset(range.start);
let end = snapshot.point_to_buffer_offset(range.end);
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
let (start_buffer, start_buffer_offset) = start;
let (end_buffer, end_buffer_offset) = end;
if start_buffer.remote_id() == end_buffer.remote_id() {
(start_buffer.clone(), start_buffer_offset..end_buffer_offset)
} else {
self.finish_inline_assist(inline_assist_id, false, cx);
return;
}
let multi_buffer = editor.read(cx).buffer().read(cx);
let multi_buffer_snapshot = multi_buffer.snapshot(cx);
let snapshot = if multi_buffer.is_singleton() {
multi_buffer.as_singleton().unwrap().read(cx).snapshot()
} else {
self.finish_inline_assist(inline_assist_id, false, cx);
return;
};
let language = buffer.language_at(range.start);
let range = pending_assist.codegen.read(cx).range();
let language_range = snapshot.anchor_at(
range.start.to_offset(&multi_buffer_snapshot),
language::Bias::Left,
)
..snapshot.anchor_at(
range.end.to_offset(&multi_buffer_snapshot),
language::Bias::Right,
);
let language = snapshot.language_at(language_range.start);
let language_name = if let Some(language) = language.as_ref() {
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
None
@@ -589,45 +594,77 @@ impl AssistantPanel {
} else {
None
};
let language_name = language_name.as_deref();
let codegen_kind = codegen.read(cx).kind().clone();
let user_prompt = user_prompt.to_string();
let codegen_kind = pending_assist.codegen.read(cx).kind().clone();
let index = self.semantic_index.clone();
let mut messages = Vec::new();
let mut model = settings::get::<AssistantSettings>(cx)
.default_open_ai_model
.clone();
if let Some(conversation) = conversation {
let conversation = conversation.read(cx);
let buffer = conversation.buffer.read(cx);
messages.extend(
conversation
.messages(cx)
.map(|message| message.to_open_ai_message(buffer)),
);
model = conversation.model.clone();
}
let prompt = cx.background().spawn(async move {
let language_name = language_name.as_deref();
generate_content_prompt(user_prompt, language_name, &buffer, range, codegen_kind)
pending_assist.codegen.update(cx, |codegen, cx| {
codegen.start(
user_prompt.to_string(),
cx,
language_name,
snapshot,
language_range.clone(),
codegen_kind.clone(),
index,
)
});
cx.spawn(|_, mut cx| async move {
let prompt = prompt.await;
// let api_key = self.api_key.as_ref().clone().into_inner().clone().unwrap();
// let function_provider = OpenAIFunctionCallingProvider::new(api_key);
messages.push(RequestMessage {
role: Role::User,
content: prompt,
});
let request = OpenAIRequest {
model: model.full_name().into(),
messages,
stream: true,
};
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
})
.detach();
// let planning_messages = vec![RequestMessage {
// role: Role::User,
// content: planning_prompt,
// }];
// println!("GETTING HERE");
// let function_call = cx
// .spawn(|this, mut cx| async move {
// let result = function_provider
// .complete("gpt-4".to_string(), planning_messages, functions)
// .await;
// dbg!(&result);
// result
// })
// .detach();
// let function_name = function_call.name.as_str();
// let prompt = match function_name {
// "rewrite_prompt" => {
// let user_prompt = RewritePrompt::load()
// .complete(function_call.arguments)
// .unwrap();
// generate_content_prompt(
// user_prompt.to_string(),
// language_name,
// &snapshot,
// language_range,
// cx,
// codegen_kind,
// )
// }
// _ => {
// todo!();
// }
// };
// let mut messages = Vec::new();
// let mut model = settings::get::<AssistantSettings>(cx)
// .default_open_ai_model
// .clone();
// if let Some(conversation) = conversation {
// let conversation = conversation.read(cx);
// let buffer = conversation.buffer.read(cx);
// messages.extend(
// conversation
// .messages(cx)
// .map(|message| message.to_open_ai_message(buffer)),
// );
// model = conversation.model.clone();
// }
}
fn update_highlights_for_editor(

View File

@@ -1,10 +1,22 @@
use crate::streaming_diff::{Hunk, StreamingDiff};
use ai::completion::{CompletionProvider, OpenAIRequest};
use crate::{
prompts::{generate_codegen_planning_prompt, generate_content_prompt},
streaming_diff::{Hunk, StreamingDiff},
};
use ai::{
completion::{CompletionProvider, OpenAIRequest},
function_calling::{OpenAIFunction, OpenAIFunctionCallingProvider},
skills::RewritePrompt,
RequestMessage, Role,
};
use anyhow::Result;
use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
use editor::{
multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{Entity, ModelContext, ModelHandle, Task};
use language::{Rope, TransactionId};
use gpui::{BorrowAppContext, Entity, ModelContext, ModelHandle, Task};
use language::{BufferSnapshot, Rope, TransactionId};
use project::Project;
use semantic_index::{skills::RepositoryContextRetriever, SemanticIndex};
use std::{cmp, future, ops::Range, sync::Arc};
pub enum Event {
@@ -20,6 +32,7 @@ pub enum CodegenKind {
pub struct Codegen {
provider: Arc<dyn CompletionProvider>,
fc_provider: OpenAIFunctionCallingProvider,
buffer: ModelHandle<MultiBuffer>,
snapshot: MultiBufferSnapshot,
kind: CodegenKind,
@@ -29,6 +42,7 @@ pub struct Codegen {
generation: Task<()>,
idle: bool,
_subscription: gpui::Subscription,
project: ModelHandle<Project>,
}
impl Entity for Codegen {
@@ -38,13 +52,31 @@ impl Entity for Codegen {
impl Codegen {
pub fn new(
buffer: ModelHandle<MultiBuffer>,
kind: CodegenKind,
mut kind: CodegenKind,
provider: Arc<dyn CompletionProvider>,
fc_provider: OpenAIFunctionCallingProvider,
cx: &mut ModelContext<Self>,
project: ModelHandle<Project>,
) -> Self {
let snapshot = buffer.read(cx).snapshot(cx);
match &mut kind {
CodegenKind::Transform { range } => {
let mut point_range = range.to_point(&snapshot);
point_range.start.column = 0;
if point_range.end.column > 0 || point_range.start.row == point_range.end.row {
point_range.end.column = snapshot.line_len(point_range.end.row);
}
range.start = snapshot.anchor_before(point_range.start);
range.end = snapshot.anchor_after(point_range.end);
}
CodegenKind::Generate { position } => {
*position = position.bias_right(&snapshot);
}
}
Self {
provider,
fc_provider,
buffer: buffer.clone(),
snapshot,
kind,
@@ -54,6 +86,7 @@ impl Codegen {
idle: true,
generation: Task::ready(()),
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
project,
}
}
@@ -95,7 +128,17 @@ impl Codegen {
self.error.as_ref()
}
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
pub fn start(
&mut self,
prompt: String,
cx: &mut ModelContext<Self>,
language_name: Option<&str>,
buffer: BufferSnapshot,
range: Range<language::Anchor>,
kind: CodegenKind,
index: ModelHandle<SemanticIndex>,
) {
let language_range = range.clone();
let range = self.range();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
@@ -109,9 +152,101 @@ impl Codegen {
.next()
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
let response = self.provider.complete(prompt);
let messages = vec![RequestMessage {
role: Role::User,
content: prompt.clone(),
}];
let request = OpenAIRequest {
model: "gpt-4".to_string(),
messages: messages.clone(),
stream: true,
};
let (planning_prompt, outline) = generate_codegen_planning_prompt(
prompt.clone(),
language_name.clone(),
&buffer,
language_range.clone(),
cx,
kind.clone(),
);
let project = self.project.clone();
self.generation = cx.spawn_weak(|this, mut cx| {
// Plan Ahead
let planning_messages = vec![RequestMessage {
role: Role::User,
content: planning_prompt,
}];
let repo_retriever = RepositoryContextRetriever::load(index, project);
let functions: Vec<Box<dyn OpenAIFunction>> = vec![
Box::new(RewritePrompt::load()),
Box::new(repo_retriever.clone()),
];
let completion_provider = self.provider.clone();
let fc_provider = self.fc_provider.clone();
let language_name = language_name.clone();
let language_name = if let Some(language_name) = language_name.clone() {
Some(language_name.to_string())
} else {
None
};
let kind = kind.clone();
async move {
let mut user_prompt = prompt.clone();
let user_prompt = if let Ok(function_call) = fc_provider
.complete("gpt-4".to_string(), planning_messages, functions)
.await
{
let function_name = function_call.name.as_str();
println!("FUNCTION NAME: {:?}", function_name);
let user_prompt = match function_name {
"rewrite_prompt" => {
let user_prompt = RewritePrompt::load()
.complete(function_call.arguments)
.unwrap();
generate_content_prompt(
user_prompt,
language_name,
outline,
kind,
vec![],
)
}
_ => {
let arguments = function_call.arguments.clone();
let snippet = repo_retriever
.complete_test(arguments, &mut cx)
.await
.unwrap();
let snippet = vec![snippet];
generate_content_prompt(prompt, language_name, outline, kind, snippet)
}
};
user_prompt
} else {
user_prompt
};
println!("{:?}", user_prompt.clone());
let messages = vec![RequestMessage {
role: Role::User,
content: user_prompt.clone(),
}];
let request = OpenAIRequest {
model: "gpt-4".to_string(),
messages: messages.clone(),
stream: true,
};
let response = completion_provider.complete(request);
let generate = async {
let mut edit_start = range.start.to_offset(&snapshot);
@@ -332,315 +467,317 @@ fn strip_markdown_codeblock(
})
}
#[cfg(test)]
mod tests {
use super::*;
use futures::{
future::BoxFuture,
stream::{self, BoxStream},
};
use gpui::{executor::Deterministic, TestAppContext};
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use parking_lot::Mutex;
use rand::prelude::*;
use settings::SettingsStore;
use smol::future::FutureExt;
// #[cfg(test)]
// mod tests {
// use super::*;
// use futures::{
// future::BoxFuture,
// stream::{self, BoxStream},
// };
// use gpui::{executor::Deterministic, TestAppContext};
// use indoc::indoc;
// use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
// use parking_lot::Mutex;
// use rand::prelude::*;
// use settings::SettingsStore;
// use smol::future::FutureExt;
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
// #[gpui::test(iterations = 10)]
// async fn test_transform_autoindent(
// cx: &mut TestAppContext,
// mut rng: StdRng,
// deterministic: Arc<Deterministic>,
// ) {
// cx.set_global(cx.read(SettingsStore::test));
// cx.update(language_settings::init);
let text = indoc! {"
fn main() {
let x = 0;
for _ in 0..10 {
x += 1;
}
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Transform { range },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
// let text = indoc! {"
// fn main() {
// let x = 0;
// for _ in 0..10 {
// x += 1;
// }
// }
// "};
// let buffer =
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
// let range = buffer.read_with(cx, |buffer, cx| {
// let snapshot = buffer.snapshot(cx);
// snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
// });
// let provider = Arc::new(TestCompletionProvider::new());
// let fc_provider = OpenAIFunctionCallingProvider::new("".to_string());
// let codegen = cx.add_model(|cx| {
// Codegen::new(
// buffer.clone(),
// CodegenKind::Transform { range },
// provider.clone(),
// fc_provider,
// cx,
// )
// });
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = concat!(
" let mut x = 0;\n",
" while x < 10 {\n",
" x += 1;\n",
" }",
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
// let mut new_text = concat!(
// " let mut x = 0;\n",
// " while x < 10 {\n",
// " x += 1;\n",
// " }",
// );
// while !new_text.is_empty() {
// let max_len = cmp::min(new_text.len(), 10);
// let len = rng.gen_range(1..=max_len);
// let (chunk, suffix) = new_text.split_at(len);
// provider.send_completion(chunk);
// new_text = suffix;
// deterministic.run_until_parked();
// }
// provider.finish_completion();
// deterministic.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
// assert_eq!(
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
// indoc! {"
// fn main() {
// let mut x = 0;
// while x < 10 {
// x += 1;
// }
// }
// "}
// );
// }
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_past_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
// #[gpui::test(iterations = 10)]
// async fn test_autoindent_when_generating_past_indentation(
// cx: &mut TestAppContext,
// mut rng: StdRng,
// deterministic: Arc<Deterministic>,
// ) {
// cx.set_global(cx.read(SettingsStore::test));
// cx.update(language_settings::init);
let text = indoc! {"
fn main() {
le
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
// let text = indoc! {"
// fn main() {
// le
// }
// "};
// let buffer =
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
// let position = buffer.read_with(cx, |buffer, cx| {
// let snapshot = buffer.snapshot(cx);
// snapshot.anchor_before(Point::new(1, 6))
// });
// let provider = Arc::new(TestCompletionProvider::new());
// let codegen = cx.add_model(|cx| {
// Codegen::new(
// buffer.clone(),
// CodegenKind::Generate { position },
// provider.clone(),
// cx,
// )
// });
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = concat!(
"t mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
// let mut new_text = concat!(
// "t mut x = 0;\n",
// "while x < 10 {\n",
// " x += 1;\n",
// "}", //
// );
// while !new_text.is_empty() {
// let max_len = cmp::min(new_text.len(), 10);
// let len = rng.gen_range(1..=max_len);
// let (chunk, suffix) = new_text.split_at(len);
// provider.send_completion(chunk);
// new_text = suffix;
// deterministic.run_until_parked();
// }
// provider.finish_completion();
// deterministic.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
// assert_eq!(
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
// indoc! {"
// fn main() {
// let mut x = 0;
// while x < 10 {
// x += 1;
// }
// }
// "}
// );
// }
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_before_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
// #[gpui::test(iterations = 10)]
// async fn test_autoindent_when_generating_before_indentation(
// cx: &mut TestAppContext,
// mut rng: StdRng,
// deterministic: Arc<Deterministic>,
// ) {
// cx.set_global(cx.read(SettingsStore::test));
// cx.update(language_settings::init);
let text = concat!(
"fn main() {\n",
" \n",
"}\n" //
);
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
// let text = concat!(
// "fn main() {\n",
// " \n",
// "}\n" //
// );
// let buffer =
// cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
// let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
// let position = buffer.read_with(cx, |buffer, cx| {
// let snapshot = buffer.snapshot(cx);
// snapshot.anchor_before(Point::new(1, 2))
// });
// let provider = Arc::new(TestCompletionProvider::new());
// let codegen = cx.add_model(|cx| {
// Codegen::new(
// buffer.clone(),
// CodegenKind::Generate { position },
// provider.clone(),
// cx,
// )
// });
// codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = concat!(
"let mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
// let mut new_text = concat!(
// "let mut x = 0;\n",
// "while x < 10 {\n",
// " x += 1;\n",
// "}", //
// );
// while !new_text.is_empty() {
// let max_len = cmp::min(new_text.len(), 10);
// let len = rng.gen_range(1..=max_len);
// let (chunk, suffix) = new_text.split_at(len);
// provider.send_completion(chunk);
// new_text = suffix;
// deterministic.run_until_parked();
// }
// provider.finish_completion();
// deterministic.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
// assert_eq!(
// buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
// indoc! {"
// fn main() {
// let mut x = 0;
// while x < 10 {
// x += 1;
// }
// }
// "}
// );
// }
#[gpui::test]
async fn test_strip_markdown_codeblock() {
assert_eq!(
strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"```js\nLorem ipsum dolor\n```"
);
assert_eq!(
strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"``\nLorem ipsum dolor\n```"
);
// #[gpui::test]
// async fn test_strip_markdown_codeblock() {
// assert_eq!(
// strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "Lorem ipsum dolor"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "Lorem ipsum dolor"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "Lorem ipsum dolor"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "Lorem ipsum dolor"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "```js\nLorem ipsum dolor\n```"
// );
// assert_eq!(
// strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
// .map(|chunk| chunk.unwrap())
// .collect::<String>()
// .await,
// "``\nLorem ipsum dolor\n```"
// );
fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
stream::iter(
text.chars()
.collect::<Vec<_>>()
.chunks(size)
.map(|chunk| Ok(chunk.iter().collect::<String>()))
.collect::<Vec<_>>(),
)
}
}
// fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
// stream::iter(
// text.chars()
// .collect::<Vec<_>>()
// .chunks(size)
// .map(|chunk| Ok(chunk.iter().collect::<String>()))
// .collect::<Vec<_>>(),
// )
// }
// }
struct TestCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
// struct TestCompletionProvider {
// last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
// }
impl TestCompletionProvider {
fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
// impl TestCompletionProvider {
// fn new() -> Self {
// Self {
// last_completion_tx: Mutex::new(None),
// }
// }
fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
// fn send_completion(&self, completion: impl Into<String>) {
// let mut tx = self.last_completion_tx.lock();
// tx.as_mut().unwrap().try_send(completion.into()).unwrap();
// }
fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
// fn finish_completion(&self) {
// self.last_completion_tx.lock().take().unwrap();
// }
// }
impl CompletionProvider for TestCompletionProvider {
fn complete(
&self,
_prompt: OpenAIRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
}
// impl CompletionProvider for TestCompletionProvider {
// fn complete(
// &self,
// _prompt: OpenAIRequest,
// ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
// let (tx, rx) = mpsc::channel(1);
// *self.last_completion_tx.lock() = Some(tx);
// async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
// }
// }
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_indents_query(
r#"
(call_expression) @indent
(field_expression) @indent
(_ "(" ")" @end) @indent
(_ "{" "}" @end) @indent
"#,
)
.unwrap()
}
}
// fn rust_lang() -> Language {
// Language::new(
// LanguageConfig {
// name: "Rust".into(),
// path_suffixes: vec!["rs".to_string()],
// ..Default::default()
// },
// Some(tree_sitter_rust::language()),
// )
// .with_indents_query(
// r#"
// (call_expression) @indent
// (field_expression) @indent
// (_ "(" ")" @end) @indent
// (_ "{" "}" @end) @indent
// "#,
// )
// .unwrap()
// }
// }

View File

@@ -1,160 +1,180 @@
use crate::codegen::CodegenKind;
use gpui::{AppContext, AsyncAppContext};
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp::{self, Reverse};
use std::fmt::Write;
use std::cmp;
use std::ops::Range;
use std::{fmt::Write, iter};
#[allow(dead_code)]
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
#[derive(Debug)]
struct Match {
collapse: Range<usize>,
keep: Vec<Range<usize>>,
}
use crate::codegen::CodegenKind;
let selected_range = selected_range.to_offset(buffer);
let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
Some(&grammar.embedding_config.as_ref()?.query)
});
let configs = ts_matches
.grammars()
.iter()
.map(|g| g.embedding_config.as_ref().unwrap())
.collect::<Vec<_>>();
let mut matches = Vec::new();
while let Some(mat) = ts_matches.peek() {
let config = &configs[mat.grammar_index];
if let Some(collapse) = mat.captures.iter().find_map(|cap| {
if Some(cap.index) == config.collapse_capture_ix {
Some(cap.node.byte_range())
fn outline_for_prompt(
buffer: &BufferSnapshot,
range: Range<language::Anchor>,
cx: &AppContext,
) -> Option<String> {
let indent = buffer
.language_indent_size_at(0, cx)
.chars()
.collect::<String>();
let outline = buffer.outline(None)?;
let range = range.to_offset(buffer);
let mut text = String::new();
let mut items = outline.items.into_iter().peekable();
let mut intersected = false;
let mut intersection_indent = 0;
let mut extended_range = range.clone();
while let Some(item) = items.next() {
let item_range = item.range.to_offset(buffer);
if item_range.end < range.start || item_range.start > range.end {
text.extend(iter::repeat(indent.as_str()).take(item.depth));
text.push_str(&item.text);
text.push('\n');
} else {
intersected = true;
let is_terminal = items
.peek()
.map_or(true, |next_item| next_item.depth <= item.depth);
if is_terminal {
if item_range.start <= extended_range.start {
extended_range.start = item_range.start;
intersection_indent = item.depth;
}
extended_range.end = cmp::max(extended_range.end, item_range.end);
} else {
None
}
}) {
let mut keep = Vec::new();
for capture in mat.captures.iter() {
if Some(capture.index) == config.keep_capture_ix {
keep.push(capture.node.byte_range());
let name_start = item_range.start + item.name_ranges.first().unwrap().start;
let name_end = item_range.start + item.name_ranges.last().unwrap().end;
if range.start > name_end {
text.extend(iter::repeat(indent.as_str()).take(item.depth));
text.push_str(&item.text);
text.push('\n');
} else {
continue;
if name_start <= extended_range.start {
extended_range.start = item_range.start;
intersection_indent = item.depth;
}
extended_range.end = cmp::max(extended_range.end, name_end);
}
}
ts_matches.advance();
matches.push(Match { collapse, keep });
} else {
ts_matches.advance();
}
}
matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
let mut matches = matches.into_iter().peekable();
let mut summary = String::new();
let mut offset = 0;
let mut flushed_selection = false;
while let Some(mat) = matches.next() {
// Keep extending the collapsed range if the next match surrounds
// the current one.
while let Some(next_mat) = matches.peek() {
if mat.collapse.start <= next_mat.collapse.start
&& mat.collapse.end >= next_mat.collapse.end
{
matches.next().unwrap();
if intersected
&& items.peek().map_or(true, |next_item| {
next_item.range.start.to_offset(buffer) > range.end
})
{
intersected = false;
text.extend(iter::repeat(indent.as_str()).take(intersection_indent));
text.extend(buffer.text_for_range(extended_range.start..range.start));
text.push_str("<|START|");
text.extend(buffer.text_for_range(range.clone()));
if range.start != range.end {
text.push_str("|END|>");
} else {
break;
text.push_str(">");
}
text.extend(buffer.text_for_range(range.end..extended_range.end));
text.push('\n');
}
if offset > mat.collapse.start {
// Skip collapsed nodes that have already been summarized.
offset = cmp::max(offset, mat.collapse.end);
continue;
}
if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
if !flushed_selection {
// The collapsed node ends after the selection starts, so we'll flush the selection first.
summary.extend(buffer.text_for_range(offset..selected_range.start));
summary.push_str("<|START|");
if selected_range.end == selected_range.start {
summary.push_str(">");
} else {
summary.extend(buffer.text_for_range(selected_range.clone()));
summary.push_str("|END|>");
}
offset = selected_range.end;
flushed_selection = true;
}
// If the selection intersects the collapsed node, we won't collapse it.
if selected_range.end >= mat.collapse.start {
continue;
}
}
summary.extend(buffer.text_for_range(offset..mat.collapse.start));
for keep in mat.keep {
summary.extend(buffer.text_for_range(keep));
}
offset = mat.collapse.end;
}
// Flush selection if we haven't already done so.
if !flushed_selection && offset <= selected_range.start {
summary.extend(buffer.text_for_range(offset..selected_range.start));
summary.push_str("<|START|");
if selected_range.end == selected_range.start {
summary.push_str(">");
} else {
summary.extend(buffer.text_for_range(selected_range.clone()));
summary.push_str("|END|>");
}
offset = selected_range.end;
}
summary.extend(buffer.text_for_range(offset..buffer.len()));
summary
Some(text)
}
pub fn generate_content_prompt(
pub fn generate_codegen_planning_prompt(
user_prompt: String,
language_name: Option<&str>,
buffer: &BufferSnapshot,
range: Range<impl ToOffset>,
range: Range<language::Anchor>,
cx: &AppContext,
kind: CodegenKind,
) -> String {
let range = range.to_offset(buffer);
) -> (String, Option<String>) {
let mut prompt = String::new();
// General Preamble
if let Some(language_name) = language_name {
writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
} else {
writeln!(prompt, "You're an expert engineer.\n").unwrap();
writeln!(prompt, "You're an expert software engineer.\n").unwrap();
}
let mut content = String::new();
content.extend(buffer.text_for_range(0..range.start));
if range.start == range.end {
content.push_str("<|START|>");
} else {
content.push_str("<|START|");
let outline = outline_for_prompt(buffer, range.clone(), cx);
if let Some(outline) = outline.clone() {
writeln!(
prompt,
"You're currently working inside the Zed editor on a file with the following outline:"
)
.unwrap();
if let Some(language_name) = language_name {
let language_name = language_name.to_lowercase();
writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
} else {
writeln!(prompt, "```\n{outline}\n```").unwrap();
}
}
content.extend(buffer.text_for_range(range.clone()));
if range.start != range.end {
content.push_str("|END|>");
match kind {
CodegenKind::Generate { position: _ } => {
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.").unwrap();
}
CodegenKind::Transform { range: _ } => {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
}
}
content.extend(buffer.text_for_range(range.end..buffer.len()));
writeln!(
prompt,
"The file you are currently working on has the following content:"
"The user has provided the following prompt: '{user_prompt}'\n"
)
.unwrap();
if let Some(language_name) = language_name {
let language_name = language_name.to_lowercase();
writeln!(prompt, "```{language_name}\n{content}\n```").unwrap();
writeln!(
prompt,
"It is your task to identify if any additional context is needed from the repository"
)
.unwrap();
(prompt, outline)
}
pub fn generate_content_prompt(
user_prompt: String,
language_name: Option<String>,
outline: Option<String>,
kind: CodegenKind,
snippet: Vec<String>,
) -> String {
let mut prompt = String::new();
// General Preamble
if let Some(language_name) = language_name.clone() {
writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
} else {
writeln!(prompt, "```\n{content}\n```").unwrap();
writeln!(prompt, "You're an expert software engineer.\n").unwrap();
}
if snippet.len() > 0 {
writeln!(
prompt,
"Here are a few snippets from the codebase which may help: "
);
}
for snip in snippet {
writeln!(prompt, "{snip}");
}
if let Some(outline) = outline {
writeln!(
prompt,
"The file you are currently working on has the following outline:"
)
.unwrap();
if let Some(language_name) = language_name.clone() {
let language_name = language_name.to_lowercase();
writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
} else {
writeln!(prompt, "```\n{outline}\n```").unwrap();
}
}
match kind {
@@ -220,37 +240,39 @@ pub(crate) mod tests {
},
Some(tree_sitter_rust::language()),
)
.with_embedding_query(
.with_indents_query(
r#"
(
[(line_comment) (attribute_item)]* @context
.
[
(struct_item
name: (_) @name)
(enum_item
name: (_) @name)
(impl_item
trait: (_)? @name
"for"? @name
type: (_) @name)
(trait_item
name: (_) @name)
(function_item
name: (_) @name
body: (block
"{" @keep
"}" @keep) @collapse)
(macro_definition
name: (_) @name)
] @item
)
"#,
(call_expression) @indent
(field_expression) @indent
(_ "(" ")" @end) @indent
(_ "{" "}" @end) @indent
"#,
)
.unwrap()
.with_outline_query(
r#"
(struct_item
"struct" @context
name: (_) @name) @item
(enum_item
"enum" @context
name: (_) @name) @item
(enum_variant
name: (_) @name) @item
(field_declaration
name: (_) @name) @item
(impl_item
"impl" @context
trait: (_)? @name
"for"? @context
type: (_) @name) @item
(function_item
"fn" @context
name: (_) @name) @item
(mod_item
"mod" @context
name: (_) @name) @item
"#,
)
.unwrap()
}
@@ -286,133 +308,132 @@ pub(crate) mod tests {
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let snapshot = buffer.read(cx).snapshot();
let outline = outline_for_prompt(
&snapshot,
snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_before(Point::new(1, 4)),
cx,
);
assert_eq!(
summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
indoc! {"
struct X {
<|START|>a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
outline.as_deref(),
Some(indoc! {"
struct X
<|START|>a: usize
b
impl X
fn new
fn a
fn b
"})
);
let outline = outline_for_prompt(
&snapshot,
snapshot.anchor_before(Point::new(8, 12))..snapshot.anchor_before(Point::new(8, 14)),
cx,
);
assert_eq!(
summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
outline.as_deref(),
Some(indoc! {"
struct X
a
b
impl X
fn new() -> Self {
let <|START|a |END|>= 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
fn a
fn b
"})
);
let outline = outline_for_prompt(
&snapshot,
snapshot.anchor_before(Point::new(6, 0))..snapshot.anchor_before(Point::new(6, 0)),
cx,
);
assert_eq!(
summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
outline.as_deref(),
Some(indoc! {"
struct X
a
b
impl X
<|START|>
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
fn new
fn a
fn b
"})
);
let outline = outline_for_prompt(
&snapshot,
snapshot.anchor_before(Point::new(8, 12))..snapshot.anchor_before(Point::new(13, 9)),
cx,
);
assert_eq!(
summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
<|START|>"}
);
// Ensure nested functions get collapsed properly.
let text = indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let a = 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {
let a = 30;
fn nested() -> usize {
3
outline.as_deref(),
Some(indoc! {"
struct X
a
b
impl X
fn new() -> Self {
let <|START|a = 1;
let b = 2;
Self { a, b }
}
self.a + nested()
}
pub fn b(&self) -> usize {
self.b
}
}
"};
buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
let snapshot = buffer.read(cx).snapshot();
pub f|END|>n a(&self, param: bool) -> usize {
self.a
}
fn b
"})
);
let outline = outline_for_prompt(
&snapshot,
snapshot.anchor_before(Point::new(5, 6))..snapshot.anchor_before(Point::new(12, 0)),
cx,
);
assert_eq!(
summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
indoc! {"
<|START|>struct X {
a: usize,
b: usize,
}
outline.as_deref(),
Some(indoc! {"
struct X
a
b
impl X<|START| {
impl X {
fn new() -> Self {
let a = 1;
let b = 2;
Self { a, b }
}
|END|>
fn a
fn b
"})
);
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
let outline = outline_for_prompt(
&snapshot,
snapshot.anchor_before(Point::new(18, 8))..snapshot.anchor_before(Point::new(18, 8)),
cx,
);
assert_eq!(
outline.as_deref(),
Some(indoc! {"
struct X
a
b
impl X
fn new
fn a
pub fn b(&self) -> usize {
<|START|>self.b
}
"})
);
}
}

View File

@@ -115,15 +115,13 @@ pub fn check(_: &Check, cx: &mut AppContext) {
fn view_release_notes(_: &ViewReleaseNotes, cx: &mut AppContext) {
if let Some(auto_updater) = AutoUpdater::get(cx) {
let auto_updater = auto_updater.read(cx);
let server_url = &auto_updater.server_url;
let current_version = auto_updater.current_version;
let server_url = &auto_updater.read(cx).server_url;
let latest_release_url = if cx.has_global::<ReleaseChannel>()
&& *cx.global::<ReleaseChannel>() == ReleaseChannel::Preview
{
format!("{server_url}/releases/preview/{current_version}")
format!("{server_url}/releases/preview/latest")
} else {
format!("{server_url}/releases/stable/{current_version}")
format!("{server_url}/releases/stable/latest")
};
cx.platform().open_url(&latest_release_url);
}

View File

@@ -20,6 +20,7 @@ test-support = [
[dependencies]
audio = { path = "../audio" }
channel = { path = "../channel" }
client = { path = "../client" }
collections = { path = "../collections" }
gpui = { path = "../gpui" }

View File

@@ -2,22 +2,22 @@ pub mod call_settings;
pub mod participant;
pub mod room;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use audio::Audio;
use call_settings::CallSettings;
use client::{
proto, ClickhouseEvent, Client, TelemetrySettings, TypedEnvelope, User, UserStore,
ZED_ALWAYS_ACTIVE,
};
use channel::ChannelId;
use client::{proto, ClickhouseEvent, Client, TelemetrySettings, TypedEnvelope, User, UserStore};
use collections::HashSet;
use futures::{future::Shared, FutureExt};
use postage::watch;
use gpui::{
AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Subscription, Task,
WeakModelHandle,
};
use postage::watch;
use project::Project;
use std::sync::Arc;
pub use participant::ParticipantLocation;
pub use room::Room;
@@ -68,7 +68,6 @@ impl ActiveCall {
location: None,
pending_invites: Default::default(),
incoming_call: watch::channel(),
_subscriptions: vec![
client.add_request_handler(cx.handle(), Self::handle_incoming_call),
client.add_message_handler(cx.handle(), Self::handle_call_canceled),
@@ -78,7 +77,7 @@ impl ActiveCall {
}
}
pub fn channel_id(&self, cx: &AppContext) -> Option<u64> {
pub fn channel_id(&self, cx: &AppContext) -> Option<ChannelId> {
self.room()?.read(cx).channel_id()
}
@@ -290,10 +289,10 @@ impl ActiveCall {
&mut self,
channel_id: u64,
cx: &mut ModelContext<Self>,
) -> Task<Result<ModelHandle<Room>>> {
) -> Task<Result<()>> {
if let Some(room) = self.room().cloned() {
if room.read(cx).channel_id() == Some(channel_id) {
return Task::ready(Ok(room));
return Task::ready(Ok(()));
} else {
room.update(cx, |room, cx| room.clear_state(cx));
}
@@ -308,7 +307,7 @@ impl ActiveCall {
this.update(&mut cx, |this, cx| {
this.report_call_event("join channel", cx)
});
Ok(room)
Ok(())
})
}
@@ -349,22 +348,17 @@ impl ActiveCall {
}
}
pub fn location(&self) -> Option<&WeakModelHandle<Project>> {
self.location.as_ref()
}
pub fn set_location(
&mut self,
project: Option<&ModelHandle<Project>>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if project.is_some() || !*ZED_ALWAYS_ACTIVE {
self.location = project.map(|project| project.downgrade());
if let Some((room, _)) = self.room.as_ref() {
return room.update(cx, |room, cx| room.set_location(project, cx));
}
self.location = project.map(|project| project.downgrade());
if let Some((room, _)) = self.room.as_ref() {
room.update(cx, |room, cx| room.set_location(project, cx))
} else {
Task::ready(Ok(()))
}
Task::ready(Ok(()))
}
fn set_room(

View File

@@ -1,5 +1,4 @@
use anyhow::{anyhow, Result};
use client::ParticipantIndex;
use client::{proto, User};
use collections::HashMap;
use gpui::WeakModelHandle;
@@ -44,7 +43,6 @@ pub struct RemoteParticipant {
pub peer_id: proto::PeerId,
pub projects: Vec<proto::ParticipantProject>,
pub location: ParticipantLocation,
pub participant_index: ParticipantIndex,
pub muted: bool,
pub speaking: bool,
pub video_tracks: HashMap<live_kit_client::Sid, Arc<RemoteVideoTrack>>,

View File

@@ -7,7 +7,7 @@ use anyhow::{anyhow, Result};
use audio::{Audio, Sound};
use client::{
proto::{self, PeerId},
Client, ParticipantIndex, TypedEnvelope, User, UserStore,
Client, TypedEnvelope, User, UserStore,
};
use collections::{BTreeMap, HashMap, HashSet};
use fs::Fs;
@@ -18,7 +18,7 @@ use live_kit_client::{
LocalAudioTrack, LocalTrackPublication, LocalVideoTrack, RemoteAudioTrackUpdate,
RemoteVideoTrackUpdate,
};
use postage::{sink::Sink, stream::Stream, watch};
use postage::stream::Stream;
use project::Project;
use std::{future::Future, mem, pin::Pin, sync::Arc, time::Duration};
use util::{post_inc, ResultExt, TryFutureExt};
@@ -44,12 +44,6 @@ pub enum Event {
RemoteProjectUnshared {
project_id: u64,
},
RemoteProjectJoined {
project_id: u64,
},
RemoteProjectInvitationDiscarded {
project_id: u64,
},
Left,
}
@@ -70,8 +64,6 @@ pub struct Room {
user_store: ModelHandle<UserStore>,
follows_by_leader_id_project_id: HashMap<(PeerId, u64), Vec<PeerId>>,
subscriptions: Vec<client::Subscription>,
room_update_completed_tx: watch::Sender<Option<()>>,
room_update_completed_rx: watch::Receiver<Option<()>>,
pending_room_update: Option<Task<()>>,
maintain_connection: Option<Task<Option<()>>>,
}
@@ -106,10 +98,6 @@ impl Room {
self.channel_id
}
pub fn is_sharing_project(&self) -> bool {
!self.shared_projects.is_empty()
}
#[cfg(any(test, feature = "test-support"))]
pub fn is_connected(&self) -> bool {
if let Some(live_kit) = self.live_kit.as_ref() {
@@ -213,8 +201,6 @@ impl Room {
Audio::play_sound(Sound::Joined, cx);
let (room_update_completed_tx, room_update_completed_rx) = watch::channel();
Self {
id,
channel_id,
@@ -234,8 +220,6 @@ impl Room {
user_store,
follows_by_leader_id_project_id: Default::default(),
maintain_connection: Some(maintain_connection),
room_update_completed_tx,
room_update_completed_rx,
}
}
@@ -604,43 +588,6 @@ impl Room {
.map_or(&[], |v| v.as_slice())
}
/// Returns the most 'active' projects, defined as most people in the project
pub fn most_active_project(&self, cx: &AppContext) -> Option<(u64, u64)> {
let mut project_hosts_and_guest_counts = HashMap::<u64, (Option<u64>, u32)>::default();
for participant in self.remote_participants.values() {
match participant.location {
ParticipantLocation::SharedProject { project_id } => {
project_hosts_and_guest_counts
.entry(project_id)
.or_default()
.1 += 1;
}
ParticipantLocation::External | ParticipantLocation::UnsharedProject => {}
}
for project in &participant.projects {
project_hosts_and_guest_counts
.entry(project.id)
.or_default()
.0 = Some(participant.user.id);
}
}
if let Some(user) = self.user_store.read(cx).current_user() {
for project in &self.local_participant.projects {
project_hosts_and_guest_counts
.entry(project.id)
.or_default()
.0 = Some(user.id);
}
}
project_hosts_and_guest_counts
.into_iter()
.filter_map(|(id, (host, guest_count))| Some((id, host?, guest_count)))
.max_by_key(|(_, _, guest_count)| *guest_count)
.map(|(id, host, _)| (id, host))
}
async fn handle_room_updated(
this: ModelHandle<Self>,
envelope: TypedEnvelope<proto::RoomUpdated>,
@@ -704,7 +651,6 @@ impl Room {
let Some(peer_id) = participant.peer_id else {
continue;
};
let participant_index = ParticipantIndex(participant.participant_index);
this.participant_user_ids.insert(participant.user_id);
let old_projects = this
@@ -755,9 +701,8 @@ impl Room {
if let Some(remote_participant) =
this.remote_participants.get_mut(&participant.user_id)
{
remote_participant.peer_id = peer_id;
remote_participant.projects = participant.projects;
remote_participant.participant_index = participant_index;
remote_participant.peer_id = peer_id;
if location != remote_participant.location {
remote_participant.location = location;
cx.emit(Event::ParticipantLocationChanged {
@@ -769,7 +714,6 @@ impl Room {
participant.user_id,
RemoteParticipant {
user: user.clone(),
participant_index,
peer_id,
projects: participant.projects,
location,
@@ -863,17 +807,7 @@ impl Room {
let _ = this.leave(cx);
}
this.user_store.update(cx, |user_store, cx| {
let participant_indices_by_user_id = this
.remote_participants
.iter()
.map(|(user_id, participant)| (*user_id, participant.participant_index))
.collect();
user_store.set_participant_indices(participant_indices_by_user_id, cx);
});
this.check_invariants();
this.room_update_completed_tx.try_send(Some(())).ok();
cx.notify();
});
}));
@@ -882,17 +816,6 @@ impl Room {
Ok(())
}
pub fn room_update_completed(&mut self) -> impl Future<Output = ()> {
let mut done_rx = self.room_update_completed_rx.clone();
async move {
while let Some(result) = done_rx.next().await {
if result.is_some() {
break;
}
}
}
}
fn remote_video_track_updated(
&mut self,
change: RemoteVideoTrackUpdate,
@@ -1080,7 +1003,6 @@ impl Room {
) -> Task<Result<ModelHandle<Project>>> {
let client = self.client.clone();
let user_store = self.user_store.clone();
cx.emit(Event::RemoteProjectJoined { project_id: id });
cx.spawn(|this, mut cx| async move {
let project =
Project::remote(id, client, user_store, language_registry, fs, cx.clone()).await?;

View File

@@ -23,7 +23,6 @@ language = { path = "../language" }
settings = { path = "../settings" }
feature_flags = { path = "../feature_flags" }
sum_tree = { path = "../sum_tree" }
clock = { path = "../clock" }
anyhow.workspace = true
futures.workspace = true
@@ -39,7 +38,7 @@ smol.workspace = true
thiserror.workspace = true
time.workspace = true
tiny_http = "0.8"
uuid.workspace = true
uuid = { version = "1.1.2", features = ["v4"] }
url = "2.2"
serde.workspace = true
serde_derive.workspace = true

View File

@@ -2,21 +2,19 @@ mod channel_buffer;
mod channel_chat;
mod channel_store;
use client::{Client, UserStore};
use gpui::{AppContext, ModelHandle};
use std::sync::Arc;
pub use channel_buffer::{ChannelBuffer, ChannelBufferEvent, ACKNOWLEDGE_DEBOUNCE_INTERVAL};
pub use channel_buffer::{ChannelBuffer, ChannelBufferEvent};
pub use channel_chat::{ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId};
pub use channel_store::{
Channel, ChannelData, ChannelEvent, ChannelId, ChannelMembership, ChannelPath, ChannelStore,
};
use client::Client;
use std::sync::Arc;
#[cfg(test)]
mod channel_store_tests;
pub fn init(client: &Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
channel_store::init(client, user_store, cx);
pub fn init(client: &Arc<Client>) {
channel_buffer::init(client);
channel_chat::init(client);
}

View File

@@ -1,39 +1,31 @@
use crate::Channel;
use anyhow::Result;
use client::{Client, Collaborator, UserStore};
use collections::HashMap;
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
use language::proto::serialize_version;
use rpc::{
proto::{self, PeerId},
TypedEnvelope,
};
use std::{sync::Arc, time::Duration};
use client::Client;
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle};
use rpc::{proto, TypedEnvelope};
use std::sync::Arc;
use util::ResultExt;
pub const ACKNOWLEDGE_DEBOUNCE_INTERVAL: Duration = Duration::from_millis(250);
pub(crate) fn init(client: &Arc<Client>) {
client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer);
client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer_collaborators);
client.add_model_message_handler(ChannelBuffer::handle_add_channel_buffer_collaborator);
client.add_model_message_handler(ChannelBuffer::handle_remove_channel_buffer_collaborator);
client.add_model_message_handler(ChannelBuffer::handle_update_channel_buffer_collaborator);
}
pub struct ChannelBuffer {
pub(crate) channel: Arc<Channel>,
connected: bool,
collaborators: HashMap<PeerId, Collaborator>,
user_store: ModelHandle<UserStore>,
collaborators: Vec<proto::Collaborator>,
buffer: ModelHandle<language::Buffer>,
buffer_epoch: u64,
client: Arc<Client>,
subscription: Option<client::Subscription>,
acknowledge_task: Option<Task<Result<()>>>,
}
pub enum ChannelBufferEvent {
CollaboratorsChanged,
Disconnected,
BufferEdited,
}
impl Entity for ChannelBuffer {
@@ -41,9 +33,6 @@ impl Entity for ChannelBuffer {
fn release(&mut self, _: &mut AppContext) {
if self.connected {
if let Some(task) = self.acknowledge_task.take() {
task.detach();
}
self.client
.send(proto::LeaveChannelBuffer {
channel_id: self.channel.id,
@@ -57,7 +46,6 @@ impl ChannelBuffer {
pub(crate) async fn new(
channel: Arc<Channel>,
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
mut cx: AsyncAppContext,
) -> Result<ModelHandle<Self>> {
let response = client
@@ -73,6 +61,8 @@ impl ChannelBuffer {
.map(language::proto::deserialize_operation)
.collect::<Result<Vec<_>, _>>()?;
let collaborators = response.collaborators;
let buffer = cx.add_model(|_| {
language::Buffer::remote(response.buffer_id, response.replica_id as u16, base_text)
});
@@ -83,50 +73,34 @@ impl ChannelBuffer {
anyhow::Ok(cx.add_model(|cx| {
cx.subscribe(&buffer, Self::on_buffer_update).detach();
let mut this = Self {
Self {
buffer,
buffer_epoch: response.epoch,
client,
connected: true,
collaborators: Default::default(),
acknowledge_task: None,
collaborators,
channel,
subscription: Some(subscription.set_model(&cx.handle(), &mut cx.to_async())),
user_store,
};
this.replace_collaborators(response.collaborators, cx);
this
}
}))
}
pub fn remote_id(&self, cx: &AppContext) -> u64 {
self.buffer.read(cx).remote_id()
}
pub fn user_store(&self) -> &ModelHandle<UserStore> {
&self.user_store
}
pub(crate) fn replace_collaborators(
&mut self,
collaborators: Vec<proto::Collaborator>,
cx: &mut ModelContext<Self>,
) {
let mut new_collaborators = HashMap::default();
for collaborator in collaborators {
if let Ok(collaborator) = Collaborator::from_proto(collaborator) {
new_collaborators.insert(collaborator.peer_id, collaborator);
}
}
for (_, old_collaborator) in &self.collaborators {
if !new_collaborators.contains_key(&old_collaborator.peer_id) {
for old_collaborator in &self.collaborators {
if collaborators
.iter()
.any(|c| c.replica_id == old_collaborator.replica_id)
{
self.buffer.update(cx, |buffer, cx| {
buffer.remove_peer(old_collaborator.replica_id as u16, cx)
});
}
}
self.collaborators = new_collaborators;
self.collaborators = collaborators;
cx.emit(ChannelBufferEvent::CollaboratorsChanged);
cx.notify();
}
@@ -153,14 +127,64 @@ impl ChannelBuffer {
Ok(())
}
async fn handle_update_channel_buffer_collaborators(
async fn handle_add_channel_buffer_collaborator(
this: ModelHandle<Self>,
message: TypedEnvelope<proto::UpdateChannelBufferCollaborators>,
envelope: TypedEnvelope<proto::AddChannelBufferCollaborator>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
let collaborator = envelope.payload.collaborator.ok_or_else(|| {
anyhow::anyhow!(
"Should have gotten a collaborator in the AddChannelBufferCollaborator message"
)
})?;
this.update(&mut cx, |this, cx| {
this.collaborators.push(collaborator);
cx.emit(ChannelBufferEvent::CollaboratorsChanged);
cx.notify();
});
Ok(())
}
async fn handle_remove_channel_buffer_collaborator(
this: ModelHandle<Self>,
message: TypedEnvelope<proto::RemoveChannelBufferCollaborator>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, cx| {
this.replace_collaborators(message.payload.collaborators, cx);
this.collaborators.retain(|collaborator| {
if collaborator.peer_id == message.payload.peer_id {
this.buffer.update(cx, |buffer, cx| {
buffer.remove_peer(collaborator.replica_id as u16, cx)
});
false
} else {
true
}
});
cx.emit(ChannelBufferEvent::CollaboratorsChanged);
cx.notify();
});
Ok(())
}
async fn handle_update_channel_buffer_collaborator(
this: ModelHandle<Self>,
message: TypedEnvelope<proto::UpdateChannelBufferCollaborator>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
this.update(&mut cx, |this, cx| {
for collaborator in &mut this.collaborators {
if collaborator.peer_id == message.payload.old_peer_id {
collaborator.peer_id = message.payload.new_peer_id;
break;
}
}
cx.emit(ChannelBufferEvent::CollaboratorsChanged);
cx.notify();
});
@@ -172,43 +196,17 @@ impl ChannelBuffer {
&mut self,
_: ModelHandle<language::Buffer>,
event: &language::Event,
cx: &mut ModelContext<Self>,
_: &mut ModelContext<Self>,
) {
match event {
language::Event::Operation(operation) => {
let operation = language::proto::serialize_operation(operation);
self.client
.send(proto::UpdateChannelBuffer {
channel_id: self.channel.id,
operations: vec![operation],
})
.log_err();
}
language::Event::Edited => {
cx.emit(ChannelBufferEvent::BufferEdited);
}
_ => {}
}
}
pub fn acknowledge_buffer_version(&mut self, cx: &mut ModelContext<'_, ChannelBuffer>) {
let buffer = self.buffer.read(cx);
let version = buffer.version();
let buffer_id = buffer.remote_id();
let client = self.client.clone();
let epoch = self.epoch();
self.acknowledge_task = Some(cx.spawn_weak(|_, cx| async move {
cx.background().timer(ACKNOWLEDGE_DEBOUNCE_INTERVAL).await;
client
.send(proto::AckBufferOperation {
buffer_id,
epoch,
version: serialize_version(&version),
if let language::Event::Operation(operation) = event {
let operation = language::proto::serialize_operation(operation);
self.client
.send(proto::UpdateChannelBuffer {
channel_id: self.channel.id,
operations: vec![operation],
})
.ok();
Ok(())
}));
.log_err();
}
}
pub fn epoch(&self) -> u64 {
@@ -219,7 +217,7 @@ impl ChannelBuffer {
self.buffer.clone()
}
pub fn collaborators(&self) -> &HashMap<PeerId, Collaborator> {
pub fn collaborators(&self) -> &[proto::Collaborator] {
&self.collaborators
}

View File

@@ -1,4 +1,4 @@
use crate::{Channel, ChannelId, ChannelStore};
use crate::Channel;
use anyhow::{anyhow, Result};
use client::{
proto,
@@ -16,9 +16,7 @@ use util::{post_inc, ResultExt as _, TryFutureExt};
pub struct ChannelChat {
channel: Arc<Channel>,
messages: SumTree<ChannelMessage>,
channel_store: ModelHandle<ChannelStore>,
loaded_all_messages: bool,
last_acknowledged_id: Option<u64>,
next_pending_message_id: usize,
user_store: ModelHandle<UserStore>,
rpc: Arc<Client>,
@@ -36,7 +34,7 @@ pub struct ChannelMessage {
pub nonce: u128,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum ChannelMessageId {
Saved(u64),
Pending(usize),
@@ -57,10 +55,6 @@ pub enum ChannelChatEvent {
old_range: Range<usize>,
new_count: usize,
},
NewMessage {
channel_id: ChannelId,
message_id: u64,
},
}
pub fn init(client: &Arc<Client>) {
@@ -83,7 +77,6 @@ impl Entity for ChannelChat {
impl ChannelChat {
pub async fn new(
channel: Arc<Channel>,
channel_store: ModelHandle<ChannelStore>,
user_store: ModelHandle<UserStore>,
client: Arc<Client>,
mut cx: AsyncAppContext,
@@ -101,13 +94,11 @@ impl ChannelChat {
let mut this = Self {
channel,
user_store,
channel_store,
rpc: client,
outgoing_messages_lock: Default::default(),
messages: Default::default(),
loaded_all_messages,
next_pending_message_id: 0,
last_acknowledged_id: None,
rng: StdRng::from_entropy(),
_subscription: subscription.set_model(&cx.handle(), &mut cx.to_async()),
};
@@ -228,26 +219,6 @@ impl ChannelChat {
false
}
pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
if let ChannelMessageId::Saved(latest_message_id) = self.messages.summary().max_id {
if self
.last_acknowledged_id
.map_or(true, |acknowledged_id| acknowledged_id < latest_message_id)
{
self.rpc
.send(proto::AckChannelMessage {
channel_id: self.channel.id,
message_id: latest_message_id,
})
.ok();
self.last_acknowledged_id = Some(latest_message_id);
self.channel_store.update(cx, |store, cx| {
store.acknowledge_message_id(self.channel.id, latest_message_id, cx);
});
}
}
}
pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
let user_store = self.user_store.clone();
let rpc = self.rpc.clone();
@@ -342,15 +313,10 @@ impl ChannelChat {
.payload
.message
.ok_or_else(|| anyhow!("empty message"))?;
let message_id = message.id;
let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
this.update(&mut cx, |this, cx| {
this.insert_messages(SumTree::from_item(message, &()), cx);
cx.emit(ChannelChatEvent::NewMessage {
channel_id: this.channel.id,
message_id,
})
this.insert_messages(SumTree::from_item(message, &()), cx)
});
Ok(())
@@ -422,7 +388,6 @@ impl ChannelChat {
old_range: start_ix..end_ix,
new_count,
});
cx.notify();
}
}

View File

@@ -2,10 +2,8 @@ mod channel_index;
use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat};
use anyhow::{anyhow, Result};
use channel_index::ChannelIndex;
use client::{Client, Subscription, User, UserId, UserStore};
use collections::{hash_map, HashMap, HashSet};
use db::RELEASE_CHANNEL;
use futures::{channel::mpsc, future::Shared, Future, FutureExt, StreamExt};
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use rpc::{
@@ -16,11 +14,7 @@ use serde_derive::{Deserialize, Serialize};
use std::{borrow::Cow, hash::Hash, mem, ops::Deref, sync::Arc, time::Duration};
use util::ResultExt;
pub fn init(client: &Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
let channel_store =
cx.add_model(|cx| ChannelStore::new(client.clone(), user_store.clone(), cx));
cx.set_global(channel_store);
}
use self::channel_index::ChannelIndex;
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
@@ -49,28 +43,6 @@ pub type ChannelData = (Channel, ChannelPath);
pub struct Channel {
pub id: ChannelId,
pub name: String,
pub unseen_note_version: Option<(u64, clock::Global)>,
pub unseen_message_id: Option<u64>,
}
impl Channel {
pub fn link(&self) -> String {
RELEASE_CHANNEL.link_prefix().to_owned()
+ "channel/"
+ &self.slug()
+ "-"
+ &self.id.to_string()
}
pub fn slug(&self) -> String {
let slug: String = self
.name
.chars()
.map(|c| if c.is_alphanumeric() { c } else { '-' })
.collect();
slug.trim_matches(|c| c == '-').to_string()
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
@@ -97,10 +69,6 @@ enum OpenedModelHandle<E: Entity> {
}
impl ChannelStore {
pub fn global(cx: &AppContext) -> ModelHandle<Self> {
cx.global::<ModelHandle<Self>>().clone()
}
pub fn new(
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
@@ -114,21 +82,12 @@ impl ChannelStore {
let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
while let Some(status) = connection_status.next().await {
let this = this.upgrade(&cx)?;
match status {
client::Status::Connected { .. } => {
this.update(&mut cx, |this, cx| this.handle_connect(cx))
.await
.log_err()?;
}
client::Status::SignedOut | client::Status::UpgradeRequired => {
this.update(&mut cx, |this, cx| this.handle_disconnect(false, cx));
}
_ => {
this.update(&mut cx, |this, cx| this.handle_disconnect(true, cx));
}
}
if status.is_connected() {
this.update(&mut cx, |this, cx| this.handle_connect(cx))
.await
.log_err()?;
} else {
this.update(&mut cx, |this, cx| this.handle_disconnect(cx));
}
}
Some(())
@@ -239,73 +198,14 @@ impl ChannelStore {
cx: &mut ModelContext<Self>,
) -> Task<Result<ModelHandle<ChannelBuffer>>> {
let client = self.client.clone();
let user_store = self.user_store.clone();
self.open_channel_resource(
channel_id,
|this| &mut this.opened_buffers,
|channel, cx| ChannelBuffer::new(channel, client, user_store, cx),
|channel, cx| ChannelBuffer::new(channel, client, cx),
cx,
)
}
pub fn has_channel_buffer_changed(&self, channel_id: ChannelId) -> Option<bool> {
self.channel_index
.by_id()
.get(&channel_id)
.map(|channel| channel.unseen_note_version.is_some())
}
pub fn has_new_messages(&self, channel_id: ChannelId) -> Option<bool> {
self.channel_index
.by_id()
.get(&channel_id)
.map(|channel| channel.unseen_message_id.is_some())
}
pub fn notes_changed(
&mut self,
channel_id: ChannelId,
epoch: u64,
version: &clock::Global,
cx: &mut ModelContext<Self>,
) {
self.channel_index.note_changed(channel_id, epoch, version);
cx.notify();
}
pub fn new_message(
&mut self,
channel_id: ChannelId,
message_id: u64,
cx: &mut ModelContext<Self>,
) {
self.channel_index.new_message(channel_id, message_id);
cx.notify();
}
pub fn acknowledge_message_id(
&mut self,
channel_id: ChannelId,
message_id: u64,
cx: &mut ModelContext<Self>,
) {
self.channel_index
.acknowledge_message_id(channel_id, message_id);
cx.notify();
}
pub fn acknowledge_notes_version(
&mut self,
channel_id: ChannelId,
epoch: u64,
version: &clock::Global,
cx: &mut ModelContext<Self>,
) {
self.channel_index
.acknowledge_note_version(channel_id, epoch, version);
cx.notify();
}
pub fn open_channel_chat(
&mut self,
channel_id: ChannelId,
@@ -313,11 +213,10 @@ impl ChannelStore {
) -> Task<Result<ModelHandle<ChannelChat>>> {
let client = self.client.clone();
let user_store = self.user_store.clone();
let this = cx.handle();
self.open_channel_resource(
channel_id,
|this| &mut this.opened_chats,
|channel, cx| ChannelChat::new(channel, this, user_store, client, cx),
|channel, cx| ChannelChat::new(channel, user_store, client, cx),
cx,
)
}
@@ -832,7 +731,7 @@ impl ChannelStore {
})
}
fn handle_disconnect(&mut self, wait_for_reconnect: bool, cx: &mut ModelContext<Self>) {
fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
self.channel_index.clear();
self.channel_invitations.clear();
self.channel_participants.clear();
@@ -843,10 +742,7 @@ impl ChannelStore {
self.disconnect_channel_buffers_task.get_or_insert_with(|| {
cx.spawn_weak(|this, mut cx| async move {
if wait_for_reconnect {
cx.background().timer(RECONNECT_TIMEOUT).await;
}
cx.background().timer(RECONNECT_TIMEOUT).await;
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
for (_, buffer) in this.opened_buffers.drain() {
@@ -882,8 +778,6 @@ impl ChannelStore {
Arc::new(Channel {
id: channel.id,
name: channel.name,
unseen_note_version: None,
unseen_message_id: None,
}),
),
}
@@ -892,9 +786,7 @@ impl ChannelStore {
let channels_changed = !payload.channels.is_empty()
|| !payload.delete_channels.is_empty()
|| !payload.insert_edge.is_empty()
|| !payload.delete_edge.is_empty()
|| !payload.unseen_channel_messages.is_empty()
|| !payload.unseen_channel_buffer_changes.is_empty();
|| !payload.delete_edge.is_empty();
if channels_changed {
if !payload.delete_channels.is_empty() {
@@ -921,22 +813,6 @@ impl ChannelStore {
index.insert(channel)
}
for unseen_buffer_change in payload.unseen_channel_buffer_changes {
let version = language::proto::deserialize_version(&unseen_buffer_change.version);
index.note_changed(
unseen_buffer_change.channel_id,
unseen_buffer_change.epoch,
&version,
);
}
for unseen_channel_message in payload.unseen_channel_messages {
index.new_messages(
unseen_channel_message.channel_id,
unseen_channel_message.message_id,
);
}
for edge in payload.insert_edge {
index.insert_edge(edge.channel_id, edge.parent_id);
}

View File

@@ -38,43 +38,6 @@ impl ChannelIndex {
channels_by_id: &mut self.channels_by_id,
}
}
pub fn acknowledge_note_version(
&mut self,
channel_id: ChannelId,
epoch: u64,
version: &clock::Global,
) {
if let Some(channel) = self.channels_by_id.get_mut(&channel_id) {
let channel = Arc::make_mut(channel);
if let Some((unseen_epoch, unseen_version)) = &channel.unseen_note_version {
if epoch > *unseen_epoch
|| epoch == *unseen_epoch && version.observed_all(unseen_version)
{
channel.unseen_note_version = None;
}
}
}
}
pub fn acknowledge_message_id(&mut self, channel_id: ChannelId, message_id: u64) {
if let Some(channel) = self.channels_by_id.get_mut(&channel_id) {
let channel = Arc::make_mut(channel);
if let Some(unseen_message_id) = channel.unseen_message_id {
if message_id >= unseen_message_id {
channel.unseen_message_id = None;
}
}
}
}
pub fn note_changed(&mut self, channel_id: ChannelId, epoch: u64, version: &clock::Global) {
insert_note_changed(&mut self.channels_by_id, channel_id, epoch, version);
}
pub fn new_message(&mut self, channel_id: ChannelId, message_id: u64) {
insert_new_message(&mut self.channels_by_id, channel_id, message_id)
}
}
impl Deref for ChannelIndex {
@@ -113,14 +76,6 @@ impl<'a> ChannelPathsInsertGuard<'a> {
}
}
pub fn note_changed(&mut self, channel_id: ChannelId, epoch: u64, version: &clock::Global) {
insert_note_changed(&mut self.channels_by_id, channel_id, epoch, &version);
}
pub fn new_messages(&mut self, channel_id: ChannelId, message_id: u64) {
insert_new_message(&mut self.channels_by_id, channel_id, message_id)
}
pub fn insert(&mut self, channel_proto: proto::Channel) {
if let Some(existing_channel) = self.channels_by_id.get_mut(&channel_proto.id) {
Arc::make_mut(existing_channel).name = channel_proto.name;
@@ -130,8 +85,6 @@ impl<'a> ChannelPathsInsertGuard<'a> {
Arc::new(Channel {
id: channel_proto.id,
name: channel_proto.name,
unseen_note_version: None,
unseen_message_id: None,
}),
);
self.insert_root(channel_proto.id);
@@ -207,32 +160,3 @@ fn channel_path_sorting_key<'a>(
path.iter()
.map(|id| Some(channels_by_id.get(id)?.name.as_str()))
}
fn insert_note_changed(
channels_by_id: &mut BTreeMap<ChannelId, Arc<Channel>>,
channel_id: u64,
epoch: u64,
version: &clock::Global,
) {
if let Some(channel) = channels_by_id.get_mut(&channel_id) {
let unseen_version = Arc::make_mut(channel)
.unseen_note_version
.get_or_insert((0, clock::Global::new()));
if epoch > unseen_version.0 {
*unseen_version = (epoch, version.clone());
} else {
unseen_version.1.join(&version);
}
}
}
fn insert_new_message(
channels_by_id: &mut BTreeMap<ChannelId, Arc<Channel>>,
channel_id: u64,
message_id: u64,
) {
if let Some(channel) = channels_by_id.get_mut(&channel_id) {
let unseen_message_id = Arc::make_mut(channel).unseen_message_id.get_or_insert(0);
*unseen_message_id = message_id.max(*unseen_message_id);
}
}

View File

@@ -340,10 +340,10 @@ fn init_test(cx: &mut AppContext) -> ModelHandle<ChannelStore> {
cx.foreground().forbid_parking();
cx.set_global(SettingsStore::test(cx));
crate::init(&client);
client::init(&client, cx);
crate::init(&client, user_store, cx);
ChannelStore::global(cx)
cx.add_model(|cx| ChannelStore::new(client, user_store, cx))
}
fn update_channels(

View File

@@ -182,7 +182,6 @@ impl Bundle {
kCFStringEncodingUTF8,
ptr::null(),
));
// equivalent to: open zed-cli:... -a /Applications/Zed\ Preview.app
let urls_to_open = CFArray::from_copyable(&[url_to_open.as_concrete_TypeRef()]);
LSOpenFromURLSpec(
&LSLaunchURLSpec {

View File

@@ -33,16 +33,15 @@ parking_lot.workspace = true
postage.workspace = true
rand.workspace = true
schemars.workspace = true
serde.workspace = true
serde_derive.workspace = true
smol.workspace = true
sysinfo.workspace = true
tempfile = "3"
thiserror.workspace = true
time.workspace = true
tiny_http = "0.8"
uuid.workspace = true
uuid = { version = "1.1.2", features = ["v4"] }
url = "2.2"
serde.workspace = true
serde_derive.workspace = true
tempfile = "3"
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }

View File

@@ -34,7 +34,7 @@ use std::{
future::Future,
marker::PhantomData,
path::PathBuf,
sync::{atomic::AtomicU64, Arc, Weak},
sync::{Arc, Weak},
time::{Duration, Instant},
};
use telemetry::Telemetry;
@@ -62,15 +62,13 @@ lazy_static! {
.and_then(|v| v.parse().ok());
pub static ref ZED_APP_PATH: Option<PathBuf> =
std::env::var("ZED_APP_PATH").ok().map(PathBuf::from);
pub static ref ZED_ALWAYS_ACTIVE: bool =
std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| e.len() > 0);
}
pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
actions!(client, [SignIn, SignOut, Reconnect]);
actions!(client, [SignIn, SignOut]);
pub fn init_settings(cx: &mut AppContext) {
settings::register::<TelemetrySettings>(cx);
@@ -102,21 +100,10 @@ pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
}
}
});
cx.add_global_action({
let client = client.clone();
move |_: &Reconnect, cx| {
if let Some(client) = client.upgrade() {
cx.spawn(|cx| async move {
client.reconnect(&cx);
})
.detach();
}
}
});
}
pub struct Client {
id: AtomicU64,
id: usize,
peer: Arc<Peer>,
http: Arc<dyn HttpClient>,
telemetry: Arc<Telemetry>,
@@ -385,7 +372,7 @@ impl settings::Setting for TelemetrySettings {
impl Client {
pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
Arc::new(Self {
id: AtomicU64::new(0),
id: 0,
peer: Peer::new(0),
telemetry: Telemetry::new(http.clone(), cx),
http,
@@ -398,16 +385,17 @@ impl Client {
})
}
pub fn id(&self) -> u64 {
self.id.load(std::sync::atomic::Ordering::SeqCst)
pub fn id(&self) -> usize {
self.id
}
pub fn http_client(&self) -> Arc<dyn HttpClient> {
self.http.clone()
}
pub fn set_id(&self, id: u64) -> &Self {
self.id.store(id, std::sync::atomic::Ordering::SeqCst);
#[cfg(any(test, feature = "test-support"))]
pub fn set_id(&mut self, id: usize) -> &Self {
self.id = id;
self
}
@@ -464,7 +452,7 @@ impl Client {
}
fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
log::info!("set status on client {}: {:?}", self.id(), status);
log::info!("set status on client {}: {:?}", self.id, status);
let mut state = self.state.write();
*state.status.0.borrow_mut() = status;
@@ -815,7 +803,6 @@ impl Client {
}
}
let credentials = credentials.unwrap();
self.set_id(credentials.user_id);
if was_disconnected {
self.set_status(Status::Connecting, cx);
@@ -1223,11 +1210,6 @@ impl Client {
self.set_status(Status::SignedOut, cx);
}
pub fn reconnect(self: &Arc<Self>, cx: &AsyncAppContext) {
self.peer.teardown();
self.set_status(Status::ConnectionLost, cx);
}
fn connection_id(&self) -> Result<ConnectionId> {
if let Status::Connected { connection_id, .. } = *self.status().borrow() {
Ok(connection_id)
@@ -1237,7 +1219,7 @@ impl Client {
}
pub fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME);
log::debug!("rpc send. client_id:{}, name:{}", self.id, T::NAME);
self.peer.send(self.connection_id()?, message)
}
@@ -1253,7 +1235,7 @@ impl Client {
&self,
request: T,
) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
let client_id = self.id();
let client_id = self.id;
log::debug!(
"rpc request start. client_id:{}. name:{}",
client_id,
@@ -1274,7 +1256,7 @@ impl Client {
}
fn respond<T: RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) -> Result<()> {
log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
self.peer.respond(receipt, response)
}
@@ -1283,7 +1265,7 @@ impl Client {
receipt: Receipt<T>,
error: proto::Error,
) -> Result<()> {
log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME);
log::debug!("rpc respond. client_id:{}. name:{}", self.id, T::NAME);
self.peer.respond_with_error(receipt, error)
}
@@ -1352,7 +1334,7 @@ impl Client {
if let Some(handler) = handler {
let future = handler(subscriber, message, &self, cx.clone());
let client_id = self.id();
let client_id = self.id;
log::debug!(
"rpc message received. client_id:{}, sender_id:{:?}, type:{}",
client_id,

View File

@@ -4,9 +4,6 @@ use lazy_static::lazy_static;
use parking_lot::Mutex;
use serde::Serialize;
use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration};
use sysinfo::{
CpuRefreshKind, Pid, PidExt, ProcessExt, ProcessRefreshKind, RefreshKind, System, SystemExt,
};
use tempfile::NamedTempFile;
use util::http::HttpClient;
use util::{channel::ReleaseChannel, TryFutureExt};
@@ -20,8 +17,7 @@ pub struct Telemetry {
#[derive(Default)]
struct TelemetryState {
metrics_id: Option<Arc<str>>, // Per logged-in user
installation_id: Option<Arc<str>>, // Per app installation (different for dev, preview, and stable)
session_id: Option<Arc<str>>, // Per app launch
installation_id: Option<Arc<str>>, // Per app installation
app_version: Option<Arc<str>>,
release_channel: Option<&'static str>,
os_name: &'static str,
@@ -44,7 +40,6 @@ lazy_static! {
struct ClickhouseEventRequestBody {
token: &'static str,
installation_id: Option<Arc<str>>,
session_id: Option<Arc<str>>,
is_staff: Option<bool>,
app_version: Option<Arc<str>>,
os_name: &'static str,
@@ -93,14 +88,6 @@ pub enum ClickhouseEvent {
kind: AssistantKind,
model: &'static str,
},
Cpu {
usage_as_percentage: f32,
core_count: u32,
},
Memory {
memory_in_bytes: u64,
virtual_memory_in_bytes: u64,
},
}
#[cfg(debug_assertions)]
@@ -135,7 +122,6 @@ impl Telemetry {
release_channel,
installation_id: None,
metrics_id: None,
session_id: None,
clickhouse_events_queue: Default::default(),
flush_clickhouse_events_task: Default::default(),
log_file: None,
@@ -150,68 +136,15 @@ impl Telemetry {
Some(self.state.lock().log_file.as_ref()?.path().to_path_buf())
}
pub fn start(
self: &Arc<Self>,
installation_id: Option<String>,
session_id: String,
cx: &mut AppContext,
) {
pub fn start(self: &Arc<Self>, installation_id: Option<String>) {
let mut state = self.state.lock();
state.installation_id = installation_id.map(|id| id.into());
state.session_id = Some(session_id.into());
let has_clickhouse_events = !state.clickhouse_events_queue.is_empty();
drop(state);
if has_clickhouse_events {
self.flush_clickhouse_events();
}
let this = self.clone();
cx.spawn(|mut cx| async move {
// Avoiding calling `System::new_all()`, as there have been crashes related to it
let refresh_kind = RefreshKind::new()
.with_memory() // For memory usage
.with_processes(ProcessRefreshKind::everything()) // For process usage
.with_cpu(CpuRefreshKind::everything()); // For core count
let mut system = System::new_with_specifics(refresh_kind);
// Avoiding calling `refresh_all()`, just update what we need
system.refresh_specifics(refresh_kind);
loop {
// Waiting some amount of time before the first query is important to get a reasonable value
// https://docs.rs/sysinfo/0.29.10/sysinfo/trait.ProcessExt.html#tymethod.cpu_usage
const DURATION_BETWEEN_SYSTEM_EVENTS: Duration = Duration::from_secs(60);
smol::Timer::after(DURATION_BETWEEN_SYSTEM_EVENTS).await;
system.refresh_specifics(refresh_kind);
let current_process = Pid::from_u32(std::process::id());
let Some(process) = system.processes().get(&current_process) else {
let process = current_process;
log::error!("Failed to find own process {process:?} in system process table");
// TODO: Fire an error telemetry event
return;
};
let memory_event = ClickhouseEvent::Memory {
memory_in_bytes: process.memory(),
virtual_memory_in_bytes: process.virtual_memory(),
};
let cpu_event = ClickhouseEvent::Cpu {
usage_as_percentage: process.cpu_usage(),
core_count: system.cpus().len() as u32,
};
let telemetry_settings = cx.update(|cx| *settings::get::<TelemetrySettings>(cx));
this.report_clickhouse_event(memory_event, telemetry_settings);
this.report_clickhouse_event(cpu_event, telemetry_settings);
}
})
.detach();
}
pub fn set_authenticated_user_info(
@@ -297,21 +230,22 @@ impl Telemetry {
{
let state = this.state.lock();
let request_body = ClickhouseEventRequestBody {
token: ZED_SECRET_CLIENT_TOKEN,
installation_id: state.installation_id.clone(),
session_id: state.session_id.clone(),
is_staff: state.is_staff.clone(),
app_version: state.app_version.clone(),
os_name: state.os_name,
os_version: state.os_version.clone(),
architecture: state.architecture,
release_channel: state.release_channel,
events,
};
json_bytes.clear();
serde_json::to_writer(&mut json_bytes, &request_body)?;
serde_json::to_writer(
&mut json_bytes,
&ClickhouseEventRequestBody {
token: ZED_SECRET_CLIENT_TOKEN,
installation_id: state.installation_id.clone(),
is_staff: state.is_staff.clone(),
app_version: state.app_version.clone(),
os_name: state.os_name,
os_version: state.os_version.clone(),
architecture: state.architecture,
release_channel: state.release_channel,
events,
},
)?;
}
this.http_client

View File

@@ -7,15 +7,11 @@ use gpui::{AsyncAppContext, Entity, ImageData, ModelContext, ModelHandle, Task};
use postage::{sink::Sink, watch};
use rpc::proto::{RequestMessage, UsersResponse};
use std::sync::{Arc, Weak};
use text::ReplicaId;
use util::http::HttpClient;
use util::TryFutureExt as _;
pub type UserId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ParticipantIndex(pub u32);
#[derive(Default, Debug)]
pub struct User {
pub id: UserId,
@@ -23,13 +19,6 @@ pub struct User {
pub avatar: Option<Arc<ImageData>>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Collaborator {
pub peer_id: proto::PeerId,
pub replica_id: ReplicaId,
pub user_id: UserId,
}
impl PartialOrd for User {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
@@ -67,7 +56,6 @@ pub enum ContactRequestStatus {
pub struct UserStore {
users: HashMap<u64, Arc<User>>,
participant_indices: HashMap<u64, ParticipantIndex>,
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
current_user: watch::Receiver<Option<Arc<User>>>,
contacts: Vec<Arc<Contact>>,
@@ -93,7 +81,6 @@ pub enum Event {
kind: ContactEventKind,
},
ShowContacts,
ParticipantIndicesChanged,
}
#[derive(Clone, Copy)]
@@ -131,7 +118,6 @@ impl UserStore {
current_user: current_user_rx,
contacts: Default::default(),
incoming_contact_requests: Default::default(),
participant_indices: Default::default(),
outgoing_contact_requests: Default::default(),
invite_info: None,
client: Arc::downgrade(&client),
@@ -595,10 +581,6 @@ impl UserStore {
self.load_users(proto::FuzzySearchUsers { query }, cx)
}
pub fn get_cached_user(&self, user_id: u64) -> Option<Arc<User>> {
self.users.get(&user_id).cloned()
}
pub fn get_user(
&mut self,
user_id: u64,
@@ -659,21 +641,6 @@ impl UserStore {
}
})
}
pub fn set_participant_indices(
&mut self,
participant_indices: HashMap<u64, ParticipantIndex>,
cx: &mut ModelContext<Self>,
) {
if participant_indices != self.participant_indices {
self.participant_indices = participant_indices;
cx.emit(Event::ParticipantIndicesChanged);
}
}
pub fn participant_indices(&self) -> &HashMap<u64, ParticipantIndex> {
&self.participant_indices
}
}
impl User {
@@ -705,16 +672,6 @@ impl Contact {
}
}
impl Collaborator {
pub fn from_proto(message: proto::Collaborator) -> Result<Self> {
Ok(Self {
peer_id: message.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?,
replica_id: message.replica_id as ReplicaId,
user_id: message.user_id as UserId,
})
}
}
async fn fetch_avatar(http: &dyn HttpClient, url: &str) -> Result<Arc<ImageData>> {
let mut response = http
.get(url, Default::default(), true)

View File

@@ -3,7 +3,7 @@ authors = ["Nathan Sobo <nathan@zed.dev>"]
default-run = "collab"
edition = "2021"
name = "collab"
version = "0.23.3"
version = "0.22.1"
publish = false
[[bin]]
@@ -42,12 +42,14 @@ rand.workspace = true
reqwest = { version = "0.11", features = ["json"], optional = true }
scrypt = "0.7"
smallvec.workspace = true
sea-orm = { version = "0.12.x", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls", "with-uuid"] }
# Remove fork dependency when a version with https://github.com/SeaQL/sea-orm/pull/1283 is released.
sea-orm = { git = "https://github.com/zed-industries/sea-orm", rev = "18f4c691085712ad014a51792af75a9044bacee6", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls"] }
sea-query = "0.27"
serde.workspace = true
serde_derive.workspace = true
serde_json.workspace = true
sha-1 = "0.9"
sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid", "any"] }
sqlx = { version = "0.6", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid", "any"] }
time.workspace = true
tokio = { version = "1", features = ["full"] }
tokio-tungstenite = "0.17"
@@ -57,7 +59,6 @@ toml.workspace = true
tracing = "0.1.34"
tracing-log = "0.1.3"
tracing-subscriber = { version = "0.3.11", features = ["env-filter", "json"] }
uuid.workspace = true
[dev-dependencies]
audio = { path = "../audio" }
@@ -86,9 +87,9 @@ env_logger.workspace = true
indoc.workspace = true
util = { path = "../util" }
lazy_static.workspace = true
sea-orm = { version = "0.12.x", features = ["sqlx-sqlite"] }
sea-orm = { git = "https://github.com/zed-industries/sea-orm", rev = "18f4c691085712ad014a51792af75a9044bacee6", features = ["sqlx-sqlite"] }
serde_json.workspace = true
sqlx = { version = "0.7", features = ["sqlite"] }
sqlx = { version = "0.6", features = ["sqlite"] }
unindent.workspace = true
[features]

View File

@@ -37,10 +37,8 @@ CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b");
CREATE TABLE "rooms" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
"live_kit_room" VARCHAR NOT NULL,
"enviroment" VARCHAR,
"channel_id" INTEGER REFERENCES channels (id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX "index_rooms_on_channel_id" ON "rooms" ("channel_id");
CREATE TABLE "projects" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -160,8 +158,7 @@ CREATE TABLE "room_participants" (
"initial_project_id" INTEGER,
"calling_user_id" INTEGER NOT NULL REFERENCES users (id),
"calling_connection_id" INTEGER NOT NULL,
"calling_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE SET NULL,
"participant_index" INTEGER
"calling_connection_server_id" INTEGER REFERENCES servers (id) ON DELETE SET NULL
);
CREATE UNIQUE INDEX "index_room_participants_on_user_id" ON "room_participants" ("user_id");
CREATE INDEX "index_room_participants_on_room_id" ON "room_participants" ("room_id");
@@ -291,24 +288,3 @@ CREATE TABLE "user_features" (
CREATE UNIQUE INDEX "index_user_features_user_id_and_feature_id" ON "user_features" ("user_id", "feature_id");
CREATE INDEX "index_user_features_on_user_id" ON "user_features" ("user_id");
CREATE INDEX "index_user_features_on_feature_id" ON "user_features" ("feature_id");
CREATE TABLE "observed_buffer_edits" (
"user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
"buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE,
"epoch" INTEGER NOT NULL,
"lamport_timestamp" INTEGER NOT NULL,
"replica_id" INTEGER NOT NULL,
PRIMARY KEY (user_id, buffer_id)
);
CREATE UNIQUE INDEX "index_observed_buffers_user_and_buffer_id" ON "observed_buffer_edits" ("user_id", "buffer_id");
CREATE TABLE IF NOT EXISTS "observed_channel_messages" (
"user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
"channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
"channel_message_id" INTEGER NOT NULL,
PRIMARY KEY (user_id, channel_id)
);
CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id");

View File

@@ -1,19 +0,0 @@
CREATE TABLE IF NOT EXISTS "observed_buffer_edits" (
"user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
"buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE,
"epoch" INTEGER NOT NULL,
"lamport_timestamp" INTEGER NOT NULL,
"replica_id" INTEGER NOT NULL,
PRIMARY KEY (user_id, buffer_id)
);
CREATE UNIQUE INDEX "index_observed_buffer_user_and_buffer_id" ON "observed_buffer_edits" ("user_id", "buffer_id");
CREATE TABLE IF NOT EXISTS "observed_channel_messages" (
"user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
"channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
"channel_message_id" INTEGER NOT NULL,
PRIMARY KEY (user_id, channel_id)
);
CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id");

View File

@@ -1 +0,0 @@
ALTER TABLE room_participants ADD COLUMN participant_index INTEGER;

View File

@@ -1 +0,0 @@
ALTER TABLE rooms ADD COLUMN enviroment TEXT;

View File

@@ -1 +0,0 @@
CREATE UNIQUE INDEX "index_rooms_on_channel_id" ON "rooms" ("channel_id");

View File

@@ -19,12 +19,11 @@ use rpc::{
ConnectionId,
};
use sea_orm::{
entity::prelude::*,
sea_query::{Alias, Expr, OnConflict, Query},
ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
TransactionTrait,
entity::prelude::*, ActiveValue, Condition, ConnectionTrait, DatabaseConnection,
DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType,
QueryOrder, QuerySelect, Statement, TransactionTrait,
};
use sea_query::{Alias, Expr, OnConflict, Query};
use serde::{Deserialize, Serialize};
use sqlx::{
migrate::{Migrate, Migration, MigrationSource},
@@ -63,7 +62,6 @@ pub struct Database {
// separate files in the `queries` folder.
impl Database {
pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
sqlx::any::install_default_drivers();
Ok(Self {
options: options.clone(),
pool: sea_orm::Database::connect(options).await?,
@@ -121,7 +119,7 @@ impl Database {
Ok(new_migrations)
}
pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
@@ -323,7 +321,7 @@ fn is_serialization_error(error: &Error) -> bool {
}
}
pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
impl Deref for TransactionHandle {
type Target = DatabaseTransaction;
@@ -439,8 +437,6 @@ pub struct ChannelsForUser {
pub channels: ChannelGraph,
pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
pub channels_with_admin_privileges: HashSet<ChannelId>,
pub unseen_buffer_changes: Vec<proto::UnseenChannelBufferChange>,
pub channel_messages: Vec<proto::UnseenChannelMessage>,
}
#[derive(Debug)]
@@ -514,7 +510,7 @@ pub struct RefreshedRoom {
pub struct RefreshedChannelBuffer {
pub connection_ids: Vec<ConnectionId>,
pub collaborators: Vec<proto::Collaborator>,
pub removed_collaborators: Vec<proto::RemoveChannelBufferCollaborator>,
}
pub struct Project {

View File

@@ -1,5 +1,6 @@
use crate::Result;
use sea_orm::{entity::prelude::*, DbErr};
use sea_orm::DbErr;
use sea_query::{Value, ValueTypeErr};
use serde::{Deserialize, Serialize};
macro_rules! id_type {
@@ -16,7 +17,6 @@ macro_rules! id_type {
Hash,
Serialize,
Deserialize,
DeriveValueType,
)]
#[serde(transparent)]
pub struct $name(pub i32);
@@ -42,6 +42,40 @@ macro_rules! id_type {
}
}
impl From<$name> for sea_query::Value {
fn from(value: $name) -> Self {
sea_query::Value::Int(Some(value.0))
}
}
impl sea_orm::TryGetable for $name {
fn try_get(
res: &sea_orm::QueryResult,
pre: &str,
col: &str,
) -> Result<Self, sea_orm::TryGetError> {
Ok(Self(i32::try_get(res, pre, col)?))
}
}
impl sea_query::ValueType for $name {
fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
Ok(Self(value_to_integer(v)?))
}
fn type_name() -> String {
stringify!($name).into()
}
fn array_type() -> sea_query::ArrayType {
sea_query::ArrayType::Int
}
fn column_type() -> sea_query::ColumnType {
sea_query::ColumnType::Integer(None)
}
}
impl sea_orm::TryFromU64 for $name {
fn try_from_u64(n: u64) -> Result<Self, DbErr> {
Ok(Self(n.try_into().map_err(|_| {
@@ -54,7 +88,7 @@ macro_rules! id_type {
}
}
impl sea_orm::sea_query::Nullable for $name {
impl sea_query::Nullable for $name {
fn null() -> Value {
Value::Int(None)
}
@@ -62,6 +96,20 @@ macro_rules! id_type {
};
}
fn value_to_integer(v: Value) -> Result<i32, ValueTypeErr> {
match v {
Value::TinyInt(Some(int)) => int.try_into().map_err(|_| ValueTypeErr),
Value::SmallInt(Some(int)) => int.try_into().map_err(|_| ValueTypeErr),
Value::Int(Some(int)) => int.try_into().map_err(|_| ValueTypeErr),
Value::BigInt(Some(int)) => int.try_into().map_err(|_| ValueTypeErr),
Value::TinyUnsigned(Some(int)) => int.try_into().map_err(|_| ValueTypeErr),
Value::SmallUnsigned(Some(int)) => int.try_into().map_err(|_| ValueTypeErr),
Value::Unsigned(Some(int)) => int.try_into().map_err(|_| ValueTypeErr),
Value::BigUnsigned(Some(int)) => int.try_into().map_err(|_| ValueTypeErr),
_ => Err(ValueTypeErr),
}
}
id_type!(BufferId);
id_type!(AccessTokenId);
id_type!(ChannelChatParticipantId);

View File

@@ -2,12 +2,6 @@ use super::*;
use prost::Message;
use text::{EditOperation, UndoOperation};
pub struct LeftChannelBuffer {
pub channel_id: ChannelId,
pub collaborators: Vec<proto::Collaborator>,
pub connections: Vec<ConnectionId>,
}
impl Database {
pub async fn join_channel_buffer(
&self,
@@ -74,32 +68,7 @@ impl Database {
.await?;
collaborators.push(collaborator);
let (base_text, operations, max_operation) =
self.get_buffer_state(&buffer, &tx).await?;
// Save the last observed operation
if let Some(op) = max_operation {
observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
user_id: ActiveValue::Set(user_id),
buffer_id: ActiveValue::Set(buffer.id),
epoch: ActiveValue::Set(op.epoch),
lamport_timestamp: ActiveValue::Set(op.lamport_timestamp),
replica_id: ActiveValue::Set(op.replica_id),
})
.on_conflict(
OnConflict::columns([
observed_buffer_edits::Column::UserId,
observed_buffer_edits::Column::BufferId,
])
.update_columns([
observed_buffer_edits::Column::Epoch,
observed_buffer_edits::Column::LamportTimestamp,
])
.to_owned(),
)
.exec(&*tx)
.await?;
}
let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
Ok(proto::JoinChannelBufferResponse {
buffer_id: buffer.id.to_proto(),
@@ -235,26 +204,23 @@ impl Database {
server_id: ServerId,
) -> Result<RefreshedChannelBuffer> {
self.transaction(|tx| async move {
let db_collaborators = channel_buffer_collaborator::Entity::find()
let collaborators = channel_buffer_collaborator::Entity::find()
.filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
.all(&*tx)
.await?;
let mut connection_ids = Vec::new();
let mut collaborators = Vec::new();
let mut removed_collaborators = Vec::new();
let mut collaborator_ids_to_remove = Vec::new();
for db_collaborator in &db_collaborators {
if !db_collaborator.connection_lost
&& db_collaborator.connection_server_id == server_id
{
connection_ids.push(db_collaborator.connection());
collaborators.push(proto::Collaborator {
peer_id: Some(db_collaborator.connection().into()),
replica_id: db_collaborator.replica_id.0 as u32,
user_id: db_collaborator.user_id.to_proto(),
})
for collaborator in &collaborators {
if !collaborator.connection_lost && collaborator.connection_server_id == server_id {
connection_ids.push(collaborator.connection());
} else {
collaborator_ids_to_remove.push(db_collaborator.id);
removed_collaborators.push(proto::RemoveChannelBufferCollaborator {
channel_id: channel_id.to_proto(),
peer_id: Some(collaborator.connection().into()),
});
collaborator_ids_to_remove.push(collaborator.id);
}
}
@@ -265,7 +231,7 @@ impl Database {
Ok(RefreshedChannelBuffer {
connection_ids,
collaborators,
removed_collaborators,
})
})
.await
@@ -275,7 +241,7 @@ impl Database {
&self,
channel_id: ChannelId,
connection: ConnectionId,
) -> Result<LeftChannelBuffer> {
) -> Result<Vec<ConnectionId>> {
self.transaction(|tx| async move {
self.leave_channel_buffer_internal(channel_id, connection, &*tx)
.await
@@ -309,7 +275,7 @@ impl Database {
pub async fn leave_channel_buffers(
&self,
connection: ConnectionId,
) -> Result<Vec<LeftChannelBuffer>> {
) -> Result<Vec<(ChannelId, Vec<ConnectionId>)>> {
self.transaction(|tx| async move {
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryChannelIds {
@@ -328,10 +294,10 @@ impl Database {
let mut result = Vec::new();
for channel_id in channel_ids {
let left_channel_buffer = self
let collaborators = self
.leave_channel_buffer_internal(channel_id, connection, &*tx)
.await?;
result.push(left_channel_buffer);
result.push((channel_id, collaborators));
}
Ok(result)
@@ -344,7 +310,7 @@ impl Database {
channel_id: ChannelId,
connection: ConnectionId,
tx: &DatabaseTransaction,
) -> Result<LeftChannelBuffer> {
) -> Result<Vec<ConnectionId>> {
let result = channel_buffer_collaborator::Entity::delete_many()
.filter(
Condition::all()
@@ -361,7 +327,6 @@ impl Database {
Err(anyhow!("not a collaborator on this project"))?;
}
let mut collaborators = Vec::new();
let mut connections = Vec::new();
let mut rows = channel_buffer_collaborator::Entity::find()
.filter(
@@ -371,26 +336,19 @@ impl Database {
.await?;
while let Some(row) = rows.next().await {
let row = row?;
let connection = row.connection();
connections.push(connection);
collaborators.push(proto::Collaborator {
peer_id: Some(connection.into()),
replica_id: row.replica_id.0 as u32,
user_id: row.user_id.to_proto(),
connections.push(ConnectionId {
id: row.connection_id as u32,
owner_id: row.connection_server_id.0 as u32,
});
}
drop(rows);
if collaborators.is_empty() {
if connections.is_empty() {
self.snapshot_channel_buffer(channel_id, &tx).await?;
}
Ok(LeftChannelBuffer {
channel_id,
collaborators,
connections,
})
Ok(connections)
}
pub async fn get_channel_buffer_collaborators(
@@ -398,46 +356,33 @@ impl Database {
channel_id: ChannelId,
) -> Result<Vec<UserId>> {
self.transaction(|tx| async move {
self.get_channel_buffer_collaborators_internal(channel_id, &*tx)
.await
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryUserIds {
UserId,
}
let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
.select_only()
.column(channel_buffer_collaborator::Column::UserId)
.filter(
Condition::all()
.add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
)
.into_values::<_, QueryUserIds>()
.all(&*tx)
.await?;
Ok(users)
})
.await
}
async fn get_channel_buffer_collaborators_internal(
&self,
channel_id: ChannelId,
tx: &DatabaseTransaction,
) -> Result<Vec<UserId>> {
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryUserIds {
UserId,
}
let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
.select_only()
.column(channel_buffer_collaborator::Column::UserId)
.filter(
Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
)
.into_values::<_, QueryUserIds>()
.all(&*tx)
.await?;
Ok(users)
}
pub async fn update_channel_buffer(
&self,
channel_id: ChannelId,
user: UserId,
operations: &[proto::Operation],
) -> Result<(
Vec<ConnectionId>,
Vec<UserId>,
i32,
Vec<proto::VectorClockEntry>,
)> {
) -> Result<Vec<ConnectionId>> {
self.transaction(move |tx| async move {
self.check_user_is_channel_member(channel_id, user, &*tx)
.await?;
@@ -456,38 +401,7 @@ impl Database {
.iter()
.filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
.collect::<Vec<_>>();
let mut channel_members;
let max_version;
if !operations.is_empty() {
let max_operation = operations
.iter()
.max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
.unwrap();
max_version = vec![proto::VectorClockEntry {
replica_id: *max_operation.replica_id.as_ref() as u32,
timestamp: *max_operation.lamport_timestamp.as_ref() as u32,
}];
// get current channel participants and save the max operation above
self.save_max_operation(
user,
buffer.id,
buffer.epoch,
*max_operation.replica_id.as_ref(),
*max_operation.lamport_timestamp.as_ref(),
&*tx,
)
.await?;
channel_members = self.get_channel_members_internal(channel_id, &*tx).await?;
let collaborators = self
.get_channel_buffer_collaborators_internal(channel_id, &*tx)
.await?;
channel_members.retain(|member| !collaborators.contains(member));
buffer_operation::Entity::insert_many(operations)
.on_conflict(
OnConflict::columns([
@@ -501,9 +415,6 @@ impl Database {
)
.exec(&*tx)
.await?;
} else {
channel_members = Vec::new();
max_version = Vec::new();
}
let mut connections = Vec::new();
@@ -522,53 +433,11 @@ impl Database {
});
}
Ok((connections, channel_members, buffer.epoch, max_version))
Ok(connections)
})
.await
}
async fn save_max_operation(
&self,
user_id: UserId,
buffer_id: BufferId,
epoch: i32,
replica_id: i32,
lamport_timestamp: i32,
tx: &DatabaseTransaction,
) -> Result<()> {
use observed_buffer_edits::Column;
observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
user_id: ActiveValue::Set(user_id),
buffer_id: ActiveValue::Set(buffer_id),
epoch: ActiveValue::Set(epoch),
replica_id: ActiveValue::Set(replica_id),
lamport_timestamp: ActiveValue::Set(lamport_timestamp),
})
.on_conflict(
OnConflict::columns([Column::UserId, Column::BufferId])
.update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId])
.action_cond_where(
Condition::any().add(Column::Epoch.lt(epoch)).add(
Condition::all().add(Column::Epoch.eq(epoch)).add(
Condition::any()
.add(Column::LamportTimestamp.lt(lamport_timestamp))
.add(
Column::LamportTimestamp
.eq(lamport_timestamp)
.and(Column::ReplicaId.lt(replica_id)),
),
),
),
)
.to_owned(),
)
.exec_without_returning(tx)
.await?;
Ok(())
}
async fn get_buffer_operation_serialization_version(
&self,
buffer_id: BufferId,
@@ -586,7 +455,7 @@ impl Database {
.ok_or_else(|| anyhow!("missing buffer snapshot"))?)
}
pub async fn get_channel_buffer(
async fn get_channel_buffer(
&self,
channel_id: ChannelId,
tx: &DatabaseTransaction,
@@ -605,11 +474,7 @@ impl Database {
&self,
buffer: &buffer::Model,
tx: &DatabaseTransaction,
) -> Result<(
String,
Vec<proto::Operation>,
Option<buffer_operation::Model>,
)> {
) -> Result<(String, Vec<proto::Operation>)> {
let id = buffer.id;
let (base_text, version) = if buffer.epoch > 0 {
let snapshot = buffer_snapshot::Entity::find()
@@ -634,28 +499,16 @@ impl Database {
.eq(id)
.and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
)
.order_by_asc(buffer_operation::Column::LamportTimestamp)
.order_by_asc(buffer_operation::Column::ReplicaId)
.stream(&*tx)
.await?;
let mut operations = Vec::new();
let mut last_row = None;
while let Some(row) = rows.next().await {
let row = row?;
last_row = Some(buffer_operation::Model {
buffer_id: row.buffer_id,
epoch: row.epoch,
lamport_timestamp: row.lamport_timestamp,
replica_id: row.lamport_timestamp,
value: Default::default(),
});
operations.push(proto::Operation {
variant: Some(operation_from_storage(row, version)?),
});
variant: Some(operation_from_storage(row?, version)?),
})
}
Ok((base_text, operations, last_row))
Ok((base_text, operations))
}
async fn snapshot_channel_buffer(
@@ -664,7 +517,7 @@ impl Database {
tx: &DatabaseTransaction,
) -> Result<()> {
let buffer = self.get_channel_buffer(channel_id, tx).await?;
let (base_text, operations, _) = self.get_buffer_state(&buffer, tx).await?;
let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
if operations.is_empty() {
return Ok(());
}
@@ -697,150 +550,6 @@ impl Database {
Ok(())
}
pub async fn observe_buffer_version(
&self,
buffer_id: BufferId,
user_id: UserId,
epoch: i32,
version: &[proto::VectorClockEntry],
) -> Result<()> {
self.transaction(|tx| async move {
// For now, combine concurrent operations.
let Some(component) = version.iter().max_by_key(|version| version.timestamp) else {
return Ok(());
};
self.save_max_operation(
user_id,
buffer_id,
epoch,
component.replica_id as i32,
component.timestamp as i32,
&*tx,
)
.await?;
Ok(())
})
.await
}
pub async fn unseen_channel_buffer_changes(
&self,
user_id: UserId,
channel_ids: &[ChannelId],
tx: &DatabaseTransaction,
) -> Result<Vec<proto::UnseenChannelBufferChange>> {
#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
enum QueryIds {
ChannelId,
Id,
}
let mut channel_ids_by_buffer_id = HashMap::default();
let mut rows = buffer::Entity::find()
.filter(buffer::Column::ChannelId.is_in(channel_ids.iter().copied()))
.stream(&*tx)
.await?;
while let Some(row) = rows.next().await {
let row = row?;
channel_ids_by_buffer_id.insert(row.id, row.channel_id);
}
drop(rows);
let mut observed_edits_by_buffer_id = HashMap::default();
let mut rows = observed_buffer_edits::Entity::find()
.filter(observed_buffer_edits::Column::UserId.eq(user_id))
.filter(
observed_buffer_edits::Column::BufferId
.is_in(channel_ids_by_buffer_id.keys().copied()),
)
.stream(&*tx)
.await?;
while let Some(row) = rows.next().await {
let row = row?;
observed_edits_by_buffer_id.insert(row.buffer_id, row);
}
drop(rows);
let latest_operations = self
.get_latest_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx)
.await?;
let mut changes = Vec::default();
for latest in latest_operations {
if let Some(observed) = observed_edits_by_buffer_id.get(&latest.buffer_id) {
if (
observed.epoch,
observed.lamport_timestamp,
observed.replica_id,
) >= (latest.epoch, latest.lamport_timestamp, latest.replica_id)
{
continue;
}
}
if let Some(channel_id) = channel_ids_by_buffer_id.get(&latest.buffer_id) {
changes.push(proto::UnseenChannelBufferChange {
channel_id: channel_id.to_proto(),
epoch: latest.epoch as u64,
version: vec![proto::VectorClockEntry {
replica_id: latest.replica_id as u32,
timestamp: latest.lamport_timestamp as u32,
}],
});
}
}
Ok(changes)
}
pub async fn get_latest_operations_for_buffers(
&self,
buffer_ids: impl IntoIterator<Item = BufferId>,
tx: &DatabaseTransaction,
) -> Result<Vec<buffer_operation::Model>> {
let mut values = String::new();
for id in buffer_ids {
if !values.is_empty() {
values.push_str(", ");
}
write!(&mut values, "({})", id).unwrap();
}
if values.is_empty() {
return Ok(Vec::default());
}
let sql = format!(
r#"
SELECT
*
FROM
(
SELECT
*,
row_number() OVER (
PARTITION BY buffer_id
ORDER BY
epoch DESC,
lamport_timestamp DESC,
replica_id DESC
) as row_number
FROM buffer_operations
WHERE
buffer_id in ({values})
) AS last_operations
WHERE
row_number = 1
"#,
);
let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
Ok(buffer_operation::Entity::find()
.from_raw_sql(stmt)
.all(&*tx)
.await?)
}
}
fn operation_to_storage(

View File

@@ -1,7 +1,8 @@
use super::*;
use rpc::proto::ChannelEdge;
use smallvec::SmallVec;
use super::*;
type ChannelDescendants = HashMap<ChannelId, SmallSet<ChannelId>>;
impl Database {
@@ -19,14 +20,21 @@ impl Database {
.await
}
pub async fn create_root_channel(&self, name: &str, creator_id: UserId) -> Result<ChannelId> {
self.create_channel(name, None, creator_id).await
pub async fn create_root_channel(
&self,
name: &str,
live_kit_room: &str,
creator_id: UserId,
) -> Result<ChannelId> {
self.create_channel(name, None, live_kit_room, creator_id)
.await
}
pub async fn create_channel(
&self,
name: &str,
parent: Option<ChannelId>,
live_kit_room: &str,
creator_id: UserId,
) -> Result<ChannelId> {
let name = Self::sanitize_channel_name(name)?;
@@ -83,6 +91,14 @@ impl Database {
.insert(&*tx)
.await?;
room::ActiveModel {
channel_id: ActiveValue::Set(Some(channel.id)),
live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
..Default::default()
}
.insert(&*tx)
.await?;
Ok(channel.id)
})
.await
@@ -375,8 +391,7 @@ impl Database {
.all(&*tx)
.await?;
self.get_user_channels(user_id, channel_memberships, &tx)
.await
self.get_user_channels(channel_memberships, &tx).await
})
.await
}
@@ -399,15 +414,13 @@ impl Database {
.all(&*tx)
.await?;
self.get_user_channels(user_id, channel_membership, &tx)
.await
self.get_user_channels(channel_membership, &tx).await
})
.await
}
pub async fn get_user_channels(
&self,
user_id: UserId,
channel_memberships: Vec<channel_member::Model>,
tx: &DatabaseTransaction,
) -> Result<ChannelsForUser> {
@@ -447,21 +460,10 @@ impl Database {
}
}
let channel_ids = graph.channels.iter().map(|c| c.id).collect::<Vec<_>>();
let channel_buffer_changes = self
.unseen_channel_buffer_changes(user_id, &channel_ids, &*tx)
.await?;
let unseen_messages = self
.unseen_channel_messages(user_id, &channel_ids, &*tx)
.await?;
Ok(ChannelsForUser {
channels: graph,
channel_participants,
channels_with_admin_privileges,
unseen_buffer_changes: channel_buffer_changes,
channel_messages: unseen_messages,
})
}
@@ -643,7 +645,7 @@ impl Database {
) -> Result<Vec<ChannelId>> {
let paths = channel_path::Entity::find()
.filter(channel_path::Column::ChannelId.eq(channel_id))
.order_by(channel_path::Column::IdPath, sea_orm::Order::Desc)
.order_by(channel_path::Column::IdPath, sea_query::Order::Desc)
.all(tx)
.await?;
let mut channel_ids = Vec::new();
@@ -782,36 +784,18 @@ impl Database {
.await
}
pub async fn get_or_create_channel_room(
&self,
channel_id: ChannelId,
live_kit_room: &str,
enviroment: &str,
) -> Result<RoomId> {
pub async fn room_id_for_channel(&self, channel_id: ChannelId) -> Result<RoomId> {
self.transaction(|tx| async move {
let tx = tx;
let room = room::Entity::find()
.filter(room::Column::ChannelId.eq(channel_id))
.one(&*tx)
.await?;
let room_id = if let Some(room) = room {
room.id
} else {
let result = room::Entity::insert(room::ActiveModel {
channel_id: ActiveValue::Set(Some(channel_id)),
live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
enviroment: ActiveValue::Set(Some(enviroment.to_string())),
..Default::default()
})
.exec(&*tx)
.await?;
result.last_insert_id
};
Ok(room_id)
let room = channel::Model {
id: channel_id,
..Default::default()
}
.find_related(room::Entity)
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("invalid channel"))?;
Ok(room.id)
})
.await
}

View File

@@ -18,12 +18,12 @@ impl Database {
let user_b_participant = Alias::new("user_b_participant");
let mut db_contacts = contact::Entity::find()
.column_as(
Expr::col((user_a_participant.clone(), room_participant::Column::Id))
Expr::tbl(user_a_participant.clone(), room_participant::Column::Id)
.is_not_null(),
"user_a_busy",
)
.column_as(
Expr::col((user_b_participant.clone(), room_participant::Column::Id))
Expr::tbl(user_b_participant.clone(), room_participant::Column::Id)
.is_not_null(),
"user_b_busy",
)

View File

@@ -89,7 +89,6 @@ impl Database {
let mut rows = channel_message::Entity::find()
.filter(condition)
.order_by_desc(channel_message::Column::Id)
.limit(count as u64)
.stream(&*tx)
.await?;
@@ -109,8 +108,7 @@ impl Database {
}),
});
}
drop(rows);
messages.reverse();
Ok(messages)
})
.await
@@ -123,7 +121,7 @@ impl Database {
body: &str,
timestamp: OffsetDateTime,
nonce: u128,
) -> Result<(MessageId, Vec<ConnectionId>, Vec<UserId>)> {
) -> Result<(MessageId, Vec<ConnectionId>)> {
self.transaction(|tx| async move {
let mut rows = channel_chat_participant::Entity::find()
.filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
@@ -132,13 +130,11 @@ impl Database {
let mut is_participant = false;
let mut participant_connection_ids = Vec::new();
let mut participant_user_ids = Vec::new();
while let Some(row) = rows.next().await {
let row = row?;
if row.user_id == user_id {
is_participant = true;
}
participant_user_ids.push(row.user_id);
participant_connection_ids.push(row.connection());
}
drop(rows);
@@ -171,141 +167,11 @@ impl Database {
ConnectionId,
}
// Observe this message for the sender
self.observe_channel_message_internal(
channel_id,
user_id,
message.last_insert_id,
&*tx,
)
.await?;
let mut channel_members = self.get_channel_members_internal(channel_id, &*tx).await?;
channel_members.retain(|member| !participant_user_ids.contains(member));
Ok((
message.last_insert_id,
participant_connection_ids,
channel_members,
))
Ok((message.last_insert_id, participant_connection_ids))
})
.await
}
pub async fn observe_channel_message(
&self,
channel_id: ChannelId,
user_id: UserId,
message_id: MessageId,
) -> Result<()> {
self.transaction(|tx| async move {
self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
.await?;
Ok(())
})
.await
}
async fn observe_channel_message_internal(
&self,
channel_id: ChannelId,
user_id: UserId,
message_id: MessageId,
tx: &DatabaseTransaction,
) -> Result<()> {
observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
user_id: ActiveValue::Set(user_id),
channel_id: ActiveValue::Set(channel_id),
channel_message_id: ActiveValue::Set(message_id),
})
.on_conflict(
OnConflict::columns([
observed_channel_messages::Column::ChannelId,
observed_channel_messages::Column::UserId,
])
.update_column(observed_channel_messages::Column::ChannelMessageId)
.action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
.to_owned(),
)
// TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
.exec_without_returning(&*tx)
.await?;
Ok(())
}
pub async fn unseen_channel_messages(
&self,
user_id: UserId,
channel_ids: &[ChannelId],
tx: &DatabaseTransaction,
) -> Result<Vec<proto::UnseenChannelMessage>> {
let mut observed_messages_by_channel_id = HashMap::default();
let mut rows = observed_channel_messages::Entity::find()
.filter(observed_channel_messages::Column::UserId.eq(user_id))
.filter(observed_channel_messages::Column::ChannelId.is_in(channel_ids.iter().copied()))
.stream(&*tx)
.await?;
while let Some(row) = rows.next().await {
let row = row?;
observed_messages_by_channel_id.insert(row.channel_id, row);
}
drop(rows);
let mut values = String::new();
for id in channel_ids {
if !values.is_empty() {
values.push_str(", ");
}
write!(&mut values, "({})", id).unwrap();
}
if values.is_empty() {
return Ok(Default::default());
}
let sql = format!(
r#"
SELECT
*
FROM (
SELECT
*,
row_number() OVER (
PARTITION BY channel_id
ORDER BY id DESC
) as row_number
FROM channel_messages
WHERE
channel_id in ({values})
) AS messages
WHERE
row_number = 1
"#,
);
let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
let last_messages = channel_message::Model::find_by_statement(stmt)
.all(&*tx)
.await?;
let mut changes = Vec::new();
for last_message in last_messages {
if let Some(observed_message) =
observed_messages_by_channel_id.get(&last_message.channel_id)
{
if observed_message.channel_message_id == last_message.id {
continue;
}
}
changes.push(proto::UnseenChannelMessage {
channel_id: last_message.channel_id.to_proto(),
message_id: last_message.id.to_proto(),
});
}
Ok(changes)
}
pub async fn remove_channel_message(
&self,
channel_id: ChannelId,

View File

@@ -738,7 +738,7 @@ impl Database {
Condition::any()
.add(
Condition::all()
.add(follower::Column::ProjectId.eq(Some(project_id)))
.add(follower::Column::ProjectId.eq(project_id))
.add(
follower::Column::LeaderConnectionServerId
.eq(connection.owner_id),
@@ -747,7 +747,7 @@ impl Database {
)
.add(
Condition::all()
.add(follower::Column::ProjectId.eq(Some(project_id)))
.add(follower::Column::ProjectId.eq(project_id))
.add(
follower::Column::FollowerConnectionServerId
.eq(connection.owner_id),
@@ -862,46 +862,13 @@ impl Database {
.await
}
pub async fn check_room_participants(
&self,
room_id: RoomId,
leader_id: ConnectionId,
follower_id: ConnectionId,
) -> Result<()> {
self.transaction(|tx| async move {
use room_participant::Column;
let count = room_participant::Entity::find()
.filter(
Condition::all().add(Column::RoomId.eq(room_id)).add(
Condition::any()
.add(Column::AnsweringConnectionId.eq(leader_id.id as i32).and(
Column::AnsweringConnectionServerId.eq(leader_id.owner_id as i32),
))
.add(Column::AnsweringConnectionId.eq(follower_id.id as i32).and(
Column::AnsweringConnectionServerId.eq(follower_id.owner_id as i32),
)),
),
)
.count(&*tx)
.await?;
if count < 2 {
Err(anyhow!("not room participants"))?;
}
Ok(())
})
.await
}
pub async fn follow(
&self,
room_id: RoomId,
project_id: ProjectId,
leader_connection: ConnectionId,
follower_connection: ConnectionId,
) -> Result<RoomGuard<proto::Room>> {
let room_id = self.room_id_for_project(project_id).await?;
self.room_transaction(room_id, |tx| async move {
follower::ActiveModel {
room_id: ActiveValue::set(room_id),
@@ -927,16 +894,15 @@ impl Database {
pub async fn unfollow(
&self,
room_id: RoomId,
project_id: ProjectId,
leader_connection: ConnectionId,
follower_connection: ConnectionId,
) -> Result<RoomGuard<proto::Room>> {
let room_id = self.room_id_for_project(project_id).await?;
self.room_transaction(room_id, |tx| async move {
follower::Entity::delete_many()
.filter(
Condition::all()
.add(follower::Column::RoomId.eq(room_id))
.add(follower::Column::ProjectId.eq(project_id))
.add(
follower::Column::LeaderConnectionServerId

View File

@@ -107,12 +107,10 @@ impl Database {
user_id: UserId,
connection: ConnectionId,
live_kit_room: &str,
release_channel: &str,
) -> Result<proto::Room> {
self.transaction(|tx| async move {
let room = room::ActiveModel {
live_kit_room: ActiveValue::set(live_kit_room.into()),
enviroment: ActiveValue::set(Some(release_channel.to_string())),
..Default::default()
}
.insert(&*tx)
@@ -130,7 +128,6 @@ impl Database {
calling_connection_server_id: ActiveValue::set(Some(ServerId(
connection.owner_id as i32,
))),
participant_index: ActiveValue::set(Some(0)),
..Default::default()
}
.insert(&*tx)
@@ -155,7 +152,6 @@ impl Database {
room_id: ActiveValue::set(room_id),
user_id: ActiveValue::set(called_user_id),
answering_connection_lost: ActiveValue::set(false),
participant_index: ActiveValue::NotSet,
calling_user_id: ActiveValue::set(calling_user_id),
calling_connection_id: ActiveValue::set(calling_connection.id as i32),
calling_connection_server_id: ActiveValue::set(Some(ServerId(
@@ -272,52 +268,20 @@ impl Database {
room_id: RoomId,
user_id: UserId,
connection: ConnectionId,
enviroment: &str,
) -> Result<RoomGuard<JoinRoom>> {
self.room_transaction(room_id, |tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryChannelIdAndEnviroment {
enum QueryChannelId {
ChannelId,
Enviroment,
}
let (channel_id, release_channel): (Option<ChannelId>, Option<String>) =
room::Entity::find()
.select_only()
.column(room::Column::ChannelId)
.column(room::Column::Enviroment)
.filter(room::Column::Id.eq(room_id))
.into_values::<_, QueryChannelIdAndEnviroment>()
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such room"))?;
if let Some(release_channel) = release_channel {
if &release_channel != enviroment {
Err(anyhow!("must join using the {} release", release_channel))?;
}
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryParticipantIndices {
ParticipantIndex,
}
let existing_participant_indices: Vec<i32> = room_participant::Entity::find()
.filter(
room_participant::Column::RoomId
.eq(room_id)
.and(room_participant::Column::ParticipantIndex.is_not_null()),
)
let channel_id: Option<ChannelId> = room::Entity::find()
.select_only()
.column(room_participant::Column::ParticipantIndex)
.into_values::<_, QueryParticipantIndices>()
.all(&*tx)
.await?;
let mut participant_index = 0;
while existing_participant_indices.contains(&participant_index) {
participant_index += 1;
}
.column(room::Column::ChannelId)
.filter(room::Column::Id.eq(room_id))
.into_values::<_, QueryChannelId>()
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such room"))?;
if let Some(channel_id) = channel_id {
self.check_user_is_channel_member(channel_id, user_id, &*tx)
@@ -336,7 +300,6 @@ impl Database {
calling_connection_server_id: ActiveValue::set(Some(ServerId(
connection.owner_id as i32,
))),
participant_index: ActiveValue::Set(Some(participant_index)),
..Default::default()
}])
.on_conflict(
@@ -345,7 +308,6 @@ impl Database {
room_participant::Column::AnsweringConnectionId,
room_participant::Column::AnsweringConnectionServerId,
room_participant::Column::AnsweringConnectionLost,
room_participant::Column::ParticipantIndex,
])
.to_owned(),
)
@@ -360,7 +322,6 @@ impl Database {
.add(room_participant::Column::AnsweringConnectionId.is_null()),
)
.set(room_participant::ActiveModel {
participant_index: ActiveValue::Set(Some(participant_index)),
answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
answering_connection_server_id: ActiveValue::set(Some(ServerId(
connection.owner_id as i32,
@@ -832,7 +793,10 @@ impl Database {
let (channel_id, room) = self.get_channel_room(room_id, &tx).await?;
let deleted = if room.participants.is_empty() {
let result = room::Entity::delete_by_id(room_id).exec(&*tx).await?;
let result = room::Entity::delete_by_id(room_id)
.filter(room::Column::ChannelId.is_null())
.exec(&*tx)
.await?;
result.rows_affected > 0
} else {
false
@@ -996,39 +960,6 @@ impl Database {
Ok(room)
}
pub async fn room_connection_ids(
&self,
room_id: RoomId,
connection_id: ConnectionId,
) -> Result<RoomGuard<HashSet<ConnectionId>>> {
self.room_transaction(room_id, |tx| async move {
let mut participants = room_participant::Entity::find()
.filter(room_participant::Column::RoomId.eq(room_id))
.stream(&*tx)
.await?;
let mut is_participant = false;
let mut connection_ids = HashSet::default();
while let Some(participant) = participants.next().await {
let participant = participant?;
if let Some(answering_connection) = participant.answering_connection() {
if answering_connection == connection_id {
is_participant = true;
} else {
connection_ids.insert(answering_connection);
}
}
}
if !is_participant {
Err(anyhow!("not a room participant"))?;
}
Ok(connection_ids)
})
.await
}
async fn get_channel_room(
&self,
room_id: RoomId,
@@ -1047,15 +978,10 @@ impl Database {
let mut pending_participants = Vec::new();
while let Some(db_participant) = db_participants.next().await {
let db_participant = db_participant?;
if let (
Some(answering_connection_id),
Some(answering_connection_server_id),
Some(participant_index),
) = (
db_participant.answering_connection_id,
db_participant.answering_connection_server_id,
db_participant.participant_index,
) {
if let Some((answering_connection_id, answering_connection_server_id)) = db_participant
.answering_connection_id
.zip(db_participant.answering_connection_server_id)
{
let location = match (
db_participant.location_kind,
db_participant.location_project_id,
@@ -1086,7 +1012,6 @@ impl Database {
peer_id: Some(answering_connection.into()),
projects: Default::default(),
location: Some(proto::ParticipantLocation { variant: location }),
participant_index: participant_index as u32,
},
);
} else {

View File

@@ -184,7 +184,7 @@ impl Database {
Ok(user::Entity::find()
.from_raw_sql(Statement::from_sql_and_values(
self.pool.get_database_backend(),
query,
query.into(),
vec![like_string.into(), name_query.into(), limit.into()],
))
.all(&*tx)

View File

@@ -12,8 +12,6 @@ pub mod contact;
pub mod feature_flag;
pub mod follower;
pub mod language_server;
pub mod observed_buffer_edits;
pub mod observed_channel_messages;
pub mod project;
pub mod project_collaborator;
pub mod room;

View File

@@ -1,43 +0,0 @@
use crate::db::{BufferId, UserId};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "observed_buffer_edits")]
pub struct Model {
#[sea_orm(primary_key)]
pub user_id: UserId,
pub buffer_id: BufferId,
pub epoch: i32,
pub lamport_timestamp: i32,
pub replica_id: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::buffer::Entity",
from = "Column::BufferId",
to = "super::buffer::Column::Id"
)]
Buffer,
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::UserId",
to = "super::user::Column::Id"
)]
User,
}
impl Related<super::buffer::Entity> for Entity {
fn to() -> RelationDef {
Relation::Buffer.def()
}
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
Relation::User.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -1,41 +0,0 @@
use crate::db::{ChannelId, MessageId, UserId};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "observed_channel_messages")]
pub struct Model {
#[sea_orm(primary_key)]
pub user_id: UserId,
pub channel_id: ChannelId,
pub channel_message_id: MessageId,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::channel::Entity",
from = "Column::ChannelId",
to = "super::channel::Column::Id"
)]
Channel,
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::UserId",
to = "super::user::Column::Id"
)]
User,
}
impl Related<super::channel::Entity> for Entity {
fn to() -> RelationDef {
Relation::Channel.def()
}
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
Relation::User.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View File

@@ -8,7 +8,6 @@ pub struct Model {
pub id: RoomId,
pub live_kit_room: String,
pub channel_id: Option<ChannelId>,
pub enviroment: Option<String>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

View File

@@ -1,5 +1,4 @@
use crate::db::{ProjectId, RoomId, RoomParticipantId, ServerId, UserId};
use rpc::ConnectionId;
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
@@ -18,16 +17,6 @@ pub struct Model {
pub calling_user_id: UserId,
pub calling_connection_id: i32,
pub calling_connection_server_id: Option<ServerId>,
pub participant_index: Option<i32>,
}
impl Model {
pub fn answering_connection(&self) -> Option<ConnectionId> {
Some(ConnectionId {
owner_id: self.answering_connection_server_id?.0 as u32,
id: self.answering_connection_id? as u32,
})
}
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

View File

@@ -12,8 +12,6 @@ use sea_orm::ConnectionTrait;
use sqlx::migrate::MigrateDatabase;
use std::sync::Arc;
const TEST_RELEASE_CHANNEL: &'static str = "test";
pub struct TestDb {
pub db: Option<Arc<Database>>,
pub connection: Option<sqlx::AnyConnection>,
@@ -41,7 +39,7 @@ impl TestDb {
db.pool
.execute(sea_orm::Statement::from_string(
db.pool.get_database_backend(),
sql,
sql.into(),
))
.await
.unwrap();
@@ -136,7 +134,7 @@ impl Drop for TestDb {
db.pool
.execute(sea_orm::Statement::from_string(
db.pool.get_database_backend(),
query,
query.into(),
))
.await
.log_err();

View File

@@ -1,6 +1,6 @@
use super::*;
use crate::test_both_dbs;
use language::proto::{self, serialize_version};
use language::proto;
use text::Buffer;
test_both_dbs!(
@@ -54,7 +54,7 @@ async fn test_channel_buffers(db: &Arc<Database>) {
let owner_id = db.create_server("production").await.unwrap().0 as u32;
let zed_id = db.create_root_channel("zed", a_id).await.unwrap();
let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
db.invite_channel_member(zed_id, b_id, a_id, false)
.await
@@ -134,14 +134,14 @@ async fn test_channel_buffers(db: &Arc<Database>) {
let zed_collaborats = db.get_channel_buffer_collaborators(zed_id).await.unwrap();
assert_eq!(zed_collaborats, &[a_id, b_id]);
let left_buffer = db
let collaborators = db
.leave_channel_buffer(zed_id, connection_id_b)
.await
.unwrap();
assert_eq!(left_buffer.connections, &[connection_id_a],);
assert_eq!(collaborators, &[connection_id_a],);
let cargo_id = db.create_root_channel("cargo", a_id).await.unwrap();
let cargo_id = db.create_root_channel("cargo", "2", a_id).await.unwrap();
let _ = db
.join_channel_buffer(cargo_id, a_id, connection_id_a)
.await
@@ -163,349 +163,3 @@ async fn test_channel_buffers(db: &Arc<Database>) {
assert_eq!(buffer_response_b.base_text, "hello, cruel world");
assert_eq!(buffer_response_b.operations, &[]);
}
test_both_dbs!(
test_channel_buffers_last_operations,
test_channel_buffers_last_operations_postgres,
test_channel_buffers_last_operations_sqlite
);
async fn test_channel_buffers_last_operations(db: &Database) {
let user_id = db
.create_user(
"user_a@example.com",
false,
NewUserParams {
github_login: "user_a".into(),
github_user_id: 101,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let observer_id = db
.create_user(
"user_b@example.com",
false,
NewUserParams {
github_login: "user_b".into(),
github_user_id: 102,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let owner_id = db.create_server("production").await.unwrap().0 as u32;
let connection_id = ConnectionId {
owner_id,
id: user_id.0 as u32,
};
let mut buffers = Vec::new();
let mut text_buffers = Vec::new();
for i in 0..3 {
let channel = db
.create_root_channel(&format!("channel-{i}"), user_id)
.await
.unwrap();
db.invite_channel_member(channel, observer_id, user_id, false)
.await
.unwrap();
db.respond_to_channel_invite(channel, observer_id, true)
.await
.unwrap();
db.join_channel_buffer(channel, user_id, connection_id)
.await
.unwrap();
buffers.push(
db.transaction(|tx| async move { db.get_channel_buffer(channel, &*tx).await })
.await
.unwrap(),
);
text_buffers.push(Buffer::new(0, 0, "".to_string()));
}
let operations = db
.transaction(|tx| {
let buffers = &buffers;
async move {
db.get_latest_operations_for_buffers([buffers[0].id, buffers[2].id], &*tx)
.await
}
})
.await
.unwrap();
assert!(operations.is_empty());
update_buffer(
buffers[0].channel_id,
user_id,
db,
vec![
text_buffers[0].edit([(0..0, "a")]),
text_buffers[0].edit([(0..0, "b")]),
text_buffers[0].edit([(0..0, "c")]),
],
)
.await;
update_buffer(
buffers[1].channel_id,
user_id,
db,
vec![
text_buffers[1].edit([(0..0, "d")]),
text_buffers[1].edit([(1..1, "e")]),
text_buffers[1].edit([(2..2, "f")]),
],
)
.await;
// cause buffer 1's epoch to increment.
db.leave_channel_buffer(buffers[1].channel_id, connection_id)
.await
.unwrap();
db.join_channel_buffer(buffers[1].channel_id, user_id, connection_id)
.await
.unwrap();
text_buffers[1] = Buffer::new(1, 0, "def".to_string());
update_buffer(
buffers[1].channel_id,
user_id,
db,
vec![
text_buffers[1].edit([(0..0, "g")]),
text_buffers[1].edit([(0..0, "h")]),
],
)
.await;
update_buffer(
buffers[2].channel_id,
user_id,
db,
vec![text_buffers[2].edit([(0..0, "i")])],
)
.await;
let operations = db
.transaction(|tx| {
let buffers = &buffers;
async move {
db.get_latest_operations_for_buffers([buffers[1].id, buffers[2].id], &*tx)
.await
}
})
.await
.unwrap();
assert_operations(
&operations,
&[
(buffers[1].id, 1, &text_buffers[1]),
(buffers[2].id, 0, &text_buffers[2]),
],
);
let operations = db
.transaction(|tx| {
let buffers = &buffers;
async move {
db.get_latest_operations_for_buffers([buffers[0].id, buffers[1].id], &*tx)
.await
}
})
.await
.unwrap();
assert_operations(
&operations,
&[
(buffers[0].id, 0, &text_buffers[0]),
(buffers[1].id, 1, &text_buffers[1]),
],
);
let buffer_changes = db
.transaction(|tx| {
let buffers = &buffers;
async move {
db.unseen_channel_buffer_changes(
observer_id,
&[
buffers[0].channel_id,
buffers[1].channel_id,
buffers[2].channel_id,
],
&*tx,
)
.await
}
})
.await
.unwrap();
pretty_assertions::assert_eq!(
buffer_changes,
[
rpc::proto::UnseenChannelBufferChange {
channel_id: buffers[0].channel_id.to_proto(),
epoch: 0,
version: serialize_version(&text_buffers[0].version()),
},
rpc::proto::UnseenChannelBufferChange {
channel_id: buffers[1].channel_id.to_proto(),
epoch: 1,
version: serialize_version(&text_buffers[1].version())
.into_iter()
.filter(|vector| vector.replica_id
== buffer_changes[1].version.first().unwrap().replica_id)
.collect::<Vec<_>>(),
},
rpc::proto::UnseenChannelBufferChange {
channel_id: buffers[2].channel_id.to_proto(),
epoch: 0,
version: serialize_version(&text_buffers[2].version()),
},
]
);
db.observe_buffer_version(
buffers[1].id,
observer_id,
1,
serialize_version(&text_buffers[1].version()).as_slice(),
)
.await
.unwrap();
let buffer_changes = db
.transaction(|tx| {
let buffers = &buffers;
async move {
db.unseen_channel_buffer_changes(
observer_id,
&[
buffers[0].channel_id,
buffers[1].channel_id,
buffers[2].channel_id,
],
&*tx,
)
.await
}
})
.await
.unwrap();
assert_eq!(
buffer_changes,
[
rpc::proto::UnseenChannelBufferChange {
channel_id: buffers[0].channel_id.to_proto(),
epoch: 0,
version: serialize_version(&text_buffers[0].version()),
},
rpc::proto::UnseenChannelBufferChange {
channel_id: buffers[2].channel_id.to_proto(),
epoch: 0,
version: serialize_version(&text_buffers[2].version()),
},
]
);
// Observe an earlier version of the buffer.
db.observe_buffer_version(
buffers[1].id,
observer_id,
1,
&[rpc::proto::VectorClockEntry {
replica_id: 0,
timestamp: 0,
}],
)
.await
.unwrap();
let buffer_changes = db
.transaction(|tx| {
let buffers = &buffers;
async move {
db.unseen_channel_buffer_changes(
observer_id,
&[
buffers[0].channel_id,
buffers[1].channel_id,
buffers[2].channel_id,
],
&*tx,
)
.await
}
})
.await
.unwrap();
assert_eq!(
buffer_changes,
[
rpc::proto::UnseenChannelBufferChange {
channel_id: buffers[0].channel_id.to_proto(),
epoch: 0,
version: serialize_version(&text_buffers[0].version()),
},
rpc::proto::UnseenChannelBufferChange {
channel_id: buffers[2].channel_id.to_proto(),
epoch: 0,
version: serialize_version(&text_buffers[2].version()),
},
]
);
}
async fn update_buffer(
channel_id: ChannelId,
user_id: UserId,
db: &Database,
operations: Vec<text::Operation>,
) {
let operations = operations
.into_iter()
.map(|op| proto::serialize_operation(&language::Operation::Buffer(op)))
.collect::<Vec<_>>();
db.update_channel_buffer(channel_id, user_id, &operations)
.await
.unwrap();
}
fn assert_operations(
operations: &[buffer_operation::Model],
expected: &[(BufferId, i32, &text::Buffer)],
) {
let actual = operations
.iter()
.map(|op| buffer_operation::Model {
buffer_id: op.buffer_id,
epoch: op.epoch,
lamport_timestamp: op.lamport_timestamp,
replica_id: op.replica_id,
value: vec![],
})
.collect::<Vec<_>>();
let expected = expected
.iter()
.map(|(buffer_id, epoch, buffer)| buffer_operation::Model {
buffer_id: *buffer_id,
epoch: *epoch,
lamport_timestamp: buffer.lamport_clock.value as i32 - 1,
replica_id: buffer.replica_id() as i32,
value: vec![],
})
.collect::<Vec<_>>();
assert_eq!(actual, expected, "unexpected operations")
}

View File

@@ -5,11 +5,7 @@ use rpc::{
};
use crate::{
db::{
queries::channels::ChannelGraph,
tests::{graph, TEST_RELEASE_CHANNEL},
ChannelId, Database, NewUserParams,
},
db::{queries::channels::ChannelGraph, tests::graph, ChannelId, Database, NewUserParams},
test_both_dbs,
};
use std::sync::Arc;
@@ -45,7 +41,7 @@ async fn test_channels(db: &Arc<Database>) {
.unwrap()
.user_id;
let zed_id = db.create_root_channel("zed", a_id).await.unwrap();
let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
// Make sure that people cannot read channels they haven't been invited to
assert!(db.get_channel(zed_id, b_id).await.unwrap().is_none());
@@ -58,13 +54,16 @@ async fn test_channels(db: &Arc<Database>) {
.await
.unwrap();
let crdb_id = db.create_channel("crdb", Some(zed_id), a_id).await.unwrap();
let crdb_id = db
.create_channel("crdb", Some(zed_id), "2", a_id)
.await
.unwrap();
let livestreaming_id = db
.create_channel("livestreaming", Some(zed_id), a_id)
.create_channel("livestreaming", Some(zed_id), "3", a_id)
.await
.unwrap();
let replace_id = db
.create_channel("replace", Some(zed_id), a_id)
.create_channel("replace", Some(zed_id), "4", a_id)
.await
.unwrap();
@@ -72,14 +71,14 @@ async fn test_channels(db: &Arc<Database>) {
members.sort();
assert_eq!(members, &[a_id, b_id]);
let rust_id = db.create_root_channel("rust", a_id).await.unwrap();
let rust_id = db.create_root_channel("rust", "5", a_id).await.unwrap();
let cargo_id = db
.create_channel("cargo", Some(rust_id), a_id)
.create_channel("cargo", Some(rust_id), "6", a_id)
.await
.unwrap();
let cargo_ra_id = db
.create_channel("cargo-ra", Some(cargo_id), a_id)
.create_channel("cargo-ra", Some(cargo_id), "7", a_id)
.await
.unwrap();
@@ -199,20 +198,15 @@ async fn test_joining_channels(db: &Arc<Database>) {
.unwrap()
.user_id;
let channel_1 = db.create_root_channel("channel_1", user_1).await.unwrap();
let room_1 = db
.get_or_create_channel_room(channel_1, "1", TEST_RELEASE_CHANNEL)
let channel_1 = db
.create_root_channel("channel_1", "1", user_1)
.await
.unwrap();
let room_1 = db.room_id_for_channel(channel_1).await.unwrap();
// can join a room with membership to its channel
let joined_room = db
.join_room(
room_1,
user_1,
ConnectionId { owner_id, id: 1 },
TEST_RELEASE_CHANNEL,
)
.join_room(room_1, user_1, ConnectionId { owner_id, id: 1 })
.await
.unwrap();
assert_eq!(joined_room.room.participants.len(), 1);
@@ -220,12 +214,7 @@ async fn test_joining_channels(db: &Arc<Database>) {
drop(joined_room);
// cannot join a room without membership to its channel
assert!(db
.join_room(
room_1,
user_2,
ConnectionId { owner_id, id: 1 },
TEST_RELEASE_CHANNEL
)
.join_room(room_1, user_2, ConnectionId { owner_id, id: 1 })
.await
.is_err());
}
@@ -280,9 +269,15 @@ async fn test_channel_invites(db: &Arc<Database>) {
.unwrap()
.user_id;
let channel_1_1 = db.create_root_channel("channel_1", user_1).await.unwrap();
let channel_1_1 = db
.create_root_channel("channel_1", "1", user_1)
.await
.unwrap();
let channel_1_2 = db.create_root_channel("channel_2", user_1).await.unwrap();
let channel_1_2 = db
.create_root_channel("channel_2", "2", user_1)
.await
.unwrap();
db.invite_channel_member(channel_1_1, user_2, user_1, false)
.await
@@ -344,7 +339,7 @@ async fn test_channel_invites(db: &Arc<Database>) {
.unwrap();
let channel_1_3 = db
.create_channel("channel_3", Some(channel_1_1), user_1)
.create_channel("channel_3", Some(channel_1_1), "1", user_1)
.await
.unwrap();
@@ -406,7 +401,7 @@ async fn test_channel_renames(db: &Arc<Database>) {
.unwrap()
.user_id;
let zed_id = db.create_root_channel("zed", user_1).await.unwrap();
let zed_id = db.create_root_channel("zed", "1", user_1).await.unwrap();
db.rename_channel(zed_id, user_1, "#zed-archive")
.await
@@ -451,22 +446,25 @@ async fn test_db_channel_moving(db: &Arc<Database>) {
.unwrap()
.user_id;
let zed_id = db.create_root_channel("zed", a_id).await.unwrap();
let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
let crdb_id = db.create_channel("crdb", Some(zed_id), a_id).await.unwrap();
let crdb_id = db
.create_channel("crdb", Some(zed_id), "2", a_id)
.await
.unwrap();
let gpui2_id = db
.create_channel("gpui2", Some(zed_id), a_id)
.create_channel("gpui2", Some(zed_id), "3", a_id)
.await
.unwrap();
let livestreaming_id = db
.create_channel("livestreaming", Some(crdb_id), a_id)
.create_channel("livestreaming", Some(crdb_id), "4", a_id)
.await
.unwrap();
let livestreaming_dag_id = db
.create_channel("livestreaming_dag", Some(livestreaming_id), a_id)
.create_channel("livestreaming_dag", Some(livestreaming_id), "5", a_id)
.await
.unwrap();
@@ -519,7 +517,12 @@ async fn test_db_channel_moving(db: &Arc<Database>) {
// ========================================================================
// Create a new channel below a channel with multiple parents
let livestreaming_dag_sub_id = db
.create_channel("livestreaming_dag_sub", Some(livestreaming_dag_id), a_id)
.create_channel(
"livestreaming_dag_sub",
Some(livestreaming_dag_id),
"6",
a_id,
)
.await
.unwrap();
@@ -809,15 +812,15 @@ async fn test_db_channel_moving_bugs(db: &Arc<Database>) {
.unwrap()
.user_id;
let zed_id = db.create_root_channel("zed", user_id).await.unwrap();
let zed_id = db.create_root_channel("zed", "1", user_id).await.unwrap();
let projects_id = db
.create_channel("projects", Some(zed_id), user_id)
.create_channel("projects", Some(zed_id), "2", user_id)
.await
.unwrap();
let livestreaming_id = db
.create_channel("livestreaming", Some(projects_id), user_id)
.create_channel("livestreaming", Some(projects_id), "3", user_id)
.await
.unwrap();

View File

@@ -479,7 +479,7 @@ async fn test_project_count(db: &Arc<Database>) {
.unwrap();
let room_id = RoomId::from_proto(
db.create_room(user1.user_id, ConnectionId { owner_id, id: 0 }, "", "dev")
db.create_room(user1.user_id, ConnectionId { owner_id, id: 0 }, "")
.await
.unwrap()
.id,
@@ -493,14 +493,9 @@ async fn test_project_count(db: &Arc<Database>) {
)
.await
.unwrap();
db.join_room(
room_id,
user2.user_id,
ConnectionId { owner_id, id: 1 },
"dev",
)
.await
.unwrap();
db.join_room(room_id, user2.user_id, ConnectionId { owner_id, id: 1 })
.await
.unwrap();
assert_eq!(db.project_count_excluding_admins().await.unwrap(), 0);
db.share_project(room_id, ConnectionId { owner_id, id: 1 }, &[])
@@ -580,85 +575,6 @@ async fn test_fuzzy_search_users() {
}
}
test_both_dbs!(
test_non_matching_release_channels,
test_non_matching_release_channels_postgres,
test_non_matching_release_channels_sqlite
);
async fn test_non_matching_release_channels(db: &Arc<Database>) {
let owner_id = db.create_server("test").await.unwrap().0 as u32;
let user1 = db
.create_user(
&format!("admin@example.com"),
true,
NewUserParams {
github_login: "admin".into(),
github_user_id: 0,
invite_count: 0,
},
)
.await
.unwrap();
let user2 = db
.create_user(
&format!("user@example.com"),
false,
NewUserParams {
github_login: "user".into(),
github_user_id: 1,
invite_count: 0,
},
)
.await
.unwrap();
let room = db
.create_room(
user1.user_id,
ConnectionId { owner_id, id: 0 },
"",
"stable",
)
.await
.unwrap();
db.call(
RoomId::from_proto(room.id),
user1.user_id,
ConnectionId { owner_id, id: 0 },
user2.user_id,
None,
)
.await
.unwrap();
// User attempts to join from preview
let result = db
.join_room(
RoomId::from_proto(room.id),
user2.user_id,
ConnectionId { owner_id, id: 1 },
"preview",
)
.await;
assert!(result.is_err());
// User switches to stable
let result = db
.join_room(
RoomId::from_proto(room.id),
user2.user_id,
ConnectionId { owner_id, id: 1 },
"stable",
)
.await;
assert!(result.is_ok())
}
fn build_background_executor() -> Arc<Background> {
Deterministic::new(0).build_background()
}

View File

@@ -1,72 +1,10 @@
use crate::{
db::{Database, MessageId, NewUserParams},
db::{Database, NewUserParams},
test_both_dbs,
};
use std::sync::Arc;
use time::OffsetDateTime;
test_both_dbs!(
test_channel_message_retrieval,
test_channel_message_retrieval_postgres,
test_channel_message_retrieval_sqlite
);
async fn test_channel_message_retrieval(db: &Arc<Database>) {
let user = db
.create_user(
"user@example.com",
false,
NewUserParams {
github_login: "user".into(),
github_user_id: 1,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let channel = db.create_channel("channel", None, user).await.unwrap();
let owner_id = db.create_server("test").await.unwrap().0 as u32;
db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user)
.await
.unwrap();
let mut all_messages = Vec::new();
for i in 0..10 {
all_messages.push(
db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
.await
.unwrap()
.0
.to_proto(),
);
}
let messages = db
.get_channel_messages(channel, user, 3, None)
.await
.unwrap()
.into_iter()
.map(|message| message.id)
.collect::<Vec<_>>();
assert_eq!(messages, &all_messages[7..10]);
let messages = db
.get_channel_messages(
channel,
user,
4,
Some(MessageId::from_proto(all_messages[6])),
)
.await
.unwrap()
.into_iter()
.map(|message| message.id)
.collect::<Vec<_>>();
assert_eq!(messages, &all_messages[2..6]);
}
test_both_dbs!(
test_channel_message_nonces,
test_channel_message_nonces_postgres,
@@ -87,7 +25,10 @@ async fn test_channel_message_nonces(db: &Arc<Database>) {
.await
.unwrap()
.user_id;
let channel = db.create_channel("channel", None, user).await.unwrap();
let channel = db
.create_channel("channel", None, "room", user)
.await
.unwrap();
let owner_id = db.create_server("test").await.unwrap().0 as u32;
@@ -116,182 +57,3 @@ async fn test_channel_message_nonces(db: &Arc<Database>) {
assert_eq!(msg1_id, msg3_id);
assert_eq!(msg2_id, msg4_id);
}
test_both_dbs!(
test_channel_message_new_notification,
test_channel_message_new_notification_postgres,
test_channel_message_new_notification_sqlite
);
async fn test_channel_message_new_notification(db: &Arc<Database>) {
let user = db
.create_user(
"user_a@example.com",
false,
NewUserParams {
github_login: "user_a".into(),
github_user_id: 1,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let observer = db
.create_user(
"user_b@example.com",
false,
NewUserParams {
github_login: "user_b".into(),
github_user_id: 1,
invite_count: 0,
},
)
.await
.unwrap()
.user_id;
let channel_1 = db.create_channel("channel", None, user).await.unwrap();
let channel_2 = db.create_channel("channel-2", None, user).await.unwrap();
db.invite_channel_member(channel_1, observer, user, false)
.await
.unwrap();
db.respond_to_channel_invite(channel_1, observer, true)
.await
.unwrap();
db.invite_channel_member(channel_2, observer, user, false)
.await
.unwrap();
db.respond_to_channel_invite(channel_2, observer, true)
.await
.unwrap();
let owner_id = db.create_server("test").await.unwrap().0 as u32;
let user_connection_id = rpc::ConnectionId { owner_id, id: 0 };
db.join_channel_chat(channel_1, user_connection_id, user)
.await
.unwrap();
let _ = db
.create_channel_message(channel_1, user, "1_1", OffsetDateTime::now_utc(), 1)
.await
.unwrap();
let (second_message, _, _) = db
.create_channel_message(channel_1, user, "1_2", OffsetDateTime::now_utc(), 2)
.await
.unwrap();
let (third_message, _, _) = db
.create_channel_message(channel_1, user, "1_3", OffsetDateTime::now_utc(), 3)
.await
.unwrap();
db.join_channel_chat(channel_2, user_connection_id, user)
.await
.unwrap();
let (fourth_message, _, _) = db
.create_channel_message(channel_2, user, "2_1", OffsetDateTime::now_utc(), 4)
.await
.unwrap();
// Check that observer has new messages
let unseen_messages = db
.transaction(|tx| async move {
db.unseen_channel_messages(observer, &[channel_1, channel_2], &*tx)
.await
})
.await
.unwrap();
assert_eq!(
unseen_messages,
[
rpc::proto::UnseenChannelMessage {
channel_id: channel_1.to_proto(),
message_id: third_message.to_proto(),
},
rpc::proto::UnseenChannelMessage {
channel_id: channel_2.to_proto(),
message_id: fourth_message.to_proto(),
},
]
);
// Observe the second message
db.observe_channel_message(channel_1, observer, second_message)
.await
.unwrap();
// Make sure the observer still has a new message
let unseen_messages = db
.transaction(|tx| async move {
db.unseen_channel_messages(observer, &[channel_1, channel_2], &*tx)
.await
})
.await
.unwrap();
assert_eq!(
unseen_messages,
[
rpc::proto::UnseenChannelMessage {
channel_id: channel_1.to_proto(),
message_id: third_message.to_proto(),
},
rpc::proto::UnseenChannelMessage {
channel_id: channel_2.to_proto(),
message_id: fourth_message.to_proto(),
},
]
);
// Observe the third message,
db.observe_channel_message(channel_1, observer, third_message)
.await
.unwrap();
// Make sure the observer does not have a new method
let unseen_messages = db
.transaction(|tx| async move {
db.unseen_channel_messages(observer, &[channel_1, channel_2], &*tx)
.await
})
.await
.unwrap();
assert_eq!(
unseen_messages,
[rpc::proto::UnseenChannelMessage {
channel_id: channel_2.to_proto(),
message_id: fourth_message.to_proto(),
}]
);
// Observe the second message again, should not regress our observed state
db.observe_channel_message(channel_1, observer, second_message)
.await
.unwrap();
// Make sure the observer does not have a new message
let unseen_messages = db
.transaction(|tx| async move {
db.unseen_channel_messages(observer, &[channel_1, channel_2], &*tx)
.await
})
.await
.unwrap();
assert_eq!(
unseen_messages,
[rpc::proto::UnseenChannelMessage {
channel_id: channel_2.to_proto(),
message_id: fourth_message.to_proto(),
}]
);
}

View File

@@ -3,8 +3,8 @@ mod connection_pool;
use crate::{
auth,
db::{
self, BufferId, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId,
ServerId, User, UserId,
self, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId, ServerId, User,
UserId,
},
executor::Executor,
AppState, Result,
@@ -38,8 +38,8 @@ use lazy_static::lazy_static;
use prometheus::{register_int_gauge, IntGauge};
use rpc::{
proto::{
self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
self, Ack, AddChannelBufferCollaborator, AnyTypedEnvelope, ChannelEdge, EntityMessage,
EnvelopedMessage, LiveKitConnectionInfo, RequestMessage,
},
Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
};
@@ -63,7 +63,6 @@ use time::OffsetDateTime;
use tokio::sync::{watch, Semaphore};
use tower::ServiceBuilder;
use tracing::{info_span, instrument, Instrument};
use util::channel::RELEASE_CHANNEL_NAME;
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
@@ -275,9 +274,7 @@ impl Server {
.add_message_handler(unfollow)
.add_message_handler(update_followers)
.add_message_handler(update_diff_base)
.add_request_handler(get_private_user_info)
.add_message_handler(acknowledge_channel_message)
.add_message_handler(acknowledge_buffer_version);
.add_request_handler(get_private_user_info);
Arc::new(server)
}
@@ -316,16 +313,9 @@ impl Server {
.trace_err()
{
for connection_id in refreshed_channel_buffer.connection_ids {
peer.send(
connection_id,
proto::UpdateChannelBufferCollaborators {
channel_id: channel_id.to_proto(),
collaborators: refreshed_channel_buffer
.collaborators
.clone(),
},
)
.trace_err();
for message in &refreshed_channel_buffer.removed_collaborators {
peer.send(connection_id, message.clone()).trace_err();
}
}
}
}
@@ -938,6 +928,11 @@ async fn create_room(
util::async_iife!({
let live_kit = live_kit?;
live_kit
.create_room(live_kit_room.clone())
.await
.trace_err()?;
let token = live_kit
.room_token(&live_kit_room, &session.user_id.to_string())
.trace_err()?;
@@ -953,12 +948,7 @@ async fn create_room(
let room = session
.db()
.await
.create_room(
session.user_id,
session.connection_id,
&live_kit_room,
RELEASE_CHANNEL_NAME.as_str(),
)
.create_room(session.user_id, session.connection_id, &live_kit_room)
.await?;
response.send(proto::CreateRoomResponse {
@@ -980,12 +970,7 @@ async fn join_room(
let room = session
.db()
.await
.join_room(
room_id,
session.user_id,
session.connection_id,
RELEASE_CHANNEL_NAME.as_str(),
)
.join_room(room_id, session.user_id, session.connection_id)
.await?;
room_updated(&room.room, &session.peer);
room.into_inner()
@@ -1898,94 +1883,94 @@ async fn follow(
response: Response<proto::Follow>,
session: Session,
) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id);
let project_id = request.project_id.map(ProjectId::from_proto);
let project_id = ProjectId::from_proto(request.project_id);
let leader_id = request
.leader_id
.ok_or_else(|| anyhow!("invalid leader id"))?
.into();
let follower_id = session.connection_id;
session
.db()
.await
.check_room_participants(room_id, leader_id, session.connection_id)
.await?;
{
let project_connection_ids = session
.db()
.await
.project_connection_ids(project_id, session.connection_id)
.await?;
let response_payload = session
if !project_connection_ids.contains(&leader_id) {
Err(anyhow!("no such peer"))?;
}
}
let mut response_payload = session
.peer
.forward_request(session.connection_id, leader_id, request)
.await?;
response_payload
.views
.retain(|view| view.leader_id != Some(follower_id.into()));
response.send(response_payload)?;
if let Some(project_id) = project_id {
let room = session
.db()
.await
.follow(room_id, project_id, leader_id, follower_id)
.await?;
room_updated(&room, &session.peer);
}
let room = session
.db()
.await
.follow(project_id, leader_id, follower_id)
.await?;
room_updated(&room, &session.peer);
Ok(())
}
async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id);
let project_id = request.project_id.map(ProjectId::from_proto);
let project_id = ProjectId::from_proto(request.project_id);
let leader_id = request
.leader_id
.ok_or_else(|| anyhow!("invalid leader id"))?
.into();
let follower_id = session.connection_id;
session
if !session
.db()
.await
.check_room_participants(room_id, leader_id, session.connection_id)
.await?;
.project_connection_ids(project_id, session.connection_id)
.await?
.contains(&leader_id)
{
Err(anyhow!("no such peer"))?;
}
session
.peer
.forward_send(session.connection_id, leader_id, request)?;
if let Some(project_id) = project_id {
let room = session
.db()
.await
.unfollow(room_id, project_id, leader_id, follower_id)
.await?;
room_updated(&room, &session.peer);
}
let room = session
.db()
.await
.unfollow(project_id, leader_id, follower_id)
.await?;
room_updated(&room, &session.peer);
Ok(())
}
async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
let room_id = RoomId::from_proto(request.room_id);
let database = session.db.lock().await;
let project_id = ProjectId::from_proto(request.project_id);
let project_connection_ids = session
.db
.lock()
.await
.project_connection_ids(project_id, session.connection_id)
.await?;
let connection_ids = if let Some(project_id) = request.project_id {
let project_id = ProjectId::from_proto(project_id);
database
.project_connection_ids(project_id, session.connection_id)
.await?
} else {
database
.room_connection_ids(room_id, session.connection_id)
.await?
};
// For now, don't send view update messages back to that view's current leader.
let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
let leader_id = request.variant.as_ref().and_then(|variant| match variant {
proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
_ => None,
proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id,
});
for follower_peer_id in request.follower_ids.iter().copied() {
let follower_connection_id = follower_peer_id.into();
if Some(follower_peer_id) != connection_id_to_omit
&& connection_ids.contains(&follower_connection_id)
if project_connection_ids.contains(&follower_connection_id)
&& Some(follower_peer_id) != leader_id
{
session.peer.forward_send(
session.connection_id,
@@ -2201,10 +2186,15 @@ async fn create_channel(
session: Session,
) -> Result<()> {
let db = session.db().await;
let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
if let Some(live_kit) = session.live_kit_client.as_ref() {
live_kit.create_room(live_kit_room.clone()).await?;
}
let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
let id = db
.create_channel(&request.name, parent_id, session.user_id)
.create_channel(&request.name, parent_id, &live_kit_room, session.user_id)
.await?;
let channel = proto::Channel {
@@ -2571,8 +2561,6 @@ async fn respond_to_channel_invite(
name: channel.name,
}),
);
update.unseen_channel_messages = result.channel_messages;
update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
update.insert_edge = result.channels.edges;
update
.channel_participants
@@ -2609,23 +2597,15 @@ async fn join_channel(
session: Session,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
let joined_room = {
leave_room_for_session(&session).await?;
let db = session.db().await;
let room_id = db
.get_or_create_channel_room(channel_id, &live_kit_room, &*RELEASE_CHANNEL_NAME)
.await?;
let room_id = db.room_id_for_channel(channel_id).await?;
let joined_room = db
.join_room(
room_id,
session.user_id,
session.connection_id,
RELEASE_CHANNEL_NAME.as_str(),
)
.join_room(room_id, session.user_id, session.connection_id)
.await?;
let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
@@ -2678,12 +2658,18 @@ async fn join_channel_buffer(
.join_channel_buffer(channel_id, session.user_id, session.connection_id)
.await?;
let replica_id = open_response.replica_id;
let collaborators = open_response.collaborators.clone();
response.send(open_response)?;
let update = UpdateChannelBufferCollaborators {
let update = AddChannelBufferCollaborator {
channel_id: channel_id.to_proto(),
collaborators: collaborators.clone(),
collaborator: Some(proto::Collaborator {
user_id: session.user_id.to_proto(),
peer_id: Some(session.connection_id.into()),
replica_id,
}),
};
channel_buffer_updated(
session.connection_id,
@@ -2704,7 +2690,7 @@ async fn update_channel_buffer(
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
let (collaborators, non_collaborators, epoch, version) = db
let collaborators = db
.update_channel_buffer(channel_id, session.user_id, &request.operations)
.await?;
@@ -2717,29 +2703,6 @@ async fn update_channel_buffer(
},
&session.peer,
);
let pool = &*session.connection_pool().await;
broadcast(
None,
non_collaborators
.iter()
.flat_map(|user_id| pool.user_connection_ids(*user_id)),
|peer_id| {
session.peer.send(
peer_id.into(),
proto::UpdateChannels {
unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
channel_id: channel_id.to_proto(),
epoch: epoch as u64,
version: version.clone(),
}],
..Default::default()
},
)
},
);
Ok(())
}
@@ -2753,8 +2716,8 @@ async fn rejoin_channel_buffers(
.rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
.await?;
for rejoined_buffer in &buffers {
let collaborators_to_notify = rejoined_buffer
for buffer in &buffers {
let collaborators_to_notify = buffer
.buffer
.collaborators
.iter()
@@ -2762,9 +2725,10 @@ async fn rejoin_channel_buffers(
channel_buffer_updated(
session.connection_id,
collaborators_to_notify,
&proto::UpdateChannelBufferCollaborators {
channel_id: rejoined_buffer.buffer.channel_id,
collaborators: rejoined_buffer.buffer.collaborators.clone(),
&proto::UpdateChannelBufferCollaborator {
channel_id: buffer.buffer.channel_id,
old_peer_id: Some(buffer.old_connection_id.into()),
new_peer_id: Some(session.connection_id.into()),
},
&session.peer,
);
@@ -2785,7 +2749,7 @@ async fn leave_channel_buffer(
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
let left_buffer = db
let collaborators_to_notify = db
.leave_channel_buffer(channel_id, session.connection_id)
.await?;
@@ -2793,10 +2757,10 @@ async fn leave_channel_buffer(
channel_buffer_updated(
session.connection_id,
left_buffer.connections,
&proto::UpdateChannelBufferCollaborators {
collaborators_to_notify,
&proto::RemoveChannelBufferCollaborator {
channel_id: channel_id.to_proto(),
collaborators: left_buffer.collaborators,
peer_id: Some(session.connection_id.into()),
},
&session.peer,
);
@@ -2835,7 +2799,7 @@ async fn send_channel_message(
.ok_or_else(|| anyhow!("nonce can't be blank"))?;
let channel_id = ChannelId::from_proto(request.channel_id);
let (message_id, connection_ids, non_participants) = session
let (message_id, connection_ids) = session
.db()
.await
.create_channel_message(
@@ -2865,27 +2829,6 @@ async fn send_channel_message(
response.send(proto::SendChannelMessageResponse {
message: Some(message),
})?;
let pool = &*session.connection_pool().await;
broadcast(
None,
non_participants
.iter()
.flat_map(|user_id| pool.user_connection_ids(*user_id)),
|peer_id| {
session.peer.send(
peer_id.into(),
proto::UpdateChannels {
unseen_channel_messages: vec![proto::UnseenChannelMessage {
channel_id: channel_id.to_proto(),
message_id: message_id.to_proto(),
}],
..Default::default()
},
)
},
);
Ok(())
}
@@ -2908,38 +2851,6 @@ async fn remove_channel_message(
Ok(())
}
async fn acknowledge_channel_message(
request: proto::AckChannelMessage,
session: Session,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id);
session
.db()
.await
.observe_channel_message(channel_id, session.user_id, message_id)
.await?;
Ok(())
}
async fn acknowledge_buffer_version(
request: proto::AckBufferOperation,
session: Session,
) -> Result<()> {
let buffer_id = BufferId::from_proto(request.buffer_id);
session
.db()
.await
.observe_buffer_version(
buffer_id,
session.user_id,
request.epoch as i32,
&request.version,
)
.await?;
Ok(())
}
async fn join_channel_chat(
request: proto::JoinChannelChat,
response: Response<proto::JoinChannelChat>,
@@ -3075,8 +2986,6 @@ fn build_initial_channels_update(
});
}
update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
update.unseen_channel_messages = channels.channel_messages;
update.insert_edge = channels.channels.edges;
for (channel_id, participants) in channels.channel_participants {
@@ -3326,13 +3235,13 @@ async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
.leave_channel_buffers(session.connection_id)
.await?;
for left_buffer in left_channel_buffers {
for (channel_id, connections) in left_channel_buffers {
channel_buffer_updated(
session.connection_id,
left_buffer.connections,
&proto::UpdateChannelBufferCollaborators {
channel_id: left_buffer.channel_id.to_proto(),
collaborators: left_buffer.collaborators,
connections,
&proto::RemoveChannelBufferCollaborator {
channel_id: channel_id.to_proto(),
peer_id: Some(session.connection_id.into()),
},
&session.peer,
);

View File

@@ -4,7 +4,6 @@ use gpui::{ModelHandle, TestAppContext};
mod channel_buffer_tests;
mod channel_message_tests;
mod channel_tests;
mod following_tests;
mod integration_tests;
mod random_channel_buffer_tests;
mod random_project_collaboration_tests;

View File

@@ -3,17 +3,15 @@ use crate::{
tests::TestServer,
};
use call::ActiveCall;
use channel::{Channel, ACKNOWLEDGE_DEBOUNCE_INTERVAL};
use client::ParticipantIndex;
use client::{Collaborator, UserId};
use channel::Channel;
use client::UserId;
use collab_ui::channel_view::ChannelView;
use collections::HashMap;
use editor::{Anchor, Editor, ToOffset};
use futures::future;
use gpui::{executor::Deterministic, ModelHandle, TestAppContext, ViewContext};
use rpc::{proto::PeerId, RECEIVE_TIMEOUT};
use gpui::{executor::Deterministic, ModelHandle, TestAppContext};
use rpc::{proto, RECEIVE_TIMEOUT};
use serde_json::json;
use std::{ops::Range, sync::Arc};
use std::sync::Arc;
#[gpui::test]
async fn test_core_channel_buffers(
@@ -102,7 +100,7 @@ async fn test_core_channel_buffers(
channel_buffer_b.read_with(cx_b, |buffer, _| {
assert_collaborators(
&buffer.collaborators(),
&[client_a.user_id(), client_b.user_id()],
&[client_b.user_id(), client_a.user_id()],
);
});
@@ -122,10 +120,10 @@ async fn test_core_channel_buffers(
}
#[gpui::test]
async fn test_channel_notes_participant_indices(
async fn test_channel_buffer_replica_ids(
deterministic: Arc<Deterministic>,
mut cx_a: &mut TestAppContext,
mut cx_b: &mut TestAppContext,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
cx_c: &mut TestAppContext,
) {
deterministic.forbid_parking();
@@ -134,13 +132,6 @@ async fn test_channel_notes_participant_indices(
let client_b = server.create_client(cx_b, "user_b").await;
let client_c = server.create_client(cx_c, "user_c").await;
let active_call_a = cx_a.read(ActiveCall::global);
let active_call_b = cx_b.read(ActiveCall::global);
cx_a.update(editor::init);
cx_b.update(editor::init);
cx_c.update(editor::init);
let channel_id = server
.make_channel(
"the-channel",
@@ -150,173 +141,140 @@ async fn test_channel_notes_participant_indices(
)
.await;
client_a
let active_call_a = cx_a.read(ActiveCall::global);
let active_call_b = cx_b.read(ActiveCall::global);
let active_call_c = cx_c.read(ActiveCall::global);
// Clients A and B join a channel.
active_call_a
.update(cx_a, |call, cx| call.join_channel(channel_id, cx))
.await
.unwrap();
active_call_b
.update(cx_b, |call, cx| call.join_channel(channel_id, cx))
.await
.unwrap();
// Clients A, B, and C join a channel buffer
// C first so that the replica IDs in the project and the channel buffer are different
let channel_buffer_c = client_c
.channel_store()
.update(cx_c, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
let channel_buffer_b = client_b
.channel_store()
.update(cx_b, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
let channel_buffer_a = client_a
.channel_store()
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
// Client B shares a project
client_b
.fs()
.insert_tree("/root", json!({"file.txt": "123"}))
.insert_tree("/dir", json!({ "file.txt": "contents" }))
.await;
let (project_a, worktree_id_a) = client_a.build_local_project("/root", cx_a).await;
let project_b = client_b.build_empty_local_project(cx_b);
let project_c = client_c.build_empty_local_project(cx_c);
let workspace_a = client_a.build_workspace(&project_a, cx_a).root(cx_a);
let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b);
let workspace_c = client_c.build_workspace(&project_c, cx_c).root(cx_c);
// Clients A, B, and C open the channel notes
let channel_view_a = cx_a
.update(|cx| ChannelView::open(channel_id, workspace_a.clone(), cx))
.await
.unwrap();
let channel_view_b = cx_b
.update(|cx| ChannelView::open(channel_id, workspace_b.clone(), cx))
.await
.unwrap();
let channel_view_c = cx_c
.update(|cx| ChannelView::open(channel_id, workspace_c.clone(), cx))
let (project_b, _) = client_b.build_local_project("/dir", cx_b).await;
let shared_project_id = active_call_b
.update(cx_b, |call, cx| call.share_project(project_b.clone(), cx))
.await
.unwrap();
// Clients A, B, and C all insert and select some text
channel_view_a.update(cx_a, |notes, cx| {
notes.editor.update(cx, |editor, cx| {
editor.insert("a", cx);
editor.change_selections(None, cx, |selections| {
selections.select_ranges(vec![0..1]);
});
});
});
deterministic.run_until_parked();
channel_view_b.update(cx_b, |notes, cx| {
notes.editor.update(cx, |editor, cx| {
editor.move_down(&Default::default(), cx);
editor.insert("b", cx);
editor.change_selections(None, cx, |selections| {
selections.select_ranges(vec![1..2]);
});
});
});
deterministic.run_until_parked();
channel_view_c.update(cx_c, |notes, cx| {
notes.editor.update(cx, |editor, cx| {
editor.move_down(&Default::default(), cx);
editor.insert("c", cx);
editor.change_selections(None, cx, |selections| {
selections.select_ranges(vec![2..3]);
});
});
});
// Client A sees clients B and C without assigned colors, because they aren't
// in a call together.
deterministic.run_until_parked();
channel_view_a.update(cx_a, |notes, cx| {
notes.editor.update(cx, |editor, cx| {
assert_remote_selections(editor, &[(None, 1..2), (None, 2..3)], cx);
});
});
// Clients A and B join the same call.
for (call, cx) in [(&active_call_a, &mut cx_a), (&active_call_b, &mut cx_b)] {
call.update(*cx, |call, cx| call.join_channel(channel_id, cx))
.await
.unwrap();
}
// Clients A and B see each other with two different assigned colors. Client C
// still doesn't have a color.
deterministic.run_until_parked();
channel_view_a.update(cx_a, |notes, cx| {
notes.editor.update(cx, |editor, cx| {
assert_remote_selections(
editor,
&[(Some(ParticipantIndex(1)), 1..2), (None, 2..3)],
cx,
);
});
});
channel_view_b.update(cx_b, |notes, cx| {
notes.editor.update(cx, |editor, cx| {
assert_remote_selections(
editor,
&[(Some(ParticipantIndex(0)), 0..1), (None, 2..3)],
cx,
);
});
});
// Client A shares a project, and client B joins.
let project_id = active_call_a
.update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
.await
.unwrap();
let project_b = client_b.build_remote_project(project_id, cx_b).await;
let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b);
// Clients A and B open the same file.
let editor_a = workspace_a
.update(cx_a, |workspace, cx| {
workspace.open_path((worktree_id_a, "file.txt"), None, true, cx)
})
.await
.unwrap()
.downcast::<Editor>()
.unwrap();
let editor_b = workspace_b
.update(cx_b, |workspace, cx| {
workspace.open_path((worktree_id_a, "file.txt"), None, true, cx)
})
.await
.unwrap()
.downcast::<Editor>()
.unwrap();
editor_a.update(cx_a, |editor, cx| {
editor.change_selections(None, cx, |selections| {
selections.select_ranges(vec![0..1]);
});
});
editor_b.update(cx_b, |editor, cx| {
editor.change_selections(None, cx, |selections| {
selections.select_ranges(vec![2..3]);
});
});
// Client A joins the project
let project_a = client_a.build_remote_project(shared_project_id, cx_a).await;
deterministic.run_until_parked();
// Clients A and B see each other with the same colors as in the channel notes.
editor_a.update(cx_a, |editor, cx| {
assert_remote_selections(editor, &[(Some(ParticipantIndex(1)), 2..3)], cx);
});
editor_b.update(cx_b, |editor, cx| {
assert_remote_selections(editor, &[(Some(ParticipantIndex(0)), 0..1)], cx);
});
}
// Client C is in a separate project.
client_c.fs().insert_tree("/dir", json!({})).await;
let (separate_project_c, _) = client_c.build_local_project("/dir", cx_c).await;
#[track_caller]
fn assert_remote_selections(
editor: &mut Editor,
expected_selections: &[(Option<ParticipantIndex>, Range<usize>)],
cx: &mut ViewContext<Editor>,
) {
let snapshot = editor.snapshot(cx);
let range = Anchor::min()..Anchor::max();
let remote_selections = snapshot
.remote_selections_in_range(&range, editor.collaboration_hub().unwrap(), cx)
.map(|s| {
let start = s.selection.start.to_offset(&snapshot.buffer_snapshot);
let end = s.selection.end.to_offset(&snapshot.buffer_snapshot);
(s.participant_index, start..end)
})
.collect::<Vec<_>>();
assert_eq!(
remote_selections, expected_selections,
"incorrect remote selections"
);
// Note that each user has a different replica id in the projects vs the
// channel buffer.
channel_buffer_a.read_with(cx_a, |channel_buffer, cx| {
assert_eq!(project_a.read(cx).replica_id(), 1);
assert_eq!(channel_buffer.buffer().read(cx).replica_id(), 2);
});
channel_buffer_b.read_with(cx_b, |channel_buffer, cx| {
assert_eq!(project_b.read(cx).replica_id(), 0);
assert_eq!(channel_buffer.buffer().read(cx).replica_id(), 1);
});
channel_buffer_c.read_with(cx_c, |channel_buffer, cx| {
// C is not in the project
assert_eq!(channel_buffer.buffer().read(cx).replica_id(), 0);
});
let channel_window_a =
cx_a.add_window(|cx| ChannelView::new(project_a.clone(), channel_buffer_a.clone(), cx));
let channel_window_b =
cx_b.add_window(|cx| ChannelView::new(project_b.clone(), channel_buffer_b.clone(), cx));
let channel_window_c = cx_c.add_window(|cx| {
ChannelView::new(separate_project_c.clone(), channel_buffer_c.clone(), cx)
});
let channel_view_a = channel_window_a.root(cx_a);
let channel_view_b = channel_window_b.root(cx_b);
let channel_view_c = channel_window_c.root(cx_c);
// For clients A and B, the replica ids in the channel buffer are mapped
// so that they match the same users' replica ids in their shared project.
channel_view_a.read_with(cx_a, |view, cx| {
assert_eq!(
view.editor.read(cx).replica_id_map().unwrap(),
&[(1, 0), (2, 1)].into_iter().collect::<HashMap<_, _>>()
);
});
channel_view_b.read_with(cx_b, |view, cx| {
assert_eq!(
view.editor.read(cx).replica_id_map().unwrap(),
&[(1, 0), (2, 1)].into_iter().collect::<HashMap<u16, u16>>(),
)
});
// Client C only sees themself, as they're not part of any shared project
channel_view_c.read_with(cx_c, |view, cx| {
assert_eq!(
view.editor.read(cx).replica_id_map().unwrap(),
&[(0, 0)].into_iter().collect::<HashMap<u16, u16>>(),
);
});
// Client C joins the project that clients A and B are in.
active_call_c
.update(cx_c, |call, cx| call.join_channel(channel_id, cx))
.await
.unwrap();
let project_c = client_c.build_remote_project(shared_project_id, cx_c).await;
deterministic.run_until_parked();
project_c.read_with(cx_c, |project, _| {
assert_eq!(project.replica_id(), 2);
});
// For clients A and B, client C's replica id in the channel buffer is
// now mapped to their replica id in the shared project.
channel_view_a.read_with(cx_a, |view, cx| {
assert_eq!(
view.editor.read(cx).replica_id_map().unwrap(),
&[(1, 0), (2, 1), (0, 2)]
.into_iter()
.collect::<HashMap<_, _>>()
);
});
channel_view_b.read_with(cx_b, |view, cx| {
assert_eq!(
view.editor.read(cx).replica_id_map().unwrap(),
&[(1, 0), (2, 1), (0, 2)]
.into_iter()
.collect::<HashMap<_, _>>(),
)
});
}
#[gpui::test]
async fn test_multiple_handles_to_channel_buffer(
deterministic: Arc<Deterministic>,
cx_a: &mut TestAppContext,
) {
async fn test_reopen_channel_buffer(deterministic: Arc<Deterministic>, cx_a: &mut TestAppContext) {
deterministic.forbid_parking();
let mut server = TestServer::start(&deterministic).await;
let client_a = server.create_client(cx_a, "user_a").await;
@@ -410,7 +368,10 @@ async fn test_channel_buffer_disconnect(
channel_buffer_a.update(cx_a, |buffer, _| {
assert_eq!(
buffer.channel().as_ref(),
&channel(channel_id, "the-channel")
&Channel {
id: channel_id,
name: "the-channel".to_string()
}
);
assert!(!buffer.is_connected());
});
@@ -435,21 +396,15 @@ async fn test_channel_buffer_disconnect(
channel_buffer_b.update(cx_b, |buffer, _| {
assert_eq!(
buffer.channel().as_ref(),
&channel(channel_id, "the-channel")
&Channel {
id: channel_id,
name: "the-channel".to_string()
}
);
assert!(!buffer.is_connected());
});
}
fn channel(id: u64, name: &'static str) -> Channel {
Channel {
id,
name: name.to_string(),
unseen_note_version: None,
unseen_message_id: None,
}
}
#[gpui::test]
async fn test_rejoin_channel_buffer(
deterministic: Arc<Deterministic>,
@@ -610,284 +565,26 @@ async fn test_channel_buffers_and_server_restarts(
channel_buffer_a.read_with(cx_a, |buffer_a, _| {
channel_buffer_b.read_with(cx_b, |buffer_b, _| {
assert_collaborators(
buffer_a.collaborators(),
&[client_a.user_id(), client_b.user_id()],
assert_eq!(
buffer_a
.collaborators()
.iter()
.map(|c| c.user_id)
.collect::<Vec<_>>(),
vec![client_a.user_id().unwrap(), client_b.user_id().unwrap()]
);
assert_eq!(buffer_a.collaborators(), buffer_b.collaborators());
});
});
}
#[gpui::test(iterations = 10)]
async fn test_following_to_channel_notes_without_a_shared_project(
deterministic: Arc<Deterministic>,
mut cx_a: &mut TestAppContext,
mut cx_b: &mut TestAppContext,
mut cx_c: &mut TestAppContext,
) {
deterministic.forbid_parking();
let mut server = TestServer::start(&deterministic).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let client_c = server.create_client(cx_c, "user_c").await;
cx_a.update(editor::init);
cx_b.update(editor::init);
cx_c.update(editor::init);
cx_a.update(collab_ui::channel_view::init);
cx_b.update(collab_ui::channel_view::init);
cx_c.update(collab_ui::channel_view::init);
let channel_1_id = server
.make_channel(
"channel-1",
None,
(&client_a, cx_a),
&mut [(&client_b, cx_b), (&client_c, cx_c)],
)
.await;
let channel_2_id = server
.make_channel(
"channel-2",
None,
(&client_a, cx_a),
&mut [(&client_b, cx_b), (&client_c, cx_c)],
)
.await;
// Clients A, B, and C join a channel.
let active_call_a = cx_a.read(ActiveCall::global);
let active_call_b = cx_b.read(ActiveCall::global);
let active_call_c = cx_c.read(ActiveCall::global);
for (call, cx) in [
(&active_call_a, &mut cx_a),
(&active_call_b, &mut cx_b),
(&active_call_c, &mut cx_c),
] {
call.update(*cx, |call, cx| call.join_channel(channel_1_id, cx))
.await
.unwrap();
}
deterministic.run_until_parked();
// Clients A, B, and C all open their own unshared projects.
client_a.fs().insert_tree("/a", json!({})).await;
client_b.fs().insert_tree("/b", json!({})).await;
client_c.fs().insert_tree("/c", json!({})).await;
let (project_a, _) = client_a.build_local_project("/a", cx_a).await;
let (project_b, _) = client_b.build_local_project("/b", cx_b).await;
let (project_c, _) = client_b.build_local_project("/c", cx_c).await;
let workspace_a = client_a.build_workspace(&project_a, cx_a).root(cx_a);
let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b);
let _workspace_c = client_c.build_workspace(&project_c, cx_c).root(cx_c);
active_call_a
.update(cx_a, |call, cx| call.set_location(Some(&project_a), cx))
.await
.unwrap();
// Client A opens the notes for channel 1.
let channel_view_1_a = cx_a
.update(|cx| ChannelView::open(channel_1_id, workspace_a.clone(), cx))
.await
.unwrap();
channel_view_1_a.update(cx_a, |notes, cx| {
assert_eq!(notes.channel(cx).name, "channel-1");
notes.editor.update(cx, |editor, cx| {
editor.insert("Hello from A.", cx);
editor.change_selections(None, cx, |selections| {
selections.select_ranges(vec![3..4]);
});
});
});
// Client B follows client A.
workspace_b
.update(cx_b, |workspace, cx| {
workspace.follow(client_a.peer_id().unwrap(), cx).unwrap()
})
.await
.unwrap();
// Client B is taken to the notes for channel 1, with the same
// text selected as client A.
deterministic.run_until_parked();
let channel_view_1_b = workspace_b.read_with(cx_b, |workspace, cx| {
assert_eq!(
workspace.leader_for_pane(workspace.active_pane()),
Some(client_a.peer_id().unwrap())
);
workspace
.active_item(cx)
.expect("no active item")
.downcast::<ChannelView>()
.expect("active item is not a channel view")
});
channel_view_1_b.read_with(cx_b, |notes, cx| {
assert_eq!(notes.channel(cx).name, "channel-1");
let editor = notes.editor.read(cx);
assert_eq!(editor.text(cx), "Hello from A.");
assert_eq!(editor.selections.ranges::<usize>(cx), &[3..4]);
});
// Client A opens the notes for channel 2.
let channel_view_2_a = cx_a
.update(|cx| ChannelView::open(channel_2_id, workspace_a.clone(), cx))
.await
.unwrap();
channel_view_2_a.read_with(cx_a, |notes, cx| {
assert_eq!(notes.channel(cx).name, "channel-2");
});
// Client B is taken to the notes for channel 2.
deterministic.run_until_parked();
let channel_view_2_b = workspace_b.read_with(cx_b, |workspace, cx| {
assert_eq!(
workspace.leader_for_pane(workspace.active_pane()),
Some(client_a.peer_id().unwrap())
);
workspace
.active_item(cx)
.expect("no active item")
.downcast::<ChannelView>()
.expect("active item is not a channel view")
});
channel_view_2_b.read_with(cx_b, |notes, cx| {
assert_eq!(notes.channel(cx).name, "channel-2");
});
}
#[gpui::test]
async fn test_channel_buffer_changes(
deterministic: Arc<Deterministic>,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
) {
deterministic.forbid_parking();
let mut server = TestServer::start(&deterministic).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let channel_id = server
.make_channel(
"the-channel",
None,
(&client_a, cx_a),
&mut [(&client_b, cx_b)],
)
.await;
let channel_buffer_a = client_a
.channel_store()
.update(cx_a, |store, cx| store.open_channel_buffer(channel_id, cx))
.await
.unwrap();
// Client A makes an edit, and client B should see that the note has changed.
channel_buffer_a.update(cx_a, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "1")], None, cx);
})
});
deterministic.run_until_parked();
let has_buffer_changed = cx_b.read(|cx| {
client_b
.channel_store()
.read(cx)
.has_channel_buffer_changed(channel_id)
.unwrap()
});
assert!(has_buffer_changed);
// Opening the buffer should clear the changed flag.
let project_b = client_b.build_empty_local_project(cx_b);
let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b);
let channel_view_b = cx_b
.update(|cx| ChannelView::open(channel_id, workspace_b.clone(), cx))
.await
.unwrap();
deterministic.run_until_parked();
let has_buffer_changed = cx_b.read(|cx| {
client_b
.channel_store()
.read(cx)
.has_channel_buffer_changed(channel_id)
.unwrap()
});
assert!(!has_buffer_changed);
// Editing the channel while the buffer is open should not show that the buffer has changed.
channel_buffer_a.update(cx_a, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "2")], None, cx);
})
});
deterministic.run_until_parked();
let has_buffer_changed = cx_b.read(|cx| {
client_b
.channel_store()
.read(cx)
.has_channel_buffer_changed(channel_id)
.unwrap()
});
assert!(!has_buffer_changed);
deterministic.advance_clock(ACKNOWLEDGE_DEBOUNCE_INTERVAL);
// Test that the server is tracking things correctly, and we retain our 'not changed'
// state across a disconnect
server.simulate_long_connection_interruption(client_b.peer_id().unwrap(), &deterministic);
let has_buffer_changed = cx_b.read(|cx| {
client_b
.channel_store()
.read(cx)
.has_channel_buffer_changed(channel_id)
.unwrap()
});
assert!(!has_buffer_changed);
// Closing the buffer should re-enable change tracking
cx_b.update(|cx| {
workspace_b.update(cx, |workspace, cx| {
workspace.close_all_items_and_panes(&Default::default(), cx)
});
drop(channel_view_b)
});
deterministic.run_until_parked();
channel_buffer_a.update(cx_a, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.edit([(0..0, "3")], None, cx);
})
});
deterministic.run_until_parked();
let has_buffer_changed = cx_b.read(|cx| {
client_b
.channel_store()
.read(cx)
.has_channel_buffer_changed(channel_id)
.unwrap()
});
assert!(has_buffer_changed);
}
#[track_caller]
fn assert_collaborators(collaborators: &HashMap<PeerId, Collaborator>, ids: &[Option<UserId>]) {
let mut user_ids = collaborators
.values()
.map(|collaborator| collaborator.user_id)
.collect::<Vec<_>>();
user_ids.sort();
fn assert_collaborators(collaborators: &[proto::Collaborator], ids: &[Option<UserId>]) {
assert_eq!(
user_ids,
collaborators
.into_iter()
.map(|collaborator| collaborator.user_id)
.collect::<Vec<_>>(),
ids.into_iter().map(|id| id.unwrap()).collect::<Vec<_>>()
);
}

View File

@@ -1,9 +1,7 @@
use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
use channel::{ChannelChat, ChannelMessageId};
use collab_ui::chat_panel::ChatPanel;
use gpui::{executor::Deterministic, BorrowAppContext, ModelHandle, TestAppContext};
use gpui::{executor::Deterministic, ModelHandle, TestAppContext};
use std::sync::Arc;
use workspace::dock::Panel;
#[gpui::test]
async fn test_basic_channel_messages(
@@ -225,136 +223,3 @@ fn assert_messages(chat: &ModelHandle<ChannelChat>, messages: &[&str], cx: &mut
messages
);
}
#[gpui::test]
async fn test_channel_message_changes(
deterministic: Arc<Deterministic>,
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
) {
deterministic.forbid_parking();
let mut server = TestServer::start(&deterministic).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
let channel_id = server
.make_channel(
"the-channel",
None,
(&client_a, cx_a),
&mut [(&client_b, cx_b)],
)
.await;
// Client A sends a message, client B should see that there is a new message.
let channel_chat_a = client_a
.channel_store()
.update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx))
.await
.unwrap();
channel_chat_a
.update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap())
.await
.unwrap();
deterministic.run_until_parked();
let b_has_messages = cx_b.read_with(|cx| {
client_b
.channel_store()
.read(cx)
.has_new_messages(channel_id)
.unwrap()
});
assert!(b_has_messages);
// Opening the chat should clear the changed flag.
cx_b.update(|cx| {
collab_ui::init(&client_b.app_state, cx);
});
let project_b = client_b.build_empty_local_project(cx_b);
let workspace_b = client_b.build_workspace(&project_b, cx_b).root(cx_b);
let chat_panel_b = workspace_b.update(cx_b, |workspace, cx| ChatPanel::new(workspace, cx));
chat_panel_b
.update(cx_b, |chat_panel, cx| {
chat_panel.set_active(true, cx);
chat_panel.select_channel(channel_id, cx)
})
.await
.unwrap();
deterministic.run_until_parked();
let b_has_messages = cx_b.read_with(|cx| {
client_b
.channel_store()
.read(cx)
.has_new_messages(channel_id)
.unwrap()
});
assert!(!b_has_messages);
// Sending a message while the chat is open should not change the flag.
channel_chat_a
.update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap())
.await
.unwrap();
deterministic.run_until_parked();
let b_has_messages = cx_b.read_with(|cx| {
client_b
.channel_store()
.read(cx)
.has_new_messages(channel_id)
.unwrap()
});
assert!(!b_has_messages);
// Sending a message while the chat is closed should change the flag.
chat_panel_b.update(cx_b, |chat_panel, cx| {
chat_panel.set_active(false, cx);
});
// Sending a message while the chat is open should not change the flag.
channel_chat_a
.update(cx_a, |c, cx| c.send_message("three".into(), cx).unwrap())
.await
.unwrap();
deterministic.run_until_parked();
let b_has_messages = cx_b.read_with(|cx| {
client_b
.channel_store()
.read(cx)
.has_new_messages(channel_id)
.unwrap()
});
assert!(b_has_messages);
// Closing the chat should re-enable change tracking
cx_b.update(|_| drop(chat_panel_b));
channel_chat_a
.update(cx_a, |c, cx| c.send_message("four".into(), cx).unwrap())
.await
.unwrap();
deterministic.run_until_parked();
let b_has_messages = cx_b.read_with(|cx| {
client_b
.channel_store()
.read(cx)
.has_new_messages(channel_id)
.unwrap()
});
assert!(b_has_messages);
}

View File

@@ -380,8 +380,6 @@ async fn test_channel_room(
// Give everyone a chance to observe user A joining
deterministic.run_until_parked();
let room_a = active_call_a.read_with(cx_a, |call, _| call.room().unwrap().clone());
room_a.read_with(cx_a, |room, _| assert!(room.is_connected()));
client_a.channel_store().read_with(cx_a, |channels, _| {
assert_participants_eq(

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -46,7 +46,12 @@ impl RandomizedTest for RandomChannelBufferTest {
let db = &server.app_state.db;
for ix in 0..CHANNEL_COUNT {
let id = db
.create_channel(&format!("channel-{ix}"), None, users[0].user_id)
.create_channel(
&format!("channel-{ix}"),
None,
&format!("livekit-room-{ix}"),
users[0].user_id,
)
.await
.unwrap();
for user in &users[1..] {
@@ -268,7 +273,7 @@ impl RandomizedTest for RandomChannelBufferTest {
// channel buffer.
let collaborators = channel_buffer.collaborators();
let mut user_ids =
collaborators.values().map(|c| c.user_id).collect::<Vec<_>>();
collaborators.iter().map(|c| c.user_id).collect::<Vec<_>>();
user_ids.sort();
assert_eq!(
user_ids,

View File

@@ -1,7 +1,7 @@
use crate::{
db::{tests::TestDb, NewUserParams, UserId},
executor::Executor,
rpc::{Server, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
rpc::{Server, CLEANUP_TIMEOUT},
AppState,
};
use anyhow::anyhow;
@@ -17,7 +17,6 @@ use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHan
use language::LanguageRegistry;
use parking_lot::Mutex;
use project::{Project, WorktreeId};
use rpc::RECEIVE_TIMEOUT;
use settings::SettingsStore;
use std::{
cell::{Ref, RefCell, RefMut},
@@ -30,7 +29,7 @@ use std::{
},
};
use util::http::FakeHttpClient;
use workspace::{Workspace, WorkspaceStore};
use workspace::Workspace;
pub struct TestServer {
pub app_state: Arc<AppState>,
@@ -44,7 +43,6 @@ pub struct TestServer {
pub struct TestClient {
pub username: String,
pub app_state: Arc<workspace::AppState>,
channel_store: ModelHandle<ChannelStore>,
state: RefCell<TestClientState>,
}
@@ -153,12 +151,12 @@ impl TestServer {
Arc::get_mut(&mut client)
.unwrap()
.set_id(user_id.to_proto())
.set_id(user_id.0 as usize)
.override_authenticate(move |cx| {
cx.spawn(|_| async move {
let access_token = "the-token".to_string();
Ok(Credentials {
user_id: user_id.to_proto(),
user_id: user_id.0 as u64,
access_token,
})
})
@@ -206,14 +204,13 @@ impl TestServer {
let fs = FakeFs::new(cx.background());
let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
let workspace_store = cx.add_model(|cx| WorkspaceStore::new(client.clone(), cx));
let mut language_registry = LanguageRegistry::test();
language_registry.set_executor(cx.background());
let channel_store =
cx.add_model(|cx| ChannelStore::new(client.clone(), user_store.clone(), cx));
let app_state = Arc::new(workspace::AppState {
client: client.clone(),
user_store: user_store.clone(),
workspace_store,
languages: Arc::new(language_registry),
channel_store: channel_store.clone(),
languages: Arc::new(LanguageRegistry::test()),
fs: fs.clone(),
build_window_options: |_, _, _| Default::default(),
initialize_workspace: |_, _, _, _| Task::ready(Ok(())),
@@ -229,7 +226,7 @@ impl TestServer {
workspace::init(app_state.clone(), cx);
audio::init((), cx);
call::init(client.clone(), user_store.clone(), cx);
channel::init(&client, user_store, cx);
channel::init(&client);
});
client
@@ -240,7 +237,6 @@ impl TestServer {
let client = TestClient {
app_state,
username: name.to_string(),
channel_store: cx.read(ChannelStore::global).clone(),
state: Default::default(),
};
client.wait_for_current_user(cx).await;
@@ -255,19 +251,6 @@ impl TestServer {
.store(true, SeqCst);
}
pub fn simulate_long_connection_interruption(
&self,
peer_id: PeerId,
deterministic: &Arc<Deterministic>,
) {
self.forbid_connections();
self.disconnect_client(peer_id);
deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
self.allow_connections();
deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
deterministic.run_until_parked();
}
pub fn forbid_connections(&self) {
self.forbid_connections.store(true, SeqCst);
}
@@ -309,9 +292,10 @@ impl TestServer {
admin: (&TestClient, &mut TestAppContext),
members: &mut [(&TestClient, &mut TestAppContext)],
) -> u64 {
let (_, admin_cx) = admin;
let channel_id = admin_cx
.read(ChannelStore::global)
let (admin_client, admin_cx) = admin;
let channel_id = admin_client
.app_state
.channel_store
.update(admin_cx, |channel_store, cx| {
channel_store.create_channel(channel, parent, cx)
})
@@ -319,8 +303,9 @@ impl TestServer {
.unwrap();
for (member_client, member_cx) in members {
admin_cx
.read(ChannelStore::global)
admin_client
.app_state
.channel_store
.update(admin_cx, |channel_store, cx| {
channel_store.invite_member(
channel_id,
@@ -334,8 +319,9 @@ impl TestServer {
admin_cx.foreground().run_until_parked();
member_cx
.read(ChannelStore::global)
member_client
.app_state
.channel_store
.update(*member_cx, |channels, _| {
channels.respond_to_channel_invite(channel_id, true)
})
@@ -443,7 +429,7 @@ impl TestClient {
}
pub fn channel_store(&self) -> &ModelHandle<ChannelStore> {
&self.channel_store
&self.app_state.channel_store
}
pub fn user_store(&self) -> &ModelHandle<UserStore> {
@@ -550,7 +536,15 @@ impl TestClient {
root_path: impl AsRef<Path>,
cx: &mut TestAppContext,
) -> (ModelHandle<Project>, WorktreeId) {
let project = self.build_empty_local_project(cx);
let project = cx.update(|cx| {
Project::local(
self.client().clone(),
self.app_state.user_store.clone(),
self.app_state.languages.clone(),
self.app_state.fs.clone(),
cx,
)
});
let (worktree, _) = project
.update(cx, |p, cx| {
p.find_or_create_local_worktree(root_path, true, cx)
@@ -563,18 +557,6 @@ impl TestClient {
(project, worktree.read_with(cx, |tree, _| tree.id()))
}
pub fn build_empty_local_project(&self, cx: &mut TestAppContext) -> ModelHandle<Project> {
cx.update(|cx| {
Project::local(
self.client().clone(),
self.app_state.user_store.clone(),
self.app_state.languages.clone(),
self.app_state.fs.clone(),
cx,
)
})
}
pub async fn build_remote_project(
&self,
host_project_id: u64,
@@ -610,8 +592,8 @@ impl TestClient {
) {
let (other_client, other_cx) = user;
cx_self
.read(ChannelStore::global)
self.app_state
.channel_store
.update(cx_self, |channel_store, cx| {
channel_store.invite_member(channel, other_client.user_id().unwrap(), true, cx)
})
@@ -620,10 +602,11 @@ impl TestClient {
cx_self.foreground().run_until_parked();
other_cx
.read(ChannelStore::global)
.update(other_cx, |channel_store, _| {
channel_store.respond_to_channel_invite(channel, true)
other_client
.app_state
.channel_store
.update(other_cx, |channels, _| {
channels.respond_to_channel_invite(channel, true)
})
.await
.unwrap();

View File

@@ -37,7 +37,6 @@ fuzzy = { path = "../fuzzy" }
gpui = { path = "../gpui" }
language = { path = "../language" }
menu = { path = "../menu" }
rich_text = { path = "../rich_text" }
picker = { path = "../picker" }
project = { path = "../project" }
recent_projects = {path = "../recent_projects"}

View File

@@ -1,12 +1,10 @@
use anyhow::{anyhow, Result};
use call::report_call_event_for_channel;
use channel::{Channel, ChannelBuffer, ChannelBufferEvent, ChannelId, ChannelStore};
use client::{
proto::{self, PeerId},
Collaborator, ParticipantIndex,
};
use channel::{ChannelBuffer, ChannelBufferEvent, ChannelId};
use client::proto;
use clock::ReplicaId;
use collections::HashMap;
use editor::{CollaborationHub, Editor};
use editor::Editor;
use gpui::{
actions,
elements::{ChildView, Label},
@@ -15,41 +13,32 @@ use gpui::{
ViewContext, ViewHandle,
};
use project::Project;
use std::{
any::{Any, TypeId},
sync::Arc,
};
use util::ResultExt;
use std::any::{Any, TypeId};
use workspace::{
item::{FollowableItem, Item, ItemHandle},
register_followable_item,
searchable::SearchableItemHandle,
ItemNavHistory, Pane, SaveIntent, ViewId, Workspace, WorkspaceId,
ItemNavHistory, Pane, ViewId, Workspace, WorkspaceId,
};
actions!(channel_view, [Deploy]);
pub fn init(cx: &mut AppContext) {
pub(crate) fn init(cx: &mut AppContext) {
register_followable_item::<ChannelView>(cx)
}
pub struct ChannelView {
pub editor: ViewHandle<Editor>,
project: ModelHandle<Project>,
channel_store: ModelHandle<ChannelStore>,
channel_buffer: ModelHandle<ChannelBuffer>,
remote_id: Option<ViewId>,
_editor_event_subscription: Subscription,
}
impl ChannelView {
pub fn open(
channel_id: ChannelId,
workspace: ViewHandle<Workspace>,
cx: &mut AppContext,
) -> Task<Result<ViewHandle<Self>>> {
pub fn deploy(channel_id: ChannelId, workspace: ViewHandle<Workspace>, cx: &mut AppContext) {
let pane = workspace.read(cx).active_pane().clone();
let channel_view = Self::open_in_pane(channel_id, pane.clone(), workspace.clone(), cx);
let channel_view = Self::open(channel_id, pane.clone(), workspace.clone(), cx);
cx.spawn(|mut cx| async move {
let channel_view = channel_view.await?;
pane.update(&mut cx, |pane, cx| {
@@ -59,13 +48,14 @@ impl ChannelView {
&workspace.read(cx).app_state().client,
cx,
);
pane.add_item(Box::new(channel_view.clone()), true, true, None, cx);
pane.add_item(Box::new(channel_view), true, true, None, cx);
});
anyhow::Ok(channel_view)
anyhow::Ok(())
})
.detach();
}
pub fn open_in_pane(
pub fn open(
channel_id: ChannelId,
pane: ViewHandle<Pane>,
workspace: ViewHandle<Workspace>,
@@ -73,7 +63,7 @@ impl ChannelView {
) -> Task<Result<ViewHandle<Self>>> {
let workspace = workspace.read(cx);
let project = workspace.project().to_owned();
let channel_store = ChannelStore::global(cx);
let channel_store = workspace.app_state().channel_store.clone();
let markdown = workspace
.app_state()
.languages
@@ -84,45 +74,17 @@ impl ChannelView {
cx.spawn(|mut cx| async move {
let channel_buffer = channel_buffer.await?;
if let Some(markdown) = markdown.await.log_err() {
channel_buffer.update(&mut cx, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.set_language(Some(markdown), cx);
})
});
}
let markdown = markdown.await?;
channel_buffer.update(&mut cx, |buffer, cx| {
buffer.buffer().update(cx, |buffer, cx| {
buffer.set_language(Some(markdown), cx);
})
});
pane.update(&mut cx, |pane, cx| {
let buffer_id = channel_buffer.read(cx).remote_id(cx);
let existing_view = pane
.items_of_type::<Self>()
.find(|view| view.read(cx).channel_buffer.read(cx).remote_id(cx) == buffer_id);
// If this channel buffer is already open in this pane, just return it.
if let Some(existing_view) = existing_view.clone() {
if existing_view.read(cx).channel_buffer == channel_buffer {
return existing_view;
}
}
let view = cx.add_view(|cx| {
let mut this = Self::new(project, channel_store, channel_buffer, cx);
this.acknowledge_buffer_version(cx);
this
});
// If the pane contained a disconnected view for this channel buffer,
// replace that.
if let Some(existing_item) = existing_view {
if let Some(ix) = pane.index_for_item(&existing_item) {
pane.close_item_by_id(existing_item.id(), SaveIntent::Skip, cx)
.detach();
pane.add_item(Box::new(view.clone()), true, true, Some(ix), cx);
}
}
view
pane.items_of_type::<Self>()
.find(|channel_view| channel_view.read(cx).channel_buffer == channel_buffer)
.unwrap_or_else(|| cx.add_view(|cx| Self::new(project, channel_buffer, cx)))
})
.ok_or_else(|| anyhow!("pane was dropped"))
})
@@ -130,35 +92,44 @@ impl ChannelView {
pub fn new(
project: ModelHandle<Project>,
channel_store: ModelHandle<ChannelStore>,
channel_buffer: ModelHandle<ChannelBuffer>,
cx: &mut ViewContext<Self>,
) -> Self {
let buffer = channel_buffer.read(cx).buffer();
let editor = cx.add_view(|cx| {
let mut editor = Editor::for_buffer(buffer, None, cx);
editor.set_collaboration_hub(Box::new(ChannelBufferCollaborationHub(
channel_buffer.clone(),
)));
editor
});
let editor = cx.add_view(|cx| Editor::for_buffer(buffer, None, cx));
let _editor_event_subscription = cx.subscribe(&editor, |_, _, e, cx| cx.emit(e.clone()));
cx.subscribe(&project, Self::handle_project_event).detach();
cx.subscribe(&channel_buffer, Self::handle_channel_buffer_event)
.detach();
Self {
let this = Self {
editor,
project,
channel_store,
channel_buffer,
remote_id: None,
_editor_event_subscription,
}
};
this.refresh_replica_id_map(cx);
this
}
pub fn channel(&self, cx: &AppContext) -> Arc<Channel> {
self.channel_buffer.read(cx).channel()
fn handle_project_event(
&mut self,
_: ModelHandle<Project>,
event: &project::Event,
cx: &mut ViewContext<Self>,
) {
match event {
project::Event::RemoteIdChanged(_) => {}
project::Event::DisconnectedFromHost => {}
project::Event::Closed => {}
project::Event::CollaboratorUpdated { .. } => {}
project::Event::CollaboratorLeft(_) => {}
project::Event::CollaboratorJoined(_) => {}
_ => return,
}
self.refresh_replica_id_map(cx);
}
fn handle_channel_buffer_event(
@@ -168,41 +139,48 @@ impl ChannelView {
cx: &mut ViewContext<Self>,
) {
match event {
ChannelBufferEvent::CollaboratorsChanged => {
self.refresh_replica_id_map(cx);
}
ChannelBufferEvent::Disconnected => self.editor.update(cx, |editor, cx| {
editor.set_read_only(true);
cx.notify();
}),
ChannelBufferEvent::BufferEdited => {
if cx.is_self_focused() || self.editor.is_focused(cx) {
self.acknowledge_buffer_version(cx);
} else {
self.channel_store.update(cx, |store, cx| {
let channel_buffer = self.channel_buffer.read(cx);
store.notes_changed(
channel_buffer.channel().id,
channel_buffer.epoch(),
&channel_buffer.buffer().read(cx).version(),
cx,
)
});
}
}
_ => {}
}
}
fn acknowledge_buffer_version(&mut self, cx: &mut ViewContext<'_, '_, ChannelView>) {
self.channel_store.update(cx, |store, cx| {
let channel_buffer = self.channel_buffer.read(cx);
store.acknowledge_notes_version(
channel_buffer.channel().id,
channel_buffer.epoch(),
&channel_buffer.buffer().read(cx).version(),
cx,
)
});
self.channel_buffer.update(cx, |buffer, cx| {
buffer.acknowledge_buffer_version(cx);
/// Build a mapping of channel buffer replica ids to the corresponding
/// replica ids in the current project.
///
/// Using this mapping, a given user can be displayed with the same color
/// in the channel buffer as in other files in the project. Users who are
/// in the channel buffer but not the project will not have a color.
fn refresh_replica_id_map(&self, cx: &mut ViewContext<Self>) {
let mut project_replica_ids_by_channel_buffer_replica_id = HashMap::default();
let project = self.project.read(cx);
let channel_buffer = self.channel_buffer.read(cx);
project_replica_ids_by_channel_buffer_replica_id
.insert(channel_buffer.replica_id(cx), project.replica_id());
project_replica_ids_by_channel_buffer_replica_id.extend(
channel_buffer
.collaborators()
.iter()
.filter_map(|channel_buffer_collaborator| {
project
.collaborators()
.values()
.find_map(|project_collaborator| {
(project_collaborator.user_id == channel_buffer_collaborator.user_id)
.then_some((
channel_buffer_collaborator.replica_id as ReplicaId,
project_collaborator.replica_id,
))
})
}),
);
self.editor.update(cx, |editor, cx| {
editor.set_replica_id_map(Some(project_replica_ids_by_channel_buffer_replica_id), cx)
});
}
}
@@ -222,7 +200,6 @@ impl View for ChannelView {
fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext<Self>) {
if cx.is_self_focused() {
self.acknowledge_buffer_version(cx);
cx.focus(self.editor.as_any())
}
}
@@ -262,7 +239,6 @@ impl Item for ChannelView {
fn clone_on_split(&self, _: WorkspaceId, cx: &mut ViewContext<Self>) -> Option<Self> {
Some(Self::new(
self.project.clone(),
self.channel_store.clone(),
self.channel_buffer.clone(),
cx,
))
@@ -306,14 +282,10 @@ impl FollowableItem for ChannelView {
}
fn to_state_proto(&self, cx: &AppContext) -> Option<proto::view::Variant> {
let channel_buffer = self.channel_buffer.read(cx);
if !channel_buffer.is_connected() {
return None;
}
let channel = self.channel_buffer.read(cx).channel();
Some(proto::view::Variant::ChannelView(
proto::view::ChannelView {
channel_id: channel_buffer.channel().id,
channel_id: channel.id,
editor: if let Some(proto::view::Variant::Editor(proto)) =
self.editor.read(cx).to_state_proto(cx)
{
@@ -339,7 +311,7 @@ impl FollowableItem for ChannelView {
unreachable!()
};
let open = ChannelView::open_in_pane(state.channel_id, pane, workspace, cx);
let open = ChannelView::open(state.channel_id, pane, workspace, cx);
Some(cx.spawn(|mut cx| async move {
let this = open.await?;
@@ -399,32 +371,17 @@ impl FollowableItem for ChannelView {
})
}
fn set_leader_peer_id(&mut self, leader_peer_id: Option<PeerId>, cx: &mut ViewContext<Self>) {
fn set_leader_replica_id(
&mut self,
leader_replica_id: Option<u16>,
cx: &mut ViewContext<Self>,
) {
self.editor.update(cx, |editor, cx| {
editor.set_leader_peer_id(leader_peer_id, cx)
editor.set_leader_replica_id(leader_replica_id, cx)
})
}
fn should_unfollow_on_event(event: &Self::Event, cx: &AppContext) -> bool {
Editor::should_unfollow_on_event(event, cx)
}
fn is_project_item(&self, _cx: &AppContext) -> bool {
false
}
}
struct ChannelBufferCollaborationHub(ModelHandle<ChannelBuffer>);
impl CollaborationHub for ChannelBufferCollaborationHub {
fn collaborators<'a>(&self, cx: &'a AppContext) -> &'a HashMap<PeerId, Collaborator> {
self.0.read(cx).collaborators()
}
fn user_participant_indices<'a>(
&self,
cx: &'a AppContext,
) -> &'a HashMap<u64, ParticipantIndex> {
self.0.read(cx).user_store().read(cx).participant_indices()
}
}

View File

@@ -3,7 +3,6 @@ use anyhow::Result;
use call::ActiveCall;
use channel::{ChannelChat, ChannelChatEvent, ChannelMessageId, ChannelStore};
use client::Client;
use collections::HashMap;
use db::kvp::KEY_VALUE_STORE;
use editor::Editor;
use feature_flags::{ChannelsAlpha, FeatureFlagAppExt};
@@ -13,13 +12,12 @@ use gpui::{
platform::{CursorStyle, MouseButton},
serde_json,
views::{ItemType, Select, SelectStyle},
AnyViewHandle, AppContext, AsyncAppContext, Entity, ImageData, ModelHandle, Subscription, Task,
View, ViewContext, ViewHandle, WeakViewHandle,
AnyViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Subscription, Task, View,
ViewContext, ViewHandle, WeakViewHandle,
};
use language::{language_settings::SoftWrap, LanguageRegistry};
use language::language_settings::SoftWrap;
use menu::Confirm;
use project::Fs;
use rich_text::RichText;
use serde::{Deserialize, Serialize};
use settings::SettingsStore;
use std::sync::Arc;
@@ -37,7 +35,6 @@ const CHAT_PANEL_KEY: &'static str = "ChatPanel";
pub struct ChatPanel {
client: Arc<Client>,
channel_store: ModelHandle<ChannelStore>,
languages: Arc<LanguageRegistry>,
active_chat: Option<(ModelHandle<ChannelChat>, Subscription)>,
message_list: ListState<ChatPanel>,
input_editor: ViewHandle<Editor>,
@@ -45,12 +42,10 @@ pub struct ChatPanel {
local_timezone: UtcOffset,
fs: Arc<dyn Fs>,
width: Option<f32>,
active: bool,
pending_serialization: Task<Option<()>>,
subscriptions: Vec<gpui::Subscription>,
workspace: WeakViewHandle<Workspace>,
has_focus: bool,
markdown_data: HashMap<ChannelMessageId, RichText>,
}
#[derive(Serialize, Deserialize)]
@@ -81,8 +76,7 @@ impl ChatPanel {
pub fn new(workspace: &mut Workspace, cx: &mut ViewContext<Workspace>) -> ViewHandle<Self> {
let fs = workspace.app_state().fs.clone();
let client = workspace.app_state().client.clone();
let channel_store = ChannelStore::global(cx);
let languages = workspace.app_state().languages.clone();
let channel_store = workspace.app_state().channel_store.clone();
let input_editor = cx.add_view(|cx| {
let mut editor = Editor::auto_height(
@@ -135,8 +129,6 @@ impl ChatPanel {
fs,
client,
channel_store,
languages,
active_chat: Default::default(),
pending_serialization: Task::ready(None),
message_list,
@@ -146,9 +138,7 @@ impl ChatPanel {
has_focus: false,
subscriptions: Vec::new(),
workspace: workspace_handle,
active: false,
width: None,
markdown_data: Default::default(),
};
let mut old_dock_position = this.position(cx);
@@ -164,9 +154,9 @@ impl ChatPanel {
}),
);
this.update_channel_count(cx);
this.init_active_channel(cx);
cx.observe(&this.channel_store, |this, _, cx| {
this.update_channel_count(cx)
this.init_active_channel(cx);
})
.detach();
@@ -185,33 +175,10 @@ impl ChatPanel {
})
.detach();
let markdown = this.languages.language_for_name("Markdown");
cx.spawn(|this, mut cx| async move {
let markdown = markdown.await?;
this.update(&mut cx, |this, cx| {
this.input_editor.update(cx, |editor, cx| {
editor.buffer().update(cx, |multi_buffer, cx| {
multi_buffer
.as_singleton()
.unwrap()
.update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx))
})
})
})?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
this
})
}
pub fn active_chat(&self) -> Option<ModelHandle<ChannelChat>> {
self.active_chat.as_ref().map(|(chat, _)| chat.clone())
}
pub fn load(
workspace: WeakViewHandle<Workspace>,
cx: AsyncAppContext,
@@ -258,8 +225,10 @@ impl ChatPanel {
);
}
fn update_channel_count(&mut self, cx: &mut ViewContext<Self>) {
fn init_active_channel(&mut self, cx: &mut ViewContext<Self>) {
let channel_count = self.channel_store.read(cx).channel_count();
self.message_list.reset(0);
self.active_chat = None;
self.channel_select.update(cx, |select, cx| {
select.set_item_count(channel_count, cx);
});
@@ -278,7 +247,6 @@ impl ChatPanel {
}
let subscription = cx.subscribe(&chat, Self::channel_did_change);
self.active_chat = Some((chat, subscription));
self.acknowledge_last_message(cx);
self.channel_select.update(cx, |select, cx| {
if let Some(ix) = self.channel_store.read(cx).index_of_channel(id) {
select.set_selected_index(ix, cx);
@@ -300,34 +268,11 @@ impl ChatPanel {
new_count,
} => {
self.message_list.splice(old_range.clone(), *new_count);
if self.active {
self.acknowledge_last_message(cx);
}
}
ChannelChatEvent::NewMessage {
channel_id,
message_id,
} => {
if !self.active {
self.channel_store.update(cx, |store, cx| {
store.new_message(*channel_id, *message_id, cx)
})
}
}
}
cx.notify();
}
fn acknowledge_last_message(&mut self, cx: &mut ViewContext<'_, '_, ChatPanel>) {
if self.active {
if let Some((chat, _)) = &self.active_chat {
chat.update(cx, |chat, cx| {
chat.acknowledge_last_message(cx);
});
}
}
}
fn render_channel(&self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
let theme = theme::current(cx);
Flex::column()
@@ -354,33 +299,13 @@ impl ChatPanel {
messages.flex(1., true).into_any()
}
fn render_message(&mut self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
let (message, is_continuation, is_last) = {
let active_chat = self.active_chat.as_ref().unwrap().0.read(cx);
let last_message = active_chat.message(ix.saturating_sub(1));
let this_message = active_chat.message(ix);
let is_continuation = last_message.id != this_message.id
&& this_message.sender.id == last_message.sender.id;
(
active_chat.message(ix).clone(),
is_continuation,
active_chat.message_count() == ix + 1,
)
};
let is_pending = message.is_pending();
let text = self
.markdown_data
.entry(message.id)
.or_insert_with(|| rich_text::render_markdown(message.body, &self.languages, None));
fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
let message = self.active_chat.as_ref().unwrap().0.read(cx).message(ix);
let now = OffsetDateTime::now_utc();
let theme = theme::current(cx);
let style = if is_pending {
let style = if message.is_pending() {
&theme.chat_panel.pending_message
} else if is_continuation {
&theme.chat_panel.continuation_message
} else {
&theme.chat_panel.message
};
@@ -393,90 +318,52 @@ impl ChatPanel {
None
};
enum MessageBackgroundHighlight {}
MouseEventHandler::new::<MessageBackgroundHighlight, _>(ix, cx, |state, cx| {
let container = style.container.style_for(state);
if is_continuation {
enum DeleteMessage {}
let body = message.body.clone();
Flex::column()
.with_child(
Flex::row()
.with_child(
text.element(
theme.editor.syntax.clone(),
style.body.clone(),
theme.editor.document_highlight_read_background,
cx,
Label::new(
message.sender.github_login.clone(),
style.sender.text.clone(),
)
.flex(1., true),
)
.with_child(render_remove(message_id_to_remove, cx, &theme))
.contained()
.with_style(*container)
.with_margin_bottom(if is_last {
theme.chat_panel.last_message_bottom_spacing
} else {
0.
})
.into_any()
} else {
Flex::column()
.with_child(
Flex::row()
.with_child(
Flex::row()
.with_child(render_avatar(
message.sender.avatar.clone(),
&theme,
))
.with_child(
Label::new(
message.sender.github_login.clone(),
style.sender.text.clone(),
)
.contained()
.with_style(style.sender.container),
)
.with_child(
Label::new(
format_timestamp(
message.timestamp,
now,
self.local_timezone,
),
style.timestamp.text.clone(),
)
.contained()
.with_style(style.timestamp.container),
)
.align_children_center()
.flex(1., true),
)
.with_child(render_remove(message_id_to_remove, cx, &theme))
.align_children_center(),
.contained()
.with_style(style.sender.container),
)
.with_child(
Flex::row()
.with_child(
text.element(
theme.editor.syntax.clone(),
style.body.clone(),
theme.editor.document_highlight_read_background,
cx,
)
.flex(1., true),
)
// Add a spacer to make everything line up
.with_child(render_remove(None, cx, &theme)),
Label::new(
format_timestamp(message.timestamp, now, self.local_timezone),
style.timestamp.text.clone(),
)
.contained()
.with_style(style.timestamp.container),
)
.contained()
.with_style(*container)
.with_margin_bottom(if is_last {
theme.chat_panel.last_message_bottom_spacing
} else {
0.
})
.into_any()
}
})
.into_any()
.with_children(message_id_to_remove.map(|id| {
MouseEventHandler::new::<DeleteMessage, _>(
id as usize,
cx,
|mouse_state, _| {
let button_style =
theme.chat_panel.icon_button.style_for(mouse_state);
render_icon_button(button_style, "icons/x.svg")
.aligned()
.into_any()
},
)
.with_padding(Padding::uniform(2.))
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, move |_, this, cx| {
this.remove_message(id, cx);
})
.flex_float()
})),
)
.with_child(Text::new(body, style.body.clone()))
.contained()
.with_style(style.container)
.into_any()
}
fn render_input_box(&self, theme: &Arc<Theme>, cx: &AppContext) -> AnyElement<Self> {
@@ -522,7 +409,7 @@ impl ChatPanel {
})
.on_click(MouseButton::Left, move |_, _, cx| {
if let Some(workspace) = workspace.upgrade(cx) {
ChannelView::open(channel_id, workspace, cx).detach();
ChannelView::deploy(channel_id, workspace, cx);
}
})
.with_tooltip::<OpenChannelNotes>(
@@ -650,7 +537,6 @@ impl ChatPanel {
cx.spawn(|this, mut cx| async move {
let chat = open_chat.await?;
this.update(&mut cx, |this, cx| {
this.markdown_data = Default::default();
this.set_active_chat(chat, cx);
})
})
@@ -660,7 +546,7 @@ impl ChatPanel {
if let Some((chat, _)) = &self.active_chat {
let channel_id = chat.read(cx).channel().id;
if let Some(workspace) = self.workspace.upgrade(cx) {
ChannelView::open(channel_id, workspace, cx).detach();
ChannelView::deploy(channel_id, workspace, cx);
}
}
}
@@ -675,72 +561,6 @@ impl ChatPanel {
}
}
fn render_avatar(avatar: Option<Arc<ImageData>>, theme: &Arc<Theme>) -> AnyElement<ChatPanel> {
let avatar_style = theme.chat_panel.avatar;
avatar
.map(|avatar| {
Image::from_data(avatar)
.with_style(avatar_style.image)
.aligned()
.contained()
.with_corner_radius(avatar_style.outer_corner_radius)
.constrained()
.with_width(avatar_style.outer_width)
.with_height(avatar_style.outer_width)
.into_any()
})
.unwrap_or_else(|| {
Empty::new()
.constrained()
.with_width(avatar_style.outer_width)
.into_any()
})
.contained()
.with_style(theme.chat_panel.avatar_container)
.into_any()
}
fn render_remove(
message_id_to_remove: Option<u64>,
cx: &mut ViewContext<'_, '_, ChatPanel>,
theme: &Arc<Theme>,
) -> AnyElement<ChatPanel> {
enum DeleteMessage {}
message_id_to_remove
.map(|id| {
MouseEventHandler::new::<DeleteMessage, _>(id as usize, cx, |mouse_state, _| {
let button_style = theme.chat_panel.icon_button.style_for(mouse_state);
render_icon_button(button_style, "icons/x.svg")
.aligned()
.into_any()
})
.with_padding(Padding::uniform(2.))
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, move |_, this, cx| {
this.remove_message(id, cx);
})
.flex_float()
.into_any()
})
.unwrap_or_else(|| {
let style = theme.chat_panel.icon_button.default;
Empty::new()
.constrained()
.with_width(style.icon_width)
.aligned()
.constrained()
.with_width(style.button_width)
.with_height(style.button_width)
.contained()
.with_uniform_padding(2.)
.flex_float()
.into_any()
})
}
impl Entity for ChatPanel {
type Event = Event;
}
@@ -807,12 +627,8 @@ impl Panel for ChatPanel {
}
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
self.active = active;
if active {
self.acknowledge_last_message(cx);
if !is_chat_feature_enabled(cx) {
cx.emit(Event::Dismissed);
}
if active && !is_chat_feature_enabled(cx) {
cx.emit(Event::Dismissed);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -215,13 +215,7 @@ impl CollabTitlebarItem {
let git_style = theme.titlebar.git_menu_button.clone();
let item_spacing = theme.titlebar.item_spacing;
let mut ret = Flex::row();
if let Some(project_host) = self.collect_project_host(theme.clone(), cx) {
ret = ret.with_child(project_host)
}
ret = ret.with_child(
let mut ret = Flex::row().with_child(
Stack::new()
.with_child(
MouseEventHandler::new::<ToggleProjectMenu, _>(0, cx, |mouse_state, cx| {
@@ -289,71 +283,6 @@ impl CollabTitlebarItem {
ret.into_any()
}
fn collect_project_host(
&self,
theme: Arc<Theme>,
cx: &mut ViewContext<Self>,
) -> Option<AnyElement<Self>> {
if ActiveCall::global(cx).read(cx).room().is_none() {
return None;
}
let project = self.project.read(cx);
let user_store = self.user_store.read(cx);
if project.is_local() {
return None;
}
let Some(host) = project.host() else {
return None;
};
let (Some(host_user), Some(participant_index)) = (
user_store.get_cached_user(host.user_id),
user_store.participant_indices().get(&host.user_id),
) else {
return None;
};
enum ProjectHost {}
enum ProjectHostTooltip {}
let host_style = theme.titlebar.project_host.clone();
let selection_style = theme
.editor
.selection_style_for_room_participant(participant_index.0);
let peer_id = host.peer_id.clone();
Some(
MouseEventHandler::new::<ProjectHost, _>(0, cx, |mouse_state, _| {
let mut host_style = host_style.style_for(mouse_state).clone();
host_style.text.color = selection_style.cursor;
Label::new(host_user.github_login.clone(), host_style.text)
.contained()
.with_style(host_style.container)
.aligned()
.left()
})
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, move |_, this, cx| {
if let Some(workspace) = this.workspace.upgrade(cx) {
if let Some(task) =
workspace.update(cx, |workspace, cx| workspace.follow(peer_id, cx))
{
task.detach_and_log_err(cx);
}
}
})
.with_tooltip::<ProjectHostTooltip>(
0,
host_user.github_login.clone() + " is sharing this project. Click to follow.",
None,
theme.tooltip.clone(),
cx,
)
.into_any_named("project-host"),
)
}
fn window_activation_changed(&mut self, active: bool, cx: &mut ViewContext<Self>) {
let project = if active {
Some(self.project.clone())
@@ -948,7 +877,7 @@ impl CollabTitlebarItem {
fn render_face_pile(
&self,
user: &User,
_replica_id: Option<ReplicaId>,
replica_id: Option<ReplicaId>,
peer_id: PeerId,
location: Option<ParticipantLocation>,
muted: bool,
@@ -957,20 +886,23 @@ impl CollabTitlebarItem {
theme: &Theme,
cx: &mut ViewContext<Self>,
) -> AnyElement<Self> {
let user_id = user.id;
let project_id = workspace.read(cx).project().read(cx).remote_id();
let room = ActiveCall::global(cx).read(cx).room().cloned();
let self_peer_id = workspace.read(cx).client().peer_id();
let self_following = workspace.read(cx).is_being_followed(peer_id);
let self_following_initialized = self_following
&& room.as_ref().map_or(false, |room| match project_id {
None => true,
Some(project_id) => room
.read(cx)
.followers_for(peer_id, project_id)
.iter()
.any(|&follower| Some(follower) == self_peer_id),
});
let room = ActiveCall::global(cx).read(cx).room();
let is_being_followed = workspace.read(cx).is_being_followed(peer_id);
let followed_by_self = room
.and_then(|room| {
Some(
is_being_followed
&& room
.read(cx)
.followers_for(peer_id, project_id?)
.iter()
.any(|&follower| {
Some(follower) == workspace.read(cx).client().peer_id()
}),
)
})
.unwrap_or(false);
let leader_style = theme.titlebar.leader_avatar;
let follower_style = theme.titlebar.follower_avatar;
@@ -989,131 +921,147 @@ impl CollabTitlebarItem {
.background_color
.unwrap_or_default();
let participant_index = self
.user_store
.read(cx)
.participant_indices()
.get(&user_id)
.copied();
if let Some(participant_index) = participant_index {
if self_following_initialized {
let selection = theme
.editor
.selection_style_for_room_participant(participant_index.0)
.selection;
if let Some(replica_id) = replica_id {
if followed_by_self {
let selection = theme.editor.replica_selection_style(replica_id).selection;
background_color = Color::blend(selection, background_color);
background_color.a = 255;
}
}
enum TitlebarParticipant {}
let mut content = Stack::new()
.with_children(user.avatar.as_ref().map(|avatar| {
let face_pile = FacePile::new(theme.titlebar.follower_avatar_overlap)
.with_child(Self::render_face(
avatar.clone(),
Self::location_style(workspace, location, leader_style, cx),
background_color,
microphone_state,
))
.with_children(
(|| {
let project_id = project_id?;
let room = room?.read(cx);
let followers = room.followers_for(peer_id, project_id);
let content = MouseEventHandler::new::<TitlebarParticipant, _>(
peer_id.as_u64() as usize,
cx,
move |_, cx| {
Stack::new()
.with_children(user.avatar.as_ref().map(|avatar| {
let face_pile = FacePile::new(theme.titlebar.follower_avatar_overlap)
.with_child(Self::render_face(
avatar.clone(),
Self::location_style(workspace, location, leader_style, cx),
background_color,
microphone_state,
))
.with_children(
(|| {
let project_id = project_id?;
let room = room?.read(cx);
let followers = room.followers_for(peer_id, project_id);
Some(followers.into_iter().filter_map(|&follower| {
if Some(follower) == self_peer_id {
return None;
Some(followers.into_iter().flat_map(|&follower| {
let remote_participant =
room.remote_participant_for_peer_id(follower);
let avatar = remote_participant
.and_then(|p| p.user.avatar.clone())
.or_else(|| {
if follower == workspace.read(cx).client().peer_id()? {
workspace
.read(cx)
.user_store()
.read(cx)
.current_user()?
.avatar
.clone()
} else {
None
}
let participant =
room.remote_participant_for_peer_id(follower)?;
Some(Self::render_face(
participant.user.avatar.clone()?,
follower_style,
background_color,
None,
))
}))
})()
.into_iter()
.flatten(),
)
.with_children(
self_following_initialized
.then(|| self.user_store.read(cx).current_user())
.and_then(|user| {
Some(Self::render_face(
user?.avatar.clone()?,
follower_style,
background_color,
None,
))
}),
);
})?;
let mut container = face_pile
.contained()
.with_style(theme.titlebar.leader_selection);
Some(Self::render_face(
avatar.clone(),
follower_style,
background_color,
None,
))
}))
})()
.into_iter()
.flatten(),
);
if let Some(participant_index) = participant_index {
if self_following_initialized {
let color = theme
.editor
.selection_style_for_room_participant(participant_index.0)
.selection;
container = container.with_background_color(color);
}
}
let mut container = face_pile
.contained()
.with_style(theme.titlebar.leader_selection);
container
}))
.with_children((|| {
let participant_index = participant_index?;
let color = theme
.editor
.selection_style_for_room_participant(participant_index.0)
.cursor;
Some(
AvatarRibbon::new(color)
.constrained()
.with_width(theme.titlebar.avatar_ribbon.width)
.with_height(theme.titlebar.avatar_ribbon.height)
.aligned()
.bottom(),
)
})())
},
);
if Some(peer_id) == self_peer_id {
return content.into_any();
}
content
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, move |_, this, cx| {
let Some(workspace) = this.workspace.upgrade(cx) else {
return;
};
if let Some(task) =
workspace.update(cx, |workspace, cx| workspace.follow(peer_id, cx))
{
task.detach_and_log_err(cx);
if let Some(replica_id) = replica_id {
if followed_by_self {
let color = theme.editor.replica_selection_style(replica_id).selection;
container = container.with_background_color(color);
}
}
})
.with_tooltip::<TitlebarParticipant>(
peer_id.as_u64() as usize,
format!("Follow {}", user.github_login),
Some(Box::new(FollowNextCollaborator)),
theme.tooltip.clone(),
cx,
)
.into_any()
container
}))
.with_children((|| {
let replica_id = replica_id?;
let color = theme.editor.replica_selection_style(replica_id).cursor;
Some(
AvatarRibbon::new(color)
.constrained()
.with_width(theme.titlebar.avatar_ribbon.width)
.with_height(theme.titlebar.avatar_ribbon.height)
.aligned()
.bottom(),
)
})())
.into_any();
if let Some(location) = location {
if let Some(replica_id) = replica_id {
enum ToggleFollow {}
content = MouseEventHandler::new::<ToggleFollow, _>(
replica_id.into(),
cx,
move |_, _| content,
)
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, move |_, item, cx| {
if let Some(workspace) = item.workspace.upgrade(cx) {
if let Some(task) = workspace
.update(cx, |workspace, cx| workspace.toggle_follow(peer_id, cx))
{
task.detach_and_log_err(cx);
}
}
})
.with_tooltip::<ToggleFollow>(
peer_id.as_u64() as usize,
if is_being_followed {
format!("Unfollow {}", user.github_login)
} else {
format!("Follow {}", user.github_login)
},
Some(Box::new(FollowNextCollaborator)),
theme.tooltip.clone(),
cx,
)
.into_any();
} else if let ParticipantLocation::SharedProject { project_id } = location {
enum JoinProject {}
let user_id = user.id;
content = MouseEventHandler::new::<JoinProject, _>(
peer_id.as_u64() as usize,
cx,
move |_, _| content,
)
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, move |_, this, cx| {
if let Some(workspace) = this.workspace.upgrade(cx) {
let app_state = workspace.read(cx).app_state().clone();
workspace::join_remote_project(project_id, user_id, app_state, cx)
.detach_and_log_err(cx);
}
})
.with_tooltip::<JoinProject>(
peer_id.as_u64() as usize,
format!("Follow {} into external project", user.github_login),
Some(Box::new(FollowNextCollaborator)),
theme.tooltip.clone(),
cx,
)
.into_any();
}
}
content
}
fn location_style(

View File

@@ -7,7 +7,7 @@ mod face_pile;
mod incoming_call_notification;
mod notifications;
mod panel_settings;
pub mod project_shared_notification;
mod project_shared_notification;
mod sharing_status_indicator;
use call::{report_call_event_for_room, ActiveCall, Room};

View File

@@ -40,9 +40,7 @@ pub fn init(app_state: &Arc<AppState>, cx: &mut AppContext) {
.push(window);
}
}
room::Event::RemoteProjectUnshared { project_id }
| room::Event::RemoteProjectJoined { project_id }
| room::Event::RemoteProjectInvitationDiscarded { project_id } => {
room::Event::RemoteProjectUnshared { project_id } => {
if let Some(windows) = notification_windows.remove(&project_id) {
for window in windows {
window.remove(cx);
@@ -84,6 +82,7 @@ impl ProjectSharedNotification {
}
fn join(&mut self, cx: &mut ViewContext<Self>) {
cx.remove_window();
if let Some(app_state) = self.app_state.upgrade() {
workspace::join_remote_project(self.project_id, self.owner.id, app_state, cx)
.detach_and_log_err(cx);
@@ -91,15 +90,7 @@ impl ProjectSharedNotification {
}
fn dismiss(&mut self, cx: &mut ViewContext<Self>) {
if let Some(active_room) =
ActiveCall::global(cx).read_with(cx, |call, _| call.room().cloned())
{
active_room.update(cx, |_, cx| {
cx.emit(room::Event::RemoteProjectInvitationDiscarded {
project_id: self.project_id,
});
});
}
cx.remove_window();
}
fn render_owner(&self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {

View File

@@ -265,7 +265,7 @@ impl PickerDelegate for CommandPaletteDelegate {
.with_children(
[
(keystroke.ctrl, "^"),
(keystroke.alt, ""),
(keystroke.alt, ""),
(keystroke.cmd, ""),
(keystroke.shift, ""),
]

View File

@@ -21,9 +21,6 @@ util = { path = "../util" }
workspace = { path = "../workspace" }
anyhow.workspace = true
schemars.workspace = true
serde.workspace = true
serde_derive.workspace = true
smallvec.workspace = true
postage.workspace = true

View File

@@ -1,6 +1,4 @@
pub mod items;
mod project_diagnostics_settings;
mod toolbar_controls;
use anyhow::Result;
use collections::{BTreeSet, HashSet};
@@ -21,7 +19,6 @@ use language::{
};
use lsp::LanguageServerId;
use project::{DiagnosticSummary, Project, ProjectPath};
use project_diagnostics_settings::ProjectDiagnosticsSettings;
use serde_json::json;
use smallvec::SmallVec;
use std::{
@@ -33,21 +30,18 @@ use std::{
sync::Arc,
};
use theme::ThemeSettings;
pub use toolbar_controls::ToolbarControls;
use util::TryFutureExt;
use workspace::{
item::{BreadcrumbText, Item, ItemEvent, ItemHandle},
ItemNavHistory, Pane, PaneBackdrop, ToolbarItemLocation, Workspace,
};
actions!(diagnostics, [Deploy, ToggleWarnings]);
actions!(diagnostics, [Deploy]);
const CONTEXT_LINE_COUNT: u32 = 1;
pub fn init(cx: &mut AppContext) {
settings::register::<ProjectDiagnosticsSettings>(cx);
cx.add_action(ProjectDiagnosticsEditor::deploy);
cx.add_action(ProjectDiagnosticsEditor::toggle_warnings);
items::init(cx);
}
@@ -61,7 +55,6 @@ struct ProjectDiagnosticsEditor {
excerpts: ModelHandle<MultiBuffer>,
path_states: Vec<PathState>,
paths_to_update: BTreeSet<(ProjectPath, LanguageServerId)>,
include_warnings: bool,
}
struct PathState {
@@ -194,7 +187,6 @@ impl ProjectDiagnosticsEditor {
editor,
path_states: Default::default(),
paths_to_update,
include_warnings: settings::get::<ProjectDiagnosticsSettings>(cx).include_warnings,
};
this.update_excerpts(None, cx);
this
@@ -212,18 +204,6 @@ impl ProjectDiagnosticsEditor {
}
}
fn toggle_warnings(&mut self, _: &ToggleWarnings, cx: &mut ViewContext<Self>) {
self.include_warnings = !self.include_warnings;
self.paths_to_update = self
.project
.read(cx)
.diagnostic_summaries(cx)
.map(|(path, server_id, _)| (path, server_id))
.collect();
self.update_excerpts(None, cx);
cx.notify();
}
fn update_excerpts(
&mut self,
language_server_id: Option<LanguageServerId>,
@@ -297,18 +277,14 @@ impl ProjectDiagnosticsEditor {
let mut blocks_to_add = Vec::new();
let mut blocks_to_remove = HashSet::default();
let mut first_excerpt_id = None;
let max_severity = if self.include_warnings {
DiagnosticSeverity::WARNING
} else {
DiagnosticSeverity::ERROR
};
let excerpts_snapshot = self.excerpts.update(cx, |excerpts, excerpts_cx| {
let mut old_groups = path_state.diagnostic_groups.iter().enumerate().peekable();
let mut new_groups = snapshot
.diagnostic_groups(language_server_id)
.into_iter()
.filter(|(_, group)| {
group.entries[group.primary_ix].diagnostic.severity <= max_severity
group.entries[group.primary_ix].diagnostic.severity
<= DiagnosticSeverity::WARNING
})
.peekable();
loop {
@@ -1525,7 +1501,6 @@ mod tests {
client::init_settings(cx);
workspace::init_settings(cx);
Project::init_settings(cx);
crate::init(cx);
});
}

View File

@@ -38,10 +38,6 @@ impl DiagnosticIndicator {
this.in_progress_checks.remove(language_server_id);
cx.notify();
}
project::Event::DiagnosticsUpdated { .. } => {
this.summary = project.read(cx).diagnostic_summary(cx);
cx.notify();
}
_ => {}
})
.detach();

View File

@@ -1,28 +0,0 @@
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Debug)]
pub struct ProjectDiagnosticsSettings {
pub include_warnings: bool,
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
pub struct ProjectDiagnosticsSettingsContent {
include_warnings: Option<bool>,
}
impl settings::Setting for ProjectDiagnosticsSettings {
const KEY: Option<&'static str> = Some("diagnostics");
type FileContent = ProjectDiagnosticsSettingsContent;
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_cx: &gpui::AppContext,
) -> anyhow::Result<Self>
where
Self: Sized,
{
Self::load_via_json_merge(default_value, user_values)
}
}

View File

@@ -1,115 +0,0 @@
use crate::{ProjectDiagnosticsEditor, ToggleWarnings};
use gpui::{
elements::*,
platform::{CursorStyle, MouseButton},
Action, Entity, EventContext, View, ViewContext, WeakViewHandle,
};
use workspace::{item::ItemHandle, ToolbarItemLocation, ToolbarItemView};
pub struct ToolbarControls {
editor: Option<WeakViewHandle<ProjectDiagnosticsEditor>>,
}
impl Entity for ToolbarControls {
type Event = ();
}
impl View for ToolbarControls {
fn ui_name() -> &'static str {
"ToolbarControls"
}
fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
let include_warnings = self
.editor
.as_ref()
.and_then(|editor| editor.upgrade(cx))
.map(|editor| editor.read(cx).include_warnings)
.unwrap_or(false);
let tooltip = if include_warnings {
"Exclude Warnings".into()
} else {
"Include Warnings".into()
};
Flex::row()
.with_child(render_toggle_button(
0,
"icons/warning.svg",
include_warnings,
(tooltip, Some(Box::new(ToggleWarnings))),
cx,
move |this, cx| {
if let Some(editor) = this.editor.and_then(|editor| editor.upgrade(cx)) {
editor.update(cx, |editor, cx| {
editor.toggle_warnings(&Default::default(), cx)
});
}
},
))
.into_any()
}
}
impl ToolbarItemView for ToolbarControls {
fn set_active_pane_item(
&mut self,
active_pane_item: Option<&dyn ItemHandle>,
_: &mut ViewContext<Self>,
) -> ToolbarItemLocation {
if let Some(pane_item) = active_pane_item.as_ref() {
if let Some(editor) = pane_item.downcast::<ProjectDiagnosticsEditor>() {
self.editor = Some(editor.downgrade());
ToolbarItemLocation::PrimaryRight { flex: None }
} else {
ToolbarItemLocation::Hidden
}
} else {
ToolbarItemLocation::Hidden
}
}
}
impl ToolbarControls {
pub fn new() -> Self {
ToolbarControls { editor: None }
}
}
fn render_toggle_button<
F: 'static + Fn(&mut ToolbarControls, &mut EventContext<ToolbarControls>),
>(
index: usize,
icon: &'static str,
toggled: bool,
tooltip: (String, Option<Box<dyn Action>>),
cx: &mut ViewContext<ToolbarControls>,
on_click: F,
) -> AnyElement<ToolbarControls> {
enum Button {}
let theme = theme::current(cx);
let (tooltip_text, action) = tooltip;
MouseEventHandler::new::<Button, _>(index, cx, |mouse_state, _| {
let style = theme
.workspace
.toolbar
.toggleable_tool
.in_state(toggled)
.style_for(mouse_state);
Svg::new(icon)
.with_color(style.color)
.constrained()
.with_width(style.icon_width)
.aligned()
.constrained()
.with_width(style.button_width)
.with_height(style.button_width)
.contained()
.with_style(style.container)
})
.with_cursor_style(CursorStyle::PointingHand)
.on_click(MouseButton::Left, move |_, view, cx| on_click(view, cx))
.with_tooltip::<Button>(index, tooltip_text, action, theme.tooltip.clone(), cx)
.into_any_named("quick action bar button")
}

View File

@@ -36,7 +36,6 @@ language = { path = "../language" }
lsp = { path = "../lsp" }
project = { path = "../project" }
rpc = { path = "../rpc" }
rich_text = { path = "../rich_text" }
settings = { path = "../settings" }
snippet = { path = "../snippet" }
sum_tree = { path = "../sum_tree" }

View File

@@ -25,7 +25,7 @@ use ::git::diff::DiffHunk;
use aho_corasick::AhoCorasick;
use anyhow::{anyhow, Context, Result};
use blink_manager::BlinkManager;
use client::{ClickhouseEvent, Collaborator, ParticipantIndex, TelemetrySettings};
use client::{ClickhouseEvent, TelemetrySettings};
use clock::{Global, ReplicaId};
use collections::{BTreeMap, Bound, HashMap, HashSet, VecDeque};
use convert_case::{Case, Casing};
@@ -79,7 +79,6 @@ pub use multi_buffer::{
use ordered_float::OrderedFloat;
use project::{FormatTrigger, Location, Project, ProjectPath, ProjectTransaction};
use rand::{seq::SliceRandom, thread_rng};
use rpc::proto::PeerId;
use scroll::{
autoscroll::Autoscroll, OngoingScroll, ScrollAnchor, ScrollManager, ScrollbarAutoHide,
};
@@ -582,11 +581,11 @@ pub struct Editor {
get_field_editor_theme: Option<Arc<GetFieldEditorTheme>>,
override_text_style: Option<Box<OverrideTextStyle>>,
project: Option<ModelHandle<Project>>,
collaboration_hub: Option<Box<dyn CollaborationHub>>,
focused: bool,
blink_manager: ModelHandle<BlinkManager>,
pub show_local_selections: bool,
mode: EditorMode,
replica_id_mapping: Option<HashMap<ReplicaId, ReplicaId>>,
show_gutter: bool,
show_wrap_guides: Option<bool>,
placeholder_text: Option<Arc<str>>,
@@ -610,7 +609,7 @@ pub struct Editor {
keymap_context_layers: BTreeMap<TypeId, KeymapContext>,
input_enabled: bool,
read_only: bool,
leader_peer_id: Option<PeerId>,
leader_replica_id: Option<u16>,
remote_id: Option<ViewId>,
hover_state: HoverState,
gutter_hovered: bool,
@@ -632,15 +631,6 @@ pub struct EditorSnapshot {
ongoing_scroll: OngoingScroll,
}
pub struct RemoteSelection {
pub replica_id: ReplicaId,
pub selection: Selection<Anchor>,
pub cursor_shape: CursorShape,
pub peer_id: PeerId,
pub line_mode: bool,
pub participant_index: Option<ParticipantIndex>,
}
#[derive(Clone, Debug)]
struct SelectionHistoryEntry {
selections: Arc<[Selection<Anchor>]>,
@@ -1057,8 +1047,7 @@ impl CompletionsMenu {
item_ix: Some(item_ix),
},
cx,
)
.map(|task| task.detach());
);
})
.into_any(),
);
@@ -1550,12 +1539,12 @@ impl Editor {
active_diagnostics: None,
soft_wrap_mode_override,
get_field_editor_theme,
collaboration_hub: project.clone().map(|project| Box::new(project) as _),
project,
focused: false,
blink_manager: blink_manager.clone(),
show_local_selections: true,
mode,
replica_id_mapping: None,
show_gutter: mode == EditorMode::Full,
show_wrap_guides: None,
placeholder_text: None,
@@ -1582,7 +1571,7 @@ impl Editor {
keymap_context_layers: Default::default(),
input_enabled: true,
read_only: false,
leader_peer_id: None,
leader_replica_id: None,
remote_id: None,
hover_state: Default::default(),
link_go_to_definition_state: Default::default(),
@@ -1669,8 +1658,8 @@ impl Editor {
self.buffer.read(cx).replica_id()
}
pub fn leader_peer_id(&self) -> Option<PeerId> {
self.leader_peer_id
pub fn leader_replica_id(&self) -> Option<ReplicaId> {
self.leader_replica_id
}
pub fn buffer(&self) -> &ModelHandle<MultiBuffer> {
@@ -1734,14 +1723,6 @@ impl Editor {
self.mode
}
pub fn collaboration_hub(&self) -> Option<&dyn CollaborationHub> {
self.collaboration_hub.as_deref()
}
pub fn set_collaboration_hub(&mut self, hub: Box<dyn CollaborationHub>) {
self.collaboration_hub = Some(hub);
}
pub fn set_placeholder_text(
&mut self,
placeholder_text: impl Into<Arc<str>>,
@@ -1818,13 +1799,26 @@ impl Editor {
cx.notify();
}
pub fn replica_id_map(&self) -> Option<&HashMap<ReplicaId, ReplicaId>> {
self.replica_id_mapping.as_ref()
}
pub fn set_replica_id_map(
&mut self,
mapping: Option<HashMap<ReplicaId, ReplicaId>>,
cx: &mut ViewContext<Self>,
) {
self.replica_id_mapping = mapping;
cx.notify();
}
fn selections_did_change(
&mut self,
local: bool,
old_cursor_position: &Anchor,
cx: &mut ViewContext<Self>,
) {
if self.focused && self.leader_peer_id.is_none() {
if self.focused && self.leader_replica_id.is_none() {
self.buffer.update(cx, |buffer, cx| {
buffer.set_active_selections(
&self.selections.disjoint_anchors(),
@@ -2454,13 +2448,7 @@ impl Editor {
let snapshot = this.buffer.read(cx).read(cx);
let new_selections = resolve_multiple::<usize, _>(new_anchor_selections, &snapshot)
.zip(new_selection_deltas)
.map(|(selection, delta)| Selection {
id: selection.id,
start: selection.start + delta,
end: selection.end + delta,
reversed: selection.reversed,
goal: SelectionGoal::None,
})
.map(|(selection, delta)| selection.map(|e| e + delta))
.collect::<Vec<_>>();
let mut i = 0;
@@ -4665,13 +4653,7 @@ impl Editor {
}
pub fn convert_to_title_case(&mut self, _: &ConvertToTitleCase, cx: &mut ViewContext<Self>) {
self.manipulate_text(cx, |text| {
// Hack to get around the fact that to_case crate doesn't support '\n' as a word boundary
// https://github.com/rutrum/convert-case/issues/16
text.split("\n")
.map(|line| line.to_case(Case::Title))
.join("\n")
})
self.manipulate_text(cx, |text| text.to_case(Case::Title))
}
pub fn convert_to_snake_case(&mut self, _: &ConvertToSnakeCase, cx: &mut ViewContext<Self>) {
@@ -4687,13 +4669,7 @@ impl Editor {
_: &ConvertToUpperCamelCase,
cx: &mut ViewContext<Self>,
) {
self.manipulate_text(cx, |text| {
// Hack to get around the fact that to_case crate doesn't support '\n' as a word boundary
// https://github.com/rutrum/convert-case/issues/16
text.split("\n")
.map(|line| line.to_case(Case::UpperCamel))
.join("\n")
})
self.manipulate_text(cx, |text| text.to_case(Case::UpperCamel))
}
pub fn convert_to_lower_camel_case(
@@ -8637,27 +8613,6 @@ impl Editor {
}
}
pub trait CollaborationHub {
fn collaborators<'a>(&self, cx: &'a AppContext) -> &'a HashMap<PeerId, Collaborator>;
fn user_participant_indices<'a>(
&self,
cx: &'a AppContext,
) -> &'a HashMap<u64, ParticipantIndex>;
}
impl CollaborationHub for ModelHandle<Project> {
fn collaborators<'a>(&self, cx: &'a AppContext) -> &'a HashMap<PeerId, Collaborator> {
self.read(cx).collaborators()
}
fn user_participant_indices<'a>(
&self,
cx: &'a AppContext,
) -> &'a HashMap<u64, ParticipantIndex> {
self.read(cx).user_store().read(cx).participant_indices()
}
}
fn inlay_hint_settings(
location: Anchor,
snapshot: &MultiBufferSnapshot,
@@ -8701,34 +8656,6 @@ fn ending_row(next_selection: &Selection<Point>, display_map: &DisplaySnapshot)
}
impl EditorSnapshot {
pub fn remote_selections_in_range<'a>(
&'a self,
range: &'a Range<Anchor>,
collaboration_hub: &dyn CollaborationHub,
cx: &'a AppContext,
) -> impl 'a + Iterator<Item = RemoteSelection> {
let participant_indices = collaboration_hub.user_participant_indices(cx);
let collaborators_by_peer_id = collaboration_hub.collaborators(cx);
let collaborators_by_replica_id = collaborators_by_peer_id
.iter()
.map(|(_, collaborator)| (collaborator.replica_id, collaborator))
.collect::<HashMap<_, _>>();
self.buffer_snapshot
.remote_selections_in_range(range)
.filter_map(move |(replica_id, line_mode, cursor_shape, selection)| {
let collaborator = collaborators_by_replica_id.get(&replica_id)?;
let participant_index = participant_indices.get(&collaborator.user_id).copied();
Some(RemoteSelection {
replica_id,
selection,
cursor_shape,
line_mode,
participant_index,
peer_id: collaborator.peer_id,
})
})
}
pub fn language_at<T: ToOffset>(&self, position: T) -> Option<&Arc<Language>> {
self.display_snapshot.buffer_snapshot.language_at(position)
}
@@ -8842,7 +8769,7 @@ impl View for Editor {
self.focused = true;
self.buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction(cx);
if self.leader_peer_id.is_none() {
if self.leader_replica_id.is_none() {
buffer.set_active_selections(
&self.selections.disjoint_anchors(),
self.selections.line_mode,

View File

@@ -1333,7 +1333,7 @@ async fn test_move_start_of_paragraph_end_of_paragraph(cx: &mut gpui::TestAppCon
cx.update_editor(|editor, cx| editor.move_to_end_of_paragraph(&MoveToEndOfParagraph, cx));
cx.assert_editor_state(
&r#"one
&r#"ˇone
two
three
@@ -1344,22 +1344,9 @@ async fn test_move_start_of_paragraph_end_of_paragraph(cx: &mut gpui::TestAppCon
.unindent(),
);
cx.update_editor(|editor, cx| editor.move_to_start_of_paragraph(&MoveToStartOfParagraph, cx));
cx.update_editor(|editor, cx| editor.move_to_end_of_paragraph(&MoveToEndOfParagraph, cx));
cx.assert_editor_state(
&r#"one
two
three
four
five
ˇ
six"#
.unindent(),
);
cx.update_editor(|editor, cx| editor.move_to_start_of_paragraph(&MoveToStartOfParagraph, cx));
cx.assert_editor_state(
&r#"one
&r#"ˇone
two
ˇ
three
@@ -1379,6 +1366,32 @@ async fn test_move_start_of_paragraph_end_of_paragraph(cx: &mut gpui::TestAppCon
four
five
sixˇ"#
.unindent(),
);
cx.update_editor(|editor, cx| editor.move_to_start_of_paragraph(&MoveToStartOfParagraph, cx));
cx.assert_editor_state(
&r#"one
two
three
four
five
ˇ
sixˇ"#
.unindent(),
);
cx.update_editor(|editor, cx| editor.move_to_start_of_paragraph(&MoveToStartOfParagraph, cx));
cx.assert_editor_state(
&r#"one
two
ˇ
three
four
five
ˇ
six"#
.unindent(),
);
@@ -2779,34 +2792,6 @@ async fn test_manipulate_text(cx: &mut TestAppContext) {
«hello worldˇ»
"});
// Test multiple line, single selection case
// Test code hack that covers the fact that to_case crate doesn't support '\n' as a word boundary
cx.set_state(indoc! {"
«The quick brown
fox jumps over
the lazy dogˇ»
"});
cx.update_editor(|e, cx| e.convert_to_title_case(&ConvertToTitleCase, cx));
cx.assert_editor_state(indoc! {"
«The Quick Brown
Fox Jumps Over
The Lazy Dogˇ»
"});
// Test multiple line, single selection case
// Test code hack that covers the fact that to_case crate doesn't support '\n' as a word boundary
cx.set_state(indoc! {"
«The quick brown
fox jumps over
the lazy dogˇ»
"});
cx.update_editor(|e, cx| e.convert_to_upper_camel_case(&ConvertToUpperCamelCase, cx));
cx.assert_editor_state(indoc! {"
«TheQuickBrown
FoxJumpsOver
TheLazyDogˇ»
"});
// From here on out, test more complex cases of manipulate_text()
// Test no selection case - should affect words cursors are in

View File

@@ -17,6 +17,7 @@ use crate::{
},
mouse_context_menu, EditorSettings, EditorStyle, GutterHover, UnfoldAt,
};
use clock::ReplicaId;
use collections::{BTreeMap, HashMap};
use git::diff::DiffHunkStatus;
use gpui::{
@@ -54,7 +55,6 @@ use std::{
sync::Arc,
};
use text::Point;
use theme::SelectionStyle;
use workspace::item::Item;
enum FoldMarkers {}
@@ -868,7 +868,14 @@ impl EditorElement {
let corner_radius = 0.15 * layout.position_map.line_height;
let mut invisible_display_ranges = SmallVec::<[Range<DisplayPoint>; 32]>::new();
for (selection_style, selections) in &layout.selections {
for (replica_id, selections) in &layout.selections {
let replica_id = *replica_id;
let selection_style = if let Some(replica_id) = replica_id {
style.replica_selection_style(replica_id)
} else {
&style.absent_selection
};
for selection in selections {
self.paint_highlighted_range(
selection.range.clone(),
@@ -2186,7 +2193,7 @@ impl Element<Editor> for EditorElement {
.anchor_before(DisplayPoint::new(end_row, 0).to_offset(&snapshot, Bias::Right))
};
let mut selections: Vec<(SelectionStyle, Vec<SelectionLayout>)> = Vec::new();
let mut selections: Vec<(Option<ReplicaId>, Vec<SelectionLayout>)> = Vec::new();
let mut active_rows = BTreeMap::new();
let mut fold_ranges = Vec::new();
let is_singleton = editor.is_singleton(cx);
@@ -2212,6 +2219,35 @@ impl Element<Editor> for EditorElement {
}),
);
let mut remote_selections = HashMap::default();
for (replica_id, line_mode, cursor_shape, selection) in snapshot
.buffer_snapshot
.remote_selections_in_range(&(start_anchor..end_anchor))
{
let replica_id = if let Some(mapping) = &editor.replica_id_mapping {
mapping.get(&replica_id).copied()
} else {
Some(replica_id)
};
// The local selections match the leader's selections.
if replica_id.is_some() && replica_id == editor.leader_replica_id {
continue;
}
remote_selections
.entry(replica_id)
.or_insert(Vec::new())
.push(SelectionLayout::new(
selection,
line_mode,
cursor_shape,
&snapshot.display_snapshot,
false,
false,
));
}
selections.extend(remote_selections);
let mut newest_selection_head = None;
if editor.show_local_selections {
@@ -2246,58 +2282,19 @@ impl Element<Editor> for EditorElement {
layouts.push(layout);
}
selections.push((style.selection, layouts));
}
if let Some(collaboration_hub) = &editor.collaboration_hub {
// When following someone, render the local selections in their color.
if let Some(leader_id) = editor.leader_peer_id {
if let Some(collaborator) = collaboration_hub.collaborators(cx).get(&leader_id) {
if let Some(participant_index) = collaboration_hub
.user_participant_indices(cx)
.get(&collaborator.user_id)
{
if let Some((local_selection_style, _)) = selections.first_mut() {
*local_selection_style =
style.selection_style_for_room_participant(participant_index.0);
}
}
}
}
let mut remote_selections = HashMap::default();
for selection in snapshot.remote_selections_in_range(
&(start_anchor..end_anchor),
collaboration_hub.as_ref(),
cx,
) {
let selection_style = if let Some(participant_index) = selection.participant_index {
style.selection_style_for_room_participant(participant_index.0)
// Render the local selections in the leader's color when following.
let local_replica_id = if let Some(leader_replica_id) = editor.leader_replica_id {
leader_replica_id
} else {
let replica_id = editor.replica_id(cx);
if let Some(mapping) = &editor.replica_id_mapping {
mapping.get(&replica_id).copied().unwrap_or(replica_id)
} else {
style.absent_selection
};
// Don't re-render the leader's selections, since the local selections
// match theirs.
if Some(selection.peer_id) == editor.leader_peer_id {
continue;
replica_id
}
};
remote_selections
.entry(selection.replica_id)
.or_insert((selection_style, Vec::new()))
.1
.push(SelectionLayout::new(
selection.selection,
selection.line_mode,
selection.cursor_shape,
&snapshot.display_snapshot,
false,
false,
));
}
selections.extend(remote_selections.into_values());
selections.push((Some(local_replica_id), layouts));
}
let scrollbar_settings = &settings::get::<EditorSettings>(cx).scrollbar;
@@ -2689,7 +2686,7 @@ pub struct LayoutState {
blocks: Vec<BlockLayout>,
highlighted_ranges: Vec<(Range<DisplayPoint>, Color)>,
fold_ranges: Vec<(BufferRow, Range<DisplayPoint>, Color)>,
selections: Vec<(SelectionStyle, Vec<SelectionLayout>)>,
selections: Vec<(Option<ReplicaId>, Vec<SelectionLayout>)>,
scrollbar_row_range: Range<f32>,
show_scrollbars: bool,
is_singleton: bool,

View File

@@ -8,12 +8,12 @@ use futures::FutureExt;
use gpui::{
actions,
elements::{Flex, MouseEventHandler, Padding, ParentElement, Text},
fonts::{HighlightStyle, Underline, Weight},
platform::{CursorStyle, MouseButton},
AnyElement, AppContext, Element, ModelHandle, Task, ViewContext,
AnyElement, AppContext, CursorRegion, Element, ModelHandle, MouseRegion, Task, ViewContext,
};
use language::{Bias, DiagnosticEntry, DiagnosticSeverity, Language, LanguageRegistry};
use project::{HoverBlock, HoverBlockKind, InlayHintLabelPart, Project};
use rich_text::{new_paragraph, render_code, render_markdown_mut, RichText};
use std::{ops::Range, sync::Arc, time::Duration};
use util::TryFutureExt;
@@ -346,25 +346,158 @@ fn show_hover(
}
fn render_blocks(
theme_id: usize,
blocks: &[HoverBlock],
language_registry: &Arc<LanguageRegistry>,
language: Option<&Arc<Language>>,
) -> RichText {
let mut data = RichText {
text: Default::default(),
highlights: Default::default(),
region_ranges: Default::default(),
regions: Default::default(),
};
style: &EditorStyle,
) -> RenderedInfo {
let mut text = String::new();
let mut highlights = Vec::new();
let mut region_ranges = Vec::new();
let mut regions = Vec::new();
for block in blocks {
match &block.kind {
HoverBlockKind::PlainText => {
new_paragraph(&mut data.text, &mut Vec::new());
data.text.push_str(&block.text);
new_paragraph(&mut text, &mut Vec::new());
text.push_str(&block.text);
}
HoverBlockKind::Markdown => {
render_markdown_mut(&block.text, language_registry, language, &mut data)
use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag};
let mut bold_depth = 0;
let mut italic_depth = 0;
let mut link_url = None;
let mut current_language = None;
let mut list_stack = Vec::new();
for event in Parser::new_ext(&block.text, Options::all()) {
let prev_len = text.len();
match event {
Event::Text(t) => {
if let Some(language) = &current_language {
render_code(
&mut text,
&mut highlights,
t.as_ref(),
language,
style,
);
} else {
text.push_str(t.as_ref());
let mut style = HighlightStyle::default();
if bold_depth > 0 {
style.weight = Some(Weight::BOLD);
}
if italic_depth > 0 {
style.italic = Some(true);
}
if let Some(link_url) = link_url.clone() {
region_ranges.push(prev_len..text.len());
regions.push(RenderedRegion {
link_url: Some(link_url),
code: false,
});
style.underline = Some(Underline {
thickness: 1.0.into(),
..Default::default()
});
}
if style != HighlightStyle::default() {
let mut new_highlight = true;
if let Some((last_range, last_style)) = highlights.last_mut() {
if last_range.end == prev_len && last_style == &style {
last_range.end = text.len();
new_highlight = false;
}
}
if new_highlight {
highlights.push((prev_len..text.len(), style));
}
}
}
}
Event::Code(t) => {
text.push_str(t.as_ref());
region_ranges.push(prev_len..text.len());
if link_url.is_some() {
highlights.push((
prev_len..text.len(),
HighlightStyle {
underline: Some(Underline {
thickness: 1.0.into(),
..Default::default()
}),
..Default::default()
},
));
}
regions.push(RenderedRegion {
code: true,
link_url: link_url.clone(),
});
}
Event::Start(tag) => match tag {
Tag::Paragraph => new_paragraph(&mut text, &mut list_stack),
Tag::Heading(_, _, _) => {
new_paragraph(&mut text, &mut list_stack);
bold_depth += 1;
}
Tag::CodeBlock(kind) => {
new_paragraph(&mut text, &mut list_stack);
current_language = if let CodeBlockKind::Fenced(language) = kind {
language_registry
.language_for_name(language.as_ref())
.now_or_never()
.and_then(Result::ok)
} else {
language.cloned()
}
}
Tag::Emphasis => italic_depth += 1,
Tag::Strong => bold_depth += 1,
Tag::Link(_, url, _) => link_url = Some(url.to_string()),
Tag::List(number) => {
list_stack.push((number, false));
}
Tag::Item => {
let len = list_stack.len();
if let Some((list_number, has_content)) = list_stack.last_mut() {
*has_content = false;
if !text.is_empty() && !text.ends_with('\n') {
text.push('\n');
}
for _ in 0..len - 1 {
text.push_str(" ");
}
if let Some(number) = list_number {
text.push_str(&format!("{}. ", number));
*number += 1;
*has_content = false;
} else {
text.push_str("- ");
}
}
}
_ => {}
},
Event::End(tag) => match tag {
Tag::Heading(_, _, _) => bold_depth -= 1,
Tag::CodeBlock(_) => current_language = None,
Tag::Emphasis => italic_depth -= 1,
Tag::Strong => bold_depth -= 1,
Tag::Link(_, _, _) => link_url = None,
Tag::List(_) => drop(list_stack.pop()),
_ => {}
},
Event::HardBreak => text.push('\n'),
Event::SoftBreak => text.push(' '),
_ => {}
}
}
}
HoverBlockKind::Code { language } => {
if let Some(language) = language_registry
@@ -372,17 +505,62 @@ fn render_blocks(
.now_or_never()
.and_then(Result::ok)
{
render_code(&mut data.text, &mut data.highlights, &block.text, &language);
render_code(&mut text, &mut highlights, &block.text, &language, style);
} else {
data.text.push_str(&block.text);
text.push_str(&block.text);
}
}
}
}
data.text = data.text.trim().to_string();
RenderedInfo {
theme_id,
text: text.trim().to_string(),
highlights,
region_ranges,
regions,
}
}
data
fn render_code(
text: &mut String,
highlights: &mut Vec<(Range<usize>, HighlightStyle)>,
content: &str,
language: &Arc<Language>,
style: &EditorStyle,
) {
let prev_len = text.len();
text.push_str(content);
for (range, highlight_id) in language.highlight_text(&content.into(), 0..content.len()) {
if let Some(style) = highlight_id.style(&style.syntax) {
highlights.push((prev_len + range.start..prev_len + range.end, style));
}
}
}
fn new_paragraph(text: &mut String, list_stack: &mut Vec<(Option<u64>, bool)>) {
let mut is_subsequent_paragraph_of_list = false;
if let Some((_, has_content)) = list_stack.last_mut() {
if *has_content {
is_subsequent_paragraph_of_list = true;
} else {
*has_content = true;
return;
}
}
if !text.is_empty() {
if !text.ends_with('\n') {
text.push('\n');
}
text.push('\n');
}
for _ in 0..list_stack.len().saturating_sub(1) {
text.push_str(" ");
}
if is_subsequent_paragraph_of_list {
text.push_str(" ");
}
}
#[derive(Default)]
@@ -445,7 +623,22 @@ pub struct InfoPopover {
symbol_range: RangeInEditor,
pub blocks: Vec<HoverBlock>,
language: Option<Arc<Language>>,
rendered_content: Option<RichText>,
rendered_content: Option<RenderedInfo>,
}
#[derive(Debug, Clone)]
struct RenderedInfo {
theme_id: usize,
text: String,
highlights: Vec<(Range<usize>, HighlightStyle)>,
region_ranges: Vec<Range<usize>>,
regions: Vec<RenderedRegion>,
}
#[derive(Debug, Clone)]
struct RenderedRegion {
code: bool,
link_url: Option<String>,
}
impl InfoPopover {
@@ -454,24 +647,63 @@ impl InfoPopover {
style: &EditorStyle,
cx: &mut ViewContext<Editor>,
) -> AnyElement<Editor> {
if let Some(rendered) = &self.rendered_content {
if rendered.theme_id != style.theme_id {
self.rendered_content = None;
}
}
let rendered_content = self.rendered_content.get_or_insert_with(|| {
render_blocks(
style.theme_id,
&self.blocks,
self.project.read(cx).languages(),
self.language.as_ref(),
style,
)
});
MouseEventHandler::new::<InfoPopover, _>(0, cx, move |_, cx| {
MouseEventHandler::new::<InfoPopover, _>(0, cx, |_, cx| {
let mut region_id = 0;
let view_id = cx.view_id();
let code_span_background_color = style.document_highlight_read_background;
let regions = rendered_content.regions.clone();
Flex::column()
.scrollable::<HoverBlock>(1, None, cx)
.with_child(rendered_content.element(
style.syntax.clone(),
style.text.clone(),
code_span_background_color,
cx,
))
.with_child(
Text::new(rendered_content.text.clone(), style.text.clone())
.with_highlights(rendered_content.highlights.clone())
.with_custom_runs(
rendered_content.region_ranges.clone(),
move |ix, bounds, cx| {
region_id += 1;
let region = regions[ix].clone();
if let Some(url) = region.link_url {
cx.scene().push_cursor_region(CursorRegion {
bounds,
style: CursorStyle::PointingHand,
});
cx.scene().push_mouse_region(
MouseRegion::new::<Self>(view_id, region_id, bounds)
.on_click::<Editor, _>(
MouseButton::Left,
move |_, _, cx| cx.platform().open_url(&url),
),
);
}
if region.code {
cx.scene().push_quad(gpui::Quad {
bounds,
background: Some(code_span_background_color),
border: Default::default(),
corner_radii: (2.0).into(),
});
}
},
)
.with_soft_wrap(true),
)
.contained()
.with_style(style.hover_popover.container)
})
@@ -567,12 +799,11 @@ mod tests {
InlayId,
};
use collections::BTreeSet;
use gpui::fonts::{HighlightStyle, Underline, Weight};
use gpui::fonts::Weight;
use indoc::indoc;
use language::{language_settings::InlayHintSettings, Diagnostic, DiagnosticSet};
use lsp::LanguageServerId;
use project::{HoverBlock, HoverBlockKind};
use rich_text::Highlight;
use smol::stream::StreamExt;
use unindent::Unindent;
use util::test::marked_text_ranges;
@@ -783,7 +1014,7 @@ mod tests {
.await;
cx.condition(|editor, _| editor.hover_state.visible()).await;
cx.editor(|editor, _| {
cx.editor(|editor, cx| {
let blocks = editor.hover_state.info_popover.clone().unwrap().blocks;
assert_eq!(
blocks,
@@ -793,7 +1024,8 @@ mod tests {
}],
);
let rendered = render_blocks(&blocks, &Default::default(), None);
let style = editor.style(cx);
let rendered = render_blocks(0, &blocks, &Default::default(), None, &style);
assert_eq!(
rendered.text,
code_str.trim(),
@@ -985,7 +1217,7 @@ mod tests {
expected_styles,
} in &rows[0..]
{
let rendered = render_blocks(&blocks, &Default::default(), None);
let rendered = render_blocks(0, &blocks, &Default::default(), None, &style);
let (expected_text, ranges) = marked_text_ranges(expected_marked_text, false);
let expected_highlights = ranges
@@ -996,21 +1228,8 @@ mod tests {
rendered.text, expected_text,
"wrong text for input {blocks:?}"
);
let rendered_highlights: Vec<_> = rendered
.highlights
.iter()
.filter_map(|(range, highlight)| {
let style = match highlight {
Highlight::Id(id) => id.style(&style.syntax)?,
Highlight::Highlight(style) => style.clone(),
};
Some((range.clone(), style))
})
.collect();
assert_eq!(
rendered_highlights, expected_highlights,
rendered.highlights, expected_highlights,
"wrong highlights for input {blocks:?}"
);
}

View File

@@ -17,7 +17,7 @@ use language::{
SelectionGoal,
};
use project::{search::SearchQuery, FormatTrigger, Item as _, Project, ProjectPath};
use rpc::proto::{self, update_view, PeerId};
use rpc::proto::{self, update_view};
use smallvec::SmallVec;
use std::{
borrow::Cow,
@@ -156,9 +156,13 @@ impl FollowableItem for Editor {
}))
}
fn set_leader_peer_id(&mut self, leader_peer_id: Option<PeerId>, cx: &mut ViewContext<Self>) {
self.leader_peer_id = leader_peer_id;
if self.leader_peer_id.is_some() {
fn set_leader_replica_id(
&mut self,
leader_replica_id: Option<u16>,
cx: &mut ViewContext<Self>,
) {
self.leader_replica_id = leader_replica_id;
if self.leader_replica_id.is_some() {
self.buffer.update(cx, |buffer, cx| {
buffer.remove_active_selections(cx);
});
@@ -305,10 +309,6 @@ impl FollowableItem for Editor {
_ => false,
}
}
fn is_project_item(&self, _cx: &AppContext) -> bool {
true
}
}
async fn update_editor_from_message(

View File

@@ -234,7 +234,7 @@ pub fn start_of_paragraph(
) -> DisplayPoint {
let point = display_point.to_point(map);
if point.row == 0 {
return DisplayPoint::zero();
return map.max_point();
}
let mut found_non_blank_line = false;
@@ -261,7 +261,7 @@ pub fn end_of_paragraph(
) -> DisplayPoint {
let point = display_point.to_point(map);
if point.row == map.max_buffer_row() {
return map.max_point();
return DisplayPoint::zero();
}
let mut found_non_blank_line = false;

View File

@@ -33,7 +33,7 @@ lazy_static.workspace = true
postage.workspace = true
serde.workspace = true
serde_derive.workspace = true
sysinfo.workspace = true
sysinfo = "0.27.1"
tree-sitter-markdown = { git = "https://github.com/MDeiml/tree-sitter-markdown", rev = "330ecab87a3e3a7211ac69bbadc19eabecdb1cca" }
urlencoding = "2.1.2"

View File

@@ -10,7 +10,6 @@ doctest = false
[dependencies]
editor = { path = "../editor" }
collections = { path = "../collections" }
fuzzy = { path = "../fuzzy" }
gpui = { path = "../gpui" }
menu = { path = "../menu" }

View File

@@ -1,6 +1,5 @@
use collections::HashMap;
use editor::{scroll::autoscroll::Autoscroll, Bias, Editor};
use fuzzy::{CharBag, PathMatch, PathMatchCandidate};
use fuzzy::PathMatch;
use gpui::{
actions, elements::*, AppContext, ModelHandle, MouseState, Task, ViewContext, WeakViewHandle,
};
@@ -33,124 +32,38 @@ pub struct FileFinderDelegate {
history_items: Vec<FoundPath>,
}
#[derive(Debug, Default)]
struct Matches {
history: Vec<(FoundPath, Option<PathMatch>)>,
search: Vec<PathMatch>,
#[derive(Debug)]
enum Matches {
History(Vec<FoundPath>),
Search(Vec<PathMatch>),
}
#[derive(Debug)]
enum Match<'a> {
History(&'a FoundPath, Option<&'a PathMatch>),
History(&'a FoundPath),
Search(&'a PathMatch),
}
impl Matches {
fn len(&self) -> usize {
self.history.len() + self.search.len()
}
fn get(&self, index: usize) -> Option<Match<'_>> {
if index < self.history.len() {
self.history
.get(index)
.map(|(path, path_match)| Match::History(path, path_match.as_ref()))
} else {
self.search
.get(index - self.history.len())
.map(Match::Search)
match self {
Self::History(items) => items.len(),
Self::Search(items) => items.len(),
}
}
fn push_new_matches(
&mut self,
history_items: &Vec<FoundPath>,
query: &PathLikeWithPosition<FileSearchQuery>,
mut new_search_matches: Vec<PathMatch>,
extend_old_matches: bool,
) {
let matching_history_paths = matching_history_item_paths(history_items, query);
new_search_matches
.retain(|path_match| !matching_history_paths.contains_key(&path_match.path));
let history_items_to_show = history_items
.iter()
.filter_map(|history_item| {
Some((
history_item.clone(),
Some(
matching_history_paths
.get(&history_item.project.path)?
.clone(),
),
))
})
.collect::<Vec<_>>();
self.history = history_items_to_show;
if extend_old_matches {
self.search
.retain(|path_match| !matching_history_paths.contains_key(&path_match.path));
util::extend_sorted(
&mut self.search,
new_search_matches.into_iter(),
100,
|a, b| b.cmp(a),
)
} else {
self.search = new_search_matches;
fn get(&self, index: usize) -> Option<Match<'_>> {
match self {
Self::History(items) => items.get(index).map(Match::History),
Self::Search(items) => items.get(index).map(Match::Search),
}
}
}
fn matching_history_item_paths(
history_items: &Vec<FoundPath>,
query: &PathLikeWithPosition<FileSearchQuery>,
) -> HashMap<Arc<Path>, PathMatch> {
let history_items_by_worktrees = history_items
.iter()
.filter_map(|found_path| {
let candidate = PathMatchCandidate {
path: &found_path.project.path,
// Only match history items names, otherwise their paths may match too many queries, producing false positives.
// E.g. `foo` would match both `something/foo/bar.rs` and `something/foo/foo.rs` and if the former is a history item,
// it would be shown first always, despite the latter being a better match.
char_bag: CharBag::from_iter(
found_path
.project
.path
.file_name()?
.to_string_lossy()
.to_lowercase()
.chars(),
),
};
Some((found_path.project.worktree_id, candidate))
})
.fold(
HashMap::default(),
|mut candidates, (worktree_id, new_candidate)| {
candidates
.entry(worktree_id)
.or_insert_with(Vec::new)
.push(new_candidate);
candidates
},
);
let mut matching_history_paths = HashMap::default();
for (worktree, candidates) in history_items_by_worktrees {
let max_results = candidates.len() + 1;
matching_history_paths.extend(
fuzzy::match_fixed_path_set(
candidates,
worktree.to_usize(),
query.path_like.path_query(),
false,
max_results,
)
.into_iter()
.map(|path_match| (Arc::clone(&path_match.path), path_match)),
);
impl Default for Matches {
fn default() -> Self {
Self::History(Vec::new())
}
matching_history_paths
}
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -168,96 +81,66 @@ impl FoundPath {
actions!(file_finder, [Toggle]);
pub fn init(cx: &mut AppContext) {
cx.add_action(toggle_or_cycle_file_finder);
cx.add_action(toggle_file_finder);
FileFinder::init(cx);
}
const MAX_RECENT_SELECTIONS: usize = 20;
fn toggle_or_cycle_file_finder(
workspace: &mut Workspace,
_: &Toggle,
cx: &mut ViewContext<Workspace>,
) {
match workspace.modal::<FileFinder>() {
Some(file_finder) => file_finder.update(cx, |file_finder, cx| {
let current_index = file_finder.delegate().selected_index();
file_finder.select_next(&menu::SelectNext, cx);
let new_index = file_finder.delegate().selected_index();
if current_index == new_index {
file_finder.select_first(&menu::SelectFirst, cx);
}
}),
None => {
workspace.toggle_modal(cx, |workspace, cx| {
let project = workspace.project().read(cx);
fn toggle_file_finder(workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>) {
workspace.toggle_modal(cx, |workspace, cx| {
let project = workspace.project().read(cx);
let currently_opened_path = workspace
.active_item(cx)
.and_then(|item| item.project_path(cx))
.map(|project_path| {
let abs_path = project
.worktree_for_id(project_path.worktree_id, cx)
.map(|worktree| worktree.read(cx).abs_path().join(&project_path.path));
FoundPath::new(project_path, abs_path)
});
// if exists, bubble the currently opened path to the top
let history_items = currently_opened_path
.clone()
.into_iter()
.chain(
workspace
.recent_navigation_history(Some(MAX_RECENT_SELECTIONS), cx)
.into_iter()
.filter(|(history_path, _)| {
Some(history_path)
!= currently_opened_path
.as_ref()
.map(|found_path| &found_path.project)
})
.filter(|(_, history_abs_path)| {
history_abs_path.as_ref()
!= currently_opened_path
.as_ref()
.and_then(|found_path| found_path.absolute.as_ref())
})
.filter(|(_, history_abs_path)| match history_abs_path {
Some(abs_path) => history_file_exists(abs_path),
None => true,
})
.map(|(history_path, abs_path)| FoundPath::new(history_path, abs_path)),
)
.collect();
let project = workspace.project().clone();
let workspace = cx.handle().downgrade();
let finder = cx.add_view(|cx| {
Picker::new(
FileFinderDelegate::new(
workspace,
project,
currently_opened_path,
history_items,
cx,
),
cx,
)
});
finder
let currently_opened_path = workspace
.active_item(cx)
.and_then(|item| item.project_path(cx))
.map(|project_path| {
let abs_path = project
.worktree_for_id(project_path.worktree_id, cx)
.map(|worktree| worktree.read(cx).abs_path().join(&project_path.path));
FoundPath::new(project_path, abs_path)
});
}
}
}
#[cfg(not(test))]
fn history_file_exists(abs_path: &PathBuf) -> bool {
abs_path.exists()
}
// if exists, bubble the currently opened path to the top
let history_items = currently_opened_path
.clone()
.into_iter()
.chain(
workspace
.recent_navigation_history(Some(MAX_RECENT_SELECTIONS), cx)
.into_iter()
.filter(|(history_path, _)| {
Some(history_path)
!= currently_opened_path
.as_ref()
.map(|found_path| &found_path.project)
})
.filter(|(_, history_abs_path)| {
history_abs_path.as_ref()
!= currently_opened_path
.as_ref()
.and_then(|found_path| found_path.absolute.as_ref())
})
.map(|(history_path, abs_path)| FoundPath::new(history_path, abs_path)),
)
.collect();
#[cfg(test)]
fn history_file_exists(abs_path: &PathBuf) -> bool {
!abs_path.ends_with("nonexistent.rs")
let project = workspace.project().clone();
let workspace = cx.handle().downgrade();
let finder = cx.add_view(|cx| {
Picker::new(
FileFinderDelegate::new(
workspace,
project,
currently_opened_path,
history_items,
cx,
),
cx,
)
});
finder
});
}
pub enum Event {
@@ -372,14 +255,24 @@ impl FileFinderDelegate {
) {
if search_id >= self.latest_search_id {
self.latest_search_id = search_id;
let extend_old_matches = self.latest_search_did_cancel
if self.latest_search_did_cancel
&& Some(query.path_like.path_query())
== self
.latest_search_query
.as_ref()
.map(|query| query.path_like.path_query());
self.matches
.push_new_matches(&self.history_items, &query, matches, extend_old_matches);
.map(|query| query.path_like.path_query())
{
match &mut self.matches {
Matches::History(_) => self.matches = Matches::Search(matches),
Matches::Search(search_matches) => {
util::extend_sorted(search_matches, matches.into_iter(), 100, |a, b| {
b.cmp(a)
})
}
}
} else {
self.matches = Matches::Search(matches);
}
self.latest_search_query = Some(query);
self.latest_search_did_cancel = did_cancel;
cx.notify();
@@ -393,7 +286,7 @@ impl FileFinderDelegate {
ix: usize,
) -> (String, Vec<usize>, String, Vec<usize>) {
let (file_name, file_name_positions, full_path, full_path_positions) = match path_match {
Match::History(found_path, found_path_match) => {
Match::History(found_path) => {
let worktree_id = found_path.project.worktree_id;
let project_relative_path = &found_path.project.path;
let has_worktree = self
@@ -425,22 +318,14 @@ impl FileFinderDelegate {
path = Arc::from(absolute_path.as_path());
}
}
let mut path_match = PathMatch {
self.labels_for_path_match(&PathMatch {
score: ix as f64,
positions: Vec::new(),
worktree_id: worktree_id.to_usize(),
path,
path_prefix: "".into(),
distance_to_relative_ancestor: usize::MAX,
};
if let Some(found_path_match) = found_path_match {
path_match
.positions
.extend(found_path_match.positions.iter())
}
self.labels_for_path_match(&path_match)
})
}
Match::Search(path_match) => self.labels_for_path_match(path_match),
};
@@ -521,21 +406,23 @@ impl PickerDelegate for FileFinderDelegate {
if raw_query.is_empty() {
let project = self.project.read(cx);
self.latest_search_id = post_inc(&mut self.search_count);
self.matches = Matches {
history: self
.history_items
self.matches = Matches::History(
self.history_items
.iter()
.filter(|history_item| {
project
.worktree_for_id(history_item.project.worktree_id, cx)
.is_some()
|| (project.is_local() && history_item.absolute.is_some())
|| (project.is_local()
&& history_item
.absolute
.as_ref()
.filter(|abs_path| abs_path.exists())
.is_some())
})
.cloned()
.map(|p| (p, None))
.collect(),
search: Vec::new(),
};
);
cx.notify();
Task::ready(())
} else {
@@ -567,7 +454,7 @@ impl PickerDelegate for FileFinderDelegate {
}
};
match m {
Match::History(history_match, _) => {
Match::History(history_match) => {
let worktree_id = history_match.project.worktree_id;
if workspace
.project()
@@ -979,11 +866,11 @@ mod tests {
finder.update(cx, |finder, cx| {
let delegate = finder.delegate_mut();
assert!(
delegate.matches.history.is_empty(),
"Search matches expected"
);
let matches = delegate.matches.search.clone();
let matches = match &delegate.matches {
Matches::Search(path_matches) => path_matches,
_ => panic!("Search matches expected"),
}
.clone();
// Simulate a search being cancelled after the time limit,
// returning only a subset of the matches that would have been found.
@@ -1006,11 +893,12 @@ mod tests {
cx,
);
assert!(
delegate.matches.history.is_empty(),
"Search matches expected"
);
assert_eq!(delegate.matches.search.as_slice(), &matches[0..4]);
match &delegate.matches {
Matches::Search(new_matches) => {
assert_eq!(new_matches.as_slice(), &matches[0..4])
}
_ => panic!("Search matches expected"),
};
});
}
@@ -1118,11 +1006,10 @@ mod tests {
cx.read(|cx| {
let finder = finder.read(cx);
let delegate = finder.delegate();
assert!(
delegate.matches.history.is_empty(),
"Search matches expected"
);
let matches = delegate.matches.search.clone();
let matches = match &delegate.matches {
Matches::Search(path_matches) => path_matches,
_ => panic!("Search matches expected"),
};
assert_eq!(matches.len(), 1);
let (file_name, file_name_positions, full_path, full_path_positions) =
@@ -1201,11 +1088,10 @@ mod tests {
finder.read_with(cx, |f, _| {
let delegate = f.delegate();
assert!(
delegate.matches.history.is_empty(),
"Search matches expected"
);
let matches = delegate.matches.search.clone();
let matches = match &delegate.matches {
Matches::Search(path_matches) => path_matches,
_ => panic!("Search matches expected"),
};
assert_eq!(matches[0].path.as_ref(), Path::new("dir2/a.txt"));
assert_eq!(matches[1].path.as_ref(), Path::new("dir1/a.txt"));
});
@@ -1573,451 +1459,6 @@ mod tests {
);
}
#[gpui::test]
async fn test_toggle_panel_new_selections(
deterministic: Arc<gpui::executor::Deterministic>,
cx: &mut gpui::TestAppContext,
) {
let app_state = init_test(cx);
app_state
.fs
.as_fake()
.insert_tree(
"/src",
json!({
"test": {
"first.rs": "// First Rust file",
"second.rs": "// Second Rust file",
"third.rs": "// Third Rust file",
}
}),
)
.await;
let project = Project::test(app_state.fs.clone(), ["/src".as_ref()], cx).await;
let window = cx.add_window(|cx| Workspace::test_new(project, cx));
let workspace = window.root(cx);
// generate some history to select from
open_close_queried_buffer(
"fir",
1,
"first.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
open_close_queried_buffer(
"sec",
1,
"second.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
open_close_queried_buffer(
"thi",
1,
"third.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
let current_history = open_close_queried_buffer(
"sec",
1,
"second.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
for expected_selected_index in 0..current_history.len() {
cx.dispatch_action(window.into(), Toggle);
let selected_index = cx.read(|cx| {
workspace
.read(cx)
.modal::<FileFinder>()
.unwrap()
.read(cx)
.delegate()
.selected_index()
});
assert_eq!(
selected_index, expected_selected_index,
"Should select the next item in the history"
);
}
cx.dispatch_action(window.into(), Toggle);
let selected_index = cx.read(|cx| {
workspace
.read(cx)
.modal::<FileFinder>()
.unwrap()
.read(cx)
.delegate()
.selected_index()
});
assert_eq!(
selected_index, 0,
"Should wrap around the history and start all over"
);
}
#[gpui::test]
async fn test_search_preserves_history_items(
deterministic: Arc<gpui::executor::Deterministic>,
cx: &mut gpui::TestAppContext,
) {
let app_state = init_test(cx);
app_state
.fs
.as_fake()
.insert_tree(
"/src",
json!({
"test": {
"first.rs": "// First Rust file",
"second.rs": "// Second Rust file",
"third.rs": "// Third Rust file",
"fourth.rs": "// Fourth Rust file",
}
}),
)
.await;
let project = Project::test(app_state.fs.clone(), ["/src".as_ref()], cx).await;
let window = cx.add_window(|cx| Workspace::test_new(project, cx));
let workspace = window.root(cx);
let worktree_id = cx.read(|cx| {
let worktrees = workspace.read(cx).worktrees(cx).collect::<Vec<_>>();
assert_eq!(worktrees.len(), 1,);
WorktreeId::from_usize(worktrees[0].id())
});
// generate some history to select from
open_close_queried_buffer(
"fir",
1,
"first.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
open_close_queried_buffer(
"sec",
1,
"second.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
open_close_queried_buffer(
"thi",
1,
"third.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
open_close_queried_buffer(
"sec",
1,
"second.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
cx.dispatch_action(window.into(), Toggle);
let first_query = "f";
let finder = cx.read(|cx| workspace.read(cx).modal::<FileFinder>().unwrap());
finder
.update(cx, |finder, cx| {
finder
.delegate_mut()
.update_matches(first_query.to_string(), cx)
})
.await;
finder.read_with(cx, |finder, _| {
let delegate = finder.delegate();
assert_eq!(delegate.matches.history.len(), 1, "Only one history item contains {first_query}, it should be present and others should be filtered out");
let history_match = delegate.matches.history.first().unwrap();
assert!(history_match.1.is_some(), "Should have path matches for history items after querying");
assert_eq!(history_match.0, FoundPath::new(
ProjectPath {
worktree_id,
path: Arc::from(Path::new("test/first.rs")),
},
Some(PathBuf::from("/src/test/first.rs"))
));
assert_eq!(delegate.matches.search.len(), 1, "Only one non-history item contains {first_query}, it should be present");
assert_eq!(delegate.matches.search.first().unwrap().path.as_ref(), Path::new("test/fourth.rs"));
});
let second_query = "fsdasdsa";
let finder = cx.read(|cx| workspace.read(cx).modal::<FileFinder>().unwrap());
finder
.update(cx, |finder, cx| {
finder
.delegate_mut()
.update_matches(second_query.to_string(), cx)
})
.await;
finder.read_with(cx, |finder, _| {
let delegate = finder.delegate();
assert!(
delegate.matches.history.is_empty(),
"No history entries should match {second_query}"
);
assert!(
delegate.matches.search.is_empty(),
"No search entries should match {second_query}"
);
});
let first_query_again = first_query;
let finder = cx.read(|cx| workspace.read(cx).modal::<FileFinder>().unwrap());
finder
.update(cx, |finder, cx| {
finder
.delegate_mut()
.update_matches(first_query_again.to_string(), cx)
})
.await;
finder.read_with(cx, |finder, _| {
let delegate = finder.delegate();
assert_eq!(delegate.matches.history.len(), 1, "Only one history item contains {first_query_again}, it should be present and others should be filtered out, even after non-matching query");
let history_match = delegate.matches.history.first().unwrap();
assert!(history_match.1.is_some(), "Should have path matches for history items after querying");
assert_eq!(history_match.0, FoundPath::new(
ProjectPath {
worktree_id,
path: Arc::from(Path::new("test/first.rs")),
},
Some(PathBuf::from("/src/test/first.rs"))
));
assert_eq!(delegate.matches.search.len(), 1, "Only one non-history item contains {first_query_again}, it should be present, even after non-matching query");
assert_eq!(delegate.matches.search.first().unwrap().path.as_ref(), Path::new("test/fourth.rs"));
});
}
#[gpui::test]
async fn test_history_items_vs_very_good_external_match(
deterministic: Arc<gpui::executor::Deterministic>,
cx: &mut gpui::TestAppContext,
) {
let app_state = init_test(cx);
app_state
.fs
.as_fake()
.insert_tree(
"/src",
json!({
"collab_ui": {
"first.rs": "// First Rust file",
"second.rs": "// Second Rust file",
"third.rs": "// Third Rust file",
"collab_ui.rs": "// Fourth Rust file",
}
}),
)
.await;
let project = Project::test(app_state.fs.clone(), ["/src".as_ref()], cx).await;
let window = cx.add_window(|cx| Workspace::test_new(project, cx));
let workspace = window.root(cx);
// generate some history to select from
open_close_queried_buffer(
"fir",
1,
"first.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
open_close_queried_buffer(
"sec",
1,
"second.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
open_close_queried_buffer(
"thi",
1,
"third.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
open_close_queried_buffer(
"sec",
1,
"second.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
cx.dispatch_action(window.into(), Toggle);
let query = "collab_ui";
let finder = cx.read(|cx| workspace.read(cx).modal::<FileFinder>().unwrap());
finder
.update(cx, |finder, cx| {
finder.delegate_mut().update_matches(query.to_string(), cx)
})
.await;
finder.read_with(cx, |finder, _| {
let delegate = finder.delegate();
assert!(
delegate.matches.history.is_empty(),
"History items should not math query {query}, they should be matched by name only"
);
let search_entries = delegate
.matches
.search
.iter()
.map(|path_match| path_match.path.to_path_buf())
.collect::<Vec<_>>();
assert_eq!(
search_entries,
vec![
PathBuf::from("collab_ui/collab_ui.rs"),
PathBuf::from("collab_ui/third.rs"),
PathBuf::from("collab_ui/first.rs"),
PathBuf::from("collab_ui/second.rs"),
],
"Despite all search results having the same directory name, the most matching one should be on top"
);
});
}
#[gpui::test]
async fn test_nonexistent_history_items_not_shown(
deterministic: Arc<gpui::executor::Deterministic>,
cx: &mut gpui::TestAppContext,
) {
let app_state = init_test(cx);
app_state
.fs
.as_fake()
.insert_tree(
"/src",
json!({
"test": {
"first.rs": "// First Rust file",
"nonexistent.rs": "// Second Rust file",
"third.rs": "// Third Rust file",
}
}),
)
.await;
let project = Project::test(app_state.fs.clone(), ["/src".as_ref()], cx).await;
let window = cx.add_window(|cx| Workspace::test_new(project, cx));
let workspace = window.root(cx);
// generate some history to select from
open_close_queried_buffer(
"fir",
1,
"first.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
open_close_queried_buffer(
"non",
1,
"nonexistent.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
open_close_queried_buffer(
"thi",
1,
"third.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
open_close_queried_buffer(
"fir",
1,
"first.rs",
window.into(),
&workspace,
&deterministic,
cx,
)
.await;
cx.dispatch_action(window.into(), Toggle);
let query = "rs";
let finder = cx.read(|cx| workspace.read(cx).modal::<FileFinder>().unwrap());
finder
.update(cx, |finder, cx| {
finder.delegate_mut().update_matches(query.to_string(), cx)
})
.await;
finder.read_with(cx, |finder, _| {
let delegate = finder.delegate();
let history_entries = delegate
.matches
.history
.iter()
.map(|(_, path_match)| path_match.as_ref().expect("should have a path match").path.to_path_buf())
.collect::<Vec<_>>();
assert_eq!(
history_entries,
vec![
PathBuf::from("test/first.rs"),
PathBuf::from("test/third.rs"),
],
"Should have all opened files in the history, except the ones that do not exist on disk"
);
});
}
async fn open_close_queried_buffer(
input: &str,
expected_matches: usize,

View File

@@ -9,10 +9,13 @@ path = "src/fs.rs"
[dependencies]
collections = { path = "../collections" }
gpui = { path = "../gpui" }
lsp = { path = "../lsp" }
rope = { path = "../rope" }
text = { path = "../text" }
util = { path = "../util" }
sum_tree = { path = "../sum_tree" }
rpc = { path = "../rpc" }
anyhow.workspace = true
async-trait.workspace = true
@@ -31,10 +34,8 @@ log.workspace = true
libc = "0.2"
time.workspace = true
gpui = { path = "../gpui", optional = true}
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }
[features]
test-support = ["gpui/test-support"]
test-support = []

View File

@@ -93,6 +93,33 @@ pub struct Metadata {
pub is_dir: bool,
}
impl From<lsp::CreateFileOptions> for CreateOptions {
fn from(options: lsp::CreateFileOptions) -> Self {
Self {
overwrite: options.overwrite.unwrap_or(false),
ignore_if_exists: options.ignore_if_exists.unwrap_or(false),
}
}
}
impl From<lsp::RenameFileOptions> for RenameOptions {
fn from(options: lsp::RenameFileOptions) -> Self {
Self {
overwrite: options.overwrite.unwrap_or(false),
ignore_if_exists: options.ignore_if_exists.unwrap_or(false),
}
}
}
impl From<lsp::DeleteFileOptions> for RemoveOptions {
fn from(options: lsp::DeleteFileOptions) -> Self {
Self {
recursive: options.recursive.unwrap_or(false),
ignore_if_not_exists: options.ignore_if_not_exists.unwrap_or(false),
}
}
}
pub struct RealFs;
#[async_trait::async_trait]

View File

@@ -2,6 +2,7 @@ use anyhow::Result;
use collections::HashMap;
use git2::{BranchType, StatusShow};
use parking_lot::Mutex;
use rpc::proto;
use serde_derive::{Deserialize, Serialize};
use std::{
cmp::Ordering,
@@ -22,7 +23,6 @@ pub struct Branch {
/// Timestamp of most recent commit, normalized to Unix Epoch format.
pub unix_timestamp: Option<i64>,
}
#[async_trait::async_trait]
pub trait GitRepository: Send {
fn reload_index(&self);
@@ -358,6 +358,24 @@ impl GitFileStatus {
}
}
}
pub fn from_proto(git_status: Option<i32>) -> Option<GitFileStatus> {
git_status.and_then(|status| {
proto::GitStatus::from_i32(status).map(|status| match status {
proto::GitStatus::Added => GitFileStatus::Added,
proto::GitStatus::Modified => GitFileStatus::Modified,
proto::GitStatus::Conflict => GitFileStatus::Conflict,
})
})
}
pub fn to_proto(self) -> i32 {
match self {
GitFileStatus::Added => proto::GitStatus::Added as i32,
GitFileStatus::Modified => proto::GitStatus::Modified as i32,
GitFileStatus::Conflict => proto::GitStatus::Conflict as i32,
}
}
}
#[derive(Clone, Debug, Ord, Hash, PartialOrd, Eq, PartialEq)]

View File

@@ -4,7 +4,5 @@ mod paths;
mod strings;
pub use char_bag::CharBag;
pub use paths::{
match_fixed_path_set, match_path_sets, PathMatch, PathMatchCandidate, PathMatchCandidateSet,
};
pub use paths::{match_path_sets, PathMatch, PathMatchCandidate, PathMatchCandidateSet};
pub use strings::{match_strings, StringMatch, StringMatchCandidate};

View File

@@ -441,7 +441,7 @@ mod tests {
score,
worktree_id: 0,
positions: Vec::new(),
path: Arc::from(candidate.path),
path: candidate.path.clone(),
path_prefix: "".into(),
distance_to_relative_ancestor: usize::MAX,
},

View File

@@ -14,7 +14,7 @@ use crate::{
#[derive(Clone, Debug)]
pub struct PathMatchCandidate<'a> {
pub path: &'a Path,
pub path: &'a Arc<Path>,
pub char_bag: CharBag,
}
@@ -90,44 +90,6 @@ impl Ord for PathMatch {
}
}
pub fn match_fixed_path_set(
candidates: Vec<PathMatchCandidate>,
worktree_id: usize,
query: &str,
smart_case: bool,
max_results: usize,
) -> Vec<PathMatch> {
let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
let query = query.chars().collect::<Vec<_>>();
let query_char_bag = CharBag::from(&lowercase_query[..]);
let mut matcher = Matcher::new(
&query,
&lowercase_query,
query_char_bag,
smart_case,
max_results,
);
let mut results = Vec::new();
matcher.match_candidates(
&[],
&[],
candidates.into_iter(),
&mut results,
&AtomicBool::new(false),
|candidate, score| PathMatch {
score,
worktree_id,
positions: Vec::new(),
path: Arc::from(candidate.path),
path_prefix: Arc::from(""),
distance_to_relative_ancestor: usize::MAX,
},
);
results
}
pub async fn match_path_sets<'a, Set: PathMatchCandidateSet<'a>>(
candidate_sets: &'a [Set],
query: &str,
@@ -195,7 +157,7 @@ pub async fn match_path_sets<'a, Set: PathMatchCandidateSet<'a>>(
score,
worktree_id,
positions: Vec::new(),
path: Arc::from(candidate.path),
path: candidate.path.clone(),
path_prefix: candidate_set.prefix(),
distance_to_relative_ancestor: relative_to.as_ref().map_or(
usize::MAX,

Some files were not shown because too many files have changed in this diff Show More