Compare commits

..

1 Commits

Author SHA1 Message Date
Max Brunsfeld
961e548bd6 Introduce extension-cli for packaging extensions in CI 2024-03-18 16:57:37 -07:00
650 changed files with 20320 additions and 32151 deletions

View File

@@ -23,6 +23,12 @@ body:
description: Run the `copy system specs into clipboard` command palette action and paste the output in the field below.
validations:
required: true
- type: textarea
attributes:
label: If applicable, add mockups / screenshots to help explain present your vision of the feature
description: Drag issues into the text input below
validations:
required: false
- type: textarea
attributes:
label: If applicable, attach your `~/Library/Logs/Zed/Zed.log` file to this issue.

View File

@@ -54,9 +54,6 @@ jobs:
- name: Check unused dependencies
uses: bnjbvr/cargo-machete@main
- name: Check license generation
run: script/generate-licenses /tmp/zed_licenses_output
- name: Ensure fresh merge
shell: bash -euxo pipefail {0}
run: |

View File

@@ -1,40 +0,0 @@
name: Publish zed-extension CLI
on:
push:
tags:
- extension-cli
env:
CARGO_TERM_COLOR: always
CARGO_INCREMENTAL: 0
jobs:
publish:
name: Publish zed-extension CLI
runs-on:
- ubuntu-latest
steps:
- name: Checkout repo
uses: actions/checkout@v4
with:
clean: false
submodules: "recursive"
- name: Cache dependencies
uses: swatinem/rust-cache@v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
- name: Configure linux
shell: bash -euxo pipefail {0}
run: script/linux
- name: Build extension CLI
run: cargo build --release --package extension_cli
- name: Upload binary
env:
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
run: script/upload-extension-cli ${{ github.sha }}

View File

@@ -9,10 +9,10 @@ jobs:
if: github.repository_owner == 'zed-industries'
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- uses: actions/setup-python@v4
with:
python-version: "3.11"
python-version: "3.10.5"
architecture: "x64"
cache: "pip"
- run: pip install -r script/update_top_ranking_issues/requirements.txt
- run: python script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 5393
- run: python script/update_top_ranking_issues/main.py 5393 --github-token ${{ secrets.GITHUB_TOKEN }} --prod

View File

@@ -9,10 +9,10 @@ jobs:
if: github.repository_owner == 'zed-industries'
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- uses: actions/setup-python@v4
with:
python-version: "3.11"
python-version: "3.10.5"
architecture: "x64"
cache: "pip"
- run: pip install -r script/update_top_ranking_issues/requirements.txt
- run: python script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 6952 --query-day-interval 7
- run: python script/update_top_ranking_issues/main.py 6952 --github-token ${{ secrets.GITHUB_TOKEN }} --prod --query-day-interval 7

View File

@@ -15,10 +15,6 @@
"JSON": {
"tab_size": 2,
"formatter": "prettier"
},
"JavaScript": {
"tab_size": 2,
"formatter": "prettier"
}
},
"formatter": "auto"

View File

