Compare commits
1 Commits
text-eg
...
extension-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
961e548bd6 |
6
.github/ISSUE_TEMPLATE/2_crash_report.yml
vendored
6
.github/ISSUE_TEMPLATE/2_crash_report.yml
vendored
@@ -23,6 +23,12 @@ body:
|
||||
description: Run the `copy system specs into clipboard` command palette action and paste the output in the field below.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: If applicable, add mockups / screenshots to help explain present your vision of the feature
|
||||
description: Drag issues into the text input below
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: If applicable, attach your `~/Library/Logs/Zed/Zed.log` file to this issue.
|
||||
|
||||
3
.github/workflows/ci.yml
vendored
3
.github/workflows/ci.yml
vendored
@@ -54,9 +54,6 @@ jobs:
|
||||
- name: Check unused dependencies
|
||||
uses: bnjbvr/cargo-machete@main
|
||||
|
||||
- name: Check license generation
|
||||
run: script/generate-licenses /tmp/zed_licenses_output
|
||||
|
||||
- name: Ensure fresh merge
|
||||
shell: bash -euxo pipefail {0}
|
||||
run: |
|
||||
|
||||
40
.github/workflows/publish_extension_cli.yml
vendored
40
.github/workflows/publish_extension_cli.yml
vendored
@@ -1,40 +0,0 @@
|
||||
name: Publish zed-extension CLI
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- extension-cli
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
CARGO_INCREMENTAL: 0
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
name: Publish zed-extension CLI
|
||||
runs-on:
|
||||
- ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: swatinem/rust-cache@v2
|
||||
with:
|
||||
save-if: ${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: Configure linux
|
||||
shell: bash -euxo pipefail {0}
|
||||
run: script/linux
|
||||
|
||||
- name: Build extension CLI
|
||||
run: cargo build --release --package extension_cli
|
||||
|
||||
- name: Upload binary
|
||||
env:
|
||||
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
|
||||
DIGITALOCEAN_SPACES_SECRET_KEY: ${{ secrets.DIGITALOCEAN_SPACES_SECRET_KEY }}
|
||||
run: script/upload-extension-cli ${{ github.sha }}
|
||||
@@ -9,10 +9,10 @@ jobs:
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
python-version: "3.10.5"
|
||||
architecture: "x64"
|
||||
cache: "pip"
|
||||
- run: pip install -r script/update_top_ranking_issues/requirements.txt
|
||||
- run: python script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 5393
|
||||
- run: python script/update_top_ranking_issues/main.py 5393 --github-token ${{ secrets.GITHUB_TOKEN }} --prod
|
||||
|
||||
@@ -9,10 +9,10 @@ jobs:
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
python-version: "3.10.5"
|
||||
architecture: "x64"
|
||||
cache: "pip"
|
||||
- run: pip install -r script/update_top_ranking_issues/requirements.txt
|
||||
- run: python script/update_top_ranking_issues/main.py --github-token ${{ secrets.GITHUB_TOKEN }} --issue-reference-number 6952 --query-day-interval 7
|
||||
- run: python script/update_top_ranking_issues/main.py 6952 --github-token ${{ secrets.GITHUB_TOKEN }} --prod --query-day-interval 7
|
||||
|
||||
@@ -15,10 +15,6 @@
|
||||
"JSON": {
|
||||
"tab_size": 2,
|
||||
"formatter": "prettier"
|
||||
},
|
||||
"JavaScript": {
|
||||
"tab_size": 2,
|
||||
"formatter": "prettier"
|
||||
}
|
||||
},
|
||||
"formatter": "auto"
|
||||
|
||||
1342
Cargo.lock
generated
1342
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
71
Cargo.toml
71
Cargo.toml
@@ -1,7 +1,7 @@
|
||||
[workspace]
|
||||
members = [
|
||||
"crates/activity_indicator",
|
||||
"crates/anthropic",
|
||||
"crates/ai",
|
||||
"crates/assets",
|
||||
"crates/assistant",
|
||||
"crates/audio",
|
||||
@@ -29,16 +29,13 @@ members = [
|
||||
"crates/feature_flags",
|
||||
"crates/feedback",
|
||||
"crates/file_finder",
|
||||
"crates/file_icons",
|
||||
"crates/fs",
|
||||
"crates/fsevent",
|
||||
"crates/fuzzy",
|
||||
"crates/git",
|
||||
"crates/go_to_line",
|
||||
"crates/google_ai",
|
||||
"crates/gpui",
|
||||
"crates/gpui_macros",
|
||||
"crates/image_viewer",
|
||||
"crates/install_cli",
|
||||
"crates/journal",
|
||||
"crates/language",
|
||||
@@ -54,7 +51,6 @@ members = [
|
||||
"crates/multi_buffer",
|
||||
"crates/node_runtime",
|
||||
"crates/notifications",
|
||||
"crates/open_ai",
|
||||
"crates/outline",
|
||||
"crates/picker",
|
||||
"crates/prettier",
|
||||
@@ -72,7 +68,7 @@ members = [
|
||||
"crates/task",
|
||||
"crates/tasks_ui",
|
||||
"crates/search",
|
||||
"crates/semantic_version",
|
||||
"crates/semantic_index",
|
||||
"crates/settings",
|
||||
"crates/snippet",
|
||||
"crates/sqlez",
|
||||
@@ -80,11 +76,9 @@ members = [
|
||||
"crates/story",
|
||||
"crates/storybook",
|
||||
"crates/sum_tree",
|
||||
"crates/tab_switcher",
|
||||
"crates/terminal",
|
||||
"crates/terminal_view",
|
||||
"crates/text",
|
||||
"crates/text-eg",
|
||||
"crates/theme",
|
||||
"crates/theme_importer",
|
||||
"crates/theme_selector",
|
||||
@@ -100,21 +94,8 @@ members = [
|
||||
"crates/zed",
|
||||
"crates/zed_actions",
|
||||
|
||||
"extensions/astro",
|
||||
"extensions/clojure",
|
||||
"extensions/csharp",
|
||||
"extensions/emmet",
|
||||
"extensions/erlang",
|
||||
"extensions/gleam",
|
||||
"extensions/haskell",
|
||||
"extensions/html",
|
||||
"extensions/php",
|
||||
"extensions/prisma",
|
||||
"extensions/purescript",
|
||||
"extensions/svelte",
|
||||
"extensions/toml",
|
||||
"extensions/uiua",
|
||||
"extensions/zig",
|
||||
|
||||
"tooling/xtask",
|
||||
]
|
||||
@@ -124,7 +105,6 @@ resolver = "2"
|
||||
[workspace.dependencies]
|
||||
activity_indicator = { path = "crates/activity_indicator" }
|
||||
ai = { path = "crates/ai" }
|
||||
anthropic = { path = "crates/anthropic" }
|
||||
assets = { path = "crates/assets" }
|
||||
assistant = { path = "crates/assistant" }
|
||||
audio = { path = "crates/audio" }
|
||||
@@ -152,17 +132,14 @@ extensions_ui = { path = "crates/extensions_ui" }
|
||||
feature_flags = { path = "crates/feature_flags" }
|
||||
feedback = { path = "crates/feedback" }
|
||||
file_finder = { path = "crates/file_finder" }
|
||||
file_icons = { path = "crates/file_icons" }
|
||||
fs = { path = "crates/fs" }
|
||||
fsevent = { path = "crates/fsevent" }
|
||||
fuzzy = { path = "crates/fuzzy" }
|
||||
git = { path = "crates/git" }
|
||||
go_to_line = { path = "crates/go_to_line" }
|
||||
google_ai = { path = "crates/google_ai" }
|
||||
gpui = { path = "crates/gpui" }
|
||||
gpui_macros = { path = "crates/gpui_macros" }
|
||||
install_cli = { path = "crates/install_cli" }
|
||||
image_viewer = { path = "crates/image_viewer" }
|
||||
journal = { path = "crates/journal" }
|
||||
language = { path = "crates/language" }
|
||||
language_selector = { path = "crates/language_selector" }
|
||||
@@ -177,7 +154,6 @@ menu = { path = "crates/menu" }
|
||||
multi_buffer = { path = "crates/multi_buffer" }
|
||||
node_runtime = { path = "crates/node_runtime" }
|
||||
notifications = { path = "crates/notifications" }
|
||||
open_ai = { path = "crates/open_ai" }
|
||||
outline = { path = "crates/outline" }
|
||||
picker = { path = "crates/picker" }
|
||||
plugin = { path = "crates/plugin" }
|
||||
@@ -196,7 +172,7 @@ rpc = { path = "crates/rpc" }
|
||||
task = { path = "crates/task" }
|
||||
tasks_ui = { path = "crates/tasks_ui" }
|
||||
search = { path = "crates/search" }
|
||||
semantic_version = { path = "crates/semantic_version" }
|
||||
semantic_index = { path = "crates/semantic_index" }
|
||||
settings = { path = "crates/settings" }
|
||||
snippet = { path = "crates/snippet" }
|
||||
sqlez = { path = "crates/sqlez" }
|
||||
@@ -204,7 +180,6 @@ sqlez_macros = { path = "crates/sqlez_macros" }
|
||||
story = { path = "crates/story" }
|
||||
storybook = { path = "crates/storybook" }
|
||||
sum_tree = { path = "crates/sum_tree" }
|
||||
tab_switcher = { path = "crates/tab_switcher" }
|
||||
terminal = { path = "crates/terminal" }
|
||||
terminal_view = { path = "crates/terminal_view" }
|
||||
text = { path = "crates/text" }
|
||||
@@ -223,7 +198,6 @@ zed = { path = "crates/zed" }
|
||||
zed_actions = { path = "crates/zed_actions" }
|
||||
|
||||
anyhow = "1.0.57"
|
||||
any_vec = "0.13"
|
||||
async-compression = { version = "0.4", features = ["gzip", "futures-io"] }
|
||||
async-fs = "1.6"
|
||||
async-recursion = "1.0.0"
|
||||
@@ -233,7 +207,7 @@ bitflags = "2.4.2"
|
||||
blade-graphics = { git = "https://github.com/kvark/blade", rev = "61cbd6b2c224791d52b150fe535cee665cc91bb2" }
|
||||
blade-macros = { git = "https://github.com/kvark/blade", rev = "61cbd6b2c224791d52b150fe535cee665cc91bb2" }
|
||||
blade-rwh = { package = "raw-window-handle", version = "0.5" }
|
||||
cap-std = "3.0"
|
||||
cap-std = "2.0"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
clickhouse = { version = "0.11.6" }
|
||||
@@ -265,9 +239,7 @@ parking_lot = "0.12.1"
|
||||
profiling = "1"
|
||||
postage = { version = "0.5", features = ["futures-traits"] }
|
||||
pretty_assertions = "1.3.0"
|
||||
prost = "0.9"
|
||||
prost-build = "0.9"
|
||||
prost-types = "0.9"
|
||||
prost = "0.8"
|
||||
pulldown-cmark = { version = "0.10.0", default-features = false }
|
||||
rand = "0.8.5"
|
||||
refineable = { path = "./crates/refineable" }
|
||||
@@ -295,8 +267,6 @@ tempfile = "3.9.0"
|
||||
thiserror = "1.0.29"
|
||||
tiktoken-rs = "0.5.7"
|
||||
time = { version = "0.3", features = [
|
||||
"macros",
|
||||
"parsing",
|
||||
"serde",
|
||||
"serde-well-known",
|
||||
"formatting",
|
||||
@@ -305,18 +275,25 @@ toml = "0.8"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tower-http = "0.4.4"
|
||||
tree-sitter = { version = "0.20", features = ["wasm"] }
|
||||
tree-sitter-astro = { git = "https://github.com/virchau13/tree-sitter-astro.git", rev = "e924787e12e8a03194f36a113290ac11d6dc10f3" }
|
||||
tree-sitter-bash = { git = "https://github.com/tree-sitter/tree-sitter-bash", rev = "7331995b19b8f8aba2d5e26deb51d2195c18bc94" }
|
||||
tree-sitter-c = "0.20.1"
|
||||
tree-sitter-clojure = { git = "https://github.com/prcastro/tree-sitter-clojure", branch = "update-ts" }
|
||||
tree-sitter-c-sharp = { git = "https://github.com/tree-sitter/tree-sitter-c-sharp", rev = "dd5e59721a5f8dae34604060833902b882023aaf" }
|
||||
tree-sitter-cpp = { git = "https://github.com/tree-sitter/tree-sitter-cpp", rev = "f44509141e7e483323d2ec178f2d2e6c0fc041c1" }
|
||||
tree-sitter-css = { git = "https://github.com/tree-sitter/tree-sitter-css", rev = "769203d0f9abe1a9a691ac2b9fe4bb4397a73c51" }
|
||||
tree-sitter-dockerfile = { git = "https://github.com/camdencheek/tree-sitter-dockerfile", rev = "33e22c33bcdbfc33d42806ee84cfd0b1248cc392" }
|
||||
tree-sitter-dart = { git = "https://github.com/agent3bood/tree-sitter-dart", rev = "48934e3bf757a9b78f17bdfaa3e2b4284656fdc7" }
|
||||
tree-sitter-elixir = { git = "https://github.com/elixir-lang/tree-sitter-elixir", rev = "a2861e88a730287a60c11ea9299c033c7d076e30" }
|
||||
tree-sitter-elm = { git = "https://github.com/elm-tooling/tree-sitter-elm", rev = "692c50c0b961364c40299e73c1306aecb5d20f40" }
|
||||
tree-sitter-embedded-template = "0.20.0"
|
||||
tree-sitter-erlang = "0.4.0"
|
||||
tree-sitter-gleam = { git = "https://github.com/gleam-lang/tree-sitter-gleam", rev = "58b7cac8fc14c92b0677c542610d8738c373fa81" }
|
||||
tree-sitter-glsl = { git = "https://github.com/theHamsta/tree-sitter-glsl", rev = "2a56fb7bc8bb03a1892b4741279dd0a8758b7fb3" }
|
||||
tree-sitter-go = { git = "https://github.com/tree-sitter/tree-sitter-go", rev = "aeb2f33b366fd78d5789ff104956ce23508b85db" }
|
||||
tree-sitter-gomod = { git = "https://github.com/camdencheek/tree-sitter-go-mod" }
|
||||
tree-sitter-gowork = { git = "https://github.com/d1y/tree-sitter-go-work" }
|
||||
tree-sitter-haskell = { git = "https://github.com/tree-sitter/tree-sitter-haskell", rev = "8a99848fc734f9c4ea523b3f2a07df133cbbcec2" }
|
||||
tree-sitter-hcl = { git = "https://github.com/MichaHoffmann/tree-sitter-hcl", rev = "v1.1.0" }
|
||||
rustc-demangle = "0.1.23"
|
||||
tree-sitter-heex = { git = "https://github.com/phoenixframework/tree-sitter-heex", rev = "2e1348c3cf2c9323e87c2744796cf3f3868aa82a" }
|
||||
@@ -328,32 +305,38 @@ tree-sitter-markdown = { git = "https://github.com/MDeiml/tree-sitter-markdown",
|
||||
tree-sitter-nix = { git = "https://github.com/nix-community/tree-sitter-nix", rev = "66e3e9ce9180ae08fc57372061006ef83f0abde7" }
|
||||
tree-sitter-nu = { git = "https://github.com/nushell/tree-sitter-nu", rev = "7dd29f9616822e5fc259f5b4ae6c4ded9a71a132" }
|
||||
tree-sitter-ocaml = { git = "https://github.com/tree-sitter/tree-sitter-ocaml", rev = "4abfdc1c7af2c6c77a370aee974627be1c285b3b" }
|
||||
tree-sitter-php = "0.21.1"
|
||||
tree-sitter-prisma-io = { git = "https://github.com/victorhqc/tree-sitter-prisma" }
|
||||
tree-sitter-proto = { git = "https://github.com/rewinfrey/tree-sitter-proto", rev = "36d54f288aee112f13a67b550ad32634d0c2cb52" }
|
||||
tree-sitter-purescript = { git = "https://github.com/postsolar/tree-sitter-purescript", rev = "v0.1.0" }
|
||||
tree-sitter-python = "0.20.2"
|
||||
tree-sitter-racket = { git = "https://github.com/zed-industries/tree-sitter-racket", rev = "eb010cf2c674c6fd9a6316a84e28ef90190fe51a" }
|
||||
tree-sitter-regex = "0.20.0"
|
||||
tree-sitter-ruby = "0.20.0"
|
||||
tree-sitter-rust = "0.20.3"
|
||||
tree-sitter-scheme = { git = "https://github.com/6cdh/tree-sitter-scheme", rev = "af0fd1fa452cb2562dc7b5c8a8c55551c39273b9" }
|
||||
tree-sitter-svelte = { git = "https://github.com/Himujjal/tree-sitter-svelte", rev = "bd60db7d3d06f89b6ec3b287c9a6e9190b5564bd" }
|
||||
tree-sitter-toml = { git = "https://github.com/tree-sitter/tree-sitter-toml", rev = "342d9be207c2dba869b9967124c679b5e6fd0ebe" }
|
||||
tree-sitter-typescript = { git = "https://github.com/tree-sitter/tree-sitter-typescript", rev = "5d20856f34315b068c41edaee2ac8a100081d259" }
|
||||
tree-sitter-vue = { git = "https://github.com/zed-industries/tree-sitter-vue", rev = "6608d9d60c386f19d80af7d8132322fa11199c42" }
|
||||
tree-sitter-yaml = { git = "https://github.com/zed-industries/tree-sitter-yaml", rev = "f545a41f57502e1b5ddf2a6668896c1b0620f930" }
|
||||
tree-sitter-zig = { git = "https://github.com/maxxnino/tree-sitter-zig", rev = "0d08703e4c3f426ec61695d7617415fff97029bd" }
|
||||
unindent = "0.1.7"
|
||||
unicase = "2.6"
|
||||
url = "2.2"
|
||||
uuid = { version = "1.1.2", features = ["v4"] }
|
||||
wasmparser = "0.201"
|
||||
wasm-encoder = "0.201"
|
||||
wasmtime = { version = "19.0.0", default-features = false, features = [
|
||||
wasmparser = "0.121"
|
||||
wasm-encoder = "0.41"
|
||||
wasmtime = { version = "18.0", default-features = false, features = [
|
||||
"async",
|
||||
"demangle",
|
||||
"runtime",
|
||||
"cranelift",
|
||||
"component-model",
|
||||
] }
|
||||
wasmtime-wasi = "19.0.0"
|
||||
wasmtime-wasi = "18.0"
|
||||
which = "6.0.0"
|
||||
wit-component = "0.201"
|
||||
wit-component = "0.20"
|
||||
sys-locale = "0.3.1"
|
||||
|
||||
[workspace.dependencies.windows]
|
||||
@@ -368,7 +351,6 @@ features = [
|
||||
"Win32_Security",
|
||||
"Win32_Security_Credentials",
|
||||
"Win32_Storage_FileSystem",
|
||||
"Win32_System_LibraryLoader",
|
||||
"Win32_System_Com",
|
||||
"Win32_System_Com_StructuredStorage",
|
||||
"Win32_System_DataExchange",
|
||||
@@ -387,7 +369,7 @@ features = [
|
||||
]
|
||||
|
||||
[patch.crates-io]
|
||||
tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "7f21c3b98c0749ac192da67a0d65dfe3eabc4a63" }
|
||||
tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "4294e59279205f503eb14348dd5128bd5910c8fb" }
|
||||
# Workaround for a broken nightly build of gpui: See #7644 and revisit once 0.5.3 is released.
|
||||
pathfinder_simd = { git = "https://github.com/servo/pathfinder.git", rev = "30419d07660dc11a21e42ef4a7fa329600cff152" }
|
||||
|
||||
@@ -398,20 +380,15 @@ debug = "limited"
|
||||
[profile.dev.package]
|
||||
taffy = { opt-level = 3 }
|
||||
cranelift-codegen = { opt-level = 3 }
|
||||
resvg = { opt-level = 3 }
|
||||
rustybuzz = { opt-level = 3 }
|
||||
ttf-parser = { opt-level = 3 }
|
||||
wasmtime-cranelift = { opt-level = 3 }
|
||||
wasmtime = { opt-level = 3 }
|
||||
|
||||
[profile.release]
|
||||
debug = "limited"
|
||||
lto = "thin"
|
||||
codegen-units = 1
|
||||
|
||||
[profile.release.package]
|
||||
zed = { codegen-units = 16 }
|
||||
|
||||
[workspace.lints.clippy]
|
||||
dbg_macro = "deny"
|
||||
todo = "deny"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# syntax = docker/dockerfile:1.2
|
||||
|
||||
FROM rust:1.77-bookworm as builder
|
||||
FROM rust:1.76-bookworm as builder
|
||||
WORKDIR app
|
||||
COPY . .
|
||||
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M18 10L21 7L17 3L14 6M18 10L8 20H4V16L14 6M18 10L14 6" stroke="#000000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 379 B |
@@ -1,6 +1,4 @@
|
||||
[
|
||||
// todo(linux): Review the editor bindings
|
||||
// Standard Linux bindings
|
||||
{
|
||||
"bindings": {
|
||||
"up": "menu::SelectPrev",
|
||||
@@ -11,14 +9,14 @@
|
||||
"pagedown": "menu::SelectLast",
|
||||
"shift-pagedown": "menu::SelectFirst",
|
||||
"ctrl-n": "menu::SelectNext",
|
||||
"ctrl-up": "menu::SelectFirst",
|
||||
"ctrl-down": "menu::SelectLast",
|
||||
"enter": "menu::Confirm",
|
||||
"shift-f10": "menu::ShowContextMenu",
|
||||
"ctrl-enter": "menu::SecondaryConfirm",
|
||||
"escape": "menu::Cancel",
|
||||
"ctrl-escape": "menu::Cancel",
|
||||
"ctrl-c": "menu::Cancel",
|
||||
"shift-enter": "picker::UseSelectedQuery",
|
||||
"alt-enter": ["picker::ConfirmInput", { "secondary": false }],
|
||||
"ctrl-alt-enter": ["picker::ConfirmInput", { "secondary": true }],
|
||||
"shift-enter": "menu::UseSelectedQuery",
|
||||
"ctrl-shift-w": "workspace::CloseWindow",
|
||||
"shift-escape": "workspace::ToggleZoom",
|
||||
"ctrl-o": "workspace::Open",
|
||||
@@ -29,6 +27,8 @@
|
||||
"ctrl-,": "zed::OpenSettings",
|
||||
"ctrl-q": "zed::Quit",
|
||||
"ctrl-h": "zed::Hide",
|
||||
"alt-ctrl-h": "zed::HideOthers",
|
||||
"ctrl-m": "zed::Minimize",
|
||||
"f11": "zed::ToggleFullScreen"
|
||||
}
|
||||
},
|
||||
@@ -45,101 +45,80 @@
|
||||
"shift-tab": "editor::TabPrev",
|
||||
"ctrl-k": "editor::CutToEndOfLine",
|
||||
"ctrl-t": "editor::Transpose",
|
||||
// "ctrl-backspace": "editor::DeleteToBeginningOfLine",
|
||||
// "ctrl-delete": "editor::DeleteToEndOfLine",
|
||||
"ctrl-backspace": "editor::DeleteToPreviousWordStart",
|
||||
// "ctrl-w": "editor::DeleteToPreviousWordStart",
|
||||
"ctrl-delete": "editor::DeleteToNextWordEnd",
|
||||
// "alt-h": "editor::DeleteToPreviousWordStart",
|
||||
// "alt-d": "editor::DeleteToNextWordEnd",
|
||||
"ctrl-backspace": "editor::DeleteToBeginningOfLine",
|
||||
"ctrl-delete": "editor::DeleteToEndOfLine",
|
||||
"alt-backspace": "editor::DeleteToPreviousWordStart",
|
||||
"alt-delete": "editor::DeleteToNextWordEnd",
|
||||
"alt-h": "editor::DeleteToPreviousWordStart",
|
||||
"alt-d": "editor::DeleteToNextWordEnd",
|
||||
"ctrl-x": "editor::Cut",
|
||||
"ctrl-c": "editor::Copy",
|
||||
"ctrl-v": "editor::Paste",
|
||||
"ctrl-z": "editor::Undo",
|
||||
"ctrl-shift-z": "editor::Redo",
|
||||
"ctrl-y": "editor::Redo",
|
||||
"up": "editor::MoveUp",
|
||||
// "ctrl-up": "editor::MoveToStartOfParagraph", todo(linux) Should be "scroll down by 1 line"
|
||||
"ctrl-up": "editor::MoveToStartOfParagraph",
|
||||
"pageup": "editor::PageUp",
|
||||
// "shift-pageup": "editor::MovePageUp", todo(linux) should be 'select page up'
|
||||
"shift-pageup": "editor::MovePageUp",
|
||||
"home": "editor::MoveToBeginningOfLine",
|
||||
"down": "editor::MoveDown",
|
||||
// "ctrl-down": "editor::MoveToEndOfParagraph", todo(linux) should be "scroll up by 1 line"
|
||||
"ctrl-down": "editor::MoveToEndOfParagraph",
|
||||
"pagedown": "editor::PageDown",
|
||||
// "shift-pagedown": "editor::MovePageDown", todo(linux) should be 'select page down'
|
||||
"shift-pagedown": "editor::MovePageDown",
|
||||
"end": "editor::MoveToEndOfLine",
|
||||
"left": "editor::MoveLeft",
|
||||
"right": "editor::MoveRight",
|
||||
"ctrl-left": "editor::MoveToPreviousWordStart",
|
||||
// "alt-b": "editor::MoveToPreviousWordStart",
|
||||
"ctrl-right": "editor::MoveToNextWordEnd",
|
||||
// "alt-f": "editor::MoveToNextWordEnd",
|
||||
// "cmd-left": "editor::MoveToBeginningOfLine",
|
||||
// "ctrl-a": "editor::MoveToBeginningOfLine",
|
||||
// "cmd-right": "editor::MoveToEndOfLine",
|
||||
// "ctrl-e": "editor::MoveToEndOfLine",
|
||||
"ctrl-p": "editor::MoveUp",
|
||||
"ctrl-n": "editor::MoveDown",
|
||||
"ctrl-b": "editor::MoveLeft",
|
||||
"ctrl-f": "editor::MoveRight",
|
||||
"ctrl-shift-l": "editor::NextScreen", // todo(linux): What is this
|
||||
"alt-left": "editor::MoveToPreviousWordStart",
|
||||
"alt-b": "editor::MoveToPreviousWordStart",
|
||||
"alt-right": "editor::MoveToNextWordEnd",
|
||||
"alt-f": "editor::MoveToNextWordEnd",
|
||||
"ctrl-e": "editor::MoveToEndOfLine",
|
||||
"ctrl-home": "editor::MoveToBeginning",
|
||||
"ctrl-end": "editor::MoveToEnd",
|
||||
"ctrl-=end": "editor::MoveToEnd",
|
||||
"shift-up": "editor::SelectUp",
|
||||
"shift-down": "editor::SelectDown",
|
||||
"ctrl-shift-n": "editor::SelectDown",
|
||||
"shift-left": "editor::SelectLeft",
|
||||
"ctrl-shift-b": "editor::SelectLeft",
|
||||
"shift-right": "editor::SelectRight",
|
||||
"ctrl-shift-left": "editor::SelectToPreviousWordStart",
|
||||
"ctrl-shift-right": "editor::SelectToNextWordEnd",
|
||||
"ctrl-shift-up": "editor::AddSelectionAbove",
|
||||
"ctrl-shift-down": "editor::AddSelectionBelow",
|
||||
// "ctrl-shift-up": "editor::SelectToStartOfParagraph",
|
||||
// "ctrl-shift-down": "editor::SelectToEndOfParagraph",
|
||||
"ctrl-shift-f": "editor::SelectRight",
|
||||
"alt-shift-left": "editor::SelectToPreviousWordStart",
|
||||
"alt-shift-b": "editor::SelectToPreviousWordStart",
|
||||
"alt-shift-right": "editor::SelectToNextWordEnd",
|
||||
"alt-shift-f": "editor::SelectToNextWordEnd",
|
||||
"ctrl-shift-up": "editor::SelectToStartOfParagraph",
|
||||
"ctrl-shift-down": "editor::SelectToEndOfParagraph",
|
||||
"ctrl-shift-home": "editor::SelectToBeginning",
|
||||
"ctrl-shift-end": "editor::SelectToEnd",
|
||||
"ctrl-a": "editor::SelectAll",
|
||||
"ctrl-l": "editor::SelectLine",
|
||||
"ctrl-shift-i": "editor::Format",
|
||||
// "cmd-shift-left": [
|
||||
// "editor::SelectToBeginningOfLine",
|
||||
// {
|
||||
// "stop_at_soft_wraps": true
|
||||
// }
|
||||
// ],
|
||||
"shift-home": [
|
||||
"editor::SelectToBeginningOfLine",
|
||||
{
|
||||
"stop_at_soft_wraps": true
|
||||
}
|
||||
],
|
||||
// "ctrl-shift-a": [
|
||||
// "editor::SelectToBeginningOfLine",
|
||||
// {
|
||||
// "stop_at_soft_wraps": true
|
||||
// }
|
||||
// ],
|
||||
// "cmd-shift-right": [
|
||||
// "editor::SelectToEndOfLine",
|
||||
// {
|
||||
// "stop_at_soft_wraps": true
|
||||
// }
|
||||
// ],
|
||||
"shift-end": [
|
||||
"editor::SelectToEndOfLine",
|
||||
{
|
||||
"stop_at_soft_wraps": true
|
||||
}
|
||||
],
|
||||
// "ctrl-shift-e": [
|
||||
// "editor::SelectToEndOfLine",
|
||||
// {
|
||||
// "stop_at_soft_wraps": true
|
||||
// }
|
||||
// ],
|
||||
// "alt-v": [
|
||||
// "editor::MovePageUp",
|
||||
// {
|
||||
// "center_cursor": true
|
||||
// }
|
||||
// ],
|
||||
"ctrl-alt-space": "editor::ShowCharacterPalette",
|
||||
"ctrl-shift-e": [
|
||||
"editor::SelectToEndOfLine",
|
||||
{
|
||||
"stop_at_soft_wraps": true
|
||||
}
|
||||
],
|
||||
"ctrl-;": "editor::ToggleLineNumbers",
|
||||
"ctrl-k ctrl-r": "editor::RevertSelectedHunks",
|
||||
"ctrl-alt-g b": "editor::ToggleGitBlame"
|
||||
"ctrl-alt-z": "editor::RevertSelectedHunks"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -147,8 +126,8 @@
|
||||
"bindings": {
|
||||
"enter": "editor::Newline",
|
||||
"shift-enter": "editor::Newline",
|
||||
"ctrl-shift-enter": "editor::NewlineBelow",
|
||||
"ctrl-enter": "editor::NewlineAbove",
|
||||
"ctrl-shift-enter": "editor::NewlineAbove",
|
||||
"ctrl-enter": "editor::NewlineBelow",
|
||||
"alt-z": "editor::ToggleSoftWrap",
|
||||
"ctrl-f": [
|
||||
"buffer_search::Deploy",
|
||||
@@ -156,27 +135,21 @@
|
||||
"focus": true
|
||||
}
|
||||
],
|
||||
// "cmd-e": [
|
||||
// "buffer_search::Deploy",
|
||||
// {
|
||||
// "focus": false
|
||||
// }
|
||||
// ],
|
||||
"ctrl->": "assistant::QuoteSelection"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && mode == full && inline_completion",
|
||||
"context": "Editor && mode == full && copilot_suggestion",
|
||||
"bindings": {
|
||||
"alt-]": "editor::NextInlineCompletion",
|
||||
"alt-[": "editor::PreviousInlineCompletion",
|
||||
"alt-right": "editor::AcceptPartialInlineCompletion"
|
||||
"alt-]": "copilot::NextSuggestion",
|
||||
"alt-[": "copilot::PreviousSuggestion",
|
||||
"alt-right": "editor::AcceptPartialCopilotSuggestion"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && !inline_completion",
|
||||
"context": "Editor && !copilot_suggestion",
|
||||
"bindings": {
|
||||
"alt-\\": "editor::ShowInlineCompletion"
|
||||
"alt-\\": "copilot::Suggest"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -190,8 +163,8 @@
|
||||
{
|
||||
"context": "AssistantPanel",
|
||||
"bindings": {
|
||||
"ctrl-g": "search::SelectNextMatch",
|
||||
"ctrl-shift-g": "search::SelectPrevMatch"
|
||||
"f3": "search::SelectNextMatch",
|
||||
"shift-f3": "search::SelectPrevMatch"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -235,8 +208,9 @@
|
||||
"escape": "project_search::ToggleFocus",
|
||||
"alt-tab": "search::CycleMode",
|
||||
"ctrl-shift-h": "search::ToggleReplace",
|
||||
"alt-ctrl-g": "search::ActivateRegexMode",
|
||||
"alt-ctrl-x": "search::ActivateTextMode"
|
||||
"ctrl-alt-g": "search::ActivateRegexMode",
|
||||
"ctrl-alt-s": "search::ActivateSemanticMode",
|
||||
"ctrl-alt-x": "search::ActivateTextMode"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -250,7 +224,7 @@
|
||||
"context": "ProjectSearchBar && in_replace",
|
||||
"bindings": {
|
||||
"enter": "search::ReplaceNext",
|
||||
"ctrl-alt-enter": "search::ReplaceAll"
|
||||
"ctrl-enter": "search::ReplaceAll"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -259,31 +233,35 @@
|
||||
"escape": "project_search::ToggleFocus",
|
||||
"alt-tab": "search::CycleMode",
|
||||
"ctrl-shift-h": "search::ToggleReplace",
|
||||
"alt-ctrl-g": "search::ActivateRegexMode",
|
||||
"alt-ctrl-x": "search::ActivateTextMode"
|
||||
"ctrl-alt-g": "search::ActivateRegexMode",
|
||||
"ctrl-alt-s": "search::ActivateSemanticMode",
|
||||
"ctrl-alt-x": "search::ActivateTextMode"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Pane",
|
||||
"bindings": {
|
||||
"ctrl-pageup": "pane::ActivatePrevItem",
|
||||
"ctrl-pagedown": "pane::ActivateNextItem",
|
||||
"ctrl-{": "pane::ActivatePrevItem",
|
||||
"ctrl-}": "pane::ActivateNextItem",
|
||||
"ctrl-alt-left": "pane::ActivatePrevItem",
|
||||
"ctrl-alt-right": "pane::ActivateNextItem",
|
||||
"ctrl-w": "pane::CloseActiveItem",
|
||||
"alt-ctrl-t": "pane::CloseInactiveItems",
|
||||
"alt-ctrl-shift-w": "workspace::CloseInactiveTabsAndPanes",
|
||||
"ctrl-alt-t": "pane::CloseInactiveItems",
|
||||
"ctrl-alt-shift-w": "workspace::CloseInactiveTabsAndPanes",
|
||||
"ctrl-k u": "pane::CloseCleanItems",
|
||||
"ctrl-k w": "pane::CloseAllItems",
|
||||
"ctrl-shift-f": "project_search::ToggleFocus",
|
||||
"ctrl-alt-g": "search::SelectNextMatch",
|
||||
"ctrl-alt-shift-g": "search::SelectPrevMatch",
|
||||
"ctrl-alt-shift-h": "search::ToggleReplace",
|
||||
"ctrl-k ctrl-w": "pane::CloseAllItems",
|
||||
"ctrl-f": "project_search::ToggleFocus",
|
||||
"f3": "search::SelectNextMatch",
|
||||
"shift-f3": "search::SelectPrevMatch",
|
||||
"ctrl-shift-h": "search::ToggleReplace",
|
||||
"alt-enter": "search::SelectAllMatches",
|
||||
"alt-c": "search::ToggleCaseSensitive",
|
||||
"alt-w": "search::ToggleWholeWord",
|
||||
"alt-r": "search::CycleMode",
|
||||
"alt-ctrl-f": "project_search::ToggleFilters",
|
||||
"ctrl-alt-shift-r": "search::ActivateRegexMode",
|
||||
"ctrl-alt-shift-x": "search::ActivateTextMode"
|
||||
"ctrl-alt-c": "search::ToggleCaseSensitive",
|
||||
"ctrl-alt-w": "search::ToggleWholeWord",
|
||||
"alt-tab": "search::CycleMode",
|
||||
"ctrl-alt-f": "project_search::ToggleFilters",
|
||||
"ctrl-alt-g": "search::ActivateRegexMode",
|
||||
"ctrl-alt-s": "search::ActivateSemanticMode",
|
||||
"ctrl-alt-x": "search::ActivateTextMode"
|
||||
}
|
||||
},
|
||||
// Bindings from VS Code
|
||||
@@ -292,22 +270,8 @@
|
||||
"bindings": {
|
||||
"ctrl-[": "editor::Outdent",
|
||||
"ctrl-]": "editor::Indent",
|
||||
"shift-alt-up": "editor::AddSelectionAbove",
|
||||
"shift-alt-down": "editor::AddSelectionBelow",
|
||||
"ctrl-shift-k": "editor::DeleteLine",
|
||||
"alt-up": "editor::MoveLineUp",
|
||||
"alt-down": "editor::MoveLineDown",
|
||||
"ctrl-alt-shift-up": [
|
||||
"editor::DuplicateLine",
|
||||
{
|
||||
"move_upwards": true
|
||||
}
|
||||
],
|
||||
"ctrl-alt-shift-down": "editor::DuplicateLine",
|
||||
"ctrl-shift-left": "editor::SelectToPreviousWordStart",
|
||||
"ctrl-shift-right": "editor::SelectToNextWordEnd",
|
||||
"ctrl-shift-up": "editor::SelectLargerSyntaxNode", //todo(linux) tmp keybinding
|
||||
"ctrl-shift-down": "editor::SelectSmallerSyntaxNode", //todo(linux) tmp keybinding
|
||||
"ctrl-alt-up": "editor::AddSelectionAbove",
|
||||
"ctrl-alt-down": "editor::AddSelectionBelow",
|
||||
"ctrl-d": [
|
||||
"editor::SelectNext",
|
||||
{
|
||||
@@ -340,6 +304,8 @@
|
||||
"advance_downwards": false
|
||||
}
|
||||
],
|
||||
"alt-up": "editor::SelectLargerSyntaxNode",
|
||||
"alt-down": "editor::SelectSmallerSyntaxNode",
|
||||
"ctrl-u": "editor::UndoSelection",
|
||||
"ctrl-shift-u": "editor::RedoSelection",
|
||||
"f8": "editor::GoToDiagnostic",
|
||||
@@ -348,16 +314,15 @@
|
||||
"f12": "editor::GoToDefinition",
|
||||
"alt-f12": "editor::GoToDefinitionSplit",
|
||||
"ctrl-f12": "editor::GoToTypeDefinition",
|
||||
"shift-f12": "editor::GoToImplementation",
|
||||
"alt-ctrl-f12": "editor::GoToTypeDefinitionSplit",
|
||||
"ctrl-alt-f12": "editor::GoToTypeDefinitionSplit",
|
||||
"alt-shift-f12": "editor::FindAllReferences",
|
||||
"ctrl-m": "editor::MoveToEnclosingBracket",
|
||||
"ctrl-shift-[": "editor::Fold",
|
||||
"ctrl-shift-]": "editor::UnfoldLines",
|
||||
"ctrl-alt-[": "editor::Fold",
|
||||
"ctrl-alt-]": "editor::UnfoldLines",
|
||||
"ctrl-space": "editor::ShowCompletions",
|
||||
"ctrl-.": "editor::ToggleCodeActions",
|
||||
"alt-ctrl-r": "editor::RevealInFinder",
|
||||
"ctrl-alt-shift-c": "editor::DisplayCursorNames"
|
||||
"ctrl-alt-r": "editor::RevealInFinder",
|
||||
"ctrl-alt-c": "editor::DisplayCursorNames"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -370,18 +335,18 @@
|
||||
{
|
||||
"context": "Pane",
|
||||
"bindings": {
|
||||
"alt-1": ["pane::ActivateItem", 0],
|
||||
"alt-2": ["pane::ActivateItem", 1],
|
||||
"alt-3": ["pane::ActivateItem", 2],
|
||||
"alt-4": ["pane::ActivateItem", 3],
|
||||
"alt-5": ["pane::ActivateItem", 4],
|
||||
"alt-6": ["pane::ActivateItem", 5],
|
||||
"alt-7": ["pane::ActivateItem", 6],
|
||||
"alt-8": ["pane::ActivateItem", 7],
|
||||
"alt-9": ["pane::ActivateItem", 8],
|
||||
"alt-0": "pane::ActivateLastItem",
|
||||
"ctrl-alt--": "pane::GoBack",
|
||||
"ctrl-alt-_": "pane::GoForward",
|
||||
"ctrl-1": ["pane::ActivateItem", 0],
|
||||
"ctrl-2": ["pane::ActivateItem", 1],
|
||||
"ctrl-3": ["pane::ActivateItem", 2],
|
||||
"ctrl-4": ["pane::ActivateItem", 3],
|
||||
"ctrl-5": ["pane::ActivateItem", 4],
|
||||
"ctrl-6": ["pane::ActivateItem", 5],
|
||||
"ctrl-7": ["pane::ActivateItem", 6],
|
||||
"ctrl-8": ["pane::ActivateItem", 7],
|
||||
"ctrl-9": ["pane::ActivateItem", 8],
|
||||
"ctrl-0": "pane::ActivateLastItem",
|
||||
"ctrl--": "pane::GoBack",
|
||||
"ctrl-_": "pane::GoForward",
|
||||
"ctrl-shift-t": "pane::ReopenClosedItem",
|
||||
"ctrl-shift-f": "project_search::ToggleFocus"
|
||||
}
|
||||
@@ -396,8 +361,8 @@
|
||||
// "create_new_window": true
|
||||
// }
|
||||
// ]
|
||||
"alt-ctrl-o": "projects::OpenRecent",
|
||||
"alt-ctrl-shift-b": "branches::OpenRecent",
|
||||
"ctrl-alt-o": "projects::OpenRecent",
|
||||
"ctrl-alt-b": "branches::OpenRecent",
|
||||
"ctrl-~": "workspace::NewTerminal",
|
||||
"ctrl-s": "workspace::Save",
|
||||
"ctrl-k s": "workspace::SaveWithoutFormat",
|
||||
@@ -405,27 +370,24 @@
|
||||
"ctrl-n": "workspace::NewFile",
|
||||
"ctrl-shift-n": "workspace::NewWindow",
|
||||
"ctrl-`": "terminal_panel::ToggleFocus",
|
||||
"alt-1": ["workspace::ActivatePane", 0],
|
||||
"alt-2": ["workspace::ActivatePane", 1],
|
||||
"alt-3": ["workspace::ActivatePane", 2],
|
||||
"alt-4": ["workspace::ActivatePane", 3],
|
||||
"alt-5": ["workspace::ActivatePane", 4],
|
||||
"alt-6": ["workspace::ActivatePane", 5],
|
||||
"alt-7": ["workspace::ActivatePane", 6],
|
||||
"alt-8": ["workspace::ActivatePane", 7],
|
||||
"alt-9": ["workspace::ActivatePane", 8],
|
||||
"ctrl-alt-b": "workspace::ToggleLeftDock",
|
||||
"ctrl-b": "workspace::ToggleRightDock",
|
||||
"ctrl-1": ["workspace::ActivatePane", 0],
|
||||
"ctrl-2": ["workspace::ActivatePane", 1],
|
||||
"ctrl-3": ["workspace::ActivatePane", 2],
|
||||
"ctrl-4": ["workspace::ActivatePane", 3],
|
||||
"ctrl-5": ["workspace::ActivatePane", 4],
|
||||
"ctrl-6": ["workspace::ActivatePane", 5],
|
||||
"ctrl-7": ["workspace::ActivatePane", 6],
|
||||
"ctrl-8": ["workspace::ActivatePane", 7],
|
||||
"ctrl-9": ["workspace::ActivatePane", 8],
|
||||
"ctrl-b": "workspace::ToggleLeftDock",
|
||||
"ctrl-r": "workspace::ToggleRightDock",
|
||||
"ctrl-j": "workspace::ToggleBottomDock",
|
||||
"ctrl-alt-y": "workspace::CloseAllDocks",
|
||||
"ctrl-shift-f": "pane::DeploySearch",
|
||||
"ctrl-k ctrl-s": "zed::OpenKeymap",
|
||||
"ctrl-k ctrl-t": "theme_selector::Toggle",
|
||||
"ctrl-shift-t": "project_symbols::Toggle",
|
||||
"ctrl-k ctrl-s": "zed::OpenKeymap",
|
||||
"ctrl-t": "project_symbols::Toggle",
|
||||
"ctrl-p": "file_finder::Toggle",
|
||||
"ctrl-tab": "tab_switcher::Toggle",
|
||||
"ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }],
|
||||
"ctrl-e": "file_finder::Toggle",
|
||||
"ctrl-shift-p": "command_palette::Toggle",
|
||||
"ctrl-shift-m": "diagnostics::Deploy",
|
||||
"ctrl-shift-e": "project_panel::ToggleFocus",
|
||||
@@ -446,12 +408,15 @@
|
||||
}
|
||||
},
|
||||
// Bindings from Sublime Text
|
||||
// todo(linux) make sure these match linux bindings or remove above comment?
|
||||
{
|
||||
"context": "Editor",
|
||||
"bindings": {
|
||||
"ctrl-shift-k": "editor::DeleteLine",
|
||||
"ctrl-shift-d": "editor::DuplicateLineDown",
|
||||
"ctrl-shift-d": "editor::DuplicateLine",
|
||||
"ctrl-j": "editor::JoinLines",
|
||||
"ctrl-alt-up": "editor::MoveLineUp",
|
||||
"ctrl-alt-down": "editor::MoveLineDown",
|
||||
"ctrl-alt-backspace": "editor::DeleteToPreviousSubwordStart",
|
||||
"ctrl-alt-h": "editor::DeleteToPreviousSubwordStart",
|
||||
"ctrl-alt-delete": "editor::DeleteToNextSubwordEnd",
|
||||
@@ -467,6 +432,7 @@
|
||||
}
|
||||
},
|
||||
// Bindings from Atom
|
||||
// todo(linux) make sure these match linux bindings or remove above comment?
|
||||
{
|
||||
"context": "Pane",
|
||||
"bindings": {
|
||||
@@ -512,7 +478,7 @@
|
||||
"bindings": {
|
||||
"ctrl-alt-shift-f": "workspace::FollowNextCollaborator",
|
||||
// TODO: Move this to a dock open action
|
||||
"ctrl-shift-c": "collab_panel::ToggleFocus",
|
||||
"ctrl-alt-c": "collab_panel::ToggleFocus",
|
||||
"ctrl-alt-i": "zed::DebugElements",
|
||||
"ctrl-:": "editor::ToggleInlayHints"
|
||||
}
|
||||
@@ -539,19 +505,19 @@
|
||||
"left": "project_panel::CollapseSelectedEntry",
|
||||
"right": "project_panel::ExpandSelectedEntry",
|
||||
"ctrl-n": "project_panel::NewFile",
|
||||
"alt-ctrl-n": "project_panel::NewDirectory",
|
||||
"ctrl-alt-n": "project_panel::NewDirectory",
|
||||
"ctrl-x": "project_panel::Cut",
|
||||
"ctrl-c": "project_panel::Copy",
|
||||
"ctrl-v": "project_panel::Paste",
|
||||
"ctrl-alt-c": "project_panel::CopyPath",
|
||||
"alt-ctrl-shift-c": "project_panel::CopyRelativePath",
|
||||
"ctrl-alt-shift-c": "project_panel::CopyRelativePath",
|
||||
"f2": "project_panel::Rename",
|
||||
"enter": "project_panel::Rename",
|
||||
"backspace": "project_panel::Delete",
|
||||
"delete": "project_panel::Delete",
|
||||
"ctrl-backspace": ["project_panel::Delete", { "skip_prompt": true }],
|
||||
"ctrl-delete": ["project_panel::Delete", { "skip_prompt": true }],
|
||||
"alt-ctrl-r": "project_panel::RevealInFinder",
|
||||
"ctrl-alt-r": "project_panel::RevealInFinder",
|
||||
"alt-shift-f": "project_panel::NewSearchInDirectory"
|
||||
}
|
||||
},
|
||||
@@ -592,35 +558,29 @@
|
||||
"escape": "chat_panel::CloseReplyPreview"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "FileFinder",
|
||||
"bindings": { "ctrl-shift-p": "file_finder::SelectPrev" }
|
||||
},
|
||||
{
|
||||
"context": "TabSwitcher",
|
||||
"bindings": {
|
||||
"ctrl-shift-tab": "menu::SelectPrev",
|
||||
"ctrl-backspace": "tab_switcher::CloseSelectedItem"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Terminal",
|
||||
"bindings": {
|
||||
"ctrl-alt-space": "terminal::ShowCharacterPalette",
|
||||
"shift-ctrl-c": "terminal::Copy",
|
||||
"shift-ctrl-v": "terminal::Paste",
|
||||
"ctrl-shift-c": "terminal::Copy",
|
||||
"ctrl-shift-v": "terminal::Paste",
|
||||
"ctrl-k": "terminal::Clear",
|
||||
// Some nice conveniences
|
||||
"ctrl-backspace": ["terminal::SendText", "\u0015"],
|
||||
"ctrl-right": ["terminal::SendText", "\u0005"],
|
||||
"ctrl-left": ["terminal::SendText", "\u0001"],
|
||||
// Terminal.app compatibility
|
||||
"alt-left": ["terminal::SendText", "\u001bb"],
|
||||
"alt-right": ["terminal::SendText", "\u001bf"],
|
||||
// There are conflicting bindings for these keys in the global context.
|
||||
// these bindings override them, remove at your own risk:
|
||||
"up": ["terminal::SendKeystroke", "up"],
|
||||
"pageup": ["terminal::SendKeystroke", "pageup"],
|
||||
"down": ["terminal::SendKeystroke", "down"],
|
||||
"pagedown": ["terminal::SendKeystroke", "pagedown"],
|
||||
"escape": ["terminal::SendKeystroke", "escape"],
|
||||
"enter": ["terminal::SendKeystroke", "enter"],
|
||||
"ctrl-c": ["terminal::SendKeystroke", "ctrl-c"],
|
||||
|
||||
// Some nice conveniences
|
||||
"ctrl-backspace": ["terminal::SendText", "\u0015"],
|
||||
"ctrl-right": ["terminal::SendText", "\u0005"],
|
||||
"ctrl-left": ["terminal::SendText", "\u0001"]
|
||||
"ctrl-c": ["terminal::SendKeystroke", "ctrl-c"]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -13,15 +13,11 @@
|
||||
"cmd-up": "menu::SelectFirst",
|
||||
"cmd-down": "menu::SelectLast",
|
||||
"enter": "menu::Confirm",
|
||||
"ctrl-enter": "menu::SecondaryConfirm",
|
||||
"ctrl-enter": "menu::ShowContextMenu",
|
||||
"cmd-enter": "menu::SecondaryConfirm",
|
||||
"escape": "menu::Cancel",
|
||||
"cmd-escape": "menu::Cancel",
|
||||
"ctrl-escape": "menu::Cancel",
|
||||
"ctrl-c": "menu::Cancel",
|
||||
"shift-enter": "picker::UseSelectedQuery",
|
||||
"alt-enter": ["picker::ConfirmInput", { "secondary": false }],
|
||||
"cmd-alt-enter": ["picker::ConfirmInput", { "secondary": true }],
|
||||
"shift-enter": "menu::UseSelectedQuery",
|
||||
"cmd-shift-w": "workspace::CloseWindow",
|
||||
"shift-escape": "workspace::ToggleZoom",
|
||||
"cmd-o": "workspace::Open",
|
||||
@@ -158,8 +154,7 @@
|
||||
],
|
||||
"ctrl-cmd-space": "editor::ShowCharacterPalette",
|
||||
"cmd-;": "editor::ToggleLineNumbers",
|
||||
"cmd-alt-z": "editor::RevertSelectedHunks",
|
||||
"cmd-alt-g b": "editor::ToggleGitBlame"
|
||||
"cmd-alt-z": "editor::RevertSelectedHunks"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -186,17 +181,17 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && mode == full && inline_completion",
|
||||
"context": "Editor && mode == full && copilot_suggestion",
|
||||
"bindings": {
|
||||
"alt-]": "editor::NextInlineCompletion",
|
||||
"alt-[": "editor::PreviousInlineCompletion",
|
||||
"alt-right": "editor::AcceptPartialInlineCompletion"
|
||||
"alt-]": "copilot::NextSuggestion",
|
||||
"alt-[": "copilot::PreviousSuggestion",
|
||||
"alt-right": "editor::AcceptPartialCopilotSuggestion"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && !inline_completion",
|
||||
"context": "Editor && !copilot_suggestion",
|
||||
"bindings": {
|
||||
"alt-\\": "editor::ShowInlineCompletion"
|
||||
"alt-\\": "copilot::Suggest"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -256,6 +251,7 @@
|
||||
"alt-tab": "search::CycleMode",
|
||||
"cmd-shift-h": "search::ToggleReplace",
|
||||
"alt-cmd-g": "search::ActivateRegexMode",
|
||||
"alt-cmd-s": "search::ActivateSemanticMode",
|
||||
"alt-cmd-x": "search::ActivateTextMode"
|
||||
}
|
||||
},
|
||||
@@ -280,6 +276,7 @@
|
||||
"alt-tab": "search::CycleMode",
|
||||
"cmd-shift-h": "search::ToggleReplace",
|
||||
"alt-cmd-g": "search::ActivateRegexMode",
|
||||
"alt-cmd-s": "search::ActivateSemanticMode",
|
||||
"alt-cmd-x": "search::ActivateTextMode"
|
||||
}
|
||||
},
|
||||
@@ -305,6 +302,7 @@
|
||||
"alt-tab": "search::CycleMode",
|
||||
"alt-cmd-f": "project_search::ToggleFilters",
|
||||
"alt-cmd-g": "search::ActivateRegexMode",
|
||||
"alt-cmd-s": "search::ActivateSemanticMode",
|
||||
"alt-cmd-x": "search::ActivateTextMode"
|
||||
}
|
||||
},
|
||||
@@ -321,8 +319,13 @@
|
||||
"cmd-shift-k": "editor::DeleteLine",
|
||||
"alt-up": "editor::MoveLineUp",
|
||||
"alt-down": "editor::MoveLineDown",
|
||||
"alt-shift-up": "editor::DuplicateLineUp",
|
||||
"alt-shift-down": "editor::DuplicateLineDown",
|
||||
"alt-shift-up": [
|
||||
"editor::DuplicateLine",
|
||||
{
|
||||
"move_upwards": true
|
||||
}
|
||||
],
|
||||
"alt-shift-down": "editor::DuplicateLine",
|
||||
"ctrl-shift-right": "editor::SelectLargerSyntaxNode",
|
||||
"ctrl-shift-left": "editor::SelectSmallerSyntaxNode",
|
||||
"cmd-d": [
|
||||
@@ -365,7 +368,6 @@
|
||||
"f12": "editor::GoToDefinition",
|
||||
"alt-f12": "editor::GoToDefinitionSplit",
|
||||
"cmd-f12": "editor::GoToTypeDefinition",
|
||||
"shift-f12": "editor::GoToImplementation",
|
||||
"alt-cmd-f12": "editor::GoToTypeDefinitionSplit",
|
||||
"alt-shift-f12": "editor::FindAllReferences",
|
||||
"ctrl-m": "editor::MoveToEnclosingBracket",
|
||||
@@ -440,8 +442,6 @@
|
||||
"cmd-k cmd-t": "theme_selector::Toggle",
|
||||
"cmd-t": "project_symbols::Toggle",
|
||||
"cmd-p": "file_finder::Toggle",
|
||||
"ctrl-tab": "tab_switcher::Toggle",
|
||||
"ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }],
|
||||
"cmd-shift-p": "command_palette::Toggle",
|
||||
"cmd-shift-m": "diagnostics::Deploy",
|
||||
"cmd-shift-e": "project_panel::ToggleFocus",
|
||||
@@ -601,14 +601,9 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "FileFinder",
|
||||
"bindings": { "cmd-shift-p": "file_finder::SelectPrev" }
|
||||
},
|
||||
{
|
||||
"context": "TabSwitcher",
|
||||
"context": "ChatPanel > MessageEditor",
|
||||
"bindings": {
|
||||
"ctrl-shift-tab": "menu::SelectPrev",
|
||||
"ctrl-backspace": "tab_switcher::CloseSelectedItem"
|
||||
"escape": "chat_panel::CloseReplyPreview"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ctrl->": "zed::IncreaseBufferFontSize",
|
||||
"ctrl-<": "zed::DecreaseBufferFontSize",
|
||||
"ctrl-shift-j": "editor::JoinLines",
|
||||
"cmd-d": "editor::DuplicateLineDown",
|
||||
"cmd-d": "editor::DuplicateLine",
|
||||
"cmd-backspace": "editor::DeleteLine",
|
||||
"cmd-pagedown": "editor::MovePageDown",
|
||||
"cmd-pageup": "editor::MovePageUp",
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
"cmd-up": "menu::SelectFirst",
|
||||
"cmd-down": "menu::SelectLast",
|
||||
"enter": "menu::Confirm",
|
||||
"ctrl-enter": "menu::SecondaryConfirm",
|
||||
"ctrl-enter": "menu::ShowContextMenu",
|
||||
"cmd-enter": "menu::SecondaryConfirm",
|
||||
"escape": "menu::Cancel",
|
||||
"ctrl-c": "menu::Cancel",
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
"context": "Editor",
|
||||
"bindings": {
|
||||
"cmd-l": "go_to_line::Toggle",
|
||||
"ctrl-shift-d": "editor::DuplicateLineDown",
|
||||
"ctrl-shift-d": "editor::DuplicateLine",
|
||||
"cmd-b": "editor::GoToDefinition",
|
||||
"cmd-j": "editor::ScrollCursorCenter",
|
||||
"cmd-enter": "editor::NewlineBelow",
|
||||
|
||||
@@ -510,7 +510,7 @@
|
||||
"ctrl-[": "vim::NormalBefore",
|
||||
"ctrl-x ctrl-o": "editor::ShowCompletions",
|
||||
"ctrl-x ctrl-a": "assistant::InlineAssist", // zed specific
|
||||
"ctrl-x ctrl-c": "editor::ShowInlineCompletion", // zed specific
|
||||
"ctrl-x ctrl-c": "copilot::Suggest", // zed specific
|
||||
"ctrl-x ctrl-l": "editor::ToggleCodeActions", // zed specific
|
||||
"ctrl-x ctrl-z": "editor::Cancel",
|
||||
"ctrl-w": "editor::DeleteToPreviousWordStart",
|
||||
@@ -546,12 +546,6 @@
|
||||
"escape": "buffer_search::Dismiss"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "EmptyPane || SharedScreen",
|
||||
"bindings": {
|
||||
":": "command_palette::Toggle"
|
||||
}
|
||||
},
|
||||
{
|
||||
// netrw compatibility
|
||||
"context": "ProjectPanel && not_editing",
|
||||
|
||||
@@ -48,8 +48,7 @@
|
||||
// which gives the same size as all other panes.
|
||||
"active_pane_magnification": 1.0,
|
||||
// The key to use for adding multiple cursors
|
||||
// Currently "alt" or "cmd_or_ctrl" (also aliased as
|
||||
// "cmd" and "ctrl") are supported.
|
||||
// Currently "alt" or "cmd" are supported.
|
||||
"multi_cursor_modifier": "alt",
|
||||
// Whether to enable vim modes and key bindings
|
||||
"vim_mode": false,
|
||||
@@ -93,12 +92,6 @@
|
||||
// Whether to automatically type closing characters for you. For example,
|
||||
// when you type (, Zed will automatically add a closing ) at the correct position.
|
||||
"use_autoclose": true,
|
||||
// Controls how the editor handles the autoclosed characters.
|
||||
// When set to `false`(default), skipping over and auto-removing of the closing characters
|
||||
// happen only for auto-inserted characters.
|
||||
// Otherwise(when `true`), the closing characters are always skipped over and auto-removed
|
||||
// no matter how they were inserted.
|
||||
"always_treat_brackets_as_autoclosed": false,
|
||||
// Controls whether copilot provides suggestion immediately
|
||||
// or waits for a `copilot::Toggle`
|
||||
"show_copilot_suggestions": true,
|
||||
@@ -244,10 +237,6 @@
|
||||
"default_width": 380
|
||||
},
|
||||
"assistant": {
|
||||
// Version of this setting.
|
||||
"version": "1",
|
||||
// Whether the assistant is enabled.
|
||||
"enabled": true,
|
||||
// Whether to show the assistant panel button in the status bar.
|
||||
"button": true,
|
||||
// Where to dock the assistant panel. Can be 'left', 'right' or 'bottom'.
|
||||
@@ -256,16 +245,28 @@
|
||||
"default_width": 640,
|
||||
// Default height when the assistant is docked to the bottom.
|
||||
"default_height": 320,
|
||||
// AI provider.
|
||||
// Deprecated: Please use `provider.api_url` instead.
|
||||
// The default OpenAI API endpoint to use when starting new conversations.
|
||||
"openai_api_url": "https://api.openai.com/v1",
|
||||
// Deprecated: Please use `provider.default_model` instead.
|
||||
// The default OpenAI model to use when starting new conversations. This
|
||||
// setting can take three values:
|
||||
//
|
||||
// 1. "gpt-3.5-turbo-0613""
|
||||
// 2. "gpt-4-0613""
|
||||
// 3. "gpt-4-1106-preview"
|
||||
"default_open_ai_model": "gpt-4-1106-preview",
|
||||
"provider": {
|
||||
"name": "openai",
|
||||
// The default model to use when starting new conversations. This
|
||||
"type": "openai",
|
||||
// The default OpenAI API endpoint to use when starting new conversations.
|
||||
"api_url": "https://api.openai.com/v1",
|
||||
// The default OpenAI model to use when starting new conversations. This
|
||||
// setting can take three values:
|
||||
//
|
||||
// 1. "gpt-3.5-turbo"
|
||||
// 2. "gpt-4"
|
||||
// 3. "gpt-4-turbo-preview"
|
||||
"default_model": "gpt-4-turbo-preview"
|
||||
// 1. "gpt-3.5-turbo-0613""
|
||||
// 2. "gpt-4-0613""
|
||||
// 3. "gpt-4-1106-preview"
|
||||
"default_model": "gpt-4-1106-preview"
|
||||
}
|
||||
},
|
||||
// Whether the screen sharing icon is shown in the os status bar.
|
||||
@@ -504,6 +505,10 @@
|
||||
// Existing terminals will not pick up this change until they are recreated.
|
||||
// "max_scroll_history_lines": 10000,
|
||||
},
|
||||
// Difference settings for semantic_index
|
||||
"semantic_index": {
|
||||
"enabled": true
|
||||
},
|
||||
// Settings specific to our elixir integration
|
||||
"elixir": {
|
||||
// Change the LSP zed uses for elixir.
|
||||
@@ -561,9 +566,6 @@
|
||||
"source.organizeImports": true
|
||||
}
|
||||
},
|
||||
"Make": {
|
||||
"hard_tabs": true
|
||||
},
|
||||
"Markdown": {
|
||||
"tab_size": 2,
|
||||
"soft_wrap": "preferred_line_length"
|
||||
@@ -591,13 +593,10 @@
|
||||
},
|
||||
"OCaml Interface": {
|
||||
"tab_size": 2
|
||||
},
|
||||
"Prisma": {
|
||||
"tab_size": 2
|
||||
}
|
||||
},
|
||||
// Zed's Prettier integration settings.
|
||||
// If Prettier is enabled, Zed will use this for its Prettier instance for any applicable file, if
|
||||
// If Prettier is enabled, Zed will use this its Prettier instance for any applicable file, if
|
||||
// project has no other Prettier installed.
|
||||
"prettier": {
|
||||
// Use regular Prettier json configuration:
|
||||
|
||||
@@ -111,7 +111,7 @@
|
||||
"hint": "#618399ff",
|
||||
"hint.background": "#12231fff",
|
||||
"hint.border": "#183934ff",
|
||||
"ignored": "#6b6b73ff",
|
||||
"ignored": "#aca8aeff",
|
||||
"ignored.background": "#262933ff",
|
||||
"ignored.border": "#2b2f38ff",
|
||||
"info": "#10a793ff",
|
||||
|
||||
@@ -111,7 +111,7 @@
|
||||
"hint": "#706897ff",
|
||||
"hint.background": "#161a35ff",
|
||||
"hint.border": "#222953ff",
|
||||
"ignored": "#756f7eff",
|
||||
"ignored": "#898591ff",
|
||||
"ignored.background": "#3a353fff",
|
||||
"ignored.border": "#56505eff",
|
||||
"info": "#566ddaff",
|
||||
@@ -495,7 +495,7 @@
|
||||
"hint": "#776d9dff",
|
||||
"hint.background": "#e1e0f9ff",
|
||||
"hint.border": "#c8c7f2ff",
|
||||
"ignored": "#6e6876ff",
|
||||
"ignored": "#5a5462ff",
|
||||
"ignored.background": "#bfbcc5ff",
|
||||
"ignored.border": "#8f8b96ff",
|
||||
"info": "#586cdaff",
|
||||
@@ -879,7 +879,7 @@
|
||||
"hint": "#b17272ff",
|
||||
"hint.background": "#171e38ff",
|
||||
"hint.border": "#262f56ff",
|
||||
"ignored": "#8f8b77ff",
|
||||
"ignored": "#a4a08bff",
|
||||
"ignored.background": "#45433bff",
|
||||
"ignored.border": "#6c695cff",
|
||||
"info": "#6684e0ff",
|
||||
@@ -1263,7 +1263,7 @@
|
||||
"hint": "#b37979ff",
|
||||
"hint.background": "#e3e5faff",
|
||||
"hint.border": "#cdd1f5ff",
|
||||
"ignored": "#878471ff",
|
||||
"ignored": "#706d5fff",
|
||||
"ignored.background": "#cecab4ff",
|
||||
"ignored.border": "#a8a48eff",
|
||||
"info": "#6684dfff",
|
||||
@@ -1647,7 +1647,7 @@
|
||||
"hint": "#6f815aff",
|
||||
"hint.background": "#142319ff",
|
||||
"hint.border": "#1c3927ff",
|
||||
"ignored": "#7d7c6aff",
|
||||
"ignored": "#91907fff",
|
||||
"ignored.background": "#424136ff",
|
||||
"ignored.border": "#5d5c4cff",
|
||||
"info": "#36a165ff",
|
||||
@@ -2031,7 +2031,7 @@
|
||||
"hint": "#758961ff",
|
||||
"hint.background": "#d9ecdfff",
|
||||
"hint.border": "#bbddc6ff",
|
||||
"ignored": "#767463ff",
|
||||
"ignored": "#61604fff",
|
||||
"ignored.background": "#c5c4b9ff",
|
||||
"ignored.border": "#969585ff",
|
||||
"info": "#37a165ff",
|
||||
@@ -2415,7 +2415,7 @@
|
||||
"hint": "#a77087ff",
|
||||
"hint.background": "#0f1c3dff",
|
||||
"hint.border": "#182d5bff",
|
||||
"ignored": "#8e8683ff",
|
||||
"ignored": "#a79f9dff",
|
||||
"ignored.background": "#443c39ff",
|
||||
"ignored.border": "#665f5cff",
|
||||
"info": "#407ee6ff",
|
||||
@@ -2799,7 +2799,7 @@
|
||||
"hint": "#a67287ff",
|
||||
"hint.background": "#dfe3fbff",
|
||||
"hint.border": "#c6cef7ff",
|
||||
"ignored": "#837b78ff",
|
||||
"ignored": "#6a6360ff",
|
||||
"ignored.background": "#ccc7c5ff",
|
||||
"ignored.border": "#aaa3a1ff",
|
||||
"info": "#407ee6ff",
|
||||
@@ -3183,7 +3183,7 @@
|
||||
"hint": "#8d70a8ff",
|
||||
"hint.background": "#0d1a43ff",
|
||||
"hint.border": "#192961ff",
|
||||
"ignored": "#908190ff",
|
||||
"ignored": "#a899a8ff",
|
||||
"ignored.background": "#433a43ff",
|
||||
"ignored.border": "#675b67ff",
|
||||
"info": "#5169ebff",
|
||||
@@ -3567,7 +3567,7 @@
|
||||
"hint": "#8c70a6ff",
|
||||
"hint.background": "#e2dffcff",
|
||||
"hint.border": "#cac7faff",
|
||||
"ignored": "#857785ff",
|
||||
"ignored": "#6b5e6bff",
|
||||
"ignored.background": "#c6b8c6ff",
|
||||
"ignored.border": "#ad9dadff",
|
||||
"info": "#5169ebff",
|
||||
@@ -3951,7 +3951,7 @@
|
||||
"hint": "#52809aff",
|
||||
"hint.background": "#121c24ff",
|
||||
"hint.border": "#1a2f3cff",
|
||||
"ignored": "#688c9dff",
|
||||
"ignored": "#7c9fb3ff",
|
||||
"ignored.background": "#33444dff",
|
||||
"ignored.border": "#4f6a78ff",
|
||||
"info": "#267eadff",
|
||||
@@ -4335,7 +4335,7 @@
|
||||
"hint": "#5a87a0ff",
|
||||
"hint.background": "#d8e4eeff",
|
||||
"hint.border": "#b9cee0ff",
|
||||
"ignored": "#628496ff",
|
||||
"ignored": "#526f7dff",
|
||||
"ignored.background": "#a6cadcff",
|
||||
"ignored.border": "#80a4b6ff",
|
||||
"info": "#267eadff",
|
||||
@@ -4719,7 +4719,7 @@
|
||||
"hint": "#8a647aff",
|
||||
"hint.background": "#1c1b29ff",
|
||||
"hint.border": "#2c2b45ff",
|
||||
"ignored": "#756e6eff",
|
||||
"ignored": "#898383ff",
|
||||
"ignored.background": "#3b3535ff",
|
||||
"ignored.border": "#564e4eff",
|
||||
"info": "#7272caff",
|
||||
@@ -5103,7 +5103,7 @@
|
||||
"hint": "#91697fff",
|
||||
"hint.background": "#e4e1f5ff",
|
||||
"hint.border": "#cecaecff",
|
||||
"ignored": "#6e6666ff",
|
||||
"ignored": "#5a5252ff",
|
||||
"ignored.background": "#c1bbbbff",
|
||||
"ignored.border": "#8e8989ff",
|
||||
"info": "#7272caff",
|
||||
@@ -5487,7 +5487,7 @@
|
||||
"hint": "#607e76ff",
|
||||
"hint.background": "#151e20ff",
|
||||
"hint.border": "#1f3233ff",
|
||||
"ignored": "#6f7e74ff",
|
||||
"ignored": "#859188ff",
|
||||
"ignored.background": "#353f39ff",
|
||||
"ignored.border": "#505e55ff",
|
||||
"info": "#468b8fff",
|
||||
@@ -5871,7 +5871,7 @@
|
||||
"hint": "#66847cff",
|
||||
"hint.background": "#dae7e8ff",
|
||||
"hint.border": "#bed4d6ff",
|
||||
"ignored": "#68766dff",
|
||||
"ignored": "#546259ff",
|
||||
"ignored.background": "#bcc5bfff",
|
||||
"ignored.border": "#8b968eff",
|
||||
"info": "#488b90ff",
|
||||
@@ -6255,7 +6255,7 @@
|
||||
"hint": "#008b9fff",
|
||||
"hint.background": "#051949ff",
|
||||
"hint.border": "#102667ff",
|
||||
"ignored": "#778f77ff",
|
||||
"ignored": "#8ba48bff",
|
||||
"ignored.background": "#3b453bff",
|
||||
"ignored.border": "#5c6c5cff",
|
||||
"info": "#3e62f4ff",
|
||||
@@ -6639,7 +6639,7 @@
|
||||
"hint": "#008fa1ff",
|
||||
"hint.background": "#e1ddfeff",
|
||||
"hint.border": "#c9c4fdff",
|
||||
"ignored": "#718771ff",
|
||||
"ignored": "#5f705fff",
|
||||
"ignored.background": "#b4ceb4ff",
|
||||
"ignored.border": "#8ea88eff",
|
||||
"info": "#3e61f4ff",
|
||||
@@ -7023,7 +7023,7 @@
|
||||
"hint": "#6c81a5ff",
|
||||
"hint.background": "#161f2bff",
|
||||
"hint.border": "#203348ff",
|
||||
"ignored": "#7e849eff",
|
||||
"ignored": "#959bb2ff",
|
||||
"ignored.background": "#3e4769ff",
|
||||
"ignored.border": "#5b6385ff",
|
||||
"info": "#3e8ed0ff",
|
||||
@@ -7407,7 +7407,7 @@
|
||||
"hint": "#7087b2ff",
|
||||
"hint.background": "#dde7f6ff",
|
||||
"hint.border": "#c2d5efff",
|
||||
"ignored": "#767d9aff",
|
||||
"ignored": "#5f6789ff",
|
||||
"ignored.background": "#c1c5d8ff",
|
||||
"ignored.border": "#9a9fb6ff",
|
||||
"info": "#3e8fd0ff",
|
||||
|
||||
@@ -111,7 +111,7 @@
|
||||
"hint": "#628b80ff",
|
||||
"hint.background": "#0d2f4eff",
|
||||
"hint.border": "#1b4a6eff",
|
||||
"ignored": "#696a6aff",
|
||||
"ignored": "#8a8986ff",
|
||||
"ignored.background": "#313337ff",
|
||||
"ignored.border": "#3f4043ff",
|
||||
"info": "#5ac1feff",
|
||||
@@ -480,7 +480,7 @@
|
||||
"hint": "#8ca7c2ff",
|
||||
"hint.background": "#deebfaff",
|
||||
"hint.border": "#c4daf6ff",
|
||||
"ignored": "#a9acaeff",
|
||||
"ignored": "#8b8e92ff",
|
||||
"ignored.background": "#dcdddeff",
|
||||
"ignored.border": "#cfd1d2ff",
|
||||
"info": "#3b9ee5ff",
|
||||
@@ -849,7 +849,7 @@
|
||||
"hint": "#7399a3ff",
|
||||
"hint.background": "#123950ff",
|
||||
"hint.border": "#24556fff",
|
||||
"ignored": "#7b7d7fff",
|
||||
"ignored": "#9a9a98ff",
|
||||
"ignored.background": "#464a52ff",
|
||||
"ignored.border": "#53565dff",
|
||||
"info": "#72cffeff",
|
||||
|
||||
@@ -111,7 +111,7 @@
|
||||
"hint": "#8c957dff",
|
||||
"hint.background": "#1e2321ff",
|
||||
"hint.border": "#303a36ff",
|
||||
"ignored": "#998b78ff",
|
||||
"ignored": "#c5b597ff",
|
||||
"ignored.background": "#4c4642ff",
|
||||
"ignored.border": "#5b534dff",
|
||||
"info": "#83a598ff",
|
||||
@@ -485,7 +485,7 @@
|
||||
"hint": "#6a695bff",
|
||||
"hint.background": "#1e2321ff",
|
||||
"hint.border": "#303a36ff",
|
||||
"ignored": "#998b78ff",
|
||||
"ignored": "#c5b597ff",
|
||||
"ignored.background": "#4c4642ff",
|
||||
"ignored.border": "#5b534dff",
|
||||
"info": "#83a598ff",
|
||||
@@ -859,7 +859,7 @@
|
||||
"hint": "#8c957dff",
|
||||
"hint.background": "#1e2321ff",
|
||||
"hint.border": "#303a36ff",
|
||||
"ignored": "#998b78ff",
|
||||
"ignored": "#c5b597ff",
|
||||
"ignored.background": "#4c4642ff",
|
||||
"ignored.border": "#5b534dff",
|
||||
"info": "#83a598ff",
|
||||
@@ -1233,7 +1233,7 @@
|
||||
"hint": "#677562ff",
|
||||
"hint.background": "#d2dee2ff",
|
||||
"hint.border": "#adc5ccff",
|
||||
"ignored": "#897b6eff",
|
||||
"ignored": "#5f5650ff",
|
||||
"ignored.background": "#d9c8a4ff",
|
||||
"ignored.border": "#c8b899ff",
|
||||
"info": "#0b6678ff",
|
||||
@@ -1607,7 +1607,7 @@
|
||||
"hint": "#677562ff",
|
||||
"hint.background": "#d2dee2ff",
|
||||
"hint.border": "#adc5ccff",
|
||||
"ignored": "#897b6eff",
|
||||
"ignored": "#5f5650ff",
|
||||
"ignored.background": "#d9c8a4ff",
|
||||
"ignored.border": "#c8b899ff",
|
||||
"info": "#0b6678ff",
|
||||
@@ -1981,7 +1981,7 @@
|
||||
"hint": "#677562ff",
|
||||
"hint.background": "#d2dee2ff",
|
||||
"hint.border": "#adc5ccff",
|
||||
"ignored": "#897b6eff",
|
||||
"ignored": "#5f5650ff",
|
||||
"ignored.background": "#d9c8a4ff",
|
||||
"ignored.border": "#c8b899ff",
|
||||
"info": "#0b6678ff",
|
||||
|
||||
@@ -111,7 +111,7 @@
|
||||
"hint": "#5a6f89ff",
|
||||
"hint.background": "#18243dff",
|
||||
"hint.border": "#293b5bff",
|
||||
"ignored": "#555a63ff",
|
||||
"ignored": "#838994ff",
|
||||
"ignored.background": "#3b414dff",
|
||||
"ignored.border": "#464b57ff",
|
||||
"info": "#74ade8ff",
|
||||
@@ -485,7 +485,7 @@
|
||||
"hint": "#9294beff",
|
||||
"hint.background": "#e2e2faff",
|
||||
"hint.border": "#cbcdf6ff",
|
||||
"ignored": "#a1a1a3ff",
|
||||
"ignored": "#7e8087ff",
|
||||
"ignored.background": "#dcdcddff",
|
||||
"ignored.border": "#c9c9caff",
|
||||
"info": "#5c78e2ff",
|
||||
|
||||
@@ -111,7 +111,7 @@
|
||||
"hint": "#5e768cff",
|
||||
"hint.background": "#2f3639ff",
|
||||
"hint.border": "#435255ff",
|
||||
"ignored": "#2f2b43ff",
|
||||
"ignored": "#74708dff",
|
||||
"ignored.background": "#292738ff",
|
||||
"ignored.border": "#423f55ff",
|
||||
"info": "#9bced6ff",
|
||||
@@ -490,7 +490,7 @@
|
||||
"hint": "#7a92aaff",
|
||||
"hint.background": "#dde9ebff",
|
||||
"hint.border": "#c3d7dbff",
|
||||
"ignored": "#938fa3ff",
|
||||
"ignored": "#706c8cff",
|
||||
"ignored.background": "#dcd8d8ff",
|
||||
"ignored.border": "#dcd6d5ff",
|
||||
"info": "#57949fff",
|
||||
@@ -869,7 +869,7 @@
|
||||
"hint": "#728aa2ff",
|
||||
"hint.background": "#2f3639ff",
|
||||
"hint.border": "#435255ff",
|
||||
"ignored": "#605d7aff",
|
||||
"ignored": "#85819eff",
|
||||
"ignored.background": "#38354eff",
|
||||
"ignored.border": "#504c68ff",
|
||||
"info": "#9bced6ff",
|
||||
|
||||
@@ -111,7 +111,7 @@
|
||||
"hint": "#727d68ff",
|
||||
"hint.background": "#171e1eff",
|
||||
"hint.border": "#223131ff",
|
||||
"ignored": "#827568ff",
|
||||
"ignored": "#a69782ff",
|
||||
"ignored.background": "#333944ff",
|
||||
"ignored.border": "#3d4350ff",
|
||||
"info": "#518b8bff",
|
||||
|
||||
@@ -111,7 +111,7 @@
|
||||
"hint": "#4f8297ff",
|
||||
"hint.background": "#141f2cff",
|
||||
"hint.border": "#1b3149ff",
|
||||
"ignored": "#6f8389ff",
|
||||
"ignored": "#93a1a1ff",
|
||||
"ignored.background": "#073743ff",
|
||||
"ignored.border": "#2b4e58ff",
|
||||
"info": "#278ad1ff",
|
||||
@@ -480,7 +480,7 @@
|
||||
"hint": "#5789a3ff",
|
||||
"hint.background": "#dbe6f6ff",
|
||||
"hint.border": "#bfd3efff",
|
||||
"ignored": "#6a7f86ff",
|
||||
"ignored": "#34555eff",
|
||||
"ignored.background": "#cfd0c4ff",
|
||||
"ignored.border": "#9faaa8ff",
|
||||
"info": "#288bd1ff",
|
||||
|
||||
@@ -111,7 +111,7 @@
|
||||
"hint": "#246e61ff",
|
||||
"hint.background": "#0e2242ff",
|
||||
"hint.border": "#193760ff",
|
||||
"ignored": "#4c4735ff",
|
||||
"ignored": "#736e55ff",
|
||||
"ignored.background": "#2a261cff",
|
||||
"ignored.border": "#302c21ff",
|
||||
"info": "#499befff",
|
||||
|
||||
@@ -16,7 +16,6 @@ doctest = false
|
||||
anyhow.workspace = true
|
||||
auto_update.workspace = true
|
||||
editor.workspace = true
|
||||
extension.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use auto_update::{AutoUpdateStatus, AutoUpdater, DismissErrorMessage};
|
||||
use editor::Editor;
|
||||
use extension::ExtensionStore;
|
||||
use futures::StreamExt;
|
||||
use gpui::{
|
||||
actions, svg, AppContext, CursorStyle, EventEmitter, InteractiveElement as _, Model,
|
||||
@@ -206,7 +205,7 @@ impl ActivityIndicator {
|
||||
}
|
||||
LanguageServerBinaryStatus::Downloading => downloading.push(status.name.0.as_ref()),
|
||||
LanguageServerBinaryStatus::Failed { .. } => failed.push(status.name.0.as_ref()),
|
||||
LanguageServerBinaryStatus::None => {}
|
||||
LanguageServerBinaryStatus::Downloaded | LanguageServerBinaryStatus::Cached => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -289,18 +288,6 @@ impl ActivityIndicator {
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(extension_store) =
|
||||
ExtensionStore::try_global(cx).map(|extension_store| extension_store.read(cx))
|
||||
{
|
||||
if let Some(extension_id) = extension_store.outstanding_operations().keys().next() {
|
||||
return Content {
|
||||
icon: Some(DOWNLOAD_ICON),
|
||||
message: format!("Updating {extension_id} extension…"),
|
||||
on_click: None,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
41
crates/ai/Cargo.toml
Normal file
41
crates/ai/Cargo.toml
Normal file
@@ -0,0 +1,41 @@
|
||||
[package]
|
||||
name = "ai"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/ai.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
bincode = "1.3.3"
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
isahc.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
matrixmultiply = "0.3.7"
|
||||
ordered-float.workspace = true
|
||||
parking_lot.workspace = true
|
||||
parse_duration = "2.1.1"
|
||||
postage.workspace = true
|
||||
rand.workspace = true
|
||||
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
tiktoken-rs.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
8
crates/ai/src/ai.rs
Normal file
8
crates/ai/src/ai.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
pub mod auth;
|
||||
pub mod completion;
|
||||
pub mod embedding;
|
||||
pub mod models;
|
||||
pub mod prompts;
|
||||
pub mod providers;
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
||||
23
crates/ai/src/auth.rs
Normal file
23
crates/ai/src/auth.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use futures::future::BoxFuture;
|
||||
use gpui::AppContext;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ProviderCredential {
|
||||
Credentials { api_key: String },
|
||||
NoCredentials,
|
||||
NotNeeded,
|
||||
}
|
||||
|
||||
pub trait CredentialProvider: Send + Sync {
|
||||
fn has_credentials(&self) -> bool;
|
||||
#[must_use]
|
||||
fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential>;
|
||||
#[must_use]
|
||||
fn save_credentials(
|
||||
&self,
|
||||
cx: &mut AppContext,
|
||||
credential: ProviderCredential,
|
||||
) -> BoxFuture<()>;
|
||||
#[must_use]
|
||||
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()>;
|
||||
}
|
||||
23
crates/ai/src/completion.rs
Normal file
23
crates/ai/src/completion.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use anyhow::Result;
|
||||
use futures::{future::BoxFuture, stream::BoxStream};
|
||||
|
||||
use crate::{auth::CredentialProvider, models::LanguageModel};
|
||||
|
||||
pub trait CompletionRequest: Send + Sync {
|
||||
fn data(&self) -> serde_json::Result<String>;
|
||||
}
|
||||
|
||||
pub trait CompletionProvider: CredentialProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider>;
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn CompletionProvider> {
|
||||
fn clone(&self) -> Box<dyn CompletionProvider> {
|
||||
self.box_clone()
|
||||
}
|
||||
}
|
||||
121
crates/ai/src/embedding.rs
Normal file
121
crates/ai/src/embedding.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use ordered_float::OrderedFloat;
|
||||
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
|
||||
use rusqlite::ToSql;
|
||||
|
||||
use crate::auth::CredentialProvider;
|
||||
use crate::models::LanguageModel;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct Embedding(pub Vec<f32>);
|
||||
|
||||
// This is needed for semantic index functionality
|
||||
// Unfortunately it has to live wherever the "Embedding" struct is created.
|
||||
// Keeping this in here though, introduces a 'rusqlite' dependency into AI
|
||||
// which is less than ideal
|
||||
impl FromSql for Embedding {
|
||||
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
|
||||
let bytes = value.as_blob()?;
|
||||
let embedding =
|
||||
bincode::deserialize(bytes).map_err(|err| rusqlite::types::FromSqlError::Other(err))?;
|
||||
Ok(Embedding(embedding))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for Embedding {
|
||||
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
|
||||
let bytes = bincode::serialize(&self.0)
|
||||
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
|
||||
Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
|
||||
}
|
||||
}
|
||||
impl From<Vec<f32>> for Embedding {
|
||||
fn from(value: Vec<f32>) -> Self {
|
||||
Embedding(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
|
||||
let len = self.0.len();
|
||||
assert_eq!(len, other.0.len());
|
||||
|
||||
let mut result = 0.0;
|
||||
unsafe {
|
||||
matrixmultiply::sgemm(
|
||||
1,
|
||||
len,
|
||||
1,
|
||||
1.0,
|
||||
self.0.as_ptr(),
|
||||
len as isize,
|
||||
1,
|
||||
other.0.as_ptr(),
|
||||
1,
|
||||
len as isize,
|
||||
0.0,
|
||||
&mut result as *mut f32,
|
||||
1,
|
||||
1,
|
||||
);
|
||||
}
|
||||
OrderedFloat(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait EmbeddingProvider: CredentialProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel>;
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
|
||||
fn max_tokens_per_batch(&self) -> usize;
|
||||
fn rate_limit_expiration(&self) -> Option<Instant>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::prelude::*;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_similarity(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
Embedding::from(vec![1., 0., 0., 0., 0.])
|
||||
.similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
|
||||
0.
|
||||
);
|
||||
assert_eq!(
|
||||
Embedding::from(vec![2., 0., 0., 0., 0.])
|
||||
.similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
|
||||
6.
|
||||
);
|
||||
|
||||
for _ in 0..100 {
|
||||
let size = 1536;
|
||||
let mut a = vec![0.; size];
|
||||
let mut b = vec![0.; size];
|
||||
for (a, b) in a.iter_mut().zip(b.iter_mut()) {
|
||||
*a = rng.gen();
|
||||
*b = rng.gen();
|
||||
}
|
||||
let a = Embedding::from(a);
|
||||
let b = Embedding::from(b);
|
||||
|
||||
assert_eq!(
|
||||
round_to_decimals(a.similarity(&b), 1),
|
||||
round_to_decimals(reference_dot(&a.0, &b.0), 1)
|
||||
);
|
||||
}
|
||||
|
||||
fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
|
||||
let factor = 10.0_f32.powi(decimal_places);
|
||||
(n * factor).round() / factor
|
||||
}
|
||||
|
||||
fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
|
||||
OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
|
||||
}
|
||||
}
|
||||
}
|
||||
16
crates/ai/src/models.rs
Normal file
16
crates/ai/src/models.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
pub enum TruncationDirection {
|
||||
Start,
|
||||
End,
|
||||
}
|
||||
|
||||
pub trait LanguageModel {
|
||||
fn name(&self) -> String;
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String>;
|
||||
fn capacity(&self) -> anyhow::Result<usize>;
|
||||
}
|
||||
337
crates/ai/src/prompts/base.rs
Normal file
337
crates/ai/src/prompts/base.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
use std::cmp::Reverse;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
use language::BufferSnapshot;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::models::LanguageModel;
|
||||
use crate::prompts::repository_context::PromptCodeSnippet;
|
||||
|
||||
pub(crate) enum PromptFileType {
|
||||
Text,
|
||||
Code,
|
||||
}
|
||||
|
||||
// TODO: Set this up to manage for defaults well
|
||||
pub struct PromptArguments {
|
||||
pub model: Arc<dyn LanguageModel>,
|
||||
pub user_prompt: Option<String>,
|
||||
pub language_name: Option<String>,
|
||||
pub project_name: Option<String>,
|
||||
pub snippets: Vec<PromptCodeSnippet>,
|
||||
pub reserved_tokens: usize,
|
||||
pub buffer: Option<BufferSnapshot>,
|
||||
pub selected_range: Option<Range<usize>>,
|
||||
}
|
||||
|
||||
impl PromptArguments {
|
||||
pub(crate) fn get_file_type(&self) -> PromptFileType {
|
||||
if self
|
||||
.language_name
|
||||
.as_ref()
|
||||
.map(|name| !["Markdown", "Plain Text"].contains(&name.as_str()))
|
||||
.unwrap_or(true)
|
||||
{
|
||||
PromptFileType::Code
|
||||
} else {
|
||||
PromptFileType::Text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait PromptTemplate {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)>;
|
||||
}
|
||||
|
||||
#[repr(i8)]
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum PromptPriority {
|
||||
/// Ignores truncation.
|
||||
Mandatory,
|
||||
/// Truncates based on priority.
|
||||
Ordered { order: usize },
|
||||
}
|
||||
|
||||
impl PartialOrd for PromptPriority {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for PromptPriority {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
match (self, other) {
|
||||
(Self::Mandatory, Self::Mandatory) => std::cmp::Ordering::Equal,
|
||||
(Self::Mandatory, Self::Ordered { .. }) => std::cmp::Ordering::Greater,
|
||||
(Self::Ordered { .. }, Self::Mandatory) => std::cmp::Ordering::Less,
|
||||
(Self::Ordered { order: a }, Self::Ordered { order: b }) => b.cmp(a),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PromptChain {
|
||||
args: PromptArguments,
|
||||
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
|
||||
}
|
||||
|
||||
impl PromptChain {
|
||||
pub fn new(
|
||||
args: PromptArguments,
|
||||
templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
|
||||
) -> Self {
|
||||
PromptChain { args, templates }
|
||||
}
|
||||
|
||||
pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
|
||||
// Argsort based on Prompt Priority
|
||||
let separator = "\n";
|
||||
let separator_tokens = self.args.model.count_tokens(separator)?;
|
||||
let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
|
||||
sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
|
||||
|
||||
let mut tokens_outstanding = if truncate {
|
||||
Some(self.args.model.capacity()? - self.args.reserved_tokens)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut prompts = vec!["".to_string(); sorted_indices.len()];
|
||||
for idx in sorted_indices {
|
||||
let (_, template) = &self.templates[idx];
|
||||
|
||||
if let Some((template_prompt, prompt_token_count)) =
|
||||
template.generate(&self.args, tokens_outstanding).log_err()
|
||||
{
|
||||
if template_prompt != "" {
|
||||
prompts[idx] = template_prompt;
|
||||
|
||||
if let Some(remaining_tokens) = tokens_outstanding {
|
||||
let new_tokens = prompt_token_count + separator_tokens;
|
||||
tokens_outstanding = if remaining_tokens > new_tokens {
|
||||
Some(remaining_tokens - new_tokens)
|
||||
} else {
|
||||
Some(0)
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prompts.retain(|x| x != "");
|
||||
|
||||
let full_prompt = prompts.join(separator);
|
||||
let total_token_count = self.args.model.count_tokens(&full_prompt)?;
|
||||
anyhow::Ok((prompts.join(separator), total_token_count))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use crate::models::TruncationDirection;
|
||||
use crate::test::FakeLanguageModel;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
pub fn test_prompt_chain() {
|
||||
struct TestPromptTemplate {}
|
||||
impl PromptTemplate for TestPromptTemplate {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let mut content = "This is a test prompt template".to_string();
|
||||
|
||||
let mut token_count = args.model.count_tokens(&content)?;
|
||||
if let Some(max_token_length) = max_token_length {
|
||||
if token_count > max_token_length {
|
||||
content = args.model.truncate(
|
||||
&content,
|
||||
max_token_length,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
token_count = max_token_length;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok((content, token_count))
|
||||
}
|
||||
}
|
||||
|
||||
struct TestLowPriorityTemplate {}
|
||||
impl PromptTemplate for TestLowPriorityTemplate {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let mut content = "This is a low priority test prompt template".to_string();
|
||||
|
||||
let mut token_count = args.model.count_tokens(&content)?;
|
||||
if let Some(max_token_length) = max_token_length {
|
||||
if token_count > max_token_length {
|
||||
content = args.model.truncate(
|
||||
&content,
|
||||
max_token_length,
|
||||
TruncationDirection::End,
|
||||
)?;
|
||||
token_count = max_token_length;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok((content, token_count))
|
||||
}
|
||||
}
|
||||
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens: 0,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(false).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
prompt,
|
||||
"This is a test prompt template\nThis is a low priority test prompt template"
|
||||
.to_string()
|
||||
);
|
||||
|
||||
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
|
||||
|
||||
// Testing with Truncation Off
|
||||
// Should ignore capacity and return all prompts
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens: 0,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(false).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
prompt,
|
||||
"This is a test prompt template\nThis is a low priority test prompt template"
|
||||
.to_string()
|
||||
);
|
||||
|
||||
assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
|
||||
|
||||
// Testing with Truncation Off
|
||||
// Should ignore capacity and return all prompts
|
||||
let capacity = 20;
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens: 0,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 2 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(true).unwrap();
|
||||
|
||||
assert_eq!(prompt, "This is a test promp".to_string());
|
||||
assert_eq!(token_count, capacity);
|
||||
|
||||
// Change Ordering of Prompts Based on Priority
|
||||
let capacity = 120;
|
||||
let reserved_tokens = 10;
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
|
||||
let args = PromptArguments {
|
||||
model: model.clone(),
|
||||
language_name: None,
|
||||
project_name: None,
|
||||
snippets: Vec::new(),
|
||||
reserved_tokens,
|
||||
buffer: None,
|
||||
selected_range: None,
|
||||
user_prompt: None,
|
||||
};
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(
|
||||
PromptPriority::Mandatory,
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(TestPromptTemplate {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(TestLowPriorityTemplate {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
|
||||
let (prompt, token_count) = chain.generate(true).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
prompt,
|
||||
"This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
|
||||
.to_string()
|
||||
);
|
||||
assert_eq!(token_count, capacity - reserved_tokens);
|
||||
}
|
||||
}
|
||||
164
crates/ai/src/prompts/file_context.rs
Normal file
164
crates/ai/src/prompts/file_context.rs
Normal file
@@ -0,0 +1,164 @@
|
||||
use anyhow::anyhow;
|
||||
use language::BufferSnapshot;
|
||||
use language::ToOffset;
|
||||
|
||||
use crate::models::LanguageModel;
|
||||
use crate::models::TruncationDirection;
|
||||
use crate::prompts::base::PromptArguments;
|
||||
use crate::prompts::base::PromptTemplate;
|
||||
use std::fmt::Write;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn retrieve_context(
|
||||
buffer: &BufferSnapshot,
|
||||
selected_range: &Option<Range<usize>>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
max_token_count: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize, bool)> {
|
||||
let mut prompt = String::new();
|
||||
let mut truncated = false;
|
||||
if let Some(selected_range) = selected_range {
|
||||
let start = selected_range.start.to_offset(buffer);
|
||||
let end = selected_range.end.to_offset(buffer);
|
||||
|
||||
let start_window = buffer.text_for_range(0..start).collect::<String>();
|
||||
|
||||
let mut selected_window = String::new();
|
||||
if start == end {
|
||||
write!(selected_window, "<|START|>").unwrap();
|
||||
} else {
|
||||
write!(selected_window, "<|START|").unwrap();
|
||||
}
|
||||
|
||||
write!(
|
||||
selected_window,
|
||||
"{}",
|
||||
buffer.text_for_range(start..end).collect::<String>()
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
if start != end {
|
||||
write!(selected_window, "|END|>").unwrap();
|
||||
}
|
||||
|
||||
let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
|
||||
|
||||
if let Some(max_token_count) = max_token_count {
|
||||
let selected_tokens = model.count_tokens(&selected_window)?;
|
||||
if selected_tokens > max_token_count {
|
||||
return Err(anyhow!(
|
||||
"selected range is greater than model context window, truncation not possible"
|
||||
));
|
||||
};
|
||||
|
||||
let mut remaining_tokens = max_token_count - selected_tokens;
|
||||
let start_window_tokens = model.count_tokens(&start_window)?;
|
||||
let end_window_tokens = model.count_tokens(&end_window)?;
|
||||
let outside_tokens = start_window_tokens + end_window_tokens;
|
||||
if outside_tokens > remaining_tokens {
|
||||
let (start_goal_tokens, end_goal_tokens) =
|
||||
if start_window_tokens < end_window_tokens {
|
||||
let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
|
||||
remaining_tokens -= start_goal_tokens;
|
||||
let end_goal_tokens = remaining_tokens.min(end_window_tokens);
|
||||
(start_goal_tokens, end_goal_tokens)
|
||||
} else {
|
||||
let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
|
||||
remaining_tokens -= end_goal_tokens;
|
||||
let start_goal_tokens = remaining_tokens.min(start_window_tokens);
|
||||
(start_goal_tokens, end_goal_tokens)
|
||||
};
|
||||
|
||||
let truncated_start_window =
|
||||
model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
|
||||
let truncated_end_window =
|
||||
model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
|
||||
writeln!(
|
||||
prompt,
|
||||
"{truncated_start_window}{selected_window}{truncated_end_window}"
|
||||
)
|
||||
.unwrap();
|
||||
truncated = true;
|
||||
} else {
|
||||
writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
|
||||
}
|
||||
} else {
|
||||
// If we dont have a selected range, include entire file.
|
||||
writeln!(prompt, "{}", &buffer.text()).unwrap();
|
||||
|
||||
// Dumb truncation strategy
|
||||
if let Some(max_token_count) = max_token_count {
|
||||
if model.count_tokens(&prompt)? > max_token_count {
|
||||
truncated = true;
|
||||
prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let token_count = model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, token_count, truncated))
|
||||
}
|
||||
|
||||
pub struct FileContext {}
|
||||
|
||||
impl PromptTemplate for FileContext {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
if let Some(buffer) = &args.buffer {
|
||||
let mut prompt = String::new();
|
||||
// Add Initial Preamble
|
||||
// TODO: Do we want to add the path in here?
|
||||
writeln!(
|
||||
prompt,
|
||||
"The file you are currently working on has the following content:"
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let language_name = args
|
||||
.language_name
|
||||
.clone()
|
||||
.unwrap_or("".to_string())
|
||||
.to_lowercase();
|
||||
|
||||
let (context, _, truncated) = retrieve_context(
|
||||
buffer,
|
||||
&args.selected_range,
|
||||
args.model.clone(),
|
||||
max_token_length,
|
||||
)?;
|
||||
writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
|
||||
|
||||
if truncated {
|
||||
writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
|
||||
}
|
||||
|
||||
if let Some(selected_range) = &args.selected_range {
|
||||
let start = selected_range.start.to_offset(buffer);
|
||||
let end = selected_range.end.to_offset(buffer);
|
||||
|
||||
if start == end {
|
||||
writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// Really dumb truncation strategy
|
||||
if let Some(max_tokens) = max_token_length {
|
||||
prompt = args
|
||||
.model
|
||||
.truncate(&prompt, max_tokens, TruncationDirection::End)?;
|
||||
}
|
||||
|
||||
let token_count = args.model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, token_count))
|
||||
} else {
|
||||
Err(anyhow!("no buffer provided to retrieve file context from"))
|
||||
}
|
||||
}
|
||||
}
|
||||
99
crates/ai/src/prompts/generate.rs
Normal file
99
crates/ai/src/prompts/generate.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||
use anyhow::anyhow;
|
||||
use std::fmt::Write;
|
||||
|
||||
pub fn capitalize(s: &str) -> String {
|
||||
let mut c = s.chars();
|
||||
match c.next() {
|
||||
None => String::new(),
|
||||
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GenerateInlineContent {}
|
||||
|
||||
impl PromptTemplate for GenerateInlineContent {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let Some(user_prompt) = &args.user_prompt else {
|
||||
return Err(anyhow!("user prompt not provided"));
|
||||
};
|
||||
|
||||
let file_type = args.get_file_type();
|
||||
let content_type = match &file_type {
|
||||
PromptFileType::Code => "code",
|
||||
PromptFileType::Text => "text",
|
||||
};
|
||||
|
||||
let mut prompt = String::new();
|
||||
|
||||
if let Some(selected_range) = &args.selected_range {
|
||||
if selected_range.start == selected_range.end {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Assume the cursor is located where the `<|START|>` span is."
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"{} can't be replaced, so assume your answer will be inserted at the cursor.",
|
||||
capitalize(content_type)
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Generate {content_type} based on the users prompt: {user_prompt}",
|
||||
)
|
||||
.unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
|
||||
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
|
||||
writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
|
||||
}
|
||||
} else {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Generate {content_type} based on the users prompt: {user_prompt}"
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
if let Some(language_name) = &args.language_name {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Your answer MUST always and only be valid {}.",
|
||||
language_name
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
writeln!(prompt, "Never make remarks about the output.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Do not return anything else, except the generated {content_type}."
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
match file_type {
|
||||
PromptFileType::Code => {
|
||||
// writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Really dumb truncation strategy
|
||||
if let Some(max_tokens) = max_token_length {
|
||||
prompt = args.model.truncate(
|
||||
&prompt,
|
||||
max_tokens,
|
||||
crate::models::TruncationDirection::End,
|
||||
)?;
|
||||
}
|
||||
|
||||
let token_count = args.model.count_tokens(&prompt)?;
|
||||
|
||||
anyhow::Ok((prompt, token_count))
|
||||
}
|
||||
}
|
||||
5
crates/ai/src/prompts/mod.rs
Normal file
5
crates/ai/src/prompts/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod base;
|
||||
pub mod file_context;
|
||||
pub mod generate;
|
||||
pub mod preamble;
|
||||
pub mod repository_context;
|
||||
52
crates/ai/src/prompts/preamble.rs
Normal file
52
crates/ai/src/prompts/preamble.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
|
||||
use std::fmt::Write;
|
||||
|
||||
pub struct EngineerPreamble {}
|
||||
|
||||
impl PromptTemplate for EngineerPreamble {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
let mut prompts = Vec::new();
|
||||
|
||||
match args.get_file_type() {
|
||||
PromptFileType::Code => {
|
||||
prompts.push(format!(
|
||||
"You are an expert {}engineer.",
|
||||
args.language_name.clone().unwrap_or("".to_string()) + " "
|
||||
));
|
||||
}
|
||||
PromptFileType::Text => {
|
||||
prompts.push("You are an expert engineer.".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(project_name) = args.project_name.clone() {
|
||||
prompts.push(format!(
|
||||
"You are currently working inside the '{project_name}' project in code editor Zed."
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(mut remaining_tokens) = max_token_length {
|
||||
let mut prompt = String::new();
|
||||
let mut total_count = 0;
|
||||
for prompt_piece in prompts {
|
||||
let prompt_token_count =
|
||||
args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
|
||||
if remaining_tokens > prompt_token_count {
|
||||
writeln!(prompt, "{prompt_piece}").unwrap();
|
||||
remaining_tokens -= prompt_token_count;
|
||||
total_count += prompt_token_count;
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok((prompt, total_count))
|
||||
} else {
|
||||
let prompt = prompts.join("\n");
|
||||
let token_count = args.model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, token_count))
|
||||
}
|
||||
}
|
||||
}
|
||||
96
crates/ai/src/prompts/repository_context.rs
Normal file
96
crates/ai/src/prompts/repository_context.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
use crate::prompts::base::{PromptArguments, PromptTemplate};
|
||||
use std::fmt::Write;
|
||||
use std::{ops::Range, path::PathBuf};
|
||||
|
||||
use gpui::{AsyncAppContext, Model};
|
||||
use language::{Anchor, Buffer};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PromptCodeSnippet {
|
||||
path: Option<PathBuf>,
|
||||
language_name: Option<String>,
|
||||
content: String,
|
||||
}
|
||||
|
||||
impl PromptCodeSnippet {
|
||||
pub fn new(
|
||||
buffer: Model<Buffer>,
|
||||
range: Range<Anchor>,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
let content = snapshot.text_for_range(range.clone()).collect::<String>();
|
||||
|
||||
let language_name = buffer
|
||||
.language()
|
||||
.map(|language| language.name().to_string().to_lowercase());
|
||||
|
||||
let file_path = buffer.file().map(|file| file.path().to_path_buf());
|
||||
|
||||
(content, language_name, file_path)
|
||||
})?;
|
||||
|
||||
anyhow::Ok(PromptCodeSnippet {
|
||||
path: file_path,
|
||||
language_name,
|
||||
content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ToString for PromptCodeSnippet {
|
||||
fn to_string(&self) -> String {
|
||||
let path = self
|
||||
.path
|
||||
.as_ref()
|
||||
.map(|path| path.to_string_lossy().to_string())
|
||||
.unwrap_or("".to_string());
|
||||
let language_name = self.language_name.clone().unwrap_or("".to_string());
|
||||
let content = self.content.clone();
|
||||
|
||||
format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RepositoryContext {}
|
||||
|
||||
impl PromptTemplate for RepositoryContext {
|
||||
fn generate(
|
||||
&self,
|
||||
args: &PromptArguments,
|
||||
max_token_length: Option<usize>,
|
||||
) -> anyhow::Result<(String, usize)> {
|
||||
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
|
||||
let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
|
||||
let mut prompt = String::new();
|
||||
|
||||
let mut remaining_tokens = max_token_length;
|
||||
let separator_token_length = args.model.count_tokens("\n")?;
|
||||
for snippet in &args.snippets {
|
||||
let mut snippet_prompt = template.to_string();
|
||||
let content = snippet.to_string();
|
||||
writeln!(snippet_prompt, "{content}").unwrap();
|
||||
|
||||
let token_count = args.model.count_tokens(&snippet_prompt)?;
|
||||
if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
|
||||
if let Some(tokens_left) = remaining_tokens {
|
||||
if tokens_left >= token_count {
|
||||
writeln!(prompt, "{snippet_prompt}").unwrap();
|
||||
remaining_tokens = if tokens_left >= (token_count + separator_token_length)
|
||||
{
|
||||
Some(tokens_left - token_count - separator_token_length)
|
||||
} else {
|
||||
Some(0)
|
||||
};
|
||||
}
|
||||
} else {
|
||||
writeln!(prompt, "{snippet_prompt}").unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let total_token_count = args.model.count_tokens(&prompt)?;
|
||||
anyhow::Ok((prompt, total_token_count))
|
||||
}
|
||||
}
|
||||
1
crates/ai/src/providers.rs
Normal file
1
crates/ai/src/providers.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod open_ai;
|
||||
9
crates/ai/src/providers/open_ai.rs
Normal file
9
crates/ai/src/providers/open_ai.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
pub mod completion;
|
||||
pub mod embedding;
|
||||
pub mod model;
|
||||
|
||||
pub use completion::*;
|
||||
pub use embedding::*;
|
||||
pub use model::OpenAiLanguageModel;
|
||||
|
||||
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
|
||||
421
crates/ai/src/providers/open_ai/completion.rs
Normal file
421
crates/ai/src/providers/open_ai/completion.rs
Normal file
@@ -0,0 +1,421 @@
|
||||
use std::{
|
||||
env,
|
||||
fmt::{self, Display},
|
||||
io,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{
|
||||
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
|
||||
Stream, StreamExt,
|
||||
};
|
||||
use gpui::{AppContext, BackgroundExecutor};
|
||||
use isahc::{http::StatusCode, Request, RequestExt};
|
||||
use parking_lot::RwLock;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
|
||||
use crate::{
|
||||
auth::{CredentialProvider, ProviderCredential},
|
||||
completion::{CompletionProvider, CompletionRequest},
|
||||
models::LanguageModel,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn cycle(&mut self) {
|
||||
*self = match self {
|
||||
Role::User => Role::Assistant,
|
||||
Role::Assistant => Role::System,
|
||||
Role::System => Role::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Role::User => write!(f, "User"),
|
||||
Role::Assistant => write!(f, "Assistant"),
|
||||
Role::System => write!(f, "System"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
pub struct OpenAiRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl CompletionRequest for OpenAiRequest {
|
||||
fn data(&self) -> serde_json::Result<String> {
|
||||
serde_json::to_string(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAiUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ChatChoiceDelta {
|
||||
pub index: u32,
|
||||
pub delta: ResponseMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OpenAiResponseStreamEvent {
|
||||
pub id: Option<String>,
|
||||
pub object: String,
|
||||
pub created: u32,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoiceDelta>,
|
||||
pub usage: Option<OpenAiUsage>,
|
||||
}
|
||||
|
||||
async fn stream_completion(
|
||||
api_url: String,
|
||||
kind: OpenAiCompletionProviderKind,
|
||||
credential: ProviderCredential,
|
||||
executor: BackgroundExecutor,
|
||||
request: Box<dyn CompletionRequest>,
|
||||
) -> Result<impl Stream<Item = Result<OpenAiResponseStreamEvent>>> {
|
||||
let api_key = match credential {
|
||||
ProviderCredential::Credentials { api_key } => api_key,
|
||||
_ => {
|
||||
return Err(anyhow!("no credentials provider for completion"));
|
||||
}
|
||||
};
|
||||
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
|
||||
|
||||
let (auth_header_name, auth_header_value) = kind.auth_header(api_key);
|
||||
let json_data = request.data()?;
|
||||
let mut response = Request::post(kind.completions_endpoint_url(&api_url))
|
||||
.header("Content-Type", "application/json")
|
||||
.header(auth_header_name, auth_header_value)
|
||||
.body(json_data)?
|
||||
.send_async()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if status == StatusCode::OK {
|
||||
executor
|
||||
.spawn(async move {
|
||||
let mut lines = BufReader::new(response.body_mut()).lines();
|
||||
|
||||
fn parse_line(
|
||||
line: Result<String, io::Error>,
|
||||
) -> Result<Option<OpenAiResponseStreamEvent>> {
|
||||
if let Some(data) = line?.strip_prefix("data: ") {
|
||||
let event = serde_json::from_str(data)?;
|
||||
Ok(Some(event))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(line) = lines.next().await {
|
||||
if let Some(event) = parse_line(line).transpose() {
|
||||
let done = event.as_ref().map_or(false, |event| {
|
||||
event
|
||||
.choices
|
||||
.last()
|
||||
.map_or(false, |choice| choice.finish_reason.is_some())
|
||||
});
|
||||
if tx.unbounded_send(event).is_err() {
|
||||
break;
|
||||
}
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
Ok(rx)
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiResponse {
|
||||
error: OpenAiError,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<OpenAiResponse>(&body) {
|
||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {}",
|
||||
response.error.message,
|
||||
)),
|
||||
|
||||
_ => Err(anyhow!(
|
||||
"Failed to connect to OpenAI API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
|
||||
pub enum AzureOpenAiApiVersion {
|
||||
/// Retiring April 2, 2024.
|
||||
#[serde(rename = "2023-03-15-preview")]
|
||||
V2023_03_15Preview,
|
||||
#[serde(rename = "2023-05-15")]
|
||||
V2023_05_15,
|
||||
/// Retiring April 2, 2024.
|
||||
#[serde(rename = "2023-06-01-preview")]
|
||||
V2023_06_01Preview,
|
||||
/// Retiring April 2, 2024.
|
||||
#[serde(rename = "2023-07-01-preview")]
|
||||
V2023_07_01Preview,
|
||||
/// Retiring April 2, 2024.
|
||||
#[serde(rename = "2023-08-01-preview")]
|
||||
V2023_08_01Preview,
|
||||
/// Retiring April 2, 2024.
|
||||
#[serde(rename = "2023-09-01-preview")]
|
||||
V2023_09_01Preview,
|
||||
#[serde(rename = "2023-12-01-preview")]
|
||||
V2023_12_01Preview,
|
||||
#[serde(rename = "2024-02-15-preview")]
|
||||
V2024_02_15Preview,
|
||||
}
|
||||
|
||||
impl fmt::Display for AzureOpenAiApiVersion {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
Self::V2023_03_15Preview => "2023-03-15-preview",
|
||||
Self::V2023_05_15 => "2023-05-15",
|
||||
Self::V2023_06_01Preview => "2023-06-01-preview",
|
||||
Self::V2023_07_01Preview => "2023-07-01-preview",
|
||||
Self::V2023_08_01Preview => "2023-08-01-preview",
|
||||
Self::V2023_09_01Preview => "2023-09-01-preview",
|
||||
Self::V2023_12_01Preview => "2023-12-01-preview",
|
||||
Self::V2024_02_15Preview => "2024-02-15-preview",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum OpenAiCompletionProviderKind {
|
||||
OpenAi,
|
||||
AzureOpenAi {
|
||||
deployment_id: String,
|
||||
api_version: AzureOpenAiApiVersion,
|
||||
},
|
||||
}
|
||||
|
||||
impl OpenAiCompletionProviderKind {
|
||||
/// Returns the chat completion endpoint URL for this [`OpenAiCompletionProviderKind`].
|
||||
fn completions_endpoint_url(&self, api_url: &str) -> String {
|
||||
match self {
|
||||
Self::OpenAi => {
|
||||
// https://platform.openai.com/docs/api-reference/chat/create
|
||||
format!("{api_url}/chat/completions")
|
||||
}
|
||||
Self::AzureOpenAi {
|
||||
deployment_id,
|
||||
api_version,
|
||||
} => {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
|
||||
format!("{api_url}/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the authentication header for this [`OpenAiCompletionProviderKind`].
|
||||
fn auth_header(&self, api_key: String) -> (&'static str, String) {
|
||||
match self {
|
||||
Self::OpenAi => ("Authorization", format!("Bearer {api_key}")),
|
||||
Self::AzureOpenAi { .. } => ("Api-Key", api_key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAiCompletionProvider {
|
||||
api_url: String,
|
||||
kind: OpenAiCompletionProviderKind,
|
||||
model: OpenAiLanguageModel,
|
||||
credential: Arc<RwLock<ProviderCredential>>,
|
||||
executor: BackgroundExecutor,
|
||||
}
|
||||
|
||||
impl OpenAiCompletionProvider {
|
||||
pub async fn new(
|
||||
api_url: String,
|
||||
kind: OpenAiCompletionProviderKind,
|
||||
model_name: String,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Self {
|
||||
let model = executor
|
||||
.spawn(async move { OpenAiLanguageModel::load(&model_name) })
|
||||
.await;
|
||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||
Self {
|
||||
api_url,
|
||||
kind,
|
||||
model,
|
||||
credential,
|
||||
executor,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for OpenAiCompletionProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
match *self.credential.read() {
|
||||
ProviderCredential::Credentials { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
|
||||
let existing_credential = self.credential.read().clone();
|
||||
let retrieved_credential = match existing_credential {
|
||||
ProviderCredential::Credentials { .. } => {
|
||||
return async move { existing_credential }.boxed()
|
||||
}
|
||||
_ => {
|
||||
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
|
||||
async move { ProviderCredential::Credentials { api_key } }.boxed()
|
||||
} else {
|
||||
let credentials = cx.read_credentials(OPEN_AI_API_URL);
|
||||
async move {
|
||||
if let Some(Some((_, api_key))) = credentials.await.log_err() {
|
||||
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||
ProviderCredential::Credentials { api_key }
|
||||
} else {
|
||||
ProviderCredential::NoCredentials
|
||||
}
|
||||
} else {
|
||||
ProviderCredential::NoCredentials
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
async move {
|
||||
let retrieved_credential = retrieved_credential.await;
|
||||
*self.credential.write() = retrieved_credential.clone();
|
||||
retrieved_credential
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn save_credentials(
|
||||
&self,
|
||||
cx: &mut AppContext,
|
||||
credential: ProviderCredential,
|
||||
) -> BoxFuture<()> {
|
||||
*self.credential.write() = credential.clone();
|
||||
let credential = credential.clone();
|
||||
let write_credentials = match credential {
|
||||
ProviderCredential::Credentials { api_key } => {
|
||||
Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
async move {
|
||||
if let Some(write_credentials) = write_credentials {
|
||||
write_credentials.await.log_err();
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
|
||||
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||
let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
|
||||
async move {
|
||||
delete_credentials.await.log_err();
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for OpenAiCompletionProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||
model
|
||||
}
|
||||
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
|
||||
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
|
||||
// which is currently model based, due to the language model.
|
||||
// At some point in the future we should rectify this.
|
||||
let credential = self.credential.read().clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let kind = self.kind.clone();
|
||||
let request = stream_completion(api_url, kind, credential, self.executor.clone(), prompt);
|
||||
async move {
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||
Box::new((*self).clone())
|
||||
}
|
||||
}
|
||||
345
crates/ai/src/providers/open_ai/embedding.rs
Normal file
345
crates/ai/src/providers/open_ai/embedding.rs
Normal file
@@ -0,0 +1,345 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::AsyncReadExt;
|
||||
use futures::FutureExt;
|
||||
use gpui::AppContext;
|
||||
use gpui::BackgroundExecutor;
|
||||
use isahc::http::StatusCode;
|
||||
use isahc::prelude::Configurable;
|
||||
use isahc::{AsyncBody, Response};
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use parse_duration::parse;
|
||||
use postage::watch;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
use std::env;
|
||||
use std::ops::Add;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::time::{Duration, Instant};
|
||||
use tiktoken_rs::{cl100k_base, CoreBPE};
|
||||
use util::http::{HttpClient, Request};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::auth::{CredentialProvider, ProviderCredential};
|
||||
use crate::embedding::{Embedding, EmbeddingProvider};
|
||||
use crate::models::LanguageModel;
|
||||
use crate::providers::open_ai::OpenAiLanguageModel;
|
||||
|
||||
use crate::providers::open_ai::OPEN_AI_API_URL;
|
||||
|
||||
pub(crate) fn open_ai_bpe_tokenizer() -> &'static CoreBPE {
|
||||
static OPEN_AI_BPE_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
|
||||
OPEN_AI_BPE_TOKENIZER.get_or_init(|| cl100k_base().unwrap())
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAiEmbeddingProvider {
|
||||
api_url: String,
|
||||
model: OpenAiLanguageModel,
|
||||
credential: Arc<RwLock<ProviderCredential>>,
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
pub executor: BackgroundExecutor,
|
||||
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
|
||||
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct OpenAiEmbeddingRequest<'a> {
|
||||
model: &'static str,
|
||||
input: Vec<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiEmbeddingResponse {
|
||||
data: Vec<OpenAiEmbedding>,
|
||||
usage: OpenAiEmbeddingUsage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiEmbedding {
|
||||
embedding: Vec<f32>,
|
||||
index: usize,
|
||||
object: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenAiEmbeddingUsage {
|
||||
prompt_tokens: usize,
|
||||
total_tokens: usize,
|
||||
}
|
||||
|
||||
impl OpenAiEmbeddingProvider {
|
||||
pub async fn new(
|
||||
api_url: String,
|
||||
client: Arc<dyn HttpClient>,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Self {
|
||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||
|
||||
// Loading the model is expensive, so ensure this runs off the main thread.
|
||||
let model = executor
|
||||
.spawn(async move { OpenAiLanguageModel::load("text-embedding-ada-002") })
|
||||
.await;
|
||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||
|
||||
OpenAiEmbeddingProvider {
|
||||
api_url,
|
||||
model,
|
||||
credential,
|
||||
client,
|
||||
executor,
|
||||
rate_limit_count_rx,
|
||||
rate_limit_count_tx,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_api_key(&self) -> Result<String> {
|
||||
match self.credential.read().clone() {
|
||||
ProviderCredential::Credentials { api_key } => Ok(api_key),
|
||||
_ => Err(anyhow!("api credentials not provided")),
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_rate_limit(&self) {
|
||||
let reset_time = *self.rate_limit_count_tx.lock().borrow();
|
||||
|
||||
if let Some(reset_time) = reset_time {
|
||||
if Instant::now() >= reset_time {
|
||||
*self.rate_limit_count_tx.lock().borrow_mut() = None
|
||||
}
|
||||
}
|
||||
|
||||
log::trace!(
|
||||
"resolving reset time: {:?}",
|
||||
*self.rate_limit_count_tx.lock().borrow()
|
||||
);
|
||||
}
|
||||
|
||||
fn update_reset_time(&self, reset_time: Instant) {
|
||||
let original_time = *self.rate_limit_count_tx.lock().borrow();
|
||||
|
||||
let updated_time = if let Some(original_time) = original_time {
|
||||
if reset_time < original_time {
|
||||
Some(reset_time)
|
||||
} else {
|
||||
Some(original_time)
|
||||
}
|
||||
} else {
|
||||
Some(reset_time)
|
||||
};
|
||||
|
||||
log::trace!("updating rate limit time: {:?}", updated_time);
|
||||
|
||||
*self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
|
||||
}
|
||||
async fn send_request(
|
||||
&self,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
spans: Vec<&str>,
|
||||
request_timeout: u64,
|
||||
) -> Result<Response<AsyncBody>> {
|
||||
let request = Request::post(format!("{api_url}/embeddings"))
|
||||
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||
.timeout(Duration::from_secs(request_timeout))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(
|
||||
serde_json::to_string(&OpenAiEmbeddingRequest {
|
||||
input: spans.clone(),
|
||||
model: "text-embedding-ada-002",
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
)?;
|
||||
|
||||
Ok(self.client.send(request).await?)
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for OpenAiEmbeddingProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
match *self.credential.read() {
|
||||
ProviderCredential::Credentials { .. } => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
|
||||
let existing_credential = self.credential.read().clone();
|
||||
let retrieved_credential = match existing_credential {
|
||||
ProviderCredential::Credentials { .. } => {
|
||||
return async move { existing_credential }.boxed()
|
||||
}
|
||||
_ => {
|
||||
if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
|
||||
async move { ProviderCredential::Credentials { api_key } }.boxed()
|
||||
} else {
|
||||
let credentials = cx.read_credentials(OPEN_AI_API_URL);
|
||||
async move {
|
||||
if let Some(Some((_, api_key))) = credentials.await.log_err() {
|
||||
if let Some(api_key) = String::from_utf8(api_key).log_err() {
|
||||
ProviderCredential::Credentials { api_key }
|
||||
} else {
|
||||
ProviderCredential::NoCredentials
|
||||
}
|
||||
} else {
|
||||
ProviderCredential::NoCredentials
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
async move {
|
||||
let retrieved_credential = retrieved_credential.await;
|
||||
*self.credential.write() = retrieved_credential.clone();
|
||||
retrieved_credential
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn save_credentials(
|
||||
&self,
|
||||
cx: &mut AppContext,
|
||||
credential: ProviderCredential,
|
||||
) -> BoxFuture<()> {
|
||||
*self.credential.write() = credential.clone();
|
||||
let credential = credential.clone();
|
||||
let write_credentials = match credential {
|
||||
ProviderCredential::Credentials { api_key } => {
|
||||
Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
async move {
|
||||
if let Some(write_credentials) = write_credentials {
|
||||
write_credentials.await.log_err();
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
|
||||
*self.credential.write() = ProviderCredential::NoCredentials;
|
||||
let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
|
||||
async move {
|
||||
delete_credentials.await.log_err();
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for OpenAiEmbeddingProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
|
||||
model
|
||||
}
|
||||
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
50000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
*self.rate_limit_count_rx.borrow()
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
|
||||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||
const MAX_RETRIES: usize = 4;
|
||||
|
||||
let api_url = self.api_url.as_str();
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
let mut request_number = 0;
|
||||
let mut rate_limiting = false;
|
||||
let mut request_timeout: u64 = 15;
|
||||
let mut response: Response<AsyncBody>;
|
||||
while request_number < MAX_RETRIES {
|
||||
response = self
|
||||
.send_request(
|
||||
&api_url,
|
||||
&api_key,
|
||||
spans.iter().map(|x| &**x).collect(),
|
||||
request_timeout,
|
||||
)
|
||||
.await?;
|
||||
|
||||
request_number += 1;
|
||||
|
||||
match response.status() {
|
||||
StatusCode::REQUEST_TIMEOUT => {
|
||||
request_timeout += 5;
|
||||
}
|
||||
StatusCode::OK => {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
let response: OpenAiEmbeddingResponse = serde_json::from_str(&body)?;
|
||||
|
||||
log::trace!(
|
||||
"openai embedding completed. tokens: {:?}",
|
||||
response.usage.total_tokens
|
||||
);
|
||||
|
||||
// If we complete a request successfully that was previously rate_limited
|
||||
// resolve the rate limit
|
||||
if rate_limiting {
|
||||
self.resolve_rate_limit()
|
||||
}
|
||||
|
||||
return Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|embedding| Embedding::from(embedding.embedding))
|
||||
.collect());
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
rate_limiting = true;
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
let delay_duration = {
|
||||
let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
|
||||
if let Some(time_to_reset) =
|
||||
response.headers().get("x-ratelimit-reset-tokens")
|
||||
{
|
||||
if let Ok(time_str) = time_to_reset.to_str() {
|
||||
parse(time_str).unwrap_or(delay)
|
||||
} else {
|
||||
delay
|
||||
}
|
||||
} else {
|
||||
delay
|
||||
}
|
||||
};
|
||||
|
||||
// If we've previously rate limited, increment the duration but not the count
|
||||
let reset_time = Instant::now().add(delay_duration);
|
||||
self.update_reset_time(reset_time);
|
||||
|
||||
log::trace!(
|
||||
"openai rate limiting: waiting {:?} until lifted",
|
||||
&delay_duration
|
||||
);
|
||||
|
||||
self.executor.timer(delay_duration).await;
|
||||
}
|
||||
_ => {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(
|
||||
"open ai bad request: {:?} {:?}",
|
||||
&response.status(),
|
||||
body
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(anyhow!("openai max retries"))
|
||||
}
|
||||
}
|
||||
59
crates/ai/src/providers/open_ai/model.rs
Normal file
59
crates/ai/src/providers/open_ai/model.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use anyhow::anyhow;
|
||||
use tiktoken_rs::CoreBPE;
|
||||
|
||||
use crate::models::{LanguageModel, TruncationDirection};
|
||||
|
||||
use super::open_ai_bpe_tokenizer;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAiLanguageModel {
|
||||
name: String,
|
||||
bpe: Option<CoreBPE>,
|
||||
}
|
||||
|
||||
impl OpenAiLanguageModel {
|
||||
pub fn load(model_name: &str) -> Self {
|
||||
let bpe = tiktoken_rs::get_bpe_from_model(model_name)
|
||||
.unwrap_or(open_ai_bpe_tokenizer().to_owned());
|
||||
OpenAiLanguageModel {
|
||||
name: model_name.to_string(),
|
||||
bpe: Some(bpe),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for OpenAiLanguageModel {
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
anyhow::Ok(bpe.encode_with_special_tokens(content).len())
|
||||
} else {
|
||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String> {
|
||||
if let Some(bpe) = &self.bpe {
|
||||
let tokens = bpe.encode_with_special_tokens(content);
|
||||
if tokens.len() > length {
|
||||
match direction {
|
||||
TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
|
||||
TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
|
||||
}
|
||||
} else {
|
||||
bpe.decode(tokens)
|
||||
}
|
||||
} else {
|
||||
Err(anyhow!("bpe for open ai model was not retrieved"))
|
||||
}
|
||||
}
|
||||
fn capacity(&self) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
|
||||
}
|
||||
}
|
||||
206
crates/ai/src/test.rs
Normal file
206
crates/ai/src/test.rs
Normal file
@@ -0,0 +1,206 @@
|
||||
use std::{
|
||||
sync::atomic::{self, AtomicUsize, Ordering},
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::AppContext;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::{
|
||||
auth::{CredentialProvider, ProviderCredential},
|
||||
completion::{CompletionProvider, CompletionRequest},
|
||||
embedding::{Embedding, EmbeddingProvider},
|
||||
models::{LanguageModel, TruncationDirection},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FakeLanguageModel {
|
||||
pub capacity: usize,
|
||||
}
|
||||
|
||||
impl LanguageModel for FakeLanguageModel {
|
||||
fn name(&self) -> String {
|
||||
"dummy".to_string()
|
||||
}
|
||||
fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(content.chars().collect::<Vec<char>>().len())
|
||||
}
|
||||
fn truncate(
|
||||
&self,
|
||||
content: &str,
|
||||
length: usize,
|
||||
direction: TruncationDirection,
|
||||
) -> anyhow::Result<String> {
|
||||
println!("TRYING TO TRUNCATE: {:?}", length.clone());
|
||||
|
||||
if length > self.count_tokens(content)? {
|
||||
println!("NOT TRUNCATING");
|
||||
return anyhow::Ok(content.to_string());
|
||||
}
|
||||
|
||||
anyhow::Ok(match direction {
|
||||
TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
|
||||
.into_iter()
|
||||
.collect::<String>(),
|
||||
TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
|
||||
.into_iter()
|
||||
.collect::<String>(),
|
||||
})
|
||||
}
|
||||
fn capacity(&self) -> anyhow::Result<usize> {
|
||||
anyhow::Ok(self.capacity)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct FakeEmbeddingProvider {
|
||||
pub embedding_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl Clone for FakeEmbeddingProvider {
|
||||
fn clone(&self) -> Self {
|
||||
FakeEmbeddingProvider {
|
||||
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeEmbeddingProvider {
|
||||
pub fn embedding_count(&self) -> usize {
|
||||
self.embedding_count.load(atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub fn embed_sync(&self, span: &str) -> Embedding {
|
||||
let mut result = vec![1.0; 26];
|
||||
for letter in span.chars() {
|
||||
let letter = letter.to_ascii_lowercase();
|
||||
if letter as u32 >= 'a' as u32 {
|
||||
let ix = (letter as u32) - ('a' as u32);
|
||||
if ix < 26 {
|
||||
result[ix as usize] += 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
for x in &mut result {
|
||||
*x /= norm;
|
||||
}
|
||||
|
||||
result.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for FakeEmbeddingProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
|
||||
async { ProviderCredential::NotNeeded }.boxed()
|
||||
}
|
||||
|
||||
fn save_credentials(
|
||||
&self,
|
||||
_cx: &mut AppContext,
|
||||
_credential: ProviderCredential,
|
||||
) -> BoxFuture<()> {
|
||||
async {}.boxed()
|
||||
}
|
||||
|
||||
fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
|
||||
async {}.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl EmbeddingProvider for FakeEmbeddingProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
Box::new(FakeLanguageModel { capacity: 1000 })
|
||||
}
|
||||
fn max_tokens_per_batch(&self) -> usize {
|
||||
1000
|
||||
}
|
||||
|
||||
fn rate_limit_expiration(&self) -> Option<Instant> {
|
||||
None
|
||||
}
|
||||
|
||||
async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
|
||||
self.embedding_count
|
||||
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
|
||||
|
||||
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FakeCompletionProvider {
|
||||
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
|
||||
}
|
||||
|
||||
impl Clone for FakeCompletionProvider {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FakeCompletionProvider {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
last_completion_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_completion(&self, completion: impl Into<String>) {
|
||||
let mut tx = self.last_completion_tx.lock();
|
||||
tx.as_mut().unwrap().try_send(completion.into()).unwrap();
|
||||
}
|
||||
|
||||
pub fn finish_completion(&self) {
|
||||
self.last_completion_tx.lock().take().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialProvider for FakeCompletionProvider {
|
||||
fn has_credentials(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
|
||||
async { ProviderCredential::NotNeeded }.boxed()
|
||||
}
|
||||
|
||||
fn save_credentials(
|
||||
&self,
|
||||
_cx: &mut AppContext,
|
||||
_credential: ProviderCredential,
|
||||
) -> BoxFuture<()> {
|
||||
async {}.boxed()
|
||||
}
|
||||
|
||||
fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
|
||||
async {}.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for FakeCompletionProvider {
|
||||
fn base_model(&self) -> Box<dyn LanguageModel> {
|
||||
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
|
||||
model
|
||||
}
|
||||
fn complete(
|
||||
&self,
|
||||
_prompt: Box<dyn CompletionRequest>,
|
||||
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
*self.last_completion_tx.lock() = Some(tx);
|
||||
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
|
||||
}
|
||||
fn box_clone(&self) -> Box<dyn CompletionProvider> {
|
||||
Box::new((*self).clone())
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
[package]
|
||||
name = "anthropic"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
license = "AGPL-3.0-or-later"
|
||||
|
||||
[lib]
|
||||
path = "src/anthropic.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
tokio.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
@@ -1 +0,0 @@
|
||||
../../LICENSE-AGPL
|
||||
@@ -1,234 +0,0 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::convert::TryFrom;
|
||||
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub enum Model {
|
||||
#[default]
|
||||
#[serde(rename = "claude-3-opus-20240229")]
|
||||
Claude3Opus,
|
||||
#[serde(rename = "claude-3-sonnet-20240229")]
|
||||
Claude3Sonnet,
|
||||
#[serde(rename = "claude-3-haiku-20240307")]
|
||||
Claude3Haiku,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn from_id(id: &str) -> Result<Self> {
|
||||
if id.starts_with("claude-3-opus") {
|
||||
Ok(Self::Claude3Opus)
|
||||
} else if id.starts_with("claude-3-sonnet") {
|
||||
Ok(Self::Claude3Sonnet)
|
||||
} else if id.starts_with("claude-3-haiku") {
|
||||
Ok(Self::Claude3Haiku)
|
||||
} else {
|
||||
Err(anyhow!("Invalid model id: {}", id))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Claude3Opus => "Claude 3 Opus",
|
||||
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
200_000
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
impl TryFrom<String> for Role {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: String) -> Result<Self> {
|
||||
match value.as_str() {
|
||||
"user" => Ok(Self::User),
|
||||
"assistant" => Ok(Self::Assistant),
|
||||
_ => Err(anyhow!("invalid role '{value}'")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Role> for String {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => "user".to_owned(),
|
||||
Role::Assistant => "assistant".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Request {
|
||||
pub model: Model,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
pub system: String,
|
||||
pub max_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct RequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ResponseEvent {
|
||||
MessageStart {
|
||||
message: ResponseMessage,
|
||||
},
|
||||
ContentBlockStart {
|
||||
index: u32,
|
||||
content_block: ContentBlock,
|
||||
},
|
||||
Ping {},
|
||||
ContentBlockDelta {
|
||||
index: u32,
|
||||
delta: TextDelta,
|
||||
},
|
||||
ContentBlockStop {
|
||||
index: u32,
|
||||
},
|
||||
MessageDelta {
|
||||
delta: ResponseMessage,
|
||||
usage: Usage,
|
||||
},
|
||||
MessageStop {},
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ResponseMessage {
|
||||
#[serde(rename = "type")]
|
||||
pub message_type: Option<String>,
|
||||
pub id: Option<String>,
|
||||
pub role: Option<String>,
|
||||
pub content: Option<Vec<String>>,
|
||||
pub model: Option<String>,
|
||||
pub stop_reason: Option<String>,
|
||||
pub stop_sequence: Option<String>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct Usage {
|
||||
pub input_tokens: Option<u32>,
|
||||
pub output_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ContentBlock {
|
||||
Text { text: String },
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum TextDelta {
|
||||
TextDelta { text: String },
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
|
||||
let uri = format!("{api_url}/v1/messages");
|
||||
let request = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Anthropic-Version", "2023-06-01")
|
||||
.header("Anthropic-Beta", "messages-2023-12-15")
|
||||
.header("X-Api-Key", api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
Ok(reader
|
||||
.lines()
|
||||
.filter_map(|line| async move {
|
||||
match line {
|
||||
Ok(line) => {
|
||||
let line = line.strip_prefix("data: ")?;
|
||||
match serde_json::from_str(line) {
|
||||
Ok(response) => Some(Ok(response)),
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
}
|
||||
}
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
}
|
||||
})
|
||||
.boxed())
|
||||
} else {
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
|
||||
let body_str = std::str::from_utf8(&body)?;
|
||||
|
||||
match serde_json::from_str::<ResponseEvent>(body_str) {
|
||||
Ok(_) => Err(anyhow!(
|
||||
"Unexpected success response while expecting an error: {}",
|
||||
body_str,
|
||||
)),
|
||||
Err(_) => Err(anyhow!(
|
||||
"Failed to connect to API: {} {}",
|
||||
response.status(),
|
||||
body_str,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use super::*;
|
||||
// use util::http::IsahcHttpClient;
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn stream_completion_success() {
|
||||
// let http_client = IsahcHttpClient::new().unwrap();
|
||||
|
||||
// let request = Request {
|
||||
// model: Model::Claude3Opus,
|
||||
// messages: vec![RequestMessage {
|
||||
// role: Role::User,
|
||||
// content: "Ping".to_string(),
|
||||
// }],
|
||||
// stream: true,
|
||||
// system: "Respond to ping with pong".to_string(),
|
||||
// max_tokens: 4096,
|
||||
// };
|
||||
|
||||
// let stream = stream_completion(
|
||||
// &http_client,
|
||||
// "https://api.anthropic.com",
|
||||
// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
|
||||
// request,
|
||||
// )
|
||||
// .await
|
||||
// .unwrap();
|
||||
|
||||
// stream
|
||||
// .for_each(|event| async {
|
||||
// match event {
|
||||
// Ok(event) => println!("{:?}", event),
|
||||
// Err(e) => eprintln!("Error: {:?}", e),
|
||||
// }
|
||||
// })
|
||||
// .await;
|
||||
// }
|
||||
// }
|
||||
@@ -5,18 +5,19 @@ edition = "2021"
|
||||
publish = false
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/assistant.rs"
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
ai.workspace = true
|
||||
anyhow.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
editor.workspace = true
|
||||
file_icons.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
@@ -25,13 +26,12 @@ language.workspace = true
|
||||
log.workspace = true
|
||||
menu.workspace = true
|
||||
multi_buffer.workspace = true
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
ordered-float.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
schemars.workspace = true
|
||||
search.workspace = true
|
||||
semantic_index.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
@@ -45,6 +45,7 @@ uuid.workspace = true
|
||||
workspace.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
ai = { workspace = true, features = ["test-support"] }
|
||||
ctor.workspace = true
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
|
||||
@@ -1,27 +1,22 @@
|
||||
pub mod assistant_panel;
|
||||
pub mod assistant_settings;
|
||||
mod codegen;
|
||||
mod completion_provider;
|
||||
mod prompts;
|
||||
mod saved_conversation;
|
||||
mod streaming_diff;
|
||||
|
||||
mod embedded_scope;
|
||||
|
||||
use ai::providers::open_ai::Role;
|
||||
use anyhow::Result;
|
||||
pub use assistant_panel::AssistantPanel;
|
||||
use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
|
||||
use assistant_settings::OpenAiModel;
|
||||
use chrono::{DateTime, Local};
|
||||
use client::{proto, Client};
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
pub(crate) use completion_provider::*;
|
||||
use gpui::{actions, AppContext, BorrowAppContext, Global, SharedString};
|
||||
pub(crate) use saved_conversation::*;
|
||||
use collections::HashMap;
|
||||
use fs::Fs;
|
||||
use futures::StreamExt;
|
||||
use gpui::{actions, AppContext, SharedString};
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::{
|
||||
fmt::{self, Display},
|
||||
sync::Arc,
|
||||
};
|
||||
use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
|
||||
use util::paths::CONVERSATIONS_DIR;
|
||||
|
||||
actions!(
|
||||
assistant,
|
||||
@@ -35,6 +30,7 @@ actions!(
|
||||
ResetKey,
|
||||
InlineAssist,
|
||||
ToggleIncludeConversation,
|
||||
ToggleRetrieveContext,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -43,134 +39,6 @@ actions!(
|
||||
)]
|
||||
struct MessageId(usize);
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn cycle(&mut self) {
|
||||
*self = match self {
|
||||
Role::User => Role::Assistant,
|
||||
Role::Assistant => Role::System,
|
||||
Role::System => Role::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Role::User => write!(f, "user"),
|
||||
Role::Assistant => write!(f, "assistant"),
|
||||
Role::System => write!(f, "system"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub enum LanguageModel {
|
||||
ZedDotDev(ZedDotDevModel),
|
||||
OpenAi(OpenAiModel),
|
||||
}
|
||||
|
||||
impl Default for LanguageModel {
|
||||
fn default() -> Self {
|
||||
LanguageModel::ZedDotDev(ZedDotDevModel::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel {
|
||||
pub fn telemetry_id(&self) -> String {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
|
||||
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> String {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
|
||||
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.max_token_count(),
|
||||
LanguageModel::ZedDotDev(model) => model.max_token_count(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.id(),
|
||||
LanguageModel::ZedDotDev(model) => model.id(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct LanguageModelRequestMessage {
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl LanguageModelRequestMessage {
|
||||
pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
|
||||
proto::LanguageModelRequestMessage {
|
||||
role: match self.role {
|
||||
Role::User => proto::LanguageModelRole::LanguageModelUser,
|
||||
Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
|
||||
Role::System => proto::LanguageModelRole::LanguageModelSystem,
|
||||
} as i32,
|
||||
content: self.content.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
pub struct LanguageModelRequest {
|
||||
pub model: LanguageModel,
|
||||
pub messages: Vec<LanguageModelRequestMessage>,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
}
|
||||
|
||||
impl LanguageModelRequest {
|
||||
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
|
||||
proto::CompleteWithLanguageModel {
|
||||
model: self.model.id().to_string(),
|
||||
messages: self.messages.iter().map(|m| m.to_proto()).collect(),
|
||||
stop: self.stop.clone(),
|
||||
temperature: self.temperature,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct LanguageModelResponseMessage {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct LanguageModelUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct LanguageModelChoiceDelta {
|
||||
pub index: u32,
|
||||
pub delta: LanguageModelResponseMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
struct MessageMetadata {
|
||||
role: Role,
|
||||
@@ -185,61 +53,72 @@ enum MessageStatus {
|
||||
Error(SharedString),
|
||||
}
|
||||
|
||||
/// The state pertaining to the Assistant.
|
||||
#[derive(Default)]
|
||||
struct Assistant {
|
||||
/// Whether the Assistant is enabled.
|
||||
enabled: bool,
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SavedMessage {
|
||||
id: MessageId,
|
||||
start: usize,
|
||||
}
|
||||
|
||||
impl Global for Assistant {}
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SavedConversation {
|
||||
id: Option<String>,
|
||||
zed: String,
|
||||
version: String,
|
||||
text: String,
|
||||
messages: Vec<SavedMessage>,
|
||||
message_metadata: HashMap<MessageId, MessageMetadata>,
|
||||
summary: String,
|
||||
api_url: Option<String>,
|
||||
model: OpenAiModel,
|
||||
}
|
||||
|
||||
impl Assistant {
|
||||
const NAMESPACE: &'static str = "assistant";
|
||||
impl SavedConversation {
|
||||
const VERSION: &'static str = "0.1.0";
|
||||
}
|
||||
|
||||
fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
|
||||
if self.enabled == enabled {
|
||||
return;
|
||||
struct SavedConversationMetadata {
|
||||
title: String,
|
||||
path: PathBuf,
|
||||
mtime: chrono::DateTime<chrono::Local>,
|
||||
}
|
||||
|
||||
impl SavedConversationMetadata {
|
||||
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
|
||||
fs.create_dir(&CONVERSATIONS_DIR).await?;
|
||||
|
||||
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
|
||||
let mut conversations = Vec::<SavedConversationMetadata>::new();
|
||||
while let Some(path) = paths.next().await {
|
||||
let path = path?;
|
||||
if path.extension() != Some(OsStr::new("json")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pattern = r" - \d+.zed.json$";
|
||||
let re = Regex::new(pattern).unwrap();
|
||||
|
||||
let metadata = fs.metadata(&path).await?;
|
||||
if let Some((file_name, metadata)) = path
|
||||
.file_name()
|
||||
.and_then(|name| name.to_str())
|
||||
.zip(metadata)
|
||||
{
|
||||
let title = re.replace(file_name, "");
|
||||
conversations.push(Self {
|
||||
title: title.into_owned(),
|
||||
path,
|
||||
mtime: metadata.mtime.into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
|
||||
|
||||
self.enabled = enabled;
|
||||
|
||||
if !enabled {
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
filter.hide_namespace(Self::NAMESPACE);
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
filter.show_namespace(Self::NAMESPACE);
|
||||
});
|
||||
Ok(conversations)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
cx.set_global(Assistant::default());
|
||||
AssistantSettings::register(cx);
|
||||
completion_provider::init(client, cx);
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
assistant_panel::init(cx);
|
||||
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
filter.hide_namespace(Assistant::NAMESPACE);
|
||||
});
|
||||
cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
|
||||
let settings = AssistantSettings::get_global(cx);
|
||||
|
||||
assistant.set_enabled(settings.enabled, cx);
|
||||
});
|
||||
cx.observe_global::<SettingsStore>(|cx| {
|
||||
cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
|
||||
let settings = AssistantSettings::get_global(cx);
|
||||
|
||||
assistant.set_enabled(settings.enabled, cx);
|
||||
});
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,305 +1,169 @@
|
||||
use std::fmt;
|
||||
|
||||
use ai::providers::open_ai::{
|
||||
AzureOpenAiApiVersion, OpenAiCompletionProviderKind, OPEN_AI_API_URL,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use gpui::Pixels;
|
||||
pub use open_ai::Model as OpenAiModel;
|
||||
use schemars::{
|
||||
schema::{InstanceType, Metadata, Schema, SchemaObject},
|
||||
JsonSchema,
|
||||
};
|
||||
use serde::{
|
||||
de::{self, Visitor},
|
||||
Deserialize, Deserializer, Serialize, Serializer,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub enum ZedDotDevModel {
|
||||
Gpt3Point5Turbo,
|
||||
Gpt4,
|
||||
#[default]
|
||||
Gpt4Turbo,
|
||||
Claude3Opus,
|
||||
Claude3Sonnet,
|
||||
Claude3Haiku,
|
||||
Custom(String),
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum OpenAiModel {
|
||||
#[serde(rename = "gpt-3.5-turbo-0613")]
|
||||
ThreePointFiveTurbo,
|
||||
#[serde(rename = "gpt-4-0613")]
|
||||
Four,
|
||||
#[serde(rename = "gpt-4-1106-preview")]
|
||||
FourTurbo,
|
||||
}
|
||||
|
||||
impl Serialize for ZedDotDevModel {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(self.id())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for ZedDotDevModel {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct ZedDotDevModelVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
|
||||
type Value = ZedDotDevModel;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
match value {
|
||||
"gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
|
||||
"gpt-4" => Ok(ZedDotDevModel::Gpt4),
|
||||
"gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
|
||||
_ => Ok(ZedDotDevModel::Custom(value.to_owned())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_str(ZedDotDevModelVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonSchema for ZedDotDevModel {
|
||||
fn schema_name() -> String {
|
||||
"ZedDotDevModel".to_owned()
|
||||
}
|
||||
|
||||
fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
|
||||
let variants = vec![
|
||||
"gpt-3.5-turbo".to_owned(),
|
||||
"gpt-4".to_owned(),
|
||||
"gpt-4-turbo-preview".to_owned(),
|
||||
];
|
||||
Schema::Object(SchemaObject {
|
||||
instance_type: Some(InstanceType::String.into()),
|
||||
enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
|
||||
metadata: Some(Box::new(Metadata {
|
||||
title: Some("ZedDotDevModel".to_owned()),
|
||||
default: Some(serde_json::json!("gpt-4-turbo-preview")),
|
||||
examples: vec![
|
||||
serde_json::json!("gpt-3.5-turbo"),
|
||||
serde_json::json!("gpt-4"),
|
||||
serde_json::json!("gpt-4-turbo-preview"),
|
||||
serde_json::json!("custom-model-name"),
|
||||
],
|
||||
..Default::default()
|
||||
})),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ZedDotDevModel {
|
||||
pub fn id(&self) -> &str {
|
||||
impl OpenAiModel {
|
||||
pub fn full_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
|
||||
Self::Gpt4 => "gpt-4",
|
||||
Self::Gpt4Turbo => "gpt-4-turbo-preview",
|
||||
Self::Claude3Opus => "claude-3-opus",
|
||||
Self::Claude3Sonnet => "claude-3-sonnet",
|
||||
Self::Claude3Haiku => "claude-3-haiku",
|
||||
Self::Custom(id) => id,
|
||||
Self::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
|
||||
Self::Four => "gpt-4-0613",
|
||||
Self::FourTurbo => "gpt-4-1106-preview",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
pub fn short_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
|
||||
Self::Gpt4 => "GPT 4",
|
||||
Self::Gpt4Turbo => "GPT 4 Turbo",
|
||||
Self::Claude3Opus => "Claude 3 Opus",
|
||||
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||
Self::Custom(id) => id.as_str(),
|
||||
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
|
||||
Self::Four => "gpt-4",
|
||||
Self::FourTurbo => "gpt-4-turbo",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
pub fn cycle(&self) -> Self {
|
||||
match self {
|
||||
Self::Gpt3Point5Turbo => 2048,
|
||||
Self::Gpt4 => 4096,
|
||||
Self::Gpt4Turbo => 128000,
|
||||
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
|
||||
Self::Custom(_) => 4096, // TODO: Make this configurable
|
||||
Self::ThreePointFiveTurbo => Self::Four,
|
||||
Self::Four => Self::FourTurbo,
|
||||
Self::FourTurbo => Self::ThreePointFiveTurbo,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AssistantDockPosition {
|
||||
Left,
|
||||
#[default]
|
||||
Right,
|
||||
Bottom,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
#[serde(tag = "name", rename_all = "snake_case")]
|
||||
pub enum AssistantProvider {
|
||||
#[serde(rename = "zed.dev")]
|
||||
ZedDotDev {
|
||||
#[serde(default)]
|
||||
default_model: ZedDotDevModel,
|
||||
},
|
||||
#[serde(rename = "openai")]
|
||||
OpenAi {
|
||||
#[serde(default)]
|
||||
default_model: OpenAiModel,
|
||||
#[serde(default = "open_ai_url")]
|
||||
api_url: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for AssistantProvider {
|
||||
fn default() -> Self {
|
||||
Self::ZedDotDev {
|
||||
default_model: ZedDotDevModel::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn open_ai_url() -> String {
|
||||
"https://api.openai.com/v1".into()
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Deserialize, Serialize)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AssistantSettings {
|
||||
pub enabled: bool,
|
||||
/// Whether to show the assistant panel button in the status bar.
|
||||
pub button: bool,
|
||||
/// Where to dock the assistant.
|
||||
pub dock: AssistantDockPosition,
|
||||
/// Default width in pixels when the assistant is docked to the left or right.
|
||||
pub default_width: Pixels,
|
||||
/// Default height in pixels when the assistant is docked to the bottom.
|
||||
pub default_height: Pixels,
|
||||
pub provider: AssistantProvider,
|
||||
/// The default OpenAI model to use when starting new conversations.
|
||||
#[deprecated = "Please use `provider.default_model` instead."]
|
||||
pub default_open_ai_model: OpenAiModel,
|
||||
/// OpenAI API base URL to use when starting new conversations.
|
||||
#[deprecated = "Please use `provider.api_url` instead."]
|
||||
pub openai_api_url: String,
|
||||
/// The settings for the AI provider.
|
||||
pub provider: AiProviderSettings,
|
||||
}
|
||||
|
||||
/// Assistant panel settings
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
pub enum AssistantSettingsContent {
|
||||
Versioned(VersionedAssistantSettingsContent),
|
||||
Legacy(LegacyAssistantSettingsContent),
|
||||
}
|
||||
impl AssistantSettings {
|
||||
pub fn provider_kind(&self) -> anyhow::Result<OpenAiCompletionProviderKind> {
|
||||
match &self.provider {
|
||||
AiProviderSettings::OpenAi(_) => Ok(OpenAiCompletionProviderKind::OpenAi),
|
||||
AiProviderSettings::AzureOpenAi(settings) => {
|
||||
let deployment_id = settings
|
||||
.deployment_id
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
|
||||
let api_version = settings
|
||||
.api_version
|
||||
.ok_or_else(|| anyhow!("no Azure OpenAI API version"))?;
|
||||
|
||||
impl JsonSchema for AssistantSettingsContent {
|
||||
fn schema_name() -> String {
|
||||
VersionedAssistantSettingsContent::schema_name()
|
||||
}
|
||||
|
||||
fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
|
||||
VersionedAssistantSettingsContent::json_schema(gen)
|
||||
}
|
||||
|
||||
fn is_referenceable() -> bool {
|
||||
VersionedAssistantSettingsContent::is_referenceable()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AssistantSettingsContent {
|
||||
fn default() -> Self {
|
||||
Self::Versioned(VersionedAssistantSettingsContent::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl AssistantSettingsContent {
|
||||
fn upgrade(&self) -> AssistantSettingsContentV1 {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||
VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
|
||||
},
|
||||
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
|
||||
enabled: None,
|
||||
button: settings.button,
|
||||
dock: settings.dock,
|
||||
default_width: settings.default_width,
|
||||
default_height: settings.default_height,
|
||||
provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
|
||||
Some(AssistantProvider::OpenAi {
|
||||
default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
|
||||
api_url: open_ai_api_url.clone(),
|
||||
})
|
||||
} else {
|
||||
settings.default_open_ai_model.clone().map(|open_ai_model| {
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: open_ai_model,
|
||||
api_url: open_ai_url(),
|
||||
}
|
||||
})
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_dock(&mut self, dock: AssistantDockPosition) {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match settings {
|
||||
VersionedAssistantSettingsContent::V1(settings) => {
|
||||
settings.dock = Some(dock);
|
||||
}
|
||||
},
|
||||
AssistantSettingsContent::Legacy(settings) => {
|
||||
settings.dock = Some(dock);
|
||||
Ok(OpenAiCompletionProviderKind::AzureOpenAi {
|
||||
deployment_id,
|
||||
api_version,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[serde(tag = "version")]
|
||||
pub enum VersionedAssistantSettingsContent {
|
||||
#[serde(rename = "1")]
|
||||
V1(AssistantSettingsContentV1),
|
||||
}
|
||||
pub fn provider_api_url(&self) -> anyhow::Result<String> {
|
||||
match &self.provider {
|
||||
AiProviderSettings::OpenAi(settings) => Ok(settings
|
||||
.api_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| OPEN_AI_API_URL.to_string())),
|
||||
AiProviderSettings::AzureOpenAi(settings) => settings
|
||||
.api_url
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("no Azure OpenAI API URL")),
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for VersionedAssistantSettingsContent {
|
||||
fn default() -> Self {
|
||||
Self::V1(AssistantSettingsContentV1 {
|
||||
enabled: None,
|
||||
button: None,
|
||||
dock: None,
|
||||
default_width: None,
|
||||
default_height: None,
|
||||
provider: None,
|
||||
})
|
||||
pub fn provider_model(&self) -> anyhow::Result<OpenAiModel> {
|
||||
match &self.provider {
|
||||
AiProviderSettings::OpenAi(settings) => {
|
||||
Ok(settings.default_model.unwrap_or(OpenAiModel::FourTurbo))
|
||||
}
|
||||
AiProviderSettings::AzureOpenAi(settings) => {
|
||||
let deployment_id = settings
|
||||
.deployment_id
|
||||
.as_deref()
|
||||
.ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
|
||||
|
||||
match deployment_id {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-preview
|
||||
"gpt-4" | "gpt-4-32k" => Ok(OpenAiModel::Four),
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35
|
||||
"gpt-35-turbo" | "gpt-35-turbo-16k" | "gpt-35-turbo-instruct" => {
|
||||
Ok(OpenAiModel::ThreePointFiveTurbo)
|
||||
}
|
||||
_ => Err(anyhow!(
|
||||
"no matching OpenAI model found for deployment ID: '{deployment_id}'"
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn provider_model_name(&self) -> anyhow::Result<String> {
|
||||
match &self.provider {
|
||||
AiProviderSettings::OpenAi(settings) => Ok(settings
|
||||
.default_model
|
||||
.unwrap_or(OpenAiModel::FourTurbo)
|
||||
.full_name()
|
||||
.to_string()),
|
||||
AiProviderSettings::AzureOpenAi(settings) => settings
|
||||
.deployment_id
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("no Azure OpenAI deployment ID")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
pub struct AssistantSettingsContentV1 {
|
||||
/// Whether the Assistant is enabled.
|
||||
///
|
||||
/// Default: true
|
||||
enabled: Option<bool>,
|
||||
/// Whether to show the assistant panel button in the status bar.
|
||||
///
|
||||
/// Default: true
|
||||
button: Option<bool>,
|
||||
/// Where to dock the assistant.
|
||||
///
|
||||
/// Default: right
|
||||
dock: Option<AssistantDockPosition>,
|
||||
/// Default width in pixels when the assistant is docked to the left or right.
|
||||
///
|
||||
/// Default: 640
|
||||
default_width: Option<f32>,
|
||||
/// Default height in pixels when the assistant is docked to the bottom.
|
||||
///
|
||||
/// Default: 320
|
||||
default_height: Option<f32>,
|
||||
/// The provider of the assistant service.
|
||||
///
|
||||
/// This can either be the internal `zed.dev` service or an external `openai` service,
|
||||
/// each with their respective default models and configurations.
|
||||
provider: Option<AssistantProvider>,
|
||||
impl Settings for AssistantSettings {
|
||||
const KEY: Option<&'static str> = Some("assistant");
|
||||
|
||||
type FileContent = AssistantSettingsContent;
|
||||
|
||||
fn load(
|
||||
default_value: &Self::FileContent,
|
||||
user_values: &[&Self::FileContent],
|
||||
_: &mut gpui::AppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
Self::load_via_json_merge(default_value, user_values)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
pub struct LegacyAssistantSettingsContent {
|
||||
/// Assistant panel settings
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
pub struct AssistantSettingsContent {
|
||||
/// Whether to show the assistant panel button in the status bar.
|
||||
///
|
||||
/// Default: true
|
||||
@@ -316,165 +180,88 @@ pub struct LegacyAssistantSettingsContent {
|
||||
///
|
||||
/// Default: 320
|
||||
pub default_height: Option<f32>,
|
||||
/// Deprecated: Please use `provider.default_model` instead.
|
||||
/// The default OpenAI model to use when starting new conversations.
|
||||
///
|
||||
/// Default: gpt-4-1106-preview
|
||||
#[deprecated = "Please use `provider.default_model` instead."]
|
||||
pub default_open_ai_model: Option<OpenAiModel>,
|
||||
/// Deprecated: Please use `provider.api_url` instead.
|
||||
/// OpenAI API base URL to use when starting new conversations.
|
||||
///
|
||||
/// Default: https://api.openai.com/v1
|
||||
#[deprecated = "Please use `provider.api_url` instead."]
|
||||
pub openai_api_url: Option<String>,
|
||||
/// The settings for the AI provider.
|
||||
#[serde(default)]
|
||||
pub provider: AiProviderSettingsContent,
|
||||
}
|
||||
|
||||
impl Settings for AssistantSettings {
|
||||
const KEY: Option<&'static str> = Some("assistant");
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum AiProviderSettings {
|
||||
/// The settings for the OpenAI provider.
|
||||
#[serde(rename = "openai")]
|
||||
OpenAi(OpenAiProviderSettings),
|
||||
/// The settings for the Azure OpenAI provider.
|
||||
#[serde(rename = "azure_openai")]
|
||||
AzureOpenAi(AzureOpenAiProviderSettings),
|
||||
}
|
||||
|
||||
type FileContent = AssistantSettingsContent;
|
||||
/// The settings for the AI provider used by the Zed Assistant.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum AiProviderSettingsContent {
|
||||
/// The settings for the OpenAI provider.
|
||||
#[serde(rename = "openai")]
|
||||
OpenAi(OpenAiProviderSettingsContent),
|
||||
/// The settings for the Azure OpenAI provider.
|
||||
#[serde(rename = "azure_openai")]
|
||||
AzureOpenAi(AzureOpenAiProviderSettingsContent),
|
||||
}
|
||||
|
||||
fn load(
|
||||
default_value: &Self::FileContent,
|
||||
user_values: &[&Self::FileContent],
|
||||
_: &mut gpui::AppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
let mut settings = AssistantSettings::default();
|
||||
|
||||
for value in [default_value].iter().chain(user_values) {
|
||||
let value = value.upgrade();
|
||||
merge(&mut settings.enabled, value.enabled);
|
||||
merge(&mut settings.button, value.button);
|
||||
merge(&mut settings.dock, value.dock);
|
||||
merge(
|
||||
&mut settings.default_width,
|
||||
value.default_width.map(Into::into),
|
||||
);
|
||||
merge(
|
||||
&mut settings.default_height,
|
||||
value.default_height.map(Into::into),
|
||||
);
|
||||
if let Some(provider) = value.provider.clone() {
|
||||
match (&mut settings.provider, provider) {
|
||||
(
|
||||
AssistantProvider::ZedDotDev { default_model },
|
||||
AssistantProvider::ZedDotDev {
|
||||
default_model: default_model_override,
|
||||
},
|
||||
) => {
|
||||
*default_model = default_model_override;
|
||||
}
|
||||
(
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
},
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: default_model_override,
|
||||
api_url: api_url_override,
|
||||
},
|
||||
) => {
|
||||
*default_model = default_model_override;
|
||||
*api_url = api_url_override;
|
||||
}
|
||||
(merged, provider_override) => {
|
||||
*merged = provider_override;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
impl Default for AiProviderSettingsContent {
|
||||
fn default() -> Self {
|
||||
Self::OpenAi(OpenAiProviderSettingsContent::default())
|
||||
}
|
||||
}
|
||||
|
||||
fn merge<T: Copy>(target: &mut T, value: Option<T>) {
|
||||
if let Some(value) = value {
|
||||
*target = value;
|
||||
}
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OpenAiProviderSettings {
|
||||
/// The OpenAI API base URL to use when starting new conversations.
|
||||
pub api_url: Option<String>,
|
||||
/// The default OpenAI model to use when starting new conversations.
|
||||
pub default_model: Option<OpenAiModel>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::{AppContext, BorrowAppContext};
|
||||
use settings::SettingsStore;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_deserialize_assistant_settings(cx: &mut AppContext) {
|
||||
let store = settings::SettingsStore::test(cx);
|
||||
cx.set_global(store);
|
||||
|
||||
// Settings default to gpt-4-turbo.
|
||||
AssistantSettings::register(cx);
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: OpenAiModel::FourTurbo,
|
||||
api_url: open_ai_url()
|
||||
}
|
||||
);
|
||||
|
||||
// Ensure backward-compatibility.
|
||||
cx.update_global::<SettingsStore, _>(|store, cx| {
|
||||
store
|
||||
.set_user_settings(
|
||||
r#"{
|
||||
"assistant": {
|
||||
"openai_api_url": "test-url",
|
||||
}
|
||||
}"#,
|
||||
cx,
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: OpenAiModel::FourTurbo,
|
||||
api_url: "test-url".into()
|
||||
}
|
||||
);
|
||||
cx.update_global::<SettingsStore, _>(|store, cx| {
|
||||
store
|
||||
.set_user_settings(
|
||||
r#"{
|
||||
"assistant": {
|
||||
"default_open_ai_model": "gpt-4-0613"
|
||||
}
|
||||
}"#,
|
||||
cx,
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: OpenAiModel::Four,
|
||||
api_url: open_ai_url()
|
||||
}
|
||||
);
|
||||
|
||||
// The new version supports setting a custom model when using zed.dev.
|
||||
cx.update_global::<SettingsStore, _>(|store, cx| {
|
||||
store
|
||||
.set_user_settings(
|
||||
r#"{
|
||||
"assistant": {
|
||||
"version": "1",
|
||||
"provider": {
|
||||
"name": "zed.dev",
|
||||
"default_model": "custom"
|
||||
}
|
||||
}
|
||||
}"#,
|
||||
cx,
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::ZedDotDev {
|
||||
default_model: ZedDotDevModel::Custom("custom".into())
|
||||
}
|
||||
);
|
||||
}
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct OpenAiProviderSettingsContent {
|
||||
/// The OpenAI API base URL to use when starting new conversations.
|
||||
///
|
||||
/// Default: https://api.openai.com/v1
|
||||
pub api_url: Option<String>,
|
||||
/// The default OpenAI model to use when starting new conversations.
|
||||
///
|
||||
/// Default: gpt-4-1106-preview
|
||||
pub default_model: Option<OpenAiModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AzureOpenAiProviderSettings {
|
||||
/// The Azure OpenAI API base URL to use when starting new conversations.
|
||||
pub api_url: Option<String>,
|
||||
/// The Azure OpenAI API version.
|
||||
pub api_version: Option<AzureOpenAiApiVersion>,
|
||||
/// The Azure OpenAI API deployment ID.
|
||||
pub deployment_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AzureOpenAiProviderSettingsContent {
|
||||
/// The Azure OpenAI API base URL to use when starting new conversations.
|
||||
pub api_url: Option<String>,
|
||||
/// The Azure OpenAI API version.
|
||||
pub api_version: Option<AzureOpenAiApiVersion>,
|
||||
/// The Azure OpenAI deployment ID.
|
||||
pub deployment_id: Option<String>,
|
||||
}
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
use crate::{
|
||||
streaming_diff::{Hunk, StreamingDiff},
|
||||
CompletionProvider, LanguageModelRequest,
|
||||
};
|
||||
use crate::streaming_diff::{Hunk, StreamingDiff};
|
||||
use ai::completion::{CompletionProvider, CompletionRequest};
|
||||
use anyhow::Result;
|
||||
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
|
||||
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
||||
use gpui::{EventEmitter, Model, ModelContext, Task};
|
||||
use language::{Rope, TransactionId};
|
||||
use std::{cmp, future, ops::Range};
|
||||
use multi_buffer;
|
||||
use std::{cmp, future, ops::Range, sync::Arc};
|
||||
|
||||
pub enum Event {
|
||||
Finished,
|
||||
@@ -21,6 +20,7 @@ pub enum CodegenKind {
|
||||
}
|
||||
|
||||
pub struct Codegen {
|
||||
provider: Arc<dyn CompletionProvider>,
|
||||
buffer: Model<MultiBuffer>,
|
||||
snapshot: MultiBufferSnapshot,
|
||||
kind: CodegenKind,
|
||||
@@ -35,9 +35,15 @@ pub struct Codegen {
|
||||
impl EventEmitter<Event> for Codegen {}
|
||||
|
||||
impl Codegen {
|
||||
pub fn new(buffer: Model<MultiBuffer>, kind: CodegenKind, cx: &mut ModelContext<Self>) -> Self {
|
||||
pub fn new(
|
||||
buffer: Model<MultiBuffer>,
|
||||
kind: CodegenKind,
|
||||
provider: Arc<dyn CompletionProvider>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
let snapshot = buffer.read(cx).snapshot(cx);
|
||||
Self {
|
||||
provider,
|
||||
buffer: buffer.clone(),
|
||||
snapshot,
|
||||
kind,
|
||||
@@ -88,7 +94,7 @@ impl Codegen {
|
||||
self.error.as_ref()
|
||||
}
|
||||
|
||||
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
|
||||
pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
|
||||
let range = self.range();
|
||||
let snapshot = self.snapshot.clone();
|
||||
let selected_text = snapshot
|
||||
@@ -102,7 +108,7 @@ impl Codegen {
|
||||
.next()
|
||||
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
|
||||
|
||||
let response = CompletionProvider::global(cx).complete(prompt);
|
||||
let response = self.provider.complete(prompt);
|
||||
self.generation = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let generate = async {
|
||||
@@ -299,7 +305,7 @@ fn strip_invalid_spans_from_codeblock(
|
||||
}
|
||||
|
||||
if first_line {
|
||||
if buffer.is_empty() || buffer == "`" || buffer == "``" {
|
||||
if buffer == "" || buffer == "`" || buffer == "``" {
|
||||
return future::ready(None);
|
||||
} else if buffer.starts_with("```") {
|
||||
starts_with_markdown_codeblock = true;
|
||||
@@ -354,9 +360,8 @@ fn strip_invalid_spans_from_codeblock(
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::FakeCompletionProvider;
|
||||
|
||||
use super::*;
|
||||
use ai::test::FakeCompletionProvider;
|
||||
use futures::stream::{self};
|
||||
use gpui::{Context, TestAppContext};
|
||||
use indoc::indoc;
|
||||
@@ -373,11 +378,15 @@ mod tests {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl CompletionRequest for DummyCompletionRequest {
|
||||
fn data(&self) -> serde_json::Result<String> {
|
||||
serde_json::to_string(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||
let provider = FakeCompletionProvider::default();
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.set_global(CompletionProvider::Fake(provider.clone()));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
let text = indoc! {"
|
||||
@@ -396,10 +405,19 @@ mod tests {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||
});
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Transform { range }, cx));
|
||||
let provider = Arc::new(FakeCompletionProvider::new());
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Transform { range },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let request = LanguageModelRequest::default();
|
||||
let request = Box::new(DummyCompletionRequest {
|
||||
name: "test".to_string(),
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
@@ -412,7 +430,8 @@ mod tests {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
provider.send_completion(chunk.into());
|
||||
println!("CHUNK: {:?}", &chunk);
|
||||
provider.send_completion(chunk);
|
||||
new_text = suffix;
|
||||
cx.background_executor.run_until_parked();
|
||||
}
|
||||
@@ -437,8 +456,6 @@ mod tests {
|
||||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
) {
|
||||
let provider = FakeCompletionProvider::default();
|
||||
cx.set_global(CompletionProvider::Fake(provider.clone()));
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
@@ -455,10 +472,19 @@ mod tests {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 6))
|
||||
});
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
|
||||
let provider = Arc::new(FakeCompletionProvider::new());
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Generate { position },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let request = LanguageModelRequest::default();
|
||||
let request = Box::new(DummyCompletionRequest {
|
||||
name: "test".to_string(),
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
@@ -471,7 +497,7 @@ mod tests {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
provider.send_completion(chunk.into());
|
||||
provider.send_completion(chunk);
|
||||
new_text = suffix;
|
||||
cx.background_executor.run_until_parked();
|
||||
}
|
||||
@@ -496,8 +522,6 @@ mod tests {
|
||||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
) {
|
||||
let provider = FakeCompletionProvider::default();
|
||||
cx.set_global(CompletionProvider::Fake(provider.clone()));
|
||||
cx.set_global(cx.update(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
@@ -514,10 +538,19 @@ mod tests {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 2))
|
||||
});
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
|
||||
let provider = Arc::new(FakeCompletionProvider::new());
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Generate { position },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let request = LanguageModelRequest::default();
|
||||
let request = Box::new(DummyCompletionRequest {
|
||||
name: "test".to_string(),
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
@@ -530,7 +563,8 @@ mod tests {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
provider.send_completion(chunk.into());
|
||||
println!("{:?}", &chunk);
|
||||
provider.send_completion(chunk);
|
||||
new_text = suffix;
|
||||
cx.background_executor.run_until_parked();
|
||||
}
|
||||
|
||||
@@ -1,188 +0,0 @@
|
||||
#[cfg(test)]
|
||||
mod fake;
|
||||
mod open_ai;
|
||||
mod zed;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use fake::*;
|
||||
pub use open_ai::*;
|
||||
pub use zed::*;
|
||||
|
||||
use crate::{
|
||||
assistant_settings::{AssistantProvider, AssistantSettings},
|
||||
LanguageModel, LanguageModelRequest,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use client::Client;
|
||||
use futures::{future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
let mut settings_version = 0;
|
||||
let provider = match &AssistantSettings::get_global(cx).provider {
|
||||
AssistantProvider::ZedDotDev { default_model } => {
|
||||
CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
client.clone(),
|
||||
settings_version,
|
||||
cx,
|
||||
))
|
||||
}
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
settings_version,
|
||||
)),
|
||||
};
|
||||
cx.set_global(provider);
|
||||
|
||||
cx.observe_global::<SettingsStore>(move |cx| {
|
||||
settings_version += 1;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
||||
match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
|
||||
(
|
||||
CompletionProvider::OpenAi(provider),
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
},
|
||||
) => {
|
||||
provider.update(default_model.clone(), api_url.clone(), settings_version);
|
||||
}
|
||||
(
|
||||
CompletionProvider::ZedDotDev(provider),
|
||||
AssistantProvider::ZedDotDev { default_model },
|
||||
) => {
|
||||
provider.update(default_model.clone(), settings_version);
|
||||
}
|
||||
(CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
|
||||
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
client.clone(),
|
||||
settings_version,
|
||||
cx,
|
||||
));
|
||||
}
|
||||
(
|
||||
CompletionProvider::ZedDotDev(_),
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
},
|
||||
) => {
|
||||
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
settings_version,
|
||||
));
|
||||
}
|
||||
#[cfg(test)]
|
||||
(CompletionProvider::Fake(_), _) => unimplemented!(),
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
pub enum CompletionProvider {
|
||||
OpenAi(OpenAiCompletionProvider),
|
||||
ZedDotDev(ZedDotDevCompletionProvider),
|
||||
#[cfg(test)]
|
||||
Fake(FakeCompletionProvider),
|
||||
}
|
||||
|
||||
impl gpui::Global for CompletionProvider {}
|
||||
|
||||
impl CompletionProvider {
|
||||
pub fn global(cx: &AppContext) -> &Self {
|
||||
cx.global::<Self>()
|
||||
}
|
||||
|
||||
pub fn settings_version(&self) -> usize {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.settings_version(),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => Task::ready(Ok(())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
|
||||
CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => Task::ready(Ok(())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_model(&self) -> LanguageModel {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
|
||||
CompletionProvider::ZedDotDev(provider) => {
|
||||
LanguageModel::ZedDotDev(provider.default_model())
|
||||
}
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn complete(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.complete(request),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.complete(request),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(provider) => provider.complete(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct FakeCompletionProvider {
|
||||
current_completion_tx: Arc<parking_lot::Mutex<Option<mpsc::UnboundedSender<String>>>>,
|
||||
}
|
||||
|
||||
impl FakeCompletionProvider {
|
||||
pub fn complete(&self) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
*self.current_completion_tx.lock() = Some(tx);
|
||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||
}
|
||||
|
||||
pub fn send_completion(&self, chunk: String) {
|
||||
self.current_completion_tx
|
||||
.lock()
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.unbounded_send(chunk)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
pub fn finish_completion(&self) {
|
||||
self.current_completion_tx.lock().take();
|
||||
}
|
||||
}
|
||||
@@ -1,301 +0,0 @@
|
||||
use crate::{
|
||||
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
|
||||
use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
|
||||
use settings::Settings;
|
||||
use std::{env, sync::Arc};
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::{http::HttpClient, ResultExt};
|
||||
|
||||
pub struct OpenAiCompletionProvider {
|
||||
api_key: Option<String>,
|
||||
api_url: String,
|
||||
default_model: OpenAiModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
settings_version: usize,
|
||||
}
|
||||
|
||||
impl OpenAiCompletionProvider {
|
||||
pub fn new(
|
||||
default_model: OpenAiModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
settings_version: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_key: None,
|
||||
api_url,
|
||||
default_model,
|
||||
http_client,
|
||||
settings_version,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) {
|
||||
self.default_model = default_model;
|
||||
self.api_url = api_url;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
pub fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated() {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = self.api_url.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
|
||||
api_key
|
||||
} else {
|
||||
let (_, api_key) = cx
|
||||
.update(|cx| cx.read_credentials(&api_url))?
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||
String::from_utf8(api_key)?
|
||||
};
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
if let CompletionProvider::OpenAi(provider) = provider {
|
||||
provider.api_key = Some(api_key);
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let delete_credentials = cx.delete_credentials(&self.api_url);
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
if let CompletionProvider::OpenAi(provider) = provider {
|
||||
provider.api_key = None;
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
pub fn default_model(&self) -> OpenAiModel {
|
||||
self.default_model.clone()
|
||||
}
|
||||
|
||||
pub fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
|
||||
pub fn complete(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_open_ai_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
||||
let model = match request.model {
|
||||
LanguageModel::ZedDotDev(_) => self.default_model(),
|
||||
LanguageModel::OpenAi(model) => model,
|
||||
};
|
||||
|
||||
Request {
|
||||
model,
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| RequestMessage {
|
||||
role: msg.role.into(),
|
||||
content: msg.content,
|
||||
})
|
||||
.collect(),
|
||||
stream: true,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_open_ai_tokens(
|
||||
request: LanguageModelRequest,
|
||||
background_executor: &gpui::BackgroundExecutor,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
background_executor
|
||||
.spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.content),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
impl From<Role> for open_ai::Role {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => OpenAiRole::User,
|
||||
Role::Assistant => OpenAiRole::Assistant,
|
||||
Role::System => OpenAiRole::System,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
api_key: View<Editor>,
|
||||
api_url: String,
|
||||
}
|
||||
|
||||
impl AuthenticationPrompt {
|
||||
fn new(api_url: String, cx: &mut WindowContext) -> Self {
|
||||
Self {
|
||||
api_key: cx.new_view(|cx| {
|
||||
let mut editor = Editor::single_line(cx);
|
||||
editor.set_placeholder_text(
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
}),
|
||||
api_url,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
let api_key = self.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
write_credentials.await?;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
if let CompletionProvider::OpenAi(provider) = provider {
|
||||
provider.api_key = Some(api_key);
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features,
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: FontWeight::NORMAL,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: relative(1.3),
|
||||
background_color: None,
|
||||
underline: None,
|
||||
strikethrough: None,
|
||||
white_space: WhiteSpace::Normal,
|
||||
};
|
||||
EditorElement::new(
|
||||
&self.api_key,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const INSTRUCTIONS: [&str; 6] = [
|
||||
"To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
|
||||
" - You can create an API key at: platform.openai.com/api-keys",
|
||||
" - Make sure your OpenAI account has credits",
|
||||
" - Having a subscription for another service like GitHub Copilot won't work.",
|
||||
"",
|
||||
"Paste your OpenAI API key below and hit enter to use the assistant:",
|
||||
];
|
||||
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.children(
|
||||
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.my_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(self.render_api_key_editor(cx)),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
|
||||
)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Label::new("Click on").size(LabelSize::Small))
|
||||
.child(Icon::new(IconName::Ai).size(IconSize::XSmall))
|
||||
.child(
|
||||
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
||||
@@ -1,175 +0,0 @@
|
||||
use crate::{
|
||||
assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
|
||||
LanguageModelRequest,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use client::{proto, Client};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
|
||||
use gpui::{AnyView, AppContext, Task};
|
||||
use std::{future, sync::Arc};
|
||||
use ui::prelude::*;
|
||||
|
||||
pub struct ZedDotDevCompletionProvider {
|
||||
client: Arc<Client>,
|
||||
default_model: ZedDotDevModel,
|
||||
settings_version: usize,
|
||||
status: client::Status,
|
||||
_maintain_client_status: Task<()>,
|
||||
}
|
||||
|
||||
impl ZedDotDevCompletionProvider {
|
||||
pub fn new(
|
||||
default_model: ZedDotDevModel,
|
||||
client: Arc<Client>,
|
||||
settings_version: usize,
|
||||
cx: &mut AppContext,
|
||||
) -> Self {
|
||||
let mut status_rx = client.status();
|
||||
let status = *status_rx.borrow();
|
||||
let maintain_client_status = cx.spawn(|mut cx| async move {
|
||||
while let Some(status) = status_rx.next().await {
|
||||
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
if let CompletionProvider::ZedDotDev(provider) = provider {
|
||||
provider.status = status;
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
Self {
|
||||
client,
|
||||
default_model,
|
||||
settings_version,
|
||||
status,
|
||||
_maintain_client_status: maintain_client_status,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) {
|
||||
self.default_model = default_model;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
pub fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
pub fn default_model(&self) -> ZedDotDevModel {
|
||||
self.default_model.clone()
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
self.status.is_connected()
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let client = self.client.clone();
|
||||
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
||||
}
|
||||
|
||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|_cx| AuthenticationPrompt).into()
|
||||
}
|
||||
|
||||
pub fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match request.model {
|
||||
LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
||||
LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
LanguageModel::ZedDotDev(
|
||||
ZedDotDevModel::Claude3Opus
|
||||
| ZedDotDevModel::Claude3Sonnet
|
||||
| ZedDotDevModel::Claude3Haiku,
|
||||
) => {
|
||||
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
|
||||
let request = self.client.request(proto::CountTokensWithLanguageModel {
|
||||
model,
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
});
|
||||
async move {
|
||||
let response = request.await?;
|
||||
Ok(response.token_count as usize)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn complete(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = proto::CompleteWithLanguageModel {
|
||||
model: request.model.id().to_string(),
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.to_proto())
|
||||
.collect(),
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
};
|
||||
|
||||
self.client
|
||||
.request_stream(request)
|
||||
.map_ok(|stream| {
|
||||
stream
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt;
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
|
||||
|
||||
v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Button::new("sign_in", "Sign in")
|
||||
.icon_color(Color::Muted)
|
||||
.icon(IconName::Github)
|
||||
.icon_position(IconPosition::Start)
|
||||
.style(ButtonStyle::Filled)
|
||||
.full_width()
|
||||
.on_click(|_, cx| {
|
||||
CompletionProvider::global(cx)
|
||||
.authenticate(cx)
|
||||
.detach_and_log_err(cx);
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
div().flex().w_full().items_center().child(
|
||||
Label::new("Sign in to enable collaboration.")
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
use editor::MultiBuffer;
|
||||
use gpui::{AppContext, Model, ModelContext, Subscription};
|
||||
|
||||
use crate::{assistant_panel::Conversation, LanguageModelRequestMessage, Role};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct EmbeddedScope {
|
||||
active_buffer: Option<Model<MultiBuffer>>,
|
||||
active_buffer_enabled: bool,
|
||||
active_buffer_subscription: Option<Subscription>,
|
||||
}
|
||||
|
||||
impl EmbeddedScope {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
active_buffer: None,
|
||||
active_buffer_enabled: true,
|
||||
active_buffer_subscription: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_active_buffer(
|
||||
&mut self,
|
||||
buffer: Option<Model<MultiBuffer>>,
|
||||
cx: &mut ModelContext<Conversation>,
|
||||
) {
|
||||
self.active_buffer_subscription.take();
|
||||
|
||||
if let Some(active_buffer) = buffer.clone() {
|
||||
self.active_buffer_subscription =
|
||||
Some(cx.subscribe(&active_buffer, |conversation, _, e, cx| {
|
||||
if let multi_buffer::Event::Edited { .. } = e {
|
||||
conversation.count_remaining_tokens(cx)
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
self.active_buffer = buffer;
|
||||
}
|
||||
|
||||
pub fn active_buffer(&self) -> Option<&Model<MultiBuffer>> {
|
||||
self.active_buffer.as_ref()
|
||||
}
|
||||
|
||||
pub fn active_buffer_enabled(&self) -> bool {
|
||||
self.active_buffer_enabled
|
||||
}
|
||||
|
||||
pub fn set_active_buffer_enabled(&mut self, enabled: bool) {
|
||||
self.active_buffer_enabled = enabled;
|
||||
}
|
||||
|
||||
/// Provide a message for the language model based on the active buffer.
|
||||
pub fn message(&self, cx: &AppContext) -> Option<LanguageModelRequestMessage> {
|
||||
if !self.active_buffer_enabled {
|
||||
return None;
|
||||
}
|
||||
|
||||
let active_buffer = self.active_buffer.as_ref()?;
|
||||
let buffer = active_buffer.read(cx);
|
||||
|
||||
if let Some(singleton) = buffer.as_singleton() {
|
||||
let singleton = singleton.read(cx);
|
||||
|
||||
let filename = singleton
|
||||
.file()
|
||||
.map(|file| file.path().to_string_lossy())
|
||||
.unwrap_or("Untitled".into());
|
||||
|
||||
let text = singleton.text();
|
||||
|
||||
let language = singleton
|
||||
.language()
|
||||
.map(|l| {
|
||||
let name = l.code_fence_block_name();
|
||||
name.to_string()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let markdown =
|
||||
format!("User's active file `{filename}`:\n\n```{language}\n{text}```\n\n");
|
||||
|
||||
return Some(LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: markdown,
|
||||
});
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -1,95 +1,394 @@
|
||||
use language::BufferSnapshot;
|
||||
use std::{fmt::Write, ops::Range};
|
||||
use ai::models::LanguageModel;
|
||||
use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
|
||||
use ai::prompts::file_context::FileContext;
|
||||
use ai::prompts::generate::GenerateInlineContent;
|
||||
use ai::prompts::preamble::EngineerPreamble;
|
||||
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
|
||||
use ai::providers::open_ai::OpenAiLanguageModel;
|
||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||
use std::cmp::{self, Reverse};
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
|
||||
#[derive(Debug)]
|
||||
struct Match {
|
||||
collapse: Range<usize>,
|
||||
keep: Vec<Range<usize>>,
|
||||
}
|
||||
|
||||
let selected_range = selected_range.to_offset(buffer);
|
||||
let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
|
||||
Some(&grammar.embedding_config.as_ref()?.query)
|
||||
});
|
||||
let configs = ts_matches
|
||||
.grammars()
|
||||
.iter()
|
||||
.map(|g| g.embedding_config.as_ref().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
let mut matches = Vec::new();
|
||||
while let Some(mat) = ts_matches.peek() {
|
||||
let config = &configs[mat.grammar_index];
|
||||
if let Some(collapse) = mat.captures.iter().find_map(|cap| {
|
||||
if Some(cap.index) == config.collapse_capture_ix {
|
||||
Some(cap.node.byte_range())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}) {
|
||||
let mut keep = Vec::new();
|
||||
for capture in mat.captures.iter() {
|
||||
if Some(capture.index) == config.keep_capture_ix {
|
||||
keep.push(capture.node.byte_range());
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
ts_matches.advance();
|
||||
matches.push(Match { collapse, keep });
|
||||
} else {
|
||||
ts_matches.advance();
|
||||
}
|
||||
}
|
||||
matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
|
||||
let mut matches = matches.into_iter().peekable();
|
||||
|
||||
let mut summary = String::new();
|
||||
let mut offset = 0;
|
||||
let mut flushed_selection = false;
|
||||
while let Some(mat) = matches.next() {
|
||||
// Keep extending the collapsed range if the next match surrounds
|
||||
// the current one.
|
||||
while let Some(next_mat) = matches.peek() {
|
||||
if mat.collapse.start <= next_mat.collapse.start
|
||||
&& mat.collapse.end >= next_mat.collapse.end
|
||||
{
|
||||
matches.next().unwrap();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if offset > mat.collapse.start {
|
||||
// Skip collapsed nodes that have already been summarized.
|
||||
offset = cmp::max(offset, mat.collapse.end);
|
||||
continue;
|
||||
}
|
||||
|
||||
if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
|
||||
if !flushed_selection {
|
||||
// The collapsed node ends after the selection starts, so we'll flush the selection first.
|
||||
summary.extend(buffer.text_for_range(offset..selected_range.start));
|
||||
summary.push_str("<|S|");
|
||||
if selected_range.end == selected_range.start {
|
||||
summary.push_str(">");
|
||||
} else {
|
||||
summary.extend(buffer.text_for_range(selected_range.clone()));
|
||||
summary.push_str("|E|>");
|
||||
}
|
||||
offset = selected_range.end;
|
||||
flushed_selection = true;
|
||||
}
|
||||
|
||||
// If the selection intersects the collapsed node, we won't collapse it.
|
||||
if selected_range.end >= mat.collapse.start {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
summary.extend(buffer.text_for_range(offset..mat.collapse.start));
|
||||
for keep in mat.keep {
|
||||
summary.extend(buffer.text_for_range(keep));
|
||||
}
|
||||
offset = mat.collapse.end;
|
||||
}
|
||||
|
||||
// Flush selection if we haven't already done so.
|
||||
if !flushed_selection && offset <= selected_range.start {
|
||||
summary.extend(buffer.text_for_range(offset..selected_range.start));
|
||||
summary.push_str("<|S|");
|
||||
if selected_range.end == selected_range.start {
|
||||
summary.push_str(">");
|
||||
} else {
|
||||
summary.extend(buffer.text_for_range(selected_range.clone()));
|
||||
summary.push_str("|E|>");
|
||||
}
|
||||
offset = selected_range.end;
|
||||
}
|
||||
|
||||
summary.extend(buffer.text_for_range(offset..buffer.len()));
|
||||
summary
|
||||
}
|
||||
|
||||
pub fn generate_content_prompt(
|
||||
user_prompt: String,
|
||||
language_name: Option<&str>,
|
||||
buffer: BufferSnapshot,
|
||||
range: Range<usize>,
|
||||
search_results: Vec<PromptCodeSnippet>,
|
||||
model: &str,
|
||||
project_name: Option<String>,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut prompt = String::new();
|
||||
|
||||
let content_type = match language_name {
|
||||
None | Some("Markdown" | "Plain Text") => {
|
||||
writeln!(prompt, "You are an expert engineer.")?;
|
||||
"Text"
|
||||
}
|
||||
Some(language_name) => {
|
||||
writeln!(prompt, "You are an expert {language_name} engineer.")?;
|
||||
writeln!(
|
||||
prompt,
|
||||
"Your answer MUST always and only be valid {}.",
|
||||
language_name
|
||||
)?;
|
||||
"Code"
|
||||
}
|
||||
// Using new Prompt Templates
|
||||
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAiLanguageModel::load(model));
|
||||
let lang_name = if let Some(language_name) = language_name {
|
||||
Some(language_name.to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(project_name) = project_name {
|
||||
writeln!(
|
||||
prompt,
|
||||
"You are currently working inside the '{project_name}' project in code editor Zed."
|
||||
)?;
|
||||
}
|
||||
let args = PromptArguments {
|
||||
model: openai_model,
|
||||
language_name: lang_name.clone(),
|
||||
project_name,
|
||||
snippets: search_results.clone(),
|
||||
reserved_tokens: 1000,
|
||||
buffer: Some(buffer),
|
||||
selected_range: Some(range),
|
||||
user_prompt: Some(user_prompt.clone()),
|
||||
};
|
||||
|
||||
// Include file content.
|
||||
for chunk in buffer.text_for_range(0..range.start) {
|
||||
prompt.push_str(chunk);
|
||||
}
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(RepositoryContext {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(FileContext {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Mandatory,
|
||||
Box::new(GenerateInlineContent {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
let (prompt, _) = chain.generate(true)?;
|
||||
|
||||
if range.is_empty() {
|
||||
prompt.push_str("<|START|>");
|
||||
} else {
|
||||
prompt.push_str("<|START|");
|
||||
}
|
||||
|
||||
for chunk in buffer.text_for_range(range.clone()) {
|
||||
prompt.push_str(chunk);
|
||||
}
|
||||
|
||||
if !range.is_empty() {
|
||||
prompt.push_str("|END|>");
|
||||
}
|
||||
|
||||
for chunk in buffer.text_for_range(range.end..buffer.len()) {
|
||||
prompt.push_str(chunk);
|
||||
}
|
||||
|
||||
prompt.push('\n');
|
||||
|
||||
if range.is_empty() {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Assume the cursor is located where the `<|START|>` span is."
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"{content_type} can't be replaced, so assume your answer will be inserted at the cursor.",
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Generate {content_type} based on the users prompt: {user_prompt}",
|
||||
)
|
||||
.unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
|
||||
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Double check that you only return code and not the '<|START|' and '|END|'> spans"
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
writeln!(prompt, "Never make remarks about the output.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Do not return anything else, except the generated {content_type}."
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(prompt)
|
||||
anyhow::Ok(prompt)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
use gpui::{AppContext, Context};
|
||||
use indoc::indoc;
|
||||
use language::{
|
||||
language_settings, tree_sitter_rust, Buffer, BufferId, Language, LanguageConfig,
|
||||
LanguageMatcher, Point,
|
||||
};
|
||||
use settings::SettingsStore;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub(crate) fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::language()),
|
||||
)
|
||||
.with_embedding_query(
|
||||
r#"
|
||||
(
|
||||
[(line_comment) (attribute_item)]* @context
|
||||
.
|
||||
[
|
||||
(struct_item
|
||||
name: (_) @name)
|
||||
|
||||
(enum_item
|
||||
name: (_) @name)
|
||||
|
||||
(impl_item
|
||||
trait: (_)? @name
|
||||
"for"? @name
|
||||
type: (_) @name)
|
||||
|
||||
(trait_item
|
||||
name: (_) @name)
|
||||
|
||||
(function_item
|
||||
name: (_) @name
|
||||
body: (block
|
||||
"{" @keep
|
||||
"}" @keep) @collapse)
|
||||
|
||||
(macro_definition
|
||||
name: (_) @name)
|
||||
] @item
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_outline_for_prompt(cx: &mut AppContext) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language_settings::init(cx);
|
||||
let text = indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {
|
||||
let a = 1;
|
||||
let b = 2;
|
||||
Self { a, b }
|
||||
}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {
|
||||
self.a
|
||||
}
|
||||
|
||||
pub fn b(&self) -> usize {
|
||||
self.b
|
||||
}
|
||||
}
|
||||
"};
|
||||
let buffer = cx.new_model(|cx| {
|
||||
Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
|
||||
});
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
|
||||
indoc! {"
|
||||
struct X {
|
||||
<|S|>a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
"}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
|
||||
indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {
|
||||
let <|S|a |E|>= 1;
|
||||
let b = 2;
|
||||
Self { a, b }
|
||||
}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
"}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
|
||||
indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
<|S|>
|
||||
fn new() -> Self {}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
"}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
|
||||
indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
<|S|>"}
|
||||
);
|
||||
|
||||
// Ensure nested functions get collapsed properly.
|
||||
let text = indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {
|
||||
let a = 1;
|
||||
let b = 2;
|
||||
Self { a, b }
|
||||
}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {
|
||||
let a = 30;
|
||||
fn nested() -> usize {
|
||||
3
|
||||
}
|
||||
self.a + nested()
|
||||
}
|
||||
|
||||
pub fn b(&self) -> usize {
|
||||
self.b
|
||||
}
|
||||
}
|
||||
"};
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
|
||||
indoc! {"
|
||||
<|S|>struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
use crate::{assistant_settings::OpenAiModel, MessageId, MessageMetadata};
|
||||
use anyhow::{anyhow, Result};
|
||||
use collections::HashMap;
|
||||
use fs::Fs;
|
||||
use futures::StreamExt;
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
cmp::Reverse,
|
||||
ffi::OsStr,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::paths::CONVERSATIONS_DIR;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SavedMessage {
|
||||
pub id: MessageId,
|
||||
pub start: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SavedConversation {
|
||||
pub id: Option<String>,
|
||||
pub zed: String,
|
||||
pub version: String,
|
||||
pub text: String,
|
||||
pub messages: Vec<SavedMessage>,
|
||||
pub message_metadata: HashMap<MessageId, MessageMetadata>,
|
||||
pub summary: String,
|
||||
}
|
||||
|
||||
impl SavedConversation {
|
||||
pub const VERSION: &'static str = "0.2.0";
|
||||
|
||||
pub async fn load(path: &Path, fs: &dyn Fs) -> Result<Self> {
|
||||
let saved_conversation = fs.load(path).await?;
|
||||
let saved_conversation_json =
|
||||
serde_json::from_str::<serde_json::Value>(&saved_conversation)?;
|
||||
match saved_conversation_json
|
||||
.get("version")
|
||||
.ok_or_else(|| anyhow!("version not found"))?
|
||||
{
|
||||
serde_json::Value::String(version) => match version.as_str() {
|
||||
Self::VERSION => Ok(serde_json::from_value::<Self>(saved_conversation_json)?),
|
||||
"0.1.0" => {
|
||||
let saved_conversation =
|
||||
serde_json::from_value::<SavedConversationV0_1_0>(saved_conversation_json)?;
|
||||
Ok(Self {
|
||||
id: saved_conversation.id,
|
||||
zed: saved_conversation.zed,
|
||||
version: saved_conversation.version,
|
||||
text: saved_conversation.text,
|
||||
messages: saved_conversation.messages,
|
||||
message_metadata: saved_conversation.message_metadata,
|
||||
summary: saved_conversation.summary,
|
||||
})
|
||||
}
|
||||
_ => Err(anyhow!(
|
||||
"unrecognized saved conversation version: {}",
|
||||
version
|
||||
)),
|
||||
},
|
||||
_ => Err(anyhow!("version not found on saved conversation")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SavedConversationV0_1_0 {
|
||||
id: Option<String>,
|
||||
zed: String,
|
||||
version: String,
|
||||
text: String,
|
||||
messages: Vec<SavedMessage>,
|
||||
message_metadata: HashMap<MessageId, MessageMetadata>,
|
||||
summary: String,
|
||||
api_url: Option<String>,
|
||||
model: OpenAiModel,
|
||||
}
|
||||
|
||||
pub struct SavedConversationMetadata {
|
||||
pub title: String,
|
||||
pub path: PathBuf,
|
||||
pub mtime: chrono::DateTime<chrono::Local>,
|
||||
}
|
||||
|
||||
impl SavedConversationMetadata {
|
||||
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
|
||||
fs.create_dir(&CONVERSATIONS_DIR).await?;
|
||||
|
||||
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
|
||||
let mut conversations = Vec::<SavedConversationMetadata>::new();
|
||||
while let Some(path) = paths.next().await {
|
||||
let path = path?;
|
||||
if path.extension() != Some(OsStr::new("json")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pattern = r" - \d+.zed.json$";
|
||||
let re = Regex::new(pattern).unwrap();
|
||||
|
||||
let metadata = fs.metadata(&path).await?;
|
||||
if let Some((file_name, metadata)) = path
|
||||
.file_name()
|
||||
.and_then(|name| name.to_str())
|
||||
.zip(metadata)
|
||||
{
|
||||
let title = re.replace(file_name, "");
|
||||
conversations.push(Self {
|
||||
title: title.into_owned(),
|
||||
path,
|
||||
mtime: metadata.mtime.into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
|
||||
|
||||
Ok(conversations)
|
||||
}
|
||||
}
|
||||
@@ -197,10 +197,12 @@ impl StreamingDiff {
|
||||
} else {
|
||||
hunks.push(Hunk::Remove { len: char_len })
|
||||
}
|
||||
} else if let Some(Hunk::Keep { len }) = hunks.last_mut() {
|
||||
*len += char_len;
|
||||
} else {
|
||||
hunks.push(Hunk::Keep { len: char_len })
|
||||
if let Some(Hunk::Keep { len }) = hunks.last_mut() {
|
||||
*len += char_len;
|
||||
} else {
|
||||
hunks.push(Hunk::Keep { len: char_len })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use assets::SoundRegistry;
|
||||
use derive_more::{Deref, DerefMut};
|
||||
use gpui::{AppContext, AssetSource, BorrowAppContext, Global};
|
||||
use gpui::{AppContext, AssetSource, Global};
|
||||
use rodio::{OutputStream, OutputStreamHandle};
|
||||
use util::ResultExt;
|
||||
|
||||
|
||||
@@ -373,10 +373,7 @@ impl ActiveCall {
|
||||
self.report_call_event("hang up", cx);
|
||||
|
||||
Audio::end_call(cx);
|
||||
|
||||
let channel_id = self.channel_id(cx);
|
||||
if let Some((room, _)) = self.room.take() {
|
||||
cx.emit(Event::RoomLeft { channel_id });
|
||||
room.update(cx, |room, cx| room.leave(cx))
|
||||
} else {
|
||||
Task::ready(Ok(()))
|
||||
|
||||
@@ -52,7 +52,7 @@ pub enum Event {
|
||||
RemoteProjectInvitationDiscarded {
|
||||
project_id: u64,
|
||||
},
|
||||
RoomLeft {
|
||||
Left {
|
||||
channel_id: Option<ChannelId>,
|
||||
},
|
||||
}
|
||||
@@ -366,6 +366,9 @@ impl Room {
|
||||
|
||||
pub(crate) fn leave(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||
cx.notify();
|
||||
cx.emit(Event::Left {
|
||||
channel_id: self.channel_id(),
|
||||
});
|
||||
self.leave_internal(cx)
|
||||
}
|
||||
|
||||
|
||||
@@ -51,7 +51,6 @@ pub struct ChannelMessage {
|
||||
pub nonce: u128,
|
||||
pub mentions: Vec<(Range<usize>, UserId)>,
|
||||
pub reply_to_message_id: Option<u64>,
|
||||
pub edited_at: Option<OffsetDateTime>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
@@ -84,10 +83,6 @@ pub enum ChannelChatEvent {
|
||||
old_range: Range<usize>,
|
||||
new_count: usize,
|
||||
},
|
||||
UpdateMessage {
|
||||
message_id: ChannelMessageId,
|
||||
message_ix: usize,
|
||||
},
|
||||
NewMessage {
|
||||
channel_id: ChannelId,
|
||||
message_id: u64,
|
||||
@@ -98,7 +93,6 @@ impl EventEmitter<ChannelChatEvent> for ChannelChat {}
|
||||
pub fn init(client: &Arc<Client>) {
|
||||
client.add_model_message_handler(ChannelChat::handle_message_sent);
|
||||
client.add_model_message_handler(ChannelChat::handle_message_removed);
|
||||
client.add_model_message_handler(ChannelChat::handle_message_updated);
|
||||
}
|
||||
|
||||
impl ChannelChat {
|
||||
@@ -195,7 +189,6 @@ impl ChannelChat {
|
||||
mentions: message.mentions.clone(),
|
||||
nonce,
|
||||
reply_to_message_id: message.reply_to_message_id,
|
||||
edited_at: None,
|
||||
},
|
||||
&(),
|
||||
),
|
||||
@@ -222,9 +215,6 @@ impl ChannelChat {
|
||||
let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.insert_messages(SumTree::from_item(message, &()), cx);
|
||||
if this.first_loaded_message_id.is_none() {
|
||||
this.first_loaded_message_id = Some(id);
|
||||
}
|
||||
})?;
|
||||
Ok(id)
|
||||
}))
|
||||
@@ -244,35 +234,6 @@ impl ChannelChat {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn update_message(
|
||||
&mut self,
|
||||
id: u64,
|
||||
message: MessageParams,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Result<Task<Result<()>>> {
|
||||
self.message_update(
|
||||
ChannelMessageId::Saved(id),
|
||||
message.text.clone(),
|
||||
message.mentions.clone(),
|
||||
Some(OffsetDateTime::now_utc()),
|
||||
cx,
|
||||
);
|
||||
|
||||
let nonce: u128 = self.rng.gen();
|
||||
|
||||
let request = self.rpc.request(proto::UpdateChannelMessage {
|
||||
channel_id: self.channel_id.0,
|
||||
message_id: id,
|
||||
body: message.text,
|
||||
nonce: Some(nonce.into()),
|
||||
mentions: mentions_to_proto(&message.mentions),
|
||||
});
|
||||
Ok(cx.spawn(move |_, _| async move {
|
||||
request.await?;
|
||||
Ok(())
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
|
||||
if self.loaded_all_messages {
|
||||
return None;
|
||||
@@ -562,32 +523,6 @@ impl ChannelChat {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_message_updated(
|
||||
this: Model<Self>,
|
||||
message: TypedEnvelope<proto::ChannelMessageUpdate>,
|
||||
_: Arc<Client>,
|
||||
mut cx: AsyncAppContext,
|
||||
) -> Result<()> {
|
||||
let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?;
|
||||
let message = message
|
||||
.payload
|
||||
.message
|
||||
.ok_or_else(|| anyhow!("empty message"))?;
|
||||
|
||||
let message = ChannelMessage::from_proto(message, &user_store, &mut cx).await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.message_update(
|
||||
message.id,
|
||||
message.body,
|
||||
message.mentions,
|
||||
message.edited_at,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
|
||||
if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
|
||||
let nonces = messages
|
||||
@@ -664,38 +599,6 @@ impl ChannelChat {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn message_update(
|
||||
&mut self,
|
||||
id: ChannelMessageId,
|
||||
body: String,
|
||||
mentions: Vec<(Range<usize>, u64)>,
|
||||
edited_at: Option<OffsetDateTime>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
let mut cursor = self.messages.cursor::<ChannelMessageId>();
|
||||
let mut messages = cursor.slice(&id, Bias::Left, &());
|
||||
let ix = messages.summary().count;
|
||||
|
||||
if let Some(mut message_to_update) = cursor.item().cloned() {
|
||||
message_to_update.body = body;
|
||||
message_to_update.mentions = mentions;
|
||||
message_to_update.edited_at = edited_at;
|
||||
messages.push(message_to_update, &());
|
||||
cursor.next(&());
|
||||
}
|
||||
|
||||
messages.append(cursor.suffix(&()), &());
|
||||
drop(cursor);
|
||||
self.messages = messages;
|
||||
|
||||
cx.emit(ChannelChatEvent::UpdateMessage {
|
||||
message_ix: ix,
|
||||
message_id: id,
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
||||
async fn messages_from_proto(
|
||||
@@ -720,15 +623,6 @@ impl ChannelMessage {
|
||||
user_store.get_user(message.sender_id, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let edited_at = message.edited_at.and_then(|t| -> Option<OffsetDateTime> {
|
||||
if let Ok(a) = OffsetDateTime::from_unix_timestamp(t as i64) {
|
||||
return Some(a);
|
||||
}
|
||||
|
||||
None
|
||||
});
|
||||
|
||||
Ok(ChannelMessage {
|
||||
id: ChannelMessageId::Saved(message.id),
|
||||
body: message.body,
|
||||
@@ -747,7 +641,6 @@ impl ChannelMessage {
|
||||
.ok_or_else(|| anyhow!("nonce is required"))?
|
||||
.into(),
|
||||
reply_to_message_id: message.reply_to_message_id,
|
||||
edited_at,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ use rpc::{
|
||||
};
|
||||
use settings::Settings;
|
||||
use std::{mem, sync::Arc, time::Duration};
|
||||
use util::{maybe, ResultExt};
|
||||
use util::{async_maybe, maybe, ResultExt};
|
||||
|
||||
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
@@ -227,7 +227,7 @@ impl ChannelStore {
|
||||
_watch_connection_status: watch_connection_status,
|
||||
disconnect_channel_buffers_task: None,
|
||||
_update_channels: cx.spawn(|this, mut cx| async move {
|
||||
maybe!(async move {
|
||||
async_maybe!({
|
||||
while let Some(update_channels) = update_channels_rx.next().await {
|
||||
if let Some(this) = this.upgrade() {
|
||||
let update_task = this.update(&mut cx, |this, cx| {
|
||||
|
||||
@@ -186,7 +186,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
|
||||
mentions: vec![],
|
||||
nonce: Some(1.into()),
|
||||
reply_to_message_id: None,
|
||||
edited_at: None,
|
||||
},
|
||||
proto::ChannelMessage {
|
||||
id: 11,
|
||||
@@ -196,7 +195,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
|
||||
mentions: vec![],
|
||||
nonce: Some(2.into()),
|
||||
reply_to_message_id: None,
|
||||
edited_at: None,
|
||||
},
|
||||
],
|
||||
done: false,
|
||||
@@ -245,7 +243,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
|
||||
mentions: vec![],
|
||||
nonce: Some(3.into()),
|
||||
reply_to_message_id: None,
|
||||
edited_at: None,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -300,7 +297,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
|
||||
nonce: Some(4.into()),
|
||||
mentions: vec![],
|
||||
reply_to_message_id: None,
|
||||
edited_at: None,
|
||||
},
|
||||
proto::ChannelMessage {
|
||||
id: 9,
|
||||
@@ -310,7 +306,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
|
||||
nonce: Some(5.into()),
|
||||
mentions: vec![],
|
||||
reply_to_message_id: None,
|
||||
edited_at: None,
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#![cfg_attr(any(target_os = "linux", target_os = "windows"), allow(dead_code))]
|
||||
#![cfg_attr(target_os = "linux", allow(dead_code))]
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use clap::Parser;
|
||||
|
||||
@@ -13,12 +13,11 @@ use async_tungstenite::tungstenite::{
|
||||
use clock::SystemClock;
|
||||
use collections::HashMap;
|
||||
use futures::{
|
||||
channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
|
||||
channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt,
|
||||
TryFutureExt as _, TryStreamExt,
|
||||
};
|
||||
use gpui::{
|
||||
actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, BorrowAppContext, Global, Model,
|
||||
Task, WeakModel,
|
||||
actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Global, Model, Task, WeakModel,
|
||||
};
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::RwLock;
|
||||
@@ -28,8 +27,8 @@ use release_channel::{AppVersion, ReleaseChannel};
|
||||
use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::fmt;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
convert::TryFrom,
|
||||
@@ -37,10 +36,7 @@ use std::{
|
||||
future::Future,
|
||||
marker::PhantomData,
|
||||
path::PathBuf,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc, Weak,
|
||||
},
|
||||
sync::{atomic::AtomicU64, Arc, Weak},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use telemetry::Telemetry;
|
||||
@@ -53,15 +49,6 @@ pub use rpc::*;
|
||||
pub use telemetry_events::Event;
|
||||
pub use user::*;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub struct DevServerToken(pub String);
|
||||
|
||||
impl fmt::Display for DevServerToken {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref ZED_SERVER_URL: Option<String> = std::env::var("ZED_SERVER_URL").ok();
|
||||
static ref ZED_RPC_URL: Option<String> = std::env::var("ZED_RPC_URL").ok();
|
||||
@@ -287,22 +274,10 @@ enum WeakSubscriber {
|
||||
Pending(Vec<Box<dyn AnyTypedEnvelope>>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub enum Credentials {
|
||||
DevServer { token: DevServerToken },
|
||||
User { user_id: u64, access_token: String },
|
||||
}
|
||||
|
||||
impl Credentials {
|
||||
pub fn authorization_header(&self) -> String {
|
||||
match self {
|
||||
Credentials::DevServer { token } => format!("dev-server-token {}", token),
|
||||
Credentials::User {
|
||||
user_id,
|
||||
access_token,
|
||||
} => format!("{} {}", user_id, access_token),
|
||||
}
|
||||
}
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Credentials {
|
||||
pub user_id: u64,
|
||||
pub access_token: String,
|
||||
}
|
||||
|
||||
impl Default for ClientState {
|
||||
@@ -467,7 +442,7 @@ impl Client {
|
||||
}
|
||||
|
||||
pub fn id(&self) -> u64 {
|
||||
self.id.load(Ordering::SeqCst)
|
||||
self.id.load(std::sync::atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
|
||||
@@ -475,7 +450,7 @@ impl Client {
|
||||
}
|
||||
|
||||
pub fn set_id(&self, id: u64) -> &Self {
|
||||
self.id.store(id, Ordering::SeqCst);
|
||||
self.id.store(id, std::sync::atomic::Ordering::SeqCst);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -519,11 +494,11 @@ impl Client {
|
||||
}
|
||||
|
||||
pub fn user_id(&self) -> Option<u64> {
|
||||
if let Some(Credentials::User { user_id, .. }) = self.state.read().credentials.as_ref() {
|
||||
Some(*user_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
self.state
|
||||
.read()
|
||||
.credentials
|
||||
.as_ref()
|
||||
.map(|credentials| credentials.user_id)
|
||||
}
|
||||
|
||||
pub fn peer_id(&self) -> Option<PeerId> {
|
||||
@@ -768,10 +743,6 @@ impl Client {
|
||||
read_credentials_from_keychain(cx).await.is_some()
|
||||
}
|
||||
|
||||
pub fn set_dev_server_token(&self, token: DevServerToken) {
|
||||
self.state.write().credentials = Some(Credentials::DevServer { token });
|
||||
}
|
||||
|
||||
#[async_recursion(?Send)]
|
||||
pub async fn authenticate_and_connect(
|
||||
self: &Arc<Self>,
|
||||
@@ -822,9 +793,7 @@ impl Client {
|
||||
}
|
||||
}
|
||||
let credentials = credentials.unwrap();
|
||||
if let Credentials::User { user_id, .. } = &credentials {
|
||||
self.set_id(*user_id);
|
||||
}
|
||||
self.set_id(credentials.user_id);
|
||||
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Connecting, cx);
|
||||
@@ -840,9 +809,7 @@ impl Client {
|
||||
Ok(conn) => {
|
||||
self.state.write().credentials = Some(credentials.clone());
|
||||
if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
|
||||
if let Credentials::User{user_id, access_token} = credentials {
|
||||
write_credentials_to_keychain(user_id, access_token, cx).await.log_err();
|
||||
}
|
||||
write_credentials_to_keychain(credentials, cx).await.log_err();
|
||||
}
|
||||
|
||||
futures::select_biased! {
|
||||
@@ -1050,7 +1017,10 @@ impl Client {
|
||||
.unwrap_or_default();
|
||||
|
||||
let request = Request::builder()
|
||||
.header("Authorization", credentials.authorization_header())
|
||||
.header(
|
||||
"Authorization",
|
||||
format!("{} {}", credentials.user_id, credentials.access_token),
|
||||
)
|
||||
.header("x-zed-protocol-version", rpc::PROTOCOL_VERSION)
|
||||
.header("x-zed-app-version", app_version)
|
||||
.header(
|
||||
@@ -1203,7 +1173,7 @@ impl Client {
|
||||
.decrypt_string(&access_token)
|
||||
.context("failed to decrypt access token")?;
|
||||
|
||||
Ok(Credentials::User {
|
||||
Ok(Credentials {
|
||||
user_id: user_id.parse()?,
|
||||
access_token,
|
||||
})
|
||||
@@ -1253,7 +1223,7 @@ impl Client {
|
||||
|
||||
// Use the admin API token to authenticate as the impersonated user.
|
||||
api_token.insert_str(0, "ADMIN_TOKEN:");
|
||||
Ok(Credentials::User {
|
||||
Ok(Credentials {
|
||||
user_id: response.user.id,
|
||||
access_token: api_token,
|
||||
})
|
||||
@@ -1290,30 +1260,6 @@ impl Client {
|
||||
.map_ok(|envelope| envelope.payload)
|
||||
}
|
||||
|
||||
pub fn request_stream<T: RequestMessage>(
|
||||
&self,
|
||||
request: T,
|
||||
) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
|
||||
let client_id = self.id.load(Ordering::SeqCst);
|
||||
log::debug!(
|
||||
"rpc request start. client_id:{}. name:{}",
|
||||
client_id,
|
||||
T::NAME
|
||||
);
|
||||
let response = self
|
||||
.connection_id()
|
||||
.map(|conn_id| self.peer.request_stream(conn_id, request));
|
||||
async move {
|
||||
let response = response?.await;
|
||||
log::debug!(
|
||||
"rpc request finish. client_id:{}. name:{}",
|
||||
client_id,
|
||||
T::NAME
|
||||
);
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
pub fn request_envelope<T: RequestMessage>(
|
||||
&self,
|
||||
request: T,
|
||||
@@ -1466,22 +1412,21 @@ async fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credenti
|
||||
.await
|
||||
.log_err()??;
|
||||
|
||||
Some(Credentials::User {
|
||||
Some(Credentials {
|
||||
user_id: user_id.parse().ok()?,
|
||||
access_token: String::from_utf8(access_token).ok()?,
|
||||
})
|
||||
}
|
||||
|
||||
async fn write_credentials_to_keychain(
|
||||
user_id: u64,
|
||||
access_token: String,
|
||||
credentials: Credentials,
|
||||
cx: &AsyncAppContext,
|
||||
) -> Result<()> {
|
||||
cx.update(move |cx| {
|
||||
cx.write_credentials(
|
||||
&ClientSettings::get_global(cx).server_url,
|
||||
&user_id.to_string(),
|
||||
access_token.as_bytes(),
|
||||
&credentials.user_id.to_string(),
|
||||
credentials.access_token.as_bytes(),
|
||||
)
|
||||
})?
|
||||
.await
|
||||
@@ -1586,7 +1531,7 @@ mod tests {
|
||||
// Time out when client tries to connect.
|
||||
client.override_authenticate(move |cx| {
|
||||
cx.background_executor().spawn(async move {
|
||||
Ok(Credentials::User {
|
||||
Ok(Credentials {
|
||||
user_id,
|
||||
access_token: "token".into(),
|
||||
})
|
||||
|
||||
@@ -15,8 +15,7 @@ use std::{env, mem, path::PathBuf, sync::Arc, time::Duration};
|
||||
use sysinfo::{CpuRefreshKind, MemoryRefreshKind, Pid, ProcessRefreshKind, RefreshKind, System};
|
||||
use telemetry_events::{
|
||||
ActionEvent, AppEvent, AssistantEvent, AssistantKind, CallEvent, CopilotEvent, CpuEvent,
|
||||
EditEvent, EditorEvent, Event, EventRequestBody, EventWrapper, ExtensionEvent, MemoryEvent,
|
||||
SettingEvent,
|
||||
EditEvent, EditorEvent, Event, EventRequestBody, EventWrapper, MemoryEvent, SettingEvent,
|
||||
};
|
||||
use tempfile::NamedTempFile;
|
||||
use util::http::{self, HttpClient, HttpClientWithUrl, Method};
|
||||
@@ -262,7 +261,7 @@ impl Telemetry {
|
||||
self: &Arc<Self>,
|
||||
conversation_id: Option<String>,
|
||||
kind: AssistantKind,
|
||||
model: String,
|
||||
model: &str,
|
||||
) {
|
||||
let event = Event::Assistant(AssistantEvent {
|
||||
conversation_id,
|
||||
@@ -327,13 +326,6 @@ impl Telemetry {
|
||||
self.report_event(event)
|
||||
}
|
||||
|
||||
pub fn report_extension_event(self: &Arc<Self>, extension_id: Arc<str>, version: Arc<str>) {
|
||||
self.report_event(Event::Extension(ExtensionEvent {
|
||||
extension_id,
|
||||
version,
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn log_edit_event(self: &Arc<Self>, environment: &'static str) {
|
||||
let mut state = self.state.lock();
|
||||
let period_data = state.event_coalescer.log_event(environment);
|
||||
@@ -478,11 +470,7 @@ impl Telemetry {
|
||||
|
||||
let request = http::Request::builder()
|
||||
.method(Method::POST)
|
||||
.uri(
|
||||
this.http_client
|
||||
.build_zed_api_url("/telemetry/events", &[])?
|
||||
.as_ref(),
|
||||
)
|
||||
.uri(this.http_client.build_zed_api_url("/telemetry/events"))
|
||||
.header("Content-Type", "text/plain")
|
||||
.header("x-zed-checksum", checksum)
|
||||
.body(json_bytes.into());
|
||||
@@ -590,10 +578,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_telemetry_flush_on_flush_interval(
|
||||
executor: BackgroundExecutor,
|
||||
cx: &mut TestAppContext,
|
||||
) {
|
||||
async fn test_connection_timeout(executor: BackgroundExecutor, cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let clock = Arc::new(FakeSystemClock::new(
|
||||
Utc.with_ymd_and_hms(1990, 4, 12, 12, 0, 0).unwrap(),
|
||||
|
||||
@@ -48,7 +48,7 @@ impl FakeServer {
|
||||
let mut state = state.lock();
|
||||
state.auth_count += 1;
|
||||
let access_token = state.access_token.to_string();
|
||||
Ok(Credentials::User {
|
||||
Ok(Credentials {
|
||||
user_id: client_user_id,
|
||||
access_token,
|
||||
})
|
||||
@@ -71,12 +71,9 @@ impl FakeServer {
|
||||
)))?
|
||||
}
|
||||
|
||||
if credentials
|
||||
!= (Credentials::User {
|
||||
user_id: client_user_id,
|
||||
access_token: state.lock().access_token.to_string(),
|
||||
})
|
||||
{
|
||||
assert_eq!(credentials.user_id, client_user_id);
|
||||
|
||||
if credentials.access_token != state.lock().access_token.to_string() {
|
||||
Err(EstablishConnectionError::Unauthorized)?
|
||||
}
|
||||
|
||||
|
||||
8
crates/collab/.admins.default.json
Normal file
8
crates/collab/.admins.default.json
Normal file
@@ -0,0 +1,8 @@
|
||||
[
|
||||
"nathansobo",
|
||||
"as-cii",
|
||||
"maxbrunsfeld",
|
||||
"iamnbutler",
|
||||
"mikayla-maki",
|
||||
"JosephTLyons"
|
||||
]
|
||||
@@ -1,5 +1,4 @@
|
||||
DATABASE_URL = "postgres://postgres@localhost/zed"
|
||||
# DATABASE_URL = "sqlite:////home/zed/.config/zed/db.sqlite3?mode=rwc"
|
||||
DATABASE_MAX_CONNECTIONS = 5
|
||||
HTTP_PORT = 8080
|
||||
API_TOKEN = "secret"
|
||||
@@ -14,7 +13,6 @@ BLOB_STORE_BUCKET = "the-extensions-bucket"
|
||||
BLOB_STORE_URL = "http://127.0.0.1:9000"
|
||||
BLOB_STORE_REGION = "the-region"
|
||||
ZED_CLIENT_CHECKSUM_SEED = "development-checksum-seed"
|
||||
SEED_PATH = "crates/collab/seed.default.json"
|
||||
|
||||
# CLICKHOUSE_URL = ""
|
||||
# CLICKHOUSE_USER = "default"
|
||||
|
||||
@@ -13,12 +13,10 @@ workspace = true
|
||||
[[bin]]
|
||||
name = "collab"
|
||||
|
||||
[features]
|
||||
sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"]
|
||||
test-support = ["sqlite"]
|
||||
[[bin]]
|
||||
name = "seed"
|
||||
|
||||
[dependencies]
|
||||
anthropic.workspace = true
|
||||
anyhow.workspace = true
|
||||
async-tungstenite = "0.16"
|
||||
aws-config = { version = "1.1.5" }
|
||||
@@ -33,12 +31,10 @@ collections.workspace = true
|
||||
dashmap = "5.4"
|
||||
envy = "0.4.2"
|
||||
futures.workspace = true
|
||||
google_ai.workspace = true
|
||||
hex.workspace = true
|
||||
live_kit_server.workspace = true
|
||||
log.workspace = true
|
||||
nanoid = "0.4"
|
||||
open_ai.workspace = true
|
||||
parking_lot.workspace = true
|
||||
prometheus = "0.13"
|
||||
prost.workspace = true
|
||||
@@ -47,7 +43,6 @@ reqwest = { version = "0.11", features = ["json"] }
|
||||
rpc.workspace = true
|
||||
scrypt = "0.7"
|
||||
sea-orm = { version = "0.12.x", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls", "with-uuid"] }
|
||||
semantic_version.workspace = true
|
||||
semver.workspace = true
|
||||
serde.workspace = true
|
||||
serde_derive.workspace = true
|
||||
@@ -85,6 +80,7 @@ git = { workspace = true, features = ["test-support"] }
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
indoc.workspace = true
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
lazy_static.workspace = true
|
||||
live_kit_client = { workspace = true, features = ["test-support"] }
|
||||
lsp = { workspace = true, features = ["test-support"] }
|
||||
menu.workspace = true
|
||||
|
||||
@@ -6,21 +6,21 @@ It contains our back-end logic for collaboration, to which we connect from the Z
|
||||
|
||||
# Local Development
|
||||
|
||||
Detailed instructions on getting started are [here](https://zed.dev/docs/local-collaboration).
|
||||
Detailed instructions on getting started are [here](https://zed.dev/docs/local-collaboration).
|
||||
|
||||
# Deployment
|
||||
|
||||
We run two instances of collab:
|
||||
|
||||
- Staging (https://staging-collab.zed.dev)
|
||||
- Production (https://collab.zed.dev)
|
||||
* Staging (https://staging-collab.zed.dev)
|
||||
* Production (https://collab.zed.dev)
|
||||
|
||||
Both of these run on the Kubernetes cluster hosted in Digital Ocean.
|
||||
|
||||
Deployment is triggered by pushing to the `collab-staging` (or `collab-production`) tag in Github. The best way to do this is:
|
||||
|
||||
- `./script/deploy-collab staging`
|
||||
- `./script/deploy-collab production`
|
||||
* `./script/deploy-collab staging`
|
||||
* `./script/deploy-collab production`
|
||||
|
||||
You can tell what is currently deployed with `./script/what-is-deployed`.
|
||||
|
||||
@@ -29,7 +29,7 @@ You can tell what is currently deployed with `./script/what-is-deployed`.
|
||||
To create a new migration:
|
||||
|
||||
```
|
||||
./script/create-migration <name>
|
||||
./script/sqlx migrate add <name>
|
||||
```
|
||||
|
||||
Migrations are run automatically on service start, so run `foreman start` again. The service will crash if the migrations fail.
|
||||
|
||||
12
crates/collab/basic.conf
Normal file
12
crates/collab/basic.conf
Normal file
@@ -0,0 +1,12 @@
|
||||
|
||||
[Interface]
|
||||
PrivateKey = B5Fp/yVfP0QYlb+YJv9ea+EMI1mWODPD3akh91cVjvc=
|
||||
Address = fdaa:0:2ce3:a7b:bea:0:a:2/120
|
||||
DNS = fdaa:0:2ce3::3
|
||||
|
||||
[Peer]
|
||||
PublicKey = RKAYPljEJiuaELNDdQIEJmQienT9+LRISfIHwH45HAw=
|
||||
AllowedIPs = fdaa:0:2ce3::/48
|
||||
Endpoint = ord1.gateway.6pn.dev:51820
|
||||
PersistentKeepalive = 15
|
||||
|
||||
@@ -47,6 +47,19 @@ spec:
|
||||
metadata:
|
||||
labels:
|
||||
app: ${ZED_SERVICE_NAME}
|
||||
annotations:
|
||||
ad.datadoghq.com/collab.check_names: |
|
||||
["openmetrics"]
|
||||
ad.datadoghq.com/collab.init_configs: |
|
||||
[{}]
|
||||
ad.datadoghq.com/collab.instances: |
|
||||
[
|
||||
{
|
||||
"openmetrics_endpoint": "http://%%host%%:%%port%%/metrics",
|
||||
"namespace": "collab_${ZED_KUBE_NAMESPACE}",
|
||||
"metrics": [".*"]
|
||||
}
|
||||
]
|
||||
spec:
|
||||
containers:
|
||||
- name: ${ZED_SERVICE_NAME}
|
||||
@@ -112,16 +125,6 @@ spec:
|
||||
secretKeyRef:
|
||||
name: livekit
|
||||
key: secret
|
||||
- name: OPENAI_API_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: openai
|
||||
key: api_key
|
||||
- name: ANTHROPIC_API_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: anthropic
|
||||
key: api_key
|
||||
- name: BLOB_STORE_ACCESS_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
|
||||
@@ -219,7 +219,6 @@ CREATE TABLE IF NOT EXISTS "channel_messages" (
|
||||
"sender_id" INTEGER NOT NULL REFERENCES users (id),
|
||||
"body" TEXT NOT NULL,
|
||||
"sent_at" TIMESTAMP,
|
||||
"edited_at" TIMESTAMP,
|
||||
"nonce" BLOB NOT NULL,
|
||||
"reply_to_message_id" INTEGER DEFAULT NULL
|
||||
);
|
||||
@@ -373,8 +372,6 @@ CREATE TABLE extension_versions (
|
||||
authors TEXT NOT NULL,
|
||||
repository TEXT NOT NULL,
|
||||
description TEXT NOT NULL,
|
||||
schema_version INTEGER NOT NULL DEFAULT 0,
|
||||
wasm_api_version TEXT,
|
||||
download_count INTEGER NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY (extension_id, version)
|
||||
);
|
||||
@@ -382,16 +379,6 @@ CREATE TABLE extension_versions (
|
||||
CREATE UNIQUE INDEX "index_extensions_external_id" ON "extensions" ("external_id");
|
||||
CREATE INDEX "index_extensions_total_download_count" ON "extensions" ("total_download_count");
|
||||
|
||||
CREATE TABLE rate_buckets (
|
||||
user_id INT NOT NULL,
|
||||
rate_limit_name VARCHAR(255) NOT NULL,
|
||||
token_count INT NOT NULL,
|
||||
last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL,
|
||||
PRIMARY KEY (user_id, rate_limit_name),
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name);
|
||||
|
||||
CREATE TABLE hosted_projects (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
channel_id INTEGER NOT NULL REFERENCES channels(id),
|
||||
@@ -401,11 +388,3 @@ CREATE TABLE hosted_projects (
|
||||
);
|
||||
CREATE INDEX idx_hosted_projects_on_channel_id ON hosted_projects (channel_id);
|
||||
CREATE UNIQUE INDEX uix_hosted_projects_on_channel_id_and_name ON hosted_projects (channel_id, name) WHERE (deleted_at IS NULL);
|
||||
|
||||
CREATE TABLE dev_servers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
channel_id INTEGER NOT NULL REFERENCES channels(id),
|
||||
name TEXT NOT NULL,
|
||||
hashed_token TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX idx_dev_servers_on_channel_id ON dev_servers (channel_id);
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS rate_buckets (
|
||||
user_id INT NOT NULL,
|
||||
rate_limit_name VARCHAR(255) NOT NULL,
|
||||
token_count INT NOT NULL,
|
||||
last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL,
|
||||
PRIMARY KEY (user_id, rate_limit_name),
|
||||
CONSTRAINT fk_user
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name);
|
||||
@@ -1 +0,0 @@
|
||||
ALTER TABLE channel_messages ADD edited_at TIMESTAMP DEFAULT NULL;
|
||||
@@ -1,2 +0,0 @@
|
||||
-- Add migration script here
|
||||
ALTER TABLE extension_versions ADD COLUMN schema_version INTEGER NOT NULL DEFAULT 0;
|
||||
@@ -1,7 +0,0 @@
|
||||
CREATE TABLE dev_servers (
|
||||
id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
|
||||
channel_id INT NOT NULL REFERENCES channels(id),
|
||||
name TEXT NOT NULL,
|
||||
hashed_token TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX idx_dev_servers_on_channel_id ON dev_servers (channel_id);
|
||||
@@ -1 +0,0 @@
|
||||
ALTER TABLE extension_versions ADD COLUMN wasm_api_version TEXT;
|
||||
@@ -1,12 +0,0 @@
|
||||
{
|
||||
"admins": [
|
||||
"nathansobo",
|
||||
"as-cii",
|
||||
"maxbrunsfeld",
|
||||
"iamnbutler",
|
||||
"mikayla-maki",
|
||||
"JosephTLyons"
|
||||
],
|
||||
"channels": ["zed"],
|
||||
"number_of_users": 100
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use rpc::proto;
|
||||
|
||||
pub fn language_model_request_to_open_ai(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
) -> Result<open_ai::Request> {
|
||||
Ok(open_ai::Request {
|
||||
model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
|
||||
messages: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| {
|
||||
let role = proto::LanguageModelRole::from_i32(message.role)
|
||||
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
|
||||
Ok(open_ai::RequestMessage {
|
||||
role: match role {
|
||||
proto::LanguageModelRole::LanguageModelUser => open_ai::Role::User,
|
||||
proto::LanguageModelRole::LanguageModelAssistant => {
|
||||
open_ai::Role::Assistant
|
||||
}
|
||||
proto::LanguageModelRole::LanguageModelSystem => open_ai::Role::System,
|
||||
},
|
||||
content: message.content,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<open_ai::RequestMessage>>>()?,
|
||||
stream: true,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn language_model_request_to_google_ai(
|
||||
request: proto::CompleteWithLanguageModel,
|
||||
) -> Result<google_ai::GenerateContentRequest> {
|
||||
Ok(google_ai::GenerateContentRequest {
|
||||
contents: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(language_model_request_message_to_google_ai)
|
||||
.collect::<Result<Vec<_>>>()?,
|
||||
generation_config: None,
|
||||
safety_settings: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn language_model_request_message_to_google_ai(
|
||||
message: proto::LanguageModelRequestMessage,
|
||||
) -> Result<google_ai::Content> {
|
||||
let role = proto::LanguageModelRole::from_i32(message.role)
|
||||
.ok_or_else(|| anyhow!("invalid role {}", message.role))?;
|
||||
|
||||
Ok(google_ai::Content {
|
||||
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
|
||||
text: message.content,
|
||||
})],
|
||||
role: match role {
|
||||
proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
|
||||
proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
|
||||
proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn count_tokens_request_to_google_ai(
|
||||
request: proto::CountTokensWithLanguageModel,
|
||||
) -> Result<google_ai::CountTokensRequest> {
|
||||
Ok(google_ai::CountTokensRequest {
|
||||
contents: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(language_model_request_message_to_google_ai)
|
||||
.collect::<Result<Vec<_>>>()?,
|
||||
})
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::ips_file::IpsFile;
|
||||
use crate::{api::slack, AppState, Error, Result};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use aws_sdk_s3::primitives::ByteStream;
|
||||
use axum::{
|
||||
@@ -9,15 +9,17 @@ use axum::{
|
||||
routing::post,
|
||||
Extension, Router, TypedHeader,
|
||||
};
|
||||
use rpc::ExtensionMetadata;
|
||||
use semantic_version::SemanticVersion;
|
||||
use serde::{Serialize, Serializer};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use telemetry_events::{
|
||||
ActionEvent, AppEvent, AssistantEvent, CallEvent, CopilotEvent, CpuEvent, EditEvent,
|
||||
EditorEvent, Event, EventRequestBody, EventWrapper, ExtensionEvent, MemoryEvent, SettingEvent,
|
||||
EditorEvent, Event, EventRequestBody, EventWrapper, MemoryEvent, SettingEvent,
|
||||
};
|
||||
use util::SemanticVersion;
|
||||
|
||||
use crate::{api::slack, AppState, Error, Result};
|
||||
|
||||
use super::ips_file::IpsFile;
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
@@ -329,21 +331,6 @@ pub async fn post_events(
|
||||
&request_body,
|
||||
first_event_at,
|
||||
)),
|
||||
Event::Extension(event) => {
|
||||
let metadata = app
|
||||
.db
|
||||
.get_extension_version(&event.extension_id, &event.version)
|
||||
.await?;
|
||||
to_upload
|
||||
.extension_events
|
||||
.push(ExtensionEventRow::from_event(
|
||||
event.clone(),
|
||||
&wrapper,
|
||||
&request_body,
|
||||
metadata,
|
||||
first_event_at,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -365,7 +352,6 @@ struct ToUpload {
|
||||
memory_events: Vec<MemoryEventRow>,
|
||||
app_events: Vec<AppEventRow>,
|
||||
setting_events: Vec<SettingEventRow>,
|
||||
extension_events: Vec<ExtensionEventRow>,
|
||||
edit_events: Vec<EditEventRow>,
|
||||
action_events: Vec<ActionEventRow>,
|
||||
}
|
||||
@@ -424,15 +410,6 @@ impl ToUpload {
|
||||
.await
|
||||
.with_context(|| format!("failed to upload to table '{SETTING_EVENTS_TABLE}'"))?;
|
||||
|
||||
const EXTENSION_EVENTS_TABLE: &str = "extension_events";
|
||||
Self::upload_to_table(
|
||||
EXTENSION_EVENTS_TABLE,
|
||||
&self.extension_events,
|
||||
clickhouse_client,
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("failed to upload to table '{EXTENSION_EVENTS_TABLE}'"))?;
|
||||
|
||||
const EDIT_EVENTS_TABLE: &str = "edit_events";
|
||||
Self::upload_to_table(EDIT_EVENTS_TABLE, &self.edit_events, clickhouse_client)
|
||||
.await
|
||||
@@ -459,12 +436,6 @@ impl ToUpload {
|
||||
}
|
||||
|
||||
insert.end().await?;
|
||||
|
||||
let event_count = rows.len();
|
||||
log::info!(
|
||||
"wrote {event_count} {event_specifier} to '{table}'",
|
||||
event_specifier = if event_count == 1 { "event" } else { "events" }
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -528,9 +499,9 @@ impl EditorEventRow {
|
||||
|
||||
Self {
|
||||
app_version: body.app_version.clone(),
|
||||
major: semver.map(|v| v.major() as i32),
|
||||
minor: semver.map(|v| v.minor() as i32),
|
||||
patch: semver.map(|v| v.patch() as i32),
|
||||
major: semver.map(|s| s.major as i32),
|
||||
minor: semver.map(|s| s.minor as i32),
|
||||
patch: semver.map(|s| s.patch as i32),
|
||||
release_channel: body.release_channel.clone().unwrap_or_default(),
|
||||
os_name: body.os_name.clone(),
|
||||
os_version: body.os_version.clone().unwrap_or_default(),
|
||||
@@ -590,9 +561,9 @@ impl CopilotEventRow {
|
||||
|
||||
Self {
|
||||
app_version: body.app_version.clone(),
|
||||
major: semver.map(|v| v.major() as i32),
|
||||
minor: semver.map(|v| v.minor() as i32),
|
||||
patch: semver.map(|v| v.patch() as i32),
|
||||
major: semver.map(|s| s.major as i32),
|
||||
minor: semver.map(|s| s.minor as i32),
|
||||
patch: semver.map(|s| s.patch as i32),
|
||||
release_channel: body.release_channel.clone().unwrap_or_default(),
|
||||
os_name: body.os_name.clone(),
|
||||
os_version: body.os_version.clone().unwrap_or_default(),
|
||||
@@ -645,9 +616,9 @@ impl CallEventRow {
|
||||
|
||||
Self {
|
||||
app_version: body.app_version.clone(),
|
||||
major: semver.map(|v| v.major() as i32),
|
||||
minor: semver.map(|v| v.minor() as i32),
|
||||
patch: semver.map(|v| v.patch() as i32),
|
||||
major: semver.map(|s| s.major as i32),
|
||||
minor: semver.map(|s| s.minor as i32),
|
||||
patch: semver.map(|s| s.patch as i32),
|
||||
release_channel: body.release_channel.clone().unwrap_or_default(),
|
||||
installation_id: body.installation_id.clone().unwrap_or_default(),
|
||||
session_id: body.session_id.clone(),
|
||||
@@ -694,9 +665,9 @@ impl AssistantEventRow {
|
||||
|
||||
Self {
|
||||
app_version: body.app_version.clone(),
|
||||
major: semver.map(|v| v.major() as i32),
|
||||
minor: semver.map(|v| v.minor() as i32),
|
||||
patch: semver.map(|v| v.patch() as i32),
|
||||
major: semver.map(|s| s.major as i32),
|
||||
minor: semver.map(|s| s.minor as i32),
|
||||
patch: semver.map(|s| s.patch as i32),
|
||||
release_channel: body.release_channel.clone().unwrap_or_default(),
|
||||
installation_id: body.installation_id.clone(),
|
||||
session_id: body.session_id.clone(),
|
||||
@@ -738,9 +709,9 @@ impl CpuEventRow {
|
||||
|
||||
Self {
|
||||
app_version: body.app_version.clone(),
|
||||
major: semver.map(|v| v.major() as i32),
|
||||
minor: semver.map(|v| v.minor() as i32),
|
||||
patch: semver.map(|v| v.patch() as i32),
|
||||
major: semver.map(|s| s.major as i32),
|
||||
minor: semver.map(|s| s.minor as i32),
|
||||
patch: semver.map(|s| s.patch as i32),
|
||||
release_channel: body.release_channel.clone().unwrap_or_default(),
|
||||
installation_id: body.installation_id.clone(),
|
||||
session_id: body.session_id.clone(),
|
||||
@@ -785,9 +756,9 @@ impl MemoryEventRow {
|
||||
|
||||
Self {
|
||||
app_version: body.app_version.clone(),
|
||||
major: semver.map(|v| v.major() as i32),
|
||||
minor: semver.map(|v| v.minor() as i32),
|
||||
patch: semver.map(|v| v.patch() as i32),
|
||||
major: semver.map(|s| s.major as i32),
|
||||
minor: semver.map(|s| s.minor as i32),
|
||||
patch: semver.map(|s| s.patch as i32),
|
||||
release_channel: body.release_channel.clone().unwrap_or_default(),
|
||||
installation_id: body.installation_id.clone(),
|
||||
session_id: body.session_id.clone(),
|
||||
@@ -831,9 +802,9 @@ impl AppEventRow {
|
||||
|
||||
Self {
|
||||
app_version: body.app_version.clone(),
|
||||
major: semver.map(|v| v.major() as i32),
|
||||
minor: semver.map(|v| v.minor() as i32),
|
||||
patch: semver.map(|v| v.patch() as i32),
|
||||
major: semver.map(|s| s.major as i32),
|
||||
minor: semver.map(|s| s.minor as i32),
|
||||
patch: semver.map(|s| s.patch as i32),
|
||||
release_channel: body.release_channel.clone().unwrap_or_default(),
|
||||
installation_id: body.installation_id.clone(),
|
||||
session_id: body.session_id.clone(),
|
||||
@@ -876,9 +847,9 @@ impl SettingEventRow {
|
||||
|
||||
Self {
|
||||
app_version: body.app_version.clone(),
|
||||
major: semver.map(|v| v.major() as i32),
|
||||
minor: semver.map(|v| v.minor() as i32),
|
||||
patch: semver.map(|v| v.patch() as i32),
|
||||
major: semver.map(|s| s.major as i32),
|
||||
minor: semver.map(|s| s.minor as i32),
|
||||
patch: semver.map(|s| s.patch as i32),
|
||||
release_channel: body.release_channel.clone().unwrap_or_default(),
|
||||
installation_id: body.installation_id.clone(),
|
||||
session_id: body.session_id.clone(),
|
||||
@@ -890,68 +861,6 @@ impl SettingEventRow {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, clickhouse::Row)]
|
||||
pub struct ExtensionEventRow {
|
||||
// AppInfoBase
|
||||
app_version: String,
|
||||
major: Option<i32>,
|
||||
minor: Option<i32>,
|
||||
patch: Option<i32>,
|
||||
release_channel: String,
|
||||
|
||||
// ClientEventBase
|
||||
installation_id: Option<String>,
|
||||
session_id: Option<String>,
|
||||
is_staff: Option<bool>,
|
||||
time: i64,
|
||||
|
||||
// ExtensionEventRow
|
||||
extension_id: Arc<str>,
|
||||
extension_version: Arc<str>,
|
||||
dev: bool,
|
||||
schema_version: Option<i32>,
|
||||
wasm_api_version: Option<String>,
|
||||
}
|
||||
|
||||
impl ExtensionEventRow {
|
||||
fn from_event(
|
||||
event: ExtensionEvent,
|
||||
wrapper: &EventWrapper,
|
||||
body: &EventRequestBody,
|
||||
extension_metadata: Option<ExtensionMetadata>,
|
||||
first_event_at: chrono::DateTime<chrono::Utc>,
|
||||
) -> Self {
|
||||
let semver = body.semver();
|
||||
let time =
|
||||
first_event_at + chrono::Duration::milliseconds(wrapper.milliseconds_since_first_event);
|
||||
|
||||
Self {
|
||||
app_version: body.app_version.clone(),
|
||||
major: semver.map(|v| v.major() as i32),
|
||||
minor: semver.map(|v| v.minor() as i32),
|
||||
patch: semver.map(|v| v.patch() as i32),
|
||||
release_channel: body.release_channel.clone().unwrap_or_default(),
|
||||
installation_id: body.installation_id.clone(),
|
||||
session_id: body.session_id.clone(),
|
||||
is_staff: body.is_staff,
|
||||
time: time.timestamp_millis(),
|
||||
extension_id: event.extension_id,
|
||||
extension_version: event.version,
|
||||
dev: extension_metadata.is_none(),
|
||||
schema_version: extension_metadata
|
||||
.as_ref()
|
||||
.and_then(|metadata| metadata.manifest.schema_version),
|
||||
wasm_api_version: extension_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.manifest
|
||||
.wasm_api_version
|
||||
.as_ref()
|
||||
.map(|version| version.to_string())
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, clickhouse::Row)]
|
||||
pub struct EditEventRow {
|
||||
// AppInfoBase
|
||||
@@ -991,9 +900,9 @@ impl EditEventRow {
|
||||
|
||||
Self {
|
||||
app_version: body.app_version.clone(),
|
||||
major: semver.map(|v| v.major() as i32),
|
||||
minor: semver.map(|v| v.minor() as i32),
|
||||
patch: semver.map(|v| v.patch() as i32),
|
||||
major: semver.map(|s| s.major as i32),
|
||||
minor: semver.map(|s| s.minor as i32),
|
||||
patch: semver.map(|s| s.patch as i32),
|
||||
release_channel: body.release_channel.clone().unwrap_or_default(),
|
||||
installation_id: body.installation_id.clone(),
|
||||
session_id: body.session_id.clone(),
|
||||
@@ -1040,9 +949,9 @@ impl ActionEventRow {
|
||||
|
||||
Self {
|
||||
app_version: body.app_version.clone(),
|
||||
major: semver.map(|v| v.major() as i32),
|
||||
minor: semver.map(|v| v.minor() as i32),
|
||||
patch: semver.map(|v| v.patch() as i32),
|
||||
major: semver.map(|s| s.major as i32),
|
||||
minor: semver.map(|s| s.minor as i32),
|
||||
patch: semver.map(|s| s.patch as i32),
|
||||
release_channel: body.release_channel.clone().unwrap_or_default(),
|
||||
installation_id: body.installation_id.clone(),
|
||||
session_id: body.session_id.clone(),
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use crate::db::ExtensionVersionConstraints;
|
||||
use crate::{db::NewExtensionVersion, AppState, Error, Result};
|
||||
use crate::{
|
||||
db::{ExtensionMetadata, NewExtensionVersion},
|
||||
executor::Executor,
|
||||
AppState, Error, Result,
|
||||
};
|
||||
use anyhow::{anyhow, Context as _};
|
||||
use aws_sdk_s3::presigning::PresigningConfig;
|
||||
use axum::{
|
||||
@@ -10,22 +13,14 @@ use axum::{
|
||||
Extension, Json, Router,
|
||||
};
|
||||
use collections::HashMap;
|
||||
use rpc::{ExtensionApiManifest, GetExtensionsResponse};
|
||||
use semantic_version::SemanticVersion;
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use time::PrimitiveDateTime;
|
||||
use util::{maybe, ResultExt};
|
||||
use util::ResultExt;
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
.route("/extensions", get(get_extensions))
|
||||
.route("/extensions/updates", get(get_extension_updates))
|
||||
.route("/extensions/:extension_id", get(get_extension_versions))
|
||||
.route(
|
||||
"/extensions/:extension_id/download",
|
||||
get(download_latest_extension),
|
||||
)
|
||||
.route(
|
||||
"/extensions/:extension_id/:version/download",
|
||||
get(download_extension),
|
||||
@@ -35,114 +30,6 @@ pub fn router() -> Router {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GetExtensionsParams {
|
||||
filter: Option<String>,
|
||||
#[serde(default)]
|
||||
ids: Option<String>,
|
||||
#[serde(default)]
|
||||
max_schema_version: i32,
|
||||
}
|
||||
|
||||
async fn get_extensions(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Query(params): Query<GetExtensionsParams>,
|
||||
) -> Result<Json<GetExtensionsResponse>> {
|
||||
let extension_ids = params
|
||||
.ids
|
||||
.as_ref()
|
||||
.map(|s| s.split(',').map(|s| s.trim()).collect::<Vec<_>>());
|
||||
|
||||
let extensions = if let Some(extension_ids) = extension_ids {
|
||||
app.db.get_extensions_by_ids(&extension_ids, None).await?
|
||||
} else {
|
||||
app.db
|
||||
.get_extensions(params.filter.as_deref(), params.max_schema_version, 500)
|
||||
.await?
|
||||
};
|
||||
|
||||
Ok(Json(GetExtensionsResponse { data: extensions }))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GetExtensionUpdatesParams {
|
||||
ids: String,
|
||||
min_schema_version: i32,
|
||||
max_schema_version: i32,
|
||||
min_wasm_api_version: SemanticVersion,
|
||||
max_wasm_api_version: SemanticVersion,
|
||||
}
|
||||
|
||||
async fn get_extension_updates(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Query(params): Query<GetExtensionUpdatesParams>,
|
||||
) -> Result<Json<GetExtensionsResponse>> {
|
||||
let constraints = ExtensionVersionConstraints {
|
||||
schema_versions: params.min_schema_version..=params.max_schema_version,
|
||||
wasm_api_versions: params.min_wasm_api_version..=params.max_wasm_api_version,
|
||||
};
|
||||
|
||||
let extension_ids = params.ids.split(',').map(|s| s.trim()).collect::<Vec<_>>();
|
||||
|
||||
let extensions = app
|
||||
.db
|
||||
.get_extensions_by_ids(&extension_ids, Some(&constraints))
|
||||
.await?;
|
||||
|
||||
Ok(Json(GetExtensionsResponse { data: extensions }))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GetExtensionVersionsParams {
|
||||
extension_id: String,
|
||||
}
|
||||
|
||||
async fn get_extension_versions(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Path(params): Path<GetExtensionVersionsParams>,
|
||||
) -> Result<Json<GetExtensionsResponse>> {
|
||||
let extension_versions = app.db.get_extension_versions(¶ms.extension_id).await?;
|
||||
|
||||
Ok(Json(GetExtensionsResponse {
|
||||
data: extension_versions,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DownloadLatestExtensionParams {
|
||||
extension_id: String,
|
||||
min_schema_version: Option<i32>,
|
||||
max_schema_version: Option<i32>,
|
||||
min_wasm_api_version: Option<SemanticVersion>,
|
||||
max_wasm_api_version: Option<SemanticVersion>,
|
||||
}
|
||||
|
||||
async fn download_latest_extension(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Path(params): Path<DownloadLatestExtensionParams>,
|
||||
) -> Result<Redirect> {
|
||||
let constraints = maybe!({
|
||||
let min_schema_version = params.min_schema_version?;
|
||||
let max_schema_version = params.max_schema_version?;
|
||||
let min_wasm_api_version = params.min_wasm_api_version?;
|
||||
let max_wasm_api_version = params.max_wasm_api_version?;
|
||||
|
||||
Some(ExtensionVersionConstraints {
|
||||
schema_versions: min_schema_version..=max_schema_version,
|
||||
wasm_api_versions: min_wasm_api_version..=max_wasm_api_version,
|
||||
})
|
||||
});
|
||||
|
||||
let extension = app
|
||||
.db
|
||||
.get_extension(¶ms.extension_id, constraints.as_ref())
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("unknown extension"))?;
|
||||
download_extension(
|
||||
Extension(app),
|
||||
Path(DownloadExtensionParams {
|
||||
extension_id: params.extension_id,
|
||||
version: extension.manifest.version.to_string(),
|
||||
}),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -151,6 +38,28 @@ struct DownloadExtensionParams {
|
||||
version: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GetExtensionsResponse {
|
||||
pub data: Vec<ExtensionMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExtensionManifest {
|
||||
name: String,
|
||||
version: String,
|
||||
description: Option<String>,
|
||||
authors: Vec<String>,
|
||||
repository: String,
|
||||
}
|
||||
|
||||
async fn get_extensions(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Query(params): Query<GetExtensionsParams>,
|
||||
) -> Result<Json<GetExtensionsResponse>> {
|
||||
let extensions = app.db.get_extensions(params.filter.as_deref(), 500).await?;
|
||||
Ok(Json(GetExtensionsResponse { data: extensions }))
|
||||
}
|
||||
|
||||
async fn download_extension(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Path(params): Path<DownloadExtensionParams>,
|
||||
@@ -199,7 +108,7 @@ async fn download_extension(
|
||||
const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
|
||||
const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
|
||||
|
||||
pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
|
||||
pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>, executor: Executor) {
|
||||
let Some(blob_store_client) = app_state.blob_store_client.clone() else {
|
||||
log::info!("no blob store client");
|
||||
return;
|
||||
@@ -209,7 +118,6 @@ pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
|
||||
return;
|
||||
};
|
||||
|
||||
let executor = app_state.executor.clone();
|
||||
executor.spawn_detached({
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
@@ -331,7 +239,7 @@ async fn fetch_extension_manifest(
|
||||
})?
|
||||
.to_vec();
|
||||
let manifest =
|
||||
serde_json::from_slice::<ExtensionApiManifest>(&manifest_bytes).with_context(|| {
|
||||
serde_json::from_slice::<ExtensionManifest>(&manifest_bytes).with_context(|| {
|
||||
format!(
|
||||
"invalid manifest for extension {extension_id} version {version}: {}",
|
||||
String::from_utf8_lossy(&manifest_bytes)
|
||||
@@ -351,8 +259,6 @@ async fn fetch_extension_manifest(
|
||||
description: manifest.description.unwrap_or_default(),
|
||||
authors: manifest.authors,
|
||||
repository: manifest.repository,
|
||||
schema_version: manifest.schema_version.unwrap_or(0),
|
||||
wasm_api_version: manifest.wasm_api_version,
|
||||
published_at,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use collections::HashMap;
|
||||
|
||||
use semantic_version::SemanticVersion;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_derive::Deserialize;
|
||||
use serde_derive::Serialize;
|
||||
use serde_json::Value;
|
||||
use util::SemanticVersion;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct IpsFile {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use crate::{
|
||||
db::{self, dev_server, AccessTokenId, Database, DevServerId, UserId},
|
||||
rpc::Principal,
|
||||
db::{self, AccessTokenId, Database, UserId},
|
||||
AppState, Error, Result,
|
||||
};
|
||||
use anyhow::{anyhow, Context};
|
||||
@@ -20,11 +19,11 @@ use std::sync::OnceLock;
|
||||
use std::{sync::Arc, time::Instant};
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
/// Validates the authorization header and adds an Extension<Principal> to the request.
|
||||
/// Authorization: <user-id> <token>
|
||||
/// <token> can be an access_token attached to that user, or an access token of an admin
|
||||
/// or (in development) the string ADMIN:<config.api_token>.
|
||||
/// Authorization: "dev-server-token" <token>
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
||||
pub struct Impersonator(pub Option<db::User>);
|
||||
|
||||
/// Validates the authorization header. This has two mechanisms, one for the ADMIN_TOKEN
|
||||
/// and one for the access tokens that we issue.
|
||||
pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
|
||||
let mut auth_header = req
|
||||
.headers()
|
||||
@@ -38,26 +37,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
||||
})?
|
||||
.split_whitespace();
|
||||
|
||||
let state = req.extensions().get::<Arc<AppState>>().unwrap();
|
||||
|
||||
let first = auth_header.next().unwrap_or("");
|
||||
if first == "dev-server-token" {
|
||||
let dev_server_token = auth_header.next().ok_or_else(|| {
|
||||
Error::Http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"missing dev-server-token token in authorization header".to_string(),
|
||||
)
|
||||
})?;
|
||||
let dev_server = verify_dev_server_token(dev_server_token, &state.db)
|
||||
.await
|
||||
.map_err(|e| Error::Http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
|
||||
|
||||
req.extensions_mut()
|
||||
.insert(Principal::DevServer(dev_server));
|
||||
return Ok::<_, Error>(next.run(req).await);
|
||||
}
|
||||
|
||||
let user_id = UserId(first.parse().map_err(|_| {
|
||||
let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
|
||||
Error::Http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"missing user id in authorization header".to_string(),
|
||||
@@ -71,6 +51,8 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
||||
)
|
||||
})?;
|
||||
|
||||
let state = req.extensions().get::<Arc<AppState>>().unwrap();
|
||||
|
||||
// In development, allow impersonation using the admin API token.
|
||||
// Don't allow this in production because we can't tell who is doing
|
||||
// the impersonating.
|
||||
@@ -94,17 +76,18 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user {} not found", user_id))?;
|
||||
|
||||
if let Some(impersonator_id) = validate_result.impersonator_id {
|
||||
let admin = state
|
||||
let impersonator = if let Some(impersonator_id) = validate_result.impersonator_id {
|
||||
let impersonator = state
|
||||
.db
|
||||
.get_user_by_id(impersonator_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
|
||||
req.extensions_mut()
|
||||
.insert(Principal::Impersonated { user, admin });
|
||||
Some(impersonator)
|
||||
} else {
|
||||
req.extensions_mut().insert(Principal::User(user));
|
||||
None
|
||||
};
|
||||
req.extensions_mut().insert(user);
|
||||
req.extensions_mut().insert(Impersonator(impersonator));
|
||||
return Ok::<_, Error>(next.run(req).await);
|
||||
}
|
||||
}
|
||||
@@ -230,33 +213,6 @@ pub async fn verify_access_token(
|
||||
})
|
||||
}
|
||||
|
||||
// a dev_server_token has the format <id>.<base64>. This is to make them
|
||||
// relatively easy to copy/paste around.
|
||||
pub async fn verify_dev_server_token(
|
||||
dev_server_token: &str,
|
||||
db: &Arc<Database>,
|
||||
) -> anyhow::Result<dev_server::Model> {
|
||||
let mut parts = dev_server_token.splitn(2, '.');
|
||||
let id = DevServerId(parts.next().unwrap_or_default().parse()?);
|
||||
let token = parts
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("invalid dev server token format"))?;
|
||||
|
||||
let token_hash = hash_access_token(&token);
|
||||
let server = db.get_dev_server(id).await?;
|
||||
|
||||
if server
|
||||
.hashed_token
|
||||
.as_bytes()
|
||||
.ct_eq(token_hash.as_ref())
|
||||
.into()
|
||||
{
|
||||
Ok(server)
|
||||
} else {
|
||||
Err(anyhow!("wrong token for dev server"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use rand::thread_rng;
|
||||
|
||||
97
crates/collab/src/bin/seed.rs
Normal file
97
crates/collab/src/bin/seed.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
use collab::{
|
||||
db::{self, NewUserParams},
|
||||
env::load_dotenv,
|
||||
executor::Executor,
|
||||
};
|
||||
use db::{ConnectOptions, Database};
|
||||
use serde::{de::DeserializeOwned, Deserialize};
|
||||
use std::{fmt::Write, fs};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GitHubUser {
|
||||
id: i32,
|
||||
login: String,
|
||||
email: Option<String>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
load_dotenv().expect("failed to load .env.toml file");
|
||||
|
||||
let mut admin_logins = load_admins("crates/collab/.admins.default.json")
|
||||
.expect("failed to load default admins file");
|
||||
if let Ok(other_admins) = load_admins("./.admins.json") {
|
||||
admin_logins.extend(other_admins);
|
||||
}
|
||||
|
||||
let database_url = std::env::var("DATABASE_URL").expect("missing DATABASE_URL env var");
|
||||
let db = Database::new(ConnectOptions::new(database_url), Executor::Production)
|
||||
.await
|
||||
.expect("failed to connect to postgres database");
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Create admin users for all of the users in `.admins.toml` or `.admins.default.toml`.
|
||||
for admin_login in admin_logins {
|
||||
let user = fetch_github::<GitHubUser>(
|
||||
&client,
|
||||
&format!("https://api.github.com/users/{admin_login}"),
|
||||
)
|
||||
.await;
|
||||
db.create_user(
|
||||
&user.email.unwrap_or(format!("{admin_login}@example.com")),
|
||||
true,
|
||||
NewUserParams {
|
||||
github_login: user.login,
|
||||
github_user_id: user.id,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("failed to create admin user");
|
||||
}
|
||||
|
||||
// Fetch 100 other random users from GitHub and insert them into the database.
|
||||
let mut user_count = db
|
||||
.get_all_users(0, 200)
|
||||
.await
|
||||
.expect("failed to load users from db")
|
||||
.len();
|
||||
let mut last_user_id = None;
|
||||
while user_count < 100 {
|
||||
let mut uri = "https://api.github.com/users?per_page=100".to_string();
|
||||
if let Some(last_user_id) = last_user_id {
|
||||
write!(&mut uri, "&since={}", last_user_id).unwrap();
|
||||
}
|
||||
let users = fetch_github::<Vec<GitHubUser>>(&client, &uri).await;
|
||||
|
||||
for github_user in users {
|
||||
last_user_id = Some(github_user.id);
|
||||
user_count += 1;
|
||||
db.get_or_create_user_by_github_account(
|
||||
&github_user.login,
|
||||
Some(github_user.id),
|
||||
github_user.email.as_deref(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("failed to insert user");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_admins(path: &str) -> anyhow::Result<Vec<String>> {
|
||||
let file_content = fs::read_to_string(path)?;
|
||||
Ok(serde_json::from_str(&file_content)?)
|
||||
}
|
||||
|
||||
async fn fetch_github<T: DeserializeOwned>(client: &reqwest::Client, url: &str) -> T {
|
||||
let response = client
|
||||
.get(url)
|
||||
.header("user-agent", "zed")
|
||||
.send()
|
||||
.await
|
||||
.unwrap_or_else(|_| panic!("failed to fetch '{}'", url));
|
||||
response
|
||||
.json()
|
||||
.await
|
||||
.unwrap_or_else(|_| panic!("failed to deserialize github user from '{}'", url))
|
||||
}
|
||||
@@ -12,7 +12,7 @@ use futures::StreamExt;
|
||||
use rand::{prelude::StdRng, Rng, SeedableRng};
|
||||
use rpc::{
|
||||
proto::{self},
|
||||
ConnectionId, ExtensionMetadata,
|
||||
ConnectionId,
|
||||
};
|
||||
use sea_orm::{
|
||||
entity::prelude::*,
|
||||
@@ -21,13 +21,11 @@ use sea_orm::{
|
||||
FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
|
||||
TransactionTrait,
|
||||
};
|
||||
use semantic_version::SemanticVersion;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{ser::Error as _, Deserialize, Serialize, Serializer};
|
||||
use sqlx::{
|
||||
migrate::{Migrate, Migration, MigrationSource},
|
||||
Connection,
|
||||
};
|
||||
use std::ops::RangeInclusive;
|
||||
use std::{
|
||||
fmt::Write as _,
|
||||
future::Future,
|
||||
@@ -38,7 +36,7 @@ use std::{
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use time::PrimitiveDateTime;
|
||||
use time::{format_description::well_known::iso8601, PrimitiveDateTime};
|
||||
use tokio::sync::{Mutex, OwnedMutexGuard};
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -130,6 +128,12 @@ impl Database {
|
||||
Ok(new_migrations)
|
||||
}
|
||||
|
||||
/// Initializes static data that resides in the database by upserting it.
|
||||
pub async fn initialize_static_data(&mut self) -> Result<()> {
|
||||
self.initialize_notification_kinds().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Transaction runs things in a transaction. If you want to call other methods
|
||||
/// and pass the transaction around you need to reborrow the transaction at each
|
||||
/// call site with: `&*tx`.
|
||||
@@ -454,16 +458,6 @@ pub struct CreatedChannelMessage {
|
||||
pub notifications: NotificationBatch,
|
||||
}
|
||||
|
||||
pub struct UpdatedChannelMessage {
|
||||
pub message_id: MessageId,
|
||||
pub participant_connection_ids: Vec<ConnectionId>,
|
||||
pub notifications: NotificationBatch,
|
||||
pub reply_to_message_id: Option<MessageId>,
|
||||
pub timestamp: PrimitiveDateTime,
|
||||
pub deleted_mention_notification_ids: Vec<NotificationId>,
|
||||
pub updated_mention_notifications: Vec<rpc::proto::Notification>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
|
||||
pub struct Invite {
|
||||
pub email_address: String,
|
||||
@@ -552,7 +546,7 @@ pub struct Channel {
|
||||
}
|
||||
|
||||
impl Channel {
|
||||
pub fn from_model(value: channel::Model) -> Self {
|
||||
fn from_model(value: channel::Model) -> Self {
|
||||
Channel {
|
||||
id: value.id,
|
||||
visibility: value.visibility,
|
||||
@@ -610,14 +604,16 @@ pub struct RejoinedChannelBuffer {
|
||||
#[derive(Clone)]
|
||||
pub struct JoinRoom {
|
||||
pub room: proto::Room,
|
||||
pub channel: Option<channel::Model>,
|
||||
pub channel_id: Option<ChannelId>,
|
||||
pub channel_members: Vec<UserId>,
|
||||
}
|
||||
|
||||
pub struct RejoinedRoom {
|
||||
pub room: proto::Room,
|
||||
pub rejoined_projects: Vec<RejoinedProject>,
|
||||
pub reshared_projects: Vec<ResharedProject>,
|
||||
pub channel: Option<channel::Model>,
|
||||
pub channel_id: Option<ChannelId>,
|
||||
pub channel_members: Vec<UserId>,
|
||||
}
|
||||
|
||||
pub struct ResharedProject {
|
||||
@@ -653,7 +649,8 @@ pub struct RejoinedWorktree {
|
||||
|
||||
pub struct LeftRoom {
|
||||
pub room: proto::Room,
|
||||
pub channel: Option<channel::Model>,
|
||||
pub channel_id: Option<ChannelId>,
|
||||
pub channel_members: Vec<UserId>,
|
||||
pub left_projects: HashMap<ProjectId, LeftProject>,
|
||||
pub canceled_calls_to_user_ids: Vec<UserId>,
|
||||
pub deleted: bool,
|
||||
@@ -661,7 +658,8 @@ pub struct LeftRoom {
|
||||
|
||||
pub struct RefreshedRoom {
|
||||
pub room: proto::Room,
|
||||
pub channel: Option<channel::Model>,
|
||||
pub channel_id: Option<ChannelId>,
|
||||
pub channel_members: Vec<UserId>,
|
||||
pub stale_participant_user_ids: Vec<UserId>,
|
||||
pub canceled_calls_to_user_ids: Vec<UserId>,
|
||||
}
|
||||
@@ -729,12 +727,36 @@ pub struct NewExtensionVersion {
|
||||
pub description: String,
|
||||
pub authors: Vec<String>,
|
||||
pub repository: String,
|
||||
pub schema_version: i32,
|
||||
pub wasm_api_version: Option<String>,
|
||||
pub published_at: PrimitiveDateTime,
|
||||
}
|
||||
|
||||
pub struct ExtensionVersionConstraints {
|
||||
pub schema_versions: RangeInclusive<i32>,
|
||||
pub wasm_api_versions: RangeInclusive<SemanticVersion>,
|
||||
#[derive(Debug, Serialize, PartialEq)]
|
||||
pub struct ExtensionMetadata {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
pub authors: Vec<String>,
|
||||
pub description: String,
|
||||
pub repository: String,
|
||||
#[serde(serialize_with = "serialize_iso8601")]
|
||||
pub published_at: PrimitiveDateTime,
|
||||
pub download_count: u64,
|
||||
}
|
||||
|
||||
pub fn serialize_iso8601<S: Serializer>(
|
||||
datetime: &PrimitiveDateTime,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error> {
|
||||
const SERDE_CONFIG: iso8601::EncodedConfig = iso8601::Config::DEFAULT
|
||||
.set_year_is_six_digits(false)
|
||||
.set_time_precision(iso8601::TimePrecision::Second {
|
||||
decimal_digits: None,
|
||||
})
|
||||
.encode();
|
||||
|
||||
datetime
|
||||
.assume_utc()
|
||||
.format(&time::format_description::well_known::Iso8601::<SERDE_CONFIG>)
|
||||
.map_err(S::Error::custom)?
|
||||
.serialize(serializer)
|
||||
}
|
||||
|
||||
@@ -67,34 +67,31 @@ macro_rules! id_type {
|
||||
};
|
||||
}
|
||||
|
||||
id_type!(AccessTokenId);
|
||||
id_type!(BufferId);
|
||||
id_type!(ChannelBufferCollaboratorId);
|
||||
id_type!(AccessTokenId);
|
||||
id_type!(ChannelChatParticipantId);
|
||||
id_type!(ChannelId);
|
||||
id_type!(ChannelMemberId);
|
||||
id_type!(ContactId);
|
||||
id_type!(DevServerId);
|
||||
id_type!(ExtensionId);
|
||||
id_type!(FlagId);
|
||||
id_type!(FollowerId);
|
||||
id_type!(HostedProjectId);
|
||||
id_type!(MessageId);
|
||||
id_type!(NotificationId);
|
||||
id_type!(NotificationKindId);
|
||||
id_type!(ProjectCollaboratorId);
|
||||
id_type!(ProjectId);
|
||||
id_type!(ReplicaId);
|
||||
id_type!(ContactId);
|
||||
id_type!(FollowerId);
|
||||
id_type!(RoomId);
|
||||
id_type!(RoomParticipantId);
|
||||
id_type!(ProjectId);
|
||||
id_type!(ProjectCollaboratorId);
|
||||
id_type!(ReplicaId);
|
||||
id_type!(ServerId);
|
||||
id_type!(SignupId);
|
||||
id_type!(UserId);
|
||||
id_type!(ChannelBufferCollaboratorId);
|
||||
id_type!(FlagId);
|
||||
id_type!(ExtensionId);
|
||||
id_type!(NotificationId);
|
||||
id_type!(NotificationKindId);
|
||||
id_type!(HostedProjectId);
|
||||
|
||||
/// ChannelRole gives you permissions for both channels and calls.
|
||||
#[derive(
|
||||
Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash, Serialize,
|
||||
)]
|
||||
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)]
|
||||
#[sea_orm(rs_type = "String", db_type = "String(None)")]
|
||||
pub enum ChannelRole {
|
||||
/// Admin can read/write and change permissions.
|
||||
|
||||
@@ -5,13 +5,11 @@ pub mod buffers;
|
||||
pub mod channels;
|
||||
pub mod contacts;
|
||||
pub mod contributors;
|
||||
pub mod dev_servers;
|
||||
pub mod extensions;
|
||||
pub mod hosted_projects;
|
||||
pub mod messages;
|
||||
pub mod notifications;
|
||||
pub mod projects;
|
||||
pub mod rate_buckets;
|
||||
pub mod rooms;
|
||||
pub mod servers;
|
||||
pub mod users;
|
||||
|
||||
@@ -45,7 +45,11 @@ impl Database {
|
||||
name: &str,
|
||||
parent_channel_id: Option<ChannelId>,
|
||||
admin_id: UserId,
|
||||
) -> Result<(channel::Model, Option<channel_member::Model>)> {
|
||||
) -> Result<(
|
||||
Channel,
|
||||
Option<channel_member::Model>,
|
||||
Vec<channel_member::Model>,
|
||||
)> {
|
||||
let name = Self::sanitize_channel_name(name)?;
|
||||
self.transaction(move |tx| async move {
|
||||
let mut parent = None;
|
||||
@@ -86,7 +90,12 @@ impl Database {
|
||||
);
|
||||
}
|
||||
|
||||
Ok((channel, membership))
|
||||
let channel_members = channel_member::Entity::find()
|
||||
.filter(channel_member::Column::ChannelId.eq(channel.root_id()))
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok((Channel::from_model(channel), membership, channel_members))
|
||||
})
|
||||
.await
|
||||
}
|
||||
@@ -172,7 +181,7 @@ impl Database {
|
||||
channel_id: ChannelId,
|
||||
visibility: ChannelVisibility,
|
||||
admin_id: UserId,
|
||||
) -> Result<channel::Model> {
|
||||
) -> Result<(Channel, Vec<channel_member::Model>)> {
|
||||
self.transaction(move |tx| async move {
|
||||
let channel = self.get_channel_internal(channel_id, &tx).await?;
|
||||
self.check_user_is_channel_admin(&channel, admin_id, &tx)
|
||||
@@ -205,7 +214,12 @@ impl Database {
|
||||
model.visibility = ActiveValue::Set(visibility);
|
||||
let channel = model.update(&*tx).await?;
|
||||
|
||||
Ok(channel)
|
||||
let channel_members = channel_member::Entity::find()
|
||||
.filter(channel_member::Column::ChannelId.eq(channel.root_id()))
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok((Channel::from_model(channel), channel_members))
|
||||
})
|
||||
.await
|
||||
}
|
||||
@@ -231,12 +245,21 @@ impl Database {
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
user_id: UserId,
|
||||
) -> Result<(ChannelId, Vec<ChannelId>)> {
|
||||
) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
|
||||
self.transaction(move |tx| async move {
|
||||
let channel = self.get_channel_internal(channel_id, &tx).await?;
|
||||
self.check_user_is_channel_admin(&channel, user_id, &tx)
|
||||
.await?;
|
||||
|
||||
let members_to_notify: Vec<UserId> = channel_member::Entity::find()
|
||||
.filter(channel_member::Column::ChannelId.eq(channel.root_id()))
|
||||
.select_only()
|
||||
.column(channel_member::Column::UserId)
|
||||
.distinct()
|
||||
.into_values::<_, QueryUserIds>()
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let channels_to_remove = self
|
||||
.get_channel_descendants_excluding_self([&channel], &tx)
|
||||
.await?
|
||||
@@ -250,7 +273,7 @@ impl Database {
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok((channel.root_id(), channels_to_remove))
|
||||
Ok((channels_to_remove, members_to_notify))
|
||||
})
|
||||
.await
|
||||
}
|
||||
@@ -320,7 +343,7 @@ impl Database {
|
||||
channel_id: ChannelId,
|
||||
admin_id: UserId,
|
||||
new_name: &str,
|
||||
) -> Result<channel::Model> {
|
||||
) -> Result<(Channel, Vec<channel_member::Model>)> {
|
||||
self.transaction(move |tx| async move {
|
||||
let new_name = Self::sanitize_channel_name(new_name)?.to_string();
|
||||
|
||||
@@ -332,7 +355,12 @@ impl Database {
|
||||
model.name = ActiveValue::Set(new_name.clone());
|
||||
let channel = model.update(&*tx).await?;
|
||||
|
||||
Ok(channel)
|
||||
let channel_members = channel_member::Entity::find()
|
||||
.filter(channel_member::Column::ChannelId.eq(channel.root_id()))
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok((Channel::from_model(channel), channel_members))
|
||||
})
|
||||
.await
|
||||
}
|
||||
@@ -956,7 +984,7 @@ impl Database {
|
||||
channel_id: ChannelId,
|
||||
new_parent_id: ChannelId,
|
||||
admin_id: UserId,
|
||||
) -> Result<(ChannelId, Vec<Channel>)> {
|
||||
) -> Result<(Vec<Channel>, Vec<channel_member::Model>)> {
|
||||
self.transaction(|tx| async move {
|
||||
let channel = self.get_channel_internal(channel_id, &tx).await?;
|
||||
self.check_user_is_channel_admin(&channel, admin_id, &tx)
|
||||
@@ -1011,7 +1039,12 @@ impl Database {
|
||||
.map(|c| Channel::from_model(c))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok((root_id, channels))
|
||||
let channel_members = channel_member::Entity::find()
|
||||
.filter(channel_member::Column::ChannelId.eq(root_id))
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok((channels, channel_members))
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
use sea_orm::EntityTrait;
|
||||
|
||||
use super::{dev_server, Database, DevServerId};
|
||||
|
||||
impl Database {
|
||||
pub async fn get_dev_server(
|
||||
&self,
|
||||
dev_server_id: DevServerId,
|
||||
) -> crate::Result<dev_server::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(dev_server::Entity::find_by_id(dev_server_id)
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("no dev server with id {}", dev_server_id))?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -1,202 +1,57 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use chrono::Utc;
|
||||
use sea_orm::sea_query::IntoCondition;
|
||||
use util::ResultExt;
|
||||
|
||||
use super::*;
|
||||
|
||||
impl Database {
|
||||
pub async fn get_extensions(
|
||||
&self,
|
||||
filter: Option<&str>,
|
||||
max_schema_version: i32,
|
||||
limit: usize,
|
||||
) -> Result<Vec<ExtensionMetadata>> {
|
||||
self.transaction(|tx| async move {
|
||||
let mut condition = Condition::all()
|
||||
.add(
|
||||
extension::Column::LatestVersion
|
||||
.into_expr()
|
||||
.eq(extension_version::Column::Version.into_expr()),
|
||||
)
|
||||
.add(extension_version::Column::SchemaVersion.lte(max_schema_version));
|
||||
let mut condition = Condition::all();
|
||||
if let Some(filter) = filter {
|
||||
let fuzzy_name_filter = Self::fuzzy_like_string(filter);
|
||||
condition = condition.add(Expr::cust_with_expr("name ILIKE $1", fuzzy_name_filter));
|
||||
}
|
||||
|
||||
self.get_extensions_where(condition, Some(limit as u64), &tx)
|
||||
.await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_extensions_by_ids(
|
||||
&self,
|
||||
ids: &[&str],
|
||||
constraints: Option<&ExtensionVersionConstraints>,
|
||||
) -> Result<Vec<ExtensionMetadata>> {
|
||||
self.transaction(|tx| async move {
|
||||
let extensions = extension::Entity::find()
|
||||
.filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
|
||||
.filter(condition)
|
||||
.order_by_desc(extension::Column::TotalDownloadCount)
|
||||
.order_by_asc(extension::Column::Name)
|
||||
.limit(Some(limit as u64))
|
||||
.filter(
|
||||
extension::Column::LatestVersion
|
||||
.into_expr()
|
||||
.eq(extension_version::Column::Version.into_expr()),
|
||||
)
|
||||
.inner_join(extension_version::Entity)
|
||||
.select_also(extension_version::Entity)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let mut max_versions = self
|
||||
.get_latest_versions_for_extensions(&extensions, constraints, &tx)
|
||||
.await?;
|
||||
|
||||
Ok(extensions
|
||||
.into_iter()
|
||||
.filter_map(|extension| {
|
||||
let (version, _) = max_versions.remove(&extension.id)?;
|
||||
Some(metadata_from_extension_and_version(extension, version))
|
||||
.filter_map(|(extension, latest_version)| {
|
||||
let version = latest_version?;
|
||||
Some(ExtensionMetadata {
|
||||
id: extension.external_id,
|
||||
name: extension.name,
|
||||
version: version.version,
|
||||
authors: version
|
||||
.authors
|
||||
.split(',')
|
||||
.map(|author| author.trim().to_string())
|
||||
.collect::<Vec<_>>(),
|
||||
description: version.description,
|
||||
repository: version.repository,
|
||||
published_at: version.published_at,
|
||||
download_count: extension.total_download_count as u64,
|
||||
})
|
||||
})
|
||||
.collect())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_latest_versions_for_extensions(
|
||||
&self,
|
||||
extensions: &[extension::Model],
|
||||
constraints: Option<&ExtensionVersionConstraints>,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<HashMap<ExtensionId, (extension_version::Model, SemanticVersion)>> {
|
||||
let mut versions = extension_version::Entity::find()
|
||||
.filter(
|
||||
extension_version::Column::ExtensionId
|
||||
.is_in(extensions.iter().map(|extension| extension.id)),
|
||||
)
|
||||
.stream(tx)
|
||||
.await?;
|
||||
|
||||
let mut max_versions =
|
||||
HashMap::<ExtensionId, (extension_version::Model, SemanticVersion)>::default();
|
||||
while let Some(version) = versions.next().await {
|
||||
let version = version?;
|
||||
let Some(extension_version) = SemanticVersion::from_str(&version.version).log_err()
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if let Some((_, max_extension_version)) = &max_versions.get(&version.extension_id) {
|
||||
if max_extension_version > &extension_version {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(constraints) = constraints {
|
||||
if !constraints
|
||||
.schema_versions
|
||||
.contains(&version.schema_version)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(wasm_api_version) = version.wasm_api_version.as_ref() {
|
||||
if let Some(version) = SemanticVersion::from_str(wasm_api_version).log_err() {
|
||||
if !constraints.wasm_api_versions.contains(&version) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
max_versions.insert(version.extension_id, (version, extension_version));
|
||||
}
|
||||
|
||||
Ok(max_versions)
|
||||
}
|
||||
|
||||
/// Returns all of the versions for the extension with the given ID.
|
||||
pub async fn get_extension_versions(
|
||||
&self,
|
||||
extension_id: &str,
|
||||
) -> Result<Vec<ExtensionMetadata>> {
|
||||
self.transaction(|tx| async move {
|
||||
let condition = extension::Column::ExternalId
|
||||
.eq(extension_id)
|
||||
.into_condition();
|
||||
|
||||
self.get_extensions_where(condition, None, &tx).await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_extensions_where(
|
||||
&self,
|
||||
condition: Condition,
|
||||
limit: Option<u64>,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<Vec<ExtensionMetadata>> {
|
||||
let extensions = extension::Entity::find()
|
||||
.inner_join(extension_version::Entity)
|
||||
.select_also(extension_version::Entity)
|
||||
.filter(condition)
|
||||
.order_by_desc(extension::Column::TotalDownloadCount)
|
||||
.order_by_asc(extension::Column::Name)
|
||||
.limit(limit)
|
||||
.all(tx)
|
||||
.await?;
|
||||
|
||||
Ok(extensions
|
||||
.into_iter()
|
||||
.filter_map(|(extension, version)| {
|
||||
Some(metadata_from_extension_and_version(extension, version?))
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn get_extension(
|
||||
&self,
|
||||
extension_id: &str,
|
||||
constraints: Option<&ExtensionVersionConstraints>,
|
||||
) -> Result<Option<ExtensionMetadata>> {
|
||||
self.transaction(|tx| async move {
|
||||
let extension = extension::Entity::find()
|
||||
.filter(extension::Column::ExternalId.eq(extension_id))
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("no such extension: {extension_id}"))?;
|
||||
|
||||
let extensions = [extension];
|
||||
let mut versions = self
|
||||
.get_latest_versions_for_extensions(&extensions, constraints, &tx)
|
||||
.await?;
|
||||
let [extension] = extensions;
|
||||
|
||||
Ok(versions.remove(&extension.id).map(|(max_version, _)| {
|
||||
metadata_from_extension_and_version(extension, max_version)
|
||||
}))
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_extension_version(
|
||||
&self,
|
||||
extension_id: &str,
|
||||
version: &str,
|
||||
) -> Result<Option<ExtensionMetadata>> {
|
||||
self.transaction(|tx| async move {
|
||||
let extension = extension::Entity::find()
|
||||
.filter(extension::Column::ExternalId.eq(extension_id))
|
||||
.filter(extension_version::Column::Version.eq(version))
|
||||
.inner_join(extension_version::Entity)
|
||||
.select_also(extension_version::Entity)
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(extension.and_then(|(extension, version)| {
|
||||
Some(metadata_from_extension_and_version(extension, version?))
|
||||
}))
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_known_extension_versions<'a>(&self) -> Result<HashMap<String, Vec<String>>> {
|
||||
self.transaction(|tx| async move {
|
||||
let mut extension_external_ids_by_id = HashMap::default();
|
||||
@@ -280,8 +135,6 @@ impl Database {
|
||||
authors: ActiveValue::Set(version.authors.join(", ")),
|
||||
repository: ActiveValue::Set(version.repository.clone()),
|
||||
description: ActiveValue::Set(version.description.clone()),
|
||||
schema_version: ActiveValue::Set(version.schema_version),
|
||||
wasm_api_version: ActiveValue::Set(version.wasm_api_version.clone()),
|
||||
download_count: ActiveValue::NotSet,
|
||||
}
|
||||
}))
|
||||
@@ -351,35 +204,3 @@ impl Database {
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
fn metadata_from_extension_and_version(
|
||||
extension: extension::Model,
|
||||
version: extension_version::Model,
|
||||
) -> ExtensionMetadata {
|
||||
ExtensionMetadata {
|
||||
id: extension.external_id.into(),
|
||||
manifest: rpc::ExtensionApiManifest {
|
||||
name: extension.name,
|
||||
version: version.version.into(),
|
||||
authors: version
|
||||
.authors
|
||||
.split(',')
|
||||
.map(|author| author.trim().to_string())
|
||||
.collect::<Vec<_>>(),
|
||||
description: Some(version.description),
|
||||
repository: version.repository,
|
||||
schema_version: Some(version.schema_version),
|
||||
wasm_api_version: version.wasm_api_version,
|
||||
},
|
||||
|
||||
published_at: convert_time_to_chrono(version.published_at),
|
||||
download_count: extension.total_download_count as u64,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert_time_to_chrono(time: time::PrimitiveDateTime) -> chrono::DateTime<Utc> {
|
||||
chrono::DateTime::from_naive_utc_and_offset(
|
||||
chrono::NaiveDateTime::from_timestamp_opt(time.assume_utc().unix_timestamp(), 0).unwrap(),
|
||||
Utc,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use super::*;
|
||||
use rpc::Notification;
|
||||
use sea_orm::{SelectColumns, TryInsertResult};
|
||||
use sea_orm::TryInsertResult;
|
||||
use time::OffsetDateTime;
|
||||
use util::ResultExt;
|
||||
|
||||
impl Database {
|
||||
/// Inserts a record representing a user joining the chat for a given channel.
|
||||
@@ -163,9 +162,6 @@ impl Database {
|
||||
lower_half: nonce.1,
|
||||
}),
|
||||
reply_to_message_id: row.reply_to_message_id.map(|id| id.to_proto()),
|
||||
edited_at: row
|
||||
.edited_at
|
||||
.map(|t| t.assume_utc().unix_timestamp() as u64),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
@@ -203,31 +199,6 @@ impl Database {
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
fn format_mentions_to_entities(
|
||||
&self,
|
||||
message_id: MessageId,
|
||||
body: &str,
|
||||
mentions: &[proto::ChatMention],
|
||||
) -> Result<Vec<tables::channel_message_mention::ActiveModel>> {
|
||||
Ok(mentions
|
||||
.iter()
|
||||
.filter_map(|mention| {
|
||||
let range = mention.range.as_ref()?;
|
||||
if !body.is_char_boundary(range.start as usize)
|
||||
|| !body.is_char_boundary(range.end as usize)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
Some(channel_message_mention::ActiveModel {
|
||||
message_id: ActiveValue::Set(message_id),
|
||||
start_offset: ActiveValue::Set(range.start as i32),
|
||||
end_offset: ActiveValue::Set(range.end as i32),
|
||||
user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
/// Creates a new channel message.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn create_channel_message(
|
||||
@@ -278,7 +249,6 @@ impl Database {
|
||||
nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
|
||||
id: ActiveValue::NotSet,
|
||||
reply_to_message_id: ActiveValue::Set(reply_to_message_id),
|
||||
edited_at: ActiveValue::NotSet,
|
||||
})
|
||||
.on_conflict(
|
||||
OnConflict::columns([
|
||||
@@ -300,7 +270,23 @@ impl Database {
|
||||
let mentioned_user_ids =
|
||||
mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
|
||||
|
||||
let mentions = self.format_mentions_to_entities(message_id, body, mentions)?;
|
||||
let mentions = mentions
|
||||
.iter()
|
||||
.filter_map(|mention| {
|
||||
let range = mention.range.as_ref()?;
|
||||
if !body.is_char_boundary(range.start as usize)
|
||||
|| !body.is_char_boundary(range.end as usize)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
Some(channel_message_mention::ActiveModel {
|
||||
message_id: ActiveValue::Set(message_id),
|
||||
start_offset: ActiveValue::Set(range.start as i32),
|
||||
end_offset: ActiveValue::Set(range.end as i32),
|
||||
user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if !mentions.is_empty() {
|
||||
channel_message_mention::Entity::insert_many(mentions)
|
||||
.exec(&*tx)
|
||||
@@ -481,20 +467,13 @@ impl Database {
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn get_notification_kind_id_by_name(&self, notification_kind: &str) -> Option<i32> {
|
||||
self.notification_kinds_by_id
|
||||
.iter()
|
||||
.find(|(_, kind)| **kind == notification_kind)
|
||||
.map(|kind| kind.0 .0)
|
||||
}
|
||||
|
||||
/// Removes the channel message with the given ID.
|
||||
pub async fn remove_channel_message(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
message_id: MessageId,
|
||||
user_id: UserId,
|
||||
) -> Result<(Vec<ConnectionId>, Vec<NotificationId>)> {
|
||||
) -> Result<Vec<ConnectionId>> {
|
||||
self.transaction(|tx| async move {
|
||||
let mut rows = channel_chat_participant::Entity::find()
|
||||
.filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
|
||||
@@ -539,190 +518,7 @@ impl Database {
|
||||
}
|
||||
}
|
||||
|
||||
let notification_kind_id =
|
||||
self.get_notification_kind_id_by_name("ChannelMessageMention");
|
||||
|
||||
let existing_notifications = notification::Entity::find()
|
||||
.filter(notification::Column::EntityId.eq(message_id))
|
||||
.filter(notification::Column::Kind.eq(notification_kind_id))
|
||||
.select_column(notification::Column::Id)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let existing_notification_ids = existing_notifications
|
||||
.into_iter()
|
||||
.map(|notification| notification.id)
|
||||
.collect();
|
||||
|
||||
// remove all the mention notifications for this message
|
||||
notification::Entity::delete_many()
|
||||
.filter(notification::Column::EntityId.eq(message_id))
|
||||
.filter(notification::Column::Kind.eq(notification_kind_id))
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok((participant_connection_ids, existing_notification_ids))
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Updates the channel message with the given ID, body and timestamp(edited_at).
|
||||
pub async fn update_channel_message(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
message_id: MessageId,
|
||||
user_id: UserId,
|
||||
body: &str,
|
||||
mentions: &[proto::ChatMention],
|
||||
edited_at: OffsetDateTime,
|
||||
) -> Result<UpdatedChannelMessage> {
|
||||
self.transaction(|tx| async move {
|
||||
let channel = self.get_channel_internal(channel_id, &tx).await?;
|
||||
self.check_user_is_channel_participant(&channel, user_id, &tx)
|
||||
.await?;
|
||||
|
||||
let mut rows = channel_chat_participant::Entity::find()
|
||||
.filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
|
||||
let mut is_participant = false;
|
||||
let mut participant_connection_ids = Vec::new();
|
||||
let mut participant_user_ids = Vec::new();
|
||||
while let Some(row) = rows.next().await {
|
||||
let row = row?;
|
||||
if row.user_id == user_id {
|
||||
is_participant = true;
|
||||
}
|
||||
participant_user_ids.push(row.user_id);
|
||||
participant_connection_ids.push(row.connection());
|
||||
}
|
||||
drop(rows);
|
||||
|
||||
if !is_participant {
|
||||
Err(anyhow!("not a chat participant"))?;
|
||||
}
|
||||
|
||||
let channel_message = channel_message::Entity::find_by_id(message_id)
|
||||
.filter(channel_message::Column::SenderId.eq(user_id))
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
|
||||
let Some(channel_message) = channel_message else {
|
||||
Err(anyhow!("Channel message not found"))?
|
||||
};
|
||||
|
||||
let edited_at = edited_at.to_offset(time::UtcOffset::UTC);
|
||||
let edited_at = time::PrimitiveDateTime::new(edited_at.date(), edited_at.time());
|
||||
|
||||
let updated_message = channel_message::ActiveModel {
|
||||
body: ActiveValue::Set(body.to_string()),
|
||||
edited_at: ActiveValue::Set(Some(edited_at)),
|
||||
reply_to_message_id: ActiveValue::Unchanged(channel_message.reply_to_message_id),
|
||||
id: ActiveValue::Unchanged(message_id),
|
||||
channel_id: ActiveValue::Unchanged(channel_id),
|
||||
sender_id: ActiveValue::Unchanged(user_id),
|
||||
sent_at: ActiveValue::Unchanged(channel_message.sent_at),
|
||||
nonce: ActiveValue::Unchanged(channel_message.nonce),
|
||||
};
|
||||
|
||||
let result = channel_message::Entity::update_many()
|
||||
.set(updated_message)
|
||||
.filter(channel_message::Column::Id.eq(message_id))
|
||||
.filter(channel_message::Column::SenderId.eq(user_id))
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
if result.rows_affected == 0 {
|
||||
return Err(anyhow!(
|
||||
"Attempted to edit a message (id: {message_id}) which does not exist anymore."
|
||||
))?;
|
||||
}
|
||||
|
||||
// we have to fetch the old mentions,
|
||||
// so we don't send a notification when the message has been edited that you are mentioned in
|
||||
let old_mentions = channel_message_mention::Entity::find()
|
||||
.filter(channel_message_mention::Column::MessageId.eq(message_id))
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
// remove all existing mentions
|
||||
channel_message_mention::Entity::delete_many()
|
||||
.filter(channel_message_mention::Column::MessageId.eq(message_id))
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
|
||||
let new_mentions = self.format_mentions_to_entities(message_id, body, mentions)?;
|
||||
if !new_mentions.is_empty() {
|
||||
// insert new mentions
|
||||
channel_message_mention::Entity::insert_many(new_mentions)
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let mut update_mention_user_ids = HashSet::default();
|
||||
let mut new_mention_user_ids =
|
||||
mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
|
||||
// Filter out users that were mentioned before
|
||||
for mention in &old_mentions {
|
||||
if new_mention_user_ids.contains(&mention.user_id.to_proto()) {
|
||||
update_mention_user_ids.insert(mention.user_id.to_proto());
|
||||
}
|
||||
|
||||
new_mention_user_ids.remove(&mention.user_id.to_proto());
|
||||
}
|
||||
|
||||
let notification_kind_id =
|
||||
self.get_notification_kind_id_by_name("ChannelMessageMention");
|
||||
|
||||
let existing_notifications = notification::Entity::find()
|
||||
.filter(notification::Column::EntityId.eq(message_id))
|
||||
.filter(notification::Column::Kind.eq(notification_kind_id))
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
// determine which notifications should be updated or deleted
|
||||
let mut deleted_notification_ids = HashSet::default();
|
||||
let mut updated_mention_notifications = Vec::new();
|
||||
for notification in existing_notifications {
|
||||
if update_mention_user_ids.contains(¬ification.recipient_id.to_proto()) {
|
||||
if let Some(notification) =
|
||||
self::notifications::model_to_proto(self, notification).log_err()
|
||||
{
|
||||
updated_mention_notifications.push(notification);
|
||||
}
|
||||
} else {
|
||||
deleted_notification_ids.insert(notification.id);
|
||||
}
|
||||
}
|
||||
|
||||
let mut notifications = Vec::new();
|
||||
for mentioned_user in new_mention_user_ids {
|
||||
notifications.extend(
|
||||
self.create_notification(
|
||||
UserId::from_proto(mentioned_user),
|
||||
rpc::Notification::ChannelMessageMention {
|
||||
message_id: message_id.to_proto(),
|
||||
sender_id: user_id.to_proto(),
|
||||
channel_id: channel_id.to_proto(),
|
||||
},
|
||||
false,
|
||||
&tx,
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(UpdatedChannelMessage {
|
||||
message_id,
|
||||
participant_connection_ids,
|
||||
notifications,
|
||||
reply_to_message_id: channel_message.reply_to_message_id,
|
||||
timestamp: channel_message.sent_at,
|
||||
deleted_mention_notification_ids: deleted_notification_ids
|
||||
.into_iter()
|
||||
.collect::<Vec<_>>(),
|
||||
updated_mention_notifications,
|
||||
})
|
||||
Ok(participant_connection_ids)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user