Compare commits

..

2 Commits

Author SHA1 Message Date
Thorsten Ball
de855fa4a7 Make debug actually do debug 2024-07-17 16:48:21 +02:00
Thorsten Ball
46239b7a04 debug: Debug a flaky test 2024-07-17 16:38:13 +02:00
421 changed files with 9520 additions and 19439 deletions

View File

@@ -40,15 +40,10 @@ jobs:
- name: Check spelling
run: |
if ! cargo install --list | grep "typos-cli v$TYPOS_CLI_VERSION" > /dev/null; then
echo "Installing typos-cli@$TYPOS_CLI_VERSION..."
cargo install "typos-cli@$TYPOS_CLI_VERSION"
else
echo "typos-cli@$TYPOS_CLI_VERSION is already installed."
if ! which typos > /dev/null; then
cargo install typos-cli
fi
typos
env:
TYPOS_CLI_VERSION: "1.23.3"
- name: Run style checks
uses: ./.github/actions/check_style
@@ -256,8 +251,6 @@ jobs:
draft: true
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
files: |
target/zed-remote-server-macos-x86_64.gz
target/zed-remote-server-macos-aarch64.gz
target/aarch64-apple-darwin/release/Zed-aarch64.dmg
target/x86_64-apple-darwin/release/Zed-x86_64.dmg
target/release/Zed.dmg
@@ -329,9 +322,7 @@ jobs:
with:
draft: true
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
files: |
target/zed-remote-server-linux-x86_64.gz
target/release/zed-linux-x86_64.tar.gz
files: target/release/zed-linux-x86_64.tar.gz
body: ""
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
@@ -414,9 +405,7 @@ jobs:
with:
draft: true
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
files: |
target/zed-remote-server-linux-aarch64.gz
target/release/zed-linux-aarch64.tar.gz
files: target/release/zed-linux-aarch64.tar.gz
body: ""
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -33,7 +33,6 @@ jobs:
- name: Run clippy
run: ./script/clippy
tests:
timeout-minutes: 60
name: Run tests
@@ -133,6 +132,7 @@ jobs:
name: Create a Linux *.tar.gz bundle for ARM
if: github.repository_owner == 'zed-industries'
runs-on:
- self-hosted
- hosted-linux-arm-1
needs: tests
env:

875
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -19,7 +19,6 @@ members = [
"crates/collections",
"crates/command_palette",
"crates/command_palette_hooks",
"crates/completion",
"crates/copilot",
"crates/db",
"crates/dev_server_projects",
@@ -44,14 +43,13 @@ members = [
"crates/gpui_macros",
"crates/headless",
"crates/html_to_markdown",
"crates/http_client",
"crates/http",
"crates/image_viewer",
"crates/indexed_docs",
"crates/inline_completion_button",
"crates/install_cli",
"crates/journal",
"crates/language",
"crates/language_model",
"crates/language_selector",
"crates/language_tools",
"crates/languages",
@@ -81,8 +79,6 @@ members = [
"crates/refineable",
"crates/refineable/derive_refineable",
"crates/release_channel",
"crates/remote",
"crates/remote_server",
"crates/repl",
"crates/rich_text",
"crates/rope",
@@ -90,9 +86,7 @@ members = [
"crates/search",
"crates/semantic_index",
"crates/semantic_version",
"crates/session",
"crates/settings",
"crates/settings_ui",
"crates/snippet",
"crates/snippet_provider",
"crates/sqlez",
@@ -143,7 +137,6 @@ members = [
"extensions/php",
"extensions/prisma",
"extensions/purescript",
"extensions/ruff",
"extensions/ruby",
"extensions/snippets",
"extensions/svelte",
@@ -161,7 +154,6 @@ resolver = "2"
[workspace.dependencies]
activity_indicator = { path = "crates/activity_indicator" }
aho-corasick = "1.1"
ai = { path = "crates/ai" }
anthropic = { path = "crates/anthropic" }
assets = { path = "crates/assets" }
@@ -181,7 +173,6 @@ collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections" }
command_palette = { path = "crates/command_palette" }
command_palette_hooks = { path = "crates/command_palette_hooks" }
completion = { path = "crates/completion" }
copilot = { path = "crates/copilot" }
db = { path = "crates/db" }
dev_server_projects = { path = "crates/dev_server_projects" }
@@ -204,14 +195,13 @@ gpui = { path = "crates/gpui" }
gpui_macros = { path = "crates/gpui_macros" }
headless = { path = "crates/headless" }
html_to_markdown = { path = "crates/html_to_markdown" }
http_client = { path = "crates/http_client" }
http = { path = "crates/http" }
image_viewer = { path = "crates/image_viewer" }
indexed_docs = { path = "crates/indexed_docs" }
inline_completion_button = { path = "crates/inline_completion_button" }
install_cli = { path = "crates/install_cli" }
journal = { path = "crates/journal" }
language = { path = "crates/language" }
language_model = { path = "crates/language_model" }
language_selector = { path = "crates/language_selector" }
language_tools = { path = "crates/language_tools" }
languages = { path = "crates/languages" }
@@ -241,8 +231,6 @@ proto = { path = "crates/proto" }
quick_action_bar = { path = "crates/quick_action_bar" }
recent_projects = { path = "crates/recent_projects" }
release_channel = { path = "crates/release_channel" }
remote = { path = "crates/remote" }
remote_server = { path = "crates/remote_server" }
repl = { path = "crates/repl" }
rich_text = { path = "crates/rich_text" }
rope = { path = "crates/rope" }
@@ -250,9 +238,7 @@ rpc = { path = "crates/rpc" }
search = { path = "crates/search" }
semantic_index = { path = "crates/semantic_index" }
semantic_version = { path = "crates/semantic_version" }
session = { path = "crates/session" }
settings = { path = "crates/settings" }
settings_ui = { path = "crates/settings_ui" }
snippet = { path = "crates/snippet" }
snippet_provider = { path = "crates/snippet_provider" }
sqlez = { path = "crates/sqlez" }
@@ -292,18 +278,16 @@ ashpd = "0.9.1"
async-compression = { version = "0.4", features = ["gzip", "futures-io"] }
async-dispatcher = { version = "0.1" }
async-fs = "1.6"
async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553" }
async-recursion = "1.0.0"
async-tar = "0.4.2"
async-trait = "0.1"
async-tungstenite = { version = "0.16" }
async-watch = "0.3.1"
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
base64 = "0.13"
bitflags = "2.6.0"
blade-graphics = { git = "https://github.com/zed-industries/blade", rev = "7e497c534d5d4a30c18d9eb182cf39eaf0aaa25e" }
blade-macros = { git = "https://github.com/zed-industries/blade", rev = "7e497c534d5d4a30c18d9eb182cf39eaf0aaa25e" }
blade-util = { git = "https://github.com/zed-industries/blade", rev = "7e497c534d5d4a30c18d9eb182cf39eaf0aaa25e" }
blade-graphics = { git = "https://github.com/zed-industries/blade", rev = "a477c2008db27db0b9f745715e119b3ee7ab7818" }
blade-macros = { git = "https://github.com/zed-industries/blade", rev = "a477c2008db27db0b9f745715e119b3ee7ab7818" }
blade-util = { git = "https://github.com/zed-industries/blade", rev = "a477c2008db27db0b9f745715e119b3ee7ab7818" }
cap-std = "3.0"
cargo_toml = "0.20"
chrono = { version = "0.4", features = ["serde"] }
@@ -317,7 +301,7 @@ dashmap = "5.5.3"
derive_more = "0.99.17"
dirs = "4.0"
emojis = "0.6.1"
env_logger = "0.10"
env_logger = "0.9"
exec = "0.3.1"
fork = "0.1.23"
futures = "0.3"
@@ -340,7 +324,7 @@ itertools = "0.11.0"
lazy_static = "1.4.0"
libc = "0.2"
linkify = "0.10.0"
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
log = { version = "0.4.16", features = ["kv_unstable_serde"] }
markup5ever_rcdom = "0.3.0"
nanoid = "0.4"
nix = "0.28"
@@ -361,13 +345,12 @@ rand = "0.8.5"
refineable = { path = "./crates/refineable" }
regex = "1.5"
repair_json = "0.1.0"
rsa = "0.9.6"
runtimelib = { version = "0.14", default-features = false, features = [
runtimelib = { version = "0.12", default-features = false, features = [
"async-dispatcher-runtime",
] }
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
rust-embed = { version = "8.4", features = ["include-exclude"] }
schemars = {version = "0.8", features = ["impl_json_schema"]}
schemars = "0.8"
semver = "1.0"
serde = { version = "1.0", features = ["derive", "rc"] }
serde_derive = { version = "1.0", features = ["deserialize_in_place"] }
@@ -398,7 +381,6 @@ time = { version = "0.3", features = [
"serde-well-known",
"formatting",
] }
tiny_http = "0.8"
toml = "0.8"
tokio = { version = "1", features = ["full"] }
tower-http = "0.4.4"
@@ -432,14 +414,14 @@ url = "2.2"
uuid = { version = "1.1.2", features = ["v4", "v5", "serde"] }
wasmparser = "0.201"
wasm-encoder = "0.201"
wasmtime = { version = "19.0.2", default-features = false, features = [
wasmtime = { version = "19.0.0", default-features = false, features = [
"async",
"demangle",
"runtime",
"cranelift",
"component-model",
] }
wasmtime-wasi = "19.0.2"
wasmtime-wasi = "19.0.0"
which = "6.0.0"
wit-component = "0.201"
sys-locale = "0.3.1"
@@ -475,6 +457,7 @@ features = [
"Win32_System_SystemInformation",
"Win32_System_SystemServices",
"Win32_System_Threading",
"Win32_System_Time",
"Win32_System_WinRT",
"Win32_UI_Controls",
"Win32_UI_HiDpi",
@@ -486,6 +469,8 @@ features = [
[patch.crates-io]
tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "7b4894ba2ae81b988846676f54c0988d4027ef4f" }
# Workaround for a broken nightly build of gpui: See #7644 and revisit once 0.5.3 is released.
pathfinder_simd = { git = "https://github.com/servo/pathfinder.git", rev = "4968e819c0d9b015437ffc694511e175801a17c7" }
[profile.dev]
split-debuginfo = "unpacked"

View File

@@ -1,3 +0,0 @@
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M6 6C5.69062 6.30938 4.56159 6.55977 3.51192 6.73263C3.27345 6.7719 3.27345 7.2281 3.51192 7.26737C4.56159 7.44023 5.69062 7.69062 6 8C6.30938 8.30938 6.55977 9.43841 6.73263 10.4881C6.7719 10.7266 7.2281 10.7266 7.26737 10.4881C7.44023 9.43841 7.69062 8.30938 8 8C8.30938 7.69062 9.43841 7.44023 10.4881 7.26737C10.7266 7.2281 10.7266 6.7719 10.4881 6.73263C9.43841 6.55977 8.30938 6.30938 8 6C7.69062 5.69062 7.44023 4.56159 7.26737 3.51192C7.2281 3.27345 6.7719 3.27345 6.73263 3.51192C6.55977 4.56159 6.30938 5.69062 6 6Z" stroke="black" stroke-width="1.25" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 700 B

View File

@@ -1,3 +1,3 @@
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M4 9.8V4.2C4 4.08954 4.08954 4 4.2 4H9.8C9.91046 4 10 4.08954 10 4.2V9.8C10 9.91046 9.91046 10 9.8 10H4.2C4.08954 10 4 9.91046 4 9.8Z" stroke="#C56757" stroke-width="1.25" stroke-linejoin="round"/>
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M9.88889 1H2.11111C1.49746 1 1 1.49746 1 2.11111V9.88889C1 10.5025 1.49746 11 2.11111 11H9.88889C10.5025 11 11 10.5025 11 9.88889V2.11111C11 1.49746 10.5025 1 9.88889 1Z" stroke="#C56757" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 310 B

After

Width:  |  Height:  |  Size: 369 B

View File

@@ -40,6 +40,7 @@
"backspace": "editor::Backspace",
"shift-backspace": "editor::Backspace",
"delete": "editor::Delete",
"ctrl-d": "editor::Delete",
"tab": "editor::Tab",
"shift-tab": "editor::TabPrev",
"ctrl-k": "editor::CutToEndOfLine",
@@ -106,6 +107,7 @@
"enter": "editor::Newline",
"shift-enter": "editor::Newline",
"ctrl-shift-enter": "editor::NewlineBelow",
"ctrl-enter": "editor::NewlineAbove",
"alt-z": "editor::ToggleSoftWrap",
"ctrl-f": "buffer_search::Deploy",
"ctrl-h": ["buffer_search::Deploy", { "replace_enabled": true }],
@@ -115,12 +117,6 @@
"ctrl-alt-e": "editor::SelectEnclosingSymbol"
}
},
{
"context": "Editor && mode == full && !jupyter",
"bindings": {
"ctrl-enter": "editor::NewlineAbove"
}
},
{
"context": "Editor && mode == full && inline_completion",
"bindings": {
@@ -209,7 +205,7 @@
}
},
{
"context": "ProjectSearchBar && in_replace > Editor",
"context": "ProjectSearchBar && in_replace",
"bindings": {
"enter": "search::ReplaceNext",
"ctrl-alt-enter": "search::ReplaceAll"
@@ -249,6 +245,13 @@
"ctrl-alt-shift-x": "search::ToggleRegex"
}
},
{
"context": "Terminal",
"bindings": {
"ctrl-w": ["terminal::SendKeystroke", "ctrl-w"],
"ctrl-e": ["terminal::SendKeystroke", "ctrl-e"]
}
},
// Bindings from VS Code
{
"context": "Editor",
@@ -266,7 +269,6 @@
"alt-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink Selection
"ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection
"ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word
"ctrl-d": ["editor::SelectNext", { "replace_newest": false }],
"ctrl-shift-down": ["editor::SelectNext", { "replace_newest": false }], // Add selection to Next Find Match
"ctrl-shift-up": ["editor::SelectPrevious", { "replace_newest": false }],
"ctrl-k ctrl-d": ["editor::SelectNext", { "replace_newest": true }],
@@ -458,16 +460,12 @@
{
"bindings": {
"ctrl-alt-shift-f": "workspace::FollowNextCollaborator",
// TODO: Move this to a dock open action
"ctrl-shift-c": "collab_panel::ToggleFocus",
"ctrl-alt-i": "zed::DebugElements",
"ctrl-:": "editor::ToggleInlayHints"
}
},
{
"context": "!Terminal",
"bindings": {
"ctrl-shift-c": "collab_panel::ToggleFocus"
}
},
{
"context": "Editor && mode == full",
"bindings": {
@@ -479,12 +477,6 @@
"ctrl-enter": "assistant::InlineAssist"
}
},
{
"context": "Editor && jupyter && !ContextEditor",
"bindings": {
"ctrl-shift-enter": "repl::Run"
}
},
{
"context": "ContextEditor > Editor",
"bindings": {
@@ -601,14 +593,11 @@
"context": "Terminal",
"bindings": {
"ctrl-alt-space": "terminal::ShowCharacterPalette",
"ctrl-shift-c": "terminal::Copy",
"shift-ctrl-c": "terminal::Copy",
"ctrl-insert": "terminal::Copy",
// "ctrl-a": "editor::SelectAll", // conflicts with readline
"ctrl-shift-v": "terminal::Paste",
"shift-ctrl-v": "terminal::Paste",
"shift-insert": "terminal::Paste",
"ctrl-enter": "assistant::InlineAssist",
"ctrl-w": ["terminal::SendKeystroke", "ctrl-w"],
"ctrl-e": ["terminal::SendKeystroke", "ctrl-e"],
"up": ["terminal::SendKeystroke", "up"],
"pageup": ["terminal::SendKeystroke", "pageup"],
"down": ["terminal::SendKeystroke", "down"],

View File

@@ -180,12 +180,6 @@
"cmd-c": "markdown::Copy"
}
},
{
"context": "Editor && jupyter && !ContextEditor",
"bindings": {
"ctrl-shift-enter": "repl::Run"
}
},
{
"context": "AssistantPanel",
"bindings": {
@@ -261,7 +255,7 @@
}
},
{
"context": "ProjectSearchBar && in_replace > Editor",
"context": "ProjectSearchBar && in_replace",
"bindings": {
"enter": "search::ReplaceNext",
"cmd-enter": "search::ReplaceAll"
@@ -298,6 +292,7 @@
"alt-cmd-c": "search::ToggleCaseSensitive",
"alt-cmd-w": "search::ToggleWholeWord",
"alt-cmd-f": "project_search::ToggleFilters",
"alt-cmd-g": "search::ToggleRegex",
"alt-cmd-x": "search::ToggleRegex"
}
},
@@ -575,6 +570,12 @@
"space": "project_panel::Open"
}
},
{
"context": "Editor && jupyter && !ContextEditor",
"bindings": {
"cmd-enter": "repl::Run"
}
},
{
"context": "CollabPanel && not_editing",
"bindings": {
@@ -627,7 +628,6 @@
"ctrl-cmd-space": "terminal::ShowCharacterPalette",
"cmd-c": "terminal::Copy",
"cmd-v": "terminal::Paste",
"cmd-a": "editor::SelectAll",
"cmd-k": "terminal::Clear",
"ctrl-enter": "assistant::InlineAssist",
// Some nice conveniences

View File

@@ -12,8 +12,8 @@
{
"context": "Editor",
"bindings": {
"ctrl-shift-up": "editor::MoveLineUp",
"ctrl-shift-down": "editor::MoveLineDown",
"ctrl-shift-up": "editor::AddSelectionAbove",
"ctrl-shift-down": "editor::AddSelectionBelow",
"ctrl-shift-m": "editor::SelectLargerSyntaxNode",
"ctrl-shift-l": "editor::SplitSelectionIntoLines",
"ctrl-shift-a": "editor::SelectLargerSyntaxNode",

View File

@@ -14,8 +14,6 @@
"bindings": {
"ctrl-shift-up": "editor::AddSelectionAbove",
"ctrl-shift-down": "editor::AddSelectionBelow",
"cmd-ctrl-up": "editor::MoveLineUp",
"cmd-ctrl-down": "editor::MoveLineDown",
"cmd-shift-space": "editor::SelectAll",
"ctrl-shift-m": "editor::SelectLargerSyntaxNode",
"cmd-shift-l": "editor::SplitSelectionIntoLines",

View File

@@ -10,7 +10,6 @@
"backspace": "vim::Backspace",
"j": "vim::Down",
"down": "vim::Down",
"ctrl-j": "vim::Down",
"enter": "vim::NextLineStart",
"ctrl-m": "vim::NextLineStart",
"+": "vim::NextLineStart",
@@ -216,7 +215,7 @@
"shift-d": "vim::DeleteToEndOfLine",
"shift-j": "vim::JoinLines",
"y": ["vim::PushOperator", "Yank"],
"shift-y": "vim::YankToEndOfLine",
"shift-y": "vim::YankLine",
"i": "vim::InsertBefore",
"shift-i": "vim::InsertFirstNonWhitespace",
"a": "vim::InsertAfter",
@@ -253,7 +252,7 @@
"[ d": "editor::GoToPrevDiagnostic",
"] c": "editor::GoToHunk",
"[ c": "editor::GoToPrevHunk",
"g c": ["vim::PushOperator", "ToggleComments"]
"g c c": "vim::ToggleComments"
}
},
{
@@ -434,12 +433,6 @@
"<": "vim::CurrentLine"
}
},
{
"context": "vim_operator == gc",
"bindings": {
"c": "vim::CurrentLine"
}
},
{
"context": "BufferSearchBar && !in_replace",
"bindings": {

View File

@@ -1,241 +0,0 @@
Your task is to map a step from the conversation above to operations on symbols inside the provided source files.
Guidelines:
- There's no need to describe *what* to do, just *where* to do it.
- If creating a file, assume any subsequent updates are included at the time of creation.
- Don't create and then update a file.
- We'll create it in one shot.
- Prefer updating symbols lower in the syntax tree if possible.
- Never include operations on a parent symbol and one of its children in the same <operations> block.
- Never nest an operation with another operation or include CDATA or other content. All operations are leaf nodes.
- Include a description attribute for each operation with a brief, one-line description of the change to perform.
- Descriptions are required for all operations except delete.
- When generating multiple operations, ensure the descriptions are specific to each individual operation.
- Avoid referring to the location in the description. Focus on the change to be made, not the location where it's made. That's implicit with the symbol you provide.
- Don't generate multiple operations at the same location. Instead, combine them together in a single operation with a succinct combined description.
The available operation types are:
1. <update>: Modify an existing symbol in a file.
2. <create_file>: Create a new file.
3. <insert_sibling_after>: Add a new symbol as sibling after an existing symbol in a file.
4. <append_child>: Add a new symbol as the last child of an existing symbol in a file.
5. <prepend_child>: Add a new symbol as the first child of an existing symbol in a file.
6. <delete>: Remove an existing symbol from a file. The `description` attribute is invalid for delete, but required for other ops.
All operations *require* a path.
Operations that *require* a symbol: <update>, <insert_sibling_after>, <delete>
Operations that don't allow a symbol: <create>
Operations that have an *optional* symbol: <prepend_child>, <append_child>
Example 1:
User:
```rs src/rectangle.rs
struct Rectangle {
width: f64,
height: f64,
}
impl Rectangle {
fn new(width: f64, height: f64) -> Self {
Rectangle { width, height }
}
}
```
Symbols for src/rectangle.rs:
- struct Rectangle
- impl Rectangle
- impl Rectangle fn new
<step>Add new methods 'calculate_area' and 'calculate_perimeter' to the Rectangle struct</step>
<step>Implement the 'Display' trait for the Rectangle struct</step>
What are the operations for the step: <step>Add a new method 'calculate_area' to the Rectangle struct</step>
Assistant (wrong):
<operations>
<append_child path="src/shapes.rs" symbol="impl Rectangle" description="Add calculate_area method" />
<append_child path="src/shapes.rs" symbol="impl Rectangle" description="Add calculate_perimeter method" />
</operations>
This demonstrates what NOT to do. NEVER append multiple children at the same location.
Assistant (corrected):
<operations>
<append_child path="src/shapes.rs" symbol="impl Rectangle" description="Add calculate area and perimeter methods" />
</operations>
User:
What are the operations for the step: <step>Implement the 'Display' trait for the Rectangle struct</step>
Assistant:
<operations>
<insert_sibling_after path="src/shapes.rs" symbol="impl Rectangle" description="Implement Display trait for Rectangle"/>
</operations>
Example 2:
User:
```rs src/user.rs
struct User {
pub name: String,
age: u32,
email: String,
}
impl User {
fn new(name: String, age: u32, email: String) -> Self {
User { name, age, email }
}
pub fn print_info(&self) {
println!("Name: {}, Age: {}, Email: {}", self.name, self.age, self.email);
}
}
```
Symbols for src/user.rs:
- struct User
- struct User pub name
- struct User age
- struct User email
- impl User
- impl User fn new
- impl User pub fn print_info
<step>Update the 'print_info' method to use formatted output</step>
<step>Remove the 'email' field from the User struct</step>
What are the operations for the step: <step>Update the 'print_info' method to use formatted output</step>
Assistant:
<operations>
<update path="src/user.rs" symbol="impl User fn print_info" description="Use formatted output" />
</operations>
User:
What are the operations for the step: <step>Remove the 'email' field from the User struct</step>
Assistant:
<operations>
<delete path="src/user.rs" symbol="struct User email" description="Remove the email field" />
</operations>
Example 3:
User:
```rs src/vehicle.rs
struct Vehicle {
make: String,
model: String,
year: u32,
}
impl Vehicle {
fn new(make: String, model: String, year: u32) -> Self {
Vehicle { make, model, year }
}
fn print_year(&self) {
println!("Year: {}", self.year);
}
}
```
Symbols for src/vehicle.rs:
- struct Vehicle
- struct Vehicle make
- struct Vehicle model
- struct Vehicle year
- impl Vehicle
- impl Vehicle fn new
- impl Vehicle fn print_year
<step>Add a 'use std::fmt;' statement at the beginning of the file</step>
<step>Add a new method 'start_engine' in the Vehicle impl block</step>
What are the operations for the step: <step>Add a 'use std::fmt;' statement at the beginning of the file</step>
Assistant:
<operations>
<prepend_child path="src/vehicle.rs" description="Add 'use std::fmt' statement" />
</operations>
User:
What are the operations for the step: <step>Add a new method 'start_engine' in the Vehicle impl block</step>
Assistant:
<operations>
<insert_sibling_after path="src/vehicle.rs" symbol="impl Vehicle fn new" description="Add start_engine method"/>
</operations>
Example 4:
User:
```rs src/employee.rs
struct Employee {
name: String,
position: String,
salary: u32,
department: String,
}
impl Employee {
fn new(name: String, position: String, salary: u32, department: String) -> Self {
Employee { name, position, salary, department }
}
fn print_details(&self) {
println!("Name: {}, Position: {}, Salary: {}, Department: {}",
self.name, self.position, self.salary, self.department);
}
fn give_raise(&mut self, amount: u32) {
self.salary += amount;
}
}
```
Symbols for src/employee.rs:
- struct Employee
- struct Employee name
- struct Employee position
- struct Employee salary
- struct Employee department
- impl Employee
- impl Employee fn new
- impl Employee fn print_details
- impl Employee fn give_raise
<step>Make salary an f32</step>
What are the operations for the step: <step>Make salary an f32</step>
A (wrong):
<operations>
<update path="src/employee.rs" symbol="struct Employee" description="Change the type of salary to an f32" />
<update path="src/employee.rs" symbol="struct Employee salary" description="Change the type to an f32" />
</operations>
This example demonstrates what not to do. `struct Employee salary` is a child of `struct Employee`.
A (corrected):
<operations>
<update path="src/employee.rs" symbol="struct Employee salary" description="Change the type to an f32" />
</operations>
User:
What are the correct operations for the step: <step>Remove the 'department' field and update the 'print_details' method</step>
A:
<operations>
<delete path="src/employee.rs" symbol="struct Employee department" />
<update path="src/employee.rs" symbol="impl Employee fn print_details" description="Don't print the 'department' field" />
</operations>
Now generate the operations for the following step.
Output only valid XML containing valid operations with their required attributes.
NEVER output code or any other text inside <operation> tags. If you do, you will replaced with another model.
Your response *MUST* begin with <operations> and end with </operations>:

View File

@@ -82,7 +82,7 @@
// Whether to confirm before quitting Zed.
"confirm_quit": false,
// Whether to restore last closed project when fresh Zed instance is opened.
"restore_on_startup": "last_session",
"restore_on_startup": "last_workspace",
// Size of the drop target in the editor.
"drop_target_size": 0.2,
// Whether the window should be closed when using 'close active item' on a window with no tabs.
@@ -375,7 +375,7 @@
},
"assistant": {
// Version of this setting.
"version": "2",
"version": "1",
// Whether the assistant is enabled.
"enabled": true,
// Whether to show the assistant panel button in the status bar.
@@ -386,12 +386,17 @@
"default_width": 640,
// Default height when the assistant is docked to the bottom.
"default_height": 320,
// The default model to use when creating new contexts.
"default_model": {
// The provider to use.
"provider": "openai",
// The model to use.
"model": "gpt-4o"
// AI provider.
"provider": {
"name": "openai",
// The default model to use when creating new contexts. This
// setting can take three values:
//
// 1. "gpt-3.5-turbo"
// 2. "gpt-4"
// 3. "gpt-4-turbo-preview"
// 4. "gpt-4o"
"default_model": "gpt-4o"
}
},
// Whether the screen sharing icon is shown in the os status bar.
@@ -699,7 +704,7 @@
//
"file_types": {
"JSON": ["flake.lock"],
"JSONC": ["**/.zed/**/*.json", "**/zed/**/*.json", "**/Zed/**/*.json", "tsconfig.json"]
"JSONC": ["**/.zed/**/*.json", "**/zed/**/*.json"]
},
// The extensions that Zed should automatically install on startup.
//
@@ -737,9 +742,6 @@
"Elixir": {
"language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."]
},
"Erlang": {
"language_servers": ["erlang-ls", "!elp", "..."]
},
"Go": {
"code_actions_on_format": {
"source.organizeImports": true
@@ -796,7 +798,7 @@
}
},
"Ruby": {
"language_servers": ["solargraph", "!ruby-lsp", "!rubocop", "..."]
"language_servers": ["solargraph", "!ruby-lsp", "..."]
},
"SCSS": {
"prettier": {
@@ -809,9 +811,6 @@
"plugins": ["prettier-plugin-sql"]
}
},
"Starlark": {
"language_servers": ["starpls", "!buck2-lsp", "..."]
},
"Svelte": {
"prettier": {
"allowed": true,
@@ -852,18 +851,6 @@
}
}
},
// Different settings for specific language models.
"language_models": {
"anthropic": {
"api_url": "https://api.anthropic.com"
},
"openai": {
"api_url": "https://api.openai.com/v1"
},
"ollama": {
"api_url": "http://localhost:11434"
}
},
// Zed's Prettier integration settings.
// Allows to enable/disable formatting with Prettier
// and configure default Prettier, used when no project-level Prettier installation is found.
@@ -896,15 +883,6 @@
// }
// }
},
// Jupyter settings
"jupyter": {
"enabled": true
// Specify the language name as the key and the kernel name as the value.
// "kernel_selections": {
// "python": "conda-base"
// "typescript": "deno"
// }
},
// Vim settings
"vim": {
"use_system_clipboard": "always",
@@ -957,13 +935,5 @@
// Examples:
// - "proxy": "socks5://localhost:10808"
// - "proxy": "http://127.0.0.1:10809"
"proxy": null,
// Set to configure aliases for the command palette.
// When typing a query which is a key of this object, the value will be used instead.
//
// Examples:
// {
// "W": "workspace::Save"
// }
"command_aliases": {}
"proxy": null
}

View File

@@ -17,27 +17,6 @@
// What to do with the terminal pane and tab, after the command was started:
// * `always` — always show the terminal pane, add and focus the corresponding task's tab in it (default)
// * `never` — avoid changing current terminal pane focus, but still add/reuse the task's tab there
"reveal": "always",
// What to do with the terminal pane and tab, after the command had finished:
// * `never` — Do nothing when the command finishes (default)
// * `always` — always hide the terminal tab, hide the pane also if it was the last tab in it
// * `on_success` — hide the terminal tab on task success only, otherwise behaves similar to `always`
"hide": "never",
// Which shell to use when running a task inside the terminal.
// May take 3 values:
// 1. (default) Use the system's default terminal configuration in /etc/passwd
// "shell": "system"
// 2. A program:
// "shell": {
// "program": "sh"
// }
// 3. A program with arguments:
// "shell": {
// "with_arguments": {
// "program": "/bin/bash",
// "arguments": ["--login"]
// }
// }
"shell": "system"
"reveal": "always"
}
]

View File

@@ -107,7 +107,6 @@ impl ActivityIndicator {
Editor::for_buffer(buffer, Some(project.clone()), cx)
})),
None,
true,
cx,
);
})?;

View File

@@ -18,7 +18,7 @@ path = "src/anthropic.rs"
[dependencies]
anyhow.workspace = true
futures.workspace = true
http_client.workspace = true
http.workspace = true
isahc.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true

View File

@@ -1,6 +1,6 @@
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use std::{convert::TryFrom, time::Duration};
@@ -20,8 +20,6 @@ pub enum Model {
Claude3Sonnet,
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
Claude3Haiku,
#[serde(rename = "custom")]
Custom { name: String, max_tokens: usize },
}
impl Model {
@@ -35,38 +33,30 @@ impl Model {
} else if id.starts_with("claude-3-haiku") {
Ok(Self::Claude3Haiku)
} else {
Err(anyhow!("invalid model id"))
Err(anyhow!("Invalid model id: {}", id))
}
}
pub fn id(&self) -> &str {
pub fn id(&self) -> &'static str {
match self {
Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
Model::Claude3Opus => "claude-3-opus-20240229",
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
Model::Claude3Haiku => "claude-3-opus-20240307",
Self::Custom { name, .. } => name,
}
}
pub fn display_name(&self) -> &str {
pub fn display_name(&self) -> &'static str {
match self {
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Custom { name, .. } => name,
}
}
pub fn max_token_count(&self) -> usize {
match self {
Self::Claude3_5Sonnet
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => 200_000,
Self::Custom { max_tokens, .. } => *max_tokens,
}
200_000
}
}
@@ -100,7 +90,6 @@ impl From<Role> for String {
#[derive(Debug, Serialize)]
pub struct Request {
#[serde(serialize_with = "serialize_request_model")]
pub model: Model,
pub messages: Vec<RequestMessage>,
pub stream: bool,
@@ -108,13 +97,6 @@ pub struct Request {
pub max_tokens: u32,
}
fn serialize_request_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&model.id())
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,

View File

@@ -23,16 +23,15 @@ test-support = [
[dependencies]
anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
assets.workspace = true
assistant_slash_command.workspace = true
async-watch.workspace = true
breadcrumbs.workspace = true
cargo_toml.workspace = true
chrono.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
completion.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
@@ -41,11 +40,10 @@ fuzzy.workspace = true
gpui.workspace = true
heed.workspace = true
html_to_markdown.workspace = true
http_client.workspace = true
http.workspace = true
indexed_docs.workspace = true
indoc.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
menu.workspace = true
multi_buffer.workspace = true
@@ -65,20 +63,21 @@ serde_json.workspace = true
settings.workspace = true
similar.workspace = true
smol.workspace = true
strsim = "0.11"
strum.workspace = true
telemetry_events.workspace = true
terminal.workspace = true
terminal_view.workspace = true
theme.workspace = true
tiktoken-rs.workspace = true
toml.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true
workspace.workspace = true
picker.workspace = true
roxmltree = "0.20.0"
[dev-dependencies]
completion = { workspace = true, features = ["test-support"] }
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true

View File

@@ -1,40 +1,42 @@
pub mod assistant_panel;
pub mod assistant_settings;
mod completion_provider;
mod context;
pub mod context_store;
mod inline_assistant;
mod model_selector;
mod prompt_library;
mod prompts;
mod search;
mod slash_command;
mod streaming_diff;
mod terminal_inline_assistant;
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
use assistant_settings::AssistantSettings;
use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
use completion::LanguageModelCompletionProvider;
pub use completion_provider::*;
pub use context::*;
pub use context_store::*;
use fs::Fs;
use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
use indexed_docs::IndexedDocsRegistry;
pub(crate) use inline_assistant::*;
use language_model::{
LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
};
pub(crate) use model_selector::*;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore};
use settings::{Settings, SettingsStore};
use slash_command::{
active_command, default_command, diagnostics_command, docs_command, fetch_command,
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
tabs_command, term_command,
};
use std::sync::Arc;
use std::{
fmt::{self, Display},
sync::Arc,
};
pub(crate) use streaming_diff::*;
actions!(
@@ -47,22 +49,16 @@ actions!(
InsertIntoEditor,
ToggleFocus,
ResetKey,
InlineAssist,
InsertActivePrompt,
DeployHistory,
DeployPromptLibrary,
ApplyEdit,
ConfirmCommand,
ToggleModelSelector,
DebugEditSteps
ToggleModelSelector
]
);
#[derive(Clone, Default, Deserialize, PartialEq)]
pub struct InlineAssist {
prompt: Option<String>,
}
impl_actions!(assistant, [InlineAssist]);
#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct MessageId(clock::Lamport);
@@ -72,6 +68,166 @@ impl MessageId {
}
}
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn from_proto(role: i32) -> Role {
match proto::LanguageModelRole::from_i32(role) {
Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
None => Role::User,
}
}
pub fn to_proto(&self) -> proto::LanguageModelRole {
match self {
Role::User => proto::LanguageModelRole::LanguageModelUser,
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
Role::System => proto::LanguageModelRole::LanguageModelSystem,
}
}
pub fn cycle(self) -> Role {
match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
Role::System => write!(f, "system"),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum LanguageModel {
Cloud(CloudModel),
OpenAi(OpenAiModel),
Anthropic(AnthropicModel),
Ollama(OllamaModel),
}
impl Default for LanguageModel {
fn default() -> Self {
LanguageModel::Cloud(CloudModel::default())
}
}
impl LanguageModel {
pub fn telemetry_id(&self) -> String {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
}
}
pub fn display_name(&self) -> String {
match self {
LanguageModel::OpenAi(model) => model.display_name().into(),
LanguageModel::Anthropic(model) => model.display_name().into(),
LanguageModel::Cloud(model) => model.display_name().into(),
LanguageModel::Ollama(model) => model.display_name().into(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
LanguageModel::OpenAi(model) => model.max_token_count(),
LanguageModel::Anthropic(model) => model.max_token_count(),
LanguageModel::Cloud(model) => model.max_token_count(),
LanguageModel::Ollama(model) => model.max_token_count(),
}
}
pub fn id(&self) -> &str {
match self {
LanguageModel::OpenAi(model) => model.id(),
LanguageModel::Anthropic(model) => model.id(),
LanguageModel::Cloud(model) => model.id(),
LanguageModel::Ollama(model) => model.id(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelRequestMessage {
pub role: Role,
pub content: String,
}
impl LanguageModelRequestMessage {
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
proto::LanguageModelRequestMessage {
role: self.role.to_proto() as i32,
content: self.content.clone(),
tool_calls: Vec::new(),
tool_call_id: None,
}
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct LanguageModelRequest {
pub model: LanguageModel,
pub messages: Vec<LanguageModelRequestMessage>,
pub stop: Vec<String>,
pub temperature: f32,
}
impl LanguageModelRequest {
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
proto::CompleteWithLanguageModel {
model: self.model.id().to_string(),
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
stop: self.stop.clone(),
temperature: self.temperature,
tool_choice: None,
tools: Vec::new(),
}
}
/// Before we send the request to the server, we can perform fixups on it appropriate to the model.
pub fn preprocess(&mut self) {
match &self.model {
LanguageModel::OpenAi(_) => {}
LanguageModel::Anthropic(_) => {}
LanguageModel::Ollama(_) => {}
LanguageModel::Cloud(model) => match model {
CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku
| CloudModel::Claude3_5Sonnet => {
preprocess_anthropic_request(self);
}
_ => {}
},
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct LanguageModelUsage {
pub prompt_tokens: u32,
@@ -165,16 +321,6 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
cx.set_global(Assistant::default());
AssistantSettings::register(cx);
// TODO: remove this when 0.148.0 is released.
if AssistantSettings::get_global(cx).using_outdated_settings_version {
update_settings_file::<AssistantSettings>(fs.clone(), cx, {
let fs = fs.clone();
|content, cx| {
content.update_file(fs, cx);
}
});
}
cx.spawn(|mut cx| {
let client = client.clone();
async move {
@@ -192,7 +338,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
context_store::init(&client);
prompt_library::init(cx);
init_completion_provider(cx);
completion_provider::init(client.clone(), cx);
assistant_slash_command::init(cx);
register_slash_commands(cx);
assistant_panel::init(cx);
@@ -217,38 +363,6 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
.detach();
}
fn init_completion_provider(cx: &mut AppContext) {
completion::init(cx);
update_active_language_model_from_settings(cx);
cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
.detach();
cx.observe(&LanguageModelRegistry::global(cx), |_, cx| {
update_active_language_model_from_settings(cx)
})
.detach();
}
fn update_active_language_model_from_settings(cx: &mut AppContext) {
let settings = AssistantSettings::get_global(cx);
let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
let model_id = LanguageModelId::from(settings.default_model.model.clone());
let Some(provider) = LanguageModelRegistry::global(cx)
.read(cx)
.provider(&provider_name)
else {
return;
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| model.id() == model_id).cloned() {
LanguageModelCompletionProvider::global(cx).update(cx, |completion_provider, cx| {
completion_provider.set_active_model(model, cx);
});
}
}
fn register_slash_commands(cx: &mut AppContext) {
let slash_command_registry = SlashCommandRegistry::global(cx);
slash_command_registry.register_command(file_command::FileSlashCommand, true);

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,162 @@
use std::sync::Arc;
use std::fmt;
use anthropic::Model as AnthropicModel;
use fs::Fs;
use gpui::{AppContext, Pixels};
use language_model::{settings::AllLanguageModelSettings, CloudModel, LanguageModel};
use ollama::Model as OllamaModel;
use open_ai::Model as OpenAiModel;
use schemars::{schema::Schema, JsonSchema};
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsSources};
use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
pub use anthropic::Model as AnthropicModel;
use gpui::Pixels;
pub use ollama::Model as OllamaModel;
pub use open_ai::Model as OpenAiModel;
use schemars::{
schema::{InstanceType, Metadata, Schema, SchemaObject},
JsonSchema,
};
use serde::{
de::{self, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use settings::{Settings, SettingsSources};
use strum::{EnumIter, IntoEnumIterator};
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
pub enum CloudModel {
Gpt3Point5Turbo,
Gpt4,
Gpt4Turbo,
#[default]
Gpt4Omni,
Claude3_5Sonnet,
Claude3Opus,
Claude3Sonnet,
Claude3Haiku,
Gemini15Pro,
Gemini15Flash,
Custom(String),
}
impl Serialize for CloudModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.id())
}
}
impl<'de> Deserialize<'de> for CloudModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ZedDotDevModelVisitor;
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
type Value = CloudModel;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
let model = CloudModel::iter()
.find(|model| model.id() == value)
.unwrap_or_else(|| CloudModel::Custom(value.to_string()));
Ok(model)
}
}
deserializer.deserialize_str(ZedDotDevModelVisitor)
}
}
impl JsonSchema for CloudModel {
fn schema_name() -> String {
"ZedDotDevModel".to_owned()
}
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = CloudModel::iter()
.filter_map(|model| {
let id = model.id();
if id.is_empty() {
None
} else {
Some(id.to_string())
}
})
.collect::<Vec<_>>();
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
default: Some(CloudModel::default().id().into()),
examples: variants.into_iter().map(Into::into).collect(),
..Default::default()
})),
..Default::default()
})
}
}
impl CloudModel {
pub fn id(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4",
Self::Gpt4Turbo => "gpt-4-turbo-preview",
Self::Gpt4Omni => "gpt-4o",
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
Self::Claude3Opus => "claude-3-opus",
Self::Claude3Sonnet => "claude-3-sonnet",
Self::Claude3Haiku => "claude-3-haiku",
Self::Gemini15Pro => "gemini-1.5-pro",
Self::Gemini15Flash => "gemini-1.5-flash",
Self::Custom(id) => id,
}
}
pub fn display_name(&self) -> &str {
match self {
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
Self::Gpt4 => "GPT 4",
Self::Gpt4Turbo => "GPT 4 Turbo",
Self::Gpt4Omni => "GPT 4 Omni",
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Gemini15Pro => "Gemini 1.5 Pro",
Self::Gemini15Flash => "Gemini 1.5 Flash",
Self::Custom(id) => id.as_str(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
Self::Gpt3Point5Turbo => 2048,
Self::Gpt4 => 4096,
Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
Self::Claude3_5Sonnet
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => 200000,
Self::Gemini15Pro => 128000,
Self::Gemini15Flash => 32000,
Self::Custom(_) => 4096, // TODO: Make this configurable
}
}
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
preprocess_anthropic_request(request)
}
_ => {}
}
}
}
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
@@ -19,9 +167,43 @@ pub enum AssistantDockPosition {
Bottom,
}
#[derive(Debug, PartialEq)]
pub enum AssistantProvider {
ZedDotDev {
model: CloudModel,
},
OpenAi {
model: OpenAiModel,
api_url: String,
low_speed_timeout_in_seconds: Option<u64>,
available_models: Vec<OpenAiModel>,
},
Anthropic {
model: AnthropicModel,
api_url: String,
low_speed_timeout_in_seconds: Option<u64>,
},
Ollama {
model: OllamaModel,
api_url: String,
low_speed_timeout_in_seconds: Option<u64>,
},
}
impl Default for AssistantProvider {
fn default() -> Self {
Self::OpenAi {
model: OpenAiModel::default(),
api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
available_models: Default::default(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(tag = "name", rename_all = "snake_case")]
pub enum AssistantProviderContentV1 {
pub enum AssistantProviderContent {
#[serde(rename = "zed.dev")]
ZedDotDev { default_model: Option<CloudModel> },
#[serde(rename = "openai")]
@@ -52,8 +234,7 @@ pub struct AssistantSettings {
pub dock: AssistantDockPosition,
pub default_width: Pixels,
pub default_height: Pixels,
pub default_model: AssistantDefaultModel,
pub using_outdated_settings_version: bool,
pub provider: AssistantProvider,
}
/// Assistant panel settings
@@ -85,142 +266,34 @@ impl Default for AssistantSettingsContent {
}
impl AssistantSettingsContent {
pub fn is_version_outdated(&self) -> bool {
fn upgrade(&self) -> AssistantSettingsContentV1 {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(_) => true,
VersionedAssistantSettingsContent::V2(_) => false,
VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
},
AssistantSettingsContent::Legacy(_) => true,
}
}
pub fn update_file(&mut self, fs: Arc<dyn Fs>, cx: &AppContext) {
if let AssistantSettingsContent::Versioned(settings) = self {
if let VersionedAssistantSettingsContent::V1(settings) = settings {
if let Some(provider) = settings.provider.clone() {
match provider {
AssistantProviderContentV1::Anthropic {
api_url,
low_speed_timeout_in_seconds,
..
} => update_settings_file::<AllLanguageModelSettings>(
fs,
cx,
move |content, _| {
if content.anthropic.is_none() {
content.anthropic =
Some(language_model::settings::AnthropicSettingsContent {
api_url,
low_speed_timeout_in_seconds,
..Default::default()
});
}
},
),
AssistantProviderContentV1::Ollama {
api_url,
low_speed_timeout_in_seconds,
..
} => update_settings_file::<AllLanguageModelSettings>(
fs,
cx,
move |content, _| {
if content.ollama.is_none() {
content.ollama =
Some(language_model::settings::OllamaSettingsContent {
api_url,
low_speed_timeout_in_seconds,
});
}
},
),
AssistantProviderContentV1::OpenAi {
api_url,
low_speed_timeout_in_seconds,
available_models,
..
} => update_settings_file::<AllLanguageModelSettings>(
fs,
cx,
move |content, _| {
if content.openai.is_none() {
content.openai =
Some(language_model::settings::OpenAiSettingsContent {
api_url,
low_speed_timeout_in_seconds,
available_models,
});
}
},
),
_ => {}
}
}
}
}
*self = AssistantSettingsContent::Versioned(VersionedAssistantSettingsContent::V2(
self.upgrade(),
));
}
fn upgrade(&self) -> AssistantSettingsContentV2 {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => AssistantSettingsContentV2 {
enabled: settings.enabled,
button: settings.button,
dock: settings.dock,
default_width: settings.default_width,
default_height: settings.default_width,
default_model: settings
.provider
.clone()
.and_then(|provider| match provider {
AssistantProviderContentV1::ZedDotDev { default_model } => {
default_model.map(|model| AssistantDefaultModel {
provider: "zed.dev".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::OpenAi { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
provider: "openai".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::Anthropic { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
provider: "anthropic".to_string(),
model: model.id().to_string(),
})
}
AssistantProviderContentV1::Ollama { default_model, .. } => {
default_model.map(|model| AssistantDefaultModel {
provider: "ollama".to_string(),
model: model.id().to_string(),
})
}
}),
},
VersionedAssistantSettingsContent::V2(settings) => settings.clone(),
},
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV2 {
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
enabled: None,
button: settings.button,
dock: settings.dock,
default_width: settings.default_width,
default_height: settings.default_height,
default_model: Some(AssistantDefaultModel {
provider: "openai".to_string(),
model: settings
.default_open_ai_model
.clone()
.unwrap_or_default()
.id()
.to_string(),
}),
provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
Some(AssistantProviderContent::OpenAi {
default_model: settings.default_open_ai_model.clone(),
api_url: Some(open_ai_api_url.clone()),
low_speed_timeout_in_seconds: None,
available_models: Some(Default::default()),
})
} else {
settings.default_open_ai_model.clone().map(|open_ai_model| {
AssistantProviderContent::OpenAi {
default_model: Some(open_ai_model),
api_url: None,
low_speed_timeout_in_seconds: None,
available_models: Some(Default::default()),
}
})
},
},
}
}
@@ -231,9 +304,6 @@ impl AssistantSettingsContent {
VersionedAssistantSettingsContent::V1(settings) => {
settings.dock = Some(dock);
}
VersionedAssistantSettingsContent::V2(settings) => {
settings.dock = Some(dock);
}
},
AssistantSettingsContent::Legacy(settings) => {
settings.dock = Some(dock);
@@ -241,78 +311,74 @@ impl AssistantSettingsContent {
}
}
pub fn set_model(&mut self, language_model: Arc<dyn LanguageModel>) {
let model = language_model.id().0.to_string();
let provider = language_model.provider_id().0.to_string();
pub fn set_model(&mut self, new_model: LanguageModel) {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() {
"zed.dev" => {
settings.provider = Some(AssistantProviderContentV1::ZedDotDev {
default_model: CloudModel::from_id(&model).ok(),
});
VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
Some(AssistantProviderContent::ZedDotDev {
default_model: model,
}) => {
if let LanguageModel::Cloud(new_model) = new_model {
*model = Some(new_model);
}
}
"anthropic" => {
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
Some(AssistantProviderContentV1::Anthropic {
api_url,
low_speed_timeout_in_seconds,
..
}) => (api_url.clone(), *low_speed_timeout_in_seconds),
_ => (None, None),
};
settings.provider = Some(AssistantProviderContentV1::Anthropic {
default_model: AnthropicModel::from_id(&model).ok(),
api_url,
low_speed_timeout_in_seconds,
});
Some(AssistantProviderContent::OpenAi {
default_model: model,
..
}) => {
if let LanguageModel::OpenAi(new_model) = new_model {
*model = Some(new_model);
}
}
"ollama" => {
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
Some(AssistantProviderContentV1::Ollama {
api_url,
low_speed_timeout_in_seconds,
..
}) => (api_url.clone(), *low_speed_timeout_in_seconds),
_ => (None, None),
};
settings.provider = Some(AssistantProviderContentV1::Ollama {
default_model: Some(ollama::Model::new(&model)),
api_url,
low_speed_timeout_in_seconds,
});
Some(AssistantProviderContent::Anthropic {
default_model: model,
..
}) => {
if let LanguageModel::Anthropic(new_model) = new_model {
*model = Some(new_model);
}
}
"openai" => {
let (api_url, low_speed_timeout_in_seconds, available_models) =
match &settings.provider {
Some(AssistantProviderContentV1::OpenAi {
api_url,
low_speed_timeout_in_seconds,
available_models,
..
}) => (
api_url.clone(),
*low_speed_timeout_in_seconds,
available_models.clone(),
),
_ => (None, None, None),
};
settings.provider = Some(AssistantProviderContentV1::OpenAi {
default_model: open_ai::Model::from_id(&model).ok(),
api_url,
low_speed_timeout_in_seconds,
available_models,
});
Some(AssistantProviderContent::Ollama {
default_model: model,
..
}) => {
if let LanguageModel::Ollama(new_model) = new_model {
*model = Some(new_model);
}
}
_ => {}
provider => match new_model {
LanguageModel::Cloud(model) => {
*provider = Some(AssistantProviderContent::ZedDotDev {
default_model: Some(model),
})
}
LanguageModel::OpenAi(model) => {
*provider = Some(AssistantProviderContent::OpenAi {
default_model: Some(model),
api_url: None,
low_speed_timeout_in_seconds: None,
available_models: Some(Default::default()),
})
}
LanguageModel::Anthropic(model) => {
*provider = Some(AssistantProviderContent::Anthropic {
default_model: Some(model),
api_url: None,
low_speed_timeout_in_seconds: None,
})
}
LanguageModel::Ollama(model) => {
*provider = Some(AssistantProviderContent::Ollama {
default_model: Some(model),
api_url: None,
low_speed_timeout_in_seconds: None,
})
}
},
},
VersionedAssistantSettingsContent::V2(settings) => {
settings.default_model = Some(AssistantDefaultModel { provider, model });
}
},
AssistantSettingsContent::Legacy(settings) => {
if let Ok(model) = open_ai::Model::from_id(&language_model.id().0) {
if let LanguageModel::OpenAi(model) = new_model {
settings.default_open_ai_model = Some(model);
}
}
@@ -325,78 +391,21 @@ impl AssistantSettingsContent {
pub enum VersionedAssistantSettingsContent {
#[serde(rename = "1")]
V1(AssistantSettingsContentV1),
#[serde(rename = "2")]
V2(AssistantSettingsContentV2),
}
impl Default for VersionedAssistantSettingsContent {
fn default() -> Self {
Self::V2(AssistantSettingsContentV2 {
Self::V1(AssistantSettingsContentV1 {
enabled: None,
button: None,
dock: None,
default_width: None,
default_height: None,
default_model: None,
provider: None,
})
}
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
pub struct AssistantSettingsContentV2 {
/// Whether the Assistant is enabled.
///
/// Default: true
enabled: Option<bool>,
/// Whether to show the assistant panel button in the status bar.
///
/// Default: true
button: Option<bool>,
/// Where to dock the assistant.
///
/// Default: right
dock: Option<AssistantDockPosition>,
/// Default width in pixels when the assistant is docked to the left or right.
///
/// Default: 640
default_width: Option<f32>,
/// Default height in pixels when the assistant is docked to the bottom.
///
/// Default: 320
default_height: Option<f32>,
/// The default model to use when creating new contexts.
default_model: Option<AssistantDefaultModel>,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct AssistantDefaultModel {
#[schemars(schema_with = "providers_schema")]
pub provider: String,
pub model: String,
}
fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
schemars::schema::SchemaObject {
enum_values: Some(vec![
"anthropic".into(),
"ollama".into(),
"openai".into(),
"zed.dev".into(),
]),
..Default::default()
}
.into()
}
impl Default for AssistantDefaultModel {
fn default() -> Self {
Self {
provider: "openai".to_string(),
model: "gpt-4".to_string(),
}
}
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
pub struct AssistantSettingsContentV1 {
/// Whether the Assistant is enabled.
@@ -423,7 +432,7 @@ pub struct AssistantSettingsContentV1 {
///
/// This can either be the internal `zed.dev` service or an external `openai` service,
/// each with their respective default models and configurations.
provider: Option<AssistantProviderContentV1>,
provider: Option<AssistantProviderContent>,
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
@@ -466,10 +475,6 @@ impl Settings for AssistantSettings {
let mut settings = AssistantSettings::default();
for value in sources.defaults_and_customizations() {
if value.is_version_outdated() {
settings.using_outdated_settings_version = true;
}
let value = value.upgrade();
merge(&mut settings.enabled, value.enabled);
merge(&mut settings.button, value.button);
@@ -482,10 +487,123 @@ impl Settings for AssistantSettings {
&mut settings.default_height,
value.default_height.map(Into::into),
);
merge(
&mut settings.default_model,
value.default_model.map(Into::into),
);
if let Some(provider) = value.provider.clone() {
match (&mut settings.provider, provider) {
(
AssistantProvider::ZedDotDev { model },
AssistantProviderContent::ZedDotDev {
default_model: model_override,
},
) => {
merge(model, model_override);
}
(
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
},
AssistantProviderContent::OpenAi {
default_model: model_override,
api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
available_models: available_models_override,
},
) => {
merge(model, model_override);
merge(api_url, api_url_override);
merge(available_models, available_models_override);
if let Some(low_speed_timeout_in_seconds_override) =
low_speed_timeout_in_seconds_override
{
*low_speed_timeout_in_seconds =
Some(low_speed_timeout_in_seconds_override);
}
}
(
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
},
AssistantProviderContent::Ollama {
default_model: model_override,
api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
},
) => {
merge(model, model_override);
merge(api_url, api_url_override);
if let Some(low_speed_timeout_in_seconds_override) =
low_speed_timeout_in_seconds_override
{
*low_speed_timeout_in_seconds =
Some(low_speed_timeout_in_seconds_override);
}
}
(
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
},
AssistantProviderContent::Anthropic {
default_model: model_override,
api_url: api_url_override,
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
},
) => {
merge(model, model_override);
merge(api_url, api_url_override);
if let Some(low_speed_timeout_in_seconds_override) =
low_speed_timeout_in_seconds_override
{
*low_speed_timeout_in_seconds =
Some(low_speed_timeout_in_seconds_override);
}
}
(provider, provider_override) => {
*provider = match provider_override {
AssistantProviderContent::ZedDotDev {
default_model: model,
} => AssistantProvider::ZedDotDev {
model: model.unwrap_or_default(),
},
AssistantProviderContent::OpenAi {
default_model: model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => AssistantProvider::OpenAi {
model: model.unwrap_or_default(),
api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
low_speed_timeout_in_seconds,
available_models: available_models.unwrap_or_default(),
},
AssistantProviderContent::Anthropic {
default_model: model,
api_url,
low_speed_timeout_in_seconds,
} => AssistantProvider::Anthropic {
model: model.unwrap_or_default(),
api_url: api_url
.unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
low_speed_timeout_in_seconds,
},
AssistantProviderContent::Ollama {
default_model: model,
api_url,
low_speed_timeout_in_seconds,
} => AssistantProvider::Ollama {
model: model.unwrap_or_default(),
api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
low_speed_timeout_in_seconds,
},
};
}
}
}
}
Ok(settings)
@@ -498,103 +616,96 @@ fn merge<T>(target: &mut T, value: Option<T>) {
}
}
// #[cfg(test)]
// mod tests {
// use gpui::{AppContext, UpdateGlobal};
// use settings::SettingsStore;
#[cfg(test)]
mod tests {
use gpui::{AppContext, UpdateGlobal};
use settings::SettingsStore;
// use super::*;
use super::*;
// #[gpui::test]
// fn test_deserialize_assistant_settings(cx: &mut AppContext) {
// let store = settings::SettingsStore::test(cx);
// cx.set_global(store);
#[gpui::test]
fn test_deserialize_assistant_settings(cx: &mut AppContext) {
let store = settings::SettingsStore::test(cx);
cx.set_global(store);
// // Settings default to gpt-4-turbo.
// AssistantSettings::register(cx);
// assert_eq!(
// AssistantSettings::get_global(cx).provider,
// AssistantProvider::OpenAi {
// model: OpenAiModel::FourOmni,
// api_url: open_ai::OPEN_AI_API_URL.into(),
// low_speed_timeout_in_seconds: None,
// available_models: Default::default(),
// }
// );
// Settings default to gpt-4-turbo.
AssistantSettings::register(cx);
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
model: OpenAiModel::FourOmni,
api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
available_models: Default::default(),
}
);
// // Ensure backward-compatibility.
// SettingsStore::update_global(cx, |store, cx| {
// store
// .set_user_settings(
// r#"{
// "assistant": {
// "openai_api_url": "test-url",
// }
// }"#,
// cx,
// )
// .unwrap();
// });
// assert_eq!(
// AssistantSettings::get_global(cx).provider,
// AssistantProvider::OpenAi {
// model: OpenAiModel::FourOmni,
// api_url: "test-url".into(),
// low_speed_timeout_in_seconds: None,
// available_models: Default::default(),
// }
// );
// SettingsStore::update_global(cx, |store, cx| {
// store
// .set_user_settings(
// r#"{
// "assistant": {
// "default_open_ai_model": "gpt-4-0613"
// }
// }"#,
// cx,
// )
// .unwrap();
// });
// assert_eq!(
// AssistantSettings::get_global(cx).provider,
// AssistantProvider::OpenAi {
// model: OpenAiModel::Four,
// api_url: open_ai::OPEN_AI_API_URL.into(),
// low_speed_timeout_in_seconds: None,
// available_models: Default::default(),
// }
// );
// Ensure backward-compatibility.
SettingsStore::update_global(cx, |store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"openai_api_url": "test-url",
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
model: OpenAiModel::FourOmni,
api_url: "test-url".into(),
low_speed_timeout_in_seconds: None,
available_models: Default::default(),
}
);
SettingsStore::update_global(cx, |store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"default_open_ai_model": "gpt-4-0613"
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
model: OpenAiModel::Four,
api_url: open_ai::OPEN_AI_API_URL.into(),
low_speed_timeout_in_seconds: None,
available_models: Default::default(),
}
);
// // The new version supports setting a custom model when using zed.dev.
// SettingsStore::update_global(cx, |store, cx| {
// store
// .set_user_settings(
// r#"{
// "assistant": {
// "version": "1",
// "provider": {
// "name": "zed.dev",
// "default_model": {
// "custom": {
// "name": "custom-provider"
// }
// }
// }
// }
// }"#,
// cx,
// )
// .unwrap();
// });
// assert_eq!(
// AssistantSettings::get_global(cx).provider,
// AssistantProvider::ZedDotDev {
// model: CloudModel::Custom {
// name: "custom-provider".into(),
// max_tokens: None
// }
// }
// );
// }
// }
// The new version supports setting a custom model when using zed.dev.
SettingsStore::update_global(cx, |store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"version": "1",
"provider": {
"name": "zed.dev",
"default_model": "custom"
}
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::ZedDotDev {
model: CloudModel::Custom("custom".into())
}
);
}
}

View File

@@ -0,0 +1,373 @@
mod anthropic;
mod cloud;
#[cfg(any(test, feature = "test-support"))]
mod fake;
mod ollama;
mod open_ai;
pub use anthropic::*;
pub use cloud::*;
#[cfg(any(test, feature = "test-support"))]
pub use fake::*;
pub use ollama::*;
pub use open_ai::*;
use parking_lot::RwLock;
use smol::lock::{Semaphore, SemaphoreGuardArc};
use crate::{
assistant_settings::{AssistantProvider, AssistantSettings},
LanguageModel, LanguageModelRequest,
};
use anyhow::Result;
use client::Client;
use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
use settings::{Settings, SettingsStore};
use std::time::Duration;
use std::{any::Any, sync::Arc};
/// Choose which model to use for openai provider.
/// If the model is not available, try to use the first available model, or fallback to the original model.
fn choose_openai_model(
model: &::open_ai::Model,
available_models: &[::open_ai::Model],
) -> ::open_ai::Model {
available_models
.iter()
.find(|&m| m == model)
.or_else(|| available_models.first())
.unwrap_or_else(|| model)
.clone()
}
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let provider = create_provider_from_settings(client.clone(), 0, cx);
cx.set_global(CompletionProvider::new(provider, Some(client)));
let mut settings_version = 0;
cx.observe_global::<SettingsStore>(move |cx| {
settings_version += 1;
cx.update_global::<CompletionProvider, _>(|provider, cx| {
provider.update_settings(settings_version, cx);
})
})
.detach();
}
pub struct CompletionResponse {
pub inner: BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>,
_lock: SemaphoreGuardArc,
}
pub trait LanguageModelCompletionProvider: Send + Sync {
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
fn settings_version(&self) -> usize;
fn is_authenticated(&self) -> bool;
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
fn model(&self) -> LanguageModel;
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>>;
fn complete(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn as_any_mut(&mut self) -> &mut dyn Any;
}
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
pub struct CompletionProvider {
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
client: Option<Arc<Client>>,
request_limiter: Arc<Semaphore>,
}
impl CompletionProvider {
pub fn new(
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
client: Option<Arc<Client>>,
) -> Self {
Self {
provider,
client,
request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
}
}
pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
self.provider.read().available_models(cx)
}
pub fn settings_version(&self) -> usize {
self.provider.read().settings_version()
}
pub fn is_authenticated(&self) -> bool {
self.provider.read().is_authenticated()
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
self.provider.read().authenticate(cx)
}
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
self.provider.read().authentication_prompt(cx)
}
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
self.provider.read().reset_credentials(cx)
}
pub fn model(&self) -> LanguageModel {
self.provider.read().model()
}
pub fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
self.provider.read().count_tokens(request, cx)
}
pub fn complete(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> Task<CompletionResponse> {
let rate_limiter = self.request_limiter.clone();
let provider = self.provider.clone();
cx.background_executor().spawn(async move {
let lock = rate_limiter.acquire_arc().await;
let response = provider.read().complete(request);
CompletionResponse {
inner: response,
_lock: lock,
}
})
}
}
impl gpui::Global for CompletionProvider {}
impl CompletionProvider {
pub fn global(cx: &AppContext) -> &Self {
cx.global::<Self>()
}
pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
&mut self,
update: impl FnOnce(&mut T) -> R,
) -> Option<R> {
let mut provider = self.provider.write();
if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
Some(update(provider))
} else {
None
}
}
pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
let updated = match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => self
.update_current_as::<_, CloudCompletionProvider>(|provider| {
provider.update(model.clone(), version);
}),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
provider.update(
choose_openai_model(&model, &available_models),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
cx,
);
}),
};
// Previously configured provider was changed to another one
if updated.is_none() {
if let Some(client) = self.client.clone() {
self.provider = create_provider_from_settings(client, version, cx);
} else {
log::warn!("completion provider cannot be created because client is not set");
}
}
}
}
fn create_provider_from_settings(
client: Arc<Client>,
settings_version: usize,
cx: &mut AppContext,
) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
)),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
choose_openai_model(&model, &available_models),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
))),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
))),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => Arc::new(RwLock::new(OllamaCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
cx,
))),
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use gpui::AppContext;
use parking_lot::RwLock;
use settings::SettingsStore;
use smol::stream::StreamExt;
use crate::{
completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider,
FakeCompletionProvider, LanguageModelRequest,
};
#[gpui::test]
fn test_rate_limiting(cx: &mut AppContext) {
SettingsStore::test(cx);
let fake_provider = FakeCompletionProvider::setup_test(cx);
let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
// Enqueue some requests
for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
let response = provider.complete(
LanguageModelRequest {
temperature: i as f32 / 10.0,
..Default::default()
},
cx,
);
cx.background_executor()
.spawn(async move {
let response = response.await;
let mut stream = response.inner.await.unwrap();
while let Some(message) = stream.next().await {
message.unwrap();
}
})
.detach();
}
cx.background_executor().run_until_parked();
assert_eq!(
fake_provider.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS
);
// Get the first completion request that is in flight and mark it as completed.
let completion = fake_provider
.running_completions()
.into_iter()
.next()
.unwrap();
fake_provider.finish_completion(&completion);
// Ensure that the number of in-flight completion requests is reduced.
assert_eq!(
fake_provider.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
);
cx.background_executor().run_until_parked();
// Ensure that another completion request was allowed to acquire the lock.
assert_eq!(
fake_provider.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS
);
// Mark all completion requests as finished that are in flight.
for request in fake_provider.running_completions() {
fake_provider.finish_completion(&request);
}
assert_eq!(fake_provider.completion_count(), 0);
// Wait until the background tasks acquire the lock again.
cx.background_executor().run_until_parked();
assert_eq!(
fake_provider.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
);
// Finish all remaining completion requests.
for request in fake_provider.running_completions() {
fake_provider.finish_completion(&request);
}
cx.background_executor().run_until_parked();
assert_eq!(fake_provider.completion_count(), 0);
}
}

View File

@@ -1,122 +1,53 @@
use crate::{
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
Role,
};
use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage};
use anthropic::{stream_completion, Request, RequestMessage};
use anyhow::{anyhow, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{
AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
WhiteSpace,
};
use http_client::HttpClient;
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
use http::HttpClient;
use settings::Settings;
use std::time::Duration;
use std::{env, sync::Arc};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
const PROVIDER_ID: &str = "anthropic";
const PROVIDER_NAME: &str = "Anthropic";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct AnthropicSettings {
pub api_url: String,
pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<anthropic::Model>,
}
pub struct AnthropicLanguageModelProvider {
http_client: Arc<dyn HttpClient>,
state: gpui::Model<State>,
}
struct State {
pub struct AnthropicCompletionProvider {
api_key: Option<String>,
_subscription: Subscription,
api_url: String,
model: AnthropicModel,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
}
impl AnthropicLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let state = cx.new_model(|cx| State {
api_key: None,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
Self { http_client, state }
}
}
impl LanguageModelProviderState for AnthropicLanguageModelProvider {
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
Some(cx.observe(&self.state, |_, _, cx| {
cx.notify();
}))
}
}
impl LanguageModelProvider for AnthropicLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();
// Add base models from anthropic::Model::iter()
for model in anthropic::Model::iter() {
if !matches!(model, anthropic::Model::Custom { .. }) {
models.insert(model.id().to_string(), model);
}
}
// Override with available models from settings
for model in AllLanguageModelSettings::get_global(cx)
.anthropic
.available_models
.iter()
{
models.insert(model.id().to_string(), model.clone());
}
models
.into_values()
.map(|model| {
Arc::new(AnthropicModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
}) as Arc<dyn LanguageModel>
})
impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
AnthropicModel::iter()
.map(LanguageModel::Anthropic)
.collect()
}
fn is_authenticated(&self, cx: &AppContext) -> bool {
self.state.read(cx).api_key.is_some()
fn settings_version(&self) -> usize {
self.settings_version
}
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
let api_url = AllLanguageModelSettings::get_global(cx)
.anthropic
.api_url
.clone();
let state = self.state.clone();
let api_url = self.api_url.clone();
cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
api_key
} else {
let (_, api_key) = cx
@@ -125,126 +56,34 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
.ok_or_else(|| anyhow!("credentials not found"))?;
String::from_utf8(api_key)?
};
state.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.api_key = Some(api_key);
});
})
})
}
}
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let state = self.state.clone();
let delete_credentials =
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
let delete_credentials = cx.delete_credentials(&self.api_url);
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
state.update(&mut cx, |this, cx| {
this.api_key = None;
cx.notify();
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.api_key = None;
});
})
})
}
}
pub struct AnthropicModel {
id: LanguageModelId,
model: anthropic::Model,
state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>,
}
impl AnthropicModel {
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
preprocess_anthropic_request(&mut request);
let mut system_message = String::new();
if request
.messages
.first()
.map_or(false, |message| message.role == Role::System)
{
system_message = request.messages.remove(0).content;
}
Request {
model: self.model.clone(),
messages: request
.messages
.iter()
.map(|msg| RequestMessage {
role: match msg.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => unreachable!("filtered out by preprocess_request"),
},
content: msg.content.clone(),
})
.collect(),
stream: true,
system: system_message,
max_tokens: 4092,
}
}
}
pub fn count_anthropic_tokens(
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
cx.background_executor()
.spawn(async move {
let messages = request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.content),
name: None,
function_call: None,
})
.collect::<Vec<_>>();
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
})
.boxed()
}
impl LanguageModel for AnthropicModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
.into()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn telemetry_id(&self) -> String {
format!("anthropic/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
fn model(&self) -> LanguageModel {
LanguageModel::Anthropic(self.model.clone())
}
fn count_tokens(
@@ -252,29 +91,19 @@ impl LanguageModel for AnthropicModel {
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
count_anthropic_tokens(request, cx)
count_open_ai_tokens(request, cx.background_executor())
}
fn stream_completion(
fn complete(
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = self.to_anthropic_request(request);
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let api_key = self.api_key.clone();
let api_url = self.api_url.clone();
let low_speed_timeout = self.low_speed_timeout;
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(
@@ -309,6 +138,79 @@ impl LanguageModel for AnthropicModel {
}
.boxed()
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
impl AnthropicCompletionProvider {
pub fn new(
model: AnthropicModel,
api_url: String,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) -> Self {
Self {
api_key: None,
api_url,
model,
http_client,
low_speed_timeout,
settings_version,
}
}
pub fn update(
&mut self,
model: AnthropicModel,
api_url: String,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) {
self.model = model;
self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout;
self.settings_version = settings_version;
}
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
preprocess_anthropic_request(&mut request);
let model = match request.model {
LanguageModel::Anthropic(model) => model,
_ => self.model.clone(),
};
let mut system_message = String::new();
if request
.messages
.first()
.map_or(false, |message| message.role == Role::System)
{
system_message = request.messages.remove(0).content;
}
Request {
model,
messages: request
.messages
.iter()
.map(|msg| RequestMessage {
role: match msg.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => unreachable!("filtered out by preprocess_request"),
},
content: msg.content.clone(),
})
.collect(),
stream: true,
system: system_message,
max_tokens: 4092,
}
}
}
pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
@@ -356,11 +258,11 @@ pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
struct AuthenticationPrompt {
api_key: View<Editor>,
state: gpui::Model<State>,
api_url: String,
}
impl AuthenticationPrompt {
fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
fn new(api_url: String, cx: &mut WindowContext) -> Self {
Self {
api_key: cx.new_view(|cx| {
let mut editor = Editor::single_line(cx);
@@ -370,7 +272,7 @@ impl AuthenticationPrompt {
);
editor
}),
state,
api_url,
}
}
@@ -380,21 +282,13 @@ impl AuthenticationPrompt {
return;
}
let write_credentials = cx.write_credentials(
AllLanguageModelSettings::get_global(cx)
.anthropic
.api_url
.as_str(),
"Bearer",
api_key.as_bytes(),
);
let state = self.state.clone();
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
cx.spawn(|_, mut cx| async move {
write_credentials.await?;
state.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.api_key = Some(api_key);
});
})
})
.detach_and_log_err(cx);

View File

@@ -0,0 +1,208 @@
use crate::{
assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
LanguageModelCompletionProvider, LanguageModelRequest,
};
use anyhow::{anyhow, Result};
use client::{proto, Client};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use gpui::{AnyView, AppContext, Task};
use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
pub struct CloudCompletionProvider {
client: Arc<Client>,
model: CloudModel,
settings_version: usize,
status: client::Status,
_maintain_client_status: Task<()>,
}
impl CloudCompletionProvider {
pub fn new(
model: CloudModel,
client: Arc<Client>,
settings_version: usize,
cx: &mut AppContext,
) -> Self {
let mut status_rx = client.status();
let status = *status_rx.borrow();
let maintain_client_status = cx.spawn(|mut cx| async move {
while let Some(status) = status_rx.next().await {
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, Self>(|provider| {
provider.status = status;
});
});
}
});
Self {
client,
model,
settings_version,
status,
_maintain_client_status: maintain_client_status,
}
}
pub fn update(&mut self, model: CloudModel, settings_version: usize) {
self.model = model;
self.settings_version = settings_version;
}
}
impl LanguageModelCompletionProvider for CloudCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
Some(custom_model)
} else {
None
};
CloudModel::iter()
.filter_map(move |model| {
if let CloudModel::Custom(_) = model {
Some(CloudModel::Custom(custom_model.take()?))
} else {
Some(model)
}
})
.map(LanguageModel::Cloud)
.collect()
}
fn settings_version(&self) -> usize {
self.settings_version
}
fn is_authenticated(&self) -> bool {
self.status.is_connected()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
let client = self.client.clone();
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
}
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|_cx| AuthenticationPrompt).into()
}
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
fn model(&self) -> LanguageModel {
LanguageModel::Cloud(self.model.clone())
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
match request.model {
LanguageModel::Cloud(CloudModel::Gpt4)
| LanguageModel::Cloud(CloudModel::Gpt4Turbo)
| LanguageModel::Cloud(CloudModel::Gpt4Omni)
| LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
count_open_ai_tokens(request, cx.background_executor())
}
LanguageModel::Cloud(
CloudModel::Claude3_5Sonnet
| CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku,
) => {
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
count_open_ai_tokens(request, cx.background_executor())
}
LanguageModel::Cloud(CloudModel::Custom(model)) => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
model,
messages: request
.messages
.iter()
.map(|message| message.to_proto())
.collect(),
});
async move {
let response = request.await?;
Ok(response.token_count as usize)
}
.boxed()
}
_ => future::ready(Err(anyhow!("invalid model"))).boxed(),
}
}
fn complete(
&self,
mut request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
request.preprocess();
let request = proto::CompleteWithLanguageModel {
model: request.model.id().to_string(),
messages: request
.messages
.iter()
.map(|message| message.to_proto())
.collect(),
stop: request.stop,
temperature: request.temperature,
tools: Vec::new(),
tool_choice: None,
};
self.client
.request_stream(request)
.map_ok(|stream| {
stream
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed()
})
.boxed()
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
struct AuthenticationPrompt;
impl Render for AuthenticationPrompt {
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
v_flex()
.gap_2()
.child(
Button::new("sign_in", "Sign in")
.icon_color(Color::Muted)
.icon(IconName::Github)
.icon_position(IconPosition::Start)
.style(ButtonStyle::Filled)
.full_width()
.on_click(|_, cx| {
CompletionProvider::global(cx)
.authenticate(cx)
.detach_and_log_err(cx);
}),
)
.child(
div().flex().w_full().items_center().child(
Label::new("Sign in to enable collaboration.")
.color(Color::Muted)
.size(LabelSize::Small),
),
),
)
}
}

View File

@@ -0,0 +1,106 @@
use anyhow::Result;
use collections::HashMap;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, Task};
use std::sync::Arc;
use ui::WindowContext;
use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
#[derive(Clone, Default)]
pub struct FakeCompletionProvider {
current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
}
impl FakeCompletionProvider {
pub fn setup_test(cx: &mut AppContext) -> Self {
use crate::CompletionProvider;
use parking_lot::RwLock;
let this = Self::default();
let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
cx.set_global(provider);
this
}
pub fn running_completions(&self) -> Vec<LanguageModelRequest> {
self.current_completion_txs
.lock()
.keys()
.map(|k| serde_json::from_str(k).unwrap())
.collect()
}
pub fn completion_count(&self) -> usize {
self.current_completion_txs.lock().len()
}
pub fn send_completion(&self, request: &LanguageModelRequest, chunk: String) {
let json = serde_json::to_string(request).unwrap();
self.current_completion_txs
.lock()
.get(&json)
.unwrap()
.unbounded_send(chunk)
.unwrap();
}
pub fn finish_completion(&self, request: &LanguageModelRequest) {
self.current_completion_txs
.lock()
.remove(&serde_json::to_string(request).unwrap());
}
}
impl LanguageModelCompletionProvider for FakeCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
vec![LanguageModel::default()]
}
fn settings_version(&self) -> usize {
0
}
fn is_authenticated(&self) -> bool {
true
}
fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
unimplemented!()
}
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
fn model(&self) -> LanguageModel {
LanguageModel::default()
}
fn count_tokens(
&self,
_request: LanguageModelRequest,
_cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
futures::future::ready(Ok(0)).boxed()
}
fn complete(
&self,
_request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs
.lock()
.insert(serde_json::to_string(&_request).unwrap(), tx);
async move { Ok(rx.map(Ok).boxed()) }.boxed()
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}

View File

@@ -1,165 +1,50 @@
use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
use http_client::HttpClient;
use crate::LanguageModelCompletionProvider;
use crate::{
assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
};
use anyhow::Result;
use futures::StreamExt as _;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
use gpui::{AnyView, AppContext, Task};
use http::HttpClient;
use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
Role as OllamaRole,
};
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use std::sync::Arc;
use std::time::Duration;
use ui::{prelude::*, ButtonLike, ElevationIndex};
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
};
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
const PROVIDER_ID: &str = "ollama";
const PROVIDER_NAME: &str = "Ollama";
#[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings {
pub api_url: String,
pub low_speed_timeout: Option<Duration>,
}
pub struct OllamaLanguageModelProvider {
pub struct OllamaCompletionProvider {
api_url: String,
model: OllamaModel,
http_client: Arc<dyn HttpClient>,
state: gpui::Model<State>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
available_models: Vec<OllamaModel>,
}
struct State {
http_client: Arc<dyn HttpClient>,
available_models: Vec<ollama::Model>,
_subscription: Subscription,
}
impl State {
fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|this, mut cx| async move {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let mut models: Vec<ollama::Model> = models
.into_iter()
// Since there is no metadata from the Ollama API
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| ollama::Model::new(&model.name))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
this.update(&mut cx, |this, cx| {
this.available_models = models;
cx.notify();
})
})
}
}
impl OllamaLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let this = Self {
http_client: http_client.clone(),
state: cx.new_model(|cx| State {
http_client,
available_models: Default::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.fetch_models(cx).detach();
cx.notify();
}),
}),
};
this.fetch_models(cx).detach();
this
}
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let state = self.state.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|mut cx| async move {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let mut models: Vec<ollama::Model> = models
.into_iter()
// Since there is no metadata from the Ollama API
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| ollama::Model::new(&model.name))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
state.update(&mut cx, |this, cx| {
this.available_models = models;
cx.notify();
})
})
}
}
impl LanguageModelProviderState for OllamaLanguageModelProvider {
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
Some(cx.observe(&self.state, |_, _, cx| {
cx.notify();
}))
}
}
impl LanguageModelProvider for OllamaLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
self.state
.read(cx)
.available_models
impl LanguageModelCompletionProvider for OllamaCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
self.available_models
.iter()
.map(|model| {
Arc::new(OllamaLanguageModel {
id: LanguageModelId::from(model.name.clone()),
model: model.clone(),
http_client: self.http_client.clone(),
}) as Arc<dyn LanguageModel>
})
.map(|m| LanguageModel::Ollama(m.clone()))
.collect()
}
fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let id = model.id().0.to_string();
cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
.detach_and_log_err(cx);
fn settings_version(&self) -> usize {
self.settings_version
}
fn is_authenticated(&self, cx: &AppContext) -> bool {
!self.state.read(cx).available_models.is_empty()
fn is_authenticated(&self) -> bool {
!self.available_models.is_empty()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
self.fetch_models(cx)
@@ -167,9 +52,14 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
}
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
let state = self.state.clone();
let fetch_models = Box::new(move |cx: &mut WindowContext| {
state.update(cx, |this, cx| this.fetch_models(cx))
cx.update_global::<CompletionProvider, _>(|provider, cx| {
provider
.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.fetch_models(cx)
})
.unwrap_or_else(|| Task::ready(Ok(())))
})
});
cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
@@ -179,68 +69,9 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
self.fetch_models(cx)
}
}
pub struct OllamaLanguageModel {
id: LanguageModelId,
model: ollama::Model,
http_client: Arc<dyn HttpClient>,
}
impl OllamaLanguageModel {
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
ChatRequest {
model: self.model.name.clone(),
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => ChatMessage::User {
content: msg.content,
},
Role::Assistant => ChatMessage::Assistant {
content: msg.content,
},
Role::System => ChatMessage::System {
content: msg.content,
},
})
.collect(),
keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
stream: true,
options: Some(ChatOptions {
num_ctx: Some(self.model.max_tokens),
stop: Some(request.stop),
temperature: Some(request.temperature),
..Default::default()
}),
}
}
}
impl LanguageModel for OllamaLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
fn telemetry_id(&self) -> String {
format!("ollama/{}", self.model.id())
fn model(&self) -> LanguageModel {
LanguageModel::Ollama(self.model.clone())
}
fn count_tokens(
@@ -260,21 +91,15 @@ impl LanguageModel for OllamaLanguageModel {
async move { Ok(token_count) }.boxed()
}
fn stream_completion(
fn complete(
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
(settings.api_url.clone(), settings.low_speed_timeout)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let api_url = self.api_url.clone();
let low_speed_timeout = self.low_speed_timeout;
async move {
let request =
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
@@ -298,6 +123,153 @@ impl LanguageModel for OllamaLanguageModel {
}
.boxed()
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
impl OllamaCompletionProvider {
pub fn new(
model: OllamaModel,
api_url: String,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
cx: &AppContext,
) -> Self {
cx.spawn({
let api_url = api_url.clone();
let client = http_client.clone();
let model = model.name.clone();
|_| async move {
if model.is_empty() {
return Ok(());
}
preload_model(client.as_ref(), &api_url, &model).await
}
})
.detach_and_log_err(cx);
Self {
api_url,
model,
http_client,
low_speed_timeout,
settings_version,
available_models: Default::default(),
}
}
pub fn update(
&mut self,
model: OllamaModel,
api_url: String,
low_speed_timeout: Option<Duration>,
settings_version: usize,
cx: &AppContext,
) {
cx.spawn({
let api_url = api_url.clone();
let client = self.http_client.clone();
let model = model.name.clone();
|_| async move { preload_model(client.as_ref(), &api_url, &model).await }
})
.detach_and_log_err(cx);
if model.name.is_empty() {
self.select_first_available_model()
} else {
self.model = model;
}
self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout;
self.settings_version = settings_version;
}
pub fn select_first_available_model(&mut self) {
if let Some(model) = self.available_models.first() {
self.model = model.clone();
}
}
pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
let http_client = self.http_client.clone();
let api_url = self.api_url.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|mut cx| async move {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let mut models: Vec<OllamaModel> = models
.into_iter()
// Since there is no metadata from the Ollama API
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| OllamaModel::new(&model.name))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.available_models = models;
if !provider.available_models.is_empty() && provider.model.name.is_empty() {
provider.select_first_available_model()
}
});
})
})
}
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
let model = match request.model {
LanguageModel::Ollama(model) => model,
_ => self.model.clone(),
};
ChatRequest {
model: model.name,
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => ChatMessage::User {
content: msg.content,
},
Role::Assistant => ChatMessage::Assistant {
content: msg.content,
},
Role::System => ChatMessage::System {
content: msg.content,
},
})
.collect(),
keep_alive: model.keep_alive.unwrap_or_default(),
stream: true,
options: Some(ChatOptions {
num_ctx: Some(model.max_tokens),
stop: Some(request.stop),
temperature: Some(request.temperature),
..Default::default()
}),
}
}
}
impl From<Role> for ollama::Role {
fn from(val: Role) -> Self {
match val {
Role::User => OllamaRole::User,
Role::Assistant => OllamaRole::Assistant,
Role::System => OllamaRole::System,
}
}
}
struct DownloadOllamaMessage {

View File

@@ -1,168 +1,71 @@
use anyhow::{anyhow, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
use gpui::{
AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
WhiteSpace,
use crate::assistant_settings::CloudModel;
use crate::assistant_settings::{AssistantProvider, AssistantSettings};
use crate::LanguageModelCompletionProvider;
use crate::{
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
};
use http_client::HttpClient;
use open_ai::{stream_completion, Request, RequestMessage};
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
use http::HttpClient;
use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
use settings::Settings;
use std::time::Duration;
use std::{env, sync::Arc};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
};
const PROVIDER_ID: &str = "openai";
const PROVIDER_NAME: &str = "OpenAI";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings {
pub api_url: String,
pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<open_ai::Model>,
}
pub struct OpenAiLanguageModelProvider {
http_client: Arc<dyn HttpClient>,
state: gpui::Model<State>,
}
struct State {
pub struct OpenAiCompletionProvider {
api_key: Option<String>,
_subscription: Subscription,
}
impl OpenAiLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let state = cx.new_model(|cx| State {
api_key: None,
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
cx.notify();
}),
});
Self { http_client, state }
}
}
impl LanguageModelProviderState for OpenAiLanguageModelProvider {
fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
Some(cx.observe(&self.state, |_, _, cx| {
cx.notify();
}))
}
}
impl LanguageModelProvider for OpenAiLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();
// Add base models from open_ai::Model::iter()
for model in open_ai::Model::iter() {
if !matches!(model, open_ai::Model::Custom { .. }) {
models.insert(model.id().to_string(), model);
}
}
// Override with available models from settings
for model in &AllLanguageModelSettings::get_global(cx)
.openai
.available_models
{
models.insert(model.id().to_string(), model.clone());
}
models
.into_values()
.map(|model| {
Arc::new(OpenAiLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
}) as Arc<dyn LanguageModel>
})
.collect()
}
fn is_authenticated(&self, cx: &AppContext) -> bool {
self.state.read(cx).api_key.is_some()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
let api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
let state = self.state.clone();
cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
api_key
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
String::from_utf8(api_key)?
};
state.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}
}
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).openai;
let delete_credentials = cx.delete_credentials(&settings.api_url);
let state = self.state.clone();
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
state.update(&mut cx, |this, cx| {
this.api_key = None;
cx.notify();
})
})
}
}
pub struct OpenAiLanguageModel {
id: LanguageModelId,
model: open_ai::Model,
state: gpui::Model<State>,
api_url: String,
model: OpenAiModel,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
}
impl OpenAiLanguageModel {
impl OpenAiCompletionProvider {
pub fn new(
model: OpenAiModel,
api_url: String,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) -> Self {
Self {
api_key: None,
api_url,
model,
http_client,
low_speed_timeout,
settings_version,
}
}
pub fn update(
&mut self,
model: OpenAiModel,
api_url: String,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) {
self.model = model;
self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout;
self.settings_version = settings_version;
}
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
let model = match request.model {
LanguageModel::OpenAi(model) => model,
_ => self.model.clone(),
};
Request {
model: self.model.clone(),
model,
messages: request
.messages
.into_iter()
@@ -188,29 +91,84 @@ impl OpenAiLanguageModel {
}
}
impl LanguageModel for OpenAiLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
if let AssistantProvider::OpenAi {
available_models, ..
} = &AssistantSettings::get_global(cx).provider
{
if !available_models.is_empty() {
return available_models
.iter()
.cloned()
.map(LanguageModel::OpenAi)
.collect();
}
}
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
vec![self.model.clone()]
} else {
OpenAiModel::iter()
.filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
.collect()
};
available_models
.into_iter()
.map(LanguageModel::OpenAi)
.collect()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
fn settings_version(&self) -> usize {
self.settings_version
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
let api_url = self.api_url.clone();
cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
api_key
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
String::from_utf8(api_key)?
};
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, Self>(|provider| {
provider.api_key = Some(api_key);
});
})
})
}
}
fn telemetry_id(&self) -> String {
format!("openai/{}", self.model.id())
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let delete_credentials = cx.delete_credentials(&self.api_url);
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, Self>(|provider| {
provider.api_key = None;
});
})
})
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
.into()
}
fn model(&self) -> LanguageModel {
LanguageModel::OpenAi(self.model.clone())
}
fn count_tokens(
@@ -218,28 +176,19 @@ impl LanguageModel for OpenAiLanguageModel {
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
count_open_ai_tokens(request, self.model.clone(), cx)
count_open_ai_tokens(request, cx.background_executor())
}
fn stream_completion(
fn complete(
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = self.to_open_ai_request(request);
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).openai;
(
state.api_key.clone(),
settings.api_url.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let api_key = self.api_key.clone();
let api_url = self.api_url.clone();
let low_speed_timeout = self.low_speed_timeout;
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(
@@ -262,14 +211,17 @@ impl LanguageModel for OpenAiLanguageModel {
}
.boxed()
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
pub fn count_open_ai_tokens(
request: LanguageModelRequest,
model: open_ai::Model,
cx: &AppContext,
background_executor: &gpui::BackgroundExecutor,
) -> BoxFuture<'static, Result<usize>> {
cx.background_executor()
background_executor
.spawn(async move {
let messages = request
.messages
@@ -286,22 +238,40 @@ pub fn count_open_ai_tokens(
})
.collect::<Vec<_>>();
if let open_ai::Model::Custom { .. } = model {
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
} else {
tiktoken_rs::num_tokens_from_messages(model.id(), &messages)
match request.model {
LanguageModel::Anthropic(_)
| LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
| LanguageModel::Cloud(CloudModel::Claude3Opus)
| LanguageModel::Cloud(CloudModel::Claude3Sonnet)
| LanguageModel::Cloud(CloudModel::Claude3Haiku)
| LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
}
_ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
}
})
.boxed()
}
impl From<Role> for open_ai::Role {
fn from(val: Role) -> Self {
match val {
Role::User => OpenAiRole::User,
Role::Assistant => OpenAiRole::Assistant,
Role::System => OpenAiRole::System,
}
}
}
struct AuthenticationPrompt {
api_key: View<Editor>,
state: gpui::Model<State>,
api_url: String,
}
impl AuthenticationPrompt {
fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
fn new(api_url: String, cx: &mut WindowContext) -> Self {
Self {
api_key: cx.new_view(|cx| {
let mut editor = Editor::single_line(cx);
@@ -311,7 +281,7 @@ impl AuthenticationPrompt {
);
editor
}),
state,
api_url,
}
}
@@ -321,15 +291,13 @@ impl AuthenticationPrompt {
return;
}
let settings = &AllLanguageModelSettings::get_global(cx).openai;
let write_credentials =
cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
let state = self.state.clone();
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
cx.spawn(|_, mut cx| async move {
write_credentials.await?;
state.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
provider.api_key = Some(api_key);
});
})
})
.detach_and_log_err(cx);

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,7 @@
use crate::{
humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent,
Hunk, LanguageModelCompletionProvider, ModelSelector, StreamingDiff,
assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt,
AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, LanguageModelRequest,
LanguageModelRequestMessage, Role, StreamingDiff,
};
use anyhow::{anyhow, Context as _, Result};
use client::telemetry::Telemetry;
@@ -8,36 +9,27 @@ use collections::{hash_map, HashMap, HashSet, VecDeque};
use editor::{
actions::{MoveDown, MoveUp, SelectAll},
display_map::{
BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, RenderBlock,
ToDisplayPoint,
},
Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
};
use fs::Fs;
use futures::{
channel::mpsc,
future::LocalBoxFuture,
stream::{self, BoxStream},
SinkExt, Stream, StreamExt,
};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{
point, AppContext, EventEmitter, FocusHandle, FocusableView, Global, HighlightStyle, Model,
ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, ViewContext, WeakView,
WindowContext,
point, AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, Global, HighlightStyle,
Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, ViewContext, WeakView,
WhiteSpace, WindowContext,
};
use language::{Buffer, IndentKind, Point, Selection, TransactionId};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use language::{Buffer, Point, Selection, TransactionId};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use rope::Rope;
use settings::Settings;
use settings::{update_settings_file, Settings};
use similar::TextDiff;
use smol::future::FutureExt;
use std::{
cmp,
future::Future,
mem,
cmp, mem,
ops::{Range, RangeInclusive},
pin::Pin,
sync::Arc,
@@ -45,7 +37,7 @@ use std::{
time::{Duration, Instant},
};
use theme::ThemeSettings;
use ui::{prelude::*, IconButtonShape, Tooltip};
use ui::{prelude::*, ContextMenu, PopoverMenu, Tooltip};
use util::RangeExt;
use workspace::{notifications::NotificationId, Toast, Workspace};
@@ -87,7 +79,6 @@ impl InlineAssistant {
editor: &View<Editor>,
workspace: Option<WeakView<Workspace>>,
assistant_panel: Option<&View<AssistantPanel>>,
initial_prompt: Option<String>,
cx: &mut WindowContext,
) {
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
@@ -139,11 +130,11 @@ impl InlineAssistant {
}
let assist_group_id = self.next_assist_group_id.post_inc();
let prompt_buffer =
cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
let prompt_buffer = cx.new_model(|cx| Buffer::local("", cx));
let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
let mut assists = Vec::new();
let mut assist_blocks = Vec::new();
let mut assist_to_focus = None;
for range in codegen_ranges {
let assist_id = self.next_assist_id.post_inc();
@@ -151,7 +142,6 @@ impl InlineAssistant {
Codegen::new(
editor.read(cx).buffer().clone(),
range.clone(),
None,
self.telemetry.clone(),
cx,
)
@@ -184,18 +174,42 @@ impl InlineAssistant {
}
}
let [prompt_block_id, end_block_id] =
self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
assists.push((assist_id, prompt_editor, prompt_block_id, end_block_id));
assist_blocks.push(BlockProperties {
style: BlockStyle::Sticky,
position: range.start,
height: prompt_editor.read(cx).height_in_lines,
render: build_assist_editor_renderer(&prompt_editor),
disposition: BlockDisposition::Above,
});
assist_blocks.push(BlockProperties {
style: BlockStyle::Sticky,
position: range.end,
height: 1,
render: Box::new(|cx| {
v_flex()
.h_full()
.w_full()
.border_t_1()
.border_color(cx.theme().status().info_border)
.into_any_element()
}),
disposition: BlockDisposition::Below,
});
assists.push((assist_id, prompt_editor));
}
let assist_block_ids = editor.update(cx, |editor, cx| {
editor.insert_blocks(assist_blocks, None, cx)
});
let editor_assists = self
.assists_by_editor
.entry(editor.downgrade())
.or_insert_with(|| EditorInlineAssists::new(&editor, cx));
let mut assist_group = InlineAssistGroup::new();
for (assist_id, prompt_editor, prompt_block_id, end_block_id) in assists {
for ((assist_id, prompt_editor), block_ids) in
assists.into_iter().zip(assist_block_ids.chunks_exact(2))
{
self.assists.insert(
assist_id,
InlineAssist::new(
@@ -204,8 +218,8 @@ impl InlineAssistant {
assistant_panel.is_some(),
editor,
&prompt_editor,
prompt_block_id,
end_block_id,
block_ids[0],
block_ids[1],
prompt_editor.read(cx).codegen.clone(),
workspace.clone(),
cx,
@@ -221,128 +235,6 @@ impl InlineAssistant {
}
}
#[allow(clippy::too_many_arguments)]
pub fn suggest_assist(
&mut self,
editor: &View<Editor>,
mut range: Range<Anchor>,
initial_prompt: String,
initial_insertion: Option<String>,
workspace: Option<WeakView<Workspace>>,
assistant_panel: Option<&View<AssistantPanel>>,
cx: &mut WindowContext,
) -> InlineAssistId {
let assist_group_id = self.next_assist_group_id.post_inc();
let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
let assist_id = self.next_assist_id.post_inc();
let buffer = editor.read(cx).buffer().clone();
let prepend_transaction_id = initial_insertion.and_then(|initial_insertion| {
buffer.update(cx, |buffer, cx| {
buffer.start_transaction(cx);
buffer.edit([(range.start..range.start, initial_insertion)], None, cx);
buffer.end_transaction(cx)
})
});
range.start = range.start.bias_left(&buffer.read(cx).read(cx));
range.end = range.end.bias_right(&buffer.read(cx).read(cx));
let codegen = cx.new_model(|cx| {
Codegen::new(
editor.read(cx).buffer().clone(),
range.clone(),
prepend_transaction_id,
self.telemetry.clone(),
cx,
)
});
let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
let prompt_editor = cx.new_view(|cx| {
PromptEditor::new(
assist_id,
gutter_dimensions.clone(),
self.prompt_history.clone(),
prompt_buffer.clone(),
codegen.clone(),
editor,
assistant_panel,
workspace.clone(),
self.fs.clone(),
cx,
)
});
let [prompt_block_id, end_block_id] =
self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
let editor_assists = self
.assists_by_editor
.entry(editor.downgrade())
.or_insert_with(|| EditorInlineAssists::new(&editor, cx));
let mut assist_group = InlineAssistGroup::new();
self.assists.insert(
assist_id,
InlineAssist::new(
assist_id,
assist_group_id,
assistant_panel.is_some(),
editor,
&prompt_editor,
prompt_block_id,
end_block_id,
prompt_editor.read(cx).codegen.clone(),
workspace.clone(),
cx,
),
);
assist_group.assist_ids.push(assist_id);
editor_assists.assist_ids.push(assist_id);
self.assist_groups.insert(assist_group_id, assist_group);
assist_id
}
fn insert_assist_blocks(
&self,
editor: &View<Editor>,
range: &Range<Anchor>,
prompt_editor: &View<PromptEditor>,
cx: &mut WindowContext,
) -> [CustomBlockId; 2] {
let assist_blocks = vec![
BlockProperties {
style: BlockStyle::Sticky,
position: range.start,
height: prompt_editor.read(cx).height_in_lines,
render: build_assist_editor_renderer(prompt_editor),
disposition: BlockDisposition::Above,
},
BlockProperties {
style: BlockStyle::Sticky,
position: range.end,
height: 1,
render: Box::new(|cx| {
v_flex()
.h_full()
.w_full()
.border_t_1()
.border_color(cx.theme().status().info_border)
.into_any_element()
}),
disposition: BlockDisposition::Below,
},
];
editor.update(cx, |editor, cx| {
let block_ids = editor.insert_blocks(assist_blocks, None, cx);
[block_ids[0], block_ids[1]]
})
}
fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
let assist = &self.assists[&assist_id];
let Some(decorations) = assist.decorations.as_ref() else {
@@ -487,14 +379,6 @@ impl InlineAssistant {
cx.propagate();
}
fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
for assist_id in editor_assists.assist_ids.clone() {
self.finish_assist(assist_id, true, cx);
}
}
}
fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
return;
@@ -814,7 +698,7 @@ impl InlineAssistant {
assist_group.assist_ids.clone()
}
pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
assist
} else {
@@ -843,32 +727,16 @@ impl InlineAssistant {
self.prompt_history.pop_front();
}
assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
let codegen = assist.codegen.clone();
let telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
if user_prompt.trim().to_lowercase() == "delete" {
async { Ok(stream::empty().boxed()) }.boxed_local()
} else {
let request = self.request_for_inline_assist(assist_id, cx);
let mut cx = cx.to_async();
async move {
let request = request.await?;
let chunks = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx)
.stream_completion(request, cx)
})?
.await?;
Ok(chunks.boxed())
}
.boxed_local()
};
codegen.update(cx, |codegen, cx| {
codegen.start(telemetry_id, chunks, cx);
});
let request = self.request_for_inline_assist(assist_id, cx);
cx.spawn(|mut cx| async move {
let request = request.await?;
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
fn request_for_inline_assist(
@@ -877,8 +745,8 @@ impl InlineAssistant {
cx: &mut WindowContext,
) -> Task<Result<LanguageModelRequest>> {
cx.spawn(|mut cx| async move {
let (user_prompt, context_request, project_name, buffer, range) =
cx.read_global(|this: &InlineAssistant, cx: &WindowContext| {
let (user_prompt, context_request, project_name, buffer, range, model) = cx
.read_global(|this: &InlineAssistant, cx: &WindowContext| {
let assist = this.assists.get(&assist_id).context("invalid assist")?;
let decorations = assist.decorations.as_ref().context("invalid assist")?;
let editor = assist.editor.upgrade().context("invalid assist")?;
@@ -912,7 +780,15 @@ impl InlineAssistant {
});
let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
let range = assist.codegen.read(cx).range.clone();
anyhow::Ok((user_prompt, context_request, project_name, buffer, range))
let model = CompletionProvider::global(cx).model();
anyhow::Ok((
user_prompt,
context_request,
project_name,
buffer,
range,
model,
))
})??;
let language = buffer.language_at(range.start);
@@ -971,6 +847,7 @@ impl InlineAssistant {
});
Ok(LanguageModelRequest {
model,
messages,
stop: vec!["|END|>".to_string()],
temperature,
@@ -978,7 +855,7 @@ impl InlineAssistant {
})
}
pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
assist
} else {
@@ -1197,14 +1074,6 @@ impl EditorInlineAssists {
}
}),
_subscriptions: vec![
cx.observe_release(editor, {
let editor = editor.downgrade();
|_, cx| {
InlineAssistant::update_global(cx, |this, cx| {
this.handle_editor_release(editor, cx);
})
}
}),
cx.observe(editor, move |editor, cx| {
InlineAssistant::update_global(cx, |this, cx| {
this.handle_editor_change(editor, cx)
@@ -1269,7 +1138,7 @@ fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
}
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct InlineAssistId(usize);
struct InlineAssistId(usize);
impl InlineAssistId {
fn post_inc(&mut self) -> InlineAssistId {
@@ -1323,19 +1192,22 @@ impl EventEmitter<PromptEditorEvent> for PromptEditor {}
impl Render for PromptEditor {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let gutter_dimensions = *self.gutter_dimensions.lock();
let fs = self.fs.clone();
let buttons = match &self.codegen.read(cx).status {
CodegenStatus::Idle => {
vec![
IconButton::new("cancel", IconName::Close)
.icon_color(Color::Muted)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
),
IconButton::new("start", IconName::SparkleAlt)
IconButton::new("start", IconName::Sparkle)
.icon_color(Color::Muted)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.icon_size(IconSize::XSmall)
.tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
@@ -1346,14 +1218,15 @@ impl Render for PromptEditor {
vec![
IconButton::new("cancel", IconName::Close)
.icon_color(Color::Muted)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.tooltip(|cx| Tooltip::text("Cancel Assist", cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
),
IconButton::new("stop", IconName::Stop)
.icon_color(Color::Error)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.icon_size(IconSize::XSmall)
.tooltip(|cx| {
Tooltip::with_meta(
"Interrupt Transformation",
@@ -1371,7 +1244,7 @@ impl Render for PromptEditor {
vec![
IconButton::new("cancel", IconName::Close)
.icon_color(Color::Muted)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
@@ -1379,7 +1252,8 @@ impl Render for PromptEditor {
if self.edited_since_done {
IconButton::new("restart", IconName::RotateCw)
.icon_color(Color::Info)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.size(ButtonSize::None)
.tooltip(|cx| {
Tooltip::with_meta(
"Restart Transformation",
@@ -1394,7 +1268,7 @@ impl Render for PromptEditor {
} else {
IconButton::new("confirm", IconName::Check)
.icon_color(Color::Info)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
.on_click(cx.listener(|_, _, cx| {
cx.emit(PromptEditorEvent::ConfirmRequested);
@@ -1420,27 +1294,59 @@ impl Render for PromptEditor {
.w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
.justify_center()
.gap_2()
.child(ModelSelector::new(
self.fs.clone(),
IconButton::new("context", IconName::Settings)
.shape(IconButtonShape::Square)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.tooltip(move |cx| {
Tooltip::with_meta(
format!(
"Using {}",
LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()),
),
None,
"Change Model",
cx,
)
}),
))
.child(
PopoverMenu::new("model-switcher")
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models(cx)
{
menu = menu.custom_entry(
{
let model = model.clone();
move |_| {
Label::new(model.display_name())
.into_any_element()
}
},
{
let fs = fs.clone();
let model = model.clone();
move |cx| {
let model = model.clone();
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings| settings.set_model(model),
);
}
},
);
}
menu
})
.into()
})
.trigger(
IconButton::new("context", IconName::Settings)
.size(ButtonSize::None)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.tooltip(move |cx| {
Tooltip::with_meta(
format!(
"Using {}",
CompletionProvider::global(cx)
.model()
.display_name()
),
None,
"Click to Change Model",
cx,
)
}),
)
.anchor(gpui::AnchorCorner::BottomRight),
)
.children(
if let CodegenStatus::Error(error) = &self.codegen.read(cx).status {
let error_message = SharedString::from(error.to_string());
@@ -1463,7 +1369,7 @@ impl Render for PromptEditor {
.child(
h_flex()
.gap_2()
.pr_6()
.pr_4()
.children(self.render_token_count(cx))
.children(buttons),
)
@@ -1629,9 +1535,7 @@ impl PromptEditor {
.await?;
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
@@ -1759,7 +1663,7 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
let model = CompletionProvider::global(cx).model();
let token_count = self.token_count?;
let max_token_count = model.max_token_count();
@@ -1827,8 +1731,12 @@ impl PromptEditor {
font_features: settings.ui_font.features.clone(),
font_size: rems(0.875).into(),
font_weight: settings.ui_font.weight,
font_style: FontStyle::Normal,
line_height: relative(1.3),
..Default::default()
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
};
EditorElement::new(
&self.editor,
@@ -1860,8 +1768,8 @@ impl InlineAssist {
include_context: bool,
editor: &View<Editor>,
prompt_editor: &View<PromptEditor>,
prompt_block_id: CustomBlockId,
end_block_id: CustomBlockId,
prompt_block_id: BlockId,
end_block_id: BlockId,
codegen: Model<Codegen>,
workspace: Option<WeakView<Workspace>>,
cx: &mut WindowContext,
@@ -1955,10 +1863,10 @@ impl InlineAssist {
}
struct InlineAssistDecorations {
prompt_block_id: CustomBlockId,
prompt_block_id: BlockId,
prompt_editor: View<PromptEditor>,
removed_line_block_ids: HashSet<CustomBlockId>,
end_block_id: CustomBlockId,
removed_line_block_ids: HashSet<BlockId>,
end_block_id: BlockId,
}
#[derive(Debug)]
@@ -1974,8 +1882,7 @@ pub struct Codegen {
range: Range<Anchor>,
edit_position: Anchor,
last_equal_ranges: Vec<Range<Anchor>>,
prepend_transaction_id: Option<TransactionId>,
generation_transaction_id: Option<TransactionId>,
transaction_id: Option<TransactionId>,
status: CodegenStatus,
generation: Task<()>,
diff: Diff,
@@ -2004,7 +1911,6 @@ impl Codegen {
pub fn new(
buffer: Model<MultiBuffer>,
range: Range<Anchor>,
prepend_transaction_id: Option<TransactionId>,
telemetry: Option<Arc<Telemetry>>,
cx: &mut ModelContext<Self>,
) -> Self {
@@ -2037,8 +1943,7 @@ impl Codegen {
range,
snapshot,
last_equal_ranges: Default::default(),
prepend_transaction_id,
generation_transaction_id: None,
transaction_id: Default::default(),
status: CodegenStatus::Idle,
generation: Task::ready(()),
diff: Diff::default(),
@@ -2054,13 +1959,8 @@ impl Codegen {
cx: &mut ModelContext<Self>,
) {
if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
if self.generation_transaction_id == Some(*transaction_id) {
self.generation_transaction_id = None;
self.generation = Task::ready(());
cx.emit(CodegenEvent::Undone);
} else if self.prepend_transaction_id == Some(*transaction_id) {
self.prepend_transaction_id = None;
self.generation_transaction_id = None;
if self.transaction_id == Some(*transaction_id) {
self.transaction_id = None;
self.generation = Task::ready(());
cx.emit(CodegenEvent::Undone);
}
@@ -2071,12 +1971,7 @@ impl Codegen {
&self.last_equal_ranges
}
pub fn start(
&mut self,
telemetry_id: String,
stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
cx: &mut ModelContext<Self>,
) {
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
let range = self.range.clone();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
@@ -2084,37 +1979,21 @@ impl Codegen {
.collect::<Rope>();
let selection_start = range.start.to_point(&snapshot);
// Start with the indentation of the first line in the selection
let mut suggested_line_indent = snapshot
.suggested_indents(selection_start.row..=selection_start.row, cx)
let suggested_line_indent = snapshot
.suggested_indents(selection_start.row..selection_start.row + 1, cx)
.into_values()
.next()
.unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
// If the first line in the selection does not have indentation, check the following lines
if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
for row in selection_start.row..=range.end.to_point(&snapshot).row {
let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
// Prefer tabs if a line in the selection uses tabs as indentation
if line_indent.kind == IndentKind::Tab {
suggested_line_indent.kind = IndentKind::Tab;
break;
}
}
}
let model_telemetry_id = prompt.model.telemetry_id();
let response = CompletionProvider::global(cx).complete(prompt, cx);
let telemetry = self.telemetry.clone();
self.edit_position = range.start;
self.diff = Diff::default();
self.status = CodegenStatus::Pending;
if let Some(transaction_id) = self.generation_transaction_id.take() {
self.buffer
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
}
self.generation = cx.spawn(|this, mut cx| {
async move {
let chunks = stream.await;
let response = response.await;
let generate = async {
let mut edit_start = range.start.to_offset(&snapshot);
@@ -2124,7 +2003,7 @@ impl Codegen {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
let chunks = StripInvalidSpans::new(chunks?);
let chunks = StripInvalidSpans::new(response.inner.await?);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
@@ -2207,7 +2086,7 @@ impl Codegen {
telemetry.report_assistant_event(
None,
telemetry_events::AssistantKind::Inline,
telemetry_id,
model_telemetry_id,
response_latency,
error_message,
);
@@ -2257,7 +2136,7 @@ impl Codegen {
});
if let Some(transaction) = transaction {
if let Some(first_transaction) = this.generation_transaction_id {
if let Some(first_transaction) = this.transaction_id {
// Group all assistant edits into the first transaction.
this.buffer.update(cx, |buffer, cx| {
buffer.merge_transactions(
@@ -2267,7 +2146,7 @@ impl Codegen {
)
});
} else {
this.generation_transaction_id = Some(transaction);
this.transaction_id = Some(transaction);
this.buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction(cx)
});
@@ -2310,12 +2189,7 @@ impl Codegen {
}
pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
if let Some(transaction_id) = self.prepend_transaction_id.take() {
self.buffer
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
}
if let Some(transaction_id) = self.generation_transaction_id.take() {
if let Some(transaction_id) = self.transaction_id.take() {
self.buffer
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
}
@@ -2577,6 +2451,10 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::FakeCompletionProvider;
use super::*;
use futures::stream::{self};
use gpui::{Context, TestAppContext};
@@ -2585,11 +2463,9 @@ mod tests {
language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
Point,
};
use language_model::LanguageModelRegistry;
use rand::prelude::*;
use serde::Serialize;
use settings::SettingsStore;
use std::{future, sync::Arc};
#[derive(Serialize)]
pub struct DummyCompletionRequest {
@@ -2599,8 +2475,7 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_model::LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
cx.update(language_settings::init);
let text = indoc! {"
@@ -2618,17 +2493,14 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.start(
String::new(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
codegen.start(LanguageModelRequest::default(), cx)
});
cx.background_executor.run_until_parked();
let mut new_text = concat!(
" let mut x = 0;\n",
" while x < 10 {\n",
@@ -2639,11 +2511,11 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
chunks_tx.unbounded_send(chunk.to_string()).unwrap();
provider.send_completion(&LanguageModelRequest::default(), chunk.into());
new_text = suffix;
cx.background_executor.run_until_parked();
}
drop(chunks_tx);
provider.finish_completion(&LanguageModelRequest::default());
cx.background_executor.run_until_parked();
assert_eq!(
@@ -2664,6 +2536,7 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
@@ -2679,16 +2552,10 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
});
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.start(
String::new(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
});
let request = LanguageModelRequest::default();
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
cx.background_executor.run_until_parked();
@@ -2702,11 +2569,11 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
chunks_tx.unbounded_send(chunk.to_string()).unwrap();
provider.send_completion(&LanguageModelRequest::default(), chunk.into());
new_text = suffix;
cx.background_executor.run_until_parked();
}
drop(chunks_tx);
provider.finish_completion(&LanguageModelRequest::default());
cx.background_executor.run_until_parked();
assert_eq!(
@@ -2727,8 +2594,7 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
cx.update(LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
@@ -2744,16 +2610,10 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
});
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.start(
String::new(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
});
let request = LanguageModelRequest::default();
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
cx.background_executor.run_until_parked();
@@ -2767,11 +2627,11 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
chunks_tx.unbounded_send(chunk.to_string()).unwrap();
provider.send_completion(&LanguageModelRequest::default(), chunk.into());
new_text = suffix;
cx.background_executor.run_until_parked();
}
drop(chunks_tx);
provider.finish_completion(&LanguageModelRequest::default());
cx.background_executor.run_until_parked();
assert_eq!(
@@ -2787,62 +2647,6 @@ mod tests {
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
cx.update(LanguageModelRegistry::test);
cx.update(completion::LanguageModelCompletionProvider::test);
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
func main() {
\tx := 0
\tfor i := 0; i < 10; i++ {
\t\tx++
\t}
}
"};
let buffer = cx.new_model(|cx| Buffer::local(text, cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
});
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.start(
String::new(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
});
let new_text = concat!(
"func main() {\n",
"\tx := 0\n",
"\tfor x < 10 {\n",
"\t\tx++\n",
"\t}", //
);
chunks_tx.unbounded_send(new_text.to_string()).unwrap();
drop(chunks_tx);
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
func main() {
\tx := 0
\tfor x < 10 {
\t\tx++
\t}
}
"}
);
}
#[gpui::test]
async fn test_strip_invalid_spans_from_codeblock() {
assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;

View File

@@ -1,128 +1,82 @@
use std::sync::Arc;
use crate::{assistant_settings::AssistantSettings, LanguageModelCompletionProvider};
use crate::{assistant_settings::AssistantSettings, CompletionProvider, ToggleModelSelector};
use fs::Fs;
use language_model::LanguageModelRegistry;
use settings::update_settings_file;
use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
use ui::{prelude::*, ButtonLike, ContextMenu, PopoverMenu, PopoverMenuHandle, Tooltip};
#[derive(IntoElement)]
pub struct ModelSelector<T: PopoverTrigger> {
handle: Option<PopoverMenuHandle<ContextMenu>>,
pub struct ModelSelector {
handle: PopoverMenuHandle<ContextMenu>,
fs: Arc<dyn Fs>,
trigger: T,
}
impl<T: PopoverTrigger> ModelSelector<T> {
pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
ModelSelector {
handle: None,
fs,
trigger,
}
}
pub fn with_handle(mut self, handle: PopoverMenuHandle<ContextMenu>) -> Self {
self.handle = Some(handle);
self
impl ModelSelector {
pub fn new(handle: PopoverMenuHandle<ContextMenu>, fs: Arc<dyn Fs>) -> Self {
ModelSelector { handle, fs }
}
}
impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
fn render(self, _: &mut WindowContext) -> impl IntoElement {
let mut menu = PopoverMenu::new("model-switcher");
if let Some(handle) = self.handle {
menu = menu.with_handle(handle);
}
menu.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for (index, provider) in LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.enumerate()
{
if index > 0 {
menu = menu.separator();
}
menu = menu.header(provider.name().0);
let available_models = provider.provided_models(cx);
if available_models.is_empty() {
impl RenderOnce for ModelSelector {
fn render(self, cx: &mut WindowContext) -> impl IntoElement {
PopoverMenu::new("model-switcher")
.with_handle(self.handle)
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models(cx) {
menu = menu.custom_entry(
{
move |_| {
h_flex()
.w_full()
.gap_1()
.child(Icon::new(IconName::Settings))
.child(Label::new("Configure"))
.into_any()
}
},
{
let provider = provider.id();
move |cx| {
LanguageModelCompletionProvider::global(cx).update(
cx,
|completion_provider, cx| {
completion_provider
.set_active_provider(provider.clone(), cx)
},
);
}
},
);
}
let selected_model = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.id());
let selected_provider = LanguageModelCompletionProvider::read_global(cx)
.active_provider()
.map(|m| m.id());
for available_model in available_models {
menu = menu.custom_entry(
{
let id = available_model.id();
let provider_id = available_model.provider_id();
let model_name = available_model.name().0.clone();
let selected_model = selected_model.clone();
let selected_provider = selected_provider.clone();
move |_| {
h_flex()
.w_full()
.justify_between()
.child(Label::new(model_name.clone()))
.when(
selected_model.as_ref() == Some(&id)
&& selected_provider.as_ref() == Some(&provider_id),
|this| this.child(Icon::new(IconName::Check)),
)
.into_any()
}
let model = model.clone();
move |_| Label::new(model.display_name()).into_any_element()
},
{
let fs = self.fs.clone();
let model = available_model.clone();
let model = model.clone();
move |cx| {
let model = model.clone();
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings, _| settings.set_model(model),
move |settings| settings.set_model(model),
);
}
},
);
}
}
menu
menu
})
.into()
})
.into()
})
.trigger(self.trigger)
.attach(gpui::AnchorCorner::BottomLeft)
.trigger(
ButtonLike::new("active-model")
.style(ButtonStyle::Subtle)
.child(
h_flex()
.w_full()
.gap_0p5()
.child(
div()
.overflow_x_hidden()
.flex_grow()
.whitespace_nowrap()
.child(
Label::new(
CompletionProvider::global(cx).model().display_name(),
)
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.child(
Icon::new(IconName::ChevronDown)
.color(Color::Muted)
.size(IconSize::XSmall),
),
)
.tooltip(move |cx| {
Tooltip::for_action("Change Model", &ToggleModelSelector, cx)
}),
)
.attach(gpui::AnchorCorner::BottomLeft)
}
}

View File

@@ -1,9 +1,8 @@
use crate::{
slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant,
LanguageModelCompletionProvider,
slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider,
InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use anyhow::{anyhow, Result};
use assets::Assets;
use chrono::{DateTime, Utc};
use collections::{HashMap, HashSet};
use editor::{actions::Tab, CurrentLineHighlight, Editor, EditorElement, EditorEvent, EditorStyle};
@@ -13,13 +12,12 @@ use futures::{
};
use fuzzy::StringMatchCandidate;
use gpui::{
actions, point, size, transparent_black, AppContext, AssetSource, BackgroundExecutor, Bounds,
EventEmitter, Global, HighlightStyle, PromptLevel, ReadGlobal, Subscription, Task, TextStyle,
actions, point, size, transparent_black, AppContext, BackgroundExecutor, Bounds, EventEmitter,
Global, HighlightStyle, PromptLevel, ReadGlobal, Subscription, Task, TextStyle,
TitlebarOptions, UpdateGlobal, View, WindowBounds, WindowHandle, WindowOptions,
};
use heed::{types::SerdeBincode, Database, RoTxn};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use parking_lot::RwLock;
use picker::{Picker, PickerDelegate};
use rope::Rope;
@@ -629,18 +627,17 @@ impl PromptLibrary {
self.picker.update(cx, |picker, cx| picker.focus(cx));
}
pub fn inline_assist(&mut self, action: &InlineAssist, cx: &mut ViewContext<Self>) {
pub fn inline_assist(&mut self, _: &InlineAssist, cx: &mut ViewContext<Self>) {
let Some(active_prompt_id) = self.active_prompt_id else {
cx.propagate();
return;
};
let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
let provider = LanguageModelCompletionProvider::read_global(cx);
let initial_prompt = action.prompt.clone();
if provider.is_authenticated(cx) {
let provider = CompletionProvider::global(cx);
if provider.is_authenticated() {
InlineAssistant::update_global(cx, |assistant, cx| {
assistant.assist(&prompt_editor, None, None, initial_prompt, cx)
assistant.assist(&prompt_editor, None, None, cx)
})
} else {
for window in cx.windows() {
@@ -736,8 +733,11 @@ impl PromptLibrary {
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(
let provider = CompletionProvider::global(cx);
let model = provider.model();
provider.count_tokens(
LanguageModelRequest {
model,
messages: vec![LanguageModelRequestMessage {
role: Role::System,
content: body.to_string(),
@@ -803,7 +803,7 @@ impl PromptLibrary {
let prompt_metadata = self.store.metadata(prompt_id)?;
let prompt_editor = &self.prompt_editors[&prompt_id];
let focus_handle = prompt_editor.body_editor.focus_handle(cx);
let current_model = LanguageModelCompletionProvider::read_global(cx).active_model();
let current_model = CompletionProvider::global(cx).model();
let settings = ThemeSettings::get_global(cx);
Some(
@@ -914,11 +914,7 @@ impl PromptLibrary {
format!(
"Model: {}",
current_model
.as_ref()
.map(|model| model
.name()
.0)
.unwrap_or_default()
.display_name()
),
cx,
)
@@ -1300,17 +1296,6 @@ impl PromptStore {
fn first(&self) -> Option<PromptMetadata> {
self.metadata_cache.read().metadata.first().cloned()
}
pub fn operations_prompt(&self) -> String {
String::from_utf8(
Assets
.load("prompts/operations.md")
.unwrap()
.unwrap()
.to_vec(),
)
.unwrap()
}
}
/// Wraps a shared future to a prompt store so it can be assigned as a context global.

View File

@@ -0,0 +1,171 @@
use language::Rope;
use std::ops::Range;
/// Search the given buffer for the given substring, ignoring any differences
/// in line indentation between the query and the buffer.
///
/// Returns a vector of ranges of byte offsets in the buffer corresponding
/// to the entire lines of the buffer.
pub fn fuzzy_search_lines(haystack: &Rope, needle: &str) -> Option<Range<usize>> {
const SIMILARITY_THRESHOLD: f64 = 0.8;
let mut best_match: Option<(Range<usize>, f64)> = None; // (range, score)
let mut haystack_lines = haystack.chunks().lines();
let mut haystack_line_start = 0;
while let Some(mut haystack_line) = haystack_lines.next() {
let next_haystack_line_start = haystack_line_start + haystack_line.len() + 1;
let mut advanced_to_next_haystack_line = false;
let mut matched = true;
let match_start = haystack_line_start;
let mut match_end = next_haystack_line_start;
let mut match_score = 0.0;
let mut needle_lines = needle.lines().peekable();
while let Some(needle_line) = needle_lines.next() {
let similarity = line_similarity(haystack_line, needle_line);
if similarity >= SIMILARITY_THRESHOLD {
match_end = haystack_lines.offset();
match_score += similarity;
if needle_lines.peek().is_some() {
if let Some(next_haystack_line) = haystack_lines.next() {
advanced_to_next_haystack_line = true;
haystack_line = next_haystack_line;
} else {
matched = false;
break;
}
} else {
break;
}
} else {
matched = false;
break;
}
}
if matched
&& best_match
.as_ref()
.map(|(_, best_score)| match_score > *best_score)
.unwrap_or(true)
{
best_match = Some((match_start..match_end, match_score));
}
if advanced_to_next_haystack_line {
haystack_lines.seek(next_haystack_line_start);
}
haystack_line_start = next_haystack_line_start;
}
best_match.map(|(range, _)| range)
}
/// Calculates the similarity between two lines, ignoring leading and trailing whitespace,
/// using the Jaro-Winkler distance.
///
/// Returns a value between 0.0 and 1.0, where 1.0 indicates an exact match.
fn line_similarity(line1: &str, line2: &str) -> f64 {
strsim::jaro_winkler(line1.trim(), line2.trim())
}
#[cfg(test)]
mod test {
use super::*;
use gpui::{AppContext, Context as _};
use language::Buffer;
use unindent::Unindent as _;
use util::test::marked_text_ranges;
#[gpui::test]
fn test_fuzzy_search_lines(cx: &mut AppContext) {
let (text, expected_ranges) = marked_text_ranges(
&r#"
fn main() {
if a() {
assert_eq!(
1 + 2,
does_not_match,
);
}
println!("hi");
assert_eq!(
1 + 2,
3,
); // this last line does not match
« assert_eq!(
1 + 2,
3,
);
»
« assert_eq!(
"something",
"else",
);
»
}
"#
.unindent(),
false,
);
let buffer = cx.new_model(|cx| Buffer::local(&text, cx));
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let actual_range = fuzzy_search_lines(
snapshot.as_rope(),
&"
assert_eq!(
1 + 2,
3,
);
"
.unindent(),
)
.unwrap();
assert_eq!(actual_range, expected_ranges[0]);
let actual_range = fuzzy_search_lines(
snapshot.as_rope(),
&"
assert_eq!(
1 + 2,
3,
);
"
.unindent(),
)
.unwrap();
assert_eq!(actual_range, expected_ranges[0]);
let actual_range = fuzzy_search_lines(
snapshot.as_rope(),
&"
asst_eq!(
\"something\",
\"els\"
)
"
.unindent(),
)
.unwrap();
assert_eq!(actual_range, expected_ranges[1]);
let actual_range = fuzzy_search_lines(
snapshot.as_rope(),
&"
assert_eq!(
2 + 1,
3,
);
"
.unindent(),
);
assert_eq!(actual_range, None);
}
}

View File

@@ -284,7 +284,7 @@ fn collect_diagnostics(
PathBuf::try_from(path)
.ok()
.and_then(|path| {
project.read(cx).worktrees(cx).find_map(|worktree| {
project.read(cx).worktrees().find_map(|worktree| {
let worktree = worktree.read(cx);
let worktree_root_path = Path::new(worktree.root_name());
let relative_path = path.strip_prefix(worktree_root_path).ok()?;

View File

@@ -1,13 +1,12 @@
use std::path::Path;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{anyhow, bail, Result};
use assistant_slash_command::{
ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
};
use gpui::{AppContext, BackgroundExecutor, Model, Task, WeakView};
use gpui::{AppContext, Model, Task, WeakView};
use indexed_docs::{
DocsDotRsProvider, IndexedDocsRegistry, IndexedDocsStore, LocalRustdocProvider, PackageName,
ProviderId,
@@ -24,7 +23,7 @@ impl DocsSlashCommand {
pub const NAME: &'static str = "docs";
fn path_to_cargo_toml(project: Model<Project>, cx: &mut AppContext) -> Option<Arc<Path>> {
let worktree = project.read(cx).worktrees(cx).next()?;
let worktree = project.read(cx).worktrees().next()?;
let worktree = worktree.read(cx);
let entry = worktree.entry_for_path("Cargo.toml")?;
let path = ProjectPath {
@@ -91,55 +90,6 @@ impl DocsSlashCommand {
}
}
}
/// Runs just-in-time indexing for a given package, in case the slash command
/// is run without any entries existing in the index.
fn run_just_in_time_indexing(
store: Arc<IndexedDocsStore>,
key: String,
package: PackageName,
executor: BackgroundExecutor,
) -> Task<()> {
executor.clone().spawn(async move {
let (prefix, needs_full_index) = if let Some((prefix, _)) = key.split_once('*') {
// If we have a wildcard in the search, we want to wait until
// we've completely finished indexing so we get a full set of
// results for the wildcard.
(prefix.to_string(), true)
} else {
(key, false)
};
// If we already have some entries, we assume that we've indexed the package before
// and don't need to do it again.
let has_any_entries = store
.any_with_prefix(prefix.clone())
.await
.unwrap_or_default();
if has_any_entries {
return ();
};
let index_task = store.clone().index(package.clone());
if needs_full_index {
_ = index_task.await;
} else {
loop {
executor.timer(Duration::from_millis(200)).await;
if store
.any_with_prefix(prefix.clone())
.await
.unwrap_or_default()
|| !store.is_indexing(&package)
{
break;
}
}
}
})
}
}
impl SlashCommand for DocsSlashCommand {
@@ -250,14 +200,13 @@ impl SlashCommand for DocsSlashCommand {
};
let args = DocsSlashCommandArgs::parse(argument);
let executor = cx.background_executor().clone();
let task = cx.background_executor().spawn({
let store = args
.provider()
.ok_or_else(|| anyhow!("no docs provider specified"))
.and_then(|provider| IndexedDocsStore::try_global(provider, cx));
async move {
let (provider, key) = match args.clone() {
let (provider, key) = match args {
DocsSlashCommandArgs::NoProvider => bail!("no docs provider specified"),
DocsSlashCommandArgs::SearchPackageDocs {
provider, package, ..
@@ -270,12 +219,6 @@ impl SlashCommand for DocsSlashCommand {
};
let store = store?;
if let Some(package) = args.package() {
Self::run_just_in_time_indexing(store.clone(), key.clone(), package, executor)
.await;
}
let (text, ranges) = if let Some((prefix, _)) = key.split_once('*') {
let docs = store.load_many_by_prefix(prefix.to_string()).await?;
@@ -326,7 +269,7 @@ fn is_item_path_delimiter(char: char) -> bool {
!char.is_alphanumeric() && char != '-' && char != '_'
}
#[derive(Debug, PartialEq, Clone)]
#[derive(Debug, PartialEq)]
pub(crate) enum DocsSlashCommandArgs {
NoProvider,
SearchPackageDocs {

View File

@@ -10,7 +10,7 @@ use assistant_slash_command::{
use futures::AsyncReadExt;
use gpui::{AppContext, Task, WeakView};
use html_to_markdown::{convert_html_to_markdown, markdown, TagHandler};
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
use http::{AsyncBody, HttpClient, HttpClientWithUrl};
use language::LspAdapterDelegate;
use ui::prelude::*;
use workspace::Workspace;

View File

@@ -188,7 +188,7 @@ fn collect_files(
let project_handle = project.downgrade();
let snapshots = project
.read(cx)
.worktrees(cx)
.worktrees()
.map(|worktree| worktree.read(cx).snapshot())
.collect::<Vec<_>>();
cx.spawn(|mut cx| async move {

View File

@@ -75,7 +75,7 @@ impl ProjectSlashCommand {
}
fn path_to_cargo_toml(project: Model<Project>, cx: &mut AppContext) -> Option<Arc<Path>> {
let worktree = project.read(cx).worktrees(cx).next()?;
let worktree = project.read(cx).worktrees().next()?;
let worktree = worktree.read(cx);
let entry = worktree.entry_for_path("Cargo.toml")?;
let path = ProjectPath {

View File

@@ -1,6 +1,7 @@
use crate::{
humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel,
AssistantPanelEvent, LanguageModelCompletionProvider, ModelSelector,
assistant_settings::AssistantSettings, humanize_token_count,
prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent,
CompletionProvider, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
@@ -12,12 +13,11 @@ use editor::{
use fs::Fs;
use futures::{channel::mpsc, SinkExt, StreamExt};
use gpui::{
AppContext, Context, EventEmitter, FocusHandle, FocusableView, Global, Model, ModelContext,
Subscription, Task, TextStyle, UpdateGlobal, View, WeakView,
AppContext, Context, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight, Global,
Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, WhiteSpace,
};
use language::Buffer;
use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use settings::Settings;
use settings::{update_settings_file, Settings};
use std::{
cmp,
sync::Arc,
@@ -26,7 +26,7 @@ use std::{
use terminal::Terminal;
use terminal_view::TerminalView;
use theme::ThemeSettings;
use ui::{prelude::*, IconButtonShape, Tooltip};
use ui::{prelude::*, ContextMenu, PopoverMenu, Tooltip};
use util::ResultExt;
use workspace::{notifications::NotificationId, Toast, Workspace};
@@ -73,13 +73,11 @@ impl TerminalInlineAssistant {
terminal_view: &View<TerminalView>,
workspace: Option<WeakView<Workspace>>,
assistant_panel: Option<&View<AssistantPanel>>,
initial_prompt: Option<String>,
cx: &mut WindowContext,
) {
let terminal = terminal_view.read(cx).terminal().clone();
let assist_id = self.next_assist_id.post_inc();
let prompt_buffer =
cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
let prompt_buffer = cx.new_model(|cx| Buffer::local("", cx));
let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
let codegen = cx.new_model(|_| Codegen::new(terminal, self.telemetry.clone()));
@@ -214,6 +212,8 @@ impl TerminalInlineAssistant {
) -> Result<LanguageModelRequest> {
let assist = self.assists.get(&assist_id).context("invalid assist")?;
let model = CompletionProvider::global(cx).model();
let shell = std::env::var("SHELL").ok();
let working_directory = assist
.terminal
@@ -265,6 +265,7 @@ impl TerminalInlineAssistant {
});
Ok(LanguageModelRequest {
model,
messages,
stop: Vec::new(),
temperature: 1.0,
@@ -447,19 +448,22 @@ impl EventEmitter<PromptEditorEvent> for PromptEditor {}
impl Render for PromptEditor {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let fs = self.fs.clone();
let buttons = match &self.codegen.read(cx).status {
CodegenStatus::Idle => {
vec![
IconButton::new("cancel", IconName::Close)
.icon_color(Color::Muted)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
),
IconButton::new("start", IconName::SparkleAlt)
IconButton::new("start", IconName::Sparkle)
.icon_color(Color::Muted)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.icon_size(IconSize::XSmall)
.tooltip(|cx| Tooltip::for_action("Generate", &menu::Confirm, cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
@@ -470,14 +474,15 @@ impl Render for PromptEditor {
vec![
IconButton::new("cancel", IconName::Close)
.icon_color(Color::Muted)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.tooltip(|cx| Tooltip::text("Cancel Assist", cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
),
IconButton::new("stop", IconName::Stop)
.icon_color(Color::Error)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.icon_size(IconSize::XSmall)
.tooltip(|cx| {
Tooltip::with_meta(
"Interrupt Generation",
@@ -495,7 +500,7 @@ impl Render for PromptEditor {
vec![
IconButton::new("cancel", IconName::Close)
.icon_color(Color::Muted)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
.on_click(
cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
@@ -503,7 +508,8 @@ impl Render for PromptEditor {
if self.edited_since_done {
IconButton::new("restart", IconName::RotateCw)
.icon_color(Color::Info)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.size(ButtonSize::None)
.tooltip(|cx| {
Tooltip::with_meta(
"Restart Generation",
@@ -518,7 +524,7 @@ impl Render for PromptEditor {
} else {
IconButton::new("confirm", IconName::Play)
.icon_color(Color::Info)
.shape(IconButtonShape::Square)
.size(ButtonSize::None)
.tooltip(|cx| {
Tooltip::for_action("Execute generated command", &menu::Confirm, cx)
})
@@ -546,27 +552,59 @@ impl Render for PromptEditor {
.w_12()
.justify_center()
.gap_2()
.child(ModelSelector::new(
self.fs.clone(),
IconButton::new("context", IconName::Settings)
.shape(IconButtonShape::Square)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.tooltip(move |cx| {
Tooltip::with_meta(
format!(
"Using {}",
LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|model| model.name().0)
.unwrap_or_else(|| "No model selected".into()),
),
None,
"Change Model",
cx,
)
}),
))
.child(
PopoverMenu::new("model-switcher")
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
for model in CompletionProvider::global(cx).available_models(cx)
{
menu = menu.custom_entry(
{
let model = model.clone();
move |_| {
Label::new(model.display_name())
.into_any_element()
}
},
{
let fs = fs.clone();
let model = model.clone();
move |cx| {
let model = model.clone();
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
move |settings| settings.set_model(model),
);
}
},
);
}
menu
})
.into()
})
.trigger(
IconButton::new("context", IconName::Settings)
.size(ButtonSize::None)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.tooltip(move |cx| {
Tooltip::with_meta(
format!(
"Using {}",
CompletionProvider::global(cx)
.model()
.display_name()
),
None,
"Click to Change Model",
cx,
)
}),
)
.anchor(gpui::AnchorCorner::BottomRight),
)
.children(
if let CodegenStatus::Error(error) = &self.codegen.read(cx).status {
let error_message = SharedString::from(error.to_string());
@@ -708,9 +746,7 @@ impl PromptEditor {
})??;
let token_count = cx
.update(|cx| {
LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
})?
.update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))?
.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
@@ -840,7 +876,7 @@ impl PromptEditor {
}
fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
let model = CompletionProvider::global(cx).model();
let token_count = self.token_count?;
let max_token_count = model.max_token_count();
@@ -907,9 +943,13 @@ impl PromptEditor {
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features.clone(),
font_size: rems(0.875).into(),
font_weight: settings.ui_font.weight,
font_weight: FontWeight::NORMAL,
font_style: FontStyle::Normal,
line_height: relative(1.3),
..Default::default()
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
};
EditorElement::new(
&self.editor,
@@ -985,12 +1025,8 @@ impl Codegen {
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
let telemetry = self.telemetry.clone();
let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
.active_model()
.map(|m| m.telemetry_id())
.unwrap_or_default();
let response =
LanguageModelCompletionProvider::read_global(cx).stream_completion(prompt, cx);
let model_telemetry_id = prompt.model.telemetry_id();
let response = CompletionProvider::global(cx).complete(prompt, cx);
self.generation = cx.spawn(|this, mut cx| async move {
let response = response.await;
@@ -1001,8 +1037,8 @@ impl Codegen {
let mut response_latency = None;
let request_start = Instant::now();
let task = async {
let mut chunks = response?;
while let Some(chunk) = chunks.next().await {
let mut response = response.inner.await?;
while let Some(chunk) = response.next().await {
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());
}

View File

@@ -222,7 +222,7 @@ mod tests {
let worktree_ids = project.read_with(cx, |project, cx| {
project
.worktrees(cx)
.worktrees()
.map(|worktree| worktree.read(cx).id())
.collect::<Vec<_>>()
});

View File

@@ -18,12 +18,11 @@ client.workspace = true
db.workspace = true
editor.workspace = true
gpui.workspace = true
http_client.workspace = true
http.workspace = true
isahc.workspace = true
log.workspace = true
markdown_preview.workspace = true
menu.workspace = true
paths.workspace = true
release_channel.workspace = true
schemars.workspace = true
serde.workspace = true

View File

@@ -1,7 +1,7 @@
mod update_notification;
use anyhow::{anyhow, Context, Result};
use client::{Client, TelemetrySettings};
use client::{Client, TelemetrySettings, ZED_APP_PATH};
use db::kvp::KEY_VALUE_STORE;
use db::RELEASE_CHANNEL;
use editor::{Editor, MultiBuffer};
@@ -20,7 +20,7 @@ use smol::{fs, io::AsyncReadExt};
use settings::{Settings, SettingsSources, SettingsStore};
use smol::{fs::File, process::Command};
use http_client::{HttpClient, HttpClientWithUrl};
use http::{HttpClient, HttpClientWithUrl};
use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
use std::{
env::{
@@ -28,7 +28,7 @@ use std::{
consts::{ARCH, OS},
},
ffi::OsString,
path::{Path, PathBuf},
path::PathBuf,
sync::Arc,
time::Duration,
};
@@ -55,8 +55,6 @@ struct UpdateRequestBody {
installation_id: Option<Arc<str>>,
release_channel: Option<&'static str>,
telemetry: bool,
is_staff: Option<bool>,
destination: &'static str,
}
#[derive(Clone, PartialEq, Eq)]
@@ -290,12 +288,7 @@ fn view_release_notes_locally(workspace: &mut Workspace, cx: &mut ViewContext<Wo
Some(tab_description),
cx,
);
workspace.add_item_to_active_pane(
Box::new(view.clone()),
None,
true,
cx,
);
workspace.add_item_to_active_pane(Box::new(view.clone()), None, cx);
cx.notify();
})
.log_err();
@@ -361,6 +354,7 @@ impl AutoUpdater {
return;
}
self.status = AutoUpdateStatus::Checking;
cx.notify();
self.pending_poll = Some(cx.spawn(|this, mut cx| async move {
@@ -386,65 +380,29 @@ impl AutoUpdater {
cx.notify();
}
pub async fn get_latest_remote_server_release(
os: &str,
arch: &str,
mut release_channel: ReleaseChannel,
cx: &mut AsyncAppContext,
) -> Result<PathBuf> {
let this = cx.update(|cx| {
cx.default_global::<GlobalAutoUpdate>()
.0
.clone()
.ok_or_else(|| anyhow!("auto-update not initialized"))
})??;
async fn update(this: Model<Self>, mut cx: AsyncAppContext) -> Result<()> {
let (client, current_version) = this.read_with(&cx, |this, _| {
(this.http_client.clone(), this.current_version)
})?;
if release_channel == ReleaseChannel::Dev {
release_channel = ReleaseChannel::Nightly;
}
let asset = match OS {
"linux" => format!("zed-linux-{}.tar.gz", ARCH),
"macos" => "Zed.dmg".into(),
_ => return Err(anyhow!("auto-update not supported for OS {:?}", OS)),
};
let release = Self::get_latest_release(
&this,
"zed-remote-server",
os,
arch,
Some(release_channel),
cx,
)
.await?;
let servers_dir = paths::remote_servers_dir();
let channel_dir = servers_dir.join(release_channel.dev_name());
let platform_dir = channel_dir.join(format!("{}-{}", os, arch));
let version_path = platform_dir.join(format!("{}.gz", release.version));
smol::fs::create_dir_all(&platform_dir).await.ok();
let client = this.read_with(cx, |this, _| this.http_client.clone())?;
if smol::fs::metadata(&version_path).await.is_err() {
log::info!("downloading zed-remote-server {os} {arch}");
download_remote_server_binary(&version_path, release, client, cx).await?;
}
Ok(version_path)
}
async fn get_latest_release(
this: &Model<Self>,
asset: &str,
os: &str,
arch: &str,
release_channel: Option<ReleaseChannel>,
cx: &mut AsyncAppContext,
) -> Result<JsonRelease> {
let client = this.read_with(cx, |this, _| this.http_client.clone())?;
let mut url_string = client.build_url(&format!(
"/api/releases/latest?asset={}&os={}&arch={}",
asset, os, arch
asset, OS, ARCH
));
if let Some(param) = release_channel.and_then(|c| c.release_query_param()) {
url_string += "&";
url_string += param;
}
cx.update(|cx| {
if let Some(param) = ReleaseChannel::try_global(cx)
.and_then(|release_channel| release_channel.release_query_param())
{
url_string += "&";
url_string += param;
}
})?;
let mut response = client.get(&url_string, Default::default(), true).await?;
@@ -455,34 +413,8 @@ impl AutoUpdater {
.await
.context("error reading release")?;
if !response.status().is_success() {
Err(anyhow!(
"failed to fetch release: {:?}",
String::from_utf8_lossy(&body),
))?;
}
serde_json::from_slice(body.as_slice()).with_context(|| {
format!(
"error deserializing release {:?}",
String::from_utf8_lossy(&body),
)
})
}
async fn update(this: Model<Self>, mut cx: AsyncAppContext) -> Result<()> {
let (client, current_version, release_channel) = this.update(&mut cx, |this, cx| {
this.status = AutoUpdateStatus::Checking;
cx.notify();
(
this.http_client.clone(),
this.current_version,
ReleaseChannel::try_global(cx),
)
})?;
let release =
Self::get_latest_release(&this, "zed", OS, ARCH, release_channel, &mut cx).await?;
let release: JsonRelease =
serde_json::from_slice(body.as_slice()).context("error deserializing release")?;
let should_download = match *RELEASE_CHANNEL {
ReleaseChannel::Nightly => cx
@@ -509,21 +441,19 @@ impl AutoUpdater {
let temp_dir = tempfile::Builder::new()
.prefix("zed-auto-update")
.tempdir()?;
let filename = match OS {
"macos" => Ok("Zed.dmg"),
"linux" => Ok("zed.tar.gz"),
_ => Err(anyhow!("not supported: {:?}", OS)),
}?;
let downloaded_asset = temp_dir.path().join(filename);
download_release(&downloaded_asset, release, client, &cx).await?;
let downloaded_asset = download_release(&temp_dir, release, &asset, client, &cx).await?;
this.update(&mut cx, |this, cx| {
this.status = AutoUpdateStatus::Installing;
cx.notify();
})?;
let binary_path = match OS {
// We store the path of our current binary, before we install, since installation might
// delete it. Once deleted, it's hard to get the path to our binary on Linux.
// So we cache it here, which allows us to then restart later on.
let binary_path = cx.update(|cx| cx.app_path())??;
match OS {
"macos" => install_release_macos(&temp_dir, downloaded_asset, &cx).await,
"linux" => install_release_linux(&temp_dir, downloaded_asset, &cx).await,
_ => Err(anyhow!("not supported: {:?}", OS)),
@@ -570,88 +500,45 @@ impl AutoUpdater {
}
}
async fn download_remote_server_binary(
target_path: &PathBuf,
release: JsonRelease,
client: Arc<HttpClientWithUrl>,
cx: &AsyncAppContext,
) -> Result<()> {
let mut target_file = File::create(&target_path).await?;
let (installation_id, release_channel, telemetry_enabled, is_staff) = cx.update(|cx| {
let telemetry = Client::global(cx).telemetry().clone();
let is_staff = telemetry.is_staff();
let installation_id = telemetry.installation_id();
let release_channel =
ReleaseChannel::try_global(cx).map(|release_channel| release_channel.display_name());
let telemetry_enabled = TelemetrySettings::get_global(cx).metrics;
(
installation_id,
release_channel,
telemetry_enabled,
is_staff,
)
})?;
let request_body = AsyncBody::from(serde_json::to_string(&UpdateRequestBody {
installation_id,
release_channel,
telemetry: telemetry_enabled,
is_staff,
destination: "remote",
})?);
let mut response = client.get(&release.url, request_body, true).await?;
smol::io::copy(response.body_mut(), &mut target_file).await?;
Ok(())
}
async fn download_release(
target_path: &Path,
temp_dir: &tempfile::TempDir,
release: JsonRelease,
target_filename: &str,
client: Arc<HttpClientWithUrl>,
cx: &AsyncAppContext,
) -> Result<()> {
) -> Result<PathBuf> {
let target_path = temp_dir.path().join(target_filename);
let mut target_file = File::create(&target_path).await?;
let (installation_id, release_channel, telemetry_enabled, is_staff) = cx.update(|cx| {
let telemetry = Client::global(cx).telemetry().clone();
let is_staff = telemetry.is_staff();
let installation_id = telemetry.installation_id();
let (installation_id, release_channel, telemetry) = cx.update(|cx| {
let installation_id = Client::global(cx).telemetry().installation_id();
let release_channel =
ReleaseChannel::try_global(cx).map(|release_channel| release_channel.display_name());
let telemetry_enabled = TelemetrySettings::get_global(cx).metrics;
let telemetry = TelemetrySettings::get_global(cx).metrics;
(
installation_id,
release_channel,
telemetry_enabled,
is_staff,
)
(installation_id, release_channel, telemetry)
})?;
let request_body = AsyncBody::from(serde_json::to_string(&UpdateRequestBody {
installation_id,
release_channel,
telemetry: telemetry_enabled,
is_staff,
destination: "local",
telemetry,
})?);
let mut response = client.get(&release.url, request_body, true).await?;
smol::io::copy(response.body_mut(), &mut target_file).await?;
log::info!("downloaded update. path:{:?}", target_path);
Ok(())
Ok(target_path)
}
async fn install_release_linux(
temp_dir: &tempfile::TempDir,
downloaded_tar_gz: PathBuf,
cx: &AsyncAppContext,
) -> Result<PathBuf> {
) -> Result<()> {
let channel = cx.update(|cx| ReleaseChannel::global(cx).dev_name())?;
let home_dir = PathBuf::from(env::var("HOME").context("no HOME env var set")?);
let running_app_path = cx.update(|cx| cx.app_path())??;
let extracted = temp_dir.path().join("zed");
fs::create_dir_all(&extracted)
@@ -682,16 +569,7 @@ async fn install_release_linux(
let app_folder_name = format!("zed{}.app", suffix);
let from = extracted.join(&app_folder_name);
let mut to = home_dir.join(".local");
let expected_suffix = format!("{}/libexec/zed-editor", app_folder_name);
if let Some(prefix) = running_app_path
.to_str()
.and_then(|str| str.strip_suffix(&expected_suffix))
{
to = PathBuf::from(prefix);
}
let to = home_dir.join(".local");
let output = Command::new("rsync")
.args(&["-av", "--delete"])
@@ -708,15 +586,17 @@ async fn install_release_linux(
String::from_utf8_lossy(&output.stderr)
);
Ok(to.join(expected_suffix))
Ok(())
}
async fn install_release_macos(
temp_dir: &tempfile::TempDir,
downloaded_dmg: PathBuf,
cx: &AsyncAppContext,
) -> Result<PathBuf> {
let running_app_path = cx.update(|cx| cx.app_path())??;
) -> Result<()> {
let running_app_path = ZED_APP_PATH
.clone()
.map_or_else(|| cx.update(|cx| cx.app_path())?, Ok)?;
let running_app_filename = running_app_path
.file_name()
.ok_or_else(|| anyhow!("invalid running app path"))?;
@@ -764,5 +644,5 @@ async fn install_release_macos(
String::from_utf8_lossy(&output.stderr)
);
Ok(running_app_path)
Ok(())
}

View File

@@ -113,30 +113,29 @@ impl ToolbarItemView for Breadcrumbs {
) -> ToolbarItemLocation {
cx.notify();
self.active_item = None;
let Some(item) = active_pane_item else {
return ToolbarItemLocation::Hidden;
};
let this = cx.view().downgrade();
self.subscription = Some(item.subscribe_to_item_events(
cx,
Box::new(move |event, cx| {
if let ItemEvent::UpdateBreadcrumbs = event {
this.update(cx, |this, cx| {
cx.notify();
if let Some(active_item) = this.active_item.as_ref() {
cx.emit(ToolbarItemEvent::ChangeLocation(
active_item.breadcrumb_location(cx),
))
}
})
.ok();
}
}),
));
self.active_item = Some(item.boxed_clone());
item.breadcrumb_location(cx)
if let Some(item) = active_pane_item {
let this = cx.view().downgrade();
self.subscription = Some(item.subscribe_to_item_events(
cx,
Box::new(move |event, cx| {
if let ItemEvent::UpdateBreadcrumbs = event {
this.update(cx, |this, cx| {
cx.notify();
if let Some(active_item) = this.active_item.as_ref() {
cx.emit(ToolbarItemEvent::ChangeLocation(
active_item.breadcrumb_location(cx),
))
}
})
.ok();
}
}),
));
self.active_item = Some(item.boxed_clone());
item.breadcrumb_location(cx)
} else {
ToolbarItemLocation::Hidden
}
}
fn pane_focus_update(&mut self, pane_focused: bool, _: &mut ViewContext<Self>) {

View File

@@ -51,4 +51,4 @@ language = { workspace = true, features = ["test-support"] }
live_kit_client = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
http = { workspace = true, features = ["test-support"] }

View File

@@ -526,7 +526,7 @@ impl Room {
rejoined_projects.push(proto::RejoinProject {
id: project_id,
worktrees: project
.worktrees(cx)
.worktrees()
.map(|worktree| {
let worktree = worktree.read(cx);
proto::RejoinWorktree {

View File

@@ -40,4 +40,4 @@ rpc = { workspace = true, features = ["test-support"] }
client = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
http = { workspace = true, features = ["test-support"] }

View File

@@ -4,7 +4,7 @@ use super::*;
use client::{test::FakeServer, Client, UserStore};
use clock::FakeSystemClock;
use gpui::{AppContext, Context, Model, SemanticVersion, TestAppContext};
use http_client::FakeHttpClient;
use http::FakeHttpClient;
use rpc::proto::{self};
use settings::SettingsStore;

View File

@@ -129,7 +129,6 @@ fn main() -> Result<()> {
|| path.starts_with("http://")
|| path.starts_with("https://")
|| path.starts_with("file://")
|| path.starts_with("ssh://")
{
urls.push(path.to_string());
} else {

View File

@@ -18,7 +18,7 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup
[dependencies]
anyhow.workspace = true
async-recursion = "0.3"
async-tungstenite = { workspace = true, features = ["async-std", "async-native-tls"] }
async-tungstenite = { version = "0.16", features = ["async-std", "async-native-tls"] }
chrono = { workspace = true, features = ["serde"] }
clock.workspace = true
collections.workspace = true
@@ -26,7 +26,7 @@ feature_flags.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
http.workspace = true
lazy_static.workspace = true
log.workspace = true
once_cell.workspace = true
@@ -60,11 +60,12 @@ gpui = { workspace = true, features = ["test-support"] }
rpc = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
http = { workspace = true, features = ["test-support"] }
[target.'cfg(target_os = "windows")'.dependencies]
windows.workspace = true
[target.'cfg(target_os = "macos")'.dependencies]
cocoa.workspace = true
isahc = { workspace = true, features = ["static-curl"] }
async-native-tls = { version = "0.5.0", features = ["vendored"] }

View File

@@ -20,7 +20,7 @@ use futures::{
use gpui::{
actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Global, Model, Task, WeakModel,
};
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
use http::{HttpClient, HttpClientWithUrl};
use lazy_static::lazy_static;
use parking_lot::RwLock;
use postage::watch;
@@ -233,7 +233,7 @@ pub enum EstablishConnectionError {
#[error("{0}")]
Other(#[from] anyhow::Error),
#[error("{0}")]
Http(#[from] http_client::Error),
Http(#[from] http::Error),
#[error("{0}")]
Io(#[from] std::io::Error),
#[error("{0}")]
@@ -1351,7 +1351,7 @@ impl Client {
let mut url = self.rpc_url(http.clone(), None).await?;
url.set_path("/user");
url.set_query(Some(&format!("github_login={login}")));
let request: http_client::Request<AsyncBody> = Request::get(url.as_str())
let request = Request::get(url.as_str())
.header("Authorization", format!("token {api_token}"))
.body("".into())?;
@@ -1783,7 +1783,7 @@ mod tests {
use clock::FakeSystemClock;
use gpui::{BackgroundExecutor, Context, TestAppContext};
use http_client::FakeHttpClient;
use http::FakeHttpClient;
use parking_lot::Mutex;
use proto::TypedEnvelope;
use settings::SettingsStore;

View File

@@ -6,7 +6,7 @@ use clock::SystemClock;
use collections::{HashMap, HashSet};
use futures::Future;
use gpui::{AppContext, BackgroundExecutor, Task};
use http_client::{self, HttpClient, HttpClientWithUrl, Method};
use http::{self, HttpClient, HttpClientWithUrl, Method};
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use release_channel::ReleaseChannel;
@@ -18,7 +18,7 @@ use sysinfo::{CpuRefreshKind, Pid, ProcessRefreshKind, RefreshKind, System};
use telemetry_events::{
ActionEvent, AppEvent, AssistantEvent, AssistantKind, CallEvent, CpuEvent, EditEvent,
EditorEvent, Event, EventRequestBody, EventWrapper, ExtensionEvent, InlineCompletionEvent,
MemoryEvent, ReplEvent, SettingEvent,
MemoryEvent, SettingEvent,
};
use tempfile::NamedTempFile;
#[cfg(not(debug_assertions))]
@@ -531,21 +531,6 @@ impl Telemetry {
}
}
pub fn report_repl_event(
self: &Arc<Self>,
kernel_language: String,
kernel_status: String,
repl_session_id: String,
) {
let event = Event::Repl(ReplEvent {
kernel_language,
kernel_status,
repl_session_id,
});
self.report_event(event)
}
fn report_event(self: &Arc<Self>, event: Event) {
let mut state = self.state.lock();
@@ -647,7 +632,7 @@ impl Telemetry {
let checksum = calculate_json_checksum(&json_bytes).unwrap_or("".to_string());
let request = http_client::Request::builder()
let request = http::Request::builder()
.method(Method::POST)
.uri(
this.http_client
@@ -676,7 +661,7 @@ mod tests {
use chrono::TimeZone;
use clock::FakeSystemClock;
use gpui::TestAppContext;
use http_client::FakeHttpClient;
use http::FakeHttpClient;
#[gpui::test]
fn test_telemetry_flush_on_max_queue_size(cx: &mut TestAppContext) {

View File

@@ -20,7 +20,7 @@ test-support = ["sqlite"]
[dependencies]
anthropic.workspace = true
anyhow.workspace = true
async-tungstenite.workspace = true
async-tungstenite = "0.16"
aws-config = { version = "1.1.5" }
aws-sdk-s3 = { version = "1.15.0" }
axum = { version = "0.6", features = ["json", "headers", "ws"] }
@@ -35,7 +35,7 @@ envy = "0.4.2"
futures.workspace = true
google_ai.workspace = true
hex.workspace = true
http_client.workspace = true
http.workspace = true
live_kit_server.workspace = true
log.workspace = true
nanoid.workspace = true
@@ -79,7 +79,6 @@ channel.workspace = true
client = { workspace = true, features = ["test-support"] }
collab_ui = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
completion = { workspace = true, features = ["test-support"] }
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
@@ -90,7 +89,6 @@ git_hosting_providers.workspace = true
gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
live_kit_client = { workspace = true, features = ["test-support"] }
lsp = { workspace = true, features = ["test-support"] }
menu.workspace = true
@@ -101,13 +99,10 @@ pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
recent_projects = { workspace = true }
release_channel.workspace = true
remote = { workspace = true, features = ["test-support"] }
remote_server.workspace = true
dev_server_projects.workspace = true
rpc = { workspace = true, features = ["test-support"] }
sea-orm = { version = "0.12.x", features = ["sqlx-sqlite"] }
serde_json.workspace = true
session = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
sqlx = { version = "0.7", features = ["sqlite"] }
theme.workspace = true

View File

@@ -16,7 +16,7 @@ use sha2::{Digest, Sha256};
use std::sync::{Arc, OnceLock};
use telemetry_events::{
ActionEvent, AppEvent, AssistantEvent, CallEvent, CpuEvent, EditEvent, EditorEvent, Event,
EventRequestBody, EventWrapper, ExtensionEvent, InlineCompletionEvent, MemoryEvent, ReplEvent,
EventRequestBody, EventWrapper, ExtensionEvent, InlineCompletionEvent, MemoryEvent,
SettingEvent,
};
use uuid::Uuid;
@@ -518,13 +518,6 @@ pub async fn post_events(
checksum_matched,
))
}
Event::Repl(event) => to_upload.repl_events.push(ReplEventRow::from_event(
event.clone(),
&wrapper,
&request_body,
first_event_at,
checksum_matched,
)),
}
}
@@ -549,7 +542,6 @@ struct ToUpload {
extension_events: Vec<ExtensionEventRow>,
edit_events: Vec<EditEventRow>,
action_events: Vec<ActionEventRow>,
repl_events: Vec<ReplEventRow>,
}
impl ToUpload {
@@ -625,11 +617,6 @@ impl ToUpload {
.await
.with_context(|| format!("failed to upload to table '{ACTION_EVENTS_TABLE}'"))?;
const REPL_EVENTS_TABLE: &str = "repl_events";
Self::upload_to_table(REPL_EVENTS_TABLE, &self.repl_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{REPL_EVENTS_TABLE}'"))?;
Ok(())
}
@@ -638,24 +625,22 @@ impl ToUpload {
rows: &[T],
clickhouse_client: &clickhouse::Client,
) -> anyhow::Result<()> {
if rows.is_empty() {
return Ok(());
if !rows.is_empty() {
let mut insert = clickhouse_client.insert(table)?;
for event in rows {
insert.write(event).await?;
}
insert.end().await?;
let event_count = rows.len();
log::info!(
"wrote {event_count} {event_specifier} to '{table}'",
event_specifier = if event_count == 1 { "event" } else { "events" }
);
}
let mut insert = clickhouse_client.insert(table)?;
for event in rows {
insert.write(event).await?;
}
insert.end().await?;
let event_count = rows.len();
log::info!(
"wrote {event_count} {event_specifier} to '{table}'",
event_specifier = if event_count == 1 { "event" } else { "events" }
);
Ok(())
}
}
@@ -1204,62 +1189,6 @@ impl ExtensionEventRow {
}
}
#[derive(Serialize, Debug, clickhouse::Row)]
pub struct ReplEventRow {
// AppInfoBase
app_version: String,
major: Option<i32>,
minor: Option<i32>,
patch: Option<i32>,
checksum_matched: bool,
release_channel: String,
os_name: String,
os_version: String,
// ClientEventBase
installation_id: Option<String>,
session_id: Option<String>,
is_staff: Option<bool>,
time: i64,
// ReplEventRow
kernel_language: String,
kernel_status: String,
repl_session_id: String,
}
impl ReplEventRow {
fn from_event(
event: ReplEvent,
wrapper: &EventWrapper,
body: &EventRequestBody,
first_event_at: chrono::DateTime<chrono::Utc>,
checksum_matched: bool,
) -> Self {
let semver = body.semver();
let time =
first_event_at + chrono::Duration::milliseconds(wrapper.milliseconds_since_first_event);
Self {
app_version: body.app_version.clone(),
major: semver.map(|v| v.major() as i32),
minor: semver.map(|v| v.minor() as i32),
patch: semver.map(|v| v.patch() as i32),
checksum_matched,
release_channel: body.release_channel.clone().unwrap_or_default(),
os_name: body.os_name.clone(),
os_version: body.os_version.clone().unwrap_or_default(),
installation_id: body.installation_id.clone(),
session_id: body.session_id.clone(),
is_staff: body.is_staff,
time: time.timestamp_millis(),
kernel_language: event.kernel_language,
kernel_status: event.kernel_status,
repl_session_id: event.repl_session_id,
}
}
}
#[derive(Serialize, Debug, clickhouse::Row)]
pub struct EditEventRow {
// AppInfoBase

View File

@@ -164,21 +164,10 @@ pub fn hash_access_token(token: &str) -> String {
/// Encrypts the given access token with the given public key to avoid leaking it on the way
/// to the client.
pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
use rpc::auth::EncryptionFormat;
/// The encryption format to use for the access token.
///
/// Currently we're using the original encryption format to avoid
/// breaking compatibility with older clients.
///
/// Once enough clients are capable of decrypting the newer encryption
/// format we can start encrypting with `EncryptionFormat::V1`.
const ENCRYPTION_FORMAT: EncryptionFormat = EncryptionFormat::V0;
let native_app_public_key =
rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
let encrypted_access_token = native_app_public_key
.encrypt_string(access_token, ENCRYPTION_FORMAT)
.encrypt_string(access_token)
.context("failed to encrypt access token with public key")?;
Ok(encrypted_access_token)
}

View File

@@ -0,0 +1 @@

View File

@@ -42,7 +42,7 @@ use futures::{
stream::FuturesUnordered,
FutureExt, SinkExt, StreamExt, TryStreamExt,
};
use http_client::IsahcHttpClient;
use http::IsahcHttpClient;
use prometheus::{register_int_gauge, IntGauge};
use rpc::{
proto::{
@@ -4514,7 +4514,7 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
}
async fn complete_with_language_model(
mut request: proto::CompleteWithLanguageModel,
request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: Session,
open_ai_api_key: Option<Arc<str>>,
@@ -4530,43 +4530,18 @@ async fn complete_with_language_model(
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
let mut provider_and_model = request.model.split('/');
let (provider, model) = match (
provider_and_model.next().unwrap(),
provider_and_model.next(),
) {
(provider, Some(model)) => (provider, model),
(model, None) => {
if model.starts_with("gpt") {
("openai", model)
} else if model.starts_with("gemini") {
("google", model)
} else if model.starts_with("claude") {
("anthropic", model)
} else {
("unknown", model)
}
}
};
let provider = provider.to_string();
request.model = model.to_string();
match provider.as_str() {
"openai" => {
let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
complete_with_open_ai(request, response, session, api_key).await?;
}
"anthropic" => {
let api_key =
anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
complete_with_anthropic(request, response, session, api_key).await?;
}
"google" => {
let api_key =
google_ai_api_key.context("no Google AI API key configured on the server")?;
complete_with_google_ai(request, response, session, api_key).await?;
}
provider => return Err(anyhow!("unknown provider {:?}", provider))?,
if request.model.starts_with("gpt") {
let api_key =
open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
complete_with_open_ai(request, response, session, api_key).await?;
} else if request.model.starts_with("gemini") {
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
complete_with_google_ai(request, response, session, api_key).await?;
} else if request.model.starts_with("claude") {
let api_key = anthropic_api_key
.ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
complete_with_anthropic(request, response, session, api_key).await?;
}
Ok(())

View File

@@ -16,7 +16,6 @@ mod notification_tests;
mod random_channel_buffer_tests;
mod random_project_collaboration_tests;
mod randomized_test_helpers;
mod remote_editing_collaboration_tests;
mod test_server;
use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};

View File

@@ -52,7 +52,7 @@ async fn test_channel_guests(
assert!(project_b.read_with(cx_b, |project, _| project.is_read_only()));
assert!(project_b
.update(cx_b, |project, cx| {
let worktree_id = project.worktrees(cx).next().unwrap().read(cx).id();
let worktree_id = project.worktrees().next().unwrap().read(cx).id();
project.create_entry((worktree_id, "b.txt"), false, cx)
})
.await

View File

@@ -76,7 +76,7 @@ async fn test_host_disconnect(
let active_call_a = cx_a.read(ActiveCall::global);
let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await;
let worktree_a = project_a.read_with(cx_a, |project, cx| project.worktrees(cx).next().unwrap());
let worktree_a = project_a.read_with(cx_a, |project, _| project.worktrees().next().unwrap());
let project_id = active_call_a
.update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
.await
@@ -1144,7 +1144,7 @@ async fn test_share_project(
});
project_b.read_with(cx_b, |project, cx| {
let worktree = project.worktrees(cx).next().unwrap().read(cx);
let worktree = project.worktrees().next().unwrap().read(cx);
assert_eq!(
worktree.paths().map(AsRef::as_ref).collect::<Vec<_>>(),
[
@@ -1158,7 +1158,7 @@ async fn test_share_project(
project_b
.update(cx_b, |project, cx| {
let worktree = project.worktrees(cx).next().unwrap();
let worktree = project.worktrees().next().unwrap();
let entry = worktree.read(cx).entry_for_path("ignored-dir").unwrap();
project.expand_entry(worktree_id, entry.id, cx).unwrap()
})
@@ -1166,7 +1166,7 @@ async fn test_share_project(
.unwrap();
project_b.read_with(cx_b, |project, cx| {
let worktree = project.worktrees(cx).next().unwrap().read(cx);
let worktree = project.worktrees().next().unwrap().read(cx);
assert_eq!(
worktree.paths().map(AsRef::as_ref).collect::<Vec<_>>(),
[

View File

@@ -266,7 +266,7 @@ async fn test_basic_following(
// When client A activates a different editor, client B does so as well.
workspace_a.update(cx_a, |workspace, cx| {
workspace.activate_item(&editor_a1, true, true, cx)
workspace.activate_item(&editor_a1, cx)
});
executor.run_until_parked();
workspace_b.update(cx_b, |workspace, cx| {
@@ -311,7 +311,7 @@ async fn test_basic_following(
let editor = cx.new_view(|cx| {
Editor::for_multibuffer(multibuffer_a, Some(project_a.clone()), true, cx)
});
workspace.add_item_to_active_pane(Box::new(editor.clone()), None, true, cx);
workspace.add_item_to_active_pane(Box::new(editor.clone()), None, cx);
editor
});
executor.run_until_parked();
@@ -401,7 +401,7 @@ async fn test_basic_following(
workspace.unfollow(peer_id_a, cx).unwrap()
});
workspace_a.update(cx_a, |workspace, cx| {
workspace.activate_item(&editor_a2, true, true, cx)
workspace.activate_item(&editor_a2, cx)
});
executor.run_until_parked();
assert_eq!(
@@ -466,7 +466,7 @@ async fn test_basic_following(
// Client B activates a multibuffer that was created by following client A. Client A returns to that multibuffer.
workspace_b.update(cx_b, |workspace, cx| {
workspace.activate_item(&multibuffer_editor_b, true, true, cx)
workspace.activate_item(&multibuffer_editor_b, cx)
});
executor.run_until_parked();
workspace_a.update(cx_a, |workspace, cx| {

View File

@@ -18,9 +18,7 @@ use gpui::{
TestAppContext, UpdateGlobal,
};
use language::{
language_settings::{
AllLanguageSettings, Formatter, FormatterList, PrettierSettings, SelectedFormatter,
},
language_settings::{AllLanguageSettings, Formatter, PrettierSettings},
tree_sitter_rust, Diagnostic, DiagnosticEntry, FakeLspAdapter, Language, LanguageConfig,
LanguageMatcher, LineEnding, OffsetRangeExt, Point, Rope,
};
@@ -1377,7 +1375,7 @@ async fn test_unshare_project(
.await
.unwrap();
let worktree_a = project_a.read_with(cx_a, |project, cx| project.worktrees(cx).next().unwrap());
let worktree_a = project_a.read_with(cx_a, |project, _| project.worktrees().next().unwrap());
let project_b = client_b.build_dev_server_project(project_id, cx_b).await;
executor.run_until_parked();
@@ -1505,8 +1503,7 @@ async fn test_project_reconnect(
let (project_a1, _) = client_a.build_local_project("/root-1/dir1", cx_a).await;
let (project_a2, _) = client_a.build_local_project("/root-2", cx_a).await;
let (project_a3, _) = client_a.build_local_project("/root-3", cx_a).await;
let worktree_a1 =
project_a1.read_with(cx_a, |project, cx| project.worktrees(cx).next().unwrap());
let worktree_a1 = project_a1.read_with(cx_a, |project, _| project.worktrees().next().unwrap());
let project1_id = active_call_a
.update(cx_a, |call, cx| call.share_project(project_a1.clone(), cx))
.await
@@ -2309,7 +2306,7 @@ async fn test_propagate_saves_and_fs_changes(
.await;
let (project_a, worktree_id) = client_a.build_local_project("/a", cx_a).await;
let worktree_a = project_a.read_with(cx_a, |p, cx| p.worktrees(cx).next().unwrap());
let worktree_a = project_a.read_with(cx_a, |p, _| p.worktrees().next().unwrap());
let project_id = active_call_a
.update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
.await
@@ -2319,9 +2316,9 @@ async fn test_propagate_saves_and_fs_changes(
let project_b = client_b.build_dev_server_project(project_id, cx_b).await;
let project_c = client_c.build_dev_server_project(project_id, cx_c).await;
let worktree_b = project_b.read_with(cx_b, |p, cx| p.worktrees(cx).next().unwrap());
let worktree_b = project_b.read_with(cx_b, |p, _| p.worktrees().next().unwrap());
let worktree_c = project_c.read_with(cx_c, |p, cx| p.worktrees(cx).next().unwrap());
let worktree_c = project_c.read_with(cx_c, |p, _| p.worktrees().next().unwrap());
// Open and edit a buffer as both guests B and C.
let buffer_b = project_b
@@ -3023,8 +3020,8 @@ async fn test_fs_operations(
.unwrap();
let project_b = client_b.build_dev_server_project(project_id, cx_b).await;
let worktree_a = project_a.read_with(cx_a, |project, cx| project.worktrees(cx).next().unwrap());
let worktree_b = project_b.read_with(cx_b, |project, cx| project.worktrees(cx).next().unwrap());
let worktree_a = project_a.read_with(cx_a, |project, _| project.worktrees().next().unwrap());
let worktree_b = project_b.read_with(cx_b, |project, _| project.worktrees().next().unwrap());
let entry = project_b
.update(cx_b, |project, cx| {
@@ -3324,7 +3321,7 @@ async fn test_local_settings(
// As client B, join that project and observe the local settings.
let project_b = client_b.build_dev_server_project(project_id, cx_b).await;
let worktree_b = project_b.read_with(cx_b, |project, cx| project.worktrees(cx).next().unwrap());
let worktree_b = project_b.read_with(cx_b, |project, _| project.worktrees().next().unwrap());
executor.run_until_parked();
cx_b.read(|cx| {
let store = cx.global::<SettingsStore>();
@@ -3736,7 +3733,7 @@ async fn test_leaving_project(
// Client B opens a buffer.
let buffer_b1 = project_b1
.update(cx_b, |project, cx| {
let worktree_id = project.worktrees(cx).next().unwrap().read(cx).id();
let worktree_id = project.worktrees().next().unwrap().read(cx).id();
project.open_buffer((worktree_id, "a.txt"), cx)
})
.await
@@ -3774,7 +3771,7 @@ async fn test_leaving_project(
let buffer_b2 = project_b2
.update(cx_b, |project, cx| {
let worktree_id = project.worktrees(cx).next().unwrap().read(cx).id();
let worktree_id = project.worktrees().next().unwrap().read(cx).id();
project.open_buffer((worktree_id, "a.txt"), cx)
})
.await
@@ -4412,13 +4409,10 @@ async fn test_formatting_buffer(
cx_a.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList(
vec![Formatter::External {
command: "awk".into(),
arguments: vec!["{sub(/two/,\"{buffer_path}\")}1".to_string()].into(),
}]
.into(),
)));
file.defaults.formatter = Some(Formatter::External {
command: "awk".into(),
arguments: vec!["{sub(/two/,\"{buffer_path}\")}1".to_string()].into(),
});
});
});
});
@@ -4499,7 +4493,7 @@ async fn test_prettier_formatting_buffer(
cx_a.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
file.defaults.formatter = Some(SelectedFormatter::Auto);
file.defaults.formatter = Some(Formatter::Auto);
file.defaults.prettier = Some(PrettierSettings {
allowed: true,
..PrettierSettings::default()
@@ -4510,9 +4504,7 @@ async fn test_prettier_formatting_buffer(
cx_b.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList(
vec![Formatter::LanguageServer { name: None }].into(),
)));
file.defaults.formatter = Some(Formatter::LanguageServer);
file.defaults.prettier = Some(PrettierSettings {
allowed: true,
..PrettierSettings::default()
@@ -4628,7 +4620,7 @@ async fn test_definition(
.unwrap();
cx_b.read(|cx| {
assert_eq!(definitions_1.len(), 1);
assert_eq!(project_b.read(cx).worktrees(cx).count(), 2);
assert_eq!(project_b.read(cx).worktrees().count(), 2);
let target_buffer = definitions_1[0].target.buffer.read(cx);
assert_eq!(
target_buffer.text(),
@@ -4657,7 +4649,7 @@ async fn test_definition(
.unwrap();
cx_b.read(|cx| {
assert_eq!(definitions_2.len(), 1);
assert_eq!(project_b.read(cx).worktrees(cx).count(), 2);
assert_eq!(project_b.read(cx).worktrees().count(), 2);
let target_buffer = definitions_2[0].target.buffer.read(cx);
assert_eq!(
target_buffer.text(),
@@ -4815,7 +4807,7 @@ async fn test_references(
assert!(status.pending_work.is_empty());
assert_eq!(references.len(), 3);
assert_eq!(project.worktrees(cx).count(), 2);
assert_eq!(project.worktrees().count(), 2);
let two_buffer = references[0].buffer.read(cx);
let three_buffer = references[2].buffer.read(cx);
@@ -6200,7 +6192,7 @@ async fn test_preview_tabs(cx: &mut TestAppContext) {
let project = workspace.update(cx, |workspace, _| workspace.project().clone());
let worktree_id = project.update(cx, |project, cx| {
project.worktrees(cx).next().unwrap().read(cx).id()
project.worktrees().next().unwrap().read(cx).id()
});
let path_1 = ProjectPath {

View File

@@ -301,7 +301,7 @@ impl RandomizedTest for ProjectCollaborationTest {
let is_local = project.read_with(cx, |project, _| project.is_local());
let worktree = project.read_with(cx, |project, cx| {
project
.worktrees(cx)
.worktrees()
.filter(|worktree| {
let worktree = worktree.read(cx);
worktree.is_visible()
@@ -423,7 +423,7 @@ impl RandomizedTest for ProjectCollaborationTest {
81.. => {
let worktree = project.read_with(cx, |project, cx| {
project
.worktrees(cx)
.worktrees()
.filter(|worktree| {
let worktree = worktree.read(cx);
worktree.is_visible()
@@ -1172,7 +1172,7 @@ impl RandomizedTest for ProjectCollaborationTest {
let host_worktree_snapshots =
host_project.read_with(host_cx, |host_project, cx| {
host_project
.worktrees(cx)
.worktrees()
.map(|worktree| {
let worktree = worktree.read(cx);
(worktree.id(), worktree.snapshot())
@@ -1180,7 +1180,7 @@ impl RandomizedTest for ProjectCollaborationTest {
.collect::<BTreeMap<_, _>>()
});
let guest_worktree_snapshots = guest_project
.worktrees(cx)
.worktrees()
.map(|worktree| {
let worktree = worktree.read(cx);
(worktree.id(), worktree.snapshot())
@@ -1538,7 +1538,7 @@ fn project_path_for_full_path(
let root_name = components.next().unwrap().as_os_str().to_str().unwrap();
let path = components.as_path().into();
let worktree_id = project.read_with(cx, |project, cx| {
project.worktrees(cx).find_map(|worktree| {
project.worktrees().find_map(|worktree| {
let worktree = worktree.read(cx);
if worktree.root_name() == root_name {
Some(worktree.id())

View File

@@ -1,102 +0,0 @@
use crate::tests::TestServer;
use call::ActiveCall;
use fs::{FakeFs, Fs as _};
use gpui::{Context as _, TestAppContext};
use remote::SshSession;
use remote_server::HeadlessProject;
use serde_json::json;
use std::{path::Path, sync::Arc};
#[gpui::test]
async fn test_sharing_an_ssh_remote_project(
cx_a: &mut TestAppContext,
cx_b: &mut TestAppContext,
server_cx: &mut TestAppContext,
) {
let executor = cx_a.executor();
let mut server = TestServer::start(executor.clone()).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
server
.create_room(&mut [(&client_a, cx_a), (&client_b, cx_b)])
.await;
// Set up project on remote FS
let (client_ssh, server_ssh) = SshSession::fake(cx_a, server_cx);
let remote_fs = FakeFs::new(server_cx.executor());
remote_fs
.insert_tree(
"/code",
json!({
"project1": {
"README.md": "# project 1",
"src": {
"lib.rs": "fn one() -> usize { 1 }"
}
},
"project2": {
"README.md": "# project 2",
},
}),
)
.await;
// User A connects to the remote project via SSH.
server_cx.update(HeadlessProject::init);
let _headless_project =
server_cx.new_model(|cx| HeadlessProject::new(server_ssh, remote_fs.clone(), cx));
let (project_a, worktree_id) = client_a
.build_ssh_project("/code/project1", client_ssh, cx_a)
.await;
// User A shares the remote project.
let active_call_a = cx_a.read(ActiveCall::global);
let project_id = active_call_a
.update(cx_a, |call, cx| call.share_project(project_a.clone(), cx))
.await
.unwrap();
// User B joins the project.
let project_b = client_b.build_dev_server_project(project_id, cx_b).await;
let worktree_b = project_b
.update(cx_b, |project, cx| project.worktree_for_id(worktree_id, cx))
.unwrap();
executor.run_until_parked();
worktree_b.update(cx_b, |worktree, _cx| {
assert_eq!(
worktree.paths().map(Arc::as_ref).collect::<Vec<_>>(),
vec![
Path::new("README.md"),
Path::new("src"),
Path::new("src/lib.rs"),
]
);
});
// User B can open buffers in the remote project.
let buffer_b = project_b
.update(cx_b, |project, cx| {
project.open_buffer((worktree_id, "src/lib.rs"), cx)
})
.await
.unwrap();
buffer_b.update(cx_b, |buffer, cx| {
assert_eq!(buffer.text(), "fn one() -> usize { 1 }");
let ix = buffer.text().find('1').unwrap();
buffer.edit([(ix..ix + 1, "100")], None, cx);
});
project_b
.update(cx_b, |project, cx| project.save_buffer(buffer_b, cx))
.await
.unwrap();
assert_eq!(
remote_fs
.load("/code/project1/src/lib.rs".as_ref())
.await
.unwrap(),
"fn one() -> usize { 100 }"
);
}

View File

@@ -19,20 +19,18 @@ use fs::FakeFs;
use futures::{channel::oneshot, StreamExt as _};
use git::GitHostingProviderRegistry;
use gpui::{BackgroundExecutor, Context, Model, Task, TestAppContext, View, VisualTestContext};
use http_client::FakeHttpClient;
use http::FakeHttpClient;
use language::LanguageRegistry;
use node_runtime::FakeNodeRuntime;
use notifications::NotificationStore;
use parking_lot::Mutex;
use project::{Project, WorktreeId};
use remote::SshSession;
use rpc::{
proto::{self, ChannelRole},
RECEIVE_TIMEOUT,
};
use semantic_version::SemanticVersion;
use serde_json::json;
use session::Session;
use settings::SettingsStore;
use std::{
cell::{Ref, RefCell, RefMut},
@@ -157,8 +155,6 @@ impl TestServer {
}
pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
let fs = FakeFs::new(cx.executor());
cx.update(|cx| {
if cx.has_global::<SettingsStore>() {
panic!("Same cx used to create two test clients")
@@ -267,6 +263,7 @@ impl TestServer {
git_hosting_provider_registry
.register_hosting_provider(Arc::new(git_hosting_providers::Github));
let fs = FakeFs::new(cx.executor());
let user_store = cx.new_model(|cx| UserStore::new(client.clone(), cx));
let workspace_store = cx.new_model(|cx| WorkspaceStore::new(client.clone(), cx));
let language_registry = Arc::new(LanguageRegistry::test(cx.executor()));
@@ -278,7 +275,6 @@ impl TestServer {
fs: fs.clone(),
build_window_options: |_, _| Default::default(),
node_runtime: FakeNodeRuntime::new(),
session: Session::test(),
});
let os_keymap = "keymaps/default-macos.json";
@@ -298,8 +294,7 @@ impl TestServer {
menu::init();
dev_server_projects::init(client.clone(), cx);
settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
language_model::LanguageModelRegistry::test(cx);
completion::init(cx);
assistant::FakeCompletionProvider::setup_test(cx);
assistant::context_store::init(&client);
});
@@ -407,7 +402,6 @@ impl TestServer {
fs: fs.clone(),
build_window_options: |_, _| Default::default(),
node_runtime: FakeNodeRuntime::new(),
session: Session::test(),
});
cx.update(|cx| {
@@ -820,30 +814,6 @@ impl TestClient {
(project, worktree.read_with(cx, |tree, _| tree.id()))
}
pub async fn build_ssh_project(
&self,
root_path: impl AsRef<Path>,
ssh: Arc<SshSession>,
cx: &mut TestAppContext,
) -> (Model<Project>, WorktreeId) {
let project = cx.update(|cx| {
Project::ssh(
ssh,
self.client().clone(),
self.app_state.node_runtime.clone(),
self.app_state.user_store.clone(),
self.app_state.languages.clone(),
self.app_state.fs.clone(),
cx,
)
});
let (worktree, _) = project
.update(cx, |p, cx| p.find_or_create_worktree(root_path, true, cx))
.await
.unwrap();
(project, worktree.read_with(cx, |tree, _| tree.id()))
}
pub async fn build_test_project(&self, cx: &mut TestAppContext) -> Model<Project> {
self.fs()
.insert_tree(

View File

@@ -25,14 +25,13 @@ test-support = [
"settings/test-support",
"util/test-support",
"workspace/test-support",
"http_client/test-support",
"http/test-support",
]
[dependencies]
anyhow.workspace = true
call.workspace = true
channel.workspace = true
chrono.workspace = true
client.workspace = true
collections.workspace = true
db.workspace = true
@@ -80,5 +79,5 @@ rpc = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
tree-sitter-markdown.workspace = true
util = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
http = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] }

View File

@@ -11,22 +11,20 @@ use editor::{
EditorEvent,
};
use gpui::{
actions, AnyView, AppContext, ClipboardItem, Entity as _, EventEmitter, FocusableView, Model,
Pixels, Point, Render, Subscription, Task, View, ViewContext, VisualContext as _, WeakView,
WindowContext,
actions, AnyElement, AnyView, AppContext, ClipboardItem, Entity as _, EventEmitter,
FocusableView, IntoElement as _, Model, Pixels, Point, Render, Subscription, Task, View,
ViewContext, VisualContext as _, WeakView, WindowContext,
};
use project::Project;
use rpc::proto::ChannelVisibility;
use std::{
any::{Any, TypeId},
sync::Arc,
};
use ui::prelude::*;
use ui::{prelude::*, Label};
use util::ResultExt;
use workspace::item::TabContentParams;
use workspace::{item::Dedup, notifications::NotificationId};
use workspace::{
item::{FollowableItem, Item, ItemEvent, ItemHandle},
item::{FollowableItem, Item, ItemEvent, ItemHandle, TabContentParams},
searchable::SearchableItemHandle,
ItemNavHistory, Pane, SaveIntent, Toast, ViewId, Workspace, WorkspaceId,
};
@@ -387,45 +385,24 @@ impl Item for ChannelView {
}
}
fn tab_icon(&self, cx: &WindowContext) -> Option<Icon> {
let channel = self.channel(cx)?;
let icon = match channel.visibility {
ChannelVisibility::Public => IconName::Public,
ChannelVisibility::Members => IconName::Hash,
};
Some(Icon::new(icon))
}
fn tab_content(&self, params: TabContentParams, cx: &WindowContext) -> gpui::AnyElement {
let (channel_name, status) = if let Some(channel) = self.channel(cx) {
let status = match (
fn tab_content(&self, params: TabContentParams, cx: &WindowContext) -> AnyElement {
let label = if let Some(channel) = self.channel(cx) {
match (
self.channel_buffer.read(cx).buffer().read(cx).read_only(),
self.channel_buffer.read(cx).is_connected(),
) {
(false, true) => None,
(true, true) => Some("read-only"),
(_, false) => Some("disconnected"),
};
(channel.name.clone(), status)
(false, true) => format!("#{}", channel.name),
(true, true) => format!("#{} (read-only)", channel.name),
(_, false) => format!("#{} (disconnected)", channel.name),
}
} else {
("<unknown>".into(), Some("disconnected"))
"channel notes (disconnected)".to_string()
};
h_flex()
.gap_2()
.child(
Label::new(channel_name)
.color(params.text_color())
.italic(params.preview),
)
.when_some(status, |element, status| {
element.child(
Label::new(status)
.size(LabelSize::XSmall)
.color(Color::Muted),
)
Label::new(label)
.color(if params.selected {
Color::Default
} else {
Color::Muted
})
.into_any_element()
}

View File

@@ -111,7 +111,6 @@ impl ChatPanel {
this.is_scrolled_to_bottom = !event.is_scrolled;
}));
let local_offset = chrono::Local::now().offset().local_minus_utc();
let mut this = Self {
fs,
client,
@@ -121,7 +120,7 @@ impl ChatPanel {
active_chat: Default::default(),
pending_serialization: Task::ready(None),
message_editor: input_editor,
local_timezone: UtcOffset::from_whole_seconds(local_offset).unwrap(),
local_timezone: cx.local_timezone(),
subscriptions: Vec::new(),
is_scrolled_to_bottom: true,
active: false,
@@ -1107,11 +1106,9 @@ impl Panel for ChatPanel {
}
fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
settings::update_settings_file::<ChatPanelSettings>(
self.fs.clone(),
cx,
move |settings, _| settings.dock = Some(position),
);
settings::update_settings_file::<ChatPanelSettings>(self.fs.clone(), cx, move |settings| {
settings.dock = Some(position)
});
}
fn size(&self, cx: &gpui::WindowContext) -> Pixels {

View File

@@ -6,7 +6,7 @@ use editor::{AnchorRangeExt, CompletionProvider, Editor, EditorElement, EditorSt
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
AsyncWindowContext, FocusableView, FontStyle, FontWeight, HighlightStyle, IntoElement, Model,
Render, Task, TextStyle, View, ViewContext, WeakView,
Render, Task, TextStyle, View, ViewContext, WeakView, WhiteSpace,
};
use language::{
language_settings::SoftWrap, Anchor, Buffer, BufferSnapshot, CodeLabel, LanguageRegistry,
@@ -537,7 +537,10 @@ impl Render for MessageEditor {
font_weight: settings.ui_font.weight,
font_style: FontStyle::Normal,
line_height: relative(1.3),
..Default::default()
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
};
div()

View File

@@ -16,7 +16,7 @@ use gpui::{
EventEmitter, FocusHandle, FocusableView, FontStyle, InteractiveElement, IntoElement,
ListOffset, ListState, Model, MouseDownEvent, ParentElement, Pixels, Point, PromptLevel,
Render, SharedString, Styled, Subscription, Task, TextStyle, View, ViewContext, VisualContext,
WeakView,
WeakView, WhiteSpace,
};
use menu::{Cancel, Confirm, SecondaryConfirm, SelectNext, SelectPrev};
use project::{Fs, Project};
@@ -2194,7 +2194,10 @@ impl CollabPanel {
font_weight: settings.ui_font.weight,
font_style: FontStyle::Normal,
line_height: relative(1.3),
..Default::default()
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
};
EditorElement::new(
@@ -2806,7 +2809,7 @@ impl Panel for CollabPanel {
settings::update_settings_file::<CollaborationPanelSettings>(
self.fs.clone(),
cx,
move |settings, _| settings.dock = Some(position),
move |settings| settings.dock = Some(position),
);
}

View File

@@ -5,6 +5,7 @@ use gpui::{
};
use picker::{Picker, PickerDelegate};
use std::sync::Arc;
use theme::ActiveTheme as _;
use ui::{prelude::*, Avatar, ListItem, ListItemSpacing};
use util::{ResultExt as _, TryFutureExt};
use workspace::ModalView;

View File

@@ -127,12 +127,11 @@ impl NotificationPanel {
},
));
let local_offset = chrono::Local::now().offset().local_minus_utc();
let mut this = Self {
fs,
client,
user_store,
local_timezone: UtcOffset::from_whole_seconds(local_offset).unwrap(),
local_timezone: cx.local_timezone(),
channel_store: ChannelStore::global(cx),
notification_store: NotificationStore::global(cx),
notification_list,
@@ -672,7 +671,7 @@ impl Panel for NotificationPanel {
settings::update_settings_file::<NotificationPanelSettings>(
self.fs.clone(),
cx,
move |settings, _| settings.dock = Some(position),
move |settings| settings.dock = Some(position),
);
}

View File

@@ -17,10 +17,9 @@ use gpui::{
use picker::{Picker, PickerDelegate};
use postage::{sink::Sink, stream::Stream};
use settings::Settings;
use ui::{h_flex, prelude::*, v_flex, HighlightedLabel, KeyBinding, ListItem, ListItemSpacing};
use util::ResultExt;
use workspace::{ModalView, Workspace, WorkspaceSettings};
use workspace::{ModalView, Workspace};
use zed_actions::OpenZedUrl;
actions!(command_palette, [Toggle]);
@@ -249,13 +248,9 @@ impl PickerDelegate for CommandPaletteDelegate {
fn update_matches(
&mut self,
mut query: String,
query: String,
cx: &mut ViewContext<Picker<Self>>,
) -> gpui::Task<()> {
let settings = WorkspaceSettings::get_global(cx);
if let Some(alias) = settings.command_aliases.get(&query) {
query = alias.to_string();
}
let (mut tx, mut rx) = postage::dispatch::channel(1);
let task = cx.background_executor().spawn({
let mut commands = self.all_commands.clone();
@@ -482,7 +477,7 @@ mod tests {
});
workspace.update(cx, |workspace, cx| {
workspace.add_item_to_active_pane(Box::new(editor.clone()), None, true, cx);
workspace.add_item_to_active_pane(Box::new(editor.clone()), None, cx);
editor.update(cx, |editor, cx| editor.focus(cx))
});

View File

@@ -1,43 +0,0 @@
[package]
name = "completion"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/completion.rs"
doctest = false
[features]
test-support = [
"editor/test-support",
"language/test-support",
"language_model/test-support",
"project/test-support",
"text/test-support",
]
[dependencies]
anyhow.workspace = true
futures.workspace = true
gpui.workspace = true
language_model.workspace = true
serde.workspace = true
settings.workspace = true
smol.workspace = true
ui.workspace = true
[dev-dependencies]
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
language = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
rand.workspace = true
text = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View File

@@ -1 +0,0 @@
../../LICENSE-GPL

View File

@@ -1,286 +0,0 @@
use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AppContext, Global, Model, ModelContext, Task};
use language_model::{
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
LanguageModelRequest,
};
use smol::lock::{Semaphore, SemaphoreGuardArc};
use std::{pin::Pin, sync::Arc, task::Poll};
use ui::Context;
pub fn init(cx: &mut AppContext) {
let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx));
cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider));
}
struct GlobalLanguageModelCompletionProvider(Model<LanguageModelCompletionProvider>);
impl Global for GlobalLanguageModelCompletionProvider {}
pub struct LanguageModelCompletionProvider {
active_provider: Option<Arc<dyn LanguageModelProvider>>,
active_model: Option<Arc<dyn LanguageModel>>,
request_limiter: Arc<Semaphore>,
}
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
pub struct LanguageModelCompletionResponse {
pub inner: BoxStream<'static, Result<String>>,
_lock: SemaphoreGuardArc,
}
impl futures::Stream for LanguageModelCompletionResponse {
type Item = Result<String>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}
impl LanguageModelCompletionProvider {
pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<GlobalLanguageModelCompletionProvider>()
.0
.clone()
}
pub fn read_global(cx: &AppContext) -> &Self {
cx.global::<GlobalLanguageModelCompletionProvider>()
.0
.read(cx)
}
#[cfg(any(test, feature = "test-support"))]
pub fn test(cx: &mut AppContext) {
let provider = cx.new_model(|cx| {
let mut this = Self::new(cx);
let available_model = LanguageModelRegistry::read_global(cx)
.available_models(cx)
.first()
.unwrap()
.clone();
this.set_active_model(available_model, cx);
this
});
cx.set_global(GlobalLanguageModelCompletionProvider(provider));
}
pub fn new(cx: &mut ModelContext<Self>) -> Self {
cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| {
cx.notify();
})
.detach();
Self {
active_provider: None,
active_model: None,
request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
}
}
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
self.active_provider.clone()
}
pub fn set_active_provider(
&mut self,
provider_id: LanguageModelProviderId,
cx: &mut ModelContext<Self>,
) {
self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_id);
self.active_model = None;
cx.notify();
}
pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
self.active_model.clone()
}
pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
if self.active_model.as_ref().map_or(false, |m| {
m.id() == model.id() && m.provider_id() == model.provider_id()
}) {
return;
}
self.active_provider =
LanguageModelRegistry::read_global(cx).provider(&model.provider_id());
self.active_model = Some(model.clone());
if let Some(provider) = self.active_provider.as_ref() {
provider.load_model(model, cx);
}
cx.notify();
}
pub fn is_authenticated(&self, cx: &AppContext) -> bool {
self.active_provider
.as_ref()
.map_or(false, |provider| provider.is_authenticated(cx))
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
self.active_provider
.as_ref()
.map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx))
}
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
self.active_provider
.as_ref()
.map_or(Task::ready(Ok(())), |provider| {
provider.reset_credentials(cx)
})
}
pub fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
if let Some(model) = self.active_model() {
model.count_tokens(request, cx)
} else {
std::future::ready(Err(anyhow!("No active model set"))).boxed()
}
}
pub fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> Task<Result<LanguageModelCompletionResponse>> {
if let Some(language_model) = self.active_model() {
let rate_limiter = self.request_limiter.clone();
cx.spawn(|cx| async move {
let lock = rate_limiter.acquire_arc().await;
let response = language_model.stream_completion(request, &cx).await?;
Ok(LanguageModelCompletionResponse {
inner: response,
_lock: lock,
})
})
} else {
Task::ready(Err(anyhow!("No active model set")))
}
}
pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
let response = self.stream_completion(request, cx);
cx.foreground_executor().spawn(async move {
let mut chunks = response.await?;
let mut completion = String::new();
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
completion.push_str(&chunk);
}
Ok(completion)
})
}
}
#[cfg(test)]
mod tests {
use futures::StreamExt;
use gpui::AppContext;
use settings::SettingsStore;
use ui::Context;
use crate::{
LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS,
};
use language_model::LanguageModelRegistry;
#[gpui::test]
fn test_rate_limiting(cx: &mut AppContext) {
SettingsStore::test(cx);
let fake_provider = LanguageModelRegistry::test(cx);
let model = LanguageModelRegistry::read_global(cx)
.available_models(cx)
.first()
.cloned()
.unwrap();
let provider = cx.new_model(|cx| {
let mut provider = LanguageModelCompletionProvider::new(cx);
provider.set_active_model(model.clone(), cx);
provider
});
let fake_model = fake_provider.test_model();
// Enqueue some requests
for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
let response = provider.read(cx).stream_completion(
LanguageModelRequest {
temperature: i as f32 / 10.0,
..Default::default()
},
cx,
);
cx.background_executor()
.spawn(async move {
let mut stream = response.await.unwrap();
while let Some(message) = stream.next().await {
message.unwrap();
}
})
.detach();
}
cx.background_executor().run_until_parked();
assert_eq!(
fake_model.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS
);
// Get the first completion request that is in flight and mark it as completed.
let completion = fake_model.pending_completions().into_iter().next().unwrap();
fake_model.finish_completion(&completion);
// Ensure that the number of in-flight completion requests is reduced.
assert_eq!(
fake_model.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
);
cx.background_executor().run_until_parked();
// Ensure that another completion request was allowed to acquire the lock.
assert_eq!(
fake_model.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS
);
// Mark all completion requests as finished that are in flight.
for request in fake_model.pending_completions() {
fake_model.finish_completion(&request);
}
assert_eq!(fake_model.completion_count(), 0);
// Wait until the background tasks acquire the lock again.
cx.background_executor().run_until_parked();
assert_eq!(
fake_model.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
);
// Finish all remaining completion requests.
for request in fake_model.pending_completions() {
fake_model.finish_completion(&request);
}
cx.background_executor().run_until_parked();
assert_eq!(fake_model.completion_count(), 0);
}
}

View File

@@ -32,7 +32,7 @@ command_palette_hooks.workspace = true
editor.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
http.workspace = true
language.workspace = true
lsp.workspace = true
menu.workspace = true
@@ -65,4 +65,4 @@ rpc = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
theme = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
http = { workspace = true, features = ["test-support"] }

View File

@@ -12,8 +12,8 @@ use gpui::{
actions, AppContext, AsyncAppContext, Context, Entity, EntityId, EventEmitter, Global, Model,
ModelContext, Task, WeakModel,
};
use http_client::github::latest_github_release;
use http_client::HttpClient;
use http::github::latest_github_release;
use http::HttpClient;
use language::{
language_settings::{all_language_settings, language_settings, InlineCompletionProvider},
point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16,
@@ -393,7 +393,7 @@ impl Copilot {
Default::default(),
cx.to_async(),
);
let http = http_client::FakeHttpClient::create(|_| async { unreachable!() });
let http = http::FakeHttpClient::create(|_| async { unreachable!() });
let node_runtime = FakeNodeRuntime::new();
let this = cx.new_model(|cx| Self {
server_id: LanguageServerId(0),
@@ -1236,7 +1236,7 @@ mod tests {
unimplemented!()
}
fn to_proto(&self, _: &AppContext) -> rpc::proto::File {
fn to_proto(&self) -> rpc::proto::File {
unimplemented!()
}

View File

@@ -8,7 +8,7 @@ use language::{
Buffer, OffsetRangeExt, ToOffset,
};
use settings::Settings;
use std::{ops::Range, path::Path, sync::Arc, time::Duration};
use std::{path::Path, sync::Arc, time::Duration};
pub const COPILOT_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
@@ -239,7 +239,7 @@ impl InlineCompletionProvider for CopilotCompletionProvider {
buffer: &Model<Buffer>,
cursor_position: language::Anchor,
cx: &'a AppContext,
) -> Option<(&'a str, Option<Range<language::Anchor>>)> {
) -> Option<&'a str> {
let buffer_id = buffer.entity_id();
let buffer = buffer.read(cx);
let completion = self.active_completion()?;
@@ -269,7 +269,7 @@ impl InlineCompletionProvider for CopilotCompletionProvider {
if completion_text.trim().is_empty() {
None
} else {
Some((completion_text, None))
Some(completion_text)
}
} else {
None

View File

@@ -4,13 +4,13 @@ mod toolbar_controls;
#[cfg(test)]
mod diagnostics_tests;
pub(crate) mod grouped_diagnostics;
mod grouped_diagnostics;
use anyhow::Result;
use collections::{BTreeSet, HashSet};
use editor::{
diagnostic_block_renderer,
display_map::{BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock},
display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, RenderBlock},
highlight_diagnostic_message,
scroll::Autoscroll,
Editor, EditorEvent, ExcerptId, ExcerptRange, MultiBuffer, ToOffset,
@@ -45,7 +45,7 @@ use ui::{h_flex, prelude::*, Icon, IconName, Label};
use util::ResultExt;
use workspace::{
item::{BreadcrumbText, Item, ItemEvent, ItemHandle, TabContentParams},
ItemNavHistory, ToolbarItemLocation, Workspace,
ItemNavHistory, Pane, ToolbarItemLocation, Workspace,
};
actions!(diagnostics, [Deploy, ToggleWarnings]);
@@ -85,7 +85,7 @@ struct DiagnosticGroupState {
primary_diagnostic: DiagnosticEntry<language::Anchor>,
primary_excerpt_ix: usize,
excerpts: Vec<ExcerptId>,
blocks: HashSet<CustomBlockId>,
blocks: HashSet<BlockId>,
block_count: usize,
}
@@ -237,13 +237,13 @@ impl ProjectDiagnosticsEditor {
fn deploy(workspace: &mut Workspace, _: &Deploy, cx: &mut ViewContext<Workspace>) {
if let Some(existing) = workspace.item_of_type::<ProjectDiagnosticsEditor>(cx) {
workspace.activate_item(&existing, true, true, cx);
workspace.activate_item(&existing, cx);
} else {
let workspace_handle = cx.view().downgrade();
let diagnostics = cx.new_view(|cx| {
ProjectDiagnosticsEditor::new(workspace.project().clone(), workspace_handle, cx)
});
workspace.add_item_to_active_pane(Box::new(diagnostics), None, true, cx);
workspace.add_item_to_active_pane(Box::new(diagnostics), None, cx);
}
}
@@ -649,7 +649,11 @@ impl Item for ProjectDiagnosticsEditor {
fn tab_content(&self, params: TabContentParams, _: &WindowContext) -> AnyElement {
if self.summary.error_count == 0 && self.summary.warning_count == 0 {
Label::new("No problems")
.color(params.text_color())
.color(if params.selected {
Color::Default
} else {
Color::Muted
})
.into_any_element()
} else {
h_flex()
@@ -659,10 +663,13 @@ impl Item for ProjectDiagnosticsEditor {
h_flex()
.gap_1()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(
Label::new(self.summary.error_count.to_string())
.color(params.text_color()),
),
.child(Label::new(self.summary.error_count.to_string()).color(
if params.selected {
Color::Default
} else {
Color::Muted
},
)),
)
})
.when(self.summary.warning_count > 0, |then| {
@@ -670,10 +677,13 @@ impl Item for ProjectDiagnosticsEditor {
h_flex()
.gap_1()
.child(Icon::new(IconName::ExclamationTriangle).color(Color::Warning))
.child(
Label::new(self.summary.warning_count.to_string())
.color(params.text_color()),
),
.child(Label::new(self.summary.warning_count.to_string()).color(
if params.selected {
Color::Default
} else {
Color::Muted
},
)),
)
})
.into_any_element()
@@ -776,6 +786,20 @@ impl Item for ProjectDiagnosticsEditor {
self.editor
.update(cx, |editor, cx| editor.added_to_workspace(workspace, cx));
}
fn serialized_item_kind() -> Option<&'static str> {
Some("diagnostics")
}
fn deserialize(
project: Model<Project>,
workspace: WeakView<Workspace>,
_workspace_id: workspace::WorkspaceId,
_item_id: workspace::ItemId,
cx: &mut ViewContext<Pane>,
) -> Task<Result<View<Self>>> {
Task::ready(Ok(cx.new_view(|cx| Self::new(project, workspace, cx))))
}
}
const DIAGNOSTIC_HEADER: &'static str = "diagnostic header";

View File

@@ -1,7 +1,7 @@
use super::*;
use collections::HashMap;
use editor::{
display_map::{Block, BlockContext, DisplayRow},
display_map::{BlockContext, DisplayRow, TransformBlock},
DisplayPoint, GutterDimensions,
};
use gpui::{px, AvailableSpace, Stateful, TestAppContext, VisualTestContext};
@@ -954,7 +954,6 @@ fn random_diagnostic(
is_primary,
is_disk_based: false,
is_unnecessary: false,
data: None,
},
}
}
@@ -975,9 +974,9 @@ fn editor_blocks(
snapshot
.blocks_in_range(DisplayRow(0)..snapshot.max_point().row())
.filter_map(|(row, block)| {
let block_id = block.id();
let transform_block_id = block.id();
let name: SharedString = match block {
Block::Custom(block) => {
TransformBlock::Custom(block) => {
let mut element = block.render(&mut BlockContext {
context: cx,
anchor_x: px(0.),
@@ -985,7 +984,7 @@ fn editor_blocks(
line_height: px(0.),
em_width: px(0.),
max_width: px(0.),
block_id,
transform_block_id,
editor_style: &editor::EditorStyle::default(),
});
let element = element.downcast_mut::<Stateful<Div>>().unwrap();
@@ -997,7 +996,7 @@ fn editor_blocks(
.ok()?
}
Block::ExcerptHeader {
TransformBlock::ExcerptHeader {
starts_new_buffer, ..
} => {
if *starts_new_buffer {
@@ -1006,7 +1005,7 @@ fn editor_blocks(
EXCERPT_HEADER.into()
}
}
Block::ExcerptFooter { .. } => EXCERPT_FOOTER.into(),
TransformBlock::ExcerptFooter { .. } => EXCERPT_FOOTER.into(),
};
Some((row, name))

View File

@@ -3,8 +3,8 @@ use collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use editor::{
diagnostic_block_renderer,
display_map::{
BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, CustomBlockId,
RenderBlock,
BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, RenderBlock,
TransformBlockId,
},
scroll::Autoscroll,
Bias, Editor, EditorEvent, ExcerptId, ExcerptRange, MultiBuffer, MultiBufferSnapshot, ToPoint,
@@ -40,7 +40,7 @@ use ui::{h_flex, prelude::*, Icon, IconName, Label};
use util::{debug_panic, ResultExt};
use workspace::{
item::{BreadcrumbText, Item, ItemEvent, ItemHandle, TabContentParams},
ItemNavHistory, ToolbarItemLocation, Workspace,
ItemNavHistory, Pane, ToolbarItemLocation, Workspace,
};
use crate::project_diagnostics_settings::ProjectDiagnosticsSettings;
@@ -51,18 +51,18 @@ pub fn init(cx: &mut AppContext) {
.detach();
}
pub struct GroupedDiagnosticsEditor {
pub project: Model<Project>,
struct GroupedDiagnosticsEditor {
project: Model<Project>,
workspace: WeakView<Workspace>,
focus_handle: FocusHandle,
editor: View<Editor>,
summary: DiagnosticSummary,
excerpts: Model<MultiBuffer>,
path_states: Vec<PathState>,
pub paths_to_update: BTreeSet<(ProjectPath, LanguageServerId)>,
pub include_warnings: bool,
paths_to_update: BTreeSet<(ProjectPath, LanguageServerId)>,
include_warnings: bool,
context: u32,
pub update_paths_tx: UnboundedSender<(ProjectPath, Option<LanguageServerId>)>,
update_paths_tx: UnboundedSender<(ProjectPath, Option<LanguageServerId>)>,
_update_excerpts_task: Task<Result<()>>,
_subscription: Subscription,
}
@@ -71,7 +71,7 @@ struct PathState {
path: ProjectPath,
first_excerpt_id: Option<ExcerptId>,
last_excerpt_id: Option<ExcerptId>,
diagnostics: Vec<(DiagnosticData, CustomBlockId)>,
diagnostics: Vec<(DiagnosticData, BlockId)>,
}
#[derive(Debug, Clone)]
@@ -250,17 +250,17 @@ impl GroupedDiagnosticsEditor {
fn deploy(workspace: &mut Workspace, _: &Deploy, cx: &mut ViewContext<Workspace>) {
if let Some(existing) = workspace.item_of_type::<GroupedDiagnosticsEditor>(cx) {
workspace.activate_item(&existing, true, true, cx);
workspace.activate_item(&existing, cx);
} else {
let workspace_handle = cx.view().downgrade();
let diagnostics = cx.new_view(|cx| {
GroupedDiagnosticsEditor::new(workspace.project().clone(), workspace_handle, cx)
});
workspace.add_item_to_active_pane(Box::new(diagnostics), None, true, cx);
workspace.add_item_to_active_pane(Box::new(diagnostics), None, cx);
}
}
pub fn toggle_warnings(&mut self, _: &ToggleWarnings, cx: &mut ViewContext<Self>) {
fn toggle_warnings(&mut self, _: &ToggleWarnings, cx: &mut ViewContext<Self>) {
self.include_warnings = !self.include_warnings;
self.enqueue_update_all_excerpts(cx);
cx.notify();
@@ -297,7 +297,7 @@ impl GroupedDiagnosticsEditor {
/// to have changed. If a language server id is passed, then only the excerpts for
/// that language server's diagnostics will be updated. Otherwise, all stale excerpts
/// will be refreshed.
pub fn enqueue_update_stale_excerpts(&mut self, language_server_id: Option<LanguageServerId>) {
fn enqueue_update_stale_excerpts(&mut self, language_server_id: Option<LanguageServerId>) {
for (path, server_id) in &self.paths_to_update {
if language_server_id.map_or(true, |id| id == *server_id) {
self.update_paths_tx
@@ -319,8 +319,8 @@ impl GroupedDiagnosticsEditor {
|| server_to_update.map_or(false, |to_update| *server_id != to_update)
});
// TODO change selections as in the old panel, to the next primary diagnostics
// TODO make [shift-]f8 to work, jump to the next block group
// TODO kb change selections as in the old panel, to the next primary diagnostics
// TODO kb make [shift-]f8 to work, jump to the next block group
let _was_empty = self.path_states.is_empty();
let path_ix = match self.path_states.binary_search_by(|probe| {
project::compare_paths((&probe.path.path, true), (&path_to_update.path, true))
@@ -340,6 +340,7 @@ impl GroupedDiagnosticsEditor {
}
};
// TODO kb when warnings are turned off, there's a lot of refresh for many paths happening, why?
let max_severity = if self.include_warnings {
DiagnosticSeverity::WARNING
} else {
@@ -465,7 +466,11 @@ impl Item for GroupedDiagnosticsEditor {
fn tab_content(&self, params: TabContentParams, _: &WindowContext) -> AnyElement {
if self.summary.error_count == 0 && self.summary.warning_count == 0 {
Label::new("No problems")
.color(params.text_color())
.color(if params.selected {
Color::Default
} else {
Color::Muted
})
.into_any_element()
} else {
h_flex()
@@ -475,10 +480,13 @@ impl Item for GroupedDiagnosticsEditor {
h_flex()
.gap_1()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(
Label::new(self.summary.error_count.to_string())
.color(params.text_color()),
),
.child(Label::new(self.summary.error_count.to_string()).color(
if params.selected {
Color::Default
} else {
Color::Muted
},
)),
)
})
.when(self.summary.warning_count > 0, |then| {
@@ -486,10 +494,13 @@ impl Item for GroupedDiagnosticsEditor {
h_flex()
.gap_1()
.child(Icon::new(IconName::ExclamationTriangle).color(Color::Warning))
.child(
Label::new(self.summary.warning_count.to_string())
.color(params.text_color()),
),
.child(Label::new(self.summary.warning_count.to_string()).color(
if params.selected {
Color::Default
} else {
Color::Muted
},
)),
)
})
.into_any_element()
@@ -592,6 +603,20 @@ impl Item for GroupedDiagnosticsEditor {
self.editor
.update(cx, |editor, cx| editor.added_to_workspace(workspace, cx));
}
fn serialized_item_kind() -> Option<&'static str> {
Some("diagnostics")
}
fn deserialize(
project: Model<Project>,
workspace: WeakView<Workspace>,
_workspace_id: workspace::WorkspaceId,
_item_id: workspace::ItemId,
cx: &mut ViewContext<Pane>,
) -> Task<Result<View<Self>>> {
Task::ready(Ok(cx.new_view(|cx| Self::new(project, workspace, cx))))
}
}
fn compare_data_locations(
@@ -632,7 +657,7 @@ fn compare_diagnostic_ranges(
})
}
// TODO wrong? What to do here instead?
// TODO kb wrong? What to do here instead?
fn compare_diagnostic_range_edges(
old: &Range<language::Anchor>,
new: &Range<language::Anchor>,
@@ -657,10 +682,10 @@ fn compare_diagnostic_range_edges(
struct PathUpdate {
path_excerpts_borders: (Option<ExcerptId>, Option<ExcerptId>),
latest_excerpt_id: ExcerptId,
new_diagnostics: Vec<(DiagnosticData, Option<CustomBlockId>)>,
new_diagnostics: Vec<(DiagnosticData, Option<BlockId>)>,
diagnostics_by_row_label: BTreeMap<MultiBufferRow, (editor::Anchor, Vec<usize>)>,
blocks_to_remove: HashSet<CustomBlockId>,
unchanged_blocks: HashMap<usize, CustomBlockId>,
blocks_to_remove: HashSet<BlockId>,
unchanged_blocks: HashMap<usize, BlockId>,
excerpts_with_new_diagnostics: HashSet<ExcerptId>,
excerpts_to_remove: Vec<ExcerptId>,
excerpt_expands: HashMap<(ExpandExcerptDirection, u32), Vec<ExcerptId>>,
@@ -749,7 +774,7 @@ impl PathUpdate {
context: u32,
multi_buffer_snapshot: MultiBufferSnapshot,
buffer_snapshot: BufferSnapshot,
current_diagnostics: impl Iterator<Item = &'a (DiagnosticData, CustomBlockId)> + 'a,
current_diagnostics: impl Iterator<Item = &'a (DiagnosticData, BlockId)> + 'a,
) {
let mut current_diagnostics = current_diagnostics.fuse().peekable();
let mut excerpts_to_expand =
@@ -1234,10 +1259,7 @@ impl PathUpdate {
.collect()
}
fn new_blocks(
mut self,
new_block_ids: Vec<CustomBlockId>,
) -> Vec<(DiagnosticData, CustomBlockId)> {
fn new_blocks(mut self, new_block_ids: Vec<BlockId>) -> Vec<(DiagnosticData, BlockId)> {
let mut new_block_ids = new_block_ids.into_iter().fuse();
for (_, (_, grouped_diagnostics)) in self.diagnostics_by_row_label {
let mut created_block_id = None;
@@ -1288,8 +1310,8 @@ fn render_same_line_diagnostics(
folded_block_height: u8,
) -> RenderBlock {
Box::new(move |cx: &mut BlockContext| {
let block_id = match cx.block_id {
BlockId::Custom(block_id) => block_id,
let block_id = match cx.transform_block_id {
TransformBlockId::Block(block_id) => block_id,
_ => {
debug_panic!("Expected a block id for the diagnostics block");
return div().into_any_element();
@@ -1318,67 +1340,59 @@ fn render_same_line_diagnostics(
.map(|diagnostic| diagnostic_text_lines(diagnostic))
.sum::<u8>();
let editor_handle = editor_handle.clone();
let parent = h_flex()
.items_start()
.child(v_flex().size_full().when_some_else(
toggle_expand_label,
|parent, label| {
parent.child(Button::new(cx.block_id, label).on_click({
let diagnostics = Arc::clone(&diagnostics);
move |_, cx| {
let new_expanded = !expanded;
button_expanded.store(new_expanded, atomic::Ordering::Release);
let new_size = if new_expanded {
expanded_block_height
} else {
folded_block_height
};
editor_handle.update(cx, |editor, cx| {
editor.replace_blocks(
HashMap::from_iter(Some((
block_id,
(
Some(new_size),
render_same_line_diagnostics(
Arc::clone(&button_expanded),
Arc::clone(&diagnostics),
editor_handle.clone(),
folded_block_height,
let mut parent = v_flex();
let mut diagnostics_iter = diagnostics.iter().fuse();
if let Some(first_diagnostic) = diagnostics_iter.next() {
let mut renderer = diagnostic_block_renderer(
first_diagnostic.clone(),
Some(folded_block_height),
false,
true,
);
parent = parent.child(
h_flex()
.when_some(toggle_expand_label, |parent, label| {
parent.child(Button::new(cx.transform_block_id, label).on_click({
let diagnostics = Arc::clone(&diagnostics);
move |_, cx| {
let new_expanded = !expanded;
button_expanded.store(new_expanded, atomic::Ordering::Release);
let new_size = if new_expanded {
expanded_block_height
} else {
folded_block_height
};
editor_handle.update(cx, |editor, cx| {
editor.replace_blocks(
HashMap::from_iter(Some((
block_id,
(
Some(new_size),
render_same_line_diagnostics(
Arc::clone(&button_expanded),
Arc::clone(&diagnostics),
editor_handle.clone(),
folded_block_height,
),
),
),
))),
None,
cx,
)
});
}
}))
},
|parent| {
parent.child(
h_flex()
.size(IconSize::default().rems())
.invisible()
.flex_none(),
)
},
));
let max_message_rows = if expanded {
None
} else {
Some(folded_block_height)
};
let mut renderer =
diagnostic_block_renderer(first_diagnostic.clone(), max_message_rows, false, true);
let mut diagnostics_element = v_flex();
diagnostics_element = diagnostics_element.child(renderer(cx));
))),
None,
cx,
)
});
}
}))
})
.child(renderer(cx)),
);
}
if expanded {
for diagnostic in diagnostics.iter().skip(1) {
for diagnostic in diagnostics_iter {
let mut renderer = diagnostic_block_renderer(diagnostic.clone(), None, false, true);
diagnostics_element = diagnostics_element.child(renderer(cx));
parent = parent.child(renderer(cx));
}
}
parent.child(diagnostics_element).into_any_element()
parent.into_any_element()
})
}

View File

@@ -1,12 +1,11 @@
use crate::{grouped_diagnostics::GroupedDiagnosticsEditor, ProjectDiagnosticsEditor};
use futures::future::Either;
use gpui::{EventEmitter, ParentElement, Render, View, ViewContext, WeakView};
use crate::ProjectDiagnosticsEditor;
use gpui::{EventEmitter, ParentElement, Render, ViewContext, WeakView};
use ui::prelude::*;
use ui::{IconButton, IconName, Tooltip};
use workspace::{item::ItemHandle, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView};
pub struct ToolbarControls {
editor: Option<Either<WeakView<ProjectDiagnosticsEditor>, WeakView<GroupedDiagnosticsEditor>>>,
editor: Option<WeakView<ProjectDiagnosticsEditor>>,
}
impl Render for ToolbarControls {
@@ -15,33 +14,18 @@ impl Render for ToolbarControls {
let mut has_stale_excerpts = false;
let mut is_updating = false;
if let Some(editor) = self.editor() {
match editor {
Either::Left(editor) => {
let editor = editor.read(cx);
include_warnings = editor.include_warnings;
has_stale_excerpts = !editor.paths_to_update.is_empty();
is_updating = editor.update_paths_tx.len() > 0
|| editor
.project
.read(cx)
.language_servers_running_disk_based_diagnostics()
.next()
.is_some();
}
Either::Right(editor) => {
let editor = editor.read(cx);
include_warnings = editor.include_warnings;
has_stale_excerpts = !editor.paths_to_update.is_empty();
is_updating = editor.update_paths_tx.len() > 0
|| editor
.project
.read(cx)
.language_servers_running_disk_based_diagnostics()
.next()
.is_some();
}
}
if let Some(editor) = self.editor.as_ref().and_then(|editor| editor.upgrade()) {
let editor = editor.read(cx);
include_warnings = editor.include_warnings;
has_stale_excerpts = !editor.paths_to_update.is_empty();
is_updating = editor.update_paths_tx.len() > 0
|| editor
.project
.read(cx)
.language_servers_running_disk_based_diagnostics()
.next()
.is_some();
}
let tooltip = if include_warnings {
@@ -58,19 +42,12 @@ impl Render for ToolbarControls {
.disabled(is_updating)
.tooltip(move |cx| Tooltip::text("Update excerpts", cx))
.on_click(cx.listener(|this, _, cx| {
if let Some(editor) = this.editor() {
match editor {
Either::Left(editor) => {
editor.update(cx, |editor, _| {
editor.enqueue_update_stale_excerpts(None);
});
}
Either::Right(editor) => {
editor.update(cx, |editor, _| {
editor.enqueue_update_stale_excerpts(None);
});
}
}
if let Some(editor) =
this.editor.as_ref().and_then(|editor| editor.upgrade())
{
editor.update(cx, |editor, _| {
editor.enqueue_update_stale_excerpts(None);
});
}
})),
)
@@ -79,19 +56,12 @@ impl Render for ToolbarControls {
IconButton::new("toggle-warnings", IconName::ExclamationTriangle)
.tooltip(move |cx| Tooltip::text(tooltip, cx))
.on_click(cx.listener(|this, _, cx| {
if let Some(editor) = this.editor() {
match editor {
Either::Left(editor) => {
editor.update(cx, |editor, cx| {
editor.toggle_warnings(&Default::default(), cx);
});
}
Either::Right(editor) => {
editor.update(cx, |editor, cx| {
editor.toggle_warnings(&Default::default(), cx);
});
}
}
if let Some(editor) =
this.editor.as_ref().and_then(|editor| editor.upgrade())
{
editor.update(cx, |editor, cx| {
editor.toggle_warnings(&Default::default(), cx);
});
}
})),
)
@@ -108,10 +78,7 @@ impl ToolbarItemView for ToolbarControls {
) -> ToolbarItemLocation {
if let Some(pane_item) = active_pane_item.as_ref() {
if let Some(editor) = pane_item.downcast::<ProjectDiagnosticsEditor>() {
self.editor = Some(Either::Left(editor.downgrade()));
ToolbarItemLocation::PrimaryRight
} else if let Some(editor) = pane_item.downcast::<GroupedDiagnosticsEditor>() {
self.editor = Some(Either::Right(editor.downgrade()));
self.editor = Some(editor.downgrade());
ToolbarItemLocation::PrimaryRight
} else {
ToolbarItemLocation::Hidden
@@ -126,13 +93,4 @@ impl ToolbarControls {
pub fn new() -> Self {
ToolbarControls { editor: None }
}
fn editor(
&self,
) -> Option<Either<View<ProjectDiagnosticsEditor>, View<GroupedDiagnosticsEditor>>> {
Some(match self.editor.as_ref()? {
Either::Left(diagnostics) => Either::Left(diagnostics.upgrade()?),
Either::Right(grouped_diagnostics) => Either::Right(grouped_diagnostics.upgrade()?),
})
}
}

View File

@@ -28,22 +28,20 @@ test-support = [
]
[dependencies]
aho-corasick.workspace = true
aho-corasick = "1.1"
anyhow.workspace = true
assets.workspace = true
chrono.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
convert_case = "0.6.0"
db.workspace = true
emojis.workspace = true
file_icons.workspace = true
futures.workspace = true
fuzzy.workspace = true
git.workspace = true
gpui.workspace = true
http_client.workspace = true
http.workspace = true
indoc.workspace = true
itertools.workspace = true
language.workspace = true
@@ -98,4 +96,4 @@ tree-sitter-typescript.workspace = true
unindent.workspace = true
util = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
http = { workspace = true, features = ["test-support"] }

View File

@@ -129,7 +129,7 @@ pub struct ExpandExcerptsDown {
#[derive(PartialEq, Clone, Deserialize, Default)]
pub struct ShowCompletions {
#[serde(default)]
pub(super) trigger: Option<String>,
pub(super) trigger: Option<char>,
}
impl_actions!(

View File

@@ -2,14 +2,18 @@ use futures::Future;
use git::blame::BlameEntry;
use git::Oid;
use gpui::{
Asset, ClipboardItem, Element, ParentElement, Render, ScrollHandle, StatefulInteractiveElement,
WeakView, WindowContext,
Asset, Element, ParentElement, Render, ScrollHandle, StatefulInteractiveElement, WeakView,
WindowContext,
};
use settings::Settings;
use std::hash::Hash;
use theme::ThemeSettings;
use time::UtcOffset;
use ui::{prelude::*, tooltip_container, Avatar};
use theme::{ActiveTheme, ThemeSettings};
use ui::{
div, h_flex, tooltip_container, v_flex, Avatar, Button, ButtonStyle, Clickable as _, Color,
FluentBuilder, Icon, IconName, IconPosition, InteractiveElement as _, IntoElement,
SharedString, Styled as _, ViewContext,
};
use ui::{ButtonCommon, Disableable as _};
use workspace::Workspace;
use crate::git::blame::{CommitDetails, GitRemote};
@@ -125,8 +129,7 @@ impl Render for BlameEntryTooltip {
let author_email = self.blame_entry.author_mail.clone();
let short_commit_id = self.blame_entry.sha.display_short();
let full_sha = self.blame_entry.sha.to_string().clone();
let absolute_timestamp = blame_entry_absolute_timestamp(&self.blame_entry);
let absolute_timestamp = blame_entry_absolute_timestamp(&self.blame_entry, cx);
let message = self
.details
@@ -236,16 +239,6 @@ impl Render for BlameEntryTooltip {
})
},
),
)
.child(
IconButton::new("copy-sha-button", IconName::Copy)
.icon_color(Color::Muted)
.on_click(move |_, cx| {
cx.stop_propagation();
cx.write_to_clipboard(ClipboardItem::new(
full_sha.clone(),
))
}),
),
),
),
@@ -254,25 +247,30 @@ impl Render for BlameEntryTooltip {
}
}
fn blame_entry_timestamp(blame_entry: &BlameEntry, format: time_format::TimestampFormat) -> String {
fn blame_entry_timestamp(
blame_entry: &BlameEntry,
format: time_format::TimestampFormat,
cx: &WindowContext,
) -> String {
match blame_entry.author_offset_date_time() {
Ok(timestamp) => {
let local = chrono::Local::now().offset().local_minus_utc();
time_format::format_localized_timestamp(
timestamp,
time::OffsetDateTime::now_utc(),
UtcOffset::from_whole_seconds(local).unwrap(),
format,
)
}
Ok(timestamp) => time_format::format_localized_timestamp(
timestamp,
time::OffsetDateTime::now_utc(),
cx.local_timezone(),
format,
),
Err(_) => "Error parsing date".to_string(),
}
}
pub fn blame_entry_relative_timestamp(blame_entry: &BlameEntry) -> String {
blame_entry_timestamp(blame_entry, time_format::TimestampFormat::Relative)
pub fn blame_entry_relative_timestamp(blame_entry: &BlameEntry, cx: &WindowContext) -> String {
blame_entry_timestamp(blame_entry, time_format::TimestampFormat::Relative, cx)
}
fn blame_entry_absolute_timestamp(blame_entry: &BlameEntry) -> String {
blame_entry_timestamp(blame_entry, time_format::TimestampFormat::MediumAbsolute)
fn blame_entry_absolute_timestamp(blame_entry: &BlameEntry, cx: &WindowContext) -> String {
blame_entry_timestamp(
blame_entry,
time_format::TimestampFormat::MediumAbsolute,
cx,
)
}

View File

@@ -28,8 +28,9 @@ use crate::{
hover_links::InlayHighlight, movement::TextLayoutDetails, EditorStyle, InlayId, RowExt,
};
pub use block_map::{
Block, BlockBufferRows, BlockChunks as DisplayChunks, BlockContext, BlockDisposition, BlockId,
BlockMap, BlockPoint, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
BlockBufferRows, BlockChunks as DisplayChunks, BlockContext, BlockDisposition, BlockId,
BlockMap, BlockPoint, BlockProperties, BlockStyle, RenderBlock, TransformBlock,
TransformBlockId,
};
use block_map::{BlockRow, BlockSnapshot};
use collections::{HashMap, HashSet};
@@ -269,7 +270,7 @@ impl DisplayMap {
&mut self,
blocks: impl IntoIterator<Item = BlockProperties<Anchor>>,
cx: &mut ModelContext<Self>,
) -> Vec<CustomBlockId> {
) -> Vec<BlockId> {
let snapshot = self.buffer.read(cx).snapshot(cx);
let edits = self.buffer_subscription.consume().into_inner();
let tab_size = Self::tab_size(&self.buffer, cx);
@@ -285,7 +286,7 @@ impl DisplayMap {
pub fn replace_blocks(
&mut self,
heights_and_renderers: HashMap<CustomBlockId, (Option<u8>, RenderBlock)>,
heights_and_renderers: HashMap<BlockId, (Option<u8>, RenderBlock)>,
cx: &mut ModelContext<Self>,
) {
//
@@ -306,8 +307,8 @@ impl DisplayMap {
// directly and the new behavior separately.
//
//
let mut only_renderers = HashMap::<CustomBlockId, RenderBlock>::default();
let mut full_replace = HashMap::<CustomBlockId, (u8, RenderBlock)>::default();
let mut only_renderers = HashMap::<BlockId, RenderBlock>::default();
let mut full_replace = HashMap::<BlockId, (u8, RenderBlock)>::default();
for (id, (height, render)) in heights_and_renderers {
if let Some(height) = height {
full_replace.insert(id, (height, render));
@@ -334,7 +335,7 @@ impl DisplayMap {
block_map.replace(full_replace);
}
pub fn remove_blocks(&mut self, ids: HashSet<CustomBlockId>, cx: &mut ModelContext<Self>) {
pub fn remove_blocks(&mut self, ids: HashSet<BlockId>, cx: &mut ModelContext<Self>) {
let snapshot = self.buffer.read(cx).snapshot(cx);
let edits = self.buffer_subscription.consume().into_inner();
let tab_size = Self::tab_size(&self.buffer, cx);
@@ -350,7 +351,7 @@ impl DisplayMap {
pub fn row_for_block(
&mut self,
block_id: CustomBlockId,
block_id: BlockId,
cx: &mut ModelContext<Self>,
) -> Option<DisplayRow> {
let snapshot = self.buffer.read(cx).snapshot(cx);
@@ -885,16 +886,12 @@ impl DisplaySnapshot {
pub fn blocks_in_range(
&self,
rows: Range<DisplayRow>,
) -> impl Iterator<Item = (DisplayRow, &Block)> {
) -> impl Iterator<Item = (DisplayRow, &TransformBlock)> {
self.block_snapshot
.blocks_in_range(rows.start.0..rows.end.0)
.map(|(row, block)| (DisplayRow(row), block))
}
pub fn block_for_id(&self, id: BlockId) -> Option<Block> {
self.block_snapshot.block_for_id(id)
}
pub fn intersects_fold<T: ToOffset>(&self, offset: T) -> bool {
self.fold_snapshot.intersects_fold(offset)
}

View File

@@ -18,7 +18,7 @@ use std::{
Arc,
},
};
use sum_tree::{Bias, SumTree, TreeMap};
use sum_tree::{Bias, SumTree};
use text::Edit;
use ui::ElementId;
@@ -30,8 +30,7 @@ const NEWLINES: &[u8] = &[b'\n'; u8::MAX as usize];
pub struct BlockMap {
next_block_id: AtomicUsize,
wrap_snapshot: RefCell<WrapSnapshot>,
custom_blocks: Vec<Arc<CustomBlock>>,
custom_blocks_by_id: TreeMap<CustomBlockId, Arc<CustomBlock>>,
blocks: Vec<Arc<Block>>,
transforms: RefCell<SumTree<Transform>>,
show_excerpt_controls: bool,
buffer_header_height: u8,
@@ -40,7 +39,7 @@ pub struct BlockMap {
}
pub struct BlockMapReader<'a> {
blocks: &'a Vec<Arc<CustomBlock>>,
blocks: &'a Vec<Arc<Block>>,
pub snapshot: BlockSnapshot,
}
@@ -50,13 +49,12 @@ pub struct BlockMapWriter<'a>(&'a mut BlockMap);
pub struct BlockSnapshot {
wrap_snapshot: WrapSnapshot,
transforms: SumTree<Transform>,
custom_blocks_by_id: TreeMap<CustomBlockId, Arc<CustomBlock>>,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CustomBlockId(usize);
pub struct BlockId(usize);
impl Into<ElementId> for CustomBlockId {
impl Into<ElementId> for BlockId {
fn into(self) -> ElementId {
ElementId::Integer(self.0)
}
@@ -73,8 +71,8 @@ struct WrapRow(u32);
pub type RenderBlock = Box<dyn Send + FnMut(&mut BlockContext) -> AnyElement>;
pub struct CustomBlock {
id: CustomBlockId,
pub struct Block {
id: BlockId,
position: Anchor,
height: u8,
style: BlockStyle,
@@ -115,41 +113,41 @@ pub struct BlockContext<'a, 'b> {
pub gutter_dimensions: &'b GutterDimensions,
pub em_width: Pixels,
pub line_height: Pixels,
pub block_id: BlockId,
pub transform_block_id: TransformBlockId,
pub editor_style: &'b EditorStyle,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BlockId {
Custom(CustomBlockId),
pub enum TransformBlockId {
Block(BlockId),
ExcerptHeader(ExcerptId),
ExcerptFooter(ExcerptId),
}
impl From<BlockId> for EntityId {
fn from(value: BlockId) -> Self {
impl From<TransformBlockId> for EntityId {
fn from(value: TransformBlockId) -> Self {
match value {
BlockId::Custom(CustomBlockId(id)) => EntityId::from(id as u64),
BlockId::ExcerptHeader(id) => id.into(),
BlockId::ExcerptFooter(id) => id.into(),
TransformBlockId::Block(BlockId(id)) => EntityId::from(id as u64),
TransformBlockId::ExcerptHeader(id) => id.into(),
TransformBlockId::ExcerptFooter(id) => id.into(),
}
}
}
impl From<BlockId> for ElementId {
fn from(value: BlockId) -> Self {
match value {
BlockId::Custom(CustomBlockId(id)) => ("Block", id).into(),
BlockId::ExcerptHeader(id) => ("ExcerptHeader", EntityId::from(id)).into(),
BlockId::ExcerptFooter(id) => ("ExcerptFooter", EntityId::from(id)).into(),
impl Into<ElementId> for TransformBlockId {
fn into(self) -> ElementId {
match self {
Self::Block(BlockId(id)) => ("Block", id).into(),
Self::ExcerptHeader(id) => ("ExcerptHeader", EntityId::from(id)).into(),
Self::ExcerptFooter(id) => ("ExcerptFooter", EntityId::from(id)).into(),
}
}
}
impl std::fmt::Display for BlockId {
impl std::fmt::Display for TransformBlockId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Custom(id) => write!(f, "Block({id:?})"),
Self::Block(id) => write!(f, "Block({id:?})"),
Self::ExcerptHeader(id) => write!(f, "ExcerptHeader({id:?})"),
Self::ExcerptFooter(id) => write!(f, "ExcerptFooter({id:?})"),
}
@@ -166,11 +164,11 @@ pub enum BlockDisposition {
#[derive(Clone, Debug)]
struct Transform {
summary: TransformSummary,
block: Option<Block>,
block: Option<TransformBlock>,
}
pub(crate) enum BlockType {
Custom(CustomBlockId),
Custom(BlockId),
Header,
Footer,
}
@@ -182,8 +180,8 @@ pub(crate) trait BlockLike {
#[allow(clippy::large_enum_variant)]
#[derive(Clone)]
pub enum Block {
Custom(Arc<CustomBlock>),
pub enum TransformBlock {
Custom(Arc<Block>),
ExcerptHeader {
id: ExcerptId,
buffer: BufferSnapshot,
@@ -199,12 +197,12 @@ pub enum Block {
},
}
impl BlockLike for Block {
impl BlockLike for TransformBlock {
fn block_type(&self) -> BlockType {
match self {
Block::Custom(block) => BlockType::Custom(block.id),
Block::ExcerptHeader { .. } => BlockType::Header,
Block::ExcerptFooter { .. } => BlockType::Footer,
TransformBlock::Custom(block) => BlockType::Custom(block.id),
TransformBlock::ExcerptHeader { .. } => BlockType::Header,
TransformBlock::ExcerptFooter { .. } => BlockType::Footer,
}
}
@@ -213,41 +211,33 @@ impl BlockLike for Block {
}
}
impl Block {
pub fn id(&self) -> BlockId {
impl TransformBlock {
pub fn id(&self) -> TransformBlockId {
match self {
Block::Custom(block) => BlockId::Custom(block.id),
Block::ExcerptHeader { id, .. } => BlockId::ExcerptHeader(*id),
Block::ExcerptFooter { id, .. } => BlockId::ExcerptFooter(*id),
TransformBlock::Custom(block) => TransformBlockId::Block(block.id),
TransformBlock::ExcerptHeader { id, .. } => TransformBlockId::ExcerptHeader(*id),
TransformBlock::ExcerptFooter { id, .. } => TransformBlockId::ExcerptFooter(*id),
}
}
fn disposition(&self) -> BlockDisposition {
match self {
Block::Custom(block) => block.disposition,
Block::ExcerptHeader { .. } => BlockDisposition::Above,
Block::ExcerptFooter { disposition, .. } => *disposition,
TransformBlock::Custom(block) => block.disposition,
TransformBlock::ExcerptHeader { .. } => BlockDisposition::Above,
TransformBlock::ExcerptFooter { disposition, .. } => *disposition,
}
}
pub fn height(&self) -> u8 {
match self {
Block::Custom(block) => block.height,
Block::ExcerptHeader { height, .. } => *height,
Block::ExcerptFooter { height, .. } => *height,
}
}
pub fn style(&self) -> BlockStyle {
match self {
Block::Custom(block) => block.style,
Block::ExcerptHeader { .. } => BlockStyle::Sticky,
Block::ExcerptFooter { .. } => BlockStyle::Sticky,
TransformBlock::Custom(block) => block.height,
TransformBlock::ExcerptHeader { height, .. } => *height,
TransformBlock::ExcerptFooter { height, .. } => *height,
}
}
}
impl Debug for Block {
impl Debug for TransformBlock {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Custom(block) => f.debug_struct("Custom").field("block", block).finish(),
@@ -262,7 +252,7 @@ impl Debug for Block {
.field("path", &buffer.file().map(|f| f.path()))
.field("starts_new_buffer", &starts_new_buffer)
.finish(),
Block::ExcerptFooter {
TransformBlock::ExcerptFooter {
id, disposition, ..
} => f
.debug_struct("ExcerptFooter")
@@ -306,8 +296,7 @@ impl BlockMap {
let row_count = wrap_snapshot.max_point().row() + 1;
let map = Self {
next_block_id: AtomicUsize::new(0),
custom_blocks: Vec::new(),
custom_blocks_by_id: TreeMap::default(),
blocks: Vec::new(),
transforms: RefCell::new(SumTree::from_item(Transform::isomorphic(row_count), &())),
wrap_snapshot: RefCell::new(wrap_snapshot.clone()),
show_excerpt_controls,
@@ -329,11 +318,10 @@ impl BlockMap {
self.sync(&wrap_snapshot, edits);
*self.wrap_snapshot.borrow_mut() = wrap_snapshot.clone();
BlockMapReader {
blocks: &self.custom_blocks,
blocks: &self.blocks,
snapshot: BlockSnapshot {
wrap_snapshot,
transforms: self.transforms.borrow().clone(),
custom_blocks_by_id: self.custom_blocks_by_id.clone(),
},
}
}
@@ -455,26 +443,25 @@ impl BlockMap {
let new_buffer_start =
wrap_snapshot.to_point(WrapPoint::new(new_start.0, 0), Bias::Left);
let start_bound = Bound::Included(new_buffer_start);
let start_block_ix =
match self.custom_blocks[last_block_ix..].binary_search_by(|probe| {
probe
.position
.to_point(buffer)
.cmp(&new_buffer_start)
.then(Ordering::Greater)
}) {
Ok(ix) | Err(ix) => last_block_ix + ix,
};
let start_block_ix = match self.blocks[last_block_ix..].binary_search_by(|probe| {
probe
.position
.to_point(buffer)
.cmp(&new_buffer_start)
.then(Ordering::Greater)
}) {
Ok(ix) | Err(ix) => last_block_ix + ix,
};
let end_bound;
let end_block_ix = if new_end.0 > wrap_snapshot.max_point().row() {
end_bound = Bound::Unbounded;
self.custom_blocks.len()
self.blocks.len()
} else {
let new_buffer_end =
wrap_snapshot.to_point(WrapPoint::new(new_end.0, 0), Bias::Left);
end_bound = Bound::Excluded(new_buffer_end);
match self.custom_blocks[start_block_ix..].binary_search_by(|probe| {
match self.blocks[start_block_ix..].binary_search_by(|probe| {
probe
.position
.to_point(buffer)
@@ -487,22 +474,24 @@ impl BlockMap {
last_block_ix = end_block_ix;
debug_assert!(blocks_in_edit.is_empty());
blocks_in_edit.extend(self.custom_blocks[start_block_ix..end_block_ix].iter().map(
|block| {
let mut position = block.position.to_point(buffer);
match block.disposition {
BlockDisposition::Above => position.column = 0,
BlockDisposition::Below => {
position.column = buffer.line_len(MultiBufferRow(position.row))
blocks_in_edit.extend(
self.blocks[start_block_ix..end_block_ix]
.iter()
.map(|block| {
let mut position = block.position.to_point(buffer);
match block.disposition {
BlockDisposition::Above => position.column = 0,
BlockDisposition::Below => {
position.column = buffer.line_len(MultiBufferRow(position.row))
}
}
}
let position = wrap_snapshot.make_wrap_point(position, Bias::Left);
(position.row(), Block::Custom(block.clone()))
},
));
let position = wrap_snapshot.make_wrap_point(position, Bias::Left);
(position.row(), TransformBlock::Custom(block.clone()))
}),
);
if buffer.show_headers() {
blocks_in_edit.extend(BlockMap::header_and_footer_blocks(
blocks_in_edit.extend(BlockMap::header_blocks(
self.show_excerpt_controls,
self.excerpt_footer_height,
self.buffer_header_height,
@@ -549,8 +538,8 @@ impl BlockMap {
*transforms = new_transforms;
}
pub fn replace_renderers(&mut self, mut renderers: HashMap<CustomBlockId, RenderBlock>) {
for block in &mut self.custom_blocks {
pub fn replace_renderers(&mut self, mut renderers: HashMap<BlockId, RenderBlock>) {
for block in &mut self.blocks {
if let Some(render) = renderers.remove(&block.id) {
*block.render.lock() = render;
}
@@ -561,7 +550,7 @@ impl BlockMap {
self.show_excerpt_controls
}
pub fn header_and_footer_blocks<'a, 'b: 'a, 'c: 'a + 'b, R, T>(
pub fn header_blocks<'a, 'b: 'a, 'c: 'a + 'b, R, T>(
show_excerpt_controls: bool,
excerpt_footer_height: u8,
buffer_header_height: u8,
@@ -569,7 +558,7 @@ impl BlockMap {
buffer: &'b multi_buffer::MultiBufferSnapshot,
range: R,
wrap_snapshot: &'c WrapSnapshot,
) -> impl Iterator<Item = (u32, Block)> + 'b
) -> impl Iterator<Item = (u32, TransformBlock)> + 'b
where
R: RangeBounds<T>,
T: multi_buffer::ToOffset,
@@ -577,36 +566,24 @@ impl BlockMap {
buffer
.excerpt_boundaries_in_range(range)
.flat_map(move |excerpt_boundary| {
let mut wrap_row = wrap_snapshot
let wrap_row = wrap_snapshot
.make_wrap_point(Point::new(excerpt_boundary.row.0, 0), Bias::Left)
.row();
[
show_excerpt_controls
.then(|| {
let disposition;
if excerpt_boundary.next.is_some() {
disposition = BlockDisposition::Above;
} else {
wrap_row = wrap_snapshot
.make_wrap_point(
Point::new(
excerpt_boundary.row.0,
buffer.line_len(excerpt_boundary.row),
),
Bias::Left,
)
.row();
disposition = BlockDisposition::Below;
}
excerpt_boundary.prev.as_ref().map(|prev| {
(
wrap_row,
Block::ExcerptFooter {
TransformBlock::ExcerptFooter {
id: prev.id,
height: excerpt_footer_height,
disposition,
disposition: if excerpt_boundary.next.is_some() {
BlockDisposition::Above
} else {
BlockDisposition::Below
},
},
)
})
@@ -619,7 +596,7 @@ impl BlockMap {
(
wrap_row,
Block::ExcerptHeader {
TransformBlock::ExcerptHeader {
id: next.id,
buffer: next.buffer,
range: next.range,
@@ -715,7 +692,7 @@ impl<'a> DerefMut for BlockMapReader<'a> {
}
impl<'a> BlockMapReader<'a> {
pub fn row_for_block(&self, block_id: CustomBlockId) -> Option<BlockRow> {
pub fn row_for_block(&self, block_id: BlockId) -> Option<BlockRow> {
let block = self.blocks.iter().find(|block| block.id == block_id)?;
let buffer_row = block
.position
@@ -760,14 +737,14 @@ impl<'a> BlockMapWriter<'a> {
pub fn insert(
&mut self,
blocks: impl IntoIterator<Item = BlockProperties<Anchor>>,
) -> Vec<CustomBlockId> {
) -> Vec<BlockId> {
let mut ids = Vec::new();
let mut edits = Patch::default();
let wrap_snapshot = &*self.0.wrap_snapshot.borrow();
let buffer = wrap_snapshot.buffer_snapshot();
for block in blocks {
let id = CustomBlockId(self.0.next_block_id.fetch_add(1, SeqCst));
let id = BlockId(self.0.next_block_id.fetch_add(1, SeqCst));
ids.push(id);
let position = block.position;
@@ -782,21 +759,22 @@ impl<'a> BlockMapWriter<'a> {
let block_ix = match self
.0
.custom_blocks
.blocks
.binary_search_by(|probe| probe.position.cmp(&position, buffer))
{
Ok(ix) | Err(ix) => ix,
};
let new_block = Arc::new(CustomBlock {
id,
position,
height: block.height,
render: Mutex::new(block.render),
disposition: block.disposition,
style: block.style,
});
self.0.custom_blocks.insert(block_ix, new_block.clone());
self.0.custom_blocks_by_id.insert(id, new_block);
self.0.blocks.insert(
block_ix,
Arc::new(Block {
id,
position,
height: block.height,
render: Mutex::new(block.render),
disposition: block.disposition,
style: block.style,
}),
);
edits = edits.compose([Edit {
old: start_row..end_row,
@@ -808,19 +786,16 @@ impl<'a> BlockMapWriter<'a> {
ids
}
pub fn replace(
&mut self,
mut heights_and_renderers: HashMap<CustomBlockId, (u8, RenderBlock)>,
) {
pub fn replace(&mut self, mut heights_and_renderers: HashMap<BlockId, (u8, RenderBlock)>) {
let wrap_snapshot = &*self.0.wrap_snapshot.borrow();
let buffer = wrap_snapshot.buffer_snapshot();
let mut edits = Patch::default();
let mut last_block_buffer_row = None;
for block in &mut self.0.custom_blocks {
for block in &mut self.0.blocks {
if let Some((new_height, render)) = heights_and_renderers.remove(&block.id) {
if block.height != new_height {
let new_block = CustomBlock {
let new_block = Block {
id: block.id,
position: block.position,
height: new_height,
@@ -828,9 +803,7 @@ impl<'a> BlockMapWriter<'a> {
render: Mutex::new(render),
disposition: block.disposition,
};
let new_block = Arc::new(new_block);
*block = new_block.clone();
self.0.custom_blocks_by_id.insert(block.id, new_block);
*block = Arc::new(new_block);
let buffer_row = block.position.to_point(buffer).row;
if last_block_buffer_row != Some(buffer_row) {
@@ -855,12 +828,12 @@ impl<'a> BlockMapWriter<'a> {
self.0.sync(wrap_snapshot, edits);
}
pub fn remove(&mut self, block_ids: HashSet<CustomBlockId>) {
pub fn remove(&mut self, block_ids: HashSet<BlockId>) {
let wrap_snapshot = &*self.0.wrap_snapshot.borrow();
let buffer = wrap_snapshot.buffer_snapshot();
let mut edits = Patch::default();
let mut last_block_buffer_row = None;
self.0.custom_blocks.retain(|block| {
self.0.blocks.retain(|block| {
if block_ids.contains(&block.id) {
let buffer_row = block.position.to_point(buffer).row;
if last_block_buffer_row != Some(buffer_row) {
@@ -877,7 +850,6 @@ impl<'a> BlockMapWriter<'a> {
new: start_row..end_row,
})
}
self.0.custom_blocks_by_id.remove(&block.id);
false
} else {
true
@@ -962,7 +934,10 @@ impl BlockSnapshot {
}
}
pub fn blocks_in_range(&self, rows: Range<u32>) -> impl Iterator<Item = (u32, &Block)> {
pub fn blocks_in_range(
&self,
rows: Range<u32>,
) -> impl Iterator<Item = (u32, &TransformBlock)> {
let mut cursor = self.transforms.cursor::<BlockRow>();
cursor.seek(&BlockRow(rows.start), Bias::Right, &());
std::iter::from_fn(move || {
@@ -982,60 +957,6 @@ impl BlockSnapshot {
})
}
pub fn block_for_id(&self, block_id: BlockId) -> Option<Block> {
let buffer = self.wrap_snapshot.buffer_snapshot();
match block_id {
BlockId::Custom(custom_block_id) => {
let custom_block = self.custom_blocks_by_id.get(&custom_block_id)?;
Some(Block::Custom(custom_block.clone()))
}
BlockId::ExcerptHeader(excerpt_id) => {
let excerpt_range = buffer.range_for_excerpt::<Point>(excerpt_id)?;
let wrap_point = self
.wrap_snapshot
.make_wrap_point(excerpt_range.start, Bias::Left);
let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>();
cursor.seek(&WrapRow(wrap_point.row()), Bias::Left, &());
while let Some(transform) = cursor.item() {
if let Some(block) = transform.block.as_ref() {
if block.id() == block_id {
return Some(block.clone());
}
} else if cursor.start().0 > WrapRow(wrap_point.row()) {
break;
}
cursor.next(&());
}
None
}
BlockId::ExcerptFooter(excerpt_id) => {
let excerpt_range = buffer.range_for_excerpt::<Point>(excerpt_id)?;
let wrap_point = self
.wrap_snapshot
.make_wrap_point(excerpt_range.end, Bias::Left);
let mut cursor = self.transforms.cursor::<(WrapRow, BlockRow)>();
cursor.seek(&WrapRow(wrap_point.row()), Bias::Left, &());
while let Some(transform) = cursor.item() {
if let Some(block) = transform.block.as_ref() {
if block.id() == block_id {
return Some(block.clone());
}
} else if cursor.start().0 > WrapRow(wrap_point.row()) {
break;
}
cursor.next(&());
}
None
}
}
}
pub fn max_point(&self) -> BlockPoint {
let row = self.transforms.summary().output_rows - 1;
BlockPoint::new(row, self.line_len(BlockRow(row)))
@@ -1165,7 +1086,7 @@ impl Transform {
}
}
fn block(block: Block) -> Self {
fn block(block: TransformBlock) -> Self {
Self {
summary: TransformSummary {
input_rows: 0,
@@ -1314,7 +1235,7 @@ impl DerefMut for BlockContext<'_, '_> {
}
}
impl CustomBlock {
impl Block {
pub fn render(&self, cx: &mut BlockContext) -> AnyElement {
self.render.lock()(cx)
}
@@ -1328,7 +1249,7 @@ impl CustomBlock {
}
}
impl Debug for CustomBlock {
impl Debug for Block {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Block")
.field("id", &self.id)
@@ -1358,16 +1279,15 @@ fn offset_for_row(s: &str, target: u32) -> (u32, usize) {
#[cfg(test)]
mod tests {
use std::env;
use super::*;
use crate::display_map::{
fold_map::FoldMap, inlay_map::InlayMap, tab_map::TabMap, wrap_map::WrapMap,
};
use gpui::{div, font, px, AppContext, Context as _, Element};
use language::{Buffer, Capability};
use crate::display_map::inlay_map::InlayMap;
use crate::display_map::{fold_map::FoldMap, tab_map::TabMap, wrap_map::WrapMap};
use gpui::{div, font, px, Element};
use multi_buffer::MultiBuffer;
use rand::prelude::*;
use settings::SettingsStore;
use std::env;
use util::RandomCharIter;
#[gpui::test]
@@ -1554,89 +1474,6 @@ mod tests {
assert_eq!(snapshot.text(), "aaa\n\nb!!!\n\n\nbb\nccc\nddd\n\n\n");
}
#[gpui::test]
fn test_multibuffer_headers_and_footers(cx: &mut AppContext) {
init_test(cx);
let buffer1 = cx.new_model(|cx| Buffer::local("Buffer 1", cx));
let buffer2 = cx.new_model(|cx| Buffer::local("Buffer 2", cx));
let buffer3 = cx.new_model(|cx| Buffer::local("Buffer 3", cx));
let mut excerpt_ids = Vec::new();
let multi_buffer = cx.new_model(|cx| {
let mut multi_buffer = MultiBuffer::new(0, Capability::ReadWrite);
excerpt_ids.extend(multi_buffer.push_excerpts(
buffer1.clone(),
[ExcerptRange {
context: 0..buffer1.read(cx).len(),
primary: None,
}],
cx,
));
excerpt_ids.extend(multi_buffer.push_excerpts(
buffer2.clone(),
[ExcerptRange {
context: 0..buffer2.read(cx).len(),
primary: None,
}],
cx,
));
excerpt_ids.extend(multi_buffer.push_excerpts(
buffer3.clone(),
[ExcerptRange {
context: 0..buffer3.read(cx).len(),
primary: None,
}],
cx,
));
multi_buffer
});
let font = font("Helvetica");
let font_size = px(14.);
let font_id = cx.text_system().resolve_font(&font);
let mut wrap_width = px(0.);
for c in "Buff".chars() {
wrap_width += cx
.text_system()
.advance(font_id, font_size, c)
.unwrap()
.width;
}
let multi_buffer_snapshot = multi_buffer.read(cx).snapshot(cx);
let (_, inlay_snapshot) = InlayMap::new(multi_buffer_snapshot.clone());
let (_, fold_snapshot) = FoldMap::new(inlay_snapshot);
let (_, tab_snapshot) = TabMap::new(fold_snapshot, 4.try_into().unwrap());
let (_, wraps_snapshot) = WrapMap::new(tab_snapshot, font, font_size, Some(wrap_width), cx);
let block_map = BlockMap::new(wraps_snapshot.clone(), true, 1, 1, 1);
let snapshot = block_map.read(wraps_snapshot, Default::default());
// Each excerpt has a header above and footer below. Excerpts are also *separated* by a newline.
assert_eq!(
snapshot.text(),
"\nBuff\ner 1\n\n\nBuff\ner 2\n\n\nBuff\ner 3\n"
);
let blocks: Vec<_> = snapshot
.blocks_in_range(0..u32::MAX)
.map(|(row, block)| (row, block.id()))
.collect();
assert_eq!(
blocks,
vec![
(0, BlockId::ExcerptHeader(excerpt_ids[0])),
(3, BlockId::ExcerptFooter(excerpt_ids[0])),
(4, BlockId::ExcerptHeader(excerpt_ids[1])),
(7, BlockId::ExcerptFooter(excerpt_ids[1])),
(8, BlockId::ExcerptHeader(excerpt_ids[2])),
(11, BlockId::ExcerptFooter(excerpt_ids[2]))
]
);
}
#[gpui::test]
fn test_replace_with_heights(cx: &mut gpui::TestAppContext) {
let _update = cx.update(|cx| init_test(cx));
@@ -1970,7 +1807,7 @@ mod tests {
// Note that this needs to be synced with the related section in BlockMap::sync
expected_blocks.extend(
BlockMap::header_and_footer_blocks(
BlockMap::header_blocks(
true,
excerpt_footer_height,
buffer_start_header_height,
@@ -2074,16 +1911,6 @@ mod tests {
expected_block_positions
);
for (_, expected_block) in
blocks_snapshot.blocks_in_range(0..(expected_row_count as u32))
{
let actual_block = blocks_snapshot.block_for_id(expected_block.id());
assert_eq!(
actual_block.map(|block| block.id()),
Some(expected_block.id())
);
}
for (block_row, block) in expected_block_positions {
if let BlockType::Custom(block_id) = block.block_type() {
assert_eq!(
@@ -2180,7 +2007,7 @@ mod tests {
},
Custom {
disposition: BlockDisposition,
id: CustomBlockId,
id: BlockId,
height: u8,
},
}
@@ -2217,15 +2044,15 @@ mod tests {
}
}
impl From<Block> for ExpectedBlock {
fn from(block: Block) -> Self {
impl From<TransformBlock> for ExpectedBlock {
fn from(block: TransformBlock) -> Self {
match block {
Block::Custom(block) => ExpectedBlock::Custom {
TransformBlock::Custom(block) => ExpectedBlock::Custom {
id: block.id,
disposition: block.disposition,
height: block.height,
},
Block::ExcerptHeader {
TransformBlock::ExcerptHeader {
height,
starts_new_buffer,
..
@@ -2233,7 +2060,7 @@ mod tests {
height,
starts_new_buffer,
},
Block::ExcerptFooter {
TransformBlock::ExcerptFooter {
height,
disposition,
..
@@ -2253,12 +2080,12 @@ mod tests {
assets::Assets.load_test_fonts(cx);
}
impl Block {
fn as_custom(&self) -> Option<&CustomBlock> {
impl TransformBlock {
fn as_custom(&self) -> Option<&Block> {
match self {
Block::Custom(block) => Some(block),
Block::ExcerptHeader { .. } => None,
Block::ExcerptFooter { .. } => None,
TransformBlock::Custom(block) => Some(block),
TransformBlock::ExcerptHeader { .. } => None,
TransformBlock::ExcerptFooter { .. } => None,
}
}
}

View File

@@ -69,16 +69,16 @@ use gpui::{
div, impl_actions, point, prelude::*, px, relative, size, uniform_list, Action, AnyElement,
AppContext, AsyncWindowContext, AvailableSpace, BackgroundExecutor, Bounds, ClipboardItem,
Context, DispatchPhase, ElementId, EntityId, EventEmitter, FocusHandle, FocusOutEvent,
FocusableView, FontId, FontWeight, HighlightStyle, Hsla, InteractiveText, KeyContext,
ListSizingBehavior, Model, MouseButton, PaintQuad, ParentElement, Pixels, Render, SharedString,
Size, StrikethroughStyle, Styled, StyledText, Subscription, Task, TextStyle, UnderlineStyle,
UniformListScrollHandle, View, ViewContext, ViewInputHandler, VisualContext, WeakFocusHandle,
WeakView, WindowContext,
FocusableView, FontId, FontStyle, FontWeight, HighlightStyle, Hsla, InteractiveText,
KeyContext, ListSizingBehavior, Model, MouseButton, PaintQuad, ParentElement, Pixels, Render,
SharedString, Size, StrikethroughStyle, Styled, StyledText, Subscription, Task, TextStyle,
UnderlineStyle, UniformListScrollHandle, View, ViewContext, ViewInputHandler, VisualContext,
WeakFocusHandle, WeakView, WhiteSpace, WindowContext,
};
use highlight_matching_bracket::refresh_matching_bracket_highlights;
use hover_popover::{hide_hover, HoverState};
use hunk_diff::ExpandedHunks;
pub(crate) use hunk_diff::HoveredHunk;
pub(crate) use hunk_diff::HunkToExpand;
use indent_guides::ActiveIndentGuidesState;
use inlay_hint_cache::{InlayHintCache, InlaySplice, InvalidationStrategy};
pub use inline_completion_provider::*;
@@ -131,7 +131,7 @@ use std::{
mem,
num::NonZeroU32,
ops::{ControlFlow, Deref, DerefMut, Not as _, Range, RangeInclusive},
path::{Path, PathBuf},
path::Path,
rc::Rc,
sync::Arc,
time::{Duration, Instant},
@@ -272,8 +272,7 @@ pub fn init(cx: &mut AppContext) {
workspace::register_project_item::<Editor>(cx);
workspace::FollowableViewRegistry::register::<Editor>(cx);
workspace::register_serializable_item::<Editor>(cx);
workspace::register_deserializable_item::<Editor>(cx);
cx.observe_new_views(
|workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
workspace.register_action(Editor::new_file);
@@ -488,7 +487,6 @@ pub struct Editor {
mode: EditorMode,
show_breadcrumbs: bool,
show_gutter: bool,
redact_all: bool,
show_line_numbers: Option<bool>,
show_git_diff_gutter: Option<bool>,
show_code_actions: Option<bool>,
@@ -533,7 +531,7 @@ pub struct Editor {
gutter_hovered: bool,
hovered_link_state: Option<HoveredLinkState>,
inline_completion_provider: Option<RegisteredInlineCompletionProvider>,
active_inline_completion: Option<(Inlay, Option<Range<Anchor>>)>,
active_inline_completion: Option<Inlay>,
show_inline_completions: bool,
inlay_hint_cache: InlayHintCache,
expanded_hunks: ExpandedHunks,
@@ -552,7 +550,6 @@ pub struct Editor {
show_git_blame_inline: bool,
show_git_blame_inline_delay_task: Option<Task<()>>,
git_blame_inline_enabled: bool,
serialize_dirty_buffers: bool,
show_selection_menu: Option<bool>,
blame: Option<Model<GitBlame>>,
blame_subscription: Option<Subscription>,
@@ -569,7 +566,6 @@ pub struct Editor {
previous_search_ranges: Option<Arc<[Range<Anchor>]>>,
file_header_size: u8,
breadcrumb_header: Option<String>,
focused_block: Option<FocusedBlock>,
}
#[derive(Clone)]
@@ -787,7 +783,7 @@ pub struct RenameState {
pub range: Range<Anchor>,
pub old_name: Arc<str>,
pub editor: View<Editor>,
block_id: CustomBlockId,
block_id: BlockId,
}
struct InvalidationStack<T>(Vec<T>);
@@ -1539,7 +1535,7 @@ struct ActiveDiagnosticGroup {
primary_range: Range<Anchor>,
primary_message: String,
group_id: usize,
blocks: HashMap<CustomBlockId, Diagnostic>,
blocks: HashMap<BlockId, Diagnostic>,
is_valid: bool,
}
@@ -1587,11 +1583,6 @@ impl InlayHintRefreshReason {
}
}
pub(crate) struct FocusedBlock {
id: BlockId,
focus_handle: WeakFocusHandle,
}
impl Editor {
pub fn single_line(cx: &mut ViewContext<Self>) -> Self {
let buffer = cx.new_model(|cx| Buffer::local("", cx));
@@ -1823,7 +1814,6 @@ impl Editor {
show_code_actions: None,
show_runnables: None,
show_wrap_guides: None,
redact_all: false,
show_indent_guides,
placeholder_text: None,
highlight_order: 0,
@@ -1886,9 +1876,6 @@ impl Editor {
show_selection_menu: None,
show_git_blame_inline_delay_task: None,
git_blame_inline_enabled: ProjectSettings::get_global(cx).git.inline_blame_enabled(),
serialize_dirty_buffers: ProjectSettings::get_global(cx)
.session
.restore_unsaved_buffers,
blame: None,
blame_subscription: None,
file_header_size,
@@ -1916,7 +1903,6 @@ impl Editor {
linked_edit_ranges: Default::default(),
previous_search_ranges: None,
breadcrumb_header: None,
focused_block: None,
};
this.tasks_update_task = Some(this.refresh_runnables(cx));
this._subscriptions.extend(project_subscriptions);
@@ -1953,7 +1939,7 @@ impl Editor {
EditorMode::Full => "full",
};
if EditorSettings::jupyter_enabled(cx) {
if EditorSettings::get_global(cx).jupyter.enabled {
key_context.add("jupyter");
}
@@ -2001,35 +1987,28 @@ impl Editor {
_: &workspace::NewFile,
cx: &mut ViewContext<Workspace>,
) {
Self::new_in_workspace(workspace, cx).detach_and_prompt_err(
"Failed to create buffer",
cx,
|e, _| match e.error_code() {
ErrorCode::RemoteUpgradeRequired => Some(format!(
"The remote instance of Zed does not support this yet. It must be upgraded to {}",
e.error_tag("required").unwrap_or("the latest version")
)),
_ => None,
},
);
}
pub fn new_in_workspace(
workspace: &mut Workspace,
cx: &mut ViewContext<Workspace>,
) -> Task<Result<View<Editor>>> {
let project = workspace.project().clone();
let create = project.update(cx, |project, cx| project.create_buffer(cx));
cx.spawn(|workspace, mut cx| async move {
let buffer = create.await?;
workspace.update(&mut cx, |workspace, cx| {
let editor =
cx.new_view(|cx| Editor::for_buffer(buffer, Some(project.clone()), cx));
workspace.add_item_to_active_pane(Box::new(editor.clone()), None, true, cx);
editor
workspace.add_item_to_active_pane(
Box::new(
cx.new_view(|cx| Editor::for_buffer(buffer, Some(project.clone()), cx)),
),
None,
cx,
)
})
})
.detach_and_prompt_err("Failed to create buffer", cx, |e, _| match e.error_code() {
ErrorCode::RemoteUpgradeRequired => Some(format!(
"The remote instance of Zed does not support this yet. It must be upgraded to {}",
e.error_tag("required").unwrap_or("the latest version")
)),
_ => None,
});
}
pub fn new_file_in_direction(
@@ -2882,10 +2861,7 @@ impl Editor {
}
pub fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext<Self>) {
if self.clear_clicked_diff_hunks(cx) {
cx.notify();
return;
}
self.clear_expanded_diff_hunks(cx);
if self.dismiss_menus_and_popups(true, cx) {
return;
}
@@ -2920,10 +2896,6 @@ impl Editor {
return true;
}
if self.mouse_context_menu.take().is_some() {
return true;
}
if self.discard_inline_completion(should_report_inline_completion_event, cx) {
return true;
}
@@ -3661,7 +3633,7 @@ impl Editor {
if self.is_completion_trigger(text, trigger_in_words, cx) {
self.show_completions(
&ShowCompletions {
trigger: Some(text.to_owned()).filter(|x| !x.is_empty()),
trigger: text.chars().last(),
},
cx,
);
@@ -4085,18 +4057,15 @@ impl Editor {
Some(ContextMenu::Completions(_))
)
};
let trigger_kind = match (&options.trigger, is_followup_invoke) {
let trigger_kind = match (options.trigger, is_followup_invoke) {
(_, true) => CompletionTriggerKind::TRIGGER_FOR_INCOMPLETE_COMPLETIONS,
(Some(trigger), _) if buffer.read(cx).completion_triggers().contains(&trigger) => {
CompletionTriggerKind::TRIGGER_CHARACTER
}
(Some(_), _) => CompletionTriggerKind::TRIGGER_CHARACTER,
_ => CompletionTriggerKind::INVOKED,
};
let completion_context = CompletionContext {
trigger_character: options.trigger.as_ref().and_then(|trigger| {
trigger_character: options.trigger.and_then(|c| {
if trigger_kind == CompletionTriggerKind::TRIGGER_CHARACTER {
Some(String::from(trigger))
Some(String::from(c))
} else {
None
}
@@ -4681,7 +4650,7 @@ impl Editor {
let project = workspace.project().clone();
let editor =
cx.new_view(|cx| Editor::for_multibuffer(excerpt_buffer, Some(project), true, cx));
workspace.add_item_to_active_pane(Box::new(editor.clone()), None, true, cx);
workspace.add_item_to_active_pane(Box::new(editor.clone()), None, cx);
editor.update(cx, |editor, cx| {
editor.highlight_background::<Self>(
&ranges_to_highlight,
@@ -4953,7 +4922,7 @@ impl Editor {
_: &AcceptInlineCompletion,
cx: &mut ViewContext<Self>,
) {
let Some((completion, delete_range)) = self.take_active_inline_completion(cx) else {
let Some(completion) = self.take_active_inline_completion(cx) else {
return;
};
if let Some(provider) = self.inline_completion_provider() {
@@ -4964,10 +4933,6 @@ impl Editor {
utf16_range_to_replace: None,
text: completion.text.to_string().into(),
});
if let Some(range) = delete_range {
self.change_selections(None, cx, |s| s.select_ranges([range]))
}
self.insert_with_autoindent_mode(&completion.text.to_string(), None, cx);
self.refresh_inline_completion(true, cx);
cx.notify();
@@ -4979,7 +4944,7 @@ impl Editor {
cx: &mut ViewContext<Self>,
) {
if self.selections.count() == 1 && self.has_active_inline_completion(cx) {
if let Some((completion, delete_range)) = self.take_active_inline_completion(cx) {
if let Some(completion) = self.take_active_inline_completion(cx) {
let mut partial_completion = completion
.text
.chars()
@@ -4999,12 +4964,7 @@ impl Editor {
utf16_range_to_replace: None,
text: partial_completion.clone().into(),
});
if let Some(range) = delete_range {
self.change_selections(None, cx, |s| s.select_ranges([range]))
}
self.insert_with_autoindent_mode(&partial_completion, None, cx);
self.refresh_inline_completion(true, cx);
cx.notify();
}
@@ -5026,23 +4986,20 @@ impl Editor {
pub fn has_active_inline_completion(&self, cx: &AppContext) -> bool {
if let Some(completion) = self.active_inline_completion.as_ref() {
let buffer = self.buffer.read(cx).read(cx);
completion.0.position.is_valid(&buffer)
completion.position.is_valid(&buffer)
} else {
false
}
}
fn take_active_inline_completion(
&mut self,
cx: &mut ViewContext<Self>,
) -> Option<(Inlay, Option<Range<Anchor>>)> {
fn take_active_inline_completion(&mut self, cx: &mut ViewContext<Self>) -> Option<Inlay> {
let completion = self.active_inline_completion.take()?;
self.display_map.update(cx, |map, cx| {
map.splice_inlays(vec![completion.0.id], Default::default(), cx);
map.splice_inlays(vec![completion.id], Default::default(), cx);
});
let buffer = self.buffer.read(cx).read(cx);
if completion.0.position.is_valid(&buffer) {
if completion.position.is_valid(&buffer) {
Some(completion)
} else {
None
@@ -5053,8 +5010,6 @@ impl Editor {
let selection = self.selections.newest_anchor();
let cursor = selection.head();
let excerpt_id = cursor.excerpt_id;
if self.context_menu.read().is_none()
&& self.completion_tasks.is_empty()
&& selection.start == selection.end
@@ -5063,28 +5018,18 @@ impl Editor {
if let Some((buffer, cursor_buffer_position)) =
self.buffer.read(cx).text_anchor_for_position(cursor, cx)
{
if let Some((text, text_anchor_range)) =
if let Some(text) =
provider.active_completion_text(&buffer, cursor_buffer_position, cx)
{
let text = Rope::from(text);
let mut to_remove = Vec::new();
if let Some(completion) = self.active_inline_completion.take() {
to_remove.push(completion.0.id);
to_remove.push(completion.id);
}
let completion_inlay =
Inlay::suggestion(post_inc(&mut self.next_inlay_id), cursor, text);
let multibuffer_anchor_range = text_anchor_range.and_then(|range| {
let snapshot = self.buffer.read(cx).snapshot(cx);
Some(
snapshot.anchor_in_excerpt(excerpt_id, range.start)?
..snapshot.anchor_in_excerpt(excerpt_id, range.end)?,
)
});
self.active_inline_completion =
Some((completion_inlay.clone(), multibuffer_anchor_range));
self.active_inline_completion = Some(completion_inlay.clone());
self.display_map.update(cx, move |map, cx| {
map.splice_inlays(to_remove, vec![completion_inlay], cx)
});
@@ -5165,23 +5110,6 @@ impl Editor {
}))
}
fn render_close_hunk_diff_button(
&self,
hunk: HoveredHunk,
row: DisplayRow,
cx: &mut ViewContext<Self>,
) -> IconButton {
IconButton::new(
("close_hunk_diff_indicator", row.0 as usize),
ui::IconName::Close,
)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.tooltip(|cx| Tooltip::for_action("Close hunk diff", &ToggleHunkDiff, cx))
.on_click(cx.listener(move |editor, _e, cx| editor.toggle_hovered_hunk(&hunk, cx)))
}
pub fn context_menu_visible(&self) -> bool {
self.context_menu
.read()
@@ -5936,7 +5864,22 @@ impl Editor {
let revert_changes = self.gather_revert_changes(&self.selections.disjoint_anchors(), cx);
if !revert_changes.is_empty() {
self.transact(cx, |editor, cx| {
editor.revert(revert_changes, cx);
editor.buffer().update(cx, |multi_buffer, cx| {
for (buffer_id, changes) in revert_changes {
if let Some(buffer) = multi_buffer.buffer(buffer_id) {
buffer.update(cx, |buffer, cx| {
buffer.edit(
changes.into_iter().map(|(range, text)| {
(range, text.to_string().map(Arc::<str>::from))
}),
None,
cx,
);
});
}
}
});
editor.change_selections(None, cx, |selections| selections.refresh());
});
}
}
@@ -5966,20 +5909,22 @@ impl Editor {
cx: &mut ViewContext<'_, Editor>,
) -> HashMap<BufferId, Vec<(Range<text::Anchor>, Rope)>> {
let mut revert_changes = HashMap::default();
let multi_buffer_snapshot = self.buffer.read(cx).snapshot(cx);
for hunk in hunks_for_selections(&multi_buffer_snapshot, selections) {
Self::prepare_revert_change(&mut revert_changes, self.buffer(), &hunk, cx);
}
self.buffer.update(cx, |multi_buffer, cx| {
let multi_buffer_snapshot = multi_buffer.snapshot(cx);
for hunk in hunks_for_selections(&multi_buffer_snapshot, selections) {
Self::prepare_revert_change(&mut revert_changes, &multi_buffer, &hunk, cx);
}
});
revert_changes
}
pub fn prepare_revert_change(
fn prepare_revert_change(
revert_changes: &mut HashMap<BufferId, Vec<(Range<text::Anchor>, Rope)>>,
multi_buffer: &Model<MultiBuffer>,
multi_buffer: &MultiBuffer,
hunk: &DiffHunk<MultiBufferRow>,
cx: &AppContext,
cx: &mut AppContext,
) -> Option<()> {
let buffer = multi_buffer.read(cx).buffer(hunk.buffer_id)?;
let buffer = multi_buffer.buffer(hunk.buffer_id)?;
let buffer = buffer.read(cx);
let original_text = buffer.diff_base()?.slice(hunk.diff_base_byte_range.clone());
let buffer_snapshot = buffer.snapshot();
@@ -9140,13 +9085,7 @@ impl Editor {
workspace.active_pane().clone()
};
workspace.open_project_item(
pane,
target.buffer.clone(),
true,
true,
cx,
)
workspace.open_project_item(pane, target.buffer.clone(), cx)
});
target_editor.update(cx, |target_editor, cx| {
// When selecting a definition in a different buffer, disable the nav history
@@ -9444,7 +9383,7 @@ impl Editor {
None
}
});
workspace.add_item_to_active_pane(item.clone(), destination_index, true, cx);
workspace.add_item_to_active_pane(item.clone(), destination_index, cx);
}
workspace.active_pane().update(cx, |pane, cx| {
pane.set_preview_item_id(Some(item_id), cx);
@@ -10183,7 +10122,7 @@ impl Editor {
blocks: impl IntoIterator<Item = BlockProperties<Anchor>>,
autoscroll: Option<Autoscroll>,
cx: &mut ViewContext<Self>,
) -> Vec<CustomBlockId> {
) -> Vec<BlockId> {
let blocks = self
.display_map
.update(cx, |display_map, cx| display_map.insert_blocks(blocks, cx));
@@ -10195,7 +10134,7 @@ impl Editor {
pub fn replace_blocks(
&mut self,
blocks: HashMap<CustomBlockId, (Option<u8>, RenderBlock)>,
blocks: HashMap<BlockId, (Option<u8>, RenderBlock)>,
autoscroll: Option<Autoscroll>,
cx: &mut ViewContext<Self>,
) {
@@ -10208,7 +10147,7 @@ impl Editor {
pub fn remove_blocks(
&mut self,
block_ids: HashSet<CustomBlockId>,
block_ids: HashSet<BlockId>,
autoscroll: Option<Autoscroll>,
cx: &mut ViewContext<Self>,
) {
@@ -10222,21 +10161,13 @@ impl Editor {
pub fn row_for_block(
&self,
block_id: CustomBlockId,
block_id: BlockId,
cx: &mut ViewContext<Self>,
) -> Option<DisplayRow> {
self.display_map
.update(cx, |map, cx| map.row_for_block(block_id, cx))
}
pub(crate) fn set_focused_block(&mut self, focused_block: FocusedBlock) {
self.focused_block = Some(focused_block);
}
pub(crate) fn take_focused_block(&mut self) -> Option<FocusedBlock> {
self.focused_block.take()
}
pub fn insert_creases(
&mut self,
creases: impl IntoIterator<Item = Crease>,
@@ -10384,7 +10315,7 @@ impl Editor {
};
let fs = workspace.read(cx).app_state().fs.clone();
let current_show = TabBarSettings::get_global(cx).show;
update_settings_file::<TabBarSettings>(fs, cx, move |setting, _| {
update_settings_file::<TabBarSettings>(fs, cx, move |setting| {
setting.show = Some(!current_show);
});
}
@@ -10440,11 +10371,6 @@ impl Editor {
cx.notify();
}
pub fn set_redact_all(&mut self, redact_all: bool, cx: &mut ViewContext<Self>) {
self.redact_all = redact_all;
cx.notify();
}
pub fn set_show_wrap_guides(&mut self, show_wrap_guides: bool, cx: &mut ViewContext<Self>) {
self.show_wrap_guides = Some(show_wrap_guides);
cx.notify();
@@ -10455,22 +10381,6 @@ impl Editor {
cx.notify();
}
pub fn working_directory(&self, cx: &WindowContext) -> Option<PathBuf> {
if let Some(buffer) = self.buffer().read(cx).as_singleton() {
if let Some(file) = buffer.read(cx).file().and_then(|f| f.as_local()) {
if let Some(dir) = file.abs_path(cx).parent() {
return Some(dir.to_owned());
}
}
if let Some(project_path) = buffer.read(cx).project_path(cx) {
return Some(project_path.path.to_path_buf());
}
}
None
}
pub fn reveal_in_finder(&mut self, _: &RevealInFileManager, cx: &mut ViewContext<Self>) {
if let Some(buffer) = self.buffer().read(cx).as_singleton() {
if let Some(file) = buffer.read(cx).file().and_then(|f| f.as_local()) {
@@ -11139,10 +11049,6 @@ impl Editor {
display_snapshot: &DisplaySnapshot,
cx: &WindowContext,
) -> Vec<Range<DisplayPoint>> {
if self.redact_all {
return vec![DisplayPoint::zero()..display_snapshot.max_point()];
}
display_snapshot
.buffer_snapshot
.redacted_ranges(search_range, |file| {
@@ -11344,11 +11250,8 @@ impl Editor {
self.scroll_manager.vertical_scroll_margin = editor_settings.vertical_scroll_margin;
self.show_breadcrumbs = editor_settings.toolbar.breadcrumbs;
let project_settings = ProjectSettings::get_global(cx);
self.serialize_dirty_buffers = project_settings.session.restore_unsaved_buffers;
if self.mode == EditorMode::Full {
let inline_blame_enabled = project_settings.git.inline_blame_enabled();
let inline_blame_enabled = ProjectSettings::get_global(cx).git.inline_blame_enabled();
if self.git_blame_inline_enabled != inline_blame_enabled {
self.toggle_git_blame_inline_internal(false, cx);
}
@@ -11412,8 +11315,7 @@ impl Editor {
};
for (buffer, ranges) in new_selections_by_buffer {
let editor =
workspace.open_project_item::<Self>(pane.clone(), buffer, true, true, cx);
let editor = workspace.open_project_item::<Self>(pane.clone(), buffer, cx);
editor.update(cx, |editor, cx| {
editor.change_selections(Some(Autoscroll::newest()), cx, |s| {
s.select_ranges(ranges);
@@ -11794,66 +11696,15 @@ impl Editor {
pub fn file_header_size(&self) -> u8 {
self.file_header_size
}
pub fn revert(
&mut self,
revert_changes: HashMap<BufferId, Vec<(Range<text::Anchor>, Rope)>>,
cx: &mut ViewContext<Self>,
) {
self.buffer().update(cx, |multi_buffer, cx| {
for (buffer_id, changes) in revert_changes {
if let Some(buffer) = multi_buffer.buffer(buffer_id) {
buffer.update(cx, |buffer, cx| {
buffer.edit(
changes.into_iter().map(|(range, text)| {
(range, text.to_string().map(Arc::<str>::from))
}),
None,
cx,
);
});
}
}
});
self.change_selections(None, cx, |selections| selections.refresh());
}
pub fn to_pixel_point(
&mut self,
source: multi_buffer::Anchor,
editor_snapshot: &EditorSnapshot,
cx: &mut ViewContext<Self>,
) -> Option<gpui::Point<Pixels>> {
let source_point = source.to_display_point(editor_snapshot);
self.display_to_pixel_point(source_point, editor_snapshot, cx)
}
pub fn display_to_pixel_point(
&mut self,
source: DisplayPoint,
editor_snapshot: &EditorSnapshot,
cx: &mut ViewContext<Self>,
) -> Option<gpui::Point<Pixels>> {
let line_height = self.style()?.text.line_height_in_pixels(cx.rem_size());
let text_layout_details = self.text_layout_details(cx);
let scroll_top = text_layout_details
.scroll_anchor
.scroll_position(editor_snapshot)
.y;
if source.row().as_f32() < scroll_top.floor() {
return None;
}
let source_x = editor_snapshot.x_for_display_point(source, &text_layout_details);
let source_y = line_height * (source.row().as_f32() - scroll_top);
Some(gpui::Point::new(source_x, source_y))
}
}
fn hunks_for_selections(
multi_buffer_snapshot: &MultiBufferSnapshot,
selections: &[Selection<Anchor>],
) -> Vec<DiffHunk<MultiBufferRow>> {
let mut hunks = Vec::with_capacity(selections.len());
let mut processed_buffer_rows: HashMap<BufferId, HashSet<Range<text::Anchor>>> =
HashMap::default();
let buffer_rows_for_selections = selections.iter().map(|selection| {
let head = selection.head();
let tail = selection.tail();
@@ -11866,17 +11717,7 @@ fn hunks_for_selections(
}
});
hunks_for_rows(buffer_rows_for_selections, multi_buffer_snapshot)
}
pub fn hunks_for_rows(
rows: impl Iterator<Item = Range<MultiBufferRow>>,
multi_buffer_snapshot: &MultiBufferSnapshot,
) -> Vec<DiffHunk<MultiBufferRow>> {
let mut hunks = Vec::new();
let mut processed_buffer_rows: HashMap<BufferId, HashSet<Range<text::Anchor>>> =
HashMap::default();
for selected_multi_buffer_rows in rows {
for selected_multi_buffer_rows in buffer_rows_for_selections {
let query_rows =
selected_multi_buffer_rows.start..selected_multi_buffer_rows.end.next_row();
for hunk in multi_buffer_snapshot.git_diff_hunks_in_range(query_rows.clone()) {
@@ -12447,8 +12288,12 @@ impl Render for Editor {
font_features: settings.ui_font.features.clone(),
font_size: rems(0.875).into(),
font_weight: settings.ui_font.weight,
font_style: FontStyle::Normal,
line_height: relative(settings.buffer_line_height.value()),
..Default::default()
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
},
EditorMode::Full => TextStyle {
color: cx.theme().colors().editor_foreground,
@@ -12456,8 +12301,12 @@ impl Render for Editor {
font_features: settings.buffer_font.features.clone(),
font_size: settings.buffer_font_size(cx).into(),
font_weight: settings.buffer_font.weight,
font_style: FontStyle::Normal,
line_height: relative(settings.buffer_line_height.value()),
..Default::default()
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
},
};
@@ -12857,7 +12706,7 @@ pub fn diagnostic_block_renderer(
highlight_diagnostic_message(&diagnostic, max_message_rows);
Box::new(move |cx: &mut BlockContext| {
let group_id: SharedString = cx.block_id.to_string().into();
let group_id: SharedString = cx.transform_block_id.to_string().into();
let mut text_style = cx.text_style().clone();
text_style.color = diagnostic_style(diagnostic.severity, cx.theme().status());
@@ -12869,7 +12718,7 @@ pub fn diagnostic_block_renderer(
let multi_line_diagnostic = diagnostic.message.contains('\n');
let buttons = |diagnostic: &Diagnostic, block_id: BlockId| {
let buttons = |diagnostic: &Diagnostic, block_id: TransformBlockId| {
if multi_line_diagnostic {
v_flex()
} else {
@@ -12900,12 +12749,12 @@ pub fn diagnostic_block_renderer(
)
};
let icon_size = buttons(&diagnostic, cx.block_id)
let icon_size = buttons(&diagnostic, cx.transform_block_id)
.into_any_element()
.layout_as_root(AvailableSpace::min_size(), cx);
h_flex()
.id(cx.block_id)
.id(cx.transform_block_id)
.group(group_id.clone())
.relative()
.size_full()
@@ -12917,7 +12766,7 @@ pub fn diagnostic_block_renderer(
.w(cx.anchor_x - cx.gutter_dimensions.width - icon_size.width)
.flex_shrink(),
)
.child(buttons(&diagnostic, cx.block_id))
.child(buttons(&diagnostic, cx.transform_block_id))
.child(div().flex().flex_shrink_0().child(
StyledText::new(text_without_backticks.clone()).with_highlights(
&text_style,
@@ -12951,31 +12800,23 @@ pub fn highlight_diagnostic_message(
let mut prev_offset = 0;
let mut in_code_block = false;
let has_row_limit = max_message_rows.is_some();
let mut newline_indices = diagnostic
.message
.match_indices('\n')
.filter(|_| has_row_limit)
.map(|(ix, _)| ix)
.fuse()
.peekable();
for (quote_ix, _) in diagnostic
for (ix, _) in diagnostic
.message
.match_indices('`')
.chain([(diagnostic.message.len(), "")])
{
let mut first_newline_ix = None;
let mut last_newline_ix = None;
while let Some(newline_ix) = newline_indices.peek() {
if *newline_ix < quote_ix {
if first_newline_ix.is_none() {
first_newline_ix = Some(*newline_ix);
}
last_newline_ix = Some(*newline_ix);
let mut trimmed_ix = ix;
while let Some(newline_index) = newline_indices.peek() {
if *newline_index < ix {
if let Some(rows_left) = &mut max_message_rows {
if *rows_left == 0 {
trimmed_ix = newline_index.saturating_sub(1);
break;
} else {
*rows_left -= 1;
@@ -12987,14 +12828,14 @@ pub fn highlight_diagnostic_message(
}
}
let prev_len = text_without_backticks.len();
let new_text = &diagnostic.message[prev_offset..first_newline_ix.unwrap_or(quote_ix)];
let new_text = &diagnostic.message[prev_offset..trimmed_ix];
text_without_backticks.push_str(new_text);
if in_code_block {
code_ranges.push(prev_len..text_without_backticks.len());
}
prev_offset = last_newline_ix.unwrap_or(quote_ix) + 1;
prev_offset = trimmed_ix + 1;
in_code_block = !in_code_block;
if first_newline_ix.map_or(false, |newline_ix| newline_ix < quote_ix) {
if trimmed_ix != ix {
text_without_backticks.push_str("...");
break;
}
@@ -13078,13 +12919,8 @@ pub(crate) fn split_words(text: &str) -> impl std::iter::Iterator<Item = &str> +
})
}
pub trait RangeToAnchorExt: Sized {
pub trait RangeToAnchorExt {
fn to_anchors(self, snapshot: &MultiBufferSnapshot) -> Range<Anchor>;
fn to_display_points(self, snapshot: &EditorSnapshot) -> Range<DisplayPoint> {
let anchor_range = self.to_anchors(&snapshot.buffer_snapshot);
anchor_range.start.to_display_point(&snapshot)..anchor_range.end.to_display_point(&snapshot)
}
}
impl<T: ToOffset> RangeToAnchorExt for Range<T> {

View File

@@ -28,6 +28,7 @@ pub struct EditorSettings {
pub search_wrap: bool,
pub auto_signature_help: bool,
pub show_signature_help_after_edits: bool,
#[serde(default)]
pub jupyter: Jupyter,
}
@@ -68,23 +69,15 @@ pub enum DoubleClickInMultibuffer {
Open,
}
#[derive(Debug, Clone, Deserialize)]
#[derive(Default, Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub struct Jupyter {
/// Whether the Jupyter feature is enabled.
///
/// Default: true
/// Default: `false`
pub enabled: bool,
}
#[derive(Default, Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub struct JupyterContent {
/// Whether the Jupyter feature is enabled.
///
/// Default: true
pub enabled: Option<bool>,
}
#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct Toolbar {
pub breadcrumbs: bool,
@@ -254,7 +247,7 @@ pub struct EditorSettingsContent {
pub show_signature_help_after_edits: Option<bool>,
/// Jupyter REPL settings.
pub jupyter: Option<JupyterContent>,
pub jupyter: Option<Jupyter>,
}
// Toolbar related settings
@@ -325,12 +318,6 @@ pub struct GutterContent {
pub folds: Option<bool>,
}
impl EditorSettings {
pub fn jupyter_enabled(cx: &AppContext) -> bool {
EditorSettings::get_global(cx).jupyter.enabled
}
}
impl Settings for EditorSettings {
const KEY: Option<&'static str> = None;

View File

@@ -23,7 +23,7 @@ use language::{
FakeLspAdapter, IndentGuide, LanguageConfig, LanguageConfigOverride, LanguageMatcher, Override,
ParsedMarkdown, Point,
};
use language_settings::{Formatter, FormatterList, IndentGuideSettings};
use language_settings::IndentGuideSettings;
use multi_buffer::MultiBufferIndentGuide;
use parking_lot::Mutex;
use project::FakeFs;
@@ -6253,8 +6253,8 @@ async fn test_multibuffer_format_during_save(cx: &mut gpui::TestAppContext) {
},
);
let worktree = project.update(cx, |project, cx| {
let mut worktrees = project.worktrees(cx).collect::<Vec<_>>();
let worktree = project.update(cx, |project, _| {
let mut worktrees = project.worktrees().collect::<Vec<_>>();
assert_eq!(worktrees.len(), 1);
worktrees.pop().unwrap()
});
@@ -6559,9 +6559,7 @@ async fn test_range_format_during_save(cx: &mut gpui::TestAppContext) {
#[gpui::test]
async fn test_document_format_manual_trigger(cx: &mut gpui::TestAppContext) {
init_test(cx, |settings| {
settings.defaults.formatter = Some(language_settings::SelectedFormatter::List(
FormatterList(vec![Formatter::LanguageServer { name: None }].into()),
))
settings.defaults.formatter = Some(language_settings::Formatter::LanguageServer)
});
let fs = FakeFs::new(cx.executor());
@@ -6722,7 +6720,7 @@ async fn test_concurrent_format_requests(cx: &mut gpui::TestAppContext) {
#[gpui::test]
async fn test_strip_whitespace_and_format_via_lsp(cx: &mut gpui::TestAppContext) {
init_test(cx, |settings| {
settings.defaults.formatter = Some(language_settings::SelectedFormatter::Auto)
settings.defaults.formatter = Some(language_settings::Formatter::Auto)
});
let mut cx = EditorLspTestContext::new_rust(
@@ -9319,7 +9317,7 @@ async fn test_on_type_formatting_not_triggered(cx: &mut gpui::TestAppContext) {
let worktree_id = workspace
.update(cx, |workspace, cx| {
workspace.project().update(cx, |project, cx| {
project.worktrees(cx).next().unwrap().read(cx).id()
project.worktrees().next().unwrap().read(cx).id()
})
})
.unwrap();
@@ -9725,9 +9723,7 @@ async fn test_completions_in_languages_with_extra_word_characters(cx: &mut gpui:
#[gpui::test]
async fn test_document_format_with_prettier(cx: &mut gpui::TestAppContext) {
init_test(cx, |settings| {
settings.defaults.formatter = Some(language_settings::SelectedFormatter::List(
FormatterList(vec![Formatter::Prettier].into()),
))
settings.defaults.formatter = Some(language_settings::Formatter::Prettier)
});
let fs = FakeFs::new(cx.executor());
@@ -9787,7 +9783,7 @@ async fn test_document_format_with_prettier(cx: &mut gpui::TestAppContext) {
);
update_test_language_settings(cx, |settings| {
settings.defaults.formatter = Some(language_settings::SelectedFormatter::Auto)
settings.defaults.formatter = Some(language_settings::Formatter::Auto)
});
let format = editor.update(cx, |editor, cx| {
editor.perform_format(project.clone(), FormatTrigger::Manual, cx)
@@ -10449,12 +10445,7 @@ async fn test_mutlibuffer_in_navigation_history(cx: &mut gpui::TestAppContext) {
workspace.active_item(cx).is_none(),
"active item should be None before the first item is added"
);
workspace.add_item_to_active_pane(
Box::new(multi_buffer_editor.clone()),
None,
true,
cx,
);
workspace.add_item_to_active_pane(Box::new(multi_buffer_editor.clone()), None, cx);
let active_item = workspace
.active_item(cx)
.expect("should have an active item after adding the multi buffer");

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,7 @@ use git::{
parse_git_remote_url, GitHostingProvider, GitHostingProviderRegistry, Oid, PullRequest,
};
use gpui::{Model, ModelContext, Subscription, Task};
use http_client::HttpClient;
use http::HttpClient;
use language::{markdown, Bias, Buffer, BufferSnapshot, Edit, LanguageRegistry, ParsedMarkdown};
use multi_buffer::MultiBufferRow;
use project::{Item, Project};

View File

@@ -213,8 +213,22 @@ fn show_hover(
};
if !ignore_timeout {
if same_info_hover(editor, &snapshot, anchor)
|| same_diagnostic_hover(editor, &snapshot, anchor)
if editor
.hover_state
.info_popovers
.iter()
.any(|InfoPopover { symbol_range, .. }| {
symbol_range
.as_text_range()
.map(|range| {
let hover_range = range.to_offset(&snapshot.buffer_snapshot);
let offset = anchor.to_offset(&snapshot.buffer_snapshot);
// LSP returns a hover result for the end index of ranges that should be hovered, so we need to
// use an inclusive range here to check if we should dismiss the popover
(hover_range.start..=hover_range.end).contains(&offset)
})
.unwrap_or(false)
})
{
// Hover triggered from same location as last time. Don't show again.
return;
@@ -361,43 +375,6 @@ fn show_hover(
editor.hover_state.info_task = Some(task);
}
fn same_info_hover(editor: &Editor, snapshot: &EditorSnapshot, anchor: Anchor) -> bool {
editor
.hover_state
.info_popovers
.iter()
.any(|InfoPopover { symbol_range, .. }| {
symbol_range
.as_text_range()
.map(|range| {
let hover_range = range.to_offset(&snapshot.buffer_snapshot);
let offset = anchor.to_offset(&snapshot.buffer_snapshot);
// LSP returns a hover result for the end index of ranges that should be hovered, so we need to
// use an inclusive range here to check if we should dismiss the popover
(hover_range.start..=hover_range.end).contains(&offset)
})
.unwrap_or(false)
})
}
fn same_diagnostic_hover(editor: &Editor, snapshot: &EditorSnapshot, anchor: Anchor) -> bool {
editor
.hover_state
.diagnostic_popover
.as_ref()
.map(|diagnostic| {
let hover_range = diagnostic
.local_diagnostic
.range
.to_offset(&snapshot.buffer_snapshot);
let offset = anchor.to_offset(&snapshot.buffer_snapshot);
// Here we do basically the same as in `same_info_hover`, see comment there for an explanation
(hover_range.start..=hover_range.end).contains(&offset)
})
.unwrap_or(false)
}
async fn parse_blocks(
blocks: &[HoverBlock],
language_registry: &Arc<LanguageRegistry>,
@@ -545,7 +522,7 @@ impl HoverState {
pub fn focused(&self, cx: &mut ViewContext<Editor>) -> bool {
let mut hover_popover_is_focused = false;
for info_popover in &self.info_popovers {
if let Some(markdown_view) = &info_popover.parsed_content {
for markdown_view in &info_popover.parsed_content {
if markdown_view.focus_handle(cx).is_focused(cx) {
hover_popover_is_focused = true;
}

View File

@@ -5,31 +5,28 @@ use std::{
use collections::{hash_map, HashMap, HashSet};
use git::diff::{DiffHunk, DiffHunkStatus};
use gpui::{Action, AppContext, Hsla, Model, MouseButton, Subscription, Task, View};
use gpui::{AppContext, Hsla, Model, Task, View};
use language::Buffer;
use multi_buffer::{
Anchor, AnchorRangeExt, ExcerptRange, MultiBuffer, MultiBufferRow, MultiBufferSnapshot, ToPoint,
Anchor, ExcerptRange, MultiBuffer, MultiBufferRow, MultiBufferSnapshot, ToPoint,
};
use settings::SettingsStore;
use text::{BufferId, Point};
use ui::{
h_flex, v_flex, ActiveTheme, Context as _, ContextMenu, InteractiveElement, IntoElement,
ParentElement, Pixels, Styled, ViewContext, VisualContext,
div, ActiveTheme, Context as _, IntoElement, ParentElement, Styled, ViewContext, VisualContext,
};
use util::{debug_panic, RangeExt};
use crate::{
editor_settings::CurrentLineHighlight,
git::{diff_hunk_to_display, DisplayDiffHunk},
hunk_status, hunks_for_selections,
mouse_context_menu::MouseContextMenu,
BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, DiffRowHighlight, Editor,
EditorSnapshot, ExpandAllHunkDiffs, RangeToAnchorExt, RevertSelectedHunks, ToDisplayPoint,
ToggleHunkDiff,
hunk_status, hunks_for_selections, BlockDisposition, BlockId, BlockProperties, BlockStyle,
DiffRowHighlight, Editor, EditorSnapshot, ExpandAllHunkDiffs, RangeToAnchorExt,
RevertSelectedHunks, ToDisplayPoint, ToggleHunkDiff,
};
#[derive(Debug, Clone)]
pub(super) struct HoveredHunk {
pub(super) struct HunkToExpand {
pub multi_buffer_range: Range<Anchor>,
pub status: DiffHunkStatus,
pub diff_base_byte_range: Range<usize>,
@@ -58,7 +55,7 @@ impl ExpandedHunks {
#[derive(Debug, Clone)]
pub(super) struct ExpandedHunk {
pub block: Option<CustomBlockId>,
pub block: Option<BlockId>,
pub hunk_range: Range<Anchor>,
pub diff_base_byte_range: Range<usize>,
pub status: DiffHunkStatus,
@@ -66,123 +63,6 @@ pub(super) struct ExpandedHunk {
}
impl Editor {
pub(super) fn open_hunk_context_menu(
&mut self,
hovered_hunk: HoveredHunk,
clicked_point: gpui::Point<Pixels>,
cx: &mut ViewContext<Editor>,
) {
let focus_handle = self.focus_handle.clone();
let expanded = self
.expanded_hunks
.hunks(false)
.any(|expanded_hunk| expanded_hunk.hunk_range == hovered_hunk.multi_buffer_range);
let editor_handle = cx.view().clone();
let editor_snapshot = self.snapshot(cx);
let start_point = self
.to_pixel_point(hovered_hunk.multi_buffer_range.start, &editor_snapshot, cx)
.unwrap_or(clicked_point);
let end_point = self
.to_pixel_point(hovered_hunk.multi_buffer_range.start, &editor_snapshot, cx)
.unwrap_or(clicked_point);
let norm =
|a: gpui::Point<Pixels>, b: gpui::Point<Pixels>| (a.x - b.x).abs() + (a.y - b.y).abs();
let closest_source = if norm(start_point, clicked_point) < norm(end_point, clicked_point) {
hovered_hunk.multi_buffer_range.start
} else {
hovered_hunk.multi_buffer_range.end
};
self.mouse_context_menu = MouseContextMenu::pinned_to_editor(
self,
closest_source,
clicked_point,
ContextMenu::build(cx, move |menu, _| {
menu.on_blur_subscription(Subscription::new(|| {}))
.context(focus_handle)
.entry(
if expanded {
"Collapse Hunk"
} else {
"Expand Hunk"
},
Some(ToggleHunkDiff.boxed_clone()),
{
let editor = editor_handle.clone();
let hunk = hovered_hunk.clone();
move |cx| {
editor.update(cx, |editor, cx| {
editor.toggle_hovered_hunk(&hunk, cx);
});
}
},
)
.entry("Revert Hunk", Some(RevertSelectedHunks.boxed_clone()), {
let editor = editor_handle.clone();
let hunk = hovered_hunk.clone();
move |cx| {
let multi_buffer = editor.read(cx).buffer().clone();
let multi_buffer_snapshot = multi_buffer.read(cx).snapshot(cx);
let mut revert_changes = HashMap::default();
if let Some(hunk) =
crate::hunk_diff::to_diff_hunk(&hunk, &multi_buffer_snapshot)
{
Editor::prepare_revert_change(
&mut revert_changes,
&multi_buffer,
&hunk,
cx,
);
}
if !revert_changes.is_empty() {
editor.update(cx, |editor, cx| editor.revert(revert_changes, cx));
}
}
})
.entry("Revert File", None, {
let editor = editor_handle.clone();
move |cx| {
let mut revert_changes = HashMap::default();
let multi_buffer = editor.read(cx).buffer().clone();
let multi_buffer_snapshot = multi_buffer.read(cx).snapshot(cx);
for hunk in crate::hunks_for_rows(
Some(MultiBufferRow(0)..multi_buffer_snapshot.max_buffer_row())
.into_iter(),
&multi_buffer_snapshot,
) {
Editor::prepare_revert_change(
&mut revert_changes,
&multi_buffer,
&hunk,
cx,
);
}
if !revert_changes.is_empty() {
editor.update(cx, |editor, cx| {
editor.transact(cx, |editor, cx| {
editor.revert(revert_changes, cx);
});
});
}
}
})
}),
cx,
)
}
pub(super) fn toggle_hovered_hunk(
&mut self,
hovered_hunk: &HoveredHunk,
cx: &mut ViewContext<Editor>,
) {
let editor_snapshot = self.snapshot(cx);
if let Some(diff_hunk) = to_diff_hunk(hovered_hunk, &editor_snapshot.buffer_snapshot) {
self.toggle_hunks_expanded(vec![diff_hunk], cx);
self.change_selections(None, cx, |selections| selections.refresh());
}
}
pub fn toggle_hunk_diff(&mut self, _: &ToggleHunkDiff, cx: &mut ViewContext<Self>) {
let multi_buffer_snapshot = self.buffer().read(cx).snapshot(cx);
let selections = self.selections.disjoint_anchors();
@@ -284,7 +164,7 @@ impl Editor {
retain = false;
break;
} else {
hunks_to_expand.push(HoveredHunk {
hunks_to_expand.push(HunkToExpand {
status,
multi_buffer_range,
diff_base_byte_range,
@@ -302,7 +182,7 @@ impl Editor {
let remaining_hunk_point_range =
Point::new(remaining_hunk.associated_range.start.0, 0)
..Point::new(remaining_hunk.associated_range.end.0, 0);
hunks_to_expand.push(HoveredHunk {
hunks_to_expand.push(HunkToExpand {
status: hunk_status(&remaining_hunk),
multi_buffer_range: remaining_hunk_point_range
.to_anchors(&snapshot.buffer_snapshot),
@@ -335,7 +215,7 @@ impl Editor {
pub(super) fn expand_diff_hunk(
&mut self,
diff_base_buffer: Option<Model<Buffer>>,
hunk: &HoveredHunk,
hunk: &HunkToExpand,
cx: &mut ViewContext<'_, Editor>,
) -> Option<()> {
let multi_buffer_snapshot = self.buffer().read(cx).snapshot(cx);
@@ -423,57 +303,28 @@ impl Editor {
&mut self,
diff_base_buffer: Model<Buffer>,
deleted_text_height: u8,
hunk: &HoveredHunk,
hunk: &HunkToExpand,
cx: &mut ViewContext<'_, Self>,
) -> Option<CustomBlockId> {
) -> Option<BlockId> {
let deleted_hunk_color = deleted_hunk_color(cx);
let (editor_height, editor_with_deleted_text) =
editor_with_deleted_text(diff_base_buffer, deleted_hunk_color, hunk, cx);
let editor = cx.view().clone();
let editor_model = cx.model().clone();
let hunk = hunk.clone();
let mut new_block_ids = self.insert_blocks(
Some(BlockProperties {
position: hunk.multi_buffer_range.start,
height: editor_height.max(deleted_text_height),
style: BlockStyle::Flex,
disposition: BlockDisposition::Above,
render: Box::new(move |cx| {
let close_button = editor.update(cx.context, |editor, cx| {
let editor_snapshot = editor.snapshot(cx);
let hunk_start_row = hunk
.multi_buffer_range
.start
.to_display_point(&editor_snapshot)
.row();
editor.render_close_hunk_diff_button(hunk.clone(), hunk_start_row, cx)
});
let gutter_dimensions = editor_model.read(cx).gutter_dimensions;
let click_editor = editor.clone();
h_flex()
div()
.bg(deleted_hunk_color)
.size_full()
.child(
v_flex()
.max_w(gutter_dimensions.full_width())
.min_w(gutter_dimensions.full_width())
.size_full()
.on_mouse_down(MouseButton::Left, {
let click_hunk = hunk.clone();
move |e, cx| {
let modifiers = e.modifiers;
if modifiers.control || modifiers.platform {
click_editor.update(cx, |editor, cx| {
editor.toggle_hovered_hunk(&click_hunk, cx);
});
}
}
})
.child(close_button),
)
.pl(gutter_dimensions.full_width())
.child(editor_with_deleted_text.clone())
.into_any_element()
}),
disposition: BlockDisposition::Above,
}),
None,
cx,
@@ -488,21 +339,16 @@ impl Editor {
}
}
pub(super) fn clear_clicked_diff_hunks(&mut self, cx: &mut ViewContext<'_, Editor>) -> bool {
pub(super) fn clear_expanded_diff_hunks(&mut self, cx: &mut ViewContext<'_, Editor>) {
self.expanded_hunks.hunk_update_tasks.clear();
self.clear_row_highlights::<DiffRowHighlight>();
let to_remove = self
.expanded_hunks
.hunks
.drain(..)
.filter_map(|expanded_hunk| expanded_hunk.block)
.collect::<HashSet<_>>();
if to_remove.is_empty() {
false
} else {
self.remove_blocks(to_remove, None, cx);
true
}
.collect();
self.clear_row_highlights::<DiffRowHighlight>();
self.remove_blocks(to_remove, None, cx);
}
pub(super) fn sync_expanded_diff_hunks(
@@ -611,7 +457,7 @@ impl Editor {
recalculated_hunks.next();
retain = true;
} else {
hunks_to_reexpand.push(HoveredHunk {
hunks_to_reexpand.push(HunkToExpand {
status,
multi_buffer_range,
diff_base_byte_range,
@@ -676,29 +522,6 @@ impl Editor {
}
}
fn to_diff_hunk(
hovered_hunk: &HoveredHunk,
multi_buffer_snapshot: &MultiBufferSnapshot,
) -> Option<DiffHunk<MultiBufferRow>> {
let buffer_id = hovered_hunk
.multi_buffer_range
.start
.buffer_id
.or_else(|| hovered_hunk.multi_buffer_range.end.buffer_id)?;
let buffer_range = hovered_hunk.multi_buffer_range.start.text_anchor
..hovered_hunk.multi_buffer_range.end.text_anchor;
let point_range = hovered_hunk
.multi_buffer_range
.to_point(&multi_buffer_snapshot);
Some(DiffHunk {
associated_range: MultiBufferRow(point_range.start.row)
..MultiBufferRow(point_range.end.row),
buffer_id,
buffer_range,
diff_base_byte_range: hovered_hunk.diff_base_byte_range.clone(),
})
}
fn create_diff_base_buffer(buffer: &Model<Buffer>, cx: &mut AppContext) -> Option<Model<Buffer>> {
buffer
.update(cx, |buffer, _| {
@@ -732,7 +555,7 @@ fn deleted_hunk_color(cx: &AppContext) -> Hsla {
fn editor_with_deleted_text(
diff_base_buffer: Model<Buffer>,
deleted_color: Hsla,
hunk: &HoveredHunk,
hunk: &HunkToExpand,
cx: &mut ViewContext<'_, Editor>,
) -> (u8, View<Editor>) {
let parent_editor = cx.view().downgrade();
@@ -790,12 +613,11 @@ fn editor_with_deleted_text(
}
}),
]);
let parent_editor_for_reverts = parent_editor.clone();
let original_multi_buffer_range = hunk.multi_buffer_range.clone();
let diff_base_range = hunk.diff_base_byte_range.clone();
editor
.register_action::<RevertSelectedHunks>(move |_, cx| {
parent_editor_for_reverts
parent_editor
.update(cx, |editor, cx| {
let Some((buffer, original_text)) =
editor.buffer().update(cx, |buffer, cx| {
@@ -823,16 +645,6 @@ fn editor_with_deleted_text(
.ok();
})
.detach();
let hunk = hunk.clone();
editor
.register_action::<ToggleHunkDiff>(move |_, cx| {
parent_editor
.update(cx, |editor, cx| {
editor.toggle_hovered_hunk(&hunk, cx);
})
.ok();
})
.detach();
editor
});

View File

@@ -2581,7 +2581,7 @@ pub mod tests {
);
let worktree_id = project.update(cx, |project, cx| {
project.worktrees(cx).next().unwrap().read(cx).id()
project.worktrees().next().unwrap().read(cx).id()
});
let buffer_1 = project
@@ -2931,7 +2931,7 @@ pub mod tests {
);
let worktree_id = project.update(cx, |project, cx| {
project.worktrees(cx).next().unwrap().read(cx).id()
project.worktrees().next().unwrap().read(cx).id()
});
let buffer_1 = project

View File

@@ -1,7 +1,6 @@
use crate::Direction;
use gpui::{AppContext, Model, ModelContext};
use language::Buffer;
use std::ops::Range;
pub trait InlineCompletionProvider: 'static + Sized {
fn name() -> &'static str;
@@ -32,7 +31,7 @@ pub trait InlineCompletionProvider: 'static + Sized {
buffer: &Model<Buffer>,
cursor_position: language::Anchor,
cx: &'a AppContext,
) -> Option<(&'a str, Option<Range<language::Anchor>>)>;
) -> Option<&'a str>;
}
pub trait InlineCompletionProviderHandle {
@@ -63,7 +62,7 @@ pub trait InlineCompletionProviderHandle {
buffer: &Model<Buffer>,
cursor_position: language::Anchor,
cx: &'a AppContext,
) -> Option<(&'a str, Option<Range<language::Anchor>>)>;
) -> Option<&'a str>;
}
impl<T> InlineCompletionProviderHandle for Model<T>
@@ -118,7 +117,7 @@ where
buffer: &Model<Buffer>,
cursor_position: language::Anchor,
cx: &'a AppContext,
) -> Option<(&'a str, Option<Range<language::Anchor>>)> {
) -> Option<&'a str> {
self.read(cx)
.active_completion_text(buffer, cursor_position, cx)
}

View File

@@ -5,7 +5,6 @@ use crate::{
};
use anyhow::{anyhow, Context as _, Result};
use collections::HashSet;
use file_icons::FileIcons;
use futures::future::try_join_all;
use git::repository::GitFileStatus;
use gpui::{
@@ -17,13 +16,10 @@ use language::{
proto::serialize_anchor as serialize_text_anchor, Bias, Buffer, CharKind, Point, SelectionGoal,
};
use multi_buffer::AnchorRangeExt;
use project::{
project_settings::ProjectSettings, search::SearchQuery, FormatTrigger, Item as _, Project,
ProjectPath,
};
use project::{search::SearchQuery, FormatTrigger, Item as _, Project, ProjectPath};
use rpc::proto::{self, update_view, PeerId};
use settings::Settings;
use workspace::item::{Dedup, ItemSettings, SerializableItem, TabContentParams};
use workspace::item::{Dedup, ItemSettings, TabContentParams};
use std::{
any::TypeId,
@@ -40,7 +36,7 @@ use ui::{h_flex, prelude::*, Label};
use util::{paths::PathExt, ResultExt, TryFutureExt};
use workspace::item::{BreadcrumbText, FollowEvent};
use workspace::{
item::{FollowableItem, Item, ItemEvent, ProjectItem},
item::{FollowableItem, Item, ItemEvent, ItemHandle, ProjectItem},
searchable::{Direction, SearchEvent, SearchableItem, SearchableItemHandle},
ItemId, ItemNavHistory, Pane, ToolbarItemLocation, ViewId, Workspace, WorkspaceId,
};
@@ -591,20 +587,6 @@ impl Item for Editor {
Some(path.to_string_lossy().to_string().into())
}
fn tab_icon(&self, cx: &WindowContext) -> Option<Icon> {
ItemSettings::get_global(cx)
.file_icons
.then(|| {
self.buffer
.read(cx)
.as_singleton()
.and_then(|buffer| buffer.read(cx).project_path(cx))
.and_then(|path| FileIcons::get_icon(path.path.as_ref(), cx))
})
.flatten()
.map(|icon| Icon::from_path(icon))
}
fn tab_content(&self, params: TabContentParams, cx: &WindowContext) -> AnyElement {
let label_color = if ItemSettings::get_global(cx).git_status {
self.buffer()
@@ -855,8 +837,54 @@ impl Item for Editor {
Some(breadcrumbs)
}
fn added_to_workspace(&mut self, workspace: &mut Workspace, _: &mut ViewContext<Self>) {
fn added_to_workspace(&mut self, workspace: &mut Workspace, cx: &mut ViewContext<Self>) {
self.workspace = Some((workspace.weak_handle(), workspace.database_id()));
let Some(workspace_id) = workspace.database_id() else {
return;
};
let item_id = cx.view().item_id().as_u64() as ItemId;
fn serialize(
buffer: Model<Buffer>,
workspace_id: WorkspaceId,
item_id: ItemId,
cx: &mut AppContext,
) {
if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) {
let path = file.abs_path(cx);
cx.background_executor()
.spawn(async move {
DB.save_path(item_id, workspace_id, path.clone())
.await
.log_err()
})
.detach();
}
}
if let Some(buffer) = self.buffer().read(cx).as_singleton() {
serialize(buffer.clone(), workspace_id, item_id, cx);
cx.subscribe(&buffer, |this, buffer, event, cx| {
if let Some((_, Some(workspace_id))) = this.workspace.as_ref() {
if let language::Event::FileHandleChanged = event {
serialize(
buffer,
*workspace_id,
cx.view().item_id().as_u64() as ItemId,
cx,
);
}
}
})
.detach();
}
}
fn serialized_item_kind() -> Option<&'static str> {
Some("Editor")
}
fn to_item_events(event: &EditorEvent, mut f: impl FnMut(ItemEvent)) {
@@ -892,20 +920,6 @@ impl Item for Editor {
_ => {}
}
}
}
impl SerializableItem for Editor {
fn serialized_item_kind() -> &'static str {
"Editor"
}
fn cleanup(
workspace_id: WorkspaceId,
alive_items: Vec<ItemId>,
cx: &mut WindowContext,
) -> Task<Result<()>> {
cx.spawn(|_| DB.delete_unloaded_items(workspace_id, alive_items))
}
fn deserialize(
project: Model<Project>,
@@ -914,171 +928,41 @@ impl SerializableItem for Editor {
item_id: ItemId,
cx: &mut ViewContext<Pane>,
) -> Task<Result<View<Self>>> {
let path_content_language = match DB
.get_path_and_contents(item_id, workspace_id)
.context("Failed to query editor state")
{
Ok(Some((path, content, language))) => {
if ProjectSettings::get_global(cx)
.session
.restore_unsaved_buffers
{
(path, content, language)
} else {
(path, None, None)
}
}
Ok(None) => {
return Task::ready(Err(anyhow!("No path or contents found for buffer")));
}
Err(error) => {
return Task::ready(Err(error));
}
};
let project_item: Result<_> = project.update(cx, |project, cx| {
// Look up the path with this key associated, create a self with that path
let path = DB
.get_path(item_id, workspace_id)?
.context("No path stored for this editor")?;
match path_content_language {
(None, Some(content), language_name) => cx.spawn(|_, mut cx| async move {
let language = if let Some(language_name) = language_name {
let language_registry =
project.update(&mut cx, |project, _| project.languages().clone())?;
let (worktree, path) = project
.find_worktree(&path, cx)
.with_context(|| format!("No worktree for path: {path:?}"))?;
let project_path = ProjectPath {
worktree_id: worktree.read(cx).id(),
path: path.into(),
};
Some(language_registry.language_for_name(&language_name).await?)
} else {
None
};
Ok(project.open_path(project_path, cx))
});
// First create the empty buffer
let buffer = project.update(&mut cx, |project, cx| {
project.create_local_buffer("", language, cx)
})?;
project_item
.map(|project_item| {
cx.spawn(|pane, mut cx| async move {
let (_, project_item) = project_item.await?;
let buffer = project_item
.downcast::<Buffer>()
.map_err(|_| anyhow!("Project item at stored path was not a buffer"))?;
pane.update(&mut cx, |_, cx| {
cx.new_view(|cx| {
let mut editor = Editor::for_buffer(buffer, Some(project), cx);
// Then set the text so that the dirty bit is set correctly
buffer.update(&mut cx, |buffer, cx| {
buffer.set_text(content, cx);
})?;
cx.new_view(|cx| {
let mut editor = Editor::for_buffer(buffer, Some(project), cx);
editor.read_scroll_position_from_db(item_id, workspace_id, cx);
editor
})
}),
(Some(path), contents, _) => {
let project_item = project.update(cx, |project, cx| {
let (worktree, path) = project
.find_worktree(&path, cx)
.with_context(|| format!("No worktree for path: {path:?}"))?;
let project_path = ProjectPath {
worktree_id: worktree.read(cx).id(),
path: path.into(),
};
Ok(project.open_path(project_path, cx))
});
project_item
.map(|project_item| {
cx.spawn(|pane, mut cx| async move {
let (_, project_item) = project_item.await?;
let buffer = project_item.downcast::<Buffer>().map_err(|_| {
anyhow!("Project item at stored path was not a buffer")
})?;
// This is a bit wasteful: we're loading the whole buffer from
// disk and then overwrite the content.
// But for now, it keeps the implementation of the content serialization
// simple, because we don't have to persist all of the metadata that we get
// by loading the file (git diff base, mtime, ...).
if let Some(buffer_text) = contents {
buffer.update(&mut cx, |buffer, cx| {
buffer.set_text(buffer_text, cx);
})?;
}
pane.update(&mut cx, |_, cx| {
cx.new_view(|cx| {
let mut editor = Editor::for_buffer(buffer, Some(project), cx);
editor.read_scroll_position_from_db(item_id, workspace_id, cx);
editor
})
})
editor.read_scroll_position_from_db(item_id, workspace_id, cx);
editor
})
})
.unwrap_or_else(|error| Task::ready(Err(error)))
}
_ => Task::ready(Err(anyhow!("No path or contents found for buffer"))),
}
}
fn serialize(
&mut self,
workspace: &mut Workspace,
item_id: ItemId,
closing: bool,
cx: &mut ViewContext<Self>,
) -> Option<Task<Result<()>>> {
let mut serialize_dirty_buffers = self.serialize_dirty_buffers;
let project = self.project.clone()?;
if project.read(cx).visible_worktrees(cx).next().is_none() {
// If we don't have a worktree, we don't serialize, because
// projects without worktrees aren't deserialized.
serialize_dirty_buffers = false;
}
if closing && !serialize_dirty_buffers {
return None;
}
let workspace_id = workspace.database_id()?;
let buffer = self.buffer().read(cx).as_singleton()?;
let is_dirty = buffer.read(cx).is_dirty();
let path = buffer
.read(cx)
.file()
.and_then(|file| file.as_local())
.map(|file| file.abs_path(cx));
let snapshot = buffer.read(cx).snapshot();
Some(cx.spawn(|_this, cx| async move {
cx.background_executor()
.spawn(async move {
if let Some(path) = path {
DB.save_path(item_id, workspace_id, path.clone())
.await
.context("failed to save path of buffer")?
}
if serialize_dirty_buffers {
let (contents, language) = if is_dirty {
let contents = snapshot.text();
let language = snapshot.language().map(|lang| lang.name().to_string());
(Some(contents), language)
} else {
(None, None)
};
DB.save_contents(item_id, workspace_id, contents, language)
.await?;
}
anyhow::Ok(())
})
.await
.context("failed to save contents of buffer")?;
Ok(())
}))
}
fn should_serialize(&self, event: &Self::Event) -> bool {
matches!(
event,
EditorEvent::Saved | EditorEvent::DirtyChanged | EditorEvent::BufferEdited
)
})
.unwrap_or_else(|error| Task::ready(Err(error)))
}
}

View File

@@ -10,72 +10,14 @@ use gpui::prelude::FluentBuilder;
use gpui::{DismissEvent, Pixels, Point, Subscription, View, ViewContext};
use workspace::OpenInTerminal;
#[derive(Debug)]
pub enum MenuPosition {
/// When the editor is scrolled, the context menu stays on the exact
/// same position on the screen, never disappearing.
PinnedToScreen(Point<Pixels>),
/// When the editor is scrolled, the context menu follows the position it is associated with.
/// Disappears when the position is no longer visible.
PinnedToEditor {
source: multi_buffer::Anchor,
offset_x: Pixels,
offset_y: Pixels,
},
}
pub struct MouseContextMenu {
pub(crate) position: MenuPosition,
pub(crate) position: Point<Pixels>,
pub(crate) context_menu: View<ui::ContextMenu>,
_subscription: Subscription,
}
impl std::fmt::Debug for MouseContextMenu {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MouseContextMenu")
.field("position", &self.position)
.field("context_menu", &self.context_menu)
.finish()
}
}
impl MouseContextMenu {
pub(crate) fn pinned_to_editor(
editor: &mut Editor,
source: multi_buffer::Anchor,
position: Point<Pixels>,
context_menu: View<ui::ContextMenu>,
cx: &mut ViewContext<Editor>,
) -> Option<Self> {
let context_menu_focus = context_menu.focus_handle(cx);
cx.focus(&context_menu_focus);
let _subscription = cx.subscribe(
&context_menu,
move |editor, _, _event: &DismissEvent, cx| {
editor.mouse_context_menu.take();
if context_menu_focus.contains_focused(cx) {
editor.focus(cx);
}
},
);
let editor_snapshot = editor.snapshot(cx);
let source_point = editor.to_pixel_point(source, &editor_snapshot, cx)?;
let offset = position - source_point;
Some(Self {
position: MenuPosition::PinnedToEditor {
source,
offset_x: offset.x,
offset_y: offset.y,
},
context_menu,
_subscription,
})
}
pub(crate) fn pinned_to_screen(
pub(crate) fn new(
position: Point<Pixels>,
context_menu: View<ui::ContextMenu>,
cx: &mut ViewContext<Editor>,
@@ -83,18 +25,16 @@ impl MouseContextMenu {
let context_menu_focus = context_menu.focus_handle(cx);
cx.focus(&context_menu_focus);
let _subscription = cx.subscribe(
&context_menu,
move |editor, _, _event: &DismissEvent, cx| {
editor.mouse_context_menu.take();
let _subscription =
cx.subscribe(&context_menu, move |this, _, _event: &DismissEvent, cx| {
this.mouse_context_menu.take();
if context_menu_focus.contains_focused(cx) {
editor.focus(cx);
this.focus(cx);
}
},
);
});
Self {
position: MenuPosition::PinnedToScreen(position),
position,
context_menu,
_subscription,
}
@@ -131,15 +71,13 @@ pub fn deploy_context_menu(
return;
}
let display_map = editor.selections.display_map(cx);
let source_anchor = display_map.display_point_to_anchor(point, text::Bias::Right);
let context_menu = if let Some(custom) = editor.custom_context_menu.take() {
let menu = custom(editor, point, cx);
editor.custom_context_menu = Some(custom);
let Some(menu) = menu else {
if menu.is_none() {
return;
};
menu
}
menu.unwrap()
} else {
// Don't show the context menu if there isn't a project associated with this editor
if editor.project.is_none() {
@@ -147,20 +85,17 @@ pub fn deploy_context_menu(
}
let display_map = editor.selections.display_map(cx);
let buffer = &editor.snapshot(cx).buffer_snapshot;
let anchor = buffer.anchor_before(point.to_point(&display_map));
if !display_ranges(&display_map, &editor.selections).any(|r| r.contains(&point)) {
// Move the cursor to the clicked location so that dispatched actions make sense
editor.change_selections(None, cx, |s| {
s.clear_disjoint();
s.set_pending_anchor_range(anchor..anchor, SelectMode::Character);
s.set_pending_display_range(point..point, SelectMode::Character);
});
}
let focus = cx.focused();
ui::ContextMenu::build(cx, |menu, _cx| {
let builder = menu
.on_blur_subscription(Subscription::new(|| {}))
.action("Rename Symbol", Box::new(Rename))
.action("Go to Definition", Box::new(GoToDefinition))
.action("Go to Type Definition", Box::new(GoToTypeDefinition))
@@ -191,9 +126,8 @@ pub fn deploy_context_menu(
}
})
};
editor.mouse_context_menu =
MouseContextMenu::pinned_to_editor(editor, source_anchor, position, context_menu, cx);
let mouse_context_menu = MouseContextMenu::new(position, context_menu, cx);
editor.mouse_context_menu = Some(mouse_context_menu);
cx.notify();
}

View File

@@ -1,5 +1,3 @@
use anyhow::Result;
use db::sqlez::statement::Statement;
use std::path::PathBuf;
use db::sqlez_macros::sql;
@@ -12,12 +10,10 @@ define_connection!(
// editors(
// item_id: usize,
// workspace_id: usize,
// path: Option<PathBuf>,
// path: PathBuf,
// scroll_top_row: usize,
// scroll_vertical_offset: f32,
// scroll_horizontal_offset: f32,
// content: Option<String>,
// language: Option<String>,
// )
pub static ref DB: EditorDb<WorkspaceDb> =
&[sql! (
@@ -35,39 +31,13 @@ define_connection!(
ALTER TABLE editors ADD COLUMN scroll_top_row INTEGER NOT NULL DEFAULT 0;
ALTER TABLE editors ADD COLUMN scroll_horizontal_offset REAL NOT NULL DEFAULT 0;
ALTER TABLE editors ADD COLUMN scroll_vertical_offset REAL NOT NULL DEFAULT 0;
),
sql! (
// Since sqlite3 doesn't support ALTER COLUMN, we create a new
// table, move the data over, drop the old table, rename new table.
CREATE TABLE new_editors_tmp (
item_id INTEGER NOT NULL,
workspace_id INTEGER NOT NULL,
path BLOB, // <-- No longer "NOT NULL"
scroll_top_row INTEGER NOT NULL DEFAULT 0,
scroll_horizontal_offset REAL NOT NULL DEFAULT 0,
scroll_vertical_offset REAL NOT NULL DEFAULT 0,
contents TEXT, // New
language TEXT, // New
PRIMARY KEY(item_id, workspace_id),
FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
ON DELETE CASCADE
ON UPDATE CASCADE
) STRICT;
INSERT INTO new_editors_tmp(item_id, workspace_id, path, scroll_top_row, scroll_horizontal_offset, scroll_vertical_offset)
SELECT item_id, workspace_id, path, scroll_top_row, scroll_horizontal_offset, scroll_vertical_offset
FROM editors;
DROP TABLE editors;
ALTER TABLE new_editors_tmp RENAME TO editors;
)];
);
impl EditorDb {
query! {
pub fn get_path_and_contents(item_id: ItemId, workspace_id: WorkspaceId) -> Result<Option<(Option<PathBuf>, Option<String>, Option<String>)>> {
SELECT path, contents, language FROM editors
pub fn get_path(item_id: ItemId, workspace_id: WorkspaceId) -> Result<Option<PathBuf>> {
SELECT path FROM editors
WHERE item_id = ? AND workspace_id = ?
}
}
@@ -85,20 +55,6 @@ impl EditorDb {
}
}
query! {
pub async fn save_contents(item_id: ItemId, workspace: WorkspaceId, contents: Option<String>, language: Option<String>) -> Result<()> {
INSERT INTO editors
(item_id, workspace_id, contents, language)
VALUES
(?1, ?2, ?3, ?4)
ON CONFLICT DO UPDATE SET
item_id = ?1,
workspace_id = ?2,
contents = ?3,
language = ?4
}
}
// Returns the scroll top row, and offset
query! {
pub fn get_scroll_position(item_id: ItemId, workspace_id: WorkspaceId) -> Result<Option<(u32, f32, f32)>> {
@@ -124,75 +80,4 @@ impl EditorDb {
WHERE item_id = ?1 AND workspace_id = ?2
}
}
pub async fn delete_unloaded_items(
&self,
workspace: WorkspaceId,
alive_items: Vec<ItemId>,
) -> Result<()> {
let placeholders = alive_items
.iter()
.map(|_| "?")
.collect::<Vec<&str>>()
.join(", ");
let query = format!(
"DELETE FROM editors WHERE workspace_id = ? AND item_id NOT IN ({placeholders})"
);
self.write(move |conn| {
let mut statement = Statement::prepare(conn, query)?;
let mut next_index = statement.bind(&workspace, 1)?;
for id in alive_items {
next_index = statement.bind(&id, next_index)?;
}
statement.exec()
})
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use gpui;
#[gpui::test]
async fn test_saving_content() {
env_logger::try_init().ok();
let workspace_id = workspace::WORKSPACE_DB.next_id().await.unwrap();
// Sanity check: make sure there is no row in the `editors` table
assert_eq!(DB.get_path_and_contents(1234, workspace_id).unwrap(), None);
// Save content/language
DB.save_contents(
1234,
workspace_id,
Some("testing".into()),
Some("Go".into()),
)
.await
.unwrap();
// Check that it can be read from DB
let path_and_contents = DB.get_path_and_contents(1234, workspace_id).unwrap();
let (path, contents, language) = path_and_contents.unwrap();
assert!(path.is_none());
assert_eq!(contents, Some("testing".to_owned()));
assert_eq!(language, Some("Go".to_owned()));
// Update it with NULL
DB.save_contents(1234, workspace_id, None, None)
.await
.unwrap();
// Check that it worked
let path_and_contents = DB.get_path_and_contents(1234, workspace_id).unwrap();
let (path, contents, language) = path_and_contents.unwrap();
assert!(path.is_none());
assert!(contents.is_none());
assert!(language.is_none());
}
}

View File

@@ -113,7 +113,6 @@ pub fn expand_macro_recursively(
cx.new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), true, cx)),
),
None,
true,
cx,
);
})

Some files were not shown because too many files have changed in this diff Show More