1342
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,7 @@
[workspace]
members = [
"crates/activity_indicator",
"crates/anthropic",
"crates/ai",
"crates/assets",
"crates/assistant",
"crates/audio",
@@ -29,16 +29,13 @@ members = [
"crates/feature_flags",
"crates/feedback",
"crates/file_finder",
"crates/file_icons",
"crates/fs",
"crates/fsevent",
"crates/fuzzy",
"crates/git",
"crates/go_to_line",
"crates/google_ai",
"crates/gpui",
"crates/gpui_macros",
"crates/image_viewer",
"crates/install_cli",
"crates/journal",
"crates/language",
@@ -54,7 +51,6 @@ members = [
"crates/multi_buffer",
"crates/node_runtime",
"crates/notifications",
"crates/open_ai",
"crates/outline",
"crates/picker",
"crates/prettier",
@@ -72,7 +68,7 @@ members = [
"crates/task",
"crates/tasks_ui",
"crates/search",
"crates/semantic_version",
"crates/semantic_index",
"crates/settings",
"crates/snippet",
"crates/sqlez",
@@ -80,11 +76,9 @@ members = [
"crates/story",
"crates/storybook",
"crates/sum_tree",
"crates/tab_switcher",
"crates/terminal",
"crates/terminal_view",
"crates/text",
"crates/text-eg",
"crates/theme",
"crates/theme_importer",
"crates/theme_selector",
@@ -100,21 +94,8 @@ members = [
"crates/zed",
"crates/zed_actions",
"extensions/astro",
"extensions/clojure",
"extensions/csharp",
"extensions/emmet",
"extensions/erlang",
"extensions/gleam",
"extensions/haskell",
"extensions/html",
"extensions/php",
"extensions/prisma",
"extensions/purescript",
"extensions/svelte",
"extensions/toml",
"extensions/uiua",
"extensions/zig",
"tooling/xtask",
]
@@ -124,7 +105,6 @@ resolver = "2"
[workspace.dependencies]
activity_indicator = { path = "crates/activity_indicator" }
ai = { path = "crates/ai" }
anthropic = { path = "crates/anthropic" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
audio = { path = "crates/audio" }
@@ -152,17 +132,14 @@ extensions_ui = { path = "crates/extensions_ui" }
feature_flags = { path = "crates/feature_flags" }
feedback = { path = "crates/feedback" }
file_finder = { path = "crates/file_finder" }
file_icons = { path = "crates/file_icons" }
fs = { path = "crates/fs" }
fsevent = { path = "crates/fsevent" }
fuzzy = { path = "crates/fuzzy" }
git = { path = "crates/git" }
go_to_line = { path = "crates/go_to_line" }
google_ai = { path = "crates/google_ai" }
gpui = { path = "crates/gpui" }
gpui_macros = { path = "crates/gpui_macros" }
install_cli = { path = "crates/install_cli" }
image_viewer = { path = "crates/image_viewer" }
journal = { path = "crates/journal" }
language = { path = "crates/language" }
language_selector = { path = "crates/language_selector" }
@@ -177,7 +154,6 @@ menu = { path = "crates/menu" }
multi_buffer = { path = "crates/multi_buffer" }
node_runtime = { path = "crates/node_runtime" }
notifications = { path = "crates/notifications" }
open_ai = { path = "crates/open_ai" }
outline = { path = "crates/outline" }
picker = { path = "crates/picker" }
plugin = { path = "crates/plugin" }
@@ -196,7 +172,7 @@ rpc = { path = "crates/rpc" }
task = { path = "crates/task" }
tasks_ui = { path = "crates/tasks_ui" }
search = { path = "crates/search" }
semantic_version = { path = "crates/semantic_version" }
semantic_index = { path = "crates/semantic_index" }
settings = { path = "crates/settings" }
snippet = { path = "crates/snippet" }
sqlez = { path = "crates/sqlez" }
@@ -204,7 +180,6 @@ sqlez_macros = { path = "crates/sqlez_macros" }
story = { path = "crates/story" }
storybook = { path = "crates/storybook" }
sum_tree = { path = "crates/sum_tree" }
tab_switcher = { path = "crates/tab_switcher" }
terminal = { path = "crates/terminal" }
terminal_view = { path = "crates/terminal_view" }
text = { path = "crates/text" }
@@ -223,7 +198,6 @@ zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
anyhow = "1.0.57"
any_vec = "0.13"
async-compression = { version = "0.4", features = ["gzip", "futures-io"] }
async-fs = "1.6"
async-recursion = "1.0.0"
@@ -233,7 +207,7 @@ bitflags = "2.4.2"
blade-graphics = { git = "https://github.com/kvark/blade", rev = "61cbd6b2c224791d52b150fe535cee665cc91bb2" }
blade-macros = { git = "https://github.com/kvark/blade", rev = "61cbd6b2c224791d52b150fe535cee665cc91bb2" }
blade-rwh = { package = "raw-window-handle", version = "0.5" }
cap-std = "3.0"
cap-std = "2.0"
chrono = { version = "0.4", features = ["serde"] }
clap = { version = "4.4", features = ["derive"] }
clickhouse = { version = "0.11.6" }
@@ -265,9 +239,7 @@ parking_lot = "0.12.1"
profiling = "1"
postage = { version = "0.5", features = ["futures-traits"] }
pretty_assertions = "1.3.0"
prost = "0.9"
prost-build = "0.9"
prost-types = "0.9"
prost = "0.8"
pulldown-cmark = { version = "0.10.0", default-features = false }
rand = "0.8.5"
refineable = { path = "./crates/refineable" }
@@ -295,8 +267,6 @@ tempfile = "3.9.0"
thiserror = "1.0.29"
tiktoken-rs = "0.5.7"
time = { version = "0.3", features = [
"macros",
"parsing",
"serde",
"serde-well-known",
"formatting",
@@ -305,18 +275,25 @@ toml = "0.8"
tokio = { version = "1", features = ["full"] }
tower-http = "0.4.4"
tree-sitter = { version = "0.20", features = ["wasm"] }
tree-sitter-astro = { git = "https://github.com/virchau13/tree-sitter-astro.git", rev = "e924787e12e8a03194f36a113290ac11d6dc10f3" }
tree-sitter-bash = { git = "https://github.com/tree-sitter/tree-sitter-bash", rev = "7331995b19b8f8aba2d5e26deb51d2195c18bc94" }
tree-sitter-c = "0.20.1"
tree-sitter-clojure = { git = "https://github.com/prcastro/tree-sitter-clojure", branch = "update-ts" }
tree-sitter-c-sharp = { git = "https://github.com/tree-sitter/tree-sitter-c-sharp", rev = "dd5e59721a5f8dae34604060833902b882023aaf" }
tree-sitter-cpp = { git = "https://github.com/tree-sitter/tree-sitter-cpp", rev = "f44509141e7e483323d2ec178f2d2e6c0fc041c1" }
tree-sitter-css = { git = "https://github.com/tree-sitter/tree-sitter-css", rev = "769203d0f9abe1a9a691ac2b9fe4bb4397a73c51" }
tree-sitter-dockerfile = { git = "https://github.com/camdencheek/tree-sitter-dockerfile", rev = "33e22c33bcdbfc33d42806ee84cfd0b1248cc392" }
tree-sitter-dart = { git = "https://github.com/agent3bood/tree-sitter-dart", rev = "48934e3bf757a9b78f17bdfaa3e2b4284656fdc7" }
tree-sitter-elixir = { git = "https://github.com/elixir-lang/tree-sitter-elixir", rev = "a2861e88a730287a60c11ea9299c033c7d076e30" }
tree-sitter-elm = { git = "https://github.com/elm-tooling/tree-sitter-elm", rev = "692c50c0b961364c40299e73c1306aecb5d20f40" }
tree-sitter-embedded-template = "0.20.0"
tree-sitter-erlang = "0.4.0"
tree-sitter-gleam = { git = "https://github.com/gleam-lang/tree-sitter-gleam", rev = "58b7cac8fc14c92b0677c542610d8738c373fa81" }
tree-sitter-glsl = { git = "https://github.com/theHamsta/tree-sitter-glsl", rev = "2a56fb7bc8bb03a1892b4741279dd0a8758b7fb3" }
tree-sitter-go = { git = "https://github.com/tree-sitter/tree-sitter-go", rev = "aeb2f33b366fd78d5789ff104956ce23508b85db" }
tree-sitter-gomod = { git = "https://github.com/camdencheek/tree-sitter-go-mod" }
tree-sitter-gowork = { git = "https://github.com/d1y/tree-sitter-go-work" }
tree-sitter-haskell = { git = "https://github.com/tree-sitter/tree-sitter-haskell", rev = "8a99848fc734f9c4ea523b3f2a07df133cbbcec2" }
tree-sitter-hcl = { git = "https://github.com/MichaHoffmann/tree-sitter-hcl", rev = "v1.1.0" }
rustc-demangle = "0.1.23"
tree-sitter-heex = { git = "https://github.com/phoenixframework/tree-sitter-heex", rev = "2e1348c3cf2c9323e87c2744796cf3f3868aa82a" }
@@ -328,32 +305,38 @@ tree-sitter-markdown = { git = "https://github.com/MDeiml/tree-sitter-markdown",
tree-sitter-nix = { git = "https://github.com/nix-community/tree-sitter-nix", rev = "66e3e9ce9180ae08fc57372061006ef83f0abde7" }
tree-sitter-nu = { git = "https://github.com/nushell/tree-sitter-nu", rev = "7dd29f9616822e5fc259f5b4ae6c4ded9a71a132" }
tree-sitter-ocaml = { git = "https://github.com/tree-sitter/tree-sitter-ocaml", rev = "4abfdc1c7af2c6c77a370aee974627be1c285b3b" }
tree-sitter-php = "0.21.1"
tree-sitter-prisma-io = { git = "https://github.com/victorhqc/tree-sitter-prisma" }
tree-sitter-proto = { git = "https://github.com/rewinfrey/tree-sitter-proto", rev = "36d54f288aee112f13a67b550ad32634d0c2cb52" }
tree-sitter-purescript = { git = "https://github.com/postsolar/tree-sitter-purescript", rev = "v0.1.0" }
tree-sitter-python = "0.20.2"
tree-sitter-racket = { git = "https://github.com/zed-industries/tree-sitter-racket", rev = "eb010cf2c674c6fd9a6316a84e28ef90190fe51a" }
tree-sitter-regex = "0.20.0"
tree-sitter-ruby = "0.20.0"
tree-sitter-rust = "0.20.3"
tree-sitter-scheme = { git = "https://github.com/6cdh/tree-sitter-scheme", rev = "af0fd1fa452cb2562dc7b5c8a8c55551c39273b9" }
tree-sitter-svelte = { git = "https://github.com/Himujjal/tree-sitter-svelte", rev = "bd60db7d3d06f89b6ec3b287c9a6e9190b5564bd" }
tree-sitter-toml = { git = "https://github.com/tree-sitter/tree-sitter-toml", rev = "342d9be207c2dba869b9967124c679b5e6fd0ebe" }
tree-sitter-typescript = { git = "https://github.com/tree-sitter/tree-sitter-typescript", rev = "5d20856f34315b068c41edaee2ac8a100081d259" }
tree-sitter-vue = { git = "https://github.com/zed-industries/tree-sitter-vue", rev = "6608d9d60c386f19d80af7d8132322fa11199c42" }
tree-sitter-yaml = { git = "https://github.com/zed-industries/tree-sitter-yaml", rev = "f545a41f57502e1b5ddf2a6668896c1b0620f930" }
tree-sitter-zig = { git = "https://github.com/maxxnino/tree-sitter-zig", rev = "0d08703e4c3f426ec61695d7617415fff97029bd" }
unindent = "0.1.7"
unicase = "2.6"
url = "2.2"
uuid = { version = "1.1.2", features = ["v4"] }
wasmparser = "0.201"
wasm-encoder = "0.201"
wasmtime = { version = "19.0.0", default-features = false, features = [
wasmparser = "0.121"
wasm-encoder = "0.41"
wasmtime = { version = "18.0", default-features = false, features = [
"async",
"demangle",
"runtime",
"cranelift",
"component-model",
] }
wasmtime-wasi = "19.0.0"
wasmtime-wasi = "18.0"
which = "6.0.0"
wit-component = "0.201"
wit-component = "0.20"
sys-locale = "0.3.1"
[workspace.dependencies.windows]
@@ -368,7 +351,6 @@ features = [
"Win32_Security",
"Win32_Security_Credentials",
"Win32_Storage_FileSystem",
"Win32_System_LibraryLoader",
"Win32_System_Com",
"Win32_System_Com_StructuredStorage",
"Win32_System_DataExchange",
@@ -387,7 +369,7 @@ features = [
]
[patch.crates-io]
tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "7f21c3b98c0749ac192da67a0d65dfe3eabc4a63" }
tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "4294e59279205f503eb14348dd5128bd5910c8fb" }
# Workaround for a broken nightly build of gpui: See #7644 and revisit once 0.5.3 is released.
pathfinder_simd = { git = "https://github.com/servo/pathfinder.git", rev = "30419d07660dc11a21e42ef4a7fa329600cff152" }
@@ -398,20 +380,15 @@ debug = "limited"
[profile.dev.package]
taffy = { opt-level = 3 }
cranelift-codegen = { opt-level = 3 }
resvg = { opt-level = 3 }
rustybuzz = { opt-level = 3 }
ttf-parser = { opt-level = 3 }
wasmtime-cranelift = { opt-level = 3 }
wasmtime = { opt-level = 3 }
[profile.release]
debug = "limited"
lto = "thin"
codegen-units = 1
[profile.release.package]
zed = { codegen-units = 16 }
[workspace.lints.clippy]
dbg_macro = "deny"
todo = "deny"

View File

@@ -1,6 +1,6 @@
# syntax = docker/dockerfile:1.2
FROM rust:1.77-bookworm as builder
FROM rust:1.76-bookworm as builder
WORKDIR app
COPY . .

View File

@@ -1,4 +0,0 @@
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
<svg width="800px" height="800px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M18 10L21 7L17 3L14 6M18 10L8 20H4V16L14 6M18 10L14 6" stroke="#000000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 379 B

View File

@@ -1,6 +1,4 @@
[
// todo(linux): Review the editor bindings
// Standard Linux bindings
{
"bindings": {
"up": "menu::SelectPrev",
@@ -11,14 +9,14 @@
"pagedown": "menu::SelectLast",
"shift-pagedown": "menu::SelectFirst",
"ctrl-n": "menu::SelectNext",
"ctrl-up": "menu::SelectFirst",
"ctrl-down": "menu::SelectLast",
"enter": "menu::Confirm",
"shift-f10": "menu::ShowContextMenu",
"ctrl-enter": "menu::SecondaryConfirm",
"escape": "menu::Cancel",
"ctrl-escape": "menu::Cancel",
"ctrl-c": "menu::Cancel",
"shift-enter": "picker::UseSelectedQuery",
"alt-enter": ["picker::ConfirmInput", { "secondary": false }],
"ctrl-alt-enter": ["picker::ConfirmInput", { "secondary": true }],
"shift-enter": "menu::UseSelectedQuery",
"ctrl-shift-w": "workspace::CloseWindow",
"shift-escape": "workspace::ToggleZoom",
"ctrl-o": "workspace::Open",
@@ -29,6 +27,8 @@
"ctrl-,": "zed::OpenSettings",
"ctrl-q": "zed::Quit",
"ctrl-h": "zed::Hide",
"alt-ctrl-h": "zed::HideOthers",
"ctrl-m": "zed::Minimize",
"f11": "zed::ToggleFullScreen"
}
},
@@ -45,101 +45,80 @@
"shift-tab": "editor::TabPrev",
"ctrl-k": "editor::CutToEndOfLine",
"ctrl-t": "editor::Transpose",
// "ctrl-backspace": "editor::DeleteToBeginningOfLine",
// "ctrl-delete": "editor::DeleteToEndOfLine",
"ctrl-backspace": "editor::DeleteToPreviousWordStart",
// "ctrl-w": "editor::DeleteToPreviousWordStart",
"ctrl-delete": "editor::DeleteToNextWordEnd",
// "alt-h": "editor::DeleteToPreviousWordStart",
// "alt-d": "editor::DeleteToNextWordEnd",
"ctrl-backspace": "editor::DeleteToBeginningOfLine",
"ctrl-delete": "editor::DeleteToEndOfLine",
"alt-backspace": "editor::DeleteToPreviousWordStart",
"alt-delete": "editor::DeleteToNextWordEnd",
"alt-h": "editor::DeleteToPreviousWordStart",
"alt-d": "editor::DeleteToNextWordEnd",
"ctrl-x": "editor::Cut",
"ctrl-c": "editor::Copy",
"ctrl-v": "editor::Paste",
"ctrl-z": "editor::Undo",
"ctrl-shift-z": "editor::Redo",
"ctrl-y": "editor::Redo",
"up": "editor::MoveUp",
// "ctrl-up": "editor::MoveToStartOfParagraph", todo(linux) Should be "scroll down by 1 line"
"ctrl-up": "editor::MoveToStartOfParagraph",
"pageup": "editor::PageUp",
// "shift-pageup": "editor::MovePageUp", todo(linux) should be 'select page up'
"shift-pageup": "editor::MovePageUp",
"home": "editor::MoveToBeginningOfLine",
"down": "editor::MoveDown",
// "ctrl-down": "editor::MoveToEndOfParagraph", todo(linux) should be "scroll up by 1 line"
"ctrl-down": "editor::MoveToEndOfParagraph",
"pagedown": "editor::PageDown",
// "shift-pagedown": "editor::MovePageDown", todo(linux) should be 'select page down'
"shift-pagedown": "editor::MovePageDown",
"end": "editor::MoveToEndOfLine",
"left": "editor::MoveLeft",
"right": "editor::MoveRight",
"ctrl-left": "editor::MoveToPreviousWordStart",
// "alt-b": "editor::MoveToPreviousWordStart",
"ctrl-right": "editor::MoveToNextWordEnd",
// "alt-f": "editor::MoveToNextWordEnd",
// "cmd-left": "editor::MoveToBeginningOfLine",
// "ctrl-a": "editor::MoveToBeginningOfLine",
// "cmd-right": "editor::MoveToEndOfLine",
// "ctrl-e": "editor::MoveToEndOfLine",
"ctrl-p": "editor::MoveUp",
"ctrl-n": "editor::MoveDown",
"ctrl-b": "editor::MoveLeft",
"ctrl-f": "editor::MoveRight",
"ctrl-shift-l": "editor::NextScreen", // todo(linux): What is this
"alt-left": "editor::MoveToPreviousWordStart",
"alt-b": "editor::MoveToPreviousWordStart",
"alt-right": "editor::MoveToNextWordEnd",
"alt-f": "editor::MoveToNextWordEnd",
"ctrl-e": "editor::MoveToEndOfLine",
"ctrl-home": "editor::MoveToBeginning",
"ctrl-end": "editor::MoveToEnd",
"ctrl-=end": "editor::MoveToEnd",
"shift-up": "editor::SelectUp",
"shift-down": "editor::SelectDown",
"ctrl-shift-n": "editor::SelectDown",
"shift-left": "editor::SelectLeft",
"ctrl-shift-b": "editor::SelectLeft",
"shift-right": "editor::SelectRight",
"ctrl-shift-left": "editor::SelectToPreviousWordStart",
"ctrl-shift-right": "editor::SelectToNextWordEnd",
"ctrl-shift-up": "editor::AddSelectionAbove",
"ctrl-shift-down": "editor::AddSelectionBelow",
// "ctrl-shift-up": "editor::SelectToStartOfParagraph",
// "ctrl-shift-down": "editor::SelectToEndOfParagraph",
"ctrl-shift-f": "editor::SelectRight",
"alt-shift-left": "editor::SelectToPreviousWordStart",
"alt-shift-b": "editor::SelectToPreviousWordStart",
"alt-shift-right": "editor::SelectToNextWordEnd",
"alt-shift-f": "editor::SelectToNextWordEnd",
"ctrl-shift-up": "editor::SelectToStartOfParagraph",
"ctrl-shift-down": "editor::SelectToEndOfParagraph",
"ctrl-shift-home": "editor::SelectToBeginning",
"ctrl-shift-end": "editor::SelectToEnd",
"ctrl-a": "editor::SelectAll",
"ctrl-l": "editor::SelectLine",
"ctrl-shift-i": "editor::Format",
// "cmd-shift-left": [
// "editor::SelectToBeginningOfLine",
// {
// "stop_at_soft_wraps": true
// }
// ],
"shift-home": [
"editor::SelectToBeginningOfLine",
{
"stop_at_soft_wraps": true
}
],
// "ctrl-shift-a": [
// "editor::SelectToBeginningOfLine",
// {
// "stop_at_soft_wraps": true
// }
// ],
// "cmd-shift-right": [
// "editor::SelectToEndOfLine",
// {
// "stop_at_soft_wraps": true
// }
// ],
"shift-end": [
"editor::SelectToEndOfLine",
{
"stop_at_soft_wraps": true
}
],
// "ctrl-shift-e": [
// "editor::SelectToEndOfLine",
// {
// "stop_at_soft_wraps": true
// }
// ],
// "alt-v": [
// "editor::MovePageUp",
// {
// "center_cursor": true
// }
// ],
"ctrl-alt-space": "editor::ShowCharacterPalette",
"ctrl-shift-e": [
"editor::SelectToEndOfLine",
{
"stop_at_soft_wraps": true
}
],
"ctrl-;": "editor::ToggleLineNumbers",
"ctrl-k ctrl-r": "editor::RevertSelectedHunks",
"ctrl-alt-g b": "editor::ToggleGitBlame"
"ctrl-alt-z": "editor::RevertSelectedHunks"
}
},
{
@@ -147,8 +126,8 @@
"bindings": {
"enter": "editor::Newline",
"shift-enter": "editor::Newline",
"ctrl-shift-enter": "editor::NewlineBelow",
"ctrl-enter": "editor::NewlineAbove",
"ctrl-shift-enter": "editor::NewlineAbove",
"ctrl-enter": "editor::NewlineBelow",
"alt-z": "editor::ToggleSoftWrap",
"ctrl-f": [
"buffer_search::Deploy",
@@ -156,27 +135,21 @@
"focus": true
}
],
// "cmd-e": [
// "buffer_search::Deploy",
// {
// "focus": false
// }
// ],
"ctrl->": "assistant::QuoteSelection"
}
},
{
"context": "Editor && mode == full && inline_completion",
"context": "Editor && mode == full && copilot_suggestion",
"bindings": {
"alt-]": "editor::NextInlineCompletion",
"alt-[": "editor::PreviousInlineCompletion",
"alt-right": "editor::AcceptPartialInlineCompletion"
"alt-]": "copilot::NextSuggestion",
"alt-[": "copilot::PreviousSuggestion",
"alt-right": "editor::AcceptPartialCopilotSuggestion"
}
},
{
"context": "Editor && !inline_completion",
"context": "Editor && !copilot_suggestion",
"bindings": {
"alt-\\": "editor::ShowInlineCompletion"
"alt-\\": "copilot::Suggest"
}
},
{
@@ -190,8 +163,8 @@
{
"context": "AssistantPanel",
"bindings": {
"ctrl-g": "search::SelectNextMatch",
"ctrl-shift-g": "search::SelectPrevMatch"
"f3": "search::SelectNextMatch",
"shift-f3": "search::SelectPrevMatch"
}
},
{
@@ -235,8 +208,9 @@
"escape": "project_search::ToggleFocus",
"alt-tab": "search::CycleMode",
"ctrl-shift-h": "search::ToggleReplace",
"alt-ctrl-g": "search::ActivateRegexMode",
"alt-ctrl-x": "search::ActivateTextMode"
"ctrl-alt-g": "search::ActivateRegexMode",
"ctrl-alt-s": "search::ActivateSemanticMode",
"ctrl-alt-x": "search::ActivateTextMode"
}
},
{
@@ -250,7 +224,7 @@
"context": "ProjectSearchBar && in_replace",
"bindings": {
"enter": "search::ReplaceNext",
"ctrl-alt-enter": "search::ReplaceAll"
"ctrl-enter": "search::ReplaceAll"
}
},
{
@@ -259,31 +233,35 @@
"escape": "project_search::ToggleFocus",
"alt-tab": "search::CycleMode",
"ctrl-shift-h": "search::ToggleReplace",
"alt-ctrl-g": "search::ActivateRegexMode",
"alt-ctrl-x": "search::ActivateTextMode"
"ctrl-alt-g": "search::ActivateRegexMode",
"ctrl-alt-s": "search::ActivateSemanticMode",
"ctrl-alt-x": "search::ActivateTextMode"
}
},
{
"context": "Pane",
"bindings": {
"ctrl-pageup": "pane::ActivatePrevItem",
"ctrl-pagedown": "pane::ActivateNextItem",
"ctrl-{": "pane::ActivatePrevItem",
"ctrl-}": "pane::ActivateNextItem",
"ctrl-alt-left": "pane::ActivatePrevItem",
"ctrl-alt-right": "pane::ActivateNextItem",
"ctrl-w": "pane::CloseActiveItem",
"alt-ctrl-t": "pane::CloseInactiveItems",
"alt-ctrl-shift-w": "workspace::CloseInactiveTabsAndPanes",
"ctrl-alt-t": "pane::CloseInactiveItems",
"ctrl-alt-shift-w": "workspace::CloseInactiveTabsAndPanes",
"ctrl-k u": "pane::CloseCleanItems",
"ctrl-k w": "pane::CloseAllItems",
"ctrl-shift-f": "project_search::ToggleFocus",
"ctrl-alt-g": "search::SelectNextMatch",
"ctrl-alt-shift-g": "search::SelectPrevMatch",
"ctrl-alt-shift-h": "search::ToggleReplace",
"ctrl-k ctrl-w": "pane::CloseAllItems",
"ctrl-f": "project_search::ToggleFocus",
"f3": "search::SelectNextMatch",
"shift-f3": "search::SelectPrevMatch",
"ctrl-shift-h": "search::ToggleReplace",
"alt-enter": "search::SelectAllMatches",
"alt-c": "search::ToggleCaseSensitive",
"alt-w": "search::ToggleWholeWord",
"alt-r": "search::CycleMode",
"alt-ctrl-f": "project_search::ToggleFilters",
"ctrl-alt-shift-r": "search::ActivateRegexMode",
"ctrl-alt-shift-x": "search::ActivateTextMode"
"ctrl-alt-c": "search::ToggleCaseSensitive",
"ctrl-alt-w": "search::ToggleWholeWord",
"alt-tab": "search::CycleMode",
"ctrl-alt-f": "project_search::ToggleFilters",
"ctrl-alt-g": "search::ActivateRegexMode",
"ctrl-alt-s": "search::ActivateSemanticMode",
"ctrl-alt-x": "search::ActivateTextMode"
}
},
// Bindings from VS Code
@@ -292,22 +270,8 @@
"bindings": {
"ctrl-[": "editor::Outdent",
"ctrl-]": "editor::Indent",
"shift-alt-up": "editor::AddSelectionAbove",
"shift-alt-down": "editor::AddSelectionBelow",
"ctrl-shift-k": "editor::DeleteLine",
"alt-up": "editor::MoveLineUp",
"alt-down": "editor::MoveLineDown",
"ctrl-alt-shift-up": [
"editor::DuplicateLine",
{
"move_upwards": true
}
],
"ctrl-alt-shift-down": "editor::DuplicateLine",
"ctrl-shift-left": "editor::SelectToPreviousWordStart",
"ctrl-shift-right": "editor::SelectToNextWordEnd",
"ctrl-shift-up": "editor::SelectLargerSyntaxNode", //todo(linux) tmp keybinding
"ctrl-shift-down": "editor::SelectSmallerSyntaxNode", //todo(linux) tmp keybinding
"ctrl-alt-up": "editor::AddSelectionAbove",
"ctrl-alt-down": "editor::AddSelectionBelow",
"ctrl-d": [
"editor::SelectNext",
{
@@ -340,6 +304,8 @@
"advance_downwards": false
}
],
"alt-up": "editor::SelectLargerSyntaxNode",
"alt-down": "editor::SelectSmallerSyntaxNode",
"ctrl-u": "editor::UndoSelection",
"ctrl-shift-u": "editor::RedoSelection",
"f8": "editor::GoToDiagnostic",
@@ -348,16 +314,15 @@
"f12": "editor::GoToDefinition",
"alt-f12": "editor::GoToDefinitionSplit",
"ctrl-f12": "editor::GoToTypeDefinition",
"shift-f12": "editor::GoToImplementation",
"alt-ctrl-f12": "editor::GoToTypeDefinitionSplit",
"ctrl-alt-f12": "editor::GoToTypeDefinitionSplit",
"alt-shift-f12": "editor::FindAllReferences",
"ctrl-m": "editor::MoveToEnclosingBracket",
"ctrl-shift-[": "editor::Fold",
"ctrl-shift-]": "editor::UnfoldLines",
"ctrl-alt-[": "editor::Fold",
"ctrl-alt-]": "editor::UnfoldLines",
"ctrl-space": "editor::ShowCompletions",
"ctrl-.": "editor::ToggleCodeActions",
"alt-ctrl-r": "editor::RevealInFinder",
"ctrl-alt-shift-c": "editor::DisplayCursorNames"
"ctrl-alt-r": "editor::RevealInFinder",
"ctrl-alt-c": "editor::DisplayCursorNames"
}
},
{
@@ -370,18 +335,18 @@
{
"context": "Pane",
"bindings": {
"alt-1": ["pane::ActivateItem", 0],
"alt-2": ["pane::ActivateItem", 1],
"alt-3": ["pane::ActivateItem", 2],
"alt-4": ["pane::ActivateItem", 3],
"alt-5": ["pane::ActivateItem", 4],
"alt-6": ["pane::ActivateItem", 5],
"alt-7": ["pane::ActivateItem", 6],
"alt-8": ["pane::ActivateItem", 7],
"alt-9": ["pane::ActivateItem", 8],
"alt-0": "pane::ActivateLastItem",
"ctrl-alt--": "pane::GoBack",
"ctrl-alt-_": "pane::GoForward",
"ctrl-1": ["pane::ActivateItem", 0],
"ctrl-2": ["pane::ActivateItem", 1],
"ctrl-3": ["pane::ActivateItem", 2],
"ctrl-4": ["pane::ActivateItem", 3],
"ctrl-5": ["pane::ActivateItem", 4],
"ctrl-6": ["pane::ActivateItem", 5],
"ctrl-7": ["pane::ActivateItem", 6],
"ctrl-8": ["pane::ActivateItem", 7],
"ctrl-9": ["pane::ActivateItem", 8],
"ctrl-0": "pane::ActivateLastItem",
"ctrl--": "pane::GoBack",
"ctrl-_": "pane::GoForward",
"ctrl-shift-t": "pane::ReopenClosedItem",
"ctrl-shift-f": "project_search::ToggleFocus"
}
@@ -396,8 +361,8 @@
// "create_new_window": true
// }
// ]
"alt-ctrl-o": "projects::OpenRecent",
"alt-ctrl-shift-b": "branches::OpenRecent",
"ctrl-alt-o": "projects::OpenRecent",
"ctrl-alt-b": "branches::OpenRecent",
"ctrl-~": "workspace::NewTerminal",
"ctrl-s": "workspace::Save",
"ctrl-k s": "workspace::SaveWithoutFormat",
@@ -405,27 +370,24 @@
"ctrl-n": "workspace::NewFile",
"ctrl-shift-n": "workspace::NewWindow",
"ctrl-`": "terminal_panel::ToggleFocus",
"alt-1": ["workspace::ActivatePane", 0],
"alt-2": ["workspace::ActivatePane", 1],
"alt-3": ["workspace::ActivatePane", 2],
"alt-4": ["workspace::ActivatePane", 3],
"alt-5": ["workspace::ActivatePane", 4],
"alt-6": ["workspace::ActivatePane", 5],
"alt-7": ["workspace::ActivatePane", 6],
"alt-8": ["workspace::ActivatePane", 7],
"alt-9": ["workspace::ActivatePane", 8],
"ctrl-alt-b": "workspace::ToggleLeftDock",
"ctrl-b": "workspace::ToggleRightDock",
"ctrl-1": ["workspace::ActivatePane", 0],
"ctrl-2": ["workspace::ActivatePane", 1],
"ctrl-3": ["workspace::ActivatePane", 2],
"ctrl-4": ["workspace::ActivatePane", 3],
"ctrl-5": ["workspace::ActivatePane", 4],
"ctrl-6": ["workspace::ActivatePane", 5],
"ctrl-7": ["workspace::ActivatePane", 6],
"ctrl-8": ["workspace::ActivatePane", 7],
"ctrl-9": ["workspace::ActivatePane", 8],
"ctrl-b": "workspace::ToggleLeftDock",
"ctrl-r": "workspace::ToggleRightDock",
"ctrl-j": "workspace::ToggleBottomDock",
"ctrl-alt-y": "workspace::CloseAllDocks",
"ctrl-shift-f": "pane::DeploySearch",
"ctrl-k ctrl-s": "zed::OpenKeymap",
"ctrl-k ctrl-t": "theme_selector::Toggle",
"ctrl-shift-t": "project_symbols::Toggle",
"ctrl-k ctrl-s": "zed::OpenKeymap",
"ctrl-t": "project_symbols::Toggle",
"ctrl-p": "file_finder::Toggle",
"ctrl-tab": "tab_switcher::Toggle",
"ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }],
"ctrl-e": "file_finder::Toggle",
"ctrl-shift-p": "command_palette::Toggle",
"ctrl-shift-m": "diagnostics::Deploy",
"ctrl-shift-e": "project_panel::ToggleFocus",
@@ -446,12 +408,15 @@
}
},
// Bindings from Sublime Text
// todo(linux) make sure these match linux bindings or remove above comment?
{
"context": "Editor",
"bindings": {
"ctrl-shift-k": "editor::DeleteLine",
"ctrl-shift-d": "editor::DuplicateLineDown",
"ctrl-shift-d": "editor::DuplicateLine",
"ctrl-j": "editor::JoinLines",
"ctrl-alt-up": "editor::MoveLineUp",
"ctrl-alt-down": "editor::MoveLineDown",
"ctrl-alt-backspace": "editor::DeleteToPreviousSubwordStart",
"ctrl-alt-h": "editor::DeleteToPreviousSubwordStart",
"ctrl-alt-delete": "editor::DeleteToNextSubwordEnd",
@@ -467,6 +432,7 @@
}
},
// Bindings from Atom
// todo(linux) make sure these match linux bindings or remove above comment?
{
"context": "Pane",
"bindings": {
@@ -512,7 +478,7 @@
"bindings": {
"ctrl-alt-shift-f": "workspace::FollowNextCollaborator",
// TODO: Move this to a dock open action
"ctrl-shift-c": "collab_panel::ToggleFocus",
"ctrl-alt-c": "collab_panel::ToggleFocus",
"ctrl-alt-i": "zed::DebugElements",
"ctrl-:": "editor::ToggleInlayHints"
}
@@ -539,19 +505,19 @@
"left": "project_panel::CollapseSelectedEntry",
"right": "project_panel::ExpandSelectedEntry",
"ctrl-n": "project_panel::NewFile",
"alt-ctrl-n": "project_panel::NewDirectory",
"ctrl-alt-n": "project_panel::NewDirectory",
"ctrl-x": "project_panel::Cut",
"ctrl-c": "project_panel::Copy",
"ctrl-v": "project_panel::Paste",
"ctrl-alt-c": "project_panel::CopyPath",
"alt-ctrl-shift-c": "project_panel::CopyRelativePath",
"ctrl-alt-shift-c": "project_panel::CopyRelativePath",
"f2": "project_panel::Rename",
"enter": "project_panel::Rename",
"backspace": "project_panel::Delete",
"delete": "project_panel::Delete",
"ctrl-backspace": ["project_panel::Delete", { "skip_prompt": true }],
"ctrl-delete": ["project_panel::Delete", { "skip_prompt": true }],
"alt-ctrl-r": "project_panel::RevealInFinder",
"ctrl-alt-r": "project_panel::RevealInFinder",
"alt-shift-f": "project_panel::NewSearchInDirectory"
}
},
@@ -592,35 +558,29 @@
"escape": "chat_panel::CloseReplyPreview"
}
},
{
"context": "FileFinder",
"bindings": { "ctrl-shift-p": "file_finder::SelectPrev" }
},
{
"context": "TabSwitcher",
"bindings": {
"ctrl-shift-tab": "menu::SelectPrev",
"ctrl-backspace": "tab_switcher::CloseSelectedItem"
}
},
{
"context": "Terminal",
"bindings": {
"ctrl-alt-space": "terminal::ShowCharacterPalette",
"shift-ctrl-c": "terminal::Copy",
"shift-ctrl-v": "terminal::Paste",
"ctrl-shift-c": "terminal::Copy",
"ctrl-shift-v": "terminal::Paste",
"ctrl-k": "terminal::Clear",
// Some nice conveniences
"ctrl-backspace": ["terminal::SendText", "\u0015"],
"ctrl-right": ["terminal::SendText", "\u0005"],
"ctrl-left": ["terminal::SendText", "\u0001"],
// Terminal.app compatibility
"alt-left": ["terminal::SendText", "\u001bb"],
"alt-right": ["terminal::SendText", "\u001bf"],
// There are conflicting bindings for these keys in the global context.
// these bindings override them, remove at your own risk:
"up": ["terminal::SendKeystroke", "up"],
"pageup": ["terminal::SendKeystroke", "pageup"],
"down": ["terminal::SendKeystroke", "down"],
"pagedown": ["terminal::SendKeystroke", "pagedown"],
"escape": ["terminal::SendKeystroke", "escape"],
"enter": ["terminal::SendKeystroke", "enter"],
"ctrl-c": ["terminal::SendKeystroke", "ctrl-c"],
// Some nice conveniences
"ctrl-backspace": ["terminal::SendText", "\u0015"],
"ctrl-right": ["terminal::SendText", "\u0005"],
"ctrl-left": ["terminal::SendText", "\u0001"]
"ctrl-c": ["terminal::SendKeystroke", "ctrl-c"]
}
}
]

View File

@@ -13,15 +13,11 @@
"cmd-up": "menu::SelectFirst",
"cmd-down": "menu::SelectLast",
"enter": "menu::Confirm",
"ctrl-enter": "menu::SecondaryConfirm",
"ctrl-enter": "menu::ShowContextMenu",
"cmd-enter": "menu::SecondaryConfirm",
"escape": "menu::Cancel",
"cmd-escape": "menu::Cancel",
"ctrl-escape": "menu::Cancel",
"ctrl-c": "menu::Cancel",
"shift-enter": "picker::UseSelectedQuery",
"alt-enter": ["picker::ConfirmInput", { "secondary": false }],
"cmd-alt-enter": ["picker::ConfirmInput", { "secondary": true }],
"shift-enter": "menu::UseSelectedQuery",
"cmd-shift-w": "workspace::CloseWindow",
"shift-escape": "workspace::ToggleZoom",
"cmd-o": "workspace::Open",
@@ -158,8 +154,7 @@
],
"ctrl-cmd-space": "editor::ShowCharacterPalette",
"cmd-;": "editor::ToggleLineNumbers",
"cmd-alt-z": "editor::RevertSelectedHunks",
"cmd-alt-g b": "editor::ToggleGitBlame"
"cmd-alt-z": "editor::RevertSelectedHunks"
}
},
{
@@ -186,17 +181,17 @@
}
},
{
"context": "Editor && mode == full && inline_completion",
"context": "Editor && mode == full && copilot_suggestion",
"bindings": {
"alt-]": "editor::NextInlineCompletion",
"alt-[": "editor::PreviousInlineCompletion",
"alt-right": "editor::AcceptPartialInlineCompletion"
"alt-]": "copilot::NextSuggestion",
"alt-[": "copilot::PreviousSuggestion",
"alt-right": "editor::AcceptPartialCopilotSuggestion"
}
},
{
"context": "Editor && !inline_completion",
"context": "Editor && !copilot_suggestion",
"bindings": {
"alt-\\": "editor::ShowInlineCompletion"
"alt-\\": "copilot::Suggest"
}
},
{
@@ -256,6 +251,7 @@
"alt-tab": "search::CycleMode",
"cmd-shift-h": "search::ToggleReplace",
"alt-cmd-g": "search::ActivateRegexMode",
"alt-cmd-s": "search::ActivateSemanticMode",
"alt-cmd-x": "search::ActivateTextMode"
}
},
@@ -280,6 +276,7 @@
"alt-tab": "search::CycleMode",
"cmd-shift-h": "search::ToggleReplace",
"alt-cmd-g": "search::ActivateRegexMode",
"alt-cmd-s": "search::ActivateSemanticMode",
"alt-cmd-x": "search::ActivateTextMode"
}
},
@@ -305,6 +302,7 @@
"alt-tab": "search::CycleMode",
"alt-cmd-f": "project_search::ToggleFilters",
"alt-cmd-g": "search::ActivateRegexMode",
"alt-cmd-s": "search::ActivateSemanticMode",
"alt-cmd-x": "search::ActivateTextMode"
}
},
@@ -321,8 +319,13 @@
"cmd-shift-k": "editor::DeleteLine",
"alt-up": "editor::MoveLineUp",
"alt-down": "editor::MoveLineDown",
"alt-shift-up": "editor::DuplicateLineUp",
"alt-shift-down": "editor::DuplicateLineDown",
"alt-shift-up": [
"editor::DuplicateLine",
{
"move_upwards": true
}
],
"alt-shift-down": "editor::DuplicateLine",
"ctrl-shift-right": "editor::SelectLargerSyntaxNode",
"ctrl-shift-left": "editor::SelectSmallerSyntaxNode",
"cmd-d": [
@@ -365,7 +368,6 @@
"f12": "editor::GoToDefinition",
"alt-f12": "editor::GoToDefinitionSplit",
"cmd-f12": "editor::GoToTypeDefinition",
"shift-f12": "editor::GoToImplementation",
"alt-cmd-f12": "editor::GoToTypeDefinitionSplit",
"alt-shift-f12": "editor::FindAllReferences",
"ctrl-m": "editor::MoveToEnclosingBracket",
@@ -440,8 +442,6 @@
"cmd-k cmd-t": "theme_selector::Toggle",
"cmd-t": "project_symbols::Toggle",
"cmd-p": "file_finder::Toggle",
"ctrl-tab": "tab_switcher::Toggle",
"ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }],
"cmd-shift-p": "command_palette::Toggle",
"cmd-shift-m": "diagnostics::Deploy",
"cmd-shift-e": "project_panel::ToggleFocus",
@@ -601,14 +601,9 @@
}
},
{
"context": "FileFinder",
"bindings": { "cmd-shift-p": "file_finder::SelectPrev" }
},
{
"context": "TabSwitcher",
"context": "ChatPanel > MessageEditor",
"bindings": {
"ctrl-shift-tab": "menu::SelectPrev",
"ctrl-backspace": "tab_switcher::CloseSelectedItem"
"escape": "chat_panel::CloseReplyPreview"
}
},
{

View File

@@ -11,7 +11,7 @@
"ctrl->": "zed::IncreaseBufferFontSize",
"ctrl-<": "zed::DecreaseBufferFontSize",
"ctrl-shift-j": "editor::JoinLines",
"cmd-d": "editor::DuplicateLineDown",
"cmd-d": "editor::DuplicateLine",
"cmd-backspace": "editor::DeleteLine",
"cmd-pagedown": "editor::MovePageDown",
"cmd-pageup": "editor::MovePageUp",

View File

@@ -13,7 +13,7 @@
"cmd-up": "menu::SelectFirst",
"cmd-down": "menu::SelectLast",
"enter": "menu::Confirm",
"ctrl-enter": "menu::SecondaryConfirm",
"ctrl-enter": "menu::ShowContextMenu",
"cmd-enter": "menu::SecondaryConfirm",
"escape": "menu::Cancel",
"ctrl-c": "menu::Cancel",

View File

@@ -9,7 +9,7 @@
"context": "Editor",
"bindings": {
"cmd-l": "go_to_line::Toggle",
"ctrl-shift-d": "editor::DuplicateLineDown",
"ctrl-shift-d": "editor::DuplicateLine",
"cmd-b": "editor::GoToDefinition",
"cmd-j": "editor::ScrollCursorCenter",
"cmd-enter": "editor::NewlineBelow",

View File

@@ -510,7 +510,7 @@
"ctrl-[": "vim::NormalBefore",
"ctrl-x ctrl-o": "editor::ShowCompletions",
"ctrl-x ctrl-a": "assistant::InlineAssist", // zed specific
"ctrl-x ctrl-c": "editor::ShowInlineCompletion", // zed specific
"ctrl-x ctrl-c": "copilot::Suggest", // zed specific
"ctrl-x ctrl-l": "editor::ToggleCodeActions", // zed specific
"ctrl-x ctrl-z": "editor::Cancel",
"ctrl-w": "editor::DeleteToPreviousWordStart",
@@ -546,12 +546,6 @@
"escape": "buffer_search::Dismiss"
}
},
{
"context": "EmptyPane || SharedScreen",
"bindings": {
":": "command_palette::Toggle"
}
},
{
// netrw compatibility
"context": "ProjectPanel && not_editing",

View File

@@ -48,8 +48,7 @@
// which gives the same size as all other panes.
"active_pane_magnification": 1.0,
// The key to use for adding multiple cursors
// Currently "alt" or "cmd_or_ctrl" (also aliased as
// "cmd" and "ctrl") are supported.
// Currently "alt" or "cmd" are supported.
"multi_cursor_modifier": "alt",
// Whether to enable vim modes and key bindings
"vim_mode": false,
@@ -93,12 +92,6 @@
// Whether to automatically type closing characters for you. For example,
// when you type (, Zed will automatically add a closing ) at the correct position.
"use_autoclose": true,
// Controls how the editor handles the autoclosed characters.
// When set to `false`(default), skipping over and auto-removing of the closing characters
// happen only for auto-inserted characters.
// Otherwise(when `true`), the closing characters are always skipped over and auto-removed
// no matter how they were inserted.
"always_treat_brackets_as_autoclosed": false,
// Controls whether copilot provides suggestion immediately
// or waits for a `copilot::Toggle`
"show_copilot_suggestions": true,
@@ -244,10 +237,6 @@
"default_width": 380
},
"assistant": {
// Version of this setting.
"version": "1",
// Whether the assistant is enabled.
"enabled": true,
// Whether to show the assistant panel button in the status bar.
"button": true,
// Where to dock the assistant panel. Can be 'left', 'right' or 'bottom'.
@@ -256,16 +245,28 @@
"default_width": 640,
// Default height when the assistant is docked to the bottom.
"default_height": 320,
// AI provider.
// Deprecated: Please use `provider.api_url` instead.
// The default OpenAI API endpoint to use when starting new conversations.
"openai_api_url": "https://api.openai.com/v1",
// Deprecated: Please use `provider.default_model` instead.
// The default OpenAI model to use when starting new conversations. This
// setting can take three values:
//
// 1. "gpt-3.5-turbo-0613""
// 2. "gpt-4-0613""
// 3. "gpt-4-1106-preview"
"default_open_ai_model": "gpt-4-1106-preview",
"provider": {
"name": "openai",
// The default model to use when starting new conversations. This
"type": "openai",
// The default OpenAI API endpoint to use when starting new conversations.
"api_url": "https://api.openai.com/v1",
// The default OpenAI model to use when starting new conversations. This
// setting can take three values:
//
// 1. "gpt-3.5-turbo"
// 2. "gpt-4"
// 3. "gpt-4-turbo-preview"
"default_model": "gpt-4-turbo-preview"
// 1. "gpt-3.5-turbo-0613""
// 2. "gpt-4-0613""
// 3. "gpt-4-1106-preview"
"default_model": "gpt-4-1106-preview"
}
},
// Whether the screen sharing icon is shown in the os status bar.
@@ -504,6 +505,10 @@
// Existing terminals will not pick up this change until they are recreated.
// "max_scroll_history_lines": 10000,
},
// Difference settings for semantic_index
"semantic_index": {
"enabled": true
},
// Settings specific to our elixir integration
"elixir": {
// Change the LSP zed uses for elixir.
@@ -561,9 +566,6 @@
"source.organizeImports": true
}
},
"Make": {
"hard_tabs": true
},
"Markdown": {
"tab_size": 2,
"soft_wrap": "preferred_line_length"
@@ -591,13 +593,10 @@
},
"OCaml Interface": {
"tab_size": 2
},
"Prisma": {
"tab_size": 2
}
},
// Zed's Prettier integration settings.
// If Prettier is enabled, Zed will use this for its Prettier instance for any applicable file, if
// If Prettier is enabled, Zed will use this its Prettier instance for any applicable file, if
// project has no other Prettier installed.
"prettier": {
// Use regular Prettier json configuration:

View File

@@ -111,7 +111,7 @@
"hint": "#618399ff",
"hint.background": "#12231fff",
"hint.border": "#183934ff",
"ignored": "#6b6b73ff",
"ignored": "#aca8aeff",
"ignored.background": "#262933ff",
"ignored.border": "#2b2f38ff",
"info": "#10a793ff",

View File

@@ -111,7 +111,7 @@
"hint": "#706897ff",
"hint.background": "#161a35ff",
"hint.border": "#222953ff",
"ignored": "#756f7eff",
"ignored": "#898591ff",
"ignored.background": "#3a353fff",
"ignored.border": "#56505eff",
"info": "#566ddaff",
@@ -495,7 +495,7 @@
"hint": "#776d9dff",
"hint.background": "#e1e0f9ff",
"hint.border": "#c8c7f2ff",
"ignored": "#6e6876ff",
"ignored": "#5a5462ff",
"ignored.background": "#bfbcc5ff",
"ignored.border": "#8f8b96ff",
"info": "#586cdaff",
@@ -879,7 +879,7 @@
"hint": "#b17272ff",
"hint.background": "#171e38ff",
"hint.border": "#262f56ff",
"ignored": "#8f8b77ff",
"ignored": "#a4a08bff",
"ignored.background": "#45433bff",
"ignored.border": "#6c695cff",
"info": "#6684e0ff",
@@ -1263,7 +1263,7 @@
"hint": "#b37979ff",
"hint.background": "#e3e5faff",
"hint.border": "#cdd1f5ff",
"ignored": "#878471ff",
"ignored": "#706d5fff",
"ignored.background": "#cecab4ff",
"ignored.border": "#a8a48eff",
"info": "#6684dfff",
@@ -1647,7 +1647,7 @@
"hint": "#6f815aff",
"hint.background": "#142319ff",
"hint.border": "#1c3927ff",
"ignored": "#7d7c6aff",
"ignored": "#91907fff",
"ignored.background": "#424136ff",
"ignored.border": "#5d5c4cff",
"info": "#36a165ff",
@@ -2031,7 +2031,7 @@
"hint": "#758961ff",
"hint.background": "#d9ecdfff",
"hint.border": "#bbddc6ff",
"ignored": "#767463ff",
"ignored": "#61604fff",
"ignored.background": "#c5c4b9ff",
"ignored.border": "#969585ff",
"info": "#37a165ff",
@@ -2415,7 +2415,7 @@
"hint": "#a77087ff",
"hint.background": "#0f1c3dff",
"hint.border": "#182d5bff",
"ignored": "#8e8683ff",
"ignored": "#a79f9dff",
"ignored.background": "#443c39ff",
"ignored.border": "#665f5cff",
"info": "#407ee6ff",
@@ -2799,7 +2799,7 @@
"hint": "#a67287ff",
"hint.background": "#dfe3fbff",
"hint.border": "#c6cef7ff",
"ignored": "#837b78ff",
"ignored": "#6a6360ff",
"ignored.background": "#ccc7c5ff",
"ignored.border": "#aaa3a1ff",
"info": "#407ee6ff",
@@ -3183,7 +3183,7 @@
"hint": "#8d70a8ff",
"hint.background": "#0d1a43ff",
"hint.border": "#192961ff",
"ignored": "#908190ff",
"ignored": "#a899a8ff",
"ignored.background": "#433a43ff",
"ignored.border": "#675b67ff",
"info": "#5169ebff",
@@ -3567,7 +3567,7 @@
"hint": "#8c70a6ff",
"hint.background": "#e2dffcff",
"hint.border": "#cac7faff",
"ignored": "#857785ff",
"ignored": "#6b5e6bff",
"ignored.background": "#c6b8c6ff",
"ignored.border": "#ad9dadff",
"info": "#5169ebff",
@@ -3951,7 +3951,7 @@
"hint": "#52809aff",
"hint.background": "#121c24ff",
"hint.border": "#1a2f3cff",
"ignored": "#688c9dff",
"ignored": "#7c9fb3ff",
"ignored.background": "#33444dff",
"ignored.border": "#4f6a78ff",
"info": "#267eadff",
@@ -4335,7 +4335,7 @@
"hint": "#5a87a0ff",
"hint.background": "#d8e4eeff",
"hint.border": "#b9cee0ff",
"ignored": "#628496ff",
"ignored": "#526f7dff",
"ignored.background": "#a6cadcff",
"ignored.border": "#80a4b6ff",
"info": "#267eadff",
@@ -4719,7 +4719,7 @@
"hint": "#8a647aff",
"hint.background": "#1c1b29ff",
"hint.border": "#2c2b45ff",
"ignored": "#756e6eff",
"ignored": "#898383ff",
"ignored.background": "#3b3535ff",
"ignored.border": "#564e4eff",
"info": "#7272caff",
@@ -5103,7 +5103,7 @@
"hint": "#91697fff",
"hint.background": "#e4e1f5ff",
"hint.border": "#cecaecff",
"ignored": "#6e6666ff",
"ignored": "#5a5252ff",
"ignored.background": "#c1bbbbff",
"ignored.border": "#8e8989ff",
"info": "#7272caff",
@@ -5487,7 +5487,7 @@
"hint": "#607e76ff",
"hint.background": "#151e20ff",
"hint.border": "#1f3233ff",
"ignored": "#6f7e74ff",
"ignored": "#859188ff",
"ignored.background": "#353f39ff",
"ignored.border": "#505e55ff",
"info": "#468b8fff",
@@ -5871,7 +5871,7 @@
"hint": "#66847cff",
"hint.background": "#dae7e8ff",
"hint.border": "#bed4d6ff",
"ignored": "#68766dff",
"ignored": "#546259ff",
"ignored.background": "#bcc5bfff",
"ignored.border": "#8b968eff",
"info": "#488b90ff",
@@ -6255,7 +6255,7 @@
"hint": "#008b9fff",
"hint.background": "#051949ff",
"hint.border": "#102667ff",
"ignored": "#778f77ff",
"ignored": "#8ba48bff",
"ignored.background": "#3b453bff",
"ignored.border": "#5c6c5cff",
"info": "#3e62f4ff",
@@ -6639,7 +6639,7 @@
"hint": "#008fa1ff",
"hint.background": "#e1ddfeff",
"hint.border": "#c9c4fdff",
"ignored": "#718771ff",
"ignored": "#5f705fff",
"ignored.background": "#b4ceb4ff",
"ignored.border": "#8ea88eff",
"info": "#3e61f4ff",
@@ -7023,7 +7023,7 @@
"hint": "#6c81a5ff",
"hint.background": "#161f2bff",
"hint.border": "#203348ff",
"ignored": "#7e849eff",
"ignored": "#959bb2ff",
"ignored.background": "#3e4769ff",
"ignored.border": "#5b6385ff",
"info": "#3e8ed0ff",
@@ -7407,7 +7407,7 @@
"hint": "#7087b2ff",
"hint.background": "#dde7f6ff",
"hint.border": "#c2d5efff",
"ignored": "#767d9aff",
"ignored": "#5f6789ff",
"ignored.background": "#c1c5d8ff",
"ignored.border": "#9a9fb6ff",
"info": "#3e8fd0ff",

View File

@@ -111,7 +111,7 @@
"hint": "#628b80ff",
"hint.background": "#0d2f4eff",
"hint.border": "#1b4a6eff",
"ignored": "#696a6aff",
"ignored": "#8a8986ff",
"ignored.background": "#313337ff",
"ignored.border": "#3f4043ff",
"info": "#5ac1feff",
@@ -480,7 +480,7 @@
"hint": "#8ca7c2ff",
"hint.background": "#deebfaff",
"hint.border": "#c4daf6ff",
"ignored": "#a9acaeff",
"ignored": "#8b8e92ff",
"ignored.background": "#dcdddeff",
"ignored.border": "#cfd1d2ff",
"info": "#3b9ee5ff",
@@ -849,7 +849,7 @@
"hint": "#7399a3ff",
"hint.background": "#123950ff",
"hint.border": "#24556fff",
"ignored": "#7b7d7fff",
"ignored": "#9a9a98ff",
"ignored.background": "#464a52ff",
"ignored.border": "#53565dff",
"info": "#72cffeff",

View File

@@ -111,7 +111,7 @@
"hint": "#8c957dff",
"hint.background": "#1e2321ff",
"hint.border": "#303a36ff",
"ignored": "#998b78ff",
"ignored": "#c5b597ff",
"ignored.background": "#4c4642ff",
"ignored.border": "#5b534dff",
"info": "#83a598ff",
@@ -485,7 +485,7 @@
"hint": "#6a695bff",
"hint.background": "#1e2321ff",
"hint.border": "#303a36ff",
"ignored": "#998b78ff",
"ignored": "#c5b597ff",
"ignored.background": "#4c4642ff",
"ignored.border": "#5b534dff",
"info": "#83a598ff",
@@ -859,7 +859,7 @@
"hint": "#8c957dff",
"hint.background": "#1e2321ff",
"hint.border": "#303a36ff",
"ignored": "#998b78ff",
"ignored": "#c5b597ff",
"ignored.background": "#4c4642ff",
"ignored.border": "#5b534dff",
"info": "#83a598ff",
@@ -1233,7 +1233,7 @@
"hint": "#677562ff",
"hint.background": "#d2dee2ff",
"hint.border": "#adc5ccff",
"ignored": "#897b6eff",
"ignored": "#5f5650ff",
"ignored.background": "#d9c8a4ff",
"ignored.border": "#c8b899ff",
"info": "#0b6678ff",
@@ -1607,7 +1607,7 @@
"hint": "#677562ff",
"hint.background": "#d2dee2ff",
"hint.border": "#adc5ccff",
"ignored": "#897b6eff",
"ignored": "#5f5650ff",
"ignored.background": "#d9c8a4ff",
"ignored.border": "#c8b899ff",
"info": "#0b6678ff",
@@ -1981,7 +1981,7 @@
"hint": "#677562ff",
"hint.background": "#d2dee2ff",
"hint.border": "#adc5ccff",
"ignored": "#897b6eff",
"ignored": "#5f5650ff",
"ignored.background": "#d9c8a4ff",
"ignored.border": "#c8b899ff",
"info": "#0b6678ff",

View File

@@ -111,7 +111,7 @@
"hint": "#5a6f89ff",
"hint.background": "#18243dff",
"hint.border": "#293b5bff",
"ignored": "#555a63ff",
"ignored": "#838994ff",
"ignored.background": "#3b414dff",
"ignored.border": "#464b57ff",
"info": "#74ade8ff",
@@ -485,7 +485,7 @@
"hint": "#9294beff",
"hint.background": "#e2e2faff",
"hint.border": "#cbcdf6ff",
"ignored": "#a1a1a3ff",
"ignored": "#7e8087ff",
"ignored.background": "#dcdcddff",
"ignored.border": "#c9c9caff",
"info": "#5c78e2ff",

View File

@@ -111,7 +111,7 @@
"hint": "#5e768cff",
"hint.background": "#2f3639ff",
"hint.border": "#435255ff",
"ignored": "#2f2b43ff",
"ignored": "#74708dff",
"ignored.background": "#292738ff",
"ignored.border": "#423f55ff",
"info": "#9bced6ff",
@@ -490,7 +490,7 @@
"hint": "#7a92aaff",
"hint.background": "#dde9ebff",
"hint.border": "#c3d7dbff",
"ignored": "#938fa3ff",
"ignored": "#706c8cff",
"ignored.background": "#dcd8d8ff",
"ignored.border": "#dcd6d5ff",
"info": "#57949fff",
@@ -869,7 +869,7 @@
"hint": "#728aa2ff",
"hint.background": "#2f3639ff",
"hint.border": "#435255ff",
"ignored": "#605d7aff",
"ignored": "#85819eff",
"ignored.background": "#38354eff",
"ignored.border": "#504c68ff",
"info": "#9bced6ff",

View File

@@ -111,7 +111,7 @@
"hint": "#727d68ff",
"hint.background": "#171e1eff",
"hint.border": "#223131ff",
"ignored": "#827568ff",
"ignored": "#a69782ff",
"ignored.background": "#333944ff",
"ignored.border": "#3d4350ff",
"info": "#518b8bff",

View File

@@ -111,7 +111,7 @@
"hint": "#4f8297ff",
"hint.background": "#141f2cff",
"hint.border": "#1b3149ff",
"ignored": "#6f8389ff",
"ignored": "#93a1a1ff",
"ignored.background": "#073743ff",
"ignored.border": "#2b4e58ff",
"info": "#278ad1ff",
@@ -480,7 +480,7 @@
"hint": "#5789a3ff",
"hint.background": "#dbe6f6ff",
"hint.border": "#bfd3efff",
"ignored": "#6a7f86ff",
"ignored": "#34555eff",
"ignored.background": "#cfd0c4ff",
"ignored.border": "#9faaa8ff",
"info": "#288bd1ff",

View File

@@ -111,7 +111,7 @@
"hint": "#246e61ff",
"hint.background": "#0e2242ff",
"hint.border": "#193760ff",
"ignored": "#4c4735ff",
"ignored": "#736e55ff",
"ignored.background": "#2a261cff",
"ignored.border": "#302c21ff",
"info": "#499befff",

View File

@@ -16,7 +16,6 @@ doctest = false
anyhow.workspace = true
auto_update.workspace = true
editor.workspace = true
extension.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true

View File

@@ -1,6 +1,5 @@
use auto_update::{AutoUpdateStatus, AutoUpdater, DismissErrorMessage};
use editor::Editor;
use extension::ExtensionStore;
use futures::StreamExt;
use gpui::{
actions, svg, AppContext, CursorStyle, EventEmitter, InteractiveElement as _, Model,
@@ -206,7 +205,7 @@ impl ActivityIndicator {
}
LanguageServerBinaryStatus::Downloading => downloading.push(status.name.0.as_ref()),
LanguageServerBinaryStatus::Failed { .. } => failed.push(status.name.0.as_ref()),
LanguageServerBinaryStatus::None => {}
LanguageServerBinaryStatus::Downloaded | LanguageServerBinaryStatus::Cached => {}
}
}
@@ -289,18 +288,6 @@ impl ActivityIndicator {
};
}
if let Some(extension_store) =
ExtensionStore::try_global(cx).map(|extension_store| extension_store.read(cx))
{
if let Some(extension_id) = extension_store.outstanding_operations().keys().next() {
return Content {
icon: Some(DOWNLOAD_ICON),
message: format!("Updating {extension_id} extension…"),
on_click: None,
};
}
}
Default::default()
}
}

41
crates/ai/Cargo.toml Normal file
View File

@@ -0,0 +1,41 @@
[package]
name = "ai"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/ai.rs"
doctest = false
[features]
test-support = []
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
bincode = "1.3.3"
futures.workspace = true
gpui.workspace = true
isahc.workspace = true
language.workspace = true
log.workspace = true
matrixmultiply = "0.3.7"
ordered-float.workspace = true
parking_lot.workspace = true
parse_duration = "2.1.1"
postage.workspace = true
rand.workspace = true
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
tiktoken-rs.workspace = true
util.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }

8
crates/ai/src/ai.rs Normal file
View File

@@ -0,0 +1,8 @@
pub mod auth;
pub mod completion;
pub mod embedding;
pub mod models;
pub mod prompts;
pub mod providers;
#[cfg(any(test, feature = "test-support"))]
pub mod test;

23
crates/ai/src/auth.rs Normal file
View File

@@ -0,0 +1,23 @@
use futures::future::BoxFuture;
use gpui::AppContext;
#[derive(Clone, Debug)]
pub enum ProviderCredential {
Credentials { api_key: String },
NoCredentials,
NotNeeded,
}
pub trait CredentialProvider: Send + Sync {
fn has_credentials(&self) -> bool;
#[must_use]
fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential>;
#[must_use]
fn save_credentials(
&self,
cx: &mut AppContext,
credential: ProviderCredential,
) -> BoxFuture<()>;
#[must_use]
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()>;
}

View File

@@ -0,0 +1,23 @@
use anyhow::Result;
use futures::{future::BoxFuture, stream::BoxStream};
use crate::{auth::CredentialProvider, models::LanguageModel};
pub trait CompletionRequest: Send + Sync {
fn data(&self) -> serde_json::Result<String>;
}
pub trait CompletionProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn box_clone(&self) -> Box<dyn CompletionProvider>;
}
impl Clone for Box<dyn CompletionProvider> {
fn clone(&self) -> Box<dyn CompletionProvider> {
self.box_clone()
}
}

121
crates/ai/src/embedding.rs Normal file
View File

@@ -0,0 +1,121 @@
use std::time::Instant;
use anyhow::Result;
use async_trait::async_trait;
use ordered_float::OrderedFloat;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
use crate::auth::CredentialProvider;
use crate::models::LanguageModel;
#[derive(Debug, PartialEq, Clone)]
pub struct Embedding(pub Vec<f32>);
// This is needed for semantic index functionality
// Unfortunately it has to live wherever the "Embedding" struct is created.
// Keeping this in here though, introduces a 'rusqlite' dependency into AI
// which is less than ideal
impl FromSql for Embedding {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob()?;
let embedding =
bincode::deserialize(bytes).map_err(|err| rusqlite::types::FromSqlError::Other(err))?;
Ok(Embedding(embedding))
}
}
impl ToSql for Embedding {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
let bytes = bincode::serialize(&self.0)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
}
}
impl From<Vec<f32>> for Embedding {
fn from(value: Vec<f32>) -> Self {
Embedding(value)
}
}
impl Embedding {
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
let len = self.0.len();
assert_eq!(len, other.0.len());
let mut result = 0.0;
unsafe {
matrixmultiply::sgemm(
1,
len,
1,
1.0,
self.0.as_ptr(),
len as isize,
1,
other.0.as_ptr(),
1,
len as isize,
0.0,
&mut result as *mut f32,
1,
1,
);
}
OrderedFloat(result)
}
}
#[async_trait]
pub trait EmbeddingProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn rate_limit_expiration(&self) -> Option<Instant>;
}
#[cfg(test)]
mod tests {
use super::*;
use rand::prelude::*;
#[gpui::test]
fn test_similarity(mut rng: StdRng) {
assert_eq!(
Embedding::from(vec![1., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
0.
);
assert_eq!(
Embedding::from(vec![2., 0., 0., 0., 0.])
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
6.
);
for _ in 0..100 {
let size = 1536;
let mut a = vec![0.; size];
let mut b = vec![0.; size];
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
*a = rng.gen();
*b = rng.gen();
}
let a = Embedding::from(a);
let b = Embedding::from(b);
assert_eq!(
round_to_decimals(a.similarity(&b), 1),
round_to_decimals(reference_dot(&a.0, &b.0), 1)
);
}
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
let factor = 10.0_f32.powi(decimal_places);
(n * factor).round() / factor
}
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
}
}
}

16
crates/ai/src/models.rs Normal file
View File

@@ -0,0 +1,16 @@
pub enum TruncationDirection {
Start,
End,
}
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}

View File

@@ -0,0 +1,337 @@
use std::cmp::Reverse;
use std::ops::Range;
use std::sync::Arc;
use language::BufferSnapshot;
use util::ResultExt;
use crate::models::LanguageModel;
use crate::prompts::repository_context::PromptCodeSnippet;
pub(crate) enum PromptFileType {
Text,
Code,
}
// TODO: Set this up to manage for defaults well
pub struct PromptArguments {
pub model: Arc<dyn LanguageModel>,
pub user_prompt: Option<String>,
pub language_name: Option<String>,
pub project_name: Option<String>,
pub snippets: Vec<PromptCodeSnippet>,
pub reserved_tokens: usize,
pub buffer: Option<BufferSnapshot>,
pub selected_range: Option<Range<usize>>,
}
impl PromptArguments {
pub(crate) fn get_file_type(&self) -> PromptFileType {
if self
.language_name
.as_ref()
.map(|name| !["Markdown", "Plain Text"].contains(&name.as_str()))
.unwrap_or(true)
{
PromptFileType::Code
} else {
PromptFileType::Text
}
}
}
pub trait PromptTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)>;
}
#[repr(i8)]
#[derive(PartialEq, Eq)]
pub enum PromptPriority {
/// Ignores truncation.
Mandatory,
/// Truncates based on priority.
Ordered { order: usize },
}
impl PartialOrd for PromptPriority {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PromptPriority {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match (self, other) {
(Self::Mandatory, Self::Mandatory) => std::cmp::Ordering::Equal,
(Self::Mandatory, Self::Ordered { .. }) => std::cmp::Ordering::Greater,
(Self::Ordered { .. }, Self::Mandatory) => std::cmp::Ordering::Less,
(Self::Ordered { order: a }, Self::Ordered { order: b }) => b.cmp(a),
}
}
}
pub struct PromptChain {
args: PromptArguments,
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
}
impl PromptChain {
pub fn new(
args: PromptArguments,
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
) -> Self {
PromptChain { args, templates }
}
pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
// Argsort based on Prompt Priority
let separator = "\n";
let separator_tokens = self.args.model.count_tokens(separator)?;
let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
let mut tokens_outstanding = if truncate {
Some(self.args.model.capacity()? - self.args.reserved_tokens)
} else {
None
};
let mut prompts = vec!["".to_string(); sorted_indices.len()];
for idx in sorted_indices {
let (_, template) = &self.templates[idx];
if let Some((template_prompt, prompt_token_count)) =
template.generate(&self.args, tokens_outstanding).log_err()
{
if template_prompt != "" {
prompts[idx] = template_prompt;
if let Some(remaining_tokens) = tokens_outstanding {
let new_tokens = prompt_token_count + separator_tokens;
tokens_outstanding = if remaining_tokens > new_tokens {
Some(remaining_tokens - new_tokens)
} else {
Some(0)
};
}
}
}
}
prompts.retain(|x| x != "");
let full_prompt = prompts.join(separator);
let total_token_count = self.args.model.count_tokens(&full_prompt)?;
anyhow::Ok((prompts.join(separator), total_token_count))
}
}
#[cfg(test)]
pub(crate) mod tests {
use crate::models::TruncationDirection;
use crate::test::FakeLanguageModel;
use super::*;
#[test]
pub fn test_prompt_chain() {
struct TestPromptTemplate {}
impl PromptTemplate for TestPromptTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut content = "This is a test prompt template".to_string();
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
}
anyhow::Ok((content, token_count))
}
}
struct TestLowPriorityTemplate {}
impl PromptTemplate for TestLowPriorityTemplate {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut content = "This is a low priority test prompt template".to_string();
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
}
anyhow::Ok((content, token_count))
}
}
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(false).unwrap();
assert_eq!(
prompt,
"This is a test prompt template\nThis is a low priority test prompt template"
.to_string()
);
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(false).unwrap();
assert_eq!(
prompt,
"This is a test prompt template\nThis is a low priority test prompt template"
.to_string()
);
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let capacity = 20;
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens: 0,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
(
PromptPriority::Ordered { order: 2 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(true).unwrap();
assert_eq!(prompt, "This is a test promp".to_string());
assert_eq!(token_count, capacity);
// Change Ordering of Prompts Based on Priority
let capacity = 120;
let reserved_tokens = 10;
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
project_name: None,
snippets: Vec::new(),
reserved_tokens,
buffer: None,
selected_range: None,
user_prompt: None,
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(
PromptPriority::Mandatory,
Box::new(TestLowPriorityTemplate {}),
),
(
PromptPriority::Ordered { order: 0 },
Box::new(TestPromptTemplate {}),
),
(
PromptPriority::Ordered { order: 1 },
Box::new(TestLowPriorityTemplate {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, token_count) = chain.generate(true).unwrap();
assert_eq!(
prompt,
"This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
.to_string()
);
assert_eq!(token_count, capacity - reserved_tokens);
}
}

View File

@@ -0,0 +1,164 @@
use anyhow::anyhow;
use language::BufferSnapshot;
use language::ToOffset;
use crate::models::LanguageModel;
use crate::models::TruncationDirection;
use crate::prompts::base::PromptArguments;
use crate::prompts::base::PromptTemplate;
use std::fmt::Write;
use std::ops::Range;
use std::sync::Arc;
fn retrieve_context(
buffer: &BufferSnapshot,
selected_range: &Option<Range<usize>>,
model: Arc<dyn LanguageModel>,
max_token_count: Option<usize>,
) -> anyhow::Result<(String, usize, bool)> {
let mut prompt = String::new();
let mut truncated = false;
if let Some(selected_range) = selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
let start_window = buffer.text_for_range(0..start).collect::<String>();
let mut selected_window = String::new();
if start == end {
write!(selected_window, "<|START|>").unwrap();
} else {
write!(selected_window, "<|START|").unwrap();
}
write!(
selected_window,
"{}",
buffer.text_for_range(start..end).collect::<String>()
)
.unwrap();
if start != end {
write!(selected_window, "|END|>").unwrap();
}
let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
if let Some(max_token_count) = max_token_count {
let selected_tokens = model.count_tokens(&selected_window)?;
if selected_tokens > max_token_count {
return Err(anyhow!(
"selected range is greater than model context window, truncation not possible"
));
};
let mut remaining_tokens = max_token_count - selected_tokens;
let start_window_tokens = model.count_tokens(&start_window)?;
let end_window_tokens = model.count_tokens(&end_window)?;
let outside_tokens = start_window_tokens + end_window_tokens;
if outside_tokens > remaining_tokens {
let (start_goal_tokens, end_goal_tokens) =
if start_window_tokens < end_window_tokens {
let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
remaining_tokens -= start_goal_tokens;
let end_goal_tokens = remaining_tokens.min(end_window_tokens);
(start_goal_tokens, end_goal_tokens)
} else {
let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
remaining_tokens -= end_goal_tokens;
let start_goal_tokens = remaining_tokens.min(start_window_tokens);
(start_goal_tokens, end_goal_tokens)
};
let truncated_start_window =
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
let truncated_end_window =
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
writeln!(
prompt,
"{truncated_start_window}{selected_window}{truncated_end_window}"
)
.unwrap();
truncated = true;
} else {
writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
}
} else {
// If we dont have a selected range, include entire file.
writeln!(prompt, "{}", &buffer.text()).unwrap();
// Dumb truncation strategy
if let Some(max_token_count) = max_token_count {
if model.count_tokens(&prompt)? > max_token_count {
truncated = true;
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
}
}
}
}
let token_count = model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count, truncated))
}
pub struct FileContext {}
impl PromptTemplate for FileContext {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
if let Some(buffer) = &args.buffer {
let mut prompt = String::new();
// Add Initial Preamble
// TODO: Do we want to add the path in here?
writeln!(
prompt,
"The file you are currently working on has the following content:"
)
.unwrap();
let language_name = args
.language_name
.clone()
.unwrap_or("".to_string())
.to_lowercase();
let (context, _, truncated) = retrieve_context(
buffer,
&args.selected_range,
args.model.clone(),
max_token_length,
)?;
writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
if truncated {
writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
}
if let Some(selected_range) = &args.selected_range {
let start = selected_range.start.to_offset(buffer);
let end = selected_range.end.to_offset(buffer);
if start == end {
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
} else {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
}
}
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args
.model
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
}
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
} else {
Err(anyhow!("no buffer provided to retrieve file context from"))
}
}
}

View File

@@ -0,0 +1,99 @@
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use anyhow::anyhow;
use std::fmt::Write;
pub fn capitalize(s: &str) -> String {
let mut c = s.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
}
}
pub struct GenerateInlineContent {}
impl PromptTemplate for GenerateInlineContent {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let Some(user_prompt) = &args.user_prompt else {
return Err(anyhow!("user prompt not provided"));
};
let file_type = args.get_file_type();
let content_type = match &file_type {
PromptFileType::Code => "code",
PromptFileType::Text => "text",
};
let mut prompt = String::new();
if let Some(selected_range) = &args.selected_range {
if selected_range.start == selected_range.end {
writeln!(
prompt,
"Assume the cursor is located where the `<|START|>` span is."
)
.unwrap();
writeln!(
prompt,
"{} can't be replaced, so assume your answer will be inserted at the cursor.",
capitalize(content_type)
)
.unwrap();
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}",
)
.unwrap();
} else {
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
}
} else {
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}"
)
.unwrap();
}
if let Some(language_name) = &args.language_name {
writeln!(
prompt,
"Your answer MUST always and only be valid {}.",
language_name
)
.unwrap();
}
writeln!(prompt, "Never make remarks about the output.").unwrap();
writeln!(
prompt,
"Do not return anything else, except the generated {content_type}."
)
.unwrap();
match file_type {
PromptFileType::Code => {
// writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
}
_ => {}
}
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
prompt = args.model.truncate(
&prompt,
max_tokens,
crate::models::TruncationDirection::End,
)?;
}
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
}
}

View File

@@ -0,0 +1,5 @@
pub mod base;
pub mod file_context;
pub mod generate;
pub mod preamble;
pub mod repository_context;

View File

@@ -0,0 +1,52 @@
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use std::fmt::Write;
pub struct EngineerPreamble {}
impl PromptTemplate for EngineerPreamble {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
let mut prompts = Vec::new();
match args.get_file_type() {
PromptFileType::Code => {
prompts.push(format!(
"You are an expert {}engineer.",
args.language_name.clone().unwrap_or("".to_string()) + " "
));
}
PromptFileType::Text => {
prompts.push("You are an expert engineer.".to_string());
}
}
if let Some(project_name) = args.project_name.clone() {
prompts.push(format!(
"You are currently working inside the '{project_name}' project in code editor Zed."
));
}
if let Some(mut remaining_tokens) = max_token_length {
let mut prompt = String::new();
let mut total_count = 0;
for prompt_piece in prompts {
let prompt_token_count =
args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
if remaining_tokens > prompt_token_count {
writeln!(prompt, "{prompt_piece}").unwrap();
remaining_tokens -= prompt_token_count;
total_count += prompt_token_count;
}
}
anyhow::Ok((prompt, total_count))
} else {
let prompt = prompts.join("\n");
let token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, token_count))
}
}
}

View File

@@ -0,0 +1,96 @@
use crate::prompts::base::{PromptArguments, PromptTemplate};
use std::fmt::Write;
use std::{ops::Range, path::PathBuf};
use gpui::{AsyncAppContext, Model};
use language::{Anchor, Buffer};
#[derive(Clone)]
pub struct PromptCodeSnippet {
path: Option<PathBuf>,
language_name: Option<String>,
content: String,
}
impl PromptCodeSnippet {
pub fn new(
buffer: Model<Buffer>,
range: Range<Anchor>,
cx: &mut AsyncAppContext,
) -> anyhow::Result<Self> {
let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
let snapshot = buffer.snapshot();
let content = snapshot.text_for_range(range.clone()).collect::<String>();
let language_name = buffer
.language()
.map(|language| language.name().to_string().to_lowercase());
let file_path = buffer.file().map(|file| file.path().to_path_buf());
(content, language_name, file_path)
})?;
anyhow::Ok(PromptCodeSnippet {
path: file_path,
language_name,
content,
})
}
}
impl ToString for PromptCodeSnippet {
fn to_string(&self) -> String {
let path = self
.path
.as_ref()
.map(|path| path.to_string_lossy().to_string())
.unwrap_or("".to_string());
let language_name = self.language_name.clone().unwrap_or("".to_string());
let content = self.content.clone();
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
}
}
pub struct RepositoryContext {}
impl PromptTemplate for RepositoryContext {
fn generate(
&self,
args: &PromptArguments,
max_token_length: Option<usize>,
) -> anyhow::Result<(String, usize)> {
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
let mut prompt = String::new();
let mut remaining_tokens = max_token_length;
let separator_token_length = args.model.count_tokens("\n")?;
for snippet in &args.snippets {
let mut snippet_prompt = template.to_string();
let content = snippet.to_string();
writeln!(snippet_prompt, "{content}").unwrap();
let token_count = args.model.count_tokens(&snippet_prompt)?;
if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
if let Some(tokens_left) = remaining_tokens {
if tokens_left >= token_count {
writeln!(prompt, "{snippet_prompt}").unwrap();
remaining_tokens = if tokens_left >= (token_count + separator_token_length)
{
Some(tokens_left - token_count - separator_token_length)
} else {
Some(0)
};
}
} else {
writeln!(prompt, "{snippet_prompt}").unwrap();
}
}
}
let total_token_count = args.model.count_tokens(&prompt)?;
anyhow::Ok((prompt, total_token_count))
}
}

View File

@@ -0,0 +1 @@
pub mod open_ai;

View File

@@ -0,0 +1,9 @@
pub mod completion;
pub mod embedding;
pub mod model;
pub use completion::*;
pub use embedding::*;
pub use model::OpenAiLanguageModel;
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";

View File

@@ -0,0 +1,421 @@
use std::{
env,
fmt::{self, Display},
io,
sync::Arc,
};
use anyhow::{anyhow, Result};
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
use gpui::{AppContext, BackgroundExecutor};
use isahc::{http::StatusCode, Request, RequestExt};
use parking_lot::RwLock;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use util::ResultExt;
use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
models::LanguageModel,
};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "User"),
Role::Assistant => write!(f, "Assistant"),
Role::System => write!(f, "System"),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Debug, Default, Serialize)]
pub struct OpenAiRequest {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
}
impl CompletionRequest for OpenAiRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAiUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceDelta {
pub index: u32,
pub delta: ResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct OpenAiResponseStreamEvent {
pub id: Option<String>,
pub object: String,
pub created: u32,
pub model: String,
pub choices: Vec<ChatChoiceDelta>,
pub usage: Option<OpenAiUsage>,
}
async fn stream_completion(
api_url: String,
kind: OpenAiCompletionProviderKind,
credential: ProviderCredential,
executor: BackgroundExecutor,
request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAiResponseStreamEvent>>> {
let api_key = match credential {
ProviderCredential::Credentials { api_key } => api_key,
_ => {
return Err(anyhow!("no credentials provider for completion"));
}
};
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
let (auth_header_name, auth_header_value) = kind.auth_header(api_key);
let json_data = request.data()?;
let mut response = Request::post(kind.completions_endpoint_url(&api_url))
.header("Content-Type", "application/json")
.header(auth_header_name, auth_header_value)
.body(json_data)?
.send_async()
.await?;
let status = response.status();
if status == StatusCode::OK {
executor
.spawn(async move {
let mut lines = BufReader::new(response.body_mut()).lines();
fn parse_line(
line: Result<String, io::Error>,
) -> Result<Option<OpenAiResponseStreamEvent>> {
if let Some(data) = line?.strip_prefix("data: ") {
let event = serde_json::from_str(data)?;
Ok(Some(event))
} else {
Ok(None)
}
}
while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() {
let done = event.as_ref().map_or(false, |event| {
event
.choices
.last()
.map_or(false, |choice| choice.finish_reason.is_some())
});
if tx.unbounded_send(event).is_err() {
break;
}
if done {
break;
}
}
}
anyhow::Ok(())
})
.detach();
Ok(rx)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenAiResponse {
error: OpenAiError,
}
#[derive(Deserialize)]
struct OpenAiError {
message: String,
}
match serde_json::from_str::<OpenAiResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to OpenAI API: {}",
response.error.message,
)),
_ => Err(anyhow!(
"Failed to connect to OpenAI API: {} {}",
response.status(),
body,
)),
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
pub enum AzureOpenAiApiVersion {
/// Retiring April 2, 2024.
#[serde(rename = "2023-03-15-preview")]
V2023_03_15Preview,
#[serde(rename = "2023-05-15")]
V2023_05_15,
/// Retiring April 2, 2024.
#[serde(rename = "2023-06-01-preview")]
V2023_06_01Preview,
/// Retiring April 2, 2024.
#[serde(rename = "2023-07-01-preview")]
V2023_07_01Preview,
/// Retiring April 2, 2024.
#[serde(rename = "2023-08-01-preview")]
V2023_08_01Preview,
/// Retiring April 2, 2024.
#[serde(rename = "2023-09-01-preview")]
V2023_09_01Preview,
#[serde(rename = "2023-12-01-preview")]
V2023_12_01Preview,
#[serde(rename = "2024-02-15-preview")]
V2024_02_15Preview,
}
impl fmt::Display for AzureOpenAiApiVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
match self {
Self::V2023_03_15Preview => "2023-03-15-preview",
Self::V2023_05_15 => "2023-05-15",
Self::V2023_06_01Preview => "2023-06-01-preview",
Self::V2023_07_01Preview => "2023-07-01-preview",
Self::V2023_08_01Preview => "2023-08-01-preview",
Self::V2023_09_01Preview => "2023-09-01-preview",
Self::V2023_12_01Preview => "2023-12-01-preview",
Self::V2024_02_15Preview => "2024-02-15-preview",
}
)
}
}
#[derive(Clone)]
pub enum OpenAiCompletionProviderKind {
OpenAi,
AzureOpenAi {
deployment_id: String,
api_version: AzureOpenAiApiVersion,
},
}
impl OpenAiCompletionProviderKind {
/// Returns the chat completion endpoint URL for this [`OpenAiCompletionProviderKind`].
fn completions_endpoint_url(&self, api_url: &str) -> String {
match self {
Self::OpenAi => {
// https://platform.openai.com/docs/api-reference/chat/create
format!("{api_url}/chat/completions")
}
Self::AzureOpenAi {
deployment_id,
api_version,
} => {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
format!("{api_url}/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}")
}
}
}
/// Returns the authentication header for this [`OpenAiCompletionProviderKind`].
fn auth_header(&self, api_key: String) -> (&'static str, String) {
match self {
Self::OpenAi => ("Authorization", format!("Bearer {api_key}")),
Self::AzureOpenAi { .. } => ("Api-Key", api_key),
}
}
}
#[derive(Clone)]
pub struct OpenAiCompletionProvider {
api_url: String,
kind: OpenAiCompletionProviderKind,
model: OpenAiLanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
executor: BackgroundExecutor,
}
impl OpenAiCompletionProvider {
pub async fn new(
api_url: String,
kind: OpenAiCompletionProviderKind,
model_name: String,
executor: BackgroundExecutor,
) -> Self {
let model = executor
.spawn(async move { OpenAiLanguageModel::load(&model_name) })
.await;
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self {
api_url,
kind,
model,
credential,
executor,
}
}
}
impl CredentialProvider for OpenAiCompletionProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
let existing_credential = self.credential.read().clone();
let retrieved_credential = match existing_credential {
ProviderCredential::Credentials { .. } => {
return async move { existing_credential }.boxed()
}
_ => {
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
async move { ProviderCredential::Credentials { api_key } }.boxed()
} else {
let credentials = cx.read_credentials(OPEN_AI_API_URL);
async move {
if let Some(Some((_, api_key))) = credentials.await.log_err() {
if let Some(api_key) = String::from_utf8(api_key).log_err() {
ProviderCredential::Credentials { api_key }
} else {
ProviderCredential::NoCredentials
}
} else {
ProviderCredential::NoCredentials
}
}
.boxed()
}
}
};
async move {
let retrieved_credential = retrieved_credential.await;
*self.credential.write() = retrieved_credential.clone();
retrieved_credential
}
.boxed()
}
fn save_credentials(
&self,
cx: &mut AppContext,
credential: ProviderCredential,
) -> BoxFuture<()> {
*self.credential.write() = credential.clone();
let credential = credential.clone();
let write_credentials = match credential {
ProviderCredential::Credentials { api_key } => {
Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
}
_ => None,
};
async move {
if let Some(write_credentials) = write_credentials {
write_credentials.await.log_err();
}
}
.boxed()
}
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
*self.credential.write() = ProviderCredential::NoCredentials;
let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
async move {
delete_credentials.await.log_err();
}
.boxed()
}
}
impl CompletionProvider for OpenAiCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
// which is currently model based, due to the language model.
// At some point in the future we should rectify this.
let credential = self.credential.read().clone();
let api_url = self.api_url.clone();
let kind = self.kind.clone();
let request = stream_completion(api_url, kind, credential, self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View File

@@ -0,0 +1,345 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::future::BoxFuture;
use futures::AsyncReadExt;
use futures::FutureExt;
use gpui::AppContext;
use gpui::BackgroundExecutor;
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use parking_lot::{Mutex, RwLock};
use parse_duration::parse;
use postage::watch;
use serde::{Deserialize, Serialize};
use serde_json;
use std::env;
use std::ops::Add;
use std::sync::{Arc, OnceLock};
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
use util::ResultExt;
use crate::auth::{CredentialProvider, ProviderCredential};
use crate::embedding::{Embedding, EmbeddingProvider};
use crate::models::LanguageModel;
use crate::providers::open_ai::OpenAiLanguageModel;
use crate::providers::open_ai::OPEN_AI_API_URL;
pub(crate) fn open_ai_bpe_tokenizer() -> &'static CoreBPE {
static OPEN_AI_BPE_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
OPEN_AI_BPE_TOKENIZER.get_or_init(|| cl100k_base().unwrap())
}
#[derive(Clone)]
pub struct OpenAiEmbeddingProvider {
api_url: String,
model: OpenAiLanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
pub client: Arc<dyn HttpClient>,
pub executor: BackgroundExecutor,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
}
#[derive(Serialize)]
struct OpenAiEmbeddingRequest<'a> {
model: &'static str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct OpenAiEmbeddingResponse {
data: Vec<OpenAiEmbedding>,
usage: OpenAiEmbeddingUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAiEmbedding {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
struct OpenAiEmbeddingUsage {
prompt_tokens: usize,
total_tokens: usize,
}
impl OpenAiEmbeddingProvider {
pub async fn new(
api_url: String,
client: Arc<dyn HttpClient>,
executor: BackgroundExecutor,
) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
// Loading the model is expensive, so ensure this runs off the main thread.
let model = executor
.spawn(async move { OpenAiLanguageModel::load("text-embedding-ada-002") })
.await;
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
OpenAiEmbeddingProvider {
api_url,
model,
credential,
client,
executor,
rate_limit_count_rx,
rate_limit_count_tx,
}
}
fn get_api_key(&self) -> Result<String> {
match self.credential.read().clone() {
ProviderCredential::Credentials { api_key } => Ok(api_key),
_ => Err(anyhow!("api credentials not provided")),
}
}
fn resolve_rate_limit(&self) {
let reset_time = *self.rate_limit_count_tx.lock().borrow();
if let Some(reset_time) = reset_time {
if Instant::now() >= reset_time {
*self.rate_limit_count_tx.lock().borrow_mut() = None
}
}
log::trace!(
"resolving reset time: {:?}",
*self.rate_limit_count_tx.lock().borrow()
);
}
fn update_reset_time(&self, reset_time: Instant) {
let original_time = *self.rate_limit_count_tx.lock().borrow();
let updated_time = if let Some(original_time) = original_time {
if reset_time < original_time {
Some(reset_time)
} else {
Some(original_time)
}
} else {
Some(reset_time)
};
log::trace!("updating rate limit time: {:?}", updated_time);
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
}
async fn send_request(
&self,
api_url: &str,
api_key: &str,
spans: Vec<&str>,
request_timeout: u64,
) -> Result<Response<AsyncBody>> {
let request = Request::post(format!("{api_url}/embeddings"))
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAiEmbeddingRequest {
input: spans.clone(),
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
Ok(self.client.send(request).await?)
}
}
impl CredentialProvider for OpenAiEmbeddingProvider {
fn has_credentials(&self) -> bool {
match *self.credential.read() {
ProviderCredential::Credentials { .. } => true,
_ => false,
}
}
fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
let existing_credential = self.credential.read().clone();
let retrieved_credential = match existing_credential {
ProviderCredential::Credentials { .. } => {
return async move { existing_credential }.boxed()
}
_ => {
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
async move { ProviderCredential::Credentials { api_key } }.boxed()
} else {
let credentials = cx.read_credentials(OPEN_AI_API_URL);
async move {
if let Some(Some((_, api_key))) = credentials.await.log_err() {
if let Some(api_key) = String::from_utf8(api_key).log_err() {
ProviderCredential::Credentials { api_key }
} else {
ProviderCredential::NoCredentials
}
} else {
ProviderCredential::NoCredentials
}
}
.boxed()
}
}
};
async move {
let retrieved_credential = retrieved_credential.await;
*self.credential.write() = retrieved_credential.clone();
retrieved_credential
}
.boxed()
}
fn save_credentials(
&self,
cx: &mut AppContext,
credential: ProviderCredential,
) -> BoxFuture<()> {
*self.credential.write() = credential.clone();
let credential = credential.clone();
let write_credentials = match credential {
ProviderCredential::Credentials { api_key } => {
Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
}
_ => None,
};
async move {
if let Some(write_credentials) = write_credentials {
write_credentials.await.log_err();
}
}
.boxed()
}
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
*self.credential.write() = ProviderCredential::NoCredentials;
let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
async move {
delete_credentials.await.log_err();
}
.boxed()
}
}
#[async_trait]
impl EmbeddingProvider for OpenAiEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
fn max_tokens_per_batch(&self) -> usize {
50000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
*self.rate_limit_count_rx.borrow()
}
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
let api_url = self.api_url.as_str();
let api_key = self.get_api_key()?;
let mut request_number = 0;
let mut rate_limiting = false;
let mut request_timeout: u64 = 15;
let mut response: Response<AsyncBody>;
while request_number < MAX_RETRIES {
response = self
.send_request(
&api_url,
&api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
.await?;
request_number += 1;
match response.status() {
StatusCode::REQUEST_TIMEOUT => {
request_timeout += 5;
}
StatusCode::OK => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: OpenAiEmbeddingResponse = serde_json::from_str(&body)?;
log::trace!(
"openai embedding completed. tokens: {:?}",
response.usage.total_tokens
);
// If we complete a request successfully that was previously rate_limited
// resolve the rate limit
if rate_limiting {
self.resolve_rate_limit()
}
return Ok(response
.data
.into_iter()
.map(|embedding| Embedding::from(embedding.embedding))
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
rate_limiting = true;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let delay_duration = {
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
if let Some(time_to_reset) =
response.headers().get("x-ratelimit-reset-tokens")
{
if let Ok(time_str) = time_to_reset.to_str() {
parse(time_str).unwrap_or(delay)
} else {
delay
}
} else {
delay
}
};
// If we've previously rate limited, increment the duration but not the count
let reset_time = Instant::now().add(delay_duration);
self.update_reset_time(reset_time);
log::trace!(
"openai rate limiting: waiting {:?} until lifted",
&delay_duration
);
self.executor.timer(delay_duration).await;
}
_ => {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"open ai bad request: {:?} {:?}",
&response.status(),
body
));
}
}
}
Err(anyhow!("openai max retries"))
}
}

View File

@@ -0,0 +1,59 @@
use anyhow::anyhow;
use tiktoken_rs::CoreBPE;
use crate::models::{LanguageModel, TruncationDirection};
use super::open_ai_bpe_tokenizer;
#[derive(Clone)]
pub struct OpenAiLanguageModel {
name: String,
bpe: Option<CoreBPE>,
}
impl OpenAiLanguageModel {
pub fn load(model_name: &str) -> Self {
let bpe = tiktoken_rs::get_bpe_from_model(model_name)
.unwrap_or(open_ai_bpe_tokenizer().to_owned());
OpenAiLanguageModel {
name: model_name.to_string(),
bpe: Some(bpe),
}
}
}
impl LanguageModel for OpenAiLanguageModel {
fn name(&self) -> String {
self.name.clone()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
if let Some(bpe) = &self.bpe {
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
if let Some(bpe) = &self.bpe {
let tokens = bpe.encode_with_special_tokens(content);
if tokens.len() > length {
match direction {
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
}
} else {
bpe.decode(tokens)
}
} else {
Err(anyhow!("bpe for open ai model was not retrieved"))
}
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
}
}

206
crates/ai/src/test.rs Normal file
View File

@@ -0,0 +1,206 @@
use std::{
sync::atomic::{self, AtomicUsize, Ordering},
time::Instant,
};
use async_trait::async_trait;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::AppContext;
use parking_lot::Mutex;
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
embedding::{Embedding, EmbeddingProvider},
models::{LanguageModel, TruncationDirection},
};
#[derive(Clone)]
pub struct FakeLanguageModel {
pub capacity: usize,
}
impl LanguageModel for FakeLanguageModel {
fn name(&self) -> String {
"dummy".to_string()
}
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
}
fn truncate(
&self,
content: &str,
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
println!("TRYING TO TRUNCATE: {:?}", length.clone());
if length > self.count_tokens(content)? {
println!("NOT TRUNCATING");
return anyhow::Ok(content.to_string());
}
anyhow::Ok(match direction {
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
.into_iter()
.collect::<String>(),
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
.into_iter()
.collect::<String>(),
})
}
fn capacity(&self) -> anyhow::Result<usize> {
anyhow::Ok(self.capacity)
}
}
#[derive(Default)]
pub struct FakeEmbeddingProvider {
pub embedding_count: AtomicUsize,
}
impl Clone for FakeEmbeddingProvider {
fn clone(&self) -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
}
}
}
impl FakeEmbeddingProvider {
pub fn embedding_count(&self) -> usize {
self.embedding_count.load(atomic::Ordering::SeqCst)
}
pub fn embed_sync(&self, span: &str) -> Embedding {
let mut result = vec![1.0; 26];
for letter in span.chars() {
let letter = letter.to_ascii_lowercase();
if letter as u32 >= 'a' as u32 {
let ix = (letter as u32) - ('a' as u32);
if ix < 26 {
result[ix as usize] += 1.0;
}
}
}
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut result {
*x /= norm;
}
result.into()
}
}
impl CredentialProvider for FakeEmbeddingProvider {
fn has_credentials(&self) -> bool {
true
}
fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
async { ProviderCredential::NotNeeded }.boxed()
}
fn save_credentials(
&self,
_cx: &mut AppContext,
_credential: ProviderCredential,
) -> BoxFuture<()> {
async {}.boxed()
}
fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
async {}.boxed()
}
}
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(FakeLanguageModel { capacity: 1000 })
}
fn max_tokens_per_batch(&self) -> usize {
1000
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
pub struct FakeCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl Clone for FakeCompletionProvider {
fn clone(&self) -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
}
impl FakeCompletionProvider {
pub fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
pub fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
pub fn finish_completion(&self) {
self.last_completion_tx.lock().take().unwrap();
}
}
impl CredentialProvider for FakeCompletionProvider {
fn has_credentials(&self) -> bool {
true
}
fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
async { ProviderCredential::NotNeeded }.boxed()
}
fn save_credentials(
&self,
_cx: &mut AppContext,
_credential: ProviderCredential,
) -> BoxFuture<()> {
async {}.boxed()
}
fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
async {}.boxed()
}
}
impl CompletionProvider for FakeCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
model
}
fn complete(
&self,
_prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View File

@@ -1,22 +0,0 @@
[package]
name = "anthropic"
version = "0.1.0"
edition = "2021"
publish = false
license = "AGPL-3.0-or-later"
[lib]
path = "src/anthropic.rs"
[dependencies]
anyhow.workspace = true
futures.workspace = true
serde.workspace = true
serde_json.workspace = true
util.workspace = true
[dev-dependencies]
tokio.workspace = true
[lints]
workspace = true

View File

@@ -1 +0,0 @@
../../LICENSE-AGPL

View File

@@ -1,234 +0,0 @@
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub enum Model {
#[default]
#[serde(rename = "claude-3-opus-20240229")]
Claude3Opus,
#[serde(rename = "claude-3-sonnet-20240229")]
Claude3Sonnet,
#[serde(rename = "claude-3-haiku-20240307")]
Claude3Haiku,
}
impl Model {
pub fn from_id(id: &str) -> Result<Self> {
if id.starts_with("claude-3-opus") {
Ok(Self::Claude3Opus)
} else if id.starts_with("claude-3-sonnet") {
Ok(Self::Claude3Sonnet)
} else if id.starts_with("claude-3-haiku") {
Ok(Self::Claude3Haiku)
} else {
Err(anyhow!("Invalid model id: {}", id))
}
}
pub fn display_name(&self) -> &'static str {
match self {
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
}
}
pub fn max_token_count(&self) -> usize {
200_000
}
}
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
impl TryFrom<String> for Role {
type Error = anyhow::Error;
fn try_from(value: String) -> Result<Self> {
match value.as_str() {
"user" => Ok(Self::User),
"assistant" => Ok(Self::Assistant),
_ => Err(anyhow!("invalid role '{value}'")),
}
}
}
impl From<Role> for String {
fn from(val: Role) -> Self {
match val {
Role::User => "user".to_owned(),
Role::Assistant => "assistant".to_owned(),
}
}
}
#[derive(Debug, Serialize)]
pub struct Request {
pub model: Model,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub system: String,
pub max_tokens: u32,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
pub content: String,
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseEvent {
MessageStart {
message: ResponseMessage,
},
ContentBlockStart {
index: u32,
content_block: ContentBlock,
},
Ping {},
ContentBlockDelta {
index: u32,
delta: TextDelta,
},
ContentBlockStop {
index: u32,
},
MessageDelta {
delta: ResponseMessage,
usage: Usage,
},
MessageStop {},
}
#[derive(Deserialize, Debug)]
pub struct ResponseMessage {
#[serde(rename = "type")]
pub message_type: Option<String>,
pub id: Option<String>,
pub role: Option<String>,
pub content: Option<Vec<String>>,
pub model: Option<String>,
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
pub usage: Option<Usage>,
}
#[derive(Deserialize, Debug)]
pub struct Usage {
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text { text: String },
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TextDelta {
TextDelta { text: String },
}
pub async fn stream_completion(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
let uri = format!("{api_url}/v1/messages");
let request = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Anthropic-Version", "2023-06-01")
.header("Anthropic-Beta", "messages-2023-12-15")
.header("X-Api-Key", api_key)
.header("Content-Type", "application/json")
.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
Ok(reader
.lines()
.filter_map(|line| async move {
match line {
Ok(line) => {
let line = line.strip_prefix("data: ")?;
match serde_json::from_str(line) {
Ok(response) => Some(Ok(response)),
Err(error) => Some(Err(anyhow!(error))),
}
}
Err(error) => Some(Err(anyhow!(error))),
}
})
.boxed())
} else {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let body_str = std::str::from_utf8(&body)?;
match serde_json::from_str::<ResponseEvent>(body_str) {
Ok(_) => Err(anyhow!(
"Unexpected success response while expecting an error: {}",
body_str,
)),
Err(_) => Err(anyhow!(
"Failed to connect to API: {} {}",
response.status(),
body_str,
)),
}
}
}
// #[cfg(test)]
// mod tests {
// use super::*;
// use util::http::IsahcHttpClient;
// #[tokio::test]
// async fn stream_completion_success() {
// let http_client = IsahcHttpClient::new().unwrap();
// let request = Request {
// model: Model::Claude3Opus,
// messages: vec![RequestMessage {
// role: Role::User,
// content: "Ping".to_string(),
// }],
// stream: true,
// system: "Respond to ping with pong".to_string(),
// max_tokens: 4096,
// };
// let stream = stream_completion(
// &http_client,
// "https://api.anthropic.com",
// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
// request,
// )
// .await
// .unwrap();
// stream
// .for_each(|event| async {
// match event {
// Ok(event) => println!("{:?}", event),
// Err(e) => eprintln!("Error: {:?}", e),
// }
// })
// .await;
// }
// }

View File

@@ -5,18 +5,19 @@ edition = "2021"
publish = false
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/assistant.rs"
doctest = false
[dependencies]
ai.workspace = true
anyhow.workspace = true
chrono.workspace = true
client.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
editor.workspace = true
file_icons.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -25,13 +26,12 @@ language.workspace = true
log.workspace = true
menu.workspace = true
multi_buffer.workspace = true
open_ai = { workspace = true, features = ["schemars"] }
ordered-float.workspace = true
parking_lot.workspace = true
project.workspace = true
regex.workspace = true
schemars.workspace = true
search.workspace = true
semantic_index.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
@@ -45,6 +45,7 @@ uuid.workspace = true
workspace.workspace = true
[dev-dependencies]
ai = { workspace = true, features = ["test-support"] }
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true

View File

@@ -1,27 +1,22 @@
pub mod assistant_panel;
pub mod assistant_settings;
mod codegen;
mod completion_provider;
mod prompts;
mod saved_conversation;
mod streaming_diff;
mod embedded_scope;
use ai::providers::open_ai::Role;
use anyhow::Result;
pub use assistant_panel::AssistantPanel;
use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
use assistant_settings::OpenAiModel;
use chrono::{DateTime, Local};
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
pub(crate) use completion_provider::*;
use gpui::{actions, AppContext, BorrowAppContext, Global, SharedString};
pub(crate) use saved_conversation::*;
use collections::HashMap;
use fs::Fs;
use futures::StreamExt;
use gpui::{actions, AppContext, SharedString};
use regex::Regex;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{
fmt::{self, Display},
sync::Arc,
};
use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
use util::paths::CONVERSATIONS_DIR;
actions!(
assistant,
@@ -35,6 +30,7 @@ actions!(
ResetKey,
InlineAssist,
ToggleIncludeConversation,
ToggleRetrieveContext,
]
);
@@ -43,134 +39,6 @@ actions!(
)]
struct MessageId(usize);
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl Role {
pub fn cycle(&mut self) {
*self = match self {
Role::User => Role::Assistant,
Role::Assistant => Role::System,
Role::System => Role::User,
}
}
}
impl Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
Role::System => write!(f, "system"),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum LanguageModel {
ZedDotDev(ZedDotDevModel),
OpenAi(OpenAiModel),
}
impl Default for LanguageModel {
fn default() -> Self {
LanguageModel::ZedDotDev(ZedDotDevModel::default())
}
}
impl LanguageModel {
pub fn telemetry_id(&self) -> String {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
}
}
pub fn display_name(&self) -> String {
match self {
LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
}
}
pub fn max_token_count(&self) -> usize {
match self {
LanguageModel::OpenAi(model) => model.max_token_count(),
LanguageModel::ZedDotDev(model) => model.max_token_count(),
}
}
pub fn id(&self) -> &str {
match self {
LanguageModel::OpenAi(model) => model.id(),
LanguageModel::ZedDotDev(model) => model.id(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelRequestMessage {
pub role: Role,
pub content: String,
}
impl LanguageModelRequestMessage {
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
proto::LanguageModelRequestMessage {
role: match self.role {
Role::User => proto::LanguageModelRole::LanguageModelUser,
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
Role::System => proto::LanguageModelRole::LanguageModelSystem,
} as i32,
content: self.content.clone(),
}
}
}
#[derive(Debug, Default, Serialize)]
pub struct LanguageModelRequest {
pub model: LanguageModel,
pub messages: Vec<LanguageModelRequestMessage>,
pub stop: Vec<String>,
pub temperature: f32,
}
impl LanguageModelRequest {
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
proto::CompleteWithLanguageModel {
model: self.model.id().to_string(),
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
stop: self.stop.clone(),
temperature: self.temperature,
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelResponseMessage {
pub role: Option<Role>,
pub content: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct LanguageModelUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct LanguageModelChoiceDelta {
pub index: u32,
pub delta: LanguageModelResponseMessage,
pub finish_reason: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
struct MessageMetadata {
role: Role,
@@ -185,61 +53,72 @@ enum MessageStatus {
Error(SharedString),
}
/// The state pertaining to the Assistant.
#[derive(Default)]
struct Assistant {
/// Whether the Assistant is enabled.
enabled: bool,
#[derive(Serialize, Deserialize)]
struct SavedMessage {
id: MessageId,
start: usize,
}
impl Global for Assistant {}
#[derive(Serialize, Deserialize)]
struct SavedConversation {
id: Option<String>,
zed: String,
version: String,
text: String,
messages: Vec<SavedMessage>,
message_metadata: HashMap<MessageId, MessageMetadata>,
summary: String,
api_url: Option<String>,
model: OpenAiModel,
}
impl Assistant {
const NAMESPACE: &'static str = "assistant";
impl SavedConversation {
const VERSION: &'static str = "0.1.0";
}
fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
if self.enabled == enabled {
return;
struct SavedConversationMetadata {
title: String,
path: PathBuf,
mtime: chrono::DateTime<chrono::Local>,
}
impl SavedConversationMetadata {
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
fs.create_dir(&CONVERSATIONS_DIR).await?;
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
let mut conversations = Vec::<SavedConversationMetadata>::new();
while let Some(path) = paths.next().await {
let path = path?;
if path.extension() != Some(OsStr::new("json")) {
continue;
}
let pattern = r" - \d+.zed.json$";
let re = Regex::new(pattern).unwrap();
let metadata = fs.metadata(&path).await?;
if let Some((file_name, metadata)) = path
.file_name()
.and_then(|name| name.to_str())
.zip(metadata)
{
let title = re.replace(file_name, "");
conversations.push(Self {
title: title.into_owned(),
path,
mtime: metadata.mtime.into(),
});
}
}
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
self.enabled = enabled;
if !enabled {
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.hide_namespace(Self::NAMESPACE);
});
return;
}
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.show_namespace(Self::NAMESPACE);
});
Ok(conversations)
}
}
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
cx.set_global(Assistant::default());
AssistantSettings::register(cx);
completion_provider::init(client, cx);
pub fn init(cx: &mut AppContext) {
assistant_panel::init(cx);
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.hide_namespace(Assistant::NAMESPACE);
});
cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
let settings = AssistantSettings::get_global(cx);
assistant.set_enabled(settings.enabled, cx);
});
cx.observe_global::<SettingsStore>(|cx| {
cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
let settings = AssistantSettings::get_global(cx);
assistant.set_enabled(settings.enabled, cx);
});
})
.detach();
}
#[cfg(test)]

File diff suppressed because it is too large Load Diff

View File

@@ -1,305 +1,169 @@
use std::fmt;
use ai::providers::open_ai::{
AzureOpenAiApiVersion, OpenAiCompletionProviderKind, OPEN_AI_API_URL,
};
use anyhow::anyhow;
use gpui::Pixels;
pub use open_ai::Model as OpenAiModel;
use schemars::{
schema::{InstanceType, Metadata, Schema, SchemaObject},
JsonSchema,
};
use serde::{
de::{self, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
#[derive(Clone, Debug, Default, PartialEq)]
pub enum ZedDotDevModel {
Gpt3Point5Turbo,
Gpt4,
#[default]
Gpt4Turbo,
Claude3Opus,
Claude3Sonnet,
Claude3Haiku,
Custom(String),
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum OpenAiModel {
#[serde(rename = "gpt-3.5-turbo-0613")]
ThreePointFiveTurbo,
#[serde(rename = "gpt-4-0613")]
Four,
#[serde(rename = "gpt-4-1106-preview")]
FourTurbo,
}
impl Serialize for ZedDotDevModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.id())
}
}
impl<'de> Deserialize<'de> for ZedDotDevModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ZedDotDevModelVisitor;
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
type Value = ZedDotDevModel;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
match value {
"gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
"gpt-4" => Ok(ZedDotDevModel::Gpt4),
"gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
_ => Ok(ZedDotDevModel::Custom(value.to_owned())),
}
}
}
deserializer.deserialize_str(ZedDotDevModelVisitor)
}
}
impl JsonSchema for ZedDotDevModel {
fn schema_name() -> String {
"ZedDotDevModel".to_owned()
}
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
let variants = vec![
"gpt-3.5-turbo".to_owned(),
"gpt-4".to_owned(),
"gpt-4-turbo-preview".to_owned(),
];
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
metadata: Some(Box::new(Metadata {
title: Some("ZedDotDevModel".to_owned()),
default: Some(serde_json::json!("gpt-4-turbo-preview")),
examples: vec![
serde_json::json!("gpt-3.5-turbo"),
serde_json::json!("gpt-4"),
serde_json::json!("gpt-4-turbo-preview"),
serde_json::json!("custom-model-name"),
],
..Default::default()
})),
..Default::default()
})
}
}
impl ZedDotDevModel {
pub fn id(&self) -> &str {
impl OpenAiModel {
pub fn full_name(&self) -> &'static str {
match self {
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4",
Self::Gpt4Turbo => "gpt-4-turbo-preview",
Self::Claude3Opus => "claude-3-opus",
Self::Claude3Sonnet => "claude-3-sonnet",
Self::Claude3Haiku => "claude-3-haiku",
Self::Custom(id) => id,
Self::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
Self::Four => "gpt-4-0613",
Self::FourTurbo => "gpt-4-1106-preview",
}
}
pub fn display_name(&self) -> &str {
pub fn short_name(&self) -> &'static str {
match self {
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
Self::Gpt4 => "GPT 4",
Self::Gpt4Turbo => "GPT 4 Turbo",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Custom(id) => id.as_str(),
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
Self::Four => "gpt-4",
Self::FourTurbo => "gpt-4-turbo",
}
}
pub fn max_token_count(&self) -> usize {
pub fn cycle(&self) -> Self {
match self {
Self::Gpt3Point5Turbo => 2048,
Self::Gpt4 => 4096,
Self::Gpt4Turbo => 128000,
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
Self::Custom(_) => 4096, // TODO: Make this configurable
Self::ThreePointFiveTurbo => Self::Four,
Self::Four => Self::FourTurbo,
Self::FourTurbo => Self::ThreePointFiveTurbo,
}
}
}
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum AssistantDockPosition {
Left,
#[default]
Right,
Bottom,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(tag = "name", rename_all = "snake_case")]
pub enum AssistantProvider {
#[serde(rename = "zed.dev")]
ZedDotDev {
#[serde(default)]
default_model: ZedDotDevModel,
},
#[serde(rename = "openai")]
OpenAi {
#[serde(default)]
default_model: OpenAiModel,
#[serde(default = "open_ai_url")]
api_url: String,
},
}
impl Default for AssistantProvider {
fn default() -> Self {
Self::ZedDotDev {
default_model: ZedDotDevModel::default(),
}
}
}
fn open_ai_url() -> String {
"https://api.openai.com/v1".into()
}
#[derive(Default, Debug, Deserialize, Serialize)]
#[derive(Debug, Deserialize)]
pub struct AssistantSettings {
pub enabled: bool,
/// Whether to show the assistant panel button in the status bar.
pub button: bool,
/// Where to dock the assistant.
pub dock: AssistantDockPosition,
/// Default width in pixels when the assistant is docked to the left or right.
pub default_width: Pixels,
/// Default height in pixels when the assistant is docked to the bottom.
pub default_height: Pixels,
pub provider: AssistantProvider,
/// The default OpenAI model to use when starting new conversations.
#[deprecated = "Please use `provider.default_model` instead."]
pub default_open_ai_model: OpenAiModel,
/// OpenAI API base URL to use when starting new conversations.
#[deprecated = "Please use `provider.api_url` instead."]
pub openai_api_url: String,
/// The settings for the AI provider.
pub provider: AiProviderSettings,
}
/// Assistant panel settings
#[derive(Clone, Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum AssistantSettingsContent {
Versioned(VersionedAssistantSettingsContent),
Legacy(LegacyAssistantSettingsContent),
}
impl AssistantSettings {
pub fn provider_kind(&self) -> anyhow::Result<OpenAiCompletionProviderKind> {
match &self.provider {
AiProviderSettings::OpenAi(_) => Ok(OpenAiCompletionProviderKind::OpenAi),
AiProviderSettings::AzureOpenAi(settings) => {
let deployment_id = settings
.deployment_id
.clone()
.ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
let api_version = settings
.api_version
.ok_or_else(|| anyhow!("no Azure OpenAI API version"))?;
impl JsonSchema for AssistantSettingsContent {
fn schema_name() -> String {
VersionedAssistantSettingsContent::schema_name()
}
fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
VersionedAssistantSettingsContent::json_schema(gen)
}
fn is_referenceable() -> bool {
VersionedAssistantSettingsContent::is_referenceable()
}
}
impl Default for AssistantSettingsContent {
fn default() -> Self {
Self::Versioned(VersionedAssistantSettingsContent::default())
}
}
impl AssistantSettingsContent {
fn upgrade(&self) -> AssistantSettingsContentV1 {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
},
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
enabled: None,
button: settings.button,
dock: settings.dock,
default_width: settings.default_width,
default_height: settings.default_height,
provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
Some(AssistantProvider::OpenAi {
default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
api_url: open_ai_api_url.clone(),
})
} else {
settings.default_open_ai_model.clone().map(|open_ai_model| {
AssistantProvider::OpenAi {
default_model: open_ai_model,
api_url: open_ai_url(),
}
})
},
},
}
}
pub fn set_dock(&mut self, dock: AssistantDockPosition) {
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => {
settings.dock = Some(dock);
}
},
AssistantSettingsContent::Legacy(settings) => {
settings.dock = Some(dock);
Ok(OpenAiCompletionProviderKind::AzureOpenAi {
deployment_id,
api_version,
})
}
}
}
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
#[serde(tag = "version")]
pub enum VersionedAssistantSettingsContent {
#[serde(rename = "1")]
V1(AssistantSettingsContentV1),
}
pub fn provider_api_url(&self) -> anyhow::Result<String> {
match &self.provider {
AiProviderSettings::OpenAi(settings) => Ok(settings
.api_url
.clone()
.unwrap_or_else(|| OPEN_AI_API_URL.to_string())),
AiProviderSettings::AzureOpenAi(settings) => settings
.api_url
.clone()
.ok_or_else(|| anyhow!("no Azure OpenAI API URL")),
}
}
impl Default for VersionedAssistantSettingsContent {
fn default() -> Self {
Self::V1(AssistantSettingsContentV1 {
enabled: None,
button: None,
dock: None,
default_width: None,
default_height: None,
provider: None,
})
pub fn provider_model(&self) -> anyhow::Result<OpenAiModel> {
match &self.provider {
AiProviderSettings::OpenAi(settings) => {
Ok(settings.default_model.unwrap_or(OpenAiModel::FourTurbo))
}
AiProviderSettings::AzureOpenAi(settings) => {
let deployment_id = settings
.deployment_id
.as_deref()
.ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
match deployment_id {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-preview
"gpt-4" | "gpt-4-32k" => Ok(OpenAiModel::Four),
// https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35
"gpt-35-turbo" | "gpt-35-turbo-16k" | "gpt-35-turbo-instruct" => {
Ok(OpenAiModel::ThreePointFiveTurbo)
}
_ => Err(anyhow!(
"no matching OpenAI model found for deployment ID: '{deployment_id}'"
)),
}
}
}
}
pub fn provider_model_name(&self) -> anyhow::Result<String> {
match &self.provider {
AiProviderSettings::OpenAi(settings) => Ok(settings
.default_model
.unwrap_or(OpenAiModel::FourTurbo)
.full_name()
.to_string()),
AiProviderSettings::AzureOpenAi(settings) => settings
.deployment_id
.clone()
.ok_or_else(|| anyhow!("no Azure OpenAI deployment ID")),
}
}
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
pub struct AssistantSettingsContentV1 {
/// Whether the Assistant is enabled.
///
/// Default: true
enabled: Option<bool>,
/// Whether to show the assistant panel button in the status bar.
///
/// Default: true
button: Option<bool>,
/// Where to dock the assistant.
///
/// Default: right
dock: Option<AssistantDockPosition>,
/// Default width in pixels when the assistant is docked to the left or right.
///
/// Default: 640
default_width: Option<f32>,
/// Default height in pixels when the assistant is docked to the bottom.
///
/// Default: 320
default_height: Option<f32>,
/// The provider of the assistant service.
///
/// This can either be the internal `zed.dev` service or an external `openai` service,
/// each with their respective default models and configurations.
provider: Option<AssistantProvider>,
impl Settings for AssistantSettings {
const KEY: Option<&'static str> = Some("assistant");
type FileContent = AssistantSettingsContent;
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &mut gpui::AppContext,
) -> anyhow::Result<Self> {
Self::load_via_json_merge(default_value, user_values)
}
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
pub struct LegacyAssistantSettingsContent {
/// Assistant panel settings
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
pub struct AssistantSettingsContent {
/// Whether to show the assistant panel button in the status bar.
///
/// Default: true
@@ -316,165 +180,88 @@ pub struct LegacyAssistantSettingsContent {
///
/// Default: 320
pub default_height: Option<f32>,
/// Deprecated: Please use `provider.default_model` instead.
/// The default OpenAI model to use when starting new conversations.
///
/// Default: gpt-4-1106-preview
#[deprecated = "Please use `provider.default_model` instead."]
pub default_open_ai_model: Option<OpenAiModel>,
/// Deprecated: Please use `provider.api_url` instead.
/// OpenAI API base URL to use when starting new conversations.
///
/// Default: https://api.openai.com/v1
#[deprecated = "Please use `provider.api_url` instead."]
pub openai_api_url: Option<String>,
/// The settings for the AI provider.
#[serde(default)]
pub provider: AiProviderSettingsContent,
}
impl Settings for AssistantSettings {
const KEY: Option<&'static str> = Some("assistant");
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AiProviderSettings {
/// The settings for the OpenAI provider.
#[serde(rename = "openai")]
OpenAi(OpenAiProviderSettings),
/// The settings for the Azure OpenAI provider.
#[serde(rename = "azure_openai")]
AzureOpenAi(AzureOpenAiProviderSettings),
}
type FileContent = AssistantSettingsContent;
/// The settings for the AI provider used by the Zed Assistant.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AiProviderSettingsContent {
/// The settings for the OpenAI provider.
#[serde(rename = "openai")]
OpenAi(OpenAiProviderSettingsContent),
/// The settings for the Azure OpenAI provider.
#[serde(rename = "azure_openai")]
AzureOpenAi(AzureOpenAiProviderSettingsContent),
}
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &mut gpui::AppContext,
) -> anyhow::Result<Self> {
let mut settings = AssistantSettings::default();
for value in [default_value].iter().chain(user_values) {
let value = value.upgrade();
merge(&mut settings.enabled, value.enabled);
merge(&mut settings.button, value.button);
merge(&mut settings.dock, value.dock);
merge(
&mut settings.default_width,
value.default_width.map(Into::into),
);
merge(
&mut settings.default_height,
value.default_height.map(Into::into),
);
if let Some(provider) = value.provider.clone() {
match (&mut settings.provider, provider) {
(
AssistantProvider::ZedDotDev { default_model },
AssistantProvider::ZedDotDev {
default_model: default_model_override,
},
) => {
*default_model = default_model_override;
}
(
AssistantProvider::OpenAi {
default_model,
api_url,
},
AssistantProvider::OpenAi {
default_model: default_model_override,
api_url: api_url_override,
},
) => {
*default_model = default_model_override;
*api_url = api_url_override;
}
(merged, provider_override) => {
*merged = provider_override;
}
}
}
}
Ok(settings)
impl Default for AiProviderSettingsContent {
fn default() -> Self {
Self::OpenAi(OpenAiProviderSettingsContent::default())
}
}
fn merge<T: Copy>(target: &mut T, value: Option<T>) {
if let Some(value) = value {
*target = value;
}
#[derive(Debug, Clone, Deserialize)]
pub struct OpenAiProviderSettings {
/// The OpenAI API base URL to use when starting new conversations.
pub api_url: Option<String>,
/// The default OpenAI model to use when starting new conversations.
pub default_model: Option<OpenAiModel>,
}
#[cfg(test)]
mod tests {
use gpui::{AppContext, BorrowAppContext};
use settings::SettingsStore;
use super::*;
#[gpui::test]
fn test_deserialize_assistant_settings(cx: &mut AppContext) {
let store = settings::SettingsStore::test(cx);
cx.set_global(store);
// Settings default to gpt-4-turbo.
AssistantSettings::register(cx);
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
default_model: OpenAiModel::FourTurbo,
api_url: open_ai_url()
}
);
// Ensure backward-compatibility.
cx.update_global::<SettingsStore, _>(|store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"openai_api_url": "test-url",
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
default_model: OpenAiModel::FourTurbo,
api_url: "test-url".into()
}
);
cx.update_global::<SettingsStore, _>(|store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"default_open_ai_model": "gpt-4-0613"
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::OpenAi {
default_model: OpenAiModel::Four,
api_url: open_ai_url()
}
);
// The new version supports setting a custom model when using zed.dev.
cx.update_global::<SettingsStore, _>(|store, cx| {
store
.set_user_settings(
r#"{
"assistant": {
"version": "1",
"provider": {
"name": "zed.dev",
"default_model": "custom"
}
}
}"#,
cx,
)
.unwrap();
});
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::ZedDotDev {
default_model: ZedDotDevModel::Custom("custom".into())
}
);
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
pub struct OpenAiProviderSettingsContent {
/// The OpenAI API base URL to use when starting new conversations.
///
/// Default: https://api.openai.com/v1
pub api_url: Option<String>,
/// The default OpenAI model to use when starting new conversations.
///
/// Default: gpt-4-1106-preview
pub default_model: Option<OpenAiModel>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AzureOpenAiProviderSettings {
/// The Azure OpenAI API base URL to use when starting new conversations.
pub api_url: Option<String>,
/// The Azure OpenAI API version.
pub api_version: Option<AzureOpenAiApiVersion>,
/// The Azure OpenAI API deployment ID.
pub deployment_id: Option<String>,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AzureOpenAiProviderSettingsContent {
/// The Azure OpenAI API base URL to use when starting new conversations.
pub api_url: Option<String>,
/// The Azure OpenAI API version.
pub api_version: Option<AzureOpenAiApiVersion>,
/// The Azure OpenAI deployment ID.
pub deployment_id: Option<String>,
}

View File

@@ -1,13 +1,12 @@
use crate::{
streaming_diff::{Hunk, StreamingDiff},
CompletionProvider, LanguageModelRequest,
};
use crate::streaming_diff::{Hunk, StreamingDiff};
use ai::completion::{CompletionProvider, CompletionRequest};
use anyhow::Result;
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{EventEmitter, Model, ModelContext, Task};
use language::{Rope, TransactionId};
use std::{cmp, future, ops::Range};
use multi_buffer;
use std::{cmp, future, ops::Range, sync::Arc};
pub enum Event {
Finished,
@@ -21,6 +20,7 @@ pub enum CodegenKind {
}
pub struct Codegen {
provider: Arc<dyn CompletionProvider>,
buffer: Model<MultiBuffer>,
snapshot: MultiBufferSnapshot,
kind: CodegenKind,
@@ -35,9 +35,15 @@ pub struct Codegen {
impl EventEmitter<Event> for Codegen {}
impl Codegen {
pub fn new(buffer: Model<MultiBuffer>, kind: CodegenKind, cx: &mut ModelContext<Self>) -> Self {
pub fn new(
buffer: Model<MultiBuffer>,
kind: CodegenKind,
provider: Arc<dyn CompletionProvider>,
cx: &mut ModelContext<Self>,
) -> Self {
let snapshot = buffer.read(cx).snapshot(cx);
Self {
provider,
buffer: buffer.clone(),
snapshot,
kind,
@@ -88,7 +94,7 @@ impl Codegen {
self.error.as_ref()
}
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
let range = self.range();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
@@ -102,7 +108,7 @@ impl Codegen {
.next()
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
let response = CompletionProvider::global(cx).complete(prompt);
let response = self.provider.complete(prompt);
self.generation = cx.spawn(|this, mut cx| {
async move {
let generate = async {
@@ -299,7 +305,7 @@ fn strip_invalid_spans_from_codeblock(
}
if first_line {
if buffer.is_empty() || buffer == "`" || buffer == "``" {
if buffer == "" || buffer == "`" || buffer == "``" {
return future::ready(None);
} else if buffer.starts_with("```") {
starts_with_markdown_codeblock = true;
@@ -354,9 +360,8 @@ fn strip_invalid_spans_from_codeblock(
mod tests {
use std::sync::Arc;
use crate::FakeCompletionProvider;
use super::*;
use ai::test::FakeCompletionProvider;
use futures::stream::{self};
use gpui::{Context, TestAppContext};
use indoc::indoc;
@@ -373,11 +378,15 @@ mod tests {
pub name: String,
}
impl CompletionRequest for DummyCompletionRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
let provider = FakeCompletionProvider::default();
cx.set_global(cx.update(SettingsStore::test));
cx.set_global(CompletionProvider::Fake(provider.clone()));
cx.update(language_settings::init);
let text = indoc! {"
@@ -396,10 +405,19 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Transform { range }, cx));
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Transform { range },
provider.clone(),
cx,
)
});
let request = LanguageModelRequest::default();
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
@@ -412,7 +430,8 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk.into());
println!("CHUNK: {:?}", &chunk);
provider.send_completion(chunk);
new_text = suffix;
cx.background_executor.run_until_parked();
}
@@ -437,8 +456,6 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
let provider = FakeCompletionProvider::default();
cx.set_global(CompletionProvider::Fake(provider.clone()));
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
@@ -455,10 +472,19 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
let request = LanguageModelRequest::default();
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
@@ -471,7 +497,7 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk.into());
provider.send_completion(chunk);
new_text = suffix;
cx.background_executor.run_until_parked();
}
@@ -496,8 +522,6 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
let provider = FakeCompletionProvider::default();
cx.set_global(CompletionProvider::Fake(provider.clone()));
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
@@ -514,10 +538,19 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
let codegen =
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
let request = LanguageModelRequest::default();
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
@@ -530,7 +563,8 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk.into());
println!("{:?}", &chunk);
provider.send_completion(chunk);
new_text = suffix;
cx.background_executor.run_until_parked();
}

View File

@@ -1,188 +0,0 @@
#[cfg(test)]
mod fake;
mod open_ai;
mod zed;
#[cfg(test)]
pub use fake::*;
pub use open_ai::*;
pub use zed::*;
use crate::{
assistant_settings::{AssistantProvider, AssistantSettings},
LanguageModel, LanguageModelRequest,
};
use anyhow::Result;
use client::Client;
use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let mut settings_version = 0;
let provider = match &AssistantSettings::get_global(cx).provider {
AssistantProvider::ZedDotDev { default_model } => {
CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
default_model.clone(),
client.clone(),
settings_version,
cx,
))
}
AssistantProvider::OpenAi {
default_model,
api_url,
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
default_model.clone(),
api_url.clone(),
client.http_client(),
settings_version,
)),
};
cx.set_global(provider);
cx.observe_global::<SettingsStore>(move |cx| {
settings_version += 1;
cx.update_global::<CompletionProvider, _>(|provider, cx| {
match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
(
CompletionProvider::OpenAi(provider),
AssistantProvider::OpenAi {
default_model,
api_url,
},
) => {
provider.update(default_model.clone(), api_url.clone(), settings_version);
}
(
CompletionProvider::ZedDotDev(provider),
AssistantProvider::ZedDotDev { default_model },
) => {
provider.update(default_model.clone(), settings_version);
}
(CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
default_model.clone(),
client.clone(),
settings_version,
cx,
));
}
(
CompletionProvider::ZedDotDev(_),
AssistantProvider::OpenAi {
default_model,
api_url,
},
) => {
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
default_model.clone(),
api_url.clone(),
client.http_client(),
settings_version,
));
}
#[cfg(test)]
(CompletionProvider::Fake(_), _) => unimplemented!(),
}
})
})
.detach();
}
pub enum CompletionProvider {
OpenAi(OpenAiCompletionProvider),
ZedDotDev(ZedDotDevCompletionProvider),
#[cfg(test)]
Fake(FakeCompletionProvider),
}
impl gpui::Global for CompletionProvider {}
impl CompletionProvider {
pub fn global(cx: &AppContext) -> &Self {
cx.global::<Self>()
}
pub fn settings_version(&self) -> usize {
match self {
CompletionProvider::OpenAi(provider) => provider.settings_version(),
CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
}
pub fn is_authenticated(&self) -> bool {
match self {
CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
#[cfg(test)]
CompletionProvider::Fake(_) => true,
}
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
match self {
CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())),
}
}
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
match self {
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
}
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
match self {
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
#[cfg(test)]
CompletionProvider::Fake(_) => Task::ready(Ok(())),
}
}
pub fn default_model(&self) -> LanguageModel {
match self {
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
CompletionProvider::ZedDotDev(provider) => {
LanguageModel::ZedDotDev(provider.default_model())
}
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
}
pub fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
match self {
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
}
}
pub fn complete(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
match self {
CompletionProvider::OpenAi(provider) => provider.complete(request),
CompletionProvider::ZedDotDev(provider) => provider.complete(request),
#[cfg(test)]
CompletionProvider::Fake(provider) => provider.complete(),
}
}
}

View File

@@ -1,29 +0,0 @@
use anyhow::Result;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use std::sync::Arc;
#[derive(Clone, Default)]
pub struct FakeCompletionProvider {
current_completion_tx: Arc<parking_lot::Mutex<Option<mpsc::UnboundedSender<String>>>>,
}
impl FakeCompletionProvider {
pub fn complete(&self) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::unbounded();
*self.current_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(Ok).boxed()) }.boxed()
}
pub fn send_completion(&self, chunk: String) {
self.current_completion_tx
.lock()
.as_ref()
.unwrap()
.unbounded_send(chunk)
.unwrap();
}
pub fn finish_completion(&self) {
self.current_completion_tx.lock().take();
}
}

View File

@@ -1,301 +0,0 @@
use crate::{
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
};
use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
use settings::Settings;
use std::{env, sync::Arc};
use theme::ThemeSettings;
use ui::prelude::*;
use util::{http::HttpClient, ResultExt};
pub struct OpenAiCompletionProvider {
api_key: Option<String>,
api_url: String,
default_model: OpenAiModel,
http_client: Arc<dyn HttpClient>,
settings_version: usize,
}
impl OpenAiCompletionProvider {
pub fn new(
default_model: OpenAiModel,
api_url: String,
http_client: Arc<dyn HttpClient>,
settings_version: usize,
) -> Self {
Self {
api_key: None,
api_url,
default_model,
http_client,
settings_version,
}
}
pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) {
self.default_model = default_model;
self.api_url = api_url;
self.settings_version = settings_version;
}
pub fn settings_version(&self) -> usize {
self.settings_version
}
pub fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
let api_url = self.api_url.clone();
cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
api_key
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
String::from_utf8(api_key)?
};
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::OpenAi(provider) = provider {
provider.api_key = Some(api_key);
}
})
})
}
}
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let delete_credentials = cx.delete_credentials(&self.api_url);
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::OpenAi(provider) = provider {
provider.api_key = None;
}
})
})
}
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
.into()
}
pub fn default_model(&self) -> OpenAiModel {
self.default_model.clone()
}
pub fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
count_open_ai_tokens(request, cx.background_executor())
}
pub fn complete(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = self.to_open_ai_request(request);
let http_client = self.http_client.clone();
let api_key = self.api_key.clone();
let api_url = self.api_url.clone();
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
let model = match request.model {
LanguageModel::ZedDotDev(_) => self.default_model(),
LanguageModel::OpenAi(model) => model,
};
Request {
model,
messages: request
.messages
.into_iter()
.map(|msg| RequestMessage {
role: msg.role.into(),
content: msg.content,
})
.collect(),
stream: true,
stop: request.stop,
temperature: request.temperature,
}
}
}
pub fn count_open_ai_tokens(
request: LanguageModelRequest,
background_executor: &gpui::BackgroundExecutor,
) -> BoxFuture<'static, Result<usize>> {
background_executor
.spawn(async move {
let messages = request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.content),
name: None,
function_call: None,
})
.collect::<Vec<_>>();
tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
})
.boxed()
}
impl From<Role> for open_ai::Role {
fn from(val: Role) -> Self {
match val {
Role::User => OpenAiRole::User,
Role::Assistant => OpenAiRole::Assistant,
Role::System => OpenAiRole::System,
}
}
}
struct AuthenticationPrompt {
api_key: View<Editor>,
api_url: String,
}
impl AuthenticationPrompt {
fn new(api_url: String, cx: &mut WindowContext) -> Self {
Self {
api_key: cx.new_view(|cx| {
let mut editor = Editor::single_line(cx);
editor.set_placeholder_text(
"sk-000000000000000000000000000000000000000000000000",
cx,
);
editor
}),
api_url,
}
}
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
let api_key = self.api_key.read(cx).text(cx);
if api_key.is_empty() {
return;
}
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
cx.spawn(|_, mut cx| async move {
write_credentials.await?;
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::OpenAi(provider) = provider {
provider.api_key = Some(api_key);
}
})
})
.detach_and_log_err(cx);
}
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features,
font_size: rems(0.875).into(),
font_weight: FontWeight::NORMAL,
font_style: FontStyle::Normal,
line_height: relative(1.3),
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
};
EditorElement::new(
&self.api_key,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
..Default::default()
},
)
}
}
impl Render for AuthenticationPrompt {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
const INSTRUCTIONS: [&str; 6] = [
"To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
" - You can create an API key at: platform.openai.com/api-keys",
" - Make sure your OpenAI account has credits",
" - Having a subscription for another service like GitHub Copilot won't work.",
"",
"Paste your OpenAI API key below and hit enter to use the assistant:",
];
v_flex()
.p_4()
.size_full()
.on_action(cx.listener(Self::save_api_key))
.children(
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
)
.child(
h_flex()
.w_full()
.my_2()
.px_2()
.py_1()
.bg(cx.theme().colors().editor_background)
.rounded_md()
.child(self.render_api_key_editor(cx)),
)
.child(
Label::new(
"You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
)
.size(LabelSize::Small),
)
.child(
h_flex()
.gap_2()
.child(Label::new("Click on").size(LabelSize::Small))
.child(Icon::new(IconName::Ai).size(IconSize::XSmall))
.child(
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
),
)
.into_any()
}
}

View File

@@ -1,175 +0,0 @@
use crate::{
assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
LanguageModelRequest,
};
use anyhow::{anyhow, Result};
use client::{proto, Client};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use gpui::{AnyView, AppContext, Task};
use std::{future, sync::Arc};
use ui::prelude::*;
pub struct ZedDotDevCompletionProvider {
client: Arc<Client>,
default_model: ZedDotDevModel,
settings_version: usize,
status: client::Status,
_maintain_client_status: Task<()>,
}
impl ZedDotDevCompletionProvider {
pub fn new(
default_model: ZedDotDevModel,
client: Arc<Client>,
settings_version: usize,
cx: &mut AppContext,
) -> Self {
let mut status_rx = client.status();
let status = *status_rx.borrow();
let maintain_client_status = cx.spawn(|mut cx| async move {
while let Some(status) = status_rx.next().await {
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::ZedDotDev(provider) = provider {
provider.status = status;
} else {
unreachable!()
}
});
}
});
Self {
client,
default_model,
settings_version,
status,
_maintain_client_status: maintain_client_status,
}
}
pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) {
self.default_model = default_model;
self.settings_version = settings_version;
}
pub fn settings_version(&self) -> usize {
self.settings_version
}
pub fn default_model(&self) -> ZedDotDevModel {
self.default_model.clone()
}
pub fn is_authenticated(&self) -> bool {
self.status.is_connected()
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
let client = self.client.clone();
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
}
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|_cx| AuthenticationPrompt).into()
}
pub fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
match request.model {
LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => {
count_open_ai_tokens(request, cx.background_executor())
}
LanguageModel::ZedDotDev(
ZedDotDevModel::Claude3Opus
| ZedDotDevModel::Claude3Sonnet
| ZedDotDevModel::Claude3Haiku,
) => {
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
count_open_ai_tokens(request, cx.background_executor())
}
LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
model,
messages: request
.messages
.iter()
.map(|message| message.to_proto())
.collect(),
});
async move {
let response = request.await?;
Ok(response.token_count as usize)
}
.boxed()
}
}
}
pub fn complete(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = proto::CompleteWithLanguageModel {
model: request.model.id().to_string(),
messages: request
.messages
.iter()
.map(|message| message.to_proto())
.collect(),
stop: request.stop,
temperature: request.temperature,
};
self.client
.request_stream(request)
.map_ok(|stream| {
stream
.filter_map(|response| async move {
match response {
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
Err(error) => Some(Err(error)),
}
})
.boxed()
})
.boxed()
}
}
struct AuthenticationPrompt;
impl Render for AuthenticationPrompt {
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
v_flex()
.gap_2()
.child(
Button::new("sign_in", "Sign in")
.icon_color(Color::Muted)
.icon(IconName::Github)
.icon_position(IconPosition::Start)
.style(ButtonStyle::Filled)
.full_width()
.on_click(|_, cx| {
CompletionProvider::global(cx)
.authenticate(cx)
.detach_and_log_err(cx);
}),
)
.child(
div().flex().w_full().items_center().child(
Label::new("Sign in to enable collaboration.")
.color(Color::Muted)
.size(LabelSize::Small),
),
),
)
}
}

View File

@@ -1,91 +0,0 @@
use editor::MultiBuffer;
use gpui::{AppContext, Model, ModelContext, Subscription};
use crate::{assistant_panel::Conversation, LanguageModelRequestMessage, Role};
#[derive(Default)]
pub struct EmbeddedScope {
active_buffer: Option<Model<MultiBuffer>>,
active_buffer_enabled: bool,
active_buffer_subscription: Option<Subscription>,
}
impl EmbeddedScope {
pub fn new() -> Self {
Self {
active_buffer: None,
active_buffer_enabled: true,
active_buffer_subscription: None,
}
}
pub fn set_active_buffer(
&mut self,
buffer: Option<Model<MultiBuffer>>,
cx: &mut ModelContext<Conversation>,
) {
self.active_buffer_subscription.take();
if let Some(active_buffer) = buffer.clone() {
self.active_buffer_subscription =
Some(cx.subscribe(&active_buffer, |conversation, _, e, cx| {
if let multi_buffer::Event::Edited { .. } = e {
conversation.count_remaining_tokens(cx)
}
}));
}
self.active_buffer = buffer;
}
pub fn active_buffer(&self) -> Option<&Model<MultiBuffer>> {
self.active_buffer.as_ref()
}
pub fn active_buffer_enabled(&self) -> bool {
self.active_buffer_enabled
}
pub fn set_active_buffer_enabled(&mut self, enabled: bool) {
self.active_buffer_enabled = enabled;
}
/// Provide a message for the language model based on the active buffer.
pub fn message(&self, cx: &AppContext) -> Option<LanguageModelRequestMessage> {
if !self.active_buffer_enabled {
return None;
}
let active_buffer = self.active_buffer.as_ref()?;
let buffer = active_buffer.read(cx);
if let Some(singleton) = buffer.as_singleton() {
let singleton = singleton.read(cx);
let filename = singleton
.file()
.map(|file| file.path().to_string_lossy())
.unwrap_or("Untitled".into());
let text = singleton.text();
let language = singleton
.language()
.map(|l| {
let name = l.code_fence_block_name();
name.to_string()
})
.unwrap_or_default();
let markdown =
format!("User's active file `{filename}`:\n\n```{language}\n{text}```\n\n");
return Some(LanguageModelRequestMessage {
role: Role::System,
content: markdown,
});
}
None
}
}

View File

@@ -1,95 +1,394 @@
use language::BufferSnapshot;
use std::{fmt::Write, ops::Range};
use ai::models::LanguageModel;
use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
use ai::prompts::file_context::FileContext;
use ai::prompts::generate::GenerateInlineContent;
use ai::prompts::preamble::EngineerPreamble;
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
use ai::providers::open_ai::OpenAiLanguageModel;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp::{self, Reverse};
use std::ops::Range;
use std::sync::Arc;
#[allow(dead_code)]
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
#[derive(Debug)]
struct Match {
collapse: Range<usize>,
keep: Vec<Range<usize>>,
}
let selected_range = selected_range.to_offset(buffer);
let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
Some(&grammar.embedding_config.as_ref()?.query)
});
let configs = ts_matches
.grammars()
.iter()
.map(|g| g.embedding_config.as_ref().unwrap())
.collect::<Vec<_>>();
let mut matches = Vec::new();
while let Some(mat) = ts_matches.peek() {
let config = &configs[mat.grammar_index];
if let Some(collapse) = mat.captures.iter().find_map(|cap| {
if Some(cap.index) == config.collapse_capture_ix {
Some(cap.node.byte_range())
} else {
None
}
}) {
let mut keep = Vec::new();
for capture in mat.captures.iter() {
if Some(capture.index) == config.keep_capture_ix {
keep.push(capture.node.byte_range());
} else {
continue;
}
}
ts_matches.advance();
matches.push(Match { collapse, keep });
} else {
ts_matches.advance();
}
}
matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
let mut matches = matches.into_iter().peekable();
let mut summary = String::new();
let mut offset = 0;
let mut flushed_selection = false;
while let Some(mat) = matches.next() {
// Keep extending the collapsed range if the next match surrounds
// the current one.
while let Some(next_mat) = matches.peek() {
if mat.collapse.start <= next_mat.collapse.start
&& mat.collapse.end >= next_mat.collapse.end
{
matches.next().unwrap();
} else {
break;
}
}
if offset > mat.collapse.start {
// Skip collapsed nodes that have already been summarized.
offset = cmp::max(offset, mat.collapse.end);
continue;
}
if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
if !flushed_selection {
// The collapsed node ends after the selection starts, so we'll flush the selection first.
summary.extend(buffer.text_for_range(offset..selected_range.start));
summary.push_str("<|S|");
if selected_range.end == selected_range.start {
summary.push_str(">");
} else {
summary.extend(buffer.text_for_range(selected_range.clone()));
summary.push_str("|E|>");
}
offset = selected_range.end;
flushed_selection = true;
}
// If the selection intersects the collapsed node, we won't collapse it.
if selected_range.end >= mat.collapse.start {
continue;
}
}
summary.extend(buffer.text_for_range(offset..mat.collapse.start));
for keep in mat.keep {
summary.extend(buffer.text_for_range(keep));
}
offset = mat.collapse.end;
}
// Flush selection if we haven't already done so.
if !flushed_selection && offset <= selected_range.start {
summary.extend(buffer.text_for_range(offset..selected_range.start));
summary.push_str("<|S|");
if selected_range.end == selected_range.start {
summary.push_str(">");
} else {
summary.extend(buffer.text_for_range(selected_range.clone()));
summary.push_str("|E|>");
}
offset = selected_range.end;
}
summary.extend(buffer.text_for_range(offset..buffer.len()));
summary
}
pub fn generate_content_prompt(
user_prompt: String,
language_name: Option<&str>,
buffer: BufferSnapshot,
range: Range<usize>,
search_results: Vec<PromptCodeSnippet>,
model: &str,
project_name: Option<String>,
) -> anyhow::Result<String> {
let mut prompt = String::new();
let content_type = match language_name {
None | Some("Markdown" | "Plain Text") => {
writeln!(prompt, "You are an expert engineer.")?;
"Text"
}
Some(language_name) => {
writeln!(prompt, "You are an expert {language_name} engineer.")?;
writeln!(
prompt,
"Your answer MUST always and only be valid {}.",
language_name
)?;
"Code"
}
// Using new Prompt Templates
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAiLanguageModel::load(model));
let lang_name = if let Some(language_name) = language_name {
Some(language_name.to_string())
} else {
None
};
if let Some(project_name) = project_name {
writeln!(
prompt,
"You are currently working inside the '{project_name}' project in code editor Zed."
)?;
}
let args = PromptArguments {
model: openai_model,
language_name: lang_name.clone(),
project_name,
snippets: search_results.clone(),
reserved_tokens: 1000,
buffer: Some(buffer),
selected_range: Some(range),
user_prompt: Some(user_prompt.clone()),
};
// Include file content.
for chunk in buffer.text_for_range(0..range.start) {
prompt.push_str(chunk);
}
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
(
PromptPriority::Ordered { order: 1 },
Box::new(RepositoryContext {}),
),
(
PromptPriority::Ordered { order: 0 },
Box::new(FileContext {}),
),
(
PromptPriority::Mandatory,
Box::new(GenerateInlineContent {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, _) = chain.generate(true)?;
if range.is_empty() {
prompt.push_str("<|START|>");
} else {
prompt.push_str("<|START|");
}
for chunk in buffer.text_for_range(range.clone()) {
prompt.push_str(chunk);
}
if !range.is_empty() {
prompt.push_str("|END|>");
}
for chunk in buffer.text_for_range(range.end..buffer.len()) {
prompt.push_str(chunk);
}
prompt.push('\n');
if range.is_empty() {
writeln!(
prompt,
"Assume the cursor is located where the `<|START|>` span is."
)
.unwrap();
writeln!(
prompt,
"{content_type} can't be replaced, so assume your answer will be inserted at the cursor.",
)
.unwrap();
writeln!(
prompt,
"Generate {content_type} based on the users prompt: {user_prompt}",
)
.unwrap();
} else {
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
writeln!(
prompt,
"Double check that you only return code and not the '<|START|' and '|END|'> spans"
)
.unwrap();
}
writeln!(prompt, "Never make remarks about the output.").unwrap();
writeln!(
prompt,
"Do not return anything else, except the generated {content_type}."
)
.unwrap();
Ok(prompt)
anyhow::Ok(prompt)
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use gpui::{AppContext, Context};
use indoc::indoc;
use language::{
language_settings, tree_sitter_rust, Buffer, BufferId, Language, LanguageConfig,
LanguageMatcher, Point,
};
use settings::SettingsStore;
use std::sync::Arc;
pub(crate) fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_embedding_query(
r#"
(
[(line_comment) (attribute_item)]* @context
.
[
(struct_item
name: (_) @name)
(enum_item
name: (_) @name)
(impl_item
trait: (_)? @name
"for"? @name
type: (_) @name)
(trait_item
name: (_) @name)
(function_item
name: (_) @name
body: (block
"{" @keep
"}" @keep) @collapse)
(macro_definition
name: (_) @name)
] @item
)
"#,
)
.unwrap()
}
#[gpui::test]
fn test_outline_for_prompt(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language_settings::init(cx);
let text = indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let a = 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {
self.a
}
pub fn b(&self) -> usize {
self.b
}
}
"};
let buffer = cx.new_model(|cx| {
Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
});
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
indoc! {"
struct X {
<|S|>a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
assert_eq!(
summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let <|S|a |E|>= 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
assert_eq!(
summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
<|S|>
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
assert_eq!(
summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
<|S|>"}
);
// Ensure nested functions get collapsed properly.
let text = indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let a = 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {
let a = 30;
fn nested() -> usize {
3
}
self.a + nested()
}
pub fn b(&self) -> usize {
self.b
}
}
"};
buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
indoc! {"
<|S|>struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
}
}

View File

@@ -1,121 +0,0 @@
use crate::{assistant_settings::OpenAiModel, MessageId, MessageMetadata};
use anyhow::{anyhow, Result};
use collections::HashMap;
use fs::Fs;
use futures::StreamExt;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::{
cmp::Reverse,
ffi::OsStr,
path::{Path, PathBuf},
sync::Arc,
};
use util::paths::CONVERSATIONS_DIR;
#[derive(Serialize, Deserialize)]
pub struct SavedMessage {
pub id: MessageId,
pub start: usize,
}
#[derive(Serialize, Deserialize)]
pub struct SavedConversation {
pub id: Option<String>,
pub zed: String,
pub version: String,
pub text: String,
pub messages: Vec<SavedMessage>,
pub message_metadata: HashMap<MessageId, MessageMetadata>,
pub summary: String,
}
impl SavedConversation {
pub const VERSION: &'static str = "0.2.0";
pub async fn load(path: &Path, fs: &dyn Fs) -> Result<Self> {
let saved_conversation = fs.load(path).await?;
let saved_conversation_json =
serde_json::from_str::<serde_json::Value>(&saved_conversation)?;
match saved_conversation_json
.get("version")
.ok_or_else(|| anyhow!("version not found"))?
{
serde_json::Value::String(version) => match version.as_str() {
Self::VERSION => Ok(serde_json::from_value::<Self>(saved_conversation_json)?),
"0.1.0" => {
let saved_conversation =
serde_json::from_value::<SavedConversationV0_1_0>(saved_conversation_json)?;
Ok(Self {
id: saved_conversation.id,
zed: saved_conversation.zed,
version: saved_conversation.version,
text: saved_conversation.text,
messages: saved_conversation.messages,
message_metadata: saved_conversation.message_metadata,
summary: saved_conversation.summary,
})
}
_ => Err(anyhow!(
"unrecognized saved conversation version: {}",
version
)),
},
_ => Err(anyhow!("version not found on saved conversation")),
}
}
}
#[derive(Serialize, Deserialize)]
struct SavedConversationV0_1_0 {
id: Option<String>,
zed: String,
version: String,
text: String,
messages: Vec<SavedMessage>,
message_metadata: HashMap<MessageId, MessageMetadata>,
summary: String,
api_url: Option<String>,
model: OpenAiModel,
}
pub struct SavedConversationMetadata {
pub title: String,
pub path: PathBuf,
pub mtime: chrono::DateTime<chrono::Local>,
}
impl SavedConversationMetadata {
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
fs.create_dir(&CONVERSATIONS_DIR).await?;
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
let mut conversations = Vec::<SavedConversationMetadata>::new();
while let Some(path) = paths.next().await {
let path = path?;
if path.extension() != Some(OsStr::new("json")) {
continue;
}
let pattern = r" - \d+.zed.json$";
let re = Regex::new(pattern).unwrap();
let metadata = fs.metadata(&path).await?;
if let Some((file_name, metadata)) = path
.file_name()
.and_then(|name| name.to_str())
.zip(metadata)
{
let title = re.replace(file_name, "");
conversations.push(Self {
title: title.into_owned(),
path,
mtime: metadata.mtime.into(),
});
}
}
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
Ok(conversations)
}
}

View File

@@ -197,10 +197,12 @@ impl StreamingDiff {
} else {
hunks.push(Hunk::Remove { len: char_len })
}
} else if let Some(Hunk::Keep { len }) = hunks.last_mut() {
*len += char_len;
} else {
hunks.push(Hunk::Keep { len: char_len })
if let Some(Hunk::Keep { len }) = hunks.last_mut() {
*len += char_len;
} else {
hunks.push(Hunk::Keep { len: char_len })
}
}
}

View File

@@ -1,6 +1,6 @@
use assets::SoundRegistry;
use derive_more::{Deref, DerefMut};
use gpui::{AppContext, AssetSource, BorrowAppContext, Global};
use gpui::{AppContext, AssetSource, Global};
use rodio::{OutputStream, OutputStreamHandle};
use util::ResultExt;

View File

@@ -373,10 +373,7 @@ impl ActiveCall {
self.report_call_event("hang up", cx);
Audio::end_call(cx);
let channel_id = self.channel_id(cx);
if let Some((room, _)) = self.room.take() {
cx.emit(Event::RoomLeft { channel_id });
room.update(cx, |room, cx| room.leave(cx))
} else {
Task::ready(Ok(()))

View File

@@ -52,7 +52,7 @@ pub enum Event {
RemoteProjectInvitationDiscarded {
project_id: u64,
},
RoomLeft {
Left {
channel_id: Option<ChannelId>,
},
}
@@ -366,6 +366,9 @@ impl Room {
pub(crate) fn leave(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
cx.notify();
cx.emit(Event::Left {
channel_id: self.channel_id(),
});
self.leave_internal(cx)
}

View File

@@ -51,7 +51,6 @@ pub struct ChannelMessage {
pub nonce: u128,
pub mentions: Vec<(Range<usize>, UserId)>,
pub reply_to_message_id: Option<u64>,
pub edited_at: Option<OffsetDateTime>,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
@@ -84,10 +83,6 @@ pub enum ChannelChatEvent {
old_range: Range<usize>,
new_count: usize,
},
UpdateMessage {
message_id: ChannelMessageId,
message_ix: usize,
},
NewMessage {
channel_id: ChannelId,
message_id: u64,
@@ -98,7 +93,6 @@ impl EventEmitter<ChannelChatEvent> for ChannelChat {}
pub fn init(client: &Arc<Client>) {
client.add_model_message_handler(ChannelChat::handle_message_sent);
client.add_model_message_handler(ChannelChat::handle_message_removed);
client.add_model_message_handler(ChannelChat::handle_message_updated);
}
impl ChannelChat {
@@ -195,7 +189,6 @@ impl ChannelChat {
mentions: message.mentions.clone(),
nonce,
reply_to_message_id: message.reply_to_message_id,
edited_at: None,
},
&(),
),
@@ -222,9 +215,6 @@ impl ChannelChat {
let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
this.update(&mut cx, |this, cx| {
this.insert_messages(SumTree::from_item(message, &()), cx);
if this.first_loaded_message_id.is_none() {
this.first_loaded_message_id = Some(id);
}
})?;
Ok(id)
}))
@@ -244,35 +234,6 @@ impl ChannelChat {
})
}
pub fn update_message(
&mut self,
id: u64,
message: MessageParams,
cx: &mut ModelContext<Self>,
) -> Result<Task<Result<()>>> {
self.message_update(
ChannelMessageId::Saved(id),
message.text.clone(),
message.mentions.clone(),
Some(OffsetDateTime::now_utc()),
cx,
);
let nonce: u128 = self.rng.gen();
let request = self.rpc.request(proto::UpdateChannelMessage {
channel_id: self.channel_id.0,
message_id: id,
body: message.text,
nonce: Some(nonce.into()),
mentions: mentions_to_proto(&message.mentions),
});
Ok(cx.spawn(move |_, _| async move {
request.await?;
Ok(())
}))
}
pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
if self.loaded_all_messages {
return None;
@@ -562,32 +523,6 @@ impl ChannelChat {
Ok(())
}
async fn handle_message_updated(
this: Model<Self>,
message: TypedEnvelope<proto::ChannelMessageUpdate>,
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
let message = message
.payload
.message
.ok_or_else(|| anyhow!("empty message"))?;
let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
this.update(&mut cx, |this, cx| {
this.message_update(
message.id,
message.body,
message.mentions,
message.edited_at,
cx,
)
})?;
Ok(())
}
fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
let nonces = messages
@@ -664,38 +599,6 @@ impl ChannelChat {
}
}
}
fn message_update(
&mut self,
id: ChannelMessageId,
body: String,
mentions: Vec<(Range<usize>, u64)>,
edited_at: Option<OffsetDateTime>,
cx: &mut ModelContext<Self>,
) {
let mut cursor = self.messages.cursor::<ChannelMessageId>();
let mut messages = cursor.slice(&id, Bias::Left, &());
let ix = messages.summary().count;
if let Some(mut message_to_update) = cursor.item().cloned() {
message_to_update.body = body;
message_to_update.mentions = mentions;
message_to_update.edited_at = edited_at;
messages.push(message_to_update, &());
cursor.next(&());
}
messages.append(cursor.suffix(&()), &());
drop(cursor);
self.messages = messages;
cx.emit(ChannelChatEvent::UpdateMessage {
message_ix: ix,
message_id: id,
});
cx.notify();
}
}
async fn messages_from_proto(
@@ -720,15 +623,6 @@ impl ChannelMessage {
user_store.get_user(message.sender_id, cx)
})?
.await?;
let edited_at = message.edited_at.and_then(|t| -> Option<OffsetDateTime> {
if let Ok(a) = OffsetDateTime::from_unix_timestamp(t as i64) {
return Some(a);
}
None
});
Ok(ChannelMessage {
id: ChannelMessageId::Saved(message.id),
body: message.body,
@@ -747,7 +641,6 @@ impl ChannelMessage {
.ok_or_else(|| anyhow!("nonce is required"))?
.into(),
reply_to_message_id: message.reply_to_message_id,
edited_at,
})
}

View File

@@ -17,7 +17,7 @@ use rpc::{
};
use settings::Settings;
use std::{mem, sync::Arc, time::Duration};
use util::{maybe, ResultExt};
use util::{async_maybe, maybe, ResultExt};
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
@@ -227,7 +227,7 @@ impl ChannelStore {
_watch_connection_status: watch_connection_status,
disconnect_channel_buffers_task: None,
_update_channels: cx.spawn(|this, mut cx| async move {
maybe!(async move {
async_maybe!({
while let Some(update_channels) = update_channels_rx.next().await {
if let Some(this) = this.upgrade() {
let update_task = this.update(&mut cx, |this, cx| {

View File

@@ -186,7 +186,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
mentions: vec![],
nonce: Some(1.into()),
reply_to_message_id: None,
edited_at: None,
},
proto::ChannelMessage {
id: 11,
@@ -196,7 +195,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
mentions: vec![],
nonce: Some(2.into()),
reply_to_message_id: None,
edited_at: None,
},
],
done: false,
@@ -245,7 +243,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
mentions: vec![],
nonce: Some(3.into()),
reply_to_message_id: None,
edited_at: None,
}),
});
@@ -300,7 +297,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
nonce: Some(4.into()),
mentions: vec![],
reply_to_message_id: None,
edited_at: None,
},
proto::ChannelMessage {
id: 9,
@@ -310,7 +306,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
nonce: Some(5.into()),
mentions: vec![],
reply_to_message_id: None,
edited_at: None,
},
],
},

View File

@@ -1,4 +1,4 @@
#![cfg_attr(any(target_os = "linux", target_os = "windows"), allow(dead_code))]
#![cfg_attr(target_os = "linux", allow(dead_code))]
use anyhow::{anyhow, Context, Result};
use clap::Parser;

View File

@@ -13,12 +13,11 @@ use async_tungstenite::tungstenite::{
use clock::SystemClock;
use collections::HashMap;
use futures::{
channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt,
TryFutureExt as _, TryStreamExt,
};
use gpui::{
actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, BorrowAppContext, Global, Model,
Task, WeakModel,
actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Global, Model, Task, WeakModel,
};
use lazy_static::lazy_static;
use parking_lot::RwLock;
@@ -28,8 +27,8 @@ use release_channel::{AppVersion, ReleaseChannel};
use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::fmt;
use std::{
any::TypeId,
convert::TryFrom,
@@ -37,10 +36,7 @@ use std::{
future::Future,
marker::PhantomData,
path::PathBuf,
sync::{
atomic::{AtomicU64, Ordering},
Arc, Weak,
},
sync::{atomic::AtomicU64, Arc, Weak},
time::{Duration, Instant},
};
use telemetry::Telemetry;
@@ -53,15 +49,6 @@ pub use rpc::*;
pub use telemetry_events::Event;
pub use user::*;
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct DevServerToken(pub String);
impl fmt::Display for DevServerToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
lazy_static! {
static ref ZED_SERVER_URL: Option<String> = std::env::var("ZED_SERVER_URL").ok();
static ref ZED_RPC_URL: Option<String> = std::env::var("ZED_RPC_URL").ok();
@@ -287,22 +274,10 @@ enum WeakSubscriber {
Pending(Vec<Box<dyn AnyTypedEnvelope>>),
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Credentials {
DevServer { token: DevServerToken },
User { user_id: u64, access_token: String },
}
impl Credentials {
pub fn authorization_header(&self) -> String {
match self {
Credentials::DevServer { token } => format!("dev-server-token {}", token),
Credentials::User {
user_id,
access_token,
} => format!("{} {}", user_id, access_token),
}
}
#[derive(Clone, Debug)]
pub struct Credentials {
pub user_id: u64,
pub access_token: String,
}
impl Default for ClientState {
@@ -467,7 +442,7 @@ impl Client {
}
pub fn id(&self) -> u64 {
self.id.load(Ordering::SeqCst)
self.id.load(std::sync::atomic::Ordering::SeqCst)
}
pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
@@ -475,7 +450,7 @@ impl Client {
}
pub fn set_id(&self, id: u64) -> &Self {
self.id.store(id, Ordering::SeqCst);
self.id.store(id, std::sync::atomic::Ordering::SeqCst);
self
}
@@ -519,11 +494,11 @@ impl Client {
}
pub fn user_id(&self) -> Option<u64> {
if let Some(Credentials::User { user_id, .. }) = self.state.read().credentials.as_ref() {
Some(*user_id)
} else {
None
}
self.state
.read()
.credentials
.as_ref()
.map(|credentials| credentials.user_id)
}
pub fn peer_id(&self) -> Option<PeerId> {
@@ -768,10 +743,6 @@ impl Client {
read_credentials_from_keychain(cx).await.is_some()
}
pub fn set_dev_server_token(&self, token: DevServerToken) {
self.state.write().credentials = Some(Credentials::DevServer { token });
}
#[async_recursion(?Send)]
pub async fn authenticate_and_connect(
self: &Arc<Self>,
@@ -822,9 +793,7 @@ impl Client {
}
}
let credentials = credentials.unwrap();
if let Credentials::User { user_id, .. } = &credentials {
self.set_id(*user_id);
}
self.set_id(credentials.user_id);
if was_disconnected {
self.set_status(Status::Connecting, cx);
@@ -840,9 +809,7 @@ impl Client {
Ok(conn) => {
self.state.write().credentials = Some(credentials.clone());
if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
if let Credentials::User{user_id, access_token} = credentials {
write_credentials_to_keychain(user_id, access_token, cx).await.log_err();
}
write_credentials_to_keychain(credentials, cx).await.log_err();
}
futures::select_biased! {
@@ -1050,7 +1017,10 @@ impl Client {
.unwrap_or_default();
let request = Request::builder()
.header("Authorization", credentials.authorization_header())
.header(
"Authorization",
format!("{} {}", credentials.user_id, credentials.access_token),
)
.header("x-zed-protocol-version", rpc::PROTOCOL_VERSION)
.header("x-zed-app-version", app_version)
.header(
@@ -1203,7 +1173,7 @@ impl Client {
.decrypt_string(&access_token)
.context("failed to decrypt access token")?;
Ok(Credentials::User {
Ok(Credentials {
user_id: user_id.parse()?,
access_token,
})
@@ -1253,7 +1223,7 @@ impl Client {
// Use the admin API token to authenticate as the impersonated user.
api_token.insert_str(0, "ADMIN_TOKEN:");
Ok(Credentials::User {
Ok(Credentials {
user_id: response.user.id,
access_token: api_token,
})
@@ -1290,30 +1260,6 @@ impl Client {
.map_ok(|envelope| envelope.payload)
}
pub fn request_stream<T: RequestMessage>(
&self,
request: T,
) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
let client_id = self.id.load(Ordering::SeqCst);
log::debug!(
"rpc request start. client_id:{}. name:{}",
client_id,
T::NAME
);
let response = self
.connection_id()
.map(|conn_id| self.peer.request_stream(conn_id, request));
async move {
let response = response?.await;
log::debug!(
"rpc request finish. client_id:{}. name:{}",
client_id,
T::NAME
);
response
}
}
pub fn request_envelope<T: RequestMessage>(
&self,
request: T,
@@ -1466,22 +1412,21 @@ async fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credenti
.await
.log_err()??;
Some(Credentials::User {
Some(Credentials {
user_id: user_id.parse().ok()?,
access_token: String::from_utf8(access_token).ok()?,
})
}
async fn write_credentials_to_keychain(
user_id: u64,
access_token: String,
credentials: Credentials,
cx: &AsyncAppContext,
) -> Result<()> {
cx.update(move |cx| {
cx.write_credentials(
&ClientSettings::get_global(cx).server_url,
&user_id.to_string(),
access_token.as_bytes(),
&credentials.user_id.to_string(),
credentials.access_token.as_bytes(),
)
})?
.await
@@ -1586,7 +1531,7 @@ mod tests {
// Time out when client tries to connect.
client.override_authenticate(move |cx| {
cx.background_executor().spawn(async move {
Ok(Credentials::User {
Ok(Credentials {
user_id,
access_token: "token".into(),
})

View File

@@ -15,8 +15,7 @@ use std::{env, mem, path::PathBuf, sync::Arc, time::Duration};
use sysinfo::{CpuRefreshKind, MemoryRefreshKind, Pid, ProcessRefreshKind, RefreshKind, System};
use telemetry_events::{
ActionEvent, AppEvent, AssistantEvent, AssistantKind, CallEvent, CopilotEvent, CpuEvent,
EditEvent, EditorEvent, Event, EventRequestBody, EventWrapper, ExtensionEvent, MemoryEvent,
SettingEvent,
EditEvent, EditorEvent, Event, EventRequestBody, EventWrapper, MemoryEvent, SettingEvent,
};
use tempfile::NamedTempFile;
use util::http::{self, HttpClient, HttpClientWithUrl, Method};
@@ -262,7 +261,7 @@ impl Telemetry {
self: &Arc<Self>,
conversation_id: Option<String>,
kind: AssistantKind,
model: String,
model: &str,
) {
let event = Event::Assistant(AssistantEvent {
conversation_id,
@@ -327,13 +326,6 @@ impl Telemetry {
self.report_event(event)
}
pub fn report_extension_event(self: &Arc<Self>, extension_id: Arc<str>, version: Arc<str>) {
self.report_event(Event::Extension(ExtensionEvent {
extension_id,
version,
}))
}
pub fn log_edit_event(self: &Arc<Self>, environment: &'static str) {
let mut state = self.state.lock();
let period_data = state.event_coalescer.log_event(environment);
@@ -478,11 +470,7 @@ impl Telemetry {
let request = http::Request::builder()
.method(Method::POST)
.uri(
this.http_client
.build_zed_api_url("/telemetry/events", &[])?
.as_ref(),
)
.uri(this.http_client.build_zed_api_url("/telemetry/events"))
.header("Content-Type", "text/plain")
.header("x-zed-checksum", checksum)
.body(json_bytes.into());
@@ -590,10 +578,7 @@ mod tests {
}
#[gpui::test]
async fn test_telemetry_flush_on_flush_interval(
executor: BackgroundExecutor,
cx: &mut TestAppContext,
) {
async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
init_test(cx);
let clock = Arc::new(FakeSystemClock::new(
Utc.with_ymd_and_hms(1990, 4, 12, 12, 0, 0).unwrap(),

View File

@@ -48,7 +48,7 @@ impl FakeServer {
let mut state = state.lock();
state.auth_count += 1;
let access_token = state.access_token.to_string();
Ok(Credentials::User {
Ok(Credentials {
user_id: client_user_id,
access_token,
})
@@ -71,12 +71,9 @@ impl FakeServer {
)))?
}
if credentials
!= (Credentials::User {
user_id: client_user_id,
access_token: state.lock().access_token.to_string(),
})
{
assert_eq!(credentials.user_id, client_user_id);
if credentials.access_token != state.lock().access_token.to_string() {
Err(EstablishConnectionError::Unauthorized)?
}

View File

@@ -0,0 +1,8 @@
[
"nathansobo",
"as-cii",
"maxbrunsfeld",
"iamnbutler",
"mikayla-maki",
"JosephTLyons"
]

View File

@@ -1,5 +1,4 @@
DATABASE_URL = "postgres://postgres@localhost/zed"
# DATABASE_URL = "sqlite:////home/zed/.config/zed/db.sqlite3?mode=rwc"
DATABASE_MAX_CONNECTIONS = 5
HTTP_PORT = 8080
API_TOKEN = "secret"
@@ -14,7 +13,6 @@ BLOB_STORE_BUCKET = "the-extensions-bucket"
BLOB_STORE_URL = "http://127.0.0.1:9000"
BLOB_STORE_REGION = "the-region"
ZED_CLIENT_CHECKSUM_SEED = "development-checksum-seed"
SEED_PATH = "crates/collab/seed.default.json"
# CLICKHOUSE_URL = ""
# CLICKHOUSE_USER = "default"

View File

@@ -13,12 +13,10 @@ workspace = true
[[bin]]
name = "collab"
[features]
sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"]
test-support = ["sqlite"]
[[bin]]
name = "seed"
[dependencies]
anthropic.workspace = true
anyhow.workspace = true
async-tungstenite = "0.16"
aws-config = { version = "1.1.5" }
@@ -33,12 +31,10 @@ collections.workspace = true
dashmap = "5.4"
envy = "0.4.2"
futures.workspace = true
google_ai.workspace = true
hex.workspace = true
live_kit_server.workspace = true
log.workspace = true
nanoid = "0.4"
open_ai.workspace = true
parking_lot.workspace = true
prometheus = "0.13"
prost.workspace = true
@@ -47,7 +43,6 @@ reqwest = { version = "0.11", features = ["json"] }
rpc.workspace = true
scrypt = "0.7"
sea-orm = { version = "0.12.x", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls", "with-uuid"] }
semantic_version.workspace = true
semver.workspace = true
serde.workspace = true
serde_derive.workspace = true
@@ -85,6 +80,7 @@ git = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
lazy_static.workspace = true
live_kit_client = { workspace = true, features = ["test-support"] }
lsp = { workspace = true, features = ["test-support"] }
menu.workspace = true

View File

@@ -6,21 +6,21 @@ It contains our back-end logic for collaboration, to which we connect from the Z
# Local Development
Detailed instructions on getting started are [here](https://zed.dev/docs/local-collaboration).
Detailed instructions on getting started are [here](https://zed.dev/docs/local-collaboration).
# Deployment
We run two instances of collab:
- Staging (https://staging-collab.zed.dev)
- Production (https://collab.zed.dev)
* Staging (https://staging-collab.zed.dev)
* Production (https://collab.zed.dev)
Both of these run on the Kubernetes cluster hosted in Digital Ocean.
Deployment is triggered by pushing to the `collab-staging` (or `collab-production`) tag in Github. The best way to do this is:
- `./script/deploy-collab staging`
- `./script/deploy-collab production`
* `./script/deploy-collab staging`
* `./script/deploy-collab production`
You can tell what is currently deployed with `./script/what-is-deployed`.
@@ -29,7 +29,7 @@ You can tell what is currently deployed with `./script/what-is-deployed`.
To create a new migration:
```
./script/create-migration <name>
./script/sqlx migrate add <name>
```
Migrations are run automatically on service start, so run `foreman start` again. The service will crash if the migrations fail.

12
crates/collab/basic.conf Normal file
View File

@@ -0,0 +1,12 @@
[Interface]
PrivateKey = B5Fp/yVfP0QYlb+YJv9ea+EMI1mWODPD3akh91cVjvc=
Address = fdaa:0:2ce3:a7b:bea:0:a:2/120
DNS = fdaa:0:2ce3::3
[Peer]
PublicKey = RKAYPljEJiuaELNDdQIEJmQienT9+LRISfIHwH45HAw=
AllowedIPs = fdaa:0:2ce3::/48
Endpoint = ord1.gateway.6pn.dev:51820
PersistentKeepalive = 15

View File

@@ -47,6 +47,19 @@ spec:
metadata:
labels:
app: ${ZED_SERVICE_NAME}
annotations:
ad.datadoghq.com/collab.check_names: |
["openmetrics"]
ad.datadoghq.com/collab.init_configs: |
[{}]
ad.datadoghq.com/collab.instances: |
[
{
"openmetrics_endpoint": "http://%%host%%:%%port%%/metrics",
"namespace": "collab_${ZED_KUBE_NAMESPACE}",
"metrics": [".*"]
}
]
spec:
containers:
- name: ${ZED_SERVICE_NAME}
@@ -112,16 +125,6 @@ spec:
secretKeyRef:
name: livekit
key: secret
- name: OPENAI_API_KEY
valueFrom:
secretKeyRef:
name: openai
key: api_key
- name: ANTHROPIC_API_KEY
valueFrom:
secretKeyRef:
name: anthropic
key: api_key
- name: BLOB_STORE_ACCESS_KEY
valueFrom:
secretKeyRef:

View File

@@ -219,7 +219,6 @@ CREATE TABLE IF NOT EXISTS "channel_messages" (
"sender_id" INTEGER NOT NULL REFERENCES users (id),
"body" TEXT NOT NULL,
"sent_at" TIMESTAMP,
"edited_at" TIMESTAMP,
"nonce" BLOB NOT NULL,
"reply_to_message_id" INTEGER DEFAULT NULL
);
@@ -373,8 +372,6 @@ CREATE TABLE extension_versions (
authors TEXT NOT NULL,
repository TEXT NOT NULL,
description TEXT NOT NULL,
schema_version INTEGER NOT NULL DEFAULT 0,
wasm_api_version TEXT,
download_count INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (extension_id, version)
);
@@ -382,16 +379,6 @@ CREATE TABLE extension_versions (
CREATE UNIQUE INDEX "index_extensions_external_id" ON "extensions" ("external_id");
CREATE INDEX "index_extensions_total_download_count" ON "extensions" ("total_download_count");
CREATE TABLE rate_buckets (
user_id INT NOT NULL,
rate_limit_name VARCHAR(255) NOT NULL,
token_count INT NOT NULL,
last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL,
PRIMARY KEY (user_id, rate_limit_name),
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name);
CREATE TABLE hosted_projects (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id),
@@ -401,11 +388,3 @@ CREATE TABLE hosted_projects (
);
CREATE INDEX idx_hosted_projects_on_channel_id ON hosted_projects (channel_id);
CREATE UNIQUE INDEX uix_hosted_projects_on_channel_id_and_name ON hosted_projects (channel_id, name) WHERE (deleted_at IS NULL);
CREATE TABLE dev_servers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id),
name TEXT NOT NULL,
hashed_token TEXT NOT NULL
);
CREATE INDEX idx_dev_servers_on_channel_id ON dev_servers (channel_id);

View File

@@ -1,11 +0,0 @@
CREATE TABLE IF NOT EXISTS rate_buckets (
user_id INT NOT NULL,
rate_limit_name VARCHAR(255) NOT NULL,
token_count INT NOT NULL,
last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL,
PRIMARY KEY (user_id, rate_limit_name),
CONSTRAINT fk_user
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name);

View File

@@ -1 +0,0 @@
ALTER TABLE channel_messages ADD edited_at TIMESTAMP DEFAULT NULL;

View File

@@ -1,2 +0,0 @@
-- Add migration script here
ALTER TABLE extension_versions ADD COLUMN schema_version INTEGER NOT NULL DEFAULT 0;

View File

@@ -1,7 +0,0 @@
CREATE TABLE dev_servers (
id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
channel_id INT NOT NULL REFERENCES channels(id),
name TEXT NOT NULL,
hashed_token TEXT NOT NULL
);
CREATE INDEX idx_dev_servers_on_channel_id ON dev_servers (channel_id);

View File

@@ -1 +0,0 @@
ALTER TABLE extension_versions ADD COLUMN wasm_api_version TEXT;

View File

@@ -1,12 +0,0 @@
{
"admins": [
"nathansobo",
"as-cii",
"maxbrunsfeld",
"iamnbutler",
"mikayla-maki",
"JosephTLyons"
],
"channels": ["zed"],
"number_of_users": 100
}

View File

@@ -1,75 +0,0 @@
use anyhow::{anyhow, Result};
use rpc::proto;
pub fn language_model_request_to_open_ai(
request: proto::CompleteWithLanguageModel,
) -> Result<open_ai::Request> {
Ok(open_ai::Request {
model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
messages: request
.messages
.into_iter()
.map(|message| {
let role = proto::LanguageModelRole::from_i32(message.role)
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
Ok(open_ai::RequestMessage {
role: match role {
proto::LanguageModelRole::LanguageModelUser => open_ai::Role::User,
proto::LanguageModelRole::LanguageModelAssistant => {
open_ai::Role::Assistant
}
proto::LanguageModelRole::LanguageModelSystem => open_ai::Role::System,
},
content: message.content,
})
})
.collect::<Result<Vec<open_ai::RequestMessage>>>()?,
stream: true,
stop: request.stop,
temperature: request.temperature,
})
}
pub fn language_model_request_to_google_ai(
request: proto::CompleteWithLanguageModel,
) -> Result<google_ai::GenerateContentRequest> {
Ok(google_ai::GenerateContentRequest {
contents: request
.messages
.into_iter()
.map(language_model_request_message_to_google_ai)
.collect::<Result<Vec<_>>>()?,
generation_config: None,
safety_settings: None,
})
}
pub fn language_model_request_message_to_google_ai(
message: proto::LanguageModelRequestMessage,
) -> Result<google_ai::Content> {
let role = proto::LanguageModelRole::from_i32(message.role)
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
Ok(google_ai::Content {
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
text: message.content,
})],
role: match role {
proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
},
})
}
pub fn count_tokens_request_to_google_ai(
request: proto::CountTokensWithLanguageModel,
) -> Result<google_ai::CountTokensRequest> {
Ok(google_ai::CountTokensRequest {
contents: request
.messages
.into_iter()
.map(language_model_request_message_to_google_ai)
.collect::<Result<Vec<_>>>()?,
})
}

View File

@@ -1,5 +1,5 @@
use super::ips_file::IpsFile;
use crate::{api::slack, AppState, Error, Result};
use std::sync::{Arc, OnceLock};
use anyhow::{anyhow, Context};
use aws_sdk_s3::primitives::ByteStream;
use axum::{
@@ -9,15 +9,17 @@ use axum::{
routing::post,
Extension, Router, TypedHeader,
};
use rpc::ExtensionMetadata;
use semantic_version::SemanticVersion;
use serde::{Serialize, Serializer};
use sha2::{Digest, Sha256};
use std::sync::{Arc, OnceLock};
use telemetry_events::{
ActionEvent, AppEvent, AssistantEvent, CallEvent, CopilotEvent, CpuEvent, EditEvent,
EditorEvent, Event, EventRequestBody, EventWrapper, ExtensionEvent, MemoryEvent, SettingEvent,
EditorEvent, Event, EventRequestBody, EventWrapper, MemoryEvent, SettingEvent,
};
use util::SemanticVersion;
use crate::{api::slack, AppState, Error, Result};
use super::ips_file::IpsFile;
pub fn router() -> Router {
Router::new()
@@ -329,21 +331,6 @@ pub async fn post_events(
&request_body,
first_event_at,
)),
Event::Extension(event) => {
let metadata = app
.db
.get_extension_version(&event.extension_id, &event.version)
.await?;
to_upload
.extension_events
.push(ExtensionEventRow::from_event(
event.clone(),
&wrapper,
&request_body,
metadata,
first_event_at,
))
}
}
}
@@ -365,7 +352,6 @@ struct ToUpload {
memory_events: Vec<MemoryEventRow>,
app_events: Vec<AppEventRow>,
setting_events: Vec<SettingEventRow>,
extension_events: Vec<ExtensionEventRow>,
edit_events: Vec<EditEventRow>,
action_events: Vec<ActionEventRow>,
}
@@ -424,15 +410,6 @@ impl ToUpload {
.await
.with_context(|| format!("failed to upload to table '{SETTING_EVENTS_TABLE}'"))?;
const EXTENSION_EVENTS_TABLE: &str = "extension_events";
Self::upload_to_table(
EXTENSION_EVENTS_TABLE,
&self.extension_events,
clickhouse_client,
)
.await
.with_context(|| format!("failed to upload to table '{EXTENSION_EVENTS_TABLE}'"))?;
const EDIT_EVENTS_TABLE: &str = "edit_events";
Self::upload_to_table(EDIT_EVENTS_TABLE, &self.edit_events, clickhouse_client)
.await
@@ -459,12 +436,6 @@ impl ToUpload {
}
insert.end().await?;
let event_count = rows.len();
log::info!(
"wrote {event_count} {event_specifier} to '{table}'",
event_specifier = if event_count == 1 { "event" } else { "events" }
);
}
Ok(())
@@ -528,9 +499,9 @@ impl EditorEventRow {
Self {
app_version: body.app_version.clone(),
major: semver.map(|v| v.major() as i32),
minor: semver.map(|v| v.minor() as i32),
patch: semver.map(|v| v.patch() as i32),
major: semver.map(|s| s.major as i32),
minor: semver.map(|s| s.minor as i32),
patch: semver.map(|s| s.patch as i32),
release_channel: body.release_channel.clone().unwrap_or_default(),
os_name: body.os_name.clone(),
os_version: body.os_version.clone().unwrap_or_default(),
@@ -590,9 +561,9 @@ impl CopilotEventRow {
Self {
app_version: body.app_version.clone(),
major: semver.map(|v| v.major() as i32),
minor: semver.map(|v| v.minor() as i32),
patch: semver.map(|v| v.patch() as i32),
major: semver.map(|s| s.major as i32),
minor: semver.map(|s| s.minor as i32),
patch: semver.map(|s| s.patch as i32),
release_channel: body.release_channel.clone().unwrap_or_default(),
os_name: body.os_name.clone(),
os_version: body.os_version.clone().unwrap_or_default(),
@@ -645,9 +616,9 @@ impl CallEventRow {
Self {
app_version: body.app_version.clone(),
major: semver.map(|v| v.major() as i32),
minor: semver.map(|v| v.minor() as i32),
patch: semver.map(|v| v.patch() as i32),
major: semver.map(|s| s.major as i32),
minor: semver.map(|s| s.minor as i32),
patch: semver.map(|s| s.patch as i32),
release_channel: body.release_channel.clone().unwrap_or_default(),
installation_id: body.installation_id.clone().unwrap_or_default(),
session_id: body.session_id.clone(),
@@ -694,9 +665,9 @@ impl AssistantEventRow {
Self {
app_version: body.app_version.clone(),
major: semver.map(|v| v.major() as i32),
minor: semver.map(|v| v.minor() as i32),
patch: semver.map(|v| v.patch() as i32),
major: semver.map(|s| s.major as i32),
minor: semver.map(|s| s.minor as i32),
patch: semver.map(|s| s.patch as i32),
release_channel: body.release_channel.clone().unwrap_or_default(),
installation_id: body.installation_id.clone(),
session_id: body.session_id.clone(),
@@ -738,9 +709,9 @@ impl CpuEventRow {
Self {
app_version: body.app_version.clone(),
major: semver.map(|v| v.major() as i32),
minor: semver.map(|v| v.minor() as i32),
patch: semver.map(|v| v.patch() as i32),
major: semver.map(|s| s.major as i32),
minor: semver.map(|s| s.minor as i32),
patch: semver.map(|s| s.patch as i32),
release_channel: body.release_channel.clone().unwrap_or_default(),
installation_id: body.installation_id.clone(),
session_id: body.session_id.clone(),
@@ -785,9 +756,9 @@ impl MemoryEventRow {
Self {
app_version: body.app_version.clone(),
major: semver.map(|v| v.major() as i32),
minor: semver.map(|v| v.minor() as i32),
patch: semver.map(|v| v.patch() as i32),
major: semver.map(|s| s.major as i32),
minor: semver.map(|s| s.minor as i32),
patch: semver.map(|s| s.patch as i32),
release_channel: body.release_channel.clone().unwrap_or_default(),
installation_id: body.installation_id.clone(),
session_id: body.session_id.clone(),
@@ -831,9 +802,9 @@ impl AppEventRow {
Self {
app_version: body.app_version.clone(),
major: semver.map(|v| v.major() as i32),
minor: semver.map(|v| v.minor() as i32),
patch: semver.map(|v| v.patch() as i32),
major: semver.map(|s| s.major as i32),
minor: semver.map(|s| s.minor as i32),
patch: semver.map(|s| s.patch as i32),
release_channel: body.release_channel.clone().unwrap_or_default(),
installation_id: body.installation_id.clone(),
session_id: body.session_id.clone(),
@@ -876,9 +847,9 @@ impl SettingEventRow {
Self {
app_version: body.app_version.clone(),
major: semver.map(|v| v.major() as i32),
minor: semver.map(|v| v.minor() as i32),
patch: semver.map(|v| v.patch() as i32),
major: semver.map(|s| s.major as i32),
minor: semver.map(|s| s.minor as i32),
patch: semver.map(|s| s.patch as i32),
release_channel: body.release_channel.clone().unwrap_or_default(),
installation_id: body.installation_id.clone(),
session_id: body.session_id.clone(),
@@ -890,68 +861,6 @@ impl SettingEventRow {
}
}
#[derive(Serialize, Debug, clickhouse::Row)]
pub struct ExtensionEventRow {
// AppInfoBase
app_version: String,
major: Option<i32>,
minor: Option<i32>,
patch: Option<i32>,
release_channel: String,
// ClientEventBase
installation_id: Option<String>,
session_id: Option<String>,
is_staff: Option<bool>,
time: i64,
// ExtensionEventRow
extension_id: Arc<str>,
extension_version: Arc<str>,
dev: bool,
schema_version: Option<i32>,
wasm_api_version: Option<String>,
}
impl ExtensionEventRow {
fn from_event(
event: ExtensionEvent,
wrapper: &EventWrapper,
body: &EventRequestBody,
extension_metadata: Option<ExtensionMetadata>,
first_event_at: chrono::DateTime<chrono::Utc>,
) -> Self {
let semver = body.semver();
let time =
first_event_at + chrono::Duration::milliseconds(wrapper.milliseconds_since_first_event);
Self {
app_version: body.app_version.clone(),
major: semver.map(|v| v.major() as i32),
minor: semver.map(|v| v.minor() as i32),
patch: semver.map(|v| v.patch() as i32),
release_channel: body.release_channel.clone().unwrap_or_default(),
installation_id: body.installation_id.clone(),
session_id: body.session_id.clone(),
is_staff: body.is_staff,
time: time.timestamp_millis(),
extension_id: event.extension_id,
extension_version: event.version,
dev: extension_metadata.is_none(),
schema_version: extension_metadata
.as_ref()
.and_then(|metadata| metadata.manifest.schema_version),
wasm_api_version: extension_metadata.as_ref().and_then(|metadata| {
metadata
.manifest
.wasm_api_version
.as_ref()
.map(|version| version.to_string())
}),
}
}
}
#[derive(Serialize, Debug, clickhouse::Row)]
pub struct EditEventRow {
// AppInfoBase
@@ -991,9 +900,9 @@ impl EditEventRow {
Self {
app_version: body.app_version.clone(),
major: semver.map(|v| v.major() as i32),
minor: semver.map(|v| v.minor() as i32),
patch: semver.map(|v| v.patch() as i32),
major: semver.map(|s| s.major as i32),
minor: semver.map(|s| s.minor as i32),
patch: semver.map(|s| s.patch as i32),
release_channel: body.release_channel.clone().unwrap_or_default(),
installation_id: body.installation_id.clone(),
session_id: body.session_id.clone(),
@@ -1040,9 +949,9 @@ impl ActionEventRow {
Self {
app_version: body.app_version.clone(),
major: semver.map(|v| v.major() as i32),
minor: semver.map(|v| v.minor() as i32),
patch: semver.map(|v| v.patch() as i32),
major: semver.map(|s| s.major as i32),
minor: semver.map(|s| s.minor as i32),
patch: semver.map(|s| s.patch as i32),
release_channel: body.release_channel.clone().unwrap_or_default(),
installation_id: body.installation_id.clone(),
session_id: body.session_id.clone(),

View File

@@ -1,5 +1,8 @@
use crate::db::ExtensionVersionConstraints;
use crate::{db::NewExtensionVersion, AppState, Error, Result};
use crate::{
db::{ExtensionMetadata, NewExtensionVersion},
executor::Executor,
AppState, Error, Result,
};
use anyhow::{anyhow, Context as _};
use aws_sdk_s3::presigning::PresigningConfig;
use axum::{
@@ -10,22 +13,14 @@ use axum::{
Extension, Json, Router,
};
use collections::HashMap;
use rpc::{ExtensionApiManifest, GetExtensionsResponse};
use semantic_version::SemanticVersion;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::{sync::Arc, time::Duration};
use time::PrimitiveDateTime;
use util::{maybe, ResultExt};
use util::ResultExt;
pub fn router() -> Router {
Router::new()
.route("/extensions", get(get_extensions))
.route("/extensions/updates", get(get_extension_updates))
.route("/extensions/:extension_id", get(get_extension_versions))
.route(
"/extensions/:extension_id/download",
get(download_latest_extension),
)
.route(
"/extensions/:extension_id/:version/download",
get(download_extension),
@@ -35,114 +30,6 @@ pub fn router() -> Router {
#[derive(Debug, Deserialize)]
struct GetExtensionsParams {
filter: Option<String>,
#[serde(default)]
ids: Option<String>,
#[serde(default)]
max_schema_version: i32,
}
async fn get_extensions(
Extension(app): Extension<Arc<AppState>>,
Query(params): Query<GetExtensionsParams>,
) -> Result<Json<GetExtensionsResponse>> {
let extension_ids = params
.ids
.as_ref()
.map(|s| s.split(',').map(|s| s.trim()).collect::<Vec<_>>());
let extensions = if let Some(extension_ids) = extension_ids {
app.db.get_extensions_by_ids(&extension_ids, None).await?
} else {
app.db
.get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
.await?
};
Ok(Json(GetExtensionsResponse { data: extensions }))
}
#[derive(Debug, Deserialize)]
struct GetExtensionUpdatesParams {
ids: String,
min_schema_version: i32,
max_schema_version: i32,
min_wasm_api_version: SemanticVersion,
max_wasm_api_version: SemanticVersion,
}
async fn get_extension_updates(
Extension(app): Extension<Arc<AppState>>,
Query(params): Query<GetExtensionUpdatesParams>,
) -> Result<Json<GetExtensionsResponse>> {
let constraints = ExtensionVersionConstraints {
schema_versions: params.min_schema_version..=params.max_schema_version,
wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
};
let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
let extensions = app
.db
.get_extensions_by_ids(&extension_ids, Some(&constraints))
.await?;
Ok(Json(GetExtensionsResponse { data: extensions }))
}
#[derive(Debug, Deserialize)]
struct GetExtensionVersionsParams {
extension_id: String,
}
async fn get_extension_versions(
Extension(app): Extension<Arc<AppState>>,
Path(params): Path<GetExtensionVersionsParams>,
) -> Result<Json<GetExtensionsResponse>> {
let extension_versions = app.db.get_extension_versions(&params.extension_id).await?;
Ok(Json(GetExtensionsResponse {
data: extension_versions,
}))
}
#[derive(Debug, Deserialize)]
struct DownloadLatestExtensionParams {
extension_id: String,
min_schema_version: Option<i32>,
max_schema_version: Option<i32>,
min_wasm_api_version: Option<SemanticVersion>,
max_wasm_api_version: Option<SemanticVersion>,
}
async fn download_latest_extension(
Extension(app): Extension<Arc<AppState>>,
Path(params): Path<DownloadLatestExtensionParams>,
) -> Result<Redirect> {
let constraints = maybe!({
let min_schema_version = params.min_schema_version?;
let max_schema_version = params.max_schema_version?;
let min_wasm_api_version = params.min_wasm_api_version?;
let max_wasm_api_version = params.max_wasm_api_version?;
Some(ExtensionVersionConstraints {
schema_versions: min_schema_version..=max_schema_version,
wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
})
});
let extension = app
.db
.get_extension(&params.extension_id, constraints.as_ref())
.await?
.ok_or_else(|| anyhow!("unknown extension"))?;
download_extension(
Extension(app),
Path(DownloadExtensionParams {
extension_id: params.extension_id,
version: extension.manifest.version.to_string(),
}),
)
.await
}
#[derive(Debug, Deserialize)]
@@ -151,6 +38,28 @@ struct DownloadExtensionParams {
version: String,
}
#[derive(Debug, Serialize)]
struct GetExtensionsResponse {
pub data: Vec<ExtensionMetadata>,
}
#[derive(Deserialize)]
struct ExtensionManifest {
name: String,
version: String,
description: Option<String>,
authors: Vec<String>,
repository: String,
}
async fn get_extensions(
Extension(app): Extension<Arc<AppState>>,
Query(params): Query<GetExtensionsParams>,
) -> Result<Json<GetExtensionsResponse>> {
let extensions = app.db.get_extensions(params.filter.as_deref(), 500).await?;
Ok(Json(GetExtensionsResponse { data: extensions }))
}
async fn download_extension(
Extension(app): Extension<Arc<AppState>>,
Path(params): Path<DownloadExtensionParams>,
@@ -199,7 +108,7 @@ async fn download_extension(
const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>, executor: Executor) {
let Some(blob_store_client) = app_state.blob_store_client.clone() else {
log::info!("no blob store client");
return;
@@ -209,7 +118,6 @@ pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
return;
};
let executor = app_state.executor.clone();
executor.spawn_detached({
let executor = executor.clone();
async move {
@@ -331,7 +239,7 @@ async fn fetch_extension_manifest(
})?
.to_vec();
let manifest =
serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
serde_json::from_slice::<ExtensionManifest>(&manifest_bytes).with_context(|| {
format!(
"invalid manifest for extension {extension_id} version {version}: {}",
String::from_utf8_lossy(&manifest_bytes)
@@ -351,8 +259,6 @@ async fn fetch_extension_manifest(
description: manifest.description.unwrap_or_default(),
authors: manifest.authors,
repository: manifest.repository,
schema_version: manifest.schema_version.unwrap_or(0),
wasm_api_version: manifest.wasm_api_version,
published_at,
})
}

View File

@@ -1,8 +1,9 @@
use collections::HashMap;
use semantic_version::SemanticVersion;
use serde::{Deserialize, Serialize};
use serde_derive::Deserialize;
use serde_derive::Serialize;
use serde_json::Value;
use util::SemanticVersion;
#[derive(Debug)]
pub struct IpsFile {

View File

@@ -1,6 +1,5 @@
use crate::{
db::{self, dev_server, AccessTokenId, Database, DevServerId, UserId},
rpc::Principal,
db::{self, AccessTokenId, Database, UserId},
AppState, Error, Result,
};
use anyhow::{anyhow, Context};
@@ -20,11 +19,11 @@ use std::sync::OnceLock;
use std::{sync::Arc, time::Instant};
use subtle::ConstantTimeEq;
/// Validates the authorization header and adds an Extension<Principal> to the request.
/// Authorization: <user-id> <token>
/// <token> can be an access_token attached to that user, or an access token of an admin
/// or (in development) the string ADMIN:<config.api_token>.
/// Authorization: "dev-server-token" <token>
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct Impersonator(pub Option<db::User>);
/// Validates the authorization header. This has two mechanisms, one for the ADMIN_TOKEN
/// and one for the access tokens that we issue.
pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
let mut auth_header = req
.headers()
@@ -38,26 +37,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
})?
.split_whitespace();
let state = req.extensions().get::<Arc<AppState>>().unwrap();
let first = auth_header.next().unwrap_or("");
if first == "dev-server-token" {
let dev_server_token = auth_header.next().ok_or_else(|| {
Error::Http(
StatusCode::BAD_REQUEST,
"missing dev-server-token token in authorization header".to_string(),
)
})?;
let dev_server = verify_dev_server_token(dev_server_token, &state.db)
.await
.map_err(|e| Error::Http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
req.extensions_mut()
.insert(Principal::DevServer(dev_server));
return Ok::<_, Error>(next.run(req).await);
}
let user_id = UserId(first.parse().map_err(|_| {
let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
Error::Http(
StatusCode::BAD_REQUEST,
"missing user id in authorization header".to_string(),
@@ -71,6 +51,8 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
)
})?;
let state = req.extensions().get::<Arc<AppState>>().unwrap();
// In development, allow impersonation using the admin API token.
// Don't allow this in production because we can't tell who is doing
// the impersonating.
@@ -94,17 +76,18 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
.await?
.ok_or_else(|| anyhow!("user {} not found", user_id))?;
if let Some(impersonator_id) = validate_result.impersonator_id {
let admin = state
let impersonator = if let Some(impersonator_id) = validate_result.impersonator_id {
let impersonator = state
.db
.get_user_by_id(impersonator_id)
.await?
.ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
req.extensions_mut()
.insert(Principal::Impersonated { user, admin });
Some(impersonator)
} else {
req.extensions_mut().insert(Principal::User(user));
None
};
req.extensions_mut().insert(user);
req.extensions_mut().insert(Impersonator(impersonator));
return Ok::<_, Error>(next.run(req).await);
}
}
@@ -230,33 +213,6 @@ pub async fn verify_access_token(
})
}
// a dev_server_token has the format <id>.<base64>. This is to make them
// relatively easy to copy/paste around.
pub async fn verify_dev_server_token(
dev_server_token: &str,
db: &Arc<Database>,
) -> anyhow::Result<dev_server::Model> {
let mut parts = dev_server_token.splitn(2, '.');
let id = DevServerId(parts.next().unwrap_or_default().parse()?);
let token = parts
.next()
.ok_or_else(|| anyhow!("invalid dev server token format"))?;
let token_hash = hash_access_token(&token);
let server = db.get_dev_server(id).await?;
if server
.hashed_token
.as_bytes()
.ct_eq(token_hash.as_ref())
.into()
{
Ok(server)
} else {
Err(anyhow!("wrong token for dev server"))
}
}
#[cfg(test)]
mod test {
use rand::thread_rng;

View File

@@ -0,0 +1,97 @@
use collab::{
db::{self, NewUserParams},
env::load_dotenv,
executor::Executor,
};
use db::{ConnectOptions, Database};
use serde::{de::DeserializeOwned, Deserialize};
use std::{fmt::Write, fs};
#[derive(Debug, Deserialize)]
struct GitHubUser {
id: i32,
login: String,
email: Option<String>,
}
#[tokio::main]
async fn main() {
load_dotenv().expect("failed to load .env.toml file");
let mut admin_logins = load_admins("crates/collab/.admins.default.json")
.expect("failed to load default admins file");
if let Ok(other_admins) = load_admins("./.admins.json") {
admin_logins.extend(other_admins);
}
let database_url = std::env::var("DATABASE_URL").expect("missing DATABASE_URL env var");
let db = Database::new(ConnectOptions::new(database_url), Executor::Production)
.await
.expect("failed to connect to postgres database");
let client = reqwest::Client::new();
// Create admin users for all of the users in `.admins.toml` or `.admins.default.toml`.
for admin_login in admin_logins {
let user = fetch_github::<GitHubUser>(
&client,
&format!("https://api.github.com/users/{admin_login}"),
)
.await;
db.create_user(
&user.email.unwrap_or(format!("{admin_login}@example.com")),
true,
NewUserParams {
github_login: user.login,
github_user_id: user.id,
},
)
.await
.expect("failed to create admin user");
}
// Fetch 100 other random users from GitHub and insert them into the database.
let mut user_count = db
.get_all_users(0, 200)
.await
.expect("failed to load users from db")
.len();
let mut last_user_id = None;
while user_count < 100 {
let mut uri = "https://api.github.com/users?per_page=100".to_string();
if let Some(last_user_id) = last_user_id {
write!(&mut uri, "&since={}", last_user_id).unwrap();
}
let users = fetch_github::<Vec<GitHubUser>>(&client, &uri).await;
for github_user in users {
last_user_id = Some(github_user.id);
user_count += 1;
db.get_or_create_user_by_github_account(
&github_user.login,
Some(github_user.id),
github_user.email.as_deref(),
None,
)
.await
.expect("failed to insert user");
}
}
}
fn load_admins(path: &str) -> anyhow::Result<Vec<String>> {
let file_content = fs::read_to_string(path)?;
Ok(serde_json::from_str(&file_content)?)
}
async fn fetch_github<T: DeserializeOwned>(client: &reqwest::Client, url: &str) -> T {
let response = client
.get(url)
.header("user-agent", "zed")
.send()
.await
.unwrap_or_else(|_| panic!("failed to fetch '{}'", url));
response
.json()
.await
.unwrap_or_else(|_| panic!("failed to deserialize github user from '{}'", url))
}

View File

@@ -12,7 +12,7 @@ use futures::StreamExt;
use rand::{prelude::StdRng, Rng, SeedableRng};
use rpc::{
proto::{self},
ConnectionId, ExtensionMetadata,
ConnectionId,
};
use sea_orm::{
entity::prelude::*,
@@ -21,13 +21,11 @@ use sea_orm::{
FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
TransactionTrait,
};
use semantic_version::SemanticVersion;
use serde::{Deserialize, Serialize};
use serde::{ser::Error as _, Deserialize, Serialize, Serializer};
use sqlx::{
migrate::{Migrate, Migration, MigrationSource},
Connection,
};
use std::ops::RangeInclusive;
use std::{
fmt::Write as _,
future::Future,
@@ -38,7 +36,7 @@ use std::{
sync::Arc,
time::Duration,
};
use time::PrimitiveDateTime;
use time::{format_description::well_known::iso8601, PrimitiveDateTime};
use tokio::sync::{Mutex, OwnedMutexGuard};
#[cfg(test)]
@@ -130,6 +128,12 @@ impl Database {
Ok(new_migrations)
}
/// Initializes static data that resides in the database by upserting it.
pub async fn initialize_static_data(&mut self) -> Result<()> {
self.initialize_notification_kinds().await?;
Ok(())
}
/// Transaction runs things in a transaction. If you want to call other methods
/// and pass the transaction around you need to reborrow the transaction at each
/// call site with: `&*tx`.
@@ -454,16 +458,6 @@ pub struct CreatedChannelMessage {
pub notifications: NotificationBatch,
}
pub struct UpdatedChannelMessage {
pub message_id: MessageId,
pub participant_connection_ids: Vec<ConnectionId>,
pub notifications: NotificationBatch,
pub reply_to_message_id: Option<MessageId>,
pub timestamp: PrimitiveDateTime,
pub deleted_mention_notification_ids: Vec<NotificationId>,
pub updated_mention_notifications: Vec<rpc::proto::Notification>,
}
#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
pub struct Invite {
pub email_address: String,
@@ -552,7 +546,7 @@ pub struct Channel {
}
impl Channel {
pub fn from_model(value: channel::Model) -> Self {
fn from_model(value: channel::Model) -> Self {
Channel {
id: value.id,
visibility: value.visibility,
@@ -610,14 +604,16 @@ pub struct RejoinedChannelBuffer {
#[derive(Clone)]
pub struct JoinRoom {
pub room: proto::Room,
pub channel: Option<channel::Model>,
pub channel_id: Option<ChannelId>,
pub channel_members: Vec<UserId>,
}
pub struct RejoinedRoom {
pub room: proto::Room,
pub rejoined_projects: Vec<RejoinedProject>,
pub reshared_projects: Vec<ResharedProject>,
pub channel: Option<channel::Model>,
pub channel_id: Option<ChannelId>,
pub channel_members: Vec<UserId>,
}
pub struct ResharedProject {
@@ -653,7 +649,8 @@ pub struct RejoinedWorktree {
pub struct LeftRoom {
pub room: proto::Room,
pub channel: Option<channel::Model>,
pub channel_id: Option<ChannelId>,
pub channel_members: Vec<UserId>,
pub left_projects: HashMap<ProjectId, LeftProject>,
pub canceled_calls_to_user_ids: Vec<UserId>,
pub deleted: bool,
@@ -661,7 +658,8 @@ pub struct LeftRoom {
pub struct RefreshedRoom {
pub room: proto::Room,
pub channel: Option<channel::Model>,
pub channel_id: Option<ChannelId>,
pub channel_members: Vec<UserId>,
pub stale_participant_user_ids: Vec<UserId>,
pub canceled_calls_to_user_ids: Vec<UserId>,
}
@@ -729,12 +727,36 @@ pub struct NewExtensionVersion {
pub description: String,
pub authors: Vec<String>,
pub repository: String,
pub schema_version: i32,
pub wasm_api_version: Option<String>,
pub published_at: PrimitiveDateTime,
}
pub struct ExtensionVersionConstraints {
pub schema_versions: RangeInclusive<i32>,
pub wasm_api_versions: RangeInclusive<SemanticVersion>,
#[derive(Debug, Serialize, PartialEq)]
pub struct ExtensionMetadata {
pub id: String,
pub name: String,
pub version: String,
pub authors: Vec<String>,
pub description: String,
pub repository: String,
#[serde(serialize_with = "serialize_iso8601")]
pub published_at: PrimitiveDateTime,
pub download_count: u64,
}
pub fn serialize_iso8601<S: Serializer>(
datetime: &PrimitiveDateTime,
serializer: S,
) -> Result<S::Ok, S::Error> {
const SERDE_CONFIG: iso8601::EncodedConfig = iso8601::Config::DEFAULT
.set_year_is_six_digits(false)
.set_time_precision(iso8601::TimePrecision::Second {
decimal_digits: None,
})
.encode();
datetime
.assume_utc()
.format(&time::format_description::well_known::Iso8601::<SERDE_CONFIG>)
.map_err(S::Error::custom)?
.serialize(serializer)
}

View File

@@ -67,34 +67,31 @@ macro_rules! id_type {
};
}
id_type!(AccessTokenId);
id_type!(BufferId);
id_type!(ChannelBufferCollaboratorId);
id_type!(AccessTokenId);
id_type!(ChannelChatParticipantId);
id_type!(ChannelId);
id_type!(ChannelMemberId);
id_type!(ContactId);
id_type!(DevServerId);
id_type!(ExtensionId);
id_type!(FlagId);
id_type!(FollowerId);
id_type!(HostedProjectId);
id_type!(MessageId);
id_type!(NotificationId);
id_type!(NotificationKindId);
id_type!(ProjectCollaboratorId);
id_type!(ProjectId);
id_type!(ReplicaId);
id_type!(ContactId);
id_type!(FollowerId);
id_type!(RoomId);
id_type!(RoomParticipantId);
id_type!(ProjectId);
id_type!(ProjectCollaboratorId);
id_type!(ReplicaId);
id_type!(ServerId);
id_type!(SignupId);
id_type!(UserId);
id_type!(ChannelBufferCollaboratorId);
id_type!(FlagId);
id_type!(ExtensionId);
id_type!(NotificationId);
id_type!(NotificationKindId);
id_type!(HostedProjectId);
/// ChannelRole gives you permissions for both channels and calls.
#[derive(
Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash, Serialize,
)]
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)]
#[sea_orm(rs_type = "String", db_type = "String(None)")]
pub enum ChannelRole {
/// Admin can read/write and change permissions.

View File

@@ -5,13 +5,11 @@ pub mod buffers;
pub mod channels;
pub mod contacts;
pub mod contributors;
pub mod dev_servers;
pub mod extensions;
pub mod hosted_projects;
pub mod messages;
pub mod notifications;
pub mod projects;
pub mod rate_buckets;
pub mod rooms;
pub mod servers;
pub mod users;

View File

@@ -45,7 +45,11 @@ impl Database {
name: &str,
parent_channel_id: Option<ChannelId>,
admin_id: UserId,
) -> Result<(channel::Model, Option<channel_member::Model>)> {
) -> Result<(
Channel,
Option<channel_member::Model>,
Vec<channel_member::Model>,
)> {
let name = Self::sanitize_channel_name(name)?;
self.transaction(move |tx| async move {
let mut parent = None;
@@ -86,7 +90,12 @@ impl Database {
);
}
Ok((channel, membership))
let channel_members = channel_member::Entity::find()
.filter(channel_member::Column::ChannelId.eq(channel.root_id()))
.all(&*tx)
.await?;
Ok((Channel::from_model(channel), membership, channel_members))
})
.await
}
@@ -172,7 +181,7 @@ impl Database {
channel_id: ChannelId,
visibility: ChannelVisibility,
admin_id: UserId,
) -> Result<channel::Model> {
) -> Result<(Channel, Vec<channel_member::Model>)> {
self.transaction(move |tx| async move {
let channel = self.get_channel_internal(channel_id, &tx).await?;
self.check_user_is_channel_admin(&channel, admin_id, &tx)
@@ -205,7 +214,12 @@ impl Database {
model.visibility = ActiveValue::Set(visibility);
let channel = model.update(&*tx).await?;
Ok(channel)
let channel_members = channel_member::Entity::find()
.filter(channel_member::Column::ChannelId.eq(channel.root_id()))
.all(&*tx)
.await?;
Ok((Channel::from_model(channel), channel_members))
})
.await
}
@@ -231,12 +245,21 @@ impl Database {
&self,
channel_id: ChannelId,
user_id: UserId,
) -> Result<(ChannelId, Vec<ChannelId>)> {
) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
self.transaction(move |tx| async move {
let channel = self.get_channel_internal(channel_id, &tx).await?;
self.check_user_is_channel_admin(&channel, user_id, &tx)
.await?;
let members_to_notify: Vec<UserId> = channel_member::Entity::find()
.filter(channel_member::Column::ChannelId.eq(channel.root_id()))
.select_only()
.column(channel_member::Column::UserId)
.distinct()
.into_values::<_, QueryUserIds>()
.all(&*tx)
.await?;
let channels_to_remove = self
.get_channel_descendants_excluding_self([&channel], &tx)
.await?
@@ -250,7 +273,7 @@ impl Database {
.exec(&*tx)
.await?;
Ok((channel.root_id(), channels_to_remove))
Ok((channels_to_remove, members_to_notify))
})
.await
}
@@ -320,7 +343,7 @@ impl Database {
channel_id: ChannelId,
admin_id: UserId,
new_name: &str,
) -> Result<channel::Model> {
) -> Result<(Channel, Vec<channel_member::Model>)> {
self.transaction(move |tx| async move {
let new_name = Self::sanitize_channel_name(new_name)?.to_string();
@@ -332,7 +355,12 @@ impl Database {
model.name = ActiveValue::Set(new_name.clone());
let channel = model.update(&*tx).await?;
Ok(channel)
let channel_members = channel_member::Entity::find()
.filter(channel_member::Column::ChannelId.eq(channel.root_id()))
.all(&*tx)
.await?;
Ok((Channel::from_model(channel), channel_members))
})
.await
}
@@ -956,7 +984,7 @@ impl Database {
channel_id: ChannelId,
new_parent_id: ChannelId,
admin_id: UserId,
) -> Result<(ChannelId, Vec<Channel>)> {
) -> Result<(Vec<Channel>, Vec<channel_member::Model>)> {
self.transaction(|tx| async move {
let channel = self.get_channel_internal(channel_id, &tx).await?;
self.check_user_is_channel_admin(&channel, admin_id, &tx)
@@ -1011,7 +1039,12 @@ impl Database {
.map(|c| Channel::from_model(c))
.collect::<Vec<_>>();
Ok((root_id, channels))
let channel_members = channel_member::Entity::find()
.filter(channel_member::Column::ChannelId.eq(root_id))
.all(&*tx)
.await?;
Ok((channels, channel_members))
})
.await
}

View File

@@ -1,18 +0,0 @@
use sea_orm::EntityTrait;
use super::{dev_server, Database, DevServerId};
impl Database {
pub async fn get_dev_server(
&self,
dev_server_id: DevServerId,
) -> crate::Result<dev_server::Model> {
self.transaction(|tx| async move {
Ok(dev_server::Entity::find_by_id(dev_server_id)
.one(&*tx)
.await?
.ok_or_else(|| anyhow::anyhow!("no dev server with id {}", dev_server_id))?)
})
.await
}
}

View File

@@ -1,202 +1,57 @@
use std::str::FromStr;
use chrono::Utc;
use sea_orm::sea_query::IntoCondition;
use util::ResultExt;
use super::*;
impl Database {
pub async fn get_extensions(
&self,
filter: Option<&str>,
max_schema_version: i32,
limit: usize,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
let mut condition = Condition::all()
.add(
extension::Column::LatestVersion
.into_expr()
.eq(extension_version::Column::Version.into_expr()),
)
.add(extension_version::Column::SchemaVersion.lte(max_schema_version));
let mut condition = Condition::all();
if let Some(filter) = filter {
let fuzzy_name_filter = Self::fuzzy_like_string(filter);
condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
}
self.get_extensions_where(condition, Some(limit as u64), &tx)
.await
})
.await
}
pub async fn get_extensions_by_ids(
&self,
ids: &[&str],
constraints: Option<&ExtensionVersionConstraints>,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
let extensions = extension::Entity::find()
.filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
.filter(condition)
.order_by_desc(extension::Column::TotalDownloadCount)
.order_by_asc(extension::Column::Name)
.limit(Some(limit as u64))
.filter(
extension::Column::LatestVersion
.into_expr()
.eq(extension_version::Column::Version.into_expr()),
)
.inner_join(extension_version::Entity)
.select_also(extension_version::Entity)
.all(&*tx)
.await?;
let mut max_versions = self
.get_latest_versions_for_extensions(&extensions, constraints, &tx)
.await?;
Ok(extensions
.into_iter()
.filter_map(|extension| {
let (version, _) = max_versions.remove(&extension.id)?;
Some(metadata_from_extension_and_version(extension, version))
.filter_map(|(extension, latest_version)| {
let version = latest_version?;
Some(ExtensionMetadata {
id: extension.external_id,
name: extension.name,
version: version.version,
authors: version
.authors
.split(',')
.map(|author| author.trim().to_string())
.collect::<Vec<_>>(),
description: version.description,
repository: version.repository,
published_at: version.published_at,
download_count: extension.total_download_count as u64,
})
})
.collect())
})
.await
}
async fn get_latest_versions_for_extensions(
&self,
extensions: &[extension::Model],
constraints: Option<&ExtensionVersionConstraints>,
tx: &DatabaseTransaction,
) -> Result<HashMap<ExtensionId, (extension_version::Model, SemanticVersion)>> {
let mut versions = extension_version::Entity::find()
.filter(
extension_version::Column::ExtensionId
.is_in(extensions.iter().map(|extension| extension.id)),
)
.stream(tx)
.await?;
let mut max_versions =
HashMap::<ExtensionId, (extension_version::Model, SemanticVersion)>::default();
while let Some(version) = versions.next().await {
let version = version?;
let Some(extension_version) = SemanticVersion::from_str(&version.version).log_err()
else {
continue;
};
if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id) {
if max_extension_version > &extension_version {
continue;
}
}
if let Some(constraints) = constraints {
if !constraints
.schema_versions
.contains(&version.schema_version)
{
continue;
}
if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
if let Some(version) = SemanticVersion::from_str(wasm_api_version).log_err() {
if !constraints.wasm_api_versions.contains(&version) {
continue;
}
} else {
continue;
}
}
}
max_versions.insert(version.extension_id, (version, extension_version));
}
Ok(max_versions)
}
/// Returns all of the versions for the extension with the given ID.
pub async fn get_extension_versions(
&self,
extension_id: &str,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
let condition = extension::Column::ExternalId
.eq(extension_id)
.into_condition();
self.get_extensions_where(condition, None, &tx).await
})
.await
}
async fn get_extensions_where(
&self,
condition: Condition,
limit: Option<u64>,
tx: &DatabaseTransaction,
) -> Result<Vec<ExtensionMetadata>> {
let extensions = extension::Entity::find()
.inner_join(extension_version::Entity)
.select_also(extension_version::Entity)
.filter(condition)
.order_by_desc(extension::Column::TotalDownloadCount)
.order_by_asc(extension::Column::Name)
.limit(limit)
.all(tx)
.await?;
Ok(extensions
.into_iter()
.filter_map(|(extension, version)| {
Some(metadata_from_extension_and_version(extension, version?))
})
.collect())
}
pub async fn get_extension(
&self,
extension_id: &str,
constraints: Option<&ExtensionVersionConstraints>,
) -> Result<Option<ExtensionMetadata>> {
self.transaction(|tx| async move {
let extension = extension::Entity::find()
.filter(extension::Column::ExternalId.eq(extension_id))
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("no such extension: {extension_id}"))?;
let extensions = [extension];
let mut versions = self
.get_latest_versions_for_extensions(&extensions, constraints, &tx)
.await?;
let [extension] = extensions;
Ok(versions.remove(&extension.id).map(|(max_version, _)| {
metadata_from_extension_and_version(extension, max_version)
}))
})
.await
}
pub async fn get_extension_version(
&self,
extension_id: &str,
version: &str,
) -> Result<Option<ExtensionMetadata>> {
self.transaction(|tx| async move {
let extension = extension::Entity::find()
.filter(extension::Column::ExternalId.eq(extension_id))
.filter(extension_version::Column::Version.eq(version))
.inner_join(extension_version::Entity)
.select_also(extension_version::Entity)
.one(&*tx)
.await?;
Ok(extension.and_then(|(extension, version)| {
Some(metadata_from_extension_and_version(extension, version?))
}))
})
.await
}
pub async fn get_known_extension_versions<'a>(&self) -> Result<HashMap<String, Vec<String>>> {
self.transaction(|tx| async move {
let mut extension_external_ids_by_id = HashMap::default();
@@ -280,8 +135,6 @@ impl Database {
authors: ActiveValue::Set(version.authors.join(", ")),
repository: ActiveValue::Set(version.repository.clone()),
description: ActiveValue::Set(version.description.clone()),
schema_version: ActiveValue::Set(version.schema_version),
wasm_api_version: ActiveValue::Set(version.wasm_api_version.clone()),
download_count: ActiveValue::NotSet,
}
}))
@@ -351,35 +204,3 @@ impl Database {
.await
}
}
fn metadata_from_extension_and_version(
extension: extension::Model,
version: extension_version::Model,
) -> ExtensionMetadata {
ExtensionMetadata {
id: extension.external_id.into(),
manifest: rpc::ExtensionApiManifest {
name: extension.name,
version: version.version.into(),
authors: version
.authors
.split(',')
.map(|author| author.trim().to_string())
.collect::<Vec<_>>(),
description: Some(version.description),
repository: version.repository,
schema_version: Some(version.schema_version),
wasm_api_version: version.wasm_api_version,
},
published_at: convert_time_to_chrono(version.published_at),
download_count: extension.total_download_count as u64,
}
}
pub fn convert_time_to_chrono(time: time::PrimitiveDateTime) -> chrono::DateTime<Utc> {
chrono::DateTime::from_naive_utc_and_offset(
chrono::NaiveDateTime::from_timestamp_opt(time.assume_utc().unix_timestamp(), 0).unwrap(),
Utc,
)
}

View File

@@ -1,8 +1,7 @@
use super::*;
use rpc::Notification;
use sea_orm::{SelectColumns, TryInsertResult};
use sea_orm::TryInsertResult;
use time::OffsetDateTime;
use util::ResultExt;
impl Database {
/// Inserts a record representing a user joining the chat for a given channel.
@@ -163,9 +162,6 @@ impl Database {
lower_half: nonce.1,
}),
reply_to_message_id: row.reply_to_message_id.map(|id| id.to_proto()),
edited_at: row
.edited_at
.map(|t| t.assume_utc().unix_timestamp() as u64),
}
})
.collect::<Vec<_>>();
@@ -203,31 +199,6 @@ impl Database {
Ok(messages)
}
fn format_mentions_to_entities(
&self,
message_id: MessageId,
body: &str,
mentions: &[proto::ChatMention],
) -> Result<Vec<tables::channel_message_mention::ActiveModel>> {
Ok(mentions
.iter()
.filter_map(|mention| {
let range = mention.range.as_ref()?;
if !body.is_char_boundary(range.start as usize)
|| !body.is_char_boundary(range.end as usize)
{
return None;
}
Some(channel_message_mention::ActiveModel {
message_id: ActiveValue::Set(message_id),
start_offset: ActiveValue::Set(range.start as i32),
end_offset: ActiveValue::Set(range.end as i32),
user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
})
})
.collect::<Vec<_>>())
}
/// Creates a new channel message.
#[allow(clippy::too_many_arguments)]
pub async fn create_channel_message(
@@ -278,7 +249,6 @@ impl Database {
nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
id: ActiveValue::NotSet,
reply_to_message_id: ActiveValue::Set(reply_to_message_id),
edited_at: ActiveValue::NotSet,
})
.on_conflict(
OnConflict::columns([
@@ -300,7 +270,23 @@ impl Database {
let mentioned_user_ids =
mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
let mentions = self.format_mentions_to_entities(message_id, body, mentions)?;
let mentions = mentions
.iter()
.filter_map(|mention| {
let range = mention.range.as_ref()?;
if !body.is_char_boundary(range.start as usize)
|| !body.is_char_boundary(range.end as usize)
{
return None;
}
Some(channel_message_mention::ActiveModel {
message_id: ActiveValue::Set(message_id),
start_offset: ActiveValue::Set(range.start as i32),
end_offset: ActiveValue::Set(range.end as i32),
user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
})
})
.collect::<Vec<_>>();
if !mentions.is_empty() {
channel_message_mention::Entity::insert_many(mentions)
.exec(&*tx)
@@ -481,20 +467,13 @@ impl Database {
Ok(results)
}
fn get_notification_kind_id_by_name(&self, notification_kind: &str) -> Option<i32> {
self.notification_kinds_by_id
.iter()
.find(|(_, kind)| **kind == notification_kind)
.map(|kind| kind.0 .0)
}
/// Removes the channel message with the given ID.
pub async fn remove_channel_message(
&self,
channel_id: ChannelId,
message_id: MessageId,
user_id: UserId,
) -> Result<(Vec<ConnectionId>, Vec<NotificationId>)> {
) -> Result<Vec<ConnectionId>> {
self.transaction(|tx| async move {
let mut rows = channel_chat_participant::Entity::find()
.filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
@@ -539,190 +518,7 @@ impl Database {
}
}
let notification_kind_id =
self.get_notification_kind_id_by_name("ChannelMessageMention");
let existing_notifications = notification::Entity::find()
.filter(notification::Column::EntityId.eq(message_id))
.filter(notification::Column::Kind.eq(notification_kind_id))
.select_column(notification::Column::Id)
.all(&*tx)
.await?;
let existing_notification_ids = existing_notifications
.into_iter()
.map(|notification| notification.id)
.collect();
// remove all the mention notifications for this message
notification::Entity::delete_many()
.filter(notification::Column::EntityId.eq(message_id))
.filter(notification::Column::Kind.eq(notification_kind_id))
.exec(&*tx)
.await?;
Ok((participant_connection_ids, existing_notification_ids))
})
.await
}
/// Updates the channel message with the given ID, body and timestamp(edited_at).
pub async fn update_channel_message(
&self,
channel_id: ChannelId,
message_id: MessageId,
user_id: UserId,
body: &str,
mentions: &[proto::ChatMention],
edited_at: OffsetDateTime,
) -> Result<UpdatedChannelMessage> {
self.transaction(|tx| async move {
let channel = self.get_channel_internal(channel_id, &tx).await?;
self.check_user_is_channel_participant(&channel, user_id, &tx)
.await?;
let mut rows = channel_chat_participant::Entity::find()
.filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
.stream(&*tx)
.await?;
let mut is_participant = false;
let mut participant_connection_ids = Vec::new();
let mut participant_user_ids = Vec::new();
while let Some(row) = rows.next().await {
let row = row?;
if row.user_id == user_id {
is_participant = true;
}
participant_user_ids.push(row.user_id);
participant_connection_ids.push(row.connection());
}
drop(rows);
if !is_participant {
Err(anyhow!("not a chat participant"))?;
}
let channel_message = channel_message::Entity::find_by_id(message_id)
.filter(channel_message::Column::SenderId.eq(user_id))
.one(&*tx)
.await?;
let Some(channel_message) = channel_message else {
Err(anyhow!("Channel message not found"))?
};
let edited_at = edited_at.to_offset(time::UtcOffset::UTC);
let edited_at = time::PrimitiveDateTime::new(edited_at.date(), edited_at.time());
let updated_message = channel_message::ActiveModel {
body: ActiveValue::Set(body.to_string()),
edited_at: ActiveValue::Set(Some(edited_at)),
reply_to_message_id: ActiveValue::Unchanged(channel_message.reply_to_message_id),
id: ActiveValue::Unchanged(message_id),
channel_id: ActiveValue::Unchanged(channel_id),
sender_id: ActiveValue::Unchanged(user_id),
sent_at: ActiveValue::Unchanged(channel_message.sent_at),
nonce: ActiveValue::Unchanged(channel_message.nonce),
};
let result = channel_message::Entity::update_many()
.set(updated_message)
.filter(channel_message::Column::Id.eq(message_id))
.filter(channel_message::Column::SenderId.eq(user_id))
.exec(&*tx)
.await?;
if result.rows_affected == 0 {
return Err(anyhow!(
"Attempted to edit a message (id: {message_id}) which does not exist anymore."
))?;
}
// we have to fetch the old mentions,
// so we don't send a notification when the message has been edited that you are mentioned in
let old_mentions = channel_message_mention::Entity::find()
.filter(channel_message_mention::Column::MessageId.eq(message_id))
.all(&*tx)
.await?;
// remove all existing mentions
channel_message_mention::Entity::delete_many()
.filter(channel_message_mention::Column::MessageId.eq(message_id))
.exec(&*tx)
.await?;
let new_mentions = self.format_mentions_to_entities(message_id, body, mentions)?;
if !new_mentions.is_empty() {
// insert new mentions
channel_message_mention::Entity::insert_many(new_mentions)
.exec(&*tx)
.await?;
}
let mut update_mention_user_ids = HashSet::default();
let mut new_mention_user_ids =
mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
// Filter out users that were mentioned before
for mention in &old_mentions {
if new_mention_user_ids.contains(&mention.user_id.to_proto()) {
update_mention_user_ids.insert(mention.user_id.to_proto());
}
new_mention_user_ids.remove(&mention.user_id.to_proto());
}
let notification_kind_id =
self.get_notification_kind_id_by_name("ChannelMessageMention");
let existing_notifications = notification::Entity::find()
.filter(notification::Column::EntityId.eq(message_id))
.filter(notification::Column::Kind.eq(notification_kind_id))
.all(&*tx)
.await?;
// determine which notifications should be updated or deleted
let mut deleted_notification_ids = HashSet::default();
let mut updated_mention_notifications = Vec::new();
for notification in existing_notifications {
if update_mention_user_ids.contains(&notification.recipient_id.to_proto()) {
if let Some(notification) =
self::notifications::model_to_proto(self, notification).log_err()
{
updated_mention_notifications.push(notification);
}
} else {
deleted_notification_ids.insert(notification.id);
}
}
let mut notifications = Vec::new();
for mentioned_user in new_mention_user_ids {
notifications.extend(
self.create_notification(
UserId::from_proto(mentioned_user),
rpc::Notification::ChannelMessageMention {
message_id: message_id.to_proto(),
sender_id: user_id.to_proto(),
channel_id: channel_id.to_proto(),
},
false,
&tx,
)
.await?,
);
}
Ok(UpdatedChannelMessage {
message_id,
participant_connection_ids,
notifications,
reply_to_message_id: channel_message.reply_to_message_id,
timestamp: channel_message.sent_at,
deleted_mention_notification_ids: deleted_notification_ids
.into_iter()
.collect::<Vec<_>>(),
updated_mention_notifications,
})
Ok(participant_connection_ids)
})
.await
}

Some files were not shown because too many files have changed in this diff Show More