Compare commits
8 Commits
tasks-moda
...
v0.135.2-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cfd8d1860e | ||
|
|
ad5683bcdb | ||
|
|
a565c3a470 | ||
|
|
3f17540002 | ||
|
|
647c31e398 | ||
|
|
6d927a4991 | ||
|
|
3a84d5eaa3 | ||
|
|
3ad3a36643 |
@@ -1,15 +0,0 @@
|
||||
We have two cloudflare workers that let us serve some assets of this repo
|
||||
from Cloudflare.
|
||||
|
||||
* `open-source-website-assets` is used for `install.sh`
|
||||
* `docs-proxy` is used for `https://zed.dev/docs`
|
||||
|
||||
On push to `main`, both of these (and the files they depend on) are uploaded to Cloudflare.
|
||||
|
||||
### Deployment
|
||||
|
||||
These functions are deployed on push to main by the deploy_cloudflare.yml workflow. Worker Rules in Cloudflare intercept requests to zed.dev and proxy them to the appropriate workers.
|
||||
|
||||
### Testing
|
||||
|
||||
You can use [wrangler](https://developers.cloudflare.com/workers/cli-wrangler/install-update) to test these workers locally, or to deploy custom versions.
|
||||
@@ -1,14 +0,0 @@
|
||||
export default {
|
||||
async fetch(request, _env, _ctx) {
|
||||
const url = new URL(request.url);
|
||||
url.hostname = "docs-anw.pages.dev";
|
||||
|
||||
let res = await fetch(url, request);
|
||||
|
||||
if (res.status === 404) {
|
||||
res = await fetch("https://zed.dev/404");
|
||||
}
|
||||
|
||||
return res;
|
||||
},
|
||||
};
|
||||
@@ -1,8 +0,0 @@
|
||||
name = "docs-proxy"
|
||||
main = "src/worker.js"
|
||||
compatibility_date = "2024-05-03"
|
||||
workers_dev = true
|
||||
|
||||
[[routes]]
|
||||
pattern = "zed.dev/docs*"
|
||||
zone_name = "zed.dev"
|
||||
@@ -1,19 +0,0 @@
|
||||
export default {
|
||||
async fetch(request, env) {
|
||||
const url = new URL(request.url);
|
||||
const key = url.pathname.slice(1);
|
||||
|
||||
const object = await env.OPEN_SOURCE_WEBSITE_ASSETS_BUCKET.get(key);
|
||||
if (!object) {
|
||||
return await fetch("https://zed.dev/404");
|
||||
}
|
||||
|
||||
const headers = new Headers();
|
||||
object.writeHttpMetadata(headers);
|
||||
headers.set("etag", object.httpEtag);
|
||||
|
||||
return new Response(object.body, {
|
||||
headers,
|
||||
});
|
||||
},
|
||||
};
|
||||
@@ -1,8 +0,0 @@
|
||||
name = "open-source-website-assets"
|
||||
main = "src/worker.js"
|
||||
compatibility_date = "2024-05-15"
|
||||
workers_dev = true
|
||||
|
||||
[[r2_buckets]]
|
||||
binding = 'OPEN_SOURCE_WEBSITE_ASSETS_BUCKET'
|
||||
bucket_name = 'zed-open-source-website-assets'
|
||||
41
.github/workflows/ci.yml
vendored
41
.github/workflows/ci.yml
vendored
@@ -32,6 +32,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Remove untracked files
|
||||
@@ -86,6 +87,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
|
||||
- name: cargo clippy
|
||||
run: cargo xtask clippy
|
||||
@@ -102,17 +104,22 @@ jobs:
|
||||
# todo(linux): Actually run the tests
|
||||
linux_tests:
|
||||
name: (Linux) Run Clippy and tests
|
||||
runs-on:
|
||||
- self-hosted
|
||||
- deploy
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Add Rust to the PATH
|
||||
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: swatinem/rust-cache@v2
|
||||
with:
|
||||
save-if: ${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: configure linux
|
||||
shell: bash -euxo pipefail {0}
|
||||
run: script/linux
|
||||
|
||||
- name: cargo clippy
|
||||
run: cargo xtask clippy
|
||||
@@ -123,12 +130,13 @@ jobs:
|
||||
# todo(windows): Actually run the tests
|
||||
windows_tests:
|
||||
name: (Windows) Run Clippy and tests
|
||||
runs-on: hosted-windows-1
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: swatinem/rust-cache@v2
|
||||
@@ -171,6 +179,7 @@ jobs:
|
||||
# 25 was chosen arbitrarily.
|
||||
fetch-depth: 25
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
|
||||
- name: Limit target directory size
|
||||
run: script/clear-target-dir-if-larger-than 100
|
||||
@@ -253,24 +262,26 @@ jobs:
|
||||
|
||||
bundle-linux:
|
||||
name: Create a Linux bundle
|
||||
runs-on:
|
||||
- self-hosted
|
||||
- deploy
|
||||
runs-on: ubuntu-22.04 # keep the version fixed to avoid libc and dynamic linked library issues
|
||||
if: ${{ startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling') }}
|
||||
needs: [linux_tests]
|
||||
env:
|
||||
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
|
||||
steps:
|
||||
- name: Add Rust to the PATH
|
||||
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
|
||||
- name: Limit target directory size
|
||||
run: script/clear-target-dir-if-larger-than 100
|
||||
- name: Cache dependencies
|
||||
uses: swatinem/rust-cache@v2
|
||||
with:
|
||||
save-if: ${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: Configure linux
|
||||
shell: bash -euxo pipefail {0}
|
||||
run: script/linux
|
||||
|
||||
- name: Determine version and release channel
|
||||
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
|
||||
|
||||
56
.github/workflows/deploy_cloudflare.yml
vendored
56
.github/workflows/deploy_cloudflare.yml
vendored
@@ -1,56 +0,0 @@
|
||||
name: Deploy Docs
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
deploy-docs:
|
||||
name: Deploy Docs
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
- name: Setup mdBook
|
||||
uses: peaceiris/actions-mdbook@v2
|
||||
with:
|
||||
mdbook-version: "0.4.37"
|
||||
|
||||
- name: Build book
|
||||
run: |
|
||||
set -euo pipefail
|
||||
mkdir -p target/deploy
|
||||
mdbook build ./docs --dest-dir=../target/deploy/docs/
|
||||
|
||||
- name: Deploy Docs
|
||||
uses: cloudflare/wrangler-action@v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
command: pages deploy target/deploy --project-name=docs
|
||||
|
||||
- name: Deploy Install
|
||||
uses: cloudflare/wrangler-action@v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
command: r2 object put -f script/install.sh zed-open-source-website-assets/install.sh
|
||||
|
||||
- name: Deploy Docs Workers
|
||||
uses: cloudflare/wrangler-action@v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
command: deploy .cloudflare/docs-proxy/src/worker.js
|
||||
|
||||
- name: Deploy Install Workers
|
||||
uses: cloudflare/wrangler-action@v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
command: deploy .cloudflare/docs-proxy/src/worker.js
|
||||
3
.github/workflows/deploy_collab.yml
vendored
3
.github/workflows/deploy_collab.yml
vendored
@@ -21,6 +21,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Run style checks
|
||||
@@ -40,6 +41,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install cargo nextest
|
||||
@@ -74,6 +76,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
|
||||
- name: Build docker image
|
||||
run: docker build . --build-arg GITHUB_SHA=$GITHUB_SHA --tag registry.digitalocean.com/zed/collab:$GITHUB_SHA
|
||||
|
||||
35
.github/workflows/deploy_docs.yml
vendored
Normal file
35
.github/workflows/deploy_docs.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: Deploy Docs
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
deploy-docs:
|
||||
name: Deploy Docs
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
- name: Setup mdBook
|
||||
uses: peaceiris/actions-mdbook@v2
|
||||
with:
|
||||
mdbook-version: "0.4.37"
|
||||
|
||||
- name: Build book
|
||||
run: |
|
||||
set -euo pipefail
|
||||
mkdir -p target/deploy
|
||||
mdbook build ./docs --dest-dir=../target/deploy/docs/
|
||||
|
||||
- name: Deploy
|
||||
uses: cloudflare/wrangler-action@v3
|
||||
with:
|
||||
apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }}
|
||||
accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }}
|
||||
command: pages deploy target/deploy --project-name=docs
|
||||
1
.github/workflows/publish_extension_cli.yml
vendored
1
.github/workflows/publish_extension_cli.yml
vendored
@@ -19,6 +19,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
|
||||
- name: Cache dependencies
|
||||
uses: swatinem/rust-cache@v2
|
||||
|
||||
1
.github/workflows/randomized_tests.yml
vendored
1
.github/workflows/randomized_tests.yml
vendored
@@ -31,6 +31,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
|
||||
- name: Run randomized tests
|
||||
run: script/randomized-test-ci
|
||||
|
||||
18
.github/workflows/release_nightly.yml
vendored
18
.github/workflows/release_nightly.yml
vendored
@@ -25,6 +25,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Run style checks
|
||||
@@ -44,6 +45,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
|
||||
- name: Run tests
|
||||
uses: ./.github/actions/run_tests
|
||||
@@ -73,6 +75,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
|
||||
- name: Set release channel to nightly
|
||||
run: |
|
||||
@@ -93,9 +96,7 @@ jobs:
|
||||
bundle-deb:
|
||||
name: Create a Linux *.tar.gz bundle
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
runs-on:
|
||||
- self-hosted
|
||||
- deploy
|
||||
runs-on: ubuntu-22.04 # keep the version fixed to avoid libc and dynamic linked library issues
|
||||
needs: tests
|
||||
env:
|
||||
DIGITALOCEAN_SPACES_ACCESS_KEY: ${{ secrets.DIGITALOCEAN_SPACES_ACCESS_KEY }}
|
||||
@@ -106,9 +107,16 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
submodules: "recursive"
|
||||
|
||||
- name: Add Rust to the PATH
|
||||
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
||||
- name: Cache dependencies
|
||||
uses: swatinem/rust-cache@v2
|
||||
with:
|
||||
save-if: ${{ github.ref == 'refs/heads/main' }}
|
||||
|
||||
- name: Configure linux
|
||||
shell: bash -euxo pipefail {0}
|
||||
run: script/linux
|
||||
|
||||
- name: Set release channel to nightly
|
||||
run: |
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -24,4 +24,3 @@ DerivedData/
|
||||
.venv
|
||||
.blob_store
|
||||
.vscode
|
||||
.wrangler
|
||||
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "crates/live_kit_server/protocol"]
|
||||
path = crates/live_kit_server/protocol
|
||||
url = https://github.com/livekit/protocol
|
||||
9
.mailmap
9
.mailmap
@@ -15,12 +15,8 @@ Christian Bergschneider <christian.bergschneider@gmx.de>
|
||||
Christian Bergschneider <christian.bergschneider@gmx.de> <magiclake@gmx.de>
|
||||
Conrad Irwin <conrad@zed.dev>
|
||||
Conrad Irwin <conrad@zed.dev> <conrad.irwin@gmail.com>
|
||||
Fernando Tagawa <tagawafernando@gmail.com>
|
||||
Fernando Tagawa <tagawafernando@gmail.com> <fernando.tagawa.gamail.com@gmail.com>
|
||||
Greg Morenz <greg-morenz@droid.cafe>
|
||||
Greg Morenz <greg-morenz@droid.cafe> <morenzg@gmail.com>
|
||||
Ivan Žužak <izuzak@gmail.com>
|
||||
Ivan Žužak <izuzak@gmail.com> <ivan.zuzak@github.com>
|
||||
Joseph T. Lyons <JosephTLyons@gmail.com>
|
||||
Joseph T. Lyons <JosephTLyons@gmail.com> <JosephTLyons@users.noreply.github.com>
|
||||
Julia <floc@unpromptedtirade.com>
|
||||
@@ -33,9 +29,6 @@ Kirill Bulatov <kirill@zed.dev>
|
||||
Kirill Bulatov <kirill@zed.dev> <mail4score@gmail.com>
|
||||
Kyle Caverly <kylebcaverly@gmail.com>
|
||||
Kyle Caverly <kylebcaverly@gmail.com> <kyle@zed.dev>
|
||||
LoganDark <contact@logandark.mozmail.com>
|
||||
LoganDark <contact@logandark.mozmail.com> <git@logandark.mozmail.com>
|
||||
LoganDark <contact@logandark.mozmail.com> <github@logandark.mozmail.com>
|
||||
Marshall Bowers <elliott.codes@gmail.com>
|
||||
Marshall Bowers <elliott.codes@gmail.com> <marshall@zed.dev>
|
||||
Max Brunsfeld <maxbrunsfeld@gmail.com>
|
||||
@@ -48,8 +41,6 @@ Nate Butler <iamnbutler@gmail.com> <nate@zed.dev>
|
||||
Nathan Sobo <nathan@zed.dev>
|
||||
Nathan Sobo <nathan@zed.dev> <nathan@warp.dev>
|
||||
Nathan Sobo <nathan@zed.dev> <nathansobo@gmail.com>
|
||||
Petros Amoiridis <petros@hey.com>
|
||||
Petros Amoiridis <petros@hey.com> <petros@zed.dev>
|
||||
Piotr Osiewicz <piotr@zed.dev>
|
||||
Piotr Osiewicz <piotr@zed.dev> <24362066+osiewicz@users.noreply.github.com>
|
||||
Robert Clover <git@clo4.net>
|
||||
|
||||
@@ -3,5 +3,10 @@
|
||||
"label": "clippy",
|
||||
"command": "cargo",
|
||||
"args": ["xtask", "clippy"]
|
||||
},
|
||||
{
|
||||
"label": "assistant2",
|
||||
"command": "cargo",
|
||||
"args": ["run", "-p", "assistant2", "--example", "assistant_example"]
|
||||
}
|
||||
]
|
||||
|
||||
745
Cargo.lock
generated
745
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
38
Cargo.toml
38
Cargo.toml
@@ -4,8 +4,8 @@ members = [
|
||||
"crates/anthropic",
|
||||
"crates/assets",
|
||||
"crates/assistant",
|
||||
"crates/assistant2",
|
||||
"crates/assistant_tooling",
|
||||
"crates/assistant2",
|
||||
"crates/audio",
|
||||
"crates/auto_update",
|
||||
"crates/breadcrumbs",
|
||||
@@ -41,7 +41,6 @@ members = [
|
||||
"crates/gpui",
|
||||
"crates/gpui_macros",
|
||||
"crates/headless",
|
||||
"crates/http",
|
||||
"crates/image_viewer",
|
||||
"crates/inline_completion_button",
|
||||
"crates/install_cli",
|
||||
@@ -53,7 +52,6 @@ members = [
|
||||
"crates/live_kit_client",
|
||||
"crates/live_kit_server",
|
||||
"crates/lsp",
|
||||
"crates/markdown",
|
||||
"crates/markdown_preview",
|
||||
"crates/media",
|
||||
"crates/menu",
|
||||
@@ -128,7 +126,6 @@ members = [
|
||||
"extensions/php",
|
||||
"extensions/prisma",
|
||||
"extensions/purescript",
|
||||
"extensions/ruby",
|
||||
"extensions/svelte",
|
||||
"extensions/terraform",
|
||||
"extensions/toml",
|
||||
@@ -184,7 +181,6 @@ google_ai = { path = "crates/google_ai" }
|
||||
gpui = { path = "crates/gpui" }
|
||||
gpui_macros = { path = "crates/gpui_macros" }
|
||||
headless = { path = "crates/headless" }
|
||||
http = { path = "crates/http" }
|
||||
install_cli = { path = "crates/install_cli" }
|
||||
image_viewer = { path = "crates/image_viewer" }
|
||||
inline_completion_button = { path = "crates/inline_completion_button" }
|
||||
@@ -196,7 +192,6 @@ languages = { path = "crates/languages" }
|
||||
live_kit_client = { path = "crates/live_kit_client" }
|
||||
live_kit_server = { path = "crates/live_kit_server" }
|
||||
lsp = { path = "crates/lsp" }
|
||||
markdown = { path = "crates/markdown" }
|
||||
markdown_preview = { path = "crates/markdown_preview" }
|
||||
media = { path = "crates/media" }
|
||||
menu = { path = "crates/menu" }
|
||||
@@ -230,7 +225,7 @@ snippet = { path = "crates/snippet" }
|
||||
sqlez = { path = "crates/sqlez" }
|
||||
sqlez_macros = { path = "crates/sqlez_macros" }
|
||||
supermaven = { path = "crates/supermaven" }
|
||||
supermaven_api = { path = "crates/supermaven_api" }
|
||||
supermaven_api = { path = "crates/supermaven_api"}
|
||||
story = { path = "crates/story" }
|
||||
storybook = { path = "crates/storybook" }
|
||||
sum_tree = { path = "crates/sum_tree" }
|
||||
@@ -260,24 +255,20 @@ async-fs = "1.6"
|
||||
async-recursion = "1.0.0"
|
||||
async-tar = "0.4.2"
|
||||
async-trait = "0.1"
|
||||
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
|
||||
bitflags = "2.4.2"
|
||||
blade-graphics = { git = "https://github.com/kvark/blade", rev = "e35b2d41f221a48b75f7cf2e78a81e7ecb7a383c" }
|
||||
blade-macros = { git = "https://github.com/kvark/blade", rev = "e35b2d41f221a48b75f7cf2e78a81e7ecb7a383c" }
|
||||
blade-graphics = { git = "https://github.com/kvark/blade", rev = "f5766863de9dcc092e90fdbbc5e0007a99e7f9bf" }
|
||||
blade-macros = { git = "https://github.com/kvark/blade", rev = "f5766863de9dcc092e90fdbbc5e0007a99e7f9bf" }
|
||||
cap-std = "3.0"
|
||||
cargo_toml = "0.20"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
clickhouse = { version = "0.11.6" }
|
||||
ctor = "0.2.6"
|
||||
signal-hook = "0.3.17"
|
||||
ctrlc = "3.4.4"
|
||||
core-foundation = { version = "0.9.3" }
|
||||
core-foundation-sys = "0.8.6"
|
||||
derive_more = "0.99.17"
|
||||
emojis = "0.6.1"
|
||||
env_logger = "0.9"
|
||||
exec = "0.3.1"
|
||||
fork = "0.1.23"
|
||||
futures = "0.3"
|
||||
futures-batch = "0.6.1"
|
||||
futures-lite = "1.13"
|
||||
@@ -296,12 +287,10 @@ isahc = { version = "1.7.2", default-features = false, features = [
|
||||
] }
|
||||
itertools = "0.11.0"
|
||||
lazy_static = "1.4.0"
|
||||
libc = "0.2"
|
||||
linkify = "0.10.0"
|
||||
log = { version = "0.4.16", features = ["kv_unstable_serde"] }
|
||||
nanoid = "0.4"
|
||||
nix = "0.28"
|
||||
once_cell = "1.19.0"
|
||||
ordered-float = "2.1.1"
|
||||
palette = { version = "0.7.5", default-features = false, features = ["std"] }
|
||||
parking_lot = "0.12.1"
|
||||
@@ -315,9 +304,8 @@ pulldown-cmark = { version = "0.10.0", default-features = false }
|
||||
rand = "0.8.5"
|
||||
refineable = { path = "./crates/refineable" }
|
||||
regex = "1.5"
|
||||
repair_json = "0.1.0"
|
||||
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
|
||||
rust-embed = { version = "8.4", features = ["include-exclude"] }
|
||||
rust-embed = { version = "8.0", features = ["include-exclude"] }
|
||||
schemars = "0.8"
|
||||
semver = "1.0"
|
||||
serde = { version = "1.0", features = ["derive", "rc"] }
|
||||
@@ -330,7 +318,6 @@ serde_json_lenient = { version = "0.1", features = [
|
||||
serde_repr = "0.1"
|
||||
sha2 = "0.10"
|
||||
shellexpand = "2.1.0"
|
||||
shlex = "1.3.0"
|
||||
smallvec = { version = "1.6", features = ["union"] }
|
||||
smol = "1.2"
|
||||
strum = { version = "0.25.0", features = ["derive"] }
|
||||
@@ -338,7 +325,7 @@ subtle = "2.5.0"
|
||||
sysinfo = "0.30.7"
|
||||
tempfile = "3.9.0"
|
||||
thiserror = "1.0.29"
|
||||
tiktoken-rs = "0.5.9"
|
||||
tiktoken-rs = "0.5.7"
|
||||
time = { version = "0.3", features = [
|
||||
"macros",
|
||||
"parsing",
|
||||
@@ -356,7 +343,7 @@ tree-sitter-cpp = { git = "https://github.com/tree-sitter/tree-sitter-cpp", rev
|
||||
tree-sitter-css = { git = "https://github.com/tree-sitter/tree-sitter-css", rev = "769203d0f9abe1a9a691ac2b9fe4bb4397a73c51" }
|
||||
tree-sitter-elixir = { git = "https://github.com/elixir-lang/tree-sitter-elixir", rev = "a2861e88a730287a60c11ea9299c033c7d076e30" }
|
||||
tree-sitter-embedded-template = "0.20.0"
|
||||
tree-sitter-go = { git = "https://github.com/tree-sitter/tree-sitter-go", rev = "b82ab803d887002a0af11f6ce63d72884580bf33" }
|
||||
tree-sitter-go = { git = "https://github.com/tree-sitter/tree-sitter-go", rev = "aeb2f33b366fd78d5789ff104956ce23508b85db" }
|
||||
tree-sitter-gomod = { git = "https://github.com/camdencheek/tree-sitter-go-mod" }
|
||||
tree-sitter-gowork = { git = "https://github.com/d1y/tree-sitter-go-work" }
|
||||
rustc-demangle = "0.1.23"
|
||||
@@ -376,7 +363,7 @@ unindent = "0.1.7"
|
||||
unicase = "2.6"
|
||||
unicode-segmentation = "1.10"
|
||||
url = "2.2"
|
||||
uuid = { version = "1.1.2", features = ["v4", "v5", "serde"] }
|
||||
uuid = { version = "1.1.2", features = ["v4", "v5"] }
|
||||
wasmparser = "0.201"
|
||||
wasm-encoder = "0.201"
|
||||
wasmtime = { version = "19.0.0", default-features = false, features = [
|
||||
@@ -392,22 +379,20 @@ wit-component = "0.201"
|
||||
sys-locale = "0.3.1"
|
||||
|
||||
[workspace.dependencies.windows]
|
||||
version = "0.56.0"
|
||||
version = "0.53.0"
|
||||
features = [
|
||||
"implement",
|
||||
"Foundation_Numerics",
|
||||
"System",
|
||||
"System_Threading",
|
||||
"Wdk_System_SystemServices",
|
||||
"Win32_Globalization",
|
||||
"Win32_Graphics_Direct2D",
|
||||
"Win32_Graphics_Direct2D_Common",
|
||||
"Win32_Graphics_DirectWrite",
|
||||
"Win32_Graphics_Dwm",
|
||||
"Win32_Graphics_Dxgi_Common",
|
||||
"Win32_Graphics_Gdi",
|
||||
"Win32_Graphics_Imaging",
|
||||
"Win32_Graphics_Imaging_D2D",
|
||||
"Win32_Media",
|
||||
"Win32_Security",
|
||||
"Win32_Security_Credentials",
|
||||
"Win32_Storage_FileSystem",
|
||||
@@ -421,7 +406,6 @@ features = [
|
||||
"Win32_System_SystemServices",
|
||||
"Win32_System_Threading",
|
||||
"Win32_System_Time",
|
||||
"Win32_System_WinRT",
|
||||
"Win32_UI_Controls",
|
||||
"Win32_UI_HiDpi",
|
||||
"Win32_UI_Input_Ime",
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
<svg width="15" height="15" viewBox="0 0 15 15" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M13.15 7.49998C13.15 4.66458 10.9402 1.84998 7.50002 1.84998C4.7217 1.84998 3.34851 3.90636 2.76336 4.99997H4.5C4.77614 4.99997 5 5.22383 5 5.49997C5 5.77611 4.77614 5.99997 4.5 5.99997H1.5C1.22386 5.99997 1 5.77611 1 5.49997V2.49997C1 2.22383 1.22386 1.99997 1.5 1.99997C1.77614 1.99997 2 2.22383 2 2.49997V4.31318C2.70453 3.07126 4.33406 0.849976 7.50002 0.849976C11.5628 0.849976 14.15 4.18537 14.15 7.49998C14.15 10.8146 11.5628 14.15 7.50002 14.15C5.55618 14.15 3.93778 13.3808 2.78548 12.2084C2.16852 11.5806 1.68668 10.839 1.35816 10.0407C1.25306 9.78536 1.37488 9.49315 1.63024 9.38806C1.8856 9.28296 2.17781 9.40478 2.2829 9.66014C2.56374 10.3425 2.97495 10.9745 3.4987 11.5074C4.47052 12.4963 5.83496 13.15 7.50002 13.15C10.9402 13.15 13.15 10.3354 13.15 7.49998ZM7 10V5.00001H8V10H7Z" fill="currentColor" fill-rule="evenodd" clip-rule="evenodd"></path></svg>
|
||||
|
Before Width: | Height: | Size: 974 B |
@@ -57,9 +57,7 @@
|
||||
"gitkeep": "vcs",
|
||||
"gitmodules": "vcs",
|
||||
"go": "go",
|
||||
"gql": "graphql",
|
||||
"graphql": "graphql",
|
||||
"graphqls": "graphql",
|
||||
"h": "c",
|
||||
"hpp": "cpp",
|
||||
"handlebars": "code",
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M1.5 6C1.5 6.89002 1.76392 7.76004 2.25839 8.50007C2.75285 9.24009 3.45566 9.81686 4.27792 10.1575C5.10019 10.4981 6.00499 10.5872 6.87791 10.4135C7.75082 10.2399 8.55264 9.81132 9.18198 9.18198C9.81132 8.55264 10.2399 7.75082 10.4135 6.87791C10.5872 6.00499 10.4981 5.10019 10.1575 4.27792C9.81686 3.45566 9.24009 2.75285 8.50007 2.25839C7.76004 1.76392 6.89002 1.5 6 1.5C4.74198 1.50473 3.53448 1.99561 2.63 2.87L1.5 4" stroke="#919081" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M1.5 1.5V4H4" stroke="#919081" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M6 3.5V6L8 7" stroke="#919081" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 778 B |
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-library"><path d="m16 6 4 14"/><path d="M12 6v14"/><path d="M8 8v12"/><path d="M4 4v16"/></svg>
|
||||
|
Before Width: | Height: | Size: 298 B |
@@ -53,9 +53,7 @@
|
||||
// "alt-d": "editor::DeleteToNextWordEnd",
|
||||
"ctrl-x": "editor::Cut",
|
||||
"ctrl-c": "editor::Copy",
|
||||
"ctrl-insert": "editor::Copy",
|
||||
"ctrl-v": "editor::Paste",
|
||||
"shift-insert": "editor::Paste",
|
||||
"ctrl-z": "editor::Undo",
|
||||
"ctrl-shift-z": "editor::Redo",
|
||||
"up": "editor::MoveUp",
|
||||
@@ -191,12 +189,6 @@
|
||||
"ctrl-shift-enter": "editor::NewlineBelow"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Markdown",
|
||||
"bindings": {
|
||||
"ctrl-c": "markdown::Copy"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "AssistantPanel",
|
||||
"bindings": {
|
||||
@@ -501,12 +493,6 @@
|
||||
"tab": "editor::ConfirmCompletion"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && inline_completion && !showing_completions",
|
||||
"bindings": {
|
||||
"tab": "editor::AcceptInlineCompletion"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && showing_code_actions",
|
||||
"bindings": {
|
||||
@@ -560,9 +546,7 @@
|
||||
"alt-ctrl-n": "project_panel::NewDirectory",
|
||||
"ctrl-x": "project_panel::Cut",
|
||||
"ctrl-c": "project_panel::Copy",
|
||||
"ctrl-insert": "project_panel::Copy",
|
||||
"ctrl-v": "project_panel::Paste",
|
||||
"shift-insert": "project_panel::Paste",
|
||||
"ctrl-alt-c": "project_panel::CopyPath",
|
||||
"alt-ctrl-shift-c": "project_panel::CopyRelativePath",
|
||||
"f2": "project_panel::Rename",
|
||||
@@ -624,9 +608,7 @@
|
||||
"bindings": {
|
||||
"ctrl-alt-space": "terminal::ShowCharacterPalette",
|
||||
"shift-ctrl-c": "terminal::Copy",
|
||||
"ctrl-insert": "terminal::Copy",
|
||||
"shift-ctrl-v": "terminal::Paste",
|
||||
"shift-insert": "terminal::Paste",
|
||||
"up": ["terminal::SendKeystroke", "up"],
|
||||
"pageup": ["terminal::SendKeystroke", "pageup"],
|
||||
"down": ["terminal::SendKeystroke", "down"],
|
||||
|
||||
@@ -19,6 +19,9 @@
|
||||
"cmd-escape": "menu::Cancel",
|
||||
"ctrl-escape": "menu::Cancel",
|
||||
"ctrl-c": "menu::Cancel",
|
||||
"shift-enter": "picker::UseSelectedQuery",
|
||||
"alt-enter": ["picker::ConfirmInput", { "secondary": false }],
|
||||
"cmd-alt-enter": ["picker::ConfirmInput", { "secondary": true }],
|
||||
"cmd-shift-w": "workspace::CloseWindow",
|
||||
"shift-escape": "workspace::ToggleZoom",
|
||||
"cmd-o": "workspace::Open",
|
||||
@@ -208,9 +211,11 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Markdown",
|
||||
"context": "AssistantChat > Editor", // Used in the assistant2 crate
|
||||
"bindings": {
|
||||
"cmd-c": "markdown::Copy"
|
||||
"enter": ["assistant2::Submit", "Simple"],
|
||||
"cmd-enter": ["assistant2::Submit", "Codebase"],
|
||||
"escape": "assistant2::Cancel"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -515,12 +520,6 @@
|
||||
"tab": "editor::ConfirmCompletion"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && inline_completion && !showing_completions",
|
||||
"bindings": {
|
||||
"tab": "editor::AcceptInlineCompletion"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && showing_code_actions",
|
||||
"bindings": {
|
||||
@@ -577,6 +576,7 @@
|
||||
"cmd-v": "project_panel::Paste",
|
||||
"cmd-alt-c": "project_panel::CopyPath",
|
||||
"alt-cmd-shift-c": "project_panel::CopyRelativePath",
|
||||
"f2": "project_panel::Rename",
|
||||
"enter": "project_panel::Rename",
|
||||
"backspace": "project_panel::Trash",
|
||||
"delete": "project_panel::Trash",
|
||||
@@ -630,14 +630,6 @@
|
||||
"ctrl-backspace": "tab_switcher::CloseSelectedItem"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Picker",
|
||||
"bindings": {
|
||||
"alt-e": "picker::UseSelectedQuery",
|
||||
"alt-enter": ["picker::ConfirmInput", { "secondary": false }],
|
||||
"cmd-alt-enter": ["picker::ConfirmInput", { "secondary": true }]
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Terminal",
|
||||
"bindings": {
|
||||
|
||||
@@ -1,10 +1,4 @@
|
||||
[
|
||||
{
|
||||
"context": "ProjectPanel || Editor",
|
||||
"bindings": {
|
||||
"ctrl-6": "pane::AlternateFile"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && VimControl && !VimWaiting && !menu",
|
||||
"bindings": {
|
||||
@@ -123,9 +117,6 @@
|
||||
}
|
||||
}
|
||||
],
|
||||
"m": ["vim::PushOperator", "Mark"],
|
||||
"'": ["vim::PushOperator", { "Jump": { "line": true } }],
|
||||
"`": ["vim::PushOperator", { "Jump": { "line": false } }],
|
||||
";": "vim::RepeatFind",
|
||||
",": "vim::RepeatFindReversed",
|
||||
"ctrl-o": "pane::GoBack",
|
||||
@@ -246,9 +237,6 @@
|
||||
],
|
||||
"g ]": "editor::GoToDiagnostic",
|
||||
"g [": "editor::GoToPrevDiagnostic",
|
||||
"g i": ["workspace::SendKeystrokes", "` ^ i"],
|
||||
"g ,": "vim::ChangeListNewer",
|
||||
"g ;": "vim::ChangeListOlder",
|
||||
"shift-h": "vim::WindowTop",
|
||||
"shift-m": "vim::WindowMiddle",
|
||||
"shift-l": "vim::WindowBottom",
|
||||
@@ -493,8 +481,8 @@
|
||||
"shift-o": "vim::OtherEnd",
|
||||
"d": "vim::VisualDelete",
|
||||
"x": "vim::VisualDelete",
|
||||
"shift-d": "vim::VisualDeleteLine",
|
||||
"shift-x": "vim::VisualDeleteLine",
|
||||
"shift-d": "vim::VisualDelete",
|
||||
"shift-x": "vim::VisualDelete",
|
||||
"y": "vim::VisualYank",
|
||||
"shift-y": "vim::VisualYank",
|
||||
"p": "vim::Paste",
|
||||
|
||||
@@ -1,18 +1,5 @@
|
||||
{
|
||||
// The name of the Zed theme to use for the UI.
|
||||
//
|
||||
// The theme can also be set to follow system preferences:
|
||||
//
|
||||
// "theme": {
|
||||
// "mode": "system",
|
||||
// "light": "One Light",
|
||||
// "dark": "One Dark"
|
||||
// }
|
||||
//
|
||||
// Where `mode` is one of:
|
||||
// - "system": Use the theme that corresponds to the system's appearance
|
||||
// - "light": Use the theme indicated by the "light" field
|
||||
// - "dark": Use the theme indicated by the "dark" field
|
||||
// The name of the Zed theme to use for the UI
|
||||
"theme": "One Dark",
|
||||
// The name of a base set of key bindings to use.
|
||||
// This setting can take four values, each named after another
|
||||
@@ -84,28 +71,8 @@
|
||||
"restore_on_startup": "last_workspace",
|
||||
// Size of the drop target in the editor.
|
||||
"drop_target_size": 0.2,
|
||||
// Whether the window should be closed when using 'close active item' on a window with no tabs.
|
||||
// May take 3 values:
|
||||
// 1. Use the current platform's convention
|
||||
// "when_closing_with_no_tabs": "platform_default"
|
||||
// 2. Always close the window:
|
||||
// "when_closing_with_no_tabs": "close_window",
|
||||
// 3. Never close the window
|
||||
// "when_closing_with_no_tabs": "keep_window_open",
|
||||
"when_closing_with_no_tabs": "platform_default",
|
||||
// Whether the cursor blinks in the editor.
|
||||
"cursor_blink": true,
|
||||
// How to highlight the current line in the editor.
|
||||
//
|
||||
// 1. Don't highlight the current line:
|
||||
// "none"
|
||||
// 2. Highlight the gutter area:
|
||||
// "gutter"
|
||||
// 3. Highlight the editor area:
|
||||
// "line"
|
||||
// 4. Highlight the full line (default):
|
||||
// "all"
|
||||
"current_line_highlight": "all",
|
||||
// Whether to pop the completions menu while typing in an editor without
|
||||
// explicitly requesting it.
|
||||
"show_completions_on_input": true,
|
||||
@@ -316,14 +283,13 @@
|
||||
// AI provider.
|
||||
"provider": {
|
||||
"name": "openai",
|
||||
// The default model to use when creating new contexts. This
|
||||
// The default model to use when starting new conversations. This
|
||||
// setting can take three values:
|
||||
//
|
||||
// 1. "gpt-3.5-turbo"
|
||||
// 2. "gpt-4"
|
||||
// 3. "gpt-4-turbo-preview"
|
||||
// 4. "gpt-4o"
|
||||
"default_model": "gpt-4o"
|
||||
"default_model": "gpt-4-turbo-preview"
|
||||
}
|
||||
},
|
||||
// Whether the screen sharing icon is shown in the os status bar.
|
||||
@@ -333,7 +299,9 @@
|
||||
// The list of language servers to use (or disable) for all languages.
|
||||
//
|
||||
// This is typically customized on a per-language basis.
|
||||
"language_servers": ["..."],
|
||||
"language_servers": [
|
||||
"..."
|
||||
],
|
||||
// When to automatically save edited buffers. This setting can
|
||||
// take four values.
|
||||
//
|
||||
@@ -348,8 +316,6 @@
|
||||
"autosave": "off",
|
||||
// Settings related to the editor's tab bar.
|
||||
"tab_bar": {
|
||||
// Whether or not to show the tab bar in the editor
|
||||
"show": true,
|
||||
// Whether or not to show the navigation history buttons.
|
||||
"show_nav_history_buttons": true
|
||||
},
|
||||
@@ -381,8 +347,6 @@
|
||||
// when saving it.
|
||||
"ensure_final_newline_on_save": true,
|
||||
// Whether or not to perform a buffer format before saving
|
||||
//
|
||||
// Keep in mind, if the autosave with delay is enabled, format_on_save will be ignored
|
||||
"format_on_save": "on",
|
||||
// How to perform a buffer format. This setting can take 4 values:
|
||||
//
|
||||
@@ -468,7 +432,9 @@
|
||||
"copilot": {
|
||||
// The set of glob patterns for which copilot should be disabled
|
||||
// in any matching file.
|
||||
"disabled_globs": [".env"]
|
||||
"disabled_globs": [
|
||||
".env"
|
||||
]
|
||||
},
|
||||
// Settings specific to journaling
|
||||
"journal": {
|
||||
@@ -497,7 +463,7 @@
|
||||
// }
|
||||
// }
|
||||
"shell": "system",
|
||||
// Where to dock terminals panel. Can be `left`, `right`, `bottom`.
|
||||
// Where to dock terminals panel. Can be 'left', 'right', 'bottom'.
|
||||
"dock": "bottom",
|
||||
// Default width when the terminal is docked to the left or right.
|
||||
"default_width": 640,
|
||||
@@ -579,8 +545,13 @@
|
||||
// Default directories to search for virtual environments, relative
|
||||
// to the current working directory. We recommend overriding this
|
||||
// in your project's settings, rather than globally.
|
||||
"directories": [".env", "env", ".venv", "venv"],
|
||||
// Can also be `csh`, `fish`, and `nushell`
|
||||
"directories": [
|
||||
".env",
|
||||
"env",
|
||||
".venv",
|
||||
"venv"
|
||||
],
|
||||
// Can also be 'csh', 'fish', and `nushell`
|
||||
"activate_script": "default"
|
||||
}
|
||||
},
|
||||
@@ -605,7 +576,7 @@
|
||||
// use those languages.
|
||||
//
|
||||
// For example, to treat files like `foo.notjs` as JavaScript,
|
||||
// and `Embargo.lock` as TOML:
|
||||
// and 'Embargo.lock' as TOML:
|
||||
//
|
||||
// {
|
||||
// "JavaScript": ["notjs"],
|
||||
@@ -622,30 +593,19 @@
|
||||
},
|
||||
// Different settings for specific languages.
|
||||
"languages": {
|
||||
"Astro": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-astro"]
|
||||
}
|
||||
},
|
||||
"Blade": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"C++": {
|
||||
"format_on_save": "off"
|
||||
},
|
||||
"C": {
|
||||
"format_on_save": "off"
|
||||
},
|
||||
"C++": {
|
||||
"format_on_save": "off"
|
||||
},
|
||||
"CSS": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"Elixir": {
|
||||
"language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."]
|
||||
"language_servers": [
|
||||
"elixir-ls",
|
||||
"!next-ls",
|
||||
"!lexical",
|
||||
"..."
|
||||
]
|
||||
},
|
||||
"Gleam": {
|
||||
"tab_size": 2
|
||||
@@ -655,120 +615,26 @@
|
||||
"source.organizeImports": true
|
||||
}
|
||||
},
|
||||
"GraphQL": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"HEEX": {
|
||||
"language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."]
|
||||
},
|
||||
"HTML": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"Java": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-java"]
|
||||
}
|
||||
},
|
||||
"JavaScript": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"JSON": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
"language_servers": [
|
||||
"elixir-ls",
|
||||
"!next-ls",
|
||||
"!lexical",
|
||||
"..."
|
||||
]
|
||||
},
|
||||
"Make": {
|
||||
"hard_tabs": true
|
||||
},
|
||||
"Markdown": {
|
||||
"format_on_save": "off",
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"PHP": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["@prettier/plugin-php"]
|
||||
}
|
||||
},
|
||||
"Prisma": {
|
||||
"tab_size": 2
|
||||
},
|
||||
"Ruby": {
|
||||
"language_servers": ["solargraph", "!ruby-lsp", "..."]
|
||||
},
|
||||
"SCSS": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"SQL": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-sql"]
|
||||
}
|
||||
},
|
||||
"Svelte": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["prettier-plugin-svelte"]
|
||||
}
|
||||
},
|
||||
"TSX": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"Twig": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"TypeScript": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"Vue.js": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
},
|
||||
"XML": {
|
||||
"prettier": {
|
||||
"allowed": true,
|
||||
"plugins": ["@prettier/plugin-xml"]
|
||||
}
|
||||
},
|
||||
"YAML": {
|
||||
"prettier": {
|
||||
"allowed": true
|
||||
}
|
||||
}
|
||||
},
|
||||
// Zed's Prettier integration settings.
|
||||
// Allows to enable/disable formatting with Prettier
|
||||
// and configure default Prettier, used when no project-level Prettier installation is found.
|
||||
// If Prettier is enabled, Zed will use this for its Prettier instance for any applicable file, if
|
||||
// project has no other Prettier installed.
|
||||
"prettier": {
|
||||
// // Whether to consider prettier formatter or not when attempting to format a file.
|
||||
// "allowed": false,
|
||||
//
|
||||
// // Use regular Prettier json configuration.
|
||||
// // If Prettier is allowed, Zed will use this for its Prettier instance for any applicable file, if
|
||||
// // the project has no other Prettier installed.
|
||||
// "plugins": [],
|
||||
//
|
||||
// // Use regular Prettier json configuration.
|
||||
// // If Prettier is allowed, Zed will use this for its Prettier instance for any applicable file, if
|
||||
// // the project has no other Prettier installed.
|
||||
// Use regular Prettier json configuration:
|
||||
// "trailingComma": "es5",
|
||||
// "tabWidth": 4,
|
||||
// "semi": false,
|
||||
@@ -826,17 +692,5 @@
|
||||
// - `short`: "2 s, 15 l, 32 c"
|
||||
// - `long`: "2 selections, 15 lines, 32 characters"
|
||||
// Default: long
|
||||
"line_indicator_format": "long",
|
||||
// Set a proxy to use. The proxy protocol is specified by the URI scheme.
|
||||
//
|
||||
// Supported URI scheme: `http`, `https`, `socks4`, `socks4a`, `socks5`,
|
||||
// `socks5h`. `http` will be used when no scheme is specified.
|
||||
//
|
||||
// By default no proxy will be used, or Zed will try get proxy settings from
|
||||
// environment variables.
|
||||
//
|
||||
// Examples:
|
||||
// - "proxy": "socks5://localhost:10808"
|
||||
// - "proxy": "http://127.0.0.1:10809"
|
||||
"proxy": null
|
||||
"line_indicator_format": "long"
|
||||
}
|
||||
|
||||
@@ -281,14 +281,11 @@ impl ActivityIndicator {
|
||||
message: "Installing Zed update…".to_string(),
|
||||
on_click: None,
|
||||
},
|
||||
AutoUpdateStatus::Updated { binary_path } => Content {
|
||||
AutoUpdateStatus::Updated => Content {
|
||||
icon: None,
|
||||
message: "Click to restart and update Zed".to_string(),
|
||||
on_click: Some(Arc::new({
|
||||
let restart = workspace::Restart {
|
||||
binary_path: Some(binary_path.clone()),
|
||||
};
|
||||
move |_, cx| workspace::restart(&restart, cx)
|
||||
on_click: Some(Arc::new(|_, cx| {
|
||||
workspace::restart(&Default::default(), cx)
|
||||
})),
|
||||
},
|
||||
AutoUpdateStatus::Errored => Content {
|
||||
|
||||
@@ -5,10 +5,6 @@ edition = "2021"
|
||||
publish = false
|
||||
license = "AGPL-3.0-or-later"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
schemars = ["dep:schemars"]
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -18,11 +14,9 @@ path = "src/anthropic.rs"
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
http.workspace = true
|
||||
isahc.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
tokio.workspace = true
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
|
||||
use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use isahc::config::Configurable;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{convert::TryFrom, time::Duration};
|
||||
use std::{convert::TryFrom, sync::Arc};
|
||||
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
|
||||
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub enum Model {
|
||||
#[default]
|
||||
#[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
|
||||
#[serde(rename = "claude-3-opus-20240229")]
|
||||
Claude3Opus,
|
||||
#[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
|
||||
#[serde(rename = "claude-3-sonnet-20240229")]
|
||||
Claude3Sonnet,
|
||||
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
|
||||
#[serde(rename = "claude-3-haiku-20240307")]
|
||||
Claude3Haiku,
|
||||
}
|
||||
|
||||
@@ -32,14 +28,6 @@ impl Model {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &'static str {
|
||||
match self {
|
||||
Model::Claude3Opus => "claude-3-opus-20240229",
|
||||
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
|
||||
Model::Claude3Haiku => "claude-3-opus-20240307",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Claude3Opus => "Claude 3 Opus",
|
||||
@@ -153,24 +141,20 @@ pub enum TextDelta {
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
client: &dyn HttpClient,
|
||||
client: Arc<dyn HttpClient>,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
|
||||
let uri = format!("{api_url}/v1/messages");
|
||||
let mut request_builder = HttpRequest::builder()
|
||||
let request = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Anthropic-Version", "2023-06-01")
|
||||
.header("Anthropic-Beta", "tools-2024-04-04")
|
||||
.header("Anthropic-Beta", "messages-2023-12-15")
|
||||
.header("X-Api-Key", api_key)
|
||||
.header("Content-Type", "application/json");
|
||||
if let Some(low_speed_timeout) = low_speed_timeout {
|
||||
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
|
||||
}
|
||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
.header("Content-Type", "application/json")
|
||||
.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
@@ -212,7 +196,7 @@ pub async fn stream_completion(
|
||||
// #[cfg(test)]
|
||||
// mod tests {
|
||||
// use super::*;
|
||||
// use http::IsahcHttpClient;
|
||||
// use util::http::IsahcHttpClient;
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn stream_completion_success() {
|
||||
|
||||
@@ -11,8 +11,6 @@ doctest = false
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
anthropic = { workspace = true, features = ["schemars"] }
|
||||
cargo_toml.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
@@ -21,9 +19,7 @@ editor.workspace = true
|
||||
file_icons.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
fuzzy.workspace = true
|
||||
gpui.workspace = true
|
||||
http.workspace = true
|
||||
indoc.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
@@ -34,24 +30,19 @@ ordered-float.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
rope.workspace = true
|
||||
schemars.workspace = true
|
||||
search.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
strsim = "0.11"
|
||||
telemetry_events.workspace = true
|
||||
theme.workspace = true
|
||||
tiktoken-rs.workspace = true
|
||||
toml.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
workspace.workspace = true
|
||||
picker.workspace = true
|
||||
gray_matter = "0.2.7"
|
||||
|
||||
[dev-dependencies]
|
||||
ctor.workspace = true
|
||||
@@ -60,4 +51,3 @@ env_logger.workspace = true
|
||||
log.workspace = true
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
rand.workspace = true
|
||||
unindent.workspace = true
|
||||
|
||||
3
crates/assistant/features.zmd
Normal file
3
crates/assistant/features.zmd
Normal file
@@ -0,0 +1,3 @@
|
||||
Push content to a deeper layer.
|
||||
A context can have multiple sublayers.
|
||||
You can enable or disable arbitrary sublayers at arbitrary nesting depths when viewing the document.
|
||||
@@ -1,30 +0,0 @@
|
||||
mod current_project;
|
||||
mod recent_buffers;
|
||||
|
||||
pub use current_project::*;
|
||||
pub use recent_buffers::*;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct AmbientContext {
|
||||
pub recent_buffers: RecentBuffersContext,
|
||||
pub current_project: CurrentProjectContext,
|
||||
}
|
||||
|
||||
impl AmbientContext {
|
||||
pub fn snapshot(&self) -> AmbientContextSnapshot {
|
||||
AmbientContextSnapshot {
|
||||
recent_buffers: self.recent_buffers.snapshot.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct AmbientContextSnapshot {
|
||||
pub recent_buffers: RecentBuffersSnapshot,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
|
||||
pub enum ContextUpdated {
|
||||
Updating,
|
||||
Disabled,
|
||||
}
|
||||
@@ -1,180 +0,0 @@
|
||||
use std::fmt::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use fs::Fs;
|
||||
use gpui::{AsyncAppContext, ModelContext, Task, WeakModel};
|
||||
use project::{Project, ProjectPath};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::ambient_context::ContextUpdated;
|
||||
use crate::assistant_panel::Conversation;
|
||||
use crate::{LanguageModelRequestMessage, Role};
|
||||
|
||||
/// Ambient context about the current project.
|
||||
pub struct CurrentProjectContext {
|
||||
pub enabled: bool,
|
||||
pub message: String,
|
||||
pub pending_message: Option<Task<()>>,
|
||||
}
|
||||
|
||||
#[allow(clippy::derivable_impls)]
|
||||
impl Default for CurrentProjectContext {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
message: String::new(),
|
||||
pending_message: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CurrentProjectContext {
|
||||
/// Returns the [`CurrentProjectContext`] as a message to the language model.
|
||||
pub fn to_message(&self) -> Option<LanguageModelRequestMessage> {
|
||||
self.enabled
|
||||
.then(|| LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: self.message.clone(),
|
||||
})
|
||||
.filter(|message| !message.content.is_empty())
|
||||
}
|
||||
|
||||
/// Updates the [`CurrentProjectContext`] for the given [`Project`].
|
||||
pub fn update(
|
||||
&mut self,
|
||||
fs: Arc<dyn Fs>,
|
||||
project: WeakModel<Project>,
|
||||
cx: &mut ModelContext<Conversation>,
|
||||
) -> ContextUpdated {
|
||||
if !self.enabled {
|
||||
self.message.clear();
|
||||
self.pending_message = None;
|
||||
cx.notify();
|
||||
return ContextUpdated::Disabled;
|
||||
}
|
||||
|
||||
self.pending_message = Some(cx.spawn(|conversation, mut cx| async move {
|
||||
const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(100);
|
||||
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
|
||||
|
||||
let Some(path_to_cargo_toml) = Self::path_to_cargo_toml(project, &mut cx).log_err()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(path_to_cargo_toml) = path_to_cargo_toml
|
||||
.ok_or_else(|| anyhow!("no Cargo.toml"))
|
||||
.log_err()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let message_task = cx
|
||||
.background_executor()
|
||||
.spawn(async move { Self::build_message(fs, &path_to_cargo_toml).await });
|
||||
|
||||
if let Some(message) = message_task.await.log_err() {
|
||||
conversation
|
||||
.update(&mut cx, |conversation, cx| {
|
||||
conversation.ambient_context.current_project.message = message;
|
||||
conversation.count_remaining_tokens(cx);
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}));
|
||||
|
||||
ContextUpdated::Updating
|
||||
}
|
||||
|
||||
async fn build_message(fs: Arc<dyn Fs>, path_to_cargo_toml: &Path) -> Result<String> {
|
||||
let buffer = fs.load(path_to_cargo_toml).await?;
|
||||
let cargo_toml: cargo_toml::Manifest = toml::from_str(&buffer)?;
|
||||
|
||||
let mut message = String::new();
|
||||
writeln!(message, "You are in a Rust project.")?;
|
||||
|
||||
if let Some(workspace) = cargo_toml.workspace {
|
||||
writeln!(
|
||||
message,
|
||||
"The project is a Cargo workspace with the following members:"
|
||||
)?;
|
||||
for member in workspace.members {
|
||||
writeln!(message, "- {member}")?;
|
||||
}
|
||||
|
||||
if !workspace.default_members.is_empty() {
|
||||
writeln!(message, "The default members are:")?;
|
||||
for member in workspace.default_members {
|
||||
writeln!(message, "- {member}")?;
|
||||
}
|
||||
}
|
||||
|
||||
if !workspace.dependencies.is_empty() {
|
||||
writeln!(
|
||||
message,
|
||||
"The following workspace dependencies are installed:"
|
||||
)?;
|
||||
for dependency in workspace.dependencies.keys() {
|
||||
writeln!(message, "- {dependency}")?;
|
||||
}
|
||||
}
|
||||
} else if let Some(package) = cargo_toml.package {
|
||||
writeln!(
|
||||
message,
|
||||
"The project name is \"{name}\".",
|
||||
name = package.name
|
||||
)?;
|
||||
|
||||
let description = package
|
||||
.description
|
||||
.as_ref()
|
||||
.and_then(|description| description.get().ok().cloned());
|
||||
if let Some(description) = description.as_ref() {
|
||||
writeln!(message, "It describes itself as \"{description}\".")?;
|
||||
}
|
||||
|
||||
if !cargo_toml.dependencies.is_empty() {
|
||||
writeln!(message, "The following dependencies are installed:")?;
|
||||
for dependency in cargo_toml.dependencies.keys() {
|
||||
writeln!(message, "- {dependency}")?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
fn path_to_cargo_toml(
|
||||
project: WeakModel<Project>,
|
||||
cx: &mut AsyncAppContext,
|
||||
) -> Result<Option<PathBuf>> {
|
||||
cx.update(|cx| {
|
||||
let worktree = project.update(cx, |project, _cx| {
|
||||
project
|
||||
.worktrees()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow!("no worktree"))
|
||||
})??;
|
||||
|
||||
let path_to_cargo_toml = worktree.update(cx, |worktree, _cx| {
|
||||
let cargo_toml = worktree.entry_for_path("Cargo.toml")?;
|
||||
Some(ProjectPath {
|
||||
worktree_id: worktree.id(),
|
||||
path: cargo_toml.path.clone(),
|
||||
})
|
||||
});
|
||||
let path_to_cargo_toml = path_to_cargo_toml.and_then(|path| {
|
||||
project
|
||||
.update(cx, |project, cx| project.absolute_path(&path, cx))
|
||||
.ok()
|
||||
.flatten()
|
||||
});
|
||||
|
||||
Ok(path_to_cargo_toml)
|
||||
})?
|
||||
}
|
||||
}
|
||||
@@ -1,147 +0,0 @@
|
||||
use crate::{assistant_panel::Conversation, LanguageModelRequestMessage, Role};
|
||||
use gpui::{ModelContext, Subscription, Task, WeakModel};
|
||||
use language::{Buffer, BufferSnapshot, Rope};
|
||||
use std::{fmt::Write, path::PathBuf, time::Duration};
|
||||
|
||||
use super::ContextUpdated;
|
||||
|
||||
pub struct RecentBuffersContext {
|
||||
pub enabled: bool,
|
||||
pub buffers: Vec<RecentBuffer>,
|
||||
pub snapshot: RecentBuffersSnapshot,
|
||||
pub pending_message: Option<Task<()>>,
|
||||
}
|
||||
|
||||
pub struct RecentBuffer {
|
||||
pub buffer: WeakModel<Buffer>,
|
||||
pub _subscription: Subscription,
|
||||
}
|
||||
|
||||
impl Default for RecentBuffersContext {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
buffers: Vec::new(),
|
||||
snapshot: RecentBuffersSnapshot::default(),
|
||||
pending_message: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RecentBuffersContext {
|
||||
pub fn update(&mut self, cx: &mut ModelContext<Conversation>) -> ContextUpdated {
|
||||
let source_buffers = self
|
||||
.buffers
|
||||
.iter()
|
||||
.filter_map(|recent| {
|
||||
let (full_path, snapshot) = recent
|
||||
.buffer
|
||||
.read_with(cx, |buffer, cx| {
|
||||
(
|
||||
buffer.file().map(|file| file.full_path(cx)),
|
||||
buffer.snapshot(),
|
||||
)
|
||||
})
|
||||
.ok()?;
|
||||
Some(SourceBufferSnapshot {
|
||||
full_path,
|
||||
model: recent.buffer.clone(),
|
||||
snapshot,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if !self.enabled || source_buffers.is_empty() {
|
||||
self.snapshot.message = Default::default();
|
||||
self.snapshot.source_buffers.clear();
|
||||
self.pending_message = None;
|
||||
cx.notify();
|
||||
ContextUpdated::Disabled
|
||||
} else {
|
||||
self.pending_message = Some(cx.spawn(|this, mut cx| async move {
|
||||
const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(100);
|
||||
cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
|
||||
|
||||
let message = if source_buffers.is_empty() {
|
||||
Rope::new()
|
||||
} else {
|
||||
cx.background_executor()
|
||||
.spawn({
|
||||
let source_buffers = source_buffers.clone();
|
||||
async move { message_for_recent_buffers(source_buffers) }
|
||||
})
|
||||
.await
|
||||
};
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.ambient_context.recent_buffers.snapshot.source_buffers = source_buffers;
|
||||
this.ambient_context.recent_buffers.snapshot.message = message;
|
||||
this.count_remaining_tokens(cx);
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
}));
|
||||
|
||||
ContextUpdated::Updating
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [`RecentBuffersContext`] as a message to the language model.
|
||||
pub fn to_message(&self) -> Option<LanguageModelRequestMessage> {
|
||||
self.enabled
|
||||
.then(|| LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: self.snapshot.message.to_string(),
|
||||
})
|
||||
.filter(|message| !message.content.is_empty())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct RecentBuffersSnapshot {
|
||||
pub message: Rope,
|
||||
pub source_buffers: Vec<SourceBufferSnapshot>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SourceBufferSnapshot {
|
||||
pub full_path: Option<PathBuf>,
|
||||
pub model: WeakModel<Buffer>,
|
||||
pub snapshot: BufferSnapshot,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SourceBufferSnapshot {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SourceBufferSnapshot")
|
||||
.field("full_path", &self.full_path)
|
||||
.field("model (entity id)", &self.model.entity_id())
|
||||
.field("snapshot (text)", &self.snapshot.text())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
fn message_for_recent_buffers(buffers: Vec<SourceBufferSnapshot>) -> Rope {
|
||||
let mut message = String::new();
|
||||
writeln!(
|
||||
message,
|
||||
"The following is a list of recent buffers that the user has opened."
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
for buffer in buffers {
|
||||
if let Some(path) = buffer.full_path {
|
||||
writeln!(message, "```{}", path.display()).unwrap();
|
||||
} else {
|
||||
writeln!(message, "```untitled").unwrap();
|
||||
}
|
||||
|
||||
for chunk in buffer.snapshot.chunks(0..buffer.snapshot.len(), false) {
|
||||
message.push_str(chunk.text);
|
||||
}
|
||||
if !message.ends_with('\n') {
|
||||
message.push('\n');
|
||||
}
|
||||
message.push_str("```\n");
|
||||
}
|
||||
|
||||
Rope::from(message.as_str())
|
||||
}
|
||||
@@ -1,23 +1,20 @@
|
||||
mod ambient_context;
|
||||
pub mod assistant_panel;
|
||||
pub mod assistant_settings;
|
||||
mod codegen;
|
||||
mod completion_provider;
|
||||
mod omit_ranges;
|
||||
mod prompts;
|
||||
mod saved_conversation;
|
||||
mod search;
|
||||
mod slash_command;
|
||||
mod streaming_diff;
|
||||
|
||||
use ambient_context::AmbientContextSnapshot;
|
||||
mod embedded_scope;
|
||||
|
||||
pub use assistant_panel::AssistantPanel;
|
||||
use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
|
||||
use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
|
||||
use chrono::{DateTime, Local};
|
||||
use client::{proto, Client};
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
pub(crate) use completion_provider::*;
|
||||
use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
|
||||
pub(crate) use prompts::prompt_library::*;
|
||||
use gpui::{actions, AppContext, BorrowAppContext, Global, SharedString};
|
||||
pub(crate) use saved_conversation::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
@@ -29,6 +26,7 @@ use std::{
|
||||
actions!(
|
||||
assistant,
|
||||
[
|
||||
NewConversation,
|
||||
Assist,
|
||||
Split,
|
||||
CycleMessageRole,
|
||||
@@ -36,10 +34,7 @@ actions!(
|
||||
ToggleFocus,
|
||||
ResetKey,
|
||||
InlineAssist,
|
||||
InsertActivePrompt,
|
||||
ToggleIncludeConversation,
|
||||
ToggleHistory,
|
||||
ApplyEdit
|
||||
]
|
||||
);
|
||||
|
||||
@@ -80,7 +75,6 @@ impl Display for Role {
|
||||
pub enum LanguageModel {
|
||||
ZedDotDev(ZedDotDevModel),
|
||||
OpenAi(OpenAiModel),
|
||||
Anthropic(AnthropicModel),
|
||||
}
|
||||
|
||||
impl Default for LanguageModel {
|
||||
@@ -93,23 +87,20 @@ impl LanguageModel {
|
||||
pub fn telemetry_id(&self) -> String {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
|
||||
LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
|
||||
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> String {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.display_name().into(),
|
||||
LanguageModel::Anthropic(model) => model.display_name().into(),
|
||||
LanguageModel::ZedDotDev(model) => model.display_name().into(),
|
||||
LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
|
||||
LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.max_token_count(),
|
||||
LanguageModel::Anthropic(model) => model.max_token_count(),
|
||||
LanguageModel::ZedDotDev(model) => model.max_token_count(),
|
||||
}
|
||||
}
|
||||
@@ -117,7 +108,6 @@ impl LanguageModel {
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
LanguageModel::OpenAi(model) => model.id(),
|
||||
LanguageModel::Anthropic(model) => model.id(),
|
||||
LanguageModel::ZedDotDev(model) => model.id(),
|
||||
}
|
||||
}
|
||||
@@ -188,10 +178,8 @@ pub struct LanguageModelChoiceDelta {
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
struct MessageMetadata {
|
||||
role: Role,
|
||||
sent_at: DateTime<Local>,
|
||||
status: MessageStatus,
|
||||
// todo!("delete this")
|
||||
#[serde(skip)]
|
||||
ambient_context: AmbientContextSnapshot,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
@@ -243,13 +231,13 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
filter.hide_namespace(Assistant::NAMESPACE);
|
||||
});
|
||||
Assistant::update_global(cx, |assistant, cx| {
|
||||
cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
|
||||
let settings = AssistantSettings::get_global(cx);
|
||||
|
||||
assistant.set_enabled(settings.enabled, cx);
|
||||
});
|
||||
cx.observe_global::<SettingsStore>(|cx| {
|
||||
Assistant::update_global(cx, |assistant, cx| {
|
||||
cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
|
||||
let settings = AssistantSettings::get_global(cx);
|
||||
|
||||
assistant.set_enabled(settings.enabled, cx);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,5 @@
|
||||
use std::fmt;
|
||||
|
||||
pub use anthropic::Model as AnthropicModel;
|
||||
use gpui::Pixels;
|
||||
pub use open_ai::Model as OpenAiModel;
|
||||
use schemars::{
|
||||
@@ -17,9 +16,8 @@ use settings::{Settings, SettingsSources};
|
||||
pub enum ZedDotDevModel {
|
||||
Gpt3Point5Turbo,
|
||||
Gpt4,
|
||||
Gpt4Turbo,
|
||||
#[default]
|
||||
Gpt4Omni,
|
||||
Gpt4Turbo,
|
||||
Claude3Opus,
|
||||
Claude3Sonnet,
|
||||
Claude3Haiku,
|
||||
@@ -57,7 +55,6 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
|
||||
"gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
|
||||
"gpt-4" => Ok(ZedDotDevModel::Gpt4),
|
||||
"gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
|
||||
"gpt-4o" => Ok(ZedDotDevModel::Gpt4Omni),
|
||||
_ => Ok(ZedDotDevModel::Custom(value.to_owned())),
|
||||
}
|
||||
}
|
||||
@@ -77,7 +74,6 @@ impl JsonSchema for ZedDotDevModel {
|
||||
"gpt-3.5-turbo".to_owned(),
|
||||
"gpt-4".to_owned(),
|
||||
"gpt-4-turbo-preview".to_owned(),
|
||||
"gpt-4o".to_owned(),
|
||||
];
|
||||
Schema::Object(SchemaObject {
|
||||
instance_type: Some(InstanceType::String.into()),
|
||||
@@ -104,7 +100,6 @@ impl ZedDotDevModel {
|
||||
Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
|
||||
Self::Gpt4 => "gpt-4",
|
||||
Self::Gpt4Turbo => "gpt-4-turbo-preview",
|
||||
Self::Gpt4Omni => "gpt-4o",
|
||||
Self::Claude3Opus => "claude-3-opus",
|
||||
Self::Claude3Sonnet => "claude-3-sonnet",
|
||||
Self::Claude3Haiku => "claude-3-haiku",
|
||||
@@ -117,7 +112,6 @@ impl ZedDotDevModel {
|
||||
Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
|
||||
Self::Gpt4 => "GPT 4",
|
||||
Self::Gpt4Turbo => "GPT 4 Turbo",
|
||||
Self::Gpt4Omni => "GPT 4 Omni",
|
||||
Self::Claude3Opus => "Claude 3 Opus",
|
||||
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||
@@ -129,7 +123,7 @@ impl ZedDotDevModel {
|
||||
match self {
|
||||
Self::Gpt3Point5Turbo => 2048,
|
||||
Self::Gpt4 => 4096,
|
||||
Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
|
||||
Self::Gpt4Turbo => 128000,
|
||||
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
|
||||
Self::Custom(_) => 4096, // TODO: Make this configurable
|
||||
}
|
||||
@@ -159,17 +153,6 @@ pub enum AssistantProvider {
|
||||
default_model: OpenAiModel,
|
||||
#[serde(default = "open_ai_url")]
|
||||
api_url: String,
|
||||
#[serde(default)]
|
||||
low_speed_timeout_in_seconds: Option<u64>,
|
||||
},
|
||||
#[serde(rename = "anthropic")]
|
||||
Anthropic {
|
||||
#[serde(default)]
|
||||
default_model: AnthropicModel,
|
||||
#[serde(default = "anthropic_api_url")]
|
||||
api_url: String,
|
||||
#[serde(default)]
|
||||
low_speed_timeout_in_seconds: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -182,11 +165,7 @@ impl Default for AssistantProvider {
|
||||
}
|
||||
|
||||
fn open_ai_url() -> String {
|
||||
open_ai::OPEN_AI_API_URL.to_string()
|
||||
}
|
||||
|
||||
fn anthropic_api_url() -> String {
|
||||
anthropic::ANTHROPIC_API_URL.to_string()
|
||||
"https://api.openai.com/v1".into()
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Deserialize, Serialize)]
|
||||
@@ -243,14 +222,12 @@ impl AssistantSettingsContent {
|
||||
Some(AssistantProvider::OpenAi {
|
||||
default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
|
||||
api_url: open_ai_api_url.clone(),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
})
|
||||
} else {
|
||||
settings.default_open_ai_model.clone().map(|open_ai_model| {
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: open_ai_model,
|
||||
api_url: open_ai_url(),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
}
|
||||
})
|
||||
},
|
||||
@@ -339,11 +316,11 @@ pub struct LegacyAssistantSettingsContent {
|
||||
///
|
||||
/// Default: 320
|
||||
pub default_height: Option<f32>,
|
||||
/// The default OpenAI model to use when creating new contexts.
|
||||
/// The default OpenAI model to use when starting new conversations.
|
||||
///
|
||||
/// Default: gpt-4-1106-preview
|
||||
pub default_open_ai_model: Option<OpenAiModel>,
|
||||
/// OpenAI API base URL to use when creating new contexts.
|
||||
/// OpenAI API base URL to use when starting new conversations.
|
||||
///
|
||||
/// Default: https://api.openai.com/v1
|
||||
pub openai_api_url: Option<String>,
|
||||
@@ -387,17 +364,14 @@ impl Settings for AssistantSettings {
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
},
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: default_model_override,
|
||||
api_url: api_url_override,
|
||||
low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
|
||||
},
|
||||
) => {
|
||||
*default_model = default_model_override;
|
||||
*api_url = api_url_override;
|
||||
*low_speed_timeout_in_seconds = low_speed_timeout_in_seconds_override;
|
||||
}
|
||||
(merged, provider_override) => {
|
||||
*merged = provider_override;
|
||||
@@ -418,7 +392,7 @@ fn merge<T: Copy>(target: &mut T, value: Option<T>) {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::{AppContext, UpdateGlobal};
|
||||
use gpui::{AppContext, BorrowAppContext};
|
||||
use settings::SettingsStore;
|
||||
|
||||
use super::*;
|
||||
@@ -433,14 +407,13 @@ mod tests {
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: OpenAiModel::FourOmni,
|
||||
api_url: open_ai_url(),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
default_model: OpenAiModel::FourTurbo,
|
||||
api_url: open_ai_url()
|
||||
}
|
||||
);
|
||||
|
||||
// Ensure backward-compatibility.
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
cx.update_global::<SettingsStore, _>(|store, cx| {
|
||||
store
|
||||
.set_user_settings(
|
||||
r#"{
|
||||
@@ -455,12 +428,11 @@ mod tests {
|
||||
assert_eq!(
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: OpenAiModel::FourOmni,
|
||||
api_url: "test-url".into(),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
default_model: OpenAiModel::FourTurbo,
|
||||
api_url: "test-url".into()
|
||||
}
|
||||
);
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
cx.update_global::<SettingsStore, _>(|store, cx| {
|
||||
store
|
||||
.set_user_settings(
|
||||
r#"{
|
||||
@@ -476,13 +448,12 @@ mod tests {
|
||||
AssistantSettings::get_global(cx).provider,
|
||||
AssistantProvider::OpenAi {
|
||||
default_model: OpenAiModel::Four,
|
||||
api_url: open_ai_url(),
|
||||
low_speed_timeout_in_seconds: None,
|
||||
api_url: open_ai_url()
|
||||
}
|
||||
);
|
||||
|
||||
// The new version supports setting a custom model when using zed.dev.
|
||||
SettingsStore::update_global(cx, |store, cx| {
|
||||
cx.update_global::<SettingsStore, _>(|store, cx| {
|
||||
store
|
||||
.set_user_settings(
|
||||
r#"{
|
||||
|
||||
@@ -3,13 +3,11 @@ use crate::{
|
||||
CompletionProvider, LanguageModelRequest,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use client::telemetry::Telemetry;
|
||||
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
|
||||
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
||||
use gpui::{EventEmitter, Model, ModelContext, Task};
|
||||
use language::{Rope, TransactionId};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use std::{cmp, future, ops::Range, sync::Arc, time::Instant};
|
||||
use std::{cmp, future, ops::Range};
|
||||
|
||||
pub enum Event {
|
||||
Finished,
|
||||
@@ -31,19 +29,13 @@ pub struct Codegen {
|
||||
error: Option<anyhow::Error>,
|
||||
generation: Task<()>,
|
||||
idle: bool,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
_subscription: gpui::Subscription,
|
||||
}
|
||||
|
||||
impl EventEmitter<Event> for Codegen {}
|
||||
|
||||
impl Codegen {
|
||||
pub fn new(
|
||||
buffer: Model<MultiBuffer>,
|
||||
kind: CodegenKind,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
pub fn new(buffer: Model<MultiBuffer>, kind: CodegenKind, cx: &mut ModelContext<Self>) -> Self {
|
||||
let snapshot = buffer.read(cx).snapshot(cx);
|
||||
Self {
|
||||
buffer: buffer.clone(),
|
||||
@@ -54,7 +46,6 @@ impl Codegen {
|
||||
error: Default::default(),
|
||||
idle: true,
|
||||
generation: Task::ready(()),
|
||||
telemetry,
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
}
|
||||
}
|
||||
@@ -109,11 +100,9 @@ impl Codegen {
|
||||
.suggested_indents(selection_start.row..selection_start.row + 1, cx)
|
||||
.into_values()
|
||||
.next()
|
||||
.unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
|
||||
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
|
||||
|
||||
let model_telemetry_id = prompt.model.telemetry_id();
|
||||
let response = CompletionProvider::global(cx).complete(prompt);
|
||||
let telemetry = self.telemetry.clone();
|
||||
self.generation = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
let generate = async {
|
||||
@@ -121,89 +110,68 @@ impl Codegen {
|
||||
|
||||
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
|
||||
let diff = cx.background_executor().spawn(async move {
|
||||
let mut response_latency = None;
|
||||
let request_start = Instant::now();
|
||||
let diff = async {
|
||||
let chunks = strip_invalid_spans_from_codeblock(response.await?);
|
||||
futures::pin_mut!(chunks);
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let chunks = strip_invalid_spans_from_codeblock(response.await?);
|
||||
futures::pin_mut!(chunks);
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
|
||||
let mut new_text = String::new();
|
||||
let mut base_indent = None;
|
||||
let mut line_indent = None;
|
||||
let mut first_line = true;
|
||||
let mut new_text = String::new();
|
||||
let mut base_indent = None;
|
||||
let mut line_indent = None;
|
||||
let mut first_line = true;
|
||||
|
||||
while let Some(chunk) = chunks.next().await {
|
||||
if response_latency.is_none() {
|
||||
response_latency = Some(request_start.elapsed());
|
||||
}
|
||||
let chunk = chunk?;
|
||||
while let Some(chunk) = chunks.next().await {
|
||||
let chunk = chunk?;
|
||||
|
||||
let mut lines = chunk.split('\n').peekable();
|
||||
while let Some(line) = lines.next() {
|
||||
new_text.push_str(line);
|
||||
if line_indent.is_none() {
|
||||
if let Some(non_whitespace_ch_ix) =
|
||||
new_text.find(|ch: char| !ch.is_whitespace())
|
||||
{
|
||||
line_indent = Some(non_whitespace_ch_ix);
|
||||
base_indent = base_indent.or(line_indent);
|
||||
let mut lines = chunk.split('\n').peekable();
|
||||
while let Some(line) = lines.next() {
|
||||
new_text.push_str(line);
|
||||
if line_indent.is_none() {
|
||||
if let Some(non_whitespace_ch_ix) =
|
||||
new_text.find(|ch: char| !ch.is_whitespace())
|
||||
{
|
||||
line_indent = Some(non_whitespace_ch_ix);
|
||||
base_indent = base_indent.or(line_indent);
|
||||
|
||||
let line_indent = line_indent.unwrap();
|
||||
let base_indent = base_indent.unwrap();
|
||||
let indent_delta =
|
||||
line_indent as i32 - base_indent as i32;
|
||||
let mut corrected_indent_len = cmp::max(
|
||||
0,
|
||||
suggested_line_indent.len as i32 + indent_delta,
|
||||
)
|
||||
as usize;
|
||||
if first_line {
|
||||
corrected_indent_len = corrected_indent_len
|
||||
.saturating_sub(
|
||||
selection_start.column as usize,
|
||||
);
|
||||
}
|
||||
|
||||
let indent_char = suggested_line_indent.char();
|
||||
let mut indent_buffer = [0; 4];
|
||||
let indent_str =
|
||||
indent_char.encode_utf8(&mut indent_buffer);
|
||||
new_text.replace_range(
|
||||
..line_indent,
|
||||
&indent_str.repeat(corrected_indent_len),
|
||||
);
|
||||
let line_indent = line_indent.unwrap();
|
||||
let base_indent = base_indent.unwrap();
|
||||
let indent_delta = line_indent as i32 - base_indent as i32;
|
||||
let mut corrected_indent_len = cmp::max(
|
||||
0,
|
||||
suggested_line_indent.len as i32 + indent_delta,
|
||||
)
|
||||
as usize;
|
||||
if first_line {
|
||||
corrected_indent_len = corrected_indent_len
|
||||
.saturating_sub(selection_start.column as usize);
|
||||
}
|
||||
}
|
||||
|
||||
if line_indent.is_some() {
|
||||
hunks_tx.send(diff.push_new(&new_text)).await?;
|
||||
new_text.clear();
|
||||
let indent_char = suggested_line_indent.char();
|
||||
let mut indent_buffer = [0; 4];
|
||||
let indent_str =
|
||||
indent_char.encode_utf8(&mut indent_buffer);
|
||||
new_text.replace_range(
|
||||
..line_indent,
|
||||
&indent_str.repeat(corrected_indent_len),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if lines.peek().is_some() {
|
||||
hunks_tx.send(diff.push_new("\n")).await?;
|
||||
line_indent = None;
|
||||
first_line = false;
|
||||
}
|
||||
if line_indent.is_some() {
|
||||
hunks_tx.send(diff.push_new(&new_text)).await?;
|
||||
new_text.clear();
|
||||
}
|
||||
|
||||
if lines.peek().is_some() {
|
||||
hunks_tx.send(diff.push_new("\n")).await?;
|
||||
line_indent = None;
|
||||
first_line = false;
|
||||
}
|
||||
}
|
||||
hunks_tx.send(diff.push_new(&new_text)).await?;
|
||||
hunks_tx.send(diff.finish()).await?;
|
||||
|
||||
anyhow::Ok(())
|
||||
};
|
||||
|
||||
let error_message = diff.await.err().map(|error| error.to_string());
|
||||
if let Some(telemetry) = telemetry {
|
||||
telemetry.report_assistant_event(
|
||||
None,
|
||||
telemetry_events::AssistantKind::Inline,
|
||||
model_telemetry_id,
|
||||
response_latency,
|
||||
error_message,
|
||||
);
|
||||
}
|
||||
hunks_tx.send(diff.push_new(&new_text)).await?;
|
||||
hunks_tx.send(diff.finish()).await?;
|
||||
|
||||
anyhow::Ok(())
|
||||
});
|
||||
|
||||
while let Some(hunks) = hunks_rx.next().await {
|
||||
@@ -266,8 +234,7 @@ impl Codegen {
|
||||
})?;
|
||||
}
|
||||
|
||||
diff.await;
|
||||
|
||||
diff.await?;
|
||||
anyhow::Ok(())
|
||||
};
|
||||
|
||||
@@ -428,9 +395,8 @@ mod tests {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||
});
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx)
|
||||
});
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Transform { range }, cx));
|
||||
|
||||
let request = LanguageModelRequest::default();
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
@@ -487,9 +453,8 @@ mod tests {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 6))
|
||||
});
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
|
||||
});
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
|
||||
|
||||
let request = LanguageModelRequest::default();
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
@@ -546,9 +511,8 @@ mod tests {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 2))
|
||||
});
|
||||
let codegen = cx.new_model(|cx| {
|
||||
Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
|
||||
});
|
||||
let codegen =
|
||||
cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
|
||||
|
||||
let request = LanguageModelRequest::default();
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
mod anthropic;
|
||||
#[cfg(test)]
|
||||
mod fake;
|
||||
mod open_ai;
|
||||
mod zed;
|
||||
|
||||
pub use anthropic::*;
|
||||
#[cfg(test)]
|
||||
pub use fake::*;
|
||||
pub use open_ai::*;
|
||||
@@ -20,7 +18,6 @@ use futures::{future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
let mut settings_version = 0;
|
||||
@@ -36,23 +33,10 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
)),
|
||||
AssistantProvider::Anthropic {
|
||||
default_model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
} => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
)),
|
||||
};
|
||||
@@ -67,30 +51,9 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
},
|
||||
) => {
|
||||
provider.update(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
);
|
||||
}
|
||||
(
|
||||
CompletionProvider::Anthropic(provider),
|
||||
AssistantProvider::Anthropic {
|
||||
default_model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
},
|
||||
) => {
|
||||
provider.update(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
);
|
||||
provider.update(default_model.clone(), api_url.clone(), settings_version);
|
||||
}
|
||||
(
|
||||
CompletionProvider::ZedDotDev(provider),
|
||||
@@ -98,7 +61,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
) => {
|
||||
provider.update(default_model.clone(), settings_version);
|
||||
}
|
||||
(_, AssistantProvider::ZedDotDev { default_model }) => {
|
||||
(CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
|
||||
*provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
client.clone(),
|
||||
@@ -107,37 +70,21 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
));
|
||||
}
|
||||
(
|
||||
_,
|
||||
CompletionProvider::ZedDotDev(_),
|
||||
AssistantProvider::OpenAi {
|
||||
default_model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
},
|
||||
) => {
|
||||
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
));
|
||||
}
|
||||
(
|
||||
_,
|
||||
AssistantProvider::Anthropic {
|
||||
default_model,
|
||||
api_url,
|
||||
low_speed_timeout_in_seconds,
|
||||
},
|
||||
) => {
|
||||
*provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
|
||||
default_model.clone(),
|
||||
api_url.clone(),
|
||||
client.http_client(),
|
||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||
settings_version,
|
||||
));
|
||||
}
|
||||
#[cfg(test)]
|
||||
(CompletionProvider::Fake(_), _) => unimplemented!(),
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -146,7 +93,6 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
|
||||
pub enum CompletionProvider {
|
||||
OpenAi(OpenAiCompletionProvider),
|
||||
Anthropic(AnthropicCompletionProvider),
|
||||
ZedDotDev(ZedDotDevCompletionProvider),
|
||||
#[cfg(test)]
|
||||
Fake(FakeCompletionProvider),
|
||||
@@ -162,7 +108,6 @@ impl CompletionProvider {
|
||||
pub fn settings_version(&self) -> usize {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.settings_version(),
|
||||
CompletionProvider::Anthropic(provider) => provider.settings_version(),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => unimplemented!(),
|
||||
@@ -172,7 +117,6 @@ impl CompletionProvider {
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
|
||||
CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => true,
|
||||
@@ -182,7 +126,6 @@ impl CompletionProvider {
|
||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
|
||||
CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => Task::ready(Ok(())),
|
||||
@@ -192,7 +135,6 @@ impl CompletionProvider {
|
||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
|
||||
CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => unimplemented!(),
|
||||
@@ -202,7 +144,6 @@ impl CompletionProvider {
|
||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
|
||||
CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
|
||||
CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => Task::ready(Ok(())),
|
||||
@@ -212,9 +153,6 @@ impl CompletionProvider {
|
||||
pub fn default_model(&self) -> LanguageModel {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
|
||||
CompletionProvider::Anthropic(provider) => {
|
||||
LanguageModel::Anthropic(provider.default_model())
|
||||
}
|
||||
CompletionProvider::ZedDotDev(provider) => {
|
||||
LanguageModel::ZedDotDev(provider.default_model())
|
||||
}
|
||||
@@ -230,10 +168,9 @@ impl CompletionProvider {
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
|
||||
CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
|
||||
CompletionProvider::Fake(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,7 +180,6 @@ impl CompletionProvider {
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
match self {
|
||||
CompletionProvider::OpenAi(provider) => provider.complete(request),
|
||||
CompletionProvider::Anthropic(provider) => provider.complete(request),
|
||||
CompletionProvider::ZedDotDev(provider) => provider.complete(request),
|
||||
#[cfg(test)]
|
||||
CompletionProvider::Fake(provider) => provider.complete(),
|
||||
|
||||
@@ -1,327 +0,0 @@
|
||||
use crate::count_open_ai_tokens;
|
||||
use crate::{
|
||||
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
|
||||
Role,
|
||||
};
|
||||
use anthropic::{stream_completion, Request, RequestMessage, Role as AnthropicRole};
|
||||
use anyhow::{anyhow, Result};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
|
||||
use http::HttpClient;
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use std::{env, sync::Arc};
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
|
||||
pub struct AnthropicCompletionProvider {
|
||||
api_key: Option<String>,
|
||||
api_url: String,
|
||||
default_model: AnthropicModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
}
|
||||
|
||||
impl AnthropicCompletionProvider {
|
||||
pub fn new(
|
||||
default_model: AnthropicModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_key: None,
|
||||
api_url,
|
||||
default_model,
|
||||
http_client,
|
||||
low_speed_timeout,
|
||||
settings_version,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(
|
||||
&mut self,
|
||||
default_model: AnthropicModel,
|
||||
api_url: String,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) {
|
||||
self.default_model = default_model;
|
||||
self.api_url = api_url;
|
||||
self.low_speed_timeout = low_speed_timeout;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
pub fn settings_version(&self) -> usize {
|
||||
self.settings_version
|
||||
}
|
||||
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated() {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
let api_url = self.api_url.clone();
|
||||
cx.spawn(|mut cx| async move {
|
||||
let api_key = if let Ok(api_key) = env::var("ANTHROPIC_API_KEY") {
|
||||
api_key
|
||||
} else {
|
||||
let (_, api_key) = cx
|
||||
.update(|cx| cx.read_credentials(&api_url))?
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("credentials not found"))?;
|
||||
String::from_utf8(api_key)?
|
||||
};
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
if let CompletionProvider::Anthropic(provider) = provider {
|
||||
provider.api_key = Some(api_key);
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let delete_credentials = cx.delete_credentials(&self.api_url);
|
||||
cx.spawn(|mut cx| async move {
|
||||
delete_credentials.await.log_err();
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
if let CompletionProvider::Anthropic(provider) = provider {
|
||||
provider.api_key = None;
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
pub fn default_model(&self) -> AnthropicModel {
|
||||
self.default_model.clone()
|
||||
}
|
||||
|
||||
pub fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
|
||||
pub fn complete(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||
let request = self.to_anthropic_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let low_speed_timeout = self.low_speed_timeout;
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
Ok(response) => match response {
|
||||
anthropic::ResponseEvent::ContentBlockStart {
|
||||
content_block, ..
|
||||
} => match content_block {
|
||||
anthropic::ContentBlock::Text { text } => Some(Ok(text)),
|
||||
},
|
||||
anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
|
||||
match delta {
|
||||
anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn to_anthropic_request(&self, request: LanguageModelRequest) -> Request {
|
||||
let model = match request.model {
|
||||
LanguageModel::Anthropic(model) => model,
|
||||
_ => self.default_model(),
|
||||
};
|
||||
|
||||
let mut system_message = String::new();
|
||||
|
||||
let mut messages: Vec<RequestMessage> = Vec::new();
|
||||
for message in request.messages {
|
||||
if message.content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
let role = match message.role {
|
||||
Role::User => AnthropicRole::User,
|
||||
Role::Assistant => AnthropicRole::Assistant,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if last_message.role == role {
|
||||
last_message.content.push_str("\n\n");
|
||||
last_message.content.push_str(&message.content);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(RequestMessage {
|
||||
role,
|
||||
content: message.content,
|
||||
});
|
||||
}
|
||||
Role::System => {
|
||||
if !system_message.is_empty() {
|
||||
system_message.push_str("\n\n");
|
||||
}
|
||||
system_message.push_str(&message.content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Request {
|
||||
model,
|
||||
messages,
|
||||
stream: true,
|
||||
system: system_message,
|
||||
max_tokens: 4092,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AuthenticationPrompt {
|
||||
api_key: View<Editor>,
|
||||
api_url: String,
|
||||
}
|
||||
|
||||
impl AuthenticationPrompt {
|
||||
fn new(api_url: String, cx: &mut WindowContext) -> Self {
|
||||
Self {
|
||||
api_key: cx.new_view(|cx| {
|
||||
let mut editor = Editor::single_line(cx);
|
||||
editor.set_placeholder_text(
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
}),
|
||||
api_url,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
|
||||
let api_key = self.api_key.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
|
||||
cx.spawn(|_, mut cx| async move {
|
||||
write_credentials.await?;
|
||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||
if let CompletionProvider::Anthropic(provider) = provider {
|
||||
provider.api_key = Some(api_key);
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: FontWeight::NORMAL,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: relative(1.3),
|
||||
background_color: None,
|
||||
underline: None,
|
||||
strikethrough: None,
|
||||
white_space: WhiteSpace::Normal,
|
||||
};
|
||||
EditorElement::new(
|
||||
&self.api_key,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AuthenticationPrompt {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
const INSTRUCTIONS: [&str; 4] = [
|
||||
"To use the assistant panel or inline assistant, you need to add your Anthropic API key.",
|
||||
"You can create an API key at: https://console.anthropic.com/settings/keys",
|
||||
"",
|
||||
"Paste your Anthropic API key below and hit enter to use the assistant:",
|
||||
];
|
||||
|
||||
v_flex()
|
||||
.p_4()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.children(
|
||||
INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.my_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(self.render_api_key_editor(cx)),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
|
||||
)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Label::new("Click on").size(LabelSize::Small))
|
||||
.child(Icon::new(IconName::Ai).size(IconSize::XSmall))
|
||||
.child(
|
||||
Label::new("in the status bar to close this panel.").size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
use crate::assistant_settings::ZedDotDevModel;
|
||||
use crate::{
|
||||
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
|
||||
};
|
||||
@@ -6,21 +5,18 @@ use anyhow::{anyhow, Result};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
|
||||
use http::HttpClient;
|
||||
use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use std::{env, sync::Arc};
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
use util::{http::HttpClient, ResultExt};
|
||||
|
||||
pub struct OpenAiCompletionProvider {
|
||||
api_key: Option<String>,
|
||||
api_url: String,
|
||||
default_model: OpenAiModel,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
}
|
||||
|
||||
@@ -29,7 +25,6 @@ impl OpenAiCompletionProvider {
|
||||
default_model: OpenAiModel,
|
||||
api_url: String,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
@@ -37,21 +32,13 @@ impl OpenAiCompletionProvider {
|
||||
api_url,
|
||||
default_model,
|
||||
http_client,
|
||||
low_speed_timeout,
|
||||
settings_version,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(
|
||||
&mut self,
|
||||
default_model: OpenAiModel,
|
||||
api_url: String,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
settings_version: usize,
|
||||
) {
|
||||
pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) {
|
||||
self.default_model = default_model;
|
||||
self.api_url = api_url;
|
||||
self.low_speed_timeout = low_speed_timeout;
|
||||
self.settings_version = settings_version;
|
||||
}
|
||||
|
||||
@@ -125,16 +112,9 @@ impl OpenAiCompletionProvider {
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
let low_speed_timeout = self.low_speed_timeout;
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
low_speed_timeout,
|
||||
);
|
||||
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
@@ -151,8 +131,8 @@ impl OpenAiCompletionProvider {
|
||||
|
||||
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
||||
let model = match request.model {
|
||||
LanguageModel::ZedDotDev(_) => self.default_model(),
|
||||
LanguageModel::OpenAi(model) => model,
|
||||
_ => self.default_model(),
|
||||
};
|
||||
|
||||
Request {
|
||||
@@ -203,17 +183,7 @@ pub fn count_open_ai_tokens(
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
match request.model {
|
||||
LanguageModel::Anthropic(_)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Opus)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Sonnet)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Haiku) => {
|
||||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
||||
}
|
||||
_ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
|
||||
}
|
||||
tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
@@ -78,9 +78,9 @@ impl ZedDotDevCompletionProvider {
|
||||
cx: &AppContext,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
match request.model {
|
||||
LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
||||
LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni)
|
||||
| LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => {
|
||||
count_open_ai_tokens(request, cx.background_executor())
|
||||
}
|
||||
@@ -107,7 +107,6 @@ impl ZedDotDevCompletionProvider {
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
_ => future::ready(Err(anyhow!("invalid model"))).boxed(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
91
crates/assistant/src/embedded_scope.rs
Normal file
91
crates/assistant/src/embedded_scope.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
use editor::MultiBuffer;
|
||||
use gpui::{AppContext, Model, ModelContext, Subscription};
|
||||
|
||||
use crate::{assistant_panel::Conversation, LanguageModelRequestMessage, Role};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct EmbeddedScope {
|
||||
active_buffer: Option<Model<MultiBuffer>>,
|
||||
active_buffer_enabled: bool,
|
||||
active_buffer_subscription: Option<Subscription>,
|
||||
}
|
||||
|
||||
impl EmbeddedScope {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
active_buffer: None,
|
||||
active_buffer_enabled: true,
|
||||
active_buffer_subscription: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_active_buffer(
|
||||
&mut self,
|
||||
buffer: Option<Model<MultiBuffer>>,
|
||||
cx: &mut ModelContext<Conversation>,
|
||||
) {
|
||||
self.active_buffer_subscription.take();
|
||||
|
||||
if let Some(active_buffer) = buffer.clone() {
|
||||
self.active_buffer_subscription =
|
||||
Some(cx.subscribe(&active_buffer, |conversation, _, e, cx| {
|
||||
if let multi_buffer::Event::Edited { .. } = e {
|
||||
conversation.count_remaining_tokens(cx)
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
self.active_buffer = buffer;
|
||||
}
|
||||
|
||||
pub fn active_buffer(&self) -> Option<&Model<MultiBuffer>> {
|
||||
self.active_buffer.as_ref()
|
||||
}
|
||||
|
||||
pub fn active_buffer_enabled(&self) -> bool {
|
||||
self.active_buffer_enabled
|
||||
}
|
||||
|
||||
pub fn set_active_buffer_enabled(&mut self, enabled: bool) {
|
||||
self.active_buffer_enabled = enabled;
|
||||
}
|
||||
|
||||
/// Provide a message for the language model based on the active buffer.
|
||||
pub fn message(&self, cx: &AppContext) -> Option<LanguageModelRequestMessage> {
|
||||
if !self.active_buffer_enabled {
|
||||
return None;
|
||||
}
|
||||
|
||||
let active_buffer = self.active_buffer.as_ref()?;
|
||||
let buffer = active_buffer.read(cx);
|
||||
|
||||
if let Some(singleton) = buffer.as_singleton() {
|
||||
let singleton = singleton.read(cx);
|
||||
|
||||
let filename = singleton
|
||||
.file()
|
||||
.map(|file| file.path().to_string_lossy())
|
||||
.unwrap_or("Untitled".into());
|
||||
|
||||
let text = singleton.text();
|
||||
|
||||
let language = singleton
|
||||
.language()
|
||||
.map(|l| {
|
||||
let name = l.code_fence_block_name();
|
||||
name.to_string()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let markdown =
|
||||
format!("User's active file `{filename}`:\n\n```{language}\n{text}```\n\n");
|
||||
|
||||
return Some(LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: markdown,
|
||||
});
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
use rope::Rope;
|
||||
use std::{cmp::Ordering, ops::Range};
|
||||
|
||||
pub(crate) fn text_in_range_omitting_ranges(
|
||||
rope: &Rope,
|
||||
range: Range<usize>,
|
||||
omit_ranges: &[Range<usize>],
|
||||
) -> String {
|
||||
let mut content = String::with_capacity(range.len());
|
||||
let mut omit_ranges = omit_ranges
|
||||
.iter()
|
||||
.skip_while(|omit_range| omit_range.end <= range.start)
|
||||
.peekable();
|
||||
let mut offset = range.start;
|
||||
let mut chunks = rope.chunks_in_range(range.clone());
|
||||
while let Some(chunk) = chunks.next() {
|
||||
if let Some(omit_range) = omit_ranges.peek() {
|
||||
match offset.cmp(&omit_range.start) {
|
||||
Ordering::Less => {
|
||||
let max_len = omit_range.start - offset;
|
||||
if chunk.len() < max_len {
|
||||
content.push_str(chunk);
|
||||
offset += chunk.len();
|
||||
} else {
|
||||
content.push_str(&chunk[..max_len]);
|
||||
chunks.seek(omit_range.end.min(range.end));
|
||||
offset = omit_range.end;
|
||||
omit_ranges.next();
|
||||
}
|
||||
}
|
||||
Ordering::Equal | Ordering::Greater => {
|
||||
chunks.seek(omit_range.end.min(range.end));
|
||||
offset = omit_range.end;
|
||||
omit_ranges.next();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
content.push_str(chunk);
|
||||
offset += chunk.len();
|
||||
}
|
||||
}
|
||||
|
||||
content
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::{rngs::StdRng, Rng as _};
|
||||
use util::RandomCharIter;
|
||||
|
||||
#[gpui::test(iterations = 100)]
|
||||
fn test_text_in_range_omitting_ranges(mut rng: StdRng) {
|
||||
let text = RandomCharIter::new(&mut rng).take(1024).collect::<String>();
|
||||
let rope = Rope::from(text.as_str());
|
||||
|
||||
let mut start = rng.gen_range(0..=text.len() / 2);
|
||||
let mut end = rng.gen_range(text.len() / 2..=text.len());
|
||||
while !text.is_char_boundary(start) {
|
||||
start -= 1;
|
||||
}
|
||||
while !text.is_char_boundary(end) {
|
||||
end += 1;
|
||||
}
|
||||
let range = start..end;
|
||||
|
||||
let mut ix = 0;
|
||||
let mut omit_ranges = Vec::new();
|
||||
for _ in 0..rng.gen_range(0..10) {
|
||||
let mut start = rng.gen_range(ix..=text.len());
|
||||
while !text.is_char_boundary(start) {
|
||||
start += 1;
|
||||
}
|
||||
let mut end = rng.gen_range(start..=text.len());
|
||||
while !text.is_char_boundary(end) {
|
||||
end += 1;
|
||||
}
|
||||
omit_ranges.push(start..end);
|
||||
ix = end;
|
||||
if ix == text.len() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut expected_text = text[range.clone()].to_string();
|
||||
for omit_range in omit_ranges.iter().rev() {
|
||||
let start = omit_range
|
||||
.start
|
||||
.saturating_sub(range.start)
|
||||
.min(range.len());
|
||||
let end = omit_range.end.saturating_sub(range.start).min(range.len());
|
||||
expected_text.replace_range(start..end, "");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
text_in_range_omitting_ranges(&rope, range.clone(), &omit_ranges),
|
||||
expected_text,
|
||||
"text: {text:?}\nrange: {range:?}\nomit_ranges: {omit_ranges:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,95 @@
|
||||
pub mod prompt;
|
||||
pub mod prompt_library;
|
||||
pub mod prompt_manager;
|
||||
use language::BufferSnapshot;
|
||||
use std::{fmt::Write, ops::Range};
|
||||
|
||||
pub fn generate_content_prompt(
|
||||
user_prompt: String,
|
||||
language_name: Option<&str>,
|
||||
buffer: BufferSnapshot,
|
||||
range: Range<usize>,
|
||||
project_name: Option<String>,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut prompt = String::new();
|
||||
|
||||
let content_type = match language_name {
|
||||
None | Some("Markdown" | "Plain Text") => {
|
||||
writeln!(prompt, "You are an expert engineer.")?;
|
||||
"Text"
|
||||
}
|
||||
Some(language_name) => {
|
||||
writeln!(prompt, "You are an expert {language_name} engineer.")?;
|
||||
writeln!(
|
||||
prompt,
|
||||
"Your answer MUST always and only be valid {}.",
|
||||
language_name
|
||||
)?;
|
||||
"Code"
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(project_name) = project_name {
|
||||
writeln!(
|
||||
prompt,
|
||||
"You are currently working inside the '{project_name}' project in code editor Zed."
|
||||
)?;
|
||||
}
|
||||
|
||||
// Include file content.
|
||||
for chunk in buffer.text_for_range(0..range.start) {
|
||||
prompt.push_str(chunk);
|
||||
}
|
||||
|
||||
if range.is_empty() {
|
||||
prompt.push_str("<|START|>");
|
||||
} else {
|
||||
prompt.push_str("<|START|");
|
||||
}
|
||||
|
||||
for chunk in buffer.text_for_range(range.clone()) {
|
||||
prompt.push_str(chunk);
|
||||
}
|
||||
|
||||
if !range.is_empty() {
|
||||
prompt.push_str("|END|>");
|
||||
}
|
||||
|
||||
for chunk in buffer.text_for_range(range.end..buffer.len()) {
|
||||
prompt.push_str(chunk);
|
||||
}
|
||||
|
||||
prompt.push('\n');
|
||||
|
||||
if range.is_empty() {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Assume the cursor is located where the `<|START|>` span is."
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"{content_type} can't be replaced, so assume your answer will be inserted at the cursor.",
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Generate {content_type} based on the users prompt: {user_prompt}",
|
||||
)
|
||||
.unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
|
||||
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Double check that you only return code and not the '<|START|' and '|END|'> spans"
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
writeln!(prompt, "Never make remarks about the output.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Do not return anything else, except the generated {content_type}."
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
|
||||
@@ -1,278 +0,0 @@
|
||||
use language::BufferSnapshot;
|
||||
use std::{fmt::Write, ops::Range};
|
||||
use ui::SharedString;
|
||||
|
||||
use gray_matter::{engine::YAML, Matter};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct StaticPromptFrontmatter {
|
||||
title: String,
|
||||
version: String,
|
||||
author: String,
|
||||
#[serde(default)]
|
||||
languages: Vec<String>,
|
||||
#[serde(default)]
|
||||
dependencies: Vec<String>,
|
||||
}
|
||||
|
||||
impl Default for StaticPromptFrontmatter {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
title: "New Prompt".to_string(),
|
||||
version: "1.0".to_string(),
|
||||
author: "No Author".to_string(),
|
||||
languages: vec!["*".to_string()],
|
||||
dependencies: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StaticPromptFrontmatter {
|
||||
pub fn title(&self) -> SharedString {
|
||||
self.title.clone().into()
|
||||
}
|
||||
|
||||
// pub fn version(&self) -> SharedString {
|
||||
// self.version.clone().into()
|
||||
// }
|
||||
|
||||
// pub fn author(&self) -> SharedString {
|
||||
// self.author.clone().into()
|
||||
// }
|
||||
|
||||
// pub fn languages(&self) -> Vec<SharedString> {
|
||||
// self.languages
|
||||
// .clone()
|
||||
// .into_iter()
|
||||
// .map(|s| s.into())
|
||||
// .collect()
|
||||
// }
|
||||
|
||||
// pub fn dependencies(&self) -> Vec<SharedString> {
|
||||
// self.dependencies
|
||||
// .clone()
|
||||
// .into_iter()
|
||||
// .map(|s| s.into())
|
||||
// .collect()
|
||||
// }
|
||||
}
|
||||
|
||||
/// A statuc prompt that can be loaded into the prompt library
|
||||
/// from Markdown with a frontmatter header
|
||||
///
|
||||
/// Examples:
|
||||
///
|
||||
/// ### Globally available prompt
|
||||
///
|
||||
/// ```markdown
|
||||
/// ---
|
||||
/// title: Foo
|
||||
/// version: 1.0
|
||||
/// author: Jane Kim <jane@kim.com
|
||||
/// languages: ["*"]
|
||||
/// dependencies: []
|
||||
/// ---
|
||||
///
|
||||
/// Foo and bar are terms used in programming to describe generic concepts.
|
||||
/// ```
|
||||
///
|
||||
/// ### Language-specific prompt
|
||||
///
|
||||
/// ```markdown
|
||||
/// ---
|
||||
/// title: UI with GPUI
|
||||
/// version: 1.0
|
||||
/// author: Nate Butler <iamnbutler@gmail.com>
|
||||
/// languages: ["rust"]
|
||||
/// dependencies: ["gpui"]
|
||||
/// ---
|
||||
///
|
||||
/// When building a UI with GPUI, ensure you...
|
||||
/// ```
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct StaticPrompt {
|
||||
content: String,
|
||||
file_name: Option<String>,
|
||||
}
|
||||
|
||||
impl StaticPrompt {
|
||||
pub fn new(content: String) -> Self {
|
||||
StaticPrompt {
|
||||
content,
|
||||
file_name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn title(&self) -> Option<SharedString> {
|
||||
self.metadata().map(|m| m.title())
|
||||
}
|
||||
|
||||
// pub fn version(&self) -> Option<SharedString> {
|
||||
// self.metadata().map(|m| m.version())
|
||||
// }
|
||||
|
||||
// pub fn author(&self) -> Option<SharedString> {
|
||||
// self.metadata().map(|m| m.author())
|
||||
// }
|
||||
|
||||
// pub fn languages(&self) -> Vec<SharedString> {
|
||||
// self.metadata().map(|m| m.languages()).unwrap_or_default()
|
||||
// }
|
||||
|
||||
// pub fn dependencies(&self) -> Vec<SharedString> {
|
||||
// self.metadata()
|
||||
// .map(|m| m.dependencies())
|
||||
// .unwrap_or_default()
|
||||
// }
|
||||
|
||||
// pub fn load(fs: Arc<Fs>, file_name: String) -> anyhow::Result<Self> {
|
||||
// todo!()
|
||||
// }
|
||||
|
||||
// pub fn save(&self, fs: Arc<Fs>) -> anyhow::Result<()> {
|
||||
// todo!()
|
||||
// }
|
||||
|
||||
// pub fn rename(&self, new_file_name: String, fs: Arc<Fs>) -> anyhow::Result<()> {
|
||||
// todo!()
|
||||
// }
|
||||
}
|
||||
|
||||
impl StaticPrompt {
|
||||
// pub fn update(&mut self, contents: String) -> &mut Self {
|
||||
// self.content = contents;
|
||||
// self
|
||||
// }
|
||||
|
||||
/// Sets the file name of the prompt
|
||||
pub fn file_name(&mut self, file_name: String) -> &mut Self {
|
||||
self.file_name = Some(file_name);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the file name of the prompt based on the title
|
||||
// pub fn file_name_from_title(&mut self) -> &mut Self {
|
||||
// if let Some(title) = self.title() {
|
||||
// let file_name = title.to_lowercase().replace(" ", "_");
|
||||
// if !file_name.is_empty() {
|
||||
// self.file_name = Some(file_name);
|
||||
// }
|
||||
// }
|
||||
// self
|
||||
// }
|
||||
|
||||
/// Returns the prompt's content
|
||||
pub fn content(&self) -> &String {
|
||||
&self.content
|
||||
}
|
||||
fn parse(&self) -> anyhow::Result<(StaticPromptFrontmatter, String)> {
|
||||
let matter = Matter::<YAML>::new();
|
||||
let result = matter.parse(self.content.as_str());
|
||||
match result.data {
|
||||
Some(data) => {
|
||||
let front_matter: StaticPromptFrontmatter = data.deserialize()?;
|
||||
let body = result.content;
|
||||
Ok((front_matter, body))
|
||||
}
|
||||
None => Err(anyhow::anyhow!("Failed to parse frontmatter")),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn metadata(&self) -> Option<StaticPromptFrontmatter> {
|
||||
self.parse().ok().map(|(front_matter, _)| front_matter)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_content_prompt(
|
||||
user_prompt: String,
|
||||
language_name: Option<&str>,
|
||||
buffer: BufferSnapshot,
|
||||
range: Range<usize>,
|
||||
project_name: Option<String>,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut prompt = String::new();
|
||||
|
||||
let content_type = match language_name {
|
||||
None | Some("Markdown" | "Plain Text") => {
|
||||
writeln!(prompt, "You are an expert engineer.")?;
|
||||
"Text"
|
||||
}
|
||||
Some(language_name) => {
|
||||
writeln!(prompt, "You are an expert {language_name} engineer.")?;
|
||||
writeln!(
|
||||
prompt,
|
||||
"Your answer MUST always and only be valid {}.",
|
||||
language_name
|
||||
)?;
|
||||
"Code"
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(project_name) = project_name {
|
||||
writeln!(
|
||||
prompt,
|
||||
"You are currently working inside the '{project_name}' project in code editor Zed."
|
||||
)?;
|
||||
}
|
||||
|
||||
// Include file content.
|
||||
for chunk in buffer.text_for_range(0..range.start) {
|
||||
prompt.push_str(chunk);
|
||||
}
|
||||
|
||||
if range.is_empty() {
|
||||
prompt.push_str("<|START|>");
|
||||
} else {
|
||||
prompt.push_str("<|START|");
|
||||
}
|
||||
|
||||
for chunk in buffer.text_for_range(range.clone()) {
|
||||
prompt.push_str(chunk);
|
||||
}
|
||||
|
||||
if !range.is_empty() {
|
||||
prompt.push_str("|END|>");
|
||||
}
|
||||
|
||||
for chunk in buffer.text_for_range(range.end..buffer.len()) {
|
||||
prompt.push_str(chunk);
|
||||
}
|
||||
|
||||
prompt.push('\n');
|
||||
|
||||
if range.is_empty() {
|
||||
writeln!(
|
||||
prompt,
|
||||
"Assume the cursor is located where the `<|START|>` span is."
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"{content_type} can't be replaced, so assume your answer will be inserted at the cursor.",
|
||||
)
|
||||
.unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Generate {content_type} based on the users prompt: {user_prompt}",
|
||||
)
|
||||
.unwrap();
|
||||
} else {
|
||||
writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
|
||||
writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Double check that you only return code and not the '<|START|' and '|END|'> spans"
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
writeln!(prompt, "Never make remarks about the output.").unwrap();
|
||||
writeln!(
|
||||
prompt,
|
||||
"Do not return anything else, except the generated {content_type}."
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
@@ -1,152 +0,0 @@
|
||||
use anyhow::Context;
|
||||
use collections::HashMap;
|
||||
use fs::Fs;
|
||||
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol::stream::StreamExt;
|
||||
use std::sync::Arc;
|
||||
use util::paths::PROMPTS_DIR;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::prompt::StaticPrompt;
|
||||
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize)]
|
||||
pub struct PromptId(pub Uuid);
|
||||
|
||||
#[allow(unused)]
|
||||
impl PromptId {
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
pub struct PromptLibraryState {
|
||||
/// A set of prompts that all assistant contexts will start with
|
||||
default_prompt: Vec<PromptId>,
|
||||
/// All [Prompt]s loaded into the library
|
||||
prompts: HashMap<PromptId, StaticPrompt>,
|
||||
/// Prompts that have been changed but haven't been
|
||||
/// saved back to the file system
|
||||
dirty_prompts: Vec<PromptId>,
|
||||
version: usize,
|
||||
}
|
||||
|
||||
pub struct PromptLibrary {
|
||||
state: RwLock<PromptLibraryState>,
|
||||
}
|
||||
|
||||
impl Default for PromptLibrary {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptLibrary {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
state: RwLock::new(PromptLibraryState::default()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prompts(&self) -> Vec<(PromptId, StaticPrompt)> {
|
||||
let state = self.state.read();
|
||||
state
|
||||
.prompts
|
||||
.iter()
|
||||
.map(|(id, prompt)| (*id, prompt.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn first_prompt_id(&self) -> Option<PromptId> {
|
||||
let state = self.state.read();
|
||||
state.prompts.keys().next().cloned()
|
||||
}
|
||||
|
||||
pub fn prompt(&self, id: PromptId) -> Option<StaticPrompt> {
|
||||
let state = self.state.read();
|
||||
state.prompts.get(&id).cloned()
|
||||
}
|
||||
|
||||
/// Save the current state of the prompt library to the
|
||||
/// file system as a JSON file
|
||||
pub async fn save(&self, fs: Arc<dyn Fs>) -> anyhow::Result<()> {
|
||||
fs.create_dir(&PROMPTS_DIR).await?;
|
||||
|
||||
let path = PROMPTS_DIR.join("index.json");
|
||||
|
||||
let json = {
|
||||
let state = self.state.read();
|
||||
serde_json::to_string(&*state)?
|
||||
};
|
||||
|
||||
fs.atomic_write(path, json).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load the state of the prompt library from the file system
|
||||
/// or create a new one if it doesn't exist
|
||||
pub async fn load(fs: Arc<dyn Fs>) -> anyhow::Result<Self> {
|
||||
let path = PROMPTS_DIR.join("index.json");
|
||||
|
||||
let state = if fs.is_file(&path).await {
|
||||
let json = fs.load(&path).await?;
|
||||
serde_json::from_str(&json)?
|
||||
} else {
|
||||
PromptLibraryState::default()
|
||||
};
|
||||
|
||||
let mut prompt_library = Self {
|
||||
state: RwLock::new(state),
|
||||
};
|
||||
|
||||
prompt_library.load_prompts(fs).await?;
|
||||
|
||||
Ok(prompt_library)
|
||||
}
|
||||
|
||||
/// Load all prompts from the file system
|
||||
/// adding them to the library if they don't already exist
|
||||
pub async fn load_prompts(&mut self, fs: Arc<dyn Fs>) -> anyhow::Result<()> {
|
||||
// let current_prompts = self.all_prompt_contents().clone();
|
||||
|
||||
// For now, we'll just clear the prompts and reload them all
|
||||
self.state.get_mut().prompts.clear();
|
||||
|
||||
let mut prompt_paths = fs.read_dir(&PROMPTS_DIR).await?;
|
||||
|
||||
while let Some(prompt_path) = prompt_paths.next().await {
|
||||
let prompt_path = prompt_path.with_context(|| "Failed to read prompt path")?;
|
||||
|
||||
if !fs.is_file(&prompt_path).await
|
||||
|| prompt_path.extension().and_then(|ext| ext.to_str()) != Some("md")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let json = fs
|
||||
.load(&prompt_path)
|
||||
.await
|
||||
.with_context(|| format!("Failed to load prompt {:?}", prompt_path))?;
|
||||
let mut static_prompt = StaticPrompt::new(json);
|
||||
|
||||
if let Some(file_name) = prompt_path.file_name() {
|
||||
let file_name = file_name.to_string_lossy().into_owned();
|
||||
static_prompt.file_name(file_name);
|
||||
}
|
||||
|
||||
let state = self.state.get_mut();
|
||||
|
||||
let id = Uuid::new_v4();
|
||||
state.prompts.insert(PromptId(id), static_prompt);
|
||||
state.version += 1;
|
||||
}
|
||||
|
||||
// Write any changes back to the file system
|
||||
self.save(fs.clone()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,327 +0,0 @@
|
||||
use collections::HashMap;
|
||||
use editor::Editor;
|
||||
use fs::Fs;
|
||||
use gpui::{prelude::FluentBuilder, *};
|
||||
use language::{language_settings, Buffer, LanguageRegistry};
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use std::sync::Arc;
|
||||
use ui::{prelude::*, IconButtonShape, ListItem, ListItemSpacing};
|
||||
use util::{ResultExt, TryFutureExt};
|
||||
use workspace::ModalView;
|
||||
|
||||
use super::prompt_library::{PromptId, PromptLibrary};
|
||||
use crate::prompts::prompt::StaticPrompt;
|
||||
|
||||
pub struct PromptManager {
|
||||
focus_handle: FocusHandle,
|
||||
prompt_library: Arc<PromptLibrary>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
#[allow(dead_code)]
|
||||
fs: Arc<dyn Fs>,
|
||||
picker: View<Picker<PromptManagerDelegate>>,
|
||||
prompt_editors: HashMap<PromptId, View<Editor>>,
|
||||
active_prompt_id: Option<PromptId>,
|
||||
}
|
||||
|
||||
impl PromptManager {
|
||||
pub fn new(
|
||||
prompt_library: Arc<PromptLibrary>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Self {
|
||||
let prompt_manager = cx.view().downgrade();
|
||||
let picker = cx.new_view(|cx| {
|
||||
Picker::uniform_list(
|
||||
PromptManagerDelegate {
|
||||
prompt_manager,
|
||||
matching_prompts: vec![],
|
||||
matching_prompt_ids: vec![],
|
||||
prompt_library: prompt_library.clone(),
|
||||
selected_index: 0,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
.max_height(rems(35.75))
|
||||
.modal(false)
|
||||
});
|
||||
|
||||
let focus_handle = picker.focus_handle(cx);
|
||||
|
||||
let mut manager = Self {
|
||||
focus_handle,
|
||||
prompt_library,
|
||||
language_registry,
|
||||
fs,
|
||||
picker,
|
||||
prompt_editors: HashMap::default(),
|
||||
active_prompt_id: None,
|
||||
};
|
||||
|
||||
manager.active_prompt_id = manager.prompt_library.first_prompt_id();
|
||||
|
||||
manager
|
||||
}
|
||||
|
||||
pub fn set_active_prompt(&mut self, prompt_id: Option<PromptId>, cx: &mut ViewContext<Self>) {
|
||||
self.active_prompt_id = prompt_id;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn focus_active_editor(&self, cx: &mut ViewContext<Self>) {
|
||||
if let Some(active_prompt_id) = self.active_prompt_id {
|
||||
if let Some(editor) = self.prompt_editors.get(&active_prompt_id) {
|
||||
let focus_handle = editor.focus_handle(cx);
|
||||
|
||||
cx.focus(&focus_handle)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dismiss(&mut self, _: &menu::Cancel, cx: &mut ViewContext<Self>) {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
|
||||
fn render_prompt_list(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let picker = self.picker.clone();
|
||||
|
||||
v_flex()
|
||||
.id("prompt-list")
|
||||
.bg(cx.theme().colors().surface_background)
|
||||
.h_full()
|
||||
.w_2_5()
|
||||
.child(
|
||||
h_flex()
|
||||
.bg(cx.theme().colors().background)
|
||||
.p(Spacing::Small.rems(cx))
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.h(rems(1.75))
|
||||
.w_full()
|
||||
.flex_none()
|
||||
.justify_between()
|
||||
.child(Label::new("Prompt Library").size(LabelSize::Small))
|
||||
.child(IconButton::new("new-prompt", IconName::Plus).disabled(true)),
|
||||
)
|
||||
.child(
|
||||
v_flex()
|
||||
.h(rems(38.25))
|
||||
.flex_grow()
|
||||
.justify_start()
|
||||
.child(picker),
|
||||
)
|
||||
}
|
||||
|
||||
fn set_editor_for_prompt(
|
||||
&mut self,
|
||||
prompt_id: PromptId,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> impl IntoElement {
|
||||
let prompt_library = self.prompt_library.clone();
|
||||
|
||||
let editor_for_prompt = self.prompt_editors.entry(prompt_id).or_insert_with(|| {
|
||||
cx.new_view(|cx| {
|
||||
let text = if let Some(prompt) = prompt_library.prompt(prompt_id) {
|
||||
prompt.content().to_owned()
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
let buffer = cx.new_model(|cx| {
|
||||
let mut buffer = Buffer::local(text, cx);
|
||||
let markdown = self.language_registry.language_for_name("Markdown");
|
||||
cx.spawn(|buffer, mut cx| async move {
|
||||
if let Some(markdown) = markdown.await.log_err() {
|
||||
_ = buffer.update(&mut cx, |buffer, cx| {
|
||||
buffer.set_language(Some(markdown), cx);
|
||||
});
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
buffer.set_language_registry(self.language_registry.clone());
|
||||
buffer
|
||||
});
|
||||
let mut editor = Editor::for_buffer(buffer, None, cx);
|
||||
editor.set_soft_wrap_mode(language_settings::SoftWrap::EditorWidth, cx);
|
||||
editor.set_show_gutter(false, cx);
|
||||
editor
|
||||
})
|
||||
});
|
||||
editor_for_prompt.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for PromptManager {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
h_flex()
|
||||
.key_context("PromptManager")
|
||||
.track_focus(&self.focus_handle)
|
||||
.on_action(cx.listener(Self::dismiss))
|
||||
// .on_action(cx.listener(Self::save_active_prompt))
|
||||
.elevation_3(cx)
|
||||
.size_full()
|
||||
.flex_none()
|
||||
.w(rems(64.))
|
||||
.h(rems(40.))
|
||||
.overflow_hidden()
|
||||
.child(self.render_prompt_list(cx))
|
||||
.child(
|
||||
div().w_3_5().h_full().child(
|
||||
v_flex()
|
||||
.id("prompt-editor")
|
||||
.border_l_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.size_full()
|
||||
.flex_none()
|
||||
.min_w_64()
|
||||
.h_full()
|
||||
.child(
|
||||
h_flex()
|
||||
.bg(cx.theme().colors().background)
|
||||
.p(Spacing::Small.rems(cx))
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.h_7()
|
||||
.w_full()
|
||||
.justify_between()
|
||||
.child(div())
|
||||
.child(
|
||||
IconButton::new("dismiss", IconName::Close)
|
||||
.shape(IconButtonShape::Square)
|
||||
.on_click(|_, cx| {
|
||||
cx.dispatch_action(menu::Cancel.boxed_clone());
|
||||
}),
|
||||
),
|
||||
)
|
||||
.when_some(self.active_prompt_id, |this, active_prompt_id| {
|
||||
this.child(
|
||||
h_flex()
|
||||
.flex_1()
|
||||
.w_full()
|
||||
.py(Spacing::Large.rems(cx))
|
||||
.px(Spacing::XLarge.rems(cx))
|
||||
.child(self.set_editor_for_prompt(active_prompt_id, cx)),
|
||||
)
|
||||
}),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<DismissEvent> for PromptManager {}
|
||||
impl ModalView for PromptManager {}
|
||||
|
||||
impl FocusableView for PromptManager {
|
||||
fn focus_handle(&self, _cx: &AppContext) -> gpui::FocusHandle {
|
||||
self.focus_handle.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PromptManagerDelegate {
|
||||
prompt_manager: WeakView<PromptManager>,
|
||||
matching_prompts: Vec<Arc<StaticPrompt>>,
|
||||
matching_prompt_ids: Vec<PromptId>,
|
||||
prompt_library: Arc<PromptLibrary>,
|
||||
selected_index: usize,
|
||||
}
|
||||
|
||||
impl PickerDelegate for PromptManagerDelegate {
|
||||
type ListItem = ListItem;
|
||||
|
||||
fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
|
||||
"Find a prompt…".into()
|
||||
}
|
||||
|
||||
fn match_count(&self) -> usize {
|
||||
self.matching_prompt_ids.len()
|
||||
}
|
||||
|
||||
fn selected_index(&self) -> usize {
|
||||
self.selected_index
|
||||
}
|
||||
|
||||
fn set_selected_index(&mut self, ix: usize, _: &mut ViewContext<Picker<Self>>) {
|
||||
self.selected_index = ix;
|
||||
}
|
||||
|
||||
fn selected_index_changed(
|
||||
&self,
|
||||
ix: usize,
|
||||
_cx: &mut ViewContext<Picker<Self>>,
|
||||
) -> Option<Box<dyn Fn(&mut WindowContext) + 'static>> {
|
||||
let prompt_id = self.matching_prompt_ids.get(ix).copied()?;
|
||||
let prompt_manager = self.prompt_manager.upgrade()?;
|
||||
|
||||
Some(Box::new(move |cx| {
|
||||
prompt_manager.update(cx, |manager, cx| {
|
||||
manager.set_active_prompt(Some(prompt_id), cx);
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
|
||||
let prompt_library = self.prompt_library.clone();
|
||||
cx.spawn(|picker, mut cx| async move {
|
||||
async {
|
||||
let prompts = prompt_library.prompts();
|
||||
let matching_prompts = prompts
|
||||
.into_iter()
|
||||
.filter(|(_, prompt)| {
|
||||
prompt
|
||||
.content()
|
||||
.to_lowercase()
|
||||
.contains(&query.to_lowercase())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
picker.update(&mut cx, |picker, cx| {
|
||||
picker.delegate.matching_prompt_ids =
|
||||
matching_prompts.iter().map(|(id, _)| *id).collect();
|
||||
picker.delegate.matching_prompts = matching_prompts
|
||||
.into_iter()
|
||||
.map(|(_, prompt)| Arc::new(prompt))
|
||||
.collect();
|
||||
cx.notify();
|
||||
})?;
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.log_err()
|
||||
.await;
|
||||
})
|
||||
}
|
||||
|
||||
fn confirm(&mut self, _: bool, cx: &mut ViewContext<Picker<Self>>) {
|
||||
let prompt_manager = self.prompt_manager.upgrade().unwrap();
|
||||
prompt_manager.update(cx, move |manager, cx| manager.focus_active_editor(cx));
|
||||
}
|
||||
|
||||
fn should_dismiss(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn dismissed(&mut self, cx: &mut ViewContext<Picker<Self>>) {
|
||||
self.prompt_manager
|
||||
.update(cx, |_, cx| {
|
||||
cx.emit(DismissEvent);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn render_match(
|
||||
&self,
|
||||
ix: usize,
|
||||
selected: bool,
|
||||
_cx: &mut ViewContext<Picker<Self>>,
|
||||
) -> Option<Self::ListItem> {
|
||||
let matching_prompt = self.matching_prompts.get(ix)?;
|
||||
let prompt = matching_prompt.clone();
|
||||
|
||||
Some(
|
||||
ListItem::new(ix)
|
||||
.inset(true)
|
||||
.spacing(ListItemSpacing::Sparse)
|
||||
.selected(selected)
|
||||
.child(Label::new(prompt.title().unwrap_or_default().clone())),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,171 +0,0 @@
|
||||
use language::Rope;
|
||||
use std::ops::Range;
|
||||
|
||||
/// Search the given buffer for the given substring, ignoring any differences
|
||||
/// in line indentation between the query and the buffer.
|
||||
///
|
||||
/// Returns a vector of ranges of byte offsets in the buffer corresponding
|
||||
/// to the entire lines of the buffer.
|
||||
pub fn fuzzy_search_lines(haystack: &Rope, needle: &str) -> Option<Range<usize>> {
|
||||
const SIMILARITY_THRESHOLD: f64 = 0.8;
|
||||
|
||||
let mut best_match: Option<(Range<usize>, f64)> = None; // (range, score)
|
||||
let mut haystack_lines = haystack.chunks().lines();
|
||||
let mut haystack_line_start = 0;
|
||||
while let Some(mut haystack_line) = haystack_lines.next() {
|
||||
let next_haystack_line_start = haystack_line_start + haystack_line.len() + 1;
|
||||
let mut advanced_to_next_haystack_line = false;
|
||||
|
||||
let mut matched = true;
|
||||
let match_start = haystack_line_start;
|
||||
let mut match_end = next_haystack_line_start;
|
||||
let mut match_score = 0.0;
|
||||
let mut needle_lines = needle.lines().peekable();
|
||||
while let Some(needle_line) = needle_lines.next() {
|
||||
let similarity = line_similarity(haystack_line, needle_line);
|
||||
if similarity >= SIMILARITY_THRESHOLD {
|
||||
match_end = haystack_lines.offset();
|
||||
match_score += similarity;
|
||||
|
||||
if needle_lines.peek().is_some() {
|
||||
if let Some(next_haystack_line) = haystack_lines.next() {
|
||||
advanced_to_next_haystack_line = true;
|
||||
haystack_line = next_haystack_line;
|
||||
} else {
|
||||
matched = false;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
matched = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if matched
|
||||
&& best_match
|
||||
.as_ref()
|
||||
.map(|(_, best_score)| match_score > *best_score)
|
||||
.unwrap_or(true)
|
||||
{
|
||||
best_match = Some((match_start..match_end, match_score));
|
||||
}
|
||||
|
||||
if advanced_to_next_haystack_line {
|
||||
haystack_lines.seek(next_haystack_line_start);
|
||||
}
|
||||
haystack_line_start = next_haystack_line_start;
|
||||
}
|
||||
|
||||
best_match.map(|(range, _)| range)
|
||||
}
|
||||
|
||||
/// Calculates the similarity between two lines, ignoring leading and trailing whitespace,
|
||||
/// using the Jaro-Winkler distance.
|
||||
///
|
||||
/// Returns a value between 0.0 and 1.0, where 1.0 indicates an exact match.
|
||||
fn line_similarity(line1: &str, line2: &str) -> f64 {
|
||||
strsim::jaro_winkler(line1.trim(), line2.trim())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use gpui::{AppContext, Context as _};
|
||||
use language::Buffer;
|
||||
use unindent::Unindent as _;
|
||||
use util::test::marked_text_ranges;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_fuzzy_search_lines(cx: &mut AppContext) {
|
||||
let (text, expected_ranges) = marked_text_ranges(
|
||||
&r#"
|
||||
fn main() {
|
||||
if a() {
|
||||
assert_eq!(
|
||||
1 + 2,
|
||||
does_not_match,
|
||||
);
|
||||
}
|
||||
|
||||
println!("hi");
|
||||
|
||||
assert_eq!(
|
||||
1 + 2,
|
||||
3,
|
||||
); // this last line does not match
|
||||
|
||||
« assert_eq!(
|
||||
1 + 2,
|
||||
3,
|
||||
);
|
||||
»
|
||||
|
||||
« assert_eq!(
|
||||
"something",
|
||||
"else",
|
||||
);
|
||||
»
|
||||
}
|
||||
"#
|
||||
.unindent(),
|
||||
false,
|
||||
);
|
||||
|
||||
let buffer = cx.new_model(|cx| Buffer::local(&text, cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||
|
||||
let actual_range = fuzzy_search_lines(
|
||||
snapshot.as_rope(),
|
||||
&"
|
||||
assert_eq!(
|
||||
1 + 2,
|
||||
3,
|
||||
);
|
||||
"
|
||||
.unindent(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(actual_range, expected_ranges[0]);
|
||||
|
||||
let actual_range = fuzzy_search_lines(
|
||||
snapshot.as_rope(),
|
||||
&"
|
||||
assert_eq!(
|
||||
1 + 2,
|
||||
3,
|
||||
);
|
||||
"
|
||||
.unindent(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(actual_range, expected_ranges[0]);
|
||||
|
||||
let actual_range = fuzzy_search_lines(
|
||||
snapshot.as_rope(),
|
||||
&"
|
||||
asst_eq!(
|
||||
\"something\",
|
||||
\"els\"
|
||||
)
|
||||
"
|
||||
.unindent(),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(actual_range, expected_ranges[1]);
|
||||
|
||||
let actual_range = fuzzy_search_lines(
|
||||
snapshot.as_rope(),
|
||||
&"
|
||||
assert_eq!(
|
||||
2 + 1,
|
||||
3,
|
||||
);
|
||||
"
|
||||
.unindent(),
|
||||
);
|
||||
assert_eq!(actual_range, None);
|
||||
}
|
||||
}
|
||||
@@ -1,319 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use editor::{CompletionProvider, Editor};
|
||||
use futures::channel::oneshot;
|
||||
use fuzzy::{match_strings, StringMatchCandidate};
|
||||
use gpui::{AppContext, Model, Task, ViewContext, WindowHandle};
|
||||
use language::{Anchor, Buffer, CodeLabel, Documentation, LanguageServerId, ToPoint};
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use project::Project;
|
||||
use rope::Point;
|
||||
use std::{
|
||||
ops::Range,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering::SeqCst},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::PromptLibrary;
|
||||
|
||||
mod current_file_command;
|
||||
mod file_command;
|
||||
mod prompt_command;
|
||||
|
||||
pub(crate) struct SlashCommandCompletionProvider {
|
||||
commands: Arc<SlashCommandRegistry>,
|
||||
cancel_flag: Mutex<Arc<AtomicBool>>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct SlashCommandRegistry {
|
||||
commands: HashMap<String, Box<dyn SlashCommand>>,
|
||||
}
|
||||
|
||||
pub(crate) trait SlashCommand: 'static + Send + Sync {
|
||||
fn name(&self) -> String;
|
||||
fn description(&self) -> String;
|
||||
fn complete_argument(
|
||||
&self,
|
||||
query: String,
|
||||
cancel: Arc<AtomicBool>,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<String>>>;
|
||||
fn requires_argument(&self) -> bool;
|
||||
fn run(&self, argument: Option<&str>, cx: &mut AppContext) -> SlashCommandInvocation;
|
||||
}
|
||||
|
||||
pub(crate) struct SlashCommandInvocation {
|
||||
pub output: Task<Result<String>>,
|
||||
pub invalidated: oneshot::Receiver<()>,
|
||||
pub cleanup: SlashCommandCleanup,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct SlashCommandCleanup(Option<Box<dyn FnOnce()>>);
|
||||
|
||||
impl SlashCommandCleanup {
|
||||
pub fn new(cleanup: impl FnOnce() + 'static) -> Self {
|
||||
Self(Some(Box::new(cleanup)))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SlashCommandCleanup {
|
||||
fn drop(&mut self) {
|
||||
if let Some(cleanup) = self.0.take() {
|
||||
cleanup();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct SlashCommandLine {
|
||||
/// The range within the line containing the command name.
|
||||
pub name: Range<usize>,
|
||||
/// The range within the line containing the command argument.
|
||||
pub argument: Option<Range<usize>>,
|
||||
}
|
||||
|
||||
impl SlashCommandRegistry {
|
||||
pub fn new(
|
||||
project: Model<Project>,
|
||||
prompt_library: Arc<PromptLibrary>,
|
||||
window: Option<WindowHandle<Workspace>>,
|
||||
) -> Arc<Self> {
|
||||
let mut this = Self {
|
||||
commands: HashMap::default(),
|
||||
};
|
||||
|
||||
this.register_command(file_command::FileSlashCommand::new(project));
|
||||
this.register_command(prompt_command::PromptSlashCommand::new(prompt_library));
|
||||
if let Some(window) = window {
|
||||
this.register_command(current_file_command::CurrentFileSlashCommand::new(window));
|
||||
}
|
||||
|
||||
Arc::new(this)
|
||||
}
|
||||
|
||||
fn register_command(&mut self, command: impl SlashCommand) {
|
||||
self.commands.insert(command.name(), Box::new(command));
|
||||
}
|
||||
|
||||
fn command_names(&self) -> impl Iterator<Item = &String> {
|
||||
self.commands.keys()
|
||||
}
|
||||
|
||||
pub(crate) fn command(&self, name: &str) -> Option<&dyn SlashCommand> {
|
||||
self.commands.get(name).map(|b| &**b)
|
||||
}
|
||||
}
|
||||
|
||||
impl SlashCommandCompletionProvider {
|
||||
pub fn new(commands: Arc<SlashCommandRegistry>) -> Self {
|
||||
Self {
|
||||
cancel_flag: Mutex::new(Arc::new(AtomicBool::new(false))),
|
||||
commands,
|
||||
}
|
||||
}
|
||||
|
||||
fn complete_command_name(
|
||||
&self,
|
||||
command_name: &str,
|
||||
range: Range<Anchor>,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<project::Completion>>> {
|
||||
let candidates = self
|
||||
.commands
|
||||
.command_names()
|
||||
.enumerate()
|
||||
.map(|(ix, def)| StringMatchCandidate {
|
||||
id: ix,
|
||||
string: def.clone(),
|
||||
char_bag: def.as_str().into(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let commands = self.commands.clone();
|
||||
let command_name = command_name.to_string();
|
||||
let executor = cx.background_executor().clone();
|
||||
executor.clone().spawn(async move {
|
||||
let matches = match_strings(
|
||||
&candidates,
|
||||
&command_name,
|
||||
true,
|
||||
usize::MAX,
|
||||
&Default::default(),
|
||||
executor,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(matches
|
||||
.into_iter()
|
||||
.filter_map(|mat| {
|
||||
let command = commands.command(&mat.string)?;
|
||||
let mut new_text = mat.string.clone();
|
||||
if command.requires_argument() {
|
||||
new_text.push(' ');
|
||||
}
|
||||
|
||||
Some(project::Completion {
|
||||
old_range: range.clone(),
|
||||
documentation: Some(Documentation::SingleLine(command.description())),
|
||||
new_text,
|
||||
label: CodeLabel::plain(mat.string, None),
|
||||
server_id: LanguageServerId(0),
|
||||
lsp_completion: Default::default(),
|
||||
})
|
||||
})
|
||||
.collect())
|
||||
})
|
||||
}
|
||||
|
||||
fn complete_command_argument(
|
||||
&self,
|
||||
command_name: &str,
|
||||
argument: String,
|
||||
range: Range<Anchor>,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<project::Completion>>> {
|
||||
let new_cancel_flag = Arc::new(AtomicBool::new(false));
|
||||
let mut flag = self.cancel_flag.lock();
|
||||
flag.store(true, SeqCst);
|
||||
*flag = new_cancel_flag.clone();
|
||||
|
||||
if let Some(command) = self.commands.command(command_name) {
|
||||
let completions = command.complete_argument(argument, new_cancel_flag.clone(), cx);
|
||||
cx.background_executor().spawn(async move {
|
||||
Ok(completions
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|arg| project::Completion {
|
||||
old_range: range.clone(),
|
||||
label: CodeLabel::plain(arg.clone(), None),
|
||||
new_text: arg.clone(),
|
||||
documentation: None,
|
||||
server_id: LanguageServerId(0),
|
||||
lsp_completion: Default::default(),
|
||||
})
|
||||
.collect())
|
||||
})
|
||||
} else {
|
||||
cx.background_executor()
|
||||
.spawn(async move { Ok(Vec::new()) })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionProvider for SlashCommandCompletionProvider {
|
||||
fn completions(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
buffer_position: Anchor,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) -> Task<Result<Vec<project::Completion>>> {
|
||||
let task = buffer.update(cx, |buffer, cx| {
|
||||
let position = buffer_position.to_point(buffer);
|
||||
let line_start = Point::new(position.row, 0);
|
||||
let mut lines = buffer.text_for_range(line_start..position).lines();
|
||||
let line = lines.next()?;
|
||||
let call = SlashCommandLine::parse(line)?;
|
||||
|
||||
let name = &line[call.name.clone()];
|
||||
if let Some(argument) = call.argument {
|
||||
let start = buffer.anchor_after(Point::new(position.row, argument.start as u32));
|
||||
let argument = line[argument.clone()].to_string();
|
||||
Some(self.complete_command_argument(name, argument, start..buffer_position, cx))
|
||||
} else {
|
||||
let start = buffer.anchor_after(Point::new(position.row, call.name.start as u32));
|
||||
Some(self.complete_command_name(name, start..buffer_position, cx))
|
||||
}
|
||||
});
|
||||
|
||||
task.unwrap_or_else(|| Task::ready(Ok(Vec::new())))
|
||||
}
|
||||
|
||||
fn resolve_completions(
|
||||
&self,
|
||||
_: Model<Buffer>,
|
||||
_: Vec<usize>,
|
||||
_: Arc<RwLock<Box<[project::Completion]>>>,
|
||||
_: &mut ViewContext<Editor>,
|
||||
) -> Task<Result<bool>> {
|
||||
Task::ready(Ok(true))
|
||||
}
|
||||
|
||||
fn apply_additional_edits_for_completion(
|
||||
&self,
|
||||
_: Model<Buffer>,
|
||||
_: project::Completion,
|
||||
_: bool,
|
||||
_: &mut ViewContext<Editor>,
|
||||
) -> Task<Result<Option<language::Transaction>>> {
|
||||
Task::ready(Ok(None))
|
||||
}
|
||||
|
||||
fn is_completion_trigger(
|
||||
&self,
|
||||
buffer: &Model<Buffer>,
|
||||
position: language::Anchor,
|
||||
_text: &str,
|
||||
_trigger_in_words: bool,
|
||||
cx: &mut ViewContext<Editor>,
|
||||
) -> bool {
|
||||
let buffer = buffer.read(cx);
|
||||
let position = position.to_point(buffer);
|
||||
let line_start = Point::new(position.row, 0);
|
||||
let mut lines = buffer.text_for_range(line_start..position).lines();
|
||||
if let Some(line) = lines.next() {
|
||||
SlashCommandLine::parse(line).is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SlashCommandLine {
|
||||
pub(crate) fn parse(line: &str) -> Option<Self> {
|
||||
let mut call: Option<Self> = None;
|
||||
let mut ix = 0;
|
||||
for c in line.chars() {
|
||||
let next_ix = ix + c.len_utf8();
|
||||
if let Some(call) = &mut call {
|
||||
// The command arguments start at the first non-whitespace character
|
||||
// after the command name, and continue until the end of the line.
|
||||
if let Some(argument) = &mut call.argument {
|
||||
if (*argument).is_empty() && c.is_whitespace() {
|
||||
argument.start = next_ix;
|
||||
}
|
||||
argument.end = next_ix;
|
||||
}
|
||||
// The command name ends at the first whitespace character.
|
||||
else if !call.name.is_empty() {
|
||||
if c.is_whitespace() {
|
||||
call.argument = Some(next_ix..next_ix);
|
||||
} else {
|
||||
call.name.end = next_ix;
|
||||
}
|
||||
}
|
||||
// The command name must begin with a letter.
|
||||
else if c.is_alphabetic() {
|
||||
call.name.end = next_ix;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
// Commands start with a slash.
|
||||
else if c == '/' {
|
||||
call = Some(SlashCommandLine {
|
||||
name: next_ix..next_ix,
|
||||
argument: None,
|
||||
});
|
||||
}
|
||||
// The line can't contain anything before the slash except for whitespace.
|
||||
else if !c.is_whitespace() {
|
||||
return None;
|
||||
}
|
||||
ix = next_ix;
|
||||
}
|
||||
call
|
||||
}
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
use std::{borrow::Cow, cell::Cell, rc::Rc};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use collections::HashMap;
|
||||
use editor::Editor;
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{AppContext, Entity, Subscription, Task, WindowHandle};
|
||||
use workspace::{Event as WorkspaceEvent, Workspace};
|
||||
|
||||
use super::{SlashCommand, SlashCommandCleanup, SlashCommandInvocation};
|
||||
|
||||
pub(crate) struct CurrentFileSlashCommand {
|
||||
workspace: WindowHandle<Workspace>,
|
||||
}
|
||||
|
||||
impl CurrentFileSlashCommand {
|
||||
pub fn new(workspace: WindowHandle<Workspace>) -> Self {
|
||||
Self { workspace }
|
||||
}
|
||||
}
|
||||
|
||||
impl SlashCommand for CurrentFileSlashCommand {
|
||||
fn name(&self) -> String {
|
||||
"current_file".into()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"insert the current file".into()
|
||||
}
|
||||
|
||||
fn complete_argument(
|
||||
&self,
|
||||
_query: String,
|
||||
_cancel: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
||||
_cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<String>>> {
|
||||
Task::ready(Err(anyhow!("this command does not require argument")))
|
||||
}
|
||||
|
||||
fn requires_argument(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn run(&self, _argument: Option<&str>, cx: &mut AppContext) -> SlashCommandInvocation {
|
||||
let (invalidate_tx, invalidate_rx) = oneshot::channel();
|
||||
let invalidate_tx = Rc::new(Cell::new(Some(invalidate_tx)));
|
||||
let mut subscriptions: Vec<Subscription> = Vec::new();
|
||||
let output = self.workspace.update(cx, |workspace, cx| {
|
||||
let mut timestamps_by_entity_id = HashMap::default();
|
||||
for pane in workspace.panes() {
|
||||
let pane = pane.read(cx);
|
||||
for entry in pane.activation_history() {
|
||||
timestamps_by_entity_id.insert(entry.entity_id, entry.timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
let mut most_recent_buffer = None;
|
||||
for editor in workspace.items_of_type::<Editor>(cx) {
|
||||
let Some(buffer) = editor.read(cx).buffer().read(cx).as_singleton() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let timestamp = timestamps_by_entity_id
|
||||
.get(&editor.entity_id())
|
||||
.copied()
|
||||
.unwrap_or_default();
|
||||
if most_recent_buffer
|
||||
.as_ref()
|
||||
.map_or(true, |(_, prev_timestamp)| timestamp > *prev_timestamp)
|
||||
{
|
||||
most_recent_buffer = Some((buffer, timestamp));
|
||||
}
|
||||
}
|
||||
|
||||
subscriptions.push({
|
||||
let workspace_view = cx.view().clone();
|
||||
let invalidate_tx = invalidate_tx.clone();
|
||||
cx.window_context()
|
||||
.subscribe(&workspace_view, move |_workspace, event, _cx| match event {
|
||||
WorkspaceEvent::ActiveItemChanged
|
||||
| WorkspaceEvent::ItemAdded
|
||||
| WorkspaceEvent::ItemRemoved
|
||||
| WorkspaceEvent::PaneAdded(_)
|
||||
| WorkspaceEvent::PaneRemoved => {
|
||||
if let Some(invalidate_tx) = invalidate_tx.take() {
|
||||
_ = invalidate_tx.send(());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
});
|
||||
|
||||
if let Some((buffer, _)) = most_recent_buffer {
|
||||
subscriptions.push({
|
||||
let invalidate_tx = invalidate_tx.clone();
|
||||
cx.window_context().observe(&buffer, move |_buffer, _cx| {
|
||||
if let Some(invalidate_tx) = invalidate_tx.take() {
|
||||
_ = invalidate_tx.send(());
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let path = snapshot.resolve_file_path(cx, true);
|
||||
cx.background_executor().spawn(async move {
|
||||
let path = path
|
||||
.as_ref()
|
||||
.map(|path| path.to_string_lossy())
|
||||
.unwrap_or_else(|| Cow::Borrowed("untitled"));
|
||||
|
||||
let mut output = String::with_capacity(path.len() + snapshot.len() + 9);
|
||||
output.push_str("```");
|
||||
output.push_str(&path);
|
||||
output.push('\n');
|
||||
for chunk in snapshot.as_rope().chunks() {
|
||||
output.push_str(chunk);
|
||||
}
|
||||
if !output.ends_with('\n') {
|
||||
output.push('\n');
|
||||
}
|
||||
output.push_str("```");
|
||||
Ok(output)
|
||||
})
|
||||
} else {
|
||||
Task::ready(Err(anyhow!("no recent buffer found")))
|
||||
}
|
||||
});
|
||||
|
||||
SlashCommandInvocation {
|
||||
output: output.unwrap_or_else(|error| Task::ready(Err(error))),
|
||||
invalidated: invalidate_rx,
|
||||
cleanup: SlashCommandCleanup::new(move || drop(subscriptions)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,145 +0,0 @@
|
||||
use super::{SlashCommand, SlashCommandCleanup, SlashCommandInvocation};
|
||||
use anyhow::Result;
|
||||
use futures::channel::oneshot;
|
||||
use fuzzy::PathMatch;
|
||||
use gpui::{AppContext, Model, Task};
|
||||
use project::{PathMatchCandidateSet, Project};
|
||||
use std::{
|
||||
path::Path,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
};
|
||||
|
||||
pub(crate) struct FileSlashCommand {
|
||||
project: Model<Project>,
|
||||
}
|
||||
|
||||
impl FileSlashCommand {
|
||||
pub fn new(project: Model<Project>) -> Self {
|
||||
Self { project }
|
||||
}
|
||||
|
||||
fn search_paths(
|
||||
&self,
|
||||
query: String,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Vec<PathMatch>> {
|
||||
let worktrees = self
|
||||
.project
|
||||
.read(cx)
|
||||
.visible_worktrees(cx)
|
||||
.collect::<Vec<_>>();
|
||||
let include_root_name = worktrees.len() > 1;
|
||||
let candidate_sets = worktrees
|
||||
.into_iter()
|
||||
.map(|worktree| {
|
||||
let worktree = worktree.read(cx);
|
||||
PathMatchCandidateSet {
|
||||
snapshot: worktree.snapshot(),
|
||||
include_ignored: worktree
|
||||
.root_entry()
|
||||
.map_or(false, |entry| entry.is_ignored),
|
||||
include_root_name,
|
||||
directories_only: false,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let executor = cx.background_executor().clone();
|
||||
cx.foreground_executor().spawn(async move {
|
||||
fuzzy::match_path_sets(
|
||||
candidate_sets.as_slice(),
|
||||
query.as_str(),
|
||||
None,
|
||||
false,
|
||||
100,
|
||||
&cancellation_flag,
|
||||
executor,
|
||||
)
|
||||
.await
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl SlashCommand for FileSlashCommand {
|
||||
fn name(&self) -> String {
|
||||
"file".into()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"insert an entire file".into()
|
||||
}
|
||||
|
||||
fn requires_argument(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn complete_argument(
|
||||
&self,
|
||||
query: String,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
cx: &mut AppContext,
|
||||
) -> gpui::Task<Result<Vec<String>>> {
|
||||
let paths = self.search_paths(query, cancellation_flag, cx);
|
||||
cx.background_executor().spawn(async move {
|
||||
Ok(paths
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|path_match| {
|
||||
format!(
|
||||
"{}{}",
|
||||
path_match.path_prefix,
|
||||
path_match.path.to_string_lossy()
|
||||
)
|
||||
})
|
||||
.collect())
|
||||
})
|
||||
}
|
||||
|
||||
fn run(&self, argument: Option<&str>, cx: &mut AppContext) -> SlashCommandInvocation {
|
||||
let project = self.project.read(cx);
|
||||
let Some(argument) = argument else {
|
||||
return SlashCommandInvocation {
|
||||
output: Task::ready(Err(anyhow::anyhow!("missing path"))),
|
||||
invalidated: oneshot::channel().1,
|
||||
cleanup: SlashCommandCleanup::default(),
|
||||
};
|
||||
};
|
||||
|
||||
let path = Path::new(argument);
|
||||
let abs_path = project.worktrees().find_map(|worktree| {
|
||||
let worktree = worktree.read(cx);
|
||||
worktree.entry_for_path(path)?;
|
||||
worktree.absolutize(path).ok()
|
||||
});
|
||||
|
||||
let Some(abs_path) = abs_path else {
|
||||
return SlashCommandInvocation {
|
||||
output: Task::ready(Err(anyhow::anyhow!("missing path"))),
|
||||
invalidated: oneshot::channel().1,
|
||||
cleanup: SlashCommandCleanup::default(),
|
||||
};
|
||||
};
|
||||
|
||||
let fs = project.fs().clone();
|
||||
let argument = argument.to_string();
|
||||
let output = cx.background_executor().spawn(async move {
|
||||
let content = fs.load(&abs_path).await?;
|
||||
let mut output = String::with_capacity(argument.len() + content.len() + 9);
|
||||
output.push_str("```");
|
||||
output.push_str(&argument);
|
||||
output.push('\n');
|
||||
output.push_str(&content);
|
||||
if !output.ends_with('\n') {
|
||||
output.push('\n');
|
||||
}
|
||||
output.push_str("```");
|
||||
Ok(output)
|
||||
});
|
||||
SlashCommandInvocation {
|
||||
output,
|
||||
invalidated: oneshot::channel().1,
|
||||
cleanup: SlashCommandCleanup::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
use super::{SlashCommand, SlashCommandCleanup, SlashCommandInvocation};
|
||||
use crate::prompts::prompt_library::PromptLibrary;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use futures::channel::oneshot;
|
||||
use fuzzy::StringMatchCandidate;
|
||||
use gpui::{AppContext, Task};
|
||||
use std::sync::{atomic::AtomicBool, Arc};
|
||||
|
||||
pub(crate) struct PromptSlashCommand {
|
||||
library: Arc<PromptLibrary>,
|
||||
}
|
||||
|
||||
impl PromptSlashCommand {
|
||||
pub fn new(library: Arc<PromptLibrary>) -> Self {
|
||||
Self { library }
|
||||
}
|
||||
}
|
||||
|
||||
impl SlashCommand for PromptSlashCommand {
|
||||
fn name(&self) -> String {
|
||||
"prompt".into()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"insert a prompt from the library".into()
|
||||
}
|
||||
|
||||
fn requires_argument(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn complete_argument(
|
||||
&self,
|
||||
query: String,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
cx: &mut AppContext,
|
||||
) -> Task<Result<Vec<String>>> {
|
||||
let library = self.library.clone();
|
||||
let executor = cx.background_executor().clone();
|
||||
cx.background_executor().spawn(async move {
|
||||
let candidates = library
|
||||
.prompts()
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(ix, prompt)| {
|
||||
prompt
|
||||
.1
|
||||
.title()
|
||||
.map(|title| StringMatchCandidate::new(ix, title.into()))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let matches = fuzzy::match_strings(
|
||||
&candidates,
|
||||
&query,
|
||||
false,
|
||||
100,
|
||||
&cancellation_flag,
|
||||
executor,
|
||||
)
|
||||
.await;
|
||||
Ok(matches
|
||||
.into_iter()
|
||||
.map(|mat| candidates[mat.candidate_id].string.clone())
|
||||
.collect())
|
||||
})
|
||||
}
|
||||
|
||||
fn run(&self, title: Option<&str>, cx: &mut AppContext) -> SlashCommandInvocation {
|
||||
let Some(title) = title else {
|
||||
return SlashCommandInvocation {
|
||||
output: Task::ready(Err(anyhow!("missing prompt name"))),
|
||||
invalidated: oneshot::channel().1,
|
||||
cleanup: SlashCommandCleanup::default(),
|
||||
};
|
||||
};
|
||||
|
||||
let library = self.library.clone();
|
||||
let title = title.to_string();
|
||||
let output = cx.background_executor().spawn(async move {
|
||||
let prompt = library
|
||||
.prompts()
|
||||
.into_iter()
|
||||
.filter_map(|prompt| prompt.1.title().map(|title| (title, prompt)))
|
||||
.find(|(t, _)| t == &title)
|
||||
.with_context(|| format!("no prompt found with title {:?}", title))?
|
||||
.1;
|
||||
Ok(prompt.1.content().to_owned())
|
||||
});
|
||||
SlashCommandInvocation {
|
||||
output,
|
||||
invalidated: oneshot::channel().1,
|
||||
cleanup: SlashCommandCleanup::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
When the user asks you to suggest edits for a buffer, use a strict template consisting of:
|
||||
|
||||
* A markdown code block with the file path as the language identifier.
|
||||
* The original code that should be replaced
|
||||
* A separator line (`---`)
|
||||
* The new text that should replace the original lines
|
||||
|
||||
Each code block may only contain an edit for one single contiguous range of text. Use multiple code blocks for multiple edits.
|
||||
|
||||
## Example
|
||||
|
||||
If you have a buffer with the following lines:
|
||||
|
||||
```path/to/file.rs
|
||||
fn quicksort(arr: &mut [i32]) {
|
||||
if arr.len() <= 1 {
|
||||
return;
|
||||
}
|
||||
let pivot_index = partition(arr);
|
||||
let (left, right) = arr.split_at_mut(pivot_index);
|
||||
quicksort(left);
|
||||
quicksort(&mut right[1..]);
|
||||
}
|
||||
|
||||
fn partition(arr: &mut [i32]) -> usize {
|
||||
let last_index = arr.len() - 1;
|
||||
let pivot = arr[last_index];
|
||||
let mut i = 0;
|
||||
for j in 0..last_index {
|
||||
if arr[j] <= pivot {
|
||||
arr.swap(i, j);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
arr.swap(i, last_index);
|
||||
i
|
||||
}
|
||||
```
|
||||
|
||||
And you want to replace the for loop inside `partition`, output the following.
|
||||
|
||||
```edit path/to/file.rs
|
||||
for j in 0..last_index {
|
||||
if arr[j] <= pivot {
|
||||
arr.swap(i, j);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
---
|
||||
let mut j = 0;
|
||||
while j < last_index {
|
||||
if arr[j] <= pivot {
|
||||
arr.swap(i, j);
|
||||
i += 1;
|
||||
}
|
||||
j += 1;
|
||||
}
|
||||
```
|
||||
|
||||
If you wanted to insert comments above the partition function, output the following:
|
||||
|
||||
```edit path/to/file.rs
|
||||
fn partition(arr: &mut [i32]) -> usize {
|
||||
---
|
||||
// A helper function used for quicksort.
|
||||
fn partition(arr: &mut [i32]) -> usize {
|
||||
```
|
||||
|
||||
If you wanted to delete the partition function, output the following:
|
||||
|
||||
```edit path/to/file.rs
|
||||
fn partition(arr: &mut [i32]) -> usize {
|
||||
let last_index = arr.len() - 1;
|
||||
let pivot = arr[last_index];
|
||||
let mut i = 0;
|
||||
for j in 0..last_index {
|
||||
if arr[j] <= pivot {
|
||||
arr.swap(i, j);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
arr.swap(i, last_index);
|
||||
i
|
||||
}
|
||||
---
|
||||
```
|
||||
@@ -19,22 +19,20 @@ stories = ["dep:story"]
|
||||
anyhow.workspace = true
|
||||
assistant_tooling.workspace = true
|
||||
client.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
file_icons.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
fuzzy.workspace = true
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
markdown.workspace = true
|
||||
nanoid.workspace = true
|
||||
open_ai.workspace = true
|
||||
picker.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
rich_text.workspace = true
|
||||
schemars.workspace = true
|
||||
semantic_index.workspace = true
|
||||
serde.workspace = true
|
||||
@@ -44,7 +42,6 @@ story = { workspace = true, optional = true }
|
||||
theme.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
unindent.workspace = true
|
||||
workspace.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
@@ -54,7 +51,6 @@ env_logger.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
languages.workspace = true
|
||||
markdown = { workspace = true, features = ["test-support"] }
|
||||
node_runtime.workspace = true
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
rand.workspace = true
|
||||
@@ -62,5 +58,4 @@ release_channel.workspace = true
|
||||
settings = { workspace = true, features = ["test-support"] }
|
||||
theme = { workspace = true, features = ["test-support"] }
|
||||
util = { workspace = true, features = ["test-support"] }
|
||||
http = { workspace = true, features = ["test-support"] }
|
||||
workspace = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -2,41 +2,41 @@ mod assistant_settings;
|
||||
mod attachments;
|
||||
mod completion_provider;
|
||||
mod saved_conversation;
|
||||
mod saved_conversations;
|
||||
mod saved_conversation_picker;
|
||||
mod tools;
|
||||
pub mod ui;
|
||||
|
||||
use crate::saved_conversation::SavedConversationMetadata;
|
||||
use crate::ui::UserOrAssistant;
|
||||
use crate::saved_conversation::{SavedConversation, SavedMessage, SavedMessageRole};
|
||||
use crate::saved_conversation_picker::SavedConversationPicker;
|
||||
use crate::{
|
||||
attachments::ActiveEditorAttachmentTool,
|
||||
tools::{CreateBufferTool, ProjectIndexTool},
|
||||
ui::UserOrAssistant,
|
||||
};
|
||||
use ::ui::{div, prelude::*, Color, Tooltip, ViewContext};
|
||||
use anyhow::{Context, Result};
|
||||
use assistant_tooling::{
|
||||
AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment,
|
||||
};
|
||||
use attachments::ActiveEditorAttachmentTool;
|
||||
use client::{proto, Client, UserStore};
|
||||
use collections::HashMap;
|
||||
use completion_provider::*;
|
||||
use editor::Editor;
|
||||
use feature_flags::FeatureFlagAppExt as _;
|
||||
use file_icons::FileIcons;
|
||||
use fs::Fs;
|
||||
use futures::{future::join_all, StreamExt};
|
||||
use gpui::{
|
||||
list, AnyElement, AppContext, AsyncWindowContext, ClickEvent, EventEmitter, FocusHandle,
|
||||
FocusableView, ListAlignment, ListState, Model, ReadGlobal, Render, Task, UpdateGlobal, View,
|
||||
WeakView,
|
||||
FocusableView, ListAlignment, ListState, Model, Render, Task, View, WeakView,
|
||||
};
|
||||
use language::{language_settings::SoftWrap, LanguageRegistry};
|
||||
use markdown::{Markdown, MarkdownStyle};
|
||||
use open_ai::{FunctionContent, ToolCall, ToolCallContent};
|
||||
use saved_conversation::{SavedAssistantMessagePart, SavedChatMessage, SavedConversation};
|
||||
use saved_conversations::SavedConversations;
|
||||
use rich_text::RichText;
|
||||
use semantic_index::{CloudEmbeddingProvider, ProjectIndex, ProjectIndexDebugView, SemanticIndex};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
use std::sync::Arc;
|
||||
use tools::{AnnotationTool, CreateBufferTool, ProjectIndexTool};
|
||||
use tools::AnnotationTool;
|
||||
use ui::{ActiveFileButton, Composer, ProjectIndexButton};
|
||||
use util::paths::CONVERSATIONS_DIR;
|
||||
use util::{maybe, paths::EMBEDDINGS_DIR, ResultExt};
|
||||
@@ -63,7 +63,15 @@ pub enum SubmitMode {
|
||||
Codebase,
|
||||
}
|
||||
|
||||
gpui::actions!(assistant2, [Cancel, ToggleFocus, DebugProjectIndex,]);
|
||||
gpui::actions!(
|
||||
assistant2,
|
||||
[
|
||||
Cancel,
|
||||
ToggleFocus,
|
||||
DebugProjectIndex,
|
||||
ToggleSavedConversations
|
||||
]
|
||||
);
|
||||
gpui::impl_actions!(assistant2, [Submit]);
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
@@ -103,6 +111,8 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
cx.observe_new_views(SavedConversationPicker::register)
|
||||
.detach();
|
||||
}
|
||||
|
||||
pub fn enabled(cx: &AppContext) -> bool {
|
||||
@@ -125,25 +135,22 @@ impl AssistantPanel {
|
||||
})?;
|
||||
|
||||
cx.new_view(|cx| {
|
||||
let project_index = SemanticIndex::update_global(cx, |semantic_index, cx| {
|
||||
let project_index = cx.update_global(|semantic_index: &mut SemanticIndex, cx| {
|
||||
semantic_index.project_index(project.clone(), cx)
|
||||
});
|
||||
|
||||
// Used in tools to render file icons
|
||||
cx.observe_global::<FileIcons>(|_, cx| {
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
|
||||
let mut tool_registry = ToolRegistry::new();
|
||||
tool_registry
|
||||
.register(ProjectIndexTool::new(project_index.clone()))
|
||||
.register(ProjectIndexTool::new(project_index.clone()), cx)
|
||||
.unwrap();
|
||||
tool_registry
|
||||
.register(CreateBufferTool::new(workspace.clone(), project.clone()))
|
||||
.register(
|
||||
CreateBufferTool::new(workspace.clone(), project.clone()),
|
||||
cx,
|
||||
)
|
||||
.unwrap();
|
||||
tool_registry
|
||||
.register(AnnotationTool::new(workspace.clone(), project.clone()))
|
||||
.register(AnnotationTool::new(workspace.clone(), project.clone()), cx)
|
||||
.unwrap();
|
||||
|
||||
let mut attachment_registry = AttachmentRegistry::new();
|
||||
@@ -257,8 +264,6 @@ pub struct AssistantChat {
|
||||
fs: Arc<dyn Fs>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
composer_editor: View<Editor>,
|
||||
saved_conversations: View<SavedConversations>,
|
||||
saved_conversations_open: bool,
|
||||
project_index_button: View<ProjectIndexButton>,
|
||||
active_file_button: Option<View<ActiveFileButton>>,
|
||||
user_store: Model<UserStore>,
|
||||
@@ -269,11 +274,11 @@ pub struct AssistantChat {
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
attachment_registry: Arc<AttachmentRegistry>,
|
||||
project_index: Model<ProjectIndex>,
|
||||
markdown_style: MarkdownStyle,
|
||||
}
|
||||
|
||||
struct EditingMessage {
|
||||
id: MessageId,
|
||||
old_body: Arc<str>,
|
||||
body: View<Editor>,
|
||||
}
|
||||
|
||||
@@ -289,7 +294,7 @@ impl AssistantChat {
|
||||
workspace: WeakView<Workspace>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Self {
|
||||
let model = CompletionProvider::global(cx).default_model();
|
||||
let model = CompletionProvider::get(cx).default_model();
|
||||
let view = cx.view().downgrade();
|
||||
let list_state = ListState::new(
|
||||
0,
|
||||
@@ -314,24 +319,6 @@ impl AssistantChat {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let saved_conversations = cx.new_view(|cx| SavedConversations::new(cx));
|
||||
cx.spawn({
|
||||
let fs = fs.clone();
|
||||
let saved_conversations = saved_conversations.downgrade();
|
||||
|_assistant_chat, mut cx| async move {
|
||||
let saved_conversation_metadata = SavedConversationMetadata::list(fs).await?;
|
||||
|
||||
cx.update(|cx| {
|
||||
saved_conversations.update(cx, |this, cx| {
|
||||
this.init(saved_conversation_metadata, cx);
|
||||
})
|
||||
})??;
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
Self {
|
||||
model,
|
||||
messages: Vec::new(),
|
||||
@@ -341,8 +328,6 @@ impl AssistantChat {
|
||||
editor.set_placeholder_text("Send a message…", cx);
|
||||
editor
|
||||
}),
|
||||
saved_conversations,
|
||||
saved_conversations_open: false,
|
||||
list_state,
|
||||
user_store,
|
||||
fs,
|
||||
@@ -356,60 +341,30 @@ impl AssistantChat {
|
||||
pending_completion: None,
|
||||
attachment_registry,
|
||||
tool_registry,
|
||||
markdown_style: MarkdownStyle {
|
||||
code_block: gpui::TextStyleRefinement {
|
||||
font_family: Some("Zed Mono".into()),
|
||||
color: Some(cx.theme().colors().editor_foreground),
|
||||
background_color: Some(cx.theme().colors().editor_background),
|
||||
..Default::default()
|
||||
},
|
||||
inline_code: gpui::TextStyleRefinement {
|
||||
font_family: Some("Zed Mono".into()),
|
||||
// @nate: Could we add inline-code specific styles to the theme?
|
||||
color: Some(cx.theme().colors().editor_foreground),
|
||||
background_color: Some(cx.theme().colors().editor_background),
|
||||
..Default::default()
|
||||
},
|
||||
rule_color: Color::Muted.color(cx),
|
||||
block_quote_border_color: Color::Muted.color(cx),
|
||||
block_quote: gpui::TextStyleRefinement {
|
||||
color: Some(Color::Muted.color(cx)),
|
||||
..Default::default()
|
||||
},
|
||||
link: gpui::TextStyleRefinement {
|
||||
color: Some(Color::Accent.color(cx)),
|
||||
underline: Some(gpui::UnderlineStyle {
|
||||
thickness: px(1.),
|
||||
color: Some(Color::Accent.color(cx)),
|
||||
wavy: false,
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
selection_background_color: {
|
||||
let mut selection = cx.theme().players().local().selection;
|
||||
selection.fade_out(0.7);
|
||||
selection
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn message_for_id(&self, id: MessageId) -> Option<&ChatMessage> {
|
||||
self.messages.iter().find(|message| match message {
|
||||
ChatMessage::User(message) => message.id == id,
|
||||
ChatMessage::Assistant(message) => message.id == id,
|
||||
})
|
||||
fn editing_message_id(&self) -> Option<MessageId> {
|
||||
self.editing_message.as_ref().map(|message| message.id)
|
||||
}
|
||||
|
||||
fn toggle_saved_conversations(&mut self) {
|
||||
self.saved_conversations_open = !self.saved_conversations_open;
|
||||
fn focused_message_id(&self, cx: &WindowContext) -> Option<MessageId> {
|
||||
self.messages.iter().find_map(|message| match message {
|
||||
ChatMessage::User(message) => message
|
||||
.body
|
||||
.focus_handle(cx)
|
||||
.contains_focused(cx)
|
||||
.then_some(message.id),
|
||||
ChatMessage::Assistant(_) => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext<Self>) {
|
||||
// If we're currently editing a message, cancel the edit.
|
||||
if self.editing_message.take().is_some() {
|
||||
cx.notify();
|
||||
if let Some(editing_message) = self.editing_message.take() {
|
||||
editing_message
|
||||
.body
|
||||
.update(cx, |body, cx| body.set_text(editing_message.old_body, cx));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -426,7 +381,14 @@ impl AssistantChat {
|
||||
}
|
||||
|
||||
fn submit(&mut self, Submit(mode): &Submit, cx: &mut ViewContext<Self>) {
|
||||
if self.composer_editor.focus_handle(cx).is_focused(cx) {
|
||||
if let Some(focused_message_id) = self.focused_message_id(cx) {
|
||||
self.truncate_messages(focused_message_id, cx);
|
||||
self.pending_completion.take();
|
||||
self.composer_editor.focus_handle(cx).focus(cx);
|
||||
if self.editing_message_id() == Some(focused_message_id) {
|
||||
self.editing_message.take();
|
||||
}
|
||||
} else if self.composer_editor.focus_handle(cx).is_focused(cx) {
|
||||
// Don't allow multiple concurrent completions.
|
||||
if self.pending_completion.is_some() {
|
||||
cx.propagate();
|
||||
@@ -437,12 +399,10 @@ impl AssistantChat {
|
||||
let text = composer_editor.text(cx);
|
||||
let id = self.next_message_id.post_inc();
|
||||
let body = cx.new_view(|cx| {
|
||||
Markdown::new(
|
||||
text,
|
||||
self.markdown_style.clone(),
|
||||
Some(self.language_registry.clone()),
|
||||
cx,
|
||||
)
|
||||
let mut editor = Editor::auto_height(80, cx);
|
||||
editor.set_text(text, cx);
|
||||
editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
|
||||
editor
|
||||
});
|
||||
composer_editor.clear(cx);
|
||||
|
||||
@@ -453,26 +413,6 @@ impl AssistantChat {
|
||||
})
|
||||
});
|
||||
self.push_message(message, cx);
|
||||
} else if let Some(editing_message) = self.editing_message.as_ref() {
|
||||
let focus_handle = editing_message.body.focus_handle(cx);
|
||||
if focus_handle.contains_focused(cx) {
|
||||
if let Some(ChatMessage::User(user_message)) =
|
||||
self.message_for_id(editing_message.id)
|
||||
{
|
||||
user_message.body.update(cx, |body, cx| {
|
||||
body.reset(editing_message.body.read(cx).text(cx), cx);
|
||||
});
|
||||
}
|
||||
|
||||
self.truncate_messages(editing_message.id, cx);
|
||||
|
||||
self.pending_completion.take();
|
||||
self.composer_editor.focus_handle(cx).focus(cx);
|
||||
self.editing_message.take();
|
||||
} else {
|
||||
log::error!("unexpected state: no user message editor is focused.");
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
log::error!("unexpected state: no user message editor is focused.");
|
||||
return;
|
||||
@@ -551,7 +491,7 @@ impl AssistantChat {
|
||||
let messages = messages.await?;
|
||||
|
||||
let completion = cx.update(|cx| {
|
||||
CompletionProvider::global(cx).complete(
|
||||
CompletionProvider::get(cx).complete(
|
||||
model_name,
|
||||
messages,
|
||||
Vec::new(),
|
||||
@@ -561,22 +501,18 @@ impl AssistantChat {
|
||||
});
|
||||
|
||||
let mut stream = completion?.await?;
|
||||
let mut body = String::new();
|
||||
while let Some(delta) = stream.next().await {
|
||||
let delta = delta?;
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
|
||||
this.messages.last_mut()
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
messages,
|
||||
..
|
||||
})) = this.messages.last_mut()
|
||||
{
|
||||
if messages.is_empty() {
|
||||
messages.push(AssistantMessagePart {
|
||||
body: cx.new_view(|cx| {
|
||||
Markdown::new(
|
||||
"".into(),
|
||||
this.markdown_style.clone(),
|
||||
Some(this.language_registry.clone()),
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
messages.push(AssistantMessage {
|
||||
body: RichText::default(),
|
||||
tool_calls: Vec::new(),
|
||||
})
|
||||
}
|
||||
@@ -584,37 +520,35 @@ impl AssistantChat {
|
||||
let message = messages.last_mut().unwrap();
|
||||
|
||||
if let Some(content) = &delta.content {
|
||||
message
|
||||
.body
|
||||
.update(cx, |message, cx| message.append(&content, cx));
|
||||
body.push_str(content);
|
||||
}
|
||||
|
||||
for tool_call_delta in delta.tool_calls {
|
||||
let index = tool_call_delta.index as usize;
|
||||
for tool_call in delta.tool_calls {
|
||||
let index = tool_call.index as usize;
|
||||
if index >= message.tool_calls.len() {
|
||||
message.tool_calls.resize_with(index + 1, Default::default);
|
||||
}
|
||||
let tool_call = &mut message.tool_calls[index];
|
||||
let call = &mut message.tool_calls[index];
|
||||
|
||||
if let Some(id) = &tool_call_delta.id {
|
||||
tool_call.id.push_str(id);
|
||||
if let Some(id) = &tool_call.id {
|
||||
call.id.push_str(id);
|
||||
}
|
||||
|
||||
match tool_call_delta.variant {
|
||||
Some(proto::tool_call_delta::Variant::Function(
|
||||
tool_call_delta,
|
||||
)) => {
|
||||
this.tool_registry.update_tool_call(
|
||||
tool_call,
|
||||
tool_call_delta.name.as_deref(),
|
||||
tool_call_delta.arguments.as_deref(),
|
||||
cx,
|
||||
);
|
||||
match tool_call.variant {
|
||||
Some(proto::tool_call_delta::Variant::Function(tool_call)) => {
|
||||
if let Some(name) = &tool_call.name {
|
||||
call.name.push_str(name);
|
||||
}
|
||||
if let Some(arguments) = &tool_call.arguments {
|
||||
call.arguments.push_str(arguments);
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
|
||||
message.body =
|
||||
RichText::new(body.clone(), &[], &this.language_registry);
|
||||
cx.notify();
|
||||
} else {
|
||||
unreachable!()
|
||||
@@ -628,7 +562,7 @@ impl AssistantChat {
|
||||
|
||||
let mut tool_tasks = Vec::new();
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(ChatMessage::Assistant(AssistantMessage {
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
error: message_error,
|
||||
messages,
|
||||
..
|
||||
@@ -639,54 +573,54 @@ impl AssistantChat {
|
||||
cx.notify();
|
||||
} else {
|
||||
if let Some(current_message) = messages.last_mut() {
|
||||
for tool_call in current_message.tool_calls.iter_mut() {
|
||||
tool_tasks
|
||||
.extend(this.tool_registry.execute_tool_call(tool_call, cx));
|
||||
for tool_call in current_message.tool_calls.iter() {
|
||||
tool_tasks.push(this.tool_registry.call(tool_call, cx));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
// This ends recursion on calling for responses after tools
|
||||
if tool_tasks.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
join_all(tool_tasks.into_iter()).await;
|
||||
let tools = join_all(tool_tasks.into_iter()).await;
|
||||
// If the WindowContext went away for any tool's view we don't include it
|
||||
// especially since the below call would fail for the same reason.
|
||||
let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
|
||||
this.messages.last_mut()
|
||||
{
|
||||
if let Some(current_message) = messages.last_mut() {
|
||||
current_message.tool_calls = tools;
|
||||
cx.notify();
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
|
||||
// If the last message is a grouped assistant message, add to the grouped message
|
||||
if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
|
||||
if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
|
||||
self.messages.last_mut()
|
||||
{
|
||||
messages.push(AssistantMessagePart {
|
||||
body: cx.new_view(|cx| {
|
||||
Markdown::new(
|
||||
"".into(),
|
||||
self.markdown_style.clone(),
|
||||
Some(self.language_registry.clone()),
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
messages.push(AssistantMessage {
|
||||
body: RichText::default(),
|
||||
tool_calls: Vec::new(),
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
let message = ChatMessage::Assistant(AssistantMessage {
|
||||
let message = ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
id: self.next_message_id.post_inc(),
|
||||
messages: vec![AssistantMessagePart {
|
||||
body: cx.new_view(|cx| {
|
||||
Markdown::new(
|
||||
"".into(),
|
||||
self.markdown_style.clone(),
|
||||
Some(self.language_registry.clone()),
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
messages: vec![AssistantMessage {
|
||||
body: RichText::default(),
|
||||
tool_calls: Vec::new(),
|
||||
}],
|
||||
error: None,
|
||||
@@ -734,30 +668,40 @@ impl AssistantChat {
|
||||
*entry = !*entry;
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.messages.clear();
|
||||
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) {
|
||||
let messages = self
|
||||
.messages
|
||||
.drain(..)
|
||||
.map(|message| {
|
||||
let text = match &message {
|
||||
ChatMessage::User(message) => message.body.read(cx).text(cx),
|
||||
ChatMessage::Assistant(message) => message
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| message.body.text.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n"),
|
||||
};
|
||||
|
||||
SavedMessage {
|
||||
id: message.id(),
|
||||
role: match message {
|
||||
ChatMessage::User(_) => SavedMessageRole::User,
|
||||
ChatMessage::Assistant(_) => SavedMessageRole::Assistant,
|
||||
},
|
||||
text,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Reset the chat for the new conversation.
|
||||
self.list_state.reset(0);
|
||||
self.editing_message.take();
|
||||
self.collapsed_messages.clear();
|
||||
}
|
||||
|
||||
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) {
|
||||
let messages = std::mem::take(&mut self.messages)
|
||||
.into_iter()
|
||||
.map(|message| self.serialize_message(message, cx))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
self.reset();
|
||||
|
||||
let title = messages
|
||||
.first()
|
||||
.map(|message| match message {
|
||||
SavedChatMessage::User { body, .. } => body.clone(),
|
||||
SavedChatMessage::Assistant { messages, .. } => messages
|
||||
.first()
|
||||
.map(|message| message.body.to_string())
|
||||
.unwrap_or_default(),
|
||||
})
|
||||
.map(|message| message.text.clone())
|
||||
.unwrap_or_else(|| "A conversation with the assistant.".to_string());
|
||||
|
||||
let saved_conversation = SavedConversation {
|
||||
@@ -829,72 +773,69 @@ impl AssistantChat {
|
||||
.id(SharedString::from(format!("message-{}-container", id.0)))
|
||||
.when(is_first, |this| this.pt(padding))
|
||||
.map(|element| {
|
||||
if let Some(editing_message) = self.editing_message.as_ref() {
|
||||
if editing_message.id == *id {
|
||||
return element.child(Composer::new(
|
||||
editing_message.body.clone(),
|
||||
self.project_index_button.clone(),
|
||||
self.active_file_button.clone(),
|
||||
crate::ui::ModelSelector::new(
|
||||
cx.view().downgrade(),
|
||||
self.model.clone(),
|
||||
)
|
||||
.into_any_element(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
element
|
||||
.on_click(cx.listener({
|
||||
let id = *id;
|
||||
let body = body.clone();
|
||||
move |assistant_chat, event: &ClickEvent, cx| {
|
||||
if event.up.click_count == 2 {
|
||||
let body = cx.new_view(|cx| {
|
||||
let mut editor = Editor::auto_height(80, cx);
|
||||
let source = Arc::from(body.read(cx).source());
|
||||
editor.set_text(source, cx);
|
||||
editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
|
||||
editor
|
||||
});
|
||||
assistant_chat.editing_message = Some(EditingMessage {
|
||||
id,
|
||||
body: body.clone(),
|
||||
});
|
||||
body.focus_handle(cx).focus(cx);
|
||||
}
|
||||
}
|
||||
}))
|
||||
.child(
|
||||
crate::ui::ChatMessage::new(
|
||||
*id,
|
||||
UserOrAssistant::User(self.user_store.read(cx).current_user()),
|
||||
// todo!(): clean up the vec usage
|
||||
vec![
|
||||
body.clone().into_any_element(),
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.children(
|
||||
attachments
|
||||
.iter()
|
||||
.map(|attachment| attachment.view.clone()),
|
||||
)
|
||||
.into_any_element(),
|
||||
],
|
||||
self.is_message_collapsed(id),
|
||||
Box::new(cx.listener({
|
||||
let id = *id;
|
||||
move |assistant_chat, _event, _cx| {
|
||||
assistant_chat.toggle_message_collapsed(id)
|
||||
}
|
||||
})),
|
||||
if self.editing_message_id() == Some(*id) {
|
||||
element.child(Composer::new(
|
||||
body.clone(),
|
||||
self.project_index_button.clone(),
|
||||
self.active_file_button.clone(),
|
||||
crate::ui::ModelSelector::new(
|
||||
cx.view().downgrade(),
|
||||
self.model.clone(),
|
||||
)
|
||||
// TODO: Wire up selections.
|
||||
.selected(is_last),
|
||||
)
|
||||
.into_any_element(),
|
||||
))
|
||||
} else {
|
||||
element
|
||||
.on_click(cx.listener({
|
||||
let id = *id;
|
||||
let body = body.clone();
|
||||
move |assistant_chat, event: &ClickEvent, cx| {
|
||||
if event.up.click_count == 2 {
|
||||
assistant_chat.editing_message = Some(EditingMessage {
|
||||
id,
|
||||
body: body.clone(),
|
||||
old_body: body.read(cx).text(cx).into(),
|
||||
});
|
||||
body.focus_handle(cx).focus(cx);
|
||||
}
|
||||
}
|
||||
}))
|
||||
.child(
|
||||
crate::ui::ChatMessage::new(
|
||||
*id,
|
||||
UserOrAssistant::User(self.user_store.read(cx).current_user()),
|
||||
// todo!(): clean up the vec usage
|
||||
vec![
|
||||
RichText::new(
|
||||
body.read(cx).text(cx),
|
||||
&[],
|
||||
&self.language_registry,
|
||||
)
|
||||
.element(ElementId::from(id.0), cx),
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.children(
|
||||
attachments
|
||||
.iter()
|
||||
.map(|attachment| attachment.view.clone()),
|
||||
)
|
||||
.into_any_element(),
|
||||
],
|
||||
self.is_message_collapsed(id),
|
||||
Box::new(cx.listener({
|
||||
let id = *id;
|
||||
move |assistant_chat, _event, _cx| {
|
||||
assistant_chat.toggle_message_collapsed(id)
|
||||
}
|
||||
})),
|
||||
)
|
||||
// TODO: Wire up selections.
|
||||
.selected(is_last),
|
||||
)
|
||||
}
|
||||
})
|
||||
.into_any(),
|
||||
ChatMessage::Assistant(AssistantMessage {
|
||||
ChatMessage::Assistant(GroupedAssistantMessage {
|
||||
id,
|
||||
messages,
|
||||
error,
|
||||
@@ -903,25 +844,26 @@ impl AssistantChat {
|
||||
let mut message_elements = Vec::new();
|
||||
|
||||
for message in messages {
|
||||
if !message.body.read(cx).source().is_empty() {
|
||||
message_elements.push(div().child(message.body.clone()).into_any())
|
||||
if !message.body.text.is_empty() {
|
||||
message_elements.push(
|
||||
div()
|
||||
// todo!(): The element Id will need to be a combo of the base ID and the index within the grouping
|
||||
.child(message.body.element(ElementId::from(id.0), cx))
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
|
||||
let tools = message
|
||||
.tool_calls
|
||||
.iter()
|
||||
.filter_map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
|
||||
.map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
|
||||
.collect::<Vec<AnyElement>>();
|
||||
|
||||
if !tools.is_empty() {
|
||||
message_elements.push(div().children(tools).into_any())
|
||||
message_elements.push(div().children(tools).into_any_element())
|
||||
}
|
||||
}
|
||||
|
||||
if message_elements.is_empty() {
|
||||
message_elements.push(::ui::Label::new("Researching...").into_any_element())
|
||||
}
|
||||
|
||||
div()
|
||||
.when(is_first, |this| this.pt(padding))
|
||||
.child(
|
||||
@@ -967,14 +909,14 @@ impl AssistantChat {
|
||||
|
||||
// Show user's message last so that the assistant is grounded in the user's request
|
||||
completion_messages.push(CompletionMessage::User {
|
||||
content: body.read(cx).source().to_string(),
|
||||
content: body.read(cx).text(cx),
|
||||
});
|
||||
}
|
||||
ChatMessage::Assistant(AssistantMessage { messages, .. }) => {
|
||||
ChatMessage::Assistant(GroupedAssistantMessage { messages, .. }) => {
|
||||
for message in messages {
|
||||
let body = message.body.clone();
|
||||
|
||||
if body.read(cx).source().is_empty() && message.tool_calls.is_empty() {
|
||||
if body.text.is_empty() && message.tool_calls.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -993,17 +935,19 @@ impl AssistantChat {
|
||||
.collect();
|
||||
|
||||
completion_messages.push(CompletionMessage::Assistant {
|
||||
content: Some(body.read(cx).source().to_string()),
|
||||
content: Some(body.text.to_string()),
|
||||
tool_calls: tool_calls_from_assistant,
|
||||
});
|
||||
|
||||
for tool_call in &message.tool_calls {
|
||||
// Every tool call _must_ have a result by ID, otherwise OpenAI will error.
|
||||
let content = self.tool_registry.content_for_tool_call(
|
||||
tool_call,
|
||||
&mut project_context,
|
||||
cx,
|
||||
);
|
||||
let content = match &tool_call.result {
|
||||
Some(result) => {
|
||||
result.generate(&tool_call.name, &mut project_context, cx)
|
||||
}
|
||||
None => "".to_string(),
|
||||
};
|
||||
|
||||
completion_messages.push(CompletionMessage::Tool {
|
||||
content,
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
@@ -1022,53 +966,10 @@ impl AssistantChat {
|
||||
Ok(completion_messages)
|
||||
})
|
||||
}
|
||||
|
||||
fn serialize_message(
|
||||
&self,
|
||||
message: ChatMessage,
|
||||
cx: &mut ViewContext<AssistantChat>,
|
||||
) -> SavedChatMessage {
|
||||
match message {
|
||||
ChatMessage::User(message) => SavedChatMessage::User {
|
||||
id: message.id,
|
||||
body: message.body.read(cx).source().into(),
|
||||
attachments: message
|
||||
.attachments
|
||||
.iter()
|
||||
.map(|attachment| {
|
||||
self.attachment_registry
|
||||
.serialize_user_attachment(attachment)
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
ChatMessage::Assistant(message) => SavedChatMessage::Assistant {
|
||||
id: message.id,
|
||||
error: message.error,
|
||||
messages: message
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| SavedAssistantMessagePart {
|
||||
body: message.body.read(cx).source().to_string().into(),
|
||||
tool_calls: message
|
||||
.tool_calls
|
||||
.iter()
|
||||
.filter_map(|tool_call| {
|
||||
self.tool_registry
|
||||
.serialize_tool_call(tool_call, cx)
|
||||
.log_err()
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AssistantChat {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let header_height = Spacing::Small.rems(cx) * 2.0 + ButtonSize::Default.rems();
|
||||
|
||||
div()
|
||||
.relative()
|
||||
.flex_1()
|
||||
@@ -1077,60 +978,23 @@ impl Render for AssistantChat {
|
||||
.on_action(cx.listener(Self::submit))
|
||||
.on_action(cx.listener(Self::cancel))
|
||||
.text_color(Color::Default.color(cx))
|
||||
.child(list(self.list_state.clone()).flex_1().pt(header_height))
|
||||
.child(
|
||||
h_flex()
|
||||
.absolute()
|
||||
.top_0()
|
||||
.justify_between()
|
||||
.w_full()
|
||||
.h(header_height)
|
||||
.p(Spacing::Small.rems(cx))
|
||||
.gap_2()
|
||||
.child(
|
||||
IconButton::new(
|
||||
"toggle-saved-conversations",
|
||||
if self.saved_conversations_open {
|
||||
IconName::ChevronRight
|
||||
} else {
|
||||
IconName::ChevronLeft
|
||||
},
|
||||
)
|
||||
.on_click(cx.listener(|this, _event, _cx| {
|
||||
this.toggle_saved_conversations();
|
||||
}))
|
||||
.tooltip(move |cx| Tooltip::text("Switch Conversations", cx)),
|
||||
Button::new("open-saved-conversations", "Saved Conversations").on_click(
|
||||
|_event, cx| cx.dispatch_action(Box::new(ToggleSavedConversations)),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap(Spacing::Large.rems(cx))
|
||||
.child(
|
||||
IconButton::new("new-conversation", IconName::Plus)
|
||||
.on_click(cx.listener(move |this, _event, cx| {
|
||||
this.new_conversation(cx);
|
||||
}))
|
||||
.tooltip(move |cx| Tooltip::text("New Context", cx)),
|
||||
)
|
||||
.child(
|
||||
IconButton::new("assistant-menu", IconName::Menu)
|
||||
.disabled(true)
|
||||
.tooltip(move |cx| {
|
||||
Tooltip::text(
|
||||
"Coming soon – Assistant settings & controls",
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
),
|
||||
IconButton::new("new-conversation", IconName::Plus)
|
||||
.on_click(cx.listener(move |this, _event, cx| {
|
||||
this.new_conversation(cx);
|
||||
}))
|
||||
.tooltip(move |cx| Tooltip::text("New Conversation", cx)),
|
||||
),
|
||||
)
|
||||
.when(self.saved_conversations_open, |element| {
|
||||
element.child(
|
||||
h_flex()
|
||||
.absolute()
|
||||
.top(header_height)
|
||||
.w_full()
|
||||
.child(self.saved_conversations.clone()),
|
||||
)
|
||||
})
|
||||
.child(list(self.list_state.clone()).flex_1())
|
||||
.child(Composer::new(
|
||||
self.composer_editor.clone(),
|
||||
self.project_index_button.clone(),
|
||||
@@ -1154,31 +1018,38 @@ impl MessageId {
|
||||
|
||||
enum ChatMessage {
|
||||
User(UserMessage),
|
||||
Assistant(AssistantMessage),
|
||||
Assistant(GroupedAssistantMessage),
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn id(&self) -> MessageId {
|
||||
match self {
|
||||
ChatMessage::User(message) => message.id,
|
||||
ChatMessage::Assistant(message) => message.id,
|
||||
}
|
||||
}
|
||||
|
||||
fn focus_handle(&self, cx: &AppContext) -> Option<FocusHandle> {
|
||||
match self {
|
||||
ChatMessage::User(message) => Some(message.body.focus_handle(cx)),
|
||||
ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)),
|
||||
ChatMessage::Assistant(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct UserMessage {
|
||||
pub id: MessageId,
|
||||
pub body: View<Markdown>,
|
||||
pub attachments: Vec<UserAttachment>,
|
||||
}
|
||||
|
||||
struct AssistantMessagePart {
|
||||
pub body: View<Markdown>,
|
||||
pub tool_calls: Vec<ToolFunctionCall>,
|
||||
id: MessageId,
|
||||
body: View<Editor>,
|
||||
attachments: Vec<UserAttachment>,
|
||||
}
|
||||
|
||||
struct AssistantMessage {
|
||||
pub id: MessageId,
|
||||
pub messages: Vec<AssistantMessagePart>,
|
||||
pub error: Option<SharedString>,
|
||||
body: RichText,
|
||||
tool_calls: Vec<ToolFunctionCall>,
|
||||
}
|
||||
|
||||
struct GroupedAssistantMessage {
|
||||
id: MessageId,
|
||||
messages: Vec<AssistantMessage>,
|
||||
error: Option<SharedString>,
|
||||
}
|
||||
|
||||
@@ -1,68 +1,64 @@
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_tooling::{AttachmentOutput, LanguageModelAttachment, ProjectContext};
|
||||
use assistant_tooling::{LanguageModelAttachment, ProjectContext, ToolOutput};
|
||||
use editor::Editor;
|
||||
use gpui::{Render, Task, View, WeakModel, WeakView};
|
||||
use language::Buffer;
|
||||
use project::ProjectPath;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ui::{prelude::*, ButtonLike, Tooltip, WindowContext};
|
||||
use util::maybe;
|
||||
use workspace::Workspace;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ActiveEditorAttachment {
|
||||
#[serde(skip)]
|
||||
buffer: Option<WeakModel<Buffer>>,
|
||||
path: Option<PathBuf>,
|
||||
buffer: WeakModel<Buffer>,
|
||||
path: Option<ProjectPath>,
|
||||
}
|
||||
|
||||
pub struct FileAttachmentView {
|
||||
project_path: Option<ProjectPath>,
|
||||
buffer: Option<WeakModel<Buffer>>,
|
||||
error: Option<anyhow::Error>,
|
||||
output: Result<ActiveEditorAttachment>,
|
||||
}
|
||||
|
||||
impl Render for FileAttachmentView {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
if let Some(error) = &self.error {
|
||||
return div().child(error.to_string()).into_any_element();
|
||||
match &self.output {
|
||||
Ok(attachment) => {
|
||||
let filename: SharedString = attachment
|
||||
.path
|
||||
.as_ref()
|
||||
.and_then(|p| p.path.file_name()?.to_str())
|
||||
.unwrap_or("Untitled")
|
||||
.to_string()
|
||||
.into();
|
||||
|
||||
// todo!(): make the button link to the actual file to open
|
||||
ButtonLike::new("file-attachment")
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(ui::Icon::new(IconName::File))
|
||||
.child(filename.clone()),
|
||||
)
|
||||
.tooltip({
|
||||
move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx)
|
||||
})
|
||||
.into_any_element()
|
||||
}
|
||||
Err(err) => div().child(err.to_string()).into_any_element(),
|
||||
}
|
||||
|
||||
let filename: SharedString = self
|
||||
.project_path
|
||||
.as_ref()
|
||||
.and_then(|p| p.path.file_name()?.to_str())
|
||||
.unwrap_or("Untitled")
|
||||
.to_string()
|
||||
.into();
|
||||
|
||||
ButtonLike::new("file-attachment")
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_md()
|
||||
.child(ui::Icon::new(IconName::File))
|
||||
.child(filename.clone()),
|
||||
)
|
||||
.tooltip(move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx))
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
|
||||
impl AttachmentOutput for FileAttachmentView {
|
||||
impl ToolOutput for FileAttachmentView {
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
|
||||
if let Some(path) = &self.project_path {
|
||||
project.add_file(path.clone());
|
||||
return format!("current file: {}", path.path.display());
|
||||
if let Ok(result) = &self.output {
|
||||
if let Some(path) = &result.path {
|
||||
project.add_file(path.clone());
|
||||
return format!("current file: {}", path.path.display());
|
||||
} else if let Some(buffer) = result.buffer.upgrade() {
|
||||
return format!("current untitled buffer text:\n{}", buffer.read(cx).text());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(buffer) = self.buffer.as_ref().and_then(|buffer| buffer.upgrade()) {
|
||||
return format!("current untitled buffer text:\n{}", buffer.read(cx).text());
|
||||
}
|
||||
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
@@ -81,10 +77,6 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool {
|
||||
type Output = ActiveEditorAttachment;
|
||||
type View = FileAttachmentView;
|
||||
|
||||
fn name(&self) -> Arc<str> {
|
||||
"active-editor-attachment".into()
|
||||
}
|
||||
|
||||
fn run(&self, cx: &mut WindowContext) -> Task<Result<ActiveEditorAttachment>> {
|
||||
Task::ready(maybe!({
|
||||
let active_buffer = self
|
||||
@@ -99,10 +91,13 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool {
|
||||
let buffer = active_buffer.read(cx);
|
||||
|
||||
if let Some(buffer) = buffer.as_singleton() {
|
||||
let path = project::File::from_dyn(buffer.read(cx).file())
|
||||
.and_then(|file| file.worktree.read(cx).absolutize(&file.path).ok());
|
||||
let path =
|
||||
project::File::from_dyn(buffer.read(cx).file()).map(|file| ProjectPath {
|
||||
worktree_id: file.worktree_id(cx),
|
||||
path: file.path.clone(),
|
||||
});
|
||||
return Ok(ActiveEditorAttachment {
|
||||
buffer: Some(buffer.downgrade()),
|
||||
buffer: buffer.downgrade(),
|
||||
path,
|
||||
});
|
||||
} else {
|
||||
@@ -111,34 +106,7 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool {
|
||||
}))
|
||||
}
|
||||
|
||||
fn view(
|
||||
&self,
|
||||
output: Result<ActiveEditorAttachment>,
|
||||
cx: &mut WindowContext,
|
||||
) -> View<Self::View> {
|
||||
let error;
|
||||
let project_path;
|
||||
let buffer;
|
||||
match output {
|
||||
Ok(output) => {
|
||||
error = None;
|
||||
let workspace = self.workspace.upgrade().unwrap();
|
||||
let project = workspace.read(cx).project();
|
||||
project_path = output
|
||||
.path
|
||||
.and_then(|path| project.read(cx).project_path_for_absolute_path(&path, cx));
|
||||
buffer = output.buffer;
|
||||
}
|
||||
Err(err) => {
|
||||
error = Some(err);
|
||||
buffer = None;
|
||||
project_path = None;
|
||||
}
|
||||
}
|
||||
cx.new_view(|_cx| FileAttachmentView {
|
||||
project_path,
|
||||
buffer,
|
||||
error,
|
||||
})
|
||||
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View> {
|
||||
cx.new_view(|_cx| FileAttachmentView { output })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use anyhow::Result;
|
||||
use assistant_tooling::ToolFunctionDefinition;
|
||||
use client::{proto, Client};
|
||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||
use gpui::Global;
|
||||
use gpui::{AppContext, Global};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use open_ai::RequestMessage as CompletionMessage;
|
||||
@@ -11,6 +11,10 @@ pub use open_ai::RequestMessage as CompletionMessage;
|
||||
pub struct CompletionProvider(Arc<dyn CompletionProviderBackend>);
|
||||
|
||||
impl CompletionProvider {
|
||||
pub fn get(cx: &AppContext) -> &Self {
|
||||
cx.global::<CompletionProvider>()
|
||||
}
|
||||
|
||||
pub fn new(backend: impl CompletionProviderBackend) -> Self {
|
||||
Self(Arc::new(backend))
|
||||
}
|
||||
|
||||
@@ -1,16 +1,4 @@
|
||||
use std::cmp::Reverse;
|
||||
use std::ffi::OsStr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use assistant_tooling::{SavedToolFunctionCall, SavedUserAttachment};
|
||||
use fs::Fs;
|
||||
use futures::StreamExt;
|
||||
use gpui::SharedString;
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::paths::CONVERSATIONS_DIR;
|
||||
|
||||
use crate::MessageId;
|
||||
|
||||
@@ -20,71 +8,42 @@ pub struct SavedConversation {
|
||||
pub version: String,
|
||||
/// The title of the conversation, generated by the Assistant.
|
||||
pub title: String,
|
||||
pub messages: Vec<SavedChatMessage>,
|
||||
pub messages: Vec<SavedMessage>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum SavedChatMessage {
|
||||
User {
|
||||
id: MessageId,
|
||||
body: String,
|
||||
attachments: Vec<SavedUserAttachment>,
|
||||
},
|
||||
Assistant {
|
||||
id: MessageId,
|
||||
messages: Vec<SavedAssistantMessagePart>,
|
||||
error: Option<SharedString>,
|
||||
},
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum SavedMessageRole {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SavedAssistantMessagePart {
|
||||
pub body: SharedString,
|
||||
pub tool_calls: Vec<SavedToolFunctionCall>,
|
||||
pub struct SavedMessage {
|
||||
pub id: MessageId,
|
||||
pub role: SavedMessageRole,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
pub struct SavedConversationMetadata {
|
||||
pub title: String,
|
||||
pub path: PathBuf,
|
||||
pub mtime: chrono::DateTime<chrono::Local>,
|
||||
}
|
||||
|
||||
impl SavedConversationMetadata {
|
||||
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
|
||||
fs.create_dir(&CONVERSATIONS_DIR).await?;
|
||||
|
||||
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
|
||||
let mut conversations = Vec::new();
|
||||
while let Some(path) = paths.next().await {
|
||||
let path = path?;
|
||||
if path.extension() != Some(OsStr::new("json")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pattern = r" - \d+.zed.\d.\d.\d.json$";
|
||||
let re = Regex::new(pattern).unwrap();
|
||||
|
||||
let metadata = fs.metadata(&path).await?;
|
||||
if let Some((file_name, metadata)) = path
|
||||
.file_name()
|
||||
.and_then(|name| name.to_str())
|
||||
.zip(metadata)
|
||||
{
|
||||
// This is used to filter out conversations saved by the old assistant.
|
||||
if !re.is_match(file_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let title = re.replace(file_name, "");
|
||||
conversations.push(Self {
|
||||
title: title.into_owned(),
|
||||
path,
|
||||
mtime: metadata.mtime.into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
|
||||
|
||||
Ok(conversations)
|
||||
}
|
||||
/// Returns a list of placeholder conversations for mocking the UI.
|
||||
///
|
||||
/// Once we have real saved conversations to pull from we can use those instead.
|
||||
pub fn placeholder_conversations() -> Vec<SavedConversation> {
|
||||
vec![
|
||||
SavedConversation {
|
||||
version: "0.3.0".to_string(),
|
||||
title: "How to get a list of exported functions in an Erlang module".to_string(),
|
||||
messages: vec![],
|
||||
},
|
||||
SavedConversation {
|
||||
version: "0.3.0".to_string(),
|
||||
title: "7 wonders of the ancient world".to_string(),
|
||||
messages: vec![],
|
||||
},
|
||||
SavedConversation {
|
||||
version: "0.3.0".to_string(),
|
||||
title: "Size difference between u8 and a reference to u8 in Rust".to_string(),
|
||||
messages: vec![],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
@@ -5,66 +5,57 @@ use gpui::{AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, V
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use ui::{prelude::*, HighlightedLabel, ListItem, ListItemSpacing};
|
||||
use util::ResultExt;
|
||||
use workspace::{ModalView, Workspace};
|
||||
|
||||
use crate::saved_conversation::SavedConversationMetadata;
|
||||
use crate::saved_conversation::{self, SavedConversation};
|
||||
use crate::ToggleSavedConversations;
|
||||
|
||||
pub struct SavedConversations {
|
||||
focus_handle: FocusHandle,
|
||||
picker: Option<View<Picker<SavedConversationPickerDelegate>>>,
|
||||
pub struct SavedConversationPicker {
|
||||
picker: View<Picker<SavedConversationPickerDelegate>>,
|
||||
}
|
||||
|
||||
impl EventEmitter<DismissEvent> for SavedConversations {}
|
||||
impl EventEmitter<DismissEvent> for SavedConversationPicker {}
|
||||
|
||||
impl FocusableView for SavedConversations {
|
||||
impl ModalView for SavedConversationPicker {}
|
||||
|
||||
impl FocusableView for SavedConversationPicker {
|
||||
fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
|
||||
if let Some(picker) = self.picker.as_ref() {
|
||||
picker.focus_handle(cx)
|
||||
} else {
|
||||
self.focus_handle.clone()
|
||||
}
|
||||
self.picker.focus_handle(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl SavedConversations {
|
||||
pub fn new(cx: &mut ViewContext<Self>) -> Self {
|
||||
Self {
|
||||
focus_handle: cx.focus_handle(),
|
||||
picker: None,
|
||||
}
|
||||
impl SavedConversationPicker {
|
||||
pub fn register(workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>) {
|
||||
workspace.register_action(|workspace, _: &ToggleSavedConversations, cx| {
|
||||
workspace.toggle_modal(cx, move |cx| {
|
||||
let delegate = SavedConversationPickerDelegate::new(cx.view().downgrade());
|
||||
Self::new(delegate, cx)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub fn init(
|
||||
&mut self,
|
||||
saved_conversations: Vec<SavedConversationMetadata>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) {
|
||||
let delegate =
|
||||
SavedConversationPickerDelegate::new(cx.view().downgrade(), saved_conversations);
|
||||
self.picker = Some(cx.new_view(|cx| Picker::uniform_list(delegate, cx).modal(false)));
|
||||
pub fn new(delegate: SavedConversationPickerDelegate, cx: &mut ViewContext<Self>) -> Self {
|
||||
let picker = cx.new_view(|cx| Picker::uniform_list(delegate, cx));
|
||||
Self { picker }
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for SavedConversations {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.w_full()
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.children(self.picker.clone())
|
||||
impl Render for SavedConversationPicker {
|
||||
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
v_flex().w(rems(34.)).child(self.picker.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SavedConversationPickerDelegate {
|
||||
view: WeakView<SavedConversations>,
|
||||
saved_conversations: Vec<SavedConversationMetadata>,
|
||||
view: WeakView<SavedConversationPicker>,
|
||||
saved_conversations: Vec<SavedConversation>,
|
||||
selected_index: usize,
|
||||
matches: Vec<StringMatch>,
|
||||
}
|
||||
|
||||
impl SavedConversationPickerDelegate {
|
||||
pub fn new(
|
||||
weak_view: WeakView<SavedConversations>,
|
||||
saved_conversations: Vec<SavedConversationMetadata>,
|
||||
) -> Self {
|
||||
pub fn new(weak_view: WeakView<SavedConversationPicker>) -> Self {
|
||||
let saved_conversations = saved_conversation::placeholder_conversations();
|
||||
let matches = saved_conversations
|
||||
.iter()
|
||||
.map(|conversation| StringMatch {
|
||||
@@ -185,6 +176,7 @@ impl PickerDelegate for SavedConversationPickerDelegate {
|
||||
|
||||
Some(
|
||||
ListItem::new(ix)
|
||||
.inset(true)
|
||||
.spacing(ListItemSpacing::Sparse)
|
||||
.selected(selected)
|
||||
.child(HighlightedLabel::new(
|
||||
@@ -1,13 +1,12 @@
|
||||
use anyhow::Result;
|
||||
use assistant_tooling::{LanguageModelTool, ProjectContext, ToolView};
|
||||
use assistant_tooling::{LanguageModelTool, ProjectContext, ToolOutput};
|
||||
use editor::{
|
||||
display_map::{BlockContext, BlockDisposition, BlockProperties, BlockStyle},
|
||||
Editor, MultiBuffer,
|
||||
};
|
||||
use futures::{channel::mpsc::UnboundedSender, StreamExt as _};
|
||||
use gpui::{prelude::*, AnyElement, AsyncWindowContext, Model, Task, View, WeakView};
|
||||
use gpui::{prelude::*, AnyElement, Model, Task, View, WeakView};
|
||||
use language::ToPoint;
|
||||
use project::{search::SearchQuery, Project, ProjectPath};
|
||||
use project::{Project, ProjectPath};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use std::path::Path;
|
||||
@@ -26,30 +25,26 @@ impl AnnotationTool {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Deserialize, JsonSchema, Clone)]
|
||||
#[derive(Debug, Deserialize, JsonSchema, Clone)]
|
||||
pub struct AnnotationInput {
|
||||
/// Name for this set of annotations
|
||||
#[serde(default = "default_title")]
|
||||
title: String,
|
||||
/// Excerpts from the file to show to the user.
|
||||
excerpts: Vec<Excerpt>,
|
||||
}
|
||||
|
||||
fn default_title() -> String {
|
||||
"Untitled".to_string()
|
||||
annotations: Vec<Annotation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema, Clone)]
|
||||
struct Excerpt {
|
||||
struct Annotation {
|
||||
/// Path to the file
|
||||
path: String,
|
||||
/// A short, distinctive string that appears in the file, used to define a location in the file.
|
||||
text_passage: String,
|
||||
/// Text to display above the code excerpt
|
||||
annotation: String,
|
||||
/// Name of a symbol in the code
|
||||
symbol_name: String,
|
||||
/// Text to display near the symbol definition
|
||||
text: String,
|
||||
}
|
||||
|
||||
impl LanguageModelTool for AnnotationTool {
|
||||
type Input = AnnotationInput;
|
||||
type Output = String;
|
||||
type View = AnnotationResultView;
|
||||
|
||||
fn name(&self) -> String {
|
||||
@@ -60,137 +55,105 @@ impl LanguageModelTool for AnnotationTool {
|
||||
"Dynamically annotate symbols in the current codebase. Opens a buffer in a panel in their editor, to the side of the conversation. The annotations are shown in the editor as a block decoration.".to_string()
|
||||
}
|
||||
|
||||
fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
|
||||
cx.new_view(|cx| {
|
||||
let (tx, mut rx) = futures::channel::mpsc::unbounded();
|
||||
cx.spawn(|view, mut cx| async move {
|
||||
while let Some(excerpt) = rx.next().await {
|
||||
AnnotationResultView::add_excerpt(view.clone(), excerpt, &mut cx).await?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
AnnotationResultView {
|
||||
project: self.project.clone(),
|
||||
workspace: self.workspace.clone(),
|
||||
tx,
|
||||
pending_excerpt: None,
|
||||
added_editor_to_workspace: false,
|
||||
editor: None,
|
||||
error: None,
|
||||
rendered_excerpt_count: 0,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AnnotationResultView {
|
||||
workspace: WeakView<Workspace>,
|
||||
project: Model<Project>,
|
||||
pending_excerpt: Option<Excerpt>,
|
||||
added_editor_to_workspace: bool,
|
||||
editor: Option<View<Editor>>,
|
||||
tx: UnboundedSender<Excerpt>,
|
||||
error: Option<anyhow::Error>,
|
||||
rendered_excerpt_count: usize,
|
||||
}
|
||||
|
||||
impl AnnotationResultView {
|
||||
async fn add_excerpt(
|
||||
this: WeakView<Self>,
|
||||
excerpt: Excerpt,
|
||||
cx: &mut AsyncWindowContext,
|
||||
) -> Result<()> {
|
||||
let project = this.update(cx, |this, _cx| this.project.clone())?;
|
||||
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
|
||||
let workspace = self.workspace.clone();
|
||||
let project = self.project.clone();
|
||||
let excerpts = input.annotations.clone();
|
||||
let title = input.title.clone();
|
||||
|
||||
let worktree_id = project.update(cx, |project, cx| {
|
||||
let worktree = project.worktrees().next()?;
|
||||
let worktree_id = worktree.read(cx).id();
|
||||
Some(worktree_id)
|
||||
})?;
|
||||
});
|
||||
|
||||
let worktree_id = if let Some(worktree_id) = worktree_id {
|
||||
worktree_id
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("No worktree found"));
|
||||
return Task::ready(Err(anyhow::anyhow!("No worktree found")));
|
||||
};
|
||||
|
||||
let buffer_task = project.update(cx, |project, cx| {
|
||||
project.open_buffer(
|
||||
ProjectPath {
|
||||
worktree_id,
|
||||
path: Path::new(&excerpt.path).into(),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
let buffer = match buffer_task.await {
|
||||
Ok(buffer) => buffer,
|
||||
Err(error) => {
|
||||
return this.update(cx, |this, cx| {
|
||||
this.error = Some(error);
|
||||
cx.notify();
|
||||
let buffer_tasks = project.update(cx, |project, cx| {
|
||||
let excerpts = excerpts.clone();
|
||||
excerpts
|
||||
.iter()
|
||||
.map(|excerpt| {
|
||||
let project_path = ProjectPath {
|
||||
worktree_id,
|
||||
path: Path::new(&excerpt.path).into(),
|
||||
};
|
||||
project.open_buffer(project_path.clone(), cx)
|
||||
})
|
||||
}
|
||||
};
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
let snapshot = buffer.update(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
let query = SearchQuery::text(&excerpt.text_passage, false, false, false, vec![], vec![])?;
|
||||
let matches = query.search(&snapshot, None).await;
|
||||
let Some(first_match) = matches.first() else {
|
||||
log::warn!(
|
||||
"text {:?} does not appear in '{}'",
|
||||
excerpt.text_passage,
|
||||
excerpt.path
|
||||
);
|
||||
return Ok(());
|
||||
};
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let buffers = futures::future::try_join_all(buffer_tasks).await?;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
let mut start = first_match.start.to_point(&snapshot);
|
||||
start.column = 0;
|
||||
let multibuffer = cx.new_model(|_cx| {
|
||||
MultiBuffer::new(0, language::Capability::ReadWrite).with_title(title)
|
||||
})?;
|
||||
let editor =
|
||||
cx.new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), cx))?;
|
||||
|
||||
if let Some(editor) = &this.editor {
|
||||
editor.update(cx, |editor, cx| {
|
||||
let ranges = editor.buffer().update(cx, |multibuffer, cx| {
|
||||
multibuffer.push_excerpts_with_context_lines(
|
||||
buffer.clone(),
|
||||
vec![start..start],
|
||||
5,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
for (excerpt, buffer) in excerpts.iter().zip(buffers.iter()) {
|
||||
let snapshot = buffer.update(&mut cx, |buffer, _cx| buffer.snapshot())?;
|
||||
|
||||
let annotation = SharedString::from(excerpt.annotation);
|
||||
editor.insert_blocks(
|
||||
[BlockProperties {
|
||||
position: ranges[0].start,
|
||||
height: annotation.split('\n').count() as u8 + 1,
|
||||
style: BlockStyle::Fixed,
|
||||
render: Box::new(move |cx| Self::render_note_block(&annotation, cx)),
|
||||
disposition: BlockDisposition::Above,
|
||||
}],
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
if !this.added_editor_to_workspace {
|
||||
this.added_editor_to_workspace = true;
|
||||
this.workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
workspace.add_item_to_active_pane(Box::new(editor.clone()), None, cx);
|
||||
})
|
||||
.log_err();
|
||||
if let Some(outline) = snapshot.outline(None) {
|
||||
let matches = outline
|
||||
.search(&excerpt.symbol_name, cx.background_executor().clone())
|
||||
.await;
|
||||
if let Some(mat) = matches.first() {
|
||||
let item = &outline.items[mat.candidate_id];
|
||||
let start = item.range.start.to_point(&snapshot);
|
||||
editor.update(&mut cx, |editor, cx| {
|
||||
let ranges = editor.buffer().update(cx, |multibuffer, cx| {
|
||||
multibuffer.push_excerpts_with_context_lines(
|
||||
buffer.clone(),
|
||||
vec![start..start],
|
||||
5,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let explanation = SharedString::from(excerpt.text.clone());
|
||||
editor.insert_blocks(
|
||||
[BlockProperties {
|
||||
position: ranges[0].start,
|
||||
height: 2,
|
||||
style: BlockStyle::Fixed,
|
||||
render: Box::new(move |cx| {
|
||||
Self::render_note_block(&explanation, cx)
|
||||
}),
|
||||
disposition: BlockDisposition::Above,
|
||||
}],
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
workspace
|
||||
.update(&mut cx, |workspace, cx| {
|
||||
workspace.add_item_to_active_pane(Box::new(editor.clone()), None, cx);
|
||||
})
|
||||
.log_err();
|
||||
|
||||
anyhow::Ok("showed comments to users in a new view".into())
|
||||
})
|
||||
}
|
||||
|
||||
fn output_view(
|
||||
_: Self::Input,
|
||||
output: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
) -> View<Self::View> {
|
||||
cx.new_view(|_cx| AnnotationResultView { output })
|
||||
}
|
||||
}
|
||||
|
||||
impl AnnotationTool {
|
||||
fn render_note_block(explanation: &SharedString, cx: &mut BlockContext) -> AnyElement {
|
||||
let anchor_x = cx.anchor_x;
|
||||
let gutter_width = cx.gutter_dimensions.width;
|
||||
@@ -216,89 +179,24 @@ impl AnnotationResultView {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AnnotationResultView {
|
||||
output: Result<String>,
|
||||
}
|
||||
|
||||
impl Render for AnnotationResultView {
|
||||
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
if let Some(error) = &self.error {
|
||||
ui::Label::new(error.to_string()).into_any_element()
|
||||
} else {
|
||||
ui::Label::new(SharedString::from(format!(
|
||||
"Opened a buffer with {} excerpts",
|
||||
self.rendered_excerpt_count
|
||||
)))
|
||||
.into_any_element()
|
||||
match &self.output {
|
||||
Ok(output) => div().child(output.clone().into_any_element()),
|
||||
Err(error) => div().child(format!("failed to open path: {:?}", error)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolView for AnnotationResultView {
|
||||
type Input = AnnotationInput;
|
||||
type SerializedState = Option<String>;
|
||||
|
||||
fn generate(&self, _: &mut ProjectContext, _: &mut ViewContext<Self>) -> String {
|
||||
if let Some(error) = &self.error {
|
||||
format!("Failed to create buffer: {error:?}")
|
||||
} else {
|
||||
format!(
|
||||
"opened {} excerpts in a buffer",
|
||||
self.rendered_excerpt_count
|
||||
)
|
||||
impl ToolOutput for AnnotationResultView {
|
||||
fn generate(&self, _: &mut ProjectContext, _: &mut WindowContext) -> String {
|
||||
match &self.output {
|
||||
Ok(output) => output.clone(),
|
||||
Err(err) => format!("Failed to create buffer: {err:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn set_input(&mut self, mut input: Self::Input, cx: &mut ViewContext<Self>) {
|
||||
let editor = if let Some(editor) = &self.editor {
|
||||
editor.clone()
|
||||
} else {
|
||||
let multibuffer = cx.new_model(|_cx| {
|
||||
MultiBuffer::new(0, language::Capability::ReadWrite).with_title(String::new())
|
||||
});
|
||||
let editor = cx.new_view(|cx| {
|
||||
Editor::for_multibuffer(multibuffer.clone(), Some(self.project.clone()), cx)
|
||||
});
|
||||
|
||||
self.editor = Some(editor.clone());
|
||||
editor
|
||||
};
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.buffer().update(cx, |multibuffer, cx| {
|
||||
if multibuffer.title(cx) != input.title {
|
||||
multibuffer.set_title(input.title.clone(), cx);
|
||||
}
|
||||
});
|
||||
|
||||
self.pending_excerpt = input.excerpts.pop();
|
||||
for excerpt in input.excerpts.iter().skip(self.rendered_excerpt_count) {
|
||||
self.tx.unbounded_send(excerpt.clone()).ok();
|
||||
}
|
||||
self.rendered_excerpt_count = input.excerpts.len();
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
|
||||
if let Some(excerpt) = self.pending_excerpt.take() {
|
||||
self.rendered_excerpt_count += 1;
|
||||
self.tx.unbounded_send(excerpt.clone()).ok();
|
||||
}
|
||||
|
||||
self.tx.close_channel();
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
|
||||
self.error.as_ref().map(|error| error.to_string())
|
||||
}
|
||||
|
||||
fn deserialize(
|
||||
&mut self,
|
||||
output: Self::SerializedState,
|
||||
_cx: &mut ViewContext<Self>,
|
||||
) -> Result<()> {
|
||||
if let Some(error_message) = output {
|
||||
self.error = Some(anyhow::anyhow!("{}", error_message));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_tooling::{LanguageModelTool, ProjectContext, ToolView};
|
||||
use anyhow::Result;
|
||||
use assistant_tooling::{LanguageModelTool, ProjectContext, ToolOutput};
|
||||
use editor::Editor;
|
||||
use gpui::{prelude::*, Model, Task, View, WeakView};
|
||||
use project::Project;
|
||||
@@ -20,7 +20,7 @@ impl CreateBufferTool {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, JsonSchema)]
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
pub struct CreateBufferInput {
|
||||
/// The contents of the buffer.
|
||||
text: String,
|
||||
@@ -32,70 +32,25 @@ pub struct CreateBufferInput {
|
||||
}
|
||||
|
||||
impl LanguageModelTool for CreateBufferTool {
|
||||
type Input = CreateBufferInput;
|
||||
type Output = ();
|
||||
type View = CreateBufferView;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"create_file".to_string()
|
||||
"create_buffer".to_string()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"Create a new untitled file in the current codebase. Side effect: opens it in a new pane/tab for the user to edit.".to_string()
|
||||
"Create a new buffer in the current codebase".to_string()
|
||||
}
|
||||
|
||||
fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
|
||||
cx.new_view(|_cx| CreateBufferView {
|
||||
workspace: self.workspace.clone(),
|
||||
project: self.project.clone(),
|
||||
input: None,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CreateBufferView {
|
||||
workspace: WeakView<Workspace>,
|
||||
project: Model<Project>,
|
||||
input: Option<CreateBufferInput>,
|
||||
error: Option<anyhow::Error>,
|
||||
}
|
||||
|
||||
impl Render for CreateBufferView {
|
||||
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
ui::Label::new("Opening a buffer")
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolView for CreateBufferView {
|
||||
type Input = CreateBufferInput;
|
||||
|
||||
type SerializedState = ();
|
||||
|
||||
fn generate(&self, _project: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
|
||||
let Some(input) = self.input.as_ref() else {
|
||||
return "No input".to_string();
|
||||
};
|
||||
|
||||
match &self.error {
|
||||
None => format!("Created a new {} buffer", input.language),
|
||||
Some(err) => format!("Failed to create buffer: {err:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
|
||||
self.input = Some(input);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
|
||||
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
|
||||
cx.spawn({
|
||||
let workspace = self.workspace.clone();
|
||||
let project = self.project.clone();
|
||||
let input = self.input.clone();
|
||||
|_this, mut cx| async move {
|
||||
let input = input.ok_or_else(|| anyhow!("no input"))?;
|
||||
|
||||
let text = input.text.clone();
|
||||
let language_name = input.language.clone();
|
||||
let text = input.text.clone();
|
||||
let language_name = input.language.clone();
|
||||
|mut cx| async move {
|
||||
let language = cx
|
||||
.update(|cx| {
|
||||
project
|
||||
@@ -131,15 +86,34 @@ impl ToolView for CreateBufferView {
|
||||
})
|
||||
}
|
||||
|
||||
fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
|
||||
()
|
||||
}
|
||||
|
||||
fn deserialize(
|
||||
&mut self,
|
||||
_output: Self::SerializedState,
|
||||
_cx: &mut ViewContext<Self>,
|
||||
) -> Result<()> {
|
||||
Ok(())
|
||||
fn output_view(
|
||||
input: Self::Input,
|
||||
output: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
) -> View<Self::View> {
|
||||
cx.new_view(|_cx| CreateBufferView {
|
||||
language: input.language,
|
||||
output,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CreateBufferView {
|
||||
language: String,
|
||||
output: Result<()>,
|
||||
}
|
||||
|
||||
impl Render for CreateBufferView {
|
||||
fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
div().child("Opening a buffer")
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolOutput for CreateBufferView {
|
||||
fn generate(&self, _: &mut ProjectContext, _: &mut WindowContext) -> String {
|
||||
match &self.output {
|
||||
Ok(_) => format!("Created a new {} buffer", self.language),
|
||||
Err(err) => format!("Failed to create buffer: {err:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,13 @@
|
||||
use anyhow::Result;
|
||||
use assistant_tooling::{LanguageModelTool, ToolView};
|
||||
use assistant_tooling::{LanguageModelTool, ToolOutput};
|
||||
use collections::BTreeMap;
|
||||
use file_icons::FileIcons;
|
||||
use gpui::{prelude::*, AnyElement, Model, Task};
|
||||
use gpui::{prelude::*, Model, Task};
|
||||
use project::ProjectPath;
|
||||
use schemars::JsonSchema;
|
||||
use semantic_index::{ProjectIndex, Status};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fmt::Write as _,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
str::FromStr as _,
|
||||
sync::Arc,
|
||||
};
|
||||
use ui::{prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
|
||||
use serde::Deserialize;
|
||||
use std::{fmt::Write as _, ops::Range};
|
||||
use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
|
||||
|
||||
const DEFAULT_SEARCH_LIMIT: usize = 20;
|
||||
|
||||
@@ -22,373 +15,126 @@ pub struct ProjectIndexTool {
|
||||
project_index: Model<ProjectIndex>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
enum ProjectIndexToolState {
|
||||
#[default]
|
||||
CollectingQuery,
|
||||
Searching,
|
||||
Error(anyhow::Error),
|
||||
Finished {
|
||||
excerpts: BTreeMap<ProjectPath, Vec<Range<usize>>>,
|
||||
index_status: Status,
|
||||
},
|
||||
}
|
||||
// Note: Comments on a `LanguageModelTool::Input` become descriptions on the generated JSON schema as shown to the language model.
|
||||
// Any changes or deletions to the `CodebaseQuery` comments will change model behavior.
|
||||
|
||||
pub struct ProjectIndexView {
|
||||
project_index: Model<ProjectIndex>,
|
||||
input: CodebaseQuery,
|
||||
expanded_header: bool,
|
||||
state: ProjectIndexToolState,
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize, JsonSchema)]
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
pub struct CodebaseQuery {
|
||||
/// Semantic search query
|
||||
query: String,
|
||||
/// Criteria to include results
|
||||
includes: Option<SearchFilter>,
|
||||
/// Criteria to exclude results
|
||||
excludes: Option<SearchFilter>,
|
||||
/// Maximum number of results to return, defaults to 20
|
||||
limit: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Clone, Default)]
|
||||
pub struct SearchFilter {
|
||||
/// Filter by file path prefix
|
||||
prefix_path: Option<String>,
|
||||
/// Filter by file extension
|
||||
extension: Option<String>,
|
||||
// Note: we possibly can't do content filtering very easily given the project context handling
|
||||
// the final results, so we're leaving out direct string matches for now
|
||||
pub struct ProjectIndexView {
|
||||
input: CodebaseQuery,
|
||||
output: Result<ProjectIndexOutput>,
|
||||
element_id: ElementId,
|
||||
expanded_header: bool,
|
||||
}
|
||||
|
||||
fn project_starts_with(prefix_path: Option<String>, project_path: ProjectPath) -> bool {
|
||||
if let Some(path) = &prefix_path {
|
||||
if let Some(path) = PathBuf::from_str(path).ok() {
|
||||
return project_path.path.starts_with(path);
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
impl SearchFilter {
|
||||
fn matches(&self, project_path: &ProjectPath) -> bool {
|
||||
let path_match = project_starts_with(self.prefix_path.clone(), project_path.clone());
|
||||
|
||||
path_match
|
||||
&& (if let Some(extension) = &self.extension {
|
||||
project_path
|
||||
.path
|
||||
.extension()
|
||||
.and_then(|ext| ext.to_str())
|
||||
.map(|ext| ext == extension)
|
||||
.unwrap_or(false)
|
||||
} else {
|
||||
true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SerializedState {
|
||||
index_status: Status,
|
||||
error_message: Option<String>,
|
||||
worktrees: BTreeMap<Arc<Path>, WorktreeIndexOutput>,
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
struct WorktreeIndexOutput {
|
||||
excerpts: BTreeMap<Arc<Path>, Vec<Range<usize>>>,
|
||||
pub struct ProjectIndexOutput {
|
||||
status: Status,
|
||||
excerpts: BTreeMap<ProjectPath, Vec<Range<usize>>>,
|
||||
}
|
||||
|
||||
impl ProjectIndexView {
|
||||
fn new(input: CodebaseQuery, output: Result<ProjectIndexOutput>) -> Self {
|
||||
let element_id = ElementId::Name(nanoid::nanoid!().into());
|
||||
|
||||
Self {
|
||||
input,
|
||||
output,
|
||||
element_id,
|
||||
expanded_header: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn toggle_header(&mut self, cx: &mut ViewContext<Self>) {
|
||||
self.expanded_header = !self.expanded_header;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn render_filter_section(
|
||||
&mut self,
|
||||
heading: &str,
|
||||
filter: Option<SearchFilter>,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Option<AnyElement> {
|
||||
let filter = match filter {
|
||||
Some(filter) => filter,
|
||||
None => return None,
|
||||
};
|
||||
|
||||
// Any of the filter fields can be empty. We'll show nothing if they're all empty.
|
||||
let path = filter.prefix_path.as_ref().map(|path| {
|
||||
let icon_path = FileIcons::get_icon(Path::new(path), cx)
|
||||
.map(SharedString::from)
|
||||
.unwrap_or_else(|| SharedString::from("icons/file_icons/file.svg"));
|
||||
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child("Paths: ")
|
||||
.child(Icon::from_path(icon_path))
|
||||
.child(ui::Label::new(path.clone()).color(Color::Muted))
|
||||
});
|
||||
|
||||
let extension = filter.extension.as_ref().map(|extension| {
|
||||
let icon_path = FileIcons::get_icon(Path::new(extension), cx)
|
||||
.map(SharedString::from)
|
||||
.unwrap_or_else(|| SharedString::from("icons/file_icons/file.svg"));
|
||||
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child("Extensions: ")
|
||||
.child(Icon::from_path(icon_path))
|
||||
.child(ui::Label::new(extension.clone()).color(Color::Muted))
|
||||
});
|
||||
|
||||
if path.is_none() && extension.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(
|
||||
v_flex()
|
||||
.child(ui::Label::new(heading.to_string()))
|
||||
.gap_1()
|
||||
.children(path)
|
||||
.children(extension)
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ProjectIndexView {
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
|
||||
let query = self.input.query.clone();
|
||||
|
||||
let (header_text, content) = match &self.state {
|
||||
ProjectIndexToolState::Error(error) => {
|
||||
return format!("failed to search: {error:?}").into_any_element()
|
||||
}
|
||||
ProjectIndexToolState::CollectingQuery | ProjectIndexToolState::Searching => {
|
||||
("Searching...".to_string(), div())
|
||||
}
|
||||
ProjectIndexToolState::Finished { excerpts, .. } => {
|
||||
let file_count = excerpts.len();
|
||||
let result = &self.output;
|
||||
|
||||
if excerpts.is_empty() {
|
||||
("No results found".to_string(), div())
|
||||
} else {
|
||||
let header_text = format!(
|
||||
"Read {} {}",
|
||||
file_count,
|
||||
if file_count == 1 { "file" } else { "files" }
|
||||
);
|
||||
|
||||
let el = v_flex().gap_2().children(excerpts.keys().map(|path| {
|
||||
h_flex().gap_2().child(Icon::new(IconName::File)).child(
|
||||
Label::new(path.path.to_string_lossy().to_string()).color(Color::Muted),
|
||||
)
|
||||
}));
|
||||
|
||||
(header_text, el)
|
||||
}
|
||||
let output = match result {
|
||||
Err(err) => {
|
||||
return div().child(Label::new(format!("Error: {}", err)).color(Color::Error));
|
||||
}
|
||||
Ok(output) => output,
|
||||
};
|
||||
|
||||
let file_count = output.excerpts.len();
|
||||
|
||||
let header = h_flex()
|
||||
.gap_2()
|
||||
.child(Icon::new(IconName::File))
|
||||
.child(header_text);
|
||||
.child(format!(
|
||||
"Read {} {}",
|
||||
file_count,
|
||||
if file_count == 1 { "file" } else { "files" }
|
||||
));
|
||||
|
||||
v_flex()
|
||||
.gap_3()
|
||||
.child(
|
||||
CollapsibleContainer::new("collapsible-container", self.expanded_header)
|
||||
.start_slot(header)
|
||||
.on_click(cx.listener(move |this, _, cx| {
|
||||
this.toggle_header(cx);
|
||||
}))
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_3()
|
||||
.p_3()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Icon::new(IconName::MagnifyingGlass))
|
||||
.child(Label::new(format!("`{}`", query)).color(Color::Muted)),
|
||||
)
|
||||
.children(self.render_filter_section(
|
||||
"Includes",
|
||||
self.input.includes.clone(),
|
||||
cx,
|
||||
))
|
||||
.children(self.render_filter_section(
|
||||
"Excludes",
|
||||
self.input.excludes.clone(),
|
||||
cx,
|
||||
))
|
||||
.child(content),
|
||||
),
|
||||
)
|
||||
.into_any_element()
|
||||
v_flex().gap_3().child(
|
||||
CollapsibleContainer::new(self.element_id.clone(), self.expanded_header)
|
||||
.start_slot(header)
|
||||
.on_click(cx.listener(move |this, _, cx| {
|
||||
this.toggle_header(cx);
|
||||
}))
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_3()
|
||||
.p_3()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Icon::new(IconName::MagnifyingGlass))
|
||||
.child(Label::new(format!("`{}`", query)).color(Color::Muted)),
|
||||
)
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.children(output.excerpts.keys().map(|path| {
|
||||
h_flex().gap_2().child(Icon::new(IconName::File)).child(
|
||||
Label::new(path.path.to_string_lossy().to_string())
|
||||
.color(Color::Muted),
|
||||
)
|
||||
})),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolView for ProjectIndexView {
|
||||
type Input = CodebaseQuery;
|
||||
type SerializedState = SerializedState;
|
||||
|
||||
impl ToolOutput for ProjectIndexView {
|
||||
fn generate(
|
||||
&self,
|
||||
context: &mut assistant_tooling::ProjectContext,
|
||||
_: &mut ViewContext<Self>,
|
||||
_: &mut WindowContext,
|
||||
) -> String {
|
||||
match &self.state {
|
||||
ProjectIndexToolState::CollectingQuery => String::new(),
|
||||
ProjectIndexToolState::Searching => String::new(),
|
||||
ProjectIndexToolState::Error(error) => format!("failed to search: {error:?}"),
|
||||
ProjectIndexToolState::Finished {
|
||||
excerpts,
|
||||
index_status,
|
||||
} => {
|
||||
match &self.output {
|
||||
Ok(output) => {
|
||||
let mut body = "found results in the following paths:\n".to_string();
|
||||
|
||||
for (project_path, ranges) in excerpts {
|
||||
for (project_path, ranges) in &output.excerpts {
|
||||
context.add_excerpts(project_path.clone(), ranges);
|
||||
writeln!(&mut body, "* {}", &project_path.path.display()).unwrap();
|
||||
}
|
||||
|
||||
if *index_status != Status::Idle {
|
||||
if output.status != Status::Idle {
|
||||
body.push_str("Still indexing. Results may be incomplete.\n");
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
Err(err) => format!("Error: {}", err),
|
||||
}
|
||||
}
|
||||
|
||||
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
|
||||
self.input = input;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
|
||||
self.state = ProjectIndexToolState::Searching;
|
||||
cx.notify();
|
||||
|
||||
let project_index = self.project_index.read(cx);
|
||||
let index_status = project_index.status();
|
||||
|
||||
// TODO: wire the filters into the search here instead of processing after.
|
||||
// Otherwise we'll get zero results sometimes.
|
||||
let search = project_index.search(self.input.query.clone(), DEFAULT_SEARCH_LIMIT, cx);
|
||||
|
||||
let includes = self.input.includes.clone();
|
||||
let excludes = self.input.excludes.clone();
|
||||
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let search_result = search.await;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
match search_result {
|
||||
Ok(search_results) => {
|
||||
let mut excerpts = BTreeMap::<ProjectPath, Vec<Range<usize>>>::new();
|
||||
for search_result in search_results {
|
||||
let project_path = ProjectPath {
|
||||
worktree_id: search_result.worktree.read(cx).id(),
|
||||
path: search_result.path,
|
||||
};
|
||||
|
||||
if let Some(includes) = &includes {
|
||||
if !includes.matches(&project_path) {
|
||||
continue;
|
||||
}
|
||||
} else if let Some(excludes) = &excludes {
|
||||
if excludes.matches(&project_path) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
excerpts
|
||||
.entry(project_path)
|
||||
.or_default()
|
||||
.push(search_result.range);
|
||||
}
|
||||
this.state = ProjectIndexToolState::Finished {
|
||||
excerpts,
|
||||
index_status,
|
||||
};
|
||||
}
|
||||
Err(error) => {
|
||||
this.state = ProjectIndexToolState::Error(error);
|
||||
}
|
||||
}
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState {
|
||||
let mut serialized = SerializedState {
|
||||
error_message: None,
|
||||
index_status: Status::Idle,
|
||||
worktrees: Default::default(),
|
||||
};
|
||||
match &self.state {
|
||||
ProjectIndexToolState::Error(err) => serialized.error_message = Some(err.to_string()),
|
||||
ProjectIndexToolState::Finished {
|
||||
excerpts,
|
||||
index_status,
|
||||
} => {
|
||||
serialized.index_status = *index_status;
|
||||
if let Some(project) = self.project_index.read(cx).project().upgrade() {
|
||||
let project = project.read(cx);
|
||||
for (project_path, excerpts) in excerpts {
|
||||
if let Some(worktree) =
|
||||
project.worktree_for_id(project_path.worktree_id, cx)
|
||||
{
|
||||
let worktree_path = worktree.read(cx).abs_path();
|
||||
serialized
|
||||
.worktrees
|
||||
.entry(worktree_path)
|
||||
.or_default()
|
||||
.excerpts
|
||||
.insert(project_path.path.clone(), excerpts.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
serialized
|
||||
}
|
||||
|
||||
fn deserialize(
|
||||
&mut self,
|
||||
serialized: Self::SerializedState,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Result<()> {
|
||||
if !serialized.worktrees.is_empty() {
|
||||
let mut excerpts = BTreeMap::<ProjectPath, Vec<Range<usize>>>::new();
|
||||
if let Some(project) = self.project_index.read(cx).project().upgrade() {
|
||||
let project = project.read(cx);
|
||||
for (worktree_path, worktree_state) in serialized.worktrees {
|
||||
if let Some(worktree) = project
|
||||
.worktrees()
|
||||
.find(|worktree| worktree.read(cx).abs_path() == worktree_path)
|
||||
{
|
||||
let worktree_id = worktree.read(cx).id();
|
||||
for (path, serialized_excerpts) in worktree_state.excerpts {
|
||||
excerpts.insert(ProjectPath { worktree_id, path }, serialized_excerpts);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
self.state = ProjectIndexToolState::Finished {
|
||||
excerpts,
|
||||
index_status: serialized.index_status,
|
||||
};
|
||||
}
|
||||
cx.notify();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ProjectIndexTool {
|
||||
@@ -398,31 +144,66 @@ impl ProjectIndexTool {
|
||||
}
|
||||
|
||||
impl LanguageModelTool for ProjectIndexTool {
|
||||
type Input = CodebaseQuery;
|
||||
type Output = ProjectIndexOutput;
|
||||
type View = ProjectIndexView;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"semantic_search_codebase".to_string()
|
||||
"query_codebase".to_string()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
unindent::unindent(
|
||||
r#"This search tool uses a semantic index to perform search queries across your codebase, identifying and returning excerpts of text and code possibly related to the query.
|
||||
|
||||
Ideal for:
|
||||
- Discovering implementations of similar logic within the project
|
||||
- Finding usage examples of functions, classes/structures, libraries, and other code elements
|
||||
- Developing understanding of the codebase's architecture and design
|
||||
|
||||
Note: The search's effectiveness is directly related to the current state of the codebase and the specificity of your query. It is recommended that you use snippets of code that are similar to the code you wish to find."#,
|
||||
)
|
||||
"Semantic search against the user's current codebase, returning excerpts related to the query by computing a dot product against embeddings of code chunks in the code base and an embedding of the query.".to_string()
|
||||
}
|
||||
|
||||
fn view(&self, cx: &mut WindowContext) -> gpui::View<Self::View> {
|
||||
cx.new_view(|_| ProjectIndexView {
|
||||
state: ProjectIndexToolState::CollectingQuery,
|
||||
input: Default::default(),
|
||||
expanded_header: false,
|
||||
project_index: self.project_index.clone(),
|
||||
fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
|
||||
let project_index = self.project_index.read(cx);
|
||||
let status = project_index.status();
|
||||
let search = project_index.search(
|
||||
query.query.clone(),
|
||||
query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.spawn(|mut cx| async move {
|
||||
let search_results = search.await?;
|
||||
|
||||
cx.update(|cx| {
|
||||
let mut output = ProjectIndexOutput {
|
||||
status,
|
||||
excerpts: Default::default(),
|
||||
};
|
||||
|
||||
for search_result in search_results {
|
||||
let path = ProjectPath {
|
||||
worktree_id: search_result.worktree.read(cx).id(),
|
||||
path: search_result.path.clone(),
|
||||
};
|
||||
|
||||
let excerpts_for_path = output.excerpts.entry(path).or_default();
|
||||
let ix = match excerpts_for_path
|
||||
.binary_search_by_key(&search_result.range.start, |r| r.start)
|
||||
{
|
||||
Ok(ix) | Err(ix) => ix,
|
||||
};
|
||||
excerpts_for_path.insert(ix, search_result.range);
|
||||
}
|
||||
|
||||
output
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn output_view(
|
||||
input: Self::Input,
|
||||
output: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
) -> gpui::View<Self::View> {
|
||||
cx.new_view(|_cx| ProjectIndexView::new(input, output))
|
||||
}
|
||||
|
||||
fn render_running(_: &mut WindowContext) -> impl IntoElement {
|
||||
CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false)
|
||||
.start_slot("Searching code base")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,7 +119,7 @@ impl RenderOnce for ChatMessage {
|
||||
)
|
||||
.when(self.messages.len() > 0, |el| {
|
||||
el.child(
|
||||
h_flex().w_full().child(
|
||||
h_flex().child(
|
||||
v_flex()
|
||||
.relative()
|
||||
.overflow_hidden()
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::{
|
||||
AssistantChat, CompletionProvider,
|
||||
};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use gpui::{AnyElement, FontStyle, FontWeight, ReadGlobal, TextStyle, View, WeakView, WhiteSpace};
|
||||
use gpui::{AnyElement, FontStyle, FontWeight, TextStyle, View, WeakView, WhiteSpace};
|
||||
use settings::Settings;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{popover_menu, prelude::*, ButtonLike, ContextMenu, Divider, TextSize, Tooltip};
|
||||
@@ -48,73 +48,68 @@ impl RenderOnce for Composer {
|
||||
fn render(mut self, cx: &mut WindowContext) -> impl IntoElement {
|
||||
let font_size = TextSize::Default.rems(cx);
|
||||
let line_height = font_size.to_pixels(cx.rem_size()) * 1.3;
|
||||
let mut editor_border = cx.theme().colors().text;
|
||||
editor_border.fade_out(0.90);
|
||||
|
||||
// Remove the extra 1px added by the border
|
||||
let padding = Spacing::XLarge.rems(cx) - rems_from_px(1.);
|
||||
|
||||
h_flex()
|
||||
.p(Spacing::Small.rems(cx))
|
||||
.w_full()
|
||||
.items_start()
|
||||
.child(
|
||||
v_flex()
|
||||
.w_full()
|
||||
.rounded_lg()
|
||||
.p(padding)
|
||||
.border_1()
|
||||
.border_color(editor_border)
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.child(
|
||||
v_flex()
|
||||
.justify_between()
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.child({
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().editor_foreground,
|
||||
font_family: settings.buffer_font.family.clone(),
|
||||
font_features: settings.buffer_font.features.clone(),
|
||||
font_size: font_size.into(),
|
||||
font_weight: FontWeight::NORMAL,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: line_height.into(),
|
||||
background_color: None,
|
||||
underline: None,
|
||||
strikethrough: None,
|
||||
white_space: WhiteSpace::Normal,
|
||||
};
|
||||
v_flex().size_full().gap_1().child(
|
||||
v_flex()
|
||||
.w_full()
|
||||
.p_3()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.rounded_lg()
|
||||
.child(
|
||||
v_flex()
|
||||
.justify_between()
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.child({
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().editor_foreground,
|
||||
font_family: settings.buffer_font.family.clone(),
|
||||
font_features: settings.buffer_font.features.clone(),
|
||||
font_size: font_size.into(),
|
||||
font_weight: FontWeight::NORMAL,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: line_height.into(),
|
||||
background_color: None,
|
||||
underline: None,
|
||||
strikethrough: None,
|
||||
white_space: WhiteSpace::Normal,
|
||||
};
|
||||
|
||||
EditorElement::new(
|
||||
&self.editor,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.flex_none()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.w_full()
|
||||
.child(
|
||||
h_flex().gap_1().child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(self.render_tools(cx))
|
||||
.child(Divider::vertical())
|
||||
.child(self.render_attachment_tools(cx)),
|
||||
),
|
||||
EditorElement::new(
|
||||
&self.editor,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.child(h_flex().gap_1().child(self.model_selector)),
|
||||
),
|
||||
),
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.flex_none()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.w_full()
|
||||
.child(
|
||||
h_flex().gap_1().child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(self.render_tools(cx))
|
||||
.child(Divider::vertical())
|
||||
.child(self.render_attachment_tools(cx)),
|
||||
),
|
||||
)
|
||||
.child(h_flex().gap_1().child(self.model_selector)),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -139,7 +134,7 @@ impl RenderOnce for ModelSelector {
|
||||
popover_menu("model-switcher")
|
||||
.menu(move |cx| {
|
||||
ContextMenu::build(cx, |mut menu, cx| {
|
||||
for model in CompletionProvider::global(cx).available_models() {
|
||||
for model in CompletionProvider::get(cx).available_models() {
|
||||
menu = menu.custom_entry(
|
||||
{
|
||||
let model = model.clone();
|
||||
|
||||
@@ -16,14 +16,11 @@ anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
log.workspace = true
|
||||
project.workspace = true
|
||||
repair_json.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
sum_tree.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
# Assistant Tooling
|
||||
|
||||
Bringing Language Model tool calling to GPUI.
|
||||
Bringing OpenAI compatible tool calling to GPUI.
|
||||
|
||||
This unlocks:
|
||||
|
||||
- **Structured Extraction** of model responses
|
||||
- **Validation** of model inputs
|
||||
- **Execution** of chosen tools
|
||||
- **Execution** of chosen toolsn
|
||||
|
||||
## Overview
|
||||
|
||||
Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When making a chat completion you can pass a list of tools available to the model. The model will choose `0..n` tools to help them complete a user's task. It's up to _you_ to create the tools that the model can call.
|
||||
Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When make a chat completion you can pass a list of tools available to the model. The model will choose `0..n` tools to help them complete a user's task. It's up to _you_ to create the tools that the model can call.
|
||||
|
||||
> **User**: "Hey I need help with implementing a collapsible panel in GPUI"
|
||||
>
|
||||
@@ -22,64 +22,187 @@ Language Models can produce structured outputs that are perfect for calling func
|
||||
>
|
||||
> **Assistant**: "Here are some excerpts from the GPUI codebase that might help you."
|
||||
|
||||
This library is designed to facilitate this interaction mode by allowing you to go from `struct` to `tool` with two simple traits, `LanguageModelTool` and `ToolView`.
|
||||
This library is designed to facilitate this interaction mode by allowing you to go from `struct` to `tool` with a simple trait, `LanguageModelTool`.
|
||||
|
||||
## Using the Tool Registry
|
||||
## Example
|
||||
|
||||
Let's expose querying a semantic index directly by the model. First, we'll set up some _necessary_ imports
|
||||
|
||||
```rust
|
||||
let mut tool_registry = ToolRegistry::new();
|
||||
tool_registry
|
||||
.register(WeatherTool { api_client },
|
||||
})
|
||||
.unwrap(); // You can only register one tool per name
|
||||
use anyhow::Result;
|
||||
use assistant_tooling::{LanguageModelTool, ToolRegistry};
|
||||
use gpui::{App, AppContext, Task};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
```
|
||||
|
||||
let completion = cx.update(|cx| {
|
||||
CompletionProvider::get(cx).complete(
|
||||
model_name,
|
||||
messages,
|
||||
Vec::new(),
|
||||
1.0,
|
||||
// The definitions get passed directly to OpenAI when you want
|
||||
// the model to be able to call your tool
|
||||
tool_registry.definitions(),
|
||||
)
|
||||
});
|
||||
Then we'll define the query structure the model must fill in. This _must_ derive `Deserialize` from `serde` and `JsonSchema` from the `schemars` crate.
|
||||
|
||||
let mut stream = completion?.await?;
|
||||
```rust
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
struct CodebaseQuery {
|
||||
query: String,
|
||||
}
|
||||
```
|
||||
|
||||
let mut message = AssistantMessage::new();
|
||||
After that we can define our tool, with the expectation that it will need a `ProjectIndex` to search against. For this example, the index uses the same interface as `semantic_index::ProjectIndex`.
|
||||
|
||||
while let Some(delta) = stream.next().await {
|
||||
// As messages stream in, you'll get both assistant content
|
||||
if let Some(content) = &delta.content {
|
||||
message
|
||||
.body
|
||||
.update(cx, |message, cx| message.append(&content, cx));
|
||||
```rust
|
||||
struct ProjectIndex {}
|
||||
|
||||
impl ProjectIndex {
|
||||
fn new() -> Self {
|
||||
ProjectIndex {}
|
||||
}
|
||||
|
||||
// And tool calls!
|
||||
for tool_call_delta in delta.tool_calls {
|
||||
let index = tool_call_delta.index as usize;
|
||||
if index >= message.tool_calls.len() {
|
||||
message.tool_calls.resize_with(index + 1, Default::default);
|
||||
fn search(&self, _query: &str, _limit: usize, _cx: &AppContext) -> Task<Result<Vec<String>>> {
|
||||
// Instead of hooking up a real index, we're going to fake it
|
||||
if _query.contains("gpui") {
|
||||
return Task::ready(Ok(vec![r#"// crates/gpui/src/gpui.rs
|
||||
//! # Welcome to GPUI!
|
||||
//!
|
||||
//! GPUI is a hybrid immediate and retained mode, GPU accelerated, UI framework
|
||||
//! for Rust, designed to support a wide variety of applications
|
||||
"#
|
||||
.to_string()]));
|
||||
}
|
||||
let tool_call = &mut message.tool_calls[index];
|
||||
return Task::ready(Ok(vec![]));
|
||||
}
|
||||
}
|
||||
|
||||
// Build up an ID
|
||||
if let Some(id) = &tool_call_delta.id {
|
||||
tool_call.id.push_str(id);
|
||||
}
|
||||
struct ProjectIndexTool {
|
||||
project_index: ProjectIndex,
|
||||
}
|
||||
```
|
||||
|
||||
tool_registry.update_tool_call(
|
||||
tool_call,
|
||||
tool_call_delta.name.as_deref(),
|
||||
tool_call_delta.arguments.as_deref(),
|
||||
cx,
|
||||
);
|
||||
Now we can implement the `LanguageModelTool` trait for our tool by:
|
||||
|
||||
- Defining the `Input` from the model, which is `CodebaseQuery`
|
||||
- Defining the `Output`
|
||||
- Implementing the `name` and `description` functions to provide the model information when it's choosing a tool
|
||||
- Implementing the `execute` function to run the tool
|
||||
|
||||
```rust
|
||||
impl LanguageModelTool for ProjectIndexTool {
|
||||
type Input = CodebaseQuery;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> String {
|
||||
"query_codebase".to_string()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"Executes a query against the codebase, returning excerpts related to the query".to_string()
|
||||
}
|
||||
|
||||
fn execute(&self, query: Self::Input, cx: &AppContext) -> Task<Result<Self::Output>> {
|
||||
let results = self.project_index.search(query.query.as_str(), 10, cx);
|
||||
|
||||
cx.spawn(|_cx| async move {
|
||||
let results = results.await?;
|
||||
|
||||
if !results.is_empty() {
|
||||
Ok(results.join("\n"))
|
||||
} else {
|
||||
Ok("No results".to_string())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Once the stream of tokens is complete, you can exexute the tool call by calling `tool_registry.execute_tool_call(tool_call, cx)`, which returns a `Task<Result<()>>`.
|
||||
For the sake of this example, let's look at the types that OpenAI will be passing to us
|
||||
|
||||
As the tokens stream in and tool calls are executed, your `ToolView` will get updates. Render each tool call by passing that `tool_call` in to `tool_registry.render_tool_call(tool_call, cx)`. The final message for the model can be pulled by calling `self.tool_registry.content_for_tool_call( tool_call, &mut project_context, cx, )`.
|
||||
```rust
|
||||
// OpenAI definitions, shown here for demonstration
|
||||
#[derive(Deserialize)]
|
||||
struct FunctionCall {
|
||||
name: String,
|
||||
args: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Eq, PartialEq)]
|
||||
enum ToolCallType {
|
||||
#[serde(rename = "function")]
|
||||
Function,
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
|
||||
struct ToolCallId(String);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum ToolCall {
|
||||
Function {
|
||||
#[allow(dead_code)]
|
||||
id: ToolCallId,
|
||||
function: FunctionCall,
|
||||
},
|
||||
Other {
|
||||
#[allow(dead_code)]
|
||||
id: ToolCallId,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AssistantMessage {
|
||||
role: String,
|
||||
content: Option<String>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
```
|
||||
|
||||
When the model wants to call tools, it will pass a list of `ToolCall`s. When those are `function`s that we can handle, we'll pass them to our `ToolRegistry` to get a future that we can await.
|
||||
|
||||
```rust
|
||||
// Inside `fn main()`
|
||||
App::new().run(|cx: &mut AppContext| {
|
||||
let tool = ProjectIndexTool {
|
||||
project_index: ProjectIndex::new(),
|
||||
};
|
||||
|
||||
let mut registry = ToolRegistry::new();
|
||||
let registered = registry.register(tool);
|
||||
assert!(registered.is_ok());
|
||||
```
|
||||
|
||||
Let's pretend the model sent us back a message requesting
|
||||
|
||||
```rust
|
||||
let model_response = json!({
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"function": {
|
||||
"name": "query_codebase",
|
||||
"args": r#"{"query":"GPUI Task background_executor"}"#
|
||||
},
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let message: AssistantMessage = serde_json::from_value(model_response).unwrap();
|
||||
|
||||
// We know there's a tool call, so let's skip straight to it for this example
|
||||
let tool_calls = message.tool_calls.as_ref().unwrap();
|
||||
let tool_call = tool_calls.get(0).unwrap();
|
||||
```
|
||||
|
||||
We can now use our registry to call the tool.
|
||||
|
||||
```rust
|
||||
let task = registry.call(
|
||||
tool_call.name,
|
||||
tool_call.args,
|
||||
);
|
||||
|
||||
cx.spawn(|_cx| async move {
|
||||
let result = task.await?;
|
||||
println!("{}", result.unwrap());
|
||||
Ok(())
|
||||
})
|
||||
```
|
||||
|
||||
@@ -2,12 +2,8 @@ mod attachment_registry;
|
||||
mod project_context;
|
||||
mod tool_registry;
|
||||
|
||||
pub use attachment_registry::{
|
||||
AttachmentOutput, AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment,
|
||||
UserAttachment,
|
||||
};
|
||||
pub use attachment_registry::{AttachmentRegistry, LanguageModelAttachment, UserAttachment};
|
||||
pub use project_context::ProjectContext;
|
||||
pub use tool_registry::{
|
||||
LanguageModelTool, SavedToolFunctionCall, ToolFunctionCall, ToolFunctionDefinition,
|
||||
ToolRegistry, ToolView,
|
||||
LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition, ToolOutput, ToolRegistry,
|
||||
};
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
use crate::ProjectContext;
|
||||
use crate::{ProjectContext, ToolOutput};
|
||||
use anyhow::{anyhow, Result};
|
||||
use collections::HashMap;
|
||||
use futures::future::join_all;
|
||||
use gpui::{AnyView, Render, Task, View, WindowContext};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::value::RawValue;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
sync::{
|
||||
@@ -18,39 +16,25 @@ pub struct AttachmentRegistry {
|
||||
registered_attachments: HashMap<TypeId, RegisteredAttachment>,
|
||||
}
|
||||
|
||||
pub trait AttachmentOutput {
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
|
||||
}
|
||||
|
||||
pub trait LanguageModelAttachment {
|
||||
type Output: DeserializeOwned + Serialize + 'static;
|
||||
type View: Render + AttachmentOutput;
|
||||
type Output: 'static;
|
||||
type View: Render + ToolOutput;
|
||||
|
||||
fn name(&self) -> Arc<str>;
|
||||
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
|
||||
fn view(&self, output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
|
||||
|
||||
fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
|
||||
}
|
||||
|
||||
/// A collected attachment from running an attachment tool
|
||||
pub struct UserAttachment {
|
||||
pub view: AnyView,
|
||||
name: Arc<str>,
|
||||
serialized_output: Result<Box<RawValue>, String>,
|
||||
generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SavedUserAttachment {
|
||||
name: Arc<str>,
|
||||
serialized_output: Result<Box<RawValue>, String>,
|
||||
}
|
||||
|
||||
/// Internal representation of an attachment tool to allow us to treat them dynamically
|
||||
struct RegisteredAttachment {
|
||||
name: Arc<str>,
|
||||
enabled: AtomicBool,
|
||||
call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
|
||||
deserialize: Box<dyn Fn(&SavedUserAttachment, &mut WindowContext) -> Result<UserAttachment>>,
|
||||
}
|
||||
|
||||
impl AttachmentRegistry {
|
||||
@@ -61,65 +45,24 @@ impl AttachmentRegistry {
|
||||
}
|
||||
|
||||
pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
|
||||
let attachment = Arc::new(attachment);
|
||||
let call = Box::new(move |cx: &mut WindowContext| {
|
||||
let result = attachment.run(cx);
|
||||
|
||||
let call = Box::new({
|
||||
let attachment = attachment.clone();
|
||||
move |cx: &mut WindowContext| {
|
||||
let result = attachment.run(cx);
|
||||
let attachment = attachment.clone();
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let result: Result<A::Output> = result.await;
|
||||
let serialized_output =
|
||||
result
|
||||
.as_ref()
|
||||
.map_err(ToString::to_string)
|
||||
.and_then(|output| {
|
||||
Ok(RawValue::from_string(
|
||||
serde_json::to_string(output).map_err(|e| e.to_string())?,
|
||||
)
|
||||
.unwrap())
|
||||
});
|
||||
|
||||
let view = cx.update(|cx| attachment.view(result, cx))?;
|
||||
|
||||
Ok(UserAttachment {
|
||||
name: attachment.name(),
|
||||
view: view.into(),
|
||||
generate_fn: generate::<A>,
|
||||
serialized_output,
|
||||
})
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
let deserialize = Box::new({
|
||||
let attachment = attachment.clone();
|
||||
move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| {
|
||||
let serialized_output = saved_attachment.serialized_output.clone();
|
||||
let output = match &serialized_output {
|
||||
Ok(serialized_output) => {
|
||||
Ok(serde_json::from_str::<A::Output>(serialized_output.get())?)
|
||||
}
|
||||
Err(error) => Err(anyhow!("{error}")),
|
||||
};
|
||||
let view = attachment.view(output, cx).into();
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let result: Result<A::Output> = result.await;
|
||||
let view = cx.update(|cx| A::view(result, cx))?;
|
||||
|
||||
Ok(UserAttachment {
|
||||
name: saved_attachment.name.clone(),
|
||||
view,
|
||||
serialized_output,
|
||||
view: view.into(),
|
||||
generate_fn: generate::<A>,
|
||||
})
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
self.registered_attachments.insert(
|
||||
TypeId::of::<A>(),
|
||||
RegisteredAttachment {
|
||||
name: attachment.name(),
|
||||
call,
|
||||
deserialize,
|
||||
enabled: AtomicBool::new(true),
|
||||
},
|
||||
);
|
||||
@@ -191,35 +134,6 @@ impl AttachmentRegistry {
|
||||
.collect())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn serialize_user_attachment(
|
||||
&self,
|
||||
user_attachment: &UserAttachment,
|
||||
) -> SavedUserAttachment {
|
||||
SavedUserAttachment {
|
||||
name: user_attachment.name.clone(),
|
||||
serialized_output: user_attachment.serialized_output.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize_user_attachment(
|
||||
&self,
|
||||
saved_user_attachment: SavedUserAttachment,
|
||||
cx: &mut WindowContext,
|
||||
) -> Result<UserAttachment> {
|
||||
if let Some(registered_attachment) = self
|
||||
.registered_attachments
|
||||
.values()
|
||||
.find(|attachment| attachment.name == saved_user_attachment.name)
|
||||
{
|
||||
(registered_attachment.deserialize)(&saved_user_attachment, cx)
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"no attachment tool for name {}",
|
||||
saved_user_attachment.name
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserAttachment {
|
||||
|
||||
@@ -1,67 +1,41 @@
|
||||
use crate::ProjectContext;
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
|
||||
use repair_json::repair;
|
||||
use gpui::{
|
||||
div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
|
||||
};
|
||||
use schemars::{schema::RootSchema, schema_for, JsonSchema};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_json::value::RawValue;
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
collections::HashMap,
|
||||
fmt::Display,
|
||||
mem,
|
||||
sync::atomic::{AtomicBool, Ordering::SeqCst},
|
||||
};
|
||||
use ui::ViewContext;
|
||||
|
||||
use crate::ProjectContext;
|
||||
|
||||
pub struct ToolRegistry {
|
||||
registered_tools: HashMap<String, RegisteredTool>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
#[derive(Default, Deserialize)]
|
||||
pub struct ToolFunctionCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
state: ToolFunctionCallState,
|
||||
#[serde(skip)]
|
||||
pub result: Option<ToolFunctionCallResult>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
enum ToolFunctionCallState {
|
||||
#[default]
|
||||
Initializing,
|
||||
pub enum ToolFunctionCallResult {
|
||||
NoSuchTool,
|
||||
KnownTool(Box<dyn InternalToolView>),
|
||||
ExecutedTool(Box<dyn InternalToolView>),
|
||||
ParsingFailed,
|
||||
Finished {
|
||||
view: AnyView,
|
||||
generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
|
||||
},
|
||||
}
|
||||
|
||||
trait InternalToolView {
|
||||
fn view(&self) -> AnyView;
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
|
||||
fn try_set_input(&self, input: &str, cx: &mut WindowContext);
|
||||
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
|
||||
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
|
||||
fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
pub struct SavedToolFunctionCall {
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
state: SavedToolFunctionCallState,
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
enum SavedToolFunctionCallState {
|
||||
#[default]
|
||||
Initializing,
|
||||
NoSuchTool,
|
||||
KnownTool,
|
||||
ExecutedTool(Box<RawValue>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone)]
|
||||
pub struct ToolFunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
@@ -69,7 +43,14 @@ pub struct ToolFunctionDefinition {
|
||||
}
|
||||
|
||||
pub trait LanguageModelTool {
|
||||
type View: ToolView;
|
||||
/// The input type that will be passed in to `execute` when the tool is called
|
||||
/// by the language model.
|
||||
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
||||
|
||||
/// The output returned by executing the tool.
|
||||
type Output: 'static;
|
||||
|
||||
type View: Render + ToolOutput;
|
||||
|
||||
/// Returns the name of the tool.
|
||||
///
|
||||
@@ -85,7 +66,7 @@ pub trait LanguageModelTool {
|
||||
|
||||
/// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
|
||||
fn definition(&self) -> ToolFunctionDefinition {
|
||||
let root_schema = schema_for!(<Self::View as ToolView>::Input);
|
||||
let root_schema = schema_for!(Self::Input);
|
||||
|
||||
ToolFunctionDefinition {
|
||||
name: self.name(),
|
||||
@@ -94,34 +75,29 @@ pub trait LanguageModelTool {
|
||||
}
|
||||
}
|
||||
|
||||
/// A view of the output of running the tool, for displaying to the user.
|
||||
fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
|
||||
/// Executes the tool with the given input.
|
||||
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
|
||||
|
||||
fn output_view(
|
||||
input: Self::Input,
|
||||
output: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
) -> View<Self::View>;
|
||||
|
||||
fn render_running(_cx: &mut WindowContext) -> impl IntoElement {
|
||||
div()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ToolView: Render {
|
||||
/// The input type that will be passed in to `execute` when the tool is called
|
||||
/// by the language model.
|
||||
type Input: DeserializeOwned + JsonSchema;
|
||||
|
||||
/// The output returned by executing the tool.
|
||||
type SerializedState: DeserializeOwned + Serialize;
|
||||
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext<Self>) -> String;
|
||||
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>);
|
||||
fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>>;
|
||||
|
||||
fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState;
|
||||
fn deserialize(
|
||||
&mut self,
|
||||
output: Self::SerializedState,
|
||||
cx: &mut ViewContext<Self>,
|
||||
) -> Result<()>;
|
||||
pub trait ToolOutput: Sized {
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
|
||||
}
|
||||
|
||||
struct RegisteredTool {
|
||||
enabled: AtomicBool,
|
||||
type_id: TypeId,
|
||||
build_view: Box<dyn Fn(&mut WindowContext) -> Box<dyn InternalToolView>>,
|
||||
call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
|
||||
render_running: fn(&mut WindowContext) -> gpui::AnyElement,
|
||||
definition: ToolFunctionDefinition,
|
||||
}
|
||||
|
||||
@@ -158,141 +134,68 @@ impl ToolRegistry {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn update_tool_call(
|
||||
&self,
|
||||
call: &mut ToolFunctionCall,
|
||||
name: Option<&str>,
|
||||
arguments: Option<&str>,
|
||||
cx: &mut WindowContext,
|
||||
) {
|
||||
if let Some(name) = name {
|
||||
call.name.push_str(name);
|
||||
}
|
||||
if let Some(arguments) = arguments {
|
||||
if call.arguments.is_empty() {
|
||||
if let Some(tool) = self.registered_tools.get(&call.name) {
|
||||
let view = (tool.build_view)(cx);
|
||||
call.state = ToolFunctionCallState::KnownTool(view);
|
||||
} else {
|
||||
call.state = ToolFunctionCallState::NoSuchTool;
|
||||
}
|
||||
}
|
||||
call.arguments.push_str(arguments);
|
||||
|
||||
if let ToolFunctionCallState::KnownTool(view) = &call.state {
|
||||
if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
|
||||
view.try_set_input(&repaired_arguments, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute_tool_call(
|
||||
&self,
|
||||
tool_call: &mut ToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> Option<Task<Result<()>>> {
|
||||
if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) {
|
||||
let task = view.execute(cx);
|
||||
tool_call.state = ToolFunctionCallState::ExecutedTool(view);
|
||||
Some(task)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn render_tool_call(
|
||||
&self,
|
||||
tool_call: &ToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> AnyElement {
|
||||
match &tool_call.result {
|
||||
Some(result) => div()
|
||||
.p_2()
|
||||
.child(result.into_any_element(&tool_call.name))
|
||||
.into_any_element(),
|
||||
None => self
|
||||
.registered_tools
|
||||
.get(&tool_call.name)
|
||||
.map(|tool| (tool.render_running)(cx))
|
||||
.unwrap_or_else(|| div().into_any_element()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register<T: 'static + LanguageModelTool>(
|
||||
&mut self,
|
||||
tool: T,
|
||||
_cx: &mut WindowContext,
|
||||
) -> Option<AnyElement> {
|
||||
match &tool_call.state {
|
||||
ToolFunctionCallState::NoSuchTool => {
|
||||
Some(ui::Label::new("No such tool").into_any_element())
|
||||
}
|
||||
ToolFunctionCallState::Initializing => None,
|
||||
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
|
||||
Some(view.view().into_any_element())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn content_for_tool_call(
|
||||
&self,
|
||||
tool_call: &ToolFunctionCall,
|
||||
project_context: &mut ProjectContext,
|
||||
cx: &mut WindowContext,
|
||||
) -> String {
|
||||
match &tool_call.state {
|
||||
ToolFunctionCallState::Initializing => String::new(),
|
||||
ToolFunctionCallState::NoSuchTool => {
|
||||
format!("No such tool: {}", tool_call.name)
|
||||
}
|
||||
ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
|
||||
view.generate(project_context, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize_tool_call(
|
||||
&self,
|
||||
call: &ToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> Result<SavedToolFunctionCall> {
|
||||
Ok(SavedToolFunctionCall {
|
||||
id: call.id.clone(),
|
||||
name: call.name.clone(),
|
||||
arguments: call.arguments.clone(),
|
||||
state: match &call.state {
|
||||
ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing,
|
||||
ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool,
|
||||
ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool,
|
||||
ToolFunctionCallState::ExecutedTool(view) => {
|
||||
SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn deserialize_tool_call(
|
||||
&self,
|
||||
call: &SavedToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> Result<ToolFunctionCall> {
|
||||
let Some(tool) = self.registered_tools.get(&call.name) else {
|
||||
return Err(anyhow!("no such tool {}", call.name));
|
||||
};
|
||||
|
||||
Ok(ToolFunctionCall {
|
||||
id: call.id.clone(),
|
||||
name: call.name.clone(),
|
||||
arguments: call.arguments.clone(),
|
||||
state: match &call.state {
|
||||
SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing,
|
||||
SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool,
|
||||
SavedToolFunctionCallState::KnownTool => {
|
||||
log::error!("Deserialized tool that had not executed");
|
||||
let view = (tool.build_view)(cx);
|
||||
view.try_set_input(&call.arguments, cx);
|
||||
ToolFunctionCallState::KnownTool(view)
|
||||
}
|
||||
SavedToolFunctionCallState::ExecutedTool(output) => {
|
||||
let view = (tool.build_view)(cx);
|
||||
view.try_set_input(&call.arguments, cx);
|
||||
view.deserialize_output(output, cx)?;
|
||||
ToolFunctionCallState::ExecutedTool(view)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
|
||||
) -> Result<()> {
|
||||
let name = tool.name();
|
||||
let registered_tool = RegisteredTool {
|
||||
type_id: TypeId::of::<T>(),
|
||||
definition: tool.definition(),
|
||||
enabled: AtomicBool::new(true),
|
||||
build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))),
|
||||
call: Box::new(
|
||||
move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
|
||||
let name = tool_call.name.clone();
|
||||
let arguments = tool_call.arguments.clone();
|
||||
let id = tool_call.id.clone();
|
||||
|
||||
let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
|
||||
return Task::ready(Ok(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::ParsingFailed),
|
||||
}));
|
||||
};
|
||||
|
||||
let result = tool.execute(&input, cx);
|
||||
|
||||
cx.spawn(move |mut cx| async move {
|
||||
let result: Result<T::Output> = result.await;
|
||||
let view = cx.update(|cx| T::output_view(input, result, cx))?;
|
||||
|
||||
Ok(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::Finished {
|
||||
view: view.into(),
|
||||
generate_fn: generate::<T>,
|
||||
}),
|
||||
})
|
||||
})
|
||||
},
|
||||
),
|
||||
render_running: render_running::<T>,
|
||||
};
|
||||
|
||||
let previous = self.registered_tools.insert(name.clone(), registered_tool);
|
||||
@@ -301,40 +204,77 @@ impl ToolRegistry {
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ToolView> InternalToolView for View<T> {
|
||||
fn view(&self) -> AnyView {
|
||||
self.clone().into()
|
||||
}
|
||||
fn render_running<T: LanguageModelTool>(cx: &mut WindowContext) -> AnyElement {
|
||||
T::render_running(cx).into_any_element()
|
||||
}
|
||||
|
||||
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
|
||||
self.update(cx, |view, cx| view.generate(project, cx))
|
||||
}
|
||||
|
||||
fn try_set_input(&self, input: &str, cx: &mut WindowContext) {
|
||||
if let Ok(input) = serde_json::from_str::<T::Input>(input) {
|
||||
self.update(cx, |view, cx| {
|
||||
view.set_input(input, cx);
|
||||
cx.notify();
|
||||
});
|
||||
fn generate<T: LanguageModelTool>(
|
||||
view: AnyView,
|
||||
project: &mut ProjectContext,
|
||||
cx: &mut WindowContext,
|
||||
) -> String {
|
||||
view.downcast::<T::View>()
|
||||
.unwrap()
|
||||
.update(cx, |view, cx| T::View::generate(view, project, cx))
|
||||
}
|
||||
}
|
||||
|
||||
fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>> {
|
||||
self.update(cx, |view, cx| view.execute(cx))
|
||||
/// Task yields an error if the window for the given WindowContext is closed before the task completes.
|
||||
pub fn call(
|
||||
&self,
|
||||
tool_call: &ToolFunctionCall,
|
||||
cx: &mut WindowContext,
|
||||
) -> Task<Result<ToolFunctionCall>> {
|
||||
let name = tool_call.name.clone();
|
||||
let arguments = tool_call.arguments.clone();
|
||||
let id = tool_call.id.clone();
|
||||
|
||||
let tool = match self.registered_tools.get(&name) {
|
||||
Some(tool) => tool,
|
||||
None => {
|
||||
let name = name.clone();
|
||||
return Task::ready(Ok(ToolFunctionCall {
|
||||
id,
|
||||
name: name.clone(),
|
||||
arguments,
|
||||
result: Some(ToolFunctionCallResult::NoSuchTool),
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
(tool.call)(tool_call, cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolFunctionCallResult {
|
||||
pub fn generate(
|
||||
&self,
|
||||
name: &String,
|
||||
project: &mut ProjectContext,
|
||||
cx: &mut WindowContext,
|
||||
) -> String {
|
||||
match self {
|
||||
ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
|
||||
ToolFunctionCallResult::ParsingFailed => {
|
||||
format!("Unable to parse arguments for {name}")
|
||||
}
|
||||
ToolFunctionCallResult::Finished { generate_fn, view } => {
|
||||
(generate_fn)(view.clone(), project, cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>> {
|
||||
let output = self.update(cx, |view, cx| view.serialize(cx));
|
||||
Ok(RawValue::from_string(serde_json::to_string(&output)?)?)
|
||||
}
|
||||
|
||||
fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> {
|
||||
let state = serde_json::from_str::<T::SerializedState>(output.get())?;
|
||||
self.update(cx, |view, cx| view.deserialize(state, cx))?;
|
||||
Ok(())
|
||||
fn into_any_element(&self, name: &String) -> AnyElement {
|
||||
match self {
|
||||
ToolFunctionCallResult::NoSuchTool => {
|
||||
format!("Language Model attempted to call {name}").into_any_element()
|
||||
}
|
||||
ToolFunctionCallResult::ParsingFailed => {
|
||||
format!("Language Model called {name} with bad arguments").into_any_element()
|
||||
}
|
||||
ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -353,6 +293,7 @@ mod test {
|
||||
use super::*;
|
||||
use gpui::{div, prelude::*, Render, TestAppContext};
|
||||
use gpui::{EmptyView, View};
|
||||
use schemars::schema_for;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
@@ -363,6 +304,10 @@ mod test {
|
||||
unit: String,
|
||||
}
|
||||
|
||||
struct WeatherTool {
|
||||
current_weather: WeatherResult,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
|
||||
struct WeatherResult {
|
||||
location: String,
|
||||
@@ -371,81 +316,24 @@ mod test {
|
||||
}
|
||||
|
||||
struct WeatherView {
|
||||
input: Option<WeatherQuery>,
|
||||
result: Option<WeatherResult>,
|
||||
|
||||
// Fake API call
|
||||
current_weather: WeatherResult,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize)]
|
||||
struct WeatherTool {
|
||||
current_weather: WeatherResult,
|
||||
}
|
||||
|
||||
impl WeatherView {
|
||||
fn new(current_weather: WeatherResult) -> Self {
|
||||
Self {
|
||||
input: None,
|
||||
result: None,
|
||||
current_weather,
|
||||
}
|
||||
}
|
||||
result: WeatherResult,
|
||||
}
|
||||
|
||||
impl Render for WeatherView {
|
||||
fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
|
||||
match self.result {
|
||||
Some(ref result) => div()
|
||||
.child(format!("temperature: {}", result.temperature))
|
||||
.into_any_element(),
|
||||
None => div().child("Calculating weather...").into_any_element(),
|
||||
}
|
||||
div().child(format!("temperature: {}", self.result.temperature))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolView for WeatherView {
|
||||
type Input = WeatherQuery;
|
||||
|
||||
type SerializedState = WeatherResult;
|
||||
|
||||
fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
|
||||
impl ToolOutput for WeatherView {
|
||||
fn generate(&self, _output: &mut ProjectContext, _cx: &mut WindowContext) -> String {
|
||||
serde_json::to_string(&self.result).unwrap()
|
||||
}
|
||||
|
||||
fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
|
||||
self.input = Some(input);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
|
||||
let input = self.input.as_ref().unwrap();
|
||||
|
||||
let _location = input.location.clone();
|
||||
let _unit = input.unit.clone();
|
||||
|
||||
let weather = self.current_weather.clone();
|
||||
|
||||
self.result = Some(weather);
|
||||
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
|
||||
self.current_weather.clone()
|
||||
}
|
||||
|
||||
fn deserialize(
|
||||
&mut self,
|
||||
output: Self::SerializedState,
|
||||
_cx: &mut ViewContext<Self>,
|
||||
) -> Result<()> {
|
||||
self.current_weather = output;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelTool for WeatherTool {
|
||||
type Input = WeatherQuery;
|
||||
type Output = WeatherResult;
|
||||
type View = WeatherView;
|
||||
|
||||
fn name(&self) -> String {
|
||||
@@ -456,71 +344,88 @@ mod test {
|
||||
"Fetches the current weather for a given location.".to_string()
|
||||
}
|
||||
|
||||
fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
|
||||
cx.new_view(|_cx| WeatherView::new(self.current_weather.clone()))
|
||||
fn execute(
|
||||
&self,
|
||||
input: &Self::Input,
|
||||
_cx: &mut WindowContext,
|
||||
) -> Task<Result<Self::Output>> {
|
||||
let _location = input.location.clone();
|
||||
let _unit = input.unit.clone();
|
||||
|
||||
let weather = self.current_weather.clone();
|
||||
|
||||
Task::ready(Ok(weather))
|
||||
}
|
||||
|
||||
fn output_view(
|
||||
_input: Self::Input,
|
||||
result: Result<Self::Output>,
|
||||
cx: &mut WindowContext,
|
||||
) -> View<Self::View> {
|
||||
cx.new_view(|_cx| {
|
||||
let result = result.unwrap();
|
||||
WeatherView { result }
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_openai_weather_example(cx: &mut TestAppContext) {
|
||||
cx.background_executor.run_until_parked();
|
||||
let (_, cx) = cx.add_window_view(|_cx| EmptyView);
|
||||
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry
|
||||
.register(WeatherTool {
|
||||
current_weather: WeatherResult {
|
||||
location: "San Francisco".to_string(),
|
||||
temperature: 21.0,
|
||||
unit: "Celsius".to_string(),
|
||||
},
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let definitions = registry.definitions();
|
||||
assert_eq!(
|
||||
definitions,
|
||||
[ToolFunctionDefinition {
|
||||
name: "get_current_weather".to_string(),
|
||||
description: "Fetches the current weather for a given location.".to_string(),
|
||||
parameters: serde_json::from_value(json!({
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "WeatherQuery",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["location", "unit"]
|
||||
}))
|
||||
.unwrap(),
|
||||
}]
|
||||
);
|
||||
|
||||
let mut call = ToolFunctionCall {
|
||||
id: "the-id".to_string(),
|
||||
name: "get_cur".to_string(),
|
||||
..Default::default()
|
||||
let tool = WeatherTool {
|
||||
current_weather: WeatherResult {
|
||||
location: "San Francisco".to_string(),
|
||||
temperature: 21.0,
|
||||
unit: "Celsius".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let task = cx.update(|cx| {
|
||||
registry.update_tool_call(
|
||||
&mut call,
|
||||
Some("rent_weather"),
|
||||
Some(r#"{"location": "San Francisco","#),
|
||||
cx,
|
||||
);
|
||||
registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx);
|
||||
registry.execute_tool_call(&mut call, cx).unwrap()
|
||||
});
|
||||
task.await.unwrap();
|
||||
let tools = vec![tool.definition()];
|
||||
assert_eq!(tools.len(), 1);
|
||||
|
||||
match &call.state {
|
||||
ToolFunctionCallState::ExecutedTool(_view) => {}
|
||||
_ => panic!(),
|
||||
}
|
||||
let expected = ToolFunctionDefinition {
|
||||
name: "get_current_weather".to_string(),
|
||||
description: "Fetches the current weather for a given location.".to_string(),
|
||||
parameters: schema_for!(WeatherQuery),
|
||||
};
|
||||
|
||||
assert_eq!(tools[0].name, expected.name);
|
||||
assert_eq!(tools[0].description, expected.description);
|
||||
|
||||
let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
expected_schema,
|
||||
json!({
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "WeatherQuery",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["location", "unit"]
|
||||
})
|
||||
);
|
||||
|
||||
let args = json!({
|
||||
"location": "San Francisco",
|
||||
"unit": "Celsius"
|
||||
});
|
||||
|
||||
let query: WeatherQuery = serde_json::from_value(args).unwrap();
|
||||
|
||||
let result = cx.update(|cx| tool.execute(&query, cx)).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let result = result.unwrap();
|
||||
|
||||
assert_eq!(result, tool.current_weather);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ client.workspace = true
|
||||
db.workspace = true
|
||||
editor.workspace = true
|
||||
gpui.workspace = true
|
||||
http.workspace = true
|
||||
isahc.workspace = true
|
||||
log.workspace = true
|
||||
markdown_preview.workspace = true
|
||||
|
||||
@@ -20,7 +20,6 @@ use smol::{fs, io::AsyncReadExt};
|
||||
use settings::{Settings, SettingsSources, SettingsStore};
|
||||
use smol::{fs::File, process::Command};
|
||||
|
||||
use http::{HttpClient, HttpClientWithUrl};
|
||||
use release_channel::{AppCommitSha, AppVersion, ReleaseChannel};
|
||||
use std::{
|
||||
env::consts::{ARCH, OS},
|
||||
@@ -30,7 +29,10 @@ use std::{
|
||||
time::Duration,
|
||||
};
|
||||
use update_notification::UpdateNotification;
|
||||
use util::ResultExt;
|
||||
use util::{
|
||||
http::{HttpClient, HttpClientWithUrl},
|
||||
ResultExt,
|
||||
};
|
||||
use workspace::notifications::NotificationId;
|
||||
use workspace::Workspace;
|
||||
|
||||
@@ -54,22 +56,16 @@ struct UpdateRequestBody {
|
||||
telemetry: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AutoUpdateStatus {
|
||||
Idle,
|
||||
Checking,
|
||||
Downloading,
|
||||
Installing,
|
||||
Updated { binary_path: PathBuf },
|
||||
Updated,
|
||||
Errored,
|
||||
}
|
||||
|
||||
impl AutoUpdateStatus {
|
||||
pub fn is_updated(&self) -> bool {
|
||||
matches!(self, Self::Updated { .. })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AutoUpdater {
|
||||
status: AutoUpdateStatus,
|
||||
current_version: SemanticVersion,
|
||||
@@ -310,7 +306,7 @@ impl AutoUpdater {
|
||||
}
|
||||
|
||||
pub fn poll(&mut self, cx: &mut ModelContext<Self>) {
|
||||
if self.pending_poll.is_some() || self.status.is_updated() {
|
||||
if self.pending_poll.is_some() || self.status == AutoUpdateStatus::Updated {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -332,7 +328,7 @@ impl AutoUpdater {
|
||||
}
|
||||
|
||||
pub fn status(&self) -> AutoUpdateStatus {
|
||||
self.status.clone()
|
||||
self.status
|
||||
}
|
||||
|
||||
pub fn dismiss_error(&mut self, cx: &mut ModelContext<Self>) {
|
||||
@@ -408,11 +404,6 @@ impl AutoUpdater {
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
// We store the path of our current binary, before we install, since installation might
|
||||
// delete it. Once deleted, it's hard to get the path to our binary on Linux.
|
||||
// So we cache it here, which allows us to then restart later on.
|
||||
let binary_path = cx.update(|cx| cx.app_path())??;
|
||||
|
||||
match OS {
|
||||
"macos" => install_release_macos(&temp_dir, downloaded_asset, &cx).await,
|
||||
"linux" => install_release_linux(&temp_dir, downloaded_asset, &cx).await,
|
||||
@@ -422,7 +413,7 @@ impl AutoUpdater {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.set_should_show_update_notification(true, cx)
|
||||
.detach_and_log_err(cx);
|
||||
this.status = AutoUpdateStatus::Updated { binary_path };
|
||||
this.status = AutoUpdateStatus::Updated;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
|
||||
@@ -50,4 +50,3 @@ language = { workspace = true, features = ["test-support"] }
|
||||
live_kit_client = { workspace = true, features = ["test-support"] }
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
util = { workspace = true, features = ["test-support"] }
|
||||
http = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -40,4 +40,3 @@ rpc = { workspace = true, features = ["test-support"] }
|
||||
client = { workspace = true, features = ["test-support"] }
|
||||
settings = { workspace = true, features = ["test-support"] }
|
||||
util = { workspace = true, features = ["test-support"] }
|
||||
http = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -123,7 +123,6 @@ impl Channel {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ChannelMembership {
|
||||
pub user: Arc<User>,
|
||||
pub kind: proto::channel_member::Kind,
|
||||
@@ -816,11 +815,9 @@ impl ChannelStore {
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
pub fn fuzzy_search_members(
|
||||
pub fn get_channel_member_details(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
query: String,
|
||||
limit: u16,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Task<Result<Vec<ChannelMembership>>> {
|
||||
let client = self.client.clone();
|
||||
@@ -829,24 +826,26 @@ impl ChannelStore {
|
||||
let response = client
|
||||
.request(proto::GetChannelMembers {
|
||||
channel_id: channel_id.0,
|
||||
query,
|
||||
limit: limit as u64,
|
||||
})
|
||||
.await?;
|
||||
user_store.update(&mut cx, |user_store, _| {
|
||||
user_store.insert(response.users);
|
||||
response
|
||||
.members
|
||||
.into_iter()
|
||||
.filter_map(|member| {
|
||||
Some(ChannelMembership {
|
||||
user: user_store.get_cached_user(member.user_id)?,
|
||||
role: member.role(),
|
||||
kind: member.kind(),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
|
||||
let user_ids = response.members.iter().map(|m| m.user_id).collect();
|
||||
let user_store = user_store
|
||||
.upgrade()
|
||||
.ok_or_else(|| anyhow!("user store dropped"))?;
|
||||
let users = user_store
|
||||
.update(&mut cx, |user_store, cx| user_store.get_users(user_ids, cx))?
|
||||
.await?;
|
||||
|
||||
Ok(users
|
||||
.into_iter()
|
||||
.zip(response.members)
|
||||
.map(|(user, member)| ChannelMembership {
|
||||
user,
|
||||
role: member.role(),
|
||||
kind: member.kind(),
|
||||
})
|
||||
.collect())
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ use super::*;
|
||||
use client::{test::FakeServer, Client, UserStore};
|
||||
use clock::FakeSystemClock;
|
||||
use gpui::{AppContext, Context, Model, TestAppContext};
|
||||
use http::FakeHttpClient;
|
||||
use rpc::proto::{self};
|
||||
use settings::SettingsStore;
|
||||
use util::http::FakeHttpClient;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_update_channels(cx: &mut AppContext) {
|
||||
|
||||
@@ -19,17 +19,11 @@ path = "src/main.rs"
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
clap.workspace = true
|
||||
libc.workspace = true
|
||||
ipc-channel = "0.18"
|
||||
once_cell.workspace = true
|
||||
release_channel.workspace = true
|
||||
serde.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
exec.workspace = true
|
||||
fork.workspace = true
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
core-foundation.workspace = true
|
||||
core-services = "0.2"
|
||||
|
||||
@@ -13,7 +13,6 @@ pub enum CliRequest {
|
||||
paths: Vec<String>,
|
||||
wait: bool,
|
||||
open_new_workspace: Option<bool>,
|
||||
dev_server_token: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,24 +1,17 @@
|
||||
#![cfg_attr(any(target_os = "linux", target_os = "windows"), allow(dead_code))]
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use clap::Parser;
|
||||
use cli::{ipc::IpcOneShotServer, CliRequest, CliResponse, IpcHandshake};
|
||||
use cli::{CliRequest, CliResponse};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
env, fs, io,
|
||||
env,
|
||||
ffi::OsStr,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
process::ExitStatus,
|
||||
thread::{self, JoinHandle},
|
||||
};
|
||||
use util::paths::PathLikeWithPosition;
|
||||
|
||||
struct Detect;
|
||||
|
||||
trait InstalledApp {
|
||||
fn zed_version_string(&self) -> String;
|
||||
fn launch(&self, ipc_url: String) -> anyhow::Result<()>;
|
||||
fn run_foreground(&self, ipc_url: String) -> io::Result<ExitStatus>;
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "zed", disable_version_flag = true)]
|
||||
struct Args {
|
||||
@@ -40,12 +33,9 @@ struct Args {
|
||||
/// Print Zed's version and the app path.
|
||||
#[arg(short, long)]
|
||||
version: bool,
|
||||
/// Run zed in the foreground (useful for debugging)
|
||||
#[arg(long)]
|
||||
foreground: bool,
|
||||
/// Custom path to Zed.app or the zed binary
|
||||
#[arg(long)]
|
||||
zed: Option<PathBuf>,
|
||||
/// Custom Zed.app path
|
||||
#[arg(short, long)]
|
||||
bundle_path: Option<PathBuf>,
|
||||
/// Run zed in dev-server mode
|
||||
#[arg(long)]
|
||||
dev_server_token: Option<String>,
|
||||
@@ -59,6 +49,12 @@ fn parse_path_with_position(
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct InfoPlist {
|
||||
#[serde(rename = "CFBundleShortVersionString")]
|
||||
bundle_short_version_string: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
// Intercept version designators
|
||||
#[cfg(target_os = "macos")]
|
||||
@@ -72,10 +68,14 @@ fn main() -> Result<()> {
|
||||
}
|
||||
let args = Args::parse();
|
||||
|
||||
let app = Detect::detect(args.zed.as_deref()).context("Bundle detection")?;
|
||||
let bundle = Bundle::detect(args.bundle_path.as_deref()).context("Bundle detection")?;
|
||||
|
||||
if let Some(dev_server_token) = args.dev_server_token {
|
||||
return bundle.spawn(vec!["--dev-server-token".into(), dev_server_token]);
|
||||
}
|
||||
|
||||
if args.version {
|
||||
println!("{}", app.zed_version_string());
|
||||
println!("{}", bundle.zed_version_string());
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -101,10 +101,7 @@ fn main() -> Result<()> {
|
||||
paths.push(canonicalized.to_string(|path| path.display().to_string()))
|
||||
}
|
||||
|
||||
let (server, server_name) =
|
||||
IpcOneShotServer::<IpcHandshake>::new().context("Handshake before Zed spawn")?;
|
||||
let url = format!("zed-cli://{server_name}");
|
||||
|
||||
let (tx, rx) = bundle.launch()?;
|
||||
let open_new_workspace = if args.new {
|
||||
Some(true)
|
||||
} else if args.add {
|
||||
@@ -113,164 +110,78 @@ fn main() -> Result<()> {
|
||||
None
|
||||
};
|
||||
|
||||
let sender: JoinHandle<anyhow::Result<()>> = thread::spawn(move || {
|
||||
let (_, handshake) = server.accept().context("Handshake after Zed spawn")?;
|
||||
let (tx, rx) = (handshake.requests, handshake.responses);
|
||||
tx.send(CliRequest::Open {
|
||||
paths,
|
||||
wait: args.wait,
|
||||
open_new_workspace,
|
||||
dev_server_token: args.dev_server_token,
|
||||
})?;
|
||||
tx.send(CliRequest::Open {
|
||||
paths,
|
||||
wait: args.wait,
|
||||
open_new_workspace,
|
||||
})?;
|
||||
|
||||
while let Ok(response) = rx.recv() {
|
||||
match response {
|
||||
CliResponse::Ping => {}
|
||||
CliResponse::Stdout { message } => println!("{message}"),
|
||||
CliResponse::Stderr { message } => eprintln!("{message}"),
|
||||
CliResponse::Exit { status } => std::process::exit(status),
|
||||
}
|
||||
while let Ok(response) = rx.recv() {
|
||||
match response {
|
||||
CliResponse::Ping => {}
|
||||
CliResponse::Stdout { message } => println!("{message}"),
|
||||
CliResponse::Stderr { message } => eprintln!("{message}"),
|
||||
CliResponse::Exit { status } => std::process::exit(status),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
});
|
||||
|
||||
if args.foreground {
|
||||
app.run_foreground(url)?;
|
||||
} else {
|
||||
app.launch(url)?;
|
||||
sender.join().unwrap()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
enum Bundle {
|
||||
App {
|
||||
app_bundle: PathBuf,
|
||||
plist: InfoPlist,
|
||||
},
|
||||
LocalPath {
|
||||
executable: PathBuf,
|
||||
plist: InfoPlist,
|
||||
},
|
||||
}
|
||||
|
||||
fn locate_bundle() -> Result<PathBuf> {
|
||||
let cli_path = std::env::current_exe()?.canonicalize()?;
|
||||
let mut app_path = cli_path.clone();
|
||||
while app_path.extension() != Some(OsStr::new("app")) {
|
||||
if !app_path.pop() {
|
||||
return Err(anyhow!("cannot find app bundle containing {:?}", cli_path));
|
||||
}
|
||||
}
|
||||
Ok(app_path)
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
mod linux {
|
||||
use std::{
|
||||
env,
|
||||
ffi::OsString,
|
||||
io,
|
||||
os::{
|
||||
linux::net::SocketAddrExt,
|
||||
unix::net::{SocketAddr, UnixDatagram},
|
||||
},
|
||||
path::{Path, PathBuf},
|
||||
process::{self, ExitStatus},
|
||||
thread,
|
||||
time::Duration,
|
||||
};
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use cli::FORCE_CLI_MODE_ENV_VAR_NAME;
|
||||
use fork::Fork;
|
||||
use once_cell::sync::Lazy;
|
||||
use cli::{CliRequest, CliResponse};
|
||||
use ipc_channel::ipc::{IpcReceiver, IpcSender};
|
||||
|
||||
use crate::{Detect, InstalledApp};
|
||||
use crate::{Bundle, InfoPlist};
|
||||
|
||||
static RELEASE_CHANNEL: Lazy<String> =
|
||||
Lazy::new(|| include_str!("../../zed/RELEASE_CHANNEL").trim().to_string());
|
||||
|
||||
struct App(PathBuf);
|
||||
|
||||
impl Detect {
|
||||
pub fn detect(path: Option<&Path>) -> anyhow::Result<impl InstalledApp> {
|
||||
let path = if let Some(path) = path {
|
||||
path.to_path_buf().canonicalize()
|
||||
} else {
|
||||
let cli = env::current_exe()?;
|
||||
let dir = cli
|
||||
.parent()
|
||||
.ok_or_else(|| anyhow!("no parent path for cli"))?;
|
||||
|
||||
match dir.join("zed").canonicalize() {
|
||||
Ok(path) => Ok(path),
|
||||
// development builds have Zed capitalized
|
||||
Err(e) => match dir.join("Zed").canonicalize() {
|
||||
Ok(path) => Ok(path),
|
||||
Err(_) => Err(e),
|
||||
},
|
||||
}
|
||||
}?;
|
||||
|
||||
Ok(App(path))
|
||||
}
|
||||
}
|
||||
|
||||
impl InstalledApp for App {
|
||||
fn zed_version_string(&self) -> String {
|
||||
format!(
|
||||
"Zed {}{} – {}",
|
||||
if *RELEASE_CHANNEL == "stable" {
|
||||
"".to_string()
|
||||
} else {
|
||||
format!(" {} ", *RELEASE_CHANNEL)
|
||||
},
|
||||
option_env!("RELEASE_VERSION").unwrap_or_default(),
|
||||
self.0.display(),
|
||||
)
|
||||
impl Bundle {
|
||||
pub fn detect(_args_bundle_path: Option<&Path>) -> anyhow::Result<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn launch(&self, ipc_url: String) -> anyhow::Result<()> {
|
||||
let uid: u32 = unsafe { libc::getuid() };
|
||||
let sock_addr =
|
||||
SocketAddr::from_abstract_name(format!("zed-{}-{}", *RELEASE_CHANNEL, uid))?;
|
||||
|
||||
let sock = UnixDatagram::unbound()?;
|
||||
if sock.connect_addr(&sock_addr).is_err() {
|
||||
self.boot_background(ipc_url)?;
|
||||
} else {
|
||||
sock.send(ipc_url.as_bytes())?;
|
||||
}
|
||||
Ok(())
|
||||
pub fn plist(&self) -> &InfoPlist {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn run_foreground(&self, ipc_url: String) -> io::Result<ExitStatus> {
|
||||
std::process::Command::new(self.0.clone())
|
||||
.arg(ipc_url)
|
||||
.status()
|
||||
}
|
||||
}
|
||||
|
||||
impl App {
|
||||
fn boot_background(&self, ipc_url: String) -> anyhow::Result<()> {
|
||||
let path = &self.0;
|
||||
|
||||
match fork::fork() {
|
||||
Ok(Fork::Parent(_)) => Ok(()),
|
||||
Ok(Fork::Child) => {
|
||||
std::env::set_var(FORCE_CLI_MODE_ENV_VAR_NAME, "");
|
||||
if let Err(_) = fork::setsid() {
|
||||
eprintln!("failed to setsid: {}", std::io::Error::last_os_error());
|
||||
process::exit(1);
|
||||
}
|
||||
if std::env::var("ZED_KEEP_FD").is_err() {
|
||||
if let Err(_) = fork::close_fd() {
|
||||
eprintln!("failed to close_fd: {}", std::io::Error::last_os_error());
|
||||
}
|
||||
}
|
||||
let error =
|
||||
exec::execvp(path.clone(), &[path.as_os_str(), &OsString::from(ipc_url)]);
|
||||
// if exec succeeded, we never get here.
|
||||
eprintln!("failed to exec {:?}: {}", path, error);
|
||||
process::exit(1)
|
||||
}
|
||||
Err(_) => Err(anyhow!(io::Error::last_os_error())),
|
||||
}
|
||||
pub fn path(&self) -> &Path {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn wait_for_socket(
|
||||
&self,
|
||||
sock_addr: &SocketAddr,
|
||||
sock: &mut UnixDatagram,
|
||||
) -> Result<(), std::io::Error> {
|
||||
for _ in 0..100 {
|
||||
thread::sleep(Duration::from_millis(10));
|
||||
if sock.connect_addr(&sock_addr).is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
sock.connect_addr(&sock_addr)
|
||||
pub fn launch(&self) -> anyhow::Result<(IpcSender<CliRequest>, IpcReceiver<CliResponse>)> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub fn spawn(&self, _args: Vec<String>) -> anyhow::Result<()> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub fn zed_version_string(&self) -> String {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -278,84 +189,59 @@ mod linux {
|
||||
// todo("windows")
|
||||
#[cfg(target_os = "windows")]
|
||||
mod windows {
|
||||
use crate::{Detect, InstalledApp};
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::process::ExitStatus;
|
||||
|
||||
struct App;
|
||||
impl InstalledApp for App {
|
||||
fn zed_version_string(&self) -> String {
|
||||
unimplemented!()
|
||||
}
|
||||
fn launch(&self, _ipc_url: String) -> anyhow::Result<()> {
|
||||
unimplemented!()
|
||||
}
|
||||
fn run_foreground(&self, _ipc_url: String) -> io::Result<ExitStatus> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
use cli::{CliRequest, CliResponse};
|
||||
use ipc_channel::ipc::{IpcReceiver, IpcSender};
|
||||
|
||||
impl Detect {
|
||||
pub fn detect(_path: Option<&Path>) -> anyhow::Result<impl InstalledApp> {
|
||||
Ok(App)
|
||||
use crate::{Bundle, InfoPlist};
|
||||
|
||||
impl Bundle {
|
||||
pub fn detect(_args_bundle_path: Option<&Path>) -> anyhow::Result<Self> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub fn plist(&self) -> &InfoPlist {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub fn path(&self) -> &Path {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub fn launch(&self) -> anyhow::Result<(IpcSender<CliRequest>, IpcReceiver<CliResponse>)> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub fn spawn(&self, _args: Vec<String>) -> anyhow::Result<()> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
pub fn zed_version_string(&self) -> String {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
mod mac_os {
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use anyhow::{Context, Result};
|
||||
use core_foundation::{
|
||||
array::{CFArray, CFIndex},
|
||||
string::kCFStringEncodingUTF8,
|
||||
url::{CFURLCreateWithBytes, CFURL},
|
||||
};
|
||||
use core_services::{kLSLaunchDefaults, LSLaunchURLSpec, LSOpenFromURLSpec, TCFType};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
ffi::OsStr,
|
||||
fs, io,
|
||||
path::{Path, PathBuf},
|
||||
process::{Command, ExitStatus},
|
||||
ptr,
|
||||
};
|
||||
use std::{fs, path::Path, process::Command, ptr};
|
||||
|
||||
use cli::FORCE_CLI_MODE_ENV_VAR_NAME;
|
||||
use cli::{CliRequest, CliResponse, IpcHandshake, FORCE_CLI_MODE_ENV_VAR_NAME};
|
||||
use ipc_channel::ipc::{IpcOneShotServer, IpcReceiver, IpcSender};
|
||||
|
||||
use crate::{Detect, InstalledApp};
|
||||
use crate::{locate_bundle, Bundle, InfoPlist};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct InfoPlist {
|
||||
#[serde(rename = "CFBundleShortVersionString")]
|
||||
bundle_short_version_string: String,
|
||||
}
|
||||
|
||||
enum Bundle {
|
||||
App {
|
||||
app_bundle: PathBuf,
|
||||
plist: InfoPlist,
|
||||
},
|
||||
LocalPath {
|
||||
executable: PathBuf,
|
||||
plist: InfoPlist,
|
||||
},
|
||||
}
|
||||
|
||||
fn locate_bundle() -> Result<PathBuf> {
|
||||
let cli_path = std::env::current_exe()?.canonicalize()?;
|
||||
let mut app_path = cli_path.clone();
|
||||
while app_path.extension() != Some(OsStr::new("app")) {
|
||||
if !app_path.pop() {
|
||||
return Err(anyhow!("cannot find app bundle containing {:?}", cli_path));
|
||||
}
|
||||
}
|
||||
Ok(app_path)
|
||||
}
|
||||
|
||||
impl Detect {
|
||||
pub fn detect(path: Option<&Path>) -> anyhow::Result<impl InstalledApp> {
|
||||
let bundle_path = if let Some(bundle_path) = path {
|
||||
impl Bundle {
|
||||
pub fn detect(args_bundle_path: Option<&Path>) -> anyhow::Result<Self> {
|
||||
let bundle_path = if let Some(bundle_path) = args_bundle_path {
|
||||
bundle_path
|
||||
.canonicalize()
|
||||
.with_context(|| format!("Args bundle path {bundle_path:?} canonicalization"))?
|
||||
@@ -370,7 +256,7 @@ mod mac_os {
|
||||
plist::from_file::<_, InfoPlist>(&plist_path).with_context(|| {
|
||||
format!("Reading *.app bundle plist file at {plist_path:?}")
|
||||
})?;
|
||||
Ok(Bundle::App {
|
||||
Ok(Self::App {
|
||||
app_bundle: bundle_path,
|
||||
plist,
|
||||
})
|
||||
@@ -385,27 +271,42 @@ mod mac_os {
|
||||
plist::from_file::<_, InfoPlist>(&plist_path).with_context(|| {
|
||||
format!("Reading dev bundle plist file at {plist_path:?}")
|
||||
})?;
|
||||
Ok(Bundle::LocalPath {
|
||||
Ok(Self::LocalPath {
|
||||
executable: bundle_path,
|
||||
plist,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl InstalledApp for Bundle {
|
||||
fn zed_version_string(&self) -> String {
|
||||
let is_dev = matches!(self, Self::LocalPath { .. });
|
||||
format!(
|
||||
"Zed {}{} – {}",
|
||||
self.plist().bundle_short_version_string,
|
||||
if is_dev { " (dev)" } else { "" },
|
||||
self.path().display(),
|
||||
)
|
||||
fn plist(&self) -> &InfoPlist {
|
||||
match self {
|
||||
Self::App { plist, .. } => plist,
|
||||
Self::LocalPath { plist, .. } => plist,
|
||||
}
|
||||
}
|
||||
|
||||
fn launch(&self, url: String) -> anyhow::Result<()> {
|
||||
fn path(&self) -> &Path {
|
||||
match self {
|
||||
Self::App { app_bundle, .. } => app_bundle,
|
||||
Self::LocalPath { executable, .. } => executable,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn spawn(&self, args: Vec<String>) -> Result<()> {
|
||||
let path = match self {
|
||||
Self::App { app_bundle, .. } => app_bundle.join("Contents/MacOS/zed"),
|
||||
Self::LocalPath { executable, .. } => executable.clone(),
|
||||
};
|
||||
Command::new(path).args(args).status()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn launch(&self) -> anyhow::Result<(IpcSender<CliRequest>, IpcReceiver<CliResponse>)> {
|
||||
let (server, server_name) =
|
||||
IpcOneShotServer::<IpcHandshake>::new().context("Handshake before Zed spawn")?;
|
||||
let url = format!("zed-cli://{server_name}");
|
||||
|
||||
match self {
|
||||
Self::App { app_bundle, .. } => {
|
||||
let app_path = app_bundle;
|
||||
@@ -467,32 +368,18 @@ mod mac_os {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
let (_, handshake) = server.accept().context("Handshake after Zed spawn")?;
|
||||
Ok((handshake.requests, handshake.responses))
|
||||
}
|
||||
|
||||
fn run_foreground(&self, ipc_url: String) -> io::Result<ExitStatus> {
|
||||
let path = match self {
|
||||
Bundle::App { app_bundle, .. } => app_bundle.join("Contents/MacOS/zed"),
|
||||
Bundle::LocalPath { executable, .. } => executable.clone(),
|
||||
};
|
||||
|
||||
std::process::Command::new(path).arg(ipc_url).status()
|
||||
}
|
||||
}
|
||||
|
||||
impl Bundle {
|
||||
fn plist(&self) -> &InfoPlist {
|
||||
match self {
|
||||
Self::App { plist, .. } => plist,
|
||||
Self::LocalPath { plist, .. } => plist,
|
||||
}
|
||||
}
|
||||
|
||||
fn path(&self) -> &Path {
|
||||
match self {
|
||||
Self::App { app_bundle, .. } => app_bundle,
|
||||
Self::LocalPath { executable, .. } => executable,
|
||||
}
|
||||
pub fn zed_version_string(&self) -> String {
|
||||
let is_dev = matches!(self, Self::LocalPath { .. });
|
||||
format!(
|
||||
"Zed {}{} – {}",
|
||||
self.plist().bundle_short_version_string,
|
||||
if is_dev { " (dev)" } else { "" },
|
||||
self.path().display(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,17 +19,15 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup
|
||||
anyhow.workspace = true
|
||||
async-recursion = "0.3"
|
||||
async-tungstenite = { version = "0.16", features = ["async-std", "async-native-tls"] }
|
||||
async-native-tls = { version = "0.5.0", features = ["vendored"] }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
clock.workspace = true
|
||||
collections.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
http.workspace = true
|
||||
lazy_static.workspace = true
|
||||
log.workspace = true
|
||||
once_cell.workspace = true
|
||||
once_cell = "1.19.0"
|
||||
parking_lot.workspace = true
|
||||
postage.workspace = true
|
||||
rand.workspace = true
|
||||
@@ -58,11 +56,3 @@ gpui = { workspace = true, features = ["test-support"] }
|
||||
rpc = { workspace = true, features = ["test-support"] }
|
||||
settings = { workspace = true, features = ["test-support"] }
|
||||
util = { workspace = true, features = ["test-support"] }
|
||||
http = { workspace = true, features = ["test-support"] }
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
async-native-tls = {"version" = "0.5.0", features = ["vendored"]}
|
||||
# This is an indirect dependency of async-tungstenite that is included
|
||||
# here so we can vendor libssl with the feature flag.
|
||||
[package.metadata.cargo-machete]
|
||||
ignored = ["async-native-tls"]
|
||||
|
||||
@@ -17,9 +17,9 @@ use futures::{
|
||||
TryFutureExt as _, TryStreamExt,
|
||||
};
|
||||
use gpui::{
|
||||
actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Global, Model, Task, WeakModel,
|
||||
actions, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, BorrowAppContext, Global, Model,
|
||||
Task, WeakModel,
|
||||
};
|
||||
use http::{HttpClient, HttpClientWithUrl};
|
||||
use lazy_static::lazy_static;
|
||||
use parking_lot::RwLock;
|
||||
use postage::watch;
|
||||
@@ -28,7 +28,7 @@ use release_channel::{AppVersion, ReleaseChannel};
|
||||
use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
use settings::{Settings, SettingsSources, SettingsStore};
|
||||
use std::fmt;
|
||||
use std::pin::Pin;
|
||||
use std::{
|
||||
@@ -47,6 +47,7 @@ use std::{
|
||||
use telemetry::Telemetry;
|
||||
use thiserror::Error;
|
||||
use url::Url;
|
||||
use util::http::{HttpClient, HttpClientWithUrl};
|
||||
use util::{ResultExt, TryFutureExt};
|
||||
|
||||
pub use rpc::*;
|
||||
@@ -85,7 +86,7 @@ lazy_static! {
|
||||
}
|
||||
|
||||
pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100);
|
||||
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
|
||||
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
actions!(client, [SignIn, SignOut, Reconnect]);
|
||||
|
||||
@@ -113,35 +114,11 @@ impl Settings for ClientSettings {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ProxySettingsContent {
|
||||
proxy: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Default)]
|
||||
pub struct ProxySettings {
|
||||
pub proxy: Option<String>,
|
||||
}
|
||||
|
||||
impl Settings for ProxySettings {
|
||||
const KEY: Option<&'static str> = None;
|
||||
|
||||
type FileContent = ProxySettingsContent;
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
|
||||
Ok(Self {
|
||||
proxy: sources
|
||||
.user
|
||||
.and_then(|value| value.proxy.clone())
|
||||
.or(sources.default.proxy.clone()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init_settings(cx: &mut AppContext) {
|
||||
TelemetrySettings::register(cx);
|
||||
ClientSettings::register(cx);
|
||||
ProxySettings::register(cx);
|
||||
cx.update_global(|store: &mut SettingsStore, cx| {
|
||||
store.register_setting::<ClientSettings>(cx);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn init(client: &Arc<Client>, cx: &mut AppContext) {
|
||||
@@ -227,7 +204,7 @@ pub enum EstablishConnectionError {
|
||||
#[error("{0}")]
|
||||
Other(#[from] anyhow::Error),
|
||||
#[error("{0}")]
|
||||
Http(#[from] http::Error),
|
||||
Http(#[from] util::http::Error),
|
||||
#[error("{0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
@@ -535,7 +512,6 @@ impl Client {
|
||||
let clock = Arc::new(clock::RealSystemClock);
|
||||
let http = Arc::new(HttpClientWithUrl::new(
|
||||
&ClientSettings::get_global(cx).server_url,
|
||||
ProxySettings::get_global(cx).proxy.clone(),
|
||||
));
|
||||
Self::new(clock, http.clone(), cx)
|
||||
}
|
||||
@@ -1703,10 +1679,10 @@ mod tests {
|
||||
|
||||
use clock::FakeSystemClock;
|
||||
use gpui::{BackgroundExecutor, Context, TestAppContext};
|
||||
use http::FakeHttpClient;
|
||||
use parking_lot::Mutex;
|
||||
use settings::SettingsStore;
|
||||
use std::future;
|
||||
use util::http::FakeHttpClient;
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_reconnection(cx: &mut TestAppContext) {
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -5,7 +5,6 @@ use chrono::{DateTime, Utc};
|
||||
use clock::SystemClock;
|
||||
use futures::Future;
|
||||
use gpui::{AppContext, AppMetadata, BackgroundExecutor, Task};
|
||||
use http::{self, HttpClient, HttpClientWithUrl, Method};
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::Mutex;
|
||||
use release_channel::ReleaseChannel;
|
||||
@@ -15,11 +14,12 @@ use std::io::Write;
|
||||
use std::{env, mem, path::PathBuf, sync::Arc, time::Duration};
|
||||
use sysinfo::{CpuRefreshKind, Pid, ProcessRefreshKind, RefreshKind, System};
|
||||
use telemetry_events::{
|
||||
ActionEvent, AppEvent, AssistantEvent, AssistantKind, CallEvent, CpuEvent, EditEvent,
|
||||
EditorEvent, Event, EventRequestBody, EventWrapper, ExtensionEvent, InlineCompletionEvent,
|
||||
MemoryEvent, SettingEvent,
|
||||
ActionEvent, AppEvent, AssistantEvent, AssistantKind, CallEvent, CopilotEvent, CpuEvent,
|
||||
EditEvent, EditorEvent, Event, EventRequestBody, EventWrapper, ExtensionEvent, MemoryEvent,
|
||||
SettingEvent,
|
||||
};
|
||||
use tempfile::NamedTempFile;
|
||||
use util::http::{self, HttpClient, HttpClientWithUrl, Method};
|
||||
#[cfg(not(debug_assertions))]
|
||||
use util::ResultExt;
|
||||
use util::TryFutureExt;
|
||||
@@ -241,14 +241,14 @@ impl Telemetry {
|
||||
self.report_event(event)
|
||||
}
|
||||
|
||||
pub fn report_inline_completion_event(
|
||||
pub fn report_copilot_event(
|
||||
self: &Arc<Self>,
|
||||
provider: String,
|
||||
suggestion_id: Option<String>,
|
||||
suggestion_accepted: bool,
|
||||
file_extension: Option<String>,
|
||||
) {
|
||||
let event = Event::InlineCompletion(InlineCompletionEvent {
|
||||
provider,
|
||||
let event = Event::Copilot(CopilotEvent {
|
||||
suggestion_id,
|
||||
suggestion_accepted,
|
||||
file_extension,
|
||||
});
|
||||
@@ -261,15 +261,11 @@ impl Telemetry {
|
||||
conversation_id: Option<String>,
|
||||
kind: AssistantKind,
|
||||
model: String,
|
||||
response_latency: Option<Duration>,
|
||||
error_message: Option<String>,
|
||||
) {
|
||||
let event = Event::Assistant(AssistantEvent {
|
||||
conversation_id,
|
||||
kind,
|
||||
model: model.to_string(),
|
||||
response_latency,
|
||||
error_message,
|
||||
});
|
||||
|
||||
self.report_event(event)
|
||||
@@ -501,7 +497,7 @@ mod tests {
|
||||
use chrono::TimeZone;
|
||||
use clock::FakeSystemClock;
|
||||
use gpui::TestAppContext;
|
||||
use http::FakeHttpClient;
|
||||
use util::http::FakeHttpClient;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_telemetry_flush_on_max_queue_size(cx: &mut TestAppContext) {
|
||||
|
||||
@@ -89,7 +89,6 @@ pub enum ContactRequestStatus {
|
||||
|
||||
pub struct UserStore {
|
||||
users: HashMap<u64, Arc<User>>,
|
||||
by_github_login: HashMap<String, u64>,
|
||||
participant_indices: HashMap<u64, ParticipantIndex>,
|
||||
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
|
||||
current_user: watch::Receiver<Option<Arc<User>>>,
|
||||
@@ -145,7 +144,6 @@ impl UserStore {
|
||||
];
|
||||
Self {
|
||||
users: Default::default(),
|
||||
by_github_login: Default::default(),
|
||||
current_user: current_user_rx,
|
||||
contacts: Default::default(),
|
||||
incoming_contact_requests: Default::default(),
|
||||
@@ -233,7 +231,6 @@ impl UserStore {
|
||||
#[cfg(feature = "test-support")]
|
||||
pub fn clear_cache(&mut self) {
|
||||
self.users.clear();
|
||||
self.by_github_login.clear();
|
||||
}
|
||||
|
||||
async fn handle_update_invite_info(
|
||||
@@ -647,12 +644,6 @@ impl UserStore {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn cached_user_by_github_login(&self, github_login: &str) -> Option<Arc<User>> {
|
||||
self.by_github_login
|
||||
.get(github_login)
|
||||
.and_then(|id| self.users.get(id).cloned())
|
||||
}
|
||||
|
||||
pub fn current_user(&self) -> Option<Arc<User>> {
|
||||
self.current_user.borrow().clone()
|
||||
}
|
||||
@@ -670,31 +661,26 @@ impl UserStore {
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
if let Some(rpc) = client.upgrade() {
|
||||
let response = rpc.request(request).await.context("error loading users")?;
|
||||
let users = response.users;
|
||||
let users = response
|
||||
.users
|
||||
.into_iter()
|
||||
.map(User::new)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
this.update(&mut cx, |this, _| this.insert(users))
|
||||
this.update(&mut cx, |this, _| {
|
||||
for user in &users {
|
||||
this.users.insert(user.id, user.clone());
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
|
||||
Ok(users)
|
||||
} else {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, users: Vec<proto::User>) -> Vec<Arc<User>> {
|
||||
let mut ret = Vec::with_capacity(users.len());
|
||||
for user in users {
|
||||
let user = User::new(user);
|
||||
if let Some(old) = self.users.insert(user.id, user.clone()) {
|
||||
if old.github_login != user.github_login {
|
||||
self.by_github_login.remove(&old.github_login);
|
||||
}
|
||||
}
|
||||
self.by_github_login
|
||||
.insert(user.github_login.clone(), user.id);
|
||||
ret.push(user)
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
pub fn set_participant_indices(
|
||||
&mut self,
|
||||
participant_indices: HashMap<u64, ParticipantIndex>,
|
||||
|
||||
@@ -35,7 +35,6 @@ envy = "0.4.2"
|
||||
futures.workspace = true
|
||||
google_ai.workspace = true
|
||||
hex.workspace = true
|
||||
http.workspace = true
|
||||
live_kit_server.workspace = true
|
||||
log.workspace = true
|
||||
nanoid.workspace = true
|
||||
@@ -91,7 +90,6 @@ language = { workspace = true, features = ["test-support"] }
|
||||
live_kit_client = { workspace = true, features = ["test-support"] }
|
||||
lsp = { workspace = true, features = ["test-support"] }
|
||||
menu.workspace = true
|
||||
multi_buffer = { workspace = true, features = ["test-support"] }
|
||||
node_runtime.workspace = true
|
||||
notifications = { workspace = true, features = ["test-support"] }
|
||||
pretty_assertions.workspace = true
|
||||
|
||||
@@ -407,7 +407,6 @@ CREATE TABLE dev_servers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
name TEXT NOT NULL,
|
||||
ssh_connection_string TEXT,
|
||||
hashed_token TEXT NOT NULL
|
||||
);
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
ALTER TABLE dev_servers ADD COLUMN ssh_connection_string TEXT;
|
||||
@@ -15,9 +15,8 @@ use serde::{Serialize, Serializer};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use telemetry_events::{
|
||||
ActionEvent, AppEvent, AssistantEvent, CallEvent, CpuEvent, EditEvent, EditorEvent, Event,
|
||||
EventRequestBody, EventWrapper, ExtensionEvent, InlineCompletionEvent, MemoryEvent,
|
||||
SettingEvent,
|
||||
ActionEvent, AppEvent, AssistantEvent, CallEvent, CopilotEvent, CpuEvent, EditEvent,
|
||||
EditorEvent, Event, EventRequestBody, EventWrapper, ExtensionEvent, MemoryEvent, SettingEvent,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -27,7 +26,6 @@ pub fn router() -> Router {
|
||||
Router::new()
|
||||
.route("/telemetry/events", post(post_events))
|
||||
.route("/telemetry/crashes", post(post_crash))
|
||||
.route("/telemetry/panics", post(post_panic))
|
||||
.route("/telemetry/hangs", post(post_hang))
|
||||
}
|
||||
|
||||
@@ -282,77 +280,30 @@ pub async fn post_hang(
|
||||
backtrace = %backtrace,
|
||||
"hang report");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn post_panic(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
TypedHeader(ZedChecksumHeader(checksum)): TypedHeader<ZedChecksumHeader>,
|
||||
body: Bytes,
|
||||
) -> Result<()> {
|
||||
let Some(expected) = calculate_json_checksum(app.clone(), &body) else {
|
||||
return Err(Error::Http(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"events not enabled".into(),
|
||||
))?;
|
||||
};
|
||||
|
||||
if checksum != expected {
|
||||
return Err(Error::Http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"invalid checksum".into(),
|
||||
))?;
|
||||
}
|
||||
|
||||
let report: telemetry_events::PanicRequest = serde_json::from_slice(&body)
|
||||
.map_err(|_| Error::Http(StatusCode::BAD_REQUEST, "invalid json".into()))?;
|
||||
let panic = report.panic;
|
||||
|
||||
tracing::error!(
|
||||
service = "client",
|
||||
version = %panic.app_version,
|
||||
os_name = %panic.os_name,
|
||||
os_version = %panic.os_version.clone().unwrap_or_default(),
|
||||
installation_id = %panic.installation_id.unwrap_or_default(),
|
||||
description = %panic.payload,
|
||||
backtrace = %panic.backtrace.join("\n"),
|
||||
"panic report");
|
||||
|
||||
let backtrace = if panic.backtrace.len() > 25 {
|
||||
let total = panic.backtrace.len();
|
||||
format!(
|
||||
"{}\n and {} more",
|
||||
panic
|
||||
.backtrace
|
||||
.iter()
|
||||
.take(20)
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n"),
|
||||
total - 20
|
||||
)
|
||||
} else {
|
||||
panic.backtrace.join("\n")
|
||||
};
|
||||
let backtrace_with_summary = panic.payload + "\n" + &backtrace;
|
||||
|
||||
if let Some(slack_panics_webhook) = app.config.slack_panics_webhook.clone() {
|
||||
let payload = slack::WebhookBody::new(|w| {
|
||||
w.add_section(|s| s.text(slack::Text::markdown("Panic request".to_string())))
|
||||
w.add_section(|s| s.text(slack::Text::markdown("Possible Hang".to_string())))
|
||||
.add_section(|s| {
|
||||
s.add_field(slack::Text::markdown(format!(
|
||||
"*Version:*\n {} ",
|
||||
panic.app_version
|
||||
report.app_version.unwrap_or_default()
|
||||
)))
|
||||
.add_field({
|
||||
let hostname = app.config.blob_store_url.clone().unwrap_or_default();
|
||||
let hostname = hostname.strip_prefix("https://").unwrap_or_else(|| {
|
||||
hostname.strip_prefix("http://").unwrap_or_default()
|
||||
});
|
||||
|
||||
slack::Text::markdown(format!(
|
||||
"*OS:*\n{} {}",
|
||||
panic.os_name,
|
||||
panic.os_version.unwrap_or_default()
|
||||
"*Incident:*\n<https://{}.{}/{}.hang.json|{}…>",
|
||||
CRASH_REPORTS_BUCKET,
|
||||
hostname,
|
||||
incident_id,
|
||||
incident_id.chars().take(8).collect::<String>(),
|
||||
))
|
||||
})
|
||||
})
|
||||
.add_rich_text(|r| r.add_preformatted(|p| p.add_text(backtrace_with_summary)))
|
||||
.add_rich_text(|r| r.add_preformatted(|p| p.add_text(backtrace)))
|
||||
});
|
||||
let payload_json = serde_json::to_string(&payload).map_err(|err| {
|
||||
log::error!("Failed to serialize payload to JSON: {err}");
|
||||
@@ -425,19 +376,13 @@ pub async fn post_events(
|
||||
first_event_at,
|
||||
country_code.clone(),
|
||||
)),
|
||||
// Needed for clients sending old copilot_event types
|
||||
Event::Copilot(_) => {}
|
||||
Event::InlineCompletion(event) => {
|
||||
to_upload
|
||||
.inline_completion_events
|
||||
.push(InlineCompletionEventRow::from_event(
|
||||
event.clone(),
|
||||
&wrapper,
|
||||
&request_body,
|
||||
first_event_at,
|
||||
country_code.clone(),
|
||||
))
|
||||
}
|
||||
Event::Copilot(event) => to_upload.copilot_events.push(CopilotEventRow::from_event(
|
||||
event.clone(),
|
||||
&wrapper,
|
||||
&request_body,
|
||||
first_event_at,
|
||||
country_code.clone(),
|
||||
)),
|
||||
Event::Call(event) => to_upload.call_events.push(CallEventRow::from_event(
|
||||
event.clone(),
|
||||
&wrapper,
|
||||
@@ -519,7 +464,7 @@ pub async fn post_events(
|
||||
#[derive(Default)]
|
||||
struct ToUpload {
|
||||
editor_events: Vec<EditorEventRow>,
|
||||
inline_completion_events: Vec<InlineCompletionEventRow>,
|
||||
copilot_events: Vec<CopilotEventRow>,
|
||||
assistant_events: Vec<AssistantEventRow>,
|
||||
call_events: Vec<CallEventRow>,
|
||||
cpu_events: Vec<CpuEventRow>,
|
||||
@@ -538,14 +483,14 @@ impl ToUpload {
|
||||
.await
|
||||
.with_context(|| format!("failed to upload to table '{EDITOR_EVENTS_TABLE}'"))?;
|
||||
|
||||
const INLINE_COMPLETION_EVENTS_TABLE: &str = "inline_completion_events";
|
||||
const COPILOT_EVENTS_TABLE: &str = "copilot_events";
|
||||
Self::upload_to_table(
|
||||
INLINE_COMPLETION_EVENTS_TABLE,
|
||||
&self.inline_completion_events,
|
||||
COPILOT_EVENTS_TABLE,
|
||||
&self.copilot_events,
|
||||
clickhouse_client,
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("failed to upload to table '{INLINE_COMPLETION_EVENTS_TABLE}'"))?;
|
||||
.with_context(|| format!("failed to upload to table '{COPILOT_EVENTS_TABLE}'"))?;
|
||||
|
||||
const ASSISTANT_EVENTS_TABLE: &str = "assistant_events";
|
||||
Self::upload_to_table(
|
||||
@@ -645,7 +590,7 @@ where
|
||||
|
||||
let country_code = country_code.as_bytes();
|
||||
|
||||
serializer.serialize_u16(((country_code[1] as u16) << 8) + country_code[0] as u16)
|
||||
serializer.serialize_u16(((country_code[0] as u16) << 8) + country_code[1] as u16)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, clickhouse::Row)]
|
||||
@@ -715,9 +660,9 @@ impl EditorEventRow {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, clickhouse::Row)]
|
||||
pub struct InlineCompletionEventRow {
|
||||
pub struct CopilotEventRow {
|
||||
pub installation_id: String,
|
||||
pub provider: String,
|
||||
pub suggestion_id: String,
|
||||
pub suggestion_accepted: bool,
|
||||
pub app_version: String,
|
||||
pub file_extension: String,
|
||||
@@ -737,9 +682,9 @@ pub struct InlineCompletionEventRow {
|
||||
pub patch: Option<i32>,
|
||||
}
|
||||
|
||||
impl InlineCompletionEventRow {
|
||||
impl CopilotEventRow {
|
||||
fn from_event(
|
||||
event: InlineCompletionEvent,
|
||||
event: CopilotEvent,
|
||||
wrapper: &EventWrapper,
|
||||
body: &EventRequestBody,
|
||||
first_event_at: chrono::DateTime<chrono::Utc>,
|
||||
@@ -766,7 +711,7 @@ impl InlineCompletionEventRow {
|
||||
country_code: country_code.unwrap_or("XX".to_string()),
|
||||
region_code: "".to_string(),
|
||||
city: "".to_string(),
|
||||
provider: event.provider,
|
||||
suggestion_id: event.suggestion_id.unwrap_or_default(),
|
||||
suggestion_accepted: event.suggestion_accepted,
|
||||
}
|
||||
}
|
||||
@@ -840,8 +785,6 @@ pub struct AssistantEventRow {
|
||||
conversation_id: String,
|
||||
kind: String,
|
||||
model: String,
|
||||
response_latency_in_ms: Option<i64>,
|
||||
error_message: Option<String>,
|
||||
}
|
||||
|
||||
impl AssistantEventRow {
|
||||
@@ -868,10 +811,6 @@ impl AssistantEventRow {
|
||||
conversation_id: event.conversation_id.unwrap_or_default(),
|
||||
kind: event.kind.to_string(),
|
||||
model: event.model,
|
||||
response_latency_in_ms: event
|
||||
.response_latency
|
||||
.map(|latency| latency.as_millis() as i64),
|
||||
error_message: event.error_message,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -415,7 +415,7 @@ impl Database {
|
||||
if is_serialization_error(error) && prev_attempt_count < SLEEPS.len() {
|
||||
let base_delay = SLEEPS[prev_attempt_count];
|
||||
let randomized_delay = base_delay * self.rng.lock().await.gen_range(0.5..=2.0);
|
||||
log::warn!(
|
||||
log::info!(
|
||||
"retrying transaction after serialization error. delay: {} ms.",
|
||||
randomized_delay
|
||||
);
|
||||
@@ -509,7 +509,8 @@ pub type NotificationBatch = Vec<(UserId, proto::Notification)>;
|
||||
|
||||
pub struct CreatedChannelMessage {
|
||||
pub message_id: MessageId,
|
||||
pub participant_connection_ids: HashSet<ConnectionId>,
|
||||
pub participant_connection_ids: Vec<ConnectionId>,
|
||||
pub channel_members: Vec<UserId>,
|
||||
pub notifications: NotificationBatch,
|
||||
}
|
||||
|
||||
|
||||
@@ -440,7 +440,12 @@ impl Database {
|
||||
channel_id: ChannelId,
|
||||
user: UserId,
|
||||
operations: &[proto::Operation],
|
||||
) -> Result<(HashSet<ConnectionId>, i32, Vec<proto::VectorClockEntry>)> {
|
||||
) -> Result<(
|
||||
Vec<ConnectionId>,
|
||||
Vec<UserId>,
|
||||
i32,
|
||||
Vec<proto::VectorClockEntry>,
|
||||
)> {
|
||||
self.transaction(move |tx| async move {
|
||||
let channel = self.get_channel_internal(channel_id, &tx).await?;
|
||||
|
||||
@@ -474,6 +479,7 @@ impl Database {
|
||||
.filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut channel_members;
|
||||
let max_version;
|
||||
|
||||
if !operations.is_empty() {
|
||||
@@ -498,6 +504,12 @@ impl Database {
|
||||
)
|
||||
.await?;
|
||||
|
||||
channel_members = self.get_channel_participants(&channel, &tx).await?;
|
||||
let collaborators = self
|
||||
.get_channel_buffer_collaborators_internal(channel_id, &tx)
|
||||
.await?;
|
||||
channel_members.retain(|member| !collaborators.contains(member));
|
||||
|
||||
buffer_operation::Entity::insert_many(operations)
|
||||
.on_conflict(
|
||||
OnConflict::columns([
|
||||
@@ -512,10 +524,11 @@ impl Database {
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
} else {
|
||||
channel_members = Vec::new();
|
||||
max_version = Vec::new();
|
||||
}
|
||||
|
||||
let mut connections = HashSet::default();
|
||||
let mut connections = Vec::new();
|
||||
let mut rows = channel_buffer_collaborator::Entity::find()
|
||||
.filter(
|
||||
Condition::all()
|
||||
@@ -525,13 +538,13 @@ impl Database {
|
||||
.await?;
|
||||
while let Some(row) = rows.next().await {
|
||||
let row = row?;
|
||||
connections.insert(ConnectionId {
|
||||
connections.push(ConnectionId {
|
||||
id: row.connection_id as u32,
|
||||
owner_id: row.connection_server_id.0 as u32,
|
||||
});
|
||||
}
|
||||
|
||||
Ok((connections, buffer.epoch, max_version))
|
||||
Ok((connections, channel_members, buffer.epoch, max_version))
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use rpc::{
|
||||
proto::{channel_member::Kind, ChannelBufferVersion, VectorClockEntry},
|
||||
ErrorCode, ErrorCodeExt,
|
||||
};
|
||||
use sea_orm::{DbBackend, TryGetableMany};
|
||||
use sea_orm::TryGetableMany;
|
||||
|
||||
impl Database {
|
||||
#[cfg(test)]
|
||||
@@ -700,73 +700,77 @@ impl Database {
|
||||
pub async fn get_channel_participant_details(
|
||||
&self,
|
||||
channel_id: ChannelId,
|
||||
filter: &str,
|
||||
limit: u64,
|
||||
user_id: UserId,
|
||||
) -> Result<(Vec<proto::ChannelMember>, Vec<proto::User>)> {
|
||||
let members = self
|
||||
) -> Result<Vec<proto::ChannelMember>> {
|
||||
let (role, members) = self
|
||||
.transaction(move |tx| async move {
|
||||
let channel = self.get_channel_internal(channel_id, &tx).await?;
|
||||
self.check_user_is_channel_participant(&channel, user_id, &tx)
|
||||
let role = self
|
||||
.check_user_is_channel_participant(&channel, user_id, &tx)
|
||||
.await?;
|
||||
let mut query = channel_member::Entity::find()
|
||||
.find_also_related(user::Entity)
|
||||
.filter(channel_member::Column::ChannelId.eq(channel.root_id()));
|
||||
|
||||
if cfg!(any(test, sqlite)) && self.pool.get_database_backend() == DbBackend::Sqlite {
|
||||
query = query.filter(Expr::cust_with_values(
|
||||
"UPPER(github_login) LIKE ?",
|
||||
[Self::fuzzy_like_string(&filter.to_uppercase())],
|
||||
))
|
||||
} else {
|
||||
query = query.filter(Expr::cust_with_values(
|
||||
"github_login ILIKE $1",
|
||||
[Self::fuzzy_like_string(filter)],
|
||||
))
|
||||
}
|
||||
let members = query.order_by(
|
||||
Expr::cust(
|
||||
"not role = 'admin', not role = 'member', not role = 'guest', not accepted, github_login",
|
||||
),
|
||||
sea_orm::Order::Asc,
|
||||
)
|
||||
.limit(limit)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(members)
|
||||
Ok((
|
||||
role,
|
||||
self.get_channel_participant_details_internal(&channel, &tx)
|
||||
.await?,
|
||||
))
|
||||
})
|
||||
.await?;
|
||||
|
||||
let mut users: Vec<proto::User> = Vec::with_capacity(members.len());
|
||||
|
||||
let members = members
|
||||
.into_iter()
|
||||
.map(|(member, user)| {
|
||||
if let Some(user) = user {
|
||||
users.push(proto::User {
|
||||
id: user.id.to_proto(),
|
||||
avatar_url: format!(
|
||||
"https://github.com/{}.png?size=128",
|
||||
user.github_login
|
||||
),
|
||||
github_login: user.github_login,
|
||||
})
|
||||
}
|
||||
proto::ChannelMember {
|
||||
role: member.role.into(),
|
||||
user_id: member.user_id.to_proto(),
|
||||
kind: if member.accepted {
|
||||
if role == ChannelRole::Admin {
|
||||
Ok(members
|
||||
.into_iter()
|
||||
.map(|channel_member| proto::ChannelMember {
|
||||
role: channel_member.role.into(),
|
||||
user_id: channel_member.user_id.to_proto(),
|
||||
kind: if channel_member.accepted {
|
||||
Kind::Member
|
||||
} else {
|
||||
Kind::Invitee
|
||||
}
|
||||
.into(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
})
|
||||
.collect())
|
||||
} else {
|
||||
return Ok(members
|
||||
.into_iter()
|
||||
.filter_map(|member| {
|
||||
if !member.accepted {
|
||||
return None;
|
||||
}
|
||||
Some(proto::ChannelMember {
|
||||
role: member.role.into(),
|
||||
user_id: member.user_id.to_proto(),
|
||||
kind: Kind::Member.into(),
|
||||
})
|
||||
})
|
||||
.collect());
|
||||
}
|
||||
}
|
||||
|
||||
Ok((members, users))
|
||||
async fn get_channel_participant_details_internal(
|
||||
&self,
|
||||
channel: &channel::Model,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<Vec<channel_member::Model>> {
|
||||
Ok(channel_member::Entity::find()
|
||||
.filter(channel_member::Column::ChannelId.eq(channel.root_id()))
|
||||
.all(tx)
|
||||
.await?)
|
||||
}
|
||||
|
||||
/// Returns the participants in the given channel.
|
||||
pub async fn get_channel_participants(
|
||||
&self,
|
||||
channel: &channel::Model,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<Vec<UserId>> {
|
||||
let participants = self
|
||||
.get_channel_participant_details_internal(channel, tx)
|
||||
.await?;
|
||||
Ok(participants
|
||||
.into_iter()
|
||||
.map(|member| member.user_id)
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Returns whether the given user is an admin in the specified channel.
|
||||
|
||||
@@ -73,7 +73,6 @@ impl Database {
|
||||
pub async fn create_dev_server(
|
||||
&self,
|
||||
name: &str,
|
||||
ssh_connection_string: Option<&str>,
|
||||
hashed_access_token: &str,
|
||||
user_id: UserId,
|
||||
) -> crate::Result<(dev_server::Model, proto::DevServerProjectsUpdate)> {
|
||||
@@ -87,9 +86,6 @@ impl Database {
|
||||
hashed_token: ActiveValue::Set(hashed_access_token.to_string()),
|
||||
name: ActiveValue::Set(name.trim().to_string()),
|
||||
user_id: ActiveValue::Set(user_id),
|
||||
ssh_connection_string: ActiveValue::Set(
|
||||
ssh_connection_string.map(ToOwned::to_owned),
|
||||
),
|
||||
})
|
||||
.exec_with_returning(&*tx)
|
||||
.await?;
|
||||
@@ -137,7 +133,6 @@ impl Database {
|
||||
&self,
|
||||
id: DevServerId,
|
||||
name: &str,
|
||||
ssh_connection_string: Option<&str>,
|
||||
user_id: UserId,
|
||||
) -> crate::Result<proto::DevServerProjectsUpdate> {
|
||||
self.transaction(|tx| async move {
|
||||
@@ -150,9 +145,6 @@ impl Database {
|
||||
|
||||
dev_server::Entity::update(dev_server::ActiveModel {
|
||||
name: ActiveValue::Set(name.trim().to_string()),
|
||||
ssh_connection_string: ActiveValue::Set(
|
||||
ssh_connection_string.map(ToOwned::to_owned),
|
||||
),
|
||||
..dev_server.clone().into_active_model()
|
||||
})
|
||||
.exec(&*tx)
|
||||
|
||||
@@ -251,7 +251,7 @@ impl Database {
|
||||
.await?;
|
||||
|
||||
let mut is_participant = false;
|
||||
let mut participant_connection_ids = HashSet::default();
|
||||
let mut participant_connection_ids = Vec::new();
|
||||
let mut participant_user_ids = Vec::new();
|
||||
while let Some(row) = rows.next().await {
|
||||
let row = row?;
|
||||
@@ -259,7 +259,7 @@ impl Database {
|
||||
is_participant = true;
|
||||
}
|
||||
participant_user_ids.push(row.user_id);
|
||||
participant_connection_ids.insert(row.connection());
|
||||
participant_connection_ids.push(row.connection());
|
||||
}
|
||||
drop(rows);
|
||||
|
||||
@@ -336,9 +336,13 @@ impl Database {
|
||||
}
|
||||
}
|
||||
|
||||
let mut channel_members = self.get_channel_participants(&channel, &tx).await?;
|
||||
channel_members.retain(|member| !participant_user_ids.contains(member));
|
||||
|
||||
Ok(CreatedChannelMessage {
|
||||
message_id,
|
||||
participant_connection_ids,
|
||||
channel_members,
|
||||
notifications,
|
||||
})
|
||||
})
|
||||
|
||||
@@ -130,21 +130,13 @@ impl Database {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn delete_project(&self, project_id: ProjectId) -> Result<()> {
|
||||
self.weak_transaction(|tx| async move {
|
||||
project::Entity::delete_by_id(project_id).exec(&*tx).await?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Unshares the given project.
|
||||
pub async fn unshare_project(
|
||||
&self,
|
||||
project_id: ProjectId,
|
||||
connection: ConnectionId,
|
||||
user_id: Option<UserId>,
|
||||
) -> Result<TransactionGuard<(bool, Option<proto::Room>, Vec<ConnectionId>)>> {
|
||||
) -> Result<TransactionGuard<(Option<proto::Room>, Vec<ConnectionId>)>> {
|
||||
self.project_transaction(project_id, |tx| async move {
|
||||
let guest_connection_ids = self.project_guest_connection_ids(project_id, &tx).await?;
|
||||
let project = project::Entity::find_by_id(project_id)
|
||||
@@ -157,7 +149,10 @@ impl Database {
|
||||
None
|
||||
};
|
||||
if project.host_connection()? == connection {
|
||||
return Ok((true, room, guest_connection_ids));
|
||||
project::Entity::delete(project.into_active_model())
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
return Ok((room, guest_connection_ids));
|
||||
}
|
||||
if let Some(dev_server_project_id) = project.dev_server_project_id {
|
||||
if let Some(user_id) = user_id {
|
||||
@@ -174,7 +169,7 @@ impl Database {
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
return Ok((false, room, guest_connection_ids));
|
||||
return Ok((room, guest_connection_ids));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -48,9 +48,6 @@ impl Database {
|
||||
|
||||
/// Returns all users by ID. There are no access checks here, so this should only be used internally.
|
||||
pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
|
||||
if ids.len() >= 10000_usize {
|
||||
return Err(anyhow!("too many users"))?;
|
||||
}
|
||||
self.transaction(|tx| async {
|
||||
let tx = tx;
|
||||
Ok(user::Entity::find()
|
||||
|
||||
@@ -10,7 +10,6 @@ pub struct Model {
|
||||
pub name: String,
|
||||
pub user_id: UserId,
|
||||
pub hashed_token: String,
|
||||
pub ssh_connection_string: Option<String>,
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -33,7 +32,6 @@ impl Model {
|
||||
dev_server_id: self.id.to_proto(),
|
||||
name: self.name.clone(),
|
||||
status: status as i32,
|
||||
ssh_connection_string: self.ssh_connection_string.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
db::{
|
||||
tests::{channel_tree, new_test_connection, new_test_user},
|
||||
Channel, ChannelId, ChannelRole, Database, NewUserParams, RoomId, UserId,
|
||||
Channel, ChannelId, ChannelRole, Database, NewUserParams, RoomId,
|
||||
},
|
||||
test_both_dbs,
|
||||
};
|
||||
@@ -40,15 +40,15 @@ async fn test_channels(db: &Arc<Database>) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (members, _) = db
|
||||
.get_channel_participant_details(replace_id, "", 10, a_id)
|
||||
let mut members = db
|
||||
.transaction(|tx| async move {
|
||||
let channel = db.get_channel_internal(replace_id, &tx).await?;
|
||||
db.get_channel_participants(&channel, &tx).await
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let ids = members
|
||||
.into_iter()
|
||||
.map(|m| UserId::from_proto(m.user_id))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(ids, &[a_id, b_id]);
|
||||
members.sort();
|
||||
assert_eq!(members, &[a_id, b_id]);
|
||||
|
||||
let rust_id = db.create_root_channel("rust", a_id).await.unwrap();
|
||||
let cargo_id = db.create_sub_channel("cargo", rust_id, a_id).await.unwrap();
|
||||
@@ -195,8 +195,8 @@ async fn test_channel_invites(db: &Arc<Database>) {
|
||||
|
||||
assert_eq!(user_3_invites, &[channel_1_1]);
|
||||
|
||||
let (mut members, _) = db
|
||||
.get_channel_participant_details(channel_1_1, "", 100, user_1)
|
||||
let mut members = db
|
||||
.get_channel_participant_details(channel_1_1, user_1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -231,8 +231,8 @@ async fn test_channel_invites(db: &Arc<Database>) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (members, _) = db
|
||||
.get_channel_participant_details(channel_1_3, "", 100, user_1)
|
||||
let members = db
|
||||
.get_channel_participant_details(channel_1_3, user_1)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
@@ -243,16 +243,16 @@ async fn test_channel_invites(db: &Arc<Database>) {
|
||||
kind: proto::channel_member::Kind::Member.into(),
|
||||
role: proto::ChannelRole::Admin.into(),
|
||||
},
|
||||
proto::ChannelMember {
|
||||
user_id: user_3.to_proto(),
|
||||
kind: proto::channel_member::Kind::Invitee.into(),
|
||||
role: proto::ChannelRole::Admin.into(),
|
||||
},
|
||||
proto::ChannelMember {
|
||||
user_id: user_2.to_proto(),
|
||||
kind: proto::channel_member::Kind::Member.into(),
|
||||
role: proto::ChannelRole::Member.into(),
|
||||
},
|
||||
proto::ChannelMember {
|
||||
user_id: user_3.to_proto(),
|
||||
kind: proto::channel_member::Kind::Invitee.into(),
|
||||
role: proto::ChannelRole::Admin.into(),
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
@@ -482,8 +482,8 @@ async fn test_user_is_channel_participant(db: &Arc<Database>) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (mut members, _) = db
|
||||
.get_channel_participant_details(public_channel_id, "", 100, admin)
|
||||
let mut members = db
|
||||
.get_channel_participant_details(public_channel_id, admin)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -557,8 +557,8 @@ async fn test_user_is_channel_participant(db: &Arc<Database>) {
|
||||
.await
|
||||
.is_err());
|
||||
|
||||
let (mut members, _) = db
|
||||
.get_channel_participant_details(public_channel_id, "", 100, admin)
|
||||
let mut members = db
|
||||
.get_channel_participant_details(public_channel_id, admin)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -594,8 +594,8 @@ async fn test_user_is_channel_participant(db: &Arc<Database>) {
|
||||
.unwrap();
|
||||
|
||||
// currently people invited to parent channels are not shown here
|
||||
let (mut members, _) = db
|
||||
.get_channel_participant_details(public_channel_id, "", 100, admin)
|
||||
let mut members = db
|
||||
.get_channel_participant_details(public_channel_id, admin)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -663,8 +663,8 @@ async fn test_user_is_channel_participant(db: &Arc<Database>) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (mut members, _) = db
|
||||
.get_channel_participant_details(public_channel_id, "", 100, admin)
|
||||
let mut members = db
|
||||
.get_channel_participant_details(public_channel_id, admin)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -42,7 +42,6 @@ use futures::{
|
||||
stream::FuturesUnordered,
|
||||
FutureExt, SinkExt, StreamExt, TryStreamExt,
|
||||
};
|
||||
use http::IsahcHttpClient;
|
||||
use prometheus::{register_int_gauge, IntGauge};
|
||||
use rpc::{
|
||||
proto::{
|
||||
@@ -74,6 +73,7 @@ use tracing::{
|
||||
field::{self},
|
||||
info_span, instrument, Instrument,
|
||||
};
|
||||
use util::http::IsahcHttpClient;
|
||||
|
||||
use self::connection_pool::VersionedMessage;
|
||||
|
||||
@@ -2032,34 +2032,23 @@ async fn unshare_project_internal(
|
||||
user_id: Option<UserId>,
|
||||
session: &Session,
|
||||
) -> Result<()> {
|
||||
let delete = {
|
||||
let room_guard = session
|
||||
.db()
|
||||
.await
|
||||
.unshare_project(project_id, connection_id, user_id)
|
||||
.await?;
|
||||
let (room, guest_connection_ids) = &*session
|
||||
.db()
|
||||
.await
|
||||
.unshare_project(project_id, connection_id, user_id)
|
||||
.await?;
|
||||
|
||||
let (delete, room, guest_connection_ids) = &*room_guard;
|
||||
|
||||
let message = proto::UnshareProject {
|
||||
project_id: project_id.to_proto(),
|
||||
};
|
||||
|
||||
broadcast(
|
||||
Some(connection_id),
|
||||
guest_connection_ids.iter().copied(),
|
||||
|conn_id| session.peer.send(conn_id, message.clone()),
|
||||
);
|
||||
if let Some(room) = room {
|
||||
room_updated(room, &session.peer);
|
||||
}
|
||||
|
||||
*delete
|
||||
let message = proto::UnshareProject {
|
||||
project_id: project_id.to_proto(),
|
||||
};
|
||||
|
||||
if delete {
|
||||
let db = session.db().await;
|
||||
db.delete_project(project_id).await?;
|
||||
broadcast(
|
||||
Some(connection_id),
|
||||
guest_connection_ids.iter().copied(),
|
||||
|conn_id| session.peer.send(conn_id, message.clone()),
|
||||
);
|
||||
if let Some(room) = room {
|
||||
room_updated(room, &session.peer);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -2365,12 +2354,7 @@ async fn create_dev_server(
|
||||
let (dev_server, status) = session
|
||||
.db()
|
||||
.await
|
||||
.create_dev_server(
|
||||
&request.name,
|
||||
request.ssh_connection_string.as_deref(),
|
||||
&hashed_access_token,
|
||||
session.user_id(),
|
||||
)
|
||||
.create_dev_server(&request.name, &hashed_access_token, session.user_id())
|
||||
.await?;
|
||||
|
||||
send_dev_server_projects_update(session.user_id(), status, &session).await;
|
||||
@@ -2378,7 +2362,7 @@ async fn create_dev_server(
|
||||
response.send(proto::CreateDevServerResponse {
|
||||
dev_server_id: dev_server.id.0 as u64,
|
||||
access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
|
||||
name: request.name,
|
||||
name: request.name.clone(),
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -2439,12 +2423,7 @@ async fn rename_dev_server(
|
||||
let status = session
|
||||
.db()
|
||||
.await
|
||||
.rename_dev_server(
|
||||
dev_server_id,
|
||||
&request.name,
|
||||
request.ssh_connection_string.as_deref(),
|
||||
session.user_id(),
|
||||
)
|
||||
.rename_dev_server(dev_server_id, &request.name, session.user_id())
|
||||
.await?;
|
||||
|
||||
send_dev_server_projects_update(session.user_id(), status, &session).await;
|
||||
@@ -3693,15 +3672,10 @@ async fn get_channel_members(
|
||||
) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
let limit = if request.limit == 0 {
|
||||
u16::MAX as u64
|
||||
} else {
|
||||
request.limit
|
||||
};
|
||||
let (members, users) = db
|
||||
.get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
|
||||
let members = db
|
||||
.get_channel_participant_details(channel_id, session.user_id())
|
||||
.await?;
|
||||
response.send(proto::GetChannelMembersResponse { members, users })?;
|
||||
response.send(proto::GetChannelMembersResponse { members })?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -3901,13 +3875,13 @@ async fn update_channel_buffer(
|
||||
let db = session.db().await;
|
||||
let channel_id = ChannelId::from_proto(request.channel_id);
|
||||
|
||||
let (collaborators, epoch, version) = db
|
||||
let (collaborators, non_collaborators, epoch, version) = db
|
||||
.update_channel_buffer(channel_id, session.user_id(), &request.operations)
|
||||
.await?;
|
||||
|
||||
channel_buffer_updated(
|
||||
session.connection_id,
|
||||
collaborators.clone(),
|
||||
collaborators,
|
||||
&proto::UpdateChannelBuffer {
|
||||
channel_id: channel_id.to_proto(),
|
||||
operations: request.operations,
|
||||
@@ -3917,29 +3891,25 @@ async fn update_channel_buffer(
|
||||
|
||||
let pool = &*session.connection_pool().await;
|
||||
|
||||
let non_collaborators =
|
||||
pool.channel_connection_ids(channel_id)
|
||||
.filter_map(|(connection_id, _)| {
|
||||
if collaborators.contains(&connection_id) {
|
||||
None
|
||||
} else {
|
||||
Some(connection_id)
|
||||
}
|
||||
});
|
||||
|
||||
broadcast(None, non_collaborators, |peer_id| {
|
||||
session.peer.send(
|
||||
peer_id,
|
||||
proto::UpdateChannels {
|
||||
latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
|
||||
channel_id: channel_id.to_proto(),
|
||||
epoch: epoch as u64,
|
||||
version: version.clone(),
|
||||
}],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
});
|
||||
broadcast(
|
||||
None,
|
||||
non_collaborators
|
||||
.iter()
|
||||
.flat_map(|user_id| pool.user_connection_ids(*user_id)),
|
||||
|peer_id| {
|
||||
session.peer.send(
|
||||
peer_id,
|
||||
proto::UpdateChannels {
|
||||
latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
|
||||
channel_id: channel_id.to_proto(),
|
||||
epoch: epoch as u64,
|
||||
version: version.clone(),
|
||||
}],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -4067,6 +4037,7 @@ async fn send_channel_message(
|
||||
let CreatedChannelMessage {
|
||||
message_id,
|
||||
participant_connection_ids,
|
||||
channel_members,
|
||||
notifications,
|
||||
} = session
|
||||
.db()
|
||||
@@ -4097,7 +4068,7 @@ async fn send_channel_message(
|
||||
};
|
||||
broadcast(
|
||||
Some(session.connection_id),
|
||||
participant_connection_ids.clone(),
|
||||
participant_connection_ids,
|
||||
|connection| {
|
||||
session.peer.send(
|
||||
connection,
|
||||
@@ -4113,27 +4084,24 @@ async fn send_channel_message(
|
||||
})?;
|
||||
|
||||
let pool = &*session.connection_pool().await;
|
||||
let non_participants =
|
||||
pool.channel_connection_ids(channel_id)
|
||||
.filter_map(|(connection_id, _)| {
|
||||
if participant_connection_ids.contains(&connection_id) {
|
||||
None
|
||||
} else {
|
||||
Some(connection_id)
|
||||
}
|
||||
});
|
||||
broadcast(None, non_participants, |peer_id| {
|
||||
session.peer.send(
|
||||
peer_id,
|
||||
proto::UpdateChannels {
|
||||
latest_channel_message_ids: vec![proto::ChannelMessageId {
|
||||
channel_id: channel_id.to_proto(),
|
||||
message_id: message_id.to_proto(),
|
||||
}],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
});
|
||||
broadcast(
|
||||
None,
|
||||
channel_members
|
||||
.iter()
|
||||
.flat_map(|user_id| pool.user_connection_ids(*user_id)),
|
||||
|peer_id| {
|
||||
session.peer.send(
|
||||
peer_id,
|
||||
proto::UpdateChannels {
|
||||
latest_channel_message_ids: vec![proto::ChannelMessageId {
|
||||
channel_id: channel_id.to_proto(),
|
||||
message_id: message_id.to_proto(),
|
||||
}],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
},
|
||||
);
|
||||
send_notifications(pool, &session.peer, notifications);
|
||||
|
||||
Ok(())
|
||||
@@ -4365,7 +4333,6 @@ async fn complete_with_open_ai(
|
||||
OPEN_AI_API_URL,
|
||||
&api_key,
|
||||
crate::ai::language_model_request_to_open_ai(request)?,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.context("open_ai::stream_completion request failed within collab")?;
|
||||
@@ -4510,8 +4477,8 @@ async fn complete_with_anthropic(
|
||||
.collect();
|
||||
|
||||
let mut stream = anthropic::stream_completion(
|
||||
session.http_client.as_ref(),
|
||||
anthropic::ANTHROPIC_API_URL,
|
||||
session.http_client.clone(),
|
||||
"https://api.anthropic.com",
|
||||
&api_key,
|
||||
anthropic::Request {
|
||||
model,
|
||||
@@ -4520,7 +4487,6 @@ async fn complete_with_anthropic(
|
||||
system: system_message,
|
||||
max_tokens: 4092,
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ async fn test_core_channels(
|
||||
.channel_store()
|
||||
.update(cx_a, |store, cx| {
|
||||
assert!(!store.has_pending_channel_invite(channel_a_id, client_b.user_id().unwrap()));
|
||||
store.fuzzy_search_members(channel_a_id, "".to_string(), 10, cx)
|
||||
store.get_channel_member_details(channel_a_id, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -20,7 +20,7 @@ async fn test_dev_server(cx: &mut gpui::TestAppContext, cx2: &mut gpui::TestAppC
|
||||
|
||||
let resp = store
|
||||
.update(cx, |store, cx| {
|
||||
store.create_dev_server("server-1".to_string(), None, cx)
|
||||
store.create_dev_server("server-1".to_string(), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -167,7 +167,7 @@ async fn create_dev_server_project(
|
||||
|
||||
let resp = store
|
||||
.update(cx, |store, cx| {
|
||||
store.create_dev_server("server-1".to_string(), None, cx)
|
||||
store.create_dev_server("server-1".to_string(), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -352,7 +352,6 @@ async fn test_dev_server_rename(
|
||||
store.rename_dev_server(
|
||||
store.dev_servers().first().unwrap().id,
|
||||
"name-edited".to_string(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -522,7 +521,7 @@ async fn test_create_dev_server_project_path_validation(
|
||||
|
||||
let resp = store
|
||||
.update(cx1, |store, cx| {
|
||||
store.create_dev_server("server-2".to_string(), None, cx)
|
||||
store.create_dev_server("server-2".to_string(), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